-
Notifications
You must be signed in to change notification settings - Fork 1
/
test.py
128 lines (111 loc) · 5.29 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import argparse
import os
import os.path as osp
import warnings
warnings.filterwarnings("ignore")
import mmcv
import torch
from mmcv import DictAction
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
wrap_fp16_model)
from mogen.apis import multi_gpu_test, single_gpu_test
from mogen.datasets import build_dataloader, build_dataset
from mogen.models import build_architecture
def parse_args():
parser = argparse.ArgumentParser(description='mogen evaluation')
parser.add_argument('config', help='test config file path')
parser.add_argument('--work-dir',
help='the dir to save evaluation results')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('--out', help='output result file')
parser.add_argument('--gpu_collect',
action='store_true',
help='whether to use gpu to collect results')
parser.add_argument('--tmpdir', help='tmp dir for writing some results')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file.')
parser.add_argument('--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--device',
choices=['cpu', 'cuda'],
default='cuda',
help='device used for testing')
parser.add_argument('--repaint', action='store_true', help='whether to use repaint for a long sequence')
parser.add_argument('--overlap_len', type=int, default=0, help='Fix the initial N frames for this clip')
parser.add_argument('--fix_very_first', action='store_true', help='Fix the very first {overlap_len} frames for this video to be the same as GT')
parser.add_argument('--same_overlap_noisy', action="store_true", help='During the outpainting process, use the same overlapping noisyGT')
parser.add_argument('--no_resample', action="store_true", help='Do not use resample during inpainting based sampling')
parser.add_argument("--timestep_respacing", type=str, default='ddim1000', help="Set ddim steps 'ddim{STEP}'")
parser.add_argument('--jump_n_sample', type=int, default=5, help='hyperparameter for resampling')
parser.add_argument('--jump_length', type=int, default=3, help='hyperparameter for resampling')
parser.add_argument('--addBlend', type=bool, default=True, help='Blend in the overlapping region at the last two denoise steps')
parser.add_argument('--no_repaint', action="store_true", help='Do not perform repaint during long-form generation')
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def main():
args = parse_args()
cfg = mmcv.Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
cfg.data.test.test_mode = True
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# build the dataloader
dataset = build_dataset(cfg.data.test)
print('eval_dataset', len(dataset))
# the extra round_up data will be removed during gpu/cpu collect
data_loader = build_dataloader(dataset,
samples_per_gpu=cfg.data.samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False,
round_up=False)
# build the model and load checkpoint
cfg.model['opt'] = args
model = build_architecture(cfg.model)
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
wrap_fp16_model(model)
if cfg.model.inference_type != 'gt':
load_checkpoint(model, args.checkpoint, map_location='cpu')
if not distributed:
if args.device == 'cpu':
model = model.cpu()
else:
model = MMDataParallel(model, device_ids=[0])
outputs = single_gpu_test(model, data_loader)
else:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
outputs = multi_gpu_test(model, data_loader, args.tmpdir,
args.gpu_collect)
rank, _ = get_dist_info()
if rank == 0:
mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
results = dataset.evaluate(outputs, args.work_dir)
for k, v in results.items():
print(f'\n{k} : {v:.4f}')
if args.out and rank == 0:
print(f'\nwriting results to {args.out}')
mmcv.dump(results, args.out)
if __name__ == '__main__':
main()