Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/jjsjann123/rope_benchmark' into …
Browse files Browse the repository at this point in the history
…HEAD
  • Loading branch information
jjsjann123 committed Dec 21, 2024
2 parents 1806368 + 93d28ed commit 9f55f22
Showing 1 changed file with 185 additions and 2 deletions.
187 changes: 185 additions & 2 deletions tests/cpp/test_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4006,8 +4006,8 @@ TEST_F(HopperMatmulTest, HSH_NT_UseScheduler) {
constexpr int64_t M = 2048, N = 2048, K = 8192;
const auto dtype = DataType::Half;

auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype);
auto tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype);
auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); // K, M
auto tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype); // K, N
fusion.addInput(tv0);
fusion.addInput(tv1);

Expand Down Expand Up @@ -4039,6 +4039,189 @@ TEST_F(HopperMatmulTest, HSH_NT_UseScheduler) {
mparams.circular_buffer_options.circular_buffer_smem_write = true;
mparams.circular_buffer_options.circular_buffer_smem_read = false;
mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
mparams.splitk_factor = 1;
mparams.use_smem_epilogue = true;
mparams.cluster_dims = {2, 1, 1};
mparams.promote_prologue_smem_reuse = true;

SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
->schedule(&fusion, &mparams);

std::vector<c10::IValue> inputs = {a_ref, b_ref};

KernelExecutor ke;
ke.compile(&fusion, inputs);
EXPECT_TRUE(getBankConflictInfo(ke.kernel()).empty());
auto cg_outputs = ke.run(inputs);
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(ke.kernel()));

// Relax tolerance for larger sum due to large K
EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K));
}

TEST_F(HopperMatmulTest, HSH_TN_UseScheduler) {
Fusion fusion;
FusionGuard fg(&fusion);

constexpr int64_t M = 2048, N = 2048, K = 8192;
const auto dtype = DataType::Half;

auto tv0 = makeContigConcreteTensor({-1, 1, -1}, dtype); // M, K
auto tv1 = makeContigConcreteTensor({1, -1, -1}, dtype); // N, K
fusion.addInput(tv0);
fusion.addInput(tv1);

auto tv2 = fusedMultiplySum(tv0, tv1, {-1});

auto tv3 = castOp(DataType::Half, tv2);
fusion.addOutput(tv3);

auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
auto a_ref = at::randn({M, 1, K}, options);
auto b_ref = at::randn({1, N, K}, options);
auto out_ref = at::matmul(a_ref.squeeze(), b_ref.squeeze().t()).to(at::kHalf);

MatMulTileOptions gemm_tile;
gemm_tile.cta_tile = GemmTile(128, 256, 16);
gemm_tile.warp_tile = GemmTile(64, 256, 16);

MatmulParams mparams;
mparams.supported_vec_size = {8, 8, 8};
mparams.mma_macro = MmaMacro::Hopper_64_256_16;
mparams.tile_sizes = gemm_tile;
mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor;
mparams.async_gmem_load_operands = true;
mparams.circular_buffer_options.circular_buffer_smem_write = true;
mparams.circular_buffer_options.circular_buffer_smem_read = false;
mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
mparams.splitk_factor = 1;
mparams.use_smem_epilogue = true;
mparams.cluster_dims = {2, 1, 1};
mparams.promote_prologue_smem_reuse = true;

SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
->schedule(&fusion, &mparams);

std::vector<c10::IValue> inputs = {a_ref, b_ref};

KernelExecutor ke;
ke.compile(&fusion, inputs);
EXPECT_TRUE(getBankConflictInfo(ke.kernel()).empty());
auto cg_outputs = ke.run(inputs);
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(ke.kernel()));

// Relax tolerance for larger sum due to large K
EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K));
}

