Skip to content

Commit 39e05f1

Browse files
committed
feat: Added reverse SSH tunnel and more
1 parent f82e678 commit 39e05f1

File tree

5 files changed

+279
-44
lines changed

5 files changed

+279
-44
lines changed

example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44

55
def main():
6-
img = (np.random.rand(512, 512) * 4).astype("uint8")
7-
send([img, img])
6+
img = (np.random.rand(128, 128))
7+
send(img)
88

99

1010
if __name__ == "__main__":

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ requires-python = ">=3.10"
3131
dependencies = [
3232
"numpy",
3333
"pyzmq",
34+
"medvol",
3435
]
3536

3637
[project.optional-dependencies]
@@ -53,6 +54,9 @@ dev = [
5354
[project.entry-points."napari.manifest"]
5455
napari-stream = "napari_stream:napari.yaml"
5556

57+
[project.scripts]
58+
napari_stream = "napari_stream.sender:main"
59+
5660
[build-system]
5761
requires = ["setuptools>=42.0.0", "setuptools_scm"]
5862
build-backend = "setuptools.build_meta"

src/napari_stream/_receiver_widget.py

Lines changed: 188 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22
import traceback
3-
from typing import Optional
3+
import shutil
4+
import subprocess
5+
from typing import Optional, Tuple
46
import numpy as np
57
from qtpy.QtCore import QThread
68
from qtpy.QtGui import QGuiApplication
@@ -12,9 +14,12 @@
1214
QLabel,
1315
QHBoxLayout,
1416
QCheckBox,
17+
QRadioButton,
18+
QButtonGroup,
1519
)
1620

17-
from ._listener import ZMQImageListener, bind_endpoint_for_public, default_endpoint
21+
from ._listener import ZMQImageListener, bind_endpoint_for_public
22+
from ._utils import DEFAULT_TCP_PORT, default_endpoint
1823

1924
try:
2025
from napari.types import ImageData
@@ -32,32 +37,52 @@ def __init__(self, napari_viewer: Viewer):
3237
self._thread: Optional[QThread] = None
3338
self._worker: Optional[ZMQImageListener] = None
3439

35-
self._last_auto_endpoint = default_endpoint()
36-
self.endpoint_edit = QLineEdit(self._last_auto_endpoint)
40+
# Mode-specific endpoints
41+
self._endpoint_local = f"tcp://127.0.0.1:{DEFAULT_TCP_PORT}"
42+
self._endpoint_private = default_endpoint(public=True)
43+
self._endpoint_tunnel = ""
44+
45+
self._tunnel_proc: Optional[subprocess.Popen] = None
46+
self._last_mode = "local"
47+
48+
self.endpoint_edit = QLineEdit(self._endpoint_local)
3749
self.status_label = QLabel("Idle")
38-
self.public_access = QCheckBox("Enable public Access")
50+
self.mode_local = QRadioButton("Local only")
51+
self.mode_private = QRadioButton("Private network")
52+
self.mode_tunnel = QRadioButton("Private network (Reverse SSH Tunnel)")
53+
self.mode_local.setChecked(True)
54+
self.mode_group = QButtonGroup(self)
55+
for btn in (self.mode_local, self.mode_private, self.mode_tunnel):
56+
self.mode_group.addButton(btn)
57+
3958
self.autocontrast = QCheckBox("Auto-contrast on new images")
4059
self.autocontrast.setChecked(True)
60+
self.ignore_affine = QCheckBox("Ignore affine")
4161

4262
self.btn_run = QPushButton("Start")
4363
self.btn_copy = QPushButton("Copy Endpoint")
4464

4565
top = QVBoxLayout(self)
46-
row = QHBoxLayout()
47-
row.addWidget(QLabel("Endpoint:"))
48-
row.addWidget(self.endpoint_edit)
49-
row.addWidget(self.btn_copy)
50-
top.addLayout(row)
51-
top.addWidget(self.public_access)
66+
top.addWidget(QLabel("Endpoint:"))
67+
top.addWidget(self.endpoint_edit)
68+
top.addWidget(self.btn_copy)
69+
top.addWidget(self.mode_local)
70+
top.addWidget(self.mode_private)
71+
top.addWidget(self.mode_tunnel)
5272
top.addWidget(self.autocontrast)
73+
top.addWidget(self.ignore_affine)
5374
top.addWidget(self.status_label)
5475
row2 = QHBoxLayout()
5576
row2.addWidget(self.btn_run)
5677
top.addLayout(row2)
5778

5879
self.btn_run.clicked.connect(self._on_toggle_clicked)
5980
self.btn_copy.clicked.connect(self._copy_endpoint)
60-
self.public_access.stateChanged.connect(self._on_public_toggled)
81+
self.mode_group.buttonClicked.connect(self._on_mode_changed)
82+
self.destroyed.connect(lambda *_: self._stop_tunnel())
83+
app = QGuiApplication.instance()
84+
if app is not None:
85+
app.aboutToQuit.connect(lambda *_: self._stop_tunnel())
6186

