From ddc5d76b0d2f479434de7c836d8957c1cb6b0ccc Mon Sep 17 00:00:00 2001 From: Jithin James Date: Mon, 15 May 2023 04:43:59 +0530 Subject: [PATCH] fix: some fixes and readme (#26) * added info about data * fix some errors * fix emojis * fix blue * char * quickstart * download spacy if not found * finish quickstart * fix linting issues * update badges --------- Co-authored-by: Jithin James --- README.md | 27 +- examples/quickstart.ipynb | 830 ++++++++-------------------------- pyproject.toml | 1 + ragas/metrics/factual.py | 25 +- tests/benchmarks/benchmark.py | 15 +- 5 files changed, 223 insertions(+), 675 deletions(-) diff --git a/README.md b/README.md index aefb63a65..f4d8abdc7 100644 --- a/README.md +++ b/README.md @@ -7,28 +7,29 @@

- - GitHub release + + GitHub release Build - - License + + License Open In Colab - + Downloads

- Installation | - Quick Example | - Hugging Face + Installation | + Quick Example | + Metrics List | + Hugging Face

@@ -36,7 +37,7 @@ ragas is a framework that helps you evaluate your Retrieval Augmented Generation ragas provides you with the tools based on the latest research for evaluating LLM generated text to give you insights about your RAG pipeline. ragas can be integrated with your CI/CD to provide continuous check to ensure performance. -## Installation 🛡 +## 🛡 Installation ```bash pip install ragas @@ -47,7 +48,7 @@ git clone https://github.com/explodinggradients/ragas && cd ragas pip install -e . ``` -## Quickstart 🔥 +## 🔥 Quickstart This is a small example program you can run to see ragas in action! ```python @@ -74,11 +75,13 @@ e = Evaluation( results = e.eval(ds["ground_truth"], ds["generated_text"]) print(results) ``` -If you want a more in-depth explanation of core components, check out our quick-start notebook +If you want a more in-depth explanation of core components, check out our [quick-start notebook](./examples/quickstart.ipynb) ## 🧰 Metrics ### ✏️ Character based +Character based metrics focus on analyzing text at the character level. + - **Levenshtein distance** the number of single character edits (additional, insertion, deletion) required to change your generated text to ground truth text. - **Levenshtein** **ratio** is obtained by dividing the Levenshtein distance by sum of number of characters in generated text and ground truth. This type of metrics is suitable where one works with short and precise texts. @@ -92,7 +95,7 @@ N-gram based metrics as name indicates uses n-grams for comparing generated answ - **BLEU** (BiLingual Evaluation Understudy) -It measures precision by comparing  clipped n-grams in generated text to ground truth text. These matches do not consider the ordering of words. + It measures precision by comparing  clipped n-grams in generated text to ground truth text. These matches do not consider the ordering of words. ### 🪄 Model Based diff --git a/examples/quickstart.ipynb b/examples/quickstart.ipynb index 3799e9190..2f7ccdc38 100644 --- a/examples/quickstart.ipynb +++ b/examples/quickstart.ipynb @@ -22,24 +22,27 @@ }, { "cell_type": "markdown", - "id": "40258397", + "id": "5af47053", "metadata": {}, "source": [ - "## Load the Data\n", + "### load your data\n", "\n", - "For this quickstart we are going to be using a dataset that we prepared from [eli5](https://huggingface.co/datasets/eli5) dataset with the models response. \n", + "For this quickstart we are going to be using a dataset that we prepared from [eli5](https://huggingface.co/datasets/eli5) dataset with the models response. The dataset is available in [huggingface](https://huggingface.co/datasets/explodinggradients/eli5-test).\n", "\n", - "prompt: str\n", - "context: str\n", - "references: list[str]\n", - "ground_truth: list[str]\n", - "generated_text: str" + "The dataset is of the following format\n", + "| column name | type | description |\n", + "|----------------|-----------|-----------------------------------------------------------------------------------|\n", + "| prompt | str | the prompt/question to answer |\n", + "| context | str | context string that has any relevent priors the LLM needs to answer the questions |\n", + "| references | list[str] | reference documents the LLM can use to respond to the prompt |\n", + "| ground_truth | list[str] | accepted answers given by human annotators |\n", + "| generated_text | str | the generated output from the LLM |" ] }, { "cell_type": "code", "execution_count": 2, - "id": "0b5d4d41", + "id": "2bc9fb9d", "metadata": {}, "outputs": [ { @@ -70,50 +73,50 @@ "ds" ] }, + { + "cell_type": "markdown", + "id": "1e9c0687", + "metadata": {}, + "source": [ + "### choose the metrics\n", + "\n", + "ragas provides you with a wide range of metrics to evaluate the generated answers based on the latest research. You can see the entire list [here](https://github.com/explodinggradients/ragas#metrics). For this quickstart we will be using 3 from each type we support.\n", + "1. `edit_ratio` - obtained by dividing the Levenshtein distance by sum of number of characters in generated text and ground truth.\n", + "2. `bleu_score` - It measures precision by comparing clipped n-grams in generated text to ground truth text.\n", + "3. `bert_score` - measures the similarity between ground truth text answers and generated text using SBERT vector embeddings." + ] + }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 5, "id": "0b5abd7d", "metadata": {}, "outputs": [], "source": [ - "from belar.metrics import (\n", - " Rouge1,\n", - " Evaluation,\n", - " Rouge2,\n", - " RougeL,\n", - " SBERTScore,\n", - " EntailmentScore,\n", - " EditRatio,\n", - " EditDistance,\n", - ")" + "from ragas.metrics import edit_ratio, bleu_score, bert_score" + ] + }, + { + "cell_type": "markdown", + "id": "1d95d887", + "metadata": {}, + "source": [ + "now we can initialize the `Evaluation` object. This will load your metrics and data and run the evaluation for you." ] }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 7, "id": "a77c805d", "metadata": { "scrolled": true }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.\n", - "The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.\n", - "The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.\n", - "The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.\n" - ] - } - ], + "outputs": [], "source": [ - "sbert_score = SBERTScore(similarity_metric=\"cosine\")\n", - "entail = EntailmentScore(max_length=512)\n", + "from ragas.metrics import Evaluation\n", "\n", "e = Evaluation(\n", - " metrics=[Rouge1, Rouge2, RougeL, sbert_score, EditDistance, EditRatio, entail],\n", + " metrics=[bert_score, edit_ratio, bleu_score],\n", " batched=False,\n", " batch_size=30,\n", ")" @@ -121,7 +124,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 18, "id": "e879f51b", "metadata": {}, "outputs": [ @@ -138,682 +141,207 @@ }, "metadata": {}, "output_type": "display_data" - } - ], - "source": [ - "r = e.eval(ds[\"ground_truth\"], ds[\"generated_text\"])" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "f64c1915", - "metadata": {}, - "outputs": [ + }, { - "data": { - "text/plain": [ - "{'rouge1_score': 0.27777314683149845, 'rouge2_score': 0.05593454553750915, 'rougeL_score': 0.16365190027294899, 'SBERT_cosine_score': 0.37552570906095206, 'edit_distance_score': 735.114, 'edit_ratio_score': 0.41482407945510713}" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/jjmachan/miniconda3/envs/bench/lib/python3.10/site-packages/nltk/translate/bleu_score.py:552: UserWarning: \n", + "The hypothesis contains 0 counts of 2-gram overlaps.\n", + "Therefore the BLEU score evaluates to 0, independently of\n", + "how many N-gram overlaps of lower order it contains.\n", + "Consider using lower n-gram order or use SmoothingFunction()\n", + " warnings.warn(_msg)\n", + "/home/jjmachan/miniconda3/envs/bench/lib/python3.10/site-packages/nltk/translate/bleu_score.py:552: UserWarning: \n", + "The hypothesis contains 0 counts of 3-gram overlaps.\n", + "Therefore the BLEU score evaluates to 0, independently of\n", + "how many N-gram overlaps of lower order it contains.\n", + "Consider using lower n-gram order or use SmoothingFunction()\n", + " warnings.warn(_msg)\n", + "/home/jjmachan/miniconda3/envs/bench/lib/python3.10/site-packages/nltk/translate/bleu_score.py:552: UserWarning: \n", + "The hypothesis contains 0 counts of 4-gram overlaps.\n", + "Therefore the BLEU score evaluates to 0, independently of\n", + "how many N-gram overlaps of lower order it contains.\n", + "Consider using lower n-gram order or use SmoothingFunction()\n", + " warnings.warn(_msg)\n" + ] } ], "source": [ - "r" + "# run it with .eval()\n", + "result = e.eval(ds[\"ground_truth\"], ds[\"generated_text\"])" ] }, { - "cell_type": "code", - "execution_count": 21, - "id": "7c812dfe", + "cell_type": "markdown", + "id": "31fbe76c", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "0.27777314683149845" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], "source": [ - "r[\"rouge1_score\"]" + "### analysing results\n", + "\n", + "The return `Result` object is used to analyse the results." ] }, { "cell_type": "code", - "execution_count": 22, - "id": "4c8c51b1", + "execution_count": 28, + "id": "474c0aad", "metadata": {}, "outputs": [ { "data": { + "text/html": [ + "
{'BERTScore_cosine': 0.37552570906095206, 'edit_ratio': 0.41482407945510713, 'BLEU': 0.010848577619569451}\n",
+       "
\n" + ], "text/plain": [ - "{'rouge1_score': {'mean': 0.27777314683149845,\n", - " '25%': 0.22222222222222224,\n", - " '50%': 0.28116554054054055,\n", - " '75%': 0.33333333333333337,\n", - " 'min': 0.03333333333333333,\n", - " 'max': 0.49498327759197325,\n", - " 'std': 0.07709937733409833},\n", - " 'rouge2_score': {'mean': 0.05593454553750915,\n", - " '25%': 0.029795467108899944,\n", - " '50%': 0.05203595980962454,\n", - " '75%': 0.07713675213675214,\n", - " 'min': 0.0,\n", - " 'max': 0.22499999999999998,\n", - " 'std': 0.03659179594928787},\n", - " 'rougeL_score': {'mean': 0.16365190027294899,\n", - " '25%': 0.13122438524590163,\n", - " '50%': 0.1639344262295082,\n", - " '75%': 0.19366875300914782,\n", - " 'min': 0.03333333333333333,\n", - " 'max': 0.3087248322147651,\n", - " 'std': 0.04582111082128693},\n", - " 'SBERT_cosine_score': {'mean': 0.37552570906095206,\n", - " '25%': 0.2123386301100254,\n", - " '50%': 0.33269713819026947,\n", - " '75%': 0.5326416194438934,\n", - " 'min': 0.007017173804342747,\n", - " 'max': 0.9106802940368652,\n", - " 'std': 0.2075585785391846},\n", - " 'edit_distance_score': {'mean': 735.114,\n", - " '25%': 311.5,\n", - " '50%': 476.5,\n", - " '75%': 864.25,\n", - " 'min': 106,\n", - " 'max': 6370,\n", - " 'std': 729.5287718822336},\n", - " 'edit_ratio_score': {'mean': 0.41482407945510713,\n", - " '25%': 0.39987631416202846,\n", - " '50%': 0.42918677093154384,\n", - " '75%': 0.4495093721921233,\n", - " 'min': 0.10218156228008446,\n", - " 'max': 0.5729166666666667,\n", - " 'std': 0.05807177049561045}}" + "\u001b[1m{\u001b[0m\u001b[32m'BERTScore_cosine'\u001b[0m: \u001b[1;36m0.37552570906095206\u001b[0m, \u001b[32m'edit_ratio'\u001b[0m: \u001b[1;36m0.41482407945510713\u001b[0m, \u001b[32m'BLEU'\u001b[0m: \u001b[1;36m0.010848577619569451\u001b[0m\u001b[1m}\u001b[0m\n" ] }, - "execution_count": 22, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ - "r.describe()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ebf0a29d", - "metadata": {}, - "outputs": [], - "source": [ - "t_not_batched = ds_eval[\"rouge1_score\"]" + "from rich.pretty import pprint\n", + "\n", + "pprint(result)" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "4882982d", + "cell_type": "markdown", + "id": "eb07bbec", "metadata": {}, - "outputs": [], "source": [ - "np.array(t_batched) - np.array(t_not_batched)" + "you can access individual metric results via `result['']`. it also has a `.describe()` function to show the distribution of the results and you can access the individual score from `.scores` attribute." ] }, { "cell_type": "code", - "execution_count": 5, - "id": "08ef4d51", + "execution_count": 16, + "id": "4c8c51b1", "metadata": {}, "outputs": [ { "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
BERTScore_cosineedit_ratioBLEU
mean0.3755260.4148241.084858e-02
25%0.2123390.3998763.489775e-155
50%0.3326970.4291874.318061e-79
75%0.5326420.4495091.525948e-05
min0.0070170.1021824.029193e-232
max0.9106800.5729171.506915e-01
std0.2075590.0580722.343307e-02
\n", + "
" + ], "text/plain": [ - "['ground_truth', 'generated_text', 'SBERT_cosine_score']" + " BERTScore_cosine edit_ratio BLEU\n", + "mean 0.375526 0.414824 1.084858e-02\n", + "25% 0.212339 0.399876 3.489775e-155\n", + "50% 0.332697 0.429187 4.318061e-79\n", + "75% 0.532642 0.449509 1.525948e-05\n", + "min 0.007017 0.102182 4.029193e-232\n", + "max 0.910680 0.572917 1.506915e-01\n", + "std 0.207559 0.058072 2.343307e-02" ] }, - "execution_count": 5, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "ds_eval.column_names" + "from pandas import DataFrame\n", + "\n", + "# view with pandas\n", + "df = DataFrame(result.describe())\n", + "df" ] }, { "cell_type": "code", - "execution_count": 6, - "id": "f8a58fa8", + "execution_count": 29, + "id": "421c60ab", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[0.3033774197101593,\n", - " 0.016349632292985916,\n", - " 0.4478442072868347,\n", - " 0.1860141158103943,\n", - " 0.03600190579891205,\n", - " 0.6023079752922058,\n", - " 0.289838969707489,\n", - " 0.08502114564180374,\n", - " 0.17191164195537567,\n", - " 0.3593299984931946,\n", - " 0.1715232878923416,\n", - " 0.3805505037307739,\n", - " 0.5519564151763916,\n", - " 0.2677731215953827,\n", - " 0.6183438301086426,\n", - " 0.10611602663993835,\n", - " 0.19605034589767456,\n", - " 0.08165217190980911,\n", - " 0.29304254055023193,\n", - " 0.35943326354026794,\n", - " 0.38164564967155457,\n", - " 0.03771442547440529,\n", - " 0.11554502695798874,\n", - " 0.47948333621025085,\n", - " 0.23276342451572418,\n", - " 0.4236215353012085,\n", - " 0.1943129450082779,\n", - " 0.1942053735256195,\n", - " 0.12668733298778534,\n", - " 0.2597537338733673,\n", - " 0.33301281929016113,\n", - " 0.3094521462917328,\n", - " 0.3279588520526886,\n", - " 0.32722654938697815,\n", - " 0.38284799456596375,\n", - " 0.2851578891277313,\n", - " 0.23893719911575317,\n", - " 0.6166086196899414,\n", - " 0.2423057109117508,\n", - " 0.7267876267433167,\n", - " 0.08813111484050751,\n", - " 0.48606470227241516,\n", - " 0.6568448543548584,\n", - " 0.1358499825000763,\n", - " 0.4515664577484131,\n", - " 0.23441915214061737,\n", - " 0.4741160571575165,\n", - " 0.18968994915485382,\n", - " 0.382995069026947,\n", - " 0.7173715233802795,\n", - " 0.7269276976585388,\n", - " 0.2834068834781647,\n", - " 0.2564486265182495,\n", - " 0.9106802940368652,\n", - " 0.3905271291732788,\n", - " 0.1269465684890747,\n", - " 0.09796524047851562,\n", - " 0.6954237222671509,\n", - " 0.49959367513656616,\n", - " 0.3481505811214447,\n", - " 0.2524052858352661,\n", - " 0.20396579802036285,\n", - " 0.4261414706707001,\n", - " 0.35149627923965454,\n", - " 0.060562025755643845,\n", - " 0.29626941680908203,\n", - " 0.33264321088790894,\n", - " 0.32353609800338745,\n", - " 0.0929298847913742,\n", - " 0.694779634475708,\n", - " 0.42692476511001587,\n", - " 0.6740735769271851,\n", - " 0.26791706681251526,\n", - " 0.30361559987068176,\n", - " 0.6142315864562988,\n", - " 0.8581538200378418,\n", - " 0.1934203803539276,\n", - " 0.17560303211212158,\n", - " 0.39025163650512695,\n", - " 0.2257130742073059,\n", - " 0.10104137659072876,\n", - " 0.5671371221542358,\n", - " 0.2376122921705246,\n", - " 0.7245509624481201,\n", - " 0.33550819754600525,\n", - " 0.16170960664749146,\n", - " 0.3289082944393158,\n", - " 0.21686506271362305,\n", - " 0.5573591589927673,\n", - " 0.39316579699516296,\n", - " 0.3452097177505493,\n", - " 0.7620242238044739,\n", - " 0.612403154373169,\n", - " 0.20761919021606445,\n", - " 0.3436463177204132,\n", - " 0.35804855823516846,\n", - " 0.5422661304473877,\n", - " 0.2482432872056961,\n", - " 0.24608035385608673,\n", - " 0.43996310234069824,\n", - " 0.7638659477233887,\n", - " 0.4832608997821808,\n", - " 0.3723938763141632,\n", - " 0.16313855350017548,\n", - " 0.17755097150802612,\n", - " 0.7125013470649719,\n", - " 0.21019332110881805,\n", - " 0.2878414988517761,\n", - " 0.7330911755561829,\n", - " 0.5391034483909607,\n", - " 0.3856879770755768,\n", - " 0.21089066565036774,\n", - " 0.21917514503002167,\n", - " 0.5970359444618225,\n", - " 0.10427114367485046,\n", - " 0.5017147660255432,\n", - " 0.32604700326919556,\n", - " 0.26022183895111084,\n", - " 0.2217114269733429,\n", - " 0.5664410591125488,\n", - " 0.6097017526626587,\n", - " 0.6790091395378113,\n", - " 0.6737412810325623,\n", - " 0.3198738396167755,\n", - " 0.3233138620853424,\n", - " 0.27815982699394226,\n", - " 0.5739132165908813,\n", - " 0.8073441982269287,\n", - " 0.393609881401062,\n", - " 0.34070584177970886,\n", - " 0.1426166594028473,\n", - " 0.3649061918258667,\n", - " 0.21035610139369965,\n", - " 0.15468955039978027,\n", - " 0.15301679074764252,\n", - " 0.3864727020263672,\n", - " 0.3432256877422333,\n", - " 0.27995312213897705,\n", - " 0.45306405425071716,\n", - " 0.152155339717865,\n", - " 0.5590802431106567,\n", - " 0.14337098598480225,\n", - " 0.5684935450553894,\n", - " 0.06331620365381241,\n", - " 0.7308592200279236,\n", - " 0.3433731496334076,\n", - " 0.49904948472976685,\n", - " 0.24472254514694214,\n", - " 0.17057321965694427,\n", - " 0.17359305918216705,\n", - " 0.1405472606420517,\n", - " 0.21779431402683258,\n", - " 0.6882146596908569,\n", - " 0.39259153604507446,\n", - " 0.5250310301780701,\n", - " 0.29845374822616577,\n", - " 0.6535312533378601,\n", - " 0.3323957920074463,\n", - " 0.6179606318473816,\n", - " 0.6263958215713501,\n", - " 0.25900962948799133,\n", - " 0.35419002175331116,\n", - " 0.33175551891326904,\n", - " 0.1691923886537552,\n", - " 0.6974550485610962,\n", - " 0.5213074088096619,\n", - " 0.032654277980327606,\n", - " 0.34367528557777405,\n", - " 0.405593603849411,\n", - " 0.08452585339546204,\n", - " 0.11424578726291656,\n", - " 0.6650150418281555,\n", - " 0.2742277681827545,\n", - " 0.28393787145614624,\n", - " 0.29564306139945984,\n", - " 0.5309538245201111,\n", - " 0.022119097411632538,\n", - " 0.5228688716888428,\n", - " 0.6862163543701172,\n", - " 0.4796127676963806,\n", - " 0.331642746925354,\n", - " 0.469801127910614,\n", - " 0.2787094712257385,\n", - " 0.15432526171207428,\n", - " 0.13090954720973969,\n", - " 0.5296900272369385,\n", - " 0.5006809830665588,\n", - " 0.31476107239723206,\n", - " 0.6327821612358093,\n", - " 0.27751827239990234,\n", - " 0.08453290164470673,\n", - " 0.152990460395813,\n", - " 0.2828467786312103,\n", - " 0.21192562580108643,\n", - " 0.23361067473888397,\n", - " 0.1100977212190628,\n", - " 0.729167640209198,\n", - " 0.25679513812065125,\n", - " 0.29639971256256104,\n", - " 0.19549258053302765,\n", - " 0.01892801746726036,\n", - " 0.7945613265037537,\n", - " 0.7499642372131348,\n", - " 0.15835057199001312,\n", - " 0.6000410914421082,\n", - " 0.38472887873649597,\n", - " 0.27581414580345154,\n", - " 0.6135129332542419,\n", - " 0.30333641171455383,\n", - " 0.6530413627624512,\n", - " 0.32561489939689636,\n", - " 0.6843974590301514,\n", - " 0.7383497953414917,\n", - " 0.1791287064552307,\n", - " 0.15797390043735504,\n", - " 0.1897229701280594,\n", - " 0.34278005361557007,\n", - " 0.523197591304779,\n", - " 0.2993963062763214,\n", - " 0.24305762350559235,\n", - " 0.2124125361442566,\n", - " 0.23200078308582306,\n", - " 0.5277230739593506,\n", - " 0.3923065960407257,\n", - " 0.2338612824678421,\n", - " 0.6605720520019531,\n", - " 0.4534214735031128,\n", - " 0.7204974889755249,\n", - " 0.4256589412689209,\n", - " 0.1377628594636917,\n", - " 0.1862977296113968,\n", - " 0.6173402070999146,\n", - " 0.2129381150007248,\n", - " 0.18199223279953003,\n", - " 0.4077472388744354,\n", - " 0.5461190938949585,\n", - " 0.7703336477279663,\n", - " 0.7089384198188782,\n", - " 0.12397469580173492,\n", - " 0.3445894420146942,\n", - " 0.29747506976127625,\n", - " 0.12937960028648376,\n", - " 0.6808912754058838,\n", - " 0.44350528717041016,\n", - " 0.0622265450656414,\n", - " 0.800916314125061,\n", - " 0.196528360247612,\n", - " 0.40886160731315613,\n", - " 0.5457544326782227,\n", - " 0.7547292113304138,\n", - " 0.17570790648460388,\n", - " 0.33092451095581055,\n", - " 0.3909622132778168,\n", - " 0.1750270575284958,\n", - " 0.21135497093200684,\n", - " 0.2844017744064331,\n", - " 0.6711058616638184,\n", - " 0.7111238241195679,\n", - " 0.39750146865844727,\n", - " 0.3603275716304779,\n", - " 0.20594996213912964,\n", - " 0.26992928981781006,\n", - " 0.32206788659095764,\n", - " 0.5537823438644409,\n", - " 0.6196168065071106,\n", - " 0.17448124289512634,\n", - " 0.8145052194595337,\n", - " 0.13209058344364166,\n", - " 0.6009707450866699,\n", - " 0.1729992777109146,\n", - " 0.605941891670227,\n", - " 0.16112592816352844,\n", - " 0.7443314790725708,\n", - " 0.27183473110198975,\n", - " 0.6732509732246399,\n", - " 0.34409621357917786,\n", - " 0.6225290894508362,\n", - " 0.7111546397209167,\n", - " 0.25248128175735474,\n", - " 0.25385937094688416,\n", - " 0.4553792476654053,\n", - " 0.007017173804342747,\n", - " 0.5378240942955017,\n", - " 0.6920719146728516,\n", - " 0.5118893980979919,\n", - " 0.7575575113296509,\n", - " 0.053049687296152115,\n", - " 0.34726738929748535,\n", - " 0.625588595867157,\n", - " 0.2684467136859894,\n", - " 0.21171455085277557,\n", - " 0.16874279081821442,\n", - " 0.6806609034538269,\n", - " 0.6409006118774414,\n", - " 0.43180617690086365,\n", - " 0.36487576365470886,\n", - " 0.25573742389678955,\n", - " 0.8596686124801636,\n", - " 0.7924257516860962,\n", - " 0.2288934737443924,\n", - " 0.37159034609794617,\n", - " 0.21388927102088928,\n", - " 0.7443233132362366,\n", - " 0.1677546203136444,\n", - " 0.590474009513855,\n", - " 0.2609856426715851,\n", - " 0.2530490458011627,\n", - " 0.26618924736976624,\n", - " 0.25583404302597046,\n", - " 0.20902562141418457,\n", - " 0.5943877696990967,\n", - " 0.07199332863092422,\n", - " 0.44120875000953674,\n", - " 0.3591962456703186,\n", - " 0.6544501781463623,\n", - " 0.12697549164295197,\n", - " 0.3532907962799072,\n", - " 0.4480339288711548,\n", - " 0.7042593359947205,\n", - " 0.11615218967199326,\n", - " 0.6357651948928833,\n", - " 0.24792085587978363,\n", - " 0.3313771188259125,\n", - " 0.5221624970436096,\n", - " 0.35108593106269836,\n", - " 0.135896697640419,\n", - " 0.15817011892795563,\n", - " 0.8391244411468506,\n", - " 0.2277119904756546,\n", - " 0.04543468356132507,\n", - " 0.25068429112434387,\n", - " 0.1133192926645279,\n", - " 0.28534117341041565,\n", - " 0.8111948370933533,\n", - " 0.3385901153087616,\n", - " 0.49840831756591797,\n", - " 0.4116763174533844,\n", - " 0.16915757954120636,\n", - " 0.3262860178947449,\n", - " 0.10765945911407471,\n", - " 0.1261938512325287,\n", - " 0.3500753939151764,\n", - " 0.2676033079624176,\n", - " 0.6120821833610535,\n", - " 0.62961345911026,\n", - " 0.27265217900276184,\n", - " 0.7611227035522461,\n", - " 0.2189398556947708,\n", - " 0.271114706993103,\n", - " 0.7538965940475464,\n", - " 0.1766694337129593,\n", - " 0.26010769605636597,\n", - " 0.14162400364875793,\n", - " 0.15965068340301514,\n", - " 0.30319979786872864,\n", - " 0.23467262089252472,\n", - " 0.7990760207176208,\n", - " 0.3484833538532257,\n", - " 0.3364700973033905,\n", - " 0.36943286657333374,\n", - " 0.37875810265541077,\n", - " 0.5377050042152405,\n", - " 0.2255283147096634,\n", - " 0.6214497089385986,\n", - " 0.572303295135498,\n", - " 0.5672966241836548,\n", - " 0.4602000117301941,\n", - " 0.6925125122070312,\n", - " 0.19061176478862762,\n", - " 0.750962495803833,\n", - " 0.057794470340013504,\n", - " 0.22833339869976044,\n", - " 0.12149019539356232,\n", - " 0.5187497735023499,\n", - " 0.43326133489608765,\n", - " 0.7459068298339844,\n", - " 0.28757017850875854,\n", - " 0.060881346464157104,\n", - " 0.19995999336242676,\n", - " 0.2332974374294281,\n", - " 0.5807837843894958,\n", - " 0.4985215663909912,\n", - " 0.2317824810743332,\n", - " 0.20419657230377197,\n", - " 0.2929933965206146,\n", - " 0.22726529836654663,\n", - " 0.36383742094039917,\n", - " 0.26542332768440247,\n", - " 0.33275106549263,\n", - " 0.1817902773618698,\n", - " 0.019586173817515373,\n", - " 0.6501842737197876,\n", - " 0.5130109786987305,\n", - " 0.04855664074420929,\n", - " 0.327665239572525,\n", - " 0.33484169840812683,\n", - " 0.18408897519111633,\n", - " 0.8089461326599121,\n", - " 0.2609926760196686,\n", - " 0.35048383474349976,\n", - " 0.3380715847015381,\n", - " 0.19198913872241974,\n", - " 0.47304245829582214,\n", - " 0.13059648871421814,\n", - " 0.388828307390213,\n", - " 0.6691229939460754,\n", - " 0.1510116457939148,\n", - " 0.20976220071315765,\n", - " 0.4316028952598572,\n", - " 0.5592595934867859,\n", - " 0.4931623339653015,\n", - " 0.40056753158569336,\n", - " 0.1390654295682907,\n", - " 0.7112942337989807,\n", - " 0.30744668841362,\n", - " 0.2824617028236389,\n", - " 0.29495444893836975,\n", - " 0.8129028081893921,\n", - " 0.04778153821825981,\n", - " 0.3677351772785187,\n", - " 0.38807204365730286,\n", - " 0.23143930733203888,\n", - " 0.3730814754962921,\n", - " 0.3903065323829651,\n", - " 0.10604582726955414,\n", - " 0.375832736492157,\n", - " 0.32024890184402466,\n", - " 0.3080943822860718,\n", - " 0.6008120775222778,\n", - " 0.8878772258758545,\n", - " 0.4099455773830414,\n", - " 0.4919497072696686,\n", - " 0.21881842613220215,\n", - " 0.7104718089103699,\n", - " 0.40945085883140564,\n", - " 0.7066667675971985,\n", - " 0.3884510099887848,\n", - " 0.29029491543769836,\n", - " 0.48201748728752136,\n", - " 0.645422637462616,\n", - " 0.46089968085289,\n", - " 0.26423460245132446,\n", - " 0.3575299084186554,\n", - " 0.12025940418243408,\n", - " 0.3637012839317322,\n", - " 0.5629667043685913,\n", - " 0.21808886528015137,\n", - " 0.20087826251983643,\n", - " 0.19176578521728516,\n", - " 0.521368145942688,\n", - " 0.4651867747306824,\n", - " 0.2771470844745636,\n", - " 0.15467087924480438,\n", - " 0.06321043521165848,\n", - " 0.727550208568573,\n", - " 0.6326872706413269,\n", - " 0.2524058222770691,\n", - " 0.40928635001182556,\n", - " 0.2859940230846405,\n", - " 0.24548542499542236,\n", - " 0.25654155015945435,\n", - " 0.1554943472146988,\n", - " 0.2810353636741638,\n", - " 0.39291778206825256,\n", - " 0.7448244094848633,\n", - " 0.36232057213783264,\n", - " 0.2249329537153244,\n", - " 0.5934489369392395,\n", - " 0.36474189162254333,\n", - " 0.16170084476470947,\n", - " 0.2098686695098877,\n", - " 0.3690999746322632,\n", - " 0.6965110898017883,\n", - " 0.21211691200733185,\n", - " 0.6880887150764465,\n", - " 0.7315702438354492,\n", - " 0.2110704928636551,\n", - " 0.8123224973678589,\n", - " 0.7990055680274963,\n", - " 0.14683164656162262,\n", - " 0.25454556941986084,\n", - " 0.11940312385559082,\n", - " 0.2454526573419571,\n", - " 0.5912683010101318,\n", - " 0.4947351813316345,\n", - " 0.4511561691761017,\n", - " 0.13149523735046387,\n", - " 0.1972067654132843,\n", - " 0.3593907356262207,\n", - " 0.5928145051002502,\n", - " 0.25529202818870544,\n", - " 0.2567807137966156,\n", - " 0.20362421870231628,\n", - " 0.30127424001693726,\n", - " 0.6847882270812988,\n", - " 0.6155568957328796,\n", - " 0.2527444660663605,\n", - " 0.4813864529132843,\n", - " 0.3825063407421112,\n", - " 0.3193434178829193]" + "Dataset({\n", + " features: ['BERTScore_cosine', 'edit_ratio', 'BLEU'],\n", + " num_rows: 500\n", + "})" ] }, - "execution_count": 6, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "ds_eval[\"SBERT_cosine_score\"]" + "result.scores" ] - }, - { - "cell_type": "markdown", - "id": "3893e1c7", - "metadata": {}, - "source": [] } ], "metadata": { diff --git a/pyproject.toml b/pyproject.toml index 2232962c5..1ef38b55f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "nltk", "datasets", "spacy<4.0.0,>=3.0.0", + "protobuf<=3.20.0", ] dynamic = ["version", "readme"] diff --git a/ragas/metrics/factual.py b/ragas/metrics/factual.py index 82cbfb79e..d9ab22909 100644 --- a/ragas/metrics/factual.py +++ b/ragas/metrics/factual.py @@ -5,10 +5,13 @@ import string import typing as t from dataclasses import dataclass +from warnings import warn import numpy as np import spacy +import torch import transformers +from spacy.cli.download import download as spacy_download from transformers import ( AutoConfig, AutoModelForSequenceClassification, @@ -16,7 +19,6 @@ PreTrainedModel, ) -from ragas.exceptions import RagasException from ragas.metrics import Metric from ragas.utils import device_check @@ -213,13 +215,18 @@ class Qsquare(Metric): def __post_init__(self): self.qa = QAGQ.from_pretrained(self.qa_model_name) self.qg = QAGQ.from_pretrained(self.qg_model_name) + self.nli = EntailmentScore() try: self.nlp = spacy.load(SPACY_MODEL) except OSError: - raise RagasException( + warn( f"Spacy model [{SPACY_MODEL}] not found. Please run " - "`python -m spacy download {SPACY_MODEL}` to install it." + f"`python -m spacy download {SPACY_MODEL}` to install it." ) + # logger.warning(f"Spacy models '{spacy_model_name}' not found." + # " Downloading and installing.") + spacy_download(SPACY_MODEL) + self.nlp = spacy.load(SPACY_MODEL) @property def name(self): @@ -288,7 +295,6 @@ def clean_candidate(self, text): return text def score_candidates(self, ques_ans_dict: dict): - nli = EntailmentScore() for qas in ques_ans_dict.values(): for item in qas: item["answer"] = self.clean_candidate(item["answer"]) @@ -299,7 +305,7 @@ def score_candidates(self, ques_ans_dict: dict): item.update({"score": 1}) else: qstn = item.get("question") - score_dict = nli.infer( + score_dict = self.nli.infer( f'{qstn}{item.get("answer")}', f'{qstn}{item.get("predicted_answer")}', ) @@ -331,8 +337,8 @@ def score(self, ground_truth: list[str], generated_text: list[str], **kwargs): for item, ans in zip(gnd_qans[i], gen_answers) ] - del self.qa - del self.qg + # del self.qa + # del self.qg gnd_qans = self.score_candidates(gnd_qans) @@ -345,5 +351,6 @@ def score(self, ground_truth: list[str], generated_text: list[str], **kwargs): return scores -entailment_score = EntailmentScore() -q_square = Qsquare() +device = "cuda" if torch.cuda.is_available() else "cpu" +entailment_score = EntailmentScore(device=device) +q_square = Qsquare(device=device) diff --git a/tests/benchmarks/benchmark.py b/tests/benchmarks/benchmark.py index 14b9f0d7e..5bb413fc0 100644 --- a/tests/benchmarks/benchmark.py +++ b/tests/benchmarks/benchmark.py @@ -5,7 +5,15 @@ from tqdm import tqdm from utils import print_table, timeit -from ragas.metrics import Evaluation, edit_distance, edit_ratio, rouge1, rouge2, rougeL +from ragas.metrics import ( + Evaluation, + edit_distance, + edit_ratio, + q_square, + rouge1, + rouge2, + rougeL, +) DEVICE = "cuda" if is_available() else "cpu" BATCHES = [0, 1] @@ -16,8 +24,9 @@ "RougeL": rougeL, "EditRatio": edit_ratio, "EditDistance": edit_distance, - # "SBERTScore": sbert_score, - # "EntailmentScore": entail, + # "SBERTScore": bert_score, + # "EntailmentScore": entailment_score, + "Qsquare": q_square, } DS = load_dataset("explodinggradients/eli5-test", split="test_eli5") assert isinstance(DS, arrow_dataset.Dataset), "Not an arrow_dataset"