-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_utils.py
82 lines (70 loc) · 2.73 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import json
import random
import codecs
import numpy as np
class PrePareQaData(object):
def __init__(self, conf, mode):
self._mode = mode
self._config = conf
self._currPath = os.path.dirname(__file__)
self._vocabDict = self.__load_chinese_vocab()
self._sourceData = self.__read_dataset()
self.PAD = 0
def __load_chinese_vocab(self):
cv = dict()
with codecs.open(os.path.join(self._currPath, "data/chinese_vocab.txt"), "r", "utf8") as f:
for i, line in enumerate(f.readlines()):
cv[line.strip()] = i
return cv
def __read_dataset(self):
if self._mode == "train":
dataset_path = os.path.join(self._currPath, "data/trainset.txt")
elif self._mode == "test":
dataset_path = os.path.join(self._currPath, "data/testset.txt")
else:
raise Exception("mode must be in [train/test]")
if not os.path.exists(dataset_path):
raise Exception("path [{}] not exists".format(dataset_path))
with codecs.open(dataset_path, "r", "utf8") as fp:
dataset = fp.readlines()
random.shuffle(dataset)
return iter(dataset)
def __word_to_id(self, dialogue):
_id_lst = []
for char in dialogue:
_id = self._vocabDict.get(char, 3) # index(<UNKNOWN>) == 3
_id_lst.append(_id)
return _id_lst
def __parse_dialogue(self, dialogue_lst):
encoder_id_lst = []
decoder_id_lst = []
for dialogue in dialogue_lst:
dialogue = json.loads(dialogue)
encoder_id_lst.append(self.__word_to_id(dialogue["Q"]))
decoder_id_lst.append(self.__word_to_id(dialogue["A"]))
return encoder_id_lst, decoder_id_lst
def __padding_coder_id(self, coder_id_lst):
max_len = max([len(item) for item in coder_id_lst])
for coder_id in coder_id_lst:
coder_id.extend((max_len - len(coder_id)) * [self.PAD])
return coder_id_lst
def __iter__(self):
return self
def __next__(self):
dialogue_lst = []
count = 0
try:
while count < self._config.batch_size:
cur = next(self._sourceData)
if not cur:
continue
count += 1
dialogue_lst.append(cur)
except StopIteration as iter_exception:
if count == 0:
raise iter_exception
encoder_input, decoder_target = self.__parse_dialogue(dialogue_lst)
encoder_input = self.__padding_coder_id(encoder_input)
decoder_target = self.__padding_coder_id(decoder_target)
return np.array(encoder_input), np.array(decoder_target)