diff --git a/CHAP/edd/processor.py b/CHAP/edd/processor.py index faf2af2..f826b69 100755 --- a/CHAP/edd/processor.py +++ b/CHAP/edd/processor.py @@ -27,8 +27,10 @@ def get_axes(nxdata, skip_axes=None): skip_axes = [] if 'unstructured_axes' in nxdata.attrs: axes = nxdata.attrs['unstructured_axes'] + elif 'axes' in nxdata.attrs: + axes = nxdata.attrs.['axes'] else: - axes = nxdata.attrs['axes'] + return [] if isinstance(axes, str): axes = [axes] return [str(a) for a in axes if a not in skip_axes] @@ -2614,27 +2616,36 @@ def add_points(nxroot, points, logger=None): raise RuntimeError( 'Unable to find detector data in strainanalysis object') axes = get_axes(nxdata_detectors[0], skip_axes=['energy']) - coords = np.asarray([nxdata_detectors[0][a].nxdata for a in axes]).T - - def get_matching_indices(all_coords, point_coords, decimals=None): - if isinstance(decimals, int): - all_coords = np.round(all_coords, decimals=decimals) - point_coords = np.round(point_coords, decimals=decimals) - coords_match = np.all(all_coords == point_coords, axis=1) - index = np.where(coords_match)[0] - return index - - # FIX: can we round to 3 decimals right away in general? - # FIX: assumes points contains a sorted and continous slice of updates - i_0 = get_matching_indices( - coords, np.asarray([points[0][a] for a in axes]), decimals=3)[0] - i_f = get_matching_indices( - coords, np.asarray([points[-1][a] for a in axes]), decimals=3)[0] - slices = {k: np.asarray([p[k] for p in points]) for k in points[0]} - for k, v in slices.items(): - if k not in axes: - logger.info(f'Updating field {k}') - nxprocess[k][i_0:i_f+1] = v + + if len(axes): + coords = np.asarray( + [nxdata_detectors[0][a].nxdata for a in axes]).T + + def get_matching_indices(all_coords, point_coords, decimals=None): + if isinstance(decimals, int): + all_coords = np.round(all_coords, decimals=decimals) + point_coords = np.round(point_coords, decimals=decimals) + coords_match = np.all(all_coords == point_coords, axis=1) + index = np.where(coords_match)[0] + return index + + # FIX: can we round to 3 decimals right away in general? + # FIX: assumes points contains a sorted and continous + # slice of updates + i_0 = get_matching_indices( + coords, + np.asarray([points[0][a] for a in axes]), decimals=3)[0] + i_f = get_matching_indices( + coords, + np.asarray([points[-1][a] for a in axes]), decimals=3)[0] + slices = {k: np.asarray([p[k] for p in points]) for k in points[0]} + for k, v in slices.items(): + if k not in axes: + logger.info(f'Updating field {k}') + nxprocess[k][i_0:i_f+1] = v + else: + for k, v in points[0].items(): + nxprocess[k].nxdata = v # Add the summed intensity for each detector for nxdata in nxdata_detectors: @@ -3208,9 +3219,13 @@ def _get_sum_axes_data(self, nxdata, detector_id, sum_axes=True): if not isinstance(sum_axes, list): if sum_axes and 'fly_axis_labels' in nxdata.attrs: sum_axes = nxdata.attrs['fly_axis_labels'] + if isinstance(sum_axes, str): + sum_axes = [sum_axes] else: sum_axes = [] - axes = [a for a in get_axes(nxdata) if a not in sum_axes] + axes = get_axes(nxdata, skip_axes=sum_axes) + if not len(axes): + return NXdata(NXfield([np.mean(data, axis=0)], 'detector_data')) dims = np.asarray([nxdata[a].nxdata for a in axes], dtype=np.float64).T sum_indices = [] unique_points = [] @@ -3390,9 +3405,12 @@ def _strain_analysis(self, strain_analysis_config): # Setup the points list with the map axes values nxdata_ref = self._nxdata_detectors[0] axes = get_axes(nxdata_ref) - points = [ - {a: nxdata_ref[a].nxdata[i] for a in axes} - for i in range(nxdata_ref[axes[0]].size)] + if len(axes): + points = [ + {a: nxdata_ref[a].nxdata[i] for a in axes} + for i in range(nxdata_ref[axes[0]].size)] + else: + points = [{}] # Loop over the detectors to fill in the nxprocess for energies, mask, mean_data, nxdata, detector in zip(