Skip to content

Commit bbf4d37

Browse files
feat: export(block=True) will block next cell execution until export finishes
1 parent bc30d46 commit bbf4d37

File tree

2 files changed

+225
-42
lines changed

2 files changed

+225
-42
lines changed

jdaviz/async_utils.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
"""Utility function for working with asyncio and widgets."""
2+
import asyncio
3+
import time
4+
import logging
5+
import threading
6+
7+
from typing import Callable
8+
import ipywidgets
9+
from asyncio import Queue
10+
import queue
11+
import solara
12+
import solara.lab
13+
14+
logger = logging.getLogger(__name__)
15+
16+
# instead of a single global variable, this makes it kernel
17+
# scoped in solara
18+
_serial_task_queue = solara.lab.computed(Queue)
19+
20+
21+
# analogous to asyncio.create_task, except, they will be called in sequence
22+
def create_serial_task(coro):
23+
serial_task_run_task.get() # ensure task runner is running
24+
_serial_task_queue.value.put_nowait(coro)
25+
26+
27+
async def serial_task_run():
28+
logger.debug("serial_task_run: starting")
29+
while True:
30+
try:
31+
logger.debug("serial_task_run: getting task from queue")
32+
task = await _serial_task_queue.value.get()
33+
logger.debug("serial_task_run: got task from queue, running it")
34+
await task
35+
except Exception:
36+
logger.exception("Task failed")
37+
finally:
38+
_serial_task_queue.value.task_done()
39+
40+
41+
@solara.lab.computed
42+
def serial_task_run_task():
43+
event_loop_queue = queue.Queue()
44+
45+
def runner():
46+
try:
47+
event_loop = asyncio.new_event_loop()
48+
event_loop_queue.put(event_loop)
49+
50+
async def tick():
51+
while True:
52+
await asyncio.sleep(0.1)
53+
print(".", end="", flush=True)
54+
55+
# not sure why we need this
56+
event_loop.run_until_complete(tick())
57+
# event_loop.run_until_complete(serial_task_run())
58+
# but this doesn't work?
59+
# event_loop.run_forever() doesn't work?
60+
except Exception:
61+
logger.exception("Task running thread failed")
62+
63+
thread = threading.Thread(target=runner, daemon=True)
64+
thread.start()
65+
event_loop = event_loop_queue.get()
66+
event_loop.create_task(serial_task_run())
67+
logger.debug(f"serial_task_run_task: created event loop: {event_loop}")
68+
return event_loop
69+
70+
71+
def wait_for_change(widget: ipywidgets.Widget, trait_name: str):
72+
# based on https://ipywidgets.readthedocs.io/en/latest/examples/Widget%20Asynchronous.html
73+
future = asyncio.Future()
74+
75+
def getvalue(change):
76+
logger.debug(
77+
f"got new value for {trait_name} for widget {widget} new value {change.new}"
78+
)
79+
future.set_result(change.new)
80+
widget.unobserve(getvalue, trait_name)
81+
82+
logger.debug(f"observing {trait_name} for widget {type(widget)}")
83+
widget.observe(getvalue, trait_name)
84+
return future
85+
86+
87+
async def queue_screenshot_async(widget, method, *args, **kwargs):
88+
logger.debug("queue_screenshot_async: starting")
89+
# queue with just 1 item, the callback data
90+
data_queue = Queue()
91+
92+
def callback_wrapper(data):
93+
logger.debug(
94+
"queue_screenshot_async: callback_wrapper: putting result in queue"
95+
)
96+
try:
97+
data_queue.put_nowait(data)
98+
except Exception:
99+
logger.exception("Failed to put data in queue")
100+
101+
# calls Figure.get_svg_data or something similar
102+
method(callback_wrapper, *args, **kwargs)
103+
return await data_queue.get()
104+
105+
106+
def queue_screenshot_sync(
107+
widget, method, callback, on_timeout, timeout=5, *args, **kwargs
108+
):
109+
# queue with just 1 item, the callback data
110+
data_queue = Queue()
111+
112+
def callback_wrapper(data):
113+
data_queue.put_nowait(data)
114+
115+
async def execute():
116+
# calls Figure.get_svg_data or something similar
117+
method(callback_wrapper, *args, **kwargs)
118+
try:
119+
data = await asyncio.wait_for(data_queue.get(), timeout=timeout)
120+
except TimeoutError:
121+
on_timeout()
122+
callback(data)
123+
124+
create_serial_task(execute())
125+
126+
127+
def run_kernel_events_blocking_until(
128+
condition: Callable[[], bool], timeout: float = 5, sleep: float = 0.1
129+
):
130+
"""Executes kernel events while the condition is true or the timeout is reached.
131+
132+
Used to block in the notebook while we wait for a widget result.
133+
"""
134+
try:
135+
from jupyter_ui_poll import ui_events
136+
except ModuleNotFoundError:
137+
raise ModuleNotFoundError(
138+
"jupyter_ui_poll is not installed. Please install it with `pip install jupyter_ui_poll`"
139+
)
140+
141+
start_time = time.time()
142+
with ui_events() as poll:
143+
while condition():
144+
poll(10) # process up to 10 UI events per iteration
145+
time.sleep(sleep)
146+
if time.time() - start_time > timeout:
147+
raise TimeoutError(
148+
f"Timeout waiting for condition to be true after {timeout} seconds"
149+
)

jdaviz/configs/default/plugins/export/export.py

Lines changed: 76 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1+
import logging
12
import os
23
import time
34
from pathlib import Path
45
import threading
56

7+
import asyncio
8+
import ipywidgets as widgets
9+
10+
from jdaviz.async_utils import (create_serial_task, queue_screenshot_async,
11+
run_kernel_events_blocking_until,
12+
serial_task_run_task, wait_for_change)
613
from astropy import units as u
714
from astropy.nddata import CCDData
815
from glue.core.message import SubsetCreateMessage, SubsetDeleteMessage, SubsetUpdateMessage
@@ -31,6 +38,7 @@
3138
HAS_OPENCV = True
3239

3340
__all__ = ['Export']
41+
logger = logging.getLogger(__name__)
3442

3543

3644
@tray_registry('export', label="Export",
@@ -118,6 +126,7 @@ class Export(PluginTemplateMixin, ViewerSelectMixin, SubsetSelectMixin,
118126
# This is a temporary measure to allow server-installations to disable saving server-side until
119127
# saving client-side is supported for all exports.
120128
serverside_enabled = Bool(True).tag(sync=True)
129+
_busy_doing_export = Bool(False).tag(sync=True)
121130

122131
def __init__(self, *args, **kwargs):
123132
super().__init__(*args, **kwargs)
@@ -443,7 +452,7 @@ def _normalize_filename(self, filename=None, filetype=None, overwrite=False, def
443452

444453
@with_spinner()
445454
def export(self, filename=None, show_dialog=None, overwrite=False,
446-
raise_error_for_overwrite=True):
455+
raise_error_for_overwrite=True, block=False):
447456
"""
448457
Export selected item(s)
449458
@@ -462,6 +471,11 @@ def export(self, filename=None, show_dialog=None, overwrite=False,
462471
If `True`, raise exception when ``overwrite=False`` but
463472
output file already exists. Otherwise, a message will be sent
464473
to application snackbar instead.
474+
475+
block : bool
476+
If `True`, block until the export is complete, this is useful in
477+
a notebook context to ensure the export is complete before the
478+
next export is started.
465479
"""
466480
if self.multiselect:
467481
raise NotImplementedError("batch export not yet supported")
@@ -508,7 +522,7 @@ def export(self, filename=None, show_dialog=None, overwrite=False,
508522
else:
509523
self.save_figure(viewer, filename, filetype, show_dialog=show_dialog,
510524
width=f"{self.image_width}px" if self.image_custom_size else None,
511-
height=f"{self.image_height}px" if self.image_custom_size else None) # noqa
525+
height=f"{self.image_height}px" if self.image_custom_size else None, block=block) # noqa
512526

513527
# restore marks to their original state
514528
for restore, mark in zip(restores, viewer.figure.marks):
@@ -534,7 +548,7 @@ def export(self, filename=None, show_dialog=None, overwrite=False,
534548
raise FileExistsError(f"{filename} exists but overwrite={overwrite}")
535549
return
536550

537-
self.save_figure(plot, filename, filetype, show_dialog=show_dialog)
551+
self.save_figure(plot, filename, filetype, show_dialog=show_dialog, block=block)
538552

539553
elif len(self.plugin_table.selected):
540554
filetype = self.plugin_table_format.selected
@@ -609,17 +623,33 @@ def vue_overwrite_from_ui(self, *args, **kwargs):
609623
self.overwrite_warn = False
610624

611625
def save_figure(self, viewer, filename=None, filetype="png", show_dialog=False,
612-
width=None, height=None):
626+
width=None, height=None, block=True):
613627
if filename is None:
614628
filename = self.filename_default
615629

616-
# viewers in plugins will have viewer.app, other viewers have viewer.jdaviz_app
617-
if hasattr(viewer, 'jdaviz_app'):
618-
app = viewer.jdaviz_app
619-
else:
620-
app = viewer.app
630+
if self._busy_doing_export:
631+
raise ValueError("Saving figure is still in progress. Use `save_figure(..., block=True)` or export(..., block=True) to make sure the previous export is complete") # noqa
632+
self._busy_doing_export = True
621633

622-
def on_img_received(data):
634+
async def save_figure_task():
635+
try:
636+
await self._save_figure_async(viewer, filename, filetype, show_dialog, width, height)
637+
finally:
638+
self._busy_doing_export = False
639+
if block:
640+
event_loop = serial_task_run_task.get()
641+
logger.warning(f"event loop: {event_loop}, now creating task")
642+
event_loop.create_task(save_figure_task())
643+
run_kernel_events_blocking_until(lambda: not self._busy_doing_export)
644+
else:
645+
task = asyncio.create_task(save_figure_task())
646+
create_serial_task(task)
647+
return task
648+
649+
async def _save_figure_async(self, viewer, filename, filetype, show_dialog, width, height):
650+
# Things become a bit more easy to reason about using async/await instead of callbacks
651+
# So this internal method uses async/await instead of callbacks.
652+
def save_to_file(data):
623653
try:
624654
with filename.open(mode='bw') as f:
625655
f.write(data)
@@ -632,17 +662,15 @@ def on_img_received(data):
632662
f"{self.viewer.selected} exported to {str(filename)}",
633663
sender=self, color="success"))
634664

