-
Notifications
You must be signed in to change notification settings - Fork 489
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add support for safetensors in pytorch reader #2721
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2721 +/- ##
==========================================
+ Coverage 83.60% 83.70% +0.10%
==========================================
Files 819 847 +28
Lines 106600 109951 +3351
==========================================
+ Hits 89124 92036 +2912
- Misses 17476 17915 +439 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the addition 🙏
Looks pretty good overall, just some minor comments.
IMO maybe we can have something like |
I think this is a good point, and it also builds the scaffolding for potentially rewriting it to remove the Candle dependency. |
I agree that the format is not strongly related to pytorch, but I think most models available in safetensor format are pytorch models 😅 Unless you mean supporting the safetensor format as another recorder to load and save modules. In this case, not sure that this is a meaningful addition. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO maybe we can have something like
pub mod safetensors;
under a new feature gate in crateburn-import
so users could build exactly what they need, since safetensors does not seem to be a format strongly related to PyTorch.
Yes, I strongly agree with this opinion. My intention was that we have a separate support for different formats. We can reuse common code in burn-import but entry point should be different. We need a separate module and feature for GGUF files, see #1187, and it would follow the same pattern.
PyTorch's pt
and Safetensors are unrelated. We should not mix up in code or documention as such. It just happens we're using Candle's reader but it could be different.
P.S. Thanks for taking up this problem! It's been often asked feature.
I suggest creating a dedicated Additionally, I propose providing configurable options (via Lastly, we should replicate PyTorch import tests to ensure comprehensive coverage. Over time, we can expand these tests to include SafeTensor files exported from TensorFlow. One more thing: we should introduce a new feature flag, |
Hi all, I went through and essentially copied over the implementation for pytorch recorder, and created the safetensors recorder. It's a lot of new files that are essentially copied code but with little adjustments. I think this gives a good base for the future when we'd like to remove the candle dependency, and to add further support for safetsensors in the future. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making the changes.
That is a lot of code duplication 😅 I don't think "replicate" was meant it this way hahah
We should re-use the existing PyTorchAdapter
when both pytorch
and safetensors
features are enabled. Definitely don't want to copy its implementation over. And by default we can use the existing DefaultAdapter
in burn::record::serde::adapter
to simply load the safetensor file as is.
For the tests, we also don't need to copy all the python scripts. We can just have the existing scripts under pytorch-tests save both in pickle and safetensor formats. And since the current tests added come from pytorch, we can add the additional safetensor tests to the existing pytorch tests under the safetensors
feature flag guard. We would need the safetensor recorder to use the pytorch adapter anyway, so these tests can live under the pytorch-tests (with the addition of the safetensors
feature flag). If we want to add standalone tests for models saved in safetensors that don't require any other transformations we could expand that to have safetensors-tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for taking this up. This feature will help many that requested this.
This is in the right direction but we should try reducing code duplications. I would prioritize Rust code de-duplication first. We can leave the example and test duplicated for now (because it will take time). It's up to @laggui to allow it.
I also suggest creating a new section under the book specifically for SafeTensors. You can talk about the transformation for model modules.
The last thing, don't forget to update LoadArgs
for SafeTensorsFileRecorder per my previous comment.
Agreed for the bold part. |
Pull Request Template
Checklist
run-checks all
script has been executed.Related Issues/PRs
#626
Changes
Simple addition to the already implemented reader.rs, supporting safetsensors format using candle with CPU device import.
Testing
in the examples/pytorch-import directory, there is a mnist.safetensors file that is successfully imported.