Skip to content

Commit 88b4420

Browse files
authored
feature(norm): Add GroupNorm (#963)
* Add GroupNorm * Fix implemenation and add tests * Address PR comments * Fix formatting * Update burn book
1 parent 4711db0 commit 88b4420

File tree

3 files changed

+255
-0
lines changed

3 files changed

+255
-0
lines changed

burn-book/src/building-blocks/module.md

+1
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ Burn comes with built-in modules that you can use to build your own modules.
111111
| ----------- | --------------------------------------- |
112112
| `BatchNorm` | `nn.BatchNorm1d`, `nn.BatchNorm2d` etc. |
113113
| `LayerNorm` | `nn.LayerNorm` |
114+
| `GroupNorm` | `nn.GroupNorm` |
114115
| `Dropout` | `nn.Dropout` |
115116
| `GELU` | `nn.GELU` |
116117
| `Linear` | `nn.Linear` |

burn-core/src/nn/norm/group.rs

+252
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
use crate as burn;
2+
3+
use crate::config::Config;
4+
use crate::module::Module;
5+
use crate::module::Param;
6+
use crate::tensor::backend::Backend;
7+
use crate::tensor::Tensor;
8+
9+
/// Configuration to create a [GroupNorm](GroupNorm) layer.
10+
#[derive(Config)]
11+
pub struct GroupNormConfig {
12+
/// The number of groups to separate the channels into
13+
num_groups: usize,
14+
/// The number of channels expected in the input
15+
num_channels: usize,
16+
/// A value required for numerical stability. Default: 1e-5
17+
#[config(default = 1e-5)]
18+
epsilon: f64,
19+
/// A boolean value that when set to `true`, this module has learnable
20+
/// per-channel affine parameters initialized to ones (for weights)
21+
/// and zeros (for biases). Default: `true`
22+
#[config(default = true)]
23+
affine: bool,
24+
}
25+
26+
/// Applies Group Normalization over a mini-batch of inputs.
27+
///
28+
/// `Y = groupnorm(X) * γ + β`
29+
#[derive(Module, Debug)]
30+
pub struct GroupNorm<B: Backend> {
31+
num_groups: usize,
32+
num_channels: usize,
33+
gamma: Option<Param<Tensor<B, 1>>>,
34+
beta: Option<Param<Tensor<B, 1>>>,
35+
epsilon: f64,
36+
affine: bool,
37+
}
38+
39+
impl GroupNormConfig {
40+
/// Initialize a new [group norm](GroupNorm) module.
41+
pub fn init<B: Backend>(&self) -> GroupNorm<B> {
42+
assert_eq!(
43+
self.num_channels % self.num_groups,
44+
0,
45+
"The number of channels must be divisible by the number of groups"
46+
);
47+
48+
let (gamma, beta) = if self.affine {
49+
let gamma = Tensor::ones([self.num_channels]).into();
50+
let beta = Tensor::zeros([self.num_channels]).into();
51+
52+
(Some(gamma), Some(beta))
53+
} else {
54+
(None, None)
55+
};
56+
57+
GroupNorm {
58+
num_groups: self.num_groups,
59+
num_channels: self.num_channels,
60+
gamma,
61+
beta,
62+
epsilon: self.epsilon,
63+
affine: self.affine,
64+
}
65+
}
66+
67+
/// Initialize a new [group norm](GroupNorm) module with a [record](GroupNormRecord).
68+
pub fn init_with<B: Backend>(&self, record: GroupNormRecord<B>) -> GroupNorm<B> {
69+
GroupNorm {
70+
num_groups: self.num_groups,
71+
num_channels: self.num_channels,
72+
gamma: record.gamma,
73+
beta: record.beta,
74+
epsilon: self.epsilon,
75+
affine: self.affine,
76+
}
77+
}
78+
}
79+
80+
impl<B: Backend> GroupNorm<B> {
81+
/// Applies the forward pass on the input tensor.
82+
///
83+
/// # Shapes
84+
///
85+
/// - input: `[..., any, d_model]`
86+
/// - output: `[..., any, d_model]`
87+
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
88+
let shape = input.shape();
89+
if shape.num_elements() <= 2 {
90+
panic!(
91+
"input rank for GroupNorm should be at least 3, but got {}",
92+
shape.num_elements()
93+
);
94+
}
95+
96+
let batch_size = shape.dims[0];
97+
let num_channels = shape.dims[1];
98+
99+
if num_channels != self.num_channels {
100+
panic!(
101+
"expected {} channels but got {}",
102+
self.num_channels, num_channels
103+
);
104+
}
105+
106+
let hidden_size =
107+
shape.dims[2..].iter().product::<usize>() * num_channels / self.num_groups;
108+
let input = input.reshape([batch_size, self.num_groups, hidden_size]);
109+
110+
let mean = input.clone().sum_dim(2) / hidden_size as f64;
111+
let var = input.clone().sqrt().sum_dim(2) / hidden_size as f64;
112+
let input_normalized = input.sub(mean).div(var.sqrt().add_scalar(self.epsilon));
113+
114+
if self.affine {
115+
let mut affine_shape = [1; D];
116+
affine_shape[1] = num_channels;
117+
118+
input_normalized
119+
.reshape(shape)
120+
.mul(self.gamma.clone().unwrap().val().reshape(affine_shape))
121+
.add(self.beta.clone().unwrap().val().reshape(affine_shape))
122+
} else {
123+
input_normalized.reshape(shape)
124+
}
125+
}
126+
}
127+
128+
#[cfg(test)]
129+
mod tests {
130+
use super::*;
131+
use crate::TestBackend;
132+
use burn_tensor::Data;
133+
134+
#[test]
135+
fn group_norm_forward_affine_false() {
136+
let module = GroupNormConfig::new(2, 6)
137+
.with_affine(false)
138+
.init::<TestBackend>();
139+
140+
assert!(module.gamma.is_none());
141+
assert!(module.beta.is_none());
142+
143+
let input = Tensor::from_data(Data::from([
144+
[
145+
[-0.3034f32, 0.2726, -0.9659],
146+
[-1.1845, -1.3236, 0.0172],
147+
[1.9507, 1.2554, -0.8625],
148+
[1.0682, 0.3604, 0.3985],
149+
[-0.4957, -0.4461, -0.9721],
150+
[1.5157, -0.1546, -0.5596],
151+
],
152+
[
153+
[-1.6698, -0.4040, -0.7927],
154+
[0.3736, -0.0975, -0.1351],
155+
[-0.9461, 0.5461, -0.6334],
156+
[-1.0919, -0.1158, 0.1213],
157+
[-0.9535, 0.1281, 0.4372],
158+
[-0.2845, 0.3488, 0.5641],
159+
],
160+
]));
161+
162+
let output = module.forward(input);
163+
164+
output.to_data().assert_approx_eq(
165+
&Data::from([
166+
[
167+
[-0.1653, 0.3748, -0.7866],
168+
[-0.9916, -1.1220, 0.1353],
169+
[1.9485, 1.2965, -0.6896],
170+
[1.2769, 0.3628, 0.4120],
171+
[-0.7427, -0.6786, -1.3578],
172+
[1.8547, -0.3022, -0.8252],
173+
],
174+
[
175+
[-1.9342, 0.0211, -0.5793],
176+
[1.2223, 0.4945, 0.4365],
177+
[-0.8163, 1.4887, -0.3333],
178+
[-1.7960, -0.0392, 0.3875],
179+
[-1.5469, 0.3998, 0.9561],
180+
[-0.3428, 0.7970, 1.1845],
181+
],
182+
]),
183+
3,
184+
);
185+
}
186+
187+
#[test]
188+
fn group_norm_forward_affine_true() {
189+
let module = GroupNormConfig::new(3, 6)
190+
.with_affine(true)
191+
.init::<TestBackend>();
192+
193+
module
194+
.gamma
195+
.as_ref()
196+
.expect("Gamma is None")
197+
.val()
198+
.to_data()
199+
.assert_approx_eq(&Data::ones([6].into()), 3);
200+
201+
module
202+
.beta
203+
.as_ref()
204+
.expect("beta is None")
205+
.val()
206+
.to_data()
207+
.assert_approx_eq(&Data::zeros([6]), 3);
208+
209+
let input = Tensor::from_data(Data::from([
210+
[
211+
[-0.3034f32, 0.2726, -0.9659],
212+
[-1.1845, -1.3236, 0.0172],
213+
[1.9507, 1.2554, -0.8625],
214+
[1.0682, 0.3604, 0.3985],
215+
[-0.4957, -0.4461, -0.9721],
216+
[1.5157, -0.1546, -0.5596],
217+
],
218+
[
219+
[-1.6698, -0.4040, -0.7927],
220+
[0.3736, -0.0975, -0.1351],
221+
[-0.9461, 0.5461, -0.6334],
222+
[-1.0919, -0.1158, 0.1213],
223+
[-0.9535, 0.1281, 0.4372],
224+
[-0.2845, 0.3488, 0.5641],
225+
],
226+
]));
227+
228+
let output = module.forward(input);
229+
230+
output.to_data().assert_approx_eq(
231+
&Data::from([
232+
[
233+
[0.4560, 1.4014, -0.6313],
234+
[-0.9901, -1.2184, 0.9822],
235+
[1.4254, 0.6360, -1.7682],
236+
[0.4235, -0.3800, -0.3367],
237+
[-0.3890, -0.3268, -0.9862],
238+
[2.1325, 0.0386, -0.4691],
239+
],
240+
[
241+
[-1.8797, 0.0777, -0.5234],
242+
[1.2802, 0.5517, 0.4935],
243+
[-1.0102, 1.5327, -0.4773],
244+
[-1.2587, 0.4047, 0.8088],
245+
[-1.9074, 0.1691, 0.7625],
246+
[-0.6230, 0.5928, 1.0061],
247+
],
248+
]),
249+
3,
250+
);
251+
}
252+
}

burn-core/src/nn/norm/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
mod batch;
2+
mod group;
23
mod layer;
34

45
pub use batch::*;
6+
pub use group::*;
57
pub use layer::*;

0 commit comments

Comments
 (0)