Skip to content

Commit

Permalink
feat: few shot example optimzier (#1739)
Browse files Browse the repository at this point in the history
optimize with few short examples

```py
from ragas.metrics import AspectCritic
from ragas.llms import llm_factory

# define metric
llm = llm_factory("gpt-4o")
metric = AspectCritic(
    name="answer_correctness",
    definition="Given the user_input, reference and response. Is the response correct compared with the reference",
    llm=llm,
)

# optimize with annotation
from ragas.config import DemonstrationConfig
demonstration_config = DemonstrationConfig()
metric.train(
    "alignment_sample.json",
    demonstration_config=demonstration_config,
)
```
  • Loading branch information
jjmachan authored Dec 9, 2024
1 parent 9f5cccc commit d432ed0
Show file tree
Hide file tree
Showing 9 changed files with 334 additions and 57 deletions.
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,9 @@ addopts = "-n 0"
asyncio_default_fixture_loop_scope = "function"
[pytest]
testpaths = ["tests"]

[dependency-groups]
dev = [
"arize-phoenix>=6.1.0",
"openinference-instrumentation-langchain>=0.1.29",
]
22 changes: 17 additions & 5 deletions src/ragas/config.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,39 @@
from __future__ import annotations

import typing as t

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator

from ragas.embeddings import BaseRagasEmbeddings
from ragas.llms import BaseRagasLLM
from ragas.embeddings.base import BaseRagasEmbeddings
from ragas.llms.base import BaseRagasLLM
from ragas.losses import Loss
from ragas.optimizers import GeneticOptimizer, Optimizer

DEFAULT_OPTIMIZER_CONFIG = {"max_steps": 100}


class DemonstrationConfig(BaseModel):
embedding: t.Any # this has to be of type Any because BaseRagasEmbedding is an ABC
enabled: bool = True
top_k: int = 3
threshold: float = 0.7
technique: t.Literal["random", "similarity"] = "similarity"
embedding: t.Optional[BaseRagasEmbeddings] = None

@field_validator("embedding")
def validate_embedding(cls, v):
if not isinstance(v, BaseRagasEmbeddings):
raise ValueError("embedding must be an instance of BaseRagasEmbeddings")
return v


class InstructionConfig(BaseModel):
llm: BaseRagasLLM
enabled: bool = True
loss: t.Optional[Loss] = None
optimizer: Optimizer = GeneticOptimizer()
optimizer_config: t.Dict[str, t.Any] = Field(
default_factory=lambda: DEFAULT_OPTIMIZER_CONFIG
)
llm: t.Optional[BaseRagasLLM] = None


InstructionConfig.model_rebuild()
12 changes: 11 additions & 1 deletion src/ragas/dataset_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ class PromptAnnotation(BaseModel):
prompt_input: t.Dict[str, t.Any]
prompt_output: t.Dict[str, t.Any]
is_accepted: bool
edited_output: t.Union[t.Dict[str, t.Any], None]
edited_output: t.Optional[t.Dict[str, t.Any]] = None

def __getitem__(self, key):
return getattr(self, key)
Expand Down Expand Up @@ -801,3 +801,13 @@ def stratified_batches(
all_batches.append(batch)

return all_batches

def get_prompt_annotations(self) -> t.Dict[str, t.List[PromptAnnotation]]:
"""
Get all the prompt annotations for each prompt as a list.
"""
prompt_annotations = defaultdict(list)
for sample in self.samples:
for prompt_name, prompt_annotation in sample.prompts.items():
prompt_annotations[prompt_name].append(prompt_annotation)
return prompt_annotations
152 changes: 118 additions & 34 deletions src/ragas/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@
from dataclasses import dataclass, field
from enum import Enum

from pydantic import ValidationError
from pysbd import Segmenter
from tqdm import tqdm

from ragas._analytics import EvaluationEvent, _analytics_batcher
from ragas.callbacks import ChainType, new_group
from ragas.dataset_schema import MetricAnnotation, MultiTurnSample, SingleTurnSample
from ragas.executor import is_event_loop_running
from ragas.losses import BinaryMetricLoss, MSELoss
from ragas.prompt import PromptMixin
from ragas.prompt import FewShotPydanticPrompt, PromptMixin
from ragas.run_config import RunConfig
from ragas.utils import (
RAGAS_SUPPORTED_LANGUAGE_CODES,
Expand Down Expand Up @@ -230,48 +232,30 @@ def init(self, run_config: RunConfig):
)
self.llm.set_run_config(run_config)

def train(
def _optimize_instruction(
self,
path: str,
demonstration_config: t.Optional[DemonstrationConfig] = None,
instruction_config: t.Optional[InstructionConfig] = None,
callbacks: t.Optional[Callbacks] = None,
run_config: t.Optional[RunConfig] = None,
batch_size: t.Optional[int] = None,
with_debugging_logs=False,
raise_exceptions: bool = True,
) -> None:

if not path.endswith(".json"):
raise ValueError("Train data must be in json format")

if instruction_config is None:
from ragas.config import InstructionConfig

instruction_config = InstructionConfig()

if demonstration_config is None:
from ragas.config import DemonstrationConfig

demonstration_config = DemonstrationConfig()

dataset = MetricAnnotation.from_json(path, metric_name=self.name)

optimizer = instruction_config.optimizer
llm = instruction_config.llm or self.llm
if llm is None:
instruction_config: InstructionConfig,
dataset: MetricAnnotation,
callbacks: Callbacks,
run_config: RunConfig,
batch_size: t.Optional[int],
with_debugging_logs: bool,
raise_exceptions: bool,
):
if self.llm is None:
raise ValueError(
f"Metric '{self.name}' has no valid LLM provided (self.llm is None). Please initantiate a the metric with an LLM to run." # noqa
)
optimizer = instruction_config.optimizer
if optimizer.llm is None:
optimizer.llm = llm
optimizer.llm = instruction_config.llm

# figure out the loss function
if instruction_config.loss is None:
if self.output_type is None:
raise ValueError(
f"Output type for metric '{self.name}' is not defined. Please set the output type in the metric or in the instruction config."
)

if self.output_type.name == MetricOutputType.BINARY.name:
loss_fun = BinaryMetricLoss()
elif (
Expand All @@ -286,8 +270,8 @@ def train(
else:
loss_fun = instruction_config.loss

# Optimize the prompts
optimizer.metric = self

optimizer_config = instruction_config.optimizer_config or {}
optimized_prompts = optimizer.optimize(
dataset[self.name],
Expand All @@ -299,11 +283,111 @@ def train(
with_debugging_logs=with_debugging_logs,
raise_exceptions=raise_exceptions,
)

# replace the instruction in the metric with the optimized instruction
prompts = self.get_prompts()
for key, val in optimized_prompts.items():
prompts[key].instruction = val
self.set_prompts(**prompts)
return

def _optimize_demonstration(
self, demonstration_config: DemonstrationConfig, dataset: MetricAnnotation
):
# get the prompt annotations for this metric
prompt_annotations = dataset[self.name].get_prompt_annotations()
prompts = self.get_prompts()
for prompt_name, prompt_annotation_list in prompt_annotations.items():
# create a new FewShotPydanticPrompt with these annotations
if prompt_name not in prompts:
raise ValueError(
f"Prompt '{prompt_name}' not found in metric '{self.name}'. Please check the prompt names in the annotation dataset."
)
pydantic_prompt = prompts[prompt_name]
input_model, output_model = (
pydantic_prompt.input_model,
pydantic_prompt.output_model,
)
# convert annotations into examples
input_examples, output_examples = [], []
for i, prompt_annotation in enumerate(prompt_annotation_list):
try:
# skip if the prompt is not accepted
if not prompt_annotation.is_accepted:
continue
input_examples.append(
input_model.model_validate(prompt_annotation.prompt_input)
)
# use the edited output if it is provided
if prompt_annotation.edited_output is not None:
output_examples.append(
output_model.model_validate(prompt_annotation.edited_output)
)
else:
output_examples.append(
output_model.model_validate(prompt_annotation.prompt_output)
)
except ValidationError as e:
logger.warning(
f"Skipping prompt '{prompt_name}' example {i} because of validation error: {e}"
)
continue
embedding_model = demonstration_config.embedding
few_shot_prompt = FewShotPydanticPrompt.from_pydantic_prompt(
pydantic_prompt=pydantic_prompt,
embeddings=embedding_model,
)

# add the top k examples to the few shot prompt
few_shot_prompt.top_k_for_examples = demonstration_config.top_k
few_shot_prompt.threshold_for_examples = demonstration_config.threshold

# add examples to the few shot prompt
for input_example, output_example in tqdm(
zip(input_examples, output_examples),
total=len(input_examples),
desc=f"Few-shot examples [{prompt_name}]",
):
few_shot_prompt.add_example(input_example, output_example)
prompts[prompt_name] = few_shot_prompt
self.set_prompts(**prompts)

def train(
self,
path: str,
demonstration_config: t.Optional[DemonstrationConfig] = None,
instruction_config: t.Optional[InstructionConfig] = None,
callbacks: t.Optional[Callbacks] = None,
run_config: t.Optional[RunConfig] = None,
batch_size: t.Optional[int] = None,
with_debugging_logs=False,
raise_exceptions: bool = True,
) -> None:
run_config = run_config or RunConfig()
callbacks = callbacks or []

# load the dataset from path
if not path.endswith(".json"):
raise ValueError("Train data must be in json format")
dataset = MetricAnnotation.from_json(path, metric_name=self.name)

# only optimize the instruction if instruction_config is provided
if instruction_config is not None:
self._optimize_instruction(
instruction_config=instruction_config,
dataset=dataset,
callbacks=callbacks,
run_config=run_config,
batch_size=batch_size,
with_debugging_logs=with_debugging_logs,
raise_exceptions=raise_exceptions,
)

# if demonstration_config is provided, optimize the demonstrations
if demonstration_config is not None:
self._optimize_demonstration(
demonstration_config=demonstration_config,
dataset=dataset,
)


@dataclass
Expand Down
Loading

0 comments on commit d432ed0

Please sign in to comment.