From ff449fccfb7d594c582071694b00f4c2192c5465 Mon Sep 17 00:00:00 2001 From: Jithin James Date: Wed, 24 Jan 2024 19:39:33 -0800 Subject: [PATCH] fix: answer_correctness embedding (#513) --- src/ragas/metrics/_answer_correctness.py | 2 ++ tests/benchmarks/benchmark_eval.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/ragas/metrics/_answer_correctness.py b/src/ragas/metrics/_answer_correctness.py index 88cdf4895..da7f95736 100644 --- a/src/ragas/metrics/_answer_correctness.py +++ b/src/ragas/metrics/_answer_correctness.py @@ -104,6 +104,8 @@ def __post_init__(self: t.Self): if not all([w >= 0 for w in self.weights]): raise ValueError("Weights must be non-negative") + def init_model(self): + super().init_model() if self.answer_similarity is None and self.weights[1] != 0: self.answer_similarity = AnswerSimilarity( llm=self.llm, batch_size=self.batch_size, embeddings=self.embeddings diff --git a/tests/benchmarks/benchmark_eval.py b/tests/benchmarks/benchmark_eval.py index 2b78d482d..ae91c9937 100644 --- a/tests/benchmarks/benchmark_eval.py +++ b/tests/benchmarks/benchmark_eval.py @@ -19,7 +19,7 @@ # data ds = load_dataset("explodinggradients/amnesty_qa", "english") assert isinstance(ds, DatasetDict) -eval_dataset = ds["train"] +eval_dataset = ds["eval"] # metrics metrics = [