Skip to content

Commit

Permalink
Merge pull request #8 from rayosborn/fix-projections
Browse files Browse the repository at this point in the history
Fixes problems caused by different type-checking behavior in Python 2 and 3.
  • Loading branch information
rayosborn committed Feb 16, 2016
2 parents c593a64 + 9824044 commit ad26a32
Showing 1 changed file with 39 additions and 34 deletions.
73 changes: 39 additions & 34 deletions src/nexusformat/nexus/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,11 @@
from __future__ import (absolute_import, division, print_function)
import six

from copy import copy, deepcopy
import numbers
import os
import re
import sys
from copy import copy, deepcopy

import numpy as np
import h5py as h5
Expand Down Expand Up @@ -375,7 +376,7 @@ def __enter__(self):
return self.open()

def __exit__(self, *args):
self._file.close()
self.close()

def get(self, *args, **kwds):
return self._file.get(*args, **kwds)
Expand Down Expand Up @@ -501,7 +502,7 @@ def _writeattrs(self, attrs):
If no group or data object is open, the file attributes are returned.
"""
for name, value in attrs.iteritems():
for name, value in attrs.items():
self[self.nxpath].attrs[name] = value.nxdata

def _writedata(self, data):
Expand Down Expand Up @@ -556,9 +557,8 @@ def _writedata(self, data):
maxshape=data._maxshape,
fillvalue = data._fillvalue)
try:
value = data.nxdata
if value is not None:
self[self.nxpath][()] = value
if data._value is not None:
self[self.nxpath][()] = data._value
except NeXusError:
pass
self._writeattrs(data.attrs)
Expand Down Expand Up @@ -1214,7 +1214,7 @@ def set_unchanged(self, recursive=False):
self._changed = False

def _getclass(self):
return self._class
return text(self._class)

def _setclass(self, class_):
class_ = globals()[text(class_)]
Expand All @@ -1224,13 +1224,13 @@ def _setclass(self, class_):
self.update()

def _getname(self):
return self._name
return text(self._name)

def _setname(self, value):
if self.nxgroup:
self.nxgroup._entries[value] = self.nxgroup._entries[self._name]
del self.nxgroup._entries[self._name]
self._name = str(value)
self._name = text(value)
self.set_changed()

def _getgroup(self):
Expand Down Expand Up @@ -1562,7 +1562,7 @@ def __init__(self, value=None, name='field', dtype=None, shape=(), group=None,
self._dtype = np.dtype(dtype)
except Exception:
raise NeXusError("Invalid data type: %s" % dtype)
if isinstance(shape, int):
if isinstance(shape, numbers.Integral):
shape = [shape]
self._shape = tuple(shape)
# Append extra keywords to the attribute list
Expand Down Expand Up @@ -1659,7 +1659,7 @@ def __getitem__(self, idx):
real-space slicing should only be used on monotonically increasing (or
decreasing) one-dimensional arrays.
"""
idx = convert_index(idx,self)
idx = convert_index(idx, self)
if len(self) == 1:
result = self
elif self._value is None:
Expand Down Expand Up @@ -1923,7 +1923,7 @@ def index(self, value, max=False):
idx = idx - 1
except IndexError:
pass
return np.clip(idx, 0, len(self.nxdata)-1)
return int(np.clip(idx, 0, len(self.nxdata)-1))

