Skip to content

Commit

Permalink
Fixing bug and simplifying implementation (#185)
Browse files Browse the repository at this point in the history
  • Loading branch information
tdegeus authored Oct 13, 2023
1 parent 9ba7fe4 commit fa3e650
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 43 deletions.
87 changes: 46 additions & 41 deletions shelephant/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class Location:
* :py:attr:`Location.root`: The base directory.
* :py:attr:`Location.ssh` (optional): ``[user@]host``
* :py:attr:`Location.hostpath`.
* :py:attr:`Location.prefix` (optional): Prefix to add to all paths.
* :py:attr:`Location.python` (optional): The python executable on the ``ssh`` host.
* :py:attr:`Location.dump` (optional): Location of "dump" file -- file with list of files.
Expand Down Expand Up @@ -84,14 +85,22 @@ def __init__(
:param files: List of files.
"""
self.root = pathlib.Path(root)
self._mount = mount is not None
self._absroot = self.root.absolute() if not mount else mount.absolute()
if mount is None:
self.mount = None
self._absroot_is_mount = False
self._absroot = self.root.absolute()
else:
self.mount = pathlib.Path(mount)
self._absroot_is_mount = True
self._absroot = self.mount.absolute()
self.prefix = pathlib.Path(prefix) if prefix is not None else None
self.ssh = ssh
self.python = "python3"
self.dump = None
self.search = None

assert not self._absroot_is_mount or self.ssh is not None, "needs ssh to use mount"

if ssh is not None:
assert self.root.is_absolute(), "root must be absolute path when using ssh"

Expand Down Expand Up @@ -308,10 +317,12 @@ def from_yaml(cls, path: str | pathlib.Path):
root=data.get("root", path.parent),
ssh=data.get("ssh", None),
prefix=data.get("prefix", None),
mount=data.get("mount", None),
)
ret._mount = "mount" in data
assert not ret._mount or ret.ssh is not None, "ssh must be set when using mount"
ret._absroot = data.get("mount", _force_absolute_path(path.parent, ret.root))
if not ret._absroot_is_mount:
ret._absroot = _force_absolute_path(path.parent, ret.root)
else:
ret._absroot = _force_absolute_path(path.parent, ret.mount)
ret.dump = data.get("dump", None)
ret.search = data.get("search", None)
ret._overwrite_dataset_from_dict(data.get("files", []))
Expand Down Expand Up @@ -349,6 +360,9 @@ def asdict(self) -> dict:
if self.ssh is not None:
ret["ssh"] = self.ssh

if self.mount is not None:
ret["mount"] = str(self.mount)

if self.dump is not None:
ret["dump"] = str(self.dump)

Expand Down Expand Up @@ -453,6 +467,16 @@ def isavailable(self, mount: bool = False) -> bool:
return self._absroot.is_dir()
return ssh.is_dir(self.ssh, self.root)

def is_mounted(self) -> bool:
"""
Check if a location is a local directory, or if a remote directory is mounted.
:return: True if mounted.
"""
if self.ssh is not None and not self._absroot_is_mount:
return False
return self._absroot.is_dir()

def remove(self, paths: list[str]):
"""
Remove files from list of files.
Expand Down Expand Up @@ -891,54 +915,35 @@ def lock(args: list[str]):
def _create_symlink_data(
sdir: pathlib.Path,
name: str,
root: str,
ssh: str = None,
mount: str = None,
remove: bool = False,
loc: Location,
):
"""
Create or refresh symlink in ``.shelephant/data/<name>``.
:param sdir: Path to ``.shelephant`` directory.
:param name: Name of the storage location.
:param root: Root of the storage location.
:param ssh: SSH host of the storage location.
:param mount: Mount of the storage location.
:param remove: Remove existing symlink.
:param loc: Location.
:param refresh: Remove existing symlink.
"""
if name == "here":
return

with search.cwd(sdir):
if remove:
if (sdir / "data" / name).is_symlink():
(sdir / "data" / name).unlink()

if root.is_absolute() and not ssh:
pathlib.Path(f"data/{name}").symlink_to(root)
elif ssh is not None and mount is not None:
pathlib.Path(f"data/{name}").symlink_to(mount)
elif not ssh:
pathlib.Path(f"data/{name}").symlink_to(root)
mylink = pathlib.Path(f"data/{name}")
if mylink.is_symlink():
mylink.unlink()

if loc.is_mounted():
mylink.symlink_to(loc._absroot)
else:
pathlib.Path(f"data/{name}").symlink_to(pathlib.Path("..") / "unavailable")
mylink.symlink_to(pathlib.Path("..") / "unavailable")

storage = yaml.read("storage.yaml")
if name not in storage:
storage.append(name)
yaml.overwrite("storage.yaml", storage)


def _auto_symlink_data(sdir: pathlib.Path, name: str, remove: bool = False):
"""
Call :py:func:`_create_symlink_data` with data read from ``.shelephant/storage/{name}.yaml``.
:param sdir: Path to ``.shelephant`` directory.
:param name: Name of the storage location.
:param remove: Remove existing symlink.
"""
with search.cwd(sdir):
loc = Location.from_yaml(pathlib.Path("storage") / f"{name}.yaml")
_create_symlink_data(sdir, name, loc.root, loc.ssh, loc._mount, remove)


def _add_parser():
"""
Return parser for :py:func:`shelephant add`.
Expand Down Expand Up @@ -1039,9 +1044,7 @@ def add(args: list[str]):
loc.search = s

loc.overwrite_yaml(f"storage/{args.name}.yaml")

if args.name != "here":
_create_symlink_data(sdir, args.name, args.root, args.ssh, args.mount)
_create_symlink_data(sdir, args.name, Location.from_yaml(f"storage/{args.name}.yaml"))

opts = [args.name]
if args.shallow:
Expand Down Expand Up @@ -1176,7 +1179,9 @@ def update(args: list[str]):
assert lock is None, "cannot update all locations from storage location"
assert args.name != "here" or paths is None, "cannot specify paths for 'here'"
if args.base_link:
_auto_symlink_data(sdir, args.name, remove=True)
_create_symlink_data(
sdir, args.name, Location.from_yaml(sdir / "storage" / f"{args.name}.yaml")
)
assert args.name in yaml.read(sdir / "storage.yaml"), f"'{args.name}' is not a location"
args.name = [args.name]

Expand Down
7 changes: 5 additions & 2 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,10 +445,12 @@ def test_basic_ssh_mount(self):
with tempdir(), shelephant.ssh.tempdir("localhost") as source2:
dataset = pathlib.Path("dataset")
source1 = pathlib.Path("source1")
source2 /= "lh"
mount = pathlib.Path("mount")

dataset.mkdir()
source1.mkdir()
source2.mkdir()
mount.symlink_to(source2)

with cwd(source1):
Expand All @@ -469,7 +471,7 @@ def test_basic_ssh_mount(self):
"--ssh",
"localhost",
"--mount",
str(mount),
"../mount",
"--rglob",
"*.txt",
"-q",
Expand All @@ -495,7 +497,8 @@ def test_basic_ssh_mount(self):
self.assertEqual(pathlib.Path(f).readlink().parent.name, "source1")
for f in ["e.txt", "f.txt"]:
self.assertEqual(pathlib.Path(f).readlink().parent.name, "source2")
self.assertEqual(pathlib.Path(os.path.realpath(f)).parent.name, "mount")
self.assertEqual(pathlib.Path(f).readlink().parent.readlink().name, "mount")
self.assertEqual(pathlib.Path(os.path.realpath(f)).parent.name, "lh")

with cwd(dataset):
self.assertRaises(AssertionError, shelephant.dataset.add, ["source2", "foo", "-q"])
Expand Down

0 comments on commit fa3e650

Please sign in to comment.