TEST_F(HopperMatmulTest, HSH_NN_UseScheduler) {
Fusion fusion;
FusionGuard fg(&fusion);

constexpr int64_t M = 2048, N = 2048, K = 8192;
const auto dtype = DataType::Half;

auto tv0 = makeContigConcreteTensor({1, -1, -1}, dtype); // K, M
auto tv1 = makeContigConcreteTensor({-1, -1, 1}, dtype); // N, K
fusion.addInput(tv0);
fusion.addInput(tv1);

auto tv2 = fusedMultiplySum(tv0, tv1, {1});

// Reorder the accumulator as [M, N, K]
// [M, K, N] -> [M, N, K]
tv2->reorder({{-1, -3}});
tv2->commitLeafToLogical();

auto tv3 = castOp(DataType::Half, tv2);
fusion.addOutput(tv3);

auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
auto a_ref = at::randn({1, K, M}, options);
auto b_ref = at::randn({N, K, 1}, options);
auto out_ref =
at::matmul(a_ref.squeeze().t(), b_ref.squeeze().t()).to(at::kHalf);

MatMulTileOptions gemm_tile;
gemm_tile.cta_tile = GemmTile(128, 256, 16);
gemm_tile.warp_tile = GemmTile(64, 256, 16);

MatmulParams mparams;
mparams.supported_vec_size = {8, 8, 8};
mparams.mma_macro = MmaMacro::Hopper_64_256_16;
mparams.tile_sizes = gemm_tile;
mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor;
mparams.async_gmem_load_operands = true;
mparams.circular_buffer_options.circular_buffer_smem_write = true;
mparams.circular_buffer_options.circular_buffer_smem_read = false;
mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
mparams.splitk_factor = 1;
mparams.use_smem_epilogue = true;
mparams.cluster_dims = {2, 1, 1};
mparams.promote_prologue_smem_reuse = true;

SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
->schedule(&fusion, &mparams);

std::vector<c10::IValue> inputs = {a_ref, b_ref};

KernelExecutor ke;
ke.compile(&fusion, inputs);
EXPECT_TRUE(getBankConflictInfo(ke.kernel()).empty());
auto cg_outputs = ke.run(inputs);
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(ke.kernel()));

// Relax tolerance for larger sum due to large K
EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K));
}

TEST_F(HopperMatmulTest, HSH_TT_UseScheduler) {
Fusion fusion;
FusionGuard fg(&fusion);

constexpr int64_t M = 2048, N = 2048, K = 8192;
const auto dtype = DataType::Half;

auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); // M, K
auto tv1 = makeContigConcreteTensor({1, -1, -1}, dtype); // K, N
fusion.addInput(tv0);
fusion.addInput(tv1);

auto tv2 = fusedMultiplySum(tv0, tv1, {1});

// Reorder the accumulator as [M, N, K]
// [M, K, N] -> [M, N, K]
tv2->reorder({{-2, -1}});
tv2->commitLeafToLogical();

auto tv3 = castOp(DataType::Half, tv2);
fusion.addOutput(tv3);

auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
auto a_ref = at::randn({M, K, 1}, options);
auto b_ref = at::randn({1, K, N}, options);
auto out_ref = at::matmul(a_ref.squeeze(), b_ref.squeeze()).to(at::kHalf);

MatMulTileOptions gemm_tile;
gemm_tile.cta_tile = GemmTile(128, 256, 16);
gemm_tile.warp_tile = GemmTile(64, 256, 16);

MatmulParams mparams;
mparams.supported_vec_size = {8, 8, 8};
mparams.mma_macro = MmaMacro::Hopper_64_256_16;
mparams.tile_sizes = gemm_tile;
mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor;
mparams.async_gmem_load_operands = true;
mparams.circular_buffer_options.circular_buffer_smem_write = true;
mparams.circular_buffer_options.circular_buffer_smem_read = false;
mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
mparams.splitk_factor = 1;
mparams.use_smem_epilogue = true;
mparams.cluster_dims = {2, 1, 1};
Expand Down

0 comments on commit 9f55f22

Please sign in to comment.