Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for PDF file uploads as context for LLM queries #3638

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions fastchat/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
IMAGE_MODERATION_MSG = (
"$MODERATION$ YOUR IMAGE VIOLATES OUR CONTENT MODERATION GUIDELINES."
)
PDF_MODERATION_MSG = "$MODERATION$ YOUR PDF VIOLATES OUR CONTENT MODERATION GUIDELINES."
MODERATION_MSG = "$MODERATION$ YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES."
CONVERSATION_LIMIT_MSG = "YOU HAVE REACHED THE CONVERSATION LENGTH LIMIT. PLEASE CLEAR HISTORY AND START A NEW CONVERSATION."
INACTIVE_MSG = "THIS SESSION HAS BEEN INACTIVE FOR TOO LONG. PLEASE REFRESH THIS PAGE."
Expand All @@ -39,6 +40,9 @@
)
# Maximum conversation turns
CONVERSATION_TURN_LIMIT = 50
# Maximum PDF Page Limit
PDF_PAGE_LIMIT = 50
PDF_LIMIT_MSG = f"YOU HAVE REACHED THE MAXIMUM PDF PAGE LIMIT ({PDF_PAGE_LIMIT} PAGES). PLEASE UPLOAD A SMALLER DOCUMENT."
# Session expiration time
SESSION_EXPIRATION_TIME = 3600
# The output dir of log files
Expand Down
33 changes: 27 additions & 6 deletions fastchat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,18 +362,39 @@ def update_last_message(self, message: str):
def to_gradio_chatbot(self):
"""Convert the conversation to gradio chatbot format."""
from fastchat.serve.vision.image import ImageFormat
import re

ret = []
for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
if type(msg) is tuple:
msg, images = msg
image = images[0] # Only one image on gradio at one time
if image.image_format == ImageFormat.URL:
img_str = f'<img src="{image.url}" alt="user upload image" />'
elif image.image_format == ImageFormat.BYTES:
img_str = f'<img src="data:image/{image.filetype};base64,{image.base64_str}" alt="user upload image" />'
msg = img_str + msg.replace("<image>\n", "").strip()

pattern = re.compile("!\[\]\(_page_\d_Figure_\d\.jpeg\)")
embed_locations = pattern.findall(msg)

pdfchat = False
for i, embed_str in enumerate(embed_locations):
if i >= len(images):
break

image = images[i]
msg = msg.replace(
embed_str,
f'<img src="data:image/{image.filetype};base64,{image.base64_str}" alt="document image" />',
)
pdfchat = True

if not pdfchat:
# vision arena only supports one image on gradio at one time
image = images[0]
if image.image_format == ImageFormat.URL:
img_str = (
f'<img src="{image.url}" alt="user upload image" />'
)
elif image.image_format == ImageFormat.BYTES:
img_str = f'<img src="data:image/{image.filetype};base64,{image.base64_str}" alt="user upload image" />'
msg = img_str + msg.replace("<image>\n", "").strip()

ret.append([msg, None])
else:
Expand Down
169 changes: 163 additions & 6 deletions fastchat/serve/gradio_block_arena_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,29 @@ def get_vqa_sample():
return (res, path)


def is_pdf(file_path):
try:
with open(file_path, "rb") as file:
header = file.read(5) # Read the first 5 bytes
return header == b"%PDF-"
except Exception as e:
print(f"Error: {e}")
return False


def set_visible_image(textbox):
images = textbox["files"]
if len(images) == 0:
import filetype

files = textbox["files"]
if len(files) == 0:
return invisible_image_column
elif len(images) > 1:
elif len(files) > 1:
gr.Warning(
"We only support single image conversations. Please start a new round if you would like to chat using this image."
"We only support single image or document conversations. Please start a new round if you would like to chat using this image or document."
)

return visible_image_column
elif filetype.is_image(files[0]):
return visible_image_column
return invisible_image_column


def set_invisible_image():
Expand Down Expand Up @@ -166,6 +179,113 @@ def report_csam_image(state, image):
pass


def wrap_pdfchat_query(query, document):
# TODO: Considering redesign the context format.
# document_context = f"""
# The following is the content of a document:
# {document}
# Based on this document, answer the following question:
# {query}
# """

