-
Notifications
You must be signed in to change notification settings - Fork 2
/
generate.py
36 lines (28 loc) · 1.08 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import os
from termcolor import cprint
from src.MidiGenerator import MidiGenerator
from src import Args
from src.Args import Parser, ArgType
def main(args):
"""
Entry point
"""
if args.pc:
data_path = os.path.join('../Dataset', args.data)
else:
data_path = os.path.join('../../../../../../storage1/valentin', args.data)
data_transformed_path = data_path + '_transformed'
if not args.pc:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
my_model = MidiGenerator.with_model(id=args.load) # Load the model
my_model.generate_from_data(length=args.length,
nb_seeds=args.nb_seeds,
save_images=args.images,
no_duration=args.no_duration)
cprint('---------- Done ----------', 'grey', 'on_green')
if __name__ == '__main__':
# create a separate main function because original main function is too mainstream
parser = Parser(argtype=ArgType.Generate)
args = parser.parse_args()
args = Args.preprocess.generate(args)
main(args)