6287
def _on_toggle_clicked(self):
6388
if self._is_running():
@@ -68,10 +93,30 @@ def _on_toggle_clicked(self):
6893
def _on_start(self):
6994
if self._is_running():
7095
return
71-
endpoint = self.endpoint_edit.text().strip()
72-
bind_endpoint = self._resolve_endpoint_for_worker(endpoint)
96+
mode = self._current_mode()
97+
user_entry = self.endpoint_edit.text().strip()
98+
endpoint = user_entry or self._endpoint_for_mode(mode)
99+
if mode == "local":
100+
endpoint = self._ensure_local_endpoint(endpoint, mode="local")
101+
elif mode == "private":
102+
endpoint = self._ensure_private_endpoint(endpoint)
103+
elif mode == "tunnel":
104+
ssh_target, ssh_port = self._parse_ssh_target(user_entry)
105+
if not ssh_target:
106+
self.status_label.setText("Enter SSH target for reverse tunnel.")
107+
return
108+
endpoint = self._ensure_local_endpoint(
109+
self._endpoint_tunnel or f"tcp://127.0.0.1:{DEFAULT_TCP_PORT}",
110+
mode="tunnel",
111+
update_field=False,
112+
)
113+
self._store_endpoint_for_mode("tunnel", user_entry)
114+
else:
115+
self.status_label.setText("Unknown mode.")
116+
return
117+
73118
self._thread = QThread()
74-
self._worker = ZMQImageListener(bind_endpoint)
119+
self._worker = ZMQImageListener(endpoint)
75120
self._worker.moveToThread(self._thread)
76121

77122
self._thread.started.connect(self._worker.start)
@@ -83,7 +128,12 @@ def _on_start(self):
83128
self.btn_run.setEnabled(True)
84129
self._thread.start()
85130

131+
if mode == "tunnel":
132+
if not self._start_reverse_tunnel(ssh_target, ssh_port, endpoint):
133+
self._on_stop()
134+
86135
def _on_stop(self):
136+
self._stop_tunnel()
87137
if self._worker is not None:
88138
self._worker.stop()
89139
if self._thread is not None:
@@ -94,30 +144,132 @@ def _on_stop(self):
94144
self.btn_run.setText("Start")
95145
self.btn_run.setEnabled(True)
96146

97-
def _on_public_toggled(self, checked: int):
98-
is_public = bool(checked)
99-
suggested = default_endpoint(public=is_public)
100-
current = self.endpoint_edit.text().strip()
101-
if current == self._last_auto_endpoint:
102-
self.endpoint_edit.setText(suggested)
103-
self._last_auto_endpoint = suggested
104-
if self._is_running():
105-
self._restart_listener()
106-
107147
def _copy_endpoint(self):
108148
endpoint = self.endpoint_edit.text().strip()
109149
QGuiApplication.clipboard().setText(endpoint)
110150

