@@ -225,6 +225,8 @@ def add_args(parser):
225225 help = 'freeze decoder token embedding' )
226226 parser .add_argument ('--add-type-embedding' , action = 'store_true' ,
227227 help = 'add source/region/patch type embedding' )
228+ parser .add_argument ('--interpolate-position' , action = 'store_true' ,
229+ help = 'interpolate position' )
228230
229231 parser .add_argument ('--resnet-type' , choices = ['resnet50' , 'resnet101' , 'resnet152' ],
230232 help = 'resnet type' )
@@ -498,6 +500,9 @@ def __init__(self, args, dictionary, embed_tokens):
498500 [Embedding (image_num_rel_dis , self .num_attention_heads , zero_init = True ) for _ in range (args .encoder_layers )]
499501 )
500502
503+ self .patch_image_size = args .patch_image_size
504+ self .orig_patch_image_size = args .orig_patch_image_size
505+
501506 self .register_buffer ("token_rp_bucket" , token_rp_bucket )
502507 self .register_buffer ("image_rp_bucket" , image_rp_bucket )
503508 self .entangle_position_embedding = args .entangle_position_embedding
@@ -560,7 +565,19 @@ def get_patch_images_info(self, patch_images, sample_patch_num, device):
560565 image_num_patches = sample_patch_num
561566 image_padding_mask = image_padding_mask .gather (1 , patch_orders )
562567 image_position_ids = image_position_ids .gather (1 , patch_orders )
563- image_pos_embed = self .embed_image_positions (image_position_ids )
568+ orig_num_patches = (self .orig_patch_image_size // 16 ) ** 2
569+ orig_hw = self .orig_patch_image_size // 16
570+ if getattr (self .args , "interpolate_position" , False ) and image_num_patches > orig_num_patches :
571+ old_image_position_ids = torch .arange (orig_hw ).unsqueeze (0 ).expand (orig_hw , orig_hw ) + \
572+ torch .arange (orig_hw ).unsqueeze (1 ) * self .args .image_bucket_size + 1
573+ old_image_position_ids = old_image_position_ids .to (device )
574+ old_image_pos_embed = self .embed_image_positions (old_image_position_ids )
575+ old_image_pos_embed = old_image_pos_embed .reshape (1 , orig_hw , orig_hw , - 1 ).permute (0 , 3 , 1 , 2 )
576+ image_pos_embed = F .interpolate (old_image_pos_embed , size = (h , w ), mode = 'bilinear' )
577+ image_pos_embed = image_pos_embed .permute (0 , 2 , 3 , 1 ).reshape (1 , image_num_patches , - 1 )
578+ image_pos_embed = image_pos_embed .expand (patch_images .size (0 ), - 1 , - 1 )
579+ else :
580+ image_pos_embed = self .embed_image_positions (image_position_ids )
564581
565582 return image_embed , image_num_patches , image_padding_mask , image_position_ids , image_pos_embed
566583
0 commit comments