Skip to content

Commit 1239d9b

Browse files
[Breaking] Make Tensor, Module, Optimizer !Sync + Refactor Autodiff (#1575)
1 parent ce898ff commit 1239d9b

File tree

51 files changed

+1049
-678
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+1049
-678
lines changed

crates/burn-autodiff/Cargo.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ repository = "https://github.com/tracel-ai/burn/tree/main/burn-autodiff"
1111
version.workspace = true
1212

1313
[features]
14-
default = []
14+
default = ["std"]
1515
export_tests = ["burn-tensor-testgen"]
16+
std = []
1617

1718
[dependencies]
1819
burn-common = { path = "../burn-common", version = "0.13.0" }

crates/burn-autodiff/src/backend.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::{
22
checkpoint::strategy::{CheckpointStrategy, NoCheckpointing},
33
grads::Gradients,
4-
graph::backward::backward,
4+
runtime::AutodiffClient,
55
tensor::AutodiffTensor,
66
AutodiffBridge,
77
};
@@ -53,7 +53,9 @@ impl<B: Backend, C: CheckpointStrategy> AutodiffBackend for Autodiff<B, C> {
5353
type Gradients = Gradients;
5454

5555
fn backward<const D: usize>(tensor: AutodiffTensor<B, D>) -> Gradients {
56-
backward(tensor)
56+
let client = tensor.node.client.clone();
57+
58+
AutodiffClient::backward(&client, tensor)
5759
}
5860

5961
fn grad<const D: usize>(
@@ -83,7 +85,7 @@ impl<B: Backend, C: CheckpointStrategy> AutodiffBackend for Autodiff<B, C> {
8385
grad: B::FloatTensorPrimitive<D>,
8486
) {
8587
grads.remove(tensor);
86-
grads.register::<B, D>(tensor.node.clone(), grad);
88+
grads.register::<B, D>(tensor.node.id, grad);
8789
}
8890

8991
fn int_inner<const D: usize>(

crates/burn-autodiff/src/bridge.rs

+6-6
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ where
3939
_bridge: PhantomData<Bridge>,
4040
}
4141

42-
#[derive(new, Debug)]
42+
#[derive(new, Debug, Clone)]
4343
struct RetroIntoTarget<B: Backend, Bridge: BackendBridge<B>, const D: usize> {
4444
tensor_id: NodeID,
4545
_backend: PhantomData<B>,
@@ -84,9 +84,9 @@ where
8484
_backend: PhantomData,
8585
_bridge: PhantomData,
8686
}
87-
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
87+
.prepare::<C>([tensor.node.clone()])
8888
.memory_bound()
89-
.retro_forward(RetroIntoTarget::<B, Bridge, D>::new(tensor.node.id.clone()))
89+
.retro_forward(RetroIntoTarget::<B, Bridge, D>::new(tensor.node.id))
9090
.parents([&tensor])
9191
.stateless(Bridge::into_target(tensor.primitive, None))
9292
}
@@ -101,7 +101,7 @@ where
101101
_bridge: PhantomData<Bridge>,
102102
}
103103

104-
#[derive(new, Debug)]
104+
#[derive(new, Debug, Clone)]
105105
struct RetroFromTarget<B: Backend, Bridge: BackendBridge<B>, const D: usize> {
106106
tensor_id: NodeID,
107107
_backend: PhantomData<B>,
@@ -146,9 +146,9 @@ where
146146
_backend: PhantomData,
147147
_bridge: PhantomData,
148148
}
149-
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
149+
.prepare::<C>([tensor.node.clone()])
150150
.memory_bound()
151-
.retro_forward(RetroFromTarget::<B, Bridge, D>::new(tensor.node.id.clone()))
151+
.retro_forward(RetroFromTarget::<B, Bridge, D>::new(tensor.node.id))
152152
.parents([&tensor])
153153
.stateless(Bridge::from_target(tensor.primitive, None))
154154
}

crates/burn-autodiff/src/checkpoint/base.rs

+9-13
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,20 @@
1-
use std::collections::HashMap;
2-
3-
use crate::graph::{NodeID, NodeRef};
4-
51
use super::{
62
retro_forward::RetroForwards,
73
state::{BackwardStates, State},
84
};
5+
use crate::graph::NodeID;
6+
use std::collections::HashMap;
97

108
#[derive(new, Debug)]
119
/// Links a [NodeID] to its autodiff graph [NodeRef]
1210
pub(crate) struct NodeTree {
13-
map: HashMap<NodeID, NodeRef>,
11+
map: HashMap<NodeID, Vec<NodeID>>,
1412
}
1513

1614
impl NodeTree {
1715
/// Gives the parents of the node in the autodiff graph
1816
pub(crate) fn parents(&self, node_id: &NodeID) -> Option<Vec<NodeID>> {
19-
self.map.get(node_id).map(|node| node.parents.clone())
17+
self.map.get(node_id).cloned()
2018
}
2119
}
2220

@@ -33,14 +31,12 @@ impl Checkpointer {
3331
/// or give their pre-computed tensors.
3432
pub fn retrieve_node_output<T>(&mut self, node_id: NodeID) -> T
3533
where
36-
T: Clone + Send + Sync + 'static,
34+
T: Clone + Send + 'static,
3735
{
38-
self.topological_sort(node_id.clone())
39-
.into_iter()
40-
.for_each(|node| {
41-
self.retro_forwards
42-
.execute_retro_forward(node, &mut self.backward_states)
43-
});
36+
self.topological_sort(node_id).into_iter().for_each(|node| {
37+
self.retro_forwards
38+
.execute_retro_forward(node, &mut self.backward_states)
39+
});
4440

4541
self.backward_states.get_state::<T>(&node_id)
4642
}

crates/burn-autodiff/src/checkpoint/builder.rs

+29-32
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
use std::{any::Any, collections::HashMap, sync::Arc};
2-
3-
use burn_tensor::backend::Backend;
4-
51
use crate::{
6-
graph::{ComputingProperty, NodeID, NodeRef, NodeSteps},
2+
graph::{ComputingProperty, NodeID, NodeSteps},
73
tensor::AutodiffTensor,
84
};
5+
use burn_tensor::backend::Backend;
6+
use std::{any::Any, collections::HashMap, sync::Arc};
97

108
use super::{
119
base::{Checkpointer, NodeTree},
@@ -20,31 +18,34 @@ pub enum CheckpointingAction {
2018
/// The node's already computed output should be saved
2119
Computed {
2220
/// The node
23-
node_ref: NodeRef,
21+
node_id: NodeID,
2422
/// The node's output
25-
state_content: Box<dyn Any + Send + Sync>,
23+
state_content: Box<dyn Any + Send>,
2624
},
2725
/// The node should recompute itself when asked
2826
Recompute {
2927
/// The node
30-
node_ref: NodeRef,
28+
node_id: NodeID,
3129
/// How the node should recompute itself
3230
retro_forward: Arc<dyn RetroForward>,
3331
},
3432
}
3533

34+
// TODO: Remove that when proper client server.
35+
unsafe impl Send for CheckpointingAction {}
36+
3637
impl CheckpointingAction {
3738
/// Utilitary function to access the id of the node of the checkpointing action
3839
pub fn id(&self) -> NodeID {
3940
match self {
4041
CheckpointingAction::Computed {
41-
node_ref,
42+
node_id: node_ref,
4243
state_content: _,
43-
} => node_ref.id.clone(),
44+
} => *node_ref,
4445
CheckpointingAction::Recompute {
45-
node_ref,
46+
node_id: node_ref,
4647
retro_forward: _,
47-
} => node_ref.id.clone(),
48+
} => *node_ref,
4849
}
4950
}
5051
}
@@ -83,13 +84,13 @@ impl CheckpointerBuilder {
8384
match &tensor.node.properties {
8485
ComputingProperty::ComputeBound | ComputingProperty::Ambiguous => {
8586
action_list.push(CheckpointingAction::Computed {
86-
node_ref: tensor.node.clone(),
87+
node_id: tensor.node.id,
8788
state_content: Box::new(tensor.primitive.clone()),
8889
})
8990
}
9091
ComputingProperty::MemoryBound { retro_forward } => {
9192
action_list.push(CheckpointingAction::Recompute {
92-
node_ref: tensor.node.clone(),
93+
node_id: tensor.node.id,
9394
retro_forward: retro_forward.clone(),
9495
})
9596
}
@@ -105,10 +106,6 @@ impl CheckpointerBuilder {
105106
}
106107
}
107108

108-
pub(crate) fn len(&self) -> usize {
109-
self.explicit_actions.len() + self.backup_actions.len()
110-
}
111-
112109
pub(crate) fn build(self, graph: &NodeSteps) -> Checkpointer {
113110
let node_tree = self.make_tree(graph);
114111
let mut backward_states_map = HashMap::new();
@@ -143,11 +140,11 @@ impl CheckpointerBuilder {
143140
{
144141
match action {
145142
CheckpointingAction::Computed {
146-
node_ref,
143+
node_id: node_ref,
147144
state_content: _,
148-
} => stop_nodes.push(node_ref.id.clone()),
145+
} => stop_nodes.push(*node_ref),
149146
CheckpointingAction::Recompute {
150-
node_ref: _,
147+
node_id: _,
151148
retro_forward: _,
152149
} => {}
153150
}
@@ -165,10 +162,10 @@ impl CheckpointerBuilder {
165162
for action in self.explicit_actions.iter() {
166163
match action {
167164
CheckpointingAction::Computed {
168-
node_ref,
165+
node_id: node_ref,
169166
state_content: _,
170167
} => {
171-
let id = node_ref.id.clone();
168+
let id = *node_ref;
172169
match n_required_map.remove(&id) {
173170
Some(n) => {
174171
n_required_map.insert(id, n + 1);
@@ -179,10 +176,10 @@ impl CheckpointerBuilder {
179176
};
180177
}
181178
CheckpointingAction::Recompute {
182-
node_ref,
179+
node_id: node_ref,
183180
retro_forward: _,
184181
} => {
185-
let id = node_ref.id.clone();
182+
let id = *node_ref;
186183
Self::update_n_required_of_parents(
187184
id,
188185
&mut n_required_map,
@@ -229,13 +226,13 @@ impl CheckpointerBuilder {
229226

230227
match action {
231228
CheckpointingAction::Computed {
232-
node_ref: _,
229+
node_id: _,
233230
state_content,
234231
} => {
235232
self.checkpoint_compute(backward_states_map, node_id, state_content, n_required)
236233
}
237234
CheckpointingAction::Recompute {
238-
node_ref: _,
235+
node_id: _,
239236
retro_forward,
240237
} => self.checkpoint_lazy(
241238
backward_states_map,
@@ -251,7 +248,7 @@ impl CheckpointerBuilder {
251248
fn make_tree(&self, graph: &NodeSteps) -> NodeTree {
252249
let mut tree = HashMap::default();
253250
for (id, step) in graph {
254-
tree.insert(id.clone(), step.node());
251+
tree.insert(*id, step.parents());
255252
}
256253
NodeTree::new(tree)
257254
}
@@ -267,7 +264,7 @@ impl CheckpointerBuilder {
267264
n_required_map.insert(id, n + 1);
268265
}
269266
None => {
270-
n_required_map.insert(id.clone(), 1);
267+
n_required_map.insert(id, 1);
271268
if !stop_nodes.contains(&id) {
272269
if let Some(parents) = node_tree.parents(&id) {
273270
for p in parents {
@@ -288,7 +285,7 @@ impl CheckpointerBuilder {
288285
&self,
289286
backward_states_map: &mut HashMap<NodeID, State>,
290287
node_id: NodeID,
291-
state_content: Box<dyn Any + Send + Sync>,
288+
state_content: Box<dyn Any + Send>,
292289
n_required: usize,
293290
) {
294291
backward_states_map.insert(
@@ -308,7 +305,7 @@ impl CheckpointerBuilder {
308305
retro_forward: Arc<dyn RetroForward>,
309306
n_required: usize,
310307
) {
311-
retro_forward_map.insert(node_id.clone(), retro_forward);
312-
backward_states_map.insert(node_id.clone(), State::Recompute { n_required });
308+
retro_forward_map.insert(node_id, retro_forward);
309+
backward_states_map.insert(node_id, State::Recompute { n_required });
313310
}
314311
}

crates/burn-autodiff/src/checkpoint/retro_forward.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use super::state::{BackwardStates, State};
77
/// Definition of the forward function of a node, called during retropropagation only.
88
/// This is different from the normal forward function because it reads and writes from
99
/// the [InnerStates] map instead of having a clear function signature.
10-
pub trait RetroForward: Debug + Send + Sync + 'static {
10+
pub trait RetroForward: Debug + Send + 'static {
1111
fn forward(&self, states: &mut BackwardStates, out_node: NodeID);
1212
}
1313

@@ -31,7 +31,7 @@ impl RetroForwards {
3131
{
3232
// Retro forwards are always used only once because afterwards their state is computed
3333
let retro_forward = self.map.remove(&node_id).unwrap();
34-
retro_forward.forward(backward_states, node_id.clone());
34+
retro_forward.forward(backward_states, node_id);
3535
}
3636
}
3737

@@ -48,7 +48,7 @@ macro_rules! retro_unary_scalar {
4848
$name:ident,
4949
$ops:expr
5050
) => {
51-
#[derive(new, Debug)]
51+
#[derive(new, Debug, Clone)]
5252
struct $name<B: Backend, const D: usize> {
5353
lhs_id: NodeID,
5454
rhs: FloatElem<B>,
@@ -72,7 +72,7 @@ macro_rules! retro_unary {
7272
$name:ident,
7373
$ops:expr
7474
) => {
75-
#[derive(new, Debug)]
75+
#[derive(new, Debug, Clone)]
7676
struct $name<B: Backend, const D: usize> {
7777
input_id: NodeID,
7878
_backend: PhantomData<B>,
@@ -95,7 +95,7 @@ macro_rules! retro_binary {
9595
$name:ident,
9696
$ops:expr
9797
) => {
98-
#[derive(new, Debug)]
98+
#[derive(new, Debug, Clone)]
9999
struct $name<B: Backend, const D: usize> {
100100
lhs_id: NodeID,
101101
rhs_id: NodeID,

crates/burn-autodiff/src/checkpoint/state.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::{any::Any, collections::HashMap};
33
use crate::graph::NodeID;
44

55
/// In order to accept arbitrary node output in the same hashmap, we need to upcast them to any.
6-
pub(crate) type StateContent = Box<dyn Any + Send + Sync>;
6+
pub(crate) type StateContent = Box<dyn Any + Send>;
77

88
#[derive(Debug)]
99
/// The state contained at one node. Encapsulates the node output if precomputed,
@@ -71,7 +71,7 @@ impl BackwardStates {
7171
/// This function always gives ownership of the output, but will clone it if needed for further uses.
7272
pub(crate) fn get_state<T>(&mut self, node_id: &NodeID) -> T
7373
where
74-
T: Clone + Send + Sync + 'static,
74+
T: Clone + Send + 'static,
7575
{
7676
// Fetch the state and decrement its number of required
7777
let state = self.map.remove(node_id).unwrap();
@@ -97,7 +97,7 @@ impl BackwardStates {
9797
.unwrap()
9898
.clone();
9999

100-
self.insert_state(node_id.clone(), new_stored_state);
100+
self.insert_state(*node_id, new_stored_state);
101101

102102
downcasted
103103
} else {
@@ -119,7 +119,7 @@ impl BackwardStates {
119119

120120
pub(crate) fn save<T>(&mut self, node_id: NodeID, saved_output: T)
121121
where
122-
T: Clone + Send + Sync + 'static,
122+
T: Clone + Send + 'static,
123123
{
124124
let n_required = self.get_state_ref(&node_id).unwrap().n_required();
125125
self.insert_state(

0 commit comments

Comments
 (0)