Skip to content

Commit

Permalink
fixes: handle long context extraction (#1680)
Browse files Browse the repository at this point in the history
  • Loading branch information
shahules786 authored Nov 19, 2024
1 parent c729d08 commit 29f70cf
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 24 deletions.
1 change: 1 addition & 0 deletions docs/extra/components/choose_generator_llm.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

```python
from ragas.llms import LangchainLLMWrapper
from ragas.embeddings import LangchainEmbeddingsWrapper
from langchain_openai import ChatOpenAI
from langchain_openai import OpenAIEmbeddings
generator_llm = LangchainLLMWrapper(ChatOpenAI(model="gpt-4o"))
Expand Down
20 changes: 20 additions & 0 deletions src/ragas/testset/transforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field

import tiktoken
from tiktoken.core import Encoding

from ragas.llms import BaseRagasLLM, llm_factory
from ragas.prompt import PromptMixin
from ragas.testset.graph import KnowledgeGraph, Node, Relationship

DEFAULT_TOKENIZER = tiktoken.get_encoding("o200k_base")

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -188,6 +193,21 @@ async def apply_extract(node: Node):
class LLMBasedExtractor(Extractor, PromptMixin):
llm: BaseRagasLLM = field(default_factory=llm_factory)
merge_if_possible: bool = True
max_token_limit: int = 32000
tokenizer: Encoding = DEFAULT_TOKENIZER

def split_text_by_token_limit(self, text, max_token_limit):

# Tokenize the entire input string
tokens = self.tokenizer.encode(text)

# Split tokens into chunks of max_token_limit or less
chunks = []
for i in range(0, len(tokens), max_token_limit):
chunk_tokens = tokens[i : i + max_token_limit]
chunks.append(self.tokenizer.decode(chunk_tokens))

return chunks


class Splitter(BaseGraphTransformation):
Expand Down
70 changes: 46 additions & 24 deletions src/ragas/testset/transforms/extractors/llm_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ class HeadlinesExtractorPrompt(PydanticPrompt[TextWithExtractionLimit, Headlines
"Introduction",
"Main Concepts",
"Detailed Analysis",
"Subsection: Specialized Techniques"
"Future Directions",
"Conclusion",
],
),
),
Expand Down Expand Up @@ -174,14 +176,15 @@ async def extract(self, node: Node) -> t.Tuple[str, t.Any]:
node_text = node.get_property("page_content")
if node_text is None:
return self.property_name, None
result = await self.prompt.generate(self.llm, data=StringIO(text=node_text))
chunks = self.split_text_by_token_limit(node_text, self.max_token_limit)
result = await self.prompt.generate(self.llm, data=StringIO(text=chunks[0]))
return self.property_name, result.text


@dataclass
class KeyphrasesExtractor(LLMBasedExtractor):
"""
Extracts top 5 keyphrases from the given text.
Extracts top keyphrases from the given text.
Attributes
----------
Expand All @@ -199,10 +202,15 @@ async def extract(self, node: Node) -> t.Tuple[str, t.Any]:
node_text = node.get_property("page_content")
if node_text is None:
return self.property_name, None
result = await self.prompt.generate(
self.llm, data=TextWithExtractionLimit(text=node_text, max_num=self.max_num)
)
return self.property_name, result.keyphrases
chunks = self.split_text_by_token_limit(node_text, self.max_token_limit)
keyphrases = []
for chunk in chunks:
result = await self.prompt.generate(
self.llm, data=TextWithExtractionLimit(text=chunk, max_num=self.max_num)
)
keyphrases.extend(result.keyphrases)
return self.property_name, keyphrases



@dataclass
Expand All @@ -225,7 +233,8 @@ async def extract(self, node: Node) -> t.Tuple[str, t.Any]:
node_text = node.get_property("page_content")
if node_text is None:
return self.property_name, None
result = await self.prompt.generate(self.llm, data=StringIO(text=node_text))
chunks = self.split_text_by_token_limit(node_text, self.max_token_limit)
result = await self.prompt.generate(self.llm, data=StringIO(text=chunks[0]))
return self.property_name, result.text


Expand All @@ -250,12 +259,15 @@ async def extract(self, node: Node) -> t.Tuple[str, t.Any]:
node_text = node.get_property("page_content")
if node_text is None:
return self.property_name, None
result = await self.prompt.generate(
self.llm, data=TextWithExtractionLimit(text=node_text, max_num=self.max_num)
)
if result is None:
return self.property_name, None
return self.property_name, result.headlines
chunks = self.split_text_by_token_limit(node_text, self.max_token_limit)
headlines = []
for chunk in chunks:
result = await self.prompt.generate(
self.llm, data=TextWithExtractionLimit(text=chunk, max_num=self.max_num)
)
if result:
headlines.extend(result.headlines)
return self.property_name, headlines


@dataclass
Expand All @@ -279,11 +291,15 @@ async def extract(self, node: Node) -> t.Tuple[str, t.List[str]]:
node_text = node.get_property("page_content")
if node_text is None:
return self.property_name, []
result = await self.prompt.generate(
self.llm,
data=TextWithExtractionLimit(text=node_text, max_num=self.max_num_entities),
)
return self.property_name, result.entities
chunks = self.split_text_by_token_limit(node_text, self.max_token_limit)
entities = []
for chunk in chunks:
result = await self.prompt.generate(
self.llm,
data=TextWithExtractionLimit(text=chunk, max_num=self.max_num_entities),
)
entities.extend(result.entities)
return self.property_name, entities


class TopicDescription(BaseModel):
Expand Down Expand Up @@ -328,7 +344,8 @@ async def extract(self, node: Node) -> t.Tuple[str, t.Any]:
node_text = node.get_property("page_content")
if node_text is None:
return self.property_name, None
result = await self.prompt.generate(self.llm, data=StringIO(text=node_text))
chunks = self.split_text_by_token_limit(node_text, self.max_token_limit)
result = await self.prompt.generate(self.llm, data=StringIO(text=chunks[0]))
return self.property_name, result.description


Expand Down Expand Up @@ -383,8 +400,13 @@ async def extract(self, node: Node) -> t.Tuple[str, t.List[str]]:
node_text = node.get_property("page_content")
if node_text is None:
return self.property_name, []
result = await self.prompt.generate(
self.llm,
data=TextWithExtractionLimit(text=node_text, max_num=self.max_num_themes),
)
return self.property_name, result.output
chunks = self.split_text_by_token_limit(node_text, self.max_token_limit)
themes = []
for chunk in chunks:
result = await self.prompt.generate(
self.llm,
data=TextWithExtractionLimit(text=chunk, max_num=self.max_num_themes),
)
themes.extend(result.output)

return self.property_name, themes

0 comments on commit 29f70cf

Please sign in to comment.