Skip to content

Commit 552d653

Browse files
feat: export(block=True) will block next cell execution until export finishes
1 parent ca5d942 commit 552d653

File tree

6 files changed

+311
-64
lines changed

6 files changed

+311
-64
lines changed

jdaviz/async_utils.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
"""Utility function for working with asyncio and widgets."""
2+
import asyncio
3+
import time
4+
import logging
5+
import threading
6+
7+
from jupyter_ui_poll import ui_events
8+
9+
from typing import Callable
10+
import ipywidgets
11+
from asyncio import Queue
12+
import queue
13+
import solara
14+
import solara.lab
15+
16+
logger = logging.getLogger(__name__)
17+
18+
# instead of a single global variable, this makes it kernel
19+
# scoped in solara
20+
_serial_task_queue = solara.lab.computed(Queue)
21+
22+
23+
# analogous to asyncio.create_task, except, they will be called in sequence
24+
def create_serial_task(coro):
25+
serial_task_run_task.get() # ensure task runner is running
26+
_serial_task_queue.value.put_nowait(coro)
27+
28+
29+
async def serial_task_run():
30+
logger.debug("serial_task_run: starting")
31+
while True:
32+
try:
33+
logger.debug("serial_task_run: getting task from queue")
34+
task = await _serial_task_queue.value.get()
35+
logger.debug("serial_task_run: got task from queue, running it")
36+
await task
37+
except Exception:
38+
logger.exception("Task failed")
39+
finally:
40+
_serial_task_queue.value.task_done()
41+
42+
43+
@solara.lab.computed
44+
def serial_task_run_task():
45+
event_loop_queue = queue.Queue()
46+
47+
def runner():
48+
try:
49+
event_loop = asyncio.new_event_loop()
50+
event_loop_queue.put(event_loop)
51+
52+
async def tick():
53+
while True:
54+
await asyncio.sleep(0.1)
55+
56+
# not sure why we need this
57+
event_loop.run_until_complete(tick())
58+
# event_loop.run_until_complete(serial_task_run())
59+
# but this doesn't work?
60+
# event_loop.run_forever() doesn't work?
61+
except Exception:
62+
logger.exception("Task running thread failed")
63+
64+
thread = threading.Thread(target=runner, daemon=True)
65+
thread.start()
66+
event_loop = event_loop_queue.get()
67+
event_loop.create_task(serial_task_run())
68+
logger.debug(f"serial_task_run_task: created event loop: {event_loop}")
69+
return event_loop
70+
71+
72+
def wait_for_change(widget: ipywidgets.Widget, trait_name: str):
73+
# based on https://ipywidgets.readthedocs.io/en/latest/examples/Widget%20Asynchronous.html
74+
future = asyncio.Future()
75+
76+
def getvalue(change):
77+
logger.debug(
78+
f"got new value for {trait_name} for widget {widget} new value {change.new}"
79+
)
80+
future.set_result(change.new)
81+
widget.unobserve(getvalue, trait_name)
82+
83+
logger.debug(f"observing {trait_name} for widget {type(widget)}")
84+
widget.observe(getvalue, trait_name)
85+
return future
86+
87+
88+
async def queue_screenshot_async(widget, method, *args, **kwargs):
89+
logger.debug("queue_screenshot_async: starting")
90+
# queue with just 1 item, the callback data
91+
data_queue = Queue()
92+
93+
def callback_wrapper(data):
94+
logger.debug(
95+
"queue_screenshot_async: callback_wrapper: putting result in queue"
96+
)
97+
try:
98+
data_queue.put_nowait(data)
99+
except Exception:
100+
logger.exception("Failed to put data in queue")
101+
102+
# calls Figure.get_svg_data or something similar
103+
method(callback_wrapper, *args, **kwargs)
104+
return await data_queue.get()
105+
106+
107+
def queue_screenshot_sync(
108+
widget, method, callback, on_timeout, timeout=5, *args, **kwargs
109+
):
110+
# queue with just 1 item, the callback data
111+
data_queue = Queue()
112+
113+
def callback_wrapper(data):
114+
data_queue.put_nowait(data)
115+
116+
async def execute():
117+
# calls Figure.get_svg_data or something similar
118+
method(callback_wrapper, *args, **kwargs)
119+
try:
120+
data = await asyncio.wait_for(data_queue.get(), timeout=timeout)
121+
except TimeoutError:
122+
on_timeout()
123+
callback(data)
124+
125+
create_serial_task(execute())
126+
127+
128+
def run_kernel_events_blocking_until(
129+
condition: Callable[[], bool], timeout: float = 5, sleep: float = 0.1
130+
):
131+
"""Executes kernel events while the condition is true or the timeout is reached.
132+
133+
Used to block in the notebook while we wait for a widget result.
134+
"""
135+
start_time = time.time()
136+
with ui_events() as poll:
137+
while condition():
138+
poll(10) # process up to 10 UI events per iteration
139+
time.sleep(sleep)
140+
if time.time() - start_time > timeout:
141+
raise TimeoutError(
142+
f"Timeout waiting for condition to be true after {timeout} seconds"
143+
)

jdaviz/configs/default/plugins/export/export.py

Lines changed: 89 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1+
import logging
12
import os
23
import time
34
from pathlib import Path
45
import threading
56

7+
import asyncio
8+
import ipywidgets as widgets
9+
10+
from jdaviz.async_utils import (create_serial_task, queue_screenshot_async,
11+
run_kernel_events_blocking_until,
12+
serial_task_run_task, wait_for_change)
613
from astropy import units as u
714
from astropy.nddata import CCDData
815
from glue.core.message import SubsetCreateMessage, SubsetDeleteMessage, SubsetUpdateMessage
@@ -31,6 +38,7 @@
3138
HAS_OPENCV = True
3239

3340
__all__ = ['Export']
41+
logger = logging.getLogger(__name__)
3442

3543

3644
@tray_registry('export', label="Export",
@@ -118,6 +126,7 @@ class Export(PluginTemplateMixin, ViewerSelectMixin, SubsetSelectMixin,
118126
# This is a temporary measure to allow server-installations to disable saving server-side until
119127
# saving client-side is supported for all exports.
120128
serverside_enabled = Bool(True).tag(sync=True)
129+
_busy_doing_export = Bool(False).tag(sync=True)
121130

122131
def __init__(self, *args, **kwargs):
123132
super().__init__(*args, **kwargs)
@@ -443,7 +452,7 @@ def _normalize_filename(self, filename=None, filetype=None, overwrite=False, def
443452

444453
@with_spinner()
445454
def export(self, filename=None, show_dialog=None, overwrite=False,
446-
raise_error_for_overwrite=True):
455+
raise_error_for_overwrite=True, block=True):
447456
"""
448457
Export selected item(s)
449458
@@ -462,6 +471,11 @@ def export(self, filename=None, show_dialog=None, overwrite=False,
462471
If `True`, raise exception when ``overwrite=False`` but
463472
output file already exists. Otherwise, a message will be sent
464473
to application snackbar instead.
474+
475+
block : bool
476+
If `True`, block until the export is complete, this is useful in
477+
a notebook context to ensure the export is complete before the
478+
next export is started.
465479
"""
466480
if self.multiselect:
467481
raise NotImplementedError("batch export not yet supported")
@@ -508,7 +522,7 @@ def export(self, filename=None, show_dialog=None, overwrite=False,
508522
else:
509523
self.save_figure(viewer, filename, filetype, show_dialog=show_dialog,
510524
width=f"{self.image_width}px" if self.image_custom_size else None,
511-
height=f"{self.image_height}px" if self.image_custom_size else None) # noqa
525+
height=f"{self.image_height}px" if self.image_custom_size else None, block=block) # noqa
512526

513527
# restore marks to their original state
514528
for restore, mark in zip(restores, viewer.figure.marks):
@@ -534,7 +548,7 @@ def export(self, filename=None, show_dialog=None, overwrite=False,
534548
raise FileExistsError(f"{filename} exists but overwrite={overwrite}")
535549
return
536550

537-
self.save_figure(plot, filename, filetype, show_dialog=show_dialog)
551+
self.save_figure(plot, filename, filetype, show_dialog=show_dialog, block=block)
538552

539553
elif len(self.plugin_table.selected):
540554
filetype = self.plugin_table_format.selected
@@ -620,17 +634,40 @@ def vue_overwrite_from_ui(self, *args, **kwargs):
620634
self.overwrite_warn = False
621635

622636
def save_figure(self, viewer, filename=None, filetype="png", show_dialog=False,
623-
width=None, height=None):
637+
width=None, height=None, block=True):
624638
if filename is None:
625639
filename = self.filename_default
626640

627-
# viewers in plugins will have viewer.app, other viewers have viewer.jdaviz_app
628-
if hasattr(viewer, 'jdaviz_app'):
629-
app = viewer.jdaviz_app
630-
else:
631-
app = viewer.app
641+
if self._busy_doing_export:
642+
raise ValueError("Saving figure is still in progress. Use ` export(..., block=True)` to make sure the previous export is complete") # noqa
643+
self._busy_doing_export = True
644+
self._last_error = None
632645

633-
def on_img_received(data):
646+
async def save_figure_task():
647+
try:
648+
await self._save_figure_async(viewer, filename, filetype, show_dialog, width,
649+
height)
650+
except BaseException as e:
651+
logger.error(f"Error saving figure: {e}")
652+
self._last_error = e
653+
finally:
654+
self._busy_doing_export = False
655+
if block:
656+
event_loop = serial_task_run_task.get()
657+
logger.warning(f"event loop: {event_loop}, now creating task")
658+
event_loop.create_task(save_figure_task())
659+
run_kernel_events_blocking_until(lambda: self._busy_doing_export)
660+
if self._last_error is not None:
661+
raise self._last_error
662+
else:
663+
task = asyncio.create_task(save_figure_task())
664+
create_serial_task(task)
665+
return task
666+
667+
async def _save_figure_async(self, viewer, filename, filetype, show_dialog, width, height):
668+
# Things become a bit more easy to reason about using async/await instead of callbacks
669+
# So this internal method uses async/await instead of callbacks.
670+
def save_to_file(data):
634671
try:
635672
with filename.open(mode='bw') as f:
636673
f.write(data)
@@ -643,17 +680,15 @@ def on_img_received(data):
643680
f"{self.viewer.selected} exported to {str(filename)}",
644681
sender=self, color="success"))
645682

646-
def get_png(figure):
647-
if figure._upload_png_callback is not None:
648-
raise ValueError("previous png export is still in progress. Wait to complete before making another call to save_figure") # noqa: E501 # pragma: no cover
649-
650-
figure.get_png_data(on_img_received)
683+
# viewers in plugins will have viewer.app, other viewers have viewer.jdaviz_app
684+
if hasattr(viewer, 'jdaviz_app'):
685+
app = viewer.jdaviz_app
686+
else:
687+
app = viewer.app
651688

652689
if (width is not None or height is not None):
653690
assert width is not None and height is not None, \
654691
"Both width and height must be provided"
655-
import ipywidgets as widgets
656-
from typing import Callable
657692

658693
def _show_hidden(widget: widgets.Widget, width: str, height: str):
659694
import ipyvuetify as v
@@ -669,42 +704,54 @@ def _show_hidden(widget: widgets.Widget, width: str, height: str):
669704
# TODO: we might want to remove it from the DOM
670705
app.invisible_children = [*app.invisible_children, wrapper_widget]
671706

672-
def _widget_after_first_display(widget: widgets.Widget, callback: Callable):
707+
def _widget_after_first_display(widget: widgets.Widget):
673708
if widget._view_count is None:
674709
widget._view_count = 0
675-
called_callback = False
676-
677-
def view_count_changed(change):
678-
nonlocal called_callback
679-
if change["new"] == 1 and not called_callback:
680-
called_callback = True
681-
callback()
682-
widget.observe(view_count_changed, "_view_count")
710+
logger.debug(f"waiting for view count to change for widget {type(widget)}")
711+
return wait_for_change(widget, "_view_count")
683712

