Skip to content

Commit

Permalink
BREAKING CHANGE: moving shelephant.search.cwd -> `shelephant.path.c…
Browse files Browse the repository at this point in the history
…wd` and `shelephant.search.tempdir` -> `shelephant.path.tempdir` (#192)
  • Loading branch information
tdegeus authored Nov 6, 2023
1 parent 41f61b4 commit 2859239
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 48 deletions.
4 changes: 2 additions & 2 deletions shelephant/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from . import dataset
from . import local
from . import output
from . import path as mypathlib
from . import rsync
from . import scp
from . import search
from . import ssh
from . import yaml
from ._version import version
Expand Down Expand Up @@ -525,7 +525,7 @@ def shelephant_rm(args: list[str], paths: list[str] = None):
if source.ssh is None:
return local.remove(source.hostpath, files, progress=not args.quiet)

with ssh.tempdir(source.ssh) as remote, search.tempdir():
with ssh.tempdir(source.ssh) as remote, mypathlib.tempdir():
files = [str(source.root / i) for i in files]
pathlib.Path("remove.txt").write_text("\n".join(files))
shutil.copy(pathlib.Path(__file__).parent / "_remove.py", "script.py")
Expand Down
37 changes: 19 additions & 18 deletions shelephant/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from . import cli
from . import compute_hash
from . import output
from . import path as mypathlib
from . import rsync
from . import scp
from . import search
Expand Down Expand Up @@ -512,7 +513,7 @@ def _read_impl(self, verbose: bool):
if self.ssh is None:
return self._overwrite_dataset_from_dict(yaml.read(self._absroot / self.dump))

with search.tempdir():
with mypathlib.tempdir():
scp.copy(self.hostpath, ".", [self.dump], progress=False)
return self._overwrite_dataset_from_dict(yaml.read(self.dump))

Expand All @@ -524,7 +525,7 @@ def _read_impl(self, verbose: bool):

# search on SSH remote host for files (the sha256/size/mtime of 'new' files is set to None)
cache_dir = ssh._shelephant_cachdir(self.ssh, self.python)
with ssh._cachedir(self.ssh, cache_dir) as remote, search.tempdir():
with ssh._cachedir(self.ssh, cache_dir) as remote, mypathlib.tempdir():
shutil.copy(pathlib.Path(__file__).parent / "search.py", "script.py")
with open("settings.json", "w") as f:
json.dump(self.search, f)
Expand Down Expand Up @@ -602,7 +603,7 @@ def _get_info(self, paths: list[pathlib.Path], sha256: bool, progress: bool, ver
)

cache_dir = ssh._shelephant_cachdir(self.ssh, self.python)
with ssh._cachedir(self.ssh, cache_dir) as remote, search.tempdir():
with ssh._cachedir(self.ssh, cache_dir) as remote, mypathlib.tempdir():
files = [str(self.root / i) for i in paths]
pathlib.Path("files.txt").write_text("\n".join(files))
pathlib.Path("sha256.txt").write_text("")
Expand Down Expand Up @@ -936,7 +937,7 @@ def _create_symlink_data(
if name == "here":
return

with search.cwd(sdir):
with mypathlib.cwd(sdir):
mylink = pathlib.Path(f"data/{name}")
if mylink.is_symlink():
mylink.unlink()
Expand Down Expand Up @@ -1036,7 +1037,7 @@ def add(args: list[str]):
if not args.mount.is_absolute():
args.mount = pathlib.Path(os.path.relpath(args.mount.absolute(), sdir / "storage"))

with search.cwd(sdir):
with mypathlib.cwd(sdir):
loc = Location(root=args.root, ssh=args.ssh, mount=args.mount, prefix=args.prefix)
s = []
d = {}
Expand Down Expand Up @@ -1197,14 +1198,14 @@ def update(args: list[str]):
assert lock != "here"
args.name = [lock]

with search.cwd(sdir):
with mypathlib.cwd(sdir):
# read existing symlinks

if lock is None:
symlinks = yaml.read("symlinks.yaml", [])
symlinks = {pathlib.Path(i["path"]): pathlib.Path(i["storage"]) for i in symlinks}
if args.clean:
with search.cwd(base):
with mypathlib.cwd(base):
for path in list(symlinks.keys()):
if not path.is_symlink():
symlinks.pop(path)
Expand Down Expand Up @@ -1321,7 +1322,7 @@ def update(args: list[str]):
if symlink not in symlinks:
add_links.append(symlink)

with search.cwd(sdir / ".."):
with mypathlib.cwd(sdir / ".."):
for f in rm_links:
if f.is_symlink():
f.unlink()
Expand Down Expand Up @@ -1430,7 +1431,7 @@ def cp(args: list[str]):
args.path = args.path if args.path != [pathlib.Path(".")] else []
paths = [os.path.relpath(path, base) for path in args.path]

with search.cwd(sdir):
with mypathlib.cwd(sdir):
opts = [f"storage/{args.source}.yaml", f"storage/{args.destination}.yaml"]
opts += ["--colors", args.colors]
opts += ["--mode", args.mode]
Expand Down Expand Up @@ -1503,7 +1504,7 @@ def mv(args: list[str]):
base = sdir.parent
paths = [os.path.relpath(path, base) for path in args.path]

with search.cwd(sdir):
with mypathlib.cwd(sdir):
dest = Location.from_yaml(f"storage/{args.destination}.yaml")
assert dest.ssh is None, "Cannot move to remote location."
opts = [f"storage/{args.source}.yaml", str(dest._absroot)]
Expand All @@ -1514,7 +1515,7 @@ def mv(args: list[str]):
cli.shelephant_mv(opts, paths)

if not args.dry_run:
with search.cwd(sdir):
with mypathlib.cwd(sdir):
f = f"storage/{args.source}.yaml"
Location.from_yaml(f).remove(paths).overwrite_yaml(f)
update(["--quiet", "--force", args.destination] + list(map(str, args.path)))
Expand Down Expand Up @@ -1570,15 +1571,15 @@ def rm(args: list[str]):
base = sdir.parent
paths = [os.path.relpath(path, base) for path in args.path]

with search.cwd(sdir):
with mypathlib.cwd(sdir):
opts = [f"storage/{args.source}.yaml"]
opts += ["--force"] if args.force else []
opts += ["--quiet"] if args.quiet else []
opts += ["--dry-run"] if args.dry_run else []
cli.shelephant_rm(opts, paths)

if not args.dry_run:
with search.cwd(sdir):
with mypathlib.cwd(sdir):
f = f"storage/{args.source}.yaml"
Location.from_yaml(f).remove(paths).overwrite_yaml(f)
update([])
Expand Down Expand Up @@ -1629,7 +1630,7 @@ def pwd(args: list[str]):
cwd = pathlib.Path.cwd()
post = os.path.relpath(cwd, sdir / "..")

with search.cwd(sdir):
with mypathlib.cwd(sdir):
f = f"storage/{args.source}.yaml"
loc = Location.from_yaml(f)
prefix = loc.prefix
Expand Down Expand Up @@ -1688,7 +1689,7 @@ def diff(args: list[str]):
sdir = _search_upwards_dir(".shelephant")
assert sdir is not None, "Not in a shelephant dataset"

with search.cwd(sdir):
with mypathlib.cwd(sdir):
storage = yaml.read(sdir / "storage.yaml")
assert args.source in storage, f"Unknown storage location {args.source}"
assert args.dest in storage, f"Unknown storage location {args.dest}"
Expand Down Expand Up @@ -1800,7 +1801,7 @@ def status(args: list[str]):
if args.in_use == "none":
args.in_use = na

with search.cwd(sdir):
with mypathlib.cwd(sdir):
symlinks = np.sort([i["path"] for i in yaml.read("symlinks.yaml", [])])
storage = yaml.read(sdir / "storage.yaml")
storage.remove("here")
Expand Down Expand Up @@ -2129,7 +2130,7 @@ def gitignore(args: list[str]):
else:
ignore = ""

with search.cwd(sdir):
with mypathlib.cwd(sdir):
symlinks = [i["path"] for i in yaml.read("symlinks.yaml", [])]

ignore += "\n# <shelephant>\n" + "\n".join(symlinks) + "\n# </shelephant>\n"
Expand All @@ -2142,5 +2143,5 @@ def git(args: list[str]):
:param args: Command-line arguments (should be all strings).
"""
with search.cwd(_search_upwards_dir(".shelephant")):
with mypathlib.cwd(_search_upwards_dir(".shelephant")):
print(exec_cmd(f"git {' '.join(args)}"))
39 changes: 39 additions & 0 deletions shelephant/path.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os
import pathlib
import tempfile
from collections import defaultdict
from contextlib import contextmanager

import click

Expand Down Expand Up @@ -95,3 +98,39 @@ def makedirs(dirnames: list[str], force: bool = False):

for dirname in dirnames:
os.makedirs(dirname)


@contextmanager
def tempdir():
"""
Set the cwd to a temporary directory::
with tempdir("foo"):
# Do something in foo
"""

origin = pathlib.Path().absolute()
with tempfile.TemporaryDirectory() as dirname:
try:
os.chdir(dirname)
yield pathlib.Path(dirname)
finally:
os.chdir(origin)


@contextmanager
def cwd(dirname: pathlib.Path):
"""
Set the cwd to a specified directory::
with cwd("foo"):
# Do something in foo
:param dirname: The directory to change to.
"""
origin = pathlib.Path().absolute()
try:
os.chdir(dirname)
yield dirname
finally:
os.chdir(origin)
27 changes: 5 additions & 22 deletions shelephant/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,17 @@
import re
import subprocess
import sys
import tempfile
from contextlib import contextmanager


# note: a local function is needed
# remote read copies this file as it is (and not "path.py")
@contextmanager
def tempdir():
"""
Set the cwd to a temporary directory::
with tempdir("foo"):
# Do something in foo
"""

origin = pathlib.Path().absolute()
with tempfile.TemporaryDirectory() as dirname:
try:
os.chdir(dirname)
yield
finally:
os.chdir(origin)


@contextmanager
def cwd(dirname: pathlib.Path):
def _cwd(dirname: pathlib.Path):
"""
Set the cwd to a specified directory::
with cwd("foo"):
with _cwd("foo"):
# Do something in foo
:param dirname: The directory to change to.
Expand Down Expand Up @@ -129,7 +112,7 @@ def search(*settings: dict, root: pathlib.Path = pathlib.Path(".")) -> list[path
:param root: The root directory to search in.
:return: A list of paths.
"""
with cwd(root):
with _cwd(root):
ret = []
for setting in settings:
if "rglob" in setting:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from shelephant.cli import shelephant_mv
from shelephant.cli import shelephant_parse
from shelephant.cli import shelephant_rm
from shelephant.search import cwd
from shelephant.search import tempdir
from shelephant.path import cwd
from shelephant.path import tempdir

has_ssh = shelephant.ssh.has_keys_set("localhost")
has_rsync = shutil.which("rsync") is not None
Expand Down
4 changes: 2 additions & 2 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from shelephant._tests import create_dummy_files
from shelephant.cli import f_dump
from shelephant.cli import shelephant_dump
from shelephant.search import cwd
from shelephant.search import tempdir
from shelephant.path import cwd
from shelephant.path import tempdir

has_ssh = shelephant.ssh.has_keys_set("localhost")

Expand Down
4 changes: 2 additions & 2 deletions tests/test_local_rsync_scp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import shelephant
from shelephant._tests import create_dummy_files
from shelephant.search import cwd
from shelephant.search import tempdir
from shelephant.path import cwd
from shelephant.path import tempdir

has_rsync = shutil.which("rsync") is not None

Expand Down

0 comments on commit 2859239

Please sign in to comment.