|
| 1 | +import tensorflow as tf |
| 2 | +import json |
| 3 | +import numpy as np |
| 4 | +from tqdm import tqdm |
| 5 | +import os |
| 6 | + |
| 7 | +from model import Model |
| 8 | +from util import create_batch, convert_tokens, evaluate |
| 9 | + |
| 10 | + |
| 11 | +def train(config): |
| 12 | + with open(config.word_emb_file, "r") as fh: |
| 13 | + word_mat = np.array(json.load(fh), dtype=np.float32) |
| 14 | + with open(config.char_emb_file, "r") as fh: |
| 15 | + char_mat = np.array(json.load(fh), dtype=np.float32) |
| 16 | + with open(config.train_eval_file, "r") as fh: |
| 17 | + train_eval_file = json.load(fh) |
| 18 | + with open(config.dev_eval_file, "r") as fh: |
| 19 | + dev_eval_file = json.load(fh) |
| 20 | + |
| 21 | + print("Building model...") |
| 22 | + train_batch = create_batch(config.train_record_file, config) |
| 23 | + dev_batch = create_batch(config.dev_record_file, config) |
| 24 | + with tf.variable_scope("model"): |
| 25 | + model_train = Model(config, train_batch, word_mat, char_mat) |
| 26 | + tf.get_variable_scope().reuse_variables() |
| 27 | + model_dev = Model(config, dev_batch, word_mat, |
| 28 | + char_mat, trainable=False) |
| 29 | + |
| 30 | + sess_config = tf.ConfigProto(allow_soft_placement=True) |
| 31 | + sess_config.gpu_options.allow_growth = True |
| 32 | + |
| 33 | + loss_save = 100.0 |
| 34 | + patience = 0 |
| 35 | + lr = config.init_lr |
| 36 | + |
| 37 | + with tf.Session(config=sess_config) as sess: |
| 38 | + writer = tf.summary.FileWriter(config.log_dir) |
| 39 | + sess.run(tf.global_variables_initializer()) |
| 40 | + coord = tf.train.Coordinator() |
| 41 | + threads = tf.train.start_queue_runners(coord=coord) |
| 42 | + saver = tf.train.Saver() |
| 43 | + sess.run(tf.assign(model_train.is_train, |
| 44 | + tf.constant(True, dtype=tf.bool))) |
| 45 | + sess.run(tf.assign(model_train.lr, tf.constant(lr, dtype=tf.float32))) |
| 46 | + |
| 47 | + for _ in tqdm(range(1, config.num_steps + 1)): |
| 48 | + global_step = sess.run(model_train.global_step) + 1 |
| 49 | + loss, train_op = sess.run([model_train.loss, model_train.train_op]) |
| 50 | + if global_step % config.period == 0: |
| 51 | + loss_sum = tf.Summary(value=[tf.Summary.Value( |
| 52 | + tag="model/loss", simple_value=loss), ]) |
| 53 | + writer.add_summary(loss_sum, global_step) |
| 54 | + if global_step % config.checkpoint == 0: |
| 55 | + sess.run(tf.assign(model_train.is_train, |
| 56 | + tf.constant(False, dtype=tf.bool))) |
| 57 | + _, summ = evaluate_batch( |
| 58 | + model_train, config.val_num_batches, train_eval_file, sess, "train") |
| 59 | + for s in summ: |
| 60 | + writer.add_summary(s, global_step) |
| 61 | + |
| 62 | + metrics, summ = evaluate_batch( |
| 63 | + model_dev, config.val_num_batches, dev_eval_file, sess, "dev") |
| 64 | + sess.run(tf.assign(model_train.is_train, |
| 65 | + tf.constant(True, dtype=tf.bool))) |
| 66 | + |
| 67 | + dev_loss = metrics["loss"] |
| 68 | + if dev_loss < loss_save: |
| 69 | + loss_save = dev_loss |
| 70 | + patience = 0 |
| 71 | + else: |
| 72 | + patience += 1 |
| 73 | + if patience >= config.patience: |
| 74 | + lr /= 2.0 |
| 75 | + loss_save = dev_loss |
| 76 | + patience = 0 |
| 77 | + sess.run(tf.assign(model_train.lr, |
| 78 | + tf.constant(lr, dtype=tf.float32))) |
| 79 | + for s in summ: |
| 80 | + writer.add_summary(s, global_step) |
| 81 | + writer.flush() |
| 82 | + filename = os.path.join( |
| 83 | + config.save_dir, "model_{}.ckpt".format(global_step)) |
| 84 | + saver.save(sess, filename) |
| 85 | + coord.request_stop() |
| 86 | + coord.join(threads) |
| 87 | + |
| 88 | + |
| 89 | +def test(config): |
| 90 | + with open(config.word_emb_file, "r") as fh: |
| 91 | + word_mat = np.array(json.load(fh), dtype=np.float32) |
| 92 | + with open(config.char_emb_file, "r") as fh: |
| 93 | + char_mat = np.array(json.load(fh), dtype=np.float32) |
| 94 | + with open(config.test_eval_file, "r") as fh: |
| 95 | + eval_file = json.load(fh) |
| 96 | + with open(config.test_meta, "r") as fh: |
| 97 | + meta = json.load(fh) |
| 98 | + |
| 99 | + total = meta["total"] |
| 100 | + |
| 101 | + print("Loading model...") |
| 102 | + test_batch = create_batch(config.test_record_file, config, test=True) |
| 103 | + with tf.variable_scope("model"): |
| 104 | + model = Model(config, test_batch, word_mat, char_mat, trainable=False) |
| 105 | + |
| 106 | + sess_config = tf.ConfigProto(allow_soft_placement=True) |
| 107 | + sess_config.gpu_options.allow_growth = True |
| 108 | + |
| 109 | + with tf.Session(config=sess_config) as sess: |
| 110 | + init_op = tf.group(tf.global_variables_initializer(), |
| 111 | + tf.local_variables_initializer()) |
| 112 | + sess.run(init_op) |
| 113 | + coord = tf.train.Coordinator() |
| 114 | + threads = tf.train.start_queue_runners(coord=coord) |
| 115 | + saver = tf.train.Saver() |
| 116 | + saver.restore(sess, tf.train.latest_checkpoint(config.save_dir)) |
| 117 | + sess.run(tf.assign(model.is_train, tf.constant(False, dtype=tf.bool))) |
| 118 | + losses = [] |
| 119 | + answer_dict = {} |
| 120 | + for step in tqdm(range(total // config.batch_size)): |
| 121 | + qa_id, loss, yp1, yp2 = sess.run( |
| 122 | + [model.qa_id, model.loss, model.yp1, model.yp2]) |
| 123 | + answer_dict.update(convert_tokens( |
| 124 | + eval_file, qa_id.tolist(), yp1.tolist(), yp2.tolist())) |
| 125 | + losses.append(loss) |
| 126 | + coord.request_stop() |
| 127 | + coord.join(threads) |
| 128 | + loss = np.mean(losses) |
| 129 | + metrics = evaluate(eval_file, answer_dict) |
| 130 | + print("Exact Match: {}, F1: {}".format( |
| 131 | + metrics['exact_match'], metrics['f1'])) |
| 132 | + |
| 133 | + |
| 134 | +def evaluate_batch(model, num_batches, eval_file, sess, data_type): |
| 135 | + answer_dict = {} |
| 136 | + losses = [] |
| 137 | + for _ in tqdm(range(1, num_batches + 1)): |
| 138 | + qa_id, loss, yp1, yp2, = sess.run( |
| 139 | + [model.qa_id, model.loss, model.yp1, model.yp2]) |
| 140 | + answer_dict.update(convert_tokens( |
| 141 | + eval_file, qa_id.tolist(), yp1.tolist(), yp2.tolist())) |
| 142 | + losses.append(loss) |
| 143 | + loss = np.mean(losses) |
| 144 | + metrics = evaluate(eval_file, answer_dict) |
| 145 | + metrics["loss"] = loss |
| 146 | + loss_sum = tf.Summary(value=[tf.Summary.Value( |
| 147 | + tag="{}/loss".format(data_type), simple_value=metrics["loss"]), ]) |
| 148 | + f1_sum = tf.Summary(value=[tf.Summary.Value( |
| 149 | + tag="{}/f1".format(data_type), simple_value=metrics["f1"]), ]) |
| 150 | + em_sum = tf.Summary(value=[tf.Summary.Value( |
| 151 | + tag="{}/em".format(data_type), simple_value=metrics["exact_match"]), ]) |
| 152 | + return metrics, [loss_sum, f1_sum, em_sum] |
0 commit comments