Skip to content

Commit

Permalink
fix the problem of sigmoid gradient generating NaN (#1140)
Browse files Browse the repository at this point in the history
* use sigmoid derivative formulas

* add test

* fix test error

* move sigmoid to tensor/ops/activation.rs

* use full precision in the default implementation

* rename the param of `sigmoid_backward`
  • Loading branch information
wcshds authored Jan 16, 2024
1 parent b99726f commit a5bdf38
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 1 deletion.
23 changes: 23 additions & 0 deletions burn-autodiff/src/ops/activation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,27 @@ impl<B: Backend> ActivationOps<Autodiff<B>> for Autodiff<B> {
OpsKind::UnTracked(prep) => prep.finish(output),
}
}

fn sigmoid<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
#[derive(Debug)]
struct Sigmoid;

impl<B: Backend, const D: usize> Backward<B, D, 1> for Sigmoid {
type State = B::TensorPrimitive<D>;

fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
B::sigmoid_backward(ops.state, grad)
});
}
}

match Sigmoid.prepare([tensor.node], [tensor.graph]).stateful() {
OpsKind::Tracked(prep) => {
let output = B::sigmoid(tensor.primitive);
prep.finish(output.clone(), output)
}
OpsKind::UnTracked(prep) => prep.finish(B::sigmoid(tensor.primitive)),
}
}
}
2 changes: 2 additions & 0 deletions burn-autodiff/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ mod recip;
mod relu;
mod reshape;
mod select;
mod sigmoid;
mod sin;
mod slice;
mod softmax;
Expand Down Expand Up @@ -103,6 +104,7 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_abs!();
burn_autodiff::testgen_ad_sub!();
burn_autodiff::testgen_ad_tanh!();
burn_autodiff::testgen_ad_sigmoid!();
burn_autodiff::testgen_ad_transpose!();
};
}
33 changes: 33 additions & 0 deletions burn-autodiff/src/tests/sigmoid.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#[burn_tensor_testgen::testgen(ad_sigmoid)]
mod tests {
use super::*;
use burn_tensor::{activation, Data};

#[test]
fn should_diff_sigmoid() {
let data = Data::<f32, 1>::from([0.8762]);

let device = Default::default();
let tensor_1 = TestAutodiffTensor::from_data(data, &device).require_grad();
let tensor_2 = activation::sigmoid(tensor_1.clone());
let grads = tensor_2.backward();

let grad = tensor_1.grad(&grads).unwrap();

grad.to_data().assert_approx_eq(&Data::from([0.207549]), 4);
}

#[test]
fn small_neg_val_should_not_cause_grad_overflow() {
let data = Data::<f32, 1>::from([-90.0]);

let device = Default::default();
let tensor_1 = TestAutodiffTensor::from_data(data, &device).require_grad();
let tensor_2 = activation::sigmoid(tensor_1.clone());
let grads = tensor_2.backward();

let grad = tensor_1.grad(&grads).unwrap();

grad.to_data().assert_approx_eq(&Data::from([0.0]), 4);
}
}
4 changes: 4 additions & 0 deletions burn-tch/src/ops/activation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,8 @@ impl<E: TchElement> ActivationOps<Self> for LibTorch<E> {

TchTensor::from_existing(tensor, storage)
}

fn sigmoid<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
tensor.unary_ops(|mut tensor| tensor.sigmoid_(), |tensor| tensor.sigmoid())
}
}
2 changes: 1 addition & 1 deletion burn-tensor/src/tensor/activation/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ pub fn log_softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize)

/// Applies the sigmoid function.
pub fn sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
log_sigmoid(tensor).exp()
Tensor::from_primitive(B::sigmoid(tensor.primitive))
}

/// Applies the log sigmoid function.
Expand Down
40 changes: 40 additions & 0 deletions burn-tensor/src/tensor/ops/activation.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::tensor::ops::tensor::TensorOps;
use crate::{backend::Backend, ElementConversion};
use core::f64::consts::SQRT_2;

Expand Down Expand Up @@ -102,4 +103,43 @@ pub trait ActivationOps<B: Backend> {

B::mul(y, grad)
}

/// Applies the Sigmoid activation function.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The output tensor.
fn sigmoid<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D> {
let tensor_full = B::to_full_precision(&tensor);
let tensor_tmp = B::FullPrecisionBackend::exp(B::FullPrecisionBackend::neg(
B::FullPrecisionBackend::log(B::FullPrecisionBackend::add_scalar(
B::FullPrecisionBackend::exp(B::FullPrecisionBackend::neg(tensor_full)),
1.0.elem(),
)),
));

B::from_full_precision(tensor_tmp)
}

/// Applies the Sigmoid activation function backward.
///
/// # Arguments
///
/// * `output` - The output tensor of the sigmoid function.
/// * `grad` - The gradient.
///
/// # Returns
///
/// The output tensor.
fn sigmoid_backward<const D: usize>(
output: FloatTensor<B, D>,
grad: FloatTensor<B, D>,
) -> FloatTensor<B, D> {
let value = B::mul(output.clone(), B::add_scalar(B::neg(output), 1.0.elem()));
B::mul(value, grad)
}
}

0 comments on commit a5bdf38

Please sign in to comment.