From 2a00cb3b33be5e38060697f205118ac39cecbc3d Mon Sep 17 00:00:00 2001 From: Aakanksha Duggal Date: Thu, 14 Nov 2024 11:05:50 -0500 Subject: [PATCH] Update the way docling_model_path is passed to generate_data Signed-off-by: Aakanksha Duggal --- src/instructlab/sdg/generate_data.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index bc5d81b2..583da5d3 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -228,17 +228,17 @@ def _sdg_init(ctx, pipeline): sdg_models_path = os.path.join(d, "models") break - if sdg_models_path is not None: - try: - with open( - os.path.join(sdg_models_path, "config.yaml"), "r", encoding="utf-8" - ) as file: - config = yaml.safe_load(file) - docling_model_path = config["models"][0]["path"] - except (FileNotFoundError, NotADirectoryError, PermissionError) as e: - logger.warning( - f"unable to read docling models path from config.yaml {e}" - ) + if sdg_models_path is not None: + try: + with open( + os.path.join(sdg_models_path, "config.yaml"), "r", encoding="utf-8" + ) as file: + config = yaml.safe_load(file) + docling_model_path = config["models"][0]["path"] + except (FileNotFoundError, NotADirectoryError, PermissionError) as e: + logger.warning( + f"unable to read docling models path from config.yaml {e}" + ) for d in data_dirs: pipeline_path = os.path.join(d, "pipelines", pipeline) @@ -271,6 +271,7 @@ def load_pipeline(yaml_basename): load_pipeline("knowledge.yaml"), load_pipeline("freeform_skills.yaml"), load_pipeline("grounded_skills.yaml"), + docling_model_path ) @@ -315,7 +316,6 @@ def generate_data( batch_size: Optional[int] = None, checkpoint_dir: Optional[str] = None, max_num_tokens: Optional[int] = DEFAULT_MAX_NUM_TOKENS, - docling_model_path: Optional[str] = None, ) -> None: """Generate data for training and testing a model. @@ -384,7 +384,7 @@ def generate_data( max_num_tokens=max_num_tokens, ) - knowledge_pipe, freeform_skills_pipe, grounded_skills_pipe = _sdg_init( + knowledge_pipe, freeform_skills_pipe, grounded_skills_pipe, docling_model_path = _sdg_init( ctx, pipeline )