Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: namedtuple decoding for vvmcontract structs #356

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
45bb007
set base_path in loads_partial_vvm
charles-cooper Dec 26, 2024
c6389f2
add namedtuple parsing for structs
charles-cooper Dec 26, 2024
84fbf5a
small refactor for get_logs()
charles-cooper Dec 26, 2024
fe7cbdb
add decode_log for abi contracts
charles-cooper Dec 26, 2024
9da2a89
fix error handling in vvm deployer
charles-cooper Dec 26, 2024
655850d
fix lint
charles-cooper Dec 27, 2024
49dfd2c
add a note
charles-cooper Dec 27, 2024
3f241fe
rename a variable
charles-cooper Dec 27, 2024
6d53c63
fix: marshal output for tuple return
charles-cooper Dec 31, 2024
bf6c380
thread name= to VVMDeployer
charles-cooper Jan 2, 2025
d1d10b7
update _loads_partial_vvm to not trample VVMDeployer.name
charles-cooper Jan 2, 2025
c2ed9fa
handle tuples inside lists
charles-cooper Jan 3, 2025
594ed67
lint
charles-cooper Jan 3, 2025
970894f
handle namedtuples with `from` field
charles-cooper Jan 4, 2025
a67d641
fail more gracefully in decode_log when event abi not found
charles-cooper Jan 4, 2025
926c6c3
use namedtuple(rename=True)
charles-cooper Jan 4, 2025
2a3e564
add strict=True param to get_logs
charles-cooper Jan 6, 2025
f91d853
test VVMDeployer does not stomp cache
charles-cooper Jan 9, 2025
60dc4e4
add backwards compatibility
charles-cooper Jan 9, 2025
025c94f
add test for logs
charles-cooper Jan 9, 2025
56842d1
add address field to the log
charles-cooper Jan 9, 2025
a30932c
add tests for namedtuple structs
charles-cooper Jan 9, 2025
04c6113
add out-of-order indexed field
charles-cooper Jan 9, 2025
5273787
add tests for log address in subcall
charles-cooper Jan 9, 2025
417d0f0
fix lint
charles-cooper Jan 9, 2025
ef44c4d
add test for proper BoaError in VVMDeployer.deploy
charles-cooper Jan 9, 2025
06074bf
add tests, forward contract_name properly
charles-cooper Jan 9, 2025
16a8083
add a note
charles-cooper Jan 9, 2025
bf6a75f
fix lint
charles-cooper Jan 9, 2025
fc31b21
update existing tests
charles-cooper Jan 9, 2025
83459ba
update another test
charles-cooper Jan 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 176 additions & 15 deletions boa/contracts/abi/abi_contract.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from collections import defaultdict
from collections import defaultdict, namedtuple
from functools import cached_property
from typing import Any, Optional, Union
from warnings import warn

from eth.abc import ComputationAPI
from vyper.semantics.analysis.base import FunctionVisibility, StateMutability
from vyper.utils import method_id
from vyper.utils import keccak256, method_id

from boa.contracts.base_evm_contract import (
BoaError,
Expand Down Expand Up @@ -47,7 +47,7 @@ def argument_count(self) -> int:

@property
def signature(self) -> str:
return f"({_format_abi_type(self.argument_types)})"
return _format_abi_type(self.argument_types)

@cached_property
def return_type(self) -> list:
Expand Down Expand Up @@ -134,13 +134,25 @@ def __call__(self, *args, value=0, gas=None, sender=None, **kwargs):
contract=self.contract,
)

match self.contract.marshal_to_python(computation, self.return_type):
val = self.contract.marshal_to_python(computation, self.return_type)

# this property should be guaranteed by abi_decode inside marshal_to_python,
# assert it again just for clarity
# note that val should be a tuple.
assert len(self._abi["outputs"]) == len(val)

match val:
case ():
return None
case (single,):
return single
return _parse_complex(self._abi["outputs"][0], single, name=self.name)
case multiple:
return tuple(multiple)
item_abis = self._abi["outputs"]
cls = type(multiple) # should be tuple
return cls(
_parse_complex(abi, item, name=self.name)
for (abi, item) in zip(item_abis, multiple)
)


class ABIOverload:
Expand Down Expand Up @@ -234,16 +246,19 @@ def __init__(
name: str,
abi: list[dict],
functions: list[ABIFunction],
events: list[dict],
address: Address,
filename: Optional[str] = None,
env=None,
nowarn=False,
):
super().__init__(name, env, filename=filename, address=address)
self._abi = abi
self._functions = functions
self._events = events

