Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for PDF file uploads as context for LLM queries #3638

Open
wants to merge 42 commits into
base: main
Choose a base branch
from
Open
Changes from 3 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
3857e22
added pdf context support
andrewwan0131 Dec 7, 2024
06d056b
These changes are in response to PR comments
andrewwan0131 Dec 8, 2024
cc66890
These changes are in response to PR comments
andrewwan0131 Dec 8, 2024
85767e5
Changed file detection to magic numbers and removed unnecessary libra…
andrewwan0131 Dec 8, 2024
c49344f
switch to using environment variable
CodingWithTim Dec 26, 2024
afbf7e6
new architecture and bug fixes
CodingWithTim Dec 30, 2024
8527b02
fix format
CodingWithTim Dec 30, 2024
8da825b
improve UI and efficiency
CodingWithTim Dec 30, 2024
b59cea8
fix formatting
CodingWithTim Dec 30, 2024
d5efd2c
fix first page only parsing issue
CodingWithTim Dec 30, 2024
f2905b9
fix first page only parsing issue
CodingWithTim Dec 30, 2024
e7ab73f
additional improvements
CodingWithTim Dec 30, 2024
f1c6185
add multilingual support
CodingWithTim Jan 2, 2025
f7e92e1
support google cloud storage
CodingWithTim Jan 4, 2025
0daef32
fix format
CodingWithTim Jan 4, 2025
32c6724
add pdf maximum page limit
CodingWithTim Jan 4, 2025
2cb0937
remove language detection
CodingWithTim Jan 5, 2025
e4c0f3b
fix format
CodingWithTim Jan 5, 2025
5c52665
support multimodal pdfchat and switch to marker pdf
CodingWithTim Jan 6, 2025
61284e0
switch to package implementation of 'is_image'
CodingWithTim Jan 6, 2025
586a2f6
flexible state variable for pdf_id
CodingWithTim Jan 6, 2025
2ea729c
fix error
CodingWithTim Jan 6, 2025
f2c4d64
Marker API Implemented + Updated Llama code if ever needed
Jan 31, 2025
cf9b408
Content Moderation implemented + couple of latency improvements
Feb 11, 2025
2bf158c
fixed bug where text moderation wasn't being flagged
andrewwan0131 Feb 12, 2025
06110d2
fixed bug where text moderation wasn't being flagged
andrewwan0131 Feb 12, 2025
34c7a8e
added image_resize functionality for image moderation
yixin-huang1 Feb 13, 2025
0955a76
minor text fix
andrewwan0131 Feb 13, 2025
2878d3c
fixed formatting
andrewwan0131 Feb 14, 2025
701f7c5
applied black formatting py3.10
andrewwan0131 Feb 14, 2025
ab2443c
fixed black version
andrewwan0131 Feb 14, 2025
89743ad
Revert "minor text fix"
yixin-huang1 Feb 16, 2025
63cb555
revert + fix some PR issues
yixin-huang1 Feb 16, 2025
bb24800
revert the spacing changes
yixin-huang1 Feb 16, 2025
4902d71
Update setup_pdfchat.sh
yixin-huang1 Feb 16, 2025
3528919
fix formatting
CodingWithTim Feb 19, 2025
bac54f0
FIX: add missing pdf_id
CodingWithTim Feb 19, 2025
1c6f911
FIX: add_text logic and make cleaner
CodingWithTim Feb 19, 2025
85ed193
FIX: small bug with push to gc
CodingWithTim Feb 19, 2025
0919376
Updated PDF Character Length Limit
Feb 20, 2025
a4b3b6b
Fixed bug with character limits
Feb 21, 2025
5a314de
Fixed bug with character limits
Feb 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 107 additions & 11 deletions fastchat/serve/gradio_block_arena_vision_anony.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
"""

import json
import subprocess
import time

import gradio as gr
import numpy as np
from typing import Union

import os
import PyPDF2

from fastchat.constants import (
TEXT_MODERATION_MSG,
IMAGE_MODERATION_MSG,
Expand Down Expand Up @@ -242,6 +246,71 @@ def clear_history(request: gr.Request):
+ [""]
)

def extract_text_from_pdf(pdf_file_path):
"""Extract text from a PDF file."""
try:
with open(pdf_file_path, 'rb') as f:
reader = PyPDF2.PdfReader(f)
pdf_text = ""
for page in reader.pages:
pdf_text += page.extract_text()
return pdf_text
except Exception as e:
logger.error(f"Failed to extract text from PDF: {e}")
return None

import os
import nest_asyncio
from llama_parse import LlamaParse

nest_asyncio.apply() # Ensure compatibility with async environments

def pdf_parse(pdf_path):
# Set API key, can also be configured in the environment
api_key = "LLAMA API"

# Initialize the LlamaParse object
parser = LlamaParse(
api_key=api_key,
result_type="markdown", # Output in Markdown format
num_workers=4, # Number of API calls for batch processing
verbose=True, # Print detailed logs
language="en" # Set language (default is English)
)

# Prepare the output directory and file name
output_dir = "outputs"
os.makedirs(output_dir, exist_ok=True)

pdf_name = os.path.splitext(os.path.basename(pdf_path))[0]
markdown_file_path = os.path.join(output_dir, f"{pdf_name}.md")

# Load and parse the PDF
extra_info = {"file_name": pdf_name}

with open(pdf_path, "rb") as pdf_file:
# Pass the file object and extra info for parsing
documents = parser.load_data(pdf_file, extra_info=extra_info)

# Save the parsed content to a Markdown file
markdown_content = documents[0].text if documents else ""

return markdown_content

def wrap_query_context(user_query, query_context):
#TODO: refactor to split up user query and query context.
# lines = input.split("\n\n[USER QUERY]", 1)
# user_query = lines[1].strip()
# query_context = lines[0][len('[QUERY CONTEXT]\n\n'): ]
reformatted_query_context = (
f"[QUERY CONTEXT]\n"
f"<details>\n"
f"<summary>Expand context details</summary>\n\n"
f"{query_context}\n\n"
f"</details>"
)
markdown = reformatted_query_context + f"\n\n[USER QUERY]\n\n{user_query}"
return markdown

def add_text(
state0,
Expand All @@ -253,10 +322,14 @@ def add_text(
request: gr.Request,
):
if isinstance(chat_input, dict):
text, images = chat_input["text"], chat_input["files"]
text, files = chat_input["text"], chat_input["files"]
else:
text = chat_input
images = []
files = []

images = []

file_extension = os.path.splitext(files[0])[1].lower()

ip = get_ip(request)
logger.info(f"add_text (anony). ip: {ip}. len: {len(text)}")
Expand All @@ -267,7 +340,7 @@ def add_text(
if states[0] is None:
assert states[1] is None

if len(images) > 0:
if len(files) > 0 and file_extension != ".pdf":
model_left, model_right = get_battle_pair(
context.all_vision_models,
VISION_BATTLE_TARGETS,
Expand Down Expand Up @@ -350,7 +423,8 @@ def add_text(
+ [""]
)

text = text[:BLIND_MODE_INPUT_CHAR_LEN_LIMIT] # Hard cut-off
if file_extension != ".pdf":
text = text[:BLIND_MODE_INPUT_CHAR_LEN_LIMIT] # Hard cut-off
for i in range(num_sides):
post_processed_text = _prepare_text_with_image(
states[i], text, images, csam_flag=csam_flag
Expand All @@ -363,6 +437,27 @@ def add_text(
for i in range(num_sides):
if "deluxe" in states[i].model_name:
hint_msg = SLOW_MODEL_MSG

if file_extension == ".pdf":
document_text = pdf_parse(files[0])
post_processed_text = f"""
The following is the content of a document:

