From 97993c5d5d7d9abb7619e6eb3198bbe1b9b217a7 Mon Sep 17 00:00:00 2001 From: Alice Purcell Date: Tue, 15 Oct 2024 10:56:11 +0100 Subject: [PATCH 1/6] Validate IO-annotated optional string defaults Combine the two non-Pydantic paths of _get_inputs_from_callable into one using construct_io_from_annotation. This fixes a bug where the default was being verified as None for optional strings only if there was no IO annotation. Signed-off-by: Alice Purcell --- src/hera/workflows/script.py | 44 +++++++++++----------------------- tests/test_unit/test_script.py | 9 +++++++ 2 files changed, 23 insertions(+), 30 deletions(-) diff --git a/src/hera/workflows/script.py b/src/hera/workflows/script.py index e9ade1fc5..faac62962 100644 --- a/src/hera/workflows/script.py +++ b/src/hera/workflows/script.py @@ -495,22 +495,18 @@ class will be used as inputs, rather than the class itself. artifacts.extend(input_class._get_artifacts(add_missing_path=True)) - elif param_or_artifact := get_workflow_annotation(func_param.annotation): - if param_or_artifact.output: + else: + io = construct_io_from_annotation(func_param.name, func_param.annotation) + if io.output: continue - # Create a new object so we don't modify the Workflow itself - new_object = param_or_artifact.copy() - if not new_object.name: - new_object.name = func_param.name - - if isinstance(new_object, Artifact): - if new_object.path is None: - new_object.path = new_object._get_default_inputs_path() + if isinstance(io, Artifact): + if io.path is None: + io.path = io._get_default_inputs_path() - artifacts.append(new_object) - elif isinstance(new_object, Parameter): - if new_object.default is not None: + artifacts.append(io) + elif isinstance(io, Parameter): + if io.default is not None: # TODO: in 5.18 remove the flag check and `warn`, and raise the ValueError directly (minus "flag" text) warnings.warn( "Using the default field for Parameters in Annotations is deprecated since v5.16" @@ -524,27 +520,15 @@ class will be used as inputs, rather than the class itself. ) if func_param.default != inspect.Parameter.empty: # TODO: remove this check in 5.18: - if new_object.default is not None: + if io.default is not None: raise ValueError( "default cannot be set via both the function parameter default and the Parameter's default" ) - new_object.default = serialize(func_param.default) - parameters.append(new_object) - else: - if ( - func_param.default != inspect.Parameter.empty - and func_param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD - ): - default = func_param.default - else: - default = MISSING - - if origin_type_issupertype(func_param.annotation, NoneType) and ( - default is MISSING or default is not None - ): - raise ValueError(f"Optional parameter '{func_param.name}' must have a default value of None.") + io.default = serialize(func_param.default) - parameters.append(Parameter(name=func_param.name, default=default)) + if origin_type_issupertype(func_param.annotation, NoneType) and io.default != "null": + raise ValueError(f"Optional parameter '{func_param.name}' must have a default value of None.") + parameters.append(io) return parameters, artifacts diff --git a/tests/test_unit/test_script.py b/tests/test_unit/test_script.py index 46bc5c32d..b3e745ceb 100644 --- a/tests/test_unit/test_script.py +++ b/tests/test_unit/test_script.py @@ -182,6 +182,15 @@ def unknown_annotations_ignored(my_optional_string: Optional[str] = "123") -> st _get_inputs_from_callable(unknown_annotations_ignored) +def test_invalid_script_when_optional_parameter_does_not_have_default_value_6(): + @script() + def unknown_annotations_ignored(my_optional_string: Annotated[Optional[str], Parameter(name="my-string")]) -> str: + return "Got: {}".format(my_optional_string) + + with pytest.raises(ValueError, match="Optional parameter 'my_optional_string' must have a default value of None."): + _get_inputs_from_callable(unknown_annotations_ignored) + + def test_invalid_script_when_multiple_input_workflow_annotations_are_given(): @script() def invalid_script(a_str: Annotated[str, Artifact(name="a_str"), Parameter(name="a_str")] = "123") -> str: From fb27fa4c6387027dbf247b187ffd5dd802252fdd Mon Sep 17 00:00:00 2001 From: Alice Purcell Date: Fri, 11 Oct 2024 11:28:06 +0100 Subject: [PATCH 2/6] Support Literals in origin_type_issubclass Signed-off-by: Alice Purcell --- src/hera/shared/_type_util.py | 3 +++ tests/script_runner/parameter_inputs.py | 12 +++++++++++- tests/test_runner.py | 12 ++++++++++++ tests/test_unit/test_shared_type_utils.py | 6 +++++- 4 files changed, 31 insertions(+), 2 deletions(-) diff --git a/src/hera/shared/_type_util.py b/src/hera/shared/_type_util.py index fb36f5428..a19f57ee7 100644 --- a/src/hera/shared/_type_util.py +++ b/src/hera/shared/_type_util.py @@ -6,6 +6,7 @@ Any, Iterable, List, + Literal, Optional, Tuple, Type, @@ -120,6 +121,8 @@ def origin_type_issubtype(annotation: Any, type_: Union[type, Tuple[type, ...]]) origin_type = get_unsubscripted_type(unwrapped_type) if origin_type is Union or origin_type is UnionType: return all(origin_type_issubtype(arg, type_) for arg in get_args(unwrapped_type)) + if origin_type is Literal: + return all(isinstance(value, type_) for value in get_args(unwrapped_type)) return isinstance(origin_type, type) and issubclass(origin_type, type_) diff --git a/tests/script_runner/parameter_inputs.py b/tests/script_runner/parameter_inputs.py index 43cff08b3..e954767e5 100644 --- a/tests/script_runner/parameter_inputs.py +++ b/tests/script_runner/parameter_inputs.py @@ -1,5 +1,5 @@ import json -from typing import Any, List, Union +from typing import Any, List, Literal, Union try: from typing import Annotated @@ -53,6 +53,11 @@ def annotated_basic_types_with_other_metadata( return Output(output=[Input(a=a_but_kebab, b=b_but_kebab)]) +@script() +def annotated_str_literal(my_literal: Annotated[Literal["1", "2"], Parameter(name="str-literal")]) -> str: + return f"type given: {type(my_literal).__name__}" + + @script() def annotated_object(annotated_input_value: Annotated[Input, Parameter(name="input-value")]) -> Output: return Output(output=[annotated_input_value]) @@ -81,6 +86,11 @@ def str_or_int_parameter(my_str_or_int: Union[int, str]) -> str: return f"type given: {type(my_str_or_int).__name__}" +@script() +def str_literal(my_literal: Literal["1", "2"]) -> str: + return f"type given: {type(my_literal).__name__}" + + @script() def str_parameter_expects_jsonstr_dict(my_json_str: str) -> dict: return json.loads(my_json_str) diff --git a/tests/test_runner.py b/tests/test_runner.py index f61546695..83a60b436 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -77,6 +77,12 @@ "type given: int", id="str-or-int-given-int", ), + pytest.param( + "tests.script_runner.parameter_inputs:str_literal", + [{"name": "my_literal", "value": "1"}], + "type given: str", + id="str-literal", + ), pytest.param( "tests.script_runner.parameter_inputs:str_parameter_expects_jsonstr_dict", [{"name": "my_json_str", "value": json.dumps({"my": "dict"})}], @@ -89,6 +95,12 @@ [{"my": "dict"}], id="str-json-param-as-list", ), + pytest.param( + "tests.script_runner.parameter_inputs:annotated_str_literal", + [{"name": "my_literal", "value": "1"}], + "type given: str", + id="annotated-str-literal", + ), pytest.param( "tests.script_runner.parameter_inputs:annotated_str_parameter_expects_jsonstr_dict", [{"name": "my_json_str", "value": json.dumps({"my": "dict"})}], diff --git a/tests/test_unit/test_shared_type_utils.py b/tests/test_unit/test_shared_type_utils.py index e777e0b45..1b31907e1 100644 --- a/tests/test_unit/test_shared_type_utils.py +++ b/tests/test_unit/test_shared_type_utils.py @@ -1,5 +1,5 @@ import sys -from typing import List, NoReturn, Optional, Union +from typing import List, Literal, NoReturn, Optional, Union if sys.version_info >= (3, 9): from typing import Annotated @@ -151,6 +151,10 @@ def test_get_unsubscripted_type(annotation, expected): pytest.param(Annotated[Optional[str], "foo"], (str, NoneType), True, id="annotated-optional"), pytest.param(str, (str, NoneType), True, id="str-is-subtype-of-optional-str"), pytest.param(Union[int, str], (str, NoneType), False, id="union-int-str-not-subtype-of-optional-str"), + pytest.param(Literal["foo", "bar"], (str, NoneType), True, id="literal-str-is-subtype-of-optional-str"), + pytest.param(Literal["foo", None], (str, NoneType), True, id="literal-none-is-subtype-of-optional-str"), + pytest.param(Literal[1, 2], (str, NoneType), False, id="literal-int-not-subtype-of-optional-str"), + pytest.param(Literal[1, "foo"], (str, NoneType), False, id="mixed-literal-not-subtype-of-optional-str"), ], ) def test_origin_type_issubtype(annotation, target, expected): From 3631764275ad4234df90d12efad8cc287ca6f8da Mon Sep 17 00:00:00 2001 From: Alice Purcell Date: Mon, 28 Oct 2024 09:45:27 +0000 Subject: [PATCH 3/6] Copy Literals to input Parameter enum field Signed-off-by: Alice Purcell --- src/hera/shared/_type_util.py | 16 ++++++---- .../script_annotations/annotated_literals.py | 18 +++++++++++ tests/script_annotations/literals.py | 13 ++++++++ .../pydantic_io_literals.py | 24 +++++++++++++++ tests/test_script_annotations.py | 30 +++++++++++++++++++ 5 files changed, 96 insertions(+), 5 deletions(-) create mode 100644 tests/script_annotations/annotated_literals.py create mode 100644 tests/script_annotations/literals.py create mode 100644 tests/script_annotations/pydantic_io_literals.py diff --git a/src/hera/shared/_type_util.py b/src/hera/shared/_type_util.py index a19f57ee7..7a1ff10f9 100644 --- a/src/hera/shared/_type_util.py +++ b/src/hera/shared/_type_util.py @@ -92,13 +92,19 @@ def construct_io_from_annotation(python_name: str, annotation: Any) -> Union[Par For a function parameter, python_name should be the parameter name. For a Pydantic Input or Output class, python_name should be the field name. """ - if annotation := get_workflow_annotation(annotation): + if workflow_annotation := get_workflow_annotation(annotation): # Copy so as to not modify the fields themselves - annotation_copy = annotation.copy() - annotation_copy.name = annotation.name or python_name - return annotation_copy + io = workflow_annotation.copy() + else: + io = Parameter() - return Parameter(name=python_name) + io.name = io.name or python_name + if isinstance(io, Parameter) and not io.enum: + type_ = unwrap_annotation(annotation) + if get_origin(type_) is Literal: + io.enum = list(get_args(type_)) + + return io def get_unsubscripted_type(t: Any) -> Any: diff --git a/tests/script_annotations/annotated_literals.py b/tests/script_annotations/annotated_literals.py new file mode 100644 index 000000000..d15787fc7 --- /dev/null +++ b/tests/script_annotations/annotated_literals.py @@ -0,0 +1,18 @@ +from typing import Annotated, Literal + +from hera.shared import global_config +from hera.workflows import Parameter, Steps, Workflow, script + +global_config.experimental_features["script_annotations"] = True + + +@script(constructor="runner") +def literal_str( + my_str: Annotated[Literal["foo", "bar"], Parameter(name="my-str")], +) -> Annotated[Literal[1, 2], Parameter(name="index")]: + return {"foo": 1, "bar": 2}[my_str] + + +with Workflow(name="my-workflow", entrypoint="steps") as w: + with Steps(name="steps"): + literal_str() diff --git a/tests/script_annotations/literals.py b/tests/script_annotations/literals.py new file mode 100644 index 000000000..c1756f399 --- /dev/null +++ b/tests/script_annotations/literals.py @@ -0,0 +1,13 @@ +from typing import Literal + +from hera.workflows import Steps, Workflow, script + + +@script(constructor="runner") +def literal_str(my_str: Literal["foo", "bar"]) -> Literal[1, 2]: + return {"foo": 1, "bar": 2}[my_str] + + +with Workflow(name="my-workflow", entrypoint="steps") as w: + with Steps(name="steps"): + literal_str() diff --git a/tests/script_annotations/pydantic_io_literals.py b/tests/script_annotations/pydantic_io_literals.py new file mode 100644 index 000000000..057a8f65c --- /dev/null +++ b/tests/script_annotations/pydantic_io_literals.py @@ -0,0 +1,24 @@ +from typing import Literal + +from hera.shared import global_config +from hera.workflows import Input, Output, Steps, Workflow, script + +global_config.experimental_features["script_pydantic_io"] = True + + +class ExampleInput(Input): + my_str: Literal["foo", "bar"] + + +class ExampleOutput(Output): + index: Literal[1, 2] + + +@script(constructor="runner") +def literal_str(input: ExampleInput) -> ExampleOutput: + return ExampleOutput(index={"foo": 1, "bar": 2}[input.my_str]) + + +with Workflow(name="my-workflow", entrypoint="steps") as w: + with Steps(name="steps"): + literal_str() diff --git a/tests/test_script_annotations.py b/tests/test_script_annotations.py index 314269b0a..30a871437 100644 --- a/tests/test_script_annotations.py +++ b/tests/test_script_annotations.py @@ -476,3 +476,33 @@ def test_script_with_param(global_config_fixture, module_name): } ] assert consume_task["withParam"] == "{{tasks.generate.outputs.parameters.some-values}}" + + +@pytest.mark.parametrize( + ("module_name", "input_name"), + [ + pytest.param("tests.script_annotations.literals", "my_str", id="bare-type-annotation"), + pytest.param("tests.script_annotations.annotated_literals", "my-str", id="annotated"), + pytest.param("tests.script_annotations.pydantic_io_literals", "my_str", id="pydantic-io"), + ], +) +def test_script_literals(global_config_fixture, module_name, input_name): + """Test that Literals work correctly as direct type annotations.""" + # GIVEN + global_config_fixture.experimental_features["script_annotations"] = True + + # Force a reload of the test module, as the runner performs "importlib.import_module", which + # may fetch a cached version + module = importlib.import_module(module_name) + importlib.reload(module) + workflow: Workflow = importlib.import_module(module.__name__).w + + # WHEN + workflow_dict = workflow.to_dict() + assert workflow == Workflow.from_dict(workflow_dict) + assert workflow == Workflow.from_yaml(workflow.to_yaml()) + + # THEN + (literal_str,) = (t for t in workflow_dict["spec"]["templates"] if t["name"] == "literal-str") + + assert literal_str["inputs"]["parameters"] == [{"name": input_name, "enum": ["foo", "bar"]}] From 2831dfd76f79b0269ef8f02c14000ce55c61618f Mon Sep 17 00:00:00 2001 From: Alice Purcell Date: Mon, 28 Oct 2024 11:11:48 +0000 Subject: [PATCH 4/6] Also copy Literals outside of experimental feature Signed-off-by: Alice Purcell --- src/hera/shared/_type_util.py | 7 +++++++ src/hera/workflows/script.py | 2 ++ tests/test_script_annotations.py | 6 ++++-- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/hera/shared/_type_util.py b/src/hera/shared/_type_util.py index 7a1ff10f9..3dc7138ca 100644 --- a/src/hera/shared/_type_util.py +++ b/src/hera/shared/_type_util.py @@ -83,6 +83,13 @@ def get_workflow_annotation(annotation: Any) -> Optional[Union[Artifact, Paramet return metadata[0] +def add_metadata_from_type(parameter: Parameter, annotation: Any) -> None: + if not parameter.enum: + type_ = unwrap_annotation(annotation) + if get_origin(type_) is Literal: + parameter.enum = list(get_args(type_)) + + def construct_io_from_annotation(python_name: str, annotation: Any) -> Union[Parameter, Artifact]: """Constructs a Parameter or Artifact object based on annotations. diff --git a/src/hera/workflows/script.py b/src/hera/workflows/script.py index faac62962..0974e4885 100644 --- a/src/hera/workflows/script.py +++ b/src/hera/workflows/script.py @@ -48,6 +48,7 @@ ) from hera.shared._pydantic import _PYDANTIC_VERSION, root_validator, validator from hera.shared._type_util import ( + add_metadata_from_type, construct_io_from_annotation, get_workflow_annotation, is_subscripted, @@ -379,6 +380,7 @@ def _get_parameters_from_callable(source: Callable) -> List[Parameter]: default = MISSING param = Parameter(name=p.name, default=default) + add_metadata_from_type(param, p.annotation) parameters.append(param) return parameters diff --git a/tests/test_script_annotations.py b/tests/test_script_annotations.py index 30a871437..93fe49e36 100644 --- a/tests/test_script_annotations.py +++ b/tests/test_script_annotations.py @@ -486,10 +486,12 @@ def test_script_with_param(global_config_fixture, module_name): pytest.param("tests.script_annotations.pydantic_io_literals", "my_str", id="pydantic-io"), ], ) -def test_script_literals(global_config_fixture, module_name, input_name): +@pytest.mark.parametrize("experimental_feature", ["", "script_annotations", "script_pydantic_io"]) +def test_script_literals(global_config_fixture, module_name, input_name, experimental_feature): """Test that Literals work correctly as direct type annotations.""" # GIVEN - global_config_fixture.experimental_features["script_annotations"] = True + if experimental_feature: + global_config_fixture.experimental_features[experimental_feature] = True # Force a reload of the test module, as the runner performs "importlib.import_module", which # may fetch a cached version From fa6a37fac76fe4ce341757a22115e77ba7a56197 Mon Sep 17 00:00:00 2001 From: Alice Purcell Date: Wed, 6 Nov 2024 09:20:35 +0000 Subject: [PATCH 5/6] Reuse add_metadata_from_type Signed-off-by: Alice Purcell --- src/hera/shared/_type_util.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/hera/shared/_type_util.py b/src/hera/shared/_type_util.py index 3dc7138ca..cfb9c51ef 100644 --- a/src/hera/shared/_type_util.py +++ b/src/hera/shared/_type_util.py @@ -106,10 +106,8 @@ def construct_io_from_annotation(python_name: str, annotation: Any) -> Union[Par io = Parameter() io.name = io.name or python_name - if isinstance(io, Parameter) and not io.enum: - type_ = unwrap_annotation(annotation) - if get_origin(type_) is Literal: - io.enum = list(get_args(type_)) + if isinstance(io, Parameter): + add_metadata_from_type(io, annotation) return io From dbb4a6609ae3fc1feb338ddadb06ba629f49bdb8 Mon Sep 17 00:00:00 2001 From: Alice Purcell Date: Wed, 6 Nov 2024 09:25:58 +0000 Subject: [PATCH 6/6] Rename new function to set_enum_based_on_type This function only support setting the enum field, and it doesn't appear likely that we will want to set other fields in future, so make the name more specific. Additionally, add a docstring, and refactor the function to use the early-return pattern now adding other metadata is ruled out. Signed-off-by: Alice Purcell --- src/hera/shared/_type_util.py | 17 +++++++++++------ src/hera/workflows/script.py | 4 ++-- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/hera/shared/_type_util.py b/src/hera/shared/_type_util.py index cfb9c51ef..f1c27f405 100644 --- a/src/hera/shared/_type_util.py +++ b/src/hera/shared/_type_util.py @@ -83,11 +83,16 @@ def get_workflow_annotation(annotation: Any) -> Optional[Union[Artifact, Paramet return metadata[0] -def add_metadata_from_type(parameter: Parameter, annotation: Any) -> None: - if not parameter.enum: - type_ = unwrap_annotation(annotation) - if get_origin(type_) is Literal: - parameter.enum = list(get_args(type_)) +def set_enum_based_on_type(parameter: Parameter, annotation: Any) -> None: + """Sets the enum field of a Parameter based on its type annotation. + + Currently, only supports Literals. + """ + if parameter.enum: + return + type_ = unwrap_annotation(annotation) + if get_origin(type_) is Literal: + parameter.enum = list(get_args(type_)) def construct_io_from_annotation(python_name: str, annotation: Any) -> Union[Parameter, Artifact]: @@ -107,7 +112,7 @@ def construct_io_from_annotation(python_name: str, annotation: Any) -> Union[Par io.name = io.name or python_name if isinstance(io, Parameter): - add_metadata_from_type(io, annotation) + set_enum_based_on_type(io, annotation) return io diff --git a/src/hera/workflows/script.py b/src/hera/workflows/script.py index 0974e4885..2065c438c 100644 --- a/src/hera/workflows/script.py +++ b/src/hera/workflows/script.py @@ -48,11 +48,11 @@ ) from hera.shared._pydantic import _PYDANTIC_VERSION, root_validator, validator from hera.shared._type_util import ( - add_metadata_from_type, construct_io_from_annotation, get_workflow_annotation, is_subscripted, origin_type_issupertype, + set_enum_based_on_type, ) from hera.shared.serialization import serialize from hera.workflows._context import _context @@ -380,7 +380,7 @@ def _get_parameters_from_callable(source: Callable) -> List[Parameter]: default = MISSING param = Parameter(name=p.name, default=default) - add_metadata_from_type(param, p.annotation) + set_enum_based_on_type(param, p.annotation) parameters.append(param) return parameters