Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature request: Schedule a small matmul op as a reduction (or pointwise) op #3646

Open
naoyam opened this issue Dec 25, 2024 · 1 comment
Open
Labels

Comments

@naoyam
Copy link
Collaborator

naoyam commented Dec 25, 2024

In RoPE, there's a small matmul, which is currently sent to aten. For example, this is a first part of the Mistral forward RoPE module:

Inputs:
  T0_g___bfloat[bS0{1}, iS1{4096}, iS2{4096}]
  T1_g___bfloat[bS3{1}, iS4{4096}, iS5{1024}]
  T2_g___bfloat[bS6{1}, iS7{4096}, iS8{1024}]
  T3_g___bfloat[iS9{64}]
  T4_g_int64_t[bS10{1}, iS11{4096}]
Outputs:
  T29_g___bfloat[bS99{1}, bS100{1}, iS101{4096}, iS102{128}]
  T31_g___bfloat[bS107{1}, bS108{1}, iS109{4096}, iS110{128}]
  T51_g___bfloat[bS191{1}, iS192{32}, iS193{4096}, iS194{128}]
  T76_g___bfloat[bS299{1}, iS306{32}rf, iS302{4096}, iS303{128}]
  T81_g___bfloat[bS327{1}, iS334{32}rf, iS330{4096}, iS331{128}]

%kernel_math {
T5_l___bfloat[bS12{1}, iS13{4096}, iS16{32}rf, iS17{128}rf] = view( T0_g___bfloat[bS0{1}, iS1{4096}, iS2{4096}] )
T6_l___bfloat[bS18{1}, iS20{32}, iS19{4096}, iS21{128}]
   = Set.Permute( T5_l___bfloat[bS12{1}, iS13{4096}, iS16{32}rf, iS17{128}rf], cache_op=Streaming )
T34_l_float[bS119{1}, iS120{32}, iS121{4096}, iS122{128}]
   = __bfloat2float(T6_l___bfloat[bS18{1}, iS20{32}, iS19{4096}, iS21{128}]);
T11_l___bfloat[bS42{1}, iS43{64}, bS44{1}]
   = broadcast( T3_g___bfloat[iS9{64}] )
T12_l___bfloat[bS45{1}, iS46{64}, bS47{1}]
   = Set( T11_l___bfloat[bS42{1}, iS43{64}, bS44{1}], cache_op=Streaming )
T13_l_float[bS48{1}, iS49{64}, bS50{1}]
   = __bfloat2float(T12_l___bfloat[bS45{1}, iS46{64}, bS47{1}]);
T14_l_float[bS51{1}, iS52{64}, bS53{1}]
   = Set( T13_l_float[bS48{1}, iS49{64}, bS50{1}], cache_op=Streaming )
T15_l_float[bS54{1}, iS55{64}, bS56{1}]
   = Set( T14_l_float[bS51{1}, iS52{64}, bS53{1}], cache_op=Streaming )
T16_l_int64_t[bS57{1}, bS58{1}, iS59{4096}]
   = broadcast( T4_g_int64_t[bS10{1}, iS11{4096}] )
T17_l_int64_t[bS60{1}, bS61{1}, iS62{4096}]
   = Set( T16_l_int64_t[bS57{1}, bS58{1}, iS59{4096}], cache_op=Streaming )
T18_l_float[bS63{1}, bS64{1}, iS65{4096}]
   = (float)(T17_l_int64_t[bS60{1}, bS61{1}, iS62{4096}]);
T19_l_float[bS66{1}, iS67{64}, iS68{4096}]
   = matmul(T15_l_float[bS54{1}, iS55{64}, bS56{1}],
            T18_l_float[bS63{1}, bS64{1}, iS65{4096}])
T20_l_float[bS69{1}, iS71{4096}, iS70{64}]
   = Set.Permute( T19_l_float[bS66{1}, iS67{64}, iS68{4096}], cache_op=Streaming )
T21_l_float[bS72{1}, iS73{4096}, iS75{128}rf]
   = pad( T20_l_float[bS69{1}, iS71{4096}, iS70{64}], {0, 0, 0, 0, 0, 64} )
i85 = 0 + 64;
T22_l_float[bS76{1}, iS77{4096}, iS79{( ( 0 + 64 ) + 64 )}rf]
   = pad( T20_l_float[bS69{1}, iS71{4096}, iS70{64}], {0, 0, 0, 0, i85, 0} )
T23_l_float[bS80{1}, iS81{4096}, iS82{128}]
   = cat( T21_l_float[bS72{1}, iS73{4096}, iS75{128}rf], T22_l_float[bS76{1}, iS77{4096}, iS79{( ( 0 + 64 ) + 64 )}rf], 2 )
T24_l_float[bS83{1}, iS84{4096}, iS85{128}]
   = cosf(T23_l_float[bS80{1}, iS81{4096}, iS82{128}]);
T26_l___bfloat[bS89{1}, iS90{4096}, iS91{128}]
   = __float2bfloat(T24_l_float[bS83{1}, iS84{4096}, iS85{128}]);

The matmul op producing T19 becomes a segmentation boundary as the op and only itself is handled by aten, and the pre- and post sections are handled by the other schedulers. While this would make sense if the matmul op were compute-heavy, in this particular case it is unlikely as the dimensions are quite small.

T19_l_float[bS66{1}, iS67{64}, iS68{4096}]
   = matmul(T15_l_float[bS54{1}, iS55{64}, bS56{1}],
            T18_l_float[bS63{1}, bS64{1}, iS65{4096}])

This could be translated to just a sequence of pointwise ops:

T15_b = broadcast(T15, {false, false, false, true});
T18_b = broadcast(T18, {false, true, false, false});
T19 = squeeze(mul(T15_b, T18_b), -2);

Combined with #3645, the above section of the forward module would be likely fused into a single kernel with no segmentation.

@naoyam naoyam added the rope label Dec 25, 2024
@jacobhinkle
Copy link
Collaborator

@Priya2698, I think it's time we start handling these K=1 cases in the matmul op (and similar for linear) as we discussed a while back. What do you think?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants