Skip to content

Commit 4213b4f

Browse files
authored
fix moe stage2 when prebuild aot (#1018)
1 parent f03268a commit 4213b4f

File tree

2 files changed

+4
-6
lines changed

2 files changed

+4
-6
lines changed

csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ MoeKernel moe_dispatch(std::string &kernelName, int block_m, int inter_dim, at::
3838
}
3939
else
4040
{
41-
return moe_stage2_heuristic_dispatch(block_m, inter_dim, x_dtype, w_dtype, y_dtype, act_op, quant_type, mul_routed_weight);
41+
return moe_stage2_heuristic_dispatch(block_m, inter_dim, x_dtype, w_dtype, y_dtype, 0, quant_type, mul_routed_weight);
4242
}
4343
}
4444

csrc/ck_gemm_moe_2stages_codegen/gen_instances.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@
5959
"""
6060

6161
heuristic_dispatch_end = """
62+
TORCH_CHECK(
63+
false,
64+
"Unsupported kernel config for moe heuristic dispatch");
6265
}}
6366
6467
"""
@@ -227,7 +230,6 @@
227230
if (dtype_checker<{A0DataType}>{{}}(x_dtype)
228231
&& dtype_checker<{B0DataType}>{{}}(w_dtype)
229232
&& dtype_checker<{EDataType}>{{}}(y_dtype)
230-
&& {ActOP} == act_op
231233
&& {MulRoutedWeight} == mul_routed_weight_stage
232234
&& {Quant} == quant)
233235
{{
@@ -261,7 +263,6 @@
261263
if (dtype_checker<{A0DataType}>{{}}(x_dtype)
262264
&& dtype_checker<{B0DataType}>{{}}(w_dtype)
263265
&& dtype_checker<{EDataType}>{{}}(y_dtype)
264-
&& {ActOP} == act_op
265266
&& {MulRoutedWeight} == mul_routed_weight_stage
266267
&& {Quant} == quant)
267268
{{
@@ -295,7 +296,6 @@
295296
if (dtype_checker<{A0DataType}>{{}}(x_dtype)
296297
&& dtype_checker<{B0DataType}>{{}}(w_dtype)
297298
&& dtype_checker<{EDataType}>{{}}(y_dtype)
298-
&& {ActOP} == act_op
299299
&& {MulRoutedWeight} == mul_routed_weight_stage
300300
&& {Quant} == quant)
301301
{{
@@ -331,7 +331,6 @@
331331
if (dtype_checker<{A0DataType}>{{}}(x_dtype)
332332
&& dtype_checker<{B0DataType}>{{}}(w_dtype)
333333
&& dtype_checker<{EDataType}>{{}}(y_dtype)
334-
&& {ActOP} == act_op
335334
&& {MulRoutedWeight} == mul_routed_weight_stage
336335
&& {Quant} == quant)
337336
{{
@@ -389,7 +388,6 @@
389388
if (dtype_checker<{A0DataType}>{{}}(x_dtype)
390389
&& dtype_checker<{B0DataType}>{{}}(w_dtype)
391390
&& dtype_checker<{EDataType}>{{}}(y_dtype)
392-
&& {ActOP} == act_op
393391
&& {MulRoutedWeight} == mul_routed_weight_stage
394392
&& {Quant} == quant)
395393
{{

0 commit comments

Comments
 (0)