@@ -176,7 +176,7 @@ VBuffer kernelNCHW_OCHW_repack_O4C4HWi4o4(
176176}
177177
178178VBuffer bufferFromOptionalHostData (
179- c10::optional<float *> data,
179+ c10::optional<const float *> data,
180180 const uint32_t size) {
181181 const auto sizeAligned =
182182 ROUND_UP (size, context ().limits ().minStorageBufferOffsetAlignment );
@@ -202,17 +202,15 @@ uint32_t conv2d_biasBufferSize(uint32_t oc) {
202202void conv2d_depthwise (
203203 VulkanTensor& output,
204204 const VulkanTensor& input,
205- const float * weight,
206- const c10::optional< float *> bias ,
207- const Conv2DParams params,
205+ const VulkanTensor& weight,
206+ const VBuffer& biasBuffer ,
207+ const Conv2DParams& params,
208208 c10::optional<float > output_min,
209209 c10::optional<float > output_max) {
210210 TORCH_INTERNAL_ASSERT (params.G == params.C );
211211 auto osizes = output.sizes ();
212212 TORCH_INTERNAL_ASSERT (osizes[2 ] == params.OH );
213213 TORCH_INTERNAL_ASSERT (osizes[3 ] == params.OW );
214- auto biasBuffer =
215- bufferFromOptionalHostData (bias, conv2d_biasBufferSize (params.OC ));
216214 struct ConstBlock {
217215 int32_t padding[2 ];
218216 int32_t kernelSize[2 ];
@@ -234,9 +232,6 @@ void conv2d_depthwise(
234232 output_max ? *output_max : std::numeric_limits<float >::infinity ()};
235233 VBuffer constBuffer = makeUniformConstBuffer ((void *)&cb, sizeof (cb));
236234
237- VulkanTensor kernel{{params.OC , params.KH , params.KW }};
238- kernel.set_data_from_host (weight);
239-
240235 VkDescriptorSetLayout descriptorSetLayout{};
241236 VkDescriptorPool descriptorPool{};
242237 VkDescriptorSet descriptorSet{};
@@ -256,7 +251,7 @@ void conv2d_depthwise(
256251
257252 output.image ()->bindStorageImage (descriptorSet, 0 );
258253 input.image ()->bindShaderRead (descriptorSet, 1 );
259- kernel .image ()->bindShaderRead (descriptorSet, 2 );
254+ weight .image ()->bindShaderRead (descriptorSet, 2 );
260255 biasBuffer.bind (descriptorSet, 3 );
261256 constBuffer.bind (descriptorSet, 4 );
262257
@@ -269,7 +264,7 @@ void conv2d_depthwise(
269264 auto commandBuffer = computeUnit.commandBuffer ();
270265 output.image ()->addImageMemoryBarrierToGeneral (commandBuffer);
271266 input.image ()->addImageMemoryBarrierToShaderRead (commandBuffer);
272- kernel .image ()->addImageMemoryBarrierToShaderRead (commandBuffer);
267+ weight .image ()->addImageMemoryBarrierToShaderRead (commandBuffer);
273268 computeUnit.dispatchCommandBuffer (
274269 params.OW , params.OH , params.OC_4 , workGroupSize);
275270 computeUnit.endCommandBuffer ();
@@ -279,6 +274,44 @@ void conv2d_depthwise(
279274 vkDestroyDescriptorSetLayout (device, descriptorSetLayout, nullptr );
280275}
281276
277+ void conv2d_depthwise (
278+ VulkanTensor& output,
279+ const VulkanTensor& input,
280+ const VulkanTensor& weight,
281+ const c10::optional<const float *> bias,
282+ const Conv2DParams params,
283+ c10::optional<float > output_min,
284+ c10::optional<float > output_max) {
285+ conv2d_depthwise (
286+ output,
287+ input,
288+ weight,
289+ bufferFromOptionalHostData (bias, conv2d_biasBufferSize (params.OC )),
290+ params,
291+ output_min,
292+ output_max);
293+ }
294+
295+ void conv2d_depthwise (
296+ VulkanTensor& output,
297+ const VulkanTensor& input,
298+ const float * weight,
299+ const c10::optional<const float *> bias,
300+ const Conv2DParams params,
301+ c10::optional<float > output_min,
302+ c10::optional<float > output_max) {
303+ VulkanTensor weightTensor{{params.OC , params.KH , params.KW }};
304+ weightTensor.set_data_from_host (weight);
305+ conv2d_depthwise (
306+ output,
307+ input,
308+ weightTensor,
309+ bufferFromOptionalHostData (bias, conv2d_biasBufferSize (params.OC )),
310+ params,
311+ output_min,
312+ output_max);
313+ }
314+
282315ImageSizes conv2d_prepack_weights_image_sizes (
283316 int64_t OC,
284317 int64_t C,
@@ -463,7 +496,7 @@ void conv2d(
463496 VulkanTensor& output,
464497 const VulkanTensor& input,
465498 const VImage& kernelImage,
466- const c10::optional<float *> bias,
499+ const c10::optional<const float *> bias,
467500 const Conv2DParams& params,
468501 c10::optional<float > output_min,
469502 c10::optional<float > output_max) {
@@ -483,10 +516,22 @@ void conv2d(
483516 VulkanTensor& output,
484517 const VulkanTensor& input,
485518 const VulkanTensor& weight_prepacked,
486- c10::optional<float *> bias,
519+ c10::optional<const float *> bias,
487520 const Conv2DParams params,
488521 c10::optional<float > output_min,
489522 c10::optional<float > output_max) {
523+ if (params.G > 1 ) {
524+ conv2d_depthwise (
525+ output,
526+ input,
527+ weight_prepacked,
528+ bufferFromOptionalHostData (bias, conv2d_biasBufferSize (params.OC )),
529+ params,
530+ output_min,
531+ output_max);
532+ return ;
533+ }
534+
490535 conv2d (
491536 output,
492537 input,
@@ -505,6 +550,18 @@ void conv2d(
505550 const Conv2DParams params,
506551 c10::optional<float > output_min,
507552 c10::optional<float > output_max) {
553+ if (params.G > 1 ) {
554+ conv2d_depthwise (
555+ output,
556+ input,
557+ weight_prepacked,
558+ *(bias.buffer ()),
559+ params,
560+ output_min,
561+ output_max);
562+ return ;
563+ }
564+
508565 conv2d (
509566 output,
510567 input,
@@ -519,7 +576,7 @@ void conv2d(
519576 VulkanTensor& output,
520577 const VulkanTensor& input,
521578 const float * weight,
522- const c10::optional<float *> bias,
579+ const c10::optional<const float *> bias,
523580 const Conv2DParams params,
524581 c10::optional<float > output_min,
525582 c10::optional<float > output_max) {
0 commit comments