diff --git a/src/layer/arm/gru_int8.h b/src/layer/arm/gru_int8.h index 78696de401e..8f9051ab728 100644 --- a/src/layer/arm/gru_int8.h +++ b/src/layer/arm/gru_int8.h @@ -148,6 +148,33 @@ static void gru_transform_weight_int8(const Mat& weight_xc, const Mat& weight_xc kptr += 32; } +#else + for (; i + 7 < size; i += 8) + { + int8x8_t _w0 = vld1_s8(weight_xc_R_0 + i); + int8x8_t _w1 = vld1_s8(weight_xc_R_1 + i); + int8x8_t _w2 = vld1_s8(weight_xc_R_2 + i); + int8x8_t _w3 = vld1_s8(weight_xc_R_3 + i); + int8x8_t _w4 = vld1_s8(weight_xc_U_0 + i); + int8x8_t _w5 = vld1_s8(weight_xc_U_1 + i); + int8x8_t _w6 = vld1_s8(weight_xc_U_2 + i); + int8x8_t _w7 = vld1_s8(weight_xc_U_3 + i); + + int32x2x2_t _t0 = vtrn_s32(vreinterpret_s32_s8(_w0), vreinterpret_s32_s8(_w4)); + int32x2x2_t _t1 = vtrn_s32(vreinterpret_s32_s8(_w1), vreinterpret_s32_s8(_w5)); + int32x2x2_t _t2 = vtrn_s32(vreinterpret_s32_s8(_w2), vreinterpret_s32_s8(_w6)); + int32x2x2_t _t3 = vtrn_s32(vreinterpret_s32_s8(_w3), vreinterpret_s32_s8(_w7)); + + int32x4x4_t _w; + _w.val[0] = vcombine_s32(_t0.val[0], _t0.val[1]); + _w.val[1] = vcombine_s32(_t1.val[0], _t1.val[1]); + _w.val[2] = vcombine_s32(_t2.val[0], _t2.val[1]); + _w.val[3] = vcombine_s32(_t3.val[0], _t3.val[1]); + + vst4q_s32((int*)kptr, _w); + + kptr += 64; + } #endif // __ARM_FEATURE_DOTPROD for (; i + 1 < size; i += 2) { @@ -223,6 +250,33 @@ static void gru_transform_weight_int8(const Mat& weight_xc, const Mat& weight_xc kptr += 32; } +#else + for (; i + 7 < num_output; i += 8) + { + int8x8_t _w0 = vld1_s8(weight_hc_R_0 + i); + int8x8_t _w1 = vld1_s8(weight_hc_R_1 + i); + int8x8_t _w2 = vld1_s8(weight_hc_R_2 + i); + int8x8_t _w3 = vld1_s8(weight_hc_R_3 + i); + int8x8_t _w4 = vld1_s8(weight_hc_U_0 + i); + int8x8_t _w5 = vld1_s8(weight_hc_U_1 + i); + int8x8_t _w6 = vld1_s8(weight_hc_U_2 + i); + int8x8_t _w7 = vld1_s8(weight_hc_U_3 + i); + + int32x2x2_t _t0 = vtrn_s32(vreinterpret_s32_s8(_w0), vreinterpret_s32_s8(_w4)); + int32x2x2_t _t1 = vtrn_s32(vreinterpret_s32_s8(_w1), vreinterpret_s32_s8(_w5)); + int32x2x2_t _t2 = vtrn_s32(vreinterpret_s32_s8(_w2), vreinterpret_s32_s8(_w6)); + int32x2x2_t _t3 = vtrn_s32(vreinterpret_s32_s8(_w3), vreinterpret_s32_s8(_w7)); + + int32x4x4_t _w; + _w.val[0] = vcombine_s32(_t0.val[0], _t0.val[1]); + _w.val[1] = vcombine_s32(_t1.val[0], _t1.val[1]); + _w.val[2] = vcombine_s32(_t2.val[0], _t2.val[1]); + _w.val[3] = vcombine_s32(_t3.val[0], _t3.val[1]); + + vst4q_s32((int*)kptr, _w); + + kptr += 64; + } #endif // __ARM_FEATURE_DOTPROD for (; i + 1 < num_output; i += 2) { @@ -282,6 +336,15 @@ static void gru_transform_weight_int8(const Mat& weight_xc, const Mat& weight_xc kptr += 16; } +#else + for (; i + 7 < num_output; i += 8) + { + vst1_s8(kptr, vld1_s8(weight_hc_N_0 + i)); + vst1_s8(kptr + 8, vld1_s8(weight_hc_N_1 + i)); + vst1_s8(kptr + 16, vld1_s8(weight_hc_N_2 + i)); + vst1_s8(kptr + 24, vld1_s8(weight_hc_N_3 + i)); + kptr += 32; + } #endif // __ARM_FEATURE_DOTPROD for (; i + 1 < num_output; i += 2) { @@ -329,6 +392,15 @@ static void gru_transform_weight_int8(const Mat& weight_xc, const Mat& weight_xc kptr += 16; } +#else + for (; i + 7 < size; i += 8) + { + vst1_s8(kptr, vld1_s8(weight_xc_N_0 + i)); + vst1_s8(kptr + 8, vld1_s8(weight_xc_N_1 + i)); + vst1_s8(kptr + 16, vld1_s8(weight_xc_N_2 + i)); + vst1_s8(kptr + 24, vld1_s8(weight_xc_N_3 + i)); + kptr += 32; + } #endif // __ARM_FEATURE_DOTPROD for (; i + 1 < size; i += 2) { @@ -666,6 +738,71 @@ static void gru_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de } _gru_Rx0 = vaddq_s32(_gru_Rx0, _sum1); _gru_Ux0 = vaddq_s32(_gru_Ux0, _sum2); +#else + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); + for (; i + 7 < size; i += 8) + { +#if NCNN_GNU_INLINE_ASM && !__aarch64__ + const signed char* xptr = x + i; + + asm volatile( + "vldm %1!, {d0-d7} \n" + "vld1.s8 {d16}, [%0] \n" + "vdup.32 d17, d16[0] \n" + "vdup.32 d16, d16[1] \n" + "vmull.s8 q4, d0, d17 \n" + "vmull.s8 q5, d1, d17 \n" + "vmull.s8 q6, d2, d17 \n" + "vmull.s8 q7, d3, d17 \n" + "vmlal.s8 q4, d4, d16 \n" + "vmlal.s8 q5, d5, d16 \n" + "vmlal.s8 q6, d6, d16 \n" + "vmlal.s8 q7, d7, d16 \n" + "vpadal.s16 %q2, q4 \n" + "vpadal.s16 %q3, q5 \n" + "vpadal.s16 %q4, q6 \n" + "vpadal.s16 %q5, q7 \n" + : "=r"(xptr), "=r"(kptr), "=w"(_sum0), "=w"(_sum1), "=w"(_sum2), "=w"(_sum3) + : "0"(xptr), "1"(kptr), "2"(_sum0), "3"(_sum1), "4"(_sum2), "5"(_sum3) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8" + ); +#else + int32x2_t _xi01 = vreinterpret_s32_s8(vld1_s8(x + i)); + int8x8_t _xi0 = vreinterpret_s8_s32(vdup_lane_s32(_xi01, 0)); + int8x8_t _xi1 = vreinterpret_s8_s32(vdup_lane_s32(_xi01, 1)); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + int8x16_t _w2 = vld1q_s8(kptr + 32); + int8x16_t _w3 = vld1q_s8(kptr + 48); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_w0), _xi0); + int16x8_t _s1 = vmull_s8(vget_high_s8(_w0), _xi0); + int16x8_t _s2 = vmull_s8(vget_low_s8(_w1), _xi0); + int16x8_t _s3 = vmull_s8(vget_high_s8(_w1), _xi0); + _s0 = vmlal_s8(_s0, vget_low_s8(_w2), _xi1); + _s1 = vmlal_s8(_s1, vget_high_s8(_w2), _xi1); + _s2 = vmlal_s8(_s2, vget_low_s8(_w3), _xi1); + _s3 = vmlal_s8(_s3, vget_high_s8(_w3), _xi1); + + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + kptr += 64; +#endif + } + { + int32x2_t _s0 = vpadd_s32(vget_low_s32(_sum0), vget_high_s32(_sum0)); + int32x2_t _s1 = vpadd_s32(vget_low_s32(_sum1), vget_high_s32(_sum1)); + int32x2_t _s2 = vpadd_s32(vget_low_s32(_sum2), vget_high_s32(_sum2)); + int32x2_t _s3 = vpadd_s32(vget_low_s32(_sum3), vget_high_s32(_sum3)); + _gru_Rx0 = vaddq_s32(_gru_Rx0, vcombine_s32(_s0, _s1)); + _gru_Ux0 = vaddq_s32(_gru_Ux0, vcombine_s32(_s2, _s3)); + } #endif // __ARM_FEATURE_DOTPROD for (; i + 3 < size; i += 4) { @@ -740,6 +877,71 @@ static void gru_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de } _gru_Rh0 = vaddq_s32(_gru_Rh0, _sum1); _gru_Uh0 = vaddq_s32(_gru_Uh0, _sum2); +#else + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + for (; i + 7 < num_output; i += 8) + { +#if NCNN_GNU_INLINE_ASM && !__aarch64__ + const signed char* hsptr = hs + i; + + asm volatile( + "vldm %1!, {d0-d7} \n" + "vld1.s8 {d16}, [%0] \n" + "vdup.32 d17, d16[0] \n" + "vdup.32 d16, d16[1] \n" + "vmull.s8 q4, d0, d17 \n" + "vmull.s8 q5, d1, d17 \n" + "vmull.s8 q6, d2, d17 \n" + "vmull.s8 q7, d3, d17 \n" + "vmlal.s8 q4, d4, d16 \n" + "vmlal.s8 q5, d5, d16 \n" + "vmlal.s8 q6, d6, d16 \n" + "vmlal.s8 q7, d7, d16 \n" + "vpadal.s16 %q2, q4 \n" + "vpadal.s16 %q3, q5 \n" + "vpadal.s16 %q4, q6 \n" + "vpadal.s16 %q5, q7 \n" + : "=r"(hsptr), "=r"(kptr), "=w"(_sum0), "=w"(_sum1), "=w"(_sum2), "=w"(_sum3) + : "0"(hsptr), "1"(kptr), "2"(_sum0), "3"(_sum1), "4"(_sum2), "5"(_sum3) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8" + ); +#else + int32x2_t _h_cont01 = vreinterpret_s32_s8(vld1_s8(hs + i)); + int8x8_t _h_cont0 = vreinterpret_s8_s32(vdup_lane_s32(_h_cont01, 0)); + int8x8_t _h_cont1 = vreinterpret_s8_s32(vdup_lane_s32(_h_cont01, 1)); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + int8x16_t _w2 = vld1q_s8(kptr + 32); + int8x16_t _w3 = vld1q_s8(kptr + 48); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_w0), _h_cont0); + int16x8_t _s1 = vmull_s8(vget_high_s8(_w0), _h_cont0); + int16x8_t _s2 = vmull_s8(vget_low_s8(_w1), _h_cont0); + int16x8_t _s3 = vmull_s8(vget_high_s8(_w1), _h_cont0); + _s0 = vmlal_s8(_s0, vget_low_s8(_w2), _h_cont1); + _s1 = vmlal_s8(_s1, vget_high_s8(_w2), _h_cont1); + _s2 = vmlal_s8(_s2, vget_low_s8(_w3), _h_cont1); + _s3 = vmlal_s8(_s3, vget_high_s8(_w3), _h_cont1); + + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + kptr += 64; +#endif + } + { + int32x2_t _s0 = vpadd_s32(vget_low_s32(_sum0), vget_high_s32(_sum0)); + int32x2_t _s1 = vpadd_s32(vget_low_s32(_sum1), vget_high_s32(_sum1)); + int32x2_t _s2 = vpadd_s32(vget_low_s32(_sum2), vget_high_s32(_sum2)); + int32x2_t _s3 = vpadd_s32(vget_low_s32(_sum3), vget_high_s32(_sum3)); + _gru_Rh0 = vaddq_s32(_gru_Rh0, vcombine_s32(_s0, _s1)); + _gru_Uh0 = vaddq_s32(_gru_Uh0, vcombine_s32(_s2, _s3)); + } #endif // __ARM_FEATURE_DOTPROD for (; i + 3 < num_output; i += 4) { @@ -832,6 +1034,87 @@ static void gru_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de kptr += 32; } _gru_Nh0 = vaddq_s32(_gru_Nh0, _sum1); +#else + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + for (; i + 15 < num_output; i += 16) + { +#if NCNN_GNU_INLINE_ASM && !__aarch64__ + const signed char* hsptr = hs + i; + + asm volatile( + "vldm %1!, {d0-d7} \n" + "vld1.s8 {d16-d17}, [%0] \n" + "vmull.s8 q4, d0, d16 \n" + "vmull.s8 q5, d1, d16 \n" + "vmull.s8 q6, d2, d16 \n" + "vmull.s8 q7, d3, d16 \n" + "vmlal.s8 q4, d4, d17 \n" + "vmlal.s8 q5, d5, d17 \n" + "vmlal.s8 q6, d6, d17 \n" + "vmlal.s8 q7, d7, d17 \n" + "vpadal.s16 %q2, q4 \n" + "vpadal.s16 %q3, q5 \n" + "vpadal.s16 %q4, q6 \n" + "vpadal.s16 %q5, q7 \n" + : "=r"(hsptr), "=r"(kptr), "=w"(_sum0), "=w"(_sum1), "=w"(_sum2), "=w"(_sum3) + : "0"(hsptr), "1"(kptr), "2"(_sum0), "3"(_sum1), "4"(_sum2), "5"(_sum3) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8" + ); +#else + int8x16_t _h_cont = vld1q_s8(hs + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + int8x16_t _w2 = vld1q_s8(kptr + 32); + int8x16_t _w3 = vld1q_s8(kptr + 48); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_w0), vget_low_s8(_h_cont)); + int16x8_t _s1 = vmull_s8(vget_high_s8(_w0), vget_low_s8(_h_cont)); + int16x8_t _s2 = vmull_s8(vget_low_s8(_w1), vget_low_s8(_h_cont)); + int16x8_t _s3 = vmull_s8(vget_high_s8(_w1), vget_low_s8(_h_cont)); + _s0 = vmlal_s8(_s0, vget_low_s8(_w2), vget_high_s8(_h_cont)); + _s1 = vmlal_s8(_s1, vget_high_s8(_w2), vget_high_s8(_h_cont)); + _s2 = vmlal_s8(_s2, vget_low_s8(_w3), vget_high_s8(_h_cont)); + _s3 = vmlal_s8(_s3, vget_high_s8(_w3), vget_high_s8(_h_cont)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + kptr += 64; +#endif + } + for (; i + 7 < num_output; i += 8) + { + int8x8_t _h_cont = vld1_s8(hs + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_w0), _h_cont); + int16x8_t _s1 = vmull_s8(vget_high_s8(_w0), _h_cont); + int16x8_t _s2 = vmull_s8(vget_low_s8(_w1), _h_cont); + int16x8_t _s3 = vmull_s8(vget_high_s8(_w1), _h_cont); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + kptr += 32; + } + { + int32x4x2_t _tmp0 = vzipq_s32(_sum0, _sum1); + int32x4x2_t _tmp1 = vzipq_s32(_sum2, _sum3); + _sum0 = vcombine_s32(vget_low_s32(_tmp0.val[0]), vget_low_s32(_tmp1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_tmp0.val[0]), vget_high_s32(_tmp1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_tmp0.val[1]), vget_low_s32(_tmp1.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_tmp0.val[1]), vget_high_s32(_tmp1.val[1])); + } + _gru_Nh0 = vaddq_s32(_gru_Nh0, _sum0); + _gru_Nh0 = vaddq_s32(_gru_Nh0, _sum1); + _gru_Nh0 = vaddq_s32(_gru_Nh0, _sum2); + _gru_Nh0 = vaddq_s32(_gru_Nh0, _sum3); #endif // __ARM_FEATURE_DOTPROD for (; i + 3 < num_output; i += 4) { @@ -888,6 +1171,87 @@ static void gru_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de kptr += 32; } _gru_Nx0 = vaddq_s32(_gru_Nx0, _sum1); +#else + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + for (; i + 15 < size; i += 16) + { +#if NCNN_GNU_INLINE_ASM && !__aarch64__ + const signed char* xptr = x + i; + + asm volatile( + "vldm %1!, {d0-d7} \n" + "vld1.s8 {d16-d17}, [%0] \n" + "vmull.s8 q4, d0, d16 \n" + "vmull.s8 q5, d1, d16 \n" + "vmull.s8 q6, d2, d16 \n" + "vmull.s8 q7, d3, d16 \n" + "vmlal.s8 q4, d4, d17 \n" + "vmlal.s8 q5, d5, d17 \n" + "vmlal.s8 q6, d6, d17 \n" + "vmlal.s8 q7, d7, d17 \n" + "vpadal.s16 %q2, q4 \n" + "vpadal.s16 %q3, q5 \n" + "vpadal.s16 %q4, q6 \n" + "vpadal.s16 %q5, q7 \n" + : "=r"(xptr), "=r"(kptr), "=w"(_sum0), "=w"(_sum1), "=w"(_sum2), "=w"(_sum3) + : "0"(xptr), "1"(kptr), "2"(_sum0), "3"(_sum1), "4"(_sum2), "5"(_sum3) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8" + ); +#else + int8x16_t _xi = vld1q_s8(x + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + int8x16_t _w2 = vld1q_s8(kptr + 32); + int8x16_t _w3 = vld1q_s8(kptr + 48); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_w0), vget_low_s8(_xi)); + int16x8_t _s1 = vmull_s8(vget_high_s8(_w0), vget_low_s8(_xi)); + int16x8_t _s2 = vmull_s8(vget_low_s8(_w1), vget_low_s8(_xi)); + int16x8_t _s3 = vmull_s8(vget_high_s8(_w1), vget_low_s8(_xi)); + _s0 = vmlal_s8(_s0, vget_low_s8(_w2), vget_high_s8(_xi)); + _s1 = vmlal_s8(_s1, vget_high_s8(_w2), vget_high_s8(_xi)); + _s2 = vmlal_s8(_s2, vget_low_s8(_w3), vget_high_s8(_xi)); + _s3 = vmlal_s8(_s3, vget_high_s8(_w3), vget_high_s8(_xi)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + kptr += 64; +#endif + } + for (; i + 7 < size; i += 8) + { + int8x8_t _xi = vld1_s8(x + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_w0), _xi); + int16x8_t _s1 = vmull_s8(vget_high_s8(_w0), _xi); + int16x8_t _s2 = vmull_s8(vget_low_s8(_w1), _xi); + int16x8_t _s3 = vmull_s8(vget_high_s8(_w1), _xi); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + kptr += 32; + } + { + int32x4x2_t _tmp0 = vzipq_s32(_sum0, _sum1); + int32x4x2_t _tmp1 = vzipq_s32(_sum2, _sum3); + _sum0 = vcombine_s32(vget_low_s32(_tmp0.val[0]), vget_low_s32(_tmp1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_tmp0.val[0]), vget_high_s32(_tmp1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_tmp0.val[1]), vget_low_s32(_tmp1.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_tmp0.val[1]), vget_high_s32(_tmp1.val[1])); + } + _gru_Nx0 = vaddq_s32(_gru_Nx0, _sum0); + _gru_Nx0 = vaddq_s32(_gru_Nx0, _sum1); + _gru_Nx0 = vaddq_s32(_gru_Nx0, _sum2); + _gru_Nx0 = vaddq_s32(_gru_Nx0, _sum3); #endif // __ARM_FEATURE_DOTPROD for (; i + 3 < size; i += 4) { diff --git a/src/layer/arm/lstm_int8.h b/src/layer/arm/lstm_int8.h index b58ebfde398..cd66fb5580a 100644 --- a/src/layer/arm/lstm_int8.h +++ b/src/layer/arm/lstm_int8.h @@ -104,6 +104,15 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x kptr[8 + 7] = weight_xc_G[i + 3]; kptr += 16; } +#else + for (; i + 7 < size; i += 8) + { + vst1_s8(kptr, vld1_s8(weight_xc_I + i)); + vst1_s8(kptr + 8, vld1_s8(weight_xc_F + i)); + vst1_s8(kptr + 16, vld1_s8(weight_xc_O + i)); + vst1_s8(kptr + 24, vld1_s8(weight_xc_G + i)); + kptr += 32; + } #endif // __ARM_FEATURE_DOTPROD for (; i + 1 < size; i += 2) { @@ -150,6 +159,15 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x kptr[8 + 7] = weight_hc_G[i + 3]; kptr += 16; } +#else + for (; i + 7 < num_output; i += 8) + { + vst1_s8(kptr, vld1_s8(weight_hc_I + i)); + vst1_s8(kptr + 8, vld1_s8(weight_hc_F + i)); + vst1_s8(kptr + 16, vld1_s8(weight_hc_O + i)); + vst1_s8(kptr + 24, vld1_s8(weight_hc_G + i)); + kptr += 32; + } #endif // __ARM_FEATURE_DOTPROD for (; i + 1 < num_output; i += 2) { @@ -484,6 +502,87 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d _lstm_IFOGx0 = vaddq_s32(_lstm_IFOGx0, _sum1); _lstm_IFOGx0 = vaddq_s32(_lstm_IFOGx0, _sum2); _lstm_IFOGx0 = vaddq_s32(_lstm_IFOGx0, _sum3); +#else + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); + for (; i + 15 < size; i += 16) + { +#if NCNN_GNU_INLINE_ASM && !__aarch64__ + const signed char* xptr = x + i; + + asm volatile( + "vldm %1!, {d0-d7} \n" + "vld1.s8 {d16-d17}, [%0] \n" + "vmull.s8 q4, d0, d16 \n" + "vmull.s8 q5, d1, d16 \n" + "vmull.s8 q6, d2, d16 \n" + "vmull.s8 q7, d3, d16 \n" + "vmlal.s8 q4, d4, d17 \n" + "vmlal.s8 q5, d5, d17 \n" + "vmlal.s8 q6, d6, d17 \n" + "vmlal.s8 q7, d7, d17 \n" + "vpadal.s16 %q2, q4 \n" + "vpadal.s16 %q3, q5 \n" + "vpadal.s16 %q4, q6 \n" + "vpadal.s16 %q5, q7 \n" + : "=r"(xptr), "=r"(kptr), "=w"(_sum0), "=w"(_sum1), "=w"(_sum2), "=w"(_sum3) + : "0"(xptr), "1"(kptr), "2"(_sum0), "3"(_sum1), "4"(_sum2), "5"(_sum3) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8" + ); +#else + int8x16_t _xi = vld1q_s8(x + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + int8x16_t _w2 = vld1q_s8(kptr + 32); + int8x16_t _w3 = vld1q_s8(kptr + 48); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_w0), vget_low_s8(_xi)); + int16x8_t _s1 = vmull_s8(vget_high_s8(_w0), vget_low_s8(_xi)); + int16x8_t _s2 = vmull_s8(vget_low_s8(_w1), vget_low_s8(_xi)); + int16x8_t _s3 = vmull_s8(vget_high_s8(_w1), vget_low_s8(_xi)); + _s0 = vmlal_s8(_s0, vget_low_s8(_w2), vget_high_s8(_xi)); + _s1 = vmlal_s8(_s1, vget_high_s8(_w2), vget_high_s8(_xi)); + _s2 = vmlal_s8(_s2, vget_low_s8(_w3), vget_high_s8(_xi)); + _s3 = vmlal_s8(_s3, vget_high_s8(_w3), vget_high_s8(_xi)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + kptr += 64; +#endif + } + for (; i + 7 < size; i += 8) + { + int8x8_t _xi = vld1_s8(x + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_w0), _xi); + int16x8_t _s1 = vmull_s8(vget_high_s8(_w0), _xi); + int16x8_t _s2 = vmull_s8(vget_low_s8(_w1), _xi); + int16x8_t _s3 = vmull_s8(vget_high_s8(_w1), _xi); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + kptr += 32; + } + { + int32x4x2_t _tmp0 = vzipq_s32(_sum0, _sum1); + int32x4x2_t _tmp1 = vzipq_s32(_sum2, _sum3); + _sum0 = vcombine_s32(vget_low_s32(_tmp0.val[0]), vget_low_s32(_tmp1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_tmp0.val[0]), vget_high_s32(_tmp1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_tmp0.val[1]), vget_low_s32(_tmp1.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_tmp0.val[1]), vget_high_s32(_tmp1.val[1])); + } + _lstm_IFOGx0 = vaddq_s32(_lstm_IFOGx0, _sum0); + _lstm_IFOGx0 = vaddq_s32(_lstm_IFOGx0, _sum1); + _lstm_IFOGx0 = vaddq_s32(_lstm_IFOGx0, _sum2); + _lstm_IFOGx0 = vaddq_s32(_lstm_IFOGx0, _sum3); #endif // __ARM_FEATURE_DOTPROD for (; i + 3 < size; i += 4) { @@ -558,6 +657,87 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d _lstm_IFOGh0 = vaddq_s32(_lstm_IFOGh0, _sum1); _lstm_IFOGh0 = vaddq_s32(_lstm_IFOGh0, _sum2); _lstm_IFOGh0 = vaddq_s32(_lstm_IFOGh0, _sum3); +#else + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + for (; i + 15 < num_output; i += 16) + { +#if NCNN_GNU_INLINE_ASM && !__aarch64__ + const signed char* hsptr = hs + i; + + asm volatile( + "vldm %1!, {d0-d7} \n" + "vld1.s8 {d16-d17}, [%0] \n" + "vmull.s8 q4, d0, d16 \n" + "vmull.s8 q5, d1, d16 \n" + "vmull.s8 q6, d2, d16 \n" + "vmull.s8 q7, d3, d16 \n" + "vmlal.s8 q4, d4, d17 \n" + "vmlal.s8 q5, d5, d17 \n" + "vmlal.s8 q6, d6, d17 \n" + "vmlal.s8 q7, d7, d17 \n" + "vpadal.s16 %q2, q4 \n" + "vpadal.s16 %q3, q5 \n" + "vpadal.s16 %q4, q6 \n" + "vpadal.s16 %q5, q7 \n" + : "=r"(hsptr), "=r"(kptr), "=w"(_sum0), "=w"(_sum1), "=w"(_sum2), "=w"(_sum3) + : "0"(hsptr), "1"(kptr), "2"(_sum0), "3"(_sum1), "4"(_sum2), "5"(_sum3) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8" + ); +#else + int8x16_t _h_cont = vld1q_s8(hs + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + int8x16_t _w2 = vld1q_s8(kptr + 32); + int8x16_t _w3 = vld1q_s8(kptr + 48); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_w0), vget_low_s8(_h_cont)); + int16x8_t _s1 = vmull_s8(vget_high_s8(_w0), vget_low_s8(_h_cont)); + int16x8_t _s2 = vmull_s8(vget_low_s8(_w1), vget_low_s8(_h_cont)); + int16x8_t _s3 = vmull_s8(vget_high_s8(_w1), vget_low_s8(_h_cont)); + _s0 = vmlal_s8(_s0, vget_low_s8(_w2), vget_high_s8(_h_cont)); + _s1 = vmlal_s8(_s1, vget_high_s8(_w2), vget_high_s8(_h_cont)); + _s2 = vmlal_s8(_s2, vget_low_s8(_w3), vget_high_s8(_h_cont)); + _s3 = vmlal_s8(_s3, vget_high_s8(_w3), vget_high_s8(_h_cont)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + kptr += 64; +#endif + } + for (; i + 7 < num_output; i += 8) + { + int8x8_t _h_cont = vld1_s8(hs + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_w0), _h_cont); + int16x8_t _s1 = vmull_s8(vget_high_s8(_w0), _h_cont); + int16x8_t _s2 = vmull_s8(vget_low_s8(_w1), _h_cont); + int16x8_t _s3 = vmull_s8(vget_high_s8(_w1), _h_cont); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + kptr += 32; + } + { + int32x4x2_t _tmp0 = vzipq_s32(_sum0, _sum1); + int32x4x2_t _tmp1 = vzipq_s32(_sum2, _sum3); + _sum0 = vcombine_s32(vget_low_s32(_tmp0.val[0]), vget_low_s32(_tmp1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_tmp0.val[0]), vget_high_s32(_tmp1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_tmp0.val[1]), vget_low_s32(_tmp1.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_tmp0.val[1]), vget_high_s32(_tmp1.val[1])); + } + _lstm_IFOGh0 = vaddq_s32(_lstm_IFOGh0, _sum0); + _lstm_IFOGh0 = vaddq_s32(_lstm_IFOGh0, _sum1); + _lstm_IFOGh0 = vaddq_s32(_lstm_IFOGh0, _sum2); + _lstm_IFOGh0 = vaddq_s32(_lstm_IFOGh0, _sum3); #endif // __ARM_FEATURE_DOTPROD for (; i + 3 < num_output; i += 4) { diff --git a/src/layer/arm/rnn_int8.h b/src/layer/arm/rnn_int8.h index 0dbf849fb85..93b73d426d0 100644 --- a/src/layer/arm/rnn_int8.h +++ b/src/layer/arm/rnn_int8.h @@ -94,6 +94,15 @@ static void rnn_transform_weight_int8(const Mat& weight_xc, const Mat& weight_xc kptr += 16; } +#else + for (; i + 7 < size; i += 8) + { + vst1_s8(kptr, vld1_s8(weight_xc_0 + i)); + vst1_s8(kptr + 8, vld1_s8(weight_xc_1 + i)); + vst1_s8(kptr + 16, vld1_s8(weight_xc_2 + i)); + vst1_s8(kptr + 24, vld1_s8(weight_xc_3 + i)); + kptr += 32; + } #endif // __ARM_FEATURE_DOTPROD for (; i + 1 < size; i += 2) { @@ -141,6 +150,15 @@ static void rnn_transform_weight_int8(const Mat& weight_xc, const Mat& weight_xc kptr += 16; } +#else + for (; i + 7 < num_output; i += 8) + { + vst1_s8(kptr, vld1_s8(weight_hc_0 + i)); + vst1_s8(kptr + 8, vld1_s8(weight_hc_1 + i)); + vst1_s8(kptr + 16, vld1_s8(weight_hc_2 + i)); + vst1_s8(kptr + 24, vld1_s8(weight_hc_3 + i)); + kptr += 32; + } #endif // __ARM_FEATURE_DOTPROD for (; i + 1 < num_output; i += 2) { @@ -417,6 +435,87 @@ static void rnn_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de _rnn_Hx0 = vaddq_s32(_rnn_Hx0, _sum1); _rnn_Hx0 = vaddq_s32(_rnn_Hx0, _sum2); _rnn_Hx0 = vaddq_s32(_rnn_Hx0, _sum3); +#else + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); + for (; i + 15 < size; i += 16) + { +#if NCNN_GNU_INLINE_ASM && !__aarch64__ + const signed char* xptr = x + i; + + asm volatile( + "vldm %1!, {d0-d7} \n" + "vld1.s8 {d16-d17}, [%0] \n" + "vmull.s8 q4, d0, d16 \n" + "vmull.s8 q5, d1, d16 \n" + "vmull.s8 q6, d2, d16 \n" + "vmull.s8 q7, d3, d16 \n" + "vmlal.s8 q4, d4, d17 \n" + "vmlal.s8 q5, d5, d17 \n" + "vmlal.s8 q6, d6, d17 \n" + "vmlal.s8 q7, d7, d17 \n" + "vpadal.s16 %q2, q4 \n" + "vpadal.s16 %q3, q5 \n" + "vpadal.s16 %q4, q6 \n" + "vpadal.s16 %q5, q7 \n" + : "=r"(xptr), "=r"(kptr), "=w"(_sum0), "=w"(_sum1), "=w"(_sum2), "=w"(_sum3) + : "0"(xptr), "1"(kptr), "2"(_sum0), "3"(_sum1), "4"(_sum2), "5"(_sum3) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8" + ); +#else + int8x16_t _xi = vld1q_s8(x + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + int8x16_t _w2 = vld1q_s8(kptr + 32); + int8x16_t _w3 = vld1q_s8(kptr + 48); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_w0), vget_low_s8(_xi)); + int16x8_t _s1 = vmull_s8(vget_high_s8(_w0), vget_low_s8(_xi)); + int16x8_t _s2 = vmull_s8(vget_low_s8(_w1), vget_low_s8(_xi)); + int16x8_t _s3 = vmull_s8(vget_high_s8(_w1), vget_low_s8(_xi)); + _s0 = vmlal_s8(_s0, vget_low_s8(_w2), vget_high_s8(_xi)); + _s1 = vmlal_s8(_s1, vget_high_s8(_w2), vget_high_s8(_xi)); + _s2 = vmlal_s8(_s2, vget_low_s8(_w3), vget_high_s8(_xi)); + _s3 = vmlal_s8(_s3, vget_high_s8(_w3), vget_high_s8(_xi)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + kptr += 64; +#endif + } + for (; i + 7 < size; i += 8) + { + int8x8_t _xi = vld1_s8(x + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_w0), _xi); + int16x8_t _s1 = vmull_s8(vget_high_s8(_w0), _xi); + int16x8_t _s2 = vmull_s8(vget_low_s8(_w1), _xi); + int16x8_t _s3 = vmull_s8(vget_high_s8(_w1), _xi); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + kptr += 32; + } + { + int32x4x2_t _tmp0 = vzipq_s32(_sum0, _sum1); + int32x4x2_t _tmp1 = vzipq_s32(_sum2, _sum3); + _sum0 = vcombine_s32(vget_low_s32(_tmp0.val[0]), vget_low_s32(_tmp1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_tmp0.val[0]), vget_high_s32(_tmp1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_tmp0.val[1]), vget_low_s32(_tmp1.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_tmp0.val[1]), vget_high_s32(_tmp1.val[1])); + } + _rnn_Hx0 = vaddq_s32(_rnn_Hx0, _sum0); + _rnn_Hx0 = vaddq_s32(_rnn_Hx0, _sum1); + _rnn_Hx0 = vaddq_s32(_rnn_Hx0, _sum2); + _rnn_Hx0 = vaddq_s32(_rnn_Hx0, _sum3); #endif // __ARM_FEATURE_DOTPROD for (; i + 3 < size; i += 4) { @@ -491,6 +590,87 @@ static void rnn_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de _rnn_Hh0 = vaddq_s32(_rnn_Hh0, _sum1); _rnn_Hh0 = vaddq_s32(_rnn_Hh0, _sum2); _rnn_Hh0 = vaddq_s32(_rnn_Hh0, _sum3); +#else + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + for (; i + 15 < num_output; i += 16) + { +#if NCNN_GNU_INLINE_ASM && !__aarch64__ + const signed char* hsptr = hs + i; + + asm volatile( + "vldm %1!, {d0-d7} \n" + "vld1.s8 {d16-d17}, [%0] \n" + "vmull.s8 q4, d0, d16 \n" + "vmull.s8 q5, d1, d16 \n" + "vmull.s8 q6, d2, d16 \n" + "vmull.s8 q7, d3, d16 \n" + "vmlal.s8 q4, d4, d17 \n" + "vmlal.s8 q5, d5, d17 \n" + "vmlal.s8 q6, d6, d17 \n" + "vmlal.s8 q7, d7, d17 \n" + "vpadal.s16 %q2, q4 \n" + "vpadal.s16 %q3, q5 \n" + "vpadal.s16 %q4, q6 \n" + "vpadal.s16 %q5, q7 \n" + : "=r"(hsptr), "=r"(kptr), "=w"(_sum0), "=w"(_sum1), "=w"(_sum2), "=w"(_sum3) + : "0"(hsptr), "1"(kptr), "2"(_sum0), "3"(_sum1), "4"(_sum2), "5"(_sum3) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8" + ); +#else + int8x16_t _h_cont = vld1q_s8(hs + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + int8x16_t _w2 = vld1q_s8(kptr + 32); + int8x16_t _w3 = vld1q_s8(kptr + 48); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_w0), vget_low_s8(_h_cont)); + int16x8_t _s1 = vmull_s8(vget_high_s8(_w0), vget_low_s8(_h_cont)); + int16x8_t _s2 = vmull_s8(vget_low_s8(_w1), vget_low_s8(_h_cont)); + int16x8_t _s3 = vmull_s8(vget_high_s8(_w1), vget_low_s8(_h_cont)); + _s0 = vmlal_s8(_s0, vget_low_s8(_w2), vget_high_s8(_h_cont)); + _s1 = vmlal_s8(_s1, vget_high_s8(_w2), vget_high_s8(_h_cont)); + _s2 = vmlal_s8(_s2, vget_low_s8(_w3), vget_high_s8(_h_cont)); + _s3 = vmlal_s8(_s3, vget_high_s8(_w3), vget_high_s8(_h_cont)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + kptr += 64; +#endif + } + for (; i + 7 < num_output; i += 8) + { + int8x8_t _h_cont = vld1_s8(hs + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_w0), _h_cont); + int16x8_t _s1 = vmull_s8(vget_high_s8(_w0), _h_cont); + int16x8_t _s2 = vmull_s8(vget_low_s8(_w1), _h_cont); + int16x8_t _s3 = vmull_s8(vget_high_s8(_w1), _h_cont); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + kptr += 32; + } + { + int32x4x2_t _tmp0 = vzipq_s32(_sum0, _sum1); + int32x4x2_t _tmp1 = vzipq_s32(_sum2, _sum3); + _sum0 = vcombine_s32(vget_low_s32(_tmp0.val[0]), vget_low_s32(_tmp1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_tmp0.val[0]), vget_high_s32(_tmp1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_tmp0.val[1]), vget_low_s32(_tmp1.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_tmp0.val[1]), vget_high_s32(_tmp1.val[1])); + } + _rnn_Hh0 = vaddq_s32(_rnn_Hh0, _sum0); + _rnn_Hh0 = vaddq_s32(_rnn_Hh0, _sum1); + _rnn_Hh0 = vaddq_s32(_rnn_Hh0, _sum2); + _rnn_Hh0 = vaddq_s32(_rnn_Hh0, _sum3); #endif // __ARM_FEATURE_DOTPROD for (; i + 3 < num_output; i += 4) {