|
| 1 | +use crate as burn; |
| 2 | + |
| 3 | +use crate::{config::Config, module::Module}; |
| 4 | +use burn_tensor::backend::Backend; |
| 5 | +use burn_tensor::Tensor; |
| 6 | +use core::marker::PhantomData; |
| 7 | + |
| 8 | +use super::Reduction; |
| 9 | + |
| 10 | +/// Configuration to create a [Huber loss](HuberLoss). |
| 11 | +#[derive(Config, Debug)] |
| 12 | +pub struct HuberLossConfig { |
| 13 | + /// The bound where the Huber loss function changes from quadratic to linear behaviour. |
| 14 | + pub delta: f32, |
| 15 | +} |
| 16 | + |
| 17 | +impl HuberLossConfig { |
| 18 | + /// Initialize [Huber loss](HuberLoss). |
| 19 | + pub fn init<B: Backend>(&self, device: &B::Device) -> HuberLoss<B> { |
| 20 | + // device is not needed as of now, but we might want to prepare some data on it |
| 21 | + // and its consistent with other loss functions |
| 22 | + let _ = device; |
| 23 | + self.assertions(); |
| 24 | + HuberLoss { |
| 25 | + delta: self.delta, |
| 26 | + lin_bias: self.delta * self.delta * 0.5, |
| 27 | + _backend: PhantomData, |
| 28 | + } |
| 29 | + } |
| 30 | + |
| 31 | + fn assertions(&self) { |
| 32 | + assert!( |
| 33 | + self.delta >= 0., // This also tests for normality |
| 34 | + "Delta for Huber loss must be a non-negative number." |
| 35 | + ); |
| 36 | + } |
| 37 | +} |
| 38 | + |
| 39 | +/// Calculate the Huber loss between the inputs and the target. |
| 40 | +/// |
| 41 | +/// The loss for each element of the residuals `r = targets - predictions` is given by |
| 42 | +/// |
| 43 | +/// ```text |
| 44 | +/// L(r) = 0.5 * x^2 if |r| <= d |
| 45 | +/// L(r) = 0.5 * d^2 + d * (|r| - d) if |r| > d |
| 46 | +/// ``` |
| 47 | +/// |
| 48 | +/// where `d` is the configured `delta`. In particular, this is equal to the |
| 49 | +/// [L2 Loss](super::MseLoss) for residuals with magnitude smaller than `delta`, |
| 50 | +/// but behaves linearly instead of quadratically for large residuals. |
| 51 | +/// |
| 52 | +/// This loss function is less sensitive to outliers than the mean squared error loss. |
| 53 | +/// |
| 54 | +/// See also: <https://en.wikipedia.org/wiki/Huber_loss> |
| 55 | +#[derive(Module, Debug)] |
| 56 | +pub struct HuberLoss<B: Backend> { |
| 57 | + delta: f32, |
| 58 | + lin_bias: f32, // delta * delta * 0.5 precomputed |
| 59 | + _backend: PhantomData<B>, |
| 60 | +} |
| 61 | + |
| 62 | +impl<B: Backend> HuberLoss<B> { |
| 63 | + /// Compute the loss element-wise for the predictions and targets, then reduce |
| 64 | + /// to a single loss value. |
| 65 | + /// |
| 66 | + /// `Reduction::Auto` behaves as `Reduction::Mean`. |
| 67 | + /// |
| 68 | + /// # Shapes |
| 69 | + /// |
| 70 | + /// - predictions: \[...dims\] |
| 71 | + /// - targets: \[...dims\] |
| 72 | + /// - output: \[1\] |
| 73 | + pub fn forward<const D: usize>( |
| 74 | + &self, |
| 75 | + predictions: Tensor<B, D>, |
| 76 | + targets: Tensor<B, D>, |
| 77 | + reduction: Reduction, |
| 78 | + ) -> Tensor<B, 1> { |
| 79 | + let loss = self.forward_no_reduction(predictions, targets); |
| 80 | + match reduction { |
| 81 | + Reduction::Mean | Reduction::Auto => loss.mean(), |
| 82 | + Reduction::Sum => loss.sum(), |
| 83 | + } |
| 84 | + } |
| 85 | + /// Compute the loss element-wise for the predictions and targets. |
| 86 | + /// |
| 87 | + /// # Shapes |
| 88 | + /// |
| 89 | + /// - predictions: [...dims] |
| 90 | + /// - targets: [...dims] |
| 91 | + /// - output: [...dims] |
| 92 | + pub fn forward_no_reduction<const D: usize>( |
| 93 | + &self, |
| 94 | + predictions: Tensor<B, D>, |
| 95 | + targets: Tensor<B, D>, |
| 96 | + ) -> Tensor<B, D> { |
| 97 | + let residuals = targets - predictions; |
| 98 | + self.forward_residuals(residuals) |
| 99 | + } |
| 100 | + /// Compute the loss element-wise for the given residuals. |
| 101 | + /// |
| 102 | + /// # Shapes |
| 103 | + /// |
| 104 | + /// - residuals: [...dims] |
| 105 | + /// - output: [...dims] |
| 106 | + pub fn forward_residuals<const D: usize>(&self, residuals: Tensor<B, D>) -> Tensor<B, D> { |
| 107 | + let is_large = residuals.clone().abs().greater_elem(self.delta); |
| 108 | + // We are interested in `sign(r)` when `abs(r) > self.delta`. Note that the |
| 109 | + // `sign()` function, in general, suffers from a jump at 0. |
| 110 | + // Instead the following tensor implements `delta * sign(r)` for values outside |
| 111 | + // the bound: |
| 112 | + let softsign = residuals.clone().clamp(-self.delta, self.delta); |
| 113 | + |
| 114 | + // 0.5 * d^2 + d * (|r| - d) = |
| 115 | + // d * |r| - 0.5 * d^2 |
| 116 | + // Moreover |r| = sign(r) * r |
| 117 | + let outside = softsign.mul(residuals.clone()).sub_scalar(self.lin_bias); |
| 118 | + |
| 119 | + let inside = residuals.powf_scalar(2.).mul_scalar(0.5); |
| 120 | + inside.mask_where(is_large, outside) |
| 121 | + } |
| 122 | +} |
| 123 | + |
| 124 | +#[cfg(test)] |
| 125 | +mod tests { |
| 126 | + use super::*; |
| 127 | + use crate::TestBackend; |
| 128 | + use burn_tensor::Data; |
| 129 | + type TestTensor<const D: usize> = Tensor<TestBackend, D>; |
| 130 | + |
| 131 | + #[test] |
| 132 | + fn test_huber_loss() { |
| 133 | + let predict = Data::from([-2., -0.5, 0., 0.3, 1.]); |
| 134 | + let targets = Data::from([0., 0., 0., 0., 0.]); |
| 135 | + |
| 136 | + let device = Default::default(); |
| 137 | + |
| 138 | + let predict = TestTensor::<1>::from_data(predict, &device); |
| 139 | + let targets = TestTensor::<1>::from_data(targets, &device); |
| 140 | + |
| 141 | + let huber = HuberLossConfig::new(0.5).init(&device); |
| 142 | + |
| 143 | + let loss_sum = huber.forward(predict.clone(), targets.clone(), Reduction::Sum); |
| 144 | + let loss = huber.forward(predict.clone(), targets.clone(), Reduction::Auto); |
| 145 | + let loss_no_reduction = huber.forward_no_reduction(predict, targets); |
| 146 | + |
| 147 | + loss_no_reduction |
| 148 | + .into_data() |
| 149 | + .assert_approx_eq(&Data::from([0.875, 0.125, 0., 0.045, 0.375]), 7); |
| 150 | + loss.into_data().assert_approx_eq(&Data::from([0.284]), 7); |
| 151 | + loss_sum |
| 152 | + .into_data() |
| 153 | + .assert_approx_eq(&Data::from([1.42]), 7); |
| 154 | + } |
| 155 | + |
| 156 | + #[cfg(feature = "std")] |
| 157 | + #[test] |
| 158 | + fn test_huber_ad_loss() { |
| 159 | + type TestAutodiffTensor = Tensor<crate::TestAutodiffBackend, 1>; |
| 160 | + |
| 161 | + let predict = Data::from([-2., -0.5, 0., 0.3, 1.]); |
| 162 | + let targets = Data::from([0., 0., 0., 0., 0.]); |
| 163 | + |
| 164 | + let device = Default::default(); |
| 165 | + let predict = TestAutodiffTensor::from_data(predict, &device).require_grad(); |
| 166 | + let targets = TestAutodiffTensor::from_data(targets, &device); |
| 167 | + |
| 168 | + let loss = HuberLossConfig::new(0.5).init(&device); |
| 169 | + let loss = loss.forward_no_reduction(predict.clone(), targets); |
| 170 | + |
| 171 | + let grads = loss.backward(); |
| 172 | + let grads_predict = predict.grad(&grads).unwrap(); |
| 173 | + |
| 174 | + grads_predict |
| 175 | + .to_data() |
| 176 | + .assert_approx_eq(&Data::from([-0.5, -0.5, 0., 0.3, 0.5]), 3); |
| 177 | + } |
| 178 | +} |
0 commit comments