diff --git a/src/hera/workflows/script.py b/src/hera/workflows/script.py index 6684099a..691af297 100644 --- a/src/hera/workflows/script.py +++ b/src/hera/workflows/script.py @@ -883,7 +883,11 @@ def transform_script_template_post_build( script_env.append(EnvVar(name="hera__script_pydantic_io", value="")) if script_env: - script.env = script_env + if not script.env: + # If user did not set any env vars themselves then we need to initialise the list + script.env = [] + + script.env.extend(script_env) return script diff --git a/tests/test_unit/test_script.py b/tests/test_unit/test_script.py index b3e745ce..d62bc73a 100644 --- a/tests/test_unit/test_script.py +++ b/tests/test_unit/test_script.py @@ -1,12 +1,26 @@ from pathlib import Path -from typing import Annotated, Dict, Optional, Union +from typing import Annotated, Dict, List, Optional, Union, cast import pytest -from hera.workflows import Output, Workflow, script +from hera.shared._global_config import _GlobalConfig +from hera.workflows._mixins import EnvT from hera.workflows.artifact import Artifact +from hera.workflows.env import Env +from hera.workflows.io import Output +from hera.workflows.models import ( + EnvVar as ModelEnvVar, + ScriptTemplate, + Workflow as ModelWorkflow, +) from hera.workflows.parameter import Parameter -from hera.workflows.script import _get_inputs_from_callable, _get_outputs_from_return_annotation +from hera.workflows.script import ( + RunnerScriptConstructor, + _get_inputs_from_callable, + _get_outputs_from_return_annotation, + script, +) +from hera.workflows.workflow import Workflow def test_get_inputs_from_callable_simple_params(): @@ -243,3 +257,128 @@ def invalid_script(a_str: str = "123") -> Annotated[str, Artifact(name="a_str"), with pytest.raises(ValueError, match="Annotation metadata cannot contain more than one Artifact/Parameter."): _get_outputs_from_return_annotation(invalid_script, None) + + +class TestRunnerScriptEnv: + @staticmethod + def build_workflow(script_env, constructor=None) -> ModelWorkflow: + constructor = constructor if constructor is not None else RunnerScriptConstructor() + + @script(constructor=constructor, env=script_env) + def my_script(): + pass + + with Workflow(name="test") as w: + my_script() + + return cast(ModelWorkflow, w.build()) + + @pytest.mark.parametrize( + "user_env,expected_env", + ( + [ + None, + None, + ], + [ + Env(name="my_env_var", value=42), + [ModelEnvVar(name="my_env_var", value=42)], + ], + ), + ) + def test_runner_script_no_added_env_vars(self, user_env: EnvT, expected_env): + built_workflow = self.build_workflow(user_env) + + script_template = cast(ScriptTemplate, built_workflow.spec.templates[0].script) + assert script_template is not None + assert script_template.env == expected_env + + @pytest.mark.parametrize( + "user_env,expected_env", + ( + [ + None, + [ + ModelEnvVar(name="hera__outputs_directory", value="/my/tmp/dir"), + ], + ], + [ + Env(name="my_env_var", value=42), + [ + ModelEnvVar(name="my_env_var", value=42), + ModelEnvVar(name="hera__outputs_directory", value="/my/tmp/dir"), + ], + ], + ), + ) + def test_runner_script_output_dir_env_var(self, user_env: EnvT, expected_env: Optional[List[ModelEnvVar]]): + # GIVEN + constructor = RunnerScriptConstructor(outputs_directory="/my/tmp/dir") + + built_workflow = self.build_workflow(user_env, constructor) + + script_template = cast(ScriptTemplate, built_workflow.spec.templates[0].script) + assert script_template is not None + assert script_template.env == expected_env + + @pytest.mark.parametrize( + "user_env,expected_env", + ( + [ + None, + [ + ModelEnvVar(name="hera__pydantic_mode", value="1"), + ], + ], + [ + Env(name="my_env_var", value=42), + [ + ModelEnvVar(name="my_env_var", value=42), + ModelEnvVar(name="hera__pydantic_mode", value="1"), + ], + ], + ), + ) + def test_runner_script_pydantic_mode_env_var(self, user_env: EnvT, expected_env: Optional[List[ModelEnvVar]]): + # GIVEN + constructor = RunnerScriptConstructor(pydantic_mode=1) + + built_workflow = self.build_workflow(user_env, constructor) + + script_template = cast(ScriptTemplate, built_workflow.spec.templates[0].script) + assert script_template is not None + assert script_template.env == expected_env + + @pytest.mark.parametrize( + "user_env,expected_env", + ( + [ + None, + [ + ModelEnvVar(name="hera__script_pydantic_io", value=""), + ], + ], + [ + Env(name="my_env_var", value=42), + [ + ModelEnvVar(name="my_env_var", value=42), + ModelEnvVar(name="hera__script_pydantic_io", value=""), + ], + ], + ), + ) + def test_runner_script_pydantic_io_env_var( + self, + global_config_fixture: _GlobalConfig, + user_env: EnvT, + expected_env: Optional[List[ModelEnvVar]], + ): + # GIVEN + global_config_fixture.experimental_features["script_pydantic_io"] = True + constructor = RunnerScriptConstructor() + + built_workflow = self.build_workflow(user_env, constructor) + + script_template = cast(ScriptTemplate, built_workflow.spec.templates[0].script) + assert script_template is not None + assert script_template.env == expected_env