Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Feb 3, 2025
1 parent f0c82e2 commit e127325
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 40 deletions.
9 changes: 1 addition & 8 deletions crates/burn-core/src/nn/rnn/gate_controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,7 @@ impl<B: Backend> GateController<B> {
/// H = hidden state
/// b = bias terms
pub fn gate_product(&self, input: Tensor<B, 2>, hidden: Tensor<B, 2>) -> Tensor<B, 2> {
println!("{input}");
let temp = self.input_transform.forward(input);
let temp2 = self.hidden_transform.forward(hidden);
println!("1: {temp}");
println!("2: {temp2}");
let temp = temp + temp2;
// panic!("3: {temp}");
temp
self.input_transform.forward(input) + self.hidden_transform.forward(hidden)
}

/// Used to initialize a gate controller with known weight layers,
Expand Down
5 changes: 0 additions & 5 deletions crates/burn-core/src/nn/rnn/lstm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -721,26 +721,21 @@ mod tests {
output_with_init_state
.to_data()
.assert_approx_eq(&expected_output_with_init_state, 3);
println!("1");
output_without_init_state
.to_data()
.assert_approx_eq(&expected_output_without_init_state, 3);
println!("2");
state_with_init_state
.hidden
.to_data()
.assert_approx_eq(&expected_hn_with_init_state, 3);
println!("3");
state_with_init_state
.cell
.to_data()
.assert_approx_eq(&expected_cn_with_init_state, 3);
println!("4");
state_without_init_state
.hidden
.to_data()
.assert_approx_eq(&expected_hn_without_init_state, 3);
println!("5");
state_without_init_state
.cell
.to_data()
Expand Down
1 change: 0 additions & 1 deletion crates/burn-core/src/nn/transformer/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,6 @@ mod tests {
use burn_tensor::Device;

use super::*;
use crate::tensor::Distribution;
use crate::{nn::attention::generate_autoregressive_mask, TestBackend};

#[test]
Expand Down
6 changes: 0 additions & 6 deletions crates/burn-fusion/src/stream/execution/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,6 @@ impl<R: FusionRuntime> OperationQueue<R> {
let mut context = self.converter.context(handles);
optimization.execute(&mut context);

log::info!("====== MATMUL ======");
for op in &self.global[0..num_drained] {
log::info!("{op:?}")
}
log::info!("====== END ======");

self.drain_queue(num_drained, handles);
self.operations.drain(0..num_drained);
}
Expand Down
1 change: 0 additions & 1 deletion crates/burn-jit/src/fusion/matmul/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ impl<R: JitRuntime> OptimizationBuilder<JitOptimization<R>> for MatmulBuilder<R>
}

if self.matmul.is_none() {
log::info!("New matmul fusion");
if let OperationDescription::Float(_, FloatOperationDescription::Matmul(op)) = operation
{
let lhs = self.builder.input_unhandled(&op.lhs);
Expand Down
23 changes: 9 additions & 14 deletions crates/burn-jit/src/fusion/on_write/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ impl TryFuseBuilder {

impl OptimizationBuilder<FuseOnWriteTrace> for FuseOnWriteBuilder {
fn register(&mut self, op: &OperationDescription) {
log::info!("Register {op:?}");
if let OptimizationStatus::Closed = self.status {
return;
}
Expand Down Expand Up @@ -107,14 +106,12 @@ impl OptimizationBuilder<FuseOnWriteTrace> for FuseOnWriteBuilder {
return;
}
}
// TODO
//
// OperationDescription::BaseBool(ops) => {
// if !self.register_base(ops) {
// self.status = OptimizationStatus::Closed;
// return;
// }
// }
OperationDescription::BaseBool(ops) => {
if !self.register_base(ops) {
self.status = OptimizationStatus::Closed;
return;
}
}
_ => {
self.status = OptimizationStatus::Closed;
return;
Expand Down Expand Up @@ -542,12 +539,10 @@ impl FuseOnWriteBuilder {
return false;
}

if updated != self.current_output_shape {
if updated != out.shape {
return false;
}
self.current_output_shape.clone_from_slice(&out.shape);
if updated != out.shape {
return false;
}
self.current_output_shape.clone_from_slice(&out.shape);

true
}
Expand Down
8 changes: 4 additions & 4 deletions crates/burn-jit/src/fusion/on_write/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ pub fn read<C: CubePrimitive>(
config,
comptime![Some(shape)],
),
_ => comptime![panic![]],
_ => comptime![panic!("Only input can be reshaped")],
},
}
}
Expand Down Expand Up @@ -90,7 +90,7 @@ pub fn read_scalar_shape(inputs: &GlobalArgs, #[comptime] arg: Arg) -> u32 {
let offset = comptime![inputs.s_u32.len() - pos - 1];
*inputs.s_u32.index(offset)
}
_ => comptime![panic!["Not a scalar shape"]],
_ => comptime![panic!("Not a scalar shape")],
}
}

Expand Down Expand Up @@ -773,8 +773,8 @@ fn index_offset_with_layout<N: CubePrimitive, L: CubePrimitive>(
) -> u32 {
match comptime![shape.clone()] {
Some(shape) => {
let index_reshaped = reshaped_index(inputs, layout, index, rank, shape);
reshaped_index_to_original_index(tensor, index_reshaped, rank)
let index = reshaped_index(inputs, layout, index, rank, shape);
reshaped_index_to_original_index(tensor, index, rank)
}
None => {
let offset_ref = index * layout.line_size();
Expand Down
4 changes: 3 additions & 1 deletion crates/burn-jit/src/fusion/on_write/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,9 @@ pub fn fuse_on_write<E: CubePrimitive>(
ElemwisePrecision::U8 => {
equal::<u8>(inputs, outputs, &mut locals, write_pos, op, config)
}
_ => comptime![panic!("Unsupported precision {op:?}")],
ElemwisePrecision::Bool => {
equal::<bool>(inputs, outputs, &mut locals, write_pos, op, config)
}
},
ElemwiseOp::Greater(op) => match op.lhs.precision() {
ElemwisePrecision::F32 => {
Expand Down

0 comments on commit e127325

Please sign in to comment.