Skip to content

Commit bb13729

Browse files
authored
Improve ONNX import book section (#2059)
* Improve ONNX importing section * Update onnx-model.md
1 parent 62a30e9 commit bb13729

File tree

1 file changed

+159
-78
lines changed

1 file changed

+159
-78
lines changed

burn-book/src/import/onnx-model.md

+159-78
Original file line numberDiff line numberDiff line change
@@ -1,137 +1,218 @@
1-
# Import ONNX Model
1+
# Importing ONNX Models in Burn
22

3-
## Why Importing Models is Necessary
3+
## Table of Contents
44

5-
In the realm of deep learning, it's common to switch between different frameworks depending on your
6-
project's specific needs. Maybe you've painstakingly fine-tuned a model in TensorFlow or PyTorch and
7-
now you want to reap the benefits of Burn's unique features for deployment or further testing. This
8-
is precisely the scenario where importing models into Burn can be a game-changer.
5+
1. [Introduction](#introduction)
6+
2. [Why Import Models?](#why-import-models)
7+
3. [Understanding ONNX](#understanding-onnx)
8+
4. [Burn's ONNX Support](#burns-onnx-support)
9+
5. [Step-by-Step Guide](#step-by-step-guide)
10+
6. [Advanced Configuration](#advanced-configuration)
11+
7. [Loading and Using Models](#loading-and-using-models)
12+
8. [Troubleshooting](#troubleshooting)
13+
9. [Examples and Resources](#examples-and-resources)
14+
10. [Conclusion](#conclusion)
915

10-
## Traditional Methods: The Drawbacks
16+
## Introduction
1117

12-
If you've been working with other deep learning frameworks like PyTorch, it's likely that you've
13-
exported model weights before. PyTorch, for instance, lets you save model weights using its
14-
`torch.save()` function. Yet, to port this model to another framework, you face the arduous task of
15-
manually recreating the architecture in the destination framework before loading in the weights. Not
16-
only is this method tedious, but it's also error-prone and hinders smooth interoperability between
17-
frameworks.
18+
As the field of deep learning continues to evolve, the need for interoperability between different
19+
frameworks becomes increasingly important. Burn, a modern deep learning framework in Rust,
20+
recognizes this need and provides robust support for importing models from other popular frameworks.
21+
This section focuses on importing
22+
[ONNX (Open Neural Network Exchange)](https://onnx.ai/onnx/intro/index.html) models into Burn,
23+
enabling you to leverage pre-trained models and seamlessly integrate them into your Rust-based deep
24+
learning projects.
1825

19-
It's worth noting that for models using cutting-edge, framework-specific features, manual porting
20-
might be the only option, as standards like ONNX might not yet support these new innovations.
26+
## Why Import Models?
2127

22-
## Enter ONNX
28+
Importing pre-trained models offers several advantages:
2329

24-
[ONNX (Open Neural Network Exchange)](https://onnx.ai/onnx/intro/index.html) is designed to solve
25-
such complications. It's an open-standard format that exports both the architecture and the weights
26-
of a deep learning model. This feature makes it exponentially easier to move models between
27-
different frameworks, thereby significantly aiding interoperability. ONNX is supported by a number
28-
of frameworks including but not limited to TensorFlow, PyTorch, Caffe2, and Microsoft Cognitive
29-
Toolkit.
30+
1. **Time-saving**: Avoid the need to train models from scratch, which can be time-consuming and
31+
resource-intensive.
32+
2. **Access to state-of-the-art architectures**: Utilize cutting-edge models developed by
33+
researchers and industry leaders.
34+
3. **Transfer learning**: Fine-tune imported models for your specific tasks, benefiting from
35+
knowledge transfer.
36+
4. **Consistency across frameworks**: Ensure consistent performance when moving from one framework
37+
to another.
3038

31-
### Advantages of ONNX
39+
## Understanding ONNX
3240

33-
ONNX stands out for encapsulating two key elements:
41+
ONNX (Open Neural Network Exchange) is an open format designed to represent machine learning models.
42+
Key features include:
3443

35-
1. **Model Information**: It captures the architecture, detailing the layers, their connections, and
36-
configurations.
37-
2. **Weights**: ONNX also contains the trained model's weights.
44+
- **Framework agnostic**: ONNX provides a common format that works across various deep learning
45+
frameworks.
46+
- **Comprehensive representation**: It captures both the model architecture and trained weights.
47+
- **Wide support**: Many popular frameworks like PyTorch, TensorFlow, and scikit-learn support ONNX
48+
export.
3849

39-
This dual encapsulation not only simplifies the porting of models between frameworks but also allows
40-
seamless deployment across different environments without compatibility concerns.
50+
By using ONNX, you can easily move models between different frameworks and deployment environments.
4151

42-
## Burn's ONNX Support: Importing Made Easy
52+
## Burn's ONNX Support
4353

44-
Understanding the important role that ONNX plays in the contemporary deep learning landscape, Burn
45-
simplifies the process of importing ONNX models via an intuitive API designed to mesh well with
46-
Burn's ecosystem.
54+
Burn takes a unique approach to ONNX import, offering several advantages:
4755

48-
Burn's solution is to translate ONNX files into Rust source code as well as Burn-compatible weights.
49-
This transformation is carried out through the burn-import crate's code generator during build time,
50-
providing advantages for both executing and further training ONNX models.
56+
1. **Native Rust code generation**: ONNX models are translated into Rust source code, allowing for
57+
deep integration with Burn's ecosystem.
58+
2. **Compile-time optimization**: The generated Rust code can be optimized by the Rust compiler,
59+
potentially improving performance.
60+
3. **No runtime dependency**: Unlike some solutions that require an ONNX runtime, Burn's approach
61+
eliminates this dependency.
62+
4. **Trainability**: Imported models can be further trained or fine-tuned using Burn.
63+
5. **Portability**: The generated Rust code can be compiled for various targets, including
64+
WebAssembly and embedded devices.
65+
6. **Any Burn Backend**: The imported models can be used with any of Burn's backends.
5166

52-
### Advantages of Burn's ONNX Approach
67+
## Step-by-Step Guide
5368

54-
1. **Native Integration**: The generated Rust code is fully integrated into Burn's architecture,
55-
enabling your model to run on various backends without the need for a separate ONNX runtime.
69+
Let's walk through the process of importing an ONNX model into a Burn project:
5670

57-
2. **Trainability**: The imported model is not just for inference; it can be further trained or
58-
fine-tuned using Burn's native training loop.
71+
### Step 1: Update `build.rs`
5972

60-
3. **Portability**: As the model is converted to Rust source code, it can be compiled into
61-
WebAssembly for browser execution. Likewise, this approach is beneficial for no-std embedded
62-
devices.
73+
First, add the `burn-import` crate to your `Cargo.toml`:
6374

64-
4. **Optimization**: Rust's compiler can further optimize the generated code for target
65-
architectures, thereby improving performance.
66-
67-
### Sample Code for Importing ONNX Model
68-
69-
Below is a step-by-step guide to importing an ONNX model into a Burn-based project:
70-
71-
#### Step 1: Update `build.rs`
75+
```toml
76+
[build-dependencies]
77+
burn-import = "0.14.0"
78+
```
7279

73-
Include the `burn-import` crate and use the following Rust code in your `build.rs`:
80+
Then, in your `build.rs` file:
7481

75-
```rust, ignore
82+
```rust
7683
use burn_import::onnx::ModelGen;
7784

7885
fn main() {
79-
// Generate Rust code from the ONNX model file
8086
ModelGen::new()
81-
.input("src/model/mnist.onnx")
87+
.input("src/model/my_model.onnx")
8288
.out_dir("model/")
8389
.run_from_script();
8490
}
8591
```
8692

87-
#### Step 2: Modify `mod.rs`
93+
This script uses `ModelGen` to generate Rust code from your ONNX model during the build process.
8894

89-
Add this code to the `mod.rs` file located in `src/model`:
95+
### Step 2: Modify `mod.rs`
9096

91-
```rust, ignore
92-
pub mod mnist {
93-
include!(concat!(env!("OUT_DIR"), "/model/mnist.rs"));
97+
In your `src/model/mod.rs` file, include the generated code:
98+
99+
```rust
100+
pub mod my_model {
101+
include!(concat!(env!("OUT_DIR"), "/model/my_model.rs"));
94102
}
95103
```
96104

97-
#### Step 3: Utilize Imported Model
105+
This makes the generated model code available in your project.
98106

99-
Here's how to use the imported model in your application:
107+
### Step 3: Use the Imported Model
100108

101-
```rust, ignore
102-
mod model;
109+
Now you can use the imported model in your Rust code:
103110

111+
```rust
104112
use burn::tensor;
105113
use burn_ndarray::{NdArray, NdArrayDevice};
106-
use model::mnist::Model;
114+
use model::my_model::Model;
107115

108116
fn main() {
109-
// Initialize a new model instance
110117
let device = NdArrayDevice::default();
111-
let model: Model<NdArray<f32>> = Model::new(&device);
112118

113-
// Create a sample input tensor (zeros for demonstration)
114-
let input = tensor::Tensor::<NdArray<f32>, 4>::zeros([1, 1, 28, 28], &device);
119+
// Create model instance and load weights from target dir default device.
120+
// (see more load options below in "Loading and Using Models" section)
121+
let model: Model<NdArray<f32>> = Model::default();
122+
123+
// Create input tensor (replace with your actual input)
124+
let input = tensor::Tensor::<NdArray<f32>, 4>::zeros([1, 3, 224, 224], &device);
115125

116126
// Perform inference
117127
let output = model.forward(input);
118128

119-
// Print the output
120-
println!("{:?}", output);
129+
println!("Model output: {:?}", output);
121130
}
122131
```
123132

124-
### Working Examples
133+
## Advanced Configuration
134+
135+
The `ModelGen` struct offers several configuration options:
136+
137+
```rust
138+
ModelGen::new()
139+
.input("path/to/model.onnx")
140+
.out_dir("model/")
141+
.record_type(RecordType::NamedMpk)
142+
.half_precision(false)
143+
.embed_states(false)
144+
.run_from_script();
145+
```
146+
147+
- `record_type`: Specifies the format for storing weights (Bincode, NamedMpk, NamedMpkGz, or
148+
PrettyJson).
149+
- `half_precision`: Use half-precision (f16) for weights to reduce model size.
150+
- `embed_states`: Embed model weights directly in the generated Rust code. Note: This requires
151+
record type `Bincode`.
152+
153+
## Loading and Using Models
154+
155+
Depending on your configuration, you can load models in different ways:
156+
157+
```rust
158+
// Create a new model instance with device. Initializes weights randomly and lazily.
159+
// You can load weights via `load_record` afterwards.
160+
let model = Model::<Backend>::new(&device);
161+
162+
// Load from a file (must specify weights file in the target output directory or copy it from there).
163+
// File type should match the record type specified in `ModelGen`.
164+
let model = Model::<Backend>::from_file("path/to/weights", &device);
165+
166+
// Load from embedded weights (if embed_states was true)
167+
let model = Model::<Backend>::from_embedded();
168+
169+
// Load from the out director location and load to default device (useful for testing)
170+
let model = Model::<Backend>::default();
171+
```
172+
173+
## Troubleshooting
174+
175+
Here are some common issues and their solutions:
176+
177+
1. **Unsupported ONNX operator**: If you encounter an error about an unsupported operator, check the
178+
[list of supported ONNX operators](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md).
179+
You may need to simplify your model or wait for support to be added.
180+
181+
2. **Build errors**: Ensure that your `burn-import` version matches your Burn version. Also, check
182+
that the ONNX file path in `build.rs` is correct.
125183

126-
For practical examples, please refer to:
184+
3. **Runtime errors**: If you get errors when running your model, double-check that your input
185+
tensors match the expected shape and data type of your model.
186+
187+
4. **Performance issues**: If your imported model is slower than expected, try using the
188+
`half_precision` option to reduce memory usage, or experiment with different `record_type`
189+
options.
190+
191+
5. **Artifact Files**: You can view the generated Rust code and weights files in the `OUT_DIR`
192+
directory specified in `build.rs` (usually `target/debug/build/<project>/out`).
193+
194+
## Examples and Resources
195+
196+
For more detailed examples, check out:
127197

128198
1. [MNIST Inference Example](https://github.com/tracel-ai/burn/tree/main/examples/onnx-inference)
129199
2. [SqueezeNet Image Classification](https://github.com/tracel-ai/models/tree/main/squeezenet-burn)
130200

131-
By combining ONNX's robustness with Burn's unique features, you'll have the flexibility and power to
132-
streamline your deep learning workflows like never before.
201+
These examples demonstrate real-world usage of ONNX import in Burn projects.
202+
203+
## Conclusion
204+
205+
Importing ONNX models into Burn opens up a world of possibilities, allowing you to leverage
206+
pre-trained models from other frameworks while taking advantage of Burn's performance and Rust's
207+
safety features. By following this guide, you should be able to seamlessly integrate ONNX models
208+
into your Burn projects, whether for inference, fine-tuning, or as a starting point for further
209+
development.
210+
211+
Remember that the `burn-import` crate is actively developed, with ongoing work to support more ONNX
212+
operators and improve performance. Stay tuned to the Burn repository for updates and new features!
133213

134214
---
135215

136-
> 🚨**Note**: `burn-import` crate is in active development and currently supports a
137-
> [limited set of ONNX operators](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md).
216+
> 🚨**Note**: The `burn-import` crate is in active development. For the most up-to-date information
217+
> on supported ONNX operators, please refer to the
218+
> [official documentation](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md).

0 commit comments

Comments
 (0)