-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathmain.py
60 lines (46 loc) · 1.79 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# -*- coding: utf-8 -*-
import torch
from torch.backends import cudnn
import os
import random
from option import get_option
from trainer import Trainer
from utils import save_option
import data_loader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.utils.data as data
def backend_setting(option):
log_dir = os.path.join(option.save_dir, option.exp_name)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
if option.random_seed is None:
option.random_seed = random.randint(1,10000)
torch.manual_seed(option.random_seed)
if torch.cuda.is_available() and not option.cuda:
print('WARNING: GPU is available, but not use it')
if not torch.cuda.is_available() and option.cuda:
option.cuda = False
if option.cuda:
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
#os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(i) for i in option.gpu_ids])
torch.cuda.manual_seed_all(option.random_seed)
cudnn.benchmark = option.cudnn_benchmark
if option.train_baseline:
option.is_train = True
def main():
option = get_option()
backend_setting(option)
trainer = Trainer(option)
custom_loader = data_loader.WholeDataLoader(option)
trainval_loader = torch.utils.data.DataLoader(custom_loader,
batch_size=option.batch_size,
shuffle=True,
num_workers=option.num_workers)
if option.is_train:
save_option(option)
trainer.train(trainval_loader)
else:
trainer._validate(trainval_loader)
pass
if __name__ == '__main__': main()