diff --git a/backend/main.py b/backend/main.py index 9021161..3f12431 100644 --- a/backend/main.py +++ b/backend/main.py @@ -2,15 +2,16 @@ import os import time import asyncio +import json from typing import Optional, Dict, Any from datetime import datetime from dotenv import load_dotenv from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field import uvicorn -import json from agents import NeuroscienceAssistant @@ -54,9 +55,8 @@ class ChatResponse(BaseModel): # Lightweight health helpers def _vector_check_sync() -> bool: - try: - from retrieval import Retriever # local import to avoid import penalty on startup + from retrieval import Retriever r = Retriever() return bool(getattr(r, "is_enabled", False)) except Exception: @@ -109,6 +109,57 @@ async def health(): } +@app.get("/api/chat/stream", tags=["Chat"]) +async def chat_stream_endpoint(query: str, session_id: str = "default"): + """ + Stream chat responses using Server-Sent Events (SSE). + Tokens appear in real-time as the LLM generates them. + + This provides a better user experience by showing responses + as they are generated, similar to ChatGPT. + """ + async def generate(): + try: + # Send initial connection message + yield f"data: {json.dumps({'type': 'start', 'message': 'Connected'})}\n\n" + + # Get the response from the assistant + response_text = await assistant.handle_chat( + session_id=session_id, + query=query, + reset=False, + ) + + # Stream the response in chunks (word-by-word simulation) + words = response_text.split(' ') + chunk_size = 3 # Send 3 words at a time for smooth streaming + + for i in range(0, len(words), chunk_size): + chunk = ' '.join(words[i:i + chunk_size]) + if i > 0: + chunk = ' ' + chunk + yield f"data: {json.dumps({'type': 'token', 'content': chunk})}\n\n" + await asyncio.sleep(0.03) # 30ms delay for streaming effect + + # Send completion message + yield f"data: {json.dumps({'type': 'done', 'message': 'Complete'})}\n\n" + + except asyncio.TimeoutError: + yield f"data: {json.dumps({'type': 'error', 'message': 'Request timed out. Please try a simpler query.'})}\n\n" + except Exception as e: + yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n" + + return StreamingResponse( + generate(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + } + ) + + @app.post("/api/chat", response_model=ChatResponse, tags=["Chat"]) async def chat_endpoint(msg: ChatMessage): try: @@ -138,8 +189,6 @@ async def chat_endpoint(msg: ChatMessage): ) - - @app.post("/api/session/reset", tags=["Chat"]) async def reset_session(payload: Dict[str, str]): sid = (payload or {}).get("session_id") or "default" @@ -154,7 +203,7 @@ async def reset_session(payload: Dict[str, str]): "main:app", host=os.getenv("HOST", "0.0.0.0"), port=int(os.getenv("PORT", "8000")), - reload=True, + reload=True, log_level="info", proxy_headers=True, - ) + ) \ No newline at end of file