Skip to content

Commit

Permalink
Refine evaluation interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
KaQuMiQ authored Aug 19, 2024
1 parent 5288d77 commit 5717680
Show file tree
Hide file tree
Showing 31 changed files with 1,111 additions and 1,130 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ ifndef INSTALL_OPTIONS
endif

ifndef UV_VERSION
UV_VERSION := 0.2.25
UV_VERSION := 0.2.37
endif

.PHONY: install venv sync lock update format lint test release
Expand Down
18 changes: 10 additions & 8 deletions constraints
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ idna==3.7
iniconfig==2.0.0
# via pytest
jiter==0.5.0
# via anthropic
# via
# anthropic
# openai
loguru==0.7.2
# via fastembed
markdown-it-py==3.0.0
Expand All @@ -80,9 +82,9 @@ numpy==1.26.4
# onnxruntime
onnx==1.16.2
# via fastembed
onnxruntime==1.18.1
onnxruntime==1.19.0
# via fastembed
openai==1.38.0
openai==1.41.0
# via draive (pyproject.toml)
packaging==24.1
# via
Expand All @@ -107,7 +109,7 @@ pydantic-core==2.20.1
# via pydantic
pygments==2.18.0
# via rich
pyright==1.1.374
pyright==1.1.375
# via draive (pyproject.toml)
pystemmer==2.2.0.1
# via fastembed
Expand All @@ -120,7 +122,7 @@ pytest-asyncio==0.23.8
# via draive (pyproject.toml)
pytest-cov==4.1.0
# via draive (pyproject.toml)
pyyaml==6.0.1
pyyaml==6.0.2
# via
# bandit
# huggingface-hub
Expand All @@ -133,7 +135,7 @@ requests==2.32.3
# tiktoken
rich==13.7.1
# via bandit
ruff==0.5.6
ruff==0.5.7
# via draive (pyproject.toml)
sentencepiece==0.2.0
# via draive (pyproject.toml)
Expand All @@ -147,11 +149,11 @@ snowballstemmer==2.2.0
# via fastembed
stevedore==5.2.0
# via bandit
sympy==1.13.1
sympy==1.13.2
# via onnxruntime
tiktoken==0.7.0
# via draive (pyproject.toml)
tokenizers==0.19.1
tokenizers==0.20.0
# via
# anthropic
# fastembed
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "draive"
description = "Framework designed to simplify and accelerate the development of LLM-based applications."
version = "0.25.0"
version = "0.26.0"
readme = "README.md"
maintainers = [
{ name = "Kacper Kaliński", email = "[email protected]" },
Expand Down
6 changes: 5 additions & 1 deletion src/draive/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@
MultimodalContent,
MultimodalContentConvertible,
MultimodalContentElement,
MultimodalContentPlaceholder,
MultimodalTemplate,
RateLimitError,
TextContent,
VideoBase64Content,
Expand Down Expand Up @@ -284,10 +286,12 @@
"ModelGeneration",
"ModelGenerator",
"ModelGeneratorDecoder",
"Multimodal",
"MultimodalContent",
"MultimodalContentConvertible",
"MultimodalContentElement",
"Multimodal",
"MultimodalContentPlaceholder",
"MultimodalTemplate",
"noop",
"not_missing",
"ParameterDefaultFactory",
Expand Down
13 changes: 9 additions & 4 deletions src/draive/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,39 @@
from draive.evaluation.evaluator import (
Evaluator,
EvaluatorDefinition,
EvaluatorResult,
PreparedEvaluator,
evaluator,
)
from draive.evaluation.scenario import (
EvaluationScenarioResult,
PreparedScenarioEvaluator,
ScenarioEvaluator,
ScenarioEvaluatorDefinition,
ScenarioEvaluatorResult,
evaluation_scenario,
)
from draive.evaluation.score import Evaluation, EvaluationScore
from draive.evaluation.score import EvaluationScore
from draive.evaluation.suite import (
EvaluationCaseResult,
EvaluationSuite,
EvaluationSuiteCase,
EvaluationSuiteCaseResult,
EvaluationSuiteDefinition,
EvaluationSuiteStorage,
SuiteEvaluatorCaseResult,
SuiteEvaluatorResult,
evaluation_suite,
)

__all__ = [
"evaluation_scenario",
"evaluation_suite",
"Evaluation",
"EvaluatorDefinition",
"EvaluationCaseResult",
"EvaluationScenarioResult",
"EvaluationScore",
"EvaluationSuite",
"EvaluationSuiteCase",
"EvaluationSuiteCaseResult",
"EvaluationSuiteDefinition",
"EvaluationSuiteStorage",
"evaluator",
Expand All @@ -41,4 +44,6 @@
"ScenarioEvaluator",
"ScenarioEvaluatorDefinition",
"ScenarioEvaluatorResult",
"SuiteEvaluatorCaseResult",
"SuiteEvaluatorResult",
]
126 changes: 98 additions & 28 deletions src/draive/evaluation/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from collections.abc import Callable
from typing import Protocol, Self, cast, final, overload, runtime_checkable

from draive.evaluation.score import Evaluation, EvaluationScore
from draive.evaluation.score import EvaluationScore
from draive.parameters import DataModel, Field, ParameterPath
from draive.scope import ctx
from draive.utils import freeze

__all__ = [
"evaluator",
"Evaluator",
"EvaluatorResult",
"PreparedEvaluator",
"EvaluatorDefinition",
]


Expand All @@ -23,12 +25,63 @@ class EvaluatorResult(DataModel):
threshold: float = Field(
description="Score threshold required to pass evaluation",
)
meta: dict[str, str | float | int | bool | None] | None = Field(
description="Additional evaluation metadata",
default=None,
)

@property
def passed(self) -> bool:
return self.score.value >= self.threshold


class EvaluationResult(DataModel):
@classmethod
async def of(
cls,
score: EvaluationScore | float | bool,
/,
meta: dict[str, str | float | int | bool | None] | None = None,
) -> Self:
evaluation_score: EvaluationScore
match score:
case EvaluationScore() as score:
evaluation_score = score

case float() as value:
evaluation_score = EvaluationScore(value=value)

case passed:
evaluation_score = EvaluationScore(value=1.0 if passed else 0.0)

return cls(
score=evaluation_score,
meta=meta,
)

score: EvaluationScore = Field(
description="Evaluation score",
)
meta: dict[str, str | float | int | bool | None] | None = Field(
description="Additional evaluation metadata",
default=None,
)


@runtime_checkable
class EvaluatorDefinition[Value, **Args](Protocol):
@property
def __name__(self) -> str: ...

async def __call__(
self,
value: Value,
/,
*args: Args.args,
**kwargs: Args.kwargs,
) -> EvaluationResult | EvaluationScore | float | bool: ...


@runtime_checkable
class PreparedEvaluator[Value](Protocol):
async def __call__(
Expand All @@ -43,14 +96,14 @@ class Evaluator[Value, **Args]:
def __init__(
self,
name: str,
evaluation: Evaluation[Value, Args],
definition: EvaluatorDefinition[Value, Args],
threshold: float | None,
) -> None:
assert ( # nosec: B101
threshold is None or 0 <= threshold <= 1
), "Evaluation threshold has to be between 0 and 1"

self._evaluation: Evaluation[Value, Args] = evaluation
self._definition: EvaluatorDefinition[Value, Args] = definition
self.name: str = name
self.threshold: float = threshold or 1

Expand All @@ -62,7 +115,7 @@ def with_threshold(
) -> Self:
return self.__class__(
name=self.name,
evaluation=self._evaluation,
definition=self._definition,
threshold=threshold,
)

Expand Down Expand Up @@ -102,16 +155,16 @@ async def evaluation(
value: Mapped,
*args: Args.args,
**kwargs: Args.kwargs,
) -> EvaluationScore | float | bool:
return await self._evaluation(
) -> EvaluationResult | EvaluationScore | float | bool:
return await self._definition(
mapper(value),
*args,
**kwargs,
)

return Evaluator[Mapped, Args](
name=self.name,
evaluation=evaluation,
definition=evaluation,
threshold=self.threshold,
)

Expand All @@ -123,34 +176,51 @@ async def __call__(
**kwargs: Args.kwargs,
) -> EvaluatorResult:
evaluation_score: EvaluationScore
match await self._evaluation(
value,
*args,
**kwargs,
):
case float() as score_value:
evaluation_score = EvaluationScore(value=score_value)
evaluation_meta: dict[str, str | float | int | bool | None] | None
try:
match await self._definition(
value,
*args,
**kwargs,
):
case EvaluationResult() as result:
evaluation_score = result.score
evaluation_meta = result.meta

case bool() as score_bool:
evaluation_score = EvaluationScore(value=1 if score_bool else 0)
case EvaluationScore() as score:
evaluation_score = score
evaluation_meta = None

case EvaluationScore() as score:
evaluation_score = score
case float() as score_value:
evaluation_score = EvaluationScore(value=score_value)
evaluation_meta = None

case passed:
evaluation_score = EvaluationScore(value=1 if passed else 0)
evaluation_meta = None

# for whatever reason pyright wants int to be handled...
case int() as score_int:
evaluation_score = EvaluationScore(value=float(score_int))
except Exception as exc:
ctx.log_error(
f"Evaluator `{self.name}` failed, using `0` score fallback result",
exception=exc,
)
evaluation_score = EvaluationScore(
value=0,
comment="Evaluation failed",
)
evaluation_meta = {"exception": str(exc)}

return EvaluatorResult(
evaluator=self.name,
score=evaluation_score,
threshold=self.threshold,
meta=evaluation_meta,
)


@overload
def evaluator[Value, **Args](
evaluation: Evaluation[Value, Args] | None = None,
definition: EvaluatorDefinition[Value, Args] | None = None,
/,
) -> Evaluator[Value, Args]: ...

Expand All @@ -161,29 +231,29 @@ def evaluator[Value, **Args](
name: str | None = None,
threshold: float | None = None,
) -> Callable[
[Evaluation[Value, Args]],
[EvaluatorDefinition[Value, Args]],
Evaluator[Value, Args],
]: ...


def evaluator[Value, **Args](
evaluation: Evaluation[Value, Args] | None = None,
evaluation: EvaluatorDefinition[Value, Args] | None = None,
*,
name: str | None = None,
threshold: float | None = None,
) -> (
Callable[
[Evaluation[Value, Args]],
[EvaluatorDefinition[Value, Args]],
Evaluator[Value, Args],
]
| Evaluator[Value, Args]
):
def wrap(
evaluation: Evaluation[Value, Args],
definition: EvaluatorDefinition[Value, Args],
) -> Evaluator[Value, Args]:
return Evaluator(
name=name or evaluation.__name__,
evaluation=evaluation,
name=name or definition.__name__,
definition=definition,
threshold=threshold,
)

Expand Down
Loading

0 comments on commit 5717680

Please sign in to comment.