Skip to content

Commit

Permalink
Fix get xml
Browse files Browse the repository at this point in the history
  • Loading branch information
gabriel-piles committed Oct 4, 2024
1 parent b9e6a56 commit 4b21828
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 18 deletions.
4 changes: 2 additions & 2 deletions src/data_model/PdfImages.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ def remove_images():

@staticmethod
def from_pdf_path(pdf_path: str | Path, pdf_name: str = "", xml_file_name: str = ""):
xml_path = Path(join(XMLS_PATH, xml_file_name)) if xml_file_name else None
xml_path = None if not xml_file_name else Path(XMLS_PATH, xml_file_name)

if xml_path and not xml_path.parent.exists():
os.makedirs(xml_path.parent, exist_ok=True)

pdf_features: PdfFeatures = PdfFeatures.from_pdf_path(pdf_path, str(xml_path) if xml_path else None)
pdf_features: PdfFeatures = PdfFeatures.from_pdf_path(pdf_path, xml_path)

if pdf_name:
pdf_features.file_name = pdf_name
Expand Down
4 changes: 2 additions & 2 deletions src/pdf_features/PdfFeatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ def is_pdf_encrypted(pdf_path):
return False if "File is not encrypted" in result.stdout else True

@staticmethod
def from_pdf_path(pdf_path, xml_path: str = None):
def from_pdf_path(pdf_path, xml_path: str | Path = None):
remove_xml = False if xml_path else True
xml_path = xml_path if xml_path else join(tempfile.gettempdir(), "pdf_etree.xml")
xml_path = str(xml_path) if xml_path else join(tempfile.gettempdir(), "pdf_etree.xml")

if PdfFeatures.is_pdf_encrypted(pdf_path):
subprocess.run(["qpdf", "--decrypt", "--replace-input", pdf_path])
Expand Down
3 changes: 1 addition & 2 deletions src/pdf_layout_analysis/get_xml.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import os
from os.path import join
from pathlib import Path

from configuration import XMLS_PATH


def get_xml(xml_file_name: str) -> str:
xml_file_path = Path(join(XMLS_PATH, xml_file_name))
xml_file_path = Path(XMLS_PATH, xml_file_name)

with open(xml_file_path, mode="r") as file:
content = file.read()
Expand Down
23 changes: 11 additions & 12 deletions src/pdf_layout_analysis/run_pdf_layout_analysis_fast.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import os
from os.path import join
from pathlib import Path
from typing import AnyStr

from data_model.PdfImages import PdfImages
Expand All @@ -12,26 +10,27 @@
from pdf_tokens_type_trainer.TokenTypeTrainer import TokenTypeTrainer
from pdf_tokens_type_trainer.ModelConfiguration import ModelConfiguration

from configuration import ROOT_PATH, service_logger, XMLS_PATH
from configuration import ROOT_PATH, service_logger
from data_model.SegmentBox import SegmentBox


def analyze_pdf_fast(file: AnyStr, xml_file_name: str = "", extraction_format: str = "") -> list[dict]:
pdf_path = pdf_content_to_pdf_path(file)
service_logger.info("Creating Paragraph Tokens [fast]")

xml_path = Path(join(XMLS_PATH, xml_file_name)) if xml_file_name else None
pdf_images = PdfImages.from_pdf_path(pdf_path=pdf_path, pdf_name="", xml_file_name=xml_file_name)

if xml_path and not xml_path.parent.exists():
os.makedirs(xml_path.parent, exist_ok=True)

pdf_images = PdfImages.from_pdf_path(pdf_path, str(xml_path) if xml_path else None)
pdf_features = pdf_images.pdf_features
token_type_trainer = TokenTypeTrainer([pdf_features], ModelConfiguration())
token_type_trainer = TokenTypeTrainer([pdf_images.pdf_features], ModelConfiguration())
token_type_trainer.set_token_types(join(ROOT_PATH, "models", "token_type_lightgbm.model"))
trainer = ParagraphExtractorTrainer(pdfs_features=[pdf_features], model_configuration=PARAGRAPH_EXTRACTION_CONFIGURATION)

trainer = ParagraphExtractorTrainer(
pdfs_features=[pdf_images.pdf_features], model_configuration=PARAGRAPH_EXTRACTION_CONFIGURATION
)
segments = trainer.get_pdf_segments(join(ROOT_PATH, "models", "paragraph_extraction_lightgbm.model"))

extract_formula_format(pdf_images, segments)
if extraction_format:
extract_table_format(pdf_images, segments, extraction_format)
return [SegmentBox.from_pdf_segment(pdf_segment, pdf_features.pages).to_dict() for pdf_segment in segments]

pdf_images.remove_images()
return [SegmentBox.from_pdf_segment(pdf_segment, pdf_images.pdf_features.pages).to_dict() for pdf_segment in segments]
24 changes: 24 additions & 0 deletions src/test_end_to_end.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from time import sleep

import requests
from unittest import TestCase
from configuration import ROOT_PATH
Expand Down Expand Up @@ -116,6 +118,28 @@ def test_regular_pdf_fast(self):
self.assertEqual(842, results_dict[0]["page_height"])
self.assertEqual("Section header", results_dict[0]["type"])

def test_save_xml_fast(self):
xml_name = "test_fast.xml"
with open(f"{ROOT_PATH}/test_pdfs/regular.pdf", "rb") as stream:
files = {"file": stream}
data = {"fast": "True"}
requests.post(f"{self.service_url}/save_xml/{xml_name}", files=files, data=data)

result_xml = requests.get(f"{self.service_url}/get_xml/{xml_name}")
self.assertEqual(200, result_xml.status_code)
self.assertIsNotNone(result_xml.text)

def test_save_xml(self):
xml_name = "test.xml"
with open(f"{ROOT_PATH}/test_pdfs/regular.pdf", "rb") as stream:
files = {"file": stream}
data = {"fast": "False"}
requests.post(f"{self.service_url}/save_xml/{xml_name}", files=files, data=data)

result_xml = requests.get(f"{self.service_url}/get_xml/{xml_name}")
self.assertEqual(200, result_xml.status_code)
self.assertIsNotNone(result_xml.text)

def test_korean(self):
with open(f"{ROOT_PATH}/test_pdfs/korean.pdf", "rb") as stream:
files = {"file": stream}
Expand Down

0 comments on commit 4b21828

Please sign in to comment.