Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefer tesserocr over easyocr, if available (backport #369) #391

Merged
merged 2 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
click>=8.1.7,<9.0.0
datasets>=2.18.0,<3.0.0
docling>=2.4.2,<3.0.0
docling[tesserocr]>=2.4.2,<3.0.0
GitPython>=3.1.42,<4.0.0
httpx>=0.25.0,<1.0.0
instructlab-schema>=0.4.0
Expand Down
71 changes: 61 additions & 10 deletions src/instructlab/sdg/utils/chunkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,14 @@
from datasets import Dataset
from docling.datamodel.base_models import InputFormat
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import PdfPipelineOptions
from docling.document_converter import (
ConversionStatus,
DocumentConverter,
PdfFormatOption,
from docling.datamodel.pipeline_options import (
EasyOcrOptions,
OcrOptions,
PdfPipelineOptions,
TesseractOcrOptions,
)
from docling.pipeline.standard_pdf_pipeline import StandardPdfPipeline
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter
from tabulate import tabulate
from transformers import AutoTokenizer

logger = logging.getLogger(__name__)
_DEFAULT_CHUNK_OVERLAP = 100
Expand All @@ -35,6 +33,38 @@ def _num_chars_from_tokens(num_tokens) -> int:
return int(num_tokens * 4) # 1 token ~ 4 English character


def resolve_ocr_options() -> OcrOptions:
# First, attempt to use tesserocr
try:
ocr_options = TesseractOcrOptions()
# pylint: disable=import-outside-toplevel
# Third Party
from docling.models.tesseract_ocr_model import TesseractOcrModel

_ = TesseractOcrModel(True, ocr_options)
return ocr_options
except ImportError:
# No tesserocr, so try something else
pass
try:
ocr_options = EasyOcrOptions()
# Keep easyocr models on the CPU instead of GPU
ocr_options.use_gpu = False
# triggers torch loading, import lazily
# pylint: disable=import-outside-toplevel
# Third Party
from docling.models.easyocr_model import EasyOcrModel

_ = EasyOcrModel(True, ocr_options)
return ocr_options
except ImportError:
# no easyocr either, so don't use any OCR
logger.error(
"Failed to load Tesseract and EasyOCR - disabling optical character recognition in PDF documents"
)
return None


class FileTypes(Enum):
MD = ".md"
PDF = ".pdf"
Expand Down Expand Up @@ -208,13 +238,24 @@ def chunk_documents(self) -> List:
Returns:
List: a list of chunks from the documents
"""
# triggers torch loading, import lazily
# pylint: disable=import-outside-toplevel
# Third Party
from docling.document_converter import DocumentConverter, PdfFormatOption
from docling.pipeline.standard_pdf_pipeline import StandardPdfPipeline

if self.document_paths == []:
return []

model_artifacts_path = StandardPdfPipeline.download_models_hf()
pipeline_options = PdfPipelineOptions(artifacts_path=model_artifacts_path)
# Keep OCR models on the CPU instead of GPU
pipeline_options.ocr_options.use_gpu = False
pipeline_options = PdfPipelineOptions(
artifacts_path=model_artifacts_path,
do_ocr=False,
)
ocr_options = resolve_ocr_options()
if ocr_options is not None:
pipeline_options.do_ocr = True
pipeline_options.ocr_options = ocr_options
converter = DocumentConverter(
format_options={
InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options)
Expand Down Expand Up @@ -309,6 +350,11 @@ def create_tokenizer(self, model_name: str):
Returns:
AutoTokenizer: The tokenizer instance.
"""
# import lazily to not load transformers at top level
# pylint: disable=import-outside-toplevel
# Third Party
from transformers import AutoTokenizer

try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info(f"Successfully loaded tokenizer from: {model_name}")
Expand Down Expand Up @@ -540,6 +586,11 @@ def export_documents(self, converted_docs: Iterable[ConversionResult]):
Returns:
Path: path to directory with docling json artifacts
"""
# triggers torch loading, import lazily
# pylint: disable=import-outside-toplevel
# Third Party
from docling.document_converter import ConversionStatus

docling_artifacts_path = self.output_dir / "docling-artifacts"
docling_artifacts_path.mkdir(parents=True, exist_ok=True)

Expand Down
14 changes: 14 additions & 0 deletions tests/functional/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Standard
import pathlib
import typing

# Third Party
import pytest

TESTS_PATH = pathlib.Path(__file__).parent.parent.absolute()


@pytest.fixture
def testdata_path() -> typing.Generator[pathlib.Path, None, None]:
"""Path to local test data directory"""
yield TESTS_PATH / "testdata"
11 changes: 11 additions & 0 deletions tests/functional/test_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SPDX-License-Identifier: Apache-2.0

# Standard
import pathlib
import subprocess
import sys


def test_sdg_imports(testdata_path: pathlib.Path):
script = testdata_path / "leanimports.py"
subprocess.check_call([sys.executable, str(script)], text=True)
52 changes: 52 additions & 0 deletions tests/test_chunkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

# Standard
from pathlib import Path
from unittest.mock import MagicMock, patch
import tempfile

# Third Party
from docling.datamodel.pipeline_options import EasyOcrOptions, TesseractOcrOptions
import pytest

# First Party
Expand All @@ -13,6 +15,7 @@
DocumentChunker,
FileTypes,
TextSplitChunker,
resolve_ocr_options,
)

# Local
Expand Down Expand Up @@ -86,3 +89,52 @@ def test_chunker_factory_empty_filetype(documents_dir):
output_dir=temp_dir,
tokenizer_model_name="instructlab/merlinite-7b-lab",
)


def test_resolve_ocr_options_is_not_none():
"""
Test that resolve_ocr_options does not return None, which means it
found a valid OCR library on the machine running this test
"""
ocr_options = resolve_ocr_options()
assert ocr_options is not None


@patch("docling.models.tesseract_ocr_model.TesseractOcrModel")
def test_resolve_ocr_options_prefers_tessserocr(mock_tesseract):
"""
Ensure resolve_ocr_options defaults to tesserocr if we're able
to load that library without error.
"""
mock_tesseract.return_value = MagicMock()
ocr_options = resolve_ocr_options()
assert isinstance(ocr_options, TesseractOcrOptions)


@patch("docling.models.tesseract_ocr_model.TesseractOcrModel")
def test_resolve_ocr_options_falls_back_to_easyocr(mock_tesseract):
"""
Ensure resolve_ocr_options falls back to easyocr if we cannot
load tesserocr.
"""
mock_tesseract.side_effect = ImportError("mock import error")
ocr_options = resolve_ocr_options()
assert isinstance(ocr_options, EasyOcrOptions)


@patch("docling.models.tesseract_ocr_model.TesseractOcrModel")
@patch("docling.models.easyocr_model.EasyOcrModel")
@patch("logging.Logger.error")
def test_resolve_ocr_options_none_found_logs_error(
mock_logger, mock_easyocr, mock_tesseract
):
"""
If we cannot load tesserocr or easyocr, ensure
resolve_ocr_options logs an error so that users are aware optical
character recognition in PDFs will be disabled.
"""
mock_tesseract.side_effect = ImportError("mock import error")
mock_easyocr.side_effect = ImportError("mock import error")
ocr_options = resolve_ocr_options()
assert ocr_options is None
mock_logger.assert_called()
15 changes: 15 additions & 0 deletions tests/testdata/leanimports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Helper for test_sdg_imports"""

# Standard
import sys

# block slow imports
for unwanted in ["deepspeed", "llama_cpp", "torch", "transformers", "vllm"]:
# importlib raises ModuleNotFound when sys.modules value is None.
assert unwanted not in sys.modules
sys.modules[unwanted] = None # type: ignore[assignment]

# First Party
# This will trigger errors if any of the import chain tries to load
# the unwanted modules
from instructlab.sdg.generate_data import generate_data
Loading