diff --git a/src/nemo_run/core/execution/slurm.py b/src/nemo_run/core/execution/slurm.py index afbf051..cbd5ad3 100644 --- a/src/nemo_run/core/execution/slurm.py +++ b/src/nemo_run/core/execution/slurm.py @@ -42,8 +42,14 @@ from nemo_run.core.packaging.base import Packager from nemo_run.core.packaging.git import GitArchivePackager from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer -from nemo_run.core.tunnel.callback import Callback -from nemo_run.core.tunnel.client import LocalTunnel, SSHConfigFile, SSHTunnel, Tunnel +from nemo_run.core.tunnel.client import ( + Callback, + LocalTunnel, + PackagingJob, + SSHConfigFile, + SSHTunnel, + Tunnel, +) from nemo_run.core.tunnel.server import TunnelMetadata, server_dir from nemo_run.devspace.base import DevSpace @@ -388,7 +394,7 @@ def __post_init__(self): self.wait_time_for_group_job = 0 def info(self) -> str: - return f"{self.__class__.__qualname__} on {self.tunnel._key}" + return f"{self.__class__.__qualname__} on {self.tunnel.key}" def alloc(self, job_name="interactive"): self.job_name = f"{self.job_name_prefix}{job_name}" @@ -537,13 +543,39 @@ def package_configs(self, *cfgs: tuple[str, str]) -> list[str]: return filenames def package(self, packager: Packager, job_name: str): - if job_name in self.tunnel.packaging_jobs: + if job_name in self.tunnel.packaging_jobs and not packager.symlink_from_remote_dir: logger.info( f"Packaging for job {job_name} in tunnel {self.tunnel} already done. Skipping subsequent packagings.\n" "This may cause issues if you have multiple tasks with the same name but different packagers, as only the first packager will be used." ) return + if packager.symlink_from_remote_dir: + logger.info( + f"Packager {packager} is configured to symlink from remote dir. Skipping packaging." + ) + if type(packager) is Packager: + self.tunnel.packaging_jobs[job_name] = PackagingJob(symlink=False) + return + + self.tunnel.packaging_jobs[job_name] = PackagingJob( + symlink=True, + src_path=packager.symlink_from_remote_dir, + dst_path=os.path.join(self.tunnel.job_dir, Path(self.job_dir).name, "code"), + ) + + # Tunnel job dir is the directory of the experiment id, so the base job dir is two levels up + base_remote_dir = str(Path(self.tunnel.job_dir).parent.parent) + base_remote_mount = f"{base_remote_dir}:{base_remote_dir}" + if base_remote_mount not in self.container_mounts: + self.container_mounts.append(f"{base_remote_dir}:{base_remote_dir}") + + for req in self.resource_group: + if base_remote_mount not in req.container_mounts: + req.container_mounts.append(base_remote_mount) + + return + assert self.experiment_id, "Executor not assigned to an experiment." if isinstance(packager, GitArchivePackager): output = subprocess.run( @@ -573,7 +605,12 @@ def package(self, packager: Packager, job_name: str): f"tar -xvzf {local_pkg} -C {local_code_extraction_path} --ignore-zeros", hide=True ) - self.tunnel.packaging_jobs.add(job_name) + self.tunnel.packaging_jobs[job_name] = PackagingJob( + symlink=False, + dst_path=None + if type(packager) is Packager + else os.path.join(self.tunnel.job_dir, Path(self.job_dir).name, "code"), + ) def parse_deps(self) -> list[str]: """ diff --git a/src/nemo_run/core/packaging/base.py b/src/nemo_run/core/packaging/base.py index 93d176d..95bd25d 100644 --- a/src/nemo_run/core/packaging/base.py +++ b/src/nemo_run/core/packaging/base.py @@ -16,7 +16,7 @@ import logging from dataclasses import dataclass from pathlib import Path - +from typing import Optional from nemo_run.config import ConfigurableMixin @@ -45,6 +45,10 @@ class Packager(ConfigurableMixin): #: Uses component or executor specific debug flags if set to True. debug: bool = False + #: Symlinks the package from the provided remote dir. + #: Only applicable when using SlurmExecutor at the moment. + symlink_from_remote_dir: Optional[str] = None + def package(self, path: Path, job_dir: str, name: str) -> str: ... def setup(self): diff --git a/src/nemo_run/core/tunnel/callback.py b/src/nemo_run/core/tunnel/callback.py deleted file mode 100644 index bc6949f..0000000 --- a/src/nemo_run/core/tunnel/callback.py +++ /dev/null @@ -1,45 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from nemo_run.core.tunnel.client import Tunnel - - -class Callback: - def setup(self, tunnel: "Tunnel"): - """Called when the tunnel is setup.""" - self.tunnel = tunnel - - def on_start(self): - """Called when the keep_alive loop starts.""" - pass - - def on_interval(self): - """Called at each interval during the keep_alive loop.""" - pass - - def on_stop(self): - """Called when the keep_alive loop stops.""" - pass - - def on_error(self, error: Exception): - """Called when an error occurs during the keep_alive loop. - - Args: - error (Exception): The exception that was raised. - """ - pass diff --git a/src/nemo_run/core/tunnel/client.py b/src/nemo_run/core/tunnel/client.py index b2d9be8..7abb31c 100644 --- a/src/nemo_run/core/tunnel/client.py +++ b/src/nemo_run/core/tunnel/client.py @@ -24,7 +24,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Callable, Optional +from typing import Callable, Optional import paramiko import paramiko.ssh_exception @@ -32,12 +32,9 @@ from invoke.context import Context from invoke.runners import Result as RunResult -from nemo_run.config import NEMORUN_HOME +from nemo_run.config import NEMORUN_HOME, ConfigurableMixin from nemo_run.core.frontend.console.api import CONSOLE -if TYPE_CHECKING: - from nemo_run.core.tunnel.callback import Callback - logger: logging.Logger = logging.getLogger(__name__) TUNNEL_DIR = ".tunnels" TUNNEL_FILE_SUBPATH = os.path.join(NEMORUN_HOME, TUNNEL_DIR) @@ -58,18 +55,24 @@ def authentication_handler(title, instructions, prompt_list): @dataclass(kw_only=True) -class Tunnel(ABC): +class PackagingJob(ConfigurableMixin): + symlink: bool = False + src_path: Optional[str] = None + dst_path: Optional[str] = None + + def symlink_cmd(self): + return f"ln -s {self.src_path} {self.dst_path}" + + +@dataclass(kw_only=True) +class Tunnel(ABC, ConfigurableMixin): job_dir: str host: str user: str + packaging_jobs: dict[str, PackagingJob] = field(default_factory=dict) def __post_init__(self): - self._key = f"{self.user}@{self.host}" - self._packaging_jobs = set() - - @property - def packaging_jobs(self): - return self._packaging_jobs + self.key = f"{self.user}@{self.host}" def _set_job_dir(self, experiment_id: str): ... @@ -377,3 +380,29 @@ def remove_entry(self, name: str): file.writelines(lines) print(f"Removed SSH config entry for {host}.") + + +class Callback: + def setup(self, tunnel: "Tunnel"): + """Called when the tunnel is setup.""" + self.tunnel = tunnel + + def on_start(self): + """Called when the keep_alive loop starts.""" + pass + + def on_interval(self): + """Called at each interval during the keep_alive loop.""" + pass + + def on_stop(self): + """Called when the keep_alive loop stops.""" + pass + + def on_error(self, error: Exception): + """Called when an error occurs during the keep_alive loop. + + Args: + error (Exception): The exception that was raised. + """ + pass diff --git a/src/nemo_run/devspace/base.py b/src/nemo_run/devspace/base.py index 03f436a..bccc99c 100644 --- a/src/nemo_run/devspace/base.py +++ b/src/nemo_run/devspace/base.py @@ -19,7 +19,7 @@ import fiddle as fdl from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer -from nemo_run.core.tunnel.callback import Callback +from nemo_run.core.tunnel.client import Callback if TYPE_CHECKING: from nemo_run.core.execution.base import Executor diff --git a/src/nemo_run/run/experiment.py b/src/nemo_run/run/experiment.py index 8805013..896f4db 100644 --- a/src/nemo_run/run/experiment.py +++ b/src/nemo_run/run/experiment.py @@ -190,6 +190,7 @@ class Experiment(ConfigurableMixin): _VERSION_FILE = "_VERSION" _TASK_FILE = "_TASKS" _DONE_FILE = "_DONE" + _TUNNELS_FILE = "_TUNNELS" _current_experiment_token: Optional[contextvars.Token] @classmethod @@ -221,6 +222,12 @@ def _from_config(cls: Type["Experiment"], exp_dir: str) -> "Experiment": exp: "Experiment" = fdl.build(cfg) exp._jobs = exp._load_jobs() + try: + exp.tunnels = exp._load_tunnels() + except Exception as e: + exp.console.log( + f"Exception {e} loading tunnels for experiment {id}, will continue without loading tunnels." + ) return exp @@ -327,6 +334,20 @@ def _save_config(self): with open(os.path.join(self._exp_dir, self.__class__._VERSION_FILE), "w+") as f: f.write(f"{run.__version__}\n") + def _save_tunnels(self): + serializer = ZlibJSONSerializer() + serialized_tunnels = { + k: serializer.serialize(v.to_config()) for k, v in self.tunnels.items() + } + with open(os.path.join(self._exp_dir, self.__class__._TUNNELS_FILE), "w+") as f: + json.dump(serialized_tunnels, f) + + def _load_tunnels(self) -> dict[str, Tunnel]: + with open(os.path.join(self._exp_dir, self.__class__._TUNNELS_FILE)) as f: + serialized_tunnels = json.load(f) + serializer = ZlibJSONSerializer() + return {k: fdl.build(serializer.deserialize(v)) for k, v in serialized_tunnels.items()} + def _save_jobs(self): serialized_jobs = list(map(lambda job: job.serialize(), self.jobs)) with open(os.path.join(self._exp_dir, self.__class__._TASK_FILE), "w+") as f: @@ -645,9 +666,19 @@ def run( for tunnel in self.tunnels.values(): if isinstance(tunnel, SSHTunnel): tunnel.connect() - assert tunnel.session, f"SSH tunnel {tunnel._key} failed to connect." + assert tunnel.session, f"SSH tunnel {tunnel.key} failed to connect." rsync(tunnel.session, self._exp_dir, os.path.dirname(tunnel.job_dir)) + symlink_cmds = [] + for packaging_job in tunnel.packaging_jobs.values(): + if packaging_job.symlink: + symlink_cmds.append(packaging_job.symlink_cmd()) + + if symlink_cmds: + tunnel.run(" && ".join(symlink_cmds)) + + self._save_tunnels() + return self._run_dag(detach=detach, tail_logs=tail_logs, executors=executors) def _run_dag(self, detach: bool, tail_logs: bool, executors: set[Executor]): diff --git a/src/nemo_run/run/torchx_backend/packaging.py b/src/nemo_run/run/torchx_backend/packaging.py index 5709bcc..3a6dd14 100644 --- a/src/nemo_run/run/torchx_backend/packaging.py +++ b/src/nemo_run/run/torchx_backend/packaging.py @@ -82,11 +82,6 @@ def package( args.append(fn_or_script_filename) else: - args += [ - "-p", - _serialize(executor.packager.to_config()), - ] - args.append(_serialize(fn_or_script)) role_args = default_cmd + args diff --git a/src/nemo_run/run/torchx_backend/schedulers/slurm.py b/src/nemo_run/run/torchx_backend/schedulers/slurm.py index 7274cfb..c0eafb5 100644 --- a/src/nemo_run/run/torchx_backend/schedulers/slurm.py +++ b/src/nemo_run/run/torchx_backend/schedulers/slurm.py @@ -72,14 +72,14 @@ def _initialize_tunnel(self, tunnel: SSHTunnel | LocalTunnel): return experiment = run_experiment._current_experiment.get(None) - if experiment and tunnel._key in experiment.tunnels: - self.tunnel = experiment.tunnels[tunnel._key] + if experiment and tunnel.key in experiment.tunnels: + self.tunnel = experiment.tunnels[tunnel.key] return self.tunnel = tunnel if experiment: - experiment.tunnels[tunnel._key] = self.tunnel + experiment.tunnels[tunnel.key] = self.tunnel def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[Any]: # type: ignore assert isinstance(cfg, SlurmExecutor), f"{cfg.__class__} not supported for slurm scheduler." @@ -96,6 +96,8 @@ def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[Any]: # t partition = executor.partition assert partition is None or isinstance(partition, str), "partition must be str" + executor.package(packager=executor.packager, job_name=Path(job_dir).name) + srun_cmds: list[list[str]] = [] jobs = [] envs = {} @@ -137,8 +139,6 @@ def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[Any]: # t with open(path, "w") as f: f.write(script) - executor.package(packager=executor.packager, job_name=Path(job_dir).name) - return AppDryRunInfo(req, repr) def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest]) -> str: # type: ignore diff --git a/test/run/torchx_backend/test_packaging.py b/test/run/torchx_backend/test_packaging.py index 9b97838..2c61925 100644 --- a/test/run/torchx_backend/test_packaging.py +++ b/test/run/torchx_backend/test_packaging.py @@ -71,8 +71,6 @@ def test_package_partial(mock_executor): "nemo_run.core.runners.fdl_runner", "-n", "test", - "-p", - "eJzdVE1Lw0AQ_StlL60goelRrKCePAgevJUSNtlJunazG_ajGEr_uzvbRNPYtBU9eQnZYea9N59bopWy5Ga0JbauwP8QDTm5HpE11PiqaLamBegkJjtvVekbZNbsA1wlwNs7wV8oVd3glIo5EUyp48JyadAqaRlsASMgcwsl4i4W5EkyeJ9w_M6nV-jOIHUFWS4RDyxl1FLvKp0QGMp4Zn-pAyEOZfS4_jTFY3l8JjL7B4lgOKOpHw-r6Qa08QPU-v2gUzlnTECUGJ1FmZI5L7p6HlqS15bjuZXSG6h7a_UEw-bjXCZKJ5kwYxysk-wSSpVoJz21hmi_CFwWUUoNdHW8NCtCdr4cB2RUF64EaRN89hkP96xdpmEMS4vTEM0aDCOsuLFK1-dBZh5kaGz-tk_YqM6JuXQwOq3psz3uLcMTEG5JrwYCaGDYUOHQkFNhIEizq4BAbvFO3kXNITpRnsNynlmE3REOj45mj5MpJ7_d2rh78OLu0YgvWby4X_AYydCTK4mKp9E08sI-AMy28Wg=", "eJzdlEtPwzAMgP_KlMuGhKp1RwRIcOOAxIHbNEVuk21haVKlzrRq2n8nDuto9wSxE5eqduLPdvxYM2ctsrvemmFdyvDDnJyy2x5byJok4Yui5iAET9kmqG32IXOsvix8qWXQt6y_MWW9BRVWeB1VmVcalalIa6CIusiIZIWyIO54zF6MkKuBou_D8IauA5vc9roHaTzI2GRCTiSCAIRgb7zWxBMqxz8GR4hubHu-rpr3sTx2iYz-QSJkLiALPYMOltJV0vHm3i8qNVVCaJnwyuVJbs1UzdrxPDdO3hsfr00oe132hOgGZPbQnxpuHc911aemOusdrcvnK55BvpBGJCgr5GUQYKZMJ5Dd5LBN7N2WO3AzX0iDnMR9n935a2bsNANhdh6xHYTThLmqQlb1ZcgoQE41znUrFfu-tXp-2htGFpY7b464ewOHCvSZLoC9F9ASIn4J2pMiDf8l4DxasnvanI9J2EwHL5tdAI2OgTICnXzdbjUuTNLmCD_QSR04ufXmYIOn7Y25E0Zb4eLkpgf1SskbXVXWUMjDZJiEyD4BcFMKTQ==", ] diff --git a/test/test_api.py b/test/test_api.py index 8026ce8..d565829 100644 --- a/test/test_api.py +++ b/test/test_api.py @@ -16,8 +16,9 @@ from dataclasses import dataclass from unittest.mock import Mock -import nemo_run as run import pytest + +import nemo_run as run from nemo_run.api import dryrun_fn @@ -117,7 +118,10 @@ def test_dryrun_fn_with_executor(self, capsys, configured_fn): captured = capsys.readouterr() assert "Dry run for task test.test_api:some_fn" in captured.out - assert "LocalExecutor(packager=Packager(debug=False)" in captured.out + assert ( + "LocalExecutor(packager=Packager(debug=False, symlink_from_remote_dir=None)" + in captured.out + ) def test_dryrun_fn_with_build(self, mocker, configured_fn): build_mock = Mock()