From ee3094755e8194f1579adaad66570bc2014d5d74 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 9 Jan 2025 01:39:06 -0800 Subject: [PATCH 1/3] Introduce repeat and RepeatOp Almost same semantics PyTorch repeat. Previously only partially introduced as a translation from a repeat pattern using concat. which had a bug when repeating broadcast IDs. This PR fixes the issue by handling broadcast separately using a new IR node, RepeatOp, which represents repetition of broadcast IDs. --- csrc/device_lower/pass/fusion_simplifier.cpp | 10 +++ csrc/device_lower/utils.cpp | 1 + csrc/dispatch.h | 1 + csrc/ir/internal_nodes.h | 28 +++++++ csrc/ir/nodes.cpp | 68 +++++++++++++++ csrc/logical_domain_map.h | 4 + csrc/ops/alias.cpp | 84 +++++++++++++++++++ csrc/ops/alias.h | 5 ++ .../translate_repeat_to_expand.cpp | 48 ++++------- tests/cpp/test_gpu3.cpp | 77 +++++++++++++++++ tests/cpp/test_preseg_passes.cpp | 48 +++++++++++ 11 files changed, 344 insertions(+), 30 deletions(-) diff --git a/csrc/device_lower/pass/fusion_simplifier.cpp b/csrc/device_lower/pass/fusion_simplifier.cpp index ed870a586da..03c92914695 100644 --- a/csrc/device_lower/pass/fusion_simplifier.cpp +++ b/csrc/device_lower/pass/fusion_simplifier.cpp @@ -56,6 +56,16 @@ class LoadStoreOpInserter : private kir::ExprMutator { container, LoadStoreOpType::Set, out, in)); } + void handle(RepeatOp* op) final { + auto out = op->out(); + auto in = op->in(); + auto container = out->container(); + registerReplaceAndPropagate( + op, + IrBuilder::createInContainer( + container, LoadStoreOpType::Set, out, in)); + } + void handle(ViewOp* vop) final { auto out = vop->out(); auto in = vop->in(); diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index 35b825d5348..66baab289db 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -162,6 +162,7 @@ bool isTvOp(const Expr* expr) { BroadcastOp, SqueezeOp, ExpandOp, + RepeatOp, ViewAsScalar, ViewOp, PadOp, diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 7681aa878a1..997b32ec7d6 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -96,6 +96,7 @@ class Val; f(BroadcastOp); \ f(SqueezeOp); \ f(ExpandOp); \ + f(RepeatOp); \ f(ViewAsScalar); \ f(ViewOp); \ f(CatOp); \ diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index df9b0bf50c9..762f206cf3b 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -1527,6 +1527,34 @@ class ExpandOp : public Expr { const std::vector& inputs) const override; }; +class RepeatOp : public Expr { + public: + using Expr::Expr; + + RepeatOp(IrBuilderPasskey, TensorView* out, TensorView* in); + + NVFUSER_DECLARE_CLONE_AND_CREATE + + const char* getOpString() const override { + return "RepeatOp"; + } + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + + TensorView* out() const { + return output(0)->as(); + } + + TensorView* in() const { + return input(0)->as(); + } + + std::vector evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const override; +}; + class ViewAsScalar : public Expr { public: using Expr::Expr; diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 5f0528991b2..af5990f91fe 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -2135,6 +2135,74 @@ std::vector ExpandOp::evaluate( NVFUSER_DEFINE_CLONE_AND_CREATE(ExpandOp) +RepeatOp::RepeatOp(IrBuilderPasskey passkey, TensorView* out, TensorView* in) + : Expr(passkey) { + auto in_domain = TensorDomain::noReductions(in->getLogicalDomain()); + const auto& out_domain = out->getLogicalDomain(); + + NVF_ERROR(in_domain.size() == out_domain.size()); + + bool repetition_found = false; + for (const auto i : c10::irange(in_domain.size())) { + if (in_domain.at(i)->isBroadcast() && !out_domain.at(i)->isBroadcast()) { + NVF_ERROR(!in_domain.at(i)->hasExpandedExtent()); + NVF_ERROR(in_domain.at(i)->extent()->isOneInt()); + repetition_found = true; + } + } + + NVF_ERROR( + repetition_found, + "No repetition dim found: ", + out->toString(), + ", ", + in->toString()); + + addOutput(out); + addInput(in); +} + +std::string RepeatOp::toString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << out()->toString() << " = repeat( " << in() + << " )\n"; + return ss.str(); +} + +std::string RepeatOp::toInlineString(int indent_size) const { + NVF_CHECK(false, "Tensor op can not be printed inline"); +} + +std::vector RepeatOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + NVF_ERROR( + inputs.size() == 1, + "ConcretizeOp expects exactly 1 input, but received ", + inputs.size()); + auto tensor = inputs.at(0).as(); + std::vector sizes; + sizes.reserve(out()->getLogicalDomain().size()); + const auto c2p = + PairwiseLogicalDomainMap(in(), out()).mapConsumerToProducer(); + for (const auto i : c10::irange(out()->getLogicalDomain().size())) { + auto out_id = out()->getLogicalDomain().at(i); + auto inp_id = c2p.at(out_id); + auto out_extent = ee.evaluate(out_id->extent()).as(); + auto inp_extent = ee.evaluate(inp_id->extent()).as(); + NVF_ERROR( + out_extent == inp_extent || out_extent % inp_extent == 0, + "Invalid input and output extents: ", + inp_extent, + ", ", + out_extent); + sizes.push_back(out_extent / inp_extent); + } + return {tensor.repeat(sizes)}; +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(RepeatOp) + ViewAsScalar::ViewAsScalar( IrBuilderPasskey passkey, Val* out, diff --git a/csrc/logical_domain_map.h b/csrc/logical_domain_map.h index d00e070df28..0a2d9d3ca01 100644 --- a/csrc/logical_domain_map.h +++ b/csrc/logical_domain_map.h @@ -504,6 +504,10 @@ class ComputeAtLogicalDomainMapBuilder : private BackwardVisitor { mapPointwiseLikeOp(op); } + void handle(RepeatOp* op) override { + mapPointwiseLikeOp(op); + } + void handle(PadOp* op) override { // For compute-at, padded id should be mapped mapPointwiseLikeOp(op); diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index bc541623310..695cc4d4356 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -1124,4 +1124,88 @@ TensorView* expand_as(TensorView* inp, TensorView* other) { return out_tensor; } +TensorView* repeat(TensorView* inp_tv, std::vector repeat_times) { + const auto ndims = + TensorDomain::noReductions(inp_tv->getLogicalDomain()).size(); + + // Handle repetitions of non-broadcast IDs first. Each ID is + // individully repeated by: + // + // Step 1. Insert a broadcast ID immediately outside of the + // repeated ID + // Step 2. Expand the broadcast ID by the repetition factor + // Step 3. Flatten the expanded ID and the repeated ID + + bool has_repetition_of_broadcast = false; + auto intermediate_tv = inp_tv; + for (const auto i : c10::irange(ndims)) { + if (repeat_times.at(i) == 1) { + continue; + } + + auto inp_id = intermediate_tv->getLogicalDomain().at(i); + + // Broadcast is handled after this + if (inp_id->isBroadcast()) { + has_repetition_of_broadcast = true; + continue; + } + + // Step 1: Insert a broadcast ID + std::vector bcast_flags(ndims + 1, false); + bcast_flags.at(i) = true; + auto broadcast_tv = broadcast(intermediate_tv, bcast_flags); + + // Step 2: Expand the broadcast ID for the repetition factor + std::vector expanded_sizes( + bcast_flags.size(), IrBuilder::create(-1L)); + expanded_sizes.at(i) = IrBuilder::create(repeat_times.at(i)); + auto expanded_tv = expand(broadcast_tv, expanded_sizes); + + // Step 3: Reshape to merge the broadcast ID and the repeated ID + intermediate_tv = flatten(expanded_tv, (int64_t)i, (int64_t)i + 1); + } + + if (!has_repetition_of_broadcast) { + return intermediate_tv; + } + + // Repeat broadcast IDs. The expand approach doesn't work as reshape + // would just squeeze repeated IDs and thus there would be no + // merge. Expanded IDs would remain to be expanded broadcast IDs. To + // concretize them, use RepeatOp + std::vector new_domain; + new_domain.reserve(ndims); + std::vector> new_contiguity; + new_contiguity.reserve(ndims); + + for (const auto i : c10::irange(ndims)) { + auto inp_id = intermediate_tv->getLogicalDomain().at(i); + IterDomain* new_id = nullptr; + + if (repeat_times.at(i) > 1 && inp_id->isBroadcast()) { + new_id = IterDomainBuilder(inp_id) + .extent(IrBuilder::create( + repeat_times.at(i), DataType::Index)) + .iter_type(IterType::Iteration) + .build(); + } else { + new_id = inp_id->cloneWithoutRFactor(); + } + + new_domain.push_back(new_id); + new_contiguity.push_back( + new_id->isBroadcast() ? std::optional(std::nullopt) + : std::optional(true)); + } + + auto out_tv = IrBuilder::create( + IrBuilder::create(new_domain, new_contiguity), + inp_tv->dtype()); + + IrBuilder::create(out_tv, intermediate_tv); + + return out_tv; +} + } // namespace nvfuser diff --git a/csrc/ops/alias.h b/csrc/ops/alias.h index 1b6d443c38c..b8b52f5d6be 100644 --- a/csrc/ops/alias.h +++ b/csrc/ops/alias.h @@ -182,4 +182,9 @@ NVF_API TensorView* expand( // non broadcasted iter domain, inp will be expanded to other's size. NVF_API TensorView* expand_as(TensorView* inp, TensorView* other); +// Repeat each dimension for a given time. The repeat_times parameter +// must have the same number of elements as the dimensionality of the +// input tensor (excluding reduction IDs). +NVF_API TensorView* repeat(TensorView* inp, std::vector repeat_times); + } // namespace nvfuser diff --git a/csrc/preseg_passes/translate_repeat_to_expand.cpp b/csrc/preseg_passes/translate_repeat_to_expand.cpp index 382dcb85f52..761adbbfe4a 100644 --- a/csrc/preseg_passes/translate_repeat_to_expand.cpp +++ b/csrc/preseg_passes/translate_repeat_to_expand.cpp @@ -124,13 +124,11 @@ class RepeatToExpandTranslator { } } - // For each detected repetition: - // - // Step 1. Insert a broadcast ID immediately outside of the - // repeated ID - // Step 2. Expand the broadcast ID by the repetition factor - // Step 3. Flatten the expanded ID and the repeated ID + // For each detected repetition, replace the output with a repeat + // output. void translate() { + FusionGuard fg(fusion_); + const auto exprs = fusion_->exprs(); // Apply the translation in a reverse topological order. Since the // output of the repetition is replaced, the use exprs of the @@ -145,36 +143,26 @@ class RepeatToExpandTranslator { const auto& info = repeat_info_map_it->second; - if (info.cat_inp_tvs.size() < 2) { + const auto num_repetitions = (int64_t)info.cat_inp_tvs.size(); + + if (num_repetitions < 2) { continue; } - auto original_out_tv = expr->output(0)->as(); - - // Step 1 - auto inp_domain = + const auto inp_domain = TensorDomain::noReductions(info.input_tv->getLogicalDomain()); - std::vector bcast_flags(inp_domain.size() + 1, false); - auto repeated_id_offset = std::distance( - inp_domain.begin(), - std::find(inp_domain.begin(), inp_domain.end(), info.repeated_id)); - bcast_flags.at(repeated_id_offset) = true; - auto broadcast_tv = broadcast(info.input_tv, bcast_flags); - NVF_ERROR((size_t)broadcast_tv->nDims() == inp_domain.size() + 1); - - // Step 2 - std::vector expanded_sizes( - bcast_flags.size(), IrBuilder::create(-1L)); - expanded_sizes.at(repeated_id_offset) = - IrBuilder::create((int64_t)info.cat_inp_tvs.size()); - auto expanded_tv = expand(broadcast_tv, expanded_sizes); - - // Step 3 - auto flattened_tv = - flatten(expanded_tv, repeated_id_offset, repeated_id_offset + 1); + + std::vector repeated_times(inp_domain.size(), 1); + auto repeated_id_it = + std::find(inp_domain.begin(), inp_domain.end(), info.repeated_id); + NVF_ERROR(repeated_id_it != inp_domain.end()); + auto repeated_dim = std::distance(inp_domain.begin(), repeated_id_it); + repeated_times.at(repeated_dim) = num_repetitions; + + TensorView* replacement_tv = repeat(info.input_tv, repeated_times); ir_utils::replaceValInAllExprInputsAndFusionOutputs( - original_out_tv, flattened_tv); + expr->output(0), replacement_tv); } } diff --git a/tests/cpp/test_gpu3.cpp b/tests/cpp/test_gpu3.cpp index 76d45f6de4c..ce891b7ae22 100644 --- a/tests/cpp/test_gpu3.cpp +++ b/tests/cpp/test_gpu3.cpp @@ -9272,6 +9272,83 @@ TEST_F(NVFuserTest, AllIdsMultipleDependencies) { } } +// Repeating a broadcast ID. RepeatOp should be used. +TEST_F(NVFuserTest, Repeat1) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeConcreteTensor({10}); + fusion.addInput(tv0); + + auto tv1 = broadcast(tv0, {false, true}); + auto tv2 = repeat(tv1, {1L, 2L}); + fusion.addOutput(tv2); + + EXPECT_TRUE(tv2->definition()->isA()); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({10}, options); + std::vector inputs({t0}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); +} + +// Repeating a non-broadcast ID. Should be translated to broadcast + +// expand + reshape. +TEST_F(NVFuserTest, Repeat2) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeConcreteTensor({10}); + fusion.addInput(tv0); + + auto tv1 = repeat(tv0, {2L}); + fusion.addOutput(tv1); + + ASSERT_TRUE(tv1->definition()->isA()); + ASSERT_TRUE(tv1->definition()->input(0)->definition()->isA()); + ASSERT_TRUE(tv1->definition() + ->input(0) + ->definition() + ->input(0) + ->definition() + ->isA()); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({10}, options); + std::vector inputs({t0}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); +} + +// Repeating a mix of broadcast and non-broadcast IDs +TEST_F(NVFuserTest, Repeat3) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + std::vector shape{2, 1, 3, 1}; + auto tv0 = makeConcreteTensor(shape); + fusion.addInput(tv0); + + auto tv1 = repeat(tv0, {2L, 2L, 2L, 2L}); + fusion.addOutput(tv1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn(shape, options); + std::vector inputs({t0}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); +} + // Test file size should be up to 10K LoC. Create a new file for more tests. } // namespace nvfuser diff --git a/tests/cpp/test_preseg_passes.cpp b/tests/cpp/test_preseg_passes.cpp index 4661d6e5599..33de732297a 100644 --- a/tests/cpp/test_preseg_passes.cpp +++ b/tests/cpp/test_preseg_passes.cpp @@ -982,4 +982,52 @@ TEST_F(PresegTest, TranslateRepeatToExpand5) { EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::PointWise); } +// Repeating a broadcast ID. Repro of +// https://github.com/NVIDIA/Fuser/issues/3682. +TEST_F(PresegTest, TranslateRepeatToExpand6) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(&fusion); + + auto tv0 = makeContigConcreteTensor({32, 1}); + fusion.addInput(tv0); + + auto tv1 = cat({tv0, tv0}, -1); + fusion.addOutput(tv1); + + { + // Make sure pad and cat no longer exist + Fusion fusion_copy = fusion; + OptimizationPass::runPass(&fusion_copy); + auto new_exprs = fusion_copy.exprs(); + EXPECT_EQ( + std::find_if( + new_exprs.begin(), + new_exprs.end(), + [](Expr* new_expr) { return new_expr->isOneOf(); }), + new_exprs.end()); + // RepeatOp should be used + EXPECT_NE( + std::find_if( + new_exprs.begin(), + new_exprs.end(), + [](Expr* new_expr) { return new_expr->isA(); }), + new_exprs.end()); + } + + auto options = at::TensorOptions().device(at::kCUDA, 0); + auto t0 = at::randn({32, 1}, options); + std::vector inputs = {t0}; + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(executor_cache.fusion(), outputs, inputs, __LINE__, __FILE__); + + // Should be scheduled as a pointwise kernel + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_FALSE(runtime->isSegmented()); + const auto& heuristic_param = + runtime->schedulerHeuristics()->heuristicsList().front(); + EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::PointWise); +} + } // namespace nvfuser::preseg_passes From 4f57eebf15f431390aafd662d632f01f989e85e7 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 9 Jan 2025 11:40:43 -0800 Subject: [PATCH 2/3] commens --- csrc/id_model/predicate_indexing.cpp | 2 +- csrc/ir/internal_nodes.h | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/csrc/id_model/predicate_indexing.cpp b/csrc/id_model/predicate_indexing.cpp index 15edb366357..1b7733bedaf 100644 --- a/csrc/id_model/predicate_indexing.cpp +++ b/csrc/id_model/predicate_indexing.cpp @@ -26,7 +26,7 @@ std::vector getPredicateDomains( : consumer_tv->getLogicalDomain(); // Broadcast domains should not need to be predicated. Note that - // unlike indexing for TensorIndex, reduction doamins do need to be + // unlike indexing for TensorIndex, reduction domains do need to be // indexed to guard the access to the producer tensor predicate_domains.erase( std::remove_if( diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index 762f206cf3b..6aebcb3c457 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -1527,10 +1527,18 @@ class ExpandOp : public Expr { const std::vector& inputs) const override; }; +// Represents a repetition of broadcast IDs. Repetitions of +// non-broadcast IDs are represented using the broadcast, expand and +// reshape pattern. See the repeat op implementation in ops/alias.cpp +// as well as the TranslateRepeatToExpand preseg pass. class RepeatOp : public Expr { public: using Expr::Expr; + // in: Input tensor that have broadcast logical IDs. + // out: Output tensor where some of the input broadcast logical IDs + // are converted to concrete IDs. Their extents represent the + // repetition factor of each ID. RepeatOp(IrBuilderPasskey, TensorView* out, TensorView* in); NVFUSER_DECLARE_CLONE_AND_CREATE From 3e6f480cd73405f90ae6e53d9661ad6ac85bdecc Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 9 Jan 2025 16:05:06 -0800 Subject: [PATCH 3/3] PR feedback --- csrc/ir/nodes.cpp | 30 +++++++++++++------ csrc/ops/alias.cpp | 8 ++++- csrc/ops/alias.h | 4 ++- .../translate_repeat_to_expand.cpp | 4 --- tests/cpp/test_gpu3.cpp | 6 ++-- tests/cpp/test_preseg_passes.cpp | 9 +++--- 6 files changed, 39 insertions(+), 22 deletions(-) diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index af5990f91fe..3087fe4e34e 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -2142,6 +2142,15 @@ RepeatOp::RepeatOp(IrBuilderPasskey passkey, TensorView* out, TensorView* in) NVF_ERROR(in_domain.size() == out_domain.size()); + NVF_ERROR( + std::none_of( + out->getLogicalDomain().begin(), + out->getLogicalDomain().end(), + [](IterDomain* out_logical_id) { + return out_logical_id->isReduction(); + }), + "Output should not have reduction IDs."); + bool repetition_found = false; for (const auto i : c10::irange(in_domain.size())) { if (in_domain.at(i)->isBroadcast() && !out_domain.at(i)->isBroadcast()) { @@ -2178,11 +2187,11 @@ std::vector RepeatOp::evaluate( const std::vector& inputs) const { NVF_ERROR( inputs.size() == 1, - "ConcretizeOp expects exactly 1 input, but received ", + "RepeatOp expects exactly 1 input, but received ", inputs.size()); auto tensor = inputs.at(0).as(); - std::vector sizes; - sizes.reserve(out()->getLogicalDomain().size()); + std::vector multipliers; + multipliers.reserve(out()->getLogicalDomain().size()); const auto c2p = PairwiseLogicalDomainMap(in(), out()).mapConsumerToProducer(); for (const auto i : c10::irange(out()->getLogicalDomain().size())) { @@ -2191,14 +2200,17 @@ std::vector RepeatOp::evaluate( auto out_extent = ee.evaluate(out_id->extent()).as(); auto inp_extent = ee.evaluate(inp_id->extent()).as(); NVF_ERROR( - out_extent == inp_extent || out_extent % inp_extent == 0, - "Invalid input and output extents: ", + out_extent % inp_extent == 0, + "For dimension ", + i, + ", the output extent (", + out_extent, + " should be a multiple of the input extent (", inp_extent, - ", ", - out_extent); - sizes.push_back(out_extent / inp_extent); + ")."); + multipliers.push_back(out_extent / inp_extent); } - return {tensor.repeat(sizes)}; + return {tensor.repeat(multipliers)}; } NVFUSER_DEFINE_CLONE_AND_CREATE(RepeatOp) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 695cc4d4356..5729fed5b3f 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -1124,7 +1124,9 @@ TensorView* expand_as(TensorView* inp, TensorView* other) { return out_tensor; } -TensorView* repeat(TensorView* inp_tv, std::vector repeat_times) { +TensorView* repeat( + TensorView* inp_tv, + const std::vector& repeat_times) { const auto ndims = TensorDomain::noReductions(inp_tv->getLogicalDomain()).size(); @@ -1135,6 +1137,10 @@ TensorView* repeat(TensorView* inp_tv, std::vector repeat_times) { // repeated ID // Step 2. Expand the broadcast ID by the repetition factor // Step 3. Flatten the expanded ID and the repeated ID + // + // Note that it's also possible to repeat multiple non-broadcast IDs + // once by inserting and expanding broadcast IDs by one BroadcastOp + // and one ExpandOp. bool has_repetition_of_broadcast = false; auto intermediate_tv = inp_tv; diff --git a/csrc/ops/alias.h b/csrc/ops/alias.h index b8b52f5d6be..8a896dba1d6 100644 --- a/csrc/ops/alias.h +++ b/csrc/ops/alias.h @@ -185,6 +185,8 @@ NVF_API TensorView* expand_as(TensorView* inp, TensorView* other); // Repeat each dimension for a given time. The repeat_times parameter // must have the same number of elements as the dimensionality of the // input tensor (excluding reduction IDs). -NVF_API TensorView* repeat(TensorView* inp, std::vector repeat_times); +NVF_API TensorView* repeat( + TensorView* inp, + const std::vector& repeat_times); } // namespace nvfuser diff --git a/csrc/preseg_passes/translate_repeat_to_expand.cpp b/csrc/preseg_passes/translate_repeat_to_expand.cpp index 761adbbfe4a..19c1274bf83 100644 --- a/csrc/preseg_passes/translate_repeat_to_expand.cpp +++ b/csrc/preseg_passes/translate_repeat_to_expand.cpp @@ -145,10 +145,6 @@ class RepeatToExpandTranslator { const auto num_repetitions = (int64_t)info.cat_inp_tvs.size(); - if (num_repetitions < 2) { - continue; - } - const auto inp_domain = TensorDomain::noReductions(info.input_tv->getLogicalDomain()); diff --git a/tests/cpp/test_gpu3.cpp b/tests/cpp/test_gpu3.cpp index ce891b7ae22..d6e715f1fe8 100644 --- a/tests/cpp/test_gpu3.cpp +++ b/tests/cpp/test_gpu3.cpp @@ -9273,7 +9273,7 @@ TEST_F(NVFuserTest, AllIdsMultipleDependencies) { } // Repeating a broadcast ID. RepeatOp should be used. -TEST_F(NVFuserTest, Repeat1) { +TEST_F(NVFuserTest, RepeatBroadcast) { auto fusion_ptr = std::make_unique(); auto& fusion = *fusion_ptr; FusionGuard fg(fusion_ptr.get()); @@ -9298,7 +9298,7 @@ TEST_F(NVFuserTest, Repeat1) { // Repeating a non-broadcast ID. Should be translated to broadcast + // expand + reshape. -TEST_F(NVFuserTest, Repeat2) { +TEST_F(NVFuserTest, RepeatNonBroadcast) { auto fusion_ptr = std::make_unique(); auto& fusion = *fusion_ptr; FusionGuard fg(fusion_ptr.get()); @@ -9328,7 +9328,7 @@ TEST_F(NVFuserTest, Repeat2) { } // Repeating a mix of broadcast and non-broadcast IDs -TEST_F(NVFuserTest, Repeat3) { +TEST_F(NVFuserTest, RepeatBroadcastAndNonBroadcast) { auto fusion_ptr = std::make_unique(); auto& fusion = *fusion_ptr; FusionGuard fg(fusion_ptr.get()); diff --git a/tests/cpp/test_preseg_passes.cpp b/tests/cpp/test_preseg_passes.cpp index 33de732297a..f3462108496 100644 --- a/tests/cpp/test_preseg_passes.cpp +++ b/tests/cpp/test_preseg_passes.cpp @@ -26,6 +26,8 @@ namespace nvfuser::preseg_passes { +using testing::ElementsAre; + using PresegTest = NVFuserTest; TEST_F(PresegTest, FusionTestOptimizationPassFlag) { @@ -1024,10 +1026,9 @@ TEST_F(PresegTest, TranslateRepeatToExpand6) { // Should be scheduled as a pointwise kernel FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); - EXPECT_FALSE(runtime->isSegmented()); - const auto& heuristic_param = - runtime->schedulerHeuristics()->heuristicsList().front(); - EXPECT_EQ(heuristic_param->scheduler_type, SchedulerType::PointWise); + EXPECT_THAT( + runtime->fusionSegments()->groups(), + ElementsAre(HeuristicIs(SchedulerType::PointWise))); } } // namespace nvfuser::preseg_passes