@@ -66,5 +66,111 @@ void propagateResizeToInputs(Expr* resize_tensor_op) {
66
66
}
67
67
}
68
68
69
+ std::unordered_map<TensorView*, ValGroups> getNonExclusiveResizeInfo (
70
+ const std::vector<Expr*>& ordered_resize_tensor_ops,
71
+ const ValGraph& exact_graph) {
72
+ NVF_ERROR (!ordered_resize_tensor_ops.empty ());
73
+ Fusion* fusion = ordered_resize_tensor_ops[0 ]->fusion ();
74
+
75
+ std::unordered_map<TensorView*, ValGroups> non_exclusive_resizes;
76
+
77
+ std::unordered_set<Val*> inputs{
78
+ fusion->inputs ().begin (), fusion->inputs ().end ()};
79
+
80
+ auto get_root_to_logical_resizes =
81
+ [&exact_graph](TensorView* tv) -> ValGroups {
82
+ // This should be only used for outputs of resize-based ops,
83
+ // so it should always have a root domain.
84
+ NVF_ERROR (tv->hasRoot ());
85
+ auto out_tv_root_to_logical_exprs = DependencyCheck::getAllExprsBetween (
86
+ {tv->getRootDomain ().begin (), tv->getRootDomain ().end ()},
87
+ {tv->getLogicalDomain ().begin (), tv->getLogicalDomain ().end ()});
88
+ ValGroups resize_inp_ids;
89
+ for (auto resize :
90
+ ir_utils::filterByType<Resize>(out_tv_root_to_logical_exprs)) {
91
+ resize_inp_ids.pushBack (exact_graph.toGroup (resize->in ()));
92
+ }
93
+ return resize_inp_ids;
94
+ };
95
+
96
+ // Traverse the ops in a topological order
97
+ for (Expr* resize_tensor_op : ordered_resize_tensor_ops) {
98
+ auto inp_tv = dynamic_cast <TensorView*>(resize_tensor_op->inputs ().at (0 ));
99
+ auto out_tv = dynamic_cast <TensorView*>(resize_tensor_op->outputs ().at (0 ));
100
+
101
+ ValGroups resize_inp_ids = get_root_to_logical_resizes (out_tv);
102
+ NVF_ERROR (!resize_inp_ids.empty ());
103
+
104
+ auto dep_vals =
105
+ DependencyCheck::getAllValsBetween (inputs, std::vector<Val*>{inp_tv});
106
+
107
+ // For each tensor that inp_tv depends on, check if the resize op
108
+ // is considered non-exclusive with respect to the tensor. That
109
+ // is, if propagation of the resize may result in externally
110
+ // visible changes through the tensor, the resize is considered
111
+ // non-exclusive.
112
+ for (auto dep_tv : ir_utils::filterByType<TensorView>(dep_vals)) {
113
+ bool maybe_non_exclusive = dep_tv->isFusionOutput ();
114
+
115
+ if (!maybe_non_exclusive) {
116
+ // If a dependent tv has a consumer that inp_tv does not
117
+ // depend on, propagation of resize would escape to outputs,
118
+ // which needs to be avoided.
119
+ for (auto consumer_tv : ir_utils::consumerTvsOf (dep_tv)) {
120
+ // We are interested in if resized IDs are used by other tensors
121
+ // than out_tv
122
+ if (consumer_tv != out_tv &&
123
+ std::find (dep_vals.begin (), dep_vals.end (), consumer_tv) ==
124
+ dep_vals.end ()) {
125
+ maybe_non_exclusive = true ;
126
+ break ;
127
+ }
128
+ }
129
+ }
130
+
131
+ if (!maybe_non_exclusive) {
132
+ continue ;
133
+ }
134
+
135
+ // dep_tv potentially is either a fusion output or it has a
136
+ // consumer outside of the dependency set to the resized
137
+ // tensor. Propagating the resize to dep_tv should be
138
+ // avoided. However, if the dep_tv iter domain that corresponds
139
+ // to the resized ID is a broadcast or there's no such ID, it
140
+ // should still be safe to consider the resize op exclusive as
141
+ // there's no iter domain to resize. For a concrete example, see
142
+ // ResizeSchedulerTest.PropagateMultipleSlicesToInputs4.
143
+ const auto inp_tv_logical_groups =
144
+ exact_graph.toGroups (inp_tv->getLogicalDomain ());
145
+ const auto dep_tv_logical_groups =
146
+ exact_graph.toGroups (dep_tv->getLogicalDomain ());
147
+ auto vals_between = getValsBetween<ValGraphBFS>(
148
+ {inp_tv_logical_groups.begin (), inp_tv_logical_groups.end ()},
149
+ {dep_tv_logical_groups.begin (), dep_tv_logical_groups.end ()},
150
+ exact_graph);
151
+
152
+ for (const ValGroup& resize_inp_id : resize_inp_ids) {
153
+ if (std::find (
154
+ vals_between.begin (), vals_between.end (), resize_inp_id) ==
155
+ vals_between.end ()) {
156
+ // This resize can be ignored as there's no corresponding ID
157
+ // in the dep tv
158
+ continue ;
159
+ }
160
+
161
+ // This resize input ID is not exclusively used
162
+ non_exclusive_resizes[inp_tv].pushBack (resize_inp_id);
163
+ }
164
+ }
165
+
166
+ // Analysis of exclusiveness until in_tv is done. Following
167
+ // resize-based tensor ops do not need to check the same section
168
+ // of the fusion and can start from out_tv.
169
+ inputs.insert (out_tv);
170
+ }
171
+
172
+ return non_exclusive_resizes;
173
+ }
174
+
69
175
} // namespace scheduler_tools
70
176
} // namespace nvfuser
0 commit comments