Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

n/a #6848

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open

n/a #6848

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tfx/dsl/component/experimental/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def _extract_func_args(
if name in arg_defaults and arg_defaults[name] is not None:
raise ValueError('beam Pipeline parameter does not allow default ',
'value other than None.')
elif arg_format == utils.ArgFormats.XFLOW_CONTEXT:
continue
else:
raise ValueError('Unknown argument format: %r' % (arg_format,))
return result
Expand Down
3 changes: 3 additions & 0 deletions tfx/dsl/component/experimental/function_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from tfx.dsl.component.experimental import annotations
from tfx.dsl.component.experimental import json_compat
from tfx.dsl.component.experimental import utils
from tfx.dsl.component.experimental import xflow_context
from tfx.types import artifact
from tfx.types import standard_artifacts
import typing_extensions
Expand Down Expand Up @@ -254,6 +255,8 @@ def _parse_signature(
'`InputArtifact[ArtifactType]` or `OutputArtifact[ArtifactType]` '
'typehint annotations.' % (arg, func)
)
elif arg_typehint == xflow_context.XflowContext:
arg_formats[arg] = utils.ArgFormats.XFLOW_CONTEXT
else:
raise ValueError(
'Unknown type hint annotation for argument %r on function %r'
Expand Down
2 changes: 2 additions & 0 deletions tfx/dsl/component/experimental/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class ArgFormats(enum.Enum):
BEAM_PARAMETER = 5
LIST_INPUT_ARTIFACTS = 6
PREVIOUS_OUTPUT_ARTIFACTS = 7
# Used for context param passed to Xflow components.
XFLOW_CONTEXT = 8


def assert_is_functype(func: Any) -> None:
Expand Down
23 changes: 23 additions & 0 deletions tfx/dsl/component/experimental/xflow_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Context for Xflow components."""

import abc


class XflowContext(metaclass=abc.ABCMeta):
"""Context for Xflow components."""

# TODO(b/348507146): Add user-facing fields to XflowContext.
pass
64 changes: 41 additions & 23 deletions tfx/orchestration/portable/python_executor_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import sys
from typing import Optional, cast

from tfx import types
from tfx.dsl.components.base import base_executor
from tfx.dsl.io import fileio
from tfx.orchestration.portable import base_executor_operator
Expand All @@ -31,6 +32,39 @@
_STATEFUL_WORKING_DIR = 'stateful_working_dir'


def hydrate_value_artifacts(input_artifacts: dict[str, list[types.Artifact]]):
"""Reads value of ValueArtifacts into memory."""
for _, artifact_list in input_artifacts.items():
for artifact in artifact_list:
if isinstance(artifact, ValueArtifact):
# Read ValueArtifact into memory.
artifact.read()


def construct_executor_output(
execution_info: data_types.ExecutionInfo,
output_dict: dict[str, list[types.Artifact]],
) -> execution_result_pb2.ExecutorOutput:
"""Constructs final executor output."""
# If result is not returned from the Do function, then try to
# read from the executor_output_uri.
if fileio.exists(execution_info.execution_output_uri):
return execution_result_pb2.ExecutorOutput.FromString(
fileio.open(execution_info.execution_output_uri, 'rb').read()
)
else:
# Old style TFX executor doesn't return executor_output, but modify
# output_dict and exec_properties in place. For backward compatibility,
# we use their executor_output and exec_properties to construct
# ExecutorOutput.
result = execution_result_pb2.ExecutorOutput()
outputs_utils.populate_output_artifact(result, output_dict)
outputs_utils.populate_exec_properties(
result, execution_info.exec_properties
)
return result


def run_with_executor(
execution_info: data_types.ExecutionInfo,
executor: base_executor.BaseExecutor
Expand All @@ -44,31 +78,15 @@ def run_with_executor(
Returns:
The output from executor.
"""
for _, artifact_list in execution_info.input_dict.items():
for artifact in artifact_list:
if isinstance(artifact, ValueArtifact):
# Read ValueArtifact into memory.
artifact.read()
hydrate_value_artifacts(execution_info.input_dict)

output_dict = copy.deepcopy(execution_info.output_dict)
result = executor.Do(execution_info.input_dict, output_dict,
execution_info.exec_properties)
if not result:
# If result is not returned from the Do function, then try to
# read from the executor_output_uri.
if fileio.exists(execution_info.execution_output_uri):
result = execution_result_pb2.ExecutorOutput.FromString(
fileio.open(execution_info.execution_output_uri, 'rb').read())
else:
# Old style TFX executor doesn't return executor_output, but modify
# output_dict and exec_properties in place. For backward compatibility,
# we use their executor_output and exec_properties to construct
# ExecutorOutput.
result = execution_result_pb2.ExecutorOutput()
outputs_utils.populate_output_artifact(result, output_dict)
outputs_utils.populate_exec_properties(result,
execution_info.exec_properties)
return result
result = executor.Do(
execution_info.input_dict, output_dict, execution_info.exec_properties
)
if result:
return result
return construct_executor_output(execution_info, output_dict)


class PythonExecutorOperator(base_executor_operator.BaseExecutorOperator):
Expand Down