Commit b70845a
fix load checkpoint issue when load_state_dict with assign=True (#3369)
Summary:
Pull Request resolved: #3369
# Issue Summary
The PositionWeightedModuleCollection class in TorchRec maintains two separate references to position weight parameters: position_weights (ParameterDict) and position_weights_dict (regular dict). When checkpoints are loaded using load_state_dict(..., assign=True), the position_weights_dict becomes desynchronized and continues pointing to stale parameter tensors.
# Impact
- Silent correctness failures during model evaluation and inference
- Training instability when loading checkpoints for recurring training
- Production issues in model serving pipelines
- Inconsistent results between fresh model initialization and checkpoint loading
# Root Cause Analysis
1. Initialization: During module construction, position_weights_dict[key] = self.position_weights[key] creates references to parameter tensors (line 190)
2. Checkpoint Loading: When load_state_dict(..., assign=True) is called, PyTorch replaces the actual parameter tensors with new ones from the checkpoint
3. Stale References: The position_weights_dict continues pointing to the old parameter tensors that are no longer part of the model
4. Silent Failure: The get_weights_list() function uses position_weights_dict (line 212), causing the model to use incorrect weights without any error
Reviewed By: dyerinoon
Differential Revision: D81749871
fbshipit-source-id: 81f6c4dddef42b377598511396b3210669e4db361 parent 82aaaa8 commit b70845a
File tree
2 files changed
+54
-1
lines changed- torchrec/modules
- tests
2 files changed
+54
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
10 | 10 | | |
11 | 11 | | |
12 | 12 | | |
13 | | - | |
| 13 | + | |
14 | 14 | | |
15 | 15 | | |
16 | 16 | | |
17 | 17 | | |
| 18 | + | |
18 | 19 | | |
19 | 20 | | |
20 | 21 | | |
| |||
232 | 233 | | |
233 | 234 | | |
234 | 235 | | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
209 | 209 | | |
210 | 210 | | |
211 | 211 | | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
0 commit comments