Skip to content

Commit

Permalink
Merge pull request #8 from Zipstack/enhance-report-details
Browse files Browse the repository at this point in the history
Added total cost and tokens for embedings and extraction LLM in detailed report
  • Loading branch information
ritwik-g authored Nov 15, 2024
2 parents d05e5c4 + 58790de commit 0687d6e
Showing 1 changed file with 86 additions and 9 deletions.
95 changes: 86 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sqlite3
import sys
import time
import textwrap
from dataclasses import dataclass
from datetime import datetime
from functools import partial
Expand Down Expand Up @@ -51,6 +52,10 @@ def init_db():
time_taken REAL,
status_code INTEGER,
status_api_endpoint TEXT,
total_embedding_cost REAL,
total_embedding_tokens INTEGER,
total_llm_cost REAL,
total_llm_tokens INTEGER,
updated_at TEXT,
created_at TEXT
)"""
Expand Down Expand Up @@ -97,6 +102,15 @@ def update_db(
status_code,
status_api_endpoint,
):

total_embedding_cost = None
total_embedding_tokens = None
total_llm_cost = None
total_llm_tokens = None

if result is not None:
total_embedding_cost, total_llm_cost, total_embedding_tokens, total_llm_tokens = calculate_cost_and_tokens(result)

conn = sqlite3.connect(DB_NAME)
conn.set_trace_callback(
lambda x: (
Expand All @@ -109,16 +123,20 @@ def update_db(
now = datetime.now().isoformat()
c.execute(
"""
INSERT OR REPLACE INTO file_status (file_name, execution_status, result, time_taken, status_code, status_api_endpoint, updated_at, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, COALESCE((SELECT created_at FROM file_status WHERE file_name = ?), ?))
""",
INSERT OR REPLACE INTO file_status (file_name, execution_status, result, time_taken, status_code, status_api_endpoint, total_embedding_cost, total_embedding_tokens, total_llm_cost, total_llm_tokens, updated_at, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, COALESCE((SELECT created_at FROM file_status WHERE file_name = ?), ?))
""",
(
file_name,
execution_status,
json.dumps(result),
time_taken,
status_code,
status_api_endpoint,
total_embedding_cost,
total_embedding_tokens,
total_llm_cost,
total_llm_tokens,
now,
file_name,
now,
Expand All @@ -127,10 +145,57 @@ def update_db(
conn.commit()
conn.close()

# Calculate total cost and tokens for detailed report
def calculate_cost_and_tokens(result):

total_embedding_cost = None
total_embedding_tokens = None
total_llm_cost = None
total_llm_tokens = None

# Extract 'extraction_result' from the result
extraction_result = result.get("extraction_result", [])

if not extraction_result:
return total_embedding_cost, total_llm_cost, total_embedding_tokens, total_llm_tokens

extraction_data = extraction_result[0].get("result", "")

# If extraction_data is a string, attempt to parse it as JSON
if isinstance(extraction_data, str):
try:
extraction_data = json.loads(extraction_data) if extraction_data else {}
except json.JSONDecodeError:
logger.warning("Failed to decode JSON for extraction data; defaulting to empty dictionary.")
extraction_data = {}


metadata = extraction_data.get("metadata", None)
embedding_llm = metadata.get("embedding") if metadata else None
extraction_llm = metadata.get("extraction_llm") if metadata else None

#Process embedding costs and tokens if embedding_llm list exists and is not empty
if embedding_llm:
total_embedding_cost = 0.0
total_embedding_tokens = 0
for item in embedding_llm:
total_embedding_cost += float(item.get("cost_in_dollars", "0"))
total_embedding_tokens += item.get("embedding_tokens", 0)

#Process embedding costs and tokens if extraction_llm list exists and is not empty
if extraction_llm:
total_llm_cost = 0.0
total_llm_tokens = 0
for item in extraction_llm:
total_llm_cost += float(item.get("cost_in_dollars", "0"))
total_llm_tokens += item.get("total_tokens", 0)

return total_embedding_cost, total_llm_cost, total_embedding_tokens, total_llm_tokens


# Print final summary with count of each status and average time using a single SQL query
def print_summary():
conn = sqlite3.connect("file_processing.db")
conn = sqlite3.connect(DB_NAME)
c = conn.cursor()

# Fetch count and average time for each status
Expand All @@ -153,13 +218,13 @@ def print_summary():


def print_report():
conn = sqlite3.connect("file_processing.db")
conn = sqlite3.connect(DB_NAME)
c = conn.cursor()

# Fetch count and average time for each status
# Fetch required fields, including total_cost and total_tokens
c.execute(
"""
SELECT file_name, execution_status, time_taken
SELECT file_name, execution_status, time_taken, total_embedding_cost, total_embedding_tokens, total_llm_cost, total_llm_tokens
FROM file_status
"""
)
Expand All @@ -170,8 +235,20 @@ def print_report():
print("\nDetailed Report:")
if report_data:
# Tabulate the data with column headers
headers = ["File Name", "Execution Status", "Time Elapsed (seconds)"]
print(tabulate(report_data, headers=headers, tablefmt="pretty"))
headers = ["File Name", "Execution Status", "Time Elapsed (seconds)", "Total Embedding Cost", "Total Embedding Tokens", "Total LLM Cost", "Total LLM Tokens"]

# Wrap text in each column to a specific width (e.g., 30 characters for file names and 20 for others) and return None if the value is NULL
formatted_data = []
for row in report_data:
formatted_row = [
"None" if cell is None else
textwrap.fill(str(cell), width=30) if isinstance(cell, str) else
f"{cell:.8f}" if isinstance(cell, float) else cell
for cell in row
]
formatted_data.append(formatted_row)

print(tabulate(formatted_data, headers=headers, tablefmt="pretty"))
else:
print("No records found in the database.")

Expand Down

0 comments on commit 0687d6e

Please sign in to comment.