Skip to content

add tma transpose auto scheduler#5982

Open
liqiangxl wants to merge 21 commits intomainfrom
llu/transpose_tma_auto2
Open

add tma transpose auto scheduler#5982
liqiangxl wants to merge 21 commits intomainfrom
llu/transpose_tma_auto2

Conversation

@liqiangxl
Copy link
Collaborator

@liqiangxl liqiangxl commented Feb 19, 2026

The heuristics is a basic version and current performance is in this doc.

Base automatically changed from llu/transpose_tma_auto to main February 19, 2026 17:09
@github-actions
Copy link

github-actions bot commented Feb 19, 2026

Review updated until commit 14b113a

Description

  • Implement TMA transpose auto scheduler with comprehensive scheduling logic

  • Add TMA transpose enable option and fallback mechanism to non-TMA scheduler

  • Add new TMA-specific parameters for chunks_per_thread and elements_per_chunk

  • Add comprehensive test suite for TMA transpose functionality

Changes walkthrough

Relevant files
Enhancement
transpose_tma.cpp
Implement TMA transpose scheduling algorithm                         

csrc/scheduler/transpose_tma.cpp

  • Implement getTransposeHeuristics with TMA-specific tile sizing and
    chunking logic
  • Implement scheduleTranspose with TMA tiling, shared memory swizzling,
    and register scheduling
  • Add TMA load/store operations with proper memory type promotion
  • Include comprehensive transform propagation and parallelization
    strategies
  • +234/-4 
    options.cpp
    Add TMA transpose enable option                                                   

    csrc/options.cpp

  • Replace std::sort with std::ranges::sort for modern C++ usage
  • Add tma_transpose option to enable options map
  • +2/-1     
    transpose.cpp
    Add TMA enable check with fallback logic                                 

    csrc/scheduler/transpose.cpp

  • Add conditional check for TMA transpose enable option before TMA
    scheduling
  • Maintain fallback to non-TMA scheduler when TMA is not applicable
  • +5/-3     
    options.h
    Add TMA transpose option and improve enum types                   

    csrc/options.h

  • Add TmaTranspose to EnableOption enum
  • Change enum underlying types to std::uint8_t for type safety
  • Optimize Options copy constructor with lambda initialization
  • +11/-8   
    transpose_heuristic.h
    Extend TransposeParams with TMA-specific parameters           

    csrc/scheduler/transpose_heuristic.h

  • Add use_tma_store, chunks_per_thread, and elements_per_chunk
    parameters
  • Update equality comparison and hashing to include new TMA parameters
  • Add TMA-specific information to parameter string representation
  • +19/-0   
    Tests
    test_transpose.cpp
    Add comprehensive TMA transpose test suite                             

    tests/cpp/test_transpose.cpp

  • Add TmaTransposeTestP test class with parameterized testing
  • Test various data types, transpose dimensions, and inner dimension
    sizes
  • Include corner case handling for small inner dimensions
  • +55/-0   

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Parameter Validation

    The heuristic calculations make several assumptions about tensor dimensions and TMA hardware constraints. Need to validate that tile_size calculations (kTmaSwizzleBytes / max_input_dtype_size) don't result in invalid tile sizes, and that chunks_per_thread calculations stay within valid hardware ranges [1, 8].

    tparams->tile_size2 = kTmaSwizzleBytes / max_input_dtype_size;
    // [Tunable] tile1 is the inner most dim of the output tvs
    tparams->tile_size1 =
        (n_input == 1) ? tparams->tile_size2 * 2 : tparams->tile_size2;
    // [Tunable] In 128-bytes swizzled tma load, inner most dim is split into 8
    // chunks each with 16 bytes. Each thread may handle multiple chunks along
    // the inner most dim, range is [1, 8]
    // bdimx = tile_size1 * 8 / chunks_per_thread
    const int64_t target_bdimx = (n_input == 1) ? 256 : 128;
    tparams->chunks_per_thread = tparams->tile_size1 * 8 / target_bdimx;
    tparams->elements_per_chunk = kBytesPerChunk / max_input_dtype_size;
    Error Handling Robustness

    The TMA scheduling involves complex memory layout manipulations. Need to ensure proper error handling for cases where TMA operations might fail or be invalid, and verify the fallback to non-TMA scheduler works correctly for all edge cases.

    void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) {
      FusionGuard fg(fusion);
    
      // Make sure we don't have global memory set on intermediate tensors from
      // fusion segmentation
      scheduler_utils::clearMemorySpace(fusion);
    
      // maybe has_reduction for scheduling should be done on a per output tensor
      // basis.
      NVF_ERROR(
          !ir_utils::hasAnyReductionOps(fusion),
          "This scheduler only handles pointwise ops.");
    
      // Cache inputs
      auto cached_inputs = scheduler_utils::cacheInputs(fusion, true);
    
      // Cache and fork outputs
      auto cached_outputs = scheduler_utils::cacheAndForkOutputs(fusion, true);
    
      scheduler_utils::prepareForMemoryTypePromotion(fusion);
    
      // always use TMA load for inputs
      int64_t max_input_dims = 0;
      TensorView* input_ref = nullptr;
      std::vector<TensorView*> tma_load_tvs;
      for (auto [cached_input, input_idx] : cached_inputs) {
        if (auto load_op = dynamic_cast<LoadStoreOp*>(cached_input->definition())) {
          load_op->setOpType(LoadStoreOpType::CpAsyncBulkTensorTile);
          cached_input->setMemoryType(MemoryType::Shared);
          tma_load_tvs.push_back(cached_input);
        }
        // find the input with the most logical dimensions
        if (scheduler_utils::nLogicalDims(cached_input) > max_input_dims) {
          max_input_dims = scheduler_utils::nLogicalDims(cached_input);
          input_ref = cached_input;
        }
      }
      NVF_ERROR(!tma_load_tvs.empty());
    
    Performance Validation

    The PR mentions performance data in an external document, but the actual performance characteristics and any performance regressions should be validated. The current heuristics appear basic and may need tuning for optimal performance across different tensor shapes and data types.

    std::unique_ptr<TransposeParams> getTransposeHeuristics(
        Fusion* fusion,
        SchedulerRuntimeInfo& runtime_info,
        HeuristicDataCache* data_cache) {
      auto tparams = std::make_unique<TransposeParams>();
      tparams->tag = "TMA Transpose heuristics";
      tparams->cparams.index_type = runtime_info.getIndexType();
      tparams->use_tma_load = true;
      tparams->use_tma_store = false;
    
      int64_t max_input_dtype_size = 1;
      int64_t n_input = 0;
      for (auto inp : ir_utils::filterByType<TensorView>(fusion->inputs())) {
        max_input_dtype_size = std::max(
            max_input_dtype_size,
            dataTypeSizeByte(valueOrError(inp->getDataType())));
        n_input++;
      }
      // input layout: [I1, I2] -> [tile1, tile2]
      // output layout: [I2, I1] -> [tile2, tile1]
      // tile2 is the inner most dim of the input tvs, it must equals to tma swizzle
      // bytes.
      tparams->tile_size2 = kTmaSwizzleBytes / max_input_dtype_size;
      // [Tunable] tile1 is the inner most dim of the output tvs
      tparams->tile_size1 =
          (n_input == 1) ? tparams->tile_size2 * 2 : tparams->tile_size2;
      // [Tunable] In 128-bytes swizzled tma load, inner most dim is split into 8
      // chunks each with 16 bytes. Each thread may handle multiple chunks along
      // the inner most dim, range is [1, 8]
      // bdimx = tile_size1 * 8 / chunks_per_thread
      const int64_t target_bdimx = (n_input == 1) ? 256 : 128;
      tparams->chunks_per_thread = tparams->tile_size1 * 8 / target_bdimx;
      tparams->elements_per_chunk = kBytesPerChunk / max_input_dtype_size;
    
      if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) {
        debug() << "\n===== TMA Transpose Stats ========\n"
                << "inputs: " << ir_utils::toString(fusion->inputs()) << "\n"
                << "outputs: " << ir_utils::toString(fusion->outputs()) << "\n"
                << "tile_size1: " << tparams->tile_size1 << "\n"
                << "tile_size2: " << tparams->tile_size2 << "\n"
                << "chunks_per_thread: " << tparams->chunks_per_thread << "\n"
                << "elements_per_chunk: " << tparams->elements_per_chunk << "\n"
                << "\n";
      }
      return tparams;
    }

    @liqiangxl liqiangxl marked this pull request as ready for review February 19, 2026 17:15
    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl liqiangxl requested a review from rdspring1 February 19, 2026 17:16
    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Feb 19, 2026

    Greptile Summary

    This PR implements a TMA (Tensor Memory Accelerator) based transpose auto-scheduler for NVIDIA GPUs. The implementation adds a new TmaTranspose enable option that gates the TMA transpose scheduling path, falling back to the existing non-TMA scheduler if disabled or if TMA heuristics return null.

    The core implementation includes:

    • Heuristic computation in getTransposeHeuristics() that calculates tile sizes based on data type size and TMA swizzle requirements (128 bytes)
    • A comprehensive scheduling algorithm in scheduleTranspose() that handles TMA load operations, applies XOR swizzling to shared memory, and sets up per-thread access patterns
    • New parameters in TransposeParams: use_tma_store, chunks_per_thread, and elements_per_chunk
    • Additional infrastructure changes including adding std::uint8_t as underlying type for option enums and modernizing code with std::ranges::sort

    Previous review threads have identified several edge case handling issues in the heuristics calculation that should be addressed.

    Confidence Score: 3/5

    • This PR introduces significant new functionality with edge cases flagged in previous reviews
    • The implementation is comprehensive with good test coverage, but previous review threads identified several edge case issues (division by zero risks, potential zero tile sizes) that could cause runtime failures. The core logic appears sound, but the heuristics need validation checks.
    • Pay close attention to csrc/scheduler/transpose_tma.cpp - the heuristics calculation needs validation to handle edge cases where computed values could be zero

    Important Files Changed

    Filename Overview
    csrc/options.cpp Modernized to use std::ranges::sort and added TmaTranspose option
    csrc/options.h Added uint8_t underlying type to enums, cstdint include, TmaTranspose option, and improved Options copy constructor
    csrc/scheduler/transpose.cpp Added EnableOption::TmaTranspose gate before calling TMA heuristics
    csrc/scheduler/transpose_heuristic.h Added use_tma_store, chunks_per_thread, elements_per_chunk fields with proper equality/hashing
    csrc/scheduler/transpose_tma.cpp Complete TMA transpose scheduler implementation with heuristics and scheduling logic - previous review comments indicate several edge case issues
    tests/cpp/test_transpose.cpp Added comprehensive TMA transpose test suite with multiple data types and dimensions

    Flowchart

    %%{init: {'theme': 'neutral'}}%%
    flowchart TD
        A[TransposeScheduler::computeHeuristics] --> B{TmaTranspose enabled?}
        B -->|Yes| C[transpose::tma::getTransposeHeuristics]
        B -->|No| D[transpose::non_tma::getTransposeHeuristics]
        C --> E{tparams != nullptr?}
        E -->|No| D
        E -->|Yes| F[Return TMA params]
        D --> F
        F --> G[TransposeScheduler::schedule]
        G --> H[transpose::tma::scheduleTranspose]
        
        C --> I[Calculate max_input_dtype_size]
        I --> J[Compute tile_size2 = 128 / dtype_size]
        J --> K[Compute tile_size1 based on n_input]
        K --> L[Compute chunks_per_thread = tile_size1 * 8 / target_bdimx]
        L --> M[Compute elements_per_chunk = 16 / dtype_size]
        
        H --> N[Cache inputs/outputs]
        N --> O[Setup TMA load ops]
        O --> P[TMA tiling with tile_size1/tile_size2]
        P --> Q[Apply XOR swizzle to input smem]
        Q --> R[Split by elements_per_chunk]
        R --> S[Split by chunks_per_thread]
        S --> T[Parallelize with TIDx and Unroll]
    
    Loading

    Last reviewed commit: 14b113a

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    6 files reviewed, 4 comments

    Edit Code Review Agent Settings | Greptile

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    6 files reviewed, 2 comments

    Edit Code Review Agent Settings | Greptile

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Feb 20, 2026

    Additional Comments (2)

    csrc/scheduler/transpose_heuristic.h
    New fields use_tma_store, chunks_per_thread, and elements_per_chunk missing from equality check

      bool sameAs(const HeuristicParams* other_base) const override {
        auto other = dynamic_cast<const TransposeParams*>(other_base);
        if (other == nullptr) {
          return false;
        }
        bool attr_equal = other->cparams == cparams &&
            other->use_tma_load == use_tma_load &&
            other->use_tma_store == use_tma_store &&
            other->chunks_per_thread == chunks_per_thread &&
            other->elements_per_chunk == elements_per_chunk &&
            other->split_before_tiling == split_before_tiling &&
            other->dims_merged_with_1 == dims_merged_with_1 &&
            other->dims_merged_with_2 == dims_merged_with_2 &&
            other->vectorize_factor1 == vectorize_factor1 &&
            other->vectorize_factor2 == vectorize_factor2 &&
            other->tile_size1 == tile_size1 && other->tile_size2 == tile_size2;
        return attr_equal;
      }
    

    csrc/scheduler/transpose_heuristic.h
    New fields use_tma_store, chunks_per_thread, and elements_per_chunk missing from hash calculation

      size_t hash() const override {
        return c10::get_hash(
            use_tma_load,
            use_tma_store,
            chunks_per_thread,
            elements_per_chunk,
            split_before_tiling,
            dims_merged_with_1,
            dims_merged_with_2,
            vectorize_factor1,
            vectorize_factor2,
            tile_size1,
            tile_size2);
      }
    

    liqiangxl and others added 7 commits February 20, 2026 07:01
    Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
    Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
    Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
    Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
    @liqiangxl
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    6 files reviewed, 3 comments

    Edit Code Review Agent Settings | Greptile

    // the inner most dim, range is [1, 8]
    // bdimx = tile_size1 * 8 / chunks_per_thread
    const int64_t target_bdimx = (n_input == 1) ? 256 : 128;
    tparams->chunks_per_thread = tparams->tile_size1 * 8 / target_bdimx;
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    potential division by zero if target_bdimx is zero (though unlikely given the constants)

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Feb 20, 2026

    Additional Comments (2)

    csrc/scheduler/transpose_heuristic.h
    missing new fields in hash function: use_tma_store, chunks_per_thread, elements_per_chunk

      size_t hash() const override {
        return c10::get_hash(
            use_tma_load,
            use_tma_store,
            chunks_per_thread,
            elements_per_chunk,
            split_before_tiling,
            dims_merged_with_1,
            dims_merged_with_2,
            vectorize_factor1,
            vectorize_factor2,
            tile_size1,
            tile_size2);
      }
    

    csrc/scheduler/transpose_heuristic.h
    missing new fields in equality check: use_tma_store, chunks_per_thread, elements_per_chunk

      bool sameAs(const HeuristicParams* other_base) const override {
        auto other = dynamic_cast<const TransposeParams*>(other_base);
        if (other == nullptr) {
          return false;
        }
        bool attr_equal = other->cparams == cparams &&
            other->use_tma_load == use_tma_load &&
            other->use_tma_store == use_tma_store &&
            other->chunks_per_thread == chunks_per_thread &&
            other->elements_per_chunk == elements_per_chunk &&
            other->split_before_tiling == split_before_tiling &&
            other->dims_merged_with_1 == dims_merged_with_1 &&
            other->dims_merged_with_2 == dims_merged_with_2 &&
            other->vectorize_factor1 == vectorize_factor1 &&
            other->vectorize_factor2 == vectorize_factor2 &&
            other->tile_size1 == tile_size1 && other->tile_size2 == tile_size2;
        return attr_equal;
      }
    

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    6 files reviewed, 2 comments

    Edit Code Review Agent Settings | Greptile

    // the inner most dim, range is [1, 8]
    // bdimx = tile_size1 * 8 / chunks_per_thread
    const int64_t target_bdimx = (n_input == 1) ? 256 : 128;
    tparams->chunks_per_thread = tparams->tile_size1 * 8 / target_bdimx;
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    if tile_size1 is 0 (which could happen if max_input_dtype_size > 128), chunks_per_thread will be 0, causing issues with split at line 217

    Suggested change
    tparams->chunks_per_thread = tparams->tile_size1 * 8 / target_bdimx;
    NVF_CHECK(tparams->tile_size1 > 0, "tile_size1 must be positive");
    tparams->chunks_per_thread = tparams->tile_size1 * 8 / target_bdimx;

    Comment on lines +134 to +136
    if (max_output_dims == 0 && max_input_dims == 0) {
    return;
    }
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    consider adding validation or logging when both max_output_dims and max_input_dims are 0 before early return

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

    @liqiangxl liqiangxl removed the request for review from rdspring1 February 20, 2026 15:38
    @liqiangxl liqiangxl marked this pull request as draft February 20, 2026 15:38
    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    1 similar comment
    @liqiangxl
    Copy link
    Collaborator Author

    !test

    //! These can be set through the `NVFUSER_DUMP` environment variable
    //!
    enum class DebugDumpOption {
    enum class DebugDumpOption : std::uint8_t {
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    For clangtidy

    std::back_inserter(option_values),
    [](const auto& kv) { return kv.first; });
    std::sort(option_values.begin(), option_values.end());
    std::ranges::sort(option_values);
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    clangtidy

    options_ = other.options_;
    }
    Options(const Options& other)
    : options_([&other]() {
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    clangtidy

    @liqiangxl liqiangxl marked this pull request as ready for review February 24, 2026 14:07
    @liqiangxl liqiangxl requested a review from rdspring1 February 24, 2026 14:08
    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    6 files reviewed, 4 comments

    Edit Code Review Agent Settings | Greptile

    Comment on lines +56 to +57
    tparams->chunks_per_thread = tparams->tile_size1 * 8 / target_bdimx;
    tparams->elements_per_chunk = kBytesPerChunk / max_input_dtype_size;
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    if chunks_per_thread is 0 (can happen when tile_size1 * 8 < target_bdimx, e.g., when data type is large), the split at line 218 will fail

    Suggested change
    tparams->chunks_per_thread = tparams->tile_size1 * 8 / target_bdimx;
    tparams->elements_per_chunk = kBytesPerChunk / max_input_dtype_size;
    tparams->chunks_per_thread = tparams->tile_size1 * 8 / target_bdimx;
    if (tparams->chunks_per_thread == 0) {
    return nullptr; // Fall back to non-TMA scheduler
    }

    // bdimx = tile_size1 * 8 / chunks_per_thread
    const int64_t target_bdimx = (n_input == 1) ? 256 : 128;
    tparams->chunks_per_thread = tparams->tile_size1 * 8 / target_bdimx;
    tparams->elements_per_chunk = kBytesPerChunk / max_input_dtype_size;
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    if elements_per_chunk is 0 (can happen when max_input_dtype_size > 16, e.g., for complex or very large data types), the split at line 216 will fail

    Suggested change
    tparams->elements_per_chunk = kBytesPerChunk / max_input_dtype_size;
    tparams->elements_per_chunk = kBytesPerChunk / max_input_dtype_size;
    if (tparams->elements_per_chunk == 0) {
    return nullptr; // Fall back to non-TMA scheduler
    }

    Comment on lines +37 to +42
    for (auto inp : ir_utils::filterByType<TensorView>(fusion->inputs())) {
    max_input_dtype_size = std::max(
    max_input_dtype_size,
    dataTypeSizeByte(valueOrError(inp->getDataType())));
    n_input++;
    }
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    if no inputs exist (n_input == 0), the heuristics are still computed and returned, potentially leading to incorrect scheduling

    Suggested change
    for (auto inp : ir_utils::filterByType<TensorView>(fusion->inputs())) {
    max_input_dtype_size = std::max(
    max_input_dtype_size,
    dataTypeSizeByte(valueOrError(inp->getDataType())));
    n_input++;
    }
    for (auto inp : ir_utils::filterByType<TensorView>(fusion->inputs())) {
    max_input_dtype_size = std::max(
    max_input_dtype_size,
    dataTypeSizeByte(valueOrError(inp->getDataType())));
    n_input++;
    }
    if (n_input == 0) {
    return nullptr; // No inputs to transpose
    }

    // [Tunable] In 128-bytes swizzled tma load, inner most dim is split into 8
    // chunks each with 16 bytes. Each thread may handle multiple chunks along
    // the inner most dim, range is [1, 8]
    // bdimx = tile_size1 * 8 / chunks_per_thread
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    comment formula is backwards - should be chunks_per_thread = tile_size1 * 8 / bdimx not bdimx = tile_size1 * 8 / chunks_per_thread

    Suggested change
    // bdimx = tile_size1 * 8 / chunks_per_thread
    // chunks_per_thread = tile_size1 * 8 / target_bdimx

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    1 participant