Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

redo register sharing PR-3972 #3993

Open
wants to merge 35 commits into
base: main
Choose a base branch
from
Open

redo register sharing PR-3972 #3993

wants to merge 35 commits into from

Conversation

liqiangxl
Copy link
Collaborator

@liqiangxl liqiangxl commented Feb 28, 2025

redo #3972

  1. Add hardcode MIN_BLOCKS_PER_SM = 1, to ensure register sharing is not ignored by compiler.
  2. Revise number of padded threads for warp specilization with register sharing to ensure both loading branch and computation branch has 128*N threads.
  3. Added checks and revised tests to ensure that the requested register count in the loading branch is lower than the initial count, in the computing branch is higher than the initial count, and that the requested count is feasible

ref for setmaxnreg

Generated code sample:

  // bdimx = 32, bdimy = 20
  if (threadIdx.y >= 16) {
    // last 128 threads are executing this setmaxnreg
    decreaseRegisters<64>();
    // select 1 thread form the last warp to do TMA load
    if threadIdx.y == 19) {
      if (Hopper::electSync(4294967295U)) {
        // TMA Load
      }
    }
    return;
  } else {
    // first 512 threads are executing this setmaxnreg
    increaseRegisters<104>();
    // computations
  }

@liqiangxl liqiangxl changed the title redo register sharing PR-https://github.com/NVIDIA/Fuser/pull/3972 redo register sharing PR-3972 Feb 28, 2025
Copy link

github-actions bot commented Feb 28, 2025

Review updated until commit df88c18

Description

  • Revised warp specialization logic to support register sharing.

  • Added checks and adjustments for padded threads in warp specialization.

  • Updated tests to validate register sharing functionality.

  • Enhanced ParallelDimensionMap to handle register sharing with warp specialization.


Changes walkthrough 📝

Relevant files
Enhancement
circular_buffer.cpp
Update warp dispatch for register sharing                               

