-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathplot_scores.py
40 lines (32 loc) · 1.23 KB
/
plot_scores.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import argparse
import os
import matplotlib.pyplot as plt
import pandas as pd
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--title', type=str, default='')
parser.add_argument('--file', action='append', dest='files',
default=[], type=str,
help='specify paths of scores.txt')
parser.add_argument('--label', action='append', dest='labels',
default=[], type=str,
help='specify labels for scores.txt files')
args = parser.parse_args()
assert len(args.files) > 0
assert len(args.labels) == len(args.files)
for i, (fpath, label) in enumerate(zip(args.files, args.labels)):
if os.path.isdir(fpath):
fpath = os.path.join(fpath, 'scores.txt')
assert os.path.exists(fpath)
scores = pd.read_csv(fpath, delimiter='\t')
plt.plot(scores['steps'], scores[label], label=label)
plt.xlabel('steps')
plt.ylabel('score')
plt.legend(loc='best')
if args.title:
plt.title(args.title)
fig_fname = args.files[0] + args.title + '.png'
plt.savefig(fig_fname)
print('Saved a figure as {}'.format(fig_fname))
if __name__ == '__main__':
main()