Skip to content

Commit

Permalink
Add option to symlink from remote dir in packager (#122)
Browse files Browse the repository at this point in the history
* Add option to symlink from remote dir in packager

Signed-off-by: Hemil Desai <[email protected]>

* Save tunnels for experiment

Signed-off-by: Hemil Desai <[email protected]>

* Fix

Signed-off-by: Hemil Desai <[email protected]>

* Mount base remote dir for symlinks

Signed-off-by: Hemil Desai <[email protected]>

* fix

Signed-off-by: Hemil Desai <[email protected]>

---------

Signed-off-by: Hemil Desai <[email protected]>
  • Loading branch information
hemildesai authored Dec 18, 2024
1 parent 283c0a1 commit b4e2258
Show file tree
Hide file tree
Showing 10 changed files with 132 additions and 79 deletions.
47 changes: 42 additions & 5 deletions src/nemo_run/core/execution/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,14 @@
from nemo_run.core.packaging.base import Packager
from nemo_run.core.packaging.git import GitArchivePackager
from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer
from nemo_run.core.tunnel.callback import Callback
from nemo_run.core.tunnel.client import LocalTunnel, SSHConfigFile, SSHTunnel, Tunnel
from nemo_run.core.tunnel.client import (
Callback,
LocalTunnel,
PackagingJob,
SSHConfigFile,
SSHTunnel,
Tunnel,
)
from nemo_run.core.tunnel.server import TunnelMetadata, server_dir
from nemo_run.devspace.base import DevSpace

Expand Down Expand Up @@ -388,7 +394,7 @@ def __post_init__(self):
self.wait_time_for_group_job = 0

def info(self) -> str:
return f"{self.__class__.__qualname__} on {self.tunnel._key}"
return f"{self.__class__.__qualname__} on {self.tunnel.key}"

def alloc(self, job_name="interactive"):
self.job_name = f"{self.job_name_prefix}{job_name}"
Expand Down Expand Up @@ -537,13 +543,39 @@ def package_configs(self, *cfgs: tuple[str, str]) -> list[str]:
return filenames

def package(self, packager: Packager, job_name: str):
if job_name in self.tunnel.packaging_jobs:
if job_name in self.tunnel.packaging_jobs and not packager.symlink_from_remote_dir:
logger.info(
f"Packaging for job {job_name} in tunnel {self.tunnel} already done. Skipping subsequent packagings.\n"
"This may cause issues if you have multiple tasks with the same name but different packagers, as only the first packager will be used."
)
return

if packager.symlink_from_remote_dir:
logger.info(
f"Packager {packager} is configured to symlink from remote dir. Skipping packaging."
)
if type(packager) is Packager:
self.tunnel.packaging_jobs[job_name] = PackagingJob(symlink=False)
return

self.tunnel.packaging_jobs[job_name] = PackagingJob(
symlink=True,
src_path=packager.symlink_from_remote_dir,
dst_path=os.path.join(self.tunnel.job_dir, Path(self.job_dir).name, "code"),
)

# Tunnel job dir is the directory of the experiment id, so the base job dir is two levels up
base_remote_dir = str(Path(self.tunnel.job_dir).parent.parent)
base_remote_mount = f"{base_remote_dir}:{base_remote_dir}"
if base_remote_mount not in self.container_mounts:
self.container_mounts.append(f"{base_remote_dir}:{base_remote_dir}")

for req in self.resource_group:
if base_remote_mount not in req.container_mounts:
req.container_mounts.append(base_remote_mount)

return

assert self.experiment_id, "Executor not assigned to an experiment."
if isinstance(packager, GitArchivePackager):
output = subprocess.run(
Expand Down Expand Up @@ -573,7 +605,12 @@ def package(self, packager: Packager, job_name: str):
f"tar -xvzf {local_pkg} -C {local_code_extraction_path} --ignore-zeros", hide=True
)

self.tunnel.packaging_jobs.add(job_name)
self.tunnel.packaging_jobs[job_name] = PackagingJob(
symlink=False,
dst_path=None
if type(packager) is Packager
else os.path.join(self.tunnel.job_dir, Path(self.job_dir).name, "code"),
)

def parse_deps(self) -> list[str]:
"""
Expand Down
6 changes: 5 additions & 1 deletion src/nemo_run/core/packaging/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
from dataclasses import dataclass
from pathlib import Path

from typing import Optional

from nemo_run.config import ConfigurableMixin

Expand Down Expand Up @@ -45,6 +45,10 @@ class Packager(ConfigurableMixin):
#: Uses component or executor specific debug flags if set to True.
debug: bool = False

#: Symlinks the package from the provided remote dir.
#: Only applicable when using SlurmExecutor at the moment.
symlink_from_remote_dir: Optional[str] = None

def package(self, path: Path, job_dir: str, name: str) -> str: ...

def setup(self):
Expand Down
45 changes: 0 additions & 45 deletions src/nemo_run/core/tunnel/callback.py

This file was deleted.

53 changes: 41 additions & 12 deletions src/nemo_run/core/tunnel/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,17 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Optional
from typing import Callable, Optional

import paramiko
import paramiko.ssh_exception
from fabric import Config, Connection
from invoke.context import Context
from invoke.runners import Result as RunResult

from nemo_run.config import NEMORUN_HOME
from nemo_run.config import NEMORUN_HOME, ConfigurableMixin
from nemo_run.core.frontend.console.api import CONSOLE

if TYPE_CHECKING:
from nemo_run.core.tunnel.callback import Callback

logger: logging.Logger = logging.getLogger(__name__)
TUNNEL_DIR = ".tunnels"
TUNNEL_FILE_SUBPATH = os.path.join(NEMORUN_HOME, TUNNEL_DIR)
Expand All @@ -58,18 +55,24 @@ def authentication_handler(title, instructions, prompt_list):


@dataclass(kw_only=True)
class Tunnel(ABC):
class PackagingJob(ConfigurableMixin):
symlink: bool = False
src_path: Optional[str] = None
dst_path: Optional[str] = None

def symlink_cmd(self):
return f"ln -s {self.src_path} {self.dst_path}"


@dataclass(kw_only=True)
class Tunnel(ABC, ConfigurableMixin):
job_dir: str
host: str
user: str
packaging_jobs: dict[str, PackagingJob] = field(default_factory=dict)

def __post_init__(self):
self._key = f"{self.user}@{self.host}"
self._packaging_jobs = set()

@property
def packaging_jobs(self):
return self._packaging_jobs
self.key = f"{self.user}@{self.host}"

def _set_job_dir(self, experiment_id: str): ...

Expand Down Expand Up @@ -377,3 +380,29 @@ def remove_entry(self, name: str):
file.writelines(lines)

print(f"Removed SSH config entry for {host}.")


class Callback:
def setup(self, tunnel: "Tunnel"):
"""Called when the tunnel is setup."""
self.tunnel = tunnel

def on_start(self):
"""Called when the keep_alive loop starts."""
pass

def on_interval(self):
"""Called at each interval during the keep_alive loop."""
pass

def on_stop(self):
"""Called when the keep_alive loop stops."""
pass

def on_error(self, error: Exception):
"""Called when an error occurs during the keep_alive loop.
Args:
error (Exception): The exception that was raised.
"""
pass
2 changes: 1 addition & 1 deletion src/nemo_run/devspace/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import fiddle as fdl

from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer
from nemo_run.core.tunnel.callback import Callback
from nemo_run.core.tunnel.client import Callback

if TYPE_CHECKING:
from nemo_run.core.execution.base import Executor
Expand Down
33 changes: 32 additions & 1 deletion src/nemo_run/run/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ class Experiment(ConfigurableMixin):
_VERSION_FILE = "_VERSION"
_TASK_FILE = "_TASKS"
_DONE_FILE = "_DONE"
_TUNNELS_FILE = "_TUNNELS"
_current_experiment_token: Optional[contextvars.Token]

@classmethod
Expand Down Expand Up @@ -221,6 +222,12 @@ def _from_config(cls: Type["Experiment"], exp_dir: str) -> "Experiment":

exp: "Experiment" = fdl.build(cfg)
exp._jobs = exp._load_jobs()
try:
exp.tunnels = exp._load_tunnels()
except Exception as e:
exp.console.log(
f"Exception {e} loading tunnels for experiment {id}, will continue without loading tunnels."
)

return exp

Expand Down Expand Up @@ -327,6 +334,20 @@ def _save_config(self):
with open(os.path.join(self._exp_dir, self.__class__._VERSION_FILE), "w+") as f:
f.write(f"{run.__version__}\n")

def _save_tunnels(self):
serializer = ZlibJSONSerializer()
serialized_tunnels = {
k: serializer.serialize(v.to_config()) for k, v in self.tunnels.items()
}
with open(os.path.join(self._exp_dir, self.__class__._TUNNELS_FILE), "w+") as f:
json.dump(serialized_tunnels, f)

def _load_tunnels(self) -> dict[str, Tunnel]:
with open(os.path.join(self._exp_dir, self.__class__._TUNNELS_FILE)) as f:
serialized_tunnels = json.load(f)
serializer = ZlibJSONSerializer()
return {k: fdl.build(serializer.deserialize(v)) for k, v in serialized_tunnels.items()}

def _save_jobs(self):
serialized_jobs = list(map(lambda job: job.serialize(), self.jobs))
with open(os.path.join(self._exp_dir, self.__class__._TASK_FILE), "w+") as f:
Expand Down Expand Up @@ -645,9 +666,19 @@ def run(
for tunnel in self.tunnels.values():
if isinstance(tunnel, SSHTunnel):
tunnel.connect()
assert tunnel.session, f"SSH tunnel {tunnel._key} failed to connect."
assert tunnel.session, f"SSH tunnel {tunnel.key} failed to connect."
rsync(tunnel.session, self._exp_dir, os.path.dirname(tunnel.job_dir))

symlink_cmds = []
for packaging_job in tunnel.packaging_jobs.values():
if packaging_job.symlink:
symlink_cmds.append(packaging_job.symlink_cmd())

if symlink_cmds:
tunnel.run(" && ".join(symlink_cmds))

self._save_tunnels()

return self._run_dag(detach=detach, tail_logs=tail_logs, executors=executors)

def _run_dag(self, detach: bool, tail_logs: bool, executors: set[Executor]):
Expand Down
5 changes: 0 additions & 5 deletions src/nemo_run/run/torchx_backend/packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,6 @@ def package(

args.append(fn_or_script_filename)
else:
args += [
"-p",
_serialize(executor.packager.to_config()),
]

args.append(_serialize(fn_or_script))

role_args = default_cmd + args
Expand Down
10 changes: 5 additions & 5 deletions src/nemo_run/run/torchx_backend/schedulers/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,14 @@ def _initialize_tunnel(self, tunnel: SSHTunnel | LocalTunnel):
return

experiment = run_experiment._current_experiment.get(None)
if experiment and tunnel._key in experiment.tunnels:
self.tunnel = experiment.tunnels[tunnel._key]
if experiment and tunnel.key in experiment.tunnels:
self.tunnel = experiment.tunnels[tunnel.key]
return

self.tunnel = tunnel

if experiment:
experiment.tunnels[tunnel._key] = self.tunnel
experiment.tunnels[tunnel.key] = self.tunnel

def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[Any]: # type: ignore
assert isinstance(cfg, SlurmExecutor), f"{cfg.__class__} not supported for slurm scheduler."
Expand All @@ -96,6 +96,8 @@ def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[Any]: # t
partition = executor.partition
assert partition is None or isinstance(partition, str), "partition must be str"

executor.package(packager=executor.packager, job_name=Path(job_dir).name)

srun_cmds: list[list[str]] = []
jobs = []
envs = {}
Expand Down Expand Up @@ -137,8 +139,6 @@ def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[Any]: # t
with open(path, "w") as f:
f.write(script)

executor.package(packager=executor.packager, job_name=Path(job_dir).name)

return AppDryRunInfo(req, repr)

def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest]) -> str: # type: ignore
Expand Down
2 changes: 0 additions & 2 deletions test/run/torchx_backend/test_packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ def test_package_partial(mock_executor):
"nemo_run.core.runners.fdl_runner",
"-n",
"test",
"-p",
"eJzdVE1Lw0AQ_StlL60goelRrKCePAgevJUSNtlJunazG_ajGEr_uzvbRNPYtBU9eQnZYea9N59bopWy5Ga0JbauwP8QDTm5HpE11PiqaLamBegkJjtvVekbZNbsA1wlwNs7wV8oVd3glIo5EUyp48JyadAqaRlsASMgcwsl4i4W5EkyeJ9w_M6nV-jOIHUFWS4RDyxl1FLvKp0QGMp4Zn-pAyEOZfS4_jTFY3l8JjL7B4lgOKOpHw-r6Qa08QPU-v2gUzlnTECUGJ1FmZI5L7p6HlqS15bjuZXSG6h7a_UEw-bjXCZKJ5kwYxysk-wSSpVoJz21hmi_CFwWUUoNdHW8NCtCdr4cB2RUF64EaRN89hkP96xdpmEMS4vTEM0aDCOsuLFK1-dBZh5kaGz-tk_YqM6JuXQwOq3psz3uLcMTEG5JrwYCaGDYUOHQkFNhIEizq4BAbvFO3kXNITpRnsNynlmE3REOj45mj5MpJ7_d2rh78OLu0YgvWby4X_AYydCTK4mKp9E08sI-AMy28Wg=",
"eJzdlEtPwzAMgP_KlMuGhKp1RwRIcOOAxIHbNEVuk21haVKlzrRq2n8nDuto9wSxE5eqduLPdvxYM2ctsrvemmFdyvDDnJyy2x5byJok4Yui5iAET9kmqG32IXOsvix8qWXQt6y_MWW9BRVWeB1VmVcalalIa6CIusiIZIWyIO54zF6MkKuBou_D8IauA5vc9roHaTzI2GRCTiSCAIRgb7zWxBMqxz8GR4hubHu-rpr3sTx2iYz-QSJkLiALPYMOltJV0vHm3i8qNVVCaJnwyuVJbs1UzdrxPDdO3hsfr00oe132hOgGZPbQnxpuHc911aemOusdrcvnK55BvpBGJCgr5GUQYKZMJ5Dd5LBN7N2WO3AzX0iDnMR9n935a2bsNANhdh6xHYTThLmqQlb1ZcgoQE41znUrFfu-tXp-2htGFpY7b464ewOHCvSZLoC9F9ASIn4J2pMiDf8l4DxasnvanI9J2EwHL5tdAI2OgTICnXzdbjUuTNLmCD_QSR04ufXmYIOn7Y25E0Zb4eLkpgf1SskbXVXWUMjDZJiEyD4BcFMKTQ==",
]

Expand Down
8 changes: 6 additions & 2 deletions test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
from dataclasses import dataclass
from unittest.mock import Mock

import nemo_run as run
import pytest

import nemo_run as run
from nemo_run.api import dryrun_fn


Expand Down Expand Up @@ -117,7 +118,10 @@ def test_dryrun_fn_with_executor(self, capsys, configured_fn):

captured = capsys.readouterr()
assert "Dry run for task test.test_api:some_fn" in captured.out
assert "LocalExecutor(packager=Packager(debug=False)" in captured.out
assert (
"LocalExecutor(packager=Packager(debug=False, symlink_from_remote_dir=None)"
in captured.out
)

def test_dryrun_fn_with_build(self, mocker, configured_fn):
build_mock = Mock()
Expand Down

0 comments on commit b4e2258

Please sign in to comment.