Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] Grouped GEMM on A10 GPUs #2053

Closed
MinghaoYan opened this issue Jan 22, 2025 · 2 comments
Closed

[QST] Grouped GEMM on A10 GPUs #2053

MinghaoYan opened this issue Jan 22, 2025 · 2 comments

Comments

@MinghaoYan
Copy link

What is your question?
I am trying to run the Cutlass Grouped Gemm kernel on A10 GPUs. The code runs smoothly on A100 GPUs, the following is the set up I am using.

using GroupedGemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
  // A operand.
  ::cutlass::bfloat16_t,
  GroupedGemmInputLayout<false>,
  ::cutlass::ComplexTransform::kNone,
  4, //GroupedGemmConfig::kAlignmentA,
  // B operand.
  ::cutlass::bfloat16_t,
  GroupedGemmInputLayout<false>,
  ::cutlass::ComplexTransform::kNone,
  4, //GroupedGemmConfig::kAlignmentB,
  // C operand.
  ::cutlass::bfloat16_t,
  ::cutlass::layout::RowMajor,
  float,
  ::cutlass::arch::OpClassTensorOp,
  ::cutlass::arch::Sm80,
  GroupedGemmConfig::ThreadblockShape,
  GroupedGemmConfig::WarpShape,
  GroupedGemmConfig::InstructionShape,
  GroupedGemmConfig::EpilogueOutputOp,
  ::cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
  GroupedGemmConfig::kStages>::GemmKernel;

However, I'm running into issues deploying it on A10 GPUs. If I change the ArchTag to SM86, I would run into a variety of compilation issues on incomplete types as shown in #609.

After further digging, in #1181 it was suggested that compiling with SM80 should be sufficient. My code does compile if archtag is set to SM80. However, it seems like any matrix multiplication setup would run into an error since my threadblock_count check

int threadblock_count = Gemm::sufficient(problem_sizes_host.data(), num_experts);

would always return 0.

Are there any suggestions or examples on how to properly use the GroupedGemm kernel on A10 GPUs?

@jackkosaian
Copy link
Contributor

It's likely the case that your kernel requires using more shared memory than is available on A10 (A100 has more shared memory than A10). Try reducing the ThreadblockShape.

@MinghaoYan
Copy link
Author

Thank you, I tuned it a bit and got it to work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants