Skip to content

Commit d5d0310

Browse files
shahules786Jithin James
and
Jithin James
authored
Added Q-square metric (#16)
* some fixes * add property to rouge * update quickstart * add max_length * add max_length * add qa-qg paradigm * fix model pipeline * add filter candidates * add candidate cleaning * compute scores from qa-qj * remove json saving * change score for neutral * resolve merge conflicts * resolve merge conflicts * initialize metric instances * black reformatting * format quickstart * fix imports * add spacy * merge from main * fix type checks * rmv test notebook --------- Co-authored-by: Jithin James <[email protected]>
1 parent af9a433 commit d5d0310

File tree

5 files changed

+257
-12
lines changed

5 files changed

+257
-12
lines changed

belar/metrics/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from belar.metrics.base import Evaluation, Metric
22
from belar.metrics.factual import EntailmentScore
33
from belar.metrics.similarity import SBERTScore
4-
from belar.metrics.simple import BLUE, EditDistance, EditRatio, Rouge1, Rouge2, RougeL
4+
from belar.metrics.simple import (BLUE, EditDistance, EditRatio, Rouge1,
5+
Rouge2, RougeL)
56

67
__all__ = [
78
"Evaluation",

belar/metrics/factual.py

+251-1
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,37 @@
11
from __future__ import annotations
22

3+
import json
4+
import re
5+
import string
36
import typing as t
47
from dataclasses import dataclass
58

6-
from transformers import AutoModelForSequenceClassification, AutoTokenizer
9+
import numpy as np
10+
import spacy
11+
import transformers
12+
from transformers import (AutoConfig, AutoModelForSequenceClassification,
13+
AutoTokenizer, PreTrainedModel)
714

815
from belar.metrics import Metric
916
from belar.utils import device_check
1017

1118
if t.TYPE_CHECKING:
1219
from torch import device as Device
1320

21+
from transformers.models.auto.modeling_auto import (
22+
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
23+
MODEL_WITH_LM_HEAD_MAPPING_NAMES)
24+
25+
MODEL_MAPPINGS_NAMES = [
26+
MODEL_WITH_LM_HEAD_MAPPING_NAMES,
27+
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
28+
]
29+
30+
DEVICES = ["cpu", "cuda"]
31+
SPACY_MODEL = "en_core_web_sm"
32+
LABEL2SCORE = {"entailment": 1, "contradiction": 0, "neutral": 0.5}
33+
EPS = 1e-8
34+
1435

1536
@dataclass
1637
class EntailmentScore(Metric):
@@ -47,6 +68,20 @@ def name(self):
4768
def is_batchable(self):
4869
return True
4970

71+
def infer(self, ground_truth: str, generated_text: str):
72+
encodings = self.tokenizer(
73+
ground_truth,
74+
generated_text,
75+
truncation=True,
76+
return_tensors="pt",
77+
max_length=self.max_length,
78+
padding="max_length",
79+
)
80+
label2id = {value.lower(): key for key, value in self.id2label.items()}
81+
output = self.model(**encodings)
82+
pred = output.logits.softmax(axis=-1).detach().cpu().squeeze()
83+
return {label: pred[id].item() for label, id in label2id.items()}
84+
5085
def batch_infer(self, inputs: dict):
5186
predictions = []
5287
input_ids, attention_mask = inputs["input_ids"], inputs["attention_mask"]
@@ -87,3 +122,218 @@ def score(
87122
score = self.batch_infer(encodings)
88123

89124
return score
125+
126+
127+
class QAGQ:
128+
def __init__(
129+
self,
130+
model: PreTrainedModel,
131+
model_name_or_path: str,
132+
device: t.Literal["cpu", "cuda"] | Device = "cpu",
133+
):
134+
self.model = model.from_pretrained(model_name_or_path)
135+
self.model.eval() # type: ignore
136+
self.device = device_check(device)
137+
self.model.to(self.device) # type: ignore
138+
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
139+
140+
@classmethod
141+
def from_pretrained(cls, model_name_or_path):
142+
config = AutoConfig.from_pretrained(model_name_or_path)
143+
model_mappings = [
144+
arch for model_type in MODEL_MAPPINGS_NAMES for arch in model_type.values()
145+
]
146+
architecture = np.intersect1d(model_mappings, config.architectures)
147+
if len(architecture) == 0:
148+
raise ValueError("Model doesn't support QA or LM architecture")
149+
model = getattr(transformers, architecture[0])
150+
return cls(model, model_name_or_path)
151+
152+
def batch_generate_question(self, answers: list[str], context: str, **kwargs):
153+
input_texts = [
154+
"answer: %s context: %s </s>" % (ans, context) for ans in answers
155+
]
156+
max_length = kwargs.pop("input_max_length", 512)
157+
encodings = self.tokenizer(
158+
input_texts,
159+
padding="max_length",
160+
truncation=True,
161+
max_length=max_length,
162+
return_tensors="pt",
163+
)
164+
encodings = {k: v.to(self.device) for k, v in encodings.items()}
165+
outputs = self.model.generate(**encodings, **kwargs) # type: ignore
166+
outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
167+
return [output.replace("question:", "").strip() for output in outputs]
168+
169+
def batch_generate_answers(self, questions: list[str], context: str, **kwargs):
170+
max_length = kwargs.pop("input_max_length", 512)
171+
encodings = self.tokenizer(
172+
questions,
173+
[context] * len(questions),
174+
padding="max_length",
175+
truncation=True,
176+
max_length=max_length,
177+
return_tensors="pt",
178+
)
179+
encodings = {
180+
k: v.view(-1, max_length).to(self.device) for k, v in encodings.items()
181+
}
182+
poss_ans_starts, poss_ans_ends = self.model(
183+
**encodings, return_dict=False
184+
) # type: ignore
185+
best_start = poss_ans_starts.argmax(1)
186+
best_ends = poss_ans_ends.argmax(1)
187+
answers = [
188+
encodings["input_ids"][i][start : end + 1]
189+
for i, (start, end) in enumerate(zip(best_start, best_ends))
190+
]
191+
answers = self.tokenizer.batch_decode(answers)
192+
return answers
193+
194+
195+
@dataclass
196+
class Qsquare(Metric):
197+
qa_model_name: str = "consciousAI/question-answering-roberta-base-s"
198+
qg_model_name: str = "mrm8488/t5-base-finetuned-question-generation-ap"
199+
device: t.Literal["cpu", "cuda"] = "cpu"
200+
max_answers: int = 10
201+
crosscheck_candidates: bool = True
202+
load_single = False
203+
batch_size: int = 4
204+
include_nouns: bool = True
205+
save_results: bool = False
206+
207+
def __post_init__(
208+
self,
209+
):
210+
self.nlp = spacy.load(SPACY_MODEL)
211+
self.qa = QAGQ.from_pretrained(self.qa_model_name)
212+
self.qg = QAGQ.from_pretrained(self.qg_model_name)
213+
214+
@property
215+
def name(self):
216+
return "Q^2"
217+
218+
@property
219+
def is_batchable(self):
220+
return True
221+
222+
def generate_candidates(self, text: str):
223+
text = text.strip()
224+
nouns = [
225+
i.text.lower()
226+
for i in self.nlp(text).noun_chunks
227+
if i.text.lower() not in self.nlp.Defaults.stop_words
228+
]
229+
entities = set([ent.text.lower() for ent in self.nlp(text).ents])
230+
num_nouns = max(0, self.max_answers - len(entities))
231+
nouns = list(np.setdiff1d(nouns, list(entities)))
232+
if nouns and self.include_nouns:
233+
nouns = np.random.choice(nouns, size=num_nouns).tolist()
234+
else:
235+
nouns = []
236+
237+
return list(entities.union(set(nouns)))
238+
239+
def generate_questions(self, candidates: list[str], context: str, **kwargs):
240+
questions = []
241+
for idx in range(0, len(candidates), self.batch_size):
242+
batch_questions = self.qg.batch_generate_question(
243+
candidates[idx : idx + self.batch_size], context, **kwargs
244+
)
245+
questions.extend(
246+
[qstn if qstn.endswith("?") else f"{qstn}?" for qstn in batch_questions]
247+
)
248+
assert len(questions) == len(candidates), "Missing question for some candidates"
249+
return questions
250+
251+
def generate_answers(self, questions: list[str], context: str):
252+
answers = []
253+
for idx in range(0, len(questions), self.batch_size):
254+
batch_answers = self.qa.batch_generate_answers(
255+
questions[idx : idx + self.batch_size], context
256+
)
257+
answers.extend(batch_answers)
258+
assert len(answers) == len(questions), "Missing answers for some questions"
259+
return answers
260+
261+
def filter_candidates(
262+
self, questions: list[str], candidates: list[str], gen_answers: list[str]
263+
):
264+
final_questions = []
265+
final_candidates = []
266+
for qstn, ans1, ans2 in zip(questions, candidates, gen_answers):
267+
if self.clean_candidate(ans1) == self.clean_candidate(ans2):
268+
final_candidates.append(ans1)
269+
final_questions.append(qstn)
270+
271+
return final_questions, final_candidates
272+
273+
def clean_candidate(self, text):
274+
text = text.strip().lower()
275+
text = text.translate(str.maketrans("", "", string.punctuation))
276+
text = re.sub(r"\b(a|an|the|in|our)\b", " ", text)
277+
278+
return text
279+
280+
def score_candidates(self, ques_ans_dict: dict):
281+
nli = EntailmentScore()
282+
for qas in ques_ans_dict.values():
283+
for item in qas:
284+
item["answer"] = self.clean_candidate(item["answer"])
285+
item["predicted_answer"] = self.clean_candidate(
286+
item["predicted_answer"]
287+
)
288+
if item["answer"] == item["predicted_answer"]:
289+
item.update({"score": 1})
290+
else:
291+
qstn = item.get("question")
292+
score_dict = nli.infer(
293+
f'{qstn}{item.get("answer")}',
294+
f'{qstn}{item.get("predicted_answer")}',
295+
)
296+
label = max(zip(score_dict.values(), score_dict.keys()))[1]
297+
item.update({"score": LABEL2SCORE[label]})
298+
299+
return ques_ans_dict
300+
301+
def score(self, ground_truth: list[str], generated_text: list[str], **kwargs):
302+
gnd_qans = {}
303+
ans_candidates = [self.generate_candidates(text) for text in ground_truth]
304+
for i, (candidates, context) in enumerate(zip(ans_candidates, ground_truth)):
305+
questions = self.generate_questions(candidates, context, **kwargs)
306+
gen_answers = self.generate_answers(questions, context)
307+
if self.crosscheck_candidates:
308+
questions, candidates = self.filter_candidates(
309+
questions, candidates, gen_answers
310+
)
311+
gnd_qans[i] = [
312+
{"question": qstn, "answer": ans}
313+
for qstn, ans in zip(questions, candidates)
314+
]
315+
316+
for i, gen_text in enumerate(generated_text):
317+
questions = [item["question"] for item in gnd_qans[i]]
318+
gen_answers = self.generate_answers(questions, gen_text)
319+
_ = [
320+
item.update({"predicted_answer": ans})
321+
for item, ans in zip(gnd_qans[i], gen_answers)
322+
]
323+
324+
del self.qa
325+
del self.qg
326+
327+
gnd_qans = self.score_candidates(gnd_qans)
328+
329+
if self.save_results:
330+
with open("qa-qj-intermediate.json", "w") as file:
331+
json.dump(gnd_qans, file, indent=4)
332+
333+
scores = [[dic["score"] for dic in item] for item in gnd_qans.values()]
334+
scores = [sum(sublist) / (len(sublist) + EPS) for sublist in scores]
335+
return scores
336+
337+
338+
ENTScore = EntailmentScore()
339+
Q2Score = Qsquare()

belar/utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from __future__ import annotations
2+
23
import typing as t
34
from warnings import warn
45

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ dependencies = [
88
"sentence-transformers",
99
"nltk",
1010
"datasets",
11+
"spacy",
1112
]
1213
dynamic = ["version", "readme"]
1314

tests/benchmarks/benchmark.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,8 @@
55
from tqdm import tqdm
66
from utils import print_table, timeit
77

8-
from belar.metrics import (
9-
EditDistance,
10-
EditRatio,
11-
EntailmentScore,
12-
Evaluation,
13-
Rouge1,
14-
Rouge2,
15-
RougeL,
16-
SBERTScore,
17-
)
8+
from belar.metrics import (EditDistance, EditRatio, EntailmentScore,
9+
Evaluation, Rouge1, Rouge2, RougeL, SBERTScore)
1810

1911
DEVICE = "cuda" if is_available() else "cpu"
2012
BATCHES = [0, 1]

0 commit comments

Comments
 (0)