Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

modern lstm #2752

Merged
merged 15 commits into from
Feb 3, 2025
11 changes: 11 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 27 additions & 0 deletions examples/modern-lstm/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
[package]
name = "modern-lstm"
version = "0.1.0"
edition = "2021"

[features]
ndarray = ["burn/ndarray"]
ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"]
ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"]
ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"]
tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu"]
cuda = ["burn/cuda"]

[dependencies]
burn = { path = "../../crates/burn", features=["train"] }

# Random number generator
rand = { workspace = true }
rand_distr = { workspace = true }

# Serialization
serde = {workspace = true, features = ["std", "derive"]}

# Organise the results in dataframe
polars = { workspace = true }
46 changes: 46 additions & 0 deletions examples/modern-lstm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Advanced LSTM Implementation with Burn

A more advanced implementation of Long Short-Term Memory (LSTM) networks in Burn with combined
weight matrices for the input and hidden states, based on the
[PyTorch implementation](https://github.com/shiv08/Advanced-LSTM-Implementation-with-PyTorch).

`LstmNetwork` is the top-level module with bidirectional and regularization support. The LSTM
variants differ by `bidirectional` and `num_layers` settings:

- LSTM: `num_layers = 1` and `bidirectional = false`
- Stacked LSTM: `num_layers > 1` and `bidirectional = false`
- Bidirectional LSTM: `num_layers = 1` and `bidirectional = true`
- Bidirectional Stacked LSTM: `num_layers > 1` and `bidirectional = true`

This implementation is complementary to Burn's official LSTM, users can choose either one depends on
the project's specific needs.

## Usage

## Training

```sh
# Cuda backend
cargo run --example lstm-train --release --features cuda-jit

# Wgpu backend
cargo run --example lstm-train --release --features wgpu

# Tch GPU backend
export TORCH_CUDA_VERSION=cu121 # Set the cuda version
cargo run --example lstm-train --release --features tch-gpu

# Tch CPU backend
cargo run --example lstm-train --release --features tch-cpu

# NdArray backend (CPU)
cargo run --example lstm-train --release --features ndarray
cargo run --example lstm-train --release --features ndarray-blas-openblas
cargo run --example lstm-train --release --features ndarray-blas-netlib
```

### Inference

```sh
cargo run --example lstm-infer --release --features cuda-jit
```
86 changes: 86 additions & 0 deletions examples/modern-lstm/examples/lstm-infer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
use burn::tensor::backend::Backend;

pub fn launch<B: Backend>(device: B::Device) {
modern_lstm::inference::infer::<B>("/tmp/modern-lstm", device);
}

#[cfg(any(
feature = "ndarray",
feature = "ndarray-blas-netlib",
feature = "ndarray-blas-openblas",
feature = "ndarray-blas-accelerate",
))]
mod ndarray {
use burn::backend::ndarray::{NdArray, NdArrayDevice};

use crate::launch;

pub fn run() {
launch::<NdArray>(NdArrayDevice::Cpu);
}
}

#[cfg(feature = "tch-gpu")]
mod tch_gpu {
use burn::backend::libtorch::{LibTorch, LibTorchDevice};

use crate::launch;

pub fn run() {
#[cfg(not(target_os = "macos"))]
let device = LibTorchDevice::Cuda(0);
#[cfg(target_os = "macos")]
let device = LibTorchDevice::Mps;

launch::<LibTorch>(device);
}
}

#[cfg(feature = "tch-cpu")]
mod tch_cpu {
use burn::backend::libtorch::{LibTorch, LibTorchDevice};

use crate::launch;

pub fn run() {
launch::<LibTorch>(LibTorchDevice::Cpu);
}
}

#[cfg(feature = "wgpu")]
mod wgpu {
use crate::launch;
use burn::backend::wgpu::Wgpu;

pub fn run() {
launch::<Wgpu>(Default::default());
}
}

#[cfg(feature = "cuda")]
mod cuda {
use crate::launch;
use burn::backend::Cuda;

pub fn run() {
launch::<Cuda>(Default::default());
}
}

fn main() {
#[cfg(any(
feature = "ndarray",
feature = "ndarray-blas-netlib",
feature = "ndarray-blas-openblas",
feature = "ndarray-blas-accelerate",
))]
ndarray::run();
#[cfg(feature = "tch-gpu")]
tch_gpu::run();
#[cfg(feature = "tch-cpu")]
tch_cpu::run();
#[cfg(feature = "wgpu")]
wgpu::run();
#[cfg(feature = "cuda")]
cuda::run();
}
104 changes: 104 additions & 0 deletions examples/modern-lstm/examples/lstm-train.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
use burn::{
grad_clipping::GradientClippingConfig, optim::AdamConfig, tensor::backend::AutodiffBackend,
};
use modern_lstm::{model::LstmNetworkConfig, training::TrainingConfig};

pub fn launch<B: AutodiffBackend>(device: B::Device) {
let config = TrainingConfig::new(
LstmNetworkConfig::new(),
// Gradient clipping via optimizer config
AdamConfig::new().with_grad_clipping(Some(GradientClippingConfig::Norm(1.0))),
);

modern_lstm::training::train::<B>("/tmp/modern-lstm", config, device);
}

#[cfg(any(
feature = "ndarray",
feature = "ndarray-blas-netlib",
feature = "ndarray-blas-openblas",
feature = "ndarray-blas-accelerate",
))]
mod ndarray {
use burn::backend::{
ndarray::{NdArray, NdArrayDevice},
Autodiff,
};

use crate::launch;

pub fn run() {
launch::<Autodiff<NdArray>>(NdArrayDevice::Cpu);
}
}

#[cfg(feature = "tch-gpu")]
mod tch_gpu {
use burn::backend::{
libtorch::{LibTorch, LibTorchDevice},
Autodiff,
};

use crate::launch;

pub fn run() {
#[cfg(not(target_os = "macos"))]
let device = LibTorchDevice::Cuda(0);
#[cfg(target_os = "macos")]
let device = LibTorchDevice::Mps;

launch::<Autodiff<LibTorch>>(device);
}
}

#[cfg(feature = "tch-cpu")]
mod tch_cpu {
use burn::backend::{
libtorch::{LibTorch, LibTorchDevice},
Autodiff,
};

use crate::launch;

pub fn run() {
launch::<Autodiff<LibTorch>>(LibTorchDevice::Cpu);
}
}

#[cfg(feature = "wgpu")]
mod wgpu {
use crate::launch;
use burn::backend::{wgpu::Wgpu, Autodiff};

pub fn run() {
launch::<Autodiff<Wgpu>>(Default::default());
}
}

#[cfg(feature = "cuda")]
mod cuda {
use crate::launch;
use burn::backend::{cuda::CudaDevice, Autodiff, Cuda};

pub fn run() {
launch::<Autodiff<Cuda>>(CudaDevice::default());
}
}

fn main() {
#[cfg(any(
feature = "ndarray",
feature = "ndarray-blas-netlib",
feature = "ndarray-blas-openblas",
feature = "ndarray-blas-accelerate",
))]
ndarray::run();
#[cfg(feature = "tch-gpu")]
tch_gpu::run();
#[cfg(feature = "tch-cpu")]
tch_cpu::run();
#[cfg(feature = "wgpu")]
wgpu::run();
#[cfg(feature = "cuda")]
cuda::run();
}
110 changes: 110 additions & 0 deletions examples/modern-lstm/src/dataset.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
use burn::{
data::{
dataloader::batcher::Batcher,
dataset::{Dataset, InMemDataset},
},
prelude::*,
};
use rand::Rng;
use rand_distr::{Distribution, Normal};
use serde::{Deserialize, Serialize};

// Dataset parameters
pub const NUM_SEQUENCES: usize = 1000;
pub const SEQ_LENGTH: usize = 10;
pub const NOISE_LEVEL: f32 = 0.1;
pub const RANDOM_SEED: u64 = 5;

// Generate a sequence where each number is the sum of previous two numbers plus noise
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SequenceDatasetItem {
pub sequence: Vec<f32>,
pub target: f32,
}

impl SequenceDatasetItem {
pub fn new(seq_length: usize, noise_level: f32) -> Self {
// Start with two random numbers between 0 and 1
let mut seq = vec![rand::thread_rng().gen(), rand::thread_rng().gen()];

// Generate sequence
for _i in 0..seq_length {
// Next number is sum of previous two plus noise
let normal = Normal::new(0.0, noise_level).unwrap();
let next_val =
seq[seq.len() - 2] + seq[seq.len() - 1] + normal.sample(&mut rand::thread_rng());
seq.push(next_val);
}

Self {
// Convert to sequence and target
sequence: seq[0..seq.len() - 1].to_vec(), // All but last
target: seq[seq.len() - 1], // Last value
}
}
}

// Custom Dataset for Sequence Data
pub struct SequenceDataset {
dataset: InMemDataset<SequenceDatasetItem>,
}

impl SequenceDataset {
pub fn new(num_sequences: usize, seq_length: usize, noise_level: f32) -> Self {
let mut items = vec![];
for _i in 0..num_sequences {
items.push(SequenceDatasetItem::new(seq_length, noise_level));
}
let dataset = InMemDataset::new(items);

Self { dataset }
}
}

impl Dataset<SequenceDatasetItem> for SequenceDataset {
fn get(&self, index: usize) -> Option<SequenceDatasetItem> {
self.dataset.get(index)
}

fn len(&self) -> usize {
self.dataset.len()
}
}

#[derive(Clone, Debug)]
pub struct SequenceBatcher<B: Backend> {
device: B::Device,
}

#[derive(Clone, Debug)]
pub struct SequenceBatch<B: Backend> {
pub sequences: Tensor<B, 3>, // [batch_size, seq_length, input_size]
pub targets: Tensor<B, 2>, // [batch_size, 1]
}

impl<B: Backend> SequenceBatcher<B> {
pub fn new(device: B::Device) -> Self {
Self { device }
}
}

impl<B: Backend> Batcher<SequenceDatasetItem, SequenceBatch<B>> for SequenceBatcher<B> {
fn batch(&self, items: Vec<SequenceDatasetItem>) -> SequenceBatch<B> {
let mut sequences: Vec<Tensor<B, 2>> = Vec::new();

for item in items.iter() {
let seq_tensor = Tensor::<B, 1>::from_floats(item.sequence.as_slice(), &self.device);
// Add feature dimension, the input_size is 1 implicitly. We can change the input_size here with some operations
sequences.push(seq_tensor.unsqueeze_dims(&[-1]));
}
let sequences = Tensor::stack(sequences, 0);

let targets = items
.iter()
.map(|item| Tensor::<B, 1>::from_floats([item.target], &self.device))
.collect();
let targets = Tensor::stack(targets, 0);

SequenceBatch { sequences, targets }
}
}
Loading