Skip to content

Commit

Permalink
Update keyPath to be ceil(maxLevels/8)
Browse files Browse the repository at this point in the history
  • Loading branch information
arnaucube committed Oct 1, 2021
1 parent 9eb7c8e commit 0921cac
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 23 deletions.
30 changes: 8 additions & 22 deletions tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,13 +311,10 @@ func (t *Tree) AddWithTx(wTx db.WriteTx, k, v []byte) error {
}

func (t *Tree) add(wTx db.WriteTx, root []byte, fromLvl int, k, v []byte) ([]byte, error) {
keyPath := make([]byte, t.hashFunction.Len())
// if len(k) > t.hashFunction.Len() { // WIP
// return nil, fmt.Errorf("len(k) > hashFunction.Len()")
// }
keyPath := make([]byte, int(math.Ceil(float64(t.maxLevels)/float64(8)))) //nolint:gomnd
copy(keyPath[:], k)

path := getPath(t.maxLevels, keyPath)

// go down to the leaf
var siblings [][]byte
_, _, siblings, err := t.down(wTx, k, root, siblings, path, fromLvl, false)
Expand Down Expand Up @@ -593,12 +590,7 @@ func (t *Tree) UpdateWithTx(wTx db.WriteTx, k, v []byte) error {
return ErrSnapshotNotEditable
}

var err error

keyPath := make([]byte, t.hashFunction.Len())
// if len(k) > t.hashFunction.Len() { // WIP
// return fmt.Errorf("len(k) > hashFunction.Len()")
// }
keyPath := make([]byte, int(math.Ceil(float64(t.maxLevels)/float64(8)))) //nolint:gomnd
copy(keyPath[:], k)
path := getPath(t.maxLevels, keyPath)

Expand Down Expand Up @@ -655,18 +647,15 @@ func (t *Tree) GenProof(k []byte) ([]byte, []byte, []byte, bool, error) {
// GenProofWithTx does the same than the GenProof method, but allowing to pass
// the db.ReadTx that is used.
func (t *Tree) GenProofWithTx(rTx db.ReadTx, k []byte) ([]byte, []byte, []byte, bool, error) {
keyPath := make([]byte, t.hashFunction.Len())
// if len(k) > t.hashFunction.Len() { // WIP
// return nil, nil, nil, false, fmt.Errorf("len(k) > hashFunction.Len()")
// }
keyPath := make([]byte, int(math.Ceil(float64(t.maxLevels)/float64(8)))) //nolint:gomnd
copy(keyPath[:], k)
path := getPath(t.maxLevels, keyPath)

root, err := t.RootWithTx(rTx)
if err != nil {
return nil, nil, nil, false, err
}

path := getPath(t.maxLevels, keyPath)
// go down to the leaf
var siblings [][]byte
_, value, siblings, err := t.down(rTx, k, root, siblings, path, 0, true)
Expand Down Expand Up @@ -793,18 +782,15 @@ func (t *Tree) Get(k []byte) ([]byte, []byte, error) {
// ErrKeyNotFound, and in the leafK & leafV parameters will be placed the data
// found in the tree in the leaf that was on the path going to the input key.
func (t *Tree) GetWithTx(rTx db.ReadTx, k []byte) ([]byte, []byte, error) {
keyPath := make([]byte, t.hashFunction.Len())
// if len(k) > t.hashFunction.Len() { // WIP
// return nil, nil, fmt.Errorf("len(k) > hashFunction.Len()")
// }
keyPath := make([]byte, int(math.Ceil(float64(t.maxLevels)/float64(8)))) //nolint:gomnd
copy(keyPath[:], k)
path := getPath(t.maxLevels, keyPath)

root, err := t.RootWithTx(rTx)
if err != nil {
return nil, nil, err
}

path := getPath(t.maxLevels, keyPath)
// go down to the leaf
var siblings [][]byte
_, value, _, err := t.down(rTx, k, root, siblings, path, 0, true)
Expand All @@ -827,7 +813,7 @@ func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) (bool,
return false, err
}

keyPath := make([]byte, hashFunc.Len())
keyPath := make([]byte, len(siblings))
copy(keyPath[:], k)

key, _, err := newLeafValue(hashFunc, k, v)
Expand Down
34 changes: 34 additions & 0 deletions tree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,40 @@ func TestGetFromSnapshotExpectArboErrKeyNotFound(t *testing.T) {
c.Assert(err, qt.Equals, ErrKeyNotFound) // and not equal to db.ErrKeyNotFound
}

func TestKeyLen(t *testing.T) {
c := qt.New(t)
database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
// maxLevels is 100, keyPath length = ceil(maxLevels/8) = 13
maxLevels := 100
tree, err := NewTree(database, maxLevels, HashFunctionPoseidon)
c.Assert(err, qt.IsNil)

// expect no errors when adding a key of only 4 bytes (when the
// required length of keyPath for 100 levels would be 13 bytes)
bLen := 4
k := BigIntToBytes(bLen, big.NewInt(1))
v := BigIntToBytes(bLen, big.NewInt(1))

err = tree.Add(k, v)
c.Assert(err, qt.IsNil)

err = tree.Update(k, v)
c.Assert(err, qt.IsNil)

_, _, _, _, err = tree.GenProof(k)
c.Assert(err, qt.IsNil)

_, _, err = tree.Get(k)
c.Assert(err, qt.IsNil)

k = BigIntToBytes(bLen, big.NewInt(2))
v = BigIntToBytes(bLen, big.NewInt(2))
invalids, err := tree.AddBatch([][]byte{k}, [][]byte{v})
c.Assert(err, qt.IsNil)
c.Assert(len(invalids), qt.Equals, 0)
}

func BenchmarkAdd(b *testing.B) {
bLen := 32 // for both Poseidon & Sha256
// prepare inputs
Expand Down
2 changes: 1 addition & 1 deletion vt.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (p *params) keysValuesToKvs(ks, vs [][]byte) ([]kv, error) {
}
kvs := make([]kv, len(ks))
for i := 0; i < len(ks); i++ {
keyPath := make([]byte, p.hashFunction.Len())
keyPath := make([]byte, int(math.Ceil(float64(p.maxLevels)/float64(8)))) //nolint:gomnd
copy(keyPath[:], ks[i])
kvs[i].pos = i
kvs[i].keyPath = keyPath
Expand Down

0 comments on commit 0921cac

Please sign in to comment.