Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 34 additions & 46 deletions onnxruntime/core/providers/webgpu/math/gemm_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,68 +123,56 @@ void InitializeLogicalWorkgroupIDAndGlobalID(ShaderHelper& shader) {
} // namespace

void MatMulReadFnSource(ShaderHelper& shader,
const ShaderVariableHelper& a,
const ShaderVariableHelper& b,
std::string_view function_name,
const ShaderVariableHelper& input,
const std::string& input_name,
const ShaderIndicesHelper* batch_dims,
bool transA,
bool transB) {
const int a_components = a.NumComponents();
std::string_view rows,
std::string_view components_per_row,
bool transpose) {
const int components = input.NumComponents();
const std::string data_type = "output_element_t";
std::string type_string = MakeScalarOrVectorType(a_components, data_type);
const std::string type_string = MakeScalarOrVectorType(components, data_type);

shader.AdditionalImplementation()
<< "fn mm_readA(batch: i32, row: i32, colIn: i32 "
<< "fn " << function_name << "(batch: i32, row: i32, colIn: i32 "
<< (batch_dims
? ", batch_indices: batch_dims_indices_t"
: "")
<< ") -> " << type_string << " {\n"
<< " var value = " << type_string << "(0);\n"
<< " let col = colIn * " << a_components << ";\n";
if (transA) {
shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_a_outer)) {\n";
<< ") -> " << type_string << " {\n "
<< " var value = " << type_string << "(0);\n"
<< " let col = colIn * " << components << ";\n";
if (transpose) {
shader.AdditionalImplementation() << " if(row < i32(" << components_per_row << ") && col < i32(" << rows << ")) {\n";
} else {
shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_inner)) {\n";
shader.AdditionalImplementation() << " if(row < i32(" << rows << ") && col < i32(" << components_per_row << ")) {\n";
}
shader.AdditionalImplementation() << " var a_indices: a_indices_t;\n";

if (batch_dims) {
shader.AdditionalImplementation() << ConvertOutputBatchIndicesToInputBatchIndices("a", a, a.Rank() - 2, batch_dims ? batch_dims->Rank() : 0, " batch_indices ");
}
shader.AdditionalImplementation() << " " << a.IndicesSet("a_indices", a.Rank() - 2, "u32(row)") << "\n"
<< " " << a.IndicesSet("a_indices", a.Rank() - 1, "u32(colIn)") << "\n"
<< " value = " << a.GetByIndices("a_indices") << ";\n"
<< " }\n"
<< " return value;\n"
<< "}\n\n";

// Add the mm_readB function
const int b_components = b.NumComponents();
type_string = MakeScalarOrVectorType(b_components, data_type);
shader.AdditionalImplementation()
<< "fn mm_readB(batch: i32, row: i32, colIn: i32 "
<< (batch_dims
? ", batch_indices: batch_dims_indices_t"
: "")
<< ") -> " << type_string << " {\n"
<< " var value = " << type_string << "(0);\n"
<< " let col = colIn * " << b_components << ";\n";
const std::string input_indices = input_name + "_indices";
shader.AdditionalImplementation() << " var " << input_indices << ": " << input_name << "_indices_t" << ";\n";

if (transB) {
shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_b_outer) && col < i32(uniforms.dim_inner)) {\n";
} else {
shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_b_outer)) {\n";
if (batch_dims) {
shader.AdditionalImplementation() << ConvertOutputBatchIndicesToInputBatchIndices(input_name, input, input.Rank() - 2, batch_dims ? batch_dims->Rank() : 0, " batch_indices ") << "\n";
}

shader.AdditionalImplementation() << " var b_indices: b_indices_t;\n"
<< ConvertOutputBatchIndicesToInputBatchIndices("b", b, b.Rank() - 2, batch_dims ? batch_dims->Rank() : 0, "batch_indices")
<< " " << b.IndicesSet("b_indices", b.Rank() - 2, "u32(row)") << "\n"
<< " " << b.IndicesSet("b_indices", b.Rank() - 1, "u32(colIn)") << "\n"
<< " value = " << b.GetByIndices("b_indices") << ";\n"
<< " }\n"
<< " return value;\n"
shader.AdditionalImplementation() << input.IndicesSet(input_indices, input.Rank() - 2, "u32(row)") << "\n"
<< input.IndicesSet(input_indices, input.Rank() - 1, "u32(colIn)") << "\n"
<< " value = " << input.GetByIndices(input_indices) << ";\n"
<< " }\n"
<< " return value;\n"
<< "}\n\n";
}

void MatMulReadFnSource(ShaderHelper& shader,
const ShaderVariableHelper& a,
const ShaderVariableHelper& b,
const ShaderIndicesHelper* batch_dims,
bool transA,
bool transB) {
MatMulReadFnSource(shader, "mm_readA", a, "a", batch_dims, "uniforms.dim_a_outer", "uniforms.dim_inner", transA);
MatMulReadFnSource(shader, "mm_readB", b, "b", batch_dims, "uniforms.dim_inner", "uniforms.dim_b_outer", transB);
}

void MatMulWriteFnSource(ShaderHelper& shader,
const ShaderVariableHelper& output,
const ShaderVariableHelper* bias,
Expand Down
Loading