Skip to content

Commit

Permalink
feat: 优化 Chat 的载入及为其创建一个简易 server (#124)
Browse files Browse the repository at this point in the history
  • Loading branch information
TogetsuDo authored Oct 30, 2024
1 parent ca88e29 commit f119a35
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 18 deletions.
9 changes: 9 additions & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,18 @@ COMMAND_START=["","/"]
#请参考 https://zhuanlan.zhihu.com/p/618011122 配置 strategy
#CHAT_STRATEGY=cuda fp16

# 是否使用本地 api
#chat_use_local_server = False

#chat api超时时间,机子差可以设长一点
#chat_server_timeout = 15

#chat api重试次数
#chat_server_retry = 3

# tts 功能相关配置

# 声码器,可选值:pwgan_aishell3、wavernn_csmsc
#TTS_VOCODER=pwgan_aishell3


12 changes: 10 additions & 2 deletions docs/AIDeployment.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,17 @@ AI 功能均对设备硬件要求较高,且配置操作更加复杂一些。
python -m pip install tokenizers rwkv
```

3. `src/plugins/chat/prompt.py` 里的起手咒语 `INIT_PROMPT` 有兴趣可以试着改改
3. (可选)在 `.env` 里配置是否启用 chat server,由独立进程加载聊天模型。默认不启用,由 Pallas-Bot 直接加载聊天模型

```bash
python src/pluings/chat/server.py
```

`src/plugins/chat/server.py`中的端口可以自行修改,默认为 5000,保证与 `src/plugins/chat/__init__.py` 中一致即可。也可以自行部署 gunicorn 等生产服务器。

4. `src/plugins/chat/prompt.py` 里的起手咒语 `INIT_PROMPT` 有兴趣可以试着改改

4. `src/plugins/chat/model.py` 里的 `STRATEGY` 可以按上游仓库的 [说明](https://github.com/BlinkDL/ChatRWKV/tree/main#%E4%B8%AD%E6%96%87%E6%A8%A1%E5%9E%8B) 改改,能省点显存啥的
5. `src/plugins/chat/model.py` 里的 `STRATEGY` 可以按上游仓库的 [说明](https://github.com/BlinkDL/ChatRWKV/tree/main#%E4%B8%AD%E6%96%87%E6%A8%A1%E5%9E%8B) 改改,能省点显存啥的

## 酒后语音说话(TTS)

Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ pydantic~=1.10.0
pymongo~=4.3.3
jieba~=0.42.1
pypinyin~=0.49.0

# chat
httpx~=0.27.0
6 changes: 6 additions & 0 deletions src/common/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ class PluginConfig(BaseModel, extra=Extra.ignore):
tts_vocoder: str = 'pwgan_aishell3'
# chat 模型的strategy
chat_strategy: str = ''
# chat 是否使用本地api
chat_use_local_server: bool = False
# chat api超时时间
chat_server_timeout: int = 15
# chat api重试次数
chat_server_retry: int = 3


try:
Expand Down
73 changes: 60 additions & 13 deletions src/plugins/chat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
from asyncer import asyncify
from nonebot.adapters.onebot.v11 import MessageSegment, permission, GroupMessageEvent
from nonebot.adapters import Bot, Event
from nonebot.rule import Rule
from nonebot.typing import T_State
from nonebot import on_message, logger
import httpx

from src.common.config import BotConfig, GroupConfig, plugin_config

Expand All @@ -21,34 +23,70 @@
raise error

TTS_MIN_LENGTH = 10
CHAT_API_URL = 'http://127.0.0.1:5000/chat'
USE_API = plugin_config.chat_use_local_server
TIMEOUT = plugin_config.chat_server_timeout
MAX_RETRIES = plugin_config.chat_server_retry
RETRY_BACKOFF_FACTOR = 1 # 重试间隔

# 用来重试的
client = httpx.AsyncClient(
timeout=httpx.Timeout(timeout=TIMEOUT),
transport=httpx.AsyncHTTPTransport(retries=MAX_RETRIES)
)

try:
chat = Chat(plugin_config.chat_strategy)
except Exception as error:
logger.error('Chat model init error: ', error)
raise error

if USE_API:
try:
chat = None
except Exception as error:
logger.error('Chat api init error: ', error)
raise error
else:
try:
chat = Chat(plugin_config.chat_strategy)
except Exception as error:
logger.error('Chat model init error: ', error)
raise error

@BotConfig.handle_sober_up
def on_sober_up(bot_id, group_id, drunkenness) -> None:
session = f'{bot_id}_{group_id}'
logger.info(
f'bot [{bot_id}] sober up in group [{group_id}], clear session [{session}]')
chat.del_session(session)

logger.info(f'bot [{bot_id}] sober up in group [{group_id}], clear session [{session}]')
if USE_API:
try:
response = client.delete(f'{CHAT_API_URL}/del_session', params={'session': session})
response.raise_for_status()
except httpx.HTTPError as error:
logger.error(f'Failed to delete session [{session}]: {error}')
else:
if chat is not None:
chat.del_session(session)

def is_drunk(bot: Bot, event: Event, state: T_State) -> int:
config = BotConfig(event.self_id, event.group_id)
return config.drunkenness()


drunk_msg = on_message(
rule=Rule(is_drunk),
priority=13,
block=True,
permission=permission.GROUP,
)

async def make_api_request(url, method, json_data=None, params=None):
for a in range(MAX_RETRIES + 1):
try:
if method == 'POST':
response = await client.post(url, json=json_data)
elif method == 'DELETE':
response = await client.delete(url, params=params)
response.raise_for_status()
return response
except httpx.HTTPError as error:
logger.error(f'Request failed (attempt {a + 1}): {error}')
if a < MAX_RETRIES:
await asyncio.sleep(RETRY_BACKOFF_FACTOR * (2 ** a))
return None

@drunk_msg.handle()
async def _(bot: Bot, event: GroupMessageEvent, state: T_State):
Expand All @@ -71,7 +109,16 @@ async def _(bot: Bot, event: GroupMessageEvent, state: T_State):
text = text[:50]
if not text:
return
ans = await asyncify(chat.chat)(session, text)

if USE_API:
response = await make_api_request(CHAT_API_URL, 'POST', json_data={'session': session, 'text': text, 'token_count': 50})
if response:
ans = response.json().get('response', '')
else:
return
else:
ans = await asyncify(chat.chat)(session, text)

logger.info(f'session [{session}]: {text} -> {ans}')

if TTS_AVAIABLE and len(ans) >= TTS_MIN_LENGTH:
Expand All @@ -80,4 +127,4 @@ async def _(bot: Bot, event: GroupMessageEvent, state: T_State):
await drunk_msg.send(voice)

config.reset_cooldown(cd_key)
await drunk_msg.finish(ans)
await drunk_msg.finish(ans)
23 changes: 20 additions & 3 deletions src/plugins/chat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
from threading import Lock
from copy import deepcopy
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
import os
import time
import torch
import threading

cuda = torch.cuda.is_available()
os.environ['RWKV_JIT_ON'] = '1'
Expand Down Expand Up @@ -34,6 +37,17 @@ def __init__(self, strategy=DEFAULT_STRATEGY, model_dir=DEFAULT_MODEL_DIR) -> No
raise Exception(f'Chat model not found in {self.MODEL_DIR}')
if not self.TOKEN_PATH.exists():
raise Exception(f'Chat token not found in {self.TOKEN_PATH}')

self.pipeline = None
self.args = None
self.all_state = defaultdict(lambda: None)
self.all_occurrence = {}
self.chat_locker = Lock()
self.executor = ThreadPoolExecutor(max_workers=10)

threading.Thread(target=self._load_model).start()

def _load_model(self):
model = RWKV(model=str(self.MODEL_PATH), strategy=self.STRATEGY)
self.pipeline = PIPELINE(model, str(self.TOKEN_PATH))
self.args = PIPELINE_ARGS(
Expand All @@ -49,11 +63,14 @@ def __init__(self, strategy=DEFAULT_STRATEGY, model_dir=DEFAULT_MODEL_DIR) -> No
INIT_STATE = deepcopy(self.pipeline.generate(
INIT_PROMPT, token_count=200, args=self.args)[1])
self.all_state = defaultdict(lambda: deepcopy(INIT_STATE))
self.all_occurrence = {}

self.chat_locker = Lock()

def chat(self, session: str, text: str, token_count: int = 50) -> str:
while self.pipeline is None:
time.sleep(0.1)
future = self.executor.submit(self._chat, session, text, token_count)
return future.result()

def _chat(self, session: str, text: str, token_count: int = 50) -> str:
with self.chat_locker:
state = self.all_state[session]
ctx = CHAT_FORMAT.format(text)
Expand Down
102 changes: 102 additions & 0 deletions src/plugins/chat/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from flask import Flask, request, jsonify
from pathlib import Path
from threading import Lock
from copy import deepcopy
from collections import defaultdict
import os
import torch


app = Flask(__name__)

cuda = torch.cuda.is_available()
os.environ['RWKV_JIT_ON'] = '1'
# 这个要配个 ninja 啥的环境,能大幅提高推理速度,有需要可以自己弄下(仅支持 cuda 显卡)
os.environ["RWKV_CUDA_ON"] = '0'

from rwkv.model import RWKV
import prompt
import pipeline

DEFAULT_STRATEGY = 'cuda fp16' if cuda else 'cpu fp32'
API_DIR = Path(__file__).resolve().parent.parent.parent.parent
DEFAULT_MODEL_DIR = API_DIR / 'resource' / 'chat' / 'models'
print(f"DEFAULT_MODEL_DIR: {DEFAULT_MODEL_DIR}")
print("Files in directory:")
for f in DEFAULT_MODEL_DIR.iterdir():
print(f)
class Chat:
def __init__(self, strategy=DEFAULT_STRATEGY, model_dir=DEFAULT_MODEL_DIR) -> None:
self.STRATEGY = strategy if strategy else DEFAULT_STRATEGY
self.MODEL_DIR = model_dir
self.MODEL_EXT = '.pth'
self.MODEL_PATH = None
self.TOKEN_PATH = self.MODEL_DIR / '20B_tokenizer.json'
for f in self.MODEL_DIR.glob('*'):
if f.suffix != self.MODEL_EXT:
continue
self.MODEL_PATH = f.with_suffix('')
break
if not self.MODEL_PATH:
raise Exception(f'Chat model not found in {self.MODEL_DIR}')
if not self.TOKEN_PATH.exists():
raise Exception(f'Chat token not found in {self.TOKEN_PATH}')
model = RWKV(model=str(self.MODEL_PATH), strategy=self.STRATEGY)
self.pipeline = pipeline.PIPELINE(model, str(self.TOKEN_PATH))
self.args = pipeline.PIPELINE_ARGS(
temperature=1.0,
top_p=0.7,
alpha_frequency=0.25,
alpha_presence=0.25,
token_ban=[0], # ban the generation of some tokens
token_stop=[], # stop generation whenever you see any token here
ends=('\n'),
ends_if_too_long=("。", "!", "?", "\n"))

INIT_STATE = deepcopy(self.pipeline.generate(
prompt.INIT_PROMPT, token_count=200, args=self.args)[1])
self.all_state = defaultdict(lambda: deepcopy(INIT_STATE))
self.all_occurrence = {}

self.chat_locker = Lock()

def chat(self, session: str, text: str, token_count: int = 50) -> str:
with self.chat_locker:
state = self.all_state[session]
ctx = prompt.CHAT_FORMAT.format(text)
occurrence = self.all_occurrence.get(session, {})

out, state, occurrence = self.pipeline.generate(
ctx, token_count=token_count, args=self.args, state=state, occurrence=occurrence)

self.all_state[session] = deepcopy(state)
self.all_occurrence[session] = occurrence
return out.strip()

def del_session(self, session: str):
with self.chat_locker:
if session in self.all_state:
del self.all_state[session]
if session in self.all_occurrence:
del self.all_occurrence[session]

chat_instance = Chat('cpu fp32')

@app.route('/chat', methods=['POST'])
def chat():
data = request.json
session = data.get('session', 'main')
text = data.get('text', '')
token_count = data.get('token_count', 50)
response = chat_instance.chat(session, text, token_count)
return jsonify({'response': response})

@app.route('/del_session', methods=['DELETE'])
def del_session():
data = request.json
session = data.get('session', 'main')
chat_instance.del_session(session)
return jsonify({'status': 'success'})

if __name__ == "__main__":
app.run(host='0.0.0.0', port=5000)

0 comments on commit f119a35

Please sign in to comment.