You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add a Hopper TMA (cp.async.bulk) copy kernel in csrc/multidevice/tma_copy.cu and validate it across three memory source/destination types:
local GMEM
peer symmetric memory. It means TMA can write from local shared memory to remote global memory.
NVLS multicast pointers. It means that by using the multicast ptr as the destination of the TMA request, data can be broadcast to the whole NVL domain in one shot at line rate. Note, however, that this is not officially supported according to the CUDA doc.
Those behavior are demonstrated through three unit tests at tests/cpp/test_multidevice_tma.cpp. The tests reuse the SymmetricTensor abstraction for VMM allocation, IPC handle exchange, and multicast setup, keeping the test bodies focused on the TMA transfer itself.
Why
The CUDA backend for multi-device communication (csrc/multidevice/cuda_p2p.cpp) currently uses SM-based copies (regular threads load/store or multimem.st) and copy-engine copies (cudaMemcpyAsync / cudaMemcpyBatchAsync). TMA offers a third transport option that is GPU-initiated, lightweight (single-thread issue), fully asynchronous, and frees SM resources for overlapping compute. This transport is leveraged by DeepEP for intra-node MoE dispatch. This PR validates that TMA works correctly on the memory types used by nvFuser's multi-device infrastructure.
This lays the groundwork for a follow-up PR that integrates TMA as a transport option for P2P and multicast communications alongside the existing SM-based copies and copy-engine transports.
How
The kernel is implemented in csrc/multidevice/tma_copy.cu. It is a single-warp kernel where thread 0 performs a two-phase TMA transfer through shared memory (GMEM(src) --[TMA load]--> SMEM --[TMA store]--> GMEM(dst)), using mbarrier for async completion tracking. TMA is a GMEM-SMEM engine — there is no GMEM-to-GMEM variant, so shared memory staging is inherent to the hardware.
The kernel is compiled at runtime via NVRTC (same pattern as the existing alltoallv.cu, multicast.cu kernels in cuda_p2p.cpp, and other kernels in runtime/) and stringified at build time through the existing NVFUSER_RUNTIME_FILES pipeline.
The kernel assumes num_bytes > 0 and divisible by 16, but lacks runtime validation. While the test validates this, the kernel itself could be more defensive against invalid inputs to prevent undefined behavior.
Heavy reliance on inline PTX assembly for TMA operations. While functionally correct, this approach lacks compile-time safety and could benefit from additional validation or abstraction to prevent potential issues with register allocation or instruction encoding.
The multicast test has CUDA version gating (>= 13000) which may limit testing coverage. Consider if there are alternative ways to validate multicast functionality or if this limitation is acceptable for the current scope.
#if (CUDA_VERSION >= 13000)
// Verify TMA 1D bulk copy writing TO an NVLS multicast pointer.// Root uses TMA to write data to the MC pointer, which broadcasts// via NVLS hardware. All ranks then verify the data arrived by// reading from their local UC view with a normal copy.TEST_F(TmaTest, TmaMulticastWrite) {
if (communicator_->size() == 1) {
GTEST_SKIP() << "Skipping test for single device";
}
constint64_t rank = communicator_->deviceId();
constint64_t local_rank = communicator_->local_rank();
int major;
NVFUSER_CUDA_RT_SAFE_CALL(cudaDeviceGetAttribute(
&major, cudaDevAttrComputeCapabilityMajor, local_rank));
if (major < 9) {
GTEST_SKIP() << "Requires Hopper (SM90+)";
}
int is_multicast_supported;
NVFUSER_CUDA_SAFE_CALL(cuDeviceGetAttribute(
&is_multicast_supported,
CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED,
local_rank));
if (is_multicast_supported == 0) {
GTEST_SKIP() << "Device does not support Multicast Objects; skipping.";
}
constexprint64_tkNumElems = 524288; // 2 MB / sizeof(int32_t)constexprint64_t root = 0;
// cp.async.bulk transfer size is limited by shared memory,// so we broadcast a 4 KB slice via TMA.constexprintkTmaBytes = 4096;
static_assert(kTmaBytes % 16 == 0);
constexprintkTmaElems = kTmaBytes / sizeof(int32_t);
at::Tensor local =
SymmetricTensor::allocate({kNumElems}, at::kInt, communicator_->device());
local.zero_();
SymmetricTensor sym(local);
sym.setupMulticast(root, "tma_mcast");
auto opts = at::TensorOptions().dtype(at::kInt).device(at::kCUDA, local_rank);
// Root: TMA-write source data to MC pointer (NVLS broadcasts it)if (rank == root) {
at::Tensor src = at::arange(kTmaElems, opts);
launchTmaCopy1D(sym.multicastPtr(), src.data_ptr(), kTmaBytes);
NVFUSER_CUDA_RT_SAFE_CALL(cudaDeviceSynchronize());
}
communicator_->barrier();
// All ranks: verify data arrived via normal read of local UC tensor
at::Tensor readback = sym.localTensor().slice(0, 0, kTmaElems).clone();
at::Tensor expected = at::arange(kTmaElems, opts);
EXPECT_TRUE(readback.equal(expected))
<< "Rank " << rank << " did not receive multicast data written by TMA";
}
#endif// CUDA_VERSION >= 13000
Added Hopper TMA (cp.async.bulk) copy kernel for P2P and multicast communications. The kernel implements a two-phase transfer (GMEM→SMEM→GMEM) driven by a single thread with mbarrier synchronization, validated across local memory, peer VMM-mapped memory, and NVLS multicast pointers.
Kernel implementation in csrc/multidevice/tma_copy.cu uses inline PTX assembly for TMA operations with proper mbarrier synchronization between load and store phases
Tests validate three critical memory types: local device memory, inter-device P2P via VMM, and NVLS multicast writes
Build system correctly integrates kernel stringification via existing NVFUSER_RUNTIME_FILES pipeline
Code follows established patterns from existing runtime kernels (alltoallv.cu, multicast.cu)
Confidence Score: 5/5
This PR is safe to merge with minimal risk
The implementation is well-structured with comprehensive tests, proper synchronization primitives, thorough documentation, and follows established project patterns. The TMA kernel uses correct PTX assembly with proper mbarrier synchronization. Tests validate all target memory types with appropriate feature detection and graceful skips. No logic errors, security issues, or architectural concerns identified.
No files require special attention
Important Files Changed
Filename
Overview
CMakeLists.txt
Added test file registration, build dependencies, and runtime file for TMA kernel stringification
csrc/multidevice/tma_copy.cu
Implements single-warp TMA kernel with two-phase GMEM->SMEM->GMEM transfer using mbarrier synchronization
tests/cpp/test_multidevice_tma.cpp
Comprehensive tests validating TMA across local, peer P2P, and NVLS multicast memory types with NVRTC compilation
Sequence Diagram
sequenceDiagram
participant T0 as Thread 0
participant SMEM as Shared Memory
participant MBAR as MBarrier
participant GMEM_SRC as GMEM(src)
participant GMEM_DST as GMEM(dst)
T0->>MBAR: mbarrier.init(arrival_count=1)
T0->>T0: fence.mbarrier_init.release.cluster
T0->>T0: __syncwarp()
Note over T0,GMEM_SRC: TMA Load Phase
T0->>MBAR: mbarrier.arrive.expect_tx(num_bytes)
T0->>GMEM_SRC: cp.async.bulk (TMA load)
GMEM_SRC-->>SMEM: async data transfer
SMEM->>MBAR: complete_tx notification
T0->>MBAR: mbarrier.try_wait.parity (spin until done)
MBAR-->>T0: phase flip (load complete)
Note over T0,GMEM_DST: TMA Store Phase
T0->>SMEM: cp.async.bulk.global.shared
SMEM-->>GMEM_DST: async data transfer
T0->>T0: cp.async.bulk.commit_group
T0->>T0: cp.async.bulk.wait_group.read 0
T0->>MBAR: mbarrier.inval
for codegen. Would you recommend using those building blocks or adding some so it's easier for nvFuser to generate fused comm/gemm in the future?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
Add a Hopper TMA (
cp.async.bulk) copy kernel incsrc/multidevice/tma_copy.cuand validate it across three memory source/destination types:Those behavior are demonstrated through three unit tests at
tests/cpp/test_multidevice_tma.cpp. The tests reuse theSymmetricTensorabstraction for VMM allocation, IPC handle exchange, and multicast setup, keeping the test bodies focused on the TMA transfer itself.Why
The CUDA backend for multi-device communication (
csrc/multidevice/cuda_p2p.cpp) currently uses SM-based copies (regular threads load/store ormultimem.st) and copy-engine copies (cudaMemcpyAsync/cudaMemcpyBatchAsync). TMA offers a third transport option that is GPU-initiated, lightweight (single-thread issue), fully asynchronous, and frees SM resources for overlapping compute. This transport is leveraged by DeepEP for intra-node MoE dispatch. This PR validates that TMA works correctly on the memory types used by nvFuser's multi-device infrastructure.This lays the groundwork for a follow-up PR that integrates TMA as a transport option for P2P and multicast communications alongside the existing SM-based copies and copy-engine transports.
How
csrc/multidevice/tma_copy.cu. It is a single-warp kernel where thread 0 performs a two-phase TMA transfer through shared memory (GMEM(src) --[TMA load]--> SMEM --[TMA store]--> GMEM(dst)), usingmbarrierfor async completion tracking. TMA is a GMEM-SMEM engine — there is no GMEM-to-GMEM variant, so shared memory staging is inherent to the hardware.alltoallv.cu,multicast.cukernels incuda_p2p.cpp, and other kernels inruntime/) and stringified at build time through the existingNVFUSER_RUNTIME_FILESpipeline.