Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Markdown table and formula extraction support #85

Merged
merged 13 commits into from
Sep 20, 2024
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
FROM pytorch/pytorch:2.4.0-cuda11.8-cudnn9-runtime

RUN apt-get update
RUN apt-get install --fix-missing -y -q --no-install-recommends libgomp1 ffmpeg libsm6 libxext6 pdftohtml git ninja-build g++ qpdf
RUN apt-get install --fix-missing -y -q --no-install-recommends libgomp1 ffmpeg libsm6 libxext6 pdftohtml git ninja-build g++ qpdf pandoc

RUN mkdir -p /app/src
RUN mkdir -p /app/models
Expand Down
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,28 @@ we process them after sorting all segments with content. To determine their read
using distance as a criterion.


### Extracting Tables and Formulas

Our service provides a way to extract your tables and formulas in different formats.

As default, formula segments' "text" property will include the formula in LaTeX format.

You can also extract tables in different formats like "markdown", "latex", or "html" but this is not a default option.
To extract the tables like this, you should set "extraction_format" parameter. Some example usages shown below:

```
curl -X POST -F 'file=@/PATH/TO/PDF/pdf_name.pdf' localhost:5060 -F "extraction_format=latex"
```
```
curl -X POST -F 'file=@/PATH/TO/PDF/pdf_name.pdf' localhost:5060/fast -F "extraction_format=markdown"
```

You should be aware that this additional extraction process can make the process much longer, especially if you have a large number of tables.

(For table extraction, we are using [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy)
and for formula extraction, we are using [RapidLaTeXOCR](https://github.com/RapidAI/RapidLaTeXOCR))


## Benchmarks

### Performance
Expand Down
5 changes: 4 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,7 @@ lightgbm==4.5.0
huggingface_hub==0.24.3
setuptools==72.1.0
roman==4.2
hydra-core==1.3.2
hydra-core==1.3.2
pypandoc==1.13
rapid-latex-ocr==0.0.7
git+https://github.com/UniModal4Reasoning/StructEqTable-Deploy.git@fd06078bfa9364849eb39330c075dd63cbed73ff
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
name=PROJECT_NAME,
packages=["pdf_tokens_type_trainer", "pdf_features", "pdf_token_type_labels", "fast_trainer"],
package_dir={"": "src"},
version="0.8",
version="0.9",
url="https://github.com/huridocs/pdf-document-layout-analysis",
author="HURIDOCS",
description="This tool is for PDF document layout analysis",
Expand Down
12 changes: 5 additions & 7 deletions src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import PlainTextResponse
from starlette.concurrency import run_in_threadpool

from catch_exceptions import catch_exceptions
from configuration import service_logger
from pdf_layout_analysis.get_xml import get_xml
Expand All @@ -30,19 +29,18 @@ async def error():

@app.post("/")
@catch_exceptions
async def run(file: UploadFile = File(...), fast: bool = Form(False)):
async def run(file: UploadFile = File(...), fast: bool = Form(False), extraction_format: str = Form("")):
if fast:
return await run_in_threadpool(analyze_pdf_fast, file.file.read(), "")

return await run_in_threadpool(analyze_pdf, file.file.read(), "")
return await run_in_threadpool(analyze_pdf_fast, file.file.read(), "", extraction_format)
return await run_in_threadpool(analyze_pdf, file.file.read(), "", extraction_format)


@app.post("/save_xml/{xml_file_name}")
@catch_exceptions
async def analyze_and_save_xml(file: UploadFile = File(...), xml_file_name: str | None = None, fast: bool = Form(False)):
if fast:
return await run_in_threadpool(analyze_pdf_fast, file.file.read(), xml_file_name)
return await run_in_threadpool(analyze_pdf, file.file.read(), xml_file_name)
return await run_in_threadpool(analyze_pdf_fast, file.file.read(), xml_file_name, "")
return await run_in_threadpool(analyze_pdf, file.file.read(), xml_file_name, "")


@app.get("/get_xml/{xml_file_name}", response_class=PlainTextResponse)
Expand Down
35 changes: 35 additions & 0 deletions src/extraction_formats/extract_formula_formats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import io
from PIL.Image import Image
from rapid_latex_ocr import LatexOCR
from data_model.PdfImages import PdfImages
from fast_trainer.PdfSegment import PdfSegment
from pdf_token_type_labels.TokenType import TokenType


def get_latex_format(model: LatexOCR, formula_image: Image):
buffer = io.BytesIO()
formula_image.save(buffer, format="jpeg")
image_bytes = buffer.getvalue()
result, elapsed_time = model(image_bytes)
return result


def extract_formula_format(pdf_images: PdfImages, predicted_segments: list[PdfSegment]):
formula_segments = [
(index, segment) for index, segment in enumerate(predicted_segments) if segment.segment_type == TokenType.FORMULA
]
if not formula_segments:
return

model = LatexOCR()

for index, formula_segment in formula_segments:
page_image: Image = pdf_images.pdf_images[formula_segment.page_number - 1]
left, top = formula_segment.bounding_box.left, formula_segment.bounding_box.top
width, height = formula_segment.bounding_box.width, formula_segment.bounding_box.height
formula_image = page_image.crop((left, top, left + width, top + height))
try:
extracted_formula = get_latex_format(model, formula_image)
except RuntimeError:
continue
predicted_segments[index].text_content = extracted_formula
79 changes: 79 additions & 0 deletions src/extraction_formats/extract_table_formats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import time
from typing import Optional

import torch
from PIL import Image
from struct_eqtable import build_model

from configuration import service_logger
from data_model.PdfImages import PdfImages
from fast_trainer.PdfSegment import PdfSegment
from pdf_token_type_labels.TokenType import TokenType


def get_table_format(
model,
raw_image: Image = None,
image_path: str = "",
max_waiting_time: int = 1000,
extraction_format: str = "latex",
) -> str:
from pypandoc import convert_text

if not raw_image:
raw_image = Image.open(image_path)

start_time = time.time()
with torch.no_grad():
output = model(raw_image)

cost_time = time.time() - start_time

if cost_time >= max_waiting_time:
warn_log = (
f"The table extraction model inference time exceeds the maximum waiting time {max_waiting_time} seconds.\n"
"Please increase the maximum waiting time or model may not support the type of input table image"
)
service_logger.info(warn_log)

for i, latex_code in enumerate(output):
for tgt_fmt in [extraction_format]:
tgt_code = convert_text(latex_code, tgt_fmt, format="latex") if tgt_fmt != "latex" else latex_code
return tgt_code


def get_model():

ckpt_path: str = "U4R/StructTable-base"
max_new_tokens: int = 2048
max_waiting_time: int = 1000
use_cpu: bool = False
tensorrt_path: Optional[str] = None
model = build_model(ckpt_path, max_new_tokens=max_new_tokens, max_time=max_waiting_time, tensorrt_path=tensorrt_path)
if not use_cpu and tensorrt_path is None:
try:
model = model.cuda()
except RuntimeError:
pass
return model


def extract_table_format(pdf_images: PdfImages, predicted_segments: list[PdfSegment], extraction_format: str):
table_segments = [
(index, segment) for index, segment in enumerate(predicted_segments) if segment.segment_type == TokenType.TABLE
]
if not table_segments:
return

model = get_model()

for index, table_segment in table_segments:
page_image: Image = pdf_images.pdf_images[table_segment.page_number - 1]
left, top = table_segment.bounding_box.left, table_segment.bounding_box.top
width, height = table_segment.bounding_box.width, table_segment.bounding_box.height
table_image = page_image.crop((left, top, left + width, top + height))
try:
extracted_table = get_table_format(model, raw_image=table_image, extraction_format=extraction_format)
except RuntimeError:
continue
predicted_segments[index].text_content = extracted_table
8 changes: 7 additions & 1 deletion src/pdf_layout_analysis/run_pdf_layout_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from typing import AnyStr
from data_model.SegmentBox import SegmentBox
from ditod.VGTTrainer import VGTTrainer
from extraction_formats.extract_formula_formats import extract_formula_format
from extraction_formats.extract_table_formats import extract_table_format
from vgt.get_json_annotations import get_annotations
from vgt.get_model_configuration import get_model_configuration
from vgt.get_most_probable_pdf_segments import get_most_probable_pdf_segments
Expand Down Expand Up @@ -49,7 +51,7 @@ def predict_doclaynet():
VGTTrainer.test(configuration, model)


def analyze_pdf(file: AnyStr, xml_file_name: str) -> list[dict]:
def analyze_pdf(file: AnyStr, xml_file_name: str, extraction_format: str = "") -> list[dict]:
pdf_path = pdf_content_to_pdf_path(file)
service_logger.info(f"Creating PDF images")
pdf_images_list: list[PdfImages] = [PdfImages.from_pdf_path(pdf_path, "", xml_file_name)]
Expand All @@ -59,6 +61,10 @@ def analyze_pdf(file: AnyStr, xml_file_name: str) -> list[dict]:
remove_files()
predicted_segments = get_most_probable_pdf_segments("doclaynet", pdf_images_list, False)
predicted_segments = get_reading_orders(pdf_images_list, predicted_segments)
extract_formula_format(pdf_images_list[0], predicted_segments)
if extraction_format:
extract_table_format(pdf_images_list[0], predicted_segments, extraction_format)

return [
SegmentBox.from_pdf_segment(pdf_segment, pdf_images_list[0].pdf_features.pages).to_dict()
for pdf_segment in predicted_segments
Expand Down
12 changes: 9 additions & 3 deletions src/pdf_layout_analysis/run_pdf_layout_analysis_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from pathlib import Path
from typing import AnyStr

from data_model.PdfImages import PdfImages
from extraction_formats.extract_formula_formats import extract_formula_format
from extraction_formats.extract_table_formats import extract_table_format
from fast_trainer.ParagraphExtractorTrainer import ParagraphExtractorTrainer
from fast_trainer.model_configuration import MODEL_CONFIGURATION as PARAGRAPH_EXTRACTION_CONFIGURATION
from pdf_features.PdfFeatures import PdfFeatures
from pdf_layout_analysis.run_pdf_layout_analysis import pdf_content_to_pdf_path
from pdf_tokens_type_trainer.TokenTypeTrainer import TokenTypeTrainer
from pdf_tokens_type_trainer.ModelConfiguration import ModelConfiguration
Expand All @@ -14,7 +16,7 @@
from data_model.SegmentBox import SegmentBox


def analyze_pdf_fast(file: AnyStr, xml_file_name: str = "") -> list[dict]:
def analyze_pdf_fast(file: AnyStr, xml_file_name: str = "", extraction_format: str = "") -> list[dict]:
pdf_path = pdf_content_to_pdf_path(file)
service_logger.info("Creating Paragraph Tokens [fast]")

Expand All @@ -23,9 +25,13 @@ def analyze_pdf_fast(file: AnyStr, xml_file_name: str = "") -> list[dict]:
if xml_path and not xml_path.parent.exists():
os.makedirs(xml_path.parent, exist_ok=True)

pdf_features = PdfFeatures.from_pdf_path(pdf_path, str(xml_path) if xml_path else None)
pdf_images = PdfImages.from_pdf_path(pdf_path, str(xml_path) if xml_path else None)
pdf_features = pdf_images.pdf_features
token_type_trainer = TokenTypeTrainer([pdf_features], ModelConfiguration())
token_type_trainer.set_token_types(join(ROOT_PATH, "models", "token_type_lightgbm.model"))
trainer = ParagraphExtractorTrainer(pdfs_features=[pdf_features], model_configuration=PARAGRAPH_EXTRACTION_CONFIGURATION)
segments = trainer.get_pdf_segments(join(ROOT_PATH, "models", "paragraph_extraction_lightgbm.model"))
extract_formula_format(pdf_images, segments)
if extraction_format:
extract_table_format(pdf_images, segments, extraction_format)
return [SegmentBox.from_pdf_segment(pdf_segment, pdf_features.pages).to_dict() for pdf_segment in segments]
28 changes: 28 additions & 0 deletions src/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,31 @@ def test_text_extraction_fast(self):
self.assertEqual(response_json.split()[0], "Document")
self.assertEqual(response_json.split()[1], "Big")
self.assertEqual(response_json.split()[-1], "TEXT")

def test_table_extraction(self):
with open(f"{ROOT_PATH}/test_pdfs/table.pdf", "rb") as stream:
files = {"file": stream}
data = {"extraction_format": "markdown"}

response = requests.post(f"{self.service_url}", files=files, data=data)

response_json = response.json()
table_text = response_json[0]["text"]
self.assertEqual(response.status_code, 200)
self.assertIn("**Column 1**", table_text.split("\n")[0])
self.assertIn("**Column 2**", table_text.split("\n")[0])
self.assertIn("Data 1A", table_text.split("\n")[2])
self.assertIn("Data 2B", table_text.split("\n")[3])

def test_formula_extraction(self):
with open(f"{ROOT_PATH}/test_pdfs/formula.pdf", "rb") as stream:
files = {"file": stream}
data = {"extraction_format": "latex"}

response = requests.post(f"{self.service_url}", files=files, data=data)

response_json = response.json()
formula_text = response_json[1]["text"]
self.assertEqual(response.status_code, 200)
self.assertIn("E_{_{v r i o r}}", formula_text)
self.assertIn("-\\ \\Theta||", formula_text)
Binary file added test_pdfs/formula.pdf
Binary file not shown.
Binary file added test_pdfs/table.pdf
Binary file not shown.
Loading