Skip to content

Commit

Permalink
add: edd.SliceNDdataReader
Browse files Browse the repository at this point in the history
  • Loading branch information
keara-soloway committed Oct 16, 2024
1 parent 9f1e9b2 commit d02bfbf
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHAP/edd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@
SetupNXdataReader,
UpdateNXdataReader,
NXdataSliceReader,
SliceNXdataReader,
)
# from CHAP.edd.writer import
53 changes: 53 additions & 0 deletions CHAP/edd/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,59 @@ def read(self, filename, dataset_id, detectors=None):
return {'coords': coords, 'signals': signals,
'attrs': attrs, 'data_points': data_points}

class SliceNXdataReader(Reader):
"""A reader class to load and slice an NXdata field from a NeXus
file. This class reads EDD (Energy Dispersive Diffraction) data
from an NXdata group and slices all fields according to the
provided slicing parameters.
"""
def read(self, filename, scan_number, inputdir=None):
"""Reads an NXdata group from a NeXus file and slices the
fields within it based on the provided scan number.
:param filename: The name of the NeXus file to read.
:type filename: str
:param scan_number: The scan number to use for slicing the
data.
:type scan_number: int
:param inputdir: The directory containing the input file,
defaults to None.
:type inputdir: str, optional
:return: The root object of the NeXus file with sliced NXdata
fields.
:rtype: NXroot
:raises ValueError: If no NXdata group is found in the file.
"""
import os
import numpy as np
from nexusformat.nexus import NXentry, NXfield

from CHAP.common import NexusReader
from CHAP.utils.general import nxcopy

reader = NexusReader()
nxroot = nxcopy(reader.read(os.path.join(inputdir, filename)))
nxdata = None
for nxname, nxobject in nxroot.items():
if isinstance(nxobject, NXentry):
nxdata = nxobject.data
if nxdata is None:
msg = 'Could not find NXdata group'
self.logger.error(msg)
raise ValueError(msg)

indices = np.argwhere(nxdata.SCAN_N.nxdata == scan_number).flatten()
for nxname, nxobject in nxdata.items():
if isinstance(nxobject, NXfield):
nxdata[nxname] = NXfield(
value=nxobject.nxdata[indices],
dtype=nxdata[nxname].dtype,
attrs=nxdata[nxname].attrs,
)

return nxroot

class UpdateNXdataReader(Reader):
"""Companion to `edd.SetupNXdataReader` and
Expand Down

0 comments on commit d02bfbf

Please sign in to comment.