From a75e3b4ea52903d2daec08173dff7488b76526be Mon Sep 17 00:00:00 2001 From: nihui Date: Sat, 13 Apr 2024 21:55:33 +0800 Subject: [PATCH] fix instruction extension dispatch --- src/layer/arm/convolution_im2col_gemm_int8.h | 2696 +---------------- src/layer/arm/convolution_packed_int8.h | 12 +- src/layer/arm/innerproduct_fp16s.h | 16 +- src/layer/arm/innerproduct_gemm_fp16s.h | 8 +- src/layer/x86/cast_bf16.h | 6 +- src/layer/x86/convolution_3x3_winograd_int8.h | 32 +- src/layer/x86/convolution_im2col_gemm_int8.h | 24 +- src/layer/x86/convolution_packed_int8.h | 23 +- 8 files changed, 135 insertions(+), 2682 deletions(-) diff --git a/src/layer/arm/convolution_im2col_gemm_int8.h b/src/layer/arm/convolution_im2col_gemm_int8.h index 63a1df9c32b..171809fb31b 100644 --- a/src/layer/arm/convolution_im2col_gemm_int8.h +++ b/src/layer/arm/convolution_im2col_gemm_int8.h @@ -12,17 +12,15 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#if !(__ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD) #if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 void convolution_im2col_gemm_transform_kernel_int8_i8mm(const Mat& kernel, Mat& AT, int inch, int outch, int kernel_w, int kernel_h, const Option& opt); void convolution_im2col_gemm_int8_i8mm(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, int nT, const Option& opt); #endif -#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 void convolution_im2col_gemm_transform_kernel_int8_asimddp(const Mat& kernel, Mat& AT, int inch, int outch, int kernel_w, int kernel_h, const Option& opt); void convolution_im2col_gemm_int8_asimddp(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, int nT, const Option& opt); #endif -#endif static void convolution_im2col_pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) { @@ -8230,14 +8228,7 @@ static void convolution_im2col_input_tile_conv1x1s1d1_int8(const Mat& bottom_blo } } -template -#if __ARM_FEATURE_MATMUL_INT8 -void convolution_im2col_input_tile_int8_i8mm(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) -#elif __ARM_FEATURE_DOTPROD -void convolution_im2col_input_tile_int8_asimddp(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) -#else // __ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD -void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) -#endif // __ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD +static inline void convolution_im2col_input_tile_int8_impl(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h) { const int w = bottom_blob.w; // const int channels = bottom_blob.c; @@ -10747,6 +10738,18 @@ void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, int j, i } } +template +#if __ARM_FEATURE_MATMUL_INT8 +void convolution_im2col_input_tile_int8_i8mm(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) +#elif __ARM_FEATURE_DOTPROD +void convolution_im2col_input_tile_int8_asimddp(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) +#else // __ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD +void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) +#endif // __ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD +{ + convolution_im2col_input_tile_int8_impl(bottom_blob, B, j, max_jj, k, max_kk, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h); +} + #if __ARM_FEATURE_MATMUL_INT8 template void convolution_im2col_input_tile_int8_i8mm<1, 1, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); template void convolution_im2col_input_tile_int8_i8mm<3, 3, 1, 1, 1, 1>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); @@ -10850,2622 +10853,113 @@ static void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, i return; } - const int w = bottom_blob.w; - // const int channels = bottom_blob.c; - const int elempack = bottom_blob.elempack; + convolution_im2col_input_tile_int8_impl(bottom_blob, B, j, max_jj, k, max_kk, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h); +} - const int kernel_extent_w = dilation_w * (kernel_w - 1) + 1; - const int outw = (w - kernel_extent_w) / stride_w + 1; +static void convolution_im2col_gemm_transform_kernel_int8(const Mat& kernel, Mat& AT, int inch, int outch, int kernel_w, int kernel_h, const Option& opt) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + convolution_im2col_gemm_transform_kernel_int8_i8mm(kernel, AT, inch, outch, kernel_w, kernel_h, opt); + return; + } +#endif - // j max_jj outw*outh split w and h +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + convolution_im2col_gemm_transform_kernel_int8_asimddp(kernel, AT, inch, outch, kernel_w, kernel_h, opt); + return; + } +#endif - // k max_kk pa*maxk*(inch/pa) split inch + // NCNN_LOGE("convolution_im2col_gemm_transform_kernel"); + const int maxk = kernel_w * kernel_h; - // k/max_kk shall be multiple of maxk + const int M = outch; + const int K = inch * maxk; - const int maxk = kernel_w * kernel_h; + int TILE_M, TILE_N, TILE_K; + convolution_im2col_gemm_get_optimal_tile_mnk_int8(M, 0, K, TILE_M, TILE_N, TILE_K, opt.num_threads); - signed char* pp = B; + const int nn_M = (M + TILE_M - 1) / TILE_M; - int jj = 0; + int elempack = 1; #if __ARM_NEON -#if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) + if (opt.use_packing_layout) { - int dy0 = (j + jj) / outw * stride_h; - int dy1 = (j + jj + 1) / outw * stride_h; - int dy2 = (j + jj + 2) / outw * stride_h; - int dy3 = (j + jj + 3) / outw * stride_h; - int dy4 = (j + jj + 4) / outw * stride_h; - int dy5 = (j + jj + 5) / outw * stride_h; - int dy6 = (j + jj + 6) / outw * stride_h; - int dy7 = (j + jj + 7) / outw * stride_h; - int dx0 = (j + jj) % outw * stride_w; - int dx1 = (j + jj + 1) % outw * stride_w; - int dx2 = (j + jj + 2) % outw * stride_w; - int dx3 = (j + jj + 3) % outw * stride_w; - int dx4 = (j + jj + 4) % outw * stride_w; - int dx5 = (j + jj + 5) % outw * stride_w; - int dx6 = (j + jj + 6) % outw * stride_w; - int dx7 = (j + jj + 7) % outw * stride_w; + elempack = inch % 8 == 0 ? 8 : 1; + } +#endif // __ARM_NEON - if (dy0 == dy7) - { - int kk = 0; - if (elempack == 1) - { -#if __ARM_FEATURE_DOTPROD -#if __ARM_FEATURE_MATMUL_INT8 - for (; kk + 7 < max_kk; kk += 8) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int p2 = (k + kk + 2) / maxk; - int p3 = (k + kk + 3) / maxk; - int p4 = (k + kk + 4) / maxk; - int p5 = (k + kk + 5) / maxk; - int p6 = (k + kk + 6) / maxk; - int p7 = (k + kk + 7) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int uv2 = (k + kk + 2) % maxk; - int uv3 = (k + kk + 3) % maxk; - int uv4 = (k + kk + 4) % maxk; - int uv5 = (k + kk + 5) % maxk; - int uv6 = (k + kk + 6) % maxk; - int uv7 = (k + kk + 7) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int u2 = uv2 / kernel_w; - int u3 = uv3 / kernel_w; - int u4 = uv4 / kernel_w; - int u5 = uv5 / kernel_w; - int u6 = uv6 / kernel_w; - int u7 = uv7 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - int v2 = uv2 % kernel_w; - int v3 = uv3 % kernel_w; - int v4 = uv4 % kernel_w; - int v5 = uv5 % kernel_w; - int v6 = uv6 % kernel_w; - int v7 = uv7 % kernel_w; + // maxk-inch-outch to pa-maxk-inch/pa-outch + Mat A_data; + if (maxk == 1) + { + A_data = kernel.reshape(maxk * inch, outch); + } + else + { + Mat weight_data_r2 = kernel.reshape(maxk, inch, outch); - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - const Mat img2 = bottom_blob.channel(p2); - const Mat img3 = bottom_blob.channel(p3); - const Mat img4 = bottom_blob.channel(p4); - const Mat img5 = bottom_blob.channel(p5); - const Mat img6 = bottom_blob.channel(p6); - const Mat img7 = bottom_blob.channel(p7); + A_data.create(maxk * inch, outch, (size_t)1u, 1); - int x00 = dx0 + dilation_w * v0; - int y00 = dy0 + dilation_h * u0; + for (int q = 0; q < outch; q += 1) + { + signed char* g00 = A_data.row(q); - int x10 = dx0 + dilation_w * v1; - int y10 = dy0 + dilation_h * u1; + for (int p = 0; p + (elempack - 1) < inch; p += elempack) + { + for (int k = 0; k < maxk; k++) + { + for (int i = 0; i < elempack; i++) + { + const signed char* k00 = weight_data_r2.channel(q).row(p + i); + g00[0] = k00[k]; + g00++; + } + } + } + } + } - int x20 = dx0 + dilation_w * v2; - int y20 = dy0 + dilation_h * u2; + AT.create(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, (M + TILE_M - 1) / TILE_M, (size_t)1u, 1); - int x30 = dx0 + dilation_w * v3; - int y30 = dy0 + dilation_h * u3; + #pragma omp parallel for num_threads(opt.num_threads) + for (int ppj = 0; ppj < nn_M; ppj++) + { + const int i = ppj * TILE_M; - int x40 = dx0 + dilation_w * v4; - int y40 = dy0 + dilation_h * u4; + const int max_ii = std::min((M - i), TILE_M); - int x50 = dx0 + dilation_w * v5; - int y50 = dy0 + dilation_h * u5; + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); - int x60 = dx0 + dilation_w * v6; - int y60 = dy0 + dilation_h * u6; + Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); - int x70 = dx0 + dilation_w * v7; - int y70 = dy0 + dilation_h * u7; + convolution_im2col_pack_A_tile_int8(A_data, AT_tile, i, max_ii, k, max_kk); + } + } +} - const signed char* sptr0 = img0.row(y00) + x00; - const signed char* sptr1 = img1.row(y10) + x10; - const signed char* sptr2 = img2.row(y20) + x20; - const signed char* sptr3 = img3.row(y30) + x30; - const signed char* sptr4 = img4.row(y40) + x40; - const signed char* sptr5 = img5.row(y50) + x50; - const signed char* sptr6 = img6.row(y60) + x60; - const signed char* sptr7 = img7.row(y70) + x70; +static void convolution_im2col_gemm_int8(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, int nT, const Option& opt) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + convolution_im2col_gemm_int8_i8mm(bottom_blob, top_blob, AT, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, nT, opt); + return; + } +#endif - if (stride_w == 1) - { - int8x8_t _r0 = vld1_s8(sptr0); - int8x8_t _r1 = vld1_s8(sptr1); - int8x8_t _r2 = vld1_s8(sptr2); - int8x8_t _r3 = vld1_s8(sptr3); - int8x8_t _r4 = vld1_s8(sptr4); - int8x8_t _r5 = vld1_s8(sptr5); - int8x8_t _r6 = vld1_s8(sptr6); - int8x8_t _r7 = vld1_s8(sptr7); - // save as transpose8x8 - int8x8x2_t _r01 = vzip_s8(_r0, _r1); - int8x8x2_t _r23 = vzip_s8(_r2, _r3); - int8x8x2_t _r45 = vzip_s8(_r4, _r5); - int8x8x2_t _r67 = vzip_s8(_r6, _r7); - int16x8x4_t _r0246; - _r0246.val[0] = vreinterpretq_s16_s8(vcombine_s8(_r01.val[0], _r01.val[1])); - _r0246.val[1] = vreinterpretq_s16_s8(vcombine_s8(_r23.val[0], _r23.val[1])); - _r0246.val[2] = vreinterpretq_s16_s8(vcombine_s8(_r45.val[0], _r45.val[1])); - _r0246.val[3] = vreinterpretq_s16_s8(vcombine_s8(_r67.val[0], _r67.val[1])); - vst4q_s16((short*)pp, _r0246); - pp += 64; - } - else if (stride_w == 2) - { - int8x16_t _r0 = vld1q_s8(sptr0); - int8x16_t _r1 = vld1q_s8(sptr1); - int8x16_t _r2 = vld1q_s8(sptr2); - int8x16_t _r3 = vld1q_s8(sptr3); - int8x16_t _r4 = vld1q_s8(sptr4); - int8x16_t _r5 = vld1q_s8(sptr5); - int8x16_t _r6 = vld1q_s8(sptr6); - int8x16_t _r7 = vld1q_s8(sptr7); - int8x16_t _r01 = vtrnq_s8(_r0, _r1).val[0]; - int8x16_t _r23 = vtrnq_s8(_r2, _r3).val[0]; - int8x16_t _r45 = vtrnq_s8(_r4, _r5).val[0]; - int8x16_t _r67 = vtrnq_s8(_r6, _r7).val[0]; - int16x8x4_t _r0123; - _r0123.val[0] = vreinterpretq_s16_s8(_r01); - _r0123.val[1] = vreinterpretq_s16_s8(_r23); - _r0123.val[2] = vreinterpretq_s16_s8(_r45); - _r0123.val[3] = vreinterpretq_s16_s8(_r67); - vst4q_s16((short*)pp, _r0123); - pp += 64; - } - else - { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr2[0]; - pp[3] = sptr3[0]; - pp[4] = sptr4[0]; - pp[5] = sptr5[0]; - pp[6] = sptr6[0]; - pp[7] = sptr7[0]; - pp[8] = sptr0[stride_w]; - pp[9] = sptr1[stride_w]; - pp[10] = sptr2[stride_w]; - pp[11] = sptr3[stride_w]; - pp[12] = sptr4[stride_w]; - pp[13] = sptr5[stride_w]; - pp[14] = sptr6[stride_w]; - pp[15] = sptr7[stride_w]; - pp[16] = sptr0[stride_w * 2]; - pp[17] = sptr1[stride_w * 2]; - pp[18] = sptr2[stride_w * 2]; - pp[19] = sptr3[stride_w * 2]; - pp[20] = sptr4[stride_w * 2]; - pp[21] = sptr5[stride_w * 2]; - pp[22] = sptr6[stride_w * 2]; - pp[23] = sptr7[stride_w * 2]; - pp[24] = sptr0[stride_w * 3]; - pp[25] = sptr1[stride_w * 3]; - pp[26] = sptr2[stride_w * 3]; - pp[27] = sptr3[stride_w * 3]; - pp[28] = sptr4[stride_w * 3]; - pp[29] = sptr5[stride_w * 3]; - pp[30] = sptr6[stride_w * 3]; - pp[31] = sptr7[stride_w * 3]; - pp[32] = sptr0[stride_w * 4]; - pp[33] = sptr1[stride_w * 4]; - pp[34] = sptr2[stride_w * 4]; - pp[35] = sptr3[stride_w * 4]; - pp[36] = sptr4[stride_w * 4]; - pp[37] = sptr5[stride_w * 4]; - pp[38] = sptr6[stride_w * 4]; - pp[39] = sptr7[stride_w * 4]; - pp[40] = sptr0[stride_w * 5]; - pp[41] = sptr1[stride_w * 5]; - pp[42] = sptr2[stride_w * 5]; - pp[43] = sptr3[stride_w * 5]; - pp[44] = sptr4[stride_w * 5]; - pp[45] = sptr5[stride_w * 5]; - pp[46] = sptr6[stride_w * 5]; - pp[47] = sptr7[stride_w * 5]; - pp[48] = sptr0[stride_w * 6]; - pp[49] = sptr1[stride_w * 6]; - pp[50] = sptr2[stride_w * 6]; - pp[51] = sptr3[stride_w * 6]; - pp[52] = sptr4[stride_w * 6]; - pp[53] = sptr5[stride_w * 6]; - pp[54] = sptr6[stride_w * 6]; - pp[55] = sptr7[stride_w * 6]; - pp[56] = sptr0[stride_w * 7]; - pp[57] = sptr1[stride_w * 7]; - pp[58] = sptr2[stride_w * 7]; - pp[59] = sptr3[stride_w * 7]; - pp[60] = sptr4[stride_w * 7]; - pp[61] = sptr5[stride_w * 7]; - pp[62] = sptr6[stride_w * 7]; - pp[63] = sptr7[stride_w * 7]; - pp += 64; - } - } -#endif // __ARM_FEATURE_MATMUL_INT8 - for (; kk + 3 < max_kk; kk += 4) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int p2 = (k + kk + 2) / maxk; - int p3 = (k + kk + 3) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int uv2 = (k + kk + 2) % maxk; - int uv3 = (k + kk + 3) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int u2 = uv2 / kernel_w; - int u3 = uv3 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - int v2 = uv2 % kernel_w; - int v3 = uv3 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - const Mat img2 = bottom_blob.channel(p2); - const Mat img3 = bottom_blob.channel(p3); - - int x00 = dx0 + dilation_w * v0; - int y00 = dy0 + dilation_h * u0; - - int x10 = dx0 + dilation_w * v1; - int y10 = dy0 + dilation_h * u1; - - int x20 = dx0 + dilation_w * v2; - int y20 = dy0 + dilation_h * u2; - - int x30 = dx0 + dilation_w * v3; - int y30 = dy0 + dilation_h * u3; - - const signed char* sptr0 = img0.row(y00) + x00; - const signed char* sptr1 = img1.row(y10) + x10; - const signed char* sptr2 = img2.row(y20) + x20; - const signed char* sptr3 = img3.row(y30) + x30; - - if (stride_w == 1) - { - int8x8x4_t _r01; - _r01.val[0] = vld1_s8(sptr0); - _r01.val[1] = vld1_s8(sptr1); - _r01.val[2] = vld1_s8(sptr2); - _r01.val[3] = vld1_s8(sptr3); - vst4_s8(pp, _r01); - pp += 32; - } - else if (stride_w == 2) - { - int8x16_t _r0 = vld1q_s8(sptr0); - int8x16_t _r1 = vld1q_s8(sptr1); - int8x16_t _r2 = vld1q_s8(sptr2); - int8x16_t _r3 = vld1q_s8(sptr3); - int8x16_t _r01 = vtrnq_s8(_r0, _r1).val[0]; - int8x16_t _r23 = vtrnq_s8(_r2, _r3).val[0]; - int16x8x2_t _r0123; - _r0123.val[0] = vreinterpretq_s16_s8(_r01); - _r0123.val[1] = vreinterpretq_s16_s8(_r23); - vst2q_s16((short*)pp, _r0123); - pp += 32; - } - else - { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr2[0]; - pp[3] = sptr3[0]; - pp[4] = sptr0[stride_w]; - pp[5] = sptr1[stride_w]; - pp[6] = sptr2[stride_w]; - pp[7] = sptr3[stride_w]; - pp[8] = sptr0[stride_w * 2]; - pp[9] = sptr1[stride_w * 2]; - pp[10] = sptr2[stride_w * 2]; - pp[11] = sptr3[stride_w * 2]; - pp[12] = sptr0[stride_w * 3]; - pp[13] = sptr1[stride_w * 3]; - pp[14] = sptr2[stride_w * 3]; - pp[15] = sptr3[stride_w * 3]; - pp[16] = sptr0[stride_w * 4]; - pp[17] = sptr1[stride_w * 4]; - pp[18] = sptr2[stride_w * 4]; - pp[19] = sptr3[stride_w * 4]; - pp[20] = sptr0[stride_w * 5]; - pp[21] = sptr1[stride_w * 5]; - pp[22] = sptr2[stride_w * 5]; - pp[23] = sptr3[stride_w * 5]; - pp[24] = sptr0[stride_w * 6]; - pp[25] = sptr1[stride_w * 6]; - pp[26] = sptr2[stride_w * 6]; - pp[27] = sptr3[stride_w * 6]; - pp[28] = sptr0[stride_w * 7]; - pp[29] = sptr1[stride_w * 7]; - pp[30] = sptr2[stride_w * 7]; - pp[31] = sptr3[stride_w * 7]; - pp += 32; - } - } -#endif // __ARM_FEATURE_DOTPROD - for (; kk + 1 < max_kk; kk += 2) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = dx0 + dilation_w * v0; - int y00 = dy0 + dilation_h * u0; - - int x10 = dx0 + dilation_w * v1; - int y10 = dy0 + dilation_h * u1; - - const signed char* sptr0 = img0.row(y00) + x00; - const signed char* sptr1 = img1.row(y10) + x10; - - if (stride_w == 1) - { - int8x8x2_t _r01; - _r01.val[0] = vld1_s8(sptr0); - _r01.val[1] = vld1_s8(sptr1); - vst2_s8(pp, _r01); - pp += 16; - } - else if (stride_w == 2) - { - int8x16_t _r0 = vld1q_s8(sptr0); - int8x16_t _r1 = vld1q_s8(sptr1); - int8x16_t _r01 = vtrnq_s8(_r0, _r1).val[0]; - vst1q_s8(pp, _r01); - pp += 16; - } - else - { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr0[stride_w]; - pp[3] = sptr1[stride_w]; - pp[4] = sptr0[stride_w * 2]; - pp[5] = sptr1[stride_w * 2]; - pp[6] = sptr0[stride_w * 3]; - pp[7] = sptr1[stride_w * 3]; - pp[8] = sptr0[stride_w * 4]; - pp[9] = sptr1[stride_w * 4]; - pp[10] = sptr0[stride_w * 5]; - pp[11] = sptr1[stride_w * 5]; - pp[12] = sptr0[stride_w * 6]; - pp[13] = sptr1[stride_w * 6]; - pp[14] = sptr0[stride_w * 7]; - pp[15] = sptr1[stride_w * 7]; - pp += 16; - } - } - } - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x0 = dx0 + dilation_w * v; - int y0 = dy0 + dilation_h * u; - - const signed char* sptr = img.row(y0) + x0 * elempack; - - if (elempack == 8) - { -#if __ARM_FEATURE_MATMUL_INT8 - int8x8_t _r0 = vld1_s8(sptr); - int8x8_t _r1 = vld1_s8(sptr + stride_w * 8); - int8x8_t _r2 = vld1_s8(sptr + stride_w * 16); - int8x8_t _r3 = vld1_s8(sptr + stride_w * 24); - int8x8_t _r4 = vld1_s8(sptr + stride_w * 32); - int8x8_t _r5 = vld1_s8(sptr + stride_w * 40); - int8x8_t _r6 = vld1_s8(sptr + stride_w * 48); - int8x8_t _r7 = vld1_s8(sptr + stride_w * 56); - vst1_s8(pp, _r0); - vst1_s8(pp + 8, _r1); - vst1_s8(pp + 16, _r2); - vst1_s8(pp + 24, _r3); - vst1_s8(pp + 32, _r4); - vst1_s8(pp + 40, _r5); - vst1_s8(pp + 48, _r6); - vst1_s8(pp + 56, _r7); - pp += 64; -#elif __ARM_FEATURE_DOTPROD - int32x2_t _r0 = vreinterpret_s32_s8(vld1_s8(sptr)); - int32x2_t _r1 = vreinterpret_s32_s8(vld1_s8(sptr + stride_w * 8)); - int32x2_t _r2 = vreinterpret_s32_s8(vld1_s8(sptr + stride_w * 16)); - int32x2_t _r3 = vreinterpret_s32_s8(vld1_s8(sptr + stride_w * 24)); - int32x2_t _r4 = vreinterpret_s32_s8(vld1_s8(sptr + stride_w * 32)); - int32x2_t _r5 = vreinterpret_s32_s8(vld1_s8(sptr + stride_w * 40)); - int32x2_t _r6 = vreinterpret_s32_s8(vld1_s8(sptr + stride_w * 48)); - int32x2_t _r7 = vreinterpret_s32_s8(vld1_s8(sptr + stride_w * 56)); - int32x2x2_t _r01 = vzip_s32(_r0, _r1); - int32x2x2_t _r23 = vzip_s32(_r2, _r3); - int32x2x2_t _r45 = vzip_s32(_r4, _r5); - int32x2x2_t _r67 = vzip_s32(_r6, _r7); - vst1_s32((int*)pp, _r01.val[0]); - vst1_s32((int*)(pp + 8), _r23.val[0]); - vst1_s32((int*)(pp + 16), _r45.val[0]); - vst1_s32((int*)(pp + 24), _r67.val[0]); - vst1_s32((int*)(pp + 32), _r01.val[1]); - vst1_s32((int*)(pp + 40), _r23.val[1]); - vst1_s32((int*)(pp + 48), _r45.val[1]); - vst1_s32((int*)(pp + 56), _r67.val[1]); - pp += 64; -#else // __ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD - int16x4_t _r0 = vreinterpret_s16_s8(vld1_s8(sptr)); - int16x4_t _r1 = vreinterpret_s16_s8(vld1_s8(sptr + stride_w * 8)); - int16x4_t _r2 = vreinterpret_s16_s8(vld1_s8(sptr + stride_w * 16)); - int16x4_t _r3 = vreinterpret_s16_s8(vld1_s8(sptr + stride_w * 24)); - int16x4_t _r4 = vreinterpret_s16_s8(vld1_s8(sptr + stride_w * 32)); - int16x4_t _r5 = vreinterpret_s16_s8(vld1_s8(sptr + stride_w * 40)); - int16x4_t _r6 = vreinterpret_s16_s8(vld1_s8(sptr + stride_w * 48)); - int16x4_t _r7 = vreinterpret_s16_s8(vld1_s8(sptr + stride_w * 56)); - int16x4x2_t _r01 = vzip_s16(_r0, _r1); - int16x4x2_t _r23 = vzip_s16(_r2, _r3); - int16x4x2_t _r45 = vzip_s16(_r4, _r5); - int16x4x2_t _r67 = vzip_s16(_r6, _r7); - int32x4x4_t _r0123; - _r0123.val[0] = vreinterpretq_s32_s16(vcombine_s16(_r01.val[0], _r01.val[1])); - _r0123.val[1] = vreinterpretq_s32_s16(vcombine_s16(_r23.val[0], _r23.val[1])); - _r0123.val[2] = vreinterpretq_s32_s16(vcombine_s16(_r45.val[0], _r45.val[1])); - _r0123.val[3] = vreinterpretq_s32_s16(vcombine_s16(_r67.val[0], _r67.val[1])); - vst4q_s32((int*)pp, _r0123); - pp += 64; -#endif // __ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD - } - if (elempack == 1) - { - pp[0] = sptr[0]; - pp[1] = sptr[stride_w]; - pp[2] = sptr[stride_w * 2]; - pp[3] = sptr[stride_w * 3]; - pp[4] = sptr[stride_w * 4]; - pp[5] = sptr[stride_w * 5]; - pp[6] = sptr[stride_w * 6]; - pp[7] = sptr[stride_w * 7]; - pp += 8; - } - } - } - else - { - int kk = 0; - if (elempack == 1) - { -#if __ARM_FEATURE_DOTPROD -#if __ARM_FEATURE_MATMUL_INT8 - for (; kk + 7 < max_kk; kk += 8) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int p2 = (k + kk + 2) / maxk; - int p3 = (k + kk + 3) / maxk; - int p4 = (k + kk + 4) / maxk; - int p5 = (k + kk + 5) / maxk; - int p6 = (k + kk + 6) / maxk; - int p7 = (k + kk + 7) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int uv2 = (k + kk + 2) % maxk; - int uv3 = (k + kk + 3) % maxk; - int uv4 = (k + kk + 4) % maxk; - int uv5 = (k + kk + 5) % maxk; - int uv6 = (k + kk + 6) % maxk; - int uv7 = (k + kk + 7) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int u2 = uv2 / kernel_w; - int u3 = uv3 / kernel_w; - int u4 = uv4 / kernel_w; - int u5 = uv5 / kernel_w; - int u6 = uv6 / kernel_w; - int u7 = uv7 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - int v2 = uv2 % kernel_w; - int v3 = uv3 % kernel_w; - int v4 = uv4 % kernel_w; - int v5 = uv5 % kernel_w; - int v6 = uv6 % kernel_w; - int v7 = uv7 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - const Mat img2 = bottom_blob.channel(p2); - const Mat img3 = bottom_blob.channel(p3); - const Mat img4 = bottom_blob.channel(p4); - const Mat img5 = bottom_blob.channel(p5); - const Mat img6 = bottom_blob.channel(p6); - const Mat img7 = bottom_blob.channel(p7); - - int x00 = dx0 + dilation_w * v0; - int x01 = dx1 + dilation_w * v0; - int x02 = dx2 + dilation_w * v0; - int x03 = dx3 + dilation_w * v0; - int x04 = dx4 + dilation_w * v0; - int x05 = dx5 + dilation_w * v0; - int x06 = dx6 + dilation_w * v0; - int x07 = dx7 + dilation_w * v0; - int y00 = dy0 + dilation_h * u0; - int y01 = dy1 + dilation_h * u0; - int y02 = dy2 + dilation_h * u0; - int y03 = dy3 + dilation_h * u0; - int y04 = dy4 + dilation_h * u0; - int y05 = dy5 + dilation_h * u0; - int y06 = dy6 + dilation_h * u0; - int y07 = dy7 + dilation_h * u0; - - int x10 = dx0 + dilation_w * v1; - int x11 = dx1 + dilation_w * v1; - int x12 = dx2 + dilation_w * v1; - int x13 = dx3 + dilation_w * v1; - int x14 = dx4 + dilation_w * v1; - int x15 = dx5 + dilation_w * v1; - int x16 = dx6 + dilation_w * v1; - int x17 = dx7 + dilation_w * v1; - int y10 = dy0 + dilation_h * u1; - int y11 = dy1 + dilation_h * u1; - int y12 = dy2 + dilation_h * u1; - int y13 = dy3 + dilation_h * u1; - int y14 = dy4 + dilation_h * u1; - int y15 = dy5 + dilation_h * u1; - int y16 = dy6 + dilation_h * u1; - int y17 = dy7 + dilation_h * u1; - - int x20 = dx0 + dilation_w * v2; - int x21 = dx1 + dilation_w * v2; - int x22 = dx2 + dilation_w * v2; - int x23 = dx3 + dilation_w * v2; - int x24 = dx4 + dilation_w * v2; - int x25 = dx5 + dilation_w * v2; - int x26 = dx6 + dilation_w * v2; - int x27 = dx7 + dilation_w * v2; - int y20 = dy0 + dilation_h * u2; - int y21 = dy1 + dilation_h * u2; - int y22 = dy2 + dilation_h * u2; - int y23 = dy3 + dilation_h * u2; - int y24 = dy4 + dilation_h * u2; - int y25 = dy5 + dilation_h * u2; - int y26 = dy6 + dilation_h * u2; - int y27 = dy7 + dilation_h * u2; - - int x30 = dx0 + dilation_w * v3; - int x31 = dx1 + dilation_w * v3; - int x32 = dx2 + dilation_w * v3; - int x33 = dx3 + dilation_w * v3; - int x34 = dx4 + dilation_w * v3; - int x35 = dx5 + dilation_w * v3; - int x36 = dx6 + dilation_w * v3; - int x37 = dx7 + dilation_w * v3; - int y30 = dy0 + dilation_h * u3; - int y31 = dy1 + dilation_h * u3; - int y32 = dy2 + dilation_h * u3; - int y33 = dy3 + dilation_h * u3; - int y34 = dy4 + dilation_h * u3; - int y35 = dy5 + dilation_h * u3; - int y36 = dy6 + dilation_h * u3; - int y37 = dy7 + dilation_h * u3; - - int x40 = dx0 + dilation_w * v4; - int x41 = dx1 + dilation_w * v4; - int x42 = dx2 + dilation_w * v4; - int x43 = dx3 + dilation_w * v4; - int x44 = dx4 + dilation_w * v4; - int x45 = dx5 + dilation_w * v4; - int x46 = dx6 + dilation_w * v4; - int x47 = dx7 + dilation_w * v4; - int y40 = dy0 + dilation_h * u4; - int y41 = dy1 + dilation_h * u4; - int y42 = dy2 + dilation_h * u4; - int y43 = dy3 + dilation_h * u4; - int y44 = dy4 + dilation_h * u4; - int y45 = dy5 + dilation_h * u4; - int y46 = dy6 + dilation_h * u4; - int y47 = dy7 + dilation_h * u4; - - int x50 = dx0 + dilation_w * v5; - int x51 = dx1 + dilation_w * v5; - int x52 = dx2 + dilation_w * v5; - int x53 = dx3 + dilation_w * v5; - int x54 = dx4 + dilation_w * v5; - int x55 = dx5 + dilation_w * v5; - int x56 = dx6 + dilation_w * v5; - int x57 = dx7 + dilation_w * v5; - int y50 = dy0 + dilation_h * u5; - int y51 = dy1 + dilation_h * u5; - int y52 = dy2 + dilation_h * u5; - int y53 = dy3 + dilation_h * u5; - int y54 = dy4 + dilation_h * u5; - int y55 = dy5 + dilation_h * u5; - int y56 = dy6 + dilation_h * u5; - int y57 = dy7 + dilation_h * u5; - - int x60 = dx0 + dilation_w * v6; - int x61 = dx1 + dilation_w * v6; - int x62 = dx2 + dilation_w * v6; - int x63 = dx3 + dilation_w * v6; - int x64 = dx4 + dilation_w * v6; - int x65 = dx5 + dilation_w * v6; - int x66 = dx6 + dilation_w * v6; - int x67 = dx7 + dilation_w * v6; - int y60 = dy0 + dilation_h * u6; - int y61 = dy1 + dilation_h * u6; - int y62 = dy2 + dilation_h * u6; - int y63 = dy3 + dilation_h * u6; - int y64 = dy4 + dilation_h * u6; - int y65 = dy5 + dilation_h * u6; - int y66 = dy6 + dilation_h * u6; - int y67 = dy7 + dilation_h * u6; - - int x70 = dx0 + dilation_w * v7; - int x71 = dx1 + dilation_w * v7; - int x72 = dx2 + dilation_w * v7; - int x73 = dx3 + dilation_w * v7; - int x74 = dx4 + dilation_w * v7; - int x75 = dx5 + dilation_w * v7; - int x76 = dx6 + dilation_w * v7; - int x77 = dx7 + dilation_w * v7; - int y70 = dy0 + dilation_h * u7; - int y71 = dy1 + dilation_h * u7; - int y72 = dy2 + dilation_h * u7; - int y73 = dy3 + dilation_h * u7; - int y74 = dy4 + dilation_h * u7; - int y75 = dy5 + dilation_h * u7; - int y76 = dy6 + dilation_h * u7; - int y77 = dy7 + dilation_h * u7; - - const signed char* sptr00 = img0.row(y00) + x00; - const signed char* sptr01 = img0.row(y01) + x01; - const signed char* sptr02 = img0.row(y02) + x02; - const signed char* sptr03 = img0.row(y03) + x03; - const signed char* sptr04 = img0.row(y04) + x04; - const signed char* sptr05 = img0.row(y05) + x05; - const signed char* sptr06 = img0.row(y06) + x06; - const signed char* sptr07 = img0.row(y07) + x07; - - const signed char* sptr10 = img1.row(y10) + x10; - const signed char* sptr11 = img1.row(y11) + x11; - const signed char* sptr12 = img1.row(y12) + x12; - const signed char* sptr13 = img1.row(y13) + x13; - const signed char* sptr14 = img1.row(y14) + x14; - const signed char* sptr15 = img1.row(y15) + x15; - const signed char* sptr16 = img1.row(y16) + x16; - const signed char* sptr17 = img1.row(y17) + x17; - - const signed char* sptr20 = img2.row(y20) + x20; - const signed char* sptr21 = img2.row(y21) + x21; - const signed char* sptr22 = img2.row(y22) + x22; - const signed char* sptr23 = img2.row(y23) + x23; - const signed char* sptr24 = img2.row(y24) + x24; - const signed char* sptr25 = img2.row(y25) + x25; - const signed char* sptr26 = img2.row(y26) + x26; - const signed char* sptr27 = img2.row(y27) + x27; - - const signed char* sptr30 = img3.row(y30) + x30; - const signed char* sptr31 = img3.row(y31) + x31; - const signed char* sptr32 = img3.row(y32) + x32; - const signed char* sptr33 = img3.row(y33) + x33; - const signed char* sptr34 = img3.row(y34) + x34; - const signed char* sptr35 = img3.row(y35) + x35; - const signed char* sptr36 = img3.row(y36) + x36; - const signed char* sptr37 = img3.row(y37) + x37; - - const signed char* sptr40 = img4.row(y40) + x40; - const signed char* sptr41 = img4.row(y41) + x41; - const signed char* sptr42 = img4.row(y42) + x42; - const signed char* sptr43 = img4.row(y43) + x43; - const signed char* sptr44 = img4.row(y44) + x44; - const signed char* sptr45 = img4.row(y45) + x45; - const signed char* sptr46 = img4.row(y46) + x46; - const signed char* sptr47 = img4.row(y47) + x47; - - const signed char* sptr50 = img5.row(y50) + x50; - const signed char* sptr51 = img5.row(y51) + x51; - const signed char* sptr52 = img5.row(y52) + x52; - const signed char* sptr53 = img5.row(y53) + x53; - const signed char* sptr54 = img5.row(y54) + x54; - const signed char* sptr55 = img5.row(y55) + x55; - const signed char* sptr56 = img5.row(y56) + x56; - const signed char* sptr57 = img5.row(y57) + x57; - - const signed char* sptr60 = img6.row(y60) + x60; - const signed char* sptr61 = img6.row(y61) + x61; - const signed char* sptr62 = img6.row(y62) + x62; - const signed char* sptr63 = img6.row(y63) + x63; - const signed char* sptr64 = img6.row(y64) + x64; - const signed char* sptr65 = img6.row(y65) + x65; - const signed char* sptr66 = img6.row(y66) + x66; - const signed char* sptr67 = img6.row(y67) + x67; - - const signed char* sptr70 = img7.row(y70) + x70; - const signed char* sptr71 = img7.row(y71) + x71; - const signed char* sptr72 = img7.row(y72) + x72; - const signed char* sptr73 = img7.row(y73) + x73; - const signed char* sptr74 = img7.row(y74) + x74; - const signed char* sptr75 = img7.row(y75) + x75; - const signed char* sptr76 = img7.row(y76) + x76; - const signed char* sptr77 = img7.row(y77) + x77; - - pp[0] = sptr00[0]; - pp[1] = sptr10[0]; - pp[2] = sptr20[0]; - pp[3] = sptr30[0]; - pp[4] = sptr40[0]; - pp[5] = sptr50[0]; - pp[6] = sptr60[0]; - pp[7] = sptr70[0]; - pp[8] = sptr01[0]; - pp[9] = sptr11[0]; - pp[10] = sptr21[0]; - pp[11] = sptr31[0]; - pp[12] = sptr41[0]; - pp[13] = sptr51[0]; - pp[14] = sptr61[0]; - pp[15] = sptr71[0]; - pp[16] = sptr02[0]; - pp[17] = sptr12[0]; - pp[18] = sptr22[0]; - pp[19] = sptr32[0]; - pp[20] = sptr42[0]; - pp[21] = sptr52[0]; - pp[22] = sptr62[0]; - pp[23] = sptr72[0]; - pp[24] = sptr03[0]; - pp[25] = sptr13[0]; - pp[26] = sptr23[0]; - pp[27] = sptr33[0]; - pp[28] = sptr43[0]; - pp[29] = sptr53[0]; - pp[30] = sptr63[0]; - pp[31] = sptr73[0]; - pp[32] = sptr04[0]; - pp[33] = sptr14[0]; - pp[34] = sptr24[0]; - pp[35] = sptr34[0]; - pp[36] = sptr44[0]; - pp[37] = sptr54[0]; - pp[38] = sptr64[0]; - pp[39] = sptr74[0]; - pp[40] = sptr05[0]; - pp[41] = sptr15[0]; - pp[42] = sptr25[0]; - pp[43] = sptr35[0]; - pp[44] = sptr45[0]; - pp[45] = sptr55[0]; - pp[46] = sptr65[0]; - pp[47] = sptr75[0]; - pp[48] = sptr06[0]; - pp[49] = sptr16[0]; - pp[50] = sptr26[0]; - pp[51] = sptr36[0]; - pp[52] = sptr46[0]; - pp[53] = sptr56[0]; - pp[54] = sptr66[0]; - pp[55] = sptr76[0]; - pp[56] = sptr07[0]; - pp[57] = sptr17[0]; - pp[58] = sptr27[0]; - pp[59] = sptr37[0]; - pp[60] = sptr47[0]; - pp[61] = sptr57[0]; - pp[62] = sptr67[0]; - pp[63] = sptr77[0]; - pp += 64; - } -#endif // __ARM_FEATURE_MATMUL_INT8 - for (; kk + 3 < max_kk; kk += 4) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int p2 = (k + kk + 2) / maxk; - int p3 = (k + kk + 3) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int uv2 = (k + kk + 2) % maxk; - int uv3 = (k + kk + 3) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int u2 = uv2 / kernel_w; - int u3 = uv3 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - int v2 = uv2 % kernel_w; - int v3 = uv3 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - const Mat img2 = bottom_blob.channel(p2); - const Mat img3 = bottom_blob.channel(p3); - - int x00 = dx0 + dilation_w * v0; - int x01 = dx1 + dilation_w * v0; - int x02 = dx2 + dilation_w * v0; - int x03 = dx3 + dilation_w * v0; - int x04 = dx4 + dilation_w * v0; - int x05 = dx5 + dilation_w * v0; - int x06 = dx6 + dilation_w * v0; - int x07 = dx7 + dilation_w * v0; - int y00 = dy0 + dilation_h * u0; - int y01 = dy1 + dilation_h * u0; - int y02 = dy2 + dilation_h * u0; - int y03 = dy3 + dilation_h * u0; - int y04 = dy4 + dilation_h * u0; - int y05 = dy5 + dilation_h * u0; - int y06 = dy6 + dilation_h * u0; - int y07 = dy7 + dilation_h * u0; - - int x10 = dx0 + dilation_w * v1; - int x11 = dx1 + dilation_w * v1; - int x12 = dx2 + dilation_w * v1; - int x13 = dx3 + dilation_w * v1; - int x14 = dx4 + dilation_w * v1; - int x15 = dx5 + dilation_w * v1; - int x16 = dx6 + dilation_w * v1; - int x17 = dx7 + dilation_w * v1; - int y10 = dy0 + dilation_h * u1; - int y11 = dy1 + dilation_h * u1; - int y12 = dy2 + dilation_h * u1; - int y13 = dy3 + dilation_h * u1; - int y14 = dy4 + dilation_h * u1; - int y15 = dy5 + dilation_h * u1; - int y16 = dy6 + dilation_h * u1; - int y17 = dy7 + dilation_h * u1; - - int x20 = dx0 + dilation_w * v2; - int x21 = dx1 + dilation_w * v2; - int x22 = dx2 + dilation_w * v2; - int x23 = dx3 + dilation_w * v2; - int x24 = dx4 + dilation_w * v2; - int x25 = dx5 + dilation_w * v2; - int x26 = dx6 + dilation_w * v2; - int x27 = dx7 + dilation_w * v2; - int y20 = dy0 + dilation_h * u2; - int y21 = dy1 + dilation_h * u2; - int y22 = dy2 + dilation_h * u2; - int y23 = dy3 + dilation_h * u2; - int y24 = dy4 + dilation_h * u2; - int y25 = dy5 + dilation_h * u2; - int y26 = dy6 + dilation_h * u2; - int y27 = dy7 + dilation_h * u2; - - int x30 = dx0 + dilation_w * v3; - int x31 = dx1 + dilation_w * v3; - int x32 = dx2 + dilation_w * v3; - int x33 = dx3 + dilation_w * v3; - int x34 = dx4 + dilation_w * v3; - int x35 = dx5 + dilation_w * v3; - int x36 = dx6 + dilation_w * v3; - int x37 = dx7 + dilation_w * v3; - int y30 = dy0 + dilation_h * u3; - int y31 = dy1 + dilation_h * u3; - int y32 = dy2 + dilation_h * u3; - int y33 = dy3 + dilation_h * u3; - int y34 = dy4 + dilation_h * u3; - int y35 = dy5 + dilation_h * u3; - int y36 = dy6 + dilation_h * u3; - int y37 = dy7 + dilation_h * u3; - - const signed char* sptr00 = img0.row(y00) + x00; - const signed char* sptr01 = img0.row(y01) + x01; - const signed char* sptr02 = img0.row(y02) + x02; - const signed char* sptr03 = img0.row(y03) + x03; - const signed char* sptr04 = img0.row(y04) + x04; - const signed char* sptr05 = img0.row(y05) + x05; - const signed char* sptr06 = img0.row(y06) + x06; - const signed char* sptr07 = img0.row(y07) + x07; - - const signed char* sptr10 = img1.row(y10) + x10; - const signed char* sptr11 = img1.row(y11) + x11; - const signed char* sptr12 = img1.row(y12) + x12; - const signed char* sptr13 = img1.row(y13) + x13; - const signed char* sptr14 = img1.row(y14) + x14; - const signed char* sptr15 = img1.row(y15) + x15; - const signed char* sptr16 = img1.row(y16) + x16; - const signed char* sptr17 = img1.row(y17) + x17; - - const signed char* sptr20 = img2.row(y20) + x20; - const signed char* sptr21 = img2.row(y21) + x21; - const signed char* sptr22 = img2.row(y22) + x22; - const signed char* sptr23 = img2.row(y23) + x23; - const signed char* sptr24 = img2.row(y24) + x24; - const signed char* sptr25 = img2.row(y25) + x25; - const signed char* sptr26 = img2.row(y26) + x26; - const signed char* sptr27 = img2.row(y27) + x27; - - const signed char* sptr30 = img3.row(y30) + x30; - const signed char* sptr31 = img3.row(y31) + x31; - const signed char* sptr32 = img3.row(y32) + x32; - const signed char* sptr33 = img3.row(y33) + x33; - const signed char* sptr34 = img3.row(y34) + x34; - const signed char* sptr35 = img3.row(y35) + x35; - const signed char* sptr36 = img3.row(y36) + x36; - const signed char* sptr37 = img3.row(y37) + x37; - - pp[0] = sptr00[0]; - pp[1] = sptr10[0]; - pp[2] = sptr20[0]; - pp[3] = sptr30[0]; - pp[4] = sptr01[0]; - pp[5] = sptr11[0]; - pp[6] = sptr21[0]; - pp[7] = sptr31[0]; - pp[8] = sptr02[0]; - pp[9] = sptr12[0]; - pp[10] = sptr22[0]; - pp[11] = sptr32[0]; - pp[12] = sptr03[0]; - pp[13] = sptr13[0]; - pp[14] = sptr23[0]; - pp[15] = sptr33[0]; - pp[16] = sptr04[0]; - pp[17] = sptr14[0]; - pp[18] = sptr24[0]; - pp[19] = sptr34[0]; - pp[20] = sptr05[0]; - pp[21] = sptr15[0]; - pp[22] = sptr25[0]; - pp[23] = sptr35[0]; - pp[24] = sptr06[0]; - pp[25] = sptr16[0]; - pp[26] = sptr26[0]; - pp[27] = sptr36[0]; - pp[28] = sptr07[0]; - pp[29] = sptr17[0]; - pp[30] = sptr27[0]; - pp[31] = sptr37[0]; - pp += 32; - } -#endif // __ARM_FEATURE_DOTPROD - for (; kk + 1 < max_kk; kk += 2) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = dx0 + dilation_w * v0; - int x01 = dx1 + dilation_w * v0; - int x02 = dx2 + dilation_w * v0; - int x03 = dx3 + dilation_w * v0; - int x04 = dx4 + dilation_w * v0; - int x05 = dx5 + dilation_w * v0; - int x06 = dx6 + dilation_w * v0; - int x07 = dx7 + dilation_w * v0; - int y00 = dy0 + dilation_h * u0; - int y01 = dy1 + dilation_h * u0; - int y02 = dy2 + dilation_h * u0; - int y03 = dy3 + dilation_h * u0; - int y04 = dy4 + dilation_h * u0; - int y05 = dy5 + dilation_h * u0; - int y06 = dy6 + dilation_h * u0; - int y07 = dy7 + dilation_h * u0; - - int x10 = dx0 + dilation_w * v1; - int x11 = dx1 + dilation_w * v1; - int x12 = dx2 + dilation_w * v1; - int x13 = dx3 + dilation_w * v1; - int x14 = dx4 + dilation_w * v1; - int x15 = dx5 + dilation_w * v1; - int x16 = dx6 + dilation_w * v1; - int x17 = dx7 + dilation_w * v1; - int y10 = dy0 + dilation_h * u1; - int y11 = dy1 + dilation_h * u1; - int y12 = dy2 + dilation_h * u1; - int y13 = dy3 + dilation_h * u1; - int y14 = dy4 + dilation_h * u1; - int y15 = dy5 + dilation_h * u1; - int y16 = dy6 + dilation_h * u1; - int y17 = dy7 + dilation_h * u1; - - const signed char* sptr00 = img0.row(y00) + x00; - const signed char* sptr01 = img0.row(y01) + x01; - const signed char* sptr02 = img0.row(y02) + x02; - const signed char* sptr03 = img0.row(y03) + x03; - const signed char* sptr04 = img0.row(y04) + x04; - const signed char* sptr05 = img0.row(y05) + x05; - const signed char* sptr06 = img0.row(y06) + x06; - const signed char* sptr07 = img0.row(y07) + x07; - - const signed char* sptr10 = img1.row(y10) + x10; - const signed char* sptr11 = img1.row(y11) + x11; - const signed char* sptr12 = img1.row(y12) + x12; - const signed char* sptr13 = img1.row(y13) + x13; - const signed char* sptr14 = img1.row(y14) + x14; - const signed char* sptr15 = img1.row(y15) + x15; - const signed char* sptr16 = img1.row(y16) + x16; - const signed char* sptr17 = img1.row(y17) + x17; - - pp[0] = sptr00[0]; - pp[1] = sptr10[0]; - pp[2] = sptr01[0]; - pp[3] = sptr11[0]; - pp[4] = sptr02[0]; - pp[5] = sptr12[0]; - pp[6] = sptr03[0]; - pp[7] = sptr13[0]; - pp[8] = sptr04[0]; - pp[9] = sptr14[0]; - pp[10] = sptr05[0]; - pp[11] = sptr15[0]; - pp[12] = sptr06[0]; - pp[13] = sptr16[0]; - pp[14] = sptr07[0]; - pp[15] = sptr17[0]; - pp += 16; - } - } - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x0 = dx0 + dilation_w * v; - int x1 = dx1 + dilation_w * v; - int x2 = dx2 + dilation_w * v; - int x3 = dx3 + dilation_w * v; - int x4 = dx4 + dilation_w * v; - int x5 = dx5 + dilation_w * v; - int x6 = dx6 + dilation_w * v; - int x7 = dx7 + dilation_w * v; - int y0 = dy0 + dilation_h * u; - int y1 = dy1 + dilation_h * u; - int y2 = dy2 + dilation_h * u; - int y3 = dy3 + dilation_h * u; - int y4 = dy4 + dilation_h * u; - int y5 = dy5 + dilation_h * u; - int y6 = dy6 + dilation_h * u; - int y7 = dy7 + dilation_h * u; - - const signed char* sptr0 = img.row(y0) + x0 * elempack; - const signed char* sptr1 = img.row(y1) + x1 * elempack; - const signed char* sptr2 = img.row(y2) + x2 * elempack; - const signed char* sptr3 = img.row(y3) + x3 * elempack; - const signed char* sptr4 = img.row(y4) + x4 * elempack; - const signed char* sptr5 = img.row(y5) + x5 * elempack; - const signed char* sptr6 = img.row(y6) + x6 * elempack; - const signed char* sptr7 = img.row(y7) + x7 * elempack; - - if (elempack == 8) - { -#if __ARM_FEATURE_MATMUL_INT8 - int8x8_t _r0 = vld1_s8(sptr0); - int8x8_t _r1 = vld1_s8(sptr1); - int8x8_t _r2 = vld1_s8(sptr2); - int8x8_t _r3 = vld1_s8(sptr3); - int8x8_t _r4 = vld1_s8(sptr4); - int8x8_t _r5 = vld1_s8(sptr5); - int8x8_t _r6 = vld1_s8(sptr6); - int8x8_t _r7 = vld1_s8(sptr7); - vst1_s8(pp, _r0); - vst1_s8(pp + 8, _r1); - vst1_s8(pp + 16, _r2); - vst1_s8(pp + 24, _r3); - vst1_s8(pp + 32, _r4); - vst1_s8(pp + 40, _r5); - vst1_s8(pp + 48, _r6); - vst1_s8(pp + 56, _r7); - pp += 64; -#elif __ARM_FEATURE_DOTPROD - int32x2_t _r0 = vreinterpret_s32_s8(vld1_s8(sptr0)); - int32x2_t _r1 = vreinterpret_s32_s8(vld1_s8(sptr1)); - int32x2_t _r2 = vreinterpret_s32_s8(vld1_s8(sptr2)); - int32x2_t _r3 = vreinterpret_s32_s8(vld1_s8(sptr3)); - int32x2_t _r4 = vreinterpret_s32_s8(vld1_s8(sptr4)); - int32x2_t _r5 = vreinterpret_s32_s8(vld1_s8(sptr5)); - int32x2_t _r6 = vreinterpret_s32_s8(vld1_s8(sptr6)); - int32x2_t _r7 = vreinterpret_s32_s8(vld1_s8(sptr7)); - int32x2x2_t _r01 = vzip_s32(_r0, _r1); - int32x2x2_t _r23 = vzip_s32(_r2, _r3); - int32x2x2_t _r45 = vzip_s32(_r4, _r5); - int32x2x2_t _r67 = vzip_s32(_r6, _r7); - vst1_s32((int*)pp, _r01.val[0]); - vst1_s32((int*)(pp + 8), _r23.val[0]); - vst1_s32((int*)(pp + 16), _r45.val[0]); - vst1_s32((int*)(pp + 24), _r67.val[0]); - vst1_s32((int*)(pp + 32), _r01.val[1]); - vst1_s32((int*)(pp + 40), _r23.val[1]); - vst1_s32((int*)(pp + 48), _r45.val[1]); - vst1_s32((int*)(pp + 56), _r67.val[1]); - pp += 64; -#else // __ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD - int16x4_t _r0 = vreinterpret_s16_s8(vld1_s8(sptr0)); - int16x4_t _r1 = vreinterpret_s16_s8(vld1_s8(sptr1)); - int16x4_t _r2 = vreinterpret_s16_s8(vld1_s8(sptr2)); - int16x4_t _r3 = vreinterpret_s16_s8(vld1_s8(sptr3)); - int16x4_t _r4 = vreinterpret_s16_s8(vld1_s8(sptr4)); - int16x4_t _r5 = vreinterpret_s16_s8(vld1_s8(sptr5)); - int16x4_t _r6 = vreinterpret_s16_s8(vld1_s8(sptr6)); - int16x4_t _r7 = vreinterpret_s16_s8(vld1_s8(sptr7)); - int16x4x2_t _r01 = vzip_s16(_r0, _r1); - int16x4x2_t _r23 = vzip_s16(_r2, _r3); - int16x4x2_t _r45 = vzip_s16(_r4, _r5); - int16x4x2_t _r67 = vzip_s16(_r6, _r7); - int32x4x4_t _r0123; - _r0123.val[0] = vreinterpretq_s32_s16(vcombine_s16(_r01.val[0], _r01.val[1])); - _r0123.val[1] = vreinterpretq_s32_s16(vcombine_s16(_r23.val[0], _r23.val[1])); - _r0123.val[2] = vreinterpretq_s32_s16(vcombine_s16(_r45.val[0], _r45.val[1])); - _r0123.val[3] = vreinterpretq_s32_s16(vcombine_s16(_r67.val[0], _r67.val[1])); - vst4q_s32((int*)pp, _r0123); - pp += 64; -#endif // __ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD - } - if (elempack == 1) - { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr2[0]; - pp[3] = sptr3[0]; - pp[4] = sptr4[0]; - pp[5] = sptr5[0]; - pp[6] = sptr6[0]; - pp[7] = sptr7[0]; - pp += 8; - } - } - } - } -#endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) - { - int dy0 = (j + jj) / outw * stride_h; - int dy1 = (j + jj + 1) / outw * stride_h; - int dy2 = (j + jj + 2) / outw * stride_h; - int dy3 = (j + jj + 3) / outw * stride_h; - int dx0 = (j + jj) % outw * stride_w; - int dx1 = (j + jj + 1) % outw * stride_w; - int dx2 = (j + jj + 2) % outw * stride_w; - int dx3 = (j + jj + 3) % outw * stride_w; - - if (dy0 == dy3) - { - int kk = 0; - if (elempack == 1) - { -#if __ARM_FEATURE_DOTPROD -#if __ARM_FEATURE_MATMUL_INT8 - for (; kk + 7 < max_kk; kk += 8) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int p2 = (k + kk + 2) / maxk; - int p3 = (k + kk + 3) / maxk; - int p4 = (k + kk + 4) / maxk; - int p5 = (k + kk + 5) / maxk; - int p6 = (k + kk + 6) / maxk; - int p7 = (k + kk + 7) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int uv2 = (k + kk + 2) % maxk; - int uv3 = (k + kk + 3) % maxk; - int uv4 = (k + kk + 4) % maxk; - int uv5 = (k + kk + 5) % maxk; - int uv6 = (k + kk + 6) % maxk; - int uv7 = (k + kk + 7) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int u2 = uv2 / kernel_w; - int u3 = uv3 / kernel_w; - int u4 = uv4 / kernel_w; - int u5 = uv5 / kernel_w; - int u6 = uv6 / kernel_w; - int u7 = uv7 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - int v2 = uv2 % kernel_w; - int v3 = uv3 % kernel_w; - int v4 = uv4 % kernel_w; - int v5 = uv5 % kernel_w; - int v6 = uv6 % kernel_w; - int v7 = uv7 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - const Mat img2 = bottom_blob.channel(p2); - const Mat img3 = bottom_blob.channel(p3); - const Mat img4 = bottom_blob.channel(p4); - const Mat img5 = bottom_blob.channel(p5); - const Mat img6 = bottom_blob.channel(p6); - const Mat img7 = bottom_blob.channel(p7); - - int x00 = dx0 + dilation_w * v0; - int y00 = dy0 + dilation_h * u0; - - int x10 = dx0 + dilation_w * v1; - int y10 = dy0 + dilation_h * u1; - - int x20 = dx0 + dilation_w * v2; - int y20 = dy0 + dilation_h * u2; - - int x30 = dx0 + dilation_w * v3; - int y30 = dy0 + dilation_h * u3; - - int x40 = dx0 + dilation_w * v4; - int y40 = dy0 + dilation_h * u4; - - int x50 = dx0 + dilation_w * v5; - int y50 = dy0 + dilation_h * u5; - - int x60 = dx0 + dilation_w * v6; - int y60 = dy0 + dilation_h * u6; - - int x70 = dx0 + dilation_w * v7; - int y70 = dy0 + dilation_h * u7; - - const signed char* sptr0 = img0.row(y00) + x00; - const signed char* sptr1 = img1.row(y10) + x10; - const signed char* sptr2 = img2.row(y20) + x20; - const signed char* sptr3 = img3.row(y30) + x30; - const signed char* sptr4 = img4.row(y40) + x40; - const signed char* sptr5 = img5.row(y50) + x50; - const signed char* sptr6 = img6.row(y60) + x60; - const signed char* sptr7 = img7.row(y70) + x70; - - if (stride_w == 1) - { - int8x8_t _r0 = vld1_s8(sptr0); - int8x8_t _r1 = vld1_s8(sptr1); - int8x8_t _r2 = vld1_s8(sptr2); - int8x8_t _r3 = vld1_s8(sptr3); - int8x8_t _r4 = vld1_s8(sptr4); - int8x8_t _r5 = vld1_s8(sptr5); - int8x8_t _r6 = vld1_s8(sptr6); - int8x8_t _r7 = vld1_s8(sptr7); - int16x4x4_t _r0123; - _r0123.val[0] = vreinterpret_s16_s8(vzip_s8(_r0, _r1).val[0]); - _r0123.val[1] = vreinterpret_s16_s8(vzip_s8(_r2, _r3).val[0]); - _r0123.val[2] = vreinterpret_s16_s8(vzip_s8(_r4, _r5).val[0]); - _r0123.val[3] = vreinterpret_s16_s8(vzip_s8(_r6, _r7).val[0]); - vst4_s16((short*)pp, _r0123); - pp += 32; - } - else if (stride_w == 2) - { - int8x8_t _r0 = vld1_s8(sptr0); - int8x8_t _r1 = vld1_s8(sptr1); - int8x8_t _r2 = vld1_s8(sptr2); - int8x8_t _r3 = vld1_s8(sptr3); - int8x8_t _r4 = vld1_s8(sptr4); - int8x8_t _r5 = vld1_s8(sptr5); - int8x8_t _r6 = vld1_s8(sptr6); - int8x8_t _r7 = vld1_s8(sptr7); - int8x8_t _r01 = vtrn_s8(_r0, _r1).val[0]; - int8x8_t _r23 = vtrn_s8(_r2, _r3).val[0]; - int8x8_t _r45 = vtrn_s8(_r4, _r5).val[0]; - int8x8_t _r67 = vtrn_s8(_r6, _r7).val[0]; - int16x4x4_t _r0123; - _r0123.val[0] = vreinterpret_s16_s8(_r01); - _r0123.val[1] = vreinterpret_s16_s8(_r23); - _r0123.val[2] = vreinterpret_s16_s8(_r45); - _r0123.val[3] = vreinterpret_s16_s8(_r67); - vst4_s16((short*)pp, _r0123); - pp += 32; - } - else - { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr2[0]; - pp[3] = sptr3[0]; - pp[4] = sptr4[0]; - pp[5] = sptr5[0]; - pp[6] = sptr6[0]; - pp[7] = sptr7[0]; - pp[8] = sptr0[stride_w]; - pp[9] = sptr1[stride_w]; - pp[10] = sptr2[stride_w]; - pp[11] = sptr3[stride_w]; - pp[12] = sptr4[stride_w]; - pp[13] = sptr5[stride_w]; - pp[14] = sptr6[stride_w]; - pp[15] = sptr7[stride_w]; - pp[16] = sptr0[stride_w * 2]; - pp[17] = sptr1[stride_w * 2]; - pp[18] = sptr2[stride_w * 2]; - pp[19] = sptr3[stride_w * 2]; - pp[20] = sptr4[stride_w * 2]; - pp[21] = sptr5[stride_w * 2]; - pp[22] = sptr6[stride_w * 2]; - pp[23] = sptr7[stride_w * 2]; - pp[24] = sptr0[stride_w * 3]; - pp[25] = sptr1[stride_w * 3]; - pp[26] = sptr2[stride_w * 3]; - pp[27] = sptr3[stride_w * 3]; - pp[28] = sptr4[stride_w * 3]; - pp[29] = sptr5[stride_w * 3]; - pp[30] = sptr6[stride_w * 3]; - pp[31] = sptr7[stride_w * 3]; - pp += 32; - } - } -#endif // __ARM_FEATURE_MATMUL_INT8 - for (; kk + 3 < max_kk; kk += 4) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int p2 = (k + kk + 2) / maxk; - int p3 = (k + kk + 3) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int uv2 = (k + kk + 2) % maxk; - int uv3 = (k + kk + 3) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int u2 = uv2 / kernel_w; - int u3 = uv3 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - int v2 = uv2 % kernel_w; - int v3 = uv3 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - const Mat img2 = bottom_blob.channel(p2); - const Mat img3 = bottom_blob.channel(p3); - - int x00 = dx0 + dilation_w * v0; - int y00 = dy0 + dilation_h * u0; - - int x10 = dx0 + dilation_w * v1; - int y10 = dy0 + dilation_h * u1; - - int x20 = dx0 + dilation_w * v2; - int y20 = dy0 + dilation_h * u2; - - int x30 = dx0 + dilation_w * v3; - int y30 = dy0 + dilation_h * u3; - - const signed char* sptr0 = img0.row(y00) + x00; - const signed char* sptr1 = img1.row(y10) + x10; - const signed char* sptr2 = img2.row(y20) + x20; - const signed char* sptr3 = img3.row(y30) + x30; - - if (stride_w == 1) - { - int8x8_t _r0 = vld1_s8(sptr0); - int8x8_t _r1 = vld1_s8(sptr1); - int8x8_t _r2 = vld1_s8(sptr2); - int8x8_t _r3 = vld1_s8(sptr3); - int16x4x2_t _r01; - _r01.val[0] = vreinterpret_s16_s8(vzip_s8(_r0, _r1).val[0]); - _r01.val[1] = vreinterpret_s16_s8(vzip_s8(_r2, _r3).val[0]); - vst2_s16((short*)pp, _r01); - pp += 16; - } - else if (stride_w == 2) - { - int8x8_t _r0 = vld1_s8(sptr0); - int8x8_t _r1 = vld1_s8(sptr1); - int8x8_t _r2 = vld1_s8(sptr2); - int8x8_t _r3 = vld1_s8(sptr3); - int8x8_t _r01 = vtrn_s8(_r0, _r1).val[0]; - int8x8_t _r23 = vtrn_s8(_r2, _r3).val[0]; - int16x4x2_t _r0123; - _r0123.val[0] = vreinterpret_s16_s8(_r01); - _r0123.val[1] = vreinterpret_s16_s8(_r23); - vst2_s16((short*)pp, _r0123); - pp += 16; - } - else - { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr2[0]; - pp[3] = sptr3[0]; - pp[4] = sptr0[stride_w]; - pp[5] = sptr1[stride_w]; - pp[6] = sptr2[stride_w]; - pp[7] = sptr3[stride_w]; - pp[8] = sptr0[stride_w * 2]; - pp[9] = sptr1[stride_w * 2]; - pp[10] = sptr2[stride_w * 2]; - pp[11] = sptr3[stride_w * 2]; - pp[12] = sptr0[stride_w * 3]; - pp[13] = sptr1[stride_w * 3]; - pp[14] = sptr2[stride_w * 3]; - pp[15] = sptr3[stride_w * 3]; - pp += 16; - } - } -#endif // __ARM_FEATURE_DOTPROD - for (; kk + 1 < max_kk; kk += 2) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = dx0 + dilation_w * v0; - int y00 = dy0 + dilation_h * u0; - - int x10 = dx0 + dilation_w * v1; - int y10 = dy0 + dilation_h * u1; - - const signed char* sptr0 = img0.row(y00) + x00; - const signed char* sptr1 = img1.row(y10) + x10; - - if (stride_w == 1) - { - int8x8_t _r0 = vld1_s8(sptr0); - int8x8_t _r1 = vld1_s8(sptr1); - int8x8_t _r01 = vzip_s8(_r0, _r1).val[0]; - vst1_s8(pp, _r01); - pp += 8; - } - else if (stride_w == 2) - { - int8x8_t _r0 = vld1_s8(sptr0); - int8x8_t _r1 = vld1_s8(sptr1); - int8x8_t _r01 = vtrn_s8(_r0, _r1).val[0]; - vst1_s8(pp, _r01); - pp += 8; - } - else - { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr0[stride_w]; - pp[3] = sptr1[stride_w]; - pp[4] = sptr0[stride_w * 2]; - pp[5] = sptr1[stride_w * 2]; - pp[6] = sptr0[stride_w * 3]; - pp[7] = sptr1[stride_w * 3]; - pp += 8; - } - } - } - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x0 = dx0 + dilation_w * v; - int y0 = dy0 + dilation_h * u; - - const signed char* sptr = img.row(y0) + x0 * elempack; - - if (elempack == 8) - { -#if __ARM_FEATURE_MATMUL_INT8 - int8x8_t _r0 = vld1_s8(sptr); - int8x8_t _r1 = vld1_s8(sptr + stride_w * 8); - int8x8_t _r2 = vld1_s8(sptr + stride_w * 16); - int8x8_t _r3 = vld1_s8(sptr + stride_w * 24); - vst1_s8(pp, _r0); - vst1_s8(pp + 8, _r1); - vst1_s8(pp + 16, _r2); - vst1_s8(pp + 24, _r3); - pp += 32; -#elif __ARM_FEATURE_DOTPROD - int32x2x4_t _r0123; - _r0123.val[0] = vreinterpret_s32_s8(vld1_s8(sptr)); - _r0123.val[1] = vreinterpret_s32_s8(vld1_s8(sptr + stride_w * 8)); - _r0123.val[2] = vreinterpret_s32_s8(vld1_s8(sptr + stride_w * 16)); - _r0123.val[3] = vreinterpret_s32_s8(vld1_s8(sptr + stride_w * 24)); - vst4_s32((int*)pp, _r0123); - pp += 32; -#else // __ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD - int16x4x4_t _r0123; - _r0123.val[0] = vreinterpret_s16_s8(vld1_s8(sptr)); - _r0123.val[1] = vreinterpret_s16_s8(vld1_s8(sptr + stride_w * 8)); - _r0123.val[2] = vreinterpret_s16_s8(vld1_s8(sptr + stride_w * 16)); - _r0123.val[3] = vreinterpret_s16_s8(vld1_s8(sptr + stride_w * 24)); - vst4_s16((short*)pp, _r0123); - pp += 32; -#endif // __ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD - } - if (elempack == 1) - { - pp[0] = sptr[0]; - pp[1] = sptr[stride_w]; - pp[2] = sptr[stride_w * 2]; - pp[3] = sptr[stride_w * 3]; - pp += 4; - } - } - } - else - { - int kk = 0; - if (elempack == 1) - { -#if __ARM_FEATURE_DOTPROD -#if __ARM_FEATURE_MATMUL_INT8 - for (; kk + 7 < max_kk; kk += 8) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int p2 = (k + kk + 2) / maxk; - int p3 = (k + kk + 3) / maxk; - int p4 = (k + kk + 4) / maxk; - int p5 = (k + kk + 5) / maxk; - int p6 = (k + kk + 6) / maxk; - int p7 = (k + kk + 7) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int uv2 = (k + kk + 2) % maxk; - int uv3 = (k + kk + 3) % maxk; - int uv4 = (k + kk + 4) % maxk; - int uv5 = (k + kk + 5) % maxk; - int uv6 = (k + kk + 6) % maxk; - int uv7 = (k + kk + 7) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int u2 = uv2 / kernel_w; - int u3 = uv3 / kernel_w; - int u4 = uv4 / kernel_w; - int u5 = uv5 / kernel_w; - int u6 = uv6 / kernel_w; - int u7 = uv7 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - int v2 = uv2 % kernel_w; - int v3 = uv3 % kernel_w; - int v4 = uv4 % kernel_w; - int v5 = uv5 % kernel_w; - int v6 = uv6 % kernel_w; - int v7 = uv7 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - const Mat img2 = bottom_blob.channel(p2); - const Mat img3 = bottom_blob.channel(p3); - const Mat img4 = bottom_blob.channel(p4); - const Mat img5 = bottom_blob.channel(p5); - const Mat img6 = bottom_blob.channel(p6); - const Mat img7 = bottom_blob.channel(p7); - - int x00 = dx0 + dilation_w * v0; - int x01 = dx1 + dilation_w * v0; - int x02 = dx2 + dilation_w * v0; - int x03 = dx3 + dilation_w * v0; - int y00 = dy0 + dilation_h * u0; - int y01 = dy1 + dilation_h * u0; - int y02 = dy2 + dilation_h * u0; - int y03 = dy3 + dilation_h * u0; - - int x10 = dx0 + dilation_w * v1; - int x11 = dx1 + dilation_w * v1; - int x12 = dx2 + dilation_w * v1; - int x13 = dx3 + dilation_w * v1; - int y10 = dy0 + dilation_h * u1; - int y11 = dy1 + dilation_h * u1; - int y12 = dy2 + dilation_h * u1; - int y13 = dy3 + dilation_h * u1; - - int x20 = dx0 + dilation_w * v2; - int x21 = dx1 + dilation_w * v2; - int x22 = dx2 + dilation_w * v2; - int x23 = dx3 + dilation_w * v2; - int y20 = dy0 + dilation_h * u2; - int y21 = dy1 + dilation_h * u2; - int y22 = dy2 + dilation_h * u2; - int y23 = dy3 + dilation_h * u2; - - int x30 = dx0 + dilation_w * v3; - int x31 = dx1 + dilation_w * v3; - int x32 = dx2 + dilation_w * v3; - int x33 = dx3 + dilation_w * v3; - int y30 = dy0 + dilation_h * u3; - int y31 = dy1 + dilation_h * u3; - int y32 = dy2 + dilation_h * u3; - int y33 = dy3 + dilation_h * u3; - - int x40 = dx0 + dilation_w * v4; - int x41 = dx1 + dilation_w * v4; - int x42 = dx2 + dilation_w * v4; - int x43 = dx3 + dilation_w * v4; - int y40 = dy0 + dilation_h * u4; - int y41 = dy1 + dilation_h * u4; - int y42 = dy2 + dilation_h * u4; - int y43 = dy3 + dilation_h * u4; - - int x50 = dx0 + dilation_w * v5; - int x51 = dx1 + dilation_w * v5; - int x52 = dx2 + dilation_w * v5; - int x53 = dx3 + dilation_w * v5; - int y50 = dy0 + dilation_h * u5; - int y51 = dy1 + dilation_h * u5; - int y52 = dy2 + dilation_h * u5; - int y53 = dy3 + dilation_h * u5; - - int x60 = dx0 + dilation_w * v6; - int x61 = dx1 + dilation_w * v6; - int x62 = dx2 + dilation_w * v6; - int x63 = dx3 + dilation_w * v6; - int y60 = dy0 + dilation_h * u6; - int y61 = dy1 + dilation_h * u6; - int y62 = dy2 + dilation_h * u6; - int y63 = dy3 + dilation_h * u6; - - int x70 = dx0 + dilation_w * v7; - int x71 = dx1 + dilation_w * v7; - int x72 = dx2 + dilation_w * v7; - int x73 = dx3 + dilation_w * v7; - int y70 = dy0 + dilation_h * u7; - int y71 = dy1 + dilation_h * u7; - int y72 = dy2 + dilation_h * u7; - int y73 = dy3 + dilation_h * u7; - - const signed char* sptr00 = img0.row(y00) + x00; - const signed char* sptr01 = img0.row(y01) + x01; - const signed char* sptr02 = img0.row(y02) + x02; - const signed char* sptr03 = img0.row(y03) + x03; - - const signed char* sptr10 = img1.row(y10) + x10; - const signed char* sptr11 = img1.row(y11) + x11; - const signed char* sptr12 = img1.row(y12) + x12; - const signed char* sptr13 = img1.row(y13) + x13; - - const signed char* sptr20 = img2.row(y20) + x20; - const signed char* sptr21 = img2.row(y21) + x21; - const signed char* sptr22 = img2.row(y22) + x22; - const signed char* sptr23 = img2.row(y23) + x23; - - const signed char* sptr30 = img3.row(y30) + x30; - const signed char* sptr31 = img3.row(y31) + x31; - const signed char* sptr32 = img3.row(y32) + x32; - const signed char* sptr33 = img3.row(y33) + x33; - - const signed char* sptr40 = img4.row(y40) + x40; - const signed char* sptr41 = img4.row(y41) + x41; - const signed char* sptr42 = img4.row(y42) + x42; - const signed char* sptr43 = img4.row(y43) + x43; - - const signed char* sptr50 = img5.row(y50) + x50; - const signed char* sptr51 = img5.row(y51) + x51; - const signed char* sptr52 = img5.row(y52) + x52; - const signed char* sptr53 = img5.row(y53) + x53; - - const signed char* sptr60 = img6.row(y60) + x60; - const signed char* sptr61 = img6.row(y61) + x61; - const signed char* sptr62 = img6.row(y62) + x62; - const signed char* sptr63 = img6.row(y63) + x63; - - const signed char* sptr70 = img7.row(y70) + x70; - const signed char* sptr71 = img7.row(y71) + x71; - const signed char* sptr72 = img7.row(y72) + x72; - const signed char* sptr73 = img7.row(y73) + x73; - - pp[0] = sptr00[0]; - pp[1] = sptr10[0]; - pp[2] = sptr20[0]; - pp[3] = sptr30[0]; - pp[4] = sptr40[0]; - pp[5] = sptr50[0]; - pp[6] = sptr60[0]; - pp[7] = sptr70[0]; - pp[8] = sptr01[0]; - pp[9] = sptr11[0]; - pp[10] = sptr21[0]; - pp[11] = sptr31[0]; - pp[12] = sptr41[0]; - pp[13] = sptr51[0]; - pp[14] = sptr61[0]; - pp[15] = sptr71[0]; - pp[16] = sptr02[0]; - pp[17] = sptr12[0]; - pp[18] = sptr22[0]; - pp[19] = sptr32[0]; - pp[20] = sptr42[0]; - pp[21] = sptr52[0]; - pp[22] = sptr62[0]; - pp[23] = sptr72[0]; - pp[24] = sptr03[0]; - pp[25] = sptr13[0]; - pp[26] = sptr23[0]; - pp[27] = sptr33[0]; - pp[28] = sptr43[0]; - pp[29] = sptr53[0]; - pp[30] = sptr63[0]; - pp[31] = sptr73[0]; - pp += 32; - } -#endif // __ARM_FEATURE_MATMUL_INT8 - for (; kk + 3 < max_kk; kk += 4) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int p2 = (k + kk + 2) / maxk; - int p3 = (k + kk + 3) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int uv2 = (k + kk + 2) % maxk; - int uv3 = (k + kk + 3) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int u2 = uv2 / kernel_w; - int u3 = uv3 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - int v2 = uv2 % kernel_w; - int v3 = uv3 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - const Mat img2 = bottom_blob.channel(p2); - const Mat img3 = bottom_blob.channel(p3); - - int x00 = dx0 + dilation_w * v0; - int x01 = dx1 + dilation_w * v0; - int x02 = dx2 + dilation_w * v0; - int x03 = dx3 + dilation_w * v0; - int y00 = dy0 + dilation_h * u0; - int y01 = dy1 + dilation_h * u0; - int y02 = dy2 + dilation_h * u0; - int y03 = dy3 + dilation_h * u0; - - int x10 = dx0 + dilation_w * v1; - int x11 = dx1 + dilation_w * v1; - int x12 = dx2 + dilation_w * v1; - int x13 = dx3 + dilation_w * v1; - int y10 = dy0 + dilation_h * u1; - int y11 = dy1 + dilation_h * u1; - int y12 = dy2 + dilation_h * u1; - int y13 = dy3 + dilation_h * u1; - - int x20 = dx0 + dilation_w * v2; - int x21 = dx1 + dilation_w * v2; - int x22 = dx2 + dilation_w * v2; - int x23 = dx3 + dilation_w * v2; - int y20 = dy0 + dilation_h * u2; - int y21 = dy1 + dilation_h * u2; - int y22 = dy2 + dilation_h * u2; - int y23 = dy3 + dilation_h * u2; - - int x30 = dx0 + dilation_w * v3; - int x31 = dx1 + dilation_w * v3; - int x32 = dx2 + dilation_w * v3; - int x33 = dx3 + dilation_w * v3; - int y30 = dy0 + dilation_h * u3; - int y31 = dy1 + dilation_h * u3; - int y32 = dy2 + dilation_h * u3; - int y33 = dy3 + dilation_h * u3; - - const signed char* sptr00 = img0.row(y00) + x00; - const signed char* sptr01 = img0.row(y01) + x01; - const signed char* sptr02 = img0.row(y02) + x02; - const signed char* sptr03 = img0.row(y03) + x03; - - const signed char* sptr10 = img1.row(y10) + x10; - const signed char* sptr11 = img1.row(y11) + x11; - const signed char* sptr12 = img1.row(y12) + x12; - const signed char* sptr13 = img1.row(y13) + x13; - - const signed char* sptr20 = img2.row(y20) + x20; - const signed char* sptr21 = img2.row(y21) + x21; - const signed char* sptr22 = img2.row(y22) + x22; - const signed char* sptr23 = img2.row(y23) + x23; - - const signed char* sptr30 = img3.row(y30) + x30; - const signed char* sptr31 = img3.row(y31) + x31; - const signed char* sptr32 = img3.row(y32) + x32; - const signed char* sptr33 = img3.row(y33) + x33; - - pp[0] = sptr00[0]; - pp[1] = sptr10[0]; - pp[2] = sptr20[0]; - pp[3] = sptr30[0]; - pp[4] = sptr01[0]; - pp[5] = sptr11[0]; - pp[6] = sptr21[0]; - pp[7] = sptr31[0]; - pp[8] = sptr02[0]; - pp[9] = sptr12[0]; - pp[10] = sptr22[0]; - pp[11] = sptr32[0]; - pp[12] = sptr03[0]; - pp[13] = sptr13[0]; - pp[14] = sptr23[0]; - pp[15] = sptr33[0]; - pp += 16; - } -#endif // __ARM_FEATURE_DOTPROD - for (; kk + 1 < max_kk; kk += 2) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = dx0 + dilation_w * v0; - int x01 = dx1 + dilation_w * v0; - int x02 = dx2 + dilation_w * v0; - int x03 = dx3 + dilation_w * v0; - int y00 = dy0 + dilation_h * u0; - int y01 = dy1 + dilation_h * u0; - int y02 = dy2 + dilation_h * u0; - int y03 = dy3 + dilation_h * u0; - - int x10 = dx0 + dilation_w * v1; - int x11 = dx1 + dilation_w * v1; - int x12 = dx2 + dilation_w * v1; - int x13 = dx3 + dilation_w * v1; - int y10 = dy0 + dilation_h * u1; - int y11 = dy1 + dilation_h * u1; - int y12 = dy2 + dilation_h * u1; - int y13 = dy3 + dilation_h * u1; - - const signed char* sptr00 = img0.row(y00) + x00; - const signed char* sptr01 = img0.row(y01) + x01; - const signed char* sptr02 = img0.row(y02) + x02; - const signed char* sptr03 = img0.row(y03) + x03; - - const signed char* sptr10 = img1.row(y10) + x10; - const signed char* sptr11 = img1.row(y11) + x11; - const signed char* sptr12 = img1.row(y12) + x12; - const signed char* sptr13 = img1.row(y13) + x13; - - pp[0] = sptr00[0]; - pp[1] = sptr10[0]; - pp[2] = sptr01[0]; - pp[3] = sptr11[0]; - pp[4] = sptr02[0]; - pp[5] = sptr12[0]; - pp[6] = sptr03[0]; - pp[7] = sptr13[0]; - pp += 8; - } - } - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x0 = dx0 + dilation_w * v; - int x1 = dx1 + dilation_w * v; - int x2 = dx2 + dilation_w * v; - int x3 = dx3 + dilation_w * v; - int y0 = dy0 + dilation_h * u; - int y1 = dy1 + dilation_h * u; - int y2 = dy2 + dilation_h * u; - int y3 = dy3 + dilation_h * u; - - const signed char* sptr0 = img.row(y0) + x0 * elempack; - const signed char* sptr1 = img.row(y1) + x1 * elempack; - const signed char* sptr2 = img.row(y2) + x2 * elempack; - const signed char* sptr3 = img.row(y3) + x3 * elempack; - - if (elempack == 8) - { -#if __ARM_FEATURE_MATMUL_INT8 - int8x8_t _r0 = vld1_s8(sptr0); - int8x8_t _r1 = vld1_s8(sptr1); - int8x8_t _r2 = vld1_s8(sptr2); - int8x8_t _r3 = vld1_s8(sptr3); - vst1_s8(pp, _r0); - vst1_s8(pp + 8, _r1); - vst1_s8(pp + 16, _r2); - vst1_s8(pp + 24, _r3); - pp += 32; -#elif __ARM_FEATURE_DOTPROD - int32x2x4_t _r0123; - _r0123.val[0] = vreinterpret_s32_s8(vld1_s8(sptr0)); - _r0123.val[1] = vreinterpret_s32_s8(vld1_s8(sptr1)); - _r0123.val[2] = vreinterpret_s32_s8(vld1_s8(sptr2)); - _r0123.val[3] = vreinterpret_s32_s8(vld1_s8(sptr3)); - vst4_s32((int*)pp, _r0123); - pp += 32; -#else // __ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD - int16x4x4_t _r0123; - _r0123.val[0] = vreinterpret_s16_s8(vld1_s8(sptr0)); - _r0123.val[1] = vreinterpret_s16_s8(vld1_s8(sptr1)); - _r0123.val[2] = vreinterpret_s16_s8(vld1_s8(sptr2)); - _r0123.val[3] = vreinterpret_s16_s8(vld1_s8(sptr3)); - vst4_s16((short*)pp, _r0123); - pp += 32; -#endif // __ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD - } - if (elempack == 1) - { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr2[0]; - pp[3] = sptr3[0]; - pp += 4; - } - } - } - } -#endif // __ARM_NEON - for (; jj + 1 < max_jj; jj += 2) - { - int dy0 = (j + jj) / outw * stride_h; - int dy1 = (j + jj + 1) / outw * stride_h; - int dx0 = (j + jj) % outw * stride_w; - int dx1 = (j + jj + 1) % outw * stride_w; - - if (dy0 == dy1) - { - int kk = 0; -#if __ARM_NEON - if (elempack == 1) - { -#if __ARM_FEATURE_DOTPROD -#if __ARM_FEATURE_MATMUL_INT8 - for (; kk + 7 < max_kk; kk += 8) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int p2 = (k + kk + 2) / maxk; - int p3 = (k + kk + 3) / maxk; - int p4 = (k + kk + 4) / maxk; - int p5 = (k + kk + 5) / maxk; - int p6 = (k + kk + 6) / maxk; - int p7 = (k + kk + 7) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int uv2 = (k + kk + 2) % maxk; - int uv3 = (k + kk + 3) % maxk; - int uv4 = (k + kk + 4) % maxk; - int uv5 = (k + kk + 5) % maxk; - int uv6 = (k + kk + 6) % maxk; - int uv7 = (k + kk + 7) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int u2 = uv2 / kernel_w; - int u3 = uv3 / kernel_w; - int u4 = uv4 / kernel_w; - int u5 = uv5 / kernel_w; - int u6 = uv6 / kernel_w; - int u7 = uv7 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - int v2 = uv2 % kernel_w; - int v3 = uv3 % kernel_w; - int v4 = uv4 % kernel_w; - int v5 = uv5 % kernel_w; - int v6 = uv6 % kernel_w; - int v7 = uv7 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - const Mat img2 = bottom_blob.channel(p2); - const Mat img3 = bottom_blob.channel(p3); - const Mat img4 = bottom_blob.channel(p4); - const Mat img5 = bottom_blob.channel(p5); - const Mat img6 = bottom_blob.channel(p6); - const Mat img7 = bottom_blob.channel(p7); - - int x00 = dx0 + dilation_w * v0; - int y00 = dy0 + dilation_h * u0; - int x10 = dx0 + dilation_w * v1; - int y10 = dy0 + dilation_h * u1; - - int x20 = dx0 + dilation_w * v2; - int y20 = dy0 + dilation_h * u2; - int x30 = dx0 + dilation_w * v3; - int y30 = dy0 + dilation_h * u3; - - int x40 = dx0 + dilation_w * v4; - int y40 = dy0 + dilation_h * u4; - int x50 = dx0 + dilation_w * v5; - int y50 = dy0 + dilation_h * u5; - - int x60 = dx0 + dilation_w * v6; - int y60 = dy0 + dilation_h * u6; - int x70 = dx0 + dilation_w * v7; - int y70 = dy0 + dilation_h * u7; - - const signed char* sptr0 = img0.row(y00) + x00; - const signed char* sptr1 = img1.row(y10) + x10; - const signed char* sptr2 = img2.row(y20) + x20; - const signed char* sptr3 = img3.row(y30) + x30; - - const signed char* sptr4 = img4.row(y40) + x40; - const signed char* sptr5 = img5.row(y50) + x50; - const signed char* sptr6 = img6.row(y60) + x60; - const signed char* sptr7 = img7.row(y70) + x70; - - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr2[0]; - pp[3] = sptr3[0]; - pp[4] = sptr4[0]; - pp[5] = sptr5[0]; - pp[6] = sptr6[0]; - pp[7] = sptr7[0]; - pp[8] = sptr0[stride_w]; - pp[9] = sptr1[stride_w]; - pp[10] = sptr2[stride_w]; - pp[11] = sptr3[stride_w]; - pp[12] = sptr4[stride_w]; - pp[13] = sptr5[stride_w]; - pp[14] = sptr6[stride_w]; - pp[15] = sptr7[stride_w]; - pp += 16; - } -#endif // __ARM_FEATURE_MATMUL_INT8 - for (; kk + 3 < max_kk; kk += 4) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int p2 = (k + kk + 2) / maxk; - int p3 = (k + kk + 3) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int uv2 = (k + kk + 2) % maxk; - int uv3 = (k + kk + 3) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int u2 = uv2 / kernel_w; - int u3 = uv3 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - int v2 = uv2 % kernel_w; - int v3 = uv3 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - const Mat img2 = bottom_blob.channel(p2); - const Mat img3 = bottom_blob.channel(p3); - - int x00 = dx0 + dilation_w * v0; - int y00 = dy0 + dilation_h * u0; - int x10 = dx0 + dilation_w * v1; - int y10 = dy0 + dilation_h * u1; - int x20 = dx0 + dilation_w * v2; - int y20 = dy0 + dilation_h * u2; - int x30 = dx0 + dilation_w * v3; - int y30 = dy0 + dilation_h * u3; - - const signed char* sptr0 = img0.row(y00) + x00; - const signed char* sptr1 = img1.row(y10) + x10; - const signed char* sptr2 = img2.row(y20) + x20; - const signed char* sptr3 = img3.row(y30) + x30; - - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr2[0]; - pp[3] = sptr3[0]; - pp[4] = sptr0[stride_w]; - pp[5] = sptr1[stride_w]; - pp[6] = sptr2[stride_w]; - pp[7] = sptr3[stride_w]; - pp += 8; - } -#endif // __ARM_FEATURE_DOTPROD - for (; kk + 1 < max_kk; kk += 2) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = dx0 + dilation_w * v0; - int y00 = dy0 + dilation_h * u0; - int x10 = dx0 + dilation_w * v1; - int y10 = dy0 + dilation_h * u1; - - const signed char* sptr0 = img0.row(y00) + x00; - const signed char* sptr1 = img1.row(y10) + x10; - - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr0[stride_w]; - pp[3] = sptr1[stride_w]; - pp += 4; - } - } -#endif // __ARM_NEON - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x0 = dx0 + dilation_w * v; - int y0 = dy0 + dilation_h * u; - - const signed char* sptr = img.row(y0) + x0 * elempack; - -#if __ARM_NEON - if (elempack == 8) - { -#if __ARM_FEATURE_MATMUL_INT8 - int8x8_t _r0 = vld1_s8(sptr); - int8x8_t _r1 = vld1_s8(sptr + stride_w * 8); - vst1_s8(pp, _r0); - vst1_s8(pp + 8, _r1); - pp += 16; -#elif __ARM_FEATURE_DOTPROD - int32x2x2_t _r01; - _r01.val[0] = vreinterpret_s32_s8(vld1_s8(sptr)); - _r01.val[1] = vreinterpret_s32_s8(vld1_s8(sptr + stride_w * 8)); - vst2_s32((int*)pp, _r01); - pp += 16; -#else // __ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD - int16x4x2_t _r01; - _r01.val[0] = vreinterpret_s16_s8(vld1_s8(sptr)); - _r01.val[1] = vreinterpret_s16_s8(vld1_s8(sptr + stride_w * 8)); - vst2_s16((short*)pp, _r01); - pp += 16; -#endif // __ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD - } -#endif // __ARM_NEON - if (elempack == 1) - { - pp[0] = sptr[0]; - pp[1] = sptr[stride_w]; - pp += 2; - } - } - } - else - { - int kk = 0; -#if __ARM_NEON - if (elempack == 1) - { -#if __ARM_FEATURE_DOTPROD -#if __ARM_FEATURE_MATMUL_INT8 - for (; kk + 7 < max_kk; kk += 8) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int p2 = (k + kk + 2) / maxk; - int p3 = (k + kk + 3) / maxk; - int p4 = (k + kk + 4) / maxk; - int p5 = (k + kk + 5) / maxk; - int p6 = (k + kk + 6) / maxk; - int p7 = (k + kk + 7) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int uv2 = (k + kk + 2) % maxk; - int uv3 = (k + kk + 3) % maxk; - int uv4 = (k + kk + 4) % maxk; - int uv5 = (k + kk + 5) % maxk; - int uv6 = (k + kk + 6) % maxk; - int uv7 = (k + kk + 7) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int u2 = uv2 / kernel_w; - int u3 = uv3 / kernel_w; - int u4 = uv4 / kernel_w; - int u5 = uv5 / kernel_w; - int u6 = uv6 / kernel_w; - int u7 = uv7 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - int v2 = uv2 % kernel_w; - int v3 = uv3 % kernel_w; - int v4 = uv4 % kernel_w; - int v5 = uv5 % kernel_w; - int v6 = uv6 % kernel_w; - int v7 = uv7 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - const Mat img2 = bottom_blob.channel(p2); - const Mat img3 = bottom_blob.channel(p3); - const Mat img4 = bottom_blob.channel(p4); - const Mat img5 = bottom_blob.channel(p5); - const Mat img6 = bottom_blob.channel(p6); - const Mat img7 = bottom_blob.channel(p7); - - int x00 = dx0 + dilation_w * v0; - int x01 = dx1 + dilation_w * v0; - int y00 = dy0 + dilation_h * u0; - int y01 = dy1 + dilation_h * u0; - int x10 = dx0 + dilation_w * v1; - int x11 = dx1 + dilation_w * v1; - int y10 = dy0 + dilation_h * u1; - int y11 = dy1 + dilation_h * u1; - - int x20 = dx0 + dilation_w * v2; - int x21 = dx1 + dilation_w * v2; - int y20 = dy0 + dilation_h * u2; - int y21 = dy1 + dilation_h * u2; - int x30 = dx0 + dilation_w * v3; - int x31 = dx1 + dilation_w * v3; - int y30 = dy0 + dilation_h * u3; - int y31 = dy1 + dilation_h * u3; - - int x40 = dx0 + dilation_w * v4; - int x41 = dx1 + dilation_w * v4; - int y40 = dy0 + dilation_h * u4; - int y41 = dy1 + dilation_h * u4; - int x50 = dx0 + dilation_w * v5; - int x51 = dx1 + dilation_w * v5; - int y50 = dy0 + dilation_h * u5; - int y51 = dy1 + dilation_h * u5; - - int x60 = dx0 + dilation_w * v6; - int x61 = dx1 + dilation_w * v6; - int y60 = dy0 + dilation_h * u6; - int y61 = dy1 + dilation_h * u6; - int x70 = dx0 + dilation_w * v7; - int x71 = dx1 + dilation_w * v7; - int y70 = dy0 + dilation_h * u7; - int y71 = dy1 + dilation_h * u7; - - const signed char* sptr00 = img0.row(y00) + x00; - const signed char* sptr01 = img0.row(y01) + x01; - const signed char* sptr10 = img1.row(y10) + x10; - const signed char* sptr11 = img1.row(y11) + x11; - const signed char* sptr20 = img2.row(y20) + x20; - const signed char* sptr21 = img2.row(y21) + x21; - const signed char* sptr30 = img3.row(y30) + x30; - const signed char* sptr31 = img3.row(y31) + x31; - - const signed char* sptr40 = img4.row(y40) + x40; - const signed char* sptr41 = img4.row(y41) + x41; - const signed char* sptr50 = img5.row(y50) + x50; - const signed char* sptr51 = img5.row(y51) + x51; - const signed char* sptr60 = img6.row(y60) + x60; - const signed char* sptr61 = img6.row(y61) + x61; - const signed char* sptr70 = img7.row(y70) + x70; - const signed char* sptr71 = img7.row(y71) + x71; - - pp[0] = sptr00[0]; - pp[1] = sptr10[0]; - pp[2] = sptr20[0]; - pp[3] = sptr30[0]; - pp[4] = sptr40[0]; - pp[5] = sptr50[0]; - pp[6] = sptr60[0]; - pp[7] = sptr70[0]; - pp[8] = sptr01[0]; - pp[9] = sptr11[0]; - pp[10] = sptr21[0]; - pp[11] = sptr31[0]; - pp[12] = sptr41[0]; - pp[13] = sptr51[0]; - pp[14] = sptr61[0]; - pp[15] = sptr71[0]; - pp += 16; - } -#endif // __ARM_FEATURE_MATMUL_INT8 - for (; kk + 3 < max_kk; kk += 4) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int p2 = (k + kk + 2) / maxk; - int p3 = (k + kk + 3) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int uv2 = (k + kk + 2) % maxk; - int uv3 = (k + kk + 3) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int u2 = uv2 / kernel_w; - int u3 = uv3 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - int v2 = uv2 % kernel_w; - int v3 = uv3 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - const Mat img2 = bottom_blob.channel(p2); - const Mat img3 = bottom_blob.channel(p3); - - int x00 = dx0 + dilation_w * v0; - int x01 = dx1 + dilation_w * v0; - int y00 = dy0 + dilation_h * u0; - int y01 = dy1 + dilation_h * u0; - int x10 = dx0 + dilation_w * v1; - int x11 = dx1 + dilation_w * v1; - int y10 = dy0 + dilation_h * u1; - int y11 = dy1 + dilation_h * u1; - int x20 = dx0 + dilation_w * v2; - int x21 = dx1 + dilation_w * v2; - int y20 = dy0 + dilation_h * u2; - int y21 = dy1 + dilation_h * u2; - int x30 = dx0 + dilation_w * v3; - int x31 = dx1 + dilation_w * v3; - int y30 = dy0 + dilation_h * u3; - int y31 = dy1 + dilation_h * u3; - - const signed char* sptr00 = img0.row(y00) + x00; - const signed char* sptr01 = img0.row(y01) + x01; - const signed char* sptr10 = img1.row(y10) + x10; - const signed char* sptr11 = img1.row(y11) + x11; - const signed char* sptr20 = img2.row(y20) + x20; - const signed char* sptr21 = img2.row(y21) + x21; - const signed char* sptr30 = img3.row(y30) + x30; - const signed char* sptr31 = img3.row(y31) + x31; - - pp[0] = sptr00[0]; - pp[1] = sptr10[0]; - pp[2] = sptr20[0]; - pp[3] = sptr30[0]; - pp[4] = sptr01[0]; - pp[5] = sptr11[0]; - pp[6] = sptr21[0]; - pp[7] = sptr31[0]; - pp += 8; - } -#endif // __ARM_FEATURE_DOTPROD - for (; kk + 1 < max_kk; kk += 2) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = dx0 + dilation_w * v0; - int x01 = dx1 + dilation_w * v0; - int y00 = dy0 + dilation_h * u0; - int y01 = dy1 + dilation_h * u0; - int x10 = dx0 + dilation_w * v1; - int x11 = dx1 + dilation_w * v1; - int y10 = dy0 + dilation_h * u1; - int y11 = dy1 + dilation_h * u1; - - const signed char* sptr00 = img0.row(y00) + x00; - const signed char* sptr01 = img0.row(y01) + x01; - const signed char* sptr10 = img1.row(y10) + x10; - const signed char* sptr11 = img1.row(y11) + x11; - - pp[0] = sptr00[0]; - pp[1] = sptr10[0]; - pp[2] = sptr01[0]; - pp[3] = sptr11[0]; - pp += 4; - } - } -#endif // __ARM_NEON - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x0 = dx0 + dilation_w * v; - int x1 = dx1 + dilation_w * v; - int y0 = dy0 + dilation_h * u; - int y1 = dy1 + dilation_h * u; - - const signed char* sptr0 = img.row(y0) + x0 * elempack; - const signed char* sptr1 = img.row(y1) + x1 * elempack; - -#if __ARM_NEON - if (elempack == 8) - { -#if __ARM_FEATURE_MATMUL_INT8 - int8x8_t _r0 = vld1_s8(sptr0); - int8x8_t _r1 = vld1_s8(sptr1); - vst1_s8(pp, _r0); - vst1_s8(pp + 8, _r1); - pp += 16; -#elif __ARM_FEATURE_DOTPROD - int32x2x2_t _r01; - _r01.val[0] = vreinterpret_s32_s8(vld1_s8(sptr0)); - _r01.val[1] = vreinterpret_s32_s8(vld1_s8(sptr1)); - vst2_s32((int*)pp, _r01); - pp += 16; -#else // __ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD - int16x4x2_t _r01; - _r01.val[0] = vreinterpret_s16_s8(vld1_s8(sptr0)); - _r01.val[1] = vreinterpret_s16_s8(vld1_s8(sptr1)); - vst2_s16((short*)pp, _r01); - pp += 16; -#endif // __ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD - } -#endif // __ARM_NEON - if (elempack == 1) - { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp += 2; - } - } - } - } - for (; jj < max_jj; jj++) - { - int dy = (j + jj) / outw * stride_h; - int dx = (j + jj) % outw * stride_w; - - int kk = 0; - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x = dx + dilation_w * v; - int y = dy + dilation_h * u; - - const signed char* sptr = img.row(y) + x * elempack; - -#if __ARM_NEON - if (elempack == 8) - { - vst1_s8(pp, vld1_s8(sptr)); - pp += 8; - } -#endif // __ARM_NEON - if (elempack == 1) - { - pp[0] = sptr[0]; - pp += 1; - } - } - } -} - -static void convolution_im2col_gemm_transform_kernel_int8(const Mat& kernel, Mat& AT, int inch, int outch, int kernel_w, int kernel_h, const Option& opt) -{ -#if !(__ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD) -#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 - if (ncnn::cpu_support_arm_i8mm()) - { - convolution_im2col_gemm_transform_kernel_int8_i8mm(kernel, AT, inch, outch, kernel_w, kernel_h, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD - if (ncnn::cpu_support_arm_asimddp()) - { - convolution_im2col_gemm_transform_kernel_int8_asimddp(kernel, AT, inch, outch, kernel_w, kernel_h, opt); - return; - } -#endif -#endif - - // NCNN_LOGE("convolution_im2col_gemm_transform_kernel"); - const int maxk = kernel_w * kernel_h; - - const int M = outch; - const int K = inch * maxk; - - int TILE_M, TILE_N, TILE_K; - convolution_im2col_gemm_get_optimal_tile_mnk_int8(M, 0, K, TILE_M, TILE_N, TILE_K, opt.num_threads); - - const int nn_M = (M + TILE_M - 1) / TILE_M; - - int elempack = 1; -#if __ARM_NEON - if (opt.use_packing_layout) - { - elempack = inch % 8 == 0 ? 8 : 1; - } -#endif // __ARM_NEON - - // maxk-inch-outch to pa-maxk-inch/pa-outch - Mat A_data; - if (maxk == 1) - { - A_data = kernel.reshape(maxk * inch, outch); - } - else - { - Mat weight_data_r2 = kernel.reshape(maxk, inch, outch); - - A_data.create(maxk * inch, outch, (size_t)1u, 1); - - for (int q = 0; q < outch; q += 1) - { - signed char* g00 = A_data.row(q); - - for (int p = 0; p + (elempack - 1) < inch; p += elempack) - { - for (int k = 0; k < maxk; k++) - { - for (int i = 0; i < elempack; i++) - { - const signed char* k00 = weight_data_r2.channel(q).row(p + i); - g00[0] = k00[k]; - g00++; - } - } - } - } - } - - AT.create(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, (M + TILE_M - 1) / TILE_M, (size_t)1u, 1); - - #pragma omp parallel for num_threads(opt.num_threads) - for (int ppj = 0; ppj < nn_M; ppj++) - { - const int i = ppj * TILE_M; - - const int max_ii = std::min((M - i), TILE_M); - - for (int k = 0; k < K; k += TILE_K) - { - const int max_kk = std::min((K - k), TILE_K); - - Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); - - convolution_im2col_pack_A_tile_int8(A_data, AT_tile, i, max_ii, k, max_kk); - } - } -} - -static void convolution_im2col_gemm_int8(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, int nT, const Option& opt) -{ -#if !(__ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD) -#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 - if (ncnn::cpu_support_arm_i8mm()) - { - convolution_im2col_gemm_int8_i8mm(bottom_blob, top_blob, AT, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, nT, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 if (ncnn::cpu_support_arm_asimddp()) { convolution_im2col_gemm_int8_asimddp(bottom_blob, top_blob, AT, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, nT, opt); return; } -#endif #endif const int maxk = kernel_w * kernel_h; diff --git a/src/layer/arm/convolution_packed_int8.h b/src/layer/arm/convolution_packed_int8.h index 19342a0dedb..897b19ab3cf 100644 --- a/src/layer/arm/convolution_packed_int8.h +++ b/src/layer/arm/convolution_packed_int8.h @@ -12,21 +12,18 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#if !(__ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD) #if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 void convolution_transform_kernel_packed_int8_i8mm(const Mat& kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h); void convolution_packed_int8_i8mm(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_tm, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt); #endif -#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 void convolution_transform_kernel_packed_int8_asimddp(const Mat& kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h); void convolution_packed_int8_asimddp(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_tm, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt); #endif -#endif static void convolution_transform_kernel_packed_int8(const Mat& kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h) { -#if !(__ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD) #if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 if (ncnn::cpu_support_arm_i8mm()) { @@ -35,13 +32,12 @@ static void convolution_transform_kernel_packed_int8(const Mat& kernel, Mat& ker } #endif -#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 if (ncnn::cpu_support_arm_asimddp()) { convolution_transform_kernel_packed_int8_asimddp(kernel, kernel_tm, inch, outch, kernel_w, kernel_h); return; } -#endif #endif const int maxk = kernel_w * kernel_h; @@ -531,7 +527,6 @@ static void convolution_transform_kernel_packed_int8(const Mat& kernel, Mat& ker static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_tm, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt) { -#if !(__ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD) #if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 if (ncnn::cpu_support_arm_i8mm()) { @@ -540,13 +535,12 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const } #endif -#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 if (ncnn::cpu_support_arm_asimddp()) { convolution_packed_int8_asimddp(bottom_blob, top_blob, weight_data_tm, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, opt); return; } -#endif #endif const int w = bottom_blob.w; diff --git a/src/layer/arm/innerproduct_fp16s.h b/src/layer/arm/innerproduct_fp16s.h index d4fa9950c3c..d8c74d56a8e 100644 --- a/src/layer/arm/innerproduct_fp16s.h +++ b/src/layer/arm/innerproduct_fp16s.h @@ -12,23 +12,20 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#if !(__ARM_FEATURE_FP16_FML || __ARM_FEATURE_FP16_VECTOR_ARITHMETIC) #if NCNN_RUNTIME_CPU && NCNN_ARM82FP16FML && __aarch64__ && !__ARM_FEATURE_FP16_FML void innerproduct_pack4_fp16s_neon_asimdfhm(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_fp16, const Mat& bias_data, int activation_type, const Mat& activation_params, const Option& opt); void innerproduct_fp16s_neon_asimdfhm(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_fp16, const Mat& bias_data, int activation_type, const Mat& activation_params, const Option& opt); void innerproduct_transform_kernel_fp16s_neon_asimdfhm(const Mat& weight_data, Mat& weight_data_tm, int num_input, int num_output, const Option& opt); #endif -#if NCNN_RUNTIME_CPU && NCNN_ARM82 && __aarch64__ && !__ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if NCNN_RUNTIME_CPU && NCNN_ARM82 && __aarch64__ && !__ARM_FEATURE_FP16_VECTOR_ARITHMETIC && !__ARM_FEATURE_FP16_FML void innerproduct_pack4_fp16s_neon_asimdhp(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_fp16, const Mat& bias_data, int activation_type, const Mat& activation_params, const Option& opt); void innerproduct_fp16s_neon_asimdhp(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_fp16, const Mat& bias_data, int activation_type, const Mat& activation_params, const Option& opt); void innerproduct_transform_kernel_fp16s_neon_asimdhp(const Mat& weight_data, Mat& weight_data_tm, int num_input, int num_output, const Option& opt); #endif -#endif static void innerproduct_pack4_fp16s_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_fp16, const Mat& bias_data, int activation_type, const Mat& activation_params, const Option& opt) { -#if !(__ARM_FEATURE_FP16_FML || __ARM_FEATURE_FP16_VECTOR_ARITHMETIC) #if NCNN_RUNTIME_CPU && NCNN_ARM82FP16FML && __aarch64__ && !__ARM_FEATURE_FP16_FML if (ncnn::cpu_support_arm_asimdfhm()) { @@ -37,13 +34,12 @@ static void innerproduct_pack4_fp16s_neon(const Mat& bottom_blob, Mat& top_blob, } #endif -#if NCNN_RUNTIME_CPU && NCNN_ARM82 && __aarch64__ && !__ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if NCNN_RUNTIME_CPU && NCNN_ARM82 && __aarch64__ && !__ARM_FEATURE_FP16_VECTOR_ARITHMETIC && !__ARM_FEATURE_FP16_FML if (ncnn::cpu_support_arm_asimdhp()) { innerproduct_pack4_fp16s_neon_asimdhp(bottom_blob, top_blob, weight_data_fp16, bias_data, activation_type, activation_params, opt); return; } -#endif #endif const int num_input = bottom_blob.w * bottom_blob.elempack; @@ -294,7 +290,6 @@ static void innerproduct_pack4_fp16s_neon(const Mat& bottom_blob, Mat& top_blob, static void innerproduct_fp16s_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_fp16, const Mat& bias_data, int activation_type, const Mat& activation_params, const Option& opt) { -#if !(__ARM_FEATURE_FP16_FML || __ARM_FEATURE_FP16_VECTOR_ARITHMETIC) #if NCNN_RUNTIME_CPU && NCNN_ARM82FP16FML && __aarch64__ && !__ARM_FEATURE_FP16_FML if (ncnn::cpu_support_arm_asimdfhm()) { @@ -303,13 +298,12 @@ static void innerproduct_fp16s_neon(const Mat& bottom_blob, Mat& top_blob, const } #endif -#if NCNN_RUNTIME_CPU && NCNN_ARM82 && __aarch64__ && !__ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if NCNN_RUNTIME_CPU && NCNN_ARM82 && __aarch64__ && !__ARM_FEATURE_FP16_VECTOR_ARITHMETIC && !__ARM_FEATURE_FP16_FML if (ncnn::cpu_support_arm_asimdhp()) { innerproduct_fp16s_neon_asimdhp(bottom_blob, top_blob, weight_data_fp16, bias_data, activation_type, activation_params, opt); return; } -#endif #endif const int num_input = bottom_blob.w * bottom_blob.elempack; @@ -516,7 +510,6 @@ static void innerproduct_fp16s_neon(const Mat& bottom_blob, Mat& top_blob, const static void innerproduct_transform_kernel_fp16s_neon(const Mat& weight_data, Mat& weight_data_tm, int num_input, int num_output, const Option& opt) { -#if !(__ARM_FEATURE_FP16_FML || __ARM_FEATURE_FP16_VECTOR_ARITHMETIC) #if NCNN_RUNTIME_CPU && NCNN_ARM82FP16FML && __aarch64__ && !__ARM_FEATURE_FP16_FML if (ncnn::cpu_support_arm_asimdfhm()) { @@ -525,13 +518,12 @@ static void innerproduct_transform_kernel_fp16s_neon(const Mat& weight_data, Mat } #endif -#if NCNN_RUNTIME_CPU && NCNN_ARM82 && __aarch64__ && !__ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if NCNN_RUNTIME_CPU && NCNN_ARM82 && __aarch64__ && !__ARM_FEATURE_FP16_VECTOR_ARITHMETIC && !__ARM_FEATURE_FP16_FML if (ncnn::cpu_support_arm_asimdhp()) { innerproduct_transform_kernel_fp16s_neon_asimdhp(weight_data, weight_data_tm, num_input, num_output, opt); return; } -#endif #endif int out_elempack = 1; diff --git a/src/layer/arm/innerproduct_gemm_fp16s.h b/src/layer/arm/innerproduct_gemm_fp16s.h index 1724cfdf144..c3b36e0b933 100644 --- a/src/layer/arm/innerproduct_gemm_fp16s.h +++ b/src/layer/arm/innerproduct_gemm_fp16s.h @@ -12,19 +12,16 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#if !(__ARM_FEATURE_FP16_FML || __ARM_FEATURE_FP16_VECTOR_ARITHMETIC) #if NCNN_RUNTIME_CPU && NCNN_ARM82FP16FML && __aarch64__ && !__ARM_FEATURE_FP16_FML void innerproduct_gemm_fp16s_neon_asimdfhm(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_fp16, const Mat& bias_data, int activation_type, const Mat& activation_params, const Option& opt); #endif -#if NCNN_RUNTIME_CPU && NCNN_ARM82 && __aarch64__ && !__ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if NCNN_RUNTIME_CPU && NCNN_ARM82 && __aarch64__ && !__ARM_FEATURE_FP16_VECTOR_ARITHMETIC && !__ARM_FEATURE_FP16_FML void innerproduct_gemm_fp16s_neon_asimdhp(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_fp16, const Mat& bias_data, int activation_type, const Mat& activation_params, const Option& opt); #endif -#endif static void innerproduct_gemm_fp16s_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_fp16, const Mat& bias_data, int activation_type, const Mat& activation_params, const Option& opt) { -#if !(__ARM_FEATURE_FP16_FML || __ARM_FEATURE_FP16_VECTOR_ARITHMETIC) #if NCNN_RUNTIME_CPU && NCNN_ARM82FP16FML && __aarch64__ && !__ARM_FEATURE_FP16_FML if (ncnn::cpu_support_arm_asimdfhm()) { @@ -33,13 +30,12 @@ static void innerproduct_gemm_fp16s_neon(const Mat& bottom_blob, Mat& top_blob, } #endif -#if NCNN_RUNTIME_CPU && NCNN_ARM82 && __aarch64__ && !__ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if NCNN_RUNTIME_CPU && NCNN_ARM82 && __aarch64__ && !__ARM_FEATURE_FP16_VECTOR_ARITHMETIC && !__ARM_FEATURE_FP16_FML if (ncnn::cpu_support_arm_asimdhp()) { innerproduct_gemm_fp16s_neon_asimdhp(bottom_blob, top_blob, weight_data_fp16, bias_data, activation_type, activation_params, opt); return; } -#endif #endif const int num_input = bottom_blob.w; diff --git a/src/layer/x86/cast_bf16.h b/src/layer/x86/cast_bf16.h index 15939b926f0..5f6abfa5a9b 100644 --- a/src/layer/x86/cast_bf16.h +++ b/src/layer/x86/cast_bf16.h @@ -17,7 +17,7 @@ void cast_fp32_to_bf16_sse_avx512bf16(const Mat& bottom_blob, Mat& top_blob, con void cast_bf16_to_fp32_sse_avx512bf16(const Mat& bottom_blob, Mat& top_blob, const Option& opt); #endif -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVX512BF16__ void cast_fp32_to_bf16_sse_avx2(const Mat& bottom_blob, Mat& top_blob, const Option& opt); void cast_bf16_to_fp32_sse_avx2(const Mat& bottom_blob, Mat& top_blob, const Option& opt); #endif @@ -32,7 +32,7 @@ static void cast_fp32_to_bf16_sse(const Mat& bottom_blob, Mat& top_blob, const O } #endif -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVX512BF16__ if (ncnn::cpu_support_x86_avx2()) { cast_fp32_to_bf16_sse_avx2(bottom_blob, top_blob, opt); @@ -104,7 +104,7 @@ static void cast_bf16_to_fp32_sse(const Mat& bottom_blob, Mat& top_blob, const O } #endif -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVX512BF16__ if (ncnn::cpu_support_x86_avx2()) { cast_bf16_to_fp32_sse_avx2(bottom_blob, top_blob, opt); diff --git a/src/layer/x86/convolution_3x3_winograd_int8.h b/src/layer/x86/convolution_3x3_winograd_int8.h index 94ea79d4540..bca8c754776 100644 --- a/src/layer/x86/convolution_3x3_winograd_int8.h +++ b/src/layer/x86/convolution_3x3_winograd_int8.h @@ -12,29 +12,27 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) #if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ void conv3x3s1_winograd23_int8_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); void conv3x3s1_winograd43_int8_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); #endif -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ void conv3x3s1_winograd23_int8_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); void conv3x3s1_winograd43_int8_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); #endif -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ void conv3x3s1_winograd23_transform_kernel_int8_avx2(const Mat& kernel, Mat& AT, int inch, int outch, const Option& opt); void conv3x3s1_winograd23_int8_avx2(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); void conv3x3s1_winograd43_transform_kernel_int8_avx2(const Mat& kernel, Mat& AT, int inch, int outch, const Option& opt); void conv3x3s1_winograd43_int8_avx2(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); #endif -#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ +#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ void conv3x3s1_winograd23_int8_xop(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); void conv3x3s1_winograd43_int8_xop(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); #endif -#endif static void pack_A_tile_int8(const Mat& A, Mat& AT, int batch, int max_ii, int max_kk) { @@ -3430,14 +3428,12 @@ static inline void conv3x3s1_winograd23_transform_kernel_tile_int8(const Mat& ke static void conv3x3s1_winograd23_transform_kernel_int8(const Mat& kernel, Mat& AT, int inch, int outch, const Option& opt) { -#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_avx2()) { conv3x3s1_winograd23_transform_kernel_int8_avx2(kernel, AT, inch, outch, opt); return; } -#endif #endif const int M = outch; @@ -4430,7 +4426,6 @@ static inline void conv3x3s1_winograd23_transform_output_tile_int8(const Mat& to static void conv3x3s1_winograd23_int8(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) { -#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) #if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_avx512_vnni()) { @@ -4439,7 +4434,7 @@ static void conv3x3s1_winograd23_int8(const Mat& bottom_blob, Mat& top_blob, con } #endif -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_avx_vnni()) { conv3x3s1_winograd23_int8_avxvnni(bottom_blob, top_blob, AT, nT, opt); @@ -4447,7 +4442,7 @@ static void conv3x3s1_winograd23_int8(const Mat& bottom_blob, Mat& top_blob, con } #endif -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_avx2()) { conv3x3s1_winograd23_int8_avx2(bottom_blob, top_blob, AT, nT, opt); @@ -4455,13 +4450,12 @@ static void conv3x3s1_winograd23_int8(const Mat& bottom_blob, Mat& top_blob, con } #endif -#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ +#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_xop()) { conv3x3s1_winograd23_int8_xop(bottom_blob, top_blob, AT, nT, opt); return; } -#endif #endif int outw = top_blob.w; @@ -4642,14 +4636,12 @@ static inline void conv3x3s1_winograd43_transform_kernel_tile_int8(const Mat& ke static void conv3x3s1_winograd43_transform_kernel_int8(const Mat& kernel, Mat& AT, int inch, int outch, const Option& opt) { -#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_avx2()) { conv3x3s1_winograd43_transform_kernel_int8_avx2(kernel, AT, inch, outch, opt); return; } -#endif #endif const int M = outch; @@ -6260,7 +6252,6 @@ static inline void conv3x3s1_winograd43_transform_output_tile_int8(const Mat& to static void conv3x3s1_winograd43_int8(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) { -#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) #if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_avx512_vnni()) { @@ -6269,7 +6260,7 @@ static void conv3x3s1_winograd43_int8(const Mat& bottom_blob, Mat& top_blob, con } #endif -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_avx_vnni()) { conv3x3s1_winograd43_int8_avxvnni(bottom_blob, top_blob, AT, nT, opt); @@ -6277,7 +6268,7 @@ static void conv3x3s1_winograd43_int8(const Mat& bottom_blob, Mat& top_blob, con } #endif -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_avx2()) { conv3x3s1_winograd43_int8_avx2(bottom_blob, top_blob, AT, nT, opt); @@ -6285,13 +6276,12 @@ static void conv3x3s1_winograd43_int8(const Mat& bottom_blob, Mat& top_blob, con } #endif -#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ +#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_xop()) { conv3x3s1_winograd43_int8_xop(bottom_blob, top_blob, AT, nT, opt); return; } -#endif #endif int outw = top_blob.w; diff --git a/src/layer/x86/convolution_im2col_gemm_int8.h b/src/layer/x86/convolution_im2col_gemm_int8.h index 351987abaab..1ea276eeba1 100644 --- a/src/layer/x86/convolution_im2col_gemm_int8.h +++ b/src/layer/x86/convolution_im2col_gemm_int8.h @@ -12,24 +12,22 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) #if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ void convolution_im2col_gemm_int8_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, int nT, const Option& opt); #endif -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ void convolution_im2col_gemm_int8_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, int nT, const Option& opt); #endif -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ void convolution_im2col_gemm_transform_kernel_int8_avx2(const Mat& kernel, Mat& AT, int inch, int outch, int kernel_w, int kernel_h, const Option& opt); void convolution_im2col_gemm_int8_avx2(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, int nT, const Option& opt); #endif -#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ +#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ void convolution_im2col_gemm_int8_xop(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, int nT, const Option& opt); #endif -#endif static void convolution_im2col_pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) { @@ -7476,14 +7474,12 @@ static void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, i static void convolution_im2col_gemm_transform_kernel_int8(const Mat& kernel, Mat& AT, int inch, int outch, int kernel_w, int kernel_h, const Option& opt) { -#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_avx2()) { convolution_im2col_gemm_transform_kernel_int8_avx2(kernel, AT, inch, outch, kernel_w, kernel_h, opt); return; } -#endif #endif // NCNN_LOGE("convolution_im2col_gemm_transform_kernel"); @@ -7558,24 +7554,23 @@ static void convolution_im2col_gemm_transform_kernel_int8(const Mat& kernel, Mat static void convolution_im2col_gemm_int8(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, int nT, const Option& opt) { -#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) #if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ - if (ncnn::cpu_support_x86_avx512vnni()) + if (ncnn::cpu_support_x86_avx512_vnni()) { convolution_im2col_gemm_int8_avx512vnni(bottom_blob, top_blob, AT, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, nT, opt); return; } #endif -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ - if (ncnn::cpu_support_x86_avxvnni()) +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) { convolution_im2col_gemm_int8_avxvnni(bottom_blob, top_blob, AT, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, nT, opt); return; } #endif -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_avx2()) { convolution_im2col_gemm_int8_avx2(bottom_blob, top_blob, AT, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, nT, opt); @@ -7583,13 +7578,12 @@ static void convolution_im2col_gemm_int8(const Mat& bottom_blob, Mat& top_blob, } #endif -#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ +#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_xop()) { convolution_im2col_gemm_int8_xop(bottom_blob, top_blob, AT, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, nT, opt); return; } -#endif #endif const int maxk = kernel_w * kernel_h; diff --git a/src/layer/x86/convolution_packed_int8.h b/src/layer/x86/convolution_packed_int8.h index 46c03f0ca9b..8a1659565f5 100644 --- a/src/layer/x86/convolution_packed_int8.h +++ b/src/layer/x86/convolution_packed_int8.h @@ -12,31 +12,26 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ -void convolution_transform_kernel_packed_int8_avx2(const Mat& kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h); -#endif - -#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) #if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ void convolution_packed_int8_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_tm, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt); #endif -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ void convolution_packed_int8_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_tm, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt); #endif -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ +void convolution_transform_kernel_packed_int8_avx2(const Mat& kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h); void convolution_packed_int8_avx2(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_tm, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt); #endif -#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ +#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ void convolution_packed_int8_xop(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_tm, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt); #endif -#endif static void convolution_transform_kernel_packed_int8(const Mat& kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h) { -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_avx2()) { convolution_transform_kernel_packed_int8_avx2(kernel, kernel_tm, inch, outch, kernel_w, kernel_h); @@ -880,7 +875,6 @@ static void convolution_transform_kernel_packed_int8(const Mat& kernel, Mat& ker static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_tm, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt) { -#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) #if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_avx512_vnni()) { @@ -889,7 +883,7 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const } #endif -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_avx_vnni()) { convolution_packed_int8_avxvnni(bottom_blob, top_blob, weight_data_tm, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, opt); @@ -897,7 +891,7 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const } #endif -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_avx2()) { convolution_packed_int8_avx2(bottom_blob, top_blob, weight_data_tm, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, opt); @@ -905,13 +899,12 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const } #endif -#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ +#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_xop()) { convolution_packed_int8_xop(bottom_blob, top_blob, weight_data_tm, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, opt); return; } -#endif #endif const int w = bottom_blob.w;