-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
48 lines (38 loc) · 1.75 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import os
import yaml
import logging
import argparse
import numpy as np
import torch
import train_language_model
if __name__ == '__main__':
# load settings from config file
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, help="Path to a config file.")
_args = parser.parse_args()
with open(_args.config, "r") as f:
cfg = yaml.safe_load(f)
# create loggers
performance_logger = logging.getLogger(cfg['experiment_name']+cfg['model']+'_perf_logger')
performance_logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(message)s','%d-%m-%Y %H:%M')
perf_file_handler = logging.FileHandler(os.path.join(cfg['log_dir'],cfg['model'],cfg['experiment_name']+'_performance'))
perf_file_handler.setLevel(logging.INFO)
perf_file_handler.setFormatter(formatter)
performance_logger.addHandler(perf_file_handler)
generated_text_logger = logging.getLogger(cfg['experiment_name']+cfg['model']+'_gen_logger')
generated_text_logger.setLevel(logging.INFO)
formatter2 = logging.Formatter('%(message)s')
gen_file_handler = logging.FileHandler(os.path.join(cfg['log_dir'],cfg['model'],cfg['experiment_name']+'_generated_texts'))
gen_file_handler.setLevel(logging.INFO)
gen_file_handler.setFormatter(formatter2)
generated_text_logger.addHandler(gen_file_handler)
# print settings
performance_logger.info(f'cfg: {cfg}')
# setting device on GPU if available, else CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
performance_logger.info(f'Using device: {device}')
if cfg['model']=='k2t':
train_language_model.train(cfg, device, performance_logger, generated_text_logger)
else:
raise NotImplementedError()