-
Notifications
You must be signed in to change notification settings - Fork 491
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
support safetensors format, for it's own crate
- Loading branch information
1 parent
6a0330e
commit dbd40ce
Showing
79 changed files
with
3,752 additions
and
20 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
[package] | ||
name = "safetensors-tests" | ||
version.workspace = true | ||
edition.workspace = true | ||
license.workspace = true | ||
|
||
[dev-dependencies] | ||
burn = { path = "../../burn" } | ||
burn-ndarray = { path = "../../burn-ndarray" } | ||
burn-autodiff = { path = "../../burn-autodiff" } | ||
serde = { workspace = true } | ||
float-cmp = { workspace = true } | ||
burn-import = { path = "../", features = ["safetensors"] } | ||
|
||
|
||
[build-dependencies] | ||
burn-import = { path = "../", features = ["safetensors"] } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
|
Binary file added
BIN
+448 Bytes
crates/burn-import/safetensors-tests/tests/batch_norm/batch_norm2d.safetensors
Binary file not shown.
41 changes: 41 additions & 0 deletions
41
crates/burn-import/safetensors-tests/tests/batch_norm/export_weights.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from safetensors.torch import save_file | ||
|
||
|
||
class Model(nn.Module): | ||
def __init__(self): | ||
super(Model, self).__init__() | ||
self.norm1 = nn.BatchNorm2d(5) | ||
|
||
def forward(self, x): | ||
x = self.norm1(x) | ||
return x | ||
|
||
|
||
def main(): | ||
|
||
torch.set_printoptions(precision=8) | ||
torch.manual_seed(1) | ||
|
||
model = Model().to(torch.device("cpu")) | ||
|
||
# Condition batch norm (each forward will affect the running stats) | ||
x1 = torch.ones(1, 5, 2, 2) - 0.5 | ||
_ = model(x1) | ||
model.eval() # Set to eval mode to freeze running stats | ||
# Save the model to safetensors after the first forward | ||
save_file(model.state_dict(), "batch_norm2d.safetensors") | ||
|
||
x2 = torch.ones(1, 5, 2, 2) - 0.3 | ||
print("Input shape: {}", x2.shape) | ||
output = model(x2) | ||
print("Output: {}", output) | ||
print("Output Shape: {}", output.shape) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
60 changes: 60 additions & 0 deletions
60
crates/burn-import/safetensors-tests/tests/batch_norm/mod.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
use burn::{ | ||
module::Module, | ||
nn::{BatchNorm, BatchNormConfig}, | ||
tensor::{backend::Backend, Tensor}, | ||
}; | ||
|
||
#[derive(Module, Debug)] | ||
pub struct Net<B: Backend> { | ||
norm1: BatchNorm<B, 2>, | ||
} | ||
|
||
impl<B: Backend> Net<B> { | ||
pub fn new(device: &B::Device) -> Self { | ||
Self { | ||
norm1: BatchNormConfig::new(4).init(device), | ||
} | ||
} | ||
|
||
/// Forward pass of the model. | ||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> { | ||
self.norm1.forward(x) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
type Backend = burn_ndarray::NdArray<f32>; | ||
|
||
use burn::record::{FullPrecisionSettings, Recorder}; | ||
use burn_import::safetensors::SafeTensorsFileRecorder; | ||
|
||
use super::*; | ||
|
||
#[test] | ||
fn batch_norm2d() { | ||
let device = Default::default(); | ||
let record = SafeTensorsFileRecorder::<FullPrecisionSettings>::default() | ||
.load("tests/batch_norm/batch_norm2d.safetensors".into(), &device) | ||
.expect("Should decode state successfully"); | ||
|
||
let model = Net::<Backend>::new(&device).load_record(record); | ||
|
||
let input = Tensor::<Backend, 4>::ones([1, 5, 2, 2], &device) - 0.3; | ||
|
||
let output = model.forward(input); | ||
|
||
let expected = Tensor::<Backend, 4>::from_data( | ||
[[ | ||
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]], | ||
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]], | ||
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]], | ||
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]], | ||
[[0.68515635, 0.68515635], [0.68515635, 0.68515635]], | ||
]], | ||
&device, | ||
); | ||
|
||
output.to_data().assert_approx_eq(&expected.to_data(), 5); | ||
} | ||
} |
Binary file added
BIN
+75 Bytes
crates/burn-import/safetensors-tests/tests/boolean/boolean.safetensors
Binary file not shown.
38 changes: 38 additions & 0 deletions
38
crates/burn-import/safetensors-tests/tests/boolean/export_weights.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from safetensors.torch import save_file | ||
|
||
|
||
class Model(nn.Module): | ||
def __init__(self): | ||
super(Model, self).__init__() | ||
buffer = torch.tensor([True, False, True]) | ||
self.register_buffer("buffer", buffer, persistent=True) | ||
|
||
def forward(self, x): | ||
x = self.buffer | ||
return x | ||
|
||
|
||
def main(): | ||
|
||
torch.set_printoptions(precision=8) | ||
torch.manual_seed(1) | ||
|
||
model = Model().to(torch.device("cpu")) | ||
|
||
save_file(model.state_dict(), "boolean.safetensors") | ||
|
||
input = torch.ones(3, 3) | ||
print("Input shape: {}", input.shape) | ||
print("Input: {}", input) | ||
output = model(input) | ||
print("Output: {}", output) | ||
print("Output Shape: {}", output.shape) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
use burn::{ | ||
module::{Module, Param}, | ||
tensor::{backend::Backend, Bool, Tensor}, | ||
}; | ||
|
||
#[derive(Module, Debug)] | ||
pub struct Net<B: Backend> { | ||
buffer: Param<Tensor<B, 1, Bool>>, | ||
} | ||
|
||
impl<B: Backend> Net<B> { | ||
/// Create a new model from the given record. | ||
pub fn new_with(record: NetRecord<B>) -> Self { | ||
Self { | ||
buffer: record.buffer, | ||
} | ||
} | ||
|
||
/// Forward pass of the model. | ||
pub fn forward(&self, _x: Tensor<B, 2>) -> Tensor<B, 1, Bool> { | ||
self.buffer.val() | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
|
||
use burn::{ | ||
record::{FullPrecisionSettings, Recorder}, | ||
tensor::TensorData, | ||
}; | ||
use burn_import::safetensors::SafeTensorsFileRecorder; | ||
|
||
use super::*; | ||
|
||
type Backend = burn_ndarray::NdArray<f32>; | ||
|
||
#[test] | ||
#[ignore = "It appears loading boolean tensors are not supported yet"] | ||
// Error skipping: Msg("unsupported storage type BoolStorage") | ||
fn boolean() { | ||
let device = Default::default(); | ||
let record = SafeTensorsFileRecorder::<FullPrecisionSettings>::default() | ||
.load("tests/boolean/boolean.safetensors".into(), &device) | ||
.expect("Should decode state successfully"); | ||
|
||
let model = Net::<Backend>::new_with(record); | ||
|
||
let input = Tensor::<Backend, 2>::ones([3, 3], &device); | ||
|
||
let output = model.forward(input); | ||
|
||
let expected = | ||
Tensor::<Backend, 1, Bool>::from_bool(TensorData::from([true, false, true]), &device); | ||
|
||
assert_eq!(output.to_data(), expected.to_data()); | ||
} | ||
} |
Binary file added
BIN
+108 Bytes
crates/burn-import/safetensors-tests/tests/buffer/buffer.safetensors
Binary file not shown.
38 changes: 38 additions & 0 deletions
38
crates/burn-import/safetensors-tests/tests/buffer/export_weights.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from safetensors.torch import save_file | ||
|
||
|
||
class Model(nn.Module): | ||
def __init__(self): | ||
super(Model, self).__init__() | ||
buffer = torch.ones(3, 3) | ||
self.register_buffer("buffer", buffer, persistent=True) | ||
|
||
def forward(self, x): | ||
x = self.buffer + x | ||
return x | ||
|
||
|
||
def main(): | ||
|
||
torch.set_printoptions(precision=8) | ||
torch.manual_seed(1) | ||
|
||
model = Model().to(torch.device("cpu")) | ||
|
||
save_file(model.state_dict(), "buffer.safetensors") | ||
|
||
input = torch.ones(3, 3) | ||
print("Input shape: {}", input.shape) | ||
print("Input: {}", input) | ||
output = model(input) | ||
print("Output: {}", output) | ||
print("Output Shape: {}", output.shape) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
use burn::{ | ||
module::{Module, Param}, | ||
tensor::{backend::Backend, Tensor}, | ||
}; | ||
|
||
#[derive(Module, Debug)] | ||
pub struct Net<B: Backend> { | ||
buffer: Param<Tensor<B, 2>>, | ||
} | ||
|
||
impl<B: Backend> Net<B> { | ||
/// Create a new model from the given record. | ||
pub fn new_with(record: NetRecord<B>) -> Self { | ||
Self { | ||
buffer: record.buffer, | ||
} | ||
} | ||
|
||
/// Forward pass of the model. | ||
pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> { | ||
self.buffer.val() + x | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
type Backend = burn_ndarray::NdArray<f32>; | ||
|
||
use burn::record::{FullPrecisionSettings, Recorder}; | ||
use burn_import::safetensors::SafeTensorsFileRecorder; | ||
|
||
use super::*; | ||
|
||
#[test] | ||
fn buffer() { | ||
let device = Default::default(); | ||
let record = SafeTensorsFileRecorder::<FullPrecisionSettings>::default() | ||
.load("tests/buffer/buffer.safetensors".into(), &device) | ||
.expect("Should decode state successfully"); | ||
|
||
let model = Net::<Backend>::new_with(record); | ||
|
||
let input = Tensor::<Backend, 2>::ones([3, 3], &device); | ||
|
||
let output = model.forward(input); | ||
|
||
let expected = Tensor::<Backend, 2>::ones([3, 3], &device) * 2.0; | ||
|
||
output.to_data().assert_approx_eq(&expected.to_data(), 3); | ||
} | ||
} |
Binary file added
BIN
+9.03 KB
crates/burn-import/safetensors-tests/tests/complex_nested/complex_nested.safetensors
Binary file not shown.
Oops, something went wrong.