forked from gwinndr/MusicTransformer-Pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
preprocess_midi.py
114 lines (91 loc) · 3.02 KB
/
preprocess_midi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import argparse
import os
import pickle
import json
import third_party.midi_processor.processor as midi_processor
JSON_FILE = "maestro-v2.0.0.json"
# prep_midi
def prep_midi(maestro_root, output_dir):
"""
----------
Author: Damon Gwinn
----------
Pre-processes the maestro dataset, putting processed midi data (train, eval, test) into the
given output folder
----------
"""
train_dir = os.path.join(output_dir, "train")
os.makedirs(train_dir, exist_ok=True)
val_dir = os.path.join(output_dir, "val")
os.makedirs(val_dir, exist_ok=True)
test_dir = os.path.join(output_dir, "test")
os.makedirs(test_dir, exist_ok=True)
maestro_json_file = os.path.join(maestro_root, JSON_FILE)
if(not os.path.isfile(maestro_json_file)):
print("ERROR: Could not find file:", maestro_json_file)
return False
maestro_json = json.load(open(maestro_json_file, "r"))
print("Found", len(maestro_json), "pieces")
print("Preprocessing...")
total_count = 0
train_count = 0
val_count = 0
test_count = 0
for piece in maestro_json:
mid = os.path.join(maestro_root, piece["midi_filename"])
split_type = piece["split"]
f_name = mid.split("/")[-1] + ".pickle"
if(split_type == "train"):
o_file = os.path.join(train_dir, f_name)
train_count += 1
elif(split_type == "validation"):
o_file = os.path.join(val_dir, f_name)
val_count += 1
elif(split_type == "test"):
o_file = os.path.join(test_dir, f_name)
test_count += 1
else:
print("ERROR: Unrecognized split type:", split_type)
return False
prepped = midi_processor.encode_midi(mid)
o_stream = open(o_file, "wb")
pickle.dump(prepped, o_stream)
o_stream.close()
total_count += 1
if(total_count % 50 == 0):
print(total_count, "/", len(maestro_json))
print("Num Train:", train_count)
print("Num Val:", val_count)
print("Num Test:", test_count)
return True
# parse_args
def parse_args():
"""
----------
Author: Damon Gwinn
----------
Parses arguments for preprocess_midi using argparse
----------
"""
parser = argparse.ArgumentParser()
parser.add_argument("maestro_root", type=str, help="Root folder for the Maestro dataset")
parser.add_argument("-output_dir", type=str, default="./dataset/e_piano", help="Output folder to put the preprocessed midi into")
return parser.parse_args()
# main
def main():
"""
----------
Author: Damon Gwinn
----------
Entry point. Preprocesses maestro and saved midi to specified output folder.
----------
"""
args = parse_args()
maestro_root = args.maestro_root
output_dir = args.output_dir
print("Preprocessing midi files and saving to", output_dir)
prep_midi(maestro_root, output_dir)
print("Done!")
print("")
if __name__ == "__main__":
main()