4545 --trace=cuda,nvtx,cudnn,cublas \
4646 python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe nvfp4
4747
48+ # Example for jagged input benchmark to simulate unbalanced token splits
49+ python benchmarks/linear/benchmark_grouped_linear.py --recipe nvfp4 --jagged-input "15296,8960,14656,14784,11712,7936,14080,10880"
50+
51+ # Example to look at a single kernel target with NCU, like the fused hadamard amax kernel for NVFP4 recipe
52+ ncu -f -o ./benchmarks/linear/ncu_b200_numgemm_8_nvfp4_rht_amax \
53+ --set=full \
54+ --kernel-name "GroupHadamardAmaxTmaKernel" \
55+ -s 5 -c 5 \
56+ python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe nvfp4 --profile
57+
4858"""
4959
5060RECIPES = {
@@ -163,7 +173,7 @@ def benchmark_linear(
163173 return timing_ms
164174
165175
166- def run_benchmark_linear (mkns , recipe_name , use_bias , num_gemms = 4 ):
176+ def run_benchmark_linear (mkns , recipe_name , use_bias , num_gemms = 4 , m_splits_provided = None ):
167177 data = []
168178 assert not use_bias , "Bias is not supported for GroupedLinear benchmark"
169179
@@ -173,12 +183,13 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4):
173183 x = torch .randn ((m , k ), dtype = torch .bfloat16 , device = device , requires_grad = True )
174184 ws = [torch .randn ((n , k ), dtype = torch .bfloat16 , device = device ) for _ in range (num_gemms )]
175185 assert m % num_gemms == 0
176- m_splits = [m // num_gemms ] * num_gemms
186+ m_splits = [m // num_gemms ] * num_gemms if m_splits_provided is None else m_splits_provided
177187 # Bias is not supported for GroupedLinear benchmark
178188 bias = None
179189
180190 # Run the benchmark
181191 print (f"fwd_m={ m } , fwd_k={ k } , fwd_n={ n } " )
192+ print (f"m_splits: { m_splits } " )
182193
183194 grouped_fwd_bwd_timing_ms = benchmark_linear (
184195 x ,
@@ -235,8 +246,35 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4):
235246 default = "bf16" ,
236247 help = "Recipe to use, options are fp8_sub_channel, mxfp8, bf16, or all" ,
237248 )
249+ # add an argument for the jagged input
250+ # example: [15296, 8960, 14656, 14784, 11712, 7936, 14080, 10880] => sums up to 98304
251+ parser .add_argument (
252+ "--jagged-input" ,
253+ type = str ,
254+ default = None ,
255+ help = "Jagged input to use, example: [15296, 8960, 14656, 14784, 11712, 7936, 14080, 10880]" ,
256+ )
257+ parser .add_argument (
258+ "--hidden-dim" ,
259+ type = int ,
260+ default = 7168 ,
261+ help = "Hidden dimension to use, default is 7168" ,
262+ )
263+ parser .add_argument (
264+ "--output-dim" ,
265+ type = int ,
266+ default = 2048 ,
267+ help = "Output dimension to use, default is 2048" ,
268+ )
238269 args = parser .parse_args ()
239270
271+ jagged_input_splits = None
272+ if args .jagged_input is not None :
273+ jagged_input_splits = [int (x ) for x in args .jagged_input .split ("," )]
274+ print (f"Jagged input splits: { jagged_input_splits } " )
275+ print (f"Jagged input splits sum: { sum (jagged_input_splits )} " )
276+ print (f"Jagged input splits num_gemms: { len (jagged_input_splits )} " )
277+
240278 use_bias = False
241279 # Set the MKN values to benchmark
242280 # Deepseek V3 EP64, SEQ_LEN=8192, topK8
@@ -256,11 +294,28 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4):
256294 # 4 or 8local experts per rank
257295 num_gemms_list = [4 , 8 ]
258296
297+ if jagged_input_splits is not None :
298+ num_gemms_list = [len (jagged_input_splits )]
299+
300+ token_dim_list = [16384 , 32768 , 65536 , 98304 ]
301+ hidden_dim_list = [7168 ]
302+ output_dim_list = [2048 ]
303+
304+ # override the default targets to benchmark if specified
305+ if jagged_input_splits is not None :
306+ token_dim_list = [sum (jagged_input_splits )]
307+
308+ if args .hidden_dim is not None :
309+ hidden_dim_list = [args .hidden_dim ]
310+
311+ if args .output_dim is not None :
312+ output_dim_list = [args .output_dim ]
313+
259314 # MKN for group linear
260315 mkns = []
261- for m in [ 65536 ] :
262- for k in [ 7168 ] :
263- for n in [ 2048 ] :
316+ for m in token_dim_list :
317+ for k in hidden_dim_list :
318+ for n in output_dim_list :
264319 mkns .append ((m , k , n ))
265320
266321 # default recipes to run if not specified
@@ -272,14 +327,20 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4):
272327 recipe_list = [args .recipe ]
273328
274329 if args .profile :
275- mkns = [(8192 * 8 , 7168 , 2048 )]
330+ num_gemms_list = [8 ]
331+ hidden_dim_to_profile = 7168 if args .hidden_dim is None else args .hidden_dim
332+ output_dim_to_profile = 2048 if args .output_dim is None else args .output_dim
333+ token_dim_to_profile = 8192 * 8
334+ if jagged_input_splits is not None :
335+ num_gemms_list = [len (jagged_input_splits )]
336+ token_dim_to_profile = sum (jagged_input_splits )
337+ mkns = [(token_dim_to_profile , hidden_dim_to_profile , output_dim_to_profile )]
276338 # in profile mode, only run one recipe specified in args.recipe
277339 assert args .recipe != "all" , (
278340 "In profile mode, only one recipe can be specified, please specify the recipe as"
279341 " fp8_sub_channel, mxfp8, nvfp4, or bf16"
280342 )
281343 recipe_list = [args .recipe ]
282- num_gemms_list = [8 ]
283344 torch .autograd .profiler .emit_nvtx (record_shapes = True ).__enter__ ()
284345
285346 # Initialize a dataframe to store the results
@@ -310,6 +371,7 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4):
310371 recipe_name ,
311372 use_bias ,
312373 num_gemms = num_gemms ,
374+ m_splits_provided = jagged_input_splits ,
313375 )
314376 df_linears = pd .concat ([df_linears , df ])
315377
0 commit comments