Skip to content

Commit

Permalink
feat: added EDD scan_type 4 strainanalysis (oversampling)
Browse files Browse the repository at this point in the history
  • Loading branch information
rolfverberg committed Feb 17, 2024
1 parent 676a3e8 commit 9e49b56
Show file tree
Hide file tree
Showing 5 changed files with 420 additions and 94 deletions.
90 changes: 69 additions & 21 deletions CHAP/edd/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,29 +762,21 @@ def add_calibration(self, calibration):
setattr(self, field, getattr(calibration, field))
self.calibration_bin_ranges = calibration.include_bin_ranges

def get_tth_map(self, map_config, sum_fly_axes=False):
"""Return a map of 2&theta values to use -- may vary at each
def get_tth_map(self, map_shape):
"""Return the map of 2&theta values to use -- may vary at each
point in the map.
:param map_config: The map configuration with which the
returned map of 2&theta values will be used.
:type map_config: CHAP.common.models.map.MapConfig
:param map_shape: The shape of the suplied 2&theta map.
:return: Map of 2&theta values.
:rtype: np.ndarray
"""
if getattr(self, 'tth_map', None) is not None:
raise ValueError('Need to validate the shape')
if self.tth_map.shape != map_shape:
raise ValueError(
'Invalid "tth_map" field shape '
f'{self.tth_map.shape} (expected {map_shape})')
return self.tth_map
if not isinstance(sum_fly_axes, bool):
raise ValueError(
f'Invalid sum_fly_axes parameter ({sum_fly_axes})')
if not sum_fly_axes:
return np.full(map_config.shape, self.tth_calibrated)
map_shape = map_config.shape
fly_axis_labels = map_config.attrs.get('fly_axis_labels')
tth_shape = [map_shape[i] for i, dim in enumerate(map_config.dims)
if dim not in fly_axis_labels]
return np.full(tth_shape, self.tth_calibrated)
return np.full(map_shape, self.tth_calibrated)

def dict(self, *args, **kwargs):
"""Return a representation of this configuration in a
Expand Down Expand Up @@ -842,6 +834,7 @@ class StrainAnalysisConfig(BaseModel):
materials: list[MaterialConfig]
flux_file: FilePath
sum_fly_axes: Optional[StrictBool]
oversampling: Optional[dict] = {'num': 10}

_parfile: Optional[ParFile]

Expand Down Expand Up @@ -936,6 +929,43 @@ def validate_sum_fly_axes(cls, value, values):
value = True
return value

@validator('oversampling', always=True)
def validate_oversampling(cls, value, values):
"""Validate the oversampling field.
:param value: Field value to validate (`oversampling`).
:type value: bool
:param values: Dictionary of validated class field values.
:type values: dict
:return: The validated value for oversampling.
:rtype: bool
"""
# Local modules
from CHAP.utils.general import is_int

if 'start' in value and not is_int(value['start'], ge=0):
raise ValueError('Invalid "start" parameter in "oversampling" '
f'field ({value["start"]})')
if 'end' in value and not is_int(value['end'], gt=0):
raise ValueError('Invalid "end" parameter in "oversampling" '
f'field ({value["end"]})')
if 'width' in value and not is_int(value['width'], gt=0):
raise ValueError('Invalid "width" parameter in "oversampling" '
f'field ({value["width"]})')
if 'stride' in value and not is_int(value['stride'], gt=0):
raise ValueError('Invalid "stride" parameter in "oversampling" '
f'field ({value["stride"]})')
if 'num' in value and not is_int(value['num'], gt=0):
raise ValueError('Invalid "num" parameter in "oversampling" '
f'field ({value["num"]})')
if 'mode' in value and 'mode' not in ('valid', 'full'):
raise ValueError('Invalid "mode" parameter in "oversampling" '
f'field ({value["mode"]})')
if not ('width' in value or 'stride' in value or 'num' in value):
raise ValueError('Invalid input parameters, specify at least one '
'of "width", "stride" or "num"')
return value

def mca_data(self, detector=None, map_index=None):
"""Get MCA data for a single or multiple detector elements.
Expand Down Expand Up @@ -972,11 +1002,29 @@ def mca_data(self, detector=None, map_index=None):
mca_data = np.reshape(
mca_data, (*self.map_config.shape, len(mca_data[0])))
if self.sum_fly_axes and fly_axis_labels:
sum_indices = []
for axis in fly_axis_labels:
sum_indices.append(self.map_config.dims.index(axis))
return np.sum(mca_data, tuple(sorted(sum_indices)))

scan_type = self.map_config.attrs['scan_type']
if scan_type == 3:
sum_indices = []
for axis in fly_axis_labels:
sum_indices.append(self.map_config.dims.index(axis))
return np.sum(mca_data, tuple(sorted(sum_indices)))
elif scan_type == 4:
# Local modules
from CHAP.edd.utils import get_rolling_sum_spectra

return get_rolling_sum_spectra(
mca_data,
self.map_config.dims.index(fly_axis_labels[0]),
self.oversampling.get('start', 0),
self.oversampling.get('end'),
self.oversampling.get('width'),
self.oversampling.get('stride'),
self.oversampling.get('num'),
self.oversampling.get('mode', 'valid'))
else:
raise ValueError(
f'scan_type {scan_type} not implemented yet '
'in StrainAnalysisConfig.mca_data()')
else:
return np.asarray(mca_data)
else:
Expand Down
70 changes: 52 additions & 18 deletions CHAP/edd/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1592,25 +1592,37 @@ def get_nxroot(self,
select_mask_and_hkls,
)

def linkdims(nxgroup, field_dims=[]):
def linkdims(nxgroup, field_dims=[], oversampling_axis={}):
if isinstance(field_dims, dict):
field_dims = [field_dims]
if map_config.map_type == 'structured':
axes = deepcopy(map_config.dims)
for dims in field_dims:
axes.append(dims['axes'])
nxgroup.attrs['axes'] = axes
else:
axes = ['map_index']
for dims in field_dims:
axes.append(dims['axes'])
nxgroup.attrs['axes'] = axes
nxgroup.attrs[f'map_index_indices'] = 0
for dim in map_config.dims:
nxgroup.makelink(nxentry.data[dim])
if f'{dim}_indices' in nxentry.data.attrs:
nxgroup.attrs[f'{dim}_indices'] = \
nxentry.data.attrs[f'{dim}_indices']
if dim in oversampling_axis:
bin_name = dim.replace('fly_', 'bin_')
axes[axes.index(dim)] = bin_name
nxgroup[bin_name] = NXfield(
value=oversampling_axis[dim],
units=nxentry.data[dim].units,
attrs={
'long_name':
f'oversampled {nxentry.data[dim].long_name}',
'data_type': nxentry.data[dim].data_type,
'local_name':
f'oversampled {nxentry.data[dim].local_name}'})
else:
nxgroup.makelink(nxentry.data[dim])
if f'{dim}_indices' in nxentry.data.attrs:
nxgroup.attrs[f'{dim}_indices'] = \
nxentry.data.attrs[f'{dim}_indices']
nxgroup.attrs['axes'] = axes
for dims in field_dims:
nxgroup.attrs[f'{dims["axes"]}_indices'] = dims['index']

Expand Down Expand Up @@ -1640,6 +1652,24 @@ def linkdims(nxgroup, field_dims=[]):
self.logger.debug(f'mca_data_summed.shape: {mca_data_summed.shape}')
self.logger.debug(f'effective_map_shape: {effective_map_shape}')

# Check for oversampling axis and create the binned coordinates
oversampling_axis = {}
if (map_config.attrs.get('scan_type') == 4
and strain_analysis_config.sum_fly_axes):
# Local modules
from CHAP.utils.general import rolling_average

fly_axis = map_config.attrs.get('fly_axis_labels')[0]
oversampling = strain_analysis_config.oversampling
oversampling_axis[fly_axis] = rolling_average(
nxdata[fly_axis].nxdata,
start=oversampling.get('start', 0),
end=oversampling.get('end'),
width=oversampling.get('width'),
stride=oversampling.get('stride'),
num=oversampling.get('num'),
mode=oversampling.get('mode', 'valid'))

# Loop over the detectors to perform the strain analysis
for i, detector in enumerate(strain_analysis_config.detectors):

Expand Down Expand Up @@ -1740,7 +1770,8 @@ def linkdims(nxgroup, field_dims=[]):
det_nxdata = nxdetector.data
linkdims(
det_nxdata,
{'axes': 'energy', 'index': len(effective_map_shape)})
{'axes': 'energy', 'index': len(effective_map_shape)},
oversampling_axis=oversampling_axis)
mask = detector.mca_mask()
energies = mca_bin_energies[mask]
det_nxdata.energy = NXfield(value=energies, attrs={'units': 'keV'})
Expand Down Expand Up @@ -1799,7 +1830,9 @@ def linkdims(nxgroup, field_dims=[]):
fit_nxgroup.results = NXdata()
fit_nxdata = fit_nxgroup.results
linkdims(
fit_nxdata, {'axes': 'energy', 'index': len(map_config.shape)})
fit_nxdata,
{'axes': 'energy', 'index': len(map_config.shape)},
oversampling_axis=oversampling_axis)
fit_nxdata.makelink(det_nxdata.energy)
fit_nxdata.best_fit= uniform_best_fit
fit_nxdata.residuals = uniform_residuals
Expand All @@ -1811,12 +1844,12 @@ def linkdims(nxgroup, field_dims=[]):
# fit_nxdata = fit_nxgroup.fit_hkl_centers
# linkdims(fit_nxdata)
for (hkl, center_guess, centers_fit, centers_error,
amplitudes_fit, amplitudes_error, sigmas_fit,
sigmas_error) in zip(
fit_hkls, peak_locations,
uniform_fit_centers, uniform_fit_centers_errors,
uniform_fit_amplitudes, uniform_fit_amplitudes_errors,
uniform_fit_sigmas, uniform_fit_sigmas_errors):
amplitudes_fit, amplitudes_error, sigmas_fit,
sigmas_error) in zip(
fit_hkls, peak_locations,
uniform_fit_centers, uniform_fit_centers_errors,
uniform_fit_amplitudes, uniform_fit_amplitudes_errors,
uniform_fit_sigmas, uniform_fit_sigmas_errors):
hkl_name = '_'.join(str(hkl)[1:-1].split(' '))
fit_nxgroup[hkl_name] = NXparameters()
# Report initial HKL peak centers
Expand Down Expand Up @@ -1944,8 +1977,7 @@ def animate(i):
ani.save(path)
plt.close()

tth_map = detector.get_tth_map(
map_config, strain_analysis_config.sum_fly_axes)
tth_map = detector.get_tth_map(effective_map_shape)
det_nxdata.tth.nxdata = tth_map
nominal_centers = np.asarray(
[get_peak_locations(d0, tth_map) for d0 in fit_ds])
Expand All @@ -1962,7 +1994,9 @@ def animate(i):
fit_nxgroup.results = NXdata()
fit_nxdata = fit_nxgroup.results
linkdims(
fit_nxdata, {'axes': 'energy', 'index': len(map_config.shape)})
fit_nxdata,
{'axes': 'energy', 'index': len(map_config.shape)},
oversampling_axis=oversampling_axis)
fit_nxdata.makelink(det_nxdata.energy)
fit_nxdata.best_fit= unconstrained_best_fit
fit_nxdata.residuals = unconstrained_residuals
Expand Down
2 changes: 1 addition & 1 deletion CHAP/edd/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def add_fly_axis(fly_axis_index):
add_fly_axis(0)
if scan_type in (2, 3, 5):
add_fly_axis(1)
if scan_type in (4, 5):
if scan_type == 5:
scalar_data.append(dict(
label='bin_axis', units='n/a', data_type='smb_par',
name='bin_axis'))
Expand Down
Loading

0 comments on commit 9e49b56

Please sign in to comment.