Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 56 additions & 7 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
import os
import time
import asyncio
import json
from typing import Optional, Dict, Any
from datetime import datetime

from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
import uvicorn
import json

from agents import NeuroscienceAssistant

Expand Down Expand Up @@ -54,9 +55,8 @@ class ChatResponse(BaseModel):
# Lightweight health helpers

def _vector_check_sync() -> bool:

try:
from retrieval import Retriever # local import to avoid import penalty on startup
from retrieval import Retriever
r = Retriever()
return bool(getattr(r, "is_enabled", False))
except Exception:
Expand Down Expand Up @@ -109,6 +109,57 @@ async def health():
}


@app.get("/api/chat/stream", tags=["Chat"])
async def chat_stream_endpoint(query: str, session_id: str = "default"):
"""
Stream chat responses using Server-Sent Events (SSE).
Tokens appear in real-time as the LLM generates them.

This provides a better user experience by showing responses
as they are generated, similar to ChatGPT.
"""
async def generate():
try:
# Send initial connection message
yield f"data: {json.dumps({'type': 'start', 'message': 'Connected'})}\n\n"

# Get the response from the assistant
response_text = await assistant.handle_chat(
session_id=session_id,
query=query,
reset=False,
)

# Stream the response in chunks (word-by-word simulation)
words = response_text.split(' ')
chunk_size = 3 # Send 3 words at a time for smooth streaming

for i in range(0, len(words), chunk_size):
chunk = ' '.join(words[i:i + chunk_size])
if i > 0:
chunk = ' ' + chunk
yield f"data: {json.dumps({'type': 'token', 'content': chunk})}\n\n"
await asyncio.sleep(0.03) # 30ms delay for streaming effect

# Send completion message
yield f"data: {json.dumps({'type': 'done', 'message': 'Complete'})}\n\n"

except asyncio.TimeoutError:
yield f"data: {json.dumps({'type': 'error', 'message': 'Request timed out. Please try a simpler query.'})}\n\n"
except Exception as e:
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"

return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
}
)


@app.post("/api/chat", response_model=ChatResponse, tags=["Chat"])
async def chat_endpoint(msg: ChatMessage):
try:
Expand Down Expand Up @@ -138,8 +189,6 @@ async def chat_endpoint(msg: ChatMessage):
)




@app.post("/api/session/reset", tags=["Chat"])
async def reset_session(payload: Dict[str, str]):
sid = (payload or {}).get("session_id") or "default"
Expand All @@ -154,7 +203,7 @@ async def reset_session(payload: Dict[str, str]):
"main:app",
host=os.getenv("HOST", "0.0.0.0"),
port=int(os.getenv("PORT", "8000")),
reload=True,
reload=True,
log_level="info",
proxy_headers=True,
)
)