Skip to content

Commit

Permalink
support safetensors format, for it's own crate
Browse files Browse the repository at this point in the history
  • Loading branch information
wandbrandon committed Jan 28, 2025
1 parent 6a0330e commit dbd40ce
Show file tree
Hide file tree
Showing 79 changed files with 3,752 additions and 20 deletions.
29 changes: 29 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ members = [
"crates/*",
"crates/burn-import/pytorch-tests",
"crates/burn-import/onnx-tests",
"crates/burn-import/safetensors-tests",
"examples/*",
"examples/pytorch-import/model",
"xtask",
Expand Down
7 changes: 5 additions & 2 deletions crates/burn-import/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@ version.workspace = true
default-run = "onnx2burn"

[features]
default = ["onnx", "pytorch"]
default = ["onnx", "pytorch", "safetensors"]
onnx = []
pytorch = ["burn/record-item-custom-serde", "thiserror", "zip"]
safetensors = ["burn/record-item-custom-serde", "thiserror", "zip"]

[dependencies]
burn = { path = "../burn", version = "0.17.0", default-features = false, features = ["std"]}
burn = { path = "../burn", version = "0.17.0", default-features = false, features = [
"std",
] }
burn-ndarray = { path = "../burn-ndarray", version = "0.17.0", default-features = false }
onnx-ir = { path = "../onnx-ir", version = "0.17.0" }
candle-core = { workspace = true }
Expand Down
17 changes: 17 additions & 0 deletions crates/burn-import/safetensors-tests/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[package]
name = "safetensors-tests"
version.workspace = true
edition.workspace = true
license.workspace = true

[dev-dependencies]
burn = { path = "../../burn" }
burn-ndarray = { path = "../../burn-ndarray" }
burn-autodiff = { path = "../../burn-autodiff" }
serde = { workspace = true }
float-cmp = { workspace = true }
burn-import = { path = "../", features = ["safetensors"] }


[build-dependencies]
burn-import = { path = "../", features = ["safetensors"] }
1 change: 1 addition & 0 deletions crates/burn-import/safetensors-tests/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env python3

import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import save_file


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.norm1 = nn.BatchNorm2d(5)

def forward(self, x):
x = self.norm1(x)
return x


def main():

torch.set_printoptions(precision=8)
torch.manual_seed(1)

model = Model().to(torch.device("cpu"))

# Condition batch norm (each forward will affect the running stats)
x1 = torch.ones(1, 5, 2, 2) - 0.5
_ = model(x1)
model.eval() # Set to eval mode to freeze running stats
# Save the model to safetensors after the first forward
save_file(model.state_dict(), "batch_norm2d.safetensors")

x2 = torch.ones(1, 5, 2, 2) - 0.3
print("Input shape: {}", x2.shape)
output = model(x2)
print("Output: {}", output)
print("Output Shape: {}", output.shape)


if __name__ == "__main__":
main()
60 changes: 60 additions & 0 deletions crates/burn-import/safetensors-tests/tests/batch_norm/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
use burn::{
module::Module,
nn::{BatchNorm, BatchNormConfig},
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Net<B: Backend> {
norm1: BatchNorm<B, 2>,
}

impl<B: Backend> Net<B> {
pub fn new(device: &B::Device) -> Self {
Self {
norm1: BatchNormConfig::new(4).init(device),
}
}

/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.norm1.forward(x)
}
}

#[cfg(test)]
mod tests {
type Backend = burn_ndarray::NdArray<f32>;

use burn::record::{FullPrecisionSettings, Recorder};
use burn_import::safetensors::SafeTensorsFileRecorder;

use super::*;

#[test]
fn batch_norm2d() {
let device = Default::default();
let record = SafeTensorsFileRecorder::<FullPrecisionSettings>::default()
.load("tests/batch_norm/batch_norm2d.safetensors".into(), &device)
.expect("Should decode state successfully");

let model = Net::<Backend>::new(&device).load_record(record);

let input = Tensor::<Backend, 4>::ones([1, 5, 2, 2], &device) - 0.3;

let output = model.forward(input);

let expected = Tensor::<Backend, 4>::from_data(
[[
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]],
]],
&device,
);

output.to_data().assert_approx_eq(&expected.to_data(), 5);
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/usr/bin/env python3

import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import save_file


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
buffer = torch.tensor([True, False, True])
self.register_buffer("buffer", buffer, persistent=True)

def forward(self, x):
x = self.buffer
return x


def main():

torch.set_printoptions(precision=8)
torch.manual_seed(1)

model = Model().to(torch.device("cpu"))

save_file(model.state_dict(), "boolean.safetensors")

input = torch.ones(3, 3)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)


