|
| 1 | +import logging |
| 2 | +import random |
| 3 | +import typing as t |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +from langchain_core.callbacks import Callbacks |
| 7 | +from pydantic import BaseModel |
| 8 | + |
| 9 | +from ragas.executor import run_async_batch |
| 10 | +from ragas.llms.base import BaseRagasLLM |
| 11 | +from ragas.prompt import PydanticPrompt, StringIO |
| 12 | +from ragas.testset.graph import KnowledgeGraph, Node |
| 13 | + |
| 14 | +logger = logging.getLogger(__name__) |
| 15 | + |
| 16 | + |
| 17 | +def default_filter(node: Node) -> bool: |
| 18 | + if ( |
| 19 | + node.type.name == "DOCUMENT" |
| 20 | + and node.properties.get("summary_embedding") is not None |
| 21 | + ): |
| 22 | + return random.random() < 0.25 |
| 23 | + else: |
| 24 | + return False |
| 25 | + |
| 26 | + |
| 27 | +class Persona(BaseModel): |
| 28 | + name: str |
| 29 | + role_description: str |
| 30 | + |
| 31 | + |
| 32 | +class PersonaGenerationPrompt(PydanticPrompt[StringIO, Persona]): |
| 33 | + instruction: str = ( |
| 34 | + "Using the provided summary, generate a single persona who would likely " |
| 35 | + "interact with or benefit from the content. Include a unique name and a " |
| 36 | + "concise role description of who they are." |
| 37 | + ) |
| 38 | + input_model: t.Type[StringIO] = StringIO |
| 39 | + output_model: t.Type[Persona] = Persona |
| 40 | + examples: t.List[t.Tuple[StringIO, Persona]] = [ |
| 41 | + ( |
| 42 | + StringIO( |
| 43 | + text="Guide to Digital Marketing explains strategies for engaging audiences across various online platforms." |
| 44 | + ), |
| 45 | + Persona( |
| 46 | + name="Digital Marketing Specialist", |
| 47 | + role_description="Focuses on engaging audiences and growing the brand online.", |
| 48 | + ), |
| 49 | + ) |
| 50 | + ] |
| 51 | + |
| 52 | + |
| 53 | +class PersonaList(BaseModel): |
| 54 | + personas: t.List[Persona] |
| 55 | + |
| 56 | + def __getitem__(self, key: str) -> Persona: |
| 57 | + for persona in self.personas: |
| 58 | + if persona.name == key: |
| 59 | + return persona |
| 60 | + raise KeyError(f"No persona found with name '{key}'") |
| 61 | + |
| 62 | + |
| 63 | +def generate_personas_from_kg( |
| 64 | + kg: KnowledgeGraph, |
| 65 | + llm: BaseRagasLLM, |
| 66 | + persona_generation_prompt: PersonaGenerationPrompt = PersonaGenerationPrompt(), |
| 67 | + num_personas: int = 3, |
| 68 | + filter_fn: t.Callable[[Node], bool] = default_filter, |
| 69 | + callbacks: Callbacks = [], |
| 70 | +) -> t.List[Persona]: |
| 71 | + """ |
| 72 | + Generate personas from a knowledge graph based on cluster of similar document summaries. |
| 73 | +
|
| 74 | + parameters: |
| 75 | + kg: KnowledgeGraph |
| 76 | + The knowledge graph to generate personas from. |
| 77 | + llm: BaseRagasLLM |
| 78 | + The LLM to use for generating the persona. |
| 79 | + persona_generation_prompt: PersonaGenerationPrompt |
| 80 | + The prompt to use for generating the persona. |
| 81 | + num_personas: int |
| 82 | + The maximum number of personas to generate. |
| 83 | + filter_fn: Callable[[Node], bool] |
| 84 | + A function to filter nodes in the knowledge graph. |
| 85 | + callbacks: Callbacks |
| 86 | + The callbacks to use for the generation process. |
| 87 | +
|
| 88 | +
|
| 89 | + returns: |
| 90 | + t.List[Persona] |
| 91 | + The list of generated personas. |
| 92 | + """ |
| 93 | + |
| 94 | + nodes = [node for node in kg.nodes if filter_fn(node)] |
| 95 | + summaries = [node.properties.get("summary") for node in nodes] |
| 96 | + summaries = [summary for summary in summaries if isinstance(summary, str)] |
| 97 | + |
| 98 | + embeddings = [] |
| 99 | + for node in nodes: |
| 100 | + embeddings.append(node.properties.get("summary_embedding")) |
| 101 | + |
| 102 | + embeddings = np.array(embeddings) |
| 103 | + cosine_similarities = np.dot(embeddings, embeddings.T) |
| 104 | + |
| 105 | + groups = [] |
| 106 | + visited = set() |
| 107 | + threshold = 0.75 |
| 108 | + |
| 109 | + for i, _ in enumerate(summaries): |
| 110 | + if i in visited: |
| 111 | + continue |
| 112 | + group = [i] |
| 113 | + visited.add(i) |
| 114 | + for j in range(i + 1, len(summaries)): |
| 115 | + if cosine_similarities[i, j] > threshold: |
| 116 | + group.append(j) |
| 117 | + visited.add(j) |
| 118 | + groups.append(group) |
| 119 | + |
| 120 | + top_summaries = [] |
| 121 | + for group in groups: |
| 122 | + representative_summary = max([summaries[i] for i in group], key=len) |
| 123 | + top_summaries.append(representative_summary) |
| 124 | + |
| 125 | + if len(top_summaries) <= num_personas: |
| 126 | + top_summaries.extend( |
| 127 | + np.random.choice(top_summaries, num_personas - len(top_summaries)) |
| 128 | + ) |
| 129 | + |
| 130 | + # use run_async_batch to generate personas in parallel |
| 131 | + kwargs_list = [ |
| 132 | + { |
| 133 | + "llm": llm, |
| 134 | + "data": StringIO(text=summary), |
| 135 | + "callbacks": callbacks, |
| 136 | + "temperature": 1.0, |
| 137 | + } |
| 138 | + for summary in top_summaries[:num_personas] |
| 139 | + ] |
| 140 | + persona_list = run_async_batch( |
| 141 | + desc="Generating personas", |
| 142 | + func=persona_generation_prompt.generate, |
| 143 | + kwargs_list=kwargs_list, |
| 144 | + ) |
| 145 | + |
| 146 | + return persona_list |
0 commit comments