Skip to content

Commit

Permalink
provide transpose128Rev function
Browse files Browse the repository at this point in the history
  • Loading branch information
emmansun committed May 24, 2023
1 parent 158bee9 commit 2bc005e
Show file tree
Hide file tree
Showing 6 changed files with 938 additions and 30 deletions.
229 changes: 219 additions & 10 deletions _asm/transpose_amd64_asm.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,13 @@ func transpose128() {
tmp := XMM()
b := GP8()
o := GP32()
cc := GP64()

Comment("Initialize rr, current row")
rr := zero()
Label("row_loop")
Comment("Initialize cc, current col")
cc := zero()
XORQ(cc, cc)
Label("col_loop")

Comment("Initialize (rr * ncols + cc) / 8, here ncols=128")
Expand Down Expand Up @@ -203,6 +204,214 @@ func transpose128() {
RET()
}

func transpose128Rev() {
// transpose128Rev function
TEXT("transpose128Rev", NOSPLIT, "func(in, out *byte)")
Doc("Bit level matrix transpose, b0-b1-b2-b3, 128x128")

in := Mem{Base: Load(Param("in"), GP64())}
out := Mem{Base: Load(Param("out"), GP64())}

tmp := XMM()
b := GP8()
o := GP32()

Comment("Initialize rr, current row, 96")
rr := zero()
cc := GP64()
addr := GP64()

Label("row_loop_b3")
Comment("Initialize cc, current col")
XORQ(cc, cc)
Label("col_loop_b3")
Comment("Initialize (rr * ncols + cc) / 8, here ncols=128")
MOVQ(rr, addr)
ADDQ(Imm(96), addr)
Comment("Multiple with ncols")
SHLQ(Imm(7), addr)
ADDQ(cc, addr)
SHRQ(Imm(3), addr)

Comment("Construct one XMM with first byte of first 16 rows")
for i := 0; i < 16; i++ {
MOVB(in.Idx(addr, 1), b)
PINSRB(Imm(uint64(i)), b.As32(), tmp)
Comment("Add ncols / 8")
ADDQ(Imm(16), addr)
}

Comment("Initialize ((cc + 7) * nrows + rr) / 8, here nrows = 128")
MOVQ(cc, addr)
ADDQ(Imm(7), addr)
Comment("Multiple with nrows")
SHLQ(Imm(7), addr)
ADDQ(rr, addr)
SHRQ(Imm(3), addr)

Comment("Get the most significant bit of each 8-bit element in the XMM, and store the returned 2 bytes")
for i := 7; i >= 0; i-- {
PMOVMSKB(tmp, o)
MOVW(o.As16(), out.Idx(addr, 1))
PSLLQ(Imm(1), tmp)
Comment("Sub nrows / 8")
SUBQ(Imm(16), addr)
}

Comment("Compare cc with ncols, here ncols=128")
ADDQ(Imm(8), cc)
CMPQ(cc, Imm(128))
JL(LabelRef("col_loop_b3"))

Comment("Compare rr with nrows, here nrows=128")
ADDQ(Imm(16), rr)
CMPQ(rr, U8(32))
JL(LabelRef("row_loop_b3"))

Label("row_loop_b2")
Comment("Initialize cc, current col")
XORQ(cc, cc)
Label("col_loop_b2")
Comment("Initialize (rr * ncols + cc) / 8, here ncols=128")
MOVQ(rr, addr)
ADDQ(Imm(32), addr)
Comment("Multiple with ncols")
SHLQ(Imm(7), addr)
ADDQ(cc, addr)
SHRQ(Imm(3), addr)

Comment("Construct one XMM with first byte of first 16 rows")
for i := 0; i < 16; i++ {
MOVB(in.Idx(addr, 1), b)
PINSRB(Imm(uint64(i)), b.As32(), tmp)
Comment("Add ncols / 8")
ADDQ(Imm(16), addr)
}

Comment("Initialize ((cc + 7) * nrows + rr) / 8, here nrows = 128")
MOVQ(cc, addr)
ADDQ(Imm(7), addr)
Comment("Multiple with nrows")
SHLQ(Imm(7), addr)
ADDQ(rr, addr)
SHRQ(Imm(3), addr)

Comment("Get the most significant bit of each 8-bit element in the XMM, and store the returned 2 bytes")
for i := 7; i >= 0; i-- {
PMOVMSKB(tmp, o)
MOVW(o.As16(), out.Idx(addr, 1))
PSLLQ(Imm(1), tmp)
Comment("Sub nrows / 8")
SUBQ(Imm(16), addr)
}

Comment("Compare cc with ncols, here ncols=128")
ADDQ(Imm(8), cc)
CMPQ(cc, Imm(128))
JL(LabelRef("col_loop_b2"))

Comment("Compare rr with nrows, here nrows=128")
ADDQ(Imm(16), rr)
CMPQ(rr, U8(64))
JL(LabelRef("row_loop_b2"))

Label("row_loop_b1")
Comment("Initialize cc, current col")
XORQ(cc, cc)
Label("col_loop_b1")
Comment("Initialize (rr * ncols + cc) / 8, here ncols=128")
MOVQ(rr, addr)
SUBQ(Imm(32), addr)
Comment("Multiple with ncols")
SHLQ(Imm(7), addr)
ADDQ(cc, addr)
SHRQ(Imm(3), addr)

Comment("Construct one XMM with first byte of first 16 rows")
for i := 0; i < 16; i++ {
MOVB(in.Idx(addr, 1), b)
PINSRB(Imm(uint64(i)), b.As32(), tmp)
Comment("Add ncols / 8")
ADDQ(Imm(16), addr)
}

Comment("Initialize ((cc + 7) * nrows + rr) / 8, here nrows = 128")
MOVQ(cc, addr)
ADDQ(Imm(7), addr)
Comment("Multiple with nrows")
SHLQ(Imm(7), addr)
ADDQ(rr, addr)
SHRQ(Imm(3), addr)

Comment("Get the most significant bit of each 8-bit element in the XMM, and store the returned 2 bytes")
for i := 7; i >= 0; i-- {
PMOVMSKB(tmp, o)
MOVW(o.As16(), out.Idx(addr, 1))
PSLLQ(Imm(1), tmp)
Comment("Sub nrows / 8")
SUBQ(Imm(16), addr)
}

Comment("Compare cc with ncols, here ncols=128")
ADDQ(Imm(8), cc)
CMPQ(cc, Imm(128))
JL(LabelRef("col_loop_b1"))

Comment("Compare rr with nrows, here nrows=128")
ADDQ(Imm(16), rr)
CMPQ(rr, U8(96))
JL(LabelRef("row_loop_b1"))

Label("row_loop_b0")
Comment("Initialize cc, current col")
XORQ(cc, cc)
Label("col_loop_b0")
Comment("Initialize (rr * ncols + cc) / 8, here ncols=128")
MOVQ(rr, addr)
SUBQ(Imm(96), addr)
Comment("Multiple with ncols")
SHLQ(Imm(7), addr)
ADDQ(cc, addr)
SHRQ(Imm(3), addr)

Comment("Construct one XMM with first byte of first 16 rows")
for i := 0; i < 16; i++ {
MOVB(in.Idx(addr, 1), b)
PINSRB(Imm(uint64(i)), b.As32(), tmp)
Comment("Add ncols / 8")
ADDQ(Imm(16), addr)
}

Comment("Initialize ((cc + 7) * nrows + rr) / 8, here nrows = 128")
MOVQ(cc, addr)
ADDQ(Imm(7), addr)
Comment("Multiple with nrows")
SHLQ(Imm(7), addr)
ADDQ(rr, addr)
SHRQ(Imm(3), addr)

Comment("Get the most significant bit of each 8-bit element in the XMM, and store the returned 2 bytes")
for i := 7; i >= 0; i-- {
PMOVMSKB(tmp, o)
MOVW(o.As16(), out.Idx(addr, 1))
PSLLQ(Imm(1), tmp)
Comment("Sub nrows / 8")
SUBQ(Imm(16), addr)
}

Comment("Compare cc with ncols, here ncols=128")
ADDQ(Imm(8), cc)
CMPQ(cc, Imm(128))
JL(LabelRef("col_loop_b0"))

Comment("Compare rr with nrows, here nrows=128")
ADDQ(Imm(16), rr)
CMPQ(rr, U8(128))
JL(LabelRef("row_loop_b0"))

RET()
}

func xor32x128() {
// xor32x128 function
TEXT("xor32x128", NOSPLIT, "func(x, y, out *byte)")
Expand Down Expand Up @@ -728,22 +937,21 @@ func sbox128() {
PANDN(f, t8) // e9

Comment("e10=^(g1 & l1)")
MOVOU(buffer.Offset(1*16), t9)
PAND(t7, t9)
PANDN(f, t9) // e10
MOVOU(buffer.Offset(1*16), t1)
PAND(t7, t1)
PANDN(f, t1) // e10

Comment("r6=e9 ^ e10")
PXOR(t9, t8) // r6 = e9 ^ e10
PXOR(t1, t8) // r6 = e9 ^ e10

Comment("e11=^(g0 & l0)")
MOVOU(buffer, t10)
PAND(t11, t10)
PANDN(f, t10) // e11
Comment("r7=e11 ^ e10")
PXOR(t10, t9) // r7 = e11 ^ e10
Comment("store r6 r7")
PXOR(t10, t1) // r7 = e11 ^ e10
Comment("store r6")
MOVOU(t8, buffer.Offset(28*16))
MOVOU(t9, buffer.Offset(29*16))

Comment("e12=^(m6 & k3)")
MOVOU(buffer.Offset((8+6)*16), t7) // m6
Expand Down Expand Up @@ -787,13 +995,13 @@ func sbox128() {
PXOR(t10, t9) // r11 = e17 ^ e16 = t9

Comment("start output function")
// t1 = r7
// t7 = r8
// t11 = r9
// t8 = r10
// t9 = r11
Comment("[t1]=r7 ^ r9")
MOVOU(buffer.Offset((22+7)*16), t1) // r7
PXOR(t1, t11) // t11 = r7 ^ r9
PXOR(t1, t11) // t11 = r7 ^ r9
Comment("t2=t1 ^ r1")
MOVOU(buffer.Offset((22+1)*16), t2) // r1
PXOR(t11, t2) // t2 = r1 ^ t11
Expand Down Expand Up @@ -1166,6 +1374,7 @@ func main() {
transpose64()
transpose64Rev()
transpose128()
transpose128Rev()
xor32x128()
xor32x128avx()
xorRoundKey128()
Expand Down
8 changes: 1 addition & 7 deletions bs128.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,5 @@ func (bs bs128) EncryptBlocks(xk []uint32, dst, src []byte) {
b2 = bs.xor32(b2, bs.l(bs.tao(bs.xorRK(xk[i*4+2], rk, b3, b0, b1), buffer), buffer))
b3 = bs.xor32(b3, bs.l(bs.tao(bs.xorRK(xk[i*4+3], rk, b0, b1, b2), buffer), buffer))
}
copy(rk, b0)
copy(state[:], b3)
copy(state[96*bitSize:], rk)
copy(rk, b1)
copy(state[32*bitSize:], b2)
copy(state[64*bitSize:], rk)
transpose128(&state[0], &dst[0])
transpose128Rev(&state[0], &dst[0])
}
19 changes: 19 additions & 0 deletions bs128_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,25 @@ func BenchmarkL128(b *testing.B) {
}
}

func BenchmarkXorRK(b *testing.B) {
b0 := make([]byte, 32*BS128.bytes())
b1 := make([]byte, 32*BS128.bytes())
b2 := make([]byte, 32*BS128.bytes())
rk := make([]byte, 32*BS128.bytes())
k := uint32(0xa3b1bac6)
for i := 0; i < b.N; i++ {
BS128.xorRK(k, rk, b0, b1, b2)
}
}

func BenchmarkXor32(b *testing.B) {
b0 := make([]byte, 32*BS128.bytes())
b1 := make([]byte, 32*BS128.bytes())
for i := 0; i < b.N; i++ {
BS128.xor32(b0, b1)
}
}

func TestBS128EncryptBlocks(t *testing.T) {
bitSize := BS128.bytes()
key := []byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10}
Expand Down
3 changes: 3 additions & 0 deletions transpose128_amd64.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 2bc005e

Please sign in to comment.