From fe4fda30f9db9b4a392261c443bcf8160d48f4d8 Mon Sep 17 00:00:00 2001 From: AnniekStok Date: Mon, 12 Aug 2024 17:07:25 +0200 Subject: [PATCH] implement plane viewing and clean up --- setup.cfg | 6 +- src/napari_lumen_segmentation/__init__.py | 6 +- .../_custom_table_widget.py | 112 +- .../_distance_widget.py | 258 ++-- .../_histogram_widget.py | 67 +- .../_layer_dropdown.py | 49 +- src/napari_lumen_segmentation/_plot_widget.py | 149 ++- .../_skeleton_widget.py | 140 ++- src/napari_lumen_segmentation/_widget.py | 1072 +++++++++++------ .../napari_multiple_view_widget.py | 438 ------- 10 files changed, 1195 insertions(+), 1102 deletions(-) delete mode 100644 src/napari_lumen_segmentation/napari_multiple_view_widget.py diff --git a/setup.cfg b/setup.cfg index 62d34dd..a085fb3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,7 +32,6 @@ project_urls = packages = find: install_requires = napari - numpy scikit-image napari-skimage-regionprops dask_image @@ -41,8 +40,11 @@ install_requires = diplib localthickness skan + numpy==1.26.4 + napari-plane-sliders -python_requires = >=3.8 +dependency_links = https://github.com/AnniekStok/napari-plane-sliders/tarball/main#egg=napari-plane-sliders +python_requires = >=3.10 include_package_data = True package_dir = =src diff --git a/src/napari_lumen_segmentation/__init__.py b/src/napari_lumen_segmentation/__init__.py index 6c88587..4b3f12f 100644 --- a/src/napari_lumen_segmentation/__init__.py +++ b/src/napari_lumen_segmentation/__init__.py @@ -1,10 +1,10 @@ __version__ = "0.0.1" -from ._widget import AnnotateLabelsND from ._custom_table_widget import ColoredTableWidget, TableWidget +from ._distance_widget import DistanceWidget from ._histogram_widget import HistWidget from ._layer_dropdown import LayerDropdown from ._skeleton_widget import SkeletonWidget -from ._distance_widget import DistanceWidget +from ._widget import AnnotateLabelsND __all__ = ( "AnnotateLabelsND", @@ -13,5 +13,5 @@ "TableWidget", "HistWidget", "LayerDropdown", - "DistanceWidget" + "DistanceWidget", ) diff --git a/src/napari_lumen_segmentation/_custom_table_widget.py b/src/napari_lumen_segmentation/_custom_table_widget.py index c1322a3..0d470f8 100644 --- a/src/napari_lumen_segmentation/_custom_table_widget.py +++ b/src/napari_lumen_segmentation/_custom_table_widget.py @@ -1,25 +1,29 @@ - import napari - import pandas as pd -from pandas import DataFrame - +from matplotlib.colors import ListedColormap, to_rgb from napari_skimage_regionprops import TableWidget -from matplotlib.colors import to_rgb, ListedColormap - -from qtpy.QtWidgets import QTableWidget, QHBoxLayout, QTableWidgetItem, QWidget, QGridLayout, QPushButton, QFileDialog - +from pandas import DataFrame from qtpy.QtGui import QColor +from qtpy.QtWidgets import ( + QFileDialog, + QGridLayout, + QHBoxLayout, + QPushButton, + QTableWidget, + QTableWidgetItem, + QWidget, +) + class ColoredTableWidget(TableWidget): - """Customized table widget based on the napari_skimage_regionprops TableWidget - - """ + """Customized table widget based on the napari_skimage_regionprops TableWidget""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.ascending = False # for choosing whether to sort ascending or descending + self.ascending = ( + False # for choosing whether to sort ascending or descending + ) # Reconnect the clicked signal to your custom method. self._view.clicked.connect(self._clicked_table) @@ -27,20 +31,23 @@ def __init__(self, *args, **kwargs): # Connect to single click in the header to sort the table. self._view.horizontalHeader().sectionClicked.connect(self._sort_table) - def _set_label_colors_to_rows(self) -> None: """Apply the colors of the napari label image to the table""" for i in range(self._view.rowCount()): - label = self._table['label'][i] + label = self._table["label"][i] label_color = to_rgb(self._layer.get_color(label)) - scaled_color = (int(label_color[0] * 255), int(label_color[1] * 255), int(label_color[2] * 255)) + scaled_color = ( + int(label_color[0] * 255), + int(label_color[1] * 255), + int(label_color[2] * 255), + ) for j in range(self._view.columnCount()): self._view.item(i, j).setBackground(QColor(*scaled_color)) - + def _clicked_table(self): """Also set show_selected_label to True and jump to the corresponding stack position""" - + super()._clicked_table() self._layer.show_selected_label = True @@ -49,28 +56,34 @@ def _clicked_table(self): current_step = self._viewer.dims.current_step if len(current_step) == 4: new_step = (current_step[0], z, current_step[2], current_step[3]) - elif len(current_step) == 3: + elif len(current_step) == 3: new_step = (z, current_step[1], current_step[2]) - else: + else: new_step = current_step self._viewer.dims.current_step = new_step - + def _sort_table(self): """Sorts the table in ascending or descending order""" selected_column = list(self._table.keys())[self._view.currentColumn()] - df = pd.DataFrame(self._table).sort_values(by=selected_column, ascending=self.ascending) + df = pd.DataFrame(self._table).sort_values( + by=selected_column, ascending=self.ascending + ) self.ascending = not self.ascending - self.set_content(df.to_dict(orient='list')) + self.set_content(df.to_dict(orient="list")) self._set_label_colors_to_rows() + class TableWidget(QWidget): """ The table widget represents a table inside napari. Tables are just views on `properties` of `layers`. """ - def __init__(self, props=pd.DataFrame(), viewer: "napari.Viewer" = None ): + + def __init__( + self, props: pd.DataFrame | None, viewer: "napari.Viewer" = None + ): super().__init__() self._viewer = viewer @@ -82,7 +95,10 @@ def __init__(self, props=pd.DataFrame(), viewer: "napari.Viewer" = None ): self.ascending = False self._view.horizontalHeader().sectionClicked.connect(self._sort_table) - self.props = props.to_dict(orient='list') + if props is None: + self.props = pd.DataFrame().to_dict(orient="list") + else: + self.props = props.to_dict(orient="list") self.set_content(self.props) copy_button = QPushButton("Copy to clipboard") @@ -102,12 +118,16 @@ def __init__(self, props=pd.DataFrame(), viewer: "napari.Viewer" = None ): action_widget.layout().setContentsMargins(0, 0, 0, 0) def _save_clicked(self, event=None, filename=None): - if filename is None: filename, _ = QFileDialog.getSaveFileName(self, "Save as csv...", ".", "*.csv") + if filename is None: + filename, _ = QFileDialog.getSaveFileName( + self, "Save as csv...", ".", "*.csv" + ) DataFrame(self._table).to_csv(filename) - def _copy_clicked(self): DataFrame(self._table).to_clipboard() + def _copy_clicked(self): + DataFrame(self._table).to_clipboard() - def set_content(self, table : dict): + def set_content(self, table: dict): """ Overwrites the content of the table with the content of a given dictionary. """ @@ -134,35 +154,39 @@ def get_content(self) -> dict: Returns the current content of the table """ return self._table - + def _sort_table(self): """Sorts the table in ascending or descending order""" selected_column = list(self._table.keys())[self._view.currentColumn()] - df = pd.DataFrame(self._table).sort_values(by=selected_column, ascending=self.ascending) + df = pd.DataFrame(self._table).sort_values( + by=selected_column, ascending=self.ascending + ) self.ascending = not self.ascending - self.set_content(df.to_dict(orient='list')) + self.set_content(df.to_dict(orient="list")) if self.sort_by is not None: - self._recolor(self.sort_by, self.colormap) - - def _recolor(self, by:str, cmap:ListedColormap): - """Assign colors to the table based on given column and colormap """ + self._recolor(self.sort_by, self.colormap) + + def _recolor(self, by: str, cmap: ListedColormap): + """Assign colors to the table based on given column and colormap""" default_color = self.palette().color(self.backgroundRole()) - if by is None: - for i in range(self._view.rowCount()): + if by is None: + for i in range(self._view.rowCount()): for j in range(self._view.columnCount()): self._view.item(i, j).setBackground(default_color) - else: - for i in range(self._view.rowCount()): - id = self._table[by][i] - color = to_rgb(cmap.colors[id]) - scaled_color = (int(color[0] * 255), int(color[1] * 255), int(color[2] * 255)) + else: + for i in range(self._view.rowCount()): + label = self._table[by][i] + color = to_rgb(cmap.colors[label]) + scaled_color = ( + int(color[0] * 255), + int(color[1] * 255), + int(color[2] * 255), + ) for j in range(self._view.columnCount()): self._view.item(i, j).setBackground(QColor(*scaled_color)) - + self.sort_by = by self.colormap = cmap - - \ No newline at end of file diff --git a/src/napari_lumen_segmentation/_distance_widget.py b/src/napari_lumen_segmentation/_distance_widget.py index faa8c49..3e57275 100644 --- a/src/napari_lumen_segmentation/_distance_widget.py +++ b/src/napari_lumen_segmentation/_distance_widget.py @@ -1,25 +1,32 @@ -import napari -import pandas as pd -import dask.array as da -import diplib as dip -import localthickness as lt -import numpy as np +from itertools import combinations +import dask.array as da +import diplib as dip +import localthickness as lt import matplotlib.pyplot as plt - -from napari.layers import Labels, Points -from qtpy.QtWidgets import QScrollArea, QGroupBox, QMessageBox, QLabel, QHBoxLayout, QVBoxLayout, QPushButton -from PyQt5.QtGui import QColor -from itertools import combinations - -from ._custom_table_widget import TableWidget -from ._histogram_widget import HistWidget -from ._layer_dropdown import LayerDropdown +import napari +import numpy as np +import pandas as pd +from napari.layers import Labels, Points +from PyQt5.QtGui import QColor +from qtpy.QtWidgets import ( + QGroupBox, + QHBoxLayout, + QLabel, + QMessageBox, + QPushButton, + QScrollArea, + QVBoxLayout, +) + +from ._custom_table_widget import TableWidget +from ._histogram_widget import HistWidget +from ._layer_dropdown import LayerDropdown class DistanceWidget(QScrollArea): - def __init__(self, viewer:napari.Viewer, labels: Labels): + def __init__(self, viewer: napari.Viewer, labels: Labels): super().__init__() self.viewer = viewer self.labels = labels @@ -27,8 +34,8 @@ def __init__(self, viewer:napari.Viewer, labels: Labels): distance_analysis_layout = QVBoxLayout() ### Add button to calculate local thickness map - local_thickness_box = QGroupBox("Compute local thickness map") - self.thickness_btn = QPushButton('Calculate local thickness') + local_thickness_box = QGroupBox("Compute local thickness map") + self.thickness_btn = QPushButton("Calculate local thickness") self.thickness_btn.clicked.connect(self._calculate_local_thickness) thickness_layout = QVBoxLayout() thickness_layout.addWidget(self.thickness_btn) @@ -36,31 +43,45 @@ def __init__(self, viewer:napari.Viewer, labels: Labels): distance_analysis_layout.addWidget(local_thickness_box) ### Add widget for histogram - self.hist_widget_box = HistWidget('Histogram', viewer) + self.hist_widget_box = HistWidget("Histogram", viewer) self.hist_widget_box.setMaximumHeight(400) distance_analysis_layout.addWidget(self.hist_widget_box) ### geodesic distance from points - geodesic_distmap_box = QGroupBox('Compute geodesic distance map') + geodesic_distmap_box = QGroupBox("Compute geodesic distance map") geodesic_distmap_box_layout = QVBoxLayout() geodesic_distmap_mask_layout = QHBoxLayout() - geodesic_distmap_mask_layout.addWidget(QLabel('Mask image')) - self.geodesic_distmap_mask_dropdown = LayerDropdown(self.viewer, (Labels)) - self.geodesic_distmap_mask_dropdown.layer_changed.connect(self._update_geodesic_distmap_mask) - geodesic_distmap_mask_layout.addWidget(self.geodesic_distmap_mask_dropdown) + geodesic_distmap_mask_layout.addWidget(QLabel("Mask image")) + self.geodesic_distmap_mask_dropdown = LayerDropdown( + self.viewer, (Labels) + ) + self.geodesic_distmap_mask_dropdown.layer_changed.connect( + self._update_geodesic_distmap_mask + ) + geodesic_distmap_mask_layout.addWidget( + self.geodesic_distmap_mask_dropdown + ) geodesic_distmap_marker_layer_layout = QHBoxLayout() - geodesic_distmap_marker_layer_layout.addWidget(QLabel('Marker points')) - self.geodesic_distmap_marker_layer_dropdown = LayerDropdown(self.viewer, (Labels, Points)) - self.geodesic_distmap_marker_layer_dropdown.layer_changed.connect(self._update_geodesic_distmap_marker_layer) - geodesic_distmap_marker_layer_layout.addWidget(self.geodesic_distmap_marker_layer_dropdown) - - geodesic_distmap_btn = QPushButton('Run') + geodesic_distmap_marker_layer_layout.addWidget(QLabel("Marker points")) + self.geodesic_distmap_marker_layer_dropdown = LayerDropdown( + self.viewer, (Labels, Points) + ) + self.geodesic_distmap_marker_layer_dropdown.layer_changed.connect( + self._update_geodesic_distmap_marker_layer + ) + geodesic_distmap_marker_layer_layout.addWidget( + self.geodesic_distmap_marker_layer_dropdown + ) + + geodesic_distmap_btn = QPushButton("Run") geodesic_distmap_btn.clicked.connect(self._calculate_geodesic_distance) - + geodesic_distmap_box_layout.addLayout(geodesic_distmap_mask_layout) - geodesic_distmap_box_layout.addLayout(geodesic_distmap_marker_layer_layout) + geodesic_distmap_box_layout.addLayout( + geodesic_distmap_marker_layer_layout + ) geodesic_distmap_box_layout.addWidget(geodesic_distmap_btn) geodesic_distmap_box.setLayout(geodesic_distmap_box_layout) @@ -68,67 +89,97 @@ def __init__(self, viewer:napari.Viewer, labels: Labels): ### euclidean and geodesic distance measurements distance_measurement_box = QGroupBox("Distance measurements") - self.table_widget = TableWidget(props = pd.DataFrame()) + self.table_widget = TableWidget(props=pd.DataFrame()) distance_measurement_layout = QHBoxLayout() distance_measurement_layout.addWidget(self.table_widget) distance_measurement_box.setLayout(distance_measurement_layout) distance_analysis_layout.addWidget(distance_measurement_box) # Set main layout - self.setLayout(distance_analysis_layout) + self.setLayout(distance_analysis_layout) self.setWidgetResizable(True) - def _update_geodesic_distmap_mask(self, selected_layer:str) -> None: + def _update_geodesic_distmap_mask(self, selected_layer: str) -> None: """Set the mask layer for geodesic distance map calculation""" - if selected_layer == '': + if selected_layer == "": self.geodesic_distmap_mask_layer = None else: - self.geodesic_distmap_mask_layer = self.viewer.layers[selected_layer] + self.geodesic_distmap_mask_layer = self.viewer.layers[ + selected_layer + ] self.geodesic_distmap_mask_dropdown.setCurrentText(selected_layer) - def _update_geodesic_distmap_marker_layer(self, selected_layer:str) -> None: + def _update_geodesic_distmap_marker_layer( + self, selected_layer: str + ) -> None: """Set the marker layer for geodesic distance map calculation""" - if selected_layer == '': + if selected_layer == "": self.geodesic_distmap_marker_layer = None else: - self.geodesic_distmap_marker_layer = self.viewer.layers[selected_layer] - self.geodesic_distmap_marker_layer_dropdown.setCurrentText(selected_layer) - - def _calculate_geodesic_distance(self): + self.geodesic_distmap_marker_layer = self.viewer.layers[ + selected_layer + ] + self.geodesic_distmap_marker_layer_dropdown.setCurrentText( + selected_layer + ) + + def _calculate_geodesic_distance(self): """Run geodesic distance map computation""" - if type(self.geodesic_distmap_mask_layer) == da.core.Array == da.core.Array: - msg = QMessageBox() - msg.setWindowTitle("Please convert to an in memory array") - msg.setText("Please convert to an in memory array") - msg.setIcon(QMessageBox.Information) - msg.setStandardButtons(QMessageBox.Ok) - msg.exec_() - return False - + if isinstance(self.geodesic_distmap_mask_layer, da.core.Array): + msg = QMessageBox() + msg.setWindowTitle("Please convert to an in memory array") + msg.setText("Please convert to an in memory array") + msg.setIcon(QMessageBox.Information) + msg.setStandardButtons(QMessageBox.Ok) + msg.exec_() + return False + if isinstance(self.geodesic_distmap_marker_layer, Labels): - marker = self.geodesic_distmap_marker_layer.data.copy() > 0 - self.viewer.add_image(np.array(dip.GeodesicDistanceTransform(~marker, self.geodesic_distmap_mask_layer.data > 0), dtype = np.float32), colormap = 'magma') + marker = self.geodesic_distmap_marker_layer.data.copy() > 0 + self.viewer.add_image( + np.array( + dip.GeodesicDistanceTransform( + ~marker, self.geodesic_distmap_mask_layer.data > 0 + ), + dtype=np.float32, + ), + colormap="magma", + ) elif isinstance(self.geodesic_distmap_marker_layer, Points): - if len(self.geodesic_distmap_marker_layer.data) == 1: + if len(self.geodesic_distmap_marker_layer.data) == 1: mask2 = self.geodesic_distmap_mask_layer.data.copy() > 0 - mask2[tuple(self.geodesic_distmap_marker_layer.data[0].astype(int))] = False - self.viewer.add_image(np.array(dip.GeodesicDistanceTransform(mask2, self.geodesic_distmap_mask_layer.data > 0), dtype = np.float32), colormap = 'magma') - - elif len(self.geodesic_distmap_marker_layer.data) > 1: + mask2[ + tuple( + self.geodesic_distmap_marker_layer.data[0].astype(int) + ) + ] = False + self.viewer.add_image( + np.array( + dip.GeodesicDistanceTransform( + mask2, self.geodesic_distmap_mask_layer.data > 0 + ), + dtype=np.float32, + ), + colormap="magma", + ) + + elif len(self.geodesic_distmap_marker_layer.data) > 1: measurements = pd.DataFrame() point_ids = {} unique_id_counter = -1 - colormap = plt.get_cmap('tab10') - - for point1, point2 in combinations(self.geodesic_distmap_marker_layer.data, 2): - + colormap = plt.get_cmap("tab10") + + for point1, point2 in combinations( + self.geodesic_distmap_marker_layer.data, 2 + ): + # Calculate unique IDs for point1 and point2 if tuple(point1) not in point_ids: unique_id_counter += 1 @@ -143,60 +194,91 @@ def _calculate_geodesic_distance(self): # calculate the geodesic distance mask2 = self.geodesic_distmap_mask_layer.data.copy() > 0 mask2[tuple(point2.astype(int))] = False - dist_map = np.array(dip.GeodesicDistanceTransform(mask2, self.geodesic_distmap_mask_layer.data > 0), dtype = np.float32) + dist_map = np.array( + dip.GeodesicDistanceTransform( + mask2, self.geodesic_distmap_mask_layer.data > 0 + ), + dtype=np.float32, + ) geodesic_dist = dist_map[tuple(point1.astype(int))] # Get unique colors for point1 and point2 from tab10 colormap point1_color = colormap(point_ids[tuple(point1)] % 10) point2_color = colormap(point_ids[tuple(point2)] % 10) - + # Create a dictionary to store the measurements for this pair of points measurement_dict = { - 'point1.ID': point_ids[tuple(point1)], - 'point2.ID': point_ids[tuple(point2)], - 'point1.x': point1[0], - 'point1.y': point1[1], - 'point1.z': point1[2], - 'point2.x': point2[0], - 'point2.y': point2[1], - 'point2.z': point2[2], - 'point1.color': point1_color[:3], # Use only RGB components - 'point2.color': point2_color[:3], # Use only RGB components - 'euclidean_dist': euclidean_dist, - 'geodesic_dist': geodesic_dist + "point1.ID": point_ids[tuple(point1)], + "point2.ID": point_ids[tuple(point2)], + "point1.x": point1[0], + "point1.y": point1[1], + "point1.z": point1[2], + "point2.x": point2[0], + "point2.y": point2[1], + "point2.z": point2[2], + "point1.color": point1_color[ + :3 + ], # Use only RGB components + "point2.color": point2_color[ + :3 + ], # Use only RGB components + "euclidean_dist": euclidean_dist, + "geodesic_dist": geodesic_dist, } - measurements = pd.concat([measurements, pd.DataFrame([measurement_dict])]) - - self.table_widget.set_content(measurements.to_dict(orient = 'list')) + measurements = pd.concat( + [measurements, pd.DataFrame([measurement_dict])] + ) + + self.table_widget.set_content( + measurements.to_dict(orient="list") + ) # Iterate over all rows in the QTableWidget i = 0 for index, row in measurements.iterrows(): - point1_color = row['point1.color'] - point2_color = row['point2.color'] + point1_color = row["point1.color"] + point2_color = row["point2.color"] point1_cols = [0, 2, 3, 4, 8] point2_cols = [1, 5, 6, 7, 9] # Set background color for point1 cells (columns 0, 1, and 2) for j in point1_cols: - self.table_widget._view.item(i, j).setBackground(QColor(int(point1_color[0] * 255), int(point1_color[1] * 255), int(point1_color[2] * 255))) + self.table_widget._view.item(i, j).setBackground( + QColor( + int(point1_color[0] * 255), + int(point1_color[1] * 255), + int(point1_color[2] * 255), + ) + ) # Set background color for point2 cells (columns 3, 4, and 5) for j in point2_cols: - self.table_widget._view.item(i, j).setBackground(QColor(int(point2_color[0] * 255), int(point2_color[1] * 255), int(point2_color[2] * 255))) + self.table_widget._view.item(i, j).setBackground( + QColor( + int(point2_color[0] * 255), + int(point2_color[1] * 255), + int(point2_color[2] * 255), + ) + ) i += 1 - # also set colormap to the points - colors = [colormap(i) for i in range(len(self.geodesic_distmap_marker_layer.data))] + # also set colormap to the points + colors = [ + colormap(i) + for i in range( + len(self.geodesic_distmap_marker_layer.data) + ) + ] self.geodesic_distmap_marker_layer.edge_color = colors self.geodesic_distmap_marker_layer.face_color = colors - def _calculate_local_thickness(self) -> None: """Calculates local thickness of label image and adds the image to the viewer""" - self.viewer.add_image(lt.local_thickness(self.labels.data), colormap = 'magma') \ No newline at end of file + self.viewer.add_image( + lt.local_thickness(self.labels.data), colormap="magma" + ) diff --git a/src/napari_lumen_segmentation/_histogram_widget.py b/src/napari_lumen_segmentation/_histogram_widget.py index dee2200..750c792 100644 --- a/src/napari_lumen_segmentation/_histogram_widget.py +++ b/src/napari_lumen_segmentation/_histogram_widget.py @@ -1,18 +1,28 @@ import os -from pathlib import Path -import napari.layers +from pathlib import Path -import matplotlib.pyplot as plt -import numpy as np +import matplotlib.pyplot as plt +import napari.layers +import numpy as np +from matplotlib.backends.backend_qt5agg import ( + FigureCanvas, + NavigationToolbar2QT, +) +from napari.layers import Image +from qtpy.QtGui import QIcon +from qtpy.QtWidgets import ( + QGroupBox, + QHBoxLayout, + QPushButton, + QVBoxLayout, + QWidget, +) -from matplotlib.backends.backend_qt5agg import FigureCanvas, NavigationToolbar2QT -from qtpy.QtWidgets import QHBoxLayout, QVBoxLayout, QWidget, QPushButton, QGroupBox -from qtpy.QtGui import QIcon -from napari.layers import Image -from ._layer_dropdown import LayerDropdown +from ._layer_dropdown import LayerDropdown ICON_ROOT = Path(__file__).parent / "icons" + class HistWidget(QGroupBox): """Customized plotting widget class. @@ -32,16 +42,16 @@ def __init__(self, title, viewer: napari.Viewer): # Specify plot customizations. self.fig.patch.set_facecolor("#262930") - self.ax.tick_params(colors='white') + self.ax.tick_params(colors="white") self.ax.set_facecolor("#262930") - self.ax.xaxis.label.set_color('white') - self.ax.yaxis.label.set_color('white') + self.ax.xaxis.label.set_color("white") + self.ax.yaxis.label.set_color("white") self.ax.spines["bottom"].set_color("white") self.ax.spines["top"].set_color("white") self.ax.spines["right"].set_color("white") self.ax.spines["left"].set_color("white") for action_name in self.toolbar._actions: - action=self.toolbar._actions[action_name] + action = self.toolbar._actions[action_name] icon_path = os.path.join(ICON_ROOT, action_name + ".png") action.setIcon(QIcon(icon_path)) @@ -49,9 +59,9 @@ def __init__(self, title, viewer: napari.Viewer): x_axis_layout = QHBoxLayout() self.image_dropdown = LayerDropdown(self.viewer, (Image)) x_axis_layout.addWidget(self.image_dropdown) - apply_btn = QPushButton('Show Histogram') + apply_btn = QPushButton("Show Histogram") x_axis_layout.addWidget(apply_btn) - apply_btn.clicked.connect(self._update_plot) + apply_btn.clicked.connect(self._update_plot) dropdown_layout = QVBoxLayout() dropdown_layout.addLayout(x_axis_layout) @@ -63,27 +73,28 @@ def __init__(self, title, viewer: napari.Viewer): plotting_layout.addWidget(dropdown_widget) plotting_layout.addWidget(self.toolbar) plotting_layout.addWidget(self.plot_canvas) - self.setLayout(plotting_layout) - + self.setLayout(plotting_layout) + def _update_plot(self) -> None: - """Update the histogram - """ + """Update the histogram""" image_layer = self.viewer.layers[self.image_dropdown.currentText()] intensity_values = image_layer.data.flatten() - intensity_values = intensity_values[np.isfinite(intensity_values) & (intensity_values != 0)] + intensity_values = intensity_values[ + np.isfinite(intensity_values) & (intensity_values != 0) + ] # Clear data points, and reset the axis scaling for artist in self.ax.lines + self.ax.collections: artist.remove() self.ax.clear() - self.ax.set_xlabel('Intensity') - self.ax.set_ylabel('Count') + self.ax.set_xlabel("Intensity") + self.ax.set_ylabel("Count") self.ax.relim() # Recalculate limits for the current data self.ax.autoscale_view() # Update the view to include the new limits - self.ax.xaxis.label.set_color('white') - self.ax.yaxis.label.set_color('white') - - self.ax.hist(intensity_values, bins = 255, color='turquoise', alpha=0.7) - - self.plot_canvas.draw() \ No newline at end of file + self.ax.xaxis.label.set_color("white") + self.ax.yaxis.label.set_color("white") + + self.ax.hist(intensity_values, bins=255, color="turquoise", alpha=0.7) + + self.plot_canvas.draw() diff --git a/src/napari_lumen_segmentation/_layer_dropdown.py b/src/napari_lumen_segmentation/_layer_dropdown.py index 9a7cf3c..5dd0f02 100644 --- a/src/napari_lumen_segmentation/_layer_dropdown.py +++ b/src/napari_lumen_segmentation/_layer_dropdown.py @@ -1,23 +1,27 @@ +from typing import Tuple + import napari -from qtpy.QtWidgets import QComboBox -from PyQt5.QtCore import pyqtSignal -from typing import Tuple +from PyQt5.QtCore import pyqtSignal +from qtpy.QtWidgets import QComboBox + class LayerDropdown(QComboBox): - """QComboBox widget with functions for updating the selected layer and to update the list of options when the list of layers is modified. - - """ - - layer_changed = pyqtSignal(str) # Define a signal to emit the selected layer name + """QComboBox widget with functions for updating the selected layer and to update the list of options when the list of layers is modified.""" - def __init__(self, viewer:napari.Viewer, layer_type:Tuple): + layer_changed = pyqtSignal( + str + ) # Define a signal to emit the selected layer name + + def __init__(self, viewer: napari.Viewer, layer_type: Tuple): super().__init__() self.viewer = viewer self.layer_type = layer_type self.viewer.layers.events.inserted.connect(self._on_insert) self.viewer.layers.events.changed.connect(self._update_dropdown) self.viewer.layers.events.removed.connect(self._update_dropdown) - self.viewer.layers.selection.events.changed.connect(self._on_selection_changed) + self.viewer.layers.selection.events.changed.connect( + self._on_selection_changed + ) self.currentIndexChanged.connect(self._emit_layer_changed) self._update_dropdown() @@ -25,32 +29,41 @@ def _on_insert(self, event) -> None: """Update dropdown and make new layer responsive to name changes""" layer = event.value + @layer.events.name.connect def _on_rename(name_event): self._update_dropdown() + self._update_dropdown() - + def _on_selection_changed(self) -> None: - """Request signal emission if the user changes the layer selection.""" + """Request signal emission if the user changes the layer selection.""" - if len(self.viewer.layers.selection) == 1: # Only consider single layer selection + if ( + len(self.viewer.layers.selection) == 1 + ): # Only consider single layer selection selected_layer = self.viewer.layers.selection.active if isinstance(selected_layer, self.layer_type): self.setCurrentText(selected_layer.name) - self._emit_layer_changed() + self._emit_layer_changed() def _update_dropdown(self) -> None: """Update the list of options in the dropdown menu whenever the list of layers is changed""" selected_layer = self.currentText() self.clear() - layers = [layer for layer in self.viewer.layers if isinstance(layer, self.layer_type) and not layer.name == "label options"] + layers = [ + layer + for layer in self.viewer.layers + if isinstance(layer, self.layer_type) + and layer.name != "label options" + ] items = [] for layer in layers: self.addItem(layer.name) items.append(layer.name) - - # In case the currently selected layer is one of the available items, set it again to the current value of the dropdown. + + # In case the currently selected layer is one of the available items, set it again to the current value of the dropdown. if selected_layer in items: self.setCurrentText(selected_layer) @@ -58,4 +71,4 @@ def _emit_layer_changed(self) -> None: """Emit a signal holding the currently selected layer""" selected_layer = self.currentText() - self.layer_changed.emit(selected_layer) \ No newline at end of file + self.layer_changed.emit(selected_layer) diff --git a/src/napari_lumen_segmentation/_plot_widget.py b/src/napari_lumen_segmentation/_plot_widget.py index 1ffa21d..aaab3ce 100644 --- a/src/napari_lumen_segmentation/_plot_widget.py +++ b/src/napari_lumen_segmentation/_plot_widget.py @@ -1,28 +1,29 @@ import os -from pathlib import Path -import napari.layers - -import matplotlib.pyplot as plt -import pandas as pd -import numpy as np - -from matplotlib.backends.backend_qt5agg import FigureCanvas, NavigationToolbar2QT -from qtpy.QtWidgets import QHBoxLayout, QVBoxLayout, QWidget, QComboBox, QLabel -from qtpy.QtGui import QIcon +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from matplotlib.backends.backend_qt5agg import ( + FigureCanvas, + NavigationToolbar2QT, +) +from qtpy.QtGui import QIcon +from qtpy.QtWidgets import QComboBox, QHBoxLayout, QLabel, QVBoxLayout, QWidget ICON_ROOT = Path(__file__).parent / "icons" + class PlotWidget(QWidget): - """Customized plotting widget class. - """ + """Customized plotting widget class.""" def __init__(self, props: pd.DataFrame): super().__init__() self.props = props self.label_colormap = None - self.categorical_cmap = 'tab10' # may be overwritten by parent - self.continuous_cmap = 'summer' # may be overwritten by parent + self.categorical_cmap = "tab10" # may be overwritten by parent + self.continuous_cmap = "summer" # may be overwritten by parent # Main plot. self.fig = plt.figure() @@ -32,30 +33,34 @@ def __init__(self, props: pd.DataFrame): # Specify plot customizations. self.fig.patch.set_facecolor("#262930") - self.ax.tick_params(colors='white') + self.ax.tick_params(colors="white") self.ax.set_facecolor("#262930") - self.ax.xaxis.label.set_color('white') - self.ax.yaxis.label.set_color('white') + self.ax.xaxis.label.set_color("white") + self.ax.yaxis.label.set_color("white") self.ax.spines["bottom"].set_color("white") self.ax.spines["top"].set_color("white") self.ax.spines["right"].set_color("white") self.ax.spines["left"].set_color("white") for action_name in self.toolbar._actions: - action=self.toolbar._actions[action_name] + action = self.toolbar._actions[action_name] icon_path = os.path.join(ICON_ROOT, action_name + ".png") action.setIcon(QIcon(icon_path)) # Create a dropdown window for selecting what to plot on the axes. x_axis_layout = QHBoxLayout() self.x_combo = QComboBox() - self.x_combo.addItems([item for item in self.props.columns if item != 'index']) - x_axis_layout.addWidget(QLabel('x-axis')) + self.x_combo.addItems( + [item for item in self.props.columns if item != "index"] + ) + x_axis_layout.addWidget(QLabel("x-axis")) x_axis_layout.addWidget(self.x_combo) y_axis_layout = QHBoxLayout() self.y_combo = QComboBox() - self.y_combo.addItems([item for item in self.props.columns if item != 'index']) - y_axis_layout.addWidget(QLabel('y-axis')) + self.y_combo.addItems( + [item for item in self.props.columns if item != "index"] + ) + y_axis_layout.addWidget(QLabel("y-axis")) y_axis_layout.addWidget(self.y_combo) self.x_combo.currentIndexChanged.connect(self._update_plot) @@ -63,9 +68,11 @@ def __init__(self, props: pd.DataFrame): color_group_layout = QHBoxLayout() self.group_combo = QComboBox() - self.group_combo.addItems([item for item in self.props.columns if item != 'index']) + self.group_combo.addItems( + [item for item in self.props.columns if item != "index"] + ) self.group_combo.currentIndexChanged.connect(self._update_plot) - color_group_layout.addWidget(QLabel('Group color')) + color_group_layout.addWidget(QLabel("Group color")) color_group_layout.addWidget(self.group_combo) dropdown_layout = QVBoxLayout() @@ -80,28 +87,34 @@ def __init__(self, props: pd.DataFrame): plotting_layout.addWidget(dropdown_widget) plotting_layout.addWidget(self.toolbar) plotting_layout.addWidget(self.plot_canvas) - self.setLayout(plotting_layout) - - def _update_properties(self, filtered_measurements): + self.setLayout(plotting_layout) + + def _update_properties(self, filtered_measurements): """Update the properties and regenerate the plot""" - self.props = filtered_measurements[filtered_measurements['selected']] + self.props = filtered_measurements[filtered_measurements["selected"]] self._update_plot() - def _update_dropdowns(self) -> None: + def _update_dropdowns(self) -> None: """Update the options in the dropdown menus""" self.x_combo.blockSignals(True) self.y_combo.blockSignals(True) self.group_combo.blockSignals(True) - + self.x_combo.clear() self.y_combo.clear() self.group_combo.clear() - self.x_combo.addItems([item for item in self.props.columns if item != 'index']) + self.x_combo.addItems( + [item for item in self.props.columns if item != "index"] + ) self.x_combo.setCurrentIndex(0) - self.y_combo.addItems([item for item in self.props.columns if item != 'index']) + self.y_combo.addItems( + [item for item in self.props.columns if item != "index"] + ) self.y_combo.setCurrentIndex(1) - self.group_combo.addItems([item for item in self.props.columns if item != 'index']) + self.group_combo.addItems( + [item for item in self.props.columns if item != "index"] + ) self.group_combo.setCurrentIndex(0) self.x_combo.blockSignals(False) @@ -111,15 +124,13 @@ def _update_dropdowns(self) -> None: self._update_plot() def _update_plot(self) -> None: - """Update the plot by plotting the features selected by the user. - - """ + """Update the plot by plotting the features selected by the user.""" if not self.props.empty: x_axis_property = self.x_combo.currentText() y_axis_property = self.y_combo.currentText() group = self.group_combo.currentText() - + # Clear data points, and reset the axis scaling for artist in self.ax.lines + self.ax.collections: artist.remove() @@ -127,26 +138,64 @@ def _update_plot(self) -> None: self.ax.set_ylabel(y_axis_property) self.ax.relim() # Recalculate limits for the current data self.ax.autoscale_view() # Update the view to include the new limits - + if group == "label": # plot using label colors if self.label_colormap is not None: - self.ax.scatter(self.props[x_axis_property], self.props[y_axis_property], c=self.props[group], cmap=self.label_colormap, s = 10) - else: - self.ax.scatter(self.props[x_axis_property], self.props[y_axis_property], c=self.props[group], cmap=self.categorical_cmap, s = 10) + self.ax.scatter( + self.props[x_axis_property], + self.props[y_axis_property], + c=self.props[group], + cmap=self.label_colormap, + s=10, + ) + else: + self.ax.scatter( + self.props[x_axis_property], + self.props[y_axis_property], + c=self.props[group], + cmap=self.categorical_cmap, + s=10, + ) else: # Plot data points on a custom categorical or continuous colormap. - if self.props[group].dtype == 'object' or np.issubdtype(self.props[group].dtype, np.integer): + if self.props[group].dtype == "object" or np.issubdtype( + self.props[group].dtype, np.integer + ): unique_categories = np.unique(self.props[group]) cmap_colors = self.categorical_cmap.colors if len(unique_categories) <= len(cmap_colors): - category_to_color = {category: cmap_colors[i] for i, category in enumerate(unique_categories)} - colors = [category_to_color[category] for category in self.props[group]] - self.ax.scatter(self.props[x_axis_property], self.props[y_axis_property], c=colors, cmap=self.categorical_cmap, s = 10) - else: - self.ax.scatter(self.props[x_axis_property], self.props[y_axis_property], c=self.props[group], cmap=self.categorical_cmap, s = 10) + category_to_color = { + category: cmap_colors[i] + for i, category in enumerate(unique_categories) + } + colors = [ + category_to_color[category] + for category in self.props[group] + ] + self.ax.scatter( + self.props[x_axis_property], + self.props[y_axis_property], + c=colors, + cmap=self.categorical_cmap, + s=10, + ) + else: + self.ax.scatter( + self.props[x_axis_property], + self.props[y_axis_property], + c=self.props[group], + cmap=self.categorical_cmap, + s=10, + ) else: - self.ax.scatter(self.props[x_axis_property], self.props[y_axis_property], c=self.props[group], cmap=self.continuous_cmap, s = 10) - - self.plot_canvas.draw() \ No newline at end of file + self.ax.scatter( + self.props[x_axis_property], + self.props[y_axis_property], + c=self.props[group], + cmap=self.continuous_cmap, + s=10, + ) + + self.plot_canvas.draw() diff --git a/src/napari_lumen_segmentation/_skeleton_widget.py b/src/napari_lumen_segmentation/_skeleton_widget.py index c4ac37d..0e146c4 100644 --- a/src/napari_lumen_segmentation/_skeleton_widget.py +++ b/src/napari_lumen_segmentation/_skeleton_widget.py @@ -1,20 +1,27 @@ -import skan -import napari -import pandas as pd -import numpy as np -import matplotlib.pyplot as plt import matplotlib.colors as mcolors +import matplotlib.pyplot as plt +import napari +import numpy as np +import pandas as pd +import skan +from napari.layers import Labels +from qtpy.QtWidgets import ( + QComboBox, + QGroupBox, + QPushButton, + QScrollArea, + QVBoxLayout, + QWidget, +) +from skimage import morphology + +from ._custom_table_widget import TableWidget +from ._plot_widget import PlotWidget -from qtpy.QtWidgets import QVBoxLayout, QWidget, QComboBox, QGroupBox, QPushButton, QScrollArea -from skimage import morphology -from napari.layers import Labels - -from ._custom_table_widget import TableWidget -from ._plot_widget import PlotWidget class SkeletonWidget(QScrollArea): - def __init__(self, viewer:napari.Viewer, labels: Labels): + def __init__(self, viewer: napari.Viewer, labels: Labels): super().__init__() self.viewer = viewer self.labels = labels @@ -23,110 +30,117 @@ def __init__(self, viewer:napari.Viewer, labels: Labels): self.analysis_layout = QVBoxLayout() ## Add box for skeleton analysis - skeleton_box = QGroupBox('Skeleton Analysis') + skeleton_box = QGroupBox("Skeleton Analysis") self.skeleton_box_layout = QVBoxLayout() - self.skeleton_btn = QPushButton('Create skeleton') + self.skeleton_btn = QPushButton("Create skeleton") self.skeleton_btn.clicked.connect(self._skeletonize) self.skeleton_box_layout.addWidget(self.skeleton_btn) self.skeleton_visualization_dropdown = QComboBox() - self.skeleton_visualization_dropdown.addItem('Skeleton') - self.skeleton_visualization_dropdown.addItem('Path') - self.skeleton_visualization_dropdown.addItem('Branch Length') - self.skeleton_visualization_dropdown.currentIndexChanged.connect(self._update_skeleton_visualization) - - self.skeleton_box_layout.addWidget(self.skeleton_visualization_dropdown) + self.skeleton_visualization_dropdown.addItem("Skeleton") + self.skeleton_visualization_dropdown.addItem("Path") + self.skeleton_visualization_dropdown.addItem("Branch Length") + self.skeleton_visualization_dropdown.currentIndexChanged.connect( + self._update_skeleton_visualization + ) + + self.skeleton_box_layout.addWidget( + self.skeleton_visualization_dropdown + ) skeleton_box.setLayout(self.skeleton_box_layout) - + self.analysis_layout.addWidget(skeleton_box) - self.table_widget = TableWidget(props = pd.DataFrame()) - self.plot_widget = PlotWidget(props = pd.DataFrame()) - + self.table_widget = TableWidget(props=pd.DataFrame()) + self.plot_widget = PlotWidget(props=pd.DataFrame()) + self.analysis_layout.addWidget(self.table_widget) self.analysis_layout.addWidget(self.plot_widget) self.analysis_widgets = QWidget() self.analysis_widgets.setLayout(self.analysis_layout) - + self.setWidget(self.analysis_widgets) self.setWidgetResizable(True) - def _skeletonize(self) -> None: + def _skeletonize(self) -> None: """Create skeleton from label image""" skel = morphology.skeletonize(self.labels.data) degree_image = skan.csr.make_degree_image(skel) - - self.viewer.add_labels(degree_image, name = 'Connectivity') + self.viewer.add_labels(degree_image, name="Connectivity") skeleton = skan.Skeleton(skel) all_paths = [ - skeleton.path_coordinates(i) - for i in range(skeleton.n_paths) + skeleton.path_coordinates(i) for i in range(skeleton.n_paths) ] self.paths_table = skan.summarize(skeleton) - self.paths_table['path-id'] = np.arange(skeleton.n_paths) - + self.paths_table["path-id"] = np.arange(skeleton.n_paths) + # Create a randomized colormap - colormaps = ['tab20', 'tab20b', 'tab20c'] + colormaps = ["tab20", "tab20b", "tab20c"] color_cycle = [] for cmap_name in colormaps: cmap = plt.get_cmap(cmap_name) colors = [mcolors.to_hex(cmap(i)) for i in range(20)] color_cycle.extend(colors) np.random.shuffle(color_cycle) - n_colors = len(np.unique(self.paths_table['path-id'])) + n_colors = len(np.unique(self.paths_table["path-id"])) repetitions = (n_colors + len(color_cycle) - 1) // len(color_cycle) extended_color_cycle = color_cycle * repetitions - self.random_cmap = mcolors.ListedColormap(extended_color_cycle[:n_colors]) - - # define shapes layer - self.skeleton = self.viewer.add_shapes(all_paths, - shape_type='path', + self.random_cmap = mcolors.ListedColormap( + extended_color_cycle[:n_colors] + ) + + # define shapes layer + self.skeleton = self.viewer.add_shapes( + all_paths, + shape_type="path", properties=self.paths_table, edge_width=0.5, - edge_color='skeleton-id', - face_color='skeleton-id', - edge_colormap='viridis', - face_colormap='viridis', - edge_color_cycle = self.random_cmap.colors, - face_color_cycle = self.random_cmap.colors - ) + edge_color="skeleton-id", + face_color="skeleton-id", + edge_colormap="viridis", + face_colormap="viridis", + edge_color_cycle=self.random_cmap.colors, + face_color_cycle=self.random_cmap.colors, + ) # update table widget - self.table_widget.set_content(self.paths_table.to_dict(orient = 'list')) - self.table_widget._recolor(by="skeleton-id", cmap = self.random_cmap) + self.table_widget.set_content(self.paths_table.to_dict(orient="list")) + self.table_widget._recolor(by="skeleton-id", cmap=self.random_cmap) # update plot widget self.plot_widget.props = self.paths_table self.plot_widget.categorical_cmap = self.random_cmap - self.plot_widget.continuous_cmap = 'viridis' - self.plot_widget._update_dropdowns() - + self.plot_widget.continuous_cmap = "viridis" + self.plot_widget._update_dropdowns() + def _update_skeleton_visualization(self) -> None: """Update the coloring of the skeleton layer""" if self.skeleton_visualization_dropdown.currentText() == "Path": - ids = self.skeleton.properties['path-id'] + ids = self.skeleton.properties["path-id"] colors = [self.random_cmap.colors[c] for c in ids] self.skeleton.edge_color = colors - self.skeleton.face_color = colors - self.table_widget._recolor(by='path-id', cmap = self.random_cmap) + self.skeleton.face_color = colors + self.table_widget._recolor(by="path-id", cmap=self.random_cmap) if self.skeleton_visualization_dropdown.currentText() == "Skeleton": - ids = self.skeleton.properties['skeleton-id'] + ids = self.skeleton.properties["skeleton-id"] colors = [self.random_cmap.colors[c] for c in ids] self.skeleton.edge_color = colors - self.skeleton.face_color = colors - self.table_widget._recolor(by="skeleton-id", cmap = self.random_cmap) - if self.skeleton_visualization_dropdown.currentText() == "Branch Length": - self.skeleton.edge_color = 'branch-distance' - self.skeleton.face_color = 'branch-distance' - self.table_widget._recolor(by=None, cmap = self.random_cmap) - + self.skeleton.face_color = colors + self.table_widget._recolor(by="skeleton-id", cmap=self.random_cmap) + if ( + self.skeleton_visualization_dropdown.currentText() + == "Branch Length" + ): + self.skeleton.edge_color = "branch-distance" + self.skeleton.face_color = "branch-distance" + self.table_widget._recolor(by=None, cmap=self.random_cmap) + self.skeleton.refresh() - diff --git a/src/napari_lumen_segmentation/_widget.py b/src/napari_lumen_segmentation/_widget.py index 8d5ed0e..91ce255 100644 --- a/src/napari_lumen_segmentation/_widget.py +++ b/src/napari_lumen_segmentation/_widget.py @@ -2,50 +2,64 @@ Napari plugin widget for editing N-dimensional label data """ +import functools import os import shutil -import functools + +import dask.array as da import napari +import numpy as np +import pandas as pd import tifffile +from matplotlib.colors import ListedColormap, to_rgb +from napari.layers import Image, Labels +from napari_plane_sliders._plane_slider_widget import PlaneSliderWidget +from qtpy.QtWidgets import ( + QComboBox, + QDoubleSpinBox, + QFileDialog, + QGroupBox, + QHBoxLayout, + QLabel, + QLineEdit, + QMessageBox, + QPushButton, + QScrollArea, + QSpinBox, + QTabWidget, + QVBoxLayout, + QWidget, +) +from scipy import ndimage +from scipy.ndimage import binary_erosion +from skimage import measure +from skimage.io import imread +from skimage.segmentation import ( + expand_labels, + inverse_gaussian_gradient, + morphological_geodesic_active_contour, +) + +from ._custom_table_widget import ColoredTableWidget, TableWidget +from ._distance_widget import DistanceWidget +from ._layer_dropdown import LayerDropdown +from ._plot_widget import PlotWidget +from ._skeleton_widget import SkeletonWidget + -import numpy as np -import dask.array as da -import pandas as pd - -from scipy.ndimage import binary_erosion -from scipy import ndimage -from dask_image.imread import imread -from napari.layers import Image, Labels -from skimage import measure -from skimage.io import imread -from skimage.segmentation import expand_labels, morphological_geodesic_active_contour, inverse_gaussian_gradient -from qtpy.QtWidgets import QScrollArea, QDoubleSpinBox, QGroupBox, QMessageBox, QLabel, QHBoxLayout, QVBoxLayout, QPushButton, QWidget, QFileDialog, QLineEdit, QSpinBox, QComboBox, QTabWidget, QTableWidget -from qtpy.QtCore import * -from matplotlib.colors import to_rgb, ListedColormap - -from ._custom_table_widget import ColoredTableWidget, TableWidget -from .napari_multiple_view_widget import CrossWidget, MultipleViewerWidget -from ._plot_widget import PlotWidget -from ._skeleton_widget import SkeletonWidget -from ._distance_widget import DistanceWidget -from ._layer_dropdown import LayerDropdown - class AnnotateLabelsND(QWidget): - """Widget for manual correction of label data, for example to prepare ground truth data for training a segmentation model - - """ - - def __init__(self, viewer: 'napari.viewer.Viewer') -> None: + """Widget for manual correction of label data, for example to prepare ground truth data for training a segmentation model""" + + def __init__(self, viewer: "napari.viewer.Viewer") -> None: super().__init__() self.viewer = viewer - self.viewer.layers.clear() # ensure viewer is clean self.labels = None self.skeleton = None self.label_table = None self.skeleton_table = None - self.distance_table = None - self.csv_table = None + self.distance_table = None + self.csv_table = None self.points = None self.outputdir = None @@ -54,14 +68,17 @@ def __init__(self, viewer: 'napari.viewer.Viewer') -> None: self.option_labels = None self.csv_path = None self.label_properties = None - self.plot_widget = PlotWidget(props = pd.DataFrame()) - self.table_widget = TableWidget(props = pd.DataFrame()) - self.label_table_widget = ColoredTableWidget(napari.layers.Labels(np.zeros((10, 10), dtype = np.uint8)), self.viewer) - self.label_plot_widget = PlotWidget(props = pd.DataFrame()) - + self.plot_widget = PlotWidget(props=pd.DataFrame()) + self.table_widget = TableWidget(props=pd.DataFrame()) + self.label_table_widget = ColoredTableWidget( + napari.layers.Labels(np.zeros((10, 10), dtype=np.uint8)), + self.viewer, + ) + self.label_plot_widget = PlotWidget(props=pd.DataFrame()) + ### specify output directory outputbox_layout = QHBoxLayout() - self.outputdirbtn = QPushButton('Select output directory') + self.outputdirbtn = QPushButton("Select output directory") self.output_path = QLineEdit() outputbox_layout.addWidget(self.outputdirbtn) outputbox_layout.addWidget(self.output_path) @@ -74,43 +91,50 @@ def __init__(self, viewer: 'napari.viewer.Viewer') -> None: self.segmentation_layout.addWidget(self.label_dropdown) ### Add option to convert dask array to in-memory array - self.convert_to_array_btn = QPushButton('Convert to in-memory array') - self.convert_to_array_btn.setEnabled(self.labels != None and type(self.labels.data) == da.core.Array) + self.convert_to_array_btn = QPushButton("Convert to in-memory array") + self.convert_to_array_btn.setEnabled( + self.labels is not None + and isinstance(self.labels.data, da.core.Array) + ) self.convert_to_array_btn.clicked.connect(self._convert_to_array) self.segmentation_layout.addWidget(self.convert_to_array_btn) ### Add widget for adding overview table - self.table_btn = QPushButton('Show table') + self.table_btn = QPushButton("Show table") self.table_btn.clicked.connect(self._create_summary_table) - self.table_btn.clicked.connect(lambda: self.tab_widget.setCurrentIndex(2)) + self.table_btn.clicked.connect( + lambda: self.tab_widget.setCurrentIndex(2) + ) if self.labels is not None: self.table_btn.setEnabled(True) self.segmentation_layout.addWidget(self.table_btn) ## Add save labels widget - self.save_btn = QPushButton('Save labels') + self.save_btn = QPushButton("Save labels") self.save_btn.clicked.connect(self._save_labels) self.segmentation_layout.addWidget(self.save_btn) ## Add button to clear all layers - self.clear_btn = QPushButton('Clear all layers') + self.clear_btn = QPushButton("Clear all layers") self.clear_btn.clicked.connect(self._clear_layers) self.segmentation_layout.addWidget(self.clear_btn) ## Add button to run connected component analysis - self.convert_to_labels_btn = QPushButton('Run connected components labeling') + self.convert_to_labels_btn = QPushButton( + "Run connected components labeling" + ) self.convert_to_labels_btn.clicked.connect(self._convert_to_labels) self.segmentation_layout.addWidget(self.convert_to_labels_btn) - + ### Add widget for size filtering - filterbox = QGroupBox('Filter objects by size') + filterbox = QGroupBox("Filter objects by size") filter_layout = QVBoxLayout() label_size = QLabel("Min size threshold (voxels)") threshold_size_layout = QHBoxLayout() self.min_size_field = QSpinBox() self.min_size_field.setMaximum(1000000) - self.delete_btn = QPushButton('Delete') + self.delete_btn = QPushButton("Delete") threshold_size_layout.addWidget(self.min_size_field) threshold_size_layout.addWidget(self.delete_btn) @@ -125,7 +149,7 @@ def __init__(self, viewer: 'napari.viewer.Viewer') -> None: self.setLayout(self.segmentation_layout) ### Add widget for eroding/dilating labels - dil_erode_box = QGroupBox('Erode/dilate labels') + dil_erode_box = QGroupBox("Erode/dilate labels") dil_erode_box_layout = QVBoxLayout() radius_layout = QHBoxLayout() @@ -147,8 +171,8 @@ def __init__(self, viewer: 'napari.viewer.Viewer') -> None: iterations_layout.addWidget(self.iterations) shrink_dilate_buttons_layout = QHBoxLayout() - self.erode_btn = QPushButton('Erode') - self.dilate_btn = QPushButton('Dilate') + self.erode_btn = QPushButton("Erode") + self.dilate_btn = QPushButton("Dilate") self.erode_btn.clicked.connect(self._erode_labels) self.dilate_btn.clicked.connect(self._dilate_labels) shrink_dilate_buttons_layout.addWidget(self.erode_btn) @@ -166,21 +190,25 @@ def __init__(self, viewer: 'napari.viewer.Viewer') -> None: self.segmentation_layout.addWidget(dil_erode_box) ### Threshold image - threshold_box = QGroupBox('Threshold') + threshold_box = QGroupBox("Threshold") threshold_box_layout = QVBoxLayout() - self.threshold_layer_dropdown = LayerDropdown(self.viewer, (Image, Labels)) - self.threshold_layer_dropdown.layer_changed.connect(self._update_threshold_layer) + self.threshold_layer_dropdown = LayerDropdown( + self.viewer, (Image, Labels) + ) + self.threshold_layer_dropdown.layer_changed.connect( + self._update_threshold_layer + ) threshold_box_layout.addWidget(self.threshold_layer_dropdown) min_threshold_layout = QHBoxLayout() - min_threshold_layout.addWidget(QLabel('Min value')) + min_threshold_layout.addWidget(QLabel("Min value")) self.min_threshold = QSpinBox() self.min_threshold.setMaximum(65535) min_threshold_layout.addWidget(self.min_threshold) max_threshold_layout = QHBoxLayout() - max_threshold_layout.addWidget(QLabel('Max value')) + max_threshold_layout.addWidget(QLabel("Max value")) self.max_threshold = QSpinBox() self.max_threshold.setMaximum(65535) self.max_threshold.setValue(65535) @@ -188,25 +216,25 @@ def __init__(self, viewer: 'napari.viewer.Viewer') -> None: threshold_box_layout.addLayout(min_threshold_layout) threshold_box_layout.addLayout(max_threshold_layout) - threshold_btn = QPushButton('Run') + threshold_btn = QPushButton("Run") threshold_btn.clicked.connect(self._threshold) threshold_box_layout.addWidget(threshold_btn) - + threshold_box.setLayout(threshold_box_layout) self.segmentation_layout.addWidget(threshold_box) ### Add one image to another - image_calc_box = QGroupBox('Image calculator') + image_calc_box = QGroupBox("Image calculator") image_calc_box_layout = QVBoxLayout() image1_layout = QHBoxLayout() - image1_layout.addWidget(QLabel('Label image 1')) + image1_layout.addWidget(QLabel("Label image 1")) self.image1_dropdown = LayerDropdown(self.viewer, (Image, Labels)) self.image1_dropdown.layer_changed.connect(self._update_image1) image1_layout.addWidget(self.image1_dropdown) image2_layout = QHBoxLayout() - image2_layout.addWidget(QLabel('Label image 2')) + image2_layout.addWidget(QLabel("Label image 2")) self.image2_dropdown = LayerDropdown(self.viewer, (Image, Labels)) self.image2_dropdown.layer_changed.connect(self._update_image2) image2_layout.addWidget(self.image2_dropdown) @@ -222,11 +250,11 @@ def __init__(self, viewer: 'napari.viewer.Viewer') -> None: self.operation.addItem("Divide") self.operation.addItem("AND") self.operation.addItem("OR") - operation_layout.addWidget(QLabel('Operation')) + operation_layout.addWidget(QLabel("Operation")) operation_layout.addWidget(self.operation) image_calc_box_layout.addLayout(operation_layout) - add_images_btn = QPushButton('Run') + add_images_btn = QPushButton("Run") add_images_btn.clicked.connect(self._calculate_images) image_calc_box_layout.addWidget(add_images_btn) @@ -234,64 +262,68 @@ def __init__(self, viewer: 'napari.viewer.Viewer') -> None: self.segmentation_layout.addWidget(image_calc_box) ### Compute inverse gaussian gradient - inv_gauss_box = QGroupBox('Inverse Gaussian Gradient') + inv_gauss_box = QGroupBox("Inverse Gaussian Gradient") inv_gauss_box_layout = QVBoxLayout() inv_gauss_input_layout = QHBoxLayout() - inv_gauss_input_layout.addWidget(QLabel('Input image')) + inv_gauss_input_layout.addWidget(QLabel("Input image")) self.inv_gauss_input_dropdown = LayerDropdown(self.viewer, (Image)) - self.inv_gauss_input_dropdown.layer_changed.connect(self._update_inv_gauss_input) + self.inv_gauss_input_dropdown.layer_changed.connect( + self._update_inv_gauss_input + ) inv_gauss_input_layout.addWidget(self.inv_gauss_input_dropdown) inv_gauss_sigma_layout = QHBoxLayout() - inv_gauss_sigma_layout.addWidget(QLabel('Sigma')) + inv_gauss_sigma_layout.addWidget(QLabel("Sigma")) self.inv_gauss_sigma_spin = QSpinBox() self.inv_gauss_sigma_spin.setMinimum(1) self.inv_gauss_sigma_spin.setMaximum(50) inv_gauss_sigma_layout.addWidget(self.inv_gauss_sigma_spin) - inv_gauss_btn = QPushButton('Run') + inv_gauss_btn = QPushButton("Run") inv_gauss_btn.clicked.connect(self._calculate_inv_gauss) - + inv_gauss_box_layout.addLayout(inv_gauss_input_layout) inv_gauss_box_layout.addLayout(inv_gauss_sigma_layout) inv_gauss_box_layout.addWidget(inv_gauss_btn) inv_gauss_box.setLayout(inv_gauss_box_layout) self.segmentation_layout.addWidget(inv_gauss_box) - + ### Morphological geodesic active contour - active_contour_box = QGroupBox('Morphological Geodesic Active Contour') + active_contour_box = QGroupBox("Morphological Geodesic Active Contour") active_contour_box_layout = QVBoxLayout() inv_gauss_layout = QHBoxLayout() - inv_gauss_layout.addWidget(QLabel('Edges map')) + inv_gauss_layout.addWidget(QLabel("Edges map")) self.inv_gauss_dropdown = LayerDropdown(self.viewer, (Image)) self.inv_gauss_dropdown.layer_changed.connect(self._update_inv_gauss) inv_gauss_layout.addWidget(self.inv_gauss_dropdown) seeds_layout = QHBoxLayout() - seeds_layout.addWidget(QLabel('Label seeds')) + seeds_layout.addWidget(QLabel("Label seeds")) self.seeds_dropdown = LayerDropdown(self.viewer, (Labels)) self.seeds_dropdown.layer_changed.connect(self._update_seeds) seeds_layout.addWidget(self.seeds_dropdown) num_iter_layout = QHBoxLayout() - num_iter_layout.addWidget(QLabel('Number of iterations')) + num_iter_layout.addWidget(QLabel("Number of iterations")) self.num_iter_spin = QSpinBox() self.num_iter_spin.setMinimum(1) self.num_iter_spin.setMaximum(5000) num_iter_layout.addWidget(self.num_iter_spin) balloon_layout = QHBoxLayout() - balloon_layout.addWidget(QLabel('Balloon')) + balloon_layout.addWidget(QLabel("Balloon")) self.balloon = QDoubleSpinBox() self.balloon.setMinimum(-10) self.balloon.setMaximum(10) balloon_layout.addWidget(self.balloon) - calc_active_contour_btn = QPushButton('Run') - calc_active_contour_btn.clicked.connect(self._morphological_active_contour) + calc_active_contour_btn = QPushButton("Run") + calc_active_contour_btn.clicked.connect( + self._morphological_active_contour + ) active_contour_box_layout.addLayout(inv_gauss_layout) active_contour_box_layout.addLayout(seeds_layout) @@ -302,25 +334,11 @@ def __init__(self, viewer: 'napari.viewer.Viewer') -> None: active_contour_box.setLayout(active_contour_box_layout) self.segmentation_layout.addWidget(active_contour_box) - ### add the button to show the cross in multiview - cross_box = QGroupBox('Add cross to multiview') - cross_box_layout = QHBoxLayout() - self.cross = CrossWidget(self.viewer) - self.cross.setChecked(False) - self.cross.layer = None - cross_box_layout.addWidget(self.cross) - cross_box.setLayout(cross_box_layout) - self.segmentation_layout.addWidget(cross_box) - ### combine into tab widget - - ## Add multiview widget - self.multi_view_table_widget = QWidget() - self.multi_view_table_layout = QHBoxLayout() - self.multiview_widget = MultipleViewerWidget(self.viewer) - self.multi_view_table_layout.addWidget(self.multiview_widget) - self.multi_view_table_widget.setLayout(self.multi_view_table_layout) - self.tab_widget.addTab(self.multi_view_table_widget, "Orthogonal Views") + + ## add plane viewing widget + plane_widget = PlaneSliderWidget(self.viewer) + self.tab_widget.addTab(plane_widget, "Plane Viewing") ## add combined segmentation widgets self.segmentation_widgets = QWidget() @@ -332,178 +350,209 @@ def __init__(self, viewer: 'napari.viewer.Viewer') -> None: ## add tab with label plots self.label_plotting_widgets = QWidget() - self.label_plotting_widgets_layout = QVBoxLayout() + self.label_plotting_widgets_layout = QVBoxLayout() self.label_plotting_widgets_layout.addWidget(self.label_table_widget) self.label_plotting_widgets_layout.addWidget(self.label_plot_widget) - self.label_plotting_widgets.setLayout(self.label_plotting_widgets_layout) + self.label_plotting_widgets.setLayout( + self.label_plotting_widgets_layout + ) self.tab_widget.addTab(self.label_plotting_widgets, "Label Plots") ## add skeleton analysis widgets - self.skeleton_widget = SkeletonWidget(viewer = self.viewer, labels = self.labels) + self.skeleton_widget = SkeletonWidget( + viewer=self.viewer, labels=self.labels + ) self.tab_widget.addTab(self.skeleton_widget, "Skeleton Analysis") - self.distance_widget = DistanceWidget(viewer = self.viewer, labels = self.labels) + self.distance_widget = DistanceWidget( + viewer=self.viewer, labels=self.labels + ) self.tab_widget.addTab(self.distance_widget, "Distance Analysis") - + # Add the tab widget to the main layout self.main_layout = QVBoxLayout() self.main_layout.addWidget(self.tab_widget) self.setLayout(self.main_layout) - def _switch_table_content(self) -> None: + def _switch_table_content(self) -> None: """Set the content of the table widget depending on the choice in the table dropdown""" if self.table_dropdown.currentText() == "CSV": # switch to skeleton table - self.table_widget.set_content(self.csv_table.to_dict(orient = 'list')) - self.plot_widget.props = self.csv_table + self.table_widget.set_content( + self.csv_table.to_dict(orient="list") + ) + self.plot_widget.props = self.csv_table self.plot_widget._update_dropdowns() if self.table_dropdown.currentText() == "Skeleton": # switch to skeleton table - self.table_widget.set_content(self.skeleton_table.to_dict(orient = 'list')) + self.table_widget.set_content( + self.skeleton_table.to_dict(orient="list") + ) self.plot_widget.props = self.skeleton_table self.plot_widget._update_dropdowns() if self.table_dropdown.currentText() == "Distances": # switch to distance measurements table - self.table_widget.set_content(self.distance_table.to_dict(orient = 'list')) + self.table_widget.set_content( + self.distance_table.to_dict(orient="list") + ) self.plot_widget.props = self.distance_table self.plot_widget._update_dropdowns() - def _update_table_dropdown(self) -> None: + def _update_table_dropdown(self) -> None: """Update options in the table dropdown for plotting""" - for label, table_option in zip(["CSV", "Skeleton", "Distances"], [self.csv_table, self.skeleton_table, self.distance_table]): - if table_option is not None: + for label, table_option in zip( + ["CSV", "Skeleton", "Distances"], + [self.csv_table, self.skeleton_table, self.distance_table], + ): + if table_option is not None: label_exists = False for index in range(self.table_dropdown.count()): if self.table_dropdown.itemText(index) == label: label_exists = True break - if not label_exists: + if not label_exists: self.table_dropdown.addItem(label) - + def _choose_csv_path(self) -> None: options = QFileDialog.Options() - path, _ = QFileDialog.getOpenFileName(self, "Open .csv file", "", "CSV Files (*.csv);;All Files (*)", options=options) + path, _ = QFileDialog.getOpenFileName( + self, + "Open .csv file", + "", + "CSV Files (*.csv);;All Files (*)", + options=options, + ) if path: self.csv_path_edit.setText(path) self.csv_path = str(path) - - def _update_label_props_path(self) -> None: + + def _update_label_props_path(self) -> None: self.csv_path = str(self.csv_path_edit.text()) - def _set_csv_table(self) -> None: + def _set_csv_table(self) -> None: if self.csv_path is not None and os.path.exists(self.csv_path): self.csv_table = pd.read_csv(self.csv_path) - self.table_widget.set_content(self.csv_table.to_dict(orient='list')) - self.plot_widget.props = self.csv_table + self.table_widget.set_content( + self.csv_table.to_dict(orient="list") + ) + self.plot_widget.props = self.csv_table self.plot_widget.label_colormap = None - self.plot_widget._update_dropdowns() - self._update_table_dropdown() - self.table_dropdown.setCurrentText("CSV") + self.plot_widget._update_dropdowns() + self._update_table_dropdown() + self.table_dropdown.setCurrentText("CSV") else: - print('no csv file selected') + print("no csv file selected") def _on_get_output_dir(self) -> None: """Show a dialog window to let the user pick the output directory.""" - - path = QFileDialog.getExistingDirectory(self, 'Select Output Folder') + + path = QFileDialog.getExistingDirectory(self, "Select Output Folder") if path: self.output_path.setText(path) self.outputdir = str(self.output_path.text()) - def _convert_to_labels(self) -> None: + def _convert_to_labels(self) -> None: """Convert to labels image""" - - if self.labels is not None: - self.labels = self.viewer.add_labels(measure.label(self.labels.data)) + + if self.labels is not None: + self.labels = self.viewer.add_labels( + measure.label(self.labels.data) + ) self._update_labels(self.labels.name) - def _update_labels(self, selected_layer:str) -> None: + def _update_labels(self, selected_layer: str) -> None: """Update the layer that is set to be the 'labels' layer that is being edited.""" - if selected_layer == '': + if selected_layer == "": self.labels = None else: self.labels = self.viewer.layers[selected_layer] self.label_dropdown.setCurrentText(selected_layer) - self.convert_to_array_btn.setEnabled(type(self.labels.data) == da.core.Array) + self.convert_to_array_btn.setEnabled( + isinstance(self.labels.data, da.core.Array) + ) self.skeleton_widget.labels = self.labels self.distance_widget.labels = self.labels - - def _update_image1(self, selected_layer:str) -> None: + + def _update_image1(self, selected_layer: str) -> None: """Update the layer that is set to be the 'source labels' layer for copying labels from.""" - if selected_layer == '': + if selected_layer == "": self.image1_layer = None else: self.image1_layer = self.viewer.layers[selected_layer] self.image1_dropdown.setCurrentText(selected_layer) - def _update_image2(self, selected_layer:str) -> None: + def _update_image2(self, selected_layer: str) -> None: """Update the layer that is set to be the 'source labels' layer for copying labels from.""" - if selected_layer == '': + if selected_layer == "": self.image2_layer = None else: self.image2_layer = self.viewer.layers[selected_layer] self.image2_dropdown.setCurrentText(selected_layer) - - def _update_threshold_layer(self, selected_layer:str) -> None: + + def _update_threshold_layer(self, selected_layer: str) -> None: """Update the layer that is set to be the 'source labels' layer for copying labels from.""" - if selected_layer == '': + if selected_layer == "": self.threshold_layer = None else: self.threshold_layer = self.viewer.layers[selected_layer] self.threshold_layer_dropdown.setCurrentText(selected_layer) - - def _update_inv_gauss_input(self, selected_layer:str) -> None: + + def _update_inv_gauss_input(self, selected_layer: str) -> None: """Update the layer that is set to be the 'source labels' layer for copying labels from.""" - if selected_layer == '': + if selected_layer == "": self.inv_gauss_input_layer = None else: self.inv_gauss_input_layer = self.viewer.layers[selected_layer] self.inv_gauss_input_dropdown.setCurrentText(selected_layer) - def _update_inv_gauss(self, selected_layer:str) -> None: + def _update_inv_gauss(self, selected_layer: str) -> None: """Update the layer that is set to be the 'source labels' layer for copying labels from.""" - if selected_layer == '': + if selected_layer == "": self.inv_gauss_layer = None else: self.inv_gauss_layer = self.viewer.layers[selected_layer] self.inv_gauss_dropdown.setCurrentText(selected_layer) - - def _update_seeds(self, selected_layer:str) -> None: + + def _update_seeds(self, selected_layer: str) -> None: """Update the layer that is set to be the 'source labels' layer for copying labels from.""" - if selected_layer == '': + if selected_layer == "": self.seeds_layer = None else: self.seeds_layer = self.viewer.layers[selected_layer] self.seeds_dropdown.setCurrentText(selected_layer) - - def _convert_to_array(self) -> None: + + def _convert_to_array(self) -> None: """Convert from dask array to in-memory array. This is necessary for manual editing using the label tools (brush, eraser, fill bucket).""" - - if type(self.labels.data) == da.core.Array: + + if isinstance(self.labels.data, da.core.Array): stack = [] for i in range(self.labels.data.shape[0]): current_stack = self.labels.data[i].compute() stack.append(current_stack) - self.labels.data = np.stack(stack, axis = 0) - + self.labels.data = np.stack(stack, axis=0) + def _create_summary_table(self) -> None: """Create table displaying the sizes of the different labels in the current stack""" - if type(self.labels.data) == da.core.Array: + if isinstance(self.labels.data, da.core.Array): tp = self.viewer.dims.current_step[0] - current_stack = self.labels.data[tp].compute() # Compute the current stack - self.label_table = measure.regionprops_table(current_stack, properties = ['label', 'area', 'centroid']) + current_stack = self.labels.data[ + tp + ].compute() # Compute the current stack + self.label_table = measure.regionprops_table( + current_stack, properties=["label", "area", "centroid"] + ) if hasattr(self.labels, "properties"): self.labels.properties = self.label_table if hasattr(self.labels, "features"): @@ -512,42 +561,57 @@ def _create_summary_table(self) -> None: else: if len(self.labels.data.shape) == 4: tp = self.viewer.dims.current_step[0] - self.label_table = measure.regionprops_table(self.labels.data[tp], properties = ['label', 'area', 'centroid']) + self.label_table = measure.regionprops_table( + self.labels.data[tp], + properties=["label", "area", "centroid"], + ) if hasattr(self.labels, "properties"): self.labels.properties = self.label_table if hasattr(self.labels, "features"): self.labels.features = self.label_table - - elif len(self.labels.data.shape) == 3: - self.label_table = measure.regionprops_table(self.labels.data, properties = ['label', 'area', 'centroid']) + + elif len(self.labels.data.shape) == 3: + self.label_table = measure.regionprops_table( + self.labels.data, properties=["label", "area", "centroid"] + ) if hasattr(self.labels, "properties"): self.labels.properties = self.label_table if hasattr(self.labels, "features"): self.labels.features = self.label_table - else: - print('input should be a 3D or 4D array') + else: + print("input should be a 3D or 4D array") self.label_table = None if self.label_table_widget is not None: self.label_table_widget.hide() if self.viewer is not None: - self.label_table_widget = ColoredTableWidget(self.labels, self.viewer) + self.label_table_widget = ColoredTableWidget( + self.labels, self.viewer + ) self.label_table_widget._set_label_colors_to_rows() self.label_table_widget.setMinimumWidth(500) - self.label_plotting_widgets_layout.addWidget(self.label_table_widget) - + self.label_plotting_widgets_layout.addWidget( + self.label_table_widget + ) + # update the plot widget and set label colors - self.label_plot_widget.props = pd.DataFrame.from_dict(self.label_table) - unique_labels = self.label_plot_widget.props['label'].unique() - label_colors = [to_rgb(self.labels.get_color(label)) for label in unique_labels] - self.label_plot_widget.label_colormap = ListedColormap(label_colors) + self.label_plot_widget.props = pd.DataFrame.from_dict( + self.label_table + ) + unique_labels = self.label_plot_widget.props["label"].unique() + label_colors = [ + to_rgb(self.labels.get_color(label)) for label in unique_labels + ] + self.label_plot_widget.label_colormap = ListedColormap( + label_colors + ) self.label_plot_widget._update_dropdowns() - + def _save_labels(self) -> None: """Save the currently active labels layer. If it consists of multiple timepoints, they are written to multiple 3D stacks.""" - if type(self.labels.data) == da.core.Array: + if isinstance(self.labels.data, da.core.Array): if self.outputdir is None: msg = QMessageBox() @@ -557,39 +621,66 @@ def _save_labels(self) -> None: msg.setStandardButtons(QMessageBox.Ok) msg.exec_() return False - + else: - outputdir = os.path.join(self.outputdir, (self.labels.name + "_finalresult")) + outputdir = os.path.join( + self.outputdir, (self.labels.name + "_finalresult") + ) if os.path.exists(outputdir): shutil.rmtree(outputdir) os.mkdir(outputdir) - for i in range(self.labels.data.shape[0]): # Loop over the first dimension - current_stack = self.labels.data[i].compute() # Compute the current stack - tifffile.imwrite(os.path.join(outputdir, (self.labels.name + '_TP' + str(i).zfill(4) + '.tif')), np.array(current_stack, dtype = 'uint16')) + for i in range( + self.labels.data.shape[0] + ): # Loop over the first dimension + current_stack = self.labels.data[ + i + ].compute() # Compute the current stack + tifffile.imwrite( + os.path.join( + outputdir, + ( + self.labels.name + + "_TP" + + str(i).zfill(4) + + ".tif" + ), + ), + np.array(current_stack, dtype="uint16"), + ) return True elif len(self.labels.data.shape) == 4: filename, _ = QFileDialog.getSaveFileName( - caption='Save Labels', - directory='', - filter='TIFF files (*.tif *.tiff)') + caption="Save Labels", + directory="", + filter="TIFF files (*.tif *.tiff)", + ) for i in range(self.labels.data.shape[0]): labels_data = self.labels.data[i].astype(np.uint16) - tifffile.imwrite((filename.split('.tif')[0] + '_TP' + str(i).zfill(4) + '.tif'), labels_data) - - elif len(self.labels.data.shape) == 3: + tifffile.imwrite( + ( + filename.split(".tif")[0] + + "_TP" + + str(i).zfill(4) + + ".tif" + ), + labels_data, + ) + + elif len(self.labels.data.shape) == 3: filename, _ = QFileDialog.getSaveFileName( - caption='Save Labels', - directory='', - filter='TIFF files (*.tif *.tiff)') + caption="Save Labels", + directory="", + filter="TIFF files (*.tif *.tiff)", + ) if filename: labels_data = self.labels.data.astype(np.uint16) tifffile.imwrite(filename, labels_data) - - else: - print('labels should be a 3D or 4D array') + + else: + print("labels should be a 3D or 4D array") def _clear_layers(self) -> None: """Clear all the layers in the viewer""" @@ -601,124 +692,192 @@ def _clear_layers(self) -> None: def _keep_objects(self) -> None: """Keep only the labels that are selected by the points layer.""" - if type(self.labels.data) == da.core.Array: + if isinstance(self.labels.data, da.core.Array): tps = np.unique([int(p[0]) for p in self.points.data]) for tp in tps: - labels_to_keep = [] + labels_to_keep = [] points = [p for p in self.points.data if p[0] == tp] - current_stack = self.labels.data[tp].compute() # Compute the current stack + current_stack = self.labels.data[ + tp + ].compute() # Compute the current stack for p in points: - labels_to_keep.append(current_stack[int(p[1]), int(p[2]), int(p[3])]) - mask = functools.reduce(np.logical_or, (current_stack==val for val in labels_to_keep)) + labels_to_keep.append( + current_stack[int(p[1]), int(p[2]), int(p[3])] + ) + mask = functools.reduce( + np.logical_or, + (current_stack == val for val in labels_to_keep), + ) filtered = np.where(mask, current_stack, 0) self.labels.data[tp] = filtered - self.labels.data = self.labels.data # to trigger viewer update + self.labels.data = self.labels.data # to trigger viewer update - else: + else: if len(self.points.data[0]) == 4: tps = np.unique([int(p[0]) for p in self.points.data]) for tp in tps: labels_to_keep = [] points = [p for p in self.points.data if p[0] == tp] for p in points: - labels_to_keep.append(self.labels.data[tp, int(p[1]), int(p[2]), int(p[3])]) - mask = functools.reduce(np.logical_or, (self.labels.data[tp]==val for val in labels_to_keep)) + labels_to_keep.append( + self.labels.data[ + tp, int(p[1]), int(p[2]), int(p[3]) + ] + ) + mask = functools.reduce( + np.logical_or, + ( + self.labels.data[tp] == val + for val in labels_to_keep + ), + ) filtered = np.where(mask, self.labels.data[tp], 0) self.labels.data[tp] = filtered - self.labels.data = self.labels.data # to trigger viewer update + self.labels.data = self.labels.data # to trigger viewer update - else: + else: labels_to_keep = [] for p in self.points.data: if len(p) == 2: - labels_to_keep.append(self.labels.data[int(p[0]), int(p[1])]) + labels_to_keep.append( + self.labels.data[int(p[0]), int(p[1])] + ) elif len(p) == 3: - labels_to_keep.append(self.labels.data[int(p[0]), int(p[1]), int(p[2])]) - - mask = functools.reduce(np.logical_or, (self.labels.data==val for val in labels_to_keep)) + labels_to_keep.append( + self.labels.data[int(p[0]), int(p[1]), int(p[2])] + ) + + mask = functools.reduce( + np.logical_or, + (self.labels.data == val for val in labels_to_keep), + ) filtered = np.where(mask, self.labels.data, 0) - - self.labels = self.viewer.add_labels(filtered, name = self.labels.name + '_points_kept') - self._update_labels(self.labels.name) - + self.labels = self.viewer.add_labels( + filtered, name=self.labels.name + "_points_kept" + ) + self._update_labels(self.labels.name) def _add_option_layer(self): """Add a new labels layer that contains different alternative segmentations as channels, and add a function to select and copy these cells through shift-clicking""" - path = QFileDialog.getExistingDirectory(self, 'Select Label Image Parent Folder') + path = QFileDialog.getExistingDirectory( + self, "Select Label Image Parent Folder" + ) if path: - label_dirs = sorted([d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))]) + label_dirs = sorted( + [ + d + for d in os.listdir(path) + if os.path.isdir(os.path.join(path, d)) + ] + ) label_stacks = [] for d in label_dirs: # n dirs indicates number of channels - label_files = sorted([f for f in os.listdir(os.path.join(path, d)) if '.tif' in f]) + label_files = sorted( + [ + f + for f in os.listdir(os.path.join(path, d)) + if ".tif" in f + ] + ) label_imgs = [] for f in label_files: # n label_files indicates n time points img = imread(os.path.join(path, d, f)) label_imgs.append(img) - + if len(label_imgs) > 1: - label_stack = np.stack(label_imgs, axis = 0) - label_stacks.append(label_stack) + label_stack = np.stack(label_imgs, axis=0) + label_stacks.append(label_stack) else: label_stacks.append(img) - + if len(label_stacks) > 1: - self.option_labels = np.stack(label_stacks, axis = 0) + self.option_labels = np.stack(label_stacks, axis=0) elif len(label_stacks) == 1: self.option_labels = label_stacks[0] - + n_channels = len(label_dirs) n_timepoints = len(label_files) - if len(img.shape) == 3: + if len(img.shape) == 3: n_slices = img.shape[0] elif len(img.shape) == 2: n_slices = 1 - - self.option_labels = self.option_labels.reshape(n_channels, n_timepoints, n_slices, img.shape[-2], img.shape[-1]) - self.option_labels = self.viewer.add_labels(self.option_labels, name = 'label options') + + self.option_labels = self.option_labels.reshape( + n_channels, + n_timepoints, + n_slices, + img.shape[-2], + img.shape[-1], + ) + self.option_labels = self.viewer.add_labels( + self.option_labels, name="label options" + ) viewer = self.viewer + @viewer.mouse_drag_callbacks.append def cell_copied(viewer, event): - if event.type == "mouse_press" and 'Shift' in event.modifiers and viewer.layers.selection.active == self.option_labels: + if ( + event.type == "mouse_press" + and "Shift" in event.modifiers + and viewer.layers.selection.active == self.option_labels + ): coords = self.option_labels.world_to_data(event.position) coords = [int(c) for c in coords] selected_label = self.option_labels.get_value(coords) - mask = self.option_labels.data[coords[0], coords[1], :, :, :] == selected_label + mask = ( + self.option_labels.data[coords[0], coords[1], :, :, :] + == selected_label + ) - if type(self.labels.data) == da.core.Array: + if isinstance(self.labels.data, da.core.Array): target_stack = self.labels.data[coords[-4]].compute() - orig_label = target_stack[coords[-3], coords[-2], coords[-1]] - if orig_label != 0: - target_stack[target_stack == orig_label] = 0 + orig_label = target_stack[ + coords[-3], coords[-2], coords[-1] + ] + if orig_label != 0: + target_stack[target_stack == orig_label] = 0 target_stack[mask] = np.max(target_stack) + 1 self.labels.data[coords[-4]] = target_stack self.labels.data = self.labels.data - - else: + + else: if len(self.labels.data.shape) == 3: - orig_label = self.labels.data[coords[-3], coords[-2], coords[-1]] + orig_label = self.labels.data[ + coords[-3], coords[-2], coords[-1] + ] if orig_label != 0: - self.labels.data[self.labels.data == orig_label] = 0 # set the original label to zero + self.labels.data[ + self.labels.data == orig_label + ] = 0 # set the original label to zero self.labels.data[mask] = np.max(self.labels.data) + 1 self.labels.data = self.labels.data elif len(self.labels.data.shape) == 4: - orig_label = self.labels.data[coords[-4], coords[-3], coords[-2], coords[-1]] + orig_label = self.labels.data[ + coords[-4], coords[-3], coords[-2], coords[-1] + ] - if orig_label != 0: - self.labels.data[coords[-4]][self.labels.data[coords[-4]] == orig_label] = 0 # set the original label to zero - self.labels.data[coords[-4]][mask] = np.max(self.labels.data) + 1 + if orig_label != 0: + self.labels.data[coords[-4]][ + self.labels.data[coords[-4]] == orig_label + ] = 0 # set the original label to zero + self.labels.data[coords[-4]][mask] = ( + np.max(self.labels.data) + 1 + ) self.labels.data = self.labels.data - + elif len(self.labels.data.shape) == 5: msg_box = QMessageBox() msg_box.setIcon(QMessageBox.Question) - msg_box.setText("Copy-pasting in 5 dimensions is not implemented, do you want to convert the labels layer to 5 dimensions (tzyx)?") + msg_box.setText( + "Copy-pasting in 5 dimensions is not implemented, do you want to convert the labels layer to 5 dimensions (tzyx)?" + ) msg_box.setWindowTitle("Convert to 4 dimensions?") yes_button = msg_box.addButton(QMessageBox.Yes) @@ -729,14 +888,16 @@ def cell_copied(viewer, event): if msg_box.clickedButton() == yes_button: self.labels.data = self.labels.data[0] elif msg_box.clickedButton() == no_button: - return False + return False else: - print('copy-pasting in more than 5 dimensions is not supported') - + print( + "copy-pasting in more than 5 dimensions is not supported" + ) + def _delete_small_objects(self) -> None: """Delete small objects in the selected layer""" - if type(self.labels.data) == da.core.Array: + if isinstance(self.labels.data, da.core.Array): if self.outputdir is None: msg = QMessageBox() msg.setWindowTitle("No output directory selected") @@ -747,56 +908,112 @@ def _delete_small_objects(self) -> None: return False else: - outputdir = os.path.join(self.outputdir, (self.labels.name + "_sizefiltered")) + outputdir = os.path.join( + self.outputdir, (self.labels.name + "_sizefiltered") + ) if os.path.exists(outputdir): shutil.rmtree(outputdir) os.mkdir(outputdir) - for i in range(self.labels.data.shape[0]): # Loop over the first dimension - current_stack = self.labels.data[i].compute() # Compute the current stack + for i in range( + self.labels.data.shape[0] + ): # Loop over the first dimension + current_stack = self.labels.data[ + i + ].compute() # Compute the current stack # measure the sizes in pixels of the labels in slice using skimage.regionprops props = measure.regionprops(current_stack) - filtered_labels = [p.label for p in props if p.area > self.min_size_field.value()] - mask = functools.reduce(np.logical_or, (current_stack==val for val in filtered_labels)) + filtered_labels = [ + p.label + for p in props + if p.area > self.min_size_field.value() + ] + mask = functools.reduce( + np.logical_or, + (current_stack == val for val in filtered_labels), + ) filtered = np.where(mask, current_stack, 0) - tifffile.imwrite(os.path.join(outputdir, (self.labels.name + '_sizefiltered_TP' + str(i).zfill(4) + '.tif')), np.array(filtered, dtype = 'uint16')) - - file_list = [os.path.join(outputdir, fname) for fname in os.listdir(outputdir) if fname.endswith('.tif')] - self.labels = self.viewer.add_labels(da.stack([imread(fname) for fname in sorted(file_list)]), name = self.labels.name + '_sizefiltered') + tifffile.imwrite( + os.path.join( + outputdir, + ( + self.labels.name + + "_sizefiltered_TP" + + str(i).zfill(4) + + ".tif" + ), + ), + np.array(filtered, dtype="uint16"), + ) + + file_list = [ + os.path.join(outputdir, fname) + for fname in os.listdir(outputdir) + if fname.endswith(".tif") + ] + self.labels = self.viewer.add_labels( + da.stack([imread(fname) for fname in sorted(file_list)]), + name=self.labels.name + "_sizefiltered", + ) self._update_labels(self.labels.name) - + else: - # Image data is a normal array and can be directly edited. - if len(self.labels.data.shape) == 4: + # Image data is a normal array and can be directly edited. + if len(self.labels.data.shape) == 4: stack = [] for i in range(self.labels.data.shape[0]): props = measure.regionprops(self.labels.data[i]) - filtered_labels = [p.label for p in props if p.area > self.min_size_field.value()] - mask = functools.reduce(np.logical_or, (self.labels.data[i]==val for val in filtered_labels)) + filtered_labels = [ + p.label + for p in props + if p.area > self.min_size_field.value() + ] + mask = functools.reduce( + np.logical_or, + ( + self.labels.data[i] == val + for val in filtered_labels + ), + ) filtered = np.where(mask, self.labels.data[i], 0) stack.append(filtered) - self.labels = self.viewer.add_labels(np.stack(stack, axis = 0), name = self.labels.name + '_sizefiltered') + self.labels = self.viewer.add_labels( + np.stack(stack, axis=0), + name=self.labels.name + "_sizefiltered", + ) self._update_labels(self.labels.name) - + elif len(self.labels.data.shape) == 3: props = measure.regionprops(self.labels.data) - filtered_labels = [p.label for p in props if p.area > self.min_size_field.value()] - mask = functools.reduce(np.logical_or, (self.labels.data==val for val in filtered_labels)) - self.labels = self.viewer.add_labels(np.where(mask, self.labels.data, 0), name = self.labels.name + '_sizefiltered') + filtered_labels = [ + p.label + for p in props + if p.area > self.min_size_field.value() + ] + mask = functools.reduce( + np.logical_or, + (self.labels.data == val for val in filtered_labels), + ) + self.labels = self.viewer.add_labels( + np.where(mask, self.labels.data, 0), + name=self.labels.name + "_sizefiltered", + ) self._update_labels(self.labels.name) else: - print('input should be 3D or 4D array') + print("input should be 3D or 4D array") def _erode_labels(self): """Shrink oversized labels through erosion""" diam = self.structuring_element_diameter.value() iterations = self.iterations.value() - structuring_element = np.ones((diam, diam, diam), dtype=bool) # Define a 3x3x3 structuring element for 3D erosion + structuring_element = np.ones( + (diam, diam, diam), dtype=bool + ) # Define a 3x3x3 structuring element for 3D erosion - if type(self.labels.data) == da.core.Array: + if isinstance(self.labels.data, da.core.Array): if self.outputdir is None: msg = QMessageBox() msg.setWindowTitle("No output directory selected") @@ -805,44 +1022,85 @@ def _erode_labels(self): msg.setStandardButtons(QMessageBox.Ok) msg.exec_() return False - + else: - outputdir = os.path.join(self.outputdir, (self.labels.name + "_eroded")) + outputdir = os.path.join( + self.outputdir, (self.labels.name + "_eroded") + ) if os.path.exists(outputdir): shutil.rmtree(outputdir) os.mkdir(outputdir) - for i in range(self.labels.data.shape[0]): # Loop over the first dimension - current_stack = self.labels.data[i].compute() # Compute the current stack + for i in range( + self.labels.data.shape[0] + ): # Loop over the first dimension + current_stack = self.labels.data[ + i + ].compute() # Compute the current stack mask = current_stack > 0 filled_mask = ndimage.binary_fill_holes(mask) - eroded_mask = binary_erosion(filled_mask, structure=structuring_element, iterations=iterations) + eroded_mask = binary_erosion( + filled_mask, + structure=structuring_element, + iterations=iterations, + ) eroded = np.where(eroded_mask, current_stack, 0) - tifffile.imwrite(os.path.join(outputdir, (self.labels.name + '_eroded_TP' + str(i).zfill(4) + '.tif')), np.array(eroded, dtype = 'uint16')) - - file_list = [os.path.join(outputdir, fname) for fname in os.listdir(outputdir) if fname.endswith('.tif')] - self.labels = self.viewer.add_labels(da.stack([imread(fname) for fname in sorted(file_list)]), name = self.labels.name + '_eroded') + tifffile.imwrite( + os.path.join( + outputdir, + ( + self.labels.name + + "_eroded_TP" + + str(i).zfill(4) + + ".tif" + ), + ), + np.array(eroded, dtype="uint16"), + ) + + file_list = [ + os.path.join(outputdir, fname) + for fname in os.listdir(outputdir) + if fname.endswith(".tif") + ] + self.labels = self.viewer.add_labels( + da.stack([imread(fname) for fname in sorted(file_list)]), + name=self.labels.name + "_eroded", + ) self._update_labels(self.labels.name) return True else: - if len(self.labels.data.shape) == 4: + if len(self.labels.data.shape) == 4: stack = [] for i in range(self.labels.data.shape[0]): mask = self.labels.data[i] > 0 filled_mask = ndimage.binary_fill_holes(mask) - eroded_mask = binary_erosion(filled_mask, structure=structuring_element, iterations=iterations) + eroded_mask = binary_erosion( + filled_mask, + structure=structuring_element, + iterations=iterations, + ) stack.append(np.where(eroded_mask, self.labels.data[i], 0)) - self.labels = self.viewer.add_labels(np.stack(stack, axis = 0), name = self.labels.name + '_eroded') - self._update_labels(self.labels.name) - elif len(self.labels.data.shape) == 3: + self.labels = self.viewer.add_labels( + np.stack(stack, axis=0), name=self.labels.name + "_eroded" + ) + self._update_labels(self.labels.name) + elif len(self.labels.data.shape) == 3: mask = self.labels.data > 0 filled_mask = ndimage.binary_fill_holes(mask) - eroded_mask = binary_erosion(filled_mask, structure=structuring_element, iterations=iterations) - self.labels = self.viewer.add_labels(np.where(eroded_mask, self.labels.data, 0), name = self.labels.name + '_eroded') + eroded_mask = binary_erosion( + filled_mask, + structure=structuring_element, + iterations=iterations, + ) + self.labels = self.viewer.add_labels( + np.where(eroded_mask, self.labels.data, 0), + name=self.labels.name + "_eroded", + ) self._update_labels(self.labels.name) - else: - print('4D or 3D array required!') + else: + print("4D or 3D array required!") def _dilate_labels(self): """Dilate labels in the selected layer.""" @@ -850,7 +1108,7 @@ def _dilate_labels(self): diam = self.structuring_element_diameter.value() iterations = self.iterations.value() - if type(self.labels.data) == da.core.Array: + if isinstance(self.labels.data, da.core.Array): if self.outputdir is None: msg = QMessageBox() msg.setWindowTitle("No output directory selected") @@ -859,127 +1117,205 @@ def _dilate_labels(self): msg.setStandardButtons(QMessageBox.Ok) msg.exec_() return False - + else: - outputdir = os.path.join(self.outputdir, (self.labels.name + "_dilated")) + outputdir = os.path.join( + self.outputdir, (self.labels.name + "_dilated") + ) if os.path.exists(outputdir): shutil.rmtree(outputdir) os.mkdir(outputdir) - for i in range(self.labels.data.shape[0]): # Loop over the first dimension - expanded_labels = self.labels.data[i].compute() # Compute the current stack - for j in range(iterations): - expanded_labels = expand_labels(expanded_labels, distance = diam) - tifffile.imwrite(os.path.join(outputdir, (self.labels.name + '_dilated_TP' + str(i).zfill(4) + '.tif')), np.array(expanded_labels, dtype = 'uint16')) - - file_list = [os.path.join(outputdir, fname) for fname in os.listdir(outputdir) if fname.endswith('.tif')] - self.labels = self.viewer.add_labels(da.stack([imread(fname) for fname in sorted(file_list)]), name = self.labels.name + '_dilated') + for i in range( + self.labels.data.shape[0] + ): # Loop over the first dimension + expanded_labels = self.labels.data[ + i + ].compute() # Compute the current stack + for _j in range(iterations): + expanded_labels = expand_labels( + expanded_labels, distance=diam + ) + tifffile.imwrite( + os.path.join( + outputdir, + ( + self.labels.name + + "_dilated_TP" + + str(i).zfill(4) + + ".tif" + ), + ), + np.array(expanded_labels, dtype="uint16"), + ) + + file_list = [ + os.path.join(outputdir, fname) + for fname in os.listdir(outputdir) + if fname.endswith(".tif") + ] + self.labels = self.viewer.add_labels( + da.stack([imread(fname) for fname in sorted(file_list)]), + name=self.labels.name + "_dilated", + ) self._update_labels(self.labels.name) return True - else: - if len(self.labels.data.shape) == 4: + else: + if len(self.labels.data.shape) == 4: stack = [] for i in range(self.labels.data.shape[0]): expanded_labels = self.labels.data[i] - for i in range(iterations): - expanded_labels = expand_labels(expanded_labels, distance = diam) + for _i in range(iterations): + expanded_labels = expand_labels( + expanded_labels, distance=diam + ) stack.append(expanded_labels) - self.labels = self.viewer.add_labels(np.stack(stack, axis = 0), name = self.labels.name + '_dilated') - self._update_labels(self.labels.name) + self.labels = self.viewer.add_labels( + np.stack(stack, axis=0), name=self.labels.name + "_dilated" + ) + self._update_labels(self.labels.name) elif len(self.labels.data.shape) == 3: expanded_labels = self.labels.data - for i in range(iterations): - expanded_labels = expand_labels(expanded_labels, distance = diam) - - self.labels = self.viewer.add_labels(expanded_labels, name = self.labels.name + '_dilated') + for _i in range(iterations): + expanded_labels = expand_labels( + expanded_labels, distance=diam + ) + + self.labels = self.viewer.add_labels( + expanded_labels, name=self.labels.name + "_dilated" + ) self._update_labels(self.labels.name) - else: - print('input should be a 3D or 4D stack') - - def _threshold(self) -> None: - """Threshold the selected label or intensity image""" + else: + print("input should be a 3D or 4D stack") - if type(self.threshold_layer.data) == da.core.Array: - msg = QMessageBox() - msg.setWindowTitle("Thresholding not yet implemented for dask arrays") - msg.setText("Thresholding not yet implemented for dask arrays") - msg.setIcon(QMessageBox.Information) - msg.setStandardButtons(QMessageBox.Ok) - msg.exec_() - return False + def _threshold(self) -> None: + """Threshold the selected label or intensity image""" - thresholded = (self.threshold_layer.data >= int(self.min_threshold.value())) & (self.threshold_layer.data <= int(self.max_threshold.value())) - self.viewer.add_labels(thresholded, name = self.threshold_layer.name + "_thresholded") + if isinstance(self.threshold_layer.data, da.core.Array): + msg = QMessageBox() + msg.setWindowTitle( + "Thresholding not yet implemented for dask arrays" + ) + msg.setText("Thresholding not yet implemented for dask arrays") + msg.setIcon(QMessageBox.Information) + msg.setStandardButtons(QMessageBox.Ok) + msg.exec_() + return False + + thresholded = ( + self.threshold_layer.data >= int(self.min_threshold.value()) + ) & (self.threshold_layer.data <= int(self.max_threshold.value())) + self.viewer.add_labels( + thresholded, name=self.threshold_layer.name + "_thresholded" + ) def _calculate_images(self) -> None: """Add label image 2 to label image 1""" - if type(self.image1_layer) == da.core.Array or type(self.image2_layer) == da.core.Array: - msg = QMessageBox() - msg.setWindowTitle("Cannot yet run image calculator on dask arrays") - msg.setText("Cannot yet run image calculator on dask arrays") - msg.setIcon(QMessageBox.Information) - msg.setStandardButtons(QMessageBox.Ok) - msg.exec_() - return False - if self.image1_layer.data.shape != self.image2_layer.data.shape: - msg = QMessageBox() - msg.setWindowTitle("Images must have the same shape") - msg.setText("Images must have the same shape") - msg.setIcon(QMessageBox.Information) - msg.setStandardButtons(QMessageBox.Ok) - msg.exec_() - return False - + if isinstance(self.image1_layer, da.core.Array) or isinstance( + self.image2_layer, da.core.Array + ): + msg = QMessageBox() + msg.setWindowTitle( + "Cannot yet run image calculator on dask arrays" + ) + msg.setText("Cannot yet run image calculator on dask arrays") + msg.setIcon(QMessageBox.Information) + msg.setStandardButtons(QMessageBox.Ok) + msg.exec_() + return False + if self.image1_layer.data.shape != self.image2_layer.data.shape: + msg = QMessageBox() + msg.setWindowTitle("Images must have the same shape") + msg.setText("Images must have the same shape") + msg.setIcon(QMessageBox.Information) + msg.setStandardButtons(QMessageBox.Ok) + msg.exec_() + return False + if self.operation.currentText() == "Add": - self.viewer.add_image(np.add(self.image1_layer.data, self.image2_layer.data)) + self.viewer.add_image( + np.add(self.image1_layer.data, self.image2_layer.data) + ) if self.operation.currentText() == "Subtract": - self.viewer.add_image(np.subtract(self.image1_layer.data, self.image2_layer.data)) + self.viewer.add_image( + np.subtract(self.image1_layer.data, self.image2_layer.data) + ) if self.operation.currentText() == "Multiply": - self.viewer.add_image(np.multiply(self.image1_layer.data, self.image2_layer.data)) + self.viewer.add_image( + np.multiply(self.image1_layer.data, self.image2_layer.data) + ) if self.operation.currentText() == "Divide": - self.viewer.add_image(np.divide(self.image1_layer.data, self.image2_layer.data, out=np.zeros_like(self.image1_layer.data, dtype=float), where=self.image2_layer.data!=0)) + self.viewer.add_image( + np.divide( + self.image1_layer.data, + self.image2_layer.data, + out=np.zeros_like(self.image1_layer.data, dtype=float), + where=self.image2_layer.data != 0, + ) + ) if self.operation.currentText() == "AND": - self.viewer.add_labels(np.logical_and(self.image1_layer.data != 0, self.image2_layer.data != 0).astype(int)) + self.viewer.add_labels( + np.logical_and( + self.image1_layer.data != 0, self.image2_layer.data != 0 + ).astype(int) + ) if self.operation.currentText() == "OR": - self.viewer.add_labels(np.logical_or(self.image1_layer.data != 0, self.image2_layer.data != 0).astype(int)) + self.viewer.add_labels( + np.logical_or( + self.image1_layer.data != 0, self.image2_layer.data != 0 + ).astype(int) + ) - def _calculate_inv_gauss(self) -> None: + def _calculate_inv_gauss(self) -> None: """Calculate inverse gaussian gradient""" - if type(self.inv_gauss_input_layer) == da.core.Array or type(self.seeds_layer) == da.core.Array: - msg = QMessageBox() - msg.setWindowTitle("Please convert to an in memory array") - msg.setText("Please convert to an in memory array") - msg.setIcon(QMessageBox.Information) - msg.setStandardButtons(QMessageBox.Ok) - msg.exec_() - return False - - self.viewer.add_image(inverse_gaussian_gradient(np.array(self.inv_gauss_input_layer.data, dtype = np.float32), sigma = self.inv_gauss_sigma_spin.value())) + if isinstance(self.inv_gauss_input_layer, da.core.Array) or isinstance( + self.seeds_layer, da.core.Array + ): + msg = QMessageBox() + msg.setWindowTitle("Please convert to an in memory array") + msg.setText("Please convert to an in memory array") + msg.setIcon(QMessageBox.Information) + msg.setStandardButtons(QMessageBox.Ok) + msg.exec_() + return False + + self.viewer.add_image( + inverse_gaussian_gradient( + np.array(self.inv_gauss_input_layer.data, dtype=np.float32), + sigma=self.inv_gauss_sigma_spin.value(), + ) + ) def _morphological_active_contour(self) -> None: """Run morphological active contour algorithm""" - if type(self.inv_gauss_layer) == da.core.Array or type(self.seeds_layer) == da.core.Array: - msg = QMessageBox() - msg.setWindowTitle("Please convert to an in memory array") - msg.setText("Please convert to an in memory array") - msg.setIcon(QMessageBox.Information) - msg.setStandardButtons(QMessageBox.Ok) - msg.exec_() - return False - if self.inv_gauss_layer.data.shape != self.seeds_layer.data.shape: - msg = QMessageBox() - msg.setWindowTitle("Images must have the same shape") - msg.setText("Images must have the same shape") - msg.setIcon(QMessageBox.Information) - msg.setStandardButtons(QMessageBox.Ok) - msg.exec_() - return False - - self.viewer.add_labels(morphological_geodesic_active_contour(self.inv_gauss_layer.data, init_level_set = self.seeds_layer.data, num_iter = self.num_iter_spin.value(), balloon = self.balloon.value())) - - - + if isinstance(self.inv_gauss_input_layer, da.core.Array) or isinstance( + self.seeds_layer, da.core.Array + ): + msg = QMessageBox() + msg.setWindowTitle("Please convert to an in memory array") + msg.setText("Please convert to an in memory array") + msg.setIcon(QMessageBox.Information) + msg.setStandardButtons(QMessageBox.Ok) + msg.exec_() + return False + if self.inv_gauss_layer.data.shape != self.seeds_layer.data.shape: + msg = QMessageBox() + msg.setWindowTitle("Images must have the same shape") + msg.setText("Images must have the same shape") + msg.setIcon(QMessageBox.Information) + msg.setStandardButtons(QMessageBox.Ok) + msg.exec_() + return False + + self.viewer.add_labels( + morphological_geodesic_active_contour( + self.inv_gauss_layer.data, + init_level_set=self.seeds_layer.data, + num_iter=self.num_iter_spin.value(), + balloon=self.balloon.value(), + ) + ) diff --git a/src/napari_lumen_segmentation/napari_multiple_view_widget.py b/src/napari_lumen_segmentation/napari_multiple_view_widget.py deleted file mode 100644 index 94a7fb5..0000000 --- a/src/napari_lumen_segmentation/napari_multiple_view_widget.py +++ /dev/null @@ -1,438 +0,0 @@ -""" -This is an example on how to have more than one viewer in the same napari window. -Additional viewers state will be synchronized with the main viewer. -Switching to 3D display will only impact the main viewer. - -This example also contain option to enable cross that will be moved to the -current dims point (`viewer.dims.point`). -""" - -from copy import deepcopy - -import numpy as np -from qtpy.QtCore import Qt -from qtpy.QtWidgets import ( - QCheckBox, - QDoubleSpinBox, - QPushButton, - QSplitter, - QTabWidget, - QVBoxLayout, - QWidget, -) -from superqt.utils import qthrottled -from packaging.version import parse as parse_version - -import napari -from napari.components.layerlist import Extent -from napari.components.viewer_model import ViewerModel -from napari.layers import Image, Labels, Layer, Vectors -from napari.qt import QtViewer -from napari.utils.action_manager import action_manager -from napari.utils.events.event import WarningEmitter -from napari.utils.notifications import show_info - -NAPARI_GE_4_16 = parse_version(napari.__version__) > parse_version("0.4.16") - - -def copy_layer_le_4_16(layer: Layer, name: str = ""): - res_layer = deepcopy(layer) - # this deepcopy is not optimal for labels and images layers - if isinstance(layer, (Image, Labels)): - res_layer.data = layer.data - - res_layer.metadata["viewer_name"] = name - - res_layer.events.disconnect() - res_layer.events.source = res_layer - for emitter in res_layer.events.emitters.values(): - emitter.disconnect() - emitter.source = res_layer - return res_layer - - -def copy_layer(layer: Layer, name: str = ""): - if NAPARI_GE_4_16: - return copy_layer_le_4_16(layer, name) - - res_layer = Layer.create(*layer.as_layer_data_tuple()) - res_layer.metadata["viewer_name"] = name - return res_layer - - -def get_property_names(layer: Layer): - klass = layer.__class__ - res = [] - for event_name, event_emitter in layer.events.emitters.items(): - if isinstance(event_emitter, WarningEmitter): - continue - if event_name in ("thumbnail", "name"): - continue - if ( - isinstance(getattr(klass, event_name, None), property) - and getattr(klass, event_name).fset is not None - ): - res.append(event_name) - return res - - -def center_cross_on_mouse( - viewer_model: napari.components.viewer_model.ViewerModel, -): - """move the cross to the mouse position""" - - if not getattr(viewer_model, "mouse_over_canvas", True): - # There is no way for napari 0.4.15 to check if mouse is over sending canvas. - show_info( - "Mouse is not over the canvas. You may need to click on the canvas." - ) - return - - viewer_model.dims.current_step = tuple( - np.round( - [ - max(min_, min(p, max_)) / step - for p, (min_, max_, step) in zip( - viewer_model.cursor.position, viewer_model.dims.range - ) - ] - ).astype(int) - ) - - -action_manager.register_action( - name='napari:move_point', - command=center_cross_on_mouse, - description='Move dims point to mouse position', - keymapprovider=ViewerModel, -) - -action_manager.bind_shortcut('napari:move_point', 'C') - - -class own_partial: - """ - Workaround for deepcopy not copying partial functions - (Qt widgets are not serializable) - """ - - def __init__(self, func, *args, **kwargs): - self.func = func - self.args = args - self.kwargs = kwargs - - def __call__(self, *args, **kwargs): - return self.func(*(self.args + args), **{**self.kwargs, **kwargs}) - - def __deepcopy__(self, memodict={}): - return own_partial( - self.func, - *deepcopy(self.args, memodict), - **deepcopy(self.kwargs, memodict), - ) - - -class QtViewerWrap(QtViewer): - def __init__(self, main_viewer, *args, **kwargs): - super().__init__(*args, **kwargs) - self.main_viewer = main_viewer - - def _qt_open( - self, - filenames: list, - stack: bool, - plugin: str = None, - layer_type: str = None, - **kwargs, - ): - """for drag and drop open files""" - self.main_viewer.window._qt_viewer._qt_open( - filenames, stack, plugin, layer_type, **kwargs - ) - - -class CrossWidget(QCheckBox): - """ - Widget to control the cross layer. because of the performance reason - the cross update is throttled - """ - - def __init__(self, viewer: napari.Viewer): - super().__init__("Add cross layer") - self.viewer = viewer - self.setChecked(False) - self.stateChanged.connect(self._update_cross_visibility) - self.layer = None - self.viewer.dims.events.order.connect(self.update_cross) - self.viewer.dims.events.ndim.connect(self._update_ndim) - self.viewer.dims.events.current_step.connect(self.update_cross) - self._extent = None - - self._update_extent() - self.viewer.dims.events.connect(self._update_extent) - - @qthrottled(leading=False) - def _update_extent(self): - """ - Calculate the extent of the data. - - Ignores the cross layer itself in calculating the extent. - """ - if NAPARI_GE_4_16: - layers = [ - layer - for layer in self.viewer.layers - if layer is not self.layer - ] - self._extent = self.viewer.layers.get_extent(layers) - else: - extent_list = [ - layer.extent - for layer in self.viewer.layers - if layer is not self.layer - ] - self._extent = Extent( - data=None, - world=self.viewer.layers._get_extent_world(extent_list), - step=self.viewer.layers._get_step_size(extent_list), - ) - self.update_cross() - - def _update_ndim(self, event): - if self.layer in self.viewer.layers: - self.viewer.layers.remove(self.layer) - self.layer = Vectors(name=".cross", ndim=event.value) - self.layer.edge_width = 1 - self.update_cross() - - def _update_cross_visibility(self, state): - if state: - self.viewer.layers.append(self.layer) - else: - self.viewer.layers.remove(self.layer) - self.update_cross() - - def update_cross(self): - if self.layer not in self.viewer.layers: - self.setChecked(False) - return - - point = self.viewer.dims.current_step - vec = [] - for i, (lower, upper) in enumerate(self._extent.world.T): - if (upper - lower) / self._extent.step[i] == 1: - continue - point1 = list(point) - point1[i] = (lower + self._extent.step[i] / 2) / self._extent.step[ - i - ] - point2 = [0 for _ in point] - point2[i] = (upper - lower) / self._extent.step[i] - vec.append((point1, point2)) - if np.any(self.layer.scale != self._extent.step): - self.layer.scale = self._extent.step - self.layer.data = vec - - -class MultipleViewerWidget(QSplitter): - """The main widget of the example.""" - - def __init__(self, viewer: napari.Viewer): - super().__init__() - self.viewer = viewer - self.viewer_model1 = ViewerModel(title="model1") - self.viewer_model2 = ViewerModel(title="model2") - self._block = False - self.qt_viewer1 = QtViewerWrap(viewer, self.viewer_model1) - self.qt_viewer2 = QtViewerWrap(viewer, self.viewer_model2) - viewer_splitter = QSplitter() - viewer_splitter.setOrientation(Qt.Vertical) - viewer_splitter.addWidget(self.qt_viewer1) - viewer_splitter.addWidget(self.qt_viewer2) - viewer_splitter.setContentsMargins(0, 0, 0, 0) - - self.addWidget(viewer_splitter) - - self.viewer.layers.events.inserted.connect(self._layer_added) - self.viewer.layers.events.removed.connect(self._layer_removed) - self.viewer.layers.events.moved.connect(self._layer_moved) - self.viewer.layers.selection.events.active.connect( - self._layer_selection_changed - ) - self.viewer.dims.events.current_step.connect(self._point_update) - self.viewer_model1.dims.events.current_step.connect(self._point_update) - self.viewer_model2.dims.events.current_step.connect(self._point_update) - self.viewer.dims.events.order.connect(self._order_update) - self.viewer.events.reset_view.connect(self._reset_view) - self.viewer_model1.events.status.connect(self._status_update) - self.viewer_model2.events.status.connect(self._status_update) - - def _status_update(self, event): - self.viewer.status = event.value - - def _reset_view(self): - self.viewer_model1.reset_view() - self.viewer_model2.reset_view() - - def _reset_layers(self): - print(self.viewer_model1.layers) - self.viewer_model1.layers.clear() - self.viewer_model2.layers.clear() - print('cleared both layers') - - def _layer_selection_changed(self, event): - """ - update of current active layer - """ - if self._block: - return - - if event.value is None: - self.viewer_model1.layers.selection.active = None - self.viewer_model2.layers.selection.active = None - return - - self.viewer_model1.layers.selection.active = self.viewer_model1.layers[ - event.value.name - ] - self.viewer_model2.layers.selection.active = self.viewer_model2.layers[ - event.value.name - ] - - def _point_update(self, event): - - try: - for model in [self.viewer, self.viewer_model1, self.viewer_model2]: - if model.dims is event.source: - continue - model.dims.current_step = event.value - except: - 'Layer was already removed! This error likely occurs because two actions are called at the same time.' - - def _order_update(self): - order = list(self.viewer.dims.order) - if len(order) <= 2: - self.viewer_model1.dims.order = order - self.viewer_model2.dims.order = order - return - - order[-3:] = order[-2], order[-3], order[-1] - self.viewer_model1.dims.order = order - order = list(self.viewer.dims.order) - order[-3:] = order[-1], order[-2], order[-3] - self.viewer_model2.dims.order = order - - def _layer_added(self, event): - """add layer to additional viewers and connect all required events""" - self.viewer_model1.layers.insert( - event.index, copy_layer(event.value, "model1") - ) - self.viewer_model2.layers.insert( - event.index, copy_layer(event.value, "model2") - ) - for name in get_property_names(event.value): - getattr(event.value.events, name).connect( - own_partial(self._property_sync, name) - ) - - if isinstance(event.value, Labels): - event.value.events.set_data.connect(self._set_data_refresh) - self.viewer_model1.layers[ - event.value.name - ].events.set_data.connect(self._set_data_refresh) - self.viewer_model2.layers[ - event.value.name - ].events.set_data.connect(self._set_data_refresh) - if event.value.name != ".cross": - self.viewer_model1.layers[event.value.name].events.data.connect( - self._sync_data - ) - self.viewer_model2.layers[event.value.name].events.data.connect( - self._sync_data - ) - - event.value.events.name.connect(self._sync_name) - - self._order_update() - - def _sync_name(self, event): - """sync name of layers""" - index = self.viewer.layers.index(event.source) - self.viewer_model1.layers[index].name = event.source.name - self.viewer_model2.layers[index].name = event.source.name - - def _sync_data(self, event): - """sync data modification from additional viewers""" - if self._block: - return - for model in [self.viewer, self.viewer_model1, self.viewer_model2]: - layer = model.layers[event.source.name] - if layer is event.source: - continue - try: - self._block = True - layer.data = event.source.data - finally: - self._block = False - - def _set_data_refresh(self, event): - """ - synchronize data refresh between layers - """ - if self._block: - return - for model in [self.viewer, self.viewer_model1, self.viewer_model2]: - layer = model.layers[event.source.name] - if layer is event.source: - continue - try: - self._block = True - layer.refresh() - finally: - self._block = False - - def _layer_removed(self, event): - """remove layer in all viewers""" - layer_name = event.value.name - self.viewer_model1.layers.pop(layer_name) - self.viewer_model2.layers.pop(layer_name) - - def _layer_moved(self, event): - """update order of layers""" - dest_index = ( - event.new_index - if event.new_index < event.index - else event.new_index + 1 - ) - self.viewer_model1.layers.move(event.index, dest_index) - self.viewer_model2.layers.move(event.index, dest_index) - - def _property_sync(self, name, event): - """Sync layers properties (except the name)""" - if event.source not in self.viewer.layers: - return - try: - self._block = True - setattr( - self.viewer_model1.layers[event.source.name], - name, - getattr(event.source, name), - ) - setattr( - self.viewer_model2.layers[event.source.name], - name, - getattr(event.source, name), - ) - finally: - self._block = False - - -if __name__ == "__main__": - view = napari.Viewer() - dock_widget = MultipleViewerWidget(view) - cross = CrossWidget(view) - - view.window.add_dock_widget(dock_widget, name="Sample") - view.window.add_dock_widget(cross, name="Cross", area="right") - - napari.run()