Skip to content

Commit

Permalink
Merge pull request #101 from TranslatorSRI/performance
Browse files Browse the repository at this point in the history
Performance upgrade
  • Loading branch information
EvanDietzMorris authored Dec 20, 2023
2 parents b1f5511 + 9cd6a7d commit 3c2bdea
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 16 deletions.
61 changes: 46 additions & 15 deletions PLATER/services/app_trapi_1_4.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""FastAPI app."""
from fastapi import Body, Depends, FastAPI, Response
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from fastapi.responses import ORJSONResponse
from typing import Dict
from pydantic import ValidationError
import orjson

from reasoner_transpiler.exceptions import InvalidPredicateError
from PLATER.models.shared import ReasonerRequest, MetaKnowledgeGraph, SRITestData
Expand Down Expand Up @@ -42,34 +43,46 @@ def get_meta_kg_response(graph_metadata_reader: GraphMetadata):
TRAPI_QUERY_EXAMPLE = graph_metadata_reader.get_example_qgraph()


async def get_meta_knowledge_graph() -> JSONResponse:
async def get_meta_knowledge_graph() -> ORJSONResponse:
"""Handle /meta_knowledge_graph."""
if META_KG_RESPONSE:
# we are intentionally returning a JSONResponse directly,
# we are intentionally returning a ORJSONResponse directly,
# we already validated with pydantic above and the content won't change
return JSONResponse(status_code=200,
content=META_KG_RESPONSE,
media_type="application/json")
return ORJSONResponse(status_code=200,
content=META_KG_RESPONSE,
media_type="application/json")
else:
return JSONResponse(status_code=500,
media_type="application/json",
content={"description": "MetaKnowledgeGraph failed validation - "
"please notify maintainers."})
return ORJSONResponse(status_code=500,
media_type="application/json",
content={"description": "MetaKnowledgeGraph failed validation - "
"please notify maintainers."})


async def get_sri_testing_data():
"""Handle /sri_testing_data."""
return SRI_TEST_DATA


def orjson_default(obj):
if isinstance(obj, set):
return list(obj)
raise TypeError


class CustomORJSONResponse(Response):
def render(self, content: dict) -> bytes:
return orjson.dumps(content,
default=orjson_default)


async def reasoner_api(
response: Response,
request: ReasonerRequest = Body(
...,
example=TRAPI_QUERY_EXAMPLE,
),
graph_interface: GraphInterface = Depends(get_graph_interface),
):
) -> CustomORJSONResponse:
"""Handle TRAPI request."""
request_json = request.dict(by_alias=True)
# default workflow
Expand All @@ -81,7 +94,7 @@ async def reasoner_api(
response_message = await question.answer(graph_interface)
request_json.update({'message': response_message, 'workflow': workflow})
except InvalidPredicateError as e:
return JSONResponse(status_code=400, content={"description": str(e)})
return CustomORJSONResponse(status_code=400, content={"description": str(e)})
elif 'overlay_connect_knodes' in workflows:
overlay = Overlay(graph_interface=graph_interface)
response_message = await overlay.connect_k_nodes(request_json['message'])
Expand All @@ -90,7 +103,10 @@ async def reasoner_api(
overlay = Overlay(graph_interface=graph_interface)
response_message = await overlay.annotate_node(request_json['message'])
request_json.update({'message': response_message, 'workflow': workflow})
return request_json

# we are intentionally returning a CustomORJSONResponse and not a pydantic model for performance reasons
json_response = CustomORJSONResponse(content=request_json, media_type="application/json")
return json_response


APP_TRAPI_1_4.add_api_route(
Expand All @@ -115,12 +131,27 @@ async def reasoner_api(
tags=["trapi"]
)

set_up_profiling = False
if set_up_profiling:
from fastapi import Request
from fastapi.responses import HTMLResponse
from pyinstrument import Profiler
from pyinstrument.renderers import SpeedscopeRenderer
@APP_TRAPI_1_4.middleware("http")
async def profile_request(request: Request, call_next):
profiler = Profiler(interval=.5, async_mode="enabled")
profiler.start()
await call_next(request)
profiler.stop()
return HTMLResponse(profiler.output(renderer=SpeedscopeRenderer()))


APP_TRAPI_1_4.add_api_route(
"/query",
reasoner_api,
methods=["POST"],
response_model=ReasonerRequest,
responses={400: {"model": Dict}},
response_model=None,
responses={400: {"model": Dict}, 200: {"model": ReasonerRequest}},
summary="Query reasoner via one of several inputs.",
description="",
tags=["trapi"]
Expand Down
2 changes: 1 addition & 1 deletion PLATER/services/util/question.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def transform_attributes(self, trapi_message, graph_interface: GraphInterface):
for node_binding_list in r["node_bindings"].values():
for node_binding in node_binding_list:
query_id = node_binding.pop('query_id', None)
if query_id != node_binding['id']:
if query_id != node_binding['id'] and query_id is not None:
node_binding['query_id'] = query_id
# add resource id
for analyses in r["analyses"]:
Expand Down

0 comments on commit 3c2bdea

Please sign in to comment.