Skip to content

Conversation

@denera
Copy link
Collaborator

@denera denera commented Dec 2, 2025

Description

This PR adds support for the NVTE cuBlasMp bindings in the Comm+GEMM overlap API.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@denera denera self-assigned this Dec 2, 2025
@denera denera force-pushed the common/tp-overlap-cublasmp branch 2 times, most recently from 908bbc2 to 69cf235 Compare December 2, 2025 20:12

#define NVTE_COMM_OVERLAP_MAX_STREAMS 3

/* \brief Check if TE is built with cuBlasMp.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: cuBLASMp

void ub_barrier(ExtComm comm);

int64_t get_nccl_comm_ptr(std::string comm_name) {
NVTE_CHECK(backend_is_nccl, "Cannot get nccComm_t ptr if backend is not NCCL.");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error message could be more descriptive - e.g. something like "chosen backend for the communication-computation overlap (cuBLASMp) requires NCCL communicator, but the passed ProcessGroup uses a different backend."

@denera denera force-pushed the common/tp-overlap-cublasmp branch from 4596411 to b4ad546 Compare December 16, 2025 19:04
@denera denera marked this pull request as ready for review December 16, 2025 22:58
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 16, 2025

Greptile Summary

This PR extends the Comm+GEMM overlap API to support cuBlasMp as an alternative backend for overlapping collective communication operations (all-gather, reduce-scatter) with GEMM computations. The implementation adds new constructor paths that accept NCCL communicator pointers and delegate the overlap orchestration to cuBlasMp, while maintaining backward compatibility with the existing userbuffers-based implementation.

Key changes:

  • Added new constructors in CommOverlapCore, CommOverlapBase, and CommOverlapP2PBase that initialize with cuBlasMp context instead of userbuffers
  • Implemented conditional execution paths: when _with_cublasmp is true, operations like atomic_gemm_overlap_ag() and split_overlap_rs() delegate to cuBlasMp via cublasmp_ag_gemm() and cublasmp_gemm_rs()
  • Extended PyTorch and JAX APIs with with_cublasmp/use_cublasmp flags that control which backend is used
  • Added comprehensive test coverage by parameterizing existing tests with use_cublasmp flag, effectively doubling test matrix
  • Updated CMake build to search for libcublasmp.so.0 specifically

Observations:

  • The implementation properly guards cuBlasMp-only operations (e.g., bulk overlap raises error with cuBlasMp, buffer filling is skipped)
  • Workspace allocation logic correctly recognizes that cuBlasMp manages its own workspaces
  • Destructor properly handles cleanup for both backends
  • The PR includes fixes from recent commits addressing function argument ordering and linting issues

Confidence Score: 4/5

  • This PR is safe to merge with minor attention needed for runtime validation of cuBlasMp-specific code paths
  • The implementation demonstrates solid software engineering with proper abstraction, backward compatibility, and comprehensive test coverage. The conditional execution paths are well-guarded with runtime checks. However, the score is 4 (not 5) because: (1) the cuBlasMp backend code paths can only be tested in environments where TE is compiled with NVTE_WITH_CUBLASMP=1, limiting pre-merge validation; (2) there's moderate architectural complexity in maintaining two parallel execution paths; (3) the PR includes several follow-up fix commits suggesting iterative refinement was needed
  • Pay close attention to transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp for the conditional execution logic, and transformer_engine/pytorch/module/base.py for the constructor instantiation patterns with cuBlasMp

Important Files Changed

Filename Overview
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp Added cuBlasMp backend support with new constructor and conditional execution paths for comm+GEMM overlap operations
transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h Added new constructors for cuBlasMp support, cuBlasMp context member variables, and helper methods for AG/RS GEMM operations
transformer_engine/pytorch/csrc/extensions.h Added get_nccl_comm_ptr() method and new constructors for CommOverlap/CommOverlapP2P to support cuBlasMp initialization
transformer_engine/pytorch/module/base.py Added with_cublasmp parameter and conditional instantiation of CommOverlap objects with cuBlasMp constructors; skips buffer filling when using cuBlasMp
transformer_engine/jax/csrc/extensions/cgemm_helper.cpp Added use_cublasmp parameter to initialization and executor creation; creates CommOverlapP2PBase with cuBlasMp constructor when enabled
tests/pytorch/distributed/test_comm_gemm_overlap.py Added use_cublasmp parameter to all test cases, effectively doubling test coverage to verify both backends

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant PyModule as PyTorch/JAX Module
    participant Helper as CommOverlapHelper
    participant Core as CommOverlapCore
    participant CBMp as cuBlasMp Context
    participant NCCL as NCCL Comm
    
    Note over User,NCCL: Initialization Phase
    User->>PyModule: initialize_ub(with_cublasmp=True)
    PyModule->>Helper: create CommOverlapHelper
    Helper->>NCCL: bootstrap NCCL communicator
    
    alt cuBlasMp Backend
        PyModule->>Helper: get_nccl_comm_ptr("intra")
        Helper-->>PyModule: nccl_comm_ptr
        PyModule->>Core: CommOverlapCore(nccl_comm_ptr, tp_rank, tp_size)
        Core->>CBMp: nvte_comm_gemm_ctx_create()
        CBMp-->>Core: _cublasmp_ctx
        Note over Core: Set _with_cublasmp=true<br/>_algo_type based on atomic/p2p
    else Traditional Backend
        PyModule->>Core: CommOverlapCore(buffer_shape, dtype, helper, ...)
        Core->>Core: initialize userbuffers
        Note over Core: Create streams, events<br/>allocate counter buffer
    end
    
    Note over User,NCCL: Execution Phase (All-Gather + GEMM)
    User->>PyModule: forward/backward pass
    PyModule->>Core: atomic_gemm_overlap_ag(A, B, D, ...)
    
    alt cuBlasMp Path
        Core->>Core: cublasmp_ag_gemm()
        Core->>CBMp: nvte_all_gather_gemm(_cublasmp_ctx, ...)
        Note over CBMp: cuBlasMp handles:<br/>- Input copy to buffer<br/>- All-gather collective<br/>- GEMM computation<br/>- All orchestrated internally
        CBMp-->>Core: result in D
    else Traditional Path
        Core->>Core: split_overlap_ag()
        Note over Core: Manual orchestration:<br/>1. Copy B to userbuffer<br/>2. Split GEMM across streams<br/>3. Launch P2P/multicast AG<br/>4. Sync and accumulate
        Core-->>Core: result in D
    end
    
    Core-->>PyModule: completed
    PyModule-->>User: output tensor
    
    Note over User,NCCL: Execution Phase (GEMM + Reduce-Scatter)
    User->>PyModule: forward/backward pass
    PyModule->>Core: atomic_gemm_overlap_rs(A, B, D, ...)
    
    alt cuBlasMp Path
        Core->>Core: cublasmp_gemm_rs()
        Core->>CBMp: nvte_gemm_reduce_scatter(_cublasmp_ctx, ...)
        Note over CBMp: cuBlasMp handles:<br/>- GEMM computation<br/>- Reduce-scatter collective<br/>- Output to final buffer
        CBMp-->>Core: result in D
    else Traditional Path
        Core->>Core: split_overlap_rs()
        Note over Core: Manual orchestration:<br/>1. Split GEMM computation<br/>2. Launch reduce-scatter<br/>3. Sync streams
        Core-->>Core: result in D
    end
    
    Core-->>PyModule: completed
    PyModule-->>User: output tensor
    
    Note over User,NCCL: Cleanup Phase
    User->>PyModule: module cleanup
    PyModule->>Core: ~CommOverlapCore()
    
    alt cuBlasMp Backend
        Core->>CBMp: nvte_comm_gemm_ctx_destroy()
    else Traditional Backend
        Core->>Core: destroy streams, events
        Core->>NCCL: destroy_communicator()
    end
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (8)

  1. transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp, line 335 (link)

    logic: Variable shadowing bug: k is assigned k * _tp_size where k appears on both sides. Should be k = k_local * _tp_size.

  2. transformer_engine/jax/csrc/extensions/cgemm_helper.cpp, line 135 (link)

    logic: Invalid reinterpret_cast: cannot cast an int* (pointer) to int (value). Should be reinterpret_cast<void**>(&handler._device_barrier).

  3. transformer_engine/pytorch/csrc/extensions.h, line 517 (link)

    syntax: Stray character a that will cause compilation failure.

  4. transformer_engine/pytorch/csrc/extensions.h, line 537-540 (link)

    logic: Constructor parameter mismatch: CommOverlapBase constructor expects (nccl_comm_ptr, tp_rank, tp_size, ...) but called with (nccl_comm_ptr, tp_size, tp_rank, ...). Order of tp_rank and tp_size is swapped.

  5. transformer_engine/pytorch/csrc/extensions.h, line 563-566 (link)

    logic: Constructor parameter mismatch: CommOverlapP2PBase constructor expects (nccl_comm_ptr, tp_rank, tp_size, ...) but called with (nccl_comm_ptr, tp_size, tp_rank, ...). Order of tp_rank and tp_size is swapped.

  6. transformer_engine/jax/csrc/extensions/cgemm_helper.cpp, line 196-199 (link)

    logic: Constructor parameter mismatch: CommOverlapP2PBase constructor expects (nccl_comm_ptr, tp_rank, tp_size, ...) but called with (nccl_comm_ptr, tp_size, tp_domain_id, ...). Should use tp_rank instead of tp_domain_id.

    Should this use tp_rank (local_device_id_within_tp_domain) or node_id (tp_domain_id)?

  7. tests/pytorch/distributed/run_gemm_with_overlap.py, line 416-418 (link)

    style: Unconditional initialization: local_kernel2_t_shape is only used when ub_obj2 is not None, but it's always initialized here. This creates an unused variable in most cases.

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

  8. tests/pytorch/distributed/run_gemm_with_overlap.py, line 490 (link)

    style: Incorrect initialization: ref2_g should only be assigned when ub_obj2 is not None, but it's unconditionally assigned to (0,) which doesn't match the expected tensor type.

18 files reviewed, 8 comments

Edit Code Review Agent Settings | Greptile

@denera denera force-pushed the common/tp-overlap-cublasmp branch from 147036f to c5471f8 Compare December 17, 2025 02:15
@denera denera force-pushed the common/tp-overlap-cublasmp branch from c5471f8 to d79bf21 Compare December 17, 2025 02:16
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (3)

  1. transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp, line 87 (link)

    logic: return value from nvte_comm_gemm_ctx_create() is not assigned to _cublasmp_ctx, causing null pointer when destructor calls nvte_comm_gemm_ctx_destroy(_cublasmp_ctx) on line 173

  2. transformer_engine/jax/cpp_extensions/gemm.py, line 819 (link)

    style: parameter order changed - collective_op moved after transpose_batch_sequence, sequence_dim, is_outer. Verify this matches the calling convention and doesn't break compatibility. Was this parameter reordering intentional, and have all call sites been verified?

  3. transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp, line 89-91 (link)

    logic: _tp_size field is not initialized in the cuBlasMp constructor, but it's used in cublasmp_ag_gemm() and cublasmp_gemm_rs() methods (lines 321, 332). Add initialization

18 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Alp Dener <[email protected]>
@denera denera force-pushed the common/tp-overlap-cublasmp branch from 364b416 to ee517d3 Compare December 17, 2025 02:50
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (6)

  1. transformer_engine/pytorch/module/base.py, line 415-417 (link)

    logic: Parameter order is incorrect - the C++ constructor signature is (helper, tp_rank, tp_size, ...) but Python is passing (helper, tp_size, local_rank, ...). This swaps tp_rank and tp_size, causing incorrect initialization.

  2. transformer_engine/pytorch/module/base.py, line 387-389 (link)

    logic: Parameter order is incorrect - the C++ constructor signature is (helper, tp_rank, tp_size, ...) but Python is passing (helper, tp_size, local_rank, ...). This swaps tp_rank and tp_size, causing incorrect initialization.

  3. tests/pytorch/distributed/run_gemm_with_overlap.py, line 340-344 (link)

    logic: Parameter order is incorrect - C++ signature is (helper, tp_rank, tp_size, ...) but passing (helper, tp_size, tp_rank, ...). Swap the second and third parameters.

  4. tests/pytorch/distributed/run_gemm_with_overlap.py, line 355-359 (link)

    logic: Parameter order is incorrect - C++ signature is (helper, tp_rank, tp_size, ...) but passing (helper, tp_size, tp_rank, ...). Swap the second and third parameters.

  5. tests/pytorch/distributed/run_gemm_with_overlap.py, line 383 (link)

    logic: Parameter order is incorrect - C++ signature is (helper, tp_rank, tp_size, ...) but passing (helper, tp_size, tp_rank, ...). Swap the second and third parameters.

  6. tests/pytorch/distributed/run_gemm_with_overlap.py, line 394 (link)

    logic: Parameter order is incorrect - C++ signature is (helper, tp_rank, tp_size, ...) but passing (helper, tp_size, tp_rank, ...). Swap the second and third parameters.

19 files reviewed, 6 comments

Edit Code Review Agent Settings | Greptile

@denera denera force-pushed the common/tp-overlap-cublasmp branch from 5cb8204 to 51b64fb Compare December 17, 2025 03:36
target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include ${NVSHMEM_DIR}/include)
find_library(CUBLASMP_LIB
NAMES cublasmp libcublasmp
NAMES cublasmp libcublasmp.so.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to be this specific?

)
assert 0 <= process_id < num_total_devices, f"Invalid process_id={process_id}"
global collective_gemm_with_cublasmp
collective_gemm_with_cublasmp = use_cublasmp
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, so that means that one either uses cublasMp for everything or for nothing?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants