From c2e7402d3ff35f75e18667ef7412893415741f83 Mon Sep 17 00:00:00 2001 From: Shubh Agarwal Date: Wed, 10 Jul 2024 13:53:24 -0400 Subject: [PATCH 1/2] addressed issue #7 & trace handling --- src/config.py | 41 ++++++++---------- src/daemon.py | 92 +++++++++++++++++++++++----------------- src/imbalances_script.py | 62 ++++++++++++++++----------- 3 files changed, 107 insertions(+), 88 deletions(-) diff --git a/src/config.py b/src/config.py index 4cdf822..3f7c4f3 100644 --- a/src/config.py +++ b/src/config.py @@ -1,5 +1,4 @@ import os -from typing import Optional from sqlalchemy import text from sqlalchemy.exc import OperationalError from sqlalchemy import create_engine, Engine @@ -18,6 +17,11 @@ "Gnosis": os.getenv("GNOSIS_NODE_URL"), } +CREATE_DB_URLS = { + "backend": os.getenv("DB_URL"), + "solver_slippage": os.getenv("SOLVER_SLIPPAGE_DB_URL"), +} + def get_env_int(var_name: str) -> int: """ @@ -36,28 +40,19 @@ def get_env_int(var_name: str) -> int: CHAIN_SLEEP_TIME = get_env_int("CHAIN_SLEEP_TIME") -def create_backend_db_connection(chain_name: str) -> Engine: - """function that creates a connection to the CoW db.""" - read_db_url = os.getenv("DB_URL") - - if not read_db_url: - raise ValueError(f"No database URL found for chain: {chain_name}") - - return create_engine(f"postgresql+psycopg2://{read_db_url}") - - -def create_solver_slippage_db_connection() -> Engine: - """function that creates a connection to the CoW db.""" - solver_db_url = os.getenv("SOLVER_SLIPPAGE_DB_URL") - if not solver_db_url: - raise ValueError( - "Solver slippage database URL not found in environment variables." - ) +def create_db_connection(db_type: str) -> Engine: + """ + Function that creates a connection to the specified database. + db_type should be either "backend" or "solver_slippage". + """ + db_url = CREATE_DB_URLS.get(db_type) + if not db_url: + raise ValueError(f"{db_type} database URL not found in environment variables.") - return create_engine(f"postgresql+psycopg2://{solver_db_url}") + return create_engine(f"postgresql+psycopg2://{db_url}") -def check_db_connection(connection: Engine, chain_name: Optional[str] = None) -> Engine: +def check_db_connection(connection: Engine, db_type: str) -> Engine: """ Check if the database connection is still active. If not, create a new one. """ @@ -68,8 +63,8 @@ def check_db_connection(connection: Engine, chain_name: Optional[str] = None) -> except OperationalError: # if connection is closed, create new one connection = ( - create_backend_db_connection(chain_name) - if chain_name - else create_solver_slippage_db_connection() + create_db_connection("backend") + if db_type == "backend" + else create_db_connection("solver_slippage") ) return connection diff --git a/src/daemon.py b/src/daemon.py index b3595b6..1c77508 100644 --- a/src/daemon.py +++ b/src/daemon.py @@ -12,8 +12,7 @@ from src.config import ( CHAIN_SLEEP_TIME, NODE_URL, - create_backend_db_connection, - create_solver_slippage_db_connection, + create_db_connection, check_db_connection, logger, ) @@ -34,10 +33,12 @@ def get_finalized_block_number(web3: Web3) -> int: def fetch_tx_data( - backend_db_connection: Engine, chain_name: str, start_block: int, end_block: int + backend_db_connection: Engine, start_block: int, end_block: int ) -> List[Tuple[str, int, int]]: """Fetch transaction hashes beginning from start_block to end_block.""" - backend_db_connection = check_db_connection(backend_db_connection, chain_name) + + backend_db_connection = check_db_connection(backend_db_connection, "backend") + query = f""" SELECT tx_hash, auction_id, block_number FROM settlements @@ -57,14 +58,17 @@ def fetch_tx_data( def record_exists( - solver_slippage_db_engine: Engine, + solver_slippage_connection: Engine, tx_hash_bytes: bytes, token_address_bytes: bytes, ) -> bool: """ Check if a record with the given (tx_hash, token_address) already exists in the database. """ - solver_slippage_db_engine = check_db_connection(solver_slippage_db_engine) + solver_slippage_connection = check_db_connection( + solver_slippage_connection, "solver_slippage" + ) + query = text( """ SELECT 1 FROM raw_token_imbalances @@ -72,7 +76,7 @@ def record_exists( """ ) try: - with solver_slippage_db_engine.connect() as connection: + with solver_slippage_connection.connect() as connection: result = connection.execute( query, {"tx_hash": tx_hash_bytes, "token_address": token_address_bytes} ) @@ -85,7 +89,7 @@ def record_exists( def write_token_imbalances_to_db( chain_name: str, - solver_slippage_db_engine: Engine, + solver_slippage_connection: Engine, auction_id: int, block_number: int, tx_hash: str, @@ -95,10 +99,15 @@ def write_token_imbalances_to_db( """ Write token imbalances to the database if the (tx_hash, token_address) combination does not already exist. """ - solver_slippage_db_engine = check_db_connection(solver_slippage_db_engine) + solver_slippage_connection = check_db_connection( + solver_slippage_connection, "solver_slippage" + ) + tx_hash_bytes = bytes.fromhex(tx_hash[2:]) token_address_bytes = bytes.fromhex(token_address[2:]) - if not record_exists(solver_slippage_db_engine, tx_hash_bytes, token_address_bytes): + if not record_exists( + solver_slippage_connection, tx_hash_bytes, token_address_bytes + ): insert_sql = text( """ INSERT INTO raw_token_imbalances (auction_id, chain_name, block_number, tx_hash, token_address, imbalance) @@ -106,7 +115,7 @@ def write_token_imbalances_to_db( """ ) try: - with solver_slippage_db_engine.connect() as connection: + with solver_slippage_connection.connect() as connection: connection.execute( insert_sql, { @@ -131,7 +140,7 @@ def write_token_imbalances_to_db( def get_start_block( - chain_name: str, solver_slippage_db_engine: Engine, web3: Web3 + chain_name: str, solver_slippage_connection: Engine, web3: Web3 ) -> int: """ Retrieve the most recent block already present in raw_token_imbalances table, @@ -139,7 +148,9 @@ def get_start_block( If no entries are present, fallback to get_finalized_block_number(). """ try: - solver_slippage_db_engine = check_db_connection(solver_slippage_db_engine) + solver_slippage_connection = check_db_connection( + solver_slippage_connection, "solver_slippage" + ) query_max_block = text( """ @@ -148,7 +159,7 @@ def get_start_block( """ ) - with solver_slippage_db_engine.connect() as connection: + with solver_slippage_connection.connect() as connection: result = connection.execute(query_max_block, {"chain_name": chain_name}) row = result.fetchone() max_block = ( @@ -176,7 +187,7 @@ def get_start_block( "Successfully deleted entries for block number: %s", max_block ) except Exception as e: - logger.debug( + logger.error( "Failed to delete entries for block number %s: %s", max_block, e ) @@ -192,8 +203,8 @@ def process_transactions(chain_name: str) -> None: """ web3 = get_web3_instance() rt = RawTokenImbalances(web3, chain_name) - backend_db_connection = create_backend_db_connection(chain_name) - solver_slippage_db_connection = create_solver_slippage_db_connection() + backend_db_connection = create_db_connection("backend") + solver_slippage_db_connection = create_db_connection("solver_slippage") start_block = get_start_block(chain_name, solver_slippage_db_connection, web3) previous_block = start_block unprocessed_txs: List[Tuple[str, int, int]] = [] @@ -202,10 +213,8 @@ def process_transactions(chain_name: str) -> None: while True: try: latest_block = get_finalized_block_number(web3) - new_txs = fetch_tx_data( - backend_db_connection, chain_name, previous_block, latest_block - ) - # add any unprocessed txs for processing, then clear list of unprocessed + new_txs = fetch_tx_data(backend_db_connection, previous_block, latest_block) + # Add any unprocessed txs for processing, then clear list of unprocessed all_txs = new_txs + unprocessed_txs unprocessed_txs.clear() @@ -213,28 +222,30 @@ def process_transactions(chain_name: str) -> None: logger.info("Processing transaction on %s: %s", chain_name, tx) try: imbalances = rt.compute_imbalances(tx) - # append imbalances to a single log message - log_message = [f"Token Imbalances on {chain_name} for tx {tx}:"] - for token_address, imbalance in imbalances.items(): - # ignore tokens that have null imbalances - if imbalance != 0: - write_token_imbalances_to_db( - chain_name, - solver_slippage_db_connection, - auction_id, - block_number, - tx, - token_address, - imbalance, - ) - log_message.append( - f"Token: {token_address}, Imbalance: {imbalance}" - ) - logger.info("\n".join(log_message)) + # Append imbalances to a single log message + if imbalances is not None: + log_message = [f"Token Imbalances on {chain_name} for tx {tx}:"] + for token_address, imbalance in imbalances.items(): + # Ignore tokens that have null imbalances + if imbalance != 0: + write_token_imbalances_to_db( + chain_name, + solver_slippage_db_connection, + auction_id, + block_number, + tx, + token_address, + imbalance, + ) + log_message.append( + f"Token: {token_address}, Imbalance: {imbalance}" + ) + logger.info("\n".join(log_message)) + else: + raise ValueError("Imbalances computation returned None.") except ValueError as e: logger.error("ValueError: %s", e) unprocessed_txs.append((tx, auction_id, block_number)) - previous_block = latest_block + 1 except ConnectionError as e: logger.error( @@ -242,6 +253,7 @@ def process_transactions(chain_name: str) -> None: ) except Exception as e: logger.error("Error processing transactions on %s: %s", chain_name, e) + if CHAIN_SLEEP_TIME is not None: time.sleep(CHAIN_SLEEP_TIME) diff --git a/src/imbalances_script.py b/src/imbalances_script.py index 5c1508d..2992208 100644 --- a/src/imbalances_script.py +++ b/src/imbalances_script.py @@ -21,6 +21,7 @@ 9. update_sdai_imbalance() is called in each iteration and only completes if there is an SDAI transfer involved which has special handling for its events. """ + from typing import Dict, List, Optional, Tuple from web3 import Web3 @@ -291,35 +292,43 @@ def update_sdai_imbalance( if event["address"] == SDAI_TOKEN_ADDRESS: self.process_sdai_event(event, imbalances, is_deposit=False) - def compute_imbalances(self, tx_hash: str) -> Dict[str, int]: - """Compute token imbalances for a given transaction hash.""" - tx_receipt = self.get_transaction_receipt(tx_hash) - if tx_receipt is None: - raise ValueError( - f"Transaction hash {tx_hash} not found on chain {self.chain_name}." - ) - # find trace and actions from trace to track native ETH events - traces = self.get_transaction_trace(tx_hash) - native_eth_imbalance = None - actions = [] - if traces is not None: + def compute_imbalances(self, tx_hash: str) -> Optional[Dict[str, int]]: + try: + tx_receipt = self.get_transaction_receipt(tx_hash) + if not tx_receipt: + logger.error("No transaction receipt found for %s", tx_hash) + return None + + traces = self.get_transaction_trace(tx_hash) + if traces is None: + logger.error( + "Error fetching transaction trace for %s. Marking transaction as unprocessed.", + tx_hash, + ) + return None + + events = self.extract_events(tx_receipt) + imbalances = self.calculate_imbalances(events, SETTLEMENT_CONTRACT_ADDRESS) + + native_eth_imbalance = None + actions = [] actions = self.extract_actions(traces, SETTLEMENT_CONTRACT_ADDRESS) native_eth_imbalance = self.calculate_native_eth_imbalance( actions, SETTLEMENT_CONTRACT_ADDRESS ) - events = self.extract_events(tx_receipt) - imbalances = self.calculate_imbalances(events, SETTLEMENT_CONTRACT_ADDRESS) - - if actions: - self.update_weth_imbalance( - events, actions, imbalances, SETTLEMENT_CONTRACT_ADDRESS - ) - self.update_native_eth_imbalance(imbalances, native_eth_imbalance) + if actions: + self.update_weth_imbalance( + events, actions, imbalances, SETTLEMENT_CONTRACT_ADDRESS + ) + self.update_native_eth_imbalance(imbalances, native_eth_imbalance) - self.update_sdai_imbalance(events, imbalances) + self.update_sdai_imbalance(events, imbalances) + return imbalances - return imbalances + except Exception as e: + logger.error("Error computing imbalances for %s: %s", tx_hash, e) + return None def main() -> None: @@ -329,9 +338,12 @@ def main() -> None: rt = RawTokenImbalances(web3, chain_name) try: imbalances = rt.compute_imbalances(tx_hash) - logger.info(f"Token Imbalances on {chain_name}:") - for token_address, imbalance in imbalances.items(): - logger.info(f"Token: {token_address}, Imbalance: {imbalance}") + if imbalances is not None: + logger.info(f"Token Imbalances on {chain_name}:") + for token_address, imbalance in imbalances.items(): + logger.info(f"Token: {token_address}, Imbalance: {imbalance}") + else: + raise ValueError("Imbalances computation returned None.") except ValueError as e: logger.error(e) From d00a7510ec1424846272f2f54212bb119c1b6cb0 Mon Sep 17 00:00:00 2001 From: Shubh Agarwal Date: Wed, 10 Jul 2024 19:12:29 -0400 Subject: [PATCH 2/2] minor fix --- src/config.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/config.py b/src/config.py index 3f7c4f3..85d6cdc 100644 --- a/src/config.py +++ b/src/config.py @@ -62,9 +62,5 @@ def check_db_connection(connection: Engine, db_type: str) -> Engine: conn.execute(text("SELECT 1")) except OperationalError: # if connection is closed, create new one - connection = ( - create_db_connection("backend") - if db_type == "backend" - else create_db_connection("solver_slippage") - ) + connection = create_db_connection(db_type) return connection