diff --git a/backend/neo4j_search_tool.py b/backend/neo4j_search_tool.py new file mode 100644 index 0000000..2929d63 --- /dev/null +++ b/backend/neo4j_search_tool.py @@ -0,0 +1,84 @@ +import os +import logging +from dataclasses import dataclass +from typing import List, Dict, Any, Optional +from neo4j import GraphDatabase + +logger = logging.getLogger("graph_retrieval") +logger.setLevel(logging.INFO) +if not logger.handlers: + _h = logging.StreamHandler() + _h.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")) + logger.addHandler(_h) + +@dataclass +class GraphItem: + """Standardized output similar to RetrievedItem in retrieval.py""" + name: str + definition: str + relationships: List[str] + source: str + +class GraphRetriever: + """ + Neo4j-backed retriever for NIFSTD ontology terms. + + Environment variables: + - NEO4J_URI + - NEO4J_USER + - NEO4J_PASSWORD + """ + def __init__(self): + self.uri = os.getenv("NEO4J_URI", "bolt://localhost:7687") + self.user = os.getenv("NEO4J_USER", "neo4j") + self.password = os.getenv("NEO4J_PASSWORD", "password") + + try: + self.driver = GraphDatabase.driver(self.uri, auth=(self.user, self.password)) + self.driver.verify_connectivity() + logger.info(f"Connected to Neo4j at {self.uri}") + except Exception as e: + logger.error(f"Failed to connect to Neo4j: {e}") + self.driver = None + + def close(self): + if self.driver: + self.driver.close() + + def search_ontology(self, term: str) -> List[GraphItem]: + """ + Finds a term in the graph and returns its definition + parent/child links. + """ + if not self.driver: + return [] + + query = """ + MATCH (n) + WHERE toLower(n.label) CONTAINS toLower($term) + OPTIONAL MATCH (n)-[r]->(related) + RETURN n.label as name, n.definition as def, collect(type(r) + " -> " + related.label) as rels + LIMIT 5 + """ + + results = [] + try: + with self.driver.session() as session: + records = session.run(query, term=term) + for record in records: + results.append(GraphItem( + name=record["name"], + definition=record.get("def", "No definition found"), + relationships=record["rels"], + source="NIFSTD Ontology" + )) + except Exception as e: + logger.error(f"Graph search failed: {e}") + + return results + +if __name__ == "__main__": + gr = GraphRetriever() + items = gr.search_ontology("hippocampus") + for item in items: + print(f"Found: {item.name} | Rels: {item.relationships}") + gr.close() \ No newline at end of file diff --git a/backend/retrieval.py b/backend/retrieval.py index 866d6c9..3dbd2f1 100644 --- a/backend/retrieval.py +++ b/backend/retrieval.py @@ -11,7 +11,6 @@ logger = logging.getLogger("retrieval") logger.setLevel(logging.INFO) - if not logger.handlers: _h = logging.StreamHandler() _h.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")) diff --git a/pyproject.toml b/pyproject.toml index 67d020c..8eca26d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,9 @@ dependencies = [ "langchain>=0.3.27", "langgraph>=0.6.4", "matplotlib>=3.10.3", + "neo4j>=6.0.3", "pandas>=2.3.1", + "pytest>=9.0.2", "requests>=2.32.4", "scikit-learn>=1.7.0", "sentence-transformers>=3.0.0", diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_neo4j_tool.py b/tests/test_neo4j_tool.py new file mode 100644 index 0000000..1bf2b79 --- /dev/null +++ b/tests/test_neo4j_tool.py @@ -0,0 +1,38 @@ +import pytest +from unittest.mock import MagicMock, patch +import os +import sys + +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +from backend.neo4j_search_tool import GraphRetriever + +@pytest.fixture +def mock_driver(): + """Fakes the Neo4j driver so we don't need a real DB.""" + with patch("backend.neo4j_search_tool.GraphDatabase.driver") as mock_dt: + mock_instance = MagicMock() + mock_dt.return_value = mock_instance + yield mock_instance + +def test_initialization(): + """Verifies the tool loads environment variables correctly.""" + with patch.dict(os.environ, {"NEO4J_URI": "bolt://test:7687"}): + tool = GraphRetriever() + assert tool.uri == "bolt://test:7687" + +def test_search_ontology(mock_driver): + """Verifies the search method returns GraphItem objects.""" + mock_session = MagicMock() + mock_driver.session.return_value.__enter__.return_value = mock_session + mock_record = MagicMock() + + data = {"name": "Hippocampus", "def": "Brain region", "rels": []} + mock_record.__getitem__.side_effect = data.__getitem__ + mock_record.get.side_effect = data.get + mock_session.run.return_value = [mock_record] + + tool = GraphRetriever() + results = tool.search_ontology("Hippocampus") + assert results[0].name == "Hippocampus" + assert results[0].definition == "Brain region" + mock_session.run.assert_called_once() \ No newline at end of file