Skip to content

Commit e22a695

Browse files
committed
wip
1 parent cd0c719 commit e22a695

File tree

1 file changed

+82
-36
lines changed

1 file changed

+82
-36
lines changed

src/layer/arm/gru_arm_asimdhp.cpp

+82-36
Original file line numberDiff line numberDiff line change
@@ -1148,27 +1148,40 @@ static int gru_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c
11481148
int i = 0;
11491149
for (; i + 3 < size; i += 4)
11501150
{
1151-
#if 0 //NCNN_GNU_INLINE_ASM
1151+
#if NCNN_GNU_INLINE_ASM
11521152
asm volatile(
1153+
"ld1 {v6.16b, v7.16b}, [%1], #32 \n"
11531154
"ld1 {v4.4h}, [%0], #8 \n"
1154-
"ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n"
1155+
"sxtl v0.8h, v6.8b \n"
1156+
"sxtl2 v1.8h, v6.16b \n"
1157+
"sxtl v2.8h, v7.8b \n"
1158+
"sxtl2 v3.8h, v7.16b \n"
1159+
"scvtf v0.8h, v0.8h \n"
1160+
"scvtf v1.8h, v1.8h \n"
1161+
"scvtf v2.8h, v2.8h \n"
1162+
"scvtf v3.8h, v3.8h \n"
1163+
"fmul v0.8h, v0.8h, %12.8h \n"
1164+
"fmul v1.8h, v1.8h, %12.8h \n"
1165+
"fmul v2.8h, v2.8h, %12.8h \n"
1166+
"fmul v3.8h, v3.8h, %12.8h \n"
11551167
"fmla %2.8h, v0.8h, v4.h[0] \n"
11561168
"fmla %3.8h, v1.8h, v4.h[1] \n"
11571169
"fmla %4.8h, v2.8h, v4.h[2] \n"
11581170
"fmla %5.8h, v3.8h, v4.h[3] \n"
11591171
: "=r"(x),
1160-
"=r"(weight_xc_RUN),
1172+
"=r"(weight_xc_int8_RUN),
11611173
"=w"(_RU),
11621174
"=w"(_sum1),
11631175
"=w"(_sum2),
11641176
"=w"(_sum3)
11651177
: "0"(x),
1166-
"1"(weight_xc_RUN),
1178+
"1"(weight_xc_int8_RUN),
11671179
"2"(_RU),
11681180
"3"(_sum1),
11691181
"4"(_sum2),
1170-
"5"(_sum3)
1171-
: "memory", "v0", "v1", "v2", "v3", "v4");
1182+
"5"(_sum3),
1183+
"w"(_descale_xc_RU)
1184+
: "memory", "v0", "v1", "v2", "v3", "v4", "v6", "v7");
11721185
#else // NCNN_GNU_INLINE_ASM
11731186
float16x4_t _x = vld1_f16(x);
11741187

@@ -1207,28 +1220,41 @@ static int gru_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c
12071220
i = 0;
12081221
for (; i + 3 < num_output; i += 4)
12091222
{
1210-
#if 0 //NCNN_GNU_INLINE_ASM
1223+
#if NCNN_GNU_INLINE_ASM
12111224
asm volatile(
1225+
"ld1 {v6.8h, v7.8h}, [%1], #32 \n"
12121226
"ld1 {v4.4s}, [%0], #16 \n"
1213-
"ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n"
1227+
"sxtl v0.8h, v6.8b \n"
1228+
"sxtl2 v1.8h, v6.16b \n"
1229+
"sxtl v2.8h, v7.8b \n"
1230+
"sxtl2 v3.8h, v7.16b \n"
1231+
"scvtf v0.8h, v0.8h \n"
1232+
"scvtf v1.8h, v1.8h \n"
1233+
"scvtf v2.8h, v2.8h \n"
1234+
"scvtf v3.8h, v3.8h \n"
12141235
"fcvtn v4.4h, v4.4s \n"
1236+
"fmul v0.8h, v0.8h, %12.8h \n"
1237+
"fmul v1.8h, v1.8h, %12.8h \n"
1238+
"fmul v2.8h, v2.8h, %12.8h \n"
1239+
"fmul v3.8h, v3.8h, %12.8h \n"
12151240
"fmla %2.8h, v0.8h, v4.h[0] \n"
12161241
"fmla %3.8h, v1.8h, v4.h[1] \n"
12171242
"fmla %4.8h, v2.8h, v4.h[2] \n"
12181243
"fmla %5.8h, v3.8h, v4.h[3] \n"
12191244
: "=r"(hidden_ptr),
1220-
"=r"(weight_hc_RUN),
1245+
"=r"(weight_hc_int8_RUN),
12211246
"=w"(_RU),
12221247
"=w"(_sum1),
12231248
"=w"(_sum2),
12241249
"=w"(_sum3)
12251250
: "0"(hidden_ptr),
1226-
"1"(weight_hc_RUN),
1251+
"1"(weight_hc_int8_RUN),
12271252
"2"(_RU),
12281253
"3"(_sum1),
12291254
"4"(_sum2),
1230-
"5"(_sum3)
1231-
: "memory", "v0", "v1", "v2", "v3", "v4");
1255+
"5"(_sum3),
1256+
"w"(_descale_hc_RU)
1257+
: "memory", "v0", "v1", "v2", "v3", "v4", "v6", "v7");
12321258
#else // NCNN_GNU_INLINE_ASM
12331259
float16x4_t _h_cont = vcvt_f16_f32(vld1q_f32(hidden_ptr));
12341260

