Skip to content

Commit 5a2184c

Browse files
authored
Add support for 32B and 64B swizzles to hopper matmul scheduler (#3544)
This PR adds support for 32B and 64B swizzles to StMatrix indexing and to the hopper matmul scheduler. ### Key Index Change The number of distinct swizzle rows is number of bytes for swizzle divided by size of megabank (16B). The number of times a swizzle pattern is repeated to fill core (8, 8) matrix is number of swizzle rows (8) divided by number of distinct rows. ```cpp MmaInputSmemSwizzle swizzle = getSwizzle(out_tv); int64_t swizzle_bytes = getBytesFromSwizzle(swizzle); constexpr int64_t megabank_size_bytes = 16; const int64_t distinct_swizzle_row_size = swizzle_bytes / megabank_size_bytes; int row = ...; int col = ...; constexpr int64_t swizzle_row_size = 8; const int64_t swizzle_row_repetitions = swizzle_row_size / distinct_swizzle_row_size; int64_t row_in_swizzle_pattern = (row % swizzle_row_size) / swizzle_row_repetitions; int64_t swizzle_col = col ^ row_in_swizzle_pattern; ``` ### Testing Changes * Added `mma_macro` as testing value. * Created separate test suite called `Swizzle/HopperMatmulSchedulerTest` to test `32B`, `64B`, `128B` swizzles.
1 parent 07be8b9 commit 5a2184c

File tree

5 files changed

+144
-42
lines changed

5 files changed

+144
-42
lines changed

csrc/device_lower/pass/index.cpp

+45-21
Original file line numberDiff line numberDiff line change
@@ -1559,15 +1559,15 @@ void IndexLowering::handleCpAsyncBulkStore(const LoadStoreOp* ldst) {
15591559
}
15601560

15611561
static DataType getMmaInputAType(MmaMacro macro) {
1562-
int warp_group_size = isHopper(macro) ? 128 : 32;
1563-
int size = getM(macro) * getK(macro) / warp_group_size /
1564-
2 /* halves per 32bit register */;
1562+
int64_t warp_group_size = isHopper(macro) ? 128L : 32L;
1563+
int64_t size = getM(macro) * getK(macro) / warp_group_size /
1564+
2L /* halves per 32bit register */;
15651565
return ArrayType{std::make_shared<DataType>(DataType::UInt32), (size_t)size};
15661566
}
15671567

15681568
static DataType getMmaInputBType(MmaMacro macro) {
1569-
int size = getN(macro) * getK(macro) / 32 /* threads per warp */ /
1570-
2 /* halves per 32bit register */;
1569+
int64_t size = getN(macro) * getK(macro) / 32L /* threads per warp */ /
1570+
2L /* halves per 32bit register */;
15711571
return ArrayType{std::make_shared<DataType>(DataType::UInt32), (size_t)size};
15721572
}
15731573

@@ -1842,8 +1842,8 @@ Val* hardCodedIndexGenerationForStMatrix(
18421842
// To account for the threadIdx.y, we have to add it to the offset:
18431843
// offset_from_tdy = threadIdx.y * tma_m * tma_n * 2 (half)
18441844
//
1845-
// Now, lets apply stmatrix tile to the TMA Box.
1846-
// [NO(2), MO(4), MI(16), NIO(4), NII(16)].
1845+
// Now, lets apply stmatrix tile (16, 16) to the TMA Box [NO(2), M(64), NI(64)].
1846+
// [NO(2), MO(4), MI(16), NIO(4), NII(16)].
18471847
//
18481848
// A warp group of 128 threads contains four warps. StMatrix is a warp-level
18491849
// operation, so four StMatrix operations can be issued simultaneously by the
@@ -1865,6 +1865,7 @@ Val* hardCodedIndexGenerationForStMatrix(
18651865
// domain is scheduled as [NO(2), M(64), NI(64)]. Therefore, we must store the
18661866
// data in shared memory in [M(64), NI(64)] contiguous tiles.
18671867
//
1868+
// NOTE: This offset is skipped if for-loop is trivial
18681869
// To account for the outer_index, we have to add it to the offset:
18691870
// offset_from_outer_index = outer_index * tma_m * NI(64) * 2 (half)
18701871
//
@@ -1928,8 +1929,13 @@ Val* hardCodedIndexGenerationForStMatrix(
19281929
// with the 8 rows of the matrix to avoid bank conflicts. This swizzle pattern
19291930
// is repeated along the rows of the TMA box.
19301931
//
1932+
// The number of distinct swizzle rows is number of bytes for swizzle divided by
1933+
// size of megabank (16B). The number of times a swizzle pattern is repeated to
1934+
// fill core (8, 8) matrix is number of swizzle rows (8) divided by number of
1935+
// distinct rows.
1936+
//
19311937
// Swizzle column
1932-
// row_in_swizzle_pattern = row % swizzle_row_size(8)
1938+
// row_in_swizzle_pattern = (row % swizzle_row_size(8)) / swizzle_repetitions
19331939
// swizzle_col = column XOR row_in_swizzle_pattern
19341940
//
19351941
// Calculate Tile Offset
@@ -1939,7 +1945,7 @@ Val* hardCodedIndexGenerationForStMatrix(
19391945
//
19401946
// Get shared memory offset
19411947
// smem_offset = offset_from_tdy + offset_from_outer_index + tile_offset
1942-
Val* hardCodedIndexGenerationForStMatrix128BSwizzle(
1948+
Val* hardCodedIndexGenerationForStMatrixSwizzle(
19431949
const LoadStoreOp* ldst,
19441950
ForLoop* loop,
19451951
const int64_t stsm_m_tile,
@@ -1958,16 +1964,19 @@ Val* hardCodedIndexGenerationForStMatrix128BSwizzle(
19581964

19591965
NVF_ERROR(ldst->out()->isA<TensorView>());
19601966
TensorView* out_tv = ldst->out()->as<TensorView>();
1961-
NVF_ERROR(getSwizzle(out_tv) == MmaInputSmemSwizzle::B128);
1967+
MmaInputSmemSwizzle swizzle = getSwizzle(out_tv);
1968+
int64_t swizzle_bytes = getBytesFromSwizzle(swizzle);
19621969

19631970
// Constants
19641971
constexpr int64_t dtype_size = 2;
19651972
constexpr int64_t warp_size = 32;
19661973
constexpr int64_t swizzle_row_size = 8;
19671974
constexpr int64_t stsm_column_size = 8;
1968-
constexpr int64_t swizzle_n_tile = 64;
1975+
constexpr int64_t megabank_size_bytes = 16;
19691976

19701977
// Derived constants
1978+
const int64_t swizzle_n_tile = swizzle_bytes / dtype_size;
1979+
const int64_t distinct_swizzle_row_size = swizzle_bytes / megabank_size_bytes;
19711980
constexpr int64_t stsm_column_stride = stsm_column_size * dtype_size;
19721981
const int64_t swizzle_n_iter = swizzle_n_tile / stsm_n_tile;
19731982
const int64_t swizzle_n_tile_stride = swizzle_n_tile * dtype_size;
@@ -2000,8 +2009,6 @@ Val* hardCodedIndexGenerationForStMatrix128BSwizzle(
20002009
Val* warp_id = SimplifyingIrBuilder::divExpr(TDX, warp_size_val);
20012010
Val* lane_id = SimplifyingIrBuilder::modExpr(TDX, warp_size_val);
20022011

2003-
Val* outer_index =
2004-
SimplifyingIrBuilder::divExpr(loop->index(), swizzle_n_iter_val);
20052012
Val* inner_index =
20062013
SimplifyingIrBuilder::modExpr(loop->index(), swizzle_n_iter_val);
20072014

@@ -2021,6 +2028,17 @@ Val* hardCodedIndexGenerationForStMatrix128BSwizzle(
20212028
// Swizzle Column
20222029
Val* row_in_swizzle_pattern =
20232030
SimplifyingIrBuilder::modExpr(row, swizzle_row_size_val);
2031+
2032+
// The swizzle pattern is repeated to fill (8, 8) matrix for 64B and 32B
2033+
// swizzles. swizzle_row_iter is the number of repetitions to fill 8 rows
2034+
// with distict swizzle rows.
2035+
const int64_t swizzle_row_iter = swizzle_row_size / distinct_swizzle_row_size;
2036+
if (swizzle_row_iter > 1) {
2037+
Val* swizzle_row_iter_val =
2038+
IrBuilder::create<Val>(swizzle_row_iter, DataType::Index);
2039+
row_in_swizzle_pattern = SimplifyingIrBuilder::divExpr(
2040+
row_in_swizzle_pattern, swizzle_row_iter_val);
2041+
}
20242042
Val* swizzle_col = bitwise_xor(col, row_in_swizzle_pattern);
20252043

20262044
// Calculate Tile Offset
@@ -2031,16 +2049,22 @@ Val* hardCodedIndexGenerationForStMatrix128BSwizzle(
20312049
Val* offset = SimplifyingIrBuilder::addExpr(row_offset, col_offset);
20322050

20332051
// Calculate Tile offset
2034-
Val* tile_offset = IrBuilder::mulExpr(outer_index, tile_stride_val);
2052+
// Skip tile offset if loop is trivial.
2053+
if (!loop->stop()->isOneInt()) {
2054+
Val* outer_index =
2055+
SimplifyingIrBuilder::divExpr(loop->index(), swizzle_n_iter_val);
2056+
Val* tile_offset =
2057+
SimplifyingIrBuilder::mulExpr(outer_index, tile_stride_val);
2058+
offset = SimplifyingIrBuilder::addExpr(tile_offset, offset);
2059+
}
20352060

20362061
// Calculate TDY offset
2037-
Val* tdy_offset = IrBuilder::mulExpr(TDY, tdy_stride_val);
2062+
Val* tdy_offset = SimplifyingIrBuilder::mulExpr(TDY, tdy_stride_val);
2063+
offset = SimplifyingIrBuilder::addExpr(tdy_offset, offset);
20382064

20392065
// Create shared memory TensorIndex
20402066
Val* out_index = SimplifyingIrBuilder::addExpr(
2041-
IrBuilder::baseAddressExpr(ir_utils::getTvOutput(ldst)),
2042-
SimplifyingIrBuilder::addExpr(
2043-
tdy_offset, SimplifyingIrBuilder::addExpr(tile_offset, offset)));
2067+
IrBuilder::baseAddressExpr(ir_utils::getTvOutput(ldst)), offset);
20442068
Val* out = IrBuilder::create<kir::TensorIndex>(
20452069
dynamic_cast<TensorView*>(ldst->out()), out_index);
20462070
return out;
@@ -2092,11 +2116,11 @@ void IndexLowering::handle(const LoadStoreOp* ldst) {
20922116
ldst, for_loops_[0], m_tile, n_tile, m, n);
20932117
break;
20942118
case MmaInputSmemSwizzle::B128:
2095-
out = hardCodedIndexGenerationForStMatrix128BSwizzle(
2119+
case MmaInputSmemSwizzle::B64:
2120+
case MmaInputSmemSwizzle::B32:
2121+
out = hardCodedIndexGenerationForStMatrixSwizzle(
20962122
ldst, for_loops_[0], m_tile, n_tile, m, n);
20972123
break;
2098-
case MmaInputSmemSwizzle::B32:
2099-
case MmaInputSmemSwizzle::B64:
21002124
default:
21012125
NVF_ERROR("Unsupported Swizzle Type for StMatrix");
21022126
}

csrc/scheduler/hopper_multi_matmul.cpp

+7-11
Original file line numberDiff line numberDiff line change
@@ -1027,13 +1027,6 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
10271027
const int64_t tma_m = getM(params_->mma_macro);
10281028
const int64_t tma_n = getN(params_->mma_macro);
10291029

1030-
NVF_ERROR(
1031-
tma_n >= 64,
1032-
"Scheduler only supports 128B swizzle that requires N dimension of MMA ",
1033-
"macro to be >= 64, but received ",
1034-
tma_n,
1035-
".");
1036-
10371030
fusion_->manage("st_matrix_m_tile", stmatrix_tile_m);
10381031
fusion_->manage("st_matrix_n_tile", stmatrix_tile_n);
10391032
fusion_->manage("st_matrix_m", tma_m);
@@ -1084,12 +1077,14 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
10841077
dc->setAllocationDomain(s.as<IterDomain*>(), true);
10851078
}
10861079

1080+
MmaInputSmemSwizzle swizzle = tmaSwizzleSharedMemory(d_smem);
1081+
10871082
// Schedule shared memory cache; Output from StMatrix
10881083
scheduleStMatrixForMmaOutput(
1089-
d_smem, stmatrix_tile_m, stmatrix_tile_n, tma_m, tma_n);
1084+
d_smem, swizzle, stmatrix_tile_m, stmatrix_tile_n, tma_m, tma_n);
10901085

10911086
// Schedule global memory output; Output from TMA Store
1092-
scheduleTMAStoreForMmaOutput(d, tma_m, tma_n);
1087+
scheduleTMAStoreForMmaOutput(d, swizzle, tma_m, tma_n);
10931088
}
10941089
}
10951090
}
@@ -1247,6 +1242,7 @@ void HopperMultipleMatmulScheduler::setUpCircularBuffering() {
12471242

12481243
void HopperMultipleMatmulScheduler::scheduleStMatrixForMmaOutput(
12491244
TensorView* tv,
1245+
MmaInputSmemSwizzle swizzle,
12501246
int64_t tile_m,
12511247
int64_t tile_n,
12521248
int64_t tma_m,
@@ -1263,7 +1259,7 @@ void HopperMultipleMatmulScheduler::scheduleStMatrixForMmaOutput(
12631259
mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(tv->getLoopDomain());
12641260

12651261
// Create tma store allocation domain with swizzle
1266-
scheduleTMAStoreForMmaOutput(tv, tma_m, tma_n);
1262+
scheduleTMAStoreForMmaOutput(tv, swizzle, tma_m, tma_n);
12671263

12681264
tv->setLoopDomain(s.as<IterDomain*>());
12691265

@@ -1290,6 +1286,7 @@ void HopperMultipleMatmulScheduler::scheduleStMatrixForMmaOutput(
12901286

12911287
void HopperMultipleMatmulScheduler::scheduleTMAStoreForMmaOutput(
12921288
TensorView* tv,
1289+
MmaInputSmemSwizzle swizzle,
12931290
int64_t m,
12941291
int64_t n) {
12951292
// [M(m), N(n)] -> [MO(1), MI(m), NO(1), NI(n)]
@@ -1301,7 +1298,6 @@ void HopperMultipleMatmulScheduler::scheduleTMAStoreForMmaOutput(
13011298
// [BDX, BDY, TDY, MO(1), NO(1), MI, NI]
13021299
// skip the first 5 iterDomains
13031300
int64_t num_ids_to_skip = 5;
1304-
MmaInputSmemSwizzle swizzle = MmaInputSmemSwizzle::B128;
13051301

13061302
NVF_ERROR(num_ids_to_skip >= 0);
13071303
if (swizzle == MmaInputSmemSwizzle::None) {

csrc/scheduler/hopper_multi_matmul.h

+6-1
Original file line numberDiff line numberDiff line change
@@ -182,14 +182,19 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler {
182182
//! registers to shared memory.
183183
void scheduleStMatrixForMmaOutput(
184184
TensorView* tv,
185+
MmaInputSmemSwizzle swizzle,
185186
int64_t tile_m,
186187
int64_t tile_n,
187188
int64_t tma_m,
188189
int64_t tma_n);
189190

190191
//! Schedules the copy operation of output of a Mma op which resided in the
191192
//! shared memory to global memory.
192-
void scheduleTMAStoreForMmaOutput(TensorView* tv, int64_t m, int64_t n);
193+
void scheduleTMAStoreForMmaOutput(
194+
TensorView* tv,
195+
MmaInputSmemSwizzle swizzle,
196+
int64_t m,
197+
int64_t n);
193198

194199
// Map TensorView's iterDomain to its ValGroup.
195200
// Then, find the MatmulDimRole for the ValGroup.

tests/cpp/test_matmul_scheduler.cpp

+52-9
Original file line numberDiff line numberDiff line change
@@ -3119,32 +3119,51 @@ using HopperMatmulSchedulerTestParams = std::tuple<
31193119
bool, // b_k_inner
31203120
int64_t, // M
31213121
int64_t, // N
3122-
int64_t // K
3123-
>;
3122+
int64_t, // K
3123+
MmaMacro>;
31243124

31253125
std::string hopperTestName(
31263126
const testing::TestParamInfo<HopperMatmulSchedulerTestParams>& info) {
31273127
std::ostringstream os;
31283128
bool use_smem_epilogue;
31293129
bool a_k_inner, b_k_inner;
31303130
int64_t M, N, K;
3131-
std::tie(use_smem_epilogue, a_k_inner, b_k_inner, M, N, K) = info.param;
3131+
MmaMacro mma_macro;
3132+
std::tie(use_smem_epilogue, a_k_inner, b_k_inner, M, N, K, mma_macro) =
3133+
info.param;
31323134
os << (a_k_inner ? "K" : "M");
31333135
os << (b_k_inner ? "K" : "N");
31343136
os << "_" << M << "_" << N << "_" << K;
3137+
os << "_MmaMacro_" << mma_macro_to_str_map.at(mma_macro);
31353138
if (use_smem_epilogue) {
31363139
os << "_tma_store";
31373140
}
31383141
return os.str();
31393142
}
31403143

3144+
std::string hopperTestNameSwizzle(
3145+
const testing::TestParamInfo<HopperMatmulSchedulerTestParams>& info) {
3146+
std::unordered_map<MmaMacro, std::string> mma_macro_to_swizzle_str_map = {
3147+
{MmaMacro::Hopper_64_256_16, "128BSwizzle"},
3148+
{MmaMacro::Hopper_64_128_16, "128BSwizzle"},
3149+
{MmaMacro::Hopper_64_64_16, "128BSwizzle"},
3150+
{MmaMacro::Hopper_64_32_16, "64BSwizzle"},
3151+
{MmaMacro::Hopper_64_16_16, "32BSwizzle"}};
3152+
MmaMacro mma_macro = std::get<6>(info.param);
3153+
std::ostringstream os;
3154+
os << hopperTestName(info);
3155+
os << "_" << mma_macro_to_swizzle_str_map.at(mma_macro);
3156+
return os.str();
3157+
}
3158+
31413159
class HopperMatmulSchedulerTest
31423160
: public NVFuserFixtureParamTest<HopperMatmulSchedulerTestParams> {
31433161
protected:
31443162
void SetUp() {
31453163
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(9, 0, 10, 0);
31463164

3147-
std::tie(use_smem_epilogue, a_k_inner, b_k_inner, M, N, K) = GetParam();
3165+
std::tie(use_smem_epilogue, a_k_inner, b_k_inner, M, N, K, mma_macro) =
3166+
GetParam();
31483167

31493168
if (a_k_inner) {
31503169
layout = b_k_inner ? MmaLayout::TN : MmaLayout::TT;
@@ -3159,14 +3178,17 @@ class HopperMatmulSchedulerTest
31593178
// Create custom Matmul Params
31603179
MatMulTileOptions gemm_tile;
31613180
// TODO cta tile is a multiple of mma macro for hopper.
3162-
gemm_tile.cta_tile = GemmTile(128, 256, 16);
3181+
// Default cta_tile configuration is 2-CTA.
3182+
gemm_tile.cta_tile =
3183+
GemmTile(2 * getM(mma_macro), getN(mma_macro), getK(mma_macro));
31633184

31643185
// TODO warp tile is (macroM, macroN, macroK) for hopper.
3165-
gemm_tile.warp_tile = GemmTile(64, 128, 16);
3186+
gemm_tile.warp_tile =
3187+
GemmTile(getM(mma_macro), getN(mma_macro), getK(mma_macro));
31663188

31673189
mparams.supported_vec_size = {8, 8, 4};
31683190

3169-
mparams.mma_macro = MmaMacro::Hopper_64_128_16;
3191+
mparams.mma_macro = mma_macro;
31703192

31713193
mparams.use_smem_epilogue = use_smem_epilogue;
31723194

@@ -3203,6 +3225,7 @@ class HopperMatmulSchedulerTest
32033225
bool use_smem_epilogue;
32043226
bool a_k_inner, b_k_inner;
32053227
int64_t M, N, K;
3228+
MmaMacro mma_macro;
32063229
std::unique_ptr<Fusion> fusion_up;
32073230
Fusion* fusion;
32083231
std::unique_ptr<FusionGuard> fusion_guard;
@@ -3275,16 +3298,36 @@ TEST_P(HopperMatmulSchedulerTest, FusedMultiplySum) {
32753298
}
32763299

32773300
INSTANTIATE_TEST_SUITE_P(
3278-
,
3301+
General,
32793302
HopperMatmulSchedulerTest,
32803303
testing::Combine(
32813304
testing::Bool(), // use_smem_epilogue
32823305
testing::Bool(), // a_k_inner
32833306
testing::Bool(), // b_k_inner
32843307
testing::Values(512), // M
32853308
testing::Values(256), // N
3286-
testing::Values(64) // K
3309+
testing::Values(64), // K
3310+
testing::Values(MmaMacro::Hopper_64_128_16) // mma_macros
32873311
),
32883312
hopperTestName);
32893313

3314+
INSTANTIATE_TEST_SUITE_P(
3315+
Swizzle,
3316+
HopperMatmulSchedulerTest,
3317+
testing::Combine(
3318+
testing::Values(true), // use_smem_epilogue
3319+
testing::Bool(), // a_k_inner
3320+
testing::Bool(), // b_k_inner
3321+
testing::Values(512), // M
3322+
testing::Values(256), // N
3323+
testing::Values(64), // K
3324+
testing::Values(
3325+
MmaMacro::Hopper_64_256_16,
3326+
MmaMacro::Hopper_64_128_16,
3327+
MmaMacro::Hopper_64_64_16,
3328+
MmaMacro::Hopper_64_32_16,
3329+
MmaMacro::Hopper_64_16_16) // mma_macros
3330+
),
3331+
hopperTestNameSwizzle);
3332+
32903333
} // namespace nvfuser

0 commit comments

Comments
 (0)