diff --git a/src/ragas/testset/testset_generator.py b/src/ragas/testset/testset_generator.py index d568e7500..7a1d1b00b 100644 --- a/src/ragas/testset/testset_generator.py +++ b/src/ragas/testset/testset_generator.py @@ -1,4 +1,6 @@ +import re import typing as t +import warnings from collections import defaultdict, namedtuple from dataclasses import dataclass @@ -33,10 +35,10 @@ ) DEFAULT_TEST_DISTRIBUTION = { - "simple": 0.5, + "simple": 0.4, "reasoning": 0.2, "multi_context": 0.2, - "conditional": 0.1, + "conditional": 0.2, } question_deep_map = { @@ -106,7 +108,7 @@ def __init__( critic_llm: BaseLLM | BaseChatModel, embeddings_model: Embeddings, testset_distribution: t.Optional[t.Dict[str, float]] = None, - chat_qa: float = 0.3, + chat_qa: float = 0.0, chunk_size: int = 1024, seed: int = 42, ) -> None: @@ -135,7 +137,7 @@ def from_default( openai_generator_llm: str = "gpt-3.5-turbo-16k", openai_filter_llm: str = "gpt-4", chat_qa: float = 0.3, - chunk_size: int = 1024, + chunk_size: int = 512, ): generator_llm = ChatOpenAI(model=openai_generator_llm) critic_llm = ChatOpenAI(model=openai_filter_llm) @@ -173,14 +175,12 @@ def _filter_context(self, context: str) -> bool: prompt = ChatPromptTemplate.from_messages([human_prompt]) results = generate(prompts=[prompt], llm=self.critic_llm) output = results.generations[0][0].text.strip() - score = eval(output) - if not isinstance(score, float | int): - index = output.lower().find("score:") - if index != -1: - index += len("score:") - score = eval(output[index:]) - else: - score = 0.0 + pattern = r"^[\d.]+$" + if not re.match(pattern, output): + score = 0.0 + else: + score = eval(output) + return score >= self.threshold def _seed_question(self, context: str) -> str: @@ -241,22 +241,30 @@ def _generate_context(self, question: str, text_chunk: str) -> t.List[str]: for qstn in question.split("\n") ] - def _remove_index(self, available_indices: list, node_idx: list) -> t.List: + def _remove_nodes(self, available_indices: list, node_idx: list) -> t.List: for idx in node_idx: available_indices.remove(idx) return available_indices - def _generate_doc_node_map( + def _generate_doc_nodes_map( self, documenet_nodes: t.List[BaseNode] - ) -> t.Dict[str, list]: - doc_nodeidx = defaultdict(list) - for idx, node in enumerate(documenet_nodes): - doc_nodeidx[node.id_].append(idx) - - return doc_nodeidx - - def _get_neighbour_node(self, idx: int, node_indices: list) -> t.List[int]: - return [idx - 1, idx] if idx == node_indices[-1] else [idx, idx + 1] + ) -> t.Dict[str, BaseNode]: + doc_nodes_map: t.Dict[str, t.List[BaseNode]] = defaultdict(list[BaseNode]) + for node in documenet_nodes: + if node.ref_doc_id: + doc_nodes_map[node.ref_doc_id].append(node) + + return doc_nodes_map # type: ignore + + def _get_neighbour_node( + self, node: BaseNode, related_nodes: list[BaseNode] + ) -> t.List[BaseNode]: + if len(related_nodes) < 2: + warnings.warn("No neighbors exists") + return [node] + idx = related_nodes.index(node) + ids = [idx - 1, idx] if idx == (len(related_nodes) - 1) else [idx, idx + 1] + return [related_nodes[idx] for idx in ids] def _embed_nodes(self, nodes: t.List[BaseNode]) -> t.Dict[str, t.List[float]]: embeddings = {} @@ -275,7 +283,6 @@ def generate(self, documents: t.List[Document], test_size: int) -> TestDataset: document_nodes: t.List[BaseNode] = node_parser.get_nodes_from_documents( documents=documents ) - # maximum 1 seed question per node if test_size > len(document_nodes): raise ValueError( @@ -283,30 +290,31 @@ def generate(self, documents: t.List[Document], test_size: int) -> TestDataset: reduce test_size or add more documents""" ) - available_indices = np.arange(0, len(document_nodes)).tolist() - doc_nodeidx = self._generate_doc_node_map(document_nodes) + available_nodes = document_nodes + doc_nodes_map = self._generate_doc_nodes_map(document_nodes) + count_neighbours = sum(len(val) > 1 for _, val in doc_nodes_map.items()) + if count_neighbours < len(documents) // 2: + warnings.warn("Most documents are too short") + count = 0 samples = [] pbar = tqdm(total=test_size) - while count < test_size and available_indices != []: + while count < test_size and available_nodes != []: evolve_type = self._get_evolve_type() - node_idx = self.rng.choice(available_indices, size=1)[0] - available_indices = self._remove_index(available_indices, [node_idx]) + curr_node = self.rng.choice(available_nodes, size=1)[0] + available_nodes = self._remove_nodes(available_nodes, [curr_node]) - neighbor_nodes = doc_nodeidx[ - document_nodes[node_idx].node_id # type: ignore - ] + neighbor_nodes = doc_nodes_map[curr_node.source_node.node_id] # Append multiple nodes randomly to remove chunking bias size = self.rng.integers(1, 3) - node_indices = ( - self._get_neighbour_node(node_idx, neighbor_nodes) + nodes = ( + self._get_neighbour_node(curr_node, neighbor_nodes) if size > 1 and evolve_type != "multi_context" - else [node_idx] + else [curr_node] ) - nodes = [document_nodes[node_idx] for node_idx in node_indices] text_chunk = " ".join([node.get_content() for node in nodes]) score = self._filter_context(text_chunk) if not score: @@ -316,14 +324,13 @@ def generate(self, documents: t.List[Document], test_size: int) -> TestDataset: if evolve_type == "multi_context": # Find most similar chunk in same document node_embedding = self._embed_nodes([nodes[-1]]) - neighbor_nodes = self._remove_index(neighbor_nodes, node_indices) - neighbor_emb = self._embed_nodes( - [document_nodes[idx][0] for idx in neighbor_nodes] - ) + neighbor_nodes = self._remove_nodes(neighbor_nodes, nodes) + neighbor_emb = self._embed_nodes(neighbor_nodes) + _, indices = get_top_k_embeddings( list(node_embedding.values())[0], list(neighbor_emb.values()), - similarity_cutoff=self.threshold, + similarity_cutoff=self.threshold / 10, ) if indices: best_neighbor = neighbor_nodes[indices[0]] @@ -332,7 +339,7 @@ def generate(self, documents: t.List[Document], test_size: int) -> TestDataset: context1=text_chunk, context2=best_neighbor.get_content(), ) - text_chunk = "\n".join([text_chunk, best_neighbor.get_context()]) + text_chunk = "\n".join([text_chunk, best_neighbor.get_content()]) else: continue