From 28721693630eea194b2db9712de1f113090df3de Mon Sep 17 00:00:00 2001 From: wangjiawen2013 Date: Tue, 28 Jan 2025 15:22:19 +0800 Subject: [PATCH] formatting --- examples/modern-lstm/src/dataset.rs | 2 +- examples/modern-lstm/src/main.rs | 2 +- examples/modern-lstm/src/model.rs | 28 +++++++++++++++------------- examples/modern-lstm/src/training.rs | 4 ++-- 4 files changed, 19 insertions(+), 17 deletions(-) diff --git a/examples/modern-lstm/src/dataset.rs b/examples/modern-lstm/src/dataset.rs index d9c61b7e5d..5820e3a73b 100644 --- a/examples/modern-lstm/src/dataset.rs +++ b/examples/modern-lstm/src/dataset.rs @@ -32,7 +32,7 @@ impl SequenceDatasetItem { // Next number is sum of previous two plus noise let normal = Normal::new(0.0, noise_level).unwrap(); let next_val = - seq[seq.len()-2] + seq[seq.len()-1] + normal.sample(&mut rand::thread_rng()); + seq[seq.len() - 2] + seq[seq.len() - 1] + normal.sample(&mut rand::thread_rng()); seq.push(next_val); } diff --git a/examples/modern-lstm/src/main.rs b/examples/modern-lstm/src/main.rs index 9f8e9c048c..f37afbb321 100644 --- a/examples/modern-lstm/src/main.rs +++ b/examples/modern-lstm/src/main.rs @@ -136,7 +136,7 @@ mod tch_gpu { lr, ); train::>(&artifact_dir, config, device); - }, + } Commands::Infer { artifact_dir } => { infer::(&artifact_dir, device); } diff --git a/examples/modern-lstm/src/model.rs b/examples/modern-lstm/src/model.rs index 7617c01a12..3afbf797da 100644 --- a/examples/modern-lstm/src/model.rs +++ b/examples/modern-lstm/src/model.rs @@ -30,7 +30,7 @@ pub struct LstmCell { // weight_hh layer uses combined weights for [i_t, f_t, g_t, o_t] for hidden state h_{t-1} pub weight_ih: Linear, pub weight_hh: Linear, - // Layer Normalization for better training stability. Don't use BatchNorm because the input distribution is always changing for LSTM. + // Layer Normalization for better training stability. Don't use BatchNorm because the input distribution is always changing for LSTM. pub norm_x: LayerNorm, // Normalize gate pre-activations pub norm_h: LayerNorm, // Normalize hidden state pub norm_c: LayerNorm, // Normalize cell state @@ -55,7 +55,7 @@ impl LstmCellConfig { pub fn init(&self, device: &B::Device) -> LstmCell { let initializer = Initializer::XavierNormal { gain: 1.0 }; let init_bias = Tensor::::ones([self.hidden_size], &device); - + let mut weight_ih = LinearConfig::new(self.input_size, 4 * self.hidden_size) .with_initializer(initializer.clone()) .init(device); @@ -65,7 +65,7 @@ impl LstmCellConfig { .clone() .unwrap() .val() - .slice_assign([self.hidden_size..2*self.hidden_size], init_bias.clone()); + .slice_assign([self.hidden_size..2 * self.hidden_size], init_bias.clone()); weight_ih.bias = weight_ih.bias.map(|p| p.map(|_t| bias)); let mut weight_hh = LinearConfig::new(self.hidden_size, 4 * self.hidden_size) @@ -76,7 +76,7 @@ impl LstmCellConfig { .clone() .unwrap() .val() - .slice_assign([self.hidden_size..2*self.hidden_size], init_bias); + .slice_assign([self.hidden_size..2 * self.hidden_size], init_bias); weight_hh.bias = weight_hh.bias.map(|p| p.map(|_t| bias)); LstmCell { @@ -173,18 +173,20 @@ impl StackedLstmConfig { LstmCellConfig::new(self.input_size, self.hidden_size, self.dropout) .init(device), ); - } else { // No dropout on last layer + } else { + // No dropout on last layer layers.push( LstmCellConfig::new(self.input_size, self.hidden_size, 0.0).init(device), ); } } else { - if i < self.num_layers -1 { + if i < self.num_layers - 1 { layers.push( LstmCellConfig::new(self.hidden_size, self.hidden_size, self.dropout) .init(device), ); - } else { // No dropout on last layer + } else { + // No dropout on last layer layers.push( LstmCellConfig::new(self.hidden_size, self.hidden_size, 0.0).init(device), ); @@ -239,7 +241,7 @@ impl StackedLstm { } layer_outputs.push(input_t); } - + // Stack output along sequence dimension let output = Tensor::stack(layer_outputs, 1); @@ -299,11 +301,11 @@ impl LstmNetworkConfig { self.dropout, ) .init(device); - (Some(lstm), 2*self.hidden_size) + (Some(lstm), 2 * self.hidden_size) } else { (None, self.hidden_size) }; - + let fc = LinearConfig::new(hidden_size, self.output_size).init(device); let dropout = DropoutConfig::new(self.dropout).init(); @@ -335,7 +337,7 @@ impl LstmNetwork { let seq_length = x.dims()[1] as i64; // Forward direction let (mut output, _states) = self.stacked_lstm.forward(x.clone(), states); - + output = match &self.reverse_lstm { Some(reverse_lstm) => { //Process sequence in reverse direction @@ -354,10 +356,10 @@ impl LstmNetwork { // Use final timestep output for prediction let final_output = self.fc.forward( output - .slice([None, Some((seq_length-1, seq_length)), None]) + .slice([None, Some((seq_length - 1, seq_length)), None]) .squeeze::<2>(1), ); final_output } -} \ No newline at end of file +} diff --git a/examples/modern-lstm/src/training.rs b/examples/modern-lstm/src/training.rs index 199b48d1a6..beb2e875b6 100644 --- a/examples/modern-lstm/src/training.rs +++ b/examples/modern-lstm/src/training.rs @@ -53,7 +53,7 @@ pub fn train(artifact_dir: &str, config: TrainingConfig, dev .shuffle(RANDOM_SEED) .num_workers(config.num_workers) .build(SequenceDataset::new(NUM_SEQUENCES, SEQ_LENGTH, NOISE_LEVEL)); - + let dataloader_valid = DataLoaderBuilder::new(batcher_valid) .batch_size(config.batch_size) .shuffle(RANDOM_SEED) @@ -82,7 +82,7 @@ pub fn train(artifact_dir: &str, config: TrainingConfig, dev let output = model.forward(batch.sequences, None); let loss = MseLoss::new().forward(output, batch.targets.clone(), Mean); train_loss += loss.clone().into_scalar().elem::() * batch.targets.dims()[0] as f32; - + // Gradients for the current backward pass let grads = loss.backward(); // Gradients linked to each parameter of the model