From 9bd1e430e4e659adf893b6e2a03866ddd6e3db74 Mon Sep 17 00:00:00 2001 From: weiihann Date: Fri, 13 Dec 2024 20:48:08 +0800 Subject: [PATCH 01/30] this works --- core/trie/bitarray.go | 134 +++++++++++++++++++++++++++++++++++++ core/trie/bitarray_test.go | 98 +++++++++++++++++++++++++++ 2 files changed, 232 insertions(+) create mode 100644 core/trie/bitarray.go create mode 100644 core/trie/bitarray_test.go diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go new file mode 100644 index 0000000000..286e38d7ae --- /dev/null +++ b/core/trie/bitarray.go @@ -0,0 +1,134 @@ +package trie + +import ( + "encoding/binary" + + "github.com/NethermindEth/juno/core/felt" +) + +const ( + mask64 = uint64(1 << 63) +) + +type bitArray struct { + len uint8 + words [4]uint64 // Little endian (i.e. words[0] is the least significant) +} + +func (b *bitArray) Len() uint8 { + return b.len +} + +func (b *bitArray) Bytes() [32]byte { + var res [32]byte + + switch { + case b.len == 0: + return res + case b.len == 255: + binary.BigEndian.PutUint64(res[0:8], b.words[3]) + binary.BigEndian.PutUint64(res[8:16], b.words[2]) + binary.BigEndian.PutUint64(res[16:24], b.words[1]) + binary.BigEndian.PutUint64(res[24:32], b.words[0]) + case b.len >= 192: + rem := 256 - uint(b.len) + mask := uint64(1<<(64-rem)) - 1 + binary.BigEndian.PutUint64(res[0:8], b.words[3]&mask) + binary.BigEndian.PutUint64(res[8:16], b.words[2]) + binary.BigEndian.PutUint64(res[16:24], b.words[1]) + binary.BigEndian.PutUint64(res[24:32], b.words[0]) + case b.len >= 128: + rem := 192 - b.len + mask := uint64(1<<(64-rem)) - 1 + binary.BigEndian.PutUint64(res[8:16], b.words[2]&mask) + binary.BigEndian.PutUint64(res[16:24], b.words[1]) + binary.BigEndian.PutUint64(res[24:32], b.words[0]) + case b.len >= 64: + rem := 128 - b.len + mask := uint64(1<<(64-rem)) - 1 + binary.BigEndian.PutUint64(res[16:24], b.words[1]&mask) + binary.BigEndian.PutUint64(res[24:32], b.words[0]) + default: + rem := 64 - b.len + mask := uint64(1<<(64-rem)) - 1 + binary.BigEndian.PutUint64(res[24:32], b.words[0]&mask) + } + + return res +} + +func (b *bitArray) SetFelt(f *felt.Felt) *bitArray { + res := f.Bytes() + b.words[3] = binary.BigEndian.Uint64(res[0:8]) + b.words[2] = binary.BigEndian.Uint64(res[8:16]) + b.words[1] = binary.BigEndian.Uint64(res[16:24]) + b.words[0] = binary.BigEndian.Uint64(res[24:32]) + b.len = felt.Bits - 1 + return b +} + +func (b *bitArray) Rsh(x *bitArray, n uint8) *bitArray { + if n >= b.len { + return b.clear() + } + + switch { + case n == 0: + return b.set(x) + case n >= 192: + b.rsh192(x) + n -= 192 + b.words[0] >>= n + b.len -= n + case n >= 128: + b.rsh128(x) + n -= 128 + b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) + b.words[1] >>= n + b.len -= n + case n >= 64: + b.rsh64(x) + n -= 64 + b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) + b.words[1] = (b.words[1] >> n) | (b.words[2] << (64 - n)) + b.words[2] = (b.words[2] >> n) | (b.words[3] << (64 - n)) + b.words[3] >>= n + b.len -= n + default: + b.set(x) + b.words[3] = (b.words[3] >> n) | (b.words[2] << (64 - n)) + b.words[2] = (b.words[2] >> n) | (b.words[1] << (64 - n)) + b.words[1] = (b.words[1] >> n) | (b.words[0] << (64 - n)) + b.words[0] >>= n + b.len -= n + } + + return b +} + +func (b *bitArray) set(x *bitArray) *bitArray { + b.len = x.len + b.words[0] = x.words[0] + b.words[1] = x.words[1] + b.words[2] = x.words[2] + b.words[3] = x.words[3] + return b +} + +func (b *bitArray) rsh64(x *bitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = 0, x.words[3], x.words[2], x.words[1] +} + +func (b *bitArray) rsh128(x *bitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, x.words[3], x.words[2] +} + +func (b *bitArray) rsh192(x *bitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, x.words[3] +} + +func (b *bitArray) clear() *bitArray { + b.len = 0 + b.words[0], b.words[1], b.words[2], b.words[3] = 0, 0, 0, 0 + return b +} diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go new file mode 100644 index 0000000000..6a5e974df6 --- /dev/null +++ b/core/trie/bitarray_test.go @@ -0,0 +1,98 @@ +package trie + +import ( + "bytes" + "encoding/binary" + "testing" +) + +var maxBitArray = [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF} + +func TestBytes(t *testing.T) { + tests := []struct { + name string + bitArray bitArray + want [32]byte + }{ + // { + // name: "length == 0", + // bitArray: bitArray{len: 0, words: maxBitArray}, + // want: [32]byte{}, + // }, + // { + // name: "length < 64", + // bitArray: bitArray{len: 38, words: maxBitArray}, + // want: func() [32]byte { + // var b [32]byte + // binary.BigEndian.PutUint64(b[24:32], 0x3FFFFFFFFF) + // return b + // }(), + // }, + // { + // name: "64 <= length < 128", + // bitArray: bitArray{len: 100, words: maxBitArray}, + // want: func() [32]byte { + // var b [32]byte + // binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFF) + // binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) + // return b + // }(), + // }, + // { + // name: "128 <= length < 192", + // bitArray: bitArray{len: 130, words: maxBitArray}, + // want: func() [32]byte { + // var b [32]byte + // binary.BigEndian.PutUint64(b[8:16], 0x3) + // binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) + // binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) + // return b + // }(), + // }, + { + name: "192 <= length < 255", + bitArray: bitArray{len: 201, words: maxBitArray}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[0:8], 0x1FF) + binary.BigEndian.PutUint64(b[8:16], 0xFFFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) + return b + }(), + }, + // { + // name: "length == 254", + // bitArray: bitArray{len: 254, words: maxBitArray}, + // want: func() [32]byte { + // var b [32]byte + // binary.BigEndian.PutUint64(b[0:8], 0x7FFFFFFFFFFFFFFF) + // binary.BigEndian.PutUint64(b[8:16], 0xFFFFFFFFFFFFFFFF) + // binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) + // binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) + // return b + // }(), + // }, + // { + // name: "length == 255", + // bitArray: bitArray{len: 255, words: maxBitArray}, + // want: func() [32]byte { + // var b [32]byte + // binary.BigEndian.PutUint64(b[0:8], 0xFFFFFFFFFFFFFFFF) + // binary.BigEndian.PutUint64(b[8:16], 0xFFFFFFFFFFFFFFFF) + // binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) + // binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) + // return b + // }(), + // }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.bitArray.Bytes() + if !bytes.Equal(got[:], tt.want[:]) { + t.Errorf("bitArray.Bytes() = %v, want %v", got, tt.want) + } + }) + } +} From a68d346cc03a3bb177a248ec0444234d26305e8e Mon Sep 17 00:00:00 2001 From: weiihann Date: Fri, 13 Dec 2024 22:40:09 +0800 Subject: [PATCH 02/30] one failed but im getting closer --- core/trie/bitarray.go | 55 +++++++++-------- core/trie/bitarray_test.go | 122 ++++++++++++++++++------------------- 2 files changed, 89 insertions(+), 88 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 286e38d7ae..6887140d41 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -11,46 +11,42 @@ const ( ) type bitArray struct { - len uint8 - words [4]uint64 // Little endian (i.e. words[0] is the least significant) -} - -func (b *bitArray) Len() uint8 { - return b.len + pos uint8 // position of the most significant bit + words [4]uint64 // little endian (i.e. words[0] is the least significant) } func (b *bitArray) Bytes() [32]byte { var res [32]byte switch { - case b.len == 0: + case b.pos == 0: return res - case b.len == 255: + case b.pos == 255: binary.BigEndian.PutUint64(res[0:8], b.words[3]) binary.BigEndian.PutUint64(res[8:16], b.words[2]) binary.BigEndian.PutUint64(res[16:24], b.words[1]) binary.BigEndian.PutUint64(res[24:32], b.words[0]) - case b.len >= 192: - rem := 256 - uint(b.len) - mask := uint64(1<<(64-rem)) - 1 + case b.pos >= 192: + rem := 255 - uint(b.pos) + mask := ^mask64 >> (rem - 1) binary.BigEndian.PutUint64(res[0:8], b.words[3]&mask) binary.BigEndian.PutUint64(res[8:16], b.words[2]) binary.BigEndian.PutUint64(res[16:24], b.words[1]) binary.BigEndian.PutUint64(res[24:32], b.words[0]) - case b.len >= 128: - rem := 192 - b.len - mask := uint64(1<<(64-rem)) - 1 + case b.pos >= 128: + rem := 191 - b.pos + mask := ^mask64 >> (rem - 1) binary.BigEndian.PutUint64(res[8:16], b.words[2]&mask) binary.BigEndian.PutUint64(res[16:24], b.words[1]) binary.BigEndian.PutUint64(res[24:32], b.words[0]) - case b.len >= 64: - rem := 128 - b.len - mask := uint64(1<<(64-rem)) - 1 + case b.pos >= 64: + rem := 127 - b.pos + mask := ^mask64 >> (rem - 1) binary.BigEndian.PutUint64(res[16:24], b.words[1]&mask) binary.BigEndian.PutUint64(res[24:32], b.words[0]) default: - rem := 64 - b.len - mask := uint64(1<<(64-rem)) - 1 + rem := 63 - b.pos + mask := ^mask64 >> (rem - 1) binary.BigEndian.PutUint64(res[24:32], b.words[0]&mask) } @@ -63,12 +59,17 @@ func (b *bitArray) SetFelt(f *felt.Felt) *bitArray { b.words[2] = binary.BigEndian.Uint64(res[8:16]) b.words[1] = binary.BigEndian.Uint64(res[16:24]) b.words[0] = binary.BigEndian.Uint64(res[24:32]) - b.len = felt.Bits - 1 + b.pos = felt.Bits - 1 return b } +// Rsh shifts the bit array to the right by n bits. func (b *bitArray) Rsh(x *bitArray, n uint8) *bitArray { - if n >= b.len { + if b.pos == 0 { + return b + } + + if n >= b.pos { return b.clear() } @@ -79,13 +80,13 @@ func (b *bitArray) Rsh(x *bitArray, n uint8) *bitArray { b.rsh192(x) n -= 192 b.words[0] >>= n - b.len -= n + b.pos -= n case n >= 128: b.rsh128(x) n -= 128 b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) b.words[1] >>= n - b.len -= n + b.pos -= n case n >= 64: b.rsh64(x) n -= 64 @@ -93,21 +94,21 @@ func (b *bitArray) Rsh(x *bitArray, n uint8) *bitArray { b.words[1] = (b.words[1] >> n) | (b.words[2] << (64 - n)) b.words[2] = (b.words[2] >> n) | (b.words[3] << (64 - n)) b.words[3] >>= n - b.len -= n + b.pos -= n default: b.set(x) b.words[3] = (b.words[3] >> n) | (b.words[2] << (64 - n)) b.words[2] = (b.words[2] >> n) | (b.words[1] << (64 - n)) b.words[1] = (b.words[1] >> n) | (b.words[0] << (64 - n)) b.words[0] >>= n - b.len -= n + b.pos -= n } return b } func (b *bitArray) set(x *bitArray) *bitArray { - b.len = x.len + b.pos = x.pos b.words[0] = x.words[0] b.words[1] = x.words[1] b.words[2] = x.words[2] @@ -128,7 +129,7 @@ func (b *bitArray) rsh192(x *bitArray) { } func (b *bitArray) clear() *bitArray { - b.len = 0 + b.pos = 0 b.words[0], b.words[1], b.words[2], b.words[3] = 0, 0, 0, 0 return b } diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 6a5e974df6..695acc8811 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -14,77 +14,77 @@ func TestBytes(t *testing.T) { bitArray bitArray want [32]byte }{ - // { - // name: "length == 0", - // bitArray: bitArray{len: 0, words: maxBitArray}, - // want: [32]byte{}, - // }, - // { - // name: "length < 64", - // bitArray: bitArray{len: 38, words: maxBitArray}, - // want: func() [32]byte { - // var b [32]byte - // binary.BigEndian.PutUint64(b[24:32], 0x3FFFFFFFFF) - // return b - // }(), - // }, - // { - // name: "64 <= length < 128", - // bitArray: bitArray{len: 100, words: maxBitArray}, - // want: func() [32]byte { - // var b [32]byte - // binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFF) - // binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) - // return b - // }(), - // }, - // { - // name: "128 <= length < 192", - // bitArray: bitArray{len: 130, words: maxBitArray}, - // want: func() [32]byte { - // var b [32]byte - // binary.BigEndian.PutUint64(b[8:16], 0x3) - // binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) - // binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) - // return b - // }(), - // }, + { + name: "length == 0", + bitArray: bitArray{pos: 0, words: maxBitArray}, + want: [32]byte{}, + }, + { + name: "length < 64", + bitArray: bitArray{pos: 38, words: maxBitArray}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[24:32], 0x7FFFFFFFFF) + return b + }(), + }, + { + name: "64 <= length < 128", + bitArray: bitArray{pos: 100, words: maxBitArray}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[16:24], 0x7FFFFFFFFF) + binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) + return b + }(), + }, + { + name: "128 <= length < 192", + bitArray: bitArray{pos: 130, words: maxBitArray}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[8:16], 0x7) + binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) + return b + }(), + }, { name: "192 <= length < 255", - bitArray: bitArray{len: 201, words: maxBitArray}, + bitArray: bitArray{pos: 201, words: maxBitArray}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[0:8], 0x3FF) + binary.BigEndian.PutUint64(b[8:16], 0xFFFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) + return b + }(), + }, + { + name: "length == 254", + bitArray: bitArray{pos: 254, words: maxBitArray}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[0:8], 0x7FFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[8:16], 0xFFFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) + return b + }(), + }, + { + name: "length == 255", + bitArray: bitArray{pos: 255, words: maxBitArray}, want: func() [32]byte { var b [32]byte - binary.BigEndian.PutUint64(b[0:8], 0x1FF) + binary.BigEndian.PutUint64(b[0:8], 0xFFFFFFFFFFFFFFFF) binary.BigEndian.PutUint64(b[8:16], 0xFFFFFFFFFFFFFFFF) binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) return b }(), }, - // { - // name: "length == 254", - // bitArray: bitArray{len: 254, words: maxBitArray}, - // want: func() [32]byte { - // var b [32]byte - // binary.BigEndian.PutUint64(b[0:8], 0x7FFFFFFFFFFFFFFF) - // binary.BigEndian.PutUint64(b[8:16], 0xFFFFFFFFFFFFFFFF) - // binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) - // binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) - // return b - // }(), - // }, - // { - // name: "length == 255", - // bitArray: bitArray{len: 255, words: maxBitArray}, - // want: func() [32]byte { - // var b [32]byte - // binary.BigEndian.PutUint64(b[0:8], 0xFFFFFFFFFFFFFFFF) - // binary.BigEndian.PutUint64(b[8:16], 0xFFFFFFFFFFFFFFFF) - // binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) - // binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) - // return b - // }(), - // }, } for _, tt := range tests { From 1ee61592feecf1f342432d9b3bc80fa1f55dcd0d Mon Sep 17 00:00:00 2001 From: weiihann Date: Fri, 13 Dec 2024 22:52:42 +0800 Subject: [PATCH 03/30] this works --- core/trie/bitarray.go | 26 +++++++++++++++++--------- core/trie/bitarray_test.go | 4 +--- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 6887140d41..da18c6cd78 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -10,8 +10,10 @@ const ( mask64 = uint64(1 << 63) ) +var maxBitArray = [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF} + type bitArray struct { - pos uint8 // position of the most significant bit + pos uint8 // position of the current most significant bit (0-255) words [4]uint64 // little endian (i.e. words[0] is the least significant) } @@ -27,26 +29,32 @@ func (b *bitArray) Bytes() [32]byte { binary.BigEndian.PutUint64(res[16:24], b.words[1]) binary.BigEndian.PutUint64(res[24:32], b.words[0]) case b.pos >= 192: - rem := 255 - uint(b.pos) - mask := ^mask64 >> (rem - 1) + // For positions >= 192, we need to mask the most significant word (words[3]) + // to zero out bits beyond the current position. + // Example: if pos = 201, then rem = 255 - 201 = 54 + // mask = ^mask64 >> (54 - 1) = ^(1<<63) >> 53 + // This creates a mask like: 0000000000000000000000000000000000000000000000000000001111111111 + // When applied to words[3], it preserves only the 10 least significant bits + shift := 255 - b.pos + mask := ^mask64 >> (shift - 1) binary.BigEndian.PutUint64(res[0:8], b.words[3]&mask) binary.BigEndian.PutUint64(res[8:16], b.words[2]) binary.BigEndian.PutUint64(res[16:24], b.words[1]) binary.BigEndian.PutUint64(res[24:32], b.words[0]) case b.pos >= 128: - rem := 191 - b.pos - mask := ^mask64 >> (rem - 1) + shift := 191 - b.pos + mask := ^mask64 >> (shift - 1) binary.BigEndian.PutUint64(res[8:16], b.words[2]&mask) binary.BigEndian.PutUint64(res[16:24], b.words[1]) binary.BigEndian.PutUint64(res[24:32], b.words[0]) case b.pos >= 64: - rem := 127 - b.pos - mask := ^mask64 >> (rem - 1) + shift := 127 - b.pos + mask := ^mask64 >> (shift - 1) binary.BigEndian.PutUint64(res[16:24], b.words[1]&mask) binary.BigEndian.PutUint64(res[24:32], b.words[0]) default: - rem := 63 - b.pos - mask := ^mask64 >> (rem - 1) + shift := 63 - b.pos + mask := ^mask64 >> (shift - 1) binary.BigEndian.PutUint64(res[24:32], b.words[0]&mask) } diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 695acc8811..6ef2fa1095 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -6,8 +6,6 @@ import ( "testing" ) -var maxBitArray = [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF} - func TestBytes(t *testing.T) { tests := []struct { name string @@ -33,7 +31,7 @@ func TestBytes(t *testing.T) { bitArray: bitArray{pos: 100, words: maxBitArray}, want: func() [32]byte { var b [32]byte - binary.BigEndian.PutUint64(b[16:24], 0x7FFFFFFFFF) + binary.BigEndian.PutUint64(b[16:24], 0x1FFFFFFFFF) binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) return b }(), From 0c708a559c5078e7c0ac1d22afa1471466dfb529 Mon Sep 17 00:00:00 2001 From: weiihann Date: Fri, 13 Dec 2024 22:58:24 +0800 Subject: [PATCH 04/30] add bytes benchmark --- core/trie/bitarray_test.go | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 6ef2fa1095..fd747dd19f 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -94,3 +94,40 @@ func TestBytes(t *testing.T) { }) } } + +func BenchmarkBitArrayBytes(b *testing.B) { + testCases := []struct { + name string + ba bitArray + }{ + { + name: "empty", + ba: bitArray{pos: 0, words: maxBitArray}, + }, + { + name: "pos_38", + ba: bitArray{pos: 38, words: maxBitArray}, + }, + { + name: "pos_100", + ba: bitArray{pos: 100, words: maxBitArray}, + }, + { + name: "pos_201", + ba: bitArray{pos: 201, words: maxBitArray}, + }, + { + name: "pos_255", + ba: bitArray{pos: 255, words: maxBitArray}, + }, + } + + for _, tc := range testCases { + b.Run(tc.name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _ = tc.ba.Bytes() + } + }) + } +} From bcad34cc51d2a61ba32d882003da31f682933583 Mon Sep 17 00:00:00 2001 From: weiihann Date: Fri, 13 Dec 2024 23:51:24 +0800 Subject: [PATCH 05/30] looks gud --- core/trie/bitarray.go | 65 ++++++++++++++++++----------------- core/trie/bitarray_test.go | 69 ++++++++++++-------------------------- 2 files changed, 53 insertions(+), 81 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index da18c6cd78..bdbd158986 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -12,49 +12,48 @@ const ( var maxBitArray = [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF} +// bitArray is a structure that represents a bit array with a max length of 255 bits. +// The reason why 255 bits is the max length is because we only need up to 252 bits for the felt. +// Though words can be used to represent 256 bits, we don't want to add an additional byte for the length. +// It uses a little endian representation to do bitwise operations of the words efficiently. +// Unlike normal bit arrays, it has a `len` field that represents the number of used bits. +// For example, if len is 10, it means that the 2^9, 2^8, ..., 2^0 bits are used. type bitArray struct { - pos uint8 // position of the current most significant bit (0-255) + len uint8 // number of used bits words [4]uint64 // little endian (i.e. words[0] is the least significant) } +// Bytes returns the bytes representation of the bit array in big endian format. func (b *bitArray) Bytes() [32]byte { var res [32]byte switch { - case b.pos == 0: + case b.len == 0: return res - case b.pos == 255: - binary.BigEndian.PutUint64(res[0:8], b.words[3]) - binary.BigEndian.PutUint64(res[8:16], b.words[2]) - binary.BigEndian.PutUint64(res[16:24], b.words[1]) - binary.BigEndian.PutUint64(res[24:32], b.words[0]) - case b.pos >= 192: - // For positions >= 192, we need to mask the most significant word (words[3]) - // to zero out bits beyond the current position. - // Example: if pos = 201, then rem = 255 - 201 = 54 - // mask = ^mask64 >> (54 - 1) = ^(1<<63) >> 53 - // This creates a mask like: 0000000000000000000000000000000000000000000000000000001111111111 - // When applied to words[3], it preserves only the 10 least significant bits - shift := 255 - b.pos - mask := ^mask64 >> (shift - 1) + case b.len >= 192: + // len is 0-based, so 255 (not 256) represents all bits used + // subtracting from 255 ensures correct mask when len=255 + // For example, when len is 255, it means all bits from index 0 + // to 254 are used (total of 255 bits). + // So when we create the mask, we shift 255 - 255 = 0 bits to the right. + // This creates a mask that covers all bits from index 0 to 254. + mask := ^mask64 >> (255 - b.len) binary.BigEndian.PutUint64(res[0:8], b.words[3]&mask) binary.BigEndian.PutUint64(res[8:16], b.words[2]) binary.BigEndian.PutUint64(res[16:24], b.words[1]) binary.BigEndian.PutUint64(res[24:32], b.words[0]) - case b.pos >= 128: - shift := 191 - b.pos - mask := ^mask64 >> (shift - 1) + case b.len >= 128: + // Similar pattern for 191 boundary (3 words × 64 bits - 1) + mask := ^mask64 >> (191 - b.len) binary.BigEndian.PutUint64(res[8:16], b.words[2]&mask) binary.BigEndian.PutUint64(res[16:24], b.words[1]) binary.BigEndian.PutUint64(res[24:32], b.words[0]) - case b.pos >= 64: - shift := 127 - b.pos - mask := ^mask64 >> (shift - 1) + case b.len >= 64: + mask := ^mask64 >> (127 - b.len) binary.BigEndian.PutUint64(res[16:24], b.words[1]&mask) binary.BigEndian.PutUint64(res[24:32], b.words[0]) default: - shift := 63 - b.pos - mask := ^mask64 >> (shift - 1) + mask := ^mask64 >> (63 - b.len) binary.BigEndian.PutUint64(res[24:32], b.words[0]&mask) } @@ -67,17 +66,17 @@ func (b *bitArray) SetFelt(f *felt.Felt) *bitArray { b.words[2] = binary.BigEndian.Uint64(res[8:16]) b.words[1] = binary.BigEndian.Uint64(res[16:24]) b.words[0] = binary.BigEndian.Uint64(res[24:32]) - b.pos = felt.Bits - 1 + b.len = felt.Bits - 1 return b } // Rsh shifts the bit array to the right by n bits. func (b *bitArray) Rsh(x *bitArray, n uint8) *bitArray { - if b.pos == 0 { + if b.len == 0 { return b } - if n >= b.pos { + if n >= b.len { return b.clear() } @@ -88,13 +87,13 @@ func (b *bitArray) Rsh(x *bitArray, n uint8) *bitArray { b.rsh192(x) n -= 192 b.words[0] >>= n - b.pos -= n + b.len -= n case n >= 128: b.rsh128(x) n -= 128 b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) b.words[1] >>= n - b.pos -= n + b.len -= n case n >= 64: b.rsh64(x) n -= 64 @@ -102,21 +101,21 @@ func (b *bitArray) Rsh(x *bitArray, n uint8) *bitArray { b.words[1] = (b.words[1] >> n) | (b.words[2] << (64 - n)) b.words[2] = (b.words[2] >> n) | (b.words[3] << (64 - n)) b.words[3] >>= n - b.pos -= n + b.len -= n default: b.set(x) b.words[3] = (b.words[3] >> n) | (b.words[2] << (64 - n)) b.words[2] = (b.words[2] >> n) | (b.words[1] << (64 - n)) b.words[1] = (b.words[1] >> n) | (b.words[0] << (64 - n)) b.words[0] >>= n - b.pos -= n + b.len -= n } return b } func (b *bitArray) set(x *bitArray) *bitArray { - b.pos = x.pos + b.len = x.len b.words[0] = x.words[0] b.words[1] = x.words[1] b.words[2] = x.words[2] @@ -137,7 +136,7 @@ func (b *bitArray) rsh192(x *bitArray) { } func (b *bitArray) clear() *bitArray { - b.pos = 0 + b.len = 0 b.words[0], b.words[1], b.words[2], b.words[3] = 0, 0, 0, 0 return b } diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index fd747dd19f..cdd23b9481 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -3,6 +3,7 @@ package trie import ( "bytes" "encoding/binary" + "math/bits" "testing" ) @@ -14,34 +15,34 @@ func TestBytes(t *testing.T) { }{ { name: "length == 0", - bitArray: bitArray{pos: 0, words: maxBitArray}, + bitArray: bitArray{len: 0, words: maxBitArray}, want: [32]byte{}, }, { name: "length < 64", - bitArray: bitArray{pos: 38, words: maxBitArray}, + bitArray: bitArray{len: 38, words: maxBitArray}, want: func() [32]byte { var b [32]byte - binary.BigEndian.PutUint64(b[24:32], 0x7FFFFFFFFF) + binary.BigEndian.PutUint64(b[24:32], 0x3FFFFFFFFF) return b }(), }, { name: "64 <= length < 128", - bitArray: bitArray{pos: 100, words: maxBitArray}, + bitArray: bitArray{len: 100, words: maxBitArray}, want: func() [32]byte { var b [32]byte - binary.BigEndian.PutUint64(b[16:24], 0x1FFFFFFFFF) + binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFF) binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) return b }(), }, { name: "128 <= length < 192", - bitArray: bitArray{pos: 130, words: maxBitArray}, + bitArray: bitArray{len: 130, words: maxBitArray}, want: func() [32]byte { var b [32]byte - binary.BigEndian.PutUint64(b[8:16], 0x7) + binary.BigEndian.PutUint64(b[8:16], 0x3) binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) return b @@ -49,10 +50,10 @@ func TestBytes(t *testing.T) { }, { name: "192 <= length < 255", - bitArray: bitArray{pos: 201, words: maxBitArray}, + bitArray: bitArray{len: 201, words: maxBitArray}, want: func() [32]byte { var b [32]byte - binary.BigEndian.PutUint64(b[0:8], 0x3FF) + binary.BigEndian.PutUint64(b[0:8], 0x1FF) binary.BigEndian.PutUint64(b[8:16], 0xFFFFFFFFFFFFFFFF) binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) @@ -61,10 +62,10 @@ func TestBytes(t *testing.T) { }, { name: "length == 254", - bitArray: bitArray{pos: 254, words: maxBitArray}, + bitArray: bitArray{len: 254, words: maxBitArray}, want: func() [32]byte { var b [32]byte - binary.BigEndian.PutUint64(b[0:8], 0x7FFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[0:8], 0x3FFFFFFFFFFFFFFF) binary.BigEndian.PutUint64(b[8:16], 0xFFFFFFFFFFFFFFFF) binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) @@ -73,10 +74,10 @@ func TestBytes(t *testing.T) { }, { name: "length == 255", - bitArray: bitArray{pos: 255, words: maxBitArray}, + bitArray: bitArray{len: 255, words: maxBitArray}, want: func() [32]byte { var b [32]byte - binary.BigEndian.PutUint64(b[0:8], 0xFFFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[0:8], 0x7FFFFFFFFFFFFFFF) binary.BigEndian.PutUint64(b[8:16], 0xFFFFFFFFFFFFFFFF) binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) @@ -91,42 +92,14 @@ func TestBytes(t *testing.T) { if !bytes.Equal(got[:], tt.want[:]) { t.Errorf("bitArray.Bytes() = %v, want %v", got, tt.want) } - }) - } -} -func BenchmarkBitArrayBytes(b *testing.B) { - testCases := []struct { - name string - ba bitArray - }{ - { - name: "empty", - ba: bitArray{pos: 0, words: maxBitArray}, - }, - { - name: "pos_38", - ba: bitArray{pos: 38, words: maxBitArray}, - }, - { - name: "pos_100", - ba: bitArray{pos: 100, words: maxBitArray}, - }, - { - name: "pos_201", - ba: bitArray{pos: 201, words: maxBitArray}, - }, - { - name: "pos_255", - ba: bitArray{pos: 255, words: maxBitArray}, - }, - } - - for _, tc := range testCases { - b.Run(tc.name, func(b *testing.B) { - b.ReportAllocs() - for i := 0; i < b.N; i++ { - _ = tc.ba.Bytes() + // check if the received bytes has the same bit count as the bitArray.len + count := 0 + for _, b := range got { + count += bits.OnesCount8(b) + } + if count != int(tt.bitArray.len) { + t.Errorf("bitArray.Bytes() bit count = %v, want %v", count, tt.bitArray.len) } }) } From 3bac94471fee742c9143a816b616c0fd1c5a0f98 Mon Sep 17 00:00:00 2001 From: weiihann Date: Sat, 14 Dec 2024 18:40:04 +0800 Subject: [PATCH 06/30] add Rsh test --- core/trie/bitarray.go | 100 +++++++++++++++++++++++------------ core/trie/bitarray_test.go | 103 +++++++++++++++++++++++++++++++++++++ 2 files changed, 170 insertions(+), 33 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index bdbd158986..bcb4c69498 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -2,82 +2,108 @@ package trie import ( "encoding/binary" + "math" "github.com/NethermindEth/juno/core/felt" ) const ( - mask64 = uint64(1 << 63) + maxUint64 = uint64(math.MaxUint64) ) -var maxBitArray = [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF} +var maxBitArray = [4]uint64{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64} // bitArray is a structure that represents a bit array with a max length of 255 bits. -// The reason why 255 bits is the max length is because we only need up to 252 bits for the felt. -// Though words can be used to represent 256 bits, we don't want to add an additional byte for the length. // It uses a little endian representation to do bitwise operations of the words efficiently. // Unlike normal bit arrays, it has a `len` field that represents the number of used bits. // For example, if len is 10, it means that the 2^9, 2^8, ..., 2^0 bits are used. +// The reason why 255 bits is the max length is because we only need up to 251 bits for a given trie key. +// Though words can be used to represent 256 bits, we don't want to add an additional byte for the length. type bitArray struct { len uint8 // number of used bits words [4]uint64 // little endian (i.e. words[0] is the least significant) } -// Bytes returns the bytes representation of the bit array in big endian format. +// Bytes returns the bytes representation of the bit array in big endian format func (b *bitArray) Bytes() [32]byte { var res [32]byte switch { case b.len == 0: + // all zeros return res case b.len >= 192: - // len is 0-based, so 255 (not 256) represents all bits used - // subtracting from 255 ensures correct mask when len=255 - // For example, when len is 255, it means all bits from index 0 - // to 254 are used (total of 255 bits). - // So when we create the mask, we shift 255 - 255 = 0 bits to the right. - // This creates a mask that covers all bits from index 0 to 254. - mask := ^mask64 >> (255 - b.len) + // Create mask for top word: keeps only valid bits above 192 + // e.g., if len=200, keeps lowest 8 bits (200-192) + mask := maxUint64 >> (256 - uint16(b.len)) binary.BigEndian.PutUint64(res[0:8], b.words[3]&mask) binary.BigEndian.PutUint64(res[8:16], b.words[2]) binary.BigEndian.PutUint64(res[16:24], b.words[1]) binary.BigEndian.PutUint64(res[24:32], b.words[0]) case b.len >= 128: - // Similar pattern for 191 boundary (3 words × 64 bits - 1) - mask := ^mask64 >> (191 - b.len) + // Mask for bits 128-191: keeps only valid bits above 128 + // e.g., if len=150, keeps lowest 22 bits (150-128) + mask := maxUint64 >> (192 - b.len) binary.BigEndian.PutUint64(res[8:16], b.words[2]&mask) binary.BigEndian.PutUint64(res[16:24], b.words[1]) binary.BigEndian.PutUint64(res[24:32], b.words[0]) case b.len >= 64: - mask := ^mask64 >> (127 - b.len) + // You get the idea + mask := maxUint64 >> (128 - b.len) binary.BigEndian.PutUint64(res[16:24], b.words[1]&mask) binary.BigEndian.PutUint64(res[24:32], b.words[0]) default: - mask := ^mask64 >> (63 - b.len) + mask := maxUint64 >> (64 - b.len) binary.BigEndian.PutUint64(res[24:32], b.words[0]&mask) } return res } -func (b *bitArray) SetFelt(f *felt.Felt) *bitArray { +func (b *bitArray) SetFelt(length uint8, f *felt.Felt) *bitArray { + b.setFelt(f) + b.len = length + return b +} + +func (b *bitArray) SetFelt251(f *felt.Felt) *bitArray { + b.setFelt(f) + b.len = 251 + return b +} + +func (b *bitArray) setFelt(f *felt.Felt) { res := f.Bytes() b.words[3] = binary.BigEndian.Uint64(res[0:8]) b.words[2] = binary.BigEndian.Uint64(res[8:16]) b.words[1] = binary.BigEndian.Uint64(res[16:24]) b.words[0] = binary.BigEndian.Uint64(res[24:32]) - b.len = felt.Bits - 1 - return b } -// Rsh shifts the bit array to the right by n bits. +func (b *bitArray) PrefixEqual(x *bitArray) bool { + if b.len == x.len { + return b.Equal(x) + } + + var long, short *bitArray + long, short = b, x + + if b.len < x.len { + long, short = x, b + } + + return long.Rsh(long, long.len-short.len).Equal(short) +} + +// Rsh sets b = x >> n and returns b. func (b *bitArray) Rsh(x *bitArray, n uint8) *bitArray { - if b.len == 0 { - return b + if x.len == 0 { + return b.set(x) } - if n >= b.len { - return b.clear() + if n >= x.len { + x.clear() + return b.set(x) } switch { @@ -85,35 +111,43 @@ func (b *bitArray) Rsh(x *bitArray, n uint8) *bitArray { return b.set(x) case n >= 192: b.rsh192(x) + b.len = x.len - n n -= 192 b.words[0] >>= n - b.len -= n case n >= 128: b.rsh128(x) + b.len = x.len - n n -= 128 b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) b.words[1] >>= n - b.len -= n case n >= 64: b.rsh64(x) + b.len = x.len - n n -= 64 b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) b.words[1] = (b.words[1] >> n) | (b.words[2] << (64 - n)) - b.words[2] = (b.words[2] >> n) | (b.words[3] << (64 - n)) - b.words[3] >>= n - b.len -= n + b.words[2] >>= n default: b.set(x) - b.words[3] = (b.words[3] >> n) | (b.words[2] << (64 - n)) - b.words[2] = (b.words[2] >> n) | (b.words[1] << (64 - n)) - b.words[1] = (b.words[1] >> n) | (b.words[0] << (64 - n)) - b.words[0] >>= n b.len -= n + b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) + b.words[1] = (b.words[1] >> n) | (b.words[2] << (64 - n)) + b.words[2] = (b.words[2] >> n) | (b.words[3] << (64 - n)) + b.words[3] >>= n } return b } +// Eq checks if two bit arrays are equal +func (b *bitArray) Equal(x *bitArray) bool { + return b.len == x.len && + b.words[0] == x.words[0] && + b.words[1] == x.words[1] && + b.words[2] == x.words[2] && + b.words[3] == x.words[3] +} + func (b *bitArray) set(x *bitArray) *bitArray { b.len = x.len b.words[0] = x.words[0] diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index cdd23b9481..b740c93143 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -104,3 +104,106 @@ func TestBytes(t *testing.T) { }) } } + +func TestRsh(t *testing.T) { + tests := []struct { + name string + initial *bitArray + shiftBy uint8 + expected *bitArray + }{ + { + name: "zero length array", + initial: &bitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + shiftBy: 5, + expected: &bitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "shift by 0", + initial: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + shiftBy: 0, + expected: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "shift by more than length", + initial: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + shiftBy: 65, + expected: &bitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "shift by less than 64", + initial: &bitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + shiftBy: 32, + expected: &bitArray{ + len: 96, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0x00000000FFFFFFFF, 0, 0}, + }, + }, + { + name: "shift by exactly 64", + initial: &bitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + shiftBy: 64, + expected: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "shift by 128", + initial: &bitArray{ + len: 251, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}, + }, + shiftBy: 128, + expected: &bitArray{ + len: 123, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + }, + { + name: "shift by 192", + initial: &bitArray{ + len: 251, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}, + }, + shiftBy: 192, + expected: &bitArray{ + len: 59, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := new(bitArray).Rsh(tt.initial, tt.shiftBy) + if !result.Equal(tt.expected) { + t.Errorf("Rsh() got = %+v, want %+v", result, tt.expected) + } + }) + } +} From 0356e730cf5e2634886c48d3ce6d4ca9e791c781 Mon Sep 17 00:00:00 2001 From: weiihann Date: Sat, 14 Dec 2024 19:20:55 +0800 Subject: [PATCH 07/30] add Truncate --- core/trie/bitarray.go | 98 +++++++++---- core/trie/bitarray_test.go | 281 +++++++++++++++++++++++++++++++++++++ 2 files changed, 354 insertions(+), 25 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index bcb4c69498..3a7387f320 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -60,31 +60,15 @@ func (b *bitArray) Bytes() [32]byte { return res } -func (b *bitArray) SetFelt(length uint8, f *felt.Felt) *bitArray { - b.setFelt(f) - b.len = length - return b -} - -func (b *bitArray) SetFelt251(f *felt.Felt) *bitArray { - b.setFelt(f) - b.len = 251 - return b -} - -func (b *bitArray) setFelt(f *felt.Felt) { - res := f.Bytes() - b.words[3] = binary.BigEndian.Uint64(res[0:8]) - b.words[2] = binary.BigEndian.Uint64(res[8:16]) - b.words[1] = binary.BigEndian.Uint64(res[16:24]) - b.words[0] = binary.BigEndian.Uint64(res[24:32]) -} - func (b *bitArray) PrefixEqual(x *bitArray) bool { if b.len == x.len { return b.Equal(x) } + if b.len == 0 || x.len == 0 { + return true + } + var long, short *bitArray long, short = b, x @@ -95,20 +79,64 @@ func (b *bitArray) PrefixEqual(x *bitArray) bool { return long.Rsh(long, long.len-short.len).Equal(short) } +// Truncate sets b to the first 'length' bits of x and returns b. +// If length >= x.len, b is a copy of x. +// Any bits beyond the specified length are cleared to zero. +// For example: +// +// x = 11001011 (len=8) +// Truncate(x, 4) = 1011 (len=4) +// Truncate(x, 10) = 11001011 (len=8, original x) +// Truncate(x, 0) = 0 (len=0) +func (b *bitArray) Truncate(x *bitArray, length uint8) *bitArray { + if length >= x.len { + return b.Set(x) + } + + b.Set(x) + b.len = length + + // Clear all words beyond what's needed + switch { + case length == 0: + b.words = [4]uint64{0, 0, 0, 0} + case length <= 64: + mask := maxUint64 >> (64 - length) + b.words[0] &= mask + b.words[1] = 0 + b.words[2] = 0 + b.words[3] = 0 + case length <= 128: + mask := maxUint64 >> (128 - length) + b.words[1] &= mask + b.words[2] = 0 + b.words[3] = 0 + case length <= 192: + mask := maxUint64 >> (192 - length) + b.words[2] &= mask + b.words[3] = 0 + default: + mask := maxUint64 >> (256 - uint16(length)) + b.words[3] &= mask + } + + return b +} + // Rsh sets b = x >> n and returns b. func (b *bitArray) Rsh(x *bitArray, n uint8) *bitArray { if x.len == 0 { - return b.set(x) + return b.Set(x) } if n >= x.len { x.clear() - return b.set(x) + return b.Set(x) } switch { case n == 0: - return b.set(x) + return b.Set(x) case n >= 192: b.rsh192(x) b.len = x.len - n @@ -128,7 +156,7 @@ func (b *bitArray) Rsh(x *bitArray, n uint8) *bitArray { b.words[1] = (b.words[1] >> n) | (b.words[2] << (64 - n)) b.words[2] >>= n default: - b.set(x) + b.Set(x) b.len -= n b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) b.words[1] = (b.words[1] >> n) | (b.words[2] << (64 - n)) @@ -148,7 +176,27 @@ func (b *bitArray) Equal(x *bitArray) bool { b.words[3] == x.words[3] } -func (b *bitArray) set(x *bitArray) *bitArray { +func (b *bitArray) SetFelt(length uint8, f *felt.Felt) *bitArray { + b.setFelt(f) + b.len = length + return b +} + +func (b *bitArray) SetFelt251(f *felt.Felt) *bitArray { + b.setFelt(f) + b.len = 251 + return b +} + +func (b *bitArray) setFelt(f *felt.Felt) { + res := f.Bytes() + b.words[3] = binary.BigEndian.Uint64(res[0:8]) + b.words[2] = binary.BigEndian.Uint64(res[8:16]) + b.words[1] = binary.BigEndian.Uint64(res[16:24]) + b.words[0] = binary.BigEndian.Uint64(res[24:32]) +} + +func (b *bitArray) Set(x *bitArray) *bitArray { b.len = x.len b.words[0] = x.words[0] b.words[1] = x.words[1] diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index b740c93143..0e031ff57c 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -207,3 +207,284 @@ func TestRsh(t *testing.T) { }) } } + +func TestPrefixEqual(t *testing.T) { + tests := []struct { + name string + a *bitArray + b *bitArray + want bool + }{ + { + name: "equal lengths, equal values", + a: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + b: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + want: true, + }, + { + name: "equal lengths, different values", + a: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + b: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFF0, 0, 0, 0}, + }, + want: false, + }, + { + name: "different lengths, a longer but same prefix", + a: &bitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + b: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + want: true, + }, + { + name: "different lengths, b longer but same prefix", + a: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + b: &bitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + want: true, + }, + { + name: "different lengths, different prefix", + a: &bitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + b: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFF0, 0, 0, 0}, + }, + want: false, + }, + { + name: "zero length arrays", + a: &bitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + b: &bitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + want: true, + }, + { + name: "one zero length array", + a: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + b: &bitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + want: true, + }, + { + name: "max length difference", + a: &bitArray{ + len: 251, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFF}, + }, + b: &bitArray{ + len: 1, + words: [4]uint64{0x1, 0, 0, 0}, + }, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.a.PrefixEqual(tt.b); got != tt.want { + t.Errorf("PrefixEqual() = %v, want %v", got, tt.want) + } + // Test symmetry: a.PrefixEqual(b) should equal b.PrefixEqual(a) + if got := tt.b.PrefixEqual(tt.a); got != tt.want { + t.Errorf("PrefixEqual() symmetric test = %v, want %v", got, tt.want) + } + }) + } +} + +func TestTruncate(t *testing.T) { + tests := []struct { + name string + initial bitArray + length uint8 + expected bitArray + }{ + { + name: "truncate to zero", + initial: bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + length: 0, + expected: bitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "truncate within first word - 32 bits", + initial: bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + length: 32, + expected: bitArray{ + len: 32, + words: [4]uint64{0x00000000FFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "truncate to single bit", + initial: bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + length: 1, + expected: bitArray{ + len: 1, + words: [4]uint64{0x0000000000000001, 0, 0, 0}, + }, + }, + { + name: "truncate across words - 100 bits", + initial: bitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + length: 100, + expected: bitArray{ + len: 100, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0x0000000FFFFFFFFF, 0, 0}, + }, + }, + { + name: "truncate at word boundary - 64 bits", + initial: bitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + length: 64, + expected: bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "truncate at word boundary - 128 bits", + initial: bitArray{ + len: 192, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, + }, + length: 128, + expected: bitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + }, + { + name: "truncate in third word - 150 bits", + initial: bitArray{ + len: 192, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, + }, + length: 150, + expected: bitArray{ + len: 150, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x3FFFFF, 0}, + }, + }, + { + name: "truncate in fourth word - 220 bits", + initial: bitArray{ + len: 255, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}, + }, + length: 220, + expected: bitArray{ + len: 220, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFF}, + }, + }, + { + name: "truncate max length - 251 bits", + initial: bitArray{ + len: 255, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}, + }, + length: 251, + expected: bitArray{ + len: 251, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFF}, + }, + }, + { + name: "truncate sparse bits", + initial: bitArray{ + len: 128, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x5555555555555555, 0, 0}, + }, + length: 100, + expected: bitArray{ + len: 100, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x0000000555555555, 0, 0}, + }, + }, + { + name: "no change when new length equals current length", + initial: bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + length: 64, + expected: bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "no change when new length greater than current length", + initial: bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + length: 128, + expected: bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := new(bitArray).Truncate(&tt.initial, tt.length) + if !result.Equal(&tt.expected) { + t.Errorf("Truncate() got = %+v, want %+v", result, tt.expected) + } + }) + } +} From b32479fe71607b12d843ac5525bf7367644ee7d1 Mon Sep 17 00:00:00 2001 From: weiihann Date: Sun, 15 Dec 2024 17:09:17 +0800 Subject: [PATCH 08/30] add CommonMSBs --- core/trie/bitarray.go | 132 +++++++++++++++- core/trie/bitarray_test.go | 314 ++++++++++++++++++++++++++++++++++++- 2 files changed, 441 insertions(+), 5 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 3a7387f320..c542c1e02a 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -1,17 +1,23 @@ package trie import ( + "bytes" "encoding/binary" "math" + "math/bits" "github.com/NethermindEth/juno/core/felt" ) const ( maxUint64 = uint64(math.MaxUint64) + byteBits = 8 ) -var maxBitArray = [4]uint64{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64} +var ( + maxBitArray = [4]uint64{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64} + emptyBitArray = &bitArray{len: 0, words: [4]uint64{0, 0, 0, 0}} +) // bitArray is a structure that represents a bit array with a max length of 255 bits. // It uses a little endian representation to do bitwise operations of the words efficiently. @@ -60,7 +66,24 @@ func (b *bitArray) Bytes() [32]byte { return res } -func (b *bitArray) PrefixEqual(x *bitArray) bool { +// EqualMSBs checks if two bit arrays share the same most significant bits, where the length of +// the check is determined by the shorter array. Returns true if either array has +// length 0, or if the first min(b.len, x.len) MSBs are identical. +// +// For example: +// +// a = 1101 (len=4) +// b = 11010111 (len=8) +// a.EqualMSBs(b) = true // First 4 MSBs match +// +// a = 1100 (len=4) +// b = 1101 (len=4) +// a.EqualMSBs(b) = false // All bits compared, not equal +// +// a = 1100 (len=4) +// b = [] (len=0) +// a.EqualMSBs(b) = true // Zero length is always a prefix match +func (b *bitArray) EqualMSBs(x *bitArray) bool { if b.len == x.len { return b.Equal(x) } @@ -70,8 +93,8 @@ func (b *bitArray) PrefixEqual(x *bitArray) bool { } var long, short *bitArray - long, short = b, x + long, short = b, x if b.len < x.len { long, short = x, b } @@ -123,6 +146,66 @@ func (b *bitArray) Truncate(x *bitArray, length uint8) *bitArray { return b } +// CommonMSBs sets b to the longest sequence of matching most significant bits between two bit arrays. +// For example: +// +// x = 1101 0111 (len=8) +// y = 1101 0000 (len=8) +// CommonMSBs(x,y) = 1101 (len=4) +func (b *bitArray) CommonMSBs(x, y *bitArray) *bitArray { + if x.len == 0 || y.len == 0 { + return emptyBitArray + } + + long, short := x, y + if x.len < y.len { + long, short = y, x + } + + // Align arrays by right-shifting longer array and then XOR to find differences + // Example: + // short = 1101 (len=4) + // long = 1101 0111 (len=8) + // + // Step 1: Right shift longer array by 4 + // short = 1100 + // long = 1101 + // + // Step 2: XOR shows difference at last bit + // 1100 (short) + // 1101 (aligned long) + // ---- XOR + // 0001 (difference at last position) + // We can then use the position of the first set bit and right-shift to get the common MSBs + diff := long.len - short.len + b.Rsh(long, diff).Xor(b, short) + divergentBit := findFirstSetBit(b) + + return b.Rsh(short, divergentBit) +} + +// findFirstSetBit returns the position of the first '1' bit in the array, +// scanning from most significant to least significant bit. +// +// The bit position is counted from the least significant bit, starting at 0. +// For example: +// +// array = 0000 0000 ... 0100 (len=251) +// findFirstSetBit() = 2 // third bit from right is set +func findFirstSetBit(b *bitArray) uint8 { + if b.len == 0 { + return 0 + } + + for i := 3; i >= 0; i-- { + if word := b.words[i]; word != 0 { + return uint8((i+1)*64 - bits.LeadingZeros64(word)) + } + } + + return 0 +} + // Rsh sets b = x >> n and returns b. func (b *bitArray) Rsh(x *bitArray, n uint8) *bitArray { if x.len == 0 { @@ -167,6 +250,15 @@ func (b *bitArray) Rsh(x *bitArray, n uint8) *bitArray { return b } +// Xor sets b = x ^ y and returns b. +func (b *bitArray) Xor(x, y *bitArray) *bitArray { + b.words[0] = x.words[0] ^ y.words[0] + b.words[1] = x.words[1] ^ y.words[1] + b.words[2] = x.words[2] ^ y.words[2] + b.words[3] = x.words[3] ^ y.words[3] + return b +} + // Eq checks if two bit arrays are equal func (b *bitArray) Equal(x *bitArray) bool { return b.len == x.len && @@ -176,6 +268,21 @@ func (b *bitArray) Equal(x *bitArray) bool { b.words[3] == x.words[3] } +// Write serializes the bitArray into a bytes buffer in the following format: +// - First byte: length of the bit array (0-255) +// - Remaining bytes: the necessary bytes included in big endian order +// Example: +// +// bitArray{len: 10, words: [4]uint64{0x03FF}} -> [0x0A, 0x03, 0xFF] +func (b *bitArray) Write(buf *bytes.Buffer) (int, error) { + if err := buf.WriteByte(b.len); err != nil { + return 0, err + } + + n, err := buf.Write(b.activeBytes()) + return n + 1, err +} + func (b *bitArray) SetFelt(length uint8, f *felt.Felt) *bitArray { b.setFelt(f) b.len = length @@ -205,6 +312,25 @@ func (b *bitArray) Set(x *bitArray) *bitArray { return b } +// byteCount returns the minimum number of bytes needed to represent the bit array. +// It rounds up to the nearest byte. +func (b *bitArray) byteCount() uint8 { + // Cast to uint16 to avoid overflow + return uint8((uint16(b.len) + uint16(byteBits-1)) / uint16(byteBits)) +} + +// activeBytes returns a slice containing only the bytes that are actually used +// by the bit array, excluding leading zero bytes. The returned slice is in +// big-endian order. +// +// Example: +// +// len = 10, words = [0x3FF, 0, 0, 0] -> [0x03, 0xFF] +func (b *bitArray) activeBytes() []byte { + wordsBytes := b.Bytes() + return wordsBytes[32-b.byteCount():] +} + func (b *bitArray) rsh64(x *bitArray) { b.words[3], b.words[2], b.words[1], b.words[0] = 0, x.words[3], x.words[2], x.words[1] } diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 0e031ff57c..825277e6e6 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -5,6 +5,8 @@ import ( "encoding/binary" "math/bits" "testing" + + "github.com/stretchr/testify/assert" ) func TestBytes(t *testing.T) { @@ -172,6 +174,18 @@ func TestRsh(t *testing.T) { words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, }, + { + name: "shift by 127", + initial: &bitArray{ + len: 255, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF}, + }, + shiftBy: 127, + expected: &bitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + }, { name: "shift by 128", initial: &bitArray{ @@ -315,11 +329,11 @@ func TestPrefixEqual(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := tt.a.PrefixEqual(tt.b); got != tt.want { + if got := tt.a.EqualMSBs(tt.b); got != tt.want { t.Errorf("PrefixEqual() = %v, want %v", got, tt.want) } // Test symmetry: a.PrefixEqual(b) should equal b.PrefixEqual(a) - if got := tt.b.PrefixEqual(tt.a); got != tt.want { + if got := tt.b.EqualMSBs(tt.a); got != tt.want { t.Errorf("PrefixEqual() symmetric test = %v, want %v", got, tt.want) } }) @@ -488,3 +502,299 @@ func TestTruncate(t *testing.T) { }) } } + +func TestWrite(t *testing.T) { + tests := []struct { + name string + bitArray bitArray + want []byte // Expected bytes after writing + }{ + { + name: "empty bit array", + bitArray: bitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + want: []byte{0}, // Just the length byte + }, + { + name: "8 bits", + bitArray: bitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, + }, + want: []byte{8, 0xFF}, // length byte + 1 data byte + }, + { + name: "10 bits requiring 2 bytes", + bitArray: bitArray{ + len: 10, + words: [4]uint64{0x3FF, 0, 0, 0}, // 1111111111 in binary + }, + want: []byte{10, 0x3, 0xFF}, // length byte + 2 data bytes + }, + { + name: "64 bits", + bitArray: bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + want: append( + []byte{64}, // length byte + []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}..., // 8 data bytes + ), + }, + { + name: "251 bits", + bitArray: bitArray{ + len: 251, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFF}, + }, + want: func() []byte { + b := make([]byte, 33) // 1 length byte + 32 data bytes + b[0] = 251 // length byte + // First byte is 0x07 (from the most significant bits) + b[1] = 0x07 + // Rest of the bytes are 0xFF + for i := 2; i < 33; i++ { + b[i] = 0xFF + } + return b + }(), + }, + { + name: "sparse bits", + bitArray: bitArray{ + len: 16, + words: [4]uint64{0xAAAA, 0, 0, 0}, // 1010101010101010 in binary + }, + want: []byte{16, 0xAA, 0xAA}, // length byte + 2 data bytes + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := new(bytes.Buffer) + gotN, err := tt.bitArray.Write(buf) + assert.NoError(t, err) + + // Check number of bytes written + if gotN != len(tt.want) { + t.Errorf("Write() wrote %d bytes, want %d", gotN, len(tt.want)) + } + + // Check written bytes + if got := buf.Bytes(); !bytes.Equal(got, tt.want) { + t.Errorf("Write() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCommonPrefix(t *testing.T) { + tests := []struct { + name string + x *bitArray + y *bitArray + want *bitArray + }{ + { + name: "empty arrays", + x: emptyBitArray, + y: emptyBitArray, + want: emptyBitArray, + }, + { + name: "one empty array", + x: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + y: emptyBitArray, + want: emptyBitArray, + }, + { + name: "identical arrays - single word", + x: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + y: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + want: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "identical arrays - multiple words", + x: &bitArray{ + len: 192, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, + }, + y: &bitArray{ + len: 192, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, + }, + want: &bitArray{ + len: 192, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, + }, + }, + { + name: "different lengths with common prefix - first word", + x: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + y: &bitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + want: &bitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "different lengths with common prefix - multiple words", + x: &bitArray{ + len: 255, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF}, + }, + y: &bitArray{ + len: 127, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF, 0, 0}, + }, + want: &bitArray{ + len: 127, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF, 0, 0}, + }, + }, + { + name: "different at first bit", + x: &bitArray{ + len: 64, + words: [4]uint64{0x7FFFFFFFFFFFFFFF, 0, 0, 0}, + }, + y: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + want: &bitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "different in middle of first word", + x: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFF0FFFFFFF, 0, 0, 0}, + }, + y: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + want: &bitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "different in second word", + x: &bitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFF0FFFFFFF, 0, 0}, + }, + y: &bitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + want: &bitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "different in third word", + x: &bitArray{ + len: 192, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, + }, + y: &bitArray{ + len: 192, + words: [4]uint64{0, 0, 0xFFFFFFFFFFFFFF0F, 0}, + }, + want: &bitArray{ + len: 56, + words: [4]uint64{0xFFFFFFFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "different in last word", + x: &bitArray{ + len: 251, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFF}, + }, + y: &bitArray{ + len: 251, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFF0FFFFFFF}, + }, + want: &bitArray{ + len: 27, + words: [4]uint64{0x7FFFFFF}, + }, + }, + { + name: "sparse bits with common prefix", + x: &bitArray{ + len: 128, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0xAAAAAAAAAAAAAAAA, 0, 0}, + }, + y: &bitArray{ + len: 128, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0xAAAAAAAAAAAAA000, 0, 0}, + }, + want: &bitArray{ + len: 52, + words: [4]uint64{0xAAAAAAAAAAAAA, 0, 0, 0}, + }, + }, + { + name: "max length difference", + x: &bitArray{ + len: 255, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF}, + }, + y: &bitArray{ + len: 1, + words: [4]uint64{0x1, 0, 0, 0}, + }, + want: &bitArray{ + len: 1, + words: [4]uint64{0x1, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(bitArray) + gotSymmetric := new(bitArray) + + got.CommonMSBs(tt.x, tt.y) + if !got.Equal(tt.want) { + t.Errorf("CommonMSBs() = %v, want %v", got, tt.want) + } + + // Test symmetry: x.CommonMSBs(y) should equal y.CommonMSBs(x) + gotSymmetric.CommonMSBs(tt.y, tt.x) + if !gotSymmetric.Equal(tt.want) { + t.Errorf("CommonMSBs() symmetric test = %v, want %v", gotSymmetric, tt.want) + } + }) + } +} From a560990de868f0b98c44db4cb1e233809803d34b Mon Sep 17 00:00:00 2001 From: weiihann Date: Sun, 15 Dec 2024 17:12:54 +0800 Subject: [PATCH 09/30] minor comments --- core/trie/bitarray.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index c542c1e02a..a229c974ba 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -102,8 +102,8 @@ func (b *bitArray) EqualMSBs(x *bitArray) bool { return long.Rsh(long, long.len-short.len).Equal(short) } -// Truncate sets b to the first 'length' bits of x and returns b. -// If length >= x.len, b is a copy of x. +// Truncate sets b to the first 'length' bits of x (starting from the least significant bit). +// If length >= x.len, b is an exact copy of x. // Any bits beyond the specified length are cleared to zero. // For example: // From e8e5eb0d56a12708f369b3ff9a4a808ee797c90e Mon Sep 17 00:00:00 2001 From: weiihann Date: Sun, 15 Dec 2024 19:05:56 +0800 Subject: [PATCH 10/30] add UnmarshalBinary --- core/trie/bitarray.go | 25 ++++++++++++++++++++++++- core/trie/bitarray_test.go | 13 +++++++++++-- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index a229c974ba..0cca0ccf71 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -268,7 +268,7 @@ func (b *bitArray) Equal(x *bitArray) bool { b.words[3] == x.words[3] } -// Write serializes the bitArray into a bytes buffer in the following format: +// Write serialises the bitArray into a bytes buffer in the following format: // - First byte: length of the bit array (0-255) // - Remaining bytes: the necessary bytes included in big endian order // Example: @@ -283,6 +283,21 @@ func (b *bitArray) Write(buf *bytes.Buffer) (int, error) { return n + 1, err } +// UnmarshalBinary deserialises the bitArray from a bytes buffer in the following format: +// - First byte: length of the bit array (0-255) +// - Remaining bytes: the necessary bytes included in big endian order +// Example: +// +// [0x0A, 0x03, 0xFF] -> bitArray{len: 10, words: [4]uint64{0x03FF}} +func (b *bitArray) UnmarshalBinary(data []byte) error { + b.len = data[0] + + var bs [32]byte + copy(bs[32-b.byteCount():], data[1:]) + b.SetBytes32(bs) + return nil +} + func (b *bitArray) SetFelt(length uint8, f *felt.Felt) *bitArray { b.setFelt(f) b.len = length @@ -295,6 +310,14 @@ func (b *bitArray) SetFelt251(f *felt.Felt) *bitArray { return b } +func (b *bitArray) SetBytes32(data [32]byte) *bitArray { + b.words[3] = binary.BigEndian.Uint64(data[0:8]) + b.words[2] = binary.BigEndian.Uint64(data[8:16]) + b.words[1] = binary.BigEndian.Uint64(data[16:24]) + b.words[0] = binary.BigEndian.Uint64(data[24:32]) + return b +} + func (b *bitArray) setFelt(f *felt.Felt) { res := f.Bytes() b.words[3] = binary.BigEndian.Uint64(res[0:8]) diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 825277e6e6..649002c1e5 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -503,7 +503,7 @@ func TestTruncate(t *testing.T) { } } -func TestWrite(t *testing.T) { +func TestWriteAndUnmarshalBinary(t *testing.T) { tests := []struct { name string bitArray bitArray @@ -584,9 +584,18 @@ func TestWrite(t *testing.T) { } // Check written bytes - if got := buf.Bytes(); !bytes.Equal(got, tt.want) { + got := buf.Bytes() + if !bytes.Equal(got, tt.want) { t.Errorf("Write() = %v, want %v", got, tt.want) } + + gotBitArray := new(bitArray) + if err := gotBitArray.UnmarshalBinary(got); err != nil { + t.Errorf("UnmarshalBinary() = %v", err) + } + if !gotBitArray.Equal(&tt.bitArray) { + t.Errorf("UnmarshalBinary() = %v, want %v", gotBitArray, tt.bitArray) + } }) } } From 993a9f193d80c724fd960118355561ed4dcd06d6 Mon Sep 17 00:00:00 2001 From: weiihann Date: Sun, 15 Dec 2024 19:13:02 +0800 Subject: [PATCH 11/30] more bitArray public --- core/trie/bitarray.go | 67 +++++---- core/trie/bitarray_test.go | 278 ++++++++++++++++++------------------- 2 files changed, 177 insertions(+), 168 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 0cca0ccf71..57f3d28072 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -16,22 +16,31 @@ const ( var ( maxBitArray = [4]uint64{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64} - emptyBitArray = &bitArray{len: 0, words: [4]uint64{0, 0, 0, 0}} + emptyBitArray = &BitArray{len: 0, words: [4]uint64{0, 0, 0, 0}} ) -// bitArray is a structure that represents a bit array with a max length of 255 bits. +// BitArray is a structure that represents a bit array with a max length of 255 bits. // It uses a little endian representation to do bitwise operations of the words efficiently. // Unlike normal bit arrays, it has a `len` field that represents the number of used bits. // For example, if len is 10, it means that the 2^9, 2^8, ..., 2^0 bits are used. // The reason why 255 bits is the max length is because we only need up to 251 bits for a given trie key. // Though words can be used to represent 256 bits, we don't want to add an additional byte for the length. -type bitArray struct { +type BitArray struct { len uint8 // number of used bits words [4]uint64 // little endian (i.e. words[0] is the least significant) } +func (b *BitArray) Felt() *felt.Felt { + bs := b.Bytes() + return new(felt.Felt).SetBytes(bs[:]) +} + +func (b *BitArray) Len() uint8 { + return b.len +} + // Bytes returns the bytes representation of the bit array in big endian format -func (b *bitArray) Bytes() [32]byte { +func (b *BitArray) Bytes() [32]byte { var res [32]byte switch { @@ -83,7 +92,7 @@ func (b *bitArray) Bytes() [32]byte { // a = 1100 (len=4) // b = [] (len=0) // a.EqualMSBs(b) = true // Zero length is always a prefix match -func (b *bitArray) EqualMSBs(x *bitArray) bool { +func (b *BitArray) EqualMSBs(x *BitArray) bool { if b.len == x.len { return b.Equal(x) } @@ -92,7 +101,7 @@ func (b *bitArray) EqualMSBs(x *bitArray) bool { return true } - var long, short *bitArray + var long, short *BitArray long, short = b, x if b.len < x.len { @@ -111,7 +120,7 @@ func (b *bitArray) EqualMSBs(x *bitArray) bool { // Truncate(x, 4) = 1011 (len=4) // Truncate(x, 10) = 11001011 (len=8, original x) // Truncate(x, 0) = 0 (len=0) -func (b *bitArray) Truncate(x *bitArray, length uint8) *bitArray { +func (b *BitArray) Truncate(x *BitArray, length uint8) *BitArray { if length >= x.len { return b.Set(x) } @@ -152,7 +161,7 @@ func (b *bitArray) Truncate(x *bitArray, length uint8) *bitArray { // x = 1101 0111 (len=8) // y = 1101 0000 (len=8) // CommonMSBs(x,y) = 1101 (len=4) -func (b *bitArray) CommonMSBs(x, y *bitArray) *bitArray { +func (b *BitArray) CommonMSBs(x, y *BitArray) *BitArray { if x.len == 0 || y.len == 0 { return emptyBitArray } @@ -192,7 +201,7 @@ func (b *bitArray) CommonMSBs(x, y *bitArray) *bitArray { // // array = 0000 0000 ... 0100 (len=251) // findFirstSetBit() = 2 // third bit from right is set -func findFirstSetBit(b *bitArray) uint8 { +func findFirstSetBit(b *BitArray) uint8 { if b.len == 0 { return 0 } @@ -207,7 +216,7 @@ func findFirstSetBit(b *bitArray) uint8 { } // Rsh sets b = x >> n and returns b. -func (b *bitArray) Rsh(x *bitArray, n uint8) *bitArray { +func (b *BitArray) Rsh(x *BitArray, n uint8) *BitArray { if x.len == 0 { return b.Set(x) } @@ -251,7 +260,7 @@ func (b *bitArray) Rsh(x *bitArray, n uint8) *bitArray { } // Xor sets b = x ^ y and returns b. -func (b *bitArray) Xor(x, y *bitArray) *bitArray { +func (b *BitArray) Xor(x, y *BitArray) *BitArray { b.words[0] = x.words[0] ^ y.words[0] b.words[1] = x.words[1] ^ y.words[1] b.words[2] = x.words[2] ^ y.words[2] @@ -260,7 +269,7 @@ func (b *bitArray) Xor(x, y *bitArray) *bitArray { } // Eq checks if two bit arrays are equal -func (b *bitArray) Equal(x *bitArray) bool { +func (b *BitArray) Equal(x *BitArray) bool { return b.len == x.len && b.words[0] == x.words[0] && b.words[1] == x.words[1] && @@ -268,13 +277,13 @@ func (b *bitArray) Equal(x *bitArray) bool { b.words[3] == x.words[3] } -// Write serialises the bitArray into a bytes buffer in the following format: +// Write serialises the BitArray into a bytes buffer in the following format: // - First byte: length of the bit array (0-255) // - Remaining bytes: the necessary bytes included in big endian order // Example: // -// bitArray{len: 10, words: [4]uint64{0x03FF}} -> [0x0A, 0x03, 0xFF] -func (b *bitArray) Write(buf *bytes.Buffer) (int, error) { +// BitArray{len: 10, words: [4]uint64{0x03FF}} -> [0x0A, 0x03, 0xFF] +func (b *BitArray) Write(buf *bytes.Buffer) (int, error) { if err := buf.WriteByte(b.len); err != nil { return 0, err } @@ -283,13 +292,13 @@ func (b *bitArray) Write(buf *bytes.Buffer) (int, error) { return n + 1, err } -// UnmarshalBinary deserialises the bitArray from a bytes buffer in the following format: +// UnmarshalBinary deserialises the BitArray from a bytes buffer in the following format: // - First byte: length of the bit array (0-255) // - Remaining bytes: the necessary bytes included in big endian order // Example: // -// [0x0A, 0x03, 0xFF] -> bitArray{len: 10, words: [4]uint64{0x03FF}} -func (b *bitArray) UnmarshalBinary(data []byte) error { +// [0x0A, 0x03, 0xFF] -> BitArray{len: 10, words: [4]uint64{0x03FF}} +func (b *BitArray) UnmarshalBinary(data []byte) error { b.len = data[0] var bs [32]byte @@ -298,19 +307,19 @@ func (b *bitArray) UnmarshalBinary(data []byte) error { return nil } -func (b *bitArray) SetFelt(length uint8, f *felt.Felt) *bitArray { +func (b *BitArray) SetFelt(length uint8, f *felt.Felt) *BitArray { b.setFelt(f) b.len = length return b } -func (b *bitArray) SetFelt251(f *felt.Felt) *bitArray { +func (b *BitArray) SetFelt251(f *felt.Felt) *BitArray { b.setFelt(f) b.len = 251 return b } -func (b *bitArray) SetBytes32(data [32]byte) *bitArray { +func (b *BitArray) SetBytes32(data [32]byte) *BitArray { b.words[3] = binary.BigEndian.Uint64(data[0:8]) b.words[2] = binary.BigEndian.Uint64(data[8:16]) b.words[1] = binary.BigEndian.Uint64(data[16:24]) @@ -318,7 +327,7 @@ func (b *bitArray) SetBytes32(data [32]byte) *bitArray { return b } -func (b *bitArray) setFelt(f *felt.Felt) { +func (b *BitArray) setFelt(f *felt.Felt) { res := f.Bytes() b.words[3] = binary.BigEndian.Uint64(res[0:8]) b.words[2] = binary.BigEndian.Uint64(res[8:16]) @@ -326,7 +335,7 @@ func (b *bitArray) setFelt(f *felt.Felt) { b.words[0] = binary.BigEndian.Uint64(res[24:32]) } -func (b *bitArray) Set(x *bitArray) *bitArray { +func (b *BitArray) Set(x *BitArray) *BitArray { b.len = x.len b.words[0] = x.words[0] b.words[1] = x.words[1] @@ -337,7 +346,7 @@ func (b *bitArray) Set(x *bitArray) *bitArray { // byteCount returns the minimum number of bytes needed to represent the bit array. // It rounds up to the nearest byte. -func (b *bitArray) byteCount() uint8 { +func (b *BitArray) byteCount() uint8 { // Cast to uint16 to avoid overflow return uint8((uint16(b.len) + uint16(byteBits-1)) / uint16(byteBits)) } @@ -349,24 +358,24 @@ func (b *bitArray) byteCount() uint8 { // Example: // // len = 10, words = [0x3FF, 0, 0, 0] -> [0x03, 0xFF] -func (b *bitArray) activeBytes() []byte { +func (b *BitArray) activeBytes() []byte { wordsBytes := b.Bytes() return wordsBytes[32-b.byteCount():] } -func (b *bitArray) rsh64(x *bitArray) { +func (b *BitArray) rsh64(x *BitArray) { b.words[3], b.words[2], b.words[1], b.words[0] = 0, x.words[3], x.words[2], x.words[1] } -func (b *bitArray) rsh128(x *bitArray) { +func (b *BitArray) rsh128(x *BitArray) { b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, x.words[3], x.words[2] } -func (b *bitArray) rsh192(x *bitArray) { +func (b *BitArray) rsh192(x *BitArray) { b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, x.words[3] } -func (b *bitArray) clear() *bitArray { +func (b *BitArray) clear() *BitArray { b.len = 0 b.words[0], b.words[1], b.words[2], b.words[3] = 0, 0, 0, 0 return b diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 649002c1e5..2c8265e09a 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -11,18 +11,18 @@ import ( func TestBytes(t *testing.T) { tests := []struct { - name string - bitArray bitArray - want [32]byte + name string + ba BitArray + want [32]byte }{ { - name: "length == 0", - bitArray: bitArray{len: 0, words: maxBitArray}, - want: [32]byte{}, + name: "length == 0", + ba: BitArray{len: 0, words: maxBitArray}, + want: [32]byte{}, }, { - name: "length < 64", - bitArray: bitArray{len: 38, words: maxBitArray}, + name: "length < 64", + ba: BitArray{len: 38, words: maxBitArray}, want: func() [32]byte { var b [32]byte binary.BigEndian.PutUint64(b[24:32], 0x3FFFFFFFFF) @@ -30,8 +30,8 @@ func TestBytes(t *testing.T) { }(), }, { - name: "64 <= length < 128", - bitArray: bitArray{len: 100, words: maxBitArray}, + name: "64 <= length < 128", + ba: BitArray{len: 100, words: maxBitArray}, want: func() [32]byte { var b [32]byte binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFF) @@ -40,8 +40,8 @@ func TestBytes(t *testing.T) { }(), }, { - name: "128 <= length < 192", - bitArray: bitArray{len: 130, words: maxBitArray}, + name: "128 <= length < 192", + ba: BitArray{len: 130, words: maxBitArray}, want: func() [32]byte { var b [32]byte binary.BigEndian.PutUint64(b[8:16], 0x3) @@ -51,8 +51,8 @@ func TestBytes(t *testing.T) { }(), }, { - name: "192 <= length < 255", - bitArray: bitArray{len: 201, words: maxBitArray}, + name: "192 <= length < 255", + ba: BitArray{len: 201, words: maxBitArray}, want: func() [32]byte { var b [32]byte binary.BigEndian.PutUint64(b[0:8], 0x1FF) @@ -63,8 +63,8 @@ func TestBytes(t *testing.T) { }(), }, { - name: "length == 254", - bitArray: bitArray{len: 254, words: maxBitArray}, + name: "length == 254", + ba: BitArray{len: 254, words: maxBitArray}, want: func() [32]byte { var b [32]byte binary.BigEndian.PutUint64(b[0:8], 0x3FFFFFFFFFFFFFFF) @@ -75,8 +75,8 @@ func TestBytes(t *testing.T) { }(), }, { - name: "length == 255", - bitArray: bitArray{len: 255, words: maxBitArray}, + name: "length == 255", + ba: BitArray{len: 255, words: maxBitArray}, want: func() [32]byte { var b [32]byte binary.BigEndian.PutUint64(b[0:8], 0x7FFFFFFFFFFFFFFF) @@ -90,18 +90,18 @@ func TestBytes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := tt.bitArray.Bytes() + got := tt.ba.Bytes() if !bytes.Equal(got[:], tt.want[:]) { - t.Errorf("bitArray.Bytes() = %v, want %v", got, tt.want) + t.Errorf("BitArray.Bytes() = %v, want %v", got, tt.want) } - // check if the received bytes has the same bit count as the bitArray.len + // check if the received bytes has the same bit count as the BitArray.len count := 0 for _, b := range got { count += bits.OnesCount8(b) } - if count != int(tt.bitArray.len) { - t.Errorf("bitArray.Bytes() bit count = %v, want %v", count, tt.bitArray.len) + if count != int(tt.ba.len) { + t.Errorf("BitArray.Bytes() bit count = %v, want %v", count, tt.ba.len) } }) } @@ -110,102 +110,102 @@ func TestBytes(t *testing.T) { func TestRsh(t *testing.T) { tests := []struct { name string - initial *bitArray + initial *BitArray shiftBy uint8 - expected *bitArray + expected *BitArray }{ { name: "zero length array", - initial: &bitArray{ + initial: &BitArray{ len: 0, words: [4]uint64{0, 0, 0, 0}, }, shiftBy: 5, - expected: &bitArray{ + expected: &BitArray{ len: 0, words: [4]uint64{0, 0, 0, 0}, }, }, { name: "shift by 0", - initial: &bitArray{ + initial: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, shiftBy: 0, - expected: &bitArray{ + expected: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, }, { name: "shift by more than length", - initial: &bitArray{ + initial: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, }, shiftBy: 65, - expected: &bitArray{ + expected: &BitArray{ len: 0, words: [4]uint64{0, 0, 0, 0}, }, }, { name: "shift by less than 64", - initial: &bitArray{ + initial: &BitArray{ len: 128, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, }, shiftBy: 32, - expected: &bitArray{ + expected: &BitArray{ len: 96, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0x00000000FFFFFFFF, 0, 0}, }, }, { name: "shift by exactly 64", - initial: &bitArray{ + initial: &BitArray{ len: 128, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, }, shiftBy: 64, - expected: &bitArray{ + expected: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, }, { name: "shift by 127", - initial: &bitArray{ + initial: &BitArray{ len: 255, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF}, }, shiftBy: 127, - expected: &bitArray{ + expected: &BitArray{ len: 128, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, }, }, { name: "shift by 128", - initial: &bitArray{ + initial: &BitArray{ len: 251, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}, }, shiftBy: 128, - expected: &bitArray{ + expected: &BitArray{ len: 123, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, }, }, { name: "shift by 192", - initial: &bitArray{ + initial: &BitArray{ len: 251, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}, }, shiftBy: 192, - expected: &bitArray{ + expected: &BitArray{ len: 59, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, @@ -214,7 +214,7 @@ func TestRsh(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := new(bitArray).Rsh(tt.initial, tt.shiftBy) + result := new(BitArray).Rsh(tt.initial, tt.shiftBy) if !result.Equal(tt.expected) { t.Errorf("Rsh() got = %+v, want %+v", result, tt.expected) } @@ -225,17 +225,17 @@ func TestRsh(t *testing.T) { func TestPrefixEqual(t *testing.T) { tests := []struct { name string - a *bitArray - b *bitArray + a *BitArray + b *BitArray want bool }{ { name: "equal lengths, equal values", - a: &bitArray{ + a: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, - b: &bitArray{ + b: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, @@ -243,11 +243,11 @@ func TestPrefixEqual(t *testing.T) { }, { name: "equal lengths, different values", - a: &bitArray{ + a: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, - b: &bitArray{ + b: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFF0, 0, 0, 0}, }, @@ -255,11 +255,11 @@ func TestPrefixEqual(t *testing.T) { }, { name: "different lengths, a longer but same prefix", - a: &bitArray{ + a: &BitArray{ len: 128, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, }, - b: &bitArray{ + b: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, @@ -267,11 +267,11 @@ func TestPrefixEqual(t *testing.T) { }, { name: "different lengths, b longer but same prefix", - a: &bitArray{ + a: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, - b: &bitArray{ + b: &BitArray{ len: 128, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, }, @@ -279,11 +279,11 @@ func TestPrefixEqual(t *testing.T) { }, { name: "different lengths, different prefix", - a: &bitArray{ + a: &BitArray{ len: 128, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, }, - b: &bitArray{ + b: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFF0, 0, 0, 0}, }, @@ -291,11 +291,11 @@ func TestPrefixEqual(t *testing.T) { }, { name: "zero length arrays", - a: &bitArray{ + a: &BitArray{ len: 0, words: [4]uint64{0, 0, 0, 0}, }, - b: &bitArray{ + b: &BitArray{ len: 0, words: [4]uint64{0, 0, 0, 0}, }, @@ -303,11 +303,11 @@ func TestPrefixEqual(t *testing.T) { }, { name: "one zero length array", - a: &bitArray{ + a: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, - b: &bitArray{ + b: &BitArray{ len: 0, words: [4]uint64{0, 0, 0, 0}, }, @@ -315,11 +315,11 @@ func TestPrefixEqual(t *testing.T) { }, { name: "max length difference", - a: &bitArray{ + a: &BitArray{ len: 251, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFF}, }, - b: &bitArray{ + b: &BitArray{ len: 1, words: [4]uint64{0x1, 0, 0, 0}, }, @@ -343,150 +343,150 @@ func TestPrefixEqual(t *testing.T) { func TestTruncate(t *testing.T) { tests := []struct { name string - initial bitArray + initial BitArray length uint8 - expected bitArray + expected BitArray }{ { name: "truncate to zero", - initial: bitArray{ + initial: BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, length: 0, - expected: bitArray{ + expected: BitArray{ len: 0, words: [4]uint64{0, 0, 0, 0}, }, }, { name: "truncate within first word - 32 bits", - initial: bitArray{ + initial: BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, length: 32, - expected: bitArray{ + expected: BitArray{ len: 32, words: [4]uint64{0x00000000FFFFFFFF, 0, 0, 0}, }, }, { name: "truncate to single bit", - initial: bitArray{ + initial: BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, length: 1, - expected: bitArray{ + expected: BitArray{ len: 1, words: [4]uint64{0x0000000000000001, 0, 0, 0}, }, }, { name: "truncate across words - 100 bits", - initial: bitArray{ + initial: BitArray{ len: 128, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, }, length: 100, - expected: bitArray{ + expected: BitArray{ len: 100, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0x0000000FFFFFFFFF, 0, 0}, }, }, { name: "truncate at word boundary - 64 bits", - initial: bitArray{ + initial: BitArray{ len: 128, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, }, length: 64, - expected: bitArray{ + expected: BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, }, { name: "truncate at word boundary - 128 bits", - initial: bitArray{ + initial: BitArray{ len: 192, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, }, length: 128, - expected: bitArray{ + expected: BitArray{ len: 128, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, }, }, { name: "truncate in third word - 150 bits", - initial: bitArray{ + initial: BitArray{ len: 192, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, }, length: 150, - expected: bitArray{ + expected: BitArray{ len: 150, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x3FFFFF, 0}, }, }, { name: "truncate in fourth word - 220 bits", - initial: bitArray{ + initial: BitArray{ len: 255, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}, }, length: 220, - expected: bitArray{ + expected: BitArray{ len: 220, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFF}, }, }, { name: "truncate max length - 251 bits", - initial: bitArray{ + initial: BitArray{ len: 255, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}, }, length: 251, - expected: bitArray{ + expected: BitArray{ len: 251, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFF}, }, }, { name: "truncate sparse bits", - initial: bitArray{ + initial: BitArray{ len: 128, words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x5555555555555555, 0, 0}, }, length: 100, - expected: bitArray{ + expected: BitArray{ len: 100, words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x0000000555555555, 0, 0}, }, }, { name: "no change when new length equals current length", - initial: bitArray{ + initial: BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, length: 64, - expected: bitArray{ + expected: BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, }, { name: "no change when new length greater than current length", - initial: bitArray{ + initial: BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, length: 128, - expected: bitArray{ + expected: BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, @@ -495,7 +495,7 @@ func TestTruncate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := new(bitArray).Truncate(&tt.initial, tt.length) + result := new(BitArray).Truncate(&tt.initial, tt.length) if !result.Equal(&tt.expected) { t.Errorf("Truncate() got = %+v, want %+v", result, tt.expected) } @@ -505,13 +505,13 @@ func TestTruncate(t *testing.T) { func TestWriteAndUnmarshalBinary(t *testing.T) { tests := []struct { - name string - bitArray bitArray - want []byte // Expected bytes after writing + name string + ba BitArray + want []byte // Expected bytes after writing }{ { name: "empty bit array", - bitArray: bitArray{ + ba: BitArray{ len: 0, words: [4]uint64{0, 0, 0, 0}, }, @@ -519,7 +519,7 @@ func TestWriteAndUnmarshalBinary(t *testing.T) { }, { name: "8 bits", - bitArray: bitArray{ + ba: BitArray{ len: 8, words: [4]uint64{0xFF, 0, 0, 0}, }, @@ -527,7 +527,7 @@ func TestWriteAndUnmarshalBinary(t *testing.T) { }, { name: "10 bits requiring 2 bytes", - bitArray: bitArray{ + ba: BitArray{ len: 10, words: [4]uint64{0x3FF, 0, 0, 0}, // 1111111111 in binary }, @@ -535,7 +535,7 @@ func TestWriteAndUnmarshalBinary(t *testing.T) { }, { name: "64 bits", - bitArray: bitArray{ + ba: BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, @@ -546,7 +546,7 @@ func TestWriteAndUnmarshalBinary(t *testing.T) { }, { name: "251 bits", - bitArray: bitArray{ + ba: BitArray{ len: 251, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFF}, }, @@ -564,7 +564,7 @@ func TestWriteAndUnmarshalBinary(t *testing.T) { }, { name: "sparse bits", - bitArray: bitArray{ + ba: BitArray{ len: 16, words: [4]uint64{0xAAAA, 0, 0, 0}, // 1010101010101010 in binary }, @@ -575,7 +575,7 @@ func TestWriteAndUnmarshalBinary(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { buf := new(bytes.Buffer) - gotN, err := tt.bitArray.Write(buf) + gotN, err := tt.ba.Write(buf) assert.NoError(t, err) // Check number of bytes written @@ -589,12 +589,12 @@ func TestWriteAndUnmarshalBinary(t *testing.T) { t.Errorf("Write() = %v, want %v", got, tt.want) } - gotBitArray := new(bitArray) + gotBitArray := new(BitArray) if err := gotBitArray.UnmarshalBinary(got); err != nil { t.Errorf("UnmarshalBinary() = %v", err) } - if !gotBitArray.Equal(&tt.bitArray) { - t.Errorf("UnmarshalBinary() = %v, want %v", gotBitArray, tt.bitArray) + if !gotBitArray.Equal(&tt.ba) { + t.Errorf("UnmarshalBinary() = %v, want %v", gotBitArray, tt.ba) } }) } @@ -603,9 +603,9 @@ func TestWriteAndUnmarshalBinary(t *testing.T) { func TestCommonPrefix(t *testing.T) { tests := []struct { name string - x *bitArray - y *bitArray - want *bitArray + x *BitArray + y *BitArray + want *BitArray }{ { name: "empty arrays", @@ -615,7 +615,7 @@ func TestCommonPrefix(t *testing.T) { }, { name: "one empty array", - x: &bitArray{ + x: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, @@ -624,165 +624,165 @@ func TestCommonPrefix(t *testing.T) { }, { name: "identical arrays - single word", - x: &bitArray{ + x: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, - y: &bitArray{ + y: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, - want: &bitArray{ + want: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, }, { name: "identical arrays - multiple words", - x: &bitArray{ + x: &BitArray{ len: 192, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, }, - y: &bitArray{ + y: &BitArray{ len: 192, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, }, - want: &bitArray{ + want: &BitArray{ len: 192, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, }, }, { name: "different lengths with common prefix - first word", - x: &bitArray{ + x: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, - y: &bitArray{ + y: &BitArray{ len: 32, words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, }, - want: &bitArray{ + want: &BitArray{ len: 32, words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, }, }, { name: "different lengths with common prefix - multiple words", - x: &bitArray{ + x: &BitArray{ len: 255, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF}, }, - y: &bitArray{ + y: &BitArray{ len: 127, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF, 0, 0}, }, - want: &bitArray{ + want: &BitArray{ len: 127, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF, 0, 0}, }, }, { name: "different at first bit", - x: &bitArray{ + x: &BitArray{ len: 64, words: [4]uint64{0x7FFFFFFFFFFFFFFF, 0, 0, 0}, }, - y: &bitArray{ + y: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, - want: &bitArray{ + want: &BitArray{ len: 0, words: [4]uint64{0, 0, 0, 0}, }, }, { name: "different in middle of first word", - x: &bitArray{ + x: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFF0FFFFFFF, 0, 0, 0}, }, - y: &bitArray{ + y: &BitArray{ len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, }, - want: &bitArray{ + want: &BitArray{ len: 32, words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, }, }, { name: "different in second word", - x: &bitArray{ + x: &BitArray{ len: 128, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFF0FFFFFFF, 0, 0}, }, - y: &bitArray{ + y: &BitArray{ len: 128, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, }, - want: &bitArray{ + want: &BitArray{ len: 32, words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, }, }, { name: "different in third word", - x: &bitArray{ + x: &BitArray{ len: 192, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, }, - y: &bitArray{ + y: &BitArray{ len: 192, words: [4]uint64{0, 0, 0xFFFFFFFFFFFFFF0F, 0}, }, - want: &bitArray{ + want: &BitArray{ len: 56, words: [4]uint64{0xFFFFFFFFFFFFFF, 0, 0, 0}, }, }, { name: "different in last word", - x: &bitArray{ + x: &BitArray{ len: 251, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFF}, }, - y: &bitArray{ + y: &BitArray{ len: 251, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFF0FFFFFFF}, }, - want: &bitArray{ + want: &BitArray{ len: 27, words: [4]uint64{0x7FFFFFF}, }, }, { name: "sparse bits with common prefix", - x: &bitArray{ + x: &BitArray{ len: 128, words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0xAAAAAAAAAAAAAAAA, 0, 0}, }, - y: &bitArray{ + y: &BitArray{ len: 128, words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0xAAAAAAAAAAAAA000, 0, 0}, }, - want: &bitArray{ + want: &BitArray{ len: 52, words: [4]uint64{0xAAAAAAAAAAAAA, 0, 0, 0}, }, }, { name: "max length difference", - x: &bitArray{ + x: &BitArray{ len: 255, words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF}, }, - y: &bitArray{ + y: &BitArray{ len: 1, words: [4]uint64{0x1, 0, 0, 0}, }, - want: &bitArray{ + want: &BitArray{ len: 1, words: [4]uint64{0x1, 0, 0, 0}, }, @@ -791,8 +791,8 @@ func TestCommonPrefix(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := new(bitArray) - gotSymmetric := new(bitArray) + got := new(BitArray) + gotSymmetric := new(BitArray) got.CommonMSBs(tt.x, tt.y) if !got.Equal(tt.want) { From 82bd95a3eff0fa8eae49df0cfdd5556f0aaf94ff Mon Sep 17 00:00:00 2001 From: weiihann Date: Sun, 15 Dec 2024 22:20:09 +0800 Subject: [PATCH 12/30] add IsBitSet --- core/trie/bitarray.go | 10 ++++ core/trie/bitarray_test.go | 109 +++++++++++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 57f3d28072..43e5800d56 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -277,6 +277,16 @@ func (b *BitArray) Equal(x *BitArray) bool { b.words[3] == x.words[3] } +// IsBitSit returns true if bit n-th is set, where n = 0 is LSB. +// The n must be <= 255. +func (b *BitArray) IsBitSet(n uint8) bool { + if n >= b.len { + return false + } + + return (b.words[n/64] & (1 << (n % 64))) != 0 +} + // Write serialises the BitArray into a bytes buffer in the following format: // - First byte: length of the bit array (0-255) // - Remaining bytes: the necessary bytes included in big endian order diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 2c8265e09a..229479bdb2 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -807,3 +807,112 @@ func TestCommonPrefix(t *testing.T) { }) } } + +func TestIsBitSet(t *testing.T) { + tests := []struct { + name string + ba BitArray + pos uint8 + want bool + }{ + { + name: "empty array", + ba: BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + pos: 0, + want: false, + }, + { + name: "first bit set", + ba: BitArray{ + len: 64, + words: [4]uint64{1, 0, 0, 0}, + }, + pos: 0, + want: true, + }, + { + name: "last bit in first word", + ba: BitArray{ + len: 64, + words: [4]uint64{1 << 63, 0, 0, 0}, + }, + pos: 63, + want: true, + }, + { + name: "first bit in second word", + ba: BitArray{ + len: 128, + words: [4]uint64{0, 1, 0, 0}, + }, + pos: 64, + want: true, + }, + { + name: "bit beyond length", + ba: BitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + pos: 65, + want: false, + }, + { + name: "alternating bits", + ba: BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 in binary + }, + pos: 1, + want: true, + }, + { + name: "alternating bits - unset position", + ba: BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 in binary + }, + pos: 0, + want: false, + }, + { + name: "bit in last word", + ba: BitArray{ + len: 251, + words: [4]uint64{0, 0, 0, 1 << 59}, + }, + pos: 251, + want: false, // position 251 is beyond the highest valid bit (250) + }, + { + name: "highest valid bit (255)", + ba: BitArray{ + len: 255, + words: [4]uint64{0, 0, 0, 1 << 62}, // bit 255 set + }, + pos: 254, + want: true, + }, + { + name: "position at length boundary", + ba: BitArray{ + len: 100, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + pos: 100, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.ba.IsBitSet(tt.pos) + if got != tt.want { + t.Errorf("IsBitSet(%d) = %v, want %v", tt.pos, got, tt.want) + } + }) + } +} From 6da60b9c3199242c5b8d935d05625f57f38ee561 Mon Sep 17 00:00:00 2001 From: weiihann Date: Sun, 15 Dec 2024 23:13:54 +0800 Subject: [PATCH 13/30] fix lint and comments --- core/trie/bitarray.go | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 43e5800d56..db6c702ac5 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -11,20 +11,19 @@ import ( const ( maxUint64 = uint64(math.MaxUint64) - byteBits = 8 + bits8 = 8 ) var ( maxBitArray = [4]uint64{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64} - emptyBitArray = &BitArray{len: 0, words: [4]uint64{0, 0, 0, 0}} + emptyBitArray = new(BitArray) ) -// BitArray is a structure that represents a bit array with a max length of 255 bits. +// BitArray is a structure that represents a bit array with length representing the number of used bits. // It uses a little endian representation to do bitwise operations of the words efficiently. -// Unlike normal bit arrays, it has a `len` field that represents the number of used bits. // For example, if len is 10, it means that the 2^9, 2^8, ..., 2^0 bits are used. -// The reason why 255 bits is the max length is because we only need up to 251 bits for a given trie key. -// Though words can be used to represent 256 bits, we don't want to add an additional byte for the length. +// The max length is 255 bits (uint8), because our use case only need up to 251 bits for a given trie key. +// Although words can be used to represent 256 bits, we don't want to add an additional byte for the length. type BitArray struct { len uint8 // number of used bits words [4]uint64 // little endian (i.e. words[0] is the least significant) @@ -40,6 +39,8 @@ func (b *BitArray) Len() uint8 { } // Bytes returns the bytes representation of the bit array in big endian format +// +//nolint:mnd func (b *BitArray) Bytes() [32]byte { var res [32]byte @@ -120,6 +121,8 @@ func (b *BitArray) EqualMSBs(x *BitArray) bool { // Truncate(x, 4) = 1011 (len=4) // Truncate(x, 10) = 11001011 (len=8, original x) // Truncate(x, 0) = 0 (len=0) +// +//nolint:mnd func (b *BitArray) Truncate(x *BitArray, length uint8) *BitArray { if length >= x.len { return b.Set(x) @@ -173,7 +176,7 @@ func (b *BitArray) CommonMSBs(x, y *BitArray) *BitArray { // Align arrays by right-shifting longer array and then XOR to find differences // Example: - // short = 1101 (len=4) + // short = 1100 (len=4) // long = 1101 0111 (len=8) // // Step 1: Right shift longer array by 4 @@ -216,6 +219,8 @@ func findFirstSetBit(b *BitArray) uint8 { } // Rsh sets b = x >> n and returns b. +// +//nolint:mnd func (b *BitArray) Rsh(x *BitArray, n uint8) *BitArray { if x.len == 0 { return b.Set(x) @@ -358,7 +363,7 @@ func (b *BitArray) Set(x *BitArray) *BitArray { // It rounds up to the nearest byte. func (b *BitArray) byteCount() uint8 { // Cast to uint16 to avoid overflow - return uint8((uint16(b.len) + uint16(byteBits-1)) / uint16(byteBits)) + return uint8((uint16(b.len) + (bits8 - 1)) / uint16(bits8)) } // activeBytes returns a slice containing only the bytes that are actually used From 5a6094acb1a664a52e68646daf2580afcfed5d22 Mon Sep 17 00:00:00 2001 From: weiihann Date: Mon, 16 Dec 2024 22:09:50 +0800 Subject: [PATCH 14/30] Felt return value --- core/trie/bitarray.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index db6c702ac5..c371bee996 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -29,9 +29,12 @@ type BitArray struct { words [4]uint64 // little endian (i.e. words[0] is the least significant) } -func (b *BitArray) Felt() *felt.Felt { +func (b *BitArray) Felt() felt.Felt { bs := b.Bytes() - return new(felt.Felt).SetBytes(bs[:]) + + var f felt.Felt + f.SetBytes(bs[:]) + return f } func (b *BitArray) Len() uint8 { From bfb2bfa0a3a7f822f59a2eefab34d8a072dc48da Mon Sep 17 00:00:00 2001 From: weiihann Date: Tue, 17 Dec 2024 14:53:42 +0800 Subject: [PATCH 15/30] add felt tests --- core/trie/bitarray.go | 99 ++++++++++++++----------- core/trie/bitarray_test.go | 144 +++++++++++++++++++++++++++++++++++-- 2 files changed, 195 insertions(+), 48 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index c371bee996..03aa84f700 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -29,11 +29,13 @@ type BitArray struct { words [4]uint64 // little endian (i.e. words[0] is the least significant) } -func (b *BitArray) Felt() felt.Felt { - bs := b.Bytes() +func NewBitArray(val uint64) *BitArray { + return new(BitArray).SetUint64(val) +} +func (b *BitArray) Felt() felt.Felt { var f felt.Felt - f.SetBytes(bs[:]) + f.SetBytes(b.Bytes()) return f } @@ -44,13 +46,13 @@ func (b *BitArray) Len() uint8 { // Bytes returns the bytes representation of the bit array in big endian format // //nolint:mnd -func (b *BitArray) Bytes() [32]byte { +func (b *BitArray) Bytes() []byte { var res [32]byte switch { case b.len == 0: // all zeros - return res + return res[:] case b.len >= 192: // Create mask for top word: keeps only valid bits above 192 // e.g., if len=200, keeps lowest 8 bits (200-192) @@ -76,7 +78,7 @@ func (b *BitArray) Bytes() [32]byte { binary.BigEndian.PutUint64(res[24:32], b.words[0]&mask) } - return res + return res[:] } // EqualMSBs checks if two bit arrays share the same most significant bits, where the length of @@ -199,28 +201,6 @@ func (b *BitArray) CommonMSBs(x, y *BitArray) *BitArray { return b.Rsh(short, divergentBit) } -// findFirstSetBit returns the position of the first '1' bit in the array, -// scanning from most significant to least significant bit. -// -// The bit position is counted from the least significant bit, starting at 0. -// For example: -// -// array = 0000 0000 ... 0100 (len=251) -// findFirstSetBit() = 2 // third bit from right is set -func findFirstSetBit(b *BitArray) uint8 { - if b.len == 0 { - return 0 - } - - for i := 3; i >= 0; i-- { - if word := b.words[i]; word != 0 { - return uint8((i+1)*64 - bits.LeadingZeros64(word)) - } - } - - return 0 -} - // Rsh sets b = x >> n and returns b. // //nolint:mnd @@ -316,13 +296,21 @@ func (b *BitArray) Write(buf *bytes.Buffer) (int, error) { // Example: // // [0x0A, 0x03, 0xFF] -> BitArray{len: 10, words: [4]uint64{0x03FF}} -func (b *BitArray) UnmarshalBinary(data []byte) error { +func (b *BitArray) UnmarshalBinary(data []byte) { b.len = data[0] var bs [32]byte copy(bs[32-b.byteCount():], data[1:]) - b.SetBytes32(bs) - return nil + b.setBytes32(bs[:]) +} + +func (b *BitArray) Set(x *BitArray) *BitArray { + b.len = x.len + b.words[0] = x.words[0] + b.words[1] = x.words[1] + b.words[2] = x.words[2] + b.words[3] = x.words[3] + return b } func (b *BitArray) SetFelt(length uint8, f *felt.Felt) *BitArray { @@ -337,11 +325,15 @@ func (b *BitArray) SetFelt251(f *felt.Felt) *BitArray { return b } -func (b *BitArray) SetBytes32(data [32]byte) *BitArray { - b.words[3] = binary.BigEndian.Uint64(data[0:8]) - b.words[2] = binary.BigEndian.Uint64(data[8:16]) - b.words[1] = binary.BigEndian.Uint64(data[16:24]) - b.words[0] = binary.BigEndian.Uint64(data[24:32]) +func (b *BitArray) SetBytes(length uint8, data []byte) *BitArray { + b.setBytes32(data) + b.len = length + return b +} + +func (b *BitArray) SetUint64(data uint64) *BitArray { + b.words[0] = data + b.len = 64 return b } @@ -353,13 +345,12 @@ func (b *BitArray) setFelt(f *felt.Felt) { b.words[0] = binary.BigEndian.Uint64(res[24:32]) } -func (b *BitArray) Set(x *BitArray) *BitArray { - b.len = x.len - b.words[0] = x.words[0] - b.words[1] = x.words[1] - b.words[2] = x.words[2] - b.words[3] = x.words[3] - return b +func (b *BitArray) setBytes32(data []byte) { + _ = data[31] + b.words[3] = binary.BigEndian.Uint64(data[0:8]) + b.words[2] = binary.BigEndian.Uint64(data[8:16]) + b.words[1] = binary.BigEndian.Uint64(data[16:24]) + b.words[0] = binary.BigEndian.Uint64(data[24:32]) } // byteCount returns the minimum number of bytes needed to represent the bit array. @@ -398,3 +389,25 @@ func (b *BitArray) clear() *BitArray { b.words[0], b.words[1], b.words[2], b.words[3] = 0, 0, 0, 0 return b } + +// findFirstSetBit returns the position of the first '1' bit in the array, +// scanning from most significant to least significant bit. +// +// The bit position is counted from the least significant bit, starting at 0. +// For example: +// +// array = 0000 0000 ... 0100 (len=251) +// findFirstSetBit() = 2 // third bit from right is set +func findFirstSetBit(b *BitArray) uint8 { + if b.len == 0 { + return 0 + } + + for i := 3; i >= 0; i-- { + if word := b.words[i]; word != 0 { + return uint8((i+1)*64 - bits.LeadingZeros64(word)) + } + } + + return 0 +} diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 229479bdb2..96af163f34 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -6,7 +6,9 @@ import ( "math/bits" "testing" + "github.com/NethermindEth/juno/core/felt" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestBytes(t *testing.T) { @@ -91,7 +93,7 @@ func TestBytes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := tt.ba.Bytes() - if !bytes.Equal(got[:], tt.want[:]) { + if !bytes.Equal(got, tt.want[:]) { t.Errorf("BitArray.Bytes() = %v, want %v", got, tt.want) } @@ -589,10 +591,8 @@ func TestWriteAndUnmarshalBinary(t *testing.T) { t.Errorf("Write() = %v, want %v", got, tt.want) } - gotBitArray := new(BitArray) - if err := gotBitArray.UnmarshalBinary(got); err != nil { - t.Errorf("UnmarshalBinary() = %v", err) - } + var gotBitArray BitArray + gotBitArray.UnmarshalBinary(got) if !gotBitArray.Equal(&tt.ba) { t.Errorf("UnmarshalBinary() = %v, want %v", gotBitArray, tt.ba) } @@ -916,3 +916,137 @@ func TestIsBitSet(t *testing.T) { }) } } + +func TestFeltConversion(t *testing.T) { + tests := []struct { + name string + ba BitArray + length uint8 + want string // hex representation of felt + }{ + { + name: "empty bit array", + ba: BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + length: 0, + want: "0x0", + }, + { + name: "single word", + ba: BitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + length: 64, + want: "0xffffffffffffffff", + }, + { + name: "two words", + ba: BitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + length: 128, + want: "0xffffffffffffffffffffffffffffffff", + }, + { + name: "three words", + ba: BitArray{ + len: 192, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, + }, + length: 192, + want: "0xffffffffffffffffffffffffffffffffffffffffffffffff", + }, + { + name: "max length (251 bits)", + ba: BitArray{ + len: 255, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFF}, + }, + length: 255, + want: "0x7ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", + }, + { + name: "sparse bits", + ba: BitArray{ + len: 128, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x5555555555555555, 0, 0}, + }, + length: 128, + want: "0x5555555555555555aaaaaaaaaaaaaaaa", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test Felt() conversion + gotFelt := tt.ba.Felt() + assert.Equal(t, tt.want, gotFelt.String()) + + // Test SetFelt() conversion (round trip) + var newBA BitArray + newBA.SetFelt(tt.length, &gotFelt) + assert.Equal(t, tt.ba.len, newBA.len) + assert.Equal(t, tt.ba.words, newBA.words) + }) + } +} + +func TestSetFeltValidation(t *testing.T) { + tests := []struct { + name string + feltStr string + length uint8 + shouldMatch bool + }{ + { + name: "valid felt with matching length", + feltStr: "0xf", + length: 4, + shouldMatch: true, + }, + { + name: "felt larger than specified length", + feltStr: "0xff", + length: 4, + shouldMatch: false, + }, + { + name: "zero felt with non-zero length", + feltStr: "0x0", + length: 8, + shouldMatch: true, + }, + { + name: "max felt with max length", + feltStr: "0x7ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", + length: 251, + shouldMatch: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var f felt.Felt + _, err := f.SetString(tt.feltStr) + require.NoError(t, err) + + var ba BitArray + ba.SetFelt(tt.length, &f) + + // Convert back to felt and compare + roundTrip := ba.Felt() + if tt.shouldMatch { + assert.True(t, roundTrip.Equal(&f), + "expected %s, got %s", f.String(), roundTrip.String()) + } else { + assert.False(t, roundTrip.Equal(&f), + "values should not match: original %s, roundtrip %s", + f.String(), roundTrip.String()) + } + }) + } +} From 812e0ad32e34a4d012c15418cad2b8b6bdf1ab0f Mon Sep 17 00:00:00 2001 From: weiihann Date: Tue, 17 Dec 2024 15:11:01 +0800 Subject: [PATCH 16/30] use const for 0xFF...FF --- core/trie/bitarray.go | 2 +- core/trie/bitarray_test.go | 162 ++++++++++++++++++------------------- 2 files changed, 82 insertions(+), 82 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 03aa84f700..7b377ddf27 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -10,7 +10,7 @@ import ( ) const ( - maxUint64 = uint64(math.MaxUint64) + maxUint64 = uint64(math.MaxUint64) // 0xFFFFFFFFFFFFFFFF bits8 = 8 ) diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 96af163f34..abfccc7a1c 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -37,7 +37,7 @@ func TestBytes(t *testing.T) { want: func() [32]byte { var b [32]byte binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFF) - binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[24:32], maxUint64) return b }(), }, @@ -47,8 +47,8 @@ func TestBytes(t *testing.T) { want: func() [32]byte { var b [32]byte binary.BigEndian.PutUint64(b[8:16], 0x3) - binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) - binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[16:24], maxUint64) + binary.BigEndian.PutUint64(b[24:32], maxUint64) return b }(), }, @@ -58,9 +58,9 @@ func TestBytes(t *testing.T) { want: func() [32]byte { var b [32]byte binary.BigEndian.PutUint64(b[0:8], 0x1FF) - binary.BigEndian.PutUint64(b[8:16], 0xFFFFFFFFFFFFFFFF) - binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) - binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[8:16], maxUint64) + binary.BigEndian.PutUint64(b[16:24], maxUint64) + binary.BigEndian.PutUint64(b[24:32], maxUint64) return b }(), }, @@ -70,9 +70,9 @@ func TestBytes(t *testing.T) { want: func() [32]byte { var b [32]byte binary.BigEndian.PutUint64(b[0:8], 0x3FFFFFFFFFFFFFFF) - binary.BigEndian.PutUint64(b[8:16], 0xFFFFFFFFFFFFFFFF) - binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) - binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[8:16], maxUint64) + binary.BigEndian.PutUint64(b[16:24], maxUint64) + binary.BigEndian.PutUint64(b[24:32], maxUint64) return b }(), }, @@ -82,9 +82,9 @@ func TestBytes(t *testing.T) { want: func() [32]byte { var b [32]byte binary.BigEndian.PutUint64(b[0:8], 0x7FFFFFFFFFFFFFFF) - binary.BigEndian.PutUint64(b[8:16], 0xFFFFFFFFFFFFFFFF) - binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFFFFFFFFF) - binary.BigEndian.PutUint64(b[24:32], 0xFFFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[8:16], maxUint64) + binary.BigEndian.PutUint64(b[16:24], maxUint64) + binary.BigEndian.PutUint64(b[24:32], maxUint64) return b }(), }, @@ -132,19 +132,19 @@ func TestRsh(t *testing.T) { name: "shift by 0", initial: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, shiftBy: 0, expected: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, }, { name: "shift by more than length", initial: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, }, shiftBy: 65, expected: &BitArray{ @@ -156,60 +156,60 @@ func TestRsh(t *testing.T) { name: "shift by less than 64", initial: &BitArray{ len: 128, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, }, shiftBy: 32, expected: &BitArray{ len: 96, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0x00000000FFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, 0x00000000FFFFFFFF, 0, 0}, }, }, { name: "shift by exactly 64", initial: &BitArray{ len: 128, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, }, shiftBy: 64, expected: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, }, { name: "shift by 127", initial: &BitArray{ len: 255, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFFF}, }, shiftBy: 127, expected: &BitArray{ len: 128, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, }, }, { name: "shift by 128", initial: &BitArray{ len: 251, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, maxUint64}, }, shiftBy: 128, expected: &BitArray{ len: 123, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, }, }, { name: "shift by 192", initial: &BitArray{ len: 251, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, maxUint64}, }, shiftBy: 192, expected: &BitArray{ len: 59, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, }, } @@ -235,11 +235,11 @@ func TestPrefixEqual(t *testing.T) { name: "equal lengths, equal values", a: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, b: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, want: true, }, @@ -247,7 +247,7 @@ func TestPrefixEqual(t *testing.T) { name: "equal lengths, different values", a: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, b: &BitArray{ len: 64, @@ -259,11 +259,11 @@ func TestPrefixEqual(t *testing.T) { name: "different lengths, a longer but same prefix", a: &BitArray{ len: 128, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, }, b: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, want: true, }, @@ -271,11 +271,11 @@ func TestPrefixEqual(t *testing.T) { name: "different lengths, b longer but same prefix", a: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, b: &BitArray{ len: 128, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, }, want: true, }, @@ -283,7 +283,7 @@ func TestPrefixEqual(t *testing.T) { name: "different lengths, different prefix", a: &BitArray{ len: 128, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, }, b: &BitArray{ len: 64, @@ -307,7 +307,7 @@ func TestPrefixEqual(t *testing.T) { name: "one zero length array", a: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, b: &BitArray{ len: 0, @@ -319,7 +319,7 @@ func TestPrefixEqual(t *testing.T) { name: "max length difference", a: &BitArray{ len: 251, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, }, b: &BitArray{ len: 1, @@ -353,7 +353,7 @@ func TestTruncate(t *testing.T) { name: "truncate to zero", initial: BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, length: 0, expected: BitArray{ @@ -365,7 +365,7 @@ func TestTruncate(t *testing.T) { name: "truncate within first word - 32 bits", initial: BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, length: 32, expected: BitArray{ @@ -377,7 +377,7 @@ func TestTruncate(t *testing.T) { name: "truncate to single bit", initial: BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, length: 1, expected: BitArray{ @@ -389,72 +389,72 @@ func TestTruncate(t *testing.T) { name: "truncate across words - 100 bits", initial: BitArray{ len: 128, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, }, length: 100, expected: BitArray{ len: 100, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0x0000000FFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, 0x0000000FFFFFFFFF, 0, 0}, }, }, { name: "truncate at word boundary - 64 bits", initial: BitArray{ len: 128, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, }, length: 64, expected: BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, }, { name: "truncate at word boundary - 128 bits", initial: BitArray{ len: 192, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, }, length: 128, expected: BitArray{ len: 128, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, }, }, { name: "truncate in third word - 150 bits", initial: BitArray{ len: 192, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, }, length: 150, expected: BitArray{ len: 150, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x3FFFFF, 0}, + words: [4]uint64{maxUint64, maxUint64, 0x3FFFFF, 0}, }, }, { name: "truncate in fourth word - 220 bits", initial: BitArray{ len: 255, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, maxUint64}, }, length: 220, expected: BitArray{ len: 220, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0xFFFFFFF}, }, }, { name: "truncate max length - 251 bits", initial: BitArray{ len: 255, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, maxUint64}, }, length: 251, expected: BitArray{ len: 251, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, }, }, { @@ -473,24 +473,24 @@ func TestTruncate(t *testing.T) { name: "no change when new length equals current length", initial: BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, length: 64, expected: BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, }, { name: "no change when new length greater than current length", initial: BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, length: 128, expected: BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, }, } @@ -539,7 +539,7 @@ func TestWriteAndUnmarshalBinary(t *testing.T) { name: "64 bits", ba: BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, want: append( []byte{64}, // length byte @@ -550,7 +550,7 @@ func TestWriteAndUnmarshalBinary(t *testing.T) { name: "251 bits", ba: BitArray{ len: 251, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, }, want: func() []byte { b := make([]byte, 33) // 1 length byte + 32 data bytes @@ -617,7 +617,7 @@ func TestCommonPrefix(t *testing.T) { name: "one empty array", x: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, y: emptyBitArray, want: emptyBitArray, @@ -626,37 +626,37 @@ func TestCommonPrefix(t *testing.T) { name: "identical arrays - single word", x: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, y: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, want: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, }, { name: "identical arrays - multiple words", x: &BitArray{ len: 192, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, }, y: &BitArray{ len: 192, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, }, want: &BitArray{ len: 192, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, }, }, { name: "different lengths with common prefix - first word", x: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, y: &BitArray{ len: 32, @@ -671,15 +671,15 @@ func TestCommonPrefix(t *testing.T) { name: "different lengths with common prefix - multiple words", x: &BitArray{ len: 255, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFFF}, }, y: &BitArray{ len: 127, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, 0x7FFFFFFFFFFFFFFF, 0, 0}, }, want: &BitArray{ len: 127, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, 0x7FFFFFFFFFFFFFFF, 0, 0}, }, }, { @@ -690,7 +690,7 @@ func TestCommonPrefix(t *testing.T) { }, y: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, want: &BitArray{ len: 0, @@ -705,7 +705,7 @@ func TestCommonPrefix(t *testing.T) { }, y: &BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, want: &BitArray{ len: 32, @@ -716,11 +716,11 @@ func TestCommonPrefix(t *testing.T) { name: "different in second word", x: &BitArray{ len: 128, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFF0FFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, 0xFFFFFFFF0FFFFFFF, 0, 0}, }, y: &BitArray{ len: 128, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, }, want: &BitArray{ len: 32, @@ -731,7 +731,7 @@ func TestCommonPrefix(t *testing.T) { name: "different in third word", x: &BitArray{ len: 192, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, }, y: &BitArray{ len: 192, @@ -746,11 +746,11 @@ func TestCommonPrefix(t *testing.T) { name: "different in last word", x: &BitArray{ len: 251, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, }, y: &BitArray{ len: 251, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFF0FFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFF0FFFFFFF}, }, want: &BitArray{ len: 27, @@ -776,7 +776,7 @@ func TestCommonPrefix(t *testing.T) { name: "max length difference", x: &BitArray{ len: 255, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFFF}, }, y: &BitArray{ len: 1, @@ -855,7 +855,7 @@ func TestIsBitSet(t *testing.T) { name: "bit beyond length", ba: BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, }, pos: 65, want: false, @@ -900,7 +900,7 @@ func TestIsBitSet(t *testing.T) { name: "position at length boundary", ba: BitArray{ len: 100, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, }, pos: 100, want: false, @@ -937,7 +937,7 @@ func TestFeltConversion(t *testing.T) { name: "single word", ba: BitArray{ len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{maxUint64, 0, 0, 0}, }, length: 64, want: "0xffffffffffffffff", @@ -946,7 +946,7 @@ func TestFeltConversion(t *testing.T) { name: "two words", ba: BitArray{ len: 128, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, }, length: 128, want: "0xffffffffffffffffffffffffffffffff", @@ -955,7 +955,7 @@ func TestFeltConversion(t *testing.T) { name: "three words", ba: BitArray{ len: 192, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, }, length: 192, want: "0xffffffffffffffffffffffffffffffffffffffffffffffff", @@ -964,7 +964,7 @@ func TestFeltConversion(t *testing.T) { name: "max length (251 bits)", ba: BitArray{ len: 255, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, }, length: 255, want: "0x7ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", From 383da715866da00e2ca35122eff53e4fbc5ff08a Mon Sep 17 00:00:00 2001 From: weiihann Date: Tue, 17 Dec 2024 16:19:27 +0800 Subject: [PATCH 17/30] add MSBs --- core/trie/bitarray.go | 36 ++++++++++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 7b377ddf27..bbcf15f414 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -3,6 +3,8 @@ package trie import ( "bytes" "encoding/binary" + "encoding/hex" + "fmt" "math" "math/bits" @@ -29,8 +31,8 @@ type BitArray struct { words [4]uint64 // little endian (i.e. words[0] is the least significant) } -func NewBitArray(val uint64) *BitArray { - return new(BitArray).SetUint64(val) +func NewBitArray(length uint8, val uint64) *BitArray { + return new(BitArray).SetUint64(length, val) } func (b *BitArray) Felt() felt.Felt { @@ -275,6 +277,14 @@ func (b *BitArray) IsBitSet(n uint8) bool { return (b.words[n/64] & (1 << (n % 64))) != 0 } +func (b *BitArray) MSBs(x *BitArray, n uint8) *BitArray { + if n >= x.len { + return b.Set(x) + } + + return b.Rsh(x, x.len-n) +} + // Write serialises the BitArray into a bytes buffer in the following format: // - First byte: length of the bit array (0-255) // - Remaining bytes: the necessary bytes included in big endian order @@ -331,12 +341,26 @@ func (b *BitArray) SetBytes(length uint8, data []byte) *BitArray { return b } -func (b *BitArray) SetUint64(data uint64) *BitArray { +func (b *BitArray) SetUint64(length uint8, data uint64) *BitArray { b.words[0] = data - b.len = 64 + b.len = length return b } +func (b *BitArray) EncodedLen() uint { + return b.byteCount() + 1 +} + +func (b *BitArray) Copy() BitArray { + var res BitArray + res.Set(b) + return res +} + +func (b *BitArray) String() string { + return fmt.Sprintf("(%d) %s", b.len, hex.EncodeToString(b.Bytes())) +} + func (b *BitArray) setFelt(f *felt.Felt) { res := f.Bytes() b.words[3] = binary.BigEndian.Uint64(res[0:8]) @@ -355,9 +379,9 @@ func (b *BitArray) setBytes32(data []byte) { // byteCount returns the minimum number of bytes needed to represent the bit array. // It rounds up to the nearest byte. -func (b *BitArray) byteCount() uint8 { +func (b *BitArray) byteCount() uint { // Cast to uint16 to avoid overflow - return uint8((uint16(b.len) + (bits8 - 1)) / uint16(bits8)) + return (uint(b.len) + (bits8 - 1)) / uint(bits8) } // activeBytes returns a slice containing only the bytes that are actually used From ff7a5cdd642cc73ab5703b128268c3dcc481c4ab Mon Sep 17 00:00:00 2001 From: weiihann Date: Tue, 17 Dec 2024 17:06:58 +0800 Subject: [PATCH 18/30] add MSBs() and rename Truncate to LSBs --- core/trie/bitarray.go | 39 ++++++---- core/trie/bitarray_test.go | 155 ++++++++++++++++++++++++++++++++----- 2 files changed, 160 insertions(+), 34 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index bbcf15f414..7f5aa11f2c 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -17,7 +17,7 @@ const ( ) var ( - maxBitArray = [4]uint64{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64} + maxBits = [4]uint64{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64} emptyBitArray = new(BitArray) ) @@ -119,18 +119,18 @@ func (b *BitArray) EqualMSBs(x *BitArray) bool { return long.Rsh(long, long.len-short.len).Equal(short) } -// Truncate sets b to the first 'length' bits of x (starting from the least significant bit). -// If length >= x.len, b is an exact copy of x. +// LSBs sets b to the least significant 'n' bits of x. +// If n >= x.len, b is an exact copy of x. // Any bits beyond the specified length are cleared to zero. // For example: // // x = 11001011 (len=8) -// Truncate(x, 4) = 1011 (len=4) -// Truncate(x, 10) = 11001011 (len=8, original x) -// Truncate(x, 0) = 0 (len=0) +// LSBs(x, 4) = 1011 (len=4) +// LSBs(x, 10) = 11001011 (len=8, original x) +// LSBs(x, 0) = 0 (len=0) // //nolint:mnd -func (b *BitArray) Truncate(x *BitArray, length uint8) *BitArray { +func (b *BitArray) LSBs(x *BitArray, length uint8) *BitArray { if length >= x.len { return b.Set(x) } @@ -165,6 +165,23 @@ func (b *BitArray) Truncate(x *BitArray, length uint8) *BitArray { return b } +// MSBs sets b to the most significant 'n' bits of x. +// If n >= x.len, b is an exact copy of x. +// Any bits beyond the specified length are cleared to zero. +// For example: +// +// x = 11001011 (len=8) +// MSBs(x, 4) = 1100 (len=4) +// MSBs(x, 10) = 11001011 (len=8, original x) +// MSBs(x, 0) = 0 (len=0) +func (b *BitArray) MSBs(x *BitArray, n uint8) *BitArray { + if n >= x.len { + return b.Set(x) + } + + return b.Rsh(x, x.len-n) +} + // CommonMSBs sets b to the longest sequence of matching most significant bits between two bit arrays. // For example: // @@ -277,14 +294,6 @@ func (b *BitArray) IsBitSet(n uint8) bool { return (b.words[n/64] & (1 << (n % 64))) != 0 } -func (b *BitArray) MSBs(x *BitArray, n uint8) *BitArray { - if n >= x.len { - return b.Set(x) - } - - return b.Rsh(x, x.len-n) -} - // Write serialises the BitArray into a bytes buffer in the following format: // - First byte: length of the bit array (0-255) // - Remaining bytes: the necessary bytes included in big endian order diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index abfccc7a1c..c90223ab6a 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -11,6 +11,10 @@ import ( "github.com/stretchr/testify/require" ) +const ( + ones63 = 0x7FFFFFFFFFFFFFFF +) + func TestBytes(t *testing.T) { tests := []struct { name string @@ -19,12 +23,12 @@ func TestBytes(t *testing.T) { }{ { name: "length == 0", - ba: BitArray{len: 0, words: maxBitArray}, + ba: BitArray{len: 0, words: maxBits}, want: [32]byte{}, }, { name: "length < 64", - ba: BitArray{len: 38, words: maxBitArray}, + ba: BitArray{len: 38, words: maxBits}, want: func() [32]byte { var b [32]byte binary.BigEndian.PutUint64(b[24:32], 0x3FFFFFFFFF) @@ -33,7 +37,7 @@ func TestBytes(t *testing.T) { }, { name: "64 <= length < 128", - ba: BitArray{len: 100, words: maxBitArray}, + ba: BitArray{len: 100, words: maxBits}, want: func() [32]byte { var b [32]byte binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFF) @@ -43,7 +47,7 @@ func TestBytes(t *testing.T) { }, { name: "128 <= length < 192", - ba: BitArray{len: 130, words: maxBitArray}, + ba: BitArray{len: 130, words: maxBits}, want: func() [32]byte { var b [32]byte binary.BigEndian.PutUint64(b[8:16], 0x3) @@ -54,7 +58,7 @@ func TestBytes(t *testing.T) { }, { name: "192 <= length < 255", - ba: BitArray{len: 201, words: maxBitArray}, + ba: BitArray{len: 201, words: maxBits}, want: func() [32]byte { var b [32]byte binary.BigEndian.PutUint64(b[0:8], 0x1FF) @@ -66,7 +70,7 @@ func TestBytes(t *testing.T) { }, { name: "length == 254", - ba: BitArray{len: 254, words: maxBitArray}, + ba: BitArray{len: 254, words: maxBits}, want: func() [32]byte { var b [32]byte binary.BigEndian.PutUint64(b[0:8], 0x3FFFFFFFFFFFFFFF) @@ -78,10 +82,10 @@ func TestBytes(t *testing.T) { }, { name: "length == 255", - ba: BitArray{len: 255, words: maxBitArray}, + ba: BitArray{len: 255, words: maxBits}, want: func() [32]byte { var b [32]byte - binary.BigEndian.PutUint64(b[0:8], 0x7FFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[0:8], ones63) binary.BigEndian.PutUint64(b[8:16], maxUint64) binary.BigEndian.PutUint64(b[16:24], maxUint64) binary.BigEndian.PutUint64(b[24:32], maxUint64) @@ -180,7 +184,7 @@ func TestRsh(t *testing.T) { name: "shift by 127", initial: &BitArray{ len: 255, - words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}, }, shiftBy: 127, expected: &BitArray{ @@ -342,7 +346,7 @@ func TestPrefixEqual(t *testing.T) { } } -func TestTruncate(t *testing.T) { +func TestLSBs(t *testing.T) { tests := []struct { name string initial BitArray @@ -497,7 +501,7 @@ func TestTruncate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := new(BitArray).Truncate(&tt.initial, tt.length) + result := new(BitArray).LSBs(&tt.initial, tt.length) if !result.Equal(&tt.expected) { t.Errorf("Truncate() got = %+v, want %+v", result, tt.expected) } @@ -505,6 +509,119 @@ func TestTruncate(t *testing.T) { } } +func TestMSBs(t *testing.T) { + tests := []struct { + name string + x *BitArray + n uint8 + want *BitArray + }{ + { + name: "empty array", + x: emptyBitArray, + n: 0, + want: emptyBitArray, + }, + { + name: "get all bits", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + n: 64, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "get more bits than available", + x: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + n: 64, + want: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "get half of available bits", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + n: 32, + want: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF00000000 >> 32, 0, 0, 0}, + }, + }, + { + name: "get MSBs across word boundary", + x: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + n: 100, + want: &BitArray{ + len: 100, + words: [4]uint64{maxUint64, maxUint64 >> 28, 0, 0}, + }, + }, + { + name: "get MSBs from max length array", + x: &BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}, + }, + n: 64, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "get zero bits", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + n: 0, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "sparse bits", + x: &BitArray{ + len: 128, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x5555555555555555, 0, 0}, + }, + n: 64, + want: &BitArray{ + len: 64, + words: [4]uint64{0x5555555555555555, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).MSBs(tt.x, tt.n) + if !got.Equal(tt.want) { + t.Errorf("MSBs() = %v, want %v", got, tt.want) + } + + if got.len != tt.want.len { + t.Errorf("MSBs() = %v, want %v", got, tt.want) + } + }) + } +} + func TestWriteAndUnmarshalBinary(t *testing.T) { tests := []struct { name string @@ -671,22 +788,22 @@ func TestCommonPrefix(t *testing.T) { name: "different lengths with common prefix - multiple words", x: &BitArray{ len: 255, - words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}, }, y: &BitArray{ len: 127, - words: [4]uint64{maxUint64, 0x7FFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, ones63, 0, 0}, }, want: &BitArray{ len: 127, - words: [4]uint64{maxUint64, 0x7FFFFFFFFFFFFFFF, 0, 0}, + words: [4]uint64{maxUint64, ones63, 0, 0}, }, }, { name: "different at first bit", x: &BitArray{ len: 64, - words: [4]uint64{0x7FFFFFFFFFFFFFFF, 0, 0, 0}, + words: [4]uint64{ones63, 0, 0, 0}, }, y: &BitArray{ len: 64, @@ -776,7 +893,7 @@ func TestCommonPrefix(t *testing.T) { name: "max length difference", x: &BitArray{ len: 255, - words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFFF}, + words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}, }, y: &BitArray{ len: 1, @@ -961,12 +1078,12 @@ func TestFeltConversion(t *testing.T) { want: "0xffffffffffffffffffffffffffffffffffffffffffffffff", }, { - name: "max length (251 bits)", + name: "251 bits", ba: BitArray{ - len: 255, + len: 251, words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, }, - length: 255, + length: 251, want: "0x7ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", }, { From 012f092c4b8b7ecc5b6c941bff5647acecc609b3 Mon Sep 17 00:00:00 2001 From: weiihann Date: Tue, 17 Dec 2024 23:48:36 +0800 Subject: [PATCH 19/30] all tests passed --- core/state.go | 9 +- core/trie/bitarray.go | 21 +-- core/trie/key.go | 187 -------------------------- core/trie/key_test.go | 229 -------------------------------- core/trie/node.go | 42 +++--- core/trie/node_test.go | 4 +- core/trie/proof.go | 109 ++++++--------- core/trie/proof_test.go | 26 +++- core/trie/storage.go | 25 ++-- core/trie/storage_test.go | 20 +-- core/trie/trie.go | 91 +++++++------ core/trie/trie_pkg_test.go | 41 +++--- core/trie/trie_test.go | 76 ++++++++++- migration/migration.go | 16 +-- migration/migration_pkg_test.go | 15 ++- 15 files changed, 290 insertions(+), 621 deletions(-) delete mode 100644 core/trie/key.go delete mode 100644 core/trie/key_test.go diff --git a/core/state.go b/core/state.go index 378ba65bec..c17ff13f3e 100644 --- a/core/state.go +++ b/core/state.go @@ -139,10 +139,11 @@ func (s *State) globalTrie(bucket db.Bucket, newTrie trie.NewTrieFunc) (*trie.Tr // fetch root key rootKeyDBKey := dbPrefix - var rootKey *trie.Key + var rootKey *trie.BitArray err := s.txn.Get(rootKeyDBKey, func(val []byte) error { - rootKey = new(trie.Key) - return rootKey.UnmarshalBinary(val) + rootKey = new(trie.BitArray) + rootKey.UnmarshalBinary(val) + return nil }) // if some error other than "not found" @@ -169,7 +170,7 @@ func (s *State) globalTrie(bucket db.Bucket, newTrie trie.NewTrieFunc) (*trie.Tr if resultingRootKey != nil { var rootKeyBytes bytes.Buffer - _, marshalErr := resultingRootKey.WriteTo(&rootKeyBytes) + _, marshalErr := resultingRootKey.Write(&rootKeyBytes) if marshalErr != nil { return marshalErr } diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 7f5aa11f2c..11a55edae5 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -109,14 +109,13 @@ func (b *BitArray) EqualMSBs(x *BitArray) bool { return true } - var long, short *BitArray - - long, short = b, x - if b.len < x.len { - long, short = x, b + // Compare only the first min(b.len, x.len) bits + minLen := b.len + if x.len < minLen { + minLen = x.len } - return long.Rsh(long, long.len-short.len).Equal(short) + return new(BitArray).MSBs(b, minLen).Equal(new(BitArray).MSBs(x, minLen)) } // LSBs sets b to the least significant 'n' bits of x. @@ -229,8 +228,7 @@ func (b *BitArray) Rsh(x *BitArray, n uint8) *BitArray { } if n >= x.len { - x.clear() - return b.Set(x) + return b.clear() } switch { @@ -277,6 +275,13 @@ func (b *BitArray) Xor(x, y *BitArray) *BitArray { // Eq checks if two bit arrays are equal func (b *BitArray) Equal(x *BitArray) bool { + // TODO(weiihann): this is really not a good thing to do... + if b == nil && x == nil { + return true + } else if b == nil || x == nil { + return false + } + return b.len == x.len && b.words[0] == x.words[0] && b.words[1] == x.words[1] && diff --git a/core/trie/key.go b/core/trie/key.go deleted file mode 100644 index 0d0ca7aa88..0000000000 --- a/core/trie/key.go +++ /dev/null @@ -1,187 +0,0 @@ -package trie - -import ( - "bytes" - "encoding/hex" - "errors" - "fmt" - "math/big" - - "github.com/NethermindEth/juno/core/felt" -) - -var NilKey = &Key{len: 0, bitset: [32]byte{}} - -type Key struct { - len uint8 - bitset [32]byte -} - -func NewKey(length uint8, keyBytes []byte) Key { - k := Key{len: length} - if len(keyBytes) > len(k.bitset) { - panic("bytes does not fit in bitset") - } - copy(k.bitset[len(k.bitset)-len(keyBytes):], keyBytes) - return k -} - -func (k *Key) bytesNeeded() uint { - const byteBits = 8 - return (uint(k.len) + (byteBits - 1)) / byteBits -} - -func (k *Key) inUseBytes() []byte { - return k.bitset[len(k.bitset)-int(k.bytesNeeded()):] -} - -func (k *Key) unusedBytes() []byte { - return k.bitset[:len(k.bitset)-int(k.bytesNeeded())] -} - -func (k *Key) WriteTo(buf *bytes.Buffer) (int64, error) { - if err := buf.WriteByte(k.len); err != nil { - return 0, err - } - - n, err := buf.Write(k.inUseBytes()) - return int64(1 + n), err -} - -func (k *Key) UnmarshalBinary(data []byte) error { - k.len = data[0] - k.bitset = [32]byte{} - copy(k.inUseBytes(), data[1:1+k.bytesNeeded()]) - return nil -} - -func (k *Key) EncodedLen() uint { - return k.bytesNeeded() + 1 -} - -func (k *Key) Len() uint8 { - return k.len -} - -func (k *Key) Felt() felt.Felt { - var f felt.Felt - f.SetBytes(k.bitset[:]) - return f -} - -func (k *Key) Equal(other *Key) bool { - if k == nil && other == nil { - return true - } else if k == nil || other == nil { - return false - } - return k.len == other.len && k.bitset == other.bitset -} - -// IsBitSet returns whether the bit at the given position is 1. -// Position 0 represents the least significant (rightmost) bit. -func (k *Key) IsBitSet(position uint8) bool { - const LSB = uint8(0x1) - byteIdx := position / 8 - byteAtIdx := k.bitset[len(k.bitset)-int(byteIdx)-1] - bitIdx := position % 8 - return ((byteAtIdx >> bitIdx) & LSB) != 0 -} - -// shiftRight removes n least significant bits from the key by performing a right shift -// operation and reducing the key length. For example, if the key contains bits -// "1111 0000" (length=8) and n=4, the result will be "1111" (length=4). -// -// The operation is destructive - it modifies the key in place. -func (k *Key) shiftRight(n uint8) { - if k.len < n { - panic("deleting more bits than there are") - } - - if n == 0 { - return - } - - var bigInt big.Int - bigInt.SetBytes(k.bitset[:]) - bigInt.Rsh(&bigInt, uint(n)) - bigInt.FillBytes(k.bitset[:]) - k.len -= n -} - -// MostSignificantBits returns a new key with the most significant n bits of the current key. -func (k *Key) MostSignificantBits(n uint8) (*Key, error) { - if n > k.len { - return nil, errors.New("cannot get more bits than the key length") - } - - keyCopy := k.Copy() - keyCopy.shiftRight(k.len - n) - return &keyCopy, nil -} - -// Truncate truncates key to `length` bits by clearing the remaining upper bits -func (k *Key) Truncate(length uint8) { - k.len = length - - unusedBytes := k.unusedBytes() - clear(unusedBytes) - - // clear upper bits on the last used byte - inUseBytes := k.inUseBytes() - unusedBitsCount := 8 - (k.len % 8) - if unusedBitsCount != 8 && len(inUseBytes) > 0 { - inUseBytes[0] = (inUseBytes[0] << unusedBitsCount) >> unusedBitsCount - } -} - -func (k *Key) String() string { - return fmt.Sprintf("(%d) %s", k.len, hex.EncodeToString(k.bitset[:])) -} - -// Copy returns a deep copy of the key -func (k *Key) Copy() Key { - newKey := Key{len: k.len} - copy(newKey.bitset[:], k.bitset[:]) - return newKey -} - -func (k *Key) Bytes() [32]byte { - var result [32]byte - copy(result[:], k.bitset[:]) - return result -} - -// findCommonKey finds the set of common MSB bits in two key bitsets. -func findCommonKey(longerKey, shorterKey *Key) (Key, bool) { - divergentBit := findDivergentBit(longerKey, shorterKey) - - if divergentBit == 0 { - return *NilKey, false - } - - commonKey := *shorterKey - commonKey.shiftRight(shorterKey.Len() - divergentBit + 1) - return commonKey, divergentBit == shorterKey.Len()+1 -} - -// findDivergentBit finds the first bit that is different between two keys, -// starting from the most significant bit of both keys. -func findDivergentBit(longerKey, shorterKey *Key) uint8 { - divergentBit := uint8(0) - for divergentBit <= shorterKey.Len() && - longerKey.IsBitSet(longerKey.Len()-divergentBit) == shorterKey.IsBitSet(shorterKey.Len()-divergentBit) { - divergentBit++ - } - return divergentBit -} - -func isSubset(longerKey, shorterKey *Key) bool { - divergentBit := findDivergentBit(longerKey, shorterKey) - return divergentBit == shorterKey.Len()+1 -} - -func FeltToKey(length uint8, key *felt.Felt) Key { - keyBytes := key.Bytes() - return NewKey(length, keyBytes[:]) -} diff --git a/core/trie/key_test.go b/core/trie/key_test.go deleted file mode 100644 index 3867678e6e..0000000000 --- a/core/trie/key_test.go +++ /dev/null @@ -1,229 +0,0 @@ -package trie_test - -import ( - "bytes" - "testing" - - "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/trie" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestKeyEncoding(t *testing.T) { - tests := map[string]struct { - Len uint8 - Bytes []byte - }{ - "multiple of 8": { - Len: 4 * 8, - Bytes: []byte{0xDE, 0xAD, 0xBE, 0xEF}, - }, - "0 len": { - Len: 0, - Bytes: []byte{}, - }, - "odd len": { - Len: 3, - Bytes: []byte{0x03}, - }, - } - - for desc, test := range tests { - t.Run(desc, func(t *testing.T) { - key := trie.NewKey(test.Len, test.Bytes) - - var keyBuffer bytes.Buffer - n, err := key.WriteTo(&keyBuffer) - require.NoError(t, err) - assert.Equal(t, len(test.Bytes)+1, int(n)) - - keyBytes := keyBuffer.Bytes() - require.Len(t, keyBytes, int(n)) - assert.Equal(t, test.Len, keyBytes[0]) - assert.Equal(t, test.Bytes, keyBytes[1:]) - - var decodedKey trie.Key - require.NoError(t, decodedKey.UnmarshalBinary(keyBytes)) - assert.Equal(t, key, decodedKey) - }) - } -} - -func BenchmarkKeyEncoding(b *testing.B) { - val, err := new(felt.Felt).SetRandom() - require.NoError(b, err) - valBytes := val.Bytes() - - key := trie.NewKey(felt.Bits, valBytes[:]) - buffer := bytes.Buffer{} - buffer.Grow(felt.Bytes + 1) - b.ResetTimer() - - for i := 0; i < b.N; i++ { - _, err := key.WriteTo(&buffer) - require.NoError(b, err) - require.NoError(b, key.UnmarshalBinary(buffer.Bytes())) - buffer.Reset() - } -} - -func TestTruncate(t *testing.T) { - tests := map[string]struct { - key trie.Key - newLen uint8 - expectedKey trie.Key - }{ - "truncate to 12 bits": { - key: trie.NewKey(16, []byte{0xF3, 0x14}), - newLen: 12, - expectedKey: trie.NewKey(12, []byte{0x03, 0x14}), - }, - "truncate to 9 bits": { - key: trie.NewKey(16, []byte{0xF3, 0x14}), - newLen: 9, - expectedKey: trie.NewKey(9, []byte{0x01, 0x14}), - }, - "truncate to 3 bits": { - key: trie.NewKey(16, []byte{0xF3, 0x14}), - newLen: 3, - expectedKey: trie.NewKey(3, []byte{0x04}), - }, - "truncate to multiple of 8": { - key: trie.NewKey(251, []uint8{ - 0x7, 0x40, 0x33, 0x8c, 0xbc, 0x9, 0xeb, 0xf, 0xb7, 0xab, - 0xc5, 0x20, 0x35, 0xc6, 0x4d, 0x4e, 0xa5, 0x78, 0x18, 0x9e, 0xd6, 0x37, 0x47, 0x91, 0xd0, - 0x6e, 0x44, 0x1e, 0xf7, 0x7f, 0xf, 0x5f, - }), - newLen: 248, - expectedKey: trie.NewKey(248, []uint8{ - 0x0, 0x40, 0x33, 0x8c, 0xbc, 0x9, 0xeb, 0xf, 0xb7, 0xab, - 0xc5, 0x20, 0x35, 0xc6, 0x4d, 0x4e, 0xa5, 0x78, 0x18, 0x9e, 0xd6, 0x37, 0x47, 0x91, 0xd0, - 0x6e, 0x44, 0x1e, 0xf7, 0x7f, 0xf, 0x5f, - }), - }, - } - - for desc, test := range tests { - t.Run(desc, func(t *testing.T) { - copyKey := test.key - copyKey.Truncate(test.newLen) - assert.Equal(t, test.expectedKey, copyKey) - }) - } -} - -func TestKeyTest(t *testing.T) { - key := trie.NewKey(44, []byte{0x10, 0x02}) - for i := 0; i < int(key.Len()); i++ { - assert.Equal(t, i == 1 || i == 12, key.IsBitSet(uint8(i)), i) - } -} - -func TestIsBitSet(t *testing.T) { - tests := map[string]struct { - key trie.Key - position uint8 - expected bool - }{ - "single byte, LSB set": { - key: trie.NewKey(8, []byte{0x01}), - position: 0, - expected: true, - }, - "single byte, MSB set": { - key: trie.NewKey(8, []byte{0x80}), - position: 7, - expected: true, - }, - "single byte, middle bit set": { - key: trie.NewKey(8, []byte{0x10}), - position: 4, - expected: true, - }, - "single byte, bit not set": { - key: trie.NewKey(8, []byte{0xFE}), - position: 0, - expected: false, - }, - "multiple bytes, LSB set": { - key: trie.NewKey(16, []byte{0x00, 0x02}), - position: 1, - expected: true, - }, - "multiple bytes, MSB set": { - key: trie.NewKey(16, []byte{0x01, 0x00}), - position: 8, - expected: true, - }, - "multiple bytes, no bits set": { - key: trie.NewKey(16, []byte{0x00, 0x00}), - position: 7, - expected: false, - }, - "check all bits in pattern": { - key: trie.NewKey(8, []byte{0xA5}), // 10100101 - position: 0, - expected: true, - }, - } - - // Additional test for 0xA5 pattern - key := trie.NewKey(8, []byte{0xA5}) // 10100101 - expectedBits := []bool{true, false, true, false, false, true, false, true} - for i, expected := range expectedBits { - assert.Equal(t, expected, key.IsBitSet(uint8(i)), "bit %d in 0xA5", i) - } - - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - result := tc.key.IsBitSet(tc.position) - assert.Equal(t, tc.expected, result) - }) - } -} - -func TestMostSignificantBits(t *testing.T) { - tests := []struct { - name string - key trie.Key - n uint8 - want trie.Key - expectErr bool - }{ - { - name: "Valid case", - key: trie.NewKey(8, []byte{0b11110000}), - n: 4, - want: trie.NewKey(4, []byte{0b00001111}), - expectErr: false, - }, - { - name: "Request more bits than available", - key: trie.NewKey(8, []byte{0b11110000}), - n: 10, - want: trie.Key{}, - expectErr: true, - }, - { - name: "Zero bits requested", - key: trie.NewKey(8, []byte{0b11110000}), - n: 0, - want: trie.NewKey(0, []byte{}), - expectErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := tt.key.MostSignificantBits(tt.n) - if (err != nil) != tt.expectErr { - t.Errorf("MostSignificantBits() error = %v, expectErr %v", err, tt.expectErr) - return - } - if !tt.expectErr && !got.Equal(&tt.want) { - t.Errorf("MostSignificantBits() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/core/trie/node.go b/core/trie/node.go index 2ef176f92a..47f43e6eab 100644 --- a/core/trie/node.go +++ b/core/trie/node.go @@ -12,14 +12,14 @@ import ( // https://docs.starknet.io/architecture-and-concepts/network-architecture/starknet-state/#trie_construction type Node struct { Value *felt.Felt - Left *Key - Right *Key + Left *BitArray + Right *BitArray LeftHash *felt.Felt RightHash *felt.Felt } // Hash calculates the hash of a [Node] -func (n *Node) Hash(path *Key, hashFunc hashFunc) *felt.Felt { +func (n *Node) Hash(path *BitArray, hashFunc hashFunc) *felt.Felt { if path.Len() == 0 { // we have to deference the Value, since the Node can released back // to the NodePool and be reused anytime @@ -34,32 +34,32 @@ func (n *Node) Hash(path *Key, hashFunc hashFunc) *felt.Felt { } // Hash calculates the hash of a [Node] -func (n *Node) HashFromParent(parentKey, nodeKey *Key, hashFunc hashFunc) *felt.Felt { +func (n *Node) HashFromParent(parentKey, nodeKey *BitArray, hashFunc hashFunc) *felt.Felt { path := path(nodeKey, parentKey) return n.Hash(&path, hashFunc) } -func (n *Node) WriteTo(buf *bytes.Buffer) (int64, error) { +func (n *Node) WriteTo(buf *bytes.Buffer) (int, error) { if n.Value == nil { return 0, errors.New("cannot marshal node with nil value") } - totalBytes := int64(0) + var totalBytes int valueB := n.Value.Bytes() wrote, err := buf.Write(valueB[:]) - totalBytes += int64(wrote) + totalBytes += wrote if err != nil { return totalBytes, err } if n.Left != nil { - wrote, errInner := n.Left.WriteTo(buf) + wrote, errInner := n.Left.Write(buf) totalBytes += wrote if errInner != nil { return totalBytes, errInner } - wrote, errInner = n.Right.WriteTo(buf) // n.Right is non-nil by design + wrote, errInner = n.Right.Write(buf) // n.Right is non-nil by design totalBytes += wrote if errInner != nil { return totalBytes, errInner @@ -74,14 +74,14 @@ func (n *Node) WriteTo(buf *bytes.Buffer) (int64, error) { leftHashB := n.LeftHash.Bytes() wrote, err = buf.Write(leftHashB[:]) - totalBytes += int64(wrote) + totalBytes += wrote if err != nil { return totalBytes, err } rightHashB := n.RightHash.Bytes() wrote, err = buf.Write(rightHashB[:]) - totalBytes += int64(wrote) + totalBytes += wrote if err != nil { return totalBytes, err } @@ -109,17 +109,13 @@ func (n *Node) UnmarshalBinary(data []byte) error { } if n.Left == nil { - n.Left = new(Key) - n.Right = new(Key) + n.Left = new(BitArray) + n.Right = new(BitArray) } - if err := n.Left.UnmarshalBinary(data); err != nil { - return err - } + n.Left.UnmarshalBinary(data) data = data[n.Left.EncodedLen():] - if err := n.Right.UnmarshalBinary(data); err != nil { - return err - } + n.Right.UnmarshalBinary(data) data = data[n.Right.EncodedLen():] if n.LeftHash == nil { @@ -156,11 +152,11 @@ func (n *Node) Update(other *Node) error { return fmt.Errorf("conflicting Values: %v != %v", n.Value, other.Value) } - if n.Left != nil && other.Left != nil && !n.Left.Equal(NilKey) && !other.Left.Equal(NilKey) && !n.Left.Equal(other.Left) { + if n.Left != nil && other.Left != nil && !n.Left.Equal(emptyBitArray) && !other.Left.Equal(emptyBitArray) && !n.Left.Equal(other.Left) { return fmt.Errorf("conflicting Left keys: %v != %v", n.Left, other.Left) } - if n.Right != nil && other.Right != nil && !n.Right.Equal(NilKey) && !other.Right.Equal(NilKey) && !n.Right.Equal(other.Right) { + if n.Right != nil && other.Right != nil && !n.Right.Equal(emptyBitArray) && !other.Right.Equal(emptyBitArray) && !n.Right.Equal(other.Right) { return fmt.Errorf("conflicting Right keys: %v != %v", n.Right, other.Right) } @@ -176,10 +172,10 @@ func (n *Node) Update(other *Node) error { if other.Value != nil { n.Value = other.Value } - if other.Left != nil && !other.Left.Equal(NilKey) { + if other.Left != nil && !other.Left.Equal(emptyBitArray) { n.Left = other.Left } - if other.Right != nil && !other.Right.Equal(NilKey) { + if other.Right != nil && !other.Right.Equal(emptyBitArray) { n.Right = other.Right } if other.LeftHash != nil { diff --git a/core/trie/node_test.go b/core/trie/node_test.go index ccb52b3eac..b222732f4b 100644 --- a/core/trie/node_test.go +++ b/core/trie/node_test.go @@ -22,7 +22,7 @@ func TestNodeHash(t *testing.T) { node := trie.Node{ Value: new(felt.Felt).SetBytes(valueBytes), } - path := trie.NewKey(6, []byte{42}) + path := trie.NewBitArray(6, 42) - assert.Equal(t, expected, node.Hash(&path, crypto.Pedersen), "TestTrieNode_Hash failed") + assert.Equal(t, expected, node.Hash(path, crypto.Pedersen), "TestTrieNode_Hash failed") } diff --git a/core/trie/proof.go b/core/trie/proof.go index bc4b66d0d9..ff16371c06 100644 --- a/core/trie/proof.go +++ b/core/trie/proof.go @@ -40,14 +40,14 @@ func (b *Binary) String() string { type Edge struct { Child *felt.Felt // child hash - Path *Key // path from parent to child + Path *BitArray // path from parent to child } func (e *Edge) Hash(hash hashFunc) *felt.Felt { - length := make([]byte, len(e.Path.bitset)) - length[len(e.Path.bitset)-1] = e.Path.len + var length [32]byte + length[31] = e.Path.len pathFelt := e.Path.Felt() - lengthFelt := new(felt.Felt).SetBytes(length) + lengthFelt := new(felt.Felt).SetBytes(length[:]) return new(felt.Felt).Add(hash(e.Child, &pathFelt), lengthFelt) } @@ -71,7 +71,7 @@ func (t *Trie) Prove(key *felt.Felt, proof *ProofNodeSet) error { return err } - var parentKey *Key + var parentKey *BitArray for i, sNode := range nodesFromRoot { sNodeEdge, sNodeBinary, err := storageNodeToProofNode(t, parentKey, sNode) @@ -137,10 +137,9 @@ func (t *Trie) GetRangeProof(leftKey, rightKey *felt.Felt, proofSet *ProofNodeSe // - Any node's computed hash doesn't match its expected hash // - The path bits don't match the key bits // - The proof ends before processing all key bits -func VerifyProof(root, keyFelt *felt.Felt, proof *ProofNodeSet, hash hashFunc) (*felt.Felt, error) { - key := FeltToKey(globalTrieHeight, keyFelt) +func VerifyProof(root, key *felt.Felt, proof *ProofNodeSet, hash hashFunc) (*felt.Felt, error) { + keyBits := new(BitArray).SetFelt(globalTrieHeight, key) expectedHash := root - keyLen := key.Len() var curPos uint8 for { @@ -156,17 +155,17 @@ func VerifyProof(root, keyFelt *felt.Felt, proof *ProofNodeSet, hash hashFunc) ( switch node := proofNode.(type) { case *Binary: // Binary nodes represent left/right choices - if key.Len() <= curPos { - return nil, fmt.Errorf("key length less than current position, key length: %d, current position: %d", key.Len(), curPos) + if keyBits.Len() <= curPos { + return nil, fmt.Errorf("key length less than current position, key length: %d, current position: %d", keyBits.Len(), curPos) } // Determine the next node to traverse based on the next bit position expectedHash = node.LeftHash - if key.IsBitSet(keyLen - curPos - 1) { + if keyBits.IsBitSet(keyBits.Len() - curPos - 1) { expectedHash = node.RightHash } curPos++ case *Edge: // Edge nodes represent paths between binary nodes - if !verifyEdgePath(&key, node.Path, curPos) { + if !verifyEdgePath(keyBits, node.Path, curPos) { return &felt.Zero, nil } @@ -176,7 +175,7 @@ func VerifyProof(root, keyFelt *felt.Felt, proof *ProofNodeSet, hash hashFunc) ( } // We've consumed all bits in our path - if curPos >= keyLen { + if curPos >= keyBits.Len() { return expectedHash, nil } } @@ -235,18 +234,18 @@ func VerifyRangeProof(root, first *felt.Felt, keys, values []*felt.Felt, proof * } nodes := NewStorageNodeSet() - firstKey := FeltToKey(globalTrieHeight, first) + firstKey := new(BitArray).SetFelt(globalTrieHeight, first) // Special case: there is a provided proof but no key-value pairs, make sure regenerated trie has no more values // Empty range proof with more elements on the right is not accepted in this function. // This is due to snap sync specification detail, where the responder must send an existing key (if any) if the requested range is empty. if len(keys) == 0 { - rootKey, val, err := proofToPath(root, &firstKey, proof, nodes) + rootKey, val, err := proofToPath(root, firstKey, proof, nodes) if err != nil { return false, err } - if val != nil || hasRightElement(rootKey, &firstKey, nodes) { + if val != nil || hasRightElement(rootKey, firstKey, nodes) { return false, errors.New("more entries available") } @@ -254,17 +253,17 @@ func VerifyRangeProof(root, first *felt.Felt, keys, values []*felt.Felt, proof * } last := keys[len(keys)-1] - lastKey := FeltToKey(globalTrieHeight, last) + lastKey := new(BitArray).SetFelt(globalTrieHeight, last) // Special case: there is only one element and two edge keys are the same - if len(keys) == 1 && firstKey.Equal(&lastKey) { - rootKey, val, err := proofToPath(root, &firstKey, proof, nodes) + if len(keys) == 1 && firstKey.Equal(lastKey) { + rootKey, val, err := proofToPath(root, firstKey, proof, nodes) if err != nil { return false, err } - elementKey := FeltToKey(globalTrieHeight, keys[0]) - if !firstKey.Equal(&elementKey) { + elementKey := new(BitArray).SetFelt(globalTrieHeight, keys[0]) + if !firstKey.Equal(elementKey) { return false, errors.New("correct proof but invalid key") } @@ -272,7 +271,7 @@ func VerifyRangeProof(root, first *felt.Felt, keys, values []*felt.Felt, proof * return false, errors.New("correct proof but invalid value") } - return hasRightElement(rootKey, &firstKey, nodes), nil + return hasRightElement(rootKey, firstKey, nodes), nil } // In all other cases, we require two edge paths available. @@ -281,12 +280,12 @@ func VerifyRangeProof(root, first *felt.Felt, keys, values []*felt.Felt, proof * return false, errors.New("last key is less than first key") } - rootKey, _, err := proofToPath(root, &firstKey, proof, nodes) + rootKey, _, err := proofToPath(root, firstKey, proof, nodes) if err != nil { return false, err } - lastRootKey, _, err := proofToPath(root, &lastKey, proof, nodes) + lastRootKey, _, err := proofToPath(root, lastKey, proof, nodes) if err != nil { return false, err } @@ -311,11 +310,11 @@ func VerifyRangeProof(root, first *felt.Felt, keys, values []*felt.Felt, proof * return false, fmt.Errorf("root hash mismatch, expected: %s, got: %s", root.String(), recomputedRoot.String()) } - return hasRightElement(rootKey, &lastKey, nodes), nil + return hasRightElement(rootKey, lastKey, nodes), nil } // isEdge checks if the storage node is an edge node. -func isEdge(parentKey *Key, sNode StorageNode) bool { +func isEdge(parentKey *BitArray, sNode StorageNode) bool { sNodeLen := sNode.key.len if parentKey == nil { // Root return sNodeLen != 0 @@ -326,7 +325,7 @@ func isEdge(parentKey *Key, sNode StorageNode) bool { // storageNodeToProofNode converts a StorageNode to the ProofNode(s). // Juno's Trie has nodes that are Binary AND Edge, whereas the protocol requires nodes that are Binary XOR Edge. // We need to convert the former to the latter for proof generation. -func storageNodeToProofNode(tri *Trie, parentKey *Key, sNode StorageNode) (*Edge, *Binary, error) { +func storageNodeToProofNode(tri *Trie, parentKey *BitArray, sNode StorageNode) (*Edge, *Binary, error) { var edge *Edge if isEdge(parentKey, sNode) { edgePath := path(sNode.key, parentKey) @@ -375,8 +374,8 @@ func storageNodeToProofNode(tri *Trie, parentKey *Key, sNode StorageNode) (*Edge // proofToPath converts a Merkle proof to trie node path. All necessary nodes will be resolved and leave the remaining // as hashes. The given edge proof can be existent or non-existent. -func proofToPath(root *felt.Felt, key *Key, proof *ProofNodeSet, nodes *StorageNodeSet) (*Key, *felt.Felt, error) { - rootKey, val, err := buildPath(root, key, 0, nil, proof, nodes) +func proofToPath(root *felt.Felt, keyBits *BitArray, proof *ProofNodeSet, nodes *StorageNodeSet) (*BitArray, *felt.Felt, error) { + rootKey, val, err := buildPath(root, keyBits, 0, nil, proof, nodes) if err != nil { return nil, nil, err } @@ -400,7 +399,7 @@ func proofToPath(root *felt.Felt, key *Key, proof *ProofNodeSet, nodes *StorageN sn := NewPartialStorageNode(edge.Path, edge.Child) // Handle leaf edge case (single key trie) - if edge.Path.Len() == key.Len() { + if edge.Path.Len() == keyBits.Len() { if err := nodes.Put(*sn.key, sn); err != nil { return nil, nil, fmt.Errorf("failed to store leaf edge: %w", err) } @@ -433,12 +432,12 @@ func proofToPath(root *felt.Felt, key *Key, proof *ProofNodeSet, nodes *StorageN // It returns the current node's key and any leaf value found along this path. func buildPath( nodeHash *felt.Felt, - key *Key, + key *BitArray, curPos uint8, curNode *StorageNode, proof *ProofNodeSet, nodes *StorageNodeSet, -) (*Key, *felt.Felt, error) { +) (*BitArray, *felt.Felt, error) { // We reached the leaf if curPos == key.Len() { leafKey := key.Copy() @@ -451,7 +450,7 @@ func buildPath( proofNode, ok := proof.Get(*nodeHash) if !ok { // non-existent proof node - return NilKey, nil, nil + return emptyBitArray, nil, nil } switch pn := proofNode.(type) { @@ -470,23 +469,19 @@ func buildPath( func handleBinaryNode( binary *Binary, nodeHash *felt.Felt, - key *Key, + key *BitArray, curPos uint8, curNode *StorageNode, proof *ProofNodeSet, nodes *StorageNodeSet, -) (*Key, *felt.Felt, error) { +) (*BitArray, *felt.Felt, error) { // If curNode is nil, it means that this current binary node is the root node. // Or, it's an internal binary node and the parent is also a binary node. // A standalone binary proof node always corresponds to a single storage node. // If curNode is not nil, it means that the parent node is an edge node. // In this case, the key of the storage node is based on the parent edge node. if curNode == nil { - nodeKey, err := key.MostSignificantBits(curPos) - if err != nil { - return nil, nil, err - } - curNode = NewPartialStorageNode(nodeKey, nodeHash) + curNode = NewPartialStorageNode(new(BitArray).MSBs(key, curPos), nodeHash) } curNode.node.LeftHash = binary.LeftHash curNode.node.RightHash = binary.RightHash @@ -523,23 +518,19 @@ func handleBinaryNode( // the current node's key and any leaf value found along this path. func handleEdgeNode( edge *Edge, - key *Key, + key *BitArray, curPos uint8, proof *ProofNodeSet, nodes *StorageNodeSet, -) (*Key, *felt.Felt, error) { +) (*BitArray, *felt.Felt, error) { // Verify the edge path matches the key path if !verifyEdgePath(key, edge.Path, curPos) { - return NilKey, nil, nil + return emptyBitArray, nil, nil } // The next node position is the end of the edge path nextPos := curPos + edge.Path.Len() - nodeKey, err := key.MostSignificantBits(nextPos) - if err != nil { - return nil, nil, fmt.Errorf("failed to get MSB for internal edge: %w", err) - } - curNode := NewPartialStorageNode(nodeKey, edge.Child) + curNode := NewPartialStorageNode(new(BitArray).MSBs(key, nextPos), edge.Child) // This is an edge leaf, stop traversing the trie if nextPos == key.Len() { @@ -562,24 +553,12 @@ func handleEdgeNode( } // verifyEdgePath checks if the edge path matches the key path at the current position. -func verifyEdgePath(key, edgePath *Key, curPos uint8) bool { - if key.Len() < curPos+edgePath.Len() { - return false - } - - // Ensure the bits between segment of the key and the node path match - start := key.Len() - curPos - edgePath.Len() - end := key.Len() - curPos - for i := start; i < end; i++ { - if key.IsBitSet(i) != edgePath.IsBitSet(i-start) { - return false // paths diverge - this proves non-membership - } - } - return true +func verifyEdgePath(key, edgePath *BitArray, curPos uint8) bool { + return new(BitArray).LSBs(key, key.Len()-curPos).EqualMSBs(edgePath) } // buildTrie builds a trie from a list of storage nodes and a list of keys and values. -func buildTrie(height uint8, rootKey *Key, nodes []*StorageNode, keys, values []*felt.Felt) (*Trie, error) { +func buildTrie(height uint8, rootKey *BitArray, nodes []*StorageNode, keys, values []*felt.Felt) (*Trie, error) { tr, err := NewTriePedersen(newMemStorage(), height) if err != nil { return nil, err @@ -607,9 +586,9 @@ func buildTrie(height uint8, rootKey *Key, nodes []*StorageNode, keys, values [] // hasRightElement checks if there is a right sibling for the given key in the trie. // This function assumes that the entire path has been resolved. -func hasRightElement(rootKey, key *Key, nodes *StorageNodeSet) bool { +func hasRightElement(rootKey, key *BitArray, nodes *StorageNodeSet) bool { cur := rootKey - for cur != nil && !cur.Equal(NilKey) { + for cur != nil && !cur.Equal(emptyBitArray) { sn, ok := nodes.Get(*cur) if !ok { return false diff --git a/core/trie/proof_test.go b/core/trie/proof_test.go index 94eaabc549..0f9c54543a 100644 --- a/core/trie/proof_test.go +++ b/core/trie/proof_test.go @@ -13,6 +13,30 @@ import ( "github.com/stretchr/testify/require" ) +func TestFix(t *testing.T) { + numKeys := 1000 + memdb := pebble.NewMemTest(t) + txn, err := memdb.NewTransaction(true) + require.NoError(t, err) + + tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{0}), 251) + require.NoError(t, err) + + records := make([]*keyValue, numKeys) + for i := 1; i < numKeys+1; i++ { + key := new(felt.Felt).SetUint64(uint64(i)) + records[i-1] = &keyValue{key: key, value: key} + _, err := tempTrie.Put(key, key) + require.NoError(t, err) + } + + sort.Slice(records, func(i, j int) bool { + return records[i].key.Cmp(records[j].key) < 0 + }) + + require.NoError(t, tempTrie.Commit()) +} + func TestProve(t *testing.T) { t.Parallel() @@ -360,7 +384,7 @@ func TestOneElementRangeProof(t *testing.T) { }) } -// TestAllElementsProof tests the range proof with all elements and nil proof. +// TestAllElementsRangeProof tests the range proof with all elements and nil proof. func TestAllElementsRangeProof(t *testing.T) { t.Parallel() diff --git a/core/trie/storage.go b/core/trie/storage.go index c4e5ae0915..6fe994fe3b 100644 --- a/core/trie/storage.go +++ b/core/trie/storage.go @@ -42,17 +42,17 @@ func NewStorage(txn db.Transaction, prefix []byte) *Storage { // dbKey creates a byte array to be used as a key to our KV store // it simply appends the given key to the configured prefix -func (t *Storage) dbKey(key *Key, buffer *bytes.Buffer) (int64, error) { +func (t *Storage) dbKey(key *BitArray, buffer *bytes.Buffer) (int, error) { _, err := buffer.Write(t.prefix) if err != nil { return 0, err } - keyLen, err := key.WriteTo(buffer) - return int64(len(t.prefix)) + keyLen, err + keyLen, err := key.Write(buffer) + return len(t.prefix) + keyLen, err } -func (t *Storage) Put(key *Key, value *Node) error { +func (t *Storage) Put(key *BitArray, value *Node) error { buffer := getBuffer() defer bufferPool.Put(buffer) keyLen, err := t.dbKey(key, buffer) @@ -69,7 +69,7 @@ func (t *Storage) Put(key *Key, value *Node) error { return t.txn.Set(encodedBytes[:keyLen], encodedBytes[keyLen:]) } -func (t *Storage) Get(key *Key) (*Node, error) { +func (t *Storage) Get(key *BitArray) (*Node, error) { buffer := getBuffer() defer bufferPool.Put(buffer) _, err := t.dbKey(key, buffer) @@ -87,7 +87,7 @@ func (t *Storage) Get(key *Key) (*Node, error) { return node, err } -func (t *Storage) Delete(key *Key) error { +func (t *Storage) Delete(key *BitArray) error { buffer := getBuffer() defer bufferPool.Put(buffer) _, err := t.dbKey(key, buffer) @@ -97,21 +97,22 @@ func (t *Storage) Delete(key *Key) error { return t.txn.Delete(buffer.Bytes()) } -func (t *Storage) RootKey() (*Key, error) { - var rootKey *Key +func (t *Storage) RootKey() (*BitArray, error) { + var rootKey *BitArray if err := t.txn.Get(t.prefix, func(val []byte) error { - rootKey = new(Key) - return rootKey.UnmarshalBinary(val) + rootKey = new(BitArray) + rootKey.UnmarshalBinary(val) + return nil }); err != nil { return nil, err } return rootKey, nil } -func (t *Storage) PutRootKey(newRootKey *Key) error { +func (t *Storage) PutRootKey(newRootKey *BitArray) error { buffer := getBuffer() defer bufferPool.Put(buffer) - _, err := newRootKey.WriteTo(buffer) + _, err := newRootKey.Write(buffer) if err != nil { return err } diff --git a/core/trie/storage_test.go b/core/trie/storage_test.go index 809ded4791..37a4e8e447 100644 --- a/core/trie/storage_test.go +++ b/core/trie/storage_test.go @@ -15,7 +15,7 @@ import ( func TestStorage(t *testing.T) { testDB := pebble.NewMemTest(t) prefix := []byte{37, 44} - key := trie.NewKey(44, nil) + key := trie.NewBitArray(44, 0) value, err := new(felt.Felt).SetRandom() require.NoError(t, err) @@ -27,7 +27,7 @@ func TestStorage(t *testing.T) { t.Run("put a node", func(t *testing.T) { require.NoError(t, testDB.Update(func(txn db.Transaction) error { tTxn := trie.NewStorage(txn, prefix) - return tTxn.Put(&key, node) + return tTxn.Put(key, node) })) }) @@ -35,7 +35,7 @@ func TestStorage(t *testing.T) { require.NoError(t, testDB.View(func(txn db.Transaction) error { tTxn := trie.NewStorage(txn, prefix) var got *trie.Node - got, err = tTxn.Get(&key) + got, err = tTxn.Get(key) require.NoError(t, err) assert.Equal(t, node, got) return err @@ -46,7 +46,7 @@ func TestStorage(t *testing.T) { // Successfully delete a node and return an error to force a roll back. require.Error(t, testDB.Update(func(txn db.Transaction) error { tTxn := trie.NewStorage(txn, prefix) - err = tTxn.Delete(&key) + err = tTxn.Delete(key) require.NoError(t, err) return errors.New("should rollback") })) @@ -56,7 +56,7 @@ func TestStorage(t *testing.T) { require.NoError(t, testDB.View(func(txn db.Transaction) error { tTxn := trie.NewStorage(txn, prefix) var got *trie.Node - got, err = tTxn.Get(&key) + got, err = tTxn.Get(key) assert.Equal(t, node, got) return err })) @@ -66,23 +66,23 @@ func TestStorage(t *testing.T) { // Delete a node. require.NoError(t, testDB.Update(func(txn db.Transaction) error { tTxn := trie.NewStorage(txn, prefix) - return tTxn.Delete(&key) + return tTxn.Delete(key) })) // Node should no longer exist in the database. require.EqualError(t, testDB.View(func(txn db.Transaction) error { tTxn := trie.NewStorage(txn, prefix) - _, err = tTxn.Get(&key) + _, err = tTxn.Get(key) return err }), db.ErrKeyNotFound.Error()) }) - rootKey := trie.NewKey(8, []byte{0x2}) + rootKey := trie.NewBitArray(8, 2) t.Run("put root key", func(t *testing.T) { require.NoError(t, testDB.Update(func(txn db.Transaction) error { tTxn := trie.NewStorage(txn, prefix) - return tTxn.PutRootKey(&rootKey) + return tTxn.PutRootKey(rootKey) })) }) @@ -91,7 +91,7 @@ func TestStorage(t *testing.T) { tTxn := trie.NewStorage(txn, prefix) gotRootKey, err := tTxn.RootKey() require.NoError(t, err) - assert.Equal(t, rootKey, *gotRootKey) + assert.Equal(t, rootKey, gotRootKey) return nil })) }) diff --git a/core/trie/trie.go b/core/trie/trie.go index c21168c505..02d930bd4f 100644 --- a/core/trie/trie.go +++ b/core/trie/trie.go @@ -37,12 +37,12 @@ type hashFunc func(*felt.Felt, *felt.Felt) *felt.Felt // [specification]: https://docs.starknet.io/architecture-and-concepts/network-architecture/starknet-state/#merkle_patricia_trie type Trie struct { height uint8 - rootKey *Key + rootKey *BitArray maxKey *felt.Felt storage *Storage hash hashFunc - dirtyNodes []*Key + dirtyNodes []*BitArray rootKeyIsDirty bool } @@ -96,32 +96,35 @@ func RunOnTempTriePoseidon(height uint8, do func(*Trie) error) error { return do(trie) } -// feltToKey Converts a key, given in felt, to a trie.Key which when followed on a [Trie], +// FeltToKey converts a key, given in felt, to a trie.Key which when followed on a [Trie], // leads to the corresponding [Node] -func (t *Trie) FeltToKey(k *felt.Felt) Key { - return FeltToKey(t.height, k) +func (t *Trie) FeltToKey(k *felt.Felt) BitArray { + var ba BitArray + ba.SetFelt(t.height, k) + return ba } // path returns the path as mentioned in the [specification] for commitment calculations. // path is suffix of key that diverges from parentKey. For example, // for a key 0b1011 and parentKey 0b10, this function would return the path object of 0b0. -func path(key, parentKey *Key) Key { - path := *key - // drop parent key, and one more MSB since left/right relation already encodes that information - if parentKey != nil { - path.Truncate(path.Len() - parentKey.Len() - 1) +func path(key, parentKey *BitArray) BitArray { + if parentKey == nil { + return key.Copy() } - return path + + var pathKey BitArray + pathKey.LSBs(key, key.Len()-parentKey.Len()-1) + return pathKey } // storageNode is the on-disk representation of a [Node], // where key is the storage key and node is the value. type StorageNode struct { - key *Key + key *BitArray node *Node } -func (sn *StorageNode) Key() *Key { +func (sn *StorageNode) Key() *BitArray { return sn.key } @@ -135,7 +138,7 @@ func (sn *StorageNode) String() string { func (sn *StorageNode) Update(other *StorageNode) error { // First validate all fields for conflicts - if sn.key != nil && other.key != nil && !sn.key.Equal(NilKey) && !other.key.Equal(NilKey) { + if sn.key != nil && other.key != nil && !sn.key.Equal(emptyBitArray) && !other.key.Equal(emptyBitArray) { if !sn.key.Equal(other.key) { return fmt.Errorf("keys do not match: %s != %s", sn.key, other.key) } @@ -149,47 +152,47 @@ func (sn *StorageNode) Update(other *StorageNode) error { } // After validation, perform update - if other.key != nil && !other.key.Equal(NilKey) { + if other.key != nil && !other.key.Equal(emptyBitArray) { sn.key = other.key } return nil } -func NewStorageNode(key *Key, node *Node) *StorageNode { +func NewStorageNode(key *BitArray, node *Node) *StorageNode { return &StorageNode{key: key, node: node} } // NewPartialStorageNode creates a new StorageNode with a given key and value, // where the right and left children are nil. -func NewPartialStorageNode(key *Key, value *felt.Felt) *StorageNode { +func NewPartialStorageNode(key *BitArray, value *felt.Felt) *StorageNode { return &StorageNode{ key: key, node: &Node{ Value: value, - Left: NilKey, - Right: NilKey, + Left: emptyBitArray, + Right: emptyBitArray, }, } } // StorageNodeSet wraps OrderedSet to provide specific functionality for StorageNodes type StorageNodeSet struct { - set *utils.OrderedSet[Key, *StorageNode] + set *utils.OrderedSet[BitArray, *StorageNode] } func NewStorageNodeSet() *StorageNodeSet { return &StorageNodeSet{ - set: utils.NewOrderedSet[Key, *StorageNode](), + set: utils.NewOrderedSet[BitArray, *StorageNode](), } } -func (s *StorageNodeSet) Get(key Key) (*StorageNode, bool) { +func (s *StorageNodeSet) Get(key BitArray) (*StorageNode, bool) { return s.set.Get(key) } // Put adds a new StorageNode or updates an existing one. -func (s *StorageNodeSet) Put(key Key, node *StorageNode) error { +func (s *StorageNodeSet) Put(key BitArray, node *StorageNode) error { if node == nil { return errors.New("cannot put nil node") } @@ -219,7 +222,7 @@ func (s *StorageNodeSet) Size() int { // nodesFromRoot enumerates the set of [Node] objects that need to be traversed from the root // of the Trie to the node which is given by the key. // The [storageNode]s are returned in descending order beginning with the root. -func (t *Trie) nodesFromRoot(key *Key) ([]StorageNode, error) { +func (t *Trie) nodesFromRoot(key *BitArray) ([]StorageNode, error) { var nodes []StorageNode cur := t.rootKey for cur != nil { @@ -238,8 +241,7 @@ func (t *Trie) nodesFromRoot(key *Key) ([]StorageNode, error) { node: node, }) - subset := isSubset(key, cur) - if cur.Len() >= key.Len() || !subset { + if cur.Len() >= key.Len() || !key.EqualMSBs(cur) { return nodes, nil } @@ -269,12 +271,12 @@ func (t *Trie) Get(key *felt.Felt) (*felt.Felt, error) { } // GetNodeFromKey returns the node for a given key. -func (t *Trie) GetNodeFromKey(key *Key) (*Node, error) { +func (t *Trie) GetNodeFromKey(key *BitArray) (*Node, error) { return t.storage.Get(key) } // check if we are updating an existing leaf, if yes avoid traversing the trie -func (t *Trie) updateLeaf(nodeKey Key, node *Node, value *felt.Felt) (*felt.Felt, error) { +func (t *Trie) updateLeaf(nodeKey BitArray, node *Node, value *felt.Felt) (*felt.Felt, error) { // Check if we are updating an existing leaf if !value.IsZero() { if existingLeaf, err := t.storage.Get(&nodeKey); err == nil { @@ -291,7 +293,7 @@ func (t *Trie) updateLeaf(nodeKey Key, node *Node, value *felt.Felt) (*felt.Felt return nil, nil } -func (t *Trie) handleEmptyTrie(old felt.Felt, nodeKey Key, node *Node, value *felt.Felt) (*felt.Felt, error) { +func (t *Trie) handleEmptyTrie(old felt.Felt, nodeKey BitArray, node *Node, value *felt.Felt) (*felt.Felt, error) { if value.IsZero() { return nil, nil // no-op } @@ -303,7 +305,7 @@ func (t *Trie) handleEmptyTrie(old felt.Felt, nodeKey Key, node *Node, value *fe return &old, nil } -func (t *Trie) deleteExistingKey(sibling StorageNode, nodeKey Key, nodes []StorageNode) (*felt.Felt, error) { +func (t *Trie) deleteExistingKey(sibling StorageNode, nodeKey BitArray, nodes []StorageNode) (*felt.Felt, error) { if nodeKey.Equal(sibling.key) { // we have to deference the Value, since the Node can released back // to the NodePool and be reused anytime @@ -316,7 +318,7 @@ func (t *Trie) deleteExistingKey(sibling StorageNode, nodeKey Key, nodes []Stora return nil, nil } -func (t *Trie) replaceLinkWithNewParent(key *Key, commonKey Key, siblingParent StorageNode) { +func (t *Trie) replaceLinkWithNewParent(key *BitArray, commonKey BitArray, siblingParent StorageNode) { if siblingParent.node.Left.Equal(key) { *siblingParent.node.Left = commonKey } else { @@ -325,8 +327,9 @@ func (t *Trie) replaceLinkWithNewParent(key *Key, commonKey Key, siblingParent S } // TODO(weiihann): not a good idea to couple proof verification logic with trie logic -func (t *Trie) insertOrUpdateValue(nodeKey *Key, node *Node, nodes []StorageNode, sibling StorageNode, siblingIsParentProof bool) error { - commonKey, _ := findCommonKey(nodeKey, sibling.key) +func (t *Trie) insertOrUpdateValue(nodeKey *BitArray, node *Node, nodes []StorageNode, sibling StorageNode, siblingIsParentProof bool) error { + var commonKey BitArray + commonKey.CommonMSBs(nodeKey, sibling.key) newParent := &Node{} var leftChild, rightChild *Node @@ -499,19 +502,19 @@ func (t *Trie) PutWithProof(key, value *felt.Felt, proof []*StorageNode) (*felt. } // Put updates the corresponding `value` for a `key` -func (t *Trie) PutInner(key *Key, node *Node) error { +func (t *Trie) PutInner(key *BitArray, node *Node) error { if err := t.storage.Put(key, node); err != nil { return err } return nil } -func (t *Trie) setRootKey(newRootKey *Key) { +func (t *Trie) setRootKey(newRootKey *BitArray) { t.rootKey = newRootKey t.rootKeyIsDirty = true } -func (t *Trie) updateValueIfDirty(key *Key) (*Node, error) { //nolint:gocyclo +func (t *Trie) updateValueIfDirty(key *BitArray) (*Node, error) { //nolint:gocyclo node, err := t.storage.Get(key) if err != nil { return nil, err @@ -525,7 +528,7 @@ func (t *Trie) updateValueIfDirty(key *Key) (*Node, error) { //nolint:gocyclo shouldUpdate := false for _, dirtyNode := range t.dirtyNodes { if key.Len() < dirtyNode.Len() { - shouldUpdate = isSubset(dirtyNode, key) + shouldUpdate = key.EqualMSBs(dirtyNode) if shouldUpdate { break } @@ -533,9 +536,9 @@ func (t *Trie) updateValueIfDirty(key *Key) (*Node, error) { //nolint:gocyclo } // Update inner proof nodes - if node.Left.Equal(NilKey) && node.Right.Equal(NilKey) { // leaf + if node.Left.Equal(emptyBitArray) && node.Right.Equal(emptyBitArray) { // leaf shouldUpdate = false - } else if node.Left.Equal(NilKey) || node.Right.Equal(NilKey) { // inner + } else if node.Left.Equal(emptyBitArray) || node.Right.Equal(emptyBitArray) { // inner shouldUpdate = true } if !shouldUpdate { @@ -544,11 +547,11 @@ func (t *Trie) updateValueIfDirty(key *Key) (*Node, error) { //nolint:gocyclo var leftIsProof, rightIsProof bool var leftHash, rightHash *felt.Felt - if node.Left.Equal(NilKey) { // key could be nil but hash cannot be + if node.Left.Equal(emptyBitArray) { // key could be nil but hash cannot be leftIsProof = true leftHash = node.LeftHash } - if node.Right.Equal(NilKey) { + if node.Right.Equal(emptyBitArray) { rightIsProof = true rightHash = node.RightHash } @@ -645,7 +648,7 @@ func (t *Trie) deleteLast(nodes []StorageNode) error { return err } - var siblingKey Key + var siblingKey BitArray if parent.node.Left.Equal(last.key) { siblingKey = *parent.node.Right } else { @@ -712,7 +715,7 @@ func (t *Trie) Commit() error { } // RootKey returns db key of the [Trie] root node -func (t *Trie) RootKey() *Key { +func (t *Trie) RootKey() *BitArray { return t.rootKey } @@ -734,7 +737,7 @@ The following can be printed: The spacing to represent the levels of the trie can remain the same. */ -func (t *Trie) dump(level int, parentP *Key) { +func (t *Trie) dump(level int, parentP *BitArray) { if t.rootKey == nil { fmt.Printf("%sEMPTY\n", strings.Repeat("\t", level)) return diff --git a/core/trie/trie_pkg_test.go b/core/trie/trie_pkg_test.go index 5426cbcafa..533450a687 100644 --- a/core/trie/trie_pkg_test.go +++ b/core/trie/trie_pkg_test.go @@ -55,16 +55,17 @@ func TestTrieKeys(t *testing.T) { // Check parent and its left right children l := tempTrie.FeltToKey(leftKey) r := tempTrie.FeltToKey(rightKey) - commonKey, isSame := findCommonKey(&l, &r) - require.False(t, isSame) + var commonKey BitArray + commonKey.CommonMSBs(&l, &r) // Common key should be 0b100, length 251-2; - expectKey := NewKey(251-2, []byte{0x4}) + // expectKey := NewKey(251-2, []byte{0x4}) + expectKey := NewBitArray(249, 4) - assert.Equal(t, expectKey, commonKey) + assert.Equal(t, expectKey, &commonKey) // Current rootKey should be the common key - assert.Equal(t, expectKey, *tempTrie.rootKey) + assert.Equal(t, expectKey, tempTrie.rootKey) parentNode, err := tempTrie.storage.Get(&commonKey) require.NoError(t, err) @@ -98,12 +99,12 @@ func TestTrieKeys(t *testing.T) { // Check parent and its left right children l := tempTrie.FeltToKey(leftKey) r := tempTrie.FeltToKey(rightKey) - commonKey, isSame := findCommonKey(&l, &r) - require.False(t, isSame) + var commonKey BitArray + commonKey.CommonMSBs(&l, &r) - expectKey := NewKey(251-2, []byte{0x4}) + expectKey := NewBitArray(249, 4) - assert.Equal(t, expectKey, commonKey) + assert.Equal(t, expectKey, &commonKey) parentNode, err := tempTrie.storage.Get(&commonKey) require.NoError(t, err) @@ -139,8 +140,8 @@ func TestTrieKeys(t *testing.T) { newKey := new(felt.Felt).SetUint64(0b101) _, err = tempTrie.Put(newKey, newVal) require.NoError(t, err) - commonKey := NewKey(250, []byte{0x2}) - parentNode, pErr := tempTrie.storage.Get(&commonKey) + commonKey := NewBitArray(250, 2) + parentNode, pErr := tempTrie.storage.Get(commonKey) require.NoError(t, pErr) assert.Equal(t, tempTrie.FeltToKey(leftKey), *parentNode.Left) assert.Equal(t, tempTrie.FeltToKey(newKey), *parentNode.Right) @@ -150,8 +151,8 @@ func TestTrieKeys(t *testing.T) { newKey := new(felt.Felt).SetUint64(0b110) _, err = tempTrie.Put(newKey, newVal) require.NoError(t, err) - commonKey := NewKey(250, []byte{0x3}) - parentNode, pErr := tempTrie.storage.Get(&commonKey) + commonKey := NewBitArray(250, 3) + parentNode, pErr := tempTrie.storage.Get(commonKey) require.NoError(t, pErr) assert.Equal(t, tempTrie.FeltToKey(newKey), *parentNode.Left) assert.Equal(t, tempTrie.FeltToKey(rightKey), *parentNode.Right) @@ -166,15 +167,15 @@ func TestTrieKeys(t *testing.T) { _, err = tempTrie.Put(newKey, newVal) require.NoError(t, err) - commonKey := NewKey(248, []byte{}) - parentNode, err := tempTrie.storage.Get(&commonKey) + commonKey := NewBitArray(248, 0) + parentNode, err := tempTrie.storage.Get(commonKey) require.NoError(t, err) assert.Equal(t, tempTrie.FeltToKey(newKey), *parentNode.Left) - expectRightKey := NewKey(249, []byte{0x1}) + expectRightKey := NewBitArray(249, 1) - assert.Equal(t, expectRightKey, *parentNode.Right) + assert.Equal(t, expectRightKey, parentNode.Right) }) }) } @@ -239,11 +240,11 @@ func TestTrieKeysAfterDeleteSubtree(t *testing.T) { _, err = tempTrie.Put(test.deleteKey, zeroVal) require.NoError(t, err) - newRootKey := NewKey(251-2, []byte{0x1}) + newRootKey := NewBitArray(249, 1) - assert.Equal(t, newRootKey, *tempTrie.rootKey) + assert.Equal(t, newRootKey, tempTrie.rootKey) - rootNode, err := tempTrie.storage.Get(&newRootKey) + rootNode, err := tempTrie.storage.Get(newRootKey) require.NoError(t, err) assert.Equal(t, tempTrie.FeltToKey(rightKey), *rootNode.Right) diff --git a/core/trie/trie_test.go b/core/trie/trie_test.go index 51d589ab63..7384bf558b 100644 --- a/core/trie/trie_test.go +++ b/core/trie/trie_test.go @@ -164,7 +164,81 @@ func TestPutZero(t *testing.T) { var keys []*felt.Felt // put random 64 keys and record roots - for range 64 { + for i := 0; i < 64; i++ { + key, value := new(felt.Felt), new(felt.Felt) + + _, err = key.SetRandom() + require.NoError(t, err) + + t.Logf("key: %s", key.String()) + + _, err = value.SetRandom() + require.NoError(t, err) + + t.Logf("value: %s", value.String()) + + _, err = tempTrie.Put(key, value) + require.NoError(t, err) + + keys = append(keys, key) + + var root *felt.Felt + root, err = tempTrie.Root() + require.NoError(t, err) + + roots = append(roots, root) + } + + t.Run("adding a zero value to a non-existent key should not change Trie", func(t *testing.T) { + var key, root *felt.Felt + key, err = new(felt.Felt).SetRandom() + require.NoError(t, err) + + _, err = tempTrie.Put(key, new(felt.Felt)) + require.NoError(t, err) + + root, err = tempTrie.Root() + require.NoError(t, err) + + assert.Equal(t, true, root.Equal(roots[len(roots)-1])) + }) + + t.Run("remove keys one by one, check roots", func(t *testing.T) { + var gotRoot *felt.Felt + // put zero in reverse order and check roots still match + for i := range 64 { + root := roots[len(roots)-1-i] + + gotRoot, err = tempTrie.Root() + require.NoError(t, err) + + assert.Equal(t, root, gotRoot) + + key := keys[len(keys)-1-i] + _, err = tempTrie.Put(key, new(felt.Felt)) + require.NoError(t, err) + } + }) + + t.Run("empty roots should match", func(t *testing.T) { + actualEmptyRoot, err := tempTrie.Root() + require.NoError(t, err) + + assert.Equal(t, true, actualEmptyRoot.Equal(emptyRoot)) + }) + return nil + })) +} + +func TestTrie(t *testing.T) { + require.NoError(t, trie.RunOnTempTriePedersen(251, func(tempTrie *trie.Trie) error { + emptyRoot, err := tempTrie.Root() + require.NoError(t, err) + var roots []*felt.Felt + var keys []*felt.Felt + + // put random 64 keys and record roots + for i := 0; i < 64; i++ { key, value := new(felt.Felt), new(felt.Felt) _, err = key.SetRandom() diff --git a/migration/migration.go b/migration/migration.go index 107bd40f10..97ce613f58 100644 --- a/migration/migration.go +++ b/migration/migration.go @@ -511,7 +511,7 @@ func calculateL1MsgHashes(txn db.Transaction, n *utils.Network) error { return processBlocks(txn, processBlockFunc) } -func bitset2Key(bs *bitset.BitSet) *trie.Key { +func bitset2BitArray(bs *bitset.BitSet) *trie.BitArray { bsWords := bs.Words() if len(bsWords) > felt.Limbs { panic("key too long to fit in Felt") @@ -524,9 +524,7 @@ func bitset2Key(bs *bitset.BitSet) *trie.Key { } f := new(felt.Felt).SetBytes(bsBytes[:]) - fBytes := f.Bytes() - k := trie.NewKey(uint8(bs.Len()), fBytes[:]) - return &k + return new(trie.BitArray).SetFelt(uint8(bs.Len()), f) } func migrateTrieRootKeysFromBitsetToTrieKeys(txn db.Transaction, key, value []byte, _ *utils.Network) error { @@ -535,8 +533,8 @@ func migrateTrieRootKeysFromBitsetToTrieKeys(txn db.Transaction, key, value []by if err := bs.UnmarshalBinary(value); err != nil { return err } - trieKey := bitset2Key(&bs) - _, err := trieKey.WriteTo(&tempBuf) + trieKey := bitset2BitArray(&bs) + _, err := trieKey.Write(&tempBuf) if err != nil { return err } @@ -574,8 +572,8 @@ func migrateTrieNodesFromBitsetToTrieKey(target db.Bucket) BucketMigratorDoFunc Value: n.Value, } if n.Left != nil { - trieNode.Left = bitset2Key(n.Left) - trieNode.Right = bitset2Key(n.Right) + trieNode.Left = bitset2BitArray(n.Left) + trieNode.Right = bitset2BitArray(n.Right) } if _, err := trieNode.WriteTo(&tempBuf); err != nil { @@ -594,7 +592,7 @@ func migrateTrieNodesFromBitsetToTrieKey(target db.Bucket) BucketMigratorDoFunc } var keyBuffer bytes.Buffer - if _, err := bitset2Key(&bs).WriteTo(&keyBuffer); err != nil { + if _, err := bitset2BitArray(&bs).Write(&keyBuffer); err != nil { return err } diff --git a/migration/migration_pkg_test.go b/migration/migration_pkg_test.go index e2d5613c48..688643386c 100644 --- a/migration/migration_pkg_test.go +++ b/migration/migration_pkg_test.go @@ -260,8 +260,11 @@ func TestMigrateTrieRootKeysFromBitsetToTrieKeys(t *testing.T) { require.NoError(t, migrateTrieRootKeysFromBitsetToTrieKeys(memTxn, key, bsBytes, &utils.Mainnet)) - var trieKey trie.Key - err = memTxn.Get(key, trieKey.UnmarshalBinary) + var trieKey trie.BitArray + err = memTxn.Get(key, func(data []byte) error { + trieKey.UnmarshalBinary(data) + return nil + }) require.NoError(t, err) require.Equal(t, bs.Len(), uint(trieKey.Len())) require.Equal(t, felt.Zero, trieKey.Felt()) @@ -357,7 +360,7 @@ func TestMigrateCairo1CompiledClass(t *testing.T) { } } -func TestMigrateTrieNodesFromBitsetToTrieKey(t *testing.T) { +func TestMigrateTrieNodesFromBitsetToBitArray(t *testing.T) { migrator := migrateTrieNodesFromBitsetToTrieKey(db.ClassesTrie) memTxn := db.NewMemTransaction() @@ -388,9 +391,9 @@ func TestMigrateTrieNodesFromBitsetToTrieKey(t *testing.T) { require.ErrorIs(t, err, db.ErrKeyNotFound) var nodeKeyBuf bytes.Buffer - newNodeKey := bitset2Key(bs) - wrote, err = newNodeKey.WriteTo(&nodeKeyBuf) - require.True(t, wrote > 0) + newNodeKey := bitset2BitArray(bs) + bWrite, err := newNodeKey.Write(&nodeKeyBuf) + require.True(t, bWrite > 0) require.NoError(t, err) var trieNode trie.Node From 4915a23e0b612fb5cd7be334930b88176a072adf Mon Sep 17 00:00:00 2001 From: weiihann Date: Wed, 18 Dec 2024 10:35:09 +0800 Subject: [PATCH 20/30] improve comments --- core/trie/bitarray.go | 107 ++++++++++++++++++++----------------- core/trie/bitarray_test.go | 3 ++ 2 files changed, 60 insertions(+), 50 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 11a55edae5..a2ee47b05d 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -16,12 +16,9 @@ const ( bits8 = 8 ) -var ( - maxBits = [4]uint64{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64} - emptyBitArray = new(BitArray) -) +var emptyBitArray = new(BitArray) -// BitArray is a structure that represents a bit array with length representing the number of used bits. +// Represents a bit array with length representing the number of used bits. // It uses a little endian representation to do bitwise operations of the words efficiently. // For example, if len is 10, it means that the 2^9, 2^8, ..., 2^0 bits are used. // The max length is 255 bits (uint8), because our use case only need up to 251 bits for a given trie key. @@ -35,6 +32,7 @@ func NewBitArray(length uint8, val uint64) *BitArray { return new(BitArray).SetUint64(length, val) } +// Returns the felt representation of the bit array. func (b *BitArray) Felt() felt.Felt { var f felt.Felt f.SetBytes(b.Bytes()) @@ -45,7 +43,7 @@ func (b *BitArray) Len() uint8 { return b.len } -// Bytes returns the bytes representation of the bit array in big endian format +// Returns the bytes representation of the bit array in big endian format // //nolint:mnd func (b *BitArray) Bytes() []byte { @@ -83,42 +81,7 @@ func (b *BitArray) Bytes() []byte { return res[:] } -// EqualMSBs checks if two bit arrays share the same most significant bits, where the length of -// the check is determined by the shorter array. Returns true if either array has -// length 0, or if the first min(b.len, x.len) MSBs are identical. -// -// For example: -// -// a = 1101 (len=4) -// b = 11010111 (len=8) -// a.EqualMSBs(b) = true // First 4 MSBs match -// -// a = 1100 (len=4) -// b = 1101 (len=4) -// a.EqualMSBs(b) = false // All bits compared, not equal -// -// a = 1100 (len=4) -// b = [] (len=0) -// a.EqualMSBs(b) = true // Zero length is always a prefix match -func (b *BitArray) EqualMSBs(x *BitArray) bool { - if b.len == x.len { - return b.Equal(x) - } - - if b.len == 0 || x.len == 0 { - return true - } - - // Compare only the first min(b.len, x.len) bits - minLen := b.len - if x.len < minLen { - minLen = x.len - } - - return new(BitArray).MSBs(b, minLen).Equal(new(BitArray).MSBs(x, minLen)) -} - -// LSBs sets b to the least significant 'n' bits of x. +// Sets b to the least significant 'n' bits of x. // If n >= x.len, b is an exact copy of x. // Any bits beyond the specified length are cleared to zero. // For example: @@ -164,7 +127,42 @@ func (b *BitArray) LSBs(x *BitArray, length uint8) *BitArray { return b } -// MSBs sets b to the most significant 'n' bits of x. +// Checks if the current bit array share the same most significant bits with another, where the length of +// the check is determined by the shorter array. Returns true if either array has +// length 0, or if the first min(b.len, x.len) MSBs are identical. +// +// For example: +// +// a = 1101 (len=4) +// b = 11010111 (len=8) +// a.EqualMSBs(b) = true // First 4 MSBs match +// +// a = 1100 (len=4) +// b = 1101 (len=4) +// a.EqualMSBs(b) = false // All bits compared, not equal +// +// a = 1100 (len=4) +// b = [] (len=0) +// a.EqualMSBs(b) = true // Zero length is always a prefix match +func (b *BitArray) EqualMSBs(x *BitArray) bool { + if b.len == x.len { + return b.Equal(x) + } + + if b.len == 0 || x.len == 0 { + return true + } + + // Compare only the first min(b.len, x.len) bits + minLen := b.len + if x.len < minLen { + minLen = x.len + } + + return new(BitArray).MSBs(b, minLen).Equal(new(BitArray).MSBs(x, minLen)) +} + +// Sets b to the most significant 'n' bits of x. // If n >= x.len, b is an exact copy of x. // Any bits beyond the specified length are cleared to zero. // For example: @@ -181,7 +179,7 @@ func (b *BitArray) MSBs(x *BitArray, n uint8) *BitArray { return b.Rsh(x, x.len-n) } -// CommonMSBs sets b to the longest sequence of matching most significant bits between two bit arrays. +// Sets b to the longest sequence of matching most significant bits between two bit arrays. // For example: // // x = 1101 0111 (len=8) @@ -219,7 +217,7 @@ func (b *BitArray) CommonMSBs(x, y *BitArray) *BitArray { return b.Rsh(short, divergentBit) } -// Rsh sets b = x >> n and returns b. +// Sets b = x >> n and returns b. // //nolint:mnd func (b *BitArray) Rsh(x *BitArray, n uint8) *BitArray { @@ -264,7 +262,7 @@ func (b *BitArray) Rsh(x *BitArray, n uint8) *BitArray { return b } -// Xor sets b = x ^ y and returns b. +// Sets b = x ^ y and returns b. func (b *BitArray) Xor(x, y *BitArray) *BitArray { b.words[0] = x.words[0] ^ y.words[0] b.words[1] = x.words[1] ^ y.words[1] @@ -273,7 +271,7 @@ func (b *BitArray) Xor(x, y *BitArray) *BitArray { return b } -// Eq checks if two bit arrays are equal +// Checks if two bit arrays are equal func (b *BitArray) Equal(x *BitArray) bool { // TODO(weiihann): this is really not a good thing to do... if b == nil && x == nil { @@ -289,7 +287,7 @@ func (b *BitArray) Equal(x *BitArray) bool { b.words[3] == x.words[3] } -// IsBitSit returns true if bit n-th is set, where n = 0 is LSB. +// Returns true if bit n-th is set, where n = 0 is LSB. // The n must be <= 255. func (b *BitArray) IsBitSet(n uint8) bool { if n >= b.len { @@ -299,7 +297,7 @@ func (b *BitArray) IsBitSet(n uint8) bool { return (b.words[n/64] & (1 << (n % 64))) != 0 } -// Write serialises the BitArray into a bytes buffer in the following format: +// Serialises the BitArray into a bytes buffer in the following format: // - First byte: length of the bit array (0-255) // - Remaining bytes: the necessary bytes included in big endian order // Example: @@ -314,7 +312,7 @@ func (b *BitArray) Write(buf *bytes.Buffer) (int, error) { return n + 1, err } -// UnmarshalBinary deserialises the BitArray from a bytes buffer in the following format: +// Deserialises the BitArray from a bytes buffer in the following format: // - First byte: length of the bit array (0-255) // - Remaining bytes: the necessary bytes included in big endian order // Example: @@ -328,6 +326,7 @@ func (b *BitArray) UnmarshalBinary(data []byte) { b.setBytes32(bs[:]) } +// Sets b to the same value as x. func (b *BitArray) Set(x *BitArray) *BitArray { b.len = x.len b.words[0] = x.words[0] @@ -337,40 +336,48 @@ func (b *BitArray) Set(x *BitArray) *BitArray { return b } +// Sets b to the bytes representation of a felt. func (b *BitArray) SetFelt(length uint8, f *felt.Felt) *BitArray { b.setFelt(f) b.len = length return b } +// Sets b to the bytes representation of a felt with length 251. func (b *BitArray) SetFelt251(f *felt.Felt) *BitArray { b.setFelt(f) b.len = 251 return b } +// Interprets the data as the big-endian bytes, sets b to that value and returns b. +// If the data is larger than 32 bytes, only the first 32 bytes are used. func (b *BitArray) SetBytes(length uint8, data []byte) *BitArray { b.setBytes32(data) b.len = length return b } +// Sets b to the uint64 representation of a bit array. func (b *BitArray) SetUint64(length uint8, data uint64) *BitArray { b.words[0] = data b.len = length return b } +// Returns the length of the encoded bit array in bytes. func (b *BitArray) EncodedLen() uint { return b.byteCount() + 1 } +// Returns a deep copy of b. func (b *BitArray) Copy() BitArray { var res BitArray res.Set(b) return res } +// Returns a string representation of the bit array. func (b *BitArray) String() string { return fmt.Sprintf("(%d) %s", b.len, hex.EncodeToString(b.Bytes())) } diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index c90223ab6a..479df49fd1 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -3,6 +3,7 @@ package trie import ( "bytes" "encoding/binary" + "math" "math/bits" "testing" @@ -11,6 +12,8 @@ import ( "github.com/stretchr/testify/require" ) +var maxBits = [4]uint64{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64} + const ( ones63 = 0x7FFFFFFFFFFFFFFF ) From 7e037d5611b90a05c30ca579bc8f2267ceb95501 Mon Sep 17 00:00:00 2001 From: weiihann Date: Wed, 18 Dec 2024 10:49:33 +0800 Subject: [PATCH 21/30] fix lint --- core/trie/node.go | 4 +++- core/trie/trie.go | 8 +++++++- core/trie/trie_pkg_test.go | 2 -- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/core/trie/node.go b/core/trie/node.go index 47f43e6eab..171a14385a 100644 --- a/core/trie/node.go +++ b/core/trie/node.go @@ -156,7 +156,9 @@ func (n *Node) Update(other *Node) error { return fmt.Errorf("conflicting Left keys: %v != %v", n.Left, other.Left) } - if n.Right != nil && other.Right != nil && !n.Right.Equal(emptyBitArray) && !other.Right.Equal(emptyBitArray) && !n.Right.Equal(other.Right) { + if n.Right != nil && other.Right != nil && + !n.Right.Equal(emptyBitArray) && !other.Right.Equal(emptyBitArray) && + !n.Right.Equal(other.Right) { return fmt.Errorf("conflicting Right keys: %v != %v", n.Right, other.Right) } diff --git a/core/trie/trie.go b/core/trie/trie.go index 02d930bd4f..f0b51cabe0 100644 --- a/core/trie/trie.go +++ b/core/trie/trie.go @@ -327,7 +327,13 @@ func (t *Trie) replaceLinkWithNewParent(key *BitArray, commonKey BitArray, sibli } // TODO(weiihann): not a good idea to couple proof verification logic with trie logic -func (t *Trie) insertOrUpdateValue(nodeKey *BitArray, node *Node, nodes []StorageNode, sibling StorageNode, siblingIsParentProof bool) error { +func (t *Trie) insertOrUpdateValue( + nodeKey *BitArray, + node *Node, + nodes []StorageNode, + sibling StorageNode, + siblingIsParentProof bool, +) error { var commonKey BitArray commonKey.CommonMSBs(nodeKey, sibling.key) diff --git a/core/trie/trie_pkg_test.go b/core/trie/trie_pkg_test.go index 533450a687..04037c4d7f 100644 --- a/core/trie/trie_pkg_test.go +++ b/core/trie/trie_pkg_test.go @@ -135,7 +135,6 @@ func TestTrieKeys(t *testing.T) { require.NoError(t, err) newVal := new(felt.Felt).SetUint64(12) - //nolint: dupl t.Run("Add to left branch", func(t *testing.T) { newKey := new(felt.Felt).SetUint64(0b101) _, err = tempTrie.Put(newKey, newVal) @@ -146,7 +145,6 @@ func TestTrieKeys(t *testing.T) { assert.Equal(t, tempTrie.FeltToKey(leftKey), *parentNode.Left) assert.Equal(t, tempTrie.FeltToKey(newKey), *parentNode.Right) }) - //nolint: dupl t.Run("Add to right branch", func(t *testing.T) { newKey := new(felt.Felt).SetUint64(0b110) _, err = tempTrie.Put(newKey, newVal) From 8625dc4d017e9604f1e8e86e12e95d2a380e0c90 Mon Sep 17 00:00:00 2001 From: weiihann Date: Wed, 18 Dec 2024 11:04:28 +0800 Subject: [PATCH 22/30] minor chore --- core/trie/bitarray.go | 6 +++--- core/trie/bitarray_test.go | 26 +++++++++++++------------- core/trie/proof_test.go | 24 ------------------------ core/trie/trie.go | 1 + 4 files changed, 17 insertions(+), 40 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index a2ee47b05d..3762d844ce 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -83,7 +83,6 @@ func (b *BitArray) Bytes() []byte { // Sets b to the least significant 'n' bits of x. // If n >= x.len, b is an exact copy of x. -// Any bits beyond the specified length are cleared to zero. // For example: // // x = 11001011 (len=8) @@ -164,7 +163,6 @@ func (b *BitArray) EqualMSBs(x *BitArray) bool { // Sets b to the most significant 'n' bits of x. // If n >= x.len, b is an exact copy of x. -// Any bits beyond the specified length are cleared to zero. // For example: // // x = 11001011 (len=8) @@ -406,7 +404,7 @@ func (b *BitArray) byteCount() uint { } // activeBytes returns a slice containing only the bytes that are actually used -// by the bit array, excluding leading zero bytes. The returned slice is in +// by the bit array, as specified by the length. The returned slice is in // big-endian order. // // Example: @@ -448,11 +446,13 @@ func findFirstSetBit(b *BitArray) uint8 { return 0 } + // Start from the most significant and move towards the least significant for i := 3; i >= 0; i-- { if word := b.words[i]; word != 0 { return uint8((i+1)*64 - bits.LeadingZeros64(word)) } } + // All bits are zero, no set bit found return 0 } diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 479df49fd1..4c57794b06 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -15,7 +15,7 @@ import ( var maxBits = [4]uint64{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64} const ( - ones63 = 0x7FFFFFFFFFFFFFFF + ones63 = 0x7FFFFFFFFFFFFFFF // 63 bits of 1 ) func TestBytes(t *testing.T) { @@ -231,7 +231,7 @@ func TestRsh(t *testing.T) { } } -func TestPrefixEqual(t *testing.T) { +func TestEqualMSBs(t *testing.T) { tests := []struct { name string a *BitArray @@ -357,7 +357,7 @@ func TestLSBs(t *testing.T) { expected BitArray }{ { - name: "truncate to zero", + name: "zero", initial: BitArray{ len: 64, words: [4]uint64{maxUint64, 0, 0, 0}, @@ -369,7 +369,7 @@ func TestLSBs(t *testing.T) { }, }, { - name: "truncate within first word - 32 bits", + name: "get 32 LSBs", initial: BitArray{ len: 64, words: [4]uint64{maxUint64, 0, 0, 0}, @@ -381,7 +381,7 @@ func TestLSBs(t *testing.T) { }, }, { - name: "truncate to single bit", + name: "get 1 LSB", initial: BitArray{ len: 64, words: [4]uint64{maxUint64, 0, 0, 0}, @@ -389,11 +389,11 @@ func TestLSBs(t *testing.T) { length: 1, expected: BitArray{ len: 1, - words: [4]uint64{0x0000000000000001, 0, 0, 0}, + words: [4]uint64{0x1, 0, 0, 0}, }, }, { - name: "truncate across words - 100 bits", + name: "get 100 LSBs across words", initial: BitArray{ len: 128, words: [4]uint64{maxUint64, maxUint64, 0, 0}, @@ -405,7 +405,7 @@ func TestLSBs(t *testing.T) { }, }, { - name: "truncate at word boundary - 64 bits", + name: "get 64 LSBs at word boundary", initial: BitArray{ len: 128, words: [4]uint64{maxUint64, maxUint64, 0, 0}, @@ -417,7 +417,7 @@ func TestLSBs(t *testing.T) { }, }, { - name: "truncate at word boundary - 128 bits", + name: "get 128 LSBs at word boundary", initial: BitArray{ len: 192, words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, @@ -429,7 +429,7 @@ func TestLSBs(t *testing.T) { }, }, { - name: "truncate in third word - 150 bits", + name: "get 150 LSBs in third word", initial: BitArray{ len: 192, words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, @@ -441,7 +441,7 @@ func TestLSBs(t *testing.T) { }, }, { - name: "truncate in fourth word - 220 bits", + name: "get 220 LSBs in fourth word", initial: BitArray{ len: 255, words: [4]uint64{maxUint64, maxUint64, maxUint64, maxUint64}, @@ -453,7 +453,7 @@ func TestLSBs(t *testing.T) { }, }, { - name: "truncate max length - 251 bits", + name: "get 251 LSBs", initial: BitArray{ len: 255, words: [4]uint64{maxUint64, maxUint64, maxUint64, maxUint64}, @@ -465,7 +465,7 @@ func TestLSBs(t *testing.T) { }, }, { - name: "truncate sparse bits", + name: "get 100 LSBs from sparse bits", initial: BitArray{ len: 128, words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x5555555555555555, 0, 0}, diff --git a/core/trie/proof_test.go b/core/trie/proof_test.go index 0f9c54543a..046b1b1bca 100644 --- a/core/trie/proof_test.go +++ b/core/trie/proof_test.go @@ -13,30 +13,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestFix(t *testing.T) { - numKeys := 1000 - memdb := pebble.NewMemTest(t) - txn, err := memdb.NewTransaction(true) - require.NoError(t, err) - - tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{0}), 251) - require.NoError(t, err) - - records := make([]*keyValue, numKeys) - for i := 1; i < numKeys+1; i++ { - key := new(felt.Felt).SetUint64(uint64(i)) - records[i-1] = &keyValue{key: key, value: key} - _, err := tempTrie.Put(key, key) - require.NoError(t, err) - } - - sort.Slice(records, func(i, j int) bool { - return records[i].key.Cmp(records[j].key) < 0 - }) - - require.NoError(t, tempTrie.Commit()) -} - func TestProve(t *testing.T) { t.Parallel() diff --git a/core/trie/trie.go b/core/trie/trie.go index f0b51cabe0..7eaa46dd7f 100644 --- a/core/trie/trie.go +++ b/core/trie/trie.go @@ -108,6 +108,7 @@ func (t *Trie) FeltToKey(k *felt.Felt) BitArray { // path is suffix of key that diverges from parentKey. For example, // for a key 0b1011 and parentKey 0b10, this function would return the path object of 0b0. func path(key, parentKey *BitArray) BitArray { + // drop parent key, and one more MSB since left/right relation already encodes that information if parentKey == nil { return key.Copy() } From 1f69b43528bba61f86a3fbfbf4f4f96092c61e93 Mon Sep 17 00:00:00 2001 From: weiihann Date: Thu, 19 Dec 2024 12:53:03 +0800 Subject: [PATCH 23/30] improvements --- core/state.go | 2 +- core/trie/bitarray.go | 73 ++++++++++++++++++-------------------- core/trie/node_test.go | 2 +- core/trie/proof.go | 1 + core/trie/storage_test.go | 16 ++++----- core/trie/trie_pkg_test.go | 18 +++++----- 6 files changed, 55 insertions(+), 57 deletions(-) diff --git a/core/state.go b/core/state.go index c17ff13f3e..27c20f0572 100644 --- a/core/state.go +++ b/core/state.go @@ -139,7 +139,7 @@ func (s *State) globalTrie(bucket db.Bucket, newTrie trie.NewTrieFunc) (*trie.Tr // fetch root key rootKeyDBKey := dbPrefix - var rootKey *trie.BitArray + var rootKey *trie.BitArray // TODO: use value instead of pointer err := s.txn.Get(rootKeyDBKey, func(val []byte) error { rootKey = new(trie.BitArray) rootKey.UnmarshalBinary(val) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 3762d844ce..7f8d5481a0 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -11,10 +11,7 @@ import ( "github.com/NethermindEth/juno/core/felt" ) -const ( - maxUint64 = uint64(math.MaxUint64) // 0xFFFFFFFFFFFFFFFF - bits8 = 8 -) +const maxUint64 = uint64(math.MaxUint64) // 0xFFFFFFFFFFFFFFFF var emptyBitArray = new(BitArray) @@ -28,8 +25,10 @@ type BitArray struct { words [4]uint64 // little endian (i.e. words[0] is the least significant) } -func NewBitArray(length uint8, val uint64) *BitArray { - return new(BitArray).SetUint64(length, val) +func NewBitArray(length uint8, val uint64) BitArray { + var b BitArray + b.SetUint64(length, val) + return b } // Returns the felt representation of the bit array. @@ -81,8 +80,8 @@ func (b *BitArray) Bytes() []byte { return res[:] } -// Sets b to the least significant 'n' bits of x. -// If n >= x.len, b is an exact copy of x. +// Sets the bit array to the least significant 'n' bits of x. +// If length >= x.len, the bit array is an exact copy of x. // For example: // // x = 11001011 (len=8) @@ -91,35 +90,35 @@ func (b *BitArray) Bytes() []byte { // LSBs(x, 0) = 0 (len=0) // //nolint:mnd -func (b *BitArray) LSBs(x *BitArray, length uint8) *BitArray { - if length >= x.len { +func (b *BitArray) LSBs(x *BitArray, n uint8) *BitArray { + if n >= x.len { return b.Set(x) } b.Set(x) - b.len = length + b.len = n // Clear all words beyond what's needed switch { - case length == 0: + case n == 0: b.words = [4]uint64{0, 0, 0, 0} - case length <= 64: - mask := maxUint64 >> (64 - length) + case n <= 64: + mask := maxUint64 >> (64 - n) b.words[0] &= mask b.words[1] = 0 b.words[2] = 0 b.words[3] = 0 - case length <= 128: - mask := maxUint64 >> (128 - length) + case n <= 128: + mask := maxUint64 >> (128 - n) b.words[1] &= mask b.words[2] = 0 b.words[3] = 0 - case length <= 192: - mask := maxUint64 >> (192 - length) + case n <= 192: + mask := maxUint64 >> (192 - n) b.words[2] &= mask b.words[3] = 0 default: - mask := maxUint64 >> (256 - uint16(length)) + mask := maxUint64 >> (256 - uint16(n)) b.words[3] &= mask } @@ -161,8 +160,8 @@ func (b *BitArray) EqualMSBs(x *BitArray) bool { return new(BitArray).MSBs(b, minLen).Equal(new(BitArray).MSBs(x, minLen)) } -// Sets b to the most significant 'n' bits of x. -// If n >= x.len, b is an exact copy of x. +// Sets the bit array to the most significant 'n' bits of x. +// If n >= x.len, the bit array is an exact copy of x. // For example: // // x = 11001011 (len=8) @@ -177,7 +176,7 @@ func (b *BitArray) MSBs(x *BitArray, n uint8) *BitArray { return b.Rsh(x, x.len-n) } -// Sets b to the longest sequence of matching most significant bits between two bit arrays. +// Sets the bit array to the longest sequence of matching most significant bits between two bit arrays. // For example: // // x = 1101 0111 (len=8) @@ -185,7 +184,7 @@ func (b *BitArray) MSBs(x *BitArray, n uint8) *BitArray { // CommonMSBs(x,y) = 1101 (len=4) func (b *BitArray) CommonMSBs(x, y *BitArray) *BitArray { if x.len == 0 || y.len == 0 { - return emptyBitArray + return b.clear() } long, short := x, y @@ -215,7 +214,7 @@ func (b *BitArray) CommonMSBs(x, y *BitArray) *BitArray { return b.Rsh(short, divergentBit) } -// Sets b = x >> n and returns b. +// Sets the bit array to x >> n and returns the bit array. // //nolint:mnd func (b *BitArray) Rsh(x *BitArray, n uint8) *BitArray { @@ -260,7 +259,7 @@ func (b *BitArray) Rsh(x *BitArray, n uint8) *BitArray { return b } -// Sets b = x ^ y and returns b. +// Sets the bit array to x ^ y and returns the bit array. func (b *BitArray) Xor(x, y *BitArray) *BitArray { b.words[0] = x.words[0] ^ y.words[0] b.words[1] = x.words[1] ^ y.words[1] @@ -324,7 +323,7 @@ func (b *BitArray) UnmarshalBinary(data []byte) { b.setBytes32(bs[:]) } -// Sets b to the same value as x. +// Sets the bit array to the same value as x. func (b *BitArray) Set(x *BitArray) *BitArray { b.len = x.len b.words[0] = x.words[0] @@ -334,21 +333,21 @@ func (b *BitArray) Set(x *BitArray) *BitArray { return b } -// Sets b to the bytes representation of a felt. +// Sets the bit array to the bytes representation of a felt. func (b *BitArray) SetFelt(length uint8, f *felt.Felt) *BitArray { b.setFelt(f) b.len = length return b } -// Sets b to the bytes representation of a felt with length 251. +// Sets the bit array to the bytes representation of a felt with length 251. func (b *BitArray) SetFelt251(f *felt.Felt) *BitArray { b.setFelt(f) b.len = 251 return b } -// Interprets the data as the big-endian bytes, sets b to that value and returns b. +// Interprets the data as the big-endian bytes, sets the bit array to that value and returns it. // If the data is larger than 32 bytes, only the first 32 bytes are used. func (b *BitArray) SetBytes(length uint8, data []byte) *BitArray { b.setBytes32(data) @@ -356,7 +355,7 @@ func (b *BitArray) SetBytes(length uint8, data []byte) *BitArray { return b } -// Sets b to the uint64 representation of a bit array. +// Sets the bit array to the uint64 representation of a bit array. func (b *BitArray) SetUint64(length uint8, data uint64) *BitArray { b.words[0] = data b.len = length @@ -368,7 +367,7 @@ func (b *BitArray) EncodedLen() uint { return b.byteCount() + 1 } -// Returns a deep copy of b. +// Returns a deep copy of the bit array. func (b *BitArray) Copy() BitArray { var res BitArray res.Set(b) @@ -396,16 +395,16 @@ func (b *BitArray) setBytes32(data []byte) { b.words[0] = binary.BigEndian.Uint64(data[24:32]) } -// byteCount returns the minimum number of bytes needed to represent the bit array. +// Returns the minimum number of bytes needed to represent the bit array. // It rounds up to the nearest byte. func (b *BitArray) byteCount() uint { + const bits8 = 8 // Cast to uint16 to avoid overflow return (uint(b.len) + (bits8 - 1)) / uint(bits8) } -// activeBytes returns a slice containing only the bytes that are actually used -// by the bit array, as specified by the length. The returned slice is in -// big-endian order. +// Returns a slice containing only the bytes that are actually used by the bit array, +// as specified by the length. The returned slice is in big-endian order. // // Example: // @@ -433,9 +432,7 @@ func (b *BitArray) clear() *BitArray { return b } -// findFirstSetBit returns the position of the first '1' bit in the array, -// scanning from most significant to least significant bit. -// +// Returns the position of the first '1' bit in the array, scanning from most significant to least significant bit. // The bit position is counted from the least significant bit, starting at 0. // For example: // diff --git a/core/trie/node_test.go b/core/trie/node_test.go index b222732f4b..cc1bb06eda 100644 --- a/core/trie/node_test.go +++ b/core/trie/node_test.go @@ -24,5 +24,5 @@ func TestNodeHash(t *testing.T) { } path := trie.NewBitArray(6, 42) - assert.Equal(t, expected, node.Hash(path, crypto.Pedersen), "TestTrieNode_Hash failed") + assert.Equal(t, expected, node.Hash(&path, crypto.Pedersen), "TestTrieNode_Hash failed") } diff --git a/core/trie/proof.go b/core/trie/proof.go index ff16371c06..894991a8a9 100644 --- a/core/trie/proof.go +++ b/core/trie/proof.go @@ -48,6 +48,7 @@ func (e *Edge) Hash(hash hashFunc) *felt.Felt { length[31] = e.Path.len pathFelt := e.Path.Felt() lengthFelt := new(felt.Felt).SetBytes(length[:]) + // TODO: no need to return reference, just return value to avoid heap allocation return new(felt.Felt).Add(hash(e.Child, &pathFelt), lengthFelt) } diff --git a/core/trie/storage_test.go b/core/trie/storage_test.go index 37a4e8e447..21302f1308 100644 --- a/core/trie/storage_test.go +++ b/core/trie/storage_test.go @@ -27,7 +27,7 @@ func TestStorage(t *testing.T) { t.Run("put a node", func(t *testing.T) { require.NoError(t, testDB.Update(func(txn db.Transaction) error { tTxn := trie.NewStorage(txn, prefix) - return tTxn.Put(key, node) + return tTxn.Put(&key, node) })) }) @@ -35,7 +35,7 @@ func TestStorage(t *testing.T) { require.NoError(t, testDB.View(func(txn db.Transaction) error { tTxn := trie.NewStorage(txn, prefix) var got *trie.Node - got, err = tTxn.Get(key) + got, err = tTxn.Get(&key) require.NoError(t, err) assert.Equal(t, node, got) return err @@ -46,7 +46,7 @@ func TestStorage(t *testing.T) { // Successfully delete a node and return an error to force a roll back. require.Error(t, testDB.Update(func(txn db.Transaction) error { tTxn := trie.NewStorage(txn, prefix) - err = tTxn.Delete(key) + err = tTxn.Delete(&key) require.NoError(t, err) return errors.New("should rollback") })) @@ -56,7 +56,7 @@ func TestStorage(t *testing.T) { require.NoError(t, testDB.View(func(txn db.Transaction) error { tTxn := trie.NewStorage(txn, prefix) var got *trie.Node - got, err = tTxn.Get(key) + got, err = tTxn.Get(&key) assert.Equal(t, node, got) return err })) @@ -66,13 +66,13 @@ func TestStorage(t *testing.T) { // Delete a node. require.NoError(t, testDB.Update(func(txn db.Transaction) error { tTxn := trie.NewStorage(txn, prefix) - return tTxn.Delete(key) + return tTxn.Delete(&key) })) // Node should no longer exist in the database. require.EqualError(t, testDB.View(func(txn db.Transaction) error { tTxn := trie.NewStorage(txn, prefix) - _, err = tTxn.Get(key) + _, err = tTxn.Get(&key) return err }), db.ErrKeyNotFound.Error()) }) @@ -82,7 +82,7 @@ func TestStorage(t *testing.T) { t.Run("put root key", func(t *testing.T) { require.NoError(t, testDB.Update(func(txn db.Transaction) error { tTxn := trie.NewStorage(txn, prefix) - return tTxn.PutRootKey(rootKey) + return tTxn.PutRootKey(&rootKey) })) }) @@ -91,7 +91,7 @@ func TestStorage(t *testing.T) { tTxn := trie.NewStorage(txn, prefix) gotRootKey, err := tTxn.RootKey() require.NoError(t, err) - assert.Equal(t, rootKey, gotRootKey) + assert.Equal(t, &rootKey, gotRootKey) return nil })) }) diff --git a/core/trie/trie_pkg_test.go b/core/trie/trie_pkg_test.go index 04037c4d7f..d9d13b1e4c 100644 --- a/core/trie/trie_pkg_test.go +++ b/core/trie/trie_pkg_test.go @@ -62,10 +62,10 @@ func TestTrieKeys(t *testing.T) { // expectKey := NewKey(251-2, []byte{0x4}) expectKey := NewBitArray(249, 4) - assert.Equal(t, expectKey, &commonKey) + assert.Equal(t, expectKey, commonKey) // Current rootKey should be the common key - assert.Equal(t, expectKey, tempTrie.rootKey) + assert.Equal(t, &expectKey, tempTrie.rootKey) parentNode, err := tempTrie.storage.Get(&commonKey) require.NoError(t, err) @@ -104,7 +104,7 @@ func TestTrieKeys(t *testing.T) { expectKey := NewBitArray(249, 4) - assert.Equal(t, expectKey, &commonKey) + assert.Equal(t, &expectKey, &commonKey) parentNode, err := tempTrie.storage.Get(&commonKey) require.NoError(t, err) @@ -140,7 +140,7 @@ func TestTrieKeys(t *testing.T) { _, err = tempTrie.Put(newKey, newVal) require.NoError(t, err) commonKey := NewBitArray(250, 2) - parentNode, pErr := tempTrie.storage.Get(commonKey) + parentNode, pErr := tempTrie.storage.Get(&commonKey) require.NoError(t, pErr) assert.Equal(t, tempTrie.FeltToKey(leftKey), *parentNode.Left) assert.Equal(t, tempTrie.FeltToKey(newKey), *parentNode.Right) @@ -150,7 +150,7 @@ func TestTrieKeys(t *testing.T) { _, err = tempTrie.Put(newKey, newVal) require.NoError(t, err) commonKey := NewBitArray(250, 3) - parentNode, pErr := tempTrie.storage.Get(commonKey) + parentNode, pErr := tempTrie.storage.Get(&commonKey) require.NoError(t, pErr) assert.Equal(t, tempTrie.FeltToKey(newKey), *parentNode.Left) assert.Equal(t, tempTrie.FeltToKey(rightKey), *parentNode.Right) @@ -166,14 +166,14 @@ func TestTrieKeys(t *testing.T) { require.NoError(t, err) commonKey := NewBitArray(248, 0) - parentNode, err := tempTrie.storage.Get(commonKey) + parentNode, err := tempTrie.storage.Get(&commonKey) require.NoError(t, err) assert.Equal(t, tempTrie.FeltToKey(newKey), *parentNode.Left) expectRightKey := NewBitArray(249, 1) - assert.Equal(t, expectRightKey, parentNode.Right) + assert.Equal(t, &expectRightKey, parentNode.Right) }) }) } @@ -240,9 +240,9 @@ func TestTrieKeysAfterDeleteSubtree(t *testing.T) { newRootKey := NewBitArray(249, 1) - assert.Equal(t, newRootKey, tempTrie.rootKey) + assert.Equal(t, &newRootKey, tempTrie.rootKey) - rootNode, err := tempTrie.storage.Get(newRootKey) + rootNode, err := tempTrie.storage.Get(&newRootKey) require.NoError(t, err) assert.Equal(t, tempTrie.FeltToKey(rightKey), *rootNode.Right) From d50aacfba9e9e5e86824faadab932e7c3442444d Mon Sep 17 00:00:00 2001 From: weiihann Date: Fri, 20 Dec 2024 11:56:48 +0800 Subject: [PATCH 24/30] ensure unused bits are zero when setting bitarray --- core/trie/bitarray.go | 64 +++++++++++++++++++++---------------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 7f8d5481a0..2abfc9be27 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -43,39 +43,14 @@ func (b *BitArray) Len() uint8 { } // Returns the bytes representation of the bit array in big endian format -// -//nolint:mnd func (b *BitArray) Bytes() []byte { var res [32]byte - switch { - case b.len == 0: - // all zeros - return res[:] - case b.len >= 192: - // Create mask for top word: keeps only valid bits above 192 - // e.g., if len=200, keeps lowest 8 bits (200-192) - mask := maxUint64 >> (256 - uint16(b.len)) - binary.BigEndian.PutUint64(res[0:8], b.words[3]&mask) - binary.BigEndian.PutUint64(res[8:16], b.words[2]) - binary.BigEndian.PutUint64(res[16:24], b.words[1]) - binary.BigEndian.PutUint64(res[24:32], b.words[0]) - case b.len >= 128: - // Mask for bits 128-191: keeps only valid bits above 128 - // e.g., if len=150, keeps lowest 22 bits (150-128) - mask := maxUint64 >> (192 - b.len) - binary.BigEndian.PutUint64(res[8:16], b.words[2]&mask) - binary.BigEndian.PutUint64(res[16:24], b.words[1]) - binary.BigEndian.PutUint64(res[24:32], b.words[0]) - case b.len >= 64: - // You get the idea - mask := maxUint64 >> (128 - b.len) - binary.BigEndian.PutUint64(res[16:24], b.words[1]&mask) - binary.BigEndian.PutUint64(res[24:32], b.words[0]) - default: - mask := maxUint64 >> (64 - b.len) - binary.BigEndian.PutUint64(res[24:32], b.words[0]&mask) - } + b.truncateToLength() + binary.BigEndian.PutUint64(res[0:8], b.words[3]) + binary.BigEndian.PutUint64(res[8:16], b.words[2]) + binary.BigEndian.PutUint64(res[16:24], b.words[1]) + binary.BigEndian.PutUint64(res[24:32], b.words[0]) return res[:] } @@ -335,15 +310,17 @@ func (b *BitArray) Set(x *BitArray) *BitArray { // Sets the bit array to the bytes representation of a felt. func (b *BitArray) SetFelt(length uint8, f *felt.Felt) *BitArray { - b.setFelt(f) b.len = length + b.setFelt(f) + b.truncateToLength() return b } // Sets the bit array to the bytes representation of a felt with length 251. func (b *BitArray) SetFelt251(f *felt.Felt) *BitArray { - b.setFelt(f) b.len = 251 + b.setFelt(f) + b.truncateToLength() return b } @@ -352,6 +329,7 @@ func (b *BitArray) SetFelt251(f *felt.Felt) *BitArray { func (b *BitArray) SetBytes(length uint8, data []byte) *BitArray { b.setBytes32(data) b.len = length + b.truncateToLength() return b } @@ -359,6 +337,7 @@ func (b *BitArray) SetBytes(length uint8, data []byte) *BitArray { func (b *BitArray) SetUint64(length uint8, data uint64) *BitArray { b.words[0] = data b.len = length + b.truncateToLength() return b } @@ -432,6 +411,27 @@ func (b *BitArray) clear() *BitArray { return b } +// Truncates the bit array to the specified length, ensuring that any unused bits are all zeros. +// +//nolint:mnd +func (b *BitArray) truncateToLength() { + switch { + case b.len == 0: + b.words = [4]uint64{0, 0, 0, 0} + case b.len <= 64: + b.words[0] &= maxUint64 >> (64 - b.len) + b.words[1], b.words[2], b.words[3] = 0, 0, 0 + case b.len <= 128: + b.words[1] &= maxUint64 >> (128 - b.len) + b.words[2], b.words[3] = 0, 0 + case b.len <= 192: + b.words[2] &= maxUint64 >> (192 - b.len) + b.words[3] = 0 + default: + b.words[3] &= maxUint64 >> (256 - uint16(b.len)) + } +} + // Returns the position of the first '1' bit in the array, scanning from most significant to least significant bit. // The bit position is counted from the least significant bit, starting at 0. // For example: From 2692494c39cc49713af69ad501fb4fce401a2de6 Mon Sep 17 00:00:00 2001 From: weiihann Date: Fri, 20 Dec 2024 12:01:49 +0800 Subject: [PATCH 25/30] update comment --- core/trie/bitarray.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 2abfc9be27..85aa4fcca4 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -437,7 +437,7 @@ func (b *BitArray) truncateToLength() { // For example: // // array = 0000 0000 ... 0100 (len=251) -// findFirstSetBit() = 2 // third bit from right is set +// findFirstSetBit() = 3 // third bit from right is set func findFirstSetBit(b *BitArray) uint8 { if b.len == 0 { return 0 From f0539b1c06023e50d3fedd50c6e3a29a9abb3fcd Mon Sep 17 00:00:00 2001 From: weiihann Date: Thu, 26 Dec 2024 11:51:54 +0800 Subject: [PATCH 26/30] Implement trie nodes use hashFn --- core/crypto/hash.go | 5 + core/trie2/bitarray.go | 455 +++++++++++++++++++++++++++++++++++++++++ core/trie2/node.go | 142 +++++++++++++ 3 files changed, 602 insertions(+) create mode 100644 core/crypto/hash.go create mode 100644 core/trie2/bitarray.go create mode 100644 core/trie2/node.go diff --git a/core/crypto/hash.go b/core/crypto/hash.go new file mode 100644 index 0000000000..00abf93dd1 --- /dev/null +++ b/core/crypto/hash.go @@ -0,0 +1,5 @@ +package crypto + +import "github.com/NethermindEth/juno/core/felt" + +type HashFn func(*felt.Felt, *felt.Felt) *felt.Felt diff --git a/core/trie2/bitarray.go b/core/trie2/bitarray.go new file mode 100644 index 0000000000..62bacbdcb5 --- /dev/null +++ b/core/trie2/bitarray.go @@ -0,0 +1,455 @@ +package trie2 + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "fmt" + "math" + "math/bits" + + "github.com/NethermindEth/juno/core/felt" +) + +const maxUint64 = uint64(math.MaxUint64) // 0xFFFFFFFFFFFFFFFF + +var emptyBitArray = new(BitArray) + +// Represents a bit array with length representing the number of used bits. +// It uses a little endian representation to do bitwise operations of the words efficiently. +// For example, if len is 10, it means that the 2^9, 2^8, ..., 2^0 bits are used. +// The max length is 255 bits (uint8), because our use case only need up to 251 bits for a given trie key. +// Although words can be used to represent 256 bits, we don't want to add an additional byte for the length. +type BitArray struct { + len uint8 // number of used bits + words [4]uint64 // little endian (i.e. words[0] is the least significant) +} + +func NewBitArray(length uint8, val uint64) BitArray { + var b BitArray + b.SetUint64(length, val) + return b +} + +// Returns the felt representation of the bit array. +func (b *BitArray) Felt() felt.Felt { + var f felt.Felt + f.SetBytes(b.Bytes()) + return f +} + +func (b *BitArray) Len() uint8 { + return b.len +} + +// Returns the bytes representation of the bit array in big endian format +func (b *BitArray) Bytes() []byte { + var res [32]byte + + b.truncateToLength() + binary.BigEndian.PutUint64(res[0:8], b.words[3]) + binary.BigEndian.PutUint64(res[8:16], b.words[2]) + binary.BigEndian.PutUint64(res[16:24], b.words[1]) + binary.BigEndian.PutUint64(res[24:32], b.words[0]) + + return res[:] +} + +// Sets the bit array to the least significant 'n' bits of x. +// If length >= x.len, the bit array is an exact copy of x. +// For example: +// +// x = 11001011 (len=8) +// LSBs(x, 4) = 1011 (len=4) +// LSBs(x, 10) = 11001011 (len=8, original x) +// LSBs(x, 0) = 0 (len=0) +// +//nolint:mnd +func (b *BitArray) LSBs(x *BitArray, n uint8) *BitArray { + if n >= x.len { + return b.Set(x) + } + + b.Set(x) + b.len = n + + // Clear all words beyond what's needed + switch { + case n == 0: + b.words = [4]uint64{0, 0, 0, 0} + case n <= 64: + mask := maxUint64 >> (64 - n) + b.words[0] &= mask + b.words[1] = 0 + b.words[2] = 0 + b.words[3] = 0 + case n <= 128: + mask := maxUint64 >> (128 - n) + b.words[1] &= mask + b.words[2] = 0 + b.words[3] = 0 + case n <= 192: + mask := maxUint64 >> (192 - n) + b.words[2] &= mask + b.words[3] = 0 + default: + mask := maxUint64 >> (256 - uint16(n)) + b.words[3] &= mask + } + + return b +} + +// Checks if the current bit array share the same most significant bits with another, where the length of +// the check is determined by the shorter array. Returns true if either array has +// length 0, or if the first min(b.len, x.len) MSBs are identical. +// +// For example: +// +// a = 1101 (len=4) +// b = 11010111 (len=8) +// a.EqualMSBs(b) = true // First 4 MSBs match +// +// a = 1100 (len=4) +// b = 1101 (len=4) +// a.EqualMSBs(b) = false // All bits compared, not equal +// +// a = 1100 (len=4) +// b = [] (len=0) +// a.EqualMSBs(b) = true // Zero length is always a prefix match +func (b *BitArray) EqualMSBs(x *BitArray) bool { + if b.len == x.len { + return b.Equal(x) + } + + if b.len == 0 || x.len == 0 { + return true + } + + // Compare only the first min(b.len, x.len) bits + minLen := b.len + if x.len < minLen { + minLen = x.len + } + + return new(BitArray).MSBs(b, minLen).Equal(new(BitArray).MSBs(x, minLen)) +} + +// Sets the bit array to the most significant 'n' bits of x. +// If n >= x.len, the bit array is an exact copy of x. +// For example: +// +// x = 11001011 (len=8) +// MSBs(x, 4) = 1100 (len=4) +// MSBs(x, 10) = 11001011 (len=8, original x) +// MSBs(x, 0) = 0 (len=0) +func (b *BitArray) MSBs(x *BitArray, n uint8) *BitArray { + if n >= x.len { + return b.Set(x) + } + + return b.Rsh(x, x.len-n) +} + +// Sets the bit array to the longest sequence of matching most significant bits between two bit arrays. +// For example: +// +// x = 1101 0111 (len=8) +// y = 1101 0000 (len=8) +// CommonMSBs(x,y) = 1101 (len=4) +func (b *BitArray) CommonMSBs(x, y *BitArray) *BitArray { + if x.len == 0 || y.len == 0 { + return b.clear() + } + + long, short := x, y + if x.len < y.len { + long, short = y, x + } + + // Align arrays by right-shifting longer array and then XOR to find differences + // Example: + // short = 1100 (len=4) + // long = 1101 0111 (len=8) + // + // Step 1: Right shift longer array by 4 + // short = 1100 + // long = 1101 + // + // Step 2: XOR shows difference at last bit + // 1100 (short) + // 1101 (aligned long) + // ---- XOR + // 0001 (difference at last position) + // We can then use the position of the first set bit and right-shift to get the common MSBs + diff := long.len - short.len + b.Rsh(long, diff).Xor(b, short) + divergentBit := findFirstSetBit(b) + + return b.Rsh(short, divergentBit) +} + +// Sets the bit array to x >> n and returns the bit array. +// +//nolint:mnd +func (b *BitArray) Rsh(x *BitArray, n uint8) *BitArray { + if x.len == 0 { + return b.Set(x) + } + + if n >= x.len { + return b.clear() + } + + switch { + case n == 0: + return b.Set(x) + case n >= 192: + b.rsh192(x) + b.len = x.len - n + n -= 192 + b.words[0] >>= n + case n >= 128: + b.rsh128(x) + b.len = x.len - n + n -= 128 + b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) + b.words[1] >>= n + case n >= 64: + b.rsh64(x) + b.len = x.len - n + n -= 64 + b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) + b.words[1] = (b.words[1] >> n) | (b.words[2] << (64 - n)) + b.words[2] >>= n + default: + b.Set(x) + b.len -= n + b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) + b.words[1] = (b.words[1] >> n) | (b.words[2] << (64 - n)) + b.words[2] = (b.words[2] >> n) | (b.words[3] << (64 - n)) + b.words[3] >>= n + } + + return b +} + +// Sets the bit array to x ^ y and returns the bit array. +func (b *BitArray) Xor(x, y *BitArray) *BitArray { + b.words[0] = x.words[0] ^ y.words[0] + b.words[1] = x.words[1] ^ y.words[1] + b.words[2] = x.words[2] ^ y.words[2] + b.words[3] = x.words[3] ^ y.words[3] + return b +} + +// Checks if two bit arrays are equal +func (b *BitArray) Equal(x *BitArray) bool { + // TODO(weiihann): this is really not a good thing to do... + if b == nil && x == nil { + return true + } else if b == nil || x == nil { + return false + } + + return b.len == x.len && + b.words[0] == x.words[0] && + b.words[1] == x.words[1] && + b.words[2] == x.words[2] && + b.words[3] == x.words[3] +} + +// Returns true if bit n-th is set, where n = 0 is LSB. +// The n must be <= 255. +func (b *BitArray) IsBitSet(n uint8) bool { + if n >= b.len { + return false + } + + return (b.words[n/64] & (1 << (n % 64))) != 0 +} + +// Serialises the BitArray into a bytes buffer in the following format: +// - First byte: length of the bit array (0-255) +// - Remaining bytes: the necessary bytes included in big endian order +// Example: +// +// BitArray{len: 10, words: [4]uint64{0x03FF}} -> [0x0A, 0x03, 0xFF] +func (b *BitArray) Write(buf *bytes.Buffer) (int, error) { + if err := buf.WriteByte(b.len); err != nil { + return 0, err + } + + n, err := buf.Write(b.activeBytes()) + return n + 1, err +} + +// Deserialises the BitArray from a bytes buffer in the following format: +// - First byte: length of the bit array (0-255) +// - Remaining bytes: the necessary bytes included in big endian order +// Example: +// +// [0x0A, 0x03, 0xFF] -> BitArray{len: 10, words: [4]uint64{0x03FF}} +func (b *BitArray) UnmarshalBinary(data []byte) { + b.len = data[0] + + var bs [32]byte + copy(bs[32-b.byteCount():], data[1:]) + b.setBytes32(bs[:]) +} + +// Sets the bit array to the same value as x. +func (b *BitArray) Set(x *BitArray) *BitArray { + b.len = x.len + b.words[0] = x.words[0] + b.words[1] = x.words[1] + b.words[2] = x.words[2] + b.words[3] = x.words[3] + return b +} + +// Sets the bit array to the bytes representation of a felt. +func (b *BitArray) SetFelt(length uint8, f *felt.Felt) *BitArray { + b.len = length + b.setFelt(f) + b.truncateToLength() + return b +} + +// Sets the bit array to the bytes representation of a felt with length 251. +func (b *BitArray) SetFelt251(f *felt.Felt) *BitArray { + b.len = 251 + b.setFelt(f) + b.truncateToLength() + return b +} + +// Interprets the data as the big-endian bytes, sets the bit array to that value and returns it. +// If the data is larger than 32 bytes, only the first 32 bytes are used. +func (b *BitArray) SetBytes(length uint8, data []byte) *BitArray { + b.setBytes32(data) + b.len = length + b.truncateToLength() + return b +} + +// Sets the bit array to the uint64 representation of a bit array. +func (b *BitArray) SetUint64(length uint8, data uint64) *BitArray { + b.words[0] = data + b.len = length + b.truncateToLength() + return b +} + +// Returns the length of the encoded bit array in bytes. +func (b *BitArray) EncodedLen() uint { + return b.byteCount() + 1 +} + +// Returns a deep copy of the bit array. +func (b *BitArray) Copy() BitArray { + var res BitArray + res.Set(b) + return res +} + +// Returns a string representation of the bit array. +func (b *BitArray) String() string { + return fmt.Sprintf("(%d) %s", b.len, hex.EncodeToString(b.Bytes())) +} + +func (b *BitArray) setFelt(f *felt.Felt) { + res := f.Bytes() + b.words[3] = binary.BigEndian.Uint64(res[0:8]) + b.words[2] = binary.BigEndian.Uint64(res[8:16]) + b.words[1] = binary.BigEndian.Uint64(res[16:24]) + b.words[0] = binary.BigEndian.Uint64(res[24:32]) +} + +func (b *BitArray) setBytes32(data []byte) { + _ = data[31] + b.words[3] = binary.BigEndian.Uint64(data[0:8]) + b.words[2] = binary.BigEndian.Uint64(data[8:16]) + b.words[1] = binary.BigEndian.Uint64(data[16:24]) + b.words[0] = binary.BigEndian.Uint64(data[24:32]) +} + +// Returns the minimum number of bytes needed to represent the bit array. +// It rounds up to the nearest byte. +func (b *BitArray) byteCount() uint { + const bits8 = 8 + // Cast to uint16 to avoid overflow + return (uint(b.len) + (bits8 - 1)) / uint(bits8) +} + +// Returns a slice containing only the bytes that are actually used by the bit array, +// as specified by the length. The returned slice is in big-endian order. +// +// Example: +// +// len = 10, words = [0x3FF, 0, 0, 0] -> [0x03, 0xFF] +func (b *BitArray) activeBytes() []byte { + wordsBytes := b.Bytes() + return wordsBytes[32-b.byteCount():] +} + +func (b *BitArray) rsh64(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = 0, x.words[3], x.words[2], x.words[1] +} + +func (b *BitArray) rsh128(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, x.words[3], x.words[2] +} + +func (b *BitArray) rsh192(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, x.words[3] +} + +func (b *BitArray) clear() *BitArray { + b.len = 0 + b.words[0], b.words[1], b.words[2], b.words[3] = 0, 0, 0, 0 + return b +} + +// Truncates the bit array to the specified length, ensuring that any unused bits are all zeros. +// +//nolint:mnd +func (b *BitArray) truncateToLength() { + switch { + case b.len == 0: + b.words = [4]uint64{0, 0, 0, 0} + case b.len <= 64: + b.words[0] &= maxUint64 >> (64 - b.len) + b.words[1], b.words[2], b.words[3] = 0, 0, 0 + case b.len <= 128: + b.words[1] &= maxUint64 >> (128 - b.len) + b.words[2], b.words[3] = 0, 0 + case b.len <= 192: + b.words[2] &= maxUint64 >> (192 - b.len) + b.words[3] = 0 + default: + b.words[3] &= maxUint64 >> (256 - uint16(b.len)) + } +} + +// Returns the position of the first '1' bit in the array, scanning from most significant to least significant bit. +// The bit position is counted from the least significant bit, starting at 0. +// For example: +// +// array = 0000 0000 ... 0100 (len=251) +// findFirstSetBit() = 3 // third bit from right is set +func findFirstSetBit(b *BitArray) uint8 { + if b.len == 0 { + return 0 + } + + // Start from the most significant and move towards the least significant + for i := 3; i >= 0; i-- { + if word := b.words[i]; word != 0 { + return uint8((i+1)*64 - bits.LeadingZeros64(word)) + } + } + + // All bits are zero, no set bit found + return 0 +} diff --git a/core/trie2/node.go b/core/trie2/node.go new file mode 100644 index 0000000000..c1878433f7 --- /dev/null +++ b/core/trie2/node.go @@ -0,0 +1,142 @@ +package trie2 + +import ( + "bytes" + "fmt" + "strings" + + "github.com/NethermindEth/juno/core/crypto" + "github.com/NethermindEth/juno/core/felt" +) + +var ( + _ node = (*internalNode)(nil) + _ node = (*edgeNode)(nil) + _ node = (*hashNode)(nil) + _ node = (*valueNode)(nil) +) + +type node interface { + hash(crypto.HashFn) *felt.Felt // TODO(weiihann): return felt value instead of pointers + cache() (*hashNode, bool) + encode(*bytes.Buffer) error + String() string +} + +type ( + internalNode struct { + children [2]node // 0 = left, 1 = right + flags nodeFlag + } + edgeNode struct { + child node + path *BitArray + flags nodeFlag + } + hashNode struct{ *felt.Felt } + valueNode struct{ *felt.Felt } +) + +type nodeFlag struct { + hash *hashNode + dirty bool +} + +func (n *internalNode) hash(hf crypto.HashFn) *felt.Felt { + return hf(n.children[0].hash(hf), n.children[1].hash(hf)) +} + +func (n *edgeNode) hash(hf crypto.HashFn) *felt.Felt { + var length [32]byte + length[31] = n.path.len + pathFelt := n.path.Felt() + lengthFelt := new(felt.Felt).SetBytes(length[:]) + return new(felt.Felt).Add(hf(n.child.hash(hf), &pathFelt), lengthFelt) +} + +func (n hashNode) hash(crypto.HashFn) *felt.Felt { return n.Felt } +func (n valueNode) hash(crypto.HashFn) *felt.Felt { return n.Felt } + +func (n *internalNode) cache() (*hashNode, bool) { return n.flags.hash, n.flags.dirty } +func (n *edgeNode) cache() (*hashNode, bool) { return n.flags.hash, n.flags.dirty } +func (n hashNode) cache() (*hashNode, bool) { return nil, true } +func (n valueNode) cache() (*hashNode, bool) { return nil, true } + +func (n *internalNode) String() string { + return fmt.Sprintf("Internal[\n left: %s\n right: %s\n]", + indent(n.children[0].String()), + indent(n.children[1].String())) +} + +func (n *edgeNode) String() string { + return fmt.Sprintf("Edge{\n path: %s\n child: %s\n}", + n.path.String(), + indent(n.child.String())) +} + +func (n hashNode) String() string { + return fmt.Sprintf("Hash(%s)", n.Felt.String()) +} + +func (n valueNode) String() string { + return fmt.Sprintf("Value(%s)", n.Felt.String()) +} + +func (n *internalNode) encode(buf *bytes.Buffer) error { + if err := n.children[0].encode(buf); err != nil { + return err + } + + if err := n.children[1].encode(buf); err != nil { + return err + } + + return nil +} + +func (n *edgeNode) encode(buf *bytes.Buffer) error { + if _, err := n.path.Write(buf); err != nil { + return err + } + + if err := n.child.encode(buf); err != nil { + return err + } + + return nil +} + +func (n hashNode) encode(buf *bytes.Buffer) error { + if _, err := buf.Write(n.Felt.Marshal()); err != nil { + return err + } + + return nil +} + +func (n valueNode) encode(buf *bytes.Buffer) error { + if _, err := buf.Write(n.Felt.Marshal()); err != nil { + return err + } + + return nil +} + +func (n *edgeNode) PathMatches(key *BitArray) bool { + return n.path.EqualMSBs(key) +} + +func (n *edgeNode) CommonPath(key *BitArray) BitArray { + var commonPath BitArray + commonPath.CommonMSBs(n.path, key) + return commonPath +} + +// Helper function to indent each line of a string +func indent(s string) string { + lines := strings.Split(s, "\n") + for i, line := range lines { + lines[i] = " " + line + } + return strings.Join(lines, "\n") +} From fbe00b0904c55910ccff4902533828ff6a6e9437 Mon Sep 17 00:00:00 2001 From: weiihann Date: Thu, 26 Dec 2024 18:58:31 +0800 Subject: [PATCH 27/30] Add LSBsAtPos() --- core/trie/bitarray.go | 20 +++++++ core/trie/bitarray_test.go | 112 +++++++++++++++++++++++++++++++++++++ core/trie/proof.go | 2 +- core/trie/trie.go | 2 +- 4 files changed, 134 insertions(+), 2 deletions(-) diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index 85aa4fcca4..d56e7cb47f 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -100,6 +100,26 @@ func (b *BitArray) LSBs(x *BitArray, n uint8) *BitArray { return b } +// Returns the least significant bits of `x` with `pos` as the most significant bit. +// `pos` is counted from the most significant bit, starting at 0. +// For example: +// +// x = 11001011 (len=8) +// LSBsAtPos(x, 1) = 1001011 (len=7) +// LSBsAtPos(x, 10) = 0 (len=0) +// LSBsAtPos(x, 0) = 11001011 (len=8, original x) +func (b *BitArray) LSBsAtPos(x *BitArray, pos uint8) *BitArray { + if pos == 0 { + return b.Set(x) + } + + if pos > x.Len() { + return b.clear() + } + + return b.LSBs(x, x.Len()-pos) +} + // Checks if the current bit array share the same most significant bits with another, where the length of // the check is determined by the shorter array. Returns true if either array has // length 0, or if the first min(b.len, x.len) MSBs are identical. diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 4c57794b06..5584c1174e 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -1170,3 +1170,115 @@ func TestSetFeltValidation(t *testing.T) { }) } } + +func TestLSBsAtPos(t *testing.T) { + tests := []struct { + name string + x *BitArray + pos uint8 + want *BitArray + }{ + { + name: "zero position", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + pos: 0, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "position beyond length", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + pos: 65, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "get last 4 bits", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + pos: 4, + want: &BitArray{ + len: 4, + words: [4]uint64{0x0F, 0, 0, 0}, // 1111 + }, + }, + { + name: "get bits across word boundary", + x: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + pos: 64, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "get bits from max length array", + x: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, + }, + pos: 200, + want: &BitArray{ + len: 51, + words: [4]uint64{0x7FFFFFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "empty array", + x: emptyBitArray, + pos: 1, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "sparse bits", + x: &BitArray{ + len: 16, + words: [4]uint64{0xAAAA, 0, 0, 0}, // 1010101010101010 + }, + pos: 8, + want: &BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 + }, + }, + { + name: "position equals length", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + pos: 64, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).LSBsAtPos(tt.x, tt.pos) + if !got.Equal(tt.want) { + t.Errorf("LSBsAtPos() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/core/trie/proof.go b/core/trie/proof.go index 894991a8a9..b7825b5b87 100644 --- a/core/trie/proof.go +++ b/core/trie/proof.go @@ -555,7 +555,7 @@ func handleEdgeNode( // verifyEdgePath checks if the edge path matches the key path at the current position. func verifyEdgePath(key, edgePath *BitArray, curPos uint8) bool { - return new(BitArray).LSBs(key, key.Len()-curPos).EqualMSBs(edgePath) + return new(BitArray).LSBsAtPos(key, curPos).EqualMSBs(edgePath) } // buildTrie builds a trie from a list of storage nodes and a list of keys and values. diff --git a/core/trie/trie.go b/core/trie/trie.go index 7eaa46dd7f..59d99a1d2c 100644 --- a/core/trie/trie.go +++ b/core/trie/trie.go @@ -114,7 +114,7 @@ func path(key, parentKey *BitArray) BitArray { } var pathKey BitArray - pathKey.LSBs(key, key.Len()-parentKey.Len()-1) + pathKey.LSBsAtPos(key, parentKey.Len()+1) return pathKey } From 4f7b564af9fa3dc03f85b04a991296ee76b0ca7c Mon Sep 17 00:00:00 2001 From: weiihann Date: Mon, 30 Dec 2024 22:24:25 +0800 Subject: [PATCH 28/30] Update() works on TrieD d --- core/trie/bitarray.go | 128 +++++++++++++++++++-- core/trie/bitarray_test.go | 6 +- core/trie/proof.go | 2 +- core/trie/trie.go | 2 +- core/trie2/bitarray.go | 168 +++++++++++++++++++++++++-- core/trie2/errors.go | 5 + core/trie2/node.go | 51 +++++---- core/trie2/trie.go | 225 +++++++++++++++++++++++++++++++++++++ core/trie2/trie_test.go | 89 +++++++++++++++ 9 files changed, 635 insertions(+), 41 deletions(-) create mode 100644 core/trie2/errors.go create mode 100644 core/trie2/trie.go create mode 100644 core/trie2/trie_test.go diff --git a/core/trie/bitarray.go b/core/trie/bitarray.go index d56e7cb47f..c6a34a41fe 100644 --- a/core/trie/bitarray.go +++ b/core/trie/bitarray.go @@ -11,7 +11,10 @@ import ( "github.com/NethermindEth/juno/core/felt" ) -const maxUint64 = uint64(math.MaxUint64) // 0xFFFFFFFFFFFFFFFF +const ( + maxUint64 = uint64(math.MaxUint64) // 0xFFFFFFFFFFFFFFFF + maxUint8 = uint8(math.MaxUint8) +) var emptyBitArray = new(BitArray) @@ -105,10 +108,10 @@ func (b *BitArray) LSBs(x *BitArray, n uint8) *BitArray { // For example: // // x = 11001011 (len=8) -// LSBsAtPos(x, 1) = 1001011 (len=7) -// LSBsAtPos(x, 10) = 0 (len=0) -// LSBsAtPos(x, 0) = 11001011 (len=8, original x) -func (b *BitArray) LSBsAtPos(x *BitArray, pos uint8) *BitArray { +// LSBsFromMSB(x, 1) = 1001011 (len=7) +// LSBsFromMSB(x, 10) = 0 (len=0) +// LSBsFromMSB(x, 0) = 11001011 (len=8, original x) +func (b *BitArray) LSBsFromMSB(x *BitArray, pos uint8) *BitArray { if pos == 0 { return b.Set(x) } @@ -251,6 +254,85 @@ func (b *BitArray) Rsh(x *BitArray, n uint8) *BitArray { b.words[3] >>= n } + b.truncateToLength() + return b +} + +// Lsh sets the bit array to x << n and returns the bit array. +// +//nolint:mnd +func (b *BitArray) Lsh(x *BitArray, n uint8) *BitArray { + b.Set(x) + + if x.len == 0 || n == 0 { + return b + } + + // If the result will overflow, we set the length to the max length + // but we still shift `n` bits + if n > maxUint8-x.len { + b.len = maxUint8 + } else { + b.len = x.len + n + } + + switch { + case n == 0: + return b + case n >= 192: + b.lsh192(x) + n -= 192 + b.words[3] <<= n + case n >= 128: + b.lsh128(x) + n -= 128 + b.words[3] = (b.words[3] << n) | (b.words[2] >> (64 - n)) + b.words[2] <<= n + case n >= 64: + b.lsh64(x) + n -= 64 + b.words[3] = (b.words[3] << n) | (b.words[2] >> (64 - n)) + b.words[2] = (b.words[2] << n) | (b.words[1] >> (64 - n)) + b.words[1] <<= n + default: + b.words[3] = (b.words[3] << n) | (b.words[2] >> (64 - n)) + b.words[2] = (b.words[2] << n) | (b.words[1] >> (64 - n)) + b.words[1] = (b.words[1] << n) | (b.words[0] >> (64 - n)) + b.words[0] <<= n + } + + b.truncateToLength() + return b +} + +// Sets the bit array to the concatenation of x and y and returns the bit array. +// For example: +// +// x = 000 (len=3) +// y = 111 (len=3) +// Append(x,y) = 000111 (len=6) +func (b *BitArray) Append(x, y *BitArray) *BitArray { + if x.len == 0 { + return b.Set(y) + } + if y.len == 0 { + return b.Set(x) + } + + // First copy x + b.Set(x) + + // Then shift left by y's length and OR with y + return b.Lsh(b, y.len).Or(b, y) +} + +// Sets the bit array to x | y and returns the bit array. +func (b *BitArray) Or(x, y *BitArray) *BitArray { + b.words[0] = x.words[0] | y.words[0] + b.words[1] = x.words[1] | y.words[1] + b.words[2] = x.words[2] | y.words[2] + b.words[3] = x.words[3] | y.words[3] + b.len = x.len return b } @@ -280,13 +362,31 @@ func (b *BitArray) Equal(x *BitArray) bool { } // Returns true if bit n-th is set, where n = 0 is LSB. -// The n must be <= 255. func (b *BitArray) IsBitSet(n uint8) bool { + return b.BitSet(n) == 1 +} + +// Returns the bit value at position n, where n = 0 is LSB. +// If n is out of bounds, returns 0. +func (b *BitArray) BitSet(n uint8) uint8 { if n >= b.len { - return false + return 0 + } + + if (b.words[n/64] & (1 << (n % 64))) != 0 { + return 1 } - return (b.words[n/64] & (1 << (n % 64))) != 0 + return 0 +} + +// Returns the bit value at the most significant bit +func (b *BitArray) MSB() uint8 { + return b.BitSet(b.Len() - 1) +} + +func (b *BitArray) IsEmpty() bool { + return b.len == 0 } // Serialises the BitArray into a bytes buffer in the following format: @@ -425,6 +525,18 @@ func (b *BitArray) rsh192(x *BitArray) { b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, x.words[3] } +func (b *BitArray) lsh64(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = x.words[2], x.words[1], x.words[0], 0 +} + +func (b *BitArray) lsh128(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = x.words[1], x.words[0], 0, 0 +} + +func (b *BitArray) lsh192(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = x.words[0], 0, 0, 0 +} + func (b *BitArray) clear() *BitArray { b.len = 0 b.words[0], b.words[1], b.words[2], b.words[3] = 0, 0, 0, 0 diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 5584c1174e..fb7c1534f6 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -1171,7 +1171,7 @@ func TestSetFeltValidation(t *testing.T) { } } -func TestLSBsAtPos(t *testing.T) { +func TestLSBsFromMSB(t *testing.T) { tests := []struct { name string x *BitArray @@ -1275,9 +1275,9 @@ func TestLSBsAtPos(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := new(BitArray).LSBsAtPos(tt.x, tt.pos) + got := new(BitArray).LSBsFromMSB(tt.x, tt.pos) if !got.Equal(tt.want) { - t.Errorf("LSBsAtPos() = %v, want %v", got, tt.want) + t.Errorf("LSBsFromMSB() = %v, want %v", got, tt.want) } }) } diff --git a/core/trie/proof.go b/core/trie/proof.go index b7825b5b87..e68734ceac 100644 --- a/core/trie/proof.go +++ b/core/trie/proof.go @@ -555,7 +555,7 @@ func handleEdgeNode( // verifyEdgePath checks if the edge path matches the key path at the current position. func verifyEdgePath(key, edgePath *BitArray, curPos uint8) bool { - return new(BitArray).LSBsAtPos(key, curPos).EqualMSBs(edgePath) + return new(BitArray).LSBsFromMSB(key, curPos).EqualMSBs(edgePath) } // buildTrie builds a trie from a list of storage nodes and a list of keys and values. diff --git a/core/trie/trie.go b/core/trie/trie.go index 59d99a1d2c..ef203799da 100644 --- a/core/trie/trie.go +++ b/core/trie/trie.go @@ -114,7 +114,7 @@ func path(key, parentKey *BitArray) BitArray { } var pathKey BitArray - pathKey.LSBsAtPos(key, parentKey.Len()+1) + pathKey.LSBsFromMSB(key, parentKey.Len()+1) return pathKey } diff --git a/core/trie2/bitarray.go b/core/trie2/bitarray.go index 62bacbdcb5..2d408d550d 100644 --- a/core/trie2/bitarray.go +++ b/core/trie2/bitarray.go @@ -11,7 +11,10 @@ import ( "github.com/NethermindEth/juno/core/felt" ) -const maxUint64 = uint64(math.MaxUint64) // 0xFFFFFFFFFFFFFFFF +const ( + maxUint64 = uint64(math.MaxUint64) // 0xFFFFFFFFFFFFFFFF + maxUint8 = uint8(math.MaxUint8) +) var emptyBitArray = new(BitArray) @@ -56,16 +59,17 @@ func (b *BitArray) Bytes() []byte { } // Sets the bit array to the least significant 'n' bits of x. +// n is counted from the least significant bit, starting at 0. // If length >= x.len, the bit array is an exact copy of x. // For example: // // x = 11001011 (len=8) -// LSBs(x, 4) = 1011 (len=4) -// LSBs(x, 10) = 11001011 (len=8, original x) -// LSBs(x, 0) = 0 (len=0) +// LSBsFromLSB(x, 4) = 1011 (len=4) +// LSBsFromLSB(x, 10) = 11001011 (len=8, original x) +// LSBsFromLSB(x, 0) = 0 (len=0) // //nolint:mnd -func (b *BitArray) LSBs(x *BitArray, n uint8) *BitArray { +func (b *BitArray) LSBsFromLSB(x *BitArray, n uint8) *BitArray { if n >= x.len { return b.Set(x) } @@ -100,6 +104,25 @@ func (b *BitArray) LSBs(x *BitArray, n uint8) *BitArray { return b } +// Returns the least significant bits of `x` with `n` counted from the most significant bit, starting at 0. +// For example: +// +// x = 11001011 (len=8) +// LSBs(x, 1) = 1001011 (len=7) +// LSBs(x, 10) = 0 (len=0) +// LSBs(x, 0) = 11001011 (len=8, original x) +func (b *BitArray) LSBs(x *BitArray, n uint8) *BitArray { + if n == 0 { + return b.Set(x) + } + + if n > x.Len() { + return b.clear() + } + + return b.LSBsFromLSB(x, x.Len()-n) +} + // Checks if the current bit array share the same most significant bits with another, where the length of // the check is determined by the shorter array. Returns true if either array has // length 0, or if the first min(b.len, x.len) MSBs are identical. @@ -231,6 +254,85 @@ func (b *BitArray) Rsh(x *BitArray, n uint8) *BitArray { b.words[3] >>= n } + b.truncateToLength() + return b +} + +// Lsh sets the bit array to x << n and returns the bit array. +// +//nolint:mnd +func (b *BitArray) Lsh(x *BitArray, n uint8) *BitArray { + b.Set(x) + + if x.len == 0 || n == 0 { + return b + } + + // If the result will overflow, we set the length to the max length + // but we still shift `n` bits + if n > maxUint8-x.len { + b.len = maxUint8 + } else { + b.len = x.len + n + } + + switch { + case n == 0: + return b + case n >= 192: + b.lsh192(x) + n -= 192 + b.words[3] <<= n + case n >= 128: + b.lsh128(x) + n -= 128 + b.words[3] = (b.words[3] << n) | (b.words[2] >> (64 - n)) + b.words[2] <<= n + case n >= 64: + b.lsh64(x) + n -= 64 + b.words[3] = (b.words[3] << n) | (b.words[2] >> (64 - n)) + b.words[2] = (b.words[2] << n) | (b.words[1] >> (64 - n)) + b.words[1] <<= n + default: + b.words[3] = (b.words[3] << n) | (b.words[2] >> (64 - n)) + b.words[2] = (b.words[2] << n) | (b.words[1] >> (64 - n)) + b.words[1] = (b.words[1] << n) | (b.words[0] >> (64 - n)) + b.words[0] <<= n + } + + b.truncateToLength() + return b +} + +// Sets the bit array to the concatenation of x and y and returns the bit array. +// For example: +// +// x = 000 (len=3) +// y = 111 (len=3) +// Append(x,y) = 000111 (len=6) +func (b *BitArray) Append(x, y *BitArray) *BitArray { + if x.len == 0 { + return b.Set(y) + } + if y.len == 0 { + return b.Set(x) + } + + // First copy x + b.Set(x) + + // Then shift left by y's length and OR with y + return b.Lsh(b, y.len).Or(b, y) +} + +// Sets the bit array to x | y and returns the bit array. +func (b *BitArray) Or(x, y *BitArray) *BitArray { + b.words[0] = x.words[0] | y.words[0] + b.words[1] = x.words[1] | y.words[1] + b.words[2] = x.words[2] | y.words[2] + b.words[3] = x.words[3] | y.words[3] + b.len = x.len return b } @@ -260,13 +362,49 @@ func (b *BitArray) Equal(x *BitArray) bool { } // Returns true if bit n-th is set, where n = 0 is LSB. -// The n must be <= 255. -func (b *BitArray) IsBitSet(n uint8) bool { +func (b *BitArray) IsBitSetFromLSB(n uint8) bool { + return b.BitSetFromLSB(n) == 1 +} + +// Returns the bit value at position n, where n = 0 is LSB. +// If n is out of bounds, returns 0. +func (b *BitArray) BitSetFromLSB(n uint8) uint8 { if n >= b.len { - return false + return 0 + } + + if (b.words[n/64] & (1 << (n % 64))) != 0 { + return 1 + } + + return 0 +} + +func (b *BitArray) IsBitSet(n uint8) bool { + return b.BitSet(n) == 1 +} + +// Returns the bit value at position n, where n = 0 is MSB. +// If n is out of bounds, returns 0. +func (b *BitArray) BitSet(n uint8) uint8 { + if n >= b.Len() { + return 0 } - return (b.words[n/64] & (1 << (n % 64))) != 0 + return b.BitSetFromLSB(b.Len() - n - 1) +} + +// Returns the bit value at the most significant bit +func (b *BitArray) MSB() uint8 { + return b.BitSet(0) +} + +func (b *BitArray) LSB() uint8 { + return b.BitSetFromLSB(0) +} + +func (b *BitArray) IsEmpty() bool { + return b.len == 0 } // Serialises the BitArray into a bytes buffer in the following format: @@ -405,6 +543,18 @@ func (b *BitArray) rsh192(x *BitArray) { b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, x.words[3] } +func (b *BitArray) lsh64(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = x.words[2], x.words[1], x.words[0], 0 +} + +func (b *BitArray) lsh128(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = x.words[1], x.words[0], 0, 0 +} + +func (b *BitArray) lsh192(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = x.words[0], 0, 0, 0 +} + func (b *BitArray) clear() *BitArray { b.len = 0 b.words[0], b.words[1], b.words[2], b.words[3] = 0, 0, 0, 0 diff --git a/core/trie2/errors.go b/core/trie2/errors.go new file mode 100644 index 0000000000..306fe62059 --- /dev/null +++ b/core/trie2/errors.go @@ -0,0 +1,5 @@ +package trie2 + +import "errors" + +var ErrCommitted = errors.New("trie is committed") diff --git a/core/trie2/node.go b/core/trie2/node.go index c1878433f7..5e072f7aae 100644 --- a/core/trie2/node.go +++ b/core/trie2/node.go @@ -10,7 +10,7 @@ import ( ) var ( - _ node = (*internalNode)(nil) + _ node = (*binaryNode)(nil) _ node = (*edgeNode)(nil) _ node = (*hashNode)(nil) _ node = (*valueNode)(nil) @@ -19,12 +19,12 @@ var ( type node interface { hash(crypto.HashFn) *felt.Felt // TODO(weiihann): return felt value instead of pointers cache() (*hashNode, bool) - encode(*bytes.Buffer) error + write(*bytes.Buffer) error String() string } type ( - internalNode struct { + binaryNode struct { children [2]node // 0 = left, 1 = right flags nodeFlag } @@ -37,12 +37,21 @@ type ( valueNode struct{ *felt.Felt } ) +const ( + binaryNodeType byte = iota + edgeNodeType + hashNodeType + valueNodeType +) + type nodeFlag struct { hash *hashNode dirty bool } -func (n *internalNode) hash(hf crypto.HashFn) *felt.Felt { +func newFlag() nodeFlag { return nodeFlag{dirty: false} } + +func (n *binaryNode) hash(hf crypto.HashFn) *felt.Felt { return hf(n.children[0].hash(hf), n.children[1].hash(hf)) } @@ -57,13 +66,13 @@ func (n *edgeNode) hash(hf crypto.HashFn) *felt.Felt { func (n hashNode) hash(crypto.HashFn) *felt.Felt { return n.Felt } func (n valueNode) hash(crypto.HashFn) *felt.Felt { return n.Felt } -func (n *internalNode) cache() (*hashNode, bool) { return n.flags.hash, n.flags.dirty } -func (n *edgeNode) cache() (*hashNode, bool) { return n.flags.hash, n.flags.dirty } -func (n hashNode) cache() (*hashNode, bool) { return nil, true } -func (n valueNode) cache() (*hashNode, bool) { return nil, true } +func (n *binaryNode) cache() (*hashNode, bool) { return n.flags.hash, n.flags.dirty } +func (n *edgeNode) cache() (*hashNode, bool) { return n.flags.hash, n.flags.dirty } +func (n hashNode) cache() (*hashNode, bool) { return nil, true } +func (n valueNode) cache() (*hashNode, bool) { return nil, true } -func (n *internalNode) String() string { - return fmt.Sprintf("Internal[\n left: %s\n right: %s\n]", +func (n *binaryNode) String() string { + return fmt.Sprintf("Binary[\n left: %s\n right: %s\n]", indent(n.children[0].String()), indent(n.children[1].String())) } @@ -82,31 +91,31 @@ func (n valueNode) String() string { return fmt.Sprintf("Value(%s)", n.Felt.String()) } -func (n *internalNode) encode(buf *bytes.Buffer) error { - if err := n.children[0].encode(buf); err != nil { +func (n *binaryNode) write(buf *bytes.Buffer) error { + if err := n.children[0].write(buf); err != nil { return err } - if err := n.children[1].encode(buf); err != nil { + if err := n.children[1].write(buf); err != nil { return err } return nil } -func (n *edgeNode) encode(buf *bytes.Buffer) error { +func (n *edgeNode) write(buf *bytes.Buffer) error { if _, err := n.path.Write(buf); err != nil { return err } - if err := n.child.encode(buf); err != nil { + if err := n.child.write(buf); err != nil { return err } return nil } -func (n hashNode) encode(buf *bytes.Buffer) error { +func (n hashNode) write(buf *bytes.Buffer) error { if _, err := buf.Write(n.Felt.Marshal()); err != nil { return err } @@ -114,7 +123,7 @@ func (n hashNode) encode(buf *bytes.Buffer) error { return nil } -func (n valueNode) encode(buf *bytes.Buffer) error { +func (n valueNode) write(buf *bytes.Buffer) error { if _, err := buf.Write(n.Felt.Marshal()); err != nil { return err } @@ -122,11 +131,15 @@ func (n valueNode) encode(buf *bytes.Buffer) error { return nil } -func (n *edgeNode) PathMatches(key *BitArray) bool { +// TODO(weiihann): check if we want to return a pointer or a value +func (n *binaryNode) copy() *binaryNode { cpy := *n; return &cpy } +func (n *edgeNode) copy() *edgeNode { cpy := *n; return &cpy } + +func (n *edgeNode) pathMatches(key *BitArray) bool { return n.path.EqualMSBs(key) } -func (n *edgeNode) CommonPath(key *BitArray) BitArray { +func (n *edgeNode) commonPath(key *BitArray) BitArray { var commonPath BitArray commonPath.CommonMSBs(n.path, key) return commonPath diff --git a/core/trie2/trie.go b/core/trie2/trie.go new file mode 100644 index 0000000000..2a60ced1fe --- /dev/null +++ b/core/trie2/trie.go @@ -0,0 +1,225 @@ +package trie2 + +import ( + "fmt" + + "github.com/NethermindEth/juno/core/felt" +) + +type Trie struct { + height uint8 + root node + reader interface{} // TODO(weiihann): implement reader + // committed bool +} + +// TODO(weiihann): implement this +func NewTrie(height uint8) *Trie { + return &Trie{height: height} +} + +func (t *Trie) Update(key, value *felt.Felt) error { + // if t.commited { + // return ErrCommitted + // } + return t.update(key, value) +} + +func (t *Trie) Get(key *felt.Felt) (*felt.Felt, error) { + k := t.FeltToKey(key) + // TODO(weiihann): get the value directly from the reader + val, root, didResolve, err := t.get(t.root, &k) + // In Starknet, a non-existent key is mapped to felt.Zero + if val == nil { + val = &felt.Zero + } + if err == nil && didResolve { + t.root = root + } + return val, err +} + +func (t *Trie) Delete(key *felt.Felt) error { + panic("TODO(weiihann): implement me") +} + +// Traverses the trie recursively to find the value that corresponds to the key. +func (t *Trie) get(n node, key *BitArray) (*felt.Felt, node, bool, error) { + switch n := n.(type) { + case *edgeNode: + if !n.pathMatches(key) { + return nil, nil, false, nil + } + val, child, didResolve, err := t.get(n.child, key.LSBs(key, n.path.Len())) + if err == nil && didResolve { + n = n.copy() + n.child = child + } + return val, n, didResolve, err + case *binaryNode: + bit := key.MSB() + val, child, didResolve, err := t.get(n.children[bit], key.LSBs(key, 1)) + if err == nil && didResolve { + n = n.copy() + n.children[bit] = child + } + return val, n, didResolve, err + case hashNode: + panic("TODO(weiihann): implement me") + case valueNode: + return n.Felt, n, false, nil + case nil: + return nil, nil, false, nil + default: + panic(fmt.Sprintf("unknown node type: %T", n)) + } +} + +func (t *Trie) update(key, value *felt.Felt) error { + k := t.FeltToKey(key) + if value.IsZero() { + _, n, err := t.delete(t.root, &k) + if err != nil { + return err + } + t.root = n + } else { + _, n, err := t.insert(t.root, &k, valueNode{Felt: value}) + if err != nil { + return err + } + t.root = n + } + return nil +} + +func (t *Trie) insert(n node, key *BitArray, value node) (bool, node, error) { + // We reach the end of the key + if key.Len() == 0 { + if v, ok := n.(valueNode); ok { + return v.Equal(value.(valueNode).Felt), value, nil + } + return true, value, nil + } + + switch n := n.(type) { + case *edgeNode: + match := n.commonPath(key) + // If the whole key matches, just keep this edge node as it is and update the value + if match.Len() == n.path.Len() { + dirty, newNode, err := t.insert(n.child, key.LSBs(key, match.Len()), value) + if !dirty || err != nil { + return false, n, err + } + return true, &edgeNode{ + path: n.path, + child: newNode, + flags: newFlag(), + }, nil + } + // Otherwise branch out at the bit index where they differ + branch := &binaryNode{flags: newFlag()} + var err error + _, branch.children[n.path.BitSet(match.Len())], err = t.insert(nil, new(BitArray).LSBs(n.path, match.Len()+1), n.child) + if err != nil { + return false, n, err + } + + _, branch.children[key.BitSet(match.Len())], err = t.insert(nil, new(BitArray).LSBs(key, match.Len()+1), value) + if err != nil { + return false, n, err + } + + // Replace this edge node with the new binary node if it occurs at the current MSB + if match.IsEmpty() { + return true, branch, nil + } + + return true, &edgeNode{path: new(BitArray).MSBs(key, match.Len()), child: branch, flags: newFlag()}, nil + + case *binaryNode: + bit := key.MSB() + dirty, newNode, err := t.insert(n.children[bit], new(BitArray).LSBs(key, 1), value) + if !dirty || err != nil { + return false, n, err + } + n = n.copy() + n.flags = newFlag() + n.children[bit] = newNode + return true, n, nil + case nil: + if key.IsEmpty() { + return true, value, nil + } + return true, &edgeNode{path: key, child: value, flags: newFlag()}, nil + case hashNode: + panic("TODO(weiihann): implement me") + default: + panic(fmt.Sprintf("unknown node type: %T", n)) + } +} + +func (t *Trie) delete(n node, key *BitArray) (bool, node, error) { + switch n := n.(type) { + case *edgeNode: + match := n.commonPath(key) + // Mismatched, don't do anything + if match.Len() < n.path.Len() { + return false, n, nil + } + // If the whole key matches, just delete the edge node + if match.Len() == key.Len() { + return true, nil, nil + } + + // Otherwise, we need to delete the child node + dirty, child, err := t.delete(n.child, key.LSBs(key, match.Len())) + if !dirty || err != nil { + return false, n, err + } + switch child := child.(type) { + case *edgeNode: + return true, &edgeNode{path: n.path, child: child.child, flags: newFlag()}, nil + default: + return true, &edgeNode{path: n.path, child: child, flags: newFlag()}, nil + } + case *binaryNode: + bit := key.MSB() + dirty, newNode, err := t.delete(n.children[bit], key.LSBs(key, 1)) + if !dirty || err != nil { + return false, n, err + } + n = n.copy() + n.flags = newFlag() + n.children[bit] = newNode + + if newNode != nil { + return true, n, nil + } + + // TODO(weiihann): combine this binary node with the child + + return true, n, nil + case valueNode: + return true, nil, nil + case nil: + return false, nil, nil + case hashNode: + panic("TODO(weiihann): implement me") + default: + panic(fmt.Sprintf("unknown node type: %T", n)) + } +} + +func (t *Trie) String() string { + if t.root == nil { + return "" + } + return t.root.String() +} + +func (t *Trie) FeltToKey(f *felt.Felt) BitArray { + var key BitArray + key.SetFelt(t.height, f) + return key +} diff --git a/core/trie2/trie_test.go b/core/trie2/trie_test.go new file mode 100644 index 0000000000..5f91ecc40e --- /dev/null +++ b/core/trie2/trie_test.go @@ -0,0 +1,89 @@ +package trie2 + +import ( + "math/rand" + "testing" + + "github.com/NethermindEth/juno/core/felt" + "github.com/stretchr/testify/require" +) + +func TestUpdate(t *testing.T) { + trie := NewTrie(251) + + key := new(felt.Felt).SetUint64(1) + value := new(felt.Felt).SetUint64(2) + err := trie.Update(key, value) + require.NoError(t, err) + + got, err := trie.Get(key) + require.NoError(t, err) + require.Equal(t, value, got) +} + +func TestUpdateRandom(t *testing.T) { + tr, records := randomTrie(t, 1000) + + for _, record := range records { + got, err := tr.Get(record.key) + require.NoError(t, err) + + if !got.Equal(record.value) { + t.Fatalf("expected %s, got %s", record.value, got) + } + } +} + +func Test4KeysTrieD(t *testing.T) { + tr, _ := build4KeysTrieD(t) + t.Log(tr.String()) +} + +type keyValue struct { + key *felt.Felt + value *felt.Felt +} + +func randomTrie(t testing.TB, n int) (*Trie, []*keyValue) { + rrand := rand.New(rand.NewSource(3)) + + tr := NewTrie(251) + records := make([]*keyValue, n) + + for i := 0; i < n; i++ { + key := new(felt.Felt).SetUint64(uint64(rrand.Uint32() + 1)) + records[i] = &keyValue{key: key, value: key} + err := tr.Update(key, key) + require.NoError(t, err) + } + + return tr, records +} + +func build4KeysTrieD(t *testing.T) (*Trie, []*keyValue) { + records := []*keyValue{ + {key: new(felt.Felt).SetUint64(1), value: new(felt.Felt).SetUint64(4)}, + {key: new(felt.Felt).SetUint64(4), value: new(felt.Felt).SetUint64(5)}, + {key: new(felt.Felt).SetUint64(6), value: new(felt.Felt).SetUint64(6)}, + {key: new(felt.Felt).SetUint64(7), value: new(felt.Felt).SetUint64(7)}, + } + + return buildTrie(t, records), records +} + +func buildTrie(t *testing.T, records []*keyValue) *Trie { + if len(records) == 0 { + t.Fatal("records must have at least one element") + } + + tempTrie := NewTrie(251) + + for _, record := range records { + err := tempTrie.Update(record.key, record.value) + t.Log("--------------------------------") + t.Log(tempTrie.String()) + require.NoError(t, err) + } + + return tempTrie +} From 2401a74898a1c6406008166d07ff3c0a81296f2b Mon Sep 17 00:00:00 2001 From: weiihann Date: Tue, 31 Dec 2024 13:58:50 +0800 Subject: [PATCH 29/30] add docs --- core/trie2/bitarray.go | 26 +++++++++---- core/trie2/node.go | 1 + core/trie2/trie.go | 81 +++++++++++++++++++++++++++++------------ core/trie2/trie_test.go | 59 +++++++++++++++++++++++------- 4 files changed, 124 insertions(+), 43 deletions(-) diff --git a/core/trie2/bitarray.go b/core/trie2/bitarray.go index 2d408d550d..6fe218df9e 100644 --- a/core/trie2/bitarray.go +++ b/core/trie2/bitarray.go @@ -363,12 +363,12 @@ func (b *BitArray) Equal(x *BitArray) bool { // Returns true if bit n-th is set, where n = 0 is LSB. func (b *BitArray) IsBitSetFromLSB(n uint8) bool { - return b.BitSetFromLSB(n) == 1 + return b.BitFromLSB(n) == 1 } // Returns the bit value at position n, where n = 0 is LSB. // If n is out of bounds, returns 0. -func (b *BitArray) BitSetFromLSB(n uint8) uint8 { +func (b *BitArray) BitFromLSB(n uint8) uint8 { if n >= b.len { return 0 } @@ -381,26 +381,26 @@ func (b *BitArray) BitSetFromLSB(n uint8) uint8 { } func (b *BitArray) IsBitSet(n uint8) bool { - return b.BitSet(n) == 1 + return b.Bit(n) == 1 } // Returns the bit value at position n, where n = 0 is MSB. // If n is out of bounds, returns 0. -func (b *BitArray) BitSet(n uint8) uint8 { +func (b *BitArray) Bit(n uint8) uint8 { if n >= b.Len() { return 0 } - return b.BitSetFromLSB(b.Len() - n - 1) + return b.BitFromLSB(b.Len() - n - 1) } // Returns the bit value at the most significant bit func (b *BitArray) MSB() uint8 { - return b.BitSet(0) + return b.Bit(0) } func (b *BitArray) LSB() uint8 { - return b.BitSetFromLSB(0) + return b.BitFromLSB(0) } func (b *BitArray) IsEmpty() bool { @@ -479,6 +479,18 @@ func (b *BitArray) SetUint64(length uint8, data uint64) *BitArray { return b } +// Sets the bit array to a single bit. +func (b *BitArray) SetBit(bit bool) *BitArray { + b.len = 1 + if bit { + b.words[0] = 1 + } else { + b.words[0] = 0 + } + b.truncateToLength() + return b +} + // Returns the length of the encoded bit array in bytes. func (b *BitArray) EncodedLen() uint { return b.byteCount() + 1 diff --git a/core/trie2/node.go b/core/trie2/node.go index 5e072f7aae..2f2277b36b 100644 --- a/core/trie2/node.go +++ b/core/trie2/node.go @@ -139,6 +139,7 @@ func (n *edgeNode) pathMatches(key *BitArray) bool { return n.path.EqualMSBs(key) } +// Returns the common bits between the current node and the given key, starting from the most significant bit func (n *edgeNode) commonPath(key *BitArray) BitArray { var commonPath BitArray commonPath.CommonMSBs(n.path, key) diff --git a/core/trie2/trie.go b/core/trie2/trie.go index 2a60ced1fe..2e97770488 100644 --- a/core/trie2/trie.go +++ b/core/trie2/trie.go @@ -18,6 +18,8 @@ func NewTrie(height uint8) *Trie { return &Trie{height: height} } +// Modifies or inserts a key-value pair in the trie. +// If value is zero, the key is deleted from the trie. func (t *Trie) Update(key, value *felt.Felt) error { // if t.commited { // return ErrCommitted @@ -25,6 +27,9 @@ func (t *Trie) Update(key, value *felt.Felt) error { return t.update(key, value) } +// Retrieves the value associated with the given key. +// Returns felt.Zero if the key doesn't exist. +// May update the trie's internal structure if nodes need to be resolved. func (t *Trie) Get(key *felt.Felt) (*felt.Felt, error) { k := t.FeltToKey(key) // TODO(weiihann): get the value directly from the reader @@ -39,11 +44,17 @@ func (t *Trie) Get(key *felt.Felt) (*felt.Felt, error) { return val, err } +// Removes the given key from the trie. func (t *Trie) Delete(key *felt.Felt) error { - panic("TODO(weiihann): implement me") + k := t.FeltToKey(key) + _, n, err := t.delete(t.root, new(BitArray), &k) + if err != nil { + return err + } + t.root = n + return nil } -// Traverses the trie recursively to find the value that corresponds to the key. func (t *Trie) get(n node, key *BitArray) (*felt.Felt, node, bool, error) { switch n := n.(type) { case *edgeNode: @@ -75,10 +86,12 @@ func (t *Trie) get(n node, key *BitArray) (*felt.Felt, node, bool, error) { } } +// Modifies the trie by either inserting/updating a value or deleting a key. +// The operation is determined by whether the value is zero (delete) or non-zero (insert/update). func (t *Trie) update(key, value *felt.Felt) error { k := t.FeltToKey(key) if value.IsZero() { - _, n, err := t.delete(t.root, &k) + _, n, err := t.delete(t.root, new(BitArray), &k) if err != nil { return err } @@ -104,8 +117,8 @@ func (t *Trie) insert(n node, key *BitArray, value node) (bool, node, error) { switch n := n.(type) { case *edgeNode: - match := n.commonPath(key) - // If the whole key matches, just keep this edge node as it is and update the value + match := n.commonPath(key) // get the matching bits between the current node and the key + // If the match is the same as the path, just keep this edge node as it is and update the value if match.Len() == n.path.Len() { dirty, newNode, err := t.insert(n.child, key.LSBs(key, match.Len()), value) if !dirty || err != nil { @@ -117,15 +130,15 @@ func (t *Trie) insert(n node, key *BitArray, value node) (bool, node, error) { flags: newFlag(), }, nil } - // Otherwise branch out at the bit index where they differ + // Otherwise branch out at the bit position where they differ branch := &binaryNode{flags: newFlag()} var err error - _, branch.children[n.path.BitSet(match.Len())], err = t.insert(nil, new(BitArray).LSBs(n.path, match.Len()+1), n.child) + _, branch.children[n.path.Bit(match.Len())], err = t.insert(nil, new(BitArray).LSBs(n.path, match.Len()+1), n.child) if err != nil { return false, n, err } - _, branch.children[key.BitSet(match.Len())], err = t.insert(nil, new(BitArray).LSBs(key, match.Len()+1), value) + _, branch.children[key.Bit(match.Len())], err = t.insert(nil, new(BitArray).LSBs(key, match.Len()+1), value) if err != nil { return false, n, err } @@ -135,22 +148,27 @@ func (t *Trie) insert(n node, key *BitArray, value node) (bool, node, error) { return true, branch, nil } + // Otherwise, create a new edge node with the path being the common path and the branch as the child return true, &edgeNode{path: new(BitArray).MSBs(key, match.Len()), child: branch, flags: newFlag()}, nil case *binaryNode: + // Go to the child node based on the MSB of the key bit := key.MSB() dirty, newNode, err := t.insert(n.children[bit], new(BitArray).LSBs(key, 1), value) if !dirty || err != nil { return false, n, err } + // Replace the child node with the new node n = n.copy() n.flags = newFlag() n.children[bit] = newNode return true, n, nil case nil: + // We reach the end of the key, return the value node if key.IsEmpty() { return true, value, nil } + // Otherwise, return a new edge node with the path being the key and the value as the child return true, &edgeNode{path: key, child: value, flags: newFlag()}, nil case hashNode: panic("TODO(weiihann): implement me") @@ -159,7 +177,7 @@ func (t *Trie) insert(n node, key *BitArray, value node) (bool, node, error) { } } -func (t *Trie) delete(n node, key *BitArray) (bool, node, error) { +func (t *Trie) delete(n node, prefix, key *BitArray) (bool, node, error) { switch n := n.(type) { case *edgeNode: match := n.commonPath(key) @@ -167,25 +185,28 @@ func (t *Trie) delete(n node, key *BitArray) (bool, node, error) { if match.Len() < n.path.Len() { return false, n, nil } - // If the whole key matches, just delete the edge node + // If the whole key matches, remove the entire edge node if match.Len() == key.Len() { return true, nil, nil } - // Otherwise, we need to delete the child node - dirty, child, err := t.delete(n.child, key.LSBs(key, match.Len())) + // Otherwise, key is longer than current node path, so we need to delete the child. + // Child can never be nil because it's guaranteed that we have at least 2 other values in the subtrie. + keyPrefix := new(BitArray).MSBs(key, n.path.Len()) + dirty, child, err := t.delete(n.child, new(BitArray).Append(prefix, keyPrefix), key.LSBs(key, n.path.Len())) if !dirty || err != nil { return false, n, err } switch child := child.(type) { case *edgeNode: - return true, &edgeNode{path: n.path, child: child.child, flags: newFlag()}, nil + return true, &edgeNode{path: new(BitArray).Append(n.path, child.path), child: child.child, flags: newFlag()}, nil default: - return true, &edgeNode{path: n.path, child: child, flags: newFlag()}, nil + return true, &edgeNode{path: new(BitArray).Set(n.path), child: child, flags: newFlag()}, nil } case *binaryNode: bit := key.MSB() - dirty, newNode, err := t.delete(n.children[bit], key.LSBs(key, 1)) + keyPrefix := new(BitArray).MSBs(key, 1) + dirty, newNode, err := t.delete(n.children[bit], new(BitArray).Append(prefix, keyPrefix), key.LSBs(key, 1)) if !dirty || err != nil { return false, n, err } @@ -193,13 +214,25 @@ func (t *Trie) delete(n node, key *BitArray) (bool, node, error) { n.flags = newFlag() n.children[bit] = newNode + // If the child node is not nil, that means we still have 2 children in this binary node if newNode != nil { return true, n, nil } - // TODO(weiihann): combine this binary node with the child + // Otherwise, we need to combine this binary node with the other child + other := bit ^ 1 + bitPrefix := new(BitArray).SetBit(other == 1) + if cn, ok := n.children[other].(*edgeNode); ok { // other child is an edge node, append the bit prefix to the child path + return true, &edgeNode{ + path: new(BitArray).Append(bitPrefix, cn.path), + child: cn.child, + flags: newFlag(), + }, nil + } - return true, n, nil + // other child is not an edge node, create a new edge node with the bit prefix as the path + // containing the other child as the child + return true, &edgeNode{path: bitPrefix, child: n.children[other], flags: newFlag()}, nil case valueNode: return true, nil, nil case nil: @@ -211,15 +244,17 @@ func (t *Trie) delete(n node, key *BitArray) (bool, node, error) { } } +// Converts a Felt value into a BitArray representation suitable for +// use as a trie key with the specified height. +func (t *Trie) FeltToKey(f *felt.Felt) BitArray { + var key BitArray + key.SetFelt(t.height, f) + return key +} + func (t *Trie) String() string { if t.root == nil { return "" } return t.root.String() } - -func (t *Trie) FeltToKey(f *felt.Felt) BitArray { - var key BitArray - key.SetFelt(t.height, f) - return key -} diff --git a/core/trie2/trie_test.go b/core/trie2/trie_test.go index 5f91ecc40e..f3001b694b 100644 --- a/core/trie2/trie_test.go +++ b/core/trie2/trie_test.go @@ -9,16 +9,16 @@ import ( ) func TestUpdate(t *testing.T) { - trie := NewTrie(251) + tr, records := nonRandomTrie(t, 1000) - key := new(felt.Felt).SetUint64(1) - value := new(felt.Felt).SetUint64(2) - err := trie.Update(key, value) - require.NoError(t, err) + for _, record := range records { + err := tr.Update(record.key, record.value) + require.NoError(t, err) - got, err := trie.Get(key) - require.NoError(t, err) - require.Equal(t, value, got) + got, err := tr.Get(record.key) + require.NoError(t, err) + require.Equal(t, record.value, got) + } } func TestUpdateRandom(t *testing.T) { @@ -34,9 +34,30 @@ func TestUpdateRandom(t *testing.T) { } } -func Test4KeysTrieD(t *testing.T) { - tr, _ := build4KeysTrieD(t) - t.Log(tr.String()) +func TestDelete(t *testing.T) { + tr, records := nonRandomTrie(t, 10000) + + for _, record := range records { + err := tr.Delete(record.key) + require.NoError(t, err) + + got, err := tr.Get(record.key) + require.NoError(t, err) + require.Equal(t, got, &felt.Zero) + } +} + +func TestDeleteRandom(t *testing.T) { + tr, records := randomTrie(t, 10000) + + for i := len(records) - 1; i >= 0; i-- { + err := tr.Delete(records[i].key) + require.NoError(t, err) + + got, err := tr.Get(records[i].key) + require.NoError(t, err) + require.Equal(t, got, &felt.Zero) + } } type keyValue struct { @@ -44,6 +65,20 @@ type keyValue struct { value *felt.Felt } +func nonRandomTrie(t *testing.T, numKeys int) (*Trie, []*keyValue) { + tr := NewTrie(251) + records := make([]*keyValue, numKeys) + + for i := 1; i < numKeys+1; i++ { + key := new(felt.Felt).SetUint64(uint64(i)) + records[i-1] = &keyValue{key: key, value: key} + err := tr.Update(key, key) + require.NoError(t, err) + } + + return tr, records +} + func randomTrie(t testing.TB, n int) (*Trie, []*keyValue) { rrand := rand.New(rand.NewSource(3)) @@ -80,8 +115,6 @@ func buildTrie(t *testing.T, records []*keyValue) *Trie { for _, record := range records { err := tempTrie.Update(record.key, record.value) - t.Log("--------------------------------") - t.Log(tempTrie.String()) require.NoError(t, err) } From 0bd5aa9992ded53c9a30d2afadf9dbb0b51889eb Mon Sep 17 00:00:00 2001 From: weiihann Date: Thu, 2 Jan 2025 11:14:51 +0800 Subject: [PATCH 30/30] Add hasher --- core/trie2/hasher.go | 85 +++++++++++++++++++++++++++++ core/trie2/trie.go | 18 +++++- core/trie2/trie_test.go | 118 +++++++++++++++++++++++++++++++++++++++- 3 files changed, 216 insertions(+), 5 deletions(-) create mode 100644 core/trie2/hasher.go diff --git a/core/trie2/hasher.go b/core/trie2/hasher.go new file mode 100644 index 0000000000..7152cd12ec --- /dev/null +++ b/core/trie2/hasher.go @@ -0,0 +1,85 @@ +package trie2 + +import ( + "fmt" + "sync" + + "github.com/NethermindEth/juno/core/crypto" +) + +// hasher handles node hashing for the trie. It supports both sequential and parallel +// hashing modes. +type hasher struct { + hashFn crypto.HashFn // The hash function to use + parallel bool // Whether to hash binary node children in parallel +} + +func newHasher(hash crypto.HashFn, parallel bool) hasher { + return hasher{ + hashFn: hash, + parallel: parallel, + } +} + +// hash computes the hash of a node and returns both the hash node and a cached +// version of the original node. If the node already has a cached hash, returns +// that instead of recomputing. +func (h *hasher) hash(n node) (node, node) { + if hash, _ := n.cache(); hash != nil { + return hash, n + } + + switch n := n.(type) { + case *edgeNode: + collapsed, cached := h.hashEdgeChild(n) + hn := &hashNode{Felt: collapsed.hash(h.hashFn)} + cached.flags.hash = hn + return hn, cached + case *binaryNode: + collapsed, cached := h.hashBinaryChildren(n) + hn := &hashNode{Felt: collapsed.hash(h.hashFn)} + cached.flags.hash = hn + return hn, cached + case valueNode, hashNode: + return n, n + default: + panic(fmt.Sprintf("unknown node type: %T", n)) + } +} + +func (h *hasher) hashEdgeChild(n *edgeNode) (collapsed, cached *edgeNode) { + collapsed, cached = n.copy(), n.copy() + + switch n.child.(type) { + case *edgeNode, *binaryNode: + collapsed.child, cached.child = h.hash(n.child) + } + + return collapsed, cached +} + +func (h *hasher) hashBinaryChildren(n *binaryNode) (collapsed, cached *binaryNode) { + collapsed, cached = n.copy(), n.copy() + + if h.parallel { // TODO(weiihann): double check this parallel strategy + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + collapsed.children[0], cached.children[0] = h.hash(n.children[0]) + }() + + go func() { + defer wg.Done() + collapsed.children[1], cached.children[1] = h.hash(n.children[1]) + }() + + wg.Wait() + } else { + collapsed.children[0], cached.children[0] = h.hash(n.children[0]) + collapsed.children[1], cached.children[1] = h.hash(n.children[1]) + } + + return collapsed, cached +} diff --git a/core/trie2/trie.go b/core/trie2/trie.go index 2e97770488..cff79c3a3a 100644 --- a/core/trie2/trie.go +++ b/core/trie2/trie.go @@ -3,6 +3,7 @@ package trie2 import ( "fmt" + "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" ) @@ -11,11 +12,12 @@ type Trie struct { root node reader interface{} // TODO(weiihann): implement reader // committed bool + hashFn crypto.HashFn } // TODO(weiihann): implement this -func NewTrie(height uint8) *Trie { - return &Trie{height: height} +func NewTrie(height uint8, hashFn crypto.HashFn) *Trie { + return &Trie{height: height, hashFn: hashFn} } // Modifies or inserts a key-value pair in the trie. @@ -55,6 +57,13 @@ func (t *Trie) Delete(key *felt.Felt) error { return nil } +// Returns the root hash of the trie. Calling this method will also cache the hash of each node in the trie. +func (t *Trie) Hash() *felt.Felt { + hash, cached := t.hashRoot() + t.root = cached + return hash.(*hashNode).Felt +} + func (t *Trie) get(n node, key *BitArray) (*felt.Felt, node, bool, error) { switch n := n.(type) { case *edgeNode: @@ -244,6 +253,11 @@ func (t *Trie) delete(n node, prefix, key *BitArray) (bool, node, error) { } } +func (t *Trie) hashRoot() (node, node) { + h := newHasher(t.hashFn, false) // TODO(weiihann): handle parallel hashing + return h.hash(t.root) +} + // Converts a Felt value into a BitArray representation suitable for // use as a trie key with the specified height. func (t *Trie) FeltToKey(f *felt.Felt) BitArray { diff --git a/core/trie2/trie_test.go b/core/trie2/trie_test.go index f3001b694b..6361eba4fa 100644 --- a/core/trie2/trie_test.go +++ b/core/trie2/trie_test.go @@ -4,7 +4,9 @@ import ( "math/rand" "testing" + "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/utils" "github.com/stretchr/testify/require" ) @@ -60,13 +62,123 @@ func TestDeleteRandom(t *testing.T) { } } +// The expected hashes are taken from Pathfinder's tests +func TestHash(t *testing.T) { + t.Run("one leaf", func(t *testing.T) { + tr := NewTrie(251, crypto.Pedersen) + err := tr.Update(new(felt.Felt).SetUint64(1), new(felt.Felt).SetUint64(2)) + require.NoError(t, err) + hash := tr.Hash() + + expected := "0x2ab889bd35e684623df9b4ea4a4a1f6d9e0ef39b67c1293b8a89dd17e351330" + require.Equal(t, expected, hash.String(), "expected %s, got %s", expected, hash.String()) + }) + + t.Run("two leaves", func(t *testing.T) { + tr := NewTrie(251, crypto.Pedersen) + err := tr.Update(new(felt.Felt).SetUint64(0), new(felt.Felt).SetUint64(2)) + require.NoError(t, err) + err = tr.Update(new(felt.Felt).SetUint64(1), new(felt.Felt).SetUint64(3)) + require.NoError(t, err) + root := tr.Hash() + + expected := "0x79acdb7a3d78052114e21458e8c4aecb9d781ce79308193c61a2f3f84439f66" + require.Equal(t, expected, root.String(), "expected %s, got %s", expected, root.String()) + }) + + t.Run("three leaves", func(t *testing.T) { + tr := NewTrie(251, crypto.Pedersen) + + keys := []*felt.Felt{ + new(felt.Felt).SetUint64(16), + new(felt.Felt).SetUint64(17), + new(felt.Felt).SetUint64(19), + } + + vals := []*felt.Felt{ + new(felt.Felt).SetUint64(10), + new(felt.Felt).SetUint64(11), + new(felt.Felt).SetUint64(12), + } + + for i := range keys { + err := tr.Update(keys[i], vals[i]) + require.NoError(t, err) + } + + expected := "0x7e2184e9e1a651fd556b42b4ff10e44a71b1709f641e0203dc8bd2b528e5e81" + root := tr.Hash() + require.Equal(t, expected, root.String(), "expected %s, got %s", expected, root.String()) + }) + + t.Run("double binary", func(t *testing.T) { + // (249,0,x3) + // | + // (0, 0, x3) + // / \ + // (0,0,x1) (1, 1, 5) + // / \ | + // (2) (3) (5) + + keys := []*felt.Felt{ + new(felt.Felt).SetUint64(0), + new(felt.Felt).SetUint64(1), + new(felt.Felt).SetUint64(3), + } + + vals := []*felt.Felt{ + new(felt.Felt).SetUint64(2), + new(felt.Felt).SetUint64(3), + new(felt.Felt).SetUint64(5), + } + + tr := NewTrie(251, crypto.Pedersen) + for i := range keys { + err := tr.Update(keys[i], vals[i]) + require.NoError(t, err) + } + + expected := "0x6a316f09913454294c6b6751dea8449bc2e235fdc04b2ab0e1ac7fea25cc34f" + root := tr.Hash() + require.Equal(t, expected, root.String(), "expected %s, got %s", expected, root.String()) + }) + + t.Run("binary root", func(t *testing.T) { + // (0, 0, x) + // / \ + // (250, 0, cc) (250, 11111.., dd) + // | | + // (cc) (dd) + + keys := []*felt.Felt{ + new(felt.Felt).SetUint64(0), + utils.HexToFelt(t, "0x7ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), + } + + vals := []*felt.Felt{ + utils.HexToFelt(t, "0xcc"), + utils.HexToFelt(t, "0xdd"), + } + + tr := NewTrie(251, crypto.Pedersen) + for i := range keys { + err := tr.Update(keys[i], vals[i]) + require.NoError(t, err) + } + + expected := "0x542ced3b6aeef48339129a03e051693eff6a566d3a0a94035fa16ab610dc9e2" + root := tr.Hash() + require.Equal(t, expected, root.String(), "expected %s, got %s", expected, root.String()) + }) +} + type keyValue struct { key *felt.Felt value *felt.Felt } func nonRandomTrie(t *testing.T, numKeys int) (*Trie, []*keyValue) { - tr := NewTrie(251) + tr := NewTrie(251, crypto.Pedersen) records := make([]*keyValue, numKeys) for i := 1; i < numKeys+1; i++ { @@ -82,7 +194,7 @@ func nonRandomTrie(t *testing.T, numKeys int) (*Trie, []*keyValue) { func randomTrie(t testing.TB, n int) (*Trie, []*keyValue) { rrand := rand.New(rand.NewSource(3)) - tr := NewTrie(251) + tr := NewTrie(251, crypto.Pedersen) records := make([]*keyValue, n) for i := 0; i < n; i++ { @@ -111,7 +223,7 @@ func buildTrie(t *testing.T, records []*keyValue) *Trie { t.Fatal("records must have at least one element") } - tempTrie := NewTrie(251) + tempTrie := NewTrie(251, crypto.Pedersen) for _, record := range records { err := tempTrie.Update(record.key, record.value)