Skip to content

Commit efc05c1

Browse files
author
wzhouad
committed
fix bucket
1 parent cdf702a commit efc05c1

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

config.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@
8383
flags.DEFINE_integer("num_threads", 4, "Number of threads in input pipeline")
8484
flags.DEFINE_boolean("use_cudnn", True, "Whether to use cudnn (only for GPU)")
8585
flags.DEFINE_boolean("is_bucket", False, "Whether to use bucketing")
86-
flags.DEFINE_list("bucket_range", [0, 400, 40], "range of bucket")
86+
flags.DEFINE_list("bucket_range", [0, 401, 40], "range of bucket")
8787

8888
flags.DEFINE_integer("batch_size", 64, "Batch size")
8989
flags.DEFINE_integer("num_steps", 60000, "Number of steps")
@@ -101,8 +101,10 @@
101101
# Extensions (Uncomment corresponding line in download.sh to download the required data)
102102
glove_char_file = os.path.join(
103103
home, "data", "glove", "glove.840B.300d-char.txt")
104-
flags.DEFINE_string("glove_char_file", glove_char_file, "Glove character embedding")
105-
flags.DEFINE_boolean("pretrained_char", False, "Whether to use pretrained char embedding")
104+
flags.DEFINE_string("glove_char_file", glove_char_file,
105+
"Glove character embedding")
106+
flags.DEFINE_boolean("pretrained_char", False,
107+
"Whether to use pretrained char embedding")
106108

107109
fasttext_file = os.path.join(home, "data", "fasttext", "wiki-news-300d-1M.vec")
108110
flags.DEFINE_string("fasttext_file", fasttext_file, "Fasttext word embedding")

util.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,12 @@ def get_batch_dataset(record_file, parser, config):
4646
def key_func(context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, y1, y2, qa_id):
4747
c_len = tf.reduce_sum(
4848
tf.cast(tf.cast(context_idxs, tf.bool), tf.int32))
49-
t = tf.clip_by_value(buckets, 0, c_len)
50-
return tf.argmax(t)
49+
buckets_min = buckets[:-1]
50+
buckets_max = buckets[1:]
51+
conditions_c = tf.logical_and(tf.less_equal(
52+
buckets_min, c_len), tf.less(c_len, buckets_max))
53+
bucket_id = tf.reduce_min(tf.where(conditions_c))
54+
return bucket_id
5155

5256
def reduce_func(key, elements):
5357
return elements.batch(config.batch_size)

0 commit comments

Comments
 (0)