@@ -1148,27 +1148,40 @@ static int gru_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c
1148
1148
int i = 0 ;
1149
1149
for (; i + 3 < size; i += 4 )
1150
1150
{
1151
- #if 0 // NCNN_GNU_INLINE_ASM
1151
+ #if NCNN_GNU_INLINE_ASM
1152
1152
asm volatile (
1153
+ " ld1 {v6.16b, v7.16b}, [%1], #32 \n "
1153
1154
" ld1 {v4.4h}, [%0], #8 \n "
1154
- "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n"
1155
+ " sxtl v0.8h, v6.8b \n "
1156
+ " sxtl2 v1.8h, v6.16b \n "
1157
+ " sxtl v2.8h, v7.8b \n "
1158
+ " sxtl2 v3.8h, v7.16b \n "
1159
+ " scvtf v0.8h, v0.8h \n "
1160
+ " scvtf v1.8h, v1.8h \n "
1161
+ " scvtf v2.8h, v2.8h \n "
1162
+ " scvtf v3.8h, v3.8h \n "
1163
+ " fmul v0.8h, v0.8h, %12.8h \n "
1164
+ " fmul v1.8h, v1.8h, %12.8h \n "
1165
+ " fmul v2.8h, v2.8h, %12.8h \n "
1166
+ " fmul v3.8h, v3.8h, %12.8h \n "
1155
1167
" fmla %2.8h, v0.8h, v4.h[0] \n "
1156
1168
" fmla %3.8h, v1.8h, v4.h[1] \n "
1157
1169
" fmla %4.8h, v2.8h, v4.h[2] \n "
1158
1170
" fmla %5.8h, v3.8h, v4.h[3] \n "
1159
1171
: " =r" (x),
1160
- "=r"(weight_xc_RUN ),
1172
+ " =r" (weight_xc_int8_RUN ),
1161
1173
" =w" (_RU),
1162
1174
" =w" (_sum1),
1163
1175
" =w" (_sum2),
1164
1176
" =w" (_sum3)
1165
1177
: " 0" (x),
1166
- "1"(weight_xc_RUN ),
1178
+ " 1" (weight_xc_int8_RUN ),
1167
1179
" 2" (_RU),
1168
1180
" 3" (_sum1),
1169
1181
" 4" (_sum2),
1170
- "5"(_sum3)
1171
- : "memory", "v0", "v1", "v2", "v3", "v4");
1182
+ " 5" (_sum3),
1183
+ " w" (_descale_xc_RU)
1184
+ : " memory" , " v0" , " v1" , " v2" , " v3" , " v4" , " v6" , " v7" );
1172
1185
#else // NCNN_GNU_INLINE_ASM
1173
1186
float16x4_t _x = vld1_f16 (x);
1174
1187
@@ -1207,28 +1220,41 @@ static int gru_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c
1207
1220
i = 0 ;
1208
1221
for (; i + 3 < num_output; i += 4 )
1209
1222
{
1210
- #if 0 // NCNN_GNU_INLINE_ASM
1223
+ #if NCNN_GNU_INLINE_ASM
1211
1224
asm volatile (
1225
+ " ld1 {v6.8h, v7.8h}, [%1], #32 \n "
1212
1226
" ld1 {v4.4s}, [%0], #16 \n "
1213
- "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n"
1227
+ " sxtl v0.8h, v6.8b \n "
1228
+ " sxtl2 v1.8h, v6.16b \n "
1229
+ " sxtl v2.8h, v7.8b \n "
1230
+ " sxtl2 v3.8h, v7.16b \n "
1231
+ " scvtf v0.8h, v0.8h \n "
1232
+ " scvtf v1.8h, v1.8h \n "
1233
+ " scvtf v2.8h, v2.8h \n "
1234
+ " scvtf v3.8h, v3.8h \n "
1214
1235
" fcvtn v4.4h, v4.4s \n "
1236
+ " fmul v0.8h, v0.8h, %12.8h \n "
1237
+ " fmul v1.8h, v1.8h, %12.8h \n "
1238
+ " fmul v2.8h, v2.8h, %12.8h \n "
1239
+ " fmul v3.8h, v3.8h, %12.8h \n "
1215
1240
" fmla %2.8h, v0.8h, v4.h[0] \n "
1216
1241
" fmla %3.8h, v1.8h, v4.h[1] \n "
1217
1242
" fmla %4.8h, v2.8h, v4.h[2] \n "
1218
1243
" fmla %5.8h, v3.8h, v4.h[3] \n "
1219
1244
: " =r" (hidden_ptr),
1220
- "=r"(weight_hc_RUN ),
1245
+ " =r" (weight_hc_int8_RUN ),
1221
1246
" =w" (_RU),
1222
1247
" =w" (_sum1),
1223
1248
" =w" (_sum2),
1224
1249
" =w" (_sum3)
1225
1250
: " 0" (hidden_ptr),
1226
- "1"(weight_hc_RUN ),
1251
+ " 1" (weight_hc_int8_RUN ),
1227
1252
" 2" (_RU),
1228
1253
" 3" (_sum1),
1229
1254
" 4" (_sum2),
1230
- "5"(_sum3)
1231
- : "memory", "v0", "v1", "v2", "v3", "v4");
1255
+ " 5" (_sum3),
1256
+ " w" (_descale_hc_RU)
1257
+ : " memory" , " v0" , " v1" , " v2" , " v3" , " v4" , " v6" , " v7" );
1232
1258
#else // NCNN_GNU_INLINE_ASM
1233
1259
float16x4_t _h_cont = vcvt_f16_f32 (vld1q_f32 (hidden_ptr));
1234
1260
@@ -1282,43 +1308,54 @@ static int gru_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c
1282
1308
1283
1309
float16x4_t _descale_xc_N = vld1_f16 (weight_xc_int8_descales_RUN + 8 );
1284
1310
float16x4_t _descale_hc_N = vld1_f16 (weight_hc_int8_descales_RUN + 8 );
1311
+ float16x8_t _descale_xc_NN = vcombine_f16 (_descale_xc_N, _descale_xc_N);
1312
+ float16x8_t _descale_hc_NN = vcombine_f16 (_descale_hc_N, _descale_hc_N);
1285
1313
1286
1314
i = 0 ;
1287
1315
for (; i + 3 < num_output; i += 4 )
1288
1316
{
1289
- #if 0 // NCNN_GNU_INLINE_ASM
1317
+ #if NCNN_GNU_INLINE_ASM
1290
1318
asm volatile (
1319
+ " ld1 {v5.16b}, [%1], #16 \n "
1291
1320
" ld1 {v4.4s}, [%0], #16 \n "
1292
- "ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%1], #32 \n"
1321
+ " sxtl v0.8h, v5.8b \n "
1322
+ " sxtl2 v2.8h, v5.16b \n "
1323
+ " scvtf v0.8h, v0.8h \n "
1324
+ " scvtf v2.8h, v2.8h \n "
1293
1325
" fcvtn v4.4h, v4.4s \n "
1326
+ " fmul v0.8h, v0.8h, %12.8h \n "
1327
+ " fmul v2.8h, v2.8h, %12.8h \n "
1328
+ " mov v1.d[0], v0.d[1] \n "
1329
+ " mov v3.d[0], v2.d[1] \n "
1294
1330
" fmla %2.4h, v0.4h, v4.h[0] \n "
1295
1331
" fmla %3.4h, v1.4h, v4.h[1] \n "
1296
1332
" fmla %4.4h, v2.4h, v4.h[2] \n "
1297
1333
" fmla %5.4h, v3.4h, v4.h[3] \n "
1298
1334
: " =r" (hidden_ptr),
1299
- "=r"(weight_hc_RUN ),
1335
+ " =r" (weight_hc_int8_RUN ),
1300
1336
" =w" (_gru_N),
1301
1337
" =w" (_sum4),
1302
1338
" =w" (_sum5),
1303
1339
" =w" (_sum6)
1304
1340
: " 0" (hidden_ptr),
1305
- "1"(weight_hc_RUN ),
1341
+ " 1" (weight_hc_int8_RUN ),
1306
1342
" 2" (_gru_N),
1307
1343
" 3" (_sum4),
1308
1344
" 4" (_sum5),
1309
- "5"(_sum6)
1310
- : "memory", "v0", "v1", "v2", "v3", "v4");
1345
+ " 5" (_sum6),
1346
+ " w" (_descale_hc_NN)
1347
+ : " memory" , " v0" , " v1" , " v2" , " v3" , " v4" , " v5" );
1311
1348
#else // NCNN_GNU_INLINE_ASM
1312
1349
float16x4_t _h_cont = vcvt_f16_f32 (vld1q_f32 (hidden_ptr));
1313
1350
1314
1351
int8x16_t _weight_hc_N0123 = vld1q_s8 (weight_hc_int8_RUN);
1315
- float16x8_t _weight_hc_N01 = vcvtq_f16_s16 (vmovl_s8 (vget_low_s8 (_weight_hc_N0123)));
1316
- float16x8_t _weight_hc_N23 = vcvtq_f16_s16 (vmovl_s8 (vget_high_s8 (_weight_hc_N0123)));
1352
+ float16x8_t _weight_hc_N01 = vmulq_f16 ( vcvtq_f16_s16 (vmovl_s8 (vget_low_s8 (_weight_hc_N0123))), _descale_hc_NN );
1353
+ float16x8_t _weight_hc_N23 = vmulq_f16 ( vcvtq_f16_s16 (vmovl_s8 (vget_high_s8 (_weight_hc_N0123))), _descale_hc_NN );
1317
1354
1318
- float16x4_t _w0 = vmul_f16 ( vget_low_s16 ( _weight_hc_N01), _descale_hc_N );
1319
- float16x4_t _w1 = vmul_f16 ( vget_high_f16 (_weight_hc_N01), _descale_hc_N );
1320
- float16x4_t _w2 = vmul_f16 ( vget_low_f16 (_weight_hc_N23), _descale_hc_N );
1321
- float16x4_t _w3 = vmul_f16 ( vget_high_f16 (_weight_hc_N23), _descale_hc_N );
1355
+ float16x4_t _w0 = vget_low_f16 ( _weight_hc_N01);
1356
+ float16x4_t _w1 = vget_high_f16 (_weight_hc_N01);
1357
+ float16x4_t _w2 = vget_low_f16 (_weight_hc_N23);
1358
+ float16x4_t _w3 = vget_high_f16 (_weight_hc_N23);
1322
1359
1323
1360
_gru_N = vfma_lane_f16 (_gru_N, _w0, _h_cont, 0 );
1324
1361
_sum4 = vfma_lane_f16 (_sum4, _w1, _h_cont, 1 );
@@ -1352,38 +1389,47 @@ static int gru_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c
1352
1389
i = 0 ;
1353
1390
for (; i + 3 < size; i += 4 )
1354
1391
{
1355
- #if 0 // NCNN_GNU_INLINE_ASM
1392
+ #if NCNN_GNU_INLINE_ASM
1356
1393
asm volatile (
1394
+ " ld1 {v5.16b}, [%1], #16 \n "
1357
1395
" ld1 {v4.4h}, [%0], #8 \n "
1358
- "ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%1], #32 \n"
1396
+ " sxtl v0.8h, v5.8b \n "
1397
+ " sxtl2 v2.8h, v5.16b \n "
1398
+ " scvtf v0.8h, v0.8h \n "
1399
+ " scvtf v2.8h, v2.8h \n "
1400
+ " fmul v0.8h, v0.8h, %12.8h \n "
1401
+ " fmul v2.8h, v2.8h, %12.8h \n "
1402
+ " mov v1.d[0], v0.d[1] \n "
1403
+ " mov v3.d[0], v2.d[1] \n "
1359
1404
" fmla %2.4h, v0.4h, v4.h[0] \n "
1360
1405
" fmla %3.4h, v1.4h, v4.h[1] \n "
1361
1406
" fmla %4.4h, v2.4h, v4.h[2] \n "
1362
1407
" fmla %5.4h, v3.4h, v4.h[3] \n "
1363
1408
: " =r" (x),
1364
- "=r"(weight_xc_RUN ),
1409
+ " =r" (weight_xc_int8_RUN ),
1365
1410
" =w" (_gru_N),
1366
1411
" =w" (_sum4),
1367
1412
" =w" (_sum5),
1368
1413
" =w" (_sum6)
1369
1414
: " 0" (x),
1370
- "1"(weight_xc_RUN ),
1415
+ " 1" (weight_xc_int8_RUN ),
1371
1416
" 2" (_gru_N),
1372
1417
" 3" (_sum4),
1373
1418
" 4" (_sum5),
1374
- "5"(_sum6)
1375
- : "memory", "v0", "v1", "v2", "v3", "v4");
1419
+ " 5" (_sum6),
1420
+ " w" (_descale_xc_NN)
1421
+ : " memory" , " v0" , " v1" , " v2" , " v3" , " v4" , " v5" );
1376
1422
#else // NCNN_GNU_INLINE_ASM
1377
1423
float16x4_t _x = vld1_f16 (x);
1378
1424
1379
1425
int8x16_t _weight_xc_N0123 = vld1q_s8 (weight_xc_int8_RUN);
1380
- float16x8_t _weight_xc_N01 = vcvtq_f16_s16 (vmovl_s8 (vget_low_s8 (_weight_xc_N0123)));
1381
- float16x8_t _weight_xc_N23 = vcvtq_f16_s16 (vmovl_s8 (vget_high_s8 (_weight_xc_N0123)));
1426
+ float16x8_t _weight_xc_N01 = vmulq_f16 ( vcvtq_f16_s16 (vmovl_s8 (vget_low_s8 (_weight_xc_N0123))), _descale_xc_NN );
1427
+ float16x8_t _weight_xc_N23 = vmulq_f16 ( vcvtq_f16_s16 (vmovl_s8 (vget_high_s8 (_weight_xc_N0123))), _descale_xc_NN );
1382
1428
1383
- float16x4_t _w0 = vmul_f16 ( vget_low_s16 ( _weight_xc_N01), _descale_xc_N );
1384
- float16x4_t _w1 = vmul_f16 ( vget_high_f16 (_weight_xc_N01), _descale_xc_N );
1385
- float16x4_t _w2 = vmul_f16 ( vget_low_f16 (_weight_xc_N23), _descale_xc_N );
1386
- float16x4_t _w3 = vmul_f16 ( vget_high_f16 (_weight_xc_N23), _descale_xc_N );
1429
+ float16x4_t _w0 = vget_low_f16 ( _weight_xc_N01);
1430
+ float16x4_t _w1 = vget_high_f16 (_weight_xc_N01);
1431
+ float16x4_t _w2 = vget_low_f16 (_weight_xc_N23);
1432
+ float16x4_t _w3 = vget_high_f16 (_weight_xc_N23);
1387
1433
1388
1434
_gru_N = vfma_lane_f16 (_gru_N, _w0, _x, 0 );
1389
1435
_sum4 = vfma_lane_f16 (_sum4, _w1, _x, 1 );
0 commit comments