diff --git a/src/layer/arm/amx_usability.h b/src/layer/arm/amx_usability.h index d2bdc267530..d6783351000 100644 --- a/src/layer/arm/amx_usability.h +++ b/src/layer/arm/amx_usability.h @@ -16,35 +16,41 @@ #define AMX_USABILITY_H // From https://github.com/corsix/amx/blob/main/aarch64.h -#define AMX_NOP_OP_IMM5(op, imm5) \ - __asm("nop\nnop\nnop\n.word (0x201000 + (%0 << 5) + %1)" : : "i"(op), "i"(imm5) : "memory") - -#define AMX_OP_GPR(op, gpr) \ - __asm(".word (0x201000 + (%0 << 5) + 0%1 - ((0%1 >> 4) * 6))" : : "i"(op), "r"((uint64_t)(gpr)) : "memory") - -#define AMX_LDX(gpr) AMX_OP_GPR( 0, gpr) -#define AMX_LDY(gpr) AMX_OP_GPR( 1, gpr) -#define AMX_STX(gpr) AMX_OP_GPR( 2, gpr) -#define AMX_STY(gpr) AMX_OP_GPR( 3, gpr) -#define AMX_LDZ(gpr) AMX_OP_GPR( 4, gpr) -#define AMX_STZ(gpr) AMX_OP_GPR( 5, gpr) -#define AMX_LDZI(gpr) AMX_OP_GPR( 6, gpr) -#define AMX_STZI(gpr) AMX_OP_GPR( 7, gpr) -#define AMX_EXTRX(gpr) AMX_OP_GPR( 8, gpr) -#define AMX_EXTRY(gpr) AMX_OP_GPR( 9, gpr) -#define AMX_FMA64(gpr) AMX_OP_GPR(10, gpr) -#define AMX_FMS64(gpr) AMX_OP_GPR(11, gpr) -#define AMX_FMA32(gpr) AMX_OP_GPR(12, gpr) -#define AMX_FMS32(gpr) AMX_OP_GPR(13, gpr) -#define AMX_MAC16(gpr) AMX_OP_GPR(14, gpr) -#define AMX_FMA16(gpr) AMX_OP_GPR(15, gpr) -#define AMX_FMS16(gpr) AMX_OP_GPR(16, gpr) -#define AMX_VECINT(gpr) AMX_OP_GPR(18, gpr) -#define AMX_VECFP(gpr) AMX_OP_GPR(19, gpr) -#define AMX_MATINT(gpr) AMX_OP_GPR(20, gpr) -#define AMX_MATFP(gpr) AMX_OP_GPR(21, gpr) -#define AMX_GENLUT(gpr) AMX_OP_GPR(22, gpr) -#define PTR_ROW_FLAGS(ptr, row, flags) (((uint64_t)&*(ptr)) + (((uint64_t)((row) + (flags) * 64)) << 56)) +#define AMX_NOP_OP_IMM5(op, imm5) \ + __asm("nop\nnop\nnop\n.word (0x201000 + (%0 << 5) + %1)" \ + : \ + : "i"(op), "i"(imm5) \ + : "memory") + +#define AMX_OP_GPR(op, gpr) \ + __asm(".word (0x201000 + (%0 << 5) + 0%1 - ((0%1 >> 4) * 6))" \ + : \ + : "i"(op), "r"((uint64_t)(gpr)) \ + : "memory") + +#define AMX_LDX(gpr) AMX_OP_GPR(0, gpr) +#define AMX_LDY(gpr) AMX_OP_GPR(1, gpr) +#define AMX_STX(gpr) AMX_OP_GPR(2, gpr) +#define AMX_STY(gpr) AMX_OP_GPR(3, gpr) +#define AMX_LDZ(gpr) AMX_OP_GPR(4, gpr) +#define AMX_STZ(gpr) AMX_OP_GPR(5, gpr) +#define AMX_LDZI(gpr) AMX_OP_GPR(6, gpr) +#define AMX_STZI(gpr) AMX_OP_GPR(7, gpr) +#define AMX_EXTRX(gpr) AMX_OP_GPR(8, gpr) +#define AMX_EXTRY(gpr) AMX_OP_GPR(9, gpr) +#define AMX_FMA64(gpr) AMX_OP_GPR(10, gpr) +#define AMX_FMS64(gpr) AMX_OP_GPR(11, gpr) +#define AMX_FMA32(gpr) AMX_OP_GPR(12, gpr) +#define AMX_FMS32(gpr) AMX_OP_GPR(13, gpr) +#define AMX_MAC16(gpr) AMX_OP_GPR(14, gpr) +#define AMX_FMA16(gpr) AMX_OP_GPR(15, gpr) +#define AMX_FMS16(gpr) AMX_OP_GPR(16, gpr) +#define AMX_VECINT(gpr) AMX_OP_GPR(18, gpr) +#define AMX_VECFP(gpr) AMX_OP_GPR(19, gpr) +#define AMX_MATINT(gpr) AMX_OP_GPR(20, gpr) +#define AMX_MATFP(gpr) AMX_OP_GPR(21, gpr) +#define AMX_GENLUT(gpr) AMX_OP_GPR(22, gpr) +#define PTR_ROW_FLAGS(ptr, row, flags) (((uint64_t) & *(ptr)) + (((uint64_t)((row) + (flags)*64)) << 56)) void amx_set() { AMX_NOP_OP_IMM5(17, 0); @@ -55,7 +61,7 @@ void amx_clr() AMX_NOP_OP_IMM5(17, 1); } -void amx_ldx(bool pair, unsigned int x_row, const void * ptr) +void amx_ldx(bool pair, unsigned int x_row, const void* ptr) { if (x_row >= 8) return; @@ -63,11 +69,11 @@ void amx_ldx(bool pair, unsigned int x_row, const void * ptr) uint64_t oprand = (uint64_t)ptr + ((uint64_t)x_row << 56); if (pair) oprand |= 1ULL << 62; - + AMX_LDX(oprand); } -void amx_ldy(bool pair, unsigned int y_row, const void * ptr) +void amx_ldy(bool pair, unsigned int y_row, const void* ptr) { if (y_row >= 8) return; @@ -75,11 +81,11 @@ void amx_ldy(bool pair, unsigned int y_row, const void * ptr) uint64_t oprand = (uint64_t)ptr + ((uint64_t)y_row << 56); if (pair) oprand |= 1ULL << 62; - + AMX_LDY(oprand); } -void amx_ldz(bool pair, unsigned int z_row, const void * ptr) +void amx_ldz(bool pair, unsigned int z_row, const void* ptr) { if (z_row >= 64) return; @@ -87,11 +93,11 @@ void amx_ldz(bool pair, unsigned int z_row, const void * ptr) uint64_t oprand = (uint64_t)ptr + ((uint64_t)z_row << 56); if (pair) oprand |= 1ULL << 62; - + AMX_LDZ(oprand); } -void amx_stz(bool pair, unsigned int z_row, const void * ptr) +void amx_stz(bool pair, unsigned int z_row, const void* ptr) { if (z_row >= 64) return; @@ -99,7 +105,7 @@ void amx_stz(bool pair, unsigned int z_row, const void * ptr) uint64_t oprand = (uint64_t)ptr + ((uint64_t)z_row << 56); if (pair) oprand |= 1ULL << 62; - + AMX_STZ(oprand); } @@ -116,7 +122,7 @@ void amx_fma16_masked(bool vector, unsigned int x_offset, unsigned int y_offset, oprand |= ((uint64_t)y_mode & 0x3) << 37; oprand |= ((uint64_t)x_mask & 0x1F) << 41; oprand |= ((uint64_t)x_mode & 0x3) << 46; - + AMX_FMA16(oprand); } @@ -138,7 +144,7 @@ void amx_fma32_masked(bool vector, unsigned int x_offset, unsigned int y_offset, oprand |= ((uint64_t)y_mode & 0x3) << 37; oprand |= ((uint64_t)x_mask & 0x1F) << 41; oprand |= ((uint64_t)x_mode & 0x3) << 46; - + AMX_FMA32(oprand); } diff --git a/src/layer/arm/convolution_im2col_gemm_fp16s.h b/src/layer/arm/convolution_im2col_gemm_fp16s.h index 8907b934949..9c8141dc2dc 100644 --- a/src/layer/arm/convolution_im2col_gemm_fp16s.h +++ b/src/layer/arm/convolution_im2col_gemm_fp16s.h @@ -3056,20 +3056,20 @@ static void convolution_gemm_transB_packed_tile_fp16sa_amx(const Mat& AT_tile, c if (pC) { for (int r = 0; r < 12; r++) - amx_ldz(false, 2*r, pC); + amx_ldz(false, 2 * r, pC); } else { __fp16 sums[16]; memset(sums, 0, 16 * sizeof(__fp16)); for (int r = 0; r < 12; r++) - amx_ldz(false, 2*r, sums); + amx_ldz(false, 2 * r, sums); } } else { for (int r = 0; r < 12; r++) - amx_ldz(false, 2*r, outptr + 8 * r); + amx_ldz(false, 2 * r, outptr + 8 * r); } int kk = 0; @@ -3088,8 +3088,9 @@ static void convolution_gemm_transB_packed_tile_fp16sa_amx(const Mat& AT_tile, c if (out_elempack == 8) { __fp16 tmp[96 + 24]; - for (int r = 0; r < 12; r++) { - amx_stz(false, 2*r, tmp + r * 8); + for (int r = 0; r < 12; r++) + { + amx_stz(false, 2 * r, tmp + r * 8); } memcpy(outptr0, tmp, 96 * sizeof(__fp16)); outptr0 += 96; @@ -3097,8 +3098,9 @@ static void convolution_gemm_transB_packed_tile_fp16sa_amx(const Mat& AT_tile, c if (out_elempack == 4) { __fp16 tmp[32]; - for (int r = 0; r < 12; r++) { - amx_stz(false, 2*r, tmp); + for (int r = 0; r < 12; r++) + { + amx_stz(false, 2 * r, tmp); float16x8_t _tmp = vld1q_f16(tmp); vst1_f16(outptr0 + 4 * r, vget_low_f16(_tmp)); vst1_f16(outptr0 + out_hstep * 4 + 4 * r, vget_high_f16(_tmp)); @@ -3167,8 +3169,9 @@ static void convolution_gemm_transB_packed_tile_fp16sa_amx(const Mat& AT_tile, c else { __fp16 tmp[32]; - for (int r = 0; r < 12; r++) { - amx_stz(false, 2*r, tmp); + for (int r = 0; r < 12; r++) + { + amx_stz(false, 2 * r, tmp); memcpy(outptr0 + 8 * r, tmp, 8 * sizeof(__fp16)); } } @@ -4915,7 +4918,7 @@ static int convolution_im2col_gemm_fp16sa(const Mat& bottom_blob, Mat& top_blob, bool k_end = k + TILE_K >= K; #if __aarch64__ && NCNN_APPLE_AMX -// #if 0 + // #if 0 if (amx_supported) { convolution_gemm_transB_packed_tile_fp16sa_amx(AT_tile, BT_tile, bias, topT_tile, top_blob, i, max_ii, j, max_jj, k, max_kk, k_end);