-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathaddition_reader.py
106 lines (83 loc) · 3.13 KB
/
addition_reader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import tensorflow as tf
import random
import numpy as np
from .registry import register
def generate_batch(params):
features = []
labels = []
sequence_lengths = []
target_mask = []
# helper to mask total_list
total_mask = [-1] * (params.num_digits + 1)
for _ in range(params.batch_size):
feature = []
label = []
total = 0
mask = []
sequence_length = random.randint(1, params.max_difficulty)
sequence_lengths.append(sequence_length)
# generate one number
for i in range(sequence_length):
num_digits = random.randint(1, params.num_digits)
value = []
for _ in range(num_digits):
value.append(random.randint(0, 9))
# remove zero if value start with zero
while value and value[0] == 0:
value.pop(0)
# add to total if value is not zero
if value:
total += int(''.join(map(str, value)))
# convert int to list of int
total_list = list(map(int, str(total)))
# pad value with -1
value += [-1] * (params.num_digits - len(value))
# pad digits beyond the end of target number with 11
total_list += [-1] * (params.num_digits + 1 - len(total_list))
# mask number of digits
mask.append((1 - np.equal(total_list, total_mask).astype(int)).tolist())
feature.append(value)
label.append(total_list)
# pad samples, targets to sequence length
sequence_offset = params.max_difficulty - sequence_length
feature += [[-1] * params.num_digits] * sequence_offset
label += [[-1] * (params.num_digits + 1)] * sequence_offset
mask += [[0] * (params.num_digits + 1)] * sequence_offset
assert len(feature) == len(label) == len(mask)
features.append(feature)
labels.append(label)
target_mask.append(mask)
return features, sequence_lengths, target_mask, labels
@register("addition")
def input_fn(data_sources, params, training):
def _input_fn():
""" Generate batch_size number of addition samples
pad x with -1 and y with 11 to match sequence length and number of digits
y has 11 classes, the 11th class represent number is complete
Returns:
x: shape=(batch_size, max_difficulty, num_digits * 10),
randomly generated integer
seq_length: shape(batch_size,). sequence length for each input
y: shape=(batch_size, max_difficulty, num_digits + 1 * 11),
sum of x until the current index
"""
get_batch = lambda: generate_batch(params)
x, seq_length, target_mask, y = \
tf.py_func(get_batch, [], [tf.int64, tf.int64, tf.int64, tf.int64])
x = tf.reshape(
tf.one_hot(x, depth=10),
shape=(params.batch_size, params.max_difficulty,
params.num_digits * 10))
y = tf.reshape(
tf.one_hot(y, depth=10),
shape=(params.batch_size, params.max_difficulty, params.num_classes))
seq_length.set_shape(shape=(params.batch_size,))
target_mask.set_shape(
shape=(params.batch_size, params.max_difficulty, params.num_digits + 1))
return {
"inputs": x,
"seq_length": seq_length,
"difficulty": seq_length,
"target_mask": target_mask
}, y
return _input_fn