Skip to content

Commit

Permalink
ENH: Add round-trip channel name saving
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner committed Dec 3, 2024
1 parent a1a05ae commit bd2c646
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 18 deletions.
1 change: 1 addition & 0 deletions doc/changes/devel/newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added support for saving and loading channel names from FIF in :meth:`mne.channels.DigMontage.save` and :meth:`mne.channels.read_dig_fif`, by `Eric Larson`_.
29 changes: 24 additions & 5 deletions mne/_fiff/_digitization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .constants import FIFF, _coord_frame_named
from .tag import read_tag
from .tree import dir_tree_find
from .write import start_and_end_file, write_dig_points
from .write import _safe_name_list, start_and_end_file, write_dig_points

_dig_kind_dict = {
"cardinal": FIFF.FIFFV_POINT_CARDINAL,
Expand Down Expand Up @@ -162,10 +162,11 @@ def __eq__(self, other): # noqa: D105
return np.allclose(self["r"], other["r"])


def _read_dig_fif(fid, meas_info):
def _read_dig_fif(fid, meas_info, *, return_ch_names=False):
"""Read digitizer data from a FIFF file."""
isotrak = dir_tree_find(meas_info, FIFF.FIFFB_ISOTRAK)
dig = None
ch_names = None
if len(isotrak) == 0:
logger.info("Isotrak not found")
elif len(isotrak) > 1:
Expand All @@ -183,13 +184,21 @@ def _read_dig_fif(fid, meas_info):
elif kind == FIFF.FIFF_MNE_COORD_FRAME:
tag = read_tag(fid, pos)
coord_frame = _coord_frame_named.get(int(tag.data.item()))
elif kind == FIFF.FIFF_MNE_CH_NAME_LIST:
tag = read_tag(fid, pos)
ch_names = _safe_name_list(tag.data, "read", "ch_names")
for d in dig:
d["coord_frame"] = coord_frame
return _format_dig_points(dig)
out = _format_dig_points(dig)
if return_ch_names:
out = (out, ch_names)
return out


@verbose
def write_dig(fname, pts, coord_frame=None, *, overwrite=False, verbose=None):
def write_dig(
fname, pts, coord_frame=None, *, ch_names=None, overwrite=False, verbose=None
):
"""Write digitization data to a FIF file.
Parameters
Expand All @@ -203,6 +212,10 @@ def write_dig(fname, pts, coord_frame=None, *, overwrite=False, verbose=None):
If all the points have the same coordinate frame, specify the type
here. Can be None (default) if the points could have varying
coordinate frames.
ch_names : list of str | None
Channel names associated with the digitization points, if available.
.. versionadded:: 1.9
%(overwrite)s
.. versionadded:: 1.0
Expand All @@ -222,9 +235,15 @@ def write_dig(fname, pts, coord_frame=None, *, overwrite=False, verbose=None):
"Points have coord_frame entries that are incompatible with "
f"coord_frame={coord_frame}: {tuple(bad_frames)}."
)
_validate_type(ch_names, (None, list, tuple), "ch_names")
if ch_names is not None:
for ci, ch_name in enumerate(ch_names):
_validate_type(ch_name, str, f"ch_names[{ci}]")

with start_and_end_file(fname) as fid:
write_dig_points(fid, pts, block=True, coord_frame=coord_frame)
write_dig_points(
fid, pts, block=True, coord_frame=coord_frame, ch_names=ch_names
)


_cardinal_ident_mapping = {
Expand Down
6 changes: 5 additions & 1 deletion mne/_fiff/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def write_ch_info(fid, ch):
fid.write(b"\0" * (16 - len(ch_name)))


def write_dig_points(fid, dig, block=False, coord_frame=None):
def write_dig_points(fid, dig, block=False, coord_frame=None, *, ch_names=None):
"""Write a set of digitizer data points into a fif file."""
if dig is not None:
data_size = 5 * 4
Expand All @@ -406,6 +406,10 @@ def write_dig_points(fid, dig, block=False, coord_frame=None):
fid.write(np.array(d["kind"], ">i4").tobytes())
fid.write(np.array(d["ident"], ">i4").tobytes())
fid.write(np.array(d["r"][:3], ">f4").tobytes())
if ch_names is not None:
write_name_list_sanitized(
fid, FIFF.FIFF_MNE_CH_NAME_LIST, ch_names, "ch_names"
)
if block:
end_block(fid, FIFF.FIFFB_ISOTRAK)

Expand Down
29 changes: 23 additions & 6 deletions mne/channels/montage.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,9 +409,20 @@ def save(self, fname, *, overwrite=False, verbose=None):
The filename to use. Should end in .fif or .fif.gz.
%(overwrite)s
%(verbose)s
See Also
--------
mne.channels.read_dig_fif
Notes
-----
.. versionchanged:: 1.9
Added support for saving the associated channel names.
"""
coord_frame = _check_get_coord_frame(self.dig)
write_dig(fname, self.dig, coord_frame, overwrite=overwrite)
write_dig(
fname, self.dig, coord_frame, overwrite=overwrite, ch_names=self.ch_names
)

def __iadd__(self, other):
"""Add two DigMontages in place.
Expand Down Expand Up @@ -835,17 +846,23 @@ def read_dig_fif(fname):
read_dig_hpts
read_dig_localite
make_dig_montage
Notes
-----
.. versionchanged:: 1.9
Added support for saving the associated channel names.
"""
fname = _check_fname(fname, overwrite="read", must_exist=True)
# Load the dig data
f, tree = fiff_open(fname)[:2]
with f as fid:
dig = _read_dig_fif(fid, tree)
dig, ch_names = _read_dig_fif(fid, tree, return_ch_names=True)

ch_names = []
for d in dig:
if d["kind"] == FIFF.FIFFV_POINT_EEG:
ch_names.append(f"EEG{d['ident']:03d}")
if ch_names is None: # backward compat from when we didn't save the names
ch_names = []
for d in dig:
if d["kind"] == FIFF.FIFFV_POINT_EEG:
ch_names.append(f"EEG{d['ident']:03d}")

montage = DigMontage(dig=dig, ch_names=ch_names)
return montage
Expand Down
30 changes: 24 additions & 6 deletions mne/channels/tests/test_montage.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
assert_equal,
)

