From 602ed78ff1beae82dc05498b617fc3ead5296cba Mon Sep 17 00:00:00 2001 From: Alice Purcell Date: Wed, 6 Nov 2024 09:25:58 +0000 Subject: [PATCH] Rename new function to set_enum_based_on_type This function only support setting the enum field, and it doesn't appear likely that we will want to set other fields in future, so make the name more specific. Additionally, add a docstring, and refactor the function to use the early-return pattern now adding other metadata is ruled out. --- src/hera/shared/_type_util.py | 17 +++++++++++------ src/hera/workflows/script.py | 4 ++-- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/hera/shared/_type_util.py b/src/hera/shared/_type_util.py index cfb9c51e..f1c27f40 100644 --- a/src/hera/shared/_type_util.py +++ b/src/hera/shared/_type_util.py @@ -83,11 +83,16 @@ def get_workflow_annotation(annotation: Any) -> Optional[Union[Artifact, Paramet return metadata[0] -def add_metadata_from_type(parameter: Parameter, annotation: Any) -> None: - if not parameter.enum: - type_ = unwrap_annotation(annotation) - if get_origin(type_) is Literal: - parameter.enum = list(get_args(type_)) +def set_enum_based_on_type(parameter: Parameter, annotation: Any) -> None: + """Sets the enum field of a Parameter based on its type annotation. + + Currently, only supports Literals. + """ + if parameter.enum: + return + type_ = unwrap_annotation(annotation) + if get_origin(type_) is Literal: + parameter.enum = list(get_args(type_)) def construct_io_from_annotation(python_name: str, annotation: Any) -> Union[Parameter, Artifact]: @@ -107,7 +112,7 @@ def construct_io_from_annotation(python_name: str, annotation: Any) -> Union[Par io.name = io.name or python_name if isinstance(io, Parameter): - add_metadata_from_type(io, annotation) + set_enum_based_on_type(io, annotation) return io diff --git a/src/hera/workflows/script.py b/src/hera/workflows/script.py index 0974e488..2065c438 100644 --- a/src/hera/workflows/script.py +++ b/src/hera/workflows/script.py @@ -48,11 +48,11 @@ ) from hera.shared._pydantic import _PYDANTIC_VERSION, root_validator, validator from hera.shared._type_util import ( - add_metadata_from_type, construct_io_from_annotation, get_workflow_annotation, is_subscripted, origin_type_issupertype, + set_enum_based_on_type, ) from hera.shared.serialization import serialize from hera.workflows._context import _context @@ -380,7 +380,7 @@ def _get_parameters_from_callable(source: Callable) -> List[Parameter]: default = MISSING param = Parameter(name=p.name, default=default) - add_metadata_from_type(param, p.annotation) + set_enum_based_on_type(param, p.annotation) parameters.append(param) return parameters