Skip to content

Commit

Permalink
Initialise all metics (#21)
Browse files Browse the repository at this point in the history
* init all metrics

* change metric imports
  • Loading branch information
shahules786 authored May 14, 2023
1 parent bc9d645 commit 8ddf9fe
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 44 deletions.
30 changes: 19 additions & 11 deletions ragas/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
from ragas.metrics.base import Evaluation, Metric
from ragas.metrics.factual import EntailmentScore
from ragas.metrics.similarity import SBERTScore
from ragas.metrics.simple import BLUE, EditDistance, EditRatio, Rouge1, Rouge2, RougeL
from ragas.metrics.factual import entailment_score, q_square
from ragas.metrics.similarity import bert_score
from ragas.metrics.simple import (
bleu_score,
edit_distance,
edit_ratio,
rouge1,
rouge2,
rougeL,
)

__all__ = [
"Evaluation",
"Metric",
"EntailmentScore",
"SBERTScore",
"BLUE",
"EditDistance",
"EditRatio",
"RougeL",
"Rouge1",
"Rouge2",
"entailment_score",
"bert_score",
"q_square",
"bleu_score",
"edit_distance",
"edit_ratio",
"rouge1",
"rouge2",
"rougeL",
]
6 changes: 3 additions & 3 deletions ragas/metrics/factual.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def __post_init__(

@property
def name(self):
return "Q^2"
return "Qsquare"

@property
def is_batchable(self):
Expand Down Expand Up @@ -340,5 +340,5 @@ def score(self, ground_truth: list[str], generated_text: list[str], **kwargs):
return scores


ENTScore = EntailmentScore()
Q2Score = Qsquare()
entailment_score = EntailmentScore()
q_square = Qsquare()
10 changes: 5 additions & 5 deletions ragas/metrics/similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
if t.TYPE_CHECKING:
from torch import Tensor

SBERT_METRIC = t.Literal["cosine", "euclidean"]
BERT_METRIC = t.Literal["cosine", "euclidean"]


@dataclass
class SBERTScore(Metric):
similarity_metric: t.Literal[SBERT_METRIC] = "cosine"
class BERTScore(Metric):
similarity_metric: t.Literal[BERT_METRIC] = "cosine"
model_path: str = "all-MiniLM-L6-v2"
batch_size: int = 1000

Expand All @@ -28,7 +28,7 @@ def __post_init__(self):
def name(
self,
):
return f"SBERT_{self.similarity_metric}"
return f"BERTScore_{self.similarity_metric}"

@property
def is_batchable(self):
Expand Down Expand Up @@ -64,4 +64,4 @@ def score(
return score


__all__ = ["SBERTScore"]
bert_score = BERTScore()
12 changes: 6 additions & 6 deletions ragas/metrics/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ def score(self, ground_truth: t.List[str], generated_text: t.List[str]):
return score


Rouge1 = ROUGE("rouge1")
Rouge2 = ROUGE("rouge2")
RougeL = ROUGE("rougeL")
BLUE = BLEUScore()
EditDistance = EditScore("distance")
EditRatio = EditScore("ratio")
rouge1 = ROUGE("rouge1")
rouge2 = ROUGE("rouge2")
rougeL = ROUGE("rougeL")
bleu_score = BLEUScore()
edit_distance = EditScore("distance")
edit_ratio = EditScore("ratio")
29 changes: 10 additions & 19 deletions tests/benchmarks/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,27 @@
import typing as t

from datasets import Dataset, load_dataset
from datasets import Dataset, arrow_dataset, load_dataset
from torch.cuda import is_available
from tqdm import tqdm
from utils import print_table, timeit

from ragas.metrics import (
EditDistance,
EditRatio,
EntailmentScore,
Evaluation,
Rouge1,
Rouge2,
RougeL,
SBERTScore,
)
from ragas.metrics import Evaluation, edit_distance, edit_ratio, rouge1, rouge2, rougeL

DEVICE = "cuda" if is_available() else "cpu"
BATCHES = [0, 1]
# init metrics
sbert_score = SBERTScore(similarity_metric="cosine")
entail = EntailmentScore(max_length=512, device=DEVICE)

METRICS = {
"Rouge1": Rouge1,
"Rouge2": Rouge2,
"RougeL": RougeL,
"EditRatio": EditRatio,
"EditDistance": EditDistance,
"Rouge1": rouge1,
"Rouge2": rouge2,
"RougeL": rougeL,
"EditRatio": edit_ratio,
"EditDistance": edit_distance,
# "SBERTScore": sbert_score,
# "EntailmentScore": entail,
}
DS = load_dataset("explodinggradients/eli5-test", split="test_eli5")
assert isinstance(DS, arrow_dataset.Dataset), "Not an arrow_dataset"
DS = DS.select(range(100))


def setup() -> t.Iterator[tuple[str, Evaluation, Dataset]]:
Expand Down

0 comments on commit 8ddf9fe

Please sign in to comment.