Skip to content

Commit 545444c

Browse files
authored
PyTorchFileRecord print debug option (#1425)
* Add debug print option to PyTorchFileRecorder * Updated documentation and improved print output * Improve print wording * Updated per PR feedback
1 parent b429cc3 commit 545444c

File tree

5 files changed

+98
-15
lines changed

5 files changed

+98
-15
lines changed

burn-book/src/import/pytorch-model.md

+46-5
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,8 @@ Which produces the following weights structure (viewed in
230230
You can use the `PyTorchFileRecorder` to change the attribute names and the order of the attributes
231231
by specifying a regular expression (See
232232
[regex::Regex::replace](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace) and
233-
[try it online](https://rregex.dev/?version=1.10&method=replace)) to
234-
match the attribute name and a replacement string in `LoadArgs`:
233+
[try it online](https://rregex.dev/?version=1.10&method=replace)) to match the attribute name and a
234+
replacement string in `LoadArgs`:
235235

236236
```rust
237237
let device = Default::default();
@@ -246,6 +246,46 @@ let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
246246
let model = Net::<Backend>::new_with(record);
247247
```
248248

249+
### Printing the source model keys and tensor information
250+
251+
If you are unsure about the keys in the source model, you can print them using the following code:
252+
253+
```rust
254+
let device = Default::default();
255+
let load_args = LoadArgs::new("tests/key_remap/key_remap.pt".into())
256+
// Remove "conv" prefix, e.g. "conv.conv1" -> "conv1"
257+
.with_key_remap("conv\\.(.*)", "$1")
258+
.with_debug_print(); // Print the keys and remapped keys
259+
260+
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
261+
.load(load_args, &device)
262+
.expect("Should decode state successfully");
263+
264+
let model = Net::<Backend>::new_with(record);
265+
```
266+
267+
Here is an example of the output:
268+
269+
```text
270+
Debug information of keys and tensor shapes:
271+
---
272+
Original Key: conv.conv1.bias
273+
Remapped Key: conv1.bias
274+
Shape: [2]
275+
Dtype: F32
276+
---
277+
Original Key: conv.conv1.weight
278+
Remapped Key: conv1.weight
279+
Shape: [2, 2, 2, 2]
280+
Dtype: F32
281+
---
282+
Original Key: conv.conv2.weight
283+
Remapped Key: conv2.weight
284+
Shape: [2, 2, 2, 2]
285+
Dtype: F32
286+
---
287+
```
288+
249289
### Loading the model weights to a partial model
250290

251291
`PyTorchFileRecorder` enables selective weight loading into partial models. For instance, in a model
@@ -254,11 +294,12 @@ defining the encoder in Burn, allowing the loading of its weights while excludin
254294

255295
### Specifying the top-level key for state_dict
256296

257-
Sometimes the [`state_dict`](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict)
297+
Sometimes the
298+
[`state_dict`](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict)
258299
is nested under a top-level key along with other metadata as in a
259300
[general checkpoint](https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training).
260-
For example, the `state_dict` of the whisper model is nested under `model_state_dict` key.
261-
In this case, you can specify the top-level key in `LoadArgs`:
301+
For example, the `state_dict` of the whisper model is nested under `model_state_dict` key. In this
302+
case, you can specify the top-level key in `LoadArgs`:
262303

263304
```rust
264305
let device = Default::default();

crates/burn-core/src/record/serde/data.rs

+13-5
Original file line numberDiff line numberDiff line change
@@ -182,19 +182,25 @@ impl NestedValue {
182182
/// * `key_remap` - A vector of tuples containing a regular expression and a replacement string.
183183
/// See [regex::Regex::replace](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace)
184184
/// for more information.
185-
///
186185
/// # Returns
187186
///
188-
/// A map of tensors with the remapped keys.
187+
/// A map of tensors with the remapped keys and
188+
/// a vector of tuples containing the remapped and original.
189189
pub fn remap<T>(
190190
mut tensors: HashMap<String, T>,
191191
key_remap: Vec<(Regex, String)>,
192-
) -> HashMap<String, T> {
192+
) -> (HashMap<String, T>, Vec<(String, String)>) {
193193
if key_remap.is_empty() {
194-
return tensors;
194+
let remapped_names = tensors
195+
.keys()
196+
.cloned()
197+
.map(|s| (s.clone(), s)) // Name is the same as the remapped name
198+
.collect();
199+
return (tensors, remapped_names);
195200
}
196201

197202
let mut remapped = HashMap::new();
203+
let mut remapped_names = Vec::new();
198204

199205
for (name, tensor) in tensors.drain() {
200206
let mut new_name = name.clone();
@@ -205,10 +211,12 @@ pub fn remap<T>(
205211
.to_string();
206212
}
207213
}
214+
215+
remapped_names.push((new_name.clone(), name));
208216
remapped.insert(new_name, tensor);
209217
}
210218

211-
remapped
219+
(remapped, remapped_names)
212220
}
213221

214222
/// Helper function to insert a value into a nested map/vector of tensors.

crates/burn-import/pytorch-tests/tests/key_remap/mod.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ mod tests {
4141
fn key_remap() {
4242
let device = Default::default();
4343
let load_args = LoadArgs::new("tests/key_remap/key_remap.pt".into())
44-
.with_key_remap("conv\\.(.*)", "$1"); // Remove "conv" prefix, e.g. "conv.conv1" -> "conv1"
44+
.with_key_remap("conv\\.(.*)", "$1") // Remove "conv" prefix, e.g. "conv.conv1" -> "conv1"
45+
.with_debug_print();
4546

4647
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
4748
.load(load_args, &device)

crates/burn-import/src/pytorch/reader.rs

+23-1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ pub fn from_file<PS, D, B>(
3535
path: &Path,
3636
key_remap: Vec<(Regex, String)>,
3737
top_level_key: Option<&str>,
38+
debug: bool,
3839
) -> Result<D, Error>
3940
where
4041
D: DeserializeOwned,
@@ -48,7 +49,28 @@ where
4849
.collect();
4950

5051
// Remap the keys (replace the keys in the map with the new keys)
51-
let tensors = remap(tensors, key_remap);
52+
let (tensors, remapped_keys) = remap(tensors, key_remap);
53+
54+
// Print the remapped keys if debug is enabled
55+
if debug {
56+
let mut remapped_keys = remapped_keys;
57+
remapped_keys.sort();
58+
println!("Debug information of keys and tensor shapes:\n---");
59+
for (new_key, old_key) in remapped_keys {
60+
if old_key != new_key {
61+
println!("Original Key: {old_key}");
62+
println!("Remapped Key: {new_key}");
63+
} else {
64+
println!("Key: {}", new_key);
65+
}
66+
67+
let shape = tensors[&new_key].shape();
68+
let dtype = tensors[&new_key].dtype();
69+
println!("Shape: {shape:?}");
70+
println!("Dtype: {dtype:?}");
71+
println!("---");
72+
}
73+
}
5274

5375
// Convert the vector of Candle tensors to a nested value data structure
5476
let nested_value = unflatten::<PS, _>(tensors)?;

crates/burn-import/src/pytorch/recorder.rs

+14-3
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ impl<PS: PrecisionSettings, B: Backend> Recorder<B> for PyTorchFileRecorder<PS>
4848
&args.file,
4949
args.key_remap,
5050
args.top_level_key.as_deref(), // Convert Option<String> to Option<&str>
51+
args.debug,
5152
)?;
5253
Ok(R::from_item(item, device))
5354
}
@@ -92,10 +93,13 @@ pub struct LoadArgs {
9293
/// Top-level key to load state_dict from the file.
9394
/// Sometimes the state_dict is nested under a top-level key in a dict.
9495
pub top_level_key: Option<String>,
96+
97+
/// Whether to print debug information.
98+
pub debug: bool,
9599
}
96100

97101
impl LoadArgs {
98-
/// Create a new `LoadArgs` instance.
102+
/// Creates a new `LoadArgs` instance.
99103
///
100104
/// # Arguments
101105
///
@@ -105,10 +109,11 @@ impl LoadArgs {
105109
file,
106110
key_remap: Vec::new(),
107111
top_level_key: None,
112+
debug: false,
108113
}
109114
}
110115

111-
/// Set key remapping.
116+
/// Sets key remapping.
112117
///
113118
/// # Arguments
114119
///
@@ -125,7 +130,7 @@ impl LoadArgs {
125130
self
126131
}
127132

128-
/// Set top-level key to load state_dict from the file.
133+
/// Sets the top-level key to load state_dict from the file.
129134
/// Sometimes the state_dict is nested under a top-level key in a dict.
130135
///
131136
/// # Arguments
@@ -135,6 +140,12 @@ impl LoadArgs {
135140
self.top_level_key = Some(key.into());
136141
self
137142
}
143+
144+
/// Sets printing debug information on.
145+
pub fn with_debug_print(mut self) -> Self {
146+
self.debug = true;
147+
self
148+
}
138149
}
139150

140151
impl From<PathBuf> for LoadArgs {

0 commit comments

Comments
 (0)