Skip to content

Commit

Permalink
Use debug_traceCall to calculate outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielSchiavini committed Apr 22, 2024
1 parent 8927cd3 commit 75d2e85
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 51 deletions.
22 changes: 6 additions & 16 deletions boa_zksync/deployer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,8 @@
from boa.rpc import to_bytes
from boa.util.abi import Address

from boa_zksync.types import ZksyncCompilerData


class ZksyncDeployer(ABIContractFactory):
def __init__(self, compiler_data: ZksyncCompilerData, name: str, filename: str):
super().__init__(
name,
compiler_data.abi,
functions=[
ABIFunction(item, name)
for item in compiler_data.abi
if item.get("type") == "function"
],
filename=filename,
)
self.compiler_data = compiler_data

def deploy(self, *args, value=0, **kwargs):
env = Env.get_singleton()
from boa_zksync.environment import ZksyncEnv
Expand All @@ -46,8 +31,9 @@ def deploy(self, *args, value=0, **kwargs):
address=address,
filename=self._filename,
env=env,
compiler_data=self.compiler_data,
)
env.register_contract(address, self)
env.register_contract(address, abi_contract)
return abi_contract

def deploy_as_blueprint(self, *args, **kwargs):
Expand All @@ -59,5 +45,9 @@ def deploy_as_blueprint(self, *args, **kwargs):

@cached_property
def constructor(self):
"""
Get the constructor function of the contract.
:raises: StopIteration if the constructor is not found.
"""
ctor_abi = next(i for i in self.abi if i["type"] == "constructor")
return ABIFunction(ctor_abi, contract_name=self._name)
33 changes: 12 additions & 21 deletions boa_zksync/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from boa.environment import _AddressType
from boa.interpret import json
from boa.network import NetworkEnv, _EstimateGasFailed
from boa.rpc import RPC, EthereumRPC, fixup_dict, to_bytes, to_hex
from boa.rpc import RPC, EthereumRPC
from boa.util.abi import Address
from boa_zksync.compile import compile_zksync, compile_zksync_source
from eth.constants import ZERO_ADDRESS
from eth.exceptions import VMError

