From 75db22900612ec4f01b1243d107895397a924e06 Mon Sep 17 00:00:00 2001 From: Jithin James Date: Mon, 23 Oct 2023 15:13:07 +0530 Subject: [PATCH] feat: llamaIndex llm support (#205) Added support for LlamaIndex `ServiceContext` and `BaseLLM`. Helps directly use llamaIndex LLMs with Ragas --- docs/howtos/customisations/llms.ipynb | 8 +- requirements/docs.txt | 1 + src/ragas/llms/__init__.py | 4 + src/ragas/llms/base.py | 160 +++++++++++++++++++++++++ src/ragas/llms/llamaindex.py | 47 ++++++++ src/ragas/metrics/base.py | 11 +- src/ragas/metrics/critique.py | 6 +- src/ragas/testset/testset_generator.py | 2 +- 8 files changed, 222 insertions(+), 17 deletions(-) create mode 100644 src/ragas/llms/__init__.py create mode 100644 src/ragas/llms/base.py create mode 100644 src/ragas/llms/llamaindex.py diff --git a/docs/howtos/customisations/llms.ipynb b/docs/howtos/customisations/llms.ipynb index dea0ccd22..6c9ad5167 100644 --- a/docs/howtos/customisations/llms.ipynb +++ b/docs/howtos/customisations/llms.ipynb @@ -179,8 +179,8 @@ "from ragas import evaluate\n", "\n", "result = evaluate(\n", - " fiqa_eval[\"baseline\"].select(range(5)), # showing only 5 for demonstration \n", - " metrics=[faithfulness]\n", + " fiqa_eval[\"baseline\"].select(range(5)), # showing only 5 for demonstration\n", + " metrics=[faithfulness],\n", ")\n", "\n", "result" @@ -301,8 +301,8 @@ "from ragas import evaluate\n", "\n", "result = evaluate(\n", - " fiqa_eval[\"baseline\"].select(range(5)), # showing only 5 for demonstration \n", - " metrics=[faithfulness]\n", + " fiqa_eval[\"baseline\"].select(range(5)), # showing only 5 for demonstration\n", + " metrics=[faithfulness],\n", ")\n", "\n", "result" diff --git a/requirements/docs.txt b/requirements/docs.txt index 4e93c630d..923fba65d 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -4,3 +4,4 @@ myst-parser[linkify] sphinx_design astroid<3 myst-nb +llama_index diff --git a/src/ragas/llms/__init__.py b/src/ragas/llms/__init__.py new file mode 100644 index 000000000..43094b7d2 --- /dev/null +++ b/src/ragas/llms/__init__.py @@ -0,0 +1,4 @@ +from ragas.llms.base import BaseRagasLLM, LangchainLLM, llm_factory +from ragas.llms.llamaindex import LlamaIndexLLM + +__all__ = ["BaseRagasLLM", "LangchainLLM", "LlamaIndexLLM", "llm_factory"] diff --git a/src/ragas/llms/base.py b/src/ragas/llms/base.py new file mode 100644 index 000000000..7a154c6f4 --- /dev/null +++ b/src/ragas/llms/base.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +import os +import typing as t +from abc import ABC, abstractmethod + +from langchain.chat_models import AzureChatOpenAI, ChatOpenAI +from langchain.chat_models.base import BaseChatModel +from langchain.llms import AzureOpenAI, OpenAI +from langchain.llms.base import BaseLLM +from langchain.schema import LLMResult + +from ragas.async_utils import run_async_tasks + +if t.TYPE_CHECKING: + from langchain.callbacks.base import Callbacks + from langchain.prompts import ChatPromptTemplate + + +def isOpenAI(llm: BaseLLM | BaseChatModel) -> bool: + return isinstance(llm, OpenAI) or isinstance(llm, ChatOpenAI) + + +# have to specify it twice for runtime and static checks +MULTIPLE_COMPLETION_SUPPORTED = [OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI] +MultipleCompletionSupportedLLM = t.Union[ + OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI +] + + +class BaseRagasLLM(ABC): + """ + BaseLLM is the base class for all LLMs. It provides a consistent interface for other + classes that interact with LLMs like Langchains, LlamaIndex, LiteLLM etc. Handles + multiple_completions even if not supported by the LLM. + + It currently takes in ChatPromptTemplates and returns LLMResults which are Langchain + primitives. + """ + + # supports multiple compeletions for the given prompt + n_completions_supported: bool = False + + @property + @abstractmethod + def llm(self): + ... + + @abstractmethod + def generate( + self, + prompts: list[str], + n: int = 1, + temperature: float = 0, + callbacks: t.Optional[Callbacks] = None, + ) -> list[list[str]]: + ... + + +class LangchainLLM(BaseRagasLLM): + n_completions_supported: bool = True + + def __init__(self, llm: BaseLLM | BaseChatModel): + self.langchain_llm = llm + + @property + def llm(self): + return self.langchain_llm + + @staticmethod + def llm_supports_completions(llm): + for llm_type in MULTIPLE_COMPLETION_SUPPORTED: + if isinstance(llm, llm_type): + return True + + def generate_multiple_completions( + self, + prompts: list[ChatPromptTemplate], + n: int = 1, + callbacks: t.Optional[Callbacks] = None, + ) -> LLMResult: + self.langchain_llm = t.cast(MultipleCompletionSupportedLLM, self.langchain_llm) + old_n = self.langchain_llm.n + self.langchain_llm.n = n + + if isinstance(self.llm, BaseLLM): + ps = [p.format() for p in prompts] + result = self.llm.generate(ps, callbacks=callbacks) + else: # if BaseChatModel + ps = [p.format_messages() for p in prompts] + result = self.llm.generate(ps, callbacks=callbacks) + self.llm.n = old_n + + return result + + async def generate_completions( + self, + prompts: list[ChatPromptTemplate], + callbacks: t.Optional[Callbacks] = None, + ) -> LLMResult: + if isinstance(self.llm, BaseLLM): + ps = [p.format() for p in prompts] + result = await self.llm.agenerate(ps, callbacks=callbacks) + else: # if BaseChatModel + ps = [p.format_messages() for p in prompts] + result = await self.llm.agenerate(ps, callbacks=callbacks) + + return result + + def generate( + self, + prompts: list[ChatPromptTemplate], + n: int = 1, + temperature: float = 0, + callbacks: t.Optional[Callbacks] = None, + ) -> LLMResult: + # set temperature to 0.2 for multiple completions + temperature = 0.2 if n > 1 else 0 + self.llm.temperature = temperature + + if self.llm_supports_completions(self.llm): + return self.generate_multiple_completions(prompts, n, callbacks) + else: # call generate_completions n times to mimic multiple completions + list_llmresults = run_async_tasks( + [self.generate_completions(prompts, callbacks) for _ in range(n)] + ) + + # fill results as if the LLM supported multiple completions + generations = [] + for i in range(len(prompts)): + completions = [] + for result in list_llmresults: + completions.append(result.generations[i][0]) + generations.append(completions) + + # compute total token usage by adding individual token usage + llm_output = list_llmresults[0].llm_output + if "token_usage" in llm_output: + sum_prompt_tokens = 0 + sum_completion_tokens = 0 + sum_total_tokens = 0 + for result in list_llmresults: + token_usage = result.llm_output["token_usage"] + sum_prompt_tokens += token_usage["prompt_tokens"] + sum_completion_tokens += token_usage["completion_tokens"] + sum_total_tokens += token_usage["total_tokens"] + + llm_output["token_usage"] = { + "prompt_tokens": sum_prompt_tokens, + "completion_tokens": sum_completion_tokens, + "sum_total_tokens": sum_total_tokens, + } + + return LLMResult(generations=generations, llm_output=llm_output) + + +def llm_factory() -> LangchainLLM: + oai_key = os.getenv("OPENAI_API_KEY", "no-key") + openai_llm = ChatOpenAI(openai_api_key=oai_key) + return LangchainLLM(llm=openai_llm) diff --git a/src/ragas/llms/llamaindex.py b/src/ragas/llms/llamaindex.py new file mode 100644 index 000000000..2a42d9efd --- /dev/null +++ b/src/ragas/llms/llamaindex.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import typing as t + +from langchain.schema.output import Generation, LLMResult +from llama_index.llms.base import LLM as LiLLM + +from ragas.async_utils import run_async_tasks +from ragas.llms.base import BaseRagasLLM + +if t.TYPE_CHECKING: + from langchain.callbacks.base import Callbacks + from langchain.prompts import ChatPromptTemplate + + +class LlamaIndexLLM(BaseRagasLLM): + def __init__(self, llm: LiLLM) -> None: + self.llama_index_llm = llm + + @property + def llm(self) -> LiLLM: + return self.llama_index_llm + + def generate( + self, + prompts: list[ChatPromptTemplate], + n: int = 1, + temperature: float = 0, + callbacks: t.Optional[Callbacks] = None, + ) -> LLMResult: + # set temperature to 0.2 for multiple completions + temperature = 0.2 if n > 1 else 0 + self.llm.temperature = temperature + + # get task coroutines + tasks = [] + for p in prompts: + tasks.extend([self.llm.acomplete(p.format()) for _ in range(n)]) + + # process results to LLMResult + # token usage is note included for now + results = run_async_tasks(tasks) + results2D = [results[i : i + n] for i in range(0, len(results), n)] + generations = [ + [Generation(text=r.text) for r in result] for result in results2D + ] + return LLMResult(generations=generations) diff --git a/src/ragas/metrics/base.py b/src/ragas/metrics/base.py index 274a289ff..d36e43d90 100644 --- a/src/ragas/metrics/base.py +++ b/src/ragas/metrics/base.py @@ -6,7 +6,6 @@ """ from __future__ import annotations -import os import typing as t from abc import ABC, abstractmethod from dataclasses import dataclass, field @@ -20,7 +19,7 @@ from tqdm import tqdm from ragas.exceptions import OpenAIKeyNotFound -from ragas.metrics.llms import LangchainLLM +from ragas.llms import LangchainLLM, llm_factory if t.TYPE_CHECKING: from langchain.callbacks.base import Callbacks @@ -109,15 +108,9 @@ def get_batches(self, dataset_size: int) -> list[range]: return make_batches(dataset_size, self.batch_size) -def _llm_factory() -> LangchainLLM: - oai_key = os.getenv("OPENAI_API_KEY", "no-key") - openai_llm = ChatOpenAI(openai_api_key=oai_key) - return LangchainLLM(llm=openai_llm) - - @dataclass class MetricWithLLM(Metric): - llm: LangchainLLM = field(default_factory=_llm_factory) + llm: LangchainLLM = field(default_factory=llm_factory) def init_model(self): if isinstance(self.llm, ChatOpenAI) or isinstance(self.llm, OpenAI): diff --git a/src/ragas/metrics/critique.py b/src/ragas/metrics/critique.py index 15eb77c7e..86ad10b5a 100644 --- a/src/ragas/metrics/critique.py +++ b/src/ragas/metrics/critique.py @@ -8,8 +8,8 @@ from langchain.callbacks.manager import CallbackManager, trace_as_chain_group from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate -from ragas.metrics.base import EvaluationMode, MetricWithLLM, _llm_factory -from ragas.metrics.llms import LangchainLLM +from ragas.llms import LangchainLLM +from ragas.metrics.base import EvaluationMode, MetricWithLLM, llm_factory CRITIQUE_PROMPT = HumanMessagePromptTemplate.from_template( """Given a input and submission. Evaluate the submission only using the given criteria. @@ -56,7 +56,7 @@ class AspectCritique(MetricWithLLM): strictness: int = field(default=1, repr=False) batch_size: int = field(default=15, repr=False) llm: LangchainLLM = field( - default_factory=_llm_factory, + default_factory=llm_factory, repr=False, ) diff --git a/src/ragas/testset/testset_generator.py b/src/ragas/testset/testset_generator.py index ea0f4a06f..77a7a0d08 100644 --- a/src/ragas/testset/testset_generator.py +++ b/src/ragas/testset/testset_generator.py @@ -20,7 +20,7 @@ from numpy.random import default_rng from tqdm import tqdm -from ragas.metrics.llms import LangchainLLM +from ragas.llms import LangchainLLM from ragas.testset.prompts import ( ANSWER_FORMULATE, COMPRESS_QUESTION,