Skip to content

Commit ff268af

Browse files
committed
squashed
Signed-off-by: Zhongbo Zhu <[email protected]>
1 parent f8cb598 commit ff268af

File tree

16 files changed

+2896
-242
lines changed

16 files changed

+2896
-242
lines changed

benchmarks/linear/benchmark_grouped_linear.py

Lines changed: 69 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,16 @@
4545
--trace=cuda,nvtx,cudnn,cublas \
4646
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe nvfp4
4747
48+
# Example for jagged input benchmark to simulate unbalanced token splits
49+
python benchmarks/linear/benchmark_grouped_linear.py --recipe nvfp4 --jagged-input "15296,8960,14656,14784,11712,7936,14080,10880"
50+
51+
# Example to look at a single kernel target with NCU, like the fused hadamard amax kernel for NVFP4 recipe
52+
ncu -f -o ./benchmarks/linear/ncu_b200_numgemm_8_nvfp4_rht_amax \
53+
--set=full \
54+
--kernel-name "GroupHadamardAmaxTmaKernel" \
55+
-s 5 -c 5 \
56+
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe nvfp4 --profile
57+
4858
"""
4959

5060
RECIPES = {
@@ -163,7 +173,7 @@ def benchmark_linear(
163173
return timing_ms
164174

165175

166-
def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4):
176+
def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4, m_splits_provided=None):
167177
data = []
168178
assert not use_bias, "Bias is not supported for GroupedLinear benchmark"
169179

@@ -173,12 +183,13 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4):
173183
x = torch.randn((m, k), dtype=torch.bfloat16, device=device, requires_grad=True)
174184
ws = [torch.randn((n, k), dtype=torch.bfloat16, device=device) for _ in range(num_gemms)]
175185
assert m % num_gemms == 0
176-
m_splits = [m // num_gemms] * num_gemms
186+
m_splits = [m // num_gemms] * num_gemms if m_splits_provided is None else m_splits_provided
177187
# Bias is not supported for GroupedLinear benchmark
178188
bias = None
179189

180190
# Run the benchmark
181191
print(f"fwd_m={m}, fwd_k={k}, fwd_n={n}")
192+
print(f"m_splits: {m_splits}")
182193

183194
grouped_fwd_bwd_timing_ms = benchmark_linear(
184195
x,
@@ -235,8 +246,35 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4):
235246
default="bf16",
236247
help="Recipe to use, options are fp8_sub_channel, mxfp8, bf16, or all",
237248
)
249+
# add an argument for the jagged input
250+
# example: [15296, 8960, 14656, 14784, 11712, 7936, 14080, 10880] => sums up to 98304
251+
parser.add_argument(
252+
"--jagged-input",
253+
type=str,
254+
default=None,
255+
help="Jagged input to use, example: [15296, 8960, 14656, 14784, 11712, 7936, 14080, 10880]",
256+
)
257+
parser.add_argument(
258+
"--hidden-dim",
259+
type=int,
260+
default=7168,
261+
help="Hidden dimension to use, default is 7168",
262+
)
263+
parser.add_argument(
264+
"--output-dim",
265+
type=int,
266+
default=2048,
267+
help="Output dimension to use, default is 2048",
268+
)
238269
args = parser.parse_args()
239270

271+
jagged_input_splits = None
272+
if args.jagged_input is not None:
273+
jagged_input_splits = [int(x) for x in args.jagged_input.split(",")]
274+
print(f"Jagged input splits: {jagged_input_splits}")
275+
print(f"Jagged input splits sum: {sum(jagged_input_splits)}")
276+
print(f"Jagged input splits num_gemms: {len(jagged_input_splits)}")
277+
240278
use_bias = False
241279
# Set the MKN values to benchmark
242280
# Deepseek V3 EP64, SEQ_LEN=8192, topK8
@@ -256,11 +294,28 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4):
256294
# 4 or 8local experts per rank
257295
num_gemms_list = [4, 8]
258296

297+
if jagged_input_splits is not None:
298+
num_gemms_list = [len(jagged_input_splits)]
299+
300+
token_dim_list = [16384, 32768, 65536, 98304]
301+
hidden_dim_list = [7168]
302+
output_dim_list = [2048]
303+
304+
# override the default targets to benchmark if specified
305+
if jagged_input_splits is not None:
306+
token_dim_list = [sum(jagged_input_splits)]
307+
308+
if args.hidden_dim is not None:
309+
hidden_dim_list = [args.hidden_dim]
310+
311+
if args.output_dim is not None:
312+
output_dim_list = [args.output_dim]
313+
259314
# MKN for group linear
260315
mkns = []
261-
for m in [65536]:
262-
for k in [7168]:
263-
for n in [2048]:
316+
for m in token_dim_list:
317+
for k in hidden_dim_list:
318+
for n in output_dim_list:
264319
mkns.append((m, k, n))
265320

266321
# default recipes to run if not specified
@@ -272,14 +327,20 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4):
272327
recipe_list = [args.recipe]
273328

274329
if args.profile:
275-
mkns = [(8192 * 8, 7168, 2048)]
330+
num_gemms_list = [8]
331+
hidden_dim_to_profile = 7168 if args.hidden_dim is None else args.hidden_dim
332+
output_dim_to_profile = 2048 if args.output_dim is None else args.output_dim
333+
token_dim_to_profile = 8192 * 8
334+
if jagged_input_splits is not None:
335+
num_gemms_list = [len(jagged_input_splits)]
336+
token_dim_to_profile = sum(jagged_input_splits)
337+
mkns = [(token_dim_to_profile, hidden_dim_to_profile, output_dim_to_profile)]
276338
# in profile mode, only run one recipe specified in args.recipe
277339
assert args.recipe != "all", (
278340
"In profile mode, only one recipe can be specified, please specify the recipe as"
279341
" fp8_sub_channel, mxfp8, nvfp4, or bf16"
280342
)
281343
recipe_list = [args.recipe]
282-
num_gemms_list = [8]
283344
torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__()
284345

285346
# Initialize a dataframe to store the results
@@ -310,6 +371,7 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4):
310371
recipe_name,
311372
use_bias,
312373
num_gemms=num_gemms,
374+
m_splits_provided=jagged_input_splits,
313375
)
314376
df_linears = pd.concat([df_linears, df])
315377

0 commit comments

Comments
 (0)