diff --git a/csrc/scheduler/matmul_heuristic.h b/csrc/scheduler/matmul_heuristic.h index 7e8ee6dc4d7..c2b42dcf20b 100644 --- a/csrc/scheduler/matmul_heuristic.h +++ b/csrc/scheduler/matmul_heuristic.h @@ -216,6 +216,8 @@ class MatmulParams : public HeuristicParams { : "column-major") << "\n" << "Grid swizzle factor: " << grid_swizzle_factor << "\n" + << "Cluster dimensions: " << std::get<0>(cluster_dims) << " " + << std::get<1>(cluster_dims) << " " << std::get<2>(cluster_dims) << "\n" << "Use shared memory epilogue: " << use_smem_epilogue << "\n" << "Promote re-use of prologue shared memory: " << promote_prologue_smem_reuse << "\n" diff --git a/csrc/scheduler/matmul_heuristic_plugin.cpp b/csrc/scheduler/matmul_heuristic_plugin.cpp index 01333727841..ef0954f2185 100644 --- a/csrc/scheduler/matmul_heuristic_plugin.cpp +++ b/csrc/scheduler/matmul_heuristic_plugin.cpp @@ -141,6 +141,9 @@ void copyParamsToConfig(KernelConfig* config, const MatmulParams* mparams) { setConfigTile(config->cta_tile, mparams->tile_sizes.cta_tile); setConfigTile(config->warp_tile, mparams->tile_sizes.warp_tile); setConfigTile(config->instruction_tile, getMmaOpShape(mparams->mma_macro)); + config->cluster_dims[0] = std::get<0>(mparams->cluster_dims); + config->cluster_dims[1] = std::get<1>(mparams->cluster_dims); + config->cluster_dims[2] = std::get<2>(mparams->cluster_dims); config->splitk_factor = mparams->splitk_factor; config->grid_swizzle_factor = mparams->grid_swizzle_factor; config->cta_order = @@ -163,6 +166,9 @@ void copyConfigToParams(MatmulParams* mparams, const KernelConfig* config) { }; setGemmTile(mparams->tile_sizes.cta_tile, config->cta_tile); setGemmTile(mparams->tile_sizes.warp_tile, config->warp_tile); + std::get<0>(mparams->cluster_dims) = config->cluster_dims[0]; + std::get<1>(mparams->cluster_dims) = config->cluster_dims[1]; + std::get<2>(mparams->cluster_dims) = config->cluster_dims[2]; mparams->circular_buffer_options.smem_circular_buffer_stage = config->load_stages; mparams->circular_buffer_options.smem_circular_buffer_prefetch_gap = diff --git a/csrc/scheduler/matmul_heuristic_plugin_api.h b/csrc/scheduler/matmul_heuristic_plugin_api.h index 207da96e9a8..1cd028b6a0a 100644 --- a/csrc/scheduler/matmul_heuristic_plugin_api.h +++ b/csrc/scheduler/matmul_heuristic_plugin_api.h @@ -72,6 +72,7 @@ struct KernelConfig { Tile cta_tile = {128, 128, 32}; Tile warp_tile = {64, 64, 32}; Tile instruction_tile = {16, 16, 16}; + Tile cluster_dims = {1, 1, 1}; uint16_t splitk_factor = 1; uint8_t load_stages = 2; // The circular buffering prefetch distance will be set to