-
Notifications
You must be signed in to change notification settings - Fork 491
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add q_into_data and q_reshape * Fix tch quantize f16 and q_into_data * Convert to actual dtype/kind in dequantize * Add module quantization and q_from_data * Fix clippy * Add documentation * Handle deserialize data conversion * Fix typo * Add calibration tests * Fix clippy precision * Add QTensorOps require_grad methods to avoid dequantizing * Add Dequantize mapper docs * Remove dead code
- Loading branch information
Showing
22 changed files
with
618 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# Quantization (Beta) | ||
|
||
Quantization techniques perform computations and store tensors in lower precision data types like | ||
8-bit integer instead of floating point precision. There are multiple approaches to quantize a deep | ||
learning model categorized as: | ||
|
||
- Post-training quantization (PTQ) | ||
- Quantization aware training (QAT) | ||
|
||
In post-training quantization, the model is trained in floating point precision and later converted | ||
to the lower precision data type. | ||
|
||
There are two types of post-training quantization: | ||
|
||
1. Static quantization: quantizes the weights and activations of the model. Quantizing the | ||
activations statically requires data to be calibrated (i.e., recording the activation values to | ||
compute the optimal quantization parameters with representative data). | ||
1. Dynamic quantization: quantized the weights ahead of time (like static quantization) but the | ||
activations are dynamically at runtime. | ||
|
||
Sometimes post-training quantization is not able to achieve acceptable task accuracy. This is where | ||
quantization aware training comes into play, as it models the effects of quantization during | ||
training. Quantization errors are thus modeled in the forward and backward passes using fake | ||
quantization modules, which helps the model learn representations that are more robust to the | ||
reduction in precision. | ||
|
||
<div class="warning"> | ||
|
||
Quantization support in Burn is currently in active development. | ||
|
||
It supports the following modes on some backends: | ||
|
||
- Static per-tensor quantization to signed 8-bit integer (`i8`) | ||
|
||
No integer operations are currently supported, which means tensors are dequantized to perform the | ||
operations in floating point precision. | ||
|
||
</div> | ||
|
||
## Module Quantization | ||
|
||
Quantizing the weights of your model after training is quite simple. We have access to the weight | ||
tensors and can collect their statistics, such as the min and max value when using | ||
`MinMaxCalibration`, to compute the quantization parameters. | ||
|
||
```rust , ignore | ||
# use burn::quantization::{MinMaxCalibration, QuantizationScheme, QuantizationType, Quantizer}; | ||
# | ||
// Quantization config | ||
let mut quantizer = Quantizer { | ||
calibration: MinMaxCalibration { | ||
scheme: QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8), | ||
}, | ||
}; | ||
|
||
// Quantize the weights | ||
let model = model.quantize_weights(&mut quantizer); | ||
``` | ||
|
||
> Given that all operations are currently performed in floating point precision, it might be wise to | ||
> dequantize the module parameters before inference. This allows us to save disk space by storing | ||
> the model in reduced precision while preserving the inference speed. | ||
> | ||
> This can easily be implemented with a `ModuleMapper`. | ||
> | ||
> ```rust, ignore | ||
> # use burn::module::{ModuleMapper, ParamId}; | ||
> # use burn::tensor::{backend::Backend, Tensor}; | ||
> # | ||
> /// Module mapper used to dequantize the model params being loaded. | ||
> pub struct Dequantize {} | ||
> | ||
> impl<B: Backend> ModuleMapper<B> for Dequantize { | ||
> fn map_float<const D: usize>( | ||
> &mut self, | ||
> _id: &ParamId, | ||
> tensor: Tensor<B, D>, | ||
> ) -> Tensor<B, D> { | ||
> tensor.dequantize() | ||
> } | ||
> } | ||
> | ||
> // Load saved quantized model in floating point precision | ||
> model = model | ||
> .load_file(file_path, recorder, &device) | ||
> .expect("Should be able to load the quantized model weights") | ||
> .map(&mut Dequantize {}); | ||
> ``` | ||
### Calibration | ||
Calibration is the step during quantization where the range of all floating-point tensors is | ||
computed. This is pretty straightforward for weights since the actual range is known at | ||
_quantization-time_ (weights are static), but activations require more attention. | ||
To compute the quantization parameters, Burn supports the following `Calibration` methods. | ||
| Method | Description | | ||
| :------------------ | :------------------------------------------------------------------------------- | | ||
| `MinMaxCalibration` | Computes the quantization range mapping based on the running min and max values. | | ||
### Quantization Scheme | ||
A quantization scheme defines the quantized type, quantization granularity and range mapping | ||
technique. | ||
Burn currently supports the following `QuantizationType` variants. | ||
| Type | Description | | ||
| :------ | :--------------------------------- | | ||
| `QInt8` | 8-bit signed integer quantization. | | ||
Quantization parameters are defined based on the range of values to represent and can typically be | ||
calculated for the layer's entire weight tensor with per-tensor quantization or separately for each | ||
channel with per-channel quantization (commonly used with CNNs). | ||
Burn currently supports the following `QuantizationScheme` variants. | ||
| Variant | Description | | ||
| :------------------- | :------------------------------------------------------------------------------------------------------------- | | ||
| `PerTensorAffine` | Computes the quantization parameters for the whole tensor and applies an affine range mapping with zero point. | | ||
| `PerTensorSymmetric` | Computes the quantization parameters for the whole tensor and applies a scale range mapping centered around 0. | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
use burn_tensor::{ | ||
backend::Backend, AffineQuantization, ElementConversion, Quantization, QuantizationStrategy, | ||
SymmetricQuantization, Tensor, | ||
}; | ||
|
||
use super::{QuantizationScheme, QuantizationType}; | ||
|
||
/// Calibration method used to compute the quantization range mapping. | ||
pub trait Calibration { | ||
/// Configure the quantization strategy. | ||
fn configure<B: Backend, const D: usize>(&self, tensor: &Tensor<B, D>) -> QuantizationStrategy; | ||
} | ||
|
||
/// Computes the quantization range mapping based on the running min and max values. | ||
pub struct MinMaxCalibration { | ||
/// Quantization scheme to be used. | ||
pub scheme: QuantizationScheme, | ||
} | ||
|
||
impl Calibration for MinMaxCalibration { | ||
fn configure<B: Backend, const D: usize>(&self, tensor: &Tensor<B, D>) -> QuantizationStrategy { | ||
let min = tensor.clone().min().into_scalar().elem::<f32>(); | ||
let max = tensor.clone().max().into_scalar().elem::<f32>(); | ||
|
||
match &self.scheme { | ||
QuantizationScheme::PerTensorAffine(dtype) => match dtype { | ||
QuantizationType::QInt8 => { | ||
QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::new(min, max)) | ||
} | ||
}, | ||
QuantizationScheme::PerTensorSymmetric(dtype) => match dtype { | ||
QuantizationType::QInt8 => QuantizationStrategy::PerTensorSymmetricInt8( | ||
SymmetricQuantization::new(min, max), | ||
), | ||
}, | ||
} | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
|
||
use super::*; | ||
use crate::TestBackend; | ||
|
||
#[test] | ||
fn min_max_calibration_per_tensor_affine_int8() { | ||
let device = <TestBackend as Backend>::Device::default(); | ||
let tensor = Tensor::<TestBackend, 1>::from_floats([-1.8, -1.0, 0.0, 0.5], &device); | ||
let calibration = MinMaxCalibration { | ||
scheme: QuantizationScheme::PerTensorAffine(QuantizationType::QInt8), | ||
}; | ||
|
||
let strategy = calibration.configure(&tensor); | ||
|
||
if let QuantizationStrategy::PerTensorAffineInt8(q) = strategy { | ||
assert_eq!(q.scale, 0.009_019_608); | ||
assert_eq!(q.offset, 72); | ||
} else { | ||
panic!("Wrong quantization strategy"); | ||
} | ||
} | ||
|
||
#[test] | ||
fn min_max_calibration_per_tensor_symmetric_int8() { | ||
let device = <TestBackend as Backend>::Device::default(); | ||
let tensor = Tensor::<TestBackend, 1>::from_floats([-1.8, -1.0, 0.0, 0.5], &device); | ||
let calibration = MinMaxCalibration { | ||
scheme: QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8), | ||
}; | ||
|
||
let strategy = calibration.configure(&tensor); | ||
|
||
if let QuantizationStrategy::PerTensorSymmetricInt8(q) = strategy { | ||
assert_eq!(q.scale, 0.014_173_228); | ||
} else { | ||
panic!("Wrong quantization strategy"); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
mod calibration; | ||
mod quantize; | ||
mod scheme; | ||
|
||
pub use calibration::*; | ||
pub use quantize::*; | ||
pub use scheme::*; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
use burn_tensor::{backend::Backend, Tensor}; | ||
|
||
use crate::module::{ModuleMapper, ParamId}; | ||
|
||
use super::Calibration; | ||
|
||
/// Describes how to quantize a module. | ||
pub struct Quantizer<C: Calibration> { | ||
/// The calibration method used in quantization. | ||
pub calibration: C, | ||
} | ||
|
||
impl<B: Backend, C: Calibration> ModuleMapper<B> for Quantizer<C> { | ||
fn map_float<const D: usize>(&mut self, _id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> { | ||
let strategy = self.calibration.configure(&tensor); | ||
tensor.quantize(strategy) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
/// Quantization data type. | ||
pub enum QuantizationType { | ||
/// 8-bit signed integer. | ||
QInt8, | ||
} | ||
|
||
/// Quantization scheme. | ||
pub enum QuantizationScheme { | ||
/// Per-tensor affine/asymmetric quantization. | ||
PerTensorAffine(QuantizationType), | ||
/// Per-tensor symmetric quantization. | ||
PerTensorSymmetric(QuantizationType), | ||
// /// Per-channel affine/asymmetric quantization. | ||
// PerChannelAffine, | ||
// /// Per-channel symmetric quantization. | ||
// PerChannelSymmetric, | ||
} |
Oops, something went wrong.