@@ -17,6 +17,10 @@ void lstm_transform_weight_int8_asimddp(const Mat& weight_xc, const Mat& weight_
17
17
void lstm_int8_asimddp (const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int elemtype, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt);
18
18
#endif
19
19
20
+ #if NCNN_RUNTIME_CPU && NCNN_VFPV4 && __ARM_NEON && !(__ARM_FP & 2)
21
+ void lstm_int8_gate_output_vfpv4 (const Mat& gates, const Mat& weight_hr, Mat& hidden_state, Mat& tmp_hidden_state, Mat& cell_state, Mat& top_blob, int ti, int elemtype, const Option& opt);
22
+ #endif
23
+
20
24
static void lstm_transform_weight_int8 (const Mat& weight_xc, const Mat& weight_xc_int8_scales, const Mat& weight_hc, const Mat& weight_hc_int8_scales, const Mat& bias_c, Mat& weight_data_tm, Mat& weight_data_tm_int8_descales, Mat& bias_c_tm, int size, int num_output, int num_directions, int hidden_size, const Option& opt)
21
25
{
22
26
// TODO dispatch for __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
@@ -181,6 +185,193 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x
181
185
}
182
186
}
183
187
188
+ static void lstm_int8_gate_output (const Mat& gates, const Mat& weight_hr, Mat& hidden_state, Mat& tmp_hidden_state, Mat& cell_state, Mat& top_blob, int ti, int elemtype, const Option& opt)
189
+ {
190
+ #if NCNN_RUNTIME_CPU && NCNN_VFPV4 && __ARM_NEON && !(__ARM_FP & 2)
191
+ if (ncnn::cpu_support_arm_vfpv4 ())
192
+ {
193
+ lstm_int8_gate_output_vfpv4 (gates, weight_hr, hidden_state, tmp_hidden_state, cell_state, top_blob, ti, elemtype, opt);
194
+ return ;
195
+ }
196
+ #endif
197
+
198
+ const int num_output = top_blob.w ;
199
+ const int hidden_size = cell_state.w ;
200
+
201
+ // lstm unit
202
+ // sigmoid(I)
203
+ // sigmoid(F)
204
+ // sigmoid(O)
205
+ // tanh(G)
206
+ // c_t := f_t .* c_{t-1} + i_t .* g_t
207
+ // h_t := o_t .* tanh[c_t]
208
+ float * output_data = top_blob.row (ti);
209
+
210
+ float * cell_ptr = cell_state;
211
+ float * hidden_ptr = hidden_state;
212
+ float * tmp_hidden_ptr = tmp_hidden_state;
213
+
214
+ int remain_hidden_size_start = 0 ;
215
+ #if __ARM_NEON
216
+ int nn_hidden_size = hidden_size >> 2 ;
217
+ remain_hidden_size_start = nn_hidden_size << 2 ;
218
+
219
+ #pragma omp parallel for num_threads(opt.num_threads)
220
+ for (int qq = 0 ; qq < nn_hidden_size; qq++)
221
+ {
222
+ int q = qq * 4 ;
223
+
224
+ const float * gates_data = gates.row (q);
225
+
226
+ float32x4x4_t _IFOG_4x4 = vld4q_f32 (gates_data);
227
+
228
+ float32x4_t _lstm_I = sigmoid_ps (_IFOG_4x4.val [0 ]);
229
+ float32x4_t _lstm_F = sigmoid_ps (_IFOG_4x4.val [1 ]);
230
+ float32x4_t _lstm_O = sigmoid_ps (_IFOG_4x4.val [2 ]);
231
+ float32x4_t _lstm_G = tanh_ps (_IFOG_4x4.val [3 ]);
232
+
233
+ float32x4_t _cell2 = vaddq_f32 (vmulq_f32 (_lstm_F, vld1q_f32 (cell_ptr + q)), vmulq_f32 (_lstm_I, _lstm_G));
234
+ float32x4_t _lstm_H = vmulq_f32 (_lstm_O, tanh_ps (_cell2));
235
+
236
+ vst1q_f32 (cell_ptr + q, _cell2);
237
+
238
+ if (num_output == hidden_size)
239
+ {
240
+ vst1q_f32 (hidden_ptr + q, _lstm_H);
241
+
242
+ if (elemtype == 1 )
243
+ {
244
+ // fp32
245
+ vst1q_f32 (output_data + q, _lstm_H);
246
+ }
247
+ if (elemtype == 2 )
248
+ {
249
+ // fp16
250
+ unsigned short * outptr = (unsigned short *)output_data + q;
251
+ #if (__ARM_FP & 2)
252
+ #if NCNN_GNU_INLINE_ASM
253
+ #if __aarch64__
254
+ asm volatile (
255
+ " fcvtn v0.4h, %2.4s \n "
256
+ " st1 {v0.4h}, [%0] \n "
257
+ : " =r" (outptr) // %0
258
+ : " 0" (outptr),
259
+ " w" (_lstm_H)
260
+ : " memory" , " v0" );
261
+ #else // __aarch64__
262
+ asm volatile (
263
+ " vcvt.f16.f32 d0, %q2 \n "
264
+ " vst1.u16 {d0}, [%0] \n "
265
+ : " =r" (outptr) // %0
266
+ : " 0" (outptr),
267
+ " w" (_lstm_H)
268
+ : " memory" , " q0" );
269
+ #endif // __aarch64__
270
+ #else // NCNN_GNU_INLINE_ASM
271
+ vst1_u16 (outptr, (uint16x4_t )vcvt_f16_f32 (_lstm_H));
272
+ #endif // NCNN_GNU_INLINE_ASM
273
+ #else
274
+ outptr[q] = float32_to_float16 (hidden_ptr[q]);
275
+ outptr[q + 1 ] = float32_to_float16 (hidden_ptr[q + 1 ]);
276
+ outptr[q + 2 ] = float32_to_float16 (hidden_ptr[q + 2 ]);
277
+ outptr[q + 3 ] = float32_to_float16 (hidden_ptr[q + 3 ]);
278
+ #endif // (__ARM_FP & 2)
279
+ }
280
+ if (elemtype == 4 )
281
+ {
282
+ // bf16
283
+ vst1_u16 ((unsigned short *)output_data + q, float2bfloat (_lstm_H));
284
+ }
285
+ }
286
+ else
287
+ {
288
+ vst1q_f32 (tmp_hidden_ptr + q, _lstm_H);
289
+ }
290
+ }
291
+ #endif // __ARM_NEON
292
+ #pragma omp parallel for num_threads(opt.num_threads)
293
+ for (int q = remain_hidden_size_start; q < hidden_size; q++)
294
+ {
295
+ const float * gates_data = gates.row (q);
296
+
297
+ float I = gates_data[0 ];
298
+ float F = gates_data[1 ];
299
+ float O = gates_data[2 ];
300
+ float G = gates_data[3 ];
301
+
302
+ I = 1 .f / (1 .f + expf (-I));
303
+ F = 1 .f / (1 .f + expf (-F));
304
+ O = 1 .f / (1 .f + expf (-O));
305
+ G = tanhf (G);
306
+
307
+ float cell2 = F * cell_ptr[q] + I * G;
308
+ float H = O * tanhf (cell2);
309
+
310
+ cell_ptr[q] = cell2;
311
+ if (num_output == hidden_size)
312
+ {
313
+ hidden_ptr[q] = H;
314
+
315
+ if (elemtype == 1 )
316
+ {
317
+ output_data[q] = H;
318
+ }
319
+ if (elemtype == 2 )
320
+ {
321
+ ((unsigned short *)output_data)[q] = float32_to_float16 (H);
322
+ }
323
+ if (elemtype == 4 )
324
+ {
325
+ ((unsigned short *)output_data)[q] = float32_to_bfloat16 (H);
326
+ }
327
+ }
328
+ else
329
+ {
330
+ tmp_hidden_ptr[q] = H;
331
+ }
332
+ }
333
+
334
+ if (num_output != hidden_size)
335
+ {
336
+ // int nn_num_output = num_output >> 2;
337
+ // int remain_num_output_start = nn_num_output << 2;
338
+ // #pragma omp parallel for num_threads(opt.num_threads)
339
+ // for (int qq = 0; qq < nn_num_output; qq++)
340
+ // {
341
+ // int q = qq * 4;
342
+ //
343
+ // }
344
+ int remain_num_output_start = 0 ;
345
+ #pragma omp parallel for num_threads(opt.num_threads)
346
+ for (int q = remain_num_output_start; q < num_output; q++)
347
+ {
348
+ const float * hr = weight_hr.row (q);
349
+ const float * tmp_hidden_ptr = tmp_hidden_state;
350
+
351
+ float H = 0 ;
352
+ for (int i = 0 ; i < hidden_size; i++)
353
+ {
354
+ H += tmp_hidden_ptr[i] * hr[i];
355
+ }
356
+
357
+ hidden_ptr[q] = H;
358
+
359
+ if (elemtype == 1 )
360
+ {
361
+ output_data[q] = H;
362
+ }
363
+ if (elemtype == 2 )
364
+ {
365
+ ((unsigned short *)output_data)[q] = float32_to_float16 (H);
366
+ }
367
+ if (elemtype == 4 )
368
+ {
369
+ ((unsigned short *)output_data)[q] = float32_to_bfloat16 (H);
370
+ }
371
+ }
372
+ }
373
+ }
374
+
184
375
static void lstm_int8 (const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int elemtype, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt)
185
376
{
186
377
// TODO dispatch for __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
@@ -476,149 +667,6 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d
476
667
#endif // __ARM_NEON
477
668
}
478
669
479
- // lstm unit
480
- // sigmoid(I)
481
- // sigmoid(F)
482
- // sigmoid(O)
483
- // tanh(G)
484
- // c_t := f_t .* c_{t-1} + i_t .* g_t
485
- // h_t := o_t .* tanh[c_t]
486
- float * output_data = top_blob.row (ti);
487
-
488
- float * cell_ptr = cell_state;
489
- float * hidden_ptr = hidden_state;
490
- float * tmp_hidden_ptr = tmp_hidden_state;
491
-
492
- int remain_hidden_size_start = 0 ;
493
- #if __ARM_NEON
494
- int nn_hidden_size = hidden_size >> 2 ;
495
- remain_hidden_size_start = nn_hidden_size << 2 ;
496
-
497
- #pragma omp parallel for num_threads(opt.num_threads)
498
- for (int qq = 0 ; qq < nn_hidden_size; qq++)
499
- {
500
- int q = qq * 4 ;
501
-
502
- const float * gates_data = gates.row (q);
503
-
504
- float32x4x4_t _IFOG_4x4 = vld4q_f32 (gates_data);
505
-
506
- float32x4_t _lstm_I = sigmoid_ps (_IFOG_4x4.val [0 ]);
507
- float32x4_t _lstm_F = sigmoid_ps (_IFOG_4x4.val [1 ]);
508
- float32x4_t _lstm_O = sigmoid_ps (_IFOG_4x4.val [2 ]);
509
- float32x4_t _lstm_G = tanh_ps (_IFOG_4x4.val [3 ]);
510
-
511
- float32x4_t _cell2 = vaddq_f32 (vmulq_f32 (_lstm_F, vld1q_f32 (cell_ptr + q)), vmulq_f32 (_lstm_I, _lstm_G));
512
- float32x4_t _lstm_H = vmulq_f32 (_lstm_O, tanh_ps (_cell2));
513
-
514
- vst1q_f32 (cell_ptr + q, _cell2);
515
-
516
- if (num_output == hidden_size)
517
- {
518
- vst1q_f32 (hidden_ptr + q, _lstm_H);
519
-
520
- if (elemtype == 1 )
521
- {
522
- // fp32
523
- vst1q_f32 (output_data + q, _lstm_H);
524
- }
525
- if (elemtype == 2 )
526
- {
527
- // fp16
528
- vst1_u16 ((unsigned short *)output_data + q, (uint16x4_t )vcvt_f16_f32 (_lstm_H));
529
- }
530
- if (elemtype == 4 )
531
- {
532
- // bf16
533
- vst1_u16 ((unsigned short *)output_data + q, float2bfloat (_lstm_H));
534
- }
535
- }
536
- else
537
- {
538
- vst1q_f32 (tmp_hidden_ptr + q, _lstm_H);
539
- }
540
- }
541
- #endif // __ARM_NEON
542
- #pragma omp parallel for num_threads(opt.num_threads)
543
- for (int q = remain_hidden_size_start; q < hidden_size; q++)
544
- {
545
- const float * gates_data = gates.row (q);
546
-
547
- float I = gates_data[0 ];
548
- float F = gates_data[1 ];
549
- float O = gates_data[2 ];
550
- float G = gates_data[3 ];
551
-
552
- I = 1 .f / (1 .f + expf (-I));
553
- F = 1 .f / (1 .f + expf (-F));
554
- O = 1 .f / (1 .f + expf (-O));
555
- G = tanhf (G);
556
-
557
- float cell2 = F * cell_ptr[q] + I * G;
558
- float H = O * tanhf (cell2);
559
-
560
- cell_ptr[q] = cell2;
561
- if (num_output == hidden_size)
562
- {
563
- hidden_ptr[q] = H;
564
-
565
- if (elemtype == 1 )
566
- {
567
- output_data[q] = H;
568
- }
569
- if (elemtype == 2 )
570
- {
571
- ((unsigned short *)output_data)[q] = float32_to_float16 (H);
572
- }
573
- if (elemtype == 4 )
574
- {
575
- ((unsigned short *)output_data)[q] = float32_to_bfloat16 (H);
576
- }
577
- }
578
- else
579
- {
580
- tmp_hidden_ptr[q] = H;
581
- }
582
- }
583
-
584
- if (num_output != hidden_size)
585
- {
586
- // int nn_num_output = num_output >> 2;
587
- // int remain_num_output_start = nn_num_output << 2;
588
- // #pragma omp parallel for num_threads(opt.num_threads)
589
- // for (int qq = 0; qq < nn_num_output; qq++)
590
- // {
591
- // int q = qq * 4;
592
- //
593
- // }
594
- int remain_num_output_start = 0 ;
595
- #pragma omp parallel for num_threads(opt.num_threads)
596
- for (int q = remain_num_output_start; q < num_output; q++)
597
- {
598
- const float * hr = weight_hr.row (q);
599
- const float * tmp_hidden_ptr = tmp_hidden_state;
600
-
601
- float H = 0 ;
602
- for (int i = 0 ; i < hidden_size; i++)
603
- {
604
- H += tmp_hidden_ptr[i] * hr[i];
605
- }
606
-
607
- hidden_ptr[q] = H;
608
-
609
- if (elemtype == 1 )
610
- {
611
- output_data[q] = H;
612
- }
613
- if (elemtype == 2 )
614
- {
615
- ((unsigned short *)output_data)[q] = float32_to_float16 (H);
616
- }
617
- if (elemtype == 4 )
618
- {
619
- ((unsigned short *)output_data)[q] = float32_to_bfloat16 (H);
620
- }
621
- }
622
- }
670
+ lstm_int8_gate_output (gates, weight_hr, hidden_state, tmp_hidden_state, cell_state, top_blob, ti, elemtype, opt);
623
671
}
624
672
}
0 commit comments