Skip to content

Commit

Permalink
Update boa and add support for call traces
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielSchiavini committed Sep 19, 2024
1 parent f13116d commit 1122473
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 24 deletions.
6 changes: 3 additions & 3 deletions boa_zksync/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __init__(self, fn: ContractFunctionT, contract: ZksyncContract):
"name": f"__boa_private_{fn.name}__",
"type": "function",
}
super().__init__(abi, contract._name)
super().__init__(abi, contract.contract_name)
self.contract = contract
self.func_t = fn

Expand All @@ -163,7 +163,7 @@ def __init__(self, var: VarInfo, name: str, contract: ZksyncContract):
"constant": True,
"type": "function",
}
super().__init__(abi, contract._name)
super().__init__(abi, contract.contract_name)
self.contract = contract
self.var = var
self.var_name = name
Expand Down Expand Up @@ -198,7 +198,7 @@ def __init__(self, code: str, contract: ZksyncContract):
"name": "__boa_debug__",
"type": "function",
}
super().__init__(abi, contract._name)
super().__init__(abi, contract.contract_name)
self.contract = contract
self.code = code

Expand Down
17 changes: 9 additions & 8 deletions boa_zksync/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@

from boa_zksync.deployer import ZksyncDeployer
from boa_zksync.node import EraTestNode
from boa_zksync.types import DeployTransaction, ZksyncComputation, ZksyncMessage
from boa_zksync.types import DeployTransaction, ZksyncComputation, ZksyncMessage, ZERO_ADDRESS, \
CONTRACT_DEPLOYER_ADDRESS

