-
Notifications
You must be signed in to change notification settings - Fork 235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
调用api.py流式代码谁能分享一下啊 #186
Comments
有人做过吗,咋调用百川的流式API接口啊 `>>> import torch
` 或者这个能改成流式的代码吗 |
# coding=utf-8
# Implements API for ChatGLM2-6B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
# Usage: python openai_api.py
# Visit http://localhost:8000/docs for documents.
import time
import torch
import uvicorn
from pydantic import BaseModel, Field
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from typing import Any, Dict, List, Literal, Optional, Union
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
from sse_starlette.sse import ServerSentEvent, EventSourceResponse
import json
import logging
from transformers.generation.utils import GenerationConfig
@asynccontextmanager
async def lifespan(app: FastAPI): # collects GPU memory
yield
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
logging.basicConfig(level=logging.WARN, filemode='a',filename='./log_stream_answer.log',format='%(asctime)s - %(levelname)s: %(message)s',datefmt='%Y-%m-%d %H:%M:%S')
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
DEVICE = "cuda"
DEVICE_ID = "0"
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
def torch_gc():
if torch.cuda.is_available():
with torch.cuda.device(CUDA_DEVICE):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "owner"
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = None
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = []
class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system"]
content: str
class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None
content: Optional[str] = None
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: Optional[float] = None
top_p: Optional[float] = None
max_length: Optional[int] = None
stream: Optional[bool] = False
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length"]
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]]
class ChatCompletionResponse(BaseModel):
model: str
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
@app.get("/v1/models", response_model=ModelList)
async def list_models():
global model_args
model_card = ModelCard(id="gpt-3.5-turbo")
return ModelList(data=[model_card])
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
global model, tokenizer
if request.messages[-1].role != "user":
raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content
prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == "system":
query = prev_messages.pop(0).content + query
messages_input = []
for msg_obj in request.messages:
messages_input.append({"role": msg_obj.role, "content": msg_obj.content})
logging.warning("---------------------------------------------------")
logging.warning(messages_input)
history = []
if request.stream:
#logging.info("query: "+str(query))
generate = predict(query, history, request.model,messages_input)
torch_gc()
return EventSourceResponse(generate, media_type="text/event-stream")
#response, _ = model.chat(tokenizer, query, history=history)
response = model.chat(tokenizer, messages_input, stream=False)
logging.warning(response)
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop"
)
torch_gc()
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")
async def predict(query: str, history: List[List[str]], model_id: str,messages: List[dict]):
global model, tokenizer
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
#yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield json.dumps(chunk.dict(exclude_unset=True), ensure_ascii=False)
current_length = 0
#for new_response, _ in model.stream_chat(tokenizer, query, history):
for new_response in model.chat(tokenizer, messages, stream=True):
if len(new_response) == current_length:
continue
new_text = new_response[current_length:]
current_length = len(new_response)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(content=new_text),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
#yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield json.dumps(chunk.dict(exclude_unset=True), ensure_ascii=False)
logging.warning(new_response)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason="stop"
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
#yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield json.dumps(chunk.dict(exclude_unset=True), ensure_ascii=False)
yield '[DONE]'
if __name__ == "__main__":
model_path = '/home/checkpoints/Baichuan2-13B-Chat-4bits'
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
model.generation_config = GenerationConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(
model_path,
use_fast=False,
trust_remote_code=True
)
uvicorn.run(app, host='0.0.0.0', port=7842, workers=1) |
如题,流式代码接口,感谢!!! 百川13B模型
The text was updated successfully, but these errors were encountered: