Skip to content

Commit

Permalink
Add typed ParameterPath as a replacement for KeyPath
Browse files Browse the repository at this point in the history
  • Loading branch information
KaQuMiQ committed May 7, 2024
1 parent c2a1ecf commit b3b64b0
Show file tree
Hide file tree
Showing 8 changed files with 435 additions and 132 deletions.
3 changes: 2 additions & 1 deletion src/draive/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
openai_lmm_completion,
openai_tokenize_text,
)
from draive.parameters import Argument, Field
from draive.parameters import Argument, Field, ParameterPath
from draive.scope import (
ScopeDependencies,
ScopeDependency,
Expand Down Expand Up @@ -219,6 +219,7 @@
"OpenAIEmbeddingConfig",
"OpenAIException",
"OpenAIImageGenerationConfig",
"ParameterPath",
"ReadOnlyMemory",
"ScopeDependencies",
"ScopeDependency",
Expand Down
7 changes: 0 additions & 7 deletions src/draive/keypaths/__init__.py

This file was deleted.

26 changes: 0 additions & 26 deletions src/draive/keypaths/component.py

This file was deleted.

90 changes: 0 additions & 90 deletions src/draive/keypaths/keypath.py

This file was deleted.

2 changes: 2 additions & 0 deletions src/draive/parameters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from draive.parameters.data import Field, ParametrizedData
from draive.parameters.definition import ParametersDefinition
from draive.parameters.function import Argument, Function
from draive.parameters.path import ParameterPath
from draive.parameters.specification import ParametersSpecification, ToolSpecification
from draive.parameters.tool import ParametrizedTool

Expand All @@ -14,4 +15,5 @@
"ParametrizedData",
"ParametrizedData",
"ParametrizedTool",
"ParameterPath",
]
37 changes: 29 additions & 8 deletions src/draive/parameters/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from dataclasses import dataclass
from dataclasses import field as dataclass_field
from dataclasses import fields as dataclass_fields
from typing import Any, Self, cast, dataclass_transform, overload
from typing import Any, ClassVar, Self, cast, dataclass_transform, overload

from draive.keypaths import KeyPath
from draive.parameters.definition import ParameterDefinition, ParametersDefinition
from draive.parameters.missing import MISSING_PARAMETER, MissingParameter
from draive.parameters.path import ParameterPath
from draive.parameters.specification import ParameterSpecification
from draive.parameters.validation import parameter_validator

Expand Down Expand Up @@ -203,17 +203,36 @@ def _field_parameter(


class ParametrizedData(metaclass=ParametrizedDataMeta):
_: ClassVar[Self]

def __init_subclass__(cls) -> None:
super().__init_subclass__()
cls._: Self = cast(
Self,
ParameterPath(cls, cls), # type: ignore
)

@classmethod
def key_path(
def path(
cls,
path: str | None = None,
/,
) -> KeyPath[Self]:
return KeyPath(
cls,
path=path,
) -> Self:
return cast(
Self,
ParameterPath(cls, cls), # type: ignore
)

@classmethod
def path_cast[Parameter](
cls,
path: Parameter,
/,
) -> ParameterPath[Self, Parameter]:
assert isinstance( # nosec: B101
path, ParameterPath
), "Prepare parameter path by using Self._.path.to.property"
return cast(ParameterPath[Self, Parameter], path)

@classmethod
def validated(
cls,
Expand All @@ -229,8 +248,10 @@ def validator(
) -> Self:
if isinstance(value, cls):
return value

elif isinstance(value, dict):
return cls.validated(**value)

else:
raise TypeError("Invalid value %s", value)

Expand Down
Loading

0 comments on commit b3b64b0

Please sign in to comment.