diff --git a/neurox/data/extraction/transformers_extractor.py b/neurox/data/extraction/transformers_extractor.py index 4c19272..238c867 100644 --- a/neurox/data/extraction/transformers_extractor.py +++ b/neurox/data/extraction/transformers_extractor.py @@ -127,6 +127,7 @@ def extract_sentence_representations( dtype="float32", include_special_tokens=False, tokenization_counts={}, + verbose=False ): """ Get representations for a single sentence @@ -245,15 +246,16 @@ def extract_sentence_representations( ] all_hidden_states = np.array(all_hidden_states, dtype=dtype) - print('Sentence : "%s"' % (sentence)) - print("Original (%03d): %s" % (len(original_tokens), original_tokens)) - print( - "Tokenized (%03d): %s" - % ( - len(tokenizer.convert_ids_to_tokens(ids)), - tokenizer.convert_ids_to_tokens(ids), + if verbose: + print('Sentence : "%s"' % (sentence)) + print("Original (%03d): %s" % (len(original_tokens), original_tokens)) + print( + "Tokenized (%03d): %s" + % ( + len(tokenizer.convert_ids_to_tokens(ids)), + tokenizer.convert_ids_to_tokens(ids), + ) ) - ) assert all_hidden_states.shape[1] == len(ids) @@ -276,13 +278,14 @@ def extract_sentence_representations( special_token_ids = [] assert all_hidden_states.shape[1] == len(filtered_ids) - print( - "Filtered (%03d): %s" - % ( - len(tokenizer.convert_ids_to_tokens(filtered_ids)), - tokenizer.convert_ids_to_tokens(filtered_ids), + if verbose: + print( + "Filtered (%03d): %s" + % ( + len(tokenizer.convert_ids_to_tokens(filtered_ids)), + tokenizer.convert_ids_to_tokens(filtered_ids), + ) ) - ) # Get actual tokens for filtered ids in order to do subword # aggregation @@ -397,15 +400,17 @@ def extract_sentence_representations( last_special_token_pointer += 1 counter += 1 - print("Detokenized (%03d): %s" % (len(detokenized), detokenized)) - print("Counter: %d" % (counter)) + if verbose: + print("Detokenized (%03d): %s" % (len(detokenized), detokenized)) + print("Counter: %d" % (counter)) if inputs_truncated: print("WARNING: Input truncated because of length, skipping check") else: assert counter == len(filtered_ids) assert len(detokenized) == len(original_tokens) + len(special_token_ids) - print("===================================================================") + if verbose: + print("===================================================================") return final_hidden_states, detokenized @@ -422,6 +427,7 @@ def extract_representations( filter_layers=None, dtype="float32", include_special_tokens=False, + verbose=False ): """ Extract representations for an entire corpus and save them to disk @@ -502,6 +508,8 @@ def corpus_generator(input_corpus_path): print("Extracting representations from model") tokenization_counts = {} # Cache for tokenizer rules for sentence_idx, sentence in enumerate(corpus_generator(input_corpus)): + if sentence_idx % 100 == 0: + print(f"Reached Sentence {sentence_idx}") hidden_states, extracted_words = extract_sentence_representations( sentence, model, @@ -512,12 +520,15 @@ def corpus_generator(input_corpus_path): dtype=dtype, include_special_tokens=include_special_tokens, tokenization_counts=tokenization_counts, + verbose=verbose ) - print("Hidden states: ", hidden_states.shape) - print("# Extracted words: ", len(extracted_words)) + if verbose: + print("Hidden states: ", hidden_states.shape) + print("# Extracted words: ", len(extracted_words)) writer.write_activations(sentence_idx, extracted_words, hidden_states) + print("[INFO] DONE WITH THE ACTIVATION EXTRACTION...") writer.close() @@ -558,7 +569,6 @@ def main(): action="store_true", help="Include special tokens like [CLS] and [SEP] in the extracted representations", ) - ActivationsWriter.add_writer_options(parser) args = parser.parse_args()