Skip to content

Commit e7602a2

Browse files
authored
fix gemm arm int8 scales descales offset (Tencent#5750)
1 parent c1f9e95 commit e7602a2

File tree

5 files changed

+93
-94
lines changed

5 files changed

+93
-94
lines changed

src/layer/arm/gemm_arm.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4404,7 +4404,6 @@ int Gemm_arm::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to
44044404
if (int8_scale_term)
44054405
{
44064406
return forward_int8(bottom_blobs, top_blobs, opt);
4407-
// return Gemm::forward_int8(bottom_blobs, top_blobs, opt);
44084407
}
44094408
#endif
44104409

src/layer/arm/gemm_int8.h

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1724,8 +1724,8 @@ static void compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, float B_s
17241724

17251725
const float v127_B_scale = 127.f * B_scale;
17261726

1727-
float* ps = scales;
1728-
float* pods = out_descales;
1727+
float* ps = (float*)scales + i;
1728+
float* pods = (float*)out_descales + i;
17291729

17301730
#if __ARM_NEON
17311731
if (elempack == 4)
@@ -1897,8 +1897,8 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i
18971897
{
18981898
const float* p0 = (const float*)A + (i + ii) * A_hstep + k * elempack;
18991899

1900-
float32x4_t _scale0 = vld1q_f32((const float*)scales + ii);
1901-
float32x4_t _scale1 = vld1q_f32((const float*)scales + ii + 4);
1900+
float32x4_t _scale0 = vld1q_f32((const float*)scales + i + ii);
1901+
float32x4_t _scale1 = vld1q_f32((const float*)scales + i + ii + 4);
19021902

19031903
if (elempack == 4)
19041904
{
@@ -2314,7 +2314,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i
23142314
{
23152315
const float* p0 = (const float*)A + (i + ii) * A_hstep + k * elempack;
23162316

2317-
float32x4_t _scale = vld1q_f32((const float*)scales + ii);
2317+
float32x4_t _scale = vld1q_f32((const float*)scales + i + ii);
23182318

23192319
if (elempack == 4)
23202320
{
@@ -2592,8 +2592,8 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i
25922592
{
25932593
const float* p0 = (const float*)A + (i + ii) * A_hstep + k;
25942594

2595-
const float scale0 = scales[ii];
2596-
const float scale1 = scales[ii + 1];
2595+
const float scale0 = scales[i + ii];
2596+
const float scale1 = scales[i + ii + 1];
25972597

25982598
// if (elempack == 1)
25992599
{
@@ -2680,7 +2680,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i
26802680
{
26812681
const float* p0 = (const float*)A + (i + ii) * A_hstep + k;
26822682

2683-
const float scale = scales[ii];
2683+
const float scale = scales[i + ii];
26842684

26852685
// if (elempack == 1)
26862686
{
@@ -2750,8 +2750,8 @@ static void transpose_compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales,
27502750
#endif
27512751
#endif
27522752

2753-
float* ps = scales;
2754-
float* pods = out_descales;
2753+
float* ps = (float*)scales + i;
2754+
float* pods = (float*)out_descales + i;
27552755

27562756
#if __ARM_NEON
27572757
if (elempack == 4)
@@ -3055,8 +3055,8 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int
30553055
{
30563056
const float* p0 = (const float*)A + k * A_hstep + (i + ii) * elempack;
30573057

3058-
float32x4_t _scale0 = vld1q_f32((const float*)scales + ii);
3059-
float32x4_t _scale1 = vld1q_f32((const float*)scales + ii + 4);
3058+
float32x4_t _scale0 = vld1q_f32((const float*)scales + i + ii);
3059+
float32x4_t _scale1 = vld1q_f32((const float*)scales + i + ii + 4);
30603060

30613061
if (elempack == 4)
30623062
{
@@ -3396,7 +3396,7 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int
33963396
{
33973397
const float* p0 = (const float*)A + k * A_hstep + (i + ii) * elempack;
33983398

3399-
float32x4_t _scale = vld1q_f32((const float*)scales + ii);
3399+
float32x4_t _scale = vld1q_f32((const float*)scales + i + ii);
34003400

34013401
if (elempack == 4)
34023402
{
@@ -3622,8 +3622,8 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int
36223622
{
36233623
const float* p0 = (const float*)A + k * A_hstep + (i + ii) * elempack;
36243624

3625-
const float scale0 = scales[ii];
3626-
const float scale1 = scales[ii + 1];
3625+
const float scale0 = scales[i + ii];
3626+
const float scale1 = scales[i + ii + 1];
36273627

36283628
#if __ARM_NEON
36293629
float32x4_t _scale0 = vdupq_n_f32(scale0);
@@ -3805,7 +3805,7 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int
38053805
{
38063806
const float* p0 = (const float*)A + k * A_hstep + (i + ii) * elempack;
38073807

3808-
const float scale = scales[ii];
3808+
const float scale = scales[i + ii];
38093809

38103810
#if __ARM_NEON
38113811
float32x4_t _scale = vdupq_n_f32(scale);
@@ -5646,8 +5646,8 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat&
56465646
{
56475647
float* p0 = (float*)top_blob + (i + ii) * out_hstep + j * out_elempack;
56485648

5649-
float32x4_t _descale0 = vld1q_f32((const float*)descales + ii);
5650-
float32x4_t _descale1 = vld1q_f32((const float*)descales + ii + 4);
5649+
float32x4_t _descale0 = vld1q_f32((const float*)descales + i + ii);
5650+
float32x4_t _descale1 = vld1q_f32((const float*)descales + i + ii + 4);
56515651

56525652
float32x4_t _c0;
56535653
float32x4_t _c1;
@@ -6593,7 +6593,7 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat&
65936593
{
65946594
float* p0 = (float*)top_blob + (i + ii) * out_hstep + j * out_elempack;
65956595

6596-
float32x4_t _descale = vld1q_f32((const float*)descales + ii);
6596+
float32x4_t _descale = vld1q_f32((const float*)descales + i + ii);
65976597

65986598
float32x4_t _c0;
65996599
if (pC)
@@ -7181,10 +7181,10 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat&
71817181
// out_elempack == 1
71827182
float* p0 = (float*)top_blob + (i + ii) * out_hstep + j;
71837183

7184-
const float descale0 = descales[ii];
7185-
const float descale1 = descales[ii + 1];
7184+
const float descale0 = descales[i + ii];
7185+
const float descale1 = descales[i + ii + 1];
71867186
#if __ARM_NEON
7187-
float32x2_t _descale = vld1_f32((const float*)descales + ii);
7187+
float32x2_t _descale = vld1_f32((const float*)descales + i + ii);
71887188
#endif
71897189

71907190
float c0;
@@ -7467,7 +7467,7 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat&
74677467
// out_elempack == 1
74687468
float* p0 = (float*)top_blob + (i + ii) * out_hstep + j;
74697469

7470-
const float descale = descales[ii];
7470+
const float descale = descales[i + ii];
74717471
#if __ARM_NEON
74727472
float32x4_t _descale = vdupq_n_f32(descale);
74737473
#endif
@@ -7726,8 +7726,8 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma
77267726
{
77277727
float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * out_elempack;
77287728

7729-
float32x4_t _descale0 = vld1q_f32((const float*)descales + ii);
7730-
float32x4_t _descale1 = vld1q_f32((const float*)descales + ii + 4);
7729+
float32x4_t _descale0 = vld1q_f32((const float*)descales + i + ii);
7730+
float32x4_t _descale1 = vld1q_f32((const float*)descales + i + ii + 4);
77317731

77327732
float32x4_t _c0;
77337733
float32x4_t _c1;
@@ -8673,7 +8673,7 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma
86738673
{
86748674
float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * out_elempack;
86758675

8676-
float32x4_t _descale = vld1q_f32((const float*)descales + ii);
8676+
float32x4_t _descale = vld1q_f32((const float*)descales + i + ii);
86778677

86788678
float32x4_t _c0;
86798679
if (pC)
@@ -9237,10 +9237,10 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma
92379237
{
92389238
float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * out_elempack;
92399239

9240-
const float descale0 = descales[ii];
9241-
const float descale1 = descales[ii + 1];
9240+
const float descale0 = descales[i + ii];
9241+
const float descale1 = descales[i + ii + 1];
92429242
#if __ARM_NEON
9243-
float32x2_t _descale01 = vld1_f32((const float*)descales + ii);
9243+
float32x2_t _descale01 = vld1_f32((const float*)descales + i + ii);
92449244
#endif
92459245

92469246
float c0;
@@ -9556,7 +9556,7 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma
95569556
{
95579557
float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * out_elempack;
95589558

9559-
const float descale = descales[ii];
9559+
const float descale = descales[i + ii];
95609560
#if __ARM_NEON
95619561
float32x4_t _descale = vdupq_n_f32(descale);
95629562
#endif

src/layer/arm/gemm_int8_bf16s.h

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ static void compute_A_tile_bf16_int8_scales(const Mat& A, Mat& scales, float B_s
3838

3939
const float v127_B_scale = 127.f * B_scale;
4040

41-
float* ps = scales;
42-
float* pods = out_descales;
41+
float* ps = (float*)scales + i;
42+
float* pods = (float*)out_descales + i;
4343

4444
#if __ARM_NEON
4545
if (elempack == 4)
@@ -217,8 +217,8 @@ static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i
217217
{
218218
const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k * elempack;
219219

220-
float32x4_t _scale0 = vld1q_f32((const float*)scales + ii);
221-
float32x4_t _scale1 = vld1q_f32((const float*)scales + ii + 4);
220+
float32x4_t _scale0 = vld1q_f32((const float*)scales + i + ii);
221+
float32x4_t _scale1 = vld1q_f32((const float*)scales + i + ii + 4);
222222

223223
if (elempack == 4)
224224
{
@@ -665,7 +665,7 @@ static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i
665665
{
666666
const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k * elempack;
667667

668-
float32x4_t _scale = vld1q_f32((const float*)scales + ii);
668+
float32x4_t _scale = vld1q_f32((const float*)scales + i + ii);
669669

670670
if (elempack == 4)
671671
{
@@ -958,8 +958,8 @@ static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i
958958
{
959959
const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k;
960960

961-
const float scale0 = scales[ii];
962-
const float scale1 = scales[ii + 1];
961+
const float scale0 = scales[i + ii];
962+
const float scale1 = scales[i + ii + 1];
963963

964964
// if (elempack == 1)
965965
{
@@ -1048,7 +1048,7 @@ static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i
10481048
{
10491049
const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k;
10501050

1051-
const float scale = scales[ii];
1051+
const float scale = scales[i + ii];
10521052

10531053
// if (elempack == 1)
10541054
{
@@ -1121,8 +1121,8 @@ static void transpose_compute_A_tile_bf16_int8_scales(const Mat& A, Mat& scales,
11211121
#endif
11221122
#endif
11231123

1124-
float* ps = scales;
1125-
float* pods = out_descales;
1124+
float* ps = (float*)scales + i;
1125+
float* pods = (float*)out_descales + i;
11261126

11271127
#if __ARM_NEON
11281128
if (elempack == 4)
@@ -1362,8 +1362,8 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int
13621362
{
13631363
const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * elempack;
13641364

1365-
float32x4_t _scale0 = vld1q_f32((const float*)scales + ii);
1366-
float32x4_t _scale1 = vld1q_f32((const float*)scales + ii + 4);
1365+
float32x4_t _scale0 = vld1q_f32((const float*)scales + i + ii);
1366+
float32x4_t _scale1 = vld1q_f32((const float*)scales + i + ii + 4);
13671367

13681368
if (elempack == 4)
13691369
{
@@ -1731,7 +1731,7 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int
17311731
{
17321732
const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * elempack;
17331733

1734-
float32x4_t _scale = vld1q_f32((const float*)scales + ii);
1734+
float32x4_t _scale = vld1q_f32((const float*)scales + i + ii);
17351735

17361736
if (elempack == 4)
17371737
{
@@ -1963,8 +1963,8 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int
19631963
{
19641964
const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * elempack;
19651965

1966-
const float scale0 = scales[ii];
1967-
const float scale1 = scales[ii + 1];
1966+
const float scale0 = scales[i + ii];
1967+
const float scale1 = scales[i + ii + 1];
19681968

19691969
#if __ARM_NEON
19701970
float32x4_t _scale0 = vdupq_n_f32(scale0);
@@ -2187,7 +2187,7 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int
21872187
{
21882188
const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * elempack;
21892189

2190-
const float scale = scales[ii];
2190+
const float scale = scales[i + ii];
21912191

21922192
#if __ARM_NEON
21932193
float32x4_t _scale = vdupq_n_f32(scale);
@@ -4169,8 +4169,8 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat&
41694169
{
41704170
unsigned short* p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j * out_elempack;
41714171

4172-
float32x4_t _descale0 = vld1q_f32((const float*)descales + ii);
4173-
float32x4_t _descale1 = vld1q_f32((const float*)descales + ii + 4);
4172+
float32x4_t _descale0 = vld1q_f32((const float*)descales + i + ii);
4173+
float32x4_t _descale1 = vld1q_f32((const float*)descales + i + ii + 4);
41744174

41754175
float32x4_t _c0;
41764176
float32x4_t _c1;
@@ -5189,7 +5189,7 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat&
51895189
{
51905190
unsigned short* p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j * out_elempack;
51915191

5192-
float32x4_t _descale = vld1q_f32((const float*)descales + ii);
5192+
float32x4_t _descale = vld1q_f32((const float*)descales + i + ii);
51935193

51945194
float32x4_t _c0;
51955195
if (pC)
@@ -5794,10 +5794,10 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat&
57945794
// out_elempack == 1
57955795
unsigned short* p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j;
57965796

5797-
const float descale0 = descales[ii];
5798-
const float descale1 = descales[ii + 1];
5797+
const float descale0 = descales[i + ii];
5798+
const float descale1 = descales[i + ii + 1];
57995799
#if __ARM_NEON
5800-
float32x2_t _descale = vld1_f32((const float*)descales + ii);
5800+
float32x2_t _descale = vld1_f32((const float*)descales + i + ii);
58015801
#endif
58025802

58035803
float c0;
@@ -6097,7 +6097,7 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat&
60976097
// out_elempack == 1
60986098
unsigned short* p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j;
60996099

6100-
const float descale = descales[ii];
6100+
const float descale = descales[i + ii];
61016101
#if __ARM_NEON
61026102
float32x4_t _descale = vdupq_n_f32(descale);
61036103
#endif
@@ -6359,8 +6359,8 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma
63596359
{
63606360
unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack;
63616361

6362-
float32x4_t _descale0 = vld1q_f32((const float*)descales + ii);
6363-
float32x4_t _descale1 = vld1q_f32((const float*)descales + ii + 4);
6362+
float32x4_t _descale0 = vld1q_f32((const float*)descales + i + ii);
6363+
float32x4_t _descale1 = vld1q_f32((const float*)descales + i + ii + 4);
63646364

63656365
float32x4_t _c0;
63666366
float32x4_t _c1;
@@ -7318,7 +7318,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma
73187318
{
73197319
unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack;
73207320

7321-
float32x4_t _descale = vld1q_f32((const float*)descales + ii);
7321+
float32x4_t _descale = vld1q_f32((const float*)descales + i + ii);
73227322

73237323
float32x4_t _c0;
73247324
if (pC)
@@ -7902,10 +7902,10 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma
79027902
{
79037903
unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack;
79047904

7905-
const float descale0 = descales[ii];
7906-
const float descale1 = descales[ii + 1];
7905+
const float descale0 = descales[i + ii];
7906+
const float descale1 = descales[i + ii + 1];
79077907
#if __ARM_NEON
7908-
float32x2_t _descale01 = vld1_f32((const float*)descales + ii);
7908+
float32x2_t _descale01 = vld1_f32((const float*)descales + i + ii);
79097909
#endif
79107910

79117911
float c0;
@@ -8250,7 +8250,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma
82508250
{
82518251
unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack;
82528252

8253-
const float descale = descales[ii];
8253+
const float descale = descales[i + ii];
82548254
#if __ARM_NEON
82558255
float32x4_t _descale = vdupq_n_f32(descale);
82568256
#endif

0 commit comments

Comments
 (0)