Skip to content

Commit

Permalink
s2: Clean up matchlen assembly (#825)
Browse files Browse the repository at this point in the history
* s2: Clean up matchlen assembly
  • Loading branch information
klauspost authored Jun 13, 2023
1 parent a3a5dce commit d9eae82
Show file tree
Hide file tree
Showing 2 changed files with 287 additions and 337 deletions.
169 changes: 77 additions & 92 deletions s2/_generate/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ func main() {
bmi1: false,
bmi2: false,
snappy: false,
avx2: false,
outputMargin: 9,
}
o.genEncodeBlockAsm("encodeBlockAsm", 14, 6, 6, limit14B)
Expand Down Expand Up @@ -150,6 +151,7 @@ type options struct {
bmi1 bool
bmi2 bool
skipOutput bool
avx2 bool
maxLen int
maxOffset int
outputMargin int // Should be at least 5.
Expand Down Expand Up @@ -614,6 +616,9 @@ func (o options) genEncodeBlockAsm(name string, tableBits, skipLog, hashBytes, m
panic(err)
}
MOVQ(U32(0), ri.Addr)
if o.avx2 {
VZEROUPPER()
}
RET()
}
Label("match_dst_size_check_" + name)
Expand Down Expand Up @@ -697,6 +702,9 @@ func (o options) genEncodeBlockAsm(name string, tableBits, skipLog, hashBytes, m
panic(err)
}
MOVQ(U32(0), ri.Addr)
if o.avx2 {
VZEROUPPER()
}
RET()
Label("match_nolit_dst_ok_" + name)
}
Expand Down Expand Up @@ -753,6 +761,9 @@ func (o options) genEncodeBlockAsm(name string, tableBits, skipLog, hashBytes, m
if err != nil {
panic(err)
}
if o.avx2 {
VZEROUPPER()
}
MOVQ(U32(0), ri.Addr)
RET()
Label("emit_remainder_ok_" + name)
Expand Down Expand Up @@ -801,6 +812,9 @@ func (o options) genEncodeBlockAsm(name string, tableBits, skipLog, hashBytes, m
JAE(ok)
})
}
if o.avx2 {
VZEROUPPER()
}
Store(length, ReturnIndex(0))
RET()
}
Expand Down Expand Up @@ -1273,6 +1287,9 @@ func (o options) genEncodeBetterBlockAsm(name string, lTableBits, sTableBits, sk
panic(err)
}
MOVQ(U32(0), ri.Addr)
if o.avx2 {
VZEROUPPER()
}
RET()
}
Label("match_dst_size_check_" + name)
Expand Down Expand Up @@ -1385,6 +1402,9 @@ func (o options) genEncodeBetterBlockAsm(name string, lTableBits, sTableBits, sk
panic(err)
}
MOVQ(U32(0), ri.Addr)
if o.avx2 {
VZEROUPPER()
}
RET()
}
}
Expand Down Expand Up @@ -1538,6 +1558,9 @@ func (o options) genEncodeBetterBlockAsm(name string, lTableBits, sTableBits, sk
panic(err)
}
MOVQ(U32(0), ri.Addr)
if o.avx2 {
VZEROUPPER()
}
RET()
Label("emit_remainder_ok_" + name)
}
Expand Down Expand Up @@ -1579,6 +1602,9 @@ func (o options) genEncodeBetterBlockAsm(name string, lTableBits, sTableBits, sk
JAE(ok)
})
Store(length, ReturnIndex(0))
if o.avx2 {
VZEROUPPER()
}
RET()
}

