Skip to content

Commit

Permalink
modern lstm
Browse files Browse the repository at this point in the history
  • Loading branch information
wangjiawen2013 committed Jan 28, 2025
1 parent 29c383b commit f0cddd5
Show file tree
Hide file tree
Showing 10 changed files with 1,673 additions and 147 deletions.
934 changes: 787 additions & 147 deletions Cargo.lock

Large diffs are not rendered by default.

29 changes: 29 additions & 0 deletions examples/modern-lstm/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
[package]
name = "lstm"
version = "0.1.0"
edition = "2021"

[features]
ndarray = ["burn/ndarray"]
ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"]
ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"]
ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"]
tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu"]

[dependencies]
burn = { version="0.16.0", features=["train"] }

# Random number generator
rand = { version="0.8.5" }
rand_distr = { version="0.4.3" }

# Serialization
serde = { version="1.0.210", features=["std", "derive"] }

# Organise the results in dataframe
polars = { version="0.44.1" }

# Command line parser
clap = { version="4.5.21", features=["derive"] }
22 changes: 22 additions & 0 deletions examples/modern-lstm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Advanced LSTM Implementation with Burn
A sophisticated implementation of Long Short-Term Memory (LSTM) networks in Burn, featuring state-of-the-art architectural enhancements and optimizations. This implementation includes bidirectional processing capabilities and advanced regularization techniques, making it suitable for both research and production environments. More details can be found at the [PyTorch implementation](https://github.com/shiv08/Advanced-LSTM-Implementation-with-PyTorch).

`LstmNetwork` is the top-level module with bidirectional support and output projection. It can support multiple LSTM variants by setting appropriate `bidirectional` and `num_layers`
* LSTM: `num_layers = 1` and `bidirectional = false`
* Stacked LSTM: `num_layers > 1` and `bidirectional = false`
* Birectional LSTM: `num_layers = 1` and `bidirectional = true`
* Birectional Stacked LSTM: `num_layers > 1` and `birectional = true`

This implementation is complementary to Burn's official LSTM, users can choose either one depends on the project's specific needs.

## Example Usage

### Training
```sh
cargo run --release --features ndarray -- train --artifact-dir /home/wangjw/data/work/projects/lstm/output --num-epochs 30 --batch-size 32 --num-workers 2 --lr 0.001
```

### Inference
```sh
cargo run --release --features ndarray -- infer --artifact-dir /home/wangjw/data/work/projects/lstm/output
```
42 changes: 42 additions & 0 deletions examples/modern-lstm/src/cli.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use clap::{Parser, Subcommand};

/// A CLI for long short-term memory network
#[derive(Parser, Debug)]
#[command(version, author, about, long_about=None)]
pub struct Cli {
#[command(subcommand)]
pub command: Commands,
}

#[derive(Subcommand, Debug)]
pub enum Commands {
/// Train a model
Train {
/// Path to save trained model
#[arg(long)]
artifact_dir: String,

/// Number of epochs of training
#[arg(long, default_value="200")]
num_epochs: usize,

/// Size of the batches
#[arg(long, default_value="64")]
batch_size: usize,

/// Number of cpu threads to use during batch generation
#[arg(long, default_value="8")]
num_workers: usize,

/// Learning rate
#[arg(long, default_value="0.00005")]
lr: f64,
},

/// Inference
Infer {
/// Path to the trained model
#[arg(long)]
artifact_dir: String,
}
}
111 changes: 111 additions & 0 deletions examples/modern-lstm/src/dataset.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
use burn::{
data::{
dataloader::batcher::Batcher,
dataset::{Dataset, InMemDataset},
},
prelude::*,
};
use rand::Rng;
use rand_distr::{Normal, Distribution};
use serde::{Deserialize, Serialize};

// Dataset parameters
pub const NUM_SEQUENCES: usize = 1000;
pub const SEQ_LENGTH: usize = 10;
pub const NOISE_LEVEL: f32 = 0.1;
pub const RANDOM_SEED: u64 = 5;

// Generate a sequence where each number is the sum of previous two numbers plus noise
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SequenceDatasetItem {
pub sequence: Vec<f32>,
pub target: f32,
}

impl SequenceDatasetItem {
pub fn new(seq_length: usize, noise_level: f32) -> Self {
// Start with two random numbers between 0 and 1
let mut seq = vec![rand::thread_rng().gen(), rand::thread_rng().gen()];

// Generate sequence
for _i in 0..seq_length {
// Next number is sum of previous two plus noise
let normal = Normal::new(0.0, noise_level).unwrap();
let next_val = seq[seq.len()-2] + seq[seq.len()-1] + normal.sample(&mut rand::thread_rng());
seq.push(next_val);
}

Self {
// Convert to sequence and target
sequence: seq[0..seq.len()-1].to_vec(), // All but last
target: seq[seq.len()-1], // Last value
}
}
}

// Custom Dataset for Sequence Data
pub struct SequenceDataset {
dataset: InMemDataset<SequenceDatasetItem>,
}

impl SequenceDataset {
pub fn new(num_sequences: usize, seq_length: usize, noise_level: f32) -> Self {
let mut items = vec![];
for _i in 0..num_sequences {
items.push(SequenceDatasetItem::new(seq_length, noise_level));
}
let dataset = InMemDataset::new(items);

Self { dataset }
}
}

impl Dataset<SequenceDatasetItem> for SequenceDataset {
fn get(&self, index: usize) -> Option<SequenceDatasetItem> {
self.dataset.get(index)
}

fn len(&self) -> usize {
self.dataset.len()
}
}

#[derive(Clone, Debug)]
pub struct SequenceBatcher<B: Backend> {
device: B::Device,
}

#[derive(Clone, Debug)]
pub struct SequenceBatch<B: Backend> {
pub sequences: Tensor<B, 3>, // [batch_size, seq_length, input_size]
pub targets: Tensor<B, 2>, // [batch_size, 1]
}

impl<B: Backend> SequenceBatcher<B> {
pub fn new(device: B::Device) -> Self {
Self {
device,
}
}
}

impl<B: Backend> Batcher<SequenceDatasetItem, SequenceBatch<B>> for SequenceBatcher<B> {
fn batch(&self, items: Vec<SequenceDatasetItem>) -> SequenceBatch<B> {
let mut sequences: Vec<Tensor<B, 2>> = Vec::new();

for item in items.iter() {
let seq_tensor = Tensor::<B, 1>::from_floats(item.sequence.as_slice(), &self.device);
// Add feature dimension, the input_size is 1 implicitly. We can change the input_size here with some operations
sequences.push(seq_tensor.unsqueeze_dims(&[-1]));
}
let sequences = Tensor::stack(sequences, 0);

let targets = items
.iter()
.map(|item| Tensor::<B, 1>::from_floats([item.target], &self.device))
.collect();
let targets = Tensor::stack(targets, 0);

SequenceBatch { sequences, targets }
}
}
46 changes: 46 additions & 0 deletions examples/modern-lstm/src/inference.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use crate::{
dataset::{
SequenceBatcher,
SequenceDataset,
SequenceDatasetItem,
NUM_SEQUENCES, SEQ_LENGTH, NOISE_LEVEL,
},
training::TrainingConfig,
model::LstmNetwork,
};
use polars::prelude::*;
use burn::{
data::{dataset::Dataset, dataloader::batcher::Batcher},
prelude::*,
record::{CompactRecorder, Recorder},
};

pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device) {
// Loading model
let config = TrainingConfig::load(format!("{artifact_dir}/config.json"))
.expect("Config should exist for the model; run train first");
let record = CompactRecorder::new()
.load(format!("{artifact_dir}/model").into(), &device)
.expect("Trained model should exist; run train first");

let model: LstmNetwork<B> = config.model.init(&device).load_record(record);

let dataset = SequenceDataset::new(NUM_SEQUENCES / 5, SEQ_LENGTH, NOISE_LEVEL);
let items: Vec<SequenceDatasetItem> = dataset.iter().collect();

let batcher = SequenceBatcher::new(device);
// Put all items in one batch
let batch = batcher.batch(items);
let predicted = model.forward(batch.sequences, None);
let targets = batch.targets;

let predicted = predicted.squeeze::<1>(1).into_data();
let expected = targets.squeeze::<1>(1).into_data();

// Display the predicted vs expected values
let results = df![
"predicted" => &predicted.to_vec::<f32>().unwrap(),
"expected" => &expected.to_vec::<f32>().unwrap(),
].unwrap();
println!("{}", &results.head(Some(10)));
}
5 changes: 5 additions & 0 deletions examples/modern-lstm/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pub mod dataset;
pub mod model;
pub mod training;
pub mod inference;
pub mod cli;
Loading

0 comments on commit f0cddd5

Please sign in to comment.