Skip to content

Commit

Permalink
Refine evaluation interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
KaQuMiQ authored Aug 9, 2024
1 parent b400631 commit 0d2cad2
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 64 deletions.
42 changes: 32 additions & 10 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ build-backend = "setuptools.build_meta"

[project]
name = "draive"
version = "0.24.1"
description = "Framework designed to simplify and accelerate the development of LLM-based applications."
version = "0.25.0"
readme = "README.md"
maintainers = [
{ name = "Kacper Kaliński", email = "[email protected]" },
Expand All @@ -20,21 +21,42 @@ classifiers = [
"Topic :: Software Development :: Libraries :: Application Frameworks",
]
license = { file = "LICENSE" }
dependencies = ["numpy~=1.26"]
dependencies = [
"numpy~=1.26",
]

[project.urls]
Homepage = "https://miquido.com"
Repository = "https://github.com/miquido/draive.git"

[project.optional-dependencies]
sentencepiece = ["sentencepiece~=0.2"]
fastembed = ["fastembed~=0.3.0"]
openai = ["openai~=1.32", "tiktoken~=0.7"]
anthropic = ["anthropic~=0.29.0"]
mistral = ["httpx~=0.27", "draive[sentencepiece]"]
gemini = ["httpx~=0.27", "draive[sentencepiece]"]
ollama = ["httpx~=0.27"]
mistralrs = ["mistralrs~=0.1.19"]
sentencepiece = [
"sentencepiece~=0.2",
]
fastembed = [
"fastembed~=0.3.0",
]
openai = [
"openai~=1.32",
"tiktoken~=0.7",
]
anthropic = [
"anthropic~=0.29.0",
]
mistral = [
"httpx~=0.27",
"draive[sentencepiece]",
]
gemini = [
"httpx~=0.27",
"draive[sentencepiece]",
]
ollama = [
"httpx~=0.27",
]
mistralrs = [
"mistralrs~=0.1.19",
]

dev = [
"draive[sentencepiece]",
Expand Down
2 changes: 1 addition & 1 deletion src/draive/anthropic/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ async def _completion( # noqa: PLR0913, PLR0912, C901
else:
raise AnthropicException("Invalid Anthropic completion", completion)

case "end_turn":
case "end_turn" | "stop_sequence":
if (tool_calls := tool_calls) and (tools := tools):
ctx.record(ResultTrace.of(tool_calls))
return LMMToolRequests(
Expand Down
47 changes: 15 additions & 32 deletions src/draive/evaluation/scenario.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from asyncio import gather
from collections.abc import Callable, Sequence
from typing import Protocol, overload, runtime_checkable

from draive.evaluation.evaluator import EvaluatorResult, PreparedEvaluator
from draive.evaluation.evaluator import EvaluatorResult
from draive.parameters import DataModel, Field
from draive.types import frozenlist
from draive.utils import freeze
Expand Down Expand Up @@ -33,6 +32,7 @@ class PreparedScenarioEvaluator[Value](Protocol):
async def __call__(
self,
value: Value,
/,
) -> ScenarioEvaluatorResult: ...


Expand All @@ -41,11 +41,13 @@ class ScenarioEvaluatorDefinition[Value, **Args](Protocol):
@property
def __name__(self) -> str: ...

def __call__(
async def __call__(
self,
value: Value,
/,
*args: Args.args,
**kwargs: Args.kwargs,
) -> Sequence[PreparedEvaluator[Value]] | PreparedEvaluator[Value]: ...
) -> Sequence[EvaluatorResult]: ...


class ScenarioEvaluator[Value, **Args]:
Expand All @@ -64,25 +66,13 @@ def prepared(
*args: Args.args,
**kwargs: Args.kwargs,
) -> PreparedScenarioEvaluator[Value]:
prepared_evaluators: Sequence[PreparedEvaluator[Value]]
match self._definition(*args, **kwargs):
case [*evaluators]:
prepared_evaluators = evaluators

case evaluator:
prepared_evaluators = (evaluator,)

async def evaluate(
value: Value,
) -> ScenarioEvaluatorResult:
return ScenarioEvaluatorResult(
name=self.name,
evaluations=tuple(
await gather(
*[evaluator(value) for evaluator in prepared_evaluators],
return_exceptions=False,
),
),
return await self(
value,
*args,
**kwargs,
)

return evaluate
Expand All @@ -94,21 +84,14 @@ async def __call__(
*args: Args.args,
**kwargs: Args.kwargs,
) -> ScenarioEvaluatorResult:
prepared_evaluators: Sequence[PreparedEvaluator[Value]]
match self._definition(*args, **kwargs):
case [*evaluators]:
prepared_evaluators = evaluators

case evaluator:
prepared_evaluators = (evaluator,)

return ScenarioEvaluatorResult(
name=self.name,
evaluations=tuple(
await gather(
*[evaluator(value) for evaluator in prepared_evaluators],
return_exceptions=False,
),
await self._definition(
value,
*args,
**kwargs,
)
),
)

Expand Down
74 changes: 53 additions & 21 deletions src/draive/evaluation/suite.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from asyncio import Lock, gather
from collections.abc import Callable
from pathlib import Path
from typing import Protocol, overload, runtime_checkable
from typing import Protocol, Self, overload, runtime_checkable
from uuid import UUID, uuid4

from draive.evaluation.scenario import ScenarioEvaluatorResult
from draive.evaluation.evaluator import EvaluatorResult, PreparedEvaluator
from draive.evaluation.scenario import PreparedScenarioEvaluator, ScenarioEvaluatorResult
from draive.parameters import DataModel, Field
from draive.scope import ctx
from draive.types import frozenlist
Expand Down Expand Up @@ -33,7 +34,7 @@ class EvaluationSuiteCaseResult[CaseParameters: DataModel, Value: DataModel | st
value: Value = Field(
description="Evaluated value",
)
results: frozenlist[ScenarioEvaluatorResult] = Field(
results: frozenlist[ScenarioEvaluatorResult | EvaluatorResult] = Field(
description="Evaluation results",
)

Expand All @@ -43,10 +44,40 @@ def passed(self) -> bool:


class EvaluationCaseResult[Value: DataModel | str](DataModel):
@classmethod
def of(
cls,
results: ScenarioEvaluatorResult | EvaluatorResult,
*_results: ScenarioEvaluatorResult | EvaluatorResult,
value: Value,
) -> Self:
return cls(
value=value,
results=(results, *_results),
)

@classmethod
async def evaluating(
cls,
value: Value,
/,
evaluators: PreparedScenarioEvaluator[Value] | PreparedEvaluator[Value],
*_evaluators: PreparedScenarioEvaluator[Value] | PreparedEvaluator[Value],
) -> Self:
return cls(
value=value,
results=tuple(
await gather(
*[evaluator(value) for evaluator in [evaluators, *_evaluators]],
return_exceptions=False,
),
),
)

value: Value = Field(
description="Evaluated value",
)
results: frozenlist[ScenarioEvaluatorResult] = Field(
results: frozenlist[ScenarioEvaluatorResult | EvaluatorResult] = Field(
description="Evaluation results",
)

Expand All @@ -55,7 +86,7 @@ class EvaluationCaseResult[Value: DataModel | str](DataModel):
class EvaluationSuiteDefinition[CaseParameters: DataModel, Value: DataModel | str](Protocol):
async def __call__(
self,
evaluation_case: CaseParameters,
parameters: CaseParameters,
) -> EvaluationCaseResult[Value]: ...


Expand Down Expand Up @@ -89,33 +120,36 @@ def __init__(
@overload
async def __call__(
self,
parameters: CaseParameters | UUID | None,
/,
*,
evaluated_case: CaseParameters | UUID | None,
reload: bool = False,
) -> EvaluationSuiteCaseResult[CaseParameters, Value]: ...

@overload
async def __call__(
self,
/,
*,
reload: bool = False,
) -> list[EvaluationSuiteCaseResult[CaseParameters, Value]]: ...

async def __call__(
self,
parameters: CaseParameters | UUID | None = None,
/,
*,
evaluated_case: CaseParameters | UUID | None = None,
reload: bool = False,
) -> (
list[EvaluationSuiteCaseResult[CaseParameters, Value]]
| EvaluationSuiteCaseResult[CaseParameters, Value]
):
async with self._lock:
match evaluated_case:
match parameters:
case None:
return await gather(
*[
self._evaluate(evaluated_case=case)
self._evaluate(case=case)
for case in (await self._data(reload=reload)).cases
],
return_exceptions=False,
Expand All @@ -130,42 +164,40 @@ async def __call__(
iter([case for case in available_cases if case.identifier == identifier]),
None,
):
return await self._evaluate(evaluated_case=evaluation_case)
return await self._evaluate(case=evaluation_case)

else:
raise ValueError(f"Evaluation case with ID {identifier} does not exists.")

case case_parameters:
return await self._evaluate(
evaluated_case=EvaluationSuiteCase[CaseParameters](
case=EvaluationSuiteCase[CaseParameters](
parameters=case_parameters,
)
)

async def _evaluate(
self,
*,
evaluated_case: EvaluationSuiteCase[CaseParameters],
case: EvaluationSuiteCase[CaseParameters],
) -> EvaluationSuiteCaseResult[CaseParameters, Value]:
case_result: EvaluationCaseResult[Value] = await self._definition(
evaluation_case=evaluated_case.parameters
)
result: EvaluationCaseResult[Value] = await self._definition(parameters=case.parameters)

return EvaluationSuiteCaseResult[CaseParameters, Value](
case=evaluated_case,
value=case_result.value,
results=case_result.results,
case=case,
value=result.value,
results=result.results,
)

async def _data(
self,
reload: bool = False,
) -> EvaluationSuiteData[CaseParameters]:
if (data := self._data_cache) and not reload:
return data
if reload or self._data_cache is None:
self._data_cache = await self._storage.load()
return self._data_cache

else:
self._data_cache = await self._storage.load()
return self._data_cache

async def cases(
Expand Down

0 comments on commit 0d2cad2

Please sign in to comment.