diff --git a/CHAP/edd/models.py b/CHAP/edd/models.py index 2589c43..31abc14 100755 --- a/CHAP/edd/models.py +++ b/CHAP/edd/models.py @@ -762,29 +762,21 @@ def add_calibration(self, calibration): setattr(self, field, getattr(calibration, field)) self.calibration_bin_ranges = calibration.include_bin_ranges - def get_tth_map(self, map_config, sum_fly_axes=False): - """Return a map of 2&theta values to use -- may vary at each + def get_tth_map(self, map_shape): + """Return the map of 2&theta values to use -- may vary at each point in the map. - :param map_config: The map configuration with which the - returned map of 2&theta values will be used. - :type map_config: CHAP.common.models.map.MapConfig + :param map_shape: The shape of the suplied 2&theta map. :return: Map of 2&theta values. :rtype: np.ndarray """ if getattr(self, 'tth_map', None) is not None: - raise ValueError('Need to validate the shape') + if self.tth_map.shape != map_shape: + raise ValueError( + 'Invalid "tth_map" field shape ' + f'{self.tth_map.shape} (expected {map_shape})') return self.tth_map - if not isinstance(sum_fly_axes, bool): - raise ValueError( - f'Invalid sum_fly_axes parameter ({sum_fly_axes})') - if not sum_fly_axes: - return np.full(map_config.shape, self.tth_calibrated) - map_shape = map_config.shape - fly_axis_labels = map_config.attrs.get('fly_axis_labels') - tth_shape = [map_shape[i] for i, dim in enumerate(map_config.dims) - if dim not in fly_axis_labels] - return np.full(tth_shape, self.tth_calibrated) + return np.full(map_shape, self.tth_calibrated) def dict(self, *args, **kwargs): """Return a representation of this configuration in a @@ -842,6 +834,7 @@ class StrainAnalysisConfig(BaseModel): materials: list[MaterialConfig] flux_file: FilePath sum_fly_axes: Optional[StrictBool] + oversampling: Optional[dict] = {'num': 10} _parfile: Optional[ParFile] @@ -936,6 +929,43 @@ def validate_sum_fly_axes(cls, value, values): value = True return value + @validator('oversampling', always=True) + def validate_oversampling(cls, value, values): + """Validate the oversampling field. + + :param value: Field value to validate (`oversampling`). + :type value: bool + :param values: Dictionary of validated class field values. + :type values: dict + :return: The validated value for oversampling. + :rtype: bool + """ + # Local modules + from CHAP.utils.general import is_int + + if 'start' in value and not is_int(value['start'], ge=0): + raise ValueError('Invalid "start" parameter in "oversampling" ' + f'field ({value["start"]})') + if 'end' in value and not is_int(value['end'], gt=0): + raise ValueError('Invalid "end" parameter in "oversampling" ' + f'field ({value["end"]})') + if 'width' in value and not is_int(value['width'], gt=0): + raise ValueError('Invalid "width" parameter in "oversampling" ' + f'field ({value["width"]})') + if 'stride' in value and not is_int(value['stride'], gt=0): + raise ValueError('Invalid "stride" parameter in "oversampling" ' + f'field ({value["stride"]})') + if 'num' in value and not is_int(value['num'], gt=0): + raise ValueError('Invalid "num" parameter in "oversampling" ' + f'field ({value["num"]})') + if 'mode' in value and 'mode' not in ('valid', 'full'): + raise ValueError('Invalid "mode" parameter in "oversampling" ' + f'field ({value["mode"]})') + if not ('width' in value or 'stride' in value or 'num' in value): + raise ValueError('Invalid input parameters, specify at least one ' + 'of "width", "stride" or "num"') + return value + def mca_data(self, detector=None, map_index=None): """Get MCA data for a single or multiple detector elements. @@ -972,11 +1002,29 @@ def mca_data(self, detector=None, map_index=None): mca_data = np.reshape( mca_data, (*self.map_config.shape, len(mca_data[0]))) if self.sum_fly_axes and fly_axis_labels: - sum_indices = [] - for axis in fly_axis_labels: - sum_indices.append(self.map_config.dims.index(axis)) - return np.sum(mca_data, tuple(sorted(sum_indices))) - + scan_type = self.map_config.attrs['scan_type'] + if scan_type == 3: + sum_indices = [] + for axis in fly_axis_labels: + sum_indices.append(self.map_config.dims.index(axis)) + return np.sum(mca_data, tuple(sorted(sum_indices))) + elif scan_type == 4: + # Local modules + from CHAP.edd.utils import get_rolling_sum_spectra + + return get_rolling_sum_spectra( + mca_data, + self.map_config.dims.index(fly_axis_labels[0]), + self.oversampling.get('start', 0), + self.oversampling.get('end'), + self.oversampling.get('width'), + self.oversampling.get('stride'), + self.oversampling.get('num'), + self.oversampling.get('mode', 'valid')) + else: + raise ValueError( + f'scan_type {scan_type} not implemented yet ' + 'in StrainAnalysisConfig.mca_data()') else: return np.asarray(mca_data) else: diff --git a/CHAP/edd/processor.py b/CHAP/edd/processor.py index 97aeb2f..91bb27d 100755 --- a/CHAP/edd/processor.py +++ b/CHAP/edd/processor.py @@ -1592,25 +1592,37 @@ def get_nxroot(self, select_mask_and_hkls, ) - def linkdims(nxgroup, field_dims=[]): + def linkdims(nxgroup, field_dims=[], oversampling_axis={}): if isinstance(field_dims, dict): field_dims = [field_dims] if map_config.map_type == 'structured': axes = deepcopy(map_config.dims) for dims in field_dims: axes.append(dims['axes']) - nxgroup.attrs['axes'] = axes else: axes = ['map_index'] for dims in field_dims: axes.append(dims['axes']) - nxgroup.attrs['axes'] = axes nxgroup.attrs[f'map_index_indices'] = 0 for dim in map_config.dims: - nxgroup.makelink(nxentry.data[dim]) - if f'{dim}_indices' in nxentry.data.attrs: - nxgroup.attrs[f'{dim}_indices'] = \ - nxentry.data.attrs[f'{dim}_indices'] + if dim in oversampling_axis: + bin_name = dim.replace('fly_', 'bin_') + axes[axes.index(dim)] = bin_name + nxgroup[bin_name] = NXfield( + value=oversampling_axis[dim], + units=nxentry.data[dim].units, + attrs={ + 'long_name': + f'oversampled {nxentry.data[dim].long_name}', + 'data_type': nxentry.data[dim].data_type, + 'local_name': + f'oversampled {nxentry.data[dim].local_name}'}) + else: + nxgroup.makelink(nxentry.data[dim]) + if f'{dim}_indices' in nxentry.data.attrs: + nxgroup.attrs[f'{dim}_indices'] = \ + nxentry.data.attrs[f'{dim}_indices'] + nxgroup.attrs['axes'] = axes for dims in field_dims: nxgroup.attrs[f'{dims["axes"]}_indices'] = dims['index'] @@ -1640,6 +1652,24 @@ def linkdims(nxgroup, field_dims=[]): self.logger.debug(f'mca_data_summed.shape: {mca_data_summed.shape}') self.logger.debug(f'effective_map_shape: {effective_map_shape}') + # Check for oversampling axis and create the binned coordinates + oversampling_axis = {} + if (map_config.attrs.get('scan_type') == 4 + and strain_analysis_config.sum_fly_axes): + # Local modules + from CHAP.utils.general import rolling_average + + fly_axis = map_config.attrs.get('fly_axis_labels')[0] + oversampling = strain_analysis_config.oversampling + oversampling_axis[fly_axis] = rolling_average( + nxdata[fly_axis].nxdata, + start=oversampling.get('start', 0), + end=oversampling.get('end'), + width=oversampling.get('width'), + stride=oversampling.get('stride'), + num=oversampling.get('num'), + mode=oversampling.get('mode', 'valid')) + # Loop over the detectors to perform the strain analysis for i, detector in enumerate(strain_analysis_config.detectors): @@ -1740,7 +1770,8 @@ def linkdims(nxgroup, field_dims=[]): det_nxdata = nxdetector.data linkdims( det_nxdata, - {'axes': 'energy', 'index': len(effective_map_shape)}) + {'axes': 'energy', 'index': len(effective_map_shape)}, + oversampling_axis=oversampling_axis) mask = detector.mca_mask() energies = mca_bin_energies[mask] det_nxdata.energy = NXfield(value=energies, attrs={'units': 'keV'}) @@ -1799,7 +1830,9 @@ def linkdims(nxgroup, field_dims=[]): fit_nxgroup.results = NXdata() fit_nxdata = fit_nxgroup.results linkdims( - fit_nxdata, {'axes': 'energy', 'index': len(map_config.shape)}) + fit_nxdata, + {'axes': 'energy', 'index': len(map_config.shape)}, + oversampling_axis=oversampling_axis) fit_nxdata.makelink(det_nxdata.energy) fit_nxdata.best_fit= uniform_best_fit fit_nxdata.residuals = uniform_residuals @@ -1811,12 +1844,12 @@ def linkdims(nxgroup, field_dims=[]): # fit_nxdata = fit_nxgroup.fit_hkl_centers # linkdims(fit_nxdata) for (hkl, center_guess, centers_fit, centers_error, - amplitudes_fit, amplitudes_error, sigmas_fit, - sigmas_error) in zip( - fit_hkls, peak_locations, - uniform_fit_centers, uniform_fit_centers_errors, - uniform_fit_amplitudes, uniform_fit_amplitudes_errors, - uniform_fit_sigmas, uniform_fit_sigmas_errors): + amplitudes_fit, amplitudes_error, sigmas_fit, + sigmas_error) in zip( + fit_hkls, peak_locations, + uniform_fit_centers, uniform_fit_centers_errors, + uniform_fit_amplitudes, uniform_fit_amplitudes_errors, + uniform_fit_sigmas, uniform_fit_sigmas_errors): hkl_name = '_'.join(str(hkl)[1:-1].split(' ')) fit_nxgroup[hkl_name] = NXparameters() # Report initial HKL peak centers @@ -1944,8 +1977,7 @@ def animate(i): ani.save(path) plt.close() - tth_map = detector.get_tth_map( - map_config, strain_analysis_config.sum_fly_axes) + tth_map = detector.get_tth_map(effective_map_shape) det_nxdata.tth.nxdata = tth_map nominal_centers = np.asarray( [get_peak_locations(d0, tth_map) for d0 in fit_ds]) @@ -1962,7 +1994,9 @@ def animate(i): fit_nxgroup.results = NXdata() fit_nxdata = fit_nxgroup.results linkdims( - fit_nxdata, {'axes': 'energy', 'index': len(map_config.shape)}) + fit_nxdata, + {'axes': 'energy', 'index': len(map_config.shape)}, + oversampling_axis=oversampling_axis) fit_nxdata.makelink(det_nxdata.energy) fit_nxdata.best_fit= unconstrained_best_fit fit_nxdata.residuals = unconstrained_residuals diff --git a/CHAP/edd/reader.py b/CHAP/edd/reader.py index fc94e7c..f3ae7ce 100755 --- a/CHAP/edd/reader.py +++ b/CHAP/edd/reader.py @@ -88,7 +88,7 @@ def add_fly_axis(fly_axis_index): add_fly_axis(0) if scan_type in (2, 3, 5): add_fly_axis(1) - if scan_type in (4, 5): + if scan_type == 5: scalar_data.append(dict( label='bin_axis', units='n/a', data_type='smb_par', name='bin_axis')) diff --git a/CHAP/edd/utils.py b/CHAP/edd/utils.py index 011049c..67dba10 100755 --- a/CHAP/edd/utils.py +++ b/CHAP/edd/utils.py @@ -25,6 +25,7 @@ def get_peak_locations(ds, tth): return hc / (2. * ds * np.sin(0.5 * np.radians(tth))) + def make_material(name, sgnum, lattice_parameters, dmin=0.6): """Return a hexrd.material.Material with the given properties. @@ -59,6 +60,7 @@ def make_material(name, sgnum, lattice_parameters, dmin=0.6): return material + def get_unique_hkls_ds(materials, tth_tol=None, tth_max=None, round_sig=8): """Return the unique HKLs and lattice spacings for the given list of materials. @@ -107,6 +109,7 @@ def get_unique_hkls_ds(materials, tth_tol=None, tth_max=None, round_sig=8): return hkls_unique, ds_unique + def select_tth_initial_guess(x, y, hkls, ds, tth_initial_guess=5.0, interactive=False, filename=None): """Show a matplotlib figure of a reference MCA spectrum on top of @@ -269,6 +272,7 @@ def confirm(event): return tth_new_guess + def select_material_params(x, y, tth, materials=[], label='Reference Data', interactive=False, filename=None): """Interactively select the lattice parameters and space group for @@ -517,6 +521,7 @@ def confirm(event): return new_materials + def select_mask_and_hkls(x, y, hkls, ds, tth, preselected_bin_ranges=[], preselected_hkl_indices=[], detector_name=None, ref_map=None, flux_energy_range=None, calibration_bin_ranges=None, @@ -974,6 +979,75 @@ def confirm(event): return selected_bin_ranges, selected_hkl_indices +def get_rolling_sum_spectra( + y, bin_axis, start=0, end=None, width=None, stride=None, num=None, + mode='valid'): + """ + Return the rolling sum of the spectra over a specified axis. + """ + y = np.asarray(y) + if not 0 <= bin_axis < y.ndim-1: + raise ValueError(f'Invalid "bin_axis" parameter ({bin_axis})') + size = y.shape[bin_axis] + if not 0 <= start < size: + raise ValueError(f'Invalid "start" parameter ({start})') + if end is None: + end = size + elif not start < end <= size: + raise ValueError('Invalid "start" and "end" combination ' + f'({start} and {end})') + + size = end-start + if stride is None: + if width is None: + width = max(1, int(size/num)) + stride = width + else: + width = max(1, min(width, size)) + if num is None: + stride = width + else: + stride = max(1, int((size-width) / (num-1))) + else: + stride = max(1, min(stride, size-stride)) + if width is None: + width = stride + if mode == 'valid': + num = 1 + max(0, int((size-width) / stride)) + else: + num = int(size/stride) + if num*stride < size: + num += 1 + bin_ranges = [(start+n*stride, min(start+size, start+n*stride+width)) + for n in range(num)] + + y_shape = y.shape + y_ndim = y.ndim + swap_axis = False + if y_ndim > 2 and bin_axis != y_ndim-2: + y = np.swapaxes(y, bin_axis, y_ndim-2) + swap_axis = True + if y_ndim > 3: + map_shape = y.shape[0:y_ndim-2] + y = y.reshape((np.prod(map_shape), *y.shape[y_ndim-2:])) + if y_ndim == 2: + y = np.expand_dims(y, 0) + + ry = np.zeros((y.shape[0], num, y.shape[-1]), dtype=y.dtype) + for dim in range(y.shape[0]): + for n in range(num): + ry[dim, n] = np.sum(y[dim,bin_ranges[n][0]:bin_ranges[n][1]], 0) + + if y_ndim > 3: + ry = np.reshape(ry, (*map_shape, num, y_shape[-1])) + if y_ndim == 2: + ry = np.squeeze(ry) + if swap_axis: + ry = np.swapaxes(ry, bin_axis, y_ndim-2) + + return ry + + def get_spectra_fits(spectra, energies, peak_locations, fit_params): """Return twenty arrays of fit results for the map of spectra provided: uniform centers, uniform center errors, uniform @@ -1005,10 +1079,13 @@ def get_spectra_fits(spectra, energies, peak_locations, fit_params): numpy.ndarray, numpy.ndarray, numpy.ndarray, numpy.ndarray, numpy.ndarray] """ - from CHAP.utils.fit import FitMap + from CHAP.utils.fit import Fit, FitMap # Perform fit to get measured peak positions - fit = FitMap(spectra, x=energies) + if spectra.ndim == 1: + fit = Fit(spectra, x=energies) + else: + fit = FitMap(spectra, x=energies) num_peak = len(peak_locations) delta = 0.1 * (energies[-1]-energies[0]) centers_range = ( @@ -1022,31 +1099,44 @@ def get_spectra_fits(spectra, energies, peak_locations, fit_params): fwhm_max=fit_params.fwhm_max, centers_range=centers_range) fit.fit(num_proc=fit_params.num_proc) -#RV fit.fit(num_proc=1, plot=True) - uniform_fit_centers = [ - fit.best_values[ - fit.best_parameters().index(f'peak{i+1}_center')] - for i in range(num_peak)] - uniform_fit_centers_errors = [ - fit.best_errors[ - fit.best_parameters().index(f'peak{i+1}_center')] - for i in range(num_peak)] - uniform_fit_amplitudes = [ - fit.best_values[ - fit.best_parameters().index(f'peak{i+1}_amplitude')] - for i in range(num_peak)] - uniform_fit_amplitudes_errors = [ - fit.best_errors[ - fit.best_parameters().index(f'peak{i+1}_amplitude')] - for i in range(num_peak)] - uniform_fit_sigmas = [ - fit.best_values[ - fit.best_parameters().index(f'peak{i+1}_sigma')] - for i in range(num_peak)] - uniform_fit_sigmas_errors = [ - fit.best_errors[ - fit.best_parameters().index(f'peak{i+1}_sigma')] - for i in range(num_peak)] + if spectra.ndim == 1: + uniform_fit_centers = [ + fit.best_values[f'peak{i+1}_center'] for i in range(num_peak)] + uniform_fit_centers_errors = [ + fit.best_errors[f'peak{i+1}_center'] for i in range(num_peak)] + uniform_fit_amplitudes = [ + fit.best_values[f'peak{i+1}_amplitude'] for i in range(num_peak)] + uniform_fit_amplitudes_errors = [ + fit.best_errors[f'peak{i+1}_amplitude'] for i in range(num_peak)] + uniform_fit_sigmas = [ + fit.best_values[f'peak{i+1}_sigma'] for i in range(num_peak)] + uniform_fit_sigmas_errors = [ + fit.best_errors[f'peak{i+1}_sigma'] for i in range(num_peak)] + else: + uniform_fit_centers = [ + fit.best_values[ + fit.best_parameters().index(f'peak{i+1}_center')] + for i in range(num_peak)] + uniform_fit_centers_errors = [ + fit.best_errors[ + fit.best_parameters().index(f'peak{i+1}_center')] + for i in range(num_peak)] + uniform_fit_amplitudes = [ + fit.best_values[ + fit.best_parameters().index(f'peak{i+1}_amplitude')] + for i in range(num_peak)] + uniform_fit_amplitudes_errors = [ + fit.best_errors[ + fit.best_parameters().index(f'peak{i+1}_amplitude')] + for i in range(num_peak)] + uniform_fit_sigmas = [ + fit.best_values[ + fit.best_parameters().index(f'peak{i+1}_sigma')] + for i in range(num_peak)] + uniform_fit_sigmas_errors = [ + fit.best_errors[ + fit.best_parameters().index(f'peak{i+1}_sigma')] + for i in range(num_peak)] uniform_best_fit = fit.best_fit uniform_residuals = fit.residual uniform_redchi = fit.redchi @@ -1056,33 +1146,44 @@ def get_spectra_fits(spectra, energies, peak_locations, fit_params): fit.create_multipeak_model(fit_type='unconstrained') fit.fit(num_proc=fit_params.num_proc, rel_amplitude_cutoff=fit_params.rel_amplitude_cutoff) -#RV fit.fit(num_proc=1, plot=True) - unconstrained_fit_centers = np.array( - [fit.best_values[ - fit.best_parameters()\ - .index(f'peak{i+1}_center')] - for i in range(num_peak)]) - unconstrained_fit_centers_errors = np.array( - [fit.best_errors[ - fit.best_parameters()\ - .index(f'peak{i+1}_center')] - for i in range(num_peak)]) - unconstrained_fit_amplitudes = [ - fit.best_values[ - fit.best_parameters().index(f'peak{i+1}_amplitude')] - for i in range(num_peak)] - unconstrained_fit_amplitudes_errors = [ - fit.best_errors[ - fit.best_parameters().index(f'peak{i+1}_amplitude')] - for i in range(num_peak)] - unconstrained_fit_sigmas = [ - fit.best_values[ - fit.best_parameters().index(f'peak{i+1}_sigma')] - for i in range(num_peak)] - unconstrained_fit_sigmas_errors = [ - fit.best_errors[ - fit.best_parameters().index(f'peak{i+1}_sigma')] - for i in range(num_peak)] + if spectra.ndim == 1: + unconstrained_fit_centers = [ + fit.best_values[f'peak{i+1}_center'] for i in range(num_peak)] + unconstrained_fit_centers_errors = [ + fit.best_errors[f'peak{i+1}_center'] for i in range(num_peak)] + unconstrained_fit_amplitudes = [ + fit.best_values[f'peak{i+1}_amplitude'] for i in range(num_peak)] + unconstrained_fit_amplitudes_errors = [ + fit.best_errors[f'peak{i+1}_amplitude'] for i in range(num_peak)] + unconstrained_fit_sigmas = [ + fit.best_values[f'peak{i+1}_sigma'] for i in range(num_peak)] + unconstrained_fit_sigmas_errors = [ + fit.best_errors[f'peak{i+1}_sigma'] for i in range(num_peak)] + else: + unconstrained_fit_centers = np.array( + [fit.best_values[ + fit.best_parameters().index(f'peak{i+1}_center')] + for i in range(num_peak)]) + unconstrained_fit_centers_errors = np.array( + [fit.best_errors[ + fit.best_parameters().index(f'peak{i+1}_center')] + for i in range(num_peak)]) + unconstrained_fit_amplitudes = [ + fit.best_values[ + fit.best_parameters().index(f'peak{i+1}_amplitude')] + for i in range(num_peak)] + unconstrained_fit_amplitudes_errors = [ + fit.best_errors[ + fit.best_parameters().index(f'peak{i+1}_amplitude')] + for i in range(num_peak)] + unconstrained_fit_sigmas = [ + fit.best_values[ + fit.best_parameters().index(f'peak{i+1}_sigma')] + for i in range(num_peak)] + unconstrained_fit_sigmas_errors = [ + fit.best_errors[ + fit.best_parameters().index(f'peak{i+1}_sigma')] + for i in range(num_peak)] unconstrained_best_fit = fit.best_fit unconstrained_residuals = fit.residual unconstrained_redchi = fit.redchi diff --git a/CHAP/utils/general.py b/CHAP/utils/general.py index 0e9717f..51cfe85 100755 --- a/CHAP/utils/general.py +++ b/CHAP/utils/general.py @@ -909,6 +909,149 @@ def file_exists_and_readable(f): return f +def rolling_average( + y, x=None, dtype=None, start=0, end=None, width=None, + stride=None, num=None, average=True, mode='valid', + use_convolve=None): + """ + Returns the rolling sum or average of an array over the last + dimension. + """ + y = np.asarray(y) + y_shape = y.shape + if y.ndim == 1: + y = np.expand_dims(y, 0) + else: + y = y.reshape((np.prod(y.shape[0:-1]), y.shape[-1])) + if x is not None: + x = np.asarray(x) + if x.ndim != 1: + raise ValueError('Parameter "x" must be a 1D array-like') + if x.size != y.shape[1]: + raise ValueError(f'Dimensions of "x" and "y[1]" do not ' + f'match ({x.size} vs {y.shape[1]})') + if dtype is None: + if average: + dtype = y.dtype + else: + dtype = np.float32 + if width is None and stride is None and num is None: + raise ValueError('Invalid input parameters, specify at least one of ' + '"width", "stride" or "num"') + if width is not None and not is_int(width, ge=1): + raise ValueError(f'Invalid "width" parameter ({width})') + if stride is not None and not is_int(stride, ge=1): + raise ValueError(f'Invalid "stride" parameter ({stride})') + if num is not None and not is_int(num, ge=1): + raise ValueError(f'Invalid "num" parameter ({num})') + if not isinstance(average, bool): + raise ValueError(f'Invalid "average" parameter ({average})') + if mode not in ('valid', 'full'): + raise ValueError(f'Invalid "mode" parameter ({mode})') + size = y.shape[1] + if size < 2: + raise ValueError(f'Invalid y[1] dimension ({size})') + if not is_int(start, ge=0, lt=size): + raise ValueError(f'Invalid "start" parameter ({start})') + if end is None: + end = size + elif not is_int(end, gt=start, le=size): + raise ValueError(f'Invalid "end" parameter ({end})') + if use_convolve is None: + if len(y_shape) ==1: + use_convolve = True + else: + use_convolve = False + if use_convolve and (start or end < size): + y = np.take(y, range(start, end), axis=1) + if x is not None: + x = x[start:end] + size = y.shape[1] + else: + size = end-start + + if stride is None: + if width is None: + width = max(1, int(size/num)) + stride = width + else: + width = min(width, size) + if num is None: + stride = width + else: + stride = max(1, int((size-width) / (num-1))) + else: + stride = min(stride, size-stride) + if width is None: + width = stride + + if mode == 'valid': + num = 1 + max(0, int((size-width) / stride)) + else: + num = int(size/stride) + if num*stride < size: + num += 1 + + if use_convolve: + n_start = 0 + n_end = width + weight = np.empty((num)) + for n in range(num): + n_num = n_end-n_start + weight[n] = n_num + n_start += stride + n_end = min(size, n_end+stride) + + window = np.ones((width)) + if x is not None: + if mode == 'valid': + rx = np.convolve(x, window)[width-1:1-width:stride] + else: + rx = np.convolve(x, window)[width-1::stride] + rx /= weight + + ry = [] + if mode == 'valid': + for i in range(y.shape[0]): + ry.append(np.convolve(y[i], window)[width-1:1-width:stride]) + else: + for i in range(y.shape[0]): + ry.append(np.convolve(y[i], window)[width-1::stride]) + ry = np.reshape(ry, (*y_shape[0:-1], num)) + if len(y_shape) == 1: + ry = np.squeeze(ry) + if average: + ry = (np.asarray(ry).astype(np.float32)/weight).astype(dtype) + elif mode != 'valid': + weight = np.where(weight < width, width/weight, 1.0) + ry = (np.asarray(ry).astype(np.float32)*weight).astype(dtype) + else: + ry = np.zeros((num, y.shape[0]), dtype=y.dtype) + if x is not None: + rx = np.zeros(num, dtype=x.dtype) + n_start = start + n_end = n_start+width + for n in range(num): + y_sum = np.sum(y[:,n_start:n_end], 1) + n_num = n_end-n_start + if n_num < width: + y_sum *= width/n_num + ry[n] = y_sum + if x is not None: + rx[n] = np.sum(x[n_start:n_end])/n_num + n_start += stride + n_end = min(start+size, n_end+stride) + ry = np.reshape(ry.T, (*y_shape[0:-1], num)) + if len(y_shape) == 1: + ry = np.squeeze(ry) + if average: + ry = (ry.astype(np.float32)/width).astype(dtype) + + if x is None: + return ry + return ry, rx + + def select_mask_1d( y, x=None, label=None, ref_data=[], preselected_index_ranges=None, preselected_mask=None, title=None, xlabel=None, ylabel=None,