-
Notifications
You must be signed in to change notification settings - Fork 583
[common] Add support for cuBLASLt GEMM for GroupedTensor #2502
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
- Add FP8 scale_inv pointer handling in nvte_grouped_gemm for proper FP8 GEMM - Fix random padding in tests to ensure 16-byte alignment for all dtypes - Reorder GroupedGemmSetupWorkspace members for natural alignment - Remove debug prints Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
|
/te-ci L0 |
Greptile SummaryThis PR adds a new Key Changes
Issues Found
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant nvte_grouped_gemm
participant Validation
participant OperandSelection
participant SetupKernel as setup_grouped_gemm_kernel
participant cuBLASLt
User->>nvte_grouped_gemm: Call with A, B, C, D, alpha, beta
nvte_grouped_gemm->>nvte_grouped_gemm: Check SM >= 90 (Hopper)
nvte_grouped_gemm->>Validation: validate_grouped_gemm_inputs()
Validation-->>nvte_grouped_gemm: OK
nvte_grouped_gemm->>nvte_grouped_gemm: If C is NULL, set C = D
nvte_grouped_gemm->>OperandSelection: select_grouped_operand(A, transa)
OperandSelection->>OperandSelection: Check FP8 TN layout constraints
OperandSelection->>OperandSelection: Choose row-wise vs column-wise data
OperandSelection-->>nvte_grouped_gemm: A_sel (dptr, dtype, trans)
nvte_grouped_gemm->>OperandSelection: select_grouped_operand(B, transb)
OperandSelection-->>nvte_grouped_gemm: B_sel (dptr, dtype, trans)
nvte_grouped_gemm->>nvte_grouped_gemm: Allocate workspace buffers
nvte_grouped_gemm->>SetupKernel: Launch kernel on GPU
SetupKernel->>SetupKernel: For each tensor i:<br/>- Compute pointers from base + offset<br/>- Extract M[i], N[i], K[i] from dims<br/>- Setup alpha_ptrs[i], beta_ptrs[i]
SetupKernel-->>nvte_grouped_gemm: Pointer arrays ready
nvte_grouped_gemm->>cuBLASLt: cublasLtGroupedMatrixLayoutInit()
nvte_grouped_gemm->>cuBLASLt: cublasLtMatmulDescInit()
nvte_grouped_gemm->>cuBLASLt: Set FP8 scale pointers (if FP8)
nvte_grouped_gemm->>cuBLASLt: cublasLtMatmulAlgoGetHeuristic()
cuBLASLt-->>nvte_grouped_gemm: Selected algorithm
nvte_grouped_gemm->>cuBLASLt: cublasLtMatmul(A_ptrs, B_ptrs, C_ptrs, D_ptrs)
cuBLASLt->>cuBLASLt: Execute grouped GEMM:<br/>D[i] = alpha[i]*op(A[i])@op(B[i]) + beta[i]*C[i]
cuBLASLt-->>nvte_grouped_gemm: Results in D
nvte_grouped_gemm-->>User: Return
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (4)
-
tests/cpp/operator/test_grouped_gemm.cu, line 75 (link)logic: missing columnwise_data in move assignment
-
tests/cpp/operator/test_grouped_gemm.cu, line 336-337 (link)logic: kSameFirst and kSameLast test cases use identical shape tuples. Should kSameFirst and kSameLast have different shape patterns to properly test the respective scenarios?
-
tests/cpp/operator/test_grouped_gemm.cu, line 367-378 (link)logic: missing case for InputCase::kFP8Delayed
-
transformer_engine/common/include/transformer_engine/gemm.h, line 265-266 (link)style: The
avg_m,avg_n,avg_kparameters are not documented in the function commentWhat do these average dimension parameters represent and how should they be computed?
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
4 files reviewed, 4 comments
- Remove unused alignment parameter from GroupedGemmSetupWorkspace::from_buffers - Simplify select_grouped_operand by removing dead code branches - Add GroupedOperandSelection.tensor field to avoid passing tensor separately - Extract set_fp8_scale_pointers and init_matrix_layouts helpers - Add safety check for FP8 on Hopper column-wise fallback - Support NULL C tensor when beta=0 (uses D as placeholder) - Remove unused get_scale_inv() from test - Add use_null_c test parameter and test case - Fix documentation: alpha/beta are single element tensors only Signed-off-by: Piotr Gadzinski <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/common/include/transformer_engine/gemm.h, line 266 (link)syntax: Documentation incorrectly states
alpha[i]andbeta[i]. The implementation uses a singlealphaandbetavalue for all matrices in the group (batch stride is 1 ininit_matmul_descat cublaslt_gemm.cu:1404), not per-matrix scaling.
4 files reviewed, 1 comment
- Change alpha/beta from single values to per-matrix arrays - Validate alpha/beta have exactly num_tensors elements - Update kernel to index alpha_ptr[idx] and beta_ptr[idx] - Move alpha/beta validation to validate_grouped_gemm_inputs - Update tests to use per-matrix alpha/beta arrays - Update documentation Signed-off-by: Piotr Gadzinski <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Piotr Gadzinski <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
|
/te-ci |
| NVTEMatmulConfig config, cudaStream_t stream, const int64_t *avg_m, | ||
| const int64_t *avg_n, const int64_t *avg_k); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The average sizes seem like advanced configs that would be better to leave out of the top-level API. Can we move them inside NVTEMatmulConfig?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was suggestion of @ptrendx - customer may potentially want to use them if they know something more about the shapes. For example if there are multiple tensors of with k dimension D and one tensor with K dimension equal to 1, then it is potentially true that telling cublas that avg dim = D will result in better performance.
| #ifndef TRANSFORMER_ENGINE_GEMM_H_ | ||
| #define TRANSFORMER_ENGINE_GEMM_H_ | ||
|
|
||
| #include <cstdint> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For compatibility with C:
| #include <cstdint> | |
| #include <stdint.h> |
Alternatively, if we hide the average dims inside NVTEMatmulConfig, there's no reason to expose int64_t in this header.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Signed-off-by: Pawel Gadzinski <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (2)
-
transformer_engine/common/include/transformer_engine/gemm.h, line 256 (link)syntax: Documentation mentions non-existent
configparameter - this parameter is not in the function signature on line 276-280 -
tests/cpp/operator/test_grouped_gemm.cu, line 141 (link)logic: Alignment calculation appears incorrect - multiplying 16 bytes by 8 bits gives 128 bits, not the intended 16-byte alignment
6 files reviewed, 2 comments
Signed-off-by: Pawel Gadzinski <[email protected]>
|
/te-ci L0 |
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (4)
-
transformer_engine/common/include/transformer_engine/gemm.h, line 260-279 (link)syntax: Missing
configparameter in implementationThe function signature includes an
NVTEMatmulConfig configparameter in the declaration (line 276), but the actual implementation incublaslt_grouped_gemm.cu:498doesn't accept this parameter. The call site in tests (line 449-462 oftest_grouped_gemm.cu) also omits it, passing0instead.Either add the parameter to the implementation or remove it from the declaration.
-
transformer_engine/common/gemm/cublaslt_grouped_gemm.cu, line 592 (link)logic: cuBLAS version mismatch in error message
Error message mentions "cuBLAS 13.2+" but the compile-time guard checks for
CUBLAS_VERSION >= 130100(line 29), which is cuBLAS 13.1. Also, header comment references cuBLAS 13.1 vs 13.2 inconsistently. -
transformer_engine/common/gemm/cublaslt_grouped_gemm.cu, line 472-473 (link)logic: Potential null pointer dereference when C is NULL
When
Cis NULL anduse_null_c=true,C->data.dptrandC->dtype()will cause a segfault. The code setsinputC = outputDwhen C is NULL (line 525), but this happens afterlaunch_grouped_gemm_setupis called, whereCis still NULL. -
tests/cpp/operator/test_grouped_gemm.cu, line 95-102 (link)style: Workspace size calculation doesn't match implementation
Test calculates
4 * ptr_bytes + 3 * int_bytes + 2 * ptr_bytes(6 pointer arrays total), but the implementation inGroupedGemmSetupWorkspace::from_buffersexpects exactly 6 pointer arrays + 3 int arrays. The calculation is correct but the comment formatting makes it unclear. Consider:6 * ptr_bytes + 3 * int_bytes.
7 files reviewed, 4 comments
Description
Adds
nvte_grouped_gemmAPI using cuBLASLt grouped matmul for batched GEMM on tensors with varying shapes. A GPU kernel (setup_grouped_gemm_kernel) convertsNVTEGroupedTensorformat (contiguous buffer + offsets) to cuBLAS requirements (pointer arrays + per-matrix M/N/K).New API
Computes
D = alpha * op(A) @ op(B) + beta * Cfor groups of matrices with potentially different shapes.Type of change
Changes
GroupedGemmSetupWorkspacestruct for cuBLAS workspace layouttest_grouped_gemm.cucomparing againstnvte_multi_tensor_gemm(FP8/BF16, various shapes and transpose layouts)Checklist: