Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,14 @@ Status Im2ColMatMulProgram::GenerateShaderCode(ShaderHelper& shader) const {

ORT_ENFORCE(tile_m_ == 16 || tile_m_ == 32, "tile_m must be 16 or 32.");
ORT_ENFORCE(tile_n_ == 64, "tile_n must be 64.");
ORT_ENFORCE(vec_size_ == 1 || vec_size_ == 4, "vec_size must be 4 or 1.");

return WGSL_TEMPLATE_APPLY(shader, "nn/im2col_matmul.wgsl.template",
WGSL_TEMPLATE_PARAMETER(has_bias, has_bias_),
WGSL_TEMPLATE_PARAMETER(tile_m, tile_m_),
WGSL_TEMPLATE_PARAMETER(tile_n, tile_n_),
WGSL_TEMPLATE_PARAMETER(use_subgroup, use_subgroup_),
WGSL_TEMPLATE_PARAMETER(vec_size, vec_size_),
WGSL_TEMPLATE_VARIABLE(output, output),
WGSL_TEMPLATE_VARIABLE(src, src),
WGSL_TEMPLATE_VARIABLE(weight, weight));
Expand Down Expand Up @@ -145,7 +147,8 @@ Status ApplyIm2ColMatMulProgram(ComputeContext& context,
// Ensure the subgroup size must be greater than or equal to `tile_m` to safely enable `use_subgroup`.
// If the status of this condition is uncertain, the feature must be disabled.
const bool use_subgroup = false;
Im2ColMatMulProgram im2col_mm_program{has_bias, tile_m, tile_n, use_subgroup};
const uint32_t vec_size = channel_input % 4 == 0 ? 4 : 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about extending it into 1, 2 or 4? const uint32_t vec_size = GetMaxComponents(channel_input);

Im2ColMatMulProgram im2col_mm_program{has_bias, tile_m, tile_n, vec_size, use_subgroup};
im2col_mm_program.SetWorkgroupSize(workgroup_size);

const uint32_t M_tiles = CeilDiv(im2col_m, tile_m);
Expand All @@ -154,10 +157,10 @@ Status ApplyIm2ColMatMulProgram(ComputeContext& context,

im2col_mm_program.AddInput({src,
ProgramTensorMetadataDependency::TypeAndRank,
4});
static_cast<int>(vec_size)});
im2col_mm_program.AddInput({&ohwi_weight,
ProgramTensorMetadataDependency::TypeAndRank,
4});
static_cast<int>(vec_size)});
if (has_bias) {
im2col_mm_program.AddInput({bias,
ProgramTensorMetadataDependency::TypeAndRank});
Expand All @@ -181,7 +184,7 @@ Status ApplyIm2ColMatMulProgram(ComputeContext& context,
{dilations},
{pads},
{strides}});
im2col_mm_program.CacheHint(has_bias, tile_m, tile_n, use_subgroup);
im2col_mm_program.CacheHint(has_bias, tile_m, tile_n, vec_size, use_subgroup);

return context.RunProgram(im2col_mm_program);
}
Expand Down Expand Up @@ -212,12 +215,6 @@ bool CanApplyIm2ColMatMulProgram(ComputeContextBase& context,
return false;
}

// TODO: Support channel input vec1
const uint32_t channel_input = onnxruntime::narrow<uint32_t>(weight_shape[1]);
if (channel_input % 4 != 0) {
return false;
}