684713
cloned_viewer = viewer._clone_viewer_outside_app()
685714
# make sure we will the size of our container which defines the
686715
# size of the figure
687716
cloned_viewer.figure.layout.width = "100%"
688717
cloned_viewer.figure.layout.height = "100%"
689718

690-
def on_figure_displayed():
691-
# we need a bit of a delay to ensure the figure is fully displayed
692-
# maybe this can be fixed on the bqplot side in the future
693-
def wait_in_other_thread():
694-
import time
695-
time.sleep(0.2)
696-
get_png(cloned_viewer.figure)
697-
# wait in other thread to avoid blocking the main thread (widgets can update)
698-
threading.Thread(target=wait_in_other_thread).start()
699-
_widget_after_first_display(cloned_viewer.figure, on_figure_displayed)
719+
logger.debug("calling _widget_after_first_display for widget")
720+
display_future = _widget_after_first_display(cloned_viewer.figure)
721+
logger.debug(f"calling _show_hidden for widget {display_future}")
700722
_show_hidden(cloned_viewer.figure, width, height)
701-
elif filetype == 'png':
702-
# NOTE: get_png already check if _upload_png_callback is not None
703-
get_png(viewer.figure)
704-
elif filetype == 'svg':
723+
logger.debug("waiting for display future")
724+
await display_future
725+
logger.debug("display future done")
726+
await asyncio.sleep(0.2)
727+
logger.debug("sleeping done")
728+
if cloned_viewer.figure._upload_png_callback is not None:
729+
raise ValueError("previous svg export is still in progress. Wait to complete "
730+
"before making another call to save_figure")
731+
if cloned_viewer.figure._upload_svg_callback is not None:
732+
raise ValueError("previous svg export is still in progress. Wait to complete "
733+
"before making another call to save_figure")
734+
logger.debug("queueing screenshot")
735+
get_image_data_method = cloned_viewer.figure.get_svg_data if filetype == 'svg' else \
736+
cloned_viewer.figure.get_png_data
737+
data = await queue_screenshot_async(cloned_viewer.figure, get_image_data_method)
738+
logger.debug("got data, saving to file {filename}")
739+
save_to_file(data)
740+
logger.debug("saved to file {filename}")
741+
elif filetype in ['png', 'svg']:
742+
if viewer.figure._upload_png_callback is not None:
743+
raise ValueError("previous png export is still in progress. Wait to complete "
744+
"before making another call to save_figure")
705745
if viewer.figure._upload_svg_callback is not None:
706-
raise ValueError("previous svg export is still in progress. Wait to complete before making another call to save_figure") # noqa
707-
viewer.figure.get_svg_data(on_img_received)
746+
raise ValueError("previous svg export is still in progress. Wait to complete "
747+
"before making another call to save_figure")
748+
get_image_data_method = viewer.figure.get_svg_data if filetype == 'svg' else \
749+
viewer.figure.get_png_data
750+
logger.debug("queueing screenshot")
751+
data = await queue_screenshot_async(viewer.figure, get_image_data_method)
752+
logger.debug("got data, saving to file {filename}")
753+
save_to_file(data)
754+
logger.debug("saved to file {filename}")
708755
else:
709756
raise ValueError(f"Unsupported filetype={filetype} for save_figure")
710757

0 commit comments

Comments
 (0)