Skip to content

Commit

Permalink
fix: load data_points from pipeline data in common.SetupNXdataProcessor
Browse files Browse the repository at this point in the history
Also: add ability to specify data_types for all fields in common.SetupNXdataProcessor
  • Loading branch information
keara-soloway committed Oct 11, 2024
1 parent a04f187 commit b285cc0
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions CHAP/common/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2592,12 +2592,13 @@ def process(self, data, nxname='data',
self.coords = coords
self.signals = signals
self.attrs = attrs
self.data_points = data_points
try:
setup_params = self.unwrap_pipelinedata(data)[0]
except:
setup_params = None
if isinstance(setup_params, dict):
for a in ('coords', 'signals', 'attrs'):
for a in ('coords', 'signals', 'attrs', 'data_points'):
setup_param = setup_params.get(a)
if not getattr(self, a) and setup_param:
self.logger.info(f'Using input data from pipeline for {a}')
Expand All @@ -2609,19 +2610,19 @@ def process(self, data, nxname='data',
self.logger.warning('Ignoring all input data from pipeline')
self.shape = tuple(len(c['values']) for c in self.coords)
self.extra_nxfields = extra_nxfields
self._data_points = []
self.duplicates = duplicates
self.init_nxdata()
if data_points is not None:
for d in data_points:

if self.data_points is not None:
for d in self.data_points:
self.add_data_point(d)

return self.nxdata

def add_data_point(self, data_point):
"""Add a data point to this dataset.
1. Validate `data_point`.
2. Append `data_point` to `self._data_points`.
2. Append `data_point` to `self.data_points`.
3. Update signal `NXfield`s in `self.nxdata`.
:param data_point: Data point defining a point in the
Expand All @@ -2630,13 +2631,12 @@ def add_data_point(self, data_point):
:type data_point: dict[str, object]
:returns: None
"""
self.logger.info(f'Adding data point no. {len(self._data_points)}')
self.logger.info(f'Adding data point no. {len(self.data_points)}')
self.logger.debug(f'New data point: {data_point}')
valid, msg = self.validate_data_point(data_point)
if not valid:
self.logger.error(f'Cannot add data point: {msg}')
else:
self._data_points.append(data_point)
self.update_nxdata(data_point)

def validate_data_point(self, data_point):
Expand Down Expand Up @@ -2689,11 +2689,13 @@ def init_nxdata(self):
axes = tuple(NXfield(
value=c['values'],
name=c['name'],
attrs=c.get('attrs')) for c in self.coords)
attrs=c.get('attrs'),
dtype=c.get('dtype', 'float64')) for c in self.coords)
entries = {s['name']: NXfield(
value=np.full((*self.shape, *s['shape']), 0),
name=s['name'],
attrs=s.get('attrs')) for s in self.signals}
attrs=s.get('attrs'),
dtype=s.get('dtype', 'float64')) for s in self.signals}
extra_nxfields = [NXfield(**params) for params in self.extra_nxfields]
extra_nxfields = {f.nxname: f for f in extra_nxfields}
entries.update(extra_nxfields)
Expand Down Expand Up @@ -2846,6 +2848,12 @@ def process(self, data, nxfilename, nxdata_path, data_points=None,
try:
nxfile.writevalue(
os.path.join(nxdata_path, k), np.asarray(v), index)
# self.logger.debug(
# f'Wrote to {os.path.join(nxdata_path, k)}'
# + f' in {nxfilename}'
# + f' at index {index} '
# + f' value: {np.asarray(v)}'
# + f' (type: {type(v)})')
except Exception as exc:
self.logger.error(
f'Error updating signal {k} for new data point '
Expand Down

0 comments on commit b285cc0

Please sign in to comment.