From 7423ee78f950e7c428a919e3fe8910fd5f884e98 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Thu, 15 Sep 2022 13:59:14 +0800 Subject: [PATCH] webgpu: Expand DepthwiseConv2DVec4Program to support any stride (#6820) * webgpu: Optimize stride2x2 * Support any workPerThread * Move some checkings out of the loop * Add tests * Change constraint to stride <= 2 * Use fma builtin * Address comments Co-authored-by: Linchenn <40653845+Linchenn@users.noreply.github.com> --- .../src/depthwise_conv2d_vec4_webgpu.ts | 56 ++++--- .../src/kernels/DepthwiseConv2dNative.ts | 2 +- .../src/kernels/FusedDepthwiseConv2D.ts | 2 +- tfjs-core/src/ops/depthwise_conv2d_test.ts | 153 +++++++++++++++++- 4 files changed, 185 insertions(+), 28 deletions(-) diff --git a/tfjs-backend-webgpu/src/depthwise_conv2d_vec4_webgpu.ts b/tfjs-backend-webgpu/src/depthwise_conv2d_vec4_webgpu.ts index 8f1962e0489..744ba59a052 100644 --- a/tfjs-backend-webgpu/src/depthwise_conv2d_vec4_webgpu.ts +++ b/tfjs-backend-webgpu/src/depthwise_conv2d_vec4_webgpu.ts @@ -28,6 +28,7 @@ export class DepthwiseConv2DVec4Program implements WebGPUProgram { variableNames = ['x', 'W']; uniforms = 'pad : vec2, inDims : vec2,'; workGroupSize: [number, number, number] = [4, 4, 4]; + workPerThread = 4; convInfo: backend_util.Conv2DInfo; addBias: boolean; activation: backend_util.Activation; @@ -40,7 +41,8 @@ export class DepthwiseConv2DVec4Program implements WebGPUProgram { this.outputShape = convInfo.outShape; this.dispatchLayout = {x: [3], y: [2], z: [0, 1]}; this.dispatch = computeDispatch( - this.dispatchLayout, this.outputShape, this.workGroupSize, [4, 4, 1]); + this.dispatchLayout, this.outputShape, this.workGroupSize, + [4, this.workPerThread, 1]); util.assert( convInfo.dataFormat === 'channelsLast', @@ -58,57 +60,61 @@ export class DepthwiseConv2DVec4Program implements WebGPUProgram { this.activation = activation; this.hasPreluActivation = hasPreluActivation; - this.shaderKey = `depthwiseVec4_${activation}_${ - this.convInfo.filterHeight}_${this.convInfo.filterWidth}`; + this.shaderKey = + `depthwiseVec4_${activation}_${this.convInfo.filterHeight}_${ + this.convInfo.filterWidth}_${this.convInfo.strideHeight}_${ + this.convInfo.strideWidth}_${this.workPerThread}`; } getUserCode(): string { - // Here 4 is the work per thread in X dimension. - const xNumber = 4 + this.convInfo.filterWidth - 1; + const xNumber = (this.workPerThread - 1) * this.convInfo.strideWidth + + this.convInfo.filterWidth; + const userCode = ` ${activationFnSnippet(this.activation, this.hasPreluActivation, true, 4)} fn readX(batch : i32, row : i32, col : i32, channel : i32) -> vec4 { var value = vec4(0.0); - if (row >=0 && row < uniforms.inDims[0] && col >=0 && col < uniforms.inDims[1]) - { + if (col >=0 && col < uniforms.inDims[1]) { value = getX(batch, row, col, channel); } return value; } + + const strideHeight = ${this.convInfo.strideHeight}; + const strideWidth = ${this.convInfo.strideWidth}; ${getWorkGroupSizeString()} fn _start(@builtin(global_invocation_id) globalId: vec3) { let batch = i32(globalId.z) / uniforms.outShape[1]; let r = i32(globalId.z) % uniforms.outShape[1]; - let c = i32(globalId.y) * 4; + let c = i32(globalId.y) * ${this.workPerThread}; let d1 = i32(globalId.x) * 4; - let xRCCorner = vec2(r, c) - uniforms.pad; + let xRCCorner = vec2(r, c) * vec2(strideHeight, strideWidth) - uniforms.pad; let xRCorner = xRCCorner.x; let xCCorner = xRCCorner.y; var xVals : array, ${xNumber}>; - var dotProd : array, 4>; - dotProd[0] = vec4(0.0); - dotProd[1] = vec4(0.0); - dotProd[2] = vec4(0.0); - dotProd[3] = vec4(0.0); + var dotProd : array, ${this.workPerThread}>; + for (var i = 0; i < ${this.workPerThread}; i++) { + dotProd[i] = vec4(0.0); + } // Use constant instead of uniform can give better performance. for (var wR = 0; wR < ${this.convInfo.filterHeight}; wR = wR + 1) { let xR = xRCorner + wR; - for (var i = 0; i < ${xNumber}; i++) - { - xVals[i] = readX(batch, xR, xCCorner + i, d1); - } - for (var wC = 0; wC < ${this.convInfo.filterWidth}; wC = wC + 1) { - let wValue = getW(wR, wC, d1, 0); - dotProd[0] = dotProd[0] + xVals[0 + wC] * wValue; - dotProd[1] = dotProd[1] + xVals[1 + wC] * wValue; - dotProd[2] = dotProd[2] + xVals[2 + wC] * wValue; - dotProd[3] = dotProd[3] + xVals[3 + wC] * wValue; + if (xR >=0 && xR < uniforms.inDims[0]) { + for (var i = 0; i < ${xNumber}; i++) { + xVals[i] = readX(batch, xR, xCCorner + i, d1); + } + for (var wC = 0; wC < ${this.convInfo.filterWidth}; wC = wC + 1) { + let wValue = getW(wR, wC, d1, 0); + for (var i = 0; i < ${this.workPerThread}; i++) { + dotProd[i] = fma(xVals[i * strideWidth + wC], wValue, dotProd[i]); + } + } } } - for (var i = 0; i < 4; i = i + 1) { + for (var i = 0; i < ${this.workPerThread}; i = i + 1) { let coords = vec4(batch, r, c + i, d1); if (coordsInBounds4D(coords, uniforms.outShape)) { var value = dotProd[i]; diff --git a/tfjs-backend-webgpu/src/kernels/DepthwiseConv2dNative.ts b/tfjs-backend-webgpu/src/kernels/DepthwiseConv2dNative.ts index 2dd2bdad23f..e952351560c 100644 --- a/tfjs-backend-webgpu/src/kernels/DepthwiseConv2dNative.ts +++ b/tfjs-backend-webgpu/src/kernels/DepthwiseConv2dNative.ts @@ -56,7 +56,7 @@ export function depthwiseConv2dNative(args: { convInfo.outShape, convInfo.filterHeight, convInfo.filterWidth); } else if ( isChannelsLast && convInfo.inHeight > 4 && convInfo.inWidth > 4 && - convInfo.strideHeight === 1 && convInfo.strideWidth === 1 && + convInfo.strideWidth <= 2 && convInfo.inChannels === convInfo.outChannels && convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 && convInfo.inChannels % 4 === 0) { diff --git a/tfjs-backend-webgpu/src/kernels/FusedDepthwiseConv2D.ts b/tfjs-backend-webgpu/src/kernels/FusedDepthwiseConv2D.ts index eb6b0093551..a7545542e19 100644 --- a/tfjs-backend-webgpu/src/kernels/FusedDepthwiseConv2D.ts +++ b/tfjs-backend-webgpu/src/kernels/FusedDepthwiseConv2D.ts @@ -65,7 +65,7 @@ export function fusedDepthwiseConv2D(args: { let program: DepthwiseConv2DProgram|DepthwiseConv2DVec4Program; if (convInfo.inHeight > 4 && convInfo.inWidth > 4 && - convInfo.strideHeight === 1 && convInfo.strideWidth === 1 && + convInfo.strideWidth <= 2 && convInfo.inChannels === convInfo.outChannels && convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 && convInfo.inChannels % 4 === 0) { diff --git a/tfjs-core/src/ops/depthwise_conv2d_test.ts b/tfjs-core/src/ops/depthwise_conv2d_test.ts index 7ebf5c33692..e6ccc7b0e2a 100644 --- a/tfjs-core/src/ops/depthwise_conv2d_test.ts +++ b/tfjs-core/src/ops/depthwise_conv2d_test.ts @@ -338,7 +338,7 @@ describeWithFlags('depthwiseConv2D', ALL_ENVS, () => { expectArraysClose(await result.data(), await expectedResult.data()); }); - it('input=1x5x5x1,f=3,s=1,d=2,p=same,chMul=1', async () => { + it('input=1x5x5x1,f=3,s=1,d=2,p=valid,chMul=1', async () => { const fSize = 3; const pad = 'valid'; const stride = 1; @@ -370,6 +370,157 @@ describeWithFlags('depthwiseConv2D', ALL_ENVS, () => { expectArraysClose(await result.data(), expected); }); + it('input=1x5x5x4,f=3,s=1,d=1,p=same,chMul=1', async () => { + const fSize = 3; + const pad = 'same'; + const stride = 1; + const chMul = 1; + const inDepth = 4; + + const x = tf.tensor4d( + [ + 0.149194, 0.089009, 0.654891, 0.083324, 0.537043, 0.644331, + 0.563037, 0.211859, 0.633501, 0.186427, 0.777034, 0.50001, + 0.607341, 0.95303, 0.696479, 0.050387, 0.62045, 0.728049, + 0.028043, 0.437009, 0.712881, 0.741935, 0.974474, 0.621102, + 0.171411, 0.675707, 0.758567, 0.413529, 0.963967, 0.217291, + 0.101335, 0.804231, 0.329673, 0.924503, 0.728742, 0.180217, + 0.210459, 0.133869, 0.650827, 0.047613, 0.554795, 0.653365, + 0.442196, 0.261945, 0.0528113, 0.656698, 0.127345, 0.610039, + 0.169131, 0.458647, 0.0988288, 0.966109, 0.0421747, 0.82035, + 0.274711, 0.359377, 0.512113, 0.689682, 0.941571, 0.31961, + 0.743826, 0.858147, 0.984766, 0.926973, 0.579597, 0.444104, + 0.505969, 0.241437, 0.937999, 0.0957074, 0.773611, 0.46023, + 0.469379, 0.363789, 0.269745, 0.486136, 0.894215, 0.794299, + 0.724615, 0.261945, 0.0528113, 0.656698, 0.127345, 0.610039, + 0.169131, 0.458647, 0.0988288, 0.966109, 0.0421747, 0.82035, + 0.274711, 0.359377, 0.512113, 0.689682, 0.941571, 0.31961, + 0.743826, 0.858147, 0.984766, 0.926973 + ], + [1, 5, 5, inDepth]); + const w = tf.tensor4d( + [ + 0.6511372, 0.8699447, 0.6511372, 0.8699447, 0.267792, 0.9981787, + 0.267792, 0.9981787, 0.4913572, 0.3321196, 0.4913572, 0.3321196, + 0.5286497, 0.4241803, 0.5286497, 0.4241803, 0.0175446, 0.8365464, + 0.0175446, 0.8365464, 0.1768399, 0.2874831, 0.1768399, 0.2874831, + 0.0933998, 0.5764548, 0.0933998, 0.5764548, 0.0661623, 0.8850273, + 0.0661623, 0.8850273, 0.8700929, 0.205422, 0.8700929, 0.205422 + ], + [fSize, fSize, inDepth, chMul], + ); + + const result = tf.depthwiseConv2d(x, w, stride, pad); + expect(result.shape).toEqual([1, 5, 5, 4]); + const expected = [ + 0.29389750957489014, 1.055132269859314, 0.8355544209480286, + 0.7652503848075867, 1.116986632347107, 1.7007107734680176, + 0.7228718996047974, 1.2455471754074097, 0.7690584063529968, + 1.4749835729599, 1.1460752487182617, 1.5098011493682861, + 0.7502411007881165, 2.056602716445923, 1.0519171953201294, + 1.012758731842041, 0.37667199969291687, 1.6647151708602905, + 0.4798099994659424, 0.532977283000946, 0.4293096363544464, + 1.8309053182601929, 0.7433272004127502, 1.1491419076919556, + 1.3050479888916016, 2.7769954204559326, 1.6411027908325195, + 2.1799824237823486, 1.0364032983779907, 2.7503039836883545, + 1.7060394287109375, 2.880652904510498, 1.8967751264572144, + 3.3914175033569336, 1.734355092048645, 2.076633930206299, + 0.7774094939231873, 3.1432321071624756, 0.9456352591514587, + 1.0863502025604248, 0.8477171659469604, 2.5510711669921875, + 1.169355869293213, 2.0218098163604736, 2.23183274269104, + 3.257829189300537, 1.939490556716919, 2.96195650100708, + 1.0946838855743408, 2.4252827167510986, 1.329919695854187, + 3.0390005111694336, 1.8967963457107544, 2.775693416595459, + 1.5250799655914307, 2.4470155239105225, 0.40530526638031006, + 2.775503158569336, 0.8836789727210999, 1.1361782550811768, + 0.4407186806201935, 2.3912413120269775, 0.38215696811676025, + 2.047299861907959, 1.080580234527588, 3.09224534034729, + 1.2943278551101685, 3.1656715869903564, 0.9704407453536987, + 2.8066811561584473, 1.419780969619751, 3.1822099685668945, + 1.720312237739563, 3.279745578765869, 2.0871992111206055, + 2.6629819869995117, 0.5254714488983154, 3.3779194355010986, + 0.73943030834198, 2.0616414546966553, 0.5148154497146606, + 1.6852912902832031, 0.5320349931716919, 1.7935365438461304, + 1.1387810707092285, 2.119696617126465, 1.2744661569595337, + 2.3705403804779053, 1.0399315357208252, 1.6817822456359863, + 0.8927359580993652, 1.6332063674926758, 1.3386595249176025, + 1.8818190097808838, 1.267898440361023, 1.6589205265045166, + 0.8288722038269043, 2.119757890701294, 0.8847255706787109, + 1.5954076051712036 + ]; + expectArraysClose(await result.data(), expected); + }); + + it('input=1x5x5x4,f=5,s=2,d=1,p=same,chMul=1', async () => { + const fSize = 5; + const pad = 'same'; + const stride = 2; + const chMul = 1; + const inDepth = 4; + + const x = tf.tensor4d( + [ + 0.149194, 0.089009, 0.654891, 0.083324, 0.537043, 0.644331, + 0.563037, 0.211859, 0.633501, 0.186427, 0.777034, 0.50001, + 0.607341, 0.95303, 0.696479, 0.050387, 0.62045, 0.728049, + 0.028043, 0.437009, 0.712881, 0.741935, 0.974474, 0.621102, + 0.171411, 0.675707, 0.758567, 0.413529, 0.963967, 0.217291, + 0.101335, 0.804231, 0.329673, 0.924503, 0.728742, 0.180217, + 0.210459, 0.133869, 0.650827, 0.047613, 0.554795, 0.653365, + 0.442196, 0.261945, 0.0528113, 0.656698, 0.127345, 0.610039, + 0.169131, 0.458647, 0.0988288, 0.966109, 0.0421747, 0.82035, + 0.274711, 0.359377, 0.512113, 0.689682, 0.941571, 0.31961, + 0.743826, 0.858147, 0.984766, 0.926973, 0.579597, 0.444104, + 0.505969, 0.241437, 0.937999, 0.0957074, 0.773611, 0.46023, + 0.469379, 0.363789, 0.269745, 0.486136, 0.894215, 0.794299, + 0.724615, 0.261945, 0.0528113, 0.656698, 0.127345, 0.610039, + 0.169131, 0.458647, 0.0988288, 0.966109, 0.0421747, 0.82035, + 0.274711, 0.359377, 0.512113, 0.689682, 0.941571, 0.31961, + 0.743826, 0.858147, 0.984766, 0.926973 + ], + [1, 5, 5, inDepth]); + const w = tf.tensor4d( + [ + 0.6511372, 0.8699447, 0.6511372, 0.8699447, 0.267792, 0.9981787, + 0.267792, 0.9981787, 0.4913572, 0.3321196, 0.4913572, 0.3321196, + 0.5286497, 0.4241803, 0.5286497, 0.4241803, 0.0175446, 0.8365464, + 0.0175446, 0.8365464, 0.1768399, 0.2874831, 0.1768399, 0.2874831, + 0.0933998, 0.5764548, 0.0933998, 0.5764548, 0.0661623, 0.8850273, + 0.0661623, 0.8850273, 0.8700929, 0.205422, 0.8700929, 0.205422, + 0.149194, 0.089009, 0.654891, 0.083324, 0.537043, 0.644331, + 0.563037, 0.211859, 0.633501, 0.186427, 0.777034, 0.50001, + 0.607341, 0.95303, 0.696479, 0.050387, 0.62045, 0.728049, + 0.028043, 0.437009, 0.712881, 0.741935, 0.974474, 0.621102, + 0.171411, 0.125386, 0.975199, 0.640437, 0.281895, 0.990968, + 0.347208, 0.889702, 0.180695, 0.691992, 0.347154, 0.386692, + 0.327191, 0.483784, 0.591807, 0.24263, 0.95182, 0.174353, + 0.592136, 0.623469, 0.988244, 0.660731, 0.946534, 0.0801365, + 0.864889, 0.874602, 0.240347, 0.906352, 0.478657, 0.825918, + 0.380769, 0.184705, 0.238241, 0.201907, 0.294087, 0.181165, + 0.191303, 0.7225, 0.430064, 0.900622 + ], + [fSize, fSize, inDepth, chMul], + ); + + const result = tf.depthwiseConv2d(x, w, stride, pad); + expect(result.shape).toEqual([1, 3, 3, 4]); + const expected = [ + 2.2883458137512207, 2.5740344524383545, 2.3246560096740723, + 2.27826189994812, 3.0600292682647705, 5.021538734436035, + 4.432307720184326, 2.6976213455200195, 1.8467353582382202, + 3.617821216583252, 2.0940940380096436, 1.3091316223144531, + 2.4892354011535645, 4.767732620239258, 3.126866579055786, + 3.4326541423797607, 4.181705474853516, 8.082467079162598, + 6.922453880310059, 5.922790050506592, 2.819075345993042, + 5.9510369300842285, 3.7211103439331055, 2.7263708114624023, + 1.164026141166687, 3.3068809509277344, 1.6575196981430054, + 2.738445997238159, 2.288442850112915, 5.463253021240234, + 2.840029239654541, 3.8579823970794678, 1.440760612487793, + 3.862100839614868, 2.3826799392700195, 2.2323575019836426 + ]; + expectArraysClose(await result.data(), expected); + }); + it('input=1x5x5x1,f=3,s=1,d=2,p=explicit,chMul=1', async () => { const fSize = 3; const pad =