From 3121b7c47a7332269e6b5248b2493af06eb52852 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Sat, 21 Dec 2024 22:19:50 -0800 Subject: [PATCH 1/3] repro --- tests/cpp/test_indexing.cpp | 108 ++++++++++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) diff --git a/tests/cpp/test_indexing.cpp b/tests/cpp/test_indexing.cpp index d9436979ba5..bf51326578e 100644 --- a/tests/cpp/test_indexing.cpp +++ b/tests/cpp/test_indexing.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -347,6 +348,13 @@ class PredicateIndexValidator : public kir::IrVisitor { auto out_ti = expr->output(0)->as(); + // This is just an initialization expr, likely by zero. Only the + // actual expr will be validted. + if (out_ti->view()->definition()->input(0)->isA() && + expr->input(0)->isScalar()) { + return; + } + NVF_ERROR(!scope_exprs_.empty()); auto inline_ite = dynamic_cast(scope_exprs_.back()); NVF_ERROR( @@ -354,6 +362,7 @@ class PredicateIndexValidator : public kir::IrVisitor { "No inline predicate detected: ", expr->toString()); + std::cerr << expr->toString(); validateInlinePredicate(out_ti, inline_ite->predicate()->value()); // If there's an other IfThenElse in the scope stack, validate the @@ -5390,6 +5399,105 @@ TEST_F(IndexingTest, ResizeRotation) { testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); } +TEST_F(PredicateIndexingTest, VectorizedResizeRotation) { + Fusion fusion; + FusionGuard fg(&fusion); + + const int64_t i0 = 32; + + EnableOptionsGuard enable_options_guard; + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + + auto zero = fusion.zeroVal(); + + // concrete shapes to avoid dynamic Fusion + auto tv0 = makeContigConcreteTensor({i0}); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + // left half + auto tv2 = slice(tv1, {{zero, IrBuilder::create(i0 / 2)}}); + + auto tv3 = set(tv0); + // right half + auto tv4 = slice( + tv3, {{IrBuilder::create(i0 / 2), IrBuilder::create(i0)}}); + + // Rotation + auto tv5 = cat({tv4, tv2}, 0); + + auto tv6 = add(tv0, tv5); + + fusion.addOutput(tv6); + + for (Expr* expr : fusion.exprs()) { + if (expr->isOneOf()) { + scheduler_tools::propagateResizeToInputs(expr); + } + } + + for (auto tv : fusion.allTvs()) { + if (tv->isFusionInput()) { + continue; + } + + tv->split(0, 4); + } + + tv1->axis(-1)->parallelize(ParallelType::Vectorize); + + inlineMost(); + + struct GetReference : AbstractGetReference { + GetReference(const TensorIndexer& indexer, const IdModel& id_model) + : AbstractGetReference(indexer, id_model) {} + + Val* getInlinePredicate(TensorView* tv) const override { + if (tv->name() != 1) { + return nullptr; + } + + if (for_loops_.back()->iter_domain()->getParallelType() != + ParallelType::Vectorize) { + return nullptr; + } + + std::vector loop_indices = getLoopIndices(tv, indexer_, for_loops_); + + Val* zero = tv->fusion()->zeroVal(); + + auto second_resize = dynamic_cast( + tv->axis(0)->definition()->input(0)->definition()); + EXPECT_NE(second_resize, nullptr); + + auto start_idx = addExpr( + IrBuilder::addExpr( + mulExpr(loop_indices.at(0), tv->axis(1)->extent()), zero), + IrBuilder::negExpr(second_resize->leftExpand())); + auto stop_idx = addExpr( + IrBuilder::addExpr( + mulExpr(loop_indices.at(0), tv->axis(1)->extent()), createInt(3)), + IrBuilder::negExpr(second_resize->leftExpand())); + + return andExpr( + geExpr(start_idx, tv->fusion()->zeroVal()), + ltExpr(stop_idx, tv->getLogicalDomain().at(0)->extent())); + } + }; + + PredicateIndexValidator::validate(&fusion, false); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({i0}, options); + std::vector inputs{t0}; + + KernelExecutor ke; + ke.compile(&fusion, inputs); + auto outputs = ke.run(inputs); + + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); +} + // Repro of issue #3505. The indexing WAR for resize triggered an // assertion due to loop promotion. TEST_F(IndexingTest, Issue3505Repro1) { From 3be09ad287af4ab5bc366dc5573e9f8c88bbcd45 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Sat, 21 Dec 2024 22:36:35 -0800 Subject: [PATCH 2/3] bug fix --- csrc/id_model/indexing_traversal.cpp | 57 +++++++++++++++++++++++----- 1 file changed, 47 insertions(+), 10 deletions(-) diff --git a/csrc/id_model/indexing_traversal.cpp b/csrc/id_model/indexing_traversal.cpp index c2a6127a861..d83bfd97c14 100644 --- a/csrc/id_model/indexing_traversal.cpp +++ b/csrc/id_model/indexing_traversal.cpp @@ -44,6 +44,27 @@ IndexingTraversal::IndexingTraversal( } resize_paths_.insert(resize); } + + // A unique expr path should be always allowed + for (const auto& expr_g : graph.disjointExprSets().disjointSets()) { + auto resize = dynamic_cast(expr_g->front()); + if (resize == nullptr) { + continue; + } + + auto input_groups = graph.inputGroups(expr_g); + auto output_groups = graph.outputGroups(expr_g); + if (input_groups.size() != 1 || output_groups.size() != 1) { + continue; + } + + if (graph.getUses(input_groups[0]).size() != 1 || + graph.getDefinitions(output_groups[0]).size() != 1) { + continue; + } + + resize_paths_.insert(resize); + } } std::optional IndexingTraversal:: @@ -65,18 +86,26 @@ std::optional IndexingTraversal:: /*build_graphs=*/false); // Gather all resize exprs for each of the inputs and outputs - std::unordered_map> tv_resize_map; - for (auto inp : ir_utils::filterByType(expr->inputs())) { - for (auto expr : inp->domain()->allExprs()) { + std::unordered_map> tv_resize_map; + for (auto inp : expr->inputs()) { + auto inp_tv = ir_utils::getTv(inp); + if (inp_tv == nullptr) { + continue; + } + for (auto expr : inp_tv->domain()->allExprs()) { if (auto resize = dynamic_cast(expr)) { - tv_resize_map[inp].push_back(resize); + tv_resize_map[inp_tv].push_back(resize); } } } - for (auto out : ir_utils::filterByType(expr->outputs())) { - for (auto expr : out->domain()->allExprs()) { + for (auto out : expr->outputs()) { + auto out_tv = ir_utils::getTv(out); + if (out_tv == nullptr) { + continue; + } + for (auto expr : out_tv->domain()->allExprs()) { if (auto resize = dynamic_cast(expr)) { - tv_resize_map[out].push_back(resize); + tv_resize_map[out_tv].push_back(resize); } } } @@ -149,9 +178,17 @@ std::optional IndexingTraversal:: }; bool single_id_resized_multiple_times = false; - for (auto out : ir_utils::filterByType(expr->outputs())) { - for (auto inp : ir_utils::filterByType(expr->inputs())) { - if (isSingleIdResizedMultipleTimes(inp, out)) { + for (auto out : expr->outputs()) { + auto out_tv = ir_utils::getTv(out); + if (out_tv == nullptr) { + continue; + } + for (auto inp : expr->inputs()) { + auto inp_tv = ir_utils::getTv(inp); + if (inp_tv == nullptr) { + continue; + } + if (isSingleIdResizedMultipleTimes(inp_tv, out_tv)) { single_id_resized_multiple_times = true; break; } From 6e6ca27ed2e8207d4f889d368d8aca1e7bc4b0a1 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Sat, 21 Dec 2024 22:37:56 -0800 Subject: [PATCH 3/3] cleanup --- tests/cpp/test_indexing.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/cpp/test_indexing.cpp b/tests/cpp/test_indexing.cpp index bf51326578e..3691babd5b0 100644 --- a/tests/cpp/test_indexing.cpp +++ b/tests/cpp/test_indexing.cpp @@ -362,7 +362,6 @@ class PredicateIndexValidator : public kir::IrVisitor { "No inline predicate detected: ", expr->toString()); - std::cerr << expr->toString(); validateInlinePredicate(out_ti, inline_ite->predicate()->value()); // If there's an other IfThenElse in the scope stack, validate the