Skip to content

Commit 8053001

Browse files
authored
Fix LayerNorm normalization. (#2186)
Fixes #2185.
1 parent c29ed43 commit 8053001

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

crates/burn-core/src/nn/norm/layer.rs

+22-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ impl<B: Backend> LayerNorm<B> {
6767
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
6868
let (var, mean) = input.clone().var_mean_bias(D - 1);
6969

70-
let input_normalized = input.sub(mean).div(var.sqrt().add_scalar(self.epsilon));
70+
let input_normalized = input.sub(mean).div(var.add_scalar(self.epsilon).sqrt());
7171

7272
input_normalized
7373
.mul(self.gamma.val().unsqueeze())
@@ -122,6 +122,27 @@ mod tests {
122122
output.to_data().assert_approx_eq(&expected, 3);
123123
}
124124

125+
#[test]
126+
fn layer_norm_forward_large_epsilon() {
127+
let device = Default::default();
128+
let module = LayerNormConfig::new(10)
129+
.with_epsilon(1e-1)
130+
.init::<TestBackend>(&device);
131+
let input = Tensor::<TestBackend, 2>::from_data(
132+
TensorData::from([[
133+
-0.6897, -2.7106, 2.2222, -1.0330, -0.8933, 1.1765, 0.0601, 1.5252, -0.3630, 0.6728,
134+
]]),
135+
&device,
136+
);
137+
138+
let output = module.forward(input);
139+
140+
let expected = TensorData::from([[
141+
-0.4863, -1.9180, 1.5766, -0.7295, -0.6305, 0.8358, 0.0449, 1.0828, -0.2548, 0.4790,
142+
]]);
143+
output.to_data().assert_approx_eq(&expected, 3);
144+
}
145+
125146
#[cfg(feature = "std")]
126147
#[test]
127148
fn layer_norm_backward() {

0 commit comments

Comments
 (0)