diff --git a/blockchain/encoder_initializer.go b/blockchain/encoder_initializer.go index fbc82ab4d8..93efac7a82 100644 --- a/blockchain/encoder_initializer.go +++ b/blockchain/encoder_initializer.go @@ -20,6 +20,7 @@ func RegisterCoreTypesToEncoder() { reflect.TypeOf(core.DeployAccountTransaction{}), reflect.TypeOf(core.Cairo0Class{}), reflect.TypeOf(core.Cairo1Class{}), + reflect.TypeOf(core.StateContract{}), } for _, t := range types { diff --git a/clients/feeder/feeder.go b/clients/feeder/feeder.go index 82e99390ad..010adee3bc 100644 --- a/clients/feeder/feeder.go +++ b/clients/feeder/feeder.go @@ -92,7 +92,7 @@ func NopBackoff(d time.Duration) time.Duration { } // NewTestClient returns a client and a function to close a test server. -func NewTestClient(t *testing.T, network *utils.Network) *Client { +func NewTestClient(t testing.TB, network *utils.Network) *Client { srv := newTestServer(t, network) t.Cleanup(srv.Close) ua := "Juno/v0.0.1-test Starknet Implementation" @@ -117,7 +117,7 @@ func NewTestClient(t *testing.T, network *utils.Network) *Client { return c } -func newTestServer(t *testing.T, network *utils.Network) *httptest.Server { +func newTestServer(t testing.TB, network *utils.Network) *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { queryMap, err := url.ParseQuery(r.URL.RawQuery) if err != nil { diff --git a/cmd/juno/dbcmd.go b/cmd/juno/dbcmd.go index 4fe5cd3a81..5c79dea982 100644 --- a/cmd/juno/dbcmd.go +++ b/cmd/juno/dbcmd.go @@ -206,7 +206,7 @@ func dbSize(cmd *cobra.Command, args []string) error { totalSize += bucketItem.Size totalCount += bucketItem.Count - if utils.AnyOf(b, db.StateTrie, db.ContractStorage, db.Class, db.ContractNonce, db.ContractDeploymentHeight) { + if utils.AnyOf(b, db.StateTrie, db.ContractStorage, db.Class, db.Contract) { withoutHistorySize += bucketItem.Size withHistorySize += bucketItem.Size diff --git a/core/contract.go b/core/contract.go index 2af1fd8c4c..84378c6b51 100644 --- a/core/contract.go +++ b/core/contract.go @@ -7,6 +7,7 @@ import ( "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/trie" "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/encoder" ) // contract storage has fixed height at 251 @@ -17,187 +18,211 @@ var ( ErrContractAlreadyDeployed = errors.New("contract already deployed") ) -// NewContractUpdater creates an updater for the contract instance at the given address. -// Deploy should be called for contracts that were just deployed to the network. -func NewContractUpdater(addr *felt.Felt, txn db.Transaction) (*ContractUpdater, error) { - contractDeployed, err := deployed(addr, txn) - if err != nil { - return nil, err - } +type OnValueChanged = func(location, oldValue *felt.Felt) error - if !contractDeployed { - return nil, ErrContractNotDeployed - } +// StateContract represents a contract instance. +// The usage of a `StateContract` is as follows: +// 1. Create or obtain `StateContract` instance from the database. +// 2. Update the contract fields +// 3. Commit the contract to the database +type StateContract struct { + // ClassHash is the hash of the contract's class + ClassHash *felt.Felt + // Nonce is the contract's nonce + Nonce *felt.Felt + // DeployHeight is the height at which the contract is deployed + DeployHeight uint64 + // Address that this contract instance is deployed to + Address *felt.Felt `cbor:"-"` + // dirtyStorage is a map of storage locations that have been updated + dirtyStorage map[felt.Felt]*felt.Felt `cbor:"-"` +} - return &ContractUpdater{ - Address: addr, - txn: txn, - }, nil +// NewStateContract creates a new contract instance. +func NewStateContract( + addr *felt.Felt, + classHash *felt.Felt, + nonce *felt.Felt, + deployHeight uint64, +) *StateContract { + sc := &StateContract{ + Address: addr, + ClassHash: classHash, + Nonce: nonce, + DeployHeight: deployHeight, + dirtyStorage: make(map[felt.Felt]*felt.Felt), + } + + return sc } -// DeployContract sets up the database for a new contract. -func DeployContract(addr, classHash *felt.Felt, txn db.Transaction) (*ContractUpdater, error) { - contractDeployed, err := deployed(addr, txn) +// StorageRoot returns the root of the contract's storage trie. +func (c *StateContract) StorageRoot(txn db.Transaction) (*felt.Felt, error) { + storageTrie, err := storage(c.Address, txn) if err != nil { return nil, err } - if contractDeployed { - return nil, ErrContractAlreadyDeployed - } + return storageTrie.Root() +} - err = setClassHash(txn, addr, classHash) - if err != nil { - return nil, err +// UpdateStorage updates the storage of a contract. +// Note that this does not modify the storage trie, which must be committed separately. +func (c *StateContract) UpdateStorage(key, value *felt.Felt) { + if c.dirtyStorage == nil { + c.dirtyStorage = make(map[felt.Felt]*felt.Felt) } - c, err := NewContractUpdater(addr, txn) - if err != nil { - return nil, err + c.dirtyStorage[*key] = value +} + +// GetStorage retrieves the value of a storage location from the contract's storage +func (c *StateContract) GetStorage(key *felt.Felt, txn db.Transaction) (*felt.Felt, error) { + if c.dirtyStorage != nil { + if val, ok := c.dirtyStorage[*key]; ok { + return val, nil + } } - err = c.UpdateNonce(&felt.Zero) + // get from db + storage, err := storage(c.Address, txn) if err != nil { return nil, err } - return c, nil + return storage.Get(key) } -// ContractAddress computes the address of a Starknet contract. -func ContractAddress(callerAddress, classHash, salt *felt.Felt, constructorCallData []*felt.Felt) *felt.Felt { - prefix := new(felt.Felt).SetBytes([]byte("STARKNET_CONTRACT_ADDRESS")) - callDataHash := crypto.PedersenArray(constructorCallData...) - - // https://docs.starknet.io/architecture-and-concepts/smart-contracts/contract-address/ - return crypto.PedersenArray( - prefix, - callerAddress, - salt, - classHash, - callDataHash, - ) +// logOldValue is a helper function to record the history of a contract's value +func (c *StateContract) logOldValue(key []byte, oldValue *felt.Felt, height uint64, txn db.Transaction) error { + return txn.Set(logDBKey(key, height), oldValue.Marshal()) } -func deployed(addr *felt.Felt, txn db.Transaction) (bool, error) { - _, err := ContractClassHash(addr, txn) - if errors.Is(err, db.ErrKeyNotFound) { - return false, nil - } - if err != nil { - return false, err - } - return true, nil +// LogStorage records the history of the contract's storage +func (c *StateContract) LogStorage(location, oldVal *felt.Felt, height uint64, txn db.Transaction) error { + key := storageLogKey(c.Address, location) + return c.logOldValue(key, oldVal, height, txn) } -// ContractUpdater is a helper to update an existing contract instance. -type ContractUpdater struct { - // Address that this contract instance is deployed to - Address *felt.Felt - // txn to access the database - txn db.Transaction +// LogNonce records the history of the contract's nonce +func (c *StateContract) LogNonce(height uint64, txn db.Transaction) error { + key := nonceLogKey(c.Address) + return c.logOldValue(key, c.Nonce, height, txn) } -// Purge eliminates the contract instance, deleting all associated data from storage -// assumes storage is cleared in revert process -func (c *ContractUpdater) Purge() error { - addrBytes := c.Address.Marshal() - buckets := []db.Bucket{db.ContractNonce, db.ContractClassHash} - - for _, bucket := range buckets { - if err := c.txn.Delete(bucket.Key(addrBytes)); err != nil { - return err - } - } - - return nil +// LogClassHash records the history of the contract's class hash +func (c *StateContract) LogClassHash(height uint64, txn db.Transaction) error { + key := classHashLogKey(c.Address) + return c.logOldValue(key, c.ClassHash, height, txn) } -// ContractNonce returns the amount transactions sent from this contract. -// Only account contracts can have a non-zero nonce. -func ContractNonce(addr *felt.Felt, txn db.Transaction) (*felt.Felt, error) { - key := db.ContractNonce.Key(addr.Marshal()) - var nonce *felt.Felt - if err := txn.Get(key, func(val []byte) error { - nonce = new(felt.Felt) - nonce.SetBytes(val) - return nil - }); err != nil { - return nil, err - } - return nonce, nil -} +// BufferedCommit creates a buffered transaction and commits the contract to the database +func (c *StateContract) BufferedCommit(txn db.Transaction, logChanges bool, blockNum uint64) (*db.BufferedTransaction, error) { + bufferedTxn := db.NewBufferedTransaction(txn) -// UpdateNonce updates the nonce value in the database. -func (c *ContractUpdater) UpdateNonce(nonce *felt.Felt) error { - nonceKey := db.ContractNonce.Key(c.Address.Marshal()) - return c.txn.Set(nonceKey, nonce.Marshal()) -} - -// ContractRoot returns the root of the contract storage. -func ContractRoot(addr *felt.Felt, txn db.Transaction) (*felt.Felt, error) { - cStorage, err := storage(addr, txn) - if err != nil { + if err := c.Commit(bufferedTxn, logChanges, blockNum); err != nil { return nil, err } - return cStorage.Root() -} -type OnValueChanged = func(location, oldValue *felt.Felt) error + return bufferedTxn, nil +} -// UpdateStorage applies a change-set to the contract storage. -func (c *ContractUpdater) UpdateStorage(diff map[felt.Felt]*felt.Felt, cb OnValueChanged) error { - cStorage, err := storage(c.Address, c.txn) +func (c *StateContract) Commit(txn db.Transaction, logChanges bool, blockNum uint64) error { + storageTrie, err := storage(c.Address, txn) if err != nil { return err } - // apply the diff - for key, value := range diff { - oldValue, pErr := cStorage.Put(&key, value) - if pErr != nil { - return pErr + + for key, value := range c.dirtyStorage { + oldVal, err := storageTrie.Put(&key, value) + if err != nil { + return err } - if oldValue != nil { - if err = cb(&key, oldValue); err != nil { + if oldVal != nil && logChanges { + if err = c.LogStorage(&key, oldVal, blockNum, txn); err != nil { return err } } } - return cStorage.Commit() + if err := storageTrie.Commit(); err != nil { + return err + } + + contractBytes, err := encoder.Marshal(c) + if err != nil { + return err + } + + return txn.Set(db.Contract.Key(c.Address.Marshal()), contractBytes) } -func ContractStorage(addr, key *felt.Felt, txn db.Transaction) (*felt.Felt, error) { - cStorage, err := storage(addr, txn) +// Purge eliminates the contract instance, deleting all associated data from database +// assumes storage is cleared in revert process +func (c *StateContract) Purge(txn db.Transaction) error { + addrBytes := c.Address.Marshal() + + return txn.Delete(db.Contract.Key(addrBytes)) +} + +func storageLogKey(contractAddress, storageLocation *felt.Felt) []byte { + return db.ContractStorageHistory.Key(contractAddress.Marshal(), storageLocation.Marshal()) +} + +func nonceLogKey(contractAddress *felt.Felt) []byte { + return db.ContractNonceHistory.Key(contractAddress.Marshal()) +} + +func classHashLogKey(contractAddress *felt.Felt) []byte { + return db.ContractClassHashHistory.Key(contractAddress.Marshal()) +} + +// GetContract is a wrapper around getContract which checks if a contract is deployed +func GetContract(addr *felt.Felt, txn db.Transaction) (*StateContract, error) { + contract, err := getContract(addr, txn) if err != nil { + if errors.Is(err, db.ErrKeyNotFound) { + return nil, ErrContractNotDeployed + } return nil, err } - return cStorage.Get(key) + + return contract, nil } -// ContractClassHash returns hash of the class that the contract at the given address instantiates. -func ContractClassHash(addr *felt.Felt, txn db.Transaction) (*felt.Felt, error) { - key := db.ContractClassHash.Key(addr.Marshal()) - var classHash *felt.Felt +// getContract gets a contract instance from the database. +func getContract(addr *felt.Felt, txn db.Transaction) (*StateContract, error) { + key := db.Contract.Key(addr.Marshal()) + var contract StateContract if err := txn.Get(key, func(val []byte) error { - classHash = new(felt.Felt) - classHash.SetBytes(val) + if err := encoder.Unmarshal(val, &contract); err != nil { + return err + } + + contract.Address = addr + contract.dirtyStorage = make(map[felt.Felt]*felt.Felt) + return nil }); err != nil { return nil, err } - return classHash, nil + return &contract, nil } -func setClassHash(txn db.Transaction, addr, classHash *felt.Felt) error { - classHashKey := db.ContractClassHash.Key(addr.Marshal()) - return txn.Set(classHashKey, classHash.Marshal()) -} +// ContractAddress computes the address of a Starknet contract. +func ContractAddress(callerAddress, classHash, salt *felt.Felt, constructorCallData []*felt.Felt) *felt.Felt { + prefix := new(felt.Felt).SetBytes([]byte("STARKNET_CONTRACT_ADDRESS")) + callDataHash := crypto.PedersenArray(constructorCallData...) -// Replace replaces the class that the contract instantiates -func (c *ContractUpdater) Replace(classHash *felt.Felt) error { - return setClassHash(c.txn, c.Address, classHash) + // https://docs.starknet.io/architecture-and-concepts/smart-contracts/contract-address/ + return crypto.PedersenArray( + prefix, + callerAddress, + salt, + classHash, + callDataHash, + ) } // storage returns the [core.Trie] that represents the diff --git a/core/contract_test.go b/core/contract_test.go index 8ace83ba7e..70a4608554 100644 --- a/core/contract_test.go +++ b/core/contract_test.go @@ -11,10 +11,6 @@ import ( "github.com/stretchr/testify/require" ) -var NoopOnValueChanged = func(location, oldValue *felt.Felt) error { - return nil -} - func TestContractAddress(t *testing.T) { tests := []struct { callerAddress *felt.Felt @@ -59,124 +55,128 @@ func TestNewContract(t *testing.T) { t.Cleanup(func() { require.NoError(t, txn.Discard()) }) + + blockNumber := uint64(10) addr := new(felt.Felt).SetUint64(234) classHash := new(felt.Felt).SetBytes([]byte("class hash")) - t.Run("cannot create Contract instance if un-deployed", func(t *testing.T) { - _, err = core.NewContractUpdater(addr, txn) - require.EqualError(t, err, core.ErrContractNotDeployed.Error()) + t.Run("cannot get contract if un-deployed", func(t *testing.T) { + _, err = core.GetContract(addr, txn) + require.ErrorIs(t, err, core.ErrContractNotDeployed) }) - contract, err := core.DeployContract(addr, classHash, txn) - require.NoError(t, err) + var contract *core.StateContract + t.Run("commit contract", func(t *testing.T) { + contract = core.NewStateContract(addr, classHash, &felt.Zero, blockNumber) + require.NoError(t, contract.Commit(txn, true, blockNumber)) + }) - t.Run("redeploy should fail", func(t *testing.T) { - _, err := core.DeployContract(addr, classHash, txn) - require.EqualError(t, err, core.ErrContractAlreadyDeployed.Error()) + t.Run("get contract from db", func(t *testing.T) { + contract, err = core.GetContract(addr, txn) + require.NoError(t, err) }) - t.Run("a call to contract should fail with a committed txn", func(t *testing.T) { - assert.NoError(t, txn.Commit()) - t.Run("ClassHash()", func(t *testing.T) { - _, err := core.ContractClassHash(addr, txn) - assert.Error(t, err) - }) - t.Run("Root()", func(t *testing.T) { - _, err := core.ContractRoot(addr, txn) - assert.Error(t, err) - }) - t.Run("Nonce()", func(t *testing.T) { - _, err := core.ContractNonce(addr, txn) - assert.Error(t, err) - }) - t.Run("Storage()", func(t *testing.T) { - _, err := core.ContractStorage(addr, classHash, txn) - assert.Error(t, err) - }) - t.Run("UpdateNonce()", func(t *testing.T) { - assert.Error(t, contract.UpdateNonce(&felt.Zero)) - }) - t.Run("UpdateStorage()", func(t *testing.T) { - assert.Error(t, contract.UpdateStorage(nil, NoopOnValueChanged)) - }) + t.Run("check contract fields", func(t *testing.T) { + assert.Equal(t, addr, contract.Address) + assert.Equal(t, classHash, contract.ClassHash) + assert.Equal(t, &felt.Zero, contract.Nonce) + assert.Equal(t, blockNumber, contract.DeployHeight) }) } -func TestNonceAndClassHash(t *testing.T) { +func TestUpdateContract(t *testing.T) { testDB := pebble.NewMemTest(t) txn, err := testDB.NewTransaction(true) require.NoError(t, err) + blockNumber := uint64(10) addr := new(felt.Felt).SetUint64(44) classHash := new(felt.Felt).SetUint64(37) - contract, err := core.DeployContract(addr, classHash, txn) + contract := core.NewStateContract(addr, classHash, &felt.Zero, blockNumber) + require.NoError(t, contract.Commit(txn, true, blockNumber)) + + contract, err = core.GetContract(addr, txn) require.NoError(t, err) - t.Run("initial nonce should be 0", func(t *testing.T) { - got, err := core.ContractNonce(addr, txn) - require.NoError(t, err) - assert.Equal(t, new(felt.Felt), got) + t.Run("verify initial nonce", func(t *testing.T) { + require.Equal(t, &felt.Zero, contract.Nonce) }) - t.Run("UpdateNonce()", func(t *testing.T) { - require.NoError(t, contract.UpdateNonce(classHash)) - got, err := core.ContractNonce(addr, txn) + t.Run("update contract nonce", func(t *testing.T) { + newNonce := new(felt.Felt).SetUint64(1) + contract.Nonce = newNonce + require.NoError(t, contract.Commit(txn, true, blockNumber)) + + contract, err = core.GetContract(addr, txn) require.NoError(t, err) - assert.Equal(t, classHash, got) + + require.Equal(t, newNonce, contract.Nonce) }) - t.Run("ClassHash()", func(t *testing.T) { - got, err := core.ContractClassHash(addr, txn) - require.NoError(t, err) - assert.Equal(t, classHash, got) + t.Run("verify initial class hash", func(t *testing.T) { + require.Equal(t, classHash, contract.ClassHash) }) - t.Run("Replace()", func(t *testing.T) { - replaceWith := utils.HexToFelt(t, "0xDEADBEEF") - require.NoError(t, contract.Replace(replaceWith)) - got, err := core.ContractClassHash(addr, txn) + t.Run("update class hash", func(t *testing.T) { + newHash := new(felt.Felt).SetUint64(1) + contract.ClassHash = newHash + require.NoError(t, contract.Commit(txn, true, blockNumber)) + + contract, err = core.GetContract(addr, txn) require.NoError(t, err) - assert.Equal(t, replaceWith, got) + + require.Equal(t, newHash, contract.ClassHash) }) } -func TestUpdateStorageAndStorage(t *testing.T) { +func TestContractStorage(t *testing.T) { testDB := pebble.NewMemTest(t) txn, err := testDB.NewTransaction(true) require.NoError(t, err) + blockNumber := uint64(10) addr := new(felt.Felt).SetUint64(44) classHash := new(felt.Felt).SetUint64(37) - contract, err := core.DeployContract(addr, classHash, txn) - require.NoError(t, err) + contract := core.NewStateContract(addr, classHash, &felt.Zero, blockNumber) + require.NoError(t, contract.Commit(txn, true, blockNumber)) + + t.Run("get initial storage", func(t *testing.T) { + gotValue, err := contract.GetStorage(addr, txn) + require.NoError(t, err) + assert.Equal(t, &felt.Zero, gotValue) + }) t.Run("apply storage diff", func(t *testing.T) { - oldRoot, err := core.ContractRoot(addr, txn) + oldRoot, err := contract.StorageRoot(txn) require.NoError(t, err) - require.NoError(t, contract.UpdateStorage(map[felt.Felt]*felt.Felt{*addr: classHash}, NoopOnValueChanged)) + contract.UpdateStorage(addr, classHash) + require.NoError(t, contract.Commit(txn, false, blockNumber)) - gotValue, err := core.ContractStorage(addr, addr, txn) + contract, err = core.GetContract(addr, txn) + require.NoError(t, err) + + gotValue, err := contract.GetStorage(addr, txn) require.NoError(t, err) assert.Equal(t, classHash, gotValue) - newRoot, err := core.ContractRoot(addr, txn) + newRoot, err := contract.StorageRoot(txn) require.NoError(t, err) assert.NotEqual(t, oldRoot, newRoot) }) t.Run("delete key from storage with storage diff", func(t *testing.T) { - require.NoError(t, contract.UpdateStorage(map[felt.Felt]*felt.Felt{*addr: new(felt.Felt)}, NoopOnValueChanged)) + contract.UpdateStorage(addr, new(felt.Felt)) + require.NoError(t, contract.Commit(txn, false, blockNumber)) - val, err := core.ContractStorage(addr, addr, txn) + contract, err = core.GetContract(addr, txn) require.NoError(t, err) - require.Equal(t, &felt.Zero, val) - sRoot, err := core.ContractRoot(addr, txn) + gotValue, err := contract.GetStorage(addr, txn) require.NoError(t, err) - assert.Equal(t, new(felt.Felt), sRoot) + assert.Equal(t, &felt.Zero, gotValue) }) } @@ -185,13 +185,14 @@ func TestPurge(t *testing.T) { txn, err := testDB.NewTransaction(true) require.NoError(t, err) + blockNumber := uint64(10) addr := new(felt.Felt).SetUint64(44) classHash := new(felt.Felt).SetUint64(37) - contract, err := core.DeployContract(addr, classHash, txn) - require.NoError(t, err) + contract := core.NewStateContract(addr, classHash, &felt.Zero, blockNumber) + require.NoError(t, contract.Commit(txn, false, blockNumber)) - require.NoError(t, contract.Purge()) - _, err = core.NewContractUpdater(addr, txn) + require.NoError(t, contract.Purge(txn)) + _, err = core.GetContract(addr, txn) assert.ErrorIs(t, err, core.ErrContractNotDeployed) } diff --git a/core/history.go b/core/history.go deleted file mode 100644 index f14db1702e..0000000000 --- a/core/history.go +++ /dev/null @@ -1,134 +0,0 @@ -package core - -import ( - "bytes" - "encoding/binary" - "errors" - - "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/db" - "github.com/NethermindEth/juno/utils" -) - -var ErrCheckHeadState = errors.New("check head state") - -type history struct { - txn db.Transaction -} - -func logDBKey(key []byte, height uint64) []byte { - return binary.BigEndian.AppendUint64(key, height) -} - -func (h *history) logOldValue(key, value []byte, height uint64) error { - return h.txn.Set(logDBKey(key, height), value) -} - -func (h *history) deleteLog(key []byte, height uint64) error { - return h.txn.Delete(logDBKey(key, height)) -} - -func (h *history) valueAt(key []byte, height uint64) ([]byte, error) { - it, err := h.txn.NewIterator() - if err != nil { - return nil, err - } - - for it.Seek(logDBKey(key, height)); it.Valid(); it.Next() { - seekedKey := it.Key() - // seekedKey size should be `len(key) + sizeof(uint64)` and seekedKey should match key prefix - if len(seekedKey) != len(key)+8 || !bytes.HasPrefix(seekedKey, key) { - break - } - - seekedHeight := binary.BigEndian.Uint64(seekedKey[len(key):]) - if seekedHeight < height { - // last change happened before the height we are looking for - // check head state - break - } else if seekedHeight == height { - // a log exists for the height we are looking for, so the old value in this log entry is not useful. - // advance the iterator and see we can use the next entry. If not, ErrCheckHeadState will be returned - continue - } - - val, itErr := it.Value() - if err = utils.RunAndWrapOnError(it.Close, itErr); err != nil { - return nil, err - } - // seekedHeight > height - return val, nil - } - - return nil, utils.RunAndWrapOnError(it.Close, ErrCheckHeadState) -} - -func storageLogKey(contractAddress, storageLocation *felt.Felt) []byte { - return db.ContractStorageHistory.Key(contractAddress.Marshal(), storageLocation.Marshal()) -} - -// LogContractStorage logs the old value of a storage location for the given contract which changed on height `height` -func (h *history) LogContractStorage(contractAddress, storageLocation, oldValue *felt.Felt, height uint64) error { - key := storageLogKey(contractAddress, storageLocation) - return h.logOldValue(key, oldValue.Marshal(), height) -} - -// DeleteContractStorageLog deletes the log at the given height -func (h *history) DeleteContractStorageLog(contractAddress, storageLocation *felt.Felt, height uint64) error { - return h.deleteLog(storageLogKey(contractAddress, storageLocation), height) -} - -// ContractStorageAt returns the value of a storage location of the given contract at the height `height` -func (h *history) ContractStorageAt(contractAddress, storageLocation *felt.Felt, height uint64) (*felt.Felt, error) { - key := storageLogKey(contractAddress, storageLocation) - value, err := h.valueAt(key, height) - if err != nil { - return nil, err - } - - return new(felt.Felt).SetBytes(value), nil -} - -func nonceLogKey(contractAddress *felt.Felt) []byte { - return db.ContractNonceHistory.Key(contractAddress.Marshal()) -} - -func (h *history) LogContractNonce(contractAddress, oldValue *felt.Felt, height uint64) error { - return h.logOldValue(nonceLogKey(contractAddress), oldValue.Marshal(), height) -} - -func (h *history) DeleteContractNonceLog(contractAddress *felt.Felt, height uint64) error { - return h.deleteLog(nonceLogKey(contractAddress), height) -} - -func (h *history) ContractNonceAt(contractAddress *felt.Felt, height uint64) (*felt.Felt, error) { - key := nonceLogKey(contractAddress) - value, err := h.valueAt(key, height) - if err != nil { - return nil, err - } - - return new(felt.Felt).SetBytes(value), nil -} - -func classHashLogKey(contractAddress *felt.Felt) []byte { - return db.ContractClassHashHistory.Key(contractAddress.Marshal()) -} - -func (h *history) LogContractClassHash(contractAddress, oldValue *felt.Felt, height uint64) error { - return h.logOldValue(classHashLogKey(contractAddress), oldValue.Marshal(), height) -} - -func (h *history) DeleteContractClassHashLog(contractAddress *felt.Felt, height uint64) error { - return h.deleteLog(classHashLogKey(contractAddress), height) -} - -func (h *history) ContractClassHashAt(contractAddress *felt.Felt, height uint64) (*felt.Felt, error) { - key := classHashLogKey(contractAddress) - value, err := h.valueAt(key, height) - if err != nil { - return nil, err - } - - return new(felt.Felt).SetBytes(value), nil -} diff --git a/core/history_pkg_test.go b/core/history_pkg_test.go deleted file mode 100644 index b883f4c0ed..0000000000 --- a/core/history_pkg_test.go +++ /dev/null @@ -1,104 +0,0 @@ -package core - -import ( - "testing" - - "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/db/pebble" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestHistory(t *testing.T) { - testDB := pebble.NewMemTest(t) - txn, err := testDB.NewTransaction(true) - require.NoError(t, err) - t.Cleanup(func() { - require.NoError(t, txn.Discard()) - }) - - history := &history{txn: txn} - contractAddress := new(felt.Felt).SetUint64(123) - - for desc, test := range map[string]struct { - logger func(location, oldValue *felt.Felt, height uint64) error - getter func(location *felt.Felt, height uint64) (*felt.Felt, error) - deleter func(location *felt.Felt, height uint64) error - }{ - "contract storage": { - logger: func(location, oldValue *felt.Felt, height uint64) error { - return history.LogContractStorage(contractAddress, location, oldValue, height) - }, - getter: func(location *felt.Felt, height uint64) (*felt.Felt, error) { - return history.ContractStorageAt(contractAddress, location, height) - }, - deleter: func(location *felt.Felt, height uint64) error { - return history.DeleteContractStorageLog(contractAddress, location, height) - }, - }, - "contract nonce": { - logger: history.LogContractNonce, - getter: history.ContractNonceAt, - deleter: history.DeleteContractNonceLog, - }, - "contract class hash": { - logger: history.LogContractClassHash, - getter: history.ContractClassHashAt, - deleter: history.DeleteContractClassHashLog, - }, - } { - location := new(felt.Felt).SetUint64(456) - - t.Run(desc, func(t *testing.T) { - t.Run("no history", func(t *testing.T) { - _, err := test.getter(location, 1) - assert.ErrorIs(t, err, ErrCheckHeadState) - }) - - value := new(felt.Felt).SetUint64(789) - - t.Run("log value changed at height 5 and 10", func(t *testing.T) { - assert.NoError(t, test.logger(location, &felt.Zero, 5)) - assert.NoError(t, test.logger(location, value, 10)) - }) - - t.Run("get value before height 5", func(t *testing.T) { - oldValue, err := test.getter(location, 1) - require.NoError(t, err) - assert.Equal(t, &felt.Zero, oldValue) - }) - - t.Run("get value between height 5-10 ", func(t *testing.T) { - oldValue, err := test.getter(location, 7) - require.NoError(t, err) - assert.Equal(t, value, oldValue) - }) - - t.Run("get value on height that change happened ", func(t *testing.T) { - oldValue, err := test.getter(location, 5) - require.NoError(t, err) - assert.Equal(t, value, oldValue) - - _, err = test.getter(location, 10) - assert.ErrorIs(t, err, ErrCheckHeadState) - }) - - t.Run("get value after height 10 ", func(t *testing.T) { - _, err := test.getter(location, 13) - assert.ErrorIs(t, err, ErrCheckHeadState) - }) - - t.Run("get a random location ", func(t *testing.T) { - _, err := test.getter(new(felt.Felt).SetUint64(37), 13) - assert.ErrorIs(t, err, ErrCheckHeadState) - }) - - require.NoError(t, test.deleter(location, 10)) - - t.Run("get after delete", func(t *testing.T) { - _, err := test.getter(location, 7) - assert.ErrorIs(t, err, ErrCheckHeadState) - }) - }) - } -} diff --git a/core/state.go b/core/state.go index effde8b518..5cbdcc5c45 100644 --- a/core/state.go +++ b/core/state.go @@ -15,14 +15,16 @@ import ( "github.com/NethermindEth/juno/core/trie" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/encoder" + "github.com/NethermindEth/juno/utils" "github.com/sourcegraph/conc/pool" ) const globalTrieHeight = 251 var ( - stateVersion = new(felt.Felt).SetBytes([]byte(`STARKNET_STATE_V0`)) - leafVersion = new(felt.Felt).SetBytes([]byte(`CONTRACT_CLASS_LEAF_V0`)) + stateVersion = new(felt.Felt).SetBytes([]byte(`STARKNET_STATE_V0`)) + leafVersion = new(felt.Felt).SetBytes([]byte(`CONTRACT_CLASS_LEAF_V0`)) + ErrCheckHeadState = errors.New("check head state") ) var _ StateHistoryReader = (*State)(nil) @@ -45,46 +47,120 @@ type StateReader interface { } type State struct { - *history txn db.Transaction + + // This map holds the contract objects which are being updated in the current state update. + contracts map[felt.Felt]*StateContract } func NewState(txn db.Transaction) *State { return &State{ - history: &history{txn: txn}, - txn: txn, + txn: txn, + contracts: make(map[felt.Felt]*StateContract), } } -// putNewContract creates a contract storage instance in the state and stores the relation between contract address and class hash to be -// queried later with [GetContractClass]. -func (s *State) putNewContract(stateTrie *trie.Trie, addr, classHash *felt.Felt, blockNumber uint64) error { - contract, err := DeployContract(addr, classHash, s.txn) +// ContractClassHash returns class hash of a contract at a given address. +func (s *State) ContractClassHash(addr *felt.Felt) (*felt.Felt, error) { + contract, err := GetContract(addr, s.txn) if err != nil { - return err - } - - numBytes := MarshalBlockNumber(blockNumber) - if err = s.txn.Set(db.ContractDeploymentHeight.Key(addr.Marshal()), numBytes); err != nil { - return err + return nil, err } - return s.updateContractCommitment(stateTrie, contract) -} - -// ContractClassHash returns class hash of a contract at a given address. -func (s *State) ContractClassHash(addr *felt.Felt) (*felt.Felt, error) { - return ContractClassHash(addr, s.txn) + return contract.ClassHash, nil } // ContractNonce returns nonce of a contract at a given address. func (s *State) ContractNonce(addr *felt.Felt) (*felt.Felt, error) { - return ContractNonce(addr, s.txn) + contract, err := GetContract(addr, s.txn) + if err != nil { + return nil, err + } + + return contract.Nonce, nil } // ContractStorage returns value of a key in the storage of the contract at the given address. func (s *State) ContractStorage(addr, key *felt.Felt) (*felt.Felt, error) { - return ContractStorage(addr, key, s.txn) + contract, err := GetContract(addr, s.txn) + if err != nil { + return nil, err + } + + return contract.GetStorage(key, s.txn) +} + +func (s *State) ContractClassHashAt(addr *felt.Felt, blockNumber uint64) (*felt.Felt, error) { + return s.contractValueAt(classHashLogKey, addr, blockNumber) +} + +func (s *State) ContractStorageAt(addr, loc *felt.Felt, blockNumber uint64) (*felt.Felt, error) { + return s.contractValueAt(func(a *felt.Felt) []byte { return storageLogKey(a, loc) }, addr, blockNumber) +} + +func (s *State) ContractNonceAt(addr *felt.Felt, blockNumber uint64) (*felt.Felt, error) { + return s.contractValueAt(nonceLogKey, addr, blockNumber) +} + +func (s *State) deleteLog(key []byte, height uint64) error { + return s.txn.Delete(logDBKey(key, height)) +} + +func (s *State) DeleteContractStorageLog(contractAddress, storageLocation *felt.Felt, height uint64) error { + return s.deleteLog(storageLogKey(contractAddress, storageLocation), height) +} + +func (s *State) DeleteContractNonceLog(contractAddress *felt.Felt, height uint64) error { + return s.deleteLog(nonceLogKey(contractAddress), height) +} + +func (s *State) DeleteContractClassHashLog(contractAddress *felt.Felt, height uint64) error { + return s.deleteLog(classHashLogKey(contractAddress), height) +} + +func (s *State) contractValueAt(keyFunc func(*felt.Felt) []byte, addr *felt.Felt, blockNumber uint64) (*felt.Felt, error) { + key := keyFunc(addr) + value, err := s.valueAt(key, blockNumber) + if err != nil { + return nil, err + } + + return new(felt.Felt).SetBytes(value), nil +} + +func (s *State) valueAt(key []byte, height uint64) ([]byte, error) { + it, err := s.txn.NewIterator() + if err != nil { + return nil, err + } + + for it.Seek(logDBKey(key, height)); it.Valid(); it.Next() { + seekedKey := it.Key() + // seekedKey size should be `len(key) + sizeof(uint64)` and seekedKey should match key prefix + if len(seekedKey) != len(key)+8 || !bytes.HasPrefix(seekedKey, key) { + break + } + + seekedHeight := binary.BigEndian.Uint64(seekedKey[len(key):]) + if seekedHeight < height { + // last change happened before the height we are looking for + // check head state + break + } else if seekedHeight == height { + // a log exists for the height we are looking for, so the old value in this log entry is not useful. + // advance the iterator and see we can use the next entry. If not, ErrCheckHeadState will be returned + continue + } + + val, itErr := it.Value() + if err = utils.RunAndWrapOnError(it.Close, itErr); err != nil { + return nil, err + } + // seekedHeight > height + return val, nil + } + + return nil, utils.RunAndWrapOnError(it.Close, ErrCheckHeadState) } // Root returns the state commitment. @@ -222,15 +298,27 @@ func (s *State) Update(blockNumber uint64, update *StateUpdate, declaredClasses // register deployed contracts for addr, classHash := range update.StateDiff.DeployedContracts { - if err = s.putNewContract(stateTrie, &addr, classHash, blockNumber); err != nil { + // check if contract is already deployed + _, err := GetContract(&addr, s.txn) + if err == nil { + return ErrContractAlreadyDeployed + } + + if !errors.Is(err, ErrContractNotDeployed) { return err } + + s.contracts[addr] = NewStateContract(&addr, classHash, &felt.Zero, blockNumber) } - if err = s.updateContracts(stateTrie, blockNumber, update.StateDiff, true); err != nil { + if err = s.updateContracts(blockNumber, update.StateDiff, true); err != nil { return err } + if err = s.Commit(stateTrie, true, blockNumber); err != nil { + return fmt.Errorf("state commit: %v", err) + } + if err = storageCloser(); err != nil { return err } @@ -238,6 +326,58 @@ func (s *State) Update(blockNumber uint64, update *StateUpdate, declaredClasses return s.verifyStateUpdateRoot(update.NewRoot) } +func (s *State) GetReverseStateDiff(blockNumber uint64, diff *StateDiff) (*StateDiff, error) { + reversed := *diff + + // storage diffs + reversed.StorageDiffs = make(map[felt.Felt]map[felt.Felt]*felt.Felt, len(diff.StorageDiffs)) + for addr, storageDiffs := range diff.StorageDiffs { + reversedDiffs := make(map[felt.Felt]*felt.Felt, len(storageDiffs)) + for key := range storageDiffs { + value := &felt.Zero + if blockNumber > 0 { + oldValue, err := s.ContractStorageAt(&addr, &key, blockNumber-1) + if err != nil { + return nil, err + } + value = oldValue + } + reversedDiffs[key] = value + } + reversed.StorageDiffs[addr] = reversedDiffs + } + + // nonces + reversed.Nonces = make(map[felt.Felt]*felt.Felt, len(diff.Nonces)) + for addr := range diff.Nonces { + oldNonce := &felt.Zero + if blockNumber > 0 { + var err error + oldNonce, err = s.ContractNonceAt(&addr, blockNumber-1) + if err != nil { + return nil, err + } + } + reversed.Nonces[addr] = oldNonce + } + + // replaced + reversed.ReplacedClasses = make(map[felt.Felt]*felt.Felt, len(diff.ReplacedClasses)) + for addr := range diff.ReplacedClasses { + classHash := &felt.Zero + if blockNumber > 0 { + var err error + classHash, err = s.ContractClassHashAt(&addr, blockNumber-1) + if err != nil { + return nil, err + } + } + reversed.ReplacedClasses[addr] = classHash + } + + return &reversed, nil +} + var ( noClassContractsClassHash = new(felt.Felt).SetUint64(0) @@ -246,60 +386,152 @@ var ( } ) -func (s *State) updateContracts(stateTrie *trie.Trie, blockNumber uint64, diff *StateDiff, logChanges bool) error { - // replace contract instances - for addr, classHash := range diff.ReplacedClasses { - oldClassHash, err := s.replaceContract(stateTrie, &addr, classHash) +// Commit updates the state by committing the dirty contracts to the database. +func (s *State) Commit( + stateTrie *trie.Trie, + logChanges bool, + blockNumber uint64, +) error { + type bufferedTransactionWithAddress struct { + txn *db.BufferedTransaction + addr *felt.Felt + } + + // // sort the contracts in descending storage diff order + keys := slices.SortedStableFunc(maps.Keys(s.contracts), func(a, b felt.Felt) int { + return len(s.contracts[a].dirtyStorage) - len(s.contracts[b].dirtyStorage) + }) + + contractPools := pool.NewWithResults[*bufferedTransactionWithAddress]().WithErrors().WithMaxGoroutines(runtime.GOMAXPROCS(0)) + for _, addr := range keys { + contractPools.Go(func() (*bufferedTransactionWithAddress, error) { + txn, err := s.contracts[addr].BufferedCommit(s.txn, logChanges, blockNumber) + if err != nil { + return nil, err + } + + return &bufferedTransactionWithAddress{ + txn: txn, + addr: &addr, + }, nil + }) + } + + bufferedTxns, err := contractPools.Wait() + if err != nil { + return err + } + + // we sort bufferedTxns in ascending contract address order to achieve an additional speedup + sort.Slice(bufferedTxns, func(i, j int) bool { + return bufferedTxns[i].addr.Cmp(bufferedTxns[j].addr) < 0 + }) + + for _, bufferedTxn := range bufferedTxns { + if err := bufferedTxn.txn.Flush(); err != nil { + return err + } + } + + for _, contract := range s.contracts { + if err := s.updateContractCommitment(stateTrie, contract); err != nil { + return err + } + } + + // finally, clear the contracts map + s.contracts = make(map[felt.Felt]*StateContract) + + return nil +} + +func (s *State) updateContracts(blockNumber uint64, diff *StateDiff, logChanges bool) error { + if err := s.updateContractClasses(blockNumber, diff.ReplacedClasses, logChanges); err != nil { + return err + } + + if err := s.updateContractNonces(blockNumber, diff.Nonces, logChanges); err != nil { + return err + } + + return s.updateContractStorages(blockNumber, diff.StorageDiffs) +} + +func (s *State) updateContractClasses( + blockNumber uint64, + replacedClasses map[felt.Felt]*felt.Felt, + logChanges bool, +) error { + for addr, classHash := range replacedClasses { + contract, err := s.getContract(addr) if err != nil { return err } if logChanges { - if err = s.LogContractClassHash(&addr, oldClassHash, blockNumber); err != nil { + if err := contract.LogClassHash(blockNumber, s.txn); err != nil { return err } } + + contract.ClassHash = classHash } + return nil +} - // update contract nonces - for addr, nonce := range diff.Nonces { - oldNonce, err := s.updateContractNonce(stateTrie, &addr, nonce) +func (s *State) updateContractNonces( + blockNumber uint64, + nonces map[felt.Felt]*felt.Felt, + logChanges bool, +) error { + for addr, nonce := range nonces { + contract, err := s.getContract(addr) if err != nil { return err } if logChanges { - if err = s.LogContractNonce(&addr, oldNonce, blockNumber); err != nil { + if err := contract.LogNonce(blockNumber, s.txn); err != nil { return err } } - } - - // update contract storages - return s.updateContractStorages(stateTrie, diff.StorageDiffs, blockNumber, logChanges) -} -// replaceContract replaces the class that a contract at a given address instantiates -func (s *State) replaceContract(stateTrie *trie.Trie, addr, classHash *felt.Felt) (*felt.Felt, error) { - contract, err := NewContractUpdater(addr, s.txn) - if err != nil { - return nil, err + contract.Nonce = nonce } + return nil +} - oldClassHash, err := ContractClassHash(addr, s.txn) - if err != nil { - return nil, err - } +func (s *State) updateContractStorages( + blockNumber uint64, + storageDiffs map[felt.Felt]map[felt.Felt]*felt.Felt, +) error { + for addr, diff := range storageDiffs { + contract, err := s.getContract(addr) + if err != nil { + if _, ok := noClassContracts[addr]; ok && errors.Is(err, ErrContractNotDeployed) { + contract = NewStateContract(&addr, noClassContractsClassHash, &felt.Zero, blockNumber) + s.contracts[addr] = contract + } else { + return err + } + } - if err = contract.Replace(classHash); err != nil { - return nil, err + contract.dirtyStorage = diff } + return nil +} - if err = s.updateContractCommitment(stateTrie, contract); err != nil { - return nil, err +func (s *State) getContract(addr felt.Felt) (*StateContract, error) { + contract, ok := s.contracts[addr] + if !ok { + var err error + contract, err = GetContract(&addr, s.txn) + if err != nil { + return nil, err + } + s.contracts[addr] = contract } - - return oldClassHash, nil + return contract, nil } type DeclaredClass struct { @@ -342,150 +574,18 @@ func (s *State) Class(classHash *felt.Felt) (*DeclaredClass, error) { return &class, nil } -func (s *State) updateStorageBuffered(contractAddr *felt.Felt, updateDiff map[felt.Felt]*felt.Felt, blockNumber uint64, logChanges bool) ( - *db.BufferedTransaction, error, -) { - // to avoid multiple transactions writing to s.txn, create a buffered transaction and use that in the worker goroutine - bufferedTxn := db.NewBufferedTransaction(s.txn) - bufferedState := NewState(bufferedTxn) - bufferedContract, err := NewContractUpdater(contractAddr, bufferedTxn) - if err != nil { - return nil, err - } - - onValueChanged := func(location, oldValue *felt.Felt) error { - if logChanges { - return bufferedState.LogContractStorage(contractAddr, location, oldValue, blockNumber) - } - return nil - } - - if err = bufferedContract.UpdateStorage(updateDiff, onValueChanged); err != nil { - return nil, err - } - - return bufferedTxn, nil -} - -// updateContractStorage applies the diff set to the Trie of the -// contract at the given address in the given Txn context. -func (s *State) updateContractStorages(stateTrie *trie.Trie, diffs map[felt.Felt]map[felt.Felt]*felt.Felt, - blockNumber uint64, logChanges bool, -) error { - type bufferedTransactionWithAddress struct { - txn *db.BufferedTransaction - addr *felt.Felt - } - - // make sure all noClassContracts are deployed - for addr := range diffs { - if _, ok := noClassContracts[addr]; !ok { - continue - } - - _, err := NewContractUpdater(&addr, s.txn) - if err != nil { - if !errors.Is(err, ErrContractNotDeployed) { - return err - } - // Deploy noClassContract - err = s.putNewContract(stateTrie, &addr, noClassContractsClassHash, blockNumber) - if err != nil { - return err - } - } - } - - // sort the contracts in decending diff size order - // so we start with the heaviest update first - keys := slices.SortedStableFunc(maps.Keys(diffs), func(a, b felt.Felt) int { return len(diffs[a]) - len(diffs[b]) }) - - // update per-contract storage Tries concurrently - contractUpdaters := pool.NewWithResults[*bufferedTransactionWithAddress]().WithErrors().WithMaxGoroutines(runtime.GOMAXPROCS(0)) - for _, key := range keys { - contractAddr := key - contractUpdaters.Go(func() (*bufferedTransactionWithAddress, error) { - bufferedTxn, err := s.updateStorageBuffered(&contractAddr, diffs[contractAddr], blockNumber, logChanges) - if err != nil { - return nil, err - } - return &bufferedTransactionWithAddress{txn: bufferedTxn, addr: &contractAddr}, nil - }) - } - - bufferedTxns, err := contractUpdaters.Wait() - if err != nil { - return err - } - - // we sort bufferedTxns in ascending contract address order to achieve an additional speedup - sort.Slice(bufferedTxns, func(i, j int) bool { - return bufferedTxns[i].addr.Cmp(bufferedTxns[j].addr) < 0 - }) - - // flush buffered txns - for _, txnWithAddress := range bufferedTxns { - if err := txnWithAddress.txn.Flush(); err != nil { - return err - } - } - - for addr := range diffs { - contract, err := NewContractUpdater(&addr, s.txn) - if err != nil { - return err - } - - if err = s.updateContractCommitment(stateTrie, contract); err != nil { - return err - } - } - - return nil -} - -// updateContractNonce updates nonce of the contract at the -// given address in the given Txn context. -func (s *State) updateContractNonce(stateTrie *trie.Trie, addr, nonce *felt.Felt) (*felt.Felt, error) { - contract, err := NewContractUpdater(addr, s.txn) - if err != nil { - return nil, err - } - - oldNonce, err := ContractNonce(addr, s.txn) - if err != nil { - return nil, err - } - - if err = contract.UpdateNonce(nonce); err != nil { - return nil, err - } - - if err = s.updateContractCommitment(stateTrie, contract); err != nil { - return nil, err - } - - return oldNonce, nil -} - // updateContractCommitment recalculates the contract commitment and updates its value in the global state Trie -func (s *State) updateContractCommitment(stateTrie *trie.Trie, contract *ContractUpdater) error { - root, err := ContractRoot(contract.Address, s.txn) - if err != nil { - return err - } - - cHash, err := ContractClassHash(contract.Address, s.txn) - if err != nil { - return err - } - - nonce, err := ContractNonce(contract.Address, s.txn) +func (s *State) updateContractCommitment(stateTrie *trie.Trie, contract *StateContract) error { + rootKey, err := contract.StorageRoot(s.txn) if err != nil { return err } - commitment := calculateContractCommitment(root, cHash, nonce) + commitment := calculateContractCommitment( + rootKey, + contract.ClassHash, + contract.Nonce, + ) _, err = stateTrie.Put(contract.Address, commitment) return err @@ -517,17 +617,15 @@ func (s *State) updateDeclaredClassesTrie(declaredClasses map[felt.Felt]*felt.Fe // ContractIsAlreadyDeployedAt returns if contract at given addr was deployed at blockNumber func (s *State) ContractIsAlreadyDeployedAt(addr *felt.Felt, blockNumber uint64) (bool, error) { - var deployedAt uint64 - if err := s.txn.Get(db.ContractDeploymentHeight.Key(addr.Marshal()), func(bytes []byte) error { - deployedAt = binary.BigEndian.Uint64(bytes) - return nil - }); err != nil { - if errors.Is(err, db.ErrKeyNotFound) { + contract, err := GetContract(addr, s.txn) + if err != nil { + if errors.Is(err, ErrContractNotDeployed) { return false, nil } return false, err } - return deployedAt <= blockNumber, nil + + return contract.DeployHeight <= blockNumber, nil } func (s *State) Revert(blockNumber uint64, update *StateUpdate) error { @@ -542,12 +640,11 @@ func (s *State) Revert(blockNumber uint64, update *StateUpdate) error { reversedDiff, err := s.GetReverseStateDiff(blockNumber, update.StateDiff) if err != nil { - return fmt.Errorf("error getting reverse state diff: %v", err) + return fmt.Errorf("build reverse diff: %v", err) } - err = s.performStateDeletions(blockNumber, update.StateDiff) - if err != nil { - return fmt.Errorf("error performing state deletions: %v", err) + if err = s.performStateDeletions(blockNumber, reversedDiff); err != nil { + return fmt.Errorf("perform state deletions: %v", err) } stateTrie, storageCloser, err := s.storage() @@ -555,55 +652,32 @@ func (s *State) Revert(blockNumber uint64, update *StateUpdate) error { return err } - if err = s.updateContracts(stateTrie, blockNumber, reversedDiff, false); err != nil { + if err = s.updateContracts(blockNumber, reversedDiff, false); err != nil { return fmt.Errorf("update contracts: %v", err) } - if err = storageCloser(); err != nil { - return err + if err = s.Commit(stateTrie, false, blockNumber); err != nil { + return fmt.Errorf("state commit: %v", err) } // purge deployed contracts for addr := range update.StateDiff.DeployedContracts { - if err = s.purgeContract(&addr); err != nil { + if err = s.purgeContract(stateTrie, &addr); err != nil { return fmt.Errorf("purge contract: %v", err) } } - if err = s.purgeNoClassContracts(); err != nil { + if err = s.purgeNoClassContracts(stateTrie); err != nil { + return fmt.Errorf("purge no class contract: %v", err) + } + + if err = storageCloser(); err != nil { return err } return s.verifyStateUpdateRoot(update.OldRoot) } -func (s *State) purgeNoClassContracts() error { - // As noClassContracts are not in StateDiff.DeployedContracts we can only purge them if their storage no longer exists. - // Updating contracts with reverse diff will eventually lead to the deletion of noClassContract's storage key from db. Thus, - // we can use the lack of key's existence as reason for purging noClassContracts. - for addr := range noClassContracts { - noClassC, err := NewContractUpdater(&addr, s.txn) - if err != nil { - if !errors.Is(err, ErrContractNotDeployed) { - return err - } - continue - } - - r, err := ContractRoot(noClassC.Address, s.txn) - if err != nil { - return fmt.Errorf("contract root: %v", err) - } - - if r.Equal(&felt.Zero) { - if err = s.purgeContract(&addr); err != nil { - return fmt.Errorf("purge contract: %v", err) - } - } - } - return nil -} - func (s *State) removeDeclaredClasses(blockNumber uint64, v0Classes []*felt.Felt, v1Classes map[felt.Felt]*felt.Felt) error { totalCapacity := len(v0Classes) + len(v1Classes) classHashes := make([]*felt.Felt, 0, totalCapacity) @@ -639,107 +713,78 @@ func (s *State) removeDeclaredClasses(blockNumber uint64, v0Classes []*felt.Felt return classesCloser() } -func (s *State) purgeContract(addr *felt.Felt) error { - contract, err := NewContractUpdater(addr, s.txn) - if err != nil { - return err - } - - state, storageCloser, err := s.storage() +func (s *State) purgeContract(stateTrie *trie.Trie, addr *felt.Felt) error { + contract, err := GetContract(addr, s.txn) if err != nil { return err } - if err = s.txn.Delete(db.ContractDeploymentHeight.Key(addr.Marshal())); err != nil { - return err - } - - if _, err = state.Put(contract.Address, &felt.Zero); err != nil { + if _, err = stateTrie.Put(contract.Address, &felt.Zero); err != nil { return err } - if err = contract.Purge(); err != nil { + if err = contract.Purge(s.txn); err != nil { return err } - return storageCloser() + return nil } -func (s *State) GetReverseStateDiff(blockNumber uint64, diff *StateDiff) (*StateDiff, error) { - reversed := *diff - +func (s *State) performStateDeletions(blockNumber uint64, diff *StateDiff) error { // storage diffs - reversed.StorageDiffs = make(map[felt.Felt]map[felt.Felt]*felt.Felt, len(diff.StorageDiffs)) for addr, storageDiffs := range diff.StorageDiffs { - reversedDiffs := make(map[felt.Felt]*felt.Felt, len(storageDiffs)) for key := range storageDiffs { - value := &felt.Zero - if blockNumber > 0 { - oldValue, err := s.ContractStorageAt(&addr, &key, blockNumber-1) - if err != nil { - return nil, err - } - value = oldValue + if err := s.DeleteContractStorageLog(&addr, &key, blockNumber); err != nil { + return err } - reversedDiffs[key] = value } - reversed.StorageDiffs[addr] = reversedDiffs } // nonces - reversed.Nonces = make(map[felt.Felt]*felt.Felt, len(diff.Nonces)) for addr := range diff.Nonces { - oldNonce := &felt.Zero - if blockNumber > 0 { - var err error - oldNonce, err = s.ContractNonceAt(&addr, blockNumber-1) - if err != nil { - return nil, err - } + if err := s.DeleteContractNonceLog(&addr, blockNumber); err != nil { + return err } - reversed.Nonces[addr] = oldNonce } - // replaced - reversed.ReplacedClasses = make(map[felt.Felt]*felt.Felt, len(diff.ReplacedClasses)) + // replaced classes for addr := range diff.ReplacedClasses { - classHash := &felt.Zero - if blockNumber > 0 { - var err error - classHash, err = s.ContractClassHashAt(&addr, blockNumber-1) - if err != nil { - return nil, err - } + if err := s.DeleteContractClassHashLog(&addr, blockNumber); err != nil { + return err } - reversed.ReplacedClasses[addr] = classHash } - return &reversed, nil + return nil } -func (s *State) performStateDeletions(blockNumber uint64, diff *StateDiff) error { - // storage diffs - for addr, storageDiffs := range diff.StorageDiffs { - for key := range storageDiffs { - if err := s.DeleteContractStorageLog(&addr, &key, blockNumber); err != nil { +func (s *State) purgeNoClassContracts(stateTrie *trie.Trie) error { + // As noClassContracts are not in StateDiff.DeployedContracts we can only purge them if their storage no longer exists. + // Updating contracts with reverse diff will eventually lead to the deletion of noClassContract's storage key from db. Thus, + // we can use the lack of key's existence as reason for purging noClassContracts. + for addr := range noClassContracts { + contract, err := GetContract(&addr, s.txn) + if err != nil { + if !errors.Is(err, ErrContractNotDeployed) { return err } + continue } - } - // nonces - for addr := range diff.Nonces { - if err := s.DeleteContractNonceLog(&addr, blockNumber); err != nil { - return err + rootKey, err := contract.StorageRoot(s.txn) + if err != nil { + return fmt.Errorf("get root key: %v", err) } - } - // replaced classes - for addr := range diff.ReplacedClasses { - if err := s.DeleteContractClassHashLog(&addr, blockNumber); err != nil { - return err + if rootKey.Equal(&felt.Zero) { + if err = s.purgeContract(stateTrie, &addr); err != nil { + return fmt.Errorf("purge contract: %v", err) + } } } return nil } + +func logDBKey(key []byte, height uint64) []byte { + return binary.BigEndian.AppendUint64(key, height) +} diff --git a/core/state_test.go b/core/state_test.go index 6b96d64b3b..e2d04ee2ea 100644 --- a/core/state_test.go +++ b/core/state_test.go @@ -41,6 +41,7 @@ func TestMain(m *testing.M) { _ = encoder.RegisterType(reflect.TypeOf(core.Cairo0Class{})) _ = encoder.RegisterType(reflect.TypeOf(core.Cairo1Class{})) + _ = encoder.RegisterType(reflect.TypeOf(core.StateContract{})) code := m.Run() @@ -440,17 +441,22 @@ func TestRevert(t *testing.T) { require.NoError(t, state.Update(1, su1, nil)) t.Run("revert a replaced class", func(t *testing.T) { + replacedVal := utils.HexToFelt(t, "0xDEADBEEF") replaceStateUpdate := &core.StateUpdate{ NewRoot: utils.HexToFelt(t, "0x30b1741b28893b892ac30350e6372eac3a6f32edee12f9cdca7fbe7540a5ee"), OldRoot: su1.NewRoot, StateDiff: &core.StateDiff{ ReplacedClasses: map[felt.Felt]*felt.Felt{ - su1FirstDeployedAddress: utils.HexToFelt(t, "0xDEADBEEF"), + su1FirstDeployedAddress: replacedVal, }, }, } require.NoError(t, state.Update(2, replaceStateUpdate, nil)) + classHash, err := state.ContractClassHash(new(felt.Felt).Set(&su1FirstDeployedAddress)) + require.NoError(t, err) + assert.Equal(t, replacedVal, classHash) + require.NoError(t, state.Revert(2, replaceStateUpdate)) classHash, sErr := state.ContractClassHash(new(felt.Felt).Set(&su1FirstDeployedAddress)) require.NoError(t, sErr) @@ -458,20 +464,25 @@ func TestRevert(t *testing.T) { }) t.Run("revert a nonce update", func(t *testing.T) { + replacedVal := utils.HexToFelt(t, "0xDEADBEEF") nonceStateUpdate := &core.StateUpdate{ NewRoot: utils.HexToFelt(t, "0x6683657d2b6797d95f318e7c6091dc2255de86b72023c15b620af12543eb62c"), OldRoot: su1.NewRoot, StateDiff: &core.StateDiff{ Nonces: map[felt.Felt]*felt.Felt{ - su1FirstDeployedAddress: utils.HexToFelt(t, "0xDEADBEEF"), + su1FirstDeployedAddress: replacedVal, }, }, } require.NoError(t, state.Update(2, nonceStateUpdate, nil)) - require.NoError(t, state.Revert(2, nonceStateUpdate)) nonce, sErr := state.ContractNonce(new(felt.Felt).Set(&su1FirstDeployedAddress)) require.NoError(t, sErr) + assert.Equal(t, replacedVal, nonce) + + require.NoError(t, state.Revert(2, nonceStateUpdate)) + nonce, sErr = state.ContractNonce(new(felt.Felt).Set(&su1FirstDeployedAddress)) + require.NoError(t, sErr) assert.Equal(t, &felt.Zero, nonce) }) @@ -702,3 +713,168 @@ func TestRevertDeclaredClasses(t *testing.T) { _, err = state.Class(sierraHash) require.ErrorIs(t, err, db.ErrKeyNotFound) } + +func TestHistory(t *testing.T) { + testDB := pebble.NewMemTest(t) + txn, err := testDB.NewTransaction(true) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, txn.Discard()) + }) + + state := core.NewState(txn) + addr := &felt.Zero + location := new(felt.Felt).SetUint64(456) + value := new(felt.Felt).SetUint64(789) + + t.Run("no history", func(t *testing.T) { + _, err := state.ContractNonceAt(new(felt.Felt).SetUint64(1), 1) + require.ErrorIs(t, err, core.ErrCheckHeadState) + + _, err = state.ContractClassHashAt(new(felt.Felt).SetUint64(1), 1) + require.ErrorIs(t, err, core.ErrCheckHeadState) + + _, err = state.ContractStorageAt(new(felt.Felt).SetUint64(1), new(felt.Felt).SetUint64(1), 1) + require.ErrorIs(t, err, core.ErrCheckHeadState) + }) + + contract := core.NewStateContract(&felt.Zero, &felt.Zero, &felt.Zero, 0) + t.Run("log value changed at height 5 and 10", func(t *testing.T) { + assert.NoError(t, contract.LogNonce(5, txn)) + assert.NoError(t, contract.LogClassHash(5, txn)) + assert.NoError(t, contract.LogStorage(location, &felt.Zero, 5, txn)) + + contract.Nonce = value + contract.ClassHash = value + + assert.NoError(t, contract.LogNonce(10, txn)) + assert.NoError(t, contract.LogClassHash(10, txn)) + assert.NoError(t, contract.LogStorage(location, value, 10, txn)) + }) + + t.Run("get value before height 5", func(t *testing.T) { + oldValue, err := state.ContractStorageAt(addr, location, 1) + require.NoError(t, err) + assert.Equal(t, &felt.Zero, oldValue) + + oldValue, err = state.ContractNonceAt(addr, 1) + require.NoError(t, err) + assert.Equal(t, &felt.Zero, oldValue) + + oldValue, err = state.ContractClassHashAt(addr, 1) + require.NoError(t, err) + assert.Equal(t, &felt.Zero, oldValue) + }) + + t.Run("get value between height 5-10", func(t *testing.T) { + oldValue, err := state.ContractStorageAt(addr, location, 7) + require.NoError(t, err) + assert.Equal(t, value, oldValue) + + oldValue, err = state.ContractNonceAt(addr, 7) + require.NoError(t, err) + assert.Equal(t, value, oldValue) + + oldValue, err = state.ContractClassHashAt(addr, 7) + require.NoError(t, err) + assert.Equal(t, value, oldValue) + }) + + t.Run("get value on height that change happened", func(t *testing.T) { + oldValue, err := state.ContractStorageAt(addr, location, 5) + require.NoError(t, err) + assert.Equal(t, value, oldValue) + + _, err = state.ContractStorageAt(addr, location, 10) + assert.ErrorIs(t, err, core.ErrCheckHeadState) + + oldValue, err = state.ContractNonceAt(addr, 5) + require.NoError(t, err) + assert.Equal(t, value, oldValue) + + _, err = state.ContractNonceAt(addr, 10) + assert.ErrorIs(t, err, core.ErrCheckHeadState) + + oldValue, err = state.ContractClassHashAt(addr, 5) + require.NoError(t, err) + assert.Equal(t, value, oldValue) + + _, err = state.ContractClassHashAt(addr, 10) + assert.ErrorIs(t, err, core.ErrCheckHeadState) + }) + + t.Run("get value after height 10 ", func(t *testing.T) { + _, err = state.ContractStorageAt(addr, location, 13) + assert.ErrorIs(t, err, core.ErrCheckHeadState) + + _, err = state.ContractNonceAt(addr, 13) + assert.ErrorIs(t, err, core.ErrCheckHeadState) + + _, err = state.ContractClassHashAt(addr, 13) + assert.ErrorIs(t, err, core.ErrCheckHeadState) + }) + + t.Run("get a random location ", func(t *testing.T) { + _, err = state.ContractStorageAt(new(felt.Felt).SetUint64(37), new(felt.Felt).SetUint64(37), 13) + assert.ErrorIs(t, err, core.ErrCheckHeadState) + + _, err = state.ContractNonceAt(new(felt.Felt).SetUint64(37), 13) + assert.ErrorIs(t, err, core.ErrCheckHeadState) + + _, err = state.ContractClassHashAt(new(felt.Felt).SetUint64(37), 13) + assert.ErrorIs(t, err, core.ErrCheckHeadState) + }) + + t.Run("delete storage and get value after delete", func(t *testing.T) { + assert.NoError(t, state.DeleteContractClassHashLog(addr, 10)) + assert.NoError(t, state.DeleteContractNonceLog(addr, 10)) + assert.NoError(t, state.DeleteContractStorageLog(addr, location, 10)) + + _, err = state.ContractStorageAt(addr, location, 10) + assert.ErrorIs(t, err, core.ErrCheckHeadState) + + _, err = state.ContractNonceAt(addr, 10) + assert.ErrorIs(t, err, core.ErrCheckHeadState) + + _, err = state.ContractClassHashAt(addr, 10) + assert.ErrorIs(t, err, core.ErrCheckHeadState) + }) +} + +func BenchmarkStateUpdate(b *testing.B) { + client := feeder.NewTestClient(b, &utils.Mainnet) + gw := adaptfeeder.New(client) + + su0, err := gw.StateUpdate(context.Background(), 0) + require.NoError(b, err) + + su1, err := gw.StateUpdate(context.Background(), 1) + require.NoError(b, err) + + su2, err := gw.StateUpdate(context.Background(), 2) + require.NoError(b, err) + + stateUpdates := []*core.StateUpdate{su0, su1, su2} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StopTimer() + // Create a new test database for each iteration + testDB := pebble.NewMemTest(b) + txn, err := testDB.NewTransaction(true) + require.NoError(b, err) + + state := core.NewState(txn) + b.StartTimer() + + for i, su := range stateUpdates { + err = state.Update(uint64(i), su, nil) + if err != nil { + b.Fatalf("Error updating state: %v", err) + } + } + + b.StopTimer() + require.NoError(b, txn.Discard()) + } +} diff --git a/db/buckets.go b/db/buckets.go index 3918eb5f29..ef8aa1b8a8 100644 --- a/db/buckets.go +++ b/db/buckets.go @@ -9,30 +9,38 @@ type Bucket byte // keys like Bolt or MDBX does. We use a global prefix list as a poor // man's bucket alternative. const ( - StateTrie Bucket = iota // state metadata (e.g., the state root) - Peer // maps peer ID to peer multiaddresses - ContractClassHash // maps contract addresses and class hashes - ContractStorage // contract storages - Class // maps class hashes to classes - ContractNonce // contract nonce - ChainHeight // Latest height of the blockchain - BlockHeaderNumbersByHash - BlockHeadersByNumber - TransactionBlockNumbersAndIndicesByHash // maps transaction hashes to block number and index - TransactionsByBlockNumberAndIndex // maps block number and index to transaction - ReceiptsByBlockNumberAndIndex // maps block number and index to transaction receipt - StateUpdatesByBlockNumber + // StateTrie -> Latest state trie's root key + // StateTrie + ContractAddr -> Contract's commitment value + // StateTrie + ContractAddr + Trie node path -> Trie node value + StateTrie Bucket = iota + Peer // Peer + PeerID bytes -> Encoded peer multiaddresses + ContractClassHash // (Legacy) ContractClassHash + ContractAddr -> Contract's class hash value + // ContractStorage + ContractAddr -> Latest contract storage trie's root key + // ContractStorage + ContractAddr + Trie node path -> Trie node value + ContractStorage + Class // Class + Class hash -> Class object + ContractNonce // (Legacy) ContractNonce + ContractAddr -> Contract's nonce value + ChainHeight // ChainHeight -> Latest height of the blockchain + BlockHeaderNumbersByHash // BlockHeaderNumbersByHash + BlockHash -> Block number + BlockHeadersByNumber // BlockHeadersByNumber + BlockNumber -> Block header object + TransactionBlockNumbersAndIndicesByHash // TransactionBlockNumbersAndIndicesByHash + TransactionHash -> Encoded(BlockNumber, Index) + TransactionsByBlockNumberAndIndex // TransactionsByBlockNumberAndIndex + Encoded(BlockNumber, Index) -> Encoded(Transaction) + ReceiptsByBlockNumberAndIndex // ReceiptsByBlockNumberAndIndex + Encoded(BlockNumber, Index) -> Encoded(Receipt) + StateUpdatesByBlockNumber // StateUpdatesByBlockNumber + BlockNumber -> Encoded(StateUpdate) + // ClassesTrie -> Latest classes trie's root key + // ClassesTrie + ClassHash -> PoseidonHash(leafVersion, compiledClassHash) ClassesTrie - ContractStorageHistory - ContractNonceHistory - ContractClassHashHistory - ContractDeploymentHeight - L1Height - SchemaVersion - Pending - BlockCommitments - Temporary // used temporarily for migrations - SchemaIntermediateState + ContractStorageHistory // ContractStorageHistory + ContractAddr + BlockHeight + StorageLocation -> StorageValue + ContractNonceHistory // ContractNonceHistory + ContractAddr + BlockHeight -> Contract's nonce value + ContractClassHashHistory // ContractClassHashHistory + ContractAddr + BlockHeight -> Contract's class hash value + ContractDeploymentHeight // (Legacy) ContractDeploymentHeight + ContractAddr -> BlockHeight + L1Height // L1Height -> Latest height of the L1 chain + SchemaVersion // SchemaVersion -> DB schema version + Pending // Pending -> Pending block + BlockCommitments // BlockCommitments + BlockNumber -> Block commitments + Temporary // used temporarily for migrations + SchemaIntermediateState // used for db schema metadata + Contract // Contract + ContractAddr -> Encoded(Contract) ) // Key flattens a prefix and series of byte arrays into a single []byte. diff --git a/db/buckets_enumer.go b/db/buckets_enumer.go index 0501198b61..0abf339cf5 100644 --- a/db/buckets_enumer.go +++ b/db/buckets_enumer.go @@ -7,11 +7,11 @@ import ( "strings" ) -const _BucketName = "StateTriePeerContractClassHashContractStorageClassContractNonceChainHeightBlockHeaderNumbersByHashBlockHeadersByNumberTransactionBlockNumbersAndIndicesByHashTransactionsByBlockNumberAndIndexReceiptsByBlockNumberAndIndexStateUpdatesByBlockNumberClassesTrieContractStorageHistoryContractNonceHistoryContractClassHashHistoryContractDeploymentHeightL1HeightSchemaVersionPendingBlockCommitmentsTemporarySchemaIntermediateState" +const _BucketName = "StateTriePeerContractClassHashContractStorageClassContractNonceChainHeightBlockHeaderNumbersByHashBlockHeadersByNumberTransactionBlockNumbersAndIndicesByHashTransactionsByBlockNumberAndIndexReceiptsByBlockNumberAndIndexStateUpdatesByBlockNumberClassesTrieContractStorageHistoryContractNonceHistoryContractClassHashHistoryContractDeploymentHeightL1HeightSchemaVersionPendingBlockCommitmentsTemporarySchemaIntermediateStateContract" -var _BucketIndex = [...]uint16{0, 9, 13, 30, 45, 50, 63, 74, 98, 118, 157, 190, 219, 244, 255, 277, 297, 321, 345, 353, 366, 373, 389, 398, 421} +var _BucketIndex = [...]uint16{0, 9, 13, 30, 45, 50, 63, 74, 98, 118, 157, 190, 219, 244, 255, 277, 297, 321, 345, 353, 366, 373, 389, 398, 421, 429} -const _BucketLowerName = "statetriepeercontractclasshashcontractstorageclasscontractnoncechainheightblockheadernumbersbyhashblockheadersbynumbertransactionblocknumbersandindicesbyhashtransactionsbyblocknumberandindexreceiptsbyblocknumberandindexstateupdatesbyblocknumberclassestriecontractstoragehistorycontractnoncehistorycontractclasshashhistorycontractdeploymentheightl1heightschemaversionpendingblockcommitmentstemporaryschemaintermediatestate" +const _BucketLowerName = "statetriepeercontractclasshashcontractstorageclasscontractnoncechainheightblockheadernumbersbyhashblockheadersbynumbertransactionblocknumbersandindicesbyhashtransactionsbyblocknumberandindexreceiptsbyblocknumberandindexstateupdatesbyblocknumberclassestriecontractstoragehistorycontractnoncehistorycontractclasshashhistorycontractdeploymentheightl1heightschemaversionpendingblockcommitmentstemporaryschemaintermediatestatecontract" func (i Bucket) String() string { if i >= Bucket(len(_BucketIndex)-1) { @@ -48,9 +48,10 @@ func _BucketNoOp() { _ = x[BlockCommitments-(21)] _ = x[Temporary-(22)] _ = x[SchemaIntermediateState-(23)] + _ = x[Contract-(24)] } -var _BucketValues = []Bucket{StateTrie, Peer, ContractClassHash, ContractStorage, Class, ContractNonce, ChainHeight, BlockHeaderNumbersByHash, BlockHeadersByNumber, TransactionBlockNumbersAndIndicesByHash, TransactionsByBlockNumberAndIndex, ReceiptsByBlockNumberAndIndex, StateUpdatesByBlockNumber, ClassesTrie, ContractStorageHistory, ContractNonceHistory, ContractClassHashHistory, ContractDeploymentHeight, L1Height, SchemaVersion, Pending, BlockCommitments, Temporary, SchemaIntermediateState} +var _BucketValues = []Bucket{StateTrie, Peer, ContractClassHash, ContractStorage, Class, ContractNonce, ChainHeight, BlockHeaderNumbersByHash, BlockHeadersByNumber, TransactionBlockNumbersAndIndicesByHash, TransactionsByBlockNumberAndIndex, ReceiptsByBlockNumberAndIndex, StateUpdatesByBlockNumber, ClassesTrie, ContractStorageHistory, ContractNonceHistory, ContractClassHashHistory, ContractDeploymentHeight, L1Height, SchemaVersion, Pending, BlockCommitments, Temporary, SchemaIntermediateState, Contract} var _BucketNameToValueMap = map[string]Bucket{ _BucketName[0:9]: StateTrie, @@ -101,6 +102,8 @@ var _BucketNameToValueMap = map[string]Bucket{ _BucketLowerName[389:398]: Temporary, _BucketName[398:421]: SchemaIntermediateState, _BucketLowerName[398:421]: SchemaIntermediateState, + _BucketName[421:429]: Contract, + _BucketLowerName[421:429]: Contract, } var _BucketNames = []string{ @@ -128,6 +131,7 @@ var _BucketNames = []string{ _BucketName[373:389], _BucketName[389:398], _BucketName[398:421], + _BucketName[421:429], } // BucketString retrieves an enum value from the enum constants string name. diff --git a/db/pebble/db.go b/db/pebble/db.go index 5974edf720..77aed603d7 100644 --- a/db/pebble/db.go +++ b/db/pebble/db.go @@ -60,7 +60,7 @@ func NewMem() (db.DB, error) { } // NewMemTest opens a new in-memory database, panics on error -func NewMemTest(t *testing.T) db.DB { +func NewMemTest(t testing.TB) db.DB { memDB, err := NewMem() if err != nil { t.Fatalf("create in-memory db: %v", err) diff --git a/migration/migration.go b/migration/migration.go index f3a8dd2a72..b0173dac37 100644 --- a/migration/migration.go +++ b/migration/migration.go @@ -66,6 +66,7 @@ var defaultMigrations = []Migration{ NewBucketMover(db.Temporary, db.ContractStorage), NewBucketMigrator(db.StateUpdatesByBlockNumber, changeStateDiffStruct).WithBatchSize(100), //nolint:mnd NewBucketMigrator(db.Class, migrateCairo1CompiledClass).WithBatchSize(1_000), //nolint:mnd + MigrationFunc(MigrateContractFields), } var ErrCallWithNewTransaction = errors.New("call with new transaction") @@ -714,3 +715,160 @@ func migrateCairo1CompiledClass(txn db.Transaction, key, value []byte, _ *utils. return txn.Set(key, value) } + +func MigrateContractFields(txn db.Transaction, _ *utils.Network) error { + contracts := make(map[felt.Felt]*core.StateContract) + + it, err := txn.NewIterator() + if err != nil { + return err + } + + if err := collectContractNonces(txn, contracts); err != nil { + return err + } + + if err := collectContractClassHashes(txn, contracts); err != nil { + return err + } + + if err := collectContractDeploymentHeights(txn, contracts); err != nil { + return err + } + + if err := storeUpdatedContracts(txn, contracts); err != nil { + return err + } + + return it.Close() +} + +func collectContractNonces(txn db.Transaction, contracts map[felt.Felt]*core.StateContract) error { + it, err := txn.NewIterator() + if err != nil { + return err + } + + noncePrefix := db.ContractNonce.Key() + for it.Seek(noncePrefix); it.Valid(); it.Next() { + key := it.Key() + if !bytes.Equal(key[:len(noncePrefix)], noncePrefix) { + break + } + + addr := key[len(noncePrefix):] + if len(addr) != felt.Bytes { + return fmt.Errorf("invalid address length: %d", len(addr)) + } + + addrFelt := new(felt.Felt).SetBytes(addr) + + value, err := it.Value() + if err != nil { + return err + } + + contract := &core.StateContract{ + Nonce: new(felt.Felt).SetBytes(value), + } + contracts[*addrFelt] = contract + + if err := txn.Delete(key); err != nil { + return err + } + } + + return it.Close() +} + +func collectContractClassHashes(txn db.Transaction, contracts map[felt.Felt]*core.StateContract) error { + it, err := txn.NewIterator() + if err != nil { + return err + } + + classHashPrefix := db.ContractClassHash.Key() + for it.Seek(classHashPrefix); it.Valid(); it.Next() { + key := it.Key() + if !bytes.Equal(key[:len(classHashPrefix)], classHashPrefix) { + break + } + + addr := key[len(classHashPrefix):] + if len(addr) != felt.Bytes { + return fmt.Errorf("invalid address length: %d", len(addr)) + } + addrFelt := new(felt.Felt).SetBytes(addr) + + // this should never happen because collectContractNonces should have collected all the contracts + if _, ok := contracts[*addrFelt]; !ok { + return fmt.Errorf("contract not found for address: %s", addrFelt) + } + + value, err := it.Value() + if err != nil { + return err + } + + contracts[*addrFelt].ClassHash = new(felt.Felt).SetBytes(value) + + if err := txn.Delete(key); err != nil { + return err + } + } + + return it.Close() +} + +func collectContractDeploymentHeights(txn db.Transaction, contracts map[felt.Felt]*core.StateContract) error { + it, err := txn.NewIterator() + if err != nil { + return err + } + + deployHeightPrefix := db.ContractDeploymentHeight.Key() + for it.Seek(deployHeightPrefix); it.Valid(); it.Next() { + key := it.Key() + if !bytes.Equal(key[:len(deployHeightPrefix)], deployHeightPrefix) { + break + } + + addr := key[len(deployHeightPrefix):] + if len(addr) != felt.Bytes { + return fmt.Errorf("invalid address length: %d", len(addr)) + } + addrFelt := new(felt.Felt).SetBytes(addr) + + // this should never happen because collectContractNonces should have collected all the contracts + if _, ok := contracts[*addrFelt]; !ok { + return fmt.Errorf("contract not found for address: %s", addrFelt) + } + + value, err := it.Value() + if err != nil { + return err + } + + contracts[*addrFelt].DeployHeight = binary.BigEndian.Uint64(value) + + if err := txn.Delete(key); err != nil { + return err + } + } + + return it.Close() +} + +func storeUpdatedContracts(txn db.Transaction, contracts map[felt.Felt]*core.StateContract) error { + for addr, contract := range contracts { + contractBytes, err := encoder.Marshal(contract) + if err != nil { + return err + } + + if err := txn.Set(db.Contract.Key(addr.Marshal()), contractBytes); err != nil { + return err + } + } + return nil +} diff --git a/migration/migration_test.go b/migration/migration_test.go index f037974fb1..85065462aa 100644 --- a/migration/migration_test.go +++ b/migration/migration_test.go @@ -2,9 +2,14 @@ package migration_test import ( "context" + "encoding/binary" "testing" + "github.com/NethermindEth/juno/core" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/pebble" + "github.com/NethermindEth/juno/encoder" "github.com/NethermindEth/juno/migration" "github.com/NethermindEth/juno/utils" "github.com/stretchr/testify/require" @@ -40,3 +45,46 @@ func TestMigrateIfNeeded(t *testing.T) { require.Equal(t, meta, postVersion) }) } + +func TestMigrateContractFields(t *testing.T) { + testDB := pebble.NewMemTest(t) + txn, err := testDB.NewTransaction(true) + require.NoError(t, err) + + // Test data + contracts := []struct { + addr *felt.Felt + nonce *felt.Felt + classHash *felt.Felt + deploymentHeight uint64 + }{ + {new(felt.Felt).SetUint64(1), new(felt.Felt).SetUint64(11), new(felt.Felt).SetUint64(111), 1111}, + {new(felt.Felt).SetUint64(2), new(felt.Felt).SetUint64(22), new(felt.Felt).SetUint64(222), 2222}, + {new(felt.Felt).SetUint64(3), new(felt.Felt).SetUint64(33), new(felt.Felt).SetUint64(333), 3333}, + } + + // Set up initial data + for _, c := range contracts { + addrBytes := c.addr.Marshal() + hBytes := make([]byte, 8) + binary.BigEndian.PutUint64(hBytes, c.deploymentHeight) + + require.NoError(t, txn.Set(db.ContractNonce.Key(addrBytes), c.nonce.Marshal())) + require.NoError(t, txn.Set(db.ContractClassHash.Key(addrBytes), c.classHash.Marshal())) + require.NoError(t, txn.Set(db.ContractDeploymentHeight.Key(addrBytes), hBytes)) + } + + // Run migration + require.NoError(t, migration.MigrateContractFields(txn, nil)) + + // Verify results + for _, c := range contracts { + var contract core.StateContract + require.NoError(t, txn.Get(db.Contract.Key(c.addr.Marshal()), func(value []byte) error { + return encoder.Unmarshal(value, &contract) + })) + require.Equal(t, c.nonce, contract.Nonce) + require.Equal(t, c.classHash, contract.ClassHash) + require.Equal(t, c.deploymentHeight, contract.DeployHeight) + } +}