Skip to content

Commit e0575ea

Browse files
committed
wip
1 parent 5fd9ab3 commit e0575ea

File tree

3 files changed

+223
-145
lines changed

3 files changed

+223
-145
lines changed

src/layer/arm/lstm_arm_vfpv4.cpp

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// Tencent is pleased to support the open source community by making ncnn available.
2+
//
3+
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
4+
//
5+
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6+
// in compliance with the License. You may obtain a copy of the License at
7+
//
8+
// https://opensource.org/licenses/BSD-3-Clause
9+
//
10+
// Unless required by applicable law or agreed to in writing, software distributed
11+
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12+
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13+
// specific language governing permissions and limitations under the License.
14+
15+
#include "cpu.h"
16+
#include "mat.h"
17+
#include "layer.h"
18+
#include "arm_activation.h"
19+
#include "arm_usability.h"
20+
21+
namespace ncnn {
22+
23+
#include "lstm_int8.h"
24+
25+
void lstm_int8_gate_output_vfpv4(const Mat& gates, const Mat& weight_hr, Mat& hidden_state, Mat& tmp_hidden_state, Mat& cell_state, Mat& top_blob, int ti, int elemtype, const Option& opt)
26+
{
27+
lstm_int8_gate_output(gates, weight_hr, hidden_state, tmp_hidden_state, cell_state, top_blob, ti, elemtype, opt);
28+
}
29+
30+
} // namespace ncnn

src/layer/arm/lstm_int8.h

+192-144
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ void lstm_transform_weight_int8_asimddp(const Mat& weight_xc, const Mat& weight_
1717
void lstm_int8_asimddp(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int elemtype, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt);
1818
#endif
1919

20+
#if NCNN_RUNTIME_CPU && NCNN_VFPV4 && __ARM_NEON && !(__ARM_FP & 2)
21+
void lstm_int8_gate_output_vfpv4(const Mat& gates, const Mat& weight_hr, Mat& hidden_state, Mat& tmp_hidden_state, Mat& cell_state, Mat& top_blob, int ti, int elemtype, const Option& opt);
22+
#endif
23+
2024
static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_xc_int8_scales, const Mat& weight_hc, const Mat& weight_hc_int8_scales, const Mat& bias_c, Mat& weight_data_tm, Mat& weight_data_tm_int8_descales, Mat& bias_c_tm, int size, int num_output, int num_directions, int hidden_size, const Option& opt)
2125
{
2226
// TODO dispatch for __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
@@ -181,6 +185,193 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x
181185
}
182186
}
183187

