Skip to content

Almost-exact graph recognizes equivalence in the split-split pattern#5986

Draft
wujingyue wants to merge 6 commits intomainfrom
wjy/split
Draft

Almost-exact graph recognizes equivalence in the split-split pattern#5986
wujingyue wants to merge 6 commits intomainfrom
wjy/split

Conversation

@wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Feb 19, 2026

A spin-off from #4404

For #3987

@wujingyue
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Feb 19, 2026

Review updated until commit 42937d3

Description

  • Add mapSplitOfSplit function to recognize equivalence in split-split pattern

  • Modify getUses to return empty groups instead of error when no uses found

  • Add test cases for split-reshape patterns with different extent scenarios

  • Simplify test code by removing unnecessary unique_ptr usage

Changes walkthrough

Relevant files
Enhancement
id_model.cpp
Add split-split pattern mapping functionality                       

csrc/id_model/id_model.cpp

  • Remove unnecessary includes for trivial_broadcast and
    val_graph_visitor
  • Add mapSplitOfSplit function to handle split-split equivalence pattern
  • Integrate mapSplitOfSplit into buildAlmostExactGraph workflow
  • +57/-2   
    Error handling
    val_graph.cpp
    Improve getUses error handling                                                     

    csrc/val_graph.cpp

  • Modify getUses to return empty ExprGroups instead of throwing error
  • Add static empty_expr_groups for cases with no uses
  • +5/-4     
    Tests
    test_id_model.cpp
    Add split-reshape tests and simplify test code                     

    tests/cpp/test_id_model.cpp

  • Remove unused includes for fstream and graphviz
  • Simplify Fusion creation in multiple tests using direct instantiation
  • Add SplitingReshape test for basic split-reshape equivalence
  • Add SplitingReshape_DifferentExtents test for extent mismatch case
  • +55/-26 

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Error Handling Regression

    The removal of NVF_ERROR in getUses() method may hide legitimate bugs. Previously, if a val_group was expected to have uses but didn't exist in unique_uses_, it would throw an error. Now it silently returns an empty ExprGroups. This could mask programming errors where a val_group is incorrectly expected to have uses.

    const ExprGroups& ValGraph::getUses(const ValGroup& val_group) const {
      NVF_ERROR(val_group, "Nullptr not allowed");
    
      static ExprGroups empty_expr_groups;
      const auto it = unique_uses_.find(val_group);
      if (it == unique_uses_.end()) {
        return empty_expr_groups;
      }
      return it->second;
    }
    Algorithm Correctness

    The mapSplitOfSplit function implements a complex pattern matching algorithm for split-split patterns. The algorithm should be carefully reviewed to ensure it correctly handles all edge cases, particularly around the conditions for mapping outermost_grand and outer' IDs. The logic around extent comparison and the requirement that outer and inner must not be conflated needs validation.

    void mapSplitOfSplit(ValGraph& graph) {
      // The following is a subpattern of
      // https://github.com/NVIDIA/Fuser/blob/main/doc/reading/iterdomain.md#2-properties-of-iterdomain-transformations
      //
      // outer, _ = split(root)
      // outermost_grand, _ = split(outer)
      // outer', _ = split(root)
      //
      // If outermost_grand and outer' have the same extent, map them.
      std::vector<std::pair<Val*, Val*>> ids_to_map;
      for (const ValGroup& root : graph.disjointValSets().disjointSets()) {
        const ExprGroups& uses_of_root = graph.getUses(root);
        std::vector<ValGroup> outermost_grands;
        for (const ExprGroup& use_of_root : uses_of_root) {
          auto* split0 = dynamic_cast<Split*>(use_of_root->front());
          if (split0 == nullptr) {
            continue;
          }
          // Only follow the outer output of the first split; outer and inner
          // must not be conflated.
          const ValGroup& outer = graph.toGroup(split0->outer());
          for (const ExprGroup& use_of_outer : graph.getUses(outer)) {
            auto* split1 = dynamic_cast<Split*>(use_of_outer->front());
            if (split1 == nullptr) {
              continue;
            }
            const ValGroup& outermost_grand = graph.toGroup(split1->outer());
            outermost_grands.push_back(outermost_grand);
          }
        }
    
        for (const ValGroup& outermost_grand : outermost_grands) {
          Val* extent_of_grand =
              outermost_grand->front()->as<IterDomain>()->extent();
    
          for (const ExprGroup& use_of_root : uses_of_root) {
            auto* split = dynamic_cast<Split*>(use_of_root->front());
            if (split == nullptr) {
              continue;
            }
    
            const ValGroup& outer = graph.toGroup(split->outer());
            if (outer->front()->as<IterDomain>()->extent()->sameAs(
                    extent_of_grand)) {
              ids_to_map.emplace_back(outermost_grand->front(), outer->front());
            }
          }
        }
      }
    
      for (const auto& [id1, id2] : ids_to_map) {
        graph.mapVals(id1, id2);
      }
    }

    Test failures

    • (Medium, 34) NVFuser TMA load & inner-reduction tests hitting internal assertions (validator_utils.cpp, indexing.cpp) across multiple runners

      Test Name GB200 H100 Source
      TMASimpleLdstTest.Load/1D_128B___half Link
      TMASimpleLdstTest.Load/1D_128B_float Link
      TMASimpleLdstTest.Load/1D_32B___half Link
      TMASimpleLdstTest.Load/1D_32B_float Link
      TMASimpleLdstTest.Load/1D_64B___half Link
      TMASimpleLdstTest.Load/1D_64B_float Link
      TmaInnerReductionManualTest.Basic/ndim_2_inner_size_1048576 Link
      TmaInnerReductionManualTest.Basic/ndim_2_inner_size_131072 Link
      TmaInnerReductionManualTest.Basic/ndim_2_inner_size_524288 Link
      TmaInnerReductionManualTest.Basic/ndim_2_inner_size_65536 Link
      ... with 7 more test failures omitted. Check internal logs.

    @wujingyue
    Copy link
    Collaborator Author

    !test

    @wujingyue wujingyue changed the title [IdModel] almost-exact graph recognizes (partially) equivalence in the split-split pattern [IdModel] almost-exact graph recognizes equivalence in the split-split pattern Feb 19, 2026
    @wujingyue wujingyue changed the title [IdModel] almost-exact graph recognizes equivalence in the split-split pattern Almost-exact graph recognizes equivalence in the split-split pattern Feb 19, 2026
    @wujingyue wujingyue requested a review from naoyam February 19, 2026 22:28
    @wujingyue
    Copy link
    Collaborator Author

    @naoyam while I'm cleaning things up and verifying tests, do you think it's moving to the right direction?

    @wujingyue
    Copy link
    Collaborator Author

    !test

    @wujingyue wujingyue requested a review from naoyam February 20, 2026 01:10
    @naoyam
    Copy link
    Collaborator

    naoyam commented Feb 20, 2026

    Looks good overall.

    @wujingyue
    Copy link
    Collaborator Author

    I'm running into some interesting test failures. One of them is an validation error:

    [ RUN      ] TMASimpleLdstTest.Load/1D_128B___half
    ...
    Validation error in output 0 on line 524 in file /opt/pytorch/nvfuser/tests/cpp/test_memory.cpp.
      Detected max abs error of: 7.34375
        absolute tolerance was set to 0.00390625
        and relative tolerance set to 0.0078125
    

    The symptom is around this TensorView

    T2_g___half[iblockIdx.x15{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, ithreadIdx.x30{512}]
      logical domain: (iS33{( (( (( getMetaData(T0) )).logical_size ))[0] )})
      contiguity: t
        Split: iS33{( (( (( getMetaData(T0) )).logical_size ))[0] )} by factor 64 -> iblockIdx.x15{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, iS16{64}
        Split: iS16{64} by factor 64 -> iS17{1}, iS18{64}
        Split: iS18{64} by factor 8 -> iS23{8}, iS24{8}
        Split: iS17{1} by factor 8 -> iS19{1}, iS20{8}
        Split: iS20{8} by factor 1 -> iS21{8}, iS22{1}
        Xor(2D): iS21{8} , iS23{8} -> iS25{8} , iS26{8}
        Merge: iS19{1} and iS25{8} -> iS27{8}
        Merge: iS27{8} and iS22{1} -> iS28{8}
        Merge: iS28{8} and iS26{8} -> iS29{64}
        Merge: iS29{64} and iS24{8} -> ithreadIdx.x30{512}
      loop domain: (iblockIdx.x15{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, ithreadIdx.x30{512})
    

    The new code maps iS19{1} and iS17{1}.

    This is mathematically correct because these two IterDomains can share the same index -- the index is 0 all the time. However, codegen doesn't seem to like the mapping.

    Before I throw more if-elses at it, what's the right contract so people can DbC? cc @naoyam

    @naoyam
    Copy link
    Collaborator

    naoyam commented Feb 20, 2026

    Can you show the diff of generated codes? I'm guessing something isn't working around predicates.

    @wujingyue
    Copy link
    Collaborator Author

    TMASimpleLdstTest.Load/1D_128B___half

    git fetch origin wjy/split
    git checkout wjy/split
    _bn && bin/test_nvfuser --gtest_filter=TMASimpleLdstTest.Load/1D_128B___half
    

    cc @naoyam

    @wujingyue
    Copy link
    Collaborator Author

    As @naoyam requested:

    [ RUN      ] TMASimpleLdstTest.Load/1D_128B___half
    Inputs:
      T0_g___half[iS31{( (( (( getMetaData(T0) )).logical_size ))[0] )}]
    Outputs:
      T2_g___half[iblockIdx.x15{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, ithreadIdx.x30{512}]
    
    %kernel {
    T1_s___half[iblockIdx.x3{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, iB7{1}, iB13{8}, iB10{1}, iB14{8}, iB12{8}]
       = CpAsyncBulkTensorTile( T0_g___half[iS31{( (( (( getMetaData(T0) )).logical_size ))[0] )}] )
    T2_g___half[iblockIdx.x15{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, ithreadIdx.x30{512}]
       = Set( T1_s___half[iblockIdx.x3{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, iB7{1}, iB13{8}, iB10{1}, iB14{8}, iB12{8}], cache_op=Streaming )
    
    TransformPrinter :
    T0_g___half[iS31{( (( (( getMetaData(T0) )).logical_size ))[0] )}]
      logical domain: (iS31{( (( (( getMetaData(T0) )).logical_size ))[0] )})
      contiguity: t
      loop domain: (iS31{( (( (( getMetaData(T0) )).logical_size ))[0] )})
    T1_s___half[iblockIdx.x3{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, iB7{1}, iB13{8}, iB10{1}, iB14{8}, iB12{8}]
      logical domain: (iS32{( (( (( getMetaData(T0) )).logical_size ))[0] )})
      allocation domain: (iblockIdx.x3{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, iB7{1}, iB13{8}, iB10{1}, iB14{8}, iB12{8})
      contiguity: t t t t t t
        Split: iS32{( (( (( getMetaData(T0) )).logical_size ))[0] )} by factor 64 -> iblockIdx.x3{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, iS4{64}
        Split: iS4{64} by factor 64 -> iS5{1}, iS6{64}
        Split: iS5{1} by factor 8 -> iB7{1}, iS8{8}
        Split: iS6{64} by factor 8 -> iS11{8}, iB12{8}
        Split: iS8{8} by factor 1 -> iS9{8}, iB10{1}
        Xor(2D): iS9{8} , iS11{8} -> iB13{8} , iB14{8}
      loop domain: (iblockIdx.x3{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, iB7{1}, iB13{8}, iB10{1}, iB14{8}, iB12{8})
    T2_g___half[iblockIdx.x15{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, ithreadIdx.x30{512}]
      logical domain: (iS33{( (( (( getMetaData(T0) )).logical_size ))[0] )})
      contiguity: t
        Split: iS33{( (( (( getMetaData(T0) )).logical_size ))[0] )} by factor 64 -> iblockIdx.x15{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, iS16{64}
        Split: iS16{64} by factor 64 -> iS17{1}, iS18{64}
        Split: iS18{64} by factor 8 -> iS23{8}, iS24{8}
        Split: iS17{1} by factor 8 -> iS19{1}, iS20{8}
        Split: iS20{8} by factor 1 -> iS21{8}, iS22{1}
        Xor(2D): iS21{8} , iS23{8} -> iS25{8} , iS26{8}
        Merge: iS19{1} and iS25{8} -> iS27{8}
        Merge: iS27{8} and iS22{1} -> iS28{8}
        Merge: iS28{8} and iS26{8} -> iS29{64}
        Merge: iS29{64} and iS24{8} -> ithreadIdx.x30{512}
      loop domain: (iblockIdx.x15{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, ithreadIdx.x30{512})
    } // %kernel
    iS19{1} <==> iS17{1}
    
    ======= Codegen output for kernel: nvfuser_none_f0_c0_r0_g0 =======
    
    // Codegen generated code
    __global__ void nvfuser_none_f0_c0_r0_g0(Tensor<__half, 1, 1> T0, const __grid_constant__ TensorMap var0, Tensor<__half, 1, 1> T2) {
      alignas(128) extern __shared__ char array[];
      const unsigned smem_offset = 0;
      const TensorMap* ptr1;
      ptr1 = &var0;
      nvfuser_index_t i2;
      i2 = 64 * ((nvfuser_index_t)blockIdx.x);
      Array<int, 1, 1> a3;
      a3 = Array<int, 1, 1>{__to_int32(i2)};
      nvfuser_index_t i4;
      i4 = ((8 * ((((nvfuser_index_t)threadIdx.x) / 64) ^ ((((nvfuser_index_t)threadIdx.x) / 8) % 8))) + (((nvfuser_index_t)threadIdx.x) % 8)) + i2;
      __half* T1 = reinterpret_cast<__half*>(array + smem_offset + 0);
      uint64_t* T3 = reinterpret_cast<uint64_t*>(array + smem_offset + 1024);
      mbarrier::init(toSmem(T3), 1U);
      __syncthreads();
      if ((Hopper::electSync(4294967295U) && (((nvfuser_index_t)threadIdx.x) < 32ULL))) {
        uint64_t i5;
        i5 = mbarrier::arriveExpectTX(toSmem(T3), 128U);
        Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<1>{ ptr1, a3, toSmem(T3) }), toSmem(T1));
        mbarrier::wait(toSmem(T3), i5);
      }
      __syncthreads();
      mbarrier::inval(toSmem(T3));
      if (((i4 >= 0) && (i4 < T0.logical_size[0LL]))) {
        T2[i4]
           = T1[((nvfuser_index_t)threadIdx.x)];
      }
    }
    
    ======================================
    
    unknown file: Failure
    C++ exception with description " INTERNAL ASSERT FAILED at /opt/pytorch/nvfuser/csrc/validator_utils.cpp:505, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues.
    Expected aten_output_in_common_dtype.allclose( fusion_output_in_common_dtype, tolerance_values.second, tolerance_values.first, true) .
    
    Validation error in output 0 on line 524 in file /opt/pytorch/nvfuser/tests/cpp/test_memory.cpp.
      Detected max abs error of: 8.17969
        absolute tolerance was set to 0.00390625
        and relative tolerance set to 0.0078125
    

    @naoyam
    Copy link
    Collaborator

    naoyam commented Feb 20, 2026

    I'd like to see the diff result comparing the generated kernels. Please run the test with NVFUSER_DUMP=cuda_to_file to save the code to a file and run them through the diff command.

    @wujingyue
    Copy link
    Collaborator Author

    NVFUSER_DUMP=cuda_kernel without this PR:

    __global__ void nvfuser_none_f0_c0_r0_g0(Tensor<__half, 1, 1> T0, const __grid_constant__ TensorMap var0, Tensor<__half, 1, 1> T2) {
      alignas(128) extern __shared__ char array[];
      const unsigned smem_offset = 0;
      const TensorMap* ptr1;
      ptr1 = &var0;
      nvfuser_index_t i2;
      i2 = 64 * ((nvfuser_index_t)blockIdx.x);
      Array<int, 1, 1> a3;
      a3 = Array<int, 1, 1>{__to_int32(i2)};
      nvfuser_index_t i4;
      i4 = ((8 * ((((nvfuser_index_t)threadIdx.x) / 64) ^ ((((nvfuser_index_t)threadIdx.x) / 8) % 8))) + (((nvfuser_index_t)threadIdx.x) % 8)) + i2;
      __half* T1 = reinterpret_cast<__half*>(array + smem_offset + 0);
      uint64_t* T3 = reinterpret_cast<uint64_t*>(array + smem_offset + 1024);
      mbarrier::init(toSmem(T3), 1U);
      __syncthreads();
      if ((Hopper::electSync(4294967295U) && (((nvfuser_index_t)threadIdx.x) < 32ULL))) {
        uint64_t i5;
        i5 = mbarrier::arriveExpectTX(toSmem(T3), 128U);
        Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<1>{ ptr1, a3, toSmem(T3) }), toSmem(T1));
        mbarrier::wait(toSmem(T3), i5);
      }
      __syncthreads();
      mbarrier::inval(toSmem(T3));
      if ((((((nvfuser_index_t)threadIdx.x) < 64) && (i4 >= 0)) && (i4 < T0.logical_size[0LL]))) {
        T2[i4]
           = T1[((nvfuser_index_t)threadIdx.x)];
      }
    }
    

    I think you are right about predication:

    $ diff -ruN /tmp/old_kernel.txt /tmp/new_kernel.txt
    --- /tmp/old_kernel.txt 2026-02-20 14:49:44.624896719 -0800
    +++ /tmp/new_kernel.txt 2026-02-20 14:50:04.110394445 -0800
    @@ -21,7 +21,7 @@
       }
       __syncthreads();
       mbarrier::inval(toSmem(T3));
    -  if ((((((nvfuser_index_t)threadIdx.x) < 64) && (i4 >= 0)) && (i4 < T0.logical_size[0LL]))) {
    +  if (((i4 >= 0) && (i4 < T0.logical_size[0LL]))) {
         T2[i4]
            = T1[((nvfuser_index_t)threadIdx.x)];
       }
    

    cc @naoyam

    @wujingyue
    Copy link
    Collaborator Author

    wujingyue commented Feb 24, 2026

    Copying messages from @naoyam for https://abseil.io/resources/swe-book/html/ch03.html


    I looked into the issue. The issue happens due to the predication for non-divisible splits.

    https://github.com/NVIDIA/Fuser/blob/main/csrc/id_model/indexing.cpp#L778

    IIRC, Xiang had some writeup.

    https://github.com/NVIDIA/Fuser/blob/main/doc/reading/divisibility-of-split.md

    In this case, T2 has a non-divisible split with is17:

    T2_g___half[iblockIdx.x15{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, ithreadIdx.x30{512}]
      logical domain: (iS33{( (( (( getMetaData(T0) )).logical_size ))[0] )})
      contiguity: t
        Split: iS33{( (( (( getMetaData(T0) )).logical_size ))[0] )} by factor 64 -> iblockIdx.x15{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, iS16{64}
        Split: iS16{64} by factor 64 -> iS17{1}, iS18{64}
        Split: iS18{64} by factor 8 -> iS23{8}, iS24{8}
        Split: iS17{1} by factor 8 -> iS19{1}, iS20{8}
        Split: iS20{8} by factor 1 -> iS21{8}, iS22{1}
        Xor(2D): iS21{8} , iS23{8} -> iS25{8} , iS26{8}
        Merge: iS19{1} and iS25{8} -> iS27{8}
        Merge: iS27{8} and iS22{1} -> iS28{8}
        Merge: iS28{8} and iS26{8} -> iS29{64}
        Merge: iS29{64} and iS24{8} -> ithreadIdx.x30{512}
    

    iS17 is split by 8, which effectively expands the domain by a factor of 8, and so we would need to make sure indexing would not go beyond the original extent of iS17, which is just 1.

    getNonDivisibleIdsToPredicate used here returns iS17 in this case.

    In main, this line creates this predicate: ( ( ( threadIdx.x / 8 ) / 8 ) < 1 )

    Now, the PR adds another mapping: iS17 and iS19 . When we do the traversal, iS19 simply gets assigned with index value of zero. That is because of Merge: iS19{1} and iS25{8} -> iS27{8}. Here, iS25 and iS27 are mapped as part of the almost-exact mappings, so we simply forward the assigned index of iS27 to iS25, and for iS19, I think we simply assign zero (I need to confirm this). Since iS19 gets zero, so does iS17.

    This results in the non-divisible split predicate of 0 < 1, instead of ( ( ( threadIdx.x / 8 ) / 8 ) < 1 ) . As a result, since 0 < 1 is always true, the resulting code doesn't get any predicate for the non-divisible split.

    The almost-exact mapping is used for indexing traversal, so its mapping needs to take indexing equality into consideration. Even if two iter domains have the same extent, it doesn't automatically mean they should use the same index. In this case, for the purpose of indexing, I'd question if iS17 and iS19 should be mapped.

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants