From 3268d798347391aed713bc36403bd2a8dac052f7 Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Fri, 20 Dec 2024 11:57:09 -0800 Subject: [PATCH] Add TN, NN, TT UseScheduler Matmul tests (#3628) This PR adds the remaining layouts with customizable hopper matmul scheduler configurations for M=2048, N=2048, K=8192. --- tests/cpp/test_matmul.cpp | 187 +++++++++++++++++++++++++++++++++++++- 1 file changed, 185 insertions(+), 2 deletions(-) diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index d040aa9e6d2..9e9395c5e18 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -4006,8 +4006,8 @@ TEST_F(HopperMatmulTest, HSH_NT_UseScheduler) { constexpr int64_t M = 2048, N = 2048, K = 8192; const auto dtype = DataType::Half; - auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); - auto tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype); + auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); // K, M + auto tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype); // K, N fusion.addInput(tv0); fusion.addInput(tv1); @@ -4039,6 +4039,189 @@ TEST_F(HopperMatmulTest, HSH_NT_UseScheduler) { mparams.circular_buffer_options.circular_buffer_smem_write = true; mparams.circular_buffer_options.circular_buffer_smem_read = false; mparams.circular_buffer_options.smem_circular_buffer_stage = 4; + mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1; + mparams.splitk_factor = 1; + mparams.use_smem_epilogue = true; + mparams.cluster_dims = {2, 1, 1}; + mparams.promote_prologue_smem_reuse = true; + + SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul) + ->schedule(&fusion, &mparams); + + std::vector inputs = {a_ref, b_ref}; + + KernelExecutor ke; + ke.compile(&fusion, inputs); + EXPECT_TRUE(getBankConflictInfo(ke.kernel()).empty()); + auto cg_outputs = ke.run(inputs); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(ke.kernel())); + + // Relax tolerance for larger sum due to large K + EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K)); +} + +TEST_F(HopperMatmulTest, HSH_TN_UseScheduler) { + Fusion fusion; + FusionGuard fg(&fusion); + + constexpr int64_t M = 2048, N = 2048, K = 8192; + const auto dtype = DataType::Half; + + auto tv0 = makeContigConcreteTensor({-1, 1, -1}, dtype); // M, K + auto tv1 = makeContigConcreteTensor({1, -1, -1}, dtype); // N, K + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = fusedMultiplySum(tv0, tv1, {-1}); + + auto tv3 = castOp(DataType::Half, tv2); + fusion.addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); + auto a_ref = at::randn({M, 1, K}, options); + auto b_ref = at::randn({1, N, K}, options); + auto out_ref = at::matmul(a_ref.squeeze(), b_ref.squeeze().t()).to(at::kHalf); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 256, 16); + gemm_tile.warp_tile = GemmTile(64, 256, 16); + + MatmulParams mparams; + mparams.supported_vec_size = {8, 8, 8}; + mparams.mma_macro = MmaMacro::Hopper_64_256_16; + mparams.tile_sizes = gemm_tile; + mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor; + mparams.async_gmem_load_operands = true; + mparams.circular_buffer_options.circular_buffer_smem_write = true; + mparams.circular_buffer_options.circular_buffer_smem_read = false; + mparams.circular_buffer_options.smem_circular_buffer_stage = 4; + mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1; + mparams.splitk_factor = 1; + mparams.use_smem_epilogue = true; + mparams.cluster_dims = {2, 1, 1}; + mparams.promote_prologue_smem_reuse = true; + + SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul) + ->schedule(&fusion, &mparams); + + std::vector inputs = {a_ref, b_ref}; + + KernelExecutor ke; + ke.compile(&fusion, inputs); + EXPECT_TRUE(getBankConflictInfo(ke.kernel()).empty()); + auto cg_outputs = ke.run(inputs); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(ke.kernel())); + + // Relax tolerance for larger sum due to large K + EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K)); +} + +TEST_F(HopperMatmulTest, HSH_NN_UseScheduler) { + Fusion fusion; + FusionGuard fg(&fusion); + + constexpr int64_t M = 2048, N = 2048, K = 8192; + const auto dtype = DataType::Half; + + auto tv0 = makeContigConcreteTensor({1, -1, -1}, dtype); // K, M + auto tv1 = makeContigConcreteTensor({-1, -1, 1}, dtype); // N, K + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = fusedMultiplySum(tv0, tv1, {1}); + + // Reorder the accumulator as [M, N, K] + // [M, K, N] -> [M, N, K] + tv2->reorder({{-1, -3}}); + tv2->commitLeafToLogical(); + + auto tv3 = castOp(DataType::Half, tv2); + fusion.addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); + auto a_ref = at::randn({1, K, M}, options); + auto b_ref = at::randn({N, K, 1}, options); + auto out_ref = + at::matmul(a_ref.squeeze().t(), b_ref.squeeze().t()).to(at::kHalf); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 256, 16); + gemm_tile.warp_tile = GemmTile(64, 256, 16); + + MatmulParams mparams; + mparams.supported_vec_size = {8, 8, 8}; + mparams.mma_macro = MmaMacro::Hopper_64_256_16; + mparams.tile_sizes = gemm_tile; + mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor; + mparams.async_gmem_load_operands = true; + mparams.circular_buffer_options.circular_buffer_smem_write = true; + mparams.circular_buffer_options.circular_buffer_smem_read = false; + mparams.circular_buffer_options.smem_circular_buffer_stage = 4; + mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1; + mparams.splitk_factor = 1; + mparams.use_smem_epilogue = true; + mparams.cluster_dims = {2, 1, 1}; + mparams.promote_prologue_smem_reuse = true; + + SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul) + ->schedule(&fusion, &mparams); + + std::vector inputs = {a_ref, b_ref}; + + KernelExecutor ke; + ke.compile(&fusion, inputs); + EXPECT_TRUE(getBankConflictInfo(ke.kernel()).empty()); + auto cg_outputs = ke.run(inputs); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(ke.kernel())); + + // Relax tolerance for larger sum due to large K + EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K)); +} + +TEST_F(HopperMatmulTest, HSH_TT_UseScheduler) { + Fusion fusion; + FusionGuard fg(&fusion); + + constexpr int64_t M = 2048, N = 2048, K = 8192; + const auto dtype = DataType::Half; + + auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); // M, K + auto tv1 = makeContigConcreteTensor({1, -1, -1}, dtype); // K, N + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = fusedMultiplySum(tv0, tv1, {1}); + + // Reorder the accumulator as [M, N, K] + // [M, K, N] -> [M, N, K] + tv2->reorder({{-2, -1}}); + tv2->commitLeafToLogical(); + + auto tv3 = castOp(DataType::Half, tv2); + fusion.addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); + auto a_ref = at::randn({M, K, 1}, options); + auto b_ref = at::randn({1, K, N}, options); + auto out_ref = at::matmul(a_ref.squeeze(), b_ref.squeeze()).to(at::kHalf); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 256, 16); + gemm_tile.warp_tile = GemmTile(64, 256, 16); + + MatmulParams mparams; + mparams.supported_vec_size = {8, 8, 8}; + mparams.mma_macro = MmaMacro::Hopper_64_256_16; + mparams.tile_sizes = gemm_tile; + mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor; + mparams.async_gmem_load_operands = true; + mparams.circular_buffer_options.circular_buffer_smem_write = true; + mparams.circular_buffer_options.circular_buffer_smem_read = false; + mparams.circular_buffer_options.smem_circular_buffer_stage = 4; + mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1; mparams.splitk_factor = 1; mparams.use_smem_epilogue = true; mparams.cluster_dims = {2, 1, 1};