-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtrain.py
executable file
·46 lines (36 loc) · 1.16 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import os
import sys
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__))))
import argparse
from trainer import Trainer_Epoch, Trainer_Episode
from utils import read_config
def main(config):
if config["type"].lower() == "epoch":
trainer = Trainer_Epoch(config)
elif config["type"].lower() == "episode":
trainer = Trainer_Episode(config)
else:
raise KeyError("type error")
trainer.train()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="")
parser.add_argument(
"--config",
default="config/base_epoch.yml",
type=str,
help="path to config file",
)
parser.add_argument(
"--resume", default="", type=str, help="path to model_pretrained.pth file"
)
parser.add_argument(
"--only_model",
default=False,
type=lambda x: (str(x).lower() == "true"),
help="only resume model",
)
args = parser.parse_args()
config = read_config(args.config)
config.update({"resume": args.resume})
config.update({"only_model": args.only_model})
main(config)