@@ -1282,43 +1308,54 @@ static int gru_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c
12821308

12831309
float16x4_t _descale_xc_N = vld1_f16(weight_xc_int8_descales_RUN + 8);
12841310
float16x4_t _descale_hc_N = vld1_f16(weight_hc_int8_descales_RUN + 8);
1311+
float16x8_t _descale_xc_NN = vcombine_f16(_descale_xc_N, _descale_xc_N);
1312+
float16x8_t _descale_hc_NN = vcombine_f16(_descale_hc_N, _descale_hc_N);
12851313

12861314
i = 0;
12871315
for (; i + 3 < num_output; i += 4)
12881316
{
1289-
#if 0 //NCNN_GNU_INLINE_ASM
1317+
#if NCNN_GNU_INLINE_ASM
12901318
asm volatile(
1319+
"ld1 {v5.16b}, [%1], #16 \n"
12911320
"ld1 {v4.4s}, [%0], #16 \n"
1292-
"ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%1], #32 \n"
1321+
"sxtl v0.8h, v5.8b \n"
1322+
"sxtl2 v2.8h, v5.16b \n"
1323+
"scvtf v0.8h, v0.8h \n"
1324+
"scvtf v2.8h, v2.8h \n"
12931325
"fcvtn v4.4h, v4.4s \n"
1326+
"fmul v0.8h, v0.8h, %12.8h \n"
1327+
"fmul v2.8h, v2.8h, %12.8h \n"
1328+
"mov v1.d[0], v0.d[1] \n"
1329+
"mov v3.d[0], v2.d[1] \n"
12941330
"fmla %2.4h, v0.4h, v4.h[0] \n"
12951331
"fmla %3.4h, v1.4h, v4.h[1] \n"
12961332
"fmla %4.4h, v2.4h, v4.h[2] \n"
12971333
"fmla %5.4h, v3.4h, v4.h[3] \n"
12981334
: "=r"(hidden_ptr),
1299-
"=r"(weight_hc_RUN),
1335+
"=r"(weight_hc_int8_RUN),
13001336
"=w"(_gru_N),
13011337
"=w"(_sum4),
13021338
"=w"(_sum5),
13031339
"=w"(_sum6)
13041340
: "0"(hidden_ptr),
1305-
"1"(weight_hc_RUN),
1341+
"1"(weight_hc_int8_RUN),
13061342
"2"(_gru_N),
13071343
"3"(_sum4),
13081344
"4"(_sum5),
1309-
"5"(_sum6)
1310-
: "memory", "v0", "v1", "v2", "v3", "v4");
1345+
"5"(_sum6),
1346+
"w"(_descale_hc_NN)
1347+
: "memory", "v0", "v1", "v2", "v3", "v4", "v5");
13111348
#else // NCNN_GNU_INLINE_ASM
13121349
float16x4_t _h_cont = vcvt_f16_f32(vld1q_f32(hidden_ptr));
13131350

13141351
int8x16_t _weight_hc_N0123 = vld1q_s8(weight_hc_int8_RUN);
1315-
float16x8_t _weight_hc_N01 = vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_hc_N0123)));
1316-
float16x8_t _weight_hc_N23 = vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_hc_N0123)));
1352+
float16x8_t _weight_hc_N01 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_hc_N0123))), _descale_hc_NN);
1353+
float16x8_t _weight_hc_N23 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_hc_N0123))), _descale_hc_NN);
13171354

