Skip to content

Commit 82e9171

Browse files
MollySophiagithub-actions[bot]
authored andcommitted
apply code-format changes
1 parent 48054cd commit 82e9171

File tree

2 files changed

+58
-49
lines changed

2 files changed

+58
-49
lines changed

src/layer/arm/amx_usability.h

+45-39
Original file line numberDiff line numberDiff line change
@@ -16,35 +16,41 @@
1616
#define AMX_USABILITY_H
1717

1818
// From https://github.com/corsix/amx/blob/main/aarch64.h
19-
#define AMX_NOP_OP_IMM5(op, imm5) \
20-
__asm("nop\nnop\nnop\n.word (0x201000 + (%0 << 5) + %1)" : : "i"(op), "i"(imm5) : "memory")
21-
22-
#define AMX_OP_GPR(op, gpr) \
23-
__asm(".word (0x201000 + (%0 << 5) + 0%1 - ((0%1 >> 4) * 6))" : : "i"(op), "r"((uint64_t)(gpr)) : "memory")
24-
25-
#define AMX_LDX(gpr) AMX_OP_GPR( 0, gpr)
26-
#define AMX_LDY(gpr) AMX_OP_GPR( 1, gpr)
27-
#define AMX_STX(gpr) AMX_OP_GPR( 2, gpr)
28-
#define AMX_STY(gpr) AMX_OP_GPR( 3, gpr)
29-
#define AMX_LDZ(gpr) AMX_OP_GPR( 4, gpr)
30-
#define AMX_STZ(gpr) AMX_OP_GPR( 5, gpr)
31-
#define AMX_LDZI(gpr) AMX_OP_GPR( 6, gpr)
32-
#define AMX_STZI(gpr) AMX_OP_GPR( 7, gpr)
33-
#define AMX_EXTRX(gpr) AMX_OP_GPR( 8, gpr)
34-
#define AMX_EXTRY(gpr) AMX_OP_GPR( 9, gpr)
35-
#define AMX_FMA64(gpr) AMX_OP_GPR(10, gpr)
36-
#define AMX_FMS64(gpr) AMX_OP_GPR(11, gpr)
37-
#define AMX_FMA32(gpr) AMX_OP_GPR(12, gpr)
38-
#define AMX_FMS32(gpr) AMX_OP_GPR(13, gpr)
39-
#define AMX_MAC16(gpr) AMX_OP_GPR(14, gpr)
40-
#define AMX_FMA16(gpr) AMX_OP_GPR(15, gpr)
41-
#define AMX_FMS16(gpr) AMX_OP_GPR(16, gpr)
42-
#define AMX_VECINT(gpr) AMX_OP_GPR(18, gpr)
43-
#define AMX_VECFP(gpr) AMX_OP_GPR(19, gpr)
44-
#define AMX_MATINT(gpr) AMX_OP_GPR(20, gpr)
45-
#define AMX_MATFP(gpr) AMX_OP_GPR(21, gpr)
46-
#define AMX_GENLUT(gpr) AMX_OP_GPR(22, gpr)
47-
#define PTR_ROW_FLAGS(ptr, row, flags) (((uint64_t)&*(ptr)) + (((uint64_t)((row) + (flags) * 64)) << 56))
19+
#define AMX_NOP_OP_IMM5(op, imm5) \
20+
__asm("nop\nnop\nnop\n.word (0x201000 + (%0 << 5) + %1)" \
21+
: \
22+
: "i"(op), "i"(imm5) \
23+
: "memory")
24+
25+
#define AMX_OP_GPR(op, gpr) \
26+
__asm(".word (0x201000 + (%0 << 5) + 0%1 - ((0%1 >> 4) * 6))" \
27+
: \
28+
: "i"(op), "r"((uint64_t)(gpr)) \
29+
: "memory")
30+
31+
#define AMX_LDX(gpr) AMX_OP_GPR(0, gpr)
32+
#define AMX_LDY(gpr) AMX_OP_GPR(1, gpr)
33+
#define AMX_STX(gpr) AMX_OP_GPR(2, gpr)
34+
#define AMX_STY(gpr) AMX_OP_GPR(3, gpr)
35+
#define AMX_LDZ(gpr) AMX_OP_GPR(4, gpr)
36+
#define AMX_STZ(gpr) AMX_OP_GPR(5, gpr)
37+
#define AMX_LDZI(gpr) AMX_OP_GPR(6, gpr)
38+
#define AMX_STZI(gpr) AMX_OP_GPR(7, gpr)
39+
#define AMX_EXTRX(gpr) AMX_OP_GPR(8, gpr)
40+
#define AMX_EXTRY(gpr) AMX_OP_GPR(9, gpr)
41+
#define AMX_FMA64(gpr) AMX_OP_GPR(10, gpr)
42+
#define AMX_FMS64(gpr) AMX_OP_GPR(11, gpr)
43+
#define AMX_FMA32(gpr) AMX_OP_GPR(12, gpr)
44+
#define AMX_FMS32(gpr) AMX_OP_GPR(13, gpr)
45+
#define AMX_MAC16(gpr) AMX_OP_GPR(14, gpr)
46+
#define AMX_FMA16(gpr) AMX_OP_GPR(15, gpr)
47+
#define AMX_FMS16(gpr) AMX_OP_GPR(16, gpr)
48+
#define AMX_VECINT(gpr) AMX_OP_GPR(18, gpr)
49+
#define AMX_VECFP(gpr) AMX_OP_GPR(19, gpr)
50+
#define AMX_MATINT(gpr) AMX_OP_GPR(20, gpr)
51+
#define AMX_MATFP(gpr) AMX_OP_GPR(21, gpr)
52+
#define AMX_GENLUT(gpr) AMX_OP_GPR(22, gpr)
53+
#define PTR_ROW_FLAGS(ptr, row, flags) (((uint64_t) & *(ptr)) + (((uint64_t)((row) + (flags)*64)) << 56))
4854
void amx_set()
4955
{
5056
AMX_NOP_OP_IMM5(17, 0);
@@ -55,51 +61,51 @@ void amx_clr()
5561
AMX_NOP_OP_IMM5(17, 1);
5662
}
5763

