Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor HF loader and add poolingMethod #954

Open
wants to merge 80 commits into
base: mainline
Choose a base branch
from

Conversation

wanliAlex
Copy link
Collaborator

@wanliAlex wanliAlex commented Sep 5, 2024

  • What kind of change does this PR introduce? (Bug fix, feature, docs update, ...)
    feature

  • What is the current behavior? (You can also link to an open issue here)

  • The HF loader class is outdated.
  • We can't specify the pooling method
  • Some document links are outdated
  • What is the new behavior (if this is a feature change)?
  • We refactor the HF class code
  • Add a new field "poolingMethod" to the model properties of the HF loader so users can specify the poolingMethod
  • Fix some documents links
  • Does this PR introduce a breaking change? (What changes might users need to make in their application due to this PR?)

no

  • Have unit tests been run against this PR? (Has there also been any additional testing?)

running

  • Related Python client changes (link commit/PR here)

  • Related documentation changes (link commit/PR here)

  • Other information:

  • Please check if the PR fulfills these requirements

  • The commit message follows our guidelines
  • Tests for the changes have been added (for bug fixes/features)
  • Docs have been added / updated (for bug fixes / features)

farshidz
farshidz previously approved these changes Oct 22, 2024
from marqo.s2_inference.types import *
from marqo.core.inference.image_download import (_is_image, format_and_load_CLIP_images,
format_and_load_CLIP_image)
from marqo.core.inference.inference_models.abstract_embedding_model import AbstractEmbeddingModel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should be consistent when renaming classes/directories. I notice there's a new directory called core/inference/inference_models. Maybe it should be core/inference/embedding_models to keep consistency if we're referring to the same objects.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This goes for any other reference to inference models vs. embeddings models



class PoolingMethod(str, Enum):
Mean = "mean"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add max and attention pooling as options as well? This can be a future feature.

return PoolingMethod.Mean

if not isinstance(content, dict):
logger.warn(f"Could not infer pooling method from the model {name}. Defaulting to mean pooling.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code snippet:

logger.warn(f"Could not infer pooling method from the model {name}. Defaulting to mean pooling.")
            return PoolingMethod.Mean

Is repeated a lot. It could be split into a function, or put at the bottom of this function and triggered with a boolean flag

CLS = "cls"


class HuggingFaceModelProperties(MarqoBaseModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably worth making a ModelProperties class and having this class subclass from it. We may need an OpenClipModelProperties in the future.

Comment on lines +57 to +62
if not (self.model_properties.name or self.model_properties.url or self.model_properties.model_location):
raise InvalidModelPropertiesError(
f"Invalid model properties for the 'hf' model. "
f"You do not have the necessary information to load the model. "
f"Check {marqo_docs.bring_your_own_model()} for more information."
)
Copy link
Collaborator

@papa99do papa99do Oct 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit. This logic is covered in the next section, can be removed. You also have a validator in the HuggingFaceModelProperties class to ensure this.

Comment on lines 34 to 35
self._load_necessary_components()
self._check_loaded_components()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to separate these two method?

sentence = [sentence]

if self._model is None:
self.load()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need concurrency control here?

@staticmethod
def extract_huggingface_archive(path: str) -> str:
'''
This function takes the path as input. The path can must be a string that can be:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit. The path can must be a string that can be -> The path is a string that can be

with tarfile.open(path, 'r') as tar_ref:
tar_ref.extractall(new_dir)
# return the path to the new directory
return new_dir
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to keep the extracted files after the model is loaded?

Comment on lines +182 to +191
@staticmethod
def _average_pool_func(model_output, attention_mask):
"""A pooling function that averages the hidden states of the model."""
last_hidden = model_output.last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

@staticmethod
def _cls_pool_func(model_output, attention_mask=None):
"""A pooling function that extracts the CLS token from the model."""
return model_output[0][:, 0]
Copy link
Collaborator

@papa99do papa99do Oct 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these pooling methods common across models? Will we support more pooling method in the future? If so, consider extract them out to a class hierarchy? (does not need to change now)

try:
file_path = hf_hub_download(repo_id, file_name, cache_dir=ModelCache.hf_cache_path)
except HfHubHTTPError:
logger.warn(f"Could not infer pooling method from the model {name}. Defaulting to mean pooling.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit. logger.warning. (logger.warn is deprecated)

class TestHuggingFaceModel(unittest.TestCase):
"""Test initializing the HuggingFaceModel with valid properties."""

E5_BASE_V2_MODEL_EMBEDDINGS = np.squeeze(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit. consider having these fixed embeddings in a separate json file, it's easier to read and maintain.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants