Skip to content

Commit

Permalink
feat: llamaIndex llm support (#205)
Browse files Browse the repository at this point in the history
Added support for LlamaIndex `ServiceContext` and `BaseLLM`.

Helps directly use llamaIndex LLMs with Ragas
  • Loading branch information
jjmachan authored Oct 23, 2023
1 parent c2a64d5 commit 75db229
Show file tree
Hide file tree
Showing 8 changed files with 222 additions and 17 deletions.
8 changes: 4 additions & 4 deletions docs/howtos/customisations/llms.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@
"from ragas import evaluate\n",
"\n",
"result = evaluate(\n",
" fiqa_eval[\"baseline\"].select(range(5)), # showing only 5 for demonstration \n",
" metrics=[faithfulness]\n",
" fiqa_eval[\"baseline\"].select(range(5)), # showing only 5 for demonstration\n",
" metrics=[faithfulness],\n",
")\n",
"\n",
"result"
Expand Down Expand Up @@ -301,8 +301,8 @@
"from ragas import evaluate\n",
"\n",
"result = evaluate(\n",
" fiqa_eval[\"baseline\"].select(range(5)), # showing only 5 for demonstration \n",
" metrics=[faithfulness]\n",
" fiqa_eval[\"baseline\"].select(range(5)), # showing only 5 for demonstration\n",
" metrics=[faithfulness],\n",
")\n",
"\n",
"result"
Expand Down
1 change: 1 addition & 0 deletions requirements/docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ myst-parser[linkify]
sphinx_design
astroid<3
myst-nb
llama_index
4 changes: 4 additions & 0 deletions src/ragas/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from ragas.llms.base import BaseRagasLLM, LangchainLLM, llm_factory
from ragas.llms.llamaindex import LlamaIndexLLM

__all__ = ["BaseRagasLLM", "LangchainLLM", "LlamaIndexLLM", "llm_factory"]
160 changes: 160 additions & 0 deletions src/ragas/llms/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from __future__ import annotations

import os
import typing as t
from abc import ABC, abstractmethod

from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain.chat_models.base import BaseChatModel
from langchain.llms import AzureOpenAI, OpenAI
from langchain.llms.base import BaseLLM
from langchain.schema import LLMResult

from ragas.async_utils import run_async_tasks

if t.TYPE_CHECKING:
from langchain.callbacks.base import Callbacks
from langchain.prompts import ChatPromptTemplate


def isOpenAI(llm: BaseLLM | BaseChatModel) -> bool:
return isinstance(llm, OpenAI) or isinstance(llm, ChatOpenAI)


# have to specify it twice for runtime and static checks
MULTIPLE_COMPLETION_SUPPORTED = [OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI]
MultipleCompletionSupportedLLM = t.Union[
OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI
]


class BaseRagasLLM(ABC):
"""
BaseLLM is the base class for all LLMs. It provides a consistent interface for other
classes that interact with LLMs like Langchains, LlamaIndex, LiteLLM etc. Handles
multiple_completions even if not supported by the LLM.
It currently takes in ChatPromptTemplates and returns LLMResults which are Langchain
primitives.
"""

# supports multiple compeletions for the given prompt
n_completions_supported: bool = False

@property
@abstractmethod
def llm(self):
...

@abstractmethod
def generate(
self,
prompts: list[str],
n: int = 1,
temperature: float = 0,
callbacks: t.Optional[Callbacks] = None,
) -> list[list[str]]:
...


class LangchainLLM(BaseRagasLLM):
n_completions_supported: bool = True

def __init__(self, llm: BaseLLM | BaseChatModel):
self.langchain_llm = llm

@property
def llm(self):
return self.langchain_llm

@staticmethod
def llm_supports_completions(llm):
for llm_type in MULTIPLE_COMPLETION_SUPPORTED:
if isinstance(llm, llm_type):
return True

def generate_multiple_completions(
self,
prompts: list[ChatPromptTemplate],
n: int = 1,
callbacks: t.Optional[Callbacks] = None,
) -> LLMResult:
self.langchain_llm = t.cast(MultipleCompletionSupportedLLM, self.langchain_llm)
old_n = self.langchain_llm.n
self.langchain_llm.n = n

if isinstance(self.llm, BaseLLM):
ps = [p.format() for p in prompts]
result = self.llm.generate(ps, callbacks=callbacks)
else: # if BaseChatModel
ps = [p.format_messages() for p in prompts]
result = self.llm.generate(ps, callbacks=callbacks)
self.llm.n = old_n

return result

async def generate_completions(
self,
prompts: list[ChatPromptTemplate],
callbacks: t.Optional[Callbacks] = None,
) -> LLMResult:
if isinstance(self.llm, BaseLLM):
ps = [p.format() for p in prompts]
result = await self.llm.agenerate(ps, callbacks=callbacks)
else: # if BaseChatModel
ps = [p.format_messages() for p in prompts]
result = await self.llm.agenerate(ps, callbacks=callbacks)

return result

def generate(
self,
prompts: list[ChatPromptTemplate],
n: int = 1,
temperature: float = 0,
callbacks: t.Optional[Callbacks] = None,
) -> LLMResult:
# set temperature to 0.2 for multiple completions
temperature = 0.2 if n > 1 else 0
self.llm.temperature = temperature

if self.llm_supports_completions(self.llm):
return self.generate_multiple_completions(prompts, n, callbacks)
else: # call generate_completions n times to mimic multiple completions
list_llmresults = run_async_tasks(
[self.generate_completions(prompts, callbacks) for _ in range(n)]
)

# fill results as if the LLM supported multiple completions
generations = []
for i in range(len(prompts)):
completions = []
for result in list_llmresults:
completions.append(result.generations[i][0])
generations.append(completions)

# compute total token usage by adding individual token usage
llm_output = list_llmresults[0].llm_output
if "token_usage" in llm_output:
sum_prompt_tokens = 0
sum_completion_tokens = 0
sum_total_tokens = 0
for result in list_llmresults:
token_usage = result.llm_output["token_usage"]
sum_prompt_tokens += token_usage["prompt_tokens"]
sum_completion_tokens += token_usage["completion_tokens"]
sum_total_tokens += token_usage["total_tokens"]

llm_output["token_usage"] = {
"prompt_tokens": sum_prompt_tokens,
"completion_tokens": sum_completion_tokens,
"sum_total_tokens": sum_total_tokens,
}

return LLMResult(generations=generations, llm_output=llm_output)


def llm_factory() -> LangchainLLM:
oai_key = os.getenv("OPENAI_API_KEY", "no-key")
openai_llm = ChatOpenAI(openai_api_key=oai_key)
return LangchainLLM(llm=openai_llm)
47 changes: 47 additions & 0 deletions src/ragas/llms/llamaindex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import annotations

import typing as t

from langchain.schema.output import Generation, LLMResult
from llama_index.llms.base import LLM as LiLLM

from ragas.async_utils import run_async_tasks
from ragas.llms.base import BaseRagasLLM

if t.TYPE_CHECKING:
from langchain.callbacks.base import Callbacks
from langchain.prompts import ChatPromptTemplate


class LlamaIndexLLM(BaseRagasLLM):
def __init__(self, llm: LiLLM) -> None:
self.llama_index_llm = llm

@property
def llm(self) -> LiLLM:
return self.llama_index_llm

def generate(
self,
prompts: list[ChatPromptTemplate],
n: int = 1,
temperature: float = 0,
callbacks: t.Optional[Callbacks] = None,
) -> LLMResult:
# set temperature to 0.2 for multiple completions
temperature = 0.2 if n > 1 else 0
self.llm.temperature = temperature

# get task coroutines
tasks = []
for p in prompts:
tasks.extend([self.llm.acomplete(p.format()) for _ in range(n)])

# process results to LLMResult
# token usage is note included for now
results = run_async_tasks(tasks)
results2D = [results[i : i + n] for i in range(0, len(results), n)]
generations = [
[Generation(text=r.text) for r in result] for result in results2D
]
return LLMResult(generations=generations)
11 changes: 2 additions & 9 deletions src/ragas/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
"""
from __future__ import annotations

import os
import typing as t
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
Expand All @@ -20,7 +19,7 @@
from tqdm import tqdm

from ragas.exceptions import OpenAIKeyNotFound
from ragas.metrics.llms import LangchainLLM
from ragas.llms import LangchainLLM, llm_factory

if t.TYPE_CHECKING:
from langchain.callbacks.base import Callbacks
Expand Down Expand Up @@ -109,15 +108,9 @@ def get_batches(self, dataset_size: int) -> list[range]:
return make_batches(dataset_size, self.batch_size)


def _llm_factory() -> LangchainLLM:
oai_key = os.getenv("OPENAI_API_KEY", "no-key")
openai_llm = ChatOpenAI(openai_api_key=oai_key)
return LangchainLLM(llm=openai_llm)


@dataclass
class MetricWithLLM(Metric):
llm: LangchainLLM = field(default_factory=_llm_factory)
llm: LangchainLLM = field(default_factory=llm_factory)

def init_model(self):
if isinstance(self.llm, ChatOpenAI) or isinstance(self.llm, OpenAI):
Expand Down
6 changes: 3 additions & 3 deletions src/ragas/metrics/critique.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from langchain.callbacks.manager import CallbackManager, trace_as_chain_group
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate

from ragas.metrics.base import EvaluationMode, MetricWithLLM, _llm_factory
from ragas.metrics.llms import LangchainLLM
from ragas.llms import LangchainLLM
from ragas.metrics.base import EvaluationMode, MetricWithLLM, llm_factory

CRITIQUE_PROMPT = HumanMessagePromptTemplate.from_template(
"""Given a input and submission. Evaluate the submission only using the given criteria.
Expand Down Expand Up @@ -56,7 +56,7 @@ class AspectCritique(MetricWithLLM):
strictness: int = field(default=1, repr=False)
batch_size: int = field(default=15, repr=False)
llm: LangchainLLM = field(
default_factory=_llm_factory,
default_factory=llm_factory,
repr=False,
)

Expand Down
2 changes: 1 addition & 1 deletion src/ragas/testset/testset_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from numpy.random import default_rng
from tqdm import tqdm

from ragas.metrics.llms import LangchainLLM
from ragas.llms import LangchainLLM
from ragas.testset.prompts import (
ANSWER_FORMULATE,
COMPRESS_QUESTION,
Expand Down

0 comments on commit 75db229

Please sign in to comment.