diff --git a/burn-book/src/building-blocks/tensor.md b/burn-book/src/building-blocks/tensor.md index bf7c74ddac..ad9a76ae7b 100644 --- a/burn-book/src/building-blocks/tensor.md +++ b/burn-book/src/building-blocks/tensor.md @@ -330,6 +330,7 @@ strategies. | `activation::sigmoid(tensor)` | `nn.functional.sigmoid(tensor)` | | `activation::silu(tensor)` | `nn.functional.silu(tensor)` | | `activation::softmax(tensor, dim)` | `nn.functional.softmax(tensor, dim)` | +| `activation::softmin(tensor, dim)` | `nn.functional.softmin(tensor, dim)` | | `activation::softplus(tensor, beta)` | `nn.functional.softplus(tensor, beta)` | | `activation::tanh(tensor)` | `nn.functional.tanh(tensor)` | diff --git a/crates/burn-tensor/src/tensor/activation/base.rs b/crates/burn-tensor/src/tensor/activation/base.rs index 0539647c43..cc5990d375 100644 --- a/crates/burn-tensor/src/tensor/activation/base.rs +++ b/crates/burn-tensor/src/tensor/activation/base.rs @@ -78,6 +78,19 @@ pub fn softmax(tensor: Tensor, dim: usize) -> tensor.div(tensor_tmp) } +/// Applies the softmin function on the input tensor along the given dimension. +/// +/// `softmin(x_i) = exp(-x_i) / sum_j(exp(-x_j))` +/// +/// # Notes +/// +/// The dimension argument `dim` specifies the dimension along which the function will be computed. +/// It must in the range of `0` and `D-1`. +pub fn softmin(tensor: Tensor, dim: usize) -> Tensor { + check!(TensorCheck::dim_ops::("softmin", dim)); + softmax(tensor.neg(), dim) +} + /// Applies the softplus function /// /// `softplus(x_i) = log(1 + exp(\beta x_i)) / \beta` diff --git a/crates/burn-tensor/src/tests/activation/mod.rs b/crates/burn-tensor/src/tests/activation/mod.rs index e28377fe2f..022b4b2ced 100644 --- a/crates/burn-tensor/src/tests/activation/mod.rs +++ b/crates/burn-tensor/src/tests/activation/mod.rs @@ -8,5 +8,6 @@ pub(crate) mod relu; pub(crate) mod sigmoid; pub(crate) mod silu; pub(crate) mod softmax; +pub(crate) mod softmin; pub(crate) mod softplus; pub(crate) mod tanh_activation; diff --git a/crates/burn-tensor/src/tests/activation/softmin.rs b/crates/burn-tensor/src/tests/activation/softmin.rs new file mode 100644 index 0000000000..62a1401889 --- /dev/null +++ b/crates/burn-tensor/src/tests/activation/softmin.rs @@ -0,0 +1,15 @@ +#[burn_tensor_testgen::testgen(softmin)] +mod tests { + use super::*; + use burn_tensor::{activation, Tensor, TensorData}; + + #[test] + fn test_softmin_d2() { + let tensor = TestTensor::<2>::from([[1.0, 7.0], [13.0, -3.0]]); + + let output = activation::softmin(tensor, 1); + let expected = TensorData::from([[9.9753e-01, 2.4726e-03], [1.1254e-07, 1.0000e+00]]); + + output.into_data().assert_approx_eq(&expected, 4); + } +} diff --git a/crates/burn-tensor/src/tests/mod.rs b/crates/burn-tensor/src/tests/mod.rs index e9ee65c3ce..e5a0acc80e 100644 --- a/crates/burn-tensor/src/tests/mod.rs +++ b/crates/burn-tensor/src/tests/mod.rs @@ -15,6 +15,7 @@ macro_rules! testgen_all { burn_tensor::testgen_relu!(); burn_tensor::testgen_leaky_relu!(); burn_tensor::testgen_softmax!(); + burn_tensor::testgen_softmin!(); burn_tensor::testgen_softplus!(); burn_tensor::testgen_sigmoid!(); burn_tensor::testgen_log_sigmoid!();