Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Jepa loader more flexible #945

Merged
merged 12 commits into from
Jan 3, 2025
Merged

Make Jepa loader more flexible #945

merged 12 commits into from
Jan 3, 2025

Conversation

antoine-tran
Copy link
Contributor

@antoine-tran antoine-tran commented Jan 2, 2025

What does this PR do? Please describe:
In JEPA frozen evaluation scenario, the pretrained encoder was loaded from an external pretrained checkpoint. Because the pretrained checkpoints can be updated separately from the attentive pooling checkpoints, we need to make sure the weights used in combination with the corresponding attentive pooling are saved in a special place, and will not be overridden in the subsequent training epoches.

In JEPA, this was done via a special checkpoint key "target_encoder" that is saved along with "encoder" in a pretrained checkpoint, that is reserved for the evaluation reproducibility. References: JEPA evals configs (Example here, loader here)

This PR makes an update to allow loading a JEPA model for different purposes (training, frozen evaluation)

** Fixes #{issue number}

Does your PR introduce any breaking changes? If yes, please list them:
List of all backwards-incompatible changes.

Check list:

  • Was the content of this PR discussed and approved via a GitHub issue? (no need for typos or documentation improvements)
  • Did you read the contributor guideline?
  • Did you make sure that your PR does only one thing instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests?
  • Did you verify new and existing tests pass locally with your changes?
  • Did you update the CHANGELOG? (no need for typos, documentation, or minor internal changes)

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 2, 2025
if "encoder" not in state_dict:
raise ValueError(f"`encoder` not found in state dict (available key: {state_dict.keys()})")

return state_dict["encoder"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although I understood the PR description, I am not sure if I understand the change here. Any reason for not handling this check in convert_jepa_checkpoint? I mean instead of

checkpoint = checkpoint["encoder"]

having:

checkpoint = checkpoint.get("encoder")
if checkpoint is None:
  raise ValueError(...)

What is the benefit of having this check in a tensor_loader?

Copy link
Contributor Author

@antoine-tran antoine-tran Jan 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have 2 thoughts in making this change, both are opinionated though:

  • We should narrow the scope of convert_jepa_checkpoint function to only converting the parameters related to the jepa model. How we get into these parameters is handled separately (in TensorLoader).
  • With this, we do not list all possible checkpoint keys ("encoder" , "target_encoder") and define their priority in convert_jepa_checkpoint. This allows us to inject the pretrained encoders from other "exotic" checkpoints (for example, the jepa-llava where the encoder is stored in vision_tower).

The drawback of this approach though is we have to write custom TensorLoader for each checkpoint, so it is the matter of opinions here...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about doing something like:

# Handles different variants of JEPA checkpoints and delegates the actual conversion
# to the standard converter.
def convert_jepa_checkpoint(
    checkpoint: dict[str, Any], config: JepaConfig
) -> dict[str, Any]:
  if "vision_tower" in checkpoint:
      return convert_jepa_encoder_checkpoint(checkpoint["vision_tower"])

  if "target_encoder" in checkpoint:
      return convert_jepa_encoder_checkpoint(checkpoint["target_encoder"])

  if "encoder" in checkpoint:
      return convert_jepa_encoder_checkpoint(checkpoint["encoder"])

  raise ValueError("encoder not found.")

def convert_jepa_encoder_checkpoint(
    checkpoint: dict[str, Any], config: JepaConfig
) -> dict[str, Any]:
    # Contains the current implementation.
    ...

My worry with the TensorLoader approach is that we leak state dict handling logic to tensor loading. Essentially we want to "pre-process" the checkpoint before passing it to the converter. So a wrapper function might do the job as well. Let me know what you think.

Copy link
Contributor

@cbalioglu cbalioglu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make sure to run mypy/flake8/isort/black to fix the format errors before merging. Otherwise, looks good to me! Thanks!

Tuan Tran and others added 3 commits January 3, 2025 10:00
@antoine-tran antoine-tran merged commit b4b29e5 into main Jan 3, 2025
15 checks passed
@antoine-tran antoine-tran deleted the tuan/fix_jepa_loader branch January 3, 2025 19:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants