Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sliding window bug #20

Open
hankcs opened this issue Jul 8, 2020 · 0 comments
Open

Sliding window bug #20

hankcs opened this issue Jul 8, 2020 · 0 comments

Comments

@hankcs
Copy link

hankcs commented Jul 8, 2020

Hi, there seems to be a bug in the calculation of final_window_start:

# Next, select indices of the sequence such that it will result in embeddings representing the original
# sentence. To capture maximal context, the indices will be the middle part of each embedded window
# sub-sequence (plus any leftover start and final edge windows), e.g.,
# 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# "[CLS] I went to the very fine [SEP] [CLS] the very fine store to eat [SEP]"
# with max_pieces = 8 should produce max context indices [2, 3, 4, 10, 11, 12] with additional start
# and final windows with indices [0, 1] and [14, 15] respectively.
# Find the stride as half the max pieces, ignoring the special start and end tokens
# Calculate an offset to extract the centermost embeddings of each window
stride = (self.max_pieces - self.start_tokens - self.end_tokens) // 2
stride_offset = stride // 2 + self.start_tokens
first_window = list(range(stride_offset))
max_context_windows = [i for i in range(full_seq_len)
if stride_offset - 1 < i % self.max_pieces < stride_offset + stride]
final_window_start = full_seq_len - (full_seq_len % self.max_pieces) + stride_offset + stride
final_window = list(range(final_window_start, full_seq_len))
select_indices = first_window + max_context_windows + final_window

On the test case from your comment, final_window_start is greater than full_seq_len:

full_seq_len = 16
max_pieces = 8
start_tokens = 1
end_tokens = 1

# Next, select indices of the sequence such that it will result in embeddings representing the original
# sentence. To capture maximal context, the indices will be the middle part of each embedded window
# sub-sequence (plus any leftover start and final edge windows), e.g.,
#  0     1 2    3  4   5    6    7     8     9   10   11   12    13 14  15
# "[CLS] I went to the very fine [SEP] [CLS] the very fine store to eat [SEP]"
# with max_pieces = 8 should produce max context indices [2, 3, 4, 10, 11, 12] with additional start
# and final windows with indices [0, 1] and [14, 15] respectively.

# Find the stride as half the max pieces, ignoring the special start and end tokens
# Calculate an offset to extract the centermost embeddings of each window
stride = (max_pieces - start_tokens - end_tokens) // 2
stride_offset = stride // 2 + start_tokens

first_window = list(range(stride_offset))

max_context_windows = [i for i in range(full_seq_len)
                       if stride_offset - 1 < i % max_pieces < stride_offset + stride]

final_window_start = full_seq_len - (full_seq_len % max_pieces) + stride_offset + stride
final_window = list(range(final_window_start, full_seq_len))

select_indices = first_window + max_context_windows + final_window
print(select_indices)

Output is [0, 1, 2, 3, 4, 10, 11, 12] and [14, 15] is missing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant