Skip to content

Commit

Permalink
Feat/llamaindex: adding llamaindex (#999)
Browse files Browse the repository at this point in the history
fixes: #557 


its been long πŸ™‚
  • Loading branch information
jjmachan authored May 30, 2024
1 parent e2c57b1 commit 0319c19
Show file tree
Hide file tree
Showing 12 changed files with 718 additions and 200 deletions.
627 changes: 444 additions & 183 deletions docs/howtos/integrations/llamaindex.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions src/ragas/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
BaseRagasEmbeddings,
HuggingfaceEmbeddings,
LangchainEmbeddingsWrapper,
LlamaIndexEmbeddingsWrapper,
embedding_factory,
)

__all__ = [
"HuggingfaceEmbeddings",
"BaseRagasEmbeddings",
"LangchainEmbeddingsWrapper",
"LlamaIndexEmbeddingsWrapper",
"embedding_factory",
]
25 changes: 25 additions & 0 deletions src/ragas/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@

from ragas.run_config import RunConfig, add_async_retry, add_retry

if t.TYPE_CHECKING:
from llama_index.core.base.embeddings.base import BaseEmbedding

DEFAULT_MODEL_NAME = "BAAI/bge-small-en-v1.5"


Expand Down Expand Up @@ -153,6 +156,28 @@ def predict(self, texts: List[List[str]]) -> List[List[float]]:
return predictions.tolist()


class LlamaIndexEmbeddingsWrapper(BaseRagasEmbeddings):
def __init__(
self, embeddings: BaseEmbedding, run_config: t.Optional[RunConfig] = None
):
self.embeddings = embeddings
if run_config is None:
run_config = RunConfig()
self.set_run_config(run_config)

def embed_query(self, text: str) -> t.List[float]:
return self.embeddings.get_query_embedding(text)

def embed_documents(self, texts: t.List[str]) -> t.List[t.List[float]]:
return self.embeddings.get_text_embedding_batch(texts)

async def aembed_query(self, text: str) -> t.List[float]:
return await self.embeddings.aget_query_embedding(text)

async def aembed_documents(self, texts: t.List[str]) -> t.List[t.List[float]]:
return await self.embeddings.aget_text_embedding_batch(texts)


def embedding_factory(
model: str = "text-embedding-ada-002", run_config: t.Optional[RunConfig] = None
) -> BaseRagasEmbeddings:
Expand Down
1 change: 0 additions & 1 deletion src/ragas/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@


def runner_exception_hook(args: threading.ExceptHookArgs):
print(args)
raise args.exc_type


Expand Down
113 changes: 113 additions & 0 deletions src/ragas/integrations/llama_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from __future__ import annotations

import logging
import typing as t
from copy import copy
from uuid import uuid4

from datasets import Dataset

from ragas.embeddings import LlamaIndexEmbeddingsWrapper
from ragas.evaluation import evaluate as ragas_evaluate
from ragas.exceptions import ExceptionInRunner
from ragas.executor import Executor
from ragas.llms import LlamaIndexLLMWrapper
from ragas.validation import EVALMODE_TO_COLUMNS, validate_evaluation_modes

if t.TYPE_CHECKING:
from llama_index.core.base.embeddings.base import (
BaseEmbedding as LlamaIndexEmbeddings,
)
from llama_index.core.base.llms.base import BaseLLM as LlamaindexLLM

from ragas.evaluation import Result
from ragas.metrics.base import Metric


logger = logging.getLogger(__name__)


def validate_dataset(dataset: dict, metrics: list[Metric]):
# change EVALMODE_TO_COLUMNS for usecase with no contexts and answer
evalmod_to_columns_llamaindex = copy(EVALMODE_TO_COLUMNS)
for mode in EVALMODE_TO_COLUMNS:
if "answer" in EVALMODE_TO_COLUMNS[mode]:
EVALMODE_TO_COLUMNS[mode].remove("answer")
if "contexts" in EVALMODE_TO_COLUMNS[mode]:
EVALMODE_TO_COLUMNS[mode].remove("contexts")

hf_dataset = Dataset.from_dict(dataset)
validate_evaluation_modes(hf_dataset, metrics, evalmod_to_columns_llamaindex)


def evaluate(
query_engine,
dataset: dict,
metrics: list[Metric],
llm: t.Optional[LlamaindexLLM] = None,
embeddings: t.Optional[LlamaIndexEmbeddings] = None,
raise_exceptions: bool = True,
column_map: t.Optional[t.Dict[str, str]] = None,
) -> Result:
column_map = column_map or {}

