forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_subgraph_rewriter.cpp
130 lines (108 loc) · 3.45 KB
/
test_subgraph_rewriter.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#include <test/cpp/jit/test_base.h>
#include <test/cpp/jit/test_utils.h>
#include <torch/csrc/jit/ir/subgraph_matcher.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include <torch/csrc/jit/testing/file_check.h>
namespace torch {
namespace jit {
using namespace testing;
void testFilterMatch() {
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph(%0):
%a = a::aaa(%0)
%b : int = prim::Constant[value=1]()
%c = c::ccc(%a, %b)
return (%c))IR",
graph.get());
std::string pattern = R"IR(
graph(%a, %b):
%c = c::ccc(%a, %b)
return (%c))IR";
Graph pattern_graph;
std::unordered_map<std::string, Value*> vmap;
parseIR(pattern, &pattern_graph, vmap);
auto b_is_constant = [](const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
const auto& match_vmap = match.values_map;
auto b_node = match_vmap.at(vmap.at("b"))->node();
return b_node->kind() == prim::Constant;
};
auto b_is_one = [](const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
const auto& match_vmap = match.values_map;
auto b_val = toIValue(match_vmap.at(vmap.at("b")));
return b_val && b_val->isInt() && b_val->toInt() == 1;
};
auto b_is_two = [](const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
const auto& match_vmap = match.values_map;
auto b_val = toIValue(match_vmap.at(vmap.at("b")));
return b_val && b_val->isInt() && b_val->toInt() == 2;
};
std::string replacement = R"IR(
graph(%a, %b):
%d = d::ddd(%a, %b)
return (%d))IR";
SubgraphRewriter rewriter;
rewriter.RegisterRewritePattern(pattern, replacement);
// b is constant, so the match will succeed
{
auto g = graph->copy();
rewriter.runOnGraph(g, b_is_constant);
FileCheck().check("d::ddd")->check_not("c::ccc")->run(*g);
}
// b is constant and the value is one, the match will succeed
{
auto g = graph->copy();
rewriter.runOnGraph(g, {b_is_constant, b_is_one});
FileCheck().check("d::ddd")->check_not("c::ccc")->run(*g);
}
// b is constant but the value is not two, the match will fail
{
auto g = graph->copy();
rewriter.runOnGraph(g, {b_is_constant, b_is_two});
FileCheck().check("c::ccc")->check_not("d::ddd")->run(*g);
}
}
void testFilterNoMatch() {
auto graph = std::make_shared<Graph>();
parseIR(
R"IR(
graph(%0):
%a = a::aaa(%0)
%b = prim::Constant[value=1]()
%c = c::ccc(%a, %b)
return (%c))IR",
graph.get());
std::string pattern = R"IR(
graph(%a, %b):
%c = c::ccc(%a, %b)
return (%c))IR";
Graph pattern_graph;
std::unordered_map<std::string, Value*> vmap;
parseIR(pattern, &pattern_graph, vmap);
auto filter = [](const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
const auto& match_vmap = match.values_map;
auto b_node = match_vmap.at(vmap.at("b"))->node();
// b_node is not prim::Assign, so this won't match and we'll skip the
// rewrite
return b_node->kind() == prim::Assign;
};
std::string replacement = R"IR(
graph(%a, %b):
%d = d::ddd(%a, %b)
return (%d))IR";
SubgraphRewriter rewriter;
rewriter.RegisterRewritePattern(pattern, replacement);
rewriter.runOnGraph(graph, filter);
FileCheck().check("c::ccc")->check_not("d::ddd")->run(*graph);
}
void testSubgraphRewriter() {
testFilterMatch();
testFilterNoMatch();
}
} // namespace jit
} // namespace torch