|
1 | 1 | #version 450
|
| 2 | +#extension GL_EXT_shader_explicit_arithmetic_types : require |
2 | 3 |
|
3 | 4 | #include "mul_mat_vec_base.comp"
|
4 | 5 |
|
@@ -32,38 +33,67 @@ void main() {
|
32 | 33 | const uint s_offset = 8*v_im;
|
33 | 34 | const uint y_offset = 128*v_im + l0;
|
34 | 35 |
|
35 |
| - tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp |
| 36 | + FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp |
36 | 37 |
|
37 | 38 | [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
38 | 39 | const uint y_idx = i * QUANT_K + y_offset;
|
39 | 40 |
|
40 |
| - const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x); |
41 |
| - const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y); |
| 41 | + f16vec2 d = data_a[ib0 + i].d; |
| 42 | + const FLOAT_TYPE dall = d.x; |
| 43 | + const FLOAT_TYPE dmin = d.y; |
| 44 | + |
| 45 | + B_TYPE_VEC2 b0 = data_b_v2[(b_offset + y_idx) / 2 + 0]; |
| 46 | + B_TYPE_VEC2 b16 = data_b_v2[(b_offset + y_idx) / 2 + 8]; |
| 47 | + B_TYPE_VEC2 b32 = data_b_v2[(b_offset + y_idx) / 2 + 16]; |
| 48 | + B_TYPE_VEC2 b48 = data_b_v2[(b_offset + y_idx) / 2 + 24]; |
| 49 | + B_TYPE_VEC2 b64 = data_b_v2[(b_offset + y_idx) / 2 + 32]; |
| 50 | + B_TYPE_VEC2 b80 = data_b_v2[(b_offset + y_idx) / 2 + 40]; |
| 51 | + B_TYPE_VEC2 b96 = data_b_v2[(b_offset + y_idx) / 2 + 48]; |
| 52 | + B_TYPE_VEC2 b112 = data_b_v2[(b_offset + y_idx) / 2 + 56]; |
| 53 | + |
| 54 | + uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0]; |
| 55 | + uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1]; |
| 56 | + |
| 57 | + uint32_t s0_lo4_u32 = s0_u32 & 0x0F0F0F0F; |
| 58 | + uint32_t s0_hi4_u32 = (s0_u32 >> 4) & 0x0F0F0F0F; |
| 59 | + uint32_t s4_lo4_u32 = s4_u32 & 0x0F0F0F0F; |
| 60 | + uint32_t s4_hi4_u32 = (s4_u32 >> 4) & 0x0F0F0F0F; |
| 61 | + |
| 62 | + uvec4 s0_lo4 = uvec4(unpack8(s0_lo4_u32)); |
| 63 | + uvec4 s4_lo4 = uvec4(unpack8(s4_lo4_u32)); |
| 64 | + uvec4 s0_hi4 = uvec4(unpack8(s0_hi4_u32)); |
| 65 | + uvec4 s4_hi4 = uvec4(unpack8(s4_hi4_u32)); |
| 66 | + |
| 67 | + uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 0]; |
| 68 | + uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]; |
| 69 | + uvec2 qs0 = uvec2(unpack8(qs0_u16)); |
| 70 | + uvec2 qs16 = uvec2(unpack8(qs16_u16)); |
42 | 71 |
|
43 | 72 | FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
|
44 | 73 | FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
|
45 |
| - for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { |
46 |
| - sum1 = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 0) & 3), |
47 |
| - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 0) & 3), |
48 |
| - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 2) & 3), |
49 |
| - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 2) & 3), |
50 |
| - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 4) & 3), |
51 |
| - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 4) & 3), |
52 |
| - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 6) & 3), |
53 |
| - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l +112]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 6) & 3), sum1)))))))); |
54 |
| - sum2 = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 0] >> 4) & 0xF), |
55 |
| - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 1] >> 4) & 0xF), |
56 |
| - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 2] >> 4) & 0xF), |
57 |
| - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 3] >> 4) & 0xF), |
58 |
| - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 4] >> 4) & 0xF), |
59 |
| - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 5] >> 4) & 0xF), |
60 |
| - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 6] >> 4) & 0xF), |
61 |
| - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l +112]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 7] >> 4) & 0xF), sum2)))))))); |
| 74 | + [[unroll]] for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { |
| 75 | + sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3), |
| 76 | + fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3), |
| 77 | + fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3), |
| 78 | + fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3), |
| 79 | + fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3), |
| 80 | + fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3), |
| 81 | + fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3), |
| 82 | + fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1)))))))); |
| 83 | + sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]), |
| 84 | + fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]), |
| 85 | + fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]), |
| 86 | + fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]), |
| 87 | + fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]), |
| 88 | + fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]), |
| 89 | + fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]), |
| 90 | + fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2)))))))); |
62 | 91 | }
|
63 |
| - const uint tmp_idx = 16 * ix + tid; |
64 |
| - tmp[tmp_idx] = fma(dall, sum1, fma(-dmin, sum2, tmp[tmp_idx])); |
| 92 | + temp = fma(dall, sum1, fma(-dmin, sum2, temp)); |
65 | 93 | }
|
66 | 94 |
|
| 95 | + tmp[gl_LocalInvocationID.x] = temp; |
| 96 | + |
67 | 97 | // sum up partial sums and write back result
|
68 | 98 | barrier();
|
69 | 99 | [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
|
|
0 commit comments