Skip to content

Commit

Permalink
webgpu: Expand DepthwiseConv2DVec4Program to support any stride (#6820)
Browse files Browse the repository at this point in the history
* webgpu: Optimize stride2x2

* Support any workPerThread

* Move some checkings out of the loop

* Add tests

* Change constraint to stride <= 2

* Use fma builtin

* Address comments

Co-authored-by: Linchenn <[email protected]>
  • Loading branch information
qjia7 and Linchenn authored Sep 15, 2022
1 parent 32d2db4 commit 7423ee7
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 28 deletions.
56 changes: 31 additions & 25 deletions tfjs-backend-webgpu/src/depthwise_conv2d_vec4_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ export class DepthwiseConv2DVec4Program implements WebGPUProgram {
variableNames = ['x', 'W'];
uniforms = 'pad : vec2<i32>, inDims : vec2<i32>,';
workGroupSize: [number, number, number] = [4, 4, 4];
workPerThread = 4;
convInfo: backend_util.Conv2DInfo;
addBias: boolean;
activation: backend_util.Activation;
Expand All @@ -40,7 +41,8 @@ export class DepthwiseConv2DVec4Program implements WebGPUProgram {
this.outputShape = convInfo.outShape;
this.dispatchLayout = {x: [3], y: [2], z: [0, 1]};
this.dispatch = computeDispatch(
this.dispatchLayout, this.outputShape, this.workGroupSize, [4, 4, 1]);
this.dispatchLayout, this.outputShape, this.workGroupSize,
[4, this.workPerThread, 1]);

util.assert(
convInfo.dataFormat === 'channelsLast',
Expand All @@ -58,57 +60,61 @@ export class DepthwiseConv2DVec4Program implements WebGPUProgram {
this.activation = activation;
this.hasPreluActivation = hasPreluActivation;

this.shaderKey = `depthwiseVec4_${activation}_${
this.convInfo.filterHeight}_${this.convInfo.filterWidth}`;
this.shaderKey =
`depthwiseVec4_${activation}_${this.convInfo.filterHeight}_${
this.convInfo.filterWidth}_${this.convInfo.strideHeight}_${
this.convInfo.strideWidth}_${this.workPerThread}`;
}

getUserCode(): string {
// Here 4 is the work per thread in X dimension.
const xNumber = 4 + this.convInfo.filterWidth - 1;
const xNumber = (this.workPerThread - 1) * this.convInfo.strideWidth +
this.convInfo.filterWidth;

const userCode = `
${activationFnSnippet(this.activation, this.hasPreluActivation, true, 4)}
fn readX(batch : i32, row : i32, col : i32, channel : i32) -> vec4<f32> {
var value = vec4<f32>(0.0);
if (row >=0 && row < uniforms.inDims[0] && col >=0 && col < uniforms.inDims[1])
{
if (col >=0 && col < uniforms.inDims[1]) {
value = getX(batch, row, col, channel);
}
return value;
}
const strideHeight = ${this.convInfo.strideHeight};
const strideWidth = ${this.convInfo.strideWidth};
${getWorkGroupSizeString()}
fn _start(@builtin(global_invocation_id) globalId: vec3<u32>) {
let batch = i32(globalId.z) / uniforms.outShape[1];
let r = i32(globalId.z) % uniforms.outShape[1];
let c = i32(globalId.y) * 4;
let c = i32(globalId.y) * ${this.workPerThread};
let d1 = i32(globalId.x) * 4;
let xRCCorner = vec2<i32>(r, c) - uniforms.pad;
let xRCCorner = vec2<i32>(r, c) * vec2<i32>(strideHeight, strideWidth) - uniforms.pad;
let xRCorner = xRCCorner.x;
let xCCorner = xRCCorner.y;
var xVals : array<vec4<f32>, ${xNumber}>;
var dotProd : array<vec4<f32>, 4>;
dotProd[0] = vec4<f32>(0.0);
dotProd[1] = vec4<f32>(0.0);
dotProd[2] = vec4<f32>(0.0);
dotProd[3] = vec4<f32>(0.0);
var dotProd : array<vec4<f32>, ${this.workPerThread}>;
for (var i = 0; i < ${this.workPerThread}; i++) {
dotProd[i] = vec4<f32>(0.0);
}
// Use constant instead of uniform can give better performance.
for (var wR = 0; wR < ${this.convInfo.filterHeight}; wR = wR + 1) {
let xR = xRCorner + wR;
for (var i = 0; i < ${xNumber}; i++)
{
xVals[i] = readX(batch, xR, xCCorner + i, d1);
}
for (var wC = 0; wC < ${this.convInfo.filterWidth}; wC = wC + 1) {
let wValue = getW(wR, wC, d1, 0);
dotProd[0] = dotProd[0] + xVals[0 + wC] * wValue;
dotProd[1] = dotProd[1] + xVals[1 + wC] * wValue;
dotProd[2] = dotProd[2] + xVals[2 + wC] * wValue;
dotProd[3] = dotProd[3] + xVals[3 + wC] * wValue;
if (xR >=0 && xR < uniforms.inDims[0]) {
for (var i = 0; i < ${xNumber}; i++) {
xVals[i] = readX(batch, xR, xCCorner + i, d1);
}
for (var wC = 0; wC < ${this.convInfo.filterWidth}; wC = wC + 1) {
let wValue = getW(wR, wC, d1, 0);
for (var i = 0; i < ${this.workPerThread}; i++) {
dotProd[i] = fma(xVals[i * strideWidth + wC], wValue, dotProd[i]);
}
}
}
}
for (var i = 0; i < 4; i = i + 1) {
for (var i = 0; i < ${this.workPerThread}; i = i + 1) {
let coords = vec4<i32>(batch, r, c + i, d1);
if (coordsInBounds4D(coords, uniforms.outShape)) {
var value = dotProd[i];
Expand Down
2 changes: 1 addition & 1 deletion tfjs-backend-webgpu/src/kernels/DepthwiseConv2dNative.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ export function depthwiseConv2dNative(args: {
convInfo.outShape, convInfo.filterHeight, convInfo.filterWidth);
} else if (
isChannelsLast && convInfo.inHeight > 4 && convInfo.inWidth > 4 &&
convInfo.strideHeight === 1 && convInfo.strideWidth === 1 &&
convInfo.strideWidth <= 2 &&
convInfo.inChannels === convInfo.outChannels &&
convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 &&
convInfo.inChannels % 4 === 0) {
Expand Down
2 changes: 1 addition & 1 deletion tfjs-backend-webgpu/src/kernels/FusedDepthwiseConv2D.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ export function fusedDepthwiseConv2D(args: {

let program: DepthwiseConv2DProgram|DepthwiseConv2DVec4Program;
if (convInfo.inHeight > 4 && convInfo.inWidth > 4 &&
convInfo.strideHeight === 1 && convInfo.strideWidth === 1 &&
convInfo.strideWidth <= 2 &&
convInfo.inChannels === convInfo.outChannels &&
convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 &&
convInfo.inChannels % 4 === 0) {
Expand Down
153 changes: 152 additions & 1 deletion tfjs-core/src/ops/depthwise_conv2d_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ describeWithFlags('depthwiseConv2D', ALL_ENVS, () => {
expectArraysClose(await result.data(), await expectedResult.data());
});

it('input=1x5x5x1,f=3,s=1,d=2,p=same,chMul=1', async () => {
it('input=1x5x5x1,f=3,s=1,d=2,p=valid,chMul=1', async () => {
const fSize = 3;
const pad = 'valid';
const stride = 1;
Expand Down Expand Up @@ -370,6 +370,157 @@ describeWithFlags('depthwiseConv2D', ALL_ENVS, () => {
expectArraysClose(await result.data(), expected);
});

it('input=1x5x5x4,f=3,s=1,d=1,p=same,chMul=1', async () => {
const fSize = 3;
const pad = 'same';
const stride = 1;
const chMul = 1;
const inDepth = 4;

const x = tf.tensor4d(
[
0.149194, 0.089009, 0.654891, 0.083324, 0.537043, 0.644331,
0.563037, 0.211859, 0.633501, 0.186427, 0.777034, 0.50001,
0.607341, 0.95303, 0.696479, 0.050387, 0.62045, 0.728049,
0.028043, 0.437009, 0.712881, 0.741935, 0.974474, 0.621102,
0.171411, 0.675707, 0.758567, 0.413529, 0.963967, 0.217291,
0.101335, 0.804231, 0.329673, 0.924503, 0.728742, 0.180217,
0.210459, 0.133869, 0.650827, 0.047613, 0.554795, 0.653365,
0.442196, 0.261945, 0.0528113, 0.656698, 0.127345, 0.610039,
0.169131, 0.458647, 0.0988288, 0.966109, 0.0421747, 0.82035,
0.274711, 0.359377, 0.512113, 0.689682, 0.941571, 0.31961,
0.743826, 0.858147, 0.984766, 0.926973, 0.579597, 0.444104,
0.505969, 0.241437, 0.937999, 0.0957074, 0.773611, 0.46023,
0.469379, 0.363789, 0.269745, 0.486136, 0.894215, 0.794299,
0.724615, 0.261945, 0.0528113, 0.656698, 0.127345, 0.610039,
0.169131, 0.458647, 0.0988288, 0.966109, 0.0421747, 0.82035,
0.274711, 0.359377, 0.512113, 0.689682, 0.941571, 0.31961,
0.743826, 0.858147, 0.984766, 0.926973
],
[1, 5, 5, inDepth]);
const w = tf.tensor4d(
[
0.6511372, 0.8699447, 0.6511372, 0.8699447, 0.267792, 0.9981787,
0.267792, 0.9981787, 0.4913572, 0.3321196, 0.4913572, 0.3321196,
0.5286497, 0.4241803, 0.5286497, 0.4241803, 0.0175446, 0.8365464,
0.0175446, 0.8365464, 0.1768399, 0.2874831, 0.1768399, 0.2874831,
0.0933998, 0.5764548, 0.0933998, 0.5764548, 0.0661623, 0.8850273,
0.0661623, 0.8850273, 0.8700929, 0.205422, 0.8700929, 0.205422
],
[fSize, fSize, inDepth, chMul],
);

const result = tf.depthwiseConv2d(x, w, stride, pad);
expect(result.shape).toEqual([1, 5, 5, 4]);
const expected = [
0.29389750957489014, 1.055132269859314, 0.8355544209480286,
0.7652503848075867, 1.116986632347107, 1.7007107734680176,
0.7228718996047974, 1.2455471754074097, 0.7690584063529968,
1.4749835729599, 1.1460752487182617, 1.5098011493682861,
0.7502411007881165, 2.056602716445923, 1.0519171953201294,
1.012758731842041, 0.37667199969291687, 1.6647151708602905,
0.4798099994659424, 0.532977283000946, 0.4293096363544464,
1.8309053182601929, 0.7433272004127502, 1.1491419076919556,
1.3050479888916016, 2.7769954204559326, 1.6411027908325195,
2.1799824237823486, 1.0364032983779907, 2.7503039836883545,
1.7060394287109375, 2.880652904510498, 1.8967751264572144,
3.3914175033569336, 1.734355092048645, 2.076633930206299,
0.7774094939231873, 3.1432321071624756, 0.9456352591514587,
1.0863502025604248, 0.8477171659469604, 2.5510711669921875,
1.169355869293213, 2.0218098163604736, 2.23183274269104,
3.257829189300537, 1.939490556716919, 2.96195650100708,
1.0946838855743408, 2.4252827167510986, 1.329919695854187,
3.0390005111694336, 1.8967963457107544, 2.775693416595459,
1.5250799655914307, 2.4470155239105225, 0.40530526638031006,
2.775503158569336, 0.8836789727210999, 1.1361782550811768,
0.4407186806201935, 2.3912413120269775, 0.38215696811676025,
2.047299861907959, 1.080580234527588, 3.09224534034729,
1.2943278551101685, 3.1656715869903564, 0.9704407453536987,
2.8066811561584473, 1.419780969619751, 3.1822099685668945,
1.720312237739563, 3.279745578765869, 2.0871992111206055,
2.6629819869995117, 0.5254714488983154, 3.3779194355010986,
0.73943030834198, 2.0616414546966553, 0.5148154497146606,
1.6852912902832031, 0.5320349931716919, 1.7935365438461304,
1.1387810707092285, 2.119696617126465, 1.2744661569595337,
2.3705403804779053, 1.0399315357208252, 1.6817822456359863,
0.8927359580993652, 1.6332063674926758, 1.3386595249176025,
1.8818190097808838, 1.267898440361023, 1.6589205265045166,
0.8288722038269043, 2.119757890701294, 0.8847255706787109,
1.5954076051712036
];
expectArraysClose(await result.data(), expected);
});

it('input=1x5x5x4,f=5,s=2,d=1,p=same,chMul=1', async () => {
const fSize = 5;
const pad = 'same';
const stride = 2;
const chMul = 1;
const inDepth = 4;

const x = tf.tensor4d(
[
0.149194, 0.089009, 0.654891, 0.083324, 0.537043, 0.644331,
0.563037, 0.211859, 0.633501, 0.186427, 0.777034, 0.50001,
0.607341, 0.95303, 0.696479, 0.050387, 0.62045, 0.728049,
0.028043, 0.437009, 0.712881, 0.741935, 0.974474, 0.621102,
0.171411, 0.675707, 0.758567, 0.413529, 0.963967, 0.217291,
0.101335, 0.804231, 0.329673, 0.924503, 0.728742, 0.180217,
0.210459, 0.133869, 0.650827, 0.047613, 0.554795, 0.653365,
0.442196, 0.261945, 0.0528113, 0.656698, 0.127345, 0.610039,
0.169131, 0.458647, 0.0988288, 0.966109, 0.0421747, 0.82035,
0.274711, 0.359377, 0.512113, 0.689682, 0.941571, 0.31961,
0.743826, 0.858147, 0.984766, 0.926973, 0.579597, 0.444104,
0.505969, 0.241437, 0.937999, 0.0957074, 0.773611, 0.46023,
0.469379, 0.363789, 0.269745, 0.486136, 0.894215, 0.794299,
0.724615, 0.261945, 0.0528113, 0.656698, 0.127345, 0.610039,
0.169131, 0.458647, 0.0988288, 0.966109, 0.0421747, 0.82035,
0.274711, 0.359377, 0.512113, 0.689682, 0.941571, 0.31961,
0.743826, 0.858147, 0.984766, 0.926973
],
[1, 5, 5, inDepth]);
const w = tf.tensor4d(
[
0.6511372, 0.8699447, 0.6511372, 0.8699447, 0.267792, 0.9981787,
0.267792, 0.9981787, 0.4913572, 0.3321196, 0.4913572, 0.3321196,
0.5286497, 0.4241803, 0.5286497, 0.4241803, 0.0175446, 0.8365464,
0.0175446, 0.8365464, 0.1768399, 0.2874831, 0.1768399, 0.2874831,
0.0933998, 0.5764548, 0.0933998, 0.5764548, 0.0661623, 0.8850273,
0.0661623, 0.8850273, 0.8700929, 0.205422, 0.8700929, 0.205422,
0.149194, 0.089009, 0.654891, 0.083324, 0.537043, 0.644331,
0.563037, 0.211859, 0.633501, 0.186427, 0.777034, 0.50001,
0.607341, 0.95303, 0.696479, 0.050387, 0.62045, 0.728049,
0.028043, 0.437009, 0.712881, 0.741935, 0.974474, 0.621102,
0.171411, 0.125386, 0.975199, 0.640437, 0.281895, 0.990968,
0.347208, 0.889702, 0.180695, 0.691992, 0.347154, 0.386692,
0.327191, 0.483784, 0.591807, 0.24263, 0.95182, 0.174353,
0.592136, 0.623469, 0.988244, 0.660731, 0.946534, 0.0801365,
0.864889, 0.874602, 0.240347, 0.906352, 0.478657, 0.825918,
0.380769, 0.184705, 0.238241, 0.201907, 0.294087, 0.181165,
0.191303, 0.7225, 0.430064, 0.900622
],
[fSize, fSize, inDepth, chMul],
);

const result = tf.depthwiseConv2d(x, w, stride, pad);
expect(result.shape).toEqual([1, 3, 3, 4]);
const expected = [
2.2883458137512207, 2.5740344524383545, 2.3246560096740723,
2.27826189994812, 3.0600292682647705, 5.021538734436035,
4.432307720184326, 2.6976213455200195, 1.8467353582382202,
3.617821216583252, 2.0940940380096436, 1.3091316223144531,
2.4892354011535645, 4.767732620239258, 3.126866579055786,
3.4326541423797607, 4.181705474853516, 8.082467079162598,
6.922453880310059, 5.922790050506592, 2.819075345993042,
5.9510369300842285, 3.7211103439331055, 2.7263708114624023,
1.164026141166687, 3.3068809509277344, 1.6575196981430054,
2.738445997238159, 2.288442850112915, 5.463253021240234,
2.840029239654541, 3.8579823970794678, 1.440760612487793,
3.862100839614868, 2.3826799392700195, 2.2323575019836426
];
expectArraysClose(await result.data(), expected);
});

it('input=1x5x5x1,f=3,s=1,d=2,p=explicit,chMul=1', async () => {
const fSize = 3;
const pad =
Expand Down

0 comments on commit 7423ee7

Please sign in to comment.