1- import os
1+ import io
22import json
33import tqdm
4+ import copy
45import torch
5- import base64
6- import argparse
6+ import itertools
7+ import pandas as pd
78import torch .utils .data as torch_data
9+ import PIL .Image as PIL_image
810
9- from typing import List
1011from functools import partial
1112
12- from muffin .eval .muffin_vqa import init_muffin
1313from muffin .train .train_utils import encode_multimodal_preference_sample , SFT_collator_fn
14- from muffin .data .datasets import SingleDataSourceDataset
15- from muffin .data .tsv_file_op import multimodal_img_tsv_writer_prev
16- from muffin .data .tsv_file import TSVFile
14+
15+
16+ def bytes_to_PIL_image (img_buffer ):
17+ img_io = io .BytesIO (img_buffer )
18+ img_io .seek (0 )
19+ image = PIL_image .open (img_io ).convert ('RGB' )
20+ return image
21+
22+
23+ class InferenceSampler (torch .utils .data .sampler .Sampler ):
24+
25+ def __init__ (self , size ):
26+ self ._size = int (size )
27+ assert size > 0
28+ self ._rank = torch .distributed .get_rank ()
29+ self ._world_size = torch .distributed .get_world_size ()
30+ self ._local_indices = self ._get_local_indices (size , self ._world_size ,
31+ self ._rank )
32+
33+ @staticmethod
34+ def _get_local_indices (total_size , world_size , rank ):
35+ shard_size = total_size // world_size
36+ left = total_size % world_size
37+ shard_sizes = [shard_size + int (r < left ) for r in range (world_size )]
38+
39+ begin = sum (shard_sizes [:rank ])
40+ end = min (sum (shard_sizes [:rank + 1 ]), total_size )
41+ return range (begin , end )
42+
43+ def __iter__ (self ):
44+ yield from self ._local_indices
45+
46+ def __len__ (self ):
47+ return len (self ._local_indices )
1748
1849
1950def get_batch_logps (logits : torch .FloatTensor , labels : torch .LongTensor , return_per_token_logp = False , return_all = False ) -> torch .FloatTensor :
@@ -52,16 +83,13 @@ def get_batch_logps(logits: torch.FloatTensor, labels: torch.LongTensor, return_
5283
5384class PreferenceInferenceDataset (torch_data .Dataset ):
5485 def __init__ (self ,
55- data_dir ,
86+ data ,
5687 tokenizer ,
57- tsv_filenames : List [str ],
5888 image_token_len ,
5989 img_processor ,
60- use_im_start_end ):
61- if 'DPO_preference_llava' in data_dir or 'llavarlhf' in tsv_filenames [0 ]:
62- self .data = SingleDataSourceDataset ('dpo_preference_llava_7b_v1_preference_hallonly' ,data_dir , tsv_filenames )
63- else :
64- self .data = SingleDataSourceDataset ('RLHF-V-Hall_v0' ,data_dir , tsv_filenames )
90+ use_im_start_end = True ):
91+
92+ self .data = data
6593
6694 self .mm_cfg = {
6795 'image_processor' : img_processor ,
@@ -73,7 +101,29 @@ def __init__(self,
73101
74102 def __getitem__ (self , index ):
75103 sample = self .data [index ]
76- rej_data_dict , win_data_dict = encode_multimodal_preference_sample (sample , self .tokenizer , self .mm_cfg )
104+ metainfo = {
105+ "origin_dataset" : sample ['origin_dataset' ],
106+ "origin_split" : json .loads (sample ['origin_split' ]),
107+ "origin_idx" : sample ['idx' ],
108+ "image_id" : sample ['image_path' ],
109+ }
110+
111+ text = json .loads (sample ['text' ])
112+ question = {'from' : 'human' , 'value' : f"<image>\n { text ['question' ]} " }
113+ chosen = {'from' : 'gpt' , 'value' : text ['chosen' ]}
114+ rejected = {'from' : 'gpt' , 'value' : text ['rejected' ]}
115+
116+ image = bytes_to_PIL_image (sample ['image' ]['bytes' ])
117+
118+ formated_sample = {
119+ 'image' : image ,
120+ "question" : question ,
121+ "chosen" : chosen ,
122+ "rejected" : rejected ,
123+ "idx" : sample ['idx' ],
124+ "metainfo" : metainfo
125+ }
126+ rej_data_dict , win_data_dict = encode_multimodal_preference_sample (formated_sample , self .tokenizer , self .mm_cfg )
77127 return rej_data_dict , win_data_dict
78128
79129 def __len__ (self ):
@@ -125,17 +175,7 @@ def preference_collator_fn(instances, pad_token_id):
125175 return batch
126176
127177
128- def get_multimodal_sample_logps (model , tokenizer , data_dir , tsv_files , image_token_len , img_processor , use_im_start_end ):
129- dataset = PreferenceInferenceDataset (data_dir = data_dir ,
130- tokenizer = tokenizer ,
131- tsv_filenames = tsv_files ,
132- image_token_len = image_token_len ,
133- img_processor = img_processor ,
134- use_im_start_end = use_im_start_end )
135- collate_fn = partial (preference_collator_fn , pad_token_id = tokenizer .pad_token_id )
136- dataloader = torch_data .DataLoader (dataset , batch_size = 1 , collate_fn = collate_fn ,
137- num_workers = 5 , shuffle = False )
138-
178+ def get_multimodal_sample_logps (model , dataloader ):
139179 win_logp_list = []
140180 rej_logp_list = []
141181
@@ -180,51 +220,65 @@ def get_multimodal_sample_logps(model, tokenizer, data_dir, tsv_files, image_tok
180220 return win_logp_list , win_avg_logp_list , win_per_token_logp_list , rej_logp_list , rej_avg_logp_list , rej_per_token_logp_list
181221
182222
183- def write_logp_to_preference_tsv (tsv_filename , out_tsv_filename , logps , overwrite_logps = False ):
184- origin_data = TSVFile (tsv_filename )
185-
223+ def write_logp_to_preference_parquet (origin_data , cache_file , logps , overwrite_logps = False ):
186224 out_data = []
187- for line , logp_data in tqdm .tqdm (zip (origin_data , logps )):
188- text_b64 = line [2 ]
189- text = base64 .b64decode (text_b64 ).decode ('utf-8' )
190- preference_data = json .loads (text )
191- if len (preference_data ) == 4 :
225+
226+ for index in range (len (origin_data )):
227+ line = origin_data [index ]
228+ logp_data = logps [index ]
229+
230+ new_line = copy .deepcopy (line )
231+
232+ text = json .loads (new_line ['text' ])
233+
234+ if 'logps' in text .keys ():
192235 assert overwrite_logps , 'Found existing logp data, pass overwrite_logps=True to force overwritting'
193- preference_data [3 ] = logp_data
236+ text ['logps' ] = logp_data
237+ new_line ['text' ] = json .dumps (text )
238+
194239 else :
195- assert len (preference_data ) == 3 , f'Undefined data structure, expecting [Q, Win, Rej], got { text } '
196- preference_data .append (logp_data )
240+ assert list (text .keys ()) == ['question' , 'chosen' , 'rejected' ], f'Undefined data structure, expecting [Q, Win, Rej], got { text .keys ()} '
241+ text ['logps' ] = logp_data
242+ new_line ['text' ] = json .dumps (text )
197243
198- line [2 ] = base64 .b64encode (json .dumps (preference_data ).encode ('utf-8' )).decode ('utf-8' )
199- out_data .append (line )
244+ out_data .append (new_line )
200245
201- multimodal_img_tsv_writer_prev (out_data , out_tsv_filename )
246+ df = pd . DataFrame (out_data )
202247
203- def inference_logp (args ):
204- model , img_processor , image_token_len , tokenizer = init_muffin (args .model_name )
205- use_im_start_end = True
248+ if torch .distributed .get_rank () == 0 :
249+ df .to_parquet (cache_file )
206250
207- tsv_files = [ args . tsv_file ]
251+ torch . distributed . barrier ()
208252
209- for tsv_filename in tsv_files :
210- win_logp_list , win_avg_logp_list , win_per_token_logp_list , rej_logp_list , rej_avg_logp_list , rej_per_token_logp_list = get_multimodal_sample_logps (model , tokenizer , args .data_dir , [tsv_filename ], image_token_len , img_processor , use_im_start_end )
211- logps = list (zip (win_logp_list , win_avg_logp_list , win_per_token_logp_list , rej_logp_list , rej_avg_logp_list , rej_per_token_logp_list ))
253+ return df
254+
255+ def inference_logp (model , tokenizer , hf_data , cache_file , image_token_len , img_processor , use_im_start_end ):
256+ model = model .to (dtype = torch .bfloat16 , device = 'cuda' )
257+ dataset = PreferenceInferenceDataset (tokenizer = tokenizer ,
258+ data = hf_data ,
259+ image_token_len = image_token_len ,
260+ img_processor = img_processor ,
261+ use_im_start_end = use_im_start_end )
262+ collate_fn = partial (preference_collator_fn , pad_token_id = tokenizer .pad_token_id )
263+ dataloader = torch_data .DataLoader (dataset , batch_size = 1 , collate_fn = collate_fn ,
264+ num_workers = 5 , shuffle = False , sampler = InferenceSampler (len (dataset )))
212265
266+ outputs = get_multimodal_sample_logps (model , dataloader ) # win_logp_list, win_avg_logp_list, win_per_token_logp_list, rej_logp_list, rej_avg_logp_list, rej_per_token_logp_list
213267
214- tsv_filepath = os .path .join (args .data_dir , tsv_filename )
268+ world_size = torch .distributed .get_world_size ()
269+ merged_outputs = [[None for _ in range (world_size )] for i in range (len (outputs ))]
270+ for i in range (len (outputs )):
271+ torch .distributed .all_gather_object (merged_outputs [i ], outputs [i ])
272+ merged_outputs [i ] = [_ for _ in itertools .chain .from_iterable (merged_outputs [i ])]
215273
216- save_name = '-' . join ( tsv_filename . split ( '-' )[: - 1 ])
217- save_name = save_name + '_' + args . logp_file
274+ win_logp_list , win_avg_logp_list , win_per_token_logp_list , rej_logp_list , rej_avg_logp_list , rej_per_token_logp_list \
275+ = merged_outputs
218276
219- write_logp_to_preference_tsv ( tsv_filepath , f' { args . data_dir } / { save_name } ' , logps , overwrite_logps = True )
277+ logps = list ( zip ( win_logp_list , win_avg_logp_list , win_per_token_logp_list , rej_logp_list , rej_avg_logp_list , rej_per_token_logp_list ) )
220278
279+ df = write_logp_to_preference_parquet (dataset .data , cache_file , logps , overwrite_logps = False )
221280
222- if __name__ == '__main__' :
223- parser = argparse .ArgumentParser ()
224- parser .add_argument ("--model-name" , type = str , default = "RLHF-V_v0-SFT-13B" )
225- parser .add_argument ("--data-dir" , type = str )
226- parser .add_argument ("--tsv-file" , type = str )
227- parser .add_argument ("--logp-file" , type = str , default = "dpo_with_rlhf-v-sft_logp_train" )
228- args = parser .parse_args ()
281+ torch .distributed .barrier ()
229282
230- inference_logp (args )
283+ del model
284+ return df
0 commit comments