From e25a464132eb2faebf73c46b45856d4389c44853 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 24 Dec 2024 09:08:53 -0800 Subject: [PATCH 1/5] Translate the repetition pattern with expand and reshape --- CMakeLists.txt | 1 + csrc/preseg_passes/pre_segmenter.cpp | 4 + .../translate_repeat_to_expand.cpp | 227 +++++++++++++ .../translate_repeat_to_expand.h | 25 ++ csrc/preseg_passes/translate_repetition.h | 25 ++ tests/cpp/test_preseg_passes.cpp | 303 ++++++++++++++++++ 6 files changed, 585 insertions(+) create mode 100644 csrc/preseg_passes/translate_repeat_to_expand.cpp create mode 100644 csrc/preseg_passes/translate_repeat_to_expand.h create mode 100644 csrc/preseg_passes/translate_repetition.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 9cfe58b540e..465c0ab5bd2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -200,6 +200,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/preseg_passes/remove_empty.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/reorder_sharded_axis.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/segment_inplace_update.cpp + ${NVFUSER_SRCS_DIR}/preseg_passes/translate_repeat_to_expand.cpp ${NVFUSER_SRCS_DIR}/rng.cpp ${NVFUSER_SRCS_DIR}/runtime/allocations.cpp ${NVFUSER_SRCS_DIR}/runtime/executor.cpp diff --git a/csrc/preseg_passes/pre_segmenter.cpp b/csrc/preseg_passes/pre_segmenter.cpp index 2ad82f9dd20..b4943f1c91e 100644 --- a/csrc/preseg_passes/pre_segmenter.cpp +++ b/csrc/preseg_passes/pre_segmenter.cpp @@ -25,6 +25,7 @@ #include #include #include +#include namespace nvfuser::preseg_passes { @@ -45,6 +46,9 @@ namespace nvfuser::preseg_passes { // Replace TensorViews with zero extent. Outputs and inputs may still be empty OptimizationPass::runPass(fusion); + // This pass should be placed before ConsecutiveCastPass as more + // consecutive cast ops may be exposed by this pass + OptimizationPass::runPass(fusion); // removes consecutive cast operations OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); diff --git a/csrc/preseg_passes/translate_repeat_to_expand.cpp b/csrc/preseg_passes/translate_repeat_to_expand.cpp new file mode 100644 index 00000000000..26f716c40f4 --- /dev/null +++ b/csrc/preseg_passes/translate_repeat_to_expand.cpp @@ -0,0 +1,227 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include + +#include +#include +#include + +#include +#include + +namespace nvfuser::preseg_passes { + +namespace { + +struct RepetitionInfo { + TensorView* input_tv = nullptr; + IterDomain* repeated_id = nullptr; + std::vector cat_inp_tvs; + TensorView* output_tv = nullptr; +}; + +class RepeatToExpandTranslator { + public: + RepeatToExpandTranslator(Fusion* fusion) : fusion_(fusion) {} + + void run() { + inspect(); + translate(); + } + + // TODO: Consider translating the addition-based concat to an actual + // CatOp. By doing so, this pass would just need to find concat ops. + void inspect() { + const auto exprs = fusion_->exprs(); + + auto get_cat_inp = [](TensorView* tv) -> TensorView* { + if (tv->uses().size() != 1) { + return nullptr; + } + + // Skip cast + if (auto uop = dynamic_cast(tv->uses().at(0)); + uop != nullptr && uop->getUnaryOpType() == UnaryOpType::Cast) { + tv = uop->output(0)->as(); + + if (tv->uses().size() != 1) { + return nullptr; + } + } + + if (tv->uses().size() != 1) { + return nullptr; + } + + auto use_expr = tv->uses().at(0); + if (use_expr->isA() || + (use_expr->isA() && + use_expr->as()->getBinaryOpType() == BinaryOpType::Add)) { + return tv; + } else { + return nullptr; + } + }; + + for (auto pad : ir_utils::filterByType(exprs)) { + auto repeat_inp = pad->input(0)->as(); + auto pad_out = pad->output(0)->as(); + + // There must be just one logical ID expanded by this pad op + IterDomain* out_padded_root_id = nullptr; + bool multiple_resizes_found = false; + for (const auto i : c10::irange(pad_out->getLogicalDomain().size())) { + auto out_logical_id = pad_out->getLogicalDomain().at(i); + auto resize = dynamic_cast(out_logical_id->definition()); + if (resize == nullptr) { + continue; + } + if (out_padded_root_id != nullptr) { + // Multiple IDs are resized. Not supported. + multiple_resizes_found = true; + break; + } + out_padded_root_id = resize->in(); + } + + if (multiple_resizes_found) { + break; + } + + auto inp_padded_id = PairwiseLogicalDomainMap(repeat_inp, pad_out) + .mapConsumerToProducer() + .at(out_padded_root_id); + + auto cat_inp = get_cat_inp(pad_out); + if (cat_inp == nullptr) { + continue; + } + + // Note that this can be a CatOp or an addition + auto cat_op = cat_inp->uses().at(0); + + if (auto it = repeat_info_map.find(cat_op); it == repeat_info_map.end()) { + RepetitionInfo info; + info.input_tv = repeat_inp; + info.repeated_id = inp_padded_id; + info.cat_inp_tvs.push_back(cat_inp); + repeat_info_map.emplace(cat_op, info); + } else { + auto& info = repeat_info_map.at(cat_op); + if (info.input_tv != repeat_inp || info.repeated_id != inp_padded_id) { + // Invalid + repeat_info_map.erase(cat_op); + continue; + } + info.cat_inp_tvs.push_back(cat_inp); + } + } + + // Remove invalid entries + for (auto it = repeat_info_map.begin(); it != repeat_info_map.end();) { + Expr* concatenating_expr = it->first; + const RepetitionInfo& info = it->second; + // Make sure all inputs to concatenating_expr are detected + if (concatenating_expr->inputs().size() != info.cat_inp_tvs.size()) { + // Invalid + it = repeat_info_map.erase(it); + continue; + } + ++it; + } + } + + void translate() { + const auto exprs = fusion_->exprs(); + // Apply the translation in a reverse topological order. Since the + // output of the repetition is replaced, the use exprs of the + // output are replaced too, which may invalidate the inspected + // info invalid. + for (auto exprs_it = exprs.rbegin(); exprs_it != exprs.rend(); ++exprs_it) { + Expr* expr = *exprs_it; + auto repeat_info_map_it = repeat_info_map.find(expr); + if (repeat_info_map_it == repeat_info_map.end()) { + continue; + } + + const auto& info = repeat_info_map_it->second; + + if (info.cat_inp_tvs.size() < 2) { + continue; + } + + // Step 1. Insert a broadcast ID immediately outside of the + // repeated ID + // Step 2. Expand the broadcast ID by the repetition factor + // Step 3. Flatten the expanded ID and the repeated ID + // Step 4. Cast the flattened tensor if necessary. If the + // concatenation is done by addition and the inputs are fp16, + // there must be casting to fp32 before the addition. + + // Step 1 + auto inp_domain = + TensorDomain::noReductions(info.input_tv->getLogicalDomain()); + std::vector bcast_flags(inp_domain.size() + 1, false); + auto repeated_id_offset = std::distance( + inp_domain.begin(), + std::find(inp_domain.begin(), inp_domain.end(), info.repeated_id)); + bcast_flags.at(repeated_id_offset) = true; + auto broadcast_tv = broadcast(info.input_tv, bcast_flags); + NVF_ERROR(broadcast_tv->nDims() == inp_domain.size() + 1); + + // Step 2 + std::vector expanded_sizes( + bcast_flags.size(), IrBuilder::create(-1L)); + expanded_sizes.at(repeated_id_offset) = IrBuilder::create(2L); + auto expanded_tv = expand(broadcast_tv, expanded_sizes); + + // Step 3 + auto flattened_tv = + flatten(expanded_tv, repeated_id_offset, repeated_id_offset + 1); + + // Step 4 + TensorView* new_out_tv = nullptr; + auto origin_out_tv = expr->output(0)->as(); + + if (info.input_tv->dtype() != origin_out_tv->dtype()) { + // Input should be either Half or BFloat16 + NVF_ERROR( + info.input_tv->dtype() == DataType::Half || + info.input_tv->dtype() == DataType::BFloat16, + "Unexpected input type: ", + info.input_tv->toString()); + // Output should be either Float + NVF_ERROR( + origin_out_tv->dtype() == DataType::Float, + "Unexpected output type: ", + origin_out_tv->toString()); + new_out_tv = castOp(DataType::Float, flattened_tv); + } else { + new_out_tv = flattened_tv; + } + + ir_utils::replaceValInAllExprInputsAndFusionOutputs( + origin_out_tv, new_out_tv); + } + } + + private: + Fusion* fusion_ = nullptr; + // Map of concatenating expr to its infoi + std::unordered_map repeat_info_map; +}; + +} // namespace + +void TranslateRepeatToExpand::runPass(Fusion* fusion) { + FusionGuard fg(fusion); + RepeatToExpandTranslator translator(fusion); + translator.run(); +} + +} // namespace nvfuser::preseg_passes diff --git a/csrc/preseg_passes/translate_repeat_to_expand.h b/csrc/preseg_passes/translate_repeat_to_expand.h new file mode 100644 index 00000000000..de1879e9b65 --- /dev/null +++ b/csrc/preseg_passes/translate_repeat_to_expand.h @@ -0,0 +1,25 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include + +namespace nvfuser::preseg_passes { + +class TranslateRepeatToExpand : public OptimizationPass { + friend class OptimizationPass; + + protected: + static void runPass(Fusion* fusion); + static std::string name() { + return "TranslateRepeatToExpand"; + } +}; + +} // namespace nvfuser::preseg_passes diff --git a/csrc/preseg_passes/translate_repetition.h b/csrc/preseg_passes/translate_repetition.h new file mode 100644 index 00000000000..de1879e9b65 --- /dev/null +++ b/csrc/preseg_passes/translate_repetition.h @@ -0,0 +1,25 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include + +namespace nvfuser::preseg_passes { + +class TranslateRepeatToExpand : public OptimizationPass { + friend class OptimizationPass; + + protected: + static void runPass(Fusion* fusion); + static std::string name() { + return "TranslateRepeatToExpand"; + } +}; + +} // namespace nvfuser::preseg_passes diff --git a/tests/cpp/test_preseg_passes.cpp b/tests/cpp/test_preseg_passes.cpp index 33fb1b635ba..566373bfa89 100644 --- a/tests/cpp/test_preseg_passes.cpp +++ b/tests/cpp/test_preseg_passes.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -770,4 +771,306 @@ TEST_F(PresegTest, DisjointSetsOfExtentsConcreteSymbolic) { testValidate( executor_cache.fusion(), cg_outputs, {t0, t1}, __LINE__, __FILE__); } + +// Trivial repeat pattern +TEST_F(PresegTest, TranslateRepeatToExpand1) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(&fusion); + + auto tv0 = makeContigConcreteTensor({32}); + fusion.addInput(tv0); + + auto tv1 = cat({tv0, tv0}, -1); + fusion.addOutput(tv1); + + { + // Make sure pad and cat no longer exist + Fusion fusion_copy = fusion; + OptimizationPass::runPass(&fusion_copy); + auto new_exprs = fusion_copy.exprs(); + EXPECT_EQ( + std::find_if( + new_exprs.begin(), + new_exprs.end(), + [](Expr* new_expr) { return new_expr->isOneOf(); }), + new_exprs.end()); + } + + auto options = at::TensorOptions().device(at::kCUDA, 0); + auto t0 = at::randn({32}, options); + std::vector inputs = {t0}; + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); + + // Should be scheduled as a pointwise kernel + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::PointWise); +} + +// Consecutive repetitions with the same IDs +TEST_F(PresegTest, TranslateRepeatToExpand2) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(&fusion); + + auto tv0 = makeContigConcreteTensor({32}); + fusion.addInput(tv0); + + auto tv1 = cat({tv0, tv0}, -1); + auto tv2 = cat({tv1, tv1}, -1); + + fusion.addOutput(tv2); + + { + Fusion fusion_copy = fusion; + OptimizationPass::runPass(&fusion_copy); + auto new_exprs = fusion_copy.exprs(); + EXPECT_EQ( + std::find_if( + new_exprs.begin(), + new_exprs.end(), + [](Expr* new_expr) { return new_expr->isOneOf(); }), + new_exprs.end()); + } + + auto options = at::TensorOptions().device(at::kCUDA, 0); + auto t0 = at::randn({32}, options); + std::vector inputs = {t0}; + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); + + // Should be scheduled as a pointwise kernel + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::PointWise); +} + +// Consecutive repetitions with different IDs +TEST_F(PresegTest, TranslateRepeatToExpand3) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(&fusion); + + auto tv0 = makeContigConcreteTensor({4, 8}); + fusion.addInput(tv0); + + auto tv1 = cat({tv0, tv0}, 1); + auto tv2 = cat({tv1, tv1}, 0); + + fusion.addOutput(tv2); + + { + Fusion fusion_copy = fusion; + OptimizationPass::runPass(&fusion_copy); + auto new_exprs = fusion_copy.exprs(); + EXPECT_EQ( + std::find_if( + new_exprs.begin(), + new_exprs.end(), + [](Expr* new_expr) { return new_expr->isOneOf(); }), + new_exprs.end()); + } + + auto options = at::TensorOptions().device(at::kCUDA, 0); + auto t0 = at::randn({4, 8}, options); + std::vector inputs = {t0}; + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); + + // Should be scheduled as a pointwise kernel + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::PointWise); +} + +// Repeat the same ID of the same tensor multiple times. While the +// repetitions are the same, there's nothing to allow the output IDs +// to be mapped, so the translated fusion will be segmented. This is a +// downside compared to the original fusion, where all IDs are +// connected, so it's relatively straightforward to fuse them together +// without segmentation. +TEST_F(PresegTest, TranslateRepeatToExpand4) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(&fusion); + + auto tv0 = makeContigConcreteTensor({4, 8}); + fusion.addInput(tv0); + + // Consecutive repetitions with different IDs + auto tv1 = cat({tv0, tv0}, 1); + auto tv2 = cat({tv0, tv0}, 1); + + fusion.addOutput(tv1); + fusion.addOutput(tv2); + + { + Fusion fusion_copy = fusion; + OptimizationPass::runPass(&fusion_copy); + auto new_exprs = fusion_copy.exprs(); + EXPECT_EQ( + std::find_if( + new_exprs.begin(), + new_exprs.end(), + [](Expr* new_expr) { return new_expr->isOneOf(); }), + new_exprs.end()); + } + + auto options = at::TensorOptions().device(at::kCUDA, 0); + auto t0 = at::randn({4, 8}, options); + std::vector inputs = {t0}; + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); + + // Should be segmented to two pointwise kernels + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + const auto& heuristic_list = runtime->schedulerHeuristics()->heuristicsList(); + ASSERT_EQ(heuristic_list.size(), 2); + EXPECT_EQ(heuristic_list.at(0)->scheduler_type, SchedulerType::PointWise); + EXPECT_EQ(heuristic_list.at(1)->scheduler_type, SchedulerType::PointWise); +} + +// Trivial pattern with addition +TEST_F(PresegTest, TranslateRepeatToExpand5) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(&fusion); + + auto tv0 = makeContigConcreteTensor({32}); + fusion.addInput(tv0); + + auto tv1 = pad(tv0, {fusion.zeroVal(), tv0->axis(-1)->extent()}); + auto tv2 = pad(tv0, {tv0->axis(-1)->extent(), fusion.zeroVal()}); + auto tv3 = add(tv1, tv2); + fusion.addOutput(tv3); + + fusion.printMath(); + { + Fusion fusion_copy = fusion; + OptimizationPass::runPass(&fusion_copy); + auto new_exprs = fusion_copy.exprs(); + EXPECT_EQ( + std::find_if( + new_exprs.begin(), + new_exprs.end(), + [](Expr* new_expr) { return new_expr->isOneOf(); }), + new_exprs.end()); + } + + auto options = at::TensorOptions().device(at::kCUDA, 0); + auto t0 = at::randn({32}, options); + std::vector inputs = {t0}; + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); + + // Should be scheduled as a pointwise kernel + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::PointWise); +} + +// Addition-based concatenation followed using BFloat16 +TEST_F(PresegTest, TranslateRepeatToExpand6) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(&fusion); + + auto tv0 = makeContigConcreteTensor({32}, DataType::BFloat16); + fusion.addInput(tv0); + + auto tv1 = pad(tv0, {fusion.zeroVal(), tv0->axis(-1)->extent()}); + auto tv2 = pad(tv0, {tv0->axis(-1)->extent(), fusion.zeroVal()}); + auto tv3 = add(castOp(DataType::Float, tv1), castOp(DataType::Float, tv2)); + auto tv4 = castOp(DataType::BFloat16, tv3); + fusion.addOutput(tv4); + + fusion.printMath(); + { + Fusion fusion_copy = fusion; + OptimizationPass::runPass(&fusion_copy); + auto new_exprs = fusion_copy.exprs(); + EXPECT_EQ( + std::find_if( + new_exprs.begin(), + new_exprs.end(), + [](Expr* new_expr) { return new_expr->isOneOf(); }), + new_exprs.end()); + } + + auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA, 0); + auto t0 = at::randn({32}, options); + std::vector inputs = {t0}; + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); + + // Should be scheduled as a pointwise kernel + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::PointWise); +} + +// Addition-based concatenation followed by a reduction using BFloat16 +// inputs +TEST_F(PresegTest, TranslateRepeatToExpand7) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(&fusion); + + auto tv0 = makeContigConcreteTensor({32}, DataType::BFloat16); + fusion.addInput(tv0); + + auto tv1 = pad(tv0, {fusion.zeroVal(), tv0->axis(-1)->extent()}); + auto tv2 = pad(tv0, {tv0->axis(-1)->extent(), fusion.zeroVal()}); + auto tv3 = add(tv1, tv2); + auto tv4 = sum(tv3, {0}); + auto tv5 = castOp(DataType::BFloat16, tv4); + fusion.addOutput(tv5); + + fusion.printMath(); + { + Fusion fusion_copy = fusion; + OptimizationPass::runPass(&fusion_copy); + fusion_copy.printMath(); + auto new_exprs = fusion_copy.exprs(); + EXPECT_EQ( + std::find_if( + new_exprs.begin(), + new_exprs.end(), + [](Expr* new_expr) { return new_expr->isOneOf(); }), + new_exprs.end()); + } + + auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA, 0); + auto t0 = at::randn({32}, options); + std::vector inputs = {t0}; + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); + + // Should be scheduled as a reduction kernel + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::Reduction); +} + } // namespace nvfuser::preseg_passes From 330a62fc74b0cdbab52eb00059e8893dfc65de34 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 24 Dec 2024 18:56:11 -0800 Subject: [PATCH 2/5] cleanup --- .../translate_repeat_to_expand.cpp | 141 +++++++++++------- .../translate_repeat_to_expand.h | 20 ++- 2 files changed, 104 insertions(+), 57 deletions(-) diff --git a/csrc/preseg_passes/translate_repeat_to_expand.cpp b/csrc/preseg_passes/translate_repeat_to_expand.cpp index 26f716c40f4..fd0c7c19e90 100644 --- a/csrc/preseg_passes/translate_repeat_to_expand.cpp +++ b/csrc/preseg_passes/translate_repeat_to_expand.cpp @@ -19,12 +19,22 @@ namespace nvfuser::preseg_passes { namespace { struct RepetitionInfo { + // Input tensor that is repeated TensorView* input_tv = nullptr; + // Repeated logical ID of the input tensor IterDomain* repeated_id = nullptr; + // Tensors fed into the concat op std::vector cat_inp_tvs; - TensorView* output_tv = nullptr; }; +// Translation algorithm overview: +// +// Step 1: Inspection. Traverses the given fusion and looks for a +// sequence of ops that correspond to a repeatition. See +// RepeatToExpandTranslator::inspect() for more details. +// +// Step 2: Apply the translation in a reverse topologial order. See +// RepeatToExpandTranslator::translate() for more details. class RepeatToExpandTranslator { public: RepeatToExpandTranslator(Fusion* fusion) : fusion_(fusion) {} @@ -34,45 +44,22 @@ class RepeatToExpandTranslator { translate(); } - // TODO: Consider translating the addition-based concat to an actual + private: + // Traverse through the fusion and gather all patterns of a pad + // followed by a concat. If a single concat op has multiple pad + // inputs that resize the same iter domain of the same input tensor, + // that must correspond to a repetition. + // + // NOTE: Consider translating the addition-based concat to an actual // CatOp. By doing so, this pass would just need to find concat ops. void inspect() { const auto exprs = fusion_->exprs(); - auto get_cat_inp = [](TensorView* tv) -> TensorView* { - if (tv->uses().size() != 1) { - return nullptr; - } - - // Skip cast - if (auto uop = dynamic_cast(tv->uses().at(0)); - uop != nullptr && uop->getUnaryOpType() == UnaryOpType::Cast) { - tv = uop->output(0)->as(); - - if (tv->uses().size() != 1) { - return nullptr; - } - } - - if (tv->uses().size() != 1) { - return nullptr; - } - - auto use_expr = tv->uses().at(0); - if (use_expr->isA() || - (use_expr->isA() && - use_expr->as()->getBinaryOpType() == BinaryOpType::Add)) { - return tv; - } else { - return nullptr; - } - }; - for (auto pad : ir_utils::filterByType(exprs)) { - auto repeat_inp = pad->input(0)->as(); + auto pad_inp = pad->input(0)->as(); auto pad_out = pad->output(0)->as(); - // There must be just one logical ID expanded by this pad op + // Not supported if there are multiple expanded logical IDs IterDomain* out_padded_root_id = nullptr; bool multiple_resizes_found = false; for (const auto i : c10::irange(pad_out->getLogicalDomain().size())) { @@ -89,15 +76,18 @@ class RepeatToExpandTranslator { out_padded_root_id = resize->in(); } - if (multiple_resizes_found) { + if (multiple_resizes_found || out_padded_root_id == nullptr) { + // Unsupported pattern break; } - auto inp_padded_id = PairwiseLogicalDomainMap(repeat_inp, pad_out) + auto inp_padded_id = PairwiseLogicalDomainMap(pad_inp, pad_out) .mapConsumerToProducer() .at(out_padded_root_id); - auto cat_inp = get_cat_inp(pad_out); + // The padded tensor must be used by a concat or an addition + auto cat_inp = getMaybeValidConcatInp(pad_out); + if (cat_inp == nullptr) { continue; } @@ -105,17 +95,21 @@ class RepeatToExpandTranslator { // Note that this can be a CatOp or an addition auto cat_op = cat_inp->uses().at(0); - if (auto it = repeat_info_map.find(cat_op); it == repeat_info_map.end()) { + // If other inputs to the same concat op are already found, make + // sure this path from the pad op is compatible with the known + // ops. + if (auto it = repeat_info_map_.find(cat_op); + it == repeat_info_map_.end()) { RepetitionInfo info; - info.input_tv = repeat_inp; + info.input_tv = pad_inp; info.repeated_id = inp_padded_id; info.cat_inp_tvs.push_back(cat_inp); - repeat_info_map.emplace(cat_op, info); + repeat_info_map_.emplace(cat_op, info); } else { - auto& info = repeat_info_map.at(cat_op); - if (info.input_tv != repeat_inp || info.repeated_id != inp_padded_id) { + auto& info = repeat_info_map_.at(cat_op); + if (info.input_tv != pad_inp || info.repeated_id != inp_padded_id) { // Invalid - repeat_info_map.erase(cat_op); + repeat_info_map_.erase(cat_op); continue; } info.cat_inp_tvs.push_back(cat_inp); @@ -123,19 +117,62 @@ class RepeatToExpandTranslator { } // Remove invalid entries - for (auto it = repeat_info_map.begin(); it != repeat_info_map.end();) { + for (auto it = repeat_info_map_.begin(); it != repeat_info_map_.end();) { Expr* concatenating_expr = it->first; const RepetitionInfo& info = it->second; // Make sure all inputs to concatenating_expr are detected if (concatenating_expr->inputs().size() != info.cat_inp_tvs.size()) { // Invalid - it = repeat_info_map.erase(it); + it = repeat_info_map_.erase(it); continue; } ++it; } } + // For an output of a pad, finds the corresponding tensor that is an + // input to a concat. For this translation to work, there must not + // be other uses than the immediate concat, except type casting, + // which may be used when repating fp16 tensors with an + // addition. nullptr is returned if no valid tensor is found. + TensorView* getMaybeValidConcatInp(TensorView* pad_out_tv) { + if (pad_out_tv->uses().size() != 1) { + return nullptr; + } + + // Skip cast + if (auto uop = dynamic_cast(pad_out_tv->uses().at(0)); + uop != nullptr && uop->getUnaryOpType() == UnaryOpType::Cast) { + pad_out_tv = uop->output(0)->as(); + + if (pad_out_tv->uses().size() != 1) { + return nullptr; + } + } + + if (pad_out_tv->uses().size() != 1) { + return nullptr; + } + + auto use_expr = pad_out_tv->uses().at(0); + if (use_expr->isA() || + (use_expr->isA() && + use_expr->as()->getBinaryOpType() == BinaryOpType::Add)) { + return pad_out_tv; + } else { + return nullptr; + } + } + + // For each detected repetition: + // + // Step 1. Insert a broadcast ID immediately outside of the + // repeated ID + // Step 2. Expand the broadcast ID by the repetition factor + // Step 3. Flatten the expanded ID and the repeated ID + // Step 4. Cast the flattened tensor if necessary. If the + // concatenation is done by addition and the inputs are fp16, + // there must be casting to fp32 before the addition. void translate() { const auto exprs = fusion_->exprs(); // Apply the translation in a reverse topological order. Since the @@ -144,8 +181,8 @@ class RepeatToExpandTranslator { // info invalid. for (auto exprs_it = exprs.rbegin(); exprs_it != exprs.rend(); ++exprs_it) { Expr* expr = *exprs_it; - auto repeat_info_map_it = repeat_info_map.find(expr); - if (repeat_info_map_it == repeat_info_map.end()) { + auto repeat_info_map_it = repeat_info_map_.find(expr); + if (repeat_info_map_it == repeat_info_map_.end()) { continue; } @@ -155,14 +192,6 @@ class RepeatToExpandTranslator { continue; } - // Step 1. Insert a broadcast ID immediately outside of the - // repeated ID - // Step 2. Expand the broadcast ID by the repetition factor - // Step 3. Flatten the expanded ID and the repeated ID - // Step 4. Cast the flattened tensor if necessary. If the - // concatenation is done by addition and the inputs are fp16, - // there must be casting to fp32 before the addition. - // Step 1 auto inp_domain = TensorDomain::noReductions(info.input_tv->getLogicalDomain()); @@ -212,8 +241,8 @@ class RepeatToExpandTranslator { private: Fusion* fusion_ = nullptr; - // Map of concatenating expr to its infoi - std::unordered_map repeat_info_map; + // Map of concat exprs to their info about repetition + std::unordered_map repeat_info_map_; }; } // namespace diff --git a/csrc/preseg_passes/translate_repeat_to_expand.h b/csrc/preseg_passes/translate_repeat_to_expand.h index de1879e9b65..9b876bb7813 100644 --- a/csrc/preseg_passes/translate_repeat_to_expand.h +++ b/csrc/preseg_passes/translate_repeat_to_expand.h @@ -12,7 +12,25 @@ namespace nvfuser::preseg_passes { -class TranslateRepeatToExpand : public OptimizationPass { +// Translate concat-based repetitions to expand and reshape ops. +// +// For example, given the following fusion: +// +// t0 = [i0]; +// t1 = cat({t0, t0}, -1); +// +// It will be translated to: +// +// t0 = [i0] +// t2 = broadcast(t0, {true, false}); +// t3 = expand(t2, {2, i0}); +// t4 = reshape(t3, {2 * i0}); +// +// And all uses of t1 will be replaced by t4. This pattern commonly +// appears in RoPE, e.g., +// https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L136. +class TranslateRepeatToExpand + : public OptimizationPass { friend class OptimizationPass; protected: From 4c1abda909fbf2d85786b13f3d1566a3490db567 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 24 Dec 2024 20:09:49 -0800 Subject: [PATCH 3/5] Drop support of addition-based concat as it isn't immediately necessary --- .../translate_repeat_to_expand.cpp | 78 ++--------- .../translate_repeat_to_expand.h | 11 ++ tests/cpp/test_preseg_passes.cpp | 131 ------------------ 3 files changed, 19 insertions(+), 201 deletions(-) diff --git a/csrc/preseg_passes/translate_repeat_to_expand.cpp b/csrc/preseg_passes/translate_repeat_to_expand.cpp index fd0c7c19e90..893abe7cebc 100644 --- a/csrc/preseg_passes/translate_repeat_to_expand.cpp +++ b/csrc/preseg_passes/translate_repeat_to_expand.cpp @@ -49,9 +49,6 @@ class RepeatToExpandTranslator { // followed by a concat. If a single concat op has multiple pad // inputs that resize the same iter domain of the same input tensor, // that must correspond to a repetition. - // - // NOTE: Consider translating the addition-based concat to an actual - // CatOp. By doing so, this pass would just need to find concat ops. void inspect() { const auto exprs = fusion_->exprs(); @@ -85,15 +82,12 @@ class RepeatToExpandTranslator { .mapConsumerToProducer() .at(out_padded_root_id); - // The padded tensor must be used by a concat or an addition - auto cat_inp = getMaybeValidConcatInp(pad_out); - - if (cat_inp == nullptr) { + // The padded tensor must be immediately used by a concat only + if (pad_out->uses().size() != 1 || !pad_out->uses().at(0)->isA()) { continue; } - // Note that this can be a CatOp or an addition - auto cat_op = cat_inp->uses().at(0); + auto cat_op = pad_out->uses().at(0); // If other inputs to the same concat op are already found, make // sure this path from the pad op is compatible with the known @@ -103,7 +97,7 @@ class RepeatToExpandTranslator { RepetitionInfo info; info.input_tv = pad_inp; info.repeated_id = inp_padded_id; - info.cat_inp_tvs.push_back(cat_inp); + info.cat_inp_tvs.push_back(pad_out); repeat_info_map_.emplace(cat_op, info); } else { auto& info = repeat_info_map_.at(cat_op); @@ -112,7 +106,7 @@ class RepeatToExpandTranslator { repeat_info_map_.erase(cat_op); continue; } - info.cat_inp_tvs.push_back(cat_inp); + info.cat_inp_tvs.push_back(pad_out); } } @@ -130,49 +124,12 @@ class RepeatToExpandTranslator { } } - // For an output of a pad, finds the corresponding tensor that is an - // input to a concat. For this translation to work, there must not - // be other uses than the immediate concat, except type casting, - // which may be used when repating fp16 tensors with an - // addition. nullptr is returned if no valid tensor is found. - TensorView* getMaybeValidConcatInp(TensorView* pad_out_tv) { - if (pad_out_tv->uses().size() != 1) { - return nullptr; - } - - // Skip cast - if (auto uop = dynamic_cast(pad_out_tv->uses().at(0)); - uop != nullptr && uop->getUnaryOpType() == UnaryOpType::Cast) { - pad_out_tv = uop->output(0)->as(); - - if (pad_out_tv->uses().size() != 1) { - return nullptr; - } - } - - if (pad_out_tv->uses().size() != 1) { - return nullptr; - } - - auto use_expr = pad_out_tv->uses().at(0); - if (use_expr->isA() || - (use_expr->isA() && - use_expr->as()->getBinaryOpType() == BinaryOpType::Add)) { - return pad_out_tv; - } else { - return nullptr; - } - } - // For each detected repetition: // // Step 1. Insert a broadcast ID immediately outside of the // repeated ID // Step 2. Expand the broadcast ID by the repetition factor // Step 3. Flatten the expanded ID and the repeated ID - // Step 4. Cast the flattened tensor if necessary. If the - // concatenation is done by addition and the inputs are fp16, - // there must be casting to fp32 before the addition. void translate() { const auto exprs = fusion_->exprs(); // Apply the translation in a reverse topological order. Since the @@ -192,6 +149,8 @@ class RepeatToExpandTranslator { continue; } + auto original_out_tv = expr->output(0)->as(); + // Step 1 auto inp_domain = TensorDomain::noReductions(info.input_tv->getLogicalDomain()); @@ -213,29 +172,8 @@ class RepeatToExpandTranslator { auto flattened_tv = flatten(expanded_tv, repeated_id_offset, repeated_id_offset + 1); - // Step 4 - TensorView* new_out_tv = nullptr; - auto origin_out_tv = expr->output(0)->as(); - - if (info.input_tv->dtype() != origin_out_tv->dtype()) { - // Input should be either Half or BFloat16 - NVF_ERROR( - info.input_tv->dtype() == DataType::Half || - info.input_tv->dtype() == DataType::BFloat16, - "Unexpected input type: ", - info.input_tv->toString()); - // Output should be either Float - NVF_ERROR( - origin_out_tv->dtype() == DataType::Float, - "Unexpected output type: ", - origin_out_tv->toString()); - new_out_tv = castOp(DataType::Float, flattened_tv); - } else { - new_out_tv = flattened_tv; - } - ir_utils::replaceValInAllExprInputsAndFusionOutputs( - origin_out_tv, new_out_tv); + original_out_tv, flattened_tv); } } diff --git a/csrc/preseg_passes/translate_repeat_to_expand.h b/csrc/preseg_passes/translate_repeat_to_expand.h index 9b876bb7813..bc5d3ed5b35 100644 --- a/csrc/preseg_passes/translate_repeat_to_expand.h +++ b/csrc/preseg_passes/translate_repeat_to_expand.h @@ -29,6 +29,17 @@ namespace nvfuser::preseg_passes { // And all uses of t1 will be replaced by t4. This pattern commonly // appears in RoPE, e.g., // https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L136. +// While the resize scheduler should be able to handle these patterns for +// pointwise-only segments, it is currently limited to only pointwise +// fusions only. This translation should promote larger fusions +// as it is not specific to any surrounding ops. +// +// Note that there's a potential downside compared to handling cat ops +// directly. Since insertion of broadcast IDs is not represented as +// Fusion IR expressions, a fusion may have more disconnected ID +// graphs after the translation, which may cause a segmentation that +// could be avoided with the original fusion. See +// PresegTest.TranslateRepeatToExpand4 for a concrete example. class TranslateRepeatToExpand : public OptimizationPass { friend class OptimizationPass; diff --git a/tests/cpp/test_preseg_passes.cpp b/tests/cpp/test_preseg_passes.cpp index 566373bfa89..74f2d02cb90 100644 --- a/tests/cpp/test_preseg_passes.cpp +++ b/tests/cpp/test_preseg_passes.cpp @@ -942,135 +942,4 @@ TEST_F(PresegTest, TranslateRepeatToExpand4) { EXPECT_EQ(heuristic_list.at(1)->scheduler_type, SchedulerType::PointWise); } -// Trivial pattern with addition -TEST_F(PresegTest, TranslateRepeatToExpand5) { - auto fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr; - FusionGuard fg(&fusion); - - auto tv0 = makeContigConcreteTensor({32}); - fusion.addInput(tv0); - - auto tv1 = pad(tv0, {fusion.zeroVal(), tv0->axis(-1)->extent()}); - auto tv2 = pad(tv0, {tv0->axis(-1)->extent(), fusion.zeroVal()}); - auto tv3 = add(tv1, tv2); - fusion.addOutput(tv3); - - fusion.printMath(); - { - Fusion fusion_copy = fusion; - OptimizationPass::runPass(&fusion_copy); - auto new_exprs = fusion_copy.exprs(); - EXPECT_EQ( - std::find_if( - new_exprs.begin(), - new_exprs.end(), - [](Expr* new_expr) { return new_expr->isOneOf(); }), - new_exprs.end()); - } - - auto options = at::TensorOptions().device(at::kCUDA, 0); - auto t0 = at::randn({32}, options); - std::vector inputs = {t0}; - FusionExecutorCache executor_cache(std::move(fusion_ptr)); - auto outputs = executor_cache.runFusionWithInputs(inputs); - testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); - - // Should be scheduled as a pointwise kernel - FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); - EXPECT_FALSE(runtime->isSegmented()); - const auto& heuristic_param = - runtime->schedulerHeuristics()->heuristicsList().front(); - EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::PointWise); -} - -// Addition-based concatenation followed using BFloat16 -TEST_F(PresegTest, TranslateRepeatToExpand6) { - auto fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr; - FusionGuard fg(&fusion); - - auto tv0 = makeContigConcreteTensor({32}, DataType::BFloat16); - fusion.addInput(tv0); - - auto tv1 = pad(tv0, {fusion.zeroVal(), tv0->axis(-1)->extent()}); - auto tv2 = pad(tv0, {tv0->axis(-1)->extent(), fusion.zeroVal()}); - auto tv3 = add(castOp(DataType::Float, tv1), castOp(DataType::Float, tv2)); - auto tv4 = castOp(DataType::BFloat16, tv3); - fusion.addOutput(tv4); - - fusion.printMath(); - { - Fusion fusion_copy = fusion; - OptimizationPass::runPass(&fusion_copy); - auto new_exprs = fusion_copy.exprs(); - EXPECT_EQ( - std::find_if( - new_exprs.begin(), - new_exprs.end(), - [](Expr* new_expr) { return new_expr->isOneOf(); }), - new_exprs.end()); - } - - auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA, 0); - auto t0 = at::randn({32}, options); - std::vector inputs = {t0}; - FusionExecutorCache executor_cache(std::move(fusion_ptr)); - auto outputs = executor_cache.runFusionWithInputs(inputs); - testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); - - // Should be scheduled as a pointwise kernel - FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); - EXPECT_FALSE(runtime->isSegmented()); - const auto& heuristic_param = - runtime->schedulerHeuristics()->heuristicsList().front(); - EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::PointWise); -} - -// Addition-based concatenation followed by a reduction using BFloat16 -// inputs -TEST_F(PresegTest, TranslateRepeatToExpand7) { - auto fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr; - FusionGuard fg(&fusion); - - auto tv0 = makeContigConcreteTensor({32}, DataType::BFloat16); - fusion.addInput(tv0); - - auto tv1 = pad(tv0, {fusion.zeroVal(), tv0->axis(-1)->extent()}); - auto tv2 = pad(tv0, {tv0->axis(-1)->extent(), fusion.zeroVal()}); - auto tv3 = add(tv1, tv2); - auto tv4 = sum(tv3, {0}); - auto tv5 = castOp(DataType::BFloat16, tv4); - fusion.addOutput(tv5); - - fusion.printMath(); - { - Fusion fusion_copy = fusion; - OptimizationPass::runPass(&fusion_copy); - fusion_copy.printMath(); - auto new_exprs = fusion_copy.exprs(); - EXPECT_EQ( - std::find_if( - new_exprs.begin(), - new_exprs.end(), - [](Expr* new_expr) { return new_expr->isOneOf(); }), - new_exprs.end()); - } - - auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA, 0); - auto t0 = at::randn({32}, options); - std::vector inputs = {t0}; - FusionExecutorCache executor_cache(std::move(fusion_ptr)); - auto outputs = executor_cache.runFusionWithInputs(inputs); - testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); - - // Should be scheduled as a reduction kernel - FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); - EXPECT_FALSE(runtime->isSegmented()); - const auto& heuristic_param = - runtime->schedulerHeuristics()->heuristicsList().front(); - EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::Reduction); -} - } // namespace nvfuser::preseg_passes From baebd7b1ffde13ed33acc09e6fc9b4a8610d271c Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 24 Dec 2024 20:19:12 -0800 Subject: [PATCH 4/5] remove --- csrc/preseg_passes/translate_repetition.h | 25 ----------------------- 1 file changed, 25 deletions(-) delete mode 100644 csrc/preseg_passes/translate_repetition.h diff --git a/csrc/preseg_passes/translate_repetition.h b/csrc/preseg_passes/translate_repetition.h deleted file mode 100644 index de1879e9b65..00000000000 --- a/csrc/preseg_passes/translate_repetition.h +++ /dev/null @@ -1,25 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -#pragma once - -#include -#include - -namespace nvfuser::preseg_passes { - -class TranslateRepeatToExpand : public OptimizationPass { - friend class OptimizationPass; - - protected: - static void runPass(Fusion* fusion); - static std::string name() { - return "TranslateRepeatToExpand"; - } -}; - -} // namespace nvfuser::preseg_passes From 498c8daba2768dc23823f7fd5460fa46ec79c7bd Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 30 Dec 2024 15:55:48 -0800 Subject: [PATCH 5/5] fix --- .../translate_repeat_to_expand.cpp | 3 +- tests/cpp/test_preseg_passes.cpp | 42 ++++++++++++++++++- 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/csrc/preseg_passes/translate_repeat_to_expand.cpp b/csrc/preseg_passes/translate_repeat_to_expand.cpp index 893abe7cebc..71b0e0b061a 100644 --- a/csrc/preseg_passes/translate_repeat_to_expand.cpp +++ b/csrc/preseg_passes/translate_repeat_to_expand.cpp @@ -165,7 +165,8 @@ class RepeatToExpandTranslator { // Step 2 std::vector expanded_sizes( bcast_flags.size(), IrBuilder::create(-1L)); - expanded_sizes.at(repeated_id_offset) = IrBuilder::create(2L); + expanded_sizes.at(repeated_id_offset) = + IrBuilder::create((int64_t)info.cat_inp_tvs.size()); auto expanded_tv = expand(broadcast_tv, expanded_sizes); // Step 3 diff --git a/tests/cpp/test_preseg_passes.cpp b/tests/cpp/test_preseg_passes.cpp index 74f2d02cb90..4661d6e5599 100644 --- a/tests/cpp/test_preseg_passes.cpp +++ b/tests/cpp/test_preseg_passes.cpp @@ -908,7 +908,7 @@ TEST_F(PresegTest, TranslateRepeatToExpand4) { auto tv0 = makeContigConcreteTensor({4, 8}); fusion.addInput(tv0); - // Consecutive repetitions with different IDs + // Consecutive repetitions with the same IDs auto tv1 = cat({tv0, tv0}, 1); auto tv2 = cat({tv0, tv0}, 1); @@ -942,4 +942,44 @@ TEST_F(PresegTest, TranslateRepeatToExpand4) { EXPECT_EQ(heuristic_list.at(1)->scheduler_type, SchedulerType::PointWise); } +// Repeating more than two times +TEST_F(PresegTest, TranslateRepeatToExpand5) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(&fusion); + + auto tv0 = makeContigConcreteTensor({32}); + fusion.addInput(tv0); + + auto tv1 = cat({tv0, tv0, tv0, tv0}, -1); + fusion.addOutput(tv1); + + { + // Make sure pad and cat no longer exist + Fusion fusion_copy = fusion; + OptimizationPass::runPass(&fusion_copy); + auto new_exprs = fusion_copy.exprs(); + EXPECT_EQ( + std::find_if( + new_exprs.begin(), + new_exprs.end(), + [](Expr* new_expr) { return new_expr->isOneOf(); }), + new_exprs.end()); + } + + auto options = at::TensorOptions().device(at::kCUDA, 0); + auto t0 = at::randn({32}, options); + std::vector inputs = {t0}; + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); + + // Should be scheduled as a pointwise kernel + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::PointWise); +} + } // namespace nvfuser::preseg_passes