Skip to content

Commit

Permalink
feat: langchain documents support for TestsetGenerator (#201)
Browse files Browse the repository at this point in the history
Usage

```py
from langchain.document_loaders import PubMedLoader
from ragas.testset import TestsetGenerator

loader = PubMedLoader("liver", load_max_docs=10)
docs = loader.load()
len(docs)
# 10

testsetgenerator = TestsetGenerator.from_default()
test_size = 10
testset = testsetgenerator.generate(docs, test_size=test_size)
test_df = testset.to_pandas()
test_df.head()
```
  • Loading branch information
jjmachan authored Oct 19, 2023
1 parent 2d42d69 commit 793bf9a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
20 changes: 18 additions & 2 deletions src/ragas/testset/testset_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
from langchain.embeddings.base import Embeddings
from langchain.llms.base import BaseLLM
from langchain.prompts import ChatPromptTemplate
from langchain.schema.document import Document as LangchainDocument
from llama_index.indices.query.embedding_utils import get_top_k_embeddings
from llama_index.node_parser.simple import SimpleNodeParser
from llama_index.readers.schema import Document
from llama_index.readers.schema import Document as LlamaindexDocument
from llama_index.schema import BaseNode
from numpy.random import default_rng
from tqdm import tqdm
Expand Down Expand Up @@ -276,11 +277,26 @@ def _embed_nodes(self, nodes: t.List[BaseNode]) -> t.Dict[str, t.List[float]]:

return embeddings

def generate(self, documents: t.List[Document], test_size: int) -> TestDataset:
def generate(
self,
documents: list[LlamaindexDocument] | list[LangchainDocument],
test_size: int,
) -> TestDataset:
if isinstance(documents[0], LangchainDocument):
# cast to LangchainDocument since its the only case here
documents = t.cast(list[LangchainDocument], documents)
documents = [
LlamaindexDocument.from_langchain_format(doc) for doc in documents
]
elif not isinstance(documents[0], LlamaindexDocument):
raise ValueError(
"Testset Generatation only supports LlamaindexDocuments or LangchainDocuments" # noqa
)
# Convert documents into nodes
node_parser = SimpleNodeParser.from_defaults(
chunk_size=self.chunk_size, chunk_overlap=0, include_metadata=True
)
documents = t.cast(list[LlamaindexDocument], documents)
document_nodes: t.List[BaseNode] = node_parser.get_nodes_from_documents(
documents=documents
)
Expand Down
2 changes: 2 additions & 0 deletions src/ragas/testset/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import json
import re
import warnings
Expand Down

0 comments on commit 793bf9a

Please sign in to comment.