Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Literal types in script runner #1249

Merged
merged 7 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions src/hera/shared/_type_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Any,
Iterable,
List,
Literal,
Optional,
Tuple,
Type,
Expand Down Expand Up @@ -82,6 +83,13 @@
return metadata[0]


def add_metadata_from_type(parameter: Parameter, annotation: Any) -> None:
alicederyn marked this conversation as resolved.
Show resolved Hide resolved
if not parameter.enum:
type_ = unwrap_annotation(annotation)

Check warning on line 88 in src/hera/shared/_type_util.py

View check run for this annotation

Codecov / codecov/patch

src/hera/shared/_type_util.py#L88

Added line #L88 was not covered by tests
if get_origin(type_) is Literal:
parameter.enum = list(get_args(type_))

Check warning on line 90 in src/hera/shared/_type_util.py

View check run for this annotation

Codecov / codecov/patch

src/hera/shared/_type_util.py#L90

Added line #L90 was not covered by tests


def construct_io_from_annotation(python_name: str, annotation: Any) -> Union[Parameter, Artifact]:
"""Constructs a Parameter or Artifact object based on annotations.

Expand All @@ -91,13 +99,19 @@
For a function parameter, python_name should be the parameter name.
For a Pydantic Input or Output class, python_name should be the field name.
"""
if annotation := get_workflow_annotation(annotation):
if workflow_annotation := get_workflow_annotation(annotation):
# Copy so as to not modify the fields themselves
annotation_copy = annotation.copy()
annotation_copy.name = annotation.name or python_name
return annotation_copy
io = workflow_annotation.copy()

Check warning on line 104 in src/hera/shared/_type_util.py

View check run for this annotation

Codecov / codecov/patch

src/hera/shared/_type_util.py#L104

Added line #L104 was not covered by tests
else:
io = Parameter()

Check warning on line 106 in src/hera/shared/_type_util.py

View check run for this annotation

Codecov / codecov/patch

src/hera/shared/_type_util.py#L106

Added line #L106 was not covered by tests

io.name = io.name or python_name

Check warning on line 108 in src/hera/shared/_type_util.py

View check run for this annotation

Codecov / codecov/patch

src/hera/shared/_type_util.py#L108

Added line #L108 was not covered by tests
if isinstance(io, Parameter) and not io.enum:
type_ = unwrap_annotation(annotation)

Check warning on line 110 in src/hera/shared/_type_util.py

View check run for this annotation

Codecov / codecov/patch

src/hera/shared/_type_util.py#L110

Added line #L110 was not covered by tests
if get_origin(type_) is Literal:
io.enum = list(get_args(type_))

Check warning on line 112 in src/hera/shared/_type_util.py

View check run for this annotation

Codecov / codecov/patch

src/hera/shared/_type_util.py#L112

Added line #L112 was not covered by tests
alicederyn marked this conversation as resolved.
Show resolved Hide resolved

return Parameter(name=python_name)
return io

Check warning on line 114 in src/hera/shared/_type_util.py

View check run for this annotation

Codecov / codecov/patch

src/hera/shared/_type_util.py#L114

Added line #L114 was not covered by tests


def get_unsubscripted_type(t: Any) -> Any:
Expand All @@ -120,6 +134,8 @@
origin_type = get_unsubscripted_type(unwrapped_type)
if origin_type is Union or origin_type is UnionType:
return all(origin_type_issubtype(arg, type_) for arg in get_args(unwrapped_type))
if origin_type is Literal:
return all(isinstance(value, type_) for value in get_args(unwrapped_type))

Check warning on line 138 in src/hera/shared/_type_util.py

View check run for this annotation

Codecov / codecov/patch

src/hera/shared/_type_util.py#L138

Added line #L138 was not covered by tests
return isinstance(origin_type, type) and issubclass(origin_type, type_)


