From deded7096486f1b7616e8d11f7e9d6eac731bdd0 Mon Sep 17 00:00:00 2001 From: Jithin James Date: Tue, 6 Feb 2024 22:19:44 -0800 Subject: [PATCH] fix: raise `ExceptionInRunner` if executor faces any issues (#569) --- src/ragas/testset/docstore.py | 4 ++++ src/ragas/testset/evolutions.py | 7 ++++++- src/ragas/testset/extractor.py | 2 +- src/ragas/testset/generator.py | 8 ++++++-- 4 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/ragas/testset/docstore.py b/src/ragas/testset/docstore.py index 6842d3d0f..d2d1c140c 100644 --- a/src/ragas/testset/docstore.py +++ b/src/ragas/testset/docstore.py @@ -15,6 +15,7 @@ from langchain_core.pydantic_v1 import Field from ragas.embeddings.base import BaseRagasEmbeddings +from ragas.exceptions import ExceptionInRunner from ragas.executor import Executor from ragas.run_config import RunConfig from ragas.testset.utils import rng @@ -245,6 +246,9 @@ def add_nodes( result_idx += 1 results = executor.results() + if results == []: + raise ExceptionInRunner() + for i, n in enumerate(nodes): if i in nodes_to_embed.keys(): n.embedding = results[nodes_to_embed[i]] diff --git a/src/ragas/testset/evolutions.py b/src/ragas/testset/evolutions.py index 5d41a4912..09802a188 100644 --- a/src/ragas/testset/evolutions.py +++ b/src/ragas/testset/evolutions.py @@ -412,7 +412,12 @@ async def _aevolve( ) # find a similar node and generate a question based on both - similar_node = self.docstore.get_similar(current_nodes.root_node)[0] + similar_node = self.docstore.get_similar(current_nodes.root_node) + if similar_node == []: + # retry + current_nodes = self.se._get_more_adjacent_nodes(current_nodes) + return await self.aretry_evolve(current_tries, current_nodes) + prompt = self.multi_context_question_prompt.format( question=simple_question, context1=current_nodes.root_node.page_content, diff --git a/src/ragas/testset/extractor.py b/src/ragas/testset/extractor.py index fcbe3e54c..77c586c2e 100644 --- a/src/ragas/testset/extractor.py +++ b/src/ragas/testset/extractor.py @@ -39,7 +39,7 @@ def save(self, cache_dir: t.Optional[str] = None) -> None: @dataclass -class keyphraseExtractor(Extractor): +class KeyphraseExtractor(Extractor): keyphrase_extraction_prompt: Prompt = field( default_factory=lambda: keyphrase_extraction_prompt ) diff --git a/src/ragas/testset/generator.py b/src/ragas/testset/generator.py index 81464a147..57c8cbfcf 100644 --- a/src/ragas/testset/generator.py +++ b/src/ragas/testset/generator.py @@ -12,6 +12,7 @@ from ragas._analytics import TesetGenerationEvent, track from ragas.embeddings.base import BaseRagasEmbeddings, LangchainEmbeddingsWrapper +from ragas.exceptions import ExceptionInRunner from ragas.executor import Executor from ragas.llms import BaseRagasLLM, LangchainLLMWrapper from ragas.run_config import RunConfig @@ -25,7 +26,7 @@ reasoning, simple, ) -from ragas.testset.extractor import keyphraseExtractor +from ragas.testset.extractor import KeyphraseExtractor from ragas.testset.filters import EvolutionFilter, NodeFilter, QuestionFilter from ragas.utils import check_if_sum_is_close, is_nan @@ -83,7 +84,7 @@ def with_openai( embeddings_model = LangchainEmbeddingsWrapper( OpenAIEmbeddings(model=embeddings) ) - keyphrase_extractor = keyphraseExtractor(llm=generator_llm_model) + keyphrase_extractor = KeyphraseExtractor(llm=generator_llm_model) if docstore is None: from langchain.text_splitter import TokenTextSplitter @@ -241,6 +242,9 @@ def generate( try: test_data_rows = exec.results() + if test_data_rows == []: + raise ExceptionInRunner() + except ValueError as e: raise e # make sure to ignore any NaNs that might have been returned