Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
wangjiawen2013 committed Jan 28, 2025
1 parent 31477f2 commit 2872169
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 17 deletions.
2 changes: 1 addition & 1 deletion examples/modern-lstm/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ impl SequenceDatasetItem {
// Next number is sum of previous two plus noise
let normal = Normal::new(0.0, noise_level).unwrap();
let next_val =
seq[seq.len()-2] + seq[seq.len()-1] + normal.sample(&mut rand::thread_rng());
seq[seq.len() - 2] + seq[seq.len() - 1] + normal.sample(&mut rand::thread_rng());
seq.push(next_val);
}

Expand Down
2 changes: 1 addition & 1 deletion examples/modern-lstm/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ mod tch_gpu {
lr,
);
train::<Autodiff<LibTorch>>(&artifact_dir, config, device);
},
}
Commands::Infer { artifact_dir } => {
infer::<LibTorch>(&artifact_dir, device);
}
Expand Down
28 changes: 15 additions & 13 deletions examples/modern-lstm/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub struct LstmCell<B: Backend> {
// weight_hh layer uses combined weights for [i_t, f_t, g_t, o_t] for hidden state h_{t-1}
pub weight_ih: Linear<B>,
pub weight_hh: Linear<B>,
// Layer Normalization for better training stability. Don't use BatchNorm because the input distribution is always changing for LSTM.
// Layer Normalization for better training stability. Don't use BatchNorm because the input distribution is always changing for LSTM.
pub norm_x: LayerNorm<B>, // Normalize gate pre-activations
pub norm_h: LayerNorm<B>, // Normalize hidden state
pub norm_c: LayerNorm<B>, // Normalize cell state
Expand All @@ -55,7 +55,7 @@ impl LstmCellConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> LstmCell<B> {
let initializer = Initializer::XavierNormal { gain: 1.0 };
let init_bias = Tensor::<B, 1>::ones([self.hidden_size], &device);

let mut weight_ih = LinearConfig::new(self.input_size, 4 * self.hidden_size)
.with_initializer(initializer.clone())
.init(device);
Expand All @@ -65,7 +65,7 @@ impl LstmCellConfig {
.clone()
.unwrap()
.val()
.slice_assign([self.hidden_size..2*self.hidden_size], init_bias.clone());
.slice_assign([self.hidden_size..2 * self.hidden_size], init_bias.clone());
weight_ih.bias = weight_ih.bias.map(|p| p.map(|_t| bias));

let mut weight_hh = LinearConfig::new(self.hidden_size, 4 * self.hidden_size)
Expand All @@ -76,7 +76,7 @@ impl LstmCellConfig {
.clone()
.unwrap()
.val()
.slice_assign([self.hidden_size..2*self.hidden_size], init_bias);
.slice_assign([self.hidden_size..2 * self.hidden_size], init_bias);
weight_hh.bias = weight_hh.bias.map(|p| p.map(|_t| bias));

LstmCell {
Expand Down Expand Up @@ -173,18 +173,20 @@ impl StackedLstmConfig {
LstmCellConfig::new(self.input_size, self.hidden_size, self.dropout)
.init(device),
);
} else { // No dropout on last layer
} else {
// No dropout on last layer
layers.push(
LstmCellConfig::new(self.input_size, self.hidden_size, 0.0).init(device),
);
}
} else {
if i < self.num_layers -1 {
if i < self.num_layers - 1 {
layers.push(
LstmCellConfig::new(self.hidden_size, self.hidden_size, self.dropout)
.init(device),
);
} else { // No dropout on last layer
} else {
// No dropout on last layer
layers.push(
LstmCellConfig::new(self.hidden_size, self.hidden_size, 0.0).init(device),
);
Expand Down Expand Up @@ -239,7 +241,7 @@ impl<B: Backend> StackedLstm<B> {
}
layer_outputs.push(input_t);
}

// Stack output along sequence dimension
let output = Tensor::stack(layer_outputs, 1);

Expand Down Expand Up @@ -299,11 +301,11 @@ impl LstmNetworkConfig {
self.dropout,
)
.init(device);
(Some(lstm), 2*self.hidden_size)
(Some(lstm), 2 * self.hidden_size)
} else {
(None, self.hidden_size)
};

let fc = LinearConfig::new(hidden_size, self.output_size).init(device);
let dropout = DropoutConfig::new(self.dropout).init();

Expand Down Expand Up @@ -335,7 +337,7 @@ impl<B: Backend> LstmNetwork<B> {
let seq_length = x.dims()[1] as i64;
// Forward direction
let (mut output, _states) = self.stacked_lstm.forward(x.clone(), states);

output = match &self.reverse_lstm {
Some(reverse_lstm) => {
//Process sequence in reverse direction
Expand All @@ -354,10 +356,10 @@ impl<B: Backend> LstmNetwork<B> {
// Use final timestep output for prediction
let final_output = self.fc.forward(
output
.slice([None, Some((seq_length-1, seq_length)), None])
.slice([None, Some((seq_length - 1, seq_length)), None])
.squeeze::<2>(1),
);

final_output
}
}
}
4 changes: 2 additions & 2 deletions examples/modern-lstm/src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, dev
.shuffle(RANDOM_SEED)
.num_workers(config.num_workers)
.build(SequenceDataset::new(NUM_SEQUENCES, SEQ_LENGTH, NOISE_LEVEL));

let dataloader_valid = DataLoaderBuilder::new(batcher_valid)
.batch_size(config.batch_size)
.shuffle(RANDOM_SEED)
Expand Down Expand Up @@ -82,7 +82,7 @@ pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, dev
let output = model.forward(batch.sequences, None);
let loss = MseLoss::new().forward(output, batch.targets.clone(), Mean);
train_loss += loss.clone().into_scalar().elem::<f32>() * batch.targets.dims()[0] as f32;

// Gradients for the current backward pass
let grads = loss.backward();
// Gradients linked to each parameter of the model
Expand Down

0 comments on commit 2872169

Please sign in to comment.