Skip to content

Commit 02e7a46

Browse files
feat: automatic persona generation (#1618)
Adds persona to diversify test set generation ```python from ragas.testset.persona import PersonaList persona_list = await PersonaList.from_kg(llm=generator_llm, kg=kg) ``` --------- Co-authored-by: Jithin James <[email protected]>
1 parent 3a6fab9 commit 02e7a46

File tree

4 files changed

+199
-1
lines changed

4 files changed

+199
-1
lines changed

src/ragas/testset/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
from ragas.testset.synthesizers.generate import TestsetGenerator
22
from ragas.testset.synthesizers.testset_schema import Testset, TestsetSample
33

4-
__all__ = ["TestsetGenerator", "Testset", "TestsetSample"]
4+
__all__ = [
5+
"TestsetGenerator",
6+
"Testset",
7+
"TestsetSample",
8+
]

src/ragas/testset/persona.py

+146
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import logging
2+
import random
3+
import typing as t
4+
5+
import numpy as np
6+
from langchain_core.callbacks import Callbacks
7+
from pydantic import BaseModel
8+
9+
from ragas.executor import run_async_batch
10+
from ragas.llms.base import BaseRagasLLM
11+
from ragas.prompt import PydanticPrompt, StringIO
12+
from ragas.testset.graph import KnowledgeGraph, Node
13+
14+
logger = logging.getLogger(__name__)
15+
16+
17+
def default_filter(node: Node) -> bool:
18+
if (
19+
node.type.name == "DOCUMENT"
20+
and node.properties.get("summary_embedding") is not None
21+
):
22+
return random.random() < 0.25
23+
else:
24+
return False
25+
26+
27+
class Persona(BaseModel):
28+
name: str
29+
role_description: str
30+
31+
32+
class PersonaGenerationPrompt(PydanticPrompt[StringIO, Persona]):
33+
instruction: str = (
34+
"Using the provided summary, generate a single persona who would likely "
35+
"interact with or benefit from the content. Include a unique name and a "
36+
"concise role description of who they are."
37+
)
38+
input_model: t.Type[StringIO] = StringIO
39+
output_model: t.Type[Persona] = Persona
40+
examples: t.List[t.Tuple[StringIO, Persona]] = [
41+
(
42+
StringIO(
43+
text="Guide to Digital Marketing explains strategies for engaging audiences across various online platforms."
44+
),
45+
Persona(
46+
name="Digital Marketing Specialist",
47+
role_description="Focuses on engaging audiences and growing the brand online.",
48+
),
49+
)
50+
]
51+
52+
53+
class PersonaList(BaseModel):
54+
personas: t.List[Persona]
55+
56+
def __getitem__(self, key: str) -> Persona:
57+
for persona in self.personas:
58+
if persona.name == key:
59+
return persona
60+
raise KeyError(f"No persona found with name '{key}'")
61+
62+
63+
def generate_personas_from_kg(
64+
kg: KnowledgeGraph,
65+
llm: BaseRagasLLM,
66+
persona_generation_prompt: PersonaGenerationPrompt = PersonaGenerationPrompt(),
67+
num_personas: int = 3,
68+
filter_fn: t.Callable[[Node], bool] = default_filter,
69+
callbacks: Callbacks = [],
70+
) -> t.List[Persona]:
71+
"""
72+
Generate personas from a knowledge graph based on cluster of similar document summaries.
73+
74+
parameters:
75+
kg: KnowledgeGraph
76+
The knowledge graph to generate personas from.
77+
llm: BaseRagasLLM
78+
The LLM to use for generating the persona.
79+
persona_generation_prompt: PersonaGenerationPrompt
80+
The prompt to use for generating the persona.
81+
num_personas: int
82+
The maximum number of personas to generate.
83+
filter_fn: Callable[[Node], bool]
84+
A function to filter nodes in the knowledge graph.
85+
callbacks: Callbacks
86+
The callbacks to use for the generation process.
87+
88+
89+
returns:
90+
t.List[Persona]
91+
The list of generated personas.
92+
"""
93+
94+
nodes = [node for node in kg.nodes if filter_fn(node)]
95+
summaries = [node.properties.get("summary") for node in nodes]
96+
summaries = [summary for summary in summaries if isinstance(summary, str)]
97+
98+
embeddings = []
99+
for node in nodes:
100+
embeddings.append(node.properties.get("summary_embedding"))
101+
102+
embeddings = np.array(embeddings)
103+
cosine_similarities = np.dot(embeddings, embeddings.T)
104+
105+
groups = []
106+
visited = set()
107+
threshold = 0.75
108+
109+
for i, _ in enumerate(summaries):
110+
if i in visited:
111+
continue
112+
group = [i]
113+
visited.add(i)
114+
for j in range(i + 1, len(summaries)):
115+
if cosine_similarities[i, j] > threshold:
116+
group.append(j)
117+
visited.add(j)
118+
groups.append(group)
119+
120+
top_summaries = []
121+
for group in groups:
122+
representative_summary = max([summaries[i] for i in group], key=len)
123+
top_summaries.append(representative_summary)
124+
125+
if len(top_summaries) <= num_personas:
126+
top_summaries.extend(
127+
np.random.choice(top_summaries, num_personas - len(top_summaries))
128+
)
129+
130+
# use run_async_batch to generate personas in parallel
131+
kwargs_list = [
132+
{
133+
"llm": llm,
134+
"data": StringIO(text=summary),
135+
"callbacks": callbacks,
136+
"temperature": 1.0,
137+
}
138+
for summary in top_summaries[:num_personas]
139+
]
140+
persona_list = run_async_batch(
141+
desc="Generating personas",
142+
func=persona_generation_prompt.generate,
143+
kwargs_list=kwargs_list,
144+
)
145+
146+
return persona_list

src/ragas/testset/transforms/extractors/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
NERExtractor,
66
SummaryExtractor,
77
TitleExtractor,
8+
TopicDescriptionExtractor,
89
)
910
from .regex_based import emails_extractor, links_extractor, markdown_headings_extractor
1011

