Skip to content

Commit

Permalink
testset generation: bug fixes (#185)
Browse files Browse the repository at this point in the history
Fixes 

- [x] issues with multi-context question generation  
- [x] Error in doc filtering
  • Loading branch information
shahules786 authored Oct 16, 2023
1 parent 6787a5c commit 0984435
Showing 1 changed file with 49 additions and 42 deletions.
91 changes: 49 additions & 42 deletions src/ragas/testset/testset_generator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import re
import typing as t
import warnings
from collections import defaultdict, namedtuple
from dataclasses import dataclass

Expand Down Expand Up @@ -33,10 +35,10 @@
)

DEFAULT_TEST_DISTRIBUTION = {
"simple": 0.5,
"simple": 0.4,
"reasoning": 0.2,
"multi_context": 0.2,
"conditional": 0.1,
"conditional": 0.2,
}

question_deep_map = {
Expand Down Expand Up @@ -106,7 +108,7 @@ def __init__(
critic_llm: BaseLLM | BaseChatModel,
embeddings_model: Embeddings,
testset_distribution: t.Optional[t.Dict[str, float]] = None,
chat_qa: float = 0.3,
chat_qa: float = 0.0,
chunk_size: int = 1024,
seed: int = 42,
) -> None:
Expand Down Expand Up @@ -135,7 +137,7 @@ def from_default(
openai_generator_llm: str = "gpt-3.5-turbo-16k",
openai_filter_llm: str = "gpt-4",
chat_qa: float = 0.3,
chunk_size: int = 1024,
chunk_size: int = 512,
):
generator_llm = ChatOpenAI(model=openai_generator_llm)
critic_llm = ChatOpenAI(model=openai_filter_llm)
Expand Down Expand Up @@ -173,14 +175,12 @@ def _filter_context(self, context: str) -> bool:
prompt = ChatPromptTemplate.from_messages([human_prompt])
results = generate(prompts=[prompt], llm=self.critic_llm)
output = results.generations[0][0].text.strip()
score = eval(output)
if not isinstance(score, float | int):
index = output.lower().find("score:")
if index != -1:
index += len("score:")
score = eval(output[index:])
else:
score = 0.0
pattern = r"^[\d.]+$"
if not re.match(pattern, output):
score = 0.0
else:
score = eval(output)

return score >= self.threshold

def _seed_question(self, context: str) -> str:
Expand Down Expand Up @@ -241,22 +241,30 @@ def _generate_context(self, question: str, text_chunk: str) -> t.List[str]:
for qstn in question.split("\n")
]

def _remove_index(self, available_indices: list, node_idx: list) -> t.List:
def _remove_nodes(self, available_indices: list, node_idx: list) -> t.List:
for idx in node_idx:
available_indices.remove(idx)
return available_indices

def _generate_doc_node_map(
def _generate_doc_nodes_map(
self, documenet_nodes: t.List[BaseNode]
) -> t.Dict[str, list]:
doc_nodeidx = defaultdict(list)
for idx, node in enumerate(documenet_nodes):
doc_nodeidx[node.id_].append(idx)

return doc_nodeidx

def _get_neighbour_node(self, idx: int, node_indices: list) -> t.List[int]:
return [idx - 1, idx] if idx == node_indices[-1] else [idx, idx + 1]
) -> t.Dict[str, BaseNode]:
doc_nodes_map: t.Dict[str, t.List[BaseNode]] = defaultdict(list[BaseNode])
for node in documenet_nodes:
if node.ref_doc_id:
doc_nodes_map[node.ref_doc_id].append(node)

return doc_nodes_map # type: ignore

def _get_neighbour_node(
self, node: BaseNode, related_nodes: list[BaseNode]
) -> t.List[BaseNode]:
if len(related_nodes) < 2:
warnings.warn("No neighbors exists")
return [node]
idx = related_nodes.index(node)
ids = [idx - 1, idx] if idx == (len(related_nodes) - 1) else [idx, idx + 1]
return [related_nodes[idx] for idx in ids]

def _embed_nodes(self, nodes: t.List[BaseNode]) -> t.Dict[str, t.List[float]]:
embeddings = {}
Expand All @@ -275,38 +283,38 @@ def generate(self, documents: t.List[Document], test_size: int) -> TestDataset:
document_nodes: t.List[BaseNode] = node_parser.get_nodes_from_documents(
documents=documents
)

# maximum 1 seed question per node
if test_size > len(document_nodes):
raise ValueError(
"""Maximum possible number of samples exceeded,
reduce test_size or add more documents"""
)

available_indices = np.arange(0, len(document_nodes)).tolist()
doc_nodeidx = self._generate_doc_node_map(document_nodes)
available_nodes = document_nodes
doc_nodes_map = self._generate_doc_nodes_map(document_nodes)
count_neighbours = sum(len(val) > 1 for _, val in doc_nodes_map.items())
if count_neighbours < len(documents) // 2:
warnings.warn("Most documents are too short")

count = 0
samples = []

pbar = tqdm(total=test_size)
while count < test_size and available_indices != []:
while count < test_size and available_nodes != []:
evolve_type = self._get_evolve_type()
node_idx = self.rng.choice(available_indices, size=1)[0]
available_indices = self._remove_index(available_indices, [node_idx])
curr_node = self.rng.choice(available_nodes, size=1)[0]
available_nodes = self._remove_nodes(available_nodes, [curr_node])

neighbor_nodes = doc_nodeidx[
document_nodes[node_idx].node_id # type: ignore
]
neighbor_nodes = doc_nodes_map[curr_node.source_node.node_id]

# Append multiple nodes randomly to remove chunking bias
size = self.rng.integers(1, 3)
node_indices = (
self._get_neighbour_node(node_idx, neighbor_nodes)
nodes = (
self._get_neighbour_node(curr_node, neighbor_nodes)
if size > 1 and evolve_type != "multi_context"
else [node_idx]
else [curr_node]
)

nodes = [document_nodes[node_idx] for node_idx in node_indices]
text_chunk = " ".join([node.get_content() for node in nodes])
score = self._filter_context(text_chunk)
if not score:
Expand All @@ -316,14 +324,13 @@ def generate(self, documents: t.List[Document], test_size: int) -> TestDataset:
if evolve_type == "multi_context":
# Find most similar chunk in same document
node_embedding = self._embed_nodes([nodes[-1]])
neighbor_nodes = self._remove_index(neighbor_nodes, node_indices)
neighbor_emb = self._embed_nodes(
[document_nodes[idx][0] for idx in neighbor_nodes]
)
neighbor_nodes = self._remove_nodes(neighbor_nodes, nodes)
neighbor_emb = self._embed_nodes(neighbor_nodes)

_, indices = get_top_k_embeddings(
list(node_embedding.values())[0],
list(neighbor_emb.values()),
similarity_cutoff=self.threshold,
similarity_cutoff=self.threshold / 10,
)
if indices:
best_neighbor = neighbor_nodes[indices[0]]
Expand All @@ -332,7 +339,7 @@ def generate(self, documents: t.List[Document], test_size: int) -> TestDataset:
context1=text_chunk,
context2=best_neighbor.get_content(),
)
text_chunk = "\n".join([text_chunk, best_neighbor.get_context()])
text_chunk = "\n".join([text_chunk, best_neighbor.get_content()])
else:
continue

Expand Down

0 comments on commit 0984435

Please sign in to comment.