forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
common_subexpression_elimination.cpp
130 lines (107 loc) · 3.85 KB
/
common_subexpression_elimination.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/node_hashing.h>
#include <torch/csrc/jit/jit_log.h>
#include <unordered_map>
namespace torch {
namespace jit {
namespace {
struct CommonSubexpressionEliminator {
CommonSubexpressionEliminator(std::shared_ptr<Graph> graph)
: graph_(std::move(graph)) {}
bool run(std::function<Node*(Node*)> parent_lookup_fn) {
return run(graph_->block(), std::move(parent_lookup_fn));
}
// The function implements common subexpression elimination.
// Since the nodes are visited in topological order, one pass is enough.
// returns true if CSE made changes to a graph
bool run(Block* block, std::function<Node*(Node*)> parent_lookup_fn) {
std::unordered_set<Node*, HashNode, EqualNode> subexprs;
bool changed = false;
for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
auto node = *it;
if (node->kind() == prim::profile) {
GRAPH_DEBUG(
"Profiled nodes shouldn't be CSE'ed there's a separate pass that does dedup and merging:\n",
*node);
continue;
}
if (node->hasSideEffects()) {
GRAPH_DEBUG("Node was skipped due to side effects:\n", *node);
continue;
}
if (node->isNondeterministic()) {
GRAPH_DEBUG("Node was skipped due to its non determinism:\n", *node);
continue;
}
if (!node->blocks().empty()) {
// Traverse sub-blocks.
for (auto block : node->blocks()) {
changed |= run(block, [&](Node* n) {
auto existing = subexprs.find(n);
if (existing != subexprs.end()) {
return *existing;
}
return parent_lookup_fn(n);
});
}
continue;
}
if (getOrCreateAliasDb().hasWriters(node)) {
GRAPH_DEBUG("Node was skipped due to alias analysis result:\n", *node);
// Do NOT have enough information to do CSE on these nodes.
continue;
}
// Check for CSE opportunities in the parent block.
auto parent_lookup = parent_lookup_fn(node);
auto g_out = node->owningGraph()->outputs();
if (parent_lookup != nullptr) {
if (!getOrCreateAliasDb().safeToChangeAliasingRelationship(
node->outputs(), parent_lookup->outputs())) {
continue;
}
GRAPH_UPDATE("Replacing\n", *node, "with\n", *parent_lookup);
changed = true;
node->replaceAllUsesWith(parent_lookup);
it.destroyCurrent();
continue;
}
// Check whether the same subexpression already exists.
auto subit = subexprs.insert(node);
if (!subit.second) {
// Subexpression exists, replace the uses of node, and destroy it.
auto existing = *subit.first;
// don't introduce new aliasing among graph outputs
if (getOrCreateAliasDb().mayContainAlias(
node->outputs(), node->owningGraph()->outputs()) &&
getOrCreateAliasDb().mayContainAlias(existing->outputs(), g_out)) {
continue;
}
GRAPH_UPDATE("Replacing\n", *node, "with\n", *existing);
changed = true;
node->replaceAllUsesWith(existing);
// Destroy the node.
it.destroyCurrent();
}
}
return changed;
}
AliasDb& getOrCreateAliasDb() {
if (!alias_db_) {
alias_db_ = std::make_unique<AliasDb>(graph_);
}
return *alias_db_;
}
private:
std::unique_ptr<AliasDb> alias_db_;
std::shared_ptr<Graph> graph_;
};
} // namespace
bool EliminateCommonSubexpression(const std::shared_ptr<Graph>& graph) {
GRAPH_DUMP("Before CSE", graph);
CommonSubexpressionEliminator cse(graph);
return cse.run([](Node*) { return nullptr; });
}
} // namespace jit
} // namespace torch