Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve slicing experience #208

Merged
merged 3 commits into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,17 @@ Changelog
=========
1.0.3
-----
Changes
~~~~~~~
- The slice widget is now limited to slicing along the x/y dimensions. Support
for slicing along z has been removed for now to make the code simpler.
- The slice widget now uses a slider to select the slice value.

Bug fixes
~~~~~~~~~
- Fixed creating 1D slices of 2D images.
- Removed the limitation that only the first 99 indices could be sliced using
the slice widget.

1.0.2
-----
Expand Down
87 changes: 48 additions & 39 deletions src/napari_matplotlib/slice.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, List, Optional, Tuple

import matplotlib.ticker as mticker
import napari
import numpy as np
import numpy.typing as npt
from qtpy.QtWidgets import QComboBox, QHBoxLayout, QLabel, QSpinBox, QWidget
from qtpy.QtCore import Qt
from qtpy.QtWidgets import (
QComboBox,
QLabel,
QSlider,
QVBoxLayout,
QWidget,
)

from .base import SingleAxesWidget
from .util import Interval

__all__ = ["SliceWidget"]

_dims_sel = ["x", "y"]


class SliceWidget(SingleAxesWidget):
"""
Expand All @@ -30,28 +35,46 @@ def __init__(
# Setup figure/axes
super().__init__(napari_viewer, parent=parent)

button_layout = QHBoxLayout()
self.layout().addLayout(button_layout)

self.dim_selector = QComboBox()
self.dim_selector.addItems(["x", "y"])

self.slice_selector = QSlider(orientation=Qt.Orientation.Horizontal)

# Create widget layout
button_layout = QVBoxLayout()
button_layout.addWidget(QLabel("Slice axis:"))
button_layout.addWidget(self.dim_selector)
self.dim_selector.addItems(["x", "y", "z"])

self.slice_selectors = {}
for d in _dims_sel:
self.slice_selectors[d] = QSpinBox()
button_layout.addWidget(QLabel(f"{d}:"))
button_layout.addWidget(self.slice_selectors[d])
button_layout.addWidget(self.slice_selector)
self.layout().addLayout(button_layout)

# Setup callbacks
# Re-draw when any of the combon/spin boxes are updated
# Re-draw when any of the combo/slider is updated
self.dim_selector.currentTextChanged.connect(self._draw)
for d in _dims_sel:
self.slice_selectors[d].textChanged.connect(self._draw)
self.slice_selector.valueChanged.connect(self._draw)

self._update_layers(None)

def on_update_layers(self) -> None:
"""
Called when layer selection is updated.
"""
if not len(self.layers):
return
if self.current_dim_name == "x":
max = self._layer.data.shape[-2]
elif self.current_dim_name == "y":
max = self._layer.data.shape[-1]
else:
raise RuntimeError("dim name must be x or y")
self.slice_selector.setRange(0, max - 1)

@property
def _slice_width(self) -> int:
"""
Width of the slice being plotted.
"""
return self._layer.data.shape[self.current_dim_index]

@property
def _layer(self) -> napari.layers.Layer:
"""
Expand All @@ -73,7 +96,7 @@ def current_dim_index(self) -> int:
"""
# Note the reversed list because in napari the z-axis is the first
# numpy axis
return self._dim_names[::-1].index(self.current_dim_name)
return self._dim_names.index(self.current_dim_name)

@property
def _dim_names(self) -> List[str]:
Expand All @@ -82,45 +105,31 @@ def _dim_names(self) -> List[str]:
dimensionality of the currently selected data.
"""
if self._layer.data.ndim == 2:
return ["x", "y"]
return ["y", "x"]
elif self._layer.data.ndim == 3:
return ["x", "y", "z"]
return ["z", "y", "x"]
else:
raise RuntimeError("Don't know how to handle ndim != 2 or 3")

@property
def _selector_values(self) -> Dict[str, int]:
"""
Values of the slice selectors.

Mapping from dimension name to value.
"""
return {d: self.slice_selectors[d].value() for d in _dims_sel}

def _get_xy(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any]]:
"""
Get data for plotting.
"""
dim_index = self.current_dim_index
if self._layer.data.ndim == 2:
dim_index -= 1
x = np.arange(self._layer.data.shape[dim_index])

vals = self._selector_values
vals.update({"z": self.current_z})
val = self.slice_selector.value()

slices = []
for dim_name in self._dim_names:
if dim_name == self.current_dim_name:
# Select all data along this axis
slices.append(slice(None))
elif dim_name == "z":
# Only select the currently viewed z-index
slices.append(slice(self.current_z, self.current_z + 1))
else:
# Select specific index
val = vals[dim_name]
slices.append(slice(val, val + 1))

# Reverse since z is the first axis in napari
slices = slices[::-1]
x = np.arange(self._slice_width)
y = self._layer.data[tuple(slices)].ravel()

return x, y
Expand Down
24 changes: 24 additions & 0 deletions src/napari_matplotlib/tests/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,27 @@ def test_slice_2D(make_napari_viewer, astronaut_data):
# Need to return a copy, as original figure is too eagerley garbage
# collected by the widget
return deepcopy(fig)


def test_slice_axes(make_napari_viewer, astronaut_data):
viewer = make_napari_viewer()
viewer.theme = "light"

# Take first RGB channel
data = astronaut_data[0][:256, :, 0]
# Shape:
# x: 0 > 512
# y: 0 > 256
assert data.ndim == 2, data.shape
# Make sure data isn't square for later tests
assert data.shape[0] != data.shape[1]
viewer.add_image(data)

widget = SliceWidget(viewer)
assert widget._dim_names == ["y", "x"]
assert widget.current_dim_name == "x"
assert widget.slice_selector.value() == 0
assert widget.slice_selector.minimum() == 0
assert widget.slice_selector.maximum() == data.shape[0] - 1
# x/y are flipped in napari
assert widget._slice_width == data.shape[1]