Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for PDF file uploads as context for LLM queries #3638

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions fastchat/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
IMAGE_MODERATION_MSG = (
"$MODERATION$ YOUR IMAGE VIOLATES OUR CONTENT MODERATION GUIDELINES."
)
PDF_MODERATION_MSG = "$MODERATION$ YOUR PDF VIOLATES OUR CONTENT MODERATION GUIDELINES."
MODERATION_MSG = "$MODERATION$ YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES."
CONVERSATION_LIMIT_MSG = "YOU HAVE REACHED THE CONVERSATION LENGTH LIMIT. PLEASE CLEAR HISTORY AND START A NEW CONVERSATION."
INACTIVE_MSG = "THIS SESSION HAS BEEN INACTIVE FOR TOO LONG. PLEASE REFRESH THIS PAGE."
Expand All @@ -39,6 +40,9 @@
)
# Maximum conversation turns
CONVERSATION_TURN_LIMIT = 50
# Maximum PDF Page Limit
PDF_PAGE_LIMIT = 50
PDF_LIMIT_MSG = f"YOU HAVE REACHED THE MAXIMUM PDF PAGE LIMIT ({PDF_PAGE_LIMIT} PAGES). PLEASE UPLOAD A SMALLER DOCUMENT."
# Session expiration time
SESSION_EXPIRATION_TIME = 3600
# The output dir of log files
Expand Down
165 changes: 160 additions & 5 deletions fastchat/serve/gradio_block_arena_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,60 @@ def get_vqa_sample():
return (res, path)


def is_image(file_path):
magic_numbers = {
b"\xff\xd8\xff": "JPEG",
b"\x89PNG\r\n\x1a\n": "PNG",
b"GIF87a": "GIF",
b"GIF89a": "GIF",
b"BM": "BMP",
b"\x00\x00\x01\x00": "ICO",
b"\x49\x49\x2a\x00": "TIFF",
b"\x4d\x4d\x00\x2a": "TIFF",
# For WebP, the first four bytes are "RIFF", but we also check for "WEBP"
CodingWithTim marked this conversation as resolved.
Show resolved Hide resolved
# in bytes 8–12.
}

try:
with open(file_path, "rb") as f:
header = f.read(16) # Read a bit more to handle WebP safely

# Check for WebP (RIFF + WEBP)
if header.startswith(b"RIFF") and header[8:12] == b"WEBP":
return True

# Check other formats
for magic in magic_numbers:
if header.startswith(magic):
return True

return False
except Exception as e:
print(f"Error reading file: {e}")
return False


def is_pdf(file_path):
try:
with open(file_path, "rb") as file:
header = file.read(5) # Read the first 5 bytes
return header == b"%PDF-"
except Exception as e:
print(f"Error: {e}")
return False


def set_visible_image(textbox):
images = textbox["files"]
if len(images) == 0:
files = textbox["files"]
if len(files) == 0:
return invisible_image_column
elif len(images) > 1:
elif len(files) > 1:
gr.Warning(
"We only support single image conversations. Please start a new round if you would like to chat using this image."
)

return visible_image_column
elif is_image(files[0]):
return visible_image_column
return invisible_image_column


def set_invisible_image():
Expand Down Expand Up @@ -166,6 +210,108 @@ def report_csam_image(state, image):
pass


def wrap_pdfchat_query(query, document):
# TODO: Considering redesign the context format.
# document_context = f"""
# The following is the content of a document:
# {document}
# Based on this document, answer the following question:
# {query}
# """

reformatted_query_context = (
f"Answer the user query given the context.\n"
f"[QUERY CONTEXT]\n"
f"<details>\n"
f"<summary>Expand context details</summary>\n\n"
f"{document}\n\n"
f"</details>"
f"\n\n[USER QUERY]\n\n{query}"
)

return reformatted_query_context


LLAMA_PARSE_MAX_RETRY = 2
TESSERACT_SUPPORTED_LANGS = "+".join(
[
"en",
"chi_tra",
"chi_sim",
"rus",
"spa",
"jpn",
"kor",
"fra",
"deu", # German
"vie",
]
)
LLAMAPARSE_SUPPORTED_LANGS = {
"English": "en",
"Chinese": "ch_sim",
"Russian": "ru",
"Spanish": "es",
"Japanese": "ja",
"Korean": "ko",
"French": "fr",
"German": "de",
"Vietnamese": "vi",
}


def detect_language_from_doc(pdf_file_path):
from pdf2image import convert_from_path
from polyglot.detect import Detector

import pytesseract # Google's open-source OCR tool

assert os.environ[
"TESSDATA_PREFIX"
], "Make sure to specify location of train data for Tesseract."

# Convert pdf into image (first page only for efficiency)
images = convert_from_path(pdf_file_path)

extracted_text = pytesseract.image_to_string(
images[0], lang=TESSERACT_SUPPORTED_LANGS
)

languages = Detector(extracted_text, quiet=True)
# return languages
return [lang.name for lang in languages.languages if lang.name != "un"]


def parse_pdf(file_path):
from llama_parse import LlamaParse

assert (
"LLAMA_CLOUD_API_KEY" in os.environ
), "Make sure to specify LlamaParse API key."

doc_lang = detect_language_from_doc(file_path)
doc_lang = LLAMAPARSE_SUPPORTED_LANGS[doc_lang[0]]

for _ in range(LLAMA_PARSE_MAX_RETRY):
try:
documents = LlamaParse(
result_type="markdown",
verbose=True,
language=doc_lang,
accurate_mode=True,
).load_data(file_path)
assert len(documents) > 0
break
except AssertionError as e:
continue

output = "\n".join(
[f"Page {i+1}:\n{doc.text}\n" for i, doc in enumerate(documents)]
)

return output


def _prepare_text_with_image(state, text, images, csam_flag):
if len(images) > 0:
if len(state.conv.get_images()) > 0:
Expand All @@ -177,6 +323,15 @@ def _prepare_text_with_image(state, text, images, csam_flag):
return text


def _prepare_text_with_pdf(text, pdfs):
if len(pdfs) > 0:
document_content = parse_pdf(pdfs[0])
print("Document processed")
text = wrap_pdfchat_query(text, document_content)

return text


# NOTE(chris): take multiple images later on
def convert_images_to_conversation_format(images):
import base64
Expand Down
73 changes: 66 additions & 7 deletions fastchat/serve/gradio_block_arena_vision_anony.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from fastchat.constants import (
TEXT_MODERATION_MSG,
IMAGE_MODERATION_MSG,
PDF_MODERATION_MSG,
PDF_LIMIT_MSG,
PDF_PAGE_LIMIT,
MODERATION_MSG,
CONVERSATION_LIMIT_MSG,
SLOW_MODEL_MSG,
Expand Down Expand Up @@ -63,17 +66,23 @@
moderate_input,
enable_multimodal,
_prepare_text_with_image,
_prepare_text_with_pdf,
convert_images_to_conversation_format,
invisible_text,
visible_text,
disable_multimodal,
is_image,
is_pdf,
)
from fastchat.serve.gradio_global_state import Context
from fastchat.serve.remote_logger import get_remote_logger
from fastchat.utils import (
build_logger,
moderation_filter,
image_moderation_filter,
get_pdf_num_page,
upload_pdf_file_to_gcs,
hash_pdf,
)

logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
Expand All @@ -86,15 +95,19 @@

# TODO(chris): fix sampling weights
VISION_SAMPLING_WEIGHTS = {}
PDFCHAT_SAMPLING_WEIGHTS = {}

# TODO(chris): Find battle targets that make sense
VISION_BATTLE_TARGETS = {}
PDFCHAT_BATTLE_TARGETS = {}

# TODO(chris): Fill out models that require sampling boost
VISION_SAMPLING_BOOST_MODELS = []
PDFCHAT_SAMPLING_BOOST_MODELS = []

# outage models won't be sampled.
VISION_OUTAGE_MODELS = []
PDFCHAT_OUTAGE_MODELS = []


def get_vqa_sample():
Expand Down Expand Up @@ -253,16 +266,19 @@ def add_text(
request: gr.Request,
):
if isinstance(chat_input, dict):
text, images = chat_input["text"], chat_input["files"]
text, files = chat_input["text"], chat_input["files"]
else:
text = chat_input
images = []
files = []

