Skip to content

[Multidevice] Tma bulk copy p2p runtime examples#6011

Open
samnordmann wants to merge 5 commits intomainfrom
tma_p2p
Open

[Multidevice] Tma bulk copy p2p runtime examples#6011
samnordmann wants to merge 5 commits intomainfrom
tma_p2p

Conversation

@samnordmann
Copy link
Collaborator

@samnordmann samnordmann commented Feb 25, 2026

What

Add a Hopper TMA (cp.async.bulk) copy kernel in csrc/multidevice/tma_copy.cu and validate it across three memory source/destination types:

  • local GMEM
  • peer symmetric memory. It means TMA can write from local shared memory to remote global memory.
  • NVLS multicast pointers. It means that by using the multicast ptr as the destination of the TMA request, data can be broadcast to the whole NVL domain in one shot at line rate. Note, however, that this is not officially supported according to the CUDA doc.

Those behavior are demonstrated through three unit tests at tests/cpp/test_multidevice_tma.cpp. The tests reuse the SymmetricTensor abstraction for VMM allocation, IPC handle exchange, and multicast setup, keeping the test bodies focused on the TMA transfer itself.

Why

The CUDA backend for multi-device communication (csrc/multidevice/cuda_p2p.cpp) currently uses SM-based copies (regular threads load/store or multimem.st) and copy-engine copies (cudaMemcpyAsync / cudaMemcpyBatchAsync). TMA offers a third transport option that is GPU-initiated, lightweight (single-thread issue), fully asynchronous, and frees SM resources for overlapping compute. This transport is leveraged by DeepEP for intra-node MoE dispatch. This PR validates that TMA works correctly on the memory types used by nvFuser's multi-device infrastructure.

This lays the groundwork for a follow-up PR that integrates TMA as a transport option for P2P and multicast communications alongside the existing SM-based copies and copy-engine transports.

How

  • The kernel is implemented in csrc/multidevice/tma_copy.cu. It is a single-warp kernel where thread 0 performs a two-phase TMA transfer through shared memory (GMEM(src) --[TMA load]--> SMEM --[TMA store]--> GMEM(dst)), using mbarrier for async completion tracking. TMA is a GMEM-SMEM engine — there is no GMEM-to-GMEM variant, so shared memory staging is inherent to the hardware.
  • The kernel is compiled at runtime via NVRTC (same pattern as the existing alltoallv.cu, multicast.cu kernels in cuda_p2p.cpp, and other kernels in runtime/) and stringified at build time through the existing NVFUSER_RUNTIME_FILES pipeline.

@github-actions
Copy link

github-actions bot commented Feb 25, 2026

Review updated until commit ae0c760

Description

  • Add Hopper TMA copy kernel using cp.async.bulk for GMEM->SMEM->GMEM transfers

  • Implement three test scenarios: local GMEM, peer symmetric memory, and NVLS multicast

  • Use NVRTC runtime compilation with dynamic shared memory and mbarrier synchronization

  • Integrate with existing SymmetricTensor infrastructure for VMM and multicast setup

Changes walkthrough

Relevant files
Tests
test_multidevice_tma.cpp
TMA copy kernel tests for multidevice scenarios                   

tests/cpp/test_multidevice_tma.cpp

  • Add comprehensive TMA copy tests with NVRTC runtime compilation
  • Test local GMEM copy, peer device memory, and NVLS multicast scenarios
  • Include SM90+ capability checks and proper error handling
  • Reuse SymmetricTensor abstraction for VMM and multicast setup
  • +271/-0 
    Enhancement
    tma_copy.cu
    Hopper TMA bulk copy kernel implementation                             

    csrc/multidevice/tma_copy.cu

  • Implement single-warp TMA kernel with thread 0 driving transfers
  • Use two-phase GMEM->SMEM->GMEM copy with mbarrier synchronization
  • Handle dynamic shared memory allocation and alignment
  • Include inline PTX assembly for TMA operations and barriers
  • +101/-0 
    Configuration changes
    CMakeLists.txt
    Build configuration for TMA tests and resources                   

    CMakeLists.txt

  • Add test_multidevice_tma.cpp to multidevice test sources
  • Include nvfuser_rt_tma_copy dependency for runtime resources
  • Add binary directory include path for generated headers
  • +5/-0     

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Kernel Robustness

    The kernel assumes num_bytes > 0 and divisible by 16, but lacks runtime validation. While the test validates this, the kernel itself could be more defensive against invalid inputs to prevent undefined behavior.

    extern "C" __global__ void __launch_bounds__(32, 1) tma_copy_1d(
        void* __restrict__ dst,
        const void* __restrict__ src,
        int num_bytes) {
      extern __shared__ __align__(128) unsigned char smem[];
    
      unsigned long long* mbar =
          reinterpret_cast<unsigned long long*>(smem + num_bytes);
      unsigned int smem_addr =
          static_cast<unsigned int>(__cvta_generic_to_shared(smem));
      unsigned int mbar_addr =
          static_cast<unsigned int>(__cvta_generic_to_shared(mbar));
    
      if (threadIdx.x == 0) {
        asm volatile(
            "mbarrier.init.shared::cta.b64 [%0], %1;" ::"r"(mbar_addr), "r"(1));
        asm volatile("fence.mbarrier_init.release.cluster;" :::);
      }
      __syncwarp();
    
      if (threadIdx.x == 0) {
        // Announce expected transaction bytes on the mbarrier
        asm volatile(
            "mbarrier.arrive.expect_tx.shared::cta.b64 _, [%0], %1;" ::"r"(
                mbar_addr),
            "r"(num_bytes));
    
        // TMA Load: GMEM -> SMEM (async, completed via mbarrier)
        asm volatile(
            "cp.async.bulk.shared::cluster.global"
            ".mbarrier::complete_tx::bytes"
            " [%0], [%1], %2, [%3];\n" ::"r"(smem_addr),
            "l"(src),
            "r"(num_bytes),
            "r"(mbar_addr)
            : "memory");
    
        // Block until the mbarrier phase flips (TMA load completed)
        asm volatile(
            "{\n"
            ".reg .pred P1;\n"
            "TMA_COPY_WAIT_LOAD:\n"
            "mbarrier.try_wait.parity.shared::cta.b64"
            " P1, [%0], %1;\n"
            "@P1 bra TMA_COPY_LOAD_DONE;\n"
            "bra TMA_COPY_WAIT_LOAD;\n"
            "TMA_COPY_LOAD_DONE:\n"
            "}" ::"r"(mbar_addr),
            "r"(0));
    
        // TMA Store: SMEM -> GMEM
        asm volatile(
            "cp.async.bulk.global.shared::cta.bulk_group"
            " [%0], [%1], %2;\n" ::"l"(dst),
            "r"(smem_addr),
            "r"(num_bytes)
            : "memory");
        asm volatile("cp.async.bulk.commit_group;");
        asm volatile("cp.async.bulk.wait_group.read 0;" ::: "memory");
    
        asm volatile("mbarrier.inval.shared::cta.b64 [%0];" ::"r"(mbar_addr));
      }
    }
    Inline Assembly Safety

    Heavy reliance on inline PTX assembly for TMA operations. While functionally correct, this approach lacks compile-time safety and could benefit from additional validation or abstraction to prevent potential issues with register allocation or instruction encoding.

      asm volatile(
          "mbarrier.init.shared::cta.b64 [%0], %1;" ::"r"(mbar_addr), "r"(1));
      asm volatile("fence.mbarrier_init.release.cluster;" :::);
    }
    __syncwarp();
    
    if (threadIdx.x == 0) {
      // Announce expected transaction bytes on the mbarrier
      asm volatile(
          "mbarrier.arrive.expect_tx.shared::cta.b64 _, [%0], %1;" ::"r"(
              mbar_addr),
          "r"(num_bytes));
    
      // TMA Load: GMEM -> SMEM (async, completed via mbarrier)
      asm volatile(
          "cp.async.bulk.shared::cluster.global"
          ".mbarrier::complete_tx::bytes"
          " [%0], [%1], %2, [%3];\n" ::"r"(smem_addr),
          "l"(src),
          "r"(num_bytes),
          "r"(mbar_addr)
          : "memory");
    
      // Block until the mbarrier phase flips (TMA load completed)
      asm volatile(
          "{\n"
          ".reg .pred P1;\n"
          "TMA_COPY_WAIT_LOAD:\n"
          "mbarrier.try_wait.parity.shared::cta.b64"
          " P1, [%0], %1;\n"
          "@P1 bra TMA_COPY_LOAD_DONE;\n"
          "bra TMA_COPY_WAIT_LOAD;\n"
          "TMA_COPY_LOAD_DONE:\n"
          "}" ::"r"(mbar_addr),
          "r"(0));
    
      // TMA Store: SMEM -> GMEM
      asm volatile(
          "cp.async.bulk.global.shared::cta.bulk_group"
          " [%0], [%1], %2;\n" ::"l"(dst),
          "r"(smem_addr),
          "r"(num_bytes)
          : "memory");
      asm volatile("cp.async.bulk.commit_group;");
      asm volatile("cp.async.bulk.wait_group.read 0;" ::: "memory");
    
      asm volatile("mbarrier.inval.shared::cta.b64 [%0];" ::"r"(mbar_addr));
    Test Coverage Limitations

    The multicast test has CUDA version gating (>= 13000) which may limit testing coverage. Consider if there are alternative ways to validate multicast functionality or if this limitation is acceptable for the current scope.

    #if (CUDA_VERSION >= 13000)
    
    // Verify TMA 1D bulk copy writing TO an NVLS multicast pointer.
    // Root uses TMA to write data to the MC pointer, which broadcasts
    // via NVLS hardware. All ranks then verify the data arrived by
    // reading from their local UC view with a normal copy.
    TEST_F(TmaTest, TmaMulticastWrite) {
      if (communicator_->size() == 1) {
        GTEST_SKIP() << "Skipping test for single device";
      }
    
      const int64_t rank = communicator_->deviceId();
      const int64_t local_rank = communicator_->local_rank();
    
      int major;
      NVFUSER_CUDA_RT_SAFE_CALL(cudaDeviceGetAttribute(
          &major, cudaDevAttrComputeCapabilityMajor, local_rank));
      if (major < 9) {
        GTEST_SKIP() << "Requires Hopper (SM90+)";
      }
    
      int is_multicast_supported;
      NVFUSER_CUDA_SAFE_CALL(cuDeviceGetAttribute(
          &is_multicast_supported,
          CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED,
          local_rank));
      if (is_multicast_supported == 0) {
        GTEST_SKIP() << "Device does not support Multicast Objects; skipping.";
      }
    
      constexpr int64_t kNumElems = 524288; // 2 MB / sizeof(int32_t)
      constexpr int64_t root = 0;
    
      // cp.async.bulk transfer size is limited by shared memory,
      // so we broadcast a 4 KB slice via TMA.
      constexpr int kTmaBytes = 4096;
      static_assert(kTmaBytes % 16 == 0);
      constexpr int kTmaElems = kTmaBytes / sizeof(int32_t);
    
      at::Tensor local =
          SymmetricTensor::allocate({kNumElems}, at::kInt, communicator_->device());
      local.zero_();
      SymmetricTensor sym(local);
      sym.setupMulticast(root, "tma_mcast");
    
      auto opts = at::TensorOptions().dtype(at::kInt).device(at::kCUDA, local_rank);
    
      // Root: TMA-write source data to MC pointer (NVLS broadcasts it)
      if (rank == root) {
        at::Tensor src = at::arange(kTmaElems, opts);
        launchTmaCopy1D(sym.multicastPtr(), src.data_ptr(), kTmaBytes);
        NVFUSER_CUDA_RT_SAFE_CALL(cudaDeviceSynchronize());
      }
    
      communicator_->barrier();
    
      // All ranks: verify data arrived via normal read of local UC tensor
      at::Tensor readback = sym.localTensor().slice(0, 0, kTmaElems).clone();
      at::Tensor expected = at::arange(kTmaElems, opts);
      EXPECT_TRUE(readback.equal(expected))
          << "Rank " << rank << " did not receive multicast data written by TMA";
    }
    
    #endif // CUDA_VERSION >= 13000

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Feb 25, 2026

    Greptile Summary

    Added Hopper TMA (cp.async.bulk) copy kernel for P2P and multicast communications. The kernel implements a two-phase transfer (GMEM→SMEM→GMEM) driven by a single thread with mbarrier synchronization, validated across local memory, peer VMM-mapped memory, and NVLS multicast pointers.

    • Kernel implementation in csrc/multidevice/tma_copy.cu uses inline PTX assembly for TMA operations with proper mbarrier synchronization between load and store phases
    • Tests validate three critical memory types: local device memory, inter-device P2P via VMM, and NVLS multicast writes
    • Build system correctly integrates kernel stringification via existing NVFUSER_RUNTIME_FILES pipeline
    • Code follows established patterns from existing runtime kernels (alltoallv.cu, multicast.cu)

    Confidence Score: 5/5

    • This PR is safe to merge with minimal risk
    • The implementation is well-structured with comprehensive tests, proper synchronization primitives, thorough documentation, and follows established project patterns. The TMA kernel uses correct PTX assembly with proper mbarrier synchronization. Tests validate all target memory types with appropriate feature detection and graceful skips. No logic errors, security issues, or architectural concerns identified.
    • No files require special attention

    Important Files Changed

    Filename Overview
    CMakeLists.txt Added test file registration, build dependencies, and runtime file for TMA kernel stringification
    csrc/multidevice/tma_copy.cu Implements single-warp TMA kernel with two-phase GMEM->SMEM->GMEM transfer using mbarrier synchronization
    tests/cpp/test_multidevice_tma.cpp Comprehensive tests validating TMA across local, peer P2P, and NVLS multicast memory types with NVRTC compilation

    Sequence Diagram

    sequenceDiagram
        participant T0 as Thread 0
        participant SMEM as Shared Memory
        participant MBAR as MBarrier
        participant GMEM_SRC as GMEM(src)
        participant GMEM_DST as GMEM(dst)
        
        T0->>MBAR: mbarrier.init(arrival_count=1)
        T0->>T0: fence.mbarrier_init.release.cluster
        T0->>T0: __syncwarp()
        
        Note over T0,GMEM_SRC: TMA Load Phase
        T0->>MBAR: mbarrier.arrive.expect_tx(num_bytes)
        T0->>GMEM_SRC: cp.async.bulk (TMA load)
        GMEM_SRC-->>SMEM: async data transfer
        SMEM->>MBAR: complete_tx notification
        
        T0->>MBAR: mbarrier.try_wait.parity (spin until done)
        MBAR-->>T0: phase flip (load complete)
        
        Note over T0,GMEM_DST: TMA Store Phase
        T0->>SMEM: cp.async.bulk.global.shared
        SMEM-->>GMEM_DST: async data transfer
        T0->>T0: cp.async.bulk.commit_group
        T0->>T0: cp.async.bulk.wait_group.read 0
        
        T0->>MBAR: mbarrier.inval
    
    Loading

    Last reviewed commit: ae0c760

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    3 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    3 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    @samnordmann
    Copy link
    Collaborator Author

    !test

    @wujingyue wujingyue requested a review from naoyam February 25, 2026 19:57
    Copy link
    Collaborator

    @wujingyue wujingyue left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    @naoyam I noticed

    Fuser/runtime/memory.cu

    Lines 86 to 96 in 005f7e3

    // References:
    //
    // TMA:
    // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
    // https://github.com/NVIDIA/cutlass/blob/main/include/cute/arch/copy_sm90_tma.hpp
    //
    // Tensor map:
    // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html
    // 1D TMA load:
    // https://github.com/NVIDIA/cutlass/blob/main/include/cute/arch/copy_sm90_tma.hpp#L1400
    for codegen. Would you recommend using those building blocks or adding some so it's easier for nvFuser to generate fused comm/gemm in the future?

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants