Skip to content

Commit ec29ef9

Browse files
authored
feat(model): support deepseek v3.2 (#316)
1 parent f3b0b07 commit ec29ef9

28 files changed

+1486
-82
lines changed

scripts/download_model_shard.sh

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#!/bin/bash
2+
3+
# Example usage of the download_shard.py script
4+
5+
# Default values
6+
MODEL_REPO=${1:-"Qwen/Qwen2.5-7B-Instruct"}
7+
START_LAYER=${2:-0}
8+
END_LAYER=${3:-10}
9+
OUTPUT_DIR=${4} # Optional, defaults to empty/unset
10+
11+
# Get the directory where this script is located
12+
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
13+
PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
14+
15+
echo "========================================================"
16+
echo "Downloading shard for model: $MODEL_REPO"
17+
echo "Layers: [$START_LAYER, $END_LAYER)"
18+
if [ -z "$OUTPUT_DIR" ]; then
19+
echo "Output Directory: Default Hugging Face Cache"
20+
OUTPUT_ARG=""
21+
else
22+
echo "Output Directory: $OUTPUT_DIR"
23+
OUTPUT_ARG="--output-dir $OUTPUT_DIR"
24+
fi
25+
echo "========================================================"
26+
27+
# Ensure PYTHONPATH includes src
28+
export PYTHONPATH="${PROJECT_ROOT}/src:${PYTHONPATH}"
29+
30+
python "${SCRIPT_DIR}/download_shard.py" \
31+
--model-repo "$MODEL_REPO" \
32+
--start-layer "$START_LAYER" \
33+
--end-layer "$END_LAYER" \
34+
$OUTPUT_ARG
35+
36+
if [ $? -eq 0 ]; then
37+
echo "========================================================"
38+
echo "Download completed successfully."
39+
echo "========================================================"
40+
else
41+
echo "========================================================"
42+
echo "Download failed."
43+
echo "========================================================"
44+
exit 1
45+
fi

scripts/download_shard.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import argparse
2+
import os
3+
import sys
4+
from pathlib import Path
5+
6+
# Add src to sys.path to allow importing parallax modules
7+
# Assuming script is in scripts/ directory, so src is at ../src
8+
current_dir = Path(__file__).resolve().parent
9+
src_dir = current_dir.parent / "src"
10+
sys.path.append(str(src_dir))
11+
12+
try:
13+
from parallax.utils.selective_download import selective_model_download
14+
from parallax_utils.logging_config import get_logger, set_log_level
15+
except ImportError:
16+
print(
17+
f"Error: Could not import parallax modules. Please ensure 'src' directory is in PYTHONPATH or script is located in 'scripts/'. Added path: {src_dir}"
18+
)
19+
sys.exit(1)
20+
21+
logger = get_logger("download_shard")
22+
23+
24+
def main():
25+
parser = argparse.ArgumentParser(
26+
description="Download specific layers of a model from Hugging Face Hub."
27+
)
28+
parser.add_argument(
29+
"--model-repo", type=str, required=True, help="Hugging Face model repository ID"
30+
)
31+
parser.add_argument(
32+
"--start-layer", type=int, required=True, help="Start layer index (inclusive)"
33+
)
34+
parser.add_argument("--end-layer", type=int, required=True, help="End layer index (exclusive)")
35+
parser.add_argument(
36+
"--output-dir",
37+
type=str,
38+
required=False,
39+
default=None,
40+
help="Local directory to save the model. If not provided, uses default Hugging Face cache.",
41+
)
42+
parser.add_argument("--log-level", type=str, default="INFO", help="Logging level")
43+
44+
args = parser.parse_args()
45+
set_log_level(args.log_level)
46+
47+
# Convert output_dir to absolute path if provided
48+
if args.output_dir:
49+
output_dir = os.path.abspath(args.output_dir)
50+
logger.info(
51+
f"Downloading model {args.model_repo} layers [{args.start_layer}, {args.end_layer}) to {output_dir}"
52+
)
53+
else:
54+
output_dir = None
55+
logger.info(
56+
f"Downloading model {args.model_repo} layers [{args.start_layer}, {args.end_layer}) to default Hugging Face cache"
57+
)
58+
59+
try:
60+
# Note: selective_model_download uses 'cache_dir' argument which is usually passed to hf_hub_download.
61+
# hf_hub_download uses cache_dir as the base for its cache structure (models--owner--repo/...).
62+
# If the user wants to download DIRECTLY to output_dir without the HF cache structure,
63+
# selective_download might need adjustment or we accept the HF cache structure.
64+
# Based on selective_download.py implementation:
65+
# It calls snapshot_download(..., cache_dir=cache_dir) or hf_hub_download(..., cache_dir=cache_dir).
66+
# So it will create the standard HF cache structure inside output_dir.
67+
68+
model_path = selective_model_download(
69+
repo_id=args.model_repo,
70+
start_layer=args.start_layer,
71+
end_layer=args.end_layer,
72+
cache_dir=output_dir,
73+
)
74+
logger.info(f"Successfully downloaded/verified model shard. Cache location: {model_path}")
75+
except Exception as e:
76+
logger.error(f"Failed to download model shard: {e}")
77+
sys.exit(1)
78+
79+
80+
if __name__ == "__main__":
81+
main()

src/backend/server/static_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
"deepseek-ai/DeepSeek-V3": "mlx-community/DeepSeek-V3-4bit",
4343
"deepseek-ai/DeepSeek-V2.5-1210": "mlx-community/DeepSeek-V2.5-1210-4bit",
4444
"deepseek-ai/DeepSeek-R1": "mlx-community/DeepSeek-R1-4bit",
45+
"deepseek-ai/DeepSeek-V3.2": "mlx-community/DeepSeek-V3.2-4bit",
4546
# Qwen 2.5 Series
4647
"Qwen/Qwen2.5-0.5B-Instruct": "Qwen/Qwen2.5-0.5B-Instruct",
4748
"Qwen/Qwen2.5-1.5B-Instruct": "Qwen/Qwen2.5-1.5B-Instruct",

src/parallax/launch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _wait_executors_check_layer_change(shared_state: SharedState, executor_subpr
107107
display_parallax_join(args.model_path)
108108
check_latest_release()
109109

110-
config = fetch_model_from_hf(args.model_path)
110+
config = fetch_model_from_hf(args.model_path, local_files_only=args.use_hfcache)
111111
# only launch http server on head node
112112
if args.start_layer == 0:
113113
http_server_process = launch_http_server(args)
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
import os
2+
from typing import Dict, List, Optional
3+
4+
import mlx.core as mx
5+
6+
_KERNELS: Dict[str, object] = {}
7+
8+
9+
def _get_metal_source(filename):
10+
path = os.path.join(os.path.dirname(__file__), filename)
11+
with open(path, "r") as f:
12+
return f.read()
13+
14+
15+
def _type_to_string(dtype: mx.Dtype) -> str:
16+
if dtype == mx.float32:
17+
return "float"
18+
elif dtype == mx.float16:
19+
return "half"
20+
elif dtype == mx.bfloat16:
21+
return "bfloat16_t"
22+
else:
23+
raise ValueError(f"Unsupported dtype: {dtype}")
24+
25+
26+
def _get_kernel(
27+
name: str,
28+
filename: str,
29+
input_names: List[str],
30+
output_names: List[str],
31+
dtype: mx.Dtype = mx.float32,
32+
):
33+
type_str = _type_to_string(dtype)
34+
kernel_key = f"{name}_{type_str}"
35+
36+
if kernel_key not in _KERNELS:
37+
source = _get_metal_source(filename)
38+
source = source.replace("{{T}}", type_str)
39+
40+
header = """
41+
#include <metal_stdlib>
42+
using namespace metal;
43+
"""
44+
_KERNELS[kernel_key] = mx.fast.metal_kernel(
45+
name=name,
46+
input_names=input_names,
47+
output_names=output_names,
48+
source=source,
49+
header=header,
50+
)
51+
return _KERNELS[kernel_key]
52+
53+
54+
def store_indexer_cache(
55+
key: mx.array,
56+
key_cache: mx.array,
57+
block_tables: mx.array,
58+
context_lengths: mx.array,
59+
block_size: int,
60+
layer_idx: int,
61+
slot_mapping: Optional[mx.array] = None,
62+
):
63+
dtype = key.dtype
64+
# key: (batch, target_len, num_heads, head_dim) or flattened
65+
66+
if slot_mapping is None:
67+
# Decode Mode
68+
batch_size = key.shape[0]
69+
if key.ndim == 4:
70+
# (batch, 1, num_kv_heads, head_dim) -> (batch, num_kv_heads, head_dim)
71+
if key.shape[1] == 1:
72+
key = key.squeeze(1)
73+
elif key.shape[2] == 1:
74+
# Fallback for old layout (batch, num_kv_heads, 1, head_dim)
75+
key = key.squeeze(2)
76+
77+
num_heads = key.shape[1]
78+
head_dim = key.shape[2]
79+
80+
# Compute slot_mapping internally
81+
indices = context_lengths - 1
82+
block_indices_in_table = indices // block_size
83+
offsets = indices % block_size
84+
batch_indices = mx.arange(batch_size)
85+
physical_block_numbers = block_tables[batch_indices, block_indices_in_table]
86+
slot_mapping = physical_block_numbers.astype(mx.int32) * block_size + offsets.astype(
87+
mx.int32
88+
)
89+
90+
num_tokens = batch_size
91+
else:
92+
# Prefill Mode
93+
if key.ndim == 4:
94+
B, T, H, D = key.shape
95+
key = key.reshape(B * T, H, D)
96+
97+
num_tokens = key.shape[0]
98+
num_heads = key.shape[1]
99+
head_dim = key.shape[2]
100+
101+
num_layers = key_cache.shape[0]
102+
num_blocks = key_cache.shape[1]
103+
104+
key_stride = num_heads * head_dim
105+
106+
def mk_int(val):
107+
return mx.array(val, dtype=mx.int32)
108+
109+
inputs = [
110+
key,
111+
key_cache,
112+
slot_mapping,
113+
mk_int(key_stride),
114+
mk_int(num_heads),
115+
mk_int(head_dim),
116+
mk_int(block_size),
117+
mk_int(layer_idx),
118+
mk_int(num_layers),
119+
mk_int(num_blocks),
120+
]
121+
122+
input_names = [
123+
"key",
124+
"key_cache",
125+
"slot_mapping",
126+
"key_stride",
127+
"num_heads",
128+
"head_dim",
129+
"block_size",
130+
"layer_idx",
131+
"num_layers",
132+
"num_blocks",
133+
]
134+
135+
kernel = _get_kernel(
136+
name="store_key_kernel",
137+
filename="store_key.metal",
138+
input_names=input_names,
139+
output_names=["dummy_out"],
140+
dtype=dtype,
141+
)
142+
143+
grid = (num_heads * head_dim, num_tokens, 1)
144+
thread_group = (min(1024, num_heads * head_dim), 1, 1)
145+
146+
outputs = kernel(
147+
inputs=inputs,
148+
grid=grid,
149+
threadgroup=thread_group,
150+
output_shapes=[(num_tokens, num_heads * head_dim)], # Dummy output
151+
output_dtypes=[mx.float32],
152+
verbose=False,
153+
)
154+
mx.eval(outputs)
155+
156+
157+
def q_dot_k(
158+
q: mx.array, # (num_heads, head_dim)
159+
key_cache: mx.array, # (L, B, H, BS, D)
160+
block_table: mx.array, # (max_blocks)
161+
context_length: mx.array, # scalar
162+
block_size: int,
163+
layer_idx: int,
164+
) -> mx.array:
165+
166+
if q.ndim > 2:
167+
q = q.squeeze() # Ensure (H, D)
168+
169+
num_heads = q.shape[0]
170+
head_dim = q.shape[1]
171+
172+
num_layers = key_cache.shape[0]
173+
num_total_blocks = key_cache.shape[1]
174+
max_blocks = block_table.shape[0]
175+
176+
ctx_len = int(context_length.item())
177+
178+
def mk_int(val):
179+
return mx.array(val, dtype=mx.int32)
180+
181+
inputs = [
182+
q,
183+
key_cache,
184+
block_table,
185+
mk_int(ctx_len),
186+
mk_int(block_size),
187+
mk_int(num_heads),
188+
mk_int(head_dim),
189+
mk_int(layer_idx),
190+
mk_int(num_layers),
191+
mk_int(num_total_blocks),
192+
mk_int(max_blocks),
193+
]
194+
195+
input_names = [
196+
"q",
197+
"key_cache",
198+
"block_table",
199+
"context_len",
200+
"block_size",
201+
"num_heads",
202+
"head_dim",
203+
"layer_idx",
204+
"num_layers",
205+
"num_total_blocks",
206+
"max_blocks",
207+
]
208+
209+
kernel = _get_kernel(
210+
name="q_dot_k_kernel",
211+
filename="q_dot_k.metal",
212+
input_names=input_names,
213+
output_names=["output"],
214+
dtype=q.dtype,
215+
)
216+
217+
# Grid: (block_size, num_heads, 1)
218+
grid = (block_size, num_heads, 1)
219+
thread_group = (min(1024, block_size), 1, 1)
220+
221+
outputs = kernel(
222+
inputs=inputs,
223+
grid=grid,
224+
threadgroup=thread_group,
225+
output_shapes=[(num_heads, ctx_len)],
226+
output_dtypes=[mx.float32], # Score is float32
227+
verbose=False,
228+
)
229+
230+
return outputs[0]

0 commit comments

Comments
 (0)