return true;
}

Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/webgpu/nn/im2col_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@ class Im2ColMatMulProgram final : public Program<Im2ColMatMulProgram> {
Im2ColMatMulProgram(bool has_bias,
uint32_t tile_m,
uint32_t tile_n,
uint32_t vec_size,
bool use_subgroup) : Program("Im2ColMatMul"),
has_bias_(has_bias),
tile_m_(tile_m),
tile_n_(tile_n),
vec_size_(vec_size),
use_subgroup_(use_subgroup) {}

Status GenerateShaderCode(ShaderHelper& shader) const override;
Expand Down Expand Up @@ -71,6 +73,7 @@ class Im2ColMatMulProgram final : public Program<Im2ColMatMulProgram> {

uint32_t tile_m_;
uint32_t tile_n_;
uint32_t vec_size_;
bool use_subgroup_;
};

Expand Down
39 changes: 28 additions & 11 deletions onnxruntime/core/providers/webgpu/nn/im2col_matmul.wgsl.template
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,26 @@
#param tile_m
#param tile_n
#param use_subgroup
#param vec_size

#use .getByOffset .setByOffset

// im2col access for src: [N, H_i, W_i, C_i / 4] (vec4-packed NHWC)
// Conceptual Matrix Shape: N * (H_o * W_o) x (K_h * K_w * C_i / 4)
fn load_src(batch : u32, m : u32, k_packed_idx : u32) -> src_value_t {
if (batch >= uniforms.batch || m >= uniforms.im2col_m || k_packed_idx * 4 >= uniforms.im2col_k) {
if (batch >= uniforms.batch || m >= uniforms.im2col_m || k_packed_idx * vec_size >= uniforms.im2col_k) {
return src_value_t();
}

let channel_i_v4 = uniforms.channel_i / 4;
let channel_i_vec = uniforms.channel_i / vec_size;

// 1. Decompose M index (H_o * W_o) into (h_idx, w_idx)
let h_idx = m / uniforms.output_w; // Output H index (H_o)
let w_idx = m % uniforms.output_w; // Output W index (W_o)

// 2. Decompose K index into (k_h, k_w, c_i_v4_idx)
let c_i_v4_idx = k_packed_idx % channel_i_v4;
let k_h_w_idx = k_packed_idx / channel_i_v4;
let c_i_v4_idx = k_packed_idx % channel_i_vec;
let k_h_w_idx = k_packed_idx / channel_i_vec;
let k_h = k_h_w_idx / uniforms.kernel_w; // Kernel Row
let k_w = k_h_w_idx % uniforms.kernel_w; // Kernel Column

Expand All @@ -33,7 +34,7 @@ fn load_src(batch : u32, m : u32, k_packed_idx : u32) -> src_value_t {

// 4. Calculate the coordinate in the original input tensor
let src_h_coord : i32 = i32(src_h_coord_padded) - i32(uniforms.pads.x);
let src_w_coord : i32 = i32(src_w_coord_padded) - i32(uniforms.pads.z);
let src_w_coord : i32 = i32(src_w_coord_padded) - i32(uniforms.pads.y);

// 5. Check for padding/out-of-bounds
if (src_h_coord < 0 || src_h_coord >= i32(uniforms.src_h) ||
Expand All @@ -42,17 +43,17 @@ fn load_src(batch : u32, m : u32, k_packed_idx : u32) -> src_value_t {
}

// 6. Calculate final NHWC/vec4 index
let src_idx = batch * uniforms.src_h * uniforms.src_w * channel_i_v4 +
u32(src_h_coord) * uniforms.src_w * channel_i_v4 +
u32(src_w_coord) * channel_i_v4 +
let src_idx = batch * uniforms.src_h * uniforms.src_w * channel_i_vec +
u32(src_h_coord) * uniforms.src_w * channel_i_vec +
u32(src_w_coord) * channel_i_vec +
c_i_v4_idx;
return src.getByOffset(src_idx);
}

// weight shape: [Co, K_h, K_w, C_i / 4] (vec4-packed CoHWCi)
fn load_weight(n : u32, k_packed_idx : u32) -> weight_value_t {
if (n < uniforms.im2col_n && k_packed_idx < uniforms.im2col_k / 4) {
let weight_idx = n * uniforms.im2col_k / 4 +
if (n < uniforms.im2col_n && k_packed_idx < uniforms.im2col_k / vec_size) {
let weight_idx = n * uniforms.im2col_k / vec_size +
k_packed_idx;
return weight.getByOffset(weight_idx);
}
Expand Down Expand Up @@ -80,7 +81,7 @@ fn write_output(batch : u32, m : u32, n : u32, value : output_element_t) {

const TILE_M_SIZE : u32 = tile_m;
const TILE_N_SIZE : u32 = tile_n;
const TILE_K_VEC_SIZE : u32 = 4;
const TILE_K_VEC_SIZE : u32 = 16 / vec_size;

var<workgroup> src_tile : array<array<src_value_t, TILE_M_SIZE>, TILE_K_VEC_SIZE>;
var<workgroup> weight_tile : array<array<weight_value_t, TILE_N_SIZE>, TILE_K_VEC_SIZE>;
Expand All @@ -92,20 +93,32 @@ $MAIN {

var results : array<output_element_t, TILE_M_SIZE>;
for (var k_idx = 0u; k_idx < uniforms.K_tiles; k_idx++) {
#if vec_size != 4
for (var src_m = 0u; src_m < TILE_M_SIZE; src_m += 4u) {
let load_src_m = src_m + local_idx / 16;
let load_src_k = local_idx % 16;
#else
for (var src_m = 0u; src_m < TILE_M_SIZE; src_m += 16u) {
// Loads a 16x4 vec of src into the workgroup memory.
let load_src_m = src_m + local_idx / 4;
let load_src_k = local_idx % 4;
#endif

src_tile[load_src_k][load_src_m] = load_src(batch,
m_global_base + load_src_m,
k_idx * TILE_K_VEC_SIZE + load_src_k);
}

#if vec_size != 4
for (var weight_n = 0u; weight_n < TILE_N_SIZE; weight_n += 4u) {
let load_weight_n = weight_n + local_idx / 16;
let load_weight_k = local_idx % 16;
#else
for (var weight_n = 0u; weight_n < TILE_N_SIZE; weight_n += 16u) {
// Loads a 16x4 vec of weight into the workgroup memory.
let load_weight_n = weight_n + local_idx / 4;
let load_weight_k = local_idx % 4;
#endif

weight_tile[load_weight_k][load_weight_n] = load_weight(n_global_base + load_weight_n,
k_idx * TILE_K_VEC_SIZE + load_weight_k);
Expand All @@ -121,7 +134,11 @@ $MAIN {
}
#else
for (var m_idx = 0u; m_idx < TILE_M_SIZE; m_idx++) {
#if vec_size != 4
results[m_idx] += output_element_t(weight_data * src_tile[inner_k_idx][m_idx]);
#else
results[m_idx] += output_element_t(dot(weight_data, src_tile[inner_k_idx][m_idx]));
#endif
}
#endif
}
Expand Down