From afa5b7e9eba297f453fb726f16c114755e861335 Mon Sep 17 00:00:00 2001 From: Andrew Yates Date: Mon, 1 Feb 2021 17:06:48 +0100 Subject: [PATCH] Fixes for Pytorch CEDR LR and hgf output (#126) * Fix Pytorch CEDR with BERT weights * Fix Pytorch LR schedule after iters/steps change * Enable TF similarity matrix padding * Add smart_open dependency * Change spacy dependency to <3.0 --- capreolus/reranker/CEDRKNRM.py | 6 +++++- capreolus/reranker/common.py | 11 ++++++----- capreolus/trainer/pytorch.py | 6 ++++-- docs/reproduction/CEDR-KNRM.md | 7 ++++--- environment.yml | 3 ++- requirements.txt | 3 ++- setup.py | 3 ++- 7 files changed, 25 insertions(+), 14 deletions(-) diff --git a/capreolus/reranker/CEDRKNRM.py b/capreolus/reranker/CEDRKNRM.py index ccec7ef77..503c96b01 100644 --- a/capreolus/reranker/CEDRKNRM.py +++ b/capreolus/reranker/CEDRKNRM.py @@ -139,7 +139,11 @@ def forward(self, bert_input, bert_mask, bert_segments): bert_segments = bert_segments.view((batch_size * self.num_passages, self.maxseqlen)) # get BERT embeddings (including CLS) for each passage - bert_output, all_layer_output = self.bert(bert_input, attention_mask=bert_mask, token_type_ids=bert_segments) + # TODO switch to hgf's ModelOutput after bumping tranformers version + outputs = self.bert(bert_input, attention_mask=bert_mask, token_type_ids=bert_segments) + if self.config["pretrained"].startswith("bert-"): + outputs = (outputs[0], outputs[2]) + bert_output, all_layer_output = outputs # average CLS embeddings to create the CLS feature cls = bert_output[:, 0, :] diff --git a/capreolus/reranker/common.py b/capreolus/reranker/common.py index 299d9de43..ab436668c 100644 --- a/capreolus/reranker/common.py +++ b/capreolus/reranker/common.py @@ -66,14 +66,13 @@ def new_similarity_matrix_tf(query_embed, doc_embed, query_tok, doc_tok, padding batch, qlen, dims = query_embed.shape doclen = doc_embed.shape[1] - # TODO apply mask for use in stuff other than KNRM query_embed = tf.reshape(tf.nn.l2_normalize(query_embed, axis=-1), [batch, qlen, 1, dims]) - # query_padding = tf.reshape(tf.cast(query_tok != padding, query_embed.dtype), [batch, qlen, 1, 1]) - # query_embed = query_embed * query_padding + query_padding = tf.reshape(tf.cast(query_tok != padding, query_embed.dtype), [batch, qlen, 1, 1]) + query_embed = query_embed * query_padding doc_embed = tf.reshape(tf.nn.l2_normalize(doc_embed, axis=-1), [batch, 1, doclen, dims]) - # doc_padding = tf.reshape(tf.cast(doc_tok != padding, doc_embed.dtype), [batch, 1, doclen, 1]) - # doc_embed = doc_embed * doc_padding + doc_padding = tf.reshape(tf.cast(doc_tok != padding, doc_embed.dtype), [batch, 1, doclen, 1]) + doc_embed = doc_embed * doc_padding simmat = tf.reduce_sum(query_embed * doc_embed, axis=-1, keepdims=True) return simmat @@ -142,6 +141,8 @@ def forward(self, query_tok, doc_tok): return simmat +# TODO replace this with newer ONIR version? +# https://github.com/Georgetown-IR-Lab/OpenNIR/blob/ca14dfa5e7cfef3fbbb35efbb4e7df0f1fbde590/onir/modules/interaction_matrix.py#L27 class StackedSimilarityMatrix(torch.nn.Module): # based on SimmatModule from https://github.com/Georgetown-IR-Lab/cedr/blob/master/modeling_util.py # which is copyright (c) 2019 Georgetown Information Retrieval Lab, MIT license diff --git a/capreolus/trainer/pytorch.py b/capreolus/trainer/pytorch.py index 76e375231..406d2c7f4 100644 --- a/capreolus/trainer/pytorch.py +++ b/capreolus/trainer/pytorch.py @@ -186,8 +186,10 @@ def train(self, reranker, train_dataset, train_output_path, dev_data, dev_output self.amp_train_autocast = contextlib.nullcontext self.scaler = None - # REF-TODO how to handle interactions between fastforward and schedule? - self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, self.lr_multiplier) + # REF-TODO how to handle interactions between fastforward and schedule? --> just save its state + self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR( + self.optimizer, lambda epoch: self.lr_multiplier(step=epoch * self.n_batch_per_iter) + ) if self.config["softmaxloss"]: self.loss = pair_softmax_loss diff --git a/docs/reproduction/CEDR-KNRM.md b/docs/reproduction/CEDR-KNRM.md index c63ed0d94..4c2f5b7f5 100644 --- a/docs/reproduction/CEDR-KNRM.md +++ b/docs/reproduction/CEDR-KNRM.md @@ -1,11 +1,11 @@ # Capreolus: Reranking robust04 with CEDR-KNRM This page contains instructions for running CEDR-KNRM on the robust04 benchmark. -[*CEDR: Contextualized Embeddings for Document Ranking*](https://arxiv.org/pdf/1904.07094.pd)f +[*CEDR: Contextualized Embeddings for Document Ranking*](https://arxiv.org/pdf/1904.07094.pdf). Sean MacAvaney, Andrew Yates, Arman Cohan, and Nazli Goharian. SIGIR 2019. ## Setup -Install Capreolus v0.2.6 or later. See the [installation guide](https://capreolus.ai/en/latest/installation.html) for help installing a release. To install from GitHub, see the [PARADE reproduction guide](https://github.com/capreolus-ir/capreolus/blob/master/docs/reproduction/PARADE.md). +Install Capreolus v0.2.6 or later. See the [installation guide](https://capreolus.ai/en/latest/installation.html) for help installing a release. To install from GitHub, see the [PARADE guide](https://github.com/capreolus-ir/capreolus/blob/master/docs/reproduction/PARADE.md). ## Running CEDR-KNRM @@ -38,7 +38,8 @@ When using a less powerful GPU or disabling mixed precision (`reranker.trainer.a 3. Each command will take a few hours on a single V100 GPU. Per-fold metrics are displayed after each fold completes. 4. When the final fold completes, cross-validated metrics are also displayed. -Note that the Tensorflow implementation has only been tested on TPUs. +Note that the Tensorflow implementation has primarily been tested on TPUs. + ## Running BERT-KNRM, VanillaBERT, and other model variants diff --git a/environment.yml b/environment.yml index 569fb7681..41269fb3c 100644 --- a/environment.yml +++ b/environment.yml @@ -4,7 +4,7 @@ channels: dependencies: - python=3.7 - pandas - - spacy + - spacy<3.0 - numpy - scipy - matplotlib @@ -48,3 +48,4 @@ dependencies: - xxhash - annoy - fasteners + - smart_open diff --git a/requirements.txt b/requirements.txt index 798f67b1c..84fa74caa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,7 +31,8 @@ Pillow beautifulsoup4 lxml scispacy -spacy +smart_open +spacy<3.0 pandas https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.2.4/en_core_sci_lg-0.2.4.tar.gz # deps that the pymagnitude package isn't pulling in: diff --git a/setup.py b/setup.py index 7f096f0c2..8cdba752b 100644 --- a/setup.py +++ b/setup.py @@ -78,7 +78,8 @@ def get_version(rel_path): "beautifulsoup4", "lxml", "scispacy", - "spacy", + "smart_open", + "spacy<3.0", "pandas", ], classifiers=["Programming Language :: Python :: 3", "Operating System :: OS Independent"],