From 30d8b42fd35707dea5cc4c3360568114c7b84e19 Mon Sep 17 00:00:00 2001 From: arnaucube Date: Mon, 4 Oct 2021 11:27:21 +0200 Subject: [PATCH] Add checks that len(key)<=maxKeyLen Add checks that the key is not bigger than maximum key length for the tree maxLevels size, where maximum key len = ceil(maxLevels/8). This is because if the key bits length is bigger than the maxLevels of the tree, two different keys that their difference is at the end, will collision in the same leaf of the tree (at the max depth). --- addbatch_test.go | 58 ++++---- circomproofs_test.go | 3 +- .../go-data-generator/generator_test.go | 3 +- tree.go | 50 +++++-- tree_test.go | 139 +++++++++++++----- utils.go | 4 +- vt.go | 50 ++++--- vt_test.go | 102 +++++++------ 8 files changed, 263 insertions(+), 146 deletions(-) diff --git a/addbatch_test.go b/addbatch_test.go index d02fd10..e6e4cb2 100644 --- a/addbatch_test.go +++ b/addbatch_test.go @@ -39,12 +39,12 @@ func debugTime(descr string, time1, time2 time.Duration) { func testInit(c *qt.C, n int) (*Tree, *Tree) { database1, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree1, err := NewTree(database1, 100, HashFunctionPoseidon) + tree1, err := NewTree(database1, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree2, err := NewTree(database2, 100, HashFunctionPoseidon) + tree2, err := NewTree(database2, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) bLen := HashFunctionPoseidon.Len() @@ -70,11 +70,11 @@ func TestAddBatchTreeEmpty(t *testing.T) { database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 100, HashFunctionPoseidon) + tree, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck - bLen := tree.HashFunction().Len() + bLen := 32 var keys, values [][]byte for i := 0; i < nLeafs; i++ { k := BigIntToBytes(bLen, big.NewInt(int64(i))) @@ -93,7 +93,7 @@ func TestAddBatchTreeEmpty(t *testing.T) { database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree2, err := NewTree(database2, 100, HashFunctionPoseidon) + tree2, err := NewTree(database2, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree2.db.Close() //nolint:errcheck tree2.dbgInit() @@ -120,11 +120,11 @@ func TestAddBatchTreeEmptyNotPowerOf2(t *testing.T) { database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 100, HashFunctionPoseidon) + tree, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck - bLen := tree.HashFunction().Len() + bLen := 32 for i := 0; i < nLeafs; i++ { k := BigIntToBytes(bLen, big.NewInt(int64(i))) v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) @@ -135,7 +135,7 @@ func TestAddBatchTreeEmptyNotPowerOf2(t *testing.T) { database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree2, err := NewTree(database2, 100, HashFunctionPoseidon) + tree2, err := NewTree(database2, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree2.db.Close() //nolint:errcheck @@ -167,13 +167,13 @@ func TestAddBatchTestVector1(t *testing.T) { c := qt.New(t) database1, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree1, err := NewTree(database1, 100, HashFunctionBlake2b) + tree1, err := NewTree(database1, 256, HashFunctionBlake2b) c.Assert(err, qt.IsNil) defer tree1.db.Close() //nolint:errcheck database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree2, err := NewTree(database2, 100, HashFunctionBlake2b) + tree2, err := NewTree(database2, 256, HashFunctionBlake2b) c.Assert(err, qt.IsNil) defer tree2.db.Close() //nolint:errcheck @@ -207,13 +207,13 @@ func TestAddBatchTestVector1(t *testing.T) { // 2nd test vectors database1, err = badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree1, err = NewTree(database1, 100, HashFunctionBlake2b) + tree1, err = NewTree(database1, 256, HashFunctionBlake2b) c.Assert(err, qt.IsNil) defer tree1.db.Close() //nolint:errcheck database2, err = badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree2, err = NewTree(database2, 100, HashFunctionBlake2b) + tree2, err = NewTree(database2, 256, HashFunctionBlake2b) c.Assert(err, qt.IsNil) defer tree2.db.Close() //nolint:errcheck @@ -255,13 +255,13 @@ func TestAddBatchTestVector2(t *testing.T) { database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree1, err := NewTree(database, 100, HashFunctionPoseidon) + tree1, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree1.db.Close() //nolint:errcheck database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree2, err := NewTree(database2, 100, HashFunctionPoseidon) + tree2, err := NewTree(database2, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree2.db.Close() //nolint:errcheck @@ -300,13 +300,13 @@ func TestAddBatchTestVector3(t *testing.T) { database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree1, err := NewTree(database, 100, HashFunctionPoseidon) + tree1, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree1.db.Close() //nolint:errcheck database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree2, err := NewTree(database2, 100, HashFunctionPoseidon) + tree2, err := NewTree(database2, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree2.db.Close() //nolint:errcheck @@ -349,13 +349,13 @@ func TestAddBatchTreeEmptyRandomKeys(t *testing.T) { database1, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree1, err := NewTree(database1, 100, HashFunctionBlake2b) + tree1, err := NewTree(database1, 256, HashFunctionBlake2b) c.Assert(err, qt.IsNil) defer tree1.db.Close() //nolint:errcheck database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree2, err := NewTree(database2, 100, HashFunctionBlake2b) + tree2, err := NewTree(database2, 256, HashFunctionBlake2b) c.Assert(err, qt.IsNil) defer tree2.db.Close() //nolint:errcheck @@ -699,7 +699,7 @@ func TestAddBatchNotEmptyUnbalanced(t *testing.T) { database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree2, err := NewTree(database2, 100, HashFunctionPoseidon) + tree2, err := NewTree(database2, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree2.db.Close() //nolint:errcheck tree2.dbgInit() @@ -776,7 +776,7 @@ func benchAdd(t *testing.T, ks, vs [][]byte) { database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 140, HashFunctionBlake2b) + tree, err := NewTree(database, 256, HashFunctionBlake2b) c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck @@ -796,7 +796,7 @@ func benchAddBatch(t *testing.T, ks, vs [][]byte) { database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 140, HashFunctionBlake2b) + tree, err := NewTree(database, 256, HashFunctionBlake2b) c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck @@ -829,7 +829,7 @@ func TestDbgStats(t *testing.T) { // 1 database1, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree1, err := NewTree(database1, 100, HashFunctionBlake2b) + tree1, err := NewTree(database1, 256, HashFunctionBlake2b) c.Assert(err, qt.IsNil) defer tree1.db.Close() //nolint:errcheck @@ -843,7 +843,7 @@ func TestDbgStats(t *testing.T) { // 2 database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree2, err := NewTree(database2, 100, HashFunctionBlake2b) + tree2, err := NewTree(database2, 256, HashFunctionBlake2b) c.Assert(err, qt.IsNil) defer tree2.db.Close() //nolint:errcheck @@ -856,7 +856,7 @@ func TestDbgStats(t *testing.T) { // 3 database3, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree3, err := NewTree(database3, 100, HashFunctionBlake2b) + tree3, err := NewTree(database3, 256, HashFunctionBlake2b) c.Assert(err, qt.IsNil) defer tree3.db.Close() //nolint:errcheck @@ -891,7 +891,7 @@ func TestLoadVT(t *testing.T) { database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 100, HashFunctionPoseidon) + tree, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck @@ -927,11 +927,11 @@ func TestAddKeysWithEmptyValues(t *testing.T) { database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 100, HashFunctionPoseidon) + tree, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck - bLen := tree.HashFunction().Len() + bLen := 32 var keys, values [][]byte for i := 0; i < nLeafs; i++ { k := BigIntToBytes(bLen, big.NewInt(int64(i))) @@ -948,7 +948,7 @@ func TestAddKeysWithEmptyValues(t *testing.T) { database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree2, err := NewTree(database2, 100, HashFunctionPoseidon) + tree2, err := NewTree(database2, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree2.db.Close() //nolint:errcheck tree2.dbgInit() @@ -962,7 +962,7 @@ func TestAddKeysWithEmptyValues(t *testing.T) { // use tree3 to add nil value array database3, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree3, err := NewTree(database3, 100, HashFunctionPoseidon) + tree3, err := NewTree(database3, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree3.db.Close() //nolint:errcheck diff --git a/circomproofs_test.go b/circomproofs_test.go index 39ad1e9..c8345a1 100644 --- a/circomproofs_test.go +++ b/circomproofs_test.go @@ -17,14 +17,13 @@ func TestCircomVerifierProof(t *testing.T) { c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck - bLen := tree.HashFunction().Len() - testVector := [][]int64{ {1, 11}, {2, 22}, {3, 33}, {4, 44}, } + bLen := 1 for i := 0; i < len(testVector); i++ { k := BigIntToBytes(bLen, big.NewInt(testVector[i][0])) v := BigIntToBytes(bLen, big.NewInt(testVector[i][1])) diff --git a/testvectors/circom/go-data-generator/generator_test.go b/testvectors/circom/go-data-generator/generator_test.go index 91e45d9..ff55129 100644 --- a/testvectors/circom/go-data-generator/generator_test.go +++ b/testvectors/circom/go-data-generator/generator_test.go @@ -18,14 +18,13 @@ func TestGenerator(t *testing.T) { tree, err := arbo.NewTree(database, 4, arbo.HashFunctionPoseidon) c.Assert(err, qt.IsNil) - bLen := tree.HashFunction().Len() - testVector := [][]int64{ {1, 11}, {2, 22}, {3, 33}, {4, 44}, } + bLen := 1 for i := 0; i < len(testVector); i++ { k := arbo.BigIntToBytes(bLen, big.NewInt(testVector[i][0])) v := arbo.BigIntToBytes(bLen, big.NewInt(testVector[i][1])) diff --git a/tree.go b/tree.go index 6f856d0..0e11b8f 100644 --- a/tree.go +++ b/tree.go @@ -229,8 +229,10 @@ func (t *Tree) AddBatchWithTx(wTx db.WriteTx, keys, values [][]byte) ([]int, err } // store root (from the vt) to db - if err := wTx.Set(dbKeyRoot, vt.root.h); err != nil { - return nil, err + if vt.root != nil { + if err := wTx.Set(dbKeyRoot, vt.root.h); err != nil { + return nil, err + } } // update nLeafs @@ -310,14 +312,34 @@ func (t *Tree) AddWithTx(wTx db.WriteTx, k, v []byte) error { return nil } -func (t *Tree) add(wTx db.WriteTx, root []byte, fromLvl int, k, v []byte) ([]byte, error) { - keyPath := make([]byte, int(math.Ceil(float64(t.maxLevels)/float64(8)))) //nolint:gomnd +// keyPathFromKey returns the keyPath and checks that the key is not bigger +// than maximum key length for the tree maxLevels size. +// This is because if the key bits length is bigger than the maxLevels of the +// tree, two different keys that their difference is at the end, will collision +// in the same leaf of the tree (at the max depth). +func keyPathFromKey(maxLevels int, k []byte) ([]byte, error) { + maxKeyLen := int(math.Ceil(float64(maxLevels) / float64(8))) //nolint:gomnd + if len(k) > maxKeyLen { + return nil, fmt.Errorf("len(k) can not be bigger than ceil(maxLevels/8), where"+ + " len(k): %d, maxLevels: %d, max key len=ceil(maxLevels/8): %d. Might need"+ + " a bigger tree depth (maxLevels>=%d) in order to input keys of length %d", + len(k), maxLevels, maxKeyLen, len(k)*8, len(k)) //nolint:gomnd + } + keyPath := make([]byte, maxKeyLen) //nolint:gomnd copy(keyPath[:], k) + return keyPath, nil +} + +func (t *Tree) add(wTx db.WriteTx, root []byte, fromLvl int, k, v []byte) ([]byte, error) { + keyPath, err := keyPathFromKey(t.maxLevels, k) + if err != nil { + return nil, err + } path := getPath(t.maxLevels, keyPath) // go down to the leaf var siblings [][]byte - _, _, siblings, err := t.down(wTx, k, root, siblings, path, fromLvl, false) + _, _, siblings, err = t.down(wTx, k, root, siblings, path, fromLvl, false) if err != nil { return nil, err } @@ -590,8 +612,10 @@ func (t *Tree) UpdateWithTx(wTx db.WriteTx, k, v []byte) error { return ErrSnapshotNotEditable } - keyPath := make([]byte, int(math.Ceil(float64(t.maxLevels)/float64(8)))) //nolint:gomnd - copy(keyPath[:], k) + keyPath, err := keyPathFromKey(t.maxLevels, k) + if err != nil { + return err + } path := getPath(t.maxLevels, keyPath) root, err := t.RootWithTx(wTx) @@ -647,8 +671,10 @@ func (t *Tree) GenProof(k []byte) ([]byte, []byte, []byte, bool, error) { // GenProofWithTx does the same than the GenProof method, but allowing to pass // the db.ReadTx that is used. func (t *Tree) GenProofWithTx(rTx db.ReadTx, k []byte) ([]byte, []byte, []byte, bool, error) { - keyPath := make([]byte, int(math.Ceil(float64(t.maxLevels)/float64(8)))) //nolint:gomnd - copy(keyPath[:], k) + keyPath, err := keyPathFromKey(t.maxLevels, k) + if err != nil { + return nil, nil, nil, false, err + } path := getPath(t.maxLevels, keyPath) root, err := t.RootWithTx(rTx) @@ -782,8 +808,10 @@ func (t *Tree) Get(k []byte) ([]byte, []byte, error) { // ErrKeyNotFound, and in the leafK & leafV parameters will be placed the data // found in the tree in the leaf that was on the path going to the input key. func (t *Tree) GetWithTx(rTx db.ReadTx, k []byte) ([]byte, []byte, error) { - keyPath := make([]byte, int(math.Ceil(float64(t.maxLevels)/float64(8)))) //nolint:gomnd - copy(keyPath[:], k) + keyPath, err := keyPathFromKey(t.maxLevels, k) + if err != nil { + return nil, nil, err + } path := getPath(t.maxLevels, keyPath) root, err := t.RootWithTx(rTx) diff --git a/tree_test.go b/tree_test.go index f6592c4..a8912c5 100644 --- a/tree_test.go +++ b/tree_test.go @@ -2,7 +2,9 @@ package arbo import ( "encoding/hex" + "math" "math/big" + "runtime" "testing" "time" @@ -60,7 +62,7 @@ func TestAddTestVectors(t *testing.T) { func testAdd(c *qt.C, hashFunc HashFunction, testVectors []string) { database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 10, hashFunc) + tree, err := NewTree(database, 256, hashFunc) c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck @@ -68,7 +70,7 @@ func testAdd(c *qt.C, hashFunc HashFunction, testVectors []string) { c.Assert(err, qt.IsNil) c.Check(hex.EncodeToString(root), qt.Equals, testVectors[0]) - bLen := hashFunc.Len() + bLen := 32 err = tree.Add( BigIntToBytes(bLen, big.NewInt(1)), BigIntToBytes(bLen, big.NewInt(2))) @@ -92,11 +94,11 @@ func TestAddBatch(t *testing.T) { c := qt.New(t) database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 100, HashFunctionPoseidon) + tree, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck - bLen := tree.HashFunction().Len() + bLen := 32 for i := 0; i < 1000; i++ { k := BigIntToBytes(bLen, big.NewInt(int64(i))) v := BigIntToBytes(bLen, big.NewInt(0)) @@ -110,7 +112,7 @@ func TestAddBatch(t *testing.T) { database, err = badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree2, err := NewTree(database, 100, HashFunctionPoseidon) + tree2, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree2.db.Close() //nolint:errcheck @@ -133,11 +135,11 @@ func TestAddDifferentOrder(t *testing.T) { c := qt.New(t) database1, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree1, err := NewTree(database1, 100, HashFunctionPoseidon) + tree1, err := NewTree(database1, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree1.db.Close() //nolint:errcheck - bLen := tree1.HashFunction().Len() + bLen := 32 for i := 0; i < 16; i++ { k := BigIntToBytes(bLen, big.NewInt(int64(i))) v := BigIntToBytes(bLen, big.NewInt(0)) @@ -148,7 +150,7 @@ func TestAddDifferentOrder(t *testing.T) { database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree2, err := NewTree(database2, 100, HashFunctionPoseidon) + tree2, err := NewTree(database2, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree2.db.Close() //nolint:errcheck @@ -173,11 +175,11 @@ func TestAddRepeatedIndex(t *testing.T) { c := qt.New(t) database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 100, HashFunctionPoseidon) + tree, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck - bLen := tree.HashFunction().Len() + bLen := 32 k := BigIntToBytes(bLen, big.NewInt(int64(3))) v := BigIntToBytes(bLen, big.NewInt(int64(12))) @@ -191,11 +193,11 @@ func TestUpdate(t *testing.T) { c := qt.New(t) database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 100, HashFunctionPoseidon) + tree, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck - bLen := tree.HashFunction().Len() + bLen := 32 k := BigIntToBytes(bLen, big.NewInt(int64(20))) v := BigIntToBytes(bLen, big.NewInt(int64(12))) if err := tree.Add(k, v); err != nil { @@ -244,11 +246,11 @@ func TestAux(t *testing.T) { // TODO split in proper tests c := qt.New(t) database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 100, HashFunctionPoseidon) + tree, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck - bLen := tree.HashFunction().Len() + bLen := 32 k := BigIntToBytes(bLen, big.NewInt(int64(1))) v := BigIntToBytes(bLen, big.NewInt(int64(0))) err = tree.Add(k, v) @@ -283,11 +285,11 @@ func TestGet(t *testing.T) { c := qt.New(t) database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 100, HashFunctionPoseidon) + tree, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck - bLen := tree.HashFunction().Len() + bLen := 32 for i := 0; i < 10; i++ { k := BigIntToBytes(bLen, big.NewInt(int64(i))) v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) @@ -307,11 +309,11 @@ func TestGenProofAndVerify(t *testing.T) { c := qt.New(t) database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 100, HashFunctionPoseidon) + tree, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck - bLen := tree.HashFunction().Len() - 1 + bLen := 32 for i := 0; i < 10; i++ { k := BigIntToBytes(bLen, big.NewInt(int64(i))) v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) @@ -339,11 +341,11 @@ func TestDumpAndImportDump(t *testing.T) { c := qt.New(t) database1, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree1, err := NewTree(database1, 100, HashFunctionPoseidon) + tree1, err := NewTree(database1, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree1.db.Close() //nolint:errcheck - bLen := tree1.HashFunction().Len() + bLen := 32 for i := 0; i < 16; i++ { k := BigIntToBytes(bLen, big.NewInt(int64(i))) v := BigIntToBytes(bLen, big.NewInt(int64(i*2))) @@ -357,7 +359,7 @@ func TestDumpAndImportDump(t *testing.T) { database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree2, err := NewTree(database2, 100, HashFunctionPoseidon) + tree2, err := NewTree(database2, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree2.db.Close() //nolint:errcheck err = tree2.ImportDump(e) @@ -376,11 +378,11 @@ func TestRWMutex(t *testing.T) { c := qt.New(t) database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 100, HashFunctionPoseidon) + tree, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck - bLen := tree.HashFunction().Len() + bLen := 32 var keys, values [][]byte for i := 0; i < 1000; i++ { k := BigIntToBytes(bLen, big.NewInt(int64(i))) @@ -469,7 +471,7 @@ func TestAddBatchFullyUsed(t *testing.T) { var keys, values [][]byte for i := 0; i < 16; i++ { - k := BigIntToBytes(32, big.NewInt(int64(i))) + k := BigIntToBytes(1, big.NewInt(int64(i))) v := k keys = append(keys, k) @@ -492,10 +494,10 @@ func TestAddBatchFullyUsed(t *testing.T) { // get all key-values and check that are equal between both trees for i := 0; i < 16; i++ { - auxK1, auxV1, err := tree1.Get(BigIntToBytes(32, big.NewInt(int64(i)))) + auxK1, auxV1, err := tree1.Get(BigIntToBytes(1, big.NewInt(int64(i)))) c.Assert(err, qt.IsNil) - auxK2, auxV2, err := tree2.Get(BigIntToBytes(32, big.NewInt(int64(i)))) + auxK2, auxV2, err := tree2.Get(BigIntToBytes(1, big.NewInt(int64(i)))) c.Assert(err, qt.IsNil) c.Assert(auxK1, qt.DeepEquals, auxK2) @@ -504,7 +506,7 @@ func TestAddBatchFullyUsed(t *testing.T) { // try adding one more key to both trees (through Add & AddBatch) and // expect not being added due the tree is already full - k := BigIntToBytes(32, big.NewInt(int64(16))) + k := BigIntToBytes(1, big.NewInt(int64(16))) v := k err = tree1.Add(k, v) c.Assert(err, qt.Equals, ErrMaxVirtualLevel) @@ -518,13 +520,13 @@ func TestSetRoot(t *testing.T) { c := qt.New(t) database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 100, HashFunctionPoseidon) + tree, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) expectedRoot := "13742386369878513332697380582061714160370929283209286127733983161245560237407" // fill the tree - bLen := tree.HashFunction().Len() + bLen := 32 var keys, values [][]byte for i := 0; i < 1000; i++ { k := BigIntToBytes(bLen, big.NewInt(int64(i))) @@ -574,11 +576,11 @@ func TestSnapshot(t *testing.T) { c := qt.New(t) database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 100, HashFunctionPoseidon) + tree, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) // fill the tree - bLen := tree.HashFunction().Len() + bLen := 32 var keys, values [][]byte for i := 0; i < 1000; i++ { k := BigIntToBytes(bLen, big.NewInt(int64(i))) @@ -624,11 +626,11 @@ func TestGetFromSnapshotExpectArboErrKeyNotFound(t *testing.T) { database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) c.Assert(err, qt.IsNil) - tree, err := NewTree(database, 100, HashFunctionPoseidon) + tree, err := NewTree(database, 256, HashFunctionPoseidon) c.Assert(err, qt.IsNil) defer tree.db.Close() //nolint:errcheck - bLen := tree.HashFunction().Len() + bLen := 32 k := BigIntToBytes(bLen, big.NewInt(int64(3))) root, err := tree.Root() @@ -646,7 +648,7 @@ func TestKeyLen(t *testing.T) { c.Assert(err, qt.IsNil) // maxLevels is 100, keyPath length = ceil(maxLevels/8) = 13 maxLevels := 100 - tree, err := NewTree(database, maxLevels, HashFunctionPoseidon) + tree, err := NewTree(database, maxLevels, HashFunctionBlake2b) c.Assert(err, qt.IsNil) // expect no errors when adding a key of only 4 bytes (when the @@ -672,6 +674,75 @@ func TestKeyLen(t *testing.T) { invalids, err := tree.AddBatch([][]byte{k}, [][]byte{v}) c.Assert(err, qt.IsNil) c.Assert(len(invalids), qt.Equals, 0) + + // expect errors when adding a key bigger than maximum capacity of the + // tree: ceil(maxLevels/8) + maxLevels = 32 + database, err = badgerdb.New(badgerdb.Options{Path: c.TempDir()}) + c.Assert(err, qt.IsNil) + tree, err = NewTree(database, maxLevels, HashFunctionBlake2b) + c.Assert(err, qt.IsNil) + + maxKeyLen := int(math.Ceil(float64(maxLevels) / float64(8))) //nolint:gomnd + k = BigIntToBytes(maxKeyLen+1, big.NewInt(1)) + v = BigIntToBytes(maxKeyLen+1, big.NewInt(1)) + + expectedErrMsg := "len(k) can not be bigger than ceil(maxLevels/8)," + + " where len(k): 5, maxLevels: 32, max key len=ceil(maxLevels/8): 4." + + " Might need a bigger tree depth (maxLevels>=40) in order to input" + + " keys of length 5" + + err = tree.Add(k, v) + c.Assert(err.Error(), qt.Equals, expectedErrMsg) + + err = tree.Update(k, v) + c.Assert(err.Error(), qt.Equals, expectedErrMsg) + + _, _, _, _, err = tree.GenProof(k) + c.Assert(err.Error(), qt.Equals, expectedErrMsg) + + _, _, err = tree.Get(k) + c.Assert(err.Error(), qt.Equals, expectedErrMsg) + + // check AddBatch with few key-values + invalids, err = tree.AddBatch([][]byte{k}, [][]byte{v}) + c.Assert(err, qt.IsNil) + c.Assert(len(invalids), qt.Equals, 1) + + // check AddBatch with many key-values + nCPU := flp2(runtime.NumCPU()) + nKVs := nCPU + 1 + var ks, vs [][]byte + for i := 0; i < nKVs; i++ { + ks = append(ks, BigIntToBytes(maxKeyLen+1, big.NewInt(1))) + vs = append(vs, BigIntToBytes(maxKeyLen+1, big.NewInt(1))) + } + invalids, err = tree.AddBatch(ks, vs) + c.Assert(err, qt.IsNil) + c.Assert(len(invalids), qt.Equals, nKVs) + + // check that with maxKeyLen it can be added + k = BigIntToBytes(maxKeyLen, big.NewInt(1)) + err = tree.Add(k, v) + c.Assert(err, qt.IsNil) + + // check CheckProof check with key longer + kAux, vAux, packedSiblings, existence, err := tree.GenProof(k) + c.Assert(err, qt.IsNil) + c.Assert(existence, qt.IsTrue) + + root, err := tree.Root() + c.Assert(err, qt.IsNil) + verif, err := CheckProof(tree.HashFunction(), kAux, vAux, root, packedSiblings) + c.Assert(err, qt.IsNil) + c.Assert(verif, qt.IsTrue) + + // use a similar key but with one zero, expect that CheckProof fails on + // the verification + kAux = append(kAux, 0) + verif, err = CheckProof(tree.HashFunction(), kAux, vAux, root, packedSiblings) + c.Assert(err, qt.IsNil) + c.Assert(verif, qt.IsFalse) } func BenchmarkAdd(b *testing.B) { diff --git a/utils.go b/utils.go index d7a225f..c8fe02e 100644 --- a/utils.go +++ b/utils.go @@ -1,6 +1,8 @@ package arbo -import "math/big" +import ( + "math/big" +) // SwapEndianness swaps the order of the bytes in the byte slice. func SwapEndianness(b []byte) []byte { diff --git a/vt.go b/vt.go index 939c6b6..55bac06 100644 --- a/vt.go +++ b/vt.go @@ -37,22 +37,32 @@ type kv struct { v []byte } -func (p *params) keysValuesToKvs(ks, vs [][]byte) ([]kv, error) { +func (p *params) keysValuesToKvs(ks, vs [][]byte) ([]kv, []int, error) { if len(ks) != len(vs) { - return nil, fmt.Errorf("len(keys)!=len(values) (%d!=%d)", + return nil, nil, fmt.Errorf("len(keys)!=len(values) (%d!=%d)", len(ks), len(vs)) } - kvs := make([]kv, len(ks)) + var invalids []int + var kvs []kv for i := 0; i < len(ks); i++ { - keyPath := make([]byte, int(math.Ceil(float64(p.maxLevels)/float64(8)))) //nolint:gomnd - copy(keyPath[:], ks[i]) - kvs[i].pos = i - kvs[i].keyPath = keyPath - kvs[i].k = ks[i] - kvs[i].v = vs[i] + keyPath, err := keyPathFromKey(p.maxLevels, ks[i]) + if err != nil { + // TODO in a future iteration, invalids will contain + // the reason of the error of why each index is + // invalid. + invalids = append(invalids, i) + continue + } + + var kvsI kv + kvsI.pos = i + kvsI.keyPath = keyPath + kvsI.k = ks[i] + kvsI.v = vs[i] + kvs = append(kvs, kvsI) } - return kvs, nil + return kvs, invalids, nil } // vt stands for virtual tree. It's a tree that does not have any computed hash @@ -94,9 +104,9 @@ func (t *vt) addBatch(ks, vs [][]byte) ([]int, error) { l := int(math.Log2(float64(nCPU))) - kvs, err := t.params.keysValuesToKvs(ks, vs) + kvs, invalids, err := t.params.keysValuesToKvs(ks, vs) if err != nil { - return nil, err + return invalids, err } buckets := splitInBuckets(kvs, nCPU) @@ -186,7 +196,6 @@ func (t *vt) addBatch(ks, vs [][]byte) ([]int, error) { } wg.Wait() - var invalids []int for i := 0; i < len(invalidsInBucket); i++ { invalids = append(invalids, invalidsInBucket[i]...) } @@ -284,7 +293,10 @@ func upFromNodes(ns []*node) (*node, error) { // add adds a key&value as a leaf in the VirtualTree func (t *vt) add(fromLvl int, k, v []byte) error { - leaf := newLeafNode(t.params, k, v) + leaf, err := newLeafNode(t.params, k, v) + if err != nil { + return err + } if t.root == nil { t.root = leaf return nil @@ -366,16 +378,18 @@ func (t *vt) computeHashes() ([][2][]byte, error) { return pairs, nil } -func newLeafNode(p *params, k, v []byte) *node { - keyPath := make([]byte, p.hashFunction.Len()) - copy(keyPath[:], k) +func newLeafNode(p *params, k, v []byte) (*node, error) { + keyPath, err := keyPathFromKey(p.maxLevels, k) + if err != nil { + return nil, err + } path := getPath(p.maxLevels, keyPath) n := &node{ k: k, v: v, path: path, } - return n + return n, nil } type virtualNodeType int diff --git a/vt_test.go b/vt_test.go index 2ec9f23..4f60454 100644 --- a/vt_test.go +++ b/vt_test.go @@ -2,6 +2,7 @@ package arbo import ( "encoding/hex" + "math" "math/big" "testing" @@ -9,28 +10,62 @@ import ( "go.vocdoni.io/dvote/db/badgerdb" ) +// testVirtualTree adds the given key-values and tests the vt root against the +// Tree +func testVirtualTree(c *qt.C, maxLevels int, keys, values [][]byte) { + c.Assert(len(keys), qt.Equals, len(values)) + + // normal tree, to have an expected root value + database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) + c.Assert(err, qt.IsNil) + tree, err := NewTree(database, maxLevels, HashFunctionSha256) + c.Assert(err, qt.IsNil) + for i := 0; i < len(keys); i++ { + err := tree.Add(keys[i], values[i]) + c.Assert(err, qt.IsNil) + } + + // virtual tree + vTree := newVT(maxLevels, HashFunctionSha256) + + c.Assert(vTree.root, qt.IsNil) + + for i := 0; i < len(keys); i++ { + err := vTree.add(0, keys[i], values[i]) + c.Assert(err, qt.IsNil) + } + + // compute hashes, and check Root + _, err = vTree.computeHashes() + c.Assert(err, qt.IsNil) + root, err := tree.Root() + c.Assert(err, qt.IsNil) + c.Assert(vTree.root.h, qt.DeepEquals, root) +} + func TestVirtualTreeTestVectors(t *testing.T) { c := qt.New(t) - bLen := 32 + maxLevels := 32 + keyLen := int(math.Ceil(float64(maxLevels) / float64(8))) //nolint:gomnd keys := [][]byte{ - BigIntToBytes(bLen, big.NewInt(1)), - BigIntToBytes(bLen, big.NewInt(33)), - BigIntToBytes(bLen, big.NewInt(1234)), - BigIntToBytes(bLen, big.NewInt(123456789)), + BigIntToBytes(keyLen, big.NewInt(1)), + BigIntToBytes(keyLen, big.NewInt(33)), + BigIntToBytes(keyLen, big.NewInt(1234)), + BigIntToBytes(keyLen, big.NewInt(123456789)), } values := [][]byte{ - BigIntToBytes(bLen, big.NewInt(2)), - BigIntToBytes(bLen, big.NewInt(44)), - BigIntToBytes(bLen, big.NewInt(9876)), - BigIntToBytes(bLen, big.NewInt(987654321)), + BigIntToBytes(keyLen, big.NewInt(2)), + BigIntToBytes(keyLen, big.NewInt(44)), + BigIntToBytes(keyLen, big.NewInt(9876)), + BigIntToBytes(keyLen, big.NewInt(987654321)), } // check the root for different batches of leafs - testVirtualTree(c, 10, keys[:1], values[:1]) - testVirtualTree(c, 10, keys[:2], values[:2]) - testVirtualTree(c, 10, keys[:3], values[:3]) - testVirtualTree(c, 10, keys[:4], values[:4]) + testVirtualTree(c, maxLevels, keys[:1], values[:1]) + testVirtualTree(c, maxLevels, keys[:2], values[:2]) + testVirtualTree(c, maxLevels, keys[:3], values[:3]) + testVirtualTree(c, maxLevels, keys[:4], values[:4]) // test with hardcoded values testvectorKeys := []string{ @@ -53,8 +88,8 @@ func TestVirtualTreeTestVectors(t *testing.T) { } // check the root for different batches of leafs - testVirtualTree(c, 10, keys[:1], values[:1]) - testVirtualTree(c, 10, keys, values) + testVirtualTree(c, 256, keys[:1], values[:1]) + testVirtualTree(c, 256, keys, values) } func TestVirtualTreeRandomKeys(t *testing.T) { @@ -69,45 +104,14 @@ func TestVirtualTreeRandomKeys(t *testing.T) { values[i] = randomBytes(32) } - testVirtualTree(c, 100, keys, values) -} - -func testVirtualTree(c *qt.C, maxLevels int, keys, values [][]byte) { - c.Assert(len(keys), qt.Equals, len(values)) - - // normal tree, to have an expected root value - database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()}) - c.Assert(err, qt.IsNil) - tree, err := NewTree(database, maxLevels, HashFunctionSha256) - c.Assert(err, qt.IsNil) - for i := 0; i < len(keys); i++ { - err := tree.Add(keys[i], values[i]) - c.Assert(err, qt.IsNil) - } - - // virtual tree - vTree := newVT(maxLevels, HashFunctionSha256) - - c.Assert(vTree.root, qt.IsNil) - - for i := 0; i < len(keys); i++ { - err := vTree.add(0, keys[i], values[i]) - c.Assert(err, qt.IsNil) - } - - // compute hashes, and check Root - _, err = vTree.computeHashes() - c.Assert(err, qt.IsNil) - root, err := tree.Root() - c.Assert(err, qt.IsNil) - c.Assert(vTree.root.h, qt.DeepEquals, root) + testVirtualTree(c, 256, keys, values) } func TestVirtualTreeAddBatch(t *testing.T) { c := qt.New(t) nLeafs := 2000 - maxLevels := 100 + maxLevels := 256 keys := make([][]byte, nLeafs) values := make([][]byte, nLeafs) @@ -151,7 +155,7 @@ func TestVirtualTreeAddBatchFullyUsed(t *testing.T) { var keys, values [][]byte for i := 0; i < 128; i++ { - k := BigIntToBytes(32, big.NewInt(int64(i))) + k := BigIntToBytes(1, big.NewInt(int64(i))) v := k keys = append(keys, k)