Skip to content

Commit

Permalink
Update VT goroutines errs & Update Pack&UnpackSibl
Browse files Browse the repository at this point in the history
- Update VT goroutines errs to avoid race condition
- Update pack & unpack siblings to use 2-byte for full length & bitmap
bytes length
- Add check in UnpackSiblings to avoid panic
  • Loading branch information
arnaucube committed Sep 21, 2021
1 parent f09b0b0 commit ed0cf70
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 17 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
err-dump
covprofile
35 changes: 21 additions & 14 deletions tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ var (

// Tree defines the struct that implements the MerkleTree functionalities
type Tree struct {
sync.RWMutex
sync.Mutex

db db.Database
maxLevels int
Expand Down Expand Up @@ -659,11 +659,12 @@ func (t *Tree) GenProofWithTx(rTx db.ReadTx, k []byte) ([]byte, []byte, []byte,
}

// PackSiblings packs the siblings into a byte array.
// [ 1 byte | L bytes | S * N bytes ]
// [ bitmap length (L) | bitmap | N non-zero siblings ]
// [ 2 byte | 2 byte | L bytes | S * N bytes ]
// [ full length | bitmap length (L) | bitmap | N non-zero siblings ]
// Where the bitmap indicates if the sibling is 0 or a value from the siblings
// array. And S is the size of the output of the hash function used for the
// Tree.
// Tree. The 2 2-byte that define the full length and bitmap length, are
// encoded in little-endian.
func PackSiblings(hashFunc HashFunction, siblings [][]byte) []byte {
var b []byte
var bitmap []bool
Expand All @@ -680,19 +681,28 @@ func PackSiblings(hashFunc HashFunction, siblings [][]byte) []byte {
bitmapBytes := bitmapToBytes(bitmap)
l := len(bitmapBytes)

res := make([]byte, l+1+len(b))
res[0] = byte(l) // set the bitmapBytes length
copy(res[1:1+l], bitmapBytes)
copy(res[1+l:], b)
fullLen := 4 + l + len(b) //nolint:gomnd
res := make([]byte, fullLen)
binary.LittleEndian.PutUint16(res[0:2], uint16(fullLen)) // set full length
binary.LittleEndian.PutUint16(res[2:4], uint16(l)) // set the bitmapBytes length
copy(res[4:4+l], bitmapBytes)
copy(res[4+l:], b)
return res
}

// UnpackSiblings unpacks the siblings from a byte array.
func UnpackSiblings(hashFunc HashFunction, b []byte) ([][]byte, error) {
l := b[0]
bitmapBytes := b[1 : 1+l]
fullLen := binary.LittleEndian.Uint16(b[0:2])
l := binary.LittleEndian.Uint16(b[2:4]) // bitmap bytes length
if len(b) != int(fullLen) {
return nil,
fmt.Errorf("error unpacking siblings. Expected len: %d, current len: %d",
fullLen, len(b))
}

bitmapBytes := b[4 : 4+l]
bitmap := bytesToBitmap(bitmapBytes)
siblingsBytes := b[1+l:]
siblingsBytes := b[4+l:]
iSibl := 0
emptySibl := make([]byte, hashFunc.Len())
var siblings [][]byte
Expand Down Expand Up @@ -845,9 +855,6 @@ func (t *Tree) GetNLeafsWithTx(rTx db.ReadTx) (int, error) {

// Snapshot returns a read-only copy of the Tree from the given root
func (t *Tree) Snapshot(fromRoot []byte) (*Tree, error) {
t.RLock()
defer t.RUnlock()

// allow to define which root to use
if fromRoot == nil {
var err error
Expand Down
2 changes: 1 addition & 1 deletion tree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ func TestGenProofAndVerify(t *testing.T) {
c.Assert(err, qt.IsNil)
defer tree.db.Close() //nolint:errcheck

bLen := tree.HashFunction().Len()
bLen := tree.HashFunction().Len() - 1
for i := 0; i < 10; i++ {
k := BigIntToBytes(bLen, big.NewInt(int64(i)))
v := BigIntToBytes(bLen, big.NewInt(int64(i*2)))
Expand Down
4 changes: 2 additions & 2 deletions vt.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ func (t *vt) addBatch(ks, vs [][]byte) ([]int, error) {
bucketVT := newVT(t.params.maxLevels, t.params.hashFunction)
bucketVT.root = nodesAtL[cpu]
for j := 0; j < len(buckets[cpu]); j++ {
if err = bucketVT.add(l, buckets[cpu][j].k, buckets[cpu][j].v); err != nil {
if err := bucketVT.add(l, buckets[cpu][j].k, buckets[cpu][j].v); err != nil {
invalidsInBucket[cpu] = append(invalidsInBucket[cpu], buckets[cpu][j].pos)
}
}
Expand Down Expand Up @@ -321,7 +321,7 @@ func (t *vt) computeHashes() ([][2][]byte, error) {
bucketVT := newVT(t.params.maxLevels, t.params.hashFunction)
bucketVT.params.dbg = newDbgStats()
bucketVT.root = nodesAtL[cpu]

var err error
bucketPairs[cpu], err = bucketVT.root.computeHashes(l-1,
t.params.maxLevels, bucketVT.params, bucketPairs[cpu])
if err != nil {
Expand Down

0 comments on commit ed0cf70

Please sign in to comment.