Skip to content

Commit b70845a

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
fix load checkpoint issue when load_state_dict with assign=True (#3369)
Summary: Pull Request resolved: #3369 # Issue Summary The PositionWeightedModuleCollection class in TorchRec maintains two separate references to position weight parameters: position_weights (ParameterDict) and position_weights_dict (regular dict). When checkpoints are loaded using load_state_dict(..., assign=True), the position_weights_dict becomes desynchronized and continues pointing to stale parameter tensors. # Impact - Silent correctness failures during model evaluation and inference - Training instability when loading checkpoints for recurring training - Production issues in model serving pipelines - Inconsistent results between fresh model initialization and checkpoint loading # Root Cause Analysis 1. Initialization: During module construction, position_weights_dict[key] = self.position_weights[key] creates references to parameter tensors (line 190) 2. Checkpoint Loading: When load_state_dict(..., assign=True) is called, PyTorch replaces the actual parameter tensors with new ones from the checkpoint 3. Stale References: The position_weights_dict continues pointing to the old parameter tensors that are no longer part of the model 4. Silent Failure: The get_weights_list() function uses position_weights_dict (line 212), causing the model to use incorrect weights without any error Reviewed By: dyerinoon Differential Revision: D81749871 fbshipit-source-id: 81f6c4dddef42b377598511396b3210669e4db36
1 parent 82aaaa8 commit b70845a

File tree

2 files changed

+54
-1
lines changed

2 files changed

+54
-1
lines changed

torchrec/modules/feature_processor_.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@
1010
#!/usr/bin/env python3
1111

1212
import abc
13-
from typing import Dict, List, Optional
13+
from typing import Dict, List, Mapping, Optional
1414

1515
import torch
1616

1717
from torch import nn
18+
from torch.nn.modules.module import _IncompatibleKeys
1819

1920
from torchrec.pt2.checks import is_non_strict_exporting
2021
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor
@@ -232,3 +233,15 @@ def _apply(self, *args, **kwargs) -> nn.Module:
232233
self.position_weights_dict[k] = param
233234

234235
return self
236+
237+
def load_state_dict(
238+
self,
239+
state_dict: Mapping[str, torch.Tensor],
240+
strict: bool = True,
241+
assign: bool = False,
242+
) -> _IncompatibleKeys:
243+
result = super().load_state_dict(state_dict, strict, assign)
244+
# Re-sync after loading
245+
for k, param in self.position_weights.items():
246+
self.position_weights_dict[k] = param
247+
return result

torchrec/modules/tests/test_feature_processor_.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,3 +209,43 @@ def test_to(self) -> None:
209209
self.assertTrue(
210210
all(param.is_meta for param in pwmc.position_weights_dict.values())
211211
)
212+
213+
def test_load_state_dict(self) -> None:
214+
values = torch.tensor([10, 11, 12, 20, 21, 22])
215+
lengths = torch.tensor([3, 3])
216+
kjt = KeyedJaggedTensor(
217+
keys=["feature1", "feature2"], values=values, lengths=lengths
218+
)
219+
220+
# Step 1: Create module and observe initial state
221+
max_feature_lengths = {"feature1": 3, "feature2": 3}
222+
module = PositionWeightedModuleCollection(max_feature_lengths)
223+
224+
# Before checkpoint loading, position_weights_dict is a element-wise reference of position_weights
225+
for f in ["feature1", "feature2"]:
226+
self.assertIs(
227+
module.position_weights[f],
228+
module.position_weights_dict[f],
229+
)
230+
231+
output = module(kjt)
232+
expected = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
233+
self.assertListEqual(output.weights().tolist(), expected)
234+
235+
# Step 2: Simulate checkpoint loading with assign=True
236+
checkpoint = {
237+
"position_weights.feature1": torch.tensor([2.0, 3.0, 4.0]),
238+
"position_weights.feature2": torch.tensor([5.0, 6.0, 7.0]),
239+
}
240+
module.load_state_dict(checkpoint, strict=False, assign=True)
241+
242+
# After checkpoint loading, position_weights_dict is a element-wise reference of position_weights
243+
for f in ["feature1", "feature2"]:
244+
self.assertIs(
245+
module.position_weights[f],
246+
module.position_weights_dict[f],
247+
)
248+
249+
output = module(kjt)
250+
expected = [2.0, 3.0, 4.0, 5.0, 6.0, 7.0]
251+
self.assertListEqual(output.weights().tolist(), expected)

0 commit comments

Comments
 (0)