@@ -123,68 +123,56 @@ void InitializeLogicalWorkgroupIDAndGlobalID(ShaderHelper& shader) {
123123} // namespace
124124
125125void MatMulReadFnSource (ShaderHelper& shader,
126- const ShaderVariableHelper& a,
127- const ShaderVariableHelper& b,
126+ std::string_view function_name,
127+ const ShaderVariableHelper& input,
128+ const std::string& input_name,
128129 const ShaderIndicesHelper* batch_dims,
129- bool transA,
130- bool transB) {
131- const int a_components = a.NumComponents ();
130+ std::string_view rows,
131+ std::string_view components_per_row,
132+ bool transpose) {
133+ const int components = input.NumComponents ();
132134 const std::string data_type = " output_element_t" ;
133- std::string type_string = MakeScalarOrVectorType (a_components , data_type);
135+ const std::string type_string = MakeScalarOrVectorType (components , data_type);
134136
135137 shader.AdditionalImplementation ()
136- << " fn mm_readA (batch: i32, row: i32, colIn: i32 "
138+ << " fn " << function_name << " (batch: i32, row: i32, colIn: i32 "
137139 << (batch_dims
138140 ? " , batch_indices: batch_dims_indices_t"
139141 : " " )
140- << " ) -> " << type_string << " {\n "
141- << " var value = " << type_string << " (0);\n "
142- << " let col = colIn * " << a_components << " ;\n " ;
143- if (transA ) {
144- shader.AdditionalImplementation () << " if(row < i32(uniforms.dim_inner ) && col < i32(uniforms.dim_a_outer )) {\n " ;
142+ << " ) -> " << type_string << " {\n "
143+ << " var value = " << type_string << " (0);\n "
144+ << " let col = colIn * " << components << " ;\n " ;
145+ if (transpose ) {
146+ shader.AdditionalImplementation () << " if(row < i32(" << components_per_row << " ) && col < i32(" << rows << " )) {\n " ;
145147 } else {
146- shader.AdditionalImplementation () << " if(row < i32(uniforms.dim_a_outer ) && col < i32(uniforms.dim_inner )) {\n " ;
148+ shader.AdditionalImplementation () << " if(row < i32(" << rows << " ) && col < i32(" << components_per_row << " )) {\n " ;
147149 }
148- shader.AdditionalImplementation () << " var a_indices: a_indices_t;\n " ;
149150
150- if (batch_dims) {
151- shader.AdditionalImplementation () << ConvertOutputBatchIndicesToInputBatchIndices (" a" , a, a.Rank () - 2 , batch_dims ? batch_dims->Rank () : 0 , " batch_indices " );
152- }
153- shader.AdditionalImplementation () << " " << a.IndicesSet (" a_indices" , a.Rank () - 2 , " u32(row)" ) << " \n "
154- << " " << a.IndicesSet (" a_indices" , a.Rank () - 1 , " u32(colIn)" ) << " \n "
155- << " value = " << a.GetByIndices (" a_indices" ) << " ;\n "
156- << " }\n "
157- << " return value;\n "
158- << " }\n\n " ;
159-
160- // Add the mm_readB function
161- const int b_components = b.NumComponents ();
162- type_string = MakeScalarOrVectorType (b_components, data_type);
163- shader.AdditionalImplementation ()
164- << " fn mm_readB(batch: i32, row: i32, colIn: i32 "
165- << (batch_dims
166- ? " , batch_indices: batch_dims_indices_t"
167- : " " )
168- << " ) -> " << type_string << " {\n "
169- << " var value = " << type_string << " (0);\n "
170- << " let col = colIn * " << b_components << " ;\n " ;
151+ const std::string input_indices = input_name + " _indices" ;
152+ shader.AdditionalImplementation () << " var " << input_indices << " : " << input_name << " _indices_t" << " ;\n " ;
171153
172- if (transB) {
173- shader.AdditionalImplementation () << " if(row < i32(uniforms.dim_b_outer) && col < i32(uniforms.dim_inner)) {\n " ;
174- } else {
175- shader.AdditionalImplementation () << " if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_b_outer)) {\n " ;
154+ if (batch_dims) {
155+ shader.AdditionalImplementation () << ConvertOutputBatchIndicesToInputBatchIndices (input_name, input, input.Rank () - 2 , batch_dims ? batch_dims->Rank () : 0 , " batch_indices " ) << " \n " ;
176156 }
177157
178- shader.AdditionalImplementation () << " var b_indices: b_indices_t;\n "
179- << ConvertOutputBatchIndicesToInputBatchIndices (" b" , b, b.Rank () - 2 , batch_dims ? batch_dims->Rank () : 0 , " batch_indices" )
180- << " " << b.IndicesSet (" b_indices" , b.Rank () - 2 , " u32(row)" ) << " \n "
181- << " " << b.IndicesSet (" b_indices" , b.Rank () - 1 , " u32(colIn)" ) << " \n "
182- << " value = " << b.GetByIndices (" b_indices" ) << " ;\n "
183- << " }\n "
184- << " return value;\n "
158+ shader.AdditionalImplementation () << input.IndicesSet (input_indices, input.Rank () - 2 , " u32(row)" ) << " \n "
159+ << input.IndicesSet (input_indices, input.Rank () - 1 , " u32(colIn)" ) << " \n "
160+ << " value = " << input.GetByIndices (input_indices) << " ;\n "
161+ << " }\n "
162+ << " return value;\n "
185163 << " }\n\n " ;
186164}
187165
166+ void MatMulReadFnSource (ShaderHelper& shader,
167+ const ShaderVariableHelper& a,
168+ const ShaderVariableHelper& b,
169+ const ShaderIndicesHelper* batch_dims,
170+ bool transA,
171+ bool transB) {
172+ MatMulReadFnSource (shader, " mm_readA" , a, " a" , batch_dims, " uniforms.dim_a_outer" , " uniforms.dim_inner" , transA);
173+ MatMulReadFnSource (shader, " mm_readB" , b, " b" , batch_dims, " uniforms.dim_inner" , " uniforms.dim_b_outer" , transB);
174+ }
175+
188176void MatMulWriteFnSource (ShaderHelper& shader,
189177 const ShaderVariableHelper& output,
190178 const ShaderVariableHelper* bias,
0 commit comments