From 803a442141e26ee1b9a90d08b31f51d2f1ae931d Mon Sep 17 00:00:00 2001 From: Varshith Date: Fri, 26 Jul 2024 06:00:29 +0530 Subject: [PATCH 01/16] init --- .gitignore | 1 + exo/inference/mlx/models/sharded_llava.py | 595 ++++++++++++++++++++++ exo/inference/mlx/sharded_utils.py | 60 ++- exo/inference/mlx/test_sharded_llava.py | 0 4 files changed, 655 insertions(+), 1 deletion(-) create mode 100644 exo/inference/mlx/models/sharded_llava.py create mode 100644 exo/inference/mlx/test_sharded_llava.py diff --git a/.gitignore b/.gitignore index c5d644a87..b97ff07aa 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ __pycache__/ .venv test_weights.npz .exo_used_ports +.idea # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/exo/inference/mlx/models/sharded_llava.py b/exo/inference/mlx/models/sharded_llava.py new file mode 100644 index 000000000..a90dd79e2 --- /dev/null +++ b/exo/inference/mlx/models/sharded_llava.py @@ -0,0 +1,595 @@ +# Copyright © 2024 Apple Inc. + +import math +import glob +import inspect +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Dict, Union, Tuple + +import mlx.core as mx +import mlx.nn as nn +import numpy as np +from huggingface_hub import snapshot_download + + +@dataclass +class VisionConfig: + model_type: str + num_hidden_layers: int = 24 + hidden_size: int = 1024 + intermediate_size: int = 4096 + num_attention_heads: int = 16 + image_size: int = 336 + patch_size: int = 14 + projection_dim: int = 768 + vocab_size: int = 32000 + num_channels: int = 3 + layer_norm_eps: float = 1e-5 + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + + +class VisionAttention(nn.Module): + def __init__( + self, + dims: int, + num_heads: int, + query_input_dims: Optional[int] = None, + key_input_dims: Optional[int] = None, + value_input_dims: Optional[int] = None, + value_dims: Optional[int] = None, + value_output_dims: Optional[int] = None, + bias: bool = False, + ): + super().__init__() + + if (dims % num_heads) != 0: + raise ValueError( + "The input feature dimensions should be divisible by the " + f"number of heads ({dims} % {num_heads}) != 0" + ) + + query_input_dims = query_input_dims or dims + key_input_dims = key_input_dims or dims + value_input_dims = value_input_dims or key_input_dims + value_dims = value_dims or dims + value_output_dims = value_output_dims or dims + + self.num_heads = num_heads + self.q_proj = nn.Linear(query_input_dims, dims, bias=bias) + self.k_proj = nn.Linear(key_input_dims, dims, bias=bias) + self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias) + self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias) + + def __call__(self, queries, keys, values, mask=None): + queries = self.q_proj(queries) + keys = self.k_proj(keys) + values = self.v_proj(values) + + num_heads = self.num_heads + B, L, D = queries.shape + _, S, _ = keys.shape + queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1) + values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) + + scale = math.sqrt(1 / queries.shape[-1]) + scores = (queries * scale) @ keys + if mask is not None: + scores = scores + mask.astype(scores.dtype) + scores = mx.softmax(scores, axis=-1) + values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.out_proj(values_hat) + + +class VisionMLP(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.activation_fn = nn.GELU(approx="fast") + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def __call__(self, x: mx.array) -> mx.array: + x = self.activation_fn(self.fc1(x)) + x = self.fc2(x) + return x + + +class VisionEncoderLayer(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = VisionAttention( + config.hidden_size, config.num_attention_heads, bias=True + ) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = VisionMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: + y = self.layer_norm1(x) + y = self.self_attn(y, y, y, mask) + x = x + y + y = self.layer_norm2(x) + y = self.mlp(y) + return x + y + + +class VisionEncoder(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.layers = [VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)] + + +class VisionEmbeddings(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = mx.zeros((config.hidden_size,)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + + def __call__(self, x: mx.array) -> mx.array: + batch_size = x.shape[0] + patch_embeddings = self.patch_embedding(x) + patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2) + embed_dim = patch_embeddings.shape[-1] + cls_embeddings = mx.broadcast_to( + self.class_embedding, (batch_size, 1, embed_dim) + ) + embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1) + embeddings += self.position_embedding.weight + return embeddings + + +class ClipVisionModel(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.embeddings = VisionEmbeddings(config) + self.pre_layrnorm = nn.LayerNorm(config.hidden_size) + self.encoder = VisionEncoder(config) + self.post_layernorm = nn.LayerNorm(config.hidden_size) + + def __call__( + self, + x: mx.array, + output_hidden_states: Optional[bool] = None, + ) -> mx.array: + x = self.embeddings(x) + x = self.pre_layrnorm(x) + + encoder_states = (x,) if output_hidden_states else None + + for l in self.encoder.layers: + x = l(x, mask=None) + if output_hidden_states: + encoder_states = encoder_states + (x,) + + pooler_output = self.post_layernorm(x[:, 0, :]) + return pooler_output, x, encoder_states + + +class VisionModel(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + + self.model_type = config.model_type + if self.model_type != "clip_vision_model": + raise ValueError(f"Unsupported model type: {self.model_type}") + + self.vision_model = ClipVisionModel(config) + + def __call__( + self, x: mx.array, output_hidden_states: Optional[bool] = None + ) -> mx.array: + return self.vision_model(x, output_hidden_states) + + @staticmethod + def sanitize(weights): + sanitized_weights = {} + for k, v in weights.items(): + if "position_ids" in k: + # Remove unused position_ids + continue + elif "patch_embedding.weight" in k: + # PyTorch conv2d weight tensors have shape: + # [out_channels, in_channels, kH, KW] + # MLX conv2d expects the weight be of shape: + # [out_channels, kH, KW, in_channels] + sanitized_weights[k] = v.transpose(0, 2, 3, 1) + else: + sanitized_weights[k] = v + + return sanitized_weights + +@dataclass +class TextConfig: + model_type: str + hidden_size: int = 4096 + num_hidden_layers: int = 32 + intermediate_size: int = 11008 + num_attention_heads: int = 32 + rms_norm_eps: float = 1e-6 + vocab_size: int = 32000 + num_key_value_heads: int = None + rope_theta: float = 10000 + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + + def __post_init__(self): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + if self.rope_scaling: + required_keys = {"factor", "type"} + if not all(key in self.rope_scaling for key in required_keys): + raise ValueError(f"rope_scaling must contain keys {required_keys}") + + if self.rope_scaling["type"] != "linear": + raise ValueError("rope_scaling 'type' currently only supports 'linear'") + + +class TextAttention(nn.Module): + def __init__(self, config: TextConfig): + super().__init__() + + dim = config.hidden_size + self.n_heads = n_heads = config.num_attention_heads + self.n_kv_heads = n_kv_heads = config.num_key_value_heads + + self.repeats = n_heads // n_kv_heads + + head_dim = config.hidden_size // n_heads + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) + + rope_scale = ( + 1 / config.rope_scaling["factor"] + if config.rope_scaling is not None + and config.rope_scaling["type"] == "linear" + else 1 + ) + self.rope = nn.RoPE( + head_dim, + traditional=config.rope_traditional, + base=config.rope_theta, + scale=rope_scale, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = mx.concatenate([key_cache, keys], axis=2) + values = mx.concatenate([value_cache, values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output), (keys, values) + + +class TextMLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + self.up_proj = nn.Linear(dim, hidden_dim, bias=False) + + def __call__(self, x) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, config: TextConfig): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.self_attn = TextAttention(config) + self.mlp = TextMLP(config.hidden_size, config.intermediate_size) + self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.config = config + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + r, cache = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out, cache + + +class Llama(nn.Module): + def __init__(self, config: TextConfig): + super().__init__() + self.config = config + self.vocab_size = config.vocab_size + self.num_hidden_layers = config.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = [ + TransformerBlock(config=config) for _ in range(config.num_hidden_layers) + ] + self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + inputs_embeds=None, + ): + # for passing merged input embeddings + if inputs_embeds is None: + h = self.embed_tokens(inputs) + else: + h = inputs_embeds + + mask = None + if h.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) + mask = mask.astype(h.dtype) + + if cache is None: + cache = [None] * len(self.layers) + + for e, layer in enumerate(self.layers): + h, cache[e] = layer(h, mask, cache[e]) + + return self.norm(h), cache + + +class LanguageModel(nn.Module): + def __init__(self, config: TextConfig): + super().__init__() + self.model_type = config.model_type + if self.model_type != "llama": + raise ValueError( + f"Model type {self.model_type} not supported. Currently only 'llama' is supported" + ) + self.model = Llama(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + inputs_embeds=None, + ): + out, cache = self.model(inputs, cache, inputs_embeds) + return self.lm_head(out), cache + + @staticmethod + def sanitize(weights): + # Remove unused precomputed rotary freqs + return { + k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k + } + + +@dataclass +class LlaVAConfig: + text_config: TextConfig + vision_config: VisionConfig + ignore_index: int = -100 + image_token_index: int = 32000 + vision_feature_select_strategy: str = "default" + vision_feature_layer: int = -2 + vocab_size: int = 32000 + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + + +class LlavaMultiModalProjector(nn.Module): + def __init__(self, config: LlaVAConfig): + super().__init__() + self.linear_1 = nn.Linear( + config.vision_config.hidden_size, config.text_config.hidden_size, bias=True + ) + self.gelu = nn.GELU() + self.linear_2 = nn.Linear( + config.text_config.hidden_size, config.text_config.hidden_size, bias=True + ) + + def __call__(self, x: mx.array) -> mx.array: + x = self.linear_1(x) + x = self.gelu(x) + x = self.linear_2(x) + return x + + +class LlavaModel(nn.Module): + def __init__(self, config: LlaVAConfig): + self.config = config + self.vision_tower = VisionModel(config.vision_config) + self.language_model = LanguageModel(config.text_config) + self.multi_modal_projector = LlavaMultiModalProjector(config) + self.vision_feature_layer = config.vision_feature_layer + self.vision_feature_select_strategy = config.vision_feature_select_strategy + + def get_input_embeddings( + self, + input_ids: Optional[mx.array] = None, + pixel_values: Optional[mx.array] = None, + ): + if pixel_values is None: + return self.language_model(input_ids) + + # Get the input embeddings from the language model + inputs_embeds = self.language_model.model.embed_tokens(input_ids) + + # Get the ouptut hidden states from the vision model + *_, hidden_states = self.vision_tower( + pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True + ) + + # Select the hidden states from the desired layer + selected_image_feature = hidden_states[self.vision_feature_layer] + + if self.vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif self.vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + else: + raise ValueError( + "Unexpected feature selection strategy: " + f"{self.vision_feature_select_strategy}" + ) + + # Pass image features through the multi-modal projector + image_features = self.multi_modal_projector(selected_image_feature) + + # Insert special image tokens in the input_ids + final_inputs_embeds = self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids + ) + return final_inputs_embeds + + def _merge_input_ids_with_image_features( + self, image_features, inputs_embeds, input_ids + ): + image_token_index = self.config.image_token_index + num_images, num_image_patches, embed_dim = image_features.shape + + # Positions of tokens in input_ids, assuming batch size is 1 + image_positions = np.where(input_ids[0] == image_token_index)[0].tolist() + + if len(image_positions) != num_images: + raise ValueError( + f"The number of image tokens ({len(image_positions)}) does not " + f" match the number of image inputs ({num_images})." + ) + + text_segments = [] + start_idx = 0 + + for position in image_positions: + text_segments.append(inputs_embeds[:, start_idx:position]) + start_idx = position + 1 + + image_embeddings = mx.split(image_features, image_features.shape[0]) + final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p] + final_embeddings += [inputs_embeds[:, start_idx:]] + + # Create a final embedding of shape + # (1, num_image_patches*num_images + sequence_len, embed_dim) + return mx.concatenate(final_embeddings, axis=1) + + def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None): + input_embddings = self.get_input_embeddings(input_ids, pixel_values) + logits, cache = self.language_model( + input_ids, cache=cache, inputs_embeds=input_embddings + ) + return logits, cache + + @staticmethod + def from_pretrained(path_or_hf_repo: str): + path = Path(path_or_hf_repo) + if not path.exists(): + path = Path( + snapshot_download( + repo_id=path_or_hf_repo, + allow_patterns=[ + "*.json", + "*.safetensors", + "*.py", + "tokenizer.model", + "*.tiktoken", + ], + ) + ) + + with open(path / "config.json", "r") as f: + model_config = json.load(f) + + model_config = LlaVAConfig.from_dict(model_config) + + model_config.vision_config = VisionConfig.from_dict(model_config.vision_config) + model_config.text_config = TextConfig.from_dict(model_config.text_config) + + model = LlavaModel(model_config) + weight_files = glob.glob(str(path / "*.safetensors")) + if not weight_files: + raise FileNotFoundError(f"No safetensors found in {path}") + + weights = {} + for wf in weight_files: + weights.update(mx.load(wf)) + + weights = VisionModel.sanitize(weights) + weights = LanguageModel.sanitize(weights) + + model.load_weights(list(weights.items())) + return model diff --git a/exo/inference/mlx/sharded_utils.py b/exo/inference/mlx/sharded_utils.py index dd59a52ab..0a00000f6 100644 --- a/exo/inference/mlx/sharded_utils.py +++ b/exo/inference/mlx/sharded_utils.py @@ -13,11 +13,13 @@ import mlx.nn as nn from huggingface_hub import snapshot_download from huggingface_hub.utils._errors import RepositoryNotFoundError +from transformers import AutoProcessor from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper from mlx_lm.tuner.utils import apply_lora_layers from ..shard import Shard +from exo.inference.mlx.models.sharded_llava import LlavaModel, LlaVAConfig, VisionConfig, VisionModel, TextConfig, LanguageModel class ModelNotFoundError(Exception): def __init__(self, message): @@ -228,4 +230,60 @@ async def load_shard( model.eval() tokenizer = load_tokenizer(model_path, tokenizer_config) - return model, tokenizer \ No newline at end of file + return model, tokenizer + + +async def load_shard_llava( + path_or_hf_repo: str, + shard: Shard, + tokenizer_config={}, + model_config={}, + adapter_path: Optional[str] = None, + lazy: bool = False, +) -> Tuple[nn.Module, TokenizerWrapper]: + """ + Load the model and tokenizer from a given path or a huggingface repository. + + Args: + path_or_hf_repo (Path): The path or the huggingface repository to load the model from. + tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer. + Defaults to an empty dictionary. + model_config(dict, optional): Configuration parameters specifically for the model. + Defaults to an empty dictionary. + adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers + to the model. Default: ``None``. + lazy (bool): If False eval the model parameters to make sure they are + loaded in memory before returning, otherwise they will be loaded + when needed. Default: ``False`` + Returns: + Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer. + + Raises: + FileNotFoundError: If config file or safetensors are not found. + ValueError: If model class or args class are not found. + """ + model_path = await get_model_path(path_or_hf_repo) + processor = AutoProcessor.from_pretrained(model_path) + + with open(model_path / "config.json", "r") as f: + model_config = json.load(f) + + model_config = LlaVAConfig.from_dict(model_config) + + model_config.vision_config = VisionConfig.from_dict(model_config.vision_config) + model_config.text_config = TextConfig.from_dict(model_config.text_config) + + model = LlavaModel(model_config) + weight_files = glob.glob(str(model_path / "*.safetensors")) + if not weight_files: + raise FileNotFoundError(f"No safetensors found in {model_path}") + + weights = {} + for wf in weight_files: + weights.update(mx.load(wf)) + + weights = VisionModel.sanitize(weights) + weights = LanguageModel.sanitize(weights) + + model.load_weights(list(weights.items())) + return model, processor \ No newline at end of file diff --git a/exo/inference/mlx/test_sharded_llava.py b/exo/inference/mlx/test_sharded_llava.py new file mode 100644 index 000000000..e69de29bb From 7cbf6a35bd065b0cd2e8bb5b219f08c8c0e5367f Mon Sep 17 00:00:00 2001 From: Varshith Date: Fri, 26 Jul 2024 19:12:42 +0530 Subject: [PATCH 02/16] working test --- exo/inference/mlx/models/sharded_llava.py | 49 ++++++++++-------- exo/inference/mlx/sharded_model.py | 13 +++-- exo/inference/mlx/test_sharded_llava.py | 62 +++++++++++++++++++++++ 3 files changed, 99 insertions(+), 25 deletions(-) diff --git a/exo/inference/mlx/models/sharded_llava.py b/exo/inference/mlx/models/sharded_llava.py index a90dd79e2..883351f0e 100644 --- a/exo/inference/mlx/models/sharded_llava.py +++ b/exo/inference/mlx/models/sharded_llava.py @@ -10,6 +10,7 @@ import mlx.core as mx import mlx.nn as nn +from mlx_lm.models.base import KVCache import numpy as np from huggingface_hub import snapshot_download @@ -236,7 +237,8 @@ class TextConfig: num_attention_heads: int = 32 rms_norm_eps: float = 1e-6 vocab_size: int = 32000 - num_key_value_heads: int = None + n_kv_heads: int = None + head_dim: Optional[int] = None rope_theta: float = 10000 rope_traditional: bool = False rope_scaling: Optional[Dict[str, Union[float, str]]] = None @@ -252,8 +254,11 @@ def from_dict(cls, params): ) def __post_init__(self): - if self.num_key_value_heads is None: - self.num_key_value_heads = self.num_attention_heads + if self.n_kv_heads is None: + self.n_kv_heads = self.num_attention_heads + + if self.head_dim is None: + self.head_dim = self.hidden_size // self.num_attention_heads if self.rope_scaling: required_keys = {"factor", "type"} @@ -270,7 +275,7 @@ def __init__(self, config: TextConfig): dim = config.hidden_size self.n_heads = n_heads = config.num_attention_heads - self.n_kv_heads = n_kv_heads = config.num_key_value_heads + self.n_kv_heads = n_kv_heads = config.n_kv_heads self.repeats = n_heads // n_kv_heads @@ -299,7 +304,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: B, L, D = x.shape @@ -311,11 +316,9 @@ def __call__( values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) else: queries = self.rope(queries) keys = self.rope(keys) @@ -324,7 +327,7 @@ def __call__( queries, keys, values, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output), (keys, values) + return self.o_proj(output) class TextMLP(nn.Module): @@ -355,13 +358,13 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[KVCache] = None, ) -> mx.array: - r, cache = self.self_attn(self.input_layernorm(x), mask, cache) + r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r r = self.mlp(self.post_attention_layernorm(h)) out = h + r - return out, cache + return out class Llama(nn.Module): @@ -370,6 +373,8 @@ def __init__(self, config: TextConfig): self.config = config self.vocab_size = config.vocab_size self.num_hidden_layers = config.num_hidden_layers + self.n_kv_heads = config.n_kv_heads + self.head_dim = config.head_dim assert self.vocab_size > 0 self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = [ @@ -397,10 +402,11 @@ def __call__( if cache is None: cache = [None] * len(self.layers) - for e, layer in enumerate(self.layers): - h, cache[e] = layer(h, mask, cache[e]) - return self.norm(h), cache + for layer, c in zip(self.layers, cache): + h = layer(h, mask, c) + + return self.norm(h) class LanguageModel(nn.Module): @@ -420,8 +426,8 @@ def __call__( cache=None, inputs_embeds=None, ): - out, cache = self.model(inputs, cache, inputs_embeds) - return self.lm_head(out), cache + out = self.model(inputs, cache, inputs_embeds) + return self.lm_head(out) @staticmethod def sanitize(weights): @@ -435,6 +441,7 @@ def sanitize(weights): class LlaVAConfig: text_config: TextConfig vision_config: VisionConfig + model_type: str = "llava" ignore_index: int = -100 image_token_index: int = 32000 vision_feature_select_strategy: str = "default" @@ -549,10 +556,10 @@ def _merge_input_ids_with_image_features( def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None): input_embddings = self.get_input_embeddings(input_ids, pixel_values) - logits, cache = self.language_model( + logits = self.language_model( input_ids, cache=cache, inputs_embeds=input_embddings ) - return logits, cache + return logits @staticmethod def from_pretrained(path_or_hf_repo: str): diff --git a/exo/inference/mlx/sharded_model.py b/exo/inference/mlx/sharded_model.py index 43e26d171..dbf0b80e7 100644 --- a/exo/inference/mlx/sharded_model.py +++ b/exo/inference/mlx/sharded_model.py @@ -57,9 +57,14 @@ def __call__( return self.step(x, temp, top_p, logit_bias) def reset(self): + if hasattr(self.model.config, "vision_config"): + model = self.model.language_model.model + else: + model = self.model + kv_heads = ( - [self.model.n_kv_heads] * len(self.model.layers) - if isinstance(self.model.n_kv_heads, int) - else self.model.n_kv_heads + [model.n_kv_heads] * len(model.layers) + if isinstance(model.n_kv_heads, int) + else model.n_kv_heads ) - self.cache = [KVCache(self.model.head_dim, n) for n in kv_heads] + self.cache = [KVCache(model.head_dim, n) for n in kv_heads] diff --git a/exo/inference/mlx/test_sharded_llava.py b/exo/inference/mlx/test_sharded_llava.py index e69de29bb..fdd160c01 100644 --- a/exo/inference/mlx/test_sharded_llava.py +++ b/exo/inference/mlx/test_sharded_llava.py @@ -0,0 +1,62 @@ +import torch +import codecs +import asyncio +import requests +from PIL import Image +from io import BytesIO + +import mlx.core as mx +from mlx_lm.models.base import KVCache + +from exo.inference.mlx.sharded_model import StatefulShardedModel +from exo.inference.mlx.sharded_utils import load_shard_llava +from exo.inference.shard import Shard + +def sample(logits, temperature=0.0): + if temperature == 0: + return mx.argmax(logits, axis=-1) + else: + return mx.random.categorical(logits * (1 / temperature)) +def generate_text(input_ids, pixel_values, model, processor, max_tokens, temperature): + kv_heads = ( + [model.language_model.model.n_kv_heads] * len(model.language_model.model.layers) + if isinstance(model.language_model.model.n_kv_heads, int) + else model.language_model.model.n_kv_heads + ) + cache = [KVCache(model.language_model.model.head_dim, n) for n in kv_heads] + logits = model(input_ids, pixel_values, cache=cache) + logits = logits[:, -1, :] + y = sample(logits, temperature=temperature) + tokens = [y.item()] + + for n in range(max_tokens - 1): + logits = model.language_model(y[None], cache=cache) + logits = logits[:, -1, :] + y = sample(logits, temperature) + token = y.item() + if token == processor.tokenizer.eos_token_id: + break + tokens.append(token) + + return processor.tokenizer.decode(tokens) + +shard_full = Shard("llava", 0, 31, 32) + +full_model_shard, full_processor = asyncio.run(load_shard_llava("llava-hf/llava-1.5-7b-hf", shard=shard_full)) + +full = StatefulShardedModel(shard_full, full_model_shard) + +PROMPT = "USER: \nWhat are these?\nASSISTANT:" +IMAGE_FILE = "http://images.cocodataset.org/val2017/000000039769.jpg" +response = requests.get(IMAGE_FILE) +img = Image.open(BytesIO(response.content)) +prompt = codecs.decode(PROMPT, "unicode_escape") +inputs = full_processor(prompt, img, return_tensors="np") +pixel_values = mx.array(inputs["pixel_values"]) +input_ids = mx.array(inputs["input_ids"]) + +print(prompt) +generated_text = generate_text( + input_ids, pixel_values, full_model_shard, full_processor, 10, 0 +) +print(generated_text) \ No newline at end of file From 9d2616b9cfe3222fff64a0855b10cc73e22cbe33 Mon Sep 17 00:00:00 2001 From: Varshith Date: Sun, 28 Jul 2024 00:30:34 +0530 Subject: [PATCH 03/16] shareded inference --- .gitignore | 1 + exo/inference/mlx/models/sharded_llava.py | 219 ++++++++++++---------- exo/inference/mlx/sharded_model.py | 24 ++- exo/inference/mlx/sharded_utils.py | 73 ++------ exo/inference/mlx/test_sharded_llava.py | 69 +++---- exo/inference/shard.py | 3 + 6 files changed, 181 insertions(+), 208 deletions(-) diff --git a/.gitignore b/.gitignore index b97ff07aa..13e309d3b 100644 --- a/.gitignore +++ b/.gitignore @@ -83,6 +83,7 @@ target/ # Jupyter Notebook .ipynb_checkpoints +Untitled.ipynb # IPython profile_default/ diff --git a/exo/inference/mlx/models/sharded_llava.py b/exo/inference/mlx/models/sharded_llava.py index 883351f0e..9923846c3 100644 --- a/exo/inference/mlx/models/sharded_llava.py +++ b/exo/inference/mlx/models/sharded_llava.py @@ -1,18 +1,15 @@ # Copyright © 2024 Apple Inc. import math -import glob import inspect -import json -from dataclasses import dataclass -from pathlib import Path -from typing import Optional, Dict, Union, Tuple +from dataclasses import dataclass, field +from typing import Optional, Dict, Union import mlx.core as mx import mlx.nn as nn -from mlx_lm.models.base import KVCache +from mlx_lm.models.base import BaseModelArgs, KVCache +from exo.inference.shard import Shard import numpy as np -from huggingface_hub import snapshot_download @dataclass @@ -42,15 +39,15 @@ def from_dict(cls, params): class VisionAttention(nn.Module): def __init__( - self, - dims: int, - num_heads: int, - query_input_dims: Optional[int] = None, - key_input_dims: Optional[int] = None, - value_input_dims: Optional[int] = None, - value_dims: Optional[int] = None, - value_output_dims: Optional[int] = None, - bias: bool = False, + self, + dims: int, + num_heads: int, + query_input_dims: Optional[int] = None, + key_input_dims: Optional[int] = None, + value_input_dims: Optional[int] = None, + value_dims: Optional[int] = None, + value_output_dims: Optional[int] = None, + bias: bool = False, ): super().__init__() @@ -206,7 +203,7 @@ def __init__(self, config: VisionConfig): self.vision_model = ClipVisionModel(config) def __call__( - self, x: mx.array, output_hidden_states: Optional[bool] = None + self, x: mx.array, output_hidden_states: Optional[bool] = None ) -> mx.array: return self.vision_model(x, output_hidden_states) @@ -228,6 +225,7 @@ def sanitize(weights): return sanitized_weights + @dataclass class TextConfig: model_type: str @@ -235,10 +233,10 @@ class TextConfig: num_hidden_layers: int = 32 intermediate_size: int = 11008 num_attention_heads: int = 32 + head_dim: int = None rms_norm_eps: float = 1e-6 vocab_size: int = 32000 - n_kv_heads: int = None - head_dim: Optional[int] = None + num_key_value_heads: int = None rope_theta: float = 10000 rope_traditional: bool = False rope_scaling: Optional[Dict[str, Union[float, str]]] = None @@ -254,12 +252,15 @@ def from_dict(cls, params): ) def __post_init__(self): - if self.n_kv_heads is None: - self.n_kv_heads = self.num_attention_heads + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads if self.head_dim is None: self.head_dim = self.hidden_size // self.num_attention_heads + if self.model_type is None: + self.model_type = "llama" + if self.rope_scaling: required_keys = {"factor", "type"} if not all(key in self.rope_scaling for key in required_keys): @@ -275,12 +276,12 @@ def __init__(self, config: TextConfig): dim = config.hidden_size self.n_heads = n_heads = config.num_attention_heads - self.n_kv_heads = n_kv_heads = config.n_kv_heads + self.n_kv_heads = n_kv_heads = config.num_key_value_heads self.repeats = n_heads // n_kv_heads head_dim = config.hidden_size // n_heads - self.scale = head_dim**-0.5 + self.scale = head_dim ** -0.5 self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) @@ -290,7 +291,7 @@ def __init__(self, config: TextConfig): rope_scale = ( 1 / config.rope_scaling["factor"] if config.rope_scaling is not None - and config.rope_scaling["type"] == "linear" + and config.rope_scaling["type"] == "linear" else 1 ) self.rope = nn.RoPE( @@ -301,10 +302,10 @@ def __init__(self, config: TextConfig): ) def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[KVCache] = None, ) -> mx.array: B, L, D = x.shape @@ -355,10 +356,10 @@ def __init__(self, config: TextConfig): self.config = config def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[KVCache] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -368,12 +369,15 @@ def __call__( class Llama(nn.Module): - def __init__(self, config: TextConfig): + def __init__(self, config: TextConfig, is_first_layer, is_last_layer): super().__init__() self.config = config + self.is_first_layer = is_first_layer + self.is_last_layer = is_last_layer self.vocab_size = config.vocab_size + self.model_type = config.model_type self.num_hidden_layers = config.num_hidden_layers - self.n_kv_heads = config.n_kv_heads + self.num_key_value_heads = config.num_key_value_heads self.head_dim = config.head_dim assert self.vocab_size > 0 self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) @@ -383,14 +387,17 @@ def __init__(self, config: TextConfig): self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def __call__( - self, - inputs: mx.array, - cache=None, - inputs_embeds=None, + self, + inputs: mx.array, + cache=None, + inputs_embeds=None, ): # for passing merged input embeddings if inputs_embeds is None: - h = self.embed_tokens(inputs) + if self.is_first_layer: + h = self.embed_tokens(inputs) + else: + h = inputs else: h = inputs_embeds @@ -406,18 +413,20 @@ def __call__( for layer, c in zip(self.layers, cache): h = layer(h, mask, c) - return self.norm(h) - + if self.is_last_layer: + h = self.norm(h) + return h class LanguageModel(nn.Module): - def __init__(self, config: TextConfig): + def __init__(self, config: TextConfig, is_first_layer, is_last_layer): super().__init__() self.model_type = config.model_type if self.model_type != "llama": raise ValueError( f"Model type {self.model_type} not supported. Currently only 'llama' is supported" ) - self.model = Llama(config) + self.is_last_layer = is_last_layer + self.model = Llama(config, is_first_layer, is_last_layer) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) def __call__( @@ -427,7 +436,9 @@ def __call__( inputs_embeds=None, ): out = self.model(inputs, cache, inputs_embeds) - return self.lm_head(out) + if self.is_last_layer: + out = self.lm_head(out) + return out @staticmethod def sanitize(weights): @@ -436,11 +447,10 @@ def sanitize(weights): k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k } - @dataclass -class LlaVAConfig: +class LlaVAConfig(BaseModelArgs): text_config: TextConfig - vision_config: VisionConfig + vision_config: VisionConfig = None model_type: str = "llava" ignore_index: int = -100 image_token_index: int = 32000 @@ -450,13 +460,32 @@ class LlaVAConfig: @classmethod def from_dict(cls, params): - return cls( - **{ - k: v - for k, v in params.items() - if k in inspect.signature(cls).parameters - } - ) + updated_params = {} + class_params = inspect.signature(cls).parameters + for k, v in params.items(): + if k in class_params: + if k in ["text_config", "vision_config"]: + v = class_params[k].annotation.from_dict(v) + updated_params.update({k: v}) + + return cls(**updated_params) + + +@dataclass +class ModelArgs(LlaVAConfig): + shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0)) + + def __post_init__(self): + if isinstance(self.shard, dict): + self.shard = Shard(**self.shard) + + if not isinstance(self.shard, Shard): + raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead") + + if not self.shard.is_first_layer(): + self.vision_config = None + + self.text_config.num_hidden_layers = self.shard.get_layer_count() class LlavaMultiModalProjector(nn.Module): @@ -477,19 +506,22 @@ def __call__(self, x: mx.array) -> mx.array: return x -class LlavaModel(nn.Module): - def __init__(self, config: LlaVAConfig): +class Model(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() self.config = config - self.vision_tower = VisionModel(config.vision_config) - self.language_model = LanguageModel(config.text_config) - self.multi_modal_projector = LlavaMultiModalProjector(config) - self.vision_feature_layer = config.vision_feature_layer - self.vision_feature_select_strategy = config.vision_feature_select_strategy + self.model_type = config.model_type + if config.vision_config: + self.vision_tower = VisionModel(config.vision_config) + self.multi_modal_projector = LlavaMultiModalProjector(config) + self.vision_feature_layer = config.vision_feature_layer + self.vision_feature_select_strategy = config.vision_feature_select_strategy + self.language_model = LanguageModel(config.text_config, config.shard.is_first_layer(), config.shard.is_last_layer()) def get_input_embeddings( - self, - input_ids: Optional[mx.array] = None, - pixel_values: Optional[mx.array] = None, + self, + input_ids: Optional[mx.array] = None, + pixel_values: Optional[mx.array] = None, ): if pixel_values is None: return self.language_model(input_ids) @@ -525,7 +557,7 @@ def get_input_embeddings( return final_inputs_embeds def _merge_input_ids_with_image_features( - self, image_features, inputs_embeds, input_ids + self, image_features, inputs_embeds, input_ids ): image_token_index = self.config.image_token_index num_images, num_image_patches, embed_dim = image_features.shape @@ -554,49 +586,32 @@ def _merge_input_ids_with_image_features( # (1, num_image_patches*num_images + sequence_len, embed_dim) return mx.concatenate(final_embeddings, axis=1) - def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None): - input_embddings = self.get_input_embeddings(input_ids, pixel_values) + def __call__(self, input_ids: mx.array, pixel_values: mx.array = None, cache=None): + input_embddings = None + if pixel_values is not None: + input_embddings = self.get_input_embeddings(input_ids, pixel_values) logits = self.language_model( input_ids, cache=cache, inputs_embeds=input_embddings ) return logits - @staticmethod - def from_pretrained(path_or_hf_repo: str): - path = Path(path_or_hf_repo) - if not path.exists(): - path = Path( - snapshot_download( - repo_id=path_or_hf_repo, - allow_patterns=[ - "*.json", - "*.safetensors", - "*.py", - "tokenizer.model", - "*.tiktoken", - ], - ) - ) - - with open(path / "config.json", "r") as f: - model_config = json.load(f) - - model_config = LlaVAConfig.from_dict(model_config) + def sanitize(self, weights): + if self.config.vision_config: + weights = self.vision_tower.sanitize(weights) + weights = self.language_model.sanitize(weights) - model_config.vision_config = VisionConfig.from_dict(model_config.vision_config) - model_config.text_config = TextConfig.from_dict(model_config.text_config) + return weights - model = LlavaModel(model_config) - weight_files = glob.glob(str(path / "*.safetensors")) - if not weight_files: - raise FileNotFoundError(f"No safetensors found in {path}") + @property + def layers(self): + return self.language_model.model.layers - weights = {} - for wf in weight_files: - weights.update(mx.load(wf)) - - weights = VisionModel.sanitize(weights) - weights = LanguageModel.sanitize(weights) + @property + def head_dim(self): + return ( + self.language_model.model.head_dim or self.language_model.model.hidden_size // self.language_model.model.num_attention_heads + ) - model.load_weights(list(weights.items())) - return model + @property + def n_kv_heads(self): + return self.language_model.model.num_key_value_heads diff --git a/exo/inference/mlx/sharded_model.py b/exo/inference/mlx/sharded_model.py index dbf0b80e7..fff463a50 100644 --- a/exo/inference/mlx/sharded_model.py +++ b/exo/inference/mlx/sharded_model.py @@ -15,7 +15,8 @@ def __init__(self, shard: Shard, model: nn.Module): def step( self, - x, + y, + pixel_values=None, temp: float = 0.0, top_p: float = 1.0, logit_bias: Optional[Dict[int, float]] = None, @@ -36,9 +37,11 @@ def sample(logits: mx.array) -> Tuple[mx.array, float]: return token - y = x - - output = self.model(y[None] if self.shard.is_first_layer() else y, cache=self.cache) + # TODO : revert hacky fix + if pixel_values is None: + output = self.model(y[None] if self.shard.is_first_layer() else y, cache=self.cache) + else: + output = self.model(y, pixel_values=pixel_values, cache=self.cache) if self.shard.is_last_layer(): logits = output[:, -1, :] @@ -57,14 +60,9 @@ def __call__( return self.step(x, temp, top_p, logit_bias) def reset(self): - if hasattr(self.model.config, "vision_config"): - model = self.model.language_model.model - else: - model = self.model - kv_heads = ( - [model.n_kv_heads] * len(model.layers) - if isinstance(model.n_kv_heads, int) - else model.n_kv_heads + [self.model.n_kv_heads] * len(self.model.layers) + if isinstance(self.model.n_kv_heads, int) + else self.model.n_kv_heads ) - self.cache = [KVCache(model.head_dim, n) for n in kv_heads] + self.cache = [KVCache(self.model.head_dim, n) for n in kv_heads] diff --git a/exo/inference/mlx/sharded_utils.py b/exo/inference/mlx/sharded_utils.py index 0a00000f6..0bc8efe63 100644 --- a/exo/inference/mlx/sharded_utils.py +++ b/exo/inference/mlx/sharded_utils.py @@ -19,7 +19,6 @@ from mlx_lm.tuner.utils import apply_lora_layers from ..shard import Shard -from exo.inference.mlx.models.sharded_llava import LlavaModel, LlaVAConfig, VisionConfig, VisionModel, TextConfig, LanguageModel class ModelNotFoundError(Exception): def __init__(self, message): @@ -29,6 +28,7 @@ def __init__(self, message): MODEL_REMAPPING = { "sharded_mistral": "sharded_llama", # mistral is compatible with llama "sharded_phi-msft": "sharded_phixtral", + "sharded_llava": "sharded_llava" } def _get_classes(config: dict): @@ -113,6 +113,7 @@ def load_model_shard( for wf in weight_files: weights_dict = mx.load(wf) all_weights_keys.update(weights_dict.keys()) + weights.update({k: v for k, v in weights_dict.items() if not k.startswith("language_model.model.layers.") or shard.start_layer <= int(k.split('.')[3]) <= shard.end_layer}) weights.update({k: v for k, v in weights_dict.items() if not k.startswith("model.layers.") or shard.start_layer <= int(k.split('.')[2]) <= shard.end_layer}) model_class, model_args_class = _get_classes(config=config) @@ -137,6 +138,11 @@ def load_model_shard( if shard.start_layer <= layer_num <= shard.end_layer: new_key = f"model.layers.{layer_num - shard.start_layer}." + '.'.join(k.split('.')[3:]) filtered_weights[new_key] = v + elif k.startswith("language_model.model.layers."): + layer_num = int(k.split('.')[3]) + if shard.start_layer <= layer_num <= shard.end_layer: + new_key = f"language_model.model.layers.{layer_num - shard.start_layer}." + '.'.join(k.split('.')[4:]) + filtered_weights[new_key] = v else: filtered_weights[k] = v weights = filtered_weights @@ -228,62 +234,11 @@ async def load_shard( if adapter_path is not None: model = apply_lora_layers(model, adapter_path) model.eval() - tokenizer = load_tokenizer(model_path, tokenizer_config) - - return model, tokenizer - - -async def load_shard_llava( - path_or_hf_repo: str, - shard: Shard, - tokenizer_config={}, - model_config={}, - adapter_path: Optional[str] = None, - lazy: bool = False, -) -> Tuple[nn.Module, TokenizerWrapper]: - """ - Load the model and tokenizer from a given path or a huggingface repository. - - Args: - path_or_hf_repo (Path): The path or the huggingface repository to load the model from. - tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer. - Defaults to an empty dictionary. - model_config(dict, optional): Configuration parameters specifically for the model. - Defaults to an empty dictionary. - adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers - to the model. Default: ``None``. - lazy (bool): If False eval the model parameters to make sure they are - loaded in memory before returning, otherwise they will be loaded - when needed. Default: ``False`` - Returns: - Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer. - - Raises: - FileNotFoundError: If config file or safetensors are not found. - ValueError: If model class or args class are not found. - """ - model_path = await get_model_path(path_or_hf_repo) - processor = AutoProcessor.from_pretrained(model_path) - - with open(model_path / "config.json", "r") as f: - model_config = json.load(f) - - model_config = LlaVAConfig.from_dict(model_config) - - model_config.vision_config = VisionConfig.from_dict(model_config.vision_config) - model_config.text_config = TextConfig.from_dict(model_config.text_config) - - model = LlavaModel(model_config) - weight_files = glob.glob(str(model_path / "*.safetensors")) - if not weight_files: - raise FileNotFoundError(f"No safetensors found in {model_path}") - - weights = {} - for wf in weight_files: - weights.update(mx.load(wf)) - - weights = VisionModel.sanitize(weights) - weights = LanguageModel.sanitize(weights) - model.load_weights(list(weights.items())) - return model, processor \ No newline at end of file + # TODO: figure out a better way + if "llama" in str(model_path): + tokenizer = load_tokenizer(model_path, tokenizer_config) + return model, tokenizer + elif "llava" in str(model_path): + processor = AutoProcessor.from_pretrained(model_path) + return model, processor \ No newline at end of file diff --git a/exo/inference/mlx/test_sharded_llava.py b/exo/inference/mlx/test_sharded_llava.py index fdd160c01..14a66aa27 100644 --- a/exo/inference/mlx/test_sharded_llava.py +++ b/exo/inference/mlx/test_sharded_llava.py @@ -9,42 +9,20 @@ from mlx_lm.models.base import KVCache from exo.inference.mlx.sharded_model import StatefulShardedModel -from exo.inference.mlx.sharded_utils import load_shard_llava +from exo.inference.mlx.sharded_utils import load_shard from exo.inference.shard import Shard -def sample(logits, temperature=0.0): - if temperature == 0: - return mx.argmax(logits, axis=-1) - else: - return mx.random.categorical(logits * (1 / temperature)) -def generate_text(input_ids, pixel_values, model, processor, max_tokens, temperature): - kv_heads = ( - [model.language_model.model.n_kv_heads] * len(model.language_model.model.layers) - if isinstance(model.language_model.model.n_kv_heads, int) - else model.language_model.model.n_kv_heads - ) - cache = [KVCache(model.language_model.model.head_dim, n) for n in kv_heads] - logits = model(input_ids, pixel_values, cache=cache) - logits = logits[:, -1, :] - y = sample(logits, temperature=temperature) - tokens = [y.item()] - - for n in range(max_tokens - 1): - logits = model.language_model(y[None], cache=cache) - logits = logits[:, -1, :] - y = sample(logits, temperature) - token = y.item() - if token == processor.tokenizer.eos_token_id: - break - tokens.append(token) - - return processor.tokenizer.decode(tokens) - shard_full = Shard("llava", 0, 31, 32) +shard1 = Shard("llava", 0, 12, 32) +shard2 = Shard("llava", 13, 31, 32) -full_model_shard, full_processor = asyncio.run(load_shard_llava("llava-hf/llava-1.5-7b-hf", shard=shard_full)) +full_model_shard, full_processor = asyncio.run(load_shard("llava-hf/llava-1.5-7b-hf", shard=shard_full)) +model_shard1, processor1 = asyncio.run(load_shard("llava-hf/llava-1.5-7b-hf", shard=shard1)) +model_shard2, processor2 = asyncio.run(load_shard("llava-hf/llava-1.5-7b-hf", shard=shard2)) full = StatefulShardedModel(shard_full, full_model_shard) +m1 = StatefulShardedModel(shard1, model_shard1) +m2 = StatefulShardedModel(shard2, model_shard2) PROMPT = "USER: \nWhat are these?\nASSISTANT:" IMAGE_FILE = "http://images.cocodataset.org/val2017/000000039769.jpg" @@ -56,7 +34,30 @@ def generate_text(input_ids, pixel_values, model, processor, max_tokens, tempera input_ids = mx.array(inputs["input_ids"]) print(prompt) -generated_text = generate_text( - input_ids, pixel_values, full_model_shard, full_processor, 10, 0 -) -print(generated_text) \ No newline at end of file +y = full.step(input_ids, pixel_values, temp=0) +full_generated_tokens = [y.item()] + +for _ in range(13): + y = full.step(y, temp=0) + full_generated_tokens.append(y.item()) + +full_response = full_processor.tokenizer.decode(full_generated_tokens) +print("full response:", full_response) + +inputs = processor1(prompt, img, return_tensors="np") +pixel_values = mx.array(inputs["pixel_values"]) +input_ids = mx.array(inputs["input_ids"]) + +y = m1.step(input_ids, pixel_values, temp=0) +y = m2.step(y, temp=0) +full_generated_tokens = [y.item()] + +for _ in range(13): + y = m1.step(y, temp=0) + y = m2.step(y, temp=0) + full_generated_tokens.append(y.item()) + +sharded_response = processor2.tokenizer.decode(full_generated_tokens) +print("sharded response:", sharded_response) + +assert full_response == sharded_response \ No newline at end of file diff --git a/exo/inference/shard.py b/exo/inference/shard.py index 79c8005e6..91df12936 100644 --- a/exo/inference/shard.py +++ b/exo/inference/shard.py @@ -13,6 +13,9 @@ def is_first_layer(self) -> bool: def is_last_layer(self) -> bool: return self.end_layer == self.n_layers - 1 + def get_layer_count(self) -> int: + return self.end_layer - self.start_layer + 1 + def to_dict(self) -> dict: return { "model_id": self.model_id, From 54993995dc426ff699ef5c2ada5af5c3a3e605c4 Mon Sep 17 00:00:00 2001 From: Varshith Date: Sun, 28 Jul 2024 00:44:41 +0530 Subject: [PATCH 04/16] conflicts --- exo/inference/mlx/sharded_model.py | 19 ++++++++++++------- exo/inference/mlx/test_sharded_llava.py | 12 ++++++------ 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/exo/inference/mlx/sharded_model.py b/exo/inference/mlx/sharded_model.py index fff463a50..b59cc3c37 100644 --- a/exo/inference/mlx/sharded_model.py +++ b/exo/inference/mlx/sharded_model.py @@ -11,11 +11,12 @@ class StatefulShardedModel: def __init__(self, shard: Shard, model: nn.Module): self.shard = shard self.model = model - self.reset() + self.request_cache: Dict[str, Tuple[str, KVCache]] = {} def step( self, - y, + request_id: str, + x, pixel_values=None, temp: float = 0.0, top_p: float = 1.0, @@ -37,11 +38,15 @@ def sample(logits: mx.array) -> Tuple[mx.array, float]: return token - # TODO : revert hacky fix + y = x + + if request_id not in self.request_cache: + self.init_cache(request_id) + if pixel_values is None: - output = self.model(y[None] if self.shard.is_first_layer() else y, cache=self.cache) + output = self.model(y[None] if self.shard.is_first_layer() else y, cache=self.request_cache[request_id]) else: - output = self.model(y, pixel_values=pixel_values, cache=self.cache) + output = self.model(y, pixel_values=pixel_values, cache=self.request_cache[request_id]) if self.shard.is_last_layer(): logits = output[:, -1, :] @@ -59,10 +64,10 @@ def __call__( ) -> Generator[Tuple[mx.array, mx.array], None, None]: return self.step(x, temp, top_p, logit_bias) - def reset(self): + def init_cache(self, request_id: str): kv_heads = ( [self.model.n_kv_heads] * len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads ) - self.cache = [KVCache(self.model.head_dim, n) for n in kv_heads] + self.request_cache[request_id] = [KVCache(self.model.head_dim, n) for n in kv_heads] \ No newline at end of file diff --git a/exo/inference/mlx/test_sharded_llava.py b/exo/inference/mlx/test_sharded_llava.py index 14a66aa27..1674f4252 100644 --- a/exo/inference/mlx/test_sharded_llava.py +++ b/exo/inference/mlx/test_sharded_llava.py @@ -34,11 +34,11 @@ input_ids = mx.array(inputs["input_ids"]) print(prompt) -y = full.step(input_ids, pixel_values, temp=0) +y = full.step("full", input_ids, pixel_values, temp=0) full_generated_tokens = [y.item()] for _ in range(13): - y = full.step(y, temp=0) + y = full.step("full", y, temp=0) full_generated_tokens.append(y.item()) full_response = full_processor.tokenizer.decode(full_generated_tokens) @@ -48,13 +48,13 @@ pixel_values = mx.array(inputs["pixel_values"]) input_ids = mx.array(inputs["input_ids"]) -y = m1.step(input_ids, pixel_values, temp=0) -y = m2.step(y, temp=0) +y = m1.step("shard", input_ids, pixel_values, temp=0) +y = m2.step("shard", y, temp=0) full_generated_tokens = [y.item()] for _ in range(13): - y = m1.step(y, temp=0) - y = m2.step(y, temp=0) + y = m1.step("shard", y, temp=0) + y = m2.step("shard", y, temp=0) full_generated_tokens.append(y.item()) sharded_response = processor2.tokenizer.decode(full_generated_tokens) From 2849128d6acad1c139d1c8963bf054ec3f16316a Mon Sep 17 00:00:00 2001 From: Varshith Date: Sun, 28 Jul 2024 01:03:34 +0530 Subject: [PATCH 05/16] processor load --- exo/inference/mlx/sharded_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/exo/inference/mlx/sharded_utils.py b/exo/inference/mlx/sharded_utils.py index 0bc8efe63..2964d5448 100644 --- a/exo/inference/mlx/sharded_utils.py +++ b/exo/inference/mlx/sharded_utils.py @@ -235,10 +235,10 @@ async def load_shard( model = apply_lora_layers(model, adapter_path) model.eval() - # TODO: figure out a better way - if "llama" in str(model_path): - tokenizer = load_tokenizer(model_path, tokenizer_config) - return model, tokenizer - elif "llava" in str(model_path): + # TODO: figure out a generic solution + if model.model_type == "llava": processor = AutoProcessor.from_pretrained(model_path) - return model, processor \ No newline at end of file + return model, processor + else: + tokenizer = load_tokenizer(model_path, tokenizer_config) + return model, tokenizer \ No newline at end of file From 833e7f3396602091b869425ff8106a7c7e52a2b0 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Sat, 27 Jul 2024 20:19:55 -0700 Subject: [PATCH 06/16] rename sharded_llava -> llava to match new convention --- README.md | 4 ++-- exo/inference/mlx/models/{sharded_llava.py => llava.py} | 0 2 files changed, 2 insertions(+), 2 deletions(-) rename exo/inference/mlx/models/{sharded_llava.py => llava.py} (100%) diff --git a/README.md b/README.md index 334e9bfc8..d4284e199 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ Forget expensive NVIDIA GPUs, unify your existing devices into one powerful GPU:

