Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions backend/neo4j_search_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import os
import logging
from dataclasses import dataclass
from typing import List, Dict, Any, Optional
from neo4j import GraphDatabase

logger = logging.getLogger("graph_retrieval")
logger.setLevel(logging.INFO)
if not logger.handlers:
_h = logging.StreamHandler()
_h.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s"))
logger.addHandler(_h)

@dataclass
class GraphItem:
"""Standardized output similar to RetrievedItem in retrieval.py"""
name: str
definition: str
relationships: List[str]
source: str

class GraphRetriever:
"""
Neo4j-backed retriever for NIFSTD ontology terms.

Environment variables:
- NEO4J_URI
- NEO4J_USER
- NEO4J_PASSWORD
"""
def __init__(self):
self.uri = os.getenv("NEO4J_URI", "bolt://localhost:7687")
self.user = os.getenv("NEO4J_USER", "neo4j")
self.password = os.getenv("NEO4J_PASSWORD", "password")

try:
self.driver = GraphDatabase.driver(self.uri, auth=(self.user, self.password))
self.driver.verify_connectivity()
logger.info(f"Connected to Neo4j at {self.uri}")
except Exception as e:
logger.error(f"Failed to connect to Neo4j: {e}")
self.driver = None

def close(self):
if self.driver:
self.driver.close()

def search_ontology(self, term: str) -> List[GraphItem]:
"""
Finds a term in the graph and returns its definition + parent/child links.
"""
if not self.driver:
return []

query = """
MATCH (n)
WHERE toLower(n.label) CONTAINS toLower($term)
OPTIONAL MATCH (n)-[r]->(related)
RETURN n.label as name, n.definition as def, collect(type(r) + " -> " + related.label) as rels
LIMIT 5
"""

results = []
try:
with self.driver.session() as session:
records = session.run(query, term=term)
for record in records:
results.append(GraphItem(
name=record["name"],
definition=record.get("def", "No definition found"),
relationships=record["rels"],
source="NIFSTD Ontology"
))
except Exception as e:
logger.error(f"Graph search failed: {e}")

return results

if __name__ == "__main__":
gr = GraphRetriever()
items = gr.search_ontology("hippocampus")
for item in items:
print(f"Found: {item.name} | Rels: {item.relationships}")
gr.close()
1 change: 0 additions & 1 deletion backend/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

logger = logging.getLogger("retrieval")
logger.setLevel(logging.INFO)

if not logger.handlers:
_h = logging.StreamHandler()
_h.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s"))
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ dependencies = [
"langchain>=0.3.27",
"langgraph>=0.6.4",
"matplotlib>=3.10.3",
"neo4j>=6.0.3",
"pandas>=2.3.1",
"pytest>=9.0.2",
"requests>=2.32.4",
"scikit-learn>=1.7.0",
"sentence-transformers>=3.0.0",
Expand Down
Empty file added tests/__init__.py
Empty file.
38 changes: 38 additions & 0 deletions tests/test_neo4j_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest
from unittest.mock import MagicMock, patch
import os
import sys

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from backend.neo4j_search_tool import GraphRetriever

@pytest.fixture
def mock_driver():
"""Fakes the Neo4j driver so we don't need a real DB."""
with patch("backend.neo4j_search_tool.GraphDatabase.driver") as mock_dt:
mock_instance = MagicMock()
mock_dt.return_value = mock_instance
yield mock_instance

def test_initialization():
"""Verifies the tool loads environment variables correctly."""
with patch.dict(os.environ, {"NEO4J_URI": "bolt://test:7687"}):
tool = GraphRetriever()
assert tool.uri == "bolt://test:7687"

def test_search_ontology(mock_driver):
"""Verifies the search method returns GraphItem objects."""
mock_session = MagicMock()
mock_driver.session.return_value.__enter__.return_value = mock_session
mock_record = MagicMock()

data = {"name": "Hippocampus", "def": "Brain region", "rels": []}
mock_record.__getitem__.side_effect = data.__getitem__
mock_record.get.side_effect = data.get
mock_session.run.return_value = [mock_record]

tool = GraphRetriever()
results = tool.search_ontology("Hippocampus")
assert results[0].name == "Hippocampus"
assert results[0].definition == "Brain region"
mock_session.run.assert_called_once()