{document_text}

Based on this document, answer the following question:

{text}
"""

post_processed_text = wrap_query_context(text, post_processed_text)

# text = text[:BLIND_MODE_INPUT_CHAR_LEN_LIMIT] # Hard cut-off
for i in range(num_sides):
states[i].conv.append_message(states[i].conv.roles[0], post_processed_text)
states[i].conv.append_message(states[i].conv.roles[1], None)
states[i].skip_next = False

return (
states
+ [x.to_gradio_chatbot() for x in states]
Expand Down Expand Up @@ -471,10 +566,10 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None):
)

multimodal_textbox = gr.MultimodalTextbox(
file_types=["image"],
file_types=["file"],
show_label=False,
container=True,
placeholder="Enter your prompt or add image here",
placeholder="Enter your prompt here. You can also upload image or PDF file",
elem_id="input_box",
scale=3,
)
Expand All @@ -483,11 +578,12 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None):
)

with gr.Row() as button_row:
if random_questions:
global vqa_samples
with open(random_questions, "r") as f:
vqa_samples = json.load(f)
random_btn = gr.Button(value="🔮 Random Image", interactive=True)
random_btn = gr.Button(value="🔮 Random Image", interactive=True)
# if random_questions:
# global vqa_samples
# with open(random_questions, "r") as f:
# vqa_samples = json.load(f)
# random_btn = gr.Button(value="🔮 Random Image", interactive=True)
clear_btn = gr.Button(value="🎲 New Round", interactive=False)
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
share_btn = gr.Button(value="📷 Share")
Expand Down
Loading