diff --git a/Cargo.lock b/Cargo.lock index 58191ae8bc..213f8d100d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6338,6 +6338,27 @@ dependencies = [ "serde_json", ] +[[package]] +name = "safetensors-import" +version = "0.17.0" +dependencies = [ + "burn", + "burn-import", + "stmodel", +] + +[[package]] +name = "safetensors-tests" +version = "0.17.0" +dependencies = [ + "burn", + "burn-autodiff", + "burn-import", + "burn-ndarray", + "float-cmp", + "serde", +] + [[package]] name = "same-file" version = "1.0.6" @@ -6845,6 +6866,14 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +[[package]] +name = "stmodel" +version = "0.6.0" +dependencies = [ + "burn", + "burn-import", +] + [[package]] name = "streaming-decompression" version = "0.1.2" diff --git a/Cargo.toml b/Cargo.toml index 22ed0b2644..09c810871e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ members = [ "crates/*", "crates/burn-import/pytorch-tests", "crates/burn-import/onnx-tests", + "crates/burn-import/safetensors-tests", "examples/*", "examples/pytorch-import/model", "xtask", diff --git a/burn-book/src/import/pytorch-model.md b/burn-book/src/import/pytorch-model.md index 1f584cdc9f..b6467dcf8e 100644 --- a/burn-book/src/import/pytorch-model.md +++ b/burn-book/src/import/pytorch-model.md @@ -3,8 +3,8 @@ ## Introduction Whether you've trained your model in PyTorch or you want to use a pre-trained model from PyTorch, -you can import them into Burn. Burn supports importing PyTorch model weights with `.pt` file -extension. Compared to ONNX models, `.pt` files only contain the weights of the model, so you will +you can import them into Burn. Burn supports importing PyTorch model weights with `.pt` and `.safetensors` file +extension. Compared to ONNX models, `.pt` and `.safetensors` files only contain the weights of the model, so you will need to reconstruct the model architecture in Burn. Here in this document we will show the full workflow of exporting a PyTorch model and importing it. diff --git a/crates/burn-import/Cargo.toml b/crates/burn-import/Cargo.toml index ee7c1c559e..076528bc4a 100644 --- a/crates/burn-import/Cargo.toml +++ b/crates/burn-import/Cargo.toml @@ -15,12 +15,15 @@ version.workspace = true default-run = "onnx2burn" [features] -default = ["onnx", "pytorch"] +default = ["onnx", "pytorch", "safetensors"] onnx = [] pytorch = ["burn/record-item-custom-serde", "thiserror", "zip"] +safetensors = ["burn/record-item-custom-serde", "thiserror", "zip"] [dependencies] -burn = { path = "../burn", version = "0.17.0", default-features = false, features = ["std"]} +burn = { path = "../burn", version = "0.17.0", default-features = false, features = [ + "std", +] } burn-ndarray = { path = "../burn-ndarray", version = "0.17.0", default-features = false } onnx-ir = { path = "../onnx-ir", version = "0.17.0" } candle-core = { workspace = true } diff --git a/crates/burn-import/safetensors-tests/Cargo.toml b/crates/burn-import/safetensors-tests/Cargo.toml new file mode 100644 index 0000000000..cdf8927a63 --- /dev/null +++ b/crates/burn-import/safetensors-tests/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "safetensors-tests" +version.workspace = true +edition.workspace = true +license.workspace = true + +[dev-dependencies] +burn = { path = "../../burn" } +burn-ndarray = { path = "../../burn-ndarray" } +burn-autodiff = { path = "../../burn-autodiff" } +serde = { workspace = true } +float-cmp = { workspace = true } +burn-import = { path = "../", features = ["safetensors"] } + + +[build-dependencies] +burn-import = { path = "../", features = ["safetensors"] } diff --git a/crates/burn-import/safetensors-tests/src/lib.rs b/crates/burn-import/safetensors-tests/src/lib.rs new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/crates/burn-import/safetensors-tests/src/lib.rs @@ -0,0 +1 @@ + diff --git a/crates/burn-import/safetensors-tests/tests/batch_norm/batch_norm2d.safetensors b/crates/burn-import/safetensors-tests/tests/batch_norm/batch_norm2d.safetensors new file mode 100644 index 0000000000..1fb5bddaa7 Binary files /dev/null and b/crates/burn-import/safetensors-tests/tests/batch_norm/batch_norm2d.safetensors differ diff --git a/crates/burn-import/safetensors-tests/tests/batch_norm/export_weights.py b/crates/burn-import/safetensors-tests/tests/batch_norm/export_weights.py new file mode 100755 index 0000000000..4f82ab4201 --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/batch_norm/export_weights.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn as nn +import torch.nn.functional as F +from safetensors.torch import save_file + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.norm1 = nn.BatchNorm2d(5) + + def forward(self, x): + x = self.norm1(x) + return x + + +def main(): + + torch.set_printoptions(precision=8) + torch.manual_seed(1) + + model = Model().to(torch.device("cpu")) + + # Condition batch norm (each forward will affect the running stats) + x1 = torch.ones(1, 5, 2, 2) - 0.5 + _ = model(x1) + model.eval() # Set to eval mode to freeze running stats + # Save the model to safetensors after the first forward + save_file(model.state_dict(), "batch_norm2d.safetensors") + + x2 = torch.ones(1, 5, 2, 2) - 0.3 + print("Input shape: {}", x2.shape) + output = model(x2) + print("Output: {}", output) + print("Output Shape: {}", output.shape) + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/safetensors-tests/tests/batch_norm/mod.rs b/crates/burn-import/safetensors-tests/tests/batch_norm/mod.rs new file mode 100644 index 0000000000..0d8319a680 --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/batch_norm/mod.rs @@ -0,0 +1,60 @@ +use burn::{ + module::Module, + nn::{BatchNorm, BatchNormConfig}, + tensor::{backend::Backend, Tensor}, +}; + +#[derive(Module, Debug)] +pub struct Net { + norm1: BatchNorm, +} + +impl Net { + pub fn new(device: &B::Device) -> Self { + Self { + norm1: BatchNormConfig::new(4).init(device), + } + } + + /// Forward pass of the model. + pub fn forward(&self, x: Tensor) -> Tensor { + self.norm1.forward(x) + } +} + +#[cfg(test)] +mod tests { + type Backend = burn_ndarray::NdArray; + + use burn::record::{FullPrecisionSettings, Recorder}; + use burn_import::safetensors::SafeTensorsFileRecorder; + + use super::*; + + #[test] + fn batch_norm2d() { + let device = Default::default(); + let record = SafeTensorsFileRecorder::::default() + .load("tests/batch_norm/batch_norm2d.safetensors".into(), &device) + .expect("Should decode state successfully"); + + let model = Net::::new(&device).load_record(record); + + let input = Tensor::::ones([1, 5, 2, 2], &device) - 0.3; + + let output = model.forward(input); + + let expected = Tensor::::from_data( + [[ + [[0.68515635, 0.68515635], [0.68515635, 0.68515635]], + [[0.68515635, 0.68515635], [0.68515635, 0.68515635]], + [[0.68515635, 0.68515635], [0.68515635, 0.68515635]], + [[0.68515635, 0.68515635], [0.68515635, 0.68515635]], + [[0.68515635, 0.68515635], [0.68515635, 0.68515635]], + ]], + &device, + ); + + output.to_data().assert_approx_eq(&expected.to_data(), 5); + } +} diff --git a/crates/burn-import/safetensors-tests/tests/boolean/boolean.safetensors b/crates/burn-import/safetensors-tests/tests/boolean/boolean.safetensors new file mode 100644 index 0000000000..daaa4ef899 Binary files /dev/null and b/crates/burn-import/safetensors-tests/tests/boolean/boolean.safetensors differ diff --git a/crates/burn-import/safetensors-tests/tests/boolean/export_weights.py b/crates/burn-import/safetensors-tests/tests/boolean/export_weights.py new file mode 100755 index 0000000000..ad859eb6b9 --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/boolean/export_weights.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn as nn +import torch.nn.functional as F +from safetensors.torch import save_file + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + buffer = torch.tensor([True, False, True]) + self.register_buffer("buffer", buffer, persistent=True) + + def forward(self, x): + x = self.buffer + return x + + +def main(): + + torch.set_printoptions(precision=8) + torch.manual_seed(1) + + model = Model().to(torch.device("cpu")) + + save_file(model.state_dict(), "boolean.safetensors") + + input = torch.ones(3, 3) + print("Input shape: {}", input.shape) + print("Input: {}", input) + output = model(input) + print("Output: {}", output) + print("Output Shape: {}", output.shape) + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/safetensors-tests/tests/boolean/mod.rs b/crates/burn-import/safetensors-tests/tests/boolean/mod.rs new file mode 100644 index 0000000000..d203dad20f --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/boolean/mod.rs @@ -0,0 +1,58 @@ +use burn::{ + module::{Module, Param}, + tensor::{backend::Backend, Bool, Tensor}, +}; + +#[derive(Module, Debug)] +pub struct Net { + buffer: Param>, +} + +impl Net { + /// Create a new model from the given record. + pub fn new_with(record: NetRecord) -> Self { + Self { + buffer: record.buffer, + } + } + + /// Forward pass of the model. + pub fn forward(&self, _x: Tensor) -> Tensor { + self.buffer.val() + } +} + +#[cfg(test)] +mod tests { + + use burn::{ + record::{FullPrecisionSettings, Recorder}, + tensor::TensorData, + }; + use burn_import::safetensors::SafeTensorsFileRecorder; + + use super::*; + + type Backend = burn_ndarray::NdArray; + + #[test] + #[ignore = "It appears loading boolean tensors are not supported yet"] + // Error skipping: Msg("unsupported storage type BoolStorage") + fn boolean() { + let device = Default::default(); + let record = SafeTensorsFileRecorder::::default() + .load("tests/boolean/boolean.safetensors".into(), &device) + .expect("Should decode state successfully"); + + let model = Net::::new_with(record); + + let input = Tensor::::ones([3, 3], &device); + + let output = model.forward(input); + + let expected = + Tensor::::from_bool(TensorData::from([true, false, true]), &device); + + assert_eq!(output.to_data(), expected.to_data()); + } +} diff --git a/crates/burn-import/safetensors-tests/tests/buffer/buffer.safetensors b/crates/burn-import/safetensors-tests/tests/buffer/buffer.safetensors new file mode 100644 index 0000000000..8b4b51b4ab Binary files /dev/null and b/crates/burn-import/safetensors-tests/tests/buffer/buffer.safetensors differ diff --git a/crates/burn-import/safetensors-tests/tests/buffer/export_weights.py b/crates/burn-import/safetensors-tests/tests/buffer/export_weights.py new file mode 100755 index 0000000000..32148e92ea --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/buffer/export_weights.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn as nn +import torch.nn.functional as F +from safetensors.torch import save_file + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + buffer = torch.ones(3, 3) + self.register_buffer("buffer", buffer, persistent=True) + + def forward(self, x): + x = self.buffer + x + return x + + +def main(): + + torch.set_printoptions(precision=8) + torch.manual_seed(1) + + model = Model().to(torch.device("cpu")) + + save_file(model.state_dict(), "buffer.safetensors") + + input = torch.ones(3, 3) + print("Input shape: {}", input.shape) + print("Input: {}", input) + output = model(input) + print("Output: {}", output) + print("Output Shape: {}", output.shape) + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/safetensors-tests/tests/buffer/mod.rs b/crates/burn-import/safetensors-tests/tests/buffer/mod.rs new file mode 100644 index 0000000000..b561ce6f1c --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/buffer/mod.rs @@ -0,0 +1,51 @@ +use burn::{ + module::{Module, Param}, + tensor::{backend::Backend, Tensor}, +}; + +#[derive(Module, Debug)] +pub struct Net { + buffer: Param>, +} + +impl Net { + /// Create a new model from the given record. + pub fn new_with(record: NetRecord) -> Self { + Self { + buffer: record.buffer, + } + } + + /// Forward pass of the model. + pub fn forward(&self, x: Tensor) -> Tensor { + self.buffer.val() + x + } +} + +#[cfg(test)] +mod tests { + type Backend = burn_ndarray::NdArray; + + use burn::record::{FullPrecisionSettings, Recorder}; + use burn_import::safetensors::SafeTensorsFileRecorder; + + use super::*; + + #[test] + fn buffer() { + let device = Default::default(); + let record = SafeTensorsFileRecorder::::default() + .load("tests/buffer/buffer.safetensors".into(), &device) + .expect("Should decode state successfully"); + + let model = Net::::new_with(record); + + let input = Tensor::::ones([3, 3], &device); + + let output = model.forward(input); + + let expected = Tensor::::ones([3, 3], &device) * 2.0; + + output.to_data().assert_approx_eq(&expected.to_data(), 3); + } +} diff --git a/crates/burn-import/safetensors-tests/tests/complex_nested/complex_nested.safetensors b/crates/burn-import/safetensors-tests/tests/complex_nested/complex_nested.safetensors new file mode 100644 index 0000000000..faba8d2722 Binary files /dev/null and b/crates/burn-import/safetensors-tests/tests/complex_nested/complex_nested.safetensors differ diff --git a/crates/burn-import/safetensors-tests/tests/complex_nested/export_weights.py b/crates/burn-import/safetensors-tests/tests/complex_nested/export_weights.py new file mode 100755 index 0000000000..c306f1d072 --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/complex_nested/export_weights.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn as nn +import torch.nn.functional as F +from safetensors.torch import save_file # Add this import + + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size): + super(ConvBlock, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size) + self.norm = nn.BatchNorm2d(out_channels) + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + + self.conv_blocks = nn.Sequential( + ConvBlock(2, 4, (3, 2)), + ConvBlock(4, 6, (3, 2)), + ) + self.norm1 = nn.BatchNorm2d(6) + + self.fc1 = nn.Linear(120, 12) + self.fc2 = nn.Linear(12, 10) + + def forward(self, x): + x = self.conv_blocks(x) + x = self.norm1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + x = F.log_softmax(x, dim=1) + return x + + +def main(): + + torch.set_printoptions(precision=8) + torch.manual_seed(2) + + model = Net().to(torch.device("cpu")) + + # Condition the model (batch norm requires a forward pass to compute the mean and variance) + x1 = torch.ones(1, 2, 9, 6) - 0.1 + x2 = torch.ones(1, 2, 9, 6) - 0.3 + output = model(x1) + output = model(x2) + model.eval() # set to eval mode + + save_file(model.state_dict(), "complex_nested.safetensors") # Replace torch.save + + # feed test data + x = torch.ones(1, 2, 9, 6) - 0.5 + output = model(x) + print("Input shape: {}", x.shape) + print("Output: {}", output) + print("Output Shape: {}", output.shape) + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/safetensors-tests/tests/complex_nested/mod.rs b/crates/burn-import/safetensors-tests/tests/complex_nested/mod.rs new file mode 100644 index 0000000000..db659faeaf --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/complex_nested/mod.rs @@ -0,0 +1,159 @@ +use burn::record::{FullPrecisionSettings, HalfPrecisionSettings, Recorder}; + +use burn::{ + module::Module, + nn::{ + conv::{Conv2d, Conv2dConfig}, + BatchNorm, BatchNormConfig, Linear, LinearConfig, + }, + tensor::{ + activation::{log_softmax, relu}, + backend::Backend, + Tensor, + }, +}; +use burn_autodiff::Autodiff; + +#[derive(Module, Debug)] +pub struct ConvBlock { + conv: Conv2d, + norm: BatchNorm, +} + +#[derive(Module, Debug)] +pub struct Net { + conv_blocks: Vec>, + norm1: BatchNorm, + fc1: Linear, + fc2: Linear, +} + +impl Net { + pub fn init(device: &B::Device) -> Self { + let conv_blocks = vec![ + ConvBlock { + conv: Conv2dConfig::new([2, 4], [3, 2]).init(device), + norm: BatchNormConfig::new(2).init(device), + }, + ConvBlock { + conv: Conv2dConfig::new([4, 6], [3, 2]).init(device), + norm: BatchNormConfig::new(4).init(device), + }, + ]; + let norm1 = BatchNormConfig::new(6).init(device); + let fc1 = LinearConfig::new(120, 12).init(device); + let fc2 = LinearConfig::new(12, 10).init(device); + + Self { + conv_blocks, + norm1, + fc1, + fc2, + } + } + + /// Forward pass of the model. + pub fn forward(&self, x: Tensor) -> Tensor { + let x = self.conv_blocks[0].forward(x); + let x = self.conv_blocks[1].forward(x); + let x = self.norm1.forward(x); + let x = x.reshape([0, -1]); + let x = self.fc1.forward(x); + let x = relu(x); + let x = self.fc2.forward(x); + + log_softmax(x, 1) + } +} + +impl ConvBlock { + pub fn forward(&self, x: Tensor) -> Tensor { + let x = self.conv.forward(x); + + self.norm.forward(x) + } +} + +/// Model with extra fields to test loading of records (e.g. from a different model). +#[derive(Module, Debug)] +pub struct PartialWithExtraNet { + conv1: ConvBlock, + extra_field: bool, // This field is not present in the pytorch model +} + +type TestBackend = burn_ndarray::NdArray; + +fn model_test(record: NetRecord, precision: usize) { + let device = Default::default(); + let model = Net::::init(&device).load_record(record); + + let input = Tensor::::ones([1, 2, 9, 6], &device) - 0.5; + + let output = model.forward(input); + + let expected = Tensor::::from_data( + [[ + -2.306_613, + -2.058_945_4, + -2.298_372_7, + -2.358_294, + -2.296_395_5, + -2.416_090_5, + -2.107_669, + -2.428_420_8, + -2.526_469, + -2.319_918_6, + ]], + &device, + ); + + output + .to_data() + .assert_approx_eq(&expected.to_data(), precision); +} + +#[cfg(test)] +mod tests { + use super::*; + use burn_import::safetensors::SafeTensorsFileRecorder; + + #[test] + fn full_record() { + let device = Default::default(); + let record = SafeTensorsFileRecorder::::default() + .load( + "tests/complex_nested/complex_nested.safetensors".into(), + &device, + ) + .expect("Should decode state successfully"); + + model_test(record, 8); + } + + #[test] + fn full_record_autodiff() { + let device = Default::default(); + let record = SafeTensorsFileRecorder::::default() + .load( + "tests/complex_nested/complex_nested.safetensors".into(), + &device, + ) + .expect("Should decode state successfully"); + + let device = Default::default(); + let _model = Net::>::init(&device).load_record(record); + } + + #[test] + fn half_record() { + let device = Default::default(); + let record = SafeTensorsFileRecorder::::default() + .load( + "tests/complex_nested/complex_nested.safetensors".into(), + &device, + ) + .expect("Should decode state successfully"); + + model_test(record, 4); + } +} diff --git a/crates/burn-import/safetensors-tests/tests/conv1d/conv1d.safetensors b/crates/burn-import/safetensors-tests/tests/conv1d/conv1d.safetensors new file mode 100644 index 0000000000..99cdde91af Binary files /dev/null and b/crates/burn-import/safetensors-tests/tests/conv1d/conv1d.safetensors differ diff --git a/crates/burn-import/safetensors-tests/tests/conv1d/export_weights.py b/crates/burn-import/safetensors-tests/tests/conv1d/export_weights.py new file mode 100755 index 0000000000..be21da569b --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/conv1d/export_weights.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn as nn +import torch.nn.functional as F +from safetensors.torch import save_file + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.conv1 = nn.Conv1d(2, 2, 2) + self.conv2 = nn.Conv1d(2, 2, 2, bias=False) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + +def main(): + + torch.set_printoptions(precision=8) + torch.manual_seed(1) + + model = Model().to(torch.device("cpu")) + + save_file(model.state_dict(), "conv1d.safetensors") + + input = torch.rand(1, 2, 6) + print("Input shape: {}", input.shape) + print("Input: {}", input) + output = model(input) + print("Output: {}", output) + print("Output Shape: {}", output.shape) + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/safetensors-tests/tests/conv1d/mod.rs b/crates/burn-import/safetensors-tests/tests/conv1d/mod.rs new file mode 100644 index 0000000000..ac69c913ce --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/conv1d/mod.rs @@ -0,0 +1,94 @@ +use burn::{ + module::Module, + nn::conv::{Conv1d, Conv1dConfig}, + tensor::{backend::Backend, Tensor}, +}; + +#[derive(Module, Debug)] +pub struct Net { + conv1: Conv1d, + conv2: Conv1d, +} + +impl Net { + /// Create a new model from the given record. + pub fn init(device: &B::Device) -> Self { + let conv1 = Conv1dConfig::new(2, 2, 2).init(device); + let conv2 = Conv1dConfig::new(2, 2, 2).with_bias(false).init(device); + + Self { conv1, conv2 } + } + + /// Forward pass of the model. + pub fn forward(&self, x: Tensor) -> Tensor { + let x = self.conv1.forward(x); + + self.conv2.forward(x) + } +} + +#[cfg(test)] +mod tests { + type Backend = burn_ndarray::NdArray; + use burn::record::{FullPrecisionSettings, HalfPrecisionSettings, Recorder}; + use burn_import::safetensors::SafeTensorsFileRecorder; + + use super::*; + + fn conv1d(record: NetRecord, precision: usize) { + let device = Default::default(); + + let model = Net::::init(&device).load_record(record); + + let input = Tensor::::from_data( + [[ + [ + 0.93708336, 0.65559506, 0.31379688, 0.19801933, 0.41619217, 0.28432965, + ], + [ + 0.33977574, + 0.523_940_8, + 0.798_063_9, + 0.77176833, + 0.01122457, + 0.80996025, + ], + ]], + &device, + ); + + let output = model.forward(input); + + let expected = Tensor::::from_data( + [[ + [0.02987457, 0.03134188, 0.04234261, -0.02437721], + [-0.03788019, -0.02972012, -0.00806090, -0.01981254], + ]], + &device, + ); + + output + .to_data() + .assert_approx_eq(&expected.to_data(), precision); + } + + #[test] + fn conv1d_full_precision() { + let device = Default::default(); + let record = SafeTensorsFileRecorder::::default() + .load("tests/conv1d/conv1d.safetensors".into(), &device) + .expect("Should decode state successfully"); + + conv1d(record, 7); + } + + #[test] + fn conv1d_half_precision() { + let device = Default::default(); + let record = SafeTensorsFileRecorder::::default() + .load("tests/conv1d/conv1d.safetensors".into(), &device) + .expect("Should decode state successfully"); + + conv1d(record, 4); + } +} diff --git a/crates/burn-import/safetensors-tests/tests/conv2d/conv2d.safetensors b/crates/burn-import/safetensors-tests/tests/conv2d/conv2d.safetensors new file mode 100644 index 0000000000..182f4553ef Binary files /dev/null and b/crates/burn-import/safetensors-tests/tests/conv2d/conv2d.safetensors differ diff --git a/crates/burn-import/safetensors-tests/tests/conv2d/export_weights.py b/crates/burn-import/safetensors-tests/tests/conv2d/export_weights.py new file mode 100755 index 0000000000..d394caa03e --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/conv2d/export_weights.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn as nn +import torch.nn.functional as F +from safetensors.torch import save_file + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.conv1 = nn.Conv2d(2, 2, (2, 2)) + self.conv2 = nn.Conv2d(2, 2, (2, 2), bias=False) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + +def main(): + + torch.set_printoptions(precision=8) + torch.manual_seed(1) + + model = Model().to(torch.device("cpu")) + + save_file(model.state_dict(), "conv2d.safetensors") + + input = torch.rand(1, 2, 5, 5) + print("Input shape: {}", input.shape) + print("Input: {}", input) + output = model(input) + print("Output: {}", output) + print("Output Shape: {}", output.shape) + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/safetensors-tests/tests/conv2d/mod.rs b/crates/burn-import/safetensors-tests/tests/conv2d/mod.rs new file mode 100644 index 0000000000..961b2e051e --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/conv2d/mod.rs @@ -0,0 +1,132 @@ +use burn::{ + module::Module, + nn::conv::{Conv2d, Conv2dConfig}, + tensor::{backend::Backend, Tensor}, +}; + +#[derive(Module, Debug)] +pub struct Net { + conv1: Conv2d, + conv2: Conv2d, +} + +impl Net { + /// Create a new model from the given record. + pub fn init(device: &B::Device) -> Self { + let conv1 = Conv2dConfig::new([2, 2], [2, 2]).init(device); + let conv2 = Conv2dConfig::new([2, 2], [2, 2]) + .with_bias(false) + .init(device); + + Self { conv1, conv2 } + } + + /// Forward pass of the model. + pub fn forward(&self, x: Tensor) -> Tensor { + let x = self.conv1.forward(x); + + self.conv2.forward(x) + } +} + +#[cfg(test)] +mod tests { + type Backend = burn_ndarray::NdArray; + + use burn::record::{FullPrecisionSettings, HalfPrecisionSettings, Recorder}; + use burn_import::safetensors::SafeTensorsFileRecorder; + + use super::*; + + fn conv2d(record: NetRecord, precision: usize) { + let device = Default::default(); + + let model = Net::::init(&device).load_record(record); + + let input = Tensor::::from_data( + [[ + [ + [ + 0.024_595_8, + 0.25883394, + 0.93905586, + 0.416_715_5, + 0.713_979_7, + ], + [0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4, 0.505_920_8], + [0.23659128, 0.757_007_4, 0.23458993, 0.64705235, 0.355_621_4], + [0.445_182_8, 0.01930594, 0.26160914, 0.771_317, 0.37846136], + [ + 0.99802476, + 0.900_794_2, + 0.476_588_2, + 0.16625845, + 0.804_481_1, + ], + ], + [ + [ + 0.65517855, + 0.17679012, + 0.824_772_3, + 0.803_550_9, + 0.943_447_5, + ], + [0.21972018, 0.417_697, 0.49031407, 0.57302874, 0.12054086], + [0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7, 0.52850497], + [0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537, 0.03694397], + [ + 0.751_675_7, + 0.148_438_4, + 0.12274551, + 0.530_407_2, + 0.414_796_4, + ], + ], + ]], + &device, + ); + + let output = model.forward(input); + + let expected = Tensor::::from_data( + [[ + [ + [-0.02502128, 0.00250649, 0.04841233], + [0.04589614, -0.00296854, 0.01991477], + [0.02920526, 0.059_497_3, 0.04326791], + ], + [ + [-0.04825336, 0.080_190_9, -0.02375088], + [0.02885434, 0.09638263, -0.07460806], + [0.02004079, 0.06244051, 0.035_887_1], + ], + ]], + &device, + ); + + output + .to_data() + .assert_approx_eq(&expected.to_data(), precision); + } + + #[test] + fn conv2d_full_precision() { + let device = Default::default(); + let record = SafeTensorsFileRecorder::::default() + .load("tests/conv2d/conv2d.safetensors".into(), &device) + .expect("Should decode state successfully"); + + conv2d(record, 7); + } + + #[test] + fn conv2d_half_precision() { + let device = Default::default(); + let record = SafeTensorsFileRecorder::::default() + .load("tests/conv2d/conv2d.safetensors".into(), &device) + .expect("Should decode state successfully"); + + conv2d(record, 4); + } +} diff --git a/crates/burn-import/safetensors-tests/tests/conv_transpose1d/conv_transpose1d.safetensors b/crates/burn-import/safetensors-tests/tests/conv_transpose1d/conv_transpose1d.safetensors new file mode 100644 index 0000000000..99cdde91af Binary files /dev/null and b/crates/burn-import/safetensors-tests/tests/conv_transpose1d/conv_transpose1d.safetensors differ diff --git a/crates/burn-import/safetensors-tests/tests/conv_transpose1d/export_weights.py b/crates/burn-import/safetensors-tests/tests/conv_transpose1d/export_weights.py new file mode 100755 index 0000000000..efaca57843 --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/conv_transpose1d/export_weights.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn as nn +import torch.nn.functional as F +from safetensors.torch import save_file # Add this import + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.conv1 = nn.ConvTranspose1d(2, 2, 2) + self.conv2 = nn.ConvTranspose1d(2, 2, 2, bias=False) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + +def main(): + + torch.set_printoptions(precision=8) + torch.manual_seed(1) + + model = Model().to(torch.device("cpu")) + + save_file(model.state_dict(), "conv_transpose1d.safetensors") # Replace torch.save + + input = torch.rand(1, 2, 2) + print("Input shape: {}", input.shape) + print("Input: {}", input) + output = model(input) + print("Output: {}", output) + print("Output Shape: {}", output.shape) + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/safetensors-tests/tests/conv_transpose1d/mod.rs b/crates/burn-import/safetensors-tests/tests/conv_transpose1d/mod.rs new file mode 100644 index 0000000000..4dbcfdb020 --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/conv_transpose1d/mod.rs @@ -0,0 +1,88 @@ +use burn::{ + module::Module, + nn::conv::{ConvTranspose1d, ConvTranspose1dConfig}, + tensor::{backend::Backend, Tensor}, +}; + +#[derive(Module, Debug)] +pub struct Net { + conv1: ConvTranspose1d, + conv2: ConvTranspose1d, +} + +impl Net { + /// Create a new model from the given record. + pub fn init(device: &B::Device) -> Self { + let conv1 = ConvTranspose1dConfig::new([2, 2], 2).init(device); + let conv2 = ConvTranspose1dConfig::new([2, 2], 2).init(device); + + Self { conv1, conv2 } + } + + /// Forward pass of the model. + pub fn forward(&self, x: Tensor) -> Tensor { + let x = self.conv1.forward(x); + + self.conv2.forward(x) + } +} + +#[cfg(test)] +mod tests { + type Backend = burn_ndarray::NdArray; + + use burn::record::{FullPrecisionSettings, HalfPrecisionSettings, Recorder}; + use burn_import::safetensors::SafeTensorsFileRecorder; + + use super::*; + + fn conv_transpose1d(record: NetRecord, precision: usize) { + let device = Default::default(); + + let model = Net::::init(&device).load_record(record); + + let input = Tensor::::from_data( + [[[0.93708336, 0.65559506], [0.31379688, 0.19801933]]], + &device, + ); + + let output = model.forward(input); + + let expected = Tensor::::from_data( + [[ + [0.02935525, 0.01119324, -0.01356167, -0.00682688], + [0.01644749, -0.01429807, 0.00083987, 0.00279229], + ]], + &device, + ); + + output + .to_data() + .assert_approx_eq(&expected.to_data(), precision); + } + + #[test] + fn conv_transpose1d_full() { + let device = Default::default(); + let record = SafeTensorsFileRecorder::::default() + .load( + "tests/conv_transpose1d/conv_transpose1d.safetensors".into(), + &device, + ) + .expect("Should decode state successfully"); + + conv_transpose1d(record, 8); + } + #[test] + fn conv_transpose1d_half() { + let device = Default::default(); + let record = SafeTensorsFileRecorder::::default() + .load( + "tests/conv_transpose1d/conv_transpose1d.safetensors".into(), + &device, + ) + .expect("Should decode state successfully"); + + conv_transpose1d(record, 4); + } +} diff --git a/crates/burn-import/safetensors-tests/tests/conv_transpose2d/conv_transpose2d.safetensors b/crates/burn-import/safetensors-tests/tests/conv_transpose2d/conv_transpose2d.safetensors new file mode 100644 index 0000000000..182f4553ef Binary files /dev/null and b/crates/burn-import/safetensors-tests/tests/conv_transpose2d/conv_transpose2d.safetensors differ diff --git a/crates/burn-import/safetensors-tests/tests/conv_transpose2d/export_weights.py b/crates/burn-import/safetensors-tests/tests/conv_transpose2d/export_weights.py new file mode 100755 index 0000000000..397f69b95a --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/conv_transpose2d/export_weights.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn as nn +import torch.nn.functional as F +from safetensors.torch import save_file # Add this import + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.conv1 = nn.ConvTranspose2d(2, 2, (2, 2)) + self.conv2 = nn.ConvTranspose2d(2, 2, (2, 2), bias=False) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + +def main(): + + torch.set_printoptions(precision=8) + torch.manual_seed(1) + + model = Model().to(torch.device("cpu")) + + save_file(model.state_dict(), "conv_transpose2d.safetensors") # Replace torch.save + + input = torch.rand(1, 2, 2, 2) + print("Input shape: {}", input.shape) + print("Input: {}", input) + output = model(input) + print("Output: {}", output) + print("Output Shape: {}", output.shape) + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/safetensors-tests/tests/conv_transpose2d/mod.rs b/crates/burn-import/safetensors-tests/tests/conv_transpose2d/mod.rs new file mode 100644 index 0000000000..ff702a793a --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/conv_transpose2d/mod.rs @@ -0,0 +1,101 @@ +use burn::{ + module::Module, + nn::conv::{ConvTranspose2d, ConvTranspose2dConfig}, + tensor::{backend::Backend, Tensor}, +}; + +#[derive(Module, Debug)] +pub struct Net { + conv1: ConvTranspose2d, + conv2: ConvTranspose2d, +} + +impl Net { + /// Create a new model from the given record. + pub fn init(device: &B::Device) -> Self { + let conv1 = ConvTranspose2dConfig::new([2, 2], [2, 2]).init(device); + let conv2 = ConvTranspose2dConfig::new([2, 2], [2, 2]).init(device); + + Self { conv1, conv2 } + } + + /// Forward pass of the model. + pub fn forward(&self, x: Tensor) -> Tensor { + let x = self.conv1.forward(x); + + self.conv2.forward(x) + } +} + +#[cfg(test)] +mod tests { + type Backend = burn_ndarray::NdArray; + + use burn::record::{FullPrecisionSettings, HalfPrecisionSettings, Recorder}; + use burn_import::safetensors::SafeTensorsFileRecorder; + + use super::*; + + fn conv_transpose2d(record: NetRecord, precision: usize) { + let device = Default::default(); + + let model = Net::::init(&device).load_record(record); + + let input = Tensor::::from_data( + [[ + [[0.024_595_8, 0.25883394], [0.93905586, 0.416_715_5]], + [[0.713_979_7, 0.267_644_3], [0.990_609, 0.28845078]], + ]], + &device, + ); + + let output = model.forward(input); + + let expected = Tensor::::from_data( + [[ + [ + [0.04547675, 0.01879685, -0.01636661, 0.00310803], + [0.02090115, 0.01192738, -0.048_240_2, 0.02252235], + [0.03249975, -0.00460748, 0.05003899, 0.04029131], + [0.02185687, -0.10226749, -0.06508022, -0.01267705], + ], + [ + [0.00277598, -0.00513832, -0.059_048_3, 0.00567626], + [-0.03149522, -0.195_757_4, 0.03474613, 0.01997269], + [-0.10096474, 0.00679589, 0.041_919_7, -0.02464108], + [-0.03174751, 0.02963913, -0.02703723, -0.01860938], + ], + ]], + &device, + ); + + output + .to_data() + .assert_approx_eq(&expected.to_data(), precision); + } + + #[test] + fn conv_transpose2d_full() { + let device = Default::default(); + let record = SafeTensorsFileRecorder::::default() + .load( + "tests/conv_transpose2d/conv_transpose2d.safetensors".into(), + &device, + ) + .expect("Should decode state successfully"); + + conv_transpose2d(record, 7); + } + #[test] + fn conv_transpose2d_half() { + let device = Default::default(); + let record = SafeTensorsFileRecorder::::default() + .load( + "tests/conv_transpose2d/conv_transpose2d.safetensors".into(), + &device, + ) + .expect("Should decode state successfully"); + + conv_transpose2d(record, 4); + } +} diff --git a/crates/burn-import/safetensors-tests/tests/embedding/embedding.safetensors b/crates/burn-import/safetensors-tests/tests/embedding/embedding.safetensors new file mode 100644 index 0000000000..b46bece4ae Binary files /dev/null and b/crates/burn-import/safetensors-tests/tests/embedding/embedding.safetensors differ diff --git a/crates/burn-import/safetensors-tests/tests/embedding/export_weights.py b/crates/burn-import/safetensors-tests/tests/embedding/export_weights.py new file mode 100755 index 0000000000..b8779fba69 --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/embedding/export_weights.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn as nn +import torch.nn.functional as F +from safetensors.torch import save_file + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.embed = nn.Embedding(10, 3) + + def forward(self, x): + x = self.embed(x) + return x + + +def main(): + + torch.set_printoptions(precision=8) + torch.manual_seed(1) + + model = Model().to(torch.device("cpu")) + + save_file(model.state_dict(), "embedding.safetensors") + + input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]]) + print("Input shape: {}", input.shape) + print("Input: {}", input) + output = model(input) + print("Output: {}", output) + print("Output Shape: {}", output.shape) + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/safetensors-tests/tests/embedding/mod.rs b/crates/burn-import/safetensors-tests/tests/embedding/mod.rs new file mode 100644 index 0000000000..7cd6856c38 --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/embedding/mod.rs @@ -0,0 +1,84 @@ +use burn::{ + module::Module, + nn::{Embedding, EmbeddingConfig}, + tensor::{backend::Backend, Int, Tensor}, +}; + +#[derive(Module, Debug)] +pub struct Net { + embed: Embedding, +} + +impl Net { + /// Create a new model. + pub fn init(device: &B::Device) -> Self { + let embed = EmbeddingConfig::new(10, 3).init(device); + Self { embed } + } + + /// Forward pass of the model. + pub fn forward(&self, x: Tensor) -> Tensor { + self.embed.forward(x) + } +} + +#[cfg(test)] +mod tests { + type Backend = burn_ndarray::NdArray; + use burn::record::{FullPrecisionSettings, HalfPrecisionSettings, Recorder}; + use burn_import::safetensors::SafeTensorsFileRecorder; + + use super::*; + + fn embedding(record: NetRecord, precision: usize) { + let device = Default::default(); + + let model = Net::::init(&device).load_record(record); + + let input = Tensor::::from_data([[1, 2, 4, 5], [4, 3, 2, 9]], &device); + + let output = model.forward(input); + + let expected = Tensor::::from_data( + [ + [ + [-1.609_484_9, -0.10016718, -0.609_188_9], + [-0.97977227, -1.609_096_3, -0.712_144_6], + [-0.22227049, 1.687_113_4, -0.32062083], + [-0.29934573, 1.879_345_7, -0.07213178], + ], + [ + [-0.22227049, 1.687_113_4, -0.32062083], + [0.303_722, -0.777_314_3, -0.25145486], + [-0.97977227, -1.609_096_3, -0.712_144_6], + [-0.02878714, 2.357_111, -1.037_338_7], + ], + ], + &device, + ); + + output + .to_data() + .assert_approx_eq(&expected.to_data(), precision); + } + + #[test] + fn embedding_full_precision() { + let device = Default::default(); + let record = SafeTensorsFileRecorder::::default() + .load("tests/embedding/embedding.safetensors".into(), &device) + .expect("Should decode state successfully"); + + embedding(record, 3); + } + + #[test] + fn embedding_half_precision() { + let device = Default::default(); + let record = SafeTensorsFileRecorder::::default() + .load("tests/embedding/embedding.safetensors".into(), &device) + .expect("Should decode state successfully"); + + embedding(record, 3); + } +} diff --git a/crates/burn-import/safetensors-tests/tests/enum_module/enum_depthwise_false.safetensors b/crates/burn-import/safetensors-tests/tests/enum_module/enum_depthwise_false.safetensors new file mode 100644 index 0000000000..31a8f03f8d Binary files /dev/null and b/crates/burn-import/safetensors-tests/tests/enum_module/enum_depthwise_false.safetensors differ diff --git a/crates/burn-import/safetensors-tests/tests/enum_module/enum_depthwise_true.safetensors b/crates/burn-import/safetensors-tests/tests/enum_module/enum_depthwise_true.safetensors new file mode 100644 index 0000000000..1e70d6963a Binary files /dev/null and b/crates/burn-import/safetensors-tests/tests/enum_module/enum_depthwise_true.safetensors differ diff --git a/crates/burn-import/safetensors-tests/tests/enum_module/export_weights.py b/crates/burn-import/safetensors-tests/tests/enum_module/export_weights.py new file mode 100755 index 0000000000..4d00577d09 --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/enum_module/export_weights.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +import torch +from torch import nn, Tensor +from safetensors.torch import save_file + + +class DwsConv(nn.Module): + """Depthwise separable convolution.""" + + def __init__(self, in_channels: int, out_channels: int, kernel_size: int) -> None: + super().__init__() + # Depthwise conv + self.dconv = nn.Conv2d( + in_channels, in_channels, kernel_size, groups=in_channels + ) + # Pointwise conv + self.pconv = nn.Conv2d(in_channels, out_channels, kernel_size=1, groups=1) + + def forward(self, x: Tensor) -> Tensor: + x = self.dconv(x) + return self.pconv(x) + + +class Model(nn.Module): + def __init__(self, depthwise: bool = False) -> None: + super().__init__() + self.conv = DwsConv(2, 2, 3) if depthwise else nn.Conv2d(2, 2, 3) + + def forward(self, x: Tensor) -> Tensor: + return self.conv(x) + + +def main(): + + torch.set_printoptions(precision=8) + torch.manual_seed(1) + + model = Model().to(torch.device("cpu")) + + save_file(model.state_dict(), "enum_depthwise_false.safetensors") + + input = torch.rand(1, 2, 5, 5) + + print("Depthwise is False") + print("Input shape: {}", input.shape) + print("Input: {}", input) + output = model(input) + print("Output: {}", output) + print("Output Shape: {}", output.shape) + + print("Depthwise is True") + model = Model(depthwise=True).to(torch.device("cpu")) + save_file(model.state_dict(), "enum_depthwise_true.safetensors") + + print("Input shape: {}", input.shape) + print("Input: {}", input) + output = model(input) + print("Output: {}", output) + print("Output Shape: {}", output.shape) + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/safetensors-tests/tests/enum_module/mod.rs b/crates/burn-import/safetensors-tests/tests/enum_module/mod.rs new file mode 100644 index 0000000000..5a5db0f6f2 --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/enum_module/mod.rs @@ -0,0 +1,198 @@ +use burn::{ + module::Module, + nn::conv::{Conv2d, Conv2dConfig}, + tensor::{backend::Backend, Tensor}, +}; + +#[derive(Module, Debug)] +pub enum Conv { + DwsConv(DwsConv), + Conv(Conv2d), +} + +#[derive(Module, Debug)] +pub struct DwsConv { + dconv: Conv2d, + pconv: Conv2d, +} + +#[derive(Module, Debug)] +pub struct Net { + conv: Conv, +} + +impl Net { + /// Create a new model from the given record. + pub fn new_with(record: NetRecord) -> Self { + let device = Default::default(); + + let conv = match record.conv { + ConvRecord::DwsConv(dws_conv) => { + let dconv = Conv2dConfig::new([2, 2], [3, 3]) + .with_groups(2) + .init(&device) + .load_record(dws_conv.dconv); + let pconv = Conv2dConfig::new([2, 2], [1, 1]) + .with_groups(1) + .init(&device) + .load_record(dws_conv.pconv); + Conv::DwsConv(DwsConv { dconv, pconv }) + } + ConvRecord::Conv(conv) => { + let conv2d_config = Conv2dConfig::new([2, 2], [3, 3]); + Conv::Conv(conv2d_config.init(&device).load_record(conv)) + } + }; + Net { conv } + } + + /// Forward pass of the model. + pub fn forward(&self, x: Tensor) -> Tensor { + match &self.conv { + Conv::DwsConv(dws_conv) => { + let x = dws_conv.dconv.forward(x); + dws_conv.pconv.forward(x) + } + Conv::Conv(conv) => conv.forward(x), + } + } +} + +#[cfg(test)] +mod tests { + type Backend = burn_ndarray::NdArray; + + use burn::record::{FullPrecisionSettings, Recorder}; + use burn_import::safetensors::SafeTensorsFileRecorder; + + use super::*; + + #[test] + fn depthwise_false() { + let device = Default::default(); + let record = SafeTensorsFileRecorder::::default() + .load( + "tests/enum_module/enum_depthwise_false.safetensors".into(), + &device, + ) + .expect("Should decode state successfully"); + + let model = Net::::new_with(record); + let input = Tensor::::from_data( + [[ + [ + [0.713_979_7, 0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4], + [0.505_920_8, 0.23659128, 0.757_007_4, 0.23458993, 0.64705235], + [0.355_621_4, 0.445_182_8, 0.01930594, 0.26160914, 0.771_317], + [0.37846136, 0.99802476, 0.900_794_2, 0.476_588_2, 0.16625845], + [ + 0.804_481_1, + 0.65517855, + 0.17679012, + 0.824_772_3, + 0.803_550_9, + ], + ], + [ + [0.943_447_5, 0.21972018, 0.417_697, 0.49031407, 0.57302874], + [0.12054086, 0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7], + [0.52850497, 0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537], + [ + 0.03694397, + 0.751_675_7, + 0.148_438_4, + 0.12274551, + 0.530_407_2, + ], + [0.414_796_4, 0.793_662, 0.21043217, 0.05550903, 0.863_884_4], + ], + ]], + &device, + ); + + let output = model.forward(input); + + let expected = Tensor::::from_data( + [[ + [ + [0.35449377, -0.02832414, 0.490_976_1], + [0.29709217, 0.332_586_3, 0.30594018], + [0.18101373, 0.30932188, 0.30558896], + ], + [ + [-0.17683622, -0.13244139, -0.05608707], + [0.23467252, -0.07038684, 0.255_044_1], + [-0.241_931_3, -0.20476191, -0.14468731], + ], + ]], + &device, + ); + + output.to_data().assert_approx_eq(&expected.to_data(), 7); + } + + #[test] + fn depthwise_true() { + let device = Default::default(); + let record = SafeTensorsFileRecorder::::default() + .load( + "tests/enum_module/enum_depthwise_true.safetensors".into(), + &device, + ) + .expect("Should decode state successfully"); + + let model = Net::::new_with(record); + + let input = Tensor::::from_data( + [[ + [ + [0.713_979_7, 0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4], + [0.505_920_8, 0.23659128, 0.757_007_4, 0.23458993, 0.64705235], + [0.355_621_4, 0.445_182_8, 0.01930594, 0.26160914, 0.771_317], + [0.37846136, 0.99802476, 0.900_794_2, 0.476_588_2, 0.16625845], + [ + 0.804_481_1, + 0.65517855, + 0.17679012, + 0.824_772_3, + 0.803_550_9, + ], + ], + [ + [0.943_447_5, 0.21972018, 0.417_697, 0.49031407, 0.57302874], + [0.12054086, 0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7], + [0.52850497, 0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537], + [ + 0.03694397, + 0.751_675_7, + 0.148_438_4, + 0.12274551, + 0.530_407_2, + ], + [0.414_796_4, 0.793_662, 0.21043217, 0.05550903, 0.863_884_4], + ], + ]], + &device, + ); + + let output = model.forward(input); + + let expected = Tensor::::from_data( + [[ + [ + [0.77874625, 0.859_017_6, 0.834_283_5], + [0.773_056_4, 0.73817325, 0.78292674], + [0.710_775_2, 0.747_187_2, 0.733_264_4], + ], + [ + [-0.44891885, -0.49027523, -0.394_170_7], + [-0.43836114, -0.33961445, -0.387_311_5], + [-0.581_134_3, -0.34197026, -0.535_035_7], + ], + ]], + &device, + ); + + output.to_data().assert_approx_eq(&expected.to_data(), 7); + } +} diff --git a/crates/burn-import/safetensors-tests/tests/group_norm/export_weights.py b/crates/burn-import/safetensors-tests/tests/group_norm/export_weights.py new file mode 100755 index 0000000000..01a69dc1c4 --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/group_norm/export_weights.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn as nn +import torch.nn.functional as F +from safetensors.torch import save_file + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.norm1 = nn.GroupNorm(2, 6) + + def forward(self, x): + x = self.norm1(x) + return x + + +def main(): + + torch.set_printoptions(precision=8) + torch.manual_seed(1) + + model = Model().to(torch.device("cpu")) + + save_file(model.state_dict(), "group_norm.safetensors") + + x2 = torch.rand(1, 6, 2, 2) + print("Input shape: {}", x2.shape) + print("Input: {}", x2) + output = model(x2) + print("Output: {}", output) + print("Output Shape: {}", output.shape) + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/safetensors-tests/tests/group_norm/group_norm.safetensors b/crates/burn-import/safetensors-tests/tests/group_norm/group_norm.safetensors new file mode 100644 index 0000000000..f3d4e76524 Binary files /dev/null and b/crates/burn-import/safetensors-tests/tests/group_norm/group_norm.safetensors differ diff --git a/crates/burn-import/safetensors-tests/tests/group_norm/mod.rs b/crates/burn-import/safetensors-tests/tests/group_norm/mod.rs new file mode 100644 index 0000000000..9c2e4f53ac --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/group_norm/mod.rs @@ -0,0 +1,88 @@ +use burn::{ + module::Module, + nn::{GroupNorm, GroupNormConfig}, + tensor::{backend::Backend, Tensor}, +}; + +#[derive(Module, Debug)] +pub struct Net { + norm1: GroupNorm, +} + +impl Net { + /// Create a new model from the given record. + pub fn init(device: &B::Device) -> Self { + let norm1 = GroupNormConfig::new(2, 6).init(device); + Self { norm1 } + } + + /// Forward pass of the model. + pub fn forward(&self, x: Tensor) -> Tensor { + self.norm1.forward(x) + } +} + +#[cfg(test)] +mod tests { + type Backend = burn_ndarray::NdArray; + use burn::record::{FullPrecisionSettings, HalfPrecisionSettings, Recorder}; + use burn_import::safetensors::SafeTensorsFileRecorder; + + use super::*; + + fn group_norm(record: NetRecord, precision: usize) { + let device = Default::default(); + + let model = Net::::init(&device).load_record(record); + + let input = Tensor::::from_data( + [[ + [[0.757_631_6, 0.27931088], [0.40306926, 0.73468447]], + [[0.02928156, 0.799_858_6], [0.39713734, 0.75437194]], + [[0.569_508_5, 0.43877792], [0.63868046, 0.524_665_9]], + [[0.682_614_1, 0.305_149_5], [0.46354562, 0.45498633]], + [[0.572_472, 0.498_002_6], [0.93708336, 0.65559506]], + [[0.31379688, 0.19801933], [0.41619217, 0.28432965]], + ]], + &device, + ); + + let output = model.forward(input); + + let expected = Tensor::::from_data( + [[ + [[1.042_578_5, -1.122_016_7], [-0.56195974, 0.938_733_6]], + [[-2.253_500_7, 1.233_672_9], [-0.588_804_1, 1.027_827_3]], + [[0.19124532, -0.40036356], [0.504_276_5, -0.01168585]], + [[1.013_829_2, -0.891_984_6], [-0.09224463, -0.13546038]], + [[0.45772314, 0.08172822], [2.298_641_4, 0.877_410_4]], + [[-0.84832406, -1.432_883_4], [-0.331_331_5, -0.997_103_7]], + ]], + &device, + ); + + output + .to_data() + .assert_approx_eq(&expected.to_data(), precision); + } + + #[test] + fn group_norm_full() { + let device = Default::default(); + let record = SafeTensorsFileRecorder::::default() + .load("tests/group_norm/group_norm.safetensors".into(), &device) + .expect("Should decode state successfully"); + + group_norm(record, 3); + } + + #[test] + fn group_norm_half() { + let device = Default::default(); + let record = SafeTensorsFileRecorder::::default() + .load("tests/group_norm/group_norm.safetensors".into(), &device) + .expect("Should decode state successfully"); + + group_norm(record, 3); + } +} diff --git a/crates/burn-import/safetensors-tests/tests/integer/export_weights.py b/crates/burn-import/safetensors-tests/tests/integer/export_weights.py new file mode 100755 index 0000000000..91ca706915 --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/integer/export_weights.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn as nn +import torch.nn.functional as F +from safetensors.torch import save_file + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + buffer = torch.tensor([1, 2, 3]) + self.register_buffer("buffer", buffer, persistent=True) + + def forward(self, x): + x = self.buffer + return x + + +def main(): + + torch.set_printoptions(precision=8) + torch.manual_seed(1) + + model = Model().to(torch.device("cpu")) + + save_file(model.state_dict(), "integer.safetensors") + + input = torch.ones(3, 3) + print("Input shape: {}", input.shape) + print("Input: {}", input) + output = model(input) + print("Output: {}", output) + print("Output Shape: {}", output.shape) + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/safetensors-tests/tests/integer/integer.safetensors b/crates/burn-import/safetensors-tests/tests/integer/integer.safetensors new file mode 100644 index 0000000000..9696860133 Binary files /dev/null and b/crates/burn-import/safetensors-tests/tests/integer/integer.safetensors differ diff --git a/crates/burn-import/safetensors-tests/tests/integer/mod.rs b/crates/burn-import/safetensors-tests/tests/integer/mod.rs new file mode 100644 index 0000000000..010bd587a6 --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/integer/mod.rs @@ -0,0 +1,69 @@ +use burn::{ + module::{Module, Param}, + tensor::{backend::Backend, Int, Tensor}, +}; + +#[derive(Module, Debug)] +pub struct Net { + buffer: Param>, +} + +impl Net { + /// Create a new model from the given record. + pub fn new_with(record: NetRecord) -> Self { + Self { + buffer: record.buffer, + } + } + + /// Forward pass of the model. + pub fn forward(&self, _x: Tensor) -> Tensor { + self.buffer.val() + } +} + +#[cfg(test)] +mod tests { + type Backend = burn_ndarray::NdArray; + use burn::{ + record::{FullPrecisionSettings, HalfPrecisionSettings, Recorder}, + tensor::TensorData, + }; + use burn_import::safetensors::SafeTensorsFileRecorder; + + use super::*; + + fn integer(record: NetRecord, _precision: usize) { + let device = Default::default(); + + let model = Net::::new_with(record); + + let input = Tensor::::ones([3, 3], &device); + + let output = model.forward(input); + + let expected = Tensor::::from_data(TensorData::from([1, 2, 3]), &device); + + assert_eq!(output.to_data(), expected.to_data()); + } + + #[test] + fn integer_full_precision() { + let device = Default::default(); + let record = SafeTensorsFileRecorder::::default() + .load("tests/integer/integer.safetensors".into(), &device) + .expect("Should decode state successfully"); + + integer(record, 0); + } + + #[test] + fn integer_half_precision() { + let device = Default::default(); + let record = SafeTensorsFileRecorder::::default() + .load("tests/integer/integer.safetensors".into(), &device) + .expect("Should decode state successfully"); + + integer(record, 0); + } +} diff --git a/crates/burn-import/safetensors-tests/tests/key_remap/export_weights.py b/crates/burn-import/safetensors-tests/tests/key_remap/export_weights.py new file mode 100755 index 0000000000..36264d20fe --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/key_remap/export_weights.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn as nn +import torch.nn.functional as F +from safetensors.torch import save_file + + +class ConvModule(nn.Module): + def __init__(self): + super(ConvModule, self).__init__() + self.conv1 = nn.Conv2d(2, 2, (2, 2)) + self.conv2 = nn.Conv2d(2, 2, (2, 2), bias=False) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.conv = ConvModule() + + def forward(self, x): + x = self.conv(x) + return x + + +def main(): + + torch.set_printoptions(precision=8) + torch.manual_seed(1) + + model = Model().to(torch.device("cpu")) + + save_file(model.state_dict(), "key_remap.safetensors") + + input = torch.rand(1, 2, 5, 5) + print("Input shape: {}", input.shape) + print("Input: {}", input) + output = model(input) + print("Output: {}", output) + print("Output Shape: {}", output.shape) + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/safetensors-tests/tests/key_remap/key_remap.safetensors b/crates/burn-import/safetensors-tests/tests/key_remap/key_remap.safetensors new file mode 100644 index 0000000000..b8d253580e Binary files /dev/null and b/crates/burn-import/safetensors-tests/tests/key_remap/key_remap.safetensors differ diff --git a/crates/burn-import/safetensors-tests/tests/key_remap/mod.rs b/crates/burn-import/safetensors-tests/tests/key_remap/mod.rs new file mode 100644 index 0000000000..b39c7e4c22 --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/key_remap/mod.rs @@ -0,0 +1,117 @@ +use burn::{ + module::Module, + nn::conv::{Conv2d, Conv2dConfig}, + tensor::{backend::Backend, Tensor}, +}; + +#[derive(Module, Debug)] +pub struct Net { + conv1: Conv2d, + conv2: Conv2d, +} + +impl Net { + /// Create a new model. + pub fn init(device: &B::Device) -> Self { + let conv1 = Conv2dConfig::new([2, 2], [2, 2]).init(device); + let conv2 = Conv2dConfig::new([2, 2], [2, 2]) + .with_bias(false) + .init(device); + Self { conv1, conv2 } + } + + /// Forward pass of the model. + pub fn forward(&self, x: Tensor) -> Tensor { + let x = self.conv1.forward(x); + + self.conv2.forward(x) + } +} + +#[cfg(test)] +mod tests { + type Backend = burn_ndarray::NdArray; + + use burn::record::{FullPrecisionSettings, Recorder}; + use burn_import::safetensors::{LoadArgs, SafeTensorsFileRecorder}; + + use super::*; + + #[test] + fn key_remap() { + let device = Default::default(); + let load_args = LoadArgs::new("tests/key_remap/key_remap.safetensors".into()) + .with_key_remap("conv\\.(.*)", "$1") // Remove "conv" prefix, e.g. "conv.conv1" -> "conv1" + .with_debug_print(); + + let record = SafeTensorsFileRecorder::::default() + .load(load_args, &device) + .expect("Should decode state successfully"); + + let model = Net::::init(&device).load_record(record); + + let input = Tensor::::from_data( + [[ + [ + [ + 0.024_595_8, + 0.25883394, + 0.93905586, + 0.416_715_5, + 0.713_979_7, + ], + [0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4, 0.505_920_8], + [0.23659128, 0.757_007_4, 0.23458993, 0.64705235, 0.355_621_4], + [0.445_182_8, 0.01930594, 0.26160914, 0.771_317, 0.37846136], + [ + 0.99802476, + 0.900_794_2, + 0.476_588_2, + 0.16625845, + 0.804_481_1, + ], + ], + [ + [ + 0.65517855, + 0.17679012, + 0.824_772_3, + 0.803_550_9, + 0.943_447_5, + ], + [0.21972018, 0.417_697, 0.49031407, 0.57302874, 0.12054086], + [0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7, 0.52850497], + [0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537, 0.03694397], + [ + 0.751_675_7, + 0.148_438_4, + 0.12274551, + 0.530_407_2, + 0.414_796_4, + ], + ], + ]], + &device, + ); + + let output = model.forward(input); + + let expected = Tensor::::from_data( + [[ + [ + [-0.02502128, 0.00250649, 0.04841233], + [0.04589614, -0.00296854, 0.01991477], + [0.02920526, 0.059_497_3, 0.04326791], + ], + [ + [-0.04825336, 0.080_190_9, -0.02375088], + [0.02885434, 0.09638263, -0.07460806], + [0.02004079, 0.06244051, 0.035_887_1], + ], + ]], + &device, + ); + + output.to_data().assert_approx_eq(&expected.to_data(), 7); + } +} diff --git a/crates/burn-import/safetensors-tests/tests/key_remap_chained/export_weights.py b/crates/burn-import/safetensors-tests/tests/key_remap_chained/export_weights.py new file mode 100755 index 0000000000..f5902cbe0c --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/key_remap_chained/export_weights.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 + +import torch +from torch import nn, Tensor +from safetensors.torch import save_file + + +class ConvBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.block = nn.Sequential( + nn.Conv2d(in_channels, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels), + ) + + def forward(self, x: Tensor) -> Tensor: + return self.block(x) + + +class Model(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 6, 3, bias=False) + self.bn = nn.BatchNorm2d(6) + self.layer = nn.Sequential(ConvBlock(6, 6), ConvBlock(6, 6)) + + def forward(self, x: Tensor) -> Tensor: + x = self.conv(x) + x = self.bn(x) + x = self.layer(x) + + return x + + +def main(): + torch.set_printoptions(precision=8) + torch.manual_seed(42) + + model = Model() + + input = torch.rand(1, 3, 4, 4) + model(input) # condition batch norm + model.eval() + + with torch.no_grad(): + print(f"Input shape: {input.shape}") + print("Input type: {}", input.dtype) + print(f"Input: {input}") + output = model(input) + + print(f"Output: {output}") + print(f"Output Shape: {output.shape}") + + save_file(model.state_dict(), "key_remap_chained.safetensors") + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/safetensors-tests/tests/key_remap_chained/key_remap_chained.safetensors b/crates/burn-import/safetensors-tests/tests/key_remap_chained/key_remap_chained.safetensors new file mode 100644 index 0000000000..c3cfe9d82d Binary files /dev/null and b/crates/burn-import/safetensors-tests/tests/key_remap_chained/key_remap_chained.safetensors differ diff --git a/crates/burn-import/safetensors-tests/tests/key_remap_chained/mod.rs b/crates/burn-import/safetensors-tests/tests/key_remap_chained/mod.rs new file mode 100644 index 0000000000..7b3f9c5743 --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/key_remap_chained/mod.rs @@ -0,0 +1,184 @@ +use std::marker::PhantomData; + +use burn::{ + module::Module, + nn::{ + conv::{Conv2d, Conv2dConfig}, + BatchNorm, BatchNormConfig, + }, + tensor::{backend::Backend, Device, Tensor}, +}; + +/// Some module that implements a specific method so it can be used in a sequential block. +pub trait ForwardModule { + fn forward(&self, input: Tensor) -> Tensor; +} + +/// Conv2d + BatchNorm block. +#[derive(Module, Debug)] +pub struct ConvBlock { + conv: Conv2d, + bn: BatchNorm, +} + +impl ForwardModule for ConvBlock { + fn forward(&self, input: Tensor) -> Tensor { + let out = self.conv.forward(input); + self.bn.forward(out) + } +} + +impl ConvBlock { + pub fn new(in_channels: usize, out_channels: usize, device: &Device) -> Self { + let conv = Conv2dConfig::new([in_channels, out_channels], [1, 1]) + .with_bias(false) + .init(device); + let bn = BatchNormConfig::new(out_channels).init(device); + + Self { conv, bn } + } +} + +/// Collection of sequential blocks. +#[derive(Module, Debug)] +pub struct ModuleBlock { + blocks: Vec, + _backend: PhantomData, +} + +impl> ModuleBlock { + pub fn forward(&self, input: Tensor) -> Tensor { + let mut out = input; + for block in &self.blocks { + out = block.forward(out); + } + out + } +} + +impl ModuleBlock> { + pub fn new(device: &Device) -> Self { + let blocks = vec![ConvBlock::new(6, 6, device), ConvBlock::new(6, 6, device)]; + + Self { + blocks, + _backend: PhantomData, + } + } +} + +#[derive(Module, Debug)] +pub struct Model { + conv: Conv2d, + bn: BatchNorm, + layer: ModuleBlock, +} + +impl Model> { + pub fn new(device: &Device) -> Self { + let conv = Conv2dConfig::new([3, 6], [3, 3]) + .with_bias(false) + .init(device); + let bn = BatchNormConfig::new(6).init(device); + + let layer = ModuleBlock::new(device); + + Self { conv, bn, layer } + } + + pub fn forward(&self, input: Tensor) -> Tensor { + let out = self.conv.forward(input); + let out = self.bn.forward(out); + self.layer.forward(out) + } +} + +#[cfg(test)] +mod tests { + type Backend = burn_ndarray::NdArray; + + use burn::record::{FullPrecisionSettings, Recorder}; + use burn_import::safetensors::{LoadArgs, SafeTensorsFileRecorder}; + + use super::*; + + #[test] + #[should_panic] + fn key_remap_chained_missing_pattern() { + // Loading record should fail due to missing pattern to map the layer.blocks + let device = Default::default(); + let load_args = + LoadArgs::new("tests/key_remap_chained/key_remap_chained.safetensors".into()) + // Map *.block.0.* -> *.conv.* + .with_key_remap("(.+)\\.block\\.0\\.(.+)", "$1.conv.$2") + // Map *.block.1.* -> *.bn.* + .with_key_remap("(.+)\\.block\\.1\\.(.+)", "$1.bn.$2"); + + let record = SafeTensorsFileRecorder::::default() + .load(load_args, &device) + .expect("Should decode state successfully"); + + let model: Model = Model::new(&device); + + model.load_record(record); + } + + #[test] + fn key_remap_chained() { + let device = Default::default(); + let load_args = + LoadArgs::new("tests/key_remap_chained/key_remap_chained.safetensors".into()) + // Map *.block.0.* -> *.conv.* + .with_key_remap("(.+)\\.block\\.0\\.(.+)", "$1.conv.$2") + // Map *.block.1.* -> *.bn.* + .with_key_remap("(.+)\\.block\\.1\\.(.+)", "$1.bn.$2") + // Map layer.[i].* -> layer.blocks.[i].* + .with_key_remap("layer\\.([0-9])\\.(.+)", "layer.blocks.$1.$2"); + + let record = SafeTensorsFileRecorder::::default() + .load(load_args, &device) + .expect("Should decode state successfully"); + + let model: Model = Model::new(&device); + + let model = model.load_record(record); + + let input = Tensor::::from_data( + [[ + [ + [0.76193494, 0.626_546_1, 0.49510366, 0.11974698], + [0.07161391, 0.03232569, 0.704_681, 0.254_516], + [0.399_373_7, 0.21224737, 0.40888822, 0.14808255], + [0.17329216, 0.665_855_4, 0.351_401_8, 0.808_671_6], + ], + [ + [0.33959562, 0.13321638, 0.41178054, 0.257_626_3], + [0.347_029_2, 0.02400219, 0.77974546, 0.15189773], + [0.75130886, 0.726_892_1, 0.85721636, 0.11647397], + [0.859_598_4, 0.263_624_2, 0.685_534_6, 0.96955734], + ], + [ + [0.42948407, 0.49613327, 0.38488472, 0.08250773], + [0.73995143, 0.00364107, 0.81039995, 0.87411255], + [0.972_853_2, 0.38206023, 0.08917904, 0.61241513], + [0.77621365, 0.00234562, 0.38650817, 0.20027226], + ], + ]], + &device, + ); + let expected = Tensor::::from_data( + [[ + [[0.198_967_1, 0.17847246], [0.06883702, 0.20012866]], + [[0.17582723, 0.11344293], [0.05444185, 0.13307181]], + [[0.192_229_5, 0.20391327], [0.06150475, 0.22688155]], + [[0.00230906, -0.02177845], [0.01129148, 0.00925517]], + [[0.14751078, 0.14433631], [0.05498439, 0.29049855]], + [[0.16868964, 0.133_269_3], [0.06917118, 0.35094324]], + ]], + &device, + ); + + let output = model.forward(input); + output.to_data().assert_approx_eq(&expected.to_data(), 7); + } +} diff --git a/crates/burn-import/safetensors-tests/tests/layer_norm/export_weights.py b/crates/burn-import/safetensors-tests/tests/layer_norm/export_weights.py new file mode 100755 index 0000000000..d27ee1d7b9 --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/layer_norm/export_weights.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn as nn +import torch.nn.functional as F +from safetensors.torch import save_file + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.norm1 = nn.LayerNorm(2) + + def forward(self, x): + x = self.norm1(x) + return x + + +def main(): + + torch.set_printoptions(precision=8) + torch.manual_seed(1) + + model = Model().to(torch.device("cpu")) + + save_file(model.state_dict(), "layer_norm.safetensors") + + x2 = torch.rand(1, 2, 2, 2) + print("Input shape: {}", x2.shape) + print("Input: {}", x2) + output = model(x2) + print("Output: {}", output) + print("Output Shape: {}", output.shape) + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/safetensors-tests/tests/layer_norm/layer_norm.safetensors b/crates/burn-import/safetensors-tests/tests/layer_norm/layer_norm.safetensors new file mode 100644 index 0000000000..d89760a17a Binary files /dev/null and b/crates/burn-import/safetensors-tests/tests/layer_norm/layer_norm.safetensors differ diff --git a/crates/burn-import/safetensors-tests/tests/layer_norm/mod.rs b/crates/burn-import/safetensors-tests/tests/layer_norm/mod.rs new file mode 100644 index 0000000000..35df8d8708 --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/layer_norm/mod.rs @@ -0,0 +1,79 @@ +use burn::{ + module::Module, + nn::{LayerNorm, LayerNormConfig}, + tensor::{backend::Backend, Tensor}, +}; + +#[derive(Module, Debug)] +pub struct Net { + norm1: LayerNorm, +} + +impl Net { + /// Create a new model. + pub fn init(device: &B::Device) -> Self { + let norm1 = LayerNormConfig::new(4).init(device); + Self { norm1 } + } + + /// Forward pass of the model. + pub fn forward(&self, x: Tensor) -> Tensor { + self.norm1.forward(x) + } +} + +#[cfg(test)] +mod tests { + type Backend = burn_ndarray::NdArray; + + use burn::record::{FullPrecisionSettings, HalfPrecisionSettings, Recorder}; + use burn_import::safetensors::SafeTensorsFileRecorder; + + use super::*; + + fn layer_norm(record: NetRecord, precision: usize) { + let device = Default::default(); + + let model = Net::::init(&device).load_record(record); + + let input = Tensor::::from_data( + [[ + [[0.757_631_6, 0.27931088], [0.40306926, 0.73468447]], + [[0.02928156, 0.799_858_6], [0.39713734, 0.75437194]], + ]], + &device, + ); + + let output = model.forward(input); + + let expected = Tensor::::from_data( + [[ + [[0.99991274, -0.999_912_5], [-0.999_818_3, 0.999_818_3]], + [[-0.999_966_2, 0.99996626], [-0.99984336, 0.99984336]], + ]], + &device, + ); + + output + .to_data() + .assert_approx_eq(&expected.to_data(), precision); + } + + #[test] + fn layer_norm_full() { + let device = Default::default(); + let record = SafeTensorsFileRecorder::::default() + .load("tests/layer_norm/layer_norm.safetensors".into(), &device) + .expect("Should decode state successfully"); + layer_norm(record, 3); + } + + #[test] + fn layer_norm_half() { + let device = Default::default(); + let record = SafeTensorsFileRecorder::::default() + .load("tests/layer_norm/layer_norm.safetensors".into(), &device) + .expect("Should decode state successfully"); + layer_norm(record, 3); + } +} diff --git a/crates/burn-import/safetensors-tests/tests/linear/export_weights.py b/crates/burn-import/safetensors-tests/tests/linear/export_weights.py new file mode 100755 index 0000000000..e53df2190b --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/linear/export_weights.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn as nn +import torch.nn.functional as F +from safetensors.torch import save_file + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.fc1 = nn.Linear(2, 3) + self.fc2 = nn.Linear(3, 4, bias=False) + + def forward(self, x): + x = self.fc1(x) + x = F.relu(x) # Add relu so that PyTorch optimizer does not combine fc1 and fc2 + x = self.fc2(x) + + return x + + +class ModelWithBias(nn.Module): + def __init__(self): + super(ModelWithBias, self).__init__() + self.fc1 = nn.Linear(2, 3) + + def forward(self, x): + x = self.fc1(x) + + return x + + +def main(): + + torch.set_printoptions(precision=8) + torch.manual_seed(1) + + model = Model().to(torch.device("cpu")) + model_with_bias = ModelWithBias().to(torch.device("cpu")) + + save_file(model.state_dict(), "linear.safetensors") + save_file(model_with_bias.state_dict(), "linear_with_bias.safetensors") + + input = torch.rand(1, 2, 2, 2) + print("Input shape: {}", input.shape) + print("Input: {}", input) + + output = model(input) + print("Output: {}", output) + print("Output Shape: {}", output.shape) + + print("Model with bias") + output = model_with_bias(input) + print("Output: {}", output) + print("Output Shape: {}", output.shape) + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/safetensors-tests/tests/linear/linear.safetensors b/crates/burn-import/safetensors-tests/tests/linear/linear.safetensors new file mode 100644 index 0000000000..2be695e4e9 Binary files /dev/null and b/crates/burn-import/safetensors-tests/tests/linear/linear.safetensors differ diff --git a/crates/burn-import/safetensors-tests/tests/linear/linear_with_bias.safetensors b/crates/burn-import/safetensors-tests/tests/linear/linear_with_bias.safetensors new file mode 100644 index 0000000000..f5075a4ce3 Binary files /dev/null and b/crates/burn-import/safetensors-tests/tests/linear/linear_with_bias.safetensors differ diff --git a/crates/burn-import/safetensors-tests/tests/linear/mod.rs b/crates/burn-import/safetensors-tests/tests/linear/mod.rs new file mode 100644 index 0000000000..fb8ca7fc65 --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/linear/mod.rs @@ -0,0 +1,149 @@ +use burn::{ + module::Module, + nn::{Linear, LinearConfig, Relu}, + tensor::{backend::Backend, Tensor}, +}; + +#[derive(Module, Debug)] +pub struct Net { + fc1: Linear, + fc2: Linear, + relu: Relu, +} + +impl Net { + /// Create a new model. + pub fn init(device: &B::Device) -> Self { + let fc1 = LinearConfig::new(2, 3).init(device); + let fc2 = LinearConfig::new(3, 4).init(device); + let relu = Relu; + + Self { fc1, fc2, relu } + } + + /// Forward pass of the model. + pub fn forward(&self, x: Tensor) -> Tensor { + let x = self.fc1.forward(x); + let x = self.relu.forward(x); + + self.fc2.forward(x) + } +} + +#[derive(Module, Debug)] +struct NetWithBias { + fc1: Linear, +} + +impl NetWithBias { + /// Create a new model. + pub fn init(device: &B::Device) -> Self { + let fc1 = LinearConfig::new(2, 3).init(device); + + Self { fc1 } + } + + /// Forward pass of the model. + pub fn forward(&self, x: Tensor) -> Tensor { + self.fc1.forward(x) + } +} + +#[cfg(test)] +mod tests { + type Backend = burn_ndarray::NdArray; + + use burn::record::{FullPrecisionSettings, HalfPrecisionSettings, Recorder}; + + use burn_import::safetensors::SafeTensorsFileRecorder; + + use super::*; + + fn linear_test(record: NetRecord, precision: usize) { + let device = Default::default(); + let model = Net::::init(&device).load_record(record); + + let input = Tensor::::from_data( + [[ + [[0.63968194, 0.97427773], [0.830_029_9, 0.04443115]], + [[0.024_595_8, 0.25883394], [0.93905586, 0.416_715_5]], + ]], + &device, + ); + + let output = model.forward(input); + let expected = Tensor::::from_data( + [[ + [ + [0.09778349, -0.13756673, 0.04962806, 0.08856435], + [0.03163241, -0.02848549, 0.01437942, 0.11905234], + ], + [ + [0.07628226, -0.10757702, 0.03656857, 0.03824598], + [0.05443089, -0.06904714, 0.02744314, 0.09997337], + ], + ]], + &device, + ); + output + .to_data() + .assert_approx_eq(&expected.to_data(), precision); + } + + #[test] + fn linear_full_precision() { + let device = Default::default(); + let record = SafeTensorsFileRecorder::::default() + .load("tests/linear/linear.safetensors".into(), &device) + .expect("Should decode state successfully"); + + linear_test(record, 7); + } + + #[test] + fn linear_half_precision() { + let device = Default::default(); + let record = SafeTensorsFileRecorder::::default() + .load("tests/linear/linear.safetensors".into(), &device) + .expect("Should decode state successfully"); + + linear_test(record, 4); + } + + #[test] + fn linear_with_bias() { + let device = Default::default(); + + let record = SafeTensorsFileRecorder::::default() + .load("tests/linear/linear_with_bias.safetensors".into(), &device) + .expect("Should decode state successfully"); + + let model = NetWithBias::::init(&device).load_record(record); + + let input = Tensor::::from_data( + [[ + [[0.63968194, 0.97427773], [0.830_029_9, 0.04443115]], + [[0.024_595_8, 0.25883394], [0.93905586, 0.416_715_5]], + ]], + &device, + ); + + let output = model.forward(input); + + let expected = Tensor::::from_data( + [[ + [ + [-0.00432095, -1.107_101_2, 0.870_691_4], + [0.024_595_5, -0.954_462_9, 0.48518157], + ], + [ + [0.34315687, -0.757_384_2, 0.548_288], + [-0.06608963, -1.072_072_7, 0.645_800_5], + ], + ]], + &device, + ); + + output.to_data().assert_approx_eq(&expected.to_data(), 6); + } +} diff --git a/crates/burn-import/safetensors-tests/tests/missing_module_field/export_weights.py b/crates/burn-import/safetensors-tests/tests/missing_module_field/export_weights.py new file mode 100755 index 0000000000..d4c6fa81e0 --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/missing_module_field/export_weights.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn as nn +import torch.nn.functional as F +from safetensors.torch import save_file + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.conv1 = nn.Conv2d(2, 2, (2, 2)) + + def forward(self, x): + x = self.conv1(x) + return x + + +def main(): + torch.set_printoptions(precision=8) + torch.manual_seed(1) + model = Model().to(torch.device("cpu")) + save_file(model.state_dict(), "missing_module_field.safetensors") + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/safetensors-tests/tests/missing_module_field/missing_module_field.safetensors b/crates/burn-import/safetensors-tests/tests/missing_module_field/missing_module_field.safetensors new file mode 100644 index 0000000000..d061ac4100 Binary files /dev/null and b/crates/burn-import/safetensors-tests/tests/missing_module_field/missing_module_field.safetensors differ diff --git a/crates/burn-import/safetensors-tests/tests/missing_module_field/mod.rs b/crates/burn-import/safetensors-tests/tests/missing_module_field/mod.rs new file mode 100644 index 0000000000..5b983a7b12 --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/missing_module_field/mod.rs @@ -0,0 +1,31 @@ +use burn::{module::Module, nn::conv::Conv2d, tensor::backend::Backend}; + +#[derive(Module, Debug)] +pub struct Net { + do_not_exist_in_pt: Conv2d, +} + +#[cfg(test)] +mod tests { + type Backend = burn_ndarray::NdArray; + + use burn::record::{FullPrecisionSettings, Recorder}; + use burn_import::safetensors::SafeTensorsFileRecorder; + + use super::*; + + #[test] + #[should_panic( + expected = "Missing source values for the 'do_not_exist_in_pt' field of type 'Conv2dRecordItem'. Please verify the source data and ensure the field name is correct" + )] + fn should_fail_if_struct_field_is_missing() { + let device = Default::default(); + let _record: NetRecord = + SafeTensorsFileRecorder::::default() + .load( + "tests/missing_module_field/missing_module_field.safetensors".into(), + &device, + ) + .expect("Should decode state successfully"); + } +} diff --git a/crates/burn-import/safetensors-tests/tests/non_contiguous_indexes/export_weights.py b/crates/burn-import/safetensors-tests/tests/non_contiguous_indexes/export_weights.py new file mode 100755 index 0000000000..cb512ebab1 --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/non_contiguous_indexes/export_weights.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn as nn +import torch.nn.functional as F +from safetensors.torch import save_file + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + num_layers = 5 # Number of repeated convolutional layers + + # Create a list to store the layers + layers = [] + for _ in range(num_layers): + layers.append(nn.Conv2d(2, 2, kernel_size=3, padding=1, bias=True)) + layers.append(nn.ReLU(inplace=True)) + + # Use nn.Sequential to create a single module from the layers + self.fc = nn.Sequential(*layers) + + def forward(self, x): + x = self.fc(x) + return x + + +def main(): + + torch.set_printoptions(precision=8) + torch.manual_seed(1) + + model = Model().to(torch.device("cpu")) + + save_file(model.state_dict(), "non_contiguous_indexes.safetensors") + + input = torch.rand(1, 2, 5, 5) + print("Input shape: {}", input.shape) + print("Input: {}", input) + output = model(input) + print("Output: {}", output) + print("Output Shape: {}", output.shape) + + +if __name__ == "__main__": + main() diff --git a/crates/burn-import/safetensors-tests/tests/non_contiguous_indexes/mod.rs b/crates/burn-import/safetensors-tests/tests/non_contiguous_indexes/mod.rs new file mode 100644 index 0000000000..517cf126b9 --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/non_contiguous_indexes/mod.rs @@ -0,0 +1,105 @@ +use burn::{ + module::Module, + nn::{ + conv::{Conv2d, Conv2dConfig}, + PaddingConfig2d, + }, + tensor::{activation::relu, backend::Backend, Tensor}, +}; + +#[derive(Module, Debug)] +pub struct Net { + fc: Vec>, +} + +impl Net { + /// Create a new model from the given record. + pub fn new_with(record: NetRecord) -> Self { + let device = Default::default(); + let conv2d_config = Conv2dConfig::new([2, 2], [3, 3]).with_padding(PaddingConfig2d::Same); + let mut fc = vec![]; + for fc_record in record.fc.into_iter() { + fc.push(conv2d_config.init(&device).load_record(fc_record)); + } + Net { fc } + } + + /// Forward pass of the model. + pub fn forward(&self, x: Tensor) -> Tensor { + self.fc.iter().fold(x, |x_i, conv| relu(conv.forward(x_i))) + } +} + +#[cfg(test)] +mod tests { + type Backend = burn_ndarray::NdArray; + + use burn::record::{FullPrecisionSettings, Recorder}; + use burn_import::safetensors::{LoadArgs, SafeTensorsFileRecorder}; + + use super::*; + + #[test] + fn key_remap() { + let device = Default::default(); + let load_args = + LoadArgs::new("tests/non_contiguous_indexes/non_contiguous_indexes.safetensors".into()) + .with_debug_print(); + + let record = SafeTensorsFileRecorder::::default() + .load(load_args, &device) + .expect("Should decode state successfully"); + + let model = Net::::new_with(record); + + let input = Tensor::::from_data( + [[ + [ + [ + 0.67890584, + 0.307_537_2, + 0.265_156_2, + 0.528_318_8, + 0.86194897, + ], + [0.14828813, 0.73480314, 0.821_220_7, 0.989_098_6, 0.15003455], + [0.62109494, 0.13028657, 0.926_875_1, 0.30604684, 0.80117637], + [0.514_885_7, 0.46105868, 0.484_046_1, 0.58499724, 0.73569804], + [0.58018994, 0.65252745, 0.05023766, 0.864_268_7, 0.935_932], + ], + [ + [0.913_302_9, 0.869_611_3, 0.139_184_3, 0.314_65, 0.94086266], + [0.11917073, 0.953_610_6, 0.10675198, 0.14779574, 0.744_439], + [0.14075547, 0.38544965, 0.863_745_9, 0.89604443, 0.97287786], + [0.39854127, 0.11136961, 0.99230546, 0.39348692, 0.29428244], + [0.621_886_9, 0.15033776, 0.828_640_1, 0.81336635, 0.10325938], + ], + ]], + &device, + ); + + let output = model.forward(input); + + let expected = Tensor::::from_data( + [[ + [ + [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000], + [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000], + [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000], + [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000], + [0.04485746, 0.03582812, 0.03432692, 0.02892298, 0.013_844_3], + ], + [ + [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000], + [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000], + [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000], + [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000], + [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000], + ], + ]], + &device, + ); + + output.to_data().assert_approx_eq(&expected.to_data(), 7); + } +} diff --git a/crates/burn-import/safetensors-tests/tests/non_contiguous_indexes/non_contiguous_indexes.safetensors b/crates/burn-import/safetensors-tests/tests/non_contiguous_indexes/non_contiguous_indexes.safetensors new file mode 100644 index 0000000000..6c30dac506 Binary files /dev/null and b/crates/burn-import/safetensors-tests/tests/non_contiguous_indexes/non_contiguous_indexes.safetensors differ diff --git a/crates/burn-import/safetensors-tests/tests/test_mod.rs b/crates/burn-import/safetensors-tests/tests/test_mod.rs new file mode 100644 index 0000000000..ec2f21b7a1 --- /dev/null +++ b/crates/burn-import/safetensors-tests/tests/test_mod.rs @@ -0,0 +1,18 @@ +mod batch_norm; +mod boolean; +mod buffer; +mod complex_nested; +mod conv1d; +mod conv2d; +mod conv_transpose1d; +mod conv_transpose2d; +mod embedding; +mod enum_module; +mod group_norm; +mod integer; +mod key_remap; +mod key_remap_chained; +mod layer_norm; +mod linear; +mod missing_module_field; +mod non_contiguous_indexes; diff --git a/crates/burn-import/src/lib.rs b/crates/burn-import/src/lib.rs index 362cf4b0c6..b9a84cb3d3 100644 --- a/crates/burn-import/src/lib.rs +++ b/crates/burn-import/src/lib.rs @@ -30,5 +30,9 @@ pub mod burn; #[cfg(feature = "pytorch")] pub mod pytorch; +/// The Safetensors module for recorder. +#[cfg(feature = "safetensors")] +pub mod safetensors; + mod formatter; pub use formatter::*; diff --git a/crates/burn-import/src/safetensors/adapter.rs b/crates/burn-import/src/safetensors/adapter.rs new file mode 100644 index 0000000000..9a3de85dca --- /dev/null +++ b/crates/burn-import/src/safetensors/adapter.rs @@ -0,0 +1,104 @@ +use burn::{ + module::Param, + record::{PrecisionSettings, Record}, + tensor::{backend::Backend, Tensor}, +}; + +use burn::record::serde::{ + adapter::{BurnModuleAdapter, DefaultAdapter}, + data::NestedValue, + ser::Serializer, +}; + +use serde::Serialize; + +/// A PyTorch adapter for the Burn module used during deserialization. +/// +/// Not all Burn module correspond to a PyTorch module. Therefore, +/// we need to adapt the Burn module to a PyTorch module. We implement +/// only those that differ. +pub struct PyTorchAdapter { + _precision_settings: std::marker::PhantomData<(PS, B)>, +} + +impl BurnModuleAdapter for PyTorchAdapter { + fn adapt_linear(data: NestedValue) -> NestedValue { + // Get the current module in the form of map. + let mut map = data.as_map().expect("Failed to get map from NestedValue"); + + // Get/remove the weight parameter. + let weight = map + .remove("weight") + .expect("Failed to find 'weight' key in map"); + + // Convert the weight parameter to a tensor (use default device, since it's quick operation). + let weight: Param> = weight + .try_into_record::<_, PS, DefaultAdapter, B>(&B::Device::default()) + .expect("Failed to deserialize weight"); + + // Do not capture transpose op when using autodiff backend + let weight = weight.set_require_grad(false); + // Transpose the weight tensor. + let weight_transposed = Param::from_tensor(weight.val().transpose()); + + // Insert the transposed weight tensor back into the map. + map.insert( + "weight".to_owned(), + serialize::(weight_transposed), + ); + + // Return the modified map. + NestedValue::Map(map) + } + + fn adapt_group_norm(data: NestedValue) -> NestedValue { + rename_weight_bias(data) + } + + fn adapt_batch_norm(data: NestedValue) -> NestedValue { + rename_weight_bias(data) + } + + fn adapt_layer_norm(data: NestedValue) -> NestedValue { + rename_weight_bias(data) + } +} + +/// Helper function to serialize a param tensor. +fn serialize(val: Param>) -> NestedValue +where + B: Backend, + PS: PrecisionSettings, +{ + let serializer = Serializer::new(); + + val.into_item::() + .serialize(serializer) + .expect("Failed to serialize the item") +} + +/// Helper function to rename the weight and bias parameters to gamma and beta. +/// +/// This is needed because PyTorch uses different names for the normalizer parameter +/// than Burn. Burn uses gamma and beta, while PyTorch uses weight and bias. +fn rename_weight_bias(data: NestedValue) -> NestedValue { + // Get the current module in the form of map. + let mut map = data.as_map().expect("Failed to get map from NestedValue"); + + // Rename the weight parameter to gamma. + let weight = map + .remove("weight") + .expect("Failed to find 'weight' key in map"); + + map.insert("gamma".to_owned(), weight); + + // Rename the bias parameter to beta. + let bias = map + .remove("bias") + .expect("Failed to find 'bias' key in map"); + + map.insert("beta".to_owned(), bias); + + // Return the modified map. + NestedValue::Map(map) +} diff --git a/crates/burn-import/src/safetensors/config.rs b/crates/burn-import/src/safetensors/config.rs new file mode 100644 index 0000000000..c766faae13 --- /dev/null +++ b/crates/burn-import/src/safetensors/config.rs @@ -0,0 +1,105 @@ +use std::collections::HashMap; +use std::fs::File; +use std::io::BufReader; +use std::path::Path; + +use super::error::Error; + +use burn::record::serde::{adapter::DefaultAdapter, data::NestedValue, de::Deserializer}; +use candle_core::pickle::{Object, Stack}; +use serde::de::DeserializeOwned; +use zip::ZipArchive; + +/// Extracts data from a `.safetensors` file, specifically looking for "data.pkl". +/// +/// # Arguments +/// * `file_path` - The path to the `.safetensors` file. +/// * `key` - Optional key to retrieve specific data from the pth file. +/// +/// # Returns +/// +/// The nested value that can be deserialized into a specific type. +fn read_st_info>(file_path: P) -> Result { + let mut zip = ZipArchive::new(BufReader::new(File::open(file_path)?))?; + + // We cannot use `zip.by_name` here because we need to find data.pkl in a sub-directory. + let data_pkl_path = (0..zip.len()).find_map(|i| { + let file = zip.by_index(i).ok()?; // Use ok() to convert Result to Option + if file.name().ends_with("data.pkl") { + Some(file.name().to_string()) + } else { + None + } + }); + + let data_pkl_path = + data_pkl_path.ok_or_else(|| Error::Other("data.pkl not found in archive".to_string()))?; + + let reader = zip.by_name(&data_pkl_path)?; + let mut reader = BufReader::new(reader); + let mut stack = Stack::empty(); + stack.read_loop(&mut reader)?; + let obj = stack.finalize()?; + + // Convert the PyTorch object to a nested value recursively + to_nested_value(obj) +} + +/// Convert a PyTorch object to a nested value recursively. +/// +/// # Arguments +/// * `obj` - The PyTorch object to convert. +/// +/// # Returns +/// The nested value. +fn to_nested_value(obj: Object) -> Result { + match obj { + Object::Bool(v) => Ok(NestedValue::Bool(v)), + Object::Int(v) => Ok(NestedValue::I32(v)), + Object::Float(v) => Ok(NestedValue::F64(v)), + Object::Unicode(v) => Ok(NestedValue::String(v)), + Object::List(v) => { + let list = v + .into_iter() + .map(to_nested_value) + .collect::, _>>()?; + Ok(NestedValue::Vec(list)) + } + Object::Dict(key_values) => { + let map = key_values + .into_iter() + .filter_map(|(name, value)| { + if let Object::Unicode(name) = name { + let nested_value = to_nested_value(value).ok()?; + Some((name, nested_value)) + } else { + None // Skip non-unicode names + } + }) + .collect::>(); + Ok(NestedValue::Map(map)) + } + _ => Err(Error::Other("Unsupported value type".into())), + } +} + +/// Deserialize config values from a `.safetensors` file. +/// +/// # Arguments +/// +/// * `path` - The path to the `.safetensors` file. +pub fn config_from_file(path: P) -> Result +where + D: DeserializeOwned, + P: AsRef, +{ + // Read the nested value from the file + let nested_value = read_st_info(path)?; + + // Create a deserializer with PyTorch adapter and nested value + let deserializer = Deserializer::::new(nested_value, true); + + // Deserialize the nested value into a target type + let value = D::deserialize(deserializer)?; + Ok(value) +} diff --git a/crates/burn-import/src/safetensors/error.rs b/crates/burn-import/src/safetensors/error.rs new file mode 100644 index 0000000000..c3568453e0 --- /dev/null +++ b/crates/burn-import/src/safetensors/error.rs @@ -0,0 +1,28 @@ +use burn::record::{serde::error, RecorderError}; +use zip::result::ZipError; + +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("Serde error: {0}")] + Serde(#[from] error::Error), + + #[error("Candle SafeTensors error: {0}")] + CandleSafeTensors(#[from] candle_core::Error), + + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + + #[error("Zip error: {0}")] + Zip(#[from] ZipError), + + // Add other kinds of errors as needed + #[error("other error: {0}")] + Other(String), +} + +// Implement From trait for Error to RecorderError +impl From for RecorderError { + fn from(error: Error) -> Self { + RecorderError::DeserializeError(error.to_string()) + } +} diff --git a/crates/burn-import/src/safetensors/mod.rs b/crates/burn-import/src/safetensors/mod.rs new file mode 100644 index 0000000000..50da7bdbca --- /dev/null +++ b/crates/burn-import/src/safetensors/mod.rs @@ -0,0 +1,7 @@ +mod adapter; +mod config; +mod error; +mod reader; +mod recorder; +pub use config::config_from_file; +pub use recorder::{LoadArgs, SafeTensorsFileRecorder}; diff --git a/crates/burn-import/src/safetensors/reader.rs b/crates/burn-import/src/safetensors/reader.rs new file mode 100644 index 0000000000..c3adc444c3 --- /dev/null +++ b/crates/burn-import/src/safetensors/reader.rs @@ -0,0 +1,171 @@ +use core::ops::Deref; +use std::collections::HashMap; +use std::path::Path; + +use super::{adapter::PyTorchAdapter, error::Error}; + +use burn::{ + module::ParamId, + record::PrecisionSettings, + tensor::{Element, ElementConversion, TensorData}, +}; +use burn::{ + record::serde::{ + data::{remap, unflatten, NestedValue, Serializable}, + de::Deserializer, + error, + ser::Serializer, + }, + tensor::backend::Backend, +}; + +use candle_core::{safetensors, Device, WithDType}; +use half::{bf16, f16}; +use regex::Regex; +use serde::{de::DeserializeOwned, Serialize}; + +/// Deserializes a PyTorch file. +/// +/// # Arguments +/// +/// * `path` - A string slice that holds the path of the file to read. +/// * `key_remap` - A vector of tuples containing a regular expression and a replacement string. +/// * `top_level_key` - An optional top-level key to load state_dict from a dictionary. +pub fn from_file( + path: &Path, + key_remap: Vec<(Regex, String)>, + debug: bool, +) -> Result +where + D: DeserializeOwned, + PS: PrecisionSettings, + B: Backend, +{ + // Read the safetensors file and return a vector of Candle tensors + let tensors: HashMap = safetensors::load(path, &Device::Cpu)? + .into_iter() + .map(|(key, tensor)| (key, CandleTensor(tensor))) + .collect(); + + // Remap the keys (replace the keys in the map with the new keys) + let (tensors, remapped_keys) = remap(tensors, key_remap); + + // Print the remapped keys if debug is enabled + if debug { + let mut remapped_keys = remapped_keys; + remapped_keys.sort(); + println!("Debug information of keys and tensor shapes:\n---"); + for (new_key, old_key) in remapped_keys { + if old_key != new_key { + println!("Original Key: {old_key}"); + println!("Remapped Key: {new_key}"); + } else { + println!("Key: {}", new_key); + } + + let shape = tensors[&new_key].shape(); + let dtype = tensors[&new_key].dtype(); + println!("Shape: {shape:?}"); + println!("Dtype: {dtype:?}"); + println!("---"); + } + } + + // Convert the vector of Candle tensors to a nested value data structure + let nested_value = unflatten::(tensors)?; + + // Create a deserializer with PyTorch adapter and nested value + let deserializer = Deserializer::>::new(nested_value, true); + + // Deserialize the nested value into a record type + let value = D::deserialize(deserializer)?; + Ok(value) +} + +/// Serializes a candle tensor. +/// +/// Tensors are wrapped in a `Param` struct (learnable parameters) and serialized as a `TensorData` struct. +/// +/// Values are serialized as `FloatElem` or `IntElem` depending on the precision settings. +impl Serializable for CandleTensor { + fn serialize(&self, serializer: Serializer) -> Result + where + PS: PrecisionSettings, + { + let shape = self.shape().clone().into_dims(); + let flatten = CandleTensor(self.flatten_all().expect("Failed to flatten the tensor")); + let param_id = ParamId::new(); + + match self.dtype() { + candle_core::DType::U8 => { + serialize_data::(flatten, shape, param_id, serializer) + } + candle_core::DType::U32 => { + serialize_data::(flatten, shape, param_id, serializer) + } + candle_core::DType::I64 => { + serialize_data::(flatten, shape, param_id, serializer) + } + candle_core::DType::BF16 => { + serialize_data::(flatten, shape, param_id, serializer) + } + candle_core::DType::F16 => { + serialize_data::(flatten, shape, param_id, serializer) + } + candle_core::DType::F32 => { + serialize_data::(flatten, shape, param_id, serializer) + } + candle_core::DType::F64 => { + serialize_data::(flatten, shape, param_id, serializer) + } + } + } +} + +/// Helper function to serialize a candle tensor data. +fn serialize_data( + tensor: CandleTensor, + shape: Vec, + param_id: ParamId, + serializer: Serializer, +) -> Result +where + E: Element + Serialize, + T: WithDType + ElementConversion, +{ + let data: Vec = tensor + .to_vec1::() + .map_err(|err| error::Error::Other(format!("Candle to vec1 error: {err}")))? + .into_iter() + .map(ElementConversion::elem) + .collect(); + + let data = TensorData::new(data, shape.clone()); + let (dtype, bytes) = (data.dtype, data.into_bytes()); + + // Manually serialize the tensor instead of using the `ParamSerde` struct, such as: + // ParamSerde::new(param_id, TensorData::new(data, shape)).serialize(serializer) + // Because serializer copies individual elements of TensorData `value` into a new Vec, + // which is not necessary and inefficient. + let mut tensor_data: HashMap = HashMap::new(); + tensor_data.insert("bytes".into(), NestedValue::Bytes(bytes)); + tensor_data.insert("shape".into(), shape.serialize(serializer.clone())?); + tensor_data.insert("dtype".into(), dtype.serialize(serializer)?); + + let mut param: HashMap = HashMap::new(); + param.insert("id".into(), NestedValue::String(param_id.serialize())); + param.insert("param".into(), NestedValue::Map(tensor_data)); + + Ok(NestedValue::Map(param)) +} + +/// New type struct for Candle tensors because we need to implement the `Serializable` trait for it. +struct CandleTensor(candle_core::Tensor); + +impl Deref for CandleTensor { + type Target = candle_core::Tensor; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} diff --git a/crates/burn-import/src/safetensors/recorder.rs b/crates/burn-import/src/safetensors/recorder.rs new file mode 100644 index 0000000000..e18d7f9631 --- /dev/null +++ b/crates/burn-import/src/safetensors/recorder.rs @@ -0,0 +1,146 @@ +use core::marker::PhantomData; +use std::path::PathBuf; + +use burn::{ + record::{PrecisionSettings, Record, Recorder, RecorderError}, + tensor::backend::Backend, +}; + +use regex::Regex; +use serde::{de::DeserializeOwned, Serialize}; + +use super::reader::from_file; + +/// A recorder that loads HuggingFace SafeTensors files (`.safetensors`) into Burn modules. +/// +/// LoadArgs can be used to remap keys or file path. +/// See [LoadArgs](struct.LoadArgs.html) for more information. +/// +#[derive(new, Debug, Default, Clone)] +pub struct SafeTensorsFileRecorder { + _settings: PhantomData, +} + +impl Recorder for SafeTensorsFileRecorder { + type Settings = PS; + type RecordArgs = PathBuf; + type RecordOutput = (); + type LoadArgs = LoadArgs; + + fn save_item( + &self, + _item: I, + _file: Self::RecordArgs, + ) -> Result<(), RecorderError> { + unimplemented!("save_item not implemented for SafeTensorsFileRecorder") + } + + fn load_item(&self, _file: Self::LoadArgs) -> Result { + unimplemented!("load_item not implemented for SafeTensorsFileRecorder") + } + + fn load>( + &self, + args: Self::LoadArgs, + device: &B::Device, + ) -> Result { + let item = + from_file::, B>(&args.file, args.key_remap, args.debug)?; + Ok(R::from_item(item, device)) + } +} + +/// Arguments for loading a SafeTensors file. +/// +/// # Fields +/// +/// * `file` - The path to the file to load. +/// * `key_remap` - A vector of tuples containing a regular expression and a replacement string. +/// See [regex::Regex::replace](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace) +/// for more information. +/// +/// # Notes +/// +/// +/// +/// # Examples +/// +/// ```text +/// use burn_import::pytorch::{LoadArgs, SafeTensorsFileRecorder")}; +/// use burn::record::FullPrecisionSettings; +/// use burn::record::Recorder; +/// +/// let args = LoadArgs::new("tests/key_remap/key_remap.pt".into()) +/// .with_key_remap("conv\\.(.*)", "$1"); // // Remove "conv" prefix, e.g. "conv.conv1" -> "conv1" +/// +/// let record = SafeTensorsFileRecorder")::::default() +/// .load(args) +/// .expect("Should decode state successfully"); +/// ``` +#[derive(Debug, Clone)] +pub struct LoadArgs { + /// The path to the file to load. + pub file: PathBuf, + + /// A list of key remappings. + pub key_remap: Vec<(Regex, String)>, + + /// Whether to print debug information. + pub debug: bool, +} + +impl LoadArgs { + /// Creates a new `LoadArgs` instance. + /// + /// # Arguments + /// + /// * `file` - The path to the file to load. + pub fn new(file: PathBuf) -> Self { + Self { + file, + key_remap: Vec::new(), + debug: false, + } + } + + /// Sets key remapping. + /// + /// # Arguments + /// + /// * `pattern` - The Regex pattern to be replaced. + /// * `replacement` - The pattern to replace with. + /// + /// See [Regex](https://docs.rs/regex/1.5.4/regex/#syntax) for the pattern syntax and + /// [Replacement](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace) for the + /// replacement syntax. + pub fn with_key_remap(mut self, pattern: &str, replacement: &str) -> Self { + let regex = Regex::new(pattern).expect("Valid regex"); + + self.key_remap.push((regex, replacement.into())); + self + } + + /// Sets printing debug information on. + pub fn with_debug_print(mut self) -> Self { + self.debug = true; + self + } +} + +impl From for LoadArgs { + fn from(val: PathBuf) -> Self { + LoadArgs::new(val) + } +} + +impl From for LoadArgs { + fn from(val: String) -> Self { + LoadArgs::new(val.into()) + } +} + +impl From<&str> for LoadArgs { + fn from(val: &str) -> Self { + LoadArgs::new(val.into()) + } +} diff --git a/examples/pytorch-import/pytorch/mnist.py b/examples/pytorch-import/pytorch/mnist.py index 08e36825fc..38f22f6f05 100755 --- a/examples/pytorch-import/pytorch/mnist.py +++ b/examples/pytorch-import/pytorch/mnist.py @@ -5,6 +5,7 @@ from __future__ import print_function import argparse +from safetensors.torch import save_file import torch import torch.nn as nn import torch.nn.functional as F @@ -53,9 +54,15 @@ def train(args, model, device, train_loader, optimizer, epoch): loss.backward() optimizer.step() if batch_idx % args.log_interval == 0: - print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( - epoch, batch_idx * len(data), len(train_loader.dataset), - 100. * batch_idx / len(train_loader), loss.item())) + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(data), + len(train_loader.dataset), + 100.0 * batch_idx / len(train_loader), + loss.item(), + ) + ) if args.dry_run: break @@ -69,45 +76,98 @@ def test(model, device, test_loader): data, target = data.to(device), target.to(device) output = model(data) # sum up batch loss - test_loss += F.nll_loss(output, target, reduction='sum').item() + test_loss += F.nll_loss(output, target, reduction="sum").item() # get the index of the max log-probability pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) - print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( - test_loss, correct, len(test_loader.dataset), - 100. * correct / len(test_loader.dataset))) + print( + "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( + test_loss, + correct, + len(test_loader.dataset), + 100.0 * correct / len(test_loader.dataset), + ) + ) def main(): # Training settings - parser = argparse.ArgumentParser(description='PyTorch MNIST Example') - parser.add_argument('--batch-size', type=int, default=64, metavar='N', - help='input batch size for training (default: 64)') - parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', - help='input batch size for testing (default: 1000)') - parser.add_argument('--epochs', type=int, default=8, metavar='N', - help='number of epochs to train (default: 14)') - parser.add_argument('--lr', type=float, default=1.0, metavar='LR', - help='learning rate (default: 1.0)') - parser.add_argument('--gamma', type=float, default=0.7, metavar='M', - help='Learning rate step gamma (default: 0.7)') - parser.add_argument('--no-cuda', action='store_true', default=False, - help='disables CUDA training') - parser.add_argument('--no-mps', action='store_true', default=False, - help='disables macOS GPU training') - parser.add_argument('--dry-run', action='store_true', default=False, - help='quickly check a single pass') - parser.add_argument('--seed', type=int, default=1, metavar='S', - help='random seed (default: 1)') - parser.add_argument('--log-interval', type=int, default=10, metavar='N', - help='how many batches to wait before logging training status') - parser.add_argument('--save-model', action='store_true', default=True, - help='For Saving the current Model') - parser.add_argument('--export-onnx', action='store_true', default=False, - help='For Saving the current Model in ONNX format') + parser = argparse.ArgumentParser(description="PyTorch MNIST Example") + parser.add_argument( + "--batch-size", + type=int, + default=64, + metavar="N", + help="input batch size for training (default: 64)", + ) + parser.add_argument( + "--test-batch-size", + type=int, + default=1000, + metavar="N", + help="input batch size for testing (default: 1000)", + ) + parser.add_argument( + "--epochs", + type=int, + default=8, + metavar="N", + help="number of epochs to train (default: 14)", + ) + parser.add_argument( + "--lr", + type=float, + default=1.0, + metavar="LR", + help="learning rate (default: 1.0)", + ) + parser.add_argument( + "--gamma", + type=float, + default=0.7, + metavar="M", + help="Learning rate step gamma (default: 0.7)", + ) + parser.add_argument( + "--no-cuda", action="store_true", default=False, help="disables CUDA training" + ) + parser.add_argument( + "--no-mps", + action="store_true", + default=False, + help="disables macOS GPU training", + ) + parser.add_argument( + "--dry-run", + action="store_true", + default=False, + help="quickly check a single pass", + ) + parser.add_argument( + "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)" + ) + parser.add_argument( + "--log-interval", + type=int, + default=10, + metavar="N", + help="how many batches to wait before logging training status", + ) + parser.add_argument( + "--save-model", + action="store_true", + default=True, + help="For Saving the current Model", + ) + parser.add_argument( + "--export-onnx", + action="store_true", + default=False, + help="For Saving the current Model in ONNX format", + ) args = parser.parse_args() use_cuda = not args.no_cuda and torch.cuda.is_available() use_mps = not args.no_mps and torch.backends.mps.is_available() @@ -118,26 +178,24 @@ def main(): device = torch.device("cuda") elif use_mps: device = torch.device("mps") + print("using MPS") else: device = torch.device("cpu") - train_kwargs = {'batch_size': args.batch_size} - test_kwargs = {'batch_size': args.test_batch_size} + train_kwargs = {"batch_size": args.batch_size} + test_kwargs = {"batch_size": args.test_batch_size} if use_cuda: - cuda_kwargs = {'num_workers': 1, - 'pin_memory': True, - 'shuffle': True} + cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True} train_kwargs.update(cuda_kwargs) test_kwargs.update(cuda_kwargs) - transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ]) - dataset1 = datasets.MNIST('/tmp/mnist-data', train=True, download=True, - transform=transform) - dataset2 = datasets.MNIST('/tmp/mnist-data', train=False, - transform=transform) + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ) + dataset1 = datasets.MNIST( + "/tmp/mnist-data", train=True, download=True, transform=transform + ) + dataset2 = datasets.MNIST("/tmp/mnist-data", train=False, transform=transform) train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs) test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) @@ -152,12 +210,14 @@ def main(): if args.save_model: torch.save(model.state_dict(), "mnist.pt") + save_file(model.state_dict(), "mnist.safetensors") if args.export_onnx: dummy_input = torch.randn(1, 1, 28, 28, device=device) - torch.onnx.export(model, dummy_input, "mnist.onnx", - verbose=True, opset_version=16) + torch.onnx.export( + model, dummy_input, "mnist.onnx", verbose=True, opset_version=16 + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/pytorch-import/pytorch/mnist.safetensors b/examples/pytorch-import/pytorch/mnist.safetensors new file mode 100644 index 0000000000..95814d4536 Binary files /dev/null and b/examples/pytorch-import/pytorch/mnist.safetensors differ diff --git a/examples/safetensors-import/Cargo.toml b/examples/safetensors-import/Cargo.toml new file mode 100644 index 0000000000..5eae87a7aa --- /dev/null +++ b/examples/safetensors-import/Cargo.toml @@ -0,0 +1,24 @@ +[package] +authors = ["Dilshod Tadjibaev (@antimora)"] +edition = "2021" +license = "MIT OR Apache-2.0" +name = "safetensors-import" +publish = false +version = "0.17.0" + +[dependencies] +burn = { path = "../../crates/burn", features = [ + "ndarray", + "dataset", + "vision", +] } + +stmodel = { path = "./stmodel" } + + +[build-dependencies] +stmodel = { path = "./stmodel" } +burn = { path = "../../crates/burn", features = ["ndarray"] } +burn-import = { path = "../../crates/burn-import", features = [ + "safetensors", +], default-features = false } diff --git a/examples/safetensors-import/README.md b/examples/safetensors-import/README.md new file mode 100644 index 0000000000..b348dcb203 --- /dev/null +++ b/examples/safetensors-import/README.md @@ -0,0 +1,29 @@ +# Import PyTorch Weights + +This crate provides a simple example for importing PyTorch generated weights to Burn. + +The `.pt` file is converted into a Burn consumable file (message pack format) using `burn-import`. +The conversation is done in the `build.rs` file. + +The model is separated into a sub-crate because `build.rs` needs for conversion and build.rs cannot +import modules for the same crate. + +## Usage + +```bash +cargo run -- 15 +``` + +Output: + +```bash +Finished dev [unoptimized + debuginfo] target(s) in 0.13s + Running `burn/target/debug/onnx-inference 15` + +Image index: 15 +Success! +Predicted: 5 +Actual: 5 +See the image online, click the link below: +https://huggingface.co/datasets/ylecun/mnist/viewer/mnist/test?row=15 +``` diff --git a/examples/safetensors-import/build.rs b/examples/safetensors-import/build.rs new file mode 100644 index 0000000000..063a77eb17 --- /dev/null +++ b/examples/safetensors-import/build.rs @@ -0,0 +1,34 @@ +/// This build script does the following: +/// 1. Loads PyTorch weights into a model record. +/// 2. Saves the model record to a file using the `NamedMpkFileRecorder`. +use std::path::Path; + +use burn::{ + backend::NdArray, + record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder}, +}; +use burn_import::safetensors::SafeTensorsFileRecorder; + +// Basic backend type (not used directly here). +type B = NdArray; + +fn main() { + let device = Default::default(); + + // Load PyTorch weights into a model record. + let record: stmodel::ModelRecord = + SafeTensorsFileRecorder::::default() + .load("pytorch/mnist.safetensors".into(), &device) + .expect("Failed to decode state"); + + // Save the model record to a file. + let recorder = NamedMpkFileRecorder::::default(); + + // Save into the OUT_DIR directory so that the model can be loaded by the + let out_dir = std::env::var("OUT_DIR").unwrap(); + let file_path = Path::new(&out_dir).join("model/mnist"); + + recorder + .record(record, file_path) + .expect("Failed to save model record"); +} diff --git a/examples/safetensors-import/pytorch/mnist.py b/examples/safetensors-import/pytorch/mnist.py new file mode 100755 index 0000000000..ec197ccf4d --- /dev/null +++ b/examples/safetensors-import/pytorch/mnist.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 + +# Originally copied and modified from: https://github.com/pytorch/examples/blob/main/mnist/main.py +# under the following license: BSD-3-Clause license + +from __future__ import print_function +import argparse +from safetensors.torch import save_file +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +from torch.optim.lr_scheduler import StepLR + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 8, 3) + self.conv2 = nn.Conv2d(8, 16, 3) + self.conv3 = nn.Conv2d(16, 24, 3) + self.norm1 = nn.BatchNorm2d(24) + self.dropout1 = nn.Dropout(0.3) + self.fc1 = nn.Linear(24 * 22 * 22, 32) + self.fc2 = nn.Linear(32, 10) + self.norm2 = nn.BatchNorm1d(10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = self.conv3(x) + x = F.relu(x) + x = self.norm1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout1(x) + x = self.fc2(x) + x = self.norm2(x) + output = F.log_softmax(x, dim=1) + return output + + +def train(args, model, device, train_loader, optimizer, epoch): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + if batch_idx % args.log_interval == 0: + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(data), + len(train_loader.dataset), + 100.0 * batch_idx / len(train_loader), + loss.item(), + ) + ) + if args.dry_run: + break + + +def test(model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + # sum up batch loss + test_loss += F.nll_loss(output, target, reduction="sum").item() + # get the index of the max log-probability + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + + test_loss /= len(test_loader.dataset) + + print( + "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( + test_loss, + correct, + len(test_loader.dataset), + 100.0 * correct / len(test_loader.dataset), + ) + ) + + +def main(): + # Training settings + parser = argparse.ArgumentParser(description="PyTorch MNIST Example") + parser.add_argument( + "--batch-size", + type=int, + default=64, + metavar="N", + help="input batch size for training (default: 64)", + ) + parser.add_argument( + "--test-batch-size", + type=int, + default=1000, + metavar="N", + help="input batch size for testing (default: 1000)", + ) + parser.add_argument( + "--epochs", + type=int, + default=8, + metavar="N", + help="number of epochs to train (default: 14)", + ) + parser.add_argument( + "--lr", + type=float, + default=1.0, + metavar="LR", + help="learning rate (default: 1.0)", + ) + parser.add_argument( + "--gamma", + type=float, + default=0.7, + metavar="M", + help="Learning rate step gamma (default: 0.7)", + ) + parser.add_argument( + "--no-cuda", action="store_true", default=False, help="disables CUDA training" + ) + parser.add_argument( + "--no-mps", + action="store_true", + default=False, + help="disables macOS GPU training", + ) + parser.add_argument( + "--dry-run", + action="store_true", + default=False, + help="quickly check a single pass", + ) + parser.add_argument( + "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)" + ) + parser.add_argument( + "--log-interval", + type=int, + default=10, + metavar="N", + help="how many batches to wait before logging training status", + ) + parser.add_argument( + "--save-model", + action="store_true", + default=True, + help="For Saving the current Model in Safetensors format", + ) + parser.add_argument( + "--export-onnx", + action="store_true", + default=False, + help="For Saving the current Model in ONNX format", + ) + args = parser.parse_args() + use_cuda = not args.no_cuda and torch.cuda.is_available() + use_mps = not args.no_mps and torch.backends.mps.is_available() + + torch.manual_seed(args.seed) + + if use_cuda: + device = torch.device("cuda") + elif use_mps: + device = torch.device("mps") + print("using MPS") + else: + device = torch.device("cpu") + + train_kwargs = {"batch_size": args.batch_size} + test_kwargs = {"batch_size": args.test_batch_size} + if use_cuda: + cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True} + train_kwargs.update(cuda_kwargs) + test_kwargs.update(cuda_kwargs) + + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ) + dataset1 = datasets.MNIST( + "/tmp/mnist-data", train=True, download=True, transform=transform + ) + dataset2 = datasets.MNIST("/tmp/mnist-data", train=False, transform=transform) + train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs) + test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) + + model = Net().to(device) + optimizer = optim.Adadelta(model.parameters(), lr=args.lr) + + scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) + for epoch in range(1, args.epochs + 1): + train(args, model, device, train_loader, optimizer, epoch) + test(model, device, test_loader) + scheduler.step() + + if args.save_model: + save_file(model.state_dict(), "mnist.safetensors") + + if args.export_onnx: + dummy_input = torch.randn(1, 1, 28, 28, device=device) + torch.onnx.export( + model, dummy_input, "mnist.onnx", verbose=True, opset_version=16 + ) + + +if __name__ == "__main__": + main() diff --git a/examples/safetensors-import/pytorch/mnist.safetensors b/examples/safetensors-import/pytorch/mnist.safetensors new file mode 100644 index 0000000000..95814d4536 Binary files /dev/null and b/examples/safetensors-import/pytorch/mnist.safetensors differ diff --git a/examples/safetensors-import/src/main.rs b/examples/safetensors-import/src/main.rs new file mode 100644 index 0000000000..a09fafca1e --- /dev/null +++ b/examples/safetensors-import/src/main.rs @@ -0,0 +1,71 @@ +use std::env::args; +use std::path::Path; + +use burn::{ + backend::ndarray::NdArray, + data::dataset::{vision::MnistDataset, Dataset}, + module::Module, + record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder}, + tensor::Tensor, +}; + +use stmodel::Model; + +const IMAGE_INX: usize = 42; // <- Change this to test a different image + +// Build output direct that contains converted model weight file path +const OUT_DIR: &str = concat!(env!("OUT_DIR"), "/model/mnist"); + +fn main() { + // Get image index argument (first) from command line + + let image_index = if let Some(image_index) = args().nth(1) { + println!("Image index: {}", image_index); + image_index + .parse::() + .expect("Failed to parse image index") + } else { + println!("No image index provided; Using default image index: {IMAGE_INX}"); + IMAGE_INX + }; + + assert!(image_index < 10000, "Image index must be less than 10000"); + + type Backend = NdArray; + let device = Default::default(); + + // Load the model record from converted PyTorch file by the build script + let record = NamedMpkFileRecorder::::default() + .load(Path::new(OUT_DIR).into(), &device) + .expect("Failed to decode state"); + + // Create a new model and load the state + let model: Model = Model::init(&device).load_record(record); + + // Load the MNIST dataset and get an item + let dataset = MnistDataset::test(); + let item = dataset.get(image_index).unwrap(); + + // Create a tensor from the image data + let image_data = item.image.iter().copied().flatten().collect::>(); + let mut input = + Tensor::::from_floats(image_data.as_slice(), &device).reshape([1, 1, 28, 28]); + + // Normalize the input + input = ((input / 255) - 0.1307) / 0.3081; + + // Run the model on the input + let output = model.forward(input); + + // Get the index of the maximum value + let arg_max = output.argmax(1).into_scalar() as u8; + + // Check if the index matches the label + assert!(arg_max == item.label); + + println!("Success!"); + println!("Predicted: {}", arg_max); + println!("Actual: {}", item.label); + println!("See the image online, click the link below:"); + println!("https://huggingface.co/datasets/ylecun/mnist/viewer/mnist/test?row={image_index}"); +} diff --git a/examples/safetensors-import/stmodel/Cargo.toml b/examples/safetensors-import/stmodel/Cargo.toml new file mode 100644 index 0000000000..e3ad37395f --- /dev/null +++ b/examples/safetensors-import/stmodel/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "stmodel" +version = "0.6.0" +edition = "2021" + +[dependencies] +burn = { path = "../../../crates/burn" } +burn-import = { path = "../../../crates/burn-import", features = [ + "safetensors", +], default-features = false } diff --git a/examples/safetensors-import/stmodel/src/lib.rs b/examples/safetensors-import/stmodel/src/lib.rs new file mode 100644 index 0000000000..405caf2cdf --- /dev/null +++ b/examples/safetensors-import/stmodel/src/lib.rs @@ -0,0 +1,77 @@ +use std::env; +use std::path::Path; + +use burn::{ + nn::{ + conv::{Conv2d, Conv2dConfig}, + BatchNorm, BatchNormConfig, Linear, LinearConfig, + }, + prelude::*, + record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder}, + tensor::activation::{log_softmax, relu}, +}; + +#[derive(Module, Debug)] +pub struct Model { + conv1: Conv2d, + conv2: Conv2d, + conv3: Conv2d, + norm1: BatchNorm, + fc1: Linear, + fc2: Linear, + norm2: BatchNorm, + phantom: core::marker::PhantomData, +} + +impl Default for Model { + fn default() -> Self { + let device = B::Device::default(); + let out_dir = env::var_os("OUT_DIR").unwrap(); + let file_path = Path::new(&out_dir).join("model/mnist"); + + let record = NamedMpkFileRecorder::::default() + .load(file_path, &device) + .expect("Failed to decode state"); + + Self::init(&device).load_record(record) + } +} + +impl Model { + pub fn init(device: &B::Device) -> Self { + let conv1 = Conv2dConfig::new([1, 8], [3, 3]).init(device); + let conv2 = Conv2dConfig::new([8, 16], [3, 3]).init(device); + let conv3 = Conv2dConfig::new([16, 24], [3, 3]).init(device); + let norm1 = BatchNormConfig::new(24).init(device); + let fc1 = LinearConfig::new(11616, 32).init(device); + let fc2 = LinearConfig::new(32, 10).init(device); + let norm2 = BatchNormConfig::new(10).init(device); + + Self { + conv1, + conv2, + conv3, + norm1, + fc1, + fc2, + norm2, + phantom: core::marker::PhantomData, + } + } + + pub fn forward(&self, input1: Tensor) -> Tensor { + let conv1_out1 = self.conv1.forward(input1); + let relu1_out1 = relu(conv1_out1); + let conv2_out1 = self.conv2.forward(relu1_out1); + let relu2_out1 = relu(conv2_out1); + let conv3_out1 = self.conv3.forward(relu2_out1); + let relu3_out1 = relu(conv3_out1); + let norm1_out1 = self.norm1.forward(relu3_out1); + let flatten1_out1 = norm1_out1.flatten(1, 3); + let fc1_out1 = self.fc1.forward(flatten1_out1); + let relu4_out1 = relu(fc1_out1); + let fc2_out1 = self.fc2.forward(relu4_out1); + let norm2_out1 = self.norm2.forward(fc2_out1); + log_softmax(norm2_out1, 1) + } +}