Skip to content

Commit 48e620f

Browse files
authored
Merge branch 'main' into fp8_enable_on_sm89
2 parents 44e949b + 49b0862 commit 48e620f

File tree

6 files changed

+546
-49
lines changed

6 files changed

+546
-49
lines changed

csrc/scheduler/resize.cpp

+10-31
Original file line numberDiff line numberDiff line change
@@ -71,40 +71,19 @@ bool ResizeScheduler::canScheduleCompileTime(Fusion* fusion) {
7171
IdModel id_model(fusion, /*build_graphs=*/false);
7272
const auto& broadcast_graph = id_model.buildBroadcastGraph();
7373

74-
// For now, only a single resize op is allowed to exist.
7574
auto resize_based_tensor_ops = ir_utils::getOpsOfType<SliceOp, PadOp>(fusion);
76-
if (resize_based_tensor_ops.size() != 1) {
77-
scheduler_debug_utils::canScheduleRejectReason(
78-
schedulerType(), "Only a single resize op is allowed.");
79-
return false;
80-
}
8175

82-
auto resize_out_tv =
83-
resize_based_tensor_ops.at(0)->output(0)->as<TensorView>();
84-
85-
auto all_dep_vals = DependencyCheck::getAllValsBetween(
86-
{fusion->inputs().begin(), fusion->inputs().end()}, {resize_out_tv});
87-
for (auto tv : ir_utils::filterByType<TensorView>(all_dep_vals)) {
88-
if (tv == resize_out_tv) {
89-
continue;
90-
}
91-
if (tv->isFusionOutput()) {
92-
scheduler_debug_utils::canScheduleRejectReason(
93-
schedulerType(),
94-
"Dependency to fusion output not allowed: ",
95-
tv->toString());
96-
return false;
97-
}
98-
for (auto consumer_of_tv : ir_utils::consumerTvsOf(tv)) {
99-
if (std::find(all_dep_vals.begin(), all_dep_vals.end(), consumer_of_tv) ==
100-
all_dep_vals.end()) {
101-
scheduler_debug_utils::canScheduleRejectReason(
102-
schedulerType(),
103-
"Resize inputs must be exclusively consumed by resize: ",
104-
consumer_of_tv->toString());
105-
return false;
106-
}
76+
if (auto non_exclusive_resizes = scheduler_tools::getNonExclusiveResizeInfo(
77+
resize_based_tensor_ops, id_model.idGraph(IdMappingMode::EXACT));
78+
!non_exclusive_resizes.empty()) {
79+
std::stringstream msg;
80+
msg << "Propagation of resizes would affect fusion outputs.";
81+
for (const auto& [tv, resize_ids] : non_exclusive_resizes) {
82+
msg << " Resize input tv: " << tv->toString()
83+
<< ", resize input ID groups: " << nvfuser::toString(resize_ids);
10784
}
85+
scheduler_debug_utils::canScheduleRejectReason(schedulerType(), msg.str());
86+
return false;
10887
}
10988

11089
// Slicing of or to a broadcast ID is not allowed yet.

csrc/scheduler/tools/resize_utils.cpp

+106
Original file line numberDiff line numberDiff line change
@@ -66,5 +66,111 @@ void propagateResizeToInputs(Expr* resize_tensor_op) {
6666
}
6767
}
6868

69+
std::unordered_map<TensorView*, ValGroups> getNonExclusiveResizeInfo(
70+
const std::vector<Expr*>& ordered_resize_tensor_ops,
71+
const ValGraph& exact_graph) {
72+
NVF_ERROR(!ordered_resize_tensor_ops.empty());
73+
Fusion* fusion = ordered_resize_tensor_ops[0]->fusion();
74+
75+
std::unordered_map<TensorView*, ValGroups> non_exclusive_resizes;
76+
77+
std::unordered_set<Val*> inputs{
78+
fusion->inputs().begin(), fusion->inputs().end()};
79+
80+
auto get_root_to_logical_resizes =
81+
[&exact_graph](TensorView* tv) -> ValGroups {
82+
// This should be only used for outputs of resize-based ops,
83+
// so it should always have a root domain.
84+
NVF_ERROR(tv->hasRoot());
85+
auto out_tv_root_to_logical_exprs = DependencyCheck::getAllExprsBetween(
86+
{tv->getRootDomain().begin(), tv->getRootDomain().end()},
87+
{tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()});
88+
ValGroups resize_inp_ids;
89+
for (auto resize :
90+
ir_utils::filterByType<Resize>(out_tv_root_to_logical_exprs)) {
91+
resize_inp_ids.pushBack(exact_graph.toGroup(resize->in()));
92+
}
93+
return resize_inp_ids;
94+
};
95+
96+
// Traverse the ops in a topological order
97+
for (Expr* resize_tensor_op : ordered_resize_tensor_ops) {
98+
auto inp_tv = dynamic_cast<TensorView*>(resize_tensor_op->inputs().at(0));
99+
auto out_tv = dynamic_cast<TensorView*>(resize_tensor_op->outputs().at(0));
100+
101+
ValGroups resize_inp_ids = get_root_to_logical_resizes(out_tv);
102+
NVF_ERROR(!resize_inp_ids.empty());
103+
104+
auto dep_vals =
105+
DependencyCheck::getAllValsBetween(inputs, std::vector<Val*>{inp_tv});
106+
107+
// For each tensor that inp_tv depends on, check if the resize op
108+
// is considered non-exclusive with respect to the tensor. That
109+
// is, if propagation of the resize may result in externally
110+
// visible changes through the tensor, the resize is considered
111+
// non-exclusive.
112+
for (auto dep_tv : ir_utils::filterByType<TensorView>(dep_vals)) {
113+
bool maybe_non_exclusive = dep_tv->isFusionOutput();
114+
115+
if (!maybe_non_exclusive) {
116+
// If a dependent tv has a consumer that inp_tv does not
117+
// depend on, propagation of resize would escape to outputs,
118+
// which needs to be avoided.
119+
for (auto consumer_tv : ir_utils::consumerTvsOf(dep_tv)) {
120+
// We are interested in if resized IDs are used by other tensors
121+
// than out_tv
122+
if (consumer_tv != out_tv &&
123+
std::find(dep_vals.begin(), dep_vals.end(), consumer_tv) ==
124+
dep_vals.end()) {
125+
maybe_non_exclusive = true;
126+
break;
127+
}
128+
}
129+
}
130+
131+
if (!maybe_non_exclusive) {
132+
continue;
133+
}
134+
135+
// dep_tv potentially is either a fusion output or it has a
136+
// consumer outside of the dependency set to the resized
137+
// tensor. Propagating the resize to dep_tv should be
138+
// avoided. However, if the dep_tv iter domain that corresponds
139+
// to the resized ID is a broadcast or there's no such ID, it
140+
// should still be safe to consider the resize op exclusive as
141+
// there's no iter domain to resize. For a concrete example, see
142+
// ResizeSchedulerTest.PropagateMultipleSlicesToInputs4.
143+
const auto inp_tv_logical_groups =
144+
exact_graph.toGroups(inp_tv->getLogicalDomain());
145+
const auto dep_tv_logical_groups =
146+
exact_graph.toGroups(dep_tv->getLogicalDomain());
147+
auto vals_between = getValsBetween<ValGraphBFS>(
148+
{inp_tv_logical_groups.begin(), inp_tv_logical_groups.end()},
149+
{dep_tv_logical_groups.begin(), dep_tv_logical_groups.end()},
150+
exact_graph);
151+
152+
for (const ValGroup& resize_inp_id : resize_inp_ids) {
153+
if (std::find(
154+
vals_between.begin(), vals_between.end(), resize_inp_id) ==
155+
vals_between.end()) {
156+
// This resize can be ignored as there's no corresponding ID
157+
// in the dep tv
158+
continue;
159+
}
160+
161+
// This resize input ID is not exclusively used
162+
non_exclusive_resizes[inp_tv].pushBack(resize_inp_id);
163+
}
164+
}
165+
166+
// Analysis of exclusiveness until in_tv is done. Following
167+
// resize-based tensor ops do not need to check the same section
168+
// of the fusion and can start from out_tv.
169+
inputs.insert(out_tv);
170+
}
171+
172+
return non_exclusive_resizes;
173+
}
174+
69175
} // namespace scheduler_tools
70176
} // namespace nvfuser

