Skip to content

Commit 8da6e60

Browse files
committed
innerproduct pack4 arm neon
1 parent ea9c114 commit 8da6e60

File tree

2 files changed

+356
-0
lines changed

2 files changed

+356
-0
lines changed

src/layer/arm/innerproduct_arm.cpp

Lines changed: 348 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,138 @@
1616

1717
#if __ARM_NEON
1818
#include <arm_neon.h>
19+
#include "neon_mathfun.h"
1920
#endif // __ARM_NEON
2021

2122
namespace ncnn {
2223

2324
DEFINE_LAYER_CREATOR(InnerProduct_arm)
2425

26+
int InnerProduct_arm::create_pipeline(const Option& opt)
27+
{
28+
int num_input = weight_data_size / num_output;
29+
30+
if (opt.use_packing_layout)
31+
{
32+
33+
// pack4
34+
if (num_input % 4 == 0 && num_output % 4 == 0)
35+
{
36+
// src = inch-outch
37+
// dst = 4a-4b-inch/4a-outch/4b
38+
{
39+
Mat weight_data_r2 = weight_data.reshape(num_input, num_output);
40+
41+
weight_data_pack4.create(num_input/4, num_output/4, (size_t)4*16, 16);
42+
43+
for (int q=0; q+3<num_output; q+=4)
44+
{
45+
const float* k0 = weight_data_r2.row(q);
46+
const float* k1 = weight_data_r2.row(q+1);
47+
const float* k2 = weight_data_r2.row(q+2);
48+
const float* k3 = weight_data_r2.row(q+3);
49+
50+
float* g00 = weight_data_pack4.row(q/4);
51+
52+
for (int p=0; p+3<num_input; p+=4)
53+
{
54+
g00[0] = k0[0];
55+
g00[1] = k1[0];
56+
g00[2] = k2[0];
57+
g00[3] = k3[0];
58+
59+
g00[4] = k0[1];
60+
g00[5] = k1[1];
61+
g00[6] = k2[1];
62+
g00[7] = k3[1];
63+
64+
g00[8] = k0[2];
65+
g00[9] = k1[2];
66+
g00[10] = k2[2];
67+
g00[11] = k3[2];
68+
69+
g00[12] = k0[3];
70+
g00[13] = k1[3];
71+
g00[14] = k2[3];
72+
g00[15] = k3[3];
73+
74+
k0 += 4;
75+
k1 += 4;
76+
k2 += 4;
77+
k3 += 4;
78+
g00 += 16;
79+
}
80+
}
81+
}
82+
}
83+
84+
// pack1to4
85+
if (num_input % 4 != 0 && num_output % 4 == 0)
86+
{
87+
// src = inch-outch
88+
// dst = 4b-inch-outch/4b
89+
{
90+
Mat weight_data_r2 = weight_data.reshape(num_input, num_output);
91+
92+
weight_data_pack1to4.create(num_input, num_output/4, (size_t)4*4, 4);
93+
94+
for (int q=0; q+3<num_output; q+=4)
95+
{
96+
const float* k0 = weight_data_r2.row(q);
97+
const float* k1 = weight_data_r2.row(q+1);
98+
const float* k2 = weight_data_r2.row(q+2);
99+
const float* k3 = weight_data_r2.row(q+3);
100+
101+
float* g00 = weight_data_pack1to4.row(q/4);
102+
103+
for (int p=0; p<num_input; p++)
104+
{
105+
g00[0] = k0[p];
106+
g00[1] = k1[p];
107+
g00[2] = k2[p];
108+
g00[3] = k3[p];
109+
110+
g00 += 4;
111+
}
112+
}
113+
}
114+
}
115+
116+
// pack4to1
117+
if (num_input % 4 == 0 && num_output % 4 != 0)
118+
{
119+
// src = inch-outch
120+
// dst = 4a-inch/4a-outch
121+
{
122+
Mat weight_data_r2 = weight_data.reshape(num_input, num_output);
123+
124+
weight_data_pack4to1.create(num_input/4, num_output, (size_t)4*4, 4);
125+
126+
for (int q=0; q<num_output; q++)
127+
{
128+
const float* k0 = weight_data_r2.row(q);
129+
130+
float* g00 = weight_data_pack4to1.row(q);
131+
132+
for (int p=0; p+3<num_input; p+=4)
133+
{
134+
g00[0] = k0[0];
135+
g00[1] = k0[1];
136+
g00[2] = k0[2];
137+
g00[3] = k0[3];
138+
139+
k0 += 4;
140+
g00 += 4;
141+
}
142+
}
143+
}
144+
}
145+
146+
} // opt.use_packing_layout
147+
148+
return 0;
149+
}
150+
25151
int InnerProduct_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
26152
{
27153
if (use_int8_inference)
@@ -34,8 +160,230 @@ int InnerProduct_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Optio
34160
int h = bottom_blob.h;
35161
int channels = bottom_blob.c;
36162
size_t elemsize = bottom_blob.elemsize;
163+
int packing = bottom_blob.packing;
37164
int size = w * h;
38165

166+
if (opt.use_packing_layout)
167+
{
168+
169+
int num_input = bottom_blob.w;
170+
171+
int out_packing = num_output % 4 == 0 ? 4 : 1;
172+
size_t out_elemsize = elemsize / packing * out_packing;
173+
174+
top_blob.create(num_output / out_packing, out_elemsize, out_packing, opt.blob_allocator);
175+
if (top_blob.empty())
176+
return -100;
177+
178+
if (packing == 4 && out_packing == 4)
179+
{
180+
// num_output
181+
#pragma omp parallel for num_threads(opt.num_threads)
182+
for (int p=0; p<num_output / out_packing; p++)
183+
{
184+
const float* w = (const float*)weight_data_pack4 + num_input * p * 16;
185+
const float* m = bottom_blob;
186+
187+
float32x4_t _sum = vdupq_n_f32(0.f);
188+
189+
if (bias_term)
190+
{
191+
_sum = vld1q_f32(((const float*)bias_data) + p * 4);
192+
}
193+
194+
// num_input
195+
for (int i = 0; i < num_input; i++)
196+
{
197+
float32x4_t _val = vld1q_f32( m );
198+
199+
float32x4_t _w0 = vld1q_f32( w );
200+
float32x4_t _w1 = vld1q_f32( w + 4 );
201+
float32x4_t _w2 = vld1q_f32( w + 8 );
202+
float32x4_t _w3 = vld1q_f32( w + 12 );
203+
204+
#if __aarch64__
205+
_sum = vmlaq_laneq_f32(_sum, _w0, _val, 0);
206+
_sum = vmlaq_laneq_f32(_sum, _w1, _val, 1);
207+
_sum = vmlaq_laneq_f32(_sum, _w2, _val, 2);
208+
_sum = vmlaq_laneq_f32(_sum, _w3, _val, 3);
209+
#else
210+
_sum = vmlaq_lane_f32(_sum, _w0, vget_low_f32(_val), 0);
211+
_sum = vmlaq_lane_f32(_sum, _w1, vget_low_f32(_val), 1);
212+
_sum = vmlaq_lane_f32(_sum, _w2, vget_high_f32(_val), 0);
213+
_sum = vmlaq_lane_f32(_sum, _w3, vget_high_f32(_val), 1);
214+
#endif
215+
216+
w += 16;
217+
m += 4;
218+
}
219+
220+
if (activation_type == 1)
221+
{
222+
float32x4_t _zero = vdupq_n_f32(0.f);
223+
_sum = vmaxq_f32(_sum, _zero);
224+
}
225+
else if (activation_type == 2)
226+
{
227+
float32x4_t _zero = vdupq_n_f32(0.f);
228+
float32x4_t _slope = vdupq_n_f32(activation_params[0]);
229+
uint32x4_t _lemask = vcleq_f32(_sum, _zero);
230+
float32x4_t _ps = vmulq_f32(_sum, _slope);
231+
_sum = vbslq_f32(_lemask, _ps, _sum);
232+
}
233+
else if (activation_type == 3)
234+
{
235+
float32x4_t _min = vdupq_n_f32(activation_params[0]);
236+
float32x4_t _max = vdupq_n_f32(activation_params[1]);
237+
_sum = vmaxq_f32(_sum, _min);
238+
_sum = vminq_f32(_sum, _max);
239+
}
240+
else if (activation_type == 4)
241+
{
242+
float32x4_t _one = vdupq_n_f32(1.f);
243+
_sum = vnegq_f32(_sum);
244+
_sum = exp_ps(_sum);
245+
_sum = vaddq_f32(_sum, _one);
246+
float32x4_t _outp = vrecpeq_f32(_sum);
247+
_outp = vmulq_f32(vrecpsq_f32(_sum, _outp), _outp);
248+
// _outp = vmulq_f32(vrecpsq_f32(_sum, _outp), _outp);
249+
_sum = _outp;
250+
}
251+
252+
float* outptr = top_blob;
253+
vst1q_f32(outptr + p * 4, _sum);
254+
}
255+
256+
return 0;
257+
}
258+
259+
if (packing == 1 && out_packing == 4)
260+
{
261+
// num_output
262+
#pragma omp parallel for num_threads(opt.num_threads)
263+
for (int p=0; p<num_output / out_packing; p++)
264+
{
265+
const float* w = (const float*)weight_data_pack1to4 + num_input * p * 4;
266+
const float* m = bottom_blob;
267+
268+
float32x4_t _sum = vdupq_n_f32(0.f);
269+
270+
if (bias_term)
271+
{
272+
_sum = vld1q_f32(((const float*)bias_data) + p * 4);
273+
}
274+
275+
// num_input
276+
for (int i = 0; i < num_input; i++)
277+
{
278+
float32x4_t _val = vdupq_n_f32( m[i] );
279+
float32x4_t _w = vld1q_f32( w );
280+
_sum = vmlaq_f32(_sum, _val, _w);
281+
282+
w += 4;
283+
}
284+
285+
if (activation_type == 1)
286+
{
287+
float32x4_t _zero = vdupq_n_f32(0.f);
288+
_sum = vmaxq_f32(_sum, _zero);
289+
}
290+
else if (activation_type == 2)
291+
{
292+
float32x4_t _zero = vdupq_n_f32(0.f);
293+
float32x4_t _slope = vdupq_n_f32(activation_params[0]);
294+
uint32x4_t _lemask = vcleq_f32(_sum, _zero);
295+
float32x4_t _ps = vmulq_f32(_sum, _slope);
296+
_sum = vbslq_f32(_lemask, _ps, _sum);
297+
}
298+
else if (activation_type == 3)
299+
{
300+
float32x4_t _min = vdupq_n_f32(activation_params[0]);
301+
float32x4_t _max = vdupq_n_f32(activation_params[1]);
302+
_sum = vmaxq_f32(_sum, _min);
303+
_sum = vminq_f32(_sum, _max);
304+
}
305+
else if (activation_type == 4)
306+
{
307+
float32x4_t _one = vdupq_n_f32(1.f);
308+
_sum = vnegq_f32(_sum);
309+
_sum = exp_ps(_sum);
310+
_sum = vaddq_f32(_sum, _one);
311+
float32x4_t _outp = vrecpeq_f32(_sum);
312+
_outp = vmulq_f32(vrecpsq_f32(_sum, _outp), _outp);
313+
// _outp = vmulq_f32(vrecpsq_f32(_sum, _outp), _outp);
314+
_sum = _outp;
315+
}
316+
317+
float* outptr = top_blob;
318+
vst1q_f32(outptr + p * 4, _sum);
319+
}
320+
321+
return 0;
322+
}
323+
324+
if (packing == 4 && out_packing == 1)
325+
{
326+
// num_output
327+
#pragma omp parallel for num_threads(opt.num_threads)
328+
for (int p=0; p<num_output; p++)
329+
{
330+
const float* w = (const float*)weight_data_pack4to1 + num_input * p * 4;
331+
const float* m = bottom_blob;
332+
333+
float sum = 0.f;
334+
335+
if (bias_term)
336+
sum = bias_data[p];
337+
338+
// num_input
339+
for (int i = 0; i < num_input; i++)
340+
{
341+
float32x4_t _val = vld1q_f32( m );
342+
float32x4_t _w = vld1q_f32( w );
343+
float32x4_t _s4 = vmulq_f32(_val, _w);
344+
#if __aarch64__
345+
sum += vaddvq_f32(_s4); // dot
346+
#else
347+
float32x2_t _ss = vadd_f32(vget_low_f32(_s4), vget_high_f32(_s4));
348+
_ss = vpadd_f32(_ss, _ss);
349+
sum += vget_lane_f32(_ss, 0);
350+
#endif
351+
352+
w += 4;
353+
m += 4;
354+
}
355+
356+
if (activation_type == 1)
357+
{
358+
sum = std::max(sum, 0.f);
359+
}
360+
else if (activation_type == 2)
361+
{
362+
float slope = activation_params[0];
363+
sum = sum > 0.f ? sum : sum * slope;
364+
}
365+
else if (activation_type == 3)
366+
{
367+
float min = activation_params[0];
368+
float max = activation_params[1];
369+
if (sum < min)
370+
sum = min;
371+
if (sum > max)
372+
sum = max;
373+
}
374+
else if (activation_type == 4)
375+
{
376+
sum = 1.f / (1.f + exp(-sum));
377+
}
378+
379+
top_blob[p] = sum;
380+
}
381+
382+
return 0;
383+
}
384+
385+
} // opt.use_packing_layout
386+
39387
top_blob.create(num_output, elemsize, opt.blob_allocator);
40388
if (top_blob.empty())
41389
return -100;

src/layer/arm/innerproduct_arm.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,15 @@ namespace ncnn {
2222
class InnerProduct_arm : virtual public InnerProduct
2323
{
2424
public:
25+
virtual int create_pipeline(const Option& opt);
26+
2527
virtual int forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const;
28+
29+
public:
30+
// pack4
31+
Mat weight_data_pack4;
32+
Mat weight_data_pack1to4;
33+
Mat weight_data_pack4to1;
2634
};
2735

2836
} // namespace ncnn

0 commit comments

Comments
 (0)