-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
37 lines (29 loc) · 856 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import sys
import torch
from model import Model
from Preprocessing import *
def get_data():
pass
def tune_model(model, lr:list, epochs:list):
pass
def predict(path):
"""
params: path: filepath of .pth trained model
"""
model = Model()
model.load_state_dict(torch.load(path))
g1 = input("Please enter the first gene: ")
# TODO: check that g1 is valid *********
g2 = input("Please enter the second gene: ")
# TODO: check that g2 is valid *********
enc = encode_pair(g1, g2)
pred = model.predict_one(enc)
if pred == 1:
print("The model predicts that {} and {} form a synthetic lethal pair.".format(g1,g2))
else:
print("The model predicts that {} and {} do NOT form a synthetic lethal pair.".format(g1,g2))
def make_visuals(model):
pass
if __name__ == "__main__":
args = sys.argv[1:]
# TODO parse args