-
Notifications
You must be signed in to change notification settings - Fork 188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor HF loader and add poolingMethod #954
base: mainline
Are you sure you want to change the base?
Conversation
…into li/add-pooling-hf
from marqo.s2_inference.types import * | ||
from marqo.core.inference.image_download import (_is_image, format_and_load_CLIP_images, | ||
format_and_load_CLIP_image) | ||
from marqo.core.inference.inference_models.abstract_embedding_model import AbstractEmbeddingModel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should be consistent when renaming classes/directories. I notice there's a new directory called core/inference/inference_models
. Maybe it should be core/inference/embedding_models
to keep consistency if we're referring to the same objects.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This goes for any other reference to inference models vs. embeddings models
|
||
|
||
class PoolingMethod(str, Enum): | ||
Mean = "mean" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add max and attention pooling as options as well? This can be a future feature.
return PoolingMethod.Mean | ||
|
||
if not isinstance(content, dict): | ||
logger.warn(f"Could not infer pooling method from the model {name}. Defaulting to mean pooling.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code snippet:
logger.warn(f"Could not infer pooling method from the model {name}. Defaulting to mean pooling.")
return PoolingMethod.Mean
Is repeated a lot. It could be split into a function, or put at the bottom of this function and triggered with a boolean flag
CLS = "cls" | ||
|
||
|
||
class HuggingFaceModelProperties(MarqoBaseModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's probably worth making a ModelProperties
class and having this class subclass from it. We may need an OpenClipModelProperties
in the future.
if not (self.model_properties.name or self.model_properties.url or self.model_properties.model_location): | ||
raise InvalidModelPropertiesError( | ||
f"Invalid model properties for the 'hf' model. " | ||
f"You do not have the necessary information to load the model. " | ||
f"Check {marqo_docs.bring_your_own_model()} for more information." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit. This logic is covered in the next section, can be removed. You also have a validator in the HuggingFaceModelProperties class to ensure this.
self._load_necessary_components() | ||
self._check_loaded_components() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to separate these two method?
sentence = [sentence] | ||
|
||
if self._model is None: | ||
self.load() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need concurrency control here?
@staticmethod | ||
def extract_huggingface_archive(path: str) -> str: | ||
''' | ||
This function takes the path as input. The path can must be a string that can be: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit. The path can must be a string that can be
-> The path is a string that can be
with tarfile.open(path, 'r') as tar_ref: | ||
tar_ref.extractall(new_dir) | ||
# return the path to the new directory | ||
return new_dir |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to keep the extracted files after the model is loaded?
@staticmethod | ||
def _average_pool_func(model_output, attention_mask): | ||
"""A pooling function that averages the hidden states of the model.""" | ||
last_hidden = model_output.last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0) | ||
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] | ||
|
||
@staticmethod | ||
def _cls_pool_func(model_output, attention_mask=None): | ||
"""A pooling function that extracts the CLS token from the model.""" | ||
return model_output[0][:, 0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these pooling methods common across models? Will we support more pooling method in the future? If so, consider extract them out to a class hierarchy? (does not need to change now)
try: | ||
file_path = hf_hub_download(repo_id, file_name, cache_dir=ModelCache.hf_cache_path) | ||
except HfHubHTTPError: | ||
logger.warn(f"Could not infer pooling method from the model {name}. Defaulting to mean pooling.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit. logger.warning. (logger.warn is deprecated)
class TestHuggingFaceModel(unittest.TestCase): | ||
"""Test initializing the HuggingFaceModel with valid properties.""" | ||
|
||
E5_BASE_V2_MODEL_EMBEDDINGS = np.squeeze( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit. consider having these fixed embeddings in a separate json file, it's easier to read and maintain.
What kind of change does this PR introduce? (Bug fix, feature, docs update, ...)
feature
What is the current behavior? (You can also link to an open issue here)
poolingMethod
" to the model properties of the HF loader so users can specify the poolingMethodno
running
Related Python client changes (link commit/PR here)
Related documentation changes (link commit/PR here)
Other information:
Please check if the PR fulfills these requirements