# wrap llms and embeddings
li_llm = None
if llm is not None:
li_llm = LlamaIndexLLMWrapper(llm)
li_embeddings = None
if embeddings is not None:
li_embeddings = LlamaIndexEmbeddingsWrapper(embeddings)

# validate and transform dataset
if dataset is None:
raise ValueError("Provide dataset!")

exec = Executor(
desc="Running Query Engine",
keep_progress_bar=True,
raise_exceptions=raise_exceptions,
)

# get query
queries = dataset["question"]
for i, q in enumerate(queries):
exec.submit(query_engine.aquery, q, name=f"query-{i}")

answers: t.List[str] = []
contexts: t.List[t.List[str]] = []
try:
results = exec.results()
if results == []:
raise ExceptionInRunner()
except Exception as e:
raise e
else:
for r in results:
answers.append(r.response)
contexts.append([n.node.text for n in r.source_nodes])

# create HF dataset
hf_dataset = Dataset.from_dict(
{
"question": queries,
"contexts": contexts,
"answer": answers,
}
)
if "ground_truth" in dataset:
hf_dataset = hf_dataset.add_column(
name="ground_truth",
column=dataset["ground_truth"],
new_fingerprint=str(uuid4()),
)

results = ragas_evaluate(
dataset=hf_dataset,
metrics=metrics,
llm=li_llm,
embeddings=li_embeddings,
raise_exceptions=raise_exceptions,
)

return results
8 changes: 7 additions & 1 deletion src/ragas/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper, llm_factory
from ragas.llms.base import (
BaseRagasLLM,
LangchainLLMWrapper,
LlamaIndexLLMWrapper,
llm_factory,
)

__all__ = [
"BaseRagasLLM",
"LangchainLLMWrapper",
"LlamaIndexLLMWrapper",
"llm_factory",
]
75 changes: 74 additions & 1 deletion src/ragas/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from langchain_community.chat_models.vertexai import ChatVertexAI
from langchain_community.llms import VertexAI
from langchain_core.language_models import BaseLanguageModel
from langchain_core.outputs import LLMResult
from langchain_core.outputs import Generation, LLMResult
from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain_openai.llms import AzureOpenAI, OpenAI
from langchain_openai.llms.base import BaseOpenAI
Expand All @@ -19,6 +19,7 @@

if t.TYPE_CHECKING:
from langchain_core.callbacks import Callbacks
from llama_index.core.base.llms.base import BaseLLM

from ragas.llms.prompt import PromptValue

Expand Down Expand Up @@ -203,6 +204,78 @@ def set_run_config(self, run_config: RunConfig):
self.run_config.exception_types = RateLimitError


class LlamaIndexLLMWrapper(BaseRagasLLM):
"""
A Adaptor for LlamaIndex LLMs
"""

def __init__(
self,
llm: BaseLLM,
run_config: t.Optional[RunConfig] = None,
):
self.llm = llm

self._signature = ""
if type(self.llm).__name__.lower() == "bedrock":
self._signature = "bedrock"
if run_config is None:
run_config = RunConfig()
self.set_run_config(run_config)

def check_args(
self,
n: int,
temperature: float,
stop: t.Optional[t.List[str]],
callbacks: Callbacks,
) -> dict[str, t.Any]:
if n != 1:
logger.warning("n values greater than 1 not support for LlamaIndex LLMs")
if temperature != 1e-8:
logger.info("temperature kwarg passed to LlamaIndex LLM")
if stop is not None:
logger.info("stop kwarg passed to LlamaIndex LLM")
if callbacks is not None:
logger.info(
"callbacks not supported for LlamaIndex LLMs, ignoring callbacks"
)
if self._signature == "bedrock":
return {"temperature": temperature}
else:
return {
"n": n,
"temperature": temperature,
"stop": stop,
}

def generate_text(
self,
prompt: PromptValue,
n: int = 1,
temperature: float = 1e-8,
stop: t.Optional[t.List[str]] = None,
callbacks: Callbacks = None,
) -> LLMResult:
kwargs = self.check_args(n, temperature, stop, callbacks)
li_response = self.llm.complete(prompt.to_string(), **kwargs)

return LLMResult(generations=[[Generation(text=li_response.text)]])

async def agenerate_text(
self,
prompt: PromptValue,
n: int = 1,
temperature: float = 1e-8,
stop: t.Optional[t.List[str]] = None,
callbacks: Callbacks = None,
) -> LLMResult:
kwargs = self.check_args(n, temperature, stop, callbacks)
li_response = await self.llm.acomplete(prompt.to_string(), **kwargs)

