-
Notifications
You must be signed in to change notification settings - Fork 29
/
data_utils.py
71 lines (62 loc) · 3.58 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import json
import torch
from functools import partial
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
def __init__(self, raw_data, label_dict, tokenizer, model_name, method):
label_list = list(label_dict.keys()) if method not in ['ce', 'scl'] else []
sep_token = ['[SEP]'] if model_name == 'bert' else ['</s>']
dataset = list()
for data in raw_data:
tokens = data['text'].lower().split(' ')
label_id = label_dict[data['label']]
dataset.append((label_list + sep_token + tokens, label_id))
self._dataset = dataset
def __getitem__(self, index):
return self._dataset[index]
def __len__(self):
return len(self._dataset)
def my_collate(batch, tokenizer, method, num_classes):
tokens, label_ids = map(list, zip(*batch))
text_ids = tokenizer(tokens,
padding=True,
truncation=True,
max_length=256,
is_split_into_words=True,
add_special_tokens=True,
return_tensors='pt')
if method not in ['ce', 'scl']:
positions = torch.zeros_like(text_ids['input_ids'])
positions[:, num_classes:] = torch.arange(0, text_ids['input_ids'].size(1)-num_classes)
text_ids['position_ids'] = positions
return text_ids, torch.tensor(label_ids)
def load_data(dataset, data_dir, tokenizer, train_batch_size, test_batch_size, model_name, method, workers):
if dataset == 'sst2':
train_data = json.load(open(os.path.join(data_dir, 'SST2_Train.json'), 'r', encoding='utf-8'))
test_data = json.load(open(os.path.join(data_dir, 'SST2_Test.json'), 'r', encoding='utf-8'))
label_dict = {'positive': 0, 'negative': 1}
elif dataset == 'trec':
train_data = json.load(open(os.path.join(data_dir, 'TREC_Train.json'), 'r', encoding='utf-8'))
test_data = json.load(open(os.path.join(data_dir, 'TREC_Test.json'), 'r', encoding='utf-8'))
label_dict = {'description': 0, 'entity': 1, 'abbreviation': 2, 'human': 3, 'location': 4, 'numeric': 5}
elif dataset == 'cr':
train_data = json.load(open(os.path.join(data_dir, 'CR_Train.json'), 'r', encoding='utf-8'))
test_data = json.load(open(os.path.join(data_dir, 'CR_Test.json'), 'r', encoding='utf-8'))
label_dict = {'positive': 0, 'negative': 1}
elif dataset == 'subj':
train_data = json.load(open(os.path.join(data_dir, 'SUBJ_Train.json'), 'r', encoding='utf-8'))
test_data = json.load(open(os.path.join(data_dir, 'SUBJ_Test.json'), 'r', encoding='utf-8'))
label_dict = {'subjective': 0, 'objective': 1}
elif dataset == 'pc':
train_data = json.load(open(os.path.join(data_dir, 'procon_Train.json'), 'r', encoding='utf-8'))
test_data = json.load(open(os.path.join(data_dir, 'procon_Test.json'), 'r', encoding='utf-8'))
label_dict = {'positive': 0, 'negative': 1}
else:
raise ValueError('unknown dataset')
trainset = MyDataset(train_data, label_dict, tokenizer, model_name, method)
testset = MyDataset(test_data, label_dict, tokenizer, model_name, method)
collate_fn = partial(my_collate, tokenizer=tokenizer, method=method, num_classes=len(label_dict))
train_dataloader = DataLoader(trainset, train_batch_size, shuffle=True, num_workers=workers, collate_fn=collate_fn, pin_memory=True)
test_dataloader = DataLoader(testset, test_batch_size, shuffle=False, num_workers=workers, collate_fn=collate_fn, pin_memory=True)
return train_dataloader, test_dataloader