Expand Down
46 changes: 16 additions & 30 deletions src/hera/workflows/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
)
from hera.shared._pydantic import _PYDANTIC_VERSION, root_validator, validator
from hera.shared._type_util import (
add_metadata_from_type,
construct_io_from_annotation,
get_workflow_annotation,
is_subscripted,
Expand Down Expand Up @@ -379,6 +380,7 @@
default = MISSING

param = Parameter(name=p.name, default=default)
add_metadata_from_type(param, p.annotation)

Check warning on line 383 in src/hera/workflows/script.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/script.py#L383

Added line #L383 was not covered by tests
sambhav marked this conversation as resolved.
Show resolved Hide resolved
parameters.append(param)

return parameters
Expand Down Expand Up @@ -495,22 +497,18 @@

artifacts.extend(input_class._get_artifacts(add_missing_path=True))

elif param_or_artifact := get_workflow_annotation(func_param.annotation):
if param_or_artifact.output:
else:
io = construct_io_from_annotation(func_param.name, func_param.annotation)

Check warning on line 501 in src/hera/workflows/script.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/script.py#L501

Added line #L501 was not covered by tests
if io.output:
continue

# Create a new object so we don't modify the Workflow itself
new_object = param_or_artifact.copy()
if not new_object.name:
new_object.name = func_param.name

if isinstance(new_object, Artifact):
if new_object.path is None:
new_object.path = new_object._get_default_inputs_path()
if isinstance(io, Artifact):
if io.path is None:
io.path = io._get_default_inputs_path()

Check warning on line 507 in src/hera/workflows/script.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/script.py#L507

Added line #L507 was not covered by tests

artifacts.append(new_object)
elif isinstance(new_object, Parameter):
if new_object.default is not None:
artifacts.append(io)

Check warning on line 509 in src/hera/workflows/script.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/script.py#L509

Added line #L509 was not covered by tests
elif isinstance(io, Parameter):
if io.default is not None:
# TODO: in 5.18 remove the flag check and `warn`, and raise the ValueError directly (minus "flag" text)
warnings.warn(
"Using the default field for Parameters in Annotations is deprecated since v5.16"
Expand All @@ -524,27 +522,15 @@
)
if func_param.default != inspect.Parameter.empty:
# TODO: remove this check in 5.18:
if new_object.default is not None:
if io.default is not None:
raise ValueError(
"default cannot be set via both the function parameter default and the Parameter's default"
)
new_object.default = serialize(func_param.default)
parameters.append(new_object)
else:
if (
func_param.default != inspect.Parameter.empty
and func_param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
):
default = func_param.default
else:
default = MISSING

if origin_type_issupertype(func_param.annotation, NoneType) and (
default is MISSING or default is not None
):
raise ValueError(f"Optional parameter '{func_param.name}' must have a default value of None.")
io.default = serialize(func_param.default)

Check warning on line 529 in src/hera/workflows/script.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/script.py#L529

Added line #L529 was not covered by tests

parameters.append(Parameter(name=func_param.name, default=default))
if origin_type_issupertype(func_param.annotation, NoneType) and io.default != "null":
raise ValueError(f"Optional parameter '{func_param.name}' must have a default value of None.")
parameters.append(io)

Check warning on line 533 in src/hera/workflows/script.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/script.py#L532-L533

Added lines #L532 - L533 were not covered by tests

return parameters, artifacts

Expand Down
18 changes: 18 additions & 0 deletions tests/script_annotations/annotated_literals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Annotated, Literal

from hera.shared import global_config
from hera.workflows import Parameter, Steps, Workflow, script

global_config.experimental_features["script_annotations"] = True


@script(constructor="runner")
def literal_str(
my_str: Annotated[Literal["foo", "bar"], Parameter(name="my-str")],
) -> Annotated[Literal[1, 2], Parameter(name="index")]:
return {"foo": 1, "bar": 2}[my_str]


with Workflow(name="my-workflow", entrypoint="steps") as w:
with Steps(name="steps"):
literal_str()
13 changes: 13 additions & 0 deletions tests/script_annotations/literals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import Literal

from hera.workflows import Steps, Workflow, script


@script(constructor="runner")
def literal_str(my_str: Literal["foo", "bar"]) -> Literal[1, 2]:
return {"foo": 1, "bar": 2}[my_str]


with Workflow(name="my-workflow", entrypoint="steps") as w:
with Steps(name="steps"):
literal_str()
24 changes: 24 additions & 0 deletions tests/script_annotations/pydantic_io_literals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Literal

from hera.shared import global_config
from hera.workflows import Input, Output, Steps, Workflow, script

global_config.experimental_features["script_pydantic_io"] = True


class ExampleInput(Input):
my_str: Literal["foo", "bar"]


class ExampleOutput(Output):
index: Literal[1, 2]


@script(constructor="runner")
def literal_str(input: ExampleInput) -> ExampleOutput:
return ExampleOutput(index={"foo": 1, "bar": 2}[input.my_str])


with Workflow(name="my-workflow", entrypoint="steps") as w:
with Steps(name="steps"):
literal_str()
12 changes: 11 additions & 1 deletion tests/script_runner/parameter_inputs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, List, Union
from typing import Any, List, Literal, Union

try:
from typing import Annotated
Expand Down Expand Up @@ -53,6 +53,11 @@ def annotated_basic_types_with_other_metadata(
return Output(output=[Input(a=a_but_kebab, b=b_but_kebab)])


@script()
def annotated_str_literal(my_literal: Annotated[Literal["1", "2"], Parameter(name="str-literal")]) -> str:
return f"type given: {type(my_literal).__name__}"


@script()
def annotated_object(annotated_input_value: Annotated[Input, Parameter(name="input-value")]) -> Output:
return Output(output=[annotated_input_value])
Expand Down Expand Up @@ -81,6 +86,11 @@ def str_or_int_parameter(my_str_or_int: Union[int, str]) -> str:
return f"type given: {type(my_str_or_int).__name__}"


@script()
def str_literal(my_literal: Literal["1", "2"]) -> str:
return f"type given: {type(my_literal).__name__}"


@script()
def str_parameter_expects_jsonstr_dict(my_json_str: str) -> dict:
return json.loads(my_json_str)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@
"type given: int",
id="str-or-int-given-int",
),
pytest.param(
"tests.script_runner.parameter_inputs:str_literal",
[{"name": "my_literal", "value": "1"}],
"type given: str",
id="str-literal",
),
pytest.param(
"tests.script_runner.parameter_inputs:str_parameter_expects_jsonstr_dict",
[{"name": "my_json_str", "value": json.dumps({"my": "dict"})}],
Expand All @@ -89,6 +95,12 @@
[{"my": "dict"}],
id="str-json-param-as-list",
),
pytest.param(
"tests.script_runner.parameter_inputs:annotated_str_literal",
[{"name": "my_literal", "value": "1"}],
"type given: str",
id="annotated-str-literal",
),
pytest.param(
"tests.script_runner.parameter_inputs:annotated_str_parameter_expects_jsonstr_dict",
[{"name": "my_json_str", "value": json.dumps({"my": "dict"})}],
Expand Down
32 changes: 32 additions & 0 deletions tests/test_script_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,3 +476,35 @@ def test_script_with_param(global_config_fixture, module_name):
}
]
assert consume_task["withParam"] == "{{tasks.generate.outputs.parameters.some-values}}"