if __name__ == "__main__":
main()
58 changes: 58 additions & 0 deletions crates/burn-import/safetensors-tests/tests/boolean/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
use burn::{
module::{Module, Param},
tensor::{backend::Backend, Bool, Tensor},
};

#[derive(Module, Debug)]
pub struct Net<B: Backend> {
buffer: Param<Tensor<B, 1, Bool>>,
}

impl<B: Backend> Net<B> {
/// Create a new model from the given record.
pub fn new_with(record: NetRecord<B>) -> Self {
Self {
buffer: record.buffer,
}
}

/// Forward pass of the model.
pub fn forward(&self, _x: Tensor<B, 2>) -> Tensor<B, 1, Bool> {
self.buffer.val()
}
}

#[cfg(test)]
mod tests {

use burn::{
record::{FullPrecisionSettings, Recorder},
tensor::TensorData,
};
use burn_import::safetensors::SafeTensorsFileRecorder;

use super::*;

type Backend = burn_ndarray::NdArray<f32>;

#[test]
#[ignore = "It appears loading boolean tensors are not supported yet"]
// Error skipping: Msg("unsupported storage type BoolStorage")
fn boolean() {
let device = Default::default();
let record = SafeTensorsFileRecorder::<FullPrecisionSettings>::default()
.load("tests/boolean/boolean.safetensors".into(), &device)
.expect("Should decode state successfully");

let model = Net::<Backend>::new_with(record);

let input = Tensor::<Backend, 2>::ones([3, 3], &device);

let output = model.forward(input);

let expected =
Tensor::<Backend, 1, Bool>::from_bool(TensorData::from([true, false, true]), &device);

assert_eq!(output.to_data(), expected.to_data());
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/usr/bin/env python3

import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import save_file


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
buffer = torch.ones(3, 3)
self.register_buffer("buffer", buffer, persistent=True)

def forward(self, x):
x = self.buffer + x
return x


def main():

torch.set_printoptions(precision=8)
torch.manual_seed(1)

model = Model().to(torch.device("cpu"))

save_file(model.state_dict(), "buffer.safetensors")

input = torch.ones(3, 3)
print("Input shape: {}", input.shape)
print("Input: {}", input)
output = model(input)
print("Output: {}", output)
print("Output Shape: {}", output.shape)


if __name__ == "__main__":
main()
51 changes: 51 additions & 0 deletions crates/burn-import/safetensors-tests/tests/buffer/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use burn::{
module::{Module, Param},
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Net<B: Backend> {
buffer: Param<Tensor<B, 2>>,
}

impl<B: Backend> Net<B> {
/// Create a new model from the given record.
pub fn new_with(record: NetRecord<B>) -> Self {
Self {
buffer: record.buffer,
}
}

/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
self.buffer.val() + x
}
}

#[cfg(test)]
mod tests {
type Backend = burn_ndarray::NdArray<f32>;

use burn::record::{FullPrecisionSettings, Recorder};
use burn_import::safetensors::SafeTensorsFileRecorder;

use super::*;

#[test]
fn buffer() {
let device = Default::default();
let record = SafeTensorsFileRecorder::<FullPrecisionSettings>::default()
.load("tests/buffer/buffer.safetensors".into(), &device)
.expect("Should decode state successfully");

let model = Net::<Backend>::new_with(record);

let input = Tensor::<Backend, 2>::ones([3, 3], &device);

let output = model.forward(input);

let expected = Tensor::<Backend, 2>::ones([3, 3], &device) * 2.0;

output.to_data().assert_approx_eq(&expected.to_data(), 3);
}
}
Binary file not shown.
Loading

0 comments on commit dbd40ce

Please sign in to comment.