Skip to content

Commit

Permalink
Bytecode as bytes + hashmap internal getter
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielSchiavini committed May 13, 2024
1 parent 5b45fe1 commit 7a4f7fa
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 13 deletions.
20 changes: 19 additions & 1 deletion boa_zksync/compile.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import subprocess
from os import path
from pathlib import Path
from shutil import which
from tempfile import TemporaryDirectory

Expand Down Expand Up @@ -35,14 +37,30 @@ def compile_zksync(
if source_code is None:
with open(filename) as file:
source_code = file.read()

kwargs = output[filename.removeprefix("./")]
bytecode = bytes.fromhex(kwargs.pop("bytecode").removeprefix("0x"))
return ZksyncCompilerData(
contract_name, source_code, compiler_args, **output[filename]
contract_name, source_code, compiler_args, bytecode, **kwargs
)


def compile_zksync_source(
source_code: str, name: str, compiler_args=None
) -> ZksyncCompilerData:
"""
Compile a contract from source code.
:param source_code: The source code of the contract.
:param name: The (file)name of the contract. If this is a file name, the
contract name will be the file name without the extension.
:param compiler_args: Extra arguments to pass to the compiler.
:return: The compiled contract.
"""
if path.exists(name):
# We need to accept filenames because of the way `boa.load` works
contract_name = Path(name).stem
return compile_zksync(contract_name, name, compiler_args, source_code)

with TemporaryDirectory() as tempdir:
filename = f"{tempdir}/{name}.vy"
with open(filename, "w") as file:
Expand Down
30 changes: 21 additions & 9 deletions boa_zksync/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ class _ZksyncInternal(ABIFunction):
"""

@cached_property
def _override_bytecode(self):
def _override_bytecode(self) -> bytes:
data = self.contract.compiler_data
source = "\n".join((data.source_code, self.source_code))
compiled = compile_zksync_source(source, self.name, data.compiler_args)
return to_bytes(compiled.bytecode)
return compiled.bytecode

@property
def source_code(self):
Expand All @@ -80,7 +80,7 @@ def __call__(self, *args, **kwargs):
try:
return super().__call__(*args, **kwargs)
finally:
env.set_code(self.contract.address, to_bytes(self.contract._bytecode))
env.set_code(self.contract.address, self.contract.compiler_data.bytecode)


class ZksyncInternalFunction(_ZksyncInternal):
Expand All @@ -96,6 +96,7 @@ def __init__(self, fn: ContractFunctionT, contract: ZksyncContract):
if fn.return_type
else []
),
"stateMutability": fn.mutability.value,
"name": f"__boa_private_{fn.name}__",
"type": "function",
}
Expand All @@ -110,29 +111,40 @@ def source_code(self):

class ZksyncInternalVariable(_ZksyncInternal):
def __init__(self, var: VarInfo, name: str, contract: ZksyncContract):
inputs, output = var.typ.getter_signature
abi = {
"anonymous": False,
"inputs": [],
"outputs": [{"name": name, "type": var.typ.abi_type.selector_name()}],
"inputs": [
{"name": f"arg{index}", "type": arg.abi_type.selector_name()}
for index, arg in enumerate(inputs)
],
"outputs": [{"name": name, "type": output.abi_type.selector_name()}],
"name": f"__boa_private_{name}__",
"constant": True,
"type": "function",
}
super().__init__(abi, contract._name)
self.contract = contract
self.var = var
self.var_name = name

def get(self):
return self.__call__()
def get(self, *args):
return self.__call__(*args)

@cached_property
def source_code(self):
args, arg_getter = "", ""
inputs, output = self.var.typ.getter_signature
if inputs:
arg_getter = "".join([f"[arg{i}]" for i in range(len(inputs))])
args = ", ".join([f"arg{i}: {arg.abi_type.selector_name()}" for i, arg in enumerate(inputs)])

return textwrap.dedent(
f"""
@external
@payable
def __boa_private_{self.var_name}__() -> {self.var.typ.abi_type.selector_name()}:
return self.{self.var_name}
def __boa_private_{self.var_name}__({args}) -> {output.abi_type.selector_name()}:
return self.{self.var_name}{arg_getter}
"""
)

Expand Down
2 changes: 1 addition & 1 deletion boa_zksync/deployer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def deploy(self, *args, value=0, **kwargs) -> ZksyncContract:
), "ZksyncDeployer can only be used in zkSync environments"

address, _ = env.deploy_code(
bytecode=to_bytes(self.compiler_data.bytecode),
bytecode=self.compiler_data.bytecode,
value=value,
constructor_calldata=(
self.constructor.prepare_calldata(*args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion boa_zksync/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def create(self):

def _reset_fork(self, block_identifier="latest"):
if isinstance(self._rpc, EraTestNode) and (inner_rpc := self._rpc.inner_rpc):
del self._rpc
del self._rpc # close the old rpc
self._rpc = inner_rpc

def fork(
Expand Down
2 changes: 1 addition & 1 deletion boa_zksync/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,9 @@ class ZksyncCompilerData:
contract_name: str
source_code: str
compiler_args: list[str]
bytecode: bytes
method_identifiers: dict
abi: list[dict]
bytecode: str
bytecode_runtime: str
warnings: list[str]
factory_deps: list[str]
Expand Down
2 changes: 2 additions & 0 deletions tests/test_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def get_name_of(addr: HasName) -> String[32]:
def test_private(zksync_env):
code = """
bar: uint256
map: HashMap[uint256, uint256]
@internal
def foo(x: uint256) -> uint256:
Expand All @@ -177,6 +178,7 @@ def foo(x: uint256) -> uint256:
"""
contract = boa.loads(code)
assert contract._storage.bar.get() == 0
assert contract._storage.map.get(0) == 0
assert contract.internal.foo(123) == 123
assert contract._storage.bar.get() == 123
assert contract.eval("self.bar = 456") is None
Expand Down

0 comments on commit 7a4f7fa

Please sign in to comment.