Skip to content

Conversation

@pggPL
Copy link
Collaborator

@pggPL pggPL commented Dec 10, 2025

Description

Adds nvte_grouped_gemm API using cuBLASLt grouped matmul for batched GEMM on tensors with varying shapes. A GPU kernel (setup_grouped_gemm_kernel) converts NVTEGroupedTensor format (contiguous buffer + offsets) to cuBLAS requirements (pointer arrays + per-matrix M/N/K).

New API

void nvte_grouped_gemm(int transa, int transb, 
                       const NVTETensor alpha, 
                       const NVTEGroupedTensor A,
                       const NVTEGroupedTensor B, 
                       const NVTETensor beta, 
                       const NVTEGroupedTensor C,
                       NVTEGroupedTensor D, 
                       NVTETensor workspace_setup, 
                       NVTETensor workspace_cublas,
                       NVTEMatmulConfig config, 
                       cudaStream_t stream, 
                       const int64_t *avg_m,
                       const int64_t *avg_n, 
                       const int64_t *avg_k);

Computes D = alpha * op(A) @ op(B) + beta * C for groups of matrices with potentially different shapes.

Type of change

  • New feature (non-breaking change which adds functionality)

Changes

  • GPU setup kernel computing pointers/dims from grouped tensor metadata
  • FP8 support with scale_inv handling and TN layout selection on Hopper
  • GroupedGemmSetupWorkspace struct for cuBLAS workspace layout
  • Tests in test_grouped_gemm.cu comparing against nvte_multi_tensor_gemm (FP8/BF16, various shapes and transpose layouts)

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Pawel Gadzinski <[email protected]>
@pggPL pggPL changed the title [common] Add support for cublasLt GEMM for GroupedTensor [common] Add support for cuBLASLt GEMM for GroupedTensor Dec 10, 2025
pre-commit-ci bot and others added 3 commits December 10, 2025 14:32
- Add FP8 scale_inv pointer handling in nvte_grouped_gemm for proper FP8 GEMM
- Fix random padding in tests to ensure 16-byte alignment for all dtypes
- Reorder GroupedGemmSetupWorkspace members for natural alignment
- Remove debug prints

Signed-off-by: Pawel Gadzinski <[email protected]>
@ptrendx ptrendx added the MoE label Dec 10, 2025
@ptrendx ptrendx linked an issue Dec 10, 2025 that may be closed by this pull request
pggPL and others added 2 commits December 10, 2025 22:34
@pggPL
Copy link
Collaborator Author

pggPL commented Dec 10, 2025

/te-ci L0

@pggPL pggPL marked this pull request as ready for review December 10, 2025 21:43
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 10, 2025

Greptile Summary

This PR adds a new nvte_grouped_gemm API that enables batched matrix multiplication on collections of matrices with varying shapes using cuBLASLt's grouped GEMM functionality.

Key Changes

  • Implemented nvte_grouped_gemm API in cublaslt_grouped_gemm.cu with support for FP8 and BF16 data types
  • Added GPU setup kernel to convert NVTEGroupedTensor format (contiguous buffer + offsets) to cuBLASLt requirements (pointer arrays + per-matrix M/N/K)
  • FP8 support includes proper scale_inv handling and TN layout enforcement on Hopper GPUs
  • Comprehensive test suite validates against existing nvte_multi_tensor_gemm across multiple transpose layouts, shape configurations, and data types
  • Requires cuBLAS 13.1+ and Hopper (SM90+) architecture

Issues Found

  • cuBLAS version inconsistency: Compile guard checks for 13.1 (CUBLAS_VERSION >= 130100) but error messages and test guards reference 13.2 (CUBLAS_VERSION >= 130200)
  • Potential null dereference: launch_grouped_gemm_setup function dereferences C pointer before null check is applied (though this is mitigated by caller assigning D when C is null)
  • Minor style issues in workspace size calculation comments

Confidence Score: 4/5

  • This PR is generally safe to merge with minor version consistency fixes needed
  • The implementation is well-structured with comprehensive tests validating correctness against existing multi_tensor_gemm. The main concerns are: (1) cuBLAS version mismatch between compile guard (13.1) and documentation/tests (13.2), and (2) potential null pointer dereference in setup function that's currently mitigated by caller logic. Core algorithm and memory management appear sound.
  • transformer_engine/common/gemm/cublaslt_grouped_gemm.cu requires version consistency fixes in error messages and guards

Important Files Changed

