From a80a3f7066688d97f0a6a7941a3edfc90a941756 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Thu, 5 Dec 2024 12:39:08 -0500 Subject: [PATCH] Add a new `instructlab.sdg.taxonomy_to_samples` API Take a first pass at separating out the data preprocessing steps from generation by adding a new top-level API (and temporary CLI) to invoke preprocessing but not generation. Signed-off-by: Ben Browning --- src/instructlab/sdg/__init__.py | 2 + .../sdg/cli/taxonomy_to_samples.py | 82 +++++++++ src/instructlab/sdg/generate_data.py | 156 ++++++++---------- src/instructlab/sdg/taxonomy.py | 134 +++++++++++++++ src/instructlab/sdg/utils/json.py | 6 + src/instructlab/sdg/utils/logging.py | 22 +++ src/instructlab/sdg/utils/taxonomy.py | 28 +++- tests/test_generate_data.py | 41 +---- tests/test_taxonomy.py | 19 ++- 9 files changed, 356 insertions(+), 134 deletions(-) create mode 100644 src/instructlab/sdg/cli/taxonomy_to_samples.py create mode 100644 src/instructlab/sdg/taxonomy.py create mode 100644 src/instructlab/sdg/utils/logging.py diff --git a/src/instructlab/sdg/__init__.py b/src/instructlab/sdg/__init__.py index 490df8e4..a3576662 100644 --- a/src/instructlab/sdg/__init__.py +++ b/src/instructlab/sdg/__init__.py @@ -29,6 +29,7 @@ "FULL_PIPELINES_PACKAGE", "SIMPLE_PIPELINES_PACKAGE", "generate_data", + "taxonomy_to_samples", ) # Local @@ -61,5 +62,6 @@ PipelineContext, ) from .registry import BlockRegistry, PromptRegistry +from .taxonomy import taxonomy_to_samples from .utils import GenerateException from .utils.taxonomy import TaxonomyReadingException diff --git a/src/instructlab/sdg/cli/taxonomy_to_samples.py b/src/instructlab/sdg/cli/taxonomy_to_samples.py new file mode 100644 index 00000000..112f764b --- /dev/null +++ b/src/instructlab/sdg/cli/taxonomy_to_samples.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Standard +import os + +# First Party +from instructlab.sdg.taxonomy import ( + DEFAULT_CHUNK_WORD_COUNT, + DEFAULT_SERVER_CTX_SIZE, + DEFAULT_TAXONOMY_BASE, + taxonomy_to_samples, +) +from instructlab.sdg.utils.logging import setup_logger + +if __name__ == "__main__": + # Standard + import argparse + + parser = argparse.ArgumentParser( + description="Turn a taxonomy into json samples suitable for use as input to data generate pipelines" + ) + + # Required args + parser.add_argument( + "--output-dir", + type=str, + required=True, + help="Directory to write the processed dataset samples into", + ) + parser.add_argument( + "--taxonomy-path", + type=str, + required=True, + help="Path to your InstructLab taxonomy", + ) + + # Optional args + parser.add_argument( + "--chunk-word-count", + type=int, + default=DEFAULT_CHUNK_WORD_COUNT, + help="Number of words per document chunk", + ) + parser.add_argument( + "--log-level", + type=str, + default=os.getenv("LOG_LEVEL", "INFO"), + help="Logging level", + ) + parser.add_argument( + "--server-ctx-size", + type=int, + default=DEFAULT_SERVER_CTX_SIZE, + help="The maximum number of tokens the inference server can handle.", + ) + parser.add_argument( + "--taxonomy-base", + type=str, + default=DEFAULT_TAXONOMY_BASE, + help="Taxonomy based used to determine what has changed - defaults to 'empty' which means consider all the taxonomy files as changed and process all of them", + ) + parser.add_argument( + "--yaml-rules", + type=str, + default=None, + help="Path to custom rules file for YAML linting", + ) + + args = parser.parse_args() + setup_logger(args.log_level) + taxonomy_to_samples( + args.taxonomy_path, + args.output_dir, + chunk_word_count=args.chunk_word_count, + server_ctx_size=args.server_ctx_size, + taxonomy_base=args.taxonomy_base, + yaml_rules=args.yaml_rules, + ) + +""" +python -m instructlab.sdg.cli.taxonomy_to_samples --taxonomy-path /path/to/my/taxonomy --output-dir /path/to/my/output +""" diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index 533db868..31643457 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -13,9 +13,9 @@ # Third Party # instructlab - All of these need to go away (other than sdg) - issue #6 +from datasets import Dataset from xdg_base_dirs import xdg_data_dirs, xdg_data_home import openai -import yaml # First Party from instructlab.sdg.blocks.llmblock import DEFAULT_MAX_NUM_TOKENS @@ -27,12 +27,9 @@ Pipeline, PipelineContext, ) +from instructlab.sdg.taxonomy import taxonomy_to_samples from instructlab.sdg.utils import GenerateException, models -from instructlab.sdg.utils.json import jldump -from instructlab.sdg.utils.taxonomy import ( - leaf_node_to_samples, - read_taxonomy_leaf_nodes, -) +from instructlab.sdg.utils.json import jldump, jlload logger = logging.getLogger(__name__) @@ -115,20 +112,21 @@ def _gen_train_data( def _knowledge_seed_example_to_test_data(seed_example, system_prompt): res = [] - for qna in seed_example["questions_and_answers"]: - user = qna["question"] + "\n" + seed_example["context"] + for i in range(3): + idx = i + 1 + user = seed_example[f"icl_query_{idx}"] + "\n" + seed_example["icl_document"] res.append( { "system": system_prompt, "user": _unescape(user), - "assistant": _unescape(qna["answer"]), + "assistant": _unescape(seed_example[f"icl_response_{idx}"]), } ) return res def _gen_test_data( - leaf_nodes, + seed_examples, output_file_test, system_prompt, ): @@ -137,30 +135,29 @@ def _gen_test_data( in instructlab/instructlab. """ test_data = [] - for _, leaf_node in leaf_nodes.items(): - for seed_example in leaf_node: - if "questions_and_answers" in seed_example: - test_data.extend( - _knowledge_seed_example_to_test_data(seed_example, system_prompt) - ) - continue + for seed_example in seed_examples: + if "icl_query_1" in seed_example: + test_data.extend( + _knowledge_seed_example_to_test_data(seed_example, system_prompt) + ) + continue - # skill seed example + # skill seed example - user = seed_example["instruction"] # question + user = seed_example["seed_question"] # question - if len(seed_example["input"]) > 0: - user += "\n" + seed_example["input"] # context + if seed_example["leaf_node_type"] == "grounded_skill": + user += "\n" + seed_example["seed_context"] # context - test_data.append( - { - "system": system_prompt, - "user": _unescape(user), - "assistant": _unescape(seed_example["output"]), # answer - } - ) + test_data.append( + { + "system": system_prompt, + "user": _unescape(user), + "assistant": _unescape(seed_example["seed_response"]), # answer + } + ) - jldump(test_data, output_file_test) + jldump(test_data, output_file_test) def _check_pipeline_dir(pipeline): @@ -208,23 +205,6 @@ def _sdg_init(ctx, pipeline): data_dirs = [os.path.join(xdg_data_home(), "instructlab", "sdg")] data_dirs.extend(os.path.join(dir, "instructlab", "sdg") for dir in xdg_data_dirs()) - docling_model_path = None - sdg_models_path = docling_model_path - for d in data_dirs: - if os.path.exists(os.path.join(d, "models")): - sdg_models_path = os.path.join(d, "models") - break - - if sdg_models_path is not None: - try: - with open( - os.path.join(sdg_models_path, "config.yaml"), "r", encoding="utf-8" - ) as file: - config = yaml.safe_load(file) - docling_model_path = config["models"][0]["path"] - except (FileNotFoundError, NotADirectoryError, PermissionError) as e: - logger.warning(f"unable to read docling models path from config.yaml {e}") - for d in data_dirs: pipeline_path = os.path.join(d, "pipelines", pipeline) if os.path.exists(pipeline_path): @@ -256,7 +236,6 @@ def load_pipeline(yaml_basename): load_pipeline("knowledge.yaml"), load_pipeline("freeform_skills.yaml"), load_pipeline("grounded_skills.yaml"), - docling_model_path, ) @@ -326,28 +305,32 @@ def generate_data( if batch_size is None: batch_size = 0 - if not os.path.exists(output_dir): - os.mkdir(output_dir) - - if not (taxonomy and os.path.exists(taxonomy)): - raise GenerateException(f"Error: taxonomy ({taxonomy}) does not exist.") - + output_dir = Path(output_dir) + output_dir.mkdir(exist_ok=True) date_suffix = datetime.now().replace(microsecond=0).isoformat().replace(":", "_") - document_output_dir = Path(output_dir) / f"documents-{date_suffix}" - - leaf_nodes = read_taxonomy_leaf_nodes( - taxonomy, taxonomy_base, yaml_rules, document_output_dir + preprocessed_output_dir = output_dir.joinpath(f"preprocessed_{date_suffix}") + + # This writes samples to disk in our output_dir and returns the + # list of files created + sample_files = taxonomy_to_samples( + taxonomy, + preprocessed_output_dir, + chunk_word_count=chunk_word_count, + server_ctx_size=server_ctx_size, + taxonomy_base=taxonomy_base, + yaml_rules=yaml_rules, ) - if not leaf_nodes: - raise GenerateException("Error: No new leaf nodes found in the taxonomy.") name = Path(model_name).stem # Just in case it is a file path output_file_messages = f"messages_{name}_{date_suffix}.jsonl" output_file_test = f"test_{name}_{date_suffix}.jsonl" output_file_train = f"train_{name}_{date_suffix}.jsonl" + all_samples = [] + for sample_file in sample_files: + all_samples.extend(jlload(sample_file)) _gen_test_data( - leaf_nodes, + all_samples, os.path.join(output_dir, output_file_test), system_prompt, ) @@ -368,8 +351,8 @@ def generate_data( max_num_tokens=max_num_tokens, ) - knowledge_pipe, freeform_skills_pipe, grounded_skills_pipe, docling_model_path = ( - _sdg_init(ctx, pipeline) + knowledge_pipe, freeform_skills_pipe, grounded_skills_pipe = _sdg_init( + ctx, pipeline ) # Make sure checkpointing is disabled (we don't want this pipeline to load checkpoints from the main pipeline) @@ -390,39 +373,34 @@ def generate_data( ) generated_data = [] - empty_sdg_leaf_nodes = [] - for leaf_node in leaf_nodes.values(): - is_knowledge = False - leaf_node_path = leaf_node[0]["taxonomy_path"].replace("->", "_") - samples = leaf_node_to_samples( - leaf_node, - taxonomy, - server_ctx_size, - chunk_word_count, - document_output_dir, - model_name, - docling_model_path=docling_model_path, - ) - + empty_input_sample_files = [] + for sample_file in sample_files: + logger.debug("Generating data from input sample file: %s", sample_file) + samples = jlload(sample_file) if not samples: - raise GenerateException("Error: No samples found in leaf node.") - - if "document" in samples.column_names: + raise GenerateException( + "Error: No samples found in input file {sample_file}" + ) + # For now we assume every sample in the file is the same type + first_sample = samples[0] + leaf_node_path = first_sample["leaf_node_path"] + leaf_node_type = first_sample["leaf_node_type"] + is_knowledge = False + if leaf_node_type == "knowledge": pipe = knowledge_pipe is_knowledge = True - - elif "seed_context" in samples.column_names: + elif leaf_node_type == "grounded_skill": pipe = grounded_skills_pipe - else: pipe = freeform_skills_pipe - logger.debug("Samples: %s", samples) + samples_ds = Dataset.from_list(samples) + logger.debug("Samples: %s", samples_ds) - new_generated_data = pipe.generate(samples, leaf_node_path) + new_generated_data = pipe.generate(samples_ds, leaf_node_path) if len(new_generated_data) == 0: - empty_sdg_leaf_nodes.append(leaf_node_path) - logger.warning("Empty dataset for qna node: %s", leaf_node_path) + empty_input_sample_files.append(sample_file) + logger.warning("Empty generated dataset for sample file: %s", sample_file) continue generated_data.append(new_generated_data) @@ -457,9 +435,9 @@ def generate_data( generate_duration = time.time() - generate_start logger.info(f"Generation took {generate_duration:.2f}s") - if len(empty_sdg_leaf_nodes) > 0: + if len(empty_input_sample_files) > 0: logger.warning( - "Leaf nodes with empty sdg output: {}".format( - " ".join(empty_sdg_leaf_nodes) + "Input sample files with empty sdg output: {}".format( + " ".join(empty_input_sample_files) ) ) diff --git a/src/instructlab/sdg/taxonomy.py b/src/instructlab/sdg/taxonomy.py new file mode 100644 index 00000000..bc017fa5 --- /dev/null +++ b/src/instructlab/sdg/taxonomy.py @@ -0,0 +1,134 @@ +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=duplicate-code + +# Standard +from pathlib import Path +from typing import Optional +import logging +import os + +# Third Party +from xdg_base_dirs import xdg_data_dirs, xdg_data_home +import yaml + +# First Party +from instructlab.sdg.utils import GenerateException +from instructlab.sdg.utils.json import jldump +from instructlab.sdg.utils.taxonomy import ( + leaf_node_to_samples, + read_taxonomy_leaf_nodes, +) + +logger = logging.getLogger(__name__) + +DEFAULT_CHUNK_WORD_COUNT = 1000 +DEFAULT_TAXONOMY_BASE = "empty" +DEFAULT_SERVER_CTX_SIZE = 4096 + + +def _locate_docling_models(): + # Search for the models in User and Site data directories + data_dirs = [os.path.join(xdg_data_home(), "instructlab", "sdg")] + data_dirs.extend(os.path.join(dir, "instructlab", "sdg") for dir in xdg_data_dirs()) + + docling_model_path = None + sdg_models_path = docling_model_path + for d in data_dirs: + if os.path.exists(os.path.join(d, "models")): + sdg_models_path = os.path.join(d, "models") + break + + if sdg_models_path is not None: + try: + with open( + os.path.join(sdg_models_path, "config.yaml"), "r", encoding="utf-8" + ) as file: + config = yaml.safe_load(file) + docling_model_path = config["models"][0]["path"] + except (FileNotFoundError, NotADirectoryError, PermissionError) as e: + logger.warning(f"unable to read docling models path from config.yaml {e}") + + return docling_model_path + + +def taxonomy_to_samples( + taxonomy_path, + output_dir, + chunk_word_count=DEFAULT_CHUNK_WORD_COUNT, # TODO: Remove chunk_word_count param + server_ctx_size=DEFAULT_SERVER_CTX_SIZE, # TODO: Remove server_ctx_size param + taxonomy_base=DEFAULT_TAXONOMY_BASE, + teacher_model_path: Optional[str] = None, + yaml_rules: Optional[str] = None, +): + """ + Preprocess a taxonomy into input samples suitable for use with + data generation pipelines. This does the following steps: + + - Determine changed leaf nodes in the taxonomy + - Retrieve knowledge documents for changed taxonomy leaf nodes + - Convert any non-markdown knowledge documents to markdown + - Write the Docling json and markdown outputs from this conversion to + disk for other processes to consume if needed. + - Chunk the converted knowledge documents to the desired chunk sizes. + - Turn the qna.yaml and knowledge documents into samples in the format + expected by the `simple` and `full` data generation pipelines shipped + in SDG. + - Write these samples to disk, with one file per taxonomy leaf node. + + Args: + taxonomy_path: The path to the taxonomy + output_dir: Where to write the samples create for use with data generation + chunk_word_count: The target number of words per document chunk + server_ctx_size: The maximum number of tokens the inference server used + during data generation can handle + taxonomy_base: Determines how we calculate what has changed. This should + be a git reference or the special value of 'empty' which + means assume the entire taxonomy has changed. + teacher_model_path: Path to the teacher model on disk, which we'll use to + load its tokenizer for use with document chunking. + yaml_rules: Path to a custom YAML rules file for YAML linting. + + Returns: + List[str]: The list of output sample files written to disk. + + """ + logging.info("Converting taxonomy to samples") + output_dir = Path(output_dir) + output_dir.mkdir(exist_ok=True) + output_files = [] + + if not (taxonomy_path and os.path.exists(taxonomy_path)): + raise GenerateException(f"Error: taxonomy ({taxonomy_path}) does not exist.") + + document_output_dir = output_dir.joinpath("documents") + docling_model_path = _locate_docling_models() + + leaf_nodes = read_taxonomy_leaf_nodes( + taxonomy_path, taxonomy_base, yaml_rules, document_output_dir + ) + if not leaf_nodes: + raise GenerateException("Error: No new leaf nodes found in the taxonomy.") + + for leaf_node in leaf_nodes.values(): + leaf_node_path = leaf_node[0]["taxonomy_path"].replace("->", "_") + samples = leaf_node_to_samples( + leaf_node, + taxonomy_path, + server_ctx_size, + chunk_word_count, + document_output_dir, + teacher_model_path, + docling_model_path=docling_model_path, + ) + + if not samples: + raise GenerateException("Error: No samples found in leaf node.") + + logger.debug("Samples: %s", samples) + + output_file = output_dir.joinpath(f"{leaf_node_path}.jsonl") + jldump(samples, output_file) + output_files.append(str(output_file)) + + logger.info("Taxonomy converted to samples and written to %s", output_dir) + return output_files diff --git a/src/instructlab/sdg/utils/json.py b/src/instructlab/sdg/utils/json.py index 041d817b..1ec0b70a 100644 --- a/src/instructlab/sdg/utils/json.py +++ b/src/instructlab/sdg/utils/json.py @@ -60,3 +60,9 @@ def jldump(data: Iterable[Any], out: str | io.IOBase) -> None: for entry in data: json.dump(entry, outfile, ensure_ascii=False) outfile.write("\n") + + +def jlload(f, mode="r"): + """Load a .jsonl file into a list of dictionaries.""" + with _make_r_io_base(f, mode) as f_: + return [json.loads(l) for l in f_.read().splitlines()] diff --git a/src/instructlab/sdg/utils/logging.py b/src/instructlab/sdg/utils/logging.py new file mode 100644 index 00000000..c6236f49 --- /dev/null +++ b/src/instructlab/sdg/utils/logging.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Standard +import logging + +# Third Party +from rich.logging import RichHandler + + +def setup_logger(level="DEBUG"): + """ + Setup a logger - ONLY to be used when running CLI commands in + SDG directly. DO NOT call this from regular library code, and only + call it from __main__ entrypoints in the instructlab.sdg.cli + package + """ + logging.basicConfig( + level=level, + format="%(message)s", + datefmt="[%X]", + handlers=[RichHandler()], + ) diff --git a/src/instructlab/sdg/utils/taxonomy.py b/src/instructlab/sdg/utils/taxonomy.py index 24592fe1..c8c1faf6 100644 --- a/src/instructlab/sdg/utils/taxonomy.py +++ b/src/instructlab/sdg/utils/taxonomy.py @@ -400,6 +400,7 @@ def map_chunks_to_icls(chunks: List, leaf_node: Dict) -> Dataset: "icl_document": icl_.get("context", ""), "document_outline": icl_.get("document_outline", ""), "domain": domain, + "leaf_node_type": "knowledge", } qna_pairs = icl_.get("questions_and_answers", []) @@ -413,7 +414,7 @@ def map_chunks_to_icls(chunks: List, leaf_node: Dict) -> Dataset: chunked_dataset.append(record) - return Dataset.from_list(chunked_dataset) + return chunked_dataset def _knowledge_leaf_node_to_samples( @@ -447,12 +448,23 @@ def _skill_leaf_node_to_samples(leaf_node): for i in range(len(leaf_node)): samples.append({}) samples[-1]["task_description"] = leaf_node[i]["task_description"] + sample_type = "freeform_skill" if leaf_node[i].get("input"): + sample_type = "grounded_skill" samples[-1]["seed_context"] = leaf_node[i]["input"] samples[-1]["seed_question"] = leaf_node[i]["instruction"] samples[-1]["seed_response"] = leaf_node[i]["output"] + samples[-1]["leaf_node_type"] = sample_type - return Dataset.from_list(samples) + return samples + + +def _enrich_metadata(samples, leaf_node): + leaf_node_path = leaf_node[0]["taxonomy_path"].replace("->", "_") + for i, sample in enumerate(samples): + sample["leaf_node_path"] = leaf_node_path + samples[i] = sample + return samples def leaf_node_to_samples( @@ -464,10 +476,9 @@ def leaf_node_to_samples( model_name, docling_model_path=None, ): - if not leaf_node: - return [] - if leaf_node[0].get("documents"): - return _knowledge_leaf_node_to_samples( + samples = [] + if leaf_node and leaf_node[0].get("documents"): + samples = _knowledge_leaf_node_to_samples( leaf_node, taxonomy_path, server_ctx_size, @@ -476,4 +487,7 @@ def leaf_node_to_samples( model_name, docling_model_path, ) - return _skill_leaf_node_to_samples(leaf_node) + elif leaf_node: + samples = _skill_leaf_node_to_samples(leaf_node) + samples = _enrich_metadata(samples, leaf_node) + return Dataset.from_list(samples) diff --git a/tests/test_generate_data.py b/tests/test_generate_data.py index c5415636..a38a76e5 100644 --- a/tests/test_generate_data.py +++ b/tests/test_generate_data.py @@ -85,7 +85,7 @@ def validate_messages_dataset(dataset_file_name, expected_samples): def validate_skill_leaf_node_dataset(dataset_file_name): ds = load_dataset("json", data_files=dataset_file_name, split="train") - assert len(ds.features) == 7 + assert len(ds.features) == 9 features = [ "task_description", "seed_context", @@ -93,6 +93,8 @@ def validate_skill_leaf_node_dataset(dataset_file_name): "seed_response", "output", "id", + "leaf_node_path", + "leaf_node_type", ] for feature in features: assert feature in ds.features @@ -512,7 +514,8 @@ def test_generate(self): ) mocked_logger.warning.assert_called() assert re.search( - "empty sdg output: knowledge_new", mocked_logger.warning.call_args.args[0] + "empty sdg output: .+knowledge_new.jsonl", + mocked_logger.warning.call_args.args[0], ) def teardown(self) -> None: @@ -559,37 +562,3 @@ def test_context_init_batch_size_optional(): batch_num_workers=32, ) assert ctx.batch_size == 20 - - -def test_sdg_init_docling_path_config_found(testdata_path): - with patch.dict(os.environ): - os.environ["XDG_DATA_HOME"] = str(testdata_path.joinpath("mock_xdg_data_dir")) - ctx = _context_init( - None, - "mixtral", - "foo.bar", - 1, - "/checkpoint/dir", - 1, - batch_size=20, - batch_num_workers=32, - ) - _, _, _, docling_model_path = _sdg_init(ctx, "full") - assert docling_model_path == "/mock/docling-models" - - -def test_sdg_init_docling_path_config_not_found(testdata_path): - with patch.dict(os.environ): - os.environ["XDG_DATA_HOME"] = str(testdata_path.joinpath("nonexistent_dir")) - ctx = _context_init( - None, - "mixtral", - "foo.bar", - 1, - "/checkpoint/dir", - 1, - batch_size=20, - batch_num_workers=32, - ) - _, _, _, docling_model_path = _sdg_init(ctx, "full") - assert docling_model_path is None diff --git a/tests/test_taxonomy.py b/tests/test_taxonomy.py index 0828e187..bedacaaf 100644 --- a/tests/test_taxonomy.py +++ b/tests/test_taxonomy.py @@ -2,14 +2,15 @@ # Standard from typing import Any, Dict, Union +from unittest.mock import patch import os -import pathlib # Third Party import pytest import yaml # First Party +from instructlab.sdg.taxonomy import _locate_docling_models from instructlab.sdg.utils import taxonomy TEST_SEED_EXAMPLE = "Can you help me debug this failing unit test?" @@ -22,7 +23,7 @@ def load_test_skills(skills_file_path) -> Union[Dict[str, Any], None]: return yaml.safe_load(skills_file) -class TestTaxonomy: +class TestUtilsTaxonomy: """Test taxonomy in instructlab.sdg.utils.taxonomy.""" @pytest.fixture(autouse=True) @@ -85,3 +86,17 @@ def test_read_taxonomy_leaf_nodes( ): seed_example_exists = True assert seed_example_exists is True + + +def test_locate_docling_models_config_found(testdata_path): + with patch.dict(os.environ): + os.environ["XDG_DATA_HOME"] = str(testdata_path.joinpath("mock_xdg_data_dir")) + docling_model_path = _locate_docling_models() + assert docling_model_path == "/mock/docling-models" + + +def test_locate_docling_models_config_not_found(testdata_path): + with patch.dict(os.environ): + os.environ["XDG_DATA_HOME"] = str(testdata_path.joinpath("nonexistent_dir")) + docling_model_path = _locate_docling_models() + assert docling_model_path is None