diff --git a/.gitignore b/.gitignore index b92e35f0..226e0d4d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ .idea/* .DS_Store -.vscode \ No newline at end of file +.vscode +__pycache__ +logs/ +vqa_logs/ \ No newline at end of file diff --git a/README.md b/README.md index dedeb439..42856b40 100644 --- a/README.md +++ b/README.md @@ -105,7 +105,7 @@ We list the parameters and pretrained checkpoints of OFAs below. For finetuned c

# Results -Below we demonstrate the results of OFAs on cross-modal understanding and generation. +Below we demonstrate the results of OFAs on cross-modal understanding and generation. You can find more results of MuE model in [MuE](https://arxiv.org/abs/2211.11152) @@ -254,6 +254,9 @@ We provide procedures to reproduce our results of image captioning on our paper cd run_scripts/caption nohup sh train_caption_stage1.sh > train_stage1.out & # stage 1, train with cross-entropy loss nohup sh train_caption_stage2.sh > train_stage2.out & # stage 2, load the best ckpt of stage1 and train with CIDEr optimization +# If you need to finetune MuE model, please apply the following script +nohup sh train_caption_stage1_base_MuE.sh > train_stage1.out & +# The stage2 uses the same script above
@@ -263,6 +266,9 @@ nohup sh train_caption_stage2.sh > train_stage2.out & # stage 2, load the best

 cd run_scripts/caption ; sh evaluate_caption.sh  # inference & evaluate
+# If you want to evaluate your MuE Model
+sh evaluate_caption_base_MuE.sh
+# You can adjust img_thres, txt_thres, and decoder_thres to achieve better performance and speed trade-off.
 
@@ -429,6 +435,8 @@ We provide steps for you to reproduce our results in visual entailment. See the
 cd run_scripts/snli_ve
 nohup sh train_snli_ve.sh > train_snli_ve.out &  # finetune for snli_ve
+# If you need to finetune MuE model, please apply the following script
+nohup sh train_snli_ve_base_MuE.sh > train_snli_ve_MuE.out &
 
@@ -438,6 +446,9 @@ nohup sh train_snli_ve.sh > train_snli_ve.out & # finetune for snli_ve

 cd run_scripts/snli_ve ; sh evaluate_snli_ve.sh dev  # specify 'dev' or 'test'
+# If you want to evaluate your MuE Model
+sh evaluate_snli_ve_base_MuE.sh
+# You can adjust img_thres, txt_thres, and decoder_thres to achieve better performance and speed trade-off.
 
@@ -600,5 +611,20 @@ Please cite our paper if you find it helpful :) volume = {abs/2202.03052}, year = {2022} } + +@article{tang2022you, + title={You Need Multiple Exiting: Dynamic Early Exiting for Accelerating Unified Vision Language Model}, + author={Tang, Shengkun and + Wang, Yaqing and + Kong, Zhenglun and + Zhang, Tianchi and + Li, Yao and + Ding, Caiwen and + Wang, Yanzhi and + Liang, Yi and + Xu, Dongkuan}, + journal={arXiv preprint arXiv:2211.11152}, + year={2022} +} ```

diff --git a/criterions/label_smoothed_cross_entropy.py b/criterions/label_smoothed_cross_entropy.py index 65175677..b1850c69 100644 --- a/criterions/label_smoothed_cross_entropy.py +++ b/criterions/label_smoothed_cross_entropy.py @@ -217,9 +217,12 @@ def forward(self, model, sample, update_num=0, reduce=True): def get_lprobs_and_target(self, model, net_output, sample): conf = sample['conf'][:, None, None] if 'conf' in sample and sample['conf'] is not None else 1 constraint_masks = None + # some weird bug will occur without this operation and following out-place operation. + # This operation doesn't change logic. + net_output = list(net_output) if "constraint_masks" in sample and sample["constraint_masks"] is not None: constraint_masks = sample["constraint_masks"] - net_output[0].masked_fill_(~constraint_masks, -math.inf) + net_output[0] = net_output[0].masked_fill_(~constraint_masks, -math.inf) if self.constraint_start is not None and self.constraint_end is not None: net_output[0][:, :, 4:self.constraint_start] = -math.inf net_output[0][:, :, self.constraint_end:] = -math.inf @@ -341,3 +344,65 @@ def logging_outputs_can_be_summed() -> bool: to True will improves distributed training speed. """ return True + +@register_criterion( + "MuE_Task_Loss", dataclass=AdjustLabelSmoothedCrossEntropyCriterionConfig +) +class MuE_Task_Loss(AdjustLabelSmoothedCrossEntropyCriterion): + def __init__( + self, + task, + sentence_avg, + label_smoothing, + ignore_prefix_size=0, + ignore_eos=False, + report_accuracy=False, + drop_worst_ratio=0, + drop_worst_after=0, + use_rdrop=False, + reg_alpha=1.0, + sample_patch_num=196, + constraint_range=None + ): + super().__init__( + task, + sentence_avg, + label_smoothing, + ignore_prefix_size, + ignore_eos, + report_accuracy, + drop_worst_ratio, + drop_worst_after, + use_rdrop, + reg_alpha, + sample_patch_num, + constraint_range) + + def compute_loss(self, model, net_output, sample, update_num, reduce=True): + loss_all = 0.0 + nll_loss_all = 0.0 + ntokens = 0 + print("using MuE Task loss") + for state in net_output[1]["inner_out_states"]: + lprobs, target, constraint_masks = self.get_lprobs_and_target(model, [state], sample) + if constraint_masks is not None: + constraint_masks = constraint_masks[target != self.padding_idx] + lprobs = lprobs[target != self.padding_idx] + target = target[target != self.padding_idx] + loss, nll_loss, ntokens = label_smoothed_nll_loss( + lprobs, + target, + self.eps, + update_num, + reduce=reduce, + drop_worst_ratio=self.drop_worst_ratio, + drop_worst_after=self.drop_worst_after, + use_rdrop=self.use_rdrop, + reg_alpha=self.reg_alpha, + constraint_masks=constraint_masks, + constraint_start=self.constraint_start, + constraint_end=self.constraint_end + ) + loss_all += loss + nll_loss_all += nll_loss + return loss_all, nll_loss_all, ntokens diff --git a/evaluate.py b/evaluate.py index e41d8a30..b460ce8d 100644 --- a/evaluate.py +++ b/evaluate.py @@ -160,8 +160,8 @@ def main(cfg: DictConfig, **kwargs): score_sum += sum([s[0] for s in scores]) score_cnt += sum([s[1] for s in scores]) else: - score_sum += sum(scores) if scores is not None else 0 - score_cnt += len(scores) if scores is not None else 0 + score_sum += sum(scores) + score_cnt += len(scores) progress.log({"sentences": sample["nsentences"]}) @@ -173,10 +173,17 @@ def cli_main(): parser.add_argument("--ema-eval", action='store_true', help="Use EMA weights to make evaluation.") parser.add_argument("--beam-search-vqa-eval", action='store_true', help="Use beam search for vqa evaluation (faster inference speed but sub-optimal result), if not specified, we compute scores for each answer in the candidate set, which is slower but can obtain best result.") parser.add_argument("--zero-shot", action='store_true') + parser.add_argument('--img_thres', type=float, metavar='D', default=1.0, + help='image theshold for early exiting model') + parser.add_argument('--txt_thres', type=float, metavar='D', default=1.0, + help='text theshold for early exiting model') + parser.add_argument('--decoder_thres', type=float, metavar='D', default=1.0, + help='decoder theshold for early exiting model') args = options.parse_args_and_arch(parser) cfg = convert_namespace_to_omegaconf(args) distributed_utils.call_main( - cfg, main, ema_eval=args.ema_eval, beam_search_vqa_eval=args.beam_search_vqa_eval, zero_shot=args.zero_shot + cfg, main, ema_eval=args.ema_eval, beam_search_vqa_eval=args.beam_search_vqa_eval, zero_shot=args.zero_shot, + img_thres=args.img_thres, txt_thres=args.txt_thres, decoder_thres=args.decoder_thres, is_train=False ) diff --git a/fairseq/fairseq/tasks/fairseq_task.py b/fairseq/fairseq/tasks/fairseq_task.py index d671f17c..096a6d47 100644 --- a/fairseq/fairseq/tasks/fairseq_task.py +++ b/fairseq/fairseq/tasks/fairseq_task.py @@ -511,11 +511,11 @@ def build_dataset_for_inference( raise NotImplementedError def inference_step( - self, generator, models, sample, prefix_tokens=None, constraints=None + self, generator, models, sample, prefix_tokens=None, constraints=None, **kwargs ): with torch.no_grad(): return generator.generate( - models, sample, prefix_tokens=prefix_tokens, constraints=constraints + models, sample, prefix_tokens=prefix_tokens, constraints=constraints, **kwargs ) def begin_epoch(self, epoch, model): diff --git a/models/ofa/unify_multihead_attention.py b/models/ofa/unify_multihead_attention.py index ee6a1fad..7aff4566 100644 --- a/models/ofa/unify_multihead_attention.py +++ b/models/ofa/unify_multihead_attention.py @@ -265,7 +265,8 @@ def forward( .view(-1, bsz * self.num_heads, self.head_dim) .transpose(0, 1) ) - + + saved_state_new = {} if saved_state is not None: # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) if "prev_key" in saved_state: @@ -298,13 +299,15 @@ def forward( src_len=k.size(1), static_kv=static_kv, ) - - saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) - saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) - saved_state["prev_key_padding_mask"] = key_padding_mask + # There are reference bugs if change saved_state directly. + # This causes error during inference in early exiting models (MuE) + # However, this has no influence on original OFA models. + saved_state_new["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) + saved_state_new["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) + saved_state_new["prev_key_padding_mask"] = key_padding_mask # In this branch incremental_state is never None assert incremental_state is not None - incremental_state = self._set_input_buffer(incremental_state, saved_state) + incremental_state = self._set_input_buffer(incremental_state, saved_state_new) assert k is not None assert k.size(1) == src_len diff --git a/models/ofa/unify_transformer.py b/models/ofa/unify_transformer.py index 3f2d04d2..43079ebb 100644 --- a/models/ofa/unify_transformer.py +++ b/models/ofa/unify_transformer.py @@ -178,6 +178,9 @@ def add_args(parser): parser.add_argument('--freeze-encoder', action='store_true', help='freeze the parameters in the encoder') + parser.add_argument('--train_mue', action='store_true', + help='use early exiting in model') + parser.add_argument('--adapter', action='store_true', help='use adapter in the model') @@ -424,16 +427,27 @@ def build_embedding(cls, args, dictionary, embed_dim, path=None): @classmethod def build_encoder(cls, args, src_dict, embed_tokens): - return TransformerEncoder(args, src_dict, embed_tokens) + if args.train_mue: + return TransformerEncoder_MuE(args, src_dict, embed_tokens) + else: + return TransformerEncoder(args, src_dict, embed_tokens) @classmethod def build_decoder(cls, args, tgt_dict, embed_tokens): - return TransformerDecoder( - args, - tgt_dict, - embed_tokens, - no_encoder_attn=getattr(args, "no_cross_attention", False), - ) + if args.train_mue: + return TransformerDecoder_MuE( + args, + tgt_dict, + embed_tokens, + no_encoder_attn=getattr(args, "no_cross_attention", False), + ) + else: + return TransformerDecoder( + args, + tgt_dict, + embed_tokens, + no_encoder_attn=getattr(args, "no_cross_attention", False), + ) # TorchScript doesn't support optional arguments with variable length (**kwargs). # Current workaround is to add union of all arguments in child classes. @@ -769,7 +783,8 @@ def forward( code_masks: Optional[torch.Tensor] = None, return_all_hiddens: bool = False, token_embeddings: Optional[torch.Tensor] = None, - sample_patch_num: Optional[int] = None + sample_patch_num: Optional[int] = None, + **kwargs ): """ Args: @@ -1320,6 +1335,7 @@ def forward( alignment_heads: Optional[int] = None, src_lengths: Optional[Any] = None, return_all_hiddens: bool = False, + **kwargs ): """ Args: @@ -1649,6 +1665,578 @@ def upgrade_state_dict_named(self, state_dict, name): return state_dict +class TransformerEncoder_MuE(TransformerEncoder): + """ + MuE encoder model from `"You Need Multiple Exiting: Dynamic Early Exiting for + Accelerating Unified Vision Language Model" (Tang, et al, 2022) + `_. + + Transformer encoder consisting of *args.encoder_layers* layers. Each layer + is a :class:`TransformerEncoderLayer`. + + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): encoding dictionary + embed_tokens (torch.nn.Embedding): input embedding + """ + + def __init__(self, args, dictionary, embed_tokens): + self.args = args + super().__init__(args, dictionary, embed_tokens) + + def forward_embedding( + self, + src_tokens, + image_embed: Optional[torch.Tensor] = None, + image_embed_2: Optional[torch.Tensor] = None, + token_embedding: Optional[torch.Tensor] = None, + pos_embed: Optional[torch.Tensor] = None, + image_pos_embed: Optional[torch.Tensor] = None, + image_pos_embed_2: Optional[torch.Tensor] = None + ): + # embed tokens and positions + if token_embedding is None: + token_embedding = self.embed_tokens(src_tokens) + x = embed = self.embed_scale * token_embedding + if self.entangle_position_embedding and pos_embed is not None: + x += pos_embed + if self.type_embedding is not None: + x += self.type_embedding(src_tokens.new_zeros(x.size()[:2])) + if self.layernorm_embedding is not None: + x = self.layernorm_embedding(x) + x = self.dropout_module(x) + if self.quant_noise is not None: + x = self.quant_noise(x) + + # embed raw images + if image_embed is not None: + image_embed = self.image_proj(image_embed) + image_x = image_embed = self.embed_scale * image_embed + if self.entangle_position_embedding and image_pos_embed is not None: + image_x += image_pos_embed + if self.type_embedding is not None: + image_x += self.type_embedding(src_tokens.new_ones(image_x.size()[:2])) + if self.patch_layernorm_embedding is not None: + image_x = self.patch_layernorm_embedding(image_x) + image_x = self.dropout_module(image_x) + if self.quant_noise is not None: + image_x = self.quant_noise(image_x) + # split image and text and send into encoder respectively. + x = [image_x, x] + embed = torch.cat([image_embed, embed], dim=1) + + if image_embed_2 is not None: + assert self.type_embedding is not None + image_embed_2 = self.image_proj(image_embed_2) + image_x_2 = image_embed_2 = self.embed_scale * image_embed_2 + if self.entangle_position_embedding and image_pos_embed_2 is not None: + image_x_2 += image_pos_embed_2 + if self.type_embedding is not None: + image_x_2 += self.type_embedding(src_tokens.new_full(image_x_2.size()[:2], fill_value=2)) + if self.patch_layernorm_embedding is not None: + image_x_2 = self.patch_layernorm_embedding(image_x_2) + image_x_2 = self.dropout_module(image_x_2) + if self.quant_noise is not None: + image_x_2 = self.quant_noise(image_x_2) + x = torch.cat([image_x_2, x], dim=1) + embed = torch.cat([image_embed_2, embed], dim=1) + + return x, embed + + def forward( + self, + src_tokens, + src_lengths, + is_train: bool = True, + patch_images: Optional[torch.Tensor] = None, + patch_images_2: Optional[torch.Tensor] = None, + patch_masks: Optional[torch.Tensor] = None, + code_masks: Optional[torch.Tensor] = None, + return_all_hiddens: bool = False, + token_embeddings: Optional[torch.Tensor] = None, + sample_patch_num: Optional[int] = None, + **kwargs + ): + """ + Args: + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (torch.LongTensor): lengths of each source sentence of + shape `(batch)` + return_all_hiddens (bool, optional): also return all of the + intermediate hidden states (default: False). + token_embeddings (torch.Tensor, optional): precomputed embeddings + default `None` will recompute embeddings + + Returns: + dict: + - **encoder_out** (Tensor): the last encoder layer's output of + shape `(src_len, batch, embed_dim)` + - **encoder_padding_mask** (ByteTensor): the positions of + padding elements of shape `(batch, src_len)` + - **encoder_embedding** (Tensor): the (scaled) embedding lookup + of shape `(batch, src_len, embed_dim)` + - **encoder_states** (List[Tensor]): all intermediate + hidden states of shape `(src_len, batch, embed_dim)`. + Only populated if *return_all_hiddens* is True. + """ + return self.forward_scriptable(src_tokens, + src_lengths, + is_train, + patch_images, + patch_images_2, + patch_masks, + return_all_hiddens, + token_embeddings, + sample_patch_num, + **kwargs) + + # TorchScript doesn't support super() method so that the scriptable Subclass + # can't access the base class model in Torchscript. + # Current workaround is to add a helper function with different name and + # call the helper function from scriptable Subclass. + def forward_scriptable( + self, + src_tokens, + src_lengths, + is_train: bool = True, + patch_images: Optional[torch.Tensor] = None, + patch_images_2: Optional[torch.Tensor] = None, + patch_masks: Optional[torch.Tensor] = None, + return_all_hiddens: bool = False, + token_embeddings: Optional[torch.Tensor] = None, + sample_patch_num: Optional[int] = None, + **kwargs + ): + """ + Args: + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (torch.LongTensor): lengths of each source sentence of + shape `(batch)` + return_all_hiddens (bool, optional): also return all of the + intermediate hidden states (default: False). + token_embeddings (torch.Tensor, optional): precomputed embeddings + default `None` will recompute embeddings + + Returns: + dict: + - **encoder_out** (Tensor): the last encoder layer's output of + shape `(src_len, batch, embed_dim)` + - **encoder_padding_mask** (ByteTensor): the positions of + padding elements of shape `(batch, src_len)` + - **encoder_embedding** (Tensor): the (scaled) embedding lookup + of shape `(batch, src_len, embed_dim)` + - **encoder_states** (List[Tensor]): all intermediate + hidden states of shape `(src_len, batch, embed_dim)`. + Only populated if *return_all_hiddens* is True. + """ + image_embed = None + image_embed_2 = None + image_pos_embed = None + image_pos_embed_2 = None + if patch_images is not None: + image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed = \ + self.get_patch_images_info(patch_images, sample_patch_num, src_tokens.device) + image_padding_mask[~patch_masks] = True + if patch_images_2 is not None: + image_embed_2, image_num_patches_2, image_padding_mask_2, image_position_ids_2, image_pos_embed_2 = \ + self.get_patch_images_info(patch_images_2, sample_patch_num, src_tokens.device) + image_padding_mask_2[~patch_masks] = True + + encoder_padding_mask = src_tokens.eq(self.padding_idx) + if patch_images is not None: + encoder_padding_mask_cat = torch.cat([image_padding_mask, encoder_padding_mask], dim=1) + if patch_images_2 is not None: + encoder_padding_mask = torch.cat([image_padding_mask_2, encoder_padding_mask], dim=1) + has_pads = (src_tokens.device.type == "xla" or encoder_padding_mask.any()) + + pos_embed = self.embed_positions(utils.new_arange(src_tokens)) + embed, encoder_embedding = self.forward_embedding( + src_tokens, image_embed, image_embed_2, token_embeddings, + pos_embed, image_pos_embed, image_pos_embed_2 + ) + x = embed[1] + image_x = embed[0] + + # account for padding while computing the representation + if has_pads: + x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) + image_x = image_x * (1 - image_padding_mask.unsqueeze(-1).type_as(image_x)) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + image_x = image_x.transpose(0, 1) + + pos_embed = self.pos_ln(pos_embed) + if patch_images is not None: + image_pos_embed = self.image_pos_ln(image_pos_embed) + pos_embed_cat = torch.cat([image_pos_embed, pos_embed], dim=1) + if patch_images_2 is not None: + image_pos_embed_2 = self.image_pos_ln(image_pos_embed_2) + pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1) + + + pos_q = self.pos_q_linear(pos_embed).view( + x.size(1), x.size(0), self.num_attention_heads, -1 + ).transpose(1, 2) * self.pos_scaling + pos_k = self.pos_k_linear(pos_embed).view( + x.size(1), x.size(0), self.num_attention_heads, -1 + ).transpose(1, 2) + abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) + + pos_q = self.pos_q_linear(image_pos_embed).view( + image_x.size(1), image_x.size(0), self.num_attention_heads, -1 + ).transpose(1, 2) * self.pos_scaling + pos_k = self.pos_k_linear(image_pos_embed).view( + image_x.size(1), image_x.size(0), self.num_attention_heads, -1 + ).transpose(1, 2) + image_abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) + + encoder_states = [] + encoder_states_img = [] + return_all_hiddens = True + if return_all_hiddens: + encoder_states.append(x) + encoder_states_img.append(image_x) + + # encoder layers + for idx, layer in enumerate(self.layers): + self_attn_bias = abs_pos_bias.clone() + self_attn_bias += self.get_rel_pos_bias(src_tokens, idx) + self_attn_bias = self_attn_bias.reshape(-1, x.size(0), x.size(0)) + + x = layer( + x, encoder_padding_mask=encoder_padding_mask if has_pads else None, self_attn_bias=self_attn_bias, + ) + if not is_train: + similarity = torch.cosine_similarity(F.normalize(x.clone().contiguous().view(1, -1)), + F.normalize(encoder_states[-1].clone().contiguous().view(1, -1))) + if similarity > kwargs["txt_thres"]: + break + if return_all_hiddens: + assert encoder_states is not None + encoder_states.append(x) + idx_y = idx + for idx, layer in enumerate(self.layers): + self_attn_bias = image_abs_pos_bias.clone() + if patch_images_2 is not None: + self_attn_bias[:, :, :image_num_patches_2, :image_num_patches_2] += \ + self.get_image_rel_pos_bias(image_position_ids_2, idx) + self_attn_bias[:, :, image_num_patches_2:image_num_patches_2+image_num_patches, image_num_patches_2:image_num_patches_2+image_num_patches] += \ + self.get_image_rel_pos_bias(image_position_ids, idx) + elif patch_images is not None: + self_attn_bias += \ + self.get_image_rel_pos_bias(image_position_ids, idx) + self_attn_bias = image_abs_pos_bias.reshape(-1, image_x.size(0), image_x.size(0)) + + image_x = layer( + image_x, encoder_padding_mask=image_padding_mask if has_pads else None, self_attn_bias=self_attn_bias + ) + if not is_train: + similarity = torch.cosine_similarity(F.normalize(image_x.clone().contiguous().view(1, -1)), + F.normalize(encoder_states_img[-1].clone().contiguous().view(1, -1))) + if similarity > kwargs["img_thres"]: + break + if return_all_hiddens: + assert encoder_states_img is not None + encoder_states_img.append(image_x) + + x = torch.cat([image_x, x]) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in + # `forward` so we use a dictionary instead. + # TorchScript does not support mixed values so the values are all lists. + # The empty list is equivalent to None. + return { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [encoder_padding_mask_cat], # B x T + "encoder_embedding": [], # B x T x C + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": [], + "src_lengths": [], + "position_embeddings": [pos_embed_cat], # B x T x C + "exit_layer": [idx+1, idx_y+1] + } + + +class TransformerDecoder_MuE(TransformerDecoder): + """ + Transformer decoder consisting of *args.decoder_layers* layers. Each layer + is a :class:`TransformerDecoderLayer`. + + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): decoding dictionary + embed_tokens (torch.nn.Embedding): output embedding + no_encoder_attn (bool, optional): whether to attend to encoder outputs + (default: False). + """ + + def __init__( + self, + args, + dictionary, + embed_tokens, + no_encoder_attn=False, + output_projection=None, + ): + self.args = args + super().__init__(args, dictionary, embed_tokens, no_encoder_attn, output_projection,) + + def forward( + self, + prev_output_tokens, + is_train: bool = True, + code_masks: Optional[torch.Tensor] = None, + encoder_out: Optional[Dict[str, List[Tensor]]] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + features_only: bool = False, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + src_lengths: Optional[Any] = None, + return_all_hiddens: bool = False, + **kwargs + ): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (optional): output from the encoder, used for + encoder-side attention, should be of size T x B x C + incremental_state (dict): dictionary used for storing state during + :ref:`Incremental decoding` + features_only (bool, optional): only return features without + applying output layer (default: False). + full_context_alignment (bool, optional): don't apply + auto-regressive mask to self-attention (default: False). + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + + x, extra = self.extract_features( + prev_output_tokens, + code_masks=code_masks, + encoder_out=encoder_out, + incremental_state=incremental_state, + full_context_alignment=full_context_alignment, + alignment_layer=alignment_layer, + alignment_heads=alignment_heads, + is_train=is_train, + **kwargs + ) + + if not features_only: + x = self.output_layer(x) + return x, extra + + def extract_features( + self, + prev_output_tokens, + code_masks: Optional[torch.Tensor], + encoder_out: Optional[Dict[str, List[Tensor]]], + is_train: bool = True, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + **kwargs + ): + return self.extract_features_scriptable( + prev_output_tokens, + code_masks, + encoder_out, + is_train, + incremental_state, + full_context_alignment, + alignment_layer, + alignment_heads, + **kwargs + ) + + """ + A scriptable subclass of this class has an extract_features method and calls + super().extract_features, but super() is not supported in torchscript. A copy of + this function is made to be used in the subclass instead. + """ + + def extract_features_scriptable( + self, + prev_output_tokens, + code_masks: Optional[torch.Tensor], + encoder_out: Optional[Dict[str, List[Tensor]]], + is_train: bool = True, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + **kwargs + ): + """ + Similar to *forward* but only return features. + + Includes several features from "Jointly Learning to Align and + Translate with Transformer Models" (Garg et al., EMNLP 2019). + + Args: + full_context_alignment (bool, optional): don't apply + auto-regressive mask to self-attention (default: False). + alignment_layer (int, optional): return mean alignment over + heads at this layer (default: last layer). + alignment_heads (int, optional): only average alignment over + this many heads (default: all heads). + + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + bs, slen = prev_output_tokens.size() + if alignment_layer is None: + alignment_layer = self.num_layers - 1 + + enc: Optional[Tensor] = None + padding_mask: Optional[Tensor] = None + if encoder_out is not None and len(encoder_out["encoder_out"]) > 0: + enc = encoder_out["encoder_out"][0] + assert ( + enc.size()[1] == bs + ), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}" + if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0: + padding_mask = encoder_out["encoder_padding_mask"][0] + + bsz, tgt_len = prev_output_tokens.shape + token_position_idx = utils.new_arange(prev_output_tokens) + tgt_pos_embed = self.embed_positions(token_position_idx) + if code_masks is not None and torch.any(code_masks): + image_position_idx = self.image_position_idx[:prev_output_tokens.size(1)].unsqueeze(0).expand(bsz, tgt_len) + tgt_pos_embed[code_masks] = self.embed_image_positions(image_position_idx)[code_masks] + + # self attn position bias + self_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, use_image=False) + + if code_masks is not None and torch.any(code_masks): + self_image_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, use_image=True) + self_abs_pos_bias[code_masks] = self_image_abs_pos_bias[code_masks] + # cross attn position bias + src_pos_embed = encoder_out['position_embeddings'][0] + cross_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, src_pos_embed=src_pos_embed) + if code_masks is not None and torch.any(code_masks): + cross_image_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, src_pos_embed=src_pos_embed, use_image=True) + cross_abs_pos_bias[code_masks] = cross_image_abs_pos_bias[code_masks] + cross_abs_pos_bias = cross_abs_pos_bias.reshape(-1, *cross_abs_pos_bias.size()[-2:]) + + all_prev_output_tokens = prev_output_tokens.clone() + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + cross_abs_pos_bias = cross_abs_pos_bias[:, -1:, :] + tgt_pos_embed = tgt_pos_embed[:, -1:, :] + + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(prev_output_tokens) + + if self.quant_noise is not None: + x = self.quant_noise(x) + + if self.project_in_dim is not None: + x = self.project_in_dim(x) + + if self.entangle_position_embedding is not None and not self.args.disable_entangle: + x += tgt_pos_embed + + if self.layernorm_embedding is not None: + if code_masks is None or not code_masks.any() or not getattr(self, "code_layernorm_embedding", False): + x = self.layernorm_embedding(x) + elif code_masks is not None and code_masks.all(): + x = self.code_layernorm_embedding(x) + else: + x[~code_masks] = self.layernorm_embedding(x[~code_masks]) + x[code_masks] = self.code_layernorm_embedding(x[code_masks]) + + x = self.dropout_module(x) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + self_attn_padding_mask: Optional[Tensor] = None + if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any(): + self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) + + # decoder layers + attn: Optional[Tensor] = None + inner_states: List[Optional[Tensor]] = [x] + inner_out_states: List[Optional[Tensor]] = [] + for idx, layer in enumerate(self.layers): + if incremental_state is None and not full_context_alignment: + self_attn_mask = self.buffered_future_mask(x) + else: + self_attn_mask = None + + self_attn_bias = self_abs_pos_bias.clone() + if code_masks is None or not code_masks.any(): + self_attn_bias += self.get_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0) + elif code_masks is not None and code_masks.all(): + self_attn_bias += self.get_image_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0) + else: + self_attn_bias[~code_masks] += self.get_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0) + self_attn_bias[code_masks] += self.get_image_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0) + self_attn_bias = self_attn_bias.reshape(-1, *self_attn_bias.size()[-2:]) + if incremental_state is not None: + self_attn_bias = self_attn_bias[:, -1:, :] + + x, layer_attn, saved_states = layer( + x, + enc, + padding_mask, + incremental_state, + self_attn_mask=self_attn_mask, + self_attn_padding_mask=self_attn_padding_mask, + need_attn=bool((idx == alignment_layer)), + need_head_weights=bool((idx == alignment_layer)), + self_attn_bias=self_attn_bias, + cross_attn_bias=cross_abs_pos_bias + ) + if not is_train: + similarity = torch.cosine_similarity(F.normalize(x.clone().contiguous().view(1, -1)), + F.normalize(inner_states[-1].clone().contiguous().view(1, -1))) + if similarity > kwargs["decoder_thres"]: + if saved_states is not None: + # Copy the state at early exited layer to skipped layers. + # Naive and simple operation but useful. Better methods could be explored. + for i in range(idx + 1, len(self.layers)): + incremental_state = self.layers[i].self_attn._set_input_buffer(incremental_state, saved_states[0]) + incremental_state = self.layers[i].encoder_attn._set_input_buffer(incremental_state, saved_states[1]) + break + inner_states.append(x) + inner_out_states.append(self.output_layer(self.layer_norm(x).transpose(0, 1))) + if layer_attn is not None and idx == alignment_layer: + attn = layer_attn.float().to(x) + + if attn is not None: + if alignment_heads is not None: + attn = attn[:alignment_heads] + + # average probabilities over heads + attn = attn.mean(dim=0) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + if self.project_out_dim is not None: + x = self.project_out_dim(x) + + return x, {"attn": [attn], "inner_states": inner_states, "exit_layer": idx + 1, "inner_out_states": inner_out_states} + def Embedding(num_embeddings, embedding_dim, padding_idx=None, zero_init=False): m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) @@ -1733,3 +2321,4 @@ def base_architecture(args): args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) + args.train_mue = getattr(args, "train_mue", False) diff --git a/models/sequence_generator.py b/models/sequence_generator.py index 13d5bdf7..4560bc1b 100644 --- a/models/sequence_generator.py +++ b/models/sequence_generator.py @@ -213,6 +213,7 @@ def _generate( prefix_tokens: Optional[Tensor] = None, constraints: Optional[Tensor] = None, bos_token: Optional[int] = None, + **kwargs ): model = EnsembleModel(models) incremental_states = torch.jit.annotate( @@ -223,7 +224,15 @@ def _generate( ], ) net_input = sample["net_input"] - + if "img_thres" in kwargs: + net_input["img_thres"] = kwargs["img_thres"] + if "txt_thres" in kwargs: + net_input["txt_thres"] = kwargs["txt_thres"] + if "is_train" in kwargs: + net_input["is_train"] = kwargs["is_train"] + if "decoder_thres" in kwargs: + net_input["decoder_thres"] = kwargs["decoder_thres"] + if "src_tokens" in net_input: src_tokens = net_input["src_tokens"] # length of the source text being the character length except EndOfSentence and pad @@ -333,6 +342,16 @@ def _generate( original_batch_idxs = torch.arange(0, bsz).type_as(tokens) for step in range(max_len + 1): # one extra step for EOS marker + + # Decay early exiting theshold for decoder layers + # small modification comparing with original paper. + # Details can be found in https://arxiv.org/abs/2211.11152 + if step < (max_len + 1) / 3: + decoder_thes = 1.0 + elif "decoder_thres" in kwargs: + decoder_thes = 0.9 * kwargs["decoder_thres"] + \ + 0.1 * math.exp(-1 * (step + 1) / (max_len + 1)) + # reorder decoder internal states based on the prev choice of beams if reorder_state is not None: if batch_idxs is not None: @@ -359,7 +378,8 @@ def _generate( constraint_end=self.constraint_end, gen_code=self.gen_code, zero_shot=self.zero_shot, - prefix_tokens=prefix_tokens + prefix_tokens=prefix_tokens, + decoder_thes=decoder_thes ) if self.lm_model is not None: @@ -818,7 +838,8 @@ def forward_decoder( constraint_end=None, gen_code=False, zero_shot=False, - prefix_tokens=None + prefix_tokens=None, + decoder_thes=1.0 ): log_probs = [] avg_attn: Optional[Tensor] = None @@ -834,6 +855,8 @@ def forward_decoder( code_masks=code_mask, encoder_out=encoder_out, incremental_state=incremental_states[i], + is_train=False, + decoder_thres=decoder_thes ) else: if hasattr(model, "decoder"): diff --git a/run_scripts/caption/evaluate_caption_base_MuE.sh b/run_scripts/caption/evaluate_caption_base_MuE.sh new file mode 100755 index 00000000..dca52733 --- /dev/null +++ b/run_scripts/caption/evaluate_caption_base_MuE.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash + +# The port for communication. Note that if you want to run multiple tasks on the same machine, +# you need to specify different port numbers. +export MASTER_PORT=1091 + +user_dir=../../ofa_module +bpe_dir=../../utils/BPE + +data=/data/tsk/caption_data/caption_test.tsv +path=/home/sht22008/tsk/projects/OFA/run_scripts/caption/stage1_checkpoints/{5,}_{0.06,}_{6000,}/checkpoint.best_cider_0.3240.pt +result_path=../../results/caption +selected_cols=1,4,2 +split='test' + +CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --nproc_per_node=4 --master_port=${MASTER_PORT} ../../evaluate.py \ + ${data} \ + --path=${path} \ + --user-dir=${user_dir} \ + --task=caption \ + --batch-size=1 \ + --log-format=simple --log-interval=10 \ + --seed=7 \ + --gen-subset=${split} \ + --results-path=${result_path} \ + --beam=5 \ + --max-len-b=16 \ + --no-repeat-ngram-size=3 \ + --fp16 \ + --num-workers=0 \ + --img_thres=0.99\ + --txt_thres=0.7\ + --decoder_thres=0.9\ + --model-overrides="{\"data\":\"${data}\",\"bpe_dir\":\"${bpe_dir}\",\"eval_cider\":False,\"selected_cols\":\"${selected_cols}\"}" + +python coco_eval.py ../../results/caption/test_predict.json ../../dataset/caption_data/test_caption_coco_format.json diff --git a/run_scripts/caption/train_caption_stage1_base_MuE.sh b/run_scripts/caption/train_caption_stage1_base_MuE.sh new file mode 100644 index 00000000..1d05d6c6 --- /dev/null +++ b/run_scripts/caption/train_caption_stage1_base_MuE.sh @@ -0,0 +1,109 @@ +#!/usr/bin/env + +# The port for communication. Note that if you want to run multiple tasks on the same machine, +# you need to specify different port numbers. +export MASTER_PORT=1061 + +log_dir=./stage1_logs +save_dir=./stage1_checkpoints +mkdir -p $log_dir $save_dir + +bpe_dir=../../utils/BPE +user_dir=../../ofa_module + +data_dir=../../dataset/caption_data +data=${data_dir}/caption_stage1_train.tsv,${data_dir}/caption_val.tsv +restore_file=../../checkpoints/ofa_base.pt +selected_cols=0,4,2 + +task=caption +arch=ofa_base +criterion=MuE_Task_Loss +label_smoothing=0.1 +lr=1e-5 +max_epoch=5 +warmup_ratio=0.06 +batch_size=2 +update_freq=4 +resnet_drop_path_rate=0.0 +encoder_drop_path_rate=0.1 +decoder_drop_path_rate=0.1 +dropout=0.1 +attention_dropout=0.0 +max_src_length=80 +max_tgt_length=20 +num_bins=1000 +patch_image_size=480 +eval_cider_cached=${data_dir}/cider_cached_tokens/coco-valid-words.p +drop_worst_ratio=0.2 + +for max_epoch in {5,}; do + echo "max_epoch "${max_epoch} + for warmup_ratio in {0.06,}; do + echo "warmup_ratio "${warmup_ratio} + for drop_worst_after in {6000,}; do + echo "drop_worst_after "${drop_worst_after} + + log_file=${log_dir}/${max_epoch}"_"${warmup_ratio}"_"${drop_worst_after}".log" + save_path=${save_dir}/${max_epoch}"_"${warmup_ratio}"_"${drop_worst_after} + mkdir -p $save_path + + CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --nproc_per_node=4 --master_port=${MASTER_PORT} ../../train.py \ + $data \ + --selected-cols=${selected_cols} \ + --bpe-dir=${bpe_dir} \ + --user-dir=${user_dir} \ + --restore-file=${restore_file} \ + --reset-optimizer --reset-dataloader --reset-meters \ + --save-dir=${save_path} \ + --task=${task} \ + --arch=${arch} \ + --criterion=${criterion} \ + --label-smoothing=${label_smoothing} \ + --batch-size=${batch_size} \ + --update-freq=${update_freq} \ + --encoder-normalize-before \ + --decoder-normalize-before \ + --share-decoder-input-output-embed \ + --share-all-embeddings \ + --layernorm-embedding \ + --patch-layernorm-embedding \ + --code-layernorm-embedding \ + --resnet-drop-path-rate=${resnet_drop_path_rate} \ + --encoder-drop-path-rate=${encoder_drop_path_rate} \ + --decoder-drop-path-rate=${decoder_drop_path_rate} \ + --dropout=${dropout} \ + --attention-dropout=${attention_dropout} \ + --weight-decay=0.01 --optimizer=adam --adam-betas="(0.9,0.999)" --adam-eps=1e-08 --clip-norm=1.0 \ + --lr-scheduler=polynomial_decay --lr=${lr} \ + --max-epoch=${max_epoch} --warmup-ratio=${warmup_ratio} \ + --log-format=simple --log-interval=10 \ + --fixed-validation-seed=7 \ + --no-epoch-checkpoints --keep-best-checkpoints=1 \ + --save-interval=1 --validate-interval=1 \ + --save-interval-updates=500 --validate-interval-updates=500 \ + --eval-cider \ + --eval-cider-cached-tokens=${eval_cider_cached} \ + --eval-args='{"beam":5,"max_len_b":16,"no_repeat_ngram_size":3}' \ + --best-checkpoint-metric=cider --maximize-best-checkpoint-metric \ + --max-src-length=${max_src_length} \ + --max-tgt-length=${max_tgt_length} \ + --find-unused-parameters \ + --freeze-encoder-embedding \ + --freeze-decoder-embedding \ + --add-type-embedding \ + --scale-attn \ + --scale-fc \ + --scale-heads \ + --disable-entangle \ + --num-bins=${num_bins} \ + --patch-image-size=${patch_image_size} \ + --drop-worst-ratio=${drop_worst_ratio} \ + --drop-worst-after=6000 \ + --fp16 \ + --fp16-scale-window=512 \ + --train_mue\ + --num-workers=0 > ${log_file} 2>&1 + done + done +done \ No newline at end of file diff --git a/run_scripts/snli_ve/evaluate_snli_ve_base_MuE.sh b/run_scripts/snli_ve/evaluate_snli_ve_base_MuE.sh new file mode 100755 index 00000000..5b9631c1 --- /dev/null +++ b/run_scripts/snli_ve/evaluate_snli_ve_base_MuE.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash + +# The port for communication. Note that if you want to run multiple tasks on the same machine, +# you need to specify different port numbers. +export MASTER_PORT=7091 + +user_dir=../../ofa_module +bpe_dir=../../utils/BPE + +# dev or test +split=test + +data=../../dataset/snli_ve_data/snli_ve_${split}.tsv +path=../../checkpoints/snli_ve_base_best.pt +result_path=../../results/snli_ve +selected_cols=0,2,3,4,5 + +CUDA_VISIBLE_DEVICES=0 python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=${MASTER_PORT} ../../evaluate.py \ + ${data} \ + --path=${path} \ + --user-dir=${user_dir} \ + --task=snli_ve \ + --batch-size=1 \ + --log-format=simple --log-interval=10 \ + --seed=7 \ + --gen-subset=${split} \ + --results-path=${result_path} \ + --fp16 \ + --num-workers=0 \ + --img_thres=0.7\ + --txt_thres=0.99\ + --decoder_thres=0.7\ + --model-overrides="{\"data\":\"${data}\",\"bpe_dir\":\"${bpe_dir}\",\"selected_cols\":\"${selected_cols}\"}" \ No newline at end of file diff --git a/run_scripts/snli_ve/train_snli_ve_base_MuE.sh b/run_scripts/snli_ve/train_snli_ve_base_MuE.sh new file mode 100644 index 00000000..fad48e01 --- /dev/null +++ b/run_scripts/snli_ve/train_snli_ve_base_MuE.sh @@ -0,0 +1,100 @@ +#!/usr/bin/env + +# The port for communication. Note that if you want to run multiple tasks on the same machine, +# you need to specify different port numbers. +export MASTER_PORT=7061 + +log_dir=./logs +save_dir=./checkpoints +mkdir -p $log_dir $save_dir + +bpe_dir=../../utils/BPE +user_dir=../../ofa_module + +data_dir=../../dataset/snli_ve_data +data=${data_dir}/snli_ve_train.tsv,${data_dir}/snli_ve_dev.tsv +restore_file=../../checkpoints/ofa_base.pt +selected_cols=0,2,3,4,5 + +task=snli_ve +arch=ofa_base +criterion=MuE_Task_Loss +label_smoothing=0.0 +lr=3e-5 +max_epoch=5 +warmup_ratio=0.06 +batch_size=4 +update_freq=8 +resnet_drop_path_rate=0.0 +encoder_drop_path_rate=0.1 +decoder_drop_path_rate=0.1 +dropout=0.1 +attention_dropout=0.0 +max_src_length=80 +max_tgt_length=20 +num_bins=1000 +patch_image_size=480 +prompt_type="prev_output" + +for max_epoch in {5,}; do + echo "max_epoch "${max_epoch} + for lr in {5e-5,}; do + echo "lr "${lr} + + log_file=${log_dir}/${max_epoch}"_"${lr}".log" + save_path=${save_dir}/${max_epoch}"_"${lr} + mkdir -p $save_path + + CUDA_VISIBLE_DEVICES=0,1 python3 -m torch.distributed.launch --nproc_per_node=2 --master_port=${MASTER_PORT} ../../train.py \ + $data \ + --selected-cols=${selected_cols} \ + --bpe-dir=${bpe_dir} \ + --user-dir=${user_dir} \ + --restore-file=${restore_file} \ + --reset-optimizer --reset-dataloader --reset-meters \ + --save-dir=${save_path} \ + --task=${task} \ + --arch=${arch} \ + --criterion=${criterion} \ + --label-smoothing=${label_smoothing} \ + --batch-size=${batch_size} \ + --update-freq=${update_freq} \ + --encoder-normalize-before \ + --decoder-normalize-before \ + --share-decoder-input-output-embed \ + --share-all-embeddings \ + --layernorm-embedding \ + --patch-layernorm-embedding \ + --code-layernorm-embedding \ + --resnet-drop-path-rate=${resnet_drop_path_rate} \ + --encoder-drop-path-rate=${encoder_drop_path_rate} \ + --decoder-drop-path-rate=${decoder_drop_path_rate} \ + --dropout=${dropout} \ + --attention-dropout=${attention_dropout} \ + --weight-decay=0.01 --optimizer=adam --adam-betas="(0.9,0.999)" --adam-eps=1e-08 --clip-norm=1.0 \ + --lr-scheduler=polynomial_decay --lr=${lr} \ + --max-epoch=${max_epoch} --warmup-ratio=${warmup_ratio} \ + --log-format=simple --log-interval=10 \ + --fixed-validation-seed=7 \ + --keep-best-checkpoints=1 \ + --save-interval=1 --validate-interval=1 \ + --save-interval-updates=500 --validate-interval-updates=500 \ + --best-checkpoint-metric=snli_score --maximize-best-checkpoint-metric \ + --max-src-length=${max_src_length} \ + --max-tgt-length=${max_tgt_length} \ + --find-unused-parameters \ + --add-type-embedding \ + --scale-attn \ + --scale-fc \ + --scale-heads \ + --disable-entangle \ + --num-bins=${num_bins} \ + --patch-image-size=${patch_image_size} \ + --prompt-type=${prompt_type} \ + --add-caption \ + --fp16 \ + --fp16-scale-window=512 \ + --train_mue\ + --num-workers=0 > ${log_file} 2>&1 + done +done \ No newline at end of file diff --git a/tasks/mm_tasks/caption.py b/tasks/mm_tasks/caption.py index f37c91be..8b67094f 100644 --- a/tasks/mm_tasks/caption.py +++ b/tasks/mm_tasks/caption.py @@ -227,7 +227,7 @@ def decode(toks, escape_unk=False): s = self.bpe.decode(s) return s - gen_out = self.inference_step(generator, [model], sample) + gen_out = self.inference_step(generator, [model], sample, is_train=True) hyps, refs = [], [] transtab = str.maketrans({key: None for key in string.punctuation}) for i in range(len(gen_out)): diff --git a/utils/eval_utils.py b/utils/eval_utils.py index a8bd9f7c..6b2e3c7e 100644 --- a/utils/eval_utils.py +++ b/utils/eval_utils.py @@ -47,7 +47,7 @@ def _calculate_error_rate(hyps, refs): def eval_caption(task, generator, models, sample, **kwargs): transtab = str.maketrans({key: None for key in string.punctuation}) - hypos = task.inference_step(generator, models, sample) + hypos = task.inference_step(generator, models, sample, **kwargs) results = [] for i, sample_id in enumerate(sample["id"].tolist()): detok_hypo_str = decode_fn(hypos[i][0]["tokens"], task.tgt_dict, task.bpe, generator) @@ -56,7 +56,7 @@ def eval_caption(task, generator, models, sample, **kwargs): def eval_caption_cn(task, generator, models, sample, **kwargs): - hypos = task.inference_step(generator, models, sample) + hypos = task.inference_step(generator, models, sample, **kwargs) results = [] for i, sample_id in enumerate(sample["id"].tolist()): detok_hypo_str = decode_fn( @@ -72,7 +72,7 @@ def eval_caption_cn(task, generator, models, sample, **kwargs): def eval_ocr(task, generator, models, sample, **kwargs): - gen_out = task.inference_step(generator, models, sample) + gen_out = task.inference_step(generator, models, sample, **kwargs) hyps, refs, results = [], [], [] for i, sample_id in enumerate(sample["id"].tolist()): decode_tokens = decode_fn(gen_out[i][0]["tokens"], task.tgt_dict, task.bpe, generator).strip() @@ -102,7 +102,7 @@ def eval_ocr(task, generator, models, sample, **kwargs): def eval_vqa_gen(task, generator, models, sample, **kwargs): if kwargs['beam_search_vqa_eval']: - hypos = task.inference_step(generator, models, sample, prefix_tokens=sample['prefix_tokens']) + hypos = task.inference_step(generator, models, sample, prefix_tokens=sample['prefix_tokens'], **kwargs) results = [] for i, sample_id in enumerate(sample["id"].tolist()): prefix_len = sample['prefix_tokens'][i].ne(1).sum().item() @@ -185,7 +185,7 @@ def _calculate_ap_score(hyps, refs, thresh=0.5): ious = area_interacts / (area_predictions + area_targets - area_interacts + 1e-6) return ((ious >= thresh) & (interacts_w > 0) & (interacts_h > 0)).float() - gen_out = task.inference_step(generator, models, sample) + gen_out = task.inference_step(generator, models, sample, **kwargs) hyps = [] for i in range(len(gen_out)): hyps.append(gen_out[i][0]["tokens"][:-1] - len(task.src_dict) + task.cfg.num_bins) @@ -208,7 +208,8 @@ def eval_snli_ve(task, generator, models, sample, **kwargs): sample["net_input"]["src_tokens"], src_lengths=sample["net_input"]["src_lengths"], patch_images=sample["net_input"]["patch_images"], - patch_masks=sample["net_input"]["patch_masks"] + patch_masks=sample["net_input"]["patch_masks"], + **kwargs ) device = sample["net_input"]["src_tokens"].device eos_item = torch.tensor([task.src_dict.eos()]) @@ -246,7 +247,7 @@ def eval_snli_ve(task, generator, models, sample, **kwargs): encoder_out["position_embeddings"][0].repeat_interleave(valid_size, dim=0) ] - decoder_out = models[0].decoder(valid_prev_output, encoder_out=new_encoder_out) + decoder_out = models[0].decoder(valid_prev_output, encoder_out=new_encoder_out, **kwargs) decoder_out[0].masked_fill_(~valid_constraint_masks, -math.inf) lprobs = models[0].get_normalized_probs(decoder_out, log_probs=True) scores = lprobs.gather(dim=-1, index=valid_tgt.unsqueeze(-1)).squeeze(-1)