reformatted_query_context = (
f"Answer the user query given the context.\n"
f"[QUERY CONTEXT]\n"
f"<details>\n"
f"<summary>Expand context details</summary>\n\n"
f"{document}\n\n"
f"</details>"
f"\n\n[USER QUERY]\n\n{query}"
)

return reformatted_query_context


# LLAMA_PARSE_MAX_RETRY = 2
# LLAMAPARSE_SUPPORTED_LANGS = {
# "English": "en",
# "Chinese": "ch_sim",
# "Russian": "ru",
# "Spanish": "es",
# "Japanese": "ja",
# "Korean": "ko",
# "French": "fr",
# "German": "de",
# "Vietnamese": "vi",
# }


# def parse_pdf(file_path):
# from llama_parse import LlamaParse

# assert (
# "LLAMA_CLOUD_API_KEY" in os.environ
# ), "Make sure to specify LlamaParse API key."

# for _ in range(LLAMA_PARSE_MAX_RETRY):
# try:
# documents = LlamaParse(
# result_type="markdown",
# verbose=True,
# languages=list(LLAMAPARSE_SUPPORTED_LANGS.values()),
# accurate_mode=True,
# ).load_data(file_path)
# assert len(documents) > 0
# break
# except AssertionError as e:
# continue

# output = "\n".join(
# [f"Page {i+1}:\n{doc.text}\n" for i, doc in enumerate(documents)]
# )

# return output


PDFPARSE_MAX_RETRY = 2
PDFPARSE_SUPPORTED_LANGS = {
"English": "en",
"Chinese": "zh",
"Russian": "ru",
"Spanish": "es",
"Japanese": "ja",
"Korean": "ko",
"French": "fr",
"German": "de",
"Vietnamese": "vi",
}
MARKER_PDFPARSE_CONFIG = {
"output_format": "markdown",
"languages": ",".join(PDFPARSE_SUPPORTED_LANGS.values()),
}


def parse_pdf(file_path):
from marker.config.parser import ConfigParser
from marker.models import create_model_dict
from marker.converters.pdf import PdfConverter

output_md, output_images = None, None
for _ in range(PDFPARSE_MAX_RETRY):
try:
config_parser = ConfigParser(MARKER_PDFPARSE_CONFIG)

converter = PdfConverter(
config=config_parser.generate_config_dict(),
artifact_dict=create_model_dict(),
processor_list=config_parser.get_processors(),
renderer=config_parser.get_renderer(),
)
rendered = converter(file_path)
output_md = rendered.markdown
output_images = list(rendered.images.values())
break
except AssertionError as e:
continue

return output_md, output_images


def _prepare_text_with_image(state, text, images, csam_flag):
if len(images) > 0:
if len(state.conv.get_images()) > 0:
Expand All @@ -177,6 +297,29 @@ def _prepare_text_with_image(state, text, images, csam_flag):
return text


# def _prepare_text_with_pdf(text, pdfs):
# if len(pdfs) > 0:
# document_content = parse_pdf(pdfs[0])
# print("Document processed")
# text = wrap_pdfchat_query(text, document_content)

# return text


def _prepare_text_with_pdf(text, pdfs):
if len(pdfs) > 0:
parsed_text, imgs = parse_pdf(pdfs[0])
print("Document processed")
wrapped_text = wrap_pdfchat_query(text, parsed_text)

imgs = convert_pdf_images_to_conversation_format(imgs)

if len(imgs) > 0:
return wrapped_text, imgs
return wrapped_text
return text


# NOTE(chris): take multiple images later on
def convert_images_to_conversation_format(images):
import base64
Expand All @@ -191,6 +334,20 @@ def convert_images_to_conversation_format(images):
return conv_images


def convert_pdf_images_to_conversation_format(images):
MAX_NSFW_ENDPOINT_IMAGE_SIZE_IN_MB = 5 / 1.5
conv_images = []
if len(images) > 0:
for img in images:
# pdf parser returns a PIL image object instead of path
conv_images.append(
Image(url="").to_conversation_format(
MAX_NSFW_ENDPOINT_IMAGE_SIZE_IN_MB, pil_img=img
)
)
return conv_images


def moderate_input(state, text, all_conv_text, model_list, images, ip):
text_flagged = moderation_filter(all_conv_text, model_list)
# flagged = moderation_filter(text, [state.model_name])
Expand Down
Loading
Loading