Skip to content

Commit 8ec4cf9

Browse files
committed
sequence example
1 parent 73b3aba commit 8ec4cf9

File tree

2 files changed

+68
-5
lines changed

2 files changed

+68
-5
lines changed

copy_memory.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from datetime import datetime
2+
from pathlib import Path
3+
4+
import numpy as np
5+
import tensorflow as tf
6+
from tensorflow.keras import losses, metrics, optimizers
7+
from tensorflow.keras.callbacks import TensorBoard
8+
9+
import tcn
10+
11+
12+
def load_dataset(batch_size, T):
13+
x_train, y_train = generate_copy_sequence(batch_size, T)
14+
x_test, y_test = generate_copy_sequence(batch_size, T)
15+
16+
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(1000)
17+
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).shuffle(1000)
18+
return train_dataset, test_dataset
19+
20+
21+
def generate_copy_sequence(batch_size, sequence_length):
22+
x = np.zeros((batch_size, sequence_length))
23+
copy_sequence = np.random.randint(0, 8, (batch_size, 10))
24+
x[:, :10] = copy_sequence
25+
x[:, -11:] = 9
26+
27+
y = np.zeros_like(x)
28+
y[:, -10:] = copy_sequence
29+
return x, y
30+
31+
32+
def train():
33+
depth = 6
34+
filters = 25
35+
block_filters = [filters] * depth
36+
sequence_length = 601
37+
38+
train_dataset, test_dataset = load_dataset(30000, sequence_length)
39+
40+
model = tcn.build_model(sequence_length=sequence_length,
41+
channels=1,
42+
num_classes=10,
43+
filters=block_filters,
44+
kernel_size=8,
45+
return_sequence=True)
46+
47+
model.compile(optimizer=optimizers.RMSprop(lr=5e-4, clipnorm=1.),
48+
metrics=[metrics.SparseCategoricalAccuracy()],
49+
loss=losses.SparseCategoricalCrossentropy())
50+
51+
print(model.summary())
52+
53+
model.fit(train_dataset.batch(32),
54+
validation_data=test_dataset.batch(32),
55+
callbacks=[TensorBoard(str(Path("logs") / datetime.now().strftime("%Y-%m-%dT%H-%M_%S")))],
56+
epochs=10)
57+
58+
59+
if __name__ == '__main__':
60+
train()

tcn.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,12 @@ def call(self, inputs, training=None, **kwargs):
6666

6767
class TCN(layers.Layer):
6868

69-
def __init__(self, filters, kernel_size, **kwargs):
69+
def __init__(self, filters, kernel_size, return_sequence=False, **kwargs):
7070
super(TCN, self).__init__(**kwargs)
7171
self.blocks = []
7272
self.depth = len(filters)
7373
self.kernel_size = kernel_size
74+
self.return_sequence = return_sequence
7475

7576
for i in range(self.depth):
7677
dilation_size = 2 ** i
@@ -81,24 +82,26 @@ def __init__(self, filters, kernel_size, **kwargs):
8182
name=f"residual_block_{i}")
8283
)
8384

84-
self.slice_layer = layers.Lambda(lambda tt: tt[:, -1, :])
85+
if not self.return_sequence:
86+
self.slice_layer = layers.Lambda(lambda tt: tt[:, -1, :])
8587

8688
def call(self, inputs, training=None, **kwargs):
8789
x = inputs
8890
for block in self.blocks:
8991
x = block(x)
9092

91-
x = self.slice_layer(x)
93+
if not self.return_sequence:
94+
x = self.slice_layer(x)
9295
return x
9396

9497
@property
9598
def receptive_field_size(self):
9699
return 1 + 2 * (self.kernel_size - 1) * (2 ** self.depth - 1)
97100

98101

99-
def build_model(sequence_length, channels, filters, num_classes, kernel_size):
102+
def build_model(sequence_length, channels, filters, num_classes, kernel_size, return_sequence=False):
100103
inputs = Input(shape=(sequence_length, channels), name="inputs")
101-
tcn_block = TCN(filters, kernel_size)
104+
tcn_block = TCN(filters, kernel_size, return_sequence)
102105
x = tcn_block(inputs)
103106

104107
outputs = layers.Dense(num_classes,

0 commit comments

Comments
 (0)