Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for safetensors in pytorch reader #2721

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

wandbrandon
Copy link

Pull Request Template

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

#626

Changes

Simple addition to the already implemented reader.rs, supporting safetsensors format using candle with CPU device import.

Testing

in the examples/pytorch-import directory, there is a mnist.safetensors file that is successfully imported.

Copy link

codecov bot commented Jan 20, 2025

Codecov Report

Attention: Patch coverage is 93.25768% with 101 lines in your changes missing coverage. Please review.

Project coverage is 83.70%. Comparing base (140ea75) to head (dbd40ce).
Report is 12 commits behind head on main.

Files with missing lines Patch % Lines
crates/burn-import/src/safetensors/config.rs 0.00% 52 Missing ⚠️
...burn-import/safetensors-tests/tests/boolean/mod.rs 0.00% 26 Missing ⚠️
crates/burn-import/src/safetensors/recorder.rs 65.90% 15 Missing ⚠️
crates/burn-import/src/safetensors/reader.rs 94.73% 5 Missing ⚠️
...port/safetensors-tests/tests/complex_nested/mod.rs 99.02% 1 Missing ⚠️
...-import/safetensors-tests/tests/enum_module/mod.rs 99.34% 1 Missing ⚠️
...afetensors-tests/tests/missing_module_field/mod.rs 90.90% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2721      +/-   ##
==========================================
+ Coverage   83.60%   83.70%   +0.10%     
==========================================
  Files         819      847      +28     
  Lines      106600   109951    +3351     
==========================================
+ Hits        89124    92036    +2912     
- Misses      17476    17915     +439     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the addition 🙏

Looks pretty good overall, just some minor comments.

crates/burn-import/src/pytorch/reader.rs Outdated Show resolved Hide resolved
examples/pytorch-import/build.rs Outdated Show resolved Hide resolved
@Nikaidou-Shinku
Copy link
Contributor

IMO maybe we can have something like pub mod safetensors; under a new feature gate in crate burn-import so users could build exactly what they need, since safetensors does not seem to be a format strongly related to PyTorch.

@wandbrandon
Copy link
Author

IMO maybe we can have something like pub mod safetensors; under a new feature gate in crate burn-import so users could build exactly what they need, since safetensors does not seem to be a format strongly related to PyTorch.

I think this is a good point, and it also builds the scaffolding for potentially rewriting it to remove the Candle dependency.

@laggui
Copy link
Member

laggui commented Jan 21, 2025

IMO maybe we can have something like pub mod safetensors; under a new feature gate in crate burn-import so users could build exactly what they need, since safetensors does not seem to be a format strongly related to PyTorch.

I agree that the format is not strongly related to pytorch, but I think most models available in safetensor format are pytorch models 😅

Unless you mean supporting the safetensor format as another recorder to load and save modules. In this case, not sure that this is a meaningful addition.

@antimora antimora self-requested a review January 27, 2025 17:48
Copy link
Collaborator

@antimora antimora left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO maybe we can have something like pub mod safetensors; under a new feature gate in crate burn-import so users could build exactly what they need, since safetensors does not seem to be a format strongly related to PyTorch.

Yes, I strongly agree with this opinion. My intention was that we have a separate support for different formats. We can reuse common code in burn-import but entry point should be different. We need a separate module and feature for GGUF files, see #1187, and it would follow the same pattern.

PyTorch's pt and Safetensors are unrelated. We should not mix up in code or documention as such. It just happens we're using Candle's reader but it could be different.

P.S. Thanks for taking up this problem! It's been often asked feature.

@antimora
Copy link
Collaborator

I suggest creating a dedicated SafeTensorFileRecorder to handle SafeTensor files independently from PyTorch's .pt files. This approach ensures a clear separation between different file formats and supports framework-specific transformations during the import process.

Additionally, I propose providing configurable options (via LoadArgs, similar to PyTorch's recorder) within the recorder to specify the appropriate transformation adapter. By default, this could use the PyTorchAdapter but allow customization for other frameworks, such as TensorFlow. This design enhances flexibility and decouples the handling of different tensor file formats. Moreover, it might be beneficial to support passing a user-defined implementation of BurnModuleAdapter when needed.

Lastly, we should replicate PyTorch import tests to ensure comprehensive coverage. Over time, we can expand these tests to include SafeTensor files exported from TensorFlow.

One more thing: we should introduce a new feature flag, safetensors.

@antimora antimora added the feature The feature request label Jan 27, 2025
@wandbrandon
Copy link
Author

wandbrandon commented Jan 28, 2025

Hi all, I went through and essentially copied over the implementation for pytorch recorder, and created the safetensors recorder. It's a lot of new files that are essentially copied code but with little adjustments. I think this gives a good base for the future when we'd like to remove the candle dependency, and to add further support for safetsensors in the future.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making the changes.

That is a lot of code duplication 😅 I don't think "replicate" was meant it this way hahah

We should re-use the existing PyTorchAdapter when both pytorch and safetensors features are enabled. Definitely don't want to copy its implementation over. And by default we can use the existing DefaultAdapter in burn::record::serde::adapter to simply load the safetensor file as is.

For the tests, we also don't need to copy all the python scripts. We can just have the existing scripts under pytorch-tests save both in pickle and safetensor formats. And since the current tests added come from pytorch, we can add the additional safetensor tests to the existing pytorch tests under the safetensors feature flag guard. We would need the safetensor recorder to use the pytorch adapter anyway, so these tests can live under the pytorch-tests (with the addition of the safetensors feature flag). If we want to add standalone tests for models saved in safetensors that don't require any other transformations we could expand that to have safetensors-tests.

Copy link
Collaborator

@antimora antimora left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for taking this up. This feature will help many that requested this.

This is in the right direction but we should try reducing code duplications. I would prioritize Rust code de-duplication first. We can leave the example and test duplicated for now (because it will take time). It's up to @laggui to allow it.

I also suggest creating a new section under the book specifically for SafeTensors. You can talk about the transformation for model modules.

The last thing, don't forget to update LoadArgs for SafeTensorsFileRecorder per my previous comment.

@laggui
Copy link
Member

laggui commented Jan 29, 2025

I would prioritize Rust code de-duplication first. We can leave the example and test duplicated for now (because it will take time).

Agreed for the bold part.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature The feature request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants