@@ -34,9 +34,6 @@ void main() {
34
34
const uint q_offset = 32*v_im + l0;
35
35
const uint y_offset = 64*v_im + l0;
36
36
37
- const uint8_t hm1 = uint8_t(1 << (2*v_im));
38
- const uint8_t hm2 = uint8_t(hm1 << 4);
39
-
40
37
FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
41
38
42
39
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) {
@@ -71,6 +68,18 @@ void main() {
71
68
uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F;
72
69
uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F;
73
70
71
+ uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8]));
72
+
73
+ uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4;
74
+ uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3;
75
+ uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010) << 0;
76
+ uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1;
77
+
78
+ qs0_16_u32_lo4 += qs0_16_lo4_offset16;
79
+ qs0_16_u32_hi4 += qs0_16_hi4_offset16;
80
+ qs64_80_u32_lo4 += qs64_80_lo4_offset16;
81
+ qs64_80_u32_hi4 += qs64_80_hi4_offset16;
82
+
74
83
uvec4 qs0_16_lo4 = uvec4(unpack8(qs0_16_u32_lo4));
75
84
uvec4 qs64_80_lo4 = uvec4(unpack8(qs64_80_u32_lo4));
76
85
uvec4 qs0_16_hi4 = uvec4(unpack8(qs0_16_u32_hi4));
@@ -102,31 +111,26 @@ void main() {
102
111
B_TYPE_VEC2 by232 = data_b_v2[(b_offset + y2_idx) / 2 + 16];
103
112
B_TYPE_VEC2 by248 = data_b_v2[(b_offset + y2_idx) / 2 + 24];
104
113
105
- uint32_t qh0 = data_a_packed16[ib0 + i].qh[l0 / 2];
106
- uint32_t qh1 = qh0 >> 8;
107
- uint32_t qh16 = data_a_packed16[ib0 + i].qh[l0 / 2 + 8];
108
- uint32_t qh17 = qh16 >> 8;
109
-
110
114
const FLOAT_TYPE sx =
111
- fma(FLOAT_TYPE(by10.x), ( q4_0 + (((qh0 & hm1) != 0) ? 16 : 0)) ,
112
- fma(FLOAT_TYPE(by10.y), ( q4_1 + (((qh1 & hm1) != 0) ? 16 : 0)) ,
113
- fma(FLOAT_TYPE(by116.x), ( q4_2 + (((qh16 & hm1) != 0) ? 16 : 0)) ,
114
- FLOAT_TYPE(by116.y) * ( q4_3 + (((qh17 & hm1) != 0) ? 16 : 0)) )));
115
+ fma(FLOAT_TYPE(by10.x), q4_0,
116
+ fma(FLOAT_TYPE(by10.y), q4_1,
117
+ fma(FLOAT_TYPE(by116.x), q4_2,
118
+ FLOAT_TYPE(by116.y) * q4_3)));
115
119
const FLOAT_TYPE sy =
116
- fma(FLOAT_TYPE(by132.x), ( q4_4 + (((qh0 & (hm1 << 1)) != 0) ? 16 : 0)) ,
117
- fma(FLOAT_TYPE(by132.y), ( q4_5 + (((qh1 & (hm1 << 1)) != 0) ? 16 : 0)) ,
118
- fma(FLOAT_TYPE(by148.x), ( q4_6 + (((qh16 & (hm1 << 1)) != 0) ? 16 : 0)) ,
119
- FLOAT_TYPE(by148.y) * ( q4_7 + (((qh17 & (hm1 << 1)) != 0) ? 16 : 0)) )));
120
+ fma(FLOAT_TYPE(by132.x), q4_4,
121
+ fma(FLOAT_TYPE(by132.y), q4_5,
122
+ fma(FLOAT_TYPE(by148.x), q4_6,
123
+ FLOAT_TYPE(by148.y) * q4_7)));
120
124
const FLOAT_TYPE sz =
121
- fma(FLOAT_TYPE(by20.x), ( q4_8 + (((qh0 & hm2) != 0) ? 16 : 0)) ,
122
- fma(FLOAT_TYPE(by20.y), ( q4_9 + (((qh1 & hm2) != 0) ? 16 : 0)) ,
123
- fma(FLOAT_TYPE(by216.x), ( q4_10 + (((qh16 & hm2) != 0) ? 16 : 0)) ,
124
- FLOAT_TYPE(by216.y) * ( q4_11 + (((qh17 & hm2) != 0) ? 16 : 0)) )));
125
+ fma(FLOAT_TYPE(by20.x), q4_8,
126
+ fma(FLOAT_TYPE(by20.y), q4_9,
127
+ fma(FLOAT_TYPE(by216.x), q4_10,
128
+ FLOAT_TYPE(by216.y) * q4_11)));
125
129
const FLOAT_TYPE sw =
126
- fma(FLOAT_TYPE(by232.x), ( q4_12 + (((qh0 & (hm2 << 1)) != 0) ? 16 : 0)) ,
127
- fma(FLOAT_TYPE(by232.y), ( q4_13 + (((qh1 & (hm2 << 1)) != 0) ? 16 : 0)) ,
128
- fma(FLOAT_TYPE(by248.x), ( q4_14 + (((qh16 & (hm2 << 1)) != 0) ? 16 : 0)) ,
129
- FLOAT_TYPE(by248.y) * ( q4_15 + (((qh17 & (hm2 << 1)) != 0) ? 16 : 0)) )));
130
+ fma(FLOAT_TYPE(by232.x), q4_12,
131
+ fma(FLOAT_TYPE(by232.y), q4_13,
132
+ fma(FLOAT_TYPE(by248.x), q4_14,
133
+ FLOAT_TYPE(by248.y) * q4_15)));
130
134
const FLOAT_TYPE smin =
131
135
fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2,
132
136
fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
0 commit comments