635-
def get_png(figure):
636-
if figure._upload_png_callback is not None:
637-
raise ValueError("previous png export is still in progress. Wait to complete before making another call to save_figure") # noqa: E501 # pragma: no cover
638-
639-
figure.get_png_data(on_img_received)
665+
# viewers in plugins will have viewer.app, other viewers have viewer.jdaviz_app
666+
if hasattr(viewer, 'jdaviz_app'):
667+
app = viewer.jdaviz_app
668+
else:
669+
app = viewer.app
640670

641671
if (width is not None or height is not None):
642672
assert width is not None and height is not None, \
643673
"Both width and height must be provided"
644-
import ipywidgets as widgets
645-
from typing import Callable
646674

647675
def _show_hidden(widget: widgets.Widget, width: str, height: str):
648676
import ipyvuetify as v
@@ -658,42 +686,48 @@ def _show_hidden(widget: widgets.Widget, width: str, height: str):
658686
# TODO: we might want to remove it from the DOM
659687
app.invisible_children = [*app.invisible_children, wrapper_widget]
660688

661-
def _widget_after_first_display(widget: widgets.Widget, callback: Callable):
689+
def _widget_after_first_display(widget: widgets.Widget):
662690
if widget._view_count is None:
663691
widget._view_count = 0
664-
called_callback = False
665-
666-
def view_count_changed(change):
667-
nonlocal called_callback
668-
if change["new"] == 1 and not called_callback:
669-
called_callback = True
670-
callback()
671-
widget.observe(view_count_changed, "_view_count")
692+
logger.debug(f"waiting for view count to change for widget {type(widget)}")
693+
return wait_for_change(widget, "_view_count")
672694

673695
cloned_viewer = viewer._clone_viewer_outside_app()
674696
# make sure we will the size of our container which defines the
675697
# size of the figure
676698
cloned_viewer.figure.layout.width = "100%"
677699
cloned_viewer.figure.layout.height = "100%"
678700

679-
def on_figure_displayed():
680-
# we need a bit of a delay to ensure the figure is fully displayed
681-
# maybe this can be fixed on the bqplot side in the future
682-
def wait_in_other_thread():
683-
import time
684-
time.sleep(0.2)
685-
get_png(cloned_viewer.figure)
686-
# wait in other thread to avoid blocking the main thread (widgets can update)
687-
threading.Thread(target=wait_in_other_thread).start()
688-
_widget_after_first_display(cloned_viewer.figure, on_figure_displayed)
701+
logger.debug("calling _widget_after_first_display for widget")
702+
display_future = _widget_after_first_display(cloned_viewer.figure)
703+
logger.debug(f"calling _show_hidden for widget {display_future}")
689704
_show_hidden(cloned_viewer.figure, width, height)
690-
elif filetype == 'png':
691-
# NOTE: get_png already check if _upload_png_callback is not None
692-
get_png(viewer.figure)
693-
elif filetype == 'svg':
705+
logger.debug("waiting for display future")
706+
await display_future
707+
logger.debug("display future done")
708+
await asyncio.sleep(0.2)
709+
logger.debug("sleeping done")
710+
if cloned_viewer.figure._upload_png_callback is not None:
711+
raise ValueError("previous svg export is still in progress. Wait to complete before making another call to save_figure")
712+
if cloned_viewer.figure._upload_svg_callback is not None:
713+
raise ValueError("previous svg export is still in progress. Wait to complete before making another call to save_figure")
714+
logger.debug("queueing screenshot")
715+
get_image_data_method = cloned_viewer.figure.get_svg_data if filetype == 'svg' else cloned_viewer.figure.get_png_data
716+
data = await queue_screenshot_async(cloned_viewer.figure, get_image_data_method)
717+
logger.debug("got data, saving to file {filename}")
718+
save_to_file(data)
719+
logger.debug("saved to file {filename}")
720+
elif filetype in ['png', 'svg']:
721+
if viewer.figure._upload_png_callback is not None:
722+
raise ValueError("previous png export is still in progress. Wait to complete before making another call to save_figure")
694723
if viewer.figure._upload_svg_callback is not None:
695-
raise ValueError("previous svg export is still in progress. Wait to complete before making another call to save_figure") # noqa
696-
viewer.figure.get_svg_data(on_img_received)
724+
raise ValueError("previous svg export is still in progress. Wait to complete before making another call to save_figure")
725+
get_image_data_method = viewer.figure.get_svg_data if filetype == 'svg' else viewer.figure.get_png_data
726+
logger.debug("queueing screenshot")
727+
data = await queue_screenshot_async(viewer.figure, get_image_data_method)
728+
logger.debug("got data, saving to file {filename}")
729+
save_to_file(data)
730+
logger.debug("saved to file {filename}")
697731
else:
698732
raise ValueError(f"Unsupported filetype={filetype} for save_figure")
699733

0 commit comments

Comments
 (0)