diff --git a/src/data_model/PdfImages.py b/src/data_model/PdfImages.py index 3b36bfc..5013ae3 100644 --- a/src/data_model/PdfImages.py +++ b/src/data_model/PdfImages.py @@ -38,12 +38,12 @@ def remove_images(): @staticmethod def from_pdf_path(pdf_path: str | Path, pdf_name: str = "", xml_file_name: str = ""): - xml_path = Path(join(XMLS_PATH, xml_file_name)) if xml_file_name else None + xml_path = None if not xml_file_name else Path(XMLS_PATH, xml_file_name) if xml_path and not xml_path.parent.exists(): os.makedirs(xml_path.parent, exist_ok=True) - pdf_features: PdfFeatures = PdfFeatures.from_pdf_path(pdf_path, str(xml_path) if xml_path else None) + pdf_features: PdfFeatures = PdfFeatures.from_pdf_path(pdf_path, xml_path) if pdf_name: pdf_features.file_name = pdf_name diff --git a/src/pdf_features/PdfFeatures.py b/src/pdf_features/PdfFeatures.py index 523c2e9..2625304 100644 --- a/src/pdf_features/PdfFeatures.py +++ b/src/pdf_features/PdfFeatures.py @@ -107,9 +107,9 @@ def is_pdf_encrypted(pdf_path): return False if "File is not encrypted" in result.stdout else True @staticmethod - def from_pdf_path(pdf_path, xml_path: str = None): + def from_pdf_path(pdf_path, xml_path: str | Path = None): remove_xml = False if xml_path else True - xml_path = xml_path if xml_path else join(tempfile.gettempdir(), "pdf_etree.xml") + xml_path = str(xml_path) if xml_path else join(tempfile.gettempdir(), "pdf_etree.xml") if PdfFeatures.is_pdf_encrypted(pdf_path): subprocess.run(["qpdf", "--decrypt", "--replace-input", pdf_path]) diff --git a/src/pdf_layout_analysis/get_xml.py b/src/pdf_layout_analysis/get_xml.py index 78ed55f..313ee2a 100644 --- a/src/pdf_layout_analysis/get_xml.py +++ b/src/pdf_layout_analysis/get_xml.py @@ -1,12 +1,11 @@ import os -from os.path import join from pathlib import Path from configuration import XMLS_PATH def get_xml(xml_file_name: str) -> str: - xml_file_path = Path(join(XMLS_PATH, xml_file_name)) + xml_file_path = Path(XMLS_PATH, xml_file_name) with open(xml_file_path, mode="r") as file: content = file.read() diff --git a/src/pdf_layout_analysis/run_pdf_layout_analysis_fast.py b/src/pdf_layout_analysis/run_pdf_layout_analysis_fast.py index a15caa7..a7879b3 100644 --- a/src/pdf_layout_analysis/run_pdf_layout_analysis_fast.py +++ b/src/pdf_layout_analysis/run_pdf_layout_analysis_fast.py @@ -1,6 +1,4 @@ -import os from os.path import join -from pathlib import Path from typing import AnyStr from data_model.PdfImages import PdfImages @@ -12,7 +10,7 @@ from pdf_tokens_type_trainer.TokenTypeTrainer import TokenTypeTrainer from pdf_tokens_type_trainer.ModelConfiguration import ModelConfiguration -from configuration import ROOT_PATH, service_logger, XMLS_PATH +from configuration import ROOT_PATH, service_logger from data_model.SegmentBox import SegmentBox @@ -20,18 +18,19 @@ def analyze_pdf_fast(file: AnyStr, xml_file_name: str = "", extraction_format: s pdf_path = pdf_content_to_pdf_path(file) service_logger.info("Creating Paragraph Tokens [fast]") - xml_path = Path(join(XMLS_PATH, xml_file_name)) if xml_file_name else None + pdf_images = PdfImages.from_pdf_path(pdf_path=pdf_path, pdf_name="", xml_file_name=xml_file_name) - if xml_path and not xml_path.parent.exists(): - os.makedirs(xml_path.parent, exist_ok=True) - - pdf_images = PdfImages.from_pdf_path(pdf_path, str(xml_path) if xml_path else None) - pdf_features = pdf_images.pdf_features - token_type_trainer = TokenTypeTrainer([pdf_features], ModelConfiguration()) + token_type_trainer = TokenTypeTrainer([pdf_images.pdf_features], ModelConfiguration()) token_type_trainer.set_token_types(join(ROOT_PATH, "models", "token_type_lightgbm.model")) - trainer = ParagraphExtractorTrainer(pdfs_features=[pdf_features], model_configuration=PARAGRAPH_EXTRACTION_CONFIGURATION) + + trainer = ParagraphExtractorTrainer( + pdfs_features=[pdf_images.pdf_features], model_configuration=PARAGRAPH_EXTRACTION_CONFIGURATION + ) segments = trainer.get_pdf_segments(join(ROOT_PATH, "models", "paragraph_extraction_lightgbm.model")) + extract_formula_format(pdf_images, segments) if extraction_format: extract_table_format(pdf_images, segments, extraction_format) - return [SegmentBox.from_pdf_segment(pdf_segment, pdf_features.pages).to_dict() for pdf_segment in segments] + + pdf_images.remove_images() + return [SegmentBox.from_pdf_segment(pdf_segment, pdf_images.pdf_features.pages).to_dict() for pdf_segment in segments] diff --git a/src/test_end_to_end.py b/src/test_end_to_end.py index 95e011a..0728b4a 100644 --- a/src/test_end_to_end.py +++ b/src/test_end_to_end.py @@ -1,3 +1,5 @@ +from time import sleep + import requests from unittest import TestCase from configuration import ROOT_PATH @@ -116,6 +118,28 @@ def test_regular_pdf_fast(self): self.assertEqual(842, results_dict[0]["page_height"]) self.assertEqual("Section header", results_dict[0]["type"]) + def test_save_xml_fast(self): + xml_name = "test_fast.xml" + with open(f"{ROOT_PATH}/test_pdfs/regular.pdf", "rb") as stream: + files = {"file": stream} + data = {"fast": "True"} + requests.post(f"{self.service_url}/save_xml/{xml_name}", files=files, data=data) + + result_xml = requests.get(f"{self.service_url}/get_xml/{xml_name}") + self.assertEqual(200, result_xml.status_code) + self.assertIsNotNone(result_xml.text) + + def test_save_xml(self): + xml_name = "test.xml" + with open(f"{ROOT_PATH}/test_pdfs/regular.pdf", "rb") as stream: + files = {"file": stream} + data = {"fast": "False"} + requests.post(f"{self.service_url}/save_xml/{xml_name}", files=files, data=data) + + result_xml = requests.get(f"{self.service_url}/get_xml/{xml_name}") + self.assertEqual(200, result_xml.status_code) + self.assertIsNotNone(result_xml.text) + def test_korean(self): with open(f"{ROOT_PATH}/test_pdfs/korean.pdf", "rb") as stream: files = {"file": stream}