@@ -230,8 +230,8 @@ Which produces the following weights structure (viewed in
230
230
You can use the ` PyTorchFileRecorder ` to change the attribute names and the order of the attributes
231
231
by specifying a regular expression (See
232
232
[ regex::Regex::replace] ( https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace ) and
233
- [ try it online] ( https://rregex.dev/?version=1.10&method=replace ) ) to
234
- match the attribute name and a replacement string in ` LoadArgs ` :
233
+ [ try it online] ( https://rregex.dev/?version=1.10&method=replace ) ) to match the attribute name and a
234
+ replacement string in ` LoadArgs ` :
235
235
236
236
``` rust
237
237
let device = Default :: default ();
@@ -246,6 +246,46 @@ let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
246
246
let model = Net :: <Backend >:: new_with (record );
247
247
```
248
248
249
+ ### Printing the source model keys and tensor information
250
+
251
+ If you are unsure about the keys in the source model, you can print them using the following code:
252
+
253
+ ``` rust
254
+ let device = Default :: default ();
255
+ let load_args = LoadArgs :: new (" tests/key_remap/key_remap.pt" . into ())
256
+ // Remove "conv" prefix, e.g. "conv.conv1" -> "conv1"
257
+ . with_key_remap (" conv\ \ .(.*)" , " $1" )
258
+ . with_debug_print (); // Print the keys and remapped keys
259
+
260
+ let record = PyTorchFileRecorder :: <FullPrecisionSettings >:: default ()
261
+ . load (load_args , & device )
262
+ . expect (" Should decode state successfully" );
263
+
264
+ let model = Net :: <Backend >:: new_with (record );
265
+ ```
266
+
267
+ Here is an example of the output:
268
+
269
+ ``` text
270
+ Debug information of keys and tensor shapes:
271
+ ---
272
+ Original Key: conv.conv1.bias
273
+ Remapped Key: conv1.bias
274
+ Shape: [2]
275
+ Dtype: F32
276
+ ---
277
+ Original Key: conv.conv1.weight
278
+ Remapped Key: conv1.weight
279
+ Shape: [2, 2, 2, 2]
280
+ Dtype: F32
281
+ ---
282
+ Original Key: conv.conv2.weight
283
+ Remapped Key: conv2.weight
284
+ Shape: [2, 2, 2, 2]
285
+ Dtype: F32
286
+ ---
287
+ ```
288
+
249
289
### Loading the model weights to a partial model
250
290
251
291
` PyTorchFileRecorder ` enables selective weight loading into partial models. For instance, in a model
@@ -254,11 +294,12 @@ defining the encoder in Burn, allowing the loading of its weights while excludin
254
294
255
295
### Specifying the top-level key for state_dict
256
296
257
- Sometimes the [ ` state_dict ` ] ( https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict )
297
+ Sometimes the
298
+ [ ` state_dict ` ] ( https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict )
258
299
is nested under a top-level key along with other metadata as in a
259
300
[ general checkpoint] ( https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training ) .
260
- For example, the ` state_dict ` of the whisper model is nested under ` model_state_dict ` key.
261
- In this case, you can specify the top-level key in ` LoadArgs ` :
301
+ For example, the ` state_dict ` of the whisper model is nested under ` model_state_dict ` key. In this
302
+ case, you can specify the top-level key in ` LoadArgs ` :
262
303
263
304
``` rust
264
305
let device = Default :: default ();
0 commit comments