111-
def _resolve_endpoint_for_worker(self, endpoint: str) -> str:
112-
if not self.public_access.isChecked():
113-
return endpoint
114-
if endpoint.startswith("tcp://"):
115-
return bind_endpoint_for_public(endpoint)
116-
# If the field holds a non-TCP endpoint, fall back to a sensible TCP default.
117-
fallback = default_endpoint(public=True)
118-
self.endpoint_edit.setText(fallback)
119-
self._last_auto_endpoint = fallback
120-
return bind_endpoint_for_public(fallback)
151+
def _on_mode_changed(self, *_):
152+
# Stop if running to allow reconfiguration
153+
if self._is_running():
154+
self._on_stop()
155+
# Persist current text to the previous mode slot
156+
prev_mode = getattr(self, "_last_mode", None)
157+
if prev_mode:
158+
self._store_endpoint_for_mode(prev_mode, self.endpoint_edit.text().strip())
159+
mode = self._current_mode()
160+
self.endpoint_edit.setPlaceholderText("user@remote or ssh-alias[#port]" if mode == "tunnel" else "")
161+
self.endpoint_edit.setText(self._endpoint_for_mode(mode))
162+
self._last_mode = mode
163+
self.status_label.setText("Idle")
164+
165+
def _current_mode(self) -> str:
166+
if self.mode_private.isChecked():
167+
return "private"
168+
if self.mode_tunnel.isChecked():
169+
return "tunnel"
170+
return "local"
171+
172+
def _endpoint_for_mode(self, mode: str) -> str:
173+
if mode == "private":
174+
return self._endpoint_private
175+
if mode == "tunnel":
176+
# Show blank to prompt SSH target entry; fall back to stored value if present
177+
return "" if not self._endpoint_tunnel else self._endpoint_tunnel
178+
return self._endpoint_local
179+
180+
def _store_endpoint_for_mode(self, mode: str, endpoint: str) -> None:
181+
if mode == "private":
182+
self._endpoint_private = endpoint or self._endpoint_private
183+
elif mode == "tunnel":
184+
self._endpoint_tunnel = endpoint or self._endpoint_tunnel
185+
else:
186+
self._endpoint_local = endpoint or self._endpoint_local
187+
188+
def _ensure_local_endpoint(self, endpoint: str, *, mode: str = "local", update_field: bool = True) -> str:
189+
port = self._extract_port(endpoint) or DEFAULT_TCP_PORT
190+
value = f"tcp://127.0.0.1:{port}"
191+
if update_field:
192+
self.endpoint_edit.setText(value)
193+
self._store_endpoint_for_mode(mode, value)
194+
return value
195+
196+
def _ensure_private_endpoint(self, endpoint: str) -> str:
197+
# Prefer a shareable IP, then bind on all interfaces.
198+
if not endpoint.startswith("tcp://"):
199+
endpoint = self._endpoint_private or default_endpoint(public=True)
200+
if endpoint.count(":") < 2:
201+
# Missing port: append default
202+
endpoint = endpoint.rstrip("/") + f":{DEFAULT_TCP_PORT}"
203+
if not endpoint.startswith("tcp://"):
204+
endpoint = "tcp://" + endpoint
205+
self._endpoint_private = endpoint
206+
self.endpoint_edit.setText(endpoint)
207+
return bind_endpoint_for_public(endpoint)
208+
209+
def _extract_port(self, endpoint: str) -> Optional[int]:
210+
if not endpoint:
211+
return None
212+
try:
213+
_, port_str = endpoint.rsplit(":", 1)
214+
return int(port_str)
215+
except Exception:
216+
return None
217+
218+
def _parse_ssh_target(self, raw: str) -> Tuple[str, int]:
219+
raw = (raw or "").strip()
220+
if not raw:
221+
return "", 22
222+
if "#" in raw:
223+
target, port_part = raw.rsplit("#", 1)
224+
try:
225+
port = int(port_part)
226+
except Exception:
227+
port = 22
228+
return target or "", port
229+
return raw, 22
230+
231+
def _start_reverse_tunnel(self, target: str, ssh_port: int, endpoint: str) -> bool:
232+
self._stop_tunnel()
233+
cmd = shutil.which("autossh") or shutil.which("ssh")
234+
if cmd is None:
235+
self.status_label.setText("autossh/ssh not found in PATH.")
236+
return False
237+
local_port = self._extract_port(endpoint) or DEFAULT_TCP_PORT
238+
remote_port = local_port
239+
args = [cmd]
240+
if cmd.endswith("autossh"):
241+
args += ["-M", "0"]
242+
args += ["-N", "-R", f"{remote_port}:localhost:{local_port}"]
243+
if ssh_port != 22:
244+
args += ["-p", str(ssh_port)]
245+
args.append(target)
246+
try:
247+
self._tunnel_proc = subprocess.Popen(
248+
args,
249+
stdout=subprocess.DEVNULL,
250+
stderr=subprocess.DEVNULL,
251+
start_new_session=False,
252+
)
253+
self.status_label.setText(f"Reverse tunnel via {target} (port {ssh_port})")
254+
return True
255+
except Exception as exc: # noqa: BLE001
256+
self._tunnel_proc = None
257+
self.status_label.setText(f"Failed to start tunnel: {exc}")
258+
return False
259+
260+
def _stop_tunnel(self, *_):
261+
if self._tunnel_proc is None:
262+
return
263+
try:
264+
self._tunnel_proc.terminate()
265+
self._tunnel_proc.wait(timeout=2)
266+
except Exception:
267+
try:
268+
self._tunnel_proc.kill()
269+
except Exception:
270+
pass
271+
finally:
272+
self._tunnel_proc = None
121273

122274
def _is_running(self) -> bool:
123275
return self._thread is not None and self._thread.isRunning()
@@ -134,7 +286,7 @@ def _on_received(self, arr: np.ndarray, meta: dict):
134286
viewer_kwargs = {}
135287

136288
# Affine: accept any square >= 2x2 (2x2, 3x3, 4x4, ...)
137-
if "affine" in meta:
289+
if "affine" in meta and not self.ignore_affine.isChecked():
138290
try:
139291
A = np.asarray(meta["affine"], dtype=float)
140292
if A.ndim == 2 and A.shape[0] == A.shape[1] and A.shape[0] >= 2:

src/napari_stream/_utils.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from __future__ import annotations
2-
import os
32
import socket
43

54
DEFAULT_TCP_PORT = 5556
@@ -8,10 +7,8 @@
87
def default_endpoint(public: bool = False) -> str:
98
if public:
109
return f"tcp://{_preferred_ip()}:{DEFAULT_TCP_PORT}"
11-
if os.name == "nt": # Windows: prefer TCP
12-
return f"tcp://127.0.0.1:{DEFAULT_TCP_PORT}"
13-
# Unix: fast local IPC
14-
return "ipc:///tmp/napari_stream.sock"
10+
# TCP everywhere for simplicity
11+
return f"tcp://127.0.0.1:{DEFAULT_TCP_PORT}"
1512

1613

1714
def _preferred_ip() -> str:
@@ -23,4 +20,4 @@ def _preferred_ip() -> str:
2320
return addr
2421
except Exception:
2522
pass
26-
return "127.0.0.1"
23+
return "127.0.0.1"

0 commit comments

Comments
 (0)