Skip to content

Commit

Permalink
Merge pull request #172 from rayosborn/fix-mpl-import
Browse files Browse the repository at this point in the history
Miscellaneous improvements
  • Loading branch information
rayosborn authored Jun 20, 2022
2 parents f0406c7 + 5b14b6d commit 4a808ff
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/nexusformat/nexus/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def plot(self, data_group, fmt=None, xmin=None, xmax=None,
if colorbar:
cb = plt.colorbar(im)
if cmap == 'tab10':
from matplotlib import mpl_version
from matplotlib import __version__ as mpl_version
from pkg_resources import parse_version as pv
cmin, cmax = im.get_clim()
if cmax - cmin <= 9:
Expand Down
28 changes: 26 additions & 2 deletions src/nexusformat/nexus/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@
'nxgetlock', 'nxsetlock', 'nxgetlockexpiry', 'nxsetlockexpiry',
'nxgetmaxsize', 'nxsetmaxsize', 'nxgetmemory', 'nxsetmemory',
'nxgetrecursive', 'nxsetrecursive',
'nxclasses', 'nxload', 'nxsave', 'nxduplicate', 'nxdir',
'nxclasses', 'nxload', 'nxopen', 'nxsave', 'nxduplicate', 'nxdir',
'nxconsolidate', 'nxdemo', 'nxversion']

import numbers
Expand Down Expand Up @@ -2015,6 +2015,13 @@ def __bool__(self):
def __contains__(self, key):
return False

def __lt__(self, other):
"""Define ordering of NeXus objects using their names."""
if not isinstance(other, NXobject):
return False
else:
return self.nxname < other.nxname

def _setattrs(self, attrs):
for k, v in attrs.items():
self._attrs[k] = v
Expand Down Expand Up @@ -5473,6 +5480,23 @@ def __repr__(self):
else:
return f"NXroot('{self.nxname}')"

def __enter__(self):
"""Open a NeXus file for multiple operations.
Returns
-------
NXroot
Current NXroot instance.
"""
if self.nxfile:
self.nxfile.__enter__()
return self

def __exit__(self, *args):
"""Close the NeXus file."""
if self.nxfile:
self.nxfile.__exit__()

def reload(self):
"""Reload the NeXus file from disk."""
if self.nxfilemode:
Expand Down Expand Up @@ -7299,7 +7323,7 @@ def load(filename, mode='r', recursive=None, **kwargs):
return root


nxload = load
nxload = nxopen = load


def save(filename, group, mode='w', **kwargs):
Expand Down
24 changes: 23 additions & 1 deletion tests/test_files.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os

import pytest
from nexusformat.nexus.tree import NXdata, NXentry, NXFile, NXroot, nxload
from nexusformat.nexus.tree import (NXdata, NXentry, NXFile, NXroot, nxload,
nxopen)


def test_file_creation(tmpdir):
Expand Down Expand Up @@ -58,3 +59,24 @@ def test_file_recursion(tmpdir, field1, field2, recursive):
assert "entry/data/f2" in w2
assert "signal" in w2["entry/data"].attrs
assert "axes" in w2["entry/data"].attrs


def test_file_context_manager(tmpdir, field1, field2):

filename = os.path.join(tmpdir, "file.nxs")

with nxopen(filename, "w") as w1:
w1["entry"] = NXentry()
w1["entry/data"] = NXdata(field1, field2)
assert w1.nxfilename == filename
assert w1.nxfilemode == "rw"

assert os.path.exists(filename)

w2 = nxopen(filename)
assert w2.nxfilename == filename
assert w2.nxfilemode == "r"
assert "entry/data/f1" in w2
assert "entry/data/f2" in w2
assert "signal" in w2["entry/data"].attrs
assert "axes" in w2["entry/data"].attrs

0 comments on commit 4a808ff

Please sign in to comment.