diff --git a/mmocr/datasets/ser_dataset.py b/mmocr/datasets/ser_dataset.py index 4144a958f..0bf1bbbfb 100644 --- a/mmocr/datasets/ser_dataset.py +++ b/mmocr/datasets/ser_dataset.py @@ -46,39 +46,46 @@ def load_data_list(self) -> List[dict]: data_list = super().load_data_list() # split text to several slices because of over-length - input_ids, bboxes, labels = [], [], [] - segment_ids, position_ids = [], [] - image_path = [] + split_text_data_list = [] for i in range(len(data_list)): start = 0 cur_iter = 0 while start < len(data_list[i]['input_ids']): end = min(start + 510, len(data_list[i]['input_ids'])) - - input_ids.append([self.tokenizer.cls_token_id] + - data_list[i]['input_ids'][start:end] + - [self.tokenizer.sep_token_id]) - bboxes.append([[0, 0, 0, 0]] + - data_list[i]['bboxes'][start:end] + - [[1000, 1000, 1000, 1000]]) - labels.append([-100] + data_list[i]['labels'][start:end] + - [-100]) - - cur_segment_ids = self.get_segment_ids(bboxes[-1]) - cur_position_ids = self.get_position_ids(cur_segment_ids) - segment_ids.append(cur_segment_ids) - position_ids.append(cur_position_ids) - image_path.append( - os.path.join(self.data_root, data_list[i]['img_path'])) + # get input_ids + input_ids = [self.tokenizer.cls_token_id] + \ + data_list[i]['input_ids'][start:end] + \ + [self.tokenizer.sep_token_id] + # get bboxes + bboxes = [[0, 0, 0, 0]] + \ + data_list[i]['bboxes'][start:end] + \ + [[1000, 1000, 1000, 1000]] + # get labels + labels = [-100] + data_list[i]['labels'][start:end] + [-100] + # get segment_ids + segment_ids = self.get_segment_ids(bboxes) + # get position_ids + position_ids = self.get_position_ids(segment_ids) + # get img_path + img_path = os.path.join(self.data_root, + data_list[i]['img_path']) + # get attention_mask + attention_mask = [1] * len(input_ids) + + data_info = {} + data_info['input_ids'] = input_ids + data_info['bboxes'] = bboxes + data_info['labels'] = labels + data_info['segment_ids'] = segment_ids + data_info['position_ids'] = position_ids + data_info['img_path'] = img_path + data_info['attention_mask '] = attention_mask + split_text_data_list.append(data_info) start = end cur_iter += 1 - assert len(input_ids) == len(bboxes) == len(labels) == len( - segment_ids) == len(position_ids) - assert len(segment_ids) == len(image_path) - - return data_list + return split_text_data_list def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]: instances = raw_data_info['instances'] diff --git a/projects/LayoutLMv3/test.py b/projects/LayoutLMv3/test.py index 14170c39f..96df72b1a 100644 --- a/projects/LayoutLMv3/test.py +++ b/projects/LayoutLMv3/test.py @@ -1,4 +1,5 @@ from mmengine.config import Config +from mmengine.registry import init_default_scope from mmocr.registry import DATASETS @@ -6,9 +7,17 @@ cfg_path = '/Users/wangnu/Documents/GitHub/mmocr/projects/' \ 'LayoutLMv3/configs/layoutlmv3_xfund_zh.py' cfg = Config.fromfile(cfg_path) + init_default_scope(cfg.get('default_scope', 'mmocr')) dataset_cfg = cfg.train_dataset dataset_cfg['tokenizer'] = \ '/Users/wangnu/Documents/GitHub/mmocr/data/layoutlmv3-base-chinese' + + train_pipeline = [ + dict(type='LoadImageFromFile', color_type='color'), + dict(type='Resize', scale=(224, 224)) + ] + dataset_cfg['pipeline'] = train_pipeline ds = DATASETS.build(dataset_cfg) - print(ds[0]) + data = ds[0] + print('hi')