@@ -18,4 +19,5 @@
1819
"HeadlinesExtractor",
1920
"EmbeddingExtractor",
2021
"NERExtractor",
22+
"TopicDescriptionExtractor",
2123
]

src/ragas/testset/transforms/extractors/llm_based.py

+46
Original file line numberDiff line numberDiff line change
@@ -260,3 +260,49 @@ async def extract(self, node: Node) -> t.Tuple[str, t.Dict[str, t.List[str]]]:
260260
return self.property_name, {}
261261
result = await self.prompt.generate(self.llm, data=StringIO(text=node_text))
262262
return self.property_name, result.entities.model_dump()
263+
264+
265+
class TopicDescription(BaseModel):
266+
description: str
267+
268+
269+
class TopicDescriptionPrompt(PydanticPrompt[StringIO, TopicDescription]):
270+
instruction: str = (
271+
"Provide a concise description of the main topic(s) discussed in the following text."
272+
)
273+
input_model: t.Type[StringIO] = StringIO
274+
output_model: t.Type[TopicDescription] = TopicDescription
275+
examples: t.List[t.Tuple[StringIO, TopicDescription]] = [
276+
(
277+
StringIO(
278+
text="Quantum Computing\n\nQuantum computing leverages the principles of quantum mechanics to perform complex computations more efficiently than classical computers. It has the potential to revolutionize fields like cryptography, material science, and optimization problems by solving tasks that are currently intractable for classical systems."
279+
),
280+
TopicDescription(
281+
description="An introduction to quantum computing and its potential to outperform classical computers in complex computations, impacting areas such as cryptography and material science."
282+
),
283+
)
284+
]
285+
286+
287+
@dataclass
288+
class TopicDescriptionExtractor(LLMBasedExtractor):
289+
"""
290+
Extracts a concise description of the main topic(s) discussed in the given text.
291+
292+
Attributes
293+
----------
294+
property_name : str
295+
The name of the property to extract.
296+
prompt : TopicDescriptionPrompt
297+
The prompt used for extraction.
298+
"""
299+
300+
property_name: str = "topic_description"
301+
prompt: TopicDescriptionPrompt = TopicDescriptionPrompt()
302+
303+
async def extract(self, node: Node) -> t.Tuple[str, t.Any]:
304+
node_text = node.get_property("page_content")
305+
if node_text is None:
306+
return self.property_name, None
307+
result = await self.prompt.generate(self.llm, data=StringIO(text=node_text))
308+
return self.property_name, result.description

0 commit comments

Comments
 (0)