Skip to content

Commit

Permalink
Fix runner script env var construction (#1289)
Browse files Browse the repository at this point in the history
* If user set their own env var we did not keep it

---------

Signed-off-by: Elliot Gunton <[email protected]>
  • Loading branch information
elliotgunton authored Dec 9, 2024
1 parent 25ac420 commit 9a8f253
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 4 deletions.
6 changes: 5 additions & 1 deletion src/hera/workflows/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,11 @@ def transform_script_template_post_build(
script_env.append(EnvVar(name="hera__script_pydantic_io", value=""))

if script_env:
script.env = script_env
if not script.env:
# If user did not set any env vars themselves then we need to initialise the list
script.env = []

script.env.extend(script_env)

return script

Expand Down
145 changes: 142 additions & 3 deletions tests/test_unit/test_script.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,26 @@
from pathlib import Path
from typing import Annotated, Dict, Optional, Union
from typing import Annotated, Dict, List, Optional, Union, cast

import pytest

from hera.workflows import Output, Workflow, script
from hera.shared._global_config import _GlobalConfig
from hera.workflows._mixins import EnvT
from hera.workflows.artifact import Artifact
from hera.workflows.env import Env
from hera.workflows.io import Output
from hera.workflows.models import (
EnvVar as ModelEnvVar,
ScriptTemplate,
Workflow as ModelWorkflow,
)
from hera.workflows.parameter import Parameter
from hera.workflows.script import _get_inputs_from_callable, _get_outputs_from_return_annotation
from hera.workflows.script import (
RunnerScriptConstructor,
_get_inputs_from_callable,
_get_outputs_from_return_annotation,
script,
)
from hera.workflows.workflow import Workflow


def test_get_inputs_from_callable_simple_params():
Expand Down Expand Up @@ -243,3 +257,128 @@ def invalid_script(a_str: str = "123") -> Annotated[str, Artifact(name="a_str"),

with pytest.raises(ValueError, match="Annotation metadata cannot contain more than one Artifact/Parameter."):
_get_outputs_from_return_annotation(invalid_script, None)


class TestRunnerScriptEnv:
@staticmethod
def build_workflow(script_env, constructor=None) -> ModelWorkflow:
constructor = constructor if constructor is not None else RunnerScriptConstructor()

@script(constructor=constructor, env=script_env)
def my_script():
pass

with Workflow(name="test") as w:
my_script()

return cast(ModelWorkflow, w.build())

@pytest.mark.parametrize(
"user_env,expected_env",
(
[
None,
None,
],
[
Env(name="my_env_var", value=42),
[ModelEnvVar(name="my_env_var", value=42)],
],
),
)
def test_runner_script_no_added_env_vars(self, user_env: EnvT, expected_env):
built_workflow = self.build_workflow(user_env)

script_template = cast(ScriptTemplate, built_workflow.spec.templates[0].script)
assert script_template is not None
assert script_template.env == expected_env

@pytest.mark.parametrize(
"user_env,expected_env",
(
[
None,
[
ModelEnvVar(name="hera__outputs_directory", value="/my/tmp/dir"),
],
],
[
Env(name="my_env_var", value=42),
[
ModelEnvVar(name="my_env_var", value=42),
ModelEnvVar(name="hera__outputs_directory", value="/my/tmp/dir"),
],
],
),
)
def test_runner_script_output_dir_env_var(self, user_env: EnvT, expected_env: Optional[List[ModelEnvVar]]):
# GIVEN
constructor = RunnerScriptConstructor(outputs_directory="/my/tmp/dir")

built_workflow = self.build_workflow(user_env, constructor)

script_template = cast(ScriptTemplate, built_workflow.spec.templates[0].script)
assert script_template is not None
assert script_template.env == expected_env

@pytest.mark.parametrize(
"user_env,expected_env",
(
[
None,
[
ModelEnvVar(name="hera__pydantic_mode", value="1"),
],
],
[
Env(name="my_env_var", value=42),
[
ModelEnvVar(name="my_env_var", value=42),
ModelEnvVar(name="hera__pydantic_mode", value="1"),
],
],
),
)
def test_runner_script_pydantic_mode_env_var(self, user_env: EnvT, expected_env: Optional[List[ModelEnvVar]]):
# GIVEN
constructor = RunnerScriptConstructor(pydantic_mode=1)

built_workflow = self.build_workflow(user_env, constructor)

script_template = cast(ScriptTemplate, built_workflow.spec.templates[0].script)
assert script_template is not None
assert script_template.env == expected_env

@pytest.mark.parametrize(
"user_env,expected_env",
(
[
None,
[
ModelEnvVar(name="hera__script_pydantic_io", value=""),
],
],
[
Env(name="my_env_var", value=42),
[
ModelEnvVar(name="my_env_var", value=42),
ModelEnvVar(name="hera__script_pydantic_io", value=""),
],
],
),
)
def test_runner_script_pydantic_io_env_var(
self,
global_config_fixture: _GlobalConfig,
user_env: EnvT,
expected_env: Optional[List[ModelEnvVar]],
):
# GIVEN
global_config_fixture.experimental_features["script_pydantic_io"] = True
constructor = RunnerScriptConstructor()

built_workflow = self.build_workflow(user_env, constructor)

script_template = cast(ScriptTemplate, built_workflow.spec.templates[0].script)
assert script_template is not None
assert script_template.env == expected_env

0 comments on commit 9a8f253

Please sign in to comment.