Skip to content

Commit

Permalink
add rms norm layer (#1527)
Browse files Browse the repository at this point in the history
  • Loading branch information
ashdtu authored Mar 25, 2024
1 parent ea72638 commit a77979e
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 0 deletions.
1 change: 1 addition & 0 deletions burn-book/src/building-blocks/module.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ Burn comes with built-in modules that you can use to build your own modules.
| `LayerNorm` | `nn.LayerNorm` |
| `GroupNorm` | `nn.GroupNorm` |
| `InstanceNorm` | `nn.InstanceNorm1d`, `nn.InstanceNorm2d` etc. |
| `RmsNorm` | _No direct equivalent_ |
| `Dropout` | `nn.Dropout` |
| `Gelu` | `nn.Gelu` |
| `Prelu` | `nn.PReLu` |
Expand Down
2 changes: 2 additions & 0 deletions crates/burn-core/src/nn/norm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ mod batch;
mod group;
mod instance;
mod layer;
mod rms;

pub use batch::*;
pub use group::*;
pub use instance::*;
pub use layer::*;
pub use rms::*;
94 changes: 94 additions & 0 deletions crates/burn-core/src/nn/norm/rms.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
use crate as burn;

use crate::config::Config;
use crate::module::Module;
use crate::module::Param;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;

/// Configuration to create a [RMS Norm](RmsNorm) layer.
#[derive(Config)]
pub struct RmsNormConfig {
/// The size of the input features.
d_model: usize,
/// A value required for numerical stability. Default: 1e-5
#[config(default = 1e-5)]
epsilon: f64,
}

impl RmsNormConfig {
/// Initialize a new [RMS Norm](RmsNorm) module.
pub fn init<B: Backend>(&self, device: &B::Device) -> RmsNorm<B> {
assert!(self.epsilon > 0.0, "epsilon must be positive.");

let gamma = Tensor::ones([self.d_model], device);

RmsNorm {
gamma: Param::from(gamma),
epsilon: self.epsilon,
}
}

/// Initialize a new [RMS Norm](RmsNorm) module with a [record](RmsNormRecord).
pub fn init_with<B: Backend>(&self, record: RmsNormRecord<B>) -> RmsNorm<B> {
RmsNorm {
gamma: record.gamma,
epsilon: self.epsilon,
}
}
}

/// Applies RMS Normalization over an input tensor along the last dimension.
///
/// `Y = X / sqrt(mean(X^2) + eps) * gamma`
///
/// where `eps` is a small value to avoid division by zero.
#[derive(Module, Debug)]
pub struct RmsNorm<B: Backend> {
/// The learnable parameter to scale the normalized tensor
gamma: Param<Tensor<B, 1>>,
/// A value required for numerical stability
epsilon: f64,
}

impl<B: Backend> RmsNorm<B> {
/// Applies the forward pass on the input tensor.
///
/// # Shapes
///
/// - input: `[..., any, d_model]`
/// - output: `[..., any, d_model]`
pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
// Calculate the root-mean-square norm of the input tensor along the last dimension
let rms = (x.clone().powf_scalar(2.0).mean_dim(D - 1) + self.epsilon).sqrt();
(x / rms) * self.gamma.val().unsqueeze()
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn_tensor::Data;

#[test]
fn rms_norm_forward() {
let device = Default::default();
let module = RmsNormConfig::new(3)
.with_epsilon(1e-5)
.init::<TestBackend>(&device);

let input = Tensor::arange(0..9, &device).float().reshape([3, 3]);

let output = module.forward(input);

output.to_data().assert_approx_eq(
&Data::from([
[0.0000, 0.7746, 1.5492],
[0.7348, 0.9798, 1.2247],
[0.8514, 0.9933, 1.1352],
]),
4,
);
}
}

0 comments on commit a77979e

Please sign in to comment.