Skip to content

Commit

Permalink
Revert "fat stack: tear out scheduled sampling and replace with full …
Browse files Browse the repository at this point in the history
…sampling"

This reverts commit 350a059.
  • Loading branch information
hans committed Jul 2, 2016
1 parent 0981f34 commit 6beee1a
Showing 1 changed file with 34 additions and 5 deletions.
39 changes: 34 additions & 5 deletions python/spinn/fat_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,9 @@ def _make_inputs(self):
self.X = self.X or T.imatrix("X")
self.transitions = self.transitions or T.imatrix("transitions")

def _step(self, transitions_t, stack_t, buffer_cur_t, tracking_hidden,
attention_hidden, buffer, ground_truth_transitions_visible,
premise_stack_tops, projected_stack_tops):
def _step(self, transitions_t, ss_mask_gen_matrix_t, stack_t, buffer_cur_t,
tracking_hidden, attention_hidden, buffer,
ground_truth_transitions_visible, premise_stack_tops, projected_stack_tops):
"""TODO document"""
batch_size, _ = self.X.shape

Expand Down Expand Up @@ -299,8 +299,25 @@ def _step(self, transitions_t, stack_t, buffer_cur_t, tracking_hidden,
logits_use_cell=self._predict_use_cell,
name="prediction_and_tracking")

# HACK: Sample from action multinomial
mask = ss_mask_gen.multinomial(pvals=actions_t).nonzero()[1]
if self.train_with_predicted_transitions:
# Model 2 case.
if self.interpolate:
# Only use ground truth transitions if they are marked as visible to the model.
effective_ss_mask_gen_matrix_t = ss_mask_gen_matrix_t * ground_truth_transitions_visible
# Interpolate between truth and prediction using bernoulli RVs
# generated prior to the step.
mask = (transitions_t * effective_ss_mask_gen_matrix_t
+ actions_t.argmax(axis=1) * (1 - effective_ss_mask_gen_matrix_t))
else:
# Use predicted actions to build a mask.
mask = actions_t.argmax(axis=1)
elif self._predict_transitions:
# Use transitions provided from external parser when not masked out
mask = (transitions_t * ground_truth_transitions_visible
+ actions_t.argmax(axis=1) * (1 - ground_truth_transitions_visible))
else:
# Model 0 case.
mask = transitions_t

# Now update the stack: first precompute reduce results.
if self.model_dim != self.stack_dim:
Expand Down Expand Up @@ -425,6 +442,18 @@ def _make_scan(self):

# Prepare data to scan over.
sequences = [transitions]
if self.interpolate:
# Generate Bernoulli RVs to simulate scheduled sampling
# if the interpolate flag is on.
ss_mask_gen_matrix = self.ss_mask_gen.binomial(
transitions.shape, p=self.ss_prob)
# Take in the RV sequence as input.
sequences.append(ss_mask_gen_matrix)
else:
# Take in the RV sequqnce as a dummy output. This is
# done to avaid defining another step function.
outputs_info = [DUMMY] + outputs_info

non_sequences = [buffer_t, self.ground_truth_transitions_visible]

if self.use_attention != "None" and self.is_hypothesis:
Expand Down

0 comments on commit 6beee1a

Please sign in to comment.