Filename Overview
transformer_engine/common/gemm/cublaslt_grouped_gemm.cu Core grouped GEMM implementation with cuBLASLt integration. Found version mismatch in error messages and potential null pointer issue in setup function.
tests/cpp/operator/test_grouped_gemm.cu Comprehensive test suite covering FP8/BF16, transpose variations, shape cases, and NULL C support. Validates against existing multi_tensor_gemm.
transformer_engine/common/include/transformer_engine/gemm.h API declaration with comprehensive documentation. Found signature issue with unused config parameter reference.

Sequence Diagram

sequenceDiagram
    participant User
    participant nvte_grouped_gemm
    participant Validation
    participant OperandSelection
    participant SetupKernel as setup_grouped_gemm_kernel
    participant cuBLASLt
    
    User->>nvte_grouped_gemm: Call with A, B, C, D, alpha, beta
    nvte_grouped_gemm->>nvte_grouped_gemm: Check SM >= 90 (Hopper)
    nvte_grouped_gemm->>Validation: validate_grouped_gemm_inputs()
    Validation-->>nvte_grouped_gemm: OK
    nvte_grouped_gemm->>nvte_grouped_gemm: If C is NULL, set C = D
    
    nvte_grouped_gemm->>OperandSelection: select_grouped_operand(A, transa)
    OperandSelection->>OperandSelection: Check FP8 TN layout constraints
    OperandSelection->>OperandSelection: Choose row-wise vs column-wise data
    OperandSelection-->>nvte_grouped_gemm: A_sel (dptr, dtype, trans)
    
    nvte_grouped_gemm->>OperandSelection: select_grouped_operand(B, transb)
    OperandSelection-->>nvte_grouped_gemm: B_sel (dptr, dtype, trans)
    
    nvte_grouped_gemm->>nvte_grouped_gemm: Allocate workspace buffers
    nvte_grouped_gemm->>SetupKernel: Launch kernel on GPU
    SetupKernel->>SetupKernel: For each tensor i:<br/>- Compute pointers from base + offset<br/>- Extract M[i], N[i], K[i] from dims<br/>- Setup alpha_ptrs[i], beta_ptrs[i]
    SetupKernel-->>nvte_grouped_gemm: Pointer arrays ready
    
    nvte_grouped_gemm->>cuBLASLt: cublasLtGroupedMatrixLayoutInit()
    nvte_grouped_gemm->>cuBLASLt: cublasLtMatmulDescInit()
    nvte_grouped_gemm->>cuBLASLt: Set FP8 scale pointers (if FP8)
    nvte_grouped_gemm->>cuBLASLt: cublasLtMatmulAlgoGetHeuristic()
    cuBLASLt-->>nvte_grouped_gemm: Selected algorithm
    
    nvte_grouped_gemm->>cuBLASLt: cublasLtMatmul(A_ptrs, B_ptrs, C_ptrs, D_ptrs)
    cuBLASLt->>cuBLASLt: Execute grouped GEMM:<br/>D[i] = alpha[i]*op(A[i])@op(B[i]) + beta[i]*C[i]
    cuBLASLt-->>nvte_grouped_gemm: Results in D
    nvte_grouped_gemm-->>User: Return
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (4)

  1. tests/cpp/operator/test_grouped_gemm.cu, line 75 (link)

    logic: missing columnwise_data in move assignment

  2. tests/cpp/operator/test_grouped_gemm.cu, line 336-337 (link)

    logic: kSameFirst and kSameLast test cases use identical shape tuples. Should kSameFirst and kSameLast have different shape patterns to properly test the respective scenarios?

  3. tests/cpp/operator/test_grouped_gemm.cu, line 367-378 (link)

    logic: missing case for InputCase::kFP8Delayed

  4. transformer_engine/common/include/transformer_engine/gemm.h, line 265-266 (link)

    style: The avg_m, avg_n, avg_k parameters are not documented in the function comment

    What do these average dimension parameters represent and how should they be computed?

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

4 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

@ptrendx ptrendx requested a review from timmoon10 December 10, 2025 22:35
pggPL and others added 2 commits December 11, 2025 11:56
- Remove unused alignment parameter from GroupedGemmSetupWorkspace::from_buffers
- Simplify select_grouped_operand by removing dead code branches
- Add GroupedOperandSelection.tensor field to avoid passing tensor separately
- Extract set_fp8_scale_pointers and init_matrix_layouts helpers
- Add safety check for FP8 on Hopper column-wise fallback
- Support NULL C tensor when beta=0 (uses D as placeholder)
- Remove unused get_scale_inv() from test
- Add use_null_c test parameter and test case
- Fix documentation: alpha/beta are single element tensors only

Signed-off-by: Piotr Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/common/include/transformer_engine/gemm.h, line 266 (link)

    syntax: Documentation incorrectly states alpha[i] and beta[i]. The implementation uses a single alpha and beta value for all matrices in the group (batch stride is 1 in init_matmul_desc at cublaslt_gemm.cu:1404), not per-matrix scaling.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

pggPL and others added 3 commits December 11, 2025 12:16
- Change alpha/beta from single values to per-matrix arrays
- Validate alpha/beta have exactly num_tensors elements
- Update kernel to index alpha_ptr[idx] and beta_ptr[idx]
- Move alpha/beta validation to validate_grouped_gemm_inputs
- Update tests to use per-matrix alpha/beta arrays
- Update documentation

Signed-off-by: Piotr Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Piotr Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
@pggPL
Copy link
Collaborator Author

pggPL commented Dec 11, 2025

/te-ci

greptile-apps[bot]

This comment was marked as resolved.

Comment on lines 273 to 274
NVTEMatmulConfig config, cudaStream_t stream, const int64_t *avg_m,
const int64_t *avg_n, const int64_t *avg_k);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The average sizes seem like advanced configs that would be better to leave out of the top-level API. Can we move them inside NVTEMatmulConfig?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was suggestion of @ptrendx - customer may potentially want to use them if they know something more about the shapes. For example if there are multiple tensors of with k dimension D and one tensor with K dimension equal to 1, then it is potentially true that telling cublas that avg dim = D will result in better performance.

#ifndef TRANSFORMER_ENGINE_GEMM_H_
#define TRANSFORMER_ENGINE_GEMM_H_

#include <cstdint>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For compatibility with C:

Suggested change
#include <cstdint>
#include <stdint.h>

Alternatively, if we hide the average dims inside NVTEMatmulConfig, there's no reason to expose int64_t in this header.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (2)

  1. transformer_engine/common/include/transformer_engine/gemm.h, line 256 (link)

    syntax: Documentation mentions non-existent config parameter - this parameter is not in the function signature on line 276-280

  2. tests/cpp/operator/test_grouped_gemm.cu, line 141 (link)

    logic: Alignment calculation appears incorrect - multiplying 16 bytes by 8 bits gives 128 bits, not the intended 16-byte alignment

6 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Pawel Gadzinski <[email protected]>
@pggPL
Copy link
Collaborator Author

pggPL commented Dec 19, 2025

/te-ci L0

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (4)

  1. transformer_engine/common/include/transformer_engine/gemm.h, line 260-279 (link)

    syntax: Missing config parameter in implementation

    The function signature includes an NVTEMatmulConfig config parameter in the declaration (line 276), but the actual implementation in cublaslt_grouped_gemm.cu:498 doesn't accept this parameter. The call site in tests (line 449-462 of test_grouped_gemm.cu) also omits it, passing 0 instead.

    Either add the parameter to the implementation or remove it from the declaration.

  2. transformer_engine/common/gemm/cublaslt_grouped_gemm.cu, line 592 (link)

    logic: cuBLAS version mismatch in error message

    Error message mentions "cuBLAS 13.2+" but the compile-time guard checks for CUBLAS_VERSION >= 130100 (line 29), which is cuBLAS 13.1. Also, header comment references cuBLAS 13.1 vs 13.2 inconsistently.

  3. transformer_engine/common/gemm/cublaslt_grouped_gemm.cu, line 472-473 (link)

    logic: Potential null pointer dereference when C is NULL

    When C is NULL and use_null_c=true, C->data.dptr and C->dtype() will cause a segfault. The code sets inputC = outputD when C is NULL (line 525), but this happens after launch_grouped_gemm_setup is called, where C is still NULL.

  4. tests/cpp/operator/test_grouped_gemm.cu, line 95-102 (link)

    style: Workspace size calculation doesn't match implementation

    Test calculates 4 * ptr_bytes + 3 * int_bytes + 2 * ptr_bytes (6 pointer arrays total), but the implementation in GroupedGemmSetupWorkspace::from_buffers expects exactly 6 pointer arrays + 3 int arrays. The calculation is correct but the comment formatting makes it unclear. Consider: 6 * ptr_bytes + 3 * int_bytes.

7 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

GroupedGemm: FP8 per-tensor via cuBLAS

3 participants