Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 30 additions & 20 deletions neurox/data/extraction/transformers_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def extract_sentence_representations(
dtype="float32",
include_special_tokens=False,
tokenization_counts={},
verbose=False
):
"""
Get representations for a single sentence
Expand Down Expand Up @@ -245,15 +246,16 @@ def extract_sentence_representations(
]
all_hidden_states = np.array(all_hidden_states, dtype=dtype)

print('Sentence : "%s"' % (sentence))
print("Original (%03d): %s" % (len(original_tokens), original_tokens))
print(
"Tokenized (%03d): %s"
% (
len(tokenizer.convert_ids_to_tokens(ids)),
tokenizer.convert_ids_to_tokens(ids),
if verbose:
print('Sentence : "%s"' % (sentence))
print("Original (%03d): %s" % (len(original_tokens), original_tokens))
print(
"Tokenized (%03d): %s"
% (
len(tokenizer.convert_ids_to_tokens(ids)),
tokenizer.convert_ids_to_tokens(ids),
)
)
)

assert all_hidden_states.shape[1] == len(ids)

Expand All @@ -276,13 +278,14 @@ def extract_sentence_representations(
special_token_ids = []

assert all_hidden_states.shape[1] == len(filtered_ids)
print(
"Filtered (%03d): %s"
% (
len(tokenizer.convert_ids_to_tokens(filtered_ids)),
tokenizer.convert_ids_to_tokens(filtered_ids),
if verbose:
print(
"Filtered (%03d): %s"
% (
len(tokenizer.convert_ids_to_tokens(filtered_ids)),
tokenizer.convert_ids_to_tokens(filtered_ids),
)
)
)

# Get actual tokens for filtered ids in order to do subword
# aggregation
Expand Down Expand Up @@ -397,15 +400,17 @@ def extract_sentence_representations(
last_special_token_pointer += 1
counter += 1

print("Detokenized (%03d): %s" % (len(detokenized), detokenized))
print("Counter: %d" % (counter))
if verbose:
print("Detokenized (%03d): %s" % (len(detokenized), detokenized))
print("Counter: %d" % (counter))

if inputs_truncated:
print("WARNING: Input truncated because of length, skipping check")
else:
assert counter == len(filtered_ids)
assert len(detokenized) == len(original_tokens) + len(special_token_ids)
print("===================================================================")
if verbose:
print("===================================================================")
return final_hidden_states, detokenized


Expand All @@ -422,6 +427,7 @@ def extract_representations(
filter_layers=None,
dtype="float32",
include_special_tokens=False,
verbose=False
):
"""
Extract representations for an entire corpus and save them to disk
Expand Down Expand Up @@ -502,6 +508,8 @@ def corpus_generator(input_corpus_path):
print("Extracting representations from model")
tokenization_counts = {} # Cache for tokenizer rules
for sentence_idx, sentence in enumerate(corpus_generator(input_corpus)):
if sentence_idx % 100 == 0:
print(f"Reached Sentence {sentence_idx}")
hidden_states, extracted_words = extract_sentence_representations(
sentence,
model,
Expand All @@ -512,12 +520,15 @@ def corpus_generator(input_corpus_path):
dtype=dtype,
include_special_tokens=include_special_tokens,
tokenization_counts=tokenization_counts,
verbose=verbose
)

print("Hidden states: ", hidden_states.shape)
print("# Extracted words: ", len(extracted_words))
if verbose:
print("Hidden states: ", hidden_states.shape)
print("# Extracted words: ", len(extracted_words))

writer.write_activations(sentence_idx, extracted_words, hidden_states)
print("[INFO] DONE WITH THE ACTIVATION EXTRACTION...")

writer.close()

Expand Down Expand Up @@ -558,7 +569,6 @@ def main():
action="store_true",
help="Include special tokens like [CLS] and [SEP] in the extracted representations",
)

ActivationsWriter.add_writer_options(parser)

args = parser.parse_args()
Expand Down