diff --git a/examples/modern-lstm/src/model.rs b/examples/modern-lstm/src/model.rs index ebf6b3bc08..268de59a0b 100644 --- a/examples/modern-lstm/src/model.rs +++ b/examples/modern-lstm/src/model.rs @@ -52,6 +52,7 @@ impl LstmCellConfig { // Initialize parameters using best practices: // 1. Orthogonal initialization for better gradient flow (here we use Xavier because of the lack of Orthogonal in burn) // 2. Initialize forget gate bias to 1.0 to prevent forgetting at start of training + #[allow(clippy::single_range_in_vec_init)] pub fn init(&self, device: &B::Device) -> LstmCell { let initializer = Initializer::XavierNormal { gain: 1.0 }; let init_bias = Tensor::::ones([self.hidden_size], device); @@ -352,12 +353,10 @@ impl LstmNetwork { // Apply dropout before final layer output = self.dropout.forward(output); // Use final timestep output for prediction - let final_output = self.fc.forward( + self.fc.forward( output .slice([None, Some((seq_length - 1, seq_length)), None]) .squeeze::<2>(1), - ); - - final_output + ) } } diff --git a/examples/modern-lstm/src/training.rs b/examples/modern-lstm/src/training.rs index a74babd509..9f6af81328 100644 --- a/examples/modern-lstm/src/training.rs +++ b/examples/modern-lstm/src/training.rs @@ -82,7 +82,7 @@ pub fn train(artifact_dir: &str, config: TrainingConfig, dev let mut valid_loss = 0.0; // Implement our training loop - for (_iteration, batch) in dataloader_train.iter().enumerate() { + for batch in dataloader_train.iter() { let output = model.forward(batch.sequences, None); let loss = MseLoss::new().forward(output, batch.targets.clone(), Mean); train_loss += loss.clone().into_scalar().elem::() * batch.targets.dims()[0] as f32; @@ -103,7 +103,7 @@ pub fn train(artifact_dir: &str, config: TrainingConfig, dev let valid_model = model.valid(); // Implement our validation loop - for (_iteration, batch) in dataloader_valid.iter().enumerate() { + for batch in dataloader_valid.iter() { let output = valid_model.forward(batch.sequences, None); let loss = MseLoss::new().forward(output, batch.targets.clone(), Mean); valid_loss += loss.clone().into_scalar().elem::() * batch.targets.dims()[0] as f32;