diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index d4958732c..8e7c44d0f 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -8,7 +8,7 @@ from aiohttp import web import aiohttp_cors from exo import DEBUG, VERSION -from exo.helpers import terminal_link +from exo.helpers import terminal_link, PrefixDict from exo.inference.shard import Shard from exo.orchestration import Node @@ -49,6 +49,7 @@ } + class Message: def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]): self.role = role @@ -234,6 +235,11 @@ def parse_chat_request(data: dict): data.get("temperature", 0.0), ) +class PromptSession: + def __init__(self, request_id: str, timestamp: int, prompt: str): + self.request_id = request_id + self.timestamp = timestamp + self.prompt = prompt class ChatGPTAPI: def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90): @@ -241,6 +247,7 @@ def __init__(self, node: Node, inference_engine_classname: str, response_timeout self.inference_engine_classname = inference_engine_classname self.response_timeout_secs = response_timeout_secs self.app = web.Application(client_max_size=100 * 1024 * 1024) # 100MB to support image upload + self.prompts: PrefixDict[str, PromptSession] = PrefixDict() self.prev_token_lens: Dict[str, int] = {} self.stream_tasks: Dict[str, asyncio.Task] = {} cors = aiohttp_cors.setup(self.app) @@ -293,12 +300,24 @@ async def handle_post_chat_completions(self, request): {"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"}, status=400, ) - request_id = str(uuid.uuid4()) tokenizer = await resolve_tokenizer(shard.model_id) if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}") prompt, image_str = build_prompt(tokenizer, chat_request.messages) + request_id = None + match = self.prompts.find_longest_prefix(prompt) + if match: + if DEBUG >= 2: + print(f"Prompt for request starts with previous prompt {len(match[1].prompt)} of {len(prompt)}: {match[1].prompt}") + request_id = match[1].request_id + self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt)) + # remove the matching prefix from the prompt + prompt = prompt[len(match[1].prompt):] + else: + request_id = str(uuid.uuid4()) + self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt)) + callback_id = f"chatgpt-api-wait-response-{request_id}" callback = self.node.on_token.register(callback_id) diff --git a/exo/helpers.py b/exo/helpers.py index 713e40717..45b5d95c4 100644 --- a/exo/helpers.py +++ b/exo/helpers.py @@ -1,6 +1,7 @@ import os import asyncio -from typing import Any, Callable, Coroutine, TypeVar, Optional, Dict, Generic, Tuple +from typing import Any, Callable, TypeVar, Optional, Dict, Generic, Tuple, List +from collections import defaultdict import socket import random import platform @@ -97,8 +98,6 @@ def terminal_link(uri, label=None): T = TypeVar("T") K = TypeVar("K") - - class AsyncCallback(Generic[T]): def __init__(self) -> None: self.condition: asyncio.Condition = asyncio.Condition() @@ -147,3 +146,23 @@ def trigger(self, name: K, *args: T) -> None: def trigger_all(self, *args: T) -> None: for callback in self.callbacks.values(): callback.set(*args) + + +K = TypeVar('K', bound=str) +V = TypeVar('V') +class PrefixDict(Generic[K, V]): + def __init__(self): + self.items: Dict[K, V] = {} + + def add(self, key: K, value: V) -> None: + self.items[key] = value + + def find_prefix(self, argument: str) -> List[Tuple[K, V]]: + return [(key, value) for key, value in self.items.items() if argument.startswith(key)] + + def find_longest_prefix(self, argument: str) -> Optional[Tuple[K, V]]: + matches = self.find_prefix(argument) + if len(matches) == 0: + return None + + return max(matches, key=lambda x: len(x[0]))