Skip to content

Commit

Permalink
fix: prompt naming related issues (#1743)
Browse files Browse the repository at this point in the history
  • Loading branch information
shahules786 authored Dec 10, 2024
1 parent a2a2cef commit de6da38
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/ragas/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def parse_run_traces(
prompt_trace = traces[prompt_uuid]
prompt_traces[f"{prompt_trace.name}"] = {
"input": prompt_trace.inputs.get("data", {}),
"output": prompt_trace.outputs.get("output", {}),
"output": prompt_trace.outputs.get("output", {})[0],
}
metric_traces[f"{metric_trace.name}"] = prompt_traces
parased_traces.append(metric_traces)
Expand Down
10 changes: 1 addition & 9 deletions src/ragas/optimizers/genetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def dict_to_str(dict: t.Dict[str, t.Any]) -> str:
exclude_none=True
)
),
output=traces[idx][prompt_name]["output"][0].model_dump(
output=traces[idx][prompt_name]["output"].model_dump(
exclude_none=True
),
expected_output=dataset[idx]["prompts"][prompt_name][
Expand Down Expand Up @@ -586,14 +586,6 @@ def evaluate_candidate(
_run_id=run_id,
_pbar=parent_pbar,
)
# remap the traces to the original prompt names
remap_traces = {val.name: key for key, val in self.metric.get_prompts().items()}
for trace in results.traces:
for key in remap_traces:
if key in trace[self.metric.name]:
trace[self.metric.name][remap_traces[key]] = trace[
self.metric.name
].pop(key)
return results

def evaluate_fitness(
Expand Down
16 changes: 12 additions & 4 deletions src/ragas/prompt/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,21 @@ class PromptMixin:
eg: [BaseSynthesizer][ragas.testset.synthesizers.base.BaseSynthesizer], [MetricWithLLM][ragas.metrics.base.MetricWithLLM]
"""

def _get_prompts(self) -> t.Dict[str, PydanticPrompt]:

prompts = {}
for key, value in inspect.getmembers(self):
if isinstance(value, PydanticPrompt):
prompts.update({key: value})
return prompts

def get_prompts(self) -> t.Dict[str, PydanticPrompt]:
"""
Returns a dictionary of prompts for the class.
"""
prompts = {}
for name, value in inspect.getmembers(self):
if isinstance(value, PydanticPrompt):
prompts.update({name: value})
for _, value in self._get_prompts().items():
prompts.update({value.name: value})
return prompts

def set_prompts(self, **prompts):
Expand All @@ -41,6 +48,7 @@ def set_prompts(self, **prompts):
If the prompt is not an instance of `PydanticPrompt`.
"""
available_prompts = self.get_prompts()
name_to_var = {v.name: k for k, v in self._get_prompts().items()}
for key, value in prompts.items():
if key not in available_prompts:
raise ValueError(
Expand All @@ -50,7 +58,7 @@ def set_prompts(self, **prompts):
raise ValueError(
f"Prompt with name '{key}' must be an instance of 'ragas.prompt.PydanticPrompt'"
)
setattr(self, key, value)
setattr(self, name_to_var[key], value)

async def adapt_prompts(
self, language: str, llm: BaseRagasLLM, adapt_instruction: bool = False
Expand Down

0 comments on commit de6da38

Please sign in to comment.