Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 83 additions & 17 deletions pydantic_evals/pydantic_evals/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from .evaluators.spec import EvaluatorSpec
from .otel import SpanTree
from .otel._context_subtree import context_subtree
from .reporting import EvaluationReport, ReportCase, ReportCaseAggregate, ReportCaseFailure
from .reporting import EvaluationReport, ReportCase, ReportCaseAggregate, ReportCaseFailure, ReportCaseMultiRun

if TYPE_CHECKING:
from pydantic_ai.retries import RetryConfig
Expand Down Expand Up @@ -264,6 +264,7 @@ async def evaluate(
retry_task: RetryConfig | None = None,
retry_evaluators: RetryConfig | None = None,
*,
runs: int = 1,
task_name: str | None = None,
metadata: dict[str, Any] | None = None,
) -> EvaluationReport[InputsT, OutputT, MetadataT]:
Expand All @@ -282,6 +283,7 @@ async def evaluate(
progress: Whether to show a progress bar for the evaluation. Defaults to `True`.
retry_task: Optional retry configuration for the task execution.
retry_evaluators: Optional retry configuration for evaluator execution.
runs: The number of times to run each case. Defaults to 1.
task_name: Optional override to the name of the task being executed, otherwise the name of the task
function will be used.
metadata: Optional dict of experiment metadata.
Expand All @@ -291,7 +293,7 @@ async def evaluate(
"""
task_name = task_name or get_unwrapped_function_name(task)
name = name or task_name
total_cases = len(self.cases)
total_cases = len(self.cases) * runs
progress_bar = Progress() if progress else None

limiter = anyio.Semaphore(max_concurrency) if max_concurrency is not None else AsyncExitStack()
Expand All @@ -306,40 +308,101 @@ async def evaluate(
task_name=task_name,
dataset_name=self.name,
n_cases=len(self.cases),
runs=runs,
**extra_attributes,
) as eval_span,
progress_bar or nullcontext(),
):
task_id = progress_bar.add_task(f'Evaluating {task_name}', total=total_cases) if progress_bar else None

async def _handle_case(case: Case[InputsT, OutputT, MetadataT], report_case_name: str):
async with limiter:
result = await _run_task_and_evaluators(
task, case, report_case_name, self.evaluators, retry_task, retry_evaluators
async def _handle_case(
case: Case[InputsT, OutputT, MetadataT], report_case_name: str
) -> list[
ReportCase[InputsT, OutputT, MetadataT]
| ReportCaseMultiRun[InputsT, OutputT, MetadataT]
| ReportCaseFailure[InputsT, OutputT, MetadataT]
]:
# If we are running multiple times, create a parent span for the case
cm = logfire_span('case: {name}', name=report_case_name) if runs > 1 else nullcontext()

results: list[
ReportCase[InputsT, OutputT, MetadataT]
| ReportCaseMultiRun[InputsT, OutputT, MetadataT]
| ReportCaseFailure[InputsT, OutputT, MetadataT]
] = []
with cm as span:
trace_id = None
span_id = None
if span and span.context:
trace_id = f'{span.context.trace_id:032x}'
span_id = f'{span.context.span_id:016x}'

for i in range(runs):
run_name = f'{report_case_name} (run {i + 1})' if runs > 1 else report_case_name
async with limiter:
result = await _run_task_and_evaluators(
task, case, run_name, self.evaluators, retry_task, retry_evaluators
)
results.append(result)
if progress_bar and task_id is not None: # pragma: no branch
progress_bar.update(task_id, advance=1)

if runs == 1:
return results

# Separate successes and failures
successes: list[ReportCase] = [r for r in results if isinstance(r, ReportCase)]
failures: list[ReportCaseFailure] = [r for r in results if isinstance(r, ReportCaseFailure)]

output_results: list[
ReportCase[InputsT, OutputT, MetadataT]
| ReportCaseMultiRun[InputsT, OutputT, MetadataT]
| ReportCaseFailure[InputsT, OutputT, MetadataT]
] = []
output_results.extend(failures)

if successes:
# Aggregate successes into a MultiRun
aggregate = ReportCaseAggregate.average(successes, name=report_case_name)
output_results.append(
ReportCaseMultiRun(
name=report_case_name,
inputs=case.inputs,
metadata=case.metadata,
expected_output=case.expected_output,
runs=successes,
aggregate=aggregate,
trace_id=trace_id,
span_id=span_id,
)
)
if progress_bar and task_id is not None: # pragma: no branch
progress_bar.update(task_id, advance=1)
return result
return output_results

if (context := eval_span.context) is None: # pragma: no cover
trace_id = None
span_id = None
else:
trace_id = f'{context.trace_id:032x}'
span_id = f'{context.span_id:016x}'
cases_and_failures = await task_group_gather(

# task_group_gather returns a list of results from each _handle_case call
nested_results = await task_group_gather(
[
lambda case=case, i=i: _handle_case(case, case.name or f'Case {i}')
for i, case in enumerate(self.cases, 1)
]
)
cases: list[ReportCase] = []
failures: list[ReportCaseFailure] = []
for item in cases_and_failures:
if isinstance(item, ReportCase):
cases.append(item)
else:
failures.append(item)

# Flatten results
cases: list[ReportCase[InputsT, OutputT, MetadataT] | ReportCaseMultiRun[InputsT, OutputT, MetadataT]] = []
failures: list[ReportCaseFailure[InputsT, OutputT, MetadataT]] = []
for group_result in nested_results:
for item in group_result:
if isinstance(item, (ReportCase, ReportCaseMultiRun)):
cases.append(item)
else:
failures.append(item)

report = EvaluationReport(
name=name,
cases=cases,
Expand Down Expand Up @@ -367,6 +430,7 @@ def evaluate_sync(
retry_task: RetryConfig | None = None,
retry_evaluators: RetryConfig | None = None,
*,
runs: int = 1,
task_name: str | None = None,
metadata: dict[str, Any] | None = None,
) -> EvaluationReport[InputsT, OutputT, MetadataT]:
Expand All @@ -384,6 +448,7 @@ def evaluate_sync(
progress: Whether to show a progress bar for the evaluation. Defaults to `True`.
retry_task: Optional retry configuration for the task execution.
retry_evaluators: Optional retry configuration for evaluator execution.
runs: The number of times to run each case. Defaults to 1.
task_name: Optional override to the name of the task being executed, otherwise the name of the task
function will be used.
metadata: Optional dict of experiment metadata.
Expand All @@ -399,6 +464,7 @@ def evaluate_sync(
progress=progress,
retry_task=retry_task,
retry_evaluators=retry_evaluators,
runs=runs,
task_name=task_name,
metadata=metadata,
)
Expand Down
Loading
Loading