self._bytecode = self.env.get_code(address)
if not self._bytecode:
if not self._bytecode and not nowarn:
warn(
f"Requested {self} but there is no bytecode at that address!",
stacklevel=2,
Expand Down Expand Up @@ -276,6 +291,96 @@ def method_id_map(self):
if not function.is_constructor
}

@cached_property
def event_for(self):
# [{"name": "Bar", "inputs":
# [{"name": "x", "type": "uint256", "indexed": false},
# {"name": "y", "type": "tuple", "components":
# [{"name": "x", "type": "uint256"}], "indexed": false}],
# "anonymous": false, "type": "event"},
# }]
ret = {}
for event_abi in self._events:
event_signature = ",".join(
_abi_from_json(item) for item in event_abi["inputs"]
)
event_name = event_abi["name"]
event_signature = f"{event_name}({event_signature})"
event_id = int(keccak256(event_signature.encode()).hex(), 16)
ret[event_id] = event_abi
return ret

def decode_log(self, log_entry):
# low level log id
charles-cooper marked this conversation as resolved.
Show resolved Hide resolved
_log_id, address, topics, data = log_entry
assert self._address.canonical_address == address
event_hash = topics[0]

if event_hash not in self.event_for:
# our abi is wrong, we can't decode it. fail loudly.
msg = f"can't find event with hash {hex(event_hash)} in abi"
msg += f" (possible events: {self.event_for})"
raise ValueError(msg)

event_abi = self.event_for[event_hash]

topic_abis = []
arg_abis = []

# add `address` to the tuple. this is prevented from being an
# actual fieldname in vyper and solidity since it is a reserved keyword
# in both languages. if for some reason some abi actually has a field
# named `address`, it will be renamed by namedtuple(rename=True).
tuple_names = ["address"]

for item_abi in event_abi["inputs"]:
is_topic = item_abi["indexed"]
assert isinstance(is_topic, bool)
if not is_topic:
arg_abis.append(item_abi)
charles-cooper marked this conversation as resolved.
Show resolved Hide resolved
else:
topic_abis.append(item_abi)

tuple_names.append(item_abi["name"])

tuple_typ = namedtuple(event_abi["name"], tuple_names, rename=True)

decoded_topics = []
for topic_abi, t in zip(topic_abis, topics[1:]):
# convert to bytes for abi decoder
encoded_topic = t.to_bytes(32, "big")
decoded_topics.append(abi_decode(_abi_from_json(topic_abi), encoded_topic))

args_selector = _format_abi_type(
[_abi_from_json(arg_abi) for arg_abi in arg_abis]
)

decoded_args = abi_decode(args_selector, data)

topics_ix = 0
args_ix = 0

xs = [Address(address)]

# re-align the evm topic + args lists with the way they appear in the
# abi ex. Transfer(indexed address, address, indexed address)
for item_abi in event_abi["inputs"]:
is_topic = item_abi["indexed"]
if is_topic:
abi = topic_abis[topics_ix]
topic = decoded_topics[topics_ix]
# topic abi is currently never complex, but use _parse_complex
# as future-proofing mechanism
xs.append(_parse_complex(abi, topic))
topics_ix += 1
else:
abi = arg_abis[args_ix]
arg = decoded_args[args_ix]
xs.append(_parse_complex(abi, arg))
args_ix += 1

return tuple_typ(*xs)

