Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for PDF file uploads as context for LLM queries #3638

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 85 additions & 5 deletions fastchat/serve/gradio_block_arena_vision_anony.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
"""

import json
import subprocess
CodingWithTim marked this conversation as resolved.
Show resolved Hide resolved
import time

import gradio as gr
import numpy as np
from typing import Union

import os
import PyPDF2

from fastchat.constants import (
TEXT_MODERATION_MSG,
IMAGE_MODERATION_MSG,
Expand Down Expand Up @@ -242,6 +246,56 @@ def clear_history(request: gr.Request):
+ [""]
)

def extract_text_from_pdf(pdf_file_path):
"""Extract text from a PDF file."""
try:
with open(pdf_file_path, 'rb') as f:
reader = PyPDF2.PdfReader(f)
pdf_text = ""
for page in reader.pages:
pdf_text += page.extract_text()
return pdf_text
except Exception as e:
logger.error(f"Failed to extract text from PDF: {e}")
return None
CodingWithTim marked this conversation as resolved.
Show resolved Hide resolved

def llama_parse(pdf_path):
os.environ['LLAMA_CLOUD_API_KEY'] = 'LLAMA KEY'

output_dir = 'outputs'
os.makedirs(output_dir, exist_ok=True)

pdf_name = os.path.splitext(os.path.basename(pdf_path))[0]
markdown_file_path = os.path.join(output_dir, f'{pdf_name}.md')

command = [
'llama-parse',
pdf_path,
'--result-type', 'markdown',
'--output-file', markdown_file_path
]

subprocess.run(command, check=True)
andrewwan0131 marked this conversation as resolved.
Show resolved Hide resolved

with open(markdown_file_path, 'r', encoding='utf-8') as file:
markdown_content = file.read()

return markdown_content

def wrap_query_context(user_query, query_context):
#TODO: refactor to split up user query and query context.
# lines = input.split("\n\n[USER QUERY]", 1)
# user_query = lines[1].strip()
# query_context = lines[0][len('[QUERY CONTEXT]\n\n'): ]
reformatted_query_context = (
f"[QUERY CONTEXT]\n"
f"<details>\n"
f"<summary>Expand context details</summary>\n\n"
f"{query_context}\n\n"
f"</details>"
)
markdown = reformatted_query_context + f"\n\n[USER QUERY]\n\n{user_query}"
return markdown

def add_text(
state0,
Expand All @@ -253,10 +307,14 @@ def add_text(
request: gr.Request,
):
if isinstance(chat_input, dict):
text, images = chat_input["text"], chat_input["files"]
text, files = chat_input["text"], chat_input["files"]
else:
text = chat_input
images = []
files = []

images = []

CodingWithTim marked this conversation as resolved.
Show resolved Hide resolved
file_extension = os.path.splitext(files[0])[1].lower()

ip = get_ip(request)
logger.info(f"add_text (anony). ip: {ip}. len: {len(text)}")
Expand All @@ -267,7 +325,7 @@ def add_text(
if states[0] is None:
assert states[1] is None

if len(images) > 0:
if len(files) > 0 and file_extension != ".pdf":
CodingWithTim marked this conversation as resolved.
Show resolved Hide resolved
model_left, model_right = get_battle_pair(
context.all_vision_models,
VISION_BATTLE_TARGETS,
Expand Down Expand Up @@ -363,6 +421,27 @@ def add_text(
for i in range(num_sides):
if "deluxe" in states[i].model_name:
hint_msg = SLOW_MODEL_MSG

if file_extension == ".pdf":
document_text = llama_parse(files[0])
CodingWithTim marked this conversation as resolved.
Show resolved Hide resolved
post_processed_text = f"""
The following is the content of a document:

{document_text}

Based on this document, answer the following question:

{text}
"""

post_processed_text = wrap_query_context(text, post_processed_text)

text = text[:BLIND_MODE_INPUT_CHAR_LEN_LIMIT] # Hard cut-off
CodingWithTim marked this conversation as resolved.
Show resolved Hide resolved
for i in range(num_sides):
states[i].conv.append_message(states[i].conv.roles[0], post_processed_text)
states[i].conv.append_message(states[i].conv.roles[1], None)
states[i].skip_next = False

return (
states
+ [x.to_gradio_chatbot() for x in states]
Expand Down Expand Up @@ -471,10 +550,10 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None):
)

multimodal_textbox = gr.MultimodalTextbox(
file_types=["image"],
file_types=["file"],
CodingWithTim marked this conversation as resolved.
Show resolved Hide resolved
show_label=False,
container=True,
placeholder="Enter your prompt or add image here",
placeholder="Enter your prompt or add a PDF file here",
andrewwan0131 marked this conversation as resolved.
Show resolved Hide resolved
elem_id="input_box",
scale=3,
)
Expand All @@ -483,6 +562,7 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None):
)

with gr.Row() as button_row:
random_btn = gr.Button(value="🔮 Random Image", interactive=True)
CodingWithTim marked this conversation as resolved.
Show resolved Hide resolved
if random_questions:
global vqa_samples
with open(random_questions, "r") as f:
Expand Down
Loading