Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable BFloat16 in stmatrix #3633

Merged
merged 3 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1686,8 +1686,8 @@ Val* hardCodedIndexGenerationForStMatrix(
Val* out_index = nullptr;

NVF_ERROR(
ldst->out()->dtype() == DataType::Half,
"we only support half type in stmatrix");
dataTypeSize(ldst->out()->dtype()) == 2,
"we only support 16-bit types in stmatrix");

NVF_ERROR(ldst->out()->isA<TensorView>());
TensorView* out_tv = ldst->out()->as<TensorView>();
Expand Down Expand Up @@ -1959,8 +1959,8 @@ Val* hardCodedIndexGenerationForStMatrixSwizzle(
"size not currently supported for stmatrix");

NVF_ERROR(
ldst->out()->dtype() == DataType::Half,
"we only support half type in stmatrix");
dataTypeSize(ldst->out()->dtype()) == 2,
"we only support 16-bit types in stmatrix");

NVF_ERROR(ldst->out()->isA<TensorView>());
TensorView* out_tv = ldst->out()->as<TensorView>();
Expand Down
6 changes: 3 additions & 3 deletions csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -521,9 +521,9 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {

// Set LoadStoreOp
// TODO: extend support when mma is not cast to half
NVF_ERROR(
dc->dtype() == DataType::Half,
"We support smem_epilogue on hopper only when the output of mma is cast to half");
NVF_CHECK(
dataTypeSize(dc->dtype()) == 2,
"We support use_smem_epilogue on Hopper only when the output is 16-bit");

d_smem->definition()->as<LoadStoreOp>()->setOpType(
LoadStoreOpType::StMatrix);
Expand Down
5 changes: 3 additions & 2 deletions csrc/scheduler/mma_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1311,8 +1311,9 @@ void scheduleStMatrixForMmaOutput(
((tile_m == 16 && tile_n == 16) || (tile_m == 16 && tile_n == 8)),
"We only support 16x16 and 16x16 stmatrix now");

NVF_ERROR(
tv->dtype() == DataType::Half, "we only support half type in stmatrix");
NVF_CHECK(
dataTypeSize(tv->dtype()) == 2,
"we only support 16-bit types in stmatrix");

// [M, N] -> [128(TIDx), N/8 , 2 , 2]
auto s =
Expand Down
14 changes: 9 additions & 5 deletions tests/cpp/test_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2811,7 +2811,7 @@ TEST_P(LdMatrixTest, Regular) {

// We get shapes M and N from MmaMacrao. The vector of ints are
// the tile_m and tile_n factors (8x8, 16x8 and 16x16).
using StMatrixTestParams = std::tuple<MmaMacro, std::vector<int>>;
using StMatrixTestParams = std::tuple<MmaMacro, std::vector<int>, DataType>;

class StMatrixTest : public NVFuserFixtureParamTest<StMatrixTestParams> {
protected:
Expand All @@ -2829,6 +2829,7 @@ TEST_P(StMatrixTest, Regular) {

auto macro = std::get<0>(GetParam());
auto tile_sizes = std::get<1>(GetParam());
auto dtype = std::get<2>(GetParam());
auto sizeM = getM(macro);
auto sizeN = getN(macro);
int64_t tile_m = tile_sizes.at(0);
Expand All @@ -2843,7 +2844,7 @@ TEST_P(StMatrixTest, Regular) {
fusion.manage("st_matrix_m", sizeM);
fusion.manage("st_matrix_n", sizeN);

auto tv0 = makeContigConcreteTensor({sizeM, sizeN}, DataType::Half);
auto tv0 = makeContigConcreteTensor({sizeM, sizeN}, dtype);
fusion.addInput(tv0);
// tv0 (global) -> tv1 (registers)
auto tv1 = set(tv0);
Expand Down Expand Up @@ -2871,7 +2872,8 @@ TEST_P(StMatrixTest, Regular) {
tv3->split(0, 32);
tv3->axis(1)->parallelize(ParallelType::TIDx);

auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
auto options =
at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
auto t0 = at::randn({sizeM, sizeN}, options);

KernelExecutor ke;
Expand All @@ -2886,13 +2888,14 @@ std::string testNameStMatrixTest(
std::ostringstream os;
auto macro = std::get<0>(info.param);
auto tile_sizes = std::get<1>(info.param);
auto dtype = std::get<2>(info.param);
auto sizeM = getM(macro);
auto sizeN = getN(macro);
auto tile_m = tile_sizes.at(0);
auto tile_n = tile_sizes.at(1);

os << "m_" << sizeM << "_n_" << sizeN << "_tile_m_" << tile_m << "_tile_n_"
<< tile_n;
<< tile_n << "_" << mma_utils::dtypeToChar(dtype);
return os.str();
}

Expand All @@ -2904,7 +2907,8 @@ INSTANTIATE_TEST_SUITE_P(
testing::Values(
// tile_m, tile_n
std::vector<int>{16, 8},
std::vector<int>{16, 16})),
std::vector<int>{16, 16}),
testing::Values(DataType::Half, DataType::BFloat16)),
testNameStMatrixTest);

TEST_P(LdMatrixTest, Transpose) {
Expand Down
11 changes: 6 additions & 5 deletions tests/cpp/test_mma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -545,11 +545,12 @@ TEST_P(HopperRSStmatrix, SingleTileWithTMALoadStoreStMatrix) {

auto cg_outputs = ke.run({inputs.first, inputs.second});
auto tref = atMatmul(
inputs.first.squeeze().to(at::kFloat),
inputs.second.squeeze().to(at::kFloat),
layout);
inputs.first.squeeze().to(at::kFloat),
inputs.second.squeeze().to(at::kFloat),
layout)
.to(data_type_to_aten(dtype));

EXPECT_TRUE(at::allclose(cg_outputs[0], tref.to(at::kHalf), 1e-1, 1e-1));
EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-1, 1e-1));
}

std::string testNameHopperRS(
Expand All @@ -569,7 +570,7 @@ INSTANTIATE_TEST_SUITE_P(
HopperRSStmatrix,
testing::Combine(
kAllHopperMacros,
testing::Values(DataType::Half),
testing::Values(DataType::Half, DataType::BFloat16),
testing::Values(MmaLayout::TN, MmaLayout::TT),
kAllSmemSwizzleModes,
testing::Values(
Expand Down
Loading