Skip to content

Commit db383a9

Browse files
authored
[WebGPU EP] Reduce duplicated code in MatMulReadFnSource() (#27151)
### Description Previously in `MatMulReadFnSource()` we use duplicated code to read data from two inputs `a` and `b`. This patch implements another overload of `MatMulReadFnSource()` to only read data from one input to reduce duplicated code and get ready for further use.
1 parent 30ad350 commit db383a9

File tree

1 file changed

+34
-46
lines changed

1 file changed

+34
-46
lines changed

onnxruntime/core/providers/webgpu/math/gemm_utils.cc

Lines changed: 34 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -123,68 +123,56 @@ void InitializeLogicalWorkgroupIDAndGlobalID(ShaderHelper& shader) {
123123
} // namespace
124124

125125
void MatMulReadFnSource(ShaderHelper& shader,
126-
const ShaderVariableHelper& a,
127-
const ShaderVariableHelper& b,
126+
std::string_view function_name,
127+
const ShaderVariableHelper& input,
128+
const std::string& input_name,
128129
const ShaderIndicesHelper* batch_dims,
129-
bool transA,
130-
bool transB) {
131-
const int a_components = a.NumComponents();
130+
std::string_view rows,
131+
std::string_view components_per_row,
132+
bool transpose) {
133+
const int components = input.NumComponents();
132134
const std::string data_type = "output_element_t";
133-
std::string type_string = MakeScalarOrVectorType(a_components, data_type);
135+
const std::string type_string = MakeScalarOrVectorType(components, data_type);
134136

135137
shader.AdditionalImplementation()
136-
<< "fn mm_readA(batch: i32, row: i32, colIn: i32 "
138+
<< "fn " << function_name << "(batch: i32, row: i32, colIn: i32 "
137139
<< (batch_dims
138140
? ", batch_indices: batch_dims_indices_t"
139141
: "")
140-
<< ") -> " << type_string << " {\n"
141-
<< " var value = " << type_string << "(0);\n"
142-
<< " let col = colIn * " << a_components << ";\n";
143-
if (transA) {
144-
shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_a_outer)) {\n";
142+
<< ") -> " << type_string << " {\n "
143+
<< " var value = " << type_string << "(0);\n"
144+
<< " let col = colIn * " << components << ";\n";
145+
if (transpose) {
146+
shader.AdditionalImplementation() << " if(row < i32(" << components_per_row << ") && col < i32(" << rows << ")) {\n";
145147
} else {
146-
shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_inner)) {\n";
148+
shader.AdditionalImplementation() << " if(row < i32(" << rows << ") && col < i32(" << components_per_row << ")) {\n";
147149
}
148-
shader.AdditionalImplementation() << " var a_indices: a_indices_t;\n";
149150

150-
if (batch_dims) {
151-
shader.AdditionalImplementation() << ConvertOutputBatchIndicesToInputBatchIndices("a", a, a.Rank() - 2, batch_dims ? batch_dims->Rank() : 0, " batch_indices ");
152-
}
153-
shader.AdditionalImplementation() << " " << a.IndicesSet("a_indices", a.Rank() - 2, "u32(row)") << "\n"
154-
<< " " << a.IndicesSet("a_indices", a.Rank() - 1, "u32(colIn)") << "\n"
155-
<< " value = " << a.GetByIndices("a_indices") << ";\n"
156-
<< " }\n"
157-
<< " return value;\n"
158-
<< "}\n\n";
159-
160-
// Add the mm_readB function
161-
const int b_components = b.NumComponents();
162-
type_string = MakeScalarOrVectorType(b_components, data_type);
163-
shader.AdditionalImplementation()
164-
<< "fn mm_readB(batch: i32, row: i32, colIn: i32 "
165-
<< (batch_dims
166-
? ", batch_indices: batch_dims_indices_t"
167-
: "")
168-
<< ") -> " << type_string << " {\n"
169-
<< " var value = " << type_string << "(0);\n"
170-
<< " let col = colIn * " << b_components << ";\n";
151+
const std::string input_indices = input_name + "_indices";
152+
shader.AdditionalImplementation() << " var " << input_indices << ": " << input_name << "_indices_t" << ";\n";
171153

172-
if (transB) {
173-
shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_b_outer) && col < i32(uniforms.dim_inner)) {\n";
174-
} else {
175-
shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_b_outer)) {\n";
154+
if (batch_dims) {
155+
shader.AdditionalImplementation() << ConvertOutputBatchIndicesToInputBatchIndices(input_name, input, input.Rank() - 2, batch_dims ? batch_dims->Rank() : 0, " batch_indices ") << "\n";
176156
}
177157

178-
shader.AdditionalImplementation() << " var b_indices: b_indices_t;\n"
179-
<< ConvertOutputBatchIndicesToInputBatchIndices("b", b, b.Rank() - 2, batch_dims ? batch_dims->Rank() : 0, "batch_indices")
180-
<< " " << b.IndicesSet("b_indices", b.Rank() - 2, "u32(row)") << "\n"
181-
<< " " << b.IndicesSet("b_indices", b.Rank() - 1, "u32(colIn)") << "\n"
182-
<< " value = " << b.GetByIndices("b_indices") << ";\n"
183-
<< " }\n"
184-
<< " return value;\n"
158+
shader.AdditionalImplementation() << input.IndicesSet(input_indices, input.Rank() - 2, "u32(row)") << "\n"
159+
<< input.IndicesSet(input_indices, input.Rank() - 1, "u32(colIn)") << "\n"
160+
<< " value = " << input.GetByIndices(input_indices) << ";\n"
161+
<< " }\n"
162+
<< " return value;\n"
185163
<< "}\n\n";
186164
}
187165

166+
void MatMulReadFnSource(ShaderHelper& shader,
167+
const ShaderVariableHelper& a,
168+
const ShaderVariableHelper& b,
169+
const ShaderIndicesHelper* batch_dims,
170+
bool transA,
171+
bool transB) {
172+
MatMulReadFnSource(shader, "mm_readA", a, "a", batch_dims, "uniforms.dim_a_outer", "uniforms.dim_inner", transA);
173+
MatMulReadFnSource(shader, "mm_readB", b, "b", batch_dims, "uniforms.dim_inner", "uniforms.dim_b_outer", transB);
174+
}
175+
188176
void MatMulWriteFnSource(ShaderHelper& shader,
189177
const ShaderVariableHelper& output,
190178
const ShaderVariableHelper* bias,

0 commit comments

Comments
 (0)