Skip to content

Commit

Permalink
Merge pull request #208 from dstansby/fix-slicing
Browse files Browse the repository at this point in the history
Improve slicing experience
  • Loading branch information
dstansby authored Aug 25, 2023
2 parents d7e88a9 + 9eb43ee commit 3a8261a
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 39 deletions.
8 changes: 8 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,17 @@ Changelog
=========
1.0.3
-----
Changes
~~~~~~~
- The slice widget is now limited to slicing along the x/y dimensions. Support
for slicing along z has been removed for now to make the code simpler.
- The slice widget now uses a slider to select the slice value.

Bug fixes
~~~~~~~~~
- Fixed creating 1D slices of 2D images.
- Removed the limitation that only the first 99 indices could be sliced using
the slice widget.

1.0.2
-----
Expand Down
87 changes: 48 additions & 39 deletions src/napari_matplotlib/slice.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, List, Optional, Tuple

import matplotlib.ticker as mticker
import napari
import numpy as np
import numpy.typing as npt
from qtpy.QtWidgets import QComboBox, QHBoxLayout, QLabel, QSpinBox, QWidget
from qtpy.QtCore import Qt
from qtpy.QtWidgets import (
QComboBox,
QLabel,
QSlider,
QVBoxLayout,
QWidget,
)

from .base import SingleAxesWidget
from .util import Interval

__all__ = ["SliceWidget"]

_dims_sel = ["x", "y"]


class SliceWidget(SingleAxesWidget):
"""
Expand All @@ -30,28 +35,46 @@ def __init__(
# Setup figure/axes
super().__init__(napari_viewer, parent=parent)

button_layout = QHBoxLayout()
self.layout().addLayout(button_layout)

self.dim_selector = QComboBox()
self.dim_selector.addItems(["x", "y"])

self.slice_selector = QSlider(orientation=Qt.Orientation.Horizontal)

# Create widget layout
button_layout = QVBoxLayout()
button_layout.addWidget(QLabel("Slice axis:"))
button_layout.addWidget(self.dim_selector)
self.dim_selector.addItems(["x", "y", "z"])

self.slice_selectors = {}
for d in _dims_sel:
self.slice_selectors[d] = QSpinBox()
button_layout.addWidget(QLabel(f"{d}:"))
button_layout.addWidget(self.slice_selectors[d])
button_layout.addWidget(self.slice_selector)
self.layout().addLayout(button_layout)

# Setup callbacks
# Re-draw when any of the combon/spin boxes are updated
# Re-draw when any of the combo/slider is updated
self.dim_selector.currentTextChanged.connect(self._draw)
for d in _dims_sel:
self.slice_selectors[d].textChanged.connect(self._draw)
self.slice_selector.valueChanged.connect(self._draw)

self._update_layers(None)

def on_update_layers(self) -> None:
"""
Called when layer selection is updated.
"""
if not len(self.layers):
return
if self.current_dim_name == "x":
max = self._layer.data.shape[-2]
elif self.current_dim_name == "y":
max = self._layer.data.shape[-1]
else:
raise RuntimeError("dim name must be x or y")
self.slice_selector.setRange(0, max - 1)

@property
def _slice_width(self) -> int:
"""
Width of the slice being plotted.
"""
return self._layer.data.shape[self.current_dim_index]

@property
def _layer(self) -> napari.layers.Layer:
"""
Expand All @@ -73,7 +96,7 @@ def current_dim_index(self) -> int:
"""
# Note the reversed list because in napari the z-axis is the first
# numpy axis
return self._dim_names[::-1].index(self.current_dim_name)
return self._dim_names.index(self.current_dim_name)

@property
def _dim_names(self) -> List[str]:
Expand All @@ -82,45 +105,31 @@ def _dim_names(self) -> List[str]:
dimensionality of the currently selected data.
"""
if self._layer.data.ndim == 2:
return ["x", "y"]
return ["y", "x"]
elif self._layer.data.ndim == 3:
return ["x", "y", "z"]
return ["z", "y", "x"]
else:
raise RuntimeError("Don't know how to handle ndim != 2 or 3")

@property
def _selector_values(self) -> Dict[str, int]:
"""
Values of the slice selectors.
Mapping from dimension name to value.
"""
return {d: self.slice_selectors[d].value() for d in _dims_sel}

def _get_xy(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any]]:
"""
Get data for plotting.
"""
dim_index = self.current_dim_index
if self._layer.data.ndim == 2:
dim_index -= 1
x = np.arange(self._layer.data.shape[dim_index])

vals = self._selector_values
vals.update({"z": self.current_z})
val = self.slice_selector.value()

slices = []
for dim_name in self._dim_names:
if dim_name == self.current_dim_name:
# Select all data along this axis
slices.append(slice(None))
elif dim_name == "z":
# Only select the currently viewed z-index
slices.append(slice(self.current_z, self.current_z + 1))
else:
# Select specific index
val = vals[dim_name]
slices.append(slice(val, val + 1))

# Reverse since z is the first axis in napari
slices = slices[::-1]
x = np.arange(self._slice_width)
y = self._layer.data[tuple(slices)].ravel()

return x, y
Expand Down
24 changes: 24 additions & 0 deletions src/napari_matplotlib/tests/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,27 @@ def test_slice_2D(make_napari_viewer, astronaut_data):
# Need to return a copy, as original figure is too eagerley garbage
# collected by the widget
return deepcopy(fig)


def test_slice_axes(make_napari_viewer, astronaut_data):
viewer = make_napari_viewer()
viewer.theme = "light"

# Take first RGB channel
data = astronaut_data[0][:256, :, 0]
# Shape:
# x: 0 > 512
# y: 0 > 256
assert data.ndim == 2, data.shape
# Make sure data isn't square for later tests
assert data.shape[0] != data.shape[1]
viewer.add_image(data)

widget = SliceWidget(viewer)
assert widget._dim_names == ["y", "x"]
assert widget.current_dim_name == "x"
assert widget.slice_selector.value() == 0
assert widget.slice_selector.minimum() == 0
assert widget.slice_selector.maximum() == data.shape[0] - 1
# x/y are flipped in napari
assert widget._slice_width == data.shape[1]

0 comments on commit 3a8261a

Please sign in to comment.