-
Notifications
You must be signed in to change notification settings - Fork 0
/
Trainer.py
22 lines (16 loc) · 722 Bytes
/
Trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import os
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common import env_checker
class TrainAndLoggingCallback(BaseCallback):
def __init__(self, check_freq, save_path, verbose=1):
super(TrainAndLoggingCallback, self).__init__(verbose)
self.check_freq = check_freq
self.save_path = save_path
def _init_callback(self):
if self.save_path is not None:
os.makedirs(self.save_path, exist_ok=True)
def _on_step(self):
if self.n_calls % self.check_freq == 0:
model_path = os.path.join(self.save_path, 'bestModel{}'.format(self.n_calls))
self.model.save(model_path)
return True