ZERO_ADDRESS = "0x0000000000000000000000000000000000000000"
_CONTRACT_DEPLOYER_ADDRESS = "0x0000000000000000000000000000000000008006"
with open(Path(__file__).parent / "IContractDeployer.json") as f:
CONTRACT_DEPLOYER = ABIContractFactory.from_abi_dict(
json.load(f), "ContractDeployer"
Expand Down Expand Up @@ -132,11 +131,11 @@ def execute_code(
"debug_traceCall",
[args.as_json_dict(), "latest", {"tracer": "callTracer"}],
)
traced_computation = ZksyncComputation.from_call_trace(trace_call)
traced_computation = ZksyncComputation.from_call_trace(self, trace_call)
except (RPCError, HTTPError):
output = self._rpc.fetch("eth_call", [args.as_json_dict(), "latest"])
traced_computation = ZksyncComputation(
args, bytes.fromhex(output.removeprefix("0x"))
self, args, bytes.fromhex(output.removeprefix("0x"))
)

if is_modifying:
Expand All @@ -147,11 +146,13 @@ def execute_code(
assert (
traced_computation.is_error == trace.is_error
), f"VMError mismatch: {traced_computation.error} != {trace.error}"
return ZksyncComputation.from_debug_trace(trace.raw_trace)
return ZksyncComputation.from_debug_trace(self, trace.raw_trace)

except _EstimateGasFailed:
if not traced_computation.is_error: # trace gives more information
return ZksyncComputation(args, error=VMError("Estimate gas failed"))
return ZksyncComputation(
self, args, error=VMError("Estimate gas failed")
)

return traced_computation

Expand Down Expand Up @@ -199,7 +200,7 @@ def deploy_code(
bytecode_hash = _hash_code(bytecode)
tx = DeployTransaction(
sender=sender,
to=_CONTRACT_DEPLOYER_ADDRESS,
to=CONTRACT_DEPLOYER_ADDRESS,
gas=gas or 0,
gas_price=gas_price,
max_priority_fee_per_gas=kwargs.pop("max_priority_fee_per_gas", gas_price),
Expand Down
42 changes: 36 additions & 6 deletions boa_zksync/types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from dataclasses import dataclass, field
from functools import cached_property
from typing import Optional
from typing import TYPE_CHECKING, Optional

import rlp
from boa.contracts.call_trace import TraceFrame
from boa.contracts.vyper.vyper_contract import VyperDeployer
from boa.interpret import compiler_data
from boa.rpc import fixup_dict, to_bytes, to_hex
Expand All @@ -15,6 +16,12 @@
from vyper.compiler import CompilerData
from vyper.compiler.settings import OptimizationLevel

if TYPE_CHECKING:
from boa_zksync import ZksyncEnv

ZERO_ADDRESS = "0x0000000000000000000000000000000000000000"
CONTRACT_DEPLOYER_ADDRESS = "0x0000000000000000000000000000000000008006"

_EIP712_TYPE = bytes.fromhex("71")
_EIP712_TYPES_SPEC = {
"EIP712Domain": [
Expand Down Expand Up @@ -218,9 +225,14 @@ def as_json_dict(self, sender_field="from"):
def as_tx_params(self):
return self.as_json_dict(sender_field="from_")

@property
def is_create(self) -> bool:
return self.to == CONTRACT_DEPLOYER_ADDRESS


@dataclass
class ZksyncComputation:
env: "ZksyncEnv"
msg: ZksyncMessage
output: bytes | None = None
error: VMError | None = None
Expand All @@ -231,7 +243,7 @@ class ZksyncComputation:
value: int = 0

@classmethod
def from_call_trace(cls, output: dict) -> "ZksyncComputation":
def from_call_trace(cls, env: "ZksyncEnv", output: dict) -> "ZksyncComputation":
"""Recursively constructs a ZksyncComputation from a debug_traceCall output."""
error = None
if output.get("error") is not None:
Expand All @@ -240,6 +252,7 @@ def from_call_trace(cls, output: dict) -> "ZksyncComputation":
error = Revert(output["revertReason"])

return cls(
env=env,
msg=ZksyncMessage(
sender=Address(output["from"]),
to=Address(output["to"]),
Expand All @@ -249,15 +262,17 @@ def from_call_trace(cls, output: dict) -> "ZksyncComputation":
),
output=to_bytes(output["output"]),
error=error,
children=[cls.from_call_trace(call) for call in output.get("calls", [])],
children=[
cls.from_call_trace(env, call) for call in output.get("calls", [])
],
gas_used=int(output["gasUsed"], 16),
revert_reason=output.get("revertReason"),
type=output.get("type", "Call"),
value=int(output.get("value", "0x"), 16),
)

@classmethod
def from_debug_trace(cls, output: dict):
def from_debug_trace(cls, env: "ZksyncEnv", output: dict):
"""
Finds the actual transaction computation, since zksync has system
contract calls in the trace.
Expand All @@ -270,12 +285,12 @@ def _find(calls: list[dict]):
if found := _find(trace["calls"]):
return found
if trace["to"] == to and trace["from"] == sender:
return cls.from_call_trace(trace)
return cls.from_call_trace(env, trace)

if result := _find(output["calls"]):
return result
# in production mode the result is not always nested
return cls.from_call_trace(output)
return cls.from_call_trace(env, output)

@property
def is_success(self) -> bool:
Expand All @@ -302,3 +317,18 @@ def raise_if_error(self) -> None:

def get_gas_used(self):
return self.gas_used

@property
def net_gas_used(self) -> int:
return self.get_gas_used()

@property
def call_trace(self) -> TraceFrame:
return self._get_call_trace()

def _get_call_trace(self, depth=0) -> TraceFrame:
address = self.msg.to
contract = self.env.lookup_contract(address)
source = contract.trace_source(self) if contract else None
children = [child._get_call_trace(depth + 1) for child in self.children]
return TraceFrame(self, source, depth, children)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ keywords = [
]
classifiers = ["Topic :: Software Development"]

dependencies = ["titanoboa>=0.2.0"]
dependencies = ["titanoboa>=0.2.2"]

[project.optional-dependencies]
forking-recommended = ["ujson"]
Expand Down
5 changes: 4 additions & 1 deletion tests/test_browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ def _javascript_call(js_func: str, *args, timeout_message: str) -> Any:
if method == "evm_snapshot":
return 1

if method == "eth_requestAccounts":
return [ZERO_ADDRESS]

if method == "evm_revert":
assert args[1:] == ([1],), f"Bad args passed to mock: {args}"
return None
Expand All @@ -26,7 +29,7 @@ def _javascript_call(js_func: str, *args, timeout_message: str) -> Any:

raise KeyError(args)

if js_func == "loadSigner":
if js_func in "loadSigner":
return ZERO_ADDRESS

raise KeyError(js_func)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_from_debug_trace_nested():
},
],
}
assert ZksyncComputation.from_debug_trace(output).output == result
assert ZksyncComputation.from_debug_trace(boa.env, output).output == result


def test_from_debug_trace_production_mode():
Expand All @@ -51,4 +51,4 @@ def test_from_debug_trace_production_mode():
"calls": [],
**_required_fields,
}
assert ZksyncComputation.from_debug_trace(output).output == result
assert ZksyncComputation.from_debug_trace(boa.env, output).output == result
21 changes: 18 additions & 3 deletions tests/test_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest
from boa import BoaError
from boa.contracts.base_evm_contract import StackTrace
from boa.contracts.call_trace import TraceFrame

STARTING_SUPPLY = 100

Expand Down Expand Up @@ -149,8 +150,8 @@ def get_name_of(addr: HasName) -> String[32]:
with pytest.raises(BoaError) as ctx:
caller_contract.get_name_of(called_contract)

(trace,) = ctx.value.args
assert trace == StackTrace(
(call_trace, stack_trace) = ctx.value.args
assert stack_trace == StackTrace(
[
" Test an error(<CalledContract interface at "
f"{called_contract.address}> (file CalledContract).name() -> ['string'])",
Expand All @@ -164,7 +165,21 @@ def get_name_of(addr: HasName) -> String[32]:
"['string'])",
]
)

assert isinstance(call_trace, TraceFrame)
assert str(call_trace).split("\n") == [
f'[E] [24549] CallerContract.get_name_of(addr = "{called_contract.address}") <0x>',
' [E] [23618] Unknown contract 0x0000000000000000000000000000000000008002.0x4de2e468',
' [566] Unknown contract 0x000000000000000000000000000000000000800B.0x29f172ad',
' [1909] Unknown contract 0x000000000000000000000000000000000000800B.0x06bed036',
' [159] Unknown contract 0x0000000000000000000000000000000000008010.0x00000000',
' [449] Unknown contract 0x000000000000000000000000000000000000800B.0xa225efcb',
' [2226] Unknown contract 0x0000000000000000000000000000000000008002.0x4de2e468',
' [427] Unknown contract 0x000000000000000000000000000000000000800B.0xa851ae78',
' [398] Unknown contract 0x0000000000000000000000000000000000008004.0xe516761e',
' [E] [2592] Unknown contract 0x0000000000000000000000000000000000008009.0xb47fade1',
f' [E] [1401] CallerContract.get_name_of(addr = "{called_contract.address}") <0x>',
' [E] [403] CalledContract.name() <0x>'
]

def test_private(zksync_env):
code = """
Expand Down

0 comments on commit 1122473

Please sign in to comment.