Skip to content

Conversation

alihassanijr
Copy link
Contributor

Some examples build their own kernel layer (FMHA, DistGEMM), and they must guard them according to the supported features.

Issue: #2559

Will close: #2559 #2558

CC @hwu36

Some examples build their own kernel layer (FMHA, DistGEMM), and they
must guard them according to the supported features.

Issue: NVIDIA#2559
The cmake config already guards targets by checking if we're building for SM100A
(although unsure whether it's an exact match or not).

For safety, it's best to just have the arch-specific MMA guards in
place.
@Flamefire
Copy link

I can confirm that applying this to 4.1.0 solves it as well as #2558 and fixes #2559

@@ -507,6 +507,9 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {


CUTLASS_DEVICE void operator()(Params const& params, char* smem_raw) {
#if ! defined(__CUDA_ARCH_FEAT_SM100_ALL)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alihassanijr , this only covers 100a but not 100f. You could take a look at launch control header file for the 100f macro.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't using the family macros break builds with CTK 12.8 and earlier?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both CUTLASS_ARCH_MMA_SM100F_SUPPORTED and CUTLASS_ARCH_MMA_SM100F_ENABLED are conditioned on CTK >= 12.9, so Sm100 users with CTK 12.8 will wind up with empty kernels.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. so if < 12.9, 100a; else 100a || 100f

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or is CUDA_ARCH_FAMILY(1000) just a false in 12.8?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right; defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) should work -- they're already conditioned on the correct CTK compiler version.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] 88_hopper_fmha_fp8 example fails to compile on some CUDA archs
3 participants