58-
void amx_ldx(bool pair, unsigned int x_row, const void * ptr)
64+
void amx_ldx(bool pair, unsigned int x_row, const void* ptr)
5965
{
6066
if (x_row >= 8)
6167
return;
6268

6369
uint64_t oprand = (uint64_t)ptr + ((uint64_t)x_row << 56);
6470
if (pair)
6571
oprand |= 1ULL << 62;
66-
72+
6773
AMX_LDX(oprand);
6874
}
6975

70-
void amx_ldy(bool pair, unsigned int y_row, const void * ptr)
76+
void amx_ldy(bool pair, unsigned int y_row, const void* ptr)
7177
{
7278
if (y_row >= 8)
7379
return;
7480

7581
uint64_t oprand = (uint64_t)ptr + ((uint64_t)y_row << 56);
7682
if (pair)
7783
oprand |= 1ULL << 62;
78-
84+
7985
AMX_LDY(oprand);
8086
}
8187

82-
void amx_ldz(bool pair, unsigned int z_row, const void * ptr)
88+
void amx_ldz(bool pair, unsigned int z_row, const void* ptr)
8389
{
8490
if (z_row >= 64)
8591
return;
8692

8793
uint64_t oprand = (uint64_t)ptr + ((uint64_t)z_row << 56);
8894
if (pair)
8995
oprand |= 1ULL << 62;
90-
96+
9197
AMX_LDZ(oprand);
9298
}
9399

94-
void amx_stz(bool pair, unsigned int z_row, const void * ptr)
100+
void amx_stz(bool pair, unsigned int z_row, const void* ptr)
95101
{
96102
if (z_row >= 64)
97103
return;
98104

99105
uint64_t oprand = (uint64_t)ptr + ((uint64_t)z_row << 56);
100106
if (pair)
101107
oprand |= 1ULL << 62;
102-
108+
103109
AMX_STZ(oprand);
104110
}
105111

@@ -116,7 +122,7 @@ void amx_fma16_masked(bool vector, unsigned int x_offset, unsigned int y_offset,
116122
oprand |= ((uint64_t)y_mode & 0x3) << 37;
117123
oprand |= ((uint64_t)x_mask & 0x1F) << 41;
118124
oprand |= ((uint64_t)x_mode & 0x3) << 46;
119-
125+
120126
AMX_FMA16(oprand);
121127
}
122128

