-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make Jepa loader more flexible #945
Conversation
src/fairseq2/models/jepa/loader.py
Outdated
if "encoder" not in state_dict: | ||
raise ValueError(f"`encoder` not found in state dict (available key: {state_dict.keys()})") | ||
|
||
return state_dict["encoder"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although I understood the PR description, I am not sure if I understand the change here. Any reason for not handling this check in convert_jepa_checkpoint
? I mean instead of
checkpoint = checkpoint["encoder"]
having:
checkpoint = checkpoint.get("encoder")
if checkpoint is None:
raise ValueError(...)
What is the benefit of having this check in a tensor_loader
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have 2 thoughts in making this change, both are opinionated though:
- We should narrow the scope of
convert_jepa_checkpoint
function to only converting the parameters related to the jepa model. How we get into these parameters is handled separately (in TensorLoader). - With this, we do not list all possible checkpoint keys ("encoder" , "target_encoder") and define their priority in
convert_jepa_checkpoint
. This allows us to inject the pretrained encoders from other "exotic" checkpoints (for example, the jepa-llava where the encoder is stored invision_tower
).
The drawback of this approach though is we have to write custom TensorLoader for each checkpoint, so it is the matter of opinions here...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about doing something like:
# Handles different variants of JEPA checkpoints and delegates the actual conversion
# to the standard converter.
def convert_jepa_checkpoint(
checkpoint: dict[str, Any], config: JepaConfig
) -> dict[str, Any]:
if "vision_tower" in checkpoint:
return convert_jepa_encoder_checkpoint(checkpoint["vision_tower"])
if "target_encoder" in checkpoint:
return convert_jepa_encoder_checkpoint(checkpoint["target_encoder"])
if "encoder" in checkpoint:
return convert_jepa_encoder_checkpoint(checkpoint["encoder"])
raise ValueError("encoder not found.")
def convert_jepa_encoder_checkpoint(
checkpoint: dict[str, Any], config: JepaConfig
) -> dict[str, Any]:
# Contains the current implementation.
...
My worry with the TensorLoader
approach is that we leak state dict handling logic to tensor loading. Essentially we want to "pre-process" the checkpoint before passing it to the converter. So a wrapper function might do the job as well. Let me know what you think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please make sure to run mypy/flake8/isort/black to fix the format errors before merging. Otherwise, looks good to me! Thanks!
What does this PR do? Please describe:
In JEPA frozen evaluation scenario, the pretrained encoder was loaded from an external pretrained checkpoint. Because the pretrained checkpoints can be updated separately from the attentive pooling checkpoints, we need to make sure the weights used in combination with the corresponding attentive pooling are saved in a special place, and will not be overridden in the subsequent training epoches.
In JEPA, this was done via a special checkpoint key "target_encoder" that is saved along with "encoder" in a pretrained checkpoint, that is reserved for the evaluation reproducibility. References: JEPA evals configs (Example here, loader here)
This PR makes an update to allow loading a JEPA model for different purposes (training, frozen evaluation)
** Fixes #{issue number}
Does your PR introduce any breaking changes? If yes, please list them:
List of all backwards-incompatible changes.
Check list: