|
| 1 | +import os |
| 2 | +from typing import Dict, List, Optional |
| 3 | + |
| 4 | +import mlx.core as mx |
| 5 | + |
| 6 | +_KERNELS: Dict[str, object] = {} |
| 7 | + |
| 8 | + |
| 9 | +def _get_metal_source(filename): |
| 10 | + path = os.path.join(os.path.dirname(__file__), filename) |
| 11 | + with open(path, "r") as f: |
| 12 | + return f.read() |
| 13 | + |
| 14 | + |
| 15 | +def _type_to_string(dtype: mx.Dtype) -> str: |
| 16 | + if dtype == mx.float32: |
| 17 | + return "float" |
| 18 | + elif dtype == mx.float16: |
| 19 | + return "half" |
| 20 | + elif dtype == mx.bfloat16: |
| 21 | + return "bfloat16_t" |
| 22 | + else: |
| 23 | + raise ValueError(f"Unsupported dtype: {dtype}") |
| 24 | + |
| 25 | + |
| 26 | +def _get_kernel( |
| 27 | + name: str, |
| 28 | + filename: str, |
| 29 | + input_names: List[str], |
| 30 | + output_names: List[str], |
| 31 | + dtype: mx.Dtype = mx.float32, |
| 32 | +): |
| 33 | + type_str = _type_to_string(dtype) |
| 34 | + kernel_key = f"{name}_{type_str}" |
| 35 | + |
| 36 | + if kernel_key not in _KERNELS: |
| 37 | + source = _get_metal_source(filename) |
| 38 | + source = source.replace("{{T}}", type_str) |
| 39 | + |
| 40 | + header = """ |
| 41 | +#include <metal_stdlib> |
| 42 | +using namespace metal; |
| 43 | +""" |
| 44 | + _KERNELS[kernel_key] = mx.fast.metal_kernel( |
| 45 | + name=name, |
| 46 | + input_names=input_names, |
| 47 | + output_names=output_names, |
| 48 | + source=source, |
| 49 | + header=header, |
| 50 | + ) |
| 51 | + return _KERNELS[kernel_key] |
| 52 | + |
| 53 | + |
| 54 | +def store_indexer_cache( |
| 55 | + key: mx.array, |
| 56 | + key_cache: mx.array, |
| 57 | + block_tables: mx.array, |
| 58 | + context_lengths: mx.array, |
| 59 | + block_size: int, |
| 60 | + layer_idx: int, |
| 61 | + slot_mapping: Optional[mx.array] = None, |
| 62 | +): |
| 63 | + dtype = key.dtype |
| 64 | + # key: (batch, target_len, num_heads, head_dim) or flattened |
| 65 | + |
| 66 | + if slot_mapping is None: |
| 67 | + # Decode Mode |
| 68 | + batch_size = key.shape[0] |
| 69 | + if key.ndim == 4: |
| 70 | + # (batch, 1, num_kv_heads, head_dim) -> (batch, num_kv_heads, head_dim) |
| 71 | + if key.shape[1] == 1: |
| 72 | + key = key.squeeze(1) |
| 73 | + elif key.shape[2] == 1: |
| 74 | + # Fallback for old layout (batch, num_kv_heads, 1, head_dim) |
| 75 | + key = key.squeeze(2) |
| 76 | + |
| 77 | + num_heads = key.shape[1] |
| 78 | + head_dim = key.shape[2] |
| 79 | + |
| 80 | + # Compute slot_mapping internally |
| 81 | + indices = context_lengths - 1 |
| 82 | + block_indices_in_table = indices // block_size |
| 83 | + offsets = indices % block_size |
| 84 | + batch_indices = mx.arange(batch_size) |
| 85 | + physical_block_numbers = block_tables[batch_indices, block_indices_in_table] |
| 86 | + slot_mapping = physical_block_numbers.astype(mx.int32) * block_size + offsets.astype( |
| 87 | + mx.int32 |
| 88 | + ) |
| 89 | + |
| 90 | + num_tokens = batch_size |
| 91 | + else: |
| 92 | + # Prefill Mode |
| 93 | + if key.ndim == 4: |
| 94 | + B, T, H, D = key.shape |
| 95 | + key = key.reshape(B * T, H, D) |
| 96 | + |
| 97 | + num_tokens = key.shape[0] |
| 98 | + num_heads = key.shape[1] |
| 99 | + head_dim = key.shape[2] |
| 100 | + |
| 101 | + num_layers = key_cache.shape[0] |
| 102 | + num_blocks = key_cache.shape[1] |
| 103 | + |
| 104 | + key_stride = num_heads * head_dim |
| 105 | + |
| 106 | + def mk_int(val): |
| 107 | + return mx.array(val, dtype=mx.int32) |
| 108 | + |
| 109 | + inputs = [ |
| 110 | + key, |
| 111 | + key_cache, |
| 112 | + slot_mapping, |
| 113 | + mk_int(key_stride), |
| 114 | + mk_int(num_heads), |
| 115 | + mk_int(head_dim), |
| 116 | + mk_int(block_size), |
| 117 | + mk_int(layer_idx), |
| 118 | + mk_int(num_layers), |
| 119 | + mk_int(num_blocks), |
| 120 | + ] |
| 121 | + |
| 122 | + input_names = [ |
| 123 | + "key", |
| 124 | + "key_cache", |
| 125 | + "slot_mapping", |
| 126 | + "key_stride", |
| 127 | + "num_heads", |
| 128 | + "head_dim", |
| 129 | + "block_size", |
| 130 | + "layer_idx", |
| 131 | + "num_layers", |
| 132 | + "num_blocks", |
| 133 | + ] |
| 134 | + |
| 135 | + kernel = _get_kernel( |
| 136 | + name="store_key_kernel", |
| 137 | + filename="store_key.metal", |
| 138 | + input_names=input_names, |
| 139 | + output_names=["dummy_out"], |
| 140 | + dtype=dtype, |
| 141 | + ) |
| 142 | + |
| 143 | + grid = (num_heads * head_dim, num_tokens, 1) |
| 144 | + thread_group = (min(1024, num_heads * head_dim), 1, 1) |
| 145 | + |
| 146 | + outputs = kernel( |
| 147 | + inputs=inputs, |
| 148 | + grid=grid, |
| 149 | + threadgroup=thread_group, |
| 150 | + output_shapes=[(num_tokens, num_heads * head_dim)], # Dummy output |
| 151 | + output_dtypes=[mx.float32], |
| 152 | + verbose=False, |
| 153 | + ) |
| 154 | + mx.eval(outputs) |
| 155 | + |
| 156 | + |
| 157 | +def q_dot_k( |
| 158 | + q: mx.array, # (num_heads, head_dim) |
| 159 | + key_cache: mx.array, # (L, B, H, BS, D) |
| 160 | + block_table: mx.array, # (max_blocks) |
| 161 | + context_length: mx.array, # scalar |
| 162 | + block_size: int, |
| 163 | + layer_idx: int, |
| 164 | +) -> mx.array: |
| 165 | + |
| 166 | + if q.ndim > 2: |
| 167 | + q = q.squeeze() # Ensure (H, D) |
| 168 | + |
| 169 | + num_heads = q.shape[0] |
| 170 | + head_dim = q.shape[1] |
| 171 | + |
| 172 | + num_layers = key_cache.shape[0] |
| 173 | + num_total_blocks = key_cache.shape[1] |
| 174 | + max_blocks = block_table.shape[0] |
| 175 | + |
| 176 | + ctx_len = int(context_length.item()) |
| 177 | + |
| 178 | + def mk_int(val): |
| 179 | + return mx.array(val, dtype=mx.int32) |
| 180 | + |
| 181 | + inputs = [ |
| 182 | + q, |
| 183 | + key_cache, |
| 184 | + block_table, |
| 185 | + mk_int(ctx_len), |
| 186 | + mk_int(block_size), |
| 187 | + mk_int(num_heads), |
| 188 | + mk_int(head_dim), |
| 189 | + mk_int(layer_idx), |
| 190 | + mk_int(num_layers), |
| 191 | + mk_int(num_total_blocks), |
| 192 | + mk_int(max_blocks), |
| 193 | + ] |
| 194 | + |
| 195 | + input_names = [ |
| 196 | + "q", |
| 197 | + "key_cache", |
| 198 | + "block_table", |
| 199 | + "context_len", |
| 200 | + "block_size", |
| 201 | + "num_heads", |
| 202 | + "head_dim", |
| 203 | + "layer_idx", |
| 204 | + "num_layers", |
| 205 | + "num_total_blocks", |
| 206 | + "max_blocks", |
| 207 | + ] |
| 208 | + |
| 209 | + kernel = _get_kernel( |
| 210 | + name="q_dot_k_kernel", |
| 211 | + filename="q_dot_k.metal", |
| 212 | + input_names=input_names, |
| 213 | + output_names=["output"], |
| 214 | + dtype=q.dtype, |
| 215 | + ) |
| 216 | + |
| 217 | + # Grid: (block_size, num_heads, 1) |
| 218 | + grid = (block_size, num_heads, 1) |
| 219 | + thread_group = (min(1024, block_size), 1, 1) |
| 220 | + |
| 221 | + outputs = kernel( |
| 222 | + inputs=inputs, |
| 223 | + grid=grid, |
| 224 | + threadgroup=thread_group, |
| 225 | + output_shapes=[(num_heads, ctx_len)], |
| 226 | + output_dtypes=[mx.float32], # Score is float32 |
| 227 | + verbose=False, |
| 228 | + ) |
| 229 | + |
| 230 | + return outputs[0] |
0 commit comments