Skip to content

Commit 85d72b4

Browse files
authored
Merge pull request #7 from Haoye17/main
change training code
2 parents 5d32ea6 + 9cab457 commit 85d72b4

File tree

7 files changed

+242
-92
lines changed

7 files changed

+242
-92
lines changed

muffin/data/data_processors.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -205,15 +205,14 @@ def unimmchat_processor(img_b64_buffer, text_b64, origin_dataset, origin_split,
205205
raise NotImplemented
206206

207207

208-
@register_data_processor('RLHF-V-Hall_v0')
208+
@register_data_processor('RLHF-V-Dataset')
209209
def dpo_cvpr_ncrp_vqa_processor(*args, **kwargs):
210210
return dpo_preference_processor(*args, **kwargs)
211211

212-
213-
@register_data_path('RLHF-V-Hall_v0')
212+
@register_data_path('RLHF-V-Dataset')
214213
def dpo_cvpr_ncrp_vqa_path():
215-
data_dir = pathlib.Path(__file__).parent.resolve() / '../../data/RLHF-V-Hall_v0'
216-
return gather_data_files_by_glob(data_dir, pattern='*dpo_with_rlhf-v-sft_logp_train-1401.tsv')
214+
data_dir = pathlib.Path(__file__).parent.resolve() / '../../data/RLHF-V-Dataset'
215+
return gather_data_files_by_glob(data_dir, pattern='RLHF-V-Dataset_withlogp-1401.tsv')
217216

218217

219218
def dpo_preference_processor(img_b64_buffer, text_b64, origin_dataset, origin_split, origin_split_inner_idx, img_path,

muffin/data/datasets.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,87 @@
11
import io
2+
import os
23
import json
4+
import torch
35
import numpy
46
import base64
5-
7+
import pandas as pd
68
import os.path as op
9+
import datasets as hf_datasets
710
import torch.utils.data as torch_data
811

912
from PIL import Image
1013
from typing import List, Iterator
1114
from muffin.data.tsv_file import TSVFile
1215
from muffin.data.data_processors import register_data_processor
16+
from muffin.eval.muffin_inference_logp import inference_logp
17+
18+
def bytes_to_PIL_image(img_buffer):
19+
img_io = io.BytesIO(img_buffer)
20+
img_io.seek(0)
21+
image = Image.open(img_io).convert('RGB')
22+
return image
23+
24+
def read_jsonl(file_path):
25+
with open(file_path, "r", encoding="utf-8") as file:
26+
return [json.loads(line) for line in file]
27+
28+
class RLHFVDataset(torch_data.Dataset):
29+
def __init__(self, data_dir: str, ref_name: str, reference_model=None,
30+
tokenizer=None, image_token_len=None, img_processor=None, use_im_start_end=True):
31+
super().__init__()
32+
33+
self.data_path = f'{data_dir}/{ref_name}_with_logp.parquet'
34+
35+
if not op.exists(self.data_path):
36+
os.makedirs(data_dir, exist_ok=True)
37+
38+
assert reference_model is not None, "`reference_model` is mandatory when logps do not exist."
39+
40+
hf_data = hf_datasets.load_dataset("HaoyeZhang/RLHF-V-Dataset")['train'].cast_column("image", hf_datasets.Image(decode=False))
41+
42+
inference_logp(reference_model, tokenizer, hf_data, self.data_path,
43+
image_token_len, img_processor, use_im_start_end)
44+
45+
torch.distributed.barrier()
46+
47+
self.data = pd.read_parquet(self.data_path)
48+
49+
# print(f'{torch.distributed.get_rank()} data len: {len(self.data)}')
50+
else:
51+
self.data = pd.read_parquet(self.data_path)
52+
53+
def __len__(self):
54+
return len(self.data)
55+
56+
def __getitem__(self, index):
57+
sample = self.data.iloc[index]
58+
text = json.loads(sample['text'])
59+
question = {'from': 'human', 'value': f"<image>\n{text['question']}"}
60+
chosen = {'from': 'gpt', 'value': text['chosen']}
61+
rejected = {'from': 'gpt', 'value': text['rejected']}
62+
63+
image = bytes_to_PIL_image(sample['image']['bytes'])
64+
65+
metainfo = {
66+
"origin_dataset": sample['origin_dataset'],
67+
"origin_split": sample['origin_split'],
68+
"origin_idx": sample['idx'],
69+
"image_id": sample['image_path'],
70+
}
71+
72+
data_dict = {
73+
'image': image,
74+
"question": question,
75+
"chosen": chosen,
76+
"rejected": rejected,
77+
"idx": sample['idx'],
78+
"metainfo": metainfo
79+
}
80+
81+
(data_dict['ref_win_logp'], data_dict['ref_win_avg_logp'], data_dict['ref_win_per_token_logp'],
82+
data_dict['ref_rej_logp'], data_dict['ref_rej_avg_logp'], data_dict['ref_rej_per_token_logp']) = text['logps']
83+
84+
return data_dict
1385

1486

1587
class MultimodalQADataset(torch_data.Dataset):

muffin/eval/muffin_inference_logp.py

Lines changed: 114 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,50 @@
1-
import os
1+
import io
22
import json
33
import tqdm
4+
import copy
45
import torch
5-
import base64
6-
import argparse
6+
import itertools
7+
import pandas as pd
78
import torch.utils.data as torch_data
9+
import PIL.Image as PIL_image
810

9-
from typing import List
1011
from functools import partial
1112

12-
from muffin.eval.muffin_vqa import init_muffin
1313
from muffin.train.train_utils import encode_multimodal_preference_sample, SFT_collator_fn
14-
from muffin.data.datasets import SingleDataSourceDataset
15-
from muffin.data.tsv_file_op import multimodal_img_tsv_writer_prev
16-
from muffin.data.tsv_file import TSVFile
14+
15+
16+
def bytes_to_PIL_image(img_buffer):
17+
img_io = io.BytesIO(img_buffer)
18+
img_io.seek(0)
19+
image = PIL_image.open(img_io).convert('RGB')
20+
return image
21+
22+
23+
class InferenceSampler(torch.utils.data.sampler.Sampler):
24+
25+
def __init__(self, size):
26+
self._size = int(size)
27+
assert size > 0
28+
self._rank = torch.distributed.get_rank()
29+
self._world_size = torch.distributed.get_world_size()
30+
self._local_indices = self._get_local_indices(size, self._world_size,
31+
self._rank)
32+
33+
@staticmethod
34+
def _get_local_indices(total_size, world_size, rank):
35+
shard_size = total_size // world_size
36+
left = total_size % world_size
37+
shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
38+
39+
begin = sum(shard_sizes[:rank])
40+
end = min(sum(shard_sizes[:rank + 1]), total_size)
41+
return range(begin, end)
42+
43+
def __iter__(self):
44+
yield from self._local_indices
45+
46+
def __len__(self):
47+
return len(self._local_indices)
1748

1849

1950
def get_batch_logps(logits: torch.FloatTensor, labels: torch.LongTensor, return_per_token_logp=False, return_all=False) -> torch.FloatTensor:
@@ -52,16 +83,13 @@ def get_batch_logps(logits: torch.FloatTensor, labels: torch.LongTensor, return_
5283

5384
class PreferenceInferenceDataset(torch_data.Dataset):
5485
def __init__(self,
55-
data_dir,
86+
data,
5687
tokenizer,
57-
tsv_filenames: List[str],
5888
image_token_len,
5989
img_processor,
60-
use_im_start_end):
61-
if 'DPO_preference_llava' in data_dir or 'llavarlhf' in tsv_filenames[0]:
62-
self.data = SingleDataSourceDataset('dpo_preference_llava_7b_v1_preference_hallonly' ,data_dir, tsv_filenames)
63-
else:
64-
self.data = SingleDataSourceDataset('RLHF-V-Hall_v0' ,data_dir, tsv_filenames)
90+
use_im_start_end=True):
91+
92+
self.data = data
6593

6694
self.mm_cfg = {
6795
'image_processor': img_processor,
@@ -73,7 +101,29 @@ def __init__(self,
73101

74102
def __getitem__(self, index):
75103
sample = self.data[index]
76-
rej_data_dict, win_data_dict = encode_multimodal_preference_sample(sample, self.tokenizer, self.mm_cfg)
104+
metainfo = {
105+
"origin_dataset": sample['origin_dataset'],
106+
"origin_split": json.loads(sample['origin_split']),
107+
"origin_idx": sample['idx'],
108+
"image_id": sample['image_path'],
109+
}
110+
111+
text = json.loads(sample['text'])
112+
question = {'from': 'human', 'value': f"<image>\n{text['question']}"}
113+
chosen = {'from': 'gpt', 'value': text['chosen']}
114+
rejected = {'from': 'gpt', 'value': text['rejected']}
115+
116+
image = bytes_to_PIL_image(sample['image']['bytes'])
117+
118+
formated_sample = {
119+
'image': image,
120+
"question": question,
121+
"chosen": chosen,
122+
"rejected": rejected,
123+
"idx": sample['idx'],
124+
"metainfo": metainfo
125+
}
126+
rej_data_dict, win_data_dict = encode_multimodal_preference_sample(formated_sample, self.tokenizer, self.mm_cfg)
77127
return rej_data_dict, win_data_dict
78128

79129
def __len__(self):
@@ -125,17 +175,7 @@ def preference_collator_fn(instances, pad_token_id):
125175
return batch
126176

127177

128-
def get_multimodal_sample_logps(model, tokenizer, data_dir, tsv_files, image_token_len, img_processor, use_im_start_end):
129-
dataset = PreferenceInferenceDataset(data_dir=data_dir,
130-
tokenizer=tokenizer,
131-
tsv_filenames=tsv_files,
132-
image_token_len=image_token_len,
133-
img_processor=img_processor,
134-
use_im_start_end=use_im_start_end)
135-
collate_fn = partial(preference_collator_fn, pad_token_id=tokenizer.pad_token_id)
136-
dataloader = torch_data.DataLoader(dataset, batch_size=1, collate_fn=collate_fn,
137-
num_workers=5, shuffle=False)
138-
178+
def get_multimodal_sample_logps(model, dataloader):
139179
win_logp_list = []
140180
rej_logp_list = []
141181

@@ -180,51 +220,65 @@ def get_multimodal_sample_logps(model, tokenizer, data_dir, tsv_files, image_tok
180220
return win_logp_list, win_avg_logp_list, win_per_token_logp_list, rej_logp_list, rej_avg_logp_list, rej_per_token_logp_list
181221

182222

183-
def write_logp_to_preference_tsv(tsv_filename, out_tsv_filename, logps, overwrite_logps=False):
184-
origin_data = TSVFile(tsv_filename)
185-
223+
def write_logp_to_preference_parquet(origin_data, cache_file, logps, overwrite_logps=False):
186224
out_data = []
187-
for line, logp_data in tqdm.tqdm(zip(origin_data, logps)):
188-
text_b64 = line[2]
189-
text = base64.b64decode(text_b64).decode('utf-8')
190-
preference_data = json.loads(text)
191-
if len(preference_data) == 4:
225+
226+
for index in range(len(origin_data)):
227+
line = origin_data[index]
228+
logp_data = logps[index]
229+
230+
new_line = copy.deepcopy(line)
231+
232+
text = json.loads(new_line['text'])
233+
234+
if 'logps' in text.keys():
192235
assert overwrite_logps, 'Found existing logp data, pass overwrite_logps=True to force overwritting'
193-
preference_data[3] = logp_data
236+
text['logps'] = logp_data
237+
new_line['text'] = json.dumps(text)
238+
194239
else:
195-
assert len(preference_data) == 3, f'Undefined data structure, expecting [Q, Win, Rej], got {text}'
196-
preference_data.append(logp_data)
240+
assert list(text.keys()) == ['question', 'chosen', 'rejected'], f'Undefined data structure, expecting [Q, Win, Rej], got {text.keys()}'
241+
text['logps'] = logp_data
242+
new_line['text'] = json.dumps(text)
197243

198-
line[2] = base64.b64encode(json.dumps(preference_data).encode('utf-8')).decode('utf-8')
199-
out_data.append(line)
244+
out_data.append(new_line)
200245

201-
multimodal_img_tsv_writer_prev(out_data, out_tsv_filename)
246+
df = pd.DataFrame(out_data)
202247

203-
def inference_logp(args):
204-
model, img_processor, image_token_len, tokenizer = init_muffin(args.model_name)
205-
use_im_start_end = True
248+
if torch.distributed.get_rank() == 0:
249+
df.to_parquet(cache_file)
206250

207-
tsv_files = [args.tsv_file]
251+
torch.distributed.barrier()
208252

209-
for tsv_filename in tsv_files:
210-
win_logp_list, win_avg_logp_list, win_per_token_logp_list, rej_logp_list, rej_avg_logp_list, rej_per_token_logp_list = get_multimodal_sample_logps(model, tokenizer, args.data_dir, [tsv_filename], image_token_len, img_processor, use_im_start_end)
211-
logps = list(zip(win_logp_list, win_avg_logp_list, win_per_token_logp_list, rej_logp_list, rej_avg_logp_list, rej_per_token_logp_list))
253+
return df
254+
255+
def inference_logp(model, tokenizer, hf_data, cache_file, image_token_len, img_processor, use_im_start_end):
256+
model = model.to(dtype=torch.bfloat16, device='cuda')
257+
dataset = PreferenceInferenceDataset(tokenizer=tokenizer,
258+
data = hf_data,
259+
image_token_len=image_token_len,
260+
img_processor=img_processor,
261+
use_im_start_end=use_im_start_end)
262+
collate_fn = partial(preference_collator_fn, pad_token_id=tokenizer.pad_token_id)
263+
dataloader = torch_data.DataLoader(dataset, batch_size=1, collate_fn=collate_fn,
264+
num_workers=5, shuffle=False, sampler=InferenceSampler(len(dataset)))
212265

266+
outputs = get_multimodal_sample_logps(model, dataloader) # win_logp_list, win_avg_logp_list, win_per_token_logp_list, rej_logp_list, rej_avg_logp_list, rej_per_token_logp_list
213267

214-
tsv_filepath = os.path.join(args.data_dir, tsv_filename)
268+
world_size = torch.distributed.get_world_size()
269+
merged_outputs = [[None for _ in range(world_size)] for i in range(len(outputs))]
270+
for i in range(len(outputs)):
271+
torch.distributed.all_gather_object(merged_outputs[i], outputs[i])
272+
merged_outputs[i] = [_ for _ in itertools.chain.from_iterable(merged_outputs[i])]
215273

216-
save_name = '-'.join(tsv_filename.split('-')[:-1])
217-
save_name = save_name + '_' + args.logp_file
274+
win_logp_list, win_avg_logp_list, win_per_token_logp_list, rej_logp_list, rej_avg_logp_list, rej_per_token_logp_list \
275+
= merged_outputs
218276

219-
write_logp_to_preference_tsv(tsv_filepath, f'{args.data_dir}/{save_name}', logps, overwrite_logps=True)
277+
logps = list(zip(win_logp_list, win_avg_logp_list, win_per_token_logp_list, rej_logp_list, rej_avg_logp_list, rej_per_token_logp_list))
220278

279+
df = write_logp_to_preference_parquet(dataset.data, cache_file, logps, overwrite_logps=False)
221280

222-
if __name__ == '__main__':
223-
parser = argparse.ArgumentParser()
224-
parser.add_argument("--model-name", type=str, default="RLHF-V_v0-SFT-13B")
225-
parser.add_argument("--data-dir", type=str)
226-
parser.add_argument("--tsv-file", type=str)
227-
parser.add_argument("--logp-file", type=str, default="dpo_with_rlhf-v-sft_logp_train")
228-
args = parser.parse_args()
281+
torch.distributed.barrier()
229282

230-
inference_logp(args)
283+
del model
284+
return df

0 commit comments

Comments
 (0)