Skip to content

Commit

Permalink
stdout, no threads
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhagarwal03 committed Jul 10, 2024
1 parent c12ab0c commit dc64020
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 51 deletions.
24 changes: 12 additions & 12 deletions .env.sample
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
# .env.sample

# URLs for DB connection
ETHEREUM_DB_URL=
GNOSIS_DB_URL=
# DB connection
DB_URL=

# URLs for Node provider connection
ETHEREUM_NODE_URL=
GNOSIS_NODE_URL=
# Node provider connection
NODE_URL=

# add credentials for connecting to solver slippage DB based on this format
SOLVER_SLIPPAGE_DB_URL=postgresql://username:password@hostname:port/database
# connecting to Solver Slippage DB
SOLVER_SLIPPAGE_DB_URL=

# configure chain sleep time
ETHEREUM_SLEEP_TIME=
GNOSIS_SLEEP_TIME=
# configure chain sleep time, e.g. CHAIN_SLEEP_TIME=60
CHAIN_SLEEP_TIME=

# add chain name, e.g. CHAIN_NAME=Ethereum
CHAIN_NAME=

# optional
INFURA_KEY=infura_key_here
INFURA_KEY=
10 changes: 5 additions & 5 deletions src/balanceof_imbalances.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from web3 import Web3
from web3.types import TxReceipt, HexStr
from eth_typing import ChecksumAddress
from typing import Dict, Optional, Set, Any
from src.config import ETHEREUM_NODE_URL
from typing import Dict, Optional, Set
from src.config import NODE_URL
from src.constants import SETTLEMENT_CONTRACT_ADDRESS, NATIVE_ETH_TOKEN_ADDRESS
from contracts.erc20_abi import erc20_abi

# conducting sanity test only for ethereum mainnet transactions


class BalanceOfImbalances:
def __init__(self, ETHEREUM_NODE_URL: str):
self.web3 = Web3(Web3.HTTPProvider(ETHEREUM_NODE_URL))
def __init__(self, NODE_URL: str):
self.web3 = Web3(Web3.HTTPProvider(NODE_URL))

