-
Notifications
You must be signed in to change notification settings - Fork 60
/
main.py
68 lines (57 loc) · 1.86 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import json
import os
import time
from typing import List, Dict, Any
from huggingface_hub import login
from pydantic import BaseModel
from vllm import SamplingParams, AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs
login(token=os.environ.get("HF_TOKEN"))
engine_args = AsyncEngineArgs(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
gpu_memory_utilization=0.9, # Increase GPU memory utilization
max_model_len=8192, # Decrease max model length
)
engine = AsyncLLMEngine.from_engine_args(engine_args)
class Message(BaseModel):
role: str
content: str
class ChatCompletionResponse(BaseModel):
id: str
object: str
created: int
model: str
choices: List[Dict[str, Any]]
async def run(
messages: List[Message],
model: str,
run_id: str,
stream: bool = True,
temperature: float = 0.8,
top_p: float = 0.95,
):
prompt = " ".join([f"{msg['role']}: {msg['content']}" for msg in messages])
sampling_params = SamplingParams(temperature=temperature, top_p=top_p)
results_generator = engine.generate(prompt, sampling_params, run_id)
previous_text = ""
full_text = "" # Collect all generated text here
async for output in results_generator:
prompt = output.outputs
new_text = prompt[0].text[len(previous_text) :]
previous_text = prompt[0].text
full_text += new_text # Append new text to full_text
response = ChatCompletionResponse(
id=run_id,
object="chat.completion",
created=int(time.time()),
model=model,
choices=[
{
"text": new_text,
"index": 0,
"logprobs": None,
"finish_reason": prompt[0].finish_reason or "stop",
}
],
)
yield json.dumps(response.model_dump())