Expand Down Expand Up @@ -2696,6 +2722,9 @@ func (o options) genMatchLen() {
Load(Param("a").Len(), length)
l := o.matchLen("standalone", aBase, bBase, length, LabelRef("gen_match_len_end"))
Label("gen_match_len_end")
if o.avx2 {
VZEROUPPER()
}
Store(l.As64(), ReturnIndex(0))
RET()
}
Expand All @@ -2706,11 +2735,13 @@ func (o options) genMatchLen() {
// Uses 2 GP registers.
func (o options) matchLen(name string, a, b, len reg.GPVirtual, end LabelRef) reg.GPVirtual {
Comment("matchLen")
if false {
return o.matchLenAlt(name, a, b, len, end)
}
tmp, matched := GP64(), GP32()
XORL(matched, matched)
if o.avx2 {
// Not faster...
o.matchLenAVX2(name+"Avx2", a, b, len, LabelRef("avx2_continue_"+name), end, matched)
}
Label("avx2_continue_" + name)

CMPL(len.As32(), U8(8))
JB(LabelRef("matchlen_match4_" + name))
Expand Down Expand Up @@ -2740,7 +2771,6 @@ func (o options) matchLen(name string, a, b, len reg.GPVirtual, end LabelRef) re
LEAL(Mem{Base: matched, Disp: 8}, matched)
CMPL(len.As32(), U8(8))
JAE(LabelRef("matchlen_loopback_" + name))
JZ(end)

// Less than 8 bytes left.
// Test 4 bytes...
Expand All @@ -2750,23 +2780,25 @@ func (o options) matchLen(name string, a, b, len reg.GPVirtual, end LabelRef) re
MOVL(Mem{Base: a, Index: matched, Scale: 1}, tmp.As32())
CMPL(Mem{Base: b, Index: matched, Scale: 1}, tmp.As32())
JNE(LabelRef("matchlen_match2_" + name))
SUBL(U8(4), len.As32())
LEAL(Mem{Base: len.As32(), Disp: -4}, len.As32())
LEAL(Mem{Base: matched, Disp: 4}, matched)

// Test 2 bytes...
Label("matchlen_match2_" + name)
CMPL(len.As32(), U8(2))
JB(LabelRef("matchlen_match1_" + name))
CMPL(len.As32(), U8(1))
// If we don't have 1, branch appropriately
JE(LabelRef("matchlen_match1_" + name))
JB(end)
// 2 or 3
MOVW(Mem{Base: a, Index: matched, Scale: 1}, tmp.As16())
CMPW(Mem{Base: b, Index: matched, Scale: 1}, tmp.As16())
JNE(LabelRef("matchlen_match1_" + name))
SUBL(U8(2), len.As32())
LEAL(Mem{Base: matched, Disp: 2}, matched)
SUBL(U8(2), len.As32())
JZ(end)

// Test 1 byte...
Label("matchlen_match1_" + name)
CMPL(len.As32(), U8(1))
JB(end)
MOVB(Mem{Base: a, Index: matched, Scale: 1}, tmp.As8())
CMPB(Mem{Base: b, Index: matched, Scale: 1}, tmp.As8())
JNE(end)
Expand All @@ -2780,94 +2812,47 @@ func (o options) matchLen(name string, a, b, len reg.GPVirtual, end LabelRef) re
// Will jump to end when done and returns the length.
// Uses 3 GP registers.
// It is better on longer matches.
func (o options) matchLenAlt(name string, a, b, len reg.GPVirtual, end LabelRef) reg.GPVirtual {
Comment("matchLenAlt")
tmp, tmp2, matched := GP64(), GP64(), GP32()
XORL(matched, matched)

CMPL(len.As32(), U8(16))
JB(LabelRef("matchlen_short_" + name))
func (o options) matchLenAVX2(name string, a, b, len reg.GPVirtual, cont, end LabelRef, dst reg.GPVirtual) {
Comment("matchLenAVX2")

Label("matchlen_loopback_" + name)
MOVQ(Mem{Base: a}, tmp)
MOVQ(Mem{Base: a, Disp: 8}, tmp2)
XORQ(Mem{Base: b, Disp: 0}, tmp)
XORQ(Mem{Base: b, Disp: 8}, tmp2)
endTest := func(xored reg.GPVirtual, disp int, ok LabelRef) {
TESTQ(xored, xored)
JZ(ok)
// Not all match.
BSFQ(xored, xored)
SARQ(U8(3), xored)
LEAL(Mem{Base: matched, Index: xored, Scale: 1, Disp: disp}, matched)
JMP(end)
}
endTest(tmp, 0, LabelRef("matchlen_loop_tmp2_"+name))
Label("matchlen_loop_tmp2_" + name)
endTest(tmp2, 8, LabelRef("matchlen_loop_"+name))

// All 16 byte matched, update and loop.
Label("matchlen_loop_" + name)
SUBL(U8(16), len.As32())
ADDL(U8(16), matched)
ADDQ(U8(16), a)
ADDQ(U8(16), b)
CMPL(len.As32(), U8(16))
JAE(LabelRef("matchlen_loopback_" + name))

// Test 4 bytes at the time...
Label("matchlen_short_" + name)
lenoff := 0
if true {
lenoff = 4
SUBL(U8(4), len.As32())
JC(LabelRef("matchlen_single_resume_" + name))

Label("matchlen_four_loopback_" + name)
assert(func(ok LabelRef) {
CMPL(len.As32(), U32(math.MaxInt32))
JB(ok)
})

MOVL(Mem{Base: a}, tmp.As32())
XORL(Mem{Base: b}, tmp.As32())
{
JZ(LabelRef("matchlen_four_loopback_next" + name))
BSFL(tmp.As32(), tmp.As32())
SARQ(U8(3), tmp)
LEAL(Mem{Base: matched, Index: tmp, Scale: 1}, matched)
JMP(end)
}
Label("matchlen_four_loopback_next" + name)
ADDL(U8(4), matched)
ADDQ(U8(4), a)
ADDQ(U8(4), b)
SUBL(U8(4), len.As32())
JNC(LabelRef("matchlen_four_loopback_" + name))
equalMaskBits := GP64()
Label(name + "loop")
{
CMPQ(len, U8(32))
JB(cont)
Comment("load 32 bytes into YMM registers")
adata := YMM()
bdata := YMM()
equalMaskBytes := YMM()
VMOVDQU(Mem{Base: a}, adata)
VMOVDQU(Mem{Base: b}, bdata)
Comment("compare bytes in adata and bdata, like 'bytewise XNOR'",
"if the byte is the same in adata and bdata, VPCMPEQB will store 0xFF in the same position in equalMaskBytes")
VPCMPEQB(adata, bdata, equalMaskBytes)
Comment("like convert byte to bit, store equalMaskBytes into general reg")
VPMOVMSKB(equalMaskBytes, equalMaskBits.As32())
CMPL(equalMaskBits.As32(), U32(0xffffffff))
JNE(LabelRef(name + "cal_prefix"))
ADDQ(U8(32), a)
ADDQ(U8(32), b)
ADDL(U8(32), dst)
SUBQ(U8(32), len)
JZ(end)
JMP(LabelRef(name + "loop"))
}

// Test one at the time
Label("matchlen_single_resume_" + name)
if true {
// Less than 16 bytes left.
if lenoff > 0 {
ADDL(U8(lenoff), len.As32())
Label(name + "cal_prefix")
{
NOTQ(equalMaskBits)
if o.bmi1 {
TZCNTQ(equalMaskBits, equalMaskBits)
} else {
BSFQ(equalMaskBits, equalMaskBits)
}
TESTL(len.As32(), len.As32())
JZ(end)

Label("matchlen_single_loopback_" + name)
MOVB(Mem{Base: a}, tmp.As8())
CMPB(Mem{Base: b}, tmp.As8())
JNE(end)
INCL(matched)
INCQ(a)
INCQ(b)
DECL(len.As32())
JNZ(LabelRef("matchlen_single_loopback_" + name))
ADDL(equalMaskBits.As32(), dst)
}
JMP(end)
return matched
return
}

func (o options) cvtLZ4BlockAsm(lz4s bool) {
Expand Down
Loading

0 comments on commit d9eae82

Please sign in to comment.