diff --git a/.gitignore b/.gitignore
index b92e35f0..226e0d4d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,6 @@
.idea/*
.DS_Store
-.vscode
\ No newline at end of file
+.vscode
+__pycache__
+logs/
+vqa_logs/
\ No newline at end of file
diff --git a/README.md b/README.md
index dedeb439..42856b40 100644
--- a/README.md
+++ b/README.md
@@ -105,7 +105,7 @@ We list the parameters and pretrained checkpoints of OFAs below. For finetuned c
# Results
-Below we demonstrate the results of OFAs on cross-modal understanding and generation.
+Below we demonstrate the results of OFAs on cross-modal understanding and generation. You can find more results of MuE model in [MuE](https://arxiv.org/abs/2211.11152)
@@ -254,6 +254,9 @@ We provide procedures to reproduce our results of image captioning on our paper
cd run_scripts/caption
nohup sh train_caption_stage1.sh > train_stage1.out & # stage 1, train with cross-entropy loss
nohup sh train_caption_stage2.sh > train_stage2.out & # stage 2, load the best ckpt of stage1 and train with CIDEr optimization
+# If you need to finetune MuE model, please apply the following script
+nohup sh train_caption_stage1_base_MuE.sh > train_stage1.out &
+# The stage2 uses the same script above
@@ -263,6 +266,9 @@ nohup sh train_caption_stage2.sh > train_stage2.out & # stage 2, load the best
cd run_scripts/caption ; sh evaluate_caption.sh # inference & evaluate
+# If you want to evaluate your MuE Model
+sh evaluate_caption_base_MuE.sh
+# You can adjust img_thres, txt_thres, and decoder_thres to achieve better performance and speed trade-off.
@@ -429,6 +435,8 @@ We provide steps for you to reproduce our results in visual entailment. See the
cd run_scripts/snli_ve
nohup sh train_snli_ve.sh > train_snli_ve.out & # finetune for snli_ve
+# If you need to finetune MuE model, please apply the following script
+nohup sh train_snli_ve_base_MuE.sh > train_snli_ve_MuE.out &
@@ -438,6 +446,9 @@ nohup sh train_snli_ve.sh > train_snli_ve.out & # finetune for snli_ve
cd run_scripts/snli_ve ; sh evaluate_snli_ve.sh dev # specify 'dev' or 'test'
+# If you want to evaluate your MuE Model
+sh evaluate_snli_ve_base_MuE.sh
+# You can adjust img_thres, txt_thres, and decoder_thres to achieve better performance and speed trade-off.
@@ -600,5 +611,20 @@ Please cite our paper if you find it helpful :)
volume = {abs/2202.03052},
year = {2022}
}
+
+@article{tang2022you,
+ title={You Need Multiple Exiting: Dynamic Early Exiting for Accelerating Unified Vision Language Model},
+ author={Tang, Shengkun and
+ Wang, Yaqing and
+ Kong, Zhenglun and
+ Zhang, Tianchi and
+ Li, Yao and
+ Ding, Caiwen and
+ Wang, Yanzhi and
+ Liang, Yi and
+ Xu, Dongkuan},
+ journal={arXiv preprint arXiv:2211.11152},
+ year={2022}
+}
```
diff --git a/criterions/label_smoothed_cross_entropy.py b/criterions/label_smoothed_cross_entropy.py
index 65175677..b1850c69 100644
--- a/criterions/label_smoothed_cross_entropy.py
+++ b/criterions/label_smoothed_cross_entropy.py
@@ -217,9 +217,12 @@ def forward(self, model, sample, update_num=0, reduce=True):
def get_lprobs_and_target(self, model, net_output, sample):
conf = sample['conf'][:, None, None] if 'conf' in sample and sample['conf'] is not None else 1
constraint_masks = None
+ # some weird bug will occur without this operation and following out-place operation.
+ # This operation doesn't change logic.
+ net_output = list(net_output)
if "constraint_masks" in sample and sample["constraint_masks"] is not None:
constraint_masks = sample["constraint_masks"]
- net_output[0].masked_fill_(~constraint_masks, -math.inf)
+ net_output[0] = net_output[0].masked_fill_(~constraint_masks, -math.inf)
if self.constraint_start is not None and self.constraint_end is not None:
net_output[0][:, :, 4:self.constraint_start] = -math.inf
net_output[0][:, :, self.constraint_end:] = -math.inf
@@ -341,3 +344,65 @@ def logging_outputs_can_be_summed() -> bool:
to True will improves distributed training speed.
"""
return True
+
+@register_criterion(
+ "MuE_Task_Loss", dataclass=AdjustLabelSmoothedCrossEntropyCriterionConfig
+)
+class MuE_Task_Loss(AdjustLabelSmoothedCrossEntropyCriterion):
+ def __init__(
+ self,
+ task,
+ sentence_avg,
+ label_smoothing,
+ ignore_prefix_size=0,
+ ignore_eos=False,
+ report_accuracy=False,
+ drop_worst_ratio=0,
+ drop_worst_after=0,
+ use_rdrop=False,
+ reg_alpha=1.0,
+ sample_patch_num=196,
+ constraint_range=None
+ ):
+ super().__init__(
+ task,
+ sentence_avg,
+ label_smoothing,
+ ignore_prefix_size,
+ ignore_eos,
+ report_accuracy,
+ drop_worst_ratio,
+ drop_worst_after,
+ use_rdrop,
+ reg_alpha,
+ sample_patch_num,
+ constraint_range)
+
+ def compute_loss(self, model, net_output, sample, update_num, reduce=True):
+ loss_all = 0.0
+ nll_loss_all = 0.0
+ ntokens = 0
+ print("using MuE Task loss")
+ for state in net_output[1]["inner_out_states"]:
+ lprobs, target, constraint_masks = self.get_lprobs_and_target(model, [state], sample)
+ if constraint_masks is not None:
+ constraint_masks = constraint_masks[target != self.padding_idx]
+ lprobs = lprobs[target != self.padding_idx]
+ target = target[target != self.padding_idx]
+ loss, nll_loss, ntokens = label_smoothed_nll_loss(
+ lprobs,
+ target,
+ self.eps,
+ update_num,
+ reduce=reduce,
+ drop_worst_ratio=self.drop_worst_ratio,
+ drop_worst_after=self.drop_worst_after,
+ use_rdrop=self.use_rdrop,
+ reg_alpha=self.reg_alpha,
+ constraint_masks=constraint_masks,
+ constraint_start=self.constraint_start,
+ constraint_end=self.constraint_end
+ )
+ loss_all += loss
+ nll_loss_all += nll_loss
+ return loss_all, nll_loss_all, ntokens
diff --git a/evaluate.py b/evaluate.py
index e41d8a30..b460ce8d 100644
--- a/evaluate.py
+++ b/evaluate.py
@@ -160,8 +160,8 @@ def main(cfg: DictConfig, **kwargs):
score_sum += sum([s[0] for s in scores])
score_cnt += sum([s[1] for s in scores])
else:
- score_sum += sum(scores) if scores is not None else 0
- score_cnt += len(scores) if scores is not None else 0
+ score_sum += sum(scores)
+ score_cnt += len(scores)
progress.log({"sentences": sample["nsentences"]})
@@ -173,10 +173,17 @@ def cli_main():
parser.add_argument("--ema-eval", action='store_true', help="Use EMA weights to make evaluation.")
parser.add_argument("--beam-search-vqa-eval", action='store_true', help="Use beam search for vqa evaluation (faster inference speed but sub-optimal result), if not specified, we compute scores for each answer in the candidate set, which is slower but can obtain best result.")
parser.add_argument("--zero-shot", action='store_true')
+ parser.add_argument('--img_thres', type=float, metavar='D', default=1.0,
+ help='image theshold for early exiting model')
+ parser.add_argument('--txt_thres', type=float, metavar='D', default=1.0,
+ help='text theshold for early exiting model')
+ parser.add_argument('--decoder_thres', type=float, metavar='D', default=1.0,
+ help='decoder theshold for early exiting model')
args = options.parse_args_and_arch(parser)
cfg = convert_namespace_to_omegaconf(args)
distributed_utils.call_main(
- cfg, main, ema_eval=args.ema_eval, beam_search_vqa_eval=args.beam_search_vqa_eval, zero_shot=args.zero_shot
+ cfg, main, ema_eval=args.ema_eval, beam_search_vqa_eval=args.beam_search_vqa_eval, zero_shot=args.zero_shot,
+ img_thres=args.img_thres, txt_thres=args.txt_thres, decoder_thres=args.decoder_thres, is_train=False
)
diff --git a/fairseq/fairseq/tasks/fairseq_task.py b/fairseq/fairseq/tasks/fairseq_task.py
index d671f17c..096a6d47 100644
--- a/fairseq/fairseq/tasks/fairseq_task.py
+++ b/fairseq/fairseq/tasks/fairseq_task.py
@@ -511,11 +511,11 @@ def build_dataset_for_inference(
raise NotImplementedError
def inference_step(
- self, generator, models, sample, prefix_tokens=None, constraints=None
+ self, generator, models, sample, prefix_tokens=None, constraints=None, **kwargs
):
with torch.no_grad():
return generator.generate(
- models, sample, prefix_tokens=prefix_tokens, constraints=constraints
+ models, sample, prefix_tokens=prefix_tokens, constraints=constraints, **kwargs
)
def begin_epoch(self, epoch, model):
diff --git a/models/ofa/unify_multihead_attention.py b/models/ofa/unify_multihead_attention.py
index ee6a1fad..7aff4566 100644
--- a/models/ofa/unify_multihead_attention.py
+++ b/models/ofa/unify_multihead_attention.py
@@ -265,7 +265,8 @@ def forward(
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
-
+
+ saved_state_new = {}
if saved_state is not None:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if "prev_key" in saved_state:
@@ -298,13 +299,15 @@ def forward(
src_len=k.size(1),
static_kv=static_kv,
)
-
- saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
- saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
- saved_state["prev_key_padding_mask"] = key_padding_mask
+ # There are reference bugs if change saved_state directly.
+ # This causes error during inference in early exiting models (MuE)
+ # However, this has no influence on original OFA models.
+ saved_state_new["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
+ saved_state_new["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
+ saved_state_new["prev_key_padding_mask"] = key_padding_mask
# In this branch incremental_state is never None
assert incremental_state is not None
- incremental_state = self._set_input_buffer(incremental_state, saved_state)
+ incremental_state = self._set_input_buffer(incremental_state, saved_state_new)
assert k is not None
assert k.size(1) == src_len
diff --git a/models/ofa/unify_transformer.py b/models/ofa/unify_transformer.py
index 3f2d04d2..43079ebb 100644
--- a/models/ofa/unify_transformer.py
+++ b/models/ofa/unify_transformer.py
@@ -178,6 +178,9 @@ def add_args(parser):
parser.add_argument('--freeze-encoder', action='store_true',
help='freeze the parameters in the encoder')
+ parser.add_argument('--train_mue', action='store_true',
+ help='use early exiting in model')
+
parser.add_argument('--adapter', action='store_true',
help='use adapter in the model')
@@ -424,16 +427,27 @@ def build_embedding(cls, args, dictionary, embed_dim, path=None):
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
- return TransformerEncoder(args, src_dict, embed_tokens)
+ if args.train_mue:
+ return TransformerEncoder_MuE(args, src_dict, embed_tokens)
+ else:
+ return TransformerEncoder(args, src_dict, embed_tokens)
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
- return TransformerDecoder(
- args,
- tgt_dict,
- embed_tokens,
- no_encoder_attn=getattr(args, "no_cross_attention", False),
- )
+ if args.train_mue:
+ return TransformerDecoder_MuE(
+ args,
+ tgt_dict,
+ embed_tokens,
+ no_encoder_attn=getattr(args, "no_cross_attention", False),
+ )
+ else:
+ return TransformerDecoder(
+ args,
+ tgt_dict,
+ embed_tokens,
+ no_encoder_attn=getattr(args, "no_cross_attention", False),
+ )
# TorchScript doesn't support optional arguments with variable length (**kwargs).
# Current workaround is to add union of all arguments in child classes.
@@ -769,7 +783,8 @@ def forward(
code_masks: Optional[torch.Tensor] = None,
return_all_hiddens: bool = False,
token_embeddings: Optional[torch.Tensor] = None,
- sample_patch_num: Optional[int] = None
+ sample_patch_num: Optional[int] = None,
+ **kwargs
):
"""
Args:
@@ -1320,6 +1335,7 @@ def forward(
alignment_heads: Optional[int] = None,
src_lengths: Optional[Any] = None,
return_all_hiddens: bool = False,
+ **kwargs
):
"""
Args:
@@ -1649,6 +1665,578 @@ def upgrade_state_dict_named(self, state_dict, name):
return state_dict
+class TransformerEncoder_MuE(TransformerEncoder):
+ """
+ MuE encoder model from `"You Need Multiple Exiting: Dynamic Early Exiting for
+ Accelerating Unified Vision Language Model" (Tang, et al, 2022)
+ `_.
+
+ Transformer encoder consisting of *args.encoder_layers* layers. Each layer
+ is a :class:`TransformerEncoderLayer`.
+
+ Args:
+ args (argparse.Namespace): parsed command-line arguments
+ dictionary (~fairseq.data.Dictionary): encoding dictionary
+ embed_tokens (torch.nn.Embedding): input embedding
+ """
+
+ def __init__(self, args, dictionary, embed_tokens):
+ self.args = args
+ super().__init__(args, dictionary, embed_tokens)
+
+ def forward_embedding(
+ self,
+ src_tokens,
+ image_embed: Optional[torch.Tensor] = None,
+ image_embed_2: Optional[torch.Tensor] = None,
+ token_embedding: Optional[torch.Tensor] = None,
+ pos_embed: Optional[torch.Tensor] = None,
+ image_pos_embed: Optional[torch.Tensor] = None,
+ image_pos_embed_2: Optional[torch.Tensor] = None
+ ):
+ # embed tokens and positions
+ if token_embedding is None:
+ token_embedding = self.embed_tokens(src_tokens)
+ x = embed = self.embed_scale * token_embedding
+ if self.entangle_position_embedding and pos_embed is not None:
+ x += pos_embed
+ if self.type_embedding is not None:
+ x += self.type_embedding(src_tokens.new_zeros(x.size()[:2]))
+ if self.layernorm_embedding is not None:
+ x = self.layernorm_embedding(x)
+ x = self.dropout_module(x)
+ if self.quant_noise is not None:
+ x = self.quant_noise(x)
+
+ # embed raw images
+ if image_embed is not None:
+ image_embed = self.image_proj(image_embed)
+ image_x = image_embed = self.embed_scale * image_embed
+ if self.entangle_position_embedding and image_pos_embed is not None:
+ image_x += image_pos_embed
+ if self.type_embedding is not None:
+ image_x += self.type_embedding(src_tokens.new_ones(image_x.size()[:2]))
+ if self.patch_layernorm_embedding is not None:
+ image_x = self.patch_layernorm_embedding(image_x)
+ image_x = self.dropout_module(image_x)
+ if self.quant_noise is not None:
+ image_x = self.quant_noise(image_x)
+ # split image and text and send into encoder respectively.
+ x = [image_x, x]
+ embed = torch.cat([image_embed, embed], dim=1)
+
+ if image_embed_2 is not None:
+ assert self.type_embedding is not None
+ image_embed_2 = self.image_proj(image_embed_2)
+ image_x_2 = image_embed_2 = self.embed_scale * image_embed_2
+ if self.entangle_position_embedding and image_pos_embed_2 is not None:
+ image_x_2 += image_pos_embed_2
+ if self.type_embedding is not None:
+ image_x_2 += self.type_embedding(src_tokens.new_full(image_x_2.size()[:2], fill_value=2))
+ if self.patch_layernorm_embedding is not None:
+ image_x_2 = self.patch_layernorm_embedding(image_x_2)
+ image_x_2 = self.dropout_module(image_x_2)
+ if self.quant_noise is not None:
+ image_x_2 = self.quant_noise(image_x_2)
+ x = torch.cat([image_x_2, x], dim=1)
+ embed = torch.cat([image_embed_2, embed], dim=1)
+
+ return x, embed
+
+ def forward(
+ self,
+ src_tokens,
+ src_lengths,
+ is_train: bool = True,
+ patch_images: Optional[torch.Tensor] = None,
+ patch_images_2: Optional[torch.Tensor] = None,
+ patch_masks: Optional[torch.Tensor] = None,
+ code_masks: Optional[torch.Tensor] = None,
+ return_all_hiddens: bool = False,
+ token_embeddings: Optional[torch.Tensor] = None,
+ sample_patch_num: Optional[int] = None,
+ **kwargs
+ ):
+ """
+ Args:
+ src_tokens (LongTensor): tokens in the source language of shape
+ `(batch, src_len)`
+ src_lengths (torch.LongTensor): lengths of each source sentence of
+ shape `(batch)`
+ return_all_hiddens (bool, optional): also return all of the
+ intermediate hidden states (default: False).
+ token_embeddings (torch.Tensor, optional): precomputed embeddings
+ default `None` will recompute embeddings
+
+ Returns:
+ dict:
+ - **encoder_out** (Tensor): the last encoder layer's output of
+ shape `(src_len, batch, embed_dim)`
+ - **encoder_padding_mask** (ByteTensor): the positions of
+ padding elements of shape `(batch, src_len)`
+ - **encoder_embedding** (Tensor): the (scaled) embedding lookup
+ of shape `(batch, src_len, embed_dim)`
+ - **encoder_states** (List[Tensor]): all intermediate
+ hidden states of shape `(src_len, batch, embed_dim)`.
+ Only populated if *return_all_hiddens* is True.
+ """
+ return self.forward_scriptable(src_tokens,
+ src_lengths,
+ is_train,
+ patch_images,
+ patch_images_2,
+ patch_masks,
+ return_all_hiddens,
+ token_embeddings,
+ sample_patch_num,
+ **kwargs)
+
+ # TorchScript doesn't support super() method so that the scriptable Subclass
+ # can't access the base class model in Torchscript.
+ # Current workaround is to add a helper function with different name and
+ # call the helper function from scriptable Subclass.
+ def forward_scriptable(
+ self,
+ src_tokens,
+ src_lengths,
+ is_train: bool = True,
+ patch_images: Optional[torch.Tensor] = None,
+ patch_images_2: Optional[torch.Tensor] = None,
+ patch_masks: Optional[torch.Tensor] = None,
+ return_all_hiddens: bool = False,
+ token_embeddings: Optional[torch.Tensor] = None,
+ sample_patch_num: Optional[int] = None,
+ **kwargs
+ ):
+ """
+ Args:
+ src_tokens (LongTensor): tokens in the source language of shape
+ `(batch, src_len)`
+ src_lengths (torch.LongTensor): lengths of each source sentence of
+ shape `(batch)`
+ return_all_hiddens (bool, optional): also return all of the
+ intermediate hidden states (default: False).
+ token_embeddings (torch.Tensor, optional): precomputed embeddings
+ default `None` will recompute embeddings
+
+ Returns:
+ dict:
+ - **encoder_out** (Tensor): the last encoder layer's output of
+ shape `(src_len, batch, embed_dim)`
+ - **encoder_padding_mask** (ByteTensor): the positions of
+ padding elements of shape `(batch, src_len)`
+ - **encoder_embedding** (Tensor): the (scaled) embedding lookup
+ of shape `(batch, src_len, embed_dim)`
+ - **encoder_states** (List[Tensor]): all intermediate
+ hidden states of shape `(src_len, batch, embed_dim)`.
+ Only populated if *return_all_hiddens* is True.
+ """
+ image_embed = None
+ image_embed_2 = None
+ image_pos_embed = None
+ image_pos_embed_2 = None
+ if patch_images is not None:
+ image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed = \
+ self.get_patch_images_info(patch_images, sample_patch_num, src_tokens.device)
+ image_padding_mask[~patch_masks] = True
+ if patch_images_2 is not None:
+ image_embed_2, image_num_patches_2, image_padding_mask_2, image_position_ids_2, image_pos_embed_2 = \
+ self.get_patch_images_info(patch_images_2, sample_patch_num, src_tokens.device)
+ image_padding_mask_2[~patch_masks] = True
+
+ encoder_padding_mask = src_tokens.eq(self.padding_idx)
+ if patch_images is not None:
+ encoder_padding_mask_cat = torch.cat([image_padding_mask, encoder_padding_mask], dim=1)
+ if patch_images_2 is not None:
+ encoder_padding_mask = torch.cat([image_padding_mask_2, encoder_padding_mask], dim=1)
+ has_pads = (src_tokens.device.type == "xla" or encoder_padding_mask.any())
+
+ pos_embed = self.embed_positions(utils.new_arange(src_tokens))
+ embed, encoder_embedding = self.forward_embedding(
+ src_tokens, image_embed, image_embed_2, token_embeddings,
+ pos_embed, image_pos_embed, image_pos_embed_2
+ )
+ x = embed[1]
+ image_x = embed[0]
+
+ # account for padding while computing the representation
+ if has_pads:
+ x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
+ image_x = image_x * (1 - image_padding_mask.unsqueeze(-1).type_as(image_x))
+
+ # B x T x C -> T x B x C
+ x = x.transpose(0, 1)
+ image_x = image_x.transpose(0, 1)
+
+ pos_embed = self.pos_ln(pos_embed)
+ if patch_images is not None:
+ image_pos_embed = self.image_pos_ln(image_pos_embed)
+ pos_embed_cat = torch.cat([image_pos_embed, pos_embed], dim=1)
+ if patch_images_2 is not None:
+ image_pos_embed_2 = self.image_pos_ln(image_pos_embed_2)
+ pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1)
+
+
+ pos_q = self.pos_q_linear(pos_embed).view(
+ x.size(1), x.size(0), self.num_attention_heads, -1
+ ).transpose(1, 2) * self.pos_scaling
+ pos_k = self.pos_k_linear(pos_embed).view(
+ x.size(1), x.size(0), self.num_attention_heads, -1
+ ).transpose(1, 2)
+ abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
+
+ pos_q = self.pos_q_linear(image_pos_embed).view(
+ image_x.size(1), image_x.size(0), self.num_attention_heads, -1
+ ).transpose(1, 2) * self.pos_scaling
+ pos_k = self.pos_k_linear(image_pos_embed).view(
+ image_x.size(1), image_x.size(0), self.num_attention_heads, -1
+ ).transpose(1, 2)
+ image_abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
+
+ encoder_states = []
+ encoder_states_img = []
+ return_all_hiddens = True
+ if return_all_hiddens:
+ encoder_states.append(x)
+ encoder_states_img.append(image_x)
+
+ # encoder layers
+ for idx, layer in enumerate(self.layers):
+ self_attn_bias = abs_pos_bias.clone()
+ self_attn_bias += self.get_rel_pos_bias(src_tokens, idx)
+ self_attn_bias = self_attn_bias.reshape(-1, x.size(0), x.size(0))
+
+ x = layer(
+ x, encoder_padding_mask=encoder_padding_mask if has_pads else None, self_attn_bias=self_attn_bias,
+ )
+ if not is_train:
+ similarity = torch.cosine_similarity(F.normalize(x.clone().contiguous().view(1, -1)),
+ F.normalize(encoder_states[-1].clone().contiguous().view(1, -1)))
+ if similarity > kwargs["txt_thres"]:
+ break
+ if return_all_hiddens:
+ assert encoder_states is not None
+ encoder_states.append(x)
+ idx_y = idx
+ for idx, layer in enumerate(self.layers):
+ self_attn_bias = image_abs_pos_bias.clone()
+ if patch_images_2 is not None:
+ self_attn_bias[:, :, :image_num_patches_2, :image_num_patches_2] += \
+ self.get_image_rel_pos_bias(image_position_ids_2, idx)
+ self_attn_bias[:, :, image_num_patches_2:image_num_patches_2+image_num_patches, image_num_patches_2:image_num_patches_2+image_num_patches] += \
+ self.get_image_rel_pos_bias(image_position_ids, idx)
+ elif patch_images is not None:
+ self_attn_bias += \
+ self.get_image_rel_pos_bias(image_position_ids, idx)
+ self_attn_bias = image_abs_pos_bias.reshape(-1, image_x.size(0), image_x.size(0))
+
+ image_x = layer(
+ image_x, encoder_padding_mask=image_padding_mask if has_pads else None, self_attn_bias=self_attn_bias
+ )
+ if not is_train:
+ similarity = torch.cosine_similarity(F.normalize(image_x.clone().contiguous().view(1, -1)),
+ F.normalize(encoder_states_img[-1].clone().contiguous().view(1, -1)))
+ if similarity > kwargs["img_thres"]:
+ break
+ if return_all_hiddens:
+ assert encoder_states_img is not None
+ encoder_states_img.append(image_x)
+
+ x = torch.cat([image_x, x])
+
+ if self.layer_norm is not None:
+ x = self.layer_norm(x)
+
+ # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
+ # `forward` so we use a dictionary instead.
+ # TorchScript does not support mixed values so the values are all lists.
+ # The empty list is equivalent to None.
+ return {
+ "encoder_out": [x], # T x B x C
+ "encoder_padding_mask": [encoder_padding_mask_cat], # B x T
+ "encoder_embedding": [], # B x T x C
+ "encoder_states": encoder_states, # List[T x B x C]
+ "src_tokens": [],
+ "src_lengths": [],
+ "position_embeddings": [pos_embed_cat], # B x T x C
+ "exit_layer": [idx+1, idx_y+1]
+ }
+
+
+class TransformerDecoder_MuE(TransformerDecoder):
+ """
+ Transformer decoder consisting of *args.decoder_layers* layers. Each layer
+ is a :class:`TransformerDecoderLayer`.
+
+ Args:
+ args (argparse.Namespace): parsed command-line arguments
+ dictionary (~fairseq.data.Dictionary): decoding dictionary
+ embed_tokens (torch.nn.Embedding): output embedding
+ no_encoder_attn (bool, optional): whether to attend to encoder outputs
+ (default: False).
+ """
+
+ def __init__(
+ self,
+ args,
+ dictionary,
+ embed_tokens,
+ no_encoder_attn=False,
+ output_projection=None,
+ ):
+ self.args = args
+ super().__init__(args, dictionary, embed_tokens, no_encoder_attn, output_projection,)
+
+ def forward(
+ self,
+ prev_output_tokens,
+ is_train: bool = True,
+ code_masks: Optional[torch.Tensor] = None,
+ encoder_out: Optional[Dict[str, List[Tensor]]] = None,
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
+ features_only: bool = False,
+ full_context_alignment: bool = False,
+ alignment_layer: Optional[int] = None,
+ alignment_heads: Optional[int] = None,
+ src_lengths: Optional[Any] = None,
+ return_all_hiddens: bool = False,
+ **kwargs
+ ):
+ """
+ Args:
+ prev_output_tokens (LongTensor): previous decoder outputs of shape
+ `(batch, tgt_len)`, for teacher forcing
+ encoder_out (optional): output from the encoder, used for
+ encoder-side attention, should be of size T x B x C
+ incremental_state (dict): dictionary used for storing state during
+ :ref:`Incremental decoding`
+ features_only (bool, optional): only return features without
+ applying output layer (default: False).
+ full_context_alignment (bool, optional): don't apply
+ auto-regressive mask to self-attention (default: False).
+
+ Returns:
+ tuple:
+ - the decoder's output of shape `(batch, tgt_len, vocab)`
+ - a dictionary with any model-specific outputs
+ """
+
+ x, extra = self.extract_features(
+ prev_output_tokens,
+ code_masks=code_masks,
+ encoder_out=encoder_out,
+ incremental_state=incremental_state,
+ full_context_alignment=full_context_alignment,
+ alignment_layer=alignment_layer,
+ alignment_heads=alignment_heads,
+ is_train=is_train,
+ **kwargs
+ )
+
+ if not features_only:
+ x = self.output_layer(x)
+ return x, extra
+
+ def extract_features(
+ self,
+ prev_output_tokens,
+ code_masks: Optional[torch.Tensor],
+ encoder_out: Optional[Dict[str, List[Tensor]]],
+ is_train: bool = True,
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
+ full_context_alignment: bool = False,
+ alignment_layer: Optional[int] = None,
+ alignment_heads: Optional[int] = None,
+ **kwargs
+ ):
+ return self.extract_features_scriptable(
+ prev_output_tokens,
+ code_masks,
+ encoder_out,
+ is_train,
+ incremental_state,
+ full_context_alignment,
+ alignment_layer,
+ alignment_heads,
+ **kwargs
+ )
+
+ """
+ A scriptable subclass of this class has an extract_features method and calls
+ super().extract_features, but super() is not supported in torchscript. A copy of
+ this function is made to be used in the subclass instead.
+ """
+
+ def extract_features_scriptable(
+ self,
+ prev_output_tokens,
+ code_masks: Optional[torch.Tensor],
+ encoder_out: Optional[Dict[str, List[Tensor]]],
+ is_train: bool = True,
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
+ full_context_alignment: bool = False,
+ alignment_layer: Optional[int] = None,
+ alignment_heads: Optional[int] = None,
+ **kwargs
+ ):
+ """
+ Similar to *forward* but only return features.
+
+ Includes several features from "Jointly Learning to Align and
+ Translate with Transformer Models" (Garg et al., EMNLP 2019).
+
+ Args:
+ full_context_alignment (bool, optional): don't apply
+ auto-regressive mask to self-attention (default: False).
+ alignment_layer (int, optional): return mean alignment over
+ heads at this layer (default: last layer).
+ alignment_heads (int, optional): only average alignment over
+ this many heads (default: all heads).
+
+ Returns:
+ tuple:
+ - the decoder's features of shape `(batch, tgt_len, embed_dim)`
+ - a dictionary with any model-specific outputs
+ """
+ bs, slen = prev_output_tokens.size()
+ if alignment_layer is None:
+ alignment_layer = self.num_layers - 1
+
+ enc: Optional[Tensor] = None
+ padding_mask: Optional[Tensor] = None
+ if encoder_out is not None and len(encoder_out["encoder_out"]) > 0:
+ enc = encoder_out["encoder_out"][0]
+ assert (
+ enc.size()[1] == bs
+ ), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}"
+ if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0:
+ padding_mask = encoder_out["encoder_padding_mask"][0]
+
+ bsz, tgt_len = prev_output_tokens.shape
+ token_position_idx = utils.new_arange(prev_output_tokens)
+ tgt_pos_embed = self.embed_positions(token_position_idx)
+ if code_masks is not None and torch.any(code_masks):
+ image_position_idx = self.image_position_idx[:prev_output_tokens.size(1)].unsqueeze(0).expand(bsz, tgt_len)
+ tgt_pos_embed[code_masks] = self.embed_image_positions(image_position_idx)[code_masks]
+
+ # self attn position bias
+ self_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, use_image=False)
+
+ if code_masks is not None and torch.any(code_masks):
+ self_image_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, use_image=True)
+ self_abs_pos_bias[code_masks] = self_image_abs_pos_bias[code_masks]
+ # cross attn position bias
+ src_pos_embed = encoder_out['position_embeddings'][0]
+ cross_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, src_pos_embed=src_pos_embed)
+ if code_masks is not None and torch.any(code_masks):
+ cross_image_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, src_pos_embed=src_pos_embed, use_image=True)
+ cross_abs_pos_bias[code_masks] = cross_image_abs_pos_bias[code_masks]
+ cross_abs_pos_bias = cross_abs_pos_bias.reshape(-1, *cross_abs_pos_bias.size()[-2:])
+
+ all_prev_output_tokens = prev_output_tokens.clone()
+ if incremental_state is not None:
+ prev_output_tokens = prev_output_tokens[:, -1:]
+ cross_abs_pos_bias = cross_abs_pos_bias[:, -1:, :]
+ tgt_pos_embed = tgt_pos_embed[:, -1:, :]
+
+ # embed tokens and positions
+ x = self.embed_scale * self.embed_tokens(prev_output_tokens)
+
+ if self.quant_noise is not None:
+ x = self.quant_noise(x)
+
+ if self.project_in_dim is not None:
+ x = self.project_in_dim(x)
+
+ if self.entangle_position_embedding is not None and not self.args.disable_entangle:
+ x += tgt_pos_embed
+
+ if self.layernorm_embedding is not None:
+ if code_masks is None or not code_masks.any() or not getattr(self, "code_layernorm_embedding", False):
+ x = self.layernorm_embedding(x)
+ elif code_masks is not None and code_masks.all():
+ x = self.code_layernorm_embedding(x)
+ else:
+ x[~code_masks] = self.layernorm_embedding(x[~code_masks])
+ x[code_masks] = self.code_layernorm_embedding(x[code_masks])
+
+ x = self.dropout_module(x)
+
+ # B x T x C -> T x B x C
+ x = x.transpose(0, 1)
+
+ self_attn_padding_mask: Optional[Tensor] = None
+ if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
+ self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
+
+ # decoder layers
+ attn: Optional[Tensor] = None
+ inner_states: List[Optional[Tensor]] = [x]
+ inner_out_states: List[Optional[Tensor]] = []
+ for idx, layer in enumerate(self.layers):
+ if incremental_state is None and not full_context_alignment:
+ self_attn_mask = self.buffered_future_mask(x)
+ else:
+ self_attn_mask = None
+
+ self_attn_bias = self_abs_pos_bias.clone()
+ if code_masks is None or not code_masks.any():
+ self_attn_bias += self.get_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0)
+ elif code_masks is not None and code_masks.all():
+ self_attn_bias += self.get_image_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0)
+ else:
+ self_attn_bias[~code_masks] += self.get_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0)
+ self_attn_bias[code_masks] += self.get_image_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0)
+ self_attn_bias = self_attn_bias.reshape(-1, *self_attn_bias.size()[-2:])
+ if incremental_state is not None:
+ self_attn_bias = self_attn_bias[:, -1:, :]
+
+ x, layer_attn, saved_states = layer(
+ x,
+ enc,
+ padding_mask,
+ incremental_state,
+ self_attn_mask=self_attn_mask,
+ self_attn_padding_mask=self_attn_padding_mask,
+ need_attn=bool((idx == alignment_layer)),
+ need_head_weights=bool((idx == alignment_layer)),
+ self_attn_bias=self_attn_bias,
+ cross_attn_bias=cross_abs_pos_bias
+ )
+ if not is_train:
+ similarity = torch.cosine_similarity(F.normalize(x.clone().contiguous().view(1, -1)),
+ F.normalize(inner_states[-1].clone().contiguous().view(1, -1)))
+ if similarity > kwargs["decoder_thres"]:
+ if saved_states is not None:
+ # Copy the state at early exited layer to skipped layers.
+ # Naive and simple operation but useful. Better methods could be explored.
+ for i in range(idx + 1, len(self.layers)):
+ incremental_state = self.layers[i].self_attn._set_input_buffer(incremental_state, saved_states[0])
+ incremental_state = self.layers[i].encoder_attn._set_input_buffer(incremental_state, saved_states[1])
+ break
+ inner_states.append(x)
+ inner_out_states.append(self.output_layer(self.layer_norm(x).transpose(0, 1)))
+ if layer_attn is not None and idx == alignment_layer:
+ attn = layer_attn.float().to(x)
+
+ if attn is not None:
+ if alignment_heads is not None:
+ attn = attn[:alignment_heads]
+
+ # average probabilities over heads
+ attn = attn.mean(dim=0)
+
+ if self.layer_norm is not None:
+ x = self.layer_norm(x)
+
+ # T x B x C -> B x T x C
+ x = x.transpose(0, 1)
+
+ if self.project_out_dim is not None:
+ x = self.project_out_dim(x)
+
+ return x, {"attn": [attn], "inner_states": inner_states, "exit_layer": idx + 1, "inner_out_states": inner_out_states}
+
def Embedding(num_embeddings, embedding_dim, padding_idx=None, zero_init=False):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
@@ -1733,3 +2321,4 @@ def base_architecture(args):
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)
+ args.train_mue = getattr(args, "train_mue", False)
diff --git a/models/sequence_generator.py b/models/sequence_generator.py
index 13d5bdf7..4560bc1b 100644
--- a/models/sequence_generator.py
+++ b/models/sequence_generator.py
@@ -213,6 +213,7 @@ def _generate(
prefix_tokens: Optional[Tensor] = None,
constraints: Optional[Tensor] = None,
bos_token: Optional[int] = None,
+ **kwargs
):
model = EnsembleModel(models)
incremental_states = torch.jit.annotate(
@@ -223,7 +224,15 @@ def _generate(
],
)
net_input = sample["net_input"]
-
+ if "img_thres" in kwargs:
+ net_input["img_thres"] = kwargs["img_thres"]
+ if "txt_thres" in kwargs:
+ net_input["txt_thres"] = kwargs["txt_thres"]
+ if "is_train" in kwargs:
+ net_input["is_train"] = kwargs["is_train"]
+ if "decoder_thres" in kwargs:
+ net_input["decoder_thres"] = kwargs["decoder_thres"]
+
if "src_tokens" in net_input:
src_tokens = net_input["src_tokens"]
# length of the source text being the character length except EndOfSentence and pad
@@ -333,6 +342,16 @@ def _generate(
original_batch_idxs = torch.arange(0, bsz).type_as(tokens)
for step in range(max_len + 1): # one extra step for EOS marker
+
+ # Decay early exiting theshold for decoder layers
+ # small modification comparing with original paper.
+ # Details can be found in https://arxiv.org/abs/2211.11152
+ if step < (max_len + 1) / 3:
+ decoder_thes = 1.0
+ elif "decoder_thres" in kwargs:
+ decoder_thes = 0.9 * kwargs["decoder_thres"] + \
+ 0.1 * math.exp(-1 * (step + 1) / (max_len + 1))
+
# reorder decoder internal states based on the prev choice of beams
if reorder_state is not None:
if batch_idxs is not None:
@@ -359,7 +378,8 @@ def _generate(
constraint_end=self.constraint_end,
gen_code=self.gen_code,
zero_shot=self.zero_shot,
- prefix_tokens=prefix_tokens
+ prefix_tokens=prefix_tokens,
+ decoder_thes=decoder_thes
)
if self.lm_model is not None:
@@ -818,7 +838,8 @@ def forward_decoder(
constraint_end=None,
gen_code=False,
zero_shot=False,
- prefix_tokens=None
+ prefix_tokens=None,
+ decoder_thes=1.0
):
log_probs = []
avg_attn: Optional[Tensor] = None
@@ -834,6 +855,8 @@ def forward_decoder(
code_masks=code_mask,
encoder_out=encoder_out,
incremental_state=incremental_states[i],
+ is_train=False,
+ decoder_thres=decoder_thes
)
else:
if hasattr(model, "decoder"):
diff --git a/run_scripts/caption/evaluate_caption_base_MuE.sh b/run_scripts/caption/evaluate_caption_base_MuE.sh
new file mode 100755
index 00000000..dca52733
--- /dev/null
+++ b/run_scripts/caption/evaluate_caption_base_MuE.sh
@@ -0,0 +1,36 @@
+#!/usr/bin/env bash
+
+# The port for communication. Note that if you want to run multiple tasks on the same machine,
+# you need to specify different port numbers.
+export MASTER_PORT=1091
+
+user_dir=../../ofa_module
+bpe_dir=../../utils/BPE
+
+data=/data/tsk/caption_data/caption_test.tsv
+path=/home/sht22008/tsk/projects/OFA/run_scripts/caption/stage1_checkpoints/{5,}_{0.06,}_{6000,}/checkpoint.best_cider_0.3240.pt
+result_path=../../results/caption
+selected_cols=1,4,2
+split='test'
+
+CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --nproc_per_node=4 --master_port=${MASTER_PORT} ../../evaluate.py \
+ ${data} \
+ --path=${path} \
+ --user-dir=${user_dir} \
+ --task=caption \
+ --batch-size=1 \
+ --log-format=simple --log-interval=10 \
+ --seed=7 \
+ --gen-subset=${split} \
+ --results-path=${result_path} \
+ --beam=5 \
+ --max-len-b=16 \
+ --no-repeat-ngram-size=3 \
+ --fp16 \
+ --num-workers=0 \
+ --img_thres=0.99\
+ --txt_thres=0.7\
+ --decoder_thres=0.9\
+ --model-overrides="{\"data\":\"${data}\",\"bpe_dir\":\"${bpe_dir}\",\"eval_cider\":False,\"selected_cols\":\"${selected_cols}\"}"
+
+python coco_eval.py ../../results/caption/test_predict.json ../../dataset/caption_data/test_caption_coco_format.json
diff --git a/run_scripts/caption/train_caption_stage1_base_MuE.sh b/run_scripts/caption/train_caption_stage1_base_MuE.sh
new file mode 100644
index 00000000..1d05d6c6
--- /dev/null
+++ b/run_scripts/caption/train_caption_stage1_base_MuE.sh
@@ -0,0 +1,109 @@
+#!/usr/bin/env
+
+# The port for communication. Note that if you want to run multiple tasks on the same machine,
+# you need to specify different port numbers.
+export MASTER_PORT=1061
+
+log_dir=./stage1_logs
+save_dir=./stage1_checkpoints
+mkdir -p $log_dir $save_dir
+
+bpe_dir=../../utils/BPE
+user_dir=../../ofa_module
+
+data_dir=../../dataset/caption_data
+data=${data_dir}/caption_stage1_train.tsv,${data_dir}/caption_val.tsv
+restore_file=../../checkpoints/ofa_base.pt
+selected_cols=0,4,2
+
+task=caption
+arch=ofa_base
+criterion=MuE_Task_Loss
+label_smoothing=0.1
+lr=1e-5
+max_epoch=5
+warmup_ratio=0.06
+batch_size=2
+update_freq=4
+resnet_drop_path_rate=0.0
+encoder_drop_path_rate=0.1
+decoder_drop_path_rate=0.1
+dropout=0.1
+attention_dropout=0.0
+max_src_length=80
+max_tgt_length=20
+num_bins=1000
+patch_image_size=480
+eval_cider_cached=${data_dir}/cider_cached_tokens/coco-valid-words.p
+drop_worst_ratio=0.2
+
+for max_epoch in {5,}; do
+ echo "max_epoch "${max_epoch}
+ for warmup_ratio in {0.06,}; do
+ echo "warmup_ratio "${warmup_ratio}
+ for drop_worst_after in {6000,}; do
+ echo "drop_worst_after "${drop_worst_after}
+
+ log_file=${log_dir}/${max_epoch}"_"${warmup_ratio}"_"${drop_worst_after}".log"
+ save_path=${save_dir}/${max_epoch}"_"${warmup_ratio}"_"${drop_worst_after}
+ mkdir -p $save_path
+
+ CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --nproc_per_node=4 --master_port=${MASTER_PORT} ../../train.py \
+ $data \
+ --selected-cols=${selected_cols} \
+ --bpe-dir=${bpe_dir} \
+ --user-dir=${user_dir} \
+ --restore-file=${restore_file} \
+ --reset-optimizer --reset-dataloader --reset-meters \
+ --save-dir=${save_path} \
+ --task=${task} \
+ --arch=${arch} \
+ --criterion=${criterion} \
+ --label-smoothing=${label_smoothing} \
+ --batch-size=${batch_size} \
+ --update-freq=${update_freq} \
+ --encoder-normalize-before \
+ --decoder-normalize-before \
+ --share-decoder-input-output-embed \
+ --share-all-embeddings \
+ --layernorm-embedding \
+ --patch-layernorm-embedding \
+ --code-layernorm-embedding \
+ --resnet-drop-path-rate=${resnet_drop_path_rate} \
+ --encoder-drop-path-rate=${encoder_drop_path_rate} \
+ --decoder-drop-path-rate=${decoder_drop_path_rate} \
+ --dropout=${dropout} \
+ --attention-dropout=${attention_dropout} \
+ --weight-decay=0.01 --optimizer=adam --adam-betas="(0.9,0.999)" --adam-eps=1e-08 --clip-norm=1.0 \
+ --lr-scheduler=polynomial_decay --lr=${lr} \
+ --max-epoch=${max_epoch} --warmup-ratio=${warmup_ratio} \
+ --log-format=simple --log-interval=10 \
+ --fixed-validation-seed=7 \
+ --no-epoch-checkpoints --keep-best-checkpoints=1 \
+ --save-interval=1 --validate-interval=1 \
+ --save-interval-updates=500 --validate-interval-updates=500 \
+ --eval-cider \
+ --eval-cider-cached-tokens=${eval_cider_cached} \
+ --eval-args='{"beam":5,"max_len_b":16,"no_repeat_ngram_size":3}' \
+ --best-checkpoint-metric=cider --maximize-best-checkpoint-metric \
+ --max-src-length=${max_src_length} \
+ --max-tgt-length=${max_tgt_length} \
+ --find-unused-parameters \
+ --freeze-encoder-embedding \
+ --freeze-decoder-embedding \
+ --add-type-embedding \
+ --scale-attn \
+ --scale-fc \
+ --scale-heads \
+ --disable-entangle \
+ --num-bins=${num_bins} \
+ --patch-image-size=${patch_image_size} \
+ --drop-worst-ratio=${drop_worst_ratio} \
+ --drop-worst-after=6000 \
+ --fp16 \
+ --fp16-scale-window=512 \
+ --train_mue\
+ --num-workers=0 > ${log_file} 2>&1
+ done
+ done
+done
\ No newline at end of file
diff --git a/run_scripts/snli_ve/evaluate_snli_ve_base_MuE.sh b/run_scripts/snli_ve/evaluate_snli_ve_base_MuE.sh
new file mode 100755
index 00000000..5b9631c1
--- /dev/null
+++ b/run_scripts/snli_ve/evaluate_snli_ve_base_MuE.sh
@@ -0,0 +1,33 @@
+#!/usr/bin/env bash
+
+# The port for communication. Note that if you want to run multiple tasks on the same machine,
+# you need to specify different port numbers.
+export MASTER_PORT=7091
+
+user_dir=../../ofa_module
+bpe_dir=../../utils/BPE
+
+# dev or test
+split=test
+
+data=../../dataset/snli_ve_data/snli_ve_${split}.tsv
+path=../../checkpoints/snli_ve_base_best.pt
+result_path=../../results/snli_ve
+selected_cols=0,2,3,4,5
+
+CUDA_VISIBLE_DEVICES=0 python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=${MASTER_PORT} ../../evaluate.py \
+ ${data} \
+ --path=${path} \
+ --user-dir=${user_dir} \
+ --task=snli_ve \
+ --batch-size=1 \
+ --log-format=simple --log-interval=10 \
+ --seed=7 \
+ --gen-subset=${split} \
+ --results-path=${result_path} \
+ --fp16 \
+ --num-workers=0 \
+ --img_thres=0.7\
+ --txt_thres=0.99\
+ --decoder_thres=0.7\
+ --model-overrides="{\"data\":\"${data}\",\"bpe_dir\":\"${bpe_dir}\",\"selected_cols\":\"${selected_cols}\"}"
\ No newline at end of file
diff --git a/run_scripts/snli_ve/train_snli_ve_base_MuE.sh b/run_scripts/snli_ve/train_snli_ve_base_MuE.sh
new file mode 100644
index 00000000..fad48e01
--- /dev/null
+++ b/run_scripts/snli_ve/train_snli_ve_base_MuE.sh
@@ -0,0 +1,100 @@
+#!/usr/bin/env
+
+# The port for communication. Note that if you want to run multiple tasks on the same machine,
+# you need to specify different port numbers.
+export MASTER_PORT=7061
+
+log_dir=./logs
+save_dir=./checkpoints
+mkdir -p $log_dir $save_dir
+
+bpe_dir=../../utils/BPE
+user_dir=../../ofa_module
+
+data_dir=../../dataset/snli_ve_data
+data=${data_dir}/snli_ve_train.tsv,${data_dir}/snli_ve_dev.tsv
+restore_file=../../checkpoints/ofa_base.pt
+selected_cols=0,2,3,4,5
+
+task=snli_ve
+arch=ofa_base
+criterion=MuE_Task_Loss
+label_smoothing=0.0
+lr=3e-5
+max_epoch=5
+warmup_ratio=0.06
+batch_size=4
+update_freq=8
+resnet_drop_path_rate=0.0
+encoder_drop_path_rate=0.1
+decoder_drop_path_rate=0.1
+dropout=0.1
+attention_dropout=0.0
+max_src_length=80
+max_tgt_length=20
+num_bins=1000
+patch_image_size=480
+prompt_type="prev_output"
+
+for max_epoch in {5,}; do
+ echo "max_epoch "${max_epoch}
+ for lr in {5e-5,}; do
+ echo "lr "${lr}
+
+ log_file=${log_dir}/${max_epoch}"_"${lr}".log"
+ save_path=${save_dir}/${max_epoch}"_"${lr}
+ mkdir -p $save_path
+
+ CUDA_VISIBLE_DEVICES=0,1 python3 -m torch.distributed.launch --nproc_per_node=2 --master_port=${MASTER_PORT} ../../train.py \
+ $data \
+ --selected-cols=${selected_cols} \
+ --bpe-dir=${bpe_dir} \
+ --user-dir=${user_dir} \
+ --restore-file=${restore_file} \
+ --reset-optimizer --reset-dataloader --reset-meters \
+ --save-dir=${save_path} \
+ --task=${task} \
+ --arch=${arch} \
+ --criterion=${criterion} \
+ --label-smoothing=${label_smoothing} \
+ --batch-size=${batch_size} \
+ --update-freq=${update_freq} \
+ --encoder-normalize-before \
+ --decoder-normalize-before \
+ --share-decoder-input-output-embed \
+ --share-all-embeddings \
+ --layernorm-embedding \
+ --patch-layernorm-embedding \
+ --code-layernorm-embedding \
+ --resnet-drop-path-rate=${resnet_drop_path_rate} \
+ --encoder-drop-path-rate=${encoder_drop_path_rate} \
+ --decoder-drop-path-rate=${decoder_drop_path_rate} \
+ --dropout=${dropout} \
+ --attention-dropout=${attention_dropout} \
+ --weight-decay=0.01 --optimizer=adam --adam-betas="(0.9,0.999)" --adam-eps=1e-08 --clip-norm=1.0 \
+ --lr-scheduler=polynomial_decay --lr=${lr} \
+ --max-epoch=${max_epoch} --warmup-ratio=${warmup_ratio} \
+ --log-format=simple --log-interval=10 \
+ --fixed-validation-seed=7 \
+ --keep-best-checkpoints=1 \
+ --save-interval=1 --validate-interval=1 \
+ --save-interval-updates=500 --validate-interval-updates=500 \
+ --best-checkpoint-metric=snli_score --maximize-best-checkpoint-metric \
+ --max-src-length=${max_src_length} \
+ --max-tgt-length=${max_tgt_length} \
+ --find-unused-parameters \
+ --add-type-embedding \
+ --scale-attn \
+ --scale-fc \
+ --scale-heads \
+ --disable-entangle \
+ --num-bins=${num_bins} \
+ --patch-image-size=${patch_image_size} \
+ --prompt-type=${prompt_type} \
+ --add-caption \
+ --fp16 \
+ --fp16-scale-window=512 \
+ --train_mue\
+ --num-workers=0 > ${log_file} 2>&1
+ done
+done
\ No newline at end of file
diff --git a/tasks/mm_tasks/caption.py b/tasks/mm_tasks/caption.py
index f37c91be..8b67094f 100644
--- a/tasks/mm_tasks/caption.py
+++ b/tasks/mm_tasks/caption.py
@@ -227,7 +227,7 @@ def decode(toks, escape_unk=False):
s = self.bpe.decode(s)
return s
- gen_out = self.inference_step(generator, [model], sample)
+ gen_out = self.inference_step(generator, [model], sample, is_train=True)
hyps, refs = [], []
transtab = str.maketrans({key: None for key in string.punctuation})
for i in range(len(gen_out)):
diff --git a/utils/eval_utils.py b/utils/eval_utils.py
index a8bd9f7c..6b2e3c7e 100644
--- a/utils/eval_utils.py
+++ b/utils/eval_utils.py
@@ -47,7 +47,7 @@ def _calculate_error_rate(hyps, refs):
def eval_caption(task, generator, models, sample, **kwargs):
transtab = str.maketrans({key: None for key in string.punctuation})
- hypos = task.inference_step(generator, models, sample)
+ hypos = task.inference_step(generator, models, sample, **kwargs)
results = []
for i, sample_id in enumerate(sample["id"].tolist()):
detok_hypo_str = decode_fn(hypos[i][0]["tokens"], task.tgt_dict, task.bpe, generator)
@@ -56,7 +56,7 @@ def eval_caption(task, generator, models, sample, **kwargs):
def eval_caption_cn(task, generator, models, sample, **kwargs):
- hypos = task.inference_step(generator, models, sample)
+ hypos = task.inference_step(generator, models, sample, **kwargs)
results = []
for i, sample_id in enumerate(sample["id"].tolist()):
detok_hypo_str = decode_fn(
@@ -72,7 +72,7 @@ def eval_caption_cn(task, generator, models, sample, **kwargs):
def eval_ocr(task, generator, models, sample, **kwargs):
- gen_out = task.inference_step(generator, models, sample)
+ gen_out = task.inference_step(generator, models, sample, **kwargs)
hyps, refs, results = [], [], []
for i, sample_id in enumerate(sample["id"].tolist()):
decode_tokens = decode_fn(gen_out[i][0]["tokens"], task.tgt_dict, task.bpe, generator).strip()
@@ -102,7 +102,7 @@ def eval_ocr(task, generator, models, sample, **kwargs):
def eval_vqa_gen(task, generator, models, sample, **kwargs):
if kwargs['beam_search_vqa_eval']:
- hypos = task.inference_step(generator, models, sample, prefix_tokens=sample['prefix_tokens'])
+ hypos = task.inference_step(generator, models, sample, prefix_tokens=sample['prefix_tokens'], **kwargs)
results = []
for i, sample_id in enumerate(sample["id"].tolist()):
prefix_len = sample['prefix_tokens'][i].ne(1).sum().item()
@@ -185,7 +185,7 @@ def _calculate_ap_score(hyps, refs, thresh=0.5):
ious = area_interacts / (area_predictions + area_targets - area_interacts + 1e-6)
return ((ious >= thresh) & (interacts_w > 0) & (interacts_h > 0)).float()
- gen_out = task.inference_step(generator, models, sample)
+ gen_out = task.inference_step(generator, models, sample, **kwargs)
hyps = []
for i in range(len(gen_out)):
hyps.append(gen_out[i][0]["tokens"][:-1] - len(task.src_dict) + task.cfg.num_bins)
@@ -208,7 +208,8 @@ def eval_snli_ve(task, generator, models, sample, **kwargs):
sample["net_input"]["src_tokens"],
src_lengths=sample["net_input"]["src_lengths"],
patch_images=sample["net_input"]["patch_images"],
- patch_masks=sample["net_input"]["patch_masks"]
+ patch_masks=sample["net_input"]["patch_masks"],
+ **kwargs
)
device = sample["net_input"]["src_tokens"].device
eos_item = torch.tensor([task.src_dict.eos()])
@@ -246,7 +247,7 @@ def eval_snli_ve(task, generator, models, sample, **kwargs):
encoder_out["position_embeddings"][0].repeat_interleave(valid_size, dim=0)
]
- decoder_out = models[0].decoder(valid_prev_output, encoder_out=new_encoder_out)
+ decoder_out = models[0].decoder(valid_prev_output, encoder_out=new_encoder_out, **kwargs)
decoder_out[0].masked_fill_(~valid_constraint_masks, -math.inf)
lprobs = models[0].get_normalized_probs(decoder_out, log_probs=True)
scores = lprobs.gather(dim=-1, index=valid_tgt.unsqueeze(-1)).squeeze(-1)