Skip to content

Commit 53eb3ec

Browse files
authored
Implement Huber loss (#1444)
* Implement Huber loss Instead of using a sign or abs function, uses clamping to compute it outside the bounds. This is better for the autodiff backend. * mention Huber loss in the book * unify naming of residuals in comments
1 parent 7a98b2f commit 53eb3ec

File tree

3 files changed

+181
-0
lines changed

3 files changed

+181
-0
lines changed

burn-book/src/building-blocks/module.md

+1
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,4 @@ Burn comes with built-in modules that you can use to build your own modules.
162162
| ------------------ | --------------------- |
163163
| `CrossEntropyLoss` | `nn.CrossEntropyLoss` |
164164
| `MseLoss` | `nn.MSELoss` |
165+
| `HuberLoss` | `nn.HuberLoss` |

crates/burn-core/src/nn/loss/huber.rs

+178
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
use crate as burn;
2+
3+
use crate::{config::Config, module::Module};
4+
use burn_tensor::backend::Backend;
5+
use burn_tensor::Tensor;
6+
use core::marker::PhantomData;
7+
8+
use super::Reduction;
9+
10+
/// Configuration to create a [Huber loss](HuberLoss).
11+
#[derive(Config, Debug)]
12+
pub struct HuberLossConfig {
13+
/// The bound where the Huber loss function changes from quadratic to linear behaviour.
14+
pub delta: f32,
15+
}
16+
17+
impl HuberLossConfig {
18+
/// Initialize [Huber loss](HuberLoss).
19+
pub fn init<B: Backend>(&self, device: &B::Device) -> HuberLoss<B> {
20+
// device is not needed as of now, but we might want to prepare some data on it
21+
// and its consistent with other loss functions
22+
let _ = device;
23+
self.assertions();
24+
HuberLoss {
25+
delta: self.delta,
26+
lin_bias: self.delta * self.delta * 0.5,
27+
_backend: PhantomData,
28+
}
29+
}
30+
31+
fn assertions(&self) {
32+
assert!(
33+
self.delta >= 0., // This also tests for normality
34+
"Delta for Huber loss must be a non-negative number."
35+
);
36+
}
37+
}
38+
39+
/// Calculate the Huber loss between the inputs and the target.
40+
///
41+
/// The loss for each element of the residuals `r = targets - predictions` is given by
42+
///
43+
/// ```text
44+
/// L(r) = 0.5 * x^2 if |r| <= d
45+
/// L(r) = 0.5 * d^2 + d * (|r| - d) if |r| > d
46+
/// ```
47+
///
48+
/// where `d` is the configured `delta`. In particular, this is equal to the
49+
/// [L2 Loss](super::MseLoss) for residuals with magnitude smaller than `delta`,
50+
/// but behaves linearly instead of quadratically for large residuals.
51+
///
52+
/// This loss function is less sensitive to outliers than the mean squared error loss.
53+
///
54+
/// See also: <https://en.wikipedia.org/wiki/Huber_loss>
55+
#[derive(Module, Debug)]
56+
pub struct HuberLoss<B: Backend> {
57+
delta: f32,
58+
lin_bias: f32, // delta * delta * 0.5 precomputed
59+
_backend: PhantomData<B>,
60+
}
61+
62+
impl<B: Backend> HuberLoss<B> {
63+
/// Compute the loss element-wise for the predictions and targets, then reduce
64+
/// to a single loss value.
65+
///
66+
/// `Reduction::Auto` behaves as `Reduction::Mean`.
67+
///
68+
/// # Shapes
69+
///
70+
/// - predictions: \[...dims\]
71+
/// - targets: \[...dims\]
72+
/// - output: \[1\]
73+
pub fn forward<const D: usize>(
74+
&self,
75+
predictions: Tensor<B, D>,
76+
targets: Tensor<B, D>,
77+
reduction: Reduction,
78+
) -> Tensor<B, 1> {
79+
let loss = self.forward_no_reduction(predictions, targets);
80+
match reduction {
81+
Reduction::Mean | Reduction::Auto => loss.mean(),
82+
Reduction::Sum => loss.sum(),
83+
}
84+
}
85+
/// Compute the loss element-wise for the predictions and targets.
86+
///
87+
/// # Shapes
88+
///
89+
/// - predictions: [...dims]
90+
/// - targets: [...dims]
91+
/// - output: [...dims]
92+
pub fn forward_no_reduction<const D: usize>(
93+
&self,
94+
predictions: Tensor<B, D>,
95+
targets: Tensor<B, D>,
96+
) -> Tensor<B, D> {
97+
let residuals = targets - predictions;
98+
self.forward_residuals(residuals)
99+
}
100+
/// Compute the loss element-wise for the given residuals.
101+
///
102+
/// # Shapes
103+
///
104+
/// - residuals: [...dims]
105+
/// - output: [...dims]
106+
pub fn forward_residuals<const D: usize>(&self, residuals: Tensor<B, D>) -> Tensor<B, D> {
107+
let is_large = residuals.clone().abs().greater_elem(self.delta);
108+
// We are interested in `sign(r)` when `abs(r) > self.delta`. Note that the
109+
// `sign()` function, in general, suffers from a jump at 0.
110+
// Instead the following tensor implements `delta * sign(r)` for values outside
111+
// the bound:
112+
let softsign = residuals.clone().clamp(-self.delta, self.delta);
113+
114+
// 0.5 * d^2 + d * (|r| - d) =
115+
// d * |r| - 0.5 * d^2
116+
// Moreover |r| = sign(r) * r
117+
let outside = softsign.mul(residuals.clone()).sub_scalar(self.lin_bias);
118+
119+
let inside = residuals.powf_scalar(2.).mul_scalar(0.5);
120+
inside.mask_where(is_large, outside)
121+
}
122+
}
123+
124+
#[cfg(test)]
125+
mod tests {
126+
use super::*;
127+
use crate::TestBackend;
128+
use burn_tensor::Data;
129+
type TestTensor<const D: usize> = Tensor<TestBackend, D>;
130+
131+
#[test]
132+
fn test_huber_loss() {
133+
let predict = Data::from([-2., -0.5, 0., 0.3, 1.]);
134+
let targets = Data::from([0., 0., 0., 0., 0.]);
135+
136+
let device = Default::default();
137+
138+
let predict = TestTensor::<1>::from_data(predict, &device);
139+
let targets = TestTensor::<1>::from_data(targets, &device);
140+
141+
let huber = HuberLossConfig::new(0.5).init(&device);
142+
143+
let loss_sum = huber.forward(predict.clone(), targets.clone(), Reduction::Sum);
144+
let loss = huber.forward(predict.clone(), targets.clone(), Reduction::Auto);
145+
let loss_no_reduction = huber.forward_no_reduction(predict, targets);
146+
147+
loss_no_reduction
148+
.into_data()
149+
.assert_approx_eq(&Data::from([0.875, 0.125, 0., 0.045, 0.375]), 7);
150+
loss.into_data().assert_approx_eq(&Data::from([0.284]), 7);
151+
loss_sum
152+
.into_data()
153+
.assert_approx_eq(&Data::from([1.42]), 7);
154+
}
155+
156+
#[cfg(feature = "std")]
157+
#[test]
158+
fn test_huber_ad_loss() {
159+
type TestAutodiffTensor = Tensor<crate::TestAutodiffBackend, 1>;
160+
161+
let predict = Data::from([-2., -0.5, 0., 0.3, 1.]);
162+
let targets = Data::from([0., 0., 0., 0., 0.]);
163+
164+
let device = Default::default();
165+
let predict = TestAutodiffTensor::from_data(predict, &device).require_grad();
166+
let targets = TestAutodiffTensor::from_data(targets, &device);
167+
168+
let loss = HuberLossConfig::new(0.5).init(&device);
169+
let loss = loss.forward_no_reduction(predict.clone(), targets);
170+
171+
let grads = loss.backward();
172+
let grads_predict = predict.grad(&grads).unwrap();
173+
174+
grads_predict
175+
.to_data()
176+
.assert_approx_eq(&Data::from([-0.5, -0.5, 0., 0.3, 0.5]), 3);
177+
}
178+
}

crates/burn-core/src/nn/loss/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
mod binary_cross_entropy;
22
mod cross_entropy;
3+
mod huber;
34
mod mse;
45
mod reduction;
56

67
pub use binary_cross_entropy::*;
78
pub use cross_entropy::*;
9+
pub use huber::*;
810
pub use mse::*;
911
pub use reduction::*;

0 commit comments

Comments
 (0)