-
Notifications
You must be signed in to change notification settings - Fork 1
/
main_separator.py
34 lines (29 loc) · 1.51 KB
/
main_separator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import yaml
import argparse
from pprint import pprint as print # noqa
from asteroid.utils import prepare_parser_from_dict, parse_args_as_dict
from src.Separator.train import main
# Keys which are not in the conf.yml file can be added here.
# In the hierarchical dictionary created when parsing, the key `key` can be
# found at dic['main_args'][key]
# By default train.py will use all available GPUs. The `id` option in run.sh
# will limit the number of available GPUs for train.py .
parser = argparse.ArgumentParser()
parser.add_argument('--exp_dir', default='Code/Separator/exp',
help='Full path to save best validation model')
if __name__ == '__main__':
# We start with opening the config file conf.yml as a dictionary from
# which we can create parsers. Each top level key in the dictionary defined
# by the YAML file creates a group in the parser.
with open('src/Separator/conf.yml') as f:
def_conf = yaml.safe_load(f)
parser = prepare_parser_from_dict(def_conf, parser=parser)
# Arguments are then parsed into a hierarchical dictionary (instead of
# flat, as returned by argparse) to facilitate calls to the different
# asteroid methods (see in main).
# plain_args is the direct output of parser.parse_args() and contains all
# the attributes in an non-hierarchical structure. It can be useful to also
# have it so we included it here but it is not used.
arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True)
print(arg_dic)
main(arg_dic)