|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +import json |
| 4 | +import re |
| 5 | +import string |
3 | 6 | import typing as t
|
4 | 7 | from dataclasses import dataclass
|
5 | 8 |
|
6 |
| -from transformers import AutoModelForSequenceClassification, AutoTokenizer |
| 9 | +import numpy as np |
| 10 | +import spacy |
| 11 | +import transformers |
| 12 | +from transformers import (AutoConfig, AutoModelForSequenceClassification, |
| 13 | + AutoTokenizer, PreTrainedModel) |
7 | 14 |
|
8 | 15 | from belar.metrics import Metric
|
9 | 16 | from belar.utils import device_check
|
10 | 17 |
|
11 | 18 | if t.TYPE_CHECKING:
|
12 | 19 | from torch import device as Device
|
13 | 20 |
|
| 21 | +from transformers.models.auto.modeling_auto import ( |
| 22 | + MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, |
| 23 | + MODEL_WITH_LM_HEAD_MAPPING_NAMES) |
| 24 | + |
| 25 | +MODEL_MAPPINGS_NAMES = [ |
| 26 | + MODEL_WITH_LM_HEAD_MAPPING_NAMES, |
| 27 | + MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, |
| 28 | +] |
| 29 | + |
| 30 | +DEVICES = ["cpu", "cuda"] |
| 31 | +SPACY_MODEL = "en_core_web_sm" |
| 32 | +LABEL2SCORE = {"entailment": 1, "contradiction": 0, "neutral": 0.5} |
| 33 | +EPS = 1e-8 |
| 34 | + |
14 | 35 |
|
15 | 36 | @dataclass
|
16 | 37 | class EntailmentScore(Metric):
|
@@ -47,6 +68,20 @@ def name(self):
|
47 | 68 | def is_batchable(self):
|
48 | 69 | return True
|
49 | 70 |
|
| 71 | + def infer(self, ground_truth: str, generated_text: str): |
| 72 | + encodings = self.tokenizer( |
| 73 | + ground_truth, |
| 74 | + generated_text, |
| 75 | + truncation=True, |
| 76 | + return_tensors="pt", |
| 77 | + max_length=self.max_length, |
| 78 | + padding="max_length", |
| 79 | + ) |
| 80 | + label2id = {value.lower(): key for key, value in self.id2label.items()} |
| 81 | + output = self.model(**encodings) |
| 82 | + pred = output.logits.softmax(axis=-1).detach().cpu().squeeze() |
| 83 | + return {label: pred[id].item() for label, id in label2id.items()} |
| 84 | + |
50 | 85 | def batch_infer(self, inputs: dict):
|
51 | 86 | predictions = []
|
52 | 87 | input_ids, attention_mask = inputs["input_ids"], inputs["attention_mask"]
|
@@ -87,3 +122,218 @@ def score(
|
87 | 122 | score = self.batch_infer(encodings)
|
88 | 123 |
|
89 | 124 | return score
|
| 125 | + |
| 126 | + |
| 127 | +class QAGQ: |
| 128 | + def __init__( |
| 129 | + self, |
| 130 | + model: PreTrainedModel, |
| 131 | + model_name_or_path: str, |
| 132 | + device: t.Literal["cpu", "cuda"] | Device = "cpu", |
| 133 | + ): |
| 134 | + self.model = model.from_pretrained(model_name_or_path) |
| 135 | + self.model.eval() # type: ignore |
| 136 | + self.device = device_check(device) |
| 137 | + self.model.to(self.device) # type: ignore |
| 138 | + self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) |
| 139 | + |
| 140 | + @classmethod |
| 141 | + def from_pretrained(cls, model_name_or_path): |
| 142 | + config = AutoConfig.from_pretrained(model_name_or_path) |
| 143 | + model_mappings = [ |
| 144 | + arch for model_type in MODEL_MAPPINGS_NAMES for arch in model_type.values() |
| 145 | + ] |
| 146 | + architecture = np.intersect1d(model_mappings, config.architectures) |
| 147 | + if len(architecture) == 0: |
| 148 | + raise ValueError("Model doesn't support QA or LM architecture") |
| 149 | + model = getattr(transformers, architecture[0]) |
| 150 | + return cls(model, model_name_or_path) |
| 151 | + |
| 152 | + def batch_generate_question(self, answers: list[str], context: str, **kwargs): |
| 153 | + input_texts = [ |
| 154 | + "answer: %s context: %s </s>" % (ans, context) for ans in answers |
| 155 | + ] |
| 156 | + max_length = kwargs.pop("input_max_length", 512) |
| 157 | + encodings = self.tokenizer( |
| 158 | + input_texts, |
| 159 | + padding="max_length", |
| 160 | + truncation=True, |
| 161 | + max_length=max_length, |
| 162 | + return_tensors="pt", |
| 163 | + ) |
| 164 | + encodings = {k: v.to(self.device) for k, v in encodings.items()} |
| 165 | + outputs = self.model.generate(**encodings, **kwargs) # type: ignore |
| 166 | + outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) |
| 167 | + return [output.replace("question:", "").strip() for output in outputs] |
| 168 | + |
| 169 | + def batch_generate_answers(self, questions: list[str], context: str, **kwargs): |
| 170 | + max_length = kwargs.pop("input_max_length", 512) |
| 171 | + encodings = self.tokenizer( |
| 172 | + questions, |
| 173 | + [context] * len(questions), |
| 174 | + padding="max_length", |
| 175 | + truncation=True, |
| 176 | + max_length=max_length, |
| 177 | + return_tensors="pt", |
| 178 | + ) |
| 179 | + encodings = { |
| 180 | + k: v.view(-1, max_length).to(self.device) for k, v in encodings.items() |
| 181 | + } |
| 182 | + poss_ans_starts, poss_ans_ends = self.model( |
| 183 | + **encodings, return_dict=False |
| 184 | + ) # type: ignore |
| 185 | + best_start = poss_ans_starts.argmax(1) |
| 186 | + best_ends = poss_ans_ends.argmax(1) |
| 187 | + answers = [ |
| 188 | + encodings["input_ids"][i][start : end + 1] |
| 189 | + for i, (start, end) in enumerate(zip(best_start, best_ends)) |
| 190 | + ] |
| 191 | + answers = self.tokenizer.batch_decode(answers) |
| 192 | + return answers |
| 193 | + |
| 194 | + |
| 195 | +@dataclass |
| 196 | +class Qsquare(Metric): |
| 197 | + qa_model_name: str = "consciousAI/question-answering-roberta-base-s" |
| 198 | + qg_model_name: str = "mrm8488/t5-base-finetuned-question-generation-ap" |
| 199 | + device: t.Literal["cpu", "cuda"] = "cpu" |
| 200 | + max_answers: int = 10 |
| 201 | + crosscheck_candidates: bool = True |
| 202 | + load_single = False |
| 203 | + batch_size: int = 4 |
| 204 | + include_nouns: bool = True |
| 205 | + save_results: bool = False |
| 206 | + |
| 207 | + def __post_init__( |
| 208 | + self, |
| 209 | + ): |
| 210 | + self.nlp = spacy.load(SPACY_MODEL) |
| 211 | + self.qa = QAGQ.from_pretrained(self.qa_model_name) |
| 212 | + self.qg = QAGQ.from_pretrained(self.qg_model_name) |
| 213 | + |
| 214 | + @property |
| 215 | + def name(self): |
| 216 | + return "Q^2" |
| 217 | + |
| 218 | + @property |
| 219 | + def is_batchable(self): |
| 220 | + return True |
| 221 | + |
| 222 | + def generate_candidates(self, text: str): |
| 223 | + text = text.strip() |
| 224 | + nouns = [ |
| 225 | + i.text.lower() |
| 226 | + for i in self.nlp(text).noun_chunks |
| 227 | + if i.text.lower() not in self.nlp.Defaults.stop_words |
| 228 | + ] |
| 229 | + entities = set([ent.text.lower() for ent in self.nlp(text).ents]) |
| 230 | + num_nouns = max(0, self.max_answers - len(entities)) |
| 231 | + nouns = list(np.setdiff1d(nouns, list(entities))) |
| 232 | + if nouns and self.include_nouns: |
| 233 | + nouns = np.random.choice(nouns, size=num_nouns).tolist() |
| 234 | + else: |
| 235 | + nouns = [] |
| 236 | + |
| 237 | + return list(entities.union(set(nouns))) |
| 238 | + |
| 239 | + def generate_questions(self, candidates: list[str], context: str, **kwargs): |
| 240 | + questions = [] |
| 241 | + for idx in range(0, len(candidates), self.batch_size): |
| 242 | + batch_questions = self.qg.batch_generate_question( |
| 243 | + candidates[idx : idx + self.batch_size], context, **kwargs |
| 244 | + ) |
| 245 | + questions.extend( |
| 246 | + [qstn if qstn.endswith("?") else f"{qstn}?" for qstn in batch_questions] |
| 247 | + ) |
| 248 | + assert len(questions) == len(candidates), "Missing question for some candidates" |
| 249 | + return questions |
| 250 | + |
| 251 | + def generate_answers(self, questions: list[str], context: str): |
| 252 | + answers = [] |
| 253 | + for idx in range(0, len(questions), self.batch_size): |
| 254 | + batch_answers = self.qa.batch_generate_answers( |
| 255 | + questions[idx : idx + self.batch_size], context |
| 256 | + ) |
| 257 | + answers.extend(batch_answers) |
| 258 | + assert len(answers) == len(questions), "Missing answers for some questions" |
| 259 | + return answers |
| 260 | + |
| 261 | + def filter_candidates( |
| 262 | + self, questions: list[str], candidates: list[str], gen_answers: list[str] |
| 263 | + ): |
| 264 | + final_questions = [] |
| 265 | + final_candidates = [] |
| 266 | + for qstn, ans1, ans2 in zip(questions, candidates, gen_answers): |
| 267 | + if self.clean_candidate(ans1) == self.clean_candidate(ans2): |
| 268 | + final_candidates.append(ans1) |
| 269 | + final_questions.append(qstn) |
| 270 | + |
| 271 | + return final_questions, final_candidates |
| 272 | + |
| 273 | + def clean_candidate(self, text): |
| 274 | + text = text.strip().lower() |
| 275 | + text = text.translate(str.maketrans("", "", string.punctuation)) |
| 276 | + text = re.sub(r"\b(a|an|the|in|our)\b", " ", text) |
| 277 | + |
| 278 | + return text |
| 279 | + |
| 280 | + def score_candidates(self, ques_ans_dict: dict): |
| 281 | + nli = EntailmentScore() |
| 282 | + for qas in ques_ans_dict.values(): |
| 283 | + for item in qas: |
| 284 | + item["answer"] = self.clean_candidate(item["answer"]) |
| 285 | + item["predicted_answer"] = self.clean_candidate( |
| 286 | + item["predicted_answer"] |
| 287 | + ) |
| 288 | + if item["answer"] == item["predicted_answer"]: |
| 289 | + item.update({"score": 1}) |
| 290 | + else: |
| 291 | + qstn = item.get("question") |
| 292 | + score_dict = nli.infer( |
| 293 | + f'{qstn}{item.get("answer")}', |
| 294 | + f'{qstn}{item.get("predicted_answer")}', |
| 295 | + ) |
| 296 | + label = max(zip(score_dict.values(), score_dict.keys()))[1] |
| 297 | + item.update({"score": LABEL2SCORE[label]}) |
| 298 | + |
| 299 | + return ques_ans_dict |
| 300 | + |
| 301 | + def score(self, ground_truth: list[str], generated_text: list[str], **kwargs): |
| 302 | + gnd_qans = {} |
| 303 | + ans_candidates = [self.generate_candidates(text) for text in ground_truth] |
| 304 | + for i, (candidates, context) in enumerate(zip(ans_candidates, ground_truth)): |
| 305 | + questions = self.generate_questions(candidates, context, **kwargs) |
| 306 | + gen_answers = self.generate_answers(questions, context) |
| 307 | + if self.crosscheck_candidates: |
| 308 | + questions, candidates = self.filter_candidates( |
| 309 | + questions, candidates, gen_answers |
| 310 | + ) |
| 311 | + gnd_qans[i] = [ |
| 312 | + {"question": qstn, "answer": ans} |
| 313 | + for qstn, ans in zip(questions, candidates) |
| 314 | + ] |
| 315 | + |
| 316 | + for i, gen_text in enumerate(generated_text): |
| 317 | + questions = [item["question"] for item in gnd_qans[i]] |
| 318 | + gen_answers = self.generate_answers(questions, gen_text) |
| 319 | + _ = [ |
| 320 | + item.update({"predicted_answer": ans}) |
| 321 | + for item, ans in zip(gnd_qans[i], gen_answers) |
| 322 | + ] |
| 323 | + |
| 324 | + del self.qa |
| 325 | + del self.qg |
| 326 | + |
| 327 | + gnd_qans = self.score_candidates(gnd_qans) |
| 328 | + |
| 329 | + if self.save_results: |
| 330 | + with open("qa-qj-intermediate.json", "w") as file: |
| 331 | + json.dump(gnd_qans, file, indent=4) |
| 332 | + |
| 333 | + scores = [[dic["score"] for dic in item] for item in gnd_qans.values()] |
| 334 | + scores = [sum(sublist) / (len(sublist) + EPS) for sublist in scores] |
| 335 | + return scores |
| 336 | + |
| 337 | + |
| 338 | +ENTScore = EntailmentScore() |
| 339 | +Q2Score = Qsquare() |
0 commit comments