diff --git a/src/hera/shared/_type_util.py b/src/hera/shared/_type_util.py index fb36f542..f1c27f40 100644 --- a/src/hera/shared/_type_util.py +++ b/src/hera/shared/_type_util.py @@ -6,6 +6,7 @@ Any, Iterable, List, + Literal, Optional, Tuple, Type, @@ -82,6 +83,18 @@ def get_workflow_annotation(annotation: Any) -> Optional[Union[Artifact, Paramet return metadata[0] +def set_enum_based_on_type(parameter: Parameter, annotation: Any) -> None: + """Sets the enum field of a Parameter based on its type annotation. + + Currently, only supports Literals. + """ + if parameter.enum: + return + type_ = unwrap_annotation(annotation) + if get_origin(type_) is Literal: + parameter.enum = list(get_args(type_)) + + def construct_io_from_annotation(python_name: str, annotation: Any) -> Union[Parameter, Artifact]: """Constructs a Parameter or Artifact object based on annotations. @@ -91,13 +104,17 @@ def construct_io_from_annotation(python_name: str, annotation: Any) -> Union[Par For a function parameter, python_name should be the parameter name. For a Pydantic Input or Output class, python_name should be the field name. """ - if annotation := get_workflow_annotation(annotation): + if workflow_annotation := get_workflow_annotation(annotation): # Copy so as to not modify the fields themselves - annotation_copy = annotation.copy() - annotation_copy.name = annotation.name or python_name - return annotation_copy + io = workflow_annotation.copy() + else: + io = Parameter() + + io.name = io.name or python_name + if isinstance(io, Parameter): + set_enum_based_on_type(io, annotation) - return Parameter(name=python_name) + return io def get_unsubscripted_type(t: Any) -> Any: @@ -120,6 +137,8 @@ def origin_type_issubtype(annotation: Any, type_: Union[type, Tuple[type, ...]]) origin_type = get_unsubscripted_type(unwrapped_type) if origin_type is Union or origin_type is UnionType: return all(origin_type_issubtype(arg, type_) for arg in get_args(unwrapped_type)) + if origin_type is Literal: + return all(isinstance(value, type_) for value in get_args(unwrapped_type)) return isinstance(origin_type, type) and issubclass(origin_type, type_) diff --git a/src/hera/workflows/script.py b/src/hera/workflows/script.py index e9ade1fc..2065c438 100644 --- a/src/hera/workflows/script.py +++ b/src/hera/workflows/script.py @@ -52,6 +52,7 @@ get_workflow_annotation, is_subscripted, origin_type_issupertype, + set_enum_based_on_type, ) from hera.shared.serialization import serialize from hera.workflows._context import _context @@ -379,6 +380,7 @@ def _get_parameters_from_callable(source: Callable) -> List[Parameter]: default = MISSING param = Parameter(name=p.name, default=default) + set_enum_based_on_type(param, p.annotation) parameters.append(param) return parameters @@ -495,22 +497,18 @@ class will be used as inputs, rather than the class itself. artifacts.extend(input_class._get_artifacts(add_missing_path=True)) - elif param_or_artifact := get_workflow_annotation(func_param.annotation): - if param_or_artifact.output: + else: + io = construct_io_from_annotation(func_param.name, func_param.annotation) + if io.output: continue - # Create a new object so we don't modify the Workflow itself - new_object = param_or_artifact.copy() - if not new_object.name: - new_object.name = func_param.name - - if isinstance(new_object, Artifact): - if new_object.path is None: - new_object.path = new_object._get_default_inputs_path() + if isinstance(io, Artifact): + if io.path is None: + io.path = io._get_default_inputs_path() - artifacts.append(new_object) - elif isinstance(new_object, Parameter): - if new_object.default is not None: + artifacts.append(io) + elif isinstance(io, Parameter): + if io.default is not None: # TODO: in 5.18 remove the flag check and `warn`, and raise the ValueError directly (minus "flag" text) warnings.warn( "Using the default field for Parameters in Annotations is deprecated since v5.16" @@ -524,27 +522,15 @@ class will be used as inputs, rather than the class itself. ) if func_param.default != inspect.Parameter.empty: # TODO: remove this check in 5.18: - if new_object.default is not None: + if io.default is not None: raise ValueError( "default cannot be set via both the function parameter default and the Parameter's default" ) - new_object.default = serialize(func_param.default) - parameters.append(new_object) - else: - if ( - func_param.default != inspect.Parameter.empty - and func_param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD - ): - default = func_param.default - else: - default = MISSING - - if origin_type_issupertype(func_param.annotation, NoneType) and ( - default is MISSING or default is not None - ): - raise ValueError(f"Optional parameter '{func_param.name}' must have a default value of None.") + io.default = serialize(func_param.default) - parameters.append(Parameter(name=func_param.name, default=default)) + if origin_type_issupertype(func_param.annotation, NoneType) and io.default != "null": + raise ValueError(f"Optional parameter '{func_param.name}' must have a default value of None.") + parameters.append(io) return parameters, artifacts diff --git a/tests/script_annotations/annotated_literals.py b/tests/script_annotations/annotated_literals.py new file mode 100644 index 00000000..d15787fc --- /dev/null +++ b/tests/script_annotations/annotated_literals.py @@ -0,0 +1,18 @@ +from typing import Annotated, Literal + +from hera.shared import global_config +from hera.workflows import Parameter, Steps, Workflow, script + +global_config.experimental_features["script_annotations"] = True + + +@script(constructor="runner") +def literal_str( + my_str: Annotated[Literal["foo", "bar"], Parameter(name="my-str")], +) -> Annotated[Literal[1, 2], Parameter(name="index")]: + return {"foo": 1, "bar": 2}[my_str] + + +with Workflow(name="my-workflow", entrypoint="steps") as w: + with Steps(name="steps"): + literal_str() diff --git a/tests/script_annotations/literals.py b/tests/script_annotations/literals.py new file mode 100644 index 00000000..c1756f39 --- /dev/null +++ b/tests/script_annotations/literals.py @@ -0,0 +1,13 @@ +from typing import Literal + +from hera.workflows import Steps, Workflow, script + + +@script(constructor="runner") +def literal_str(my_str: Literal["foo", "bar"]) -> Literal[1, 2]: + return {"foo": 1, "bar": 2}[my_str] + + +with Workflow(name="my-workflow", entrypoint="steps") as w: + with Steps(name="steps"): + literal_str() diff --git a/tests/script_annotations/pydantic_io_literals.py b/tests/script_annotations/pydantic_io_literals.py new file mode 100644 index 00000000..057a8f65 --- /dev/null +++ b/tests/script_annotations/pydantic_io_literals.py @@ -0,0 +1,24 @@ +from typing import Literal + +from hera.shared import global_config +from hera.workflows import Input, Output, Steps, Workflow, script + +global_config.experimental_features["script_pydantic_io"] = True + + +class ExampleInput(Input): + my_str: Literal["foo", "bar"] + + +class ExampleOutput(Output): + index: Literal[1, 2] + + +@script(constructor="runner") +def literal_str(input: ExampleInput) -> ExampleOutput: + return ExampleOutput(index={"foo": 1, "bar": 2}[input.my_str]) + + +with Workflow(name="my-workflow", entrypoint="steps") as w: + with Steps(name="steps"): + literal_str() diff --git a/tests/script_runner/parameter_inputs.py b/tests/script_runner/parameter_inputs.py index 43cff08b..e954767e 100644 --- a/tests/script_runner/parameter_inputs.py +++ b/tests/script_runner/parameter_inputs.py @@ -1,5 +1,5 @@ import json -from typing import Any, List, Union +from typing import Any, List, Literal, Union try: from typing import Annotated @@ -53,6 +53,11 @@ def annotated_basic_types_with_other_metadata( return Output(output=[Input(a=a_but_kebab, b=b_but_kebab)]) +@script() +def annotated_str_literal(my_literal: Annotated[Literal["1", "2"], Parameter(name="str-literal")]) -> str: + return f"type given: {type(my_literal).__name__}" + + @script() def annotated_object(annotated_input_value: Annotated[Input, Parameter(name="input-value")]) -> Output: return Output(output=[annotated_input_value]) @@ -81,6 +86,11 @@ def str_or_int_parameter(my_str_or_int: Union[int, str]) -> str: return f"type given: {type(my_str_or_int).__name__}" +@script() +def str_literal(my_literal: Literal["1", "2"]) -> str: + return f"type given: {type(my_literal).__name__}" + + @script() def str_parameter_expects_jsonstr_dict(my_json_str: str) -> dict: return json.loads(my_json_str) diff --git a/tests/test_runner.py b/tests/test_runner.py index f6154669..83a60b43 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -77,6 +77,12 @@ "type given: int", id="str-or-int-given-int", ), + pytest.param( + "tests.script_runner.parameter_inputs:str_literal", + [{"name": "my_literal", "value": "1"}], + "type given: str", + id="str-literal", + ), pytest.param( "tests.script_runner.parameter_inputs:str_parameter_expects_jsonstr_dict", [{"name": "my_json_str", "value": json.dumps({"my": "dict"})}], @@ -89,6 +95,12 @@ [{"my": "dict"}], id="str-json-param-as-list", ), + pytest.param( + "tests.script_runner.parameter_inputs:annotated_str_literal", + [{"name": "my_literal", "value": "1"}], + "type given: str", + id="annotated-str-literal", + ), pytest.param( "tests.script_runner.parameter_inputs:annotated_str_parameter_expects_jsonstr_dict", [{"name": "my_json_str", "value": json.dumps({"my": "dict"})}], diff --git a/tests/test_script_annotations.py b/tests/test_script_annotations.py index 314269b0..93fe49e3 100644 --- a/tests/test_script_annotations.py +++ b/tests/test_script_annotations.py @@ -476,3 +476,35 @@ def test_script_with_param(global_config_fixture, module_name): } ] assert consume_task["withParam"] == "{{tasks.generate.outputs.parameters.some-values}}" + + +@pytest.mark.parametrize( + ("module_name", "input_name"), + [ + pytest.param("tests.script_annotations.literals", "my_str", id="bare-type-annotation"), + pytest.param("tests.script_annotations.annotated_literals", "my-str", id="annotated"), + pytest.param("tests.script_annotations.pydantic_io_literals", "my_str", id="pydantic-io"), + ], +) +@pytest.mark.parametrize("experimental_feature", ["", "script_annotations", "script_pydantic_io"]) +def test_script_literals(global_config_fixture, module_name, input_name, experimental_feature): + """Test that Literals work correctly as direct type annotations.""" + # GIVEN + if experimental_feature: + global_config_fixture.experimental_features[experimental_feature] = True + + # Force a reload of the test module, as the runner performs "importlib.import_module", which + # may fetch a cached version + module = importlib.import_module(module_name) + importlib.reload(module) + workflow: Workflow = importlib.import_module(module.__name__).w + + # WHEN + workflow_dict = workflow.to_dict() + assert workflow == Workflow.from_dict(workflow_dict) + assert workflow == Workflow.from_yaml(workflow.to_yaml()) + + # THEN + (literal_str,) = (t for t in workflow_dict["spec"]["templates"] if t["name"] == "literal-str") + + assert literal_str["inputs"]["parameters"] == [{"name": input_name, "enum": ["foo", "bar"]}] diff --git a/tests/test_unit/test_script.py b/tests/test_unit/test_script.py index 46bc5c32..b3e745ce 100644 --- a/tests/test_unit/test_script.py +++ b/tests/test_unit/test_script.py @@ -182,6 +182,15 @@ def unknown_annotations_ignored(my_optional_string: Optional[str] = "123") -> st _get_inputs_from_callable(unknown_annotations_ignored) +def test_invalid_script_when_optional_parameter_does_not_have_default_value_6(): + @script() + def unknown_annotations_ignored(my_optional_string: Annotated[Optional[str], Parameter(name="my-string")]) -> str: + return "Got: {}".format(my_optional_string) + + with pytest.raises(ValueError, match="Optional parameter 'my_optional_string' must have a default value of None."): + _get_inputs_from_callable(unknown_annotations_ignored) + + def test_invalid_script_when_multiple_input_workflow_annotations_are_given(): @script() def invalid_script(a_str: Annotated[str, Artifact(name="a_str"), Parameter(name="a_str")] = "123") -> str: diff --git a/tests/test_unit/test_shared_type_utils.py b/tests/test_unit/test_shared_type_utils.py index e777e0b4..1b31907e 100644 --- a/tests/test_unit/test_shared_type_utils.py +++ b/tests/test_unit/test_shared_type_utils.py @@ -1,5 +1,5 @@ import sys -from typing import List, NoReturn, Optional, Union +from typing import List, Literal, NoReturn, Optional, Union if sys.version_info >= (3, 9): from typing import Annotated @@ -151,6 +151,10 @@ def test_get_unsubscripted_type(annotation, expected): pytest.param(Annotated[Optional[str], "foo"], (str, NoneType), True, id="annotated-optional"), pytest.param(str, (str, NoneType), True, id="str-is-subtype-of-optional-str"), pytest.param(Union[int, str], (str, NoneType), False, id="union-int-str-not-subtype-of-optional-str"), + pytest.param(Literal["foo", "bar"], (str, NoneType), True, id="literal-str-is-subtype-of-optional-str"), + pytest.param(Literal["foo", None], (str, NoneType), True, id="literal-none-is-subtype-of-optional-str"), + pytest.param(Literal[1, 2], (str, NoneType), False, id="literal-int-not-subtype-of-optional-str"), + pytest.param(Literal[1, "foo"], (str, NoneType), False, id="mixed-literal-not-subtype-of-optional-str"), ], ) def test_origin_type_issubtype(annotation, target, expected):