diff --git a/CMakeLists.txt b/CMakeLists.txt index 9cfe58b540e..465c0ab5bd2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -200,6 +200,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/preseg_passes/remove_empty.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/reorder_sharded_axis.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/segment_inplace_update.cpp + ${NVFUSER_SRCS_DIR}/preseg_passes/translate_repeat_to_expand.cpp ${NVFUSER_SRCS_DIR}/rng.cpp ${NVFUSER_SRCS_DIR}/runtime/allocations.cpp ${NVFUSER_SRCS_DIR}/runtime/executor.cpp diff --git a/csrc/preseg_passes/pre_segmenter.cpp b/csrc/preseg_passes/pre_segmenter.cpp index 2ad82f9dd20..b4943f1c91e 100644 --- a/csrc/preseg_passes/pre_segmenter.cpp +++ b/csrc/preseg_passes/pre_segmenter.cpp @@ -25,6 +25,7 @@ #include #include #include +#include namespace nvfuser::preseg_passes { @@ -45,6 +46,9 @@ namespace nvfuser::preseg_passes { // Replace TensorViews with zero extent. Outputs and inputs may still be empty OptimizationPass::runPass(fusion); + // This pass should be placed before ConsecutiveCastPass as more + // consecutive cast ops may be exposed by this pass + OptimizationPass::runPass(fusion); // removes consecutive cast operations OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); diff --git a/csrc/preseg_passes/translate_repeat_to_expand.cpp b/csrc/preseg_passes/translate_repeat_to_expand.cpp new file mode 100644 index 00000000000..71b0e0b061a --- /dev/null +++ b/csrc/preseg_passes/translate_repeat_to_expand.cpp @@ -0,0 +1,195 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include + +#include +#include +#include + +#include +#include + +namespace nvfuser::preseg_passes { + +namespace { + +struct RepetitionInfo { + // Input tensor that is repeated + TensorView* input_tv = nullptr; + // Repeated logical ID of the input tensor + IterDomain* repeated_id = nullptr; + // Tensors fed into the concat op + std::vector cat_inp_tvs; +}; + +// Translation algorithm overview: +// +// Step 1: Inspection. Traverses the given fusion and looks for a +// sequence of ops that correspond to a repeatition. See +// RepeatToExpandTranslator::inspect() for more details. +// +// Step 2: Apply the translation in a reverse topologial order. See +// RepeatToExpandTranslator::translate() for more details. +class RepeatToExpandTranslator { + public: + RepeatToExpandTranslator(Fusion* fusion) : fusion_(fusion) {} + + void run() { + inspect(); + translate(); + } + + private: + // Traverse through the fusion and gather all patterns of a pad + // followed by a concat. If a single concat op has multiple pad + // inputs that resize the same iter domain of the same input tensor, + // that must correspond to a repetition. + void inspect() { + const auto exprs = fusion_->exprs(); + + for (auto pad : ir_utils::filterByType(exprs)) { + auto pad_inp = pad->input(0)->as(); + auto pad_out = pad->output(0)->as(); + + // Not supported if there are multiple expanded logical IDs + IterDomain* out_padded_root_id = nullptr; + bool multiple_resizes_found = false; + for (const auto i : c10::irange(pad_out->getLogicalDomain().size())) { + auto out_logical_id = pad_out->getLogicalDomain().at(i); + auto resize = dynamic_cast(out_logical_id->definition()); + if (resize == nullptr) { + continue; + } + if (out_padded_root_id != nullptr) { + // Multiple IDs are resized. Not supported. + multiple_resizes_found = true; + break; + } + out_padded_root_id = resize->in(); + } + + if (multiple_resizes_found || out_padded_root_id == nullptr) { + // Unsupported pattern + break; + } + + auto inp_padded_id = PairwiseLogicalDomainMap(pad_inp, pad_out) + .mapConsumerToProducer() + .at(out_padded_root_id); + + // The padded tensor must be immediately used by a concat only + if (pad_out->uses().size() != 1 || !pad_out->uses().at(0)->isA()) { + continue; + } + + auto cat_op = pad_out->uses().at(0); + + // If other inputs to the same concat op are already found, make + // sure this path from the pad op is compatible with the known + // ops. + if (auto it = repeat_info_map_.find(cat_op); + it == repeat_info_map_.end()) { + RepetitionInfo info; + info.input_tv = pad_inp; + info.repeated_id = inp_padded_id; + info.cat_inp_tvs.push_back(pad_out); + repeat_info_map_.emplace(cat_op, info); + } else { + auto& info = repeat_info_map_.at(cat_op); + if (info.input_tv != pad_inp || info.repeated_id != inp_padded_id) { + // Invalid + repeat_info_map_.erase(cat_op); + continue; + } + info.cat_inp_tvs.push_back(pad_out); + } + } + + // Remove invalid entries + for (auto it = repeat_info_map_.begin(); it != repeat_info_map_.end();) { + Expr* concatenating_expr = it->first; + const RepetitionInfo& info = it->second; + // Make sure all inputs to concatenating_expr are detected + if (concatenating_expr->inputs().size() != info.cat_inp_tvs.size()) { + // Invalid + it = repeat_info_map_.erase(it); + continue; + } + ++it; + } + } + + // For each detected repetition: + // + // Step 1. Insert a broadcast ID immediately outside of the + // repeated ID + // Step 2. Expand the broadcast ID by the repetition factor + // Step 3. Flatten the expanded ID and the repeated ID + void translate() { + const auto exprs = fusion_->exprs(); + // Apply the translation in a reverse topological order. Since the + // output of the repetition is replaced, the use exprs of the + // output are replaced too, which may invalidate the inspected + // info invalid. + for (auto exprs_it = exprs.rbegin(); exprs_it != exprs.rend(); ++exprs_it) { + Expr* expr = *exprs_it; + auto repeat_info_map_it = repeat_info_map_.find(expr); + if (repeat_info_map_it == repeat_info_map_.end()) { + continue; + } + + const auto& info = repeat_info_map_it->second; + + if (info.cat_inp_tvs.size() < 2) { + continue; + } + + auto original_out_tv = expr->output(0)->as(); + + // Step 1 + auto inp_domain = + TensorDomain::noReductions(info.input_tv->getLogicalDomain()); + std::vector bcast_flags(inp_domain.size() + 1, false); + auto repeated_id_offset = std::distance( + inp_domain.begin(), + std::find(inp_domain.begin(), inp_domain.end(), info.repeated_id)); + bcast_flags.at(repeated_id_offset) = true; + auto broadcast_tv = broadcast(info.input_tv, bcast_flags); + NVF_ERROR(broadcast_tv->nDims() == inp_domain.size() + 1); + + // Step 2 + std::vector expanded_sizes( + bcast_flags.size(), IrBuilder::create(-1L)); + expanded_sizes.at(repeated_id_offset) = + IrBuilder::create((int64_t)info.cat_inp_tvs.size()); + auto expanded_tv = expand(broadcast_tv, expanded_sizes); + + // Step 3 + auto flattened_tv = + flatten(expanded_tv, repeated_id_offset, repeated_id_offset + 1); + + ir_utils::replaceValInAllExprInputsAndFusionOutputs( + original_out_tv, flattened_tv); + } + } + + private: + Fusion* fusion_ = nullptr; + // Map of concat exprs to their info about repetition + std::unordered_map repeat_info_map_; +}; + +} // namespace + +void TranslateRepeatToExpand::runPass(Fusion* fusion) { + FusionGuard fg(fusion); + RepeatToExpandTranslator translator(fusion); + translator.run(); +} + +} // namespace nvfuser::preseg_passes diff --git a/csrc/preseg_passes/translate_repeat_to_expand.h b/csrc/preseg_passes/translate_repeat_to_expand.h new file mode 100644 index 00000000000..bc5d3ed5b35 --- /dev/null +++ b/csrc/preseg_passes/translate_repeat_to_expand.h @@ -0,0 +1,54 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include + +namespace nvfuser::preseg_passes { + +// Translate concat-based repetitions to expand and reshape ops. +// +// For example, given the following fusion: +// +// t0 = [i0]; +// t1 = cat({t0, t0}, -1); +// +// It will be translated to: +// +// t0 = [i0] +// t2 = broadcast(t0, {true, false}); +// t3 = expand(t2, {2, i0}); +// t4 = reshape(t3, {2 * i0}); +// +// And all uses of t1 will be replaced by t4. This pattern commonly +// appears in RoPE, e.g., +// https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L136. +// While the resize scheduler should be able to handle these patterns for +// pointwise-only segments, it is currently limited to only pointwise +// fusions only. This translation should promote larger fusions +// as it is not specific to any surrounding ops. +// +// Note that there's a potential downside compared to handling cat ops +// directly. Since insertion of broadcast IDs is not represented as +// Fusion IR expressions, a fusion may have more disconnected ID +// graphs after the translation, which may cause a segmentation that +// could be avoided with the original fusion. See +// PresegTest.TranslateRepeatToExpand4 for a concrete example. +class TranslateRepeatToExpand + : public OptimizationPass { + friend class OptimizationPass; + + protected: + static void runPass(Fusion* fusion); + static std::string name() { + return "TranslateRepeatToExpand"; + } +}; + +} // namespace nvfuser::preseg_passes diff --git a/tests/cpp/test_preseg_passes.cpp b/tests/cpp/test_preseg_passes.cpp index 33fb1b635ba..4661d6e5599 100644 --- a/tests/cpp/test_preseg_passes.cpp +++ b/tests/cpp/test_preseg_passes.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -770,4 +771,215 @@ TEST_F(PresegTest, DisjointSetsOfExtentsConcreteSymbolic) { testValidate( executor_cache.fusion(), cg_outputs, {t0, t1}, __LINE__, __FILE__); } + +// Trivial repeat pattern +TEST_F(PresegTest, TranslateRepeatToExpand1) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(&fusion); + + auto tv0 = makeContigConcreteTensor({32}); + fusion.addInput(tv0); + + auto tv1 = cat({tv0, tv0}, -1); + fusion.addOutput(tv1); + + { + // Make sure pad and cat no longer exist + Fusion fusion_copy = fusion; + OptimizationPass::runPass(&fusion_copy); + auto new_exprs = fusion_copy.exprs(); + EXPECT_EQ( + std::find_if( + new_exprs.begin(), + new_exprs.end(), + [](Expr* new_expr) { return new_expr->isOneOf(); }), + new_exprs.end()); + } + + auto options = at::TensorOptions().device(at::kCUDA, 0); + auto t0 = at::randn({32}, options); + std::vector inputs = {t0}; + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); + + // Should be scheduled as a pointwise kernel + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::PointWise); +} + +// Consecutive repetitions with the same IDs +TEST_F(PresegTest, TranslateRepeatToExpand2) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(&fusion); + + auto tv0 = makeContigConcreteTensor({32}); + fusion.addInput(tv0); + + auto tv1 = cat({tv0, tv0}, -1); + auto tv2 = cat({tv1, tv1}, -1); + + fusion.addOutput(tv2); + + { + Fusion fusion_copy = fusion; + OptimizationPass::runPass(&fusion_copy); + auto new_exprs = fusion_copy.exprs(); + EXPECT_EQ( + std::find_if( + new_exprs.begin(), + new_exprs.end(), + [](Expr* new_expr) { return new_expr->isOneOf(); }), + new_exprs.end()); + } + + auto options = at::TensorOptions().device(at::kCUDA, 0); + auto t0 = at::randn({32}, options); + std::vector inputs = {t0}; + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); + + // Should be scheduled as a pointwise kernel + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::PointWise); +} + +// Consecutive repetitions with different IDs +TEST_F(PresegTest, TranslateRepeatToExpand3) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(&fusion); + + auto tv0 = makeContigConcreteTensor({4, 8}); + fusion.addInput(tv0); + + auto tv1 = cat({tv0, tv0}, 1); + auto tv2 = cat({tv1, tv1}, 0); + + fusion.addOutput(tv2); + + { + Fusion fusion_copy = fusion; + OptimizationPass::runPass(&fusion_copy); + auto new_exprs = fusion_copy.exprs(); + EXPECT_EQ( + std::find_if( + new_exprs.begin(), + new_exprs.end(), + [](Expr* new_expr) { return new_expr->isOneOf(); }), + new_exprs.end()); + } + + auto options = at::TensorOptions().device(at::kCUDA, 0); + auto t0 = at::randn({4, 8}, options); + std::vector inputs = {t0}; + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); + + // Should be scheduled as a pointwise kernel + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::PointWise); +} + +// Repeat the same ID of the same tensor multiple times. While the +// repetitions are the same, there's nothing to allow the output IDs +// to be mapped, so the translated fusion will be segmented. This is a +// downside compared to the original fusion, where all IDs are +// connected, so it's relatively straightforward to fuse them together +// without segmentation. +TEST_F(PresegTest, TranslateRepeatToExpand4) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(&fusion); + + auto tv0 = makeContigConcreteTensor({4, 8}); + fusion.addInput(tv0); + + // Consecutive repetitions with the same IDs + auto tv1 = cat({tv0, tv0}, 1); + auto tv2 = cat({tv0, tv0}, 1); + + fusion.addOutput(tv1); + fusion.addOutput(tv2); + + { + Fusion fusion_copy = fusion; + OptimizationPass::runPass(&fusion_copy); + auto new_exprs = fusion_copy.exprs(); + EXPECT_EQ( + std::find_if( + new_exprs.begin(), + new_exprs.end(), + [](Expr* new_expr) { return new_expr->isOneOf(); }), + new_exprs.end()); + } + + auto options = at::TensorOptions().device(at::kCUDA, 0); + auto t0 = at::randn({4, 8}, options); + std::vector inputs = {t0}; + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); + + // Should be segmented to two pointwise kernels + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + const auto& heuristic_list = runtime->schedulerHeuristics()->heuristicsList(); + ASSERT_EQ(heuristic_list.size(), 2); + EXPECT_EQ(heuristic_list.at(0)->scheduler_type, SchedulerType::PointWise); + EXPECT_EQ(heuristic_list.at(1)->scheduler_type, SchedulerType::PointWise); +} + +// Repeating more than two times +TEST_F(PresegTest, TranslateRepeatToExpand5) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(&fusion); + + auto tv0 = makeContigConcreteTensor({32}); + fusion.addInput(tv0); + + auto tv1 = cat({tv0, tv0, tv0, tv0}, -1); + fusion.addOutput(tv1); + + { + // Make sure pad and cat no longer exist + Fusion fusion_copy = fusion; + OptimizationPass::runPass(&fusion_copy); + auto new_exprs = fusion_copy.exprs(); + EXPECT_EQ( + std::find_if( + new_exprs.begin(), + new_exprs.end(), + [](Expr* new_expr) { return new_expr->isOneOf(); }), + new_exprs.end()); + } + + auto options = at::TensorOptions().device(at::kCUDA, 0); + auto t0 = at::randn({32}, options); + std::vector inputs = {t0}; + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); + + // Should be scheduled as a pointwise kernel + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::PointWise); +} + } // namespace nvfuser::preseg_passes