diff --git a/src/fairseq2/models/jepa/loader.py b/src/fairseq2/models/jepa/loader.py index 27cb9f189..a38c613e7 100644 --- a/src/fairseq2/models/jepa/loader.py +++ b/src/fairseq2/models/jepa/loader.py @@ -26,8 +26,21 @@ def convert_jepa_checkpoint( checkpoint: dict[str, Any], config: JepaConfig ) -> dict[str, Any]: - checkpoint = checkpoint["encoder"] + # We have a shared checkpoint, used for other use cases (frozen evaluation,..) + if "target_encoder" in checkpoint: + return convert_jepa_encoder_checkpoint( + checkpoint["target_encoder"], config=config + ) + if "encoder" in checkpoint: + return convert_jepa_encoder_checkpoint(checkpoint["encoder"], config=config) + + raise ValueError(f"encoder not found (available keys: {checkpoint.keys()})") + + +def convert_jepa_encoder_checkpoint( + checkpoint: dict[str, Any], config: JepaConfig +) -> dict[str, Any]: del checkpoint["module.backbone.pos_embed"] new_checkpoint = {}