-
Notifications
You must be signed in to change notification settings - Fork 78
insertDeallocate inspects inner scopes
#6007
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f2ac6bb
678f6ba
fd49660
a92f4db
e496482
8219fe4
0cd2f48
aec9e85
a338a1c
80a701c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -12,64 +12,69 @@ | |||||
| #include <functional> | ||||||
| #include <iterator> | ||||||
| #include <list> | ||||||
| #include <ranges> | ||||||
| #include <stack> | ||||||
| #include <unordered_map> | ||||||
| #include <unordered_set> | ||||||
| #include <vector> | ||||||
|
|
||||||
| #include "fusion.h" | ||||||
| #include "host_ir/ir.h" | ||||||
| #include "ir/builder.h" | ||||||
| #include "ir/utils.h" | ||||||
|
|
||||||
| namespace nvfuser::hir { | ||||||
|
|
||||||
| namespace { | ||||||
|
|
||||||
| class DominatorTree { | ||||||
| class Node { | ||||||
| public: | ||||||
| class Node { | ||||||
| public: | ||||||
| Node(Scope* scope, Scope::Iterator iterator) | ||||||
| : scope_(scope), iterator_(iterator) {} | ||||||
| Node(const Node& other) = delete; | ||||||
| Node(Node&& other) = delete; | ||||||
| Node& operator=(const Node& other) = delete; | ||||||
| Node& operator=(Node&& other) = delete; | ||||||
|
|
||||||
| const std::vector<Node*>& children() const { | ||||||
| return children_; | ||||||
| } | ||||||
| Node(Scope* scope, Expr* expr, const Node* parent) | ||||||
| : scope_(scope), | ||||||
| expr_(expr), | ||||||
| parent_(parent), | ||||||
| depth_(parent ? parent->depth() + 1 : 0) {} | ||||||
|
|
||||||
| Scope* scope() const { | ||||||
| return scope_; | ||||||
| } | ||||||
|
|
||||||
| void addChild(Node* child) { | ||||||
| children_.push_back(child); | ||||||
| } | ||||||
| Expr* getExpr() const { | ||||||
| return expr_; | ||||||
| } | ||||||
|
|
||||||
| Scope* scope() const { | ||||||
| return scope_; | ||||||
| } | ||||||
| const Node* parent() const { | ||||||
| return parent_; | ||||||
| } | ||||||
|
|
||||||
| Scope::Iterator iterator() const { | ||||||
| return iterator_; | ||||||
| } | ||||||
| int depth() const { | ||||||
| return depth_; | ||||||
| } | ||||||
|
|
||||||
| Expr* getExpr() const { | ||||||
| return *iterator_; | ||||||
| } | ||||||
| const std::vector<Node*>& children() const { | ||||||
| return children_; | ||||||
| } | ||||||
|
|
||||||
| private: | ||||||
| // Consider putting `scope` and `iterator` into a separate Mutator class. | ||||||
| // They are only needed when the user wants to modify the host IR. | ||||||
| Scope* scope_; | ||||||
| Scope::Iterator iterator_; | ||||||
| void addChild(Node* child) { | ||||||
| children_.push_back(child); | ||||||
| } | ||||||
|
|
||||||
| std::vector<Node*> children_; | ||||||
| }; | ||||||
| private: | ||||||
| Scope* scope_; | ||||||
| Expr* expr_; | ||||||
| const Node* parent_; | ||||||
| int depth_; | ||||||
| std::vector<Node*> children_; | ||||||
| }; | ||||||
|
|
||||||
| explicit DominatorTree(hir::HostIrContainer& hic) : hic_(hic) { | ||||||
| build(hic_.topLevel(), /*parent=*/nullptr); | ||||||
| class DominatorTree { | ||||||
| public: | ||||||
| explicit DominatorTree(hir::HostIrContainer& hic) : hic_(&hic) { | ||||||
| build(hic_->topLevel(), /*parent=*/nullptr); | ||||||
| } | ||||||
|
|
||||||
| const Node* getRoot() const { | ||||||
| const auto& top_level_exprs = hic_.topLevelExprs(); | ||||||
| const auto& top_level_exprs = hic_->topLevelExprs(); | ||||||
| NVF_ERROR(!top_level_exprs.empty()); | ||||||
| Expr* root = top_level_exprs.front(); | ||||||
| return &nodes_.at(root); | ||||||
|
|
@@ -105,10 +110,8 @@ class DominatorTree { | |||||
|
|
||||||
| private: | ||||||
| void build(Scope& scope, Node* parent) { | ||||||
| for (auto scope_it = scope.exprs().begin(); scope_it != scope.exprs().end(); | ||||||
| ++scope_it) { | ||||||
| Expr* e = *scope_it; | ||||||
| auto [node_it, inserted] = nodes_.try_emplace(e, &scope, scope_it); | ||||||
| for (Expr* e : scope.exprs()) { | ||||||
| auto [node_it, inserted] = nodes_.try_emplace(e, &scope, e, parent); | ||||||
| NVF_ERROR(inserted); | ||||||
| Node& node = node_it->second; | ||||||
| if (parent != nullptr) { | ||||||
|
|
@@ -131,7 +134,77 @@ class DominatorTree { | |||||
| } | ||||||
| } | ||||||
|
|
||||||
| hir::HostIrContainer& hic_; | ||||||
| hir::HostIrContainer* hic_; | ||||||
| std::unordered_map<const Expr*, Node> nodes_; | ||||||
| }; | ||||||
|
|
||||||
| // Post-dominator tree: node A post-dominates B if every path from B to exit | ||||||
| // goes through A. Built by traversing from exit toward entry. | ||||||
| class PostDominatorTree { | ||||||
| public: | ||||||
| explicit PostDominatorTree( | ||||||
| hir::HostIrContainer& hic, | ||||||
| std::unordered_map<TensorView*, const Node*>& lca) | ||||||
| : hic_(&hic) { | ||||||
| build(hic_->topLevel(), /*scope_exit_successor=*/nullptr, lca); | ||||||
| } | ||||||
|
|
||||||
| const Node* getNode(Expr* expr) const { | ||||||
| auto it = nodes_.find(expr); | ||||||
| return it != nodes_.end() ? &it->second : nullptr; | ||||||
| } | ||||||
|
|
||||||
| private: | ||||||
| void build( | ||||||
| Scope& scope, | ||||||
| Node* parent, | ||||||
| std::unordered_map<TensorView*, const Node*>& lca) { | ||||||
| for (Expr* e : scope.exprs() | std::views::reverse) { | ||||||
| auto [node_it, inserted] = nodes_.try_emplace(e, &scope, e, parent); | ||||||
| NVF_ERROR(inserted); | ||||||
| Node& node = node_it->second; | ||||||
|
|
||||||
| if (auto* alloc = dynamic_cast<kir::Allocate*>(e)) { | ||||||
| TensorView* tv = alloc->buffer()->as<TensorView>(); | ||||||
| lca[tv] = findLCA(lca[tv], &node); | ||||||
| } | ||||||
| for (auto* in : ir_utils::filterByType<TensorView>(e->inputs())) { | ||||||
| lca[in] = findLCA(lca[in], &node); | ||||||
| } | ||||||
|
|
||||||
| if (auto* loop = dynamic_cast<hir::ForLoop*>(e)) { | ||||||
| build(loop->body(), &node, lca); | ||||||
| } | ||||||
| if (auto* ite = dynamic_cast<kir::IfThenElse*>(e)) { | ||||||
| build(ite->thenBody(), &node, lca); | ||||||
| build(ite->elseBody(), &node, lca); | ||||||
| } | ||||||
|
|
||||||
| parent = &node; | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| const Node* findLCA(const Node* a, const Node* b) const { | ||||||
| if (a == nullptr) { | ||||||
| return b; | ||||||
| } | ||||||
| if (b == nullptr) { | ||||||
| return a; | ||||||
| } | ||||||
| while (a->depth() > b->depth()) { | ||||||
| a = a->parent(); | ||||||
| } | ||||||
| while (b->depth() > a->depth()) { | ||||||
| b = b->parent(); | ||||||
| } | ||||||
| while (a != b) { | ||||||
| a = a->parent(); | ||||||
| b = b->parent(); | ||||||
| } | ||||||
| return a; | ||||||
| } | ||||||
|
|
||||||
| hir::HostIrContainer* hic_; | ||||||
| std::unordered_map<const Expr*, Node> nodes_; | ||||||
| }; | ||||||
|
|
||||||
|
|
@@ -159,7 +232,7 @@ void insertAllocations(hir::HostIrContainer& hic) { | |||||
|
|
||||||
| dom_tree.depthFirstTraverse( | ||||||
| /*pre_fn=*/ | ||||||
| [&](const DominatorTree::Node* node) { | ||||||
| [&](const Node* node) { | ||||||
| Expr* e = node->getExpr(); | ||||||
| // If `e`'s output needs preallocation but isn't defined, insert an | ||||||
| // allocation right before `e`. | ||||||
|
|
@@ -171,64 +244,60 @@ void insertAllocations(hir::HostIrContainer& hic) { | |||||
| if (needsOutputPreallocation(e)) { | ||||||
| auto* allocate = | ||||||
| IrBuilder::create<kir::Allocate>(out, out->getMemoryType()); | ||||||
| node->scope()->insert(node->iterator(), allocate); | ||||||
| node->scope()->insert_before(node->getExpr(), allocate); | ||||||
| } | ||||||
|
|
||||||
| defined.insert(out); | ||||||
| } | ||||||
| }, | ||||||
| /*post_fn=*/ | ||||||
| [&](const DominatorTree::Node* node) { | ||||||
| [&](const Node* node) { | ||||||
| Expr* e = node->getExpr(); | ||||||
| for (auto* out : ir_utils::filterByType<TensorView>(e->outputs())) { | ||||||
| defined.erase(out); | ||||||
| } | ||||||
| }); | ||||||
| } | ||||||
|
|
||||||
| bool needsDeallocation(TensorView* tv) { | ||||||
| if (tv->isFusionInput()) { | ||||||
| return false; | ||||||
| } | ||||||
| if (tv->isFusionOutput()) { | ||||||
| return false; | ||||||
| } | ||||||
| if (tv->definition()->isA<ShardByStream>()) { | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add null check before dereferencing
Suggested change
|
||||||
| return false; | ||||||
| } | ||||||
| const AliasInfo& alias_info = tv->container()->getOutputAlias(tv); | ||||||
| if (alias_info.type == AllocationType::ReuseBuffer) { | ||||||
| return false; | ||||||
| } | ||||||
| return true; | ||||||
| } | ||||||
|
|
||||||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not complete. Ops like view do not always allocate new tensors. Few things make that analysis tricky:
One solution can be to explicitly allocate expr eval outputs where needed like we do for matmul/linear. Then, we only deallocate tvs that are allocated. The previous version did not make any distinction for view-like ops, so the functionality does not regress. What do you think @wujingyue
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I'm missing some context. Can you remind me why this PR needs to change how we decide what needs deallocation? I understood the motivation of looking into loops but I'm missing some connections otherwise.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This PR does not need to necessarily change this. But we do need to decide what needs deallocation since not all ops allocate new tensors. I initially started with deallocating only explicitly "allocated" tensorviews. However, that breaks the
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Got it. This is actually the old behavior. It didn't trigger this problem because ShardByStream is never top-level.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. HostIrEvaluator handles deallocation by removing the tensor from the underlying hash table. It doesn't always free the memory. What problems did you run into with ShardByStream exactly? I can try it myself tomorrow. Not on a computer right now
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Correct. I did not run into any errors with existing tests since For simplicity, if you prefer, I can remove the additional conditions from this PR, and we can discuss that in a separate PR. |
||||||
| void insertDeallocations(hir::HostIrContainer& hic) { | ||||||
| const std::list<Expr*>& top_level_exprs = hic.topLevelExprs(); | ||||||
| std::for_each(top_level_exprs.begin(), top_level_exprs.end(), [](Expr* expr) { | ||||||
| std::ranges::for_each(top_level_exprs, [](Expr* expr) { | ||||||
| NVF_ERROR( | ||||||
| !expr->isA<hir::Deallocate>(), | ||||||
| "Expected hostir container to not have deallocate, but found one " | ||||||
| "anyways: ", | ||||||
| expr); | ||||||
| }); | ||||||
|
|
||||||
| // For each input in every expression in the container, find the position of | ||||||
| // its last use and insert a deallocate directly after, except for fusion | ||||||
| // inputs and outputs. | ||||||
| std::unordered_set<TensorView*> last_use_found; | ||||||
| for (auto insertion_point = top_level_exprs.end(); | ||||||
| insertion_point != top_level_exprs.begin();) { | ||||||
| auto prev = std::prev(insertion_point); | ||||||
| Expr* e = *prev; | ||||||
|
|
||||||
| // Only tensors need to be allocated. | ||||||
| for (auto* in : ir_utils::filterByType<TensorView>(e->inputs())) { | ||||||
| // Fusion inputs are managed by the caller. | ||||||
| if (in->isFusionInput()) { | ||||||
| continue; | ||||||
| } | ||||||
| std::unordered_map<TensorView*, const Node*> lca; | ||||||
| PostDominatorTree post_dom_tree(hic, lca); | ||||||
|
|
||||||
| // Fusion outputs need to be kept alive for the caller. | ||||||
| if (in->isFusionOutput()) { | ||||||
| continue; | ||||||
| } | ||||||
|
|
||||||
| // Skip if `e` is not the last use. | ||||||
| if (!last_use_found.insert(in).second) { | ||||||
| continue; | ||||||
| } | ||||||
|
|
||||||
| auto* deallocate = IrBuilder::create<hir::Deallocate>(in); | ||||||
| hic.insertExprBefore(insertion_point, deallocate); | ||||||
| // Insert deallocate at LCA for each TV that needs deallocation. | ||||||
| for (const auto& [tv, lca_node] : lca) { | ||||||
| if (!needsDeallocation(tv)) { | ||||||
| continue; | ||||||
| } | ||||||
|
|
||||||
| // Don't `--insertion_point;` because we'd like to skip newly inserted | ||||||
| // deallocations. | ||||||
| insertion_point = prev; | ||||||
| NVF_ERROR( | ||||||
| lca_node != nullptr, "Could not find post-dominator for tensor ", tv); | ||||||
| auto* deallocate = IrBuilder::create<hir::Deallocate>(tv); | ||||||
| lca_node->scope()->insert_after(lca_node->getExpr(), deallocate); | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: create different nodes in each tree to avoid overloading too much.