@@ -138,7 +144,7 @@ void amx_fma32_masked(bool vector, unsigned int x_offset, unsigned int y_offset,
138144
oprand |= ((uint64_t)y_mode & 0x3) << 37;
139145
oprand |= ((uint64_t)x_mask & 0x1F) << 41;
140146
oprand |= ((uint64_t)x_mode & 0x3) << 46;
141-
147+
142148
AMX_FMA32(oprand);
143149
}
144150

src/layer/arm/convolution_im2col_gemm_fp16s.h

+13-10
Original file line numberDiff line numberDiff line change
@@ -3056,20 +3056,20 @@ static void convolution_gemm_transB_packed_tile_fp16sa_amx(const Mat& AT_tile, c
30563056
if (pC)
30573057
{
30583058
for (int r = 0; r < 12; r++)
3059-
amx_ldz(false, 2*r, pC);
3059+
amx_ldz(false, 2 * r, pC);
30603060
}
30613061
else
30623062
{
30633063
__fp16 sums[16];
30643064
memset(sums, 0, 16 * sizeof(__fp16));
30653065
for (int r = 0; r < 12; r++)
3066-
amx_ldz(false, 2*r, sums);
3066+
amx_ldz(false, 2 * r, sums);
30673067
}
30683068
}
30693069
else
30703070
{
30713071
for (int r = 0; r < 12; r++)
3072-
amx_ldz(false, 2*r, outptr + 8 * r);
3072+
amx_ldz(false, 2 * r, outptr + 8 * r);
30733073
}
30743074

30753075
int kk = 0;
@@ -3088,17 +3088,19 @@ static void convolution_gemm_transB_packed_tile_fp16sa_amx(const Mat& AT_tile, c
30883088
if (out_elempack == 8)
30893089
{
30903090
__fp16 tmp[96 + 24];
3091-
for (int r = 0; r < 12; r++) {
3092-
amx_stz(false, 2*r, tmp + r * 8);
3091+
for (int r = 0; r < 12; r++)
3092+
{
3093+
amx_stz(false, 2 * r, tmp + r * 8);
30933094
}
30943095
memcpy(outptr0, tmp, 96 * sizeof(__fp16));
30953096
outptr0 += 96;
30963097
}
30973098
if (out_elempack == 4)
30983099
{
30993100
__fp16 tmp[32];
3100-
for (int r = 0; r < 12; r++) {
3101-
amx_stz(false, 2*r, tmp);
3101+
for (int r = 0; r < 12; r++)
3102+
{
3103+
amx_stz(false, 2 * r, tmp);
31023104
float16x8_t _tmp = vld1q_f16(tmp);
31033105
vst1_f16(outptr0 + 4 * r, vget_low_f16(_tmp));
31043106
vst1_f16(outptr0 + out_hstep * 4 + 4 * r, vget_high_f16(_tmp));
@@ -3167,8 +3169,9 @@ static void convolution_gemm_transB_packed_tile_fp16sa_amx(const Mat& AT_tile, c
31673169
else
31683170
{
31693171
__fp16 tmp[32];
3170-
for (int r = 0; r < 12; r++) {
3171-
amx_stz(false, 2*r, tmp);
3172+
for (int r = 0; r < 12; r++)
3173+
{
3174+
amx_stz(false, 2 * r, tmp);
31723175
memcpy(outptr0 + 8 * r, tmp, 8 * sizeof(__fp16));
31733176
}
31743177
}
@@ -4915,7 +4918,7 @@ static int convolution_im2col_gemm_fp16sa(const Mat& bottom_blob, Mat& top_blob,
49154918
bool k_end = k + TILE_K >= K;
49164919

49174920
#if __aarch64__ && NCNN_APPLE_AMX
4918-
// #if 0
4921+
// #if 0
49194922
if (amx_supported)
49204923
{
49214924
convolution_gemm_transB_packed_tile_fp16sa_amx(AT_tile, BT_tile, bias, topT_tile, top_blob, i, max_ii, j, max_jj, k, max_kk, k_end);

0 commit comments

Comments
 (0)