Update: Exo Supports Llama 3.1

Now the default models, run 8B, 70B and 405B parameter models on your own devices

-

See the code

+

See the code

## Get Involved @@ -40,7 +40,7 @@ We also welcome contributions from the community. We have a list of bounties in ### Wide Model Support -exo supports LLaMA ([MLX](exo/inference/mlx/models/sharded_llama.py) and [tinygrad](exo/inference/tinygrad/models/llama.py)) and other popular models. +exo supports LLaMA ([MLX](exo/inference/mlx/models/llama.py) and [tinygrad](exo/inference/tinygrad/models/llama.py)) and other popular models. ### Dynamic Model Partitioning diff --git a/exo/inference/mlx/models/sharded_llava.py b/exo/inference/mlx/models/llava.py similarity index 100% rename from exo/inference/mlx/models/sharded_llava.py rename to exo/inference/mlx/models/llava.py From 2aa1e24ea900c53212a8f4ce3e7db61b2a32ad47 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Sat, 27 Jul 2024 20:22:56 -0700 Subject: [PATCH 07/16] remove unused torch import --- exo/inference/mlx/test_sharded_llava.py | 1 - 1 file changed, 1 deletion(-) diff --git a/exo/inference/mlx/test_sharded_llava.py b/exo/inference/mlx/test_sharded_llava.py index 1674f4252..5e2e13adf 100644 --- a/exo/inference/mlx/test_sharded_llava.py +++ b/exo/inference/mlx/test_sharded_llava.py @@ -1,4 +1,3 @@ -import torch import codecs import asyncio import requests From b44b917151b12f4a087d2ad1b8cb1df3f88271b0 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Sat, 27 Jul 2024 20:23:03 -0700 Subject: [PATCH 08/16] add pillow as testing dependency --- setup.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/setup.py b/setup.py index 146f0ca1b..79fef2e83 100644 --- a/setup.py +++ b/setup.py @@ -41,6 +41,9 @@ "pylint==3.2.6", "ruff==0.5.5", "mypy==1.11.0", + ], + "testing": [ + "pillow==10.4.0" ] } From 2fb961fccd9c34f51bce44bd134b76a88a489fdb Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Sat, 27 Jul 2024 21:04:13 -0700 Subject: [PATCH 09/16] stick to same convention as new llama --- exo/inference/mlx/models/llava.py | 34 ++++++++++++++++-------------- exo/inference/mlx/sharded_utils.py | 2 +- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/exo/inference/mlx/models/llava.py b/exo/inference/mlx/models/llava.py index 9923846c3..d439002b5 100644 --- a/exo/inference/mlx/models/llava.py +++ b/exo/inference/mlx/models/llava.py @@ -9,6 +9,7 @@ import mlx.nn as nn from mlx_lm.models.base import BaseModelArgs, KVCache from exo.inference.shard import Shard +from .base import IdentityBlock import numpy as np @@ -369,11 +370,10 @@ def __call__( class Llama(nn.Module): - def __init__(self, config: TextConfig, is_first_layer, is_last_layer): + def __init__(self, config: TextConfig, shard: Shard): super().__init__() self.config = config - self.is_first_layer = is_first_layer - self.is_last_layer = is_last_layer + self.shard = shard self.vocab_size = config.vocab_size self.model_type = config.model_type self.num_hidden_layers = config.num_hidden_layers @@ -381,10 +381,14 @@ def __init__(self, config: TextConfig, is_first_layer, is_last_layer): self.head_dim = config.head_dim assert self.vocab_size > 0 self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) - self.layers = [ - TransformerBlock(config=config) for _ in range(config.num_hidden_layers) - ] - self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.layers = [] + for i in range(self.num_hidden_layers): + if self.shard.start_layer <= i <= self.shard.end_layer: + self.layers.append(TransformerBlock(config=config)) + else: + self.layers.append(IdentityBlock()) + if self.shard.is_last_layer(): + self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def __call__( self, @@ -394,7 +398,7 @@ def __call__( ): # for passing merged input embeddings if inputs_embeds is None: - if self.is_first_layer: + if self.shard.is_first_layer(): h = self.embed_tokens(inputs) else: h = inputs @@ -413,20 +417,20 @@ def __call__( for layer, c in zip(self.layers, cache): h = layer(h, mask, c) - if self.is_last_layer: + if self.shard.is_last_layer(): h = self.norm(h) return h class LanguageModel(nn.Module): - def __init__(self, config: TextConfig, is_first_layer, is_last_layer): + def __init__(self, config: TextConfig, shard: Shard): super().__init__() self.model_type = config.model_type if self.model_type != "llama": raise ValueError( f"Model type {self.model_type} not supported. Currently only 'llama' is supported" ) - self.is_last_layer = is_last_layer - self.model = Llama(config, is_first_layer, is_last_layer) + self.shard = shard + self.model = Llama(config, shard) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) def __call__( @@ -436,7 +440,7 @@ def __call__( inputs_embeds=None, ): out = self.model(inputs, cache, inputs_embeds) - if self.is_last_layer: + if self.shard.is_last_layer(): out = self.lm_head(out) return out @@ -485,8 +489,6 @@ def __post_init__(self): if not self.shard.is_first_layer(): self.vision_config = None - self.text_config.num_hidden_layers = self.shard.get_layer_count() - class LlavaMultiModalProjector(nn.Module): def __init__(self, config: LlaVAConfig): @@ -516,7 +518,7 @@ def __init__(self, config: ModelArgs): self.multi_modal_projector = LlavaMultiModalProjector(config) self.vision_feature_layer = config.vision_feature_layer self.vision_feature_select_strategy = config.vision_feature_select_strategy - self.language_model = LanguageModel(config.text_config, config.shard.is_first_layer(), config.shard.is_last_layer()) + self.language_model = LanguageModel(config.text_config, config.shard) def get_input_embeddings( self, diff --git a/exo/inference/mlx/sharded_utils.py b/exo/inference/mlx/sharded_utils.py index 592eae32b..e5429c64c 100644 --- a/exo/inference/mlx/sharded_utils.py +++ b/exo/inference/mlx/sharded_utils.py @@ -129,7 +129,7 @@ def load_model_shard( class_predicate=None, ) - model.load_weights(list(weights.items())) + model.load_weights(list(weights.items()), strict=True) if not lazy: mx.eval(model.parameters()) From 33cbacf513957fd24e9e874092f06975d0779460 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Sat, 27 Jul 2024 21:31:36 -0700 Subject: [PATCH 10/16] fix llava sanitize --- exo/inference/mlx/models/llava.py | 38 +++++++++++++++++++++---------- 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/exo/inference/mlx/models/llava.py b/exo/inference/mlx/models/llava.py index d439002b5..0745b0653 100644 --- a/exo/inference/mlx/models/llava.py +++ b/exo/inference/mlx/models/llava.py @@ -208,8 +208,7 @@ def __call__( ) -> mx.array: return self.vision_model(x, output_hidden_states) - @staticmethod - def sanitize(weights): + def sanitize(self, weights): sanitized_weights = {} for k, v in weights.items(): if "position_ids" in k: @@ -380,7 +379,8 @@ def __init__(self, config: TextConfig, shard: Shard): self.num_key_value_heads = config.num_key_value_heads self.head_dim = config.head_dim assert self.vocab_size > 0 - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + if self.shard.is_first_layer(): + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = [] for i in range(self.num_hidden_layers): if self.shard.start_layer <= i <= self.shard.end_layer: @@ -431,7 +431,8 @@ def __init__(self, config: TextConfig, shard: Shard): ) self.shard = shard self.model = Llama(config, shard) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + if self.shard.is_last_layer(): + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) def __call__( self, @@ -444,12 +445,24 @@ def __call__( out = self.lm_head(out) return out - @staticmethod - def sanitize(weights): - # Remove unused precomputed rotary freqs - return { - k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k - } + def sanitize(self, weights): + shard_state_dict = {} + for key, value in weights.items(): + if "self_attn.rotary_emb.inv_freq" in key: + continue + + if key.startswith('language_model.model.layers.'): + layer_num = int(key.split('.')[3]) + if layer_num < self.shard.start_layer or layer_num > self.shard.end_layer: + continue + if not self.shard.is_first_layer() and key.startswith('language_model.model.embed_tokens'): + continue + elif not self.shard.is_last_layer() and (key.startswith('language_model.model.norm') or key.startswith('language_model.lm_head')): + continue + + shard_state_dict[key] = value + + return shard_state_dict @dataclass class LlaVAConfig(BaseModelArgs): @@ -599,9 +612,10 @@ def __call__(self, input_ids: mx.array, pixel_values: mx.array = None, cache=Non def sanitize(self, weights): if self.config.vision_config: - weights = self.vision_tower.sanitize(weights) + weights = self.vision_tower.sanitize(weights) + else: + weights = {k: v for k, v in weights.items() if not k.startswith(('vision_tower', 'multi_modal_projector', 'vision_feature_layer', 'vision_feature_select_strategy'))} weights = self.language_model.sanitize(weights) - return weights @property From acc94b50c73c7f65658461a6c6b3f44994e0fab7 Mon Sep 17 00:00:00 2001 From: Varshith Date: Sun, 28 Jul 2024 16:12:21 +0530 Subject: [PATCH 11/16] chatgpt api integration --- exo/api/chatgpt_api.py | 43 ++++++++++--- exo/inference/mlx/sharded_inference_engine.py | 13 +++- exo/inference/mlx/sharded_utils.py | 14 ++++ exo/inference/mlx/test_sharded_llava.py | 8 ++- exo/networking/grpc/grpc_peer_handle.py | 3 +- exo/networking/grpc/grpc_server.py | 5 +- exo/networking/grpc/node_service.proto | 5 +- exo/networking/grpc/node_service_pb2.py | 64 +++++++++---------- exo/networking/peer_handle.py | 2 +- exo/orchestration/node.py | 2 +- exo/orchestration/standard_node.py | 21 +++--- setup.py | 2 +- 12 files changed, 119 insertions(+), 63 deletions(-) diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index 9795fa1a3..fff4a63da 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -3,7 +3,7 @@ import asyncio import json from pathlib import Path -from transformers import AutoTokenizer +from transformers import AutoTokenizer, AutoProcessor from typing import List, Literal, Union, Dict from aiohttp import web import aiohttp_cors @@ -42,11 +42,15 @@ "deepseek-coder-v2-lite": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27), }, + ### llava + "llava-1.5-7b-hf": { + "MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32), + }, } class Message: - def __init__(self, role: str, content: str): + def __init__(self, role: str, content: Union[str, list]): self.role = role self.content = content @@ -68,6 +72,18 @@ def resolve_tinygrad_tokenizer(model_id: str): async def resolve_tokenizer(model_id: str): + try: + if DEBUG >= 2: print(f"Trying to AutoProcessor for {model_id}") + processor = AutoProcessor.from_pretrained(model_id) + processor.eos_token_id = processor.tokenizer.eos_token_id + processor.encode = processor.tokenizer.encode + return processor + except Exception as e: + if DEBUG >= 2: print(f"Failed to load processor for {model_id}. Error: {e}") + import traceback + + if DEBUG >= 2: print(traceback.format_exc()) + try: if DEBUG >= 2: print(f"Trying AutoTokenizer for {model_id}") return AutoTokenizer.from_pretrained(model_id) @@ -138,7 +154,18 @@ def generate_completion( def build_prompt(tokenizer, messages: List[Message]): - return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + image_str = None + for message in messages: + if not isinstance(message.content, list): + continue + + for content in message.content: + if content.get("type", None) == "image": + image_str = content.get("image", None) + break + + return prompt, image_str def parse_message(data: dict): @@ -195,7 +222,7 @@ async def handle_post_chat_token_encode(self, request): shard = shard_mappings.get(data.get("model", "llama-3.1-8b"), {}).get(self.inference_engine_classname) messages = [parse_message(msg) for msg in data.get("messages", [])] tokenizer = await resolve_tokenizer(shard.model_id) - return web.json_response({"length": len(build_prompt(tokenizer, messages))}) + return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])}) async def handle_post_chat_completions(self, request): data = await request.json() @@ -219,13 +246,13 @@ async def handle_post_chat_completions(self, request): tokenizer = await resolve_tokenizer(shard.model_id) if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}") - prompt = build_prompt(tokenizer, chat_request.messages) + prompt, image_str = build_prompt(tokenizer, chat_request.messages) callback_id = f"chatgpt-api-wait-response-{request_id}" callback = self.node.on_token.register(callback_id) - if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}") + if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=} {image_str=}") try: - await self.node.process_prompt(shard, prompt, request_id=request_id) + await self.node.process_prompt(shard, prompt, image_str, request_id=request_id) except Exception as e: if DEBUG >= 2: import traceback @@ -294,7 +321,7 @@ def on_result(_request_id: str, tokens: List[int], is_finished: bool): ) finish_reason = "length" - eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id + eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(getattr(tokenizer, "_tokenizer", None), AutoTokenizer) else tokenizer.eos_token_id if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}") if tokens[-1] == eos_token_id: tokens = tokens[:-1] diff --git a/exo/inference/mlx/sharded_inference_engine.py b/exo/inference/mlx/sharded_inference_engine.py index b5104c729..9726123bf 100644 --- a/exo/inference/mlx/sharded_inference_engine.py +++ b/exo/inference/mlx/sharded_inference_engine.py @@ -2,7 +2,7 @@ import mlx.core as mx from ..inference_engine import InferenceEngine from .sharded_model import StatefulShardedModel -from .sharded_utils import load_shard +from .sharded_utils import load_shard, get_image_from_str from ..shard import Shard from typing import Optional @@ -11,9 +11,16 @@ class MLXDynamicShardInferenceEngine(InferenceEngine): def __init__(self): self.shard = None - async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool): + async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool): await self.ensure_shard(shard) - output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, mx.array(self.tokenizer.encode(prompt)))) + if image_str: + image = get_image_from_str(image_str) + inputs = self.tokenizer(prompt, image, return_tensors="np") + pixel_values = mx.array(inputs["pixel_values"]) + input_ids = mx.array(inputs["input_ids"]) + output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, input_ids, pixel_values)) + else: + output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, mx.array(self.tokenizer.encode(prompt)))) return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool): diff --git a/exo/inference/mlx/sharded_utils.py b/exo/inference/mlx/sharded_utils.py index e5429c64c..296d552f9 100644 --- a/exo/inference/mlx/sharded_utils.py +++ b/exo/inference/mlx/sharded_utils.py @@ -8,6 +8,9 @@ from functools import partial from pathlib import Path from typing import Optional, Tuple +import requests +from PIL import Image +from io import BytesIO import mlx.core as mx import mlx.nn as nn @@ -222,7 +225,18 @@ async def load_shard( # TODO: figure out a generic solution if model.model_type == "llava": processor = AutoProcessor.from_pretrained(model_path) + processor.eos_token_id = processor.tokenizer.eos_token_id + processor.encode = processor.tokenizer.encode return model, processor else: tokenizer = load_tokenizer(model_path, tokenizer_config) return model, tokenizer + +def get_image_from_str(image_str: str): + if image_str.startswith("http"): + response = requests.get(image_str, timeout=10) + image = Image.open(BytesIO(response.content)).convert("RGB") + else: + imgdata = base64.b64decode(image_str) + image = Image.open(io.BytesIO(imgdata)) + return image diff --git a/exo/inference/mlx/test_sharded_llava.py b/exo/inference/mlx/test_sharded_llava.py index 5e2e13adf..f2e75c485 100644 --- a/exo/inference/mlx/test_sharded_llava.py +++ b/exo/inference/mlx/test_sharded_llava.py @@ -15,9 +15,11 @@ shard1 = Shard("llava", 0, 12, 32) shard2 = Shard("llava", 13, 31, 32) -full_model_shard, full_processor = asyncio.run(load_shard("llava-hf/llava-1.5-7b-hf", shard=shard_full)) -model_shard1, processor1 = asyncio.run(load_shard("llava-hf/llava-1.5-7b-hf", shard=shard1)) -model_shard2, processor2 = asyncio.run(load_shard("llava-hf/llava-1.5-7b-hf", shard=shard2)) +model_path = "llava-hf/llava-1.5-7b-hf" + +full_model_shard, full_processor = asyncio.run(load_shard(model_path, shard=shard_full)) +model_shard1, processor1 = asyncio.run(load_shard(model_path, shard=shard1)) +model_shard2, processor2 = asyncio.run(load_shard(model_path, shard=shard2)) full = StatefulShardedModel(shard_full, full_model_shard) m1 = StatefulShardedModel(shard1, model_shard1) diff --git a/exo/networking/grpc/grpc_peer_handle.py b/exo/networking/grpc/grpc_peer_handle.py index 78009d49a..e3e98dceb 100644 --- a/exo/networking/grpc/grpc_peer_handle.py +++ b/exo/networking/grpc/grpc_peer_handle.py @@ -39,9 +39,10 @@ async def disconnect(self): self.channel = None self.stub = None - async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]: + async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]: request = node_service_pb2.PromptRequest( prompt=prompt, + image_str=image_str, shard=node_service_pb2.Shard( model_id=shard.model_id, start_layer=shard.start_layer, diff --git a/exo/networking/grpc/grpc_server.py b/exo/networking/grpc/grpc_server.py index 13d4bf6fb..5992ce1f4 100644 --- a/exo/networking/grpc/grpc_server.py +++ b/exo/networking/grpc/grpc_server.py @@ -45,9 +45,10 @@ async def SendPrompt(self, request, context): n_layers=request.shard.n_layers, ) prompt = request.prompt + image_str = request.image_str request_id = request.request_id - result = await self.node.process_prompt(shard, prompt, request_id) - if DEBUG >= 2: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}") + result = await self.node.process_prompt(shard, prompt, image_str, request_id) + if DEBUG >= 2: print(f"SendPrompt {shard=} {prompt=} {image=} {request_id=} result: {result}") tensor_data = result.tobytes() if result is not None else None return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor() diff --git a/exo/networking/grpc/node_service.proto b/exo/networking/grpc/node_service.proto index d76430caf..6fcee3515 100644 --- a/exo/networking/grpc/node_service.proto +++ b/exo/networking/grpc/node_service.proto @@ -21,8 +21,9 @@ message Shard { message PromptRequest { Shard shard = 1; string prompt = 2; - optional string request_id = 3; - optional string inference_state = 4; + optional string image_str = 3; + optional string request_id = 4; + optional string inference_state = 5; } message TensorRequest { diff --git a/exo/networking/grpc/node_service_pb2.py b/exo/networking/grpc/node_service_pb2.py index 66e516c7b..62ffb1bb2 100644 --- a/exo/networking/grpc/node_service_pb2.py +++ b/exo/networking/grpc/node_service_pb2.py @@ -14,7 +14,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"\x9d\x01\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x1c\n\x0finference_state\x18\x04 \x01(\tH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"\xb3\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x1c\n\x0finference_state\x18\x04 \x01(\tH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"/\n\x19GetInferenceResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\"\\\n\x0fInferenceResult\x12)\n\x06tensor\x18\x01 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x12\x13\n\x0bis_finished\x18\x02 \x01(\x08\x42\t\n\x07_tensor\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x8e\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1a\x45\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.node_service.Peers:\x02\x38\x01\"\x19\n\x05Peers\x12\x10\n\x08peer_ids\x18\x01 \x03(\t\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x02\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x02\x12\x0c\n\x04int8\x18\x03 \x01(\x02\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"L\n\x11SendResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x03(\x05\x12\x13\n\x0bis_finished\x18\x03 \x01(\x08\"=\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\"\x07\n\x05\x45mpty2\xde\x03\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12^\n\x12GetInferenceResult\x12\'.node_service.GetInferenceResultRequest\x1a\x1d.node_service.InferenceResult\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12\x44\n\nSendResult\x12\x1f.node_service.SendResultRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"\xc3\x01\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x16\n\timage_str\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x17\n\nrequest_id\x18\x04 \x01(\tH\x01\x88\x01\x01\x12\x1c\n\x0finference_state\x18\x05 \x01(\tH\x02\x88\x01\x01\x42\x0c\n\n_image_strB\r\n\x0b_request_idB\x12\n\x10_inference_state\"\xb3\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x1c\n\x0finference_state\x18\x04 \x01(\tH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"/\n\x19GetInferenceResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\"\\\n\x0fInferenceResult\x12)\n\x06tensor\x18\x01 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x12\x13\n\x0bis_finished\x18\x02 \x01(\x08\x42\t\n\x07_tensor\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x8e\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1a\x45\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.node_service.Peers:\x02\x38\x01\"\x19\n\x05Peers\x12\x10\n\x08peer_ids\x18\x01 \x03(\t\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x02\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x02\x12\x0c\n\x04int8\x18\x03 \x01(\x02\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"L\n\x11SendResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x03(\x05\x12\x13\n\x0bis_finished\x18\x03 \x01(\x08\"=\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\"\x07\n\x05\x45mpty2\xde\x03\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12^\n\x12GetInferenceResult\x12\'.node_service.GetInferenceResultRequest\x1a\x1d.node_service.InferenceResult\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12\x44\n\nSendResult\x12\x1f.node_service.SendResultRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -28,35 +28,35 @@ _globals['_SHARD']._serialized_start=36 _globals['_SHARD']._serialized_end=119 _globals['_PROMPTREQUEST']._serialized_start=122 - _globals['_PROMPTREQUEST']._serialized_end=279 - _globals['_TENSORREQUEST']._serialized_start=282 - _globals['_TENSORREQUEST']._serialized_end=461 - _globals['_GETINFERENCERESULTREQUEST']._serialized_start=463 - _globals['_GETINFERENCERESULTREQUEST']._serialized_end=510 - _globals['_INFERENCERESULT']._serialized_start=512 - _globals['_INFERENCERESULT']._serialized_end=604 - _globals['_TENSOR']._serialized_start=606 - _globals['_TENSOR']._serialized_end=665 - _globals['_COLLECTTOPOLOGYREQUEST']._serialized_start=667 - _globals['_COLLECTTOPOLOGYREQUEST']._serialized_end=727 - _globals['_TOPOLOGY']._serialized_start=730 - _globals['_TOPOLOGY']._serialized_end=1000 - _globals['_TOPOLOGY_NODESENTRY']._serialized_start=851 - _globals['_TOPOLOGY_NODESENTRY']._serialized_end=929 - _globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_start=931 - _globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_end=1000 - _globals['_PEERS']._serialized_start=1002 - _globals['_PEERS']._serialized_end=1027 - _globals['_DEVICEFLOPS']._serialized_start=1029 - _globals['_DEVICEFLOPS']._serialized_end=1084 - _globals['_DEVICECAPABILITIES']._serialized_start=1086 - _globals['_DEVICECAPABILITIES']._serialized_end=1193 - _globals['_SENDRESULTREQUEST']._serialized_start=1195 - _globals['_SENDRESULTREQUEST']._serialized_end=1271 - _globals['_SENDOPAQUESTATUSREQUEST']._serialized_start=1273 - _globals['_SENDOPAQUESTATUSREQUEST']._serialized_end=1334 - _globals['_EMPTY']._serialized_start=1336 - _globals['_EMPTY']._serialized_end=1343 - _globals['_NODESERVICE']._serialized_start=1346 - _globals['_NODESERVICE']._serialized_end=1824 + _globals['_PROMPTREQUEST']._serialized_end=317 + _globals['_TENSORREQUEST']._serialized_start=320 + _globals['_TENSORREQUEST']._serialized_end=499 + _globals['_GETINFERENCERESULTREQUEST']._serialized_start=501 + _globals['_GETINFERENCERESULTREQUEST']._serialized_end=548 + _globals['_INFERENCERESULT']._serialized_start=550 + _globals['_INFERENCERESULT']._serialized_end=642 + _globals['_TENSOR']._serialized_start=644 + _globals['_TENSOR']._serialized_end=703 + _globals['_COLLECTTOPOLOGYREQUEST']._serialized_start=705 + _globals['_COLLECTTOPOLOGYREQUEST']._serialized_end=765 + _globals['_TOPOLOGY']._serialized_start=768 + _globals['_TOPOLOGY']._serialized_end=1038 + _globals['_TOPOLOGY_NODESENTRY']._serialized_start=889 + _globals['_TOPOLOGY_NODESENTRY']._serialized_end=967 + _globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_start=969 + _globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_end=1038 + _globals['_PEERS']._serialized_start=1040 + _globals['_PEERS']._serialized_end=1065 + _globals['_DEVICEFLOPS']._serialized_start=1067 + _globals['_DEVICEFLOPS']._serialized_end=1122 + _globals['_DEVICECAPABILITIES']._serialized_start=1124 + _globals['_DEVICECAPABILITIES']._serialized_end=1231 + _globals['_SENDRESULTREQUEST']._serialized_start=1233 + _globals['_SENDRESULTREQUEST']._serialized_end=1309 + _globals['_SENDOPAQUESTATUSREQUEST']._serialized_start=1311 + _globals['_SENDOPAQUESTATUSREQUEST']._serialized_end=1372 + _globals['_EMPTY']._serialized_start=1374 + _globals['_EMPTY']._serialized_end=1381 + _globals['_NODESERVICE']._serialized_start=1384 + _globals['_NODESERVICE']._serialized_end=1862 # @@protoc_insertion_point(module_scope) diff --git a/exo/networking/peer_handle.py b/exo/networking/peer_handle.py index 27ff42013..cf232d006 100644 --- a/exo/networking/peer_handle.py +++ b/exo/networking/peer_handle.py @@ -28,7 +28,7 @@ async def disconnect(self) -> None: pass @abstractmethod - async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]: + async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]: pass @abstractmethod diff --git a/exo/orchestration/node.py b/exo/orchestration/node.py index 93b777d48..60b729748 100644 --- a/exo/orchestration/node.py +++ b/exo/orchestration/node.py @@ -16,7 +16,7 @@ async def stop(self) -> None: pass @abstractmethod - async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]: + async def process_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]: pass @abstractmethod diff --git a/exo/orchestration/standard_node.py b/exo/orchestration/standard_node.py index 23e3e0323..55efdebc0 100644 --- a/exo/orchestration/standard_node.py +++ b/exo/orchestration/standard_node.py @@ -69,7 +69,7 @@ async def stop(self) -> None: await self.discovery.stop() await self.server.stop() - async def process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]: + async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]: shard = self.get_current_shard(base_shard) asyncio.create_task( self.broadcast_opaque_status( @@ -82,6 +82,7 @@ async def process_prompt(self, base_shard: Shard, prompt: str, request_id: Optio "base_shard": base_shard.to_dict(), "shard": shard.to_dict(), "prompt": prompt, + "image_str": image_str, "inference_state": inference_state, "request_id": request_id, } @@ -89,7 +90,7 @@ async def process_prompt(self, base_shard: Shard, prompt: str, request_id: Optio ) ) start_time = time.perf_counter_ns() - resp = await self._process_prompt(base_shard, prompt, request_id, inference_state) + resp = await self._process_prompt(base_shard, prompt, image_str, request_id, inference_state) end_time = time.perf_counter_ns() elapsed_time_ns = end_time - start_time asyncio.create_task( @@ -103,6 +104,7 @@ async def process_prompt(self, base_shard: Shard, prompt: str, request_id: Optio "base_shard": base_shard.to_dict(), "shard": shard.to_dict(), "prompt": prompt, + "image_str": image_str, "inference_state": inference_state, "request_id": request_id, "elapsed_time_ns": elapsed_time_ns, @@ -113,20 +115,20 @@ async def process_prompt(self, base_shard: Shard, prompt: str, request_id: Optio ) return resp - async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]: + async def _process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]: if request_id is None: request_id = str(uuid.uuid4()) if request_id not in self.buffered_token_output: self.buffered_token_output[request_id] = ([], False) shard = self.get_current_shard(base_shard) - if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}") + if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=} {image_str=}") if shard.start_layer != 0: - if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}") - await self.forward_to_next_shard(shard, prompt, request_id) + if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=} {image_str=}") + await self.forward_to_next_shard(shard, prompt, request_id, image_str) return - result, inference_state, is_finished = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state=inference_state) + result, inference_state, is_finished = await self.inference_engine.infer_prompt(request_id, shard, prompt, image_str, inference_state=inference_state) is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens if is_finished: self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True) @@ -234,6 +236,7 @@ async def forward_to_next_shard( base_shard: Shard, tensor_or_prompt: Union[np.ndarray, str], request_id: str, + image_str: Optional[str] = None, inference_state: Optional[str] = None, ) -> None: if not self.partitioning_strategy: @@ -255,7 +258,7 @@ async def forward_to_next_shard( if isinstance(tensor_or_prompt, np.ndarray): await self.process_tensor(shard, tensor_or_prompt, request_id, inference_state=inference_state) else: - await self.process_prompt(shard, tensor_or_prompt, request_id, inference_state=inference_state) + await self.process_prompt(shard, tensor_or_prompt, image_str, request_id, inference_state=inference_state) return target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None) @@ -267,7 +270,7 @@ async def forward_to_next_shard( if isinstance(tensor_or_prompt, np.ndarray): await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state) else: - await target_peer.send_prompt(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state) + await target_peer.send_prompt(next_shard, tensor_or_prompt, image_str=image_str, request_id=request_id, inference_state=inference_state) def get_current_shard(self, base_shard: Shard) -> Shard: partitions = self.partitioning_strategy.partition(self.topology) diff --git a/setup.py b/setup.py index 79fef2e83..53f84e195 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ "tiktoken==0.7.0", "tokenizers==0.19.1", "tqdm==4.66.4", - "transformers==4.41.2", + "transformers==4.43.3", "uuid==1.30", "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@a9f5a764dc640a5e5cbaaeeee21df7c8ca37da38", ] From 8d3d3df1dd902aca9be3da9a02f4d1e3ba9c4a6b Mon Sep 17 00:00:00 2001 From: Varshith Date: Sun, 28 Jul 2024 16:14:27 +0530 Subject: [PATCH 12/16] update readme --- README.md | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index d4284e199..2054fd224 100644 --- a/README.md +++ b/README.md @@ -111,7 +111,7 @@ The native way to access models running on exo is using the exo library with pee exo starts a ChatGPT-like WebUI (powered by [tinygrad tinychat](https://github.com/tinygrad/tinygrad/tree/master/examples/tinychat)) on http://localhost:8000 -For developers, exo also starts a ChatGPT-compatible API endpoint on http://localhost:8000/v1/chat/completions. Example with curl: +For developers, exo also starts a ChatGPT-compatible API endpoint on http://localhost:8000/v1/chat/completions. Example with curls: ```sh curl http://localhost:8000/v1/chat/completions \ @@ -123,6 +123,30 @@ curl http://localhost:8000/v1/chat/completions \ }' ``` +```sh +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "llava-1.5-7b-hf", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What are these?" + }, + { + "type": "image", + "image": "http://images.cocodataset.org/val2017/000000039769.jpg" + } + ] + } + ], + "temperature": 0.0 + }' +``` + ## Debugging Enable debug logs with the DEBUG environment variable (0-9). From 78db451d7eb40b12374ff16c7ebb9db0d1e333b6 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Tue, 30 Jul 2024 14:27:45 +0100 Subject: [PATCH 13/16] add pillow to main dependencies --- setup.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 76619596c..dca7f3f3f 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,7 @@ "huggingface-hub==0.23.4", "Jinja2==3.1.4", "numpy==2.0.0", + "pillow==10.4.0", "prometheus-client==0.20.0", "protobuf==5.27.1", "psutil==6.0.0", @@ -42,9 +43,6 @@ "ruff==0.5.5", "mypy==1.11.0", ], - "testing": [ - "pillow==10.4.0" - ] } setup( From e68d06f4ef1780acb3c9e69160846ad8c9df1d76 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Tue, 30 Jul 2024 14:51:22 +0100 Subject: [PATCH 14/16] move model-selector styles to index.css --- tinychat/examples/tinychat/index.css | 19 +++++++++++++++++++ tinychat/examples/tinychat/index.html | 21 --------------------- 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/tinychat/examples/tinychat/index.css b/tinychat/examples/tinychat/index.css index 8f0908fb9..84b04b86e 100644 --- a/tinychat/examples/tinychat/index.css +++ b/tinychat/examples/tinychat/index.css @@ -291,3 +291,22 @@ p { .monospace { font-family: monospace; } + +.model-selector { + display: flex; + justify-content: center; + padding: 20px 0; +} +.model-selector select { + padding: 10px 20px; + font-size: 16px; + border: 1px solid #ccc; + border-radius: 5px; + background-color: #f8f8f8; + cursor: pointer; +} +.model-selector select:focus { + outline: none; + border-color: #007bff; + box-shadow: 0 0 0 2px rgba(0,123,255,.25); +} \ No newline at end of file diff --git a/tinychat/examples/tinychat/index.html b/tinychat/examples/tinychat/index.html index b95a26c12..e45c7f477 100644 --- a/tinychat/examples/tinychat/index.html +++ b/tinychat/examples/tinychat/index.html @@ -30,27 +30,6 @@ - - From 0d45a855fbb907c228ad894243260877b9fe18f6 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Tue, 30 Jul 2024 20:01:18 +0100 Subject: [PATCH 15/16] increase max request size to send raw images, make image download from url async, use chatgpt-compatible convention for images --- README.md | 6 ++- exo/api/chatgpt_api.py | 40 ++++++++++++++++--- exo/inference/mlx/sharded_inference_engine.py | 2 +- exo/inference/mlx/sharded_utils.py | 34 ++++++++++++---- 4 files changed, 66 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 2054fd224..639a80be0 100644 --- a/README.md +++ b/README.md @@ -137,8 +137,10 @@ curl http://localhost:8000/v1/chat/completions \ "text": "What are these?" }, { - "type": "image", - "image": "http://images.cocodataset.org/val2017/000000039769.jpg" + "type": "image_url", + "image_url": { + "url": "http://images.cocodataset.org/val2017/000000039769.jpg" + } } ] } diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index fff4a63da..413b7f46c 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -153,14 +153,45 @@ def generate_completion( return completion -def build_prompt(tokenizer, messages: List[Message]): - prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) +def remap_messages(messages: List[Message]) -> List[Message]: + remapped_messages = [] + last_image = None + for message in messages: + remapped_content = [] + for content in message.content: + if isinstance(content, dict): + if content.get("type") in ["image_url", "image"]: + image_url = content.get("image_url", {}).get("url") or content.get("image") + if image_url: + last_image = {"type": "image", "image": image_url} + remapped_content.append({"type": "text", "text": "[An image was uploaded but is not displayed here]"}) + else: + remapped_content.append(content) + else: + remapped_content.append({"type": "text", "text": content}) + remapped_messages.append(Message(role=message.role, content=remapped_content)) + + if last_image: + # Replace the last image placeholder with the actual image content + for message in reversed(remapped_messages): + for i, content in enumerate(message.content): + if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]": + message.content[i] = last_image + return remapped_messages + + return remapped_messages + +def build_prompt(tokenizer, _messages: List[Message]): + messages = remap_messages(_messages) + prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) image_str = None for message in messages: if not isinstance(message.content, list): continue for content in message.content: + # note: we only support one image at a time right now. Multiple is possible. See: https://github.com/huggingface/transformers/blob/e68ec18ce224af879f22d904c7505a765fb77de3/docs/source/en/model_doc/llava.md?plain=1#L41 + # follows the convention in https://platform.openai.com/docs/guides/vision if content.get("type", None) == "image": image_str = content.get("image", None) break @@ -187,7 +218,7 @@ def __init__(self, node: Node, inference_engine_classname: str, response_timeout self.node = node self.inference_engine_classname = inference_engine_classname self.response_timeout_secs = response_timeout_secs - self.app = web.Application() + self.app = web.Application(client_max_size=100 * 1024 * 1024) # 100MB to support image upload self.prev_token_lens: Dict[str, int] = {} self.stream_tasks: Dict[str, asyncio.Task] = {} cors = aiohttp_cors.setup(self.app) @@ -214,7 +245,6 @@ async def middleware(request): return middleware async def handle_root(self, request): - print(f"Handling root request from {request.remote}") return web.FileResponse(self.static_dir / "index.html") async def handle_post_chat_token_encode(self, request): @@ -279,7 +309,7 @@ async def stream_result(request_id: str, tokens: List[int], is_finished: bool): self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens)) new_tokens = tokens[prev_last_tokens_len:] finish_reason = None - eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id + eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer, AutoTokenizer) else getattr(tokenizer, "eos_token_id", None) if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id: new_tokens = new_tokens[:-1] if is_finished: diff --git a/exo/inference/mlx/sharded_inference_engine.py b/exo/inference/mlx/sharded_inference_engine.py index 9726123bf..7fa16aded 100644 --- a/exo/inference/mlx/sharded_inference_engine.py +++ b/exo/inference/mlx/sharded_inference_engine.py @@ -14,7 +14,7 @@ def __init__(self): async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool): await self.ensure_shard(shard) if image_str: - image = get_image_from_str(image_str) + image = await get_image_from_str(image_str) inputs = self.tokenizer(prompt, image, return_tensors="np") pixel_values = mx.array(inputs["pixel_values"]) input_ids = mx.array(inputs["input_ids"]) diff --git a/exo/inference/mlx/sharded_utils.py b/exo/inference/mlx/sharded_utils.py index 296d552f9..d6f2e2c78 100644 --- a/exo/inference/mlx/sharded_utils.py +++ b/exo/inference/mlx/sharded_utils.py @@ -5,13 +5,16 @@ import json import logging import asyncio +import aiohttp from functools import partial from pathlib import Path from typing import Optional, Tuple import requests from PIL import Image from io import BytesIO +import base64 +from exo import DEBUG import mlx.core as mx import mlx.nn as nn from huggingface_hub import snapshot_download @@ -232,11 +235,26 @@ async def load_shard( tokenizer = load_tokenizer(model_path, tokenizer_config) return model, tokenizer -def get_image_from_str(image_str: str): - if image_str.startswith("http"): - response = requests.get(image_str, timeout=10) - image = Image.open(BytesIO(response.content)).convert("RGB") - else: - imgdata = base64.b64decode(image_str) - image = Image.open(io.BytesIO(imgdata)) - return image +async def get_image_from_str(_image_str: str): + image_str = _image_str.strip() + + if image_str.startswith("http"): + async with aiohttp.ClientSession() as session: + async with session.get(image_str, timeout=10) as response: + content = await response.read() + return Image.open(BytesIO(content)).convert("RGB") + elif image_str.startswith("data:image/"): + # Extract the image format and base64 data + format_prefix, base64_data = image_str.split(";base64,") + image_format = format_prefix.split("/")[1].lower() + if DEBUG >= 2: print(f"{image_str=} {image_format=}") + imgdata = base64.b64decode(base64_data) + img = Image.open(BytesIO(imgdata)) + + # Convert to RGB if not already + if img.mode != "RGB": + img = img.convert("RGB") + + return img + else: + raise ValueError("Invalid image_str format. Must be a URL or a base64 encoded image.") From af1c7ce327cf0f25fa597d68751b0157d1136287 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Tue, 30 Jul 2024 20:01:35 +0100 Subject: [PATCH 16/16] add support for image upload to tinychat for vision models --- tinychat/examples/tinychat/index.css | 70 ++++++++++++++++++++++++++ tinychat/examples/tinychat/index.html | 13 ++++- tinychat/examples/tinychat/index.js | 71 +++++++++++++++++++++++++-- 3 files changed, 150 insertions(+), 4 deletions(-) diff --git a/tinychat/examples/tinychat/index.css b/tinychat/examples/tinychat/index.css index 84b04b86e..b97546835 100644 --- a/tinychat/examples/tinychat/index.css +++ b/tinychat/examples/tinychat/index.css @@ -309,4 +309,74 @@ p { outline: none; border-color: #007bff; box-shadow: 0 0 0 2px rgba(0,123,255,.25); +} + +/* Image upload button styles */ +.image-input-button { + background-color: var(--secondary-color); + color: var(--foreground-color); + border: none; + border-radius: 50%; + width: 40px; + height: 40px; + font-size: 18px; + cursor: pointer; + transition: all 0.3s ease; + display: flex; + align-items: center; + justify-content: center; + margin-right: 10px; +} + +.image-input-button:hover { + background-color: var(--secondary-color-transparent); + transform: scale(1.1); +} + +.image-input-button:focus { + outline: none; + box-shadow: 0 0 0 3px rgba(var(--secondary-color-rgb), 0.5); +} + +.image-input-button i { + transition: all 0.3s ease; +} + +.image-input-button:hover i { + transform: scale(1.2); +} + +/* Hidden file input styles */ +#image-upload { + display: none; +} + +.image-preview-container { + position: relative; + display: inline-block; + margin-right: 10px; +} + +.image-preview { + max-width: 100px; + max-height: 100px; + object-fit: cover; + border-radius: 5px; +} + +.remove-image-button { + position: absolute; + top: -5px; + right: -5px; + background-color: rgba(255, 255, 255, 0.8); + border: none; + border-radius: 50%; + padding: 2px 5px; + cursor: pointer; +} + +.message > p > img { + max-width: 100%; + max-height: 100%; + object-fit: contain; } \ No newline at end of file diff --git a/tinychat/examples/tinychat/index.html b/tinychat/examples/tinychat/index.html index e45c7f477..ea79704a1 100644 --- a/tinychat/examples/tinychat/index.html +++ b/tinychat/examples/tinychat/index.html @@ -44,6 +44,7 @@ +
+ + +
+ Uploaded Image + +