diff --git a/src/download_models.py b/src/download_models.py index ba713e1..c30305f 100644 --- a/src/download_models.py +++ b/src/download_models.py @@ -1,6 +1,8 @@ import math +from fileinput import filename from os import makedirs from os.path import join, exists +from pathlib import Path from urllib.request import urlretrieve from huggingface_hub import snapshot_download, hf_hub_download @@ -33,23 +35,20 @@ def download_embedding_model(): snapshot_download(repo_id="microsoft/layoutlm-base-uncased", local_dir=model_path, local_dir_use_symlinks=False) +def download_from_hf_hub(path: Path): + if path.exists(): + return + + file_name = path.name + makedirs(path.parent, exist_ok=True) + repo_id = "HURIDOCS/pdf-document-layout-analysis" + hf_hub_download(repo_id=repo_id, filename=file_name, local_dir=path.parent, local_dir_use_symlinks=False) + + def download_lightgbm_models(): - tokens_type_model_path = join(MODELS_PATH, "token_type_lightgbm.model") - paragraph_extraction_model_path = join(MODELS_PATH, "paragraph_extraction_lightgbm.model") - if not exists(tokens_type_model_path): - hf_hub_download( - repo_id="HURIDOCS/pdf-document-layout-analysis", - filename="token_type_lightgbm.model", - local_dir=str(MODELS_PATH), - local_dir_use_symlinks=False, - ) - if not exists(paragraph_extraction_model_path): - hf_hub_download( - repo_id="HURIDOCS/pdf-document-layout-analysis", - filename="paragraph_extraction_lightgbm.model", - local_dir=str(MODELS_PATH), - local_dir_use_symlinks=False, - ) + download_from_hf_hub(Path(MODELS_PATH, "token_type_lightgbm.model")) + download_from_hf_hub(Path(MODELS_PATH, "paragraph_extraction_lightgbm.model")) + download_from_hf_hub(Path(MODELS_PATH, "config.json")) def download_models(model_name: str):