Skip to content

Commit 497b5fa

Browse files
authored
feat(backend): update mlx-lm version to 0.28.4 (#322)
1 parent ec29ef9 commit 497b5fa

File tree

9 files changed

+31
-423
lines changed

9 files changed

+31
-423
lines changed

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,19 +43,19 @@ parallax = "parallax.cli:main"
4343

4444
mac = [
4545
"torch==2.8.0",
46-
"mlx-lm==0.28.0",
46+
"mlx-lm==0.28.4",
4747
"mlx==0.30.0",
4848
]
4949

5050
gpu = [
5151
"sglang[all]==0.5.5",
52-
"mlx-lm==0.28.0",
52+
"mlx-lm==0.28.4",
5353
"mlx[cpu]==0.30.0",
5454
]
5555

5656
vllm = [
5757
"vllm==0.11.0",
58-
"mlx-lm==0.28.0",
58+
"mlx-lm==0.28.4",
5959
"mlx[cpu]==0.30.0",
6060
]
6161

src/parallax/models/deepseek_v32.py

Lines changed: 12 additions & 271 deletions
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,18 @@
11
# Copyright © 2025 Apple Inc.
2-
import math
3-
from dataclasses import dataclass
4-
from typing import Any, Dict, Optional, Tuple
2+
from typing import Any, Optional, Tuple
53

64
import mlx.core as mx
7-
import mlx.nn as nn
8-
from mlx_lm.models.base import BaseModelArgs, scaled_dot_product_attention
9-
from mlx_lm.models.rope_utils import initialize_rope
10-
from mlx_lm.models.switch_layers import SwitchGLU
5+
from mlx_lm.models.base import scaled_dot_product_attention
6+
from mlx_lm.models.deepseek_v32 import DeepseekV32Attention as MLXDeepseekV32Attention
7+
from mlx_lm.models.deepseek_v32 import DeepseekV32DecoderLayer as MLXDeepseekV32Block
8+
from mlx_lm.models.deepseek_v32 import Indexer as MLXDeepseekV32Indexer
9+
from mlx_lm.models.deepseek_v32 import ModelArgs
1110

1211
from parallax.metal.indexer.kernel import q_dot_k, store_indexer_cache
1312
from parallax.metal.paged_attention.kernel import paged_attention, reshape_and_cache
1413

1514

16-
@dataclass
17-
class ModelArgs(BaseModelArgs):
18-
model_type: str = "deepseek_v32"
19-
vocab_size: int = 102400
20-
hidden_size: int = 4096
21-
index_n_heads: int = 64
22-
index_head_dim: int = 128
23-
index_topk: int = 2048
24-
intermediate_size: int = 11008
25-
moe_intermediate_size: int = 1407
26-
num_hidden_layers: int = 30
27-
num_attention_heads: int = 32
28-
num_key_value_heads: int = 32
29-
n_shared_experts: Optional[int] = None
30-
n_routed_experts: Optional[int] = None
31-
routed_scaling_factor: float = 1.0
32-
kv_lora_rank: int = 512
33-
q_lora_rank: int = 1536
34-
qk_rope_head_dim: int = 64
35-
v_head_dim: int = 128
36-
qk_nope_head_dim: int = 128
37-
topk_method: str = "noaux_tc"
38-
scoring_func: str = "sigmoid"
39-
norm_topk_prob: bool = True
40-
n_group: int = 1
41-
topk_group: int = 1
42-
num_experts_per_tok: int = 1
43-
moe_layer_freq: int = 1
44-
first_k_dense_replace: int = 0
45-
max_position_embeddings: int = 2048
46-
rms_norm_eps: float = 1e-6
47-
rope_theta: float = 10000.0
48-
rope_scaling: Dict = None
49-
attention_bias: bool = False
50-
51-
52-
class Indexer(nn.Module):
53-
def __init__(self, args: ModelArgs):
54-
super().__init__()
55-
self.dim = args.hidden_size
56-
self.n_heads = args.index_n_heads
57-
self.head_dim = args.index_head_dim
58-
self.rope_head_dim = args.qk_rope_head_dim
59-
self.index_topk = args.index_topk
60-
self.q_lora_rank = args.q_lora_rank
61-
self.wq_b = nn.Linear(self.q_lora_rank, self.n_heads * self.head_dim, bias=False)
62-
self.wk = nn.Linear(self.dim, self.head_dim, bias=False)
63-
self.k_norm = nn.LayerNorm(self.head_dim)
64-
self.weights_proj = nn.Linear(self.dim, self.n_heads, bias=False)
65-
self.softmax_scale = self.head_dim**-0.5
66-
self.rope = nn.RoPE(
67-
dims=self.rope_head_dim,
68-
base=args.rope_theta,
69-
traditional=False, # Non-interleaved
70-
)
71-
self.rope = initialize_rope(
72-
dims=args.qk_rope_head_dim,
73-
base=args.rope_theta,
74-
traditional=False,
75-
max_position_embeddings=args.max_position_embeddings,
76-
scaling_config=args.rope_scaling,
77-
)
78-
15+
class ParallaxDeepSeekV32Indexer(MLXDeepseekV32Indexer):
7916
def __call__(
8017
self,
8118
x: mx.array,
@@ -168,207 +105,11 @@ def __call__(
168105
return mx.argpartition(scores, kth=-self.index_topk, axis=-1)[..., -self.index_topk :]
169106

170107

171-
class DeepseekV32Attention(nn.Module):
172-
def __init__(self, config: ModelArgs):
173-
super().__init__()
174-
self.config = config
175-
self.hidden_size = config.hidden_size
176-
self.num_heads = config.num_attention_heads
177-
self.max_position_embeddings = config.max_position_embeddings
178-
self.rope_theta = config.rope_theta
179-
self.q_lora_rank = config.q_lora_rank
180-
self.qk_rope_head_dim = config.qk_rope_head_dim
181-
self.kv_lora_rank = config.kv_lora_rank
182-
self.v_head_dim = config.v_head_dim
183-
self.qk_nope_head_dim = config.qk_nope_head_dim
184-
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
185-
186-
self.scale = self.q_head_dim**-0.5
187-
188-
if self.q_lora_rank is None:
189-
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias=False)
190-
else:
191-
self.q_a_proj = nn.Linear(
192-
self.hidden_size, self.q_lora_rank, bias=config.attention_bias
193-
)
194-
self.q_a_layernorm = nn.RMSNorm(self.q_lora_rank, eps=1e-6)
195-
self.q_b_proj = nn.Linear(
196-
self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
197-
)
198-
199-
self.kv_a_proj_with_mqa = nn.Linear(
200-
self.hidden_size,
201-
self.kv_lora_rank + self.qk_rope_head_dim,
202-
bias=config.attention_bias,
203-
)
204-
self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank, eps=1e-6)
205-
self.kv_b_proj = nn.Linear(
206-
self.kv_lora_rank,
207-
self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
208-
bias=False,
209-
)
210-
211-
self.o_proj = nn.Linear(
212-
self.num_heads * self.v_head_dim,
213-
self.hidden_size,
214-
bias=config.attention_bias,
215-
)
216-
217-
if self.config.rope_scaling is not None:
218-
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
219-
if mscale_all_dim:
220-
scaling_factor = self.config.rope_scaling["factor"]
221-
if scaling_factor > 1:
222-
s = 0.1 * mscale_all_dim * math.log(scaling_factor) + 1.0
223-
self.scale = self.scale * s * s
224-
225-
self.indexer = Indexer(config)
226-
self.rope = initialize_rope(
227-
dims=self.qk_rope_head_dim,
228-
base=self.rope_theta,
229-
traditional=True,
230-
max_position_embeddings=self.max_position_embeddings,
231-
scaling_config=self.config.rope_scaling,
232-
)
233-
234-
def __call__(
235-
self,
236-
x: mx.array,
237-
mask: Optional[mx.array] = None,
238-
cache: Optional[Any] = None,
239-
) -> mx.array:
240-
pass
241-
242-
243-
class DeepseekV32MLP(nn.Module):
244-
def __init__(self, config: ModelArgs, hidden_size: int = None, intermediate_size: int = None):
245-
super().__init__()
246-
self.config = config
247-
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
248-
self.intermediate_size = (
249-
config.intermediate_size if intermediate_size is None else intermediate_size
250-
)
251-
252-
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
253-
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
254-
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
255-
256-
def __call__(self, x):
257-
down_proj = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
258-
return down_proj
259-
260-
261-
@mx.compile
262-
def group_expert_select(
263-
gates,
264-
e_score_correction_bias,
265-
top_k,
266-
n_group,
267-
topk_group,
268-
routed_scaling_factor,
269-
norm_topk_prob,
270-
):
271-
272-
scores = mx.sigmoid(gates.astype(mx.float32))
273-
orig_scores = scores
274-
scores = scores + e_score_correction_bias
275-
if n_group > 1:
276-
scores = mx.unflatten(scores, axis=-1, shape=(n_group, -1))
277-
group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1, keepdims=True)
278-
k = n_group - topk_group
279-
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :]
280-
scores = mx.put_along_axis(scores, mx.stop_gradient(group_idx), mx.array(0.0), axis=-2)
281-
scores = mx.flatten(scores, -2, -1)
282-
283-
k = top_k
284-
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
285-
scores = mx.take_along_axis(orig_scores, inds, axis=-1)
286-
if top_k > 1 and norm_topk_prob:
287-
denominator = scores.sum(axis=-1, keepdims=True)
288-
scores = scores / denominator
289-
scores = scores * routed_scaling_factor
290-
291-
return inds, scores
292-
293-
294-
class MoEGate(nn.Module):
295-
def __init__(self, config: ModelArgs):
296-
super().__init__()
297-
self.config = config
298-
self.top_k = config.num_experts_per_tok
299-
self.norm_topk_prob = config.norm_topk_prob
300-
self.n_routed_experts = config.n_routed_experts
301-
self.routed_scaling_factor = config.routed_scaling_factor
302-
self.n_group = config.n_group
303-
self.topk_group = config.topk_group
304-
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
305-
self.e_score_correction_bias = mx.zeros((self.n_routed_experts,))
306-
assert config.topk_method == "noaux_tc", "Unsupported topk method."
307-
308-
def __call__(self, x):
309-
return group_expert_select(
310-
x @ self.weight.T,
311-
self.e_score_correction_bias,
312-
self.top_k,
313-
self.n_group,
314-
self.topk_group,
315-
self.routed_scaling_factor,
316-
self.norm_topk_prob,
317-
)
318-
319-
320-
class DeepseekV32MoE(nn.Module):
321-
def __init__(self, config: ModelArgs):
322-
super().__init__()
323-
self.config = config
324-
self.num_experts_per_tok = config.num_experts_per_tok
325-
self.switch_mlp = SwitchGLU(
326-
config.hidden_size,
327-
config.moe_intermediate_size,
328-
config.n_routed_experts,
329-
)
330-
331-
self.gate = MoEGate(config)
332-
if config.n_shared_experts is not None:
333-
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
334-
self.shared_experts = DeepseekV32MLP(config=config, intermediate_size=intermediate_size)
108+
class ParallaxDeepSeekV32Attention(MLXDeepseekV32Attention):
335109

336-
def __call__(self, x):
337-
inds, scores = self.gate(x)
338-
y = self.switch_mlp(x, inds)
339-
y = (y * scores[..., None]).sum(axis=-2).astype(y.dtype)
340-
if self.config.n_shared_experts is not None:
341-
y = y + self.shared_experts(x)
342-
343-
return y
344-
345-
346-
class DeepseekV32DecoderLayer(nn.Module):
347-
def __init__(self, config: ModelArgs, layer_idx: int):
348-
super().__init__()
349-
self.self_attn = DeepseekV32Attention(config)
350-
self.mlp = (
351-
DeepseekV32MoE(config)
352-
if (
353-
config.n_routed_experts is not None
354-
and layer_idx >= config.first_k_dense_replace
355-
and layer_idx % config.moe_layer_freq == 0
356-
)
357-
else DeepseekV32MLP(config)
358-
)
359-
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
360-
self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
361-
362-
def __call__(
363-
self,
364-
x: mx.array,
365-
mask: Optional[mx.array] = None,
366-
cache: Optional[Any] = None,
367-
) -> mx.array:
368-
pass
369-
370-
371-
class ParallaxDeepSeekV32Attention(DeepseekV32Attention):
110+
def __init__(self, args: ModelArgs):
111+
super().__init__(args)
112+
self.indexer = ParallaxDeepSeekV32Indexer(args)
372113

373114
def __call__(
374115
self,
@@ -481,7 +222,7 @@ def __call__(
481222
return self.o_proj(output)
482223

483224

484-
class ParallaxDeepSeekV32Block(DeepseekV32DecoderLayer):
225+
class ParallaxDeepSeekV32Block(MLXDeepseekV32Block):
485226
def __init__(self, args: ModelArgs, layer_idx: int):
486227
super().__init__(args, layer_idx=layer_idx)
487228
self.self_attn = ParallaxDeepSeekV32Attention(args)

0 commit comments

Comments
 (0)