diff --git a/.env.sample b/.env.sample index dd89db4..3def4ce 100644 --- a/.env.sample +++ b/.env.sample @@ -1,12 +1,19 @@ # .env.sample -# URLs for DB connection -ETHEREUM_DB_URL= -GNOSIS_DB_URL= +# DB connection +DB_URL= -# URLs for Node provider connection -ETHEREUM_NODE_URL= -GNOSIS_NODE_URL= +# Node provider connection +NODE_URL= + +# connecting to Solver Slippage DB +SOLVER_SLIPPAGE_DB_URL= + +# configure chain sleep time, e.g. CHAIN_SLEEP_TIME=60 +CHAIN_SLEEP_TIME= + +# add chain name, e.g. CHAIN_NAME=Ethereum +CHAIN_NAME= # optional -INFURA_KEY=infura_key_here +INFURA_KEY= diff --git a/contracts/erc20_abi.py b/contracts/erc20_abi.py index 03a62f2..9a0c632 100644 --- a/contracts/erc20_abi.py +++ b/contracts/erc20_abi.py @@ -1,3 +1,6 @@ +""" +ERC20 ABI contract +""" erc20_abi = [ { "constant": True, diff --git a/requirements.txt b/requirements.txt index 247a178..ebf26cf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,6 @@ black==23.3.0 mypy==1.4.1 pylint==3.2.5 pytest==7.4.0 -setuptools \ No newline at end of file +setuptools +pandas-stubs +types-psycopg2 diff --git a/src/balanceof_imbalances.py b/src/balanceof_imbalances.py index f317f8d..c78f965 100644 --- a/src/balanceof_imbalances.py +++ b/src/balanceof_imbalances.py @@ -1,15 +1,8 @@ -# mypy: disable-error-code="call-overload, arg-type, operator" -import sys -import os - -# for debugging purposes -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) - from web3 import Web3 -from web3.types import TxReceipt +from web3.types import TxReceipt, HexStr from eth_typing import ChecksumAddress from typing import Dict, Optional, Set -from src.config import ETHEREUM_NODE_URL +from src.config import NODE_URL from src.constants import SETTLEMENT_CONTRACT_ADDRESS, NATIVE_ETH_TOKEN_ADDRESS from contracts.erc20_abi import erc20_abi @@ -17,11 +10,14 @@ class BalanceOfImbalances: - def __init__(self, ETHEREUM_NODE_URL: str): - self.web3 = Web3(Web3.HTTPProvider(ETHEREUM_NODE_URL)) + def __init__(self, NODE_URL: str): + self.web3 = Web3(Web3.HTTPProvider(NODE_URL)) def get_token_balance( - self, token_address: str, account: str, block_identifier: int + self, + token_address: ChecksumAddress, + account: ChecksumAddress, + block_identifier: int, ) -> Optional[int]: """Retrieve the ERC-20 token balance of an account at a given block.""" token_contract = self.web3.eth.contract(address=token_address, abi=erc20_abi) @@ -33,7 +29,9 @@ def get_token_balance( print(f"Error fetching balance for token {token_address}: {e}") return None - def get_eth_balance(self, account: str, block_identifier: int) -> Optional[int]: + def get_eth_balance( + self, account: ChecksumAddress, block_identifier: int + ) -> Optional[int]: """Get the ETH balance for a given account and block number.""" try: return self.web3.eth.get_balance(account, block_identifier=block_identifier) @@ -41,9 +39,9 @@ def get_eth_balance(self, account: str, block_identifier: int) -> Optional[int]: print(f"Error fetching ETH balance: {e}") return None - def extract_token_addresses(self, tx_receipt: Dict) -> Set[str]: + def extract_token_addresses(self, tx_receipt: TxReceipt) -> Set[ChecksumAddress]: """Extract unique token addresses from 'Transfer' events in a transaction receipt.""" - token_addresses = set() + token_addresses: Set[ChecksumAddress] = set() transfer_topics = { self.web3.keccak(text="Transfer(address,address,uint256)").hex(), self.web3.keccak(text="ERC20Transfer(address,address,uint256)").hex(), @@ -51,10 +49,10 @@ def extract_token_addresses(self, tx_receipt: Dict) -> Set[str]: } for log in tx_receipt["logs"]: if log["topics"][0].hex() in transfer_topics: - token_addresses.add(log["address"]) + token_addresses.add(self.web3.to_checksum_address(log["address"])) return token_addresses - def get_transaction_receipt(self, tx_hash: str) -> Optional[TxReceipt]: + def get_transaction_receipt(self, tx_hash: HexStr) -> Optional[TxReceipt]: """Fetch the transaction receipt for the given hash.""" try: return self.web3.eth.get_transaction_receipt(tx_hash) @@ -66,35 +64,45 @@ def get_balances( self, token_addresses: Set[ChecksumAddress], block_number: int ) -> Dict[ChecksumAddress, Optional[int]]: """Get balances for all tokens at the given block number.""" - balances = {} - balances[NATIVE_ETH_TOKEN_ADDRESS] = self.get_eth_balance( - SETTLEMENT_CONTRACT_ADDRESS, block_number + balances: Dict[ChecksumAddress, Optional[int]] = {} + balances[ + self.web3.to_checksum_address(NATIVE_ETH_TOKEN_ADDRESS) + ] = self.get_eth_balance( + self.web3.to_checksum_address(SETTLEMENT_CONTRACT_ADDRESS), block_number ) for token_address in token_addresses: balances[token_address] = self.get_token_balance( - token_address, SETTLEMENT_CONTRACT_ADDRESS, block_number + token_address, + self.web3.to_checksum_address(SETTLEMENT_CONTRACT_ADDRESS), + block_number, ) return balances def calculate_imbalances( self, - prev_balances: Dict[str, Optional[int]], - final_balances: Dict[str, Optional[int]], - ) -> Dict[str, int]: + prev_balances: Dict[ChecksumAddress, Optional[int]], + final_balances: Dict[ChecksumAddress, Optional[int]], + ) -> Dict[ChecksumAddress, int]: """Calculate imbalances between previous and final balances.""" - imbalances = {} + imbalances: Dict[ChecksumAddress, int] = {} for token_address in prev_balances: if ( prev_balances[token_address] is not None and final_balances[token_address] is not None ): - imbalance = final_balances[token_address] - prev_balances[token_address] + # need to ensure prev_balance and final_balance contain values + # to prevent subtraction from None + prev_balance = prev_balances[token_address] + assert prev_balance is not None + final_balance = final_balances[token_address] + assert final_balance is not None + imbalance = final_balance - prev_balance imbalances[token_address] = imbalance return imbalances - def compute_imbalances(self, tx_hash: str) -> Dict[str, int]: + def compute_imbalances(self, tx_hash: HexStr) -> Dict[ChecksumAddress, int]: """Compute token imbalances before and after a transaction.""" tx_receipt = self.get_transaction_receipt(tx_hash) if tx_receipt is None: @@ -116,7 +124,7 @@ def compute_imbalances(self, tx_hash: str) -> Dict[str, int]: def main(): tx_hash = input("Enter transaction hash: ") - bo = BalanceOfImbalances(ETHEREUM_NODE_URL) + bo = BalanceOfImbalances(NODE_URL) imbalances = bo.compute_imbalances(tx_hash) print("Token Imbalances:") for token_address, imbalance in imbalances.items(): diff --git a/src/config.py b/src/config.py index 3d2e71b..4cdf822 100644 --- a/src/config.py +++ b/src/config.py @@ -1,10 +1,75 @@ import os +from typing import Optional +from sqlalchemy import text +from sqlalchemy.exc import OperationalError +from sqlalchemy import create_engine, Engine from dotenv import load_dotenv +from src.helper_functions import get_logger + load_dotenv() -ETHEREUM_NODE_URL = os.getenv("ETHEREUM_NODE_URL") -GNOSIS_NODE_URL = os.getenv("GNOSIS_NODE_URL") +NODE_URL = os.getenv("NODE_URL") + +logger = get_logger("raw_token_imbalances") + +# Utilized by imbalances_script for computing for single tx hash +CHAIN_RPC_ENDPOINTS = { + "Ethereum": os.getenv("ETHEREUM_NODE_URL"), + "Gnosis": os.getenv("GNOSIS_NODE_URL"), +} + + +def get_env_int(var_name: str) -> int: + """ + Function for safe conversion to int (prevents None -> int conversion issues raised by mypy) + Retrieve environment variable and convert to int. Raise an error if not set. + """ + value = os.getenv(var_name) + if value is None: + raise ValueError(f"Environment variable {var_name} is not set.") + try: + return int(value) + except ValueError: + raise ValueError(f"Environment variable {var_name} must be a int.") + + +CHAIN_SLEEP_TIME = get_env_int("CHAIN_SLEEP_TIME") + + +def create_backend_db_connection(chain_name: str) -> Engine: + """function that creates a connection to the CoW db.""" + read_db_url = os.getenv("DB_URL") + + if not read_db_url: + raise ValueError(f"No database URL found for chain: {chain_name}") + + return create_engine(f"postgresql+psycopg2://{read_db_url}") + + +def create_solver_slippage_db_connection() -> Engine: + """function that creates a connection to the CoW db.""" + solver_db_url = os.getenv("SOLVER_SLIPPAGE_DB_URL") + if not solver_db_url: + raise ValueError( + "Solver slippage database URL not found in environment variables." + ) + + return create_engine(f"postgresql+psycopg2://{solver_db_url}") -CHAIN_RPC_ENDPOINTS = {"Ethereum": ETHEREUM_NODE_URL, "Gnosis": GNOSIS_NODE_URL} -CHAIN_SLEEP_TIMES = {"Ethereum": 60, "Gnosis": 120} +def check_db_connection(connection: Engine, chain_name: Optional[str] = None) -> Engine: + """ + Check if the database connection is still active. If not, create a new one. + """ + try: + if connection: + with connection.connect() as conn: # Use connection.connect() to get a Connection object + conn.execute(text("SELECT 1")) + except OperationalError: + # if connection is closed, create new one + connection = ( + create_backend_db_connection(chain_name) + if chain_name + else create_solver_slippage_db_connection() + ) + return connection diff --git a/src/constants.py b/src/constants.py index bbe179e..63fa527 100644 --- a/src/constants.py +++ b/src/constants.py @@ -1,3 +1,4 @@ +""" Constants used for the token imbalances project """ from web3 import Web3 SETTLEMENT_CONTRACT_ADDRESS = Web3.to_checksum_address( diff --git a/src/daemon.py b/src/daemon.py index d1ff665..b3595b6 100644 --- a/src/daemon.py +++ b/src/daemon.py @@ -1,104 +1,260 @@ -# mypy: disable-error-code="import, arg-type" +""" +Running this daemon computes raw imbalances for finalized blocks by calling imbalances_script.py. +""" import os import time +from typing import List, Tuple import pandas as pd from web3 import Web3 -from typing import List -from threading import Thread -from sqlalchemy import create_engine +from sqlalchemy import text from sqlalchemy.engine import Engine from src.imbalances_script import RawTokenImbalances -from src.config import CHAIN_RPC_ENDPOINTS, CHAIN_SLEEP_TIMES +from src.config import ( + CHAIN_SLEEP_TIME, + NODE_URL, + create_backend_db_connection, + create_solver_slippage_db_connection, + check_db_connection, + logger, +) -def get_web3_instance(chain_name: str) -> Web3: - return Web3(Web3.HTTPProvider(CHAIN_RPC_ENDPOINTS[chain_name])) +def get_web3_instance() -> Web3: + """ + returns a Web3 instance for the given blockchain via chain name. + """ + return Web3(Web3.HTTPProvider(NODE_URL)) def get_finalized_block_number(web3: Web3) -> int: - return web3.eth.block_number - 64 - - -def create_db_connection(chain_name: str): - """function that creates a connection to the CoW db.""" - if chain_name == "Ethereum": - db_url = os.getenv("ETHEREUM_DB_URL") - elif chain_name == "Gnosis": - db_url = os.getenv("GNOSIS_DB_URL") - - return create_engine(f"postgresql+psycopg2://{db_url}") + """ + Get the number of the most recent finalized block. + """ + return web3.eth.block_number - 67 -def fetch_transaction_hashes( - db_connection: Engine, start_block: int, end_block: int -) -> List[str]: - """Fetch transaction hashes beginning start_block.""" +def fetch_tx_data( + backend_db_connection: Engine, chain_name: str, start_block: int, end_block: int +) -> List[Tuple[str, int, int]]: + """Fetch transaction hashes beginning from start_block to end_block.""" + backend_db_connection = check_db_connection(backend_db_connection, chain_name) query = f""" - SELECT tx_hash - FROM settlements + SELECT tx_hash, auction_id, block_number + FROM settlements WHERE block_number >= {start_block} AND block_number <= {end_block} """ - - db_hashes = pd.read_sql(query, db_connection) + db_data = pd.read_sql(query, backend_db_connection) # converts hashes at memory location to hex - db_hashes["tx_hash"] = db_hashes["tx_hash"].apply(lambda x: f"0x{x.hex()}") + db_data["tx_hash"] = db_data["tx_hash"].apply(lambda x: f"0x{x.hex()}") - return db_hashes["tx_hash"].tolist() + # return (tx hash, auction id) as tx_data + tx_data = [ + (row["tx_hash"], row["auction_id"], row["block_number"]) + for index, row in db_data.iterrows() + ] + return tx_data -def process_transactions(chain_name: str) -> None: - web3 = get_web3_instance(chain_name) - rt = RawTokenImbalances(web3, chain_name) - sleep_time = CHAIN_SLEEP_TIMES.get(chain_name) - db_connection = create_db_connection(chain_name) +def record_exists( + solver_slippage_db_engine: Engine, + tx_hash_bytes: bytes, + token_address_bytes: bytes, +) -> bool: + """ + Check if a record with the given (tx_hash, token_address) already exists in the database. + """ + solver_slippage_db_engine = check_db_connection(solver_slippage_db_engine) + query = text( + """ + SELECT 1 FROM raw_token_imbalances + WHERE tx_hash = :tx_hash AND token_address = :token_address + """ + ) + try: + with solver_slippage_db_engine.connect() as connection: + result = connection.execute( + query, {"tx_hash": tx_hash_bytes, "token_address": token_address_bytes} + ) + record_exists = result.fetchone() is not None + return record_exists + except Exception as e: + logger.error("Error checking record existence: %s", e) + return False + + +def write_token_imbalances_to_db( + chain_name: str, + solver_slippage_db_engine: Engine, + auction_id: int, + block_number: int, + tx_hash: str, + token_address: str, + imbalance: float, +) -> None: + """ + Write token imbalances to the database if the (tx_hash, token_address) combination does not already exist. + """ + solver_slippage_db_engine = check_db_connection(solver_slippage_db_engine) + tx_hash_bytes = bytes.fromhex(tx_hash[2:]) + token_address_bytes = bytes.fromhex(token_address[2:]) + if not record_exists(solver_slippage_db_engine, tx_hash_bytes, token_address_bytes): + insert_sql = text( + """ + INSERT INTO raw_token_imbalances (auction_id, chain_name, block_number, tx_hash, token_address, imbalance) + VALUES (:auction_id, :chain_name, :block_number, :tx_hash, :token_address, :imbalance) + """ + ) + try: + with solver_slippage_db_engine.connect() as connection: + connection.execute( + insert_sql, + { + "auction_id": auction_id, + "chain_name": chain_name, + "block_number": block_number, + "tx_hash": tx_hash_bytes, + "token_address": token_address_bytes, + "imbalance": imbalance, + }, + ) + connection.commit() + logger.debug("Record inserted successfully.") + except Exception as e: + logger.error("Error inserting record: %s", e) + else: + logger.info( + "Record with tx_hash %s and token_address %s already exists.", + tx_hash, + token_address, + ) - previous_block = get_finalized_block_number(web3) - unprocessed_txs = [] # type: List - print(f"{chain_name} Daemon started.") +def get_start_block( + chain_name: str, solver_slippage_db_engine: Engine, web3: Web3 +) -> int: + """ + Retrieve the most recent block already present in raw_token_imbalances table, + delete entries for that block, and return this block number as start_block. + If no entries are present, fallback to get_finalized_block_number(). + """ + try: + solver_slippage_db_engine = check_db_connection(solver_slippage_db_engine) + + query_max_block = text( + """ + SELECT MAX(block_number) FROM raw_token_imbalances + WHERE chain_name = :chain_name + """ + ) + + with solver_slippage_db_engine.connect() as connection: + result = connection.execute(query_max_block, {"chain_name": chain_name}) + row = result.fetchone() + max_block = ( + row[0] if row is not None else None + ) # Fetch the maximum block number + if max_block is not None: + logger.debug("Fetched max block number from database: %d", max_block) + + # If no entries present, fallback to get_finalized_block_number() + if max_block is None: + return get_finalized_block_number(web3) + + # delete entries for the max block from the table + delete_sql = text( + """ + DELETE FROM raw_token_imbalances WHERE chain_name = :chain_name AND block_number = :block_number + """ + ) + try: + connection.execute( + delete_sql, {"chain_name": chain_name, "block_number": max_block} + ) + connection.commit() + logger.debug( + "Successfully deleted entries for block number: %s", max_block + ) + except Exception as e: + logger.debug( + "Failed to delete entries for block number %s: %s", max_block, e + ) + + return max_block + except Exception as e: + logger.error("Error accessing database: %s", e) + return get_finalized_block_number(web3) + + +def process_transactions(chain_name: str) -> None: + """ + Process transactions to compute imbalances for a given blockchain via chain name. + """ + web3 = get_web3_instance() + rt = RawTokenImbalances(web3, chain_name) + backend_db_connection = create_backend_db_connection(chain_name) + solver_slippage_db_connection = create_solver_slippage_db_connection() + start_block = get_start_block(chain_name, solver_slippage_db_connection, web3) + previous_block = start_block + unprocessed_txs: List[Tuple[str, int, int]] = [] + logger.info("%s Daemon started. Start block: %d", chain_name, start_block) while True: try: latest_block = get_finalized_block_number(web3) - new_txs = fetch_transaction_hashes( - db_connection, previous_block, latest_block + new_txs = fetch_tx_data( + backend_db_connection, chain_name, previous_block, latest_block ) - # add any unprocessed hashes for processing, then clear list of unprocessed + # add any unprocessed txs for processing, then clear list of unprocessed all_txs = new_txs + unprocessed_txs unprocessed_txs.clear() - for tx in all_txs: - print(f"Processing transaction on {chain_name}: {tx}") + for tx, auction_id, block_number in all_txs: + logger.info("Processing transaction on %s: %s", chain_name, tx) try: imbalances = rt.compute_imbalances(tx) - print(f"Token Imbalances on {chain_name}:") + # append imbalances to a single log message + log_message = [f"Token Imbalances on {chain_name} for tx {tx}:"] for token_address, imbalance in imbalances.items(): - print(f"Token: {token_address}, Imbalance: {imbalance}") + # ignore tokens that have null imbalances + if imbalance != 0: + write_token_imbalances_to_db( + chain_name, + solver_slippage_db_connection, + auction_id, + block_number, + tx, + token_address, + imbalance, + ) + log_message.append( + f"Token: {token_address}, Imbalance: {imbalance}" + ) + logger.info("\n".join(log_message)) except ValueError as e: - print(e) - unprocessed_txs.append(tx) + logger.error("ValueError: %s", e) + unprocessed_txs.append((tx, auction_id, block_number)) - print("Done checks..") previous_block = latest_block + 1 except ConnectionError as e: - print(f"Connection error processing transactions on {chain_name}: {e}") + logger.error( + "Connection error processing transactions on %s: %s", chain_name, e + ) except Exception as e: - print(f"Error processing transactions on {chain_name}: {e}") - - time.sleep(sleep_time) + logger.error("Error processing transactions on %s: %s", chain_name, e) + if CHAIN_SLEEP_TIME is not None: + time.sleep(CHAIN_SLEEP_TIME) def main() -> None: - threads = [] - - for chain_name in CHAIN_RPC_ENDPOINTS.keys(): - thread = Thread(target=process_transactions, args=(chain_name,), daemon=True) - thread.start() - threads.append(thread) - - for thread in threads: - thread.join() + """ + Main function to start the daemon for a blockchain. + """ + chain_name = os.getenv("CHAIN_NAME") + if chain_name is None: + logger.error("CHAIN_NAME environment variable is not set.") + return + process_transactions(chain_name) if __name__ == "__main__": diff --git a/src/helper_functions.py b/src/helper_functions.py new file mode 100644 index 0000000..07cf062 --- /dev/null +++ b/src/helper_functions.py @@ -0,0 +1,47 @@ +""" +This file contains some auxiliary functions +""" +from __future__ import annotations +import sys +import logging +from typing import Optional + + +def get_logger(filename: Optional[str] = None) -> logging.Logger: + """ + get_logger() returns a logger object that can write to a file, terminal or only file if needed. + """ + logger = logging.getLogger() + logger.setLevel(logging.INFO) + + # Clear any existing handlers to avoid duplicate logs + if logger.hasHandlers(): + logger.handlers.clear() + + # Create formatter + formatter = logging.Formatter("%(levelname)s - %(message)s") + + # Handler for stdout (INFO and lower) + stdout_handler = logging.StreamHandler(sys.stdout) + stdout_handler.setLevel(logging.INFO) + stdout_handler.setFormatter(formatter) + + # ERROR and above logs will not be logged to stdout + stdout_handler.addFilter(lambda record: record.levelno < logging.ERROR) + + # Handler for stderr (ERROR and higher) + stderr_handler = logging.StreamHandler(sys.stderr) + stderr_handler.setLevel(logging.ERROR) + stderr_handler.setFormatter(formatter) + + # Add handlers to the logger + logger.addHandler(stdout_handler) + logger.addHandler(stderr_handler) + + if filename: + file_handler = logging.FileHandler(filename + ".log", mode="w") + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + return logger diff --git a/src/imbalances_script.py b/src/imbalances_script.py index cfd5d5b..5c1508d 100644 --- a/src/imbalances_script.py +++ b/src/imbalances_script.py @@ -15,16 +15,19 @@ adding the transfer value to existing inflow/outflow for the token addresses. 7. Returning to calculate_imbalances(), which finds the imbalance for all token addresses using inflow-outflow. -8. If actions are not None, it denotes an ETH transfer event, which involves reducing WETH withdrawal - amount- > update_weth_imbalance(). The ETH imbalance is also calculated via -> update_native_eth_imbalance(). -9. update_sdai_imbalance() is called in each iteration and only completes if there is an SDAI transfer - involved which has special handling for its events. +8. If actions are not None, it denotes an ETH transfer event, which involves reducing WETH + withdrawal amount- > update_weth_imbalance(). The ETH imbalance is also calculated + via -> update_native_eth_imbalance(). +9. update_sdai_imbalance() is called in each iteration and only completes if there is an SDAI + transfer involved which has special handling for its events. """ -from web3.datastructures import AttributeDict from typing import Dict, List, Optional, Tuple + from web3 import Web3 +from web3.datastructures import AttributeDict from web3.types import TxReceipt -from src.config import CHAIN_RPC_ENDPOINTS + +from src.config import CHAIN_RPC_ENDPOINTS, logger from src.constants import ( SETTLEMENT_CONTRACT_ADDRESS, NATIVE_ETH_TOKEN_ADDRESS, @@ -54,14 +57,14 @@ def find_chain_with_tx(tx_hash: str) -> Tuple[str, Web3]: for chain_name, url in CHAIN_RPC_ENDPOINTS.items(): web3 = Web3(Web3.HTTPProvider(url)) if not web3.is_connected(): - print(f"Could not connect to {chain_name}.") + logger.warning("Could not connect to %s.", chain_name) continue try: web3.eth.get_transaction_receipt(tx_hash) - print(f"Transaction found on {chain_name}.") + logger.info("Transaction found on %s.", chain_name) return chain_name, web3 - except Exception as e: - print(f"Transaction not found on {chain_name}: {e}") + except Exception as ex: + logger.debug("Transaction not found on %s: %s", chain_name, ex) raise ValueError(f"Transaction hash {tx_hash} not found on any chain.") @@ -74,10 +77,12 @@ def _to_int(value: str | int) -> int: else int(value) ) except ValueError: - print(f"Error converting value {value} to integer.") + logger.error("Error converting value %s to integer.", value) class RawTokenImbalances: + """Class for computing token imbalances.""" + def __init__(self, web3: Web3, chain_name: str): self.web3 = web3 self.chain_name = chain_name @@ -88,8 +93,8 @@ def get_transaction_receipt(self, tx_hash: str) -> Optional[TxReceipt]: """ try: return self.web3.eth.get_transaction_receipt(tx_hash) - except Exception as e: - print(f"Error getting transaction receipt: {e}") + except Exception as ex: + logger.error("Error getting transaction receipt: %s", ex) return None def get_transaction_trace(self, tx_hash: str) -> Optional[List[Dict]]: @@ -98,7 +103,7 @@ def get_transaction_trace(self, tx_hash: str) -> Optional[List[Dict]]: res = self.web3.tracing.trace_transaction(tx_hash) return res except Exception as err: - print(f"Error occurred while fetching transaction trace: {err}") + logger.error("Error occurred while fetching transaction trace: %s", err) return None def extract_actions(self, traces: List[AttributeDict], address: str) -> List[Dict]: @@ -149,7 +154,7 @@ def extract_events(self, tx_receipt: Dict) -> Dict[str, List[Dict]]: k: v for k, v in event_topics.items() if k not in transfer_topics } - events = {name: [] for name in EVENT_TOPICS} # type: dict + events: Dict[str, List[Dict]] = {name: [] for name in EVENT_TOPICS} for log in tx_receipt["logs"]: log_topic = log["topics"][0].hex() if log_topic in transfer_topics.values(): @@ -187,7 +192,7 @@ def decode_event( else: # Withdrawal event return from_address, None, value except Exception as e: - print(f"Error decoding event: {str(e)}") + logger.error("Error decoding event: %s", str(e)) return None, None, None def process_event( @@ -256,7 +261,7 @@ def decode_sdai_event(self, event: Dict) -> int | None: value = int(value_hex, 16) return value except Exception as e: - print(f"Error decoding sDAI event: {str(e)}") + logger.error(f"Error decoding sDAI event: {str(e)}") return None def process_sdai_event( @@ -317,18 +322,18 @@ def compute_imbalances(self, tx_hash: str) -> Dict[str, int]: return imbalances -# main method for finding imbalance for a single tx hash def main() -> None: + """main function for finding imbalance for a single tx hash.""" tx_hash = input("Enter transaction hash: ") chain_name, web3 = find_chain_with_tx(tx_hash) rt = RawTokenImbalances(web3, chain_name) try: imbalances = rt.compute_imbalances(tx_hash) - print(f"Token Imbalances on {chain_name}:") + logger.info(f"Token Imbalances on {chain_name}:") for token_address, imbalance in imbalances.items(): - print(f"Token: {token_address}, Imbalance: {imbalance}") + logger.info(f"Token: {token_address}, Imbalance: {imbalance}") except ValueError as e: - print(e) + logger.error(e) if __name__ == "__main__": diff --git a/tests/basic_test.py b/tests/basic_test.py index 4598d5e..f83b857 100644 --- a/tests/basic_test.py +++ b/tests/basic_test.py @@ -1,3 +1,4 @@ +""" Runs a basic test for raw imbalance calculation edge-cases. """ import pytest from src.imbalances_script import RawTokenImbalances @@ -33,6 +34,9 @@ ], ) def test_imbalances(tx_hash, expected_imbalances): + """ + Asserts imbalances match for main script with test values provided. + """ rt = RawTokenImbalances() imbalances, _ = rt.compute_imbalances(tx_hash) for token_address, expected_imbalance in expected_imbalances.items():