def __array__(self):
"""
Expand Down Expand Up @@ -1978,7 +1978,8 @@ def __eq__(self, other):
Returns true if the values of the NXfield are the same.
"""
if isinstance(other, NXfield):
if isinstance(self.nxdata, np.ndarray) and isinstance(other.nxdata, np.ndarray):
if isinstance(self.nxdata, np.ndarray) and \
isinstance(other.nxdata, np.ndarray):
return all(self.nxdata == other.nxdata)
else:
return self.nxdata == other.nxdata
Expand All @@ -1990,7 +1991,8 @@ def __ne__(self, other):
Returns true if the values of the NXfield are not the same.
"""
if isinstance(other, NXfield):
if isinstance(self.nxdata, np.ndarray) and isinstance(other.nxdata, np.ndarray):
if isinstance(self.nxdata, np.ndarray) and \
isinstance(other.nxdata, np.ndarray):
return any(self.nxdata != other.nxdata)
else:
return self.nxdata != other.nxdata
Expand Down Expand Up @@ -2225,13 +2227,13 @@ def __unicode__(self):
return u""

def _str_value(self,indent=0):
v = str(self)
v = text(self)
if '\n' in v:
v = '\n'.join([(" "*indent)+s for s in v.split('\n')])
return v

def _str_tree(self, indent=0, attrs=False, recursive=False):
dims = 'x'.join([str(n) for n in self.shape])
dims = 'x'.join([text(n) for n in self.shape])
s = text(self)
if self.dtype == string_dtype:
s = repr(s)
Expand Down Expand Up @@ -2318,9 +2320,9 @@ def _title(self):
parent = self.nxgroup
if parent:
if 'title' in parent:
return str(parent.title)
return text(parent.title)
elif parent.nxgroup and 'title' in parent.nxgroup:
return str(parent.nxgroup.title)
return text(parent.nxgroup.title)
else:
if self.nxroot.nxname != '' and self.nxroot.nxname != 'root':
return (self.nxroot.nxname + '/' + self.nxpath.lstrip('/')).rstrip('/')
Expand Down Expand Up @@ -2377,7 +2379,7 @@ def _setdtype(self, value):
self._value = np.asarray(self._value, dtype=self._dtype)

def _getshape(self):
return self._shape
return tuple([int(i) for i in self._shape])

def _setshape(self, value):
if self.nxfilemode == 'r':
Expand Down Expand Up @@ -3091,15 +3093,15 @@ def sum(self, axis=None):
if axis is None:
return self.nxsignal.sum()
else:
if isinstance(axis, int):
if isinstance(axis, numbers.Integral):
axis = [axis]
axis = tuple(axis)
signal = NXfield(self.nxsignal.sum(axis), name=self.nxsignal.nxname,
attrs=self.nxsignal.safe_attrs)
axes = self.nxaxes
averages = []
for ax in axis:
summedaxis = axes.pop(ax)
summedaxis = deepcopy(axes.pop(ax))
summedaxis.minimum = summedaxis.nxdata[0]
summedaxis.maximum = summedaxis.nxdata[-1]
averages.append(NXfield(
Expand Down Expand Up @@ -3228,6 +3230,7 @@ def _getentries(self):
nxtitle = property(_title, "Property: Group title")
entries = property(_getentries,doc="Property: NeXus objects within group")


class NXlink(NXobject):

"""
Expand Down Expand Up @@ -3761,7 +3764,7 @@ def __setitem__(self, idx, value):
if is_text(idx):
NXgroup.__setitem__(self, idx, value)
elif self.nxsignal is not None:
if isinstance(idx, int) or isinstance(idx, slice):
if isinstance(idx, numbers.Integral) or isinstance(idx, slice):
axes = self.nxaxes
idx = convert_index(idx, axes[0])
self.nxsignal[idx] = value
Expand Down Expand Up @@ -3914,7 +3917,7 @@ def project(self, axes, limits):
This assumes that the data is at least two-dimensional.
"""
if not isinstance(axes, list):
if not isinstance(axes, list) and not isinstance(axes, tuple):
axes = [axes]
if len(limits) < len(self.nxsignal.shape):
raise NeXusError("Too few limits specified")
Expand All @@ -3926,7 +3929,7 @@ def project(self, axes, limits):
result = self[idx]
idx, slab_axes = list(idx), list(projection_axes)
for slab_axis in slab_axes:
if isinstance(idx[slab_axis], int):
if isinstance(idx[slab_axis], numbers.Integral):
idx.pop(slab_axis)
projection_axes.pop(projection_axes.index(slab_axis))
for i in range(len(projection_axes)):
Expand All @@ -3942,7 +3945,7 @@ def project(self, axes, limits):
return result

def slab(self, idx):
if (isinstance(idx, int) or isinstance(idx, float) or isinstance(idx, slice)):
if (isinstance(idx, numbers.Real) or isinstance(idx, slice)):
idx = [idx]
signal = self.nxsignal
axes = self.nxaxes
Expand Down Expand Up @@ -4018,7 +4021,7 @@ def plot(self, fmt='', xmin=None, xmax=None, ymin=None, ymax=None,
if plotview is None:
raise ImportError
except ImportError:
from nexusformat.nexus.plot import plotview
from .plot import plotview

# Check there is a plottable signal
if self.nxsignal is None:
Expand Down Expand Up @@ -4062,8 +4065,8 @@ def _signal(self):
if 'signal' in self.attrs and self.attrs['signal'] in self:
return self[self.attrs['signal']]
for obj in self.values():
if 'signal' in obj.attrs and str(obj.signal) == '1':
if isinstance(self[obj.nxname],NXlink):
if 'signal' in obj.attrs and text(obj.signal) == '1':
if isinstance(self[obj.nxname], NXlink):
return self[obj.nxname].nxlink
else:
return self[obj.nxname]
Expand Down Expand Up @@ -4115,7 +4118,7 @@ def _set_axes(self, axes):
The argument should be a list of valid NXfields within the group.
"""
if not isinstance(axes, list):
if not isinstance(axes, list) and not isinstance(axes, tuple):
axes = [axes]
for axis in axes:
if axis.nxname not in self:
Expand Down Expand Up @@ -4254,16 +4257,17 @@ def convert_index(idx, axis):
if len(axis) == 1:
idx = 0
elif isinstance(idx, slice) and \
(idx.start is None or isinstance(idx.start, int)) and \
(idx.stop is None or isinstance(idx.stop, int)):
(idx.start is None or isinstance(idx.start, numbers.Integral)) and \
(idx.stop is None or isinstance(idx.stop, numbers.Integral)):
if idx.start is not None and idx.stop is not None:
if idx.stop == idx.start or idx.stop == idx.start + 1:
idx = idx.start
elif isinstance(idx, slice):
if isinstance(idx.start, NXfield) and isinstance(idx.stop, NXfield):
idx = slice(idx.start.nxdata, idx.stop.nxdata)
if ((axis.reversed and idx.start < idx.stop) or
(not axis.reversed and idx.start > idx.stop)):
if (idx.start is not None and idx.stop is not None and
((axis.reversed and idx.start < idx.stop) or
(not axis.reversed and idx.start > idx.stop))):
idx = slice(idx.stop, idx.start)
if idx.start is None:
start = None
Expand All @@ -4279,7 +4283,8 @@ def convert_index(idx, axis):
idx = start
else:
idx = slice(start, stop)
elif isinstance(idx, float):
elif not isinstance(idx, numbers.Integral) and \
isinstance(idx, numbers.Real):
idx = axis.index(idx)
return idx

Expand Down

0 comments on commit ad26a32

Please sign in to comment.