-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate.py
64 lines (59 loc) · 1.99 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
import json
import numpy as np
import pandas as pd
from keras.models import Sequential
from keras.layers import LSTM, Dropout, Dense, Activation, Embedding
#from keras.utils.vis_utils import plot_model
cij = "cti.json"
model_weights_directory = 'Model_Weights1/'
def generateSequence(initial_index, seq_length):
with open(cij) as f:
char_to_index = json.load(f)
itc = {i:ch for ch, i in char_to_index.items()}
allChars = len(itc)
model = buildModel(allChars)
print(model.summary())
model.load_weights(model_weights_directory + "Weights_90.h5")
sequence_index = [initial_index]
for i in range(seq_length):
batch = np.zeros((1, 1))
batch[0, 0] = sequence_index[-1]
prediction = model.predict_on_batch(batch).ravel()
# print(prediction)
sample = np.random.choice(range(allChars), size = 1, p = prediction)
# print(sample, itc[np.argmax(prediction)])
sequence_index.append(sample[0])
seq = ''
for c in sequence_index:
seq = seq + itc[c]
cnt = 0
for i in seq:
cnt += 1
if i == "\n":
break
seq1 = seq[cnt:]
cnt = 0
for i in seq1:
cnt += 1
if i == "\n" and seq1[cnt] == "\n":
break
seq2 = seq1[:cnt]
return seq2
def buildModel(unique_chars):
model = Sequential()
model.add(Embedding(input_dim = unique_chars, output_dim = 512, batch_input_shape = (1, 1)))
model.add(LSTM(256, return_sequences = True, stateful = True))
model.add(Dropout(0.2))
model.add(LSTM(256, return_sequences = True, stateful = True,))
model.add(Dropout(0.2))
model.add(LSTM(256, stateful = True))
model.add(Dropout(0.2))
model.add((Dense(unique_chars)))
model.add(Activation("softmax"))
# plot_model(model, to_file='model_plot.png', show_shapes=True, show_layer_names=True)
return model
if __name__ == '__main__':
sequence = generateSequence(8, 550)
print("Music Generated : \n")
print(sequence)