-
Notifications
You must be signed in to change notification settings - Fork 33
/
Copy pathmain.py
67 lines (54 loc) · 1.8 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""
Run on GPU: Install tensorflow-gpu (v1.6), then do: python main.py
"""
from __future__ import print_function
from __future__ import division
import json
import py_crepe
import numpy as np
import data_helpers
np.random.seed(123) # for reproducibility
# set parameters:
subset = None
# Whether to save model parameters
save = False
model_name_path = 'params/crepe_model.json'
model_weights_path = 'params/crepe_model_weights.h5'
# Maximum length. Longer gets chopped. Shorter gets padded.
maxlen = 1014
# Model params
# Filters for conv layers
nb_filter = 256
# Number of units in the dense layer
dense_outputs = 1024
# Conv layer kernel size
filter_kernels = [7, 7, 3, 3, 3, 3]
# Number of units in the final output layer. Number of classes.
cat_output = 4
# Compile/fit params
batch_size = 80
nb_epoch = 20
print('Loading data...')
# Expect x to be a list of sentences. Y to be index of the categories.
(xt, yt), (x_test, y_test) = data_helpers.load_ag_data()
print('Creating vocab...')
vocab, reverse_vocab, vocab_size, alphabet = data_helpers.create_vocab_set()
print('Build model...')
model = py_crepe.create_model(filter_kernels, dense_outputs, maxlen, vocab_size,
nb_filter, cat_output)
# Encode data
xt = data_helpers.encode_data(xt, maxlen, vocab)
x_test = data_helpers.encode_data(x_test, maxlen, vocab)
print('Chars vocab: {}'.format(alphabet))
print('Chars vocab size: {}'.format(vocab_size))
print('X_train.shape: {}'.format(xt.shape))
model.summary()
print('Fit model...')
model.fit(xt, yt,
validation_data=(x_test, y_test), batch_size=batch_size, epochs=nb_epoch, shuffle=True)
if save:
print('Saving model params...')
json_string = model.to_json()
with open(model_name_path, 'w') as f:
json.dump(json_string, f)
model.save_weights(model_weights_path)