@@ -315,28 +315,28 @@ def test_fp8_gemm_nt(self) -> None:
315315 f"{ m = } , { n = } , { k = } , { kernel_opt } , { accumulate = } , { out_dtype = } , "
316316 f"{ diff :.5f} , alias={ test_alias } "
317317 )
318- a , b , c , d , ref_d = generate_normal (
319- m , n , k , accumulate , out_dtype , use_ue8m0 = use_ue8m0
320- )
321-
322- # Test launch overhead
323- launch_start_t = time .time_ns ()
324- fp8_gemm_nt (a , b , d , c = c , disable_ue8m0_cast = disable_ue8m0_cast )
325- launch_end_t = time .time_ns ()
326- torch .cuda .synchronize ()
327-
328- # noinspection PyShadowingNames
329- def test_func ():
330- fp8_gemm_nt (a , b , d , c = c , disable_ue8m0_cast = disable_ue8m0_cast )
331-
332- t = bench_kineto (test_func , "fp8_gemm" , suppress_kineto_output = True )
333- print (
334- f" > Perf (m={ m :5} , n={ n :5} , k={ k :5} , { kernel_opt } , { out_opt } , { acc_opt } ): "
335- f"launch { (launch_end_t - launch_start_t ) / 1e3 :4.0f} us | { t * 1e6 :4.0f} us | "
336- f"{ 2 * m * n * k / t / 1e12 :4.0f} TFLOPS | "
337- f"{ (count_bytes (a , b , d ) + count_bytes (c ) * int (accumulate )) / 1e9 / t :4.0f} GB/s" ,
338- flush = True ,
339- )
318+ # a, b, c, d, ref_d = generate_normal(
319+ # m, n, k, accumulate, out_dtype, use_ue8m0=use_ue8m0
320+ # )
321+
322+ # # Test launch overhead
323+ # launch_start_t = time.time_ns()
324+ # fp8_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast)
325+ # launch_end_t = time.time_ns()
326+ # torch.cuda.synchronize()
327+
328+ # # noinspection PyShadowingNames
329+ # def test_func():
330+ # fp8_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast)
331+
332+ # t = bench_kineto(test_func, "fp8_gemm", suppress_kineto_output=True)
333+ # print(
334+ # f" > Perf (m={m:5}, n={n:5}, k={k:5}, {kernel_opt}, {out_opt}, {acc_opt}): "
335+ # f"launch {(launch_end_t - launch_start_t) / 1e3:4.0f} us | {t * 1e6:4.0f} us | "
336+ # f"{2 * m * n * k / t / 1e12:4.0f} TFLOPS | "
337+ # f"{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s",
338+ # flush=True,
339+ # )
340340 print (flush = True )
341341
342342 def test_m_grouped_fp8_gemm_nt_contiguous (self ) -> None :
@@ -367,24 +367,24 @@ def test_m_grouped_fp8_gemm_nt_contiguous(self) -> None:
367367 assert (
368368 diff < 0.001
369369 ), f"{ m = } , { n = } , { k = } , { kernel_opt } , { diff :.5f} , alias={ test_alias } "
370- m , a , b , m_indices , d , ref_d = generate_m_grouped_contiguous (
371- num_groups , expected_m_per_group , n , k , use_ue8m0 = use_ue8m0
372- )
373-
374- # noinspection PyShadowingNames
375- def test_func ():
376- m_grouped_fp8_gemm_nt_contiguous (
377- a , b , d , m_indices , disable_ue8m0_cast = disable_ue8m0_cast
378- )
379-
380- t = bench_kineto (test_func , "fp8_gemm" , suppress_kineto_output = True )
381- print (
382- f" > Perf ({ num_groups = } , m={ m :5} , n={ n :5} , k={ k :5} , { kernel_opt } ): "
383- f"{ t * 1e6 :4.0f} us | "
384- f"{ 2 * m * n * k / t / 1e12 :4.0f} TFLOPS | "
385- f"{ count_bytes (a , b , d ) / 1e9 / t :4.0f} GB/s" ,
386- flush = True ,
387- )
370+ # m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(
371+ # num_groups, expected_m_per_group, n, k, use_ue8m0=use_ue8m0
372+ # )
373+
374+ # # noinspection PyShadowingNames
375+ # def test_func():
376+ # m_grouped_fp8_gemm_nt_contiguous(
377+ # a, b, d, m_indices, disable_ue8m0_cast=disable_ue8m0_cast
378+ # )
379+
380+ # t = bench_kineto(test_func, "fp8_gemm", suppress_kineto_output=True)
381+ # print(
382+ # f" > Perf ({num_groups=}, m={m:5}, n={n:5}, k={k:5}, {kernel_opt}): "
383+ # f"{t * 1e6:4.0f} us | "
384+ # f"{2 * m * n * k / t / 1e12:4.0f} TFLOPS | "
385+ # f"{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s",
386+ # flush=True,
387+ # )
388388 print (flush = True )
389389
390390 def test_m_grouped_fp8_gemm_nt_masked (self ) -> None :
@@ -424,32 +424,32 @@ def test_m_grouped_fp8_gemm_nt_masked(self) -> None:
424424 diff < 0.001
425425 ), f"{ max_m = } , { n = } , { k = } , { j = } , masked_m={ masked_m [j ]} , { kernel_opt } , { num_groups = } , { diff :.5f} "
426426
427- # Construct full cases
428- a , b , masked_m , d , ref_d = generate_m_grouped_masked (
429- num_groups , max_m , expected_m_per_group , n , k , use_ue8m0 = use_ue8m0
430- )
431-
432- # noinspection PyShadowingNames
433- def test_func ():
434- m_grouped_fp8_gemm_nt_masked (
435- a ,
436- b ,
437- d ,
438- masked_m ,
439- expected_m_per_group ,
440- disable_ue8m0_cast = disable_ue8m0_cast ,
441- )
442-
443- # Test performance with fixed shapes
444- valid_m = masked_m .sum ().item ()
445- t = bench_kineto (test_func , "fp8_gemm" , suppress_kineto_output = True )
446- print (
447- f" > Perf ({ num_groups = } , expected_m_per_group={ expected_m_per_group :4} , n={ n :4} , k={ k :4} , { kernel_opt } ): "
448- f"{ t * 1e6 :4.0f} us | "
449- f"{ 2 * valid_m * n * k / t / 1e12 :4.0f} TFLOPS | "
450- f"{ (count_bytes (a , d ) * valid_m / (max_m * num_groups ) + count_bytes (b )) / 1e9 / t :4.0f} GB/s" ,
451- flush = True ,
452- )
427+ # # Construct full cases
428+ # a, b, masked_m, d, ref_d = generate_m_grouped_masked(
429+ # num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0
430+ # )
431+
432+ # # noinspection PyShadowingNames
433+ # def test_func():
434+ # m_grouped_fp8_gemm_nt_masked(
435+ # a,
436+ # b,
437+ # d,
438+ # masked_m,
439+ # expected_m_per_group,
440+ # disable_ue8m0_cast=disable_ue8m0_cast,
441+ # )
442+
443+ # # Test performance with fixed shapes
444+ # valid_m = masked_m.sum().item()
445+ # t = bench_kineto(test_func, "fp8_gemm", suppress_kineto_output=True)
446+ # print(
447+ # f" > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}): "
448+ # f"{t * 1e6:4.0f} us | "
449+ # f"{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | "
450+ # f"{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s",
451+ # flush=True,
452+ # )
453453 print (flush = True )
454454
455455 def test_bf16_gemm_nt (self ) -> None :
@@ -470,32 +470,32 @@ def test_bf16_gemm_nt(self) -> None:
470470 f"{ m = } , { n = } , { k = } , { accumulate = } , { out_dtype = } , "
471471 f"{ diff :.5f} , alias={ test_alias } "
472472 )
473- a , b , c , d , ref_d = generate_normal (
474- m , n , k , accumulate , out_dtype , use_bf16 = True
475- )
476-
477- cublas_t = 0
478- t = bench_kineto (
479- lambda : bf16_gemm_nt (a , b , d , c = c ),
480- "bf16_gemm" ,
481- suppress_kineto_output = True ,
482- )
483- if accumulate == 0 and out_dtype == torch .bfloat16 :
484- # noinspection PyBroadException
485- try :
486- cublas_t = bench_kineto (
487- lambda : a @ b .T , "nvjet" , suppress_kineto_output = True
488- )
489- except Exception :
490- pass
491- print (
492- f" > Perf (m={ m :5} , n={ n :5} , k={ k :5} , { out_opt } , { acc_opt } ): "
493- f"{ t * 1e6 :4.0f} us | "
494- f"{ 2 * m * n * k / t / 1e12 :4.0f} TFLOPS | "
495- f"{ (count_bytes (a , b , d ) + count_bytes (c ) * int (accumulate )) / 1e9 / t :4.0f} GB/s | "
496- f"{ cublas_t / t :.2f} x cuBLAS" ,
497- flush = True ,
498- )
473+ # a, b, c, d, ref_d = generate_normal(
474+ # m, n, k, accumulate, out_dtype, use_bf16=True
475+ # )
476+
477+ # cublas_t = 0
478+ # t = bench_kineto(
479+ # lambda: bf16_gemm_nt(a, b, d, c=c),
480+ # "bf16_gemm",
481+ # suppress_kineto_output=True,
482+ # )
483+ # if accumulate == 0 and out_dtype == torch.bfloat16:
484+ # # noinspection PyBroadException
485+ # try:
486+ # cublas_t = bench_kineto(
487+ # lambda: a @ b.T, "nvjet", suppress_kineto_output=True
488+ # )
489+ # except Exception:
490+ # pass
491+ # print(
492+ # f" > Perf (m={m:5}, n={n:5}, k={k:5}, {out_opt}, {acc_opt}): "
493+ # f"{t * 1e6:4.0f} us | "
494+ # f"{2 * m * n * k / t / 1e12:4.0f} TFLOPS | "
495+ # f"{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s | "
496+ # f"{cublas_t / t:.2f}x cuBLAS",
497+ # flush=True,
498+ # )
499499 print (flush = True )
500500
501501 def test_m_grouped_bf16_gemm_nt_contiguous (self ) -> None :
@@ -514,22 +514,22 @@ def test_m_grouped_bf16_gemm_nt_contiguous(self) -> None:
514514 d = torch .where ((m_indices == - 1 ).unsqueeze (1 ), torch .zeros_like (d ), d )
515515 diff = calc_diff (d , ref_d )
516516 assert diff < 0.001 , f"{ m = } , { n = } , { k = } , { diff :.5f} , alias={ test_alias } "
517- m , a , b , m_indices , d , ref_d = generate_m_grouped_contiguous (
518- num_groups , expected_m_per_group , n , k , use_bf16 = True
519- )
520-
521- # noinspection PyShadowingNames
522- def test_func ():
523- m_grouped_bf16_gemm_nt_contiguous (a , b , d , m_indices )
524-
525- t = bench_kineto (test_func , "bf16_gemm" , suppress_kineto_output = True )
526- print (
527- f" > Perf ({ num_groups = } , m={ m :5} , n={ n :5} , k={ k :5} ): "
528- f"{ t * 1e6 :4.0f} us | "
529- f"{ 2 * m * n * k / t / 1e12 :4.0f} TFLOPS | "
530- f"{ count_bytes (a , b , d ) / 1e9 / t :4.0f} GB/s" ,
531- flush = True ,
532- )
517+ # m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(
518+ # num_groups, expected_m_per_group, n, k, use_bf16=True
519+ # )
520+
521+ # # noinspection PyShadowingNames
522+ # def test_func():
523+ # m_grouped_bf16_gemm_nt_contiguous(a, b, d, m_indices)
524+
525+ # t = bench_kineto(test_func, "bf16_gemm", suppress_kineto_output=True)
526+ # print(
527+ # f" > Perf ({num_groups=}, m={m:5}, n={n:5}, k={k:5}): "
528+ # f"{t * 1e6:4.0f} us | "
529+ # f"{2 * m * n * k / t / 1e12:4.0f} TFLOPS | "
530+ # f"{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s",
531+ # flush=True,
532+ # )
533533 print (flush = True )
534534
535535 def test_m_grouped_bf16_gemm_nt_masked (self ) -> None :
@@ -558,25 +558,25 @@ def test_m_grouped_bf16_gemm_nt_masked(self) -> None:
558558 diff < 0.001
559559 ), f"{ max_m = } , { n = } , { k = } , { j = } , masked_m={ masked_m [j ]} , { num_groups = } , { diff :.5f} "
560560
561- # Construct full cases
562- a , b , masked_m , d , ref_d = generate_m_grouped_masked (
563- num_groups , max_m , expected_m_per_group , n , k , use_bf16 = True
564- )
565-
566- # noinspection PyShadowingNames
567- def test_func ():
568- m_grouped_bf16_gemm_nt_masked (a , b , d , masked_m , expected_m_per_group )
569-
570- # Test performance with fixed shapes
571- valid_m = masked_m .sum ().item ()
572- t = bench_kineto (test_func , "bf16_gemm" , suppress_kineto_output = True )
573- print (
574- f" > Perf ({ num_groups = } , expected_m_per_group={ expected_m_per_group :4} , n={ n :4} , k={ k :4} ): "
575- f"{ t * 1e6 :4.0f} us | "
576- f"{ 2 * valid_m * n * k / t / 1e12 :4.0f} TFLOPS | "
577- f"{ (count_bytes (a , d ) * valid_m / (max_m * num_groups ) + count_bytes (b )) / 1e9 / t :4.0f} GB/s" ,
578- flush = True ,
579- )
561+ # # Construct full cases
562+ # a, b, masked_m, d, ref_d = generate_m_grouped_masked(
563+ # num_groups, max_m, expected_m_per_group, n, k, use_bf16=True
564+ # )
565+
566+ # # noinspection PyShadowingNames
567+ # def test_func():
568+ # m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group)
569+
570+ # # Test performance with fixed shapes
571+ # valid_m = masked_m.sum().item()
572+ # t = bench_kineto(test_func, "bf16_gemm", suppress_kineto_output=True)
573+ # print(
574+ # f" > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}): "
575+ # f"{t * 1e6:4.0f} us | "
576+ # f"{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | "
577+ # f"{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s",
578+ # flush=True,
579+ # )
580580 print (flush = True )
581581
582582
0 commit comments