-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_conv_med.py
43 lines (34 loc) · 1.21 KB
/
train_conv_med.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import os
import sys
import torch
from src.compare_against_gurobi import compare_against_gurobi
from src.inputs.conv_med_img67 import gurobi_results, solver_inputs
from src.solve import solve
from src.training.TrainingConfig import TrainingConfig
from src.utils import seed_everything, set_abs_path_to
CURRENT_DIR = os.path.dirname(__file__)
get_abs_path = set_abs_path_to(CURRENT_DIR)
CONFIG_FILE_PATH = get_abs_path("default_training_config.yaml")
seed_everything(0)
# Load training config from YAML file.
training_config = TrainingConfig.from_yaml_file(CONFIG_FILE_PATH)
training_config.num_epoch_adv_check = 2
is_falsified, new_L_list, new_U_list, solver = solve(
solver_inputs,
device=torch.device("cuda"),
return_solver=True,
training_config=training_config,
)
if is_falsified:
print("Verification problem is falsified.")
sys.exit(0)
unstable_masks = solver.sequential.unstable_masks
compare_against_gurobi(
new_L_list=[torch.from_numpy(x) for x in new_L_list],
new_U_list=[torch.from_numpy(x) for x in new_U_list],
unstable_masks=unstable_masks,
initial_L_list=solver_inputs.L_list,
initial_U_list=solver_inputs.U_list,
gurobi_results=gurobi_results,
cutoff_threshold=1e-5,
)