Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for PDF file uploads as context for LLM queries #3638

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Changes from 5 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 119 additions & 11 deletions fastchat/serve/gradio_block_arena_vision_anony.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
import numpy as np
from typing import Union

import os
import PyPDF2
import nest_asyncio
from llama_parse import LlamaParse

from fastchat.constants import (
TEXT_MODERATION_MSG,
IMAGE_MODERATION_MSG,
Expand Down Expand Up @@ -242,6 +247,79 @@ def clear_history(request: gr.Request):
+ [""]
)

def is_pdf(file_path):
try:
with open(file_path, 'rb') as file:
header = file.read(5) # Read the first 5 bytes
return header == b'%PDF-'
except Exception as e:
print(f"Error: {e}")
return False

def is_image(file_path):
magic_numbers = {
b'\xff\xd8\xff': 'JPEG',
b'\x89PNG\r\n\x1a\n': 'PNG',
b'GIF87a': 'GIF',
b'GIF89a': 'GIF',
b'BM': 'BMP',
b'\x00\x00\x01\x00': 'ICO',
b'\x49\x49\x2a\x00': 'TIFF',
b'\x4d\x4d\x00\x2a': 'TIFF',
b'RIFF': 'WebP',
}
try:
with open(file_path, 'rb') as file:
header = file.read(8) # Read the first 8 bytes
for magic in magic_numbers:
if header.startswith(magic):
return True
return False
except Exception as e:
print(f"Error reading file: {e}")
return False

nest_asyncio.apply() # Ensure compatibility with async environments
CodingWithTim marked this conversation as resolved.
Show resolved Hide resolved

def pdf_parse(pdf_path):
# Set API key, can also be configured in the environment
api_key = os.environ["LLAMA_PARSE_KEY"]

# Initialize the LlamaParse object
parser = LlamaParse(
api_key=api_key,
result_type="markdown", # Output in Markdown format
num_workers=4, # Number of API calls for batch processing
verbose=True, # Print detailed logs
language="en" # Set language (default is English)
CodingWithTim marked this conversation as resolved.
Show resolved Hide resolved
)

pdf_name = os.path.splitext(os.path.basename(pdf_path))[0]
extra_info = {"file_name": pdf_name}

with open(pdf_path, "rb") as pdf_file:
# Pass the file object and extra info for parsing
documents = parser.load_data(pdf_file, extra_info=extra_info)

# Save the parsed content to a Markdown file
markdown_content = documents[0].text if documents else ""

return markdown_content

def wrap_query_context(user_query, query_context):
#TODO: refactor to split up user query and query context.
# lines = input.split("\n\n[USER QUERY]", 1)
# user_query = lines[1].strip()
# query_context = lines[0][len('[QUERY CONTEXT]\n\n'): ]
reformatted_query_context = (
f"[QUERY CONTEXT]\n"
f"<details>\n"
f"<summary>Expand context details</summary>\n\n"
f"{query_context}\n\n"
f"</details>"
)
markdown = reformatted_query_context + f"\n\n[USER QUERY]\n\n{user_query}"
return markdown

def add_text(
state0,
Expand All @@ -253,10 +331,18 @@ def add_text(
request: gr.Request,
):
if isinstance(chat_input, dict):
text, images = chat_input["text"], chat_input["files"]
text, files = chat_input["text"], chat_input["files"]
else:
text = chat_input
images = []
files = []

images = []

# currently support up to one pdf or one image
# if is_pdf(files[0]):
# pdfs = files
if is_image(files[0]):
images = files

ip = get_ip(request)
logger.info(f"add_text (anony). ip: {ip}. len: {len(text)}")
Expand All @@ -267,7 +353,7 @@ def add_text(
if states[0] is None:
assert states[1] is None

if len(images) > 0:
if len(files) > 0 and is_image(files[0]):
model_left, model_right = get_battle_pair(
context.all_vision_models,
VISION_BATTLE_TARGETS,
Expand Down Expand Up @@ -350,7 +436,8 @@ def add_text(
+ [""]
)

text = text[:BLIND_MODE_INPUT_CHAR_LEN_LIMIT] # Hard cut-off
if is_image(files[0]):
CodingWithTim marked this conversation as resolved.
Show resolved Hide resolved
text = text[:BLIND_MODE_INPUT_CHAR_LEN_LIMIT] # Hard cut-off
for i in range(num_sides):
post_processed_text = _prepare_text_with_image(
states[i], text, images, csam_flag=csam_flag
Expand All @@ -363,6 +450,26 @@ def add_text(
for i in range(num_sides):
if "deluxe" in states[i].model_name:
hint_msg = SLOW_MODEL_MSG

if is_pdf(files[0]):
document_text = pdf_parse(files[0])
prompt_text = f"""
The following is the content of a document:

{document_text}

Based on this document, answer the following question:

{text}
"""
post_processed_text = wrap_query_context(text, prompt_text)

# text = text[:BLIND_MODE_INPUT_CHAR_LEN_LIMIT] # Hard cut-off
for i in range(num_sides):
states[i].conv.append_message(states[i].conv.roles[0], post_processed_text)
states[i].conv.append_message(states[i].conv.roles[1], None)
states[i].skip_next = False

return (
states
+ [x.to_gradio_chatbot() for x in states]
Expand Down Expand Up @@ -471,10 +578,10 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None):
)

multimodal_textbox = gr.MultimodalTextbox(
file_types=["image"],
file_types=["file"],
CodingWithTim marked this conversation as resolved.
Show resolved Hide resolved
show_label=False,
container=True,
placeholder="Enter your prompt or add image here",
placeholder="Enter your prompt here. You can also upload image or PDF file",
elem_id="input_box",
scale=3,
)
Expand All @@ -483,11 +590,12 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None):
)

with gr.Row() as button_row:
if random_questions:
global vqa_samples
with open(random_questions, "r") as f:
vqa_samples = json.load(f)
random_btn = gr.Button(value="🔮 Random Image", interactive=True)
random_btn = gr.Button(value="🔮 Random Image", interactive=True)
CodingWithTim marked this conversation as resolved.
Show resolved Hide resolved
# if random_questions:
# global vqa_samples
# with open(random_questions, "r") as f:
# vqa_samples = json.load(f)
# random_btn = gr.Button(value="🔮 Random Image", interactive=True)
CodingWithTim marked this conversation as resolved.
Show resolved Hide resolved
clear_btn = gr.Button(value="🎲 New Round", interactive=False)
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
share_btn = gr.Button(value="📷 Share")
Expand Down
Loading