diff --git a/mesoscope_gui/gui.py b/mesoscope_gui/gui.py index adc5ba5..3b71c77 100644 --- a/mesoscope_gui/gui.py +++ b/mesoscope_gui/gui.py @@ -1,3 +1,7 @@ +# ------------------------------------------------------------------------------------------------- +# Imports +# ------------------------------------------------------------------------------------------------- + import sys import os import json @@ -12,38 +16,57 @@ import numpy as np +# ------------------------------------------------------------------------------------------------- +# Constants +# ------------------------------------------------------------------------------------------------- + opacity_effect = QGraphicsOpacityEffect() opacity_effect.setOpacity(0.5) RADIUS = 20 +# ------------------------------------------------------------------------------------------------- +# Util functions +# ------------------------------------------------------------------------------------------------- + def set_widget_opaque(widget, is_opaque): widget.setGraphicsEffect(opacity_effect if not is_opaque else None) +# ------------------------------------------------------------------------------------------------- +# Mescoscope GUI +# ------------------------------------------------------------------------------------------------- + class MesoscopeGUI(QMainWindow): - def __init__(self): - super().__init__() + # Initialization + # ------------------------------------------------------------------------------------------------- + + def __init__(self): self.current_folder_idx = 0 self.folder_paths = [] self.stack_count = 0 self.current_stack_idx = 0 - self.pixmap = None - self.init_ui() - self.points = [{} for _ in range(3)] # stack_idx, coords + super().__init__() + self.init_ui() + self._init_point_widgets() + + def _init_point_widgets(self): self.points_widgets = [] for point_idx in range(3): color = ['red', 'green', 'blue'][point_idx] x = y = -100 - self.add_point_widget(x, y, color, point_idx) + self._add_point_widget(x, y, color, point_idx) + + # UI + # ------------------------------------------------------------------------------------------------- def init_ui(self): menu_bar = self.menuBar() @@ -69,7 +92,7 @@ def init_ui(self): self.image_label = QLabel() self.image_label.installEventFilter(self) # for mouse wheel scroll self.image_label.setScaledContents(True) - self.image_label.resizeEvent = lambda event: self.update_margins() + self.image_label.resizeEvent = self.on_resized self.image_label.mousePressEvent = self.add_point_at_click self.image_label.setMinimumSize(1, 1) self.image_layout.addWidget(self.image_label) @@ -98,6 +121,33 @@ def init_ui(self): self.resize(800, 600) self.show() + def _add_point_widget(self, x, y, color, point_idx): + point_label = QLabel(self.image_label) + r = 20 + point_label.setFixedSize(r, r) + + pixmap = QPixmap(r, r) + pixmap.fill(Qt.transparent) + painter = QPainter(pixmap) + painter.setBrush(QColor(color)) + painter.setPen(Qt.NoPen) + painter.drawEllipse(0, 0, r, r) + painter.end() + + point_label.setPixmap(pixmap) + point_label.setAttribute(Qt.WA_TransparentForMouseEvents, False) + point_label.move(x - r // 2, y - r // 2) + point_label.show() + + point_label.mousePressEvent = lambda event: self.start_drag(event, point_label) + point_label.mouseMoveEvent = lambda event: self.drag_point(event, point_label) + point_label.mouseReleaseEvent = lambda event: self.end_drag(event, point_label, point_idx) + + self.points_widgets.append(point_label) + + # Folder opening + # ------------------------------------------------------------------------------------------------- + def open_dialog(self): dialog = QFileDialog() dialog.setFileMode(QFileDialog.DirectoryOnly) @@ -133,6 +183,10 @@ def select_folder(self, folder_index): self.current_folder_idx = folder_index self.update_folder() + def navigate(self, direction): + self.current_folder_idx += direction + self.folder_list.setCurrentRow(self.current_folder_idx) + def load_image_stack(self): folder = self.folder_paths[self.current_folder_idx] stack_file = next(f for f in os.listdir(folder) if f.startswith("referenceImage.stack") and f.endswith(".tif")) @@ -153,6 +207,69 @@ def load_image_stack(self): self.scrollbar.setValue(self.stack_count // 2) self.load_points() + # Coordinate transforms + # ------------------------------------------------------------------------------------------------- + + def to_relative(self, x, y): + label_width, label_height = self.image_label.width(), self.image_label.height() + pixmap = self.image_label.pixmap() + + if not pixmap: + return None, None + + # Get the aspect ratio for the pixmap and label + pixmap_width, pixmap_height = pixmap.width(), pixmap.height() + label_aspect_ratio = label_width / label_height + pixmap_aspect_ratio = pixmap_width / pixmap_height + + # Calculate scaled dimensions of the pixmap to fit within the label, preserving aspect ratio + if label_aspect_ratio > pixmap_aspect_ratio: + scaled_width = int(label_height * pixmap_aspect_ratio) + scaled_height = label_height + else: + scaled_width = label_width + scaled_height = int(label_width / pixmap_aspect_ratio) + + # Calculate margins on each side + x_margin = (label_width - scaled_width) // 2 + y_margin = (label_height - scaled_height) // 2 + + # Convert from pixel coordinates to relative [0,1] coordinates + xr = (x - x_margin) / scaled_width + yr = (y - y_margin) / scaled_height + + return max(0, min(1, xr)), max(0, min(1, yr)) # Clamp values to [0, 1] + + def to_absolute(self, xr, yr): + label_width, label_height = self.image_label.width(), self.image_label.height() + pixmap = self.image_label.pixmap() + + if not pixmap: + return None, None + + # Get the aspect ratio for the pixmap and label + pixmap_width, pixmap_height = pixmap.width(), pixmap.height() + label_aspect_ratio = label_width / label_height + pixmap_aspect_ratio = pixmap_width / pixmap_height + + # Calculate scaled dimensions of the pixmap to fit within the label, preserving aspect ratio + if label_aspect_ratio > pixmap_aspect_ratio: + scaled_width = int(label_height * pixmap_aspect_ratio) + scaled_height = label_height + else: + scaled_width = label_width + scaled_height = int(label_width / pixmap_aspect_ratio) + + # Calculate margins on each side + x_margin = (label_width - scaled_width) // 2 + y_margin = (label_height - scaled_height) // 2 + + # Convert from relative [0,1] coordinates back to pixel coordinates + x = x_margin + int(xr * scaled_width) + y = y_margin + int(yr * scaled_height) + + return x, y + def update_margins(self): pixmap = self.pixmap if not pixmap: @@ -171,26 +288,6 @@ def update_margins(self): m = int(m) self.image_label.setContentsMargins(0, m, 0, m) - def to_relative(self, x, y): - pixmap = self.pixmap - if not pixmap: - return - size = self.image_label.size() - w, h = size.width(), size.height() - pw = pixmap.width() - ph = pixmap.height() - - ih = self.image_stack.shape[1] - iw = self.image_stack.shape[2] - - if (w * ph > h * pw): - m = (w - (pw * h / ph)) / 2 - return ((x-m)/w, y/h) - - else: - m = (h - (ph * w / pw)) / 2 - return ((x)/w, (y-m)/h) - def update_image(self): self.current_stack_idx = self.scrollbar.value() if self.current_stack_idx >= self.stack_count: @@ -199,49 +296,29 @@ def update_image(self): qimg = QImage(img.data, img.shape[1], img.shape[0], img.strides[0], QImage.Format_Grayscale8) self.pixmap = QPixmap.fromImage(qimg) - # pixmap = QPixmap.fromImage(qimg).scaled( - # self.image_label.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation) - - # painter = QPainter(self.pixmap) - # for p in self.points.get(self.current_stack_idx, []): - # opacity = 1.0 if p['stack_idx'] == self.current_stack_idx else 0.5 - # painter.setOpacity(opacity) - # color = [QColor('red'), QColor('green'), QColor('blue')][p['point_idx']] - # painter.setPen(color) - # painter.setBrush(color) - # painter.drawEllipse(p['coords'][0] - 5, p['coords'][1] - 5, 10, 10) - # painter.end() self.image_label.setPixmap(self.pixmap) - self.update_margins() - def navigate(self, direction): - self.current_folder_idx += direction - self.folder_list.setCurrentRow(self.current_folder_idx) + # Adding points + # ------------------------------------------------------------------------------------------------- - def add_point_widget(self, x, y, color, point_idx): - point_label = QLabel(self.image_label) - r = 20 - point_label.setFixedSize(r, r) - - pixmap = QPixmap(r, r) - pixmap.fill(Qt.transparent) - painter = QPainter(pixmap) - painter.setBrush(QColor(color)) - painter.setPen(Qt.NoPen) - painter.drawEllipse(0, 0, r, r) - painter.end() - - point_label.setPixmap(pixmap) - point_label.setAttribute(Qt.WA_TransparentForMouseEvents, False) - point_label.move(x - r // 2, y - r // 2) - point_label.show() + def set_point_position(self, point_idx, xr, yr, stack_idx): + self.points[point_idx]['coords'] = (xr, yr) + self.points[point_idx]['stack_idx'] = stack_idx + self.update_point_position(point_idx) - point_label.mousePressEvent = lambda event: self.start_drag(event, point_label) - point_label.mouseMoveEvent = lambda event: self.drag_point(event, point_label) - point_label.mouseReleaseEvent = lambda event: self.end_drag(event, point_label, point_idx) + def update_point_position(self, point_idx): + xr, yr = self.points[point_idx].get('coords', (None, None)) + if xr is None: + return + x, y = self.to_absolute(xr, yr) + self.points_widgets[point_idx].move(x - RADIUS // 2, y - RADIUS // 2) - self.points_widgets.append(point_label) + def clear_points(self): + for widget, point in zip(self.points_widgets, self.points): + widget.move(-100, -100) + point['coords'] = None + point['stack_idx'] = None def add_point_at_click(self, event): x, y = event.pos().x(), event.pos().y() @@ -249,13 +326,11 @@ def add_point_at_click(self, event): if point_idx is None: return assert 0 <= point_idx and point_idx < 3 + xr, yr = self.to_relative(x, y) + self.set_point_position(point_idx, xr, yr, self.current_stack_idx) - r = RADIUS - self.points_widgets[point_idx].move(x - r // 2, y - r // 2) - - self.points[point_idx]['coords'] = self.to_relative(x, y) - self.points[point_idx]['stack_idx'] = self.current_stack_idx - self.save_points() + # Points drag and drop + # ------------------------------------------------------------------------------------------------- def start_drag(self, event, point_label): self.drag_offset = event.pos() @@ -270,20 +345,10 @@ def end_drag(self, event, point_label, point_idx): x, y = point_label.x() + r // 2, point_label.y() + r // 2 self.points[point_idx]['coords'] = self.to_relative(x, y) - self.save_points() + # self.save_points() - def clear_points(self): - for widget in self.points_widgets: - widget.move(-100, -100) - - def eventFilter(self, obj, event): - if obj == self.image_label and event.type() == event.Wheel: - delta = event.angleDelta().y() // 120 - new_value = self.scrollbar.value() - delta - new_value = max(0, min(self.scrollbar.maximum(), new_value)) - self.scrollbar.setValue(new_value) - return True - return super().eventFilter(obj, event) + # Points file + # ------------------------------------------------------------------------------------------------- @property def points_file(self): @@ -308,6 +373,23 @@ def save_points(self): with open(points_file, 'w') as f: json.dump({'points': self.points}, f) + # Event handling + # ------------------------------------------------------------------------------------------------- + + def on_resized(self, ev): + self.update_margins() + for point_idx in range(3): + self.update_point_position(point_idx) + + def eventFilter(self, obj, event): + if obj == self.image_label and event.type() == event.Wheel: + delta = event.angleDelta().y() // 120 + new_value = self.scrollbar.value() - delta + new_value = max(0, min(self.scrollbar.maximum(), new_value)) + self.scrollbar.setValue(new_value) + return True + return super().eventFilter(obj, event) + if __name__ == '__main__': app = QApplication(sys.argv)