From e1273254a20b3009a81cdafdd6e2d796effd737a Mon Sep 17 00:00:00 2001 From: nathaniel Date: Mon, 3 Feb 2025 10:13:39 -0500 Subject: [PATCH] Cleanup --- .../burn-core/src/nn/rnn/gate_controller.rs | 9 +------- crates/burn-core/src/nn/rnn/lstm.rs | 5 ---- .../burn-core/src/nn/transformer/decoder.rs | 1 - .../burn-fusion/src/stream/execution/base.rs | 6 ----- crates/burn-jit/src/fusion/matmul/builder.rs | 1 - .../burn-jit/src/fusion/on_write/builder.rs | 23 ++++++++----------- crates/burn-jit/src/fusion/on_write/io.rs | 8 +++---- crates/burn-jit/src/fusion/on_write/kernel.rs | 4 +++- 8 files changed, 17 insertions(+), 40 deletions(-) diff --git a/crates/burn-core/src/nn/rnn/gate_controller.rs b/crates/burn-core/src/nn/rnn/gate_controller.rs index 4d5868a936..f20f7fa2a1 100644 --- a/crates/burn-core/src/nn/rnn/gate_controller.rs +++ b/crates/burn-core/src/nn/rnn/gate_controller.rs @@ -57,14 +57,7 @@ impl GateController { /// H = hidden state /// b = bias terms pub fn gate_product(&self, input: Tensor, hidden: Tensor) -> Tensor { - println!("{input}"); - let temp = self.input_transform.forward(input); - let temp2 = self.hidden_transform.forward(hidden); - println!("1: {temp}"); - println!("2: {temp2}"); - let temp = temp + temp2; - // panic!("3: {temp}"); - temp + self.input_transform.forward(input) + self.hidden_transform.forward(hidden) } /// Used to initialize a gate controller with known weight layers, diff --git a/crates/burn-core/src/nn/rnn/lstm.rs b/crates/burn-core/src/nn/rnn/lstm.rs index 9968a48a0a..9a7c23399b 100644 --- a/crates/burn-core/src/nn/rnn/lstm.rs +++ b/crates/burn-core/src/nn/rnn/lstm.rs @@ -721,26 +721,21 @@ mod tests { output_with_init_state .to_data() .assert_approx_eq(&expected_output_with_init_state, 3); - println!("1"); output_without_init_state .to_data() .assert_approx_eq(&expected_output_without_init_state, 3); - println!("2"); state_with_init_state .hidden .to_data() .assert_approx_eq(&expected_hn_with_init_state, 3); - println!("3"); state_with_init_state .cell .to_data() .assert_approx_eq(&expected_cn_with_init_state, 3); - println!("4"); state_without_init_state .hidden .to_data() .assert_approx_eq(&expected_hn_without_init_state, 3); - println!("5"); state_without_init_state .cell .to_data() diff --git a/crates/burn-core/src/nn/transformer/decoder.rs b/crates/burn-core/src/nn/transformer/decoder.rs index 4c509bee86..a3d7e49158 100644 --- a/crates/burn-core/src/nn/transformer/decoder.rs +++ b/crates/burn-core/src/nn/transformer/decoder.rs @@ -458,7 +458,6 @@ mod tests { use burn_tensor::Device; use super::*; - use crate::tensor::Distribution; use crate::{nn::attention::generate_autoregressive_mask, TestBackend}; #[test] diff --git a/crates/burn-fusion/src/stream/execution/base.rs b/crates/burn-fusion/src/stream/execution/base.rs index ef31748a37..149b2d095d 100644 --- a/crates/burn-fusion/src/stream/execution/base.rs +++ b/crates/burn-fusion/src/stream/execution/base.rs @@ -49,12 +49,6 @@ impl OperationQueue { let mut context = self.converter.context(handles); optimization.execute(&mut context); - log::info!("====== MATMUL ======"); - for op in &self.global[0..num_drained] { - log::info!("{op:?}") - } - log::info!("====== END ======"); - self.drain_queue(num_drained, handles); self.operations.drain(0..num_drained); } diff --git a/crates/burn-jit/src/fusion/matmul/builder.rs b/crates/burn-jit/src/fusion/matmul/builder.rs index fb0ebfb48d..11e38626e3 100644 --- a/crates/burn-jit/src/fusion/matmul/builder.rs +++ b/crates/burn-jit/src/fusion/matmul/builder.rs @@ -47,7 +47,6 @@ impl OptimizationBuilder> for MatmulBuilder } if self.matmul.is_none() { - log::info!("New matmul fusion"); if let OperationDescription::Float(_, FloatOperationDescription::Matmul(op)) = operation { let lhs = self.builder.input_unhandled(&op.lhs); diff --git a/crates/burn-jit/src/fusion/on_write/builder.rs b/crates/burn-jit/src/fusion/on_write/builder.rs index 717e80ff2d..bbfca0f3d2 100644 --- a/crates/burn-jit/src/fusion/on_write/builder.rs +++ b/crates/burn-jit/src/fusion/on_write/builder.rs @@ -71,7 +71,6 @@ impl TryFuseBuilder { impl OptimizationBuilder for FuseOnWriteBuilder { fn register(&mut self, op: &OperationDescription) { - log::info!("Register {op:?}"); if let OptimizationStatus::Closed = self.status { return; } @@ -107,14 +106,12 @@ impl OptimizationBuilder for FuseOnWriteBuilder { return; } } - // TODO - // - // OperationDescription::BaseBool(ops) => { - // if !self.register_base(ops) { - // self.status = OptimizationStatus::Closed; - // return; - // } - // } + OperationDescription::BaseBool(ops) => { + if !self.register_base(ops) { + self.status = OptimizationStatus::Closed; + return; + } + } _ => { self.status = OptimizationStatus::Closed; return; @@ -542,12 +539,10 @@ impl FuseOnWriteBuilder { return false; } - if updated != self.current_output_shape { - if updated != out.shape { - return false; - } - self.current_output_shape.clone_from_slice(&out.shape); + if updated != out.shape { + return false; } + self.current_output_shape.clone_from_slice(&out.shape); true } diff --git a/crates/burn-jit/src/fusion/on_write/io.rs b/crates/burn-jit/src/fusion/on_write/io.rs index 24f2b94aa3..cacb2decc5 100644 --- a/crates/burn-jit/src/fusion/on_write/io.rs +++ b/crates/burn-jit/src/fusion/on_write/io.rs @@ -57,7 +57,7 @@ pub fn read( config, comptime![Some(shape)], ), - _ => comptime![panic![]], + _ => comptime![panic!("Only input can be reshaped")], }, } } @@ -90,7 +90,7 @@ pub fn read_scalar_shape(inputs: &GlobalArgs, #[comptime] arg: Arg) -> u32 { let offset = comptime![inputs.s_u32.len() - pos - 1]; *inputs.s_u32.index(offset) } - _ => comptime![panic!["Not a scalar shape"]], + _ => comptime![panic!("Not a scalar shape")], } } @@ -773,8 +773,8 @@ fn index_offset_with_layout( ) -> u32 { match comptime![shape.clone()] { Some(shape) => { - let index_reshaped = reshaped_index(inputs, layout, index, rank, shape); - reshaped_index_to_original_index(tensor, index_reshaped, rank) + let index = reshaped_index(inputs, layout, index, rank, shape); + reshaped_index_to_original_index(tensor, index, rank) } None => { let offset_ref = index * layout.line_size(); diff --git a/crates/burn-jit/src/fusion/on_write/kernel.rs b/crates/burn-jit/src/fusion/on_write/kernel.rs index fa5abbfba9..cf79a0fb6b 100644 --- a/crates/burn-jit/src/fusion/on_write/kernel.rs +++ b/crates/burn-jit/src/fusion/on_write/kernel.rs @@ -404,7 +404,9 @@ pub fn fuse_on_write( ElemwisePrecision::U8 => { equal::(inputs, outputs, &mut locals, write_pos, op, config) } - _ => comptime![panic!("Unsupported precision {op:?}")], + ElemwisePrecision::Bool => { + equal::(inputs, outputs, &mut locals, write_pos, op, config) + } }, ElemwiseOp::Greater(op) => match op.lhs.precision() { ElemwisePrecision::F32 => {