Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

modern lstm #2752

Merged
merged 15 commits into from
Feb 3, 2025
934 changes: 787 additions & 147 deletions Cargo.lock

Large diffs are not rendered by default.

29 changes: 29 additions & 0 deletions examples/modern-lstm/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
[package]
name = "lstm"
version = "0.1.0"
edition = "2021"

[features]
ndarray = ["burn/ndarray"]
ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"]
ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"]
ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"]
tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu"]

[dependencies]
burn = { version="0.16.0", features=["train"] }

# Random number generator
rand = { version="0.8.5" }
rand_distr = { version="0.4.3" }

# Serialization
serde = { version="1.0.210", features=["std", "derive"] }

# Organise the results in dataframe
polars = { version="0.44.1" }

# Command line parser
clap = { version="4.5.21", features=["derive"] }
wangjiawen2013 marked this conversation as resolved.
Show resolved Hide resolved
22 changes: 22 additions & 0 deletions examples/modern-lstm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Advanced LSTM Implementation with Burn
A sophisticated implementation of Long Short-Term Memory (LSTM) networks in Burn, featuring state-of-the-art architectural enhancements and optimizations. This implementation includes bidirectional processing capabilities and advanced regularization techniques, making it suitable for both research and production environments. More details can be found at the [PyTorch implementation](https://github.com/shiv08/Advanced-LSTM-Implementation-with-PyTorch).
laggui marked this conversation as resolved.
Show resolved Hide resolved

`LstmNetwork` is the top-level module with bidirectional support and output projection. It can support multiple LSTM variants by setting appropriate `bidirectional` and `num_layers`:
* LSTM: `num_layers = 1` and `bidirectional = false`
* Stacked LSTM: `num_layers > 1` and `bidirectional = false`
* Bidirectional LSTM: `num_layers = 1` and `bidirectional = true`
* Bidirectional Stacked LSTM: `num_layers > 1` and `bidirectional = true`

This implementation is complementary to Burn's official LSTM, users can choose either one depends on the project's specific needs.

## Example Usage

### Training
```sh
cargo run --release --features ndarray -- train --artifact-dir /home/wangjw/data/work/projects/lstm/output --num-epochs 30 --batch-size 32 --num-workers 2 --lr 0.001
```

### Inference
```sh
cargo run --release --features ndarray -- infer --artifact-dir /home/wangjw/data/work/projects/lstm/output
laggui marked this conversation as resolved.
Show resolved Hide resolved
```
42 changes: 42 additions & 0 deletions examples/modern-lstm/src/cli.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use clap::{Parser, Subcommand};

/// A CLI for long short-term memory network
#[derive(Parser, Debug)]
#[command(version, author, about, long_about=None)]
pub struct Cli {
#[command(subcommand)]
pub command: Commands,
}

#[derive(Subcommand, Debug)]
pub enum Commands {
/// Train a model
Train {
/// Path to save trained model
#[arg(long)]
artifact_dir: String,

/// Number of epochs of training
#[arg(long, default_value = "200")]
num_epochs: usize,

/// Size of the batches
#[arg(long, default_value = "64")]
batch_size: usize,

/// Number of cpu threads to use during batch generation
#[arg(long, default_value = "8")]
num_workers: usize,

/// Learning rate
#[arg(long, default_value = "0.00005")]
lr: f64,
},

/// Inference
Infer {
/// Path to the trained model
#[arg(long)]
artifact_dir: String,
},
}
110 changes: 110 additions & 0 deletions examples/modern-lstm/src/dataset.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
use burn::{
data::{
dataloader::batcher::Batcher,
dataset::{Dataset, InMemDataset},
},
prelude::*,
};
use rand::Rng;
use rand_distr::{Distribution, Normal};
use serde::{Deserialize, Serialize};

// Dataset parameters
pub const NUM_SEQUENCES: usize = 1000;
pub const SEQ_LENGTH: usize = 10;
pub const NOISE_LEVEL: f32 = 0.1;
pub const RANDOM_SEED: u64 = 5;

// Generate a sequence where each number is the sum of previous two numbers plus noise
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SequenceDatasetItem {
pub sequence: Vec<f32>,
pub target: f32,
}

impl SequenceDatasetItem {
pub fn new(seq_length: usize, noise_level: f32) -> Self {
// Start with two random numbers between 0 and 1
let mut seq = vec![rand::thread_rng().gen(), rand::thread_rng().gen()];

// Generate sequence
for _i in 0..seq_length {
// Next number is sum of previous two plus noise
let normal = Normal::new(0.0, noise_level).unwrap();
let next_val =
seq[seq.len() - 2] + seq[seq.len() - 1] + normal.sample(&mut rand::thread_rng());
seq.push(next_val);
}

Self {
// Convert to sequence and target
sequence: seq[0..seq.len() - 1].to_vec(), // All but last
target: seq[seq.len() - 1], // Last value
}
}
}

// Custom Dataset for Sequence Data
pub struct SequenceDataset {
dataset: InMemDataset<SequenceDatasetItem>,
}

impl SequenceDataset {
pub fn new(num_sequences: usize, seq_length: usize, noise_level: f32) -> Self {
let mut items = vec![];
for _i in 0..num_sequences {
items.push(SequenceDatasetItem::new(seq_length, noise_level));
}
let dataset = InMemDataset::new(items);

Self { dataset }
}
}

impl Dataset<SequenceDatasetItem> for SequenceDataset {
fn get(&self, index: usize) -> Option<SequenceDatasetItem> {
self.dataset.get(index)
}

fn len(&self) -> usize {
self.dataset.len()
}
}

#[derive(Clone, Debug)]
pub struct SequenceBatcher<B: Backend> {
device: B::Device,
}

#[derive(Clone, Debug)]
pub struct SequenceBatch<B: Backend> {
pub sequences: Tensor<B, 3>, // [batch_size, seq_length, input_size]
pub targets: Tensor<B, 2>, // [batch_size, 1]
}

impl<B: Backend> SequenceBatcher<B> {
pub fn new(device: B::Device) -> Self {
Self { device }
}
}

impl<B: Backend> Batcher<SequenceDatasetItem, SequenceBatch<B>> for SequenceBatcher<B> {
fn batch(&self, items: Vec<SequenceDatasetItem>) -> SequenceBatch<B> {
let mut sequences: Vec<Tensor<B, 2>> = Vec::new();

for item in items.iter() {
let seq_tensor = Tensor::<B, 1>::from_floats(item.sequence.as_slice(), &self.device);
// Add feature dimension, the input_size is 1 implicitly. We can change the input_size here with some operations
sequences.push(seq_tensor.unsqueeze_dims(&[-1]));
}
let sequences = Tensor::stack(sequences, 0);

let targets = items
.iter()
.map(|item| Tensor::<B, 1>::from_floats([item.target], &self.device))
.collect();
let targets = Tensor::stack(targets, 0);

SequenceBatch { sequences, targets }
}
}
45 changes: 45 additions & 0 deletions examples/modern-lstm/src/inference.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use crate::{
dataset::{
SequenceBatcher, SequenceDataset, SequenceDatasetItem, NOISE_LEVEL, NUM_SEQUENCES,
SEQ_LENGTH,
},
model::LstmNetwork,
training::TrainingConfig,
};
use burn::{
data::{dataloader::batcher::Batcher, dataset::Dataset},
prelude::*,
record::{CompactRecorder, Recorder},
};
use polars::prelude::*;

pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device) {
// Loading model
let config = TrainingConfig::load(format!("{artifact_dir}/config.json"))
.expect("Config should exist for the model; run train first");
let record = CompactRecorder::new()
.load(format!("{artifact_dir}/model").into(), &device)
.expect("Trained model should exist; run train first");

let model: LstmNetwork<B> = config.model.init(&device).load_record(record);

let dataset = SequenceDataset::new(NUM_SEQUENCES / 5, SEQ_LENGTH, NOISE_LEVEL);
let items: Vec<SequenceDatasetItem> = dataset.iter().collect();

let batcher = SequenceBatcher::new(device);
// Put all items in one batch
let batch = batcher.batch(items);
let predicted = model.forward(batch.sequences, None);
let targets = batch.targets;

let predicted = predicted.squeeze::<1>(1).into_data();
let expected = targets.squeeze::<1>(1).into_data();

// Display the predicted vs expected values
let results = df![
"predicted" => &predicted.to_vec::<f32>().unwrap(),
"expected" => &expected.to_vec::<f32>().unwrap(),
]
.unwrap();
println!("{}", &results.head(Some(10)));
}
5 changes: 5 additions & 0 deletions examples/modern-lstm/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pub mod cli;
pub mod dataset;
pub mod inference;
pub mod model;
pub mod training;
Loading
Loading