def get_token_balance(
self,
Expand Down Expand Up @@ -124,7 +124,7 @@ def compute_imbalances(self, tx_hash: HexStr) -> Dict[ChecksumAddress, int]:

def main():
tx_hash = input("Enter transaction hash: ")
bo = BalanceOfImbalances(ETHEREUM_NODE_URL)
bo = BalanceOfImbalances(NODE_URL)
imbalances = bo.compute_imbalances(tx_hash)
print("Token Imbalances:")
for token_address, imbalance in imbalances.items():
Expand Down
21 changes: 9 additions & 12 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@


load_dotenv()
ETHEREUM_NODE_URL = os.getenv("ETHEREUM_NODE_URL")
GNOSIS_NODE_URL = os.getenv("GNOSIS_NODE_URL")

CHAIN_RPC_ENDPOINTS = {"Ethereum": ETHEREUM_NODE_URL, "Gnosis": GNOSIS_NODE_URL}
NODE_URL = os.getenv("NODE_URL")

logger = get_logger("raw_token_imbalances")

# Utilized by imbalances_script for computing for single tx hash
CHAIN_RPC_ENDPOINTS = {
"Ethereum": os.getenv("ETHEREUM_NODE_URL"),
"Gnosis": os.getenv("GNOSIS_NODE_URL"),
}


def get_env_int(var_name: str) -> int:
"""
Expand All @@ -30,18 +33,12 @@ def get_env_int(var_name: str) -> int:
raise ValueError(f"Environment variable {var_name} must be a int.")


CHAIN_SLEEP_TIMES = {
"Ethereum": get_env_int("ETHEREUM_SLEEP_TIME"),
"Gnosis": get_env_int("GNOSIS_SLEEP_TIME"),
}
CHAIN_SLEEP_TIME = get_env_int("CHAIN_SLEEP_TIME")


def create_backend_db_connection(chain_name: str) -> Engine:
"""function that creates a connection to the CoW db."""
if chain_name == "Ethereum":
read_db_url = os.getenv("ETHEREUM_DB_URL")
elif chain_name == "Gnosis":
read_db_url = os.getenv("GNOSIS_DB_URL")
read_db_url = os.getenv("DB_URL")

if not read_db_url:
raise ValueError(f"No database URL found for chain: {chain_name}")
Expand Down
34 changes: 14 additions & 20 deletions src/daemon.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,29 @@
"""
Running this daemon computes raw imbalances for finalized blocks by calling imbalances_script.py.
"""

import os
import time
from typing import List, Tuple
from threading import Thread
import pandas as pd
from web3 import Web3
from sqlalchemy import text
from sqlalchemy.engine import Engine
from src.imbalances_script import RawTokenImbalances
from src.config import (
CHAIN_RPC_ENDPOINTS,
CHAIN_SLEEP_TIMES,
CHAIN_SLEEP_TIME,
NODE_URL,
create_backend_db_connection,
create_solver_slippage_db_connection,
check_db_connection,
logger,
)


def get_web3_instance(chain_name: str) -> Web3:
def get_web3_instance() -> Web3:
"""
returns a Web3 instance for the given blockchain via chain name.
"""
return Web3(Web3.HTTPProvider(CHAIN_RPC_ENDPOINTS[chain_name]))
return Web3(Web3.HTTPProvider(NODE_URL))


def get_finalized_block_number(web3: Web3) -> int:
Expand Down Expand Up @@ -191,9 +190,8 @@ def process_transactions(chain_name: str) -> None:
"""
Process transactions to compute imbalances for a given blockchain via chain name.
"""
web3 = get_web3_instance(chain_name)
web3 = get_web3_instance()
rt = RawTokenImbalances(web3, chain_name)
sleep_time = CHAIN_SLEEP_TIMES.get(chain_name)
backend_db_connection = create_backend_db_connection(chain_name)
solver_slippage_db_connection = create_solver_slippage_db_connection()
start_block = get_start_block(chain_name, solver_slippage_db_connection, web3)
Expand Down Expand Up @@ -244,23 +242,19 @@ def process_transactions(chain_name: str) -> None:
)
except Exception as e:
logger.error("Error processing transactions on %s: %s", chain_name, e)
if sleep_time is not None:
time.sleep(sleep_time)
if CHAIN_SLEEP_TIME is not None:
time.sleep(CHAIN_SLEEP_TIME)


def main() -> None:
"""
Main function to start the daemon threads for each blockchain.
Main function to start the daemon for a blockchain.
"""
threads = []

for chain_name in CHAIN_RPC_ENDPOINTS.keys():
thread = Thread(target=process_transactions, args=(chain_name,), daemon=True)
thread.start()
threads.append(thread)

for thread in threads:
thread.join()
chain_name = os.getenv("CHAIN_NAME")
if chain_name is None:
logger.error("CHAIN_NAME environment variable is not set.")
return
process_transactions(chain_name)


if __name__ == "__main__":
Expand Down
28 changes: 26 additions & 2 deletions src/helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
This file contains some auxiliary functions
"""
from __future__ import annotations
import sys
import logging
from typing import Optional

Expand All @@ -10,13 +11,36 @@ def get_logger(filename: Optional[str] = None) -> logging.Logger:
"""
get_logger() returns a logger object that can write to a file, terminal or only file if needed.
"""
logging.basicConfig(format="%(levelname)s - %(message)s")
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Clear any existing handlers to avoid duplicate logs
if logger.hasHandlers():
logger.handlers.clear()

# Create formatter
formatter = logging.Formatter("%(levelname)s - %(message)s")

# Handler for stdout (INFO and lower)
stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setLevel(logging.INFO)
stdout_handler.setFormatter(formatter)

# ERROR and above logs will not be logged to stdout
stdout_handler.addFilter(lambda record: record.levelno < logging.ERROR)

# Handler for stderr (ERROR and higher)
stderr_handler = logging.StreamHandler(sys.stderr)
stderr_handler.setLevel(logging.ERROR)
stderr_handler.setFormatter(formatter)

# Add handlers to the logger
logger.addHandler(stdout_handler)
logger.addHandler(stderr_handler)

if filename:
file_handler = logging.FileHandler(filename + ".log", mode="w")
file_handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(levelname)s - %(message)s")
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)

Expand Down

0 comments on commit dc64020

Please sign in to comment.