Skip to content

Commit

Permalink
Remove timm from being a hard dependency (huggingface#414)
Browse files Browse the repository at this point in the history
* Fix ov timm model loading

* increase tolerance

* change timm from hard to soft dependency

* add missing import
  • Loading branch information
echarlaix authored Aug 28, 2023
1 parent 3497d61 commit f39a84d
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 27 deletions.
12 changes: 10 additions & 2 deletions optimum/intel/openvino/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@

from optimum.exporters import TasksManager

from ..utils.import_utils import is_timm_available
from .modeling_base import OVBaseModel
from .modeling_timm import TimmConfig, TimmForImageClassification, TimmOnnxConfig, is_timm_ov_dir
from .utils import _is_timm_ov_dir


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -532,8 +533,15 @@ def from_pretrained(
**kwargs,
):
# Fix the mismatch between timm_config and huggingface_config
local_timm_model = is_timm_ov_dir(model_id)
local_timm_model = _is_timm_ov_dir(model_id)
if local_timm_model or (not os.path.isdir(model_id) and model_info(model_id).library_name == "timm"):
if not is_timm_available():
raise ImportError(
"To load a timm model, timm needs to be installed. Please install it with `pip install timm`."
)

from .modeling_timm import TimmConfig, TimmForImageClassification, TimmOnnxConfig

config = TimmConfig.from_pretrained(model_id, **kwargs)
# If locally saved timm model, directly load
if local_timm_model:
Expand Down
27 changes: 3 additions & 24 deletions optimum/intel/openvino/modeling_timm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import json
import os
from collections import OrderedDict
from glob import glob
from typing import Dict, List, Optional, Union

import numpy as np
import timm
import torch
from huggingface_hub import model_info
from packaging import version
from timm.layers.config import set_fused_attn
from timm.models._hub import load_model_config_from_hf
Expand All @@ -31,28 +28,10 @@
from optimum.exporters.onnx.config import VisionOnnxConfig
from optimum.utils import NormalizedVisionConfig


set_fused_attn(False, False)
from .utils import _is_timm_ov_dir


def is_timm_ov_dir(model_dir):
config_file = None
has_xml = False
has_bin = False
if os.path.isdir(model_dir):
for filename in glob(os.path.join(model_dir, "*")):
if filename.endswith(".xml"):
has_xml = True
if filename.endswith(".bin"):
has_bin = True
if filename.endswith("config.json"):
config_file = filename
if config_file and has_xml and has_bin:
with open(config_file) as conf:
hf_hub_id = json.load(conf).get("hf_hub_id", None)
if hf_hub_id and model_info(hf_hub_id).library_name == "timm":
return True
return False
set_fused_attn(False, False)


class TimmConfig(PretrainedConfig):
Expand All @@ -69,7 +48,7 @@ def from_pretrained(
revision: str = "main",
**kwargs,
) -> "PretrainedConfig":
if is_timm_ov_dir(pretrained_model_name_or_path):
if _is_timm_ov_dir(pretrained_model_name_or_path):
config_path = os.path.join(pretrained_model_name_or_path, "config.json")
return cls.from_json_file(config_path)

Expand Down
25 changes: 25 additions & 0 deletions optimum/intel/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
# limitations under the License.


import json
import os
from glob import glob

import numpy as np
from huggingface_hub import model_info
from openvino.runtime import Type
from transformers.onnx.utils import ParameterFormat, compute_serialized_parameters_size

Expand Down Expand Up @@ -95,3 +100,23 @@ def use_external_data_format(num_parameters: int) -> bool:
"""

return compute_serialized_parameters_size(num_parameters, ParameterFormat.Float) >= EXTERNAL_DATA_FORMAT_SIZE_LIMIT


def _is_timm_ov_dir(model_dir):
config_file = None
has_xml = False
has_bin = False
if os.path.isdir(model_dir):
for filename in glob(os.path.join(model_dir, "*")):
if filename.endswith(".xml"):
has_xml = True
if filename.endswith(".bin"):
has_bin = True
if filename.endswith("config.json"):
config_file = filename
if config_file and has_xml and has_bin:
with open(config_file) as conf:
hf_hub_id = json.load(conf).get("hf_hub_id", None)
if hf_hub_id and model_info(hf_hub_id).library_name == "timm":
return True
return False
14 changes: 14 additions & 0 deletions optimum/intel/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
except importlib_metadata.PackageNotFoundError:
_diffusers_available = False


_safetensors_version = "N/A"
_safetensors_available = importlib.util.find_spec("safetensors") is not None
if _safetensors_available:
Expand All @@ -100,6 +101,15 @@
_safetensors_available = False


_timm_available = importlib.util.find_spec("timm") is not None
_timm_version = "N/A"
if _timm_available:
try:
_timm_version = importlib_metadata.version("timm")
except importlib_metadata.PackageNotFoundError:
_timm_available = False


def is_transformers_available():
return _transformers_available

Expand Down Expand Up @@ -128,6 +138,10 @@ def is_safetensors_available():
return _safetensors_available


def is_timm_available():
return _timm_available


# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319
def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str):
"""
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"datasets>=1.4.0",
"sentencepiece",
"scipy",
"timm",
"accelerate", # transformers 4.29 require accelerate for PyTorch
]

Expand All @@ -31,6 +30,7 @@
"sacremoses",
"torchaudio",
"rjieba",
"timm",
]

QUALITY_REQUIRE = ["black~=23.1", "ruff>=0.0.241"]
Expand Down

0 comments on commit f39a84d

Please sign in to comment.