-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
executable file
·105 lines (97 loc) · 3.83 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from loss_functions import SupervisedContrastiveLoss,SupervisedCosineContrastiveLoss
import os,sys
from os.path import join
sys.path.append("/scratch/GIT/BikeML")
from baseline.BaselineModel_1a import BaselineModel_1a
from baseline.BaselineModel_1b import BaselineModel_1b
from dataloaders.dataloader import AddGaussianNoise, AddGaussianNoise, BikeDataLoader
from dataloaders.dataloader import SquarePad,Resize,SquareCrop
import torch
from torch import nn
import torchvision
from torchvision import transforms
from torchvision.transforms import ToTensor, RandomHorizontalFlip,ColorJitter, RandomGrayscale
base_config = dict(
epochs = 100,
lr = 0.01,
weight_decay = 0.000001,
image_dim = 224,
starting_epoch = 99,
number_of_figures = 16,
half_precision = False,
train_backbone = True,
clear_redis = False,
viz_attention = True
)
dataloader_params = dict(
data_set_size = 500000,
# data_set_size = 1000,
data_splits = {"train":55/60,"val": 5/60,"test":5/60},
normalize = True,
balance = 0.5,
num_workers = 32,
prefetch_factor=1,
batch_size = 1024,
transforms = torchvision.transforms.Compose([
SquareCrop((base_config["image_dim"],base_config["image_dim"])),
Resize((base_config["image_dim"],base_config["image_dim"])),
ToTensor(),
]),
# def get_color_distortion(s=1.0):
# # s is the strength of color distortion.
# color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
# rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
# rnd_gray = transforms.RandomGrayscale(p=0.2)
# color_distort = transforms.Compose([
# rnd_color_jitter,
# rnd_gray])
# return color_distort
# root = "/scratch/datasets/raw/",
root = "/data_raid/raw/",
# root = "/scratch/datasets/detr_filtered/",
shuffle=True,
memory=True,
half=False,
pin_memory=True
)
#Baseline_1a
Baseline_1a_config = dict(
model = BaselineModel_1a,
exp_name = 'testing ROC',
dataloader = BikeDataLoader,
criterion = nn.BCELoss,
project_path = "./baseline_1a",
input_shape = (dataloader_params["batch_size"], 3, base_config["image_dim"], base_config["image_dim"]),
mlp_layers = 4,
eval=True,
tiny_transforms =dataloader_params["transforms"],
viz_attention = False
)
#Baseline_1b
Baseline_1b_config = dict(
exp_name = 'Margin tuning: 0.9',
model = BaselineModel_1b,
dataloader = BikeDataLoader,
criterion = SupervisedCosineContrastiveLoss,
project_path = "./baseline_1b",
input_shape = (dataloader_params["batch_size"], 3, base_config["image_dim"], base_config["image_dim"]),
mlp_layers = 4,
embedding_dimension = 128,
margin = 0.9,
transforms = torchvision.transforms.Compose([
SquareCrop((base_config["image_dim"],base_config["image_dim"])),
RandomHorizontalFlip(p=0.5),
ColorJitter(0.8, 0.8, 0.8, 0.2),
RandomGrayscale(p=0.2),
ToTensor()
]),
tiny_transforms = torchvision.transforms.Compose([
SquareCrop((base_config["image_dim"],base_config["image_dim"])),
ColorJitter(0.8, 0.8, 0.8, 0.2),
RandomGrayscale(p=0.2),
ToTensor()
])
)
hyperparameters = base_config
hyperparameters["dataloader_params"] = dataloader_params
hyperparameters.update(Baseline_1a_config)