From 41e9e54b204566425908c1a437d7c1c798578ad3 Mon Sep 17 00:00:00 2001 From: Shahul ES Date: Tue, 12 Dec 2023 20:29:37 +0530 Subject: [PATCH] fix: handle edge cases in prompt processing (#374) --- src/ragas/metrics/_answer_correctness.py | 24 +++++++++++++++--------- src/ragas/metrics/_faithfulness.py | 2 +- src/ragas/utils.py | 2 +- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/src/ragas/metrics/_answer_correctness.py b/src/ragas/metrics/_answer_correctness.py index 76a6ebeb0..6a959d46d 100644 --- a/src/ragas/metrics/_answer_correctness.py +++ b/src/ragas/metrics/_answer_correctness.py @@ -119,15 +119,21 @@ def _score_batch( f1_score = [] for prediction in outputs: prediction = json_loader.safe_load(prediction[0].text, self.llm) - prediction = [ - item.get(key_map[k], np.nan) - for item in prediction - for k in key_map.keys() - ] - tp, fp, fn = [ - len(item) if isinstance(item, list) else np.nan for item in prediction - ] - score = tp / (tp + 0.5 * (fp + fn)) + prediction = prediction if isinstance(prediction, list) else [] + if prediction: + prediction = [ + item.get(key_map[k], np.nan) + for item in prediction + for k in key_map.keys() + ] + tp, fp, fn = [ + len(item) if isinstance(item, list) else np.nan + for item in prediction + ] + score = tp / (tp + 0.5 * (fp + fn)) + else: + score = np.nan + f1_score.append(score) similarity_scores = self.answer_similarity._score_batch(dataset) # type: ignore diff --git a/src/ragas/metrics/_faithfulness.py b/src/ragas/metrics/_faithfulness.py index 0b709a834..3cd5c1e7c 100644 --- a/src/ragas/metrics/_faithfulness.py +++ b/src/ragas/metrics/_faithfulness.py @@ -173,7 +173,7 @@ def _score_batch( scores = [] for output in outputs: output = json_loader.safe_load(output[0].text, self.llm) - output = output if output else [] + output = output if isinstance(output, list) else [] faithful_statements = sum( verdict_score_map.get(dict.get("verdict", "").lower(), np.nan) for dict in output diff --git a/src/ragas/utils.py b/src/ragas/utils.py index edbae6328..9e0849bac 100644 --- a/src/ragas/utils.py +++ b/src/ragas/utils.py @@ -111,7 +111,7 @@ def _fix_to_json( callbacks: t.Optional[CallbackManager] = None, callback_group_name: str = "batch", ): - # TODO (executor) + # TODO (executor) with trace_as_chain_group( callback_group_name, callback_manager=callbacks ) as batch_group: