Skip to content

Commit

Permalink
Update README + fix main changes
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui committed Feb 3, 2025
1 parent 7ecc9c4 commit 9bdc8e9
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 43 deletions.
11 changes: 11 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion examples/modern-lstm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"]
tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu"]
cuda-jit = ["burn/cuda-jit"]
cuda = ["burn/cuda"]

[dependencies]
burn = { path = "../../crates/burn", features=["train"] }
Expand Down
38 changes: 21 additions & 17 deletions examples/modern-lstm/README.md
Original file line number Diff line number Diff line change
@@ -1,42 +1,46 @@
# Advanced LSTM Implementation with Burn
A sophisticated implementation of Long Short-Term Memory (LSTM) networks in Burn, featuring state-of-the-art architectural enhancements and optimizations. This implementation includes bidirectional processing capabilities and advanced regularization techniques. More details can be found at the [PyTorch implementation](https://github.com/shiv08/Advanced-LSTM-Implementation-with-PyTorch).

`LstmNetwork` is the top-level module with bidirectional support and output projection. It can support multiple LSTM variants by setting appropriate `bidirectional` and `num_layers`
* LSTM: `num_layers = 1` and `bidirectional = false`
* Stacked LSTM: `num_layers > 1` and `bidirectional = false`
* Bidirectional LSTM: `num_layers = 1` and `bidirectional = true`
* Bidirectional Stacked LSTM: `num_layers > 1` and `bidirectional = true`
A more advanced implementation of Long Short-Term Memory (LSTM) networks in Burn with combined
weight matrices for the input and hidden states, based on the
[PyTorch implementation](https://github.com/shiv08/Advanced-LSTM-Implementation-with-PyTorch).

This implementation is complementary to Burn's official LSTM, users can choose either one depends on the project's specific needs.
`LstmNetwork` is the top-level module with bidirectional and regularization support. The LSTM
variants differ by `bidirectional` and `num_layers` settings:

## Usage
- LSTM: `num_layers = 1` and `bidirectional = false`
- Stacked LSTM: `num_layers > 1` and `bidirectional = false`
- Bidirectional LSTM: `num_layers = 1` and `bidirectional = true`
- Bidirectional Stacked LSTM: `num_layers > 1` and `bidirectional = true`

This implementation is complementary to Burn's official LSTM, users can choose either one depends on
the project's specific needs.

## Usage

## Training

```sh
# Cuda backend
cargo run --example train --release --features cuda-jit
cargo run --example lstm-train --release --features cuda-jit

# Wgpu backend
cargo run --example train --release --features wgpu
cargo run --example lstm-train --release --features wgpu

# Tch GPU backend
export TORCH_CUDA_VERSION=cu121 # Set the cuda version
cargo run --example train --release --features tch-gpu
cargo run --example lstm-train --release --features tch-gpu

# Tch CPU backend
cargo run --example train --release --features tch-cpu
cargo run --example lstm-train --release --features tch-cpu

# NdArray backend (CPU)
cargo run --example train --release --features ndarray
cargo run --example train --release --features ndarray-blas-openblas
cargo run --example train --release --features ndarray-blas-netlib
cargo run --example lstm-train --release --features ndarray
cargo run --example lstm-train --release --features ndarray-blas-openblas
cargo run --example lstm-train --release --features ndarray-blas-netlib
```


### Inference

```sh
cargo run --example infer --release --features cuda-jit
cargo run --example lstm-infer --release --features cuda-jit
```
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ pub fn launch<B: Backend>(device: B::Device) {
feature = "ndarray-blas-accelerate",
))]
mod ndarray {
use burn::backend::{
ndarray::{NdArray, NdArrayDevice}
};
use burn::backend::ndarray::{NdArray, NdArrayDevice};

use crate::launch;

Expand All @@ -24,9 +22,7 @@ mod ndarray {

#[cfg(feature = "tch-gpu")]
mod tch_gpu {
use burn::backend::{
libtorch::{LibTorch, LibTorchDevice}
};
use burn::backend::libtorch::{LibTorch, LibTorchDevice};

use crate::launch;

Expand All @@ -42,9 +38,7 @@ mod tch_gpu {

#[cfg(feature = "tch-cpu")]
mod tch_cpu {
use burn::backend::{
libtorch::{LibTorch, LibTorchDevice}
};
use burn::backend::libtorch::{LibTorch, LibTorchDevice};

use crate::launch;

Expand All @@ -63,13 +57,13 @@ mod wgpu {
}
}

#[cfg(feature = "cuda-jit")]
mod cuda_jit {
#[cfg(feature = "cuda")]
mod cuda {
use crate::launch;
use burn::backend::CudaJit;
use burn::backend::Cuda;

pub fn run() {
launch::<CudaJit>(Default::default());
launch::<Cuda>(Default::default());
}
}

Expand All @@ -87,6 +81,6 @@ fn main() {
tch_cpu::run();
#[cfg(feature = "wgpu")]
wgpu::run();
#[cfg(feature = "cuda-jit")]
cuda_jit::run();
#[cfg(feature = "cuda")]
cuda::run();
}
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
use burn::{
grad_clipping::GradientClippingConfig,
optim::AdamConfig,
tensor::backend::AutodiffBackend
grad_clipping::GradientClippingConfig, optim::AdamConfig, tensor::backend::AutodiffBackend,
};
use modern_lstm::{model::LstmNetworkConfig, training::TrainingConfig};

pub fn launch<B: AutodiffBackend>(device: B::Device) {
let config = TrainingConfig::new(
LstmNetworkConfig::new(),
// Gradient clipping via optimizer config
AdamConfig::new().with_grad_clipping(Some(GradientClippingConfig::Norm(1.0)))
AdamConfig::new().with_grad_clipping(Some(GradientClippingConfig::Norm(1.0))),
);

modern_lstm::training::train::<B>("/tmp/modern-lstm", config, device);
Expand Down Expand Up @@ -77,13 +75,13 @@ mod wgpu {
}
}

#[cfg(feature = "cuda-jit")]
mod cuda_jit {
#[cfg(feature = "cuda")]
mod cuda {
use crate::launch;
use burn::backend::{cuda_jit::CudaDevice, Autodiff, CudaJit};
use burn::backend::{cuda::CudaDevice, Autodiff, Cuda};

pub fn run() {
launch::<Autodiff<CudaJit>>(CudaDevice::default());
launch::<Autodiff<Cuda>>(CudaDevice::default());
}
}

Expand All @@ -101,6 +99,6 @@ fn main() {
tch_cpu::run();
#[cfg(feature = "wgpu")]
wgpu::run();
#[cfg(feature = "cuda-jit")]
cuda_jit::run();
#[cfg(feature = "cuda")]
cuda::run();
}

0 comments on commit 9bdc8e9

Please sign in to comment.