diff --git a/src/nexusformat/nexus/tree.py b/src/nexusformat/nexus/tree.py index 4e4d322..d8b8cc4 100644 --- a/src/nexusformat/nexus/tree.py +++ b/src/nexusformat/nexus/tree.py @@ -297,6 +297,11 @@ def is_text(value): return False +def natural_sort(key): + """Sort numbers according to their value, not their first character""" + return [int(t) if t.isdigit() else t for t in re.split(r'(\d+)', key)] + + class NeXusError(Exception): """NeXus Error""" pass @@ -336,6 +341,7 @@ def __init__(self, name, mode=None, **kwds): """ Creates an h5py File object for reading and writing. """ + name = os.path.abspath(name) if mode == 'w4' or mode == 'wx': raise NeXusError('Only HDF5 files supported') elif mode == 'w' or mode == 'w-' or mode == 'w5': @@ -344,12 +350,14 @@ def __init__(self, name, mode=None, **kwds): self._file = h5.File(name, mode, **kwds) self._mode = 'rw' else: - if mode == 'rw': + if mode == 'rw' or mode == 'r+': + self._mode = 'rw' mode = 'r+' + else: + self._mode = 'r' self._file = h5.File(name, mode, **kwds) - self.mode = mode self._filename = self._file.filename - self._path = '' + self._path = '/' def __repr__(self): return '' % (os.path.basename(self._filename), @@ -357,19 +365,19 @@ def __repr__(self): def __getitem__(self, key): """Returns an object from the NeXus file.""" - return self._file.get(key) + return self.file.get(key) def __setitem__(self, key, value): """Sets an object value in the NeXus file.""" - self._file[key] = value + self.file[key] = value def __delitem__(self, name): """ Delete an item from a group. """ - del self._file[name] + del self.file[name] def __contains__(self, key): """Implements 'k in d' test""" - return self._file.__contains__(key) + return self.file.__contains__(key) def __enter__(self): return self.open() @@ -378,10 +386,10 @@ def __exit__(self, *args): self.close() def get(self, *args, **kwds): - return self._file.get(*args, **kwds) + return self.file.get(*args, **kwds) def copy(self, *args, **kwds): - self._file.copy(*args, **kwds) + self.file.copy(*args, **kwds) def open(self, **kwds): if not self._file.id: @@ -396,6 +404,12 @@ def close(self): if self._file.id: self._file.close() + def isopen(self): + if self._file.id: + return True + else: + return False + def readfile(self): """ Reads the NeXus file structure from the file and returns a tree of @@ -409,7 +423,7 @@ def readfile(self): root = self._readgroup('root') root._group = None root._file = self - root._filename = self.filename + root._filename = self._filename root._mode = self._mode = _mode return root @@ -437,12 +451,18 @@ def _readnxclass(self, attrs): def _readlink(self): if self._isexternal(): - link = self.get(self.nxpath, getlink=True) - return link.path, link.filename - elif 'target' in self.attrs and self.attrs['target'] != self.nxpath: - return self.attrs['target'], None + _link = self.get(self.nxpath, getlink=True) + _target, _filename = _link.path, _link.filename + elif 'target' in self.attrs: + _target = self.attrs['target'] + _filename = self.get(self.nxpath).file.filename + if _filename == self.filename: + _filename = None + if _target == self.nxpath: + _target = None else: - return None, None + _target, _filename = None, None + return _target, _filename def _readchildren(self): children = {} @@ -469,8 +489,8 @@ def _readgroup(self, name): else: nxclass = 'NXgroup' children = self._readchildren() - if self.nxpath != '/' and self._islink(): - _target, _filename = self._readlink() + _target, _filename = self._readlink() + if self.nxpath != '/' and _target is not None: group = NXlinkgroup(nxclass=nxclass, name=name, attrs=attrs, entries=children, target=_target, file=_filename) @@ -488,8 +508,8 @@ def _readdata(self, name): """ # Finally some data, but don't read it if it is big # Instead record the location, type and size - if self._islink(): - _target, _filename = self._readlink() + _target, _filename = self._readlink() + if _target is not None: if _filename is not None: try: value, shape, dtype, attrs = self.readvalues() @@ -595,12 +615,12 @@ def _writedata(self, data): with _file as f: f.copy(_path, self[self.nxparent], self.nxpath) else: - self._file.copy(_path, self[self.nxparent], self.nxpath) + self.file.copy(_path, self[self.nxparent], self.nxpath) data._uncopied_data = None elif data._memfile: data._memfile.copy('data', self[self.nxparent], self.nxpath) data._memfile = None - elif data.nxfilemode and data.nxfile.filename != self.filename: + elif data.nxfile and data.nxfile.filename != self.filename: data.nxfile.copy(data.nxpath, self[self.nxparent]) elif data.dtype is not None: if data.nxname not in self[self.nxparent]: @@ -640,6 +660,8 @@ def _writelinks(self, links): # link sources to targets for path, target in links: if path != target and path not in self['/'] and target in self['/']: + if 'target' not in self[target].attrs: + self[target].attrs['target'] = target self[path] = self[target] def readitem(self): @@ -687,12 +709,12 @@ def copyfile(self, input_file): def _rootattrs(self): from datetime import datetime - self._file.attrs['file_name'] = self.filename - self._file.attrs['file_time'] = datetime.now().isoformat() - self._file.attrs['HDF5_Version'] = h5.version.hdf5_version - self._file.attrs['h5py_version'] = h5.version.version + self.file.attrs['file_name'] = self.filename + self.file.attrs['file_time'] = datetime.now().isoformat() + self.file.attrs['HDF5_Version'] = h5.version.hdf5_version + self.file.attrs['h5py_version'] = h5.version.version from .. import __version__ - self._file.attrs['nexusformat_version'] = __version__ + self.file.attrs['nexusformat_version'] = __version__ def update(self, item): self.nxpath = item.nxpath @@ -713,14 +735,7 @@ def update(self, item): self.nxpath = item.nxpath def rename(self, old_path, new_path): - self._file['/'].move(old_path, new_path) - - def _islink(self): - _target, _ = self._readlink() - if _target is not None: - return True - else: - return False + self.file['/'].move(old_path, new_path) def _isexternal(self): try: @@ -732,10 +747,12 @@ def _isexternal(self): @property def filename(self): """File name on disk""" - return self._file.filename + return self.file.filename @property def file(self): + if not self._file.id: + self.open() return self._file @property @@ -746,8 +763,12 @@ def mode(self): def mode(self, mode): if mode == 'rw' or mode == 'r+': self._mode = 'rw' + if self.file.id and self.file.mode == 'r': + self.close() else: self._mode = 'r' + if self.file.id and self.file.mode == 'r+': + self.close() @property def attrs(self): @@ -928,7 +949,7 @@ class AttrDict(dict): def __init__(self, parent=None, attrs={}): super(AttrDict, self).__init__() - self.parent = parent + self._parent = parent self._setattrs(attrs) def _setattrs(self, attrs): @@ -942,26 +963,26 @@ def __setitem__(self, key, value): if value is None: return if isinstance(value, NXattr): - super(AttrDict, self).__setitem__(key, value) + super(AttrDict, self).__setitem__(text(key), value) else: - super(AttrDict, self).__setitem__(key, NXattr(value)) - try: - if self.parent.nxfilemode == 'rw': - with self.parent.nxfile as f: - f.update(self, self.parent.nxpath) - except Exception: - pass + super(AttrDict, self).__setitem__(text(key), NXattr(value)) + if self._parent.nxfilemode == 'rw': + with self._parent.nxfile as f: + f.update(self) def __delitem__(self, key): super(AttrDict, self).__delitem__(key) try: - if self.parent.nxfilemode == 'rw': - with self.parent.nxfile as f: - f.nxpath = self.parent.nxpath + if self._parent.nxfilemode == 'rw': + with self._parent.nxfile as f: + f.nxpath = self._parent.nxpath del f[f.nxpath].attrs[key] except Exception: pass + @property + def nxpath(self): + return self._parent.nxpath class NXattr(object): @@ -1260,7 +1281,7 @@ def rename(self, name): with self.nxfile as f: f.rename(path, self.nxpath) - def save(self, filename=None, mode='w'): + def save(self, filename=None, mode='w-'): """ Saves the NeXus object to a data file. @@ -1307,11 +1328,19 @@ def save(self, filename=None, mode='w'): elif self.nxclass == "NXentry": root = NXroot(self) else: - root = NXroot(NXentry(self)) - nx_file = NXFile(filename, mode) - nx_file.writefile(root) - root = nx_file.readfile() - nx_file.close() + root = NXroot(NXentry(self)) + if mode != 'w': + write_mode = 'w-' + else: + write_mode = 'w' + with NXFile(filename, write_mode) as f: + f.writefile(root) + if mode == 'w' or mode == 'w-': + root._mode = 'rw' + else: + root._mode = mode + root.nxfile = filename + self.set_changed() return root else: raise NeXusError("No output file specified") @@ -1383,11 +1412,12 @@ def nxname(self, value): def nxgroup(self): return self._group - def _getpath(self): - if self.nxgroup is None: - return "" - elif self.nxclass == 'NXroot': + @property + def nxpath(self): + if self.nxclass == 'NXroot': return "/" + elif self.nxgroup is None: + return "" elif isinstance(self.nxgroup, NXroot): return "/" + self.nxname else: @@ -1397,10 +1427,6 @@ def _getpath(self): else: return self.nxname - @property - def nxpath(self): - return self._getpath() - @property def nxroot(self): if self._group is None or isinstance(self, NXroot): @@ -1422,10 +1448,10 @@ def nxentry(self): @property def nxfile(self): if self._file: - return self._file.open() + return self._file _root = self.nxroot if _root._file: - return _root._file.open() + return _root._file elif _root._filename: return NXFile(_root._filename, _root._mode) else: @@ -1433,10 +1459,12 @@ def nxfile(self): @property def nxfilename(self): - try: - return self.nxfile[self.nxpath].file.filename - except Exception: - return '' + if self._filename is not None: + return os.path.abspath(self._filename) + elif self._group is not None: + return self._group.nxfilename + else: + return None @property def nxfilemode(self): @@ -2548,7 +2576,8 @@ def shape(self, value): @property def compression(self): if self.nxfilemode: - self._compression = self.nxfile[self.nxpath].compression + with self.nxfile as f: + self._compression = f[self.nxpath].compression elif self._memfile: self._compression = self._memfile['data'].compression return self._compression @@ -2564,7 +2593,8 @@ def compression(self, value): @property def fillvalue(self): if self.nxfilemode: - self._fillvalue = self.nxfile[self.nxpath].fillvalue + with self.nxfile as f: + self._fillvalue = f[self.nxpath].fillvalue elif self._memfile: self._fillvalue = self._memfile['data'].fillvalue return self._fillvalue @@ -2580,7 +2610,8 @@ def fillvalue(self, value): @property def chunks(self): if self.nxfilemode: - self._chunks = self.nxfile[self.nxpath].chunks + with self.nxfile as f: + self._chunks = f[self.nxpath].chunks elif self._memfile: self._chunks = self._memfile['data'].chunks return self._chunks @@ -2598,7 +2629,8 @@ def chunks(self, value): @property def maxshape(self): if self.nxfilemode: - self._maxshape = self.nxfile[self.nxpath].maxshape + with self.nxfile as f: + self._maxshape = f[self.nxpath].maxshape elif self._memfile: self._maxshape = self._memfile['data'].maxshape return self._maxshape @@ -2621,14 +2653,8 @@ def size(self): @property def safe_attrs(self): - _attrs = copy(self.attrs) - if 'target' in _attrs: - del _attrs['target'] - if 'signal' in _attrs: - del _attrs['signal'] - if 'axes' in _attrs: - del _attrs['axes'] - return _attrs + return {key: self.attrs[key] for key in self.attrs + if (key != 'target' and key != 'signal' and key != 'axes')} @property def reversed(self): @@ -3234,19 +3260,31 @@ def has_key(self, name): """ return self.entries.has_key(name) + def component(self, nxclass): + """ + Finds all child objects that have a particular class. + """ + return [self.entries[i] for i in sorted(self.entries, key=natural_sort) + if self.entries[i].nxclass==nxclass] + def insert(self, value, name='unknown'): """ Adds an attribute to the group. - If it is not a valid NeXus object (NXfield or NXgroup), the attribute - is converted to an NXfield. + If it is not a valid NeXus object, the attribute is converted to an + NXfield. If the object is an internal link within an externally linked + file, the linked object in the external file is copied. """ if isinstance(value, NXobject): if name == 'unknown': name = value.nxname if name in self.entries: raise NeXusError("'%s' already exists in group" % name) - self[name] = value + if (isinstance(value, NXlink) and + value.nxfilename != self.nxfilename): + self[name] = value.nxlink + else: + self[name] = value else: if name in self.entries: raise NeXusError("'%s' already exists in group" % name) @@ -3421,12 +3459,6 @@ def implot(self, **opts): except AttributeError: raise NeXusError("Data cannot be plotted") - def component(self, nxclass): - """ - Finds all child objects that have a particular class. - """ - return [E for _name,E in self.items() if E.nxclass==nxclass] - def signals(self): """ Returns a dictionary of NXfield's containing signal data. @@ -3507,13 +3539,14 @@ def __init__(self, target=None, file=None, name=None, group=None): self._attrs = AttrDict(self) self._entries = {} if isinstance(target, NXobject): - if isinstance(target, NXlink): + if file is not None: + raise NeXusError( + "Use the NXgroup makelink function for external links") + elif isinstance(target, NXlink): raise NeXusError("Cannot link to another NXlink object") if name is None: self._name = target.nxname self._target = target.nxpath - if file is None: - self._filename = target.nxfilename if isinstance(target, NXfield): self.__class__ = NXlinkfield elif isinstance(target, NXgroup): @@ -3559,9 +3592,10 @@ def _str_tree(self, indent=0, attrs=False, recursive=False): return self._str_name(indent=indent) def update(self): - if (self._filename and os.path.exists(self._filename) and - self.nxroot.nxfilemode == 'rw'): - with NXFile(self.nxroot.nxfilename) as f: + root = self.nxroot + filename, mode = root.nxfilename, root.nxfilemode + if (filename is not None and os.path.exists(filename) and mode == 'rw'): + with NXFile(filename) as f: f.update(self) item = f.readitem() if isinstance(item, NXfield): @@ -3591,11 +3625,17 @@ def nxlink(self): return self @property - def nxfilemode(self): + def nxfilename(self): if self._filename is not None: - return 'r' + if os.path.isabs(self._filename): + return self._filename + else: + return os.path.abspath(os.path.join( + os.path.dirname(self.nxroot.nxfilename), self._filename)) + elif self._group is not None: + return self._group.nxfilename else: - return self.nxlink.nxfilemode + return None class NXlinkfield(NXlink, NXfield): @@ -3630,7 +3670,7 @@ def _str_tree(self, indent=0, attrs=False, recursive=False): @property def nxlink(self): try: - if self._filename is not None: + if self.nxfilename != self.nxroot.nxfilename: return self else: return self.nxroot[self._target] @@ -3733,11 +3773,13 @@ def lock(self): """Make the tree readonly""" if self._filename: self._mode = self._file.mode = 'r' + self.set_changed() def unlock(self): """Make the tree modifiable""" if self._filename: self._mode = self._file.mode = 'rw' + self.set_changed() def backup(self, filename=None, dir=None): """Backup the NeXus file. @@ -3803,23 +3845,26 @@ def plottable_data(self): @property def nxfile(self): if self._file: - return self._file.open() + return self._file elif self._filename: return NXFile(self._filename, self._mode) else: return None @nxfile.setter - def nxfile(self, value): - if os.path.exists(value): - self._filename = value - self._file = NXFile(value, self._mode) - root = self._file.readfile() - self._entries = root.entries - self._attrs = root.attrs + def nxfile(self, filename): + if os.path.exists(filename): + self._filename = os.path.abspath(filename) + with NXFile(self._filename, 'r') as f: + root = f.readfile() + self._entries = root._entries + for entry in self._entries: + self._entries[entry]._group = self + self._attrs._setattrs(root.attrs) + self._file = NXFile(self._filename, self._mode) self.set_changed() else: - raise NeXusError("'%s' does not exist") + raise NeXusError("'%s' does not exist" % os.path.abspath(filename)) @property def nxbackup(self): @@ -4694,7 +4739,6 @@ def save(filename, group, mode='w'): else: tree = NXroot(NXentry(group)) with NXFile(filename, mode) as f: - f = NXFile(filename, mode) f.writefile(tree) f.close()