Skip to content

Commit

Permalink
Update the way docling_model_path is passed to generate_data
Browse files Browse the repository at this point in the history
Signed-off-by: Aakanksha Duggal <[email protected]>
  • Loading branch information
aakankshaduggal committed Nov 14, 2024
1 parent 16c6f45 commit 2a00cb3
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,17 +228,17 @@ def _sdg_init(ctx, pipeline):
sdg_models_path = os.path.join(d, "models")
break

if sdg_models_path is not None:
try:
with open(
os.path.join(sdg_models_path, "config.yaml"), "r", encoding="utf-8"
) as file:
config = yaml.safe_load(file)
docling_model_path = config["models"][0]["path"]
except (FileNotFoundError, NotADirectoryError, PermissionError) as e:
logger.warning(
f"unable to read docling models path from config.yaml {e}"
)
if sdg_models_path is not None:
try:
with open(
os.path.join(sdg_models_path, "config.yaml"), "r", encoding="utf-8"
) as file:
config = yaml.safe_load(file)
docling_model_path = config["models"][0]["path"]
except (FileNotFoundError, NotADirectoryError, PermissionError) as e:
logger.warning(
f"unable to read docling models path from config.yaml {e}"
)

for d in data_dirs:
pipeline_path = os.path.join(d, "pipelines", pipeline)
Expand Down Expand Up @@ -271,6 +271,7 @@ def load_pipeline(yaml_basename):
load_pipeline("knowledge.yaml"),
load_pipeline("freeform_skills.yaml"),
load_pipeline("grounded_skills.yaml"),
docling_model_path
)


Expand Down Expand Up @@ -315,7 +316,6 @@ def generate_data(
batch_size: Optional[int] = None,
checkpoint_dir: Optional[str] = None,
max_num_tokens: Optional[int] = DEFAULT_MAX_NUM_TOKENS,
docling_model_path: Optional[str] = None,
) -> None:
"""Generate data for training and testing a model.
Expand Down Expand Up @@ -384,7 +384,7 @@ def generate_data(
max_num_tokens=max_num_tokens,
)

knowledge_pipe, freeform_skills_pipe, grounded_skills_pipe = _sdg_init(
knowledge_pipe, freeform_skills_pipe, grounded_skills_pipe, docling_model_path = _sdg_init(
ctx, pipeline
)

Expand Down

0 comments on commit 2a00cb3

Please sign in to comment.