1318-
float16x4_t _w0 = vmul_f16(vget_low_s16(_weight_hc_N01), _descale_hc_N);
1319-
float16x4_t _w1 = vmul_f16(vget_high_f16(_weight_hc_N01), _descale_hc_N);
1320-
float16x4_t _w2 = vmul_f16(vget_low_f16(_weight_hc_N23), _descale_hc_N);
1321-
float16x4_t _w3 = vmul_f16(vget_high_f16(_weight_hc_N23), _descale_hc_N);
1355+
float16x4_t _w0 = vget_low_f16(_weight_hc_N01);
1356+
float16x4_t _w1 = vget_high_f16(_weight_hc_N01);
1357+
float16x4_t _w2 = vget_low_f16(_weight_hc_N23);
1358+
float16x4_t _w3 = vget_high_f16(_weight_hc_N23);
13221359

13231360
_gru_N = vfma_lane_f16(_gru_N, _w0, _h_cont, 0);
13241361
_sum4 = vfma_lane_f16(_sum4, _w1, _h_cont, 1);
@@ -1352,38 +1389,47 @@ static int gru_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c
13521389
i = 0;
13531390
for (; i + 3 < size; i += 4)
13541391
{
1355-
#if 0 //NCNN_GNU_INLINE_ASM
1392+
#if NCNN_GNU_INLINE_ASM
13561393
asm volatile(
1394+
"ld1 {v5.16b}, [%1], #16 \n"
13571395
"ld1 {v4.4h}, [%0], #8 \n"
1358-
"ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%1], #32 \n"
1396+
"sxtl v0.8h, v5.8b \n"
1397+
"sxtl2 v2.8h, v5.16b \n"
1398+
"scvtf v0.8h, v0.8h \n"
1399+
"scvtf v2.8h, v2.8h \n"
1400+
"fmul v0.8h, v0.8h, %12.8h \n"
1401+
"fmul v2.8h, v2.8h, %12.8h \n"
1402+
"mov v1.d[0], v0.d[1] \n"
1403+
"mov v3.d[0], v2.d[1] \n"
13591404
"fmla %2.4h, v0.4h, v4.h[0] \n"
13601405
"fmla %3.4h, v1.4h, v4.h[1] \n"
13611406
"fmla %4.4h, v2.4h, v4.h[2] \n"
13621407
"fmla %5.4h, v3.4h, v4.h[3] \n"
13631408
: "=r"(x),
1364-
"=r"(weight_xc_RUN),
1409+
"=r"(weight_xc_int8_RUN),
13651410
"=w"(_gru_N),
13661411
"=w"(_sum4),
13671412
"=w"(_sum5),
13681413
"=w"(_sum6)
13691414
: "0"(x),
1370-
"1"(weight_xc_RUN),
1415+
"1"(weight_xc_int8_RUN),
13711416
"2"(_gru_N),
13721417
"3"(_sum4),
13731418
"4"(_sum5),
1374-
"5"(_sum6)
1375-
: "memory", "v0", "v1", "v2", "v3", "v4");
1419+
"5"(_sum6),
1420+
"w"(_descale_xc_NN)
1421+
: "memory", "v0", "v1", "v2", "v3", "v4", "v5");
13761422
#else // NCNN_GNU_INLINE_ASM
13771423
float16x4_t _x = vld1_f16(x);
13781424

13791425
int8x16_t _weight_xc_N0123 = vld1q_s8(weight_xc_int8_RUN);
1380-
float16x8_t _weight_xc_N01 = vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_xc_N0123)));
1381-
float16x8_t _weight_xc_N23 = vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_xc_N0123)));
1426+
float16x8_t _weight_xc_N01 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_xc_N0123))), _descale_xc_NN);
1427+
float16x8_t _weight_xc_N23 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_xc_N0123))), _descale_xc_NN);
13821428

1383-
float16x4_t _w0 = vmul_f16(vget_low_s16(_weight_xc_N01), _descale_xc_N);
1384-
float16x4_t _w1 = vmul_f16(vget_high_f16(_weight_xc_N01), _descale_xc_N);
1385-
float16x4_t _w2 = vmul_f16(vget_low_f16(_weight_xc_N23), _descale_xc_N);
1386-
float16x4_t _w3 = vmul_f16(vget_high_f16(_weight_xc_N23), _descale_xc_N);
1429+
float16x4_t _w0 = vget_low_f16(_weight_xc_N01);
1430+
float16x4_t _w1 = vget_high_f16(_weight_xc_N01);
1431+
float16x4_t _w2 = vget_low_f16(_weight_xc_N23);
1432+
float16x4_t _w3 = vget_high_f16(_weight_xc_N23);
13871433

13881434
_gru_N = vfma_lane_f16(_gru_N, _w0, _x, 0);
13891435
_sum4 = vfma_lane_f16(_sum4, _w1, _x, 1);

0 commit comments

Comments
 (0)