Skip to content

Commit

Permalink
fix state dict hook for early fusion models (#2317)
Browse files Browse the repository at this point in the history
Co-authored-by: JessicaZhong <[email protected]>
  • Loading branch information
acisseJZhong and jessicazhongeee authored Jan 30, 2025
1 parent d3b39cf commit be4ff50
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions torchtune/modules/model_fusion/_early_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,11 @@ def _state_dict_hook(module, state_dict, prefix, *args, **kwargs):
[!Note] This update changes the order of the OrderedDict
"""
for n, p in module.tok_embeddings.named_parameters():
state_dict[f"{prefix}decoder.tok_embeddings.{n}"] = p
del state_dict[f"{prefix}tok_embeddings.{n}"]
orig_key = f"{prefix}tok_embeddings.{n}"
if orig_key in state_dict:
# preserve the original tensor with its requires_grad state
state_dict[f"{prefix}decoder.tok_embeddings.{n}"] = state_dict[orig_key]
del state_dict[orig_key]

@staticmethod
def _load_state_dict_hook(module, state_dict, prefix, *args, **kwargs):
Expand Down

0 comments on commit be4ff50

Please sign in to comment.