Skip to content

Commit

Permalink
async model downloading fixes #30
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexCheema committed Jul 20, 2024
1 parent e49924e commit a4cc667
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 51 deletions.
28 changes: 1 addition & 27 deletions exo/inference/mlx/sharded_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,6 @@
from ..shard import Shard
from typing import Optional

class MLXFixedShardInferenceEngine(InferenceEngine):
def __init__(self, model_path: str, shard: Shard):
self.shard = shard
model_shard, self.tokenizer = load_shard(model_path, shard)
self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)

async def infer_prompt(self, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
if shard != self.shard:
raise ValueError(f"Shard mismatch: {shard} != {self.shard}")

output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(self.tokenizer.encode(prompt))))
return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id

async def infer_tensor(self, shard: Shard, input_data: np.ndarray) -> (np.ndarray, str, bool):
if shard != self.shard:
raise ValueError(f"Shard mismatch: {shard} != {self.shard}")

output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(input_data)))
return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id

async def reset_shard(self, shard: Shard):
if shard != self.shard:
raise ValueError(f"Shard mismatch: {shard} != {self.shard}")

self.stateful_sharded_model.reset()

class MLXDynamicShardInferenceEngine(InferenceEngine):
def __init__(self):
self.shard = None
Expand All @@ -54,6 +28,6 @@ async def ensure_shard(self, shard: Shard):
if self.shard == shard:
return

model_shard, self.tokenizer = load_shard(shard.model_id, shard)
model_shard, self.tokenizer = await load_shard(shard.model_id, shard)
self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
self.shard = shard
14 changes: 10 additions & 4 deletions exo/inference/mlx/sharded_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import importlib
import json
import logging
import asyncio
from functools import partial
from pathlib import Path
from typing import Optional, Tuple

Expand Down Expand Up @@ -151,7 +153,11 @@ def class_predicate(p, m):
model.eval()
return model

def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
async def snapshot_download_async(*args, **kwargs):
func = partial(snapshot_download, *args, **kwargs)
return await asyncio.get_event_loop().run_in_executor(None, func)

async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
"""
Ensures the model is available locally. If the path does not exist locally,
it is downloaded from the Hugging Face Hub.
Expand All @@ -167,7 +173,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
if not model_path.exists():
try:
model_path = Path(
snapshot_download(
await snapshot_download_async(
repo_id=path_or_hf_repo,
revision=revision,
allow_patterns=[
Expand All @@ -191,7 +197,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
return model_path


def load_shard(
async def load_shard(
path_or_hf_repo: str,
shard: Shard,
tokenizer_config={},
Expand Down Expand Up @@ -220,7 +226,7 @@ def load_shard(
FileNotFoundError: If config file or safetensors are not found.
ValueError: If model class or args class are not found.
"""
model_path = get_model_path(path_or_hf_repo)
model_path = await get_model_path(path_or_hf_repo)

model = load_model_shard(model_path, shard, lazy, model_config)
if adapter_path is not None:
Expand Down
20 changes: 10 additions & 10 deletions exo/inference/test_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
assert np.array_equal(resp_full, resp2)
assert np.array_equal(next_resp_full, resp4)

asyncio.run(test_inference_engine(
MLXDynamicShardInferenceEngine(),
MLXDynamicShardInferenceEngine(),
"mlx-community/Meta-Llama-3-8B-Instruct-4bit",
))

# TODO: Waiting on https://github.com/tinygrad/tinygrad/issues/5549
# asyncio.run(test_inference_engine(
# TinygradDynamicShardInferenceEngine(),
# TinygradDynamicShardInferenceEngine(),
# "llama3-8b-sfr",
# MLXDynamicShardInferenceEngine(),
# MLXDynamicShardInferenceEngine(),
# "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
# ))

# TODO: Waiting on https://github.com/tinygrad/tinygrad/issues/5549
asyncio.run(test_inference_engine(
TinygradDynamicShardInferenceEngine(),
TinygradDynamicShardInferenceEngine(),
"llama3-8b-sfr",
))
26 changes: 16 additions & 10 deletions exo/inference/tinygrad/inference.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@

import asyncio
from functools import partial
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Union
import json, argparse, random, time
import tiktoken
from tiktoken.load import load_tiktoken_bpe
from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
from tinygrad import Tensor, dtypes, nn, Context, Device, GlobalCounters
from tinygrad.helpers import DEBUG, tqdm, _cache_dir
from tinygrad.helpers import DEBUG, tqdm, _cache_dir, fetch
from exo.inference.shard import Shard
from exo.inference.inference_engine import InferenceEngine
import numpy as np
import os

MODEL_PARAMS = {
"8B": {
Expand Down Expand Up @@ -58,6 +60,11 @@ def encode(self, text, allow_special=False):
return self.model.encode(text, allowed_special="all" if allow_special else set(), disallowed_special=set())

# **** helper functions ****
async def fetch_async(url: str, name: Optional[Union[Path, str]] = None, subdir: Optional[str] = None,
allow_caching=not os.getenv("DISABLE_HTTP_CACHE")) -> Path:
func = partial(fetch, url, name, subdir, allow_caching)
return await asyncio.get_event_loop().run_in_executor(None, func)

def concat_weights(models, device=None):
def convert(name) -> Tensor:
disk_tensors: List[Tensor] = [model[name] for model in models]
Expand Down Expand Up @@ -176,16 +183,15 @@ async def ensure_shard(self, shard: Shard):
if Path(model_path / "model.safetensors.index.json").exists():
model = model_path
else:
from tinygrad.helpers import fetch

if DEBUG >= 2: print(f"Downloading tinygrad model {shard.model_id}...")
if shard.model_id.lower().find("llama3-8b-sfr") != -1:
fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir=shard.model_id)
fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00001-of-00004.safetensors", "model-00001-of-00004.safetensors", subdir=shard.model_id)
fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00002-of-00004.safetensors", "model-00002-of-00004.safetensors", subdir=shard.model_id)
fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00003-of-00004.safetensors", "model-00003-of-00004.safetensors", subdir=shard.model_id)
fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir=shard.model_id)
model = fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir=shard.model_id)
await fetch_async("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir=shard.model_id)
await fetch_async("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00001-of-00004.safetensors", "model-00001-of-00004.safetensors", subdir=shard.model_id)
await fetch_async("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00002-of-00004.safetensors", "model-00002-of-00004.safetensors", subdir=shard.model_id)
await fetch_async("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00003-of-00004.safetensors", "model-00003-of-00004.safetensors", subdir=shard.model_id)
await fetch_async("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir=shard.model_id)
model = await fetch_async("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir=shard.model_id)
size = "8B"
elif shard.model_id.lower().find("llama3-70b-sfr") != -1:
raise NotImplementedError("llama3-70b-sfr is not implemented for tinygrad")
Expand Down

0 comments on commit a4cc667

Please sign in to comment.