Skip to content

Commit

Permalink
Merge pull request #4 from cowprotocol/add-logging
Browse files Browse the repository at this point in the history
logging and writing to db
  • Loading branch information
harisang authored Jul 10, 2024
2 parents f26e083 + dc64020 commit bd08e86
Show file tree
Hide file tree
Showing 10 changed files with 418 additions and 120 deletions.
21 changes: 14 additions & 7 deletions .env.sample
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
# .env.sample

# URLs for DB connection
ETHEREUM_DB_URL=
GNOSIS_DB_URL=
# DB connection
DB_URL=

# URLs for Node provider connection
ETHEREUM_NODE_URL=
GNOSIS_NODE_URL=
# Node provider connection
NODE_URL=

# connecting to Solver Slippage DB
SOLVER_SLIPPAGE_DB_URL=

# configure chain sleep time, e.g. CHAIN_SLEEP_TIME=60
CHAIN_SLEEP_TIME=

# add chain name, e.g. CHAIN_NAME=Ethereum
CHAIN_NAME=

# optional
INFURA_KEY=infura_key_here
INFURA_KEY=
3 changes: 3 additions & 0 deletions contracts/erc20_abi.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""
ERC20 ABI contract
"""
erc20_abi = [
{
"constant": True,
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ black==23.3.0
mypy==1.4.1
pylint==3.2.5
pytest==7.4.0
setuptools
setuptools
pandas-stubs
types-psycopg2
64 changes: 36 additions & 28 deletions src/balanceof_imbalances.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,23 @@
# mypy: disable-error-code="call-overload, arg-type, operator"
import sys
import os

# for debugging purposes
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from web3 import Web3
from web3.types import TxReceipt
from web3.types import TxReceipt, HexStr
from eth_typing import ChecksumAddress
from typing import Dict, Optional, Set
from src.config import ETHEREUM_NODE_URL
from src.config import NODE_URL
from src.constants import SETTLEMENT_CONTRACT_ADDRESS, NATIVE_ETH_TOKEN_ADDRESS
from contracts.erc20_abi import erc20_abi

# conducting sanity test only for ethereum mainnet transactions


class BalanceOfImbalances:
def __init__(self, ETHEREUM_NODE_URL: str):
self.web3 = Web3(Web3.HTTPProvider(ETHEREUM_NODE_URL))
def __init__(self, NODE_URL: str):
self.web3 = Web3(Web3.HTTPProvider(NODE_URL))

def get_token_balance(
self, token_address: str, account: str, block_identifier: int
self,
token_address: ChecksumAddress,
account: ChecksumAddress,
block_identifier: int,
) -> Optional[int]:
"""Retrieve the ERC-20 token balance of an account at a given block."""
token_contract = self.web3.eth.contract(address=token_address, abi=erc20_abi)
Expand All @@ -33,28 +29,30 @@ def get_token_balance(
print(f"Error fetching balance for token {token_address}: {e}")
return None

def get_eth_balance(self, account: str, block_identifier: int) -> Optional[int]:
def get_eth_balance(
self, account: ChecksumAddress, block_identifier: int
) -> Optional[int]:
"""Get the ETH balance for a given account and block number."""
try:
return self.web3.eth.get_balance(account, block_identifier=block_identifier)
except Exception as e:
print(f"Error fetching ETH balance: {e}")
return None

def extract_token_addresses(self, tx_receipt: Dict) -> Set[str]:
def extract_token_addresses(self, tx_receipt: TxReceipt) -> Set[ChecksumAddress]:
"""Extract unique token addresses from 'Transfer' events in a transaction receipt."""
token_addresses = set()
token_addresses: Set[ChecksumAddress] = set()
transfer_topics = {
self.web3.keccak(text="Transfer(address,address,uint256)").hex(),
self.web3.keccak(text="ERC20Transfer(address,address,uint256)").hex(),
self.web3.keccak(text="Withdrawal(address,uint256)").hex(),
}
for log in tx_receipt["logs"]:
if log["topics"][0].hex() in transfer_topics:
token_addresses.add(log["address"])
token_addresses.add(self.web3.to_checksum_address(log["address"]))
return token_addresses

def get_transaction_receipt(self, tx_hash: str) -> Optional[TxReceipt]:
def get_transaction_receipt(self, tx_hash: HexStr) -> Optional[TxReceipt]:
"""Fetch the transaction receipt for the given hash."""
try:
return self.web3.eth.get_transaction_receipt(tx_hash)
Expand All @@ -66,35 +64,45 @@ def get_balances(
self, token_addresses: Set[ChecksumAddress], block_number: int
) -> Dict[ChecksumAddress, Optional[int]]:
"""Get balances for all tokens at the given block number."""
balances = {}
balances[NATIVE_ETH_TOKEN_ADDRESS] = self.get_eth_balance(
SETTLEMENT_CONTRACT_ADDRESS, block_number
balances: Dict[ChecksumAddress, Optional[int]] = {}
balances[
self.web3.to_checksum_address(NATIVE_ETH_TOKEN_ADDRESS)
] = self.get_eth_balance(
self.web3.to_checksum_address(SETTLEMENT_CONTRACT_ADDRESS), block_number
)

for token_address in token_addresses:
balances[token_address] = self.get_token_balance(
token_address, SETTLEMENT_CONTRACT_ADDRESS, block_number
token_address,
self.web3.to_checksum_address(SETTLEMENT_CONTRACT_ADDRESS),
block_number,
)

return balances

def calculate_imbalances(
self,
prev_balances: Dict[str, Optional[int]],
final_balances: Dict[str, Optional[int]],
) -> Dict[str, int]:
prev_balances: Dict[ChecksumAddress, Optional[int]],
final_balances: Dict[ChecksumAddress, Optional[int]],
) -> Dict[ChecksumAddress, int]:
"""Calculate imbalances between previous and final balances."""
imbalances = {}
imbalances: Dict[ChecksumAddress, int] = {}
for token_address in prev_balances:
if (
prev_balances[token_address] is not None
and final_balances[token_address] is not None
):
imbalance = final_balances[token_address] - prev_balances[token_address]
# need to ensure prev_balance and final_balance contain values
# to prevent subtraction from None
prev_balance = prev_balances[token_address]
assert prev_balance is not None
final_balance = final_balances[token_address]
assert final_balance is not None
imbalance = final_balance - prev_balance
imbalances[token_address] = imbalance
return imbalances

def compute_imbalances(self, tx_hash: str) -> Dict[str, int]:
def compute_imbalances(self, tx_hash: HexStr) -> Dict[ChecksumAddress, int]:
"""Compute token imbalances before and after a transaction."""
tx_receipt = self.get_transaction_receipt(tx_hash)
if tx_receipt is None:
Expand All @@ -116,7 +124,7 @@ def compute_imbalances(self, tx_hash: str) -> Dict[str, int]:

def main():
tx_hash = input("Enter transaction hash: ")
bo = BalanceOfImbalances(ETHEREUM_NODE_URL)
bo = BalanceOfImbalances(NODE_URL)
imbalances = bo.compute_imbalances(tx_hash)
print("Token Imbalances:")
for token_address, imbalance in imbalances.items():
Expand Down
73 changes: 69 additions & 4 deletions src/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,75 @@
import os
from typing import Optional
from sqlalchemy import text
from sqlalchemy.exc import OperationalError
from sqlalchemy import create_engine, Engine
from dotenv import load_dotenv
from src.helper_functions import get_logger


load_dotenv()
ETHEREUM_NODE_URL = os.getenv("ETHEREUM_NODE_URL")
GNOSIS_NODE_URL = os.getenv("GNOSIS_NODE_URL")
NODE_URL = os.getenv("NODE_URL")

logger = get_logger("raw_token_imbalances")

# Utilized by imbalances_script for computing for single tx hash
CHAIN_RPC_ENDPOINTS = {
"Ethereum": os.getenv("ETHEREUM_NODE_URL"),
"Gnosis": os.getenv("GNOSIS_NODE_URL"),
}


def get_env_int(var_name: str) -> int:
"""
Function for safe conversion to int (prevents None -> int conversion issues raised by mypy)
Retrieve environment variable and convert to int. Raise an error if not set.
"""
value = os.getenv(var_name)
if value is None:
raise ValueError(f"Environment variable {var_name} is not set.")
try:
return int(value)
except ValueError:
raise ValueError(f"Environment variable {var_name} must be a int.")


CHAIN_SLEEP_TIME = get_env_int("CHAIN_SLEEP_TIME")


def create_backend_db_connection(chain_name: str) -> Engine:
"""function that creates a connection to the CoW db."""
read_db_url = os.getenv("DB_URL")

if not read_db_url:
raise ValueError(f"No database URL found for chain: {chain_name}")

return create_engine(f"postgresql+psycopg2://{read_db_url}")


def create_solver_slippage_db_connection() -> Engine:
"""function that creates a connection to the CoW db."""
solver_db_url = os.getenv("SOLVER_SLIPPAGE_DB_URL")
if not solver_db_url:
raise ValueError(
"Solver slippage database URL not found in environment variables."
)

return create_engine(f"postgresql+psycopg2://{solver_db_url}")

CHAIN_RPC_ENDPOINTS = {"Ethereum": ETHEREUM_NODE_URL, "Gnosis": GNOSIS_NODE_URL}

CHAIN_SLEEP_TIMES = {"Ethereum": 60, "Gnosis": 120}
def check_db_connection(connection: Engine, chain_name: Optional[str] = None) -> Engine:
"""
Check if the database connection is still active. If not, create a new one.
"""
try:
if connection:
with connection.connect() as conn: # Use connection.connect() to get a Connection object
conn.execute(text("SELECT 1"))
except OperationalError:
# if connection is closed, create new one
connection = (
create_backend_db_connection(chain_name)
if chain_name
else create_solver_slippage_db_connection()
)
return connection
1 change: 1 addition & 0 deletions src/constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
""" Constants used for the token imbalances project """
from web3 import Web3

SETTLEMENT_CONTRACT_ADDRESS = Web3.to_checksum_address(
Expand Down
Loading

0 comments on commit bd08e86

Please sign in to comment.