diff --git a/src/nexusformat/nexus/plot.py b/src/nexusformat/nexus/plot.py new file mode 100644 index 0000000..4f59416 --- /dev/null +++ b/src/nexusformat/nexus/plot.py @@ -0,0 +1,206 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +#----------------------------------------------------------------------------- +# Copyright (c) 2013, NeXpy Development Team. +# +# Author: Paul Kienzle, Ray Osborn +# +# Distributed under the terms of the Modified BSD License. +# +# The full license is in the file COPYING, distributed with this software. +#----------------------------------------------------------------------------- + +""" +This module provides a standard Matplotlib plotting option to the NeXus Python +API +""" +import numpy as np +from nexusformat.nexus import NXfield + + +def _fixaxes(signal, axes): + """ + Remove length-one dimensions from plottable data + """ + shape = list(signal.shape) + while 1 in shape: shape.remove(1) + newaxes = [] + for axis in axes: + if axis.size > 1: newaxes.append(axis) + return signal.nxdata.view().reshape(shape), newaxes + + +def centers(signal, axes): + """ + Return the centers of the axes. + + This works regardless if the axes contain bin boundaries or centers. + """ + def findc(axis, dimlen): + if axis.shape[0] == dimlen+1: + return (axis.nxdata[:-1] + axis.nxdata[1:])/2 + else: + assert axis.shape[0] == dimlen + return axis.nxdata + return [findc(a,signal.shape[i]) for i,a in enumerate(axes)] + + +def label(field): + """ + Return a label for a data field suitable for use on a graph axis. + """ + if hasattr(field, 'long_name'): + return field.long_name + elif hasattr(field, 'units'): + return "%s (%s)"%(field.nxname, field.units) + else: + return field.nxname + + +class PylabPlotter(object): + + """ + Matplotlib plotter class for NeXus data. + """ + + def plot(self, data, fmt, xmin, xmax, ymin, ymax, zmin, zmax, **opts): + """ + Plot the data entry. + + Raises NeXusError if the data cannot be plotted. + """ + try: + import matplotlib.pyplot as plt + except ImportError: + raise NeXusError("Default plotting package (matplotlib) not available.") + + over = False + if "over" in opts.keys(): + if opts["over"]: + over = True + del opts["over"] + + log = logx = logy = False + if "log" in opts.keys(): + if opts["log"]: log = True + del opts["log"] + if "logy" in opts.keys(): + if opts["logy"]: logy = True + del opts["logy"] + if "logx" in opts.keys(): + if opts["logx"]: logx = True + del opts["logx"] + if fmt == '': + fmt = 'o' + + if over: + plt.autoscale(enable=False) + else: + plt.autoscale(enable=True) + plt.clf() + + signal = data.nxsignal + axes = data.nxaxes + errors = data.nxerrors + title = data.nxtitle + + # Provide a new view of the data if there is a dimension of length 1 + if 1 in signal.shape: + data, axes = _fixaxes(signal, axes) + else: + data = signal.nxdata + + # Find the centers of the bins for histogrammed data + axis_data = centers(data, axes) + + #One-dimensional Plot + if len(data.shape) == 1: + if hasattr(signal, 'units'): + if not errors and signal.units == 'counts': + errors = NXfield(np.sqrt(data)) + if errors: + ebars = errors.nxdata + plt.errorbar(axis_data[0], data, ebars, fmt=fmt, **opts) + else: + plt.plot(axis_data[0], data, fmt, **opts) + if not over: + ax = plt.gca() + xlo, xhi = ax.set_xlim(auto=True) + ylo, yhi = ax.set_ylim(auto=True) + if xmin: xlo = xmin + if xmax: xhi = xmax + ax.set_xlim(xlo, xhi) + if ymin: ylo = ymin + if ymax: yhi = ymax + ax.set_ylim(ylo, yhi) + if logx: ax.set_xscale('symlog') + if log or logy: ax.set_yscale('symlog') + plt.xlabel(label(axes[0])) + plt.ylabel(label(signal)) + plt.title(title) + + #Two dimensional plot + else: + from matplotlib.image import NonUniformImage + from matplotlib.colors import LogNorm, Normalize + + if len(data.shape) > 2: + slab = [slice(None), slice(None)] + for _dim in data.shape[2:]: + slab.append(0) + data = data[slab].view().reshape(data.shape[:2]) + print "Warning: Only the top 2D slice of the data is plotted" + + x = axis_data[1] + y = axis_data[0] + if not zmin: + zmin = np.nanmin(data) + if not zmax: + zmax = np.nanmax(data) + + if log: + zmin = max(zmin, 0.01) + zmax = max(zmax, 0.01) + opts["norm"] = LogNorm(zmin, zmax) + else: + opts["norm"] = Normalize(zmin, zmax) + + ax = plt.gca() + extent = (x[0],x[-1],y[0],y[-1]) + im = NonUniformImage(ax, extent=extent, origin=None, **opts) + im.set_data(x, y, data) + im.get_cmap().set_bad('k', 1.0) + ax.images.append(im) + xlo, xhi = ax.set_xlim(x[0], x[-1]) + ylo, yhi = ax.set_ylim(y[0], y[-1]) + if xmin: + xlo = xmin + else: + xlo = x[0] + if xmax: + xhi = xmax + else: + xhi = x[-1] + if ymin: + yhi = ymin + else: + yhi = y[0] + if ymax: + yhi = ymax + else: + yhi = y[-1] + ax.set_xlim(xlo, xhi) + ax.set_ylim(ylo, yhi) + plt.xlabel(label(axes[0])) + plt.ylabel(label(axes[1])) + plt.title(title) + plt.colorbar(im) + + plt.gcf().canvas.draw_idle() + plt.ion() + plt.show() + + +plotview = PylabPlotter() + diff --git a/src/nexusformat/nexus/tree.py b/src/nexusformat/nexus/tree.py index 7a9b893..d61f2c7 100644 --- a/src/nexusformat/nexus/tree.py +++ b/src/nexusformat/nexus/tree.py @@ -997,7 +997,7 @@ def walk(self): if False: yield - def dir(self,attrs=False,recursive=False): + def dir(self, attrs=False, recursive=False): """ Prints the object directory. @@ -1006,7 +1006,7 @@ def dir(self,attrs=False,recursive=False): displayed. If 'recursive' is True, the contents of child groups are also displayed. """ - print self._str_tree(attrs=attrs,recursive=recursive) + print self._str_tree(attrs=attrs, recursive=recursive) @property def tree(self): @@ -1017,7 +1017,7 @@ def tree(self): It invokes the 'dir' method with both 'attrs' and 'recursive' set to True. """ - return self._str_tree(attrs=True,recursive=True) + return self._str_tree(attrs=True, recursive=True) def rename(self, name): if self.nxfilemode == 'r': @@ -2262,7 +2262,10 @@ def plot(self, fmt='', xmin=None, xmax=None, ymin=None, ymax=None, Raises NeXusError if the data could not be plotted. """ - from nexpy.gui.plotview import plotview + try: + from nexpy.gui.plotview import plotview + except ImportError: + from nexusformat.nexus.plot import plotview if self.is_plottable(): if 'axes' in self.attrs.keys(): @@ -3042,17 +3045,12 @@ def _set_signal(self, signal): The argument should be a valid NXfield within the group. """ current_signal = self._signal() - if current_signal is not None: - if current_signal is not signal: - current_signal.attrs['signal'] = 2 - if 'axes' in self.attrs and 'axes' not in current_signal.attrs: - current_signal.attrs['axes'] = self.attrs['axes'] - if 'axes' not in self.attrs and 'axes' in current_signal.attrs: - self.attrs['signal'] = current_signal.attrs['axes'] + if current_signal is not None and current_signal is not signal: + if 'signal' in current_signal.attrs: + del current_signal.attrs['signal'] self.attrs['signal'] = signal.nxname if signal.nxname not in self: self[signal.nxname] = signal - self[signal.nxname].attrs['signal'] = 1 return self[signal.nxname] def _axes(self): @@ -3168,8 +3166,13 @@ def plot(self, fmt='', xmin=None, xmax=None, ymin=None, ymax=None, Raises NeXusError if the data could not be plotted. """ - from nexpy.gui.plotview import plotview - + try: + from nexpy.gui.plotview import plotview + if plotview is None: + raise ImportError + except ImportError: + from nexusformat.nexus.plot import plotview + data = self if self.nxclass == "NXroot": try: