diff --git a/exo/inference/mlx/sharded_inference_engine.py b/exo/inference/mlx/sharded_inference_engine.py index 88815a1b0..2ab83eeef 100644 --- a/exo/inference/mlx/sharded_inference_engine.py +++ b/exo/inference/mlx/sharded_inference_engine.py @@ -6,32 +6,6 @@ from ..shard import Shard from typing import Optional -class MLXFixedShardInferenceEngine(InferenceEngine): - def __init__(self, model_path: str, shard: Shard): - self.shard = shard - model_shard, self.tokenizer = load_shard(model_path, shard) - self.stateful_sharded_model = StatefulShardedModel(shard, model_shard) - - async def infer_prompt(self, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool): - if shard != self.shard: - raise ValueError(f"Shard mismatch: {shard} != {self.shard}") - - output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(self.tokenizer.encode(prompt)))) - return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id - - async def infer_tensor(self, shard: Shard, input_data: np.ndarray) -> (np.ndarray, str, bool): - if shard != self.shard: - raise ValueError(f"Shard mismatch: {shard} != {self.shard}") - - output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(input_data))) - return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id - - async def reset_shard(self, shard: Shard): - if shard != self.shard: - raise ValueError(f"Shard mismatch: {shard} != {self.shard}") - - self.stateful_sharded_model.reset() - class MLXDynamicShardInferenceEngine(InferenceEngine): def __init__(self): self.shard = None @@ -54,6 +28,6 @@ async def ensure_shard(self, shard: Shard): if self.shard == shard: return - model_shard, self.tokenizer = load_shard(shard.model_id, shard) + model_shard, self.tokenizer = await load_shard(shard.model_id, shard) self.stateful_sharded_model = StatefulShardedModel(shard, model_shard) self.shard = shard diff --git a/exo/inference/mlx/sharded_utils.py b/exo/inference/mlx/sharded_utils.py index 69890d4b9..11a16e182 100644 --- a/exo/inference/mlx/sharded_utils.py +++ b/exo/inference/mlx/sharded_utils.py @@ -4,6 +4,8 @@ import importlib import json import logging +import asyncio +from functools import partial from pathlib import Path from typing import Optional, Tuple @@ -151,7 +153,11 @@ def class_predicate(p, m): model.eval() return model -def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path: +async def snapshot_download_async(*args, **kwargs): + func = partial(snapshot_download, *args, **kwargs) + return await asyncio.get_event_loop().run_in_executor(None, func) + +async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path: """ Ensures the model is available locally. If the path does not exist locally, it is downloaded from the Hugging Face Hub. @@ -167,7 +173,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path if not model_path.exists(): try: model_path = Path( - snapshot_download( + await snapshot_download_async( repo_id=path_or_hf_repo, revision=revision, allow_patterns=[ @@ -191,7 +197,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path return model_path -def load_shard( +async def load_shard( path_or_hf_repo: str, shard: Shard, tokenizer_config={}, @@ -220,7 +226,7 @@ def load_shard( FileNotFoundError: If config file or safetensors are not found. ValueError: If model class or args class are not found. """ - model_path = get_model_path(path_or_hf_repo) + model_path = await get_model_path(path_or_hf_repo) model = load_model_shard(model_path, shard, lazy, model_config) if adapter_path is not None: diff --git a/exo/inference/test_inference_engine.py b/exo/inference/test_inference_engine.py index 2cdba8b6c..b4f898a93 100644 --- a/exo/inference/test_inference_engine.py +++ b/exo/inference/test_inference_engine.py @@ -24,15 +24,15 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e assert np.array_equal(resp_full, resp2) assert np.array_equal(next_resp_full, resp4) -asyncio.run(test_inference_engine( - MLXDynamicShardInferenceEngine(), - MLXDynamicShardInferenceEngine(), - "mlx-community/Meta-Llama-3-8B-Instruct-4bit", -)) - -# TODO: Waiting on https://github.com/tinygrad/tinygrad/issues/5549 # asyncio.run(test_inference_engine( -# TinygradDynamicShardInferenceEngine(), -# TinygradDynamicShardInferenceEngine(), -# "llama3-8b-sfr", +# MLXDynamicShardInferenceEngine(), +# MLXDynamicShardInferenceEngine(), +# "mlx-community/Meta-Llama-3-8B-Instruct-4bit", # )) + +# TODO: Waiting on https://github.com/tinygrad/tinygrad/issues/5549 +asyncio.run(test_inference_engine( + TinygradDynamicShardInferenceEngine(), + TinygradDynamicShardInferenceEngine(), + "llama3-8b-sfr", +)) diff --git a/exo/inference/tinygrad/inference.py b/exo/inference/tinygrad/inference.py index 5d1b14859..4cf1d1b11 100644 --- a/exo/inference/tinygrad/inference.py +++ b/exo/inference/tinygrad/inference.py @@ -1,16 +1,18 @@ - +import asyncio +from functools import partial from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Union import json, argparse, random, time import tiktoken from tiktoken.load import load_tiktoken_bpe from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16 from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters from tinygrad import Tensor, dtypes, nn, Context, Device, GlobalCounters -from tinygrad.helpers import DEBUG, tqdm, _cache_dir +from tinygrad.helpers import DEBUG, tqdm, _cache_dir, fetch from exo.inference.shard import Shard from exo.inference.inference_engine import InferenceEngine import numpy as np +import os MODEL_PARAMS = { "8B": { @@ -58,6 +60,11 @@ def encode(self, text, allow_special=False): return self.model.encode(text, allowed_special="all" if allow_special else set(), disallowed_special=set()) # **** helper functions **** +async def fetch_async(url: str, name: Optional[Union[Path, str]] = None, subdir: Optional[str] = None, + allow_caching=not os.getenv("DISABLE_HTTP_CACHE")) -> Path: + func = partial(fetch, url, name, subdir, allow_caching) + return await asyncio.get_event_loop().run_in_executor(None, func) + def concat_weights(models, device=None): def convert(name) -> Tensor: disk_tensors: List[Tensor] = [model[name] for model in models] @@ -176,16 +183,15 @@ async def ensure_shard(self, shard: Shard): if Path(model_path / "model.safetensors.index.json").exists(): model = model_path else: - from tinygrad.helpers import fetch if DEBUG >= 2: print(f"Downloading tinygrad model {shard.model_id}...") if shard.model_id.lower().find("llama3-8b-sfr") != -1: - fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir=shard.model_id) - fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00001-of-00004.safetensors", "model-00001-of-00004.safetensors", subdir=shard.model_id) - fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00002-of-00004.safetensors", "model-00002-of-00004.safetensors", subdir=shard.model_id) - fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00003-of-00004.safetensors", "model-00003-of-00004.safetensors", subdir=shard.model_id) - fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir=shard.model_id) - model = fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir=shard.model_id) + await fetch_async("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir=shard.model_id) + await fetch_async("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00001-of-00004.safetensors", "model-00001-of-00004.safetensors", subdir=shard.model_id) + await fetch_async("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00002-of-00004.safetensors", "model-00002-of-00004.safetensors", subdir=shard.model_id) + await fetch_async("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00003-of-00004.safetensors", "model-00003-of-00004.safetensors", subdir=shard.model_id) + await fetch_async("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir=shard.model_id) + model = await fetch_async("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir=shard.model_id) size = "8B" elif shard.model_id.lower().find("llama3-70b-sfr") != -1: raise NotImplementedError("llama3-70b-sfr is not implemented for tinygrad")