Skip to content

Commit

Permalink
modern lstm (#2752)
Browse files Browse the repository at this point in the history
* modern lstm

* format

* formatting

* formatting

* formatting

* formatting

* fix a typo

* Update examples/modern-lstm/Cargo.toml

Co-authored-by: Guillaume Lagrange <[email protected]>

* use generic backend

* remove Cargo.lock

* use backend for inference

* update readme

* Update README + fix main changes

* Fix clippy

---------

Co-authored-by: Guillaume Lagrange <[email protected]>
  • Loading branch information
wangjiawen2013 and laggui authored Feb 3, 2025
1 parent 9f00320 commit e2fa935
Show file tree
Hide file tree
Showing 10 changed files with 926 additions and 0 deletions.
11 changes: 11 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 27 additions & 0 deletions examples/modern-lstm/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
[package]
name = "modern-lstm"
version = "0.1.0"
edition = "2021"

[features]
ndarray = ["burn/ndarray"]
ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"]
ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"]
ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"]
tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu"]
cuda = ["burn/cuda"]

[dependencies]
burn = { path = "../../crates/burn", features=["train"] }

# Random number generator
rand = { workspace = true }
rand_distr = { workspace = true }

# Serialization
serde = {workspace = true, features = ["std", "derive"]}

# Organise the results in dataframe
polars = { workspace = true }
46 changes: 46 additions & 0 deletions examples/modern-lstm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Advanced LSTM Implementation with Burn

A more advanced implementation of Long Short-Term Memory (LSTM) networks in Burn with combined
weight matrices for the input and hidden states, based on the
[PyTorch implementation](https://github.com/shiv08/Advanced-LSTM-Implementation-with-PyTorch).

`LstmNetwork` is the top-level module with bidirectional and regularization support. The LSTM
variants differ by `bidirectional` and `num_layers` settings:

- LSTM: `num_layers = 1` and `bidirectional = false`
- Stacked LSTM: `num_layers > 1` and `bidirectional = false`
- Bidirectional LSTM: `num_layers = 1` and `bidirectional = true`
- Bidirectional Stacked LSTM: `num_layers > 1` and `bidirectional = true`

This implementation is complementary to Burn's official LSTM, users can choose either one depends on
the project's specific needs.

## Usage

## Training

```sh
# Cuda backend
cargo run --example lstm-train --release --features cuda-jit

# Wgpu backend
cargo run --example lstm-train --release --features wgpu

# Tch GPU backend
export TORCH_CUDA_VERSION=cu121 # Set the cuda version
cargo run --example lstm-train --release --features tch-gpu

# Tch CPU backend
cargo run --example lstm-train --release --features tch-cpu

# NdArray backend (CPU)
cargo run --example lstm-train --release --features ndarray
cargo run --example lstm-train --release --features ndarray-blas-openblas
cargo run --example lstm-train --release --features ndarray-blas-netlib
```

### Inference

```sh
cargo run --example lstm-infer --release --features cuda-jit
```
86 changes: 86 additions & 0 deletions examples/modern-lstm/examples/lstm-infer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
use burn::tensor::backend::Backend;

pub fn launch<B: Backend>(device: B::Device) {
modern_lstm::inference::infer::<B>("/tmp/modern-lstm", device);
}

#[cfg(any(
feature = "ndarray",
feature = "ndarray-blas-netlib",
feature = "ndarray-blas-openblas",
feature = "ndarray-blas-accelerate",
))]
mod ndarray {
use burn::backend::ndarray::{NdArray, NdArrayDevice};

use crate::launch;

pub fn run() {
launch::<NdArray>(NdArrayDevice::Cpu);
}
}

#[cfg(feature = "tch-gpu")]
mod tch_gpu {
use burn::backend::libtorch::{LibTorch, LibTorchDevice};

use crate::launch;

pub fn run() {
#[cfg(not(target_os = "macos"))]
let device = LibTorchDevice::Cuda(0);
#[cfg(target_os = "macos")]
let device = LibTorchDevice::Mps;

launch::<LibTorch>(device);
}
}

#[cfg(feature = "tch-cpu")]
mod tch_cpu {
use burn::backend::libtorch::{LibTorch, LibTorchDevice};

use crate::launch;

pub fn run() {
launch::<LibTorch>(LibTorchDevice::Cpu);
}
}

#[cfg(feature = "wgpu")]
mod wgpu {
use crate::launch;
use burn::backend::wgpu::Wgpu;

pub fn run() {
launch::<Wgpu>(Default::default());
}
}

#[cfg(feature = "cuda")]
mod cuda {
use crate::launch;
use burn::backend::Cuda;

pub fn run() {
launch::<Cuda>(Default::default());
}
}

fn main() {
#[cfg(any(
feature = "ndarray",
feature = "ndarray-blas-netlib",
feature = "ndarray-blas-openblas",
feature = "ndarray-blas-accelerate",
))]
ndarray::run();
#[cfg(feature = "tch-gpu")]
tch_gpu::run();
#[cfg(feature = "tch-cpu")]
tch_cpu::run();
#[cfg(feature = "wgpu")]
wgpu::run();
#[cfg(feature = "cuda")]
cuda::run();
}
104 changes: 104 additions & 0 deletions examples/modern-lstm/examples/lstm-train.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
use burn::{
grad_clipping::GradientClippingConfig, optim::AdamConfig, tensor::backend::AutodiffBackend,
};
use modern_lstm::{model::LstmNetworkConfig, training::TrainingConfig};

pub fn launch<B: AutodiffBackend>(device: B::Device) {
let config = TrainingConfig::new(
LstmNetworkConfig::new(),
// Gradient clipping via optimizer config
AdamConfig::new().with_grad_clipping(Some(GradientClippingConfig::Norm(1.0))),
);

modern_lstm::training::train::<B>("/tmp/modern-lstm", config, device);
}

#[cfg(any(
feature = "ndarray",
feature = "ndarray-blas-netlib",
feature = "ndarray-blas-openblas",
feature = "ndarray-blas-accelerate",
))]
mod ndarray {
use burn::backend::{
ndarray::{NdArray, NdArrayDevice},
Autodiff,
};

use crate::launch;

pub fn run() {
launch::<Autodiff<NdArray>>(NdArrayDevice::Cpu);
}
}

#[cfg(feature = "tch-gpu")]
mod tch_gpu {
use burn::backend::{
libtorch::{LibTorch, LibTorchDevice},
Autodiff,
};

use crate::launch;

pub fn run() {
#[cfg(not(target_os = "macos"))]
let device = LibTorchDevice::Cuda(0);
#[cfg(target_os = "macos")]
let device = LibTorchDevice::Mps;

launch::<Autodiff<LibTorch>>(device);
}
}

#[cfg(feature = "tch-cpu")]
mod tch_cpu {
use burn::backend::{
libtorch::{LibTorch, LibTorchDevice},
Autodiff,
};

use crate::launch;

pub fn run() {
launch::<Autodiff<LibTorch>>(LibTorchDevice::Cpu);
}
}

#[cfg(feature = "wgpu")]
mod wgpu {
use crate::launch;
use burn::backend::{wgpu::Wgpu, Autodiff};

pub fn run() {
launch::<Autodiff<Wgpu>>(Default::default());
}
}

#[cfg(feature = "cuda")]
mod cuda {
use crate::launch;
use burn::backend::{cuda::CudaDevice, Autodiff, Cuda};

pub fn run() {
launch::<Autodiff<Cuda>>(CudaDevice::default());
}
}

fn main() {
#[cfg(any(
feature = "ndarray",
feature = "ndarray-blas-netlib",
feature = "ndarray-blas-openblas",
feature = "ndarray-blas-accelerate",
))]
ndarray::run();
#[cfg(feature = "tch-gpu")]
tch_gpu::run();
#[cfg(feature = "tch-cpu")]
tch_cpu::run();
#[cfg(feature = "wgpu")]
wgpu::run();
#[cfg(feature = "cuda")]
cuda::run();
}
110 changes: 110 additions & 0 deletions examples/modern-lstm/src/dataset.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
use burn::{
data::{
dataloader::batcher::Batcher,
dataset::{Dataset, InMemDataset},
},
prelude::*,
};
use rand::Rng;
use rand_distr::{Distribution, Normal};
use serde::{Deserialize, Serialize};

// Dataset parameters
pub const NUM_SEQUENCES: usize = 1000;
pub const SEQ_LENGTH: usize = 10;
pub const NOISE_LEVEL: f32 = 0.1;
pub const RANDOM_SEED: u64 = 5;

// Generate a sequence where each number is the sum of previous two numbers plus noise
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SequenceDatasetItem {
pub sequence: Vec<f32>,
pub target: f32,
}

impl SequenceDatasetItem {
pub fn new(seq_length: usize, noise_level: f32) -> Self {
// Start with two random numbers between 0 and 1
let mut seq = vec![rand::thread_rng().gen(), rand::thread_rng().gen()];

// Generate sequence
for _i in 0..seq_length {
// Next number is sum of previous two plus noise
let normal = Normal::new(0.0, noise_level).unwrap();
let next_val =
seq[seq.len() - 2] + seq[seq.len() - 1] + normal.sample(&mut rand::thread_rng());
seq.push(next_val);
}

Self {
// Convert to sequence and target
sequence: seq[0..seq.len() - 1].to_vec(), // All but last
target: seq[seq.len() - 1], // Last value
}
}
}

// Custom Dataset for Sequence Data
pub struct SequenceDataset {
dataset: InMemDataset<SequenceDatasetItem>,
}

impl SequenceDataset {
pub fn new(num_sequences: usize, seq_length: usize, noise_level: f32) -> Self {
let mut items = vec![];
for _i in 0..num_sequences {
items.push(SequenceDatasetItem::new(seq_length, noise_level));
}
let dataset = InMemDataset::new(items);

Self { dataset }
}
}

impl Dataset<SequenceDatasetItem> for SequenceDataset {
fn get(&self, index: usize) -> Option<SequenceDatasetItem> {
self.dataset.get(index)
}

fn len(&self) -> usize {
self.dataset.len()
}
}

#[derive(Clone, Debug)]
pub struct SequenceBatcher<B: Backend> {
device: B::Device,
}

#[derive(Clone, Debug)]
pub struct SequenceBatch<B: Backend> {
pub sequences: Tensor<B, 3>, // [batch_size, seq_length, input_size]
pub targets: Tensor<B, 2>, // [batch_size, 1]
}

impl<B: Backend> SequenceBatcher<B> {
pub fn new(device: B::Device) -> Self {
Self { device }
}
}

impl<B: Backend> Batcher<SequenceDatasetItem, SequenceBatch<B>> for SequenceBatcher<B> {
fn batch(&self, items: Vec<SequenceDatasetItem>) -> SequenceBatch<B> {
let mut sequences: Vec<Tensor<B, 2>> = Vec::new();

for item in items.iter() {
let seq_tensor = Tensor::<B, 1>::from_floats(item.sequence.as_slice(), &self.device);
// Add feature dimension, the input_size is 1 implicitly. We can change the input_size here with some operations
sequences.push(seq_tensor.unsqueeze_dims(&[-1]));
}
let sequences = Tensor::stack(sequences, 0);

let targets = items
.iter()
.map(|item| Tensor::<B, 1>::from_floats([item.target], &self.device))
.collect();
let targets = Tensor::stack(targets, 0);

SequenceBatch { sequences, targets }
}
}
Loading

0 comments on commit e2fa935

Please sign in to comment.