Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add decode_calldata to ABI functions #317

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions boa/contracts/abi/abi_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from warnings import warn

from eth.abc import ComputationAPI
from sortedcontainers import SortedDict
from vyper.semantics.analysis.base import FunctionVisibility, StateMutability
from vyper.utils import method_id

Expand Down Expand Up @@ -105,6 +106,16 @@ def prepare_calldata(self, *args, **kwargs) -> bytes:
return encoded_args
return self.method_id + encoded_args

def decode_calldata(self, calldata: bytes) -> tuple:
"""Decode the calldata for the function call."""
calldata_method_id = calldata[:4]
if calldata_method_id != self.method_id:
raise ValueError(
f"The calldata 0x{calldata_method_id.hex()} does not match "
f"the method_id 0x{self.method_id.hex()}"
)
return abi_decode(self.signature, calldata[4:])

def _merge_kwargs(self, *args, **kwargs) -> list:
"""Merge positional and keyword arguments into a single list."""
if len(kwargs) + len(args) != self.argument_count:
Expand Down Expand Up @@ -166,11 +177,15 @@ def create(
return ABIOverload(functions)

def __init__(self, functions: list[ABIFunction]):
self.functions = functions
self.functions = SortedDict([(f.method_id, f) for f in functions])

@cached_property
def name(self) -> str | None:
return self.functions[0].name
"""
Gets the name of the overloaded function.
Note that all overloads have the same name by definition.
"""
return next(iter(self.functions.values())).name

def prepare_calldata(self, *args, disambiguate_signature=None, **kwargs) -> bytes:
"""Prepare the calldata for the function that matches the given arguments."""
Expand All @@ -179,6 +194,12 @@ def prepare_calldata(self, *args, disambiguate_signature=None, **kwargs) -> byte
)
return function.prepare_calldata(*args, **kwargs)

def decode_calldata(self, calldata: bytes) -> tuple:
"""Decode the calldata for the function that matches the given arguments."""
calldata_method_id = calldata[:4]
function = self.functions[calldata_method_id]
return function.decode_calldata(calldata)

def __call__(
self,
*args,
Expand All @@ -201,12 +222,12 @@ def _pick_overload(
self, *args, disambiguate_signature=None, **kwargs
) -> ABIFunction:
"""Pick the function that matches the given arguments."""
fns = self.functions.values()

if disambiguate_signature is None:
matches = [f for f in self.functions if f.is_encodable(*args, **kwargs)]
matches = [f for f in fns if f.is_encodable(*args, **kwargs)]
else:
matches = [
f for f in self.functions if disambiguate_signature == f.full_signature
]
matches = [f for f in fns if disambiguate_signature == f.full_signature]
assert len(matches) <= 1, "ABI signature must be unique"

assert self.name, "Constructor does not have a name."
Expand Down
9 changes: 5 additions & 4 deletions tests/unitary/contracts/abi/test_abi.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def test(n: uint256) -> uint256:
assert re.match(r"^ +\(unknown method id .*\.0x29e99f07\)$", error)


def test_prepare_calldata():
def test_calldata():
code = """
@external
def overloaded(n: uint256 = 0) -> uint256:
Expand All @@ -215,9 +215,10 @@ def argumented(n: uint256) -> uint256:
"""
abi_contract, _ = load_via_abi(code)
assert abi_contract.overloaded.prepare_calldata() == b"\x07\x8e\xec\xb4"
assert (
abi_contract.argumented.prepare_calldata(0) == b"\xedu\x96\x8d" + b"\x00" * 32
)
assert abi_contract.overloaded.decode_calldata(b"\x07\x8e\xec\xb4") == ()
argumented_calldata = b"\xedu\x96\x8d" + b"\x00" * 32
assert abi_contract.argumented.prepare_calldata(0) == argumented_calldata
assert abi_contract.argumented.decode_calldata(argumented_calldata) == (0,)
assert len(abi_contract.abi) == 3
assert abi_contract.deployer.abi == abi_contract.abi

Expand Down
Loading