From 32cfa897c409656bafab4cdaea2cf10250bff7eb Mon Sep 17 00:00:00 2001 From: ckl117 Date: Mon, 11 Nov 2024 19:02:44 +0800 Subject: [PATCH] fp8_gemm_sm90 --- .../49_collective_builder.cu | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu b/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu index 1e820ddb47..a63a4e0f0b 100644 --- a/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu +++ b/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu @@ -171,12 +171,12 @@ struct Options { return; } - cmd.get_cmd_line_argument("m", m, 2048); - cmd.get_cmd_line_argument("n", n, 2048); - cmd.get_cmd_line_argument("k", k, 2048); + cmd.get_cmd_line_argument("m", m, 128); + cmd.get_cmd_line_argument("n", n, 512); + cmd.get_cmd_line_argument("k", k, 256); cmd.get_cmd_line_argument("l", l, 1); cmd.get_cmd_line_argument("alpha", alpha, 1.f); - cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("beta", beta, 1.f); } /// Prints the usage statement. @@ -206,7 +206,7 @@ bool initialize_block( cutlass::DeviceAllocation& block, uint64_t seed=2023) { - Element scope_max, scope_min; + double scope_max, scope_min; int bits_input = cutlass::sizeof_bits::value; if (bits_input == 1) { @@ -221,7 +221,7 @@ bool initialize_block( } cutlass::reference::device::BlockFillRandomUniform( - block.get(), block.size(), seed, scope_max, scope_min, 0); + block.get(), block.size(), seed, Element(scope_max), Element(scope_min), 0); return true; } @@ -263,11 +263,11 @@ struct ExampleRunner { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::ColumnMajor; - using LayoutD = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; - using ElementA = cutlass::half_t; - using ElementB = cutlass::half_t; + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; using ElementC = cutlass::half_t; using ElementD = cutlass::half_t; using ElementAccumulator = float; @@ -374,7 +374,7 @@ struct ExampleRunner { cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({M, K})); cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({K, N})); - cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({M, N})); + cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({1, N})); cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({M, N})); cutlass::reference::device::GemmComplex( @@ -391,7 +391,7 @@ struct ExampleRunner { L, // batch_count M * K, // batch_stride_A K * N, // batch_stride_B - M * N, // batch_stride_C + 1 * N, // batch_stride_C M * N // batch_stride_D ); @@ -415,12 +415,12 @@ struct ExampleRunner { stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); - stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(1, N, L)); stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); block_A.reset(M * K * L); block_B.reset(K * N * L); - block_C.reset(M * N * L); + block_C.reset(1 * N * L); block_D.reset(M * N * L); block_ref_D.reset(M * N * L);