-
Notifications
You must be signed in to change notification settings - Fork 42
/
run.py
157 lines (140 loc) · 4.07 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# -*- coding: utf-8 -*-
#
# @File: run.py
# @Author: Haozhe Xie
# @Date: 2023-04-05 21:27:22
# @Last Modified by: Haozhe Xie
# @Last Modified at: 2023-05-22 10:45:47
# @Email: [email protected]
import argparse
import cv2
import importlib
import logging
import torch
import os
import sys
import core.vqgan
import core.sampler
import core.gancraft
import utils.distributed
from pprint import pprint
from datetime import datetime
# Fix deadlock in DataLoader
cv2.setNumThreads(0)
def get_args_from_command_line():
parser = argparse.ArgumentParser()
parser.add_argument(
"-e",
"--exp",
dest="exp_name",
help="The name of the experiment",
default="%s" % datetime.now(),
type=str,
)
parser.add_argument(
"-c",
"--cfg",
dest="cfg_file",
help="Path to the config.py file",
default="config.py",
type=str,
)
parser.add_argument(
"-n",
"--network",
dest="network",
help="The network name to train or test. ['VQGAN', 'Sampler', 'GANCraft']",
default=None,
type=str,
)
parser.add_argument(
"-g",
"--gpus",
dest="gpus",
help="The GPU device to use (e.g., 0,1,2,3).",
default=None,
type=str,
)
parser.add_argument(
"-p",
"--ckpt",
dest="ckpt",
help="Initialize the network from a pretrained model.",
default=None,
)
parser.add_argument(
"-r",
"--run",
dest="run_id",
help="The unique run ID for WandB",
default=None,
type=str,
)
parser.add_argument(
"--test", dest="test", help="Test the network.", action="store_true"
)
parser.add_argument(
"--local_rank",
type=int,
help="The rank ID of the GPU. Automatically assigned by torch.distributed.",
default=os.getenv("LOCAL_RANK", 0),
)
args = parser.parse_args()
return args
def main():
# Get args from command line
args = get_args_from_command_line()
# Read the experimental config
exec(compile(open(args.cfg_file, "rb").read(), args.cfg_file, "exec"))
cfg = locals()["__C"]
# Parse runtime arguments
if args.gpus is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
if args.exp_name is not None:
cfg.CONST.EXP_NAME = args.exp_name
if args.network is not None:
cfg.CONST.NETWORK = args.network
if args.ckpt is not None:
cfg.CONST.CKPT = args.ckpt
if args.run_id is not None:
cfg.WANDB.RUN_ID = args.run_id
if args.run_id is not None and args.ckpt is None:
raise Exception("No checkpoints")
# Print the current config
local_rank = args.local_rank
if local_rank == 0:
pprint(cfg)
# Initialize the DDP environment
if torch.cuda.is_available() and not args.test:
utils.distributed.set_affinity(local_rank)
utils.distributed.init_dist(local_rank)
# Start train/test processes
if not args.test:
if cfg.CONST.NETWORK == "VQGAN":
core.vqgan.train(cfg)
elif cfg.CONST.NETWORK == "Sampler":
core.sampler.train(cfg)
elif cfg.CONST.NETWORK == "GANCraft":
core.gancraft.train(cfg)
else:
raise Exception("Unknown network: %s" % cfg.CONST.NETWORK)
else:
if "CKPT" not in cfg.CONST or not os.path.exists(cfg.CONST.CKPT):
logging.error("Please specify the file path of checkpoint.")
sys.exit(2)
if cfg.CONST.NETWORK == "VQGAN":
core.vqgan.test(cfg)
elif cfg.CONST.NETWORK == "Sampler":
core.sampler.test(cfg)
elif cfg.CONST.NETWORK == "GANCraft":
core.gancraft.test(cfg)
else:
raise Exception("Unknown network: %s" % cfg.CONST.NETWORK)
if __name__ == "__main__":
# References: https://stackoverflow.com/a/53553516/1841143
importlib.reload(logging)
logging.basicConfig(
format="[%(levelname)s] %(asctime)s %(message)s",
level=logging.INFO,
)
main()