Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Docling models path (backport #362) #392

Closed
wants to merge 12 commits into from
24 changes: 22 additions & 2 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# instructlab - All of these need to go away (other than sdg) - issue #6
from xdg_base_dirs import xdg_data_dirs, xdg_data_home
import openai
import yaml

# First Party
# pylint: disable=ungrouped-imports
Expand Down Expand Up @@ -220,6 +221,23 @@ def _sdg_init(ctx, pipeline):
data_dirs = [os.path.join(xdg_data_home(), "instructlab", "sdg")]
data_dirs.extend(os.path.join(dir, "instructlab", "sdg") for dir in xdg_data_dirs())

docling_model_path = None
sdg_models_path = docling_model_path
for d in data_dirs:
if os.path.exists(os.path.join(d, "models")):
sdg_models_path = os.path.join(d, "models")
break

if sdg_models_path is not None:
try:
with open(
os.path.join(sdg_models_path, "config.yaml"), "r", encoding="utf-8"
) as file:
config = yaml.safe_load(file)
docling_model_path = config["models"][0]["path"]
except (FileNotFoundError, NotADirectoryError, PermissionError) as e:
logger.warning(f"unable to read docling models path from config.yaml {e}")

for d in data_dirs:
pipeline_path = os.path.join(d, "pipelines", pipeline)
if os.path.exists(pipeline_path):
Expand Down Expand Up @@ -251,6 +269,7 @@ def load_pipeline(yaml_basename):
load_pipeline("knowledge.yaml"),
load_pipeline("freeform_skills.yaml"),
load_pipeline("grounded_skills.yaml"),
docling_model_path,
)


Expand Down Expand Up @@ -363,8 +382,8 @@ def generate_data(
max_num_tokens=max_num_tokens,
)

knowledge_pipe, freeform_skills_pipe, grounded_skills_pipe = _sdg_init(
ctx, pipeline
knowledge_pipe, freeform_skills_pipe, grounded_skills_pipe, docling_model_path = (
_sdg_init(ctx, pipeline)
)

# Make sure checkpointing is disabled (we don't want this pipeline to load checkpoints from the main pipeline)
Expand Down Expand Up @@ -392,6 +411,7 @@ def generate_data(
chunk_word_count,
document_output_dir,
model_name,
docling_model_path=docling_model_path,
)

if not samples:
Expand Down
27 changes: 27 additions & 0 deletions src/instructlab/sdg/utils/chunkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
server_ctx_size=4096,
chunk_word_count=1024,
tokenizer_model_name: str | None = None,
docling_model_path: str | None = None,
):
"""Insantiate the appropriate chunker for the provided document

Expand Down Expand Up @@ -145,6 +146,7 @@
output_dir,
chunk_word_count,
tokenizer_model_name,
docling_model_path=docling_model_path,
)

@staticmethod
Expand Down Expand Up @@ -219,6 +221,7 @@
output_dir: Path,
chunk_word_count: int,
tokenizer_model_name="mistralai/Mixtral-8x7B-Instruct-v0.1",
docling_model_path=None,
):
self.document_paths = document_paths
self.filepaths = filepaths
Expand All @@ -231,6 +234,7 @@
)

self.tokenizer = self.create_tokenizer(tokenizer_model_name)
self.docling_model_path = docling_model_path

def chunk_documents(self) -> List:
"""Semantically chunk PDF documents.
Expand All @@ -247,6 +251,8 @@
if self.document_paths == []:
return []

<<<<<<< HEAD
<<<<<<< HEAD
model_artifacts_path = StandardPdfPipeline.download_models_hf()
pipeline_options = PdfPipelineOptions(
artifacts_path=model_artifacts_path,
Expand All @@ -256,6 +262,27 @@
if ocr_options is not None:
pipeline_options.do_ocr = True
pipeline_options.ocr_options = ocr_options
=======
if not self.docling_model_path.exists():
=======
if self.docling_model_path is None:
<<<<<<< HEAD
>>>>>>> f8f6959 (Update src/instructlab/sdg/utils/chunkers.py)
logger.info(
f"Docling models not found on disk, downloading models..."
)
=======
logger.info("Docling models not found on disk, downloading models...")
>>>>>>> 16c6f45 (Address mypy issues and small typos)

Check failure on line 276 in src/instructlab/sdg/utils/chunkers.py

View workflow job for this annotation

GitHub Actions / pylint

E0001: Parsing failed: 'invalid decimal literal (instructlab.sdg.utils.chunkers, line 276)' (syntax-error)
self.docling_model_path = StandardPdfPipeline.download_models_hf()
else:
logger.info("Found the docling models")

pipeline_options = PdfPipelineOptions(artifacts_path=self.docling_model_path)

# Keep OCR models on the CPU instead of GPU
pipeline_options.ocr_options.use_gpu = False
>>>>>>> 1b984e0 (Rebase)
converter = DocumentConverter(
format_options={
InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options)
Expand Down
4 changes: 4 additions & 0 deletions src/instructlab/sdg/utils/taxonomy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import yaml

# Local
from .chunkers import DocumentChunker

Check failure on line 27 in src/instructlab/sdg/utils/taxonomy.py

View workflow job for this annotation

GitHub Actions / pylint

E0001: Cannot import 'chunkers' due to 'invalid decimal literal (instructlab.sdg.utils.chunkers, line 276)' (syntax-error)

# Initialize the pdf parser
PDFParser = pdf_parser_v1()
Expand Down Expand Up @@ -416,6 +416,7 @@
chunk_word_count,
document_output_dir,
model_name,
docling_model_path=None,
):
chunker = DocumentChunker(
leaf_node=leaf_node,
Expand All @@ -424,6 +425,7 @@
server_ctx_size=server_ctx_size,
chunk_word_count=chunk_word_count,
tokenizer_model_name=model_name,
docling_model_path=docling_model_path,
)
chunks = chunker.chunk_documents()

Expand Down Expand Up @@ -453,6 +455,7 @@
chunk_word_count,
document_output_dir,
model_name,
docling_model_path=None,
):
if not leaf_node:
return []
Expand All @@ -464,5 +467,6 @@
chunk_word_count,
document_output_dir,
model_name,
docling_model_path,
)
return _skill_leaf_node_to_samples(leaf_node)
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

# Standard
from unittest import mock
import pathlib
import typing

# Third Party
from datasets import Dataset
Expand All @@ -17,6 +19,14 @@
# Local
from .taxonomy import MockTaxonomy

TESTS_PATH = pathlib.Path(__file__).parent.absolute()


@pytest.fixture
def testdata_path() -> typing.Generator[pathlib.Path, None, None]:
"""Path to local test data directory"""
yield TESTS_PATH / "testdata"


def get_ctx(**kwargs) -> PipelineContext:
kwargs.setdefault("client", mock.MagicMock())
Expand Down
36 changes: 35 additions & 1 deletion tests/test_generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import yaml

# First Party
from instructlab.sdg.generate_data import _context_init, generate_data
from instructlab.sdg.generate_data import _context_init, _sdg_init, generate_data
from instructlab.sdg.llmblock import LLMBlock
from instructlab.sdg.pipeline import PipelineContext

Expand Down Expand Up @@ -548,3 +548,37 @@ def test_context_init_batch_size_optional():
batch_num_workers=32,
)
assert ctx.batch_size == 20


def test_sdg_init_docling_path_config_found(testdata_path):
with patch.dict(os.environ):
os.environ["XDG_DATA_HOME"] = str(testdata_path.joinpath("mock_xdg_data_dir"))
ctx = _context_init(
None,
"mixtral",
"foo.bar",
1,
"/checkpoint/dir",
1,
batch_size=20,
batch_num_workers=32,
)
_, _, _, docling_model_path = _sdg_init(ctx, "full")
assert docling_model_path == "/mock/docling-models"


def test_sdg_init_docling_path_config_not_found(testdata_path):
with patch.dict(os.environ):
os.environ["XDG_DATA_HOME"] = str(testdata_path.joinpath("nonexistent_dir"))
ctx = _context_init(
None,
"mixtral",
"foo.bar",
1,
"/checkpoint/dir",
1,
batch_size=20,
batch_num_workers=32,
)
_, _, _, docling_model_path = _sdg_init(ctx, "full")
assert docling_model_path is None
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
models:
- path: /mock/docling-models
source: https://huggingface.co/ds4sd/docling-models
revision: main
Loading