def marshal_to_python(self, computation, abi_type: list[str]) -> tuple[Any, ...]:
"""
Convert the output of a contract call to a Python object.
Expand All @@ -286,7 +391,7 @@ def marshal_to_python(self, computation, abi_type: list[str]) -> tuple[Any, ...]
if computation.is_error:
return self.handle_error(computation)

schema = f"({_format_abi_type(abi_type)})"
schema = _format_abi_type(abi_type)
try:
return abi_decode(schema, computation.output)
except ABIError as e:
Expand Down Expand Up @@ -360,17 +465,27 @@ def functions(self):
if item.get("type") == "function"
]

@property
def events(self):
return [item for item in self.abi if item.get("type") == "event"]

@classmethod
def from_abi_dict(cls, abi, name="<anonymous contract>", filename=None):
return cls(name, abi, filename)

def at(self, address: Address | str) -> ABIContract:
def at(self, address: Address | str, nowarn=False) -> ABIContract:
"""
Create an ABI contract object for a deployed contract at `address`.
"""
address = Address(address)
contract = ABIContract(
self._name, self._abi, self.functions, address, self.filename
self._name,
self._abi,
self.functions,
self.events,
address,
self.filename,
nowarn=nowarn,
)

contract.env.register_contract(address, contract)
Expand All @@ -390,15 +505,15 @@ def __repr__(self):

@cached_property
def args_abi_type(self):
return f"({_format_abi_type(self.function.argument_types)})"
return _format_abi_type(self.function.argument_types)

@cached_property
def _argument_names(self) -> list[str]:
return [arg["name"] for arg in self.function._abi["inputs"]]

@cached_property
def return_abi_type(self):
return f"({_format_abi_type(self.function.return_type)})"
return _format_abi_type(self.function.return_type)


def _abi_from_json(abi: dict) -> str:
Expand All @@ -407,6 +522,12 @@ def _abi_from_json(abi: dict) -> str:
:param abi: The ABI type to parse.
:return: The schema string for the given abi type.
"""
# {"stateMutability": "view", "type": "function", "name": "foo",
# "inputs": [],
# "outputs": [{"name": "", "type": "tuple",
# "components": [{"name": "x", "type": "uint256"}]}]
# }

if "components" in abi:
components = ",".join([_abi_from_json(item) for item in abi["components"]])
if abi["type"].startswith("tuple"):
Expand All @@ -416,11 +537,51 @@ def _abi_from_json(abi: dict) -> str:
return abi["type"]


def _parse_complex(abi: dict, value: Any, name=None) -> str:
"""
Parses an ABI type into its schema string.
:param abi: The ABI type to parse.
:return: The schema string for the given abi type.
"""
# simple case
if "components" not in abi:
return value

# https://docs.soliditylang.org/en/latest/abi-spec.html#handling-tuple-types
type_ = abi["type"]
assert type_.startswith("tuple")
# number of nested arrays (we don't care if dynamic or static)
depth = type_.count("[")

# complex case
# construct a namedtuple type on the fly
components = abi["components"]
typname = name or abi["name"] or "user_struct"
component_names = [item["name"] for item in components]

typ = namedtuple(typname, component_names, rename=True) # type: ignore[misc]

def _leaf(tuple_vals):
components_parsed = [
_parse_complex(item_abi, item)
for (item_abi, item) in zip(components, tuple_vals)
]

return typ(*components_parsed)

def _go(val, depth):
if depth == 0:
return _leaf(val)
return [_go(val, depth - 1) for val in val]

return _go(value, depth)


def _format_abi_type(types: list) -> str:
"""
Converts a list of ABI types into a comma-separated string.
"""
return ",".join(
item if isinstance(item, str) else f"({_format_abi_type(item)})"
for item in types
ret = ",".join(
item if isinstance(item, str) else _format_abi_type(item) for item in types
)
return f"({ret})"
45 changes: 44 additions & 1 deletion boa/contracts/base_evm_contract.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Optional

from eth.abc import ComputationAPI

Expand All @@ -13,6 +13,11 @@
from boa.vm.py_evm import titanoboa_computation


@dataclass
class RawEvent:
event_data: Any


class _BaseEVMContract:
"""
Base class for EVM (Ethereum Virtual Machine) contract:
Expand Down Expand Up @@ -57,6 +62,44 @@ def address(self) -> Address:
raise RuntimeError("Contract address is not set")
return self._address

# ## handling events
def _get_logs(self, computation, include_child_logs):
if computation is None:
return []

if include_child_logs:
return list(computation.get_raw_log_entries())

return computation._log_entries

def get_logs(self, computation=None, include_child_logs=True, strict=True):
if computation is None:
computation = self._computation

entries = self._get_logs(computation, include_child_logs)

# py-evm log format is (log_id, topics, data)
# sort on log_id
entries = sorted(entries)

ret = []
for e in entries:
logger_address = e[1]
c = self.env.lookup_contract(logger_address)
if c is not None:
try:
decoded_log = c.decode_log(e)
except Exception as exc:
if strict:
raise exc
else:
decoded_log = RawEvent(e)
else:
decoded_log = RawEvent(e)
ret.append(decoded_log)

return ret


class StackTrace(list): # list[str|ErrorDetail]
def __str__(self):
Expand Down
Loading
Loading