Skip to content

Feature request: Schedule a small matmul op as a reduction (or pointwise) op #3646

@naoyam

Description

@naoyam

In RoPE, there's a small matmul, which is currently sent to aten. For example, this is a first part of the Mistral forward RoPE module:

Inputs:
  T0_g___bfloat[bS0{1}, iS1{4096}, iS2{4096}]
  T1_g___bfloat[bS3{1}, iS4{4096}, iS5{1024}]
  T2_g___bfloat[bS6{1}, iS7{4096}, iS8{1024}]
  T3_g___bfloat[iS9{64}]
  T4_g_int64_t[bS10{1}, iS11{4096}]
Outputs:
  T29_g___bfloat[bS99{1}, bS100{1}, iS101{4096}, iS102{128}]
  T31_g___bfloat[bS107{1}, bS108{1}, iS109{4096}, iS110{128}]
  T51_g___bfloat[bS191{1}, iS192{32}, iS193{4096}, iS194{128}]
  T76_g___bfloat[bS299{1}, iS306{32}rf, iS302{4096}, iS303{128}]
  T81_g___bfloat[bS327{1}, iS334{32}rf, iS330{4096}, iS331{128}]

%kernel_math {
T5_l___bfloat[bS12{1}, iS13{4096}, iS16{32}rf, iS17{128}rf] = view( T0_g___bfloat[bS0{1}, iS1{4096}, iS2{4096}] )
T6_l___bfloat[bS18{1}, iS20{32}, iS19{4096}, iS21{128}]
   = Set.Permute( T5_l___bfloat[bS12{1}, iS13{4096}, iS16{32}rf, iS17{128}rf], cache_op=Streaming )
T34_l_float[bS119{1}, iS120{32}, iS121{4096}, iS122{128}]
   = __bfloat2float(T6_l___bfloat[bS18{1}, iS20{32}, iS19{4096}, iS21{128}]);
T11_l___bfloat[bS42{1}, iS43{64}, bS44{1}]
   = broadcast( T3_g___bfloat[iS9{64}] )
T12_l___bfloat[bS45{1}, iS46{64}, bS47{1}]
   = Set( T11_l___bfloat[bS42{1}, iS43{64}, bS44{1}], cache_op=Streaming )
T13_l_float[bS48{1}, iS49{64}, bS50{1}]
   = __bfloat2float(T12_l___bfloat[bS45{1}, iS46{64}, bS47{1}]);
T14_l_float[bS51{1}, iS52{64}, bS53{1}]
   = Set( T13_l_float[bS48{1}, iS49{64}, bS50{1}], cache_op=Streaming )
T15_l_float[bS54{1}, iS55{64}, bS56{1}]
   = Set( T14_l_float[bS51{1}, iS52{64}, bS53{1}], cache_op=Streaming )
T16_l_int64_t[bS57{1}, bS58{1}, iS59{4096}]
   = broadcast( T4_g_int64_t[bS10{1}, iS11{4096}] )
T17_l_int64_t[bS60{1}, bS61{1}, iS62{4096}]
   = Set( T16_l_int64_t[bS57{1}, bS58{1}, iS59{4096}], cache_op=Streaming )
T18_l_float[bS63{1}, bS64{1}, iS65{4096}]
   = (float)(T17_l_int64_t[bS60{1}, bS61{1}, iS62{4096}]);
T19_l_float[bS66{1}, iS67{64}, iS68{4096}]
   = matmul(T15_l_float[bS54{1}, iS55{64}, bS56{1}],
            T18_l_float[bS63{1}, bS64{1}, iS65{4096}])
T20_l_float[bS69{1}, iS71{4096}, iS70{64}]
   = Set.Permute( T19_l_float[bS66{1}, iS67{64}, iS68{4096}], cache_op=Streaming )
T21_l_float[bS72{1}, iS73{4096}, iS75{128}rf]
   = pad( T20_l_float[bS69{1}, iS71{4096}, iS70{64}], {0, 0, 0, 0, 0, 64} )
i85 = 0 + 64;
T22_l_float[bS76{1}, iS77{4096}, iS79{( ( 0 + 64 ) + 64 )}rf]
   = pad( T20_l_float[bS69{1}, iS71{4096}, iS70{64}], {0, 0, 0, 0, i85, 0} )
T23_l_float[bS80{1}, iS81{4096}, iS82{128}]
   = cat( T21_l_float[bS72{1}, iS73{4096}, iS75{128}rf], T22_l_float[bS76{1}, iS77{4096}, iS79{( ( 0 + 64 ) + 64 )}rf], 2 )
T24_l_float[bS83{1}, iS84{4096}, iS85{128}]
   = cosf(T23_l_float[bS80{1}, iS81{4096}, iS82{128}]);
T26_l___bfloat[bS89{1}, iS90{4096}, iS91{128}]
   = __float2bfloat(T24_l_float[bS83{1}, iS84{4096}, iS85{128}]);

The matmul op producing T19 becomes a segmentation boundary as the op and only itself is handled by aten, and the pre- and post sections are handled by the other schedulers. While this would make sense if the matmul op were compute-heavy, in this particular case it is unlikely as the dimensions are quite small.

T19_l_float[bS66{1}, iS67{64}, iS68{4096}]
   = matmul(T15_l_float[bS54{1}, iS55{64}, bS56{1}],
            T18_l_float[bS63{1}, bS64{1}, iS65{4096}])

This could be translated to just a sequence of pointwise ops:

T15_b = broadcast(T15, {false, false, false, true});
T18_b = broadcast(T18, {false, true, false, false});
T19 = squeeze(mul(T15_b, T18_b), -2);

Combined with #3645, the above section of the forward module would be likely fused into a single kernel with no segmentation.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions