|
1 | 1 | # Copyright © 2025 Apple Inc. |
2 | | -import math |
3 | | -from dataclasses import dataclass |
4 | | -from typing import Any, Dict, Optional, Tuple |
| 2 | +from typing import Any, Optional, Tuple |
5 | 3 |
|
6 | 4 | import mlx.core as mx |
7 | | -import mlx.nn as nn |
8 | | -from mlx_lm.models.base import BaseModelArgs, scaled_dot_product_attention |
9 | | -from mlx_lm.models.rope_utils import initialize_rope |
10 | | -from mlx_lm.models.switch_layers import SwitchGLU |
| 5 | +from mlx_lm.models.base import scaled_dot_product_attention |
| 6 | +from mlx_lm.models.deepseek_v32 import DeepseekV32Attention as MLXDeepseekV32Attention |
| 7 | +from mlx_lm.models.deepseek_v32 import DeepseekV32DecoderLayer as MLXDeepseekV32Block |
| 8 | +from mlx_lm.models.deepseek_v32 import Indexer as MLXDeepseekV32Indexer |
| 9 | +from mlx_lm.models.deepseek_v32 import ModelArgs |
11 | 10 |
|
12 | 11 | from parallax.metal.indexer.kernel import q_dot_k, store_indexer_cache |
13 | 12 | from parallax.metal.paged_attention.kernel import paged_attention, reshape_and_cache |
14 | 13 |
|
15 | 14 |
|
16 | | -@dataclass |
17 | | -class ModelArgs(BaseModelArgs): |
18 | | - model_type: str = "deepseek_v32" |
19 | | - vocab_size: int = 102400 |
20 | | - hidden_size: int = 4096 |
21 | | - index_n_heads: int = 64 |
22 | | - index_head_dim: int = 128 |
23 | | - index_topk: int = 2048 |
24 | | - intermediate_size: int = 11008 |
25 | | - moe_intermediate_size: int = 1407 |
26 | | - num_hidden_layers: int = 30 |
27 | | - num_attention_heads: int = 32 |
28 | | - num_key_value_heads: int = 32 |
29 | | - n_shared_experts: Optional[int] = None |
30 | | - n_routed_experts: Optional[int] = None |
31 | | - routed_scaling_factor: float = 1.0 |
32 | | - kv_lora_rank: int = 512 |
33 | | - q_lora_rank: int = 1536 |
34 | | - qk_rope_head_dim: int = 64 |
35 | | - v_head_dim: int = 128 |
36 | | - qk_nope_head_dim: int = 128 |
37 | | - topk_method: str = "noaux_tc" |
38 | | - scoring_func: str = "sigmoid" |
39 | | - norm_topk_prob: bool = True |
40 | | - n_group: int = 1 |
41 | | - topk_group: int = 1 |
42 | | - num_experts_per_tok: int = 1 |
43 | | - moe_layer_freq: int = 1 |
44 | | - first_k_dense_replace: int = 0 |
45 | | - max_position_embeddings: int = 2048 |
46 | | - rms_norm_eps: float = 1e-6 |
47 | | - rope_theta: float = 10000.0 |
48 | | - rope_scaling: Dict = None |
49 | | - attention_bias: bool = False |
50 | | - |
51 | | - |
52 | | -class Indexer(nn.Module): |
53 | | - def __init__(self, args: ModelArgs): |
54 | | - super().__init__() |
55 | | - self.dim = args.hidden_size |
56 | | - self.n_heads = args.index_n_heads |
57 | | - self.head_dim = args.index_head_dim |
58 | | - self.rope_head_dim = args.qk_rope_head_dim |
59 | | - self.index_topk = args.index_topk |
60 | | - self.q_lora_rank = args.q_lora_rank |
61 | | - self.wq_b = nn.Linear(self.q_lora_rank, self.n_heads * self.head_dim, bias=False) |
62 | | - self.wk = nn.Linear(self.dim, self.head_dim, bias=False) |
63 | | - self.k_norm = nn.LayerNorm(self.head_dim) |
64 | | - self.weights_proj = nn.Linear(self.dim, self.n_heads, bias=False) |
65 | | - self.softmax_scale = self.head_dim**-0.5 |
66 | | - self.rope = nn.RoPE( |
67 | | - dims=self.rope_head_dim, |
68 | | - base=args.rope_theta, |
69 | | - traditional=False, # Non-interleaved |
70 | | - ) |
71 | | - self.rope = initialize_rope( |
72 | | - dims=args.qk_rope_head_dim, |
73 | | - base=args.rope_theta, |
74 | | - traditional=False, |
75 | | - max_position_embeddings=args.max_position_embeddings, |
76 | | - scaling_config=args.rope_scaling, |
77 | | - ) |
78 | | - |
| 15 | +class ParallaxDeepSeekV32Indexer(MLXDeepseekV32Indexer): |
79 | 16 | def __call__( |
80 | 17 | self, |
81 | 18 | x: mx.array, |
@@ -168,207 +105,11 @@ def __call__( |
168 | 105 | return mx.argpartition(scores, kth=-self.index_topk, axis=-1)[..., -self.index_topk :] |
169 | 106 |
|
170 | 107 |
|
171 | | -class DeepseekV32Attention(nn.Module): |
172 | | - def __init__(self, config: ModelArgs): |
173 | | - super().__init__() |
174 | | - self.config = config |
175 | | - self.hidden_size = config.hidden_size |
176 | | - self.num_heads = config.num_attention_heads |
177 | | - self.max_position_embeddings = config.max_position_embeddings |
178 | | - self.rope_theta = config.rope_theta |
179 | | - self.q_lora_rank = config.q_lora_rank |
180 | | - self.qk_rope_head_dim = config.qk_rope_head_dim |
181 | | - self.kv_lora_rank = config.kv_lora_rank |
182 | | - self.v_head_dim = config.v_head_dim |
183 | | - self.qk_nope_head_dim = config.qk_nope_head_dim |
184 | | - self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim |
185 | | - |
186 | | - self.scale = self.q_head_dim**-0.5 |
187 | | - |
188 | | - if self.q_lora_rank is None: |
189 | | - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias=False) |
190 | | - else: |
191 | | - self.q_a_proj = nn.Linear( |
192 | | - self.hidden_size, self.q_lora_rank, bias=config.attention_bias |
193 | | - ) |
194 | | - self.q_a_layernorm = nn.RMSNorm(self.q_lora_rank, eps=1e-6) |
195 | | - self.q_b_proj = nn.Linear( |
196 | | - self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False |
197 | | - ) |
198 | | - |
199 | | - self.kv_a_proj_with_mqa = nn.Linear( |
200 | | - self.hidden_size, |
201 | | - self.kv_lora_rank + self.qk_rope_head_dim, |
202 | | - bias=config.attention_bias, |
203 | | - ) |
204 | | - self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank, eps=1e-6) |
205 | | - self.kv_b_proj = nn.Linear( |
206 | | - self.kv_lora_rank, |
207 | | - self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), |
208 | | - bias=False, |
209 | | - ) |
210 | | - |
211 | | - self.o_proj = nn.Linear( |
212 | | - self.num_heads * self.v_head_dim, |
213 | | - self.hidden_size, |
214 | | - bias=config.attention_bias, |
215 | | - ) |
216 | | - |
217 | | - if self.config.rope_scaling is not None: |
218 | | - mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) |
219 | | - if mscale_all_dim: |
220 | | - scaling_factor = self.config.rope_scaling["factor"] |
221 | | - if scaling_factor > 1: |
222 | | - s = 0.1 * mscale_all_dim * math.log(scaling_factor) + 1.0 |
223 | | - self.scale = self.scale * s * s |
224 | | - |
225 | | - self.indexer = Indexer(config) |
226 | | - self.rope = initialize_rope( |
227 | | - dims=self.qk_rope_head_dim, |
228 | | - base=self.rope_theta, |
229 | | - traditional=True, |
230 | | - max_position_embeddings=self.max_position_embeddings, |
231 | | - scaling_config=self.config.rope_scaling, |
232 | | - ) |
233 | | - |
234 | | - def __call__( |
235 | | - self, |
236 | | - x: mx.array, |
237 | | - mask: Optional[mx.array] = None, |
238 | | - cache: Optional[Any] = None, |
239 | | - ) -> mx.array: |
240 | | - pass |
241 | | - |
242 | | - |
243 | | -class DeepseekV32MLP(nn.Module): |
244 | | - def __init__(self, config: ModelArgs, hidden_size: int = None, intermediate_size: int = None): |
245 | | - super().__init__() |
246 | | - self.config = config |
247 | | - self.hidden_size = config.hidden_size if hidden_size is None else hidden_size |
248 | | - self.intermediate_size = ( |
249 | | - config.intermediate_size if intermediate_size is None else intermediate_size |
250 | | - ) |
251 | | - |
252 | | - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
253 | | - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
254 | | - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
255 | | - |
256 | | - def __call__(self, x): |
257 | | - down_proj = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) |
258 | | - return down_proj |
259 | | - |
260 | | - |
261 | | -@mx.compile |
262 | | -def group_expert_select( |
263 | | - gates, |
264 | | - e_score_correction_bias, |
265 | | - top_k, |
266 | | - n_group, |
267 | | - topk_group, |
268 | | - routed_scaling_factor, |
269 | | - norm_topk_prob, |
270 | | -): |
271 | | - |
272 | | - scores = mx.sigmoid(gates.astype(mx.float32)) |
273 | | - orig_scores = scores |
274 | | - scores = scores + e_score_correction_bias |
275 | | - if n_group > 1: |
276 | | - scores = mx.unflatten(scores, axis=-1, shape=(n_group, -1)) |
277 | | - group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1, keepdims=True) |
278 | | - k = n_group - topk_group |
279 | | - group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :] |
280 | | - scores = mx.put_along_axis(scores, mx.stop_gradient(group_idx), mx.array(0.0), axis=-2) |
281 | | - scores = mx.flatten(scores, -2, -1) |
282 | | - |
283 | | - k = top_k |
284 | | - inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k] |
285 | | - scores = mx.take_along_axis(orig_scores, inds, axis=-1) |
286 | | - if top_k > 1 and norm_topk_prob: |
287 | | - denominator = scores.sum(axis=-1, keepdims=True) |
288 | | - scores = scores / denominator |
289 | | - scores = scores * routed_scaling_factor |
290 | | - |
291 | | - return inds, scores |
292 | | - |
293 | | - |
294 | | -class MoEGate(nn.Module): |
295 | | - def __init__(self, config: ModelArgs): |
296 | | - super().__init__() |
297 | | - self.config = config |
298 | | - self.top_k = config.num_experts_per_tok |
299 | | - self.norm_topk_prob = config.norm_topk_prob |
300 | | - self.n_routed_experts = config.n_routed_experts |
301 | | - self.routed_scaling_factor = config.routed_scaling_factor |
302 | | - self.n_group = config.n_group |
303 | | - self.topk_group = config.topk_group |
304 | | - self.weight = mx.zeros((self.n_routed_experts, config.hidden_size)) |
305 | | - self.e_score_correction_bias = mx.zeros((self.n_routed_experts,)) |
306 | | - assert config.topk_method == "noaux_tc", "Unsupported topk method." |
307 | | - |
308 | | - def __call__(self, x): |
309 | | - return group_expert_select( |
310 | | - x @ self.weight.T, |
311 | | - self.e_score_correction_bias, |
312 | | - self.top_k, |
313 | | - self.n_group, |
314 | | - self.topk_group, |
315 | | - self.routed_scaling_factor, |
316 | | - self.norm_topk_prob, |
317 | | - ) |
318 | | - |
319 | | - |
320 | | -class DeepseekV32MoE(nn.Module): |
321 | | - def __init__(self, config: ModelArgs): |
322 | | - super().__init__() |
323 | | - self.config = config |
324 | | - self.num_experts_per_tok = config.num_experts_per_tok |
325 | | - self.switch_mlp = SwitchGLU( |
326 | | - config.hidden_size, |
327 | | - config.moe_intermediate_size, |
328 | | - config.n_routed_experts, |
329 | | - ) |
330 | | - |
331 | | - self.gate = MoEGate(config) |
332 | | - if config.n_shared_experts is not None: |
333 | | - intermediate_size = config.moe_intermediate_size * config.n_shared_experts |
334 | | - self.shared_experts = DeepseekV32MLP(config=config, intermediate_size=intermediate_size) |
| 108 | +class ParallaxDeepSeekV32Attention(MLXDeepseekV32Attention): |
335 | 109 |
|
336 | | - def __call__(self, x): |
337 | | - inds, scores = self.gate(x) |
338 | | - y = self.switch_mlp(x, inds) |
339 | | - y = (y * scores[..., None]).sum(axis=-2).astype(y.dtype) |
340 | | - if self.config.n_shared_experts is not None: |
341 | | - y = y + self.shared_experts(x) |
342 | | - |
343 | | - return y |
344 | | - |
345 | | - |
346 | | -class DeepseekV32DecoderLayer(nn.Module): |
347 | | - def __init__(self, config: ModelArgs, layer_idx: int): |
348 | | - super().__init__() |
349 | | - self.self_attn = DeepseekV32Attention(config) |
350 | | - self.mlp = ( |
351 | | - DeepseekV32MoE(config) |
352 | | - if ( |
353 | | - config.n_routed_experts is not None |
354 | | - and layer_idx >= config.first_k_dense_replace |
355 | | - and layer_idx % config.moe_layer_freq == 0 |
356 | | - ) |
357 | | - else DeepseekV32MLP(config) |
358 | | - ) |
359 | | - self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
360 | | - self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
361 | | - |
362 | | - def __call__( |
363 | | - self, |
364 | | - x: mx.array, |
365 | | - mask: Optional[mx.array] = None, |
366 | | - cache: Optional[Any] = None, |
367 | | - ) -> mx.array: |
368 | | - pass |
369 | | - |
370 | | - |
371 | | -class ParallaxDeepSeekV32Attention(DeepseekV32Attention): |
| 110 | + def __init__(self, args: ModelArgs): |
| 111 | + super().__init__(args) |
| 112 | + self.indexer = ParallaxDeepSeekV32Indexer(args) |
372 | 113 |
|
373 | 114 | def __call__( |
374 | 115 | self, |
@@ -481,7 +222,7 @@ def __call__( |
481 | 222 | return self.o_proj(output) |
482 | 223 |
|
483 | 224 |
|
484 | | -class ParallaxDeepSeekV32Block(DeepseekV32DecoderLayer): |
| 225 | +class ParallaxDeepSeekV32Block(MLXDeepseekV32Block): |
485 | 226 | def __init__(self, args: ModelArgs, layer_idx: int): |
486 | 227 | super().__init__(args, layer_idx=layer_idx) |
487 | 228 | self.self_attn = ParallaxDeepSeekV32Attention(args) |
|
0 commit comments