1
- use std:: { any:: Any , collections:: HashMap , sync:: Arc } ;
2
-
3
- use burn_tensor:: backend:: Backend ;
4
-
5
1
use crate :: {
6
- graph:: { ComputingProperty , NodeID , NodeRef , NodeSteps } ,
2
+ graph:: { ComputingProperty , NodeID , NodeSteps } ,
7
3
tensor:: AutodiffTensor ,
8
4
} ;
5
+ use burn_tensor:: backend:: Backend ;
6
+ use std:: { any:: Any , collections:: HashMap , sync:: Arc } ;
9
7
10
8
use super :: {
11
9
base:: { Checkpointer , NodeTree } ,
@@ -20,31 +18,34 @@ pub enum CheckpointingAction {
20
18
/// The node's already computed output should be saved
21
19
Computed {
22
20
/// The node
23
- node_ref : NodeRef ,
21
+ node_id : NodeID ,
24
22
/// The node's output
25
- state_content : Box < dyn Any + Send + Sync > ,
23
+ state_content : Box < dyn Any + Send > ,
26
24
} ,
27
25
/// The node should recompute itself when asked
28
26
Recompute {
29
27
/// The node
30
- node_ref : NodeRef ,
28
+ node_id : NodeID ,
31
29
/// How the node should recompute itself
32
30
retro_forward : Arc < dyn RetroForward > ,
33
31
} ,
34
32
}
35
33
34
+ // TODO: Remove that when proper client server.
35
+ unsafe impl Send for CheckpointingAction { }
36
+
36
37
impl CheckpointingAction {
37
38
/// Utilitary function to access the id of the node of the checkpointing action
38
39
pub fn id ( & self ) -> NodeID {
39
40
match self {
40
41
CheckpointingAction :: Computed {
41
- node_ref,
42
+ node_id : node_ref,
42
43
state_content : _,
43
- } => node_ref. id . clone ( ) ,
44
+ } => * node_ref,
44
45
CheckpointingAction :: Recompute {
45
- node_ref,
46
+ node_id : node_ref,
46
47
retro_forward : _,
47
- } => node_ref. id . clone ( ) ,
48
+ } => * node_ref,
48
49
}
49
50
}
50
51
}
@@ -83,13 +84,13 @@ impl CheckpointerBuilder {
83
84
match & tensor. node . properties {
84
85
ComputingProperty :: ComputeBound | ComputingProperty :: Ambiguous => {
85
86
action_list. push ( CheckpointingAction :: Computed {
86
- node_ref : tensor. node . clone ( ) ,
87
+ node_id : tensor. node . id ,
87
88
state_content : Box :: new ( tensor. primitive . clone ( ) ) ,
88
89
} )
89
90
}
90
91
ComputingProperty :: MemoryBound { retro_forward } => {
91
92
action_list. push ( CheckpointingAction :: Recompute {
92
- node_ref : tensor. node . clone ( ) ,
93
+ node_id : tensor. node . id ,
93
94
retro_forward : retro_forward. clone ( ) ,
94
95
} )
95
96
}
@@ -105,10 +106,6 @@ impl CheckpointerBuilder {
105
106
}
106
107
}
107
108
108
- pub ( crate ) fn len ( & self ) -> usize {
109
- self . explicit_actions . len ( ) + self . backup_actions . len ( )
110
- }
111
-
112
109
pub ( crate ) fn build ( self , graph : & NodeSteps ) -> Checkpointer {
113
110
let node_tree = self . make_tree ( graph) ;
114
111
let mut backward_states_map = HashMap :: new ( ) ;
@@ -143,11 +140,11 @@ impl CheckpointerBuilder {
143
140
{
144
141
match action {
145
142
CheckpointingAction :: Computed {
146
- node_ref,
143
+ node_id : node_ref,
147
144
state_content : _,
148
- } => stop_nodes. push ( node_ref. id . clone ( ) ) ,
145
+ } => stop_nodes. push ( * node_ref) ,
149
146
CheckpointingAction :: Recompute {
150
- node_ref : _,
147
+ node_id : _,
151
148
retro_forward : _,
152
149
} => { }
153
150
}
@@ -165,10 +162,10 @@ impl CheckpointerBuilder {
165
162
for action in self . explicit_actions . iter ( ) {
166
163
match action {
167
164
CheckpointingAction :: Computed {
168
- node_ref,
165
+ node_id : node_ref,
169
166
state_content : _,
170
167
} => {
171
- let id = node_ref. id . clone ( ) ;
168
+ let id = * node_ref;
172
169
match n_required_map. remove ( & id) {
173
170
Some ( n) => {
174
171
n_required_map. insert ( id, n + 1 ) ;
@@ -179,10 +176,10 @@ impl CheckpointerBuilder {
179
176
} ;
180
177
}
181
178
CheckpointingAction :: Recompute {
182
- node_ref,
179
+ node_id : node_ref,
183
180
retro_forward : _,
184
181
} => {
185
- let id = node_ref. id . clone ( ) ;
182
+ let id = * node_ref;
186
183
Self :: update_n_required_of_parents (
187
184
id,
188
185
& mut n_required_map,
@@ -229,13 +226,13 @@ impl CheckpointerBuilder {
229
226
230
227
match action {
231
228
CheckpointingAction :: Computed {
232
- node_ref : _,
229
+ node_id : _,
233
230
state_content,
234
231
} => {
235
232
self . checkpoint_compute ( backward_states_map, node_id, state_content, n_required)
236
233
}
237
234
CheckpointingAction :: Recompute {
238
- node_ref : _,
235
+ node_id : _,
239
236
retro_forward,
240
237
} => self . checkpoint_lazy (
241
238
backward_states_map,
@@ -251,7 +248,7 @@ impl CheckpointerBuilder {
251
248
fn make_tree ( & self , graph : & NodeSteps ) -> NodeTree {
252
249
let mut tree = HashMap :: default ( ) ;
253
250
for ( id, step) in graph {
254
- tree. insert ( id . clone ( ) , step. node ( ) ) ;
251
+ tree. insert ( * id , step. parents ( ) ) ;
255
252
}
256
253
NodeTree :: new ( tree)
257
254
}
@@ -267,7 +264,7 @@ impl CheckpointerBuilder {
267
264
n_required_map. insert ( id, n + 1 ) ;
268
265
}
269
266
None => {
270
- n_required_map. insert ( id. clone ( ) , 1 ) ;
267
+ n_required_map. insert ( id, 1 ) ;
271
268
if !stop_nodes. contains ( & id) {
272
269
if let Some ( parents) = node_tree. parents ( & id) {
273
270
for p in parents {
@@ -288,7 +285,7 @@ impl CheckpointerBuilder {
288
285
& self ,
289
286
backward_states_map : & mut HashMap < NodeID , State > ,
290
287
node_id : NodeID ,
291
- state_content : Box < dyn Any + Send + Sync > ,
288
+ state_content : Box < dyn Any + Send > ,
292
289
n_required : usize ,
293
290
) {
294
291
backward_states_map. insert (
@@ -308,7 +305,7 @@ impl CheckpointerBuilder {
308
305
retro_forward : Arc < dyn RetroForward > ,
309
306
n_required : usize ,
310
307
) {
311
- retro_forward_map. insert ( node_id. clone ( ) , retro_forward) ;
312
- backward_states_map. insert ( node_id. clone ( ) , State :: Recompute { n_required } ) ;
308
+ retro_forward_map. insert ( node_id, retro_forward) ;
309
+ backward_states_map. insert ( node_id, State :: Recompute { n_required } ) ;
313
310
}
314
311
}
0 commit comments