-
Notifications
You must be signed in to change notification settings - Fork 3
/
train_utils.py
65 lines (49 loc) · 2.22 KB
/
train_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# -*- coding=utf-8 -*-
import tensorflow as tf
from utils import add_layer_summary
def gradient_clipping(optimizer, cost, _lower, _upper):
"""
apply gradient clipping
"""
gradients, variables = zip(*optimizer.compute_gradients( cost ))
clip_grad = [tf.clip_by_value( grad, _lower, _upper ) for grad in gradients if grad is not None]
train_op = optimizer.apply_gradients(zip(clip_grad, variables),
global_step=tf.train.get_global_step() )
return train_op
def get_train_op(optimizer, loss, params):
if params.get('clip_gradient', False):
train_op = gradient_clipping(optimizer, loss,
params['lower_gradient'],
params['upper_gradient'])
else:
train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
return train_op
def get_learning_rate(params):
"""
all sorts of learning rate strategy
"""
## TODO: support other decay method
if params.get('rate_decay', False):
lr = exponential_decay(params)
elif params.get('warmup', False):
lr = noam(params)
else:
lr = params['learning_rate']
tf.summary.scalar('learning_rate', lr)
return lr
def exponential_decay(params):
PARAMS = ['learning_rate', 'decay_rate', 'decay_step']
assert all([ i in params.keys() for i in PARAMS]), '{} are needed fro exponential decay'.format(','.join(PARAMS))
lr = tf.train.exponential_decay(params['learning_rate'],
tf.train.get_global_step(),
params['decay_rate'],
params['decay_step']
)
return lr
def noam(params):
PARAMS = ['emb_size', 'warmup_steps']
assert all([ i in params.keys() for i in PARAMS]), '{} are needed for noam'.format(','.join(PARAMS))
lr = params['emb_size'] ** -0.5 * tf.minimum(tf.cast(tf.train.get_global_step(), params['dtype']) ** -0.5,
tf.cast(tf.train.get_global_step(), params['dtype']) * params[
'warmup_steps'] ** -1.5)
return lr