Skip to content

Commit

Permalink
Merge pull request #11 from Zipstack/csv-report
Browse files Browse the repository at this point in the history
Feat!: Enhance Reporting with CSV Export and Add Error Message to Reports
  • Loading branch information
ritwik-g authored Nov 26, 2024
2 parents 582b0d1 + cc8d846 commit 073a375
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 10 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ The script uses a local SQLite database (`file_processing.db`) with the followin
- `total_embedding_tokens` (INTEGER): Total tokens used for embeddings.
- `total_llm_cost` (REAL): Total cost incurred for LLM operations.
- `total_llm_tokens` (INTEGER): Total tokens used for LLM operations.
- `error_message` (TEXT): Details of errors if `execution_status` is `ERROR`; otherwise NULL.
- `updated_at` (TEXT): Last updated timestamp
- `created_at` (TEXT): Creation timestamp

Expand Down Expand Up @@ -66,6 +67,7 @@ This will display detailed usage information.
- `--print_report`: Print a detailed report of all processed files at the end.
- `--exclude_metadata`: Exclude metadata on tokens consumed and the context passed to LLMs for prompt studio exported tools in the result for each file.
- `--no_verify`: Disable SSL certificate verification. (By default, SSL verification is enabled.)
- `--csv_report`: Path to export the detailed report as a CSV file.

## Usage Examples

Expand Down
95 changes: 85 additions & 10 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sys
import time
import textwrap
import csv
from dataclasses import dataclass
from datetime import datetime
from functools import partial
Expand Down Expand Up @@ -35,6 +36,7 @@ class Arguments:
skip_unprocessed: bool = False
log_level: str = "INFO"
print_report: bool = False
csv_report: str = ""
include_metadata: bool = True
verify: bool = True

Expand All @@ -58,6 +60,7 @@ def init_db():
total_embedding_tokens INTEGER,
total_llm_cost REAL,
total_llm_tokens INTEGER,
error_message TEXT,
updated_at TEXT,
created_at TEXT
)"""
Expand All @@ -73,6 +76,7 @@ def init_db():
"total_embedding_tokens": "INTEGER",
"total_llm_cost": "REAL",
"total_llm_tokens": "INTEGER",
"error_message": "TEXT",
}

# Add missing columns
Expand Down Expand Up @@ -126,10 +130,14 @@ def update_db(
total_embedding_tokens = None
total_llm_cost = None
total_llm_tokens = None
error_message = None

if result is not None:
total_embedding_cost, total_llm_cost, total_embedding_tokens, total_llm_tokens = calculate_cost_and_tokens(result)

if execution_status == "ERROR":
error_message = extract_error_message(result)

conn = sqlite3.connect(DB_NAME)
conn.set_trace_callback(
lambda x: (
Expand All @@ -142,8 +150,8 @@ def update_db(
now = datetime.now().isoformat()
c.execute(
"""
INSERT OR REPLACE INTO file_status (file_name, execution_status, result, time_taken, status_code, status_api_endpoint, total_embedding_cost, total_embedding_tokens, total_llm_cost, total_llm_tokens, updated_at, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, COALESCE((SELECT created_at FROM file_status WHERE file_name = ?), ?))
INSERT OR REPLACE INTO file_status (file_name, execution_status, result, time_taken, status_code, status_api_endpoint, total_embedding_cost, total_embedding_tokens, total_llm_cost, total_llm_tokens, error_message, updated_at, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, COALESCE((SELECT created_at FROM file_status WHERE file_name = ?), ?))
""",
(
file_name,
Expand All @@ -156,6 +164,7 @@ def update_db(
total_embedding_tokens,
total_llm_cost,
total_llm_tokens,
error_message,
now,
file_name,
now,
Expand Down Expand Up @@ -211,6 +220,17 @@ def calculate_cost_and_tokens(result):

return total_embedding_cost, total_llm_cost, total_embedding_tokens, total_llm_tokens

# Exract error message from the result JSON
def extract_error_message(result):
result_data = json.loads(result)
# Check for error in extraction_result
extraction_result = result_data.get("extraction_result", [])
if extraction_result and isinstance(extraction_result, list):
for item in extraction_result:
if "error" in item and item["error"]:
return item["error"]
# Fallback to the parent error
return result_data.get("error", "No error message found")

# Print final summary with count of each status and average time using a single SQL query
def print_summary():
Expand Down Expand Up @@ -243,7 +263,7 @@ def print_report():
# Fetch required fields, including total_cost and total_tokens
c.execute(
"""
SELECT file_name, execution_status, time_taken, total_embedding_cost, total_embedding_tokens, total_llm_cost, total_llm_tokens
SELECT file_name, execution_status, time_taken, total_embedding_cost, total_embedding_tokens, total_llm_cost, total_llm_tokens, error_message
FROM file_status
"""
)
Expand All @@ -254,23 +274,69 @@ def print_report():
print("\nDetailed Report:")
if report_data:
# Tabulate the data with column headers
headers = ["File Name", "Execution Status", "Time Elapsed (seconds)", "Total Embedding Cost", "Total Embedding Tokens", "Total LLM Cost", "Total LLM Tokens"]
headers = [
textwrap.fill(header, width=20)
for header in [
"File Name",
"Execution Status",
"Time Elapsed (seconds)",
"Total Embedding Cost",
"Total Embedding Tokens",
"Total LLM Cost",
"Total LLM Tokens",
"Error Message"
]
]


# Wrap text in each column to a specific width (e.g., 30 characters for file names and 20 for others) and return None if the value is NULL
formatted_data = []
# Wrap text in each column to a specific width (e.g., 30 characters for file names and 20 for others) and return None if the value is NULL
for row in report_data:
formatted_row = [
"None" if cell is None else
textwrap.fill(str(cell), width=30) if isinstance(cell, str) else
f"{cell:.8f}" if isinstance(cell, float) else cell
for cell in row
]
formatted_data.append(formatted_row)
textwrap.fill(str(cell), width=30) if isinstance(cell, str) else
cell if idx == 2 else f"{cell:.8f}" if isinstance(cell, float) else cell
for idx, cell in enumerate(row)
]
formatted_data.append(formatted_row)

print(tabulate(formatted_data, headers=headers, tablefmt="pretty"))
else:
print("No records found in the database.")

def export_report_to_csv(output_path):
conn = sqlite3.connect(DB_NAME)
c = conn.cursor()

c.execute(
"""
SELECT file_name, execution_status, time_taken, total_embedding_cost, total_embedding_tokens, total_llm_cost, total_llm_tokens, error_message
FROM file_status
"""
)
report_data = c.fetchall()
conn.close()

if not report_data:
print("No data available to export.")
return

# Define the headers
headers = [
"File Name", "Execution Status", "Time Elapsed (seconds)",
"Total Embedding Cost", "Total Embedding Tokens",
"Total LLM Cost", "Total LLM Tokens", "Error Message"
]

try:
with open(output_path, 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(headers) # Write headers
writer.writerows(report_data) # Write data rows
print(f"CSV successfully exported to {output_path}")
except Exception as e:
print(f"Error exporting to CSV: {e}")


def get_status_endpoint(file_path, client, args: Arguments):
"""Returns status_endpoint, status and response (if available)"""
Expand Down Expand Up @@ -523,6 +589,12 @@ def main():
help="Disable SSL certificate verification.",
)

parser.add_argument(
'--csv_report',
dest="csv_report",
type=str,
help='Path to export the detailed report as a CSV file',
)

args = Arguments(**vars(parser.parse_args()))

Expand All @@ -543,6 +615,9 @@ def main():
"Elapsed time calculation of a file which was resumed"
" from pending state will not be correct"
)

if args.csv_report:
export_report_to_csv(args.csv_report)


if __name__ == "__main__":
Expand Down

0 comments on commit 073a375

Please sign in to comment.