188+
static void lstm_int8_gate_output(const Mat& gates, const Mat& weight_hr, Mat& hidden_state, Mat& tmp_hidden_state, Mat& cell_state, Mat& top_blob, int ti, int elemtype, const Option& opt)
189+
{
190+
#if NCNN_RUNTIME_CPU && NCNN_VFPV4 && __ARM_NEON && !(__ARM_FP & 2)
191+
if (ncnn::cpu_support_arm_vfpv4())
192+
{
193+
lstm_int8_gate_output_vfpv4(gates, weight_hr, hidden_state, tmp_hidden_state, cell_state, top_blob, ti, elemtype, opt);
194+
return;
195+
}
196+
#endif
197+
198+
const int num_output = top_blob.w;
199+
const int hidden_size = cell_state.w;
200+
201+
// lstm unit
202+
// sigmoid(I)
203+
// sigmoid(F)
204+
// sigmoid(O)
205+
// tanh(G)
206+
// c_t := f_t .* c_{t-1} + i_t .* g_t
207+
// h_t := o_t .* tanh[c_t]
208+
float* output_data = top_blob.row(ti);
209+
210+
float* cell_ptr = cell_state;
211+
float* hidden_ptr = hidden_state;
212+
float* tmp_hidden_ptr = tmp_hidden_state;
213+
214+
int remain_hidden_size_start = 0;
215+
#if __ARM_NEON
216+
int nn_hidden_size = hidden_size >> 2;
217+
remain_hidden_size_start = nn_hidden_size << 2;
218+
219+
#pragma omp parallel for num_threads(opt.num_threads)
220+
for (int qq = 0; qq < nn_hidden_size; qq++)
221+
{
222+
int q = qq * 4;
223+
224+
const float* gates_data = gates.row(q);
225+
226+
float32x4x4_t _IFOG_4x4 = vld4q_f32(gates_data);
227+
228+
float32x4_t _lstm_I = sigmoid_ps(_IFOG_4x4.val[0]);
229+
float32x4_t _lstm_F = sigmoid_ps(_IFOG_4x4.val[1]);
230+
float32x4_t _lstm_O = sigmoid_ps(_IFOG_4x4.val[2]);
231+
float32x4_t _lstm_G = tanh_ps(_IFOG_4x4.val[3]);
232+
233+
float32x4_t _cell2 = vaddq_f32(vmulq_f32(_lstm_F, vld1q_f32(cell_ptr + q)), vmulq_f32(_lstm_I, _lstm_G));
234+
float32x4_t _lstm_H = vmulq_f32(_lstm_O, tanh_ps(_cell2));
235+
236+
vst1q_f32(cell_ptr + q, _cell2);
237+
238+
if (num_output == hidden_size)
239+
{
240+
vst1q_f32(hidden_ptr + q, _lstm_H);
241+
242+
if (elemtype == 1)
243+
{
244+
// fp32
245+
vst1q_f32(output_data + q, _lstm_H);
246+
}
247+
if (elemtype == 2)
248+
{
249+
// fp16
250+
unsigned short* outptr = (unsigned short*)output_data + q;
251+
#if (__ARM_FP & 2)
252+
#if NCNN_GNU_INLINE_ASM
253+
#if __aarch64__
254+
asm volatile(
255+
"fcvtn v0.4h, %2.4s \n"
256+
"st1 {v0.4h}, [%0] \n"
257+
: "=r"(outptr) // %0
258+
: "0"(outptr),
259+
"w"(_lstm_H)
260+
: "memory", "v0");
261+
#else // __aarch64__
262+
asm volatile(
263+
"vcvt.f16.f32 d0, %q2 \n"
264+
"vst1.u16 {d0}, [%0] \n"
265+
: "=r"(outptr) // %0
266+
: "0"(outptr),
267+
"w"(_lstm_H)
268+
: "memory", "q0");
269+
#endif // __aarch64__
270+
#else // NCNN_GNU_INLINE_ASM
271+
vst1_u16(outptr, (uint16x4_t)vcvt_f16_f32(_lstm_H));
272+
#endif // NCNN_GNU_INLINE_ASM
273+
#else
274+
outptr[q] = float32_to_float16(hidden_ptr[q]);
275+
outptr[q + 1] = float32_to_float16(hidden_ptr[q + 1]);
276+
outptr[q + 2] = float32_to_float16(hidden_ptr[q + 2]);
277+
outptr[q + 3] = float32_to_float16(hidden_ptr[q + 3]);
278+
#endif // (__ARM_FP & 2)
279+
}
280+
if (elemtype == 4)
281+
{
282+
// bf16
283+
vst1_u16((unsigned short*)output_data + q, float2bfloat(_lstm_H));
284+
}
285+
}
286+
else
287+
{
288+
vst1q_f32(tmp_hidden_ptr + q, _lstm_H);
289+
}
290+
}
291+
#endif // __ARM_NEON
292+
#pragma omp parallel for num_threads(opt.num_threads)
293+
for (int q = remain_hidden_size_start; q < hidden_size; q++)
294+
{
295+
const float* gates_data = gates.row(q);
296+
297+
float I = gates_data[0];
298+
float F = gates_data[1];
299+
float O = gates_data[2];
300+
float G = gates_data[3];
301+
302+
I = 1.f / (1.f + expf(-I));
303+
F = 1.f / (1.f + expf(-F));
304+
O = 1.f / (1.f + expf(-O));
305+
G = tanhf(G);
306+
307+
float cell2 = F * cell_ptr[q] + I * G;
308+
float H = O * tanhf(cell2);
309+
310+
cell_ptr[q] = cell2;
311+
if (num_output == hidden_size)
312+
{
313+
hidden_ptr[q] = H;
314+
315+
if (elemtype == 1)
316+
{
317+
output_data[q] = H;
318+
}
319+
if (elemtype == 2)
320+
{
321+
((unsigned short*)output_data)[q] = float32_to_float16(H);
322+
}
323+
if (elemtype == 4)
324+
{
325+
((unsigned short*)output_data)[q] = float32_to_bfloat16(H);
326+
}
327+
}
328+
else
329+
{
330+
tmp_hidden_ptr[q] = H;
331+
}
332+
}
333+
334+
if (num_output != hidden_size)
335+
{
336+
// int nn_num_output = num_output >> 2;
337+
// int remain_num_output_start = nn_num_output << 2;
338+
// #pragma omp parallel for num_threads(opt.num_threads)
339+
// for (int qq = 0; qq < nn_num_output; qq++)
340+
// {
341+
// int q = qq * 4;
342+
//
343+
// }
344+
int remain_num_output_start = 0;
345+
#pragma omp parallel for num_threads(opt.num_threads)
346+
for (int q = remain_num_output_start; q < num_output; q++)
347+
{
348+
const float* hr = weight_hr.row(q);
349+
const float* tmp_hidden_ptr = tmp_hidden_state;
350+
351+
float H = 0;
352+
for (int i = 0; i < hidden_size; i++)
353+
{
354+
H += tmp_hidden_ptr[i] * hr[i];
355+
}
356+
357+
hidden_ptr[q] = H;
358+
359+
if (elemtype == 1)
360+
{
361+
output_data[q] = H;
362+
}
363+
if (elemtype == 2)
364+
{
365+
((unsigned short*)output_data)[q] = float32_to_float16(H);
366+
}
367+
if (elemtype == 4)
368+
{
369+
((unsigned short*)output_data)[q] = float32_to_bfloat16(H);
370+
}
371+
}
372+
}
373+
}
374+
184375
static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int elemtype, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt)
185376
{
186377
// TODO dispatch for __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
@@ -476,149 +667,6 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d
476667
#endif // __ARM_NEON
477668
}
478669

