Skip to content

Commit 9c8fba7

Browse files
committed
make explicit arg for pip args
Signed-off-by: Kevin <[email protected]>
1 parent 6eeb1fb commit 9c8fba7

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

sdk/python/kubeflow/training/api/training_client.py

+2
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ def create_job(
354354
env_vars: Optional[
355355
Union[Dict[str, str], List[Union[models.V1EnvVar, models.V1EnvVar]]]
356356
] = None,
357+
pip_args: Optional[List[str]] = None,
357358
):
358359
"""Create the Training Job.
359360
Job can be created using one of the following options:
@@ -486,6 +487,7 @@ def create_job(
486487
train_func_parameters=parameters,
487488
packages_to_install=packages_to_install,
488489
pip_index_url=pip_index_url,
490+
pip_args=pip_args,
489491
)
490492

491493
# Get Training Container template.

sdk/python/kubeflow/training/utils/utils.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -110,21 +110,21 @@ def has_condition(conditions: List[models.V1JobCondition], condition_type: str)
110110

111111

112112
def get_script_for_python_packages(
113-
packages_to_install: List[str], pip_index_url: str
113+
packages_to_install: List[str], pip_index_url: str, pip_args: Optional[List[str]]
114114
) -> str:
115115
"""
116116
Get init script to install Python packages from the given pip index URL.
117117
"""
118118
packages_str = " ".join([str(package) for package in packages_to_install])
119-
119+
pip_args_str = " ".join(pip_args) if pip_args is not None else ""
120120
script_for_python_packages = textwrap.dedent(
121121
f"""
122122
if ! [ -x "$(command -v pip)" ]; then
123123
python -m ensurepip || python -m ensurepip --user || apt-get install python-pip
124124
fi
125125
126126
PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet \
127-
--no-warn-script-location --index-url {pip_index_url} {packages_str}
127+
--no-warn-script-location --index-url {pip_index_url} {pip_args_str} {packages_str}
128128
"""
129129
)
130130

@@ -137,6 +137,7 @@ def get_command_using_train_func(
137137
train_func_parameters: Optional[Dict[str, Any]] = None,
138138
packages_to_install: Optional[List[str]] = None,
139139
pip_index_url: str = constants.DEFAULT_PIP_INDEX_URL,
140+
pip_args: Optional[List[str]] = None
140141
) -> Tuple[List[str], List[str]]:
141142
"""
142143
Get container args and command from the given training function and parameters.
@@ -180,7 +181,7 @@ def get_command_using_train_func(
180181
# Install Python packages if that is required.
181182
if packages_to_install is not None:
182183
exec_script = (
183-
get_script_for_python_packages(packages_to_install, pip_index_url)
184+
get_script_for_python_packages(packages_to_install, pip_index_url, pip_args)
184185
+ exec_script
185186
)
186187

0 commit comments

Comments
 (0)