return LLMResult(generations=[[Generation(text=li_response.text)]])


def llm_factory(
model: str = "gpt-3.5-turbo", run_config: t.Optional[RunConfig] = None
) -> BaseRagasLLM:
Expand Down
6 changes: 3 additions & 3 deletions src/ragas/llms/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ class Prompt(BaseModel):
language (str): The language of the prompt (default: "english").
"""

name: str
name: str = ""
instruction: str
output_format_instruction: str = ""
examples: t.List[Example] = []
input_keys: t.List[str]
output_key: str
input_keys: t.List[str] = [""]
output_key: str = ""
output_type: t.Literal["json", "str"] = "json"
language: str = "english"

Expand Down
1 change: 0 additions & 1 deletion src/ragas/metrics/_context_relevancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def _compute_score(self, response: str, row: t.Dict) -> float:
if response.lower() != "insufficient information."
else []
)
# print(len(indices))
if len(context_sents) == 0:
return 0
else:
Expand Down
52 changes: 45 additions & 7 deletions src/ragas/testset/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,18 @@

import pandas as pd
from datasets import Dataset
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel
from langchain_openai.chat_models import ChatOpenAI
from langchain_openai.embeddings import OpenAIEmbeddings

from ragas._analytics import TestsetGenerationEvent, track
from ragas.embeddings.base import BaseRagasEmbeddings, LangchainEmbeddingsWrapper
from ragas.embeddings.base import (
BaseRagasEmbeddings,
LangchainEmbeddingsWrapper,
LlamaIndexEmbeddingsWrapper,
)
from ragas.exceptions import ExceptionInRunner
from ragas.executor import Executor
from ragas.llms import BaseRagasLLM, LangchainLLMWrapper
from ragas.llms import BaseRagasLLM, LangchainLLMWrapper, LlamaIndexLLMWrapper
from ragas.run_config import RunConfig
from ragas.testset.docstore import Document, DocumentStore, InMemoryDocumentStore
from ragas.testset.evolutions import (
Expand All @@ -34,6 +36,12 @@

if t.TYPE_CHECKING:
from langchain_core.documents import Document as LCDocument
from langchain_core.embeddings import Embeddings as LangchainEmbeddings
from langchain_core.language_models import BaseLanguageModel as LangchainLLM
from llama_index.core.base.embeddings.base import (
BaseEmbedding as LlamaIndexEmbeddings,
)
from llama_index.core.base.llms.base import BaseLLM as LlamaindexLLM
from llama_index.core.schema import Document as LlamaindexDocument

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -75,9 +83,9 @@ class TestsetGenerator:
@classmethod
def from_langchain(
cls,
generator_llm: BaseLanguageModel,
critic_llm: BaseLanguageModel,
embeddings: Embeddings,
generator_llm: LangchainLLM,
critic_llm: LangchainLLM,
embeddings: LangchainEmbeddings,
docstore: t.Optional[DocumentStore] = None,
run_config: t.Optional[RunConfig] = None,
chunk_size: int = 1024,
Expand All @@ -104,6 +112,36 @@ def from_langchain(
docstore=docstore,
)

@classmethod
def from_llama_index(
cls,
generator_llm: LlamaindexLLM,
critic_llm: LlamaindexLLM,
embeddings: LlamaIndexEmbeddings,
docstore: t.Optional[DocumentStore] = None,
run_config: t.Optional[RunConfig] = None,
) -> "TestsetGenerator":
generator_llm_model = LlamaIndexLLMWrapper(generator_llm)
critic_llm_model = LlamaIndexLLMWrapper(critic_llm)
embeddings_model = LlamaIndexEmbeddingsWrapper(embeddings)
keyphrase_extractor = KeyphraseExtractor(llm=generator_llm_model)
if docstore is None:
from langchain.text_splitter import TokenTextSplitter

splitter = TokenTextSplitter(chunk_size=1024, chunk_overlap=0)
docstore = InMemoryDocumentStore(
splitter=splitter,
embeddings=embeddings_model,
extractor=keyphrase_extractor,
run_config=run_config,
)
return cls(
generator_llm=generator_llm_model,
critic_llm=critic_llm_model,
embeddings=embeddings_model,
docstore=docstore,
)

@classmethod
@deprecated("0.1.4", removal="0.2.0", alternative="from_langchain")
def with_openai(
Expand Down
Loading

0 comments on commit 0319c19

Please sign in to comment.