Skip to content

Commit

Permalink
Merge pull request #7 from TranslatorSRI/fix_memory_usage
Browse files Browse the repository at this point in the history
Fix memory usage
  • Loading branch information
maximusunc authored Jul 27, 2023
2 parents 34a60b5 + 02bda1b commit 5c58857
Show file tree
Hide file tree
Showing 11 changed files with 101 additions and 62 deletions.
12 changes: 8 additions & 4 deletions app/clinical_evidence/compute_clinical_evidence.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""Clinical Evidence Scoring."""
import json
import logging
import numpy as np
import redis


def compute_clinical_evidence(
result: dict, message, logger: logging.Logger, clinical_evidence_edges: dict
result: dict, message, logger: logging.Logger, db_conn: redis.Redis
):
"""Given a result, compute the clinical evidence score,
Expand All @@ -23,8 +26,9 @@ def compute_clinical_evidence(
logger.error("malformed TRAPI")
continue
clinical_edge_id = f"{kg_edge['subject']}_{kg_edge['object']}"
if kg_edge and clinical_edge_id in clinical_evidence_edges:
found_edges.extend(clinical_evidence_edges[clinical_edge_id])
kg_edge = db_conn.get(clinical_edge_id)
if kg_edge is not None:
found_edges.extend(json.loads(kg_edge))

# Compute the clinical evidence score given all clinical kp edges
# Score is computed by:
Expand All @@ -44,4 +48,4 @@ def compute_clinical_evidence(
)
if total_weights > 0:
clinical_evidence_score /= total_weights
return clinical_evidence_score
return (1 / (1 + np.exp(-np.abs(clinical_evidence_score))) - 0.5) * 2
29 changes: 22 additions & 7 deletions app/ordering_components.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
"""Compute scores for each result in the given message."""
import os
import redis
from tqdm import tqdm

from .clinical_evidence.compute_clinical_evidence import compute_clinical_evidence

REDIS_PSWD = os.getenv("REDIS_PSWD", "supersecretpassword")


def get_confidence(result, message, logger):
"""
Expand All @@ -26,23 +32,32 @@ def get_confidence(result, message, logger):
return score_sum


def get_clinical_evidence(result, message, logger, clinical_evidence_edges: dict):
return compute_clinical_evidence(result, message, logger, clinical_evidence_edges)
def get_clinical_evidence(result, message, logger, db_conn):
return compute_clinical_evidence(result, message, logger, db_conn)


def get_novelty(result, message, logger):
# TODO get novelty from novelty package
return 0


def get_ordering_components(message, logger, clinical_evidence_edges: dict):
def get_ordering_components(message, logger):
logger.debug(f"Computing scores for {len(message['results'])} results")
for result in message.get("results") or []:
db_conn = redis.Redis(
host="0.0.0.0",
port=6379,
password=REDIS_PSWD,
)
for result_index, result in enumerate(tqdm(message.get("results") or [])):
clinical_evidence_score = get_clinical_evidence(
result,
message,
logger,
db_conn,
)
result["ordering_components"] = {
"confidence": get_confidence(result, message, logger),
"clinical_evidence": get_clinical_evidence(
result, message, logger, clinical_evidence_edges
),
"clinical_evidence": clinical_evidence_score,
"novelty": 0,
}
if result["ordering_components"]["clinical_evidence"] == 0:
Expand Down
46 changes: 9 additions & 37 deletions app/server.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from io import BytesIO
import json
import logging
import traceback
import os
import zipfile

from fastapi import Body, BackgroundTasks
from fastapi.responses import JSONResponse
Expand All @@ -23,7 +20,7 @@

openapi_args = dict(
title="Answer Appraiser",
version="0.2.0",
version="0.2.1",
terms_of_service="",
translator_component="Utility",
translator_teams=["Standards Reference Implementation Team"],
Expand Down Expand Up @@ -62,35 +59,6 @@
allow_headers=["*"],
)

clinical_evidence_edges = {}


@APP.on_event("startup")
def load_clinical_evidence_edges():
"""Load in precomputed clinical evidence edges.
This file is very large, ~12GB, so will take some time to load.
"""
global clinical_evidence_edges
LOGGER.info("Loading clinical evidence edges...")
clinical_evidence_edges_url = os.getenv(
"CLINICAL_EVIDENCE_EDGES_URL",
"https://stars.renci.org/var/answer_appraiser/edges_merged.zip",
)
response = httpx.get(clinical_evidence_edges_url)
response.raise_for_status()
LOGGER.info("Downloaded edges. Unzipping...")
buffer = BytesIO(response.content)
with zipfile.ZipFile(buffer, "r") as zip_ref:
edge_file = zip_ref.namelist()[0]
with zip_ref.open(edge_file) as file:
content = file.read()
decoded_content = content.decode("utf-8")

clinical_evidence_edges = json.loads(decoded_content)
LOGGER.info("Edges loaded!")


EXAMPLE = {
"message": {
"query_graph": {
Expand Down Expand Up @@ -153,9 +121,10 @@ def load_clinical_evidence_edges():

async def async_appraise(message, callback, logger: logging.Logger):
try:
get_ordering_components(message, logger, clinical_evidence_edges)
get_ordering_components(message, logger)
except Exception:
logger.error(f"Something went wrong while appraising: {traceback.format_exc()}")
logger.info("Done appraising")
try:
logger.info(f"Posting to callback {callback}")
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=600.0)) as client:
Expand All @@ -173,8 +142,9 @@ async def get_appraisal(
"""Appraise Answers"""
qid = str(uuid4())[:8]
query_dict = query.dict()
log_level = query_dict.get("log_level") or "WARNING"
log_level = query_dict.get("log_level") or "INFO"
logger = get_logger(qid, log_level)
logger.info("Starting async appraisal")
message = query_dict["message"]
if not message.get("results"):
logger.warning("No results given.")
Expand All @@ -198,16 +168,18 @@ async def get_appraisal(
async def sync_get_appraisal(query: Query = Body(..., example=EXAMPLE)):
qid = str(uuid4())[:8]
query_dict = query.dict()
log_level = query_dict.get("log_level") or "WARNING"
log_level = query_dict.get("log_level") or "INFO"
logger = get_logger(qid, log_level)
logger.info("Starting sync appraisal")
message = query_dict["message"]
if not message.get("results"):
return JSONResponse(
content={"status": "Rejected", "description": "No Results.", "job_id": qid},
status_code=400,
)
try:
get_ordering_components(message, logger, clinical_evidence_edges)
get_ordering_components(message, logger)
except Exception:
logger.error(f"Something went wrong while appraising: {traceback.format_exc()}")
logger.info("Done appraising")
return Response(message=message)
3 changes: 3 additions & 0 deletions redis/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
FROM redis
COPY redis.conf /usr/local/etc/redis/redis.conf
CMD [ "redis-server", "/usr/local/etc/redis/redis.conf" ]
22 changes: 22 additions & 0 deletions redis/redis.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
requirepass supersecretpassword
port 6380
maxmemory 20gb
# evict least frequently used keys when memory cap is hit
maxmemory-policy volatile-lfu
loglevel notice
# If we want to log to a file
logfile /data/answer_appraiser_cache.log
save 3600 1
stop-writes-on-bgsave-error no
dbfilename answer_appraiser_cache.rdb

# enable larger entry writes
proto-max-bulk-len 1000mb

# clean up any memory leaks
activedefrag yes

# only allow connections a certain percentage of total memory
# maxmemory-clients 10%
# close idle clients after 60 seconds
timeout 60
2 changes: 2 additions & 0 deletions requirements-lock.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ httpcore==0.17.2
httptools==0.2.0
httpx==0.24.1
idna==3.4
numpy==1.25.1
pydantic==1.10.9
python-dotenv==1.0.0
PyYAML==6.0
reasoner-pydantic==4.0.8
sniffio==1.3.0
starlette==0.17.1
tqdm==4.65.0
typing_extensions==4.6.3
uvicorn==0.13.3
uvloop==0.17.0
Expand Down
2 changes: 2 additions & 0 deletions requirements-test-lock.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ chardet==4.0.0
coverage==7.2.7
idna==2.10
iniconfig==2.0.0
fakeredis==2.17.0
packaging==23.1
pluggy==0.13.1
py==1.11.0
pytest==6.2.2
pytest-asyncio==0.16.0
pytest-cov==2.11.1
requests==2.25.1
sortedcontainers==2.4.0
toml==0.10.2
urllib3==1.26.16
1 change: 1 addition & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
fakeredis==2.17.0
pytest==6.2.2
pytest-cov==2.11.1
requests==2.25.1
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
fastapi==0.75.0
gunicorn==20.1.0
httpx==0.24.1
numpy==1.25.1
reasoner-pydantic==4.0.8
tqdm==4.65.0
uvicorn==0.13.3
26 changes: 26 additions & 0 deletions tests/clinical_response.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,29 @@
"""Mock Redis."""
import fakeredis
import json


def redisMock():
redis = fakeredis.FakeRedis()
redis.set(
"UMLS:C0021641_MONDO:0005015",
json.dumps(
[
{
"log_odds_ratio": 1.5,
"total_sample_size": 100,
},
{
"log_odds_ratio": 0.2,
"total_sample_size": 10000,
},
]
),
)
# set up mock function
return redis


response = {
"query_graph": {
"nodes": {
Expand Down
17 changes: 3 additions & 14 deletions tests/test_clinical_evidence.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging

from app.clinical_evidence.compute_clinical_evidence import compute_clinical_evidence
from tests.clinical_response import response
from tests.clinical_response import response, redisMock

logger = logging.getLogger(__name__)

Expand All @@ -13,17 +13,6 @@ def test_clinical_evidence():
response["results"][0],
response,
logger,
{
"UMLS:C0021641_MONDO:0005015": [
{
"log_odds_ratio": 1.5,
"total_sample_size": 100,
},
{
"log_odds_ratio": 0.2,
"total_sample_size": 10000,
},
]
},
redisMock(),
)
assert score == 0.2128712871287129
assert score == 0.10603553615150196

0 comments on commit 5c58857

Please sign in to comment.