diff --git a/src/ragas/testset/testset_generator.py b/src/ragas/testset/testset_generator.py index 16f58a39c..85c9ce211 100644 --- a/src/ragas/testset/testset_generator.py +++ b/src/ragas/testset/testset_generator.py @@ -14,9 +14,10 @@ from langchain.embeddings.base import Embeddings from langchain.llms.base import BaseLLM from langchain.prompts import ChatPromptTemplate +from langchain.schema.document import Document as LangchainDocument from llama_index.indices.query.embedding_utils import get_top_k_embeddings from llama_index.node_parser.simple import SimpleNodeParser -from llama_index.readers.schema import Document +from llama_index.readers.schema import Document as LlamaindexDocument from llama_index.schema import BaseNode from numpy.random import default_rng from tqdm import tqdm @@ -276,11 +277,26 @@ def _embed_nodes(self, nodes: t.List[BaseNode]) -> t.Dict[str, t.List[float]]: return embeddings - def generate(self, documents: t.List[Document], test_size: int) -> TestDataset: + def generate( + self, + documents: list[LlamaindexDocument] | list[LangchainDocument], + test_size: int, + ) -> TestDataset: + if isinstance(documents[0], LangchainDocument): + # cast to LangchainDocument since its the only case here + documents = t.cast(list[LangchainDocument], documents) + documents = [ + LlamaindexDocument.from_langchain_format(doc) for doc in documents + ] + elif not isinstance(documents[0], LlamaindexDocument): + raise ValueError( + "Testset Generatation only supports LlamaindexDocuments or LangchainDocuments" # noqa + ) # Convert documents into nodes node_parser = SimpleNodeParser.from_defaults( chunk_size=self.chunk_size, chunk_overlap=0, include_metadata=True ) + documents = t.cast(list[LlamaindexDocument], documents) document_nodes: t.List[BaseNode] = node_parser.get_nodes_from_documents( documents=documents ) diff --git a/src/ragas/testset/utils.py b/src/ragas/testset/utils.py index e0660bc65..bb97e5ac4 100644 --- a/src/ragas/testset/utils.py +++ b/src/ragas/testset/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import re import warnings