from boa_zksync.compile import compile_zksync, compile_zksync_source
from boa_zksync.deployer import ZksyncDeployer
from boa_zksync.node import EraTestNode
from boa_zksync.types import DeployTransaction, ZksyncComputation, ZksyncMessage
Expand Down Expand Up @@ -103,25 +103,16 @@ def execute_code(
sender = self._check_sender(self._get_sender(sender))
args = ZksyncMessage(sender, to_address, gas or 0, value, data)

if not is_modifying:
output = self._rpc.fetch("eth_call", [args.as_json_dict(), "latest"])
return ZksyncComputation(args, to_bytes(output))

try:
receipt, trace = self._send_txn(**args.as_tx_params())
except _EstimateGasFailed:
return ZksyncComputation(args, error=VMError("Estimate gas failed"))

try:
# when calling create_from_blueprint, the address is not returned
# we get it from the logs by searching for the event. todo: remove this hack
deploy_topic = '0x290afdae231a3fc0bbae8b1af63698b0a1d79b21ad17df0342dfb952fe74f8e5'
output = next(x['topics'][3] for x in receipt['logs'] if x['topics'][0] == deploy_topic)
except StopIteration:
# TODO: This does not return the correct value either.
output = trace.returndata
trace_call = self._rpc.fetch("debug_traceCall", [args.as_json_dict(), "latest"])
traced_computation = ZksyncComputation.from_trace(trace_call)
if is_modifying:
try:
receipt, trace = self._send_txn(**args.as_tx_params())
assert traced_computation.is_error == trace.is_error, f"VMError mismatch: {traced_computation.error} != {trace.error}"
except _EstimateGasFailed:
return ZksyncComputation(args, error=VMError("Estimate gas failed"))

return ZksyncComputation(args, to_bytes(output))
return traced_computation

def deploy_code(
self,
Expand Down Expand Up @@ -205,7 +196,7 @@ def create_deployer(

if not compiler_data.abi:
logging.warning("No ABI found in compiled contract")
return ZksyncDeployer(compiler_data, name or filename, filename=filename)
return ZksyncDeployer.from_abi_dict(compiler_data.abi, name, filename, compiler_data)


def _hash_code(bytecode: bytes) -> bytes:
Expand Down
33 changes: 29 additions & 4 deletions boa_zksync/types.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from collections import namedtuple
from dataclasses import dataclass, field

import rlp
from boa.rpc import to_bytes, fixup_dict, to_hex
from boa.util.abi import Address
from eth.exceptions import VMError
from eth.exceptions import VMError, Revert
from eth_account import Account
from eth_account.datastructures import SignedMessage
from eth_account.messages import encode_typed_data
from rlp.sedes import BigEndianInt, Binary, List


_EIP712_TYPE = bytes.fromhex("71")
_EIP712_TYPES_SPEC = {
"EIP712Domain": [
Expand Down Expand Up @@ -167,11 +165,16 @@ class ZksyncCompilerData:
@dataclass
class ZksyncMessage:
sender: Address
to: str
to: Address
gas: int
value: int
data: bytes

@property
def code_address(self) -> bytes:
# this is used by boa to find the contract address for stack traces
return to_bytes(self.to)

def as_json_dict(self, sender_field="from"):
return fixup_dict({
sender_field: self.sender,
Expand All @@ -192,6 +195,28 @@ class ZksyncComputation:
error: VMError | None = None
children: list["ZksyncComputation"] = field(default_factory=list)

@classmethod
def from_trace(cls, output: dict) -> "ZksyncComputation":
""" Recursively constructs a ZksyncComputation from a debug_traceCall output. """
error = None
if output.get("error") is not None:
error = VMError(output["error"])
if output.get("revertReason") is not None:
error = Revert(output["revertReason"])

return cls(
msg=ZksyncMessage(
sender=Address(output["from"]),
to=Address(output["to"]),
gas=int(output["gas"], 16),
value=int(output["value"], 16),
data=to_bytes(output["input"]),
),
output=to_bytes(output["output"]),
error=error,
children=[cls.from_trace(call) for call in output.get("calls", [])],
)

@property
def is_success(self) -> bool:
"""
Expand Down
65 changes: 55 additions & 10 deletions tests/test_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@

import pytest

from boa import loads, loads_partial
import boa
from boa import BoaError
from boa.contracts.base_evm_contract import StackTrace

from boa_zksync.util import find_free_port, wait_url, stop_subprocess

STARTING_SUPPLY = 100
Expand Down Expand Up @@ -35,14 +38,15 @@ def __init__(t: uint256):
self.balances[self] = t
@external
def update_total_supply(t: uint16):
def update_total_supply(t: uint16) -> uint256:
self.totalSupply += convert(t, uint256)
return self.totalSupply
@external
def raise_exception(t: uint256):
raise "oh no!"
"""
return loads(code, STARTING_SUPPLY)
return boa.loads(code, STARTING_SUPPLY, name="SimpleContract")


def test_total_supply(simple_contract):
Expand Down Expand Up @@ -70,8 +74,8 @@ def some_function() -> uint256:
def create_child(blueprint: address, salt: bytes32, val: uint256) -> address:
return create_from_blueprint(blueprint, val, salt=salt)
"""
blueprint = loads_partial(blueprint_code).deploy_as_blueprint()
factory = loads(factory_code)
blueprint = boa.loads_partial(blueprint_code, name="Blueprint").deploy_as_blueprint()
factory = boa.loads(factory_code, name="Factory")

salt = b"\x00" * 32

Expand All @@ -80,7 +84,7 @@ def create_child(blueprint: address, salt: bytes32, val: uint256) -> address:
# assert child_contract_address == get_create2_address(
# blueprint_bytecode, factory.address, salt
# ).some_function()
child = loads_partial(blueprint_code).at(child_contract_address)
child = boa.loads_partial(blueprint_code).at(child_contract_address)
assert child.some_function() == 5


Expand All @@ -103,12 +107,12 @@ def some_function() -> uint256:
def create_child(blueprint: address, val: uint256) -> address:
return create_from_blueprint(blueprint, val)
"""
blueprint = loads_partial(blueprint_code).deploy_as_blueprint()
factory = loads(factory_code)
blueprint = boa.loads_partial(blueprint_code, name="blueprint").deploy_as_blueprint()
factory = boa.loads(factory_code, name="factory")

child_contract_address = factory.create_child(blueprint.address, 5)

child = loads_partial(blueprint_code).at(child_contract_address)
child = boa.loads_partial(blueprint_code).at(child_contract_address)
assert child.some_function() == 5


Expand All @@ -124,5 +128,46 @@ def foo() -> uint256:
def bar() -> uint256:
return self.foo()
"""
contract = loads(code)
contract = boa.loads(code)
assert contract.bar() == 123


def test_stack_trace():
called_contract = boa.loads(
"""
@internal
@view
def _get_name() -> String[32]:
assert False, "Test an error"
return "crvUSD"
@external
@view
def name() -> String[32]:
return self._get_name()
""", name="CalledContract"
)
caller_contract = boa.loads(
"""
interface HasName:
def name() -> String[32]: view
@external
@view
def get_name_of(addr: HasName) -> String[32]:
return addr.name()
""", name="CallerContract"
)

# boa.reverts does not give us the stack trace, use pytest.raises instead
with pytest.raises(BoaError) as ctx:
caller_contract.get_name_of(called_contract)

trace, = ctx.value.args
assert trace == StackTrace([
f" (<CalledContract interface at {called_contract.address}>.name() -> ['string'])",
f" (<CallerContract interface at {caller_contract.address}>.get_name_of(address) -> ['string'])",
" <Unknown contract 0x0000000000000000000000000000000000008009>", # MsgValueSimulator
" <Unknown contract 0x0000000000000000000000000000000000008002>", # AccountCodeStorage
f" (<CallerContract interface at {caller_contract.address}>.get_name_of(address) -> ['string'])",
])
2 changes: 2 additions & 0 deletions tests/test_zksync_fork.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,5 @@ def foo() -> bool:
"""
c = boa.loads_partial(code).at("0xB27cCfd5909f46F5260Ca01BA27f591868D08704")
assert c.foo() is True
c = boa.loads(code)
assert c.foo() is True

0 comments on commit 75d2e85

Please sign in to comment.