Skip to content

Commit

Permalink
feat: Add support for stopping prompt response generation (#1484)
Browse files Browse the repository at this point in the history
* Implement trash can button stopping response mid-generation

* Implement preventing generation (attempting to delete before any tokens are generated)

* Remove unused imports

* Standardize Redis clients to use decode_responses=False

* Refactor deletePrompt to use consts

* Fix code formatting

* Change delete_prompt to delete_or_stop_prompt with POST method

* Refactor to use JSON

* Update chat.py

---------

Co-authored-by: Juan Calderon-Perez <[email protected]>
  • Loading branch information
HubertYGuan and gaby authored Oct 10, 2024
1 parent 5c8d77a commit 3d407be
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 11 deletions.
36 changes: 29 additions & 7 deletions api/src/serge/routers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,16 +157,26 @@ async def get_chat_history(chat_id: str, u: User = Depends(get_current_active_us
return messages_to_dict(history.messages)


@chat_router.delete("/{chat_id}/prompt")
async def delete_prompt(chat_id: str, idx: int, u: User = Depends(get_current_active_user)):
@chat_router.post("/{chat_id}/prompt")
async def delete_or_stop_prompt(chat_id: str, idx: int, u: User = Depends(get_current_active_user)):
if idx < 0:
raise ValueError("Index cannot be negative")

if chat_id not in [x.chat_id for x in u.chats]:
raise unauth_error

history = RedisChatMessageHistory(chat_id)
client = Redis(host="localhost", port=6379, decode_responses=False)

if idx >= len(history.messages):
logger.error("Unable to delete message, chat in progress")
raise HTTPException(status_code=202, detail="Unable to delete message, chat in progress")
client.set(f"stop_generation:{chat_id}", "1", ex=10)
if client.get(f"has_generated:{chat_id}"):
client.delete(f"has_generated:{chat_id}")
logger.info("Stopping response generation")
return {"message": "Stopping response generation"}
else:
logger.info("Preventing response generation")
return {"message": "Preventing response generation"}

messages = history.messages.copy()[:idx]
history.clear()
Expand Down Expand Up @@ -231,7 +241,7 @@ async def stream_ask_a_question(chat_id: str, prompt: str, u: User = Depends(get

logger.debug("creating Llama client")
try:
client = Llama(
llama_client = Llama(
model_path=f"/usr/src/app/weights/{chat.params.model_path}.bin",
n_ctx=len(chat.params.init_prompt) + chat.params.n_ctx,
n_gpu_layers=chat.params.n_gpu_layers,
Expand All @@ -244,11 +254,16 @@ async def stream_ask_a_question(chat_id: str, prompt: str, u: User = Depends(get
history.append(SystemMessage(content=error))
return {"event": "error"}

# Following logic triggers if deleting before any tokens are generated
if client.get(f"stop_generation:{chat_id}"):
client.delete(f"stop_generation:{chat_id}")
return {"event": "close"}

def event_generator():
full_answer = ""
error = None
try:
for output in client(
for output in llama_client(
prompt,
stream=True,
temperature=chat.params.temperature,
Expand All @@ -257,6 +272,12 @@ def event_generator():
repeat_penalty=chat.params.repeat_penalty,
max_tokens=chat.params.max_tokens,
):
if client.get(f"stop_generation:{chat_id}"):
logger.info("Generation stopped by user")
client.delete(f"stop_generation:{chat_id}")
break
elif not client.get(f"has_generated:{chat_id}"):
client.set(f"has_generated:{chat_id}", "1")
txt = output["choices"][0]["text"]
full_answer += txt
yield {"event": "message", "data": txt}
Expand All @@ -269,9 +290,10 @@ def event_generator():
logger.error(error)
yield ({"event": "error"})
finally:
client.delete(f"has_generated:{chat_id}")
if error:
history.append(SystemMessage(content=error))
else:
elif full_answer:
logger.info(full_answer)
ai_message = AIMessage(content=full_answer)
history.append(message=ai_message)
Expand Down
19 changes: 15 additions & 4 deletions web/src/routes/chat/[id]/+page.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -144,16 +144,27 @@
await goto("/chat/" + newData);
}
const STOPPING_RESPONSE = "Stopping response generation";
const PREVENTING_RESPONSE = "Preventing response generation";
async function deletePrompt(chatID: string, idx: number) {
const response = await fetch(
`/api/chat/${chatID}/prompt?idx=${idx.toString()}`,
{ method: "DELETE" },
{ method: "POST" },
);
if (response.status === 200) {
const responseData = await response.json();
switch (responseData.message) {
case STOPPING_RESPONSE:
showToast(STOPPING_RESPONSE);
return;
case PREVENTING_RESPONSE:
showToast(PREVENTING_RESPONSE);
break;
default:
showToast("Response deleted successfully");
}
await invalidate("/api/chat/" + $page.params.id);
} else if (response.status === 202) {
showToast("Chat in progress!");
} else if (response.status === 401) {
window.location.href = "/";
} else {
Expand Down

0 comments on commit 3d407be

Please sign in to comment.