Skip to content

Commit 249a790

Browse files
authored
vulkan: further optimize q5_k mul_mat_vec (#10479)
1 parent 71a6498 commit 249a790

File tree

1 file changed

+28
-24
lines changed

1 file changed

+28
-24
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp

+28-24
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,6 @@ void main() {
3434
const uint q_offset = 32*v_im + l0;
3535
const uint y_offset = 64*v_im + l0;
3636

37-
const uint8_t hm1 = uint8_t(1 << (2*v_im));
38-
const uint8_t hm2 = uint8_t(hm1 << 4);
39-
4037
FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
4138

4239
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) {
@@ -71,6 +68,18 @@ void main() {
7168
uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F;
7269
uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F;
7370

71+
uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8]));
72+
73+
uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4;
74+
uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3;
75+
uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010) << 0;
76+
uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1;
77+
78+
qs0_16_u32_lo4 += qs0_16_lo4_offset16;
79+
qs0_16_u32_hi4 += qs0_16_hi4_offset16;
80+
qs64_80_u32_lo4 += qs64_80_lo4_offset16;
81+
qs64_80_u32_hi4 += qs64_80_hi4_offset16;
82+
7483
uvec4 qs0_16_lo4 = uvec4(unpack8(qs0_16_u32_lo4));
7584
uvec4 qs64_80_lo4 = uvec4(unpack8(qs64_80_u32_lo4));
7685
uvec4 qs0_16_hi4 = uvec4(unpack8(qs0_16_u32_hi4));
@@ -102,31 +111,26 @@ void main() {
102111
B_TYPE_VEC2 by232 = data_b_v2[(b_offset + y2_idx) / 2 + 16];
103112
B_TYPE_VEC2 by248 = data_b_v2[(b_offset + y2_idx) / 2 + 24];
104113

105-
uint32_t qh0 = data_a_packed16[ib0 + i].qh[l0 / 2];
106-
uint32_t qh1 = qh0 >> 8;
107-
uint32_t qh16 = data_a_packed16[ib0 + i].qh[l0 / 2 + 8];
108-
uint32_t qh17 = qh16 >> 8;
109-
110114
const FLOAT_TYPE sx =
111-
fma(FLOAT_TYPE(by10.x), (q4_0 + (((qh0 & hm1) != 0) ? 16 : 0)),
112-
fma(FLOAT_TYPE(by10.y), (q4_1 + (((qh1 & hm1) != 0) ? 16 : 0)),
113-
fma(FLOAT_TYPE(by116.x), (q4_2 + (((qh16 & hm1) != 0) ? 16 : 0)),
114-
FLOAT_TYPE(by116.y) * (q4_3 + (((qh17 & hm1) != 0) ? 16 : 0)))));
115+
fma(FLOAT_TYPE(by10.x), q4_0,
116+
fma(FLOAT_TYPE(by10.y), q4_1,
117+
fma(FLOAT_TYPE(by116.x), q4_2,
118+
FLOAT_TYPE(by116.y) * q4_3)));
115119
const FLOAT_TYPE sy =
116-
fma(FLOAT_TYPE(by132.x), (q4_4 + (((qh0 & (hm1 << 1)) != 0) ? 16 : 0)),
117-
fma(FLOAT_TYPE(by132.y), (q4_5 + (((qh1 & (hm1 << 1)) != 0) ? 16 : 0)),
118-
fma(FLOAT_TYPE(by148.x), (q4_6 + (((qh16 & (hm1 << 1)) != 0) ? 16 : 0)),
119-
FLOAT_TYPE(by148.y) * (q4_7 + (((qh17 & (hm1 << 1)) != 0) ? 16 : 0)))));
120+
fma(FLOAT_TYPE(by132.x), q4_4,
121+
fma(FLOAT_TYPE(by132.y), q4_5,
122+
fma(FLOAT_TYPE(by148.x), q4_6,
123+
FLOAT_TYPE(by148.y) * q4_7)));
120124
const FLOAT_TYPE sz =
121-
fma(FLOAT_TYPE(by20.x), (q4_8 + (((qh0 & hm2) != 0) ? 16 : 0)),
122-
fma(FLOAT_TYPE(by20.y), (q4_9 + (((qh1 & hm2) != 0) ? 16 : 0)),
123-
fma(FLOAT_TYPE(by216.x), (q4_10 + (((qh16 & hm2) != 0) ? 16 : 0)),
124-
FLOAT_TYPE(by216.y) * (q4_11 + (((qh17 & hm2) != 0) ? 16 : 0)))));
125+
fma(FLOAT_TYPE(by20.x), q4_8,
126+
fma(FLOAT_TYPE(by20.y), q4_9,
127+
fma(FLOAT_TYPE(by216.x), q4_10,
128+
FLOAT_TYPE(by216.y) * q4_11)));
125129
const FLOAT_TYPE sw =
126-
fma(FLOAT_TYPE(by232.x), (q4_12 + (((qh0 & (hm2 << 1)) != 0) ? 16 : 0)),
127-
fma(FLOAT_TYPE(by232.y), (q4_13 + (((qh1 & (hm2 << 1)) != 0) ? 16 : 0)),
128-
fma(FLOAT_TYPE(by248.x), (q4_14 + (((qh16 & (hm2 << 1)) != 0) ? 16 : 0)),
129-
FLOAT_TYPE(by248.y) * (q4_15 + (((qh17 & (hm2 << 1)) != 0) ? 16 : 0)))));
130+
fma(FLOAT_TYPE(by232.x), q4_12,
131+
fma(FLOAT_TYPE(by232.y), q4_13,
132+
fma(FLOAT_TYPE(by248.x), q4_14,
133+
FLOAT_TYPE(by248.y) * q4_15)));
130134
const FLOAT_TYPE smin =
131135
fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2,
132136
fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,

0 commit comments

Comments
 (0)