csrc/scheduler/tools/resize_utils.h

+79
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@
77
// clang-format on
88
#pragma once
99

10+
#include <val_graph.h>
11+
1012
namespace nvfuser {
1113

1214
class Expr;
15+
class TensorView;
1316

1417
namespace scheduler_tools {
1518

@@ -19,5 +22,81 @@ namespace scheduler_tools {
1922
// fusion inputs are skipped as their loop domains don't matter.
2023
void propagateResizeToInputs(Expr* resize_op);
2124

25+
// Given a topologically ordered list of resize-based tensor ops such
26+
// as slice and pad, check if they can be propagated to fusion inputs
27+
// exclusively without causing any visible side effect. For example,
28+
// if a tensor is sliced and also is used to produce an output without
29+
// the slicing, the slice is considered non exclusive as the slice
30+
// input has the other visible consumer. Propagating the resize of the
31+
// slice to the slice input is invalid since the output computed from
32+
// the slice input depends on the full iteration space.
33+
//
34+
// For example, consider the following case:
35+
//
36+
// t0 = makeSymbolicTensor(1)
37+
// fusion.addInput(t0)
38+
// t1 = t0 + 1
39+
// t2 = t1[1:10]
40+
// t3 = t1 + 1
41+
// fusion.addOutput(t2)
42+
// fusion.addOutput(t3)
43+
//
44+
// In this case, propating the resize op of the slice would alter t1,
45+
// which would in turn affect t3, which is a fusion output. Since the
46+
// change would be visible due to the change of t3, this resize op is
47+
// considered non-exclusive.
48+
//
49+
// Consider a slightly different case as shown below:
50+
//
51+
// t0 = makeSymbolicTensor(1)
52+
// fusion.addInput(t0)
53+
// t1 = t0[1:10]
54+
// t2 = t0 + 1
55+
// fusion.addOutput(t1)
56+
// fusion.addOutput(t2)
57+
//
58+
// Note that the slice is directly done with the fusion input. Since
59+
// we do not propagate resize ops to fusion inputs, this can be
60+
// considered exclusive. However, this is also considered
61+
// non-exclusive since the actual scheduling inserts a cache after t0,
62+
// which can cause a visible side effect if the resize is propagated.
63+
//
64+
// Another non-exclusivness comes from dependent fusion outputs. For
65+
// example, if a slice input depends on a fusion output, propagation
66+
// would alter the fusion output. Consider a case like:
67+
//
68+
// t0 = makeSymbolicTensor(1)
69+
// fusion.addInput(t0)
70+
// t1 = t0 + 1
71+
// t2 = t1[1:10] // slice
72+
// fusion.addOutput(t1)
73+
// fusion.addOutput(t2)
74+
//
75+
// If the resize op for the slice is propagated to t1, only the
76+
// section of [1:10] would be computed. Since that would change a
77+
// fusion output, the resize op is considered non-exclusive.
78+
//
79+
// When there's a chain of resize-based ops, for example:
80+
//
81+
// t0 = makeSymbolicTensor(1)
82+
// fusion.addInput(t0)
83+
// t1 = t0 + 1
84+
// t2 = t1[1:10]
85+
// t3 = t2[2:5]
86+
// t4 = t1 + 1
87+
// fusion.addOutput(t3)
88+
// fusion.addOutput(t4)
89+
//
90+
// We do not consider the second slice as non-exclusive as
91+
// long as the first slice is considered non-exclusive. This will be
92+
// important when resolving the non-exclusiveness by replication.
93+
//
94+
// The function returns a map from tensors that are input to
95+
// non-exclusive ops to their resize input ID groups. This map will be
96+
// used to resolve the non-exclusiveness by replication.
97+
std::unordered_map<TensorView*, ValGroups> getNonExclusiveResizeInfo(
98+
const std::vector<Expr*>& ordered_resize_tensor_ops,
99+
const ValGraph& exact_graph);
100+
22101
} // namespace scheduler_tools
23102
} // namespace nvfuser

tests/cpp/test_gpu3.cpp

-2
Original file line numberDiff line numberDiff line change
@@ -9249,8 +9249,6 @@ TEST_F(NVFuserTest, AllIdsMultipleDependencies) {
92499249
tv1->split(0, 4);
92509250
tv1->split(0, 8);
92519251

9252-
fusion.print();
9253-
92549252
auto all_ids = tv1->domain()->allIDs();
92559253

92569254
auto split2 = tv1->axis(0)->definition()->as<Split>();

tests/cpp/test_matmul_scheduler.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -1060,7 +1060,7 @@ TEST_F(MatmulSchedulerTest, FusedMultiplySumOnly) {
10601060
// for Ampere with strict ref check, hence single layout check
10611061
TEST_F(MatmulSchedulerTest, BasicMatmulStrictCheckTT) {
10621062
// TODO: Make these tests work with Hopper as well as Ampere
1063-
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 8, 9);
1063+
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 9, 0);
10641064

10651065
const int M = 128, N = 256, K = 512;
10661066
const auto layout = MmaLayout::TT;
@@ -2481,7 +2481,7 @@ class MatmulSchedulerPluginTest : public NVFuserTest {
24812481

24822482
// Test that our fake plugin works to override the default heuristic
24832483
TEST_F(MatmulSchedulerPluginTest, BasicMatmul) {
2484-
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 8, 9);
2484+
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 9, 0);
24852485
const int M = 128, N = 256, K = 512;
24862486
const auto layout = MmaLayout::TT;
24872487
auto fusion = std::make_unique<Fusion>();
@@ -3156,7 +3156,7 @@ INSTANTIATE_TEST_SUITE_P(
31563156
#undef NVFUSER_TEST_CUDA_ARCH_GUARD
31573157

31583158
TEST_F(MatmulSchedulerTest, OperandOrderIssue2434) {
3159-
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 8, 9);
3159+
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 9, 0);
31603160
int M = 32, N = 64, K = 128;
31613161

31623162
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();

0 commit comments

Comments
 (0)