Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added support to configure DB path, result in CSV report #13

Merged
merged 2 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
*.db
.venv/
*.csv
.mypy_cache/
.venv/
.python-version
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ This will display detailed usage information.
- `-t`, `--api_timeout`: Timeout (in seconds) for API requests (default: 10).
- `-i`, `--poll_interval`: Interval (in seconds) between API status polls (default: 5).
- `-p`, `--parallel_call_count`: Number of parallel API calls (default: 10).
- `--csv_report`: Path to export the detailed report as a CSV file.
- `--db_path`: Path where the SQlite DB file is stored (default: './file_processing.db')
- `--retry_failed`: Retry processing of failed files.
- `--retry_pending`: Retry processing of pending files by making new requests.
- `--skip_pending`: Skip processing of pending files.
Expand All @@ -67,7 +69,6 @@ This will display detailed usage information.
- `--print_report`: Print a detailed report of all processed files at the end.
- `--exclude_metadata`: Exclude metadata on tokens consumed and the context passed to LLMs for prompt studio exported tools in the result for each file.
- `--no_verify`: Disable SSL certificate verification. (By default, SSL verification is enabled.)
- `--csv_report`: Path to export the detailed report as a CSV file.

## Usage Examples

Expand Down
103 changes: 56 additions & 47 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
from tqdm import tqdm
from unstract.api_deployments.client import APIDeploymentsClient

DB_NAME = "file_processing.db"
global_arguments = None
logger = logging.getLogger(__name__)


Expand All @@ -29,6 +27,7 @@ class Arguments:
api_timeout: int = 10
poll_interval: int = 5
input_folder_path: str = ""
db_path: str = ""
parallel_call_count: int = 5
retry_failed: bool = False
retry_pending: bool = False
Expand All @@ -42,8 +41,8 @@ class Arguments:


# Initialize SQLite DB
def init_db():
conn = sqlite3.connect(DB_NAME)
def init_db(args: Arguments):
conn = sqlite3.connect(args.db_path)
c = conn.cursor()

# Create the table if it doesn't exist
Expand Down Expand Up @@ -89,7 +88,7 @@ def init_db():

# Check if the file is already processed
def skip_file_processing(file_name, args: Arguments):
conn = sqlite3.connect(DB_NAME)
conn = sqlite3.connect(args.db_path)
c = conn.cursor()
c.execute(
"SELECT execution_status FROM file_status WHERE file_name = ?", (file_name,)
Expand Down Expand Up @@ -124,6 +123,7 @@ def update_db(
time_taken,
status_code,
status_api_endpoint,
args: Arguments
):

total_embedding_cost = None
Expand All @@ -138,7 +138,7 @@ def update_db(
if execution_status == "ERROR":
error_message = extract_error_message(result)

conn = sqlite3.connect(DB_NAME)
conn = sqlite3.connect(args.db_path)
conn.set_trace_callback(
lambda x: (
logger.debug(f"[{file_name}] Executing statement: {x}")
Expand Down Expand Up @@ -232,8 +232,8 @@ def extract_error_message(result):
return result.get("error", "No error message found")

# Print final summary with count of each status and average time using a single SQL query
def print_summary():
conn = sqlite3.connect(DB_NAME)
def print_summary(args: Arguments):
conn = sqlite3.connect(args.db_path)
c = conn.cursor()

# Fetch count and average time for each status
Expand All @@ -255,8 +255,8 @@ def print_summary():
print(f"Status '{status}': {count}")


def print_report():
conn = sqlite3.connect(DB_NAME)
def print_report(args: Arguments):
conn = sqlite3.connect(args.db_path)
c = conn.cursor()

# Fetch required fields, including total_cost and total_tokens
Expand Down Expand Up @@ -318,36 +318,36 @@ def print_report():

print("\nNote: For more detailed error messages, use the CSV report argument.")

def export_report_to_csv(output_path):
conn = sqlite3.connect(DB_NAME)
def export_report_to_csv(args: Arguments):
conn = sqlite3.connect(args.db_path)
c = conn.cursor()

c.execute(
"""
SELECT file_name, execution_status, time_taken, total_embedding_cost, total_embedding_tokens, total_llm_cost, total_llm_tokens, error_message
SELECT file_name, execution_status, result, time_taken, total_embedding_cost, total_embedding_tokens, total_llm_cost, total_llm_tokens, error_message
FROM file_status
"""
)
report_data = c.fetchall()
conn.close()

if not report_data:
print("No data available to export.")
print("No data available to export as CSV.")
return

# Define the headers
headers = [
"File Name", "Execution Status", "Time Elapsed (seconds)",
"File Name", "Execution Status", "Result", "Time Elapsed (seconds)",
"Total Embedding Cost", "Total Embedding Tokens",
"Total LLM Cost", "Total LLM Tokens", "Error Message"
]

try:
with open(output_path, 'w', newline='') as csvfile:
with open(args.csv_report, 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(headers) # Write headers
writer.writerows(report_data) # Write data rows
print(f"CSV successfully exported to {output_path}")
print(f"CSV successfully exported to '{args.csv_report}'")
except Exception as e:
print(f"Error exporting to CSV: {e}")

Expand All @@ -357,7 +357,7 @@ def get_status_endpoint(file_path, client, args: Arguments):
status_endpoint = None

# If retry_pending is True, check if the status API endpoint is available
conn = sqlite3.connect(DB_NAME)
conn = sqlite3.connect(args.db_path)
c = conn.cursor()
c.execute(
"SELECT status_api_endpoint FROM file_status WHERE file_name = ? AND execution_status NOT IN ('COMPLETED', 'ERROR')",
Expand All @@ -382,7 +382,7 @@ def get_status_endpoint(file_path, client, args: Arguments):

# Fresh API call to process the file
execution_status = "STARTING"
update_db(file_path, execution_status, None, None, None, None)
update_db(file_path, execution_status, None, None, None, None, args=args)
response = client.structure_file(file_paths=[file_path])
logger.debug(f"[{file_path}] Response of initial API call: {response}")
status_endpoint = response.get(
Expand All @@ -397,6 +397,7 @@ def get_status_endpoint(file_path, client, args: Arguments):
None,
status_code,
status_endpoint,
args=args
)
return status_endpoint, execution_status, response

Expand Down Expand Up @@ -436,7 +437,7 @@ def process_file(
execution_status = response.get("execution_status")
status_code = response.get("status_code") # Default to 200 if not provided
update_db(
file_path, execution_status, None, None, status_code, status_endpoint
file_path, execution_status, None, None, status_code, status_endpoint, args=args
)

result = response
Expand All @@ -456,7 +457,7 @@ def process_file(
end_time = time.time()
time_taken = round(end_time - start_time, 2)
update_db(
file_path, execution_status, result, time_taken, status_code, status_endpoint
file_path, execution_status, result, time_taken, status_code, status_endpoint, args=args
)
logger.info(f"[{file_path}]: Processing completed: {execution_status}")

Expand Down Expand Up @@ -501,14 +502,14 @@ def load_folder(args: Arguments):


def main():
parser = argparse.ArgumentParser(description="Process files using the API.")
parser = argparse.ArgumentParser(description="Process files using Unstract's API deployment")
parser.add_argument(
"-e",
"--api_endpoint",
dest="api_endpoint",
type=str,
required=True,
help="API Endpoint to use for processing the files.",
help="API Endpoint to use for processing the files",
)
parser.add_argument(
"-k",
Expand All @@ -524,55 +525,68 @@ def main():
dest="api_timeout",
type=int,
default=10,
help="Time in seconds to wait before switching to async mode.",
help="Time in seconds to wait before switching to async mode (default: 10)",
)
parser.add_argument(
"-i",
"--poll_interval",
dest="poll_interval",
type=int,
default=5,
help="Time in seconds the process will sleep between polls in async mode.",
help="Time in seconds the process will sleep between polls in async mode (default: 5)",
)
parser.add_argument(
"-f",
"--input_folder_path",
dest="input_folder_path",
type=str,
required=True,
help="Path where the files to process are present.",
help="Path where the files to process are present",
)
parser.add_argument(
"-p",
"--parallel_call_count",
dest="parallel_call_count",
type=int,
default=5,
help="Number of calls to be made in parallel.",
help="Number of calls to be made in parallel (default: 5)",
)
parser.add_argument(
"--db_path",
dest="db_path",
type=str,
default="file_processing.db",
help="Path where the SQlite DB file is stored (default: './file_processing.db)'",
)
parser.add_argument(
'--csv_report',
dest="csv_report",
type=str,
help='Path to export the detailed report as a CSV file',
)
parser.add_argument(
"--retry_failed",
dest="retry_failed",
action="store_true",
help="Retry processing of failed files.",
help="Retry processing of failed files (default: True)",
)
parser.add_argument(
"--retry_pending",
dest="retry_pending",
action="store_true",
help="Retry processing of pending files as new request (Without this it will try to fetch the results using status API).",
help="Retry processing of pending files as new request (Without this it will try to fetch the results using status API) (default: True)",
)
parser.add_argument(
"--skip_pending",
dest="skip_pending",
action="store_true",
help="Skip processing of pending files (Over rides --retry-pending).",
help="Skip processing of pending files (overrides --retry-pending) (default: True)",
)
parser.add_argument(
"--skip_unprocessed",
dest="skip_unprocessed",
action="store_true",
help="Skip unprocessed files while retry processing of failed files.",
help="Skip unprocessed files while retry processing of failed files (default: True)",
)
parser.add_argument(
"--log_level",
Expand All @@ -586,52 +600,47 @@ def main():
"--print_report",
dest="print_report",
action="store_true",
help="Print a detailed report of all file processed.",
help="Print a detailed report of all file processed (default: True)",
)

parser.add_argument(
"--exclude_metadata",
dest="include_metadata",
action="store_false",
help="Exclude metadata on tokens consumed and the context passed to LLMs for prompt studio exported tools in the result for each file.",
help="Exclude metadata on tokens consumed and the context passed to LLMs for prompt studio exported tools in the result for each file (default: False)",
)

parser.add_argument(
"--no_verify",
dest="verify",
action="store_false",
help="Disable SSL certificate verification.",
)

parser.add_argument(
'--csv_report',
dest="csv_report",
type=str,
help='Path to export the detailed report as a CSV file',
help="Disable SSL certificate verification (default: False)",
)

args = Arguments(**vars(parser.parse_args()))

ch = logging.StreamHandler(sys.stdout)
ch.setLevel(args.log_level)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
ch.setFormatter(formatter)
logging.basicConfig(level=args.log_level, handlers=[ch])

logger.warning(f"Running with params: {args}")

init_db() # Initialize DB
init_db(args=args) # Initialize DB

load_folder(args=args)

print_summary() # Print summary at the end
print_summary(args=args) # Print summary at the end
if args.print_report:
print_report()
print_report(args=args)
logger.warning(
"Elapsed time calculation of a file which was resumed"
" from pending state will not be correct"
)

if args.csv_report:
export_report_to_csv(args.csv_report)
export_report_to_csv(args=args)


if __name__ == "__main__":
Expand Down
Loading