Skip to content

Commit f944fd5

Browse files
author
wzhouad
committed
minor fix
1 parent efc05c1 commit f944fd5

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@
8383
flags.DEFINE_integer("num_threads", 4, "Number of threads in input pipeline")
8484
flags.DEFINE_boolean("use_cudnn", True, "Whether to use cudnn (only for GPU)")
8585
flags.DEFINE_boolean("is_bucket", False, "Whether to use bucketing")
86-
flags.DEFINE_list("bucket_range", [0, 401, 40], "range of bucket")
86+
flags.DEFINE_list("bucket_range", [40, 361, 40], "range of bucket")
8787

8888
flags.DEFINE_integer("batch_size", 64, "Batch size")
8989
flags.DEFINE_integer("num_steps", 60000, "Number of steps")

util.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import tensorflow as tf
2+
import numpy as np
23
import re
34
from collections import Counter
45
import string
@@ -46,10 +47,10 @@ def get_batch_dataset(record_file, parser, config):
4647
def key_func(context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, y1, y2, qa_id):
4748
c_len = tf.reduce_sum(
4849
tf.cast(tf.cast(context_idxs, tf.bool), tf.int32))
49-
buckets_min = buckets[:-1]
50-
buckets_max = buckets[1:]
51-
conditions_c = tf.logical_and(tf.less_equal(
52-
buckets_min, c_len), tf.less(c_len, buckets_max))
50+
buckets_min = [np.iinfo(np.int32).min] + buckets
51+
buckets_max = buckets + [np.iinfo(np.int32).max]
52+
conditions_c = tf.logical_and(
53+
tf.less(buckets_min, c_len), tf.less_equal(c_len, buckets_max))
5354
bucket_id = tf.reduce_min(tf.where(conditions_c))
5455
return bucket_id
5556

0 commit comments

Comments
 (0)