479-
// lstm unit
480-
// sigmoid(I)
481-
// sigmoid(F)
482-
// sigmoid(O)
483-
// tanh(G)
484-
// c_t := f_t .* c_{t-1} + i_t .* g_t
485-
// h_t := o_t .* tanh[c_t]
486-
float* output_data = top_blob.row(ti);
487-
488-
float* cell_ptr = cell_state;
489-
float* hidden_ptr = hidden_state;
490-
float* tmp_hidden_ptr = tmp_hidden_state;
491-
492-
int remain_hidden_size_start = 0;
493-
#if __ARM_NEON
494-
int nn_hidden_size = hidden_size >> 2;
495-
remain_hidden_size_start = nn_hidden_size << 2;
496-
497-
#pragma omp parallel for num_threads(opt.num_threads)
498-
for (int qq = 0; qq < nn_hidden_size; qq++)
499-
{
500-
int q = qq * 4;
501-
502-
const float* gates_data = gates.row(q);
503-
504-
float32x4x4_t _IFOG_4x4 = vld4q_f32(gates_data);
505-
506-
float32x4_t _lstm_I = sigmoid_ps(_IFOG_4x4.val[0]);
507-
float32x4_t _lstm_F = sigmoid_ps(_IFOG_4x4.val[1]);
508-
float32x4_t _lstm_O = sigmoid_ps(_IFOG_4x4.val[2]);
509-
float32x4_t _lstm_G = tanh_ps(_IFOG_4x4.val[3]);
510-
511-
float32x4_t _cell2 = vaddq_f32(vmulq_f32(_lstm_F, vld1q_f32(cell_ptr + q)), vmulq_f32(_lstm_I, _lstm_G));
512-
float32x4_t _lstm_H = vmulq_f32(_lstm_O, tanh_ps(_cell2));
513-
514-
vst1q_f32(cell_ptr + q, _cell2);
515-
516-
if (num_output == hidden_size)
517-
{
518-
vst1q_f32(hidden_ptr + q, _lstm_H);
519-
520-
if (elemtype == 1)
521-
{
522-
// fp32
523-
vst1q_f32(output_data + q, _lstm_H);
524-
}
525-
if (elemtype == 2)
526-
{
527-
// fp16
528-
vst1_u16((unsigned short*)output_data + q, (uint16x4_t)vcvt_f16_f32(_lstm_H));
529-
}
530-
if (elemtype == 4)
531-
{
532-
// bf16
533-
vst1_u16((unsigned short*)output_data + q, float2bfloat(_lstm_H));
534-
}
535-
}
536-
else
537-
{
538-
vst1q_f32(tmp_hidden_ptr + q, _lstm_H);
539-
}
540-
}
541-
#endif // __ARM_NEON
542-
#pragma omp parallel for num_threads(opt.num_threads)
543-
for (int q = remain_hidden_size_start; q < hidden_size; q++)
544-
{
545-
const float* gates_data = gates.row(q);
546-
547-
float I = gates_data[0];
548-
float F = gates_data[1];
549-
float O = gates_data[2];
550-
float G = gates_data[3];
551-
552-
I = 1.f / (1.f + expf(-I));
553-
F = 1.f / (1.f + expf(-F));
554-
O = 1.f / (1.f + expf(-O));
555-
G = tanhf(G);
556-
557-
float cell2 = F * cell_ptr[q] + I * G;
558-
float H = O * tanhf(cell2);
559-
560-
cell_ptr[q] = cell2;
561-
if (num_output == hidden_size)
562-
{
563-
hidden_ptr[q] = H;
564-
565-
if (elemtype == 1)
566-
{
567-
output_data[q] = H;
568-
}
569-
if (elemtype == 2)
570-
{
571-
((unsigned short*)output_data)[q] = float32_to_float16(H);
572-
}
573-
if (elemtype == 4)
574-
{
575-
((unsigned short*)output_data)[q] = float32_to_bfloat16(H);
576-
}
577-
}
578-
else
579-
{
580-
tmp_hidden_ptr[q] = H;
581-
}
582-
}
583-
584-
if (num_output != hidden_size)
585-
{
586-
// int nn_num_output = num_output >> 2;
587-
// int remain_num_output_start = nn_num_output << 2;
588-
// #pragma omp parallel for num_threads(opt.num_threads)
589-
// for (int qq = 0; qq < nn_num_output; qq++)
590-
// {
591-
// int q = qq * 4;
592-
//
593-
// }
594-
int remain_num_output_start = 0;
595-
#pragma omp parallel for num_threads(opt.num_threads)
596-
for (int q = remain_num_output_start; q < num_output; q++)
597-
{
598-
const float* hr = weight_hr.row(q);
599-
const float* tmp_hidden_ptr = tmp_hidden_state;
600-
601-
float H = 0;
602-
for (int i = 0; i < hidden_size; i++)
603-
{
604-
H += tmp_hidden_ptr[i] * hr[i];
605-
}
606-
607-
hidden_ptr[q] = H;
608-
609-
if (elemtype == 1)
610-
{
611-
output_data[q] = H;
612-
}
613-
if (elemtype == 2)
614-
{
615-
((unsigned short*)output_data)[q] = float32_to_float16(H);
616-
}
617-
if (elemtype == 4)
618-
{
619-
((unsigned short*)output_data)[q] = float32_to_bfloat16(H);
620-
}
621-
}
622-
}
670+
lstm_int8_gate_output(gates, weight_hr, hidden_state, tmp_hidden_state, cell_state, top_blob, ti, elemtype, opt);
623671
}
624672
}

src/layer/arm/rnn_int8.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ static void rnn_int8_gate_output(const Mat& gates, Mat& hidden_state, Mat& top_b
257257
asm volatile(
258258
"fcvtn v0.4h, %2.4s \n"
259259
"st1 {v0.4h}, [%0] \n"
260-
: "=r"(_rnn_H) // %0
260+
: "=r"(outptr) // %0
261261
: "0"(outptr),
262262
"w"(_rnn_H)
263263
: "memory", "v0");

0 commit comments

Comments
 (0)