-
Notifications
You must be signed in to change notification settings - Fork 7
/
ink_training_eager_embedding.py
58 lines (44 loc) · 1.45 KB
/
ink_training_eager_embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""Main training script."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import glob
import tensorflow as tf
from absl import app
from common.constants import Constants as C
from common.export_code import export_code
from smartink.source.training_eager import TrainingEngine
from smartink.config.config_embedding import define_flags
from smartink.config.config_embedding import get_config
from smartink.config.config_embedding import build_embedding_model
from smartink.config.config_embedding import build_dataset
FLAGS = define_flags()
gpu = tf.config.experimental.list_physical_devices('GPU')[0]
if gpu:
try:
tf.config.experimental.set_memory_growth(gpu, True)
except RuntimeError as e:
print(e)
def main(argv):
del argv
config = get_config(FLAGS)
code_files = glob.glob('**/*.py', recursive=True)
export_code(code_files, os.path.join(config.experiment.model_dir, 'code.zip'))
# Create Dataset
train_data = build_dataset(config, C.RUN_EAGER, C.DATA_TRAIN)
valid_data = build_dataset(config, C.RUN_EAGER, C.DATA_VALID)
# Create Model
model = build_embedding_model(config, C.RUN_EAGER)
# Training Engine
training_engine = TrainingEngine(
config=config,
model=model,
train_data=train_data,
valid_data=valid_data,
test_data=None,
debug=False)
# Start Training
training_engine.run()
if __name__ == "__main__":
app.run(main)