Skip to content

Commit

Permalink
Fix clippy
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui committed Feb 3, 2025
1 parent 9bdc8e9 commit af75564
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
7 changes: 3 additions & 4 deletions examples/modern-lstm/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ impl LstmCellConfig {
// Initialize parameters using best practices:
// 1. Orthogonal initialization for better gradient flow (here we use Xavier because of the lack of Orthogonal in burn)
// 2. Initialize forget gate bias to 1.0 to prevent forgetting at start of training
#[allow(clippy::single_range_in_vec_init)]
pub fn init<B: Backend>(&self, device: &B::Device) -> LstmCell<B> {
let initializer = Initializer::XavierNormal { gain: 1.0 };
let init_bias = Tensor::<B, 1>::ones([self.hidden_size], device);
Expand Down Expand Up @@ -352,12 +353,10 @@ impl<B: Backend> LstmNetwork<B> {
// Apply dropout before final layer
output = self.dropout.forward(output);
// Use final timestep output for prediction
let final_output = self.fc.forward(
self.fc.forward(
output
.slice([None, Some((seq_length - 1, seq_length)), None])
.squeeze::<2>(1),
);

final_output
)
}
}
4 changes: 2 additions & 2 deletions examples/modern-lstm/src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, dev
let mut valid_loss = 0.0;

// Implement our training loop
for (_iteration, batch) in dataloader_train.iter().enumerate() {
for batch in dataloader_train.iter() {
let output = model.forward(batch.sequences, None);
let loss = MseLoss::new().forward(output, batch.targets.clone(), Mean);
train_loss += loss.clone().into_scalar().elem::<f32>() * batch.targets.dims()[0] as f32;
Expand All @@ -103,7 +103,7 @@ pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, dev
let valid_model = model.valid();

// Implement our validation loop
for (_iteration, batch) in dataloader_valid.iter().enumerate() {
for batch in dataloader_valid.iter() {
let output = valid_model.forward(batch.sequences, None);
let loss = MseLoss::new().forward(output, batch.targets.clone(), Mean);
valid_loss += loss.clone().into_scalar().elem::<f32>() * batch.targets.dims()[0] as f32;
Expand Down

0 comments on commit af75564

Please sign in to comment.