ip = get_ip(request)
logger.info(f"add_text (anony). ip: {ip}. len: {len(text)}")
states = [state0, state1]
model_selectors = [model_selector0, model_selector1]

images = [file for file in files if is_image(file)]
pdfs = [file for file in files if is_pdf(file)]

# Init states if necessary
if states[0] is None:
assert states[1] is None
Expand All @@ -279,6 +295,26 @@ def add_text(
State(model_left, is_vision=True),
State(model_right, is_vision=True),
]
elif len(pdfs) > 0:
model_left, model_right = get_battle_pair(
context.all_pdfchat_models,
PDFCHAT_BATTLE_TARGETS,
PDFCHAT_OUTAGE_MODELS,
PDFCHAT_SAMPLING_WEIGHTS,
PDFCHAT_SAMPLING_BOOST_MODELS,
)

# Save an unique id for mapping conversation back to the file on google cloud.
unique_id = hash_pdf(pdfs[0])

states = [
State(model_left, is_vision=False, pdf_id=unique_id),
State(model_right, is_vision=False, pdf_id=unique_id),
]
upload_pdf_file_to_gcs(
pdf_file_path=pdfs[0],
filename=unique_id,
)
else:
model_left, model_right = get_battle_pair(
context.all_text_models,
Expand All @@ -287,7 +323,6 @@ def add_text(
SAMPLING_WEIGHTS,
SAMPLING_BOOST_MODELS,
)

states = [
State(model_left, is_vision=False),
State(model_right, is_vision=False),
Expand All @@ -307,10 +342,30 @@ def add_text(
+ [""]
)

if len(pdfs) > 0 and get_pdf_num_page(pdfs[0]) > PDF_PAGE_LIMIT:
logger.info(f"pdf page limit exceeded. ip: {ip}. text: {text}")
for i in range(num_sides):
states[i].skip_next = True
return (
states
+ [x.to_gradio_chatbot() for x in states]
+ [
{
"text": PDF_LIMIT_MSG
+ " PLEASE CLICK 🎲 NEW ROUND TO START A NEW CONVERSATION."
},
"",
no_change_btn,
]
+ [no_change_btn] * 7
+ [""]
)

model_list = [states[i].model_name for i in range(num_sides)]

images = convert_images_to_conversation_format(images)

# TODO: add PDF moderator
text, image_flagged, csam_flag = moderate_input(
state0, text, text, model_list, images, ip
)
Expand All @@ -323,11 +378,12 @@ def add_text(
return (
states
+ [x.to_gradio_chatbot() for x in states]
+ [{"text": CONVERSATION_LIMIT_MSG}, "", no_change_btn]
+ [
{"text": CONVERSATION_LIMIT_MSG},
"",
no_change_btn,
]
* 7
+ [no_change_btn] * 7
+ [""]
)

Expand All @@ -351,10 +407,13 @@ def add_text(
)

text = text[:BLIND_MODE_INPUT_CHAR_LEN_LIMIT] # Hard cut-off
post_processed_text = _prepare_text_with_pdf(text, pdfs)

for i in range(num_sides):
post_processed_text = _prepare_text_with_image(
states[i], text, images, csam_flag=csam_flag
states[i], post_processed_text, images, csam_flag=csam_flag
)

states[i].conv.append_message(states[i].conv.roles[0], post_processed_text)
states[i].conv.append_message(states[i].conv.roles[1], None)
states[i].skip_next = False
Expand Down Expand Up @@ -471,7 +530,7 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None):
)

multimodal_textbox = gr.MultimodalTextbox(
file_types=["image"],
file_types=["image", ".pdf"],
CodingWithTim marked this conversation as resolved.
Show resolved Hide resolved
show_label=False,
container=True,
placeholder="Enter your prompt or add image here",
Expand Down
2 changes: 2 additions & 0 deletions fastchat/serve/gradio_global_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,7 @@ class Context:
all_text_models: List[str] = field(default_factory=list)
vision_models: List[str] = field(default_factory=list)
all_vision_models: List[str] = field(default_factory=list)
pdfchat_models: List[str] = field(default_factory=list)
all_pdfchat_models: List[str] = field(default_factory=list)
models: List[str] = field(default_factory=list)
all_models: List[str] = field(default_factory=list)
Loading
Loading