Skip to content

Commit

Permalink
changing bolt connection ping to async, removing http driver
Browse files Browse the repository at this point in the history
  • Loading branch information
EvanDietzMorris committed Jan 29, 2025
1 parent 07311e5 commit 8b14ab1
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 231 deletions.
204 changes: 14 additions & 190 deletions PLATER/services/util/graph_adapter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import base64
import traceback
import httpx
import time
import neo4j
import asyncio

import neo4j.exceptions
from neo4j import unit_of_work
Expand Down Expand Up @@ -34,15 +32,20 @@ def __init__(self,
self.sync_neo4j_driver = None
self._supports_apoc = None

async def connect_to_neo4j(self):
logger.debug('PINGING NEO4J')
self.ping()
logger.debug('CHECKING IF NEO4J SUPPORTS APOC')
self.check_apoc_support()
logger.debug(f'SUPPORTS APOC : {self._supports_apoc}')
async def connect_to_neo4j(self, retries=0):
self.neo4j_driver = neo4j.AsyncGraphDatabase.driver(self.graph_db_uri,
auth=self.database_auth,
**{'telemetry_disabled': True})
try:
await self.neo4j_driver.verify_connectivity()
except Exception as e: # currently the driver says it raises Exception, not something more specific
await self.neo4j_driver.close()
if retries <= 3:
await asyncio.sleep(5)
logger.error(f'Could not establish connection to neo4j, trying again... retry {retries + 1}')
await self.connect_to_neo4j(retries + 1)
else:
raise neo4j.exceptions.ServiceUnavailable('Connection to Neo4j could not be established.')

@staticmethod
@unit_of_work(timeout=NEO4J_QUERY_TIMEOUT)
Expand Down Expand Up @@ -148,27 +151,8 @@ def run_sync(self,
self.sync_neo4j_driver.close()
self.sync_neo4j_driver = None

def ping(self, counter: int = 1, max_retries: int = 3):
if not self.sync_neo4j_driver:
self.sync_neo4j_driver = neo4j.GraphDatabase.driver(self.graph_db_uri, auth=self.database_auth)
try:
self.sync_neo4j_driver.verify_connectivity()
return True
except neo4j.exceptions.AuthError as e:
raise e
except Exception as e:
if counter > max_retries:
logger.error(f'Waited too long for Neo4j initialization... giving up..')
raise neo4j.exceptions.ServiceUnavailable('Connection to Neo4j could not be established.')
logger.info(f'Pinging Neo4j failed, trying again... {repr(e)}')
time.sleep(10)
return self.ping(counter + 1)
finally:
self.sync_neo4j_driver.close()
self.sync_neo4j_driver = None

def check_apoc_support(self):
apoc_version_query = 'call apoc.help("meta")'
apoc_version_query = 'call apoc.version()'
if self._supports_apoc is None:
try:
self.run_sync(apoc_version_query)
Expand Down Expand Up @@ -221,164 +205,6 @@ def convert_http_response_to_dict(response: dict) -> list:
return array


class Neo4jHTTPDriver:
def __init__(self, host: str, port: str, auth: tuple, scheme: str = 'http'):
self._host = host
# NOTE - here "neo4j" refers to the database name, for the community edition there's only one "neo4j"
self._neo4j_transaction_endpoint = "/db/neo4j/tx/commit"
self._scheme = scheme
self._full_transaction_path = f"{self._scheme}://{self._host}:{port}{self._neo4j_transaction_endpoint}"
self._port = port
self._supports_apoc = None
self._header = {
'Accept': 'application/json; charset=UTF-8',
'Content-Type': 'application/json',
'Authorization': 'Basic %s' % base64.b64encode(f"{auth[0]}:{auth[1]}".encode('utf-8')).decode('utf-8')
}

async def connect_to_neo4j(self):
# ping and raise error if neo4j doesn't respond.
logger.debug('PINGING NEO4J')
self.ping()
logger.debug('CHECKING IF NEO4J SUPPORTS APOC')
self.check_apoc_support()
logger.debug(f'SUPPORTS APOC : {self._supports_apoc}')

async def post_request_json(self, payload):
async with httpx.AsyncClient(timeout=NEO4J_QUERY_TIMEOUT) as session:
response = await session.post(self._full_transaction_path, json=payload, headers=self._header)
if response.status_code != 200:
logger.error(f"[x] Problem contacting Neo4j server {self._host}:{self._port} -- {response.status_code}")
txt = response.text
logger.debug(f"[x] Server responded with {txt}")
try:
return response.json()
except:
return txt
else:
return response.json()

def ping(self):
"""
Pings the neo4j backend.
:return:
"""
ping_url = f"{self._scheme}://{self._host}:{self._port}/"
# if we can't contact neo4j, we should exit.
try:
now = time.time()
response = httpx.get(ping_url, headers=self._header)
later = time.time()
time_taken = later - now
logger.debug(f'Contacting neo4j took {time_taken} seconds.')
if time_taken > 5: # greater than 5 seconds it's not healthy
logger.warn(f"Contacting neo4j took more than 5 seconds ({time_taken}). Neo4j might be stressed.")
if response.status_code != 200:
raise Exception(f'server returned {response.status_code}')
except Exception as e:
logger.error(f"Error contacting Neo4j @ {ping_url} -- Exception raised -- {e}")
logger.debug(traceback.print_exc())
raise RuntimeError('Connection to Neo4j could not be established.')

async def run(self,
query,
return_errors=False,
convert_to_dict=False,
convert_to_trapi=False,
qgraph=None):
"""
Runs a neo4j query async.
:param return_errors: returns errors as values instead of raising an exception
:param query: Cypher query.
:param convert_to_dict: whether to convert the neo4j results into a dict
:param convert_to_trapi: whether to convert the neo4j results into a TRAPI dict
:param qgraph: the TRAPI qgraph
:return: result of query.
:rtype: dict
"""
# make the statement dictionary
payload = {
"statements": [
{
"statement": f"{query}"
}
]
}

response = await self.post_request_json(payload)
errors = response.get('errors')
if errors:
logger.error(f'Neo4j returned `{errors}` for cypher {query}.')
if return_errors:
return response
raise RuntimeWarning(f'Error running cypher {query}. {errors}')
if convert_to_trapi:
raise NotImplementedError(f'TRAPI queries are currently not supported over the HTTP protocol.')
if convert_to_dict:
response = convert_http_response_to_dict(response)
return response

def run_sync(self, query, convert_to_dict=False):
"""
Runs a neo4j query. Can cause the async loop to block.
:param query:
:param convert_to_dict:
:return:
"""
payload = {
"statements": [
{
"statement": f"{query}"
}
]
}
response = httpx.post(
self._full_transaction_path,
headers=self._header,
timeout=NEO4J_QUERY_TIMEOUT,
json=payload).json()
errors = response.get('errors')
if errors:
logger.error(f'Neo4j returned `{errors}` for cypher {query}.')
raise RuntimeWarning(f'Error running cypher {query}.')
if convert_to_dict:
response = convert_http_response_to_dict(response)
return response

def convert_to_dict(self, response: dict) -> list:
"""
Converts a neo4j result to a structured result.
:param response: neo4j http raw result.
:type response: dict
:return: reformatted dict
:rtype: dict
"""
results = response.get('results')
array = []
if results:
for result in results:
cols = result.get('columns')
if cols:
data_items = result.get('data')
for item in data_items:
new_row = {}
row = item.get('row')
for col_name, col_value in zip(cols, row):
new_row[col_name] = col_value
array.append(new_row)
return array

def check_apoc_support(self):
apoc_version_query = 'call apoc.help("meta")'
if self._supports_apoc is None:
try:
self.run_sync(apoc_version_query)
self._supports_apoc = True
except:
self._supports_apoc = False
return self._supports_apoc


class GraphInterface:
"""
Singleton class for interfacing with the graph.
Expand All @@ -387,9 +213,7 @@ class GraphInterface:
class _GraphInterface:
def __init__(self, host, port, auth, protocol='bolt'):
self.protocol = protocol
if protocol == 'http':
self.driver = Neo4jHTTPDriver(host=host, port=port, auth=auth)
elif protocol == 'bolt':
if protocol == 'bolt':
self.driver = Neo4jBoltDriver(host=host, port=port, auth=auth)
else:
raise Exception(f'Unsupported graph interface protocol: {protocol}')
Expand Down
79 changes: 38 additions & 41 deletions PLATER/tests/test_graph_adapter.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import json
from collections import defaultdict
from PLATER.services.util.graph_adapter import Neo4jHTTPDriver, GraphInterface, convert_http_response_to_dict
from PLATER.services.util.graph_adapter import GraphInterface, convert_http_response_to_dict
from pytest_httpx import HTTPXMock
import pytest
import os

# TODO - improve these tests
# They are a mess, partially because Neo4jHTTPDriver connect_to_neo4j calls ping() and check_apoc_support() in init
# which means both of those calls need to be mocked every time either is called. More importantly, we should have tests
# for the bolt protocol driver. It's questionable how helpful these mocked tests are for testing the driver(s) anyway.

# TODO - implement these tests for Bolt

"""
@pytest.mark.asyncio
async def test_neo4j_http_driver_ping_success(httpx_mock: HTTPXMock):
httpx_mock.add_response(url="http://localhost:7474/", method="GET", status_code=200)
Expand Down Expand Up @@ -107,6 +104,39 @@ async def test_neo4j_http_driver_apoc(httpx_mock: HTTPXMock):
driver = Neo4jHTTPDriver(host='localhost', port='7474', auth=('neo4j', 'somepass'))
await driver.connect_to_neo4j()
assert driver.check_apoc_support() is False
@pytest.mark.asyncio
async def test_graph_interface_get_schema(httpx_mock: HTTPXMock):
query_schema = "MATCH (a)-[x]->(b) RETURN DISTINCT labels(a) as source_labels, type(x) as predicate, labels(b) as target_labels"
httpx_mock.add_response(url="http://localhost:7474/", method="GET", status_code=200)
httpx_mock.add_response(url="http://localhost:7474/db/neo4j/tx/commit", method="POST", status_code=200)
gi = GraphInterface('localhost', '7474', auth=('neo4j', ''), protocol='bolt')
await gi.connect_to_neo4j()
with open(os.path.join(os.path.dirname(__file__), 'data', 'schema_cypher_response.json'))as f:
get_schema_response_json = json.load(f)
httpx_mock.add_response(url="http://localhost:7474/db/neo4j/tx/commit", method="POST", status_code=200,
match_content=json.dumps({
"statements": [
{
"statement": f"{query_schema}"
}
]
}).encode('utf-8'), json=get_schema_response_json)
# lets pretend we already have summary
# gi.instance.summary = True
schema = gi.get_schema()
expected = defaultdict(lambda: defaultdict(set))
expected['biolink:Disease']['biolink:Disease'] = {'biolink:has_phenotype'}
expected['biolink:PhenotypicFeature']['biolink:Disease'] = set()
expected['biolink:Disease']['biolink:PhenotypicFeature'] = {'biolink:has_phenotype'}
assert schema == expected
GraphInterface.instance = None
"""

@pytest.mark.asyncio
async def test_driver_convert_to_dict():
Expand Down Expand Up @@ -134,7 +164,7 @@ async def test_driver_convert_to_dict():

@pytest.mark.asyncio
async def test_graph_interface_biolink_leaves(httpx_mock: HTTPXMock):
gi = GraphInterface('localhost', '7474', auth=('neo4j', ''), protocol='http')
gi = GraphInterface('localhost', '7474', auth=('neo4j', ''), protocol='bolt')
list_1 = [
"biolink:SmallMolecule",
"biolink:MolecularEntity",
Expand All @@ -153,7 +183,7 @@ async def test_graph_interface_biolink_leaves(httpx_mock: HTTPXMock):

@pytest.mark.asyncio
async def test_graph_interface_predicate_inverse(httpx_mock: HTTPXMock):
gi = GraphInterface('localhost', '7474', auth=('neo4j', ''), protocol='http')
gi = GraphInterface('localhost', '7474', auth=('neo4j', ''), protocol='bolt')
non_exist_predicate = "biolink:some_predicate"
assert gi.invert_predicate(non_exist_predicate) is None
symmetric_predicate = "biolink:related_to"
Expand All @@ -163,36 +193,3 @@ async def test_graph_interface_predicate_inverse(httpx_mock: HTTPXMock):
predicate_no_inverse_and_not_symmetric = "biolink:has_part"
assert gi.invert_predicate(predicate_no_inverse_and_not_symmetric) is None
GraphInterface.instance = None

@pytest.mark.asyncio
async def test_graph_interface_get_schema(httpx_mock: HTTPXMock):
query_schema = """
MATCH (a)-[x]->(b)
RETURN DISTINCT labels(a) as source_labels, type(x) as predicate, labels(b) as target_labels
"""
httpx_mock.add_response(url="http://localhost:7474/", method="GET", status_code=200)
httpx_mock.add_response(url="http://localhost:7474/db/neo4j/tx/commit", method="POST", status_code=200)

gi = GraphInterface('localhost', '7474', auth=('neo4j', ''), protocol='http')
await gi.connect_to_neo4j()
with open(os.path.join(os.path.dirname(__file__), 'data', 'schema_cypher_response.json'))as f:
get_schema_response_json = json.load(f)
httpx_mock.add_response(url="http://localhost:7474/db/neo4j/tx/commit", method="POST", status_code=200,
match_content=json.dumps({
"statements": [
{
"statement": f"{query_schema}"
}
]
}).encode('utf-8'), json=get_schema_response_json)

# lets pretend we already have summary
# gi.instance.summary = True
schema = gi.get_schema()
expected = defaultdict(lambda: defaultdict(set))
expected['biolink:Disease']['biolink:Disease'] = {'biolink:has_phenotype'}
expected['biolink:PhenotypicFeature']['biolink:Disease'] = set()
expected['biolink:Disease']['biolink:PhenotypicFeature'] = {'biolink:has_phenotype'}

assert schema == expected
GraphInterface.instance = None

0 comments on commit 8b14ab1

Please sign in to comment.