-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval.py
55 lines (42 loc) · 1.64 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import re
from glob import glob
import matplotlib.pyplot as plt
loss_per_epoch_bass = {}
loss_per_epoch_middle = {}
expr = re.compile(r'bass=(.+) middle=(.+) e=(\d+)')
file_paths = glob('./all_test_pieces_plain/**/*.midi')
for fp in file_paths:
bass, middle, epoch = expr.findall(fp)[0]
if epoch not in loss_per_epoch_bass:
loss_per_epoch_bass[epoch] = 0
loss_per_epoch_middle[epoch] = 0
loss_per_epoch_bass[epoch] += float(bass)
loss_per_epoch_middle[epoch] += float(middle)
# Bass
sorted_loss_bass = [(int(k), loss_per_epoch_bass[k]) for k in
sorted(loss_per_epoch_bass, key=loss_per_epoch_bass.get, reverse=True)]
print('Part part:')
for e, l in sorted_loss_bass:
print(e, l)
sorted_loss_bass_epochs = sorted(loss_per_epoch_bass)
epochs_bass = [int(e) for e in sorted_loss_bass_epochs]
losses_bass = [loss_per_epoch_bass[e] for e in sorted_loss_bass_epochs]
plt.plot(epochs_bass, losses_bass, label='bass')
print()
# Middle
sorted_loss_middle = [(int(k), loss_per_epoch_middle[k]) for k in
sorted(loss_per_epoch_middle, key=loss_per_epoch_middle.get, reverse=True)]
print('Middle parts:')
for e, l in sorted_loss_middle:
print(e, l)
sorted_loss_middle_epochs = sorted(loss_per_epoch_middle)
epochs_middle = [int(e) for e in sorted_loss_middle_epochs]
losses_middle = [loss_per_epoch_middle[e] for e in sorted_loss_middle_epochs]
plt.plot(epochs_middle, losses_middle, label='middle')
# Both
losses_all = []
for i in range(len(losses_bass)):
losses_all.append(losses_bass[i] + losses_middle[i])
plt.plot(epochs_bass, losses_all, label='all')
plt.legend(loc='upper left')
plt.show()