Skip to content

Commit

Permalink
Add checks that len(key)<=maxKeyLen
Browse files Browse the repository at this point in the history
Add checks that the key is not bigger than maximum key length for the tree
maxLevels size, where maximum key len = ceil(maxLevels/8).

This is because if the key bits length is bigger than the maxLevels of the
tree, two different keys that their difference is at the end, will collision in
the same leaf of the tree (at the max depth).
  • Loading branch information
arnaucube committed Oct 4, 2021
1 parent 0921cac commit 30d8b42
Show file tree
Hide file tree
Showing 8 changed files with 263 additions and 146 deletions.
58 changes: 29 additions & 29 deletions addbatch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ func debugTime(descr string, time1, time2 time.Duration) {
func testInit(c *qt.C, n int) (*Tree, *Tree) {
database1, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree1, err := NewTree(database1, 100, HashFunctionPoseidon)
tree1, err := NewTree(database1, 256, HashFunctionPoseidon)
c.Assert(err, qt.IsNil)

database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree2, err := NewTree(database2, 100, HashFunctionPoseidon)
tree2, err := NewTree(database2, 256, HashFunctionPoseidon)
c.Assert(err, qt.IsNil)

bLen := HashFunctionPoseidon.Len()
Expand All @@ -70,11 +70,11 @@ func TestAddBatchTreeEmpty(t *testing.T) {

database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree, err := NewTree(database, 100, HashFunctionPoseidon)
tree, err := NewTree(database, 256, HashFunctionPoseidon)
c.Assert(err, qt.IsNil)
defer tree.db.Close() //nolint:errcheck

bLen := tree.HashFunction().Len()
bLen := 32
var keys, values [][]byte
for i := 0; i < nLeafs; i++ {
k := BigIntToBytes(bLen, big.NewInt(int64(i)))
Expand All @@ -93,7 +93,7 @@ func TestAddBatchTreeEmpty(t *testing.T) {

database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree2, err := NewTree(database2, 100, HashFunctionPoseidon)
tree2, err := NewTree(database2, 256, HashFunctionPoseidon)
c.Assert(err, qt.IsNil)
defer tree2.db.Close() //nolint:errcheck
tree2.dbgInit()
Expand All @@ -120,11 +120,11 @@ func TestAddBatchTreeEmptyNotPowerOf2(t *testing.T) {

database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree, err := NewTree(database, 100, HashFunctionPoseidon)
tree, err := NewTree(database, 256, HashFunctionPoseidon)
c.Assert(err, qt.IsNil)
defer tree.db.Close() //nolint:errcheck

bLen := tree.HashFunction().Len()
bLen := 32
for i := 0; i < nLeafs; i++ {
k := BigIntToBytes(bLen, big.NewInt(int64(i)))
v := BigIntToBytes(bLen, big.NewInt(int64(i*2)))
Expand All @@ -135,7 +135,7 @@ func TestAddBatchTreeEmptyNotPowerOf2(t *testing.T) {

database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree2, err := NewTree(database2, 100, HashFunctionPoseidon)
tree2, err := NewTree(database2, 256, HashFunctionPoseidon)
c.Assert(err, qt.IsNil)
defer tree2.db.Close() //nolint:errcheck

Expand Down Expand Up @@ -167,13 +167,13 @@ func TestAddBatchTestVector1(t *testing.T) {
c := qt.New(t)
database1, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree1, err := NewTree(database1, 100, HashFunctionBlake2b)
tree1, err := NewTree(database1, 256, HashFunctionBlake2b)
c.Assert(err, qt.IsNil)
defer tree1.db.Close() //nolint:errcheck

database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree2, err := NewTree(database2, 100, HashFunctionBlake2b)
tree2, err := NewTree(database2, 256, HashFunctionBlake2b)
c.Assert(err, qt.IsNil)
defer tree2.db.Close() //nolint:errcheck

Expand Down Expand Up @@ -207,13 +207,13 @@ func TestAddBatchTestVector1(t *testing.T) {
// 2nd test vectors
database1, err = badgerdb.New(badgerdb.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree1, err = NewTree(database1, 100, HashFunctionBlake2b)
tree1, err = NewTree(database1, 256, HashFunctionBlake2b)
c.Assert(err, qt.IsNil)
defer tree1.db.Close() //nolint:errcheck

database2, err = badgerdb.New(badgerdb.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree2, err = NewTree(database2, 100, HashFunctionBlake2b)
tree2, err = NewTree(database2, 256, HashFunctionBlake2b)
c.Assert(err, qt.IsNil)
defer tree2.db.Close() //nolint:errcheck

Expand Down Expand Up @@ -255,13 +255,13 @@ func TestAddBatchTestVector2(t *testing.T) {

database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree1, err := NewTree(database, 100, HashFunctionPoseidon)
tree1, err := NewTree(database, 256, HashFunctionPoseidon)
c.Assert(err, qt.IsNil)
defer tree1.db.Close() //nolint:errcheck

database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree2, err := NewTree(database2, 100, HashFunctionPoseidon)
tree2, err := NewTree(database2, 256, HashFunctionPoseidon)
c.Assert(err, qt.IsNil)
defer tree2.db.Close() //nolint:errcheck

Expand Down Expand Up @@ -300,13 +300,13 @@ func TestAddBatchTestVector3(t *testing.T) {

database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree1, err := NewTree(database, 100, HashFunctionPoseidon)
tree1, err := NewTree(database, 256, HashFunctionPoseidon)
c.Assert(err, qt.IsNil)
defer tree1.db.Close() //nolint:errcheck

database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree2, err := NewTree(database2, 100, HashFunctionPoseidon)
tree2, err := NewTree(database2, 256, HashFunctionPoseidon)
c.Assert(err, qt.IsNil)
defer tree2.db.Close() //nolint:errcheck

Expand Down Expand Up @@ -349,13 +349,13 @@ func TestAddBatchTreeEmptyRandomKeys(t *testing.T) {

database1, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree1, err := NewTree(database1, 100, HashFunctionBlake2b)
tree1, err := NewTree(database1, 256, HashFunctionBlake2b)
c.Assert(err, qt.IsNil)
defer tree1.db.Close() //nolint:errcheck

database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree2, err := NewTree(database2, 100, HashFunctionBlake2b)
tree2, err := NewTree(database2, 256, HashFunctionBlake2b)
c.Assert(err, qt.IsNil)
defer tree2.db.Close() //nolint:errcheck

Expand Down Expand Up @@ -699,7 +699,7 @@ func TestAddBatchNotEmptyUnbalanced(t *testing.T) {

database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree2, err := NewTree(database2, 100, HashFunctionPoseidon)
tree2, err := NewTree(database2, 256, HashFunctionPoseidon)
c.Assert(err, qt.IsNil)
defer tree2.db.Close() //nolint:errcheck
tree2.dbgInit()
Expand Down Expand Up @@ -776,7 +776,7 @@ func benchAdd(t *testing.T, ks, vs [][]byte) {

database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree, err := NewTree(database, 140, HashFunctionBlake2b)
tree, err := NewTree(database, 256, HashFunctionBlake2b)
c.Assert(err, qt.IsNil)
defer tree.db.Close() //nolint:errcheck

Expand All @@ -796,7 +796,7 @@ func benchAddBatch(t *testing.T, ks, vs [][]byte) {

database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree, err := NewTree(database, 140, HashFunctionBlake2b)
tree, err := NewTree(database, 256, HashFunctionBlake2b)
c.Assert(err, qt.IsNil)
defer tree.db.Close() //nolint:errcheck

Expand Down Expand Up @@ -829,7 +829,7 @@ func TestDbgStats(t *testing.T) {
// 1
database1, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree1, err := NewTree(database1, 100, HashFunctionBlake2b)
tree1, err := NewTree(database1, 256, HashFunctionBlake2b)
c.Assert(err, qt.IsNil)
defer tree1.db.Close() //nolint:errcheck

Expand All @@ -843,7 +843,7 @@ func TestDbgStats(t *testing.T) {
// 2
database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree2, err := NewTree(database2, 100, HashFunctionBlake2b)
tree2, err := NewTree(database2, 256, HashFunctionBlake2b)
c.Assert(err, qt.IsNil)
defer tree2.db.Close() //nolint:errcheck

Expand All @@ -856,7 +856,7 @@ func TestDbgStats(t *testing.T) {
// 3
database3, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree3, err := NewTree(database3, 100, HashFunctionBlake2b)
tree3, err := NewTree(database3, 256, HashFunctionBlake2b)
c.Assert(err, qt.IsNil)
defer tree3.db.Close() //nolint:errcheck

Expand Down Expand Up @@ -891,7 +891,7 @@ func TestLoadVT(t *testing.T) {

database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree, err := NewTree(database, 100, HashFunctionPoseidon)
tree, err := NewTree(database, 256, HashFunctionPoseidon)
c.Assert(err, qt.IsNil)
defer tree.db.Close() //nolint:errcheck

Expand Down Expand Up @@ -927,11 +927,11 @@ func TestAddKeysWithEmptyValues(t *testing.T) {

database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree, err := NewTree(database, 100, HashFunctionPoseidon)
tree, err := NewTree(database, 256, HashFunctionPoseidon)
c.Assert(err, qt.IsNil)
defer tree.db.Close() //nolint:errcheck

bLen := tree.HashFunction().Len()
bLen := 32
var keys, values [][]byte
for i := 0; i < nLeafs; i++ {
k := BigIntToBytes(bLen, big.NewInt(int64(i)))
Expand All @@ -948,7 +948,7 @@ func TestAddKeysWithEmptyValues(t *testing.T) {

database2, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree2, err := NewTree(database2, 100, HashFunctionPoseidon)
tree2, err := NewTree(database2, 256, HashFunctionPoseidon)
c.Assert(err, qt.IsNil)
defer tree2.db.Close() //nolint:errcheck
tree2.dbgInit()
Expand All @@ -962,7 +962,7 @@ func TestAddKeysWithEmptyValues(t *testing.T) {
// use tree3 to add nil value array
database3, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
tree3, err := NewTree(database3, 100, HashFunctionPoseidon)
tree3, err := NewTree(database3, 256, HashFunctionPoseidon)
c.Assert(err, qt.IsNil)
defer tree3.db.Close() //nolint:errcheck

Expand Down
3 changes: 1 addition & 2 deletions circomproofs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@ func TestCircomVerifierProof(t *testing.T) {
c.Assert(err, qt.IsNil)
defer tree.db.Close() //nolint:errcheck

bLen := tree.HashFunction().Len()

testVector := [][]int64{
{1, 11},
{2, 22},
{3, 33},
{4, 44},
}
bLen := 1
for i := 0; i < len(testVector); i++ {
k := BigIntToBytes(bLen, big.NewInt(testVector[i][0]))
v := BigIntToBytes(bLen, big.NewInt(testVector[i][1]))
Expand Down
3 changes: 1 addition & 2 deletions testvectors/circom/go-data-generator/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@ func TestGenerator(t *testing.T) {
tree, err := arbo.NewTree(database, 4, arbo.HashFunctionPoseidon)
c.Assert(err, qt.IsNil)

bLen := tree.HashFunction().Len()

testVector := [][]int64{
{1, 11},
{2, 22},
{3, 33},
{4, 44},
}
bLen := 1
for i := 0; i < len(testVector); i++ {
k := arbo.BigIntToBytes(bLen, big.NewInt(testVector[i][0]))
v := arbo.BigIntToBytes(bLen, big.NewInt(testVector[i][1]))
Expand Down
50 changes: 39 additions & 11 deletions tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,10 @@ func (t *Tree) AddBatchWithTx(wTx db.WriteTx, keys, values [][]byte) ([]int, err
}

// store root (from the vt) to db
if err := wTx.Set(dbKeyRoot, vt.root.h); err != nil {
return nil, err
if vt.root != nil {
if err := wTx.Set(dbKeyRoot, vt.root.h); err != nil {
return nil, err
}
}

// update nLeafs
Expand Down Expand Up @@ -310,14 +312,34 @@ func (t *Tree) AddWithTx(wTx db.WriteTx, k, v []byte) error {
return nil
}

func (t *Tree) add(wTx db.WriteTx, root []byte, fromLvl int, k, v []byte) ([]byte, error) {
keyPath := make([]byte, int(math.Ceil(float64(t.maxLevels)/float64(8)))) //nolint:gomnd
// keyPathFromKey returns the keyPath and checks that the key is not bigger
// than maximum key length for the tree maxLevels size.
// This is because if the key bits length is bigger than the maxLevels of the
// tree, two different keys that their difference is at the end, will collision
// in the same leaf of the tree (at the max depth).
func keyPathFromKey(maxLevels int, k []byte) ([]byte, error) {
maxKeyLen := int(math.Ceil(float64(maxLevels) / float64(8))) //nolint:gomnd
if len(k) > maxKeyLen {
return nil, fmt.Errorf("len(k) can not be bigger than ceil(maxLevels/8), where"+
" len(k): %d, maxLevels: %d, max key len=ceil(maxLevels/8): %d. Might need"+
" a bigger tree depth (maxLevels>=%d) in order to input keys of length %d",
len(k), maxLevels, maxKeyLen, len(k)*8, len(k)) //nolint:gomnd
}
keyPath := make([]byte, maxKeyLen) //nolint:gomnd
copy(keyPath[:], k)
return keyPath, nil
}

func (t *Tree) add(wTx db.WriteTx, root []byte, fromLvl int, k, v []byte) ([]byte, error) {
keyPath, err := keyPathFromKey(t.maxLevels, k)
if err != nil {
return nil, err
}
path := getPath(t.maxLevels, keyPath)

// go down to the leaf
var siblings [][]byte
_, _, siblings, err := t.down(wTx, k, root, siblings, path, fromLvl, false)
_, _, siblings, err = t.down(wTx, k, root, siblings, path, fromLvl, false)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -590,8 +612,10 @@ func (t *Tree) UpdateWithTx(wTx db.WriteTx, k, v []byte) error {
return ErrSnapshotNotEditable
}

keyPath := make([]byte, int(math.Ceil(float64(t.maxLevels)/float64(8)))) //nolint:gomnd
copy(keyPath[:], k)
keyPath, err := keyPathFromKey(t.maxLevels, k)
if err != nil {
return err
}
path := getPath(t.maxLevels, keyPath)

root, err := t.RootWithTx(wTx)
Expand Down Expand Up @@ -647,8 +671,10 @@ func (t *Tree) GenProof(k []byte) ([]byte, []byte, []byte, bool, error) {
// GenProofWithTx does the same than the GenProof method, but allowing to pass
// the db.ReadTx that is used.
func (t *Tree) GenProofWithTx(rTx db.ReadTx, k []byte) ([]byte, []byte, []byte, bool, error) {
keyPath := make([]byte, int(math.Ceil(float64(t.maxLevels)/float64(8)))) //nolint:gomnd
copy(keyPath[:], k)
keyPath, err := keyPathFromKey(t.maxLevels, k)
if err != nil {
return nil, nil, nil, false, err
}
path := getPath(t.maxLevels, keyPath)

root, err := t.RootWithTx(rTx)
Expand Down Expand Up @@ -782,8 +808,10 @@ func (t *Tree) Get(k []byte) ([]byte, []byte, error) {
// ErrKeyNotFound, and in the leafK & leafV parameters will be placed the data
// found in the tree in the leaf that was on the path going to the input key.
func (t *Tree) GetWithTx(rTx db.ReadTx, k []byte) ([]byte, []byte, error) {
keyPath := make([]byte, int(math.Ceil(float64(t.maxLevels)/float64(8)))) //nolint:gomnd
copy(keyPath[:], k)
keyPath, err := keyPathFromKey(t.maxLevels, k)
if err != nil {
return nil, nil, err
}
path := getPath(t.maxLevels, keyPath)

root, err := t.RootWithTx(rTx)
Expand Down
Loading

0 comments on commit 30d8b42

Please sign in to comment.