Skip to content

Commit

Permalink
Fix autodiff memory management graph cleaning (#1602)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Apr 11, 2024
1 parent 0cbe9a9 commit 07a61a1
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 0 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/burn-autodiff/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.13.0", opt

derive-new = { workspace = true }
spin = { workspace = true }
log = { workspace = true }

[dev-dependencies]
burn-tensor = { path = "../burn-tensor", version = "0.13.0", default-features = false, features = [
Expand Down
7 changes: 7 additions & 0 deletions crates/burn-autodiff/src/runtime/memory_management.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ impl GraphMemoryManagement {

for node_id in graph.into_iter() {
func(&node_id);
self.graphs.remove(&GraphId::new(*node_id));
}
}

Expand Down Expand Up @@ -258,6 +259,9 @@ mod tests {
assert!(node_ids.contains(&node_1));
assert!(node_ids.contains(&node_2));

assert_eq!(graph_mm.graphs.len(), 0);
assert_eq!(graph_mm.owned.len(), 0);

// Same but with free(node_2);
graph_mm.register(node_1.clone(), vec![]);
graph_mm.register(node_2.clone(), vec![*node_1]);
Expand All @@ -267,5 +271,8 @@ mod tests {

assert!(node_ids.contains(&node_1));
assert!(node_ids.contains(&node_2));

assert_eq!(graph_mm.graphs.len(), 0);
assert_eq!(graph_mm.owned.len(), 0);
}
}
5 changes: 5 additions & 0 deletions crates/burn-autodiff/src/runtime/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::{
checkpoint::{base::Checkpointer, builder::CheckpointerBuilder},
grads::Gradients,
graph::{traversal::BreadthFirstSearch, StepBoxed},
runtime::memory_management::GraphId,
tensor::NodeRefCount,
NodeID,
};
Expand Down Expand Up @@ -63,6 +64,10 @@ impl AutodiffServer {
.collect::<Vec<_>>();

BreadthFirstSearch.traverse(root, root_step, &mut self.steps, |id, step| {
// We consume that node for the tape, so we should remove it from the
// memory_management.
self.memory_management.free_graph(GraphId::new(id), |_| {});

let order = step.order();
if order == 0 {
return;
Expand Down

0 comments on commit 07a61a1

Please sign in to comment.