diff --git a/crates/burn-core/src/nn/mod.rs b/crates/burn-core/src/nn/mod.rs index 56ed615b95..80fb02c8be 100644 --- a/crates/burn-core/src/nn/mod.rs +++ b/crates/burn-core/src/nn/mod.rs @@ -29,6 +29,7 @@ mod prelu; mod relu; mod rnn; mod rope_encoding; +mod sigmoid; mod swiglu; mod tanh; mod unfold; @@ -46,6 +47,7 @@ pub use prelu::*; pub use relu::*; pub use rnn::*; pub use rope_encoding::*; +pub use sigmoid::*; pub use swiglu::*; pub use tanh::*; pub use unfold::*; diff --git a/crates/burn-core/src/nn/sigmoid.rs b/crates/burn-core/src/nn/sigmoid.rs new file mode 100644 index 0000000000..68db4b07c2 --- /dev/null +++ b/crates/burn-core/src/nn/sigmoid.rs @@ -0,0 +1,38 @@ +use crate as burn; + +use crate::module::Module; +use crate::tensor::backend::Backend; +use crate::tensor::Tensor; + +/// Applies the sigmoid function element-wise +/// See also [sigmoid](burn::tensor::activation::sigmoid) +#[derive(Module, Clone, Debug, Default)] +pub struct Sigmoid; + +impl Sigmoid { + /// Create the module. + pub fn new() -> Self { + Self {} + } + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: `[..., any]` + /// - output: `[..., any]` + pub fn forward(&self, input: Tensor) -> Tensor { + crate::tensor::activation::sigmoid(input) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn display() { + let layer = Sigmoid::new(); + + assert_eq!(alloc::format!("{}", layer), "Sigmoid"); + } +}