-
Notifications
You must be signed in to change notification settings - Fork 78
Closed
Labels
Description
In RoPE, there's a small matmul, which is currently sent to aten. For example, this is a first part of the Mistral forward RoPE module:
Inputs:
T0_g___bfloat[bS0{1}, iS1{4096}, iS2{4096}]
T1_g___bfloat[bS3{1}, iS4{4096}, iS5{1024}]
T2_g___bfloat[bS6{1}, iS7{4096}, iS8{1024}]
T3_g___bfloat[iS9{64}]
T4_g_int64_t[bS10{1}, iS11{4096}]
Outputs:
T29_g___bfloat[bS99{1}, bS100{1}, iS101{4096}, iS102{128}]
T31_g___bfloat[bS107{1}, bS108{1}, iS109{4096}, iS110{128}]
T51_g___bfloat[bS191{1}, iS192{32}, iS193{4096}, iS194{128}]
T76_g___bfloat[bS299{1}, iS306{32}rf, iS302{4096}, iS303{128}]
T81_g___bfloat[bS327{1}, iS334{32}rf, iS330{4096}, iS331{128}]
%kernel_math {
T5_l___bfloat[bS12{1}, iS13{4096}, iS16{32}rf, iS17{128}rf] = view( T0_g___bfloat[bS0{1}, iS1{4096}, iS2{4096}] )
T6_l___bfloat[bS18{1}, iS20{32}, iS19{4096}, iS21{128}]
= Set.Permute( T5_l___bfloat[bS12{1}, iS13{4096}, iS16{32}rf, iS17{128}rf], cache_op=Streaming )
T34_l_float[bS119{1}, iS120{32}, iS121{4096}, iS122{128}]
= __bfloat2float(T6_l___bfloat[bS18{1}, iS20{32}, iS19{4096}, iS21{128}]);
T11_l___bfloat[bS42{1}, iS43{64}, bS44{1}]
= broadcast( T3_g___bfloat[iS9{64}] )
T12_l___bfloat[bS45{1}, iS46{64}, bS47{1}]
= Set( T11_l___bfloat[bS42{1}, iS43{64}, bS44{1}], cache_op=Streaming )
T13_l_float[bS48{1}, iS49{64}, bS50{1}]
= __bfloat2float(T12_l___bfloat[bS45{1}, iS46{64}, bS47{1}]);
T14_l_float[bS51{1}, iS52{64}, bS53{1}]
= Set( T13_l_float[bS48{1}, iS49{64}, bS50{1}], cache_op=Streaming )
T15_l_float[bS54{1}, iS55{64}, bS56{1}]
= Set( T14_l_float[bS51{1}, iS52{64}, bS53{1}], cache_op=Streaming )
T16_l_int64_t[bS57{1}, bS58{1}, iS59{4096}]
= broadcast( T4_g_int64_t[bS10{1}, iS11{4096}] )
T17_l_int64_t[bS60{1}, bS61{1}, iS62{4096}]
= Set( T16_l_int64_t[bS57{1}, bS58{1}, iS59{4096}], cache_op=Streaming )
T18_l_float[bS63{1}, bS64{1}, iS65{4096}]
= (float)(T17_l_int64_t[bS60{1}, bS61{1}, iS62{4096}]);
T19_l_float[bS66{1}, iS67{64}, iS68{4096}]
= matmul(T15_l_float[bS54{1}, iS55{64}, bS56{1}],
T18_l_float[bS63{1}, bS64{1}, iS65{4096}])
T20_l_float[bS69{1}, iS71{4096}, iS70{64}]
= Set.Permute( T19_l_float[bS66{1}, iS67{64}, iS68{4096}], cache_op=Streaming )
T21_l_float[bS72{1}, iS73{4096}, iS75{128}rf]
= pad( T20_l_float[bS69{1}, iS71{4096}, iS70{64}], {0, 0, 0, 0, 0, 64} )
i85 = 0 + 64;
T22_l_float[bS76{1}, iS77{4096}, iS79{( ( 0 + 64 ) + 64 )}rf]
= pad( T20_l_float[bS69{1}, iS71{4096}, iS70{64}], {0, 0, 0, 0, i85, 0} )
T23_l_float[bS80{1}, iS81{4096}, iS82{128}]
= cat( T21_l_float[bS72{1}, iS73{4096}, iS75{128}rf], T22_l_float[bS76{1}, iS77{4096}, iS79{( ( 0 + 64 ) + 64 )}rf], 2 )
T24_l_float[bS83{1}, iS84{4096}, iS85{128}]
= cosf(T23_l_float[bS80{1}, iS81{4096}, iS82{128}]);
T26_l___bfloat[bS89{1}, iS90{4096}, iS91{128}]
= __float2bfloat(T24_l_float[bS83{1}, iS84{4096}, iS85{128}]);
The matmul op producing T19 becomes a segmentation boundary as the op and only itself is handled by aten, and the pre- and post sections are handled by the other schedulers. While this would make sense if the matmul op were compute-heavy, in this particular case it is unlikely as the dimensions are quite small.
T19_l_float[bS66{1}, iS67{64}, iS68{4096}]
= matmul(T15_l_float[bS54{1}, iS55{64}, bS56{1}],
T18_l_float[bS63{1}, bS64{1}, iS65{4096}])
This could be translated to just a sequence of pointwise ops:
T15_b = broadcast(T15, {false, false, false, true});
T18_b = broadcast(T18, {false, true, false, false});
T19 = squeeze(mul(T15_b, T18_b), -2);
Combined with #3645, the above section of the forward module would be likely fused into a single kernel with no segmentation.
Reactions are currently unavailable