-
Notifications
You must be signed in to change notification settings - Fork 583
[Common] Comm+GEMM overlap API updated to support cuBlasMp backend (incl. framework API) #2443
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
908bbc2 to
69cf235
Compare
|
|
||
| #define NVTE_COMM_OVERLAP_MAX_STREAMS 3 | ||
|
|
||
| /* \brief Check if TE is built with cuBlasMp. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: cuBLASMp
| void ub_barrier(ExtComm comm); | ||
|
|
||
| int64_t get_nccl_comm_ptr(std::string comm_name) { | ||
| NVTE_CHECK(backend_is_nccl, "Cannot get nccComm_t ptr if backend is not NCCL."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This error message could be more descriptive - e.g. something like "chosen backend for the communication-computation overlap (cuBLASMp) requires NCCL communicator, but the passed ProcessGroup uses a different backend."
4596411 to
b4ad546
Compare
Greptile SummaryThis PR extends the Comm+GEMM overlap API to support cuBlasMp as an alternative backend for overlapping collective communication operations (all-gather, reduce-scatter) with GEMM computations. The implementation adds new constructor paths that accept NCCL communicator pointers and delegate the overlap orchestration to cuBlasMp, while maintaining backward compatibility with the existing userbuffers-based implementation. Key changes:
Observations:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as User Code
participant PyModule as PyTorch/JAX Module
participant Helper as CommOverlapHelper
participant Core as CommOverlapCore
participant CBMp as cuBlasMp Context
participant NCCL as NCCL Comm
Note over User,NCCL: Initialization Phase
User->>PyModule: initialize_ub(with_cublasmp=True)
PyModule->>Helper: create CommOverlapHelper
Helper->>NCCL: bootstrap NCCL communicator
alt cuBlasMp Backend
PyModule->>Helper: get_nccl_comm_ptr("intra")
Helper-->>PyModule: nccl_comm_ptr
PyModule->>Core: CommOverlapCore(nccl_comm_ptr, tp_rank, tp_size)
Core->>CBMp: nvte_comm_gemm_ctx_create()
CBMp-->>Core: _cublasmp_ctx
Note over Core: Set _with_cublasmp=true<br/>_algo_type based on atomic/p2p
else Traditional Backend
PyModule->>Core: CommOverlapCore(buffer_shape, dtype, helper, ...)
Core->>Core: initialize userbuffers
Note over Core: Create streams, events<br/>allocate counter buffer
end
Note over User,NCCL: Execution Phase (All-Gather + GEMM)
User->>PyModule: forward/backward pass
PyModule->>Core: atomic_gemm_overlap_ag(A, B, D, ...)
alt cuBlasMp Path
Core->>Core: cublasmp_ag_gemm()
Core->>CBMp: nvte_all_gather_gemm(_cublasmp_ctx, ...)
Note over CBMp: cuBlasMp handles:<br/>- Input copy to buffer<br/>- All-gather collective<br/>- GEMM computation<br/>- All orchestrated internally
CBMp-->>Core: result in D
else Traditional Path
Core->>Core: split_overlap_ag()
Note over Core: Manual orchestration:<br/>1. Copy B to userbuffer<br/>2. Split GEMM across streams<br/>3. Launch P2P/multicast AG<br/>4. Sync and accumulate
Core-->>Core: result in D
end
Core-->>PyModule: completed
PyModule-->>User: output tensor
Note over User,NCCL: Execution Phase (GEMM + Reduce-Scatter)
User->>PyModule: forward/backward pass
PyModule->>Core: atomic_gemm_overlap_rs(A, B, D, ...)
alt cuBlasMp Path
Core->>Core: cublasmp_gemm_rs()
Core->>CBMp: nvte_gemm_reduce_scatter(_cublasmp_ctx, ...)
Note over CBMp: cuBlasMp handles:<br/>- GEMM computation<br/>- Reduce-scatter collective<br/>- Output to final buffer
CBMp-->>Core: result in D
else Traditional Path
Core->>Core: split_overlap_rs()
Note over Core: Manual orchestration:<br/>1. Split GEMM computation<br/>2. Launch reduce-scatter<br/>3. Sync streams
Core-->>Core: result in D
end
Core-->>PyModule: completed
PyModule-->>User: output tensor
Note over User,NCCL: Cleanup Phase
User->>PyModule: module cleanup
PyModule->>Core: ~CommOverlapCore()
alt cuBlasMp Backend
Core->>CBMp: nvte_comm_gemm_ctx_destroy()
else Traditional Backend
Core->>Core: destroy streams, events
Core->>NCCL: destroy_communicator()
end
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (8)
-
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp, line 335 (link)logic: Variable shadowing bug:
kis assignedk * _tp_sizewherekappears on both sides. Should bek = k_local * _tp_size. -
transformer_engine/jax/csrc/extensions/cgemm_helper.cpp, line 135 (link)logic: Invalid reinterpret_cast: cannot cast an
int*(pointer) toint(value). Should bereinterpret_cast<void**>(&handler._device_barrier). -
transformer_engine/pytorch/csrc/extensions.h, line 517 (link)syntax: Stray character
athat will cause compilation failure. -
transformer_engine/pytorch/csrc/extensions.h, line 537-540 (link)logic: Constructor parameter mismatch:
CommOverlapBaseconstructor expects(nccl_comm_ptr, tp_rank, tp_size, ...)but called with(nccl_comm_ptr, tp_size, tp_rank, ...). Order oftp_rankandtp_sizeis swapped. -
transformer_engine/pytorch/csrc/extensions.h, line 563-566 (link)logic: Constructor parameter mismatch:
CommOverlapP2PBaseconstructor expects(nccl_comm_ptr, tp_rank, tp_size, ...)but called with(nccl_comm_ptr, tp_size, tp_rank, ...). Order oftp_rankandtp_sizeis swapped. -
transformer_engine/jax/csrc/extensions/cgemm_helper.cpp, line 196-199 (link)logic: Constructor parameter mismatch:
CommOverlapP2PBaseconstructor expects(nccl_comm_ptr, tp_rank, tp_size, ...)but called with(nccl_comm_ptr, tp_size, tp_domain_id, ...). Should usetp_rankinstead oftp_domain_id.Should this use tp_rank (local_device_id_within_tp_domain) or node_id (tp_domain_id)?
-
tests/pytorch/distributed/run_gemm_with_overlap.py, line 416-418 (link)style: Unconditional initialization:
local_kernel2_t_shapeis only used whenub_obj2 is not None, but it's always initialized here. This creates an unused variable in most cases.Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
-
tests/pytorch/distributed/run_gemm_with_overlap.py, line 490 (link)style: Incorrect initialization:
ref2_gshould only be assigned whenub_obj2 is not None, but it's unconditionally assigned to(0,)which doesn't match the expected tensor type.
18 files reviewed, 8 comments
147036f to
c5471f8
Compare
…rk extensions Signed-off-by: Alp Dener <[email protected]>
…entirely Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
c5471f8 to
d79bf21
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (3)
-
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp, line 87 (link)logic: return value from
nvte_comm_gemm_ctx_create()is not assigned to_cublasmp_ctx, causing null pointer when destructor callsnvte_comm_gemm_ctx_destroy(_cublasmp_ctx)on line 173 -
transformer_engine/jax/cpp_extensions/gemm.py, line 819 (link)style: parameter order changed -
collective_opmoved aftertranspose_batch_sequence,sequence_dim,is_outer. Verify this matches the calling convention and doesn't break compatibility. Was this parameter reordering intentional, and have all call sites been verified? -
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp, line 89-91 (link)logic:
_tp_sizefield is not initialized in the cuBlasMp constructor, but it's used incublasmp_ag_gemm()andcublasmp_gemm_rs()methods (lines 321, 332). Add initialization
18 files reviewed, 3 comments
Signed-off-by: Alp Dener <[email protected]>
364b416 to
ee517d3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (6)
-
transformer_engine/pytorch/module/base.py, line 415-417 (link)logic: Parameter order is incorrect - the C++ constructor signature is
(helper, tp_rank, tp_size, ...)but Python is passing(helper, tp_size, local_rank, ...). This swapstp_rankandtp_size, causing incorrect initialization. -
transformer_engine/pytorch/module/base.py, line 387-389 (link)logic: Parameter order is incorrect - the C++ constructor signature is
(helper, tp_rank, tp_size, ...)but Python is passing(helper, tp_size, local_rank, ...). This swapstp_rankandtp_size, causing incorrect initialization. -
tests/pytorch/distributed/run_gemm_with_overlap.py, line 340-344 (link)logic: Parameter order is incorrect - C++ signature is
(helper, tp_rank, tp_size, ...)but passing(helper, tp_size, tp_rank, ...). Swap the second and third parameters. -
tests/pytorch/distributed/run_gemm_with_overlap.py, line 355-359 (link)logic: Parameter order is incorrect - C++ signature is
(helper, tp_rank, tp_size, ...)but passing(helper, tp_size, tp_rank, ...). Swap the second and third parameters. -
tests/pytorch/distributed/run_gemm_with_overlap.py, line 383 (link)logic: Parameter order is incorrect - C++ signature is
(helper, tp_rank, tp_size, ...)but passing(helper, tp_size, tp_rank, ...). Swap the second and third parameters. -
tests/pytorch/distributed/run_gemm_with_overlap.py, line 394 (link)logic: Parameter order is incorrect - C++ signature is
(helper, tp_rank, tp_size, ...)but passing(helper, tp_size, tp_rank, ...). Swap the second and third parameters.
19 files reviewed, 6 comments
Signed-off-by: Alp Dener <[email protected]>
5cb8204 to
51b64fb
Compare
for more information, see https://pre-commit.ci
| target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include ${NVSHMEM_DIR}/include) | ||
| find_library(CUBLASMP_LIB | ||
| NAMES cublasmp libcublasmp | ||
| NAMES cublasmp libcublasmp.so.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to be this specific?
| ) | ||
| assert 0 <= process_id < num_total_devices, f"Invalid process_id={process_id}" | ||
| global collective_gemm_with_cublasmp | ||
| collective_gemm_with_cublasmp = use_cublasmp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, so that means that one either uses cublasMp for everything or for nothing?
Description
This PR adds support for the NVTE cuBlasMp bindings in the Comm+GEMM overlap API.
Type of change
Checklist: