Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] why the implementation of f16xs8 mixed gemm is different between TRT-LLM and native cutlass mixed gemm example? #2022

Open
danielhua23 opened this issue Jan 5, 2025 · 4 comments

Comments

@danielhua23
Copy link

What is your question?
Dear cutlass team,

lets consider sm80 and f16s8, the example of f16s8 TN mixed gemm shown here is different from TRT-LLM implementation, specifically, to my knowledge, the TRT-LLM one added the dequantization scale, but the cutlass one did not. Then my questions are:

  1. Is the performance or accuracy of TRT-LLM adding dequantization scale better than cutlass native one in LLM linear cases?
  2. from here, I see the TRT-LLM one seems load operand B(s8) using LDS not LDSM, but I can't find the f16s8 LDS specialization in MmaTensorOpMultiplicandTileIterator, only find LDS specialization for TF32, which make me confused with the “LDS". Am I missing something?

Thanks your time!

cc @manishucsd @alexsamardzic @hwu36

@alexsamardzic
Copy link
Contributor

alexsamardzic commented Jan 5, 2025

A general answer to your question would be: different approaches are possible, with different trade-offs. I'm not familiar with the code base you pointed to, but I would assume that it's made before CUTLASS added support for (some) mixed data-types GEMMs. Nowadays, CUTLASS way to apply quantization scale factors would be through EVT (see here for an EVT example).

As far as performance/accuracy concerned, it depends on the context. For example, mixed data-types GEMM on Ampere generation GPUs requires re-arranging of elements of tensor having smaller data-type. CUTLASS is doing it during each GEMM operation, but when mixed data-types GEMM used in the context say of LLM inference, with a model having its weight quantized, there are implementations (like Marlin) that expect users to have this re-arrangement done up front, along with the weights quantization, so when such weight tensor repeatedly used as a mixed data-types GEMM operand during the inference, there will be a slight performance advantage over CUTLASS.

As far as your second question concerned, I would assume that you need to also check here.

@danielhua23
Copy link
Author

@alexsamardzic thanks for your good response, I want to confirm that is mixed data-types GEMM on Ampere generation GPUs requires re-arranging of elements of tensor having smaller data-type. CUTLASS is doing it during each GEMM operation pointing to FragmentShuffler and no more rearrangement in CUTLASS?

@manishucsd
Copy link
Contributor

manishucsd commented Jan 6, 2025

The fundamental way, irrespective of Ampere or Hopper, to think about this is that Tensor Cores need thread-data arrangement in a specific layout depending on the input datatypes. There are various ways to achieve it :

  1. Pre processing on weights. Swizzled layout in GMEM which is not Row/Col major (rearrangement in GMEM)
  • Instruction sequence : LDSM -> Tensor Cores
  • Rearrangement in GMEM is done to so the layout after LDSM is Tensor Core conformant.
  • No extra cost at the inference time.
  1. Canonical Row/Col major layout in GMEM (no rearrangement)
  • Instruction sequence : LDS -> Tensor Cores
  • Use of narrower SMEM loads to attain operand layouts needed for Tensor Cores.
  • Slower loads but no shuffling needed either in GMEM or Registers.
  • Narrow bit-width loads ensure Tensor Core layout on operands in registers.
  1. Canonical Row/Col major layout in GMEM (no rearrangement)
  • Instruction sequence : LDSM -> FragmentShuffler -> Tensor Cores
  • No rearrangement in GMEM, still using wider bit-width LDSM, but shuffling needed in Registers.
  • Layout conformance achieved using FragmentShuffler warp-level register-register shuffles.

@danielhua23
Copy link
Author

danielhua23 commented Jan 6, 2025

Thanks for your detailed information @manishucsd , which is very useful for me. Still left a question, Marlin seems implement the mixed gemm using preprocess weights AOT that is 1st way you mentioned above, but when I was checking Marlin's code, it didn't use the LDSM for 4bit operand B, but LDS, declaring LDSM not support 4bit. It only uses LDSM for high bit operand A. Is this a potential optimization point for Marlin?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants