Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

addressed issue #7 & trace handling #11

Merged
merged 2 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 16 additions & 25 deletions src/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from typing import Optional
from sqlalchemy import text
from sqlalchemy.exc import OperationalError
from sqlalchemy import create_engine, Engine
Expand All @@ -18,6 +17,11 @@
"Gnosis": os.getenv("GNOSIS_NODE_URL"),
}

CREATE_DB_URLS = {
"backend": os.getenv("DB_URL"),
"solver_slippage": os.getenv("SOLVER_SLIPPAGE_DB_URL"),
}


def get_env_int(var_name: str) -> int:
"""
Expand All @@ -36,28 +40,19 @@ def get_env_int(var_name: str) -> int:
CHAIN_SLEEP_TIME = get_env_int("CHAIN_SLEEP_TIME")


def create_backend_db_connection(chain_name: str) -> Engine:
"""function that creates a connection to the CoW db."""
read_db_url = os.getenv("DB_URL")

if not read_db_url:
raise ValueError(f"No database URL found for chain: {chain_name}")

return create_engine(f"postgresql+psycopg2://{read_db_url}")


def create_solver_slippage_db_connection() -> Engine:
"""function that creates a connection to the CoW db."""
solver_db_url = os.getenv("SOLVER_SLIPPAGE_DB_URL")
if not solver_db_url:
raise ValueError(
"Solver slippage database URL not found in environment variables."
)
def create_db_connection(db_type: str) -> Engine:
"""
Function that creates a connection to the specified database.
db_type should be either "backend" or "solver_slippage".
"""
db_url = CREATE_DB_URLS.get(db_type)
if not db_url:
raise ValueError(f"{db_type} database URL not found in environment variables.")

return create_engine(f"postgresql+psycopg2://{solver_db_url}")
return create_engine(f"postgresql+psycopg2://{db_url}")


def check_db_connection(connection: Engine, chain_name: Optional[str] = None) -> Engine:
def check_db_connection(connection: Engine, db_type: str) -> Engine:
"""
Check if the database connection is still active. If not, create a new one.
"""
Expand All @@ -67,9 +62,5 @@ def check_db_connection(connection: Engine, chain_name: Optional[str] = None) ->
conn.execute(text("SELECT 1"))
except OperationalError:
# if connection is closed, create new one
connection = (
create_backend_db_connection(chain_name)
if chain_name
else create_solver_slippage_db_connection()
)
connection = create_db_connection(db_type)
return connection
92 changes: 52 additions & 40 deletions src/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from src.config import (
CHAIN_SLEEP_TIME,
NODE_URL,
create_backend_db_connection,
create_solver_slippage_db_connection,
create_db_connection,
check_db_connection,
logger,
)
Expand All @@ -34,10 +33,12 @@ def get_finalized_block_number(web3: Web3) -> int:


def fetch_tx_data(
backend_db_connection: Engine, chain_name: str, start_block: int, end_block: int
backend_db_connection: Engine, start_block: int, end_block: int
) -> List[Tuple[str, int, int]]:
"""Fetch transaction hashes beginning from start_block to end_block."""
backend_db_connection = check_db_connection(backend_db_connection, chain_name)

backend_db_connection = check_db_connection(backend_db_connection, "backend")

query = f"""
SELECT tx_hash, auction_id, block_number
FROM settlements
Expand All @@ -57,22 +58,25 @@ def fetch_tx_data(


def record_exists(
solver_slippage_db_engine: Engine,
solver_slippage_connection: Engine,
tx_hash_bytes: bytes,
token_address_bytes: bytes,
) -> bool:
"""
Check if a record with the given (tx_hash, token_address) already exists in the database.
"""
solver_slippage_db_engine = check_db_connection(solver_slippage_db_engine)
solver_slippage_connection = check_db_connection(
solver_slippage_connection, "solver_slippage"
)

query = text(
"""
SELECT 1 FROM raw_token_imbalances
WHERE tx_hash = :tx_hash AND token_address = :token_address
"""
)
try:
with solver_slippage_db_engine.connect() as connection:
with solver_slippage_connection.connect() as connection:
result = connection.execute(
query, {"tx_hash": tx_hash_bytes, "token_address": token_address_bytes}
)
Expand All @@ -85,7 +89,7 @@ def record_exists(

def write_token_imbalances_to_db(
chain_name: str,
solver_slippage_db_engine: Engine,
solver_slippage_connection: Engine,
auction_id: int,
block_number: int,
tx_hash: str,
Expand All @@ -95,18 +99,23 @@ def write_token_imbalances_to_db(
"""
Write token imbalances to the database if the (tx_hash, token_address) combination does not already exist.
"""
solver_slippage_db_engine = check_db_connection(solver_slippage_db_engine)
solver_slippage_connection = check_db_connection(
solver_slippage_connection, "solver_slippage"
)

tx_hash_bytes = bytes.fromhex(tx_hash[2:])
token_address_bytes = bytes.fromhex(token_address[2:])
if not record_exists(solver_slippage_db_engine, tx_hash_bytes, token_address_bytes):
if not record_exists(
solver_slippage_connection, tx_hash_bytes, token_address_bytes
):
insert_sql = text(
"""
INSERT INTO raw_token_imbalances (auction_id, chain_name, block_number, tx_hash, token_address, imbalance)
VALUES (:auction_id, :chain_name, :block_number, :tx_hash, :token_address, :imbalance)
"""
)
try:
with solver_slippage_db_engine.connect() as connection:
with solver_slippage_connection.connect() as connection:
connection.execute(
insert_sql,
{
Expand All @@ -131,15 +140,17 @@ def write_token_imbalances_to_db(


def get_start_block(
chain_name: str, solver_slippage_db_engine: Engine, web3: Web3
chain_name: str, solver_slippage_connection: Engine, web3: Web3
) -> int:
"""
Retrieve the most recent block already present in raw_token_imbalances table,
delete entries for that block, and return this block number as start_block.
If no entries are present, fallback to get_finalized_block_number().
"""
try:
solver_slippage_db_engine = check_db_connection(solver_slippage_db_engine)
solver_slippage_connection = check_db_connection(
solver_slippage_connection, "solver_slippage"
)

query_max_block = text(
"""
Expand All @@ -148,7 +159,7 @@ def get_start_block(
"""
)

with solver_slippage_db_engine.connect() as connection:
with solver_slippage_connection.connect() as connection:
result = connection.execute(query_max_block, {"chain_name": chain_name})
row = result.fetchone()
max_block = (
Expand Down Expand Up @@ -176,7 +187,7 @@ def get_start_block(
"Successfully deleted entries for block number: %s", max_block
)
except Exception as e:
logger.debug(
logger.error(
"Failed to delete entries for block number %s: %s", max_block, e
)

Expand All @@ -192,8 +203,8 @@ def process_transactions(chain_name: str) -> None:
"""
web3 = get_web3_instance()
rt = RawTokenImbalances(web3, chain_name)
backend_db_connection = create_backend_db_connection(chain_name)
solver_slippage_db_connection = create_solver_slippage_db_connection()
backend_db_connection = create_db_connection("backend")
solver_slippage_db_connection = create_db_connection("solver_slippage")
start_block = get_start_block(chain_name, solver_slippage_db_connection, web3)
previous_block = start_block
unprocessed_txs: List[Tuple[str, int, int]] = []
Expand All @@ -202,46 +213,47 @@ def process_transactions(chain_name: str) -> None:
while True:
try:
latest_block = get_finalized_block_number(web3)
new_txs = fetch_tx_data(
backend_db_connection, chain_name, previous_block, latest_block
)
# add any unprocessed txs for processing, then clear list of unprocessed
new_txs = fetch_tx_data(backend_db_connection, previous_block, latest_block)
# Add any unprocessed txs for processing, then clear list of unprocessed
all_txs = new_txs + unprocessed_txs
unprocessed_txs.clear()

for tx, auction_id, block_number in all_txs:
logger.info("Processing transaction on %s: %s", chain_name, tx)
try:
imbalances = rt.compute_imbalances(tx)
# append imbalances to a single log message
log_message = [f"Token Imbalances on {chain_name} for tx {tx}:"]
for token_address, imbalance in imbalances.items():
# ignore tokens that have null imbalances
if imbalance != 0:
write_token_imbalances_to_db(
chain_name,
solver_slippage_db_connection,
auction_id,
block_number,
tx,
token_address,
imbalance,
)
log_message.append(
f"Token: {token_address}, Imbalance: {imbalance}"
)
logger.info("\n".join(log_message))
# Append imbalances to a single log message
if imbalances is not None:
log_message = [f"Token Imbalances on {chain_name} for tx {tx}:"]
for token_address, imbalance in imbalances.items():
# Ignore tokens that have null imbalances
if imbalance != 0:
write_token_imbalances_to_db(
chain_name,
solver_slippage_db_connection,
auction_id,
block_number,
tx,
token_address,
imbalance,
)
log_message.append(
f"Token: {token_address}, Imbalance: {imbalance}"
)
logger.info("\n".join(log_message))
else:
raise ValueError("Imbalances computation returned None.")
except ValueError as e:
logger.error("ValueError: %s", e)
unprocessed_txs.append((tx, auction_id, block_number))

previous_block = latest_block + 1
except ConnectionError as e:
logger.error(
"Connection error processing transactions on %s: %s", chain_name, e
)
except Exception as e:
logger.error("Error processing transactions on %s: %s", chain_name, e)

if CHAIN_SLEEP_TIME is not None:
time.sleep(CHAIN_SLEEP_TIME)

Expand Down
62 changes: 37 additions & 25 deletions src/imbalances_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
9. update_sdai_imbalance() is called in each iteration and only completes if there is an SDAI
transfer involved which has special handling for its events.
"""

from typing import Dict, List, Optional, Tuple

from web3 import Web3
Expand Down Expand Up @@ -291,35 +292,43 @@ def update_sdai_imbalance(
if event["address"] == SDAI_TOKEN_ADDRESS:
self.process_sdai_event(event, imbalances, is_deposit=False)

def compute_imbalances(self, tx_hash: str) -> Dict[str, int]:
"""Compute token imbalances for a given transaction hash."""
tx_receipt = self.get_transaction_receipt(tx_hash)
if tx_receipt is None:
raise ValueError(
f"Transaction hash {tx_hash} not found on chain {self.chain_name}."
)
# find trace and actions from trace to track native ETH events
traces = self.get_transaction_trace(tx_hash)
native_eth_imbalance = None
actions = []
if traces is not None:
def compute_imbalances(self, tx_hash: str) -> Optional[Dict[str, int]]:
try:
tx_receipt = self.get_transaction_receipt(tx_hash)
if not tx_receipt:
logger.error("No transaction receipt found for %s", tx_hash)
return None

traces = self.get_transaction_trace(tx_hash)
if traces is None:
logger.error(
"Error fetching transaction trace for %s. Marking transaction as unprocessed.",
tx_hash,
)
return None

events = self.extract_events(tx_receipt)
imbalances = self.calculate_imbalances(events, SETTLEMENT_CONTRACT_ADDRESS)

native_eth_imbalance = None
actions = []
actions = self.extract_actions(traces, SETTLEMENT_CONTRACT_ADDRESS)
native_eth_imbalance = self.calculate_native_eth_imbalance(
actions, SETTLEMENT_CONTRACT_ADDRESS
)

events = self.extract_events(tx_receipt)
imbalances = self.calculate_imbalances(events, SETTLEMENT_CONTRACT_ADDRESS)

if actions:
self.update_weth_imbalance(
events, actions, imbalances, SETTLEMENT_CONTRACT_ADDRESS
)
self.update_native_eth_imbalance(imbalances, native_eth_imbalance)
if actions:
self.update_weth_imbalance(
events, actions, imbalances, SETTLEMENT_CONTRACT_ADDRESS
)
self.update_native_eth_imbalance(imbalances, native_eth_imbalance)

self.update_sdai_imbalance(events, imbalances)
self.update_sdai_imbalance(events, imbalances)
return imbalances

return imbalances
except Exception as e:
logger.error("Error computing imbalances for %s: %s", tx_hash, e)
return None


def main() -> None:
Expand All @@ -329,9 +338,12 @@ def main() -> None:
rt = RawTokenImbalances(web3, chain_name)
try:
imbalances = rt.compute_imbalances(tx_hash)
logger.info(f"Token Imbalances on {chain_name}:")
for token_address, imbalance in imbalances.items():
logger.info(f"Token: {token_address}, Imbalance: {imbalance}")
if imbalances is not None:
logger.info(f"Token Imbalances on {chain_name}:")
for token_address, imbalance in imbalances.items():
logger.info(f"Token: {token_address}, Imbalance: {imbalance}")
else:
raise ValueError("Imbalances computation returned None.")
except ValueError as e:
logger.error(e)

Expand Down
Loading