-
Notifications
You must be signed in to change notification settings - Fork 3
/
layers.py
38 lines (27 loc) · 1.5 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import tensorflow as tf
def avg_pooling_embedding(embedding, features, params):
"""
:param features: (batch, 2*window_size)
:param embedding: (vocab_size, emb_size)
:return: input_embedding : average pooling of context embedding
"""
input_embedding= []
samples = tf.unstack(features, params['batch_size'])
for sample in samples:
sample = tf.boolean_mask(sample, tf.not_equal(sample, params['pad_index']), axis=0) # (real_size,)
tmp = tf.nn.embedding_lookup(embedding, sample) # (real_size, emb_size)
input_embedding.append(tf.reduce_mean(tmp, axis=0)) # (emb_size, )
input_embedding = tf.stack(input_embedding, name = 'input_central_embeddinng') # batch * emb_size
return input_embedding
def avg_pooling_embedding_v2(embedding, features, params):
"""
Allow Embedding for INVALID Index and apply weighting mask
:param features: (batch, padded_size)
:param embedding: (vocab_size, emb_size)
:return: input_embedding : average pooling of context embedding
"""
input_embedding = tf.nn.embedding_lookup(embedding, features) # batch * padded_size * emb_size
zero_mask = tf.expand_dims(tf.equal(features, params['pad_index']), axis=2) # batch * padded_size * 1
weight = tf.where(zero_mask, tf.zeros_like(zero_mask, dtype=tf.float32), tf.ones_like(zero_mask, dtype = tf.float32)) # batch * padded_size *1
input_embedding = tf.reduce_mean(tf.multiply(weight, input_embedding), axis=1 ) # batch * emb_size
return input_embedding