Skip to content

Commit

Permalink
allow "watch" to take a stop_event (#112)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored Apr 1, 2022
1 parent 93ee394 commit 90a149d
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 22 deletions.
11 changes: 5 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,20 +108,19 @@ impl RustNotify {
})
}

pub fn watch(&self, py: Python, debounce_ms: u64, step_ms: u64, cancel_event: PyObject) -> PyResult<PyObject> {
let event_not_none = !cancel_event.is_none(py);
pub fn watch(&self, py: Python, debounce_ms: u64, step_ms: u64, stop_event: PyObject) -> PyResult<PyObject> {
let event_not_none = !stop_event.is_none(py);

let mut max_time: Option<SystemTime> = None;
let step_time = Duration::from_millis(step_ms);
let mut last_size: usize = 0;
let none: Option<bool> = None;
loop {
py.allow_threads(|| sleep(step_time));
match py.check_signals() {
Ok(_) => (),
Err(_) => {
self.clear();
return Ok(none.to_object(py));
return Ok("signalled".to_object(py));
}
};

Expand All @@ -130,9 +129,9 @@ impl RustNotify {
return Err(WatchfilesRustInternalError::new_err(error.clone()));
}

if event_not_none && cancel_event.getattr(py, "is_set")?.call0(py)?.is_true(py)? {
if event_not_none && stop_event.getattr(py, "is_set")?.call0(py)?.is_true(py)? {
self.clear();
return Ok(none.to_object(py));
return Ok("stopped".to_object(py));
}

let size = self.changes.lock().unwrap().len();
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def watch(self, debounce_ms: int, step_ms: int, cancel_event):
try:
change = next(self.iter_changes)
except StopIteration:
return None
return 'signalled'
else:
self.watch_count += 1
return change
Expand Down
13 changes: 12 additions & 1 deletion tests/test_watch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
import threading
from contextlib import contextmanager
from pathlib import Path
from time import sleep
Expand All @@ -23,6 +24,16 @@ def test_watch(tmp_path: Path, write_soon):
assert changes == {(Change.added, str((tmp_path / 'foo.txt')))}


def test_wait_stop_event(tmp_path: Path, write_soon):
sleep(0.1)
write_soon(tmp_path / 'foo.txt')

stop_event = threading.Event()
for changes in watch(tmp_path, watch_filter=None, stop_event=stop_event):
assert changes == {(Change.added, str((tmp_path / 'foo.txt')))}
stop_event.set()


async def test_awatch(tmp_path: Path, write_soon):
sleep(0.1)
write_soon(tmp_path / 'foo.txt')
Expand All @@ -31,7 +42,7 @@ async def test_awatch(tmp_path: Path, write_soon):
break


async def test_await_stop(tmp_path: Path, write_soon):
async def test_await_stop_event(tmp_path: Path, write_soon):
sleep(0.1)
write_soon(tmp_path / 'foo.txt')
stop_event = anyio.Event()
Expand Down
13 changes: 7 additions & 6 deletions watchfiles/_rust_notify.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Protocol, Set, Tuple
from typing import List, Literal, Optional, Protocol, Set, Tuple, Union

__all__ = 'RustNotify', 'WatchfilesRustInternalError'

Expand All @@ -25,8 +25,8 @@ class RustNotify:
self,
debounce_ms: int,
step_ms: int,
cancel_event: Optional[AbstractEvent],
) -> Optional[Set[Tuple[int, str]]]:
stop_event: Optional[AbstractEvent],
) -> Union[Literal['signalled', 'stopped'], Set[Tuple[int, str]]]:
"""
Watch for changes and return a set of `(event_type, path)` tuples.
Expand All @@ -40,11 +40,12 @@ class RustNotify:
debounce_ms: maximum time in milliseconds to group changes over before returning.
step_ms: time to wait for new changes in milliseconds, if no changes are detected
in this time, and at least one change has been detected, the changes are yielded.
cancel_event: event to check on every iteration to see if this function should return early.
stop_event: event to check on every iteration to see if this function should return early.
Returns:
A set of `(event_type, path)` tuples,
the event types are ints which match [`Change`][watchfiles.Change].
Either a set of `(event_type, path)` tuples
(the event types are ints which match [`Change`][watchfiles.Change]),
`'signalled'` if a signal was received, or `'stopped'` if the `stop_event` was set.
"""

class WatchfilesRustInternalError(RuntimeError):
Expand Down
27 changes: 19 additions & 8 deletions watchfiles/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,23 @@ def raw_str(self) -> str:

if TYPE_CHECKING:
import asyncio
from typing import Protocol

import trio

AnyEvent = Union[anyio.Event, asyncio.Event, trio.Event]

class AbstractEvent(Protocol):
def is_set(self) -> bool:
...


def watch(
*paths: Union[Path, str],
watch_filter: Optional[Callable[['Change', str], bool]] = DefaultFilter(),
debounce: int = 1_600,
step: int = 50,
stop_event: Optional['AbstractEvent'] = None,
debug: bool = False,
raise_interrupt: bool = True,
) -> Generator[Set[FileChange], None, None]:
Expand All @@ -69,6 +75,8 @@ def watch(
debounce: maximum time in milliseconds to group changes over before yielding them.
step: time to wait for new changes in milliseconds, if no changes are detected in this time, and
at least one change has been detected, the changes are yielded.
stop_event: event to stop watching, if this is set, the generator will stop yielding changes,
this can be anything with an `is_set()` method which returns a bool, e.g. `threading.Event()`.
debug: whether to print information about all filesystem changes in rust to stdout.
raise_interrupt: whether to re-raise `KeyboardInterrupt`s, or suppress the error and just stop iterating.
Expand All @@ -84,18 +92,20 @@ def watch(
"""
watcher = RustNotify([str(p) for p in paths], debug)
while True:
raw_changes = watcher.watch(debounce, step, None)
if raw_changes is None:
raw_changes = watcher.watch(debounce, step, stop_event)
if raw_changes == 'signalled':
if raise_interrupt:
raise KeyboardInterrupt
else:
logger.warning('KeyboardInterrupt caught, stopping watch')
return

changes = _prep_changes(raw_changes, watch_filter)
if changes:
_log_changes(changes)
yield changes
elif raw_changes == 'stopped':
return
else:
changes = _prep_changes(raw_changes, watch_filter)
if changes:
_log_changes(changes)
yield changes


async def awatch(
Expand Down Expand Up @@ -186,7 +196,8 @@ async def signal_handler() -> None:
raw_changes = await anyio.to_thread.run_sync(watcher.watch, debounce, step, stop_event_)
tg.cancel_scope.cancel()

if raw_changes is None:
# cover both cases here although in theory the watch thread should never get a signal
if raw_changes == 'stopped' or raw_changes == 'signalled':
if interrupted:
if raise_interrupt:
raise KeyboardInterrupt
Expand Down

0 comments on commit 90a149d

Please sign in to comment.