16
16
#define AMX_USABILITY_H
17
17
18
18
// From https://github.com/corsix/amx/blob/main/aarch64.h
19
- #define AMX_NOP_OP_IMM5 (op , imm5 ) \
20
- __asm("nop\nnop\nnop\n.word (0x201000 + (%0 << 5) + %1)" : : "i"(op), "i"(imm5) : "memory")
21
-
22
- #define AMX_OP_GPR (op , gpr ) \
23
- __asm(".word (0x201000 + (%0 << 5) + 0%1 - ((0%1 >> 4) * 6))" : : "i"(op), "r"((uint64_t)(gpr)) : "memory")
24
-
25
- #define AMX_LDX (gpr ) AMX_OP_GPR( 0, gpr)
26
- #define AMX_LDY (gpr ) AMX_OP_GPR( 1, gpr)
27
- #define AMX_STX (gpr ) AMX_OP_GPR( 2, gpr)
28
- #define AMX_STY (gpr ) AMX_OP_GPR( 3, gpr)
29
- #define AMX_LDZ (gpr ) AMX_OP_GPR( 4, gpr)
30
- #define AMX_STZ (gpr ) AMX_OP_GPR( 5, gpr)
31
- #define AMX_LDZI (gpr ) AMX_OP_GPR( 6, gpr)
32
- #define AMX_STZI (gpr ) AMX_OP_GPR( 7, gpr)
33
- #define AMX_EXTRX (gpr ) AMX_OP_GPR( 8, gpr)
34
- #define AMX_EXTRY (gpr ) AMX_OP_GPR( 9, gpr)
35
- #define AMX_FMA64 (gpr ) AMX_OP_GPR(10, gpr)
36
- #define AMX_FMS64 (gpr ) AMX_OP_GPR(11, gpr)
37
- #define AMX_FMA32 (gpr ) AMX_OP_GPR(12, gpr)
38
- #define AMX_FMS32 (gpr ) AMX_OP_GPR(13, gpr)
39
- #define AMX_MAC16 (gpr ) AMX_OP_GPR(14, gpr)
40
- #define AMX_FMA16 (gpr ) AMX_OP_GPR(15, gpr)
41
- #define AMX_FMS16 (gpr ) AMX_OP_GPR(16, gpr)
42
- #define AMX_VECINT (gpr ) AMX_OP_GPR(18, gpr)
43
- #define AMX_VECFP (gpr ) AMX_OP_GPR(19, gpr)
44
- #define AMX_MATINT (gpr ) AMX_OP_GPR(20, gpr)
45
- #define AMX_MATFP (gpr ) AMX_OP_GPR(21, gpr)
46
- #define AMX_GENLUT (gpr ) AMX_OP_GPR(22, gpr)
47
- #define PTR_ROW_FLAGS (ptr , row , flags ) (((uint64_t)&*(ptr)) + (((uint64_t)((row) + (flags) * 64)) << 56))
19
+ #define AMX_NOP_OP_IMM5 (op , imm5 ) \
20
+ __asm("nop\nnop\nnop\n.word (0x201000 + (%0 << 5) + %1)" \
21
+ : \
22
+ : "i"(op), "i"(imm5) \
23
+ : "memory")
24
+
25
+ #define AMX_OP_GPR (op , gpr ) \
26
+ __asm(".word (0x201000 + (%0 << 5) + 0%1 - ((0%1 >> 4) * 6))" \
27
+ : \
28
+ : "i"(op), "r"((uint64_t)(gpr)) \
29
+ : "memory")
30
+
31
+ #define AMX_LDX (gpr ) AMX_OP_GPR(0, gpr)
32
+ #define AMX_LDY (gpr ) AMX_OP_GPR(1, gpr)
33
+ #define AMX_STX (gpr ) AMX_OP_GPR(2, gpr)
34
+ #define AMX_STY (gpr ) AMX_OP_GPR(3, gpr)
35
+ #define AMX_LDZ (gpr ) AMX_OP_GPR(4, gpr)
36
+ #define AMX_STZ (gpr ) AMX_OP_GPR(5, gpr)
37
+ #define AMX_LDZI (gpr ) AMX_OP_GPR(6, gpr)
38
+ #define AMX_STZI (gpr ) AMX_OP_GPR(7, gpr)
39
+ #define AMX_EXTRX (gpr ) AMX_OP_GPR(8, gpr)
40
+ #define AMX_EXTRY (gpr ) AMX_OP_GPR(9, gpr)
41
+ #define AMX_FMA64 (gpr ) AMX_OP_GPR(10, gpr)
42
+ #define AMX_FMS64 (gpr ) AMX_OP_GPR(11, gpr)
43
+ #define AMX_FMA32 (gpr ) AMX_OP_GPR(12, gpr)
44
+ #define AMX_FMS32 (gpr ) AMX_OP_GPR(13, gpr)
45
+ #define AMX_MAC16 (gpr ) AMX_OP_GPR(14, gpr)
46
+ #define AMX_FMA16 (gpr ) AMX_OP_GPR(15, gpr)
47
+ #define AMX_FMS16 (gpr ) AMX_OP_GPR(16, gpr)
48
+ #define AMX_VECINT (gpr ) AMX_OP_GPR(18, gpr)
49
+ #define AMX_VECFP (gpr ) AMX_OP_GPR(19, gpr)
50
+ #define AMX_MATINT (gpr ) AMX_OP_GPR(20, gpr)
51
+ #define AMX_MATFP (gpr ) AMX_OP_GPR(21, gpr)
52
+ #define AMX_GENLUT (gpr ) AMX_OP_GPR(22, gpr)
53
+ #define PTR_ROW_FLAGS (ptr , row , flags ) (((uint64_t) & *(ptr)) + (((uint64_t)((row) + (flags)*64)) << 56))
48
54
void amx_set ()
49
55
{
50
56
AMX_NOP_OP_IMM5 (17 , 0 );
@@ -55,51 +61,51 @@ void amx_clr()
55
61
AMX_NOP_OP_IMM5 (17 , 1 );
56
62
}
57
63
58
- void amx_ldx (bool pair , unsigned int x_row , const void * ptr )
64
+ void amx_ldx (bool pair , unsigned int x_row , const void * ptr )
59
65
{
60
66
if (x_row >= 8 )
61
67
return ;
62
68
63
69
uint64_t oprand = (uint64_t )ptr + ((uint64_t )x_row << 56 );
64
70
if (pair )
65
71
oprand |= 1ULL << 62 ;
66
-
72
+
67
73
AMX_LDX (oprand );
68
74
}
69
75
70
- void amx_ldy (bool pair , unsigned int y_row , const void * ptr )
76
+ void amx_ldy (bool pair , unsigned int y_row , const void * ptr )
71
77
{
72
78
if (y_row >= 8 )
73
79
return ;
74
80
75
81
uint64_t oprand = (uint64_t )ptr + ((uint64_t )y_row << 56 );
76
82
if (pair )
77
83
oprand |= 1ULL << 62 ;
78
-
84
+
79
85
AMX_LDY (oprand );
80
86
}
81
87
82
- void amx_ldz (bool pair , unsigned int z_row , const void * ptr )
88
+ void amx_ldz (bool pair , unsigned int z_row , const void * ptr )
83
89
{
84
90
if (z_row >= 64 )
85
91
return ;
86
92
87
93
uint64_t oprand = (uint64_t )ptr + ((uint64_t )z_row << 56 );
88
94
if (pair )
89
95
oprand |= 1ULL << 62 ;
90
-
96
+
91
97
AMX_LDZ (oprand );
92
98
}
93
99
94
- void amx_stz (bool pair , unsigned int z_row , const void * ptr )
100
+ void amx_stz (bool pair , unsigned int z_row , const void * ptr )
95
101
{
96
102
if (z_row >= 64 )
97
103
return ;
98
104
99
105
uint64_t oprand = (uint64_t )ptr + ((uint64_t )z_row << 56 );
100
106
if (pair )
101
107
oprand |= 1ULL << 62 ;
102
-
108
+
103
109
AMX_STZ (oprand );
104
110
}
105
111
@@ -116,7 +122,7 @@ void amx_fma16_masked(bool vector, unsigned int x_offset, unsigned int y_offset,
116
122
oprand |= ((uint64_t )y_mode & 0x3 ) << 37 ;
117
123
oprand |= ((uint64_t )x_mask & 0x1F ) << 41 ;
118
124
oprand |= ((uint64_t )x_mode & 0x3 ) << 46 ;
119
-
125
+
120
126
AMX_FMA16 (oprand );
121
127
}
122
128
@@ -138,7 +144,7 @@ void amx_fma32_masked(bool vector, unsigned int x_offset, unsigned int y_offset,
138
144
oprand |= ((uint64_t )y_mode & 0x3 ) << 37 ;
139
145
oprand |= ((uint64_t )x_mask & 0x1F ) << 41 ;
140
146
oprand |= ((uint64_t )x_mode & 0x3 ) << 46 ;
141
-
147
+
142
148
AMX_FMA32 (oprand );
143
149
}
144
150
0 commit comments