diff --git a/app/admin/account_checking.py b/app/admin/account_checking.py index d8f8b68c..c579cf09 100644 --- a/app/admin/account_checking.py +++ b/app/admin/account_checking.py @@ -43,9 +43,10 @@ async def check_account_balance_consistency( This verifies that the total balance in each account matches the sum of all transactions for that account, properly accounting for credits and debits. - To ensure consistency during system operation, this function uses the maximum updated_at - timestamp from all accounts to limit transaction queries, ensuring that only transactions - created before or at the same time as the account snapshot are considered. + To ensure consistency during system operation, this function processes accounts in batches + and uses the maximum updated_at timestamp from each batch to limit transaction queries, + ensuring that only transactions created before or at the same time as the account snapshot + are considered. Args: session: Database session @@ -54,80 +55,115 @@ async def check_account_balance_consistency( List of checking results """ results = [] + batch_size = 1000 # Process 100 accounts at a time + offset = 0 + total_processed = 0 + + while True: + # Get accounts in batches using SQL pagination + query = ( + select(CreditAccountTable) + .order_by(CreditAccountTable.id) + .offset(offset) + .limit(batch_size) + ) + accounts_result = await session.execute(query) + batch_accounts = [ + CreditAccount.model_validate(acc) for acc in accounts_result.scalars().all() + ] - # Get all accounts - query = select(CreditAccountTable) - accounts_result = await session.execute(query) - accounts = [ - CreditAccount.model_validate(acc) for acc in accounts_result.scalars().all() - ] - - # Find the maximum updated_at timestamp from all accounts - # This represents the point in time when we took the snapshot of account balances - max_updated_at = ( - max([account.updated_at for account in accounts]) if accounts else None - ) - - if not max_updated_at: - return results + # If no more accounts to process, break the loop + if not batch_accounts: + break - for account in accounts: - # Sleep for 10ms to reduce database load - await asyncio.sleep(0.01) + # Update counters + batch_count = len(batch_accounts) + total_processed += batch_count + logger.info( + f"Processing account balance batch: {offset // batch_size + 1}, accounts: {batch_count}" + ) - # Calculate the total balance across all credit types - total_balance = account.free_credits + account.reward_credits + account.credits - - # Calculate the expected balance from all transactions, regardless of credit type - # Only include transactions created before or at the same time as the account snapshot - query = text(""" - SELECT - SUM(CASE WHEN credit_debit = 'credit' THEN change_amount ELSE 0 END) as credits, - SUM(CASE WHEN credit_debit = 'debit' THEN change_amount ELSE 0 END) as debits - FROM credit_transactions - WHERE account_id = :account_id - AND created_at <= :max_updated_at - """) - - tx_result = await session.execute( - query, {"account_id": account.id, "max_updated_at": max_updated_at} + # Find the maximum updated_at timestamp for this batch of accounts + # This represents the point in time when we took the snapshot of this batch of account balances + batch_max_updated_at = ( + max([account.updated_at for account in batch_accounts]) + if batch_accounts + else None ) - tx_data = tx_result.fetchone() - credits = tx_data.credits or Decimal("0") - debits = tx_data.debits or Decimal("0") - expected_balance = credits - debits + if not batch_max_updated_at: + offset += batch_size + continue - # Compare total balances - is_consistent = total_balance == expected_balance + # Process each account in the batch + for account in batch_accounts: + # Sleep for 10ms to reduce database load + await asyncio.sleep(0.01) - result = AccountCheckingResult( - check_type="account_total_balance", - status=is_consistent, - details={ - "account_id": account.id, - "owner_type": account.owner_type, - "owner_id": account.owner_id, - "current_total_balance": float(total_balance), - "free_credits": float(account.free_credits), - "reward_credits": float(account.reward_credits), - "credits": float(account.credits), - "expected_balance": float(expected_balance), - "total_credits": float(credits), - "total_debits": float(debits), - "difference": float(total_balance - expected_balance), - "max_updated_at": max_updated_at.isoformat() - if max_updated_at - else None, - }, - ) - results.append(result) + # Calculate the total balance across all credit types + total_balance = ( + account.free_credits + account.reward_credits + account.credits + ) - if not is_consistent: - logger.warning( - f"Account total balance inconsistency detected: {account.id} ({account.owner_type}:{account.owner_id}) " - f"Current total: {total_balance}, Expected: {expected_balance}" + # Calculate the expected balance from all transactions, regardless of credit type + # Only include transactions created before or at the same time as the account snapshot + query = text(""" + SELECT + SUM(CASE WHEN credit_debit = 'credit' THEN change_amount ELSE 0 END) as credits, + SUM(CASE WHEN credit_debit = 'debit' THEN change_amount ELSE 0 END) as debits + FROM credit_transactions + WHERE account_id = :account_id + AND created_at <= :max_updated_at + """) + + tx_result = await session.execute( + query, + {"account_id": account.id, "max_updated_at": batch_max_updated_at}, ) + tx_data = tx_result.fetchone() + + credits = tx_data.credits or Decimal("0") + debits = tx_data.debits or Decimal("0") + expected_balance = credits - debits + + # Compare total balances + is_consistent = total_balance == expected_balance + + result = AccountCheckingResult( + check_type="account_total_balance", + status=is_consistent, + details={ + "account_id": account.id, + "owner_type": account.owner_type, + "owner_id": account.owner_id, + "current_total_balance": float(total_balance), + "free_credits": float(account.free_credits), + "reward_credits": float(account.reward_credits), + "credits": float(account.credits), + "expected_balance": float(expected_balance), + "total_credits": float(credits), + "total_debits": float(debits), + "difference": float(total_balance - expected_balance), + "max_updated_at": batch_max_updated_at.isoformat() + if batch_max_updated_at + else None, + "batch": offset // batch_size + 1, + }, + ) + results.append(result) + + if not is_consistent: + logger.warning( + f"Account total balance inconsistency detected: {account.id} ({account.owner_type}:{account.owner_id}) " + f"Current total: {total_balance}, Expected: {expected_balance}" + ) + + # Move to the next batch + offset += batch_size + + logger.info( + f"Completed account balance consistency check: processed {total_processed} accounts in {offset // batch_size} batches" + ) return results @@ -138,6 +174,7 @@ async def check_transaction_balance( """Check if all credit events have balanced transactions. For each credit event, the sum of all credit transactions should equal the sum of all debit transactions. + Events are processed in batches to prevent memory overflow issues. Args: session: Database session @@ -146,62 +183,92 @@ async def check_transaction_balance( List of checking results """ results = [] + batch_size = 1000 # Process 1000 events at a time + offset = 0 + total_processed = 0 - # Get all events from the last 3 days (limit to recent events for performance) + # Time window for events (last 3 days for performance) three_days_ago = datetime.now(timezone.utc) - timedelta(days=3) - query = select(CreditEventTable).where( - CreditEventTable.created_at >= three_days_ago - ) - events_result = await session.execute(query) - events = [ - CreditEvent.model_validate(event) for event in events_result.scalars().all() - ] - for event in events: - # Sleep for 10ms to reduce database load - await asyncio.sleep(0.01) - - # Get all transactions for this event - tx_query = select(CreditTransactionTable).where( - CreditTransactionTable.event_id == event.id + while True: + # Get events in batches using SQL pagination + query = ( + select(CreditEventTable) + .where(CreditEventTable.created_at >= three_days_ago) + .order_by(CreditEventTable.id) + .offset(offset) + .limit(batch_size) ) - tx_result = await session.execute(tx_query) - transactions = [ - CreditTransaction.model_validate(tx) for tx in tx_result.scalars().all() + events_result = await session.execute(query) + batch_events = [ + CreditEvent.model_validate(event) for event in events_result.scalars().all() ] - # Calculate credit and debit sums - credit_sum = sum( - tx.change_amount for tx in transactions if tx.credit_debit == "credit" - ) - debit_sum = sum( - tx.change_amount for tx in transactions if tx.credit_debit == "debit" + # If no more events to process, break the loop + if not batch_events: + break + + # Update counters + batch_count = len(batch_events) + total_processed += batch_count + logger.info( + f"Processing transaction balance batch: {offset // batch_size + 1}, events: {batch_count}" ) - # Check if they balance - is_balanced = credit_sum == debit_sum + # Process each event in the batch + for event in batch_events: + # Sleep for 10ms to reduce database load + await asyncio.sleep(0.01) - result = AccountCheckingResult( - check_type="transaction_balance", - status=is_balanced, - details={ - "event_id": event.id, - "event_type": event.event_type, - "credit_sum": float(credit_sum), - "debit_sum": float(debit_sum), - "difference": float(credit_sum - debit_sum), - "created_at": event.created_at.isoformat() - if event.created_at - else None, - }, - ) - results.append(result) + # Get all transactions for this event + tx_query = select(CreditTransactionTable).where( + CreditTransactionTable.event_id == event.id + ) + tx_result = await session.execute(tx_query) + transactions = [ + CreditTransaction.model_validate(tx) for tx in tx_result.scalars().all() + ] + + # Calculate credit and debit sums + credit_sum = sum( + tx.change_amount for tx in transactions if tx.credit_debit == "credit" + ) + debit_sum = sum( + tx.change_amount for tx in transactions if tx.credit_debit == "debit" + ) - if not is_balanced: - logger.warning( - f"Transaction imbalance detected for event {event.id} ({event.event_type}). " - f"Credit: {credit_sum}, Debit: {debit_sum}" + # Check if they balance + is_balanced = credit_sum == debit_sum + + result = AccountCheckingResult( + check_type="transaction_balance", + status=is_balanced, + details={ + "event_id": event.id, + "event_type": event.event_type, + "credit_sum": float(credit_sum), + "debit_sum": float(debit_sum), + "difference": float(credit_sum - debit_sum), + "created_at": event.created_at.isoformat() + if event.created_at + else None, + "batch": offset // batch_size + 1, + }, ) + results.append(result) + + if not is_balanced: + logger.warning( + f"Transaction imbalance detected for event {event.id} ({event.event_type}). " + f"Credit: {credit_sum}, Debit: {debit_sum}" + ) + + # Move to the next batch + offset += batch_size + + logger.info( + f"Completed transaction balance check: processed {total_processed} events in {offset // batch_size} batches" + ) return results @@ -452,18 +519,19 @@ async def check_transaction_total_balance( return [result] -async def run_all_checks() -> Dict[str, List[AccountCheckingResult]]: - """Run all account checking procedures and return results. +async def run_quick_checks() -> Dict[str, List[AccountCheckingResult]]: + """Run quick account checking procedures and return results. + + These checks are designed to be fast and can be run frequently. Returns: Dictionary mapping check names to their results """ - logger.info("Starting account checking procedures") + logger.info("Starting quick account checking procedures") results = {} async with get_session() as session: - # Run all checks - results["account_balance"] = await check_account_balance_consistency(session) + # Run quick checks results["transaction_balance"] = await check_transaction_balance(session) results["orphaned_transactions"] = await check_orphaned_transactions(session) results["orphaned_events"] = await check_orphaned_events(session) @@ -488,10 +556,92 @@ async def run_all_checks() -> Dict[str, List[AccountCheckingResult]]: logger.info(f"{check_name}: All {len(check_results)} checks passed") if all_passed: - logger.info("All account checks passed successfully") + logger.info("All quick account checks passed successfully") + else: + logger.warning( + f"Quick account checking summary: {failed_count} checks failed - see logs for details" + ) + + # Send summary to Slack + from utils.slack_alert import send_slack_message + + # Create a summary message with color based on status + total_checks = sum(len(check_results) for check_results in results.values()) + + if all_passed: + color = "good" # Green color + title = "✅ Quick Account Checking Completed Successfully" + text = f"All {total_checks} quick account checks passed successfully." + notify = "" # No notification needed for success + else: + color = "danger" # Red color + title = "❌ Quick Account Checking Found Issues" + text = f"Quick account checking found {failed_count} issues out of {total_checks} checks." + notify = " " # Notify channel for failures + + # Create attachments with check details + attachments = [{"color": color, "title": title, "text": text, "fields": []}] + + # Add fields for each check type + for check_name, check_results in results.items(): + check_failed_count = sum(1 for result in check_results if not result.status) + check_status = ( + "✅ Passed" + if check_failed_count == 0 + else f"❌ Failed ({check_failed_count} issues)" + ) + + attachments[0]["fields"].append( + { + "title": check_name.replace("_", " ").title(), + "value": check_status, + "short": True, + } + ) + + # Send the message + send_slack_message( + message=f"{notify}Quick Account Checking Results", attachments=attachments + ) + + return results + + +async def run_slow_checks() -> Dict[str, List[AccountCheckingResult]]: + """Run slow account checking procedures and return results. + + These checks are more resource-intensive and should be run less frequently. + + Returns: + Dictionary mapping check names to their results + """ + logger.info("Starting slow account checking procedures") + + results = {} + async with get_session() as session: + # Run slow checks + results["account_balance"] = await check_account_balance_consistency(session) + + # Log summary + all_passed = True + failed_count = 0 + for check_name, check_results in results.items(): + check_failed_count = sum(1 for result in check_results if not result.status) + failed_count += check_failed_count + + if check_failed_count > 0: + logger.warning( + f"{check_name}: {check_failed_count} of {len(check_results)} checks failed" + ) + all_passed = False + else: + logger.info(f"{check_name}: All {len(check_results)} checks passed") + + if all_passed: + logger.info("All slow account checks passed successfully") else: logger.warning( - f"Account checking summary: {failed_count} checks failed - see logs for details" + f"Slow account checking summary: {failed_count} checks failed - see logs for details" ) # Send summary to Slack @@ -502,13 +652,13 @@ async def run_all_checks() -> Dict[str, List[AccountCheckingResult]]: if all_passed: color = "good" # Green color - title = "✅ Account Checking Completed Successfully" - text = f"All {total_checks} account checks passed successfully." + title = "✅ Slow Account Checking Completed Successfully" + text = f"All {total_checks} slow account checks passed successfully." notify = "" # No notification needed for success else: color = "danger" # Red color - title = "❌ Account Checking Found Issues" - text = f"Account checking found {failed_count} issues out of {total_checks} checks." + title = "❌ Slow Account Checking Found Issues" + text = f"Slow account checking found {failed_count} issues out of {total_checks} checks." notify = " " # Notify channel for failures # Create attachments with check details @@ -533,17 +683,17 @@ async def run_all_checks() -> Dict[str, List[AccountCheckingResult]]: # Send the message send_slack_message( - message=f"{notify}Account Checking Results", attachments=attachments + message=f"{notify}Slow Account Checking Results", attachments=attachments ) return results async def main(): - await init_db(**config.db) """Main entry point for running account checks.""" + await init_db(**config.db) logger.info("Starting account checking procedures") - results = await run_all_checks() + results = await run_quick_checks() logger.info("Completed account checking procedures") return results diff --git a/app/admin/api.py b/app/admin/api.py index 03757f59..33411937 100644 --- a/app/admin/api.py +++ b/app/admin/api.py @@ -556,6 +556,9 @@ async def override_agent( if subject: agent.owner = subject + if not agent.owner: + raise HTTPException(status_code=500, detail="Owner is required") + # Update agent latest_agent = await agent.override(agent_id) diff --git a/app/admin/scheduler.py b/app/admin/scheduler.py index 63c7256d..45d6bcff 100644 --- a/app/admin/scheduler.py +++ b/app/admin/scheduler.py @@ -8,7 +8,6 @@ from apscheduler.triggers.cron import CronTrigger from sqlalchemy import update -from app.admin.account_checking import run_all_checks from app.config.config import config from app.core.credit import refill_all_free_credits from app.services.twitter.oauth2_refresh import refresh_expiring_tokens @@ -42,20 +41,6 @@ async def reset_monthly_quotas(): await session.commit() -async def run_account_checks(): - """Run all account consistency checks and send results to Slack. - - This checks account balances, transactions, and other credit-related consistency - issues and reports the results to the configured Slack channel. - """ - logger.info("Running scheduled account consistency checks") - try: - await run_all_checks() - logger.info("Completed account consistency checks") - except Exception as e: - logger.error(f"Error running account consistency checks: {e}") - - def create_scheduler(): """Create and configure the APScheduler with all periodic tasks.""" # Job Store @@ -107,17 +92,6 @@ def create_scheduler(): replace_existing=True, ) - # Run account consistency checks every 2 hours at the top of the hour - scheduler.add_job( - run_account_checks, - trigger=CronTrigger( - hour="*/2", minute="0", timezone="UTC" - ), # Run every 2 hours at :00 - id="account_consistency_checks", - name="Account Consistency Checks", - replace_existing=True, - ) - return scheduler diff --git a/app/checker.py b/app/checker.py new file mode 100644 index 00000000..767babab --- /dev/null +++ b/app/checker.py @@ -0,0 +1,149 @@ +"""Checker for periodic read-only validation tasks. + +This module runs a separate scheduler for account checks and other validation +tasks that only require read-only database access. +""" + +import asyncio +import logging +import signal +import sys + +import sentry_sdk +from apscheduler.jobstores.redis import RedisJobStore +from apscheduler.schedulers.asyncio import AsyncIOScheduler +from apscheduler.triggers.cron import CronTrigger + +from app.admin.account_checking import run_quick_checks, run_slow_checks +from app.config.config import config +from models.db import init_db +from models.redis import init_redis + +logger = logging.getLogger(__name__) + +if config.sentry_dsn: + sentry_sdk.init( + dsn=config.sentry_dsn, + sample_rate=config.sentry_sample_rate, + traces_sample_rate=config.sentry_traces_sample_rate, + profiles_sample_rate=config.sentry_profiles_sample_rate, + environment=config.env, + release=config.release, + server_name="intent-checker", + ) + + +async def run_quick_account_checks(): + """Run quick account consistency checks and send results to Slack. + + This runs the faster checks for account balances, transactions, and other credit-related consistency + issues and reports the results to the configured Slack channel. + """ + logger.info("Running scheduled quick account consistency checks") + try: + await run_quick_checks() + logger.info("Completed quick account consistency checks") + except Exception as e: + logger.error(f"Error running quick account consistency checks: {e}") + + +async def run_slow_account_checks(): + """Run slow account consistency checks and send results to Slack. + + This runs the more resource-intensive checks for account balances, transactions, + and other credit-related consistency issues and reports the results to the configured Slack channel. + """ + logger.info("Running scheduled slow account consistency checks") + try: + await run_slow_checks() + logger.info("Completed slow account consistency checks") + except Exception as e: + logger.error(f"Error running slow account consistency checks: {e}") + + +def create_checker(): + """Create and configure the AsyncIOScheduler for validation checks.""" + # Job Store + jobstores = {} + if config.redis_host: + jobstores["default"] = RedisJobStore( + host=config.redis_host, + port=config.redis_port, + jobs_key="intentkit:checker:jobs", + run_times_key="intentkit:checker:run_times", + ) + logger.info(f"checker using redis store: {config.redis_host}") + + scheduler = AsyncIOScheduler(jobstores=jobstores) + + # Run quick account consistency checks every 2 hours at the top of the hour + scheduler.add_job( + run_quick_account_checks, + trigger=CronTrigger( + hour="*/2", minute="10", timezone="UTC" + ), # Run every 2 hours at :10 + id="quick_account_checks", + name="Quick Account Consistency Checks", + replace_existing=True, + ) + + # Run slow account consistency checks once a day at midnight UTC + scheduler.add_job( + run_slow_account_checks, + trigger=CronTrigger( + hour="0,12", minute="0", timezone="UTC" + ), # Run at midnight UTC + id="slow_account_checks", + name="Slow Account Consistency Checks", + replace_existing=True, + ) + + return scheduler + + +def start_checker(): + """Create, configure and start the checker scheduler.""" + scheduler = create_checker() + scheduler.start() + return scheduler + + +if __name__ == "__main__": + + async def main(): + # Initialize database + await init_db(**config.db) + + # Initialize Redis if configured + if config.redis_host: + await init_redis( + host=config.redis_host, + port=config.redis_port, + ) + + # Initialize checker + scheduler = create_checker() + + # Signal handler for graceful shutdown + def signal_handler(signum, frame): + logger.info("Received termination signal. Shutting down gracefully...") + scheduler.shutdown() + sys.exit(0) + + # Register signal handlers + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + try: + logger.info("Starting checker process...") + scheduler.start() + # Keep the main thread running + while True: + await asyncio.sleep(1) + except Exception as e: + logger.error(f"Error in checker process: {e}") + scheduler.shutdown() + sys.exit(1) + + # Run the async main function + asyncio.run(main())