@@ -9,7 +9,7 @@ use burn_fusion::stream::{Context, ContextOwned};
9
9
/// operation.
10
10
pub enum TuneContext < ' a , R : JitRuntime > {
11
11
Original ( & ' a mut Context < ' a , JitFusionHandle < R > > ) ,
12
- Fork ( ContextOwned < JitFusionHandle < R > > ) ,
12
+ Fork ( Box < ContextOwned < JitFusionHandle < R > > > ) ,
13
13
}
14
14
15
15
/// Fusion input wrapper containing the context and the optimization.
@@ -35,7 +35,7 @@ pub struct TuneInput<R: JitRuntime, O> {
35
35
/// the best kernel to use, which can be async.
36
36
enum UnsafeTuneContext < R : JitRuntime > {
37
37
Original ( * mut Context < ' static , JitFusionHandle < R > > ) ,
38
- Fork ( ContextOwned < JitFusionHandle < R > > ) ,
38
+ Fork ( Box < ContextOwned < JitFusionHandle < R > > > ) ,
39
39
}
40
40
41
41
unsafe impl < R : JitRuntime > Send for UnsafeTuneContext < R > { }
@@ -67,15 +67,19 @@ impl<R: JitRuntime, O> TuneInput<R, O> {
67
67
68
68
impl < R : JitRuntime > UnsafeTuneContext < R > {
69
69
fn new ( context : & mut Context < ' _ , JitFusionHandle < R > > ) -> Self {
70
- Self :: Original ( core:: ptr:: from_mut ( context) as * mut Context < ' static , JitFusionHandle < R > > )
70
+ let ptr = core:: ptr:: from_mut ( context) ;
71
+
72
+ // It is necessary for the lifetime.
73
+ #[ allow( clippy:: unnecessary_cast) ]
74
+ Self :: Original ( ptr as * mut Context < ' static , _ > )
71
75
}
72
76
73
77
fn get ( & self ) -> TuneContext < ' static , R > {
74
78
match self {
75
79
UnsafeTuneContext :: Original ( ptr) => {
76
80
TuneContext :: Original ( unsafe { ptr. as_mut ( ) . unwrap ( ) } )
77
81
}
78
- UnsafeTuneContext :: Fork ( context) => TuneContext :: Fork ( context. fork ( ) ) ,
82
+ UnsafeTuneContext :: Fork ( context) => TuneContext :: Fork ( Box :: new ( context. fork ( ) ) ) ,
79
83
}
80
84
}
81
85
}
@@ -84,7 +88,7 @@ impl<R: JitRuntime, O> Clone for TuneInput<R, O> {
84
88
fn clone ( & self ) -> Self {
85
89
Self {
86
90
context : self . context . clone ( ) ,
87
- optimization : self . optimization . clone ( ) ,
91
+ optimization : self . optimization ,
88
92
}
89
93
}
90
94
}
@@ -99,6 +103,6 @@ impl<R: JitRuntime> Clone for UnsafeTuneContext<R> {
99
103
}
100
104
UnsafeTuneContext :: Fork ( context) => context. fork ( ) ,
101
105
} ;
102
- UnsafeTuneContext :: Fork ( context)
106
+ UnsafeTuneContext :: Fork ( Box :: new ( context) )
103
107
}
104
108
}
0 commit comments