From f0cedf8f555230128acb4abb3445e492e66c20ae Mon Sep 17 00:00:00 2001 From: Charan6924 Date: Sat, 3 Jan 2026 15:37:00 -0800 Subject: [PATCH 1/3] feature : Added graph retriever demo with NIFSTD seeding code --- backend/comparison.py | 43 +++++++++++++++++++ backend/graph_retrieval.py | 84 ++++++++++++++++++++++++++++++++++++++ backend/retrieval.py | 1 - backend/seed_graph.py | 37 +++++++++++++++++ pyproject.toml | 1 + 5 files changed, 165 insertions(+), 1 deletion(-) create mode 100644 backend/comparison.py create mode 100644 backend/graph_retrieval.py create mode 100644 backend/seed_graph.py diff --git a/backend/comparison.py b/backend/comparison.py new file mode 100644 index 0000000..121496e --- /dev/null +++ b/backend/comparison.py @@ -0,0 +1,43 @@ +from graph_retrieval import GraphRetriever + +def mock_vector_search(query): + """ + Simulates what their current system returns. + Vector search looks for keywords. If I search 'Ataxia', + it finds the disease but misses the brain region connection if not explicitly stated. + """ + if "ataxia" in query.lower(): + return [ + "Ataxia is a degenerative disease of the nervous system.", + "Symptoms include lack of voluntary coordination of muscle movements." + ] + return [] + +def main(): + + query = "What brain region does Ataxia affect?" + print(f"User Query: '{query}'\n") + + # (Vector Only) + print(f"Vector Search Results:") + vector_results = mock_vector_search(query) + for res in vector_results: + print(f" - {res}") + + # (Graph Retrieval) + print(f"Graph Context Retrieval:") + gr = GraphRetriever() + results = gr.search_ontology("Ataxia") + + if not results: + print("No graph results found. (Did you run seed_graph.py?)") + else: + for item in results: + print(f" - Entity: {item.name}") + print(f" - Definition: {item.definition}") + print(f" - Relationships: {item.relationships}") + + gr.close() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/backend/graph_retrieval.py b/backend/graph_retrieval.py new file mode 100644 index 0000000..2929d63 --- /dev/null +++ b/backend/graph_retrieval.py @@ -0,0 +1,84 @@ +import os +import logging +from dataclasses import dataclass +from typing import List, Dict, Any, Optional +from neo4j import GraphDatabase + +logger = logging.getLogger("graph_retrieval") +logger.setLevel(logging.INFO) +if not logger.handlers: + _h = logging.StreamHandler() + _h.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")) + logger.addHandler(_h) + +@dataclass +class GraphItem: + """Standardized output similar to RetrievedItem in retrieval.py""" + name: str + definition: str + relationships: List[str] + source: str + +class GraphRetriever: + """ + Neo4j-backed retriever for NIFSTD ontology terms. + + Environment variables: + - NEO4J_URI + - NEO4J_USER + - NEO4J_PASSWORD + """ + def __init__(self): + self.uri = os.getenv("NEO4J_URI", "bolt://localhost:7687") + self.user = os.getenv("NEO4J_USER", "neo4j") + self.password = os.getenv("NEO4J_PASSWORD", "password") + + try: + self.driver = GraphDatabase.driver(self.uri, auth=(self.user, self.password)) + self.driver.verify_connectivity() + logger.info(f"Connected to Neo4j at {self.uri}") + except Exception as e: + logger.error(f"Failed to connect to Neo4j: {e}") + self.driver = None + + def close(self): + if self.driver: + self.driver.close() + + def search_ontology(self, term: str) -> List[GraphItem]: + """ + Finds a term in the graph and returns its definition + parent/child links. + """ + if not self.driver: + return [] + + query = """ + MATCH (n) + WHERE toLower(n.label) CONTAINS toLower($term) + OPTIONAL MATCH (n)-[r]->(related) + RETURN n.label as name, n.definition as def, collect(type(r) + " -> " + related.label) as rels + LIMIT 5 + """ + + results = [] + try: + with self.driver.session() as session: + records = session.run(query, term=term) + for record in records: + results.append(GraphItem( + name=record["name"], + definition=record.get("def", "No definition found"), + relationships=record["rels"], + source="NIFSTD Ontology" + )) + except Exception as e: + logger.error(f"Graph search failed: {e}") + + return results + +if __name__ == "__main__": + gr = GraphRetriever() + items = gr.search_ontology("hippocampus") + for item in items: + print(f"Found: {item.name} | Rels: {item.relationships}") + gr.close() \ No newline at end of file diff --git a/backend/retrieval.py b/backend/retrieval.py index 866d6c9..3dbd2f1 100644 --- a/backend/retrieval.py +++ b/backend/retrieval.py @@ -11,7 +11,6 @@ logger = logging.getLogger("retrieval") logger.setLevel(logging.INFO) - if not logger.handlers: _h = logging.StreamHandler() _h.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")) diff --git a/backend/seed_graph.py b/backend/seed_graph.py new file mode 100644 index 0000000..b6f54a1 --- /dev/null +++ b/backend/seed_graph.py @@ -0,0 +1,37 @@ +import os +from neo4j import GraphDatabase + +URI = "bolt://localhost:7687" +AUTH = ("neo4j", "password") +def seed_data(): + driver = GraphDatabase.driver(URI, auth=AUTH) + + with driver.session() as session: + print("1. Clearing old data...") + session.run("MATCH (n) DETACH DELETE n") + + print("2. Seeding new NIFSTD ontology data...") + query = """ + CREATE (brain:NIFSTD_Term {label: "Brain", id: "UBERON:0000955", definition: "The central organ of the nervous system."}) + CREATE (hind:NIFSTD_Term {label: "Hindbrain", id: "UBERON:0002028", definition: "The posterior part of the brain."}) + CREATE (cere:NIFSTD_Term {label: "Cerebellum", id: "UBERON:0002037", definition: "Region of the brain that plays an important role in motor control."}) + + // Create Relationships + CREATE (cere)-[:PART_OF]->(hind) + CREATE (hind)-[:PART_OF]->(brain) + + // Create a Disease that links to the specific region + CREATE (atak:Disease {label: "Ataxia", definition: "A degenerative disease of the nervous system."}) + CREATE (atak)-[:AFFECTS]->(cere) + + RETURN count(brain) as nodes_created + """ + + result = session.run(query) + record = result.single() + print(f"Database seeded! Nodes created: {record['nodes_created'] if record else 0}") + + driver.close() + +if __name__ == "__main__": + seed_data() \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 67d020c..fc99fff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "langchain>=0.3.27", "langgraph>=0.6.4", "matplotlib>=3.10.3", + "neo4j>=6.0.3", "pandas>=2.3.1", "requests>=2.32.4", "scikit-learn>=1.7.0", From b81b3f9c69796c6e28c514f23ff8cf73540d4502 Mon Sep 17 00:00:00 2001 From: Charan6924 Date: Sun, 4 Jan 2026 18:38:08 -0800 Subject: [PATCH 2/3] added unit tests --- tests/__init__.py | 0 tests/test_neo4j_tool.py | 45 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+) create mode 100644 tests/__init__.py create mode 100644 tests/test_neo4j_tool.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_neo4j_tool.py b/tests/test_neo4j_tool.py new file mode 100644 index 0000000..0174d1b --- /dev/null +++ b/tests/test_neo4j_tool.py @@ -0,0 +1,45 @@ +import pytest +from unittest.mock import MagicMock, patch +import os +import sys + +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) + +from backend.neo4j_search_tool import GraphRetriever + +@pytest.fixture +def mock_driver(): + """Fakes the Neo4j driver so we don't need a real DB.""" + with patch("backend.neo4j_search_tool.GraphDatabase.driver") as mock_dt: + mock_instance = MagicMock() + mock_dt.return_value = mock_instance + yield mock_instance + +def test_initialization(): + """Verifies the tool loads environment variables correctly.""" + with patch.dict(os.environ, {"NEO4J_URI": "bolt://test:7687"}): + tool = GraphRetriever() + assert tool.uri == "bolt://test:7687" + +def test_search_execution(mock_driver): + """Verifies the search method sends the correct Cypher query.""" + # 1. Setup the fake return data + mock_session = MagicMock() + mock_driver.session.return_value.__enter__.return_value = mock_session + + # Fake a Neo4j record + mock_record = MagicMock() + mock_record.data.return_value = {"name": "Hippocampus", "definition": "Brain region"} + mock_session.run.return_value = [mock_record] + + # 2. Run the tool + tool = GraphRetriever() + tool.connect() + results = tool.search("Hippocampus") + + # 3. Verify it worked + assert len(results) == 1 + assert results[0]["name"] == "Hippocampus" + + # Verify the code actually called the DB + mock_session.run.assert_called_once() \ No newline at end of file From 63ce1f14bb63874455bbb69dbea0a7dafaefa696 Mon Sep 17 00:00:00 2001 From: Charan6924 Date: Sun, 4 Jan 2026 18:46:29 -0800 Subject: [PATCH 3/3] added unit tests --- backend/comparison.py | 43 ------------------- ...raph_retrieval.py => neo4j_search_tool.py} | 0 backend/seed_graph.py | 37 ---------------- pyproject.toml | 1 + tests/test_neo4j_tool.py | 25 ++++------- 5 files changed, 10 insertions(+), 96 deletions(-) delete mode 100644 backend/comparison.py rename backend/{graph_retrieval.py => neo4j_search_tool.py} (100%) delete mode 100644 backend/seed_graph.py diff --git a/backend/comparison.py b/backend/comparison.py deleted file mode 100644 index 121496e..0000000 --- a/backend/comparison.py +++ /dev/null @@ -1,43 +0,0 @@ -from graph_retrieval import GraphRetriever - -def mock_vector_search(query): - """ - Simulates what their current system returns. - Vector search looks for keywords. If I search 'Ataxia', - it finds the disease but misses the brain region connection if not explicitly stated. - """ - if "ataxia" in query.lower(): - return [ - "Ataxia is a degenerative disease of the nervous system.", - "Symptoms include lack of voluntary coordination of muscle movements." - ] - return [] - -def main(): - - query = "What brain region does Ataxia affect?" - print(f"User Query: '{query}'\n") - - # (Vector Only) - print(f"Vector Search Results:") - vector_results = mock_vector_search(query) - for res in vector_results: - print(f" - {res}") - - # (Graph Retrieval) - print(f"Graph Context Retrieval:") - gr = GraphRetriever() - results = gr.search_ontology("Ataxia") - - if not results: - print("No graph results found. (Did you run seed_graph.py?)") - else: - for item in results: - print(f" - Entity: {item.name}") - print(f" - Definition: {item.definition}") - print(f" - Relationships: {item.relationships}") - - gr.close() - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/backend/graph_retrieval.py b/backend/neo4j_search_tool.py similarity index 100% rename from backend/graph_retrieval.py rename to backend/neo4j_search_tool.py diff --git a/backend/seed_graph.py b/backend/seed_graph.py deleted file mode 100644 index b6f54a1..0000000 --- a/backend/seed_graph.py +++ /dev/null @@ -1,37 +0,0 @@ -import os -from neo4j import GraphDatabase - -URI = "bolt://localhost:7687" -AUTH = ("neo4j", "password") -def seed_data(): - driver = GraphDatabase.driver(URI, auth=AUTH) - - with driver.session() as session: - print("1. Clearing old data...") - session.run("MATCH (n) DETACH DELETE n") - - print("2. Seeding new NIFSTD ontology data...") - query = """ - CREATE (brain:NIFSTD_Term {label: "Brain", id: "UBERON:0000955", definition: "The central organ of the nervous system."}) - CREATE (hind:NIFSTD_Term {label: "Hindbrain", id: "UBERON:0002028", definition: "The posterior part of the brain."}) - CREATE (cere:NIFSTD_Term {label: "Cerebellum", id: "UBERON:0002037", definition: "Region of the brain that plays an important role in motor control."}) - - // Create Relationships - CREATE (cere)-[:PART_OF]->(hind) - CREATE (hind)-[:PART_OF]->(brain) - - // Create a Disease that links to the specific region - CREATE (atak:Disease {label: "Ataxia", definition: "A degenerative disease of the nervous system."}) - CREATE (atak)-[:AFFECTS]->(cere) - - RETURN count(brain) as nodes_created - """ - - result = session.run(query) - record = result.single() - print(f"Database seeded! Nodes created: {record['nodes_created'] if record else 0}") - - driver.close() - -if __name__ == "__main__": - seed_data() \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index fc99fff..8eca26d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "matplotlib>=3.10.3", "neo4j>=6.0.3", "pandas>=2.3.1", + "pytest>=9.0.2", "requests>=2.32.4", "scikit-learn>=1.7.0", "sentence-transformers>=3.0.0", diff --git a/tests/test_neo4j_tool.py b/tests/test_neo4j_tool.py index 0174d1b..1bf2b79 100644 --- a/tests/test_neo4j_tool.py +++ b/tests/test_neo4j_tool.py @@ -4,7 +4,6 @@ import sys sys.path.append(os.path.join(os.path.dirname(__file__), "..")) - from backend.neo4j_search_tool import GraphRetriever @pytest.fixture @@ -21,25 +20,19 @@ def test_initialization(): tool = GraphRetriever() assert tool.uri == "bolt://test:7687" -def test_search_execution(mock_driver): - """Verifies the search method sends the correct Cypher query.""" - # 1. Setup the fake return data +def test_search_ontology(mock_driver): + """Verifies the search method returns GraphItem objects.""" mock_session = MagicMock() mock_driver.session.return_value.__enter__.return_value = mock_session - - # Fake a Neo4j record mock_record = MagicMock() - mock_record.data.return_value = {"name": "Hippocampus", "definition": "Brain region"} + + data = {"name": "Hippocampus", "def": "Brain region", "rels": []} + mock_record.__getitem__.side_effect = data.__getitem__ + mock_record.get.side_effect = data.get mock_session.run.return_value = [mock_record] - # 2. Run the tool tool = GraphRetriever() - tool.connect() - results = tool.search("Hippocampus") - - # 3. Verify it worked - assert len(results) == 1 - assert results[0]["name"] == "Hippocampus" - - # Verify the code actually called the DB + results = tool.search_ontology("Hippocampus") + assert results[0].name == "Hippocampus" + assert results[0].definition == "Brain region" mock_session.run.assert_called_once() \ No newline at end of file