diff --git a/docs/extra/components/choose_generator_llm.md b/docs/extra/components/choose_generator_llm.md index e971dd8bf..504739444 100644 --- a/docs/extra/components/choose_generator_llm.md +++ b/docs/extra/components/choose_generator_llm.md @@ -16,6 +16,7 @@ ```python from ragas.llms import LangchainLLMWrapper + from ragas.embeddings import LangchainEmbeddingsWrapper from langchain_openai import ChatOpenAI from langchain_openai import OpenAIEmbeddings generator_llm = LangchainLLMWrapper(ChatOpenAI(model="gpt-4o")) diff --git a/src/ragas/testset/transforms/base.py b/src/ragas/testset/transforms/base.py index 3c1892c81..49945e482 100644 --- a/src/ragas/testset/transforms/base.py +++ b/src/ragas/testset/transforms/base.py @@ -3,10 +3,15 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field +import tiktoken +from tiktoken.core import Encoding + from ragas.llms import BaseRagasLLM, llm_factory from ragas.prompt import PromptMixin from ragas.testset.graph import KnowledgeGraph, Node, Relationship +DEFAULT_TOKENIZER = tiktoken.get_encoding("o200k_base") + logger = logging.getLogger(__name__) @@ -188,6 +193,21 @@ async def apply_extract(node: Node): class LLMBasedExtractor(Extractor, PromptMixin): llm: BaseRagasLLM = field(default_factory=llm_factory) merge_if_possible: bool = True + max_token_limit: int = 32000 + tokenizer: Encoding = DEFAULT_TOKENIZER + + def split_text_by_token_limit(self, text, max_token_limit): + + # Tokenize the entire input string + tokens = self.tokenizer.encode(text) + + # Split tokens into chunks of max_token_limit or less + chunks = [] + for i in range(0, len(tokens), max_token_limit): + chunk_tokens = tokens[i : i + max_token_limit] + chunks.append(self.tokenizer.decode(chunk_tokens)) + + return chunks class Splitter(BaseGraphTransformation): diff --git a/src/ragas/testset/transforms/extractors/llm_based.py b/src/ragas/testset/transforms/extractors/llm_based.py index 04616daa1..21ec066a2 100644 --- a/src/ragas/testset/transforms/extractors/llm_based.py +++ b/src/ragas/testset/transforms/extractors/llm_based.py @@ -114,7 +114,9 @@ class HeadlinesExtractorPrompt(PydanticPrompt[TextWithExtractionLimit, Headlines "Introduction", "Main Concepts", "Detailed Analysis", + "Subsection: Specialized Techniques" "Future Directions", + "Conclusion", ], ), ), @@ -174,14 +176,15 @@ async def extract(self, node: Node) -> t.Tuple[str, t.Any]: node_text = node.get_property("page_content") if node_text is None: return self.property_name, None - result = await self.prompt.generate(self.llm, data=StringIO(text=node_text)) + chunks = self.split_text_by_token_limit(node_text, self.max_token_limit) + result = await self.prompt.generate(self.llm, data=StringIO(text=chunks[0])) return self.property_name, result.text @dataclass class KeyphrasesExtractor(LLMBasedExtractor): """ - Extracts top 5 keyphrases from the given text. + Extracts top keyphrases from the given text. Attributes ---------- @@ -199,10 +202,15 @@ async def extract(self, node: Node) -> t.Tuple[str, t.Any]: node_text = node.get_property("page_content") if node_text is None: return self.property_name, None - result = await self.prompt.generate( - self.llm, data=TextWithExtractionLimit(text=node_text, max_num=self.max_num) - ) - return self.property_name, result.keyphrases + chunks = self.split_text_by_token_limit(node_text, self.max_token_limit) + keyphrases = [] + for chunk in chunks: + result = await self.prompt.generate( + self.llm, data=TextWithExtractionLimit(text=chunk, max_num=self.max_num) + ) + keyphrases.extend(result.keyphrases) + return self.property_name, keyphrases + @dataclass @@ -225,7 +233,8 @@ async def extract(self, node: Node) -> t.Tuple[str, t.Any]: node_text = node.get_property("page_content") if node_text is None: return self.property_name, None - result = await self.prompt.generate(self.llm, data=StringIO(text=node_text)) + chunks = self.split_text_by_token_limit(node_text, self.max_token_limit) + result = await self.prompt.generate(self.llm, data=StringIO(text=chunks[0])) return self.property_name, result.text @@ -250,12 +259,15 @@ async def extract(self, node: Node) -> t.Tuple[str, t.Any]: node_text = node.get_property("page_content") if node_text is None: return self.property_name, None - result = await self.prompt.generate( - self.llm, data=TextWithExtractionLimit(text=node_text, max_num=self.max_num) - ) - if result is None: - return self.property_name, None - return self.property_name, result.headlines + chunks = self.split_text_by_token_limit(node_text, self.max_token_limit) + headlines = [] + for chunk in chunks: + result = await self.prompt.generate( + self.llm, data=TextWithExtractionLimit(text=chunk, max_num=self.max_num) + ) + if result: + headlines.extend(result.headlines) + return self.property_name, headlines @dataclass @@ -279,11 +291,15 @@ async def extract(self, node: Node) -> t.Tuple[str, t.List[str]]: node_text = node.get_property("page_content") if node_text is None: return self.property_name, [] - result = await self.prompt.generate( - self.llm, - data=TextWithExtractionLimit(text=node_text, max_num=self.max_num_entities), - ) - return self.property_name, result.entities + chunks = self.split_text_by_token_limit(node_text, self.max_token_limit) + entities = [] + for chunk in chunks: + result = await self.prompt.generate( + self.llm, + data=TextWithExtractionLimit(text=chunk, max_num=self.max_num_entities), + ) + entities.extend(result.entities) + return self.property_name, entities class TopicDescription(BaseModel): @@ -328,7 +344,8 @@ async def extract(self, node: Node) -> t.Tuple[str, t.Any]: node_text = node.get_property("page_content") if node_text is None: return self.property_name, None - result = await self.prompt.generate(self.llm, data=StringIO(text=node_text)) + chunks = self.split_text_by_token_limit(node_text, self.max_token_limit) + result = await self.prompt.generate(self.llm, data=StringIO(text=chunks[0])) return self.property_name, result.description @@ -383,8 +400,13 @@ async def extract(self, node: Node) -> t.Tuple[str, t.List[str]]: node_text = node.get_property("page_content") if node_text is None: return self.property_name, [] - result = await self.prompt.generate( - self.llm, - data=TextWithExtractionLimit(text=node_text, max_num=self.max_num_themes), - ) - return self.property_name, result.output + chunks = self.split_text_by_token_limit(node_text, self.max_token_limit) + themes = [] + for chunk in chunks: + result = await self.prompt.generate( + self.llm, + data=TextWithExtractionLimit(text=chunk, max_num=self.max_num_themes), + ) + themes.extend(result.output) + + return self.property_name, themes