Skip to content

Commit 38fea54

Browse files
Fix clippy
1 parent 9bc529f commit 38fea54

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

crates/burn-jit/src/fusion/matmul/tune.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ pub fn fused_matmul_autotune<R: JitRuntime, BT: BoolElement>(
2929
) {
3030
static TUNER: LocalTuner<FusedMatmulAutotuneKey, JitTuneId> = local_tuner!();
3131

32-
let tunables = TunableSet::new(create_key::<R, BT>, input_gen::<R, BT>)
32+
let tunables = TunableSet::new(create_key::<R>, input_gen::<R>)
3333
.with_tunable(tune_standard_fused::<R, BT>)
3434
.with_tunable(tune_specialized_fused::<R, BT>)
3535
.with_tunable(tune_pipelined_fused::<R, BT>)
@@ -43,7 +43,7 @@ pub fn fused_matmul_autotune<R: JitRuntime, BT: BoolElement>(
4343
);
4444
}
4545

46-
pub(crate) fn create_key<R: JitRuntime, BT: BoolElement>(
46+
pub(crate) fn create_key<R: JitRuntime>(
4747
input: &TuneInput<R, MatmulOptimization<R>>,
4848
) -> FusedMatmulAutotuneKey {
4949
let opt = input.optimization();
@@ -64,7 +64,7 @@ pub(crate) fn create_key<R: JitRuntime, BT: BoolElement>(
6464
FusedMatmulAutotuneKey::new(key, opt.len)
6565
}
6666

67-
fn input_gen<R: JitRuntime, BT: BoolElement>(
67+
fn input_gen<R: JitRuntime>(
6868
_key: &FusedMatmulAutotuneKey,
6969
input: &TuneInput<R, MatmulOptimization<R>>,
7070
) -> TuneInput<R, MatmulOptimization<R>> {

crates/burn-jit/src/fusion/tune.rs

+10-6
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use burn_fusion::stream::{Context, ContextOwned};
99
/// operation.
1010
pub enum TuneContext<'a, R: JitRuntime> {
1111
Original(&'a mut Context<'a, JitFusionHandle<R>>),
12-
Fork(ContextOwned<JitFusionHandle<R>>),
12+
Fork(Box<ContextOwned<JitFusionHandle<R>>>),
1313
}
1414

1515
/// Fusion input wrapper containing the context and the optimization.
@@ -35,7 +35,7 @@ pub struct TuneInput<R: JitRuntime, O> {
3535
/// the best kernel to use, which can be async.
3636
enum UnsafeTuneContext<R: JitRuntime> {
3737
Original(*mut Context<'static, JitFusionHandle<R>>),
38-
Fork(ContextOwned<JitFusionHandle<R>>),
38+
Fork(Box<ContextOwned<JitFusionHandle<R>>>),
3939
}
4040

4141
unsafe impl<R: JitRuntime> Send for UnsafeTuneContext<R> {}
@@ -67,15 +67,19 @@ impl<R: JitRuntime, O> TuneInput<R, O> {
6767

6868
impl<R: JitRuntime> UnsafeTuneContext<R> {
6969
fn new(context: &mut Context<'_, JitFusionHandle<R>>) -> Self {
70-
Self::Original(core::ptr::from_mut(context) as *mut Context<'static, JitFusionHandle<R>>)
70+
let ptr = core::ptr::from_mut(context);
71+
72+
// It is necessary for the lifetime.
73+
#[allow(clippy::unnecessary_cast)]
74+
Self::Original(ptr as *mut Context<'static, _>)
7175
}
7276

7377
fn get(&self) -> TuneContext<'static, R> {
7478
match self {
7579
UnsafeTuneContext::Original(ptr) => {
7680
TuneContext::Original(unsafe { ptr.as_mut().unwrap() })
7781
}
78-
UnsafeTuneContext::Fork(context) => TuneContext::Fork(context.fork()),
82+
UnsafeTuneContext::Fork(context) => TuneContext::Fork(Box::new(context.fork())),
7983
}
8084
}
8185
}
@@ -84,7 +88,7 @@ impl<R: JitRuntime, O> Clone for TuneInput<R, O> {
8488
fn clone(&self) -> Self {
8589
Self {
8690
context: self.context.clone(),
87-
optimization: self.optimization.clone(),
91+
optimization: self.optimization,
8892
}
8993
}
9094
}
@@ -99,6 +103,6 @@ impl<R: JitRuntime> Clone for UnsafeTuneContext<R> {
99103
}
100104
UnsafeTuneContext::Fork(context) => context.fork(),
101105
};
102-
UnsafeTuneContext::Fork(context)
106+
UnsafeTuneContext::Fork(Box::new(context))
103107
}
104108
}

0 commit comments

Comments
 (0)