|
1 |
| -# Import ONNX Model |
| 1 | +# Importing ONNX Models in Burn |
2 | 2 |
|
3 |
| -## Why Importing Models is Necessary |
| 3 | +## Table of Contents |
4 | 4 |
|
5 |
| -In the realm of deep learning, it's common to switch between different frameworks depending on your |
6 |
| -project's specific needs. Maybe you've painstakingly fine-tuned a model in TensorFlow or PyTorch and |
7 |
| -now you want to reap the benefits of Burn's unique features for deployment or further testing. This |
8 |
| -is precisely the scenario where importing models into Burn can be a game-changer. |
| 5 | +1. [Introduction](#introduction) |
| 6 | +2. [Why Import Models?](#why-import-models) |
| 7 | +3. [Understanding ONNX](#understanding-onnx) |
| 8 | +4. [Burn's ONNX Support](#burns-onnx-support) |
| 9 | +5. [Step-by-Step Guide](#step-by-step-guide) |
| 10 | +6. [Advanced Configuration](#advanced-configuration) |
| 11 | +7. [Loading and Using Models](#loading-and-using-models) |
| 12 | +8. [Troubleshooting](#troubleshooting) |
| 13 | +9. [Examples and Resources](#examples-and-resources) |
| 14 | +10. [Conclusion](#conclusion) |
9 | 15 |
|
10 |
| -## Traditional Methods: The Drawbacks |
| 16 | +## Introduction |
11 | 17 |
|
12 |
| -If you've been working with other deep learning frameworks like PyTorch, it's likely that you've |
13 |
| -exported model weights before. PyTorch, for instance, lets you save model weights using its |
14 |
| -`torch.save()` function. Yet, to port this model to another framework, you face the arduous task of |
15 |
| -manually recreating the architecture in the destination framework before loading in the weights. Not |
16 |
| -only is this method tedious, but it's also error-prone and hinders smooth interoperability between |
17 |
| -frameworks. |
| 18 | +As the field of deep learning continues to evolve, the need for interoperability between different |
| 19 | +frameworks becomes increasingly important. Burn, a modern deep learning framework in Rust, |
| 20 | +recognizes this need and provides robust support for importing models from other popular frameworks. |
| 21 | +This section focuses on importing |
| 22 | +[ONNX (Open Neural Network Exchange)](https://onnx.ai/onnx/intro/index.html) models into Burn, |
| 23 | +enabling you to leverage pre-trained models and seamlessly integrate them into your Rust-based deep |
| 24 | +learning projects. |
18 | 25 |
|
19 |
| -It's worth noting that for models using cutting-edge, framework-specific features, manual porting |
20 |
| -might be the only option, as standards like ONNX might not yet support these new innovations. |
| 26 | +## Why Import Models? |
21 | 27 |
|
22 |
| -## Enter ONNX |
| 28 | +Importing pre-trained models offers several advantages: |
23 | 29 |
|
24 |
| -[ONNX (Open Neural Network Exchange)](https://onnx.ai/onnx/intro/index.html) is designed to solve |
25 |
| -such complications. It's an open-standard format that exports both the architecture and the weights |
26 |
| -of a deep learning model. This feature makes it exponentially easier to move models between |
27 |
| -different frameworks, thereby significantly aiding interoperability. ONNX is supported by a number |
28 |
| -of frameworks including but not limited to TensorFlow, PyTorch, Caffe2, and Microsoft Cognitive |
29 |
| -Toolkit. |
| 30 | +1. **Time-saving**: Avoid the need to train models from scratch, which can be time-consuming and |
| 31 | + resource-intensive. |
| 32 | +2. **Access to state-of-the-art architectures**: Utilize cutting-edge models developed by |
| 33 | + researchers and industry leaders. |
| 34 | +3. **Transfer learning**: Fine-tune imported models for your specific tasks, benefiting from |
| 35 | + knowledge transfer. |
| 36 | +4. **Consistency across frameworks**: Ensure consistent performance when moving from one framework |
| 37 | + to another. |
30 | 38 |
|
31 |
| -### Advantages of ONNX |
| 39 | +## Understanding ONNX |
32 | 40 |
|
33 |
| -ONNX stands out for encapsulating two key elements: |
| 41 | +ONNX (Open Neural Network Exchange) is an open format designed to represent machine learning models. |
| 42 | +Key features include: |
34 | 43 |
|
35 |
| -1. **Model Information**: It captures the architecture, detailing the layers, their connections, and |
36 |
| - configurations. |
37 |
| -2. **Weights**: ONNX also contains the trained model's weights. |
| 44 | +- **Framework agnostic**: ONNX provides a common format that works across various deep learning |
| 45 | + frameworks. |
| 46 | +- **Comprehensive representation**: It captures both the model architecture and trained weights. |
| 47 | +- **Wide support**: Many popular frameworks like PyTorch, TensorFlow, and scikit-learn support ONNX |
| 48 | + export. |
38 | 49 |
|
39 |
| -This dual encapsulation not only simplifies the porting of models between frameworks but also allows |
40 |
| -seamless deployment across different environments without compatibility concerns. |
| 50 | +By using ONNX, you can easily move models between different frameworks and deployment environments. |
41 | 51 |
|
42 |
| -## Burn's ONNX Support: Importing Made Easy |
| 52 | +## Burn's ONNX Support |
43 | 53 |
|
44 |
| -Understanding the important role that ONNX plays in the contemporary deep learning landscape, Burn |
45 |
| -simplifies the process of importing ONNX models via an intuitive API designed to mesh well with |
46 |
| -Burn's ecosystem. |
| 54 | +Burn takes a unique approach to ONNX import, offering several advantages: |
47 | 55 |
|
48 |
| -Burn's solution is to translate ONNX files into Rust source code as well as Burn-compatible weights. |
49 |
| -This transformation is carried out through the burn-import crate's code generator during build time, |
50 |
| -providing advantages for both executing and further training ONNX models. |
| 56 | +1. **Native Rust code generation**: ONNX models are translated into Rust source code, allowing for |
| 57 | + deep integration with Burn's ecosystem. |
| 58 | +2. **Compile-time optimization**: The generated Rust code can be optimized by the Rust compiler, |
| 59 | + potentially improving performance. |
| 60 | +3. **No runtime dependency**: Unlike some solutions that require an ONNX runtime, Burn's approach |
| 61 | + eliminates this dependency. |
| 62 | +4. **Trainability**: Imported models can be further trained or fine-tuned using Burn. |
| 63 | +5. **Portability**: The generated Rust code can be compiled for various targets, including |
| 64 | + WebAssembly and embedded devices. |
| 65 | +6. **Any Burn Backend**: The imported models can be used with any of Burn's backends. |
51 | 66 |
|
52 |
| -### Advantages of Burn's ONNX Approach |
| 67 | +## Step-by-Step Guide |
53 | 68 |
|
54 |
| -1. **Native Integration**: The generated Rust code is fully integrated into Burn's architecture, |
55 |
| - enabling your model to run on various backends without the need for a separate ONNX runtime. |
| 69 | +Let's walk through the process of importing an ONNX model into a Burn project: |
56 | 70 |
|
57 |
| -2. **Trainability**: The imported model is not just for inference; it can be further trained or |
58 |
| - fine-tuned using Burn's native training loop. |
| 71 | +### Step 1: Update `build.rs` |
59 | 72 |
|
60 |
| -3. **Portability**: As the model is converted to Rust source code, it can be compiled into |
61 |
| - WebAssembly for browser execution. Likewise, this approach is beneficial for no-std embedded |
62 |
| - devices. |
| 73 | +First, add the `burn-import` crate to your `Cargo.toml`: |
63 | 74 |
|
64 |
| -4. **Optimization**: Rust's compiler can further optimize the generated code for target |
65 |
| - architectures, thereby improving performance. |
66 |
| - |
67 |
| -### Sample Code for Importing ONNX Model |
68 |
| - |
69 |
| -Below is a step-by-step guide to importing an ONNX model into a Burn-based project: |
70 |
| - |
71 |
| -#### Step 1: Update `build.rs` |
| 75 | +```toml |
| 76 | +[build-dependencies] |
| 77 | +burn-import = "0.14.0" |
| 78 | +``` |
72 | 79 |
|
73 |
| -Include the `burn-import` crate and use the following Rust code in your `build.rs`: |
| 80 | +Then, in your `build.rs` file: |
74 | 81 |
|
75 |
| -```rust, ignore |
| 82 | +```rust |
76 | 83 | use burn_import::onnx::ModelGen;
|
77 | 84 |
|
78 | 85 | fn main() {
|
79 |
| - // Generate Rust code from the ONNX model file |
80 | 86 | ModelGen::new()
|
81 |
| - .input("src/model/mnist.onnx") |
| 87 | + .input("src/model/my_model.onnx") |
82 | 88 | .out_dir("model/")
|
83 | 89 | .run_from_script();
|
84 | 90 | }
|
85 | 91 | ```
|
86 | 92 |
|
87 |
| -#### Step 2: Modify `mod.rs` |
| 93 | +This script uses `ModelGen` to generate Rust code from your ONNX model during the build process. |
88 | 94 |
|
89 |
| -Add this code to the `mod.rs` file located in `src/model`: |
| 95 | +### Step 2: Modify `mod.rs` |
90 | 96 |
|
91 |
| -```rust, ignore |
92 |
| -pub mod mnist { |
93 |
| - include!(concat!(env!("OUT_DIR"), "/model/mnist.rs")); |
| 97 | +In your `src/model/mod.rs` file, include the generated code: |
| 98 | + |
| 99 | +```rust |
| 100 | +pub mod my_model { |
| 101 | + include!(concat!(env!("OUT_DIR"), "/model/my_model.rs")); |
94 | 102 | }
|
95 | 103 | ```
|
96 | 104 |
|
97 |
| -#### Step 3: Utilize Imported Model |
| 105 | +This makes the generated model code available in your project. |
98 | 106 |
|
99 |
| -Here's how to use the imported model in your application: |
| 107 | +### Step 3: Use the Imported Model |
100 | 108 |
|
101 |
| -```rust, ignore |
102 |
| -mod model; |
| 109 | +Now you can use the imported model in your Rust code: |
103 | 110 |
|
| 111 | +```rust |
104 | 112 | use burn::tensor;
|
105 | 113 | use burn_ndarray::{NdArray, NdArrayDevice};
|
106 |
| -use model::mnist::Model; |
| 114 | +use model::my_model::Model; |
107 | 115 |
|
108 | 116 | fn main() {
|
109 |
| - // Initialize a new model instance |
110 | 117 | let device = NdArrayDevice::default();
|
111 |
| - let model: Model<NdArray<f32>> = Model::new(&device); |
112 | 118 |
|
113 |
| - // Create a sample input tensor (zeros for demonstration) |
114 |
| - let input = tensor::Tensor::<NdArray<f32>, 4>::zeros([1, 1, 28, 28], &device); |
| 119 | + // Create model instance and load weights from target dir default device. |
| 120 | + // (see more load options below in "Loading and Using Models" section) |
| 121 | + let model: Model<NdArray<f32>> = Model::default(); |
| 122 | + |
| 123 | + // Create input tensor (replace with your actual input) |
| 124 | + let input = tensor::Tensor::<NdArray<f32>, 4>::zeros([1, 3, 224, 224], &device); |
115 | 125 |
|
116 | 126 | // Perform inference
|
117 | 127 | let output = model.forward(input);
|
118 | 128 |
|
119 |
| - // Print the output |
120 |
| - println!("{:?}", output); |
| 129 | + println!("Model output: {:?}", output); |
121 | 130 | }
|
122 | 131 | ```
|
123 | 132 |
|
124 |
| -### Working Examples |
| 133 | +## Advanced Configuration |
| 134 | + |
| 135 | +The `ModelGen` struct offers several configuration options: |
| 136 | + |
| 137 | +```rust |
| 138 | +ModelGen::new() |
| 139 | + .input("path/to/model.onnx") |
| 140 | + .out_dir("model/") |
| 141 | + .record_type(RecordType::NamedMpk) |
| 142 | + .half_precision(false) |
| 143 | + .embed_states(false) |
| 144 | + .run_from_script(); |
| 145 | +``` |
| 146 | + |
| 147 | +- `record_type`: Specifies the format for storing weights (Bincode, NamedMpk, NamedMpkGz, or |
| 148 | + PrettyJson). |
| 149 | +- `half_precision`: Use half-precision (f16) for weights to reduce model size. |
| 150 | +- `embed_states`: Embed model weights directly in the generated Rust code. Note: This requires |
| 151 | + record type `Bincode`. |
| 152 | + |
| 153 | +## Loading and Using Models |
| 154 | + |
| 155 | +Depending on your configuration, you can load models in different ways: |
| 156 | + |
| 157 | +```rust |
| 158 | +// Create a new model instance with device. Initializes weights randomly and lazily. |
| 159 | +// You can load weights via `load_record` afterwards. |
| 160 | +let model = Model::<Backend>::new(&device); |
| 161 | + |
| 162 | +// Load from a file (must specify weights file in the target output directory or copy it from there). |
| 163 | +// File type should match the record type specified in `ModelGen`. |
| 164 | +let model = Model::<Backend>::from_file("path/to/weights", &device); |
| 165 | + |
| 166 | +// Load from embedded weights (if embed_states was true) |
| 167 | +let model = Model::<Backend>::from_embedded(); |
| 168 | + |
| 169 | +// Load from the out director location and load to default device (useful for testing) |
| 170 | +let model = Model::<Backend>::default(); |
| 171 | +``` |
| 172 | + |
| 173 | +## Troubleshooting |
| 174 | + |
| 175 | +Here are some common issues and their solutions: |
| 176 | + |
| 177 | +1. **Unsupported ONNX operator**: If you encounter an error about an unsupported operator, check the |
| 178 | + [list of supported ONNX operators](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md). |
| 179 | + You may need to simplify your model or wait for support to be added. |
| 180 | + |
| 181 | +2. **Build errors**: Ensure that your `burn-import` version matches your Burn version. Also, check |
| 182 | + that the ONNX file path in `build.rs` is correct. |
125 | 183 |
|
126 |
| -For practical examples, please refer to: |
| 184 | +3. **Runtime errors**: If you get errors when running your model, double-check that your input |
| 185 | + tensors match the expected shape and data type of your model. |
| 186 | + |
| 187 | +4. **Performance issues**: If your imported model is slower than expected, try using the |
| 188 | + `half_precision` option to reduce memory usage, or experiment with different `record_type` |
| 189 | + options. |
| 190 | + |
| 191 | +5. **Artifact Files**: You can view the generated Rust code and weights files in the `OUT_DIR` |
| 192 | + directory specified in `build.rs` (usually `target/debug/build/<project>/out`). |
| 193 | + |
| 194 | +## Examples and Resources |
| 195 | + |
| 196 | +For more detailed examples, check out: |
127 | 197 |
|
128 | 198 | 1. [MNIST Inference Example](https://github.com/tracel-ai/burn/tree/main/examples/onnx-inference)
|
129 | 199 | 2. [SqueezeNet Image Classification](https://github.com/tracel-ai/models/tree/main/squeezenet-burn)
|
130 | 200 |
|
131 |
| -By combining ONNX's robustness with Burn's unique features, you'll have the flexibility and power to |
132 |
| -streamline your deep learning workflows like never before. |
| 201 | +These examples demonstrate real-world usage of ONNX import in Burn projects. |
| 202 | + |
| 203 | +## Conclusion |
| 204 | + |
| 205 | +Importing ONNX models into Burn opens up a world of possibilities, allowing you to leverage |
| 206 | +pre-trained models from other frameworks while taking advantage of Burn's performance and Rust's |
| 207 | +safety features. By following this guide, you should be able to seamlessly integrate ONNX models |
| 208 | +into your Burn projects, whether for inference, fine-tuning, or as a starting point for further |
| 209 | +development. |
| 210 | + |
| 211 | +Remember that the `burn-import` crate is actively developed, with ongoing work to support more ONNX |
| 212 | +operators and improve performance. Stay tuned to the Burn repository for updates and new features! |
133 | 213 |
|
134 | 214 | ---
|
135 | 215 |
|
136 |
| -> 🚨**Note**: `burn-import` crate is in active development and currently supports a |
137 |
| -> [limited set of ONNX operators](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md). |
| 216 | +> 🚨**Note**: The `burn-import` crate is in active development. For the most up-to-date information |
| 217 | +> on supported ONNX operators, please refer to the |
| 218 | +> [official documentation](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md). |
0 commit comments