-
Notifications
You must be signed in to change notification settings - Fork 489
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
29c383b
commit f0cddd5
Showing
10 changed files
with
1,673 additions
and
147 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
[package] | ||
name = "lstm" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
[features] | ||
ndarray = ["burn/ndarray"] | ||
ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"] | ||
ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"] | ||
ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"] | ||
tch-cpu = ["burn/tch"] | ||
tch-gpu = ["burn/tch"] | ||
wgpu = ["burn/wgpu"] | ||
|
||
[dependencies] | ||
burn = { version="0.16.0", features=["train"] } | ||
|
||
# Random number generator | ||
rand = { version="0.8.5" } | ||
rand_distr = { version="0.4.3" } | ||
|
||
# Serialization | ||
serde = { version="1.0.210", features=["std", "derive"] } | ||
|
||
# Organise the results in dataframe | ||
polars = { version="0.44.1" } | ||
|
||
# Command line parser | ||
clap = { version="4.5.21", features=["derive"] } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Advanced LSTM Implementation with Burn | ||
A sophisticated implementation of Long Short-Term Memory (LSTM) networks in Burn, featuring state-of-the-art architectural enhancements and optimizations. This implementation includes bidirectional processing capabilities and advanced regularization techniques, making it suitable for both research and production environments. More details can be found at the [PyTorch implementation](https://github.com/shiv08/Advanced-LSTM-Implementation-with-PyTorch). | ||
|
||
`LstmNetwork` is the top-level module with bidirectional support and output projection. It can support multiple LSTM variants by setting appropriate `bidirectional` and `num_layers`: | ||
* LSTM: `num_layers = 1` and `bidirectional = false` | ||
* Stacked LSTM: `num_layers > 1` and `bidirectional = false` | ||
* Birectional LSTM: `num_layers = 1` and `bidirectional = true` | ||
* Birectional Stacked LSTM: `num_layers > 1` and `birectional = true` | ||
|
||
This implementation is complementary to Burn's official LSTM, users can choose either one depends on the project's specific needs. | ||
|
||
## Example Usage | ||
|
||
### Training | ||
```sh | ||
cargo run --release --features ndarray -- train --artifact-dir /home/wangjw/data/work/projects/lstm/output --num-epochs 30 --batch-size 32 --num-workers 2 --lr 0.001 | ||
``` | ||
|
||
### Inference | ||
```sh | ||
cargo run --release --features ndarray -- infer --artifact-dir /home/wangjw/data/work/projects/lstm/output | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
use clap::{Parser, Subcommand}; | ||
|
||
/// A CLI for long short-term memory network | ||
#[derive(Parser, Debug)] | ||
#[command(version, author, about, long_about=None)] | ||
pub struct Cli { | ||
#[command(subcommand)] | ||
pub command: Commands, | ||
} | ||
|
||
#[derive(Subcommand, Debug)] | ||
pub enum Commands { | ||
/// Train a model | ||
Train { | ||
/// Path to save trained model | ||
#[arg(long)] | ||
artifact_dir: String, | ||
|
||
/// Number of epochs of training | ||
#[arg(long, default_value="200")] | ||
num_epochs: usize, | ||
|
||
/// Size of the batches | ||
#[arg(long, default_value="64")] | ||
batch_size: usize, | ||
|
||
/// Number of cpu threads to use during batch generation | ||
#[arg(long, default_value="8")] | ||
num_workers: usize, | ||
|
||
/// Learning rate | ||
#[arg(long, default_value="0.00005")] | ||
lr: f64, | ||
}, | ||
|
||
/// Inference | ||
Infer { | ||
/// Path to the trained model | ||
#[arg(long)] | ||
artifact_dir: String, | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
use burn::{ | ||
data::{ | ||
dataloader::batcher::Batcher, | ||
dataset::{Dataset, InMemDataset}, | ||
}, | ||
prelude::*, | ||
}; | ||
use rand::Rng; | ||
use rand_distr::{Normal, Distribution}; | ||
use serde::{Deserialize, Serialize}; | ||
|
||
// Dataset parameters | ||
pub const NUM_SEQUENCES: usize = 1000; | ||
pub const SEQ_LENGTH: usize = 10; | ||
pub const NOISE_LEVEL: f32 = 0.1; | ||
pub const RANDOM_SEED: u64 = 5; | ||
|
||
// Generate a sequence where each number is the sum of previous two numbers plus noise | ||
#[derive(Clone, Debug, Serialize, Deserialize)] | ||
pub struct SequenceDatasetItem { | ||
pub sequence: Vec<f32>, | ||
pub target: f32, | ||
} | ||
|
||
impl SequenceDatasetItem { | ||
pub fn new(seq_length: usize, noise_level: f32) -> Self { | ||
// Start with two random numbers between 0 and 1 | ||
let mut seq = vec![rand::thread_rng().gen(), rand::thread_rng().gen()]; | ||
|
||
// Generate sequence | ||
for _i in 0..seq_length { | ||
// Next number is sum of previous two plus noise | ||
let normal = Normal::new(0.0, noise_level).unwrap(); | ||
let next_val = seq[seq.len()-2] + seq[seq.len()-1] + normal.sample(&mut rand::thread_rng()); | ||
seq.push(next_val); | ||
} | ||
|
||
Self { | ||
// Convert to sequence and target | ||
sequence: seq[0..seq.len()-1].to_vec(), // All but last | ||
target: seq[seq.len()-1], // Last value | ||
} | ||
} | ||
} | ||
|
||
// Custom Dataset for Sequence Data | ||
pub struct SequenceDataset { | ||
dataset: InMemDataset<SequenceDatasetItem>, | ||
} | ||
|
||
impl SequenceDataset { | ||
pub fn new(num_sequences: usize, seq_length: usize, noise_level: f32) -> Self { | ||
let mut items = vec![]; | ||
for _i in 0..num_sequences { | ||
items.push(SequenceDatasetItem::new(seq_length, noise_level)); | ||
} | ||
let dataset = InMemDataset::new(items); | ||
|
||
Self { dataset } | ||
} | ||
} | ||
|
||
impl Dataset<SequenceDatasetItem> for SequenceDataset { | ||
fn get(&self, index: usize) -> Option<SequenceDatasetItem> { | ||
self.dataset.get(index) | ||
} | ||
|
||
fn len(&self) -> usize { | ||
self.dataset.len() | ||
} | ||
} | ||
|
||
#[derive(Clone, Debug)] | ||
pub struct SequenceBatcher<B: Backend> { | ||
device: B::Device, | ||
} | ||
|
||
#[derive(Clone, Debug)] | ||
pub struct SequenceBatch<B: Backend> { | ||
pub sequences: Tensor<B, 3>, // [batch_size, seq_length, input_size] | ||
pub targets: Tensor<B, 2>, // [batch_size, 1] | ||
} | ||
|
||
impl<B: Backend> SequenceBatcher<B> { | ||
pub fn new(device: B::Device) -> Self { | ||
Self { | ||
device, | ||
} | ||
} | ||
} | ||
|
||
impl<B: Backend> Batcher<SequenceDatasetItem, SequenceBatch<B>> for SequenceBatcher<B> { | ||
fn batch(&self, items: Vec<SequenceDatasetItem>) -> SequenceBatch<B> { | ||
let mut sequences: Vec<Tensor<B, 2>> = Vec::new(); | ||
|
||
for item in items.iter() { | ||
let seq_tensor = Tensor::<B, 1>::from_floats(item.sequence.as_slice(), &self.device); | ||
// Add feature dimension, the input_size is 1 implicitly. We can change the input_size here with some operations | ||
sequences.push(seq_tensor.unsqueeze_dims(&[-1])); | ||
} | ||
let sequences = Tensor::stack(sequences, 0); | ||
|
||
let targets = items | ||
.iter() | ||
.map(|item| Tensor::<B, 1>::from_floats([item.target], &self.device)) | ||
.collect(); | ||
let targets = Tensor::stack(targets, 0); | ||
|
||
SequenceBatch { sequences, targets } | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
use crate::{ | ||
dataset::{ | ||
SequenceBatcher, | ||
SequenceDataset, | ||
SequenceDatasetItem, | ||
NUM_SEQUENCES, SEQ_LENGTH, NOISE_LEVEL, | ||
}, | ||
training::TrainingConfig, | ||
model::LstmNetwork, | ||
}; | ||
use polars::prelude::*; | ||
use burn::{ | ||
data::{dataset::Dataset, dataloader::batcher::Batcher}, | ||
prelude::*, | ||
record::{CompactRecorder, Recorder}, | ||
}; | ||
|
||
pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device) { | ||
// Loading model | ||
let config = TrainingConfig::load(format!("{artifact_dir}/config.json")) | ||
.expect("Config should exist for the model; run train first"); | ||
let record = CompactRecorder::new() | ||
.load(format!("{artifact_dir}/model").into(), &device) | ||
.expect("Trained model should exist; run train first"); | ||
|
||
let model: LstmNetwork<B> = config.model.init(&device).load_record(record); | ||
|
||
let dataset = SequenceDataset::new(NUM_SEQUENCES / 5, SEQ_LENGTH, NOISE_LEVEL); | ||
let items: Vec<SequenceDatasetItem> = dataset.iter().collect(); | ||
|
||
let batcher = SequenceBatcher::new(device); | ||
// Put all items in one batch | ||
let batch = batcher.batch(items); | ||
let predicted = model.forward(batch.sequences, None); | ||
let targets = batch.targets; | ||
|
||
let predicted = predicted.squeeze::<1>(1).into_data(); | ||
let expected = targets.squeeze::<1>(1).into_data(); | ||
|
||
// Display the predicted vs expected values | ||
let results = df![ | ||
"predicted" => &predicted.to_vec::<f32>().unwrap(), | ||
"expected" => &expected.to_vec::<f32>().unwrap(), | ||
].unwrap(); | ||
println!("{}", &results.head(Some(10))); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
pub mod dataset; | ||
pub mod model; | ||
pub mod training; | ||
pub mod inference; | ||
pub mod cli; |
Oops, something went wrong.