From 3af5390001e59ba767378047add0df5e26193d9f Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 7 May 2024 16:22:20 -0400 Subject: [PATCH] refactor[test]: change fixture scope in examples (#3995) roughly 5x performance increase per CPU in the `tests/functional/examples/` directory (testing locally: 27s -> 7s) --- tests/evm_backends/pyevm_env.py | 5 +++- tests/evm_backends/revm_env.py | 3 +++ .../examples/auctions/test_blind_auction.py | 2 +- .../auctions/test_simple_open_auction.py | 4 ++-- .../examples/company/test_company.py | 2 +- .../crowdfund/test_crowdfund_example.py | 2 +- .../examples/factory/test_factory.py | 6 ++--- .../test_on_chain_market_maker.py | 17 +++++++------- .../test_safe_remote_purchase.py | 4 ++-- .../examples/storage/test_advanced_storage.py | 2 +- .../examples/storage/test_storage.py | 2 +- .../examples/tokens/test_erc1155.py | 2 +- .../functional/examples/tokens/test_erc20.py | 4 ++-- .../examples/tokens/test_erc4626.py | 4 ++-- .../functional/examples/tokens/test_erc721.py | 2 +- .../functional/examples/voting/test_ballot.py | 2 +- .../functional/examples/wallet/test_wallet.py | 23 ++++++++----------- 17 files changed, 44 insertions(+), 42 deletions(-) diff --git a/tests/evm_backends/pyevm_env.py b/tests/evm_backends/pyevm_env.py index 6638308ff9..6c510278a7 100644 --- a/tests/evm_backends/pyevm_env.py +++ b/tests/evm_backends/pyevm_env.py @@ -1,3 +1,4 @@ +import copy import logging from contextlib import contextmanager from typing import Optional @@ -65,7 +66,7 @@ def _state(self) -> StateAPI: def _vm(self) -> VirtualMachineAPI: return self._chain.get_vm() - @cached_property + @property def _context(self) -> ExecutionContext: context = self._state.execution_context assert isinstance(context, ExecutionContext) # help mypy @@ -74,10 +75,12 @@ def _context(self) -> ExecutionContext: @contextmanager def anchor(self): snapshot_id = self._state.snapshot() + ctx = copy.copy(self._state.execution_context) try: yield finally: self._state.revert(snapshot_id) + self._state.execution_context = ctx def get_balance(self, address: str) -> int: return self._state.get_balance(_addr(address)) diff --git a/tests/evm_backends/revm_env.py b/tests/evm_backends/revm_env.py index c23a74e158..5c8b8aba08 100644 --- a/tests/evm_backends/revm_env.py +++ b/tests/evm_backends/revm_env.py @@ -31,6 +31,7 @@ def __init__( @contextmanager def anchor(self): snapshot_id = self._evm.snapshot() + block = BlockEnv(number=self._evm.env.block.number, timestamp=self._evm.env.block.timestamp) try: yield finally: @@ -40,6 +41,8 @@ def anchor(self): # snapshot_id is reverted by the transaction already. # revm updates are needed to make the journal more robust. pass + self._evm.set_block_env(block) + # self._evm.set_tx_env(tx) def get_balance(self, address: str) -> int: return self._evm.get_balance(address) diff --git a/tests/functional/examples/auctions/test_blind_auction.py b/tests/functional/examples/auctions/test_blind_auction.py index 06f0656f1d..eda84e1217 100644 --- a/tests/functional/examples/auctions/test_blind_auction.py +++ b/tests/functional/examples/auctions/test_blind_auction.py @@ -9,7 +9,7 @@ TEST_INCREMENT = 1 -@pytest.fixture +@pytest.fixture(scope="module") def auction_contract(env, get_contract): with open("examples/auctions/blind_auction.vy") as f: contract_code = f.read() diff --git a/tests/functional/examples/auctions/test_simple_open_auction.py b/tests/functional/examples/auctions/test_simple_open_auction.py index 430294fa79..68b208a9b8 100644 --- a/tests/functional/examples/auctions/test_simple_open_auction.py +++ b/tests/functional/examples/auctions/test_simple_open_auction.py @@ -5,12 +5,12 @@ EXPIRY = 16 -@pytest.fixture +@pytest.fixture(scope="module") def auction_start(env): return env.timestamp + 1 -@pytest.fixture +@pytest.fixture(scope="module") def auction_contract(env, get_contract, auction_start): with open("examples/auctions/simple_open_auction.vy") as f: contract_code = f.read() diff --git a/tests/functional/examples/company/test_company.py b/tests/functional/examples/company/test_company.py index 35b4951471..e302735d7c 100644 --- a/tests/functional/examples/company/test_company.py +++ b/tests/functional/examples/company/test_company.py @@ -1,7 +1,7 @@ import pytest -@pytest.fixture +@pytest.fixture(scope="module") def c(env, get_contract): with open("examples/stock/company.vy") as f: contract_code = f.read() diff --git a/tests/functional/examples/crowdfund/test_crowdfund_example.py b/tests/functional/examples/crowdfund/test_crowdfund_example.py index ff0d85d61e..510dd80c82 100644 --- a/tests/functional/examples/crowdfund/test_crowdfund_example.py +++ b/tests/functional/examples/crowdfund/test_crowdfund_example.py @@ -1,7 +1,7 @@ import pytest -@pytest.fixture +@pytest.fixture(scope="module") def c(env, get_contract): with open("examples/crowdfund.vy") as f: contract_code = f.read() diff --git a/tests/functional/examples/factory/test_factory.py b/tests/functional/examples/factory/test_factory.py index ecfc0bf557..5964d70478 100644 --- a/tests/functional/examples/factory/test_factory.py +++ b/tests/functional/examples/factory/test_factory.py @@ -4,7 +4,7 @@ import vyper -@pytest.fixture +@pytest.fixture(scope="module") def create_token(get_contract): with open("examples/tokens/ERC20.vy") as f: code = f.read() @@ -15,7 +15,7 @@ def create_token(): return create_token -@pytest.fixture +@pytest.fixture(scope="module") def create_exchange(env, get_contract): with open("examples/factory/Exchange.vy") as f: code = f.read() @@ -29,7 +29,7 @@ def create_exchange(token, factory): return create_exchange -@pytest.fixture +@pytest.fixture(scope="module") def factory(get_contract): with open("examples/factory/Exchange.vy") as f: code = f.read() diff --git a/tests/functional/examples/market_maker/test_on_chain_market_maker.py b/tests/functional/examples/market_maker/test_on_chain_market_maker.py index 071afce5d6..9dddc37ceb 100644 --- a/tests/functional/examples/market_maker/test_on_chain_market_maker.py +++ b/tests/functional/examples/market_maker/test_on_chain_market_maker.py @@ -3,14 +3,6 @@ from tests.utils import ZERO_ADDRESS - -@pytest.fixture -def market_maker(get_contract): - with open("examples/market_maker/on_chain_market_maker.vy") as f: - contract_code = f.read() - return get_contract(contract_code) - - TOKEN_NAME = "Vypercoin" TOKEN_SYMBOL = "FANG" TOKEN_DECIMALS = 18 @@ -18,7 +10,7 @@ def market_maker(get_contract): TOKEN_TOTAL_SUPPLY = TOKEN_INITIAL_SUPPLY * (10**TOKEN_DECIMALS) -@pytest.fixture +@pytest.fixture(scope="module") def erc20(get_contract): with open("examples/tokens/ERC20.vy") as f: contract_code = f.read() @@ -27,6 +19,13 @@ def erc20(get_contract): ) +@pytest.fixture(scope="module") +def market_maker(get_contract, erc20): + with open("examples/market_maker/on_chain_market_maker.vy") as f: + contract_code = f.read() + return get_contract(contract_code) + + def test_initial_state(market_maker): assert market_maker.totalEthQty() == 0 assert market_maker.totalTokenQty() == 0 diff --git a/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py b/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py index bb89375530..c4cfdc29eb 100644 --- a/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py +++ b/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py @@ -15,14 +15,14 @@ from eth_utils import to_wei -@pytest.fixture +@pytest.fixture(scope="module") def contract_code(get_contract): with open("examples/safe_remote_purchase/safe_remote_purchase.vy") as f: contract_code = f.read() return contract_code -@pytest.fixture +@pytest.fixture(scope="module") def get_balance(env): def get_balance(): a0, a1 = env.accounts[:2] diff --git a/tests/functional/examples/storage/test_advanced_storage.py b/tests/functional/examples/storage/test_advanced_storage.py index 51e5a1729e..4a41cb415c 100644 --- a/tests/functional/examples/storage/test_advanced_storage.py +++ b/tests/functional/examples/storage/test_advanced_storage.py @@ -4,7 +4,7 @@ INITIAL_VALUE = 4 -@pytest.fixture +@pytest.fixture(scope="module") def adv_storage_contract(get_contract): with open("examples/storage/advanced_storage.vy") as f: contract_code = f.read() diff --git a/tests/functional/examples/storage/test_storage.py b/tests/functional/examples/storage/test_storage.py index cdb71c5810..631bdc4dbe 100644 --- a/tests/functional/examples/storage/test_storage.py +++ b/tests/functional/examples/storage/test_storage.py @@ -3,7 +3,7 @@ INITIAL_VALUE = 4 -@pytest.fixture +@pytest.fixture(scope="module") def storage_contract(get_contract): with open("examples/storage/storage.vy") as f: contract_code = f.read() diff --git a/tests/functional/examples/tokens/test_erc1155.py b/tests/functional/examples/tokens/test_erc1155.py index afbfa8d56d..0a51c115bb 100644 --- a/tests/functional/examples/tokens/test_erc1155.py +++ b/tests/functional/examples/tokens/test_erc1155.py @@ -29,7 +29,7 @@ mintConflictBatch = [1, 2, 3] -@pytest.fixture +@pytest.fixture(scope="module") def erc1155(get_contract, env, tx_failed): owner, a1, a2, a3, a4, a5 = env.accounts[0:6] with open("examples/tokens/ERC1155ownable.vy") as f: diff --git a/tests/functional/examples/tokens/test_erc20.py b/tests/functional/examples/tokens/test_erc20.py index aef43768cb..b3dc2fe238 100644 --- a/tests/functional/examples/tokens/test_erc20.py +++ b/tests/functional/examples/tokens/test_erc20.py @@ -13,14 +13,14 @@ TOKEN_INITIAL_SUPPLY = 0 -@pytest.fixture +@pytest.fixture(scope="module") def c(get_contract): with open("examples/tokens/ERC20.vy") as f: code = f.read() return get_contract(code, *[TOKEN_NAME, TOKEN_SYMBOL, TOKEN_DECIMALS, TOKEN_INITIAL_SUPPLY]) -@pytest.fixture +@pytest.fixture(scope="module") def c_bad(get_contract): # Bad contract is used for overflow checks on totalSupply corrupted with open("examples/tokens/ERC20.vy") as f: diff --git a/tests/functional/examples/tokens/test_erc4626.py b/tests/functional/examples/tokens/test_erc4626.py index f0fb79efae..f6ff71f51a 100644 --- a/tests/functional/examples/tokens/test_erc4626.py +++ b/tests/functional/examples/tokens/test_erc4626.py @@ -7,7 +7,7 @@ TOKEN_INITIAL_SUPPLY = 0 -@pytest.fixture +@pytest.fixture(scope="module") def token(get_contract): with open("examples/tokens/ERC20.vy") as f: return get_contract( @@ -15,7 +15,7 @@ def token(get_contract): ) -@pytest.fixture +@pytest.fixture(scope="module") def vault(get_contract, token): with open("examples/tokens/ERC4626.vy") as f: return get_contract(f.read(), token.address) diff --git a/tests/functional/examples/tokens/test_erc721.py b/tests/functional/examples/tokens/test_erc721.py index 1ed26f64dc..3c1c5e71f9 100644 --- a/tests/functional/examples/tokens/test_erc721.py +++ b/tests/functional/examples/tokens/test_erc721.py @@ -11,7 +11,7 @@ ERC721_SIG = "0x80ac58cd" -@pytest.fixture +@pytest.fixture(scope="module") def c(get_contract, env): with open("examples/tokens/ERC721.vy") as f: code = f.read() diff --git a/tests/functional/examples/voting/test_ballot.py b/tests/functional/examples/voting/test_ballot.py index 2135feff72..9c82c5156b 100644 --- a/tests/functional/examples/voting/test_ballot.py +++ b/tests/functional/examples/voting/test_ballot.py @@ -6,7 +6,7 @@ PROPOSAL_2_NAME = b"Trump" + b"\x00" * 27 -@pytest.fixture +@pytest.fixture(scope="module") def c(get_contract): with open("examples/voting/ballot.vy") as f: contract_code = f.read() diff --git a/tests/functional/examples/wallet/test_wallet.py b/tests/functional/examples/wallet/test_wallet.py index c639974a31..6dfb838d8a 100644 --- a/tests/functional/examples/wallet/test_wallet.py +++ b/tests/functional/examples/wallet/test_wallet.py @@ -5,9 +5,10 @@ from eth_utils import is_same_address, to_bytes, to_checksum_address, to_int from tests.utils import ZERO_ADDRESS +from vyper.utils import keccak256 -@pytest.fixture +@pytest.fixture(scope="module") def c(env, get_contract): a0, a1, a2, a3, a4, a5, a6 = env.accounts[:7] with open("examples/wallet/wallet.vy") as f: @@ -19,20 +20,16 @@ def c(env, get_contract): return c -@pytest.fixture -def sign(keccak): - def _sign(seq, to, value, data, key): - keys = KeyAPI() - comb = seq.to_bytes(32, "big") + b"\x00" * 12 + to + value.to_bytes(32, "big") + data - h1 = keccak(comb) - h2 = keccak(b"\x19Ethereum Signed Message:\n32" + h1) - sig = keys.ecdsa_sign(h2, key) - return [28 if sig.v == 1 else 27, sig.r, sig.s] +def sign(seq, to, value, data, key): + keys = KeyAPI() + comb = seq.to_bytes(32, "big") + b"\x00" * 12 + to + value.to_bytes(32, "big") + data + h1 = keccak256(comb) + h2 = keccak256(b"\x19Ethereum Signed Message:\n32" + h1) + sig = keys.ecdsa_sign(h2, key) + return [28 if sig.v == 1 else 27, sig.r, sig.s] - return _sign - -def test_approve(env, c, tx_failed, sign): +def test_approve(env, c, tx_failed): a0, a1, a2, a3, a4, a5, a6 = env.accounts[:7] k0, k1, k2, k3, k4, k5, k6, k7 = env._keys[:8] env.set_balance(a1, 10**18)