Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion backend/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import asyncio
from enum import Enum
from typing import Dict, List, Optional, TypedDict, Any

import io
from PIL import Image
import pytesseract
from langgraph.graph import StateGraph, END

from ks_search_tool import general_search, general_search_async, global_fuzzy_keyword_search
Expand Down Expand Up @@ -493,7 +495,49 @@ def reset_session(self, session_id: str):
self.chat_history.pop(session_id, None)
self.session_memory.pop(session_id, None)

async def extract_from_image(self, image_bytes: bytes, mime_type: str) -> str:
"""
Processes image bytes using Pytesseract and uses Gemini to
clean/format the results into a valid neuroscience query.
"""
try:
# 1. Convert bytes to an Image object
img = Image.open(io.BytesIO(image_bytes))

# 2. Run OCR in a background thread (to keep the app responsive)
raw_text = await asyncio.to_thread(pytesseract.image_to_string, img)

if not raw_text.strip():
return "The image appears to be empty or unreadable."

# 3. Use Gemini to "clean" the messy OCR text
# (OCR often results in typos or weird characters in scientific papers)
client = _get_genai_client()
clean_prompt = (
"Extract ONLY the scientific search terms from this OCR text. "
"Return the terms as a comma-separated list. "
"Do NOT include explanations, introductions, or extra text.\n\n"
f"OCR Text: {raw_text}"
)

cfg = genai_types.GenerateContentConfig(
temperature=0.1,
max_output_tokens=256
)

resp = client.models.generate_content(
model=FLASH_LITE_MODEL,
contents=[clean_prompt],
config=cfg
)
raw_output = (resp.text or raw_text).strip()
clean_text = raw_output.replace("**", "")
clean_text = "\n".join([line.lstrip("-* ").strip() for line in clean_text.splitlines()])
return clean_text

except Exception as e:
print(f"OCR Error: {e}")
return f"Error extracting text: {str(e)}"
async def handle_chat(self, session_id: str, query: str, reset: bool = False) -> str:
try:
if reset:
Expand Down
27 changes: 26 additions & 1 deletion backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from datetime import datetime

from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import uvicorn
Expand Down Expand Up @@ -108,6 +108,31 @@ async def health():
"timestamp": datetime.utcnow().isoformat(),
}

@app.post("/api/ocr", tags=["Chat"])
async def ocr_endpoint(file: UploadFile = File(...)):
"""
Receives an image, extracts neuroscience-related text using Gemini,
and returns it to be used as a chat query.
"""
try:
# 1. Read the uploaded image bytes
image_bytes = await file.read()

# 2. Use the assistant to process the image
extracted_text = await assistant.extract_from_image(
image_bytes,
mime_type=file.content_type
)

return {"extracted_text": extracted_text}

except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to process image: {str(e)}"
)



@app.post("/api/chat", response_model=ChatResponse, tags=["Chat"])
async def chat_endpoint(msg: ChatMessage):
Expand Down
113 changes: 90 additions & 23 deletions frontend/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,48 @@ Try asking me something like:
});
};

const handleImageUpload = async (e: React.ChangeEvent<HTMLInputElement>) => {
const file = e.target.files?.[0];
if (!file) return;

// Show a temporary "reading" state
setIsLoading(true);

const formData = new FormData();
formData.append('file', file);

try {
const response = await fetch('/api/ocr', {
method: 'POST',
body: formData,
});

if (!response.ok) throw new Error('OCR failed');

const data = await response.json();

// Put the extracted text into the input box for the user
setInputValue(data.extracted_text);
} catch (error) {
console.error("Upload error:", error);
const errorMessage: Message = {
id: Date.now().toString(),
type: 'error',
content: 'Failed to extract text from the image. Please try a clearer screenshot.',
timestamp: new Date()
};
setMessages(prev => [...prev, errorMessage]);
} finally {
setIsLoading(false);
// Reset the file input so the same file can be uploaded again if needed
e.target.value = '';
}
};





const sendMessage = async () => {
if (!inputValue.trim() || isLoading) return;

Expand Down Expand Up @@ -188,29 +230,54 @@ Try asking me something like:
{/* Input Area */}
<footer className="input-section">
<div className="input-container">
<div className="input-wrapper">
<input
type="text"
className="message-input"
placeholder="Ask about neuroscience datasets, brain imaging data, or research topics..."
value={inputValue}
onChange={(e) => setInputValue(e.target.value)}
onKeyPress={handleKeyPress}
disabled={isLoading}
/>
<button
className={`send-button ${isLoading || !inputValue.trim() ? 'disabled' : ''}`}
type="button"
onClick={sendMessage}
disabled={isLoading || !inputValue.trim()}
>
{isLoading ? (
<i className="fas fa-spinner fa-spin"></i>
) : (
<i className="fas fa-paper-plane"></i>
)}
</button>
</div>
<div className="input-wrapper" style={{ alignItems: 'flex-end' }}>
<input
type="file" id="image-upload" accept="image/*" hidden
onChange={handleImageUpload} disabled={isLoading}
/>

<label htmlFor="image-upload" className={`action-btn upload-btn ${isLoading ? 'disabled' : ''}`}>
<i className="fas fa-paperclip"></i>
</label>

{/* Dynamic Textarea */}
<textarea
className="message-input"
placeholder="Type or upload an image..."
value={inputValue}
rows={1}
onChange={(e) => {
setInputValue(e.target.value);
// Reset height to calculate correctly
e.target.style.height = 'inherit';
// Set new height based on scrollHeight, capped at 150px
e.target.style.height = `${Math.min(e.target.scrollHeight, 150)}px`;
}}
onKeyDown={(e) => {
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
sendMessage();
// Reset height after sending
(e.target as HTMLTextAreaElement).style.height = 'inherit';
}
}}
style={{
resize: 'none',
overflowY: inputValue.split('\n').length > 5 ? 'auto' : 'hidden',
minHeight: '44px',
maxHeight: '150px'
}}
disabled={isLoading}
/>

<button
className={`send-button ${isLoading || !inputValue.trim() ? 'disabled' : ''}`}
onClick={sendMessage}
disabled={isLoading || !inputValue.trim()}
>
{isLoading ? <i className="fas fa-spinner fa-spin"></i> : <i className="fas fa-paper-plane"></i>}
</button>
</div>
<div className="input-footer">
<i className="fas fa-info-circle"></i>
<span>Powered by INCF KnowledgeSpace API - Neuroscience datasets</span>
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ dependencies = [
"langgraph>=0.6.4",
"matplotlib>=3.10.3",
"pandas>=2.3.1",
"pillow>=12.1.0",
"pytesseract>=0.3.13",
"python-multipart>=0.0.21",
"requests>=2.32.4",
"scikit-learn>=1.7.0",
"sentence-transformers>=3.0.0",
Expand Down