Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Support axes arg for field mode #12655

Merged
merged 5 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/changes/devel/12655.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Added support for passing ``axes`` to :func:`mne.viz.plot_head_positions` when
``mode='field'``, by `Eric Larson`_.
19 changes: 15 additions & 4 deletions mne/viz/_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def plot_head_positions(
mode="traces",
cmap="viridis",
direction="z",
*,
show=True,
destination=None,
info=None,
Expand Down Expand Up @@ -169,9 +170,11 @@ def plot_head_positions(

.. versionadded:: 0.16
axes : array-like, shape (3, 2)
The matplotlib axes to use. Only used for ``mode == 'traces'``.
The matplotlib axes to use.

.. versionadded:: 0.16
.. versionchanged:: 1.8
Added support for making use of this argument when ``mode="field"``.

Returns
-------
Expand All @@ -193,7 +196,9 @@ def plot_head_positions(

if not isinstance(pos, (list, tuple)):
pos = [pos]
pos = list(pos) # make our own mutable copy
for ii, p in enumerate(pos):
_validate_type(p, np.ndarray, f"pos[{ii}]")
p = np.array(p, float)
if p.ndim != 2 or p.shape[1] != 10:
raise ValueError(
Expand Down Expand Up @@ -315,9 +320,15 @@ def plot_head_positions(
from mpl_toolkits.mplot3d import Axes3D # noqa: F401, analysis:ignore
from mpl_toolkits.mplot3d.art3d import Line3DCollection

fig, ax = plt.subplots(
1, subplot_kw=dict(projection="3d"), layout="constrained"
)
_validate_type(axes, (Axes3D, None), "ax", extra="when mode='field'")
if axes is None:
_, ax = plt.subplots(
1, subplot_kw=dict(projection="3d"), layout="constrained"
)
else:
ax = axes
fig = ax.get_figure()
del axes

# First plot the trajectory as a colormap:
# http://matplotlib.org/examples/pylab_examples/multicolored_line.html
Expand Down
21 changes: 14 additions & 7 deletions mne/viz/tests/test_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,20 @@ def test_plot_head_positions():
pos = np.random.RandomState(0).randn(4, 10)
pos[:, 0] = np.arange(len(pos))
destination = (0.0, 0.0, 0.04)
with _record_warnings(): # old MPL will cause a warning
plot_head_positions(pos)
plot_head_positions(pos, mode="field", info=info, destination=destination)
plot_head_positions([pos, pos]) # list support
pytest.raises(ValueError, plot_head_positions, ["pos"])
pytest.raises(ValueError, plot_head_positions, pos[:, :9])
pytest.raises(ValueError, plot_head_positions, pos, "foo")
plot_head_positions(pos)
plot_head_positions(pos, mode="field", info=info, destination=destination)
plot_head_positions([pos, pos]) # list support
fig, ax = plt.subplots()
with pytest.raises(TypeError, match="instance of Axes3D"):
plot_head_positions(pos, mode="field", info=info, axes=ax)
fig, ax = plt.subplots(subplot_kw=dict(projection="3d"))
plot_head_positions(pos, mode="field", info=info, axes=ax)
with pytest.raises(TypeError, match="must be an instance of ndarray"):
plot_head_positions(["pos"])
larsoner marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(ValueError, match="must be dim"):
plot_head_positions(pos[:, :9])
with pytest.raises(ValueError, match="Allowed values"):
plot_head_positions(pos, "foo")
with pytest.raises(ValueError, match="shape"):
plot_head_positions(pos, axes=1.0)

Expand Down
2 changes: 1 addition & 1 deletion tools/install_pre_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ python -m pip install $STD_ARGS vtk
python -c "import vtk"

echo "PyVista"
python -m pip install $STD_ARGS "git+https://github.com/adeak/pyvista.git@fix_numpy_2" # pyvista/pyvista
python -m pip install $STD_ARGS "git+https://github.com/pyvista/pyvista"

echo "picard"
python -m pip install $STD_ARGS git+https://github.com/pierreablin/picard
Expand Down
Loading