csrc/device_lower/pass/circular_buffer.cpp

  • Updated warp dispatch predicate to handle padded threads.
  • Adjusted logic for warp specialization with register sharing.
  • +20/-7   
    parallel_dimension_map.cpp
    Enhance warp specialization with register sharing               

    csrc/parallel_dimension_map.cpp

  • Added logic to handle register sharing with warp specialization.
  • Updated mappings for warp specialization to include padding.
  • Added methods to get padded values for warp specialization.
  • +114/-22
    predicate_compute.cpp
    Update predicate compute for register sharing                       

    csrc/predicate_compute.cpp

  • Updated createElectSyncPredicate to handle register sharing.
  • Adjusted logic for multiple expression elect sync to support register
    sharing.
  • +59/-22 
    test_circular_buffering.cpp
    Update and add tests for register sharing                               

    tests/cpp/test_circular_buffering.cpp

  • Updated tests to skip register sharing on unsupported architectures.
  • Added new tests to validate register sharing functionality.
  • Revised constants and logic to accommodate register sharing.
  • +174/-66
    parallel_dimension_map.h
    Update parallel dimension map for register sharing             

    csrc/parallel_dimension_map.h

  • Added method to get warp specialization padded value.
  • Updated method to adjust mappings for warp specialization.
  • +5/-1     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The new predicate logic for warp dispatch might not handle all edge cases correctly, especially when warp_specialization_pad is not 1.

    // Create warp_dispatch_ite, the predicate is either
    // Tid == bdim - 1 or Tid >= bdim - padded
    int64_t warp_specialization_pad =
        GpuLower::current()
            ->parallelDimensionMap()
            .getWarpSpecializationPaddedVal(warp_specialize_on);
    kir::Predicate* predicate_val = nullptr;
    Val* raw =
        GpuLower::current()->parallelDimensionMap().get(warp_specialize_on);
    Val* raw_minus_pad = SimplifyingIrBuilder::subExpr(
        raw, IrBuilder::create<Val>(warp_specialization_pad, DataType::Index));
    if (warp_specialization_pad == 1) {
      predicate_val = IrBuilder::create<kir::Predicate>(IrBuilder::eqExpr(
          NamedScalar::getParallelIndex(warp_specialize_on), raw_minus_pad));
    } else {
      predicate_val = IrBuilder::create<kir::Predicate>(IrBuilder::geExpr(
          NamedScalar::getParallelIndex(warp_specialize_on), raw_minus_pad));
    }
    kir::IfThenElse* warp_dispatch_ite =
        IrBuilder::create<kir::IfThenElse>(predicate_val);
    
    // Set default value
    Performance Concern

    The padding logic for warp specialization could lead to inefficient thread usage if the block dimensions are not multiples of 128.

    NVF_ERROR(
        ws_with_register_sharing_.size() <= 1,
        "Warp specialization with register sharing is only supported on one parallel type.");
    // shortcut for case without register sharing
    if (ws_with_register_sharing_.empty()) {
      for (auto pt : warp_specialized_types_) {
        auto dim_it = dim_map_.find(pt);
        if (dim_it == dim_map_.end()) {
          dim_map_[pt] = IrBuilder::create<Val>(2, DataType::Index);
        } else {
          // Intentionally not using SimplifyingIrBuilder::addExpr here so that
          // we still have access to the pointer to the original IR node.
          // We need the pointer to the original IR node because we want
          // getRawCompute to be callable in an environment without FusionGuard,
          // that is, when the IR container is read-only. In such an environment,
          // we can't create new IR nodes for (x - 1). By using
          // IrBuilder::addExpr, we can always create IR nodes like addExpr(x, 1),
          // and SimplifyingIrBuilder::addExpr in getRawCompute will be able to
          // simplify find the x when we do addExpr(addExpr(x, 1) - 1).
          dim_map_[pt] = IrBuilder::addExpr(
              dim_it->second, dim_it->second->fusion()->oneVal());
        }
        exact_types_.erase(pt);
      }
      return;
    }
    // For register sharing, require contiguous 128 threads calling the same
    // setreg instruction.
    // Not used: 1, Const: n, Dynamic: -1
    auto getThreadsCountInDim = [&](ParallelType pt) {
      if (!dim_map_.contains(pt)) {
        return (int64_t)1;
      }
      if (dim_map_.at(pt)->isConstScalar()) {
        return dim_map_.at(pt)->value().as<int64_t>();
      }
      // Return -1 for dynamic dimensions, this disables register sharing on
      // dynamic dimensions since we can't guarantee the number of threads is
      // divisible by 128. We may allow this in the future and delegate this
      // check to a point where the launch parameters are known.
      return (int64_t)-1;
    };
    // Warp specialization with register sharing on parallel type pt
    // index = TIDx + TIDy * bdimx + TIDz * bdimx * bdimy
    auto pt = *ws_with_register_sharing_.begin();
    auto dim_it = dim_map_.find(pt);
    int64_t pad_n_threads = 0;
    int64_t after_pad = 0;
    
    // switch is not used to avoid explicitly handle all parallel types
    if (pt == ParallelType::TIDx) {
      // If on TIDx, pad by 128
      pad_n_threads = 128;
      after_pad = getThreadsCountInDim(pt) + pad_n_threads;
      NVF_ERROR(
          after_pad % 128 == 0,
          "Illegal register sharing on TIDx, bdimx = ",
          after_pad);
    } else if (pt == ParallelType::TIDy) {
      // If on TIDy, pad by 128 / bdimx
      int64_t bdimx = getThreadsCountInDim(ParallelType::TIDx);
      pad_n_threads = scheduler_utils::safeDiv(128, bdimx);
      after_pad = getThreadsCountInDim(pt) + pad_n_threads;
      NVF_ERROR(
          (after_pad * bdimx) % 128 == 0,
          "Illegal register sharing on TIDy, bdimx = ",
          bdimx,
          ", bdimy = ",
          after_pad);
    } else if (pt == ParallelType::TIDz) {
      // If on TIDz, pad by 128 / (bdimx * bdimy)
      int64_t bdimx = getThreadsCountInDim(ParallelType::TIDx);
      int64_t bdimy = getThreadsCountInDim(ParallelType::TIDy);
      pad_n_threads = scheduler_utils::safeDiv(128, bdimx * bdimy);
      after_pad = getThreadsCountInDim(pt) + pad_n_threads;
      NVF_ERROR(
          (after_pad * bdimx * bdimy) % 128 == 0,
          "Illegal register sharing on TIDz, bdimx = ",
          bdimx,
          ", bdimy = ",
          bdimy,
          ", bdimz = ",
          after_pad);
    } else {
      NVF_THROW("Unsupported parallel type for register sharing: ", pt);
    }
    
    // Apply the pad
    warp_specialization_padded_vals_[pt] = pad_n_threads;
    auto off_set = IrBuilder::create<Val>(pad_n_threads, DataType::Index);
    auto current_val = dim_it == dim_map_.end()
        ? IrBuilder::create<Val>(1, DataType::Index)
        : dim_it->second;
    dim_map_[pt] = IrBuilder::addExpr(current_val, off_set);
    exact_types_.erase(pt);
    Test Coverage

    The new tests for register sharing should be expanded to cover more CTA shapes and parallel types to ensure robustness.

      bool testEnablesRegisterSharing() {
        return std::holds_alternative<WarpSpecialized>(circular_buffer_type) &&
            std::get<WarpSpecialized>(circular_buffer_type)
                .num_registers.has_value();
      }
    
      bool testEnablesRegisterSharingTIDx() {
        return testEnablesRegisterSharing() &&
            std::get<WarpSpecialized>(circular_buffer_type).on ==
            ParallelType::TIDx;
      }
    
      bool testEnablesRegisterSharingTIDy() {
        return testEnablesRegisterSharing() &&
            std::get<WarpSpecialized>(circular_buffer_type).on ==
            ParallelType::TIDy;
      }
    
      // https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk
      // the memory range [srcMem, srcMem + size - 1] must not overflow the source
      // memory space. Otherwise, the behavior is undefined.
      bool tma1dSrcAddressOverflow(int64_t bulk_inner_dim) {
        return tensor_inner_dim % bulk_inner_dim != 0 &&
            tma_load_type == LoadStoreOpType::CpAsyncBulk;
      }
    
      template <typename data_type>
      void compare(int64_t tensor_dim, at::Tensor result, at::Tensor reference) {
        at::Tensor reference_cpu_data = reference.cpu();
        at::Tensor result_cpu_data = result.cpu();
    
        auto reference_cpu = reference_cpu_data.accessor<data_type, 1>();
        auto result_cpu = result_cpu_data.accessor<data_type, 1>();
    
        constexpr double abs_tolerance = 1e-3;
        constexpr double rel_tolerance = 1e-3;
        for (int64_t pos = 0; pos < tensor_dim; ++pos) {
          double tolerance =
              abs_tolerance + rel_tolerance * fabs((double)reference_cpu[pos]);
          if (fabs((double)result_cpu[pos] - (double)reference_cpu[pos]) >
              tolerance) {
            std::cout << "[" << pos << "] - result: " << result_cpu[pos]
                      << " | reference: " << reference_cpu[pos] << std::endl;
          }
        }
      }
    
      template <typename data_type>
      void compare(
          int64_t tensor_outer_dim,
          int64_t tensor_inner_dim,
          at::Tensor result,
          at::Tensor reference) {
        at::Tensor reference_cpu_data = reference.cpu();
        at::Tensor result_cpu_data = result.cpu();
    
        auto reference_cpu = reference_cpu_data.accessor<data_type, 2>();
        auto result_cpu = result_cpu_data.accessor<data_type, 2>();
    
        constexpr double abs_tolerance = 1e-3;
        constexpr double rel_tolerance = 1e-3;
        for (int64_t out_pos = 0; out_pos < tensor_outer_dim; ++out_pos) {
          for (int64_t in_pos = 0; in_pos < tensor_inner_dim; ++in_pos) {
            double tolerance = abs_tolerance +
                rel_tolerance * fabs((double)reference_cpu[out_pos][in_pos]);
            if (fabs(
                    (double)reference_cpu[out_pos][in_pos] -
                    (double)result_cpu[out_pos][in_pos]) > tolerance) {
              std::cout << "[" << out_pos << ", " << in_pos
                        << "] - result: " << result_cpu[out_pos][in_pos]
                        << " | ref: " << reference_cpu[out_pos][in_pos]
                        << std::endl;
            }
          }
        }
      }
    };
    
    TEST_F(NVFuserTest, ElectSyncCompatibility) {
      NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);
    
      std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
    
      TensorView* input = makeContigTensor(3);
      fusion->addInput(input);
      TensorView* output = set(input);
      fusion->addOutput(output);
    
      TensorView* smem_cache =
          input->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile);
      smem_cache->setMemoryType(MemoryType::Shared);
    
      // For TMA load, both the shared memory layout and the loop nest and
      // parallelization of TMA are specified by the consumer: smem_cache
    
      // Step 1: define TMA domain
      // Because we want to treat the entire tensor as 1D, we define the TMA
      // domain as [I0*I1*I2]
      smem_cache->merge(0);
      smem_cache->merge(0);
      // Note that the TMA domain only exist in people's mind, there is no need to
      // set anything here.
    
      // Step 2: define box
      smem_cache->split(0, 256);
      // [I0*I1*I2/256, 256]
      // partitioned IterDomain: I0*I1*I2
      // coordinate IterDomain: I0*I1*I2/256
      // box IterDomain: 256
    
      // Step 3: define tile
      // We use dense tile here, so tile == box. Nothing to do here.
    
      // Step 4: schedule the shared memory tensor
      // By default, the allocation domain is the logical domain, which is already
      // in good shape for this case.
    
      constexpr int64_t number_of_stages = 2;
      // Step 5: schedule the consumer tensor
      smem_cache->split(0, 4);
      // [I0*I1*I2/256/4, 4, 256]
      smem_cache->split(0, number_of_stages);
      // [I0*I1*I2/256/4/2, 2, 4, 256]
    
      // [BIDx, 2, TIDx, Bulk]
      smem_cache->axis(0)->parallelize(ParallelType::BIDx);
      smem_cache->axis(2)->parallelize(ParallelType::TIDx);
      smem_cache->axis(3)->parallelize(ParallelType::Bulk);
    
      // Schedule the smem->gmem part
      output->merge(0);
      output->merge(0);
      output->split(0, 256);
      output->split(0, 4);
      output->split(0, number_of_stages);
      output->axis(0)->parallelize(ParallelType::BIDx);
      output->axis(3)->parallelize(ParallelType::TIDx);
    
      inlineAllAt(output, /*pos=*/2);
      smem_cache->circularBuffer(number_of_stages);
    
      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
      std::vector<int64_t> shape(3, 300);
      auto t0 = at::randn(shape, options);
    
      // IterDomain 2 for the TMA load is parallelized with TIDx, so we generate
      // (threadIdx.x < 4) predicate. This thread predicate is incompatible with
      // circular buffering because we generate an ElectSync predicate that uses
      // a single thread.
      KernelExecutor ke;
      try {
        ke.compile(fusion.get(), {t0});
      } catch (const std::exception& e) {
        const char* reference =
            R"(This thread-parallelized TensorView T2_s_float[ iblockIdx.x15{( ceilDiv(( ceilDiv(( ceilDiv(( ( ( (( (( getMetaData(T0) )).logical_size ))[0] ) * ( (( (( getMetaData(T0) )).logical_size ))[1] ) ) * ( (( (( getMetaData(T0) )).logical_size ))[2] ) ), 256) ), 4) ), 2) )}, iS16{2}, ithreadIdx.x14{4}, iB12{256} ] ca_pos( 2 ) is incorrectly contained within a If-Then-Else with the ElectSync predicate.)";
        const char* str_match_pointer = strstr(e.what(), reference);
        ASSERT_TRUE(str_match_pointer != nullptr);
      }
    }
    
    TEST_P(TmaCircularBufferingTest, SingleDim) {
      NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);
    
      std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
    
      TensorView* tv0 = makeContigTensor(1);
      fusion->addInput(tv0);
    
      TensorView* tv1 = exp(tv0);
      fusion->addOutput(tv1);
    
      TensorView* tv2 = tv0->cacheAfter(tma_load_type);
      tv2->setMemoryType(MemoryType::Shared);
    
      TensorView* reference = tv1;
    
      // Constants
      constexpr size_t bulk_inner_dim = 256;
      if (tma1dSrcAddressOverflow(bulk_inner_dim)) {
        GTEST_SKIP() << "cp.async.bulk doesn't allow src address overflow!";
        return;
      }
    
      // [M] -> [M/bid, bid]
      reference->split(-1, bulk_inner_dim);
    
      // Propagate Transformations
      TransformPropagatorWithCheck propagator(reference);
      MaxLogicalDomainInfoSpanningTree(reference).traverse(&propagator);
    
      // Set inlineAt before applying circular buffer
      inlineAllAt(tv1, /*pos=*/1);
    
      // Parallelization
      tv2->axis(-1)->parallelize(ParallelType::Bulk);
      tv1->axis(-1)->parallelize(ParallelType::TIDx);
    
      // Circular Buffer with TMA loads
      tv2->circularBuffer(
          number_of_stages, prefetch_distance, circular_buffer_type);
    
      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
      at::Tensor t0 = at::randn({tensor_inner_dim}, options);
      at::Tensor t1 = at::exp(t0);
    
      KernelExecutor ke;
      ke.compile(fusion.get(), {t0});
    
      auto cg_outputs = ke.run({t0});
      compare<float>(tensor_inner_dim, cg_outputs[0].as<at::Tensor>(), t1);
      testValidate(fusion.get(), cg_outputs, {t0}, {t1}, __LINE__, __FILE__);
    }
    
    TEST_P(TmaCircularBufferingTest, SingleDimUnroll) {
      NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);
    
      std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
    
      TensorView* tv0 = makeContigTensor(1);
      fusion->addInput(tv0);
    
      TensorView* tv1 = exp(tv0);
      fusion->addOutput(tv1);
    
      TensorView* tv2 = tv0->cacheAfter(tma_load_type);
      tv2->setMemoryType(MemoryType::Shared);
    
      TensorView* reference = tv1;
    
      // Constants
      constexpr size_t unroll_dim = 4;
      constexpr size_t bulk_inner_dim = 256;
      if (tma1dSrcAddressOverflow(bulk_inner_dim)) {
        GTEST_SKIP() << "cp.async.bulk doesn't allow src address overflow!";
        return;
      }
      // [M] -> [M/bid, bid]
      reference->split(-1, bulk_inner_dim);
      // [M/bid, bid] -> [M/bid/unroll, unroll, bid]
      reference->split(0, unroll_dim);
    
      // Propagate Transformations
      TransformPropagatorWithCheck propagator(reference);
      MaxLogicalDomainInfoSpanningTree(reference).traverse(&propagator);
    
      // Set ComputeAt position
      inlineAllAt(tv1, /*pos=*/1);
    
      // Apply Unroll
      tv1->axis(1)->parallelize(ParallelType::Unroll);
      tv1->axis(-1)->parallelize(ParallelType::TIDx);
    
      // Circular Buffer with TMA loads
      tv2->axis(-1)->parallelize(ParallelType::Bulk);
      tv2->circularBuffer(
          number_of_stages, prefetch_distance, circular_buffer_type);
    
      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
      at::Tensor t0 = at::randn({tensor_inner_dim}, options);
      at::Tensor t1 = at::exp(t0);
    
      KernelExecutor ke;
      ke.compile(fusion.get(), {t0});
    
      int64_t axis_extent =
          ceilDiv(ceilDiv(tensor_inner_dim, bulk_inner_dim), unroll_dim);
      if (axis_extent < number_of_stages) {
        ASSERT_ANY_THROW(ke.run({t0}));
        return;
      }
    
      auto cg_outputs = ke.run({t0});
      compare<float>(tensor_inner_dim, cg_outputs[0].as<at::Tensor>(), t1);
      testValidate(fusion.get(), cg_outputs, {t0}, {t1}, __LINE__, __FILE__);
    }
    
    TEST_P(TmaCircularBufferingTest, SingleDimUnswitch) {
      NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);
    
      std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
    
      TensorView* tv0 = makeContigTensor(1);
      fusion->addInput(tv0);
    
      TensorView* tv1 = exp(tv0);
      fusion->addOutput(tv1);
    
      TensorView* tv2 = tv0->cacheAfter(tma_load_type);
      tv2->setMemoryType(MemoryType::Shared);
    
      TensorView* reference = tv1;
    
      // Constants
      constexpr size_t unroll_dim = 4;
      constexpr size_t bulk_inner_dim = 256;
      if (tma1dSrcAddressOverflow(bulk_inner_dim)) {
        GTEST_SKIP() << "cp.async.bulk doesn't allow src address overflow!";
        return;
      }
      // [M] -> [M/bid, bid]
      reference->split(-1, bulk_inner_dim);
      // [M/bid, bid] -> [M/bid/unroll, unroll, bid]
      reference->split(0, unroll_dim);
    
      // Propagate Transformations
      TransformPropagatorWithCheck propagator(reference);
      MaxLogicalDomainInfoSpanningTree(reference).traverse(&propagator);
    
      // Set ComputeAt position
      inlineAllAt(tv1, /*pos=*/1);
    
      // Apply Unswitch
      tv1->axis(1)->parallelize(ParallelType::Unswitch);
      tv1->axis(-1)->parallelize(ParallelType::TIDx);
    
      // Circular Buffer with TMA loads
      tv2->axis(-1)->parallelize(ParallelType::Bulk);
      tv2->circularBuffer(
          number_of_stages, prefetch_distance, circular_buffer_type);
    
      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
      at::Tensor t0 = at::randn({tensor_inner_dim}, options);
      at::Tensor t1 = at::exp(t0);
    
      KernelExecutor ke;
      ke.compile(fusion.get(), {t0});
    
      int64_t axis_extent =
          ceilDiv(ceilDiv(tensor_inner_dim, bulk_inner_dim), unroll_dim);
      if (axis_extent < number_of_stages) {
        ASSERT_ANY_THROW(ke.run({t0}));
        return;
      }
    
      auto cg_outputs = ke.run({t0});
      compare<float>(tensor_inner_dim, cg_outputs[0].as<at::Tensor>(), t1);
      testValidate(fusion.get(), cg_outputs, {t0}, {t1}, __LINE__, __FILE__);
    }
    
    TEST_P(TmaCircularBufferingTest, MultiDim) {
      NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);
      if (tma_load_type == LoadStoreOpType::CpAsyncBulk) {
        GTEST_SKIP() << "LoadStoreOpType::CpAsyncBulk only supports 1D TMA";
        return;
      }
    
      std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
    
      TensorView* tv0 = makeContigTensor(2);
      fusion->addInput(tv0);
    
      TensorView* tv1 = exp(tv0);
      fusion->addOutput(tv1);
    
      TensorView* tv2 = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile);
      tv2->setMemoryType(MemoryType::Shared);
    
      TensorView* reference = tv1;
    
      // Constants
      constexpr int64_t tma_outer_dim = 4;
      constexpr int64_t tma_inner_dim = 256;
    
      // [M, N] -> [M, N/bid, bid]
      reference->split(-1, tma_inner_dim);
      // [M, N/bid, bid] -> [M/bod, bod, N/bid, bid]
      reference->split(0, tma_outer_dim);
      // [M/bod, bod, N/bid, bid] -> [M/bod, N/bid, bod, bid]
      reference->reorder({{-2, -3}});
    
      // Propagate TMA transform
      TransformPropagatorWithCheck propagator(reference);
      MaxLogicalDomainInfoSpanningTree(reference).traverse(&propagator);
    
      // Apply inlineAt for TMA cache
      inlineAllAt(tv1, /*pos=*/2);
    
      // Merge TMA tile and Parallelize
      // [M/bod, N/bid, bod, bid] -> [M/bod, N/bid, bod * bid]
      reference->merge(-2, -1);
      // [M/bod, N/bid, bod * bid] -> [M/bod, N/bid, (bod * bid) / 256, 256]
      reference->split(-1, 256);
    
      // Parallelize
      reference->axis(0)->parallelize(ParallelType::BIDx);
      reference->axis(-1)->parallelize(ParallelType::TIDx);
    
      // Circular Buffer with TMA loads
      tv2->axis(0)->parallelize(ParallelType::BIDx);
      tv2->axis(-1)->parallelize(ParallelType::Bulk);
      tv2->axis(-2)->parallelize(ParallelType::Bulk);
      tv2->circularBuffer(
          number_of_stages, prefetch_distance, circular_buffer_type);
    
      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
      at::Tensor t0 = at::ones({tensor_outer_dim, tensor_inner_dim}, options);
      at::Tensor t1 = at::exp(t0);
    
      KernelExecutor ke;
      ke.compile(fusion.get(), {t0});
    
      auto cg_outputs = ke.run({t0});
      compare<float>(
          tensor_outer_dim, tensor_inner_dim, cg_outputs[0].as<at::Tensor>(), t1);
      testValidate(fusion.get(), cg_outputs, {t0}, {t1}, __LINE__, __FILE__);
    }
    
    TEST_P(TmaCircularBufferingTest, Pointwise) {
      NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);
      std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
    
      TensorView* tv0 = makeContigTensor(2);
      TensorView* tv1 = makeContigTensor(2);
      fusion->addInput(tv0);
      fusion->addInput(tv1);
    
      TensorView* tv2 = add(tv0, tv1);
      fusion->addOutput(tv2);
    
      // Use TMA to load TV0 into shared memory
      TensorView* tv3 = tv0->cacheAfter(tma_load_type);
      tv3->setMemoryType(MemoryType::Shared);
    
      // __syncthreads() is mssing if mixing TMA and non-TMA loading with circular
      TensorView* tv4 = tv1->cacheAfter(tma_load_type);
      tv4->setMemoryType(MemoryType::Shared);
    
      TensorView* reference = tv2;
    
      // Constants
      constexpr int64_t bulk_inner_dim = 256;
      if (tma1dSrcAddressOverflow(bulk_inner_dim)) {
        GTEST_SKIP() << "cp.async.bulk doesn't allow src address overflow!";
        return;
      }
      // [M, N] -> [M, N/bid, bid]
      reference->split(-1, bulk_inner_dim);
    
      TransformPropagatorWithCheck propagator(reference);
      MaxLogicalDomainInfoSpanningTree(reference).traverse(&propagator);
    
      // Set computeAt position
      inlineAllAt(tv2, /*pos=*/2);
    
      // Circular Buffer with TMA loads
      tv3->axis(0)->parallelize(ParallelType::BIDx);
      tv3->axis(2)->parallelize(ParallelType::Bulk);
      tv3->circularBuffer(
          number_of_stages, prefetch_distance, circular_buffer_type);
    
      tv4->axis(0)->parallelize(ParallelType::BIDx);
      tv4->axis(2)->parallelize(ParallelType::Bulk);
      tv4->circularBuffer(
          number_of_stages, prefetch_distance, circular_buffer_type);
    
      // Split reference to parallelize TMA tile
      reference->split(-1, bulk_inner_dim);
      reference->axis(0)->parallelize(ParallelType::BIDx);
      reference->axis(-1)->parallelize(ParallelType::TIDx);
    
      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
      at::Tensor t0 = at::randn({tensor_outer_dim, tensor_inner_dim}, options);
      at::Tensor t1 = at::randn({tensor_outer_dim, tensor_inner_dim}, options);
      at::Tensor t2 = t0 + t1;
    
      KernelExecutor ke;
      ke.compile(fusion.get(), {t0, t1});
    
      auto cg_outputs = ke.run({t0, t1});
      compare<float>(
          tensor_outer_dim, tensor_inner_dim, cg_outputs[0].as<at::Tensor>(), t2);
      testValidate(fusion.get(), cg_outputs, {t0, t1}, {t2}, __LINE__, __FILE__);
    }
    
    TEST_P(TmaCircularBufferingTest, PointwiseCpAsync) {
      GTEST_SKIP()
          << "Needs shared memory predicate, but current needsSharedMemoryPredicate() returns false";
    
      NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);
      std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
    
      TensorView* tv0 = makeContigTensor(2);
      TensorView* tv1 = makeContigTensor(2);
      fusion->addInput(tv0);
      fusion->addInput(tv1);
    
      TensorView* tv2 = add(tv0, tv1);
      fusion->addOutput(tv2);
    
      // Use TMA to load TV0 into shared memory
      TensorView* tv3 = tv0->cacheAfter(tma_load_type);
      tv3->setMemoryType(MemoryType::Shared);
    
      // Load TV1 into shared memory
      TensorView* tv4 = tv1->cacheAfter(LoadStoreOpType::CpAsync);
      tv4->setMemoryType(MemoryType::Shared);
    
      TensorView* reference = tv2;
    
      // Constants
      constexpr int64_t bulk_inner_dim = 256;
      if (tma1dSrcAddressOverflow(bulk_inner_dim)) {
        GTEST_SKIP() << "cp.async.bulk doesn't allow src address overflow!";
        return;
      }
      // [M, N] -> [M, N/bid, bid]
      reference->split(-1, bulk_inner_dim);
    
      TransformPropagatorWithCheck propagator(reference);
      MaxLogicalDomainInfoSpanningTree(reference).traverse(&propagator);
    
      // Set computeAt position
      inlineAllAt(tv2, /*pos=*/2);
    
      // Circular Buffer with TMA loads
      tv3->axis(0)->parallelize(ParallelType::BIDx);
      tv3->axis(2)->parallelize(ParallelType::Bulk);
      tv3->circularBuffer(
          number_of_stages, prefetch_distance, circular_buffer_type);
    
      // Circular Buffer with set operation
      tv4->axis(0)->parallelize(ParallelType::BIDx);
      // TODO Disable circular buffering for CpAsync
      // Circular buffering handles cpAsync sync logic separate from cloner logic.
      // tv4->circularBuffer(number_of_stages, prefetch_distance);
    
      // Split reference to parallelize TMA tile
      reference->split(-1, bulk_inner_dim);
      reference->axis(0)->parallelize(ParallelType::BIDx);
      reference->axis(-1)->parallelize(ParallelType::TIDx);
    
      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
      at::Tensor t0 = at::randn({tensor_outer_dim, tensor_inner_dim}, options);
      at::Tensor t1 = at::randn({tensor_outer_dim, tensor_inner_dim}, options);
      at::Tensor t2 = t0 + t1;
    
      KernelExecutor ke;
      ke.compile(fusion.get(), {t0, t1});
    
      auto cg_outputs = ke.run({t0, t1});
      compare<float>(
          tensor_outer_dim, tensor_inner_dim, cg_outputs[0].as<at::Tensor>(), t2);
      testValidate(fusion.get(), cg_outputs, {t0, t1}, {t2}, __LINE__, __FILE__);
    }
    
    TEST_P(TmaCircularBufferingTest, InnerReduction) {
      NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);
      std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
    
      TensorView* tv0 = makeContigTensor(2);
      fusion->addInput(tv0);
    
      TensorView* tv1 = sum(tv0, {-1});
      fusion->addOutput(tv1);
    
      TensorView* tv2 = tv0->cacheAfter(tma_load_type);
      tv2->setMemoryType(MemoryType::Shared);
    
      TensorView* reference = tv1;
    
      constexpr int64_t examples_per_cta = 4;
      constexpr int64_t bulk_inner_dim = 256;
    
      if (tma1dSrcAddressOverflow(bulk_inner_dim)) {
        GTEST_SKIP() << "cp.async.bulk doesn't allow src address overflow!";
        return;
      }
    
      // [M, N] -> [M/epc, epc, N]
      reference->split(0, examples_per_cta);
      // [M/epc, epc, N] -> [M/epc, epc, N/bid, bid]
      reference->split(-1, bulk_inner_dim);
    
      TransformPropagatorWithCheck propagator(reference);
      MaxLogicalDomainInfoSpanningTree(reference).traverse(&propagator);
    
      // [M/epc, epc, N/bid, bid] -> [M/epc, epc, N]
      reference->merge(-2, -1);
      // [M/epc, epc, N] -> [M/epc, epc, N/tdx, tdx]
      constexpr int64_t tdx = 256;
      reference->split(-1, tdx);
    
      // Parallelize
      reference->axis(0)->parallelize(ParallelType::BIDx);
    
      // Use block reduce.
      reference->axis(-1)->parallelize(ParallelType::TIDx);
    
      // InlineMost automatically handles vectorize and tma dimensions
      inlineMost();
    
      // Circular Buffer with TMA loads
      tv2->axis(0)->parallelize(ParallelType::BIDx);
      tv2->axis(-1)->parallelize(ParallelType::Bulk);
      tv2->circularBuffer(
          number_of_stages, prefetch_distance, circular_buffer_type);
    
      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
      at::Tensor t0 = at::randn({tensor_outer_dim, tensor_inner_dim}, options);
      at::Tensor t1 = sum(t0, {-1});
    
      KernelExecutor ke;
      ke.compile(fusion.get(), {t0});
    
      auto cg_outputs = ke.run({t0});
      compare<float>(tensor_outer_dim, cg_outputs[0].as<at::Tensor>(), t1);
      testValidate(fusion.get(), cg_outputs, {t0}, {t1}, __LINE__, __FILE__);
    }
    
    TEST_P(TmaCircularBufferingTest, OuterReduction) {
      NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);
    
      std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
    
      TensorView* tv0 = makeContigTensor(2);
      fusion->addInput(tv0);
    
      TensorView* tv1 = sum(tv0, {0});
      fusion->addOutput(tv1);
    
      TensorView* tv2 = tv0->cacheAfter(tma_load_type);
      tv2->setMemoryType(MemoryType::Shared);
    
      TensorView* reference = tv1;
    
      constexpr int64_t tile_size = 256;
      if (tma1dSrcAddressOverflow(tile_size)) {
        GTEST_SKIP() << "cp.async.bulk doesn't allow src address overflow!";
        return;
      }
    
      // [M, N] -> [M, N/bid, bid]
      reference->split(1, tile_size);
      // [M, N/bid, bid] -> [N/bid, M, bid]
      reference->reorder({{1, 0}});
    
      TransformPropagatorWithCheck propagator(reference);
      MaxLogicalDomainInfoSpanningTree(reference).traverse(&propagator);
    
      // Parallelize
      reference->axis(0)->parallelize(ParallelType::BIDx);
      reference->axis(2)->parallelize(ParallelType::TIDx);
      tv2->axis(0)->parallelize(ParallelType::BIDx);
      tv2->axis(2)->parallelize(ParallelType::Bulk);
    
      inlineMost();
    
      // Circular Buffer with TMA loads
      tv2->circularBuffer(
          number_of_stages, prefetch_distance, circular_buffer_type);
    
      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
      at::Tensor t0 = at::randn({tensor_outer_dim, tensor_inner_dim}, options);
      at::Tensor t1 = sum(t0, {0});
    
      KernelExecutor ke;
      ke.compile(fusion.get(), {t0});
    
      auto cg_outputs = ke.run({t0});
      compare<float>(tensor_inner_dim, cg_outputs[0].as<at::Tensor>(), t1);
      // Please note that, serial reduction has larger error than parallel reduction
      // This is the nature of the algorithm, not a bug in the implementation.
      EXPECT_EQ(at::allclose(cg_outputs[0].as<at::Tensor>(), t1, 1e-3, 1e-3), true);
    }
    
    TEST_P(TmaCircularBufferingTest, Persistent) {
      NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);
      if (testEnablesRegisterSharing()) {
        GTEST_SKIP() << "Bdimx is dynamic, register Sharing is disabled";
        return;
      }
    
      constexpr at::ScalarType dtype = at::ScalarType::Float;
      constexpr int64_t correction = 0;
      constexpr int64_t reduction_axis = 1;
      constexpr bool keepdim = true;
    
      std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
    
      TensorView* x = makeContigTensor(2, aten_to_data_type(dtype));
      fusion->addInput(x);
    
      // Algorithm:
      // x_norm = (x - x_mean) / sqrt(x_var)
      Val* num_elem = x->getLoopDomain().at(reduction_axis)->extent();
    
      TensorView* sum_x = sum(x, {reduction_axis}, /*keepdim=*/false);
      TensorView* mean_x = div(sum_x, num_elem);
      TensorView* bcast_mean = broadcast(mean_x, {false, true});
    
      TensorView* x_mean_sub = sub(x, bcast_mean);
      TensorView* x_mean_sub_sq = mul(x_mean_sub, x_mean_sub);
      TensorView* sum_x_mean_sub_sq =
          sum(x_mean_sub_sq, {reduction_axis}, /*keepdim=*/false);
      TensorView* var_x = div(sum_x_mean_sub_sq, num_elem);
      TensorView* bcast_var = broadcast(var_x, {false, true});
    
      TensorView* x_norm = div(sub(x, bcast_mean), sqrt(bcast_var));
      fusion->addOutput(x_norm);
    
      // Load input from global to shared memory
      TensorView* x_cache_smem = x->cacheAfter(tma_load_type);
      x_cache_smem->setMemoryType(MemoryType::Shared);
    
      // Load input from shared memory to registers
      x_cache_smem->cacheAfter();
    
      // Store results in registers
      x_norm->cacheBefore();
    
      std::vector<TensorView*> reduction_tvs =
          scheduler_utils::getReductionTvs(fusion.get());
    
      TensorView* reference_tv = x_norm;
    
      // boxDim array must be non-zero and less than or equal to 256
      constexpr int64_t width = 256;
      constexpr int64_t vectorize = 4;
      int64_t elem_per_compute_thread = tensor_inner_dim / width / vectorize;
      constexpr int64_t examples_per_cta = 4;
      constexpr int64_t tile_size = 256;
      if (tma1dSrcAddressOverflow(tile_size)) {
        GTEST_SKIP() << "cp.async.bulk doesn't allow src address overflow!";
        return;
      }
      // Since multi-dim CpAsyncBulk has a size limit of 256 per dimension,
      // we require multiple TMA operations to load the entire example in shared
      // memory for pointwise kernel.
      //
      // Define TMA Box
      // logical domain: [I1, I2]
      x_cache_smem->split(0, examples_per_cta);
      // split: [I0 / 4, 4, I2]
      x_cache_smem->split(-1, tile_size);
      // split: [I0/4, 4, I2/256, 256]
    
      // Schedule reference_tv
      //   logical domain: [I1, I2]
      //         split: [I1, I2/V (width / tdx), V]
      reference_tv->split(-1, vectorize);
      //         split: [I1, EPCT, I2/V/EPCT (tdx), V]
      reference_tv->split(-2, elem_per_compute_thread, /*inner_split=*/false);
      //         split: [I1, EPCT, I2/V/EPCT (tdx), U, V]
      reference_tv->split(-2, 1);
      //         reorder: [I1, I2/V/EPCT (tdx), EPCT, U, V]
      reference_tv->reorder({{-4, -3}, {-3, -4}});
      //         reorder: [I1/EPC, EPC, I2/V/EPCT (tdx), EPCT, U, V]
      reference_tv->split(0, examples_per_cta);
    
      TransformPropagator propagator(reference_tv);
      std::vector<TensorView*> all_tvs_except_cache =
          ir_utils::allTvsExcept(fusion.get(), {x_cache_smem});
      SetSelector selector(
          {all_tvs_except_cache.begin(), all_tvs_except_cache.end()});
      MaxLogicalDomainInfoSpanningTree(reference_tv, &selector)
          .traverse(&propagator);
    
      std::vector<TensorView*> rfactor_tvs;
      rfactor_tvs.reserve(reduction_tvs.size());
      std::transform(
          reduction_tvs.begin(),
          reduction_tvs.end(),
          std::back_inserter(rfactor_tvs),
          [](TensorView* tv) { return tv->rFactor({-3, -2, -1}); });
    
      // Define Parallelization Schema
      reference_tv->axis(0)->parallelize(ParallelType::BIDx);
      reference_tv->axis(2)->parallelize(ParallelType::TIDx);
      reference_tv->axis(-2)->parallelize(ParallelType::Unroll);
      scheduler_utils::parallelizeAllLike(reference_tv);
    
      // Vectorize Cache
      reference_tv->axis(-1)->parallelize(ParallelType::Vectorize);
    
      // InlineMost automatically handles vectorize and tma dimensions
      inlineMost();
    
      // Handle TMA Tensor
      // Apply circular buffer after computeAt
      x_cache_smem->axis(-1)->parallelize(ParallelType::Bulk);
      if (examples_per_cta > 1) {
        x_cache_smem->circularBuffer(
            number_of_stages, prefetch_distance, circular_buffer_type);
      }
    
      auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0);
      at::Tensor at_tv0 = at::randn({tensor_outer_dim, tensor_inner_dim}, options);
      at::Tensor at_tv1 = at::randn({tensor_outer_dim, tensor_inner_dim}, options);
    
      // Compile with KernelExecutor directly to avoid scheduling
      KernelExecutor ke;
      ke.compile(fusion.get(), {at_tv0});
      auto cg_outputs = ke.run({at_tv0});
    
      std::tuple<at::Tensor, at::Tensor> at_var_mean =
          at::var_mean(at_tv0, {-1}, correction, keepdim);
      at::Tensor at_var = std::get<0>(at_var_mean);
      at::Tensor at_mean = std::get<1>(at_var_mean);
      at::Tensor at_output = (at_tv0 - at_mean) / sqrt(at_var);
    
      testValidate(
          fusion.get(), cg_outputs, {at_tv0}, {at_output}, __LINE__, __FILE__);
    }
    
    TEST_P(TmaCircularBufferingTest, Matmul) {
      NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);
    
      if (testEnablesRegisterSharingTIDx()) {
        GTEST_SKIP()
            << "Register Sharing with TIDx used for both computation and load, requires TIDx to be a multiple of 128.";
        return;
      }
    
      // There are 512 compute threads and 128 load threads
      // register at entry: 96
      // register for compute: 96 + 32 / 4 = 104
      // register for loading: 96 - 32 = 64
      if (testEnablesRegisterSharingTIDy()) {
        circular_buffer_type =
            WarpSpecialized(ParallelType::TIDy, std::make_pair(64, 104));
      }
    
      if (tma_load_type == LoadStoreOpType::CpAsyncBulk) {
        GTEST_SKIP() << "LoadStoreOpType::CpAsyncBulk only supports 1D TMA";
        return;
      }
    
      std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
    
      // Algorithm
      TensorView* tv0 = makeContigTensor(2); // (M, K)
      TensorView* tv1 = makeContigTensor(2); // (K, N)
      fusion->addInput(tv0);
      fusion->addInput(tv1);
    
      TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B)
      TensorView* tv3 = broadcast(tv1, {true, false, false}); // (B, K, N)
      TensorView* tv4 = mul(tv2, tv3); // M, K, N
      TensorView* tv5 = sum(tv4, {1}); // M, R, N
      fusion->addOutput(tv5);
    
      // CpAsyncBulk Store
      TensorView* tv6 = tv5->cacheBefore(LoadStoreOpType::CpAsyncBulkTensorTile);
      tv6->setMemoryType(MemoryType::Shared);
    
      // For register circular buffering
      TensorView* tv0_cache_local = tv0->cacheAfter();
      TensorView* tv1_cache_local = tv1->cacheAfter();
    
      // For shared memory circular buffering
      TensorView* tv0_cache_smem =
          tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile);
      TensorView* tv1_cache_smem =
          tv1->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile);
      tv0_cache_smem->setMemoryType(MemoryType::Shared);
      tv1_cache_smem->setMemoryType(MemoryType::Shared);
    
      constexpr int64_t BSX = 64;
      constexpr int64_t TSX = 32;
      constexpr int64_t TSY = 16;
    
      // Step 0: [M, K, N]
      // Step 1: [M, K, N/BSX, BSX]
      tv6->split(-1, BSX);
    
      // Step 2: [M, K, N/BSX, BSX/TSX, TSX]
      tv6->split(-1, TSX);
    
      // Step 3: [M, K/BSX, BSX, N/BSX, BSX/TSX, TSX]
      tv6->split(1, BSX);
    
      // Step 4: [M/BSX, BSX, K/BSX, BSX, N/BSX, BSX/TSX, TSX]
      tv6->split(0, BSX);
    
      // Step 5:[M/BSX, BSX/TSY, TSY, K/BSX, BSX, N/BSX, BSX/TSX, TSX]
      tv6->split(1, TSY);
    
      // Step 6: [M/BSX, N/BSX, K/BSX, BSX/TSY, BSX/TSX, TSY, TSX, BSX]
      tv6->reorder(
          {{4, 7}, {7, 6}, {6, 4}, {2, 5}, {1, 3}, {3, 2}, {5, 1}, {0, 0}});
    
      // Step 7a: [M/BSX, N/BSX, K/BSX, BSX/TSY, BSX/TSX, TSY, TSX, BSX (reduce)]
      // Step 7b: [M/BSX, N/BSX, K/BSX (reduce), BSX/TSY, BSX/TSX, TSY, TSX]
      TensorView* tv6_rf = tv6->rFactor({-1});
    
      TransformPropagatorWithCheck propagator(tv6_rf);
      MaxLogicalDomainInfoSpanningTree(tv6_rf).traverse(&propagator);
    
      // Parallelize
      tv5->axis(0)->parallelize(ParallelType::BIDx);
      tv5->axis(1)->parallelize(ParallelType::BIDy);
      tv5->axis(-2)->parallelize(ParallelType::TIDy);
      tv5->axis(-1)->parallelize(ParallelType::TIDx);
    
      scheduler_utils::parallelizeAllLike(tv5);
    
      // (BSX/TSX * TSX * BSX) = 1024 floats = 4096 bytes * (number of buffers)
      tv0_cache_smem->axis(-3)->parallelize(ParallelType::Bulk);
      tv0_cache_smem->axis(-2)->parallelize(ParallelType::Bulk);
      tv0_cache_smem->axis(-1)->parallelize(ParallelType::Bulk);
    
      // (BSX/TSY * TSY * BSX) = 1024 floats = 4096 bytes * (number of buffers)
      tv1_cache_smem->axis(-3)->parallelize(ParallelType::Bulk);
      tv1_cache_smem->axis(-2)->parallelize(ParallelType::Bulk);
      tv1_cache_smem->axis(-1)->parallelize(ParallelType::Bulk);
    
      // Apply ParallelType::Bulk to global output tensor.
      tv5->axis(-4)->parallelize(ParallelType::Bulk);
      tv5->axis(-3)->parallelize(ParallelType::Bulk);
      tv5->axis(-2)->parallelize(ParallelType::Bulk);
      tv5->axis(-1)->parallelize(ParallelType::Bulk);
    
      // IterDomain: [M/BSX, N/BSX, K/BSX, BSX/TSY, BSX/TSX, TSY, TSX, BSX]
      // Parallelization: BDX, BDY, K/BSX ||, BSX/TSY, BSX/TSX, TSY, TSX, TDX]
      // 4 non-parallelized for-loops
      inlineMost();
    
      // Apply circular buffering after setting computeAt position
      tv0_cache_local->circularBuffer(number_of_stages, prefetch_distance);
      tv1_cache_local->circularBuffer(number_of_stages, prefetch_distance);
    
      tv0_cache_smem->circularBuffer(
          number_of_stages, prefetch_distance, circular_buffer_type);
      tv1_cache_smem->circularBuffer(
          number_of_stages, prefetch_distance, circular_buffer_type);
    
      constexpr int64_t K = 1024;
      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
      at::Tensor t0 = at::randn({tensor_outer_dim, K}, options);
      at::Tensor t1 = at::randn({K, tensor_inner_dim}, options);
      at::Tensor aten_output =
          (t0.unsqueeze(/*dim=*/-1) * t1.unsqueeze(/*dim=*/0)).sum(/*dim=*/1);
    
      KernelExecutor ke;
      ke.compile(fusion.get(), {t0, t1});
    
      auto cg_outputs = ke.run({t0, t1});
      compare<float>(
          tensor_outer_dim,
          tensor_inner_dim,
          cg_outputs[0].as<at::Tensor>(),
          aten_output);
      testValidate(
          fusion.get(), cg_outputs, {t0, t1}, {aten_output}, __LINE__, __FILE__);
    }
    
    TEST_P(TmaCircularBufferingTest, MatmulWithBroadcastedInput) {
      NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);
    
      if (testEnablesRegisterSharingTIDx()) {
        GTEST_SKIP()
            << "Register Sharing with TIDx used for both computation and load, requires TIDx to be a multiple of 128.";
        return;
      }
    
      // There are 512 compute threads and 128 load threads
      // register at entry: 96
      // register for compute: 96 + 32 / 4 = 104
      // register for loading: 96 - 32 = 64
      if (testEnablesRegisterSharingTIDy()) {
        circular_buffer_type =
            WarpSpecialized(ParallelType::TIDy, std::make_pair(64, 104));
      }
    
      if (tma_load_type == LoadStoreOpType::CpAsyncBulk) {
        GTEST_SKIP() << "LoadStoreOpType::CpAsyncBulk only supports 1D TMA";
        return;
      }
    
      std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
    
      // Algorithm
      TensorView* tv0 = makeContigConcreteTensor({-1, -1, 1}); // (M, K, B)
      TensorView* tv1 = makeContigConcreteTensor({1, -1, -1}); // (B, K, N)
      fusion->addInput(tv0);
      fusion->addInput(tv1);
    
      TensorView* tv2 = mul(tv0, tv1); // M, K, N
      TensorView* tv3 = sum(tv2, {1}); // M, R, N
      fusion->addOutput(tv3);
    
      // CpAsyncBulk Store
      TensorView* tv4 = tv3->cacheBefore(LoadStoreOpType::CpAsyncBulkTensorTile);
      tv4->setMemoryType(MemoryType::Shared);
    
      // For register circular buffering
      TensorView* tv0_cache_local = tv0->cacheAfter();
      TensorView* tv1_cache_local = tv1->cacheAfter();
    
      // For shared memory circular buffering
      TensorView* tv0_cache_smem =
          tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile);
      TensorView* tv1_cache_smem =
          tv1->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile);
      tv0_cache_smem->setMemoryType(MemoryType::Shared);
      tv1_cache_smem->setMemoryType(MemoryType::Shared);
    
      constexpr int64_t BSX = 64;
      constexpr int64_t TSX = 32;
      constexpr int64_t TSY = 16;
    
      // Step 0: [M, K, N]
      // Step 1: [M, K, N/BSX, BSX]
      tv4->split(-1, BSX);
    
      // Step 2: [M, K, N/BSX, BSX/TSX, TSX]
      tv4->split(-1, TSX);
    
      // Step 3: [M, K/BSX, BSX, N/BSX, BSX/TSX, TSX]
      tv4->split(1, BSX);
    
      // Step 4: [M/BSX, BSX, K/BSX, BSX, N/BSX, BSX/TSX, TSX]
      tv4->split(0, BSX);
    
      // Step 5:[M/BSX, BSX/TSY, TSY, K/BSX, BSX, N/BSX, BSX/TSX, TSX]
      tv4->split(1, TSY);
    
      // Step 6: [M/BSX, N/BSX, K/BSX, BSX/TSY, BSX/TSX, TSY, TSX, BSX]
      tv4->reorder(
          {{4, 7}, {7, 6}, {6, 4}, {2, 5}, {1, 3}, {3, 2}, {5, 1}, {0, 0}});
    
      // Step 7a: [M/BSX, N/BSX, K/BSX, BSX/TSY, BSX/TSX, TSY, TSX, BSX (reduce)]
      // Step 7b: [M/BSX, N/BSX, K/BSX (reduce), BSX/TSY, BSX/TSX, TSY, TSX]
      TensorView* tv4_rf = tv4->rFactor({-1});
    
      TransformPropagatorWithCheck propagator(tv4_rf);
      MaxLogicalDomainInfoSpanningTree(tv4_rf).traverse(&propagator);
    
      // Parallelize
      tv3->axis(0)->parallelize(ParallelType::BIDx);
      tv3->axis(1)->parallelize(ParallelType::BIDy);
      tv3->axis(-2)->parallelize(ParallelType::TIDy);
      tv3->axis(-1)->parallelize(ParallelType::TIDx);
    
      scheduler_utils::parallelizeAllLike(tv3);
    
      // (BSX/TSX * TSX * BSX) = 1024 floats = 4096 bytes * (number of buffers)
      tv0_cache_smem->axis(-5)->parallelize(ParallelType::Bulk);
      tv0_cache_smem->axis(-3)->parallelize(ParallelType::Bulk);
      tv0_cache_smem->axis(-1)->parallelize(ParallelType::Bulk);
    
      // (BSX/TSY * TSY * BSX) = 1024 floats = 4096 bytes * (number of buffers)
      tv1_cache_smem->axis(-4)->parallelize(ParallelType::Bulk);
      tv1_cache_smem->axis(-2)->parallelize(ParallelType::Bulk);
      tv1_cache_smem->axis(-1)->parallelize(ParallelType::Bulk);
    
      // Apply ParallelType::Bulk to global output tensor.
      tv3->axis(-4)->parallelize(ParallelType::Bulk);
      tv3->axis(-3)->parallelize(ParallelType::Bulk);
      tv3->axis(-2)->parallelize(ParallelType::Bulk);
      tv3->axis(-1)->parallelize(ParallelType::Bulk);
    
      // IterDomain: [M/BSX, N/BSX, K/BSX, BSX/TSY, BSX/TSX, TSY, TSX, BSX]
      // Parallelization: BDX, BDY, K/BSX ||, BSX/TSY, BSX/TSX, TSY, TSX, TDX]
      // 4 non-parallelized for-loops
      inlineMost();
    
      // Apply circular buffering after setting computeAt position
      tv0_cache_local->circularBuffer(number_of_stages, prefetch_distance);
      tv1_cache_local->circularBuffer(number_of_stages, prefetch_distance);
    
      tv0_cache_smem->circularBuffer(
          number_of_stages, prefetch_distance, circular_buffer_type);
      tv1_cache_smem->circularBuffer(
          number_of_stages, prefetch_distance, circular_buffer_type);
    
      constexpr int64_t K = 1024;
      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
      at::Tensor t0 = at::randn({tensor_outer_dim, K, 1}, options);
      at::Tensor t1 = at::randn({1, K, tensor_inner_dim}, options);
      at::Tensor aten_output = (t0 * t1).sum(/*dim=*/1);
    
      KernelExecutor ke;
      ke.compile(fusion.get(), {t0, t1});
    
      auto cg_outputs = ke.run({t0, t1});
      compare<float>(
          tensor_outer_dim,
          tensor_inner_dim,
          cg_outputs[0].as<at::Tensor>(),
          aten_output);
      testValidate(
          fusion.get(), cg_outputs, {t0, t1}, {aten_output}, __LINE__, __FILE__);
    }
    
    auto tmaCircularBufferingParams() {
      // A very simple PRNG:
      // https://en.wikipedia.org/wiki/Lehmer_random_number_generator
      uint32_t lcg_parkmiller = 1;
      const std::vector<CircularBufferType> all_types{
          Pipelined(false),
          Pipelined(true),
          WarpSpecialized(ParallelType::TIDx),
          WarpSpecialized(ParallelType::TIDy),
          WarpSpecialized(ParallelType::TIDx, std::make_pair(64, 168)),
          WarpSpecialized(ParallelType::TIDy, std::make_pair(64, 168))};
      const std::vector<LoadStoreOpType> tma_types{
          LoadStoreOpType::CpAsyncBulk, LoadStoreOpType::CpAsyncBulkTensorTile};
      std::vector<TmaCircularBufferingParams> values;
      for (int64_t i : {2, 4}) {
        for (int64_t j : c10::irange(-i, i)) {
          for (int64_t m : {128, 500, 1024}) {
            for (int64_t n : {1024, 2048}) {
              for (auto tma_load_type : tma_types) {
                values.emplace_back(
                    i,
                    j,
                    m,
                    n,
                    all_types[lcg_parkmiller % all_types.size()],
                    tma_load_type);
                lcg_parkmiller = (uint64_t)lcg_parkmiller * 48271 % 0x7fffffff;
              }
            }
          }
        }
      }
      return testing::ValuesIn(values);
    }
    
    std::string tmaName(
        const testing::TestParamInfo<TmaCircularBufferingParams>& info) {
      auto prefetch_distance = std::get<1>(info.param);
      std::string prefetch_distance_str;
      if (prefetch_distance < 0) {
        prefetch_distance_str = "neg" + std::to_string(-prefetch_distance);
      } else {
        prefetch_distance_str = std::to_string(prefetch_distance);
      }
      std::stringstream ss;
      ss << "stage_" << std::to_string(std::get<0>(info.param)) << "_prefetch_"
         << prefetch_distance_str << "_M_"
         << std::to_string(std::get<2>(info.param)) << "_N_"
         << std::to_string(std::get<3>(info.param)) << "_"
         << std::get<4>(info.param) << "_" << std::get<5>(info.param);
      return ss.str();
    }
    
    INSTANTIATE_TEST_SUITE_P(
        Hopper,
        TmaCircularBufferingTest,
        tmaCircularBufferingParams(),
        tmaName);
    
    using RegisterSharingTestParams = std::tuple<dim3, ParallelType>;
    using TmaRegisterSharingTest =
        NVFuserFixtureParamTest<RegisterSharingTestParams>;
    TEST_P(TmaRegisterSharingTest, RegisterSharingCtaShapes) {
      NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);
      int64_t gdimx = 2;
      auto [bdim, ws_pt] = GetParam();
      int64_t bdimx = bdim.x, bdimy = bdim.y, bdimz = bdim.z;
      int64_t n_computation_threads = bdimx * bdimy * bdimz;
      std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
    
      auto tv0 = makeContigTensor(2);
      fusion->addInput(tv0);
    
      auto tv1 = set(tv0);
      tv1->setMemoryType(MemoryType::Shared);
      tv1->definition()->as<LoadStoreOp>()->setOpType(LoadStoreOpType::CpAsyncBulk);
      auto tv2 = mul(tv1, tv1);
      fusion->addOutput(tv2);
    
      // [I1, I2] -> [gdimx, I1/gdimx, I2/bdimx/bdimy, bdimy, bdimx]
      tv2->split(0, gdimx, false);
      tv2->split(2, bdimx);
      tv2->split(2, bdimy);
      tv2->axis(-1)->parallelize(ParallelType::TIDx);
      tv2->axis(-2)->parallelize(ParallelType::TIDy);
      tv2->axis(-3)->parallelize(ParallelType::TIDz);
      tv2->axis(0)->parallelize(ParallelType::BIDx);
    
      // [I1, I2] -> [gdimx, I1/gdimx, I2]
      tv1->split(0, gdimx, false);
      tv1->axis(0)->parallelize(ParallelType::BIDx);
      tv1->axis(2)->parallelize(ParallelType::Bulk);
    
      // Set inlineAt before applying circular buffer
      inlineAllAt(tv1, /*pos=*/2);
    
      // warp specialization with register sharing requires
      // all threads in the same warp group execute the same
      // register adjustment instruction. So the number of padded
      // threads for TMA loading branch depends on CTA shape &
      // warp specialization dimension.
      // index = TIDx + TIDy * bdimx + TIDz * bdimx * bdimy
      // total = bdimx * bdimy * bdimz
      // Pad on x: bdimx += 128
      // Pad on y: bdimy += 128/bdimx
      // Pad on z: bdimz += 128/(bdimx * bdimy)
      auto get_tma_branch_threads = [&](ParallelType ws_pt) {
        if (ws_pt == ParallelType::TIDx) {
          return (int64_t)128 * bdimy * bdimz;
        } else if (ws_pt == ParallelType::TIDy) {
          return scheduler_utils::safeDiv(128, bdimx) * bdimx * bdimz;
        } else if (ws_pt == ParallelType::TIDz) {
          return scheduler_utils::safeDiv(128, bdimx * bdimy) * bdimx * bdimy;
        } else {
          NVF_THROW("TMA register sharing only supports TIDx, TIDy, and TIDz");
        }
      };
      // adjust register usage, assuming computation threads increase register
      // usage by 8, then each tma branch threads should reduce by:
      // 8 * n_computation / n_tma_branch_threads
      int64_t n_tma_branch_threads = get_tma_branch_threads(ws_pt);
      int64_t n_total_threads = n_computation_threads + n_tma_branch_threads;
      int64_t initial_reg_count = getRegPerThreadGivenThreadsPerSM(n_total_threads);
      EXPECT_TRUE(initial_reg_count % 8 == 0 || initial_reg_count == 255);
      int64_t compute_reg_count = initial_reg_count + 8;
      int64_t tma_reg_count =
          initial_reg_count - (n_computation_threads / n_tma_branch_threads) * 8;
      CircularBufferType circular_buffer_type =
          WarpSpecialized(ws_pt, std::make_pair(tma_reg_count, compute_reg_count));
      int64_t n_stages = 2;
      tv1->circularBuffer(n_stages, 1, circular_buffer_type);
    
      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
      at::Tensor t0 = at::randn({n_stages * gdimx, n_computation_threads}, options);
      at::Tensor t1 = t0 * t0;
      KernelExecutor ke;
      try {
        ke.compile(fusion.get(), {t0});
      } catch (const std::exception& e) {
        const char* reference = R"(Illegal register sharing on TIDx)";
        if ((bdimx % 128 || 128 % bdimx) && ws_pt == ParallelType::TIDx) {
          const char* str_match_pointer = strstr(e.what(), reference);
          ASSERT_TRUE(str_match_pointer != nullptr);
          return;
        }
      }
      auto cg_outputs = ke.run({t0});
      auto lparams = ke.lastLaunchParams();
      EXPECT_EQ(lparams.nThreads(), n_total_threads);
      testValidate(fusion.get(), cg_outputs, {t0}, {t1}, __LINE__, __FILE__);
    }
    INSTANTIATE_TEST_SUITE_P(
        Hopper,
        TmaRegisterSharingTest,
        ::testing::Combine(
            ::testing::Values(dim3(32, 4, 2), dim3(128, 2, 1), dim3(256, 1, 1)),
            ::testing::Values(
                ParallelType::TIDx,
                ParallelType::TIDy,
                ParallelType::TIDz)),
        [](const testing::TestParamInfo<RegisterSharingTestParams>& info) {
          std::stringstream ss;
          ss << "cta_" << std::get<0>(info.param).x;
          ss << "_" << std::get<0>(info.param).y;
          ss << "_" << std::get<0>(info.param).z;
          ss << "_pt_" << std::get<1>(info.param);
          return sanitizeTestName(ss.str());
        });
    
    } // namespace nvfuser

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @rdspring1 rdspring1 requested review from zasdfgbnm and rdspring1 March 7, 2025 17:49
    Copy link
    Collaborator

    @rdspring1 rdspring1 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Revise number of padded threads for warp specialization with register sharing to ensure both loading branch and computation branch has 128*N threads.

    The test coverage only supports a direct multiple of 128 * N threads. Probably should have had an assertion for this. This PR should expand the test coverage to handle that support.

    Is there a direct use for supporting all combinations of CTA shapes that are a multiple of 128?

    warp_dispatch_ite->thenBody().push_back(load_loop);

    // Nest load loop inside the warp dispatch if-then-else
    if (warp_specilization_pad > 1) {
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    IIUC, this nested IfThenElse should be merged with the ElectSync predicate logic at https://github.com/NVIDIA/Fuser/blob/main/csrc/predicate_compute.cpp#L652-L663.

        // select 1 thread form the last warp to do TMA load
        if (Hopper::electSync(4294967295U) && threadIdx.y == 19) {
    

    Copy link
    Collaborator Author

    @liqiangxl liqiangxl Mar 10, 2025

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    We can do that by adding the following two changes:
    (1) Don't need to nest load loop inside the warp dispatch if-then-else, basically remove changes at

    if (warp_specialization_pad > 1) {

    (2) Revise createMultipleExpressionElectSync to add extra predicate.

      for (auto pt : {ParallelType::TIDy, ParallelType::TIDz}) {
        if(!pdim_map.has(pt)){
          continue;
        }
        if (load_warp_on != pt) {
          conditional = SimplifyingIrBuilder::logicalAndExpr(
              conditional,
              IrBuilder::eqExpr(NamedScalar::getParallelIndex(pt), zero));
        }else{
          Val* raw =
              GpuLower::current()->parallelDimensionMap().get(load_warp_on);
          conditional = SimplifyingIrBuilder::logicalAndExpr(
              conditional,
              IrBuilder::eqExpr(NamedScalar::getParallelIndex(pt), IrBuilder::subExpr(raw, IrBuilder::create<Val>(1, DataType::Index))));      
        }
      }
    

    The, the generated code is changed from:

        if (threadIdx.y == 19) {
            Grid-Stride For-loop{
                if (Hopper::electSync(4294967295U)) {
                    // TMA Load
                }
          }
        }
    

    to

            Grid-Stride For-loop{
                if (Hopper::electSync(4294967295U) && threadIdx.y == 19) {
                    // TMA Load
                }
            }
    

    What's the benefit of moving threadIdx.y == 19 to the inside of the ForLoop? warp diverge is not an issue, since bdimx = 32/42/128, due to better loop handling or to keep consistent with other electSync? For example we have

      bool b18;
      b18 = ((nvfuser_index_t)threadIdx.x) < 32ULL;
      bool b19;
      b19 = ((nvfuser_index_t)threadIdx.y) == 0ULL;
      #pragma unroll
      for(nvfuser_index_t i22 = 0; i22 < 2; ++i22) {
        if (((Hopper::electSync(4294967295U) && b18) && b19)) {
          mbarrier::init(toSmem((&T12[i22])), 2U);
        }
      }
    

    instead of

    if(b19){
      #pragma unroll
      for(nvfuser_index_t i22 = 0; i22 < 2; ++i22) {
        if (((Hopper::electSync(4294967295U) && b18))) {
          mbarrier::init(toSmem((&T12[i22])), 2U);
        }
      }
    }
    
    

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    What's the benefit of moving threadIdx.y == 19 to the inside of the ForLoop?

    It isn't a CUDA kernel benefit, but a NvFuser lowering refactor.

    IfThenElse nodes are inserted in the UnrollPass pass without an actual predicate. IfThenElse nodes are also added in CircularBufferPass because of warp specialization and to handle mbarriers and TMA operations. i.e., These IfThenElse do not guard OOB memory access, but how the CTA executes these instructions. Then, the predicate is generated during the generateConditionalFromPredicate pass.

    csrc/codegen.cpp Outdated
    kernel_->hasManaged("increased_register_count"),
    "Decreased and increased register count must be set for register sharing warp specialization");

    int64_t decreased_reg_count =
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Can this check occur when compiling the fusion like a vectorization check? We can have multiple circular buffered loops. You can look up the register count through the kernel summary.

          int64_t prefetch = kernel_->summary()
                                 .circular_buffer_info
                                 .getCircularBufferOptionsFor(loop->iter_domain())
                                 .prefetch;
    

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    moved from fusion managed to kernel summary.

    liqiangxl and others added 3 commits March 10, 2025 11:19
    Co-authored-by: Ryan Spring <rspring@nvidia.com>
    Co-authored-by: Ryan Spring <rspring@nvidia.com>
    Co-authored-by: Ryan Spring <rspring@nvidia.com>
    @liqiangxl
    Copy link
    Collaborator Author

    Revise number of padded threads for warp specialization with register sharing to ensure both loading branch and computation branch has 128*N threads.

    The test coverage only supports a direct multiple of 128 * N threads. Probably should have had an assertion for this. This PR should expand the test coverage to handle that support.

    Is there a direct use for supporting all combinations of CTA shapes that are a multiple of 128?

    Currently, only matmul uses warp specialization with register sharing, test case TmaCircularBufferingTest, Matmul uses 32 x 16 computation threads. Without the extension, it fais in either 33 x 16 or 32 x 17. For persistent kernel, we may want more flexible on the shape of block, e.g. when computation is bidmx = 256, bdimy = 1 , we can pad to bdimx = 256 + 128, bdimy = 1 instead of just bdimx = 256, bdimy = 1+1

    @liqiangxl liqiangxl marked this pull request as ready for review March 10, 2025 18:50
    @rdspring1
    Copy link
    Collaborator

    My main objection is that I don't believe this PR has enough test coverage to enable all the CTA shape combinations.

    Test ideas:

    1. bdimx = 256 + 128, bdimy = 1
    2. A subset of 3D CTA shapes and all warp specialization parallel types.

    For example:
    What happens if (bdimx = 16, bdimy = 4, bdimz = 8)?

    • If WarpSpecializedOn(ParallelType::TIDx), then pad factor is 4.
    • If WarpSpecializedOn(ParallelType::TIDy), then pad factor is 1.
    • If WarpSpecializedOn(ParallelType::TIDz), then pad factor is 2.

    @rdspring1
    Copy link
    Collaborator

    rdspring1 commented Mar 11, 2025

    If this is easier, you can break out features 1, hard-code MIN_BLOCKS_PER_SM = 1, and 3, register count checks, into a separate PR and merge that quickly. Then, feature 2, padding warp specialization branch, can be in another PR.

    liqiangxl added a commit that referenced this pull request Mar 12, 2025
    …4059)
    
    Separated from #3993
    1. add launch bound
    2. disable tests with illegal paras to avoid undefined behaviors (hangs
    for some tests)
    
    ---------
    
    Co-authored-by: Ryan Spring <rspring@nvidia.com>
    liqiangxl added a commit that referenced this pull request Mar 14, 2025
    …ng (#4061)
    
    Separated from #3993
    1. Ensure the requested register count in the loading branch is lower
    than the initial count, in the computing branch is higher than the
    initial count.
    
    ---------
    
    Co-authored-by: Ryan Spring <rspring@nvidia.com>
    @liqiangxl liqiangxl marked this pull request as draft March 17, 2025 17:00
    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl
    Copy link
    Collaborator Author

    Major changes after previous review:
    (1) Added test RegisterSharingCtaShapes for different CTA shapes on different dimensions.
    For example, cta = [32, 4, 2]

    • warp specialize on x dim: Illegal, since we can't split compute threads & loading threads into multiple warp groups after padding.

    • warp specialize on y dim: [32, 4 + 4, 2], the loading branch is

      b9 = ((nvfuser_index_t)threadIdx.z) == 0ULL;
      if ((((nvfuser_index_t)threadIdx.y) >= 4)) {
        decreaseRegisters<120>();
        for(nvfuser_index_t i14 = 0; i14 < i0; ++i14) {
          if (((Hopper::electSync(4294967295U) && (((nvfuser_index_t)threadIdx.y) == 7)) && b9)) {
          }
        }
      }
    
    • warp specialize on z dim: [32, 4, 2 + 1], the loading branch is
      b8 = ((nvfuser_index_t)threadIdx.y) == 0ULL;
      b10 = ((nvfuser_index_t)threadIdx.z) == (((nvfuser_index_t)blockDim.z) + -1);
      if (b10) {
        decreaseRegisters<152>();
        for(nvfuser_index_t i15 = 0; i15 < i0; ++i15) {
          if (((Hopper::electSync(4294967295U) && b8) && b10)) {
          }
        }
      }
    

    For example, cta = [128, 2, 1]

    • warp specialize on x dim: [128 + 128, 2, 1], the loading branch is
      b9 = ((nvfuser_index_t)threadIdx.y) == 0ULL;
      b10 = ((nvfuser_index_t)threadIdx.z) == 0ULL;
      if ((((nvfuser_index_t)threadIdx.x) >= 128)) {
        decreaseRegisters<120>();
        for(nvfuser_index_t i15 = 0; i15 < i0; ++i15) {
          if ((((Hopper::electSync(4294967295U) && (((nvfuser_index_t)threadIdx.x) >= (((nvfuser_index_t)blockDim.x) + -32))) && b9) && b10)) {
          }
        }
      }
    
    • warp specialize on y dim: [128, 2 + 1, 1], the loading branch is
      b8 = ((nvfuser_index_t)threadIdx.x) < 32ULL;
      b10 = ((nvfuser_index_t)threadIdx.z) == 0ULL;
      b11 = ((nvfuser_index_t)threadIdx.y) == 2;
      if (b11) {
        decreaseRegisters<152>();
        for(nvfuser_index_t i16 = 0; i16 < i0; ++i16) {
          if ((((Hopper::electSync(4294967295U) && b8) && b11) && b10)) {
          }
        }
      }
    
    • warp specialize on z dim: [128, 2, 1 + 1], the loading branch is
      b8 = ((nvfuser_index_t)threadIdx.x) < 32ULL;
      b9 = ((nvfuser_index_t)threadIdx.y) == 0ULL;
      b11 = ((nvfuser_index_t)threadIdx.z) == (((nvfuser_index_t)blockDim.z) + -1);
      if (b11) {
        decreaseRegisters<120>();
        for(nvfuser_index_t i16 = 0; i16 < i0; ++i16) {
          if ((((Hopper::electSync(4294967295U) && b8) && b9) && b11)) {
          }
        }
      }
    

    (2) Refactored adjustMappingsForWarpSpecialization to simplify the logic of computing padded threads for warp specialization with register sharing.

    on X dim: pad 128
    on Y dim: pad safeDiv(128,  bdimx)
    on Z dim: pad safeDiv(128,  bdimx * bdimy)
    

    @liqiangxl liqiangxl marked this pull request as ready for review March 19, 2025 17:32
    @liqiangxl liqiangxl requested a review from rdspring1 March 19, 2025 17:32
    @rdspring1
    Copy link
    Collaborator

    !test

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    None yet

    2 participants