@@ -22,7 +22,25 @@ def _top_k():
2222 )
2323
2424
25- def sample_sequence (* , hparams , length , start_token = None , batch_size = None , context = None , temperature = 1 , top_k = 0 ):
25+ def top_p_logits (logits , p ):
26+ """Nucleus sampling"""
27+ batch , _ = logits .shape .as_list ()
28+ sorted_logits = tf .sort (logits , direction = 'DESCENDING' , axis = - 1 )
29+ cumulative_probs = tf .cumsum (tf .nn .softmax (sorted_logits , axis = - 1 ), axis = - 1 )
30+ indices = tf .stack ([
31+ tf .range (0 , batch ),
32+ # number of indices to include
33+ tf .maximum (tf .reduce_sum (tf .cast (cumulative_probs <= p , tf .int32 ), axis = - 1 ) - 1 , 0 ),
34+ ], axis = - 1 )
35+ min_values = tf .gather_nd (sorted_logits , indices )
36+ return tf .where (
37+ logits < min_values ,
38+ tf .ones_like (logits ) * - 1e10 ,
39+ logits ,
40+ )
41+
42+
43+ def sample_sequence (* , hparams , length , start_token = None , batch_size = None , context = None , temperature = 1 , top_k = 0 , top_p = 1 ):
2644 if start_token is None :
2745 assert context is not None , 'Specify exactly one of start_token and context!'
2846 else :
@@ -45,6 +63,7 @@ def body(past, prev, output):
4563 next_outputs = step (hparams , prev , past = past )
4664 logits = next_outputs ['logits' ][:, - 1 , :] / tf .to_float (temperature )
4765 logits = top_k_logits (logits , k = top_k )
66+ logits = top_p_logits (logits , p = top_p )
4867 samples = tf .multinomial (logits , num_samples = 1 , output_dtype = tf .int32 )
4968 return [
5069 next_outputs ['presents' ] if past is None else tf .concat ([past , next_outputs ['presents' ]], axis = - 2 ),
0 commit comments