From ed0cf70d571001290860218e38733b27e03e5861 Mon Sep 17 00:00:00 2001 From: arnaucube Date: Mon, 20 Sep 2021 18:09:32 +0200 Subject: [PATCH] Update VT goroutines errs & Update Pack&UnpackSibl - Update VT goroutines errs to avoid race condition - Update pack & unpack siblings to use 2-byte for full length & bitmap bytes length - Add check in UnpackSiblings to avoid panic --- .gitignore | 1 + tree.go | 35 +++++++++++++++++++++-------------- tree_test.go | 2 +- vt.go | 4 ++-- 4 files changed, 25 insertions(+), 17 deletions(-) diff --git a/.gitignore b/.gitignore index 9c31a5f..95cbb01 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ err-dump +covprofile diff --git a/tree.go b/tree.go index b3c5d02..b11aa93 100644 --- a/tree.go +++ b/tree.go @@ -73,7 +73,7 @@ var ( // Tree defines the struct that implements the MerkleTree functionalities type Tree struct { - sync.RWMutex + sync.Mutex db db.Database maxLevels int @@ -659,11 +659,12 @@ func (t *Tree) GenProofWithTx(rTx db.ReadTx, k []byte) ([]byte, []byte, []byte, } // PackSiblings packs the siblings into a byte array. -// [ 1 byte | L bytes | S * N bytes ] -// [ bitmap length (L) | bitmap | N non-zero siblings ] +// [ 2 byte | 2 byte | L bytes | S * N bytes ] +// [ full length | bitmap length (L) | bitmap | N non-zero siblings ] // Where the bitmap indicates if the sibling is 0 or a value from the siblings // array. And S is the size of the output of the hash function used for the -// Tree. +// Tree. The 2 2-byte that define the full length and bitmap length, are +// encoded in little-endian. func PackSiblings(hashFunc HashFunction, siblings [][]byte) []byte { var b []byte var bitmap []bool @@ -680,19 +681,28 @@ func PackSiblings(hashFunc HashFunction, siblings [][]byte) []byte { bitmapBytes := bitmapToBytes(bitmap) l := len(bitmapBytes) - res := make([]byte, l+1+len(b)) - res[0] = byte(l) // set the bitmapBytes length - copy(res[1:1+l], bitmapBytes) - copy(res[1+l:], b) + fullLen := 4 + l + len(b) //nolint:gomnd + res := make([]byte, fullLen) + binary.LittleEndian.PutUint16(res[0:2], uint16(fullLen)) // set full length + binary.LittleEndian.PutUint16(res[2:4], uint16(l)) // set the bitmapBytes length + copy(res[4:4+l], bitmapBytes) + copy(res[4+l:], b) return res } // UnpackSiblings unpacks the siblings from a byte array. func UnpackSiblings(hashFunc HashFunction, b []byte) ([][]byte, error) { - l := b[0] - bitmapBytes := b[1 : 1+l] + fullLen := binary.LittleEndian.Uint16(b[0:2]) + l := binary.LittleEndian.Uint16(b[2:4]) // bitmap bytes length + if len(b) != int(fullLen) { + return nil, + fmt.Errorf("error unpacking siblings. Expected len: %d, current len: %d", + fullLen, len(b)) + } + + bitmapBytes := b[4 : 4+l] bitmap := bytesToBitmap(bitmapBytes) - siblingsBytes := b[1+l:] + siblingsBytes := b[4+l:] iSibl := 0 emptySibl := make([]byte, hashFunc.Len()) var siblings [][]byte @@ -845,9 +855,6 @@ func (t *Tree) GetNLeafsWithTx(rTx db.ReadTx) (int, error) { // Snapshot returns a read-only copy of the Tree from the given root func (t *Tree) Snapshot(fromRoot []byte) (*Tree, error) { - t.RLock() - defer t.RUnlock() - // allow to define which root to use if fromRoot == nil { var err error diff --git a/tree_test.go b/tree_test.go index 530dca1..d0ce615 100644 --- a/tree_test.go +++ b/tree_test.go @@ -311,7 +311,7 @@ func TestGenProofAndVerify(t *testing.T) { c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck - bLen := tree.HashFunction().Len() + bLen := tree.HashFunction().Len() - 1 for i := 0; i < 10; i++ { k := BigIntToBytes(bLen, big.NewInt(int64(i))) v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) diff --git a/vt.go b/vt.go index a18bbc8..9232ec6 100644 --- a/vt.go +++ b/vt.go @@ -176,7 +176,7 @@ func (t *vt) addBatch(ks, vs [][]byte) ([]int, error) { bucketVT := newVT(t.params.maxLevels, t.params.hashFunction) bucketVT.root = nodesAtL[cpu] for j := 0; j < len(buckets[cpu]); j++ { - if err = bucketVT.add(l, buckets[cpu][j].k, buckets[cpu][j].v); err != nil { + if err := bucketVT.add(l, buckets[cpu][j].k, buckets[cpu][j].v); err != nil { invalidsInBucket[cpu] = append(invalidsInBucket[cpu], buckets[cpu][j].pos) } } @@ -321,7 +321,7 @@ func (t *vt) computeHashes() ([][2][]byte, error) { bucketVT := newVT(t.params.maxLevels, t.params.hashFunction) bucketVT.params.dbg = newDbgStats() bucketVT.root = nodesAtL[cpu] - + var err error bucketPairs[cpu], err = bucketVT.root.computeHashes(l-1, t.params.maxLevels, bucketVT.params, bucketPairs[cpu]) if err != nil {