From de6da38313480e326016312f693e0214201494d5 Mon Sep 17 00:00:00 2001 From: ikka Date: Tue, 10 Dec 2024 08:40:01 +0530 Subject: [PATCH] fix: prompt naming related issues (#1743) --- src/ragas/callbacks.py | 2 +- src/ragas/optimizers/genetic.py | 10 +--------- src/ragas/prompt/mixin.py | 16 ++++++++++++---- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/ragas/callbacks.py b/src/ragas/callbacks.py index ae5ea8b2e..cbfb92b95 100644 --- a/src/ragas/callbacks.py +++ b/src/ragas/callbacks.py @@ -164,7 +164,7 @@ def parse_run_traces( prompt_trace = traces[prompt_uuid] prompt_traces[f"{prompt_trace.name}"] = { "input": prompt_trace.inputs.get("data", {}), - "output": prompt_trace.outputs.get("output", {}), + "output": prompt_trace.outputs.get("output", {})[0], } metric_traces[f"{metric_trace.name}"] = prompt_traces parased_traces.append(metric_traces) diff --git a/src/ragas/optimizers/genetic.py b/src/ragas/optimizers/genetic.py index fd8e4e9f6..65ddc7f3c 100644 --- a/src/ragas/optimizers/genetic.py +++ b/src/ragas/optimizers/genetic.py @@ -514,7 +514,7 @@ def dict_to_str(dict: t.Dict[str, t.Any]) -> str: exclude_none=True ) ), - output=traces[idx][prompt_name]["output"][0].model_dump( + output=traces[idx][prompt_name]["output"].model_dump( exclude_none=True ), expected_output=dataset[idx]["prompts"][prompt_name][ @@ -586,14 +586,6 @@ def evaluate_candidate( _run_id=run_id, _pbar=parent_pbar, ) - # remap the traces to the original prompt names - remap_traces = {val.name: key for key, val in self.metric.get_prompts().items()} - for trace in results.traces: - for key in remap_traces: - if key in trace[self.metric.name]: - trace[self.metric.name][remap_traces[key]] = trace[ - self.metric.name - ].pop(key) return results def evaluate_fitness( diff --git a/src/ragas/prompt/mixin.py b/src/ragas/prompt/mixin.py index 79e551298..66ba13740 100644 --- a/src/ragas/prompt/mixin.py +++ b/src/ragas/prompt/mixin.py @@ -21,14 +21,21 @@ class PromptMixin: eg: [BaseSynthesizer][ragas.testset.synthesizers.base.BaseSynthesizer], [MetricWithLLM][ragas.metrics.base.MetricWithLLM] """ + def _get_prompts(self) -> t.Dict[str, PydanticPrompt]: + + prompts = {} + for key, value in inspect.getmembers(self): + if isinstance(value, PydanticPrompt): + prompts.update({key: value}) + return prompts + def get_prompts(self) -> t.Dict[str, PydanticPrompt]: """ Returns a dictionary of prompts for the class. """ prompts = {} - for name, value in inspect.getmembers(self): - if isinstance(value, PydanticPrompt): - prompts.update({name: value}) + for _, value in self._get_prompts().items(): + prompts.update({value.name: value}) return prompts def set_prompts(self, **prompts): @@ -41,6 +48,7 @@ def set_prompts(self, **prompts): If the prompt is not an instance of `PydanticPrompt`. """ available_prompts = self.get_prompts() + name_to_var = {v.name: k for k, v in self._get_prompts().items()} for key, value in prompts.items(): if key not in available_prompts: raise ValueError( @@ -50,7 +58,7 @@ def set_prompts(self, **prompts): raise ValueError( f"Prompt with name '{key}' must be an instance of 'ragas.prompt.PydanticPrompt'" ) - setattr(self, key, value) + setattr(self, name_to_var[key], value) async def adapt_prompts( self, language: str, llm: BaseRagasLLM, adapt_instruction: bool = False