-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathtest.py
28 lines (20 loc) · 1.17 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from core.rl_solve_debug import CaptioningSolver
from core.rl_model import CaptionGenerator
from core.utils import load_coco_data
def main():
# load train dataset
data = load_coco_data(data_path='./data', split='debug')
word_to_idx = data['word_to_idx']
# load val dataset to print out bleu scores every epoch
# val_data = load_coco_data(data_path='./data', split='val')
# val_data = second_process(10,16, val_data)
model = CaptionGenerator(word_to_idx, dim_feature=[196, 512], dim_embed=256,
dim_hidden=256, n_time_step=10, prev2out=True,
ctx2out=True, alpha_c=1.0, selector=True, dropout=True)
solver = CaptioningSolver(model, data, data, n_epochs=50, batch_size=128, update_rule='adam',
learning_rate=0.001, print_every=500, save_every=1, image_path='./image/',
pretrained_model=None, model_path='./model/lstm/', test_model='./model/lstm/model-5',
print_bleu=True, log_path='./log/')
solver.test(data, split='val')
if __name__ == "__main__":
main()