diff --git a/ragas/metrics/__init__.py b/ragas/metrics/__init__.py index e2518e77f..47705a73b 100644 --- a/ragas/metrics/__init__.py +++ b/ragas/metrics/__init__.py @@ -1,17 +1,25 @@ from ragas.metrics.base import Evaluation, Metric -from ragas.metrics.factual import EntailmentScore -from ragas.metrics.similarity import SBERTScore -from ragas.metrics.simple import BLUE, EditDistance, EditRatio, Rouge1, Rouge2, RougeL +from ragas.metrics.factual import entailment_score, q_square +from ragas.metrics.similarity import bert_score +from ragas.metrics.simple import ( + bleu_score, + edit_distance, + edit_ratio, + rouge1, + rouge2, + rougeL, +) __all__ = [ "Evaluation", "Metric", - "EntailmentScore", - "SBERTScore", - "BLUE", - "EditDistance", - "EditRatio", - "RougeL", - "Rouge1", - "Rouge2", + "entailment_score", + "bert_score", + "q_square", + "bleu_score", + "edit_distance", + "edit_ratio", + "rouge1", + "rouge2", + "rougeL", ] diff --git a/ragas/metrics/factual.py b/ragas/metrics/factual.py index 3928f5b25..eb944b523 100644 --- a/ragas/metrics/factual.py +++ b/ragas/metrics/factual.py @@ -218,7 +218,7 @@ def __post_init__( @property def name(self): - return "Q^2" + return "Qsquare" @property def is_batchable(self): @@ -340,5 +340,5 @@ def score(self, ground_truth: list[str], generated_text: list[str], **kwargs): return scores -ENTScore = EntailmentScore() -Q2Score = Qsquare() +entailment_score = EntailmentScore() +q_square = Qsquare() diff --git a/ragas/metrics/similarity.py b/ragas/metrics/similarity.py index f4435ce04..2a62b2a43 100644 --- a/ragas/metrics/similarity.py +++ b/ragas/metrics/similarity.py @@ -12,12 +12,12 @@ if t.TYPE_CHECKING: from torch import Tensor -SBERT_METRIC = t.Literal["cosine", "euclidean"] +BERT_METRIC = t.Literal["cosine", "euclidean"] @dataclass -class SBERTScore(Metric): - similarity_metric: t.Literal[SBERT_METRIC] = "cosine" +class BERTScore(Metric): + similarity_metric: t.Literal[BERT_METRIC] = "cosine" model_path: str = "all-MiniLM-L6-v2" batch_size: int = 1000 @@ -28,7 +28,7 @@ def __post_init__(self): def name( self, ): - return f"SBERT_{self.similarity_metric}" + return f"BERTScore_{self.similarity_metric}" @property def is_batchable(self): @@ -64,4 +64,4 @@ def score( return score -__all__ = ["SBERTScore"] +bert_score = BERTScore() diff --git a/ragas/metrics/simple.py b/ragas/metrics/simple.py index 81fdad42f..2ac8ee8b9 100644 --- a/ragas/metrics/simple.py +++ b/ragas/metrics/simple.py @@ -91,9 +91,9 @@ def score(self, ground_truth: t.List[str], generated_text: t.List[str]): return score -Rouge1 = ROUGE("rouge1") -Rouge2 = ROUGE("rouge2") -RougeL = ROUGE("rougeL") -BLUE = BLEUScore() -EditDistance = EditScore("distance") -EditRatio = EditScore("ratio") +rouge1 = ROUGE("rouge1") +rouge2 = ROUGE("rouge2") +rougeL = ROUGE("rougeL") +bleu_score = BLEUScore() +edit_distance = EditScore("distance") +edit_ratio = EditScore("ratio") diff --git a/tests/benchmarks/benchmark.py b/tests/benchmarks/benchmark.py index dc8812f5c..14b9f0d7e 100644 --- a/tests/benchmarks/benchmark.py +++ b/tests/benchmarks/benchmark.py @@ -1,36 +1,27 @@ import typing as t -from datasets import Dataset, load_dataset +from datasets import Dataset, arrow_dataset, load_dataset from torch.cuda import is_available from tqdm import tqdm from utils import print_table, timeit -from ragas.metrics import ( - EditDistance, - EditRatio, - EntailmentScore, - Evaluation, - Rouge1, - Rouge2, - RougeL, - SBERTScore, -) +from ragas.metrics import Evaluation, edit_distance, edit_ratio, rouge1, rouge2, rougeL DEVICE = "cuda" if is_available() else "cpu" BATCHES = [0, 1] -# init metrics -sbert_score = SBERTScore(similarity_metric="cosine") -entail = EntailmentScore(max_length=512, device=DEVICE) + METRICS = { - "Rouge1": Rouge1, - "Rouge2": Rouge2, - "RougeL": RougeL, - "EditRatio": EditRatio, - "EditDistance": EditDistance, + "Rouge1": rouge1, + "Rouge2": rouge2, + "RougeL": rougeL, + "EditRatio": edit_ratio, + "EditDistance": edit_distance, # "SBERTScore": sbert_score, # "EntailmentScore": entail, } DS = load_dataset("explodinggradients/eli5-test", split="test_eli5") +assert isinstance(DS, arrow_dataset.Dataset), "Not an arrow_dataset" +DS = DS.select(range(100)) def setup() -> t.Iterator[tuple[str, Evaluation, Dataset]]: