Skip to content

Commit 292ea26

Browse files
committed
add cn ofa
1 parent e87ee62 commit 292ea26

File tree

9 files changed

+42354
-16
lines changed

9 files changed

+42354
-16
lines changed

data/mm_data/caption_dataset.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,11 @@ def __init__(
113113
transforms.Normalize(mean=mean, std=std),
114114
])
115115

116+
if type(bpe).__name__ == 'GPT2BPE':
117+
self.prompt = " what does the image describe?"
118+
elif type(bpe).__name__ == 'BertBPE':
119+
self.prompt = "图片描述了什么内容?"
120+
116121
def __getitem__(self, index):
117122
uniq_id, image, caption = self.dataset[index]
118123

@@ -128,7 +133,7 @@ def __getitem__(self, index):
128133
caption = ' '.join(caption.strip().split())
129134
caption_list = [cap.translate(self.transtab).strip() for cap in caption.strip().split('&&')]
130135
tgt_caption = '&&'.join(caption_list)
131-
src_item = self.encode_text(" what does the image describe?")
136+
src_item = self.encode_text(self.prompt)
132137
tgt_item = self.encode_text(" {}".format(tgt_caption))
133138

134139
src_item = torch.cat([self.bos_item, src_item, self.eos_item])

data/mm_data/refcoco_dataset.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,11 @@ def __init__(
118118
T.Normalize(mean=mean, std=std, max_image_size=max_image_size)
119119
])
120120

121+
if type(bpe).__name__ == 'GPT2BPE':
122+
self.prompt = ' which region does the text " {} " describe?'
123+
elif type(bpe).__name__ == 'BertBPE':
124+
self.prompt = '这段文字" {} "描述的是哪个区域?'
125+
121126
def __getitem__(self, index):
122127
uniq_id, base64_str, text, region_coord = self.dataset[index]
123128

@@ -139,7 +144,7 @@ def __getitem__(self, index):
139144
quant_y1 = "<bin_{}>".format(int((patch_boxes["boxes"][0][3] * (self.num_bins - 1)).round()))
140145
region_coord = "{} {} {} {}".format(quant_x0, quant_y0, quant_x1, quant_y1)
141146
src_caption = self.pre_caption(text, self.max_src_length)
142-
src_item = self.encode_text(' which region does the text " {} " describe?'.format(src_caption))
147+
src_item = self.encode_text(self.prompt.format(src_caption))
143148
tgt_item = self.encode_text(region_coord, use_bpe=False)
144149

145150
src_item = torch.cat([self.bos_item, src_item, self.eos_item])

data/nlg_data/summary_dataset.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@ def __init__(
8181
self.num_bins = num_bins
8282
self.noise_ratio = noise_ratio
8383

84+
if type(bpe).__name__ == 'GPT2BPE':
85+
self.prompt = ' what is the summary of article " {} "?'
86+
elif type(bpe).__name__ == 'BertBPE':
87+
self.prompt = "{} 请用一个句子简单总结上文:"
88+
8489
def __getitem__(self, index):
8590
source, target = self.dataset[index]
8691
target_str = target.lower()
@@ -91,10 +96,10 @@ def __getitem__(self, index):
9196
target = target.replace('<unk>', 'unk')
9297

9398
src_item = self.encode_text(
94-
' what is the summary of article " {} "?'.format(source),
99+
self.prompt.format(source),
95100
length=self.max_src_length
96101
)
97-
tgt_item = self.encode_text(' {}'.format(target))
102+
tgt_item = self.encode_text('{}'.format(target))
98103
noise_tgt_item = self.add_noise_to_tgt(tgt_item.clone(), self.noise_ratio)
99104

100105
src_item = torch.cat([self.bos_item, src_item, self.eos_item])

data/ofa_dataset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def encode_text(self, text, length=None, append_bos=False, append_eos=False, use
4242
s = torch.cat([s, self.eos_item])
4343
return s
4444

45-
def pre_question(self, question, max_ques_words):
45+
def pre_question(self, question, max_ques_words=None):
4646
question = question.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ')
4747

4848
question = re.sub(
@@ -55,12 +55,12 @@ def pre_question(self, question, max_ques_words):
5555

5656
# truncate question
5757
question_words = question.split(' ')
58-
if len(question_words) > max_ques_words:
58+
if max_ques_words is not None and len(question_words) > max_ques_words:
5959
question = ' '.join(question_words[:max_ques_words])
6060

6161
return question
6262

63-
def pre_caption(self, caption, max_words):
63+
def pre_caption(self, caption, max_words=None):
6464
caption = caption.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ').replace('<person>', 'person')
6565

6666
caption = re.sub(
@@ -73,7 +73,7 @@ def pre_caption(self, caption, max_words):
7373

7474
# truncate caption
7575
caption_words = caption.split(' ')
76-
if len(caption_words) > max_words:
76+
if max_words is not None and len(caption_words) > max_words:
7777
caption = ' '.join(caption_words[:max_words])
7878

7979
return caption

models/ofa/unify_transformer.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,8 @@ def add_args(parser):
225225
help='freeze decoder token embedding')
226226
parser.add_argument('--add-type-embedding', action='store_true',
227227
help='add source/region/patch type embedding')
228+
parser.add_argument('--interpolate-position', action='store_true',
229+
help='interpolate position')
228230

229231
parser.add_argument('--resnet-type', choices=['resnet50', 'resnet101', 'resnet152'],
230232
help='resnet type')
@@ -498,6 +500,9 @@ def __init__(self, args, dictionary, embed_tokens):
498500
[Embedding(image_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(args.encoder_layers)]
499501
)
500502

503+
self.patch_image_size = args.patch_image_size
504+
self.orig_patch_image_size = args.orig_patch_image_size
505+
501506
self.register_buffer("token_rp_bucket", token_rp_bucket)
502507
self.register_buffer("image_rp_bucket", image_rp_bucket)
503508
self.entangle_position_embedding = args.entangle_position_embedding
@@ -560,7 +565,19 @@ def get_patch_images_info(self, patch_images, sample_patch_num, device):
560565
image_num_patches = sample_patch_num
561566
image_padding_mask = image_padding_mask.gather(1, patch_orders)
562567
image_position_ids = image_position_ids.gather(1, patch_orders)
563-
image_pos_embed = self.embed_image_positions(image_position_ids)
568+
orig_num_patches = (self.orig_patch_image_size // 16) ** 2
569+
orig_hw= self.orig_patch_image_size // 16
570+
if getattr(self.args, "interpolate_position", False) and image_num_patches > orig_num_patches:
571+
old_image_position_ids = torch.arange(orig_hw).unsqueeze(0).expand(orig_hw, orig_hw) + \
572+
torch.arange(orig_hw).unsqueeze(1) * self.args.image_bucket_size + 1
573+
old_image_position_ids = old_image_position_ids.to(device)
574+
old_image_pos_embed = self.embed_image_positions(old_image_position_ids)
575+
old_image_pos_embed = old_image_pos_embed.reshape(1, orig_hw, orig_hw, -1).permute(0, 3, 1, 2)
576+
image_pos_embed = F.interpolate(old_image_pos_embed, size=(h, w), mode='bilinear')
577+
image_pos_embed = image_pos_embed.permute(0, 2, 3, 1).reshape(1, image_num_patches, -1)
578+
image_pos_embed = image_pos_embed.expand(patch_images.size(0), -1, -1)
579+
else:
580+
image_pos_embed = self.embed_image_positions(image_position_ids)
564581

565582
return image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed
566583

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#!/usr/bin/env bash
2+
3+
# The port for communication. Note that if you want to run multiple tasks on the same machine,
4+
# you need to specify different port numbers.
5+
export MASTER_PORT=6081
6+
export CUDA_VISIBLE_DEVICES=7
7+
export GPUS_PER_NODE=1
8+
9+
user_dir=../../ofa_module
10+
bpe_dir=../../utils/BERT_CN_dict
11+
selected_cols=0,3,1,2
12+
13+
data=../../dataset/refcoco_cn_data/refcoco+_test_sample.tsv
14+
path=../../checkpoints/refcocoplus_cn_large.pt
15+
result_path=../../results/refcoco
16+
split='refcoco_val'
17+
python3 ../../evaluate.py \
18+
${data} \
19+
--path=${path} \
20+
--user-dir=${user_dir} \
21+
--task=refcoco \
22+
--batch-size=16 \
23+
--log-format=simple --log-interval=10 \
24+
--seed=7 \
25+
--gen-subset=${split} \
26+
--results-path=${result_path} \
27+
--beam=5 \
28+
--min-len=4 \
29+
--max-len-a=0 \
30+
--max-len-b=4 \
31+
--no-repeat-ngram-size=3 \
32+
--fp16 \
33+
--num-workers=0 \
34+
--model-overrides="{\"data\":\"${data}\",\"bpe_dir\":\"${bpe_dir}\",\"selected_cols\":\"${selected_cols}\"}"

tasks/ofa_task.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ class OFAConfig(FairseqDataclass):
3434
default=None,
3535
metadata={"help": "selected cols"},
3636
)
37+
bpe: Optional[str] = field(
38+
default='gpt2',
39+
metadata={"help": "which bpe to use"},
40+
)
3741
bpe_dir: Optional[str] = field(
3842
default=None,
3943
metadata={"help": "bpe dir"},
@@ -57,6 +61,9 @@ class OFAConfig(FairseqDataclass):
5761
patch_image_size: int = field(
5862
default=480, metadata={"help": "patch image size"}
5963
)
64+
orig_patch_image_size: int = field(
65+
default=256, metadata={"help": "patch image size"}
66+
)
6067
num_bins: int = field(
6168
default=1000, metadata={"help": "number of quantization bins"}
6269
)
@@ -151,13 +158,22 @@ def get_batch_iterator(
151158

152159
def build_model(self, cfg: FairseqDataclass):
153160
model = super().build_model(cfg)
154-
bpe_dict = {
155-
"_name": "gpt2",
156-
"gpt2_encoder_json": os.path.join(self.cfg.bpe_dir, "encoder.json"),
157-
"gpt2_vocab_bpe": os.path.join(self.cfg.bpe_dir, "vocab.bpe")
158-
}
159-
bpe_dict = DictConfig(bpe_dict)
160-
self.bpe = self.build_bpe(bpe_dict)
161+
if self.cfg.bpe == 'bert':
162+
bpe_dict = {
163+
"_name": "bert",
164+
"bpe_vocab_file": os.path.join(self.cfg.bpe_dir, "vocab.txt"),
165+
"bpe_cased": False
166+
}
167+
bpe_dict = DictConfig(bpe_dict)
168+
self.bpe = self.build_bpe(bpe_dict)
169+
else:
170+
bpe_dict = {
171+
"_name": "gpt2",
172+
"gpt2_encoder_json": os.path.join(self.cfg.bpe_dir, "encoder.json"),
173+
"gpt2_vocab_bpe": os.path.join(self.cfg.bpe_dir, "vocab.bpe")
174+
}
175+
bpe_dict = DictConfig(bpe_dict)
176+
self.bpe = self.build_bpe(bpe_dict)
161177
return model
162178

163179
def build_generator(

0 commit comments

Comments
 (0)