import mne.channels.montage
from mne import (
__file__ as _mne_file,
)
Expand Down Expand Up @@ -56,6 +57,7 @@
_BUILTIN_STANDARD_MONTAGES,
_check_get_coord_frame,
transform_to_head,
write_dig,
)
from mne.coreg import get_mni_fiducials
from mne.datasets import testing
Expand Down Expand Up @@ -1074,7 +1076,7 @@ def _ensure_fid_not_nan(info, ch_pos):


@testing.requires_testing_data
def test_fif_dig_montage(tmp_path):
def test_fif_dig_montage(tmp_path, monkeypatch):
"""Test FIF dig montage support."""
dig_montage = read_dig_fif(fif_dig_montage_fname)

Expand Down Expand Up @@ -1119,16 +1121,32 @@ def test_fif_dig_montage(tmp_path):
# Roundtrip of non-FIF start
montage = make_dig_montage(hsp=read_polhemus_fastscan(hsp), hpi=read_mrk(hpi))
elp_points = read_polhemus_fastscan(elp)
ch_pos = {f"EEG{k:03d}": pos for k, pos in enumerate(elp_points[8:], 1)}
montage += make_dig_montage(
ch_pos = {f"ECoG{k:03d}": pos for k, pos in enumerate(elp_points[3:], 1)}
assert len(elp_points) == 8 # there are only 8 but pretend the last are ECoG
other = make_dig_montage(
nasion=elp_points[0], lpa=elp_points[1], rpa=elp_points[2], ch_pos=ch_pos
)
assert other.ch_names[0].startswith("ECoG")
montage += other
assert montage.ch_names[0].startswith("ECoG")
_check_roundtrip(montage, fname_temp, "unknown")
montage = transform_to_head(montage)
_check_roundtrip(montage, fname_temp)
montage.dig[0]["coord_frame"] = FIFF.FIFFV_COORD_UNKNOWN
with pytest.raises(RuntimeError, match="Only a single coordinate"):
montage.save(fname_temp)
montage.dig[0]["coord_frame"] = FIFF.FIFFV_COORD_HEAD

# Check that old-style files can be read, too, using EEG001 etc.
def write_dig_no_ch_names(*args, **kwargs):
kwargs["ch_names"] = None
return write_dig(*args, **kwargs)

monkeypatch.setattr(mne.channels.montage, "write_dig", write_dig_no_ch_names)
montage.save(fname_temp, overwrite=True)
montage_read = read_dig_fif(fname_temp)
default_ch_names = [f"EEG{ii:03d}" for ii in range(1, 6)]
assert montage_read.ch_names == default_ch_names


@testing.requires_testing_data
Expand Down Expand Up @@ -1495,15 +1513,15 @@ def test_montage_positions_similar(fname, montage, n_eeg, n_good, bads):
assert_array_less(0, ang) # but not equal


# XXX: this does not check ch_names + it cannot work because of write_dig
def _check_roundtrip(montage, fname, coord_frame="head"):
"""Check roundtrip writing."""
montage.save(fname, overwrite=True)
montage_read = read_dig_fif(fname=fname)

assert_equal(repr(montage), repr(montage_read))
assert_equal(_check_get_coord_frame(montage_read.dig), coord_frame)
assert repr(montage) == repr(montage_read)
assert _check_get_coord_frame(montage_read.dig) == coord_frame
assert_dig_allclose(montage, montage_read)
assert montage.ch_names == montage_read.ch_names


def test_digmontage_constructor_errors():
Expand Down

0 comments on commit bd2c646

Please sign in to comment.