-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
55 lines (38 loc) · 1.32 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""Train script."""
import argparse
import random
from pathlib import Path
import numpy as np
import tensorflow as tf
from diffwave import DiffWaveConfig, HDF5Dataset, TFDiffWave, Trainer
def main() -> None:
"""Main entry."""
parser = argparse.ArgumentParser(description="Model Training")
parser.add_argument("name", help="name for run")
parser.add_argument("data_path", type=Path, help="path to dataset")
parser.add_argument("--log_path", type=Path, help="path to dataset", default="logs")
args = parser.parse_args()
initialize(1234)
config = DiffWaveConfig()
model = TFDiffWave(config)
dataset = HDF5Dataset(args.data_path, config)
dataset.load()
log_path = args.log_path / args.name
log_path.mkdir(exist_ok=True)
trainer = Trainer(config=config, model=model, dataset=dataset, log_path=log_path)
while trainer.step < config.max_steps:
trainer.train()
trainer.validate()
trainer.test()
def initialize(seed: int = 1234) -> None:
"""Set memory growth and seeds."""
try:
for d in tf.config.list_physical_devices("GPU"):
tf.config.experimental.set_memory_growth(d, True)
except RuntimeError:
pass
random.seed(seed)
tf.random.set_seed(seed)
np.random.seed(seed)
if __name__ == "__main__":
main()