Skip to content

Commit

Permalink
Allow matmul heuristic plugin to set cluster dimensions (#3634)
Browse files Browse the repository at this point in the history
This allows heuristic plugins to set cluster dimensions. By default, the
cluster dims are set to {1, 1, 1}, which disables this feature, so if a
plugin does not explicitly handle cluster dims, then it will just not
make use of this feature. This is necessary because setting to an
invalid value can cause a launch failure. This PR also prints the
cluster dims in `MatmulParams::toString`.
  • Loading branch information
jacobhinkle authored Dec 23, 2024
1 parent 99fb12b commit 02ffc83
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 0 deletions.
2 changes: 2 additions & 0 deletions csrc/scheduler/matmul_heuristic.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ class MatmulParams : public HeuristicParams {
: "column-major")
<< "\n"
<< "Grid swizzle factor: " << grid_swizzle_factor << "\n"
<< "Cluster dimensions: " << std::get<0>(cluster_dims) << " "
<< std::get<1>(cluster_dims) << " " << std::get<2>(cluster_dims) << "\n"
<< "Use shared memory epilogue: " << use_smem_epilogue << "\n"
<< "Promote re-use of prologue shared memory: "
<< promote_prologue_smem_reuse << "\n"
Expand Down
6 changes: 6 additions & 0 deletions csrc/scheduler/matmul_heuristic_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ void copyParamsToConfig(KernelConfig* config, const MatmulParams* mparams) {
setConfigTile(config->cta_tile, mparams->tile_sizes.cta_tile);
setConfigTile(config->warp_tile, mparams->tile_sizes.warp_tile);
setConfigTile(config->instruction_tile, getMmaOpShape(mparams->mma_macro));
config->cluster_dims[0] = std::get<0>(mparams->cluster_dims);
config->cluster_dims[1] = std::get<1>(mparams->cluster_dims);
config->cluster_dims[2] = std::get<2>(mparams->cluster_dims);
config->splitk_factor = mparams->splitk_factor;
config->grid_swizzle_factor = mparams->grid_swizzle_factor;
config->cta_order =
Expand All @@ -163,6 +166,9 @@ void copyConfigToParams(MatmulParams* mparams, const KernelConfig* config) {
};
setGemmTile(mparams->tile_sizes.cta_tile, config->cta_tile);
setGemmTile(mparams->tile_sizes.warp_tile, config->warp_tile);
std::get<0>(mparams->cluster_dims) = config->cluster_dims[0];
std::get<1>(mparams->cluster_dims) = config->cluster_dims[1];
std::get<2>(mparams->cluster_dims) = config->cluster_dims[2];
mparams->circular_buffer_options.smem_circular_buffer_stage =
config->load_stages;
mparams->circular_buffer_options.smem_circular_buffer_prefetch_gap =
Expand Down
1 change: 1 addition & 0 deletions csrc/scheduler/matmul_heuristic_plugin_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ struct KernelConfig {
Tile cta_tile = {128, 128, 32};
Tile warp_tile = {64, 64, 32};
Tile instruction_tile = {16, 16, 16};
Tile cluster_dims = {1, 1, 1};
uint16_t splitk_factor = 1;
uint8_t load_stages = 2;
// The circular buffering prefetch distance will be set to
Expand Down

0 comments on commit 02ffc83

Please sign in to comment.