-
Notifications
You must be signed in to change notification settings - Fork 73
/
Copy pathtrain_caption.py
208 lines (176 loc) · 10.3 KB
/
train_caption.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, visit
# https://github.com/NVlabs/prismer/blob/main/LICENSE
import argparse
import numpy as np
import random
import time
import functools
import json
import torch
import os
try:
import ruamel_yaml as yaml
except ModuleNotFoundError:
import ruamel.yaml as yaml
from accelerate import Accelerator, FullyShardedDataParallelPlugin
from model.prismer_caption import PrismerCaption
from model.modules.utils import interpolate_pos_embed
from dataset import create_dataset, create_loader
from utils import *
from tqdm import tqdm
parser = argparse.ArgumentParser()
parser.add_argument('--mode', default='')
parser.add_argument('--port', default='')
parser.add_argument('--config', default='configs/caption.yaml')
parser.add_argument('--from_checkpoint', action='store_true')
parser.add_argument('--evaluate', action='store_true')
parser.add_argument('--target_dataset', default='coco', type=str)
parser.add_argument('--shard_grad_op', action='store_true')
parser.add_argument('--full_shard', action='store_true')
parser.add_argument('--exp_name', default='', type=str)
parser.add_argument('--mixed_precision', default='fp16', type=str)
parser.add_argument('--seed', default=42, type=int)
args = parser.parse_args()
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)[args.target_dataset]
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
train_dataset, test_dataset = create_dataset('caption', config)
train_loader = create_loader(train_dataset, batch_size=config['batch_size_train'], num_workers=8, train=True)
test_loader = create_loader(test_dataset, batch_size=config['batch_size_test'], num_workers=8, train=False)
model = PrismerCaption(config)
tokenizer = model.tokenizer
if args.shard_grad_op: # Model Sharding: ZeRO 2
from torch.distributed.fsdp import MixedPrecision, BackwardPrefetch, ShardingStrategy, StateDictType
fsdp_plugin = FullyShardedDataParallelPlugin(sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
mixed_precision_policy=MixedPrecision(param_dtype=torch.float16,
reduce_dtype=torch.float16,
buffer_dtype=torch.float16),
state_dict_type=StateDictType.FULL_STATE_DICT,
ignored_modules=model.ignored_modules)
accelerator = Accelerator(mixed_precision=args.mixed_precision, fsdp_plugin=fsdp_plugin)
model = accelerator.prepare(model)
elif args.full_shard: # Model Sharding: ZeRO 3
from torch.distributed.fsdp import MixedPrecision, BackwardPrefetch, ShardingStrategy, StateDictType
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from model.modules.vit import ResidualAttentionBlock
from model.modules.resampler import PerceiverAttentionBlock
from model.modules.roberta import RobertaLayer
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
ResidualAttentionBlock,
PerceiverAttentionBlock,
RobertaLayer
},
)
fsdp_plugin = FullyShardedDataParallelPlugin(sharding_strategy=ShardingStrategy.FULL_SHARD,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
mixed_precision_policy=MixedPrecision(param_dtype=torch.float16,
reduce_dtype=torch.float16,
buffer_dtype=torch.float16),
state_dict_type=StateDictType.FULL_STATE_DICT,
auto_wrap_policy=auto_wrap_policy,
ignored_modules=model.ignored_modules)
accelerator = Accelerator(mixed_precision=args.mixed_precision, fsdp_plugin=fsdp_plugin)
model = accelerator.prepare(model)
else:
accelerator = Accelerator(mixed_precision=args.mixed_precision)
# Reload saved states
if not args.from_checkpoint:
state_dict = torch.load(f'logging/pretrain_{args.exp_name}/pytorch_model.bin', map_location='cpu')
state_dict['expert_encoder.positional_embedding'] = interpolate_pos_embed(state_dict['expert_encoder.positional_embedding'],
len(model.expert_encoder.positional_embedding))
model.load_state_dict(state_dict)
start_epoch = 0
else:
state_dict = torch.load(f'logging/caption_{args.exp_name}/pytorch_model.bin', map_location='cpu')
if os.path.exists(f'logging/caption_{args.exp_name}/epoch.pt'):
start_epoch = torch.load(f'logging/caption_{args.exp_name}/epoch.pt')[0] + 1
else:
start_epoch = 0
model.load_state_dict(state_dict)
accelerator.print(f'Start re-training from checkpoint with Epoch {start_epoch}')
optimizer = torch.optim.AdamW(params=filter(lambda p: p.requires_grad, model.parameters()),
lr=config['init_lr'], weight_decay=config['weight_decay'])
if args.shard_grad_op or args.full_shard:
optimizer, train_loader, test_loader = accelerator.prepare(optimizer, train_loader, test_loader)
else:
model, optimizer, train_loader, test_loader = accelerator.prepare(model, optimizer, train_loader, test_loader)
best = 0
start_time = time.time()
if not args.evaluate:
for epoch in range(start_epoch, config['max_epoch']):
train_loss = 0
num_train_elems = 0
model.train()
for i, (experts, caption) in enumerate(tqdm(train_loader)):
cosine_lr_schedule(optimizer, epoch * len(train_loader) + i, config['max_epoch'] * len(train_loader), config['init_lr'], config['min_lr'])
loss = model(experts, caption, prefix=config['prefix'])
optimizer.zero_grad()
accelerator.backward(loss)
optimizer.step()
train_loss += loss.item()
num_train_elems += 1
model.eval()
result = []
with torch.no_grad():
for step, (experts, data_ids) in enumerate(tqdm(test_loader)):
captions = model(experts, train=False, prefix=config['prefix'])
if accelerator.use_distributed:
captions = tokenizer(captions, max_length=30, padding='max_length', return_tensors='pt').input_ids
captions = captions.to(experts['rgb'].device)
data_ids, captions = accelerator.gather_for_metrics((data_ids, captions))
for data_id, caption in zip(data_ids, captions):
caption = tokenizer.decode(caption, skip_special_tokens=True)
if args.target_dataset == 'coco':
image_id = int(test_loader.dataset.data_list[data_id]['image'].split('/')[-1].strip('.jpg').split('_')[-1])
result.append({"image_id": image_id, "caption": caption.capitalize() + '.'})
elif args.target_dataset == 'nocaps':
result.append({"image_id": test_loader.dataset.data_list[data_id]['img_id'],
"caption": caption.capitalize() + '.'})
accelerator.wait_for_everyone()
if accelerator.is_main_process:
json.dump(result, open(f'/results/caption_results_{args.exp_name}_{args.target_dataset}.json', 'w'))
if args.target_dataset == 'coco':
coco_eval = coco_caption_eval(f'{config["data_path"]}/coco_karpathy_test_gt.json', result)
torch.save([coco_eval.eval['CIDEr']], f'logging/caption_{args.exp_name}/temp_cider.pt')
if not os.path.isfile(f'logging/caption_{args.exp_name}/cider.pt'):
torch.save([coco_eval.eval['CIDEr']], f'logging/caption_{args.exp_name}/cider.pt')
accelerator.wait_for_everyone()
cider = torch.load(f'logging/caption_{args.exp_name}/cider.pt')[0]
curr_cider = torch.load(f'logging/caption_{args.exp_name}/temp_cider.pt')[0]
if cider < curr_cider:
train_loss /= num_train_elems
accelerator.print(f"Epoch {epoch:03d} | loss: {train_loss:.4f} || Time: {(time.time() - start_time):.4f}")
accelerator.save_state(f'logging/caption_{args.exp_name}')
accelerator.save([epoch], f'logging/caption_{args.exp_name}/epoch.pt')
accelerator.save([curr_cider], f'logging/caption_{args.exp_name}/cider.pt')
model.eval()
if accelerator.is_main_process:
result = []
with torch.no_grad():
for step, (experts, data_ids) in enumerate(tqdm(test_loader)):
captions = model(experts, train=False, prefix=config['prefix'])
if accelerator.use_distributed:
captions = tokenizer(captions, max_length=30, padding='max_length', return_tensors='pt').input_ids
captions = captions.to(experts['rgb'].device)
data_ids, captions = accelerator.gather_for_metrics((data_ids, captions))
if accelerator.is_main_process:
for data_id, caption in zip(data_ids, captions):
caption = tokenizer.decode(caption, skip_special_tokens=True)
if args.target_dataset == 'coco':
image_id = int(test_loader.dataset.data_list[data_id]['image'].split('/')[-1].strip('.jpg').split('_')[-1])
result.append({"image_id": image_id, "caption": caption.capitalize() + '.'})
elif args.target_dataset == 'nocaps':
result.append({"image_id": test_loader.dataset.data_list[data_id]['img_id'],
"caption": caption.capitalize() + '.'})
accelerator.wait_for_everyone()
if accelerator.is_main_process:
json.dump(result, open(f'/results/caption_results_{args.exp_name}_{args.target_dataset}.json', 'w'))
if args.target_dataset == 'coco':
coco_caption_eval(f'{config["data_path"]}/coco_karpathy_test_gt.json', result)