Skip to content

Commit

Permalink
feat: sorting vector
Browse files Browse the repository at this point in the history
- vector of predictions formatted and sorted during prediction
- `forward_pass` allows to retrieve embedding vectors
  • Loading branch information
valentynbez committed Apr 1, 2024
1 parent 438749e commit 348fd76
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 42 deletions.
13 changes: 0 additions & 13 deletions mDeepFRI/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,19 +339,6 @@ def predict_protein_function(

output_buffer.close()

# sort predictions by protein name and score
with open(output_file_name, "r", encoding="utf-8") as f:
reader = csv.reader(f, delimiter="\t")
header = next(reader)
# row[0] - protein name
# row[2] - DeepFRI score
rows = sorted(reader, key=lambda row: (str(row[0]), -float(row[2])))

with open(output_file_name, "w", encoding="utf-8") as f:
writer = csv.writer(f, delimiter="\t")
writer.writerow(header)
writer.writerows(rows)

if remove_intermediate:
for db in deepfri_dbs:
remove_intermediate_files([db.sequence_db, db.mmseqs_db])
Expand Down
74 changes: 45 additions & 29 deletions mDeepFRI/predict.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ cimport numpy as np
import cython

np.import_array()

import operator

import onnxruntime as rt


Expand Down Expand Up @@ -48,7 +51,6 @@ cdef class Predictor(object):
Class for loading trained models and computing GO/EC predictions and class activation maps (CAMs).
"""

cdef public bint gcn
cdef public str model_path
cdef public int threads
cdef public dict prot2goterms
Expand All @@ -59,7 +61,6 @@ cdef class Predictor(object):
cdef public session
cdef public np.ndarray Y_hat
cdef public dict data
cdef public list test_prot_list

def __init__(self, model_path: str, threads: int = 0):
self.model_path = model_path
Expand All @@ -86,35 +87,13 @@ cdef class Predictor(object):
self.goterms = np.asarray(metadata['goterms'])
self.thresh = 0.1 * np.ones(len(self.goterms))

def predict_function(
self,
seqres: str,
cmap = None,
chain: str = "",
):
"""
Computes GO/EC predictions for a single protein chain from sequence and contact map.
Args:
seqres (str): protein sequence
cmap (np.array): contact map
chain (str): chain ID
Returns:
None
"""
def forward_pass(self, seqres: str, cmap = None):

cdef np.ndarray A
cdef np.ndarray prediction
cdef np.ndarray y
cdef list output_rows = []
cdef str go_term
cdef float score
cdef str annotation

self.Y_hat = np.zeros((1, len(self.goterms)), dtype=float)
self.data = {}
self.test_prot_list = [chain]
# self.data = {}
S = seq2onehot(seqres)
S = S.reshape(1, *S.shape)
inputDetails = self.session.get_inputs()
Expand All @@ -126,21 +105,58 @@ cdef class Predictor(object):
inputDetails[0].name: A.astype(np.float32),
inputDetails[1].name: S.astype(np.float32)
})[0]
self.data[chain] = [[A, S], seqres]
# self.data[chain] = [[A, S], seqres]

# if no cmap use CNN with 1 input
else:
prediction = self.session.run(
None, {inputDetails[0].name: S.astype(np.float32)})[0]
self.data[chain] = [[S], seqres]
# self.data[chain] = [[S], seqres]

y = prediction[:, :, 0].reshape(-1)
self.Y_hat[0] = y

return y

def format_predictions(self, y, chain: str = ""):

cdef list output_rows = []
cdef str go_term
cdef float score
cdef str annotation

go_idx = np.where(y >= self.thresh)[0]
for idx in go_idx:
go_term = self.goterms[idx].item()
score = float(y[idx])
annotation = self.gonames[idx].item()
output_rows.append([chain, go_term, score, annotation])

# Sort output_rows based on score in descending order
output_rows.sort(key=operator.itemgetter(2), reverse=True)

return output_rows


def predict_function(
self,
seqres: str,
cmap = None,
chain: str = ""
):
"""
Computes GO/EC predictions for a single protein chain from sequence and contact map.
Args:
seqres (str): protein sequence.
cmap (np.array): contact map.
chain (str): protein ID.
Returns:
list: list of GO/EC predictions.
"""

y = self.forward_pass(seqres, cmap)
output_rows = self.format_predictions(y, chain)

return output_rows

0 comments on commit 348fd76

Please sign in to comment.