Skip to content

Commit

Permalink
fix: minor bugs in common.NXdataReader and common.NXfieldReader
Browse files Browse the repository at this point in the history
  • Loading branch information
keara-soloway committed Mar 6, 2024
1 parent 79d5212 commit f09baa7
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions CHAP/common/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,14 +217,15 @@ def read(self, filename, nxpath='/'):

class NXdataReader(Reader):
"""Reader for constructing an NXdata object from components"""
def read(self, name, nxfields, signal_name, axes_names, attrs={}):
def read(self, name, nxfield_params, signal_name, axes_names, attrs={},
inputdir='.'):
"""Return a basic NXdata object constructed from components.
:param name: The name of the NXdata group.
:type name: str
:param nxfields: List of sets of parameters for `NXfieldReader`
:param nxfield_params: List of sets of parameters for `NXfieldReader`
specifying the NXfields belonging to the NXdata.
:type nxfields: list[dict]
:type nxfield_params: list[dict]
:param signal_name: Name of the signal for the NXdata (must be one
of the names of the NXfields indicated in `nxfields`)
:type signal: str
Expand All @@ -235,15 +236,17 @@ def read(self, name, nxfields, signal_name, axes_names, attrs={}):
:param attrs: Optional dictionary of additional attributes for
the NXdata
:type attrs: dict
:param inputdir: Input directory to use for `NXfieldReader`s,
defaults to `"."`
:type inputdir: str
:returns: A new NXdata object
:rtype: nexusformat.nexus.NXdata
"""
# Third party modules
from nexusformat.nexus import NXdata

# Read in NXfields
nxfields = [NXfieldReader().read(**nxfield_params)
for nxfield_params in nxfields]
nxfields = [NXfieldReader().read(**params, inputdir=inputdir)
for params in nxfield_params]
nxfields = {nxfield.nxname: nxfield for nxfield in nxfields}

# Get signal NXfield
Expand All @@ -259,11 +262,6 @@ def read(self, name, nxfields, signal_name, axes_names, attrs={}):
# Get axes NXfield(s)
if isinstance(axes_names, str):
axes_names = [axes_names]
if len(axes_names) != nxsignal.ndim:
raise ValueError(
'`axes_names` must contain the same number of entries as the '
+ 'number of dimensions of the signal NXfield '
+ 'f({nxsignal.ndim}).')
try:
nxaxes = [nxfields[axis_name] for axis_name in axes_names]
except:
Expand All @@ -281,13 +279,12 @@ def read(self, name, nxfields, signal_name, axes_names, attrs={}):
result = NXdata(signal=nxsignal, axes=nxaxes, name=name, attrs=attrs)
self.logger.info(result.tree)
return result
#return NXdata(signal=nxsignal, axes=nxaxes, name=name, attrs=attrs)


class NXfieldReader(Reader):
"""Reader for an NXfield with options to modify certain attributes."""
def read(self, filename, nxpath, nxname=None, update_attrs=None,
slice_params=None):
slice_params=None, inputdir='.'):
"""Return a copy of the indicated NXfield from the file. Name
and attributes of the returned copy may be modified with the
`nxname` and `update_attrs` keyword arguments.
Expand Down Expand Up @@ -316,9 +313,12 @@ def read(self, filename, nxpath, nxname=None, update_attrs=None,
attributes optionally modified).
:rtype: nexusformat.nexus.NXfield
"""
import os
from nexusformat.nexus import nxload, NXfield
from CHAP.utils.general import nxcopy

if not os.path.isabs(filename):
filename = os.path.join(inputdir, filename)
nxroot = nxload(filename)
nxfield = nxroot[nxpath]

Expand Down

0 comments on commit f09baa7

Please sign in to comment.