@pytest.mark.parametrize(
("module_name", "input_name"),
[
pytest.param("tests.script_annotations.literals", "my_str", id="bare-type-annotation"),
pytest.param("tests.script_annotations.annotated_literals", "my-str", id="annotated"),
pytest.param("tests.script_annotations.pydantic_io_literals", "my_str", id="pydantic-io"),
],
)
@pytest.mark.parametrize("experimental_feature", ["", "script_annotations", "script_pydantic_io"])
def test_script_literals(global_config_fixture, module_name, input_name, experimental_feature):
"""Test that Literals work correctly as direct type annotations."""
# GIVEN
if experimental_feature:
global_config_fixture.experimental_features[experimental_feature] = True

# Force a reload of the test module, as the runner performs "importlib.import_module", which
# may fetch a cached version
module = importlib.import_module(module_name)
importlib.reload(module)
workflow: Workflow = importlib.import_module(module.__name__).w

# WHEN
workflow_dict = workflow.to_dict()
assert workflow == Workflow.from_dict(workflow_dict)
assert workflow == Workflow.from_yaml(workflow.to_yaml())

# THEN
(literal_str,) = (t for t in workflow_dict["spec"]["templates"] if t["name"] == "literal-str")

assert literal_str["inputs"]["parameters"] == [{"name": input_name, "enum": ["foo", "bar"]}]
9 changes: 9 additions & 0 deletions tests/test_unit/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,15 @@ def unknown_annotations_ignored(my_optional_string: Optional[str] = "123") -> st
_get_inputs_from_callable(unknown_annotations_ignored)


def test_invalid_script_when_optional_parameter_does_not_have_default_value_6():
@script()
def unknown_annotations_ignored(my_optional_string: Annotated[Optional[str], Parameter(name="my-string")]) -> str:
return "Got: {}".format(my_optional_string)

with pytest.raises(ValueError, match="Optional parameter 'my_optional_string' must have a default value of None."):
_get_inputs_from_callable(unknown_annotations_ignored)


def test_invalid_script_when_multiple_input_workflow_annotations_are_given():
@script()
def invalid_script(a_str: Annotated[str, Artifact(name="a_str"), Parameter(name="a_str")] = "123") -> str:
Expand Down
6 changes: 5 additions & 1 deletion tests/test_unit/test_shared_type_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
from typing import List, NoReturn, Optional, Union
from typing import List, Literal, NoReturn, Optional, Union

if sys.version_info >= (3, 9):
from typing import Annotated
Expand Down Expand Up @@ -151,6 +151,10 @@ def test_get_unsubscripted_type(annotation, expected):
pytest.param(Annotated[Optional[str], "foo"], (str, NoneType), True, id="annotated-optional"),
pytest.param(str, (str, NoneType), True, id="str-is-subtype-of-optional-str"),
pytest.param(Union[int, str], (str, NoneType), False, id="union-int-str-not-subtype-of-optional-str"),
pytest.param(Literal["foo", "bar"], (str, NoneType), True, id="literal-str-is-subtype-of-optional-str"),
pytest.param(Literal["foo", None], (str, NoneType), True, id="literal-none-is-subtype-of-optional-str"),
pytest.param(Literal[1, 2], (str, NoneType), False, id="literal-int-not-subtype-of-optional-str"),
pytest.param(Literal[1, "foo"], (str, NoneType), False, id="mixed-literal-not-subtype-of-optional-str"),
],
)
def test_origin_type_issubtype(annotation, target, expected):
Expand Down
Loading