diff --git a/csrc/device_lower/analysis/tma.cpp b/csrc/device_lower/analysis/tma.cpp index eb2463b3923..a13f7525664 100644 --- a/csrc/device_lower/analysis/tma.cpp +++ b/csrc/device_lower/analysis/tma.cpp @@ -800,16 +800,22 @@ class DomainMerger { std::unordered_set& bulk_groups_; std::unordered_set& nonbulk_groups_; std::list& dim_info_; + MmaInputSmemSwizzle swizzle_; + int64_t item_size_bytes_; public: DomainMerger( std::list> raw_tma_domain, std::unordered_set& bulk_groups, std::unordered_set& nonbulk_groups, - std::list& dim_info) + std::list& dim_info, + MmaInputSmemSwizzle swizzle, + int64_t item_size_bytes) : bulk_groups_(bulk_groups), nonbulk_groups_(nonbulk_groups), - dim_info_(dim_info) { + dim_info_(dim_info), + swizzle_(swizzle), + item_size_bytes_(item_size_bytes) { ValGraph& id_graph = GpuLower::current()->tensorIndexer().traversalGraph(); contiguity_and_stride_.reserve(raw_tma_domain.size()); for (auto& item : raw_tma_domain) { @@ -868,6 +874,49 @@ class DomainMerger { return C; } + bool shouldMerge(int64_t i) { + auto type0 = type(i); + auto type1 = type(i + 1); + + bool may_increasing_box_size = (type0 == CB && type1 == CB); + if (!may_increasing_box_size) { + return true; + } + + auto extent0 = (*this)[i]->front()->as()->extent(); + auto extent1 = (*this)[i + 1]->front()->as()->extent(); + Val* merged_extent = SimplifyingIrBuilder::mulExpr(extent0, extent1); + + bool merging_innermost = ((int64_t)size() == i + 2); + + // If merging makes the size of a dimension larger than 256, we should not + // merge. + constexpr int64_t largest_dim_size = + 256; // Dimension size must be <= 256 as limited by hardware. + Val* too_large_after_merge = SimplifyingIrBuilder::gtExpr( + merged_extent, IrBuilder::create(largest_dim_size)); + if (simplifyExpr(too_large_after_merge)->isTrue()) { + return false; + } + + // If merging makes the inner size larger than the swizzle size, + // we should not merge + if (merging_innermost && swizzle_ != MmaInputSmemSwizzle::None) { + const int64_t swizzle_size = + getBytesFromSwizzle(swizzle_) / item_size_bytes_; + Val* merging_makes_gt_swizzle_size = SimplifyingIrBuilder::gtExpr( + merged_extent, IrBuilder::create(swizzle_size)); + if (simplifyExpr(merging_makes_gt_swizzle_size)->isTrue()) { + return false; + } + } + + // Because the shape is dynamic, we don't know if we should merge or + // not. For this case, we always assume merging is better than not + // merging. + return true; + } + void merge(int64_t i) { auto type0 = type(i); auto type1 = type(i + 1); @@ -941,9 +990,15 @@ std::vector run( std::unordered_set& bulk_groups, std::unordered_set& nonbulk_groups, std::list& dim_info, - int64_t item_size_bytes) { + int64_t item_size_bytes, + MmaInputSmemSwizzle swizzle) { DomainMerger tma_domain( - std::move(raw_tma_domain), bulk_groups, nonbulk_groups, dim_info); + std::move(raw_tma_domain), + bulk_groups, + nonbulk_groups, + dim_info, + swizzle, + item_size_bytes); // merge contiguous C groups and CB groups for (int64_t i = 0; i < (int64_t)tma_domain.size() - 1; i++) { if (!tma_domain.contiguity(i)) { @@ -951,8 +1006,10 @@ std::vector run( } if ((tma_domain.type(i) == C && tma_domain.type(i + 1) == C) || (tma_domain.type(i) == CB && tma_domain.type(i + 1) == CB)) { - tma_domain.merge(i); - i--; + if (tma_domain.shouldMerge(i)) { + tma_domain.merge(i); + i--; + } } } // merge contiguous C with SB/CB @@ -962,8 +1019,10 @@ std::vector run( } if (tma_domain.type(i) == C && (tma_domain.type(i + 1) == SB || tma_domain.type(i + 1) == CB)) { - tma_domain.merge(i); - i--; + if (tma_domain.shouldMerge(i)) { + tma_domain.merge(i); + i--; + } } } @@ -1056,6 +1115,9 @@ TMAInfo getTMAInfo(LoadStoreOp* ldst) { "(this is always the case for nvFuser now)", ", the first element of elementStrides must be one."); + MmaInputSmemSwizzle swizzle = getSwizzleFromBytes( + getCpAsyncBulkTensorSwizzleSize(smem_tv) * core_matrix_width_bytes); + // Handle "defining box by compositing" by collapsing some dimensions in the // raw TMA domain to get the final TMA domain. auto final_tma_domain = collapse_tma_domain::run( @@ -1063,12 +1125,9 @@ TMAInfo getTMAInfo(LoadStoreOp* ldst) { bulk_groups, nonbulk_groups, inferred_dims, - dataTypeSize(gmem_tv->dtype())); - return TMAInfo( - std::move(final_tma_domain), - getSwizzleFromBytes( - getCpAsyncBulkTensorSwizzleSize(smem_tv) * core_matrix_width_bytes), - gmem_tv); + dataTypeSize(gmem_tv->dtype()), + swizzle); + return TMAInfo(std::move(final_tma_domain), swizzle, gmem_tv); } } // namespace diff --git a/tests/cpp/test_memory.cpp b/tests/cpp/test_memory.cpp index 46e66c0ba9f..bd138a37045 100644 --- a/tests/cpp/test_memory.cpp +++ b/tests/cpp/test_memory.cpp @@ -881,7 +881,6 @@ TEST_F(TMAIndexingTest, DefineBoxByCompositing2) { } // Parallelize the tile axes tv1->axis(1)->parallelize(ParallelType::Bulk); - // tv2->axis(1)->parallelize(ParallelType::TIDx); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto t0 = at::randn({32, 4, 2, 8, 8, 8, 2, 8, 4}, options); @@ -895,6 +894,54 @@ TEST_F(TMAIndexingTest, DefineBoxByCompositing2) { testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); } +TEST_F(TMAIndexingTest, DefineBoxByCompositingShouldNotMerge) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigConcreteTensor({2, 256, 2, 32}); + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tv2 = set(tv1); + fusion.addOutput(tv2); + + tv1->setMemoryType(MemoryType::Shared); + tv1->definition()->as()->setOpType( + LoadStoreOpType::CpAsyncBulkTensorTile); + + // Use 1 thread and a single instruction to load the entire tensor to smem + for (auto id : tv1->getLoopDomain()) { + id->parallelize(ParallelType::Bulk); + } + + // Then use 32 threads to dump results out + tv2->axis(3)->parallelize(ParallelType::TIDx); + + // Schedule the allocation domain of tv1 to use 128B swizzle + AbstractTensor alloc1(tv1->getLoopDomain()); + alloc1.merge(0); + alloc1.merge(0); + // [1024, 32] + alloc1.split(1, 4); + alloc1.split(0, 8); + // [128, 8, 8, 4] + alloc1.swizzle(SwizzleType::XOR, 1, 2); + tv1->setAllocationDomain(alloc1.as(), true); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({2, 256, 2, 32}, options); + KernelExecutor ke; + ke.compile(&fusion, {t0}, {}, matmul_cparams); + + // Because merging dims will violate hardware requirement, we do not merge + // dims. + EXPECT_EQ(TMADimChecker::getDim(ke.kernel()), 4); + + EXPECT_TRUE(PredicatedChecker::isPredicated(tv1, ke.kernel())); + + auto cg_outputs = ke.run({t0}); + testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); +} + TEST_F(TMAIndexingTest, DefineBoxByRotation1) { Fusion fusion; FusionGuard fg(&fusion);