From 6a457150e61f81f33c8c037c6777fd2df9cca7b8 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Tue, 18 Apr 2023 14:16:19 -0400 Subject: [PATCH 1/6] style: edit selected files in the CHAP module to comply with the PEP8 style guide and raise the score when evaluating the package code with pylint. --- CHAP/common/models/integration.py | 642 +++++++++++------- CHAP/common/models/map.py | 542 +++++++++------ CHAP/common/processor.py | 372 ++++++----- CHAP/common/reader.py | 89 +-- CHAP/common/utils/scanparsers.py | 1013 ++++++++++++++++++++--------- CHAP/common/writer.py | 56 +- CHAP/edd/processor.py | 165 ++--- CHAP/inference/processor.py | 48 +- CHAP/pipeline.py | 38 +- CHAP/processor.py | 43 +- CHAP/reader.py | 67 +- CHAP/runner.py | 44 +- CHAP/writer.py | 65 +- 13 files changed, 1957 insertions(+), 1227 deletions(-) diff --git a/CHAP/common/models/integration.py b/CHAP/common/models/integration.py index ef1e278..e230198 100644 --- a/CHAP/common/models/integration.py +++ b/CHAP/common/models/integration.py @@ -1,17 +1,8 @@ import copy -from functools import cache, lru_cache -import json -import logging +from functools import cache import os -from time import time from typing import Literal, Optional -# from multiprocessing.pool import ThreadPool -# from nexusformat.nexus import (NXdata, -# NXdetector, -# NXfield, -# NXprocess, -# NXroot) import numpy as np from pydantic import (BaseModel, validator, @@ -20,20 +11,15 @@ conint, confloat, FilePath) -#import pyFAI, pyFAI.multi_geometry, pyFAI.units from pyFAI import load as pyfai_load from pyFAI.multi_geometry import MultiGeometry from pyFAI.units import AZIMUTHAL_UNITS, RADIAL_UNITS -#from pyspec.file.tiff import TiffFile - - -#from .map import MapConfig, SpecScans class Detector(BaseModel): - """ - Detector class to represent a single detector used in the experiment. - + """Detector class to represent a single detector used in the + experiment. + :param prefix: Prefix of the detector in the SPEC file. :type prefix: str :param poni_file: Path to the poni file. @@ -44,11 +30,12 @@ class Detector(BaseModel): prefix: constr(strip_whitespace=True, min_length=1) poni_file: FilePath mask_file: Optional[FilePath] + @validator('poni_file', allow_reuse=True) def validate_poni_file(cls, poni_file): - """ - Validate the poni file by checking if it's a valid PONI file. - + """Validate the poni file by checking if it's a valid PONI + file. + :param poni_file: Path to the poni file. :type poni_file: str :raises ValueError: If poni_file is not a valid PONI file. @@ -58,52 +45,79 @@ def validate_poni_file(cls, poni_file): poni_file = os.path.abspath(poni_file) try: ai = azimuthal_integrator(poni_file) - except: - raise(ValueError(f'{poni_file} is not a valid PONI file')) - else: - return(poni_file) + except Exception as exc: + raise ValueError(f'{poni_file} is not a valid PONI file') from exc + return poni_file + @validator('mask_file', allow_reuse=True) def validate_mask_file(cls, mask_file, values): - """ - Validate the mask file. If a mask file is provided, it checks if it's a valid TIFF file. - + """Validate the mask file. If a mask file is provided, it + checks if it's a valid TIFF file. + :param mask_file: Path to the mask file. :type mask_file: str or None :param values: A dictionary of the Detector fields. :type values: dict - :raises ValueError: If mask_file is provided and it's not a valid TIFF file. + :raises ValueError: If mask_file is provided and it's not a + valid TIFF file. :raises ValueError: If `'poni_file'` is not provided in `values`. :returns: Absolute path to the mask file or None. - :rtype: str or None + :rtype: str or None """ if mask_file is None: - return(mask_file) - else: - mask_file = os.path.abspath(mask_file) - poni_file = values.get('poni_file') - if poni_file is None: - raise(ValueError('Cannot validate mask file without a PONI file.')) - else: - try: - mask_array = get_mask_array(mask_file, poni_file) - except BaseException as e: - raise(ValueError(f'Unable to open {mask_file} as a TIFF file')) - else: - return(mask_file) + return mask_file + + mask_file = os.path.abspath(mask_file) + poni_file = values.get('poni_file') + if poni_file is None: + raise ValueError( + 'Cannot validate mask file without a PONI file.') + try: + mask_array = get_mask_array(mask_file, poni_file) + except BaseException as exc: + raise ValueError( + f'Unable to open {mask_file} as a TIFF file') from exc + return mask_file + @property def azimuthal_integrator(self): - return(azimuthal_integrator(self.poni_file)) + """Return the azimuthal integrator associated with this + detector. + """ + return azimuthal_integrator(self.poni_file) + @property def mask_array(self): - return(get_mask_array(self.mask_file, self.poni_file)) + """Return the mask array assocated with this detector.""" + return get_mask_array(self.mask_file, self.poni_file) + @cache def azimuthal_integrator(poni_file:str): + """Return the azimuthal integrator from a PONI file + + :param poni_file: path to a PONI file + :type poni_file: str + :return: azimuthal integrator + :rtype: pyFAI.azimuthal_integrator.AzimuthalIntegrator + """ if not isinstance(poni_file, str): poni_file = str(poni_file) - return(pyfai_load(poni_file)) + return pyfai_load(poni_file) + + @cache def get_mask_array(mask_file:str, poni_file:str): + """Return a mask array associated with a detector loaded from a + tiff file. + + :param mask_file: path to a .tiff file + :type mask_file: str + :param poni_file: path to a PONI file + :type poni_file: str + :return: the mask array loaded from `mask_file` + :rtype: numpy.ndarray + """ if mask_file is not None: if not isinstance(mask_file, str): mask_file = str(mask_file) @@ -113,37 +127,46 @@ def get_mask_array(mask_file:str, poni_file:str): mask_array = tiff.asarray() else: mask_array = np.zeros(azimuthal_integrator(poni_file).detector.shape) - return(mask_array) + return mask_array + class IntegrationConfig(BaseModel): - """ - Class representing the configuration for a raw detector data integration. + """Class representing the configuration for a raw detector data + integration. - :ivar tool_type: type of integration tool; always set to "integration" + :ivar tool_type: type of integration tool; always set to + "integration" :type tool_type: str, optional :ivar title: title of the integration :type title: str - :ivar integration_type: type of integration, one of "azimuthal", "radial", or "cake" + :ivar integration_type: type of integration, one of "azimuthal", + "radial", or "cake" :type integration_type: str :ivar detectors: list of detectors used in the integration :type detectors: List[Detector] - :ivar radial_units: radial units for the integration, defaults to `'q_A^-1'` + :ivar radial_units: radial units for the integration, defaults to + `'q_A^-1'` :type radial_units: str, optional :ivar radial_min: minimum radial value for the integration range :type radial_min: float, optional :ivar radial_max: maximum radial value for the integration range :type radial_max: float, optional - :ivar radial_npt: number of points in the radial range for the integration + :ivar radial_npt: number of points in the radial range for the + integration :type radial_npt: int, optional :ivar azimuthal_units: azimuthal units for the integration :type azimuthal_units: str, optional - :ivar azimuthal_min: minimum azimuthal value for the integration range + :ivar azimuthal_min: minimum azimuthal value for the integration + range :type azimuthal_min: float, optional - :ivar azimuthal_max: maximum azimuthal value for the integration range + :ivar azimuthal_max: maximum azimuthal value for the integration + range :type azimuthal_max: float, optional - :ivar azimuthal_npt: number of points in the azimuthal range for the integration + :ivar azimuthal_npt: number of points in the azimuthal range for + the integration :type azimuthal_npt: int, optional - :ivar error_model: error model for the integration, one of "poisson" or "azimuthal" + :ivar error_model: error model for the integration, one of + "poisson" or "azimuthal" :type error_model: str, optional """ tool_type: Literal['integration'] = 'integration' @@ -160,71 +183,95 @@ class IntegrationConfig(BaseModel): azimuthal_npt: conint(gt=0) = 3600 error_model: Optional[Literal['poisson', 'azimuthal']] sequence_index: Optional[conint(gt=0)] + @validator('radial_units', allow_reuse=True) def validate_radial_units(cls, radial_units): - """ - Validate the radial units for the integration. + """Validate the radial units for the integration. - :param radial_units: unvalidated radial units for the integration + :param radial_units: unvalidated radial units for the + integration :type radial_units: str - :raises ValueError: if radial units are not one of the recognized radial units + :raises ValueError: if radial units are not one of the + recognized radial units :return: validated radial units :rtype: str """ if radial_units in RADIAL_UNITS.keys(): - return(radial_units) - else: - raise(ValueError(f'Invalid radial units: {radial_units}. Must be one of {", ".join(RADIAL_UNITS.keys())}')) + return radial_units + raise ValueError( + f'Invalid radial units: {radial_units}. ' + + f'Must be one of {", ".join(RADIAL_UNITS.keys())}') + @validator('azimuthal_units', allow_reuse=True) def validate_azimuthal_units(cls, azimuthal_units): - """ - Validate that `azimuthal_units` is one of the keys in the + """Validate that `azimuthal_units` is one of the keys in the `pyFAI.units.AZIMUTHAL_UNITS` dictionary. - :param azimuthal_units: The string representing the unit to be validated. + :param azimuthal_units: The string representing the unit to be + validated. :type azimuthal_units: str - :raises ValueError: If `azimuthal_units` is not one of the keys in `pyFAI.units.AZIMUTHAL_UNITS` - :return: The original supplied value, if is one of the keys in `pyFAI.units.AZIMUTHAL_UNITS`. + :raises ValueError: If `azimuthal_units` is not one of the + keys in `pyFAI.units.AZIMUTHAL_UNITS` + :return: The original supplied value, if is one of the keys in + `pyFAI.units.AZIMUTHAL_UNITS`. :rtype: str """ if azimuthal_units in AZIMUTHAL_UNITS.keys(): - return(azimuthal_units) - else: - raise(ValueError(f'Invalid azimuthal units: {azimuthal_units}. Must be one of {", ".join(AZIMUTHAL_UNITS.keys())}')) + return azimuthal_units + raise ValueError( + f'Invalid azimuthal units: {azimuthal_units}. ' + + f'Must be one of {", ".join(AZIMUTHAL_UNITS.keys())}') + def validate_range_max(range_name:str): """Validate the maximum value of an integration range. - :param range_name: The name of the integration range (e.g. radial, azimuthal). + :param range_name: The name of the integration range + (e.g. radial, azimuthal). :type range_name: str :return: The callable that performs the validation. :rtype: callable """ def _validate_range_max(cls, range_max, values): - """Check if the maximum value of the integration range is greater than its minimum value. + """Check if the maximum value of the integration range is + greater than its minimum value. - :param range_max: The maximum value of the integration range. + :param range_max: The maximum value of the integration + range. :type range_max: float - :param values: The values of the other fields being validated. + :param values: The values of the other fields being + validated. :type values: dict - :raises ValueError: If the maximum value of the integration range is not greater than its minimum value. + :raises ValueError: If the maximum value of the + integration range is not greater than its minimum + value. :return: The validated maximum range value :rtype: float """ range_min = values.get(f'{range_name}_min') if range_min < range_max: - return(range_max) - else: - raise(ValueError(f'Maximum value of integration range must be greater than minimum value of integration range ({range_name}_min={range_min}).')) - return(_validate_range_max) - _validate_radial_max = validator('radial_max', allow_reuse=True)(validate_range_max('radial')) - _validate_azimuthal_max = validator('azimuthal_max', allow_reuse=True)(validate_range_max('azimuthal')) + return range_max + raise ValueError( + 'Maximum value of integration range must be ' + + 'greater than minimum value of integration range ' + + f'({range_name}_min={range_min}).') + return _validate_range_max + + _validate_radial_max = validator( + 'radial_max', + allow_reuse=True)(validate_range_max('radial')) + _validate_azimuthal_max = validator( + 'azimuthal_max', + allow_reuse=True)(validate_range_max('azimuthal')) + def validate_for_map_config(self, map_config:BaseModel): - """ - Validate the existence of the detector data file for all scan points in `map_config`. - - :param map_config: The `MapConfig` instance to validate against. + """Validate the existence of the detector data file for all + scan points in `map_config`. + + :param map_config: The `MapConfig` instance to validate + against. :type map_config: MapConfig - :raises RuntimeError: If a detector data file could not be found for a scan point occurring in `map_config`. + :raises RuntimeError: If a detector data file could not be + found for a scan point occurring in `map_config`. :return: None :rtype: None """ @@ -233,39 +280,55 @@ def validate_for_map_config(self, map_config:BaseModel): for scan_number in scans.scan_numbers: scanparser = scans.get_scanparser(scan_number) for scan_step_index in range(scanparser.spec_scan_npts): - # Make sure the detector data file exists for all scan points + # Make sure the detector data file exists for + # all scan points try: - detector_data_file = scanparser.get_detector_data_file(detector.prefix, scan_step_index) - except: - raise(RuntimeError(f'Could not find data file for detector prefix {detector.prefix} on scan number {scan_number} in spec file {scans.spec_file}')) + detector_data_file = \ + scanparser.get_detector_data_file( + detector.prefix, scan_step_index) + except Exception as exc: + raise RuntimeError( + 'Could not find data file for detector prefix ' + + f'{detector.prefix} ' + + f'on scan number {scan_number} ' + + f'in spec file {scans.spec_file}') from exc + def get_azimuthal_adjustments(self): - """To enable a continuous range of integration in the azimuthal direction - for radial and cake integration, obtain adjusted values for this - `IntegrationConfig`'s `azimuthal_min` and `azimuthal_max` values, the - angle amount by which those values were adjusted, and the proper location - of the discontinuity in the azimuthal direction. + """To enable a continuous range of integration in the + azimuthal direction for radial and cake integration, obtain + adjusted values for this `IntegrationConfig`'s `azimuthal_min` + and `azimuthal_max` values, the angle amount by which those + values were adjusted, and the proper location of the + discontinuity in the azimuthal direction. - :return: Adjusted chi_min, adjusted chi_max, chi_offset, chi_discontinuity + :return: Adjusted chi_min, adjusted chi_max, chi_offset, + chi_discontinuity :rtype: tuple[float,float,float,float] """ - return(get_azimuthal_adjustments(self.azimuthal_min, self.azimuthal_max)) + return get_azimuthal_adjustments(self.azimuthal_min, + self.azimuthal_max) + def get_azimuthal_integrators(self): - """Get a list of `AzimuthalIntegrator`s that correspond to the detector - configurations in this instance of `IntegrationConfig`. + """Get a list of `AzimuthalIntegrator`s that correspond to the + detector configurations in this instance of + `IntegrationConfig`. - The returned `AzimuthalIntegrator`s are (if need be) artificially rotated - in the azimuthal direction to achieve a continuous range of integration - in the azimuthal direction. + The returned `AzimuthalIntegrator`s are (if need be) + artificially rotated in the azimuthal direction to achieve a + continuous range of integration in the azimuthal direction. - :returns: A list of `AzimuthalIntegrator`s appropriate for use by this - `IntegrationConfig` tool + :returns: A list of `AzimuthalIntegrator`s appropriate for use + by this `IntegrationConfig` tool :rtype: list[pyFAI.azimuthalIntegrator.AzimuthalIntegrator] """ - chi_min, chi_max, chi_offset, chi_disc = self.get_azimuthal_adjustments() - return(get_azimuthal_integrators(tuple([detector.poni_file for detector in self.detectors]), chi_offset=chi_offset)) + chi_offset = self.get_azimuthal_adjustments()[2] + return get_azimuthal_integrators( + tuple([detector.poni_file for detector in self.detectors]), + chi_offset=chi_offset) + def get_multi_geometry_integrator(self): - """Get a `MultiGeometry` integrator suitable for use by this instance of - `IntegrationConfig`. + """Get a `MultiGeometry` integrator suitable for use by this + instance of `IntegrationConfig`. :return: A `MultiGeometry` integrator :rtype: pyFAI.multi_geometry.MultiGeometry @@ -273,94 +336,138 @@ def get_multi_geometry_integrator(self): poni_files = tuple([detector.poni_file for detector in self.detectors]) radial_range = (self.radial_min, self.radial_max) azimuthal_range = (self.azimuthal_min, self.azimuthal_max) - return(get_multi_geometry_integrator(poni_files, self.radial_units, radial_range, azimuthal_range)) - def get_azimuthally_integrated_data(self, spec_scans:BaseModel, scan_number:int, scan_step_index:int): - """Return azimuthally-integrated data for the scan step specified. - - :param spec_scans: An instance of `SpecScans` containing the scan step requested. + return get_multi_geometry_integrator(poni_files, self.radial_units, + radial_range, azimuthal_range) + + def get_azimuthally_integrated_data(self, + spec_scans:BaseModel, + scan_number:int, + scan_step_index:int): + """Return azimuthally-integrated data for the scan step + specified. + + :param spec_scans: An instance of `SpecScans` containing the + scan step requested. :type spec_scans: SpecScans - :param scan_number: The number of the scan containing the scan step requested. + :param scan_number: The number of the scan containing the scan + step requested. :type scan_number: int :param scan_step_index: The index of the scan step requested. :type scan_step_index: int - :return: A 1D array of azimuthally-integrated raw detector intensities. + :return: A 1D array of azimuthally-integrated raw detector + intensities. :rtype: np.ndarray """ - detector_data = spec_scans.get_detector_data(self.detectors, scan_number, scan_step_index) + detector_data = spec_scans.get_detector_data(self.detectors, + scan_number, + scan_step_index) integrator = self.get_multi_geometry_integrator() lst_mask = [detector.mask_array for detector in self.detectors] - result = integrator.integrate1d(detector_data, lst_mask=lst_mask, npt=self.radial_npt, error_model=self.error_model) + result = integrator.integrate1d(detector_data, + lst_mask=lst_mask, + npt=self.radial_npt, + error_model=self.error_model) if result.sigma is None: - return(result.intensity) - else: - return(result.intensity, result.sigma) - def get_radially_integrated_data(self, spec_scans:BaseModel, scan_number:int, scan_step_index:int): - """Return radially-integrated data for the scan step specified. - - :param spec_scans: An instance of `SpecScans` containing the scan step requested. + return result.intensity + return result.intensity, result.sigma + + def get_radially_integrated_data(self, + spec_scans:BaseModel, + scan_number:int, + scan_step_index:int): + """Return radially-integrated data for the scan step + specified. + + :param spec_scans: An instance of `SpecScans` containing the + scan step requested. :type spec_scans: SpecScans - :param scan_number: The number of the scan containing the scan step requested. + :param scan_number: The number of the scan containing the scan + step requested. :type scan_number: int :param scan_step_index: The index of the scan step requested. :type scan_step_index: int - :return: A 1D array of radially-integrated raw detector intensities. + :return: A 1D array of radially-integrated raw detector + intensities. :rtype: np.ndarray """ - # Handle idiosyncracies of azimuthal ranges in pyFAI - # Adjust chi ranges to get a continuous range of iintegrated data - chi_min, chi_max, chi_offset, chi_disc = self.get_azimuthal_adjustments() + # Handle idiosyncracies of azimuthal ranges in pyFAI Adjust + # chi ranges to get a continuous range of iintegrated data + chi_min, chi_max, *adjust = self.get_azimuthal_adjustments() # Perform radial integration on a detector-by-detector basis. - I_each_detector = [] + intensity_each_detector = [] variance_each_detector = [] integrators = self.get_azimuthal_integrators() - for i,(integrator,detector) in enumerate(zip(integrators,self.detectors)): - detector_data = spec_scans.get_detector_data([detector], scan_number, scan_step_index)[0] - result = integrator.integrate_radial(detector_data, self.azimuthal_npt, - unit=self.azimuthal_units, azimuth_range=(chi_min,chi_max), - radial_unit=self.radial_units, radial_range=(self.radial_min,self.radial_max), - mask=detector.mask_array) #, error_model=self.error_model) - I_each_detector.append(result.intensity) + for integrator,detector in zip(integrators,self.detectors): + detector_data = spec_scans.get_detector_data( + [detector], scan_number, scan_step_index)[0] + result = integrator.integrate_radial( + detector_data, + self.azimuthal_npt, + unit=self.azimuthal_units, + azimuth_range=(chi_min,chi_max), + radial_unit=self.radial_units, + radial_range=(self.radial_min,self.radial_max), + mask=detector.mask_array) #, error_model=self.error_model) + intensity_each_detector.append(result.intensity) if result.sigma is not None: variance_each_detector.append(result.sigma**2) - # Add the individual detectors' integrated intensities together - I = np.nansum(I_each_detector, axis=0) + # Add the individual detectors' integrated intensities + # together + intensity = np.nansum(intensity_each_detector, axis=0) # Ignore data at values of chi for which there was no data - I = np.where(I==0, np.nan, I) - if len(I_each_detector) != len(variance_each_detector): - return(I) - else: - # Get the standard deviation of the summed detectors' intensities - sigma = np.sqrt(np.nansum(variance_each_detector, axis=0)) - return(I, sigma) - def get_cake_integrated_data(self, spec_scans:BaseModel, scan_number:int, scan_step_index:int): + intensity = np.where(intensity==0, np.nan, intensity) + if len(intensity_each_detector) != len(variance_each_detector): + return intensity + + # Get the standard deviation of the summed detectors' + # intensities + sigma = np.sqrt(np.nansum(variance_each_detector, axis=0)) + return intensity, sigma + + def get_cake_integrated_data(self, + spec_scans:BaseModel, + scan_number:int, + scan_step_index:int): """Return cake-integrated data for the scan step specified. - - :param spec_scans: An instance of `SpecScans` containing the scan step requested. + + :param spec_scans: An instance of `SpecScans` containing the + scan step requested. :type spec_scans: SpecScans - :param scan_number: The number of the scan containing the scan step requested. + :param scan_number: The number of the scan containing the scan + step requested. :type scan_number: int :param scan_step_index: The index of the scan step requested. :type scan_step_index: int - :return: A 2D array of cake-integrated raw detector intensities. + :return: A 2D array of cake-integrated raw detector + intensities. :rtype: np.ndarray """ - detector_data = spec_scans.get_detector_data(self.detectors, scan_number, scan_step_index) + detector_data = spec_scans.get_detector_data( + self.detectors, scan_number, scan_step_index) integrator = self.get_multi_geometry_integrator() lst_mask = [detector.mask_array for detector in self.detectors] - result = integrator.integrate2d(detector_data, lst_mask=lst_mask, - npt_rad=self.radial_npt, npt_azim=self.azimuthal_npt, - method='bbox', - error_model=self.error_model) + result = integrator.integrate2d( + detector_data, + lst_mask=lst_mask, + npt_rad=self.radial_npt, + npt_azim=self.azimuthal_npt, + method='bbox', + error_model=self.error_model) if result.sigma is None: - return(result.intensity) - else: - return(result.intensity, result.sigma) - def get_integrated_data(self, spec_scans:BaseModel, scan_number:int, scan_step_index:int): + return result.intensity + return result.intensity, result.sigma + + def get_integrated_data(self, + spec_scans:BaseModel, + scan_number:int, + scan_step_index:int): """Return integrated data for the scan step specified. - - :param spec_scans: An instance of `SpecScans` containing the scan step requested. + + :param spec_scans: An instance of `SpecScans` containing the + scan step requested. :type spec_scans: SpecScans - :param scan_number: The number of the scan containing the scan step requested. + :param scan_number: The number of the scan containing the scan + step requested. :type scan_number: int :param scan_step_index: The index of the scan step requested. :type scan_step_index: int @@ -368,84 +475,106 @@ def get_integrated_data(self, spec_scans:BaseModel, scan_number:int, scan_step_i :rtype: np.ndarray """ if self.integration_type == 'azimuthal': - return(self.get_azimuthally_integrated_data(spec_scans, scan_number, scan_step_index)) - elif self.integration_type == 'radial': - return(self.get_radially_integrated_data(spec_scans, scan_number, scan_step_index)) - elif self.integration_type == 'cake': - return(self.get_cake_integrated_data(spec_scans, scan_number, scan_step_index)) + return self.get_azimuthally_integrated_data(spec_scans, + scan_number, + scan_step_index) + if self.integration_type == 'radial': + return self.get_radially_integrated_data(spec_scans, + scan_number, + scan_step_index) + if self.integration_type == 'cake': + return self.get_cake_integrated_data(spec_scans, + scan_number, + scan_step_index) + return None @property def integrated_data_coordinates(self): - """ - Return a dictionary of coordinate arrays for navigating the dimension(s) - of the integrated data produced by this instance of `IntegrationConfig`. - - :return: A dictionary with either one or two keys: 'azimuthal' and/or - 'radial', each of which points to a 1-D `numpy` array of coordinate - values. + """Return a dictionary of coordinate arrays for navigating the + dimension(s) of the integrated data produced by this instance + of `IntegrationConfig`. + + :return: A dictionary with either one or two keys: 'azimuthal' + and/or 'radial', each of which points to a 1-D `numpy` + array of coordinate values. :rtype: dict[str,np.ndarray] """ if self.integration_type == 'azimuthal': - return(get_integrated_data_coordinates(radial_range=(self.radial_min,self.radial_max), - radial_npt=self.radial_npt)) - elif self.integration_type == 'radial': - return(get_integrated_data_coordinates(azimuthal_range=(self.azimuthal_min,self.azimuthal_max), - azimuthal_npt=self.azimuthal_npt)) - elif self.integration_type == 'cake': - return(get_integrated_data_coordinates(radial_range=(self.radial_min,self.radial_max), - radial_npt=self.radial_npt, - azimuthal_range=(self.azimuthal_min,self.azimuthal_max), - azimuthal_npt=self.azimuthal_npt)) + return get_integrated_data_coordinates( + radial_range=(self.radial_min,self.radial_max), + radial_npt=self.radial_npt) + if self.integration_type == 'radial': + return get_integrated_data_coordinates( + azimuthal_range=(self.azimuthal_min,self.azimuthal_max), + azimuthal_npt=self.azimuthal_npt) + if self.integration_type == 'cake': + return get_integrated_data_coordinates( + radial_range=(self.radial_min,self.radial_max), + radial_npt=self.radial_npt, + azimuthal_range=(self.azimuthal_min,self.azimuthal_max), + azimuthal_npt=self.azimuthal_npt) + return None + @property def integrated_data_dims(self): - """Return a tuple of the coordinate labels for the integrated data - produced by this instance of `IntegrationConfig`. + """Return a tuple of the coordinate labels for the integrated + data produced by this instance of `IntegrationConfig`. """ directions = list(self.integrated_data_coordinates.keys()) - dim_names = [getattr(self, f'{direction}_units') for direction in directions] - return(dim_names) + dim_names = [getattr(self, f'{direction}_units') \ + for direction in directions] + return dim_names + @property def integrated_data_shape(self): - """Return a tuple representing the shape of the integrated data - produced by this instance of `IntegrationConfig` for a single scan step. + """Return a tuple representing the shape of the integrated + data produced by this instance of `IntegrationConfig` for a + single scan step. """ - return(tuple([len(coordinate_values) for coordinate_name,coordinate_values in self.integrated_data_coordinates.items()])) + return tuple([len(coordinate_values) \ + for coordinate_name,coordinate_values \ + in self.integrated_data_coordinates.items()]) + @cache def get_azimuthal_adjustments(chi_min:float, chi_max:float): - """ - Fix chi discontinuity at 180 degrees and return the adjusted chi range, - offset, and discontinuty. + """Fix chi discontinuity at 180 degrees and return the adjusted + chi range, offset, and discontinuty. - If the discontinuity is crossed, obtain the offset to artificially rotate - detectors to achieve a continuous azimuthal integration range. + If the discontinuity is crossed, obtain the offset to artificially + rotate detectors to achieve a continuous azimuthal integration + range. :param chi_min: The minimum value of the azimuthal range. :type chi_min: float :param chi_max: The maximum value of the azimuthal range. :type chi_max: float - :return: The following four values: the adjusted minimum value of the - azimuthal range, the adjusted maximum value of the azimuthal range, the - value by which the chi angle was adjusted, the position of the chi - discontinuity. + :return: The following four values: the adjusted minimum value of + the azimuthal range, the adjusted maximum value of the + azimuthal range, the value by which the chi angle was + adjusted, the position of the chi discontinuity. """ # Fix chi discontinuity at 180 degrees for now. chi_disc = 180 - # If the discontinuity is crossed, artificially rotate the detectors to - # achieve a continuous azimuthal integration range + # If the discontinuity is crossed, artificially rotate the + # detectors to achieve a continuous azimuthal integration range if chi_min < chi_disc and chi_max > chi_disc: chi_offset = chi_max - chi_disc else: chi_offset = 0 - return(chi_min-chi_offset, chi_max-chi_offset, chi_offset, chi_disc) + return chi_min-chi_offset, chi_max-chi_offset, chi_offset, chi_disc + + @cache def get_azimuthal_integrators(poni_files:tuple, chi_offset=0): - """ - Return a list of `AzimuthalIntegrator` objects generated from PONI files. - - :param poni_files: Tuple of strings, each string being a path to a PONI file. : tuple + """Return a list of `AzimuthalIntegrator` objects generated from + PONI files. + + :param poni_files: Tuple of strings, each string being a path to a + PONI file. :type poni_files: tuple - :param chi_offset: The angle in degrees by which the `AzimuthalIntegrator` objects will be rotated, defaults to 0. + :param chi_offset: The angle in degrees by which the + `AzimuthalIntegrator` objects will be rotated, defaults to 0. :type chi_offset: float, optional :return: List of `AzimuthalIntegrator` objects :rtype: list[pyFAI.azimuthalIntegrator.AzimuthalIntegrator] @@ -455,62 +584,75 @@ def get_azimuthal_integrators(poni_files:tuple, chi_offset=0): ai = copy.deepcopy(azimuthal_integrator(poni_file)) ai.rot3 += chi_offset * np.pi/180 ais.append(ai) - return(ais) + return ais + + @cache -def get_multi_geometry_integrator(poni_files:tuple, radial_unit:str, radial_range:tuple, azimuthal_range:tuple): - """Return a `MultiGeometry` instance that can be used for azimuthal or cake - integration. +def get_multi_geometry_integrator(poni_files:tuple, radial_unit:str, + radial_range:tuple, azimuthal_range:tuple): + """Return a `MultiGeometry` instance that can be used for + azimuthal or cake integration. - :param poni_files: Tuple of PONI files that describe the detectors to be - integrated. + :param poni_files: Tuple of PONI files that describe the detectors + to be integrated. :type poni_files: tuple :param radial_unit: Unit to use for radial integration range. :type radial_unit: str :param radial_range: Tuple describing the range for radial integration. :type radial_range: tuple[float,float] - :param azimuthal_range:Tuple describing the range for azimuthal integration. - :type azimuthal_range: tuple[float,float] - :return: `MultiGeometry` instance that can be used for azimuthal or cake + :param azimuthal_range:Tuple describing the range for azimuthal integration. + :type azimuthal_range: tuple[float,float] + :return: `MultiGeometry` instance that can be used for azimuthal + or cake integration. :rtype: pyFAI.multi_geometry.MultiGeometry """ - chi_min, chi_max, chi_offset, chi_disc = get_azimuthal_adjustments(*azimuthal_range) - ais = copy.deepcopy(get_azimuthal_integrators(poni_files, chi_offset=chi_offset)) - multi_geometry = MultiGeometry(ais, - unit=radial_unit, - radial_range=radial_range, - azimuth_range=(chi_min,chi_max), - wavelength=sum([ai.wavelength for ai in ais])/len(ais), - chi_disc=chi_disc) - return(multi_geometry) + chi_min, chi_max, chi_offset, chi_disc = \ + get_azimuthal_adjustments(*azimuthal_range) + ais = copy.deepcopy(get_azimuthal_integrators(poni_files, + chi_offset=chi_offset)) + multi_geometry = MultiGeometry( + ais, + unit=radial_unit, + radial_range=radial_range, + azimuth_range=(chi_min,chi_max), + wavelength=sum([ai.wavelength for ai in ais])/len(ais), + chi_disc=chi_disc) + return multi_geometry + + @cache -def get_integrated_data_coordinates(azimuthal_range:tuple=None, azimuthal_npt:int=None, radial_range:tuple=None, radial_npt:int=None): - """ - Return a dictionary of coordinate arrays for the specified radial and/or - azimuthal integration ranges. - - :param azimuthal_range: Tuple specifying the range of azimuthal angles over - which to generate coordinates, in the format (min, max), defaults to - None. +def get_integrated_data_coordinates(azimuthal_range:tuple = None, + azimuthal_npt:int = None, + radial_range:tuple = None, + radial_npt:int = None): + """Return a dictionary of coordinate arrays for the specified + radial and/or azimuthal integration ranges. + + :param azimuthal_range: Tuple specifying the range of azimuthal + angles over which to generate coordinates, in the format (min, + max), defaults to None. :type azimuthal_range: tuple[float,float], optional - :param azimuthal_npt: Number of azimuthal coordinate points to generate, - defaults to None. + :param azimuthal_npt: Number of azimuthal coordinate points to + generate, defaults to None. :type azimuthal_npt: int, optional - :param radial_range: Tuple specifying the range of radial distances over - which to generate coordinates, in the format (min, max), defaults to - None. + :param radial_range: Tuple specifying the range of radial + distances over which to generate coordinates, in the format + (min, max), defaults to None. :type radial_range: tuple[float,float], optional - :param radial_npt: Number of radial coordinate points to generate, defaults - to None. + :param radial_npt: Number of radial coordinate points to generate, + defaults to None. :type radial_npt: int, optional - :return: A dictionary with either one or two keys: 'azimuthal' and/or - 'radial', each of which points to a 1-D `numpy` array of coordinate - values. + :return: A dictionary with either one or two keys: 'azimuthal' + and/or 'radial', each of which points to a 1-D `numpy` array + of coordinate values. :rtype: dict[str,np.ndarray] """ integrated_data_coordinates = {} if azimuthal_range is not None and azimuthal_npt is not None: - integrated_data_coordinates['azimuthal'] = np.linspace(*azimuthal_range, azimuthal_npt) + integrated_data_coordinates['azimuthal'] = np.linspace( + *azimuthal_range, azimuthal_npt) if radial_range is not None and radial_npt is not None: - integrated_data_coordinates['radial'] = np.linspace(*radial_range, radial_npt) - return(integrated_data_coordinates) + integrated_data_coordinates['radial'] = np.linspace( + *radial_range, radial_npt) + return integrated_data_coordinates diff --git a/CHAP/common/models/map.py b/CHAP/common/models/map.py index f8387a6..82731cf 100644 --- a/CHAP/common/models/map.py +++ b/CHAP/common/models/map.py @@ -6,17 +6,15 @@ from pydantic import (BaseModel, conint, conlist, - confloat, constr, FilePath, PrivateAttr, - ValidationError, validator) from pyspec.file.spec import FileSpec + class Sample(BaseModel): - """ - Class representing a sample metadata configuration. + """Class representing a sample metadata configuration. :ivar name: The name of the sample. :type name: str @@ -26,9 +24,9 @@ class Sample(BaseModel): name: constr(min_length=1) description: Optional[str] + class SpecScans(BaseModel): - """ - Class representing a set of scans from a single SPEC file. + """Class representing a set of scans from a single SPEC file. :ivar spec_file: Path to the SPEC file. :type spec_file: str @@ -37,10 +35,10 @@ class SpecScans(BaseModel): """ spec_file: FilePath scan_numbers: conlist(item_type=conint(gt=0), min_items=1) + @validator('spec_file', allow_reuse=True) def validate_spec_file(cls, spec_file): - """ - Validate the specified SPEC file. + """Validate the specified SPEC file. :param spec_file: Path to the SPEC file. :type spec_file: str @@ -52,19 +50,19 @@ def validate_spec_file(cls, spec_file): spec_file = os.path.abspath(spec_file) sspec_file = FileSpec(spec_file) except: - raise(ValueError(f'Invalid SPEC file {spec_file}')) - else: - return(spec_file) + raise ValueError(f'Invalid SPEC file {spec_file}') + return spec_file + @validator('scan_numbers', allow_reuse=True) def validate_scan_numbers(cls, scan_numbers, values): - """ - Validate the specified list of scan numbers. + """Validate the specified list of scan numbers. :param scan_numbers: List of scan numbers. :type scan_numbers: list of int :param values: Dictionary of values for all fields of the model. :type values: dict - :raises ValueError: If a specified scan number is not found in the SPEC file. + :raises ValueError: If a specified scan number is not found in + the SPEC file. :return: List of scan numbers. :rtype: list of int """ @@ -74,29 +72,32 @@ def validate_scan_numbers(cls, scan_numbers, values): for scan_number in scan_numbers: scan = spec_scans.get_scan_by_number(scan_number) if scan is None: - raise(ValueError(f'There is no scan number {scan_number} in {spec_file}')) - return(scan_numbers) + raise ValueError( + f'There is no scan number {scan_number} in {spec_file}') + return scan_numbers @property def scanparsers(self): - '''A list of `ScanParser`s for each of the scans specified by the SPEC - file and scan numbers belonging to this instance of `SpecScans` - ''' - return([self.get_scanparser(scan_no) for scan_no in self.scan_numbers]) + """A list of `ScanParser`s for each of the scans specified by + the SPEC file and scan numbers belonging to this instance of + `SpecScans` + """ + return [self.get_scanparser(scan_no) for scan_no in self.scan_numbers] def get_scanparser(self, scan_number): - """This method returns a `ScanParser` for the specified scan number in - the specified SPEC file. + """This method returns a `ScanParser` for the specified scan + number in the specified SPEC file. :param scan_number: Scan number to get a `ScanParser` for :type scan_number: int :return: `ScanParser` for the specified scan number :rtype: ScanParser """ - return(get_scanparser(self.spec_file, scan_number)) + return get_scanparser(self.spec_file, scan_number) + def get_index(self, scan_number:int, scan_step_index:int, map_config): - """This method returns a tuple representing the index of a specific step - in a specific spec scan within a map. + """This method returns a tuple representing the index of a + specific step in a specific spec scan within a map. :param scan_number: Scan number to get index for :type scan_number: int @@ -104,19 +105,29 @@ def get_index(self, scan_number:int, scan_step_index:int, map_config): :type scan_step_index: int :param map_config: Map configuration to get index for :type map_config: MapConfig - :return: Index for the specified scan number and scan step index within - the specified map configuration + :return: Index for the specified scan number and scan step + index within the specified map configuration :rtype: tuple """ index = () for independent_dimension in map_config.independent_dimensions: - coordinate_index = list(map_config.coords[independent_dimension.label]).index(independent_dimension.get_value(self, scan_number, scan_step_index)) + coordinate_index = list( + map_config.coords[independent_dimension.label] + ).index( + independent_dimension.get_value( + self, + scan_number, + scan_step_index) + ) index = (coordinate_index, *index) - return(index) - def get_detector_data(self, detectors:list, scan_number:int, scan_step_index:int): - """ - Return the raw data from the specified detectors at the specified scan - number and scan step index. + return index + + def get_detector_data(self, + detectors:list, + scan_number:int, + scan_step_index:int): + """Return the raw data from the specified detectors at the + specified scan number and scan step index. :param detectors: List of detector prefixes to get raw data for :type detectors: list[str] @@ -124,87 +135,113 @@ def get_detector_data(self, detectors:list, scan_number:int, scan_step_index:int :type scan_number: int :param scan_step_index: Scan step index to get data for :type scan_step_index: int - :return: Data from the specified detectors for the specified scan number - and scan step index + :return: Data from the specified detectors for the specified + scan number and scan step index :rtype: list[np.ndarray] """ - return(get_detector_data(tuple([detector.prefix for detector in detectors]), self.spec_file, scan_number, scan_step_index)) + return get_detector_data( + tuple([detector.prefix for detector in detectors]), + self.spec_file, + scan_number, + scan_step_index) + + @cache def get_available_scan_numbers(spec_file:str): scans = FileSpec(spec_file).scans scan_numbers = list(scans.keys()) - return(scan_numbers) + return scan_numbers + + @cache def get_scanparser(spec_file:str, scan_number:int): if scan_number not in get_available_scan_numbers(spec_file): - return(None) - else: - return(ScanParser(spec_file, scan_number)) + return None + return ScanParser(spec_file, scan_number) + + @lru_cache(maxsize=10) -def get_detector_data(detector_prefixes:tuple, spec_file:str, scan_number:int, scan_step_index:int): +def get_detector_data( + detector_prefixes:tuple, + spec_file:str, + scan_number:int, + scan_step_index:int): detector_data = [] scanparser = get_scanparser(spec_file, scan_number) for prefix in detector_prefixes: image_data = scanparser.get_detector_data(prefix, scan_step_index) detector_data.append(image_data) - return(detector_data) + return detector_data + class PointByPointScanData(BaseModel): - """Class representing a source of raw scalar-valued data for which a value - was recorded at every point in a `MapConfig`. + """Class representing a source of raw scalar-valued data for which + a value was recorded at every point in a `MapConfig`. - :ivar label: A user-defined label for referring to this data in the NeXus - file and in other tools. + :ivar label: A user-defined label for referring to this data in + the NeXus file and in other tools. :type label: str :ivar units: The units in which the data were recorded. :type units: str - :ivar data_type: Represents how these data were recorded at time of data - collection. + :ivar data_type: Represents how these data were recorded at time + of data collection. :type data_type: Literal['spec_motor', 'scan_column', 'smb_par'] - :ivar name: Represents the name with which these raw data were recorded at - time of data collection. + :ivar name: Represents the name with which these raw data were + recorded at time of data collection. :type name: str """ label: constr(min_length=1) units: constr(strip_whitespace=True, min_length=1) data_type: Literal['spec_motor', 'scan_column', 'smb_par'] name: constr(strip_whitespace=True, min_length=1) + @validator('label') def validate_label(cls, label): - """Validate that the supplied `label` does not conflict with any of the - values for `label` reserved for certain data needed to perform - corrections. + """Validate that the supplied `label` does not conflict with + any of the values for `label` reserved for certain data needed + to perform corrections. :param label: The value of `label` to validate :type label: str :raises ValueError: If `label` is one of the reserved values. - :return: The original supplied value `label`, if it is allowed. + :return: The original supplied value `label`, if it is + allowed. :rtype: str """ - #if (not issubclass(cls,CorrectionsData)) and label in CorrectionsData.__fields__['label'].type_.__args__: - if (not issubclass(cls,CorrectionsData)) and label in CorrectionsData.reserved_labels(): - raise(ValueError(f'{cls.__name__}.label may not be any of the following reserved values: {CorrectionsData.reserved_labels()}')) - return(label) + if ((not issubclass(cls,CorrectionsData)) + and label in CorrectionsData.reserved_labels()): + raise ValueError( + f'{cls.__name__}.label may not be any of the following ' + + f'reserved values: {CorrectionsData.reserved_labels()}') + return label + def validate_for_station(self, station:str): - """Validate this instance of `PointByPointScanData` for a certain choice - of station (beamline). + """Validate this instance of `PointByPointScanData` for a + certain choice of station (beamline). :param station: The name of the station (in 'idxx' format). :type station: str - :raises TypeError: If the station is not compatible with the value of the - `data_type` attribute for this instance of PointByPointScanData. + :raises TypeError: If the station is not compatible with the + value of the `data_type` attribute for this instance of + PointByPointScanData. :return: None :rtype: None """ - if station.lower() not in ('id1a3', 'id3a') and self.data_type == 'smb_par': - raise(TypeError(f'{self.__class__.__name__}.data_type may not be "smb_par" when station is "{station}"')) - def validate_for_spec_scans(self, spec_scans:list[SpecScans], scan_step_index:Union[Literal['all'],int]='all'): - """Validate this instance of `PointByPointScanData` for a list of - `SpecScans`. - - :param spec_scans: A list of `SpecScans` whose raw data will be checked - for the presence of the data represented by this instance of - `PointByPointScanData` + if (station.lower() not in ('id1a3', 'id3a') + and self.data_type == 'smb_par'): + raise TypeError( + f'{self.__class__.__name__}.data_type may not be "smb_par" ' + + f'when station is "{station}"') + + def validate_for_spec_scans( + self, spec_scans:list[SpecScans], + scan_step_index:Union[Literal['all'],int] = 'all'): + """Validate this instance of `PointByPointScanData` for a list + of `SpecScans`. + + :param spec_scans: A list of `SpecScans` whose raw data will + be checked for the presence of the data represented by + this instance of `PointByPointScanData` :type spec_scans: list[SpecScans] :param scan_step_index: A specific scan step index to validate, defaults to `'all'`. @@ -220,39 +257,64 @@ def validate_for_spec_scans(self, spec_scans:list[SpecScans], scan_step_index:Un if scan_step_index == 'all': scan_step_index_range = range(scanparser.spec_scan_npts) else: - scan_step_index_range = range(scan_step_index,scan_step_index+1) + scan_step_index_range = range(scan_step_index, + scan_step_index + 1) for index in scan_step_index_range: try: self.get_value(scans, scan_number, index) except: - raise(RuntimeError(f'Could not find data for {self.name} (data_type "{self.data_type}") on scan number {scan_number} for index {index} in spec file {scans.spec_file}')) - def get_value(self, spec_scans:SpecScans, scan_number:int, scan_step_index:int): - """Return the value recorded for this instance of `PointByPointScanData` - at a specific scan step. + raise RuntimeError( + f'Could not find data for {self.name} ' + + f'(data_type "{self.data_type}") ' + + f'on scan number {scan_number} ' + + f'for index {index} ' + + f'in spec file {scans.spec_file}') + + def get_value(self, spec_scans:SpecScans, + scan_number:int, scan_step_index:int): + """Return the value recorded for this instance of + `PointByPointScanData` at a specific scan step. - :param spec_scans: An instance of `SpecScans` in which the requested scan step occurs. + :param spec_scans: An instance of `SpecScans` in which the + requested scan step occurs. :type spec_scans: SpecScans - :param scan_number: The number of the scan in which the requested scan step occurs. + :param scan_number: The number of the scan in which the + requested scan step occurs. :type scan_number: int :param scan_step_index: The index of the requested scan step. :type scan_step_index: int - :return: The value recorded of the data represented by this instance of - `PointByPointScanData` at the scan step requested + :return: The value recorded of the data represented by this + instance of `PointByPointScanData` at the scan step + requested :rtype: float """ if self.data_type == 'spec_motor': - return(get_spec_motor_value(spec_scans.spec_file, scan_number, scan_step_index, self.name)) - elif self.data_type == 'scan_column': - return(get_spec_counter_value(spec_scans.spec_file, scan_number, scan_step_index, self.name)) - elif self.data_type == 'smb_par': - return(get_smb_par_value(spec_scans.spec_file, scan_number, self.name)) + return get_spec_motor_value(spec_scans.spec_file, + scan_number, + scan_step_index, + self.name) + if self.data_type == 'scan_column': + return get_spec_counter_value(spec_scans.spec_file, + scan_number, + scan_step_index, + self.name) + if self.data_type == 'smb_par': + return get_smb_par_value(spec_scans.spec_file, + scan_number, + self.name) + return None + @cache -def get_spec_motor_value(spec_file:str, scan_number:int, scan_step_index:int, spec_mnemonic:str): - """Return the value recorded for a SPEC motor at a specific scan step. +def get_spec_motor_value(spec_file:str, scan_number:int, + scan_step_index:int, spec_mnemonic:str): + """Return the value recorded for a SPEC motor at a specific scan + step. - :param spec_file: Location of a SPEC file in which the requested scan step occurs. + :param spec_file: Location of a SPEC file in which the requested + scan step occurs. :type spec_scans: str - :param scan_number: The number of the scan in which the requested scan step occurs. + :param scan_number: The number of the scan in which the requested + scan step occurs. :type scan_number: int :param scan_step_index: The index of the requested scan step. :type scan_step_index: int @@ -265,20 +327,30 @@ def get_spec_motor_value(spec_file:str, scan_number:int, scan_step_index:int, sp if spec_mnemonic in scanparser.spec_scan_motor_mnes: motor_i = scanparser.spec_scan_motor_mnes.index(spec_mnemonic) if scan_step_index >= 0: - scan_step = np.unravel_index(scan_step_index, scanparser.spec_scan_shape, order='F') - motor_value = scanparser.spec_scan_motor_vals[motor_i][scan_step[motor_i]] + scan_step = np.unravel_index( + scan_step_index, + scanparser.spec_scan_shape, + order='F') + motor_value = \ + scanparser.spec_scan_motor_vals[motor_i][scan_step[motor_i]] else: motor_value = scanparser.spec_scan_motor_vals[motor_i] else: motor_value = scanparser.get_spec_positioner_value(spec_mnemonic) - return(motor_value) + return motor_value + + @cache -def get_spec_counter_value(spec_file:str, scan_number:int, scan_step_index:int, spec_column_label:str): - """Return the value recorded for a SPEC counter at a specific scan step. +def get_spec_counter_value(spec_file:str, scan_number:int, + scan_step_index:int, spec_column_label:str): + """Return the value recorded for a SPEC counter at a specific scan + step. - :param spec_file: Location of a SPEC file in which the requested scan step occurs. + :param spec_file: Location of a SPEC file in which the requested + scan step occurs. :type spec_scans: str - :param scan_number: The number of the scan in which the requested scan step occurs. + :param scan_number: The number of the scan in which the requested + scan step occurs. :type scan_number: int :param scan_step_index: The index of the requested scan step. :type scan_step_index: int @@ -289,16 +361,20 @@ def get_spec_counter_value(spec_file:str, scan_number:int, scan_step_index:int, """ scanparser = get_scanparser(spec_file, scan_number) if scan_step_index >= 0: - return(scanparser.spec_scan_data[spec_column_label][scan_step_index]) - else: - return(scanparser.spec_scan_data[spec_column_label]) + return scanparser.spec_scan_data[spec_column_label][scan_step_index] + return scanparser.spec_scan_data[spec_column_label] + + @cache def get_smb_par_value(spec_file:str, scan_number:int, par_name:str): - """Return the value recorded for a specific scan in SMB-tyle .par file. + """Return the value recorded for a specific scan in SMB-tyle .par + file. - :param spec_file: Location of a SPEC file in which the requested scan step occurs. + :param spec_file: Location of a SPEC file in which the requested + scan step occurs. :type spec_scans: str - :param scan_number: The number of the scan in which the requested scan step occurs. + :param scan_number: The number of the scan in which the requested + scan step occurs. :type scan_number: int :param par_name: The name of the column in the .par file :type par_name: str @@ -306,121 +382,156 @@ def get_smb_par_value(spec_file:str, scan_number:int, par_name:str): :rtype: float """ scanparser = get_scanparser(spec_file, scan_number) - return(scanparser.pars[par_name]) + return scanparser.pars[par_name] + + def validate_data_source_for_map_config(data_source, values): + """Confirm that an instance of PointByPointScanData is valid for + the station and scans provided by a map configuration dictionary. + + :param data_source: the input object to validate + :type data_source: PintByPointScanData + :param values: the map configuration dictionary + :type values: dict + :raises Exception: if `data_source` cannot be validated for + `values`. + :return: `data_source`, iff it is valid. + :rtype: PointByPointScanData + """ if data_source is not None: import_scanparser(values.get('station'), values.get('experiment_type')) data_source.validate_for_station(values.get('station')) data_source.validate_for_spec_scans(values.get('spec_scans')) - return(data_source) + return data_source -class CorrectionsData(PointByPointScanData): - """Class representing the special instances of `PointByPointScanData` that - are used by certain kinds of `CorrectionConfig` tools. - :ivar label: One of the reserved values required by `CorrectionConfig`, - `'presample_intensity'`, `'postsample_intensity'`, or - `'dwell_time_actual'`. - :type label: Literal['presample_intensity','postsample_intensity','dwell_time_actual'] +class CorrectionsData(PointByPointScanData): + """Class representing the special instances of + `PointByPointScanData` that are used by certain kinds of + `CorrectionConfig` tools. + + :ivar label: One of the reserved values required by + `CorrectionConfig`, `'presample_intensity'`, + `'postsample_intensity'`, or `'dwell_time_actual'`. + :type label: Literal['presample_intensity', + 'postsample_intensity', + 'dwell_time_actual'] :ivar units: The units in which the data were recorded. :type units: str - :ivar data_type: Represents how these data were recorded at time of data - collection. + :ivar data_type: Represents how these data were recorded at time + of data collection. :type data_type: Literal['scan_column', 'smb_par'] - :ivar name: Represents the name with which these raw data were recorded at - time of data collection. + :ivar name: Represents the name with which these raw data were + recorded at time of data collection. :type name: str """ - label: Literal['presample_intensity','postsample_intensity','dwell_time_actual'] + label: Literal['presample_intensity', + 'postsample_intensity', + 'dwell_time_actual'] data_type: Literal['scan_column','smb_par'] + @classmethod def reserved_labels(cls): - """Return a list of all the labels reserved for corrections-related - scalar data. + """Return a list of all the labels reserved for + corrections-related scalar data. :return: A list of reserved labels :rtype: list[str] """ - return(list(cls.__fields__['label'].type_.__args__)) + return list(cls.__fields__['label'].type_.__args__) + + class PresampleIntensity(CorrectionsData): - """Class representing a source of raw data for the intensity of the beam that - is incident on the sample. + """Class representing a source of raw data for the intensity of + the beam that is incident on the sample. :ivar label: Must be `"presample_intensity"` :type label: Literal["presample_intensity"] :ivar units: Must be `"counts"` :type units: Literal["counts"] - :ivar data_type: Represents how these data were recorded at time of data - collection. + :ivar data_type: Represents how these data were recorded at time + of data collection. :type data_type: Literal['scan_column', 'smb_par'] - :ivar name: Represents the name with which these raw data were recorded at - time of data collection. + :ivar name: Represents the name with which these raw data were + recorded at time of data collection. :type name: str """ label: Literal['presample_intensity'] = 'presample_intensity' units: Literal['counts'] = 'counts' + + class PostsampleIntensity(CorrectionsData): - """Class representing a source of raw data for the intensity of the beam that - has passed through the sample. + """Class representing a source of raw data for the intensity of + the beam that has passed through the sample. :ivar label: Must be `"postsample_intensity"` :type label: Literal["postsample_intensity"] :ivar units: Must be `"counts"` :type units: Literal["counts"] - :ivar data_type: Represents how these data were recorded at time of data - collection. + :ivar data_type: Represents how these data were recorded at time + of data collection. :type data_type: Literal['scan_column', 'smb_par'] - :ivar name: Represents the name with which these raw data were recorded at - time of data collection. + :ivar name: Represents the name with which these raw data were + recorded at time of data collection. :type name: str """ label: Literal['postsample_intensity'] = 'postsample_intensity' units: Literal['counts'] = 'counts' + + class DwellTimeActual(CorrectionsData): - """Class representing a source of raw data for the actual dwell time at each - scan point in SPEC (with some scan types, this value can vary slightly - point-to-point from the dwell time specified in the command). + """Class representing a source of raw data for the actual dwell + time at each scan point in SPEC (with some scan types, this value + can vary slightly point-to-point from the dwell time specified in + the command). :ivar label: Must be `"dwell_time_actual"` :type label: Literal["dwell_time_actual"] :ivar units: Must be `"counts"` :type units: Literal["counts"] - :ivar data_type: Represents how these data were recorded at time of data - collection. + :ivar data_type: Represents how these data were recorded at time + of data collection. :type data_type: Literal['scan_column', 'smb_par'] - :ivar name: Represents the name with which these raw data were recorded at - time of data collection. + :ivar name: Represents the name with which these raw data were + recorded at time of data collection. :type name: str """ label: Literal['dwell_time_actual'] = 'dwell_time_actual' units: Literal['s'] = 's' + class MapConfig(BaseModel): - """Class representing an experiment consisting of one or more SPEC scans. + """Class representing an experiment consisting of one or more SPEC + scans. :ivar title: The title for the map configuration. :type title: str - :ivar station: The name of the station at which the map was collected. + :ivar station: The name of the station at which the map was + collected. :type station: Literal['id1a3','id3a','id3b'] :ivar spec_scans: A list of the spec scans that compose the map. :type spec_scans: list[SpecScans] - :ivar independent_dimensions: A list of the sources of data representing the - raw values of each independent dimension of the map. + :ivar independent_dimensions: A list of the sources of data + representing the raw values of each independent dimension of + the map. :type independent_dimensions: list[PointByPointScanData] - :ivar presample_intensity: A source of point-by-point presample beam - intensity data. Required when applying a CorrectionConfig tool. + :ivar presample_intensity: A source of point-by-point presample + beam intensity data. Required when applying a CorrectionConfig + tool. :type presample_intensity: Optional[PresampleIntensity] - :ivar dwell_time_actual: A source of point-by-point actual dwell times for - spec scans. Required when applying a CorrectionConfig tool. + :ivar dwell_time_actual: A source of point-by-point actual dwell + times for spec scans. Required when applying a + CorrectionConfig tool. :type dwell_time_actual: Optional[DwellTimeActual] - :ivar presample_intensity: A source of point-by-point postsample beam - intensity data. Required when applying a CorrectionConfig tool with - `correction_type="flux_absorption"` or + :ivar presample_intensity: A source of point-by-point postsample + beam intensity data. Required when applying a CorrectionConfig + tool with `correction_type="flux_absorption"` or `correction_type="flux_absorption_background"`. :type presample_intensity: Optional[PresampleIntensity] - :ivar scalar_data: A list of the sources of data representing other scalar - raw data values collected at each point ion the map. In the NeXus file - representation of the map, datasets for these values will be included. + :ivar scalar_data: A list of the sources of data representing + other scalar raw data values collected at each point ion the + map. In the NeXus file representation of the map, datasets for + these values will be included. :type scalar_values: Optional[list[PointByPointScanData]] """ title: constr(strip_whitespace=True, min_length=1) @@ -428,20 +539,37 @@ class MapConfig(BaseModel): experiment_type: Literal['SAXSWAXS', 'EDD', 'XRF', 'TOMO'] sample: Sample spec_scans: conlist(item_type=SpecScans, min_items=1) - independent_dimensions: conlist(item_type=PointByPointScanData, min_items=1) + independent_dimensions: conlist(item_type=PointByPointScanData, + min_items=1) presample_intensity: Optional[PresampleIntensity] dwell_time_actual: Optional[DwellTimeActual] postsample_intensity: Optional[PostsampleIntensity] scalar_data: Optional[list[PointByPointScanData]] = [] _coords: dict = PrivateAttr() - _validate_independent_dimensions = validator('independent_dimensions', each_item=True, allow_reuse=True)(validate_data_source_for_map_config) - _validate_presample_intensity = validator('presample_intensity', allow_reuse=True)(validate_data_source_for_map_config) - _validate_dwell_time_actual = validator('dwell_time_actual', allow_reuse=True)(validate_data_source_for_map_config) - _validate_postsample_intensity = validator('postsample_intensity', allow_reuse=True)(validate_data_source_for_map_config) - _validate_scalar_data = validator('scalar_data', each_item=True, allow_reuse=True)(validate_data_source_for_map_config) + + _validate_independent_dimensions = validator( + 'independent_dimensions', + each_item=True, + allow_reuse=True)(validate_data_source_for_map_config) + _validate_presample_intensity = validator( + 'presample_intensity', + allow_reuse=True)(validate_data_source_for_map_config) + _validate_dwell_time_actual = validator( + 'dwell_time_actual', + allow_reuse=True)(validate_data_source_for_map_config) + _validate_postsample_intensity = validator( + 'postsample_intensity', + allow_reuse=True)(validate_data_source_for_map_config) + _validate_scalar_data = validator( + 'scalar_data', + each_item=True, + allow_reuse=True)(validate_data_source_for_map_config) + @validator('experiment_type') def validate_experiment_type(cls, value, values): - '''Ensure values for the station and experiment_type fields are compatible''' + """Ensure values for the station and experiment_type fields are + compatible + """ station = values.get('station') if station == 'id1a3': allowed_experiment_types = ['SAXSWAXS', 'EDD', 'TOMO'] @@ -452,18 +580,22 @@ def validate_experiment_type(cls, value, values): else: allowed_experiment_types = [] if value not in allowed_experiment_types: - raise(ValueError(f'For station {station}, allowed experiment types are {allowed_experiment_types} (suuplied experiment type {value} is not allowed)')) - return(value) + raise ValueError( + f'For station {station}, allowed experiment types are ' + + f'{", ".join(allowed_experiment_types)}. ' + + f'Supplied experiment type {value} is not allowed.') + return value + @property def coords(self): - """Return a dictionary of the values of each independent dimension across - the map. + """Return a dictionary of the values of each independent + dimension across the map. :returns: A dictionary ofthe map's coordinate values. :rtype: dict[str,list[float]] """ try: - return(self._coords) + coords = self._coords except: coords = {} for independent_dimension in self.independent_dimensions: @@ -471,65 +603,89 @@ def coords(self): for scans in self.spec_scans: for scan_number in scans.scan_numbers: scanparser = scans.get_scanparser(scan_number) - for scan_step_index in range(scanparser.spec_scan_npts): - coords[independent_dimension.label].append(independent_dimension.get_value(scans, scan_number, scan_step_index)) - coords[independent_dimension.label] = np.unique(coords[independent_dimension.label]) - self._coords = coords - return(self._coords) + for scan_step_index in range( + scanparser.spec_scan_npts): + coords[independent_dimension.label].append( + independent_dimension.get_value( + scans, scan_number, scan_step_index)) + coords[independent_dimension.label] = np.unique( + coords[independent_dimension.label]) + self._coords = coords + return coords + @property def dims(self): - """Return a tuple of the independent dimension labels for the map.""" - return([point_by_point_scan_data.label for point_by_point_scan_data in self.independent_dimensions[::-1]]) + """Return a tuple of the independent dimension labels for the + map. + """ + return [point_by_point_scan_data.label \ + for point_by_point_scan_data \ + in self.independent_dimensions[::-1]] + @property def shape(self): - """Return the shape of the map -- a tuple representing the number of - unique values of each dimension across the map. + """Return the shape of the map -- a tuple representing the + number of unique values of each dimension across the map. """ - return(tuple([len(values) for key,values in self.coords.items()][::-1])) + return tuple([len(values) for key,values in self.coords.items()][::-1]) + @property def all_scalar_data(self): - """Return a list of all instances of `PointByPointScanData` for which - this map configuration will collect dataset-like data (as opposed to - axes-like data). + """Return a list of all instances of `PointByPointScanData` + for which this map configuration will collect dataset-like + data (as opposed to axes-like data). + + This will be any and all of the items in the + corrections-data-related fields, as well as any additional + items in the optional `scalar_data` field. + """ + return [getattr(self,l,None) \ + for l in CorrectionsData.reserved_labels() \ + if getattr(self,l,None) is not None] \ + + self.scalar_data - This will be any and all of the items in the corrections-data-related - fields, as well as any additional items in the optional `scalar_data` - field.""" - return([getattr(self,l,None) for l in CorrectionsData.reserved_labels() if getattr(self,l,None) is not None] + self.scalar_data) def import_scanparser(station, experiment): - '''Given the name of a CHESS station and experiment type, import the - corresponding subclass of `ScanParser` as `ScanParser`. + """Given the name of a CHESS station and experiment type, import + the corresponding subclass of `ScanParser` as `ScanParser`. :param station: The station name ("IDxx", not the beamline acronym) :type station: str :param experiment: The experiment type :type experiment: Literal["SAXSWAXS","EDD","XRF","Tomo","Powder"] :return: None - ''' + """ station = station.lower() experiment = experiment.lower() if station in ('id1a3', 'id3a'): if experiment in ('saxswaxs', 'powder'): - from CHAP.common.utils.scanparsers import SMBLinearScanParser as ScanParser + from CHAP.common.utils.scanparsers \ + import SMBLinearScanParser as ScanParser elif experiment == 'edd': - from CHAP.common.utils.scanparsers import SMBMCAScanParser as ScanParser + from CHAP.common.utils.scanparsers \ + import SMBMCAScanParser as ScanParser elif experiment == 'tomo': - from CHAP.common.utils.scanparsers import SMBRotationScanParser as ScanParser + from CHAP.common.utils.scanparsers \ + import SMBRotationScanParser as ScanParser else: - raise(ValueError(f'Invalid experiment type for station {station}: {experiment}')) + raise ValueError( + f'Invalid experiment type for station {station}: {experiment}') elif station == 'id3b': if experiment == 'saxswaxs': - from CHAP.common.utils.scanparsers import FMBSAXSWAXSScanParser as ScanParser + from CHAP.common.utils.scanparsers \ + import FMBSAXSWAXSScanParser as ScanParser elif experiment == 'tomo': - from CHAP.common.utils.scanparsers import FMBRotationScanParser as ScanParser + from CHAP.common.utils.scanparsers \ + import FMBRotationScanParser as ScanParser elif experiment == 'xrf': - from CHAP.common.utils.scanparsers import FMBXRFScanParser as ScanParser + from CHAP.common.utils.scanparsers \ + import FMBXRFScanParser as ScanParser else: - raise(ValueError(f'Invalid experiment type for station {station}: {experiment}')) + raise ValueError( + f'Invalid experiment type for station {station}: {experiment}') else: - raise(ValueError(f'Invalid station: {station}')) + raise ValueError(f'Invalid station: {station}') globals()['ScanParser'] = ScanParser diff --git a/CHAP/common/processor.py b/CHAP/common/processor.py index 5cecbc5..21c4f7c 100755 --- a/CHAP/common/processor.py +++ b/CHAP/common/processor.py @@ -1,128 +1,132 @@ #!/usr/bin/env python #-*- coding: utf-8 -*- #pylint: disable= -''' +""" File : processor.py Author : Valentin Kuznetsov Description: Module for Processors used in multiple experiment-specific workflows. -''' +""" # system modules -import argparse -import logging import json -import sys from time import time # local modules from CHAP import Processor + class AsyncProcessor(Processor): - '''A Processor to process multiple sets of input data via asyncio module + """A Processor to process multiple sets of input data via asyncio + module :ivar mgr: The `Processor` used to process every set of input data :type mgr: Processor - ''' + """ def __init__(self, mgr): super().__init__() self.mgr = mgr - def _process(self, docs): - '''Asynchronously process the input documents with the `self.mgr` - `Processor`. - - :param docs: input documents to process + def _process(self, data): + """Asynchronously process the input documents with the + `self.mgr` `Processor`. + + :param data: input data documents to process :type docs: iterable - ''' + """ import asyncio async def task(mgr, doc): - '''Process given data using provided `Processor` - + """Process given data using provided `Processor` + :param mgr: the object that will process given data :type mgr: Processor :param doc: the data to process :type doc: object :return: processed data :rtype: object - ''' + """ return mgr.process(doc) - async def executeTasks(mgr, docs): - '''Process given set of documents using provided task manager - + async def execute_tasks(mgr, docs): + """Process given set of documents using provided task + manager + :param mgr: the object that will process all documents :type mgr: Processor :param docs: the set of data documents to process :type doc: iterable - ''' - coRoutines = [task(mgr, d) for d in docs] - await asyncio.gather(*coRoutines) + """ + coroutines = [task(mgr, d) for d in docs] + await asyncio.gather(*coroutines) + + asyncio.run(execute_tasks(self.mgr, data)) - asyncio.run(executeTasks(self.mgr, docs)) class IntegrationProcessor(Processor): - '''A processor for integrating 2D data with pyFAI - ''' + """A processor for integrating 2D data with pyFAI""" def _process(self, data): - '''Integrate the input data with the integration method and keyword - arguments supplied and return the results. - - :param data: input data, including raw data, integration method, and - keyword args for the integration method. - :type data: tuple[typing.Union[numpy.ndarray, list[numpy.ndarray]], - callable, - dict] + """Integrate the input data with the integration method and + keyword arguments supplied and return the results. + + :param data: input data, including raw data, integration + method, and keyword args for the integration method. + :type data: tuple[typing.Union[numpy.ndarray, + list[numpy.ndarray]], callable, dict] :param integration_method: the method of a `pyFAI.azimuthalIntegrator.AzimuthalIntegrator` or - `pyFAI.multi_geometry.MultiGeometry` that returns the desired - integration results. + `pyFAI.multi_geometry.MultiGeometry` that returns the + desired integration results. :return: integrated raw data :rtype: pyFAI.containers.IntegrateResult - ''' - + """ detector_data, integration_method, integration_kwargs = data - return(integration_method(detector_data, **integration_kwargs)) + return integration_method(detector_data, **integration_kwargs) + class IntegrateMapProcessor(Processor): - '''Class representing a process that takes a map and integration - configuration and returns a `nexusformat.nexus.NXprocess` containing a map of - the integrated detector data requested. - ''' + """Class representing a process that takes a map and integration + configuration and returns a `nexusformat.nexus.NXprocess` + containing a map of the integrated detector data requested. + """ def _process(self, data): - '''Process the output of a `Reader` that contains a map and integration - configuration and return a `nexusformat.nexus.NXprocess` containing a map - of the integrated detector data requested - - :param data: Result of `Reader.read` where at least one item has the - value `'MapConfig'` for the `'schema'` key, and at least one item has - the value `'IntegrationConfig'` for the `'schema'` key. + """Process the output of a `Reader` that contains a map and + integration configuration and return a + `nexusformat.nexus.NXprocess` containing a map of the + integrated detector data requested + + :param data: Result of `Reader.read` where at least one item + has the value `'MapConfig'` for the `'schema'` key, and at + least one item has the value `'IntegrationConfig'` for the + `'schema'` key. :type data: list[dict[str,object]] :return: integrated data and process metadata :rtype: nexusformat.nexus.NXprocess - ''' + """ map_config, integration_config = self.get_configs(data) nxprocess = self.get_nxprocess(map_config, integration_config) - return(nxprocess) + return nxprocess def get_configs(self, data): - '''Return valid instances of `MapConfig` and `IntegrationConfig` from the - input supplied by `MultipleReader`. - - :param data: Result of `Reader.read` where at least one item has the - value `'MapConfig'` for the `'schema'` key, and at least one item has - the value `'IntegrationConfig'` for the `'schema'` key. + """Return valid instances of `MapConfig` and + `IntegrationConfig` from the input supplied by + `MultipleReader`. + + :param data: Result of `Reader.read` where at least one item + has the value `'MapConfig'` for the `'schema'` key, and at + least one item has the value `'IntegrationConfig'` for the + `'schema'` key. :type data: list[dict[str,object]] - :raises ValueError: if `data` cannot be parsed into map and integration configurations. + :raises ValueError: if `data` cannot be parsed into map and + integration configurations. :return: valid map and integration configuration objects. :rtype: tuple[MapConfig, IntegrationConfig] - ''' + """ self.logger.debug('Getting configuration objects') t0 = time() @@ -142,29 +146,30 @@ def get_configs(self, data): integration_config = item.get('data') if not map_config: - raise(ValueError('No map configuration found')) + raise ValueError('No map configuration found') if not integration_config: - raise(ValueError('No integration configuration found')) + raise ValueError('No integration configuration found') map_config = MapConfig(**map_config) integration_config = IntegrationConfig(**integration_config) - self.logger.debug(f'Got configuration objects in {time()-t0:.3f} seconds') + self.logger.debug('Got configuration objects in ' + + f'{time()-t0:.3f} seconds') - return(map_config, integration_config) + return map_config, integration_config def get_nxprocess(self, map_config, integration_config): - '''Use a `MapConfig` and `IntegrationConfig` to construct a + """Use a `MapConfig` and `IntegrationConfig` to construct a `nexusformat.nexus.NXprocess` :param map_config: a valid map configuration :type map_config: MapConfig :param integration_config: a valid integration configuration :type integration_config: IntegrationConfig - :return: the integrated detector data and metadata contained in a NeXus - structure + :return: the integrated detector data and metadata contained + in a NeXus structure :rtype: nexusformat.nexus.NXprocess - ''' + """ self.logger.debug('Constructing NXprocess') t0 = time() @@ -199,7 +204,8 @@ def get_nxprocess(self, map_config, integration_config): nxdetector.calibration_wavelength.attrs['units'] = 'm' nxdetector.attrs['poni_file'] = str(detector.poni_file) nxdetector.attrs['mask_file'] = str(detector.mask_file) - nxdetector.raw_data_files = np.full(map_config.shape, '', dtype='|S256') + nxdetector.raw_data_files = np.full(map_config.shape, + '', dtype='|S256') nxprocess.data = NXdata() @@ -226,7 +232,8 @@ def get_nxprocess(self, map_config, integration_config): getattr(integration_config, f'{coord_name}_units'), type_=type_) nxprocess.data[coord_units.name] = coord_values - nxprocess.data.attrs[f'{coord_units.name}_indices'] = i+len(map_config.coords) + nxprocess.data.attrs[f'{coord_units.name}_indices'] = i + len( + map_config.coords) nxprocess.data[coord_units.name].units = coord_units.unit_symbol nxprocess.data[coord_units.name].attrs['long_name'] = coord_units.label @@ -234,7 +241,9 @@ def get_nxprocess(self, map_config, integration_config): nxprocess.data.I = NXfield( value=np.empty( (*tuple( - [len(coord_values) for coord_name,coord_values in map_config.coords.items()][::-1] + [len(coord_values) \ + for coord_name,coord_values \ + in map_config.coords.items()][::-1] ), *integration_config.integrated_data_shape ) @@ -246,13 +255,17 @@ def get_nxprocess(self, map_config, integration_config): if integration_config.integration_type == 'azimuthal': integration_method = integrator.integrate1d integration_kwargs = { - 'lst_mask': [detector.mask_array for detector in integration_config.detectors], + 'lst_mask': [detector.mask_array \ + for detector \ + in integration_config.detectors], 'npt': integration_config.radial_npt } elif integration_config.integration_type == 'cake': integration_method = integrator.integrate2d integration_kwargs = { - 'lst_mask': [detector.mask_array for detector in integration_config.detectors], + 'lst_mask': [detector.mask_array \ + for detector \ + in integration_config.detectors], 'npt_rad': integration_config.radial_npt, 'npt_azim': integration_config.azimuthal_npt, 'method': 'bbox' @@ -261,7 +274,6 @@ def get_nxprocess(self, map_config, integration_config): integration_processor = IntegrationProcessor() integration_processor.logger.setLevel(self.logger.getEffectiveLevel()) integration_processor.logger.addHandler(self.logger.handlers[0]) - lst_args = [] for scans in map_config.spec_scans: for scan_number in scans.scan_numbers: scanparser = scans.get_scanparser(scan_number) @@ -286,42 +298,48 @@ def get_nxprocess(self, map_config, integration_config): self.logger.debug(f'Constructed NXprocess in {time()-t0:.3f} seconds') - return(nxprocess) + return nxprocess + class MapProcessor(Processor): - '''A Processor to take a map configuration and return a - `nexusformat.nexus.NXentry` representing that map's metadata and any - scalar-valued raw data requseted by the supplied map configuration. - ''' + """A Processor to take a map configuration and return a + `nexusformat.nexus.NXentry` representing that map's metadata and + any scalar-valued raw data requseted by the supplied map + configuration. + """ def _process(self, data): - '''Process the output of a `Reader` that contains a map configuration and - return a `nexusformat.nexus.NXentry` representing the map. + """Process the output of a `Reader` that contains a map + configuration and return a `nexusformat.nexus.NXentry` + representing the map. - :param data: Result of `Reader.read` where at least one item has the - value `'MapConfig'` for the `'schema'` key. + :param data: Result of `Reader.read` where at least one item + has the value `'MapConfig'` for the `'schema'` key. :type data: list[dict[str,object]] :return: Map data & metadata :rtype: nexusformat.nexus.NXentry - ''' + """ map_config = self.get_map_config(data) nxentry = self.__class__.get_nxentry(map_config) - return(nxentry) + return nxentry def get_map_config(self, data): - '''Get an instance of `MapConfig` from a returned value of `Reader.read` + """Get an instance of `MapConfig` from a returned value of + `Reader.read` - :param data: Result of `Reader.read` where at least one item has the - value `'MapConfig'` for the `'schema'` key. + :param data: Result of `Reader.read` where at least one item + has the value `'MapConfig'` for the `'schema'` key. :type data: list[dict[str,object]] - :raises Exception: If a valid `MapConfig` cannot be constructed from `data`. - :return: a valid instance of `MapConfig` with field values taken from `data`. + :raises Exception: If a valid `MapConfig` cannot be + constructed from `data`. + :return: a valid instance of `MapConfig` with field values + taken from `data`. :rtype: MapConfig - ''' + """ - from .models.map import MapConfig + from CHAP.common.models.map import MapConfig map_config = False if isinstance(data, list): @@ -332,19 +350,21 @@ def get_map_config(self, data): break if not map_config: - raise(ValueError('No map configuration found')) + raise ValueError('No map configuration found') - return(MapConfig(**map_config)) + return MapConfig(**map_config) @staticmethod def get_nxentry(map_config): - '''Use a `MapConfig` to construct a `nexusformat.nexus.NXentry` + """Use a `MapConfig` to construct a + `nexusformat.nexus.NXentry` :param map_config: a valid map configuration :type map_config: MapConfig - :return: the map's data and metadata contained in a NeXus structure + :return: the map's data and metadata contained in a NeXus + structure :rtype: nexusformat.nexus.NXentry - ''' + """ from nexusformat.nexus import (NXcollection, NXdata, @@ -411,23 +431,25 @@ def get_nxentry(map_config): scan_number, scan_step_index) - return(nxentry) + return nxentry + class NexusToNumpyProcessor(Processor): - '''A Processor to convert the default plottable data in an `NXobject` into - an `numpy.ndarray`. - ''' + """A Processor to convert the default plottable data in an + `NXobject` into an `numpy.ndarray`. + """ def _process(self, data): - '''Return the default plottable data signal in `data` as an + """Return the default plottable data signal in `data` as an `numpy.ndarray`. - + :param data: input NeXus structure :type data: nexusformat.nexus.tree.NXobject - :raises ValueError: if `data` has no default plottable data signal + :raises ValueError: if `data` has no default plottable data + signal :return: default plottable data signal in `data` :rtype: numpy.ndarray - ''' + """ default_data = data.plottable_data @@ -435,31 +457,35 @@ def _process(self, data): default_data_path = data.attrs['default'] default_data = data.get(default_data_path) if default_data is None: - raise(ValueError(f'The structure of {data} contains no default data')) + raise ValueError( + f'The structure of {data} contains no default data') default_signal = default_data.attrs.get('signal') if default_signal is None: - raise(ValueError(f'The signal of {default_data} is unknown')) + raise ValueError(f'The signal of {default_data} is unknown') default_signal = default_signal.nxdata np_data = default_data[default_signal].nxdata - return(np_data) + return np_data + class NexusToXarrayProcessor(Processor): - '''A Processor to convert the default plottable data in an `NXobject` into - an `xarray.DataArray`.''' + """A Processor to convert the default plottable data in an + `NXobject` into an `xarray.DataArray`. + """ def _process(self, data): - '''Return the default plottable data signal in `data` as an + """Return the default plottable data signal in `data` as an `xarray.DataArray`. - + :param data: input NeXus structure :type data: nexusformat.nexus.tree.NXobject - :raises ValueError: if metadata for `xarray` is absent from `data` + :raises ValueError: if metadata for `xarray` is absent from + `data` :return: default plottable data signal in `data` :rtype: xarray.DataArray - ''' + """ from xarray import DataArray @@ -469,11 +495,12 @@ def _process(self, data): default_data_path = data.attrs['default'] default_data = data.get(default_data_path) if default_data is None: - raise(ValueError(f'The structure of {data} contains no default data')) + raise ValueError( + f'The structure of {data} contains no default data') default_signal = default_data.attrs.get('signal') if default_signal is None: - raise(ValueError(f'The signal of {default_data} is unknown')) + raise ValueError(f'The signal of {default_data} is unknown') default_signal = default_signal.nxdata signal_data = default_data[default_signal].nxdata @@ -487,129 +514,138 @@ def _process(self, data): axis.attrs) dims = tuple(axes) - name = default_signal - attrs = default_data[default_signal].attrs - return(DataArray(data=signal_data, + return DataArray(data=signal_data, coords=coords, dims=dims, name=name, - attrs=attrs)) + attrs=attrs) + class PrintProcessor(Processor): - '''A Processor to simply print the input data to stdout and return the - original input data, unchanged in any way. - ''' + """A Processor to simply print the input data to stdout and return + the original input data, unchanged in any way. + """ def _process(self, data): - '''Print and return the input data. + """Print and return the input data. :param data: Input data :type data: object :return: `data` :rtype: object - ''' + """ print(f'{self.__name__} data :') if callable(getattr(data, '_str_tree', None)): - # If data is likely an NXobject, print its tree representation - # (since NXobjects' str representations are just their nxname) + # If data is likely an NXobject, print its tree + # representation (since NXobjects' str representations are + # just their nxname) print(data._str_tree(attrs=True, recursive=True)) else: print(str(data)) - return(data) + return data + class StrainAnalysisProcessor(Processor): - '''A Processor to compute a map of sample strains by fitting bragg peaks in - 1D detector data and analyzing the difference between measured peak - locations and expected peak locations for the sample measured. - ''' + """A Processor to compute a map of sample strains by fitting bragg + peaks in 1D detector data and analyzing the difference between + measured peak locations and expected peak locations for the sample + measured. + """ def _process(self, data): - '''Process the input map detector data & configuration for the strain - analysis procedure, and return a map of sample strains. + """Process the input map detector data & configuration for the + strain analysis procedure, and return a map of sample strains. - :param data: results of `MutlipleReader.read` containing input map - detector data and strain analysis configuration + :param data: results of `MutlipleReader.read` containing input + map detector data and strain analysis configuration :type data: dict[list[str,object]] :return: map of sample strains :rtype: xarray.Dataset - ''' + """ strain_analysis_config = self.get_config(data) - return(data) + return data def get_config(self, data): - '''Get instances of the configuration objects needed by this + """Get instances of the configuration objects needed by this `Processor` from a returned value of `Reader.read` - :param data: Result of `Reader.read` where at least one item has the - value `'StrainAnalysisConfig'` for the `'schema'` key. + :param data: Result of `Reader.read` where at least one item + has the value `'StrainAnalysisConfig'` for the `'schema'` + key. :type data: list[dict[str,object]] - :raises Exception: If valid config objects cannot be constructed from `data`. - :return: valid instances of the configuration objects with field values - taken from `data`. + :raises Exception: If valid config objects cannot be + constructed from `data`. + :return: valid instances of the configuration objects with + field values taken from `data`. :rtype: StrainAnalysisConfig - ''' + """ strain_analysis_config = False if isinstance(data, list): for item in data: if isinstance(item, dict): - schema = item.get('schema') if item.get('schema') == 'StrainAnalysisConfig': strain_analysis_config = item.get('data') if not strain_analysis_config: - raise(ValueError('No strain analysis configuration found in input data')) + raise ValueError( + 'No strain analysis configuration found in input data') - return(strain_analysis_config) + return strain_analysis_config class URLResponseProcessor(Processor): - '''A Processor to decode and return data resulting from from URLReader.read''' + """A Processor to decode and return data resulting from from + URLReader.read + """ + def _process(self, data): - '''Take data returned from URLReader.read and return a decoded version - of the content. + """Take data returned from URLReader.read and return a decoded + version of the content. :param data: input data (output of URLReader.read) :type data: list[dict] :return: decoded data contents :rtype: object - ''' + """ data = data[0] content = data['data'] encoding = data['encoding'] - self.logger.debug(f'Decoding content of type {type(content)} with {encoding}') + self.logger.debug(f'Decoding content of type {type(content)} ' + + f'with {encoding}') try: content = content.decode(encoding) except: - self.logger.warning(f'Failed to decode content of type {type(content)} with {encoding}') + self.logger.warning('Failed to decode content of type ' + + f'{type(content)} with {encoding}') - return(content) + return content class XarrayToNexusProcessor(Processor): - '''A Processor to convert the data in an `xarray` structure to an + """A Processor to convert the data in an `xarray` structure to an `nexusformat.nexus.NXdata`. - ''' + """ def _process(self, data): - '''Return `data` represented as an `nexusformat.nexus.NXdata`. + """Return `data` represented as an `nexusformat.nexus.NXdata`. :param data: The input `xarray` structure :type data: typing.Union[xarray.DataArray, xarray.Dataset] :return: The data and metadata in `data` :rtype: nexusformat.nexus.NXdata - ''' + """ from nexusformat.nexus import NXdata, NXfield @@ -617,28 +653,28 @@ def _process(self, data): axes = [] for name, coord in data.coords.items(): - axes.append(NXfield(value=coord.data, name=name, attrs=coord.attrs)) + axes.append( + NXfield(value=coord.data, name=name, attrs=coord.attrs)) axes = tuple(axes) - return(NXdata(signal=signal, axes=axes)) + return NXdata(signal=signal, axes=axes) class XarrayToNumpyProcessor(Processor): - '''A Processor to convert the data in an `xarray.DataArray` structure to an - `numpy.ndarray`. - ''' + """A Processor to convert the data in an `xarray.DataArray` + structure to an `numpy.ndarray`. + """ def _process(self, data): - '''Return just the signal values contained in `data`. - + """Return just the signal values contained in `data`. + :param data: The input `xarray.DataArray` :type data: xarray.DataArray :return: The data in `data` :rtype: numpy.ndarray - ''' + """ - return(data.data) + return data.data if __name__ == '__main__': from CHAP.processor import main main() - diff --git a/CHAP/common/reader.py b/CHAP/common/reader.py index 083c39e..d2ff54b 100755 --- a/CHAP/common/reader.py +++ b/CHAP/common/reader.py @@ -1,44 +1,47 @@ #!/usr/bin/env python -''' +""" File : reader.py Author : Valentin Kuznetsov -Description: Module for Writers used in multiple experiment-specific workflows. -''' +Description: Module for Writers used in multiple experiment-specific + workflows. +""" # system modules -import argparse -import json -import logging import sys from time import time # local modules from CHAP import Reader + class BinaryFileReader(Reader): + """Reader for binary files""" def _read(self, filename): - '''Return a content of a given file name + """Return a content of a given file name :param filename: name of the binart file to read from :return: the content of `filename` :rtype: binary - ''' + """ + with open(filename, 'rb') as file: data = file.read() - return(data) + return data + class MultipleReader(Reader): - def read(self, readers): - '''Return resuts from multiple `Reader`s. + """Reader to deliver combined results from other Readers""" + def read(self, readers, **_read_kwargs): + """Return resuts from multiple `Reader`s. - :param readers: a dictionary where the keys are specific names that are - used by the next item in the `Pipeline`, and the values are `Reader` - configurations. + :param readers: a dictionary where the keys are specific names + that are used by the next item in the `Pipeline`, and the + values are `Reader` configurations. :type readers: list[dict] - :return: The results of calling `Reader.read(**kwargs)` for each item - configured in `readers`. + :return: The results of calling `Reader.read(**kwargs)` for + each item configured in `readers`. :rtype: list[dict[str,object]] - ''' + """ t0 = time() self.logger.info(f'Executing "read" with {len(readers)} Readers') @@ -50,69 +53,81 @@ def read(self, readers): reader = reader_class() reader_kwargs = reader_config[reader_name] - data.extend(reader.read(**reader_kwargs)) + # Combine keyword arguments to MultipleReader.read with + # those to the reader giving precedence to those in the + # latter + combined_kwargs = {**_read_kwargs, **reader_kwargs} + data.extend(reader.read(**combined_kwargs)) self.logger.info(f'Finished "read" in {time()-t0:.3f} seconds\n') - return(data) + return data + class NexusReader(Reader): + """Reader for NeXus files""" def _read(self, filename, nxpath='/'): - '''Return the NeXus object stored at `nxpath` in the nexus file - `filename`. + """Return the NeXus object stored at `nxpath` in the nexus + file `filename`. :param filename: name of the NeXus file to read from :type filename: str - :param nxpath: path to a specific loaction in the NeXus file to read - from, defaults to `'/'` + :param nxpath: path to a specific loaction in the NeXus file + to read from, defaults to `'/'` :type nxpath: str, optional - :raises nexusformat.nexus.NeXusError: if `filename` is not a NeXus - file or `nxpath` is not in `filename`. + :raises nexusformat.nexus.NeXusError: if `filename` is not a + NeXus file or `nxpath` is not in `filename`. :return: the NeXus structure indicated by `filename` and `nxpath`. :rtype: nexusformat.nexus.NXobject - ''' + """ from nexusformat.nexus import nxload nxobject = nxload(filename)[nxpath] - return(nxobject) + return nxobject + class URLReader(Reader): - def _read(self, url, headers={}): - '''Make an HTTPS request to the provided URL and return the results. - Headers for the request are optional. + """Reader for data available over HTTPS""" + def _read(self, url, headers={}, timeout=10): + """Make an HTTPS request to the provided URL and return the + results. Headers for the request are optional. :param url: the URL to read :type url: str - :param headers: headers to attach to the request, defaults to `{}` + :param headers: headers to attach to the request, defaults to + `{}` :type headers: dict, optional :return: the content of the response :rtype: object - ''' + """ import requests - resp = requests.get(url, headers=headers) + resp = requests.get(url, headers=headers, timeout=timeout) data = resp.content self.logger.debug(f'Response content: {data}') - return(data) + return data + class YAMLReader(Reader): + """Reader for YAML files""" def _read(self, filename): - '''Return a dictionary from the contents of a yaml file. + """Return a dictionary from the contents of a yaml file. :param filename: name of the YAML file to read from :return: the contents of `filename` :rtype: dict - ''' + """ import yaml with open(filename) as file: data = yaml.safe_load(file) - return(data) + return data + if __name__ == '__main__': from CHAP.reader import main diff --git a/CHAP/common/utils/scanparsers.py b/CHAP/common/utils/scanparsers.py index 46e5e06..5e3cc96 100755 --- a/CHAP/common/utils/scanparsers.py +++ b/CHAP/common/utils/scanparsers.py @@ -5,27 +5,38 @@ # system modules import csv import fnmatch -from functools import cache import json import os import re -# necessary for the base class, ScanParser: +# other modules import numpy as np -from pyspec.file.spec import FileSpec +from pyspec.file.spec import FileSpec + + +class ScanParser: + """Partial implementation of a class representing a SPEC scan and + some of its metadata. + + :param spec_file_name: path to a SPEC file on the CLASSE DAQ + :type spec_file_name: str + :param scan_number: the number of a scan in the SPEC file provided + with `spec_file_name` + :type scan_number: int + """ -class ScanParser(object): def __init__(self, spec_file_name:str, scan_number:int): + """Constructor method""" self.spec_file_name = spec_file_name self.scan_number = scan_number - + self._scan_path = None self._scan_name = None self._scan_title = None - + self._spec_scan = None self._spec_command = None self._spec_macro = None @@ -33,18 +44,21 @@ def __init__(self, self._spec_scan_npts = None self._spec_scan_data = None self._spec_positioner_values = None - + self._detector_data_path = None - + def __repr__(self): - return(f'{self.__class__.__name__}({self.spec_file_name}, {self.scan_number}) -- {self.spec_command}') - + return (f'{self.__class__.__name__}(' + + f'{self.spec_file_name}, ' + + f'{self.scan_number}) ' + + f'-- {self.spec_command}') + @property def spec_file(self): - # NB This FileSpec instance is not stored as a private attribute because - # it cannot be pickled (and therefore could cause problems for - # parallel code that uses ScanParsers). - return(FileSpec(self.spec_file_name)) + # NB This FileSpec instance is not stored as a private + # attribute because it cannot be pickled (and therefore could + # cause problems for parallel code that uses ScanParsers). + return FileSpec(self.spec_file_name) @property def scan_path(self): if self._scan_path is None: @@ -100,26 +114,82 @@ def detector_data_path(self): if self._detector_data_path is None: self._detector_data_path = self.get_detector_data_path() return self._detector_data_path - + def get_scan_path(self): - return(os.path.dirname(self.spec_file_name)) + """Return the name of the directory containining the SPEC file + for this scan. + + :rtype: str + """ + return os.path.dirname(self.spec_file_name) + def get_scan_name(self): - return(None) + """Return the name of this SPEC scan (not unique to scans + within a single spec file). + + :rtype: str + """ + raise NotImplementedError + def get_scan_title(self): - return(None) + """Return the title of this spec scan (unique to each scan + within a spec file). + + :rtype: str + """ + raise NotImplementedError + def get_spec_scan(self): - return(self.spec_file.getScanByNumber(self.scan_number)) + """Return the `pyspec.file.spec.Scan` object parsed from the + spec file and scan number provided to the constructor. + + :rtype: pyspec.file.spec.Scan + """ + return self.spec_file.getScanByNumber(self.scan_number) + def get_spec_command(self): - return(self.spec_scan.command) + """Return the string command of this SPEC scan. + + :rtype: str + """ + return self.spec_scan.command + def get_spec_macro(self): - return(self.spec_command.split()[0]) + """Return the macro used in this scan's SPEC command. + + :rtype: str + """ + return self.spec_command.split()[0] + def get_spec_args(self): - return(self.spec_command.split()[1:]) + """Return a list of the arguments provided to the macro for + this SPEC scan. + + :rtype: list[str] + """ + return self.spec_command.split()[1:] + def get_spec_scan_npts(self): - raise(NotImplementedError) + """Return the number of points collected in this SPEC scan + + :rtype: int + """ + raise NotImplementedError + def get_spec_scan_data(self): - return(dict(zip(self.spec_scan.labels, self.spec_scan.data.T))) + """Return a dictionary of all the counter data collected by + this SPEC scan. + + :rtype: dict[str, numpy.ndarray] + """ + return dict(zip(self.spec_scan.labels, self.spec_scan.data.T)) + def get_spec_positioner_values(self): + """Return a dictionary of all the SPEC positioner values + recorded by SPEC just before the scan began. + + :rtype: dict[str,str] + """ positioner_values = dict(self.spec_scan.motor_positions) names = list(positioner_values.keys()) mnemonics = self.spec_scan.motors @@ -127,73 +197,118 @@ def get_spec_positioner_values(self): for name,mnemonic in zip(names,mnemonics): if name != mnemonic: positioner_values[mnemonic] = positioner_values[name] - return(positioner_values) + return positioner_values + def get_detector_data_path(self): - raise(NotImplementedError) - - def get_detector_data_file(self, detector_prefix, scan_step_index:int): - raise(NotImplementedError) - def get_detector_data(self, detector_prefix, scan_step_index:int): - ''' - Return a np.ndarray of detector data. + """Return the name of the directory containing detector data + collected by this scan. + + :rtype: str + """ + raise NotImplementedError - :param detector_prefix: The detector's name in any data files, often - the EPICS macro $(P). - :type detector_substring: str + def get_detector_data_file(self, detector_prefix, scan_step_index:int): + """Return the name of the file containing detector data + collected at a certain step of this scan. - :param scan_step_index: The index of the scan step for which detector - data will be returned. + :param detector_prefix: the prefix used in filenames for the + detector + :type detector_prefix: str + :param scan_step_index: the index of the point in this scan + whose detector file name should be returned. :type scan_step_index: int + :rtype: str + """ + raise NotImplementedError - :return: The detector data - :rtype: np.ndarray - ''' - raise(NotImplementedError) + def get_detector_data(self, detector_prefix, scan_step_index:int): + """Return the detector data collected at a certain step of + this scan. + + :param detector_prefix: the prefix used in filenames for the + detector + :type detector_prefix: str + :param scan_step_index: the index of the point in this scan + whose detector data should be returned. + :type scan_step_index: int + :rtype: numpy.ndarray + """ + raise NotImplementedError def get_spec_positioner_value(self, positioner_name): + """Return the value of a spec positioner recorded before this + scan began. + + :param positioner_name: the name or mnemonic of a SPEC motor + whose position should be returned. + :raises KeyError: if `positioner_name` is not the name or + mnemonic of a SPEC motor recorded for this scan. + :raises ValueError: if the recorded string value of the + positioner in the SPEC file cannot be converted to a + float. + :rtype: float + """ try: positioner_value = self.spec_positioner_values[positioner_name] positioner_value = float(positioner_value) - return(positioner_value) except KeyError: - raise(KeyError(f'{self.scan_title}: motor {positioner_name} not found for this scan')) + raise KeyError(f'{self.scan_title}: motor {positioner_name} ' + + 'not found for this scan') except ValueError: - raise(ValueError(f'{self.scan_title}: ccould not convert value of {positioner_name} to float: {positioner_value}')) + raise ValueError(f'{self.scan_title}: could not convert value of' + + f' {positioner_name} to float: ' + + f'{positioner_value}') + return positioner_value class FMBScanParser(ScanParser): - def __init__(self, spec_file_name, scan_number): - super().__init__(spec_file_name, scan_number) + """Partial implementation of a class representing a SPEC scan + collected at FMB. + """ + def get_scan_name(self): - return(os.path.basename(self.spec_file.abspath)) - def get_scan_title(self): - return(f'{self.scan_name}_{self.scan_number:03d}') + return os.path.basename(self.spec_file.abspath) + def get_scan_title(self): + return f'{self.scan_name}_{self.scan_number:03d}' class SMBScanParser(ScanParser): + """Partial implementation of a class representing a SPEC scan + collected at SMB or FAST. + """ + def __init__(self, spec_file_name, scan_number): super().__init__(spec_file_name, scan_number) - - self._pars = None # purpose: store values found in the .par file as a dictionary + + self._pars = None self.par_file_pattern = f'*-*-{self.scan_name}' - + def get_scan_name(self): - return(os.path.basename(self.scan_path)) + return os.path.basename(self.scan_path) + def get_scan_title(self): - return(f'{self.scan_name}_{self.scan_number}') - + return f'{self.scan_name}_{self.scan_number}' + @property def pars(self): if self._pars is None: self._pars = self.get_pars() - return(self._pars) - + return self._pars + def get_pars(self): + """Return a dictionary of values recorded in the .par file + associated with this SPEC scan. + + :rtype: dict[str,object] + """ # JSON file holds titles for columns in the par file - json_files = fnmatch.filter(os.listdir(self.scan_path), f'{self.par_file_pattern}.json') - if not len(json_files) == 1: - raise(RuntimeError(f'{self.scan_title}: cannot find the .json file to decode the .par file')) + json_files = fnmatch.filter( + os.listdir(self.scan_path), + f'{self.par_file_pattern}.json') + if len(json_files) != 1: + raise RuntimeError(f'{self.scan_title}: cannot find the ' + + '.json file to decode the .par file') with open(os.path.join(self.scan_path, json_files[0])) as json_file: par_file_cols = json.load(json_file) try: @@ -201,11 +316,16 @@ def get_pars(self): scann_val_idx = par_col_names.index('SCAN_N') scann_col_idx = int(list(par_file_cols.keys())[scann_val_idx]) except: - raise(RuntimeError(f'{self.scan_title}: cannot find scan pars without a "SCAN_N" column in the par file')) - - par_files = fnmatch.filter(os.listdir(self.scan_path), f'{self.par_file_pattern}.par') - if not len(par_files) == 1: - raise(RuntimeError(f'{self.scan_title}: cannot find the .par file for this scan directory')) + raise RuntimeError(f'{self.scan_title}: cannot find scan pars ' + + 'without a "SCAN_N" column in the par file') + + par_files = fnmatch.filter( + os.listdir(self.scan_path), + f'{self.par_file_pattern}.par') + if len(par_files) != 1: + raise RuntimeError(f'{self.scan_title}: cannot find the .par ' + + 'file for this scan directory') + par_dict = None with open(os.path.join(self.scan_path, par_files[0])) as par_file: par_reader = csv.reader(par_file, delimiter=' ') for row in par_reader: @@ -214,7 +334,8 @@ def get_pars(self): if row_scann == self.scan_number: par_dict = {} for par_col_idx,par_col_name in par_file_cols.items(): - # Convert the string par value from the file to an int or float, if possible. + # Convert the string par value from the + # file to an int or float, if possible. par_value = row[int(par_col_idx)] try: par_value = int(par_value) @@ -224,27 +345,50 @@ def get_pars(self): except: pass par_dict[par_col_name] = par_value - return(par_dict) - raise(RuntimeError(f'{self.scan_title}: could not find scan pars for scan number {self.scan_number}')) - + + if par_dict is None: + raise RuntimeError(f'{self.scan_title}: could not find scan pars ' + + 'for scan number {self.scan_number}') + return par_dict + def get_counter_gain(self, counter_name): + """Return the gain of a counter as recorded in the comments of + a scan in a SPEC file converted to nA/V. + + :param counter_name: the name of the counter + :type counter_name: str + :rtype: str + """ + counter_gain = None for comment in self.spec_scan.comments: - match = re.search(f'{counter_name} gain: (?P\d+) (?P[m|u|n])A/V', comment) + match = re.search( + f'{counter_name} gain: ' # start of counter gain comments + + '(?P\d+) ' # gain numerical value + + '(?P[m|u|n])A/V', # gain units + comment) if match: unit_prefix = match['unit_prefix'] - gain_scalar = 1 if unit_prefix == 'n' else 1e3 if unit_prefix == 'u' else 1e6 + gain_scalar = 1 if unit_prefix == 'n' \ + else 1e3 if unit_prefix == 'u' \ + else 1e6 counter_gain = f'{float(match["gain_value"])*gain_scalar} nA/V' - return(counter_gain) - raise(RuntimeError(f'{self.scan_title}: could not get gain for counter {counter_name}')) + + if counter_gain is None: + raise RuntimeError(f'{self.scan_title}: could not get gain for ' + + f'counter {counter_name}') + return counter_gain class LinearScanParser(ScanParser): + """Partial implementation of a class representing a typical line + or mesh scan in SPEC. + """ def __init__(self, spec_file_name, scan_number): super().__init__(spec_file_name, scan_number) - + self._spec_scan_motor_mnes = None self._spec_scan_motor_vals = None - self._spec_scan_shape = None + self._spec_scan_shape = None self._spec_scan_dwell = None @property @@ -266,24 +410,90 @@ def spec_scan_shape(self): def spec_scan_dwell(self): if self._spec_scan_dwell is None: self._spec_scan_dwell = self.get_spec_scan_dwell() - return(self._spec_scan_dwell) + return self._spec_scan_dwell + + def get_spec_scan_motor_mnes(self): + """Return the mnemonics of the SPEC motor(s) provided to the + macro for this scan. If there is more than one motor scanned + (in a "flymesh" scan, for example), the order of motors in the + returned tuple will go from the fastest moving motor first to + the slowest moving motor last. + + :rtype: tuple + """ + raise NotImplementedError - def get_spec_scan_motor_names(self): - raise(NotImplementedError) def get_spec_scan_motor_vals(self): - raise(NotImplementedError) + """Return the values visited by each of the scanned motors. If + there is more than one motor scanned (in a "flymesh" scan, for + example), the order of motor values in the returned tuple will + go from the fastest moving motor's values first to the slowest + moving motor's values last. + + :rtype: tuple + """ + raise NotImplementedError + def get_spec_scan_shape(self): - raise(NotImplementedError) + """Return the number of points visited by each of the scanned + motors. If there is more than one motor scanned (in a + "flymesh" scan, for example), the order of number of motor + values in the returned tuple will go from the number of points + visited by the fastest moving motor first to the the number of + points visited by the slowest moving motor last. + + :rtype: tuple + """ + raise NotImplementedError + + def get_spec_scan_dwell(self): + """Return the dwell time for each point in the scan as it + appears in the command string. + + :rtype: float + """ + raise NotImplementedError + def get_spec_scan_npts(self): - return(np.prod(self.spec_scan_shape)) + """Return the number of points collected in this SPEC scan. + + :rtype: int + """ + return np.prod(self.spec_scan_shape) + def get_scan_step(self, scan_step_index:int): + """Return the index of each motor coordinate corresponding to + the index of a single point in the scan. If there is more than + one motor scanned (in a "flymesh" scan, for example), the + order of indices in the returned tuple will go from the index + of the value of the fastest moving motor first to the index of + the value of the slowest moving motor last. + + :param scan_step_index: the index of a single point in the + scan. + :type scan_step_index: int + :rtype: tuple + """ scan_steps = np.ndindex(self.spec_scan_shape[::-1]) i = 0 while i <= scan_step_index: scan_step = next(scan_steps) i += 1 - return(scan_step) + return scan_step + def get_scan_step_index(self, scan_step:tuple): + """Return the index of a single scan point corresponding to a + tuple of indices for each scanned motor coordinate. + + :param scan_step: a tuple of the indices of each scanned motor + coordinate. If there is more than one motor scanned (in a + "flymesh" scan, for example), the order of indices should + go from the index of the value of the fastest moving motor + first to the index of the value of the slowest moving + motor last. + :type scan_step: tuple + :trype: int + """ scan_steps = np.ndindex(self.spec_scan_shape[::-1]) scan_step_found = False scan_step_index = -1 @@ -293,166 +503,219 @@ def get_scan_step_index(self, scan_step:tuple): if next_scan_step == scan_step: scan_step_found = True break - return(scan_step_index) + return scan_step_index class FMBLinearScanParser(LinearScanParser, FMBScanParser): - def __init__(self, spec_file_name, scan_number): - super().__init__(spec_file_name, scan_number) - + """Partial implementation of a class representing a typical line + or mesh scan in SPEC collected at FMB. + """ + def get_spec_scan_motor_mnes(self): if self.spec_macro == 'flymesh': - return((self.spec_args[0], self.spec_args[5])) - elif self.spec_macro == 'flyscan': - return((self.spec_args[0],)) - elif self.spec_macro in ('tseries', 'loopscan'): - return(('Time',)) - else: - raise(RuntimeError(f'{self.scan_title}: cannot determine scan motors for scans of type {self.spec_macro}')) + return (self.spec_args[0], self.spec_args[5]) + if self.spec_macro == 'flyscan': + return (self.spec_args[0],) + if self.spec_macro in ('tseries', 'loopscan'): + return ('Time',) + raise RuntimeError(f'{self.scan_title}: cannot determine scan motors ' + + f'for scans of type {self.spec_macro}') + def get_spec_scan_motor_vals(self): if self.spec_macro == 'flymesh': - fast_mot_vals = np.linspace(float(self.spec_args[1]), float(self.spec_args[2]), int(self.spec_args[3])+1) - slow_mot_vals = np.linspace(float(self.spec_args[6]), float(self.spec_args[7]), int(self.spec_args[8])+1) - return((fast_mot_vals, slow_mot_vals)) - elif self.spec_macro == 'flyscan': - mot_vals = np.linspace(float(self.spec_args[1]), float(self.spec_args[2]), int(self.spec_args[3])+1) - return((mot_vals,)) - elif self.spec_macro in ('tseries', 'loopscan'): - return(self.spec_scan.data[:,0]) - else: - raise(RuntimeError(f'{self.scan_title}: cannot determine scan motors for scans of type {self.spec_macro}')) + fast_mot_vals = np.linspace(float(self.spec_args[1]), + float(self.spec_args[2]), + int(self.spec_args[3])+1) + slow_mot_vals = np.linspace(float(self.spec_args[6]), + float(self.spec_args[7]), + int(self.spec_args[8])+1) + return (fast_mot_vals, slow_mot_vals) + if self.spec_macro == 'flyscan': + mot_vals = np.linspace(float(self.spec_args[1]), + float(self.spec_args[2]), + int(self.spec_args[3])+1) + return (mot_vals,) + if self.spec_macro in ('tseries', 'loopscan'): + return self.spec_scan.data[:,0] + raise RuntimeError(f'{self.scan_title}: cannot determine scan motors ' + + f'for scans of type {self.spec_macro}') + def get_spec_scan_shape(self): if self.spec_macro == 'flymesh': fast_mot_npts = int(self.spec_args[3])+1 slow_mot_npts = int(self.spec_args[8])+1 - return((fast_mot_npts, slow_mot_npts)) - elif self.spec_macro == 'flyscan': + return (fast_mot_npts, slow_mot_npts) + if self.spec_macro == 'flyscan': mot_npts = int(self.spec_args[3])+1 - return((mot_npts,)) - elif self.spec_macro in ('tseries', 'loopscan'): - return(len(np.array(self.spec_scan.data[:,0]))) - else: - raise(RuntimeError(f'{self.scan_title}: cannot determine scan shape for scans of type {self.spec_macro}')) + return (mot_npts,) + if self.spec_macro in ('tseries', 'loopscan'): + return len(np.array(self.spec_scan.data[:,0])) + raise RuntimeError(f'{self.scan_title}: cannot determine scan shape ' + + f'for scans of type {self.spec_macro}') + def get_spec_scan_dwell(self): if self.spec_macro in ('flymesh', 'flyscan'): - return(float(self.spec_args[4])) - elif self.spec_macro in ('tseries', 'loopscan'): - return(float(self.spec_args[1])) - else: - raise(RuntimeError(f'{self.scan_title}: cannot determine dwell for scans of type {self.spec_macro}')) + return float(self.spec_args[4]) + if self.spec_macro in ('tseries', 'loopscan'): + return float(self.spec_args[1]) + raise RuntimeError(f'{self.scan_title}: cannot determine dwell for ' + + f'scans of type {self.spec_macro}') + def get_detector_data_path(self): - return(os.path.join(self.scan_path, self.scan_title)) + return os.path.join(self.scan_path, self.scan_title) class FMBSAXSWAXSScanParser(FMBLinearScanParser): - def __init__(self, spec_file_name, scan_number): - super().__init__(spec_file_name, scan_number) + """Concrete implementation of a class representing a scan taken + with the typical SAXS/WAXS setup at FMB. + """ def get_scan_title(self): - return(f'{self.scan_name}_{self.scan_number:03d}') + return f'{self.scan_name}_{self.scan_number:03d}' + def get_detector_data_file(self, detector_prefix, scan_step_index:int): scan_step = self.get_scan_step(scan_step_index) - file_indices = [f'{scan_step[i]:03d}' for i in range(len(self.spec_scan_shape)) if self.spec_scan_shape[i] != 1] - file_name = f'{self.scan_name}_{detector_prefix}_{self.scan_number:03d}_{"_".join(file_indices)}.tiff' + file_indices = [f'{scan_step[i]:03d}' \ + for i in range(len(self.spec_scan_shape)) \ + if self.spec_scan_shape[i] != 1] + file_name = (f'{self.scan_name}_' + + f'{detector_prefix}_' + + f'{self.scan_number:03d}_' + + '_'.join(file_indices) + + '.tiff') file_name_full = os.path.join(self.detector_data_path, file_name) if os.path.isfile(file_name_full): - return(file_name_full) - else: - raise(RuntimeError(f'{self.scan_title}: could not find detector image file for detector {detector_prefix} scan step ({scan_step})')) + return file_name_full + raise RuntimeError(f'{self.scan_title}: could not find detector image ' + + f'file for detector {detector_prefix} scan step ' + + f'({scan_step})') + def get_detector_data(self, detector_prefix, scan_step_index:int): from pyspec.file.tiff import TiffFile - image_file = self.get_detector_data_file(detector_prefix, scan_step_index) + image_file = self.get_detector_data_file(detector_prefix, + scan_step_index) with TiffFile(image_file) as tiff_file: image_data = tiff_file.asarray() - return(image_data) + return image_data class FMBXRFScanParser(FMBLinearScanParser): - def __init__(self, spec_file_name, scan_number): - super().__init__(spec_file_name, scan_number) + """Concrete implementation of a class representing a scan taken + with the typical XRF setup at FMB. + """ + def get_scan_title(self): - return(f'{self.scan_name}_scan{self.scan_number}') + return f'{self.scan_name}_scan{self.scan_number}' + def get_detector_data_file(self, detector_prefix, scan_step_index:int): scan_step = self.get_scan_step(scan_step_index) file_name = f'scan{self.scan_number}_{scan_step[1]:03d}.hdf5' file_name_full = os.path.join(self.detector_data_path, file_name) if os.path.isfile(file_name_full): - return(file_name_full) - else: - raise(RuntimeError(f'{self.scan_title}: could not find detector image file for detector {detector_prefix} scan step ({scan_step_index})')) + return file_name_full + raise RuntimeError(f'{self.scan_title}: could not find detector image ' + + f'file for detector {detector_prefix} scan step ' + + f'({scan_step_index})') + def get_detector_data(self, detector_prefix, scan_step_index:int): import h5py - detector_file = self.get_detector_data_file(detector_prefix, scan_step_index) + detector_file = self.get_detector_data_file(detector_prefix, + scan_step_index) scan_step = self.get_scan_step(scan_step_index) with h5py.File(detector_file) as h5_file: detector_data = h5_file['/entry/instrument/detector/data'][scan_step[0]] - return(detector_data) + return detector_data class SMBLinearScanParser(LinearScanParser, SMBScanParser): - def __init__(self, spec_file_name, scan_number): - super().__init__(spec_file_name, scan_number) + """Concrete implementation of a class representing a scan taken + with the typical powder diffraction setup at SMB. + """ + def get_spec_scan_motor_mnes(self): if self.spec_macro == 'flymesh': - return((self.spec_args[0], self.spec_args[5])) - elif self.spec_macro == 'flyscan': - return((self.spec_args[0],)) - elif self.spec_macro in ('tseries', 'loopscan'): - return(('Time',)) - else: - raise(RuntimeError(f'{self.scan_title}: cannot determine scan motors for scans of type {self.spec_macro}')) + return (self.spec_args[0], self.spec_args[5]) + if self.spec_macro == 'flyscan': + return (self.spec_args[0],) + if self.spec_macro in ('tseries', 'loopscan'): + return ('Time',) + raise RuntimeError(f'{self.scan_title}: cannot determine scan motors ' + + f'for scans of type {self.spec_macro}') + def get_spec_scan_motor_vals(self): if self.spec_macro == 'flymesh': - fast_mot_vals = np.linspace(float(self.spec_args[1]), float(self.spec_args[2]), int(self.spec_args[3])+1) - slow_mot_vals = np.linspace(float(self.spec_args[6]), float(self.spec_args[7]), int(self.spec_args[8])+1) - return((fast_mot_vals, slow_mot_vals)) - elif self.spec_macro == 'flyscan': - mot_vals = np.linspace(float(self.spec_args[1]), float(self.spec_args[2]), int(self.spec_args[3])+1) - return((mot_vals,)) - elif self.spec_macro in ('tseries', 'loopscan'): - return(self.spec_scan.data[:,0]) - else: - raise(RuntimeError(f'{self.scan_title}: cannot determine scan motors for scans of type {self.spec_macro}')) + fast_mot_vals = np.linspace(float(self.spec_args[1]), + float(self.spec_args[2]), + int(self.spec_args[3])+1) + slow_mot_vals = np.linspace(float(self.spec_args[6]), + float(self.spec_args[7]), + int(self.spec_args[8])+1) + return (fast_mot_vals, slow_mot_vals) + if self.spec_macro == 'flyscan': + mot_vals = np.linspace(float(self.spec_args[1]), + float(self.spec_args[2]), + int(self.spec_args[3])+1) + return (mot_vals,) + if self.spec_macro in ('tseries', 'loopscan'): + return self.spec_scan.data[:,0] + raise RuntimeError(f'{self.scan_title}: cannot determine scan motors ' + + f'for scans of type {self.spec_macro}') + def get_spec_scan_shape(self): if self.spec_macro == 'flymesh': fast_mot_npts = int(self.spec_args[3])+1 slow_mot_npts = int(self.spec_args[8])+1 - return((fast_mot_npts, slow_mot_npts)) - elif self.spec_macro == 'flyscan': + return (fast_mot_npts, slow_mot_npts) + if self.spec_macro == 'flyscan': mot_npts = int(self.spec_args[3])+1 - return((mot_npts,)) - elif self.spec_macro in ('tseries', 'loopscan'): - return(len(np.array(self.spec_scan.data[:,0]))) - else: - raise(RuntimeError(f'{self.scan_title}: cannot determine scan shape for scans of type {self.spec_macro}')) + return (mot_npts,) + if self.spec_macro in ('tseries', 'loopscan'): + return len(np.array(self.spec_scan.data[:,0])) + raise RuntimeError(f'{self.scan_title}: cannot determine scan shape ' + + f'for scans of type {self.spec_macro}') + def get_spec_scan_dwell(self): if self.spec_macro == 'flymesh': - return(float(self.spec_args[4])) - elif self.spec_macro == 'flyscan': - return(float(self.spec_args[-1])) - else: - raise(RuntimeError(f'{self.scan_title}: cannot determine dwell time for scans of type {self.spec_macro}')) + return float(self.spec_args[4]) + if self.spec_macro == 'flyscan': + return float(self.spec_args[-1]) + raise RuntimeError(f'{self.scan_title}: cannot determine dwell time ' + + f'for scans of type {self.spec_macro}') + def get_detector_data_path(self): - return(os.path.join(self.scan_path, str(self.scan_number))) + return os.path.join(self.scan_path, str(self.scan_number)) + def get_detector_data_file(self, detector_prefix, scan_step_index:int): scan_step = self.get_scan_step(scan_step_index) if len(scan_step) == 1: scan_step = (0, *scan_step) - file_name_pattern = f'{detector_prefix}_{self.scan_name}_*_{scan_step[0]}_data_{(scan_step[1]+1):06d}.h5' - file_name_matches = fnmatch.filter(os.listdir(self.detector_data_path), file_name_pattern) + file_name_pattern = (f'{detector_prefix}_' + + f'{self.scan_name}_*_' + + f'{scan_step[0]}_data_' + + f'{(scan_step[1]+1):06d}.h5') + file_name_matches = fnmatch.filter( + os.listdir(self.detector_data_path), + file_name_pattern) if len(file_name_matches) == 1: - return(os.path.join(self.detector_data_path, file_name_matches[0])) - else: - raise(RuntimeError(f'{self.scan_title}: could not find detector image file for detector {detector_prefix} scan step ({scan_step_index})')) + return os.path.join(self.detector_data_path, file_name_matches[0]) + raise RuntimeError(f'{self.scan_title}: could not find detector image ' + + f'file for detector {detector_prefix} scan step ' + + f'({scan_step_index})') + def get_detector_data(self, detector_prefix, scan_step_index:int): import h5py - image_file = self.get_detector_data_file(detector_prefix, scan_step_index) + image_file = self.get_detector_data_file(detector_prefix, + scan_step_index) with h5py.File(image_file) as h5_file: image_data = h5_file['/entry/data/data'][0] - return(image_data) + return image_data class RotationScanParser(ScanParser): + """Partial implementation of a class representing a rotation + scan. + """ + def __init__(self, spec_file_name, scan_number): super().__init__(spec_file_name, scan_number) @@ -467,106 +730,164 @@ def __init__(self, spec_file_name, scan_number): def scan_type(self): if self._scan_type is None: self._scan_type = self.get_scan_type() - return(self._scan_type) + return self._scan_type @property def theta_vals(self): if self._theta_vals is None: self._theta_vals = self.get_theta_vals() - return(self._theta_vals) + return self._theta_vals @property def horizontal_shift(self): if self._horizontal_shift is None: self._horizontal_shift = self.get_horizontal_shift() - return(self._horizontal_shift) + return self._horizontal_shift @property def vertical_shift(self): if self._vertical_shift is None: self._vertical_shift = self.get_vertical_shift() - return(self._vertical_shift) + return self._vertical_shift @property def starting_image_index(self): if self._starting_image_index is None: self._starting_image_index = self.get_starting_image_index() - return(self._starting_image_index) + return self._starting_image_index @property def starting_image_offset(self): if self._starting_image_offset is None: self._starting_image_offset = self.get_starting_image_offset() - return(self._starting_image_offset) - + return self._starting_image_offset + def get_scan_type(self): - return(None) + """Return a string identifier for the type of tomography data + being collected by this scan: df1 (dark field), bf1 (bright + field), or tf1 (sample tomography data). + + :rtype: typing.Literal['df1', 'bf1', 'tf1'] + """ + return None + def get_theta_vals(self): - raise(NotImplementedError) + """Return a dictionary of information about the angular values + visited by the rotating motor at each point in the scan. The + dictionary may contain a single key, "num", or three keys: + "num", "start", and "end" + + :rtype: dict[str, float]""" + raise NotImplementedError + def get_horizontal_shift(self): - raise(NotImplementedError) + """Return the value of the motor that shifts the sample in the + +x direction (hutch frame). Useful when tomography scans are + taken in a series of stacks when the sample is wider than the + width of the beam. + + :rtype: float + """ + raise NotImplementedError + def get_vertical_shift(self): - raise(NotImplementedError) + """Return the value of the motor that shifts the sample in the + +z direction (hutch frame). Useful when tomography scans are + taken in a series of stacks when the sample is taller than the + height of the beam. + + :rtype: float + """ + raise NotImplementedError + def get_starting_image_index(self): - raise(NotImplementedError) + """Return the index of the first frame of detector data + collected by this scan. + + :rtype: int + """ + raise NotImplementedError + def get_starting_image_offset(self): - raise(NotImplementedError) + """Return the offet of the index of the first "good" frame of + detector data collected by this scan from the index of the + first frame of detector data collected by this scan. + + :rtype: int + """ + raise NotImplementedError + def get_num_image(self, detector_prefix): - raise(NotImplementedError) + """Return the total number of "good" frames of detector data + collected by this scan + + :rtype: int + """ + raise NotImplementedError class FMBRotationScanParser(RotationScanParser, FMBScanParser): - def __init__(self, spec_file_name, scan_number): - super().__init__(spec_file_name, scan_number) + """Concrete implementation of a class representing a scan taken + with the typical tomography setup at FMB. + """ + def get_spec_scan_npts(self): if self.spec_macro == 'flyscan': if len(self.spec_args) == 2: # Flat field (dark or bright) - return(int(self.spec_args[0])+1) - elif len(self.spec_args) == 5: - return(int(self.spec_args[3])+1) - else: - raise(RuntimeError(f'{self.scan_title}: cannot obtain number of points from '+ - f'{self.spec_macro} with arguments {self.spec_args}')) - else: - raise(RuntimeError(f'{self.scan_title}: cannot determine number of points for scans '+ - f'of type {self.spec_macro}')) + return int(self.spec_args[0])+1 + if len(self.spec_args) == 5: + return int(self.spec_args[3])+1 + raise RuntimeError(f'{self.scan_title}: cannot obtain number of ' + + f'points from {self.spec_macro} with ' + + f'arguments {self.spec_args}') + raise RuntimeError(f'{self.scan_title}: cannot determine number of ' + + f'points for scans of type {self.spec_macro}') + def get_theta_vals(self): if self.spec_macro == 'flyscan': if len(self.spec_args) == 2: # Flat field (dark or bright) - return({'num': int(self.spec_args[0])}) - elif len(self.spec_args) == 5: - return({'start': float(self.spec_args[1]), 'end': float(self.spec_args[2]), - 'num': int(self.spec_args[3])+1}) - else: - raise(RuntimeError(f'{self.scan_title}: cannot obtain theta values from '+ - f'{self.spec_macro} with arguments {self.spec_args}')) - else: - raise(RuntimeError(f'{self.scan_title}: cannot determine theta values for scans '+ - f'of type {self.spec_macro}')) + return {'num': int(self.spec_args[0])} + if len(self.spec_args) == 5: + return {'start': float(self.spec_args[1]), + 'end': float(self.spec_args[2]), + 'num': int(self.spec_args[3])+1} + raise RuntimeError(f'{self.scan_title}: cannot obtain theta values' + + f' from {self.spec_macro} with arguments ' + + f'{self.spec_args}') + raise RuntimeError(f'{self.scan_title}: cannot determine theta values ' + + f'for scans of type {self.spec_macro}') + def get_horizontal_shift(self): - return(0.0) + return 0.0 + def get_vertical_shift(self): - return(float(self.get_spec_positioner_value('4C_samz'))) + return float(self.get_spec_positioner_value('4C_samz')) + def get_starting_image_index(self): - return(0) + return 0 + def get_starting_image_offset(self): - return(1) + return 1 + def get_num_image(self, detector_prefix): import h5py detector_file = self.get_detector_data_file(detector_prefix) with h5py.File(detector_file) as h5_file: num_image = h5_file['/entry/instrument/detector/data'].shape[0] - return(num_image-self.starting_image_offset) + return num_image-self.starting_image_offset + def get_detector_data_path(self): - return(self.scan_path) + return self.scan_path + def get_detector_data_file(self, detector_prefix): prefix = detector_prefix.upper() file_name = f'{self.scan_name}_{prefix}_{self.scan_number:03d}.h5' file_name_full = os.path.join(self.detector_data_path, file_name) if os.path.isfile(file_name_full): - return(file_name_full) - else: - raise(RuntimeError(f'{self.scan_title}: could not find detector image file for '+ - f'detector {detector_prefix}')) - #@cache - def get_all_detector_data_in_file(self, detector_prefix, scan_step_index=None): + return file_name_full + raise RuntimeError(f'{self.scan_title}: could not find detector image ' + + f'file for detector {detector_prefix}') + + def get_all_detector_data_in_file(self, + detector_prefix, + scan_step_index=None): import h5py detector_file = self.get_detector_data_file(detector_prefix) with h5py.File(detector_file) as h5_file: @@ -576,182 +897,233 @@ def get_all_detector_data_in_file(self, detector_prefix, scan_step_index=None): elif isinstance(scan_step_index, int): detector_data = h5_file['/entry/instrument/detector/data'][ self.starting_image_index+scan_step_index] - elif isinstance(scan_step_index, (list, tuple)) and len(scan_step_index) == 2: + elif (isinstance(scan_step_index, (list, tuple)) + and len(scan_step_index) == 2): detector_data = h5_file['/entry/instrument/detector/data'][ self.starting_image_index+scan_step_index[0]: self.starting_image_index+scan_step_index[1]] else: - raise(ValueError(f'Invalid parameter scan_step_index ({scan_step_index})')) - return(detector_data) + raise ValueError('Invalid parameter scan_step_index ' + + f'({scan_step_index})') + return detector_data + def get_detector_data(self, detector_prefix, scan_step_index=None): - return(self.get_all_detector_data_in_file(detector_prefix, scan_step_index)) + return self.get_all_detector_data_in_file(detector_prefix, + scan_step_index) class SMBRotationScanParser(RotationScanParser, SMBScanParser): + """Concrete implementation of a class representing a scan taken + with the typical tomography setup at SMB. + """ + def __init__(self, spec_file_name, scan_number): super().__init__(spec_file_name, scan_number) + self.par_file_pattern = f'id*-*tomo*-{self.scan_name}' + def get_spec_scan_npts(self): - if self.spec_macro == 'slew_ome' or self.spec_macro == 'rams4_slew_ome': - return(int(self.pars['nframes_real'])) - else: - raise(RuntimeError(f'{self.scan_title}: cannot determine number of points for scans of type {self.spec_macro}')) + if self.spec_macro in ('slew_ome','rams4_slew_ome'): + return int(self.pars['nframes_real']) + raise RuntimeError(f'{self.scan_title}: cannot determine number of ' + + f'points for scans of type {self.spec_macro}') + def get_scan_type(self): - try: - return(self.pars['tomo_type']) - except: - try: - return(self.pars['tomotype']) - except: - raise(RuntimeError(f'{self.scan_title}: cannot determine the scan_type')) + scan_type = self.pars.get('tomo_type', + self.pars.get('tomotype', None)) + if scan_type is None: + raise RuntimeError(f'{self.scan_title}: cannot determine ' + + 'the scan_type') + return scan_type + def get_theta_vals(self): - return({'start': float(self.pars['ome_start_real']), - 'end': float(self.pars['ome_end_real']), 'num': int(self.pars['nframes_real'])}) + return {'start': float(self.pars['ome_start_real']), + 'end': float(self.pars['ome_end_real']), + 'num': int(self.pars['nframes_real'])} + def get_horizontal_shift(self): - try: - return(float(self.pars['rams4x'])) - except: - try: - return(float(self.pars['ramsx'])) - except: - raise(RuntimeError(f'{self.scan_title}: cannot determine the horizontal shift')) + horizontal_shift = self.pars.get('rams4x', + self.pars.get('ramsx'), None) + if horizontal_shift is None: + raise RuntimeError(f'{self.scan_title}: cannot determine the ' + + 'horizontal shift') + return horizontal_shift + def get_vertical_shift(self): - try: - return(float(self.pars['rams4z'])) - except: - try: - return(float(self.pars['ramsz'])) - except: - raise(RuntimeError(f'{self.scan_title}: cannot determine the vertical shift')) + vertical_shift = self.pars.get('rams4z', + self.pars.get('ramsz'), None) + if vertical_shift is None: + raise RuntimeError(f'{self.scan_title}: cannot determine the ' + + 'vertical shift') + return vertical_shift + def get_starting_image_index(self): try: - return(int(self.pars['junkstart'])) + return int(self.pars['junkstart']) except: - raise(RuntimeError(f'{self.scan_title}: cannot determine first detector image index')) + raise RuntimeError(f'{self.scan_title}: cannot determine first ' + + 'detector image index') + def get_starting_image_offset(self): try: - return(int(self.pars['goodstart'])-self.get_starting_image_index()) + return (int(self.pars['goodstart']) + - self.get_starting_image_index()) except: - raise(RuntimeError(f'{self.scan_title}: cannot determine index offset of first good '+ - 'detector image')) + raise RuntimeError(f'{self.scan_title}: cannot determine index ' + + 'offset of first good detector image') + def get_num_image(self, detector_prefix=None): try: - return(int(self.pars['nframes_real'])) -# indexRegex = re.compile(r'\d+') + return int(self.pars['nframes_real']) +# index_regex = re.compile(r'\d+') # # At this point only tiffs # path = self.get_detector_data_path() -# files = sorted([f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f)) and -# f.endswith('.tif') and indexRegex.search(f)]) -# return(len(files)-self.starting_image_offset) +# files = sorted([f for f in os.listdir(path) \ +# if os.path.isfile(os.path.join(path, f)) \ +# and f.endswith('.tif') \ +# and index_regex.search(f)]) +# return len(files)-self.starting_image_offset except: - raise(RuntimeError(f'{self.scan_title}: cannot determine the number of good '+ - 'detector images')) + raise RuntimeError(f'{self.scan_title}: cannot determine the ' + + 'number of good detector images') + def get_detector_data_path(self): - return(os.path.join(self.scan_path, str(self.scan_number), 'nf')) + return os.path.join(self.scan_path, str(self.scan_number), 'nf') + def get_detector_data_file(self, scan_step_index:int): file_name = f'nf_{self.starting_image_index+scan_step_index:06d}.tif' file_name_full = os.path.join(self.detector_data_path, file_name) if os.path.isfile(file_name_full): - return(file_name_full) - else: - raise(RuntimeError(f'{self.scan_title}: could not find detector image file for '+ - f'scan step ({scan_step_index})')) + return file_name_full + raise RuntimeError(f'{self.scan_title}: could not find detector image ' + + f'file for scan step ({scan_step_index})') + def get_detector_data(self, detector_prefix, scan_step_index=None): if scan_step_index is None: detector_data = [] for index in range(len(self.get_num_image(detector_prefix))): - detector_data.append(self.get_detector_data(detector_prefix, index)) + detector_data.append(self.get_detector_data(detector_prefix, + index)) detector_data = np.asarray(detector_data) elif isinstance(scan_step_index, int): image_file = self.get_detector_data_file(scan_step_index) from pyspec.file.tiff import TiffFile with TiffFile(image_file) as tiff_file: detector_data = tiff_file.asarray() - elif isinstance(scan_step_index, (list, tuple)) and len(scan_step_index) == 2: + elif (isinstance(scan_step_index, (list, tuple)) + and len(scan_step_index) == 2): detector_data = [] for index in range(scan_step_index[0], scan_step_index[1]): - detector_data.append(self.get_detector_data(detector_prefix, index)) + detector_data.append(self.get_detector_data(detector_prefix, + index)) detector_data = np.asarray(detector_data) else: - raise(ValueError(f'Invalid parameter scan_step_index ({scan_step_index})')) - return(detector_data) + raise ValueError('Invalid parameter scan_step_index ' + + f'({scan_step_index})') + return detector_data class MCAScanParser(ScanParser): + """Partial implementation of a class representing a scan taken + while collecting SPEC MCA data. + """ + def __init__(self, spec_file_name, scan_number): super().__init__(spec_file_name, scan_number) - + self._dwell_time = None self._detector_num_bins = None - + @property def dwell_time(self): if self._dwell_time is None: self._dwell_time = self.get_dwell_time() - return(self._dwell_time) - + return self._dwell_time + def get_dwell_time(self): - raise(NotImplementedError) - @cache + """Return the dwell time for each scan point as it appears in + the SPEC command. + + :rtype: float + """ + raise NotImplementedError + def get_detector_num_bins(self, detector_prefix): - raise(NotImplementedError) + """Return the number of bins for the detector with the given + prefix. + + :param detector_prefix: the detector prefix as used in SPEC + MCA data files + :type detector_prefix: str + :rtype: int + """ + raise NotImplementedError + class SMBMCAScanParser(MCAScanParser, SMBScanParser): - def __init__(self, spec_file_name, scan_number): - super().__init__(spec_file_name, scan_number) - + """Concrete implementation of a class representing a scan taken + with the typical EDD setup at SMB or FAST. + """ + def get_spec_scan_npts(self): if self.spec_macro == 'tseries': - return(1) - elif self.spec_macro == 'ascan': - return(int(self.spec_args[3])) - elif self.spec_scan == 'wbslew_scan': - return(1) - else: - raise(RuntimeError(f'{self.scan_title}: cannot determine number of points for scans of type {self.spec_macro}')) + return 1 + if self.spec_macro == 'ascan': + return int(self.spec_args[3]) + if self.spec_scan == 'wbslew_scan': + return 1 + raise RuntimeError(f'{self.scan_title}: cannot determine number of ' + + f'points for scans of type {self.spec_macro}') def get_dwell_time(self): if self.spec_macro == 'tseries': - return(float(self.spec_args[1])) - elif self.spec_macro == 'ascan': - return(float(self.spec_args[4])) - elif self.spec_macro == 'wbslew_scan': - return(float(self.spec_args[3])) - else: - raise(RuntimeError(f'{self.scan_title}: cannot determine dwell time for scans of type {self.spec_macro}')) + return float(self.spec_args[1]) + if self.spec_macro == 'ascan': + return float(self.spec_args[4]) + if self.spec_macro == 'wbslew_scan': + return float(self.spec_args[3]) + raise RuntimeError(f'{self.scan_title}: cannot determine dwell time ' + + f'for scans of type {self.spec_macro}') def get_detector_num_bins(self, detector_prefix): - with open(self.get_detector_file(detector_prefix)) as detector_file: + with open(self.get_detector_data_file(detector_prefix)) \ + as detector_file: lines = detector_file.readlines() for line in lines: if line.startswith('#@CHANN'): try: - line_prefix, number_saved, first_saved, last_saved, reduction_coef = line.split() - return(int(number_saved)) + line_prefix, \ + number_saved, \ + first_saved, \ + last_saved, \ + reduction_coef = line.split() + return int(number_saved) except: continue - raise(RuntimeError(f'{self.scan_title}: could not find num_bins for detector {detector_prefix}')) - + raise RuntimeError(f'{self.scan_title}: could not find num_bins for ' + + f'detector {detector_prefix}') + def get_detector_data_path(self): - return(self.scan_path) + return self.scan_path - def get_detector_file(self, detector_prefix, scan_step_index:int=0): + def get_detector_data_file(self, detector_prefix, scan_step_index:int=0): file_name = f'spec.log.scan{self.scan_number}.mca1.mca' file_name_full = os.path.join(self.detector_data_path, file_name) if os.path.isfile(file_name_full): - return(file_name_full) - else: - raise(RuntimeError(f'{self.scan_title}: could not find detector image file')) + return file_name_full + raise RuntimeError(f'{self.scan_title}: could not find detector image ' + + 'file') - @cache def get_all_detector_data(self, detector_prefix): - # This should be easy with pyspec, but there are bugs in pyspec for MCA data..... - # or is the 'bug' from a nonstandard implementation of some macro on our end? - # According to spec manual and pyspec code, mca data should always begin w/ '@A' + # This should be easy with pyspec, but there are bugs in + # pyspec for MCA data..... or is the 'bug' from a nonstandard + # implementation of some macro on our end? According to spec + # manual and pyspec code, mca data should always begin w/ '@A' # In example scans, it begins with '@mca1' instead data = [] - - with open(self.get_detector_file(detector_prefix)) as detector_file: + + with open(self.get_detector_data_file(detector_prefix)) \ + as detector_file: lines = [line.strip("\\\n") for line in detector_file.readlines()] num_bins = self.get_detector_num_bins(detector_prefix) @@ -766,20 +1138,21 @@ def get_all_detector_data(self, detector_prefix): spectrum = np.zeros(num_bins) if counter == 1: b = np.array(a[1:]).astype('uint16') - spectrum[(counter-1)*25:((counter-1)*25+25)] = b + spectrum[(counter-1) * 25:((counter-1) * 25 + 25)] = b counter = counter + 1 - elif counter > 1 and counter <= (np.floor(num_bins/25.)): + elif counter > 1 and counter <= (np.floor(num_bins / 25.)): b = np.array(a).astype('uint16') - spectrum[(counter-1)*25:((counter-1)*25+25)] = b - counter = counter+1 + spectrum[(counter-1) * 25:((counter-1) * 25 + 25)] = b + counter = counter + 1 elif counter == (np.ceil(num_bins/25.)): b = np.array(a).astype('uint16') - spectrum[(counter-1)*25:((counter-1)*25+(np.mod(num_bins,25)))] = b + spectrum[(counter-1) * 25: + ((counter-1) * 25 + (np.mod(num_bins, 25)))] = b data.append(spectrum) counter = 0 - return(data) + return data def get_detector_data(self, detector_prefix, scan_step_index:int): detector_data = self.get_all_detector_data(detector_prefix) - return(detector_data[scan_step_index]) + return detector_data[scan_step_index] diff --git a/CHAP/common/writer.py b/CHAP/common/writer.py index 449ae9d..908559f 100755 --- a/CHAP/common/writer.py +++ b/CHAP/common/writer.py @@ -1,67 +1,70 @@ #!/usr/bin/env python -''' +""" File : writer.py Author : Valentin Kuznetsov Description: Module for Writers used in multiple experiment-specific workflows. -''' +""" # system modules -import argparse -import json -import logging import os -import sys # local modules from CHAP import Writer + class ExtractArchiveWriter(Writer): + """Writer for tar files from binary data""" def _write(self, data, filename): - '''Take a .tar archive represented as bytes in `data` and write the - extracted archive to files. + """Take a .tar archive represented as bytes in `data` and + write the extracted archive to files. :param data: the archive data :type data: bytes - :param filename: the name of a directory to which the archive files will - be written + :param filename: the name of a directory to which the archive + files will be written :type filename: str :return: the original `data` :rtype: bytes - ''' + """ from io import BytesIO import tarfile - tar = tarfile.open(fileobj=BytesIO(data)) - tar.extractall(path=filename) + with tarfile.open(fileobj=BytesIO(data)) as tar: + tar.extractall(path=filename) + + return data - return(data) class NexusWriter(Writer): + """Writer for NeXus files from `NXobject`-s""" def _write(self, data, filename, force_overwrite=False): - '''Write `data` to a NeXus file + """Write `data` to a NeXus file :param data: the data to write to `filename`. :type data: nexusformat.nexus.NXobject :param filename: name of the file to write to. :param force_overwrite: flag to allow data in `filename` to be - overwritten, if it already exists. + overwritten, if it already exists. :return: the original input data - ''' + """ from nexusformat.nexus import NXobject - + if not isinstance(data, NXobject): - raise(TypeError(f'Cannot write object of type {type(data).__name__} to a NeXus file.')) + raise TypeError('Cannot write object of type ' + + f'{type(data).__name__} to a NeXus file.') mode = 'w' if force_overwrite else 'w-' data.save(filename, mode=mode) - return(data) + return data + class YAMLWriter(Writer): + """Writer for YAML files from `dict`-s""" def _write(self, data, filename, force_overwrite=False): - '''If `data` is a `dict`, write it to `filename`. + """If `data` is a `dict`, write it to `filename`. :param data: the dictionary to write to `filename`. :type data: dict @@ -75,21 +78,24 @@ def _write(self, data, filename, force_overwrite=False): `force_overwrite` is `False`. :return: the original input data :rtype: dict - ''' + """ import yaml if not isinstance(data, (dict, list)): - raise(TypeError(f'{self.__name__}.write: input data must be a dict or list.')) + raise(TypeError(f'{self.__name__}.write: input data must be ' + + 'a dict or list.')) if not force_overwrite: if os.path.isfile(filename): - raise(RuntimeError(f'{self.__name__}: {filename} already exists.')) + raise(RuntimeError(f'{self.__name__}: {filename} already ' + + 'exists.')) with open(filename, 'w') as outf: yaml.dump(data, outf, sort_keys=False) - return(data) + return data + if __name__ == '__main__': from CHAP.writer import main diff --git a/CHAP/edd/processor.py b/CHAP/edd/processor.py index 0f3ba17..fb4e355 100755 --- a/CHAP/edd/processor.py +++ b/CHAP/edd/processor.py @@ -1,34 +1,35 @@ #!/usr/bin/env python #-*- coding: utf-8 -*- #pylint: disable= -''' +""" File : processor.py Author : Valentin Kuznetsov Description: Module for Processors used only by EDD experiments -''' +""" # system modules import json # local modules from CHAP.processor import Processor -from CHAP.common import StrainAnalysisProcessor class MCACeriaCalibrationProcessor(Processor): - '''A Processor using a CeO2 scan to obtain tuned values for the bragg - diffraction angle and linear correction parameters for MCA channel energies - for an EDD experimental setup. - ''' + """A Processor using a CeO2 scan to obtain tuned values for the + bragg diffraction angle and linear correction parameters for MCA + channel energies for an EDD experimental setup. + """ def _process(self, data): - '''Return tuned values for 2&theta and linear correction parameters for - the MCA channel energies. + """Return tuned values for 2&theta and linear correction + parameters for the MCA channel energies. - :param data: input configuration for the raw data & tuning procedure + :param data: input configuration for the raw data & tuning + procedure :type data: list[dict[str,object]] - :return: original configuration dictionary with tuned values added + :return: original configuration dictionary with tuned values + added :rtype: dict[str,float] - ''' + """ calibration_config = self.get_config(data) @@ -38,21 +39,22 @@ def _process(self, data): calibration_config.slope_calibrated = slope calibration_config.intercept_calibrated = intercept - return(calibration_config.dict()) + return calibration_config.dict() def get_config(self, data): - '''Get an instance of the configuration object needed by this + """Get an instance of the configuration object needed by this `Processor` from a returned value of `Reader.read` - :param data: Result of `Reader.read` where at least one item has the - value `'MCACeriaCalibrationConfig'` for the `'schema'` key. + :param data: Result of `Reader.read` where at least one item + has the value `'MCACeriaCalibrationConfig'` for the + `'schema'` key. :type data: list[dict[str,object]] - :raises Exception: If a valid config object cannot be constructed from - `data`. - :return: a valid instance of a configuration object with field values - taken from `data`. + :raises Exception: If a valid config object cannot be + constructed from `data`. + :return: a valid instance of a configuration object with field + values taken from `data`. :rtype: MCACeriaCalibrationConfig - ''' + """ from CHAP.edd.models import MCACeriaCalibrationConfig @@ -65,23 +67,25 @@ def get_config(self, data): break if not calibration_config: - raise(ValueError('No MCA ceria calibration configuration found in input data')) + raise ValueError('No MCA ceria calibration configuration found in ' + + 'input data') - return(MCACeriaCalibrationConfig(**calibration_config)) + return MCACeriaCalibrationConfig(**calibration_config) def calibrate(self, calibration_config): - '''Iteratively calibrate 2&theta by fitting selected peaks of an MCA - spectrum until the computed strain is sufficiently small. Use the fitted - peak locations to determine linear correction parameters for the MCA's - channel energies. + """Iteratively calibrate 2&theta by fitting selected peaks of + an MCA spectrum until the computed strain is sufficiently + small. Use the fitted peak locations to determine linear + correction parameters for the MCA's channel energies. - :param calibration_config: object configuring the CeO2 calibration - procedure + :param calibration_config: object configuring the CeO2 + calibration procedure :type calibration_config: MCACeriaCalibrationConfig - :return: calibrated values of 2&theta and linear correction parameters - for MCA channel energies : tth, slope, intercept + :return: calibrated values of 2&theta and linear correction + parameters for MCA channel energies : tth, slope, + intercept :rtype: float, float, float - ''' + """ from CHAP.common.utils.fit import Fit, FitMultipeak import numpy as np @@ -107,7 +111,8 @@ def calibrate(self, calibration_config): mca_intensity_weights = flux_correct(fit_mca_energies) fit_mca_intensities = fit_mca_intensities / mca_intensity_weights - # Get the HKLs and lattice spacings that will be used for fitting + # Get the HKLs and lattice spacings that will be used for + # fitting tth = calibration_config.tth_initial_guess fit_hkls, fit_ds = calibration_config.fit_ds() c_1 = fit_hkls[:,0]**2 + fit_hkls[:,1]**2 + fit_hkls[:,2]**2 @@ -116,8 +121,8 @@ def calibrate(self, calibration_config): ### Perform the uniform fit first ### - # Get expected peak energy locations for this iteration's starting - # value of tth + # Get expected peak energy locations for this iteration's + # starting value of tth fit_lambda = 2.0 * fit_ds * np.sin(0.5*np.radians(tth)) fit_E0 = hc / fit_lambda @@ -130,12 +135,13 @@ def calibrate(self, calibration_config): fit_type='uniform', plot=False) - # Extract values of interest from the best values for the uniform fit - # parameters - uniform_fit_centers = [best_values[f'peak{i+1}_center'] for i in range(len(calibration_config.fit_hkls))] + # Extract values of interest from the best values for the + # uniform fit parameters + uniform_fit_centers = [best_values[f'peak{i+1}_center'] \ + for i in range(len(calibration_config.fit_hkls))] # uniform_a = best_values['scale_factor'] # uniform_strain = np.log( - # (uniform_a + # (uniform_a # / calibration_config.lattice_parameter_angstrom)) # uniform_tth = tth * (1.0 + uniform_strain) # uniform_rel_rms_error = (np.linalg.norm(residual) @@ -143,8 +149,9 @@ def calibrate(self, calibration_config): ### Next, perform the unconstrained fit ### - # Use the peak locations found in the uniform fit as the initial - # guesses for peak locations in the unconstrained fit + # Use the peak locations found in the uniform fit as the + # initial guesses for peak locations in the unconstrained + # fit best_fit, residual, best_values, best_errors, redchi, success = \ FitMultipeak.fit_multipeak( fit_mca_intensities, @@ -156,7 +163,8 @@ def calibrate(self, calibration_config): # Extract values of interest from the best values for the # unconstrained fit parameters unconstrained_fit_centers = np.array( - [best_values[f'peak{i+1}_center'] for i in range(len(calibration_config.fit_hkls))]) + [best_values[f'peak{i+1}_center'] \ + for i in range(len(calibration_config.fit_hkls))]) unconstrained_a = (0.5 * hc * np.sqrt(c_1) / (unconstrained_fit_centers * abs(np.sin(0.5*np.radians(tth))))) @@ -173,7 +181,8 @@ def calibrate(self, calibration_config): prev_tth = tth tth = unconstrained_tth - # Stop tuning tth at this iteration if differences are small enough + # Stop tuning tth at this iteration if differences are + # small enough if abs(tth - prev_tth) < calibration_config.tune_tth_tol: break @@ -187,44 +196,46 @@ def calibrate(self, calibration_config): slope = fit.best_values['slope'] intercept = fit.best_values['intercept'] - return(float(tth), float(slope), float(intercept)) + return float(tth), float(slope), float(intercept) class MCADataProcessor(Processor): - '''A Processor to return data from an MCA, restuctured to incorporate the - shape & metadata associated with a map configuration to which the MCA data - belongs, and linearly transformed according to the results of a ceria - calibration. - ''' + """A Processor to return data from an MCA, restuctured to + incorporate the shape & metadata associated with a map + configuration to which the MCA data belongs, and linearly + transformed according to the results of a ceria calibration. + """ def _process(self, data): - '''Process configurations for a map and MCA detector(s), and return the - calibrated MCA data collected over the map. + """Process configurations for a map and MCA detector(s), and + return the calibrated MCA data collected over the map. - :param data: input map configuration and results of ceria calibration + :param data: input map configuration and results of ceria + calibration :type data: list[dict[str,object]] :return: calibrated and flux-corrected MCA data :rtype: nexusformat.nexus.NXentry - ''' + """ map_config, calibration_config = self.get_configs(data) nxroot = self.get_nxroot(map_config, calibration_config) - return(nxroot) + return nxroot def get_configs(self, data): - '''Get instances of the configuration objects needed by this + """Get instances of the configuration objects needed by this `Processor` from a returned value of `Reader.read` - :param data: Result of `Reader.read` where at least one item has the - value `'MapConfig'` for the `'schema'` key, and at least one item has - the value `'MCACeriaCalibrationConfig'` for the `'schema'` key. + :param data: Result of `Reader.read` where at least one item + has the value `'MapConfig'` for the `'schema'` key, and at + least one item has the value `'MCACeriaCalibrationConfig'` + for the `'schema'` key. :type data: list[dict[str,object]] - :raises Exception: If valid config objects cannot be constructed from - `data`. - :return: valid instances of the configuration objects with field values - taken from `data`. + :raises Exception: If valid config objects cannot be + constructed from `data`. + :return: valid instances of the configuration objects with + field values taken from `data`. :rtype: tuple[MapConfig, MCACeriaCalibrationConfig] - ''' + """ from CHAP.common.models.map import MapConfig from CHAP.edd.models import MCACeriaCalibrationConfig @@ -241,17 +252,20 @@ def get_configs(self, data): calibration_config = item.get('data') if not map_config: - raise(ValueError('No map configuration found in input data')) + raise ValueError('No map configuration found in input data') if not calibration_config: - raise(ValueError('No MCA ceria calibration configuration found in input data')) + raise ValueError('No MCA ceria calibration configuration found in ' + + 'input data') - return(MapConfig(**map_config), MCACeriaCalibrationConfig(**calibration_config)) + return (MapConfig(**map_config), + MCACeriaCalibrationConfig(**calibration_config)) def get_nxroot(self, map_config, calibration_config): - '''Get a map of the MCA data collected by the scans in `map_config`. The - MCA data will be calibrated and flux-corrected according to the - parameters included in `calibration_config`. The data will be returned - along with relevant metadata in the form of a NeXus structure. + """Get a map of the MCA data collected by the scans in + `map_config`. The MCA data will be calibrated and + flux-corrected according to the parameters included in + `calibration_config`. The data will be returned along with + relevant metadata in the form of a NeXus structure. :param map_config: the map configuration :type map_config: MapConfig @@ -259,13 +273,12 @@ def get_nxroot(self, map_config, calibration_config): :type calibration_config: MCACeriaCalibrationConfig :return: a map of the calibrated and flux-corrected MCA data :rtype: nexusformat.nexus.NXroot - ''' + """ from CHAP.common import MapProcessor from nexusformat.nexus import (NXdata, NXdetector, - NXentry, NXinstrument, NXroot) import numpy as np @@ -277,7 +290,8 @@ def get_nxroot(self, map_config, calibration_config): nxentry.instrument = NXinstrument() nxentry.instrument.detector = NXdetector() - nxentry.instrument.detector.calibration_configuration = json.dumps(calibration_config.dict()) + nxentry.instrument.detector.calibration_configuration = json.dumps( + calibration_config.dict()) nxentry.instrument.detector.data = NXdata() nxdata = nxentry.instrument.detector.data @@ -313,10 +327,11 @@ def get_nxroot(self, map_config, calibration_config): nxentry.data.attrs['axes'], f'{calibration_config.detector_name}_channel_energy'] else: - nxentry.data.attrs['axes'] += [f'{calibration_config.detector_name}_channel_energy'] + nxentry.data.attrs['axes'] += [ + f'{calibration_config.detector_name}_channel_energy'] nxentry.data.attrs['signal'] = calibration_config.detector_name - return(nxroot) + return nxroot if __name__ == '__main__': from CHAP.processor import main diff --git a/CHAP/inference/processor.py b/CHAP/inference/processor.py index cc5a40e..7702180 100755 --- a/CHAP/inference/processor.py +++ b/CHAP/inference/processor.py @@ -1,11 +1,11 @@ -#!/usr/bin/env python -#-*- coding: utf-8 -*- -#pylint: disable= -''' -File : processor.py -Author : Valentin Kuznetsov -Description: Processor module -''' +#!/usr/bin/env python +#-*- coding: utf-8 -*- +#pylint: disable= +""" +File : processor.py +Author : Valentin Kuznetsov +Description: Processor module +""" # system modules from time import time @@ -14,13 +14,10 @@ from CHAP import Processor class TFaaSImageProcessor(Processor): - ''' - A Processor to get predictions from TFaaS inference server. - ''' + """A Processor to get predictions from TFaaS inference server.""" + def process(self, data, url, model, verbose=False): - ''' - process data API - ''' + """process data API""" t0 = time() self.logger.info(f'Executing "process" with url {url} model {model}') @@ -29,39 +26,42 @@ def process(self, data, url, model, verbose=False): self.logger.info(f'Finished "process" in {time()-t0:.3f} seconds\n') - return(data) + return data def _process(self, data, url, model, verbose): - '''Print and return the input data. + """Print and return the input data. - :param data: Input image data, either file name or actual image data + :param data: Input image data, either file name or actual + image data :type data: object :return: `data` :rtype: object - ''' + """ from MLaaS.tfaas_client import predictImage from pathlib import Path self.logger.info(f'input data {type(data)}') if isinstance(data, str) and Path(data).is_file(): - imgFile = data - data = predictImage(url, imgFile, model, verbose) + img_file = data + data = predictImage(url, img_file, model, verbose) else: rdict = data[0] import requests img = rdict['data'] session = requests.Session() rurl = url + '/predict/image' - payload = dict(model=model) - files = dict(image=img) - self.logger.info(f'HTTP request {rurl} with image file and {payload} payload') + payload = {'model': model} + files = {'image': img} + self.logger.info( + f'HTTP request {rurl} with image file and {payload} payload') req = session.post(rurl, files=files, data=payload ) data = req.content data = data.decode('utf-8').replace('\n', '') self.logger.info(f'HTTP response {data}') - return(data) + return data + if __name__ == '__main__': from CHAP.processor import main diff --git a/CHAP/pipeline.py b/CHAP/pipeline.py index e2d7856..074bc7f 100755 --- a/CHAP/pipeline.py +++ b/CHAP/pipeline.py @@ -11,13 +11,11 @@ import logging from time import time + class Pipeline(): - """ - Pipeline represent generic Pipeline class - """ + """Pipeline represent generic Pipeline class""" def __init__(self, items=None, kwds=None): - """ - Pipeline class constructor + """Pipeline class constructor :param items: list of objects :param kwds: list of method args for individual objects @@ -31,12 +29,10 @@ def __init__(self, items=None, kwds=None): self.logger.propagate = False def execute(self): - """ - execute API - """ + """execute API""" t0 = time() - self.logger.info(f'Executing "execute"\n') + self.logger.info('Executing "execute"\n') data = None for item, kwargs in zip(self.items, self.kwds): @@ -52,33 +48,23 @@ def execute(self): self.logger.info(f'Executed "execute" in {time()-t0:.3f} seconds') + class PipelineObject(): - """ - PipelineObject represent generic Pipeline class - """ - def __init__(self, reader, writer, processor, fitter): - """ - PipelineObject class constructor - """ + """PipelineObject represent generic Pipeline class""" + def __init__(self, reader, writer, processor): + """PipelineObject class constructor""" self.reader = reader self.writer = writer self.processor = processor def read(self, filename): - """ - read object API - """ + """read object API""" return self.reader.read(filename) def write(self, data, filename): - """ - write object API - """ + """write object API""" return self.writer.write(data, filename) def process(self, data): - """ - process object API - """ + """process object API""" return self.processor.process(data) - diff --git a/CHAP/processor.py b/CHAP/processor.py index 64bee50..dbdeba9 100755 --- a/CHAP/processor.py +++ b/CHAP/processor.py @@ -10,29 +10,21 @@ # system modules import argparse import inspect -import json import logging import sys from time import time -# local modules -# from pipeline import PipelineObject class Processor(): - """ - Processor represent generic processor - """ + """Processor represent generic processor""" def __init__(self): - """ - Processor constructor - """ + """Processor constructor""" self.__name__ = self.__class__.__name__ self.logger = logging.getLogger(self.__name__) self.logger.propagate = False def process(self, data, **_process_kwargs): - """ - process data API + """process data API :param _process_kwargs: keyword arguments to pass to `self._process`, defaults to `{}` @@ -55,12 +47,18 @@ def process(self, data, **_process_kwargs): self.logger.info(f'Finished "process" in {time()-t0:.3f} seconds\n') - return(data) + return data + + def _process(self, data): + """Private method to carry out the mechanics of the specific + Processor. - def _process(self, data, **kwargs): + :param data: input data + :return: processed data + """ # If needed, extract data from a returned value of Reader.read if isinstance(data, list): - if all([isinstance(d,dict) for d in data]): + if all(isinstance(d,dict) for d in data): data = data[0]['data'] # process operation is a simple print function data += "process part\n" @@ -69,7 +67,7 @@ def _process(self, data, **kwargs): class OptionParser(): - '''User based option parser''' + """User based option parser""" def __init__(self): self.parser = argparse.ArgumentParser(prog='PROG') self.parser.add_argument( @@ -84,25 +82,26 @@ def __init__(self): def main(opt_parser=OptionParser): - '''Main function''' + """Main function""" optmgr = opt_parser() opts = optmgr.parser.parse_args() - clsName = opts.processor + cls_name = opts.processor try: - processorCls = getattr(sys.modules[__name__],clsName) + processor_cls = getattr(sys.modules[__name__],cls_name) except: - print(f'Unsupported processor {clsName}') + print(f'Unsupported processor {cls_name}') sys.exit(1) - processor = processorCls() + processor = processor_cls() processor.logger.setLevel(getattr(logging, opts.log_level)) log_handler = logging.StreamHandler() - log_handler.setFormatter(logging.Formatter('{name:20}: {message}', style='{')) + log_handler.setFormatter(logging.Formatter( + '{name:20}: {message}', style='{')) processor.logger.addHandler(log_handler) data = processor.process(opts.data) - print(f"Processor {processor} operates on data {data}") + print(f'Processor {processor} operates on data {data}') if __name__ == '__main__': main() diff --git a/CHAP/reader.py b/CHAP/reader.py index 092a5e3..f1cdb6e 100755 --- a/CHAP/reader.py +++ b/CHAP/reader.py @@ -1,56 +1,52 @@ #!/usr/bin/env python -''' +""" File : reader.py Author : Valentin Kuznetsov Description: generic Reader module -''' +""" # system modules import argparse import inspect -import json import logging import sys from time import time # local modules -# from pipeline import PipelineObject + class Reader(): - ''' - Reader represent generic file writer - ''' + """Reader represent generic file writer""" def __init__(self): - ''' - Constructor of Reader class - ''' + """Constructor of Reader class""" self.__name__ = self.__class__.__name__ self.logger = logging.getLogger(self.__name__) self.logger.propagate = False def read(self, type_=None, schema=None, encoding=None, **_read_kwargs): - '''Read API + """Read API Wrapper to read, format, and return the data requested. - :param type_: the expected type of data read from `filename`, defualts - to `None` + :param type_: the expected type of data read from `filename`, + defualts to `None` :type type_: type, optional - :param schema: the expected schema of the data read from `filename`, - defaults to `None` + :param schema: the expected schema of the data read from + `filename`, defaults to `None` :type schema: str, otional - :param _read_kwargs: keyword arguments to pass to `self._read`, defaults - to `{}` + :param _read_kwargs: keyword arguments to pass to + `self._read`, defaults to `{}` :type _read_kwargs: dict, optional - :return: list with one item: a dictionary containing the data read from - `filename`, the name of this `Reader`, and the values of `type_` and - `schema`. + :return: list with one item: a dictionary containing the data + read from `filename`, the name of this `Reader`, and the + values of `type_` and `schema`. :rtype: list[dict[str,object]] - ''' + """ t0 = time() - self.logger.info(f'Executing "read" with type={type_}, schema={schema}, kwargs={_read_kwargs}') + self.logger.info(f'Executing "read" with type={type_}, ' + + f'schema={schema}, kwargs={_read_kwargs}') _valid_read_args = {} allowed_args = inspect.getfullargspec(self._read).args \ @@ -68,25 +64,27 @@ def read(self, type_=None, schema=None, encoding=None, **_read_kwargs): 'encoding': encoding}] self.logger.info(f'Finished "read" in {time()-t0:.3f} seconds\n') - return(data) + return data def _read(self, filename): - '''Read and return the data from requested from `filename` + """Read and return the data from requested from `filename` :param filename: Name of file to read from :return: specific number of bytes from a file - ''' + """ if not filename: - self.logger.warning('No file name is given, will skip read operation') + self.logger.warning('No file name is given, will skip ' + + 'read operation') return None with open(filename) as file: data = file.read() - return(data) + return data + class OptionParser(): - '''User based option parser''' + """User based option parser""" def __init__(self): self.parser = argparse.ArgumentParser(prog='PROG') self.parser.add_argument( @@ -100,21 +98,22 @@ def __init__(self): dest='log_level', default='INFO', help='logging level') def main(opt_parser=OptionParser): - '''Main function''' + """Main function""" optmgr = opt_parser() opts = optmgr.parser.parse_args() - clsName = opts.reader + cls_name = opts.reader try: - readerCls = getattr(sys.modules[__name__],clsName) + reader_cls = getattr(sys.modules[__name__],cls_name) except: - print(f'Unsupported reader {clsName}') + print(f'Unsupported reader {cls_name}') sys.exit(1) - reader = readerCls() + reader = reader_cls() reader.logger.setLevel(getattr(logging, opts.log_level)) log_handler = logging.StreamHandler() - log_handler.setFormatter(logging.Formatter('{name:20}: {message}', style='{')) + log_handler.setFormatter(logging.Formatter( + '{name:20}: {message}', style='{')) reader.logger.addHandler(log_handler) data = reader.read(filename=opts.filename) diff --git a/CHAP/runner.py b/CHAP/runner.py index 296f760..53154b4 100755 --- a/CHAP/runner.py +++ b/CHAP/runner.py @@ -10,42 +10,48 @@ # system modules import argparse import logging -import os -import sys import yaml # local modules from CHAP.pipeline import Pipeline + class OptionParser(): + """User based option parser""" def __init__(self): - "User based option parser" + """OptionParser class constructor""" self.parser = argparse.ArgumentParser(prog='PROG') - self.parser.add_argument("--config", action="store", - dest="config", default="", help="Input configuration file") - self.parser.add_argument("--interactive", action="store_true", - dest="interactive", help="Allow interactive processes") - self.parser.add_argument('--log-level', choices=logging._nameToLevel.keys(), + self.parser.add_argument( + '--config', action='store', dest='config', + default='', help='Input configuration file') + self.parser.add_argument( + '--interactive', action='store_true', dest='interactive', + help='Allow interactive processes') + self.parser.add_argument( + '--log-level', choices=logging._nameToLevel.keys(), dest='log_level', default='INFO', help='logging level') + def main(): - "Main function" + """Main function""" optmgr = OptionParser() opts = optmgr.parser.parse_args() runner(opts) + def runner(opts): - """ - Main runner function + """Main runner function - :param opts: opts is an instance of argparse.Namespace which contains all input parameters + :param opts: object containing input parameters + :type opts: OptionParser """ logger = logging.getLogger(__name__) log_level = getattr(logging, opts.log_level.upper()) logger.setLevel(log_level) log_handler = logging.StreamHandler() - log_handler.setFormatter(logging.Formatter('{name:20}: {message}', style='{')) + log_handler.setFormatter(logging.Formatter( + '{name:20}: {message}', style='{')) logger.addHandler(log_handler) config = {} @@ -57,16 +63,20 @@ def runner(opts): kwds = [] for item in pipeline_config: # load individual object with given name from its module + kwargs = {'interactive': opts.interactive} if isinstance(item, dict): name = list(item.keys())[0] - kwargs = item[name] + # Combine the "interactive" command line argument with the + # object's keywords giving precedence of "interactive" in + # the latter + kwargs = {**kwargs, **item[name]} else: name = item kwargs = {} kwargs['interactive'] = opts.interactive - modName, clsName = name.split('.') - module = __import__(f'CHAP.{modName}', fromlist=[clsName]) - obj = getattr(module, clsName)() + mod_name, cls_name = name.split('.') + module = __import__(f'CHAP.{mod_name}', fromlist=[cls_name]) + obj = getattr(module, cls_name)() obj.logger.setLevel(log_level) obj.logger.addHandler(log_handler) logger.info(f'Loaded {obj}') diff --git a/CHAP/writer.py b/CHAP/writer.py index 84533e9..eafed0a 100755 --- a/CHAP/writer.py +++ b/CHAP/writer.py @@ -8,31 +8,22 @@ # system modules import argparse import inspect -import json import logging -import os import sys from time import time -# local modules -# from pipeline import PipelineObject class Writer(): - """ - Writer represent generic file writer - """ + """Writer represent generic file writer""" def __init__(self): - """ - Constructor of Writer class - """ + """Constructor of Writer class""" self.__name__ = self.__class__.__name__ self.logger = logging.getLogger(self.__name__) self.logger.propagate = False def write(self, data, filename, **_write_kwargs): - """ - write API + """write API :param filename: Name of file to write to :param data: data to write to file @@ -40,7 +31,8 @@ def write(self, data, filename, **_write_kwargs): """ t0 = time() - self.logger.info(f'Executing "write" with filename={filename}, type(data)={type(data)}, kwargs={_write_kwargs}') + self.logger.info(f'Executing "write" with filename={filename}, ' + + f'type(data)={type(data)}, kwargs={_write_kwargs}') _valid_write_args = {} allowed_args = inspect.getfullargspec(self._write).args \ @@ -55,49 +47,50 @@ def write(self, data, filename, **_write_kwargs): self.logger.info(f'Finished "write" in {time()-t0:.3f} seconds\n') - return(data) + return data def _write(self, data, filename): with open(filename, 'a') as file: file.write(data) - return(data) + return data class OptionParser(): - '''User based option parser''' - def __init__(self): - self.parser = argparse.ArgumentParser(prog='PROG') - self.parser.add_argument( - '--data', action='store', - dest='data', default='', help='Input data') - self.parser.add_argument( - '--filename', action='store', - dest='filename', default='', help='Output file') - self.parser.add_argument( - '--writer', action='store', - dest='writer', default='Writer', help='Writer class name') - self.parser.add_argument( - '--log-level', choices=logging._nameToLevel.keys(), + """User based option parser""" + def __init__(self): + self.parser = argparse.ArgumentParser(prog='PROG') + self.parser.add_argument( + '--data', action='store', + dest='data', default='', help='Input data') + self.parser.add_argument( + '--filename', action='store', + dest='filename', default='', help='Output file') + self.parser.add_argument( + '--writer', action='store', + dest='writer', default='Writer', help='Writer class name') + self.parser.add_argument( + '--log-level', choices=logging._nameToLevel.keys(), dest='log_level', default='INFO', help='logging level') def main(opt_parser=OptionParser): - '''Main function''' + """Main function""" optmgr = opt_parser() opts = optmgr.parser.parse_args() - clsName = opts.writer + cls_name = opts.writer try: - writerCls = getattr(sys.modules[__name__],clsName) + writer_cls = getattr(sys.modules[__name__],cls_name) except: - print(f'Unsupported writer {clsName}') + print(f'Unsupported writer {cls_name}') sys.exit(1) - writer = writerCls() + writer = writer_cls() writer.logger.setLevel(getattr(logging, opts.log_level)) log_handler = logging.StreamHandler() - log_handler.setFormatter(logging.Formatter('{name:20}: {message}', style='{')) + log_handler.setFormatter(logging.Formatter( + '{name:20}: {message}', style='{')) writer.logger.addHandler(log_handler) data = writer.write(opts.data, opts.filename) - print(f"Writer {writer} writes to {opts.filename}, data {data}") + print(f'Writer {writer} writes to {opts.filename}, data {data}') if __name__ == '__main__': main() From 70063f76c00828719e688966110ef40d47d64bde Mon Sep 17 00:00:00 2001 From: Rolf Verberg Date: Thu, 20 Apr 2023 11:05:37 -0400 Subject: [PATCH 2/6] style: edit selected files in the CHAP module to comply with PEP8 --- CHAP/common/reader.py | 9 + CHAP/common/utils/fit.py | 2156 ++++++++++++++++++--------------- CHAP/common/utils/general.py | 903 ++++++++------ CHAP/common/utils/material.py | 262 ++-- CHAP/tomo/__init__.py | 5 +- CHAP/tomo/models.py | 81 +- CHAP/tomo/processor.py | 1869 +++++++++++++++++----------- CHAP/tomo/reader.py | 6 +- CHAP/tomo/writer.py | 6 +- 9 files changed, 3082 insertions(+), 2215 deletions(-) diff --git a/CHAP/common/reader.py b/CHAP/common/reader.py index 083c39e..e0115f3 100755 --- a/CHAP/common/reader.py +++ b/CHAP/common/reader.py @@ -50,6 +50,15 @@ def read(self, readers): reader = reader_class() reader_kwargs = reader_config[reader_name] +# _valid_read_args = {} +# allowed_args = inspect.getfullargspec(self._read).args \ +# + inspect.getfullargspec(self._read).kwonlyargs +# for k, v in _read_kwargs.items(): +# if k in allowed_args: +# _valid_read_args[k] = v +# else: +# self.logger.warning(f'Ignoring invalid arg to _read: {k}') + data.extend(reader.read(**reader_kwargs)) self.logger.info(f'Finished "read" in {time()-t0:.3f} seconds\n') diff --git a/CHAP/common/utils/fit.py b/CHAP/common/utils/fit.py index f646e55..160b9fa 100755 --- a/CHAP/common/utils/fit.py +++ b/CHAP/common/utils/fit.py @@ -1,83 +1,106 @@ -#!/usr/bin/env python3 - -# -*- coding: utf-8 -*- +#!/usr/bin/env python +#-*- coding: utf-8 -*- +#pylint: disable= """ -Created on Mon Dec 6 15:36:22 2021 - -@author: rv43 +File : fit.py +Author : Rolf Verberg +Description: General curve fitting module """ -import logging - -try: - from asteval import Interpreter, get_ast_names -except: - pass +# System modules from copy import deepcopy +from logging import getLogger +from os import ( + cpu_count, + mkdir, + path, +) +from re import compile as re_compile +from re import sub +from shutil import rmtree +from sys import float_info + +# Third party modules try: - from lmfit import Model, Parameters - from lmfit.model import ModelResult - from lmfit.models import ConstantModel, LinearModel, QuadraticModel, PolynomialModel,\ - ExponentialModel, StepModel, RectangleModel, ExpressionModel, GaussianModel,\ - LorentzianModel -except: - pass + from joblib import ( + Parallel, + delayed, + ) + HAVE_JOBLIB = True +except ImportError: + HAVE_JOBLIB = False +from lmfit import ( + Parameters, + Model, +) +from lmfit.model import ModelResult +from lmfit.models import ( + ConstantModel, + LinearModel, + QuadraticModel, + PolynomialModel, + ExponentialModel, + StepModel, + RectangleModel, + ExpressionModel, + GaussianModel, + LorentzianModel, +) import numpy as np -from os import cpu_count, getpid, listdir, mkdir, path -from re import compile, sub -from shutil import rmtree try: - from sympy import diff, simplify -except: + from sympy import ( + diff, + simplify, + ) +except ImportError: pass -try: - from joblib import Parallel, delayed - have_joblib = True -except: - have_joblib = False try: import xarray as xr - have_xarray = True -except: - have_xarray = False - -try: - from .general import illegal_value, is_int, is_dict_series, is_index, index_nearest, \ - almost_equal, quick_plot #, eval_expr -except: - try: - from sys import path as syspath - syspath.append(f'/nfs/chess/user/rv43/msnctools/msnctools') - from general import illegal_value, is_int, is_dict_series, is_index, index_nearest, \ - almost_equal, quick_plot #, eval_expr - except: - from general import illegal_value, is_int, is_dict_series, is_index, index_nearest, \ - almost_equal, quick_plot #, eval_expr - -from sys import float_info -float_min = float_info.min -float_max = float_info.max + HAVE_XARRAY = True +except ImportError: + HAVE_XARRAY = False + +# Local modules +from CHAP.common.utils.general import ( + is_int, + is_num, + is_dict_series, + is_index, + index_nearest, + input_num, + quick_plot, + #eval_expr, +) + +logger = getLogger(__name__) +FLOAT_MIN = float_info.min +FLOAT_MAX = float_info.max # sigma = fwhm_factor*fwhm fwhm_factor = { - 'gaussian': f'fwhm/(2*sqrt(2*log(2)))', - 'lorentzian': f'0.5*fwhm', - 'splitlorentzian': f'0.5*fwhm', # sigma = sigma_r - 'voight': f'0.2776*fwhm', # sigma = gamma - 'pseudovoight': f'0.5*fwhm'} # fraction = 0.5 + 'gaussian': 'fwhm/(2*sqrt(2*log(2)))', + 'lorentzian': '0.5*fwhm', + 'splitlorentzian': '0.5*fwhm', # sigma = sigma_r + 'voight': '0.2776*fwhm', # sigma = gamma + 'pseudovoight': '0.5*fwhm', # fraction = 0.5 +} # amplitude = height_factor*height*fwhm height_factor = { - 'gaussian': f'height*fwhm*0.5*sqrt(pi/log(2))', - 'lorentzian': f'height*fwhm*0.5*pi', - 'splitlorentzian': f'height*fwhm*0.5*pi', # sigma = sigma_r - 'voight': f'3.334*height*fwhm', # sigma = gamma - 'pseudovoight': f'1.268*height*fwhm'} # fraction = 0.5 + 'gaussian': 'height*fwhm*0.5*sqrt(pi/log(2))', + 'lorentzian': 'height*fwhm*0.5*pi', + 'splitlorentzian': 'height*fwhm*0.5*pi', # sigma = sigma_r + 'voight': '3.334*height*fwhm', # sigma = gamma + 'pseudovoight': '1.268*height*fwhm', # fraction = 0.5 +} class Fit: - """Wrapper class for lmfit + """ + Wrapper class for lmfit. """ def __init__(self, y, x=None, models=None, normalize=True, **kwargs): + """Initialize Fit.""" + # Third party modules if not isinstance(normalize, bool): raise ValueError(f'Invalid parameter normalize ({normalize})') self._mask = None @@ -95,28 +118,33 @@ def __init__(self, y, x=None, models=None, normalize=True, **kwargs): self._y_norm = None self._y_range = None if 'try_linear_fit' in kwargs: - try_linear_fit = kwargs.pop('try_linear_fit') - if not isinstance(try_linear_fit, bool): - illegal_value(try_linear_fit, 'try_linear_fit', 'Fit.fit', raise_error=True) - self._try_linear_fit = try_linear_fit + self._try_linear_fit = kwargs.pop('try_linear_fit') + if not isinstance(self._try_linear_fit, bool): + raise ValueError( + 'Invalid value of keyword argument try_linear_fit ' + + f'({self._try_linear_fit})') if y is not None: if isinstance(y, (tuple, list, np.ndarray)): self._x = np.asarray(x) self._y = np.asarray(y) - elif have_xarray and isinstance(y, xr.DataArray): + elif HAVE_XARRAY and isinstance(y, xr.DataArray): if x is not None: - logging.warning('Ignoring superfluous input x ({x}) in Fit.__init__') + logger.warning('Ignoring superfluous input x ({x})') if y.ndim != 1: - illegal_value(y.ndim, 'DataArray dimensions', 'Fit:__init__', raise_error=True) + raise ValueError( + 'Invalid DataArray dimensions for parameter y ' + + f'({y.ndim})') self._x = np.asarray(y[y.dims[0]]) self._y = y else: - illegal_value(y, 'y', 'Fit:__init__', raise_error=True) + raise ValueError(f'Invalid parameter y ({y})') if self._x.ndim != 1: - raise ValueError(f'Invalid dimension for input x ({self._x.ndim})') + raise ValueError( + f'Invalid dimension for input x ({self._x.ndim})') if self._x.size != self._y.size: - raise ValueError(f'Inconsistent x and y dimensions ({self._x.size} vs '+ - f'{self._y.size})') + raise ValueError( + f'Inconsistent x and y dimensions ({self._x.size} vs ' + + f'{self._y.size})') if 'mask' in kwargs: self._mask = kwargs.pop('mask') if self._mask is None: @@ -127,8 +155,9 @@ def __init__(self, y, x=None, models=None, normalize=True, **kwargs): else: self._mask = np.asarray(self._mask).astype(bool) if self._x.size != self._mask.size: - raise ValueError(f'Inconsistent x and mask dimensions ({self._x.size} vs '+ - f'{self._mask.size})') + raise ValueError( + f'Inconsistent x and mask dimensions ({self._x.size} ' + + f'vs {self._mask.size})') y_masked = np.asarray(self._y)[~self._mask] y_min = float(y_masked.min()) self._y_range = float(y_masked.max())-y_min @@ -145,88 +174,111 @@ def __init__(self, y, x=None, models=None, normalize=True, **kwargs): @classmethod def fit_data(cls, y, models, x=None, normalize=True, **kwargs): - return(cls(y, x=x, models=models, normalize=normalize, **kwargs)) + """Class method for Fit.""" + return cls(y, x=x, models=models, normalize=normalize, **kwargs) @property def best_errors(self): + """Return errors in the best fit parameters.""" if self._result is None: - return(None) - return({name:self._result.params[name].stderr for name in sorted(self._result.params) - if name != 'tmp_normalization_offset_c'}) + return None + return( + {name:self._result.params[name].stderr + for name in sorted(self._result.params) + if name != 'tmp_normalization_offset_c'}) @property def best_fit(self): + """Return the best fit.""" if self._result is None: - return(None) - return(self._result.best_fit) + return None + return self._result.best_fit - @property def best_parameters(self): + """Return the best fit parameters.""" if self._result is None: - return(None) + return None parameters = {} for name in sorted(self._result.params): if name != 'tmp_normalization_offset_c': par = self._result.params[name] - parameters[name] = {'value': par.value, 'error': par.stderr, - 'init_value': par.init_value, 'min': par.min, 'max': par.max, - 'vary': par.vary, 'expr': par.expr} - return(parameters) + parameters[name] = { + 'value': par.value, + 'error': par.stderr, + 'init_value': par.init_value, + 'min': par.min, + 'max': par.max, + 'vary': par.vary, 'expr': par.expr + } + return parameters @property def best_results(self): - """Convert the input data array to a data set and add the fit results. + """ + Convert the input DataArray to a data set and add the fit + results. """ if self._result is None: - return(None) - if not have_xarray: - logging.warning('fit.best_results requires xarray in the conda environment') - return(None) + return None + if not HAVE_XARRAY: + logger.warning( + 'fit.best_results requires xarray in the conda environment') + return None if isinstance(self._y, xr.DataArray): best_results = self._y.to_dataset() dims = self._y.dims fit_name = f'{self._y.name}_fit' else: coords = {'x': (['x'], self._x)} - dims = ('x') + dims = ('x',) best_results = xr.Dataset(coords=coords) best_results['y'] = (dims, self._y) fit_name = 'y_fit' best_results[fit_name] = (dims, self.best_fit) if self._mask is not None: best_results['mask'] = self._mask - best_results.coords['par_names'] = ('peak', [name for name in self.best_values.keys()]) - best_results['best_values'] = (['par_names'], [v for v in self.best_values.values()]) - best_results['best_errors'] = (['par_names'], [v for v in self.best_errors.values()]) + best_results.coords['par_names'] = ('peak', self.best_values.keys()) + best_results['best_values'] = \ + (['par_names'], self.best_values.values()) + best_results['best_errors'] = \ + (['par_names'], self.best_errors.values()) best_results.attrs['components'] = self.components - return(best_results) + return best_results @property def best_values(self): + """Return values of the best fit parameters.""" if self._result is None: - return(None) - return({name:self._result.params[name].value for name in sorted(self._result.params) - if name != 'tmp_normalization_offset_c'}) + return None + return( + {name:self._result.params[name].value + for name in sorted(self._result.params) + if name != 'tmp_normalization_offset_c'}) @property def chisqr(self): + """Return the chisqr value of the best fit.""" if self._result is None: - return(None) - return(self._result.chisqr) + return None + return self._result.chisqr @property def components(self): + """Return the fit model components info.""" components = {} if self._result is None: - logging.warning('Unable to collect components in Fit.components') - return(components) + logger.warning('Unable to collect components in Fit.components') + return components for component in self._result.components: if 'tmp_normalization_offset_c' in component.param_names: continue parameters = {} for name in component.param_names: par = self._parameters[name] - parameters[name] = {'free': par.vary, 'value': self._result.params[name].value} + parameters[name] = { + 'free': par.vary, + 'value': self._result.params[name].value, + } if par.expr is not None: parameters[name]['expr'] = par.expr expr = None @@ -237,170 +289,200 @@ def components(self): expr = component.expr else: prefix = component.prefix - if len(prefix): + if prefix: if prefix[-1] == '_': prefix = prefix[:-1] name = f'{prefix} ({component._name})' else: name = f'{component._name}' if expr is None: - components[name] = {'parameters': parameters} + components[name] = { + 'parameters': parameters, + } else: - components[name] = {'expr': expr, 'parameters': parameters} - return(components) + components[name] = { + 'expr': expr, + 'parameters': parameters, + } + return components @property def covar(self): + """Return the covarience matrix of the best fit parameters.""" if self._result is None: - return(None) - return(self._result.covar) + return None + return self._result.covar @property def init_parameters(self): + """Return the initial parameters for the fit model.""" if self._result is None or self._result.init_params is None: - return(None) + return None parameters = {} for name in sorted(self._result.init_params): if name != 'tmp_normalization_offset_c': par = self._result.init_params[name] - parameters[name] = {'value': par.value, 'min': par.min, 'max': par.max, - 'vary': par.vary, 'expr': par.expr} - return(parameters) + parameters[name] = { + 'value': par.value, + 'min': par.min, + 'max': par.max, + 'vary': par.vary, + 'expr': par.expr, + } + return parameters @property def init_values(self): + """Return the initial values for the fit parameters.""" if self._result is None or self._result.init_params is None: - return(None) - return({name:self._result.init_params[name].value for name in - sorted(self._result.init_params) if name != 'tmp_normalization_offset_c'}) + return None + return( + {name:self._result.init_params[name].value + for name in sorted(self._result.init_params) + if name != 'tmp_normalization_offset_c'}) @property def normalization_offset(self): + """Return the normalization_offset for the fit model.""" if self._result is None: - return(None) + return None if self._norm is None: - return(0.0) + return 0.0 + if self._result.init_params is not None: + normalization_offset = float( + self._result.init_params['tmp_normalization_offset_c']) else: - if self._result.init_params is not None: - normalization_offset = float(self._result.init_params['tmp_normalization_offset_c']) - else: - normalization_offset = float(self._result.params['tmp_normalization_offset_c']) - return(normalization_offset) + normalization_offset = float( + self._result.params['tmp_normalization_offset_c']) + return normalization_offset @property def num_func_eval(self): + """ + Return the number of function evaluations for the best fit. + """ if self._result is None: - return(None) - return(self._result.nfev) + return None + return self._result.nfev @property def parameters(self): - return({name:{'min': par.min, 'max': par.max, 'vary': par.vary, 'expr': par.expr} - for name, par in self._parameters.items() if name != 'tmp_normalization_offset_c'}) + """Return the fit parameter info.""" + return( + {name:{'min': par.min, 'max': par.max, 'vary': par.vary, + 'expr': par.expr} for name, par in self._parameters.items() + if name != 'tmp_normalization_offset_c'}) @property def redchi(self): + """Return the redchi value of the best fit.""" if self._result is None: - return(None) - return(self._result.redchi) + return None + return self._result.redchi @property def residual(self): + """Return the residual in the best fit.""" if self._result is None: - return(None) - return(self._result.residual) + return None + return self._result.residual @property def success(self): + """Return the success value for the fit.""" if self._result is None: - return(None) + return None if not self._result.success: -# print(f'ier = {self._result.ier}') -# print(f'lmdif_message = {self._result.lmdif_message}') -# print(f'message = {self._result.message}') -# print(f'nfev = {self._result.nfev}') -# print(f'redchi = {self._result.redchi}') -# print(f'success = {self._result.success}') - if self._result.ier == 0 or self._result.ier == 5: - logging.warning(f'ier = {self._result.ier}: {self._result.message}') - else: - logging.warning(f'ier = {self._result.ier}: {self._result.message}') - return(True) -# self.print_fit_report() -# self.plot() - return(self._result.success) + logger.warning( + f'ier = {self._result.ier}: {self._result.message}') + if self._result.ier and self._result.ier != 5: + return True + return self._result.success @property def var_names(self): - """Intended to be used with covar + """ + Return the variable names for the covarience matrix property. """ if self._result is None: - return(None) - return(getattr(self._result, 'var_names', None)) + return None + return getattr(self._result, 'var_names', None) @property def x(self): - return(self._x) + """Return the input x-array.""" + return self._x @property def y(self): - return(self._y) + """Return the input y-array.""" + return self._y def print_fit_report(self, result=None, show_correl=False): + """Print a fit report.""" if result is None: result = self._result if result is not None: print(result.fit_report(show_correl=show_correl)) def add_parameter(self, **parameter): + """Add a fit fit parameter to the fit model.""" if not isinstance(parameter, dict): raise ValueError(f'Invalid parameter ({parameter})') if parameter.get('expr') is not None: raise KeyError(f'Invalid "expr" key in parameter {parameter}') name = parameter['name'] if not isinstance(name, str): - raise ValueError(f'Invalid "name" value ({name}) in parameter {parameter}') + raise ValueError( + f'Invalid "name" value ({name}) in parameter {parameter}') if parameter.get('norm') is None: self._parameter_norms[name] = False else: norm = parameter.pop('norm') if self._norm is None: - logging.warning(f'Ignoring norm in parameter {name} in '+ - f'Fit.add_parameter (normalization is turned off)') + logger.warning( + f'Ignoring norm in parameter {name} in Fit.add_parameter ' + + '(normalization is turned off)') self._parameter_norms[name] = False else: if not isinstance(norm, bool): - raise ValueError(f'Invalid "norm" value ({norm}) in parameter {parameter}') + raise ValueError( + f'Invalid "norm" value ({norm}) in parameter ' + + f'{parameter}') self._parameter_norms[name] = norm vary = parameter.get('vary') if vary is not None: if not isinstance(vary, bool): - raise ValueError(f'Invalid "vary" value ({vary}) in parameter {parameter}') + raise ValueError( + f'Invalid "vary" value ({vary}) in parameter {parameter}') if not vary: if 'min' in parameter: - logging.warning(f'Ignoring min in parameter {name} in '+ - f'Fit.add_parameter (vary = {vary})') + logger.warning( + f'Ignoring min in parameter {name} in ' + + f'Fit.add_parameter (vary = {vary})') parameter.pop('min') if 'max' in parameter: - logging.warning(f'Ignoring max in parameter {name} in '+ - f'Fit.add_parameter (vary = {vary})') + logger.warning( + f'Ignoring max in parameter {name} in ' + + f'Fit.add_parameter (vary = {vary})') parameter.pop('max') if self._norm is not None and name not in self._parameter_norms: - raise ValueError(f'Missing parameter normalization type for paremeter {name}') + raise ValueError( + f'Missing parameter normalization type for parameter {name}') self._parameters.add(**parameter) - def add_model(self, model, prefix=None, parameters=None, parameter_norms=None, **kwargs): - # Create the new model -# print(f'at start add_model:\nself._parameters:\n{self._parameters}') -# print(f'at start add_model: kwargs = {kwargs}') -# print(f'parameters = {parameters}') -# print(f'parameter_norms = {parameter_norms}') -# if len(self._parameters.keys()): -# print('\nAt start adding model:') -# self._parameters.pretty_print() -# print(f'parameter_norms:\n{self._parameter_norms}') + def add_model( + self, model, prefix=None, parameters=None, parameter_norms=None, + **kwargs): + """Add a model component to the fit model.""" + # Third party modules + from asteval import ( + Interpreter, + get_ast_names, + ) + if prefix is not None and not isinstance(prefix, str): - logging.warning('Ignoring illegal prefix: {model} {type(model)}') + logger.warning('Ignoring illegal prefix: {model} {type(model)}') prefix = None if prefix is None: pprefix = '' @@ -410,61 +492,79 @@ def add_model(self, model, prefix=None, parameters=None, parameter_norms=None, * if isinstance(parameters, dict): parameters = (parameters, ) elif not is_dict_series(parameters): - illegal_value(parameters, 'parameters', 'Fit.add_model', raise_error=True) + raise ValueError('Invalid parameter parameters ({parameters})') parameters = deepcopy(parameters) if parameter_norms is not None: if isinstance(parameter_norms, dict): parameter_norms = (parameter_norms, ) if not is_dict_series(parameter_norms): - illegal_value(parameter_norms, 'parameter_norms', 'Fit.add_model', raise_error=True) + raise ValueError( + 'Invalid parameter parameters_norms ({parameters_norms})') new_parameter_norms = {} if callable(model): # Linear fit not yet implemented for callable models self._try_linear_fit = False if parameter_norms is None: if parameters is None: - raise ValueError('Either "parameters" or "parameter_norms" is required in '+ - f'{model}') + raise ValueError( + 'Either parameters or parameter_norms is required in ' + + f'{model}') for par in parameters: name = par['name'] if not isinstance(name, str): - raise ValueError(f'Invalid "name" value ({name}) in input parameters') + raise ValueError( + f'Invalid "name" value ({name}) in input ' + + 'parameters') if par.get('norm') is not None: norm = par.pop('norm') if not isinstance(norm, bool): - raise ValueError(f'Invalid "norm" value ({norm}) in input parameters') + raise ValueError( + f'Invalid "norm" value ({norm}) in input ' + + 'parameters') new_parameter_norms[f'{pprefix}{name}'] = norm else: for par in parameter_norms: name = par['name'] if not isinstance(name, str): - raise ValueError(f'Invalid "name" value ({name}) in input parameters') + raise ValueError( + f'Invalid "name" value ({name}) in input ' + + 'parameters') norm = par.get('norm') if norm is None or not isinstance(norm, bool): - raise ValueError(f'Invalid "norm" value ({norm}) in input parameters') + raise ValueError( + f'Invalid "norm" value ({norm}) in input ' + + 'parameters') new_parameter_norms[f'{pprefix}{name}'] = norm if parameters is not None: for par in parameters: if par.get('expr') is not None: - raise KeyError(f'Invalid "expr" key ({par.get("expr")}) in parameter '+ - f'{name} for a callable model {model}') + raise KeyError( + f'Invalid "expr" key ({par.get("expr")}) in ' + + f'parameter {name} for a callable model {model}') name = par['name'] if not isinstance(name, str): - raise ValueError(f'Invalid "name" value ({name}) in input parameters') -# RV FIX callable model will need partial deriv functions for any linear pars to get the linearized matrix, so for now skip linear solution option + raise ValueError( + f'Invalid "name" value ({name}) in input ' + + 'parameters') +# RV callable model will need partial deriv functions for any linear +# parameter to get the linearized matrix, so for now skip linear +# solution option newmodel = Model(model, prefix=prefix) elif isinstance(model, str): - if model == 'constant': # Par: c + if model == 'constant': + # Par: c newmodel = ConstantModel(prefix=prefix) new_parameter_norms[f'{pprefix}c'] = True self._linear_parameters.append(f'{pprefix}c') - elif model == 'linear': # Par: slope, intercept + elif model == 'linear': + # Par: slope, intercept newmodel = LinearModel(prefix=prefix) new_parameter_norms[f'{pprefix}slope'] = True new_parameter_norms[f'{pprefix}intercept'] = True self._linear_parameters.append(f'{pprefix}slope') self._linear_parameters.append(f'{pprefix}intercept') - elif model == 'quadratic': # Par: a, b, c + elif model == 'quadratic': + # Par: a, b, c newmodel = QuadraticModel(prefix=prefix) new_parameter_norms[f'{pprefix}a'] = True new_parameter_norms[f'{pprefix}b'] = True @@ -472,17 +572,21 @@ def add_model(self, model, prefix=None, parameters=None, parameter_norms=None, * self._linear_parameters.append(f'{pprefix}a') self._linear_parameters.append(f'{pprefix}b') self._linear_parameters.append(f'{pprefix}c') - elif model == 'polynomial': # Par: c0, c1,..., c7 + elif model == 'polynomial': + # Par: c0, c1,..., c7 degree = kwargs.get('degree') if degree is not None: kwargs.pop('degree') if degree is None or not is_int(degree, ge=0, le=7): - raise ValueError(f'Invalid parameter degree for build-in step model ({degree})') + raise ValueError( + 'Invalid parameter degree for build-in step model ' + + f'({degree})') newmodel = PolynomialModel(degree=degree, prefix=prefix) for i in range(degree+1): new_parameter_norms[f'{pprefix}c{i}'] = True self._linear_parameters.append(f'{pprefix}c{i}') - elif model == 'gaussian': # Par: amplitude, center, sigma (fwhm, height) + elif model == 'gaussian': + # Par: amplitude, center, sigma (fwhm, height) newmodel = GaussianModel(prefix=prefix) new_parameter_norms[f'{pprefix}amplitude'] = True new_parameter_norms[f'{pprefix}center'] = False @@ -490,10 +594,12 @@ def add_model(self, model, prefix=None, parameters=None, parameter_norms=None, * self._linear_parameters.append(f'{pprefix}amplitude') self._nonlinear_parameters.append(f'{pprefix}center') self._nonlinear_parameters.append(f'{pprefix}sigma') - # parameter norms for height and fwhm are needed to get correct errors + # parameter norms for height and fwhm are needed to + # get correct errors new_parameter_norms[f'{pprefix}height'] = True new_parameter_norms[f'{pprefix}fwhm'] = False - elif model == 'lorentzian': # Par: amplitude, center, sigma (fwhm, height) + elif model == 'lorentzian': + # Par: amplitude, center, sigma (fwhm, height) newmodel = LorentzianModel(prefix=prefix) new_parameter_norms[f'{pprefix}amplitude'] = True new_parameter_norms[f'{pprefix}center'] = False @@ -501,21 +607,27 @@ def add_model(self, model, prefix=None, parameters=None, parameter_norms=None, * self._linear_parameters.append(f'{pprefix}amplitude') self._nonlinear_parameters.append(f'{pprefix}center') self._nonlinear_parameters.append(f'{pprefix}sigma') - # parameter norms for height and fwhm are needed to get correct errors + # parameter norms for height and fwhm are needed to + # get correct errors new_parameter_norms[f'{pprefix}height'] = True new_parameter_norms[f'{pprefix}fwhm'] = False - elif model == 'exponential': # Par: amplitude, decay + elif model == 'exponential': + # Par: amplitude, decay newmodel = ExponentialModel(prefix=prefix) new_parameter_norms[f'{pprefix}amplitude'] = True new_parameter_norms[f'{pprefix}decay'] = False self._linear_parameters.append(f'{pprefix}amplitude') self._nonlinear_parameters.append(f'{pprefix}decay') - elif model == 'step': # Par: amplitude, center, sigma + elif model == 'step': + # Par: amplitude, center, sigma form = kwargs.get('form') if form is not None: kwargs.pop('form') - if form is None or form not in ('linear', 'atan', 'arctan', 'erf', 'logistic'): - raise ValueError(f'Invalid parameter form for build-in step model ({form})') + if (form is None or form not in + ('linear', 'atan', 'arctan', 'erf', 'logistic')): + raise ValueError( + 'Invalid parameter form for build-in step model ' + + f'({form})') newmodel = StepModel(prefix=prefix, form=form) new_parameter_norms[f'{pprefix}amplitude'] = True new_parameter_norms[f'{pprefix}center'] = False @@ -523,13 +635,16 @@ def add_model(self, model, prefix=None, parameters=None, parameter_norms=None, * self._linear_parameters.append(f'{pprefix}amplitude') self._nonlinear_parameters.append(f'{pprefix}center') self._nonlinear_parameters.append(f'{pprefix}sigma') - elif model == 'rectangle': # Par: amplitude, center1, center2, sigma1, sigma2 + elif model == 'rectangle': + # Par: amplitude, center1, center2, sigma1, sigma2 form = kwargs.get('form') if form is not None: kwargs.pop('form') - if form is None or form not in ('linear', 'atan', 'arctan', 'erf', 'logistic'): - raise ValueError('Invalid parameter form for build-in rectangle model '+ - f'({form})') + if (form is None or form not in + ('linear', 'atan', 'arctan', 'erf', 'logistic')): + raise ValueError( + 'Invalid parameter form for build-in rectangle model ' + + f'({form})') newmodel = RectangleModel(prefix=prefix, form=form) new_parameter_norms[f'{pprefix}amplitude'] = True new_parameter_norms[f'{pprefix}center1'] = False @@ -541,86 +656,75 @@ def add_model(self, model, prefix=None, parameters=None, parameter_norms=None, * self._nonlinear_parameters.append(f'{pprefix}center2') self._nonlinear_parameters.append(f'{pprefix}sigma1') self._nonlinear_parameters.append(f'{pprefix}sigma2') - elif model == 'expression': # Par: by expression + elif model == 'expression': + # Par: by expression expr = kwargs['expr'] if not isinstance(expr, str): - raise ValueError(f'Invalid "expr" value ({expr}) in {model}') + raise ValueError( + f'Invalid "expr" value ({expr}) in {model}') kwargs.pop('expr') if parameter_norms is not None: - logging.warning('Ignoring parameter_norms (normalization determined from '+ - 'linearity)}') + logger.warning( + 'Ignoring parameter_norms (normalization ' + + 'determined from linearity)}') if parameters is not None: for par in parameters: if par.get('expr') is not None: - raise KeyError(f'Invalid "expr" key ({par.get("expr")}) in parameter '+ - f'({par}) for an expression model') + raise KeyError( + f'Invalid "expr" key ({par.get("expr")}) in ' + + f'parameter ({par}) for an expression model') if par.get('norm') is not None: - logging.warning(f'Ignoring "norm" key in parameter ({par}) '+ - '(normalization determined from linearity)}') + logger.warning( + f'Ignoring "norm" key in parameter ({par}) ' + + '(normalization determined from linearity)') par.pop('norm') name = par['name'] if not isinstance(name, str): - raise ValueError(f'Invalid "name" value ({name}) in input parameters') + raise ValueError( + f'Invalid "name" value ({name}) in input ' + + 'parameters') ast = Interpreter() - expr_parameters = [name for name in get_ast_names(ast.parse(expr)) - if name != 'x' and name not in self._parameters - and name not in ast.symtable] -# print(f'\nexpr_parameters: {expr_parameters}') -# print(f'expr = {expr}') + expr_parameters = [ + name for name in get_ast_names(ast.parse(expr)) + if name != 'x' and name not in self._parameters + and name not in ast.symtable] if prefix is None: newmodel = ExpressionModel(expr=expr) else: for name in expr_parameters: expr = sub(rf'\b{name}\b', f'{prefix}{name}', expr) - expr_parameters = [f'{prefix}{name}' for name in expr_parameters] -# print(f'\nexpr_parameters: {expr_parameters}') -# print(f'expr = {expr}') + expr_parameters = [ + f'{prefix}{name}' for name in expr_parameters] newmodel = ExpressionModel(expr=expr, name=name) -# print(f'\nnewmodel = {newmodel.__dict__}') -# print(f'params_names = {newmodel._param_names}') -# print(f'params_names = {newmodel.param_names}') # Remove already existing names for name in newmodel.param_names.copy(): if name not in expr_parameters: newmodel._func_allargs.remove(name) newmodel._param_names.remove(name) -# print(f'params_names = {newmodel._param_names}') -# print(f'params_names = {newmodel.param_names}') else: raise ValueError(f'Unknown build-in fit model ({model})') else: - illegal_value(model, 'model', 'Fit.add_model', raise_error=True) + raise ValueError('Invalid parameter model ({model})') # Add the new model to the current one -# print('\nBefore adding model:') -# print(f'\nnewmodel = {newmodel.__dict__}') -# if len(self._parameters): -# self._parameters.pretty_print() if self._model is None: self._model = newmodel else: self._model += newmodel new_parameters = newmodel.make_params() self._parameters += new_parameters -# print('\nAfter adding model:') -# print(f'\nnewmodel = {newmodel.__dict__}') -# print(f'\nnew_parameters = {new_parameters}') -# self._parameters.pretty_print() - # Check linearity of expression model paremeters + # Check linearity of expression model parameters if isinstance(newmodel, ExpressionModel): for name in newmodel.param_names: if not diff(newmodel.expr, name, name): if name not in self._linear_parameters: self._linear_parameters.append(name) new_parameter_norms[name] = True -# print(f'\nADDING {name} TO LINEAR') else: if name not in self._nonlinear_parameters: self._nonlinear_parameters.append(name) new_parameter_norms[name] = False -# print(f'\nADDING {name} TO NONLINEAR') -# print(f'new_parameter_norms:\n{new_parameter_norms}') # Scale the default initial model parameters if self._norm is not None: @@ -633,150 +737,160 @@ def add_model(self, model, prefix=None, parameters=None, parameter_norms=None, * value = par.value*self._norm[1] _min = par.min _max = par.max - if not np.isinf(_min) and abs(_min) != float_min: + if not np.isinf(_min) and abs(_min) != FLOAT_MIN: _min *= self._norm[1] - if not np.isinf(_max) and abs(_max) != float_min: + if not np.isinf(_max) and abs(_max) != FLOAT_MIN: _max *= self._norm[1] par.set(value=value, min=_min, max=_max) -# print('\nAfter norm defaults:') -# self._parameters.pretty_print() -# print(f'parameters:\n{parameters}') -# print(f'all_parameters:\n{list(self.parameters)}') -# print(f'new_parameter_norms:\n{new_parameter_norms}') -# print(f'parameter_norms:\n{self._parameter_norms}') # Initialize the model parameters from parameters if prefix is None: - prefix = "" + prefix = '' if parameters is not None: for parameter in parameters: name = parameter['name'] if not isinstance(name, str): - raise ValueError(f'Invalid "name" value ({name}) in input parameters') + raise ValueError( + f'Invalid "name" value ({name}) in input parameters') if name not in new_parameters: name = prefix+name parameter['name'] = name if name not in new_parameters: - logging.warning(f'Ignoring superfluous parameter info for {name}') + logger.warning( + f'Ignoring superfluous parameter info for {name}') continue if name in self._parameters: parameter.pop('name') if 'norm' in parameter: if not isinstance(parameter['norm'], bool): - illegal_value(parameter['norm'], 'norm', 'Fit.add_model', - raise_error=True) + raise ValueError( + f'Invalid "norm" value ({norm}) in the ' + + f'input parameter {name}') new_parameter_norms[name] = parameter['norm'] parameter.pop('norm') if parameter.get('expr') is not None: if 'value' in parameter: - logging.warning(f'Ignoring value in parameter {name} '+ - f'(set by expression: {parameter["expr"]})') + logger.warning( + f'Ignoring value in parameter {name} ' + + f'(set by expression: {parameter["expr"]})') parameter.pop('value') if 'vary' in parameter: - logging.warning(f'Ignoring vary in parameter {name} '+ - f'(set by expression: {parameter["expr"]})') + logger.warning( + f'Ignoring vary in parameter {name} ' + + f'(set by expression: {parameter["expr"]})') parameter.pop('vary') if 'min' in parameter: - logging.warning(f'Ignoring min in parameter {name} '+ - f'(set by expression: {parameter["expr"]})') + logger.warning( + f'Ignoring min in parameter {name} ' + + f'(set by expression: {parameter["expr"]})') parameter.pop('min') if 'max' in parameter: - logging.warning(f'Ignoring max in parameter {name} '+ - f'(set by expression: {parameter["expr"]})') + logger.warning( + f'Ignoring max in parameter {name} ' + + f'(set by expression: {parameter["expr"]})') parameter.pop('max') if 'vary' in parameter: if not isinstance(parameter['vary'], bool): - illegal_value(parameter['vary'], 'vary', 'Fit.add_model', - raise_error=True) + raise ValueError( + f'Invalid "vary" value ({parameter["vary"]}) ' + + f'in the input parameter {name}') if not parameter['vary']: if 'min' in parameter: - logging.warning(f'Ignoring min in parameter {name} in '+ - f'Fit.add_model (vary = {parameter["vary"]})') + logger.warning( + f'Ignoring min in parameter {name} ' + + f'(vary = {parameter["vary"]})') parameter.pop('min') if 'max' in parameter: - logging.warning(f'Ignoring max in parameter {name} in '+ - f'Fit.add_model (vary = {parameter["vary"]})') + logger.warning( + f'Ignoring max in parameter {name} ' + + f'(vary = {parameter["vary"]})') parameter.pop('max') self._parameters[name].set(**parameter) parameter['name'] = name else: - illegal_value(parameter, 'parameter name', 'Fit.model', raise_error=True) - self._parameter_norms = {**self._parameter_norms, **new_parameter_norms} -# print('\nAfter parameter init:') -# self._parameters.pretty_print() -# print(f'parameters:\n{parameters}') -# print(f'new_parameter_norms:\n{new_parameter_norms}') -# print(f'parameter_norms:\n{self._parameter_norms}') -# print(f'kwargs:\n{kwargs}') + raise ValueError( + 'Invalid parameter name in parameters ({name})') + self._parameter_norms = { + **self._parameter_norms, + **new_parameter_norms, + } # Initialize the model parameters from kwargs for name, value in {**kwargs}.items(): full_name = f'{pprefix}{name}' - if full_name in new_parameter_norms and isinstance(value, (int, float)): + if (full_name in new_parameter_norms + and isinstance(value, (int, float))): kwargs.pop(name) if self._parameters[full_name].expr is None: self._parameters[full_name].set(value=value) else: - logging.warning(f'Ignoring parameter {name} in Fit.fit (set by expression: '+ - f'{self._parameters[full_name].expr})') -# print('\nAfter kwargs init:') -# self._parameters.pretty_print() -# print(f'parameter_norms:\n{self._parameter_norms}') -# print(f'kwargs:\n{kwargs}') - - # Check parameter norms (also need it for expressions to renormalize the errors) - if self._norm is not None and (callable(model) or model == 'expression'): + logger.warning( + f'Ignoring parameter {name} (set by expression: ' + + f'{self._parameters[full_name].expr})') + + # Check parameter norms + # (also need it for expressions to renormalize the errors) + if (self._norm is not None and + (callable(model) or model == 'expression')): missing_norm = False for name in new_parameters.valuesdict(): if name not in self._parameter_norms: print(f'new_parameters:\n{new_parameters.valuesdict()}') print(f'self._parameter_norms:\n{self._parameter_norms}') - logging.error(f'Missing parameter normalization type for {name} in {model}') + logger.error( + f'Missing parameter normalization type for {name} in ' + + f'{model}') missing_norm = True if missing_norm: raise ValueError -# print(f'at end add_model:\nself._parameters:\n{list(self.parameters)}') -# print(f'at end add_model: kwargs = {kwargs}') -# print(f'\nat end add_model: newmodel:\n{newmodel.__dict__}\n') - return(kwargs) + return kwargs def eval(self, x, result=None): + """Evaluate the best fit.""" if result is None: result = self._result if result is None: - return - return(result.eval(x=np.asarray(x))-self.normalization_offset) + return None + return result.eval(x=np.asarray(x))-self.normalization_offset - def fit(self, interactive=False, guess=False, **kwargs): + def fit(self, **kwargs): + """Fit the model to the input data.""" # Check inputs if self._model is None: - logging.error('Undefined fit model') - return - if not isinstance(interactive, bool): - illegal_value(interactive, 'interactive', 'Fit.fit', raise_error=True) - if not isinstance(guess, bool): - illegal_value(guess, 'guess', 'Fit.fit', raise_error=True) + logger.error('Undefined fit model') + return None + if 'interactive' in kwargs: + interactive = kwargs.pop('interactive') + if not isinstance(interactive, bool): + raise ValueError( + 'Invalid value of keyword argument interactive ' + + f'({interactive})') + if 'guess' in kwargs: + guess = kwargs.pop('guess') + if not isinstance(guess, bool): + raise ValueError( + f'Invalid value of keyword argument guess ({guess})') if 'try_linear_fit' in kwargs: try_linear_fit = kwargs.pop('try_linear_fit') if not isinstance(try_linear_fit, bool): - illegal_value(try_linear_fit, 'try_linear_fit', 'Fit.fit', raise_error=True) + raise ValueError( + 'Invalid value of keyword argument try_linear_fit ' + + f'({try_linear_fit})') if not self._try_linear_fit: - logging.warning('Ignore superfluous keyword argument "try_linear_fit" (not '+ - 'yet supported for callable models)') + logger.warning( + 'Ignore superfluous keyword argument "try_linear_fit" ' + + '(not yet supported for callable models)') else: self._try_linear_fit = try_linear_fit -# if self._result is None: -# if 'parameters' in kwargs: -# raise ValueError('Invalid parameter parameters ({kwargs["parameters"]})') -# else: if self._result is not None: if guess: - logging.warning('Ignoring input parameter guess in Fit.fit during refitting') + logger.warning( + 'Ignoring input parameter guess during refitting') guess = False # Check for circular expressions - # FIX TODO + # RV # for name1, par1 in self._parameters.items(): # if par1.expr is not None: @@ -786,28 +900,32 @@ def fit(self, interactive=False, guess=False, **kwargs): if self._mask is not None: self._mask = np.asarray(self._mask).astype(bool) if self._x.size != self._mask.size: - raise ValueError(f'Inconsistent x and mask dimensions ({self._x.size} vs '+ - f'{self._mask.size})') - - # Estimate initial parameters with build-in lmfit guess method (only for a single model) -# print(f'\nat start fit: kwargs = {kwargs}') -#RV print('\nAt start of fit:') -#RV self._parameters.pretty_print() -# print(f'parameter_norms:\n{self._parameter_norms}') + raise ValueError( + f'Inconsistent x and mask dimensions ({self._x.size} vs ' + + f'{self._mask.size})') + + # Estimate initial parameters with build-in lmfit guess method + # (only mplemented for a single model) if guess: if self._mask is None: self._parameters = self._model.guess(self._y, x=self._x) else: - self._parameters = self._model.guess(np.asarray(self._y)[~self._mask], - x=self._x[~self._mask]) -# print('\nAfter guess:') -# self._parameters.pretty_print() + self._parameters = self._model.guess( + np.asarray(self._y)[~self._mask], x=self._x[~self._mask]) # Add constant offset for a normalized model if self._result is None and self._norm is not None and self._norm[0]: - self.add_model('constant', prefix='tmp_normalization_offset_', parameters={'name': 'c', - 'value': -self._norm[0], 'vary': False, 'norm': True}) - #'value': -self._norm[0]/self._norm[1], 'vary': False, 'norm': False}) + self.add_model( + 'constant', prefix='tmp_normalization_offset_', + parameters={ + 'name': 'c', + 'value': -self._norm[0], + 'vary': False, + 'norm': True, +# 'value': -self._norm[0]/self._norm[1], +# 'vary': False, +# 'norm': False, + }) # Adjust existing parameters for refit: if 'parameters' in kwargs: @@ -815,33 +933,36 @@ def fit(self, interactive=False, guess=False, **kwargs): if isinstance(parameters, dict): parameters = (parameters, ) elif not is_dict_series(parameters): - illegal_value(parameters, 'parameters', 'Fit.fit', raise_error=True) + raise ValueError( + 'Invalid value of keyword argument parameters ' + + f'({parameters})') for par in parameters: name = par['name'] if name not in self._parameters: - raise ValueError(f'Unable to match {name} parameter {par} to an existing one') + raise ValueError( + f'Unable to match {name} parameter {par} to an ' + + 'existing one') if self._parameters[name].expr is not None: - raise ValueError(f'Unable to modify {name} parameter {par} (currently an '+ - 'expression)') + raise ValueError( + f'Unable to modify {name} parameter {par} ' + + '(currently an expression)') if par.get('expr') is not None: - raise KeyError(f'Invalid "expr" key in {name} parameter {par}') + raise KeyError( + f'Invalid "expr" key in {name} parameter {par}') self._parameters[name].set(vary=par.get('vary')) self._parameters[name].set(min=par.get('min')) self._parameters[name].set(max=par.get('max')) self._parameters[name].set(value=par.get('value')) -#RV print('\nAfter adjust:') -#RV self._parameters.pretty_print() # Apply parameter updates through keyword arguments -# print(f'kwargs = {kwargs}') -# print(f'parameter_norms = {self._parameter_norms}') for name in set(self._parameters) & set(kwargs): value = kwargs.pop(name) if self._parameters[name].expr is None: self._parameters[name].set(value=value) else: - logging.warning(f'Ignoring parameter {name} in Fit.fit (set by expression: '+ - f'{self._parameters[name].expr})') + logger.warning( + f'Ignoring parameter {name} (set by expression: ' + + f'{self._parameters[name].expr})') # Check for uninitialized parameters for name, par in self._parameters.items(): @@ -849,7 +970,8 @@ def fit(self, interactive=False, guess=False, **kwargs): value = par.value if value is None or np.isinf(value) or np.isnan(value): if interactive: - value = input_num(f'Enter an initial value for {name}', default=1.0) + value = input_num( + f'Enter an initial value for {name}', default=1.0) else: value = 1.0 if self._norm is None or name not in self._parameter_norms: @@ -862,20 +984,11 @@ def fit(self, interactive=False, guess=False, **kwargs): linear_model = self._check_linearity_model() except: linear_model = False -# print(f'\n\n--------> linear_model = {linear_model}\n') if kwargs.get('check_only_linearity') is not None: - return(linear_model) + return linear_model # Normalize the data and initial parameters -#RV print('\nBefore normalization:') -#RV self._parameters.pretty_print() -# print(f'parameter_norms:\n{self._parameter_norms}') self._normalize() -# print(f'norm = {self._norm}') -#RV print('\nAfter normalization:') -#RV self._parameters.pretty_print() -# self.print_fit_report() -# print(f'parameter_norms:\n{self._parameter_norms}') if linear_model: # Perform a linear fit by direct matrix solution with numpy @@ -890,46 +1003,44 @@ def fit(self, interactive=False, guess=False, **kwargs): if not linear_model: # Perform a non-linear fit with lmfit # Prevent initial values from sitting at boundaries - self._parameter_bounds = {name:{'min': par.min, 'max': par.max} for name, par in - self._parameters.items() if par.vary} + self._parameter_bounds = { + name:{'min': par.min, 'max': par.max} + for name, par in self._parameters.items() if par.vary} for par in self._parameters.values(): if par.vary: par.set(value=self._reset_par_at_boundary(par, par.value)) -# print('\nAfter checking boundaries:') -# self._parameters.pretty_print() # Perform the fit # fit_kws = None # if 'Dfun' in kwargs: # fit_kws = {'Dfun': kwargs.pop('Dfun')} -# self._result = self._model.fit(self._y_norm, self._parameters, x=self._x, -# fit_kws=fit_kws, **kwargs) +# self._result = self._model.fit( +# self._y_norm, self._parameters, x=self._x, fit_kws=fit_kws, +# **kwargs) if self._mask is None: - self._result = self._model.fit(self._y_norm, self._parameters, x=self._x, **kwargs) + self._result = self._model.fit( + self._y_norm, self._parameters, x=self._x, **kwargs) else: - self._result = self._model.fit(np.asarray(self._y_norm)[~self._mask], - self._parameters, x=self._x[~self._mask], **kwargs) -#RV print('\nAfter fit:') -# print(f'\nself._result ({self._result}):\n\t{self._result.__dict__}') -#RV self._parameters.pretty_print() -# self.print_fit_report() + self._result = self._model.fit( + np.asarray(self._y_norm)[~self._mask], self._parameters, + x=self._x[~self._mask], **kwargs) # Set internal parameter values to fit results upon success if self.success: for name, par in self._parameters.items(): if par.expr is None and par.vary: par.set(value=self._result.params[name].value) -# print('\nAfter update parameter values:') -# self._parameters.pretty_print() # Renormalize the data and results self._renormalize() -#RV print('\nAfter renormalization:') -#RV self._parameters.pretty_print() -# self.print_fit_report() - def plot(self, y=None, y_title=None, result=None, skip_init=False, plot_comp=True, - plot_comp_legends=False, plot_residual=False, plot_masked_data=True, **kwargs): + return None + + def plot( + self, y=None, y_title=None, title=None, result=None, + skip_init=False, plot_comp=True, plot_comp_legends=False, + plot_residual=False, plot_masked_data=True, **kwargs): + """Plot the best fit.""" if result is None: result = self._result if result is None: @@ -943,9 +1054,10 @@ def plot(self, y=None, y_title=None, result=None, skip_init=False, plot_comp=Tru mask = self._mask if y is not None: if not isinstance(y, (tuple, list, np.ndarray)): - illegal_value(y, 'y', 'Fit.plot') + logger.warning('Ignorint invalid parameter y ({y}') if len(y) != len(self._x): - logging.warning('Ignoring parameter y in Fit.plot (wrong dimension)') + logger.warning( + 'Ignoring parameter y in plot (wrong dimension)') y = None if y is not None: if y_title is None or not isinstance(y_title, str): @@ -973,106 +1085,98 @@ def plot(self, y=None, y_title=None, result=None, skip_init=False, plot_comp=Tru num_components -= 1 if num_components > 1: eval_index = 0 - for modelname, y in components.items(): + for modelname, y_comp in components.items(): if modelname == 'tmp_normalization_offset_': continue if modelname == '_eval': modelname = f'eval{eval_index}' if len(modelname) > 20: modelname = f'{modelname[0:16]} ...' - if isinstance(y, (int, float)): - y *= np.ones(self._x[~mask].size) - plots += [(self._x[~mask], y, '--')] + if isinstance(y_comp, (int, float)): + y_comp *= np.ones(self._x[~mask].size) + plots += [(self._x[~mask], y_comp, '--')] if plot_comp_legends: if modelname[-1] == '_': legend.append(modelname[:-1]) else: legend.append(modelname) - title = kwargs.get('title') - if title is not None: - kwargs.pop('title') - quick_plot(tuple(plots), legend=legend, title=title, block=True, **kwargs) + quick_plot( + tuple(plots), legend=legend, title=title, block=True, **kwargs) @staticmethod - def guess_init_peak(x, y, *args, center_guess=None, use_max_for_center=True): - """ Return a guess for the initial height, center and fwhm for a peak + def guess_init_peak( + x, y, *args, center_guess=None, use_max_for_center=True): + """ + Return a guess for the initial height, center and fwhm for a + single peak. """ -# print(f'\n\nargs = {args}') -# print(f'center_guess = {center_guess}') -# quick_plot(x, y, vlines=center_guess, block=True) center_guesses = None x = np.asarray(x) y = np.asarray(y) if len(x) != len(y): - logging.error(f'Invalid x and y lengths ({len(x)}, {len(y)}), skip initial guess') - return(None, None, None) + logger.error( + f'Invalid x and y lengths ({len(x)}, {len(y)}), ' + + 'skip initial guess') + return None, None, None if isinstance(center_guess, (int, float)): - if len(args): - logging.warning('Ignoring additional arguments for single center_guess value') + if args: + logger.warning( + 'Ignoring additional arguments for single center_guess ' + + 'value') center_guesses = [center_guess] elif isinstance(center_guess, (tuple, list, np.ndarray)): if len(center_guess) == 1: - logging.warning('Ignoring additional arguments for single center_guess value') + logger.warning( + 'Ignoring additional arguments for single center_guess ' + + 'value') if not isinstance(center_guess[0], (int, float)): - raise ValueError(f'Invalid parameter center_guess ({type(center_guess[0])})') + raise ValueError( + 'Invalid parameter center_guess ' + + f'({type(center_guess[0])})') center_guess = center_guess[0] else: if len(args) != 1: - raise ValueError(f'Invalid number of arguments ({len(args)})') + raise ValueError( + f'Invalid number of arguments ({len(args)})') n = args[0] if not is_index(n, 0, len(center_guess)): raise ValueError('Invalid argument') center_guesses = center_guess center_guess = center_guesses[n] elif center_guess is not None: - raise ValueError(f'Invalid center_guess type ({type(center_guess)})') -# print(f'x = {x}') -# print(f'y = {y}') -# print(f'center_guess = {center_guess}') + raise ValueError( + f'Invalid center_guess type ({type(center_guess)})') # Sort the inputs index = np.argsort(x) x = x[index] y = y[index] miny = y.min() -# print(f'miny = {miny}') -# print(f'x_range = {x[0]} {x[-1]} {len(x)}') -# print(f'y_range = {y[0]} {y[-1]} {len(y)}') -# quick_plot(x, y, vlines=center_guess, block=True) -# xx = x -# yy = y # Set range for current peak -# print(f'n = {n}') -# print(f'center_guesses = {center_guesses}') if center_guesses is not None: if len(center_guesses) > 1: index = np.argsort(center_guesses) n = list(index).index(n) -# print(f'n = {n}') -# print(f'index = {index}') center_guesses = np.asarray(center_guesses)[index] -# print(f'center_guesses = {center_guesses}') if n == 0: - low = 0 - upp = index_nearest(x, (center_guesses[0]+center_guesses[1])/2) + low = 0 + upp = index_nearest( + x, (center_guesses[0]+center_guesses[1]) / 2) elif n == len(center_guesses)-1: - low = index_nearest(x, (center_guesses[n-1]+center_guesses[n])/2) - upp = len(x) + low = index_nearest( + x, (center_guesses[n-1]+center_guesses[n]) / 2) + upp = len(x) else: - low = index_nearest(x, (center_guesses[n-1]+center_guesses[n])/2) - upp = index_nearest(x, (center_guesses[n]+center_guesses[n+1])/2) -# print(f'low = {low}') -# print(f'upp = {upp}') + low = index_nearest( + x, (center_guesses[n-1]+center_guesses[n]) / 2) + upp = index_nearest( + x, (center_guesses[n]+center_guesses[n+1]) / 2) x = x[low:upp] y = y[low:upp] -# quick_plot(x, y, vlines=(x[0], center_guess, x[-1]), block=True) - # Estimate FHHM + # Estimate FWHM maxy = y.max() -# print(f'x_range = {x[0]} {x[-1]} {len(x)}') -# print(f'y_range = {y[0]} {y[-1]} {len(y)} {miny} {maxy}') -# print(f'center_guess = {center_guess}') if center_guess is None: center_index = np.argmax(y) center = x[center_index] @@ -1088,52 +1192,43 @@ def guess_init_peak(x, y, *args, center_guess=None, use_max_for_center=True): center_index = index_nearest(x, center_guess) center = center_guess height = y[center_index]-miny -# print(f'center_index = {center_index}') -# print(f'center = {center}') -# print(f'height = {height}') - half_height = miny+0.5*height -# print(f'half_height = {half_height}') + half_height = miny + 0.5*height fwhm_index1 = 0 for i in range(center_index, fwhm_index1, -1): if y[i] < half_height: fwhm_index1 = i break -# print(f'fwhm_index1 = {fwhm_index1} {x[fwhm_index1]}') fwhm_index2 = len(x)-1 for i in range(center_index, fwhm_index2): if y[i] < half_height: fwhm_index2 = i break -# print(f'fwhm_index2 = {fwhm_index2} {x[fwhm_index2]}') -# quick_plot((x,y,'o'), vlines=(x[fwhm_index1], center, x[fwhm_index2]), block=True) if fwhm_index1 == 0 and fwhm_index2 < len(x)-1: - fwhm = 2*(x[fwhm_index2]-center) + fwhm = 2 * (x[fwhm_index2]-center) elif fwhm_index1 > 0 and fwhm_index2 == len(x)-1: - fwhm = 2*(center-x[fwhm_index1]) + fwhm = 2 * (center-x[fwhm_index1]) else: fwhm = x[fwhm_index2]-x[fwhm_index1] -# print(f'fwhm_index1 = {fwhm_index1} {x[fwhm_index1]}') -# print(f'fwhm_index2 = {fwhm_index2} {x[fwhm_index2]}') -# print(f'fwhm = {fwhm}') - # Return height, center and FWHM -# quick_plot((x,y,'o'), (xx,yy), vlines=(x[fwhm_index1], center, x[fwhm_index2]), block=True) - return(height, center, fwhm) + return height, center, fwhm def _check_linearity_model(self): - """Identify the linearity of all model parameters and check if the model is linear or not + """ + Identify the linearity of all model parameters and check if + the model is linear or not. """ if not self._try_linear_fit: - logging.info('Skip linearity check (not yet supported for callable models)') - return(False) - free_parameters = [name for name, par in self._parameters.items() if par.vary] + logger.info( + 'Skip linearity check (not yet supported for callable models)') + return False + free_parameters = \ + [name for name, par in self._parameters.items() if par.vary] for component in self._model.components: if 'tmp_normalization_offset_c' in component.param_names: continue if isinstance(component, ExpressionModel): for name in free_parameters: if diff(component.expr, name, name): -# print(f'\t\t{component.expr} is non-linear in {name}') self._nonlinear_parameters.append(name) if name in self._linear_parameters: self._linear_parameters.remove(name) @@ -1149,38 +1244,33 @@ def _check_linearity_model(self): for nname in free_parameters: if name in self._nonlinear_parameters: if diff(expr, nname): -# print(f'\t\t{component} is non-linear in {nname} (through {name} = "{expr}")') self._nonlinear_parameters.append(nname) if nname in self._linear_parameters: self._linear_parameters.remove(nname) else: - assert(name in self._linear_parameters) -# print(f'\n\nexpr ({type(expr)}) = {expr}\nnname ({type(nname)}) = {nname}\n\n') + assert name in self._linear_parameters if diff(expr, nname, nname): -# print(f'\t\t{component} is non-linear in {nname} (through {name} = "{expr}")') self._nonlinear_parameters.append(nname) if nname in self._linear_parameters: self._linear_parameters.remove(nname) -# print(f'\nfree parameters:\n\t{free_parameters}') -# print(f'linear parameters:\n\t{self._linear_parameters}') -# print(f'nonlinear parameters:\n\t{self._nonlinear_parameters}\n') - if any(True for name in self._nonlinear_parameters if self._parameters[name].vary): - return(False) - return(True) + if any(True for name in self._nonlinear_parameters + if self._parameters[name].vary): + return False + return True def _fit_linear_model(self, x, y): - """Perform a linear fit by direct matrix solution with numpy """ + Perform a linear fit by direct matrix solution with numpy. + """ + # Third party modules + from asteval import Interpreter + # Construct the matrix and the free parameter vector -# print(f'\nparameters:') -# self._parameters.pretty_print() -# print(f'\nparameter_norms:\n\t{self._parameter_norms}') -# print(f'\nlinear_parameters:\n\t{self._linear_parameters}') -# print(f'nonlinear_parameters:\n\t{self._nonlinear_parameters}') - free_parameters = [name for name, par in self._parameters.items() if par.vary] -# print(f'free parameters:\n\t{free_parameters}\n') - expr_parameters = {name:par.expr for name, par in self._parameters.items() - if par.expr is not None} + free_parameters = \ + [name for name, par in self._parameters.items() if par.vary] + expr_parameters = { + name:par.expr for name, par in self._parameters.items() + if par.expr is not None} model_parameters = [] for component in self._model.components: if 'tmp_normalization_offset_c' in component.param_names: @@ -1191,175 +1281,162 @@ def _fit_linear_model(self, x, y): if hint.get('expr') is not None: expr_parameters.pop(name) model_parameters.remove(name) -# print(f'expr parameters:\n{expr_parameters}') -# print(f'model parameters:\n\t{model_parameters}\n') norm = 1.0 if self._normalized: norm = self._norm[1] -# print(f'\n\nself._normalized = {self._normalized}\nnorm = {norm}\nself._norm = {self._norm}\n') # Add expression parameters to asteval ast = Interpreter() -# print(f'Adding to asteval sym table:') for name, expr in expr_parameters.items(): -# print(f'\tadding {name} {expr}') ast.symtable[name] = expr # Add constant parameters to asteval - # (renormalize to use correctly in evaluation of expression models) + # (renormalize to use correctly in evaluation of expression + # models) for name, par in self._parameters.items(): if par.expr is None and not par.vary: if self._parameter_norms[name]: -# print(f'\tadding {name} {par.value*norm}') ast.symtable[name] = par.value*norm else: -# print(f'\tadding {name} {par.value}') ast.symtable[name] = par.value - A = np.zeros((len(x), len(free_parameters)), dtype='float64') + mat_a = np.zeros((len(x), len(free_parameters)), dtype='float64') y_const = np.zeros(len(x), dtype='float64') have_expression_model = False for component in self._model.components: if isinstance(component, ConstantModel): name = component.param_names[0] -# print(f'\nConstant model: {name} {self._parameters[name]}\n') if name in free_parameters: -# print(f'\t\t{name} is a free constant set matrix column {free_parameters.index(name)} to 1.0') - A[:,free_parameters.index(name)] = 1.0 + mat_a[:,free_parameters.index(name)] = 1.0 else: if self._parameter_norms[name]: - delta_y_const = self._parameters[name]*np.ones(len(x)) + delta_y_const = \ + self._parameters[name] * np.ones(len(x)) else: - delta_y_const = (self._parameters[name]*norm)*np.ones(len(x)) + delta_y_const = \ + (self._parameters[name]*norm) * np.ones(len(x)) y_const += delta_y_const -# print(f'\ndelta_y_const ({type(delta_y_const)}):\n{delta_y_const}\n') elif isinstance(component, ExpressionModel): have_expression_model = True const_expr = component.expr -# print(f'\nExpression model:\nconst_expr: {const_expr}\n') for name in free_parameters: dexpr_dname = diff(component.expr, name) if dexpr_dname: - const_expr = f'{const_expr}-({str(dexpr_dname)})*{name}' -# print(f'\tconst_expr: {const_expr}') + const_expr = \ + f'{const_expr}-({str(dexpr_dname)})*{name}' if not self._parameter_norms[name]: dexpr_dname = f'({dexpr_dname})/{norm}' -# print(f'\t{component.expr} is linear in {name}\n\t\tadd "{str(dexpr_dname)}" to matrix as column {free_parameters.index(name)}') - fx = [(lambda _: ast.eval(str(dexpr_dname)))(ast(f'x={v}')) for v in x] -# print(f'\tfx:\n{fx}') - if len(ast.error): - raise ValueError(f'Unable to evaluate {dexpr_dname}') - A[:,free_parameters.index(name)] += fx -# if self._parameter_norms[name]: -# print(f'\t\t{component.expr} is linear in {name} add "{str(dexpr_dname)}" to matrix as column {free_parameters.index(name)}') -# A[:,free_parameters.index(name)] += fx -# else: -# print(f'\t\t{component.expr} is linear in {name} add "({str(dexpr_dname)})/{norm}" to matrix as column {free_parameters.index(name)}') -# A[:,free_parameters.index(name)] += np.asarray(fx)/norm - # FIX: find another solution if expr not supported by simplify + y_expr = [(lambda _: ast.eval(str(dexpr_dname))) + (ast(f'x={v}')) for v in x] + if ast.error: + raise ValueError( + f'Unable to evaluate {dexpr_dname}') + mat_a[:,free_parameters.index(name)] += y_expr + # RV find another solution if expr not supported by + # simplify const_expr = str(simplify(f'({const_expr})/{norm}')) -# print(f'\nconst_expr: {const_expr}') - delta_y_const = [(lambda _: ast.eval(const_expr))(ast(f'x = {v}')) for v in x] + delta_y_const = [(lambda _: ast.eval(const_expr)) + (ast(f'x = {v}')) for v in x] y_const += delta_y_const -# print(f'\ndelta_y_const ({type(delta_y_const)}):\n{delta_y_const}\n') - if len(ast.error): + if ast.error: raise ValueError(f'Unable to evaluate {const_expr}') else: - free_model_parameters = [name for name in component.param_names - if name in free_parameters or name in expr_parameters] -# print(f'\nBuild-in model ({component}):\nfree_model_parameters: {free_model_parameters}\n') - if not len(free_model_parameters): + free_model_parameters = [ + name for name in component.param_names + if name in free_parameters or name in expr_parameters] + if not free_model_parameters: y_const += component.eval(params=self._parameters, x=x) elif isinstance(component, LinearModel): - if f'{component.prefix}slope' in free_model_parameters: - A[:,free_parameters.index(f'{component.prefix}slope')] = x + name = f'{component.prefix}slope' + if name in free_model_parameters: + mat_a[:,free_parameters.index(name)] = x else: - y_const += self._parameters[f'{component.prefix}slope'].value*x - if f'{component.prefix}intercept' in free_model_parameters: - A[:,free_parameters.index(f'{component.prefix}intercept')] = 1.0 + y_const += self._parameters[name].value * x + name = f'{component.prefix}intercept' + if name in free_model_parameters: + mat_a[:,free_parameters.index(name)] = 1.0 else: - y_const += self._parameters[f'{component.prefix}intercept'].value* \ - np.ones(len(x)) + y_const += self._parameters[name].value \ + * np.ones(len(x)) elif isinstance(component, QuadraticModel): - if f'{component.prefix}a' in free_model_parameters: - A[:,free_parameters.index(f'{component.prefix}a')] = x**2 + name = f'{component.prefix}a' + if name in free_model_parameters: + mat_a[:,free_parameters.index(name)] = x**2 else: - y_const += self._parameters[f'{component.prefix}a'].value*x**2 - if f'{component.prefix}b' in free_model_parameters: - A[:,free_parameters.index(f'{component.prefix}b')] = x + y_const += self._parameters[name].value * x**2 + name = f'{component.prefix}b' + if name in free_model_parameters: + mat_a[:,free_parameters.index(name)] = x else: - y_const += self._parameters[f'{component.prefix}b'].value*x - if f'{component.prefix}c' in free_model_parameters: - A[:,free_parameters.index(f'{component.prefix}c')] = 1.0 + y_const += self._parameters[name].value * x + name = f'{component.prefix}c' + if name in free_model_parameters: + mat_a[:,free_parameters.index(name)] = 1.0 else: - y_const += self._parameters[f'{component.prefix}c'].value*np.ones(len(x)) + y_const += self._parameters[name].value \ + * np.ones(len(x)) else: - # At this point each build-in model must be strictly proportional to each linear - # model parameter. Without this assumption, the model equation is needed - # For the current build-in lmfit models, this can only ever be the amplitude - assert(len(free_model_parameters) == 1) + # At this point each build-in model must be + # strictly proportional to each linear model + # parameter. Without this assumption, the model + # equation is needed + # For the current build-in lmfit models, this can + # only ever be the amplitude + assert len(free_model_parameters) == 1 name = f'{component.prefix}amplitude' - assert(free_model_parameters[0] == name) - assert(self._parameter_norms[name]) + assert free_model_parameters[0] == name + assert self._parameter_norms[name] expr = self._parameters[name].expr if expr is None: -# print(f'\t{component} is linear in {name} add to matrix as column {free_parameters.index(name)}') parameters = deepcopy(self._parameters) parameters[name].set(value=1.0) - index = free_parameters.index(name) - A[:,free_parameters.index(name)] += component.eval(params=parameters, x=x) + mat_a[:,free_parameters.index(name)] += component.eval( + params=parameters, x=x) else: const_expr = expr -# print(f'\tconst_expr: {const_expr}') parameters = deepcopy(self._parameters) parameters[name].set(value=1.0) dcomp_dname = component.eval(params=parameters, x=x) -# print(f'\tdcomp_dname ({type(dcomp_dname)}):\n{dcomp_dname}') for nname in free_parameters: dexpr_dnname = diff(expr, nname) if dexpr_dnname: - assert(self._parameter_norms[name]) -# print(f'\t\td({expr})/d{nname} = {dexpr_dnname}') -# print(f'\t\t{component} is linear in {nname} (through {name} = "{expr}", add to matrix as column {free_parameters.index(nname)})') - fx = np.asarray(dexpr_dnname*dcomp_dname, dtype='float64') -# print(f'\t\tfx ({type(fx)}): {fx}') -# print(f'free_parameters.index({nname}): {free_parameters.index(nname)}') + assert self._parameter_norms[name] + y_expr = np.asarray( + dexpr_dnname*dcomp_dname, dtype='float64') if self._parameter_norms[nname]: - A[:,free_parameters.index(nname)] += fx + mat_a[:,free_parameters.index(nname)] += \ + y_expr else: - A[:,free_parameters.index(nname)] += fx/norm - const_expr = f'{const_expr}-({dexpr_dnname})*{nname}' -# print(f'\t\tconst_expr: {const_expr}') + mat_a[:,free_parameters.index(nname)] += \ + y_expr/norm + const_expr = \ + f'{const_expr}-({dexpr_dnname})*{nname}' const_expr = str(simplify(f'({const_expr})/{norm}')) -# print(f'\tconst_expr: {const_expr}') - fx = [(lambda _: ast.eval(const_expr))(ast(f'x = {v}')) for v in x] -# print(f'\tfx: {fx}') - delta_y_const = np.multiply(fx, dcomp_dname) + y_expr = [ + (lambda _: ast.eval(const_expr))(ast(f'x = {v}')) + for v in x] + delta_y_const = np.multiply(y_expr, dcomp_dname) y_const += delta_y_const -# print(f'\ndelta_y_const ({type(delta_y_const)}):\n{delta_y_const}\n') -# print(A) -# print(y_const) - solution, residual, rank, s = np.linalg.lstsq(A, y-y_const, rcond=None) -# print(f'\nsolution ({type(solution)} {solution.shape}):\n\t{solution}') -# print(f'\nresidual ({type(residual)} {residual.shape}):\n\t{residual}') -# print(f'\nrank ({type(rank)} {rank.shape}):\n\t{rank}') -# print(f'\ns ({type(s)} {s.shape}):\n\t{s}\n') - - # Assemble result (compensate for normalization in expression models) + solution, _, _, _ = np.linalg.lstsq( + mat_a, y-y_const, rcond=None) + + # Assemble result + # (compensate for normalization in expression models) for name, value in zip(free_parameters, solution): self._parameters[name].set(value=value) - if self._normalized and (have_expression_model or len(expr_parameters)): + if (self._normalized and (have_expression_model + or expr_parameters)): for name, norm in self._parameter_norms.items(): par = self._parameters[name] if par.expr is None and norm: self._parameters[name].set(value=par.value*self._norm[1]) -# self._parameters.pretty_print() -# print(f'\nself._parameter_norms:\n\t{self._parameter_norms}') self._result = ModelResult(self._model, deepcopy(self._parameters)) self._result.best_fit = self._model.eval(params=self._parameters, x=x) - if self._normalized and (have_expression_model or len(expr_parameters)): + if (self._normalized and (have_expression_model + or expr_parameters)): if 'tmp_normalization_offset_c' in self._parameters: offset = self._parameters['tmp_normalization_offset_c'] else: offset = 0.0 - self._result.best_fit = (self._result.best_fit-offset-self._norm[0])/self._norm[1] + self._result.best_fit = \ + (self._result.best_fit-offset-self._norm[0]) / self._norm[1] if self._normalized: for name, norm in self._parameter_norms.items(): par = self._parameters[name] @@ -1367,15 +1444,12 @@ def _fit_linear_model(self, x, y): value = par.value/self._norm[1] self._parameters[name].set(value=value) self._result.params[name].set(value=value) -# self._parameters.pretty_print() self._result.residual = self._result.best_fit-y self._result.components = self._model.components self._result.init_params = None -# quick_plot((x, y, '.'), (x, y_const, 'g'), (x, self._result.best_fit, 'k'), (x, self._result.residual, 'r'), block=True) def _normalize(self): - """Normalize the data and initial parameters - """ + """Normalize the data and initial parameters.""" if self._normalized: return if self._norm is None: @@ -1383,7 +1457,8 @@ def _normalize(self): self._y_norm = np.asarray(self._y) else: if self._y is not None and self._y_norm is None: - self._y_norm = (np.asarray(self._y)-self._norm[0])/self._norm[1] + self._y_norm = \ + (np.asarray(self._y)-self._norm[0]) / self._norm[1] self._y_range = 1.0 for name, norm in self._parameter_norms.items(): par = self._parameters[name] @@ -1391,16 +1466,15 @@ def _normalize(self): value = par.value/self._norm[1] _min = par.min _max = par.max - if not np.isinf(_min) and abs(_min) != float_min: + if not np.isinf(_min) and abs(_min) != FLOAT_MIN: _min /= self._norm[1] - if not np.isinf(_max) and abs(_max) != float_min: + if not np.isinf(_max) and abs(_max) != FLOAT_MIN: _max /= self._norm[1] par.set(value=value, min=_min, max=_max) self._normalized = True def _renormalize(self): - """Renormalize the data and results - """ + """Renormalize the data and results.""" if self._norm is None or not self._normalized: return self._normalized = False @@ -1410,14 +1484,15 @@ def _renormalize(self): value = par.value*self._norm[1] _min = par.min _max = par.max - if not np.isinf(_min) and abs(_min) != float_min: + if not np.isinf(_min) and abs(_min) != FLOAT_MIN: _min *= self._norm[1] - if not np.isinf(_max) and abs(_max) != float_min: + if not np.isinf(_max) and abs(_max) != FLOAT_MIN: _max *= self._norm[1] par.set(value=value, min=_min, max=_max) if self._result is None: return - self._result.best_fit = self._result.best_fit*self._norm[1]+self._norm[0] + self._result.best_fit = (self._result.best_fit*self._norm[1] + + self._norm[0]) for name, par in self._result.params.items(): if self._parameter_norms.get(name, False): if par.stderr is not None: @@ -1428,17 +1503,19 @@ def _renormalize(self): value = par.value*self._norm[1] if par.init_value is not None: par.init_value *= self._norm[1] - if not np.isinf(_min) and abs(_min) != float_min: + if not np.isinf(_min) and abs(_min) != FLOAT_MIN: _min *= self._norm[1] - if not np.isinf(_max) and abs(_max) != float_min: + if not np.isinf(_max) and abs(_max) != FLOAT_MIN: _max *= self._norm[1] par.set(value=value, min=_min, max=_max) if hasattr(self._result, 'init_fit'): - self._result.init_fit = self._result.init_fit*self._norm[1]+self._norm[0] + self._result.init_fit = (self._result.init_fit*self._norm[1] + + self._norm[0]) if hasattr(self._result, 'init_values'): init_values = {} for name, value in self._result.init_values.items(): - if name not in self._parameter_norms or self._parameters[name].expr is not None: + if (name not in self._parameter_norms + or self._parameters[name].expr is not None): init_values[name] = value elif self._parameter_norms[name]: init_values[name] = value*self._norm[1] @@ -1449,14 +1526,15 @@ def _renormalize(self): _min = par.min _max = par.max value *= self._norm[1] - if not np.isinf(_min) and abs(_min) != float_min: + if not np.isinf(_min) and abs(_min) != FLOAT_MIN: _min *= self._norm[1] - if not np.isinf(_max) and abs(_max) != float_min: + if not np.isinf(_max) and abs(_max) != FLOAT_MIN: _max *= self._norm[1] par.set(value=value, min=_min, max=_max) par.init_value = par.value - # Don't renormalize chisqr, it has no useful meaning in physical units - #self._result.chisqr *= self._norm[1]*self._norm[1] + # Don't renormalize chisqr, it has no useful meaning in + # physical units +# self._result.chisqr *= self._norm[1]*self._norm[1] if self._result.covar is not None: for i, name in enumerate(self._result.var_names): if self._parameter_norms.get(name, False): @@ -1465,13 +1543,14 @@ def _renormalize(self): self._result.covar[i,j] *= self._norm[1] if self._result.covar[j,i] is not None: self._result.covar[j,i] *= self._norm[1] - # Don't renormalize redchi, it has no useful meaning in physical units - #self._result.redchi *= self._norm[1]*self._norm[1] + # Don't renormalize redchi, it has no useful meaning in + # physical units +# self._result.redchi *= self._norm[1]*self._norm[1] if self._result.residual is not None: self._result.residual *= self._norm[1] def _reset_par_at_boundary(self, par, value): - assert(par.vary) + assert par.vary name = par.name _min = self._parameter_bounds[name]['min'] _max = self._parameter_bounds[name]['max'] @@ -1484,151 +1563,165 @@ def _reset_par_at_boundary(self, par, value): else: upp = _max-0.1*abs(_max) if value >= upp: - return(upp) + return upp else: if np.isinf(_max): if self._parameter_norms.get(name, False): - low = _min+0.1*self._y_range + low = _min + 0.1*self._y_range elif _min == 0.0: low = _min+0.1 else: - low = _min+0.1*abs(_min) + low = _min + 0.1*abs(_min) if value <= low: - return(low) + return low else: - low = 0.9*_min+0.1*_max - upp = 0.1*_min+0.9*_max + low = 0.9*_min + 0.1*_max + upp = 0.1*_min + 0.9*_max if value <= low: - return(low) - elif value >= upp: - return(upp) - return(value) + return low + if value >= upp: + return upp + return value class FitMultipeak(Fit): - """Fit data with multiple peaks + """ + Wrapper to the Fit class to fit data with multiple peaks """ def __init__(self, y, x=None, normalize=True): + """Initialize FitMultipeak.""" super().__init__(y, x=x, normalize=normalize) self._fwhm_max = None self._sigma_max = None @classmethod - def fit_multipeak(cls, y, centers, x=None, normalize=True, peak_models='gaussian', + def fit_multipeak( + cls, y, centers, x=None, normalize=True, peak_models='gaussian', center_exprs=None, fit_type=None, background=None, fwhm_max=None, print_report=False, plot=False, x_eval=None): - """Make sure that centers and fwhm_max are in the correct units and consistent with expr - for a uniform fit (fit_type == 'uniform') + """Class method for FitMultipeak. + + Make sure that centers and fwhm_max are in the correct units + and consistent with expr for a uniform fit (fit_type == + 'uniform'). """ - if x_eval is not None and not isinstance(x_eval, (tuple, list, np.ndarray)): + if (x_eval is not None + and not isinstance(x_eval, (tuple, list, np.ndarray))): raise ValueError(f'Invalid parameter x_eval ({x_eval})') fit = cls(y, x=x, normalize=normalize) - success = fit.fit(centers, fit_type=fit_type, peak_models=peak_models, fwhm_max=fwhm_max, - center_exprs=center_exprs, background=background, print_report=print_report, - plot=plot) + success = fit.fit( + centers, fit_type=fit_type, peak_models=peak_models, + fwhm_max=fwhm_max, center_exprs=center_exprs, + background=background, print_report=print_report, plot=plot) if x_eval is None: best_fit = fit.best_fit else: best_fit = fit.eval(x_eval) if success: - return(best_fit, fit.residual, fit.best_values, fit.best_errors, fit.redchi, \ - fit.success) - else: - return(np.array([]), np.array([]), {}, {}, float_max, False) - - def fit(self, centers, fit_type=None, peak_models=None, center_exprs=None, fwhm_max=None, - background=None, print_report=False, plot=True, param_constraint=False): + return ( + best_fit, fit.residual, fit.best_values, fit.best_errors, + fit.redchi, fit.success) + return np.array([]), np.array([]), {}, {}, FLOAT_MAX, False + + def fit( + self, centers=None, fit_type=None, peak_models=None, + center_exprs=None, fwhm_max=None, background=None, + print_report=False, plot=True, param_constraint=False, **kwargs): + """Fit the model to the input data.""" + if centers is None: + raise ValueError('Missing required parameter centers') + if not isinstance(centers, (int, float, tuple, list)): + raise ValueError(f'Invalid parameter centers ({centers})') self._fwhm_max = fwhm_max - # Create the multipeak model - self._create_model(centers, fit_type, peak_models, center_exprs, background, - param_constraint) - - # RV: Obsolete Normalize the data and results -# print('\nBefore fit before normalization in FitMultipeak:') -# self._parameters.pretty_print() -# self._normalize() -# print('\nBefore fit after normalization in FitMultipeak:') -# self._parameters.pretty_print() + self._create_model( + centers, fit_type, peak_models, center_exprs, background, + param_constraint) # Perform the fit try: if param_constraint: - super().fit(fit_kws={'xtol': 1.e-5, 'ftol': 1.e-5, 'gtol': 1.e-5}) + super().fit( + fit_kws={'xtol': 1.e-5, 'ftol': 1.e-5, 'gtol': 1.e-5}) else: super().fit() except: - return(False) + return False # Check for valid fit parameter results fit_failure = self._check_validity() success = True if fit_failure: if param_constraint: - logging.warning(' -> Should not happen with param_constraint set, fail the fit') + logger.warning( + ' -> Should not happen with param_constraint set, ' + + 'fail the fit') success = False else: - logging.info(' -> Retry fitting with constraints') - self.fit(centers, fit_type, peak_models, center_exprs, fwhm_max=fwhm_max, - background=background, print_report=print_report, plot=plot, - param_constraint=True) + logger.info(' -> Retry fitting with constraints') + self.fit( + centers, fit_type, peak_models, center_exprs, + fwhm_max=fwhm_max, background=background, + print_report=print_report, plot=plot, + param_constraint=True) else: - # RV: Obsolete Renormalize the data and results -# print('\nAfter fit before renormalization in FitMultipeak:') -# self._parameters.pretty_print() -# self.print_fit_report() -# self._renormalize() -# print('\nAfter fit after renormalization in FitMultipeak:') -# self._parameters.pretty_print() -# self.print_fit_report() - # Print report and plot components if requested if print_report: self.print_fit_report() if plot: - self.plot(skip_init=True, plot_comp=True, plot_comp_legends=True, - plot_residual=True) + self.plot( + skip_init=True, plot_comp=True, plot_comp_legends=True, + plot_residual=True) - return(success) + return success + + def _create_model( + self, centers, fit_type=None, peak_models=None, center_exprs=None, + background=None, param_constraint=False): + """Create the multipeak model.""" + # Third party modules + from asteval import Interpreter - def _create_model(self, centers, fit_type=None, peak_models=None, center_exprs=None, - background=None, param_constraint=False): - """Create the multipeak model - """ if isinstance(centers, (int, float)): centers = [centers] num_peaks = len(centers) if peak_models is None: peak_models = num_peaks*['gaussian'] - elif isinstance(peak_models, str) and peak_models in ('gaussian', 'lorentzian'): + elif (isinstance(peak_models, str) + and peak_models in ('gaussian', 'lorentzian')): peak_models = num_peaks*[peak_models] else: - raise ValueError(f'Invalid peak model parameter ({peak_models})') + raise ValueError(f'Invalid parameter peak model ({peak_models})') if len(peak_models) != num_peaks: - raise ValueError(f'Inconsistent number of peaks in peak_models ({len(peak_models)} vs '+ - f'{num_peaks})') + raise ValueError( + 'Inconsistent number of peaks in peak_models ' + + f'({len(peak_models)} vs {num_peaks})') if num_peaks == 1: if fit_type is not None: - logging.debug('Ignoring fit_type input for fitting one peak') + logger.debug('Ignoring fit_type input for fitting one peak') fit_type = None if center_exprs is not None: - logging.debug('Ignoring center_exprs input for fitting one peak') + logger.debug( + 'Ignoring center_exprs input for fitting one peak') center_exprs = None else: if fit_type == 'uniform': if center_exprs is None: center_exprs = [f'scale_factor*{cen}' for cen in centers] if len(center_exprs) != num_peaks: - raise ValueError(f'Inconsistent number of peaks in center_exprs '+ - f'({len(center_exprs)} vs {num_peaks})') + raise ValueError( + 'Inconsistent number of peaks in center_exprs ' + + f'({len(center_exprs)} vs {num_peaks})') elif fit_type == 'unconstrained' or fit_type is None: if center_exprs is not None: - logging.warning('Ignoring center_exprs input for unconstrained fit') + logger.warning( + 'Ignoring center_exprs input for unconstrained fit') center_exprs = None else: - raise ValueError(f'Invalid fit_type in fit_multigaussian {fit_type}') + raise ValueError( + f'Invalid parameter fit_type ({fit_type})') self._sigma_max = None if param_constraint: - min_value = float_min + min_value = FLOAT_MIN if self._fwhm_max is not None: self._sigma_max = np.zeros(num_peaks) else: @@ -1648,17 +1741,23 @@ def _create_model(self, centers, fit_type=None, peak_models=None, center_exprs=N elif is_dict_series(background): for model in deepcopy(background): if 'model' not in model: - raise KeyError(f'Missing keyword "model" in model in background ({model})') + raise KeyError( + 'Missing keyword "model" in model in background ' + + f'({model})') name = model.pop('model') parameters=model.pop('parameters', None) - self.add_model(name, prefix=f'bkgd_{name}_', parameters=parameters, **model) + self.add_model( + name, prefix=f'bkgd_{name}_', parameters=parameters, + **model) else: - raise ValueError(f'Invalid parameter background ({background})') + raise ValueError( + f'Invalid parameter background ({background})') # Add peaks and guess initial fit parameters ast = Interpreter() if num_peaks == 1: - height_init, cen_init, fwhm_init = self.guess_init_peak(self._x, self._y) + height_init, cen_init, fwhm_init = self.guess_init_peak( + self._x, self._y) if self._fwhm_max is not None and fwhm_init > self._fwhm_max: fwhm_init = self._fwhm_max ast(f'fwhm = {fwhm_init}') @@ -1670,16 +1769,20 @@ def _create_model(self, centers, fit_type=None, peak_models=None, center_exprs=N ast(f'fwhm = {self._fwhm_max}') sig_max = ast(fwhm_factor[peak_models[0]]) self._sigma_max[0] = sig_max - self.add_model(peak_models[0], parameters=( + self.add_model( + peak_models[0], + parameters=( {'name': 'amplitude', 'value': amp_init, 'min': min_value}, {'name': 'center', 'value': cen_init, 'min': min_value}, - {'name': 'sigma', 'value': sig_init, 'min': min_value, 'max': sig_max})) + {'name': 'sigma', 'value': sig_init, 'min': min_value, + 'max': sig_max}, + )) else: if fit_type == 'uniform': self.add_parameter(name='scale_factor', value=1.0) for i in range(num_peaks): - height_init, cen_init, fwhm_init = self.guess_init_peak(self._x, self._y, i, - center_guess=centers) + height_init, cen_init, fwhm_init = self.guess_init_peak( + self._x, self._y, i, center_guess=centers) if self._fwhm_max is not None and fwhm_init > self._fwhm_max: fwhm_init = self._fwhm_max ast(f'fwhm = {fwhm_init}') @@ -1692,61 +1795,81 @@ def _create_model(self, centers, fit_type=None, peak_models=None, center_exprs=N sig_max = ast(fwhm_factor[peak_models[i]]) self._sigma_max[i] = sig_max if fit_type == 'uniform': - self.add_model(peak_models[i], prefix=f'peak{i+1}_', parameters=( - {'name': 'amplitude', 'value': amp_init, 'min': min_value}, + self.add_model( + peak_models[i], prefix=f'peak{i+1}_', + parameters=( + {'name': 'amplitude', 'value': amp_init, + 'min': min_value}, {'name': 'center', 'expr': center_exprs[i]}, - {'name': 'sigma', 'value': sig_init, 'min': min_value, - 'max': sig_max})) + {'name': 'sigma', 'value': sig_init, + 'min': min_value, 'max': sig_max}, + )) else: - self.add_model('gaussian', prefix=f'peak{i+1}_', parameters=( - {'name': 'amplitude', 'value': amp_init, 'min': min_value}, - {'name': 'center', 'value': cen_init, 'min': min_value}, - {'name': 'sigma', 'value': sig_init, 'min': min_value, - 'max': sig_max})) + self.add_model( + 'gaussian', + prefix=f'peak{i+1}_', + parameters=( + {'name': 'amplitude', 'value': amp_init, + 'min': min_value}, + {'name': 'center', 'value': cen_init, + 'min': min_value}, + {'name': 'sigma', 'value': sig_init, + 'min': min_value, 'max': sig_max}, + )) def _check_validity(self): - """Check for valid fit parameter results - """ + """Check for valid fit parameter results.""" fit_failure = False - index = compile(r'\d+') - for name, par in self.best_parameters.items(): + index = re_compile(r'\d+') + for name, par in self.best_parameters().items(): if 'bkgd' in name: -# if ((name == 'bkgd_c' and par['value'] <= 0.0) or -# (name.endswith('amplitude') and par['value'] <= 0.0) or - if ((name.endswith('amplitude') and par['value'] <= 0.0) or - (name.endswith('decay') and par['value'] <= 0.0)): - logging.info(f'Invalid fit result for {name} ({par["value"]})') + if ((name.endswith('amplitude') and par['value'] <= 0.0) + or (name.endswith('decay') and par['value'] <= 0.0)): + logger.info( + f'Invalid fit result for {name} ({par["value"]})') fit_failure = True - elif (((name.endswith('amplitude') or name.endswith('height')) and - par['value'] <= 0.0) or - ((name.endswith('sigma') or name.endswith('fwhm')) and par['value'] <= 0.0) or - (name.endswith('center') and par['value'] <= 0.0) or - (name == 'scale_factor' and par['value'] <= 0.0)): - logging.info(f'Invalid fit result for {name} ({par["value"]})') + elif (((name.endswith('amplitude') or name.endswith('height')) + and par['value'] <= 0.0) + or ((name.endswith('sigma') or name.endswith('fwhm')) + and par['value'] <= 0.0) + or (name.endswith('center') and par['value'] <= 0.0) + or (name == 'scale_factor' and par['value'] <= 0.0)): + logger.info(f'Invalid fit result for {name} ({par["value"]})') fit_failure = True - if 'bkgd' not in name and name.endswith('sigma') and self._sigma_max is not None: + if ('bkgd' not in name and name.endswith('sigma') + and self._sigma_max is not None): if name == 'sigma': sigma_max = self._sigma_max[0] else: - sigma_max = self._sigma_max[int(index.search(name).group())-1] + sigma_max = self._sigma_max[ + int(index.search(name).group())-1] if par['value'] > sigma_max: - logging.info(f'Invalid fit result for {name} ({par["value"]})') + logger.info( + f'Invalid fit result for {name} ({par["value"]})') fit_failure = True elif par['value'] == sigma_max: - logging.warning(f'Edge result on for {name} ({par["value"]})') - if 'bkgd' not in name and name.endswith('fwhm') and self._fwhm_max is not None: + logger.warning( + f'Edge result on for {name} ({par["value"]})') + if ('bkgd' not in name and name.endswith('fwhm') + and self._fwhm_max is not None): if par['value'] > self._fwhm_max: - logging.info(f'Invalid fit result for {name} ({par["value"]})') + logger.info( + f'Invalid fit result for {name} ({par["value"]})') fit_failure = True elif par['value'] == self._fwhm_max: - logging.warning(f'Edge result on for {name} ({par["value"]})') - return(fit_failure) + logger.warning( + f'Edge result on for {name} ({par["value"]})') + return fit_failure class FitMap(Fit): - """Fit a map of data """ - def __init__(self, ymap, x=None, models=None, normalize=True, transpose=None, **kwargs): + Wrapper to the Fit class to fit dat on a N-dimensional map + """ + def __init__( + self, ymap, x=None, models=None, normalize=True, transpose=None, + **kwargs): + """Initialize FitMap.""" super().__init__(None) self._best_errors = None self._best_fit = None @@ -1766,64 +1889,77 @@ def __init__(self, ymap, x=None, models=None, normalize=True, transpose=None, ** self._transpose = None self._try_no_bounds = True - # At this point the fastest index should always be the signal dimension so that the slowest - # ndim-1 dimensions are the map dimensions + # At this point the fastest index should always be the signal + # dimension so that the slowest ndim-1 dimensions are the + # map dimensions if isinstance(ymap, (tuple, list, np.ndarray)): self._x = np.asarray(x) - elif have_xarray and isinstance(ymap, xr.DataArray): + elif HAVE_XARRAY and isinstance(ymap, xr.DataArray): if x is not None: - logging.warning('Ignoring superfluous input x ({x}) in Fit.__init__') + logger.warning('Ignoring superfluous input x ({x})') self._x = np.asarray(ymap[ymap.dims[-1]]) else: - illegal_value(ymap, 'ymap', 'FitMap:__init__', raise_error=True) + raise ValueError('Invalid parameter ymap ({ymap})') self._ymap = ymap # Verify the input parameters if self._x.ndim != 1: raise ValueError(f'Invalid dimension for input x {self._x.ndim}') if self._ymap.ndim < 2: - raise ValueError('Invalid number of dimension of the input dataset '+ - f'{self._ymap.ndim}') + raise ValueError( + 'Invalid number of dimension of the input dataset ' + + f'{self._ymap.ndim}') if self._x.size != self._ymap.shape[-1]: - raise ValueError(f'Inconsistent x and y dimensions ({self._x.size} vs '+ - f'{self._ymap.shape[-1]})') + raise ValueError( + f'Inconsistent x and y dimensions ({self._x.size} vs ' + + f'{self._ymap.shape[-1]})') if not isinstance(normalize, bool): - logging.warning(f'Invalid value for normalize ({normalize}) in Fit.__init__: '+ - 'setting normalize to True') + logger.warning( + f'Invalid value for normalize ({normalize}) in Fit.__init__: ' + + 'setting normalize to True') normalize = True if isinstance(transpose, bool) and not transpose: transpose = None if transpose is not None and self._ymap.ndim < 3: - logging.warning(f'Transpose meaningless for {self._ymap.ndim-1}D data maps: ignoring '+ - 'transpose') + logger.warning( + f'Transpose meaningless for {self._ymap.ndim-1}D data maps: ' + + 'ignoring transpose') if transpose is not None: - if self._ymap.ndim == 3 and isinstance(transpose, bool) and transpose: + if (self._ymap.ndim == 3 and isinstance(transpose, bool) + and transpose): self._transpose = (1, 0) elif not isinstance(transpose, (tuple, list)): - logging.warning(f'Invalid data type for transpose ({transpose}, '+ - f'{type(transpose)}) in Fit.__init__: setting transpose to False') - elif len(transpose) != self._ymap.ndim-1: - logging.warning(f'Invalid dimension for transpose ({transpose}, must be equal to '+ - f'{self._ymap.ndim-1}) in Fit.__init__: setting transpose to False') + logger.warning( + f'Invalid data type for transpose ({transpose}, ' + + f'{type(transpose)}): setting transpose to False') + elif transpose != self._ymap.ndim-1: + logger.warning( + f'Invalid dimension for transpose ({transpose}, must be ' + + f'equal to {self._ymap.ndim-1}): ' + + 'setting transpose to False') elif any(i not in transpose for i in range(len(transpose))): - logging.warning(f'Invalid index in transpose ({transpose}) '+ - f'in Fit.__init__: setting transpose to False') + logger.warning( + f'Invalid index in transpose ({transpose}): ' + + 'setting transpose to False') elif not all(i==transpose[i] for i in range(self._ymap.ndim-1)): self._transpose = transpose if self._transpose is not None: self._inv_transpose = tuple(self._transpose.index(i) - for i in range(len(self._transpose))) + for i in range(len(self._transpose))) # Flatten the map (transpose if requested) - # Store the flattened map in self._ymap_norm, whether normalized or not + # Store the flattened map in self._ymap_norm, whether + # normalized or not if self._transpose is not None: - self._ymap_norm = np.transpose(np.asarray(self._ymap), list(self._transpose)+ - [len(self._transpose)]) + self._ymap_norm = np.transpose( + np.asarray(self._ymap), + list(self._transpose) + [len(self._transpose)]) else: self._ymap_norm = np.asarray(self._ymap) self._map_dim = int(self._ymap_norm.size/self._x.size) self._map_shape = self._ymap_norm.shape[:-1] - self._ymap_norm = np.reshape(self._ymap_norm, (self._map_dim, self._x.size)) + self._ymap_norm = np.reshape( + self._ymap_norm, (self._map_dim, self._x.size)) # Check if a mask is provided if 'mask' in kwargs: @@ -1834,8 +1970,9 @@ def __init__(self, ymap, x=None, models=None, normalize=True, transpose=None, ** else: self._mask = np.asarray(self._mask).astype(bool) if self._x.size != self._mask.size: - raise ValueError(f'Inconsistent mask dimension ({self._x.size} vs '+ - f'{self._mask.size})') + raise ValueError( + f'Inconsistent mask dimension ({self._x.size} vs ' + + f'{self._mask.size})') ymap_masked = np.asarray(self._ymap_norm)[:,~self._mask] ymap_min = float(ymap_masked.min()) ymap_max = float(ymap_masked.max()) @@ -1844,7 +1981,7 @@ def __init__(self, ymap, x=None, models=None, normalize=True, transpose=None, ** self._y_range = ymap_max-ymap_min if normalize and self._y_range > 0.0: self._norm = (ymap_min, self._y_range) - self._ymap_norm = (self._ymap_norm-self._norm[0])/self._norm[1] + self._ymap_norm = (self._ymap_norm-self._norm[0]) / self._norm[1] else: self._redchi_cutoff *= self._y_range**2 if models is not None: @@ -1857,25 +1994,31 @@ def __init__(self, ymap, x=None, models=None, normalize=True, transpose=None, ** @classmethod def fit_map(cls, ymap, models, x=None, normalize=True, **kwargs): - return(cls(ymap, x=x, models=models, normalize=normalize, **kwargs)) + """Class method for FitMap.""" + return cls(ymap, x=x, models=models, normalize=normalize, **kwargs) @property def best_errors(self): - return(self._best_errors) + """Return errors in the best fit parameters.""" + return self._best_errors @property def best_fit(self): - return(self._best_fit) + """Return the best fits.""" + return self._best_fit @property def best_results(self): - """Convert the input data array to a data set and add the fit results. """ - if self.best_values is None or self.best_errors is None or self.best_fit is None: - return(None) - if not have_xarray: - logging.warning('Unable to load xarray module') - return(None) + Convert the input DataArray to a data set and add the fit + results. + """ + if (self.best_values is None or self.best_errors is None + or self.best_fit is None): + return None + if not HAVE_XARRAY: + logger.warning('Unable to load xarray module') + return None best_values = self.best_values best_errors = self.best_errors if isinstance(self._ymap, xr.DataArray): @@ -1883,8 +2026,9 @@ def best_results(self): dims = self._ymap.dims fit_name = f'{self._ymap.name}_fit' else: - coords = {f'dim{n}_index':([f'dim{n}_index'], range(self._ymap.shape[n])) - for n in range(self._ymap.ndim-1)} + coords = { + f'dim{n}_index':([f'dim{n}_index'], range(self._ymap.shape[n])) + for n in range(self._ymap.ndim-1)} coords['x'] = (['x'], self._x) dims = list(coords.keys()) best_results = xr.Dataset(coords=coords) @@ -1894,26 +2038,31 @@ def best_results(self): if self._mask is not None: best_results['mask'] = self._mask for n in range(best_values.shape[0]): - best_results[f'{self._best_parameters[n]}_values'] = (dims[:-1], best_values[n]) - best_results[f'{self._best_parameters[n]}_errors'] = (dims[:-1], best_errors[n]) + best_results[f'{self._best_parameters[n]}_values'] = \ + (dims[:-1], best_values[n]) + best_results[f'{self._best_parameters[n]}_errors'] = \ + (dims[:-1], best_errors[n]) best_results.attrs['components'] = self.components - return(best_results) + return best_results @property def best_values(self): - return(self._best_values) + """Return values of the best fit parameters.""" + return self._best_values @property def chisqr(self): - logging.warning('property chisqr not defined for fit.FitMap') - return(None) + """Return the chisqr value of each best fit.""" + logger.warning('Undefined property chisqr') @property def components(self): + """Return the fit model components info.""" components = {} if self._result is None: - logging.warning('Unable to collect components in FitMap.components') - return(components) + logger.warning( + 'Unable to collect components in FitMap.components') + return components for component in self._result.components: if 'tmp_normalization_offset_c' in component.param_names: continue @@ -1922,9 +2071,15 @@ def components(self): if self._parameters[name].vary: parameters[name] = {'free': True} elif self._parameters[name].expr is not None: - parameters[name] = {'free': False, 'expr': self._parameters[name].expr} + parameters[name] = { + 'free': False, + 'expr': self._parameters[name].expr, + } else: - parameters[name] = {'free': False, 'value': self.init_parameters[name]['value']} + parameters[name] = { + 'free': False, + 'value': self.init_parameters[name]['value'], + } expr = None if isinstance(component, ExpressionModel): name = component._name @@ -1933,7 +2088,7 @@ def components(self): expr = component.expr else: prefix = component.prefix - if len(prefix): + if prefix: if prefix[-1] == '_': prefix = prefix[:-1] name = f'{prefix} ({component._name})' @@ -1943,70 +2098,89 @@ def components(self): components[name] = {'parameters': parameters} else: components[name] = {'expr': expr, 'parameters': parameters} - return(components) + return components @property def covar(self): - logging.warning('property covar not defined for fit.FitMap') - return(None) + """ + Return the covarience matrices of the best fit parameters. + """ + logger.warning('Undefined property covar') @property def max_nfev(self): - return(self._max_nfev) + """ + Return the maximum number of function evaluations for each fit. + """ + return self._max_nfev @property def num_func_eval(self): - logging.warning('property num_func_eval not defined for fit.FitMap') - return(None) + """ + Return the number of function evaluations for each best fit. + """ + logger.warning('Undefined property num_func_eval') @property def out_of_bounds(self): - return(self._out_of_bounds) + """Return the out_of_bounds value of each best fit.""" + return self._out_of_bounds @property def redchi(self): - return(self._redchi) + """Return the redchi value of each best fit.""" + return self._redchi @property def residual(self): + """Return the residual in each best fit.""" if self.best_fit is None: - return(None) + return None if self._mask is None: - return(np.asarray(self._ymap)-self.best_fit) + residual = np.asarray(self._ymap)-self.best_fit else: - ymap_flat = np.reshape(np.asarray(self._ymap), (self._map_dim, self._x.size)) + ymap_flat = np.reshape( + np.asarray(self._ymap), (self._map_dim, self._x.size)) ymap_flat_masked = ymap_flat[:,~self._mask] - ymap_masked = np.reshape(ymap_flat_masked, - list(self._map_shape)+[ymap_flat_masked.shape[-1]]) - return(ymap_masked-self.best_fit) + ymap_masked = np.reshape( + ymap_flat_masked, + list(self._map_shape) + [ymap_flat_masked.shape[-1]]) + residual = ymap_masked-self.best_fit + return residual @property def success(self): - return(self._success) + """Return the success value for each fit.""" + return self._success @property def var_names(self): - logging.warning('property var_names not defined for fit.FitMap') - return(None) + """ + Return the variable names for the covarience matrix property. + """ + logger.warning('Undefined property var_names') @property def y(self): - logging.warning('property y not defined for fit.FitMap') - return(None) + """Return the input y-array.""" + logger.warning('Undefined property y') @property def ymap(self): - return(self._ymap) + """Return the input y-array map.""" + return self._ymap def best_parameters(self, dims=None): + """Return the best fit parameters.""" if dims is None: - return(self._best_parameters) - if not isinstance(dims, (list, tuple)) or len(dims) != len(self._map_shape): - illegal_value(dims, 'dims', 'FitMap.best_parameters', raise_error=True) + return self._best_parameters + if (not isinstance(dims, (list, tuple)) + or len(dims) != len(self._map_shape)): + raise ValueError('Invalid parameter dims ({dims})') if self.best_values is None or self.best_errors is None: - logging.warning(f'Unable to obtain best parameter values for dims = {dims} in '+ - 'FitMap.best_parameters') - return({}) + logger.warning( + f'Unable to obtain best parameter values for dims = {dims}') + return {} # Create current parameters parameters = deepcopy(self._parameters) for n, name in enumerate(self._best_parameters): @@ -2017,26 +2191,40 @@ def best_parameters(self, dims=None): for name in sorted(parameters): if name != 'tmp_normalization_offset_c': par = parameters[name] - parameters_dict[name] = {'value': par.value, 'error': par.stderr, - 'init_value': self.init_parameters[name]['value'], 'min': par.min, - 'max': par.max, 'vary': par.vary, 'expr': par.expr} - return(parameters_dict) + parameters_dict[name] = { + 'value': par.value, + 'error': par.stderr, + 'init_value': self.init_parameters[name]['value'], + 'min': par.min, + 'max': par.max, + 'vary': par.vary, + 'expr': par.expr, + } + return parameters_dict def freemem(self): + """Free memory allocated for parallel processing.""" if self._memfolder is None: return try: rmtree(self._memfolder) self._memfolder = None except: - logging.warning('Could not clean-up automatically.') - - def plot(self, dims, y_title=None, plot_residual=False, plot_comp_legends=False, - plot_masked_data=True): - if not isinstance(dims, (list, tuple)) or len(dims) != len(self._map_shape): - illegal_value(dims, 'dims', 'FitMap.plot', raise_error=True) - if self._result is None or self.best_fit is None or self.best_values is None: - logging.warning(f'Unable to plot fit for dims = {dims} in FitMap.plot') + logger.warning('Could not clean-up automatically.') + + def plot( + self, dims=None, y_title=None, plot_residual=False, + plot_comp_legends=False, plot_masked_data=True, **kwargs): + """Plot the best fits.""" + if dims is None: + dims = [0]*len(self._map_shape) + if (not isinstance(dims, (list, tuple)) + or len(dims) != len(self._map_shape)): + raise ValueError('Invalid parameter dims ({dims})') + if (self._result is None or self.best_fit is None + or self.best_values is None): + logger.warning( + f'Unable to plot fit for dims = {dims}') return if y_title is None or not isinstance(y_title, str): y_title = 'data' @@ -2048,7 +2236,8 @@ def plot(self, dims, y_title=None, plot_residual=False, plot_comp_legends=False, plots = [(self._x, np.asarray(self._ymap[dims]), 'b.')] legend = [y_title] if plot_masked_data: - plots += [(self._x[mask], np.asarray(self._ymap)[(*dims,mask)], 'bx')] + plots += \ + [(self._x[mask], np.asarray(self._ymap)[(*dims,mask)], 'bx')] legend += ['masked data'] plots += [(self._x[~mask], self.best_fit[dims], 'k-')] legend += ['best fit'] @@ -2059,8 +2248,9 @@ def plot(self, dims, y_title=None, plot_residual=False, plot_comp_legends=False, parameters = deepcopy(self._parameters) for name in self._best_parameters: if self._parameters[name].vary: - parameters[name].set(value= - self.best_values[self._best_parameters.index(name)][dims]) + parameters[name].set( + value=self.best_values[self._best_parameters.index(name)] + [dims]) for component in self._result.components: if 'tmp_normalization_offset_c' in component.param_names: continue @@ -2071,7 +2261,7 @@ def plot(self, dims, y_title=None, plot_residual=False, plot_comp_legends=False, modelname = f'{prefix}: {component.expr}' else: prefix = component.prefix - if len(prefix): + if prefix: if prefix[-1] == '_': prefix = prefix[:-1] modelname = f'{prefix} ({component._name})' @@ -2085,46 +2275,61 @@ def plot(self, dims, y_title=None, plot_residual=False, plot_comp_legends=False, plots += [(self._x[~mask], y, '--')] if plot_comp_legends: legend.append(modelname) - quick_plot(tuple(plots), legend=legend, title=str(dims), block=True) + quick_plot( + tuple(plots), legend=legend, title=str(dims), block=True, **kwargs) def fit(self, **kwargs): -# t0 = time() + """Fit the model to the input data.""" # Check input parameters if self._model is None: - logging.error('Undefined fit model') + logger.error('Undefined fit model') if 'num_proc' in kwargs: num_proc = kwargs.pop('num_proc') if not is_int(num_proc, ge=1): - illegal_value(num_proc, 'num_proc', 'FitMap.fit', raise_error=True) + raise ValueError( + 'Invalid value for keyword argument num_proc ({num_proc})') else: num_proc = cpu_count() - if num_proc > 1 and not have_joblib: - logging.warning(f'Missing joblib in the conda environment, running FitMap serially') + if num_proc > 1 and not HAVE_JOBLIB: + logger.warning( + 'Missing joblib in the conda environment, running serially') num_proc = 1 if num_proc > cpu_count(): - logging.warning(f'The requested number of processors ({num_proc}) exceeds the maximum '+ - f'number of processors, num_proc reduced to ({cpu_count()})') + logger.warning( + f'The requested number of processors ({num_proc}) exceeds the ' + + 'maximum number of processors, num_proc reduced to ' + + f'({cpu_count()})') num_proc = cpu_count() if 'try_no_bounds' in kwargs: self._try_no_bounds = kwargs.pop('try_no_bounds') if not isinstance(self._try_no_bounds, bool): - illegal_value(self._try_no_bounds, 'try_no_bounds', 'FitMap.fit', raise_error=True) + raise ValueError( + 'Invalid value for keyword argument try_no_bounds ' + + f'({self._try_no_bounds})') if 'redchi_cutoff' in kwargs: self._redchi_cutoff = kwargs.pop('redchi_cutoff') if not is_num(self._redchi_cutoff, gt=0): - illegal_value(self._redchi_cutoff, 'redchi_cutoff', 'FitMap.fit', raise_error=True) + raise ValueError( + 'Invalid value for keyword argument redchi_cutoff' + + f'({self._redchi_cutoff})') if 'print_report' in kwargs: self._print_report = kwargs.pop('print_report') if not isinstance(self._print_report, bool): - illegal_value(self._print_report, 'print_report', 'FitMap.fit', raise_error=True) + raise ValueError( + 'Invalid value for keyword argument print_report' + + f'({self._print_report})') if 'plot' in kwargs: self._plot = kwargs.pop('plot') if not isinstance(self._plot, bool): - illegal_value(self._plot, 'plot', 'FitMap.fit', raise_error=True) + raise ValueError( + 'Invalid value for keyword argument plot' + + f'({self._plot})') if 'skip_init' in kwargs: self._skip_init = kwargs.pop('skip_init') if not isinstance(self._skip_init, bool): - illegal_value(self._skip_init, 'skip_init', 'FitMap.fit', raise_error=True) + raise ValueError( + 'Invalid value for keyword argument skip_init' + + f'({self._skip_init})') # Apply mask if supplied: if 'mask' in kwargs: @@ -2132,51 +2337,58 @@ def fit(self, **kwargs): if self._mask is not None: self._mask = np.asarray(self._mask).astype(bool) if self._x.size != self._mask.size: - raise ValueError(f'Inconsistent x and mask dimensions ({self._x.size} vs '+ - f'{self._mask.size})') + raise ValueError( + f'Inconsistent x and mask dimensions ({self._x.size} vs ' + + f'{self._mask.size})') # Add constant offset for a normalized single component model if self._result is None and self._norm is not None and self._norm[0]: - self.add_model('constant', prefix='tmp_normalization_offset_', parameters={'name': 'c', - 'value': -self._norm[0], 'vary': False, 'norm': True}) - #'value': -self._norm[0]/self._norm[1], 'vary': False, 'norm': False}) + self.add_model( + 'constant', + prefix='tmp_normalization_offset_', + parameters={ + 'name': 'c', + 'value': -self._norm[0], + 'vary': False, + 'norm': True, +# 'value': -self._norm[0]/self._norm[1], +# 'vary': False, +# 'norm': False, + }) # Adjust existing parameters for refit: if 'parameters' in kwargs: -# print('\nIn FitMap before adjusting existing parameters for refit:') -# self._parameters.pretty_print() -# if self._result is None: -# raise ValueError('Invalid parameter parameters ({parameters})') -# if self._best_values is None: -# raise ValueError('Valid self._best_values required for refitting in FitMap.fit') parameters = kwargs.pop('parameters') -# print(f'\nparameters:\n{parameters}') if isinstance(parameters, dict): parameters = (parameters, ) elif not is_dict_series(parameters): - illegal_value(parameters, 'parameters', 'Fit.fit', raise_error=True) + raise ValueError( + 'Invalid value for keyword argument parameters' + + f'({parameters})') for par in parameters: name = par['name'] if name not in self._parameters: - raise ValueError(f'Unable to match {name} parameter {par} to an existing one') + raise ValueError( + f'Unable to match {name} parameter {par} to an ' + + 'existing one') if self._parameters[name].expr is not None: - raise ValueError(f'Unable to modify {name} parameter {par} (currently an '+ - 'expression)') + raise ValueError( + f'Unable to modify {name} parameter {par} ' + + '(currently an expression)') value = par.get('value') vary = par.get('vary') if par.get('expr') is not None: - raise KeyError(f'Invalid "expr" key in {name} parameter {par}') - self._parameters[name].set(value=value, vary=vary, min=par.get('min'), - max=par.get('max')) - # Overwrite existing best values for fixed parameters when a value is specified -# print(f'best values befored resetting:\n{self._best_values}') + raise KeyError( + f'Invalid "expr" key in {name} parameter {par}') + self._parameters[name].set( + value=value, vary=vary, min=par.get('min'), + max=par.get('max')) + # Overwrite existing best values for fixed parameters + # when a value is specified if isinstance(value, (int, float)) and vary is False: for i, nname in enumerate(self._best_parameters): if nname == name: self._best_values[i] = value -# print(f'best values after resetting (value={value}, vary={vary}):\n{self._best_values}') -#RV print('\nIn FitMap after adjusting existing parameters for refit:') -#RV self._parameters.pretty_print() # Check for uninitialized parameters for name, par in self._parameters.items(): @@ -2189,57 +2401,53 @@ def fit(self, **kwargs): elif self._parameter_norms[name]: self._parameters[name].set(value=value*self._norm[1]) - # Create the best parameter list, consisting of all varying parameters plus the expression - # parameters in order to collect their errors + # Create the best parameter list, consisting of all varying + # parameters plus the expression parameters in order to + # collect their errors if self._result is None: # Initial fit - assert(self._best_parameters is None) - self._best_parameters = [name for name, par in self._parameters.items() - if par.vary or par.expr is not None] + assert self._best_parameters is None + self._best_parameters = [ + name for name, par in self._parameters.items() + if par.vary or par.expr is not None] num_new_parameters = 0 else: # Refit - assert(len(self._best_parameters)) - self._new_parameters = [name for name, par in self._parameters.items() - if name != 'tmp_normalization_offset_c' and name not in self._best_parameters and - (par.vary or par.expr is not None)] + assert self._best_parameters + self._new_parameters = [ + name for name, par in self._parameters.items() + if name != 'tmp_normalization_offset_c' + and name not in self._best_parameters + and (par.vary or par.expr is not None)] num_new_parameters = len(self._new_parameters) num_best_parameters = len(self._best_parameters) - # Flatten and normalize the best values of the previous fit, remove the remaining results - # of the previous fit + # Flatten and normalize the best values of the previous fit, + # remove the remaining results of the previous fit if self._result is not None: -# print('\nBefore flatten and normalize:') -# print(f'self._best_values:\n{self._best_values}') self._out_of_bounds = None self._max_nfev = None self._redchi = None self._success = None self._best_fit = None self._best_errors = None - assert(self._best_values is not None) - assert(self._best_values.shape[0] == num_best_parameters) - assert(self._best_values.shape[1:] == self._map_shape) + assert self._best_values is not None + assert self._best_values.shape[0] == num_best_parameters + assert self._best_values.shape[1:] == self._map_shape if self._transpose is not None: self._best_values = np.transpose(self._best_values, [0]+[i+1 for i in self._transpose]) - self._best_values = [np.reshape(self._best_values[i], self._map_dim) + self._best_values = [ + np.reshape(self._best_values[i], self._map_dim) for i in range(num_best_parameters)] if self._norm is not None: for i, name in enumerate(self._best_parameters): if self._parameter_norms.get(name, False): self._best_values[i] /= self._norm[1] -#RV print('\nAfter flatten and normalize:') -#RV print(f'self._best_values:\n{self._best_values}') - # Normalize the initial parameters (and best values for a refit) -# print('\nIn FitMap before normalize:') -# self._parameters.pretty_print() -# print(f'\nparameter_norms:\n{self._parameter_norms}\n') + # Normalize the initial parameters + # (and best values for a refit) self._normalize() -# print('\nIn FitMap after normalize:') -# self._parameters.pretty_print() -# print(f'\nparameter_norms:\n{self._parameter_norms}\n') # Prevent initial values from sitting at boundaries self._parameter_bounds = {name:{'min': par.min, 'max': par.max} @@ -2247,10 +2455,9 @@ def fit(self, **kwargs): for name, par in self._parameters.items(): if par.vary: par.set(value=self._reset_par_at_boundary(par, par.value)) -# print('\nAfter checking boundaries:') -# self._parameters.pretty_print() - # Set parameter bounds to unbound (only use bounds when fit fails) + # Set parameter bounds to unbound + # (only use bounds when fit fails) if self._try_no_bounds: for name in self._parameter_bounds.keys(): self._parameters[name].set(min=-np.inf, max=np.inf) @@ -2267,116 +2474,124 @@ def fit(self, **kwargs): self._success_flat = np.zeros(self._map_dim, dtype=bool) self._best_fit_flat = np.zeros((self._map_dim, x_size), dtype=self._ymap_norm.dtype) - self._best_errors_flat = [np.zeros(self._map_dim, dtype=np.float64) - for _ in range(num_best_parameters+num_new_parameters)] + self._best_errors_flat = [ + np.zeros(self._map_dim, dtype=np.float64) + for _ in range(num_best_parameters+num_new_parameters)] if self._result is None: - self._best_values_flat = [np.zeros(self._map_dim, dtype=np.float64) - for _ in range(num_best_parameters)] + self._best_values_flat = [ + np.zeros(self._map_dim, dtype=np.float64) + for _ in range(num_best_parameters)] else: self._best_values_flat = self._best_values - self._best_values_flat += [np.zeros(self._map_dim, dtype=np.float64) - for _ in range(num_new_parameters)] + self._best_values_flat += [ + np.zeros(self._map_dim, dtype=np.float64) + for _ in range(num_new_parameters)] else: self._memfolder = './joblib_memmap' try: mkdir(self._memfolder) except FileExistsError: pass - filename_memmap = path.join(self._memfolder, 'out_of_bounds_memmap') - self._out_of_bounds_flat = np.memmap(filename_memmap, dtype=bool, - shape=(self._map_dim), mode='w+') + filename_memmap = path.join( + self._memfolder, 'out_of_bounds_memmap') + self._out_of_bounds_flat = np.memmap( + filename_memmap, dtype=bool, shape=(self._map_dim), mode='w+') filename_memmap = path.join(self._memfolder, 'max_nfev_memmap') - self._max_nfev_flat = np.memmap(filename_memmap, dtype=bool, - shape=(self._map_dim), mode='w+') + self._max_nfev_flat = np.memmap( + filename_memmap, dtype=bool, shape=(self._map_dim), mode='w+') filename_memmap = path.join(self._memfolder, 'redchi_memmap') - self._redchi_flat = np.memmap(filename_memmap, dtype=np.float64, - shape=(self._map_dim), mode='w+') + self._redchi_flat = np.memmap( + filename_memmap, dtype=np.float64, shape=(self._map_dim), + mode='w+') filename_memmap = path.join(self._memfolder, 'success_memmap') - self._success_flat = np.memmap(filename_memmap, dtype=bool, - shape=(self._map_dim), mode='w+') + self._success_flat = np.memmap( + filename_memmap, dtype=bool, shape=(self._map_dim), mode='w+') filename_memmap = path.join(self._memfolder, 'best_fit_memmap') - self._best_fit_flat = np.memmap(filename_memmap, dtype=self._ymap_norm.dtype, - shape=(self._map_dim, x_size), mode='w+') + self._best_fit_flat = np.memmap( + filename_memmap, dtype=self._ymap_norm.dtype, + shape=(self._map_dim, x_size), mode='w+') self._best_errors_flat = [] for i in range(num_best_parameters+num_new_parameters): - filename_memmap = path.join(self._memfolder, f'best_errors_memmap_{i}') - self._best_errors_flat.append(np.memmap(filename_memmap, dtype=np.float64, - shape=self._map_dim, mode='w+')) + filename_memmap = path.join( + self._memfolder, f'best_errors_memmap_{i}') + self._best_errors_flat.append( + np.memmap(filename_memmap, dtype=np.float64, + shape=self._map_dim, mode='w+')) self._best_values_flat = [] for i in range(num_best_parameters): - filename_memmap = path.join(self._memfolder, f'best_values_memmap_{i}') - self._best_values_flat.append(np.memmap(filename_memmap, dtype=np.float64, - shape=self._map_dim, mode='w+')) + filename_memmap = path.join( + self._memfolder, f'best_values_memmap_{i}') + self._best_values_flat.append( + np.memmap(filename_memmap, dtype=np.float64, + shape=self._map_dim, mode='w+')) if self._result is not None: self._best_values_flat[i][:] = self._best_values[i][:] for i in range(num_new_parameters): - filename_memmap = path.join(self._memfolder, - f'best_values_memmap_{i+num_best_parameters}') - self._best_values_flat.append(np.memmap(filename_memmap, dtype=np.float64, - shape=self._map_dim, mode='w+')) + filename_memmap = path.join( + self._memfolder, + f'best_values_memmap_{i+num_best_parameters}') + self._best_values_flat.append( + np.memmap(filename_memmap, dtype=np.float64, + shape=self._map_dim, mode='w+')) # Update the best parameter list if num_new_parameters: self._best_parameters += self._new_parameters - # Perform the first fit to get model component info and initial parameters + # Perform the first fit to get model component info and + # initial parameters current_best_values = {} -# print(f'0 before:\n{current_best_values}') -# t1 = time() - self._result = self._fit(0, current_best_values, return_result=True, **kwargs) -# t2 = time() -# print(f'0 after:\n{current_best_values}') -# print('\nAfter the first fit:') -# self._parameters.pretty_print() -# print(self._result.fit_report(show_correl=False)) + self._result = self._fit( + 0, current_best_values, return_result=True, **kwargs) # Remove all irrelevant content from self._result - for attr in ('_abort', 'aborted', 'aic', 'best_fit', 'best_values', 'bic', 'calc_covar', - 'call_kws', 'chisqr', 'ci_out', 'col_deriv', 'covar', 'data', 'errorbars', - 'flatchain', 'ier', 'init_vals', 'init_fit', 'iter_cb', 'jacfcn', 'kws', - 'last_internal_values', 'lmdif_message', 'message', 'method', 'nan_policy', - 'ndata', 'nfev', 'nfree', 'params', 'redchi', 'reduce_fcn', 'residual', 'result', - 'scale_covar', 'show_candidates', 'calc_covar', 'success', 'userargs', 'userfcn', - 'userkws', 'values', 'var_names', 'weights', 'user_options'): + for attr in ( + '_abort', 'aborted', 'aic', 'best_fit', 'best_values', 'bic', + 'calc_covar', 'call_kws', 'chisqr', 'ci_out', 'col_deriv', + 'covar', 'data', 'errorbars', 'flatchain', 'ier', 'init_vals', + 'init_fit', 'iter_cb', 'jacfcn', 'kws', 'last_internal_values', + 'lmdif_message', 'message', 'method', 'nan_policy', 'ndata', + 'nfev', 'nfree', 'params', 'redchi', 'reduce_fcn', 'residual', + 'result', 'scale_covar', 'show_candidates', 'calc_covar', + 'success', 'userargs', 'userfcn', 'userkws', 'values', + 'var_names', 'weights', 'user_options'): try: delattr(self._result, attr) except AttributeError: -# logging.warning(f'Unknown attribute {attr} in fit.FtMap._cleanup_result') +# logger.warning(f'Unknown attribute {attr}') pass -# t3 = time() if num_proc == 1: # Perform the remaining fits serially for n in range(1, self._map_dim): -# print(f'{n} before:\n{current_best_values}') self._fit(n, current_best_values, **kwargs) -# print(f'{n} after:\n{current_best_values}') else: # Perform the remaining fits in parallel num_fit = self._map_dim-1 -# print(f'num_fit = {num_fit}') if num_proc > num_fit: - logging.warning(f'The requested number of processors ({num_proc}) exceeds the '+ - f'number of fits, num_proc reduced to ({num_fit})') + logger.warning( + f'The requested number of processors ({num_proc}) exceeds ' + + f'the number of fits, num_proc reduced to ({num_fit})') num_proc = num_fit num_fit_per_proc = 1 else: num_fit_per_proc = round((num_fit)/num_proc) if num_proc*num_fit_per_proc < num_fit: num_fit_per_proc +=1 -# print(f'num_fit_per_proc = {num_fit_per_proc}') num_fit_batch = min(num_fit_per_proc, 40) -# print(f'num_fit_batch = {num_fit_batch}') with Parallel(n_jobs=num_proc) as parallel: - parallel(delayed(self._fit_parallel)(current_best_values, num_fit_batch, - n_start, **kwargs) for n_start in range(1, self._map_dim, num_fit_batch)) -# t4 = time() + parallel( + delayed(self._fit_parallel) + (current_best_values, num_fit_batch, + n_start, **kwargs) + for n_start in range(1, self._map_dim, num_fit_batch)) # Renormalize the initial parameters for external use if self._norm is not None and self._normalized: init_values = {} for name, value in self._result.init_values.items(): - if name not in self._parameter_norms or self._parameters[name].expr is not None: + if (name not in self._parameter_norms + or self._parameters[name].expr is not None): init_values[name] = value elif self._parameter_norms[name]: init_values[name] = value*self._norm[1] @@ -2386,36 +2601,40 @@ def fit(self, **kwargs): _min = par.min _max = par.max value = par.value*self._norm[1] - if not np.isinf(_min) and abs(_min) != float_min: + if not np.isinf(_min) and abs(_min) != FLOAT_MIN: _min *= self._norm[1] - if not np.isinf(_max) and abs(_max) != float_min: + if not np.isinf(_max) and abs(_max) != FLOAT_MIN: _max *= self._norm[1] par.set(value=value, min=_min, max=_max) par.init_value = par.value # Remap the best results -# t5 = time() - self._out_of_bounds = np.copy(np.reshape(self._out_of_bounds_flat, self._map_shape)) - self._max_nfev = np.copy(np.reshape(self._max_nfev_flat, self._map_shape)) + self._out_of_bounds = np.copy(np.reshape( + self._out_of_bounds_flat, self._map_shape)) + self._max_nfev = np.copy(np.reshape( + self._max_nfev_flat, self._map_shape)) self._redchi = np.copy(np.reshape(self._redchi_flat, self._map_shape)) - self._success = np.copy(np.reshape(self._success_flat, self._map_shape)) - self._best_fit = np.copy(np.reshape(self._best_fit_flat, - list(self._map_shape)+[x_size])) - self._best_values = np.asarray([np.reshape(par, list(self._map_shape)) - for par in self._best_values_flat]) - self._best_errors = np.asarray([np.reshape(par, list(self._map_shape)) - for par in self._best_errors_flat]) + self._success = np.copy(np.reshape( + self._success_flat, self._map_shape)) + self._best_fit = np.copy(np.reshape( + self._best_fit_flat, list(self._map_shape)+[x_size])) + self._best_values = np.asarray([np.reshape( + par, list(self._map_shape)) for par in self._best_values_flat]) + self._best_errors = np.asarray([np.reshape( + par, list(self._map_shape)) for par in self._best_errors_flat]) if self._inv_transpose is not None: - self._out_of_bounds = np.transpose(self._out_of_bounds, self._inv_transpose) + self._out_of_bounds = np.transpose( + self._out_of_bounds, self._inv_transpose) self._max_nfev = np.transpose(self._max_nfev, self._inv_transpose) self._redchi = np.transpose(self._redchi, self._inv_transpose) self._success = np.transpose(self._success, self._inv_transpose) - self._best_fit = np.transpose(self._best_fit, - list(self._inv_transpose)+[len(self._inv_transpose)]) - self._best_values = np.transpose(self._best_values, - [0]+[i+1 for i in self._inv_transpose]) - self._best_errors = np.transpose(self._best_errors, - [0]+[i+1 for i in self._inv_transpose]) + self._best_fit = np.transpose( + self._best_fit, + list(self._inv_transpose) + [len(self._inv_transpose)]) + self._best_values = np.transpose( + self._best_values, [0] + [i+1 for i in self._inv_transpose]) + self._best_errors = np.transpose( + self._best_errors, [0] + [i+1 for i in self._inv_transpose]) del self._out_of_bounds_flat del self._max_nfev_flat del self._redchi_flat @@ -2423,7 +2642,6 @@ def fit(self, **kwargs): del self._best_fit_flat del self._best_values_flat del self._best_errors_flat -# t6 = time() # Restore parameter bounds and renormalize the parameters for name, par in self._parameter_bounds.items(): @@ -2436,20 +2654,11 @@ def fit(self, **kwargs): value = par.value*self._norm[1] _min = par.min _max = par.max - if not np.isinf(_min) and abs(_min) != float_min: + if not np.isinf(_min) and abs(_min) != FLOAT_MIN: _min *= self._norm[1] - if not np.isinf(_max) and abs(_max) != float_min: + if not np.isinf(_max) and abs(_max) != FLOAT_MIN: _max *= self._norm[1] par.set(value=value, min=_min, max=_max) -# t7 = time() -# print(f'total run time in fit: {t7-t0:.2f} seconds') -# print(f'run time first fit: {t2-t1:.2f} seconds') -# print(f'run time remaining fits: {t4-t3:.2f} seconds') -# print(f'run time remapping results: {t6-t5:.2f} seconds') - -# print('\n\nAt end fit:') -# self._parameters.pretty_print() -# print(f'self._best_values:\n{self._best_values}\n\n') # Free the shared memory self.freemem() @@ -2457,17 +2666,11 @@ def fit(self, **kwargs): def _fit_parallel(self, current_best_values, num, n_start, **kwargs): num = min(num, self._map_dim-n_start) for n in range(num): -# print(f'{n_start+n} before:\n{current_best_values}') self._fit(n_start+n, current_best_values, **kwargs) -# print(f'{n_start+n} after:\n{current_best_values}') def _fit(self, n, current_best_values, return_result=False, **kwargs): -#RV print(f'\n\nstart FitMap._fit {n}\n') -#RV print(f'current_best_values = {current_best_values}') -#RV print(f'self._best_parameters = {self._best_parameters}') -#RV print(f'self._new_parameters = {self._new_parameters}\n\n') -# self._parameters.pretty_print() - # Set parameters to current best values, but prevent them from sitting at boundaries + # Set parameters to current best values, but prevent them from + # sitting at boundaries if self._new_parameters is None: # Initial fit for name, value in current_best_values.items(): @@ -2479,19 +2682,17 @@ def _fit(self, n, current_best_values, return_result=False, **kwargs): par = self._parameters[name] if name in self._new_parameters: if name in current_best_values: - par.set(value=self._reset_par_at_boundary(par, current_best_values[name])) + par.set(value=self._reset_par_at_boundary( + par, current_best_values[name])) elif par.expr is None: par.set(value=self._best_values[i][n]) -#RV print(f'\nbefore fit {n}') -#RV self._parameters.pretty_print() if self._mask is None: - result = self._model.fit(self._ymap_norm[n], self._parameters, x=self._x, **kwargs) + result = self._model.fit( + self._ymap_norm[n], self._parameters, x=self._x, **kwargs) else: - result = self._model.fit(self._ymap_norm[n][~self._mask], self._parameters, - x=self._x[~self._mask], **kwargs) -# print(f'\nafter fit {n}') -# self._parameters.pretty_print() -# print(result.fit_report(show_correl=False)) + result = self._model.fit( + self._ymap_norm[n][~self._mask], self._parameters, + x=self._x[~self._mask], **kwargs) out_of_bounds = False for name, par in self._parameter_bounds.items(): value = result.params[name].value @@ -2506,7 +2707,8 @@ def _fit(self, n, current_best_values, return_result=False, **kwargs): # Rerun fit with parameter bounds in place for name, par in self._parameter_bounds.items(): self._parameters[name].set(min=par['min'], max=par['max']) - # Set parameters to current best values, but prevent them from sitting at boundaries + # Set parameters to current best values, but prevent them + # from sitting at boundaries if self._new_parameters is None: # Initial fit for name, value in current_best_values.items(): @@ -2522,17 +2724,13 @@ def _fit(self, n, current_best_values, return_result=False, **kwargs): current_best_values[name])) elif par.expr is None: par.set(value=self._best_values[i][n]) -# print('\nbefore fit') -# self._parameters.pretty_print() -# print(result.fit_report(show_correl=False)) if self._mask is None: - result = self._model.fit(self._ymap_norm[n], self._parameters, x=self._x, **kwargs) + result = self._model.fit( + self._ymap_norm[n], self._parameters, x=self._x, **kwargs) else: - result = self._model.fit(self._ymap_norm[n][~self._mask], self._parameters, + result = self._model.fit( + self._ymap_norm[n][~self._mask], self._parameters, x=self._x[~self._mask], **kwargs) -# print(f'\nafter fit {n}') -# self._parameters.pretty_print() -# print(result.fit_report(show_correl=False)) out_of_bounds = False for name, par in self._parameter_bounds.items(): value = result.params[name].value @@ -2542,26 +2740,25 @@ def _fit(self, n, current_best_values, return_result=False, **kwargs): if not np.isinf(par['max']) and value > par['max']: out_of_bounds = True break -# print(f'{n} redchi < redchi_cutoff = {result.redchi < self._redchi_cutoff} success = {result.success} out_of_bounds = {out_of_bounds}') # Reset parameters back to unbound for name in self._parameter_bounds.keys(): self._parameters[name].set(min=-np.inf, max=np.inf) - assert(not out_of_bounds) + assert not out_of_bounds if result.redchi >= self._redchi_cutoff: result.success = False if result.nfev == result.max_nfev: -# print(f'Maximum number of function evaluations reached for n = {n}') -# logging.warning(f'Maximum number of function evaluations reached for n = {n}') if result.redchi < self._redchi_cutoff: result.success = True self._max_nfev_flat[n] = True if result.success: - assert(all(True for par in current_best_values if par in result.params.values())) + assert all( + True for par in current_best_values + if par in result.params.values()) for par in result.params.values(): if par.vary: current_best_values[par.name] = par.value else: - logging.warning(f'Fit for n = {n} failed: {result.lmdif_message}') + logger.warning(f'Fit for n = {n} failed: {result.lmdif_message}') # Renormalize the data and results self._renormalize(n, result) if self._print_report: @@ -2569,16 +2766,15 @@ def _fit(self, n, current_best_values, return_result=False, **kwargs): if self._plot: dims = np.unravel_index(n, self._map_shape) if self._inv_transpose is not None: - dims= tuple(dims[self._inv_transpose[i]] for i in range(len(dims))) - super().plot(result=result, y=np.asarray(self._ymap[dims]), plot_comp_legends=True, - skip_init=self._skip_init, title=str(dims)) -#RV print(f'\n\nend FitMap._fit {n}\n') -#RV print(f'current_best_values = {current_best_values}') -# self._parameters.pretty_print() -# print(result.fit_report(show_correl=False)) -#RV print(f'\nself._best_values_flat:\n{self._best_values_flat}\n\n') + dims= tuple( + dims[self._inv_transpose[i]] for i in range(len(dims))) + super().plot( + result=result, y=np.asarray(self._ymap[dims]), + plot_comp_legends=True, skip_init=self._skip_init, + title=str(dims)) if return_result: - return(result) + return result + return None def _renormalize(self, n, result): self._redchi_flat[n] = np.float64(result.redchi) @@ -2586,8 +2782,10 @@ def _renormalize(self, n, result): if self._norm is None or not self._normalized: self._best_fit_flat[n] = result.best_fit for i, name in enumerate(self._best_parameters): - self._best_values_flat[i][n] = np.float64(result.params[name].value) - self._best_errors_flat[i][n] = np.float64(result.params[name].stderr) + self._best_values_flat[i][n] = np.float64( + result.params[name].value) + self._best_errors_flat[i][n] = np.float64( + result.params[name].stderr) else: pars = set(self._parameter_norms) & set(self._best_parameters) for name, par in result.params.items(): @@ -2599,15 +2797,21 @@ def _renormalize(self, n, result): if self._print_report: if par.init_value is not None: par.init_value *= self._norm[1] - if not np.isinf(par.min) and abs(par.min) != float_min: + if (not np.isinf(par.min) + and abs(par.min) != FLOAT_MIN): par.min *= self._norm[1] - if not np.isinf(par.max) and abs(par.max) != float_min: + if (not np.isinf(par.max) + and abs(par.max) != FLOAT_MIN): par.max *= self._norm[1] - self._best_fit_flat[n] = result.best_fit*self._norm[1]+self._norm[0] + self._best_fit_flat[n] = (result.best_fit*self._norm[1] + + self._norm[0]) for i, name in enumerate(self._best_parameters): - self._best_values_flat[i][n] = np.float64(result.params[name].value) - self._best_errors_flat[i][n] = np.float64(result.params[name].stderr) + self._best_values_flat[i][n] = np.float64( + result.params[name].value) + self._best_errors_flat[i][n] = np.float64( + result.params[name].stderr) if self._plot: if not self._skip_init: - result.init_fit = result.init_fit*self._norm[1]+self._norm[0] + result.init_fit = (result.init_fit*self._norm[1] + + self._norm[0]) result.best_fit = np.copy(self._best_fit_flat[n]) diff --git a/CHAP/common/utils/general.py b/CHAP/common/utils/general.py index 90fecce..6033254 100755 --- a/CHAP/common/utils/general.py +++ b/CHAP/common/utils/general.py @@ -1,45 +1,58 @@ #!/usr/bin/env python3 - -#FIX write a function that returns a list of peak indices for a given plot -#FIX use raise_error concept on more functions to optionally raise an error - # -*- coding: utf-8 -*- +#pylint: disable= """ -Created on Mon Dec 6 15:36:22 2021 - -@author: rv43 +File : general.py +Author : Rolf Verberg +Description: A collection of general modules """ +#RV write function that returns a list of peak indices for a given plot +#RV use raise_error concept on more functions -from logging import getLogger -logger = getLogger(__name__) - +# System modules from ast import literal_eval +from logging import getLogger +from os import path as os_path +from os import ( + access, + R_OK, +) from re import compile as re_compile from re import split as re_split from re import sub as re_sub from sys import float_info +# Third party modules import numpy as np try: import matplotlib.pyplot as plt from matplotlib.widgets import Button -except: +except ImportError: pass -def depth_list(L): return isinstance(L, list) and max(map(depth_list, L))+1 -def depth_tuple(T): return isinstance(T, tuple) and max(map(depth_tuple, T))+1 -def unwrap_tuple(T): - if depth_tuple(T) > 1 and len(T) == 1: - T = unwrap_tuple(*T) - return T +logger = getLogger(__name__) + +def depth_list(_list): + """Return the depth of a list.""" + return isinstance(_list, list) and 1+max(map(depth_list, _list)) +def depth_tuple(_tuple): + """Return the depth of a tuple.""" + return isinstance(_tuple, tuple) and 1+max(map(depth_tuple, _tuple)) +def unwrap_tuple(_tuple): + """Unwrap a tuple.""" + if depth_tuple(_tuple) > 1 and len(_tuple) == 1: + _tuple = unwrap_tuple(*_tuple) + return _tuple def illegal_value(value, name, location=None, raise_error=False, log=True): + """Print illegal value message and/or raise error.""" if not isinstance(location, str): location = '' else: location = f'in {location} ' if isinstance(name, str): - error_msg = f'Illegal value for {name} {location}({value}, {type(value)})' + error_msg = \ + f'Illegal value for {name} {location}({value}, {type(value)})' else: error_msg = f'Illegal value {location}({value}, {type(value)})' if log: @@ -47,27 +60,34 @@ def illegal_value(value, name, location=None, raise_error=False, log=True): if raise_error: raise ValueError(error_msg) -def illegal_combination(value1, name1, value2, name2, location=None, raise_error=False, +def illegal_combination( + value1, name1, value2, name2, location=None, raise_error=False, log=True): + """Print illegal combination message and/or raise error.""" if not isinstance(location, str): location = '' else: location = f'in {location} ' if isinstance(name1, str): - error_msg = f'Illegal combination for {name1} and {name2} {location}'+ \ - f'({value1}, {type(value1)} and {value2}, {type(value2)})' + error_msg = f'Illegal combination for {name1} and {name2} {location}' \ + + f'({value1}, {type(value1)} and {value2}, {type(value2)})' else: - error_msg = f'Illegal combination {location}'+ \ - f'({value1}, {type(value1)} and {value2}, {type(value2)})' + error_msg = f'Illegal combination {location}' \ + + f'({value1}, {type(value1)} and {value2}, {type(value2)})' if log: logger.error(error_msg) if raise_error: raise ValueError(error_msg) -def test_ge_gt_le_lt(ge, gt, le, lt, func, location=None, raise_error=False, log=True): - """Check individual and mutual validity of ge, gt, le, lt qualifiers - func: is_int or is_num to test for int or numbers - Return: True upon success or False when mutually exlusive +def test_ge_gt_le_lt( + ge, gt, le, lt, func, location=None, raise_error=False, log=True): + """ + Check individual and mutual validity of ge, gt, le, lt qualifiers. + + :param func: Test for integers or numbers + :type func: callable: is_int, is_num + :return: True upon success or False when mutually exlusive + :rtype: bool """ if ge is None and gt is None and le is None and lt is None: return True @@ -95,21 +115,23 @@ def test_ge_gt_le_lt(ge, gt, le, lt, func, location=None, raise_error=False, log if le is not None and ge > le: illegal_combination(ge, 'ge', le, 'le', location, raise_error, log) return False - elif lt is not None and ge >= lt: + if lt is not None and ge >= lt: illegal_combination(ge, 'ge', lt, 'lt', location, raise_error, log) return False elif gt is not None: if le is not None and gt >= le: illegal_combination(gt, 'gt', le, 'le', location, raise_error, log) return False - elif lt is not None and gt >= lt: + if lt is not None and gt >= lt: illegal_combination(gt, 'gt', lt, 'lt', location, raise_error, log) return False return True def range_string_ge_gt_le_lt(ge=None, gt=None, le=None, lt=None): - """Return a range string representation matching the ge, gt, le, lt qualifiers - Does not validate the inputs, do that as needed before calling + """ + Return a range string representation matching the ge, gt, le, lt + qualifiers. Does not validate the inputs, do that as needed before + calling. """ range_string = '' if ge is not None: @@ -135,30 +157,41 @@ def range_string_ge_gt_le_lt(ge=None, gt=None, le=None, lt=None): return range_string def is_int(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): - """Value is an integer in range ge <= v <= le or gt < v < lt or some combination. - Return: True if yes or False is no + """ + Value is an integer in range ge <= v <= le or gt < v < lt or some + combination. + + :return: True if yes or False is no + :rtype: bool """ return _is_int_or_num(v, 'int', ge, gt, le, lt, raise_error, log) def is_num(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): - """Value is a number in range ge <= v <= le or gt < v < lt or some combination. - Return: True if yes or False is no + """ + Value is a number in range ge <= v <= le or gt < v < lt or some + combination. + + :return: True if yes or False is no + :rtype: bool """ return _is_int_or_num(v, 'num', ge, gt, le, lt, raise_error, log) -def _is_int_or_num(v, type_str, ge=None, gt=None, le=None, lt=None, raise_error=False, +def _is_int_or_num( + v, type_str, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): if type_str == 'int': if not isinstance(v, int): illegal_value(v, 'v', '_is_int_or_num', raise_error, log) return False - if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, '_is_int_or_num', raise_error, log): + if not test_ge_gt_le_lt( + ge, gt, le, lt, is_int, '_is_int_or_num', raise_error, log): return False elif type_str == 'num': if not isinstance(v, (int, float)): illegal_value(v, 'v', '_is_int_or_num', raise_error, log) return False - if not test_ge_gt_le_lt(ge, gt, le, lt, is_num, '_is_int_or_num', raise_error, log): + if not test_ge_gt_le_lt( + ge, gt, le, lt, is_num, '_is_int_or_num', raise_error, log): return False else: illegal_value(type_str, 'type_str', '_is_int_or_num', raise_error, log) @@ -186,145 +219,181 @@ def _is_int_or_num(v, type_str, ge=None, gt=None, le=None, lt=None, raise_error= return False return True -def is_int_pair(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): - """Value is an integer pair, each in range ge <= v[i] <= le or gt < v[i] < lt or - ge[i] <= v[i] <= le[i] or gt[i] < v[i] < lt[i] or some combination. - Return: True if yes or False is no +def is_int_pair( + v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): + """ + Value is an integer pair, each in range ge <= v[i] <= le or + gt < v[i] < lt or ge[i] <= v[i] <= le[i] or gt[i] < v[i] < lt[i] + or some combination. + + :return: True if yes or False is no + :rtype: bool """ return _is_int_or_num_pair(v, 'int', ge, gt, le, lt, raise_error, log) -def is_num_pair(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): - """Value is a number pair, each in range ge <= v[i] <= le or gt < v[i] < lt or - ge[i] <= v[i] <= le[i] or gt[i] < v[i] < lt[i] or some combination. - Return: True if yes or False is no +def is_num_pair( + v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): + """ + Value is a number pair, each in range ge <= v[i] <= le or + gt < v[i] < lt or ge[i] <= v[i] <= le[i] or gt[i] < v[i] < lt[i] + or some combination. + + :return: True if yes or False is no + :rtype: bool """ return _is_int_or_num_pair(v, 'num', ge, gt, le, lt, raise_error, log) -def _is_int_or_num_pair(v, type_str, ge=None, gt=None, le=None, lt=None, raise_error=False, +def _is_int_or_num_pair( + v, type_str, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): if type_str == 'int': - if not (isinstance(v, (tuple, list)) and len(v) == 2 and isinstance(v[0], int) and - isinstance(v[1], int)): + if not (isinstance(v, (tuple, list)) and len(v) == 2 + and isinstance(v[0], int) and isinstance(v[1], int)): illegal_value(v, 'v', '_is_int_or_num_pair', raise_error, log) return False func = is_int elif type_str == 'num': - if not (isinstance(v, (tuple, list)) and len(v) == 2 and isinstance(v[0], (int, float)) and - isinstance(v[1], (int, float))): + if not (isinstance(v, (tuple, list)) and len(v) == 2 + and isinstance(v[0], (int, float)) + and isinstance(v[1], (int, float))): illegal_value(v, 'v', '_is_int_or_num_pair', raise_error, log) return False func = is_num else: - illegal_value(type_str, 'type_str', '_is_int_or_num_pair', raise_error, log) + illegal_value( + type_str, 'type_str', '_is_int_or_num_pair', raise_error, log) return False if ge is None and gt is None and le is None and lt is None: return True if ge is None or func(ge, log=True): ge = 2*[ge] - elif not _is_int_or_num_pair(ge, type_str, raise_error=raise_error, log=log): + elif not _is_int_or_num_pair( + ge, type_str, raise_error=raise_error, log=log): return False if gt is None or func(gt, log=True): gt = 2*[gt] - elif not _is_int_or_num_pair(gt, type_str, raise_error=raise_error, log=log): + elif not _is_int_or_num_pair( + gt, type_str, raise_error=raise_error, log=log): return False if le is None or func(le, log=True): le = 2*[le] - elif not _is_int_or_num_pair(le, type_str, raise_error=raise_error, log=log): + elif not _is_int_or_num_pair( + le, type_str, raise_error=raise_error, log=log): return False if lt is None or func(lt, log=True): lt = 2*[lt] - elif not _is_int_or_num_pair(lt, type_str, raise_error=raise_error, log=log): + elif not _is_int_or_num_pair( + lt, type_str, raise_error=raise_error, log=log): return False if (not func(v[0], ge[0], gt[0], le[0], lt[0], raise_error, log) or not func(v[1], ge[1], gt[1], le[1], lt[1], raise_error, log)): return False return True -def is_int_series(l, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): - """Value is a tuple or list of integers, each in range ge <= l[i] <= le or - gt < l[i] < lt or some combination. +def is_int_series( + t_or_l, ge=None, gt=None, le=None, lt=None, raise_error=False, + log=True): + """ + Value is a tuple or list of integers, each in range + ge <= l[i] <= le or gt < l[i] < lt or some combination. """ - if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, 'is_int_series', raise_error, log): + if not test_ge_gt_le_lt( + ge, gt, le, lt, is_int, 'is_int_series', raise_error, log): return False - if not isinstance(l, (tuple, list)): - illegal_value(l, 'l', 'is_int_series', raise_error, log) + if not isinstance(t_or_l, (tuple, list)): + illegal_value(t_or_l, 't_or_l', 'is_int_series', raise_error, log) return False - if any(True if not is_int(v, ge, gt, le, lt, raise_error, log) else False for v in l): + if any(not is_int(v, ge, gt, le, lt, raise_error, log) for v in t_or_l): return False return True -def is_num_series(l, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): - """Value is a tuple or list of numbers, each in range ge <= l[i] <= le or - gt < l[i] < lt or some combination. +def is_num_series( + t_or_l, ge=None, gt=None, le=None, lt=None, raise_error=False, + log=True): + """ + Value is a tuple or list of numbers, each in range ge <= l[i] <= le + or gt < l[i] < lt or some combination. """ - if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, 'is_int_series', raise_error, log): + if not test_ge_gt_le_lt( + ge, gt, le, lt, is_int, 'is_int_series', raise_error, log): return False - if not isinstance(l, (tuple, list)): - illegal_value(l, 'l', 'is_num_series', raise_error, log) + if not isinstance(t_or_l, (tuple, list)): + illegal_value(t_or_l, 't_or_l', 'is_num_series', raise_error, log) return False - if any(True if not is_num(v, ge, gt, le, lt, raise_error, log) else False for v in l): + if any(not is_num(v, ge, gt, le, lt, raise_error, log) for v in t_or_l): return False return True -def is_str_series(l, raise_error=False, log=True): - """Value is a tuple or list of strings. +def is_str_series(t_or_l, raise_error=False, log=True): """ - if (not isinstance(l, (tuple, list)) or - any(True if not isinstance(s, str) else False for s in l)): - illegal_value(l, 'l', 'is_str_series', raise_error, log) + Value is a tuple or list of strings. + """ + if (not isinstance(t_or_l, (tuple, list)) + or any(not isinstance(s, str) for s in t_or_l)): + illegal_value(t_or_l, 't_or_l', 'is_str_series', raise_error, log) return False return True -def is_dict_series(l, raise_error=False, log=True): - """Value is a tuple or list of dictionaries. +def is_dict_series(t_or_l, raise_error=False, log=True): + """ + Value is a tuple or list of dictionaries. """ - if (not isinstance(l, (tuple, list)) or - any(True if not isinstance(d, dict) else False for d in l)): - illegal_value(l, 'l', 'is_dict_series', raise_error, log) + if (not isinstance(t_or_l, (tuple, list)) + or any(not isinstance(d, dict) for d in t_or_l)): + illegal_value(t_or_l, 't_or_l', 'is_dict_series', raise_error, log) return False return True -def is_dict_nums(l, raise_error=False, log=True): - """Value is a dictionary with single number values +def is_dict_nums(d, raise_error=False, log=True): """ - if (not isinstance(l, dict) or - any(True if not is_num(v, log=False) else False for v in l.values())): - illegal_value(l, 'l', 'is_dict_nums', raise_error, log) + Value is a dictionary with single number values + """ + if (not isinstance(d, dict) + or any(not is_num(v, log=False) for v in d.values())): + illegal_value(d, 'd', 'is_dict_nums', raise_error, log) return False return True -def is_dict_strings(l, raise_error=False, log=True): - """Value is a dictionary with single string values +def is_dict_strings(d, raise_error=False, log=True): + """ + Value is a dictionary with single string values """ - if (not isinstance(l, dict) or - any(True if not isinstance(v, str) else False for v in l.values())): - illegal_value(l, 'l', 'is_dict_strings', raise_error, log) + if (not isinstance(d, dict) + or any(not isinstance(v, str) for v in d.values())): + illegal_value(d, 'd', 'is_dict_strings', raise_error, log) return False return True def is_index(v, ge=0, lt=None, raise_error=False, log=True): - """Value is an array index in range ge <= v < lt. - NOTE lt IS NOT included! + """ + Value is an array index in range ge <= v < lt. NOTE lt IS NOT + included! """ if isinstance(lt, int): if lt <= ge: - illegal_combination(ge, 'ge', lt, 'lt', 'is_index', raise_error, log) + illegal_combination( + ge, 'ge', lt, 'lt', 'is_index', raise_error, log) return False return is_int(v, ge=ge, lt=lt, raise_error=raise_error, log=log) def is_index_range(v, ge=0, le=None, lt=None, raise_error=False, log=True): - """Value is an array index range in range ge <= v[0] <= v[1] <= le or ge <= v[0] <= v[1] < lt. - NOTE le IS included! + """ + Value is an array index range in range ge <= v[0] <= v[1] <= le or + ge <= v[0] <= v[1] < lt. NOTE le IS included! """ if not is_int_pair(v, raise_error=raise_error, log=log): return False - if not test_ge_gt_le_lt(ge, None, le, lt, is_int, 'is_index_range', raise_error, log): + if not test_ge_gt_le_lt( + ge, None, le, lt, is_int, 'is_index_range', raise_error, log): return False - if not ge <= v[0] <= v[1] or (le is not None and v[1] > le) or (lt is not None and v[1] >= lt): + if (not ge <= v[0] <= v[1] or (le is not None and v[1] > le) + or (lt is not None and v[1] >= lt)): if le is not None: - error_msg = f'Value {v} out of range: !({ge} <= {v[0]} <= {v[1]} <= {le})' + error_msg = \ + f'Value {v} out of range: !({ge} <= {v[0]} <= {v[1]} <= {le})' else: - error_msg = f'Value {v} out of range: !({ge} <= {v[0]} <= {v[1]} < {lt})' + error_msg = \ + f'Value {v} out of range: !({ge} <= {v[0]} <= {v[1]} < {lt})' if log: logger.error(error_msg) if raise_error: @@ -333,148 +402,179 @@ def is_index_range(v, ge=0, le=None, lt=None, raise_error=False, log=True): return True def index_nearest(a, value): + """Return index of nearest array value.""" a = np.asarray(a) if a.ndim > 1: - raise ValueError(f'Invalid array dimension for parameter a ({a.ndim}, {a})') + raise ValueError( + f'Invalid array dimension for parameter a ({a.ndim}, {a})') # Round up for .5 value *= 1.0+float_info.epsilon return (int)(np.argmin(np.abs(a-value))) def index_nearest_low(a, value): + """Return index of nearest array value, rounded down""" a = np.asarray(a) if a.ndim > 1: - raise ValueError(f'Invalid array dimension for parameter a ({a.ndim}, {a})') + raise ValueError( + f'Invalid array dimension for parameter a ({a.ndim}, {a})') index = int(np.argmin(np.abs(a-value))) if value < a[index] and index > 0: index -= 1 return index def index_nearest_upp(a, value): + """Return index of nearest array value, rounded upp.""" a = np.asarray(a) if a.ndim > 1: - raise ValueError(f'Invalid array dimension for parameter a ({a.ndim}, {a})') + raise ValueError( + f'Invalid array dimension for parameter a ({a.ndim}, {a})') index = int(np.argmin(np.abs(a-value))) if value > a[index] and index < a.size-1: index += 1 return index def round_to_n(x, n=1): + """Round to a specific number of decimals.""" if x == 0.0: return 0 - else: - return type(x)(round(x, n-1-int(np.floor(np.log10(abs(x)))))) + return type(x)(round(x, n-1-int(np.floor(np.log10(abs(x)))))) def round_up_to_n(x, n=1): - xr = round_to_n(x, n) - if abs(x/xr) > 1.0: - xr += np.sign(x)*10**(np.floor(np.log10(abs(x)))+1-n) - return type(x)(xr) + """Round up to a specific number of decimals.""" + x_round = round_to_n(x, n) + if abs(x/x_round) > 1.0: + x_round += np.sign(x) * 10**(np.floor(np.log10(abs(x)))+1-n) + return type(x)(x_round) def trunc_to_n(x, n=1): - xr = round_to_n(x, n) - if abs(xr/x) > 1.0: - xr -= np.sign(x)*10**(np.floor(np.log10(abs(x)))+1-n) - return type(x)(xr) + """Truncate to a specific number of decimals.""" + x_round = round_to_n(x, n) + if abs(x_round/x) > 1.0: + x_round -= np.sign(x) * 10**(np.floor(np.log10(abs(x)))+1-n) + return type(x)(x_round) def almost_equal(a, b, sig_figs): + """ + Check if equal to within a certain number of significant digits. + """ if is_num(a) and is_num(b): - return abs(round_to_n(a-b, sig_figs)) < pow(10, -sig_figs+1) - else: - raise ValueError(f'Invalid value for a or b in almost_equal (a: {a}, {type(a)}, '+ - f'b: {b}, {type(b)})') - return False + return abs(round_to_n(a-b, sig_figs)) < pow(10, 1-sig_figs) + raise ValueError( + f'Invalid value for a or b in almost_equal (a: {a}, {type(a)}, ' + + f'b: {b}, {type(b)})') def string_to_list(s, split_on_dash=True, remove_duplicates=True, sort=True): - """Return a list of numbers by splitting/expanding a string on any combination of - commas, whitespaces, or dashes (when split_on_dash=True) - e.g: '1, 3, 5-8, 12 ' -> [1, 3, 5, 6, 7, 8, 12] + """ + Return a list of numbers by splitting/expanding a string on any + combination of commas, whitespaces, or dashes (when + split_on_dash=True). + e.g: '1, 3, 5-8, 12 ' -> [1, 3, 5, 6, 7, 8, 12] """ if not isinstance(s, str): - illegal_value(s, location='string_to_list') + illegal_value(s, 's', location='string_to_list') return None - if not len(s): + if not s: return [] try: - ll = [x for x in re_split('\s+,\s+|\s+,|,\s+|\s+|,', s.strip())] + list1 = re_split(r'\s+,\s+|\s+,|,\s+|\s+|,', s.strip()) except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): return None if split_on_dash: try: - l = [] - for l1 in ll: - l2 = [literal_eval(x) for x in re_split('\s+-\s+|\s+-|-\s+|\s+|-', l1)] - if len(l2) == 1: - l += l2 - elif len(l2) == 2 and l2[1] > l2[0]: - l += [i for i in range(l2[0], l2[1]+1)] + l_of_i = [] + for v in list1: + list2 = [literal_eval(x) + for x in re_split(r'\s+-\s+|\s+-|-\s+|\s+|-', v)] + if len(list2) == 1: + l_of_i += list2 + elif len(list2) == 2 and list2[1] > list2[0]: + l_of_i += list(range(list2[0], 1+list2[1])) else: raise ValueError - except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): + except (ValueError, TypeError, SyntaxError, MemoryError, + RecursionError): return None else: - l = [literal_eval(x) for x in ll] + l_of_i = [literal_eval(x) for x in list1] if remove_duplicates: - l = list(dict.fromkeys(l)) + l_of_i = list(dict.fromkeys(l_of_i)) if sort: - l = sorted(l) - return l + l_of_i = sorted(l_of_i) + return l_of_i def get_trailing_int(string): - indexRegex = re_compile(r'\d+$') - mo = indexRegex.search(string) - if mo is None: + """Get the trailing integer in a string.""" + index_regex = re_compile(r'\d+$') + match = index_regex.search(string) + if match is None: return None - else: - return int(mo.group()) + return int(match.group()) -def input_int(s=None, ge=None, gt=None, le=None, lt=None, default=None, inset=None, +def input_int( + s=None, ge=None, gt=None, le=None, lt=None, default=None, inset=None, raise_error=False, log=True): - return _input_int_or_num('int', s, ge, gt, le, lt, default, inset, raise_error, log) + """Interactively prompt the user to enter an integer.""" + return _input_int_or_num( + 'int', s, ge, gt, le, lt, default, inset, raise_error, log) -def input_num(s=None, ge=None, gt=None, le=None, lt=None, default=None, raise_error=False, - log=True): - return _input_int_or_num('num', s, ge, gt, le, lt, default, None, raise_error,log) +def input_num( + s=None, ge=None, gt=None, le=None, lt=None, default=None, + raise_error=False, log=True): + """Interactively prompt the user to enter a number.""" + return _input_int_or_num( + 'num', s, ge, gt, le, lt, default, None, raise_error,log) -def _input_int_or_num(type_str, s=None, ge=None, gt=None, le=None, lt=None, default=None, +def _input_int_or_num( + type_str, s=None, ge=None, gt=None, le=None, lt=None, default=None, inset=None, raise_error=False, log=True): + """Interactively prompt the user to enter an integer or number.""" if type_str == 'int': - if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, '_input_int_or_num', raise_error, log): + if not test_ge_gt_le_lt( + ge, gt, le, lt, is_int, '_input_int_or_num', raise_error, log): return None elif type_str == 'num': - if not test_ge_gt_le_lt(ge, gt, le, lt, is_num, '_input_int_or_num', raise_error, log): + if not test_ge_gt_le_lt( + ge, gt, le, lt, is_num, '_input_int_or_num', raise_error, log): return None else: - illegal_value(type_str, 'type_str', '_input_int_or_num', raise_error, log) + illegal_value( + type_str, 'type_str', '_input_int_or_num', raise_error, log) return None if default is not None: - if not _is_int_or_num(default, type_str, raise_error=raise_error, log=log): + if not _is_int_or_num( + default, type_str, raise_error=raise_error, log=log): return None if ge is not None and default < ge: - illegal_combination(ge, 'ge', default, 'default', '_input_int_or_num', raise_error, + illegal_combination( + ge, 'ge', default, 'default', '_input_int_or_num', raise_error, log) return None if gt is not None and default <= gt: - illegal_combination(gt, 'gt', default, 'default', '_input_int_or_num', raise_error, + illegal_combination( + gt, 'gt', default, 'default', '_input_int_or_num', raise_error, log) return None if le is not None and default > le: - illegal_combination(le, 'le', default, 'default', '_input_int_or_num', raise_error, + illegal_combination( + le, 'le', default, 'default', '_input_int_or_num', raise_error, log) return None if lt is not None and default >= lt: - illegal_combination(lt, 'lt', default, 'default', '_input_int_or_num', raise_error, + illegal_combination( + lt, 'lt', default, 'default', '_input_int_or_num', raise_error, log) return None default_string = f' [{default}]' else: default_string = '' if inset is not None: - if (not isinstance(inset, (tuple, list)) or any(True if not isinstance(i, int) else - False for i in inset)): - illegal_value(inset, 'inset', '_input_int_or_num', raise_error, log) + if (not isinstance(inset, (tuple, list)) + or any(not isinstance(i, int) for i in inset)): + illegal_value( + inset, 'inset', '_input_int_or_num', raise_error, log) return None v_range = f'{range_string_ge_gt_le_lt(ge, gt, le, lt)}' - if len(v_range): + if v_range: v_range = f' {v_range}' if s is None: if type_str == 'int': @@ -485,89 +585,100 @@ def _input_int_or_num(type_str, s=None, ge=None, gt=None, le=None, lt=None, defa print(f'{s}{v_range}{default_string}: ') try: i = input() - if isinstance(i, str) and not len(i): + if isinstance(i, str) and not i: v = default print(f'{v}') else: v = literal_eval(i) if inset and v not in inset: - raise ValueError(f'{v} not part of the set {inset}') + raise ValueError(f'{v} not part of the set {inset}') except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): v = None - except: - if log: - logger.error('Unexpected error') - if raise_error: - raise ValueError('Unexpected error') if not _is_int_or_num(v, type_str, ge, gt, le, lt): - v = _input_int_or_num(type_str, s, ge, gt, le, lt, default, inset, raise_error, log) + v = _input_int_or_num( + type_str, s, ge, gt, le, lt, default, inset, raise_error, log) return v -def input_int_list(s=None, ge=None, le=None, split_on_dash=True, remove_duplicates=True, +def input_int_list( + s=None, ge=None, le=None, split_on_dash=True, remove_duplicates=True, sort=True, raise_error=False, log=True): - """Prompt the user to input a list of interger and split the entered string on any combination - of commas, whitespaces, or dashes (when split_on_dash is True) - e.g: '1 3,5-8 , 12 ' -> [1, 3, 5, 6, 7, 8, 12] - remove_duplicates: removes duplicates if True (may also change the order) - sort: sort in ascending order if True - return None upon an illegal input """ - return _input_int_or_num_list('int', s, ge, le, split_on_dash, remove_duplicates, sort, - raise_error, log) - -def input_num_list(s=None, ge=None, le=None, remove_duplicates=True, sort=True, raise_error=False, - log=True): - """Prompt the user to input a list of numbers and split the entered string on any combination - of commas or whitespaces - e.g: '1.0, 3, 5.8, 12 ' -> [1.0, 3.0, 5.8, 12.0] - remove_duplicates: removes duplicates if True (may also change the order) - sort: sort in ascending order if True - return None upon an illegal input + Prompt the user to input a list of interger and split the entered + string on any combination of commas, whitespaces, or dashes (when + split_on_dash is True). + e.g: '1 3,5-8 , 12 ' -> [1, 3, 5, 6, 7, 8, 12] + + remove_duplicates: removes duplicates if True (may also change the + order) + sort: sort in ascending order if True + return None upon an illegal input """ - return _input_int_or_num_list('num', s, ge, le, False, remove_duplicates, sort, raise_error, + return _input_int_or_num_list( + 'int', s, ge, le, split_on_dash, remove_duplicates, sort, raise_error, log) -def _input_int_or_num_list(type_str, s=None, ge=None, le=None, split_on_dash=True, +def input_num_list( + s=None, ge=None, le=None, remove_duplicates=True, sort=True, + raise_error=False, log=True): + """ + Prompt the user to input a list of numbers and split the entered + string on any combination of commas or whitespaces. + e.g: '1.0, 3, 5.8, 12 ' -> [1.0, 3.0, 5.8, 12.0] + + remove_duplicates: removes duplicates if True (may also change the + order) + sort: sort in ascending order if True + return None upon an illegal input + """ + return _input_int_or_num_list( + 'num', s, ge, le, False, remove_duplicates, sort, raise_error, log) + +def _input_int_or_num_list( + type_str, s=None, ge=None, le=None, split_on_dash=True, remove_duplicates=True, sort=True, raise_error=False, log=True): - #FIX do we want a limit on max dimension? + #RV do we want a limit on max dimension? if type_str == 'int': - if not test_ge_gt_le_lt(ge, None, le, None, is_int, 'input_int_or_num_list', raise_error, - log): + if not test_ge_gt_le_lt( + ge, None, le, None, is_int, 'input_int_or_num_list', + raise_error, log): return None elif type_str == 'num': - if not test_ge_gt_le_lt(ge, None, le, None, is_num, 'input_int_or_num_list', raise_error, - log): + if not test_ge_gt_le_lt( + ge, None, le, None, is_num, 'input_int_or_num_list', + raise_error, log): return None else: illegal_value(type_str, 'type_str', '_input_int_or_num_list') return None v_range = f'{range_string_ge_gt_le_lt(ge=ge, le=le)}' - if len(v_range): + if v_range: v_range = f' (each value in {v_range})' if s is None: print(f'Enter a series of integers{v_range}: ') else: print(f'{s}{v_range}: ') try: - l = string_to_list(input(), split_on_dash, remove_duplicates, sort) + _list = string_to_list(input(), split_on_dash, remove_duplicates, sort) except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): - l = None + _list = None except: print('Unexpected error') raise - if (not isinstance(l, list) or - any(True if not _is_int_or_num(v, type_str, ge=ge, le=le) else False for v in l)): + if (not isinstance(_list, list) or any( + not _is_int_or_num(v, type_str, ge=ge, le=le) for v in _list)): if split_on_dash: - print('Invalid input: enter a valid set of dash/comma/whitespace separated integers '+ - 'e.g. 1 3,5-8 , 12') + print('Invalid input: enter a valid set of dash/comma/whitespace ' + + 'separated integers e.g. 1 3,5-8 , 12') else: - print('Invalid input: enter a valid set of comma/whitespace separated integers '+ - 'e.g. 1 3,5 8 , 12') - l = _input_int_or_num_list(type_str, s, ge, le, split_on_dash, remove_duplicates, sort, + print('Invalid input: enter a valid set of comma/whitespace ' + + 'separated integers e.g. 1 3,5 8 , 12') + _list = _input_int_or_num_list( + type_str, s, ge, le, split_on_dash, remove_duplicates, sort, raise_error, log) - return l + return _list def input_yesno(s=None, default=None): + """Interactively prompt the user to enter a y/n question.""" if default is not None: if not isinstance(default, str): illegal_value(default, 'default', 'input_yesno') @@ -587,7 +698,7 @@ def input_yesno(s=None, default=None): else: print(f'{s}{default_string}: ') i = input() - if isinstance(i, str) and not len(i): + if isinstance(i, str) and not i: i = default print(f'{i}') if i is not None and i.lower() in 'yes': @@ -600,28 +711,32 @@ def input_yesno(s=None, default=None): return v def input_menu(items, default=None, header=None): - if not isinstance(items, (tuple, list)) or any(True if not isinstance(i, str) else False - for i in items): + """Interactively prompt the user to select from a menu.""" + if (not isinstance(items, (tuple, list)) + or any(not isinstance(i, str) for i in items)): illegal_value(items, 'items', 'input_menu') return None if default is not None: if not (isinstance(default, str) and default in items): - logger.error(f'Invalid value for default ({default}), must be in {items}') + logger.error( + f'Invalid value for default ({default}), must be in {items}') return None - default_string = f' [{items.index(default)+1}]' + default_string = f' [{1+items.index(default)}]' else: default_string = '' if header is None: - print(f'Choose one of the following items (1, {len(items)}){default_string}:') + print( + 'Choose one of the following items ' + + f'(1, {len(items)}){default_string}:') else: print(f'{header} (1, {len(items)}){default_string}:') for i, choice in enumerate(items): print(f' {i+1}: {choice}') try: choice = input() - if isinstance(choice, str) and not len(choice): + if isinstance(choice, str) and not choice: choice = items.index(default) - print(f'{choice+1}') + print(f'{1+choice}') else: choice = literal_eval(choice) if isinstance(choice, int) and 1 <= choice <= len(items): @@ -638,141 +753,180 @@ def input_menu(items, default=None, header=None): choice = input_menu(items, default) return choice -def assert_no_duplicates_in_list_of_dicts(l: list, raise_error=False) -> list: - if not isinstance(l, list): - illegal_value(l, 'l', 'assert_no_duplicates_in_list_of_dicts', raise_error) +def assert_no_duplicates_in_list_of_dicts(_list, raise_error=False): + """ + Assert that there are no duplicates in a list of dictionaries. + """ + if not isinstance(_list, list): + illegal_value( + _list, '_list', 'assert_no_duplicates_in_list_of_dicts', + raise_error) return None - if any(True if not isinstance(d, dict) else False for d in l): - illegal_value(l, 'l', 'assert_no_duplicates_in_list_of_dicts', raise_error) + if any(not isinstance(d, dict) for d in _list): + illegal_value( + _list, '_list', 'assert_no_duplicates_in_list_of_dicts', + raise_error) return None - if len(l) != len([dict(t) for t in {tuple(sorted(d.items())) for d in l}]): + if (len(_list) != len([dict(_tuple) for _tuple in + {tuple(sorted(d.items())) for d in _list}])): if raise_error: - raise ValueError(f'Duplicate items found in {l}') - else: - logger.error(f'Duplicate items found in {l}') + raise ValueError(f'Duplicate items found in {_list}') + logger.error(f'Duplicate items found in {_list}') return None - else: - return l + return _list -def assert_no_duplicate_key_in_list_of_dicts(l: list, key: str, raise_error=False) -> list: +def assert_no_duplicate_key_in_list_of_dicts(_list, key, raise_error=False): + """ + Assert that there are no duplicate keys in a list of dictionaries. + """ if not isinstance(key, str): - illegal_value(key, 'key', 'assert_no_duplicate_key_in_list_of_dicts', raise_error) + illegal_value( + key, 'key', 'assert_no_duplicate_key_in_list_of_dicts', + raise_error) return None - if not isinstance(l, list): - illegal_value(l, 'l', 'assert_no_duplicate_key_in_list_of_dicts', raise_error) + if not isinstance(_list, list): + illegal_value( + _list, '_list', 'assert_no_duplicate_key_in_list_of_dicts', + raise_error) return None - if any(True if not isinstance(d, dict) else False for d in l): - illegal_value(l, 'l', 'assert_no_duplicates_in_list_of_dicts', raise_error) + if any(isinstance(d, dict) for d in _list): + illegal_value( + _list, '_list', 'assert_no_duplicates_in_list_of_dicts', + raise_error) return None - keys = [d.get(key, None) for d in l] - if None in keys or len(set(keys)) != len(l): + keys = [d.get(key, None) for d in _list] + if None in keys or len(set(keys)) != len(_list): if raise_error: - raise ValueError(f'Duplicate or missing key ({key}) found in {l}') - else: - logger.error(f'Duplicate or missing key ({key}) found in {l}') + raise ValueError( + f'Duplicate or missing key ({key}) found in {_list}') + logger.error(f'Duplicate or missing key ({key}) found in {_list}') return None - else: - return l + return _list -def assert_no_duplicate_attr_in_list_of_objs(l: list, attr: str, raise_error=False) -> list: +def assert_no_duplicate_attr_in_list_of_objs(_list, attr, raise_error=False): + """ + Assert that there are no duplicate attributes in a list of objects. + """ if not isinstance(attr, str): - illegal_value(attr, 'attr', 'assert_no_duplicate_attr_in_list_of_objs', raise_error) + illegal_value( + attr, 'attr', 'assert_no_duplicate_attr_in_list_of_objs', + raise_error) return None - if not isinstance(l, list): - illegal_value(l, 'l', 'assert_no_duplicate_key_in_list_of_objs', raise_error) + if not isinstance(_list, list): + illegal_value( + _list, '_list', 'assert_no_duplicate_key_in_list_of_objs', + raise_error) return None - attrs = [getattr(obj, attr, None) for obj in l] - if None in attrs or len(set(attrs)) != len(l): + attrs = [getattr(obj, attr, None) for obj in _list] + if None in attrs or len(set(attrs)) != len(_list): if raise_error: - raise ValueError(f'Duplicate or missing attr ({attr}) found in {l}') - else: - logger.error(f'Duplicate or missing attr ({attr}) found in {l}') + raise ValueError( + f'Duplicate or missing attr ({attr}) found in {_list}') + logger.error(f'Duplicate or missing attr ({attr}) found in {_list}') return None - else: - return l - -def file_exists_and_readable(path): - import os - if not os.path.isfile(path): - raise ValueError(f'{path} is not a valid file') - elif not os.access(path, os.R_OK): - raise ValueError(f'{path} is not accessible for reading') - else: - return path - -def draw_mask_1d(ydata, xdata=None, current_index_ranges=None, current_mask=None, - select_mask=True, num_index_ranges_max=None, title=None, legend=None, test_mode=False): - #FIX make color blind friendly - def draw_selections(ax, current_include, current_exclude, selected_index_ranges): + return _list + +def file_exists_and_readable(f): + """Check if a file exists and is readable.""" + if not os_path.isfile(f): + raise ValueError(f'{f} is not a valid file') + if not access(f, R_OK): + raise ValueError(f'{f} is not accessible for reading') + return f + +def draw_mask_1d( + ydata, xdata=None, current_index_ranges=None, current_mask=None, + select_mask=True, num_index_ranges_max=None, title=None, legend=None, + test_mode=False): + """Display a 2D plot and have the user select a mask.""" + #RV make color blind friendly + def draw_selections(ax, current_include, current_exclude, + selected_index_ranges): + """Draw the selections.""" ax.clear() ax.set_title(title) ax.legend([legend]) ax.plot(xdata, ydata, 'k') - for (low, upp) in current_include: - xlow = 0.5*(xdata[max(0, low-1)]+xdata[low]) - xupp = 0.5*(xdata[upp]+xdata[min(num_data-1, upp+1)]) + for low, upp in current_include: + xlow = 0.5 * (xdata[max(0, low-1)]+xdata[low]) + xupp = 0.5 * (xdata[upp]+xdata[min(num_data-1, 1+upp)]) ax.axvspan(xlow, xupp, facecolor='green', alpha=0.5) - for (low, upp) in current_exclude: - xlow = 0.5*(xdata[max(0, low-1)]+xdata[low]) - xupp = 0.5*(xdata[upp]+xdata[min(num_data-1, upp+1)]) + for low, upp in current_exclude: + xlow = 0.5 * (xdata[max(0, low-1)]+xdata[low]) + xupp = 0.5 * (xdata[upp]+xdata[min(num_data-1, 1+upp)]) ax.axvspan(xlow, xupp, facecolor='red', alpha=0.5) - for (low, upp) in selected_index_ranges: - xlow = 0.5*(xdata[max(0, low-1)]+xdata[low]) - xupp = 0.5*(xdata[upp]+xdata[min(num_data-1, upp+1)]) + for low, upp in selected_index_ranges: + xlow = 0.5 * (xdata[max(0, low-1)]+xdata[low]) + xupp = 0.5 * (xdata[upp]+xdata[min(num_data-1, 1+upp)]) ax.axvspan(xlow, xupp, facecolor=selection_color, alpha=0.5) ax.get_figure().canvas.draw() def onclick(event): + """Action taken on clicking the mouse button.""" if event.inaxes in [fig.axes[0]]: selected_index_ranges.append(index_nearest_upp(xdata, event.xdata)) def onrelease(event): - if len(selected_index_ranges) > 0: + """Action taken on releasing the mouse button.""" + if selected_index_ranges: if isinstance(selected_index_ranges[-1], int): if event.inaxes in [fig.axes[0]]: event.xdata = index_nearest_low(xdata, event.xdata) if selected_index_ranges[-1] <= event.xdata: - selected_index_ranges[-1] = (selected_index_ranges[-1], event.xdata) + selected_index_ranges[-1] = \ + (selected_index_ranges[-1], event.xdata) else: - selected_index_ranges[-1] = (event.xdata, selected_index_ranges[-1]) - draw_selections(event.inaxes, current_include, current_exclude, selected_index_ranges) + selected_index_ranges[-1] = \ + (event.xdata, selected_index_ranges[-1]) + draw_selections( + event.inaxes, current_include, current_exclude, + selected_index_ranges) else: selected_index_ranges.pop(-1) def confirm_selection(event): + """Action taken on hitting the confirm button.""" plt.close() def clear_last_selection(event): - if len(selected_index_ranges): + """Action taken on hitting the clear button.""" + if selected_index_ranges: selected_index_ranges.pop(-1) else: - while len(current_include): + while current_include: current_include.pop() - while len(current_exclude): + while current_exclude: current_exclude.pop() selected_mask.fill(False) - draw_selections(ax, current_include, current_exclude, selected_index_ranges) + draw_selections( + ax, current_include, current_exclude, selected_index_ranges) def update_mask(mask, selected_index_ranges, unselected_index_ranges): - for (low, upp) in selected_index_ranges: - selected_mask = np.logical_and(xdata >= xdata[low], xdata <= xdata[upp]) + """Update the plot with the selected mask.""" + for low, upp in selected_index_ranges: + selected_mask = np.logical_and( + xdata >= xdata[low], xdata <= xdata[upp]) mask = np.logical_or(mask, selected_mask) - for (low, upp) in unselected_index_ranges: - unselected_mask = np.logical_and(xdata >= xdata[low], xdata <= xdata[upp]) + for low, upp in unselected_index_ranges: + unselected_mask = np.logical_and( + xdata >= xdata[low], xdata <= xdata[upp]) mask[unselected_mask] = False return mask def update_index_ranges(mask): - # Update the currently included index ranges (where mask is True) + """ + Update the currently included index ranges (where mask = True). + """ current_include = [] for i, m in enumerate(mask): - if m == True: - if len(current_include) == 0 or type(current_include[-1]) == tuple: + if m: + if (not current_include + or isinstance(current_include[-1], tuple)): current_include.append(i) else: - if len(current_include) > 0 and isinstance(current_include[-1], int): + if current_include and isinstance(current_include[-1], int): current_include[-1] = (current_include[-1], i-1) - if len(current_include) > 0 and isinstance(current_include[-1], int): + if current_include and isinstance(current_include[-1], int): current_include[-1] = (current_include[-1], num_data-1) return current_include @@ -794,21 +948,25 @@ def update_index_ranges(mask): return None, None if current_index_ranges is not None: if not isinstance(current_index_ranges, (tuple, list)): - logger.warning('Invalid current_index_ranges parameter ({current_index_ranges}, '+ - f'{type(current_index_ranges)})') + logger.warning( + 'Invalid current_index_ranges parameter ' + + f'({current_index_ranges}, {type(current_index_ranges)})') return None, None if not isinstance(select_mask, bool): - logger.warning('Invalid select_mask parameter ({select_mask}, {type(select_mask)})') + logger.warning( + f'Invalid select_mask parameter ({select_mask}, ' + + f'{type(select_mask)})') return None, None if num_index_ranges_max is not None: - logger.warning('num_index_ranges_max input not yet implemented in draw_mask_1d') + logger.warning( + 'num_index_ranges_max input not yet implemented in draw_mask_1d') if title is None: title = 'select ranges of data' elif not isinstance(title, str): - illegal(title, 'title') + illegal_value(title, 'title') title = '' if legend is None and not isinstance(title, str): - illegal(legend, 'legend') + illegal_value(legend, 'legend') legend = None if select_mask: @@ -818,7 +976,8 @@ def update_index_ranges(mask): title = f'Click and drag to {title} you wish to exclude' selection_color = 'red' - # Set initial selected mask and the selected/unselected index ranges as needed + # Set initial selected mask and the selected/unselected index + # ranges as needed selected_index_ranges = [] unselected_index_ranges = [] selected_mask = np.full(xdata.shape, False, dtype=bool) @@ -829,17 +988,16 @@ def update_index_ranges(mask): selected_mask = np.full(xdata.shape, True, dtype=bool) else: selected_mask = np.copy(np.asarray(current_mask, dtype=bool)) - if current_index_ranges is not None and len(current_index_ranges): - current_index_ranges = sorted([(low, upp) for (low, upp) in current_index_ranges]) - for (low, upp) in current_index_ranges: + if current_index_ranges is not None and current_index_ranges: + current_index_ranges = sorted(list(current_index_ranges)) + for low, upp in current_index_ranges: if low > upp or low >= num_data or upp < 0: continue - if low < 0: - low = 0 - if upp >= num_data: - upp = num_data-1 + low = max(low, 0) + upp = min(upp, num_data-1) selected_index_ranges.append((low, upp)) - selected_mask = update_mask(selected_mask, selected_index_ranges, unselected_index_ranges) + selected_mask = update_mask( + selected_mask, selected_index_ranges, unselected_index_ranges) if current_index_ranges is not None and current_mask is not None: selected_mask = np.logical_and(current_mask, selected_mask) if current_mask is not None: @@ -849,7 +1007,7 @@ def update_index_ranges(mask): current_include = selected_index_ranges current_exclude = [] selected_index_ranges = [] - if not len(current_include): + if not current_include: if select_mask: current_exclude = [(0, num_data-1)] else: @@ -858,9 +1016,10 @@ def update_index_ranges(mask): if current_include[0][0] > 0: current_exclude.append((0, current_include[0][0]-1)) for i in range(1, len(current_include)): - current_exclude.append((current_include[i-1][1]+1, current_include[i][0]-1)) + current_exclude.append( + (1+current_include[i-1][1], current_include[i][0]-1)) if current_include[-1][1] < num_data-1: - current_exclude.append((current_include[-1][1]+1, num_data-1)) + current_exclude.append((1+current_include[-1][1], num_data-1)) if not test_mode: @@ -868,7 +1027,8 @@ def update_index_ranges(mask): plt.close('all') fig, ax = plt.subplots() plt.subplots_adjust(bottom=0.2) - draw_selections(ax, current_include, current_exclude, selected_index_ranges) + draw_selections( + ax, current_include, current_exclude, selected_index_ranges) # Set up event handling for click-and-drag range selection cid_click = fig.canvas.mpl_connect('button_press_event', onclick) @@ -891,20 +1051,24 @@ def update_index_ranges(mask): # Swap selection depending on select_mask if not select_mask: - selected_index_ranges, unselected_index_ranges = unselected_index_ranges, \ - selected_index_ranges + selected_index_ranges, unselected_index_ranges = \ + unselected_index_ranges, selected_index_ranges # Update the mask with the currently selected/unselected x-ranges - selected_mask = update_mask(selected_mask, selected_index_ranges, unselected_index_ranges) + selected_mask = update_mask( + selected_mask, selected_index_ranges, unselected_index_ranges) # Update the currently included index ranges (where mask is True) current_include = update_index_ranges(selected_mask) return selected_mask, current_include -def select_image_bounds(a, axis, low=None, upp=None, num_min=None, title='select array bounds', +def select_image_bounds( + a, axis, low=None, upp=None, num_min=None, title='select array bounds', raise_error=False): - """Interactively select the lower and upper data bounds for a 2D numpy array. + """ + Interactively select the lower and upper data bounds for a 2D + numpy array. """ a = np.asarray(a) if a.ndim != 2: @@ -912,7 +1076,9 @@ def select_image_bounds(a, axis, low=None, upp=None, num_min=None, title='select raise_error=raise_error) return None if axis < 0 or axis >= a.ndim: - illegal_value(axis, 'axis', location='select_image_bounds', raise_error=raise_error) + illegal_value( + axis, 'axis', location='select_image_bounds', + raise_error=raise_error) return None low_save = low upp_save = upp @@ -921,7 +1087,9 @@ def select_image_bounds(a, axis, low=None, upp=None, num_min=None, title='select num_min = 1 else: if num_min < 2 or num_min > a.shape[axis]: - logger.warning('Invalid input for num_min in select_image_bounds, input ignored') + logger.warning( + 'Invalid input for num_min in select_image_bounds, ' + + 'input ignored') num_min = 1 if low is None: min_ = 0 @@ -934,16 +1102,19 @@ def select_image_bounds(a, axis, low=None, upp=None, num_min=None, title='select else: quick_imshow(a[min_:max_,:], title=title, aspect='auto', extent=[0,a.shape[1], max_,min_]) - zoom_flag = input_yesno('Set lower data bound (y) or zoom in (n)?', 'y') + zoom_flag = input_yesno( + 'Set lower data bound (y) or zoom in (n)?', 'y') if zoom_flag: low = input_int(' Set lower data bound', ge=0, le=low_max) break - else: - min_ = input_int(' Set lower zoom index', ge=0, le=low_max) - max_ = input_int(' Set upper zoom index', ge=min_+1, le=low_max+1) + min_ = input_int(' Set lower zoom index', ge=0, le=low_max) + max_ = input_int( + ' Set upper zoom index', ge=min_+1, le=low_max+1) else: if not is_int(low, ge=0, le=a.shape[axis]-num_min): - illegal_value(low, 'low', location='select_image_bounds', raise_error=raise_error) + illegal_value( + low, 'low', location='select_image_bounds', + raise_error=raise_error) return None if upp is None: min_ = low+num_min @@ -956,16 +1127,21 @@ def select_image_bounds(a, axis, low=None, upp=None, num_min=None, title='select else: quick_imshow(a[min_:max_,:], title=title, aspect='auto', extent=[0,a.shape[1], max_,min_]) - zoom_flag = input_yesno('Set upper data bound (y) or zoom in (n)?', 'y') + zoom_flag = input_yesno( + 'Set upper data bound (y) or zoom in (n)?', 'y') if zoom_flag: - upp = input_int(' Set upper data bound', ge=upp_min, le=a.shape[axis]) + upp = input_int( + ' Set upper data bound', ge=upp_min, le=a.shape[axis]) break - else: - min_ = input_int(' Set upper zoom index', ge=upp_min, le=a.shape[axis]-1) - max_ = input_int(' Set upper zoom index', ge=min_+1, le=a.shape[axis]) + min_ = input_int( + ' Set upper zoom index', ge=upp_min, le=a.shape[axis]-1) + max_ = input_int( + ' Set upper zoom index', ge=min_+1, le=a.shape[axis]) else: if not is_int(upp, ge=low+num_min, le=a.shape[axis]): - illegal_value(upp, 'upp', location='select_image_bounds', raise_error=raise_error) + illegal_value( + upp, 'upp', location='select_image_bounds', + raise_error=raise_error) return None bounds = (low, upp) a_tmp = np.copy(a) @@ -980,22 +1156,28 @@ def select_image_bounds(a, axis, low=None, upp=None, num_min=None, title='select quick_imshow(a_tmp, title=title, aspect='auto') del a_tmp if not input_yesno('Accept these bounds (y/n)?', 'y'): - bounds = select_image_bounds(a, axis, low=low_save, upp=upp_save, num_min=num_min_save, + bounds = select_image_bounds( + a, axis, low=low_save, upp=upp_save, num_min=num_min_save, title=title) clear_imshow(title) return bounds -def select_one_image_bound(a, axis, bound=None, bound_name=None, title='select array bounds', +def select_one_image_bound( + a, axis, bound=None, bound_name=None, title='select array bounds', default='y', raise_error=False): - """Interactively select a data boundary for a 2D numpy array. + """ + Interactively select a data boundary for a 2D numpy array. """ a = np.asarray(a) if a.ndim != 2: - illegal_value(a.ndim, 'array dimension', location='select_one_image_bound', - raise_error=raise_error) + illegal_value( + a.ndim, 'array dimension', location='select_one_image_bound', + raise_error=raise_error) return None if axis < 0 or axis >= a.ndim: - illegal_value(axis, 'axis', location='select_one_image_bound', raise_error=raise_error) + illegal_value( + axis, 'axis', location='select_one_image_bound', + raise_error=raise_error) return None if bound_name is None: bound_name = 'data bound' @@ -1010,17 +1192,20 @@ def select_one_image_bound(a, axis, bound=None, bound_name=None, title='select a else: quick_imshow(a[min_:max_,:], title=title, aspect='auto', extent=[0,a.shape[1], max_,min_]) - zoom_flag = input_yesno(f'Set {bound_name} (y) or zoom in (n)?', 'y') + zoom_flag = input_yesno( + f'Set {bound_name} (y) or zoom in (n)?', 'y') if zoom_flag: bound = input_int(f' Set {bound_name}', ge=0, le=bound_max) clear_imshow(title) break - else: - min_ = input_int(' Set lower zoom index', ge=0, le=bound_max) - max_ = input_int(' Set upper zoom index', ge=min_+1, le=bound_max+1) + min_ = input_int(' Set lower zoom index', ge=0, le=bound_max) + max_ = input_int( + ' Set upper zoom index', ge=min_+1, le=bound_max+1) elif not is_int(bound, ge=0, le=a.shape[axis]-1): - illegal_value(bound, 'bound', location='select_one_image_bound', raise_error=raise_error) + illegal_value( + bound, 'bound', location='select_one_image_bound', + raise_error=raise_error) return None else: print(f'Current {bound_name} = {bound}') @@ -1033,11 +1218,13 @@ def select_one_image_bound(a, axis, bound=None, bound_name=None, title='select a quick_imshow(a_tmp, title=title, aspect='auto') del a_tmp if not input_yesno(f'Accept this {bound_name} (y/n)?', default): - bound = select_one_image_bound(a, axis, bound_name=bound_name, title=title) + bound = select_one_image_bound( + a, axis, bound_name=bound_name, title=title) clear_imshow(title) return bound def clear_imshow(title=None): + """Clear an image opened by quick_imshow().""" plt.ioff() if title is None: title = 'quick imshow' @@ -1046,6 +1233,7 @@ def clear_imshow(title=None): plt.close(fig=title) def clear_plot(title=None): + """Clear an image opened by quick_plot().""" plt.ioff() if title is None: title = 'quick plot' @@ -1053,9 +1241,11 @@ def clear_plot(title=None): raise ValueError(f'Invalid parameter title ({title})') plt.close(fig=title) -def quick_imshow(a, title=None, path=None, name=None, save_fig=False, save_only=False, - clear=True, extent=None, show_grid=False, grid_color='w', grid_linewidth=1, - block=False, **kwargs): +def quick_imshow( + a, title=None, path=None, name=None, save_fig=False, save_only=False, + clear=True, extent=None, show_grid=False, grid_color='w', + grid_linewidth=1, block=False, **kwargs): + """Display a 2D image.""" if title is not None and not isinstance(title, str): raise ValueError(f'Invalid parameter title ({title})') if path is not None and not isinstance(path, str): @@ -1071,7 +1261,7 @@ def quick_imshow(a, title=None, path=None, name=None, save_fig=False, save_only= if not title: title='quick imshow' if name is None: - ttitle = re_sub(r"\s+", '_', title) + ttitle = re_sub(r'\s+', '_', title) if path is None: path = f'{ttitle}.png' else: @@ -1081,12 +1271,15 @@ def quick_imshow(a, title=None, path=None, name=None, save_fig=False, save_only= path = name else: path = f'{path}/{name}' - if 'cmap' in kwargs and a.ndim == 3 and (a.shape[2] == 3 or a.shape[2] == 4): + if ('cmap' in kwargs and a.ndim == 3 + and (a.shape[2] == 3 or a.shape[2] == 4)): use_cmap = True if a.shape[2] == 4 and a[:,:,-1].min() != a[:,:,-1].max(): use_cmap = False - if any(True if a[i,j,0] != a[i,j,1] and a[i,j,0] != a[i,j,2] else False - for i in range(a.shape[0]) for j in range(a.shape[1])): + if any( + a[i,j,0] != a[i,j,1] and a[i,j,0] != a[i,j,2] + for i in range(a.shape[0]) + for j in range(a.shape[1])): use_cmap = False if use_cmap: a = a[:,:,0] @@ -1121,16 +1314,21 @@ def quick_imshow(a, title=None, path=None, name=None, save_fig=False, save_only= if block: plt.show(block=block) -def quick_plot(*args, xerr=None, yerr=None, vlines=None, title=None, xlim=None, ylim=None, - xlabel=None, ylabel=None, legend=None, path=None, name=None, show_grid=False, - save_fig=False, save_only=False, clear=True, block=False, **kwargs): +def quick_plot( + *args, xerr=None, yerr=None, vlines=None, title=None, xlim=None, + ylim=None, xlabel=None, ylabel=None, legend=None, path=None, name=None, + show_grid=False, save_fig=False, save_only=False, clear=True, + block=False, **kwargs): + """Display a 2D line plot.""" if title is not None and not isinstance(title, str): illegal_value(title, 'title', 'quick_plot') title = None - if xlim is not None and not isinstance(xlim, (tuple, list)) and len(xlim) != 2: + if (xlim is not None and not isinstance(xlim, (tuple, list)) + and len(xlim) != 2): illegal_value(xlim, 'xlim', 'quick_plot') xlim = None - if ylim is not None and not isinstance(ylim, (tuple, list)) and len(ylim) != 2: + if (ylim is not None and not isinstance(ylim, (tuple, list)) + and len(ylim) != 2): illegal_value(ylim, 'ylim', 'quick_plot') ylim = None if xlabel is not None and not isinstance(xlabel, str): @@ -1163,7 +1361,7 @@ def quick_plot(*args, xerr=None, yerr=None, vlines=None, title=None, xlim=None, if title is None: title = 'quick plot' if name is None: - ttitle = re_sub(r"\s+", '_', title) + ttitle = re_sub(r'\s+', '_', title) if path is None: path = f'{ttitle}.png' else: @@ -1188,8 +1386,8 @@ def quick_plot(*args, xerr=None, yerr=None, vlines=None, title=None, xlim=None, plt.ion() plt.figure(title) if depth_tuple(args) > 1: - for y in args: - plt.plot(*y, **kwargs) + for y in args: + plt.plot(*y, **kwargs) else: if xerr is None and yerr is None: plt.plot(*args, **kwargs) @@ -1201,7 +1399,8 @@ def quick_plot(*args, xerr=None, yerr=None, vlines=None, title=None, xlim=None, for v in vlines: plt.axvline(v, color='r', linestyle='--', **kwargs) # if vlines is not None: -# for s in tuple(([x, x], list(plt.gca().get_ylim())) for x in vlines): +# for s in tuple( +# ([x, x], list(plt.gca().get_ylim())) for x in vlines): # plt.plot(*s, color='red', **kwargs) if xlim is not None: plt.xlim(xlim) @@ -1222,6 +1421,4 @@ def quick_plot(*args, xerr=None, yerr=None, vlines=None, title=None, xlim=None, else: if save_fig: plt.savefig(path) - if block: - plt.show(block=block) - + plt.show(block=block) diff --git a/CHAP/common/utils/material.py b/CHAP/common/utils/material.py index 60fe4bd..2f9b25b 100755 --- a/CHAP/common/utils/material.py +++ b/CHAP/common/utils/material.py @@ -1,92 +1,100 @@ #!/usr/bin/env python3 - # -*- coding: utf-8 -*- +#pylint: disable= """ -Created on Fri May 27 12:21:25 2022 - -@author: rv43 +File : general.py +Author : Rolf Verberg +Description: Module defining the Material class """ -import logging +# System modules +from logging import getLogger +from os import path -import os +# Third party modules import numpy as np try: - import xrayutilities as xu - have_xu = True -except: - have_xu = False - pass + from xrayutilities import materials + from xrayutilities import simpack + HAVE_XU = True +except ImportError: + HAVE_XU = False try: from hexrd import material - have_hexrd = True -except: - have_hexrd = False - pass -if have_hexrd: + HAVE_HEXRD = True +except ImportError: + HAVE_HEXRD = False +if HAVE_HEXRD: try: from hexrd.valunits import valWUnit - except: - have_hexrd = False - pass + except ImportError: + HAVE_HEXRD = False + +POEDER_INTENSITY_CUTOFF = 1.e-8 -powder_intensity_cutoff = 1.e-8 +logger = getLogger(__name__) class Material: - """Base class for materials in an sin2psi or EDD analysis. - Right now it assumes a single material - extend its ability to do differently when test data is available """ - - def __init__(self, material_name=None, material_file=None, sgnum=None, - lattice_parameters_angstroms=None, atoms=None, pos=None, enrgy=None): + Base class for materials in an sin2psi or EDD analysis. Right now + it assumes a single material, extend its ability to do differently + when test data is available + """ + def __init__( + self, material_name=None, material_file=None, sgnum=None, + lattice_parameters_angstroms=None, atoms=None, pos=None, + enrgy=None): + """Initialize Material.""" self._enrgy = enrgy self._materials = [] self._ds_min = [] self._ds_unique = None self._hkls_unique = None if material_name is not None: - self.add_material(material_name, material_file, sgnum, lattice_parameters_angstroms, - atoms, pos) + self.add_material( + material_name, material_file, sgnum, + lattice_parameters_angstroms, atoms, pos) - @property - #FIX passing arguments to a property isn't correct? def lattice_parameters(self, index=0): - """Convert from internal nm units to angstrom - """ + """Convert from internal nm units to angstrom.""" matl = self._materials[index] - if isinstance(matl, xu.materials.material.Crystal): + if isinstance(matl, materials.material.Crystal): return [matl.a, matl.b, matl.c] - elif isinstance(matl, material.Material): - return [l.getVal("angstrom") for l in self._materials[index].latticeParameters[0:3]] - else: - raise ValueError('Illegal material class type') - return None + if isinstance(matl, material.Material): + return [l.getVal('angstrom') + for l in self._materials[index].latticeParameters[0:3]] + raise ValueError('Illegal material class type') - @property def ds_unique(self, tth_tol=None, tth_max=None, round_sig=8): + """Return the unique lattice spacings.""" if self._ds_unique is None: - self.get_unique_ds(tth_tol, tth_max, round_sig) + self.get_ds_unique(tth_tol, tth_max, round_sig) return self._ds_unique - @property def hkls_unique(self, tth_tol=None, tth_max=None, round_sig=8): + """Return the unique HKLs.""" if self._hkls_unique is None: - self.get_unique_ds(tth_tol, tth_max, round_sig) + self.get_ds_unique(tth_tol, tth_max, round_sig) return self._hkls_unique - def add_material(self, material_name, material_file=None, sgnum=None, - lattice_parameters_angstroms=None, atoms=None, pos=None, dmin_angstroms=0.6): + def add_material( + self, material_name, material_file=None, sgnum=None, + lattice_parameters_angstroms=None, atoms=None, pos=None, + dmin_angstroms=0.6): + """Add a material.""" # At this point only for a single material - # Unique energies works for more, but fitting with different materials is not implemented + # Unique energies works for more, but fitting with different + # materials is not implemented if len(self._materials) == 1: - exit('Multiple materials not implemented yet') + raise ValueError('Multiple materials not implemented yet') self._ds_min.append(dmin_angstroms) - self._materials.append(Material.make_material(material_name, material_file, sgnum, - lattice_parameters_angstroms, atoms, pos, dmin_angstroms)) + self._materials.append( + Material.make_material(material_name, material_file, sgnum, + lattice_parameters_angstroms, atoms, pos, dmin_angstroms)) - def get_unique_ds(self, tth_tol=None, tth_max=None, round_sig=8): - """Get the list of unique lattice spacings from material HKLs + def get_ds_unique(self, tth_tol=None, tth_max=None, round_sig=8): + """ + Get the list of unique lattice spacings from material HKLs. Parameters ---------- @@ -96,123 +104,145 @@ def get_unique_ds(self, tth_tol=None, tth_max=None, round_sig=8): Returns ------- - hkls : list of hkl's corresponding to the unique lattice spacings - ds : list of the unique lattice spacings + hkls: list of hkl's corresponding to the unique lattice spacings + ds: list of the unique lattice spacings """ hkls = np.empty((0,3)) ds = np.empty((0)) - ids = np.empty((0)) - for i,m in enumerate(self._materials): + ds_index = np.empty((0)) + for i, m in enumerate(self._materials): material_class_valid = False - if have_xu: - if isinstance(m, xu.materials.material.Crystal): - powder = xu.simpack.PowderDiffraction(m, en=self._enrgy) - hklsi = [hkl for hkl in powder.data if powder.data[hkl]['active']] - dsi = [m.planeDistance(hkl) for hkl in powder.data if powder.data[hkl]['active']] - mask = [True if d > self._ds_min[i] else False for d in dsi] + if HAVE_XU: + if isinstance(m, materials.material.Crystal): + powder = simpack.PowderDiffraction(m, en=self._enrgy) + hklsi = [hkl for hkl in powder.data + if powder.data[hkl]['active']] + ds_i = [m.planeDistance(hkl) for hkl in powder.data + if powder.data[hkl]['active']] + mask = [d > self._ds_min[i] for d in ds_i] hkls = np.vstack((hkls, np.array(hklsi)[mask,:])) - dsi = np.array(dsi)[mask] + ds_i = np.array(ds_i)[mask] material_class_valid = True - if have_hexrd: + if HAVE_HEXRD: if isinstance(m, material.Material): - pd = m.planeData + plane_data = m.planeData if tth_tol is not None: - pd.tThWidth = np.radians(tth_tol) + plane_data.tThWidth = np.radians(tth_tol) if tth_max is not None: - pd.exclusions = None - pd.tThMax = np.radians(tth_max) - hkls = np.vstack((hkls, pd.hkls.T)) - dsi = pd.getPlaneSpacings() + plane_data.exclusions = None + plane_data.tThMax = np.radians(tth_max) + hkls = np.vstack((hkls, plane_data.hkls.T)) + ds_i = plane_data.getPlaneSpacings() material_class_valid = True if not material_class_valid: raise ValueError('Illegal material class type') - ds = np.hstack((ds, dsi)) - ids = np.hstack((ids, i*np.ones(len(dsi)))) + ds = np.hstack((ds, ds_i)) + ds_index = np.hstack((ds_index, i*np.ones(len(ds_i)))) # Sort lattice spacings in reverse order (use -) - ds_unique, ids_unique, _ = np.unique(-ds.round(round_sig), return_index=True, - return_counts=True) + ds_unique, ds_index_unique, _ = np.unique( + -ds.round(round_sig), return_index=True, return_counts=True) ds_unique = np.abs(ds_unique) # Limit the list to unique lattice spacings - self._hkls_unique = hkls[ids_unique,:].astype(int) - self._ds_unique = ds[ids_unique] - hkl_list = np.vstack((np.arange(self._ds_unique.shape[0]), ids[ids_unique], - self._hkls_unique.T, self._ds_unique)).T - logging.info("Unique d's:") + self._hkls_unique = hkls[ds_index_unique,:].astype(int) + self._ds_unique = ds[ds_index_unique] + hkl_list = np.vstack( + (np.arange(self._ds_unique.shape[0]), ds_index[ds_index_unique], + self._hkls_unique.T, self._ds_unique)).T + logger.info("Unique d's:") for hkl in hkl_list: - logging.info(f'{hkl[0]:4.0f} {hkl[1]:.0f} {hkl[2]:.0f} {hkl[3]:.0f} {hkl[4]:.0f} '+ - f'{hkl[5]:.6f}') + logger.info( + f'{hkl[0]:4.0f} {hkl[1]:.0f} {hkl[2]:.0f} {hkl[3]:.0f} ' + + f'{hkl[4]:.0f} {hkl[5]:.6f}') return self._hkls_unique, self._ds_unique @staticmethod - def make_material(material_name, material_file=None, sgnum=None, - lattice_parameters_angstroms=None, atoms=None, pos=None, dmin_angstroms=0.6): - """ Use HeXRD to get material properties when a materials file is provided - Use xrayutilities otherwise + def make_material( + material_name, material_file=None, sgnum=None, + lattice_parameters_angstroms=None, atoms=None, pos=None, + dmin_angstroms=0.6): + """ + Use HeXRD to get material properties when a materials file is + provided. Use xrayutilities otherwise. """ if not isinstance(material_name, str): - raise ValueError(f'Illegal material_name: {material_name} {type(material_name)}') + raise ValueError( + f'Illegal material_name: {material_name} ' + + f'{type(material_name)}') if lattice_parameters_angstroms is not None: if material_file is not None: - logging.warning('Overwrite lattice_parameters of material_file with input values '+ - f'({lattice_parameters_angstroms})') + logger.warning( + 'Overwrite lattice_parameters of material_file with input ' + + f'values ({lattice_parameters_angstroms})') if isinstance(lattice_parameters_angstroms, (int, float)): lattice_parameters = [lattice_parameters_angstroms] - elif isinstance(lattice_parameters_angstroms, (tuple, list, np.ndarray)): + elif isinstance( + lattice_parameters_angstroms, (tuple, list, np.ndarray)): lattice_parameters = list(lattice_parameters_angstroms) else: - raise ValueError(f'Illegal lattice_parameters_angstroms: '+ - f'{lattice_parameters_angstroms} {type(lattice_parameters_angstroms)}') + raise ValueError('Illegal lattice_parameters_angstroms: ' + + f'{lattice_parameters_angstroms} ' + + f'{type(lattice_parameters_angstroms)}') if material_file is None: if not isinstance(sgnum, int): raise ValueError(f'Illegal sgnum: {sgnum} {type(sgnum)}') - if sgnum is None or lattice_parameters_angstroms is None or pos is None: - raise ValueError('Valid inputs for sgnum and lattice_parameters_angstroms and '+ - 'pos are required if materials file is not specified') + if (sgnum is None or lattice_parameters_angstroms is None + or pos is None): + raise ValueError( + 'Valid inputs for sgnum, lattice_parameters_angstroms and ' + + 'pos are required if materials file is not specified') if isinstance(pos, str): pos = [pos] use_xu = True - if (np.array(pos).ndim == 1 and isinstance(pos[0], (int, float)) and - np.array(pos).size == 3): - if have_hexrd: + if (np.array(pos).ndim == 1 and isinstance(pos[0], (int, float)) + and np.array(pos).size == 3): + if HAVE_HEXRD: pos = np.array([pos]) use_xu = False elif (np.array(pos).ndim == 2 and np.array(pos).shape[0] > 0 and np.array(pos).shape[1] == 3): - if have_hexrd: + if HAVE_HEXRD: pos = np.array(pos) use_xu = False elif not (np.array(pos).ndim == 1 and isinstance(pos[0], str) and - np.array(pos).size > 0 and have_xu): - raise ValueError(f'Illegal pos (have_xu = {have_xu}): {pos} {type(pos)}') + np.array(pos).size > 0 and HAVE_XU): + raise ValueError( + f'Illegal pos (HAVE_XU = {HAVE_XU}): {pos} {type(pos)}') if use_xu: if atoms is None: atoms = [material_name] - matl = xu.materials.Crystal(material_name, xu.materials.SGLattice(sgnum, - *lattice_parameters, atoms=atoms, pos=[p for p in np.array(pos)])) + matl = materials.Crystal( + material_name, materials.SGLattice(sgnum, + *lattice_parameters, atoms=atoms, + pos=list(np.array(pos)))) else: matl = material.Material(material_name) matl.sgnum = sgnum matl.atominfo = np.vstack((pos.T, np.ones(pos.shape[0]))).T matl.latticeParameters = lattice_parameters - matl.dmin = valWUnit('lp', 'length', dmin_angstroms, 'angstrom') + matl.dmin = valWUnit( + 'lp', 'length', dmin_angstroms, 'angstrom') exclusions = matl.planeData.get_exclusions() powder_intensity = matl.planeData.powder_intensity - exclusions = [True if exclusion or i >= len(powder_intensity) or - powder_intensity[i] < powder_intensity_cutoff else False - for i, exclusion in enumerate(exclusions)] + exclusions = [exclusion or i >= len(powder_intensity) or + powder_intensity[i] < POEDER_INTENSITY_CUTOFF + for i, exclusion in enumerate(exclusions)] matl.planeData.set_exclusions(exclusions) - logging.debug(f'powder_intensity = {matl.planeData.powder_intensity}') - logging.debug(f'exclusions = {matl.planeData.exclusions}') + logger.debug( + f'powder_intensity = {matl.planeData.powder_intensity}') + logger.debug(f'exclusions = {matl.planeData.exclusions}') else: - if not have_hexrd: - raise ValueError('Illegal inputs: must provide detailed material info when '+ - 'hexrd package is unavailable') + if not HAVE_HEXRD: + raise ValueError( + 'Illegal inputs: must provide detailed material info when ' + + 'hexrd package is unavailable') if sgnum is not None: - logging.warning('Ignore sgnum input when material_file is specified') - if not (os.path.splitext(material_file)[1] in ('.h5', '.hdf5', '.xtal', '.cif')): + logger.warning( + 'Ignore sgnum input when material_file is specified') + if not (path.splitext(material_file)[1] in + ('.h5', '.hdf5', '.xtal', '.cif')): raise ValueError(f'Illegal material file {material_file}') matl = material.Material(material_name, material_file, dmin=valWUnit('lp', 'length', dmin_angstroms, 'angstrom')) @@ -220,12 +250,12 @@ def make_material(material_name, material_file=None, sgnum=None, matl.latticeParameters = lattice_parameters exclusions = matl.planeData.get_exclusions() powder_intensity = matl.planeData.powder_intensity - exclusions = [True if exclusion or i >= len(powder_intensity) or - powder_intensity[i] < powder_intensity_cutoff else False + exclusions = [exclusion or i >= len(powder_intensity) or + powder_intensity[i] < POEDER_INTENSITY_CUTOFF for i, exclusion in enumerate(exclusions)] matl.planeData.set_exclusions(exclusions) - logging.debug(f'powder_intensity = {matl.planeData.powder_intensity}') - logging.debug(f'exclusions = {matl.planeData.exclusions}') + logger.debug( + f'powder_intensity = {matl.planeData.powder_intensity}') + logger.debug(f'exclusions = {matl.planeData.exclusions}') return matl - diff --git a/CHAP/tomo/__init__.py b/CHAP/tomo/__init__.py index eb545e5..8898633 100644 --- a/CHAP/tomo/__init__.py +++ b/CHAP/tomo/__init__.py @@ -1,5 +1,2 @@ -# from CHAP.tomo.reader import -from CHAP.tomo.processor import TomoDataProcessor -# from CHAP.tomo.writer import - from CHAP.common import MapProcessor +from CHAP.tomo.processor import TomoDataProcessor diff --git a/CHAP/tomo/models.py b/CHAP/tomo/models.py index 60a699b..3e2a7e6 100644 --- a/CHAP/tomo/models.py +++ b/CHAP/tomo/models.py @@ -1,15 +1,18 @@ -# system modules +'''Tomography Pydantic model classes''' -# third party imports +# Third party imports +from typing import ( + Literal, + Optional, +) from pydantic import ( - BaseModel, - StrictBool, - conint, - conlist, - confloat, - constr, + BaseModel, + StrictBool, + conint, + conlist, + confloat, + constr, ) -from typing import Literal, Optional class Detector(BaseModel): @@ -25,21 +28,26 @@ class Detector(BaseModel): :ivar pixel_size: Pixel size of the detector in mm :type pixel_size: int or list[int] :ivar lens_magnification: Lens magnification for the detector - :type lens_magnification: float, optional + :type lens_magnification: float, optional [1.0] """ prefix: constr(strip_whitespace=True, min_length=1) rows: conint(gt=0) columns: conint(gt=0) - pixel_size: conlist(item_type=confloat(gt=0, allow_inf_nan=False), min_items=1, max_items=2) + pixel_size: conlist(item_type=confloat(gt=0, allow_inf_nan=False), + min_items=1, max_items=2) lens_magnification: confloat(gt=0, allow_inf_nan=False) = 1.0 class TomoSetupConfig(BaseModel): """ - Class representing the configuration for the tomography reconstruction setup. + Class representing the configuration for the tomography + reconstruction setup. :ivar detectors: Detector used in the tomography experiment :type detectors: Detector + :ivar include_raw_data: Flag to designate whether raw data will be + included (True) or not (False) + :type include_raw_data: bool, optional [False] """ detector: Detector.construct() include_raw_data: Optional[StrictBool] = False @@ -47,9 +55,11 @@ class TomoSetupConfig(BaseModel): class TomoReduceConfig(BaseModel): """ - Class representing the configuration for tomography image reductions. + Class representing the configuration for tomography image + reductions. - :ivar tool_type: Type of tomography reconstruction tool; always set to "reduce_data" + :ivar tool_type: Type of tomography reconstruction tool; always set + to "reduce_data" :type tool_type: str, optional :ivar detectors: Detector used in the tomography experiment :type detectors: Detector @@ -58,16 +68,19 @@ class TomoReduceConfig(BaseModel): """ tool_type: Literal['reduce_data'] = 'reduce_data' detector: Detector = Detector.construct() - img_x_bounds: Optional[conlist(item_type=conint(ge=0), min_items=2, max_items=2)] + img_x_bounds: Optional[conlist(item_type=conint(ge=0), min_items=2, + max_items=2)] class TomoFindCenterConfig(BaseModel): """ Class representing the configuration for tomography find center axis. - :ivar tool_type: Type of tomography reconstruction tool; always set to "find_center" + :ivar tool_type: Type of tomography reconstruction tool; always set + to "find_center" :type tool_type: str, optional - :ivar center_stack_index: Stack index of tomography set to find center axis (offset 1) + :ivar center_stack_index: Stack index of tomography set to find + center axis (offset 1) :type center_stack_index: int, optional :ivar lower_row: Lower row index for center finding :type lower_row: int, optional @@ -88,9 +101,11 @@ class TomoFindCenterConfig(BaseModel): class TomoReconstructConfig(BaseModel): """ - Class representing the configuration for tomography image reconstruction. + Class representing the configuration for tomography image + reconstruction. - :ivar tool_type: Type of tomography reconstruction tool; always set to "reconstruct_data" + :ivar tool_type: Type of tomography reconstruction tool; always set + to "reconstruct_data" :type tool_type: str, optional :ivar x_bounds: Reconstructed image bounds in the x-direction :type x_bounds: list[int], optional @@ -100,26 +115,32 @@ class TomoReconstructConfig(BaseModel): :type z_bounds: list[int], optional """ tool_type: Literal['reconstruct_data'] = 'reconstruct_data' - x_bounds: Optional[conlist(item_type=conint(ge=-1), min_items=2, max_items=2)] - y_bounds: Optional[conlist(item_type=conint(ge=-1), min_items=2, max_items=2)] - z_bounds: Optional[conlist(item_type=conint(ge=-1), min_items=2, max_items=2)] + x_bounds: Optional[conlist(item_type=conint(ge=-1), min_items=2, + max_items=2)] + y_bounds: Optional[conlist(item_type=conint(ge=-1), min_items=2, + max_items=2)] + z_bounds: Optional[conlist(item_type=conint(ge=-1), min_items=2, + max_items=2)] class TomoCombineConfig(BaseModel): """ Class representing the configuration for combined tomography stacks. - :ivar tool_type: Type of tomography reconstruction tool; always set to "combine_data" + :ivar tool_type: Type of tomography reconstruction tool; always set + to "combine_data" :type tool_type: str, optional - :ivar x_bounds: Reconstructed image bounds in the x-direction + :ivar x_bounds: Combined image bounds in the x-direction :type x_bounds: list[int], optional - :ivar y_bounds: Reconstructed image bounds in the y-direction + :ivar y_bounds: Combined image bounds in the y-direction :type y_bounds: list[int], optional - :ivar z_bounds: Reconstructed image bounds in the z-direction + :ivar z_bounds: Combined image bounds in the z-direction :type z_bounds: list[int], optional """ tool_type: Literal['combine_data'] = 'combine_data' - x_bounds: Optional[conlist(item_type=conint(ge=-1), min_items=2, max_items=2)] - y_bounds: Optional[conlist(item_type=conint(ge=-1), min_items=2, max_items=2)] - z_bounds: Optional[conlist(item_type=conint(ge=-1), min_items=2, max_items=2)] - + x_bounds: Optional[conlist(item_type=conint(ge=-1), min_items=2, + max_items=2)] + y_bounds: Optional[conlist(item_type=conint(ge=-1), min_items=2, + max_items=2)] + z_bounds: Optional[conlist(item_type=conint(ge=-1), min_items=2, + max_items=2)] diff --git a/CHAP/tomo/processor.py b/CHAP/tomo/processor.py index 04b5463..bac852c 100644 --- a/CHAP/tomo/processor.py +++ b/CHAP/tomo/processor.py @@ -1,60 +1,94 @@ #!/usr/bin/env python #-*- coding: utf-8 -*- #pylint: disable= -''' +""" File : processor.py Author : Rolf Verberg Description: Module for Processors used only by tomography experiments -''' +""" -# system modules +# System modules from os import mkdir from os import path as os_path +from sys import exit as sys_exit from time import time -# third party modules -from nexusformat.nexus import NXobject +# Third party modules import numpy as np -# local modules -from CHAP.common.utils.general import input_int, input_num, input_yesno, select_image_bounds, \ - select_one_image_bound, draw_mask_1d, clear_plot, clear_imshow, quick_plot, quick_imshow +# Local modules +from CHAP.common.utils.general import ( + is_num, + input_int, +# input_num, + input_yesno, + select_image_bounds, + select_one_image_bound, + draw_mask_1d, + clear_plot, + clear_imshow, + quick_plot, + quick_imshow, +) from CHAP.common.utils.fit import Fit from CHAP.processor import Processor +from CHAP.reader import main -num_core_tomopy_limit = 24 +NUM_CORE_TOMOPY_LIMIT = 24 class TomoDataProcessor(Processor): - '''Class representing the processes to reconstruct a set of Tomographic images returning - either a dictionary or a `nexusformat.nexus.NXroot` object containing the (meta) data after - processing each individual step. - ''' - - def _process(self, data, interactive, reduce_data=False, find_center=False, - reconstruct_data=False, combine_data=False, output_folder='.', save_figs=None): - '''Process the output of a `Reader` that contains a map or a `nexusformat.nexus.NXroot` - object and one that contains the step specific instructions and return either a dictionary - or a `nexusformat.nexus.NXroot` returning the processed result. - - :param data: Result of `Reader.read` where at least one item is of type - `nexusformat.nexus.NXroot` or has the value `'MapConfig'` for the `'schema'` key, - and at least one item has the value `'TomoReduceConfig'` for the `'schema'` key. + """ + Class representing the processes to reconstruct a set of Tomographic + images returning either a dictionary or a `nexusformat.nexus.NXroot` + object containing the (meta) data after processing each individual + step. + """ + + def _process( + self, data, interactive=False, reduce_data=False, + find_center=False, reconstruct_data=False, combine_data=False, + output_folder='.', save_figs=None, **kwargs): + """ + Process the output of a `Reader` that contains a map or a + `nexusformat.nexus.NXroot` object and one that contains the step + specific instructions and return either a dictionary or a + `nexusformat.nexus.NXroot` returning the processed result. + + :param data: Result of `Reader.read` :type data: list[dict[str,object]] + :param interactive: Allows interactive actions + :type bool, optional [False] + :param reduce_data: Generate reduced tomography images + :type bool, optional [False] + :param find_center: Find the calibrated center axis info + :type bool, optional [False] + :param reconstruct_data: Reconstruct the tomography data + :type bool, optional [False] + :param combine_data: Combine the reconstructed tomography stacks + :type bool, optional [False] + :param output_folder: Output folder name + :type str, optional ['.'] + :param save_figs: Display and/or save figures to file + :type str, optional :return: processed (meta)data :rtype: dict or nexusformat.nexus.NXroot - ''' + """ if not isinstance(reduce_data, bool): raise ValueError(f'Invalid parameter reduce_data ({reduce_data})') if not isinstance(find_center, bool): raise ValueError(f'Invalid parameter find_center ({find_center})') if not isinstance(reconstruct_data, bool): - raise ValueError(f'Invalid parameter reconstruct_data ({reconstruct_data})') + raise ValueError( + f'Invalid parameter reconstruct_data ({reconstruct_data})') if not isinstance(combine_data, bool): - raise ValueError(f'Invalid parameter combine_data ({combine_data})') + raise ValueError( + f'Invalid parameter combine_data ({combine_data})') - tomo = Tomo(interactive=interactive, output_folder=output_folder, save_figs=save_figs) + tomo = Tomo( + interactive=interactive, output_folder=output_folder, + save_figs=save_figs) nxroot = None center_config = None @@ -76,15 +110,16 @@ def _process(self, data, interactive, reduce_data=False, find_center=False, else: img_x_bounds = None if nxroot is None: - raise RuntimeError('Unable to reduce the data without providing a '+ - 'reduced_data config file') + raise RuntimeError( + 'Unable to reduce the data without providing a ' + + 'reduced_data config file') if nxroot is None: map_config = configs.pop('map') nxroot = self.get_nxroot(map_config, tool_config) nxroot = tomo.gen_reduced_data(nxroot, img_x_bounds=img_x_bounds) # Find rotation axis centers for the tomography stacks - # Pass tool_config directly to tomo.find_centers? + # RV pass tool_config directly to tomo.find_centers? if find_center or 'find_center' in configs: if 'find_center' in configs: tool_config = configs.pop('find_center') @@ -97,19 +132,24 @@ def _process(self, data, interactive, reduce_data=False, find_center=False, lower_center_offset = None upper_center_offset = None center_stack_index = None - if None in center_rows or lower_center_offset is None or upper_center_offset is None: - center_config = tomo.find_centers(nxroot, center_rows=center_rows, - center_stack_index=center_stack_index) + if (None in center_rows or lower_center_offset is None + or upper_center_offset is None): + center_config = tomo.find_centers( + nxroot, center_rows=center_rows, + center_stack_index=center_stack_index) else: - #RV make a convert to dict in basemodel? - center_config = {'lower_row': tool_config.lower_row, - 'lower_center_offset': tool_config.lower_center_offset, - 'upper_row': tool_config.upper_row, - 'upper_center_offset': tool_config.upper_center_offset, - 'center_stack_index': tool_config.center_stack_index} + # RV make a convert to dict in basemodel? + center_config = { + 'lower_row': tool_config.lower_row, + 'lower_center_offset': tool_config.lower_center_offset, + 'upper_row': tool_config.upper_row, + 'upper_center_offset': tool_config.upper_center_offset, + 'center_stack_index': tool_config.center_stack_index, + } # Reconstruct tomography stacks - # Pass tool_config and center_config directly to tomo.reconstruct_data + # RV pass tool_config and center_config directly to + # tomo.reconstruct_data? if reconstruct_data or 'reconstruct' in configs: if 'reconstruct' in configs: tool_config = configs.pop('reconstruct') @@ -120,8 +160,9 @@ def _process(self, data, interactive, reduce_data=False, find_center=False, x_bounds = None y_bounds = None z_bounds = None - nxroot = tomo.reconstruct_data(nxroot, center_config, x_bounds=x_bounds, - y_bounds=y_bounds, z_bounds=z_bounds) + nxroot = tomo.reconstruct_data( + nxroot, center_config, x_bounds=x_bounds, y_bounds=y_bounds, + z_bounds=z_bounds) center_config = None # Combine reconstructed tomography stacks @@ -135,16 +176,17 @@ def _process(self, data, interactive, reduce_data=False, find_center=False, x_bounds = None y_bounds = None z_bounds = None - nxroot = tomo.combine_data(nxroot, x_bounds=x_bounds, y_bounds=y_bounds, - z_bounds=z_bounds) + nxroot = tomo.combine_data( + nxroot, x_bounds=x_bounds, y_bounds=y_bounds, + z_bounds=z_bounds) if center_config is not None: return center_config - else: - return nxroot + return nxroot def get_configs(self, data): - '''Get instances of the configuration objects needed by this + """ + Get instances of the configuration objects needed by this `Processor` from a returned value of `Reader.read` :param data: Result of `Reader.read` where at least one item @@ -159,13 +201,22 @@ def get_configs(self, data): :return: valid instances of the configuration objects with field values taken from `data`. :rtype: dict - ''' - #:rtype: dict{'map': MapConfig, 'reduce': TomoReduceConfig} RV: Is there a way to denote optional items? - from CHAP.common.models.map import MapConfig - from CHAP.tomo.models import TomoSetupConfig, TomoReduceConfig, TomoFindCenterConfig, \ - TomoReconstructConfig, TomoCombineConfig + """ + #:rtype: dict{'map': MapConfig, 'reduce': TomoReduceConfig} + # RV is there a way to denote optional items? + # Third party modules from nexusformat.nexus import NXroot + # Local modules + from CHAP.common.models.map import MapConfig + from CHAP.tomo.models import ( + TomoSetupConfig, + TomoReduceConfig, + TomoFindCenterConfig, + TomoReconstructConfig, + TomoCombineConfig, + ) + configs = {} if isinstance(data, list): for item in data: @@ -176,38 +227,59 @@ def get_configs(self, data): if schema == 'MapConfig': configs['map'] = MapConfig(**(item.get('data'))) if schema == 'TomoSetupConfig': - configs['setup'] = TomoSetupConfig(**(item.get('data'))) + configs['setup'] = TomoSetupConfig( + **(item.get('data'))) if schema == 'TomoReduceConfig': - configs['reduce'] = TomoReduceConfig(**(item.get('data'))) + configs['reduce'] = TomoReduceConfig( + **(item.get('data'))) elif schema == 'TomoFindCenterConfig': - configs['find_center'] = TomoFindCenterConfig(**(item.get('data'))) + configs['find_center'] = TomoFindCenterConfig( + **(item.get('data'))) elif schema == 'TomoReconstructConfig': - configs['reconstruct'] = TomoReconstructConfig(**(item.get('data'))) + configs['reconstruct'] = TomoReconstructConfig( + **(item.get('data'))) elif schema == 'TomoCombineConfig': - configs['combine'] = TomoCombineConfig(**(item.get('data'))) + configs['combine'] = TomoCombineConfig( + **(item.get('data'))) return configs def get_nxroot(self, map_config, tool_config): - '''Get a map of the collected tomography data from the scans in `map_config`. The - data will be reduced based on additional parameters included in `tool_config`. - The data will be returned along with relevant metadata in the form of a NeXus structure. + """ + Get a map of the collected tomography data from the scans in + `map_config`. The data will be reduced based on additional + parameters included in `tool_config`. The data will be returned + along with relevant metadata in the form of a NeXus structure. :param map_config: the map configuration :type map_config: MapConfig :param tool_config: the tomography image reduction configuration :type tool_config: TomoReduceConfig - :return: a map of the collected tomography data along with the data reduction configuration + :return: a map of the collected tomography data along with the + data reduction configuration :rtype: nexusformat.nexus.NXroot - ''' + """ + # System modules + from copy import deepcopy + + # Third party modules + from nexusformat.nexus import ( + NXcollection, + NXdata, + NXdetector, + NXinstrument, + NXroot, + NXsample, + NXsource, + NXsubentry, + ) + + # Local modules from CHAP.common import MapProcessor from CHAP.common.models.map import import_scanparser from CHAP.common.utils.general import index_nearest - from copy import deepcopy - from nexusformat.nexus import NXcollection, NXdata, NXdetector, NXinstrument, NXsample, \ - NXsource, NXsubentry, NXroot - include_raw_data = getattr(tool_config, "include_raw_data", False) + include_raw_data = getattr(tool_config, 'include_raw_data', False) # Construct NXroot nxroot = NXroot() @@ -237,29 +309,37 @@ def get_nxroot(self, map_config, tool_config): nxsource.attrs['station'] = map_config.station nxsource.attrs['experiment_type'] = map_config.experiment_type - # Add an NXdetector to the NXinstrument (don't fill in data fields yet) + # Add an NXdetector to the NXinstrument + # (do not fill in data fields yet) nxdetector = NXdetector() nxinstrument.detector = nxdetector nxdetector.local_name = tool_config.detector.prefix pixel_size = tool_config.detector.pixel_size if len(pixel_size) == 1: - nxdetector.x_pixel_size = pixel_size[0]/tool_config.detector.lens_magnification - nxdetector.y_pixel_size = pixel_size[0]/tool_config.detector.lens_magnification + nxdetector.x_pixel_size = \ + pixel_size[0]/tool_config.detector.lens_magnification + nxdetector.y_pixel_size = \ + pixel_size[0]/tool_config.detector.lens_magnification else: - nxdetector.x_pixel_size = pixel_size[0]/tool_config.detector.lens_magnification - nxdetector.y_pixel_size = pixel_size[1]/tool_config.detector.lens_magnification + nxdetector.x_pixel_size = \ + pixel_size[0]/tool_config.detector.lens_magnification + nxdetector.y_pixel_size = \ + pixel_size[1]/tool_config.detector.lens_magnification nxdetector.x_pixel_size.attrs['units'] = 'mm' nxdetector.y_pixel_size.attrs['units'] = 'mm' if include_raw_data: - # Add an NXsample to NXentry (don't fill in data fields yet) + # Add an NXsample to NXentry + # (do not fill in data fields yet) nxsample = NXsample() nxentry.sample = nxsample nxsample.name = map_config.sample.name nxsample.description = map_config.sample.description - # Add NXcollection's to NXentry to hold metadata about the spec scans in the map - # Also obtain the data fields in NXsample and NXdetector if requested + # Add NXcollection's to NXentry to hold metadata about the spec + # scans in the map + # Also obtain the data fields in NXsample and NXdetector if + # requested import_scanparser(map_config.station, map_config.experiment_type) image_keys = [] sequence_numbers = [] @@ -288,22 +368,24 @@ def get_nxroot(self, map_config, tool_config): image_key = 2 field_name = 'dark_field' elif scans.spec_file.endswith('_flat'): - #RV not yet tested with an actual fmb run + # RV not yet tested with an actual fmb run image_key = 1 field_name = 'bright_field' else: image_key = 0 field_name = 'tomo_fields' else: - raise RuntimeError(f'Invalid station: {station}') + raise RuntimeError( + f'Invalid station in map_config: {map_config.station}') # Create an NXcollection for each field type if field_name in nxentry.spec_scans: nxcollection = nxentry.spec_scans[field_name] if nxcollection.attrs['spec_file'] != str(scans.spec_file): - raise RuntimeError(f'Multiple SPEC files for a single field type not yet '+ - f'implemented; field name: {field_name}, '+ - f'SPEC file: {str(scans.spec_file)}') + raise RuntimeError( + 'Multiple SPEC files for a single field type not ' + + f'yet implemented; field name: {field_name}, ' + + f'SPEC file: {str(scans.spec_file)}') else: nxcollection = NXcollection() nxentry.spec_scans[field_name] = nxcollection @@ -314,26 +396,31 @@ def get_nxroot(self, map_config, tool_config): image_offset = scanparser.starting_image_offset if map_config.station in ('id1a3', 'id3a'): theta_vals = scanparser.theta_vals - thetas = np.linspace(theta_vals.get('start'), theta_vals.get('end'), - theta_vals.get('num')) + thetas = np.linspace(theta_vals.get('start'), + theta_vals.get('end'), theta_vals.get('num')) else: if len(scans.scan_numbers) != 1: - raise RuntimeError('Multiple scans not yet implemented for '+ - f'{map_config.station}') + raise RuntimeError( + 'Multiple scans not yet implemented for ' + + f'{map_config.station}') scan_number = scans.scan_numbers[0] thetas = [] for dim in map_config.independent_dimensions: if dim.label != 'theta': continue for index in range(scanparser.spec_scan_npts): - thetas.append(dim.get_value(scans, scan_number, index)) - if not len(thetas): - raise RuntimeError(f'Unable to obtain thetas for {field_name}') + thetas.append( + dim.get_value(scans, scan_number, index)) + if not thetas: + raise RuntimeError( + f'Unable to obtain thetas for {field_name}') if thetas[image_offset] <= 0.0 and thetas[-1] >= 180.0: image_offset = index_nearest(thetas, 0.0) - thetas = thetas[image_offset:index_nearest(thetas, 180.0)] + image_end = index_nearest(thetas, 180.0) + thetas = thetas[image_offset:image_end] elif thetas[-1]-thetas[image_offset] >= 180: - thetas = thetas[image_offset:index_nearest(thetas, thetas[0]+180.0)] + image_end = index_nearest(thetas, thetas[0]+180.0) + thetas = thetas[image_offset:image_end] else: thetas = thetas[image_offset:] @@ -349,9 +436,11 @@ def get_nxroot(self, map_config, tool_config): nxsubentry.spec_command = scanparser.spec_command # Add an NXinstrument to the scan's NXsubentry nxsubentry.instrument = NXinstrument() - # Add an NXdetector to the NXinstrument to the scan's NXsubentry + # Add an NXdetector to the NXinstrument to the scan's + # NXsubentry nxsubentry.instrument.detector = deepcopy(nxdetector) - nxsubentry.instrument.detector.frame_start_number = image_offset + nxsubentry.instrument.detector.frame_start_number = \ + image_offset nxsubentry.instrument.detector.image_key = image_key # Add an NXsample to the scan's NXsubentry nxsubentry.sample = NXsample() @@ -366,8 +455,11 @@ def get_nxroot(self, map_config, tool_config): num_image = len(thetas) image_keys += num_image*[image_key] sequence_numbers += list(range(num_image)) - image_stacks.append(scanparser.get_detector_data(tool_config.detector.prefix, - scan_step_index=(image_offset, image_offset+num_image))) + image_stacks.append( + scanparser.get_detector_data( + tool_config.detector.prefix, + scan_step_index=(image_offset, + image_offset+num_image))) rotation_angles += list(thetas) x_translations += num_image*[x_translation] z_translations += num_image*[z_translation] @@ -376,7 +468,7 @@ def get_nxroot(self, map_config, tool_config): # Add image data to NXdetector nxinstrument.detector.image_key = image_keys nxinstrument.detector.sequence_number = sequence_numbers - nxinstrument.detector.data = np.concatenate([image for image in image_stacks]) + nxinstrument.detector.data = np.concatenate(image_stacks) # Add image data to NXsample nxsample.rotation_angle = rotation_angles @@ -399,11 +491,13 @@ def get_nxroot(self, map_config, tool_config): # nxdata.attrs['row_indices'] = 1 # nxdata.attrs['column_indices'] = 2 - return(nxroot) + return nxroot -def nxcopy(nxobject:NXobject, exclude_nxpaths:list[str]=[], nxpath_prefix:str='') -> NXobject: - '''Function that returns a copy of a nexus object, optionally exluding certain child items. +def nxcopy(nxobject, exclude_nxpaths=None, nxpath_prefix=''): + """ + Function that returns a copy of a nexus object, optionally exluding + certain child items. :param nxobject: the original nexus object to return a "copy" of :type nxobject: nexusformat.nexus.NXobject @@ -415,17 +509,20 @@ def nxcopy(nxobject:NXobject, exclude_nxpaths:list[str]=[], nxpath_prefix:str='' :type nxpath_prefix: str :return: a copy of `nxobject` with some children optionally exluded. :rtype: NXobject - ''' + """ + # Third party modules from nexusformat.nexus import NXgroup nxobject_copy = nxobject.__class__() - if not len(nxpath_prefix): + if not nxpath_prefix: if 'default' in nxobject.attrs: nxobject_copy.attrs['default'] = nxobject.attrs['default'] else: for k, v in nxobject.attrs.items(): nxobject_copy.attrs[k] = v + if exclude_nxpaths is None: + exclude_nxpaths = [] for k, v in nxobject.items(): nxpath = os_path.join(nxpath_prefix, k) @@ -433,43 +530,61 @@ def nxcopy(nxobject:NXobject, exclude_nxpaths:list[str]=[], nxpath_prefix:str='' continue if isinstance(v, NXgroup): - nxobject_copy[k] = nxcopy(v, exclude_nxpaths=exclude_nxpaths, - nxpath_prefix=os_path.join(nxpath_prefix, k)) + nxobject_copy[k] = nxcopy( + v, exclude_nxpaths=exclude_nxpaths, + nxpath_prefix=os_path.join(nxpath_prefix, k)) else: nxobject_copy[k] = v - return(nxobject_copy) + return nxobject_copy + + +class SetNumexprThreads: + """ + Class that sets and keeps track of the number of processors used by + the code in general and by the num_expr package specifically. + :ivar num_core: Number of processors used by the num_expr package + :type num_core: int + """ -class set_numexpr_threads: def __init__(self, num_core): + """Initialize SetNumexprThreads.""" + # System modules from multiprocessing import cpu_count if num_core is None or num_core < 1 or num_core > cpu_count(): self._num_core = cpu_count() else: self._num_core = num_core + self._num_core_org = self._num_core def __enter__(self): - import numexpr as ne + # Third party modules + from numexpr import ( + MAX_THREADS, + set_num_threads, + ) - self._num_core_org = ne.set_num_threads(min(self._num_core, ne.MAX_THREADS)) + self._num_core_org = set_num_threads( + min(self._num_core, MAX_THREADS)) def __exit__(self, exc_type, exc_value, traceback): - import numexpr as ne + # Third party modules + from numexpr import set_num_threads - ne.set_num_threads(self._num_core_org) + set_num_threads(self._num_core_org) class Tomo: - """Processing tomography data with misalignment. - """ - def __init__(self, interactive=False, num_core=-1, output_folder='.', save_figs=None, - test_mode=False): - """Initialize with optional config input file or dictionary - """ - from logging import getLogger + """Reconstruct a set of Tomographic images.""" + def __init__( + self, interactive=False, num_core=-1, output_folder='.', + save_figs=None, test_mode=False): + """Initialize Tomo.""" + # System modules + from logging import getLogger from multiprocessing import cpu_count self.__name__ = self.__class__.__name__ @@ -516,69 +631,57 @@ def __init__(self, interactive=False, num_core=-1, output_folder='.', save_figs= if not isinstance(self._num_core, int) or self._num_core < 0: raise ValueError(f'Invalid parameter num_core ({num_core})') if self._num_core > cpu_count(): - self._logger.warning(f'num_core = {self._num_core} is larger than the number of ' - f'available processors and reduced to {cpu_count()}') + self._logger.warning( + f'num_core = {self._num_core} is larger than the number ' + + f'of available processors and reduced to {cpu_count()}') self._num_core= cpu_count() - def read(self, filename): - extension = os_path.splitext(filename)[1] - if extension == '.yml' or extension == '.yaml': - with open(filename, 'r') as f: - config = safe_load(f) -# if len(config) > 1: -# raise ValueError(f'Multiple root entries in {filename} not yet implemented') -# if len(list(config.values())[0]) > 1: -# raise ValueError(f'Multiple sample maps in {filename} not yet implemented') - return(config) - elif extension == '.nxs': - with NXFile(filename, mode='r') as nxfile: - nxroot = nxfile.readfile() - return(nxroot) - else: - raise ValueError(f'Invalid filename extension ({extension})') - - def write(self, data, filename): - extension = os_path.splitext(filename)[1] - if extension == '.yml' or extension == '.yaml': - with open(filename, 'w') as f: - safe_dump(data, f) - elif extension == '.nxs': - data.save(filename, mode='w') - elif extension == '.nc': - data.to_netcdf(os_path=filename) - else: - raise ValueError(f'Invalid filename extension ({extension})') - def gen_reduced_data(self, data, img_x_bounds=None): - """Generate the reduced tomography images. """ - from nexusformat.nexus import NXdata, NXprocess, NXroot - - from CHAP.common.models.map import import_scanparser + Generate the reduced tomography images. + + :param data: Data object containing the raw data info and + metadata required for a tomography data reduction + :type data: nexusformat.nexus.NXroot + :param img_x_bounds: Detector image bounds in the x-direction + :type img_x_bounds: tuple(int, int), list[int], optional + :return: Reduced tomography data + :rtype: nexusformat.nexus.NXroot + """ + # Third party modules + from nexusformat.nexus import ( + NXdata, + NXprocess, + NXroot, + ) self._logger.info('Generate the reduced tomography images') if img_x_bounds is not None: if not isinstance(img_x_bounds, (tuple, list)): - raise ValueError(f'Invalid parameter img_x_bounds ({img_x_bounds})') + raise ValueError( + f'Invalid parameter img_x_bounds ({img_x_bounds})') img_x_bounds = tuple(img_x_bounds) - if isinstance(data, dict): - # Create Nexus format object from input dictionary - wf = TomoWorkflow(**data) - if len(wf.sample_maps) > 1: - raise ValueError(f'Multiple sample maps not yet implemented') - nxroot = NXroot() - t0 = time() - for sample_map in wf.sample_maps: - self._logger.info(f'Start constructing the {sample_map.title} map.') - import_scanparser(sample_map.station) - sample_map.construct_nxentry(nxroot, include_raw_data=False) - self._logger.info(f'Constructed all sample maps in {time()-t0:.2f} seconds.') - nxentry = nxroot[nxroot.attrs['default']] - # Get test mode configuration info - if self._test_mode: - self._test_config = data['sample_maps'][0]['test_mode'] - elif isinstance(data, NXroot): +# if isinstance(data, dict): +# # Create Nexus format object from input dictionary +# wf = TomoWorkflow(**data) +# if len(wf.sample_maps) > 1: +# raise ValueError('Multiple sample maps not yet implemented') +# nxroot = NXroot() +# t0 = time() +# for sample_map in wf.sample_maps: +# self._logger.info( +# f'Start constructing the {sample_map.title} map') +# import_scanparser(sample_map.station) +# sample_map.construct_nxentry(nxroot, include_raw_data=False) +# self._logger.info( +# f'Constructed all sample maps in {time()-t0:.2f} seconds') +# nxentry = nxroot[nxroot.attrs['default']] +# # Get test mode configuration info +# if self._test_mode: +# self._test_config = data['sample_maps'][0]['test_mode'] +# elif isinstance(data, NXroot): + if isinstance(data, NXroot): nxentry = data[data.attrs['default']] else: raise ValueError(f'Invalid parameter data ({data})') @@ -594,7 +697,8 @@ def gen_reduced_data(self, data, img_x_bounds=None): reduced_data = self._gen_bright(nxentry, reduced_data) # Set vertical detector bounds for image stack - img_x_bounds = self._set_detector_bounds(nxentry, reduced_data, img_x_bounds=img_x_bounds) + img_x_bounds = self._set_detector_bounds( + nxentry, reduced_data, img_x_bounds=img_x_bounds) self._logger.info(f'img_x_bounds = {img_x_bounds}') reduced_data['img_x_bounds'] = img_x_bounds @@ -608,20 +712,23 @@ def gen_reduced_data(self, data, img_x_bounds=None): # Generate reduced tomography fields reduced_data = self._gen_tomo(nxentry, reduced_data) - # Create a copy of the input Nexus object and remove raw and any existing reduced data + # Create a copy of the input Nexus object and remove raw and + # any existing reduced data if isinstance(data, NXroot): - exclude_items = [f'{nxentry._name}/reduced_data/data', - f'{nxentry._name}/instrument/detector/data', - f'{nxentry._name}/instrument/detector/image_key', - f'{nxentry._name}/instrument/detector/sequence_number', - f'{nxentry._name}/sample/rotation_angle', - f'{nxentry._name}/sample/x_translation', - f'{nxentry._name}/sample/z_translation', - f'{nxentry._name}/data/data', - f'{nxentry._name}/data/image_key', - f'{nxentry._name}/data/rotation_angle', - f'{nxentry._name}/data/x_translation', - f'{nxentry._name}/data/z_translation'] + exclude_items = [ + f'{nxentry.nxname}/reduced_data/data', + f'{nxentry.nxname}/instrument/detector/data', + f'{nxentry.nxname}/instrument/detector/image_key', + f'{nxentry.nxname}/instrument/detector/sequence_number', + f'{nxentry.nxname}/sample/rotation_angle', + f'{nxentry.nxname}/sample/x_translation', + f'{nxentry.nxname}/sample/z_translation', + f'{nxentry.nxname}/data/data', + f'{nxentry.nxname}/data/image_key', + f'{nxentry.nxname}/data/rotation_angle', + f'{nxentry.nxname}/data/x_translation', + f'{nxentry.nxname}/data/z_translation', + ] nxroot = nxcopy(data, exclude_nxpaths=exclude_items) nxentry = nxroot[nxroot.attrs['default']] @@ -631,18 +738,33 @@ def gen_reduced_data(self, data, img_x_bounds=None): if 'data' not in nxentry: nxentry.data = NXdata() nxentry.attrs['default'] = 'data' - nxentry.data.makelink(nxentry.reduced_data.data.tomo_fields, name='reduced_data') - nxentry.data.makelink(nxentry.reduced_data.rotation_angle, name='rotation_angle') + nxentry.data.makelink( + nxentry.reduced_data.data.tomo_fields, name='reduced_data') + nxentry.data.makelink( + nxentry.reduced_data.rotation_angle, name='rotation_angle') nxentry.data.attrs['signal'] = 'reduced_data' - - return(nxroot) + + return nxroot def find_centers(self, nxroot, center_rows=None, center_stack_index=None): - """Find the calibrated center axis info """ - from nexusformat.nexus import NXentry, NXroot - - from CHAP.common.utils.general import is_int_pair + Find the calibrated center axis info + + :param nxroot: Data object containing the reduced data and + metadata required to find the calibrated center axis info + :type data: nexusformat.nexus.NXroot + :param center_rows: Lower and upper row indices for center + finding + :type center_rows: tuple(int, int), list[int], optional + :return: Calibrated center axis info + :rtype: dict + """ + # Third party modules + from nexusformat.nexus import ( + NXentry, + NXroot, + ) + from yaml import safe_dump self._logger.info('Find the calibrated center axis info') @@ -651,52 +773,57 @@ def find_centers(self, nxroot, center_rows=None, center_stack_index=None): nxentry = nxroot[nxroot.attrs['default']] if not isinstance(nxentry, NXentry): raise ValueError(f'Invalid nxentry ({nxentry})') - if center_rows is not None and (not isinstance(center_rows, (tuple, list)) - or len(center_rows) != 2): + if (center_rows is not None + and (not isinstance(center_rows, (tuple, list)) + or len(center_rows) != 2)): raise ValueError(f'Invalid parameter center_rows ({center_rows})') if not self._interactive and (center_rows is None or (center_rows[0] is None and center_rows[1] is None)): - self._logger.warning('center_rows unspecified, find centers at reduced data bounds') - if center_stack_index is not None and (not isinstance(center_stack_index, int) - or center_stack_index < 0): - raise ValueError(f'Invalid parameter center_stack_index ({center_stack_index})') + self._logger.warning( + 'center_rows unspecified, find centers at reduced data bounds') + if (center_stack_index is not None + and (not isinstance(center_stack_index, int) + or center_stack_index < 0)): + raise ValueError( + 'Invalid parameter center_stack_index ' + + f'({center_stack_index})') # Check if reduced data is available - if ('reduced_data' not in nxentry or 'reduced_data' not in nxentry.data): + if ('reduced_data' not in nxentry + or 'reduced_data' not in nxentry.data): raise KeyError(f'Unable to find valid reduced data in {nxentry}.') # Select the image stack to calibrate the center axis - # reduced data axes order: stack,theta,row,column - # Note: Nexus cannot follow a link if the data it points to is too big, - # so get the data from the actual place, not from nxentry.data + # reduced data axes order: stack,theta,row,column + # Note: Nexus can't follow a link if the data it points to is + # too big get the data from the actual place, not from + # nxentry.data tomo_fields_shape = nxentry.reduced_data.data.tomo_fields.shape - if len(tomo_fields_shape) != 4 or any(True for dim in tomo_fields_shape if not dim): - raise KeyError('Unable to load the required reduced tomography stack') + if (len(tomo_fields_shape) != 4 + or any(True for dim in tomo_fields_shape if not dim)): + raise KeyError( + 'Unable to load the required reduced tomography stack') num_tomo_stacks = tomo_fields_shape[0] if num_tomo_stacks == 1: center_stack_index = 0 default = 'n' else: if self._test_mode: - center_stack_index = self._test_config['center_stack_index']-1 # make offset 0 + # Convert input value to offset 0 + center_stack_index = self._test_config['center_stack_index']-1 elif self._interactive: if center_stack_index is None: - center_stack_index = input_int('\nEnter tomography stack index to calibrate ' - 'the center axis', ge=1, le=num_tomo_stacks, - default=int(1+num_tomo_stacks/2)) - else: - if (not isinstance(center_stack_index, int) or - not 0 < center_stack_index <= num_tomo_stacks): - raise ValueError('Invalid parameter center_stack_index '+ - f'({center_stack_index})') + center_stack_index = input_int( + '\nEnter tomography stack index to calibrate the ' + + 'center axis', ge=1, le=num_tomo_stacks, + default=int(1 + num_tomo_stacks/2)) center_stack_index -= 1 else: if center_stack_index is None: center_stack_index = int(num_tomo_stacks/2) - self._logger.warning('center_stack_index unspecified, use stack '+ - f'{center_stack_index+1} to find centers') - if center_stack_index >= num_tomo_stacks: - raise ValueError(f'Invalid parameter center_stack_index ({center_stack_index})') + self._logger.warning( + 'center_stack_index unspecified, use stack ' + + f'{center_stack_index+1} to find centers') default = 'y' # Get thetas (in degrees) @@ -704,8 +831,8 @@ def find_centers(self, nxroot, center_rows=None, center_stack_index=None): # Get effective pixel_size if 'zoom_perc' in nxentry.reduced_data: - eff_pixel_size = 100.*(nxentry.instrument.detector.x_pixel_size/ - nxentry.reduced_data.attrs['zoom_perc']) + eff_pixel_size = 100.0 * (nxentry.instrument.detector.x_pixel_size + / nxentry.reduced_data.attrs['zoom_perc']) else: eff_pixel_size = nxentry.instrument.detector.x_pixel_size @@ -725,13 +852,15 @@ def find_centers(self, nxroot, center_rows=None, center_stack_index=None): if lower_row == -1: lower_row = 0 if not 0 <= lower_row < tomo_fields_shape[2]-1: - raise ValueError(f'Invalid parameter center_rows ({center_rows})') + raise ValueError( + f'Invalid parameter center_rows ({center_rows})') else: lower_row = select_one_image_bound( - nxentry.reduced_data.data.tomo_fields[center_stack_index,0,:,:], - 0, bound=0, title=f'theta={round(thetas[0], 2)+0}', - bound_name='row index to find lower center', default=default, - raise_error=True) + nxentry.reduced_data.data.tomo_fields[ + center_stack_index,0,:,:], + 0, bound=0, title=f'theta={round(thetas[0], 2)+0}', + bound_name='row index to find lower center', + default=default, raise_error=True) else: if center_rows is None or center_rows[0] is None: lower_row = 0 @@ -740,14 +869,15 @@ def find_centers(self, nxroot, center_rows=None, center_stack_index=None): if lower_row == -1: lower_row = 0 if not 0 <= lower_row < tomo_fields_shape[2]-1: - raise ValueError(f'Invalid parameter center_rows ({center_rows})') - self._logger.debug('Finding center...') + raise ValueError( + f'Invalid parameter center_rows ({center_rows})') t0 = time() lower_center_offset = self._find_center_one_plane( - nxentry.reduced_data.data.tomo_fields[center_stack_index,:,lower_row,:], - lower_row, thetas, eff_pixel_size, cross_sectional_dim, path=self._output_folder, - num_core=self._num_core) - self._logger.debug(f'... done in {time()-t0:.2f} seconds') + nxentry.reduced_data.data.tomo_fields[ + center_stack_index,:,lower_row,:], + lower_row, thetas, eff_pixel_size, cross_sectional_dim, + path=self._output_folder, num_core=self._num_core) + self._logger.info(f'Finding center took {time()-t0:.2f} seconds') self._logger.debug(f'lower_row = {lower_row:.2f}') self._logger.debug(f'lower_center_offset = {lower_center_offset:.2f}') @@ -760,13 +890,16 @@ def find_centers(self, nxroot, center_rows=None, center_stack_index=None): if upper_row == -1: upper_row = tomo_fields_shape[2]-1 if not lower_row < upper_row < tomo_fields_shape[2]: - raise ValueError(f'Invalid parameter center_rows ({center_rows})') + raise ValueError( + f'Invalid parameter center_rows ({center_rows})') else: upper_row = select_one_image_bound( - nxentry.reduced_data.data.tomo_fields[center_stack_index,0,:,:], - 0, bound=tomo_fields_shape[2]-1, title=f'theta={round(thetas[0], 2)+0}', - bound_name='row index to find upper center', default=default, - raise_error=True) + nxentry.reduced_data.data.tomo_fields[ + center_stack_index,0,:,:], + 0, bound=tomo_fields_shape[2]-1, + title=f'theta = {round(thetas[0], 2)+0}', + bound_name='row index to find upper center', + default=default, raise_error=True) else: if center_rows is None or center_rows[1] is None: upper_row = tomo_fields_shape[2]-1 @@ -775,34 +908,64 @@ def find_centers(self, nxroot, center_rows=None, center_stack_index=None): if upper_row == -1: upper_row = tomo_fields_shape[2]-1 if not lower_row < upper_row < tomo_fields_shape[2]: - raise ValueError(f'Invalid parameter center_rows ({center_rows})') - self._logger.debug('Finding center...') + raise ValueError( + f'Invalid parameter center_rows ({center_rows})') t0 = time() upper_center_offset = self._find_center_one_plane( - nxentry.reduced_data.data.tomo_fields[center_stack_index,:,upper_row,:], - upper_row, thetas, eff_pixel_size, cross_sectional_dim, path=self._output_folder, - num_core=self._num_core) - self._logger.debug(f'... done in {time()-t0:.2f} seconds') + nxentry.reduced_data.data.tomo_fields[ + center_stack_index,:,upper_row,:], + upper_row, thetas, eff_pixel_size, cross_sectional_dim, + path=self._output_folder, num_core=self._num_core) + self._logger.info(f'Finding center took {time()-t0:.2f} seconds') self._logger.debug(f'upper_row = {upper_row:.2f}') self._logger.debug(f'upper_center_offset = {upper_center_offset:.2f}') - center_config = {'lower_row': lower_row, 'lower_center_offset': lower_center_offset, - 'upper_row': upper_row, 'upper_center_offset': upper_center_offset} + center_config = { + 'lower_row': lower_row, + 'lower_center_offset': lower_center_offset, + 'upper_row': upper_row, + 'upper_center_offset': upper_center_offset, + } if num_tomo_stacks > 1: - center_config['center_stack_index'] = center_stack_index+1 # save as offset 1 + # Save as offset 1 + center_config['center_stack_index'] = center_stack_index+1 # Save test data to file if self._test_mode: - with open(f'{self._output_folder}/center_config.yaml', 'w') as f: + with open(f'{self._output_folder}/center_config.yaml', 'w', + encoding='utf8') as f: safe_dump(center_config, f) - return(center_config) + return center_config - def reconstruct_data(self, nxroot, center_info, x_bounds=None, y_bounds=None, z_bounds=None): - """Reconstruct the tomography data. + def reconstruct_data( + self, nxroot, center_info, x_bounds=None, y_bounds=None, + z_bounds=None): """ - from nexusformat.nexus import NXdata, NXentry, NXprocess, NXroot - + Reconstruct the tomography data. + + :param nxroot: Reduced data + :type data: nexusformat.nexus.NXroot + :param center_info: Calibrated center axis info + :type center_info: dict + :param x_bounds: Reconstructed image bounds in the x-direction + :type x_bounds: tuple(int, int), list[int], optional + :param y_bounds: Reconstructed image bounds in the y-direction + :type y_bounds: tuple(int, int), list[int], optional + :param z_bounds: Reconstructed image bounds in the z-direction + :type z_bounds: tuple(int, int), list[int], optional + :return: Reconstructed tomography data + :rtype: dict + """ + # Third party modules + from nexusformat.nexus import ( + NXdata, + NXentry, + NXprocess, + NXroot, + ) + + # Local modules from CHAP.common.utils.general import is_int_pair self._logger.info('Reconstruct the tomography data') @@ -828,7 +991,8 @@ def reconstruct_data(self, nxroot, center_info, x_bounds=None, y_bounds=None, z_ z_bounds = tuple(z_bounds) # Check if reduced data is available - if ('reduced_data' not in nxentry or 'reduced_data' not in nxentry.data): + if ('reduced_data' not in nxentry + or 'reduced_data' not in nxentry.data): raise KeyError(f'Unable to find valid reduced data in {nxentry}.') # Create an NXprocess to store image reconstruction (meta)data @@ -839,66 +1003,82 @@ def reconstruct_data(self, nxroot, center_info, x_bounds=None, y_bounds=None, z_ lower_center_offset = center_info.get('lower_center_offset') upper_row = center_info.get('upper_row') upper_center_offset = center_info.get('upper_center_offset') - if (lower_row is None or lower_center_offset is None or upper_row is None or - upper_center_offset is None): - raise KeyError(f'Unable to find valid calibrated center axis info in {center_info}.') - center_slope = (upper_center_offset-lower_center_offset)/(upper_row-lower_row) + if (lower_row is None or lower_center_offset is None + or upper_row is None or upper_center_offset is None): + raise KeyError( + 'Unable to find valid calibrated center axis info in ' + + f'{center_info}.') + center_slope = (upper_center_offset-lower_center_offset) \ + / (upper_row-lower_row) # Get thetas (in degrees) thetas = np.asarray(nxentry.reduced_data.rotation_angle) # Reconstruct tomography data - # reduced data axes order: stack,theta,row,column - # reconstructed data order in each stack: row/z,x,y - # Note: Nexus cannot follow a link if the data it points to is too big, - # so get the data from the actual place, not from nxentry.data + # reduced data axes order: stack,theta,row,column + # reconstructed data order in each stack: row/z,x,y + # Note: Nexus can't follow a link if the data it points to is + # too big get the data from the actual place, not from + # nxentry.data if 'zoom_perc' in nxentry.reduced_data: res_title = f'{nxentry.reduced_data.attrs["zoom_perc"]}p' else: res_title = 'fullres' - load_error = False num_tomo_stacks = nxentry.reduced_data.data.tomo_fields.shape[0] tomo_recon_stacks = num_tomo_stacks*[np.array([])] for i in range(num_tomo_stacks): - # Convert reduced data stack from theta,row,column to row,theta,column - self._logger.debug(f'Reading reduced data stack {i+1}...') + # Convert reduced data stack from theta,row,column to + # row,theta,column t0 = time() tomo_stack = np.asarray(nxentry.reduced_data.data.tomo_fields[i]) - self._logger.debug(f'... done in {time()-t0:.2f} seconds') - if len(tomo_stack.shape) != 3 or any(True for dim in tomo_stack.shape if not dim): - raise ValueError(f'Unable to load tomography stack {i+1} for reconstruction') + self._logger.info( + f'Reading reduced data stack {i+1} took {time()-t0:.2f} ' + + 'seconds') + if (len(tomo_stack.shape) != 3 + or any(True for dim in tomo_stack.shape if not dim)): + raise RuntimeError( + f'Unable to load tomography stack {i+1} for ' + + 'reconstruction') tomo_stack = np.swapaxes(tomo_stack, 0, 1) - assert(len(thetas) == tomo_stack.shape[1]) - assert(0 <= lower_row < upper_row < tomo_stack.shape[0]) - center_offsets = [lower_center_offset-lower_row*center_slope, - upper_center_offset+(tomo_stack.shape[0]-1-upper_row)*center_slope] + assert len(thetas) == tomo_stack.shape[1] + assert 0 <= lower_row < upper_row < tomo_stack.shape[0] + center_offsets = [ + lower_center_offset - lower_row*center_slope, + upper_center_offset + (tomo_stack.shape[0]-1-upper_row) + * center_slope, + ] t0 = time() - self._logger.debug(f'Running _reconstruct_one_tomo_stack on {self._num_core} cores ...') - tomo_recon_stack = self._reconstruct_one_tomo_stack(tomo_stack, thetas, - center_offsets=center_offsets, num_core=self._num_core, algorithm='gridrec') - self._logger.debug(f'... done in {time()-t0:.2f} seconds') - self._logger.info(f'Reconstruction of stack {i+1} took {time()-t0:.2f} seconds') + tomo_recon_stack = self._reconstruct_one_tomo_stack( + tomo_stack, thetas, center_offsets=center_offsets, + num_core=self._num_core, algorithm='gridrec') + self._logger.info( + f'Reconstruction of stack {i+1} took {time()-t0:.2f} seconds') # Combine stacks tomo_recon_stacks[i] = tomo_recon_stack # Resize the reconstructed tomography data - # reconstructed data order in each stack: row/z,x,y + # reconstructed data order in each stack: row/z,x,y if self._test_mode: x_bounds = tuple(self._test_config.get('x_bounds')) y_bounds = tuple(self._test_config.get('y_bounds')) z_bounds = None elif self._interactive: - x_bounds, y_bounds, z_bounds = self._resize_reconstructed_data(tomo_recon_stacks, - x_bounds=x_bounds, y_bounds=y_bounds, z_bounds=z_bounds) + x_bounds, y_bounds, z_bounds = self._resize_reconstructed_data( + tomo_recon_stacks, x_bounds=x_bounds, y_bounds=y_bounds, + z_bounds=z_bounds) else: if x_bounds is None: - self._logger.warning('x_bounds unspecified, reconstruct data for full x-range') - elif not is_int_pair(x_bounds, ge=0, lt=tomo_recon_stacks[0].shape[1]): + self._logger.warning( + 'x_bounds unspecified, reconstruct data for full x-range') + elif not is_int_pair(x_bounds, ge=0, + lt=tomo_recon_stacks[0].shape[1]): raise ValueError(f'Invalid parameter x_bounds ({x_bounds})') if y_bounds is None: - self._logger.warning('y_bounds unspecified, reconstruct data for full y-range') - elif not is_int_pair(y_bounds, ge=0, lt=tomo_recon_stacks[0].shape[1]): + self._logger.warning( + 'y_bounds unspecified, reconstruct data for full y-range') + elif not is_int_pair( + y_bounds, ge=0, lt=tomo_recon_stacks[0].shape[2]): raise ValueError(f'Invalid parameter y_bounds ({y_bounds})') z_bounds = None if x_bounds is None: @@ -906,19 +1086,19 @@ def reconstruct_data(self, nxroot, center_info, x_bounds=None, y_bounds=None, z_ x_slice = int(x_range[1]/2) else: x_range = (min(x_bounds), max(x_bounds)) - x_slice = int((x_bounds[0]+x_bounds[1])/2) + x_slice = int((x_bounds[0]+x_bounds[1]) / 2) if y_bounds is None: y_range = (0, tomo_recon_stacks[0].shape[2]) - y_slice = int(y_range[1]/2) + y_slice = int(y_range[1] / 2) else: y_range = (min(y_bounds), max(y_bounds)) - y_slice = int((y_bounds[0]+y_bounds[1])/2) + y_slice = int((y_bounds[0]+y_bounds[1]) / 2) if z_bounds is None: z_range = (0, tomo_recon_stacks[0].shape[0]) - z_slice = int(z_range[1]/2) + z_slice = int(z_range[1] / 2) else: z_range = (min(z_bounds), max(z_bounds)) - z_slice = int((z_bounds[0]+z_bounds[1])/2) + z_slice = int((z_bounds[0]+z_bounds[1]) / 2) # Plot a few reconstructed image slices if self._save_figs: @@ -928,24 +1108,32 @@ def reconstruct_data(self, nxroot, center_info, x_bounds=None, y_bounds=None, z_ else: basetitle = f'recon stack {i+1}' title = f'{basetitle} {res_title} xslice{x_slice}' - quick_imshow(stack[z_range[0]:z_range[1],x_slice,y_range[0]:y_range[1]], - title=title, path=self._output_folder, save_fig=True, save_only=True) + quick_imshow( + stack[z_range[0]:z_range[1],x_slice,y_range[0]:y_range[1]], + title=title, path=self._output_folder, save_fig=True, + save_only=True) title = f'{basetitle} {res_title} yslice{y_slice}' - quick_imshow(stack[z_range[0]:z_range[1],x_range[0]:x_range[1],y_slice], - title=title, path=self._output_folder, save_fig=True, save_only=True) + quick_imshow( + stack[z_range[0]:z_range[1],x_range[0]:x_range[1],y_slice], + title=title, path=self._output_folder, save_fig=True, + save_only=True) title = f'{basetitle} {res_title} zslice{z_slice}' - quick_imshow(stack[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], - title=title, path=self._output_folder, save_fig=True, save_only=True) + quick_imshow( + stack[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], + title=title, path=self._output_folder, save_fig=True, + save_only=True) # Save test data to file - # reconstructed data order in each stack: row/z,x,y + # reconstructed data order in each stack: row/z,x,y if self._test_mode: for i, stack in enumerate(tomo_recon_stacks): - np.savetxt(f'{self._output_folder}/recon_stack_{i+1}.txt', - stack[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], fmt='%.6e') + np.savetxt( + f'{self._output_folder}/recon_stack_{i+1}.txt', + stack[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], + fmt='%.6e') # Add image reconstruction to reconstructed data NXprocess - # reconstructed data order in each stack: row/z,x,y + # reconstructed data order in each stack: row/z,x,y nxprocess.data = NXdata() nxprocess.attrs['default'] = 'data' for k, v in center_info.items(): @@ -956,12 +1144,17 @@ def reconstruct_data(self, nxroot, center_info, x_bounds=None, y_bounds=None, z_ nxprocess.y_bounds = y_bounds if z_bounds is not None: nxprocess.z_bounds = z_bounds - nxprocess.data['reconstructed_data'] = np.asarray([stack[z_range[0]:z_range[1], - x_range[0]:x_range[1],y_range[0]:y_range[1]] for stack in tomo_recon_stacks]) + nxprocess.data['reconstructed_data'] = np.asarray( + [stack[z_range[0]:z_range[1],x_range[0]:x_range[1], + y_range[0]:y_range[1]] for stack in tomo_recon_stacks]) nxprocess.data.attrs['signal'] = 'reconstructed_data' - # Create a copy of the input Nexus object and remove reduced data - exclude_items = [f'{nxentry._name}/reduced_data/data', f'{nxentry._name}/data/reduced_data'] + # Create a copy of the input Nexus object and remove reduced + # data + exclude_items = [ + f'{nxentry.nxname}/reduced_data/data', + f'{nxentry.nxname}/data/reduced_data', + ] nxroot_copy = nxcopy(nxroot, exclude_nxpaths=exclude_items) # Add the reconstructed data NXprocess to the new Nexus object @@ -970,16 +1163,36 @@ def reconstruct_data(self, nxroot, center_info, x_bounds=None, y_bounds=None, z_ if 'data' not in nxentry_copy: nxentry_copy.data = NXdata() nxentry_copy.attrs['default'] = 'data' - nxentry_copy.data.makelink(nxprocess.data.reconstructed_data, name='reconstructed_data') + nxentry_copy.data.makelink( + nxprocess.data.reconstructed_data, name='reconstructed_data') nxentry_copy.data.attrs['signal'] = 'reconstructed_data' - return(nxroot_copy) + return nxroot_copy - def combine_data(self, nxroot, x_bounds=None, y_bounds=None, z_bounds=None): + def combine_data( + self, nxroot, x_bounds=None, y_bounds=None, z_bounds=None): """Combine the reconstructed tomography stacks. - """ - from nexusformat.nexus import NXdata, NXentry, NXprocess, NXroot + :param nxroot: A stack of reconstructed tomography datasets + :type data: nexusformat.nexus.NXroot + :param x_bounds: Combined image bounds in the x-direction + :type x_bounds: tuple(int, int), list[int], optional + :param y_bounds: Combined image bounds in the y-direction + :type y_bounds: tuple(int, int), list[int], optional + :param z_bounds: Combined image bounds in the z-direction + :type z_bounds: tuple(int, int), list[int], optional + :return: Combined reconstructed tomography data + :rtype: dict + """ + # Third party modules + from nexusformat.nexus import ( + NXdata, + NXentry, + NXprocess, + NXroot, + ) + + # Local modules from CHAP.common.utils.general import is_int_pair self._logger.info('Combine the reconstructed tomography stacks') @@ -1003,38 +1216,47 @@ def combine_data(self, nxroot, x_bounds=None, y_bounds=None, z_bounds=None): z_bounds = tuple(z_bounds) # Check if reconstructed image data is available - if ('reconstructed_data' not in nxentry or 'reconstructed_data' not in nxentry.data): - raise KeyError(f'Unable to find valid reconstructed image data in {nxentry}.') + if ('reconstructed_data' not in nxentry + or 'reconstructed_data' not in nxentry.data): + raise KeyError( + f'Unable to find valid reconstructed image data in {nxentry}') - # Create an NXprocess to store combined image reconstruction (meta)data + # Create an NXprocess to store combined image reconstruction + # (meta)data nxprocess = NXprocess() # Get the reconstructed data - # reconstructed data order: stack,row(z),x,y - # Note: Nexus cannot follow a link if the data it points to is too big, - # so get the data from the actual place, not from nxentry.data - num_tomo_stacks = nxentry.reconstructed_data.data.reconstructed_data.shape[0] + # reconstructed data order: stack,row(z),x,y + # Note: Nexus can't follow a link if the data it points to is + # too big get the data from the actual place, not from + # nxentry.data + num_tomo_stacks = \ + nxentry.reconstructed_data.data.reconstructed_data.shape[0] if num_tomo_stacks == 1: self._logger.info('Only one stack available: leaving combine_data') - return(None) + return None # Combine the reconstructed stacks - # (load one stack at a time to reduce risk of hitting Nexus data access limit) + # (load one stack at a time to reduce risk of hitting Nexus + # data access limit) t0 = time() - self._logger.debug(f'Combining the reconstructed stacks ...') - tomo_recon_combined = np.asarray(nxentry.reconstructed_data.data.reconstructed_data[0]) + tomo_recon_combined = np.asarray( + nxentry.reconstructed_data.data.reconstructed_data[0]) if num_tomo_stacks > 2: - tomo_recon_combined = np.concatenate([tomo_recon_combined]+ - [nxentry.reconstructed_data.data.reconstructed_data[i] - for i in range(1, num_tomo_stacks-1)]) + tomo_recon_combined = np.concatenate( + [tomo_recon_combined] + + [nxentry.reconstructed_data.data.reconstructed_data[i] + for i in range(1, num_tomo_stacks-1)]) if num_tomo_stacks > 1: - tomo_recon_combined = np.concatenate([tomo_recon_combined]+ - [nxentry.reconstructed_data.data.reconstructed_data[num_tomo_stacks-1]]) - self._logger.debug(f'... done in {time()-t0:.2f} seconds') - self._logger.info(f'Combining the reconstructed stacks took {time()-t0:.2f} seconds') + tomo_recon_combined = np.concatenate( + [tomo_recon_combined] + + [nxentry.reconstructed_data.data.reconstructed_data[ + num_tomo_stacks-1]]) + self._logger.info( + f'Combining the reconstructed stacks took {time()-t0:.2f} seconds') # Resize the combined tomography data stacks - # combined data order: row/z,x,y + # combined data order: row/z,x,y if self._test_mode: x_bounds = None y_bounds = None @@ -1044,16 +1266,20 @@ def combine_data(self, nxroot, x_bounds=None, y_bounds=None, z_bounds=None): x_bounds = (-1, -1) if y_bounds is None and y_bounds in nxentry.reconstructed_data: y_bounds = (-1, -1) - x_bounds, y_bounds, z_bounds = self._resize_reconstructed_data(tomo_recon_combined, - z_only=True) + x_bounds, y_bounds, z_bounds = self._resize_reconstructed_data( + tomo_recon_combined, z_only=True) else: if x_bounds is None: - self._logger.warning('x_bounds unspecified, reconstruct data for full x-range') - elif not is_int_pair(x_bounds, ge=0, lt=tomo_recon_stacks[0].shape[1]): + self._logger.warning( + 'x_bounds unspecified, reconstruct data for full x-range') + elif not is_int_pair( + x_bounds, ge=0, lt=tomo_recon_combined.shape[1]): raise ValueError(f'Invalid parameter x_bounds ({x_bounds})') if y_bounds is None: - self._logger.warning('y_bounds unspecified, reconstruct data for full y-range') - elif not is_int_pair(y_bounds, ge=0, lt=tomo_recon_stacks[0].shape[1]): + self._logger.warning( + 'y_bounds unspecified, reconstruct data for full y-range') + elif not is_int_pair( + y_bounds, ge=0, lt=tomo_recon_combined.shape[2]): raise ValueError(f'Invalid parameter y_bounds ({y_bounds})') z_bounds = None if x_bounds is None: @@ -1061,40 +1287,48 @@ def combine_data(self, nxroot, x_bounds=None, y_bounds=None, z_bounds=None): x_slice = int(x_range[1]/2) else: x_range = x_bounds - x_slice = int((x_bounds[0]+x_bounds[1])/2) + x_slice = int((x_bounds[0]+x_bounds[1]) / 2) if y_bounds is None: y_range = (0, tomo_recon_combined.shape[2]) y_slice = int(y_range[1]/2) else: y_range = y_bounds - y_slice = int((y_bounds[0]+y_bounds[1])/2) + y_slice = int((y_bounds[0]+y_bounds[1]) / 2) if z_bounds is None: z_range = (0, tomo_recon_combined.shape[0]) z_slice = int(z_range[1]/2) else: z_range = z_bounds - z_slice = int((z_bounds[0]+z_bounds[1])/2) + z_slice = int((z_bounds[0]+z_bounds[1]) / 2) # Plot a few combined image slices if self._save_figs: - quick_imshow(tomo_recon_combined[z_range[0]:z_range[1],x_slice,y_range[0]:y_range[1]], - title=f'recon combined xslice{x_slice}', path=self._output_folder, - save_fig=True, save_only=True) - quick_imshow(tomo_recon_combined[z_range[0]:z_range[1],x_range[0]:x_range[1],y_slice], - title=f'recon combined yslice{y_slice}', path=self._output_folder, - save_fig=True, save_only=True) - quick_imshow(tomo_recon_combined[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], - title=f'recon combined zslice{z_slice}', path=self._output_folder, - save_fig=True, save_only=True) + quick_imshow( + tomo_recon_combined[z_range[0]:z_range[1],x_slice, + y_range[0]:y_range[1]], + title=f'recon combined xslice{x_slice}', + path=self._output_folder, save_fig=True, save_only=True) + quick_imshow( + tomo_recon_combined[z_range[0]:z_range[1], + x_range[0]:x_range[1],y_slice], + title=f'recon combined yslice{y_slice}', + path=self._output_folder, save_fig=True, save_only=True) + quick_imshow( + tomo_recon_combined[z_slice,x_range[0]:x_range[1], + y_range[0]:y_range[1]], + title=f'recon combined zslice{z_slice}', + path=self._output_folder, save_fig=True, save_only=True) # Save test data to file - # combined data order: row/z,x,y + # combined data order: row/z,x,y if self._test_mode: - np.savetxt(f'{self._output_folder}/recon_combined.txt', tomo_recon_combined[ - z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], fmt='%.6e') + np.savetxt( + f'{self._output_folder}/recon_combined.txt', + tomo_recon_combined[z_slice,x_range[0]:x_range[1], + y_range[0]:y_range[1]], fmt='%.6e') # Add image reconstruction to reconstructed data NXprocess - # combined data order: row/z,x,y + # combined data order: row/z,x,y nxprocess.data = NXdata() nxprocess.attrs['default'] = 'data' if x_bounds is not None: @@ -1104,12 +1338,15 @@ def combine_data(self, nxroot, x_bounds=None, y_bounds=None, z_bounds=None): if z_bounds is not None: nxprocess.z_bounds = z_bounds nxprocess.data['combined_data'] = tomo_recon_combined[ - z_range[0]:z_range[1],x_range[0]:x_range[1],y_range[0]:y_range[1]] + z_range[0]:z_range[1],x_range[0]:x_range[1],y_range[0]:y_range[1]] nxprocess.data.attrs['signal'] = 'combined_data' - # Create a copy of the input Nexus object and remove reconstructed data - exclude_items = [f'{nxentry._name}/reconstructed_data/data', - f'{nxentry._name}/data/reconstructed_data'] + # Create a copy of the input Nexus object and remove + # reconstructed data + exclude_items = [ + f'{nxentry.nxname}/reconstructed_data/data', + f'{nxentry.nxname}/data/reconstructed_data', + ] nxroot_copy = nxcopy(nxroot, exclude_nxpaths=exclude_items) # Add the combined data NXprocess to the new Nexus object @@ -1118,39 +1355,51 @@ def combine_data(self, nxroot, x_bounds=None, y_bounds=None, z_bounds=None): if 'data' not in nxentry_copy: nxentry_copy.data = NXdata() nxentry_copy.attrs['default'] = 'data' - nxentry_copy.data.makelink(nxprocess.data.combined_data, name='combined_data') + nxentry_copy.data.makelink( + nxprocess.data.combined_data, name='combined_data') nxentry_copy.data.attrs['signal'] = 'combined_data' - return(nxroot_copy) + return nxroot_copy def _gen_dark(self, nxentry, reduced_data): - """Generate dark field. - """ + """Generate dark field.""" + # Third party modules from nexusformat.nexus import NXdata - from CHAP.common.models.map import get_scanparser, import_scanparser + # Local modules + from CHAP.common.models.map import ( + get_scanparser, + import_scanparser, + ) # Get the dark field images image_key = nxentry.instrument.detector.get('image_key', None) if image_key and 'data' in nxentry.instrument.detector: - field_indices = [index for index, key in enumerate(image_key) if key == 2] + field_indices = [index for index, key in enumerate(image_key) + if key == 2] tdf_stack = nxentry.instrument.detector.data[field_indices,:,:] - # RV the default NXtomo form does not accomodate bright or dark field stacks + # RV the default NXtomo form does not accomodate dark field + # stacks else: - import_scanparser(nxentry.instrument.source.attrs['station'], - nxentry.instrument.source.attrs['experiment_type']) + import_scanparser( + nxentry.instrument.source.attrs['station'], + nxentry.instrument.source.attrs['experiment_type']) dark_field_scans = nxentry.spec_scans.dark_field detector_prefix = str(nxentry.instrument.detector.local_name) tdf_stack = [] for nxsubentry_name, nxsubentry in dark_field_scans.items(): scan_number = int(nxsubentry_name.split('_')[-1]) - scanparser = get_scanparser(dark_field_scans.attrs['spec_file'], scan_number) - image_offset = int(nxsubentry.instrument.detector.frame_start_number) + scanparser = get_scanparser( + dark_field_scans.attrs['spec_file'], scan_number) + image_offset = int( + nxsubentry.instrument.detector.frame_start_number) num_image = len(nxsubentry.sample.rotation_angle) - tdf_stack.append(scanparser.get_detector_data(detector_prefix, + tdf_stack.append( + scanparser.get_detector_data( + detector_prefix, (image_offset, image_offset+num_image))) if isinstance(tdf_stack, list): - assert(len(tdf_stack) == 1) # TODO + assert len(tdf_stack) == 1 # RV tdf_stack = tdf_stack[0] # Take median @@ -1160,15 +1409,16 @@ def _gen_dark(self, nxentry, reduced_data): tdf = np.median(tdf_stack, axis=0) del tdf_stack else: - raise ValueError(f'Invalid tdf_stack shape ({tdf_stack.shape})') + raise RuntimeError(f'Invalid tdf_stack shape ({tdf_stack.shape})') # Remove dark field intensities above the cutoff -#RV tdf_cutoff = None - tdf_cutoff = tdf.min()+2*(np.median(tdf)-tdf.min()) +# tdf_cutoff = None + tdf_cutoff = tdf.min() + 2 * (np.median(tdf)-tdf.min()) self._logger.debug(f'tdf_cutoff = {tdf_cutoff}') if tdf_cutoff is not None: if not isinstance(tdf_cutoff, (int, float)) or tdf_cutoff < 0: - self._logger.warning(f'Ignoring illegal value of tdf_cutoff {tdf_cutoff}') + self._logger.warning( + f'Ignoring illegal value of tdf_cutoff {tdf_cutoff}') else: tdf[tdf > tdf_cutoff] = np.nan self._logger.debug(f'tdf_cutoff = {tdf_cutoff}') @@ -1176,68 +1426,82 @@ def _gen_dark(self, nxentry, reduced_data): # Remove nans tdf_mean = np.nanmean(tdf) self._logger.debug(f'tdf_mean = {tdf_mean}') - np.nan_to_num(tdf, copy=False, nan=tdf_mean, posinf=tdf_mean, neginf=0.) + np.nan_to_num( + tdf, copy=False, nan=tdf_mean, posinf=tdf_mean, neginf=0.0) # Plot dark field if self._save_figs: - quick_imshow(tdf, title='dark field', path=self._output_folder, save_fig=True, - save_only=True) + quick_imshow( + tdf, title='dark field', path=self._output_folder, + save_fig=True, save_only=True) # Add dark field to reduced data NXprocess reduced_data.data = NXdata() reduced_data.data['dark_field'] = tdf - return(reduced_data) + return reduced_data def _gen_bright(self, nxentry, reduced_data): - """Generate bright field. - """ + """Generate bright field.""" + # Third party modules from nexusformat.nexus import NXdata - from CHAP.common.models.map import get_scanparser, import_scanparser + # Local modules + from CHAP.common.models.map import ( + get_scanparser, + import_scanparser, + ) # Get the bright field images image_key = nxentry.instrument.detector.get('image_key', None) if image_key and 'data' in nxentry.instrument.detector: - field_indices = [index for index, key in enumerate(image_key) if key == 1] + field_indices = [index for index, key in enumerate(image_key) + if key == 1] tbf_stack = nxentry.instrument.detector.data[field_indices,:,:] - # RV the default NXtomo form does not accomodate bright or dark field stacks + # RV the default NXtomo form does not accomodate bright + # field stacks else: - import_scanparser(nxentry.instrument.source.attrs['station'], - nxentry.instrument.source.attrs['experiment_type']) + import_scanparser( + nxentry.instrument.source.attrs['station'], + nxentry.instrument.source.attrs['experiment_type']) bright_field_scans = nxentry.spec_scans.bright_field detector_prefix = str(nxentry.instrument.detector.local_name) tbf_stack = [] for nxsubentry_name, nxsubentry in bright_field_scans.items(): scan_number = int(nxsubentry_name.split('_')[-1]) - scanparser = get_scanparser(bright_field_scans.attrs['spec_file'], scan_number) - image_offset = int(nxsubentry.instrument.detector.frame_start_number) + scanparser = get_scanparser( + bright_field_scans.attrs['spec_file'], scan_number) + image_offset = int( + nxsubentry.instrument.detector.frame_start_number) num_image = len(nxsubentry.sample.rotation_angle) - tbf_stack.append(scanparser.get_detector_data(detector_prefix, + tbf_stack.append( + scanparser.get_detector_data( + detector_prefix, (image_offset, image_offset+num_image))) if isinstance(tbf_stack, list): - assert(len(tbf_stack) == 1) # TODO + assert len(tbf_stack) == 1 # RV tbf_stack = tbf_stack[0] # Take median if more than one image - """Median or mean: It may be best to try the median because of some image - artifacts that arise due to crinkles in the upstream kapton tape windows - causing some phase contrast images to appear on the detector. - One thing that also may be useful in a future implementation is to do a - brightfield adjustment on EACH frame of the tomo based on a ROI in the - corner of the frame where there is no sample but there is the direct X-ray - beam because there is frame to frame fluctuations from the incoming beam. - We don’t typically account for them but potentially could. - """ - from nexusformat.nexus import NXdata - + # + # Median or mean: It may be best to try the median because of + # some image artifacts that arise due to crinkles in the + # upstream kapton tape windows causing some phase contrast + # images to appear on the detector. + # + # One thing that also may be useful in a future implementation + # is to do a brightfield adjustment on EACH frame of the tomo + # based on a ROI in the corner of the frame where there is no + # sample but there is the direct X-ray beam because there is + # frame to frame fluctuations from the incoming beam. We don’t + # typically account for them but potentially could. if tbf_stack.ndim == 2: tbf = tbf_stack elif tbf_stack.ndim == 3: tbf = np.median(tbf_stack, axis=0) del tbf_stack else: - raise ValueError(f'Invalid tbf_stack shape ({tbf_stacks.shape})') + raise RuntimeError(f'Invalid tbf_stack shape ({tbf_stack.shape})') # Subtract dark field if 'data' in reduced_data and 'dark_field' in reduced_data.data: @@ -1251,38 +1515,47 @@ def _gen_bright(self, nxentry, reduced_data): # Plot bright field if self._save_figs: - quick_imshow(tbf, title='bright field', path=self._output_folder, save_fig=True, - save_only=True) + quick_imshow( + tbf, title='bright field', path=self._output_folder, + save_fig=True, save_only=True) # Add bright field to reduced data NXprocess - if 'data' not in reduced_data: + if 'data' not in reduced_data: reduced_data.data = NXdata() reduced_data.data['bright_field'] = tbf - return(reduced_data) + return reduced_data def _set_detector_bounds(self, nxentry, reduced_data, img_x_bounds=None): - """Set vertical detector bounds for each image stack. - Right now the range is the same for each set in the image stack. """ - from CHAP.common.models.map import get_scanparser, import_scanparser + Set vertical detector bounds for each image stack.Right now the + range is the same for each set in the image stack. + """ + # Local modules + from CHAP.common.models.map import ( + get_scanparser, + import_scanparser, + ) from CHAP.common.utils.general import is_index_range if self._test_mode: - return(tuple(self._test_config['img_x_bounds'])) + return tuple(self._test_config['img_x_bounds']) # Get the first tomography image and the reference heights image_key = nxentry.instrument.detector.get('image_key', None) if image_key and 'data' in nxentry.instrument.detector: - field_indices = [index for index, key in enumerate(image_key) if key == 0] - first_image = np.asarray(nxentry.instrument.detector.data[field_indices[0],:,:]) + field_indices = [index for index, key in enumerate(image_key) + if key == 0] + first_image = np.asarray( + nxentry.instrument.detector.data[field_indices[0],:,:]) theta = float(nxentry.sample.rotation_angle[field_indices[0]]) z_translation_all = nxentry.sample.z_translation[field_indices] vertical_shifts = sorted(list(set(z_translation_all))) num_tomo_stacks = len(vertical_shifts) else: - import_scanparser(nxentry.instrument.source.attrs['station'], - nxentry.instrument.source.attrs['experiment_type']) + import_scanparser( + nxentry.instrument.source.attrs['station'], + nxentry.instrument.source.attrs['experiment_type']) tomo_field_scans = nxentry.spec_scans.tomo_fields num_tomo_stacks = len(tomo_field_scans.keys()) center_stack_index = int(num_tomo_stacks/2) @@ -1290,25 +1563,29 @@ def _set_detector_bounds(self, nxentry, reduced_data, img_x_bounds=None): vertical_shifts = [] for i, nxsubentry in enumerate(tomo_field_scans.items()): scan_number = int(nxsubentry[0].split('_')[-1]) - scanparser = get_scanparser(tomo_field_scans.attrs['spec_file'], scan_number) - image_offset = int(nxsubentry[1].instrument.detector.frame_start_number) + scanparser = get_scanparser( + tomo_field_scans.attrs['spec_file'], scan_number) + image_offset = int( + nxsubentry[1].instrument.detector.frame_start_number) vertical_shifts.append(nxsubentry[1].sample.z_translation) if i == center_stack_index: - first_image = scanparser.get_detector_data(detector_prefix, image_offset) + first_image = scanparser.get_detector_data( + detector_prefix, image_offset) theta = float(nxsubentry[1].sample.rotation_angle[0]) # Select image bounds - title = f'tomography image at theta={round(theta, 2)+0}' + title = f'tomography image at theta = {round(theta, 2)+0}' if img_x_bounds is not None: if is_index_range(img_x_bounds, ge=0, le=first_image.shape[0]): return img_x_bounds + if self._interactive: + self._logger.warning( + f'Invalid parameter img_x_bounds ({img_x_bounds}), ' + + 'ignoring img_x_bounds') + img_x_bounds = None else: - if self._interactive: - self._logger.warning(f'Invalid parameter img_x_bounds ({img_x_bounds}), ' + - 'ignoring img_x_bounds') - img_x_bounds = None - else: - raise ValueError(f'Invalid parameter img_x_bounds ({img_x_bounds})') + raise ValueError( + f'Invalid parameter img_x_bounds ({img_x_bounds})') if nxentry.instrument.source.attrs['station'] in ('id1a3', 'id3a'): pixel_size = nxentry.instrument.detector.x_pixel_size # Try to get a fit from the bright field @@ -1317,49 +1594,57 @@ def _set_detector_bounds(self, nxentry, reduced_data, img_x_bounds=None): x_sum = np.sum(tbf, 1) x_sum_min = x_sum.min() x_sum_max = x_sum.max() - fit = Fit.fit_data(x_sum, 'rectangle', x=np.array(range(len(x_sum))), form='atan', - guess=True) + fit = Fit.fit_data( + x_sum, 'rectangle', x=np.array(range(len(x_sum))), + form='atan', guess=True) parameters = fit.best_values x_low_fit = parameters.get('center1', None) x_upp_fit = parameters.get('center2', None) sig_low = parameters.get('sigma1', None) sig_upp = parameters.get('sigma2', None) - have_fit = fit.success and x_low_fit is not None and x_upp_fit is not None and \ - sig_low is not None and sig_upp is not None and \ - 0 <= x_low_fit < x_upp_fit <= x_sum.size and \ - (sig_low+sig_upp)/(x_upp_fit-x_low_fit) < 0.1 + have_fit = (fit.success and x_low_fit is not None + and x_upp_fit is not None and sig_low is not None + and sig_upp is not None + and 0 <= x_low_fit < x_upp_fit <= x_sum.size + and (sig_low+sig_upp) / (x_upp_fit-x_low_fit) < 0.1) if have_fit: # Set a 5% margin on each side - margin = 0.05*(x_upp_fit-x_low_fit) + margin = 0.05 * (x_upp_fit-x_low_fit) x_low_fit = max(0, x_low_fit-margin) x_upp_fit = min(tbf_shape[0], x_upp_fit+margin) if num_tomo_stacks == 1: if have_fit: - # Set the default range to enclose the full fitted window + # Set the default range to enclose the full fitted + # window x_low = int(x_low_fit) x_upp = int(x_upp_fit) else: - # Center a default range of 1 mm (RV: can we get this from the slits?) - num_x_min = int((1.0-0.5*pixel_size)/pixel_size) - x_low = int(0.5*(tbf_shape[0]-num_x_min)) + # Center a default range of 1 mm + # RV can we get this from the slits? + num_x_min = int((1.0 - 0.5*pixel_size) / pixel_size) + x_low = int((tbf_shape[0]-num_x_min) / 2) x_upp = x_low+num_x_min else: # Get the default range from the reference heights delta_z = vertical_shifts[1]-vertical_shifts[0] for i in range(2, num_tomo_stacks): - delta_z = min(delta_z, vertical_shifts[i]-vertical_shifts[i-1]) + delta_z = min( + delta_z, vertical_shifts[i]-vertical_shifts[i-1]) self._logger.debug(f'delta_z = {delta_z}') - num_x_min = int((delta_z-0.5*pixel_size)/pixel_size) + num_x_min = int((delta_z - 0.5*pixel_size) / pixel_size) self._logger.debug(f'num_x_min = {num_x_min}') if num_x_min > tbf_shape[0]: - self._logger.warning('Image bounds and pixel size prevent seamless stacking') + self._logger.warning( + 'Image bounds and pixel size prevent seamless ' + + 'stacking') if have_fit: - # Center the default range relative to the fitted window - x_low = int(0.5*(x_low_fit+x_upp_fit-num_x_min)) + # Center the default range relative to the fitted + # window + x_low = int((x_low_fit+x_upp_fit-num_x_min) / 2) x_upp = x_low+num_x_min else: # Center the default range - x_low = int(0.5*(tbf_shape[0]-num_x_min)) + x_low = int((tbf_shape[0]-num_x_min) / 2) x_upp = x_low+num_x_min if not self._interactive: img_x_bounds = (x_low, x_upp) @@ -1375,10 +1660,11 @@ def _set_detector_bounds(self, nxentry, reduced_data, img_x_bounds=None): tmp[x_upp-1,:] = tmp_max quick_imshow(tmp, title=title) del tmp - quick_plot((range(x_sum.size), x_sum), - ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'), - ([x_upp, x_upp], [x_sum_min, x_sum_max], 'r-'), - title='sum over theta and y') + quick_plot( + (range(x_sum.size), x_sum), + ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'), + ([x_upp, x_upp], [x_sum_min, x_sum_max], 'r-'), + title='sum over theta and y') print(f'lower bound = {x_low} (inclusive)') print(f'upper bound = {x_upp} (exclusive)]') accept = input_yesno('Accept these bounds (y/n)?', 'y') @@ -1389,33 +1675,39 @@ def _set_detector_bounds(self, nxentry, reduced_data, img_x_bounds=None): img_x_bounds = (x_low, x_upp) else: while True: - mask, img_x_bounds = draw_mask_1d(x_sum, title='select x data range', - legend='sum over theta and y') + _, img_x_bounds = draw_mask_1d( + x_sum, title='select x data range', + legend='sum over theta and y') if len(img_x_bounds) == 1: break - else: - print(f'Choose a single connected data range') + print('Choose a single connected data range') img_x_bounds = tuple(img_x_bounds[0]) - if (num_tomo_stacks > 1 and img_x_bounds[1]-img_x_bounds[0]+1 < - int((delta_z-0.5*pixel_size)/pixel_size)): - self._logger.warning('Image bounds and pixel size prevent seamless stacking') + if (num_tomo_stacks > 1 + and (img_x_bounds[1]-img_x_bounds[0]+1) + < int((delta_z - 0.5*pixel_size) / pixel_size)): + self._logger.warning( + 'Image bounds and pixel size prevent seamless stacking') else: if num_tomo_stacks > 1: - raise NotImplementedError('Selecting image bounds for multiple stacks on FMB') + raise NotImplementedError( + 'Selecting image bounds for multiple stacks on FMB') # For FMB: use the first tomography image to select range - # RV: revisit if they do tomography with multiple stacks + # RV revisit if they do tomography with multiple stacks x_sum = np.sum(first_image, 1) x_sum_min = x_sum.min() x_sum_max = x_sum.max() if self._interactive: - print('Select vertical data reduction range from first tomography image') + print( + 'Select vertical data reduction range from first ' + + 'tomography image') img_x_bounds = select_image_bounds(first_image, 0, title=title) if img_x_bounds is None: - raise ValueError('Unable to select image bounds') + raise RuntimeError('Unable to select image bounds') else: if img_x_bounds is None: - self._logger.warning('img_x_bounds unspecified, reduced data of entire '+ - 'detector range') + self._logger.warning( + 'img_x_bounds unspecified, reduce data for entire ' + + 'detector range') img_x_bounds = (0, first_image.shape[0]) # Plot results @@ -1426,68 +1718,90 @@ def _set_detector_bounds(self, nxentry, reduced_data, img_x_bounds=None): tmp_max = tmp.max() tmp[x_low,:] = tmp_max tmp[x_upp-1,:] = tmp_max - quick_imshow(tmp, title=title, path=self._output_folder, save_fig=True, save_only=True) - quick_plot((range(x_sum.size), x_sum), - ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'), - ([x_upp, x_upp], [x_sum_min, x_sum_max], 'r-'), - title='sum over theta and y', path=self._output_folder, save_fig=True, - save_only=True) + quick_imshow( + tmp, title=title, path=self._output_folder, save_fig=True, + save_only=True) + quick_plot( + (range(x_sum.size), x_sum), + ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'), + ([x_upp, x_upp], [x_sum_min, x_sum_max], 'r-'), + title='sum over theta and y', path=self._output_folder, + save_fig=True, save_only=True) del tmp - return(img_x_bounds) + return img_x_bounds def _set_zoom_or_skip(self): - """Set zoom and/or theta skip to reduce memory the requirement for the analysis. """ -# if input_yesno('\nDo you want to zoom in to reduce memory requirement (y/n)?', 'n'): -# zoom_perc = input_int(' Enter zoom percentage', ge=1, le=100) + Set zoom and/or theta skip to reduce memory the requirement + for the analysis. + """ +# if input_yesno( +# '\nDo you want to zoom in to reduce memory ' +# + 'requirement (y/n)?', 'n'): +# zoom_perc = input_int( +# ' Enter zoom percentage', ge=1, le=100) # else: # zoom_perc = None zoom_perc = None -# if input_yesno('Do you want to skip thetas to reduce memory requirement (y/n)?', 'n'): -# num_theta_skip = input_int(' Enter the number skip theta interval', ge=0, -# lt=num_theta) +# if input_yesno( +# 'Do you want to skip thetas to reduce memory ' +# + 'requirement (y/n)?', 'n'): +# num_theta_skip = input_int( +# ' Enter the number skip theta interval', +# ge=0, lt=num_theta) # else: # num_theta_skip = None num_theta_skip = None self._logger.debug(f'zoom_perc = {zoom_perc}') self._logger.debug(f'num_theta_skip = {num_theta_skip}') - return(zoom_perc, num_theta_skip) + return zoom_perc, num_theta_skip def _gen_tomo(self, nxentry, reduced_data): - """Generate tomography fields. - """ - import numexpr as ne - import scipy.ndimage as spi + """Generate tomography fields.""" + # Third party modules + from numexpr import evaluate + from scipy.ndimage import zoom - from CHAP.common.models.map import get_scanparser, import_scanparser + # Local modules + from CHAP.common.models.map import ( + get_scanparser, + import_scanparser, + ) # Get full bright field tbf = np.asarray(reduced_data.data.bright_field) tbf_shape = tbf.shape # Get image bounds - img_x_bounds = tuple(reduced_data.get('img_x_bounds', (0, tbf_shape[0]))) - img_y_bounds = tuple(reduced_data.get('img_y_bounds', (0, tbf_shape[1]))) + img_x_bounds = tuple( + reduced_data.get('img_x_bounds', (0, tbf_shape[0]))) + img_y_bounds = tuple( + reduced_data.get('img_y_bounds', (0, tbf_shape[1]))) # Get resized dark field # if 'dark_field' in data: -# tbf = np.asarray(reduced_data.data.dark_field[ -# img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]]) +# tbf = np.asarray( +# reduced_data.data.dark_field[ +# img_x_bounds[0]:img_x_bounds[1], +# img_y_bounds[0]:img_y_bounds[1]]) # else: # self._logger.warning('Dark field unavailable') # tdf = None tdf = None # Resize bright field - if img_x_bounds != (0, tbf.shape[0]) or img_y_bounds != (0, tbf.shape[1]): - tbf = tbf[img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]] + if (img_x_bounds != (0, tbf.shape[0]) + or img_y_bounds != (0, tbf.shape[1])): + tbf = tbf[img_x_bounds[0]:img_x_bounds[1], + img_y_bounds[0]:img_y_bounds[1]] # Get the tomography images image_key = nxentry.instrument.detector.get('image_key', None) if image_key and 'data' in nxentry.instrument.detector: - field_indices_all = [index for index, key in enumerate(image_key) if key == 0] + field_indices_all = [index for index, key in enumerate(image_key) + if key == 0] z_translation_all = nxentry.sample.z_translation[field_indices_all] z_translation_levels = sorted(list(set(z_translation_all))) num_tomo_stacks = len(z_translation_levels) @@ -1498,32 +1812,42 @@ def _gen_tomo(self, nxentry, reduced_data): tomo_stacks = [] for i, z_translation in enumerate(z_translation_levels): field_indices = [field_indices_all[index] - for index, z in enumerate(z_translation_all) if z == z_translation] - horizontal_shift = list(set(nxentry.sample.x_translation[field_indices])) - assert(len(horizontal_shift) == 1) + for index, z in enumerate(z_translation_all) + if z == z_translation] + horizontal_shift = list( + set(nxentry.sample.x_translation[field_indices])) + assert len(horizontal_shift) == 1 horizontal_shifts += horizontal_shift - vertical_shift = list(set(nxentry.sample.z_translation[field_indices])) - assert(len(vertical_shift) == 1) + vertical_shift = list( + set(nxentry.sample.z_translation[field_indices])) + assert len(vertical_shift) == 1 vertical_shifts += vertical_shift - sequence_numbers = nxentry.instrument.detector.sequence_number[field_indices] + sequence_numbers = nxentry.instrument.detector.sequence_number[ + field_indices] if thetas is None: - thetas = np.asarray(nxentry.sample.rotation_angle[field_indices]) \ - [sequence_numbers] + thetas = np.asarray( + nxentry.sample.rotation_angle[ + field_indices])[sequence_numbers] else: - assert(all(thetas[i] == nxentry.sample.rotation_angle[field_indices[index]] - for i, index in enumerate(sequence_numbers))) - assert(list(set(sequence_numbers)) == [i for i in range(len(sequence_numbers))]) - if list(sequence_numbers) == [i for i in range(len(sequence_numbers))]: - tomo_stack = np.asarray(nxentry.instrument.detector.data[field_indices]) + assert all( + thetas[i] == nxentry.sample.rotation_angle[ + field_indices[index]] + for i, index in enumerate(sequence_numbers)) + assert (list(set(sequence_numbers)) == + list(np.arange(0, (len(sequence_numbers))))) + if (list(sequence_numbers) == + list(np.arange(0, (len(sequence_numbers))))): + tomo_stack = np.asarray( + nxentry.instrument.detector.data[field_indices]) else: - raise ValueError('Unable to load the tomography images') + raise RuntimeError('Unable to load the tomography images') tomo_stacks.append(tomo_stack) else: - import_scanparser(nxentry.instrument.source.attrs['station'], - nxentry.instrument.source.attrs['experiment_type']) + import_scanparser( + nxentry.instrument.source.attrs['station'], + nxentry.instrument.source.attrs['experiment_type']) tomo_field_scans = nxentry.spec_scans.tomo_fields num_tomo_stacks = len(tomo_field_scans.keys()) - center_stack_index = int(num_tomo_stacks/2) detector_prefix = str(nxentry.instrument.detector.local_name) thetas = None tomo_stacks = [] @@ -1531,12 +1855,16 @@ def _gen_tomo(self, nxentry, reduced_data): vertical_shifts = [] for nxsubentry_name, nxsubentry in tomo_field_scans.items(): scan_number = int(nxsubentry_name.split('_')[-1]) - scanparser = get_scanparser(tomo_field_scans.attrs['spec_file'], scan_number) - image_offset = int(nxsubentry.instrument.detector.frame_start_number) + scanparser = get_scanparser( + tomo_field_scans.attrs['spec_file'], scan_number) + image_offset = int( + nxsubentry.instrument.detector.frame_start_number) if thetas is None: thetas = np.asarray(nxsubentry.sample.rotation_angle) num_image = len(thetas) - tomo_stacks.append(scanparser.get_detector_data(detector_prefix, + tomo_stacks.append( + scanparser.get_detector_data( + detector_prefix, (image_offset, image_offset+num_image))) horizontal_shifts.append(nxsubentry.sample.x_translation) vertical_shifts.append(nxsubentry.sample.z_translation) @@ -1544,89 +1872,91 @@ def _gen_tomo(self, nxentry, reduced_data): reduced_tomo_stacks = [] for i, tomo_stack in enumerate(tomo_stacks): # Resize the tomography images - # Right now the range is the same for each set in the image stack. - if img_x_bounds != (0, tbf.shape[0]) or img_y_bounds != (0, tbf.shape[1]): - t0 = time() + # Right now the range is the same for each set in the stack + if (img_x_bounds != (0, tbf.shape[0]) + or img_y_bounds != (0, tbf.shape[1])): tomo_stack = tomo_stack[:,img_x_bounds[0]:img_x_bounds[1], - img_y_bounds[0]:img_y_bounds[1]].astype('float64') - self._logger.debug(f'Resizing tomography images took {time()-t0:.2f} seconds') + img_y_bounds[0]:img_y_bounds[1]].astype('float64') # Subtract dark field if tdf is not None: - t0 = time() try: - with set_numexpr_threads(self._num_core): - ne.evaluate('tomo_stack-tdf', out=tomo_stack) - except: - raise RuntimeError('Unable to subtract dark field, try reducing' - + f' the detector range (currently img_x_bounds = {img_x_bounds}, and' - + f' img_y_bounds = {img_y_bounds})') - self._logger.debug(f'Subtracting dark field took {time()-t0:.2f} seconds') + with SetNumexprThreads(self._num_core): + evaluate('tomo_stack-tdf', out=tomo_stack) + except TypeError as e: + sys_exit( + f'\nA {type(e).__name__} occured while subtracting ' + + 'the dark field with num_expr.evaluate()' + + '\nTry reducing the detector range' + + f'\n(currently img_x_bounds = {img_x_bounds}, and ' + + f'img_y_bounds = {img_y_bounds})\n') # Normalize - t0 = time() try: - with set_numexpr_threads(self._num_core): - ne.evaluate('tomo_stack/tbf', out=tomo_stack, truediv=True) - except: - raise RuntimeError('Unable to normalize the tomography data, try reducing' - + f' the detector range (currently img_x_bounds = {img_x_bounds}, and' - + f' img_y_bounds = {img_y_bounds})') - self._logger.debug(f'Normalizing took {time()-t0:.2f} seconds') + with SetNumexprThreads(self._num_core): + evaluate('tomo_stack/tbf', out=tomo_stack, truediv=True) + except TypeError as e: + sys_exit( + f'\nA {type(e).__name__} occured while normalizing the ' + + 'tomography data with num_expr.evaluate()' + + '\nTry reducing the detector range' + + f'\n(currently img_x_bounds = {img_x_bounds}, and ' + + f'img_y_bounds = {img_y_bounds})\n') # Remove non-positive values and linearize data - t0 = time() - cutoff = 1.e-6 - with set_numexpr_threads(self._num_core): - ne.evaluate('where(tomo_stack num_core_tomopy_limit: - self._logger.debug(f'Running find_center_vo on {num_core_tomopy_limit} cores ...') - tomo_center = tomopy.find_center_vo(sinogram, ncore=num_core_tomopy_limit) + if num_core > NUM_CORE_TOMOPY_LIMIT: + self._logger.debug( + f'Running find_center_vo on {NUM_CORE_TOMOPY_LIMIT} cores ...') + tomo_center = find_center_vo( + sinogram, ncore=NUM_CORE_TOMOPY_LIMIT) else: - self._logger.debug(f'Running find_center_vo on {num_core} cores ...') - tomo_center = tomopy.find_center_vo(sinogram, ncore=num_core) - self._logger.debug(f'... done in {time()-t0:.2f} seconds') - self._logger.info(f'Finding center using Nghia Vo’s method took {time()-t0:.2f} seconds') + tomo_center = find_center_vo(sinogram, ncore=num_core) + self._logger.info( + f'Finding center using Nghia Vo’s method took {time()-t0:.2f} ' + + 'seconds') center_offset_vo = tomo_center-center - self._logger.info(f'Center at row {row} using Nghia Vo’s method = {center_offset_vo:.2f}') + self._logger.info( + f'Center at row {row} using Nghia Vo’s method = ' + + f'{center_offset_vo:.2f}') t0 = time() - self._logger.debug(f'Running _reconstruct_one_plane on {self._num_core} cores ...') - recon_plane = self._reconstruct_one_plane(sinogram_T, tomo_center, thetas, - eff_pixel_size, cross_sectional_dim, False, num_core) - self._logger.debug(f'... done in {time()-t0:.2f} seconds') - self._logger.info(f'Reconstructing row {row} took {time()-t0:.2f} seconds') + recon_plane = self._reconstruct_one_plane( + sinogram_t, tomo_center, thetas, eff_pixel_size, + cross_sectional_dim, False, num_core) + self._logger.info( + f'Reconstructing row {row} took {time()-t0:.2f} seconds') title = f'edges row{row} center offset{center_offset_vo:.2f} Vo' self._plot_edges_one_plane(recon_plane, title, path=path) # Try using phase correlation method -# if input_yesno('Try finding center using phase correlation (y/n)?', 'n'): +# if input_yesno(' +# Try finding center using phase correlation (y/n)?', +# 'n'): # t0 = time() -# self._logger.debug(f'Running find_center_pc ...') -# tomo_center = tomopy.find_center_pc(sinogram, sinogram, tol=0.1, rotc_guess=tomo_center) +# tomo_center = find_center_pc( +# sinogram, sinogram, tol=0.1, rotc_guess=tomo_center) # error = 1. # while error > tol: # prev = tomo_center -# tomo_center = tomopy.find_center_pc(sinogram, sinogram, tol=tol, -# rotc_guess=tomo_center) +# tomo_center = find_center_pc( +# sinogram, sinogram, tol=tol, rotc_guess=tomo_center) # error = np.abs(tomo_center-prev) -# self._logger.debug(f'... done in {time()-t0:.2f} seconds') -# self._logger.info('Finding center using the phase correlation method took '+ -# f'{time()-t0:.2f} seconds') +# self._logger.info( +# 'Finding center using the phase correlation method ' +# + f'took {time()-t0:.2f} seconds') # center_offset = tomo_center-center -# print(f'Center at row {row} using phase correlation = {center_offset:.2f}') +# print( +# f'Center at row {row} using phase correlation = ' +# + f'{center_offset:.2f}') # t0 = time() -# self._logger.debug(f'Running _reconstruct_one_plane on {self._num_core} cores ...') -# recon_plane = self._reconstruct_one_plane(sinogram_T, tomo_center, thetas, -# eff_pixel_size, cross_sectional_dim, False, num_core) -# self._logger.debug(f'... done in {time()-t0:.2f} seconds') -# self._logger.info(f'Reconstructing row {row} took {time()-t0:.2f} seconds') +# recon_plane = self._reconstruct_one_plane( +# sinogram_t, tomo_center, thetas, eff_pixel_size, +# cross_sectional_dim, False, num_core) +# self._logger.info( +# f'Reconstructing row {row} took {time()-t0:.2f} seconds') # -# title = f'edges row{row} center_offset{center_offset:.2f} PC' +# title = \ +# f'edges row{row} center_offset{center_offset:.2f} PC' # self._plot_edges_one_plane(recon_plane, title, path=path) # Select center location -# if input_yesno('Accept a center location (y) or continue search (n)?', 'y'): - if True: -# center_offset = input_num(' Enter chosen center offset', ge=-center, le=center, -# default=center_offset_vo) - center_offset = center_offset_vo - del sinogram_T - del recon_plane - return float(center_offset) - - # perform center finding search - while True: - center_offset_low = input_int('\nEnter lower bound for center offset', ge=-center, - le=center) - center_offset_upp = input_int('Enter upper bound for center offset', - ge=center_offset_low, le=center) - if center_offset_upp == center_offset_low: - center_offset_step = 1 - else: - center_offset_step = input_int('Enter step size for center offset search', ge=1, - le=center_offset_upp-center_offset_low) - num_center_offset = 1+int((center_offset_upp-center_offset_low)/center_offset_step) - center_offsets = np.linspace(center_offset_low, center_offset_upp, num_center_offset) - for center_offset in center_offsets: - if center_offset == center_offset_vo: - continue - t0 = time() - self._logger.debug(f'Running _reconstruct_one_plane on {num_core} cores ...') - recon_plane = self._reconstruct_one_plane(sinogram_T, center_offset+center, thetas, - eff_pixel_size, cross_sectional_dim, False, num_core) - self._logger.debug(f'... done in {time()-t0:.2f} seconds') - self._logger.info(f'Reconstructing center_offset {center_offset} took '+ - f'{time()-t0:.2f} seconds') - title = f'edges row{row} center_offset{center_offset:.2f}' - self._plot_edges_one_plane(recon_plane, title, path=path) - if input_int('\nContinue (0) or end the search (1)', ge=0, le=1): - break - - del sinogram_T +# if input_yesno( +# 'Accept a center location (y) or continue search (n)?', +# 'y'): +# center_offset = input_num(' Enter chosen center offset', +# ge=-center, le=center, default=center_offset_vo) +# return float(center_offset) + + # Perform center finding search +# while True: +# center_offset_low = input_int( +# '\nEnter lower bound for center offset', ge=-center,le=center) +# center_offset_upp = input_int( +# 'Enter upper bound for center offset', ge=center_offset_low, +# le=center) +# if center_offset_upp == center_offset_low: +# center_offset_step = 1 +# else: +# center_offset_step = input_int( +# 'Enter step size for center offset search', ge=1, +# le=center_offset_upp-center_offset_low) +# num_center_offset = 1 + int( +# (center_offset_upp-center_offset_low) / center_offset_step) +# center_offsets = np.linspace( +# center_offset_low, center_offset_upp, num_center_offset) +# for center_offset in center_offsets: +# if center_offset == center_offset_vo: +# continue +# t0 = time() +# recon_plane = self._reconstruct_one_plane( +# sinogram_t, center_offset+center, thetas, eff_pixel_size, +# cross_sectional_dim, False, num_core) +# self._logger.info( +# f'Reconstructing center_offset {center_offset} took ' +# + 'f{time()-t0:.2f} seconds') +# title = f'edges row{row} center_offset{center_offset:.2f}' +# self._plot_edges_one_plane(recon_plane, title, path=path) +# if input_int('\nContinue (0) or end the search (1)', ge=0, le=1): +# break + + del sinogram_t del recon_plane - center_offset = input_num(' Enter chosen center offset', ge=-center, le=center) +# center_offset = input_num( +# ' Enter chosen center offset', ge=-center, le=center) + center_offset = center_offset_vo + return float(center_offset) - def _reconstruct_one_plane(self, tomo_plane_T, center, thetas, eff_pixel_size, + def _reconstruct_one_plane( + self, tomo_plane_t, center, thetas, eff_pixel_size, cross_sectional_dim, plot_sinogram=True, num_core=1): - """Invert the sinogram for a single tomography plane. - """ - import scipy.ndimage as spi + """Invert the sinogram for a single tomography plane.""" + from scipy.ndimage import gaussian_filter from skimage.transform import iradon - import tomopy + from tomopy import misc - # tomo_plane_T index order: column,theta - assert(0 <= center < tomo_plane_T.shape[0]) - center_offset = center-tomo_plane_T.shape[0]/2 - two_offset = 2*int(np.round(center_offset)) + # tomo_plane_t index order: column,theta + assert 0 <= center < tomo_plane_t.shape[0] + center_offset = center-tomo_plane_t.shape[0]/2 + two_offset = 2 * int(np.round(center_offset)) two_offset_abs = np.abs(two_offset) - max_rad = int(0.55*(cross_sectional_dim/eff_pixel_size)) # 10% slack to avoid edge effects - if max_rad > 0.5*tomo_plane_T.shape[0]: - max_rad = 0.5*tomo_plane_T.shape[0] - dist_from_edge = max(1, int(np.floor((tomo_plane_T.shape[0]-two_offset_abs)/2.)-max_rad)) + # Add 10% slack to max_rad to avoid edge effects + max_rad = int(0.55 * (cross_sectional_dim/eff_pixel_size)) + if max_rad > 0.5*tomo_plane_t.shape[0]: + max_rad = 0.5*tomo_plane_t.shape[0] + dist_from_edge = max(1, int(np.floor( + (tomo_plane_t.shape[0] - two_offset_abs) / 2.0) - max_rad)) if two_offset >= 0: - self._logger.debug(f'sinogram range = [{two_offset+dist_from_edge}, {-dist_from_edge}]') - sinogram = tomo_plane_T[two_offset+dist_from_edge:-dist_from_edge,:] + self._logger.debug( + f'sinogram range = [{two_offset+dist_from_edge}, ' + + f'{-dist_from_edge}]') + sinogram = tomo_plane_t[ + two_offset+dist_from_edge:-dist_from_edge,:] else: - self._logger.debug(f'sinogram range = [{dist_from_edge}, {two_offset-dist_from_edge}]') - sinogram = tomo_plane_T[dist_from_edge:two_offset-dist_from_edge,:] + self._logger.debug( + f'sinogram range = [{dist_from_edge}, ' + + f'{two_offset-dist_from_edge}]') + sinogram = tomo_plane_t[dist_from_edge:two_offset-dist_from_edge,:] if plot_sinogram: - quick_imshow(sinogram.T, f'sinogram center offset{center_offset:.2f}', aspect='auto', - path=self._output_folder, save_fig=self._save_figs, save_only=self._save_only, - block=self._block) + quick_imshow( + sinogram.T, f'sinogram center offset{center_offset:.2f}', + aspect='auto', path=self._output_folder, + save_fig=self._save_figs, save_only=self._save_only, + block=self._block) # Inverting sinogram t0 = time() recon_sinogram = iradon(sinogram, theta=thetas, circle=True) - self._logger.debug(f'Inverting sinogram took {time()-t0:.2f} seconds') + self._logger.info(f'Inverting sinogram took {time()-t0:.2f} seconds') del sinogram # Performing Gaussian filtering and removing ring artifacts - recon_parameters = None#self._config.get('recon_parameters') + recon_parameters = None #self._config.get('recon_parameters') if recon_parameters is None: sigma = 1.0 ring_width = 15 else: sigma = recon_parameters.get('gaussian_sigma', 1.0) if not is_num(sigma, ge=0.0): - self._logger.warning(f'Invalid gaussian_sigma ({sigma}) in '+ - '_reconstruct_one_plane, set to a default value of 1.0') + self._logger.warning( + f'Invalid gaussian_sigma ({sigma}) in ' + + '_reconstruct_one_plane, set to a default of 1.0') sigma = 1.0 ring_width = recon_parameters.get('ring_width', 15) if not isinstance(ring_width, int) or ring_width < 0: - self._logger.warning(f'Invalid ring_width ({ring_width}) in '+ - '_reconstruct_one_plane, set to a default value of 15') + self._logger.warning( + f'Invalid ring_width ({ring_width}) in ' + + '_reconstruct_one_plane, set to a default of 15') ring_width = 15 - t0 = time() - recon_sinogram = spi.gaussian_filter(recon_sinogram, sigma, mode='nearest') + recon_sinogram = gaussian_filter( + recon_sinogram, sigma, mode='nearest') recon_clean = np.expand_dims(recon_sinogram, axis=0) del recon_sinogram - recon_clean = tomopy.misc.corr.remove_ring(recon_clean, rwidth=ring_width, ncore=num_core) - self._logger.debug(f'Filtering and removing ring artifacts took {time()-t0:.2f} seconds') + recon_clean = misc.corr.remove_ring( + recon_clean, rwidth=ring_width, ncore=num_core) return recon_clean def _plot_edges_one_plane(self, recon_plane, title, path=None): + """ + Create an "edges plot" for a singled reconstructed tomography + data plane. + """ from skimage.restoration import denoise_tv_chambolle - vis_parameters = None#self._config.get('vis_parameters') + vis_parameters = None #self._config.get('vis_parameters') if vis_parameters is None: weight = 0.1 else: weight = vis_parameters.get('denoise_weight', 0.1) if not is_num(weight, ge=0.0): - self._logger.warning(f'Invalid weight ({weight}) in _plot_edges_one_plane, '+ - 'set to a default value of 0.1') + self._logger.warning( + f'Invalid weight ({weight}) in _plot_edges_one_plane, ' + + 'set to a default of 0.1') weight = 0.1 edges = denoise_tv_chambolle(recon_plane, weight=weight) vmax = np.max(edges[0,:,:]) vmin = -vmax if path is None: path = self._output_folder - quick_imshow(edges[0,:,:], f'{title} coolwarm', path=path, cmap='coolwarm', - save_fig=self._save_figs, save_only=self._save_only, block=self._block) - quick_imshow(edges[0,:,:], f'{title} gray', path=path, cmap='gray', vmin=vmin, vmax=vmax, - save_fig=self._save_figs, save_only=self._save_only, block=self._block) + quick_imshow( + edges[0,:,:], f'{title} coolwarm', path=path, cmap='coolwarm', + save_fig=self._save_figs, save_only=self._save_only, + block=self._block) + quick_imshow( + edges[0,:,:], f'{title} gray', path=path, cmap='gray', vmin=vmin, + vmax=vmax, save_fig=self._save_figs, save_only=self._save_only, + block=self._block) del edges - def _reconstruct_one_tomo_stack(self, tomo_stack, thetas, center_offsets=[], num_core=1, + def _reconstruct_one_tomo_stack( + self, tomo_stack, thetas, center_offsets=None, num_core=1, algorithm='gridrec'): - """Reconstruct a single tomography stack. - """ - import tomopy + """Reconstruct a single tomography stack.""" + # Third party modules + from tomopy import ( + astra, + misc, + prep, + recon, + ) # tomo_stack order: row,theta,column - # input thetas must be in degrees - # centers_offset: tomography axis shift in pixels relative to column center + # input thetas must be in degrees + # centers_offset: tomography axis shift in pixels relative + # to column center # RV should we remove stripes? # https://tomopy.readthedocs.io/en/latest/api/tomopy.prep.stripe.html # RV should we remove rings? # https://tomopy.readthedocs.io/en/latest/api/tomopy.misc.corr.html - # RV: Add an option to do (extra) secondary iterations later or to do some sort of convergence test? - if not len(center_offsets): + # RV add an option to do (extra) secondary iterations later or + # to do some sort of convergence test? + if center_offsets is None: centers = np.zeros((tomo_stack.shape[0])) elif len(center_offsets) == 2: - centers = np.linspace(center_offsets[0], center_offsets[1], tomo_stack.shape[0]) + centers = np.linspace( + center_offsets[0], center_offsets[1], tomo_stack.shape[0]) else: if center_offsets.size != tomo_stack.shape[0]: - raise ValueError('center_offsets dimension mismatch in reconstruct_one_tomo_stack') + raise RuntimeError( + 'center_offsets dimension mismatch in ' + + 'reconstruct_one_tomo_stack') centers = center_offsets centers += tomo_stack.shape[2]/2 # Get reconstruction parameters - recon_parameters = None#self._config.get('recon_parameters') + recon_parameters = None #self._config.get('recon_parameters') if recon_parameters is None: sigma = 2.0 secondary_iters = 0 @@ -1862,80 +2237,99 @@ def _reconstruct_one_tomo_stack(self, tomo_stack, thetas, center_offsets=[], num else: sigma = recon_parameters.get('stripe_fw_sigma', 2.0) if not is_num(sigma, ge=0): - self._logger.warning(f'Invalid stripe_fw_sigma ({sigma}) in '+ - '_reconstruct_one_tomo_stack, set to a default value of 2.0') + self._logger.warning( + f'Invalid stripe_fw_sigma ({sigma}) in ' + + '_reconstruct_one_tomo_stack, set to a default of 2.0') ring_width = 15 secondary_iters = recon_parameters.get('secondary_iters', 0) if not isinstance(secondary_iters, int) or secondary_iters < 0: - self._logger.warning(f'Invalid secondary_iters ({secondary_iters}) in '+ - '_reconstruct_one_tomo_stack, set to a default value of 0 (skip them)') + self._logger.warning( + f'Invalid secondary_iters ({secondary_iters}) in ' + + '_reconstruct_one_tomo_stack, set to a default of 0 ' + + '(i.e., skip them)') ring_width = 0 ring_width = recon_parameters.get('ring_width', 15) if not isinstance(ring_width, int) or ring_width < 0: - self._logger.warning(f'Invalid ring_width ({ring_width}) in '+ - '_reconstruct_one_plane, set to a default value of 15') + self._logger.warning( + f'Invalid ring_width ({ring_width}) in ' + + '_reconstruct_one_plane, set to a default of 15') ring_width = 15 # Remove horizontal stripe - t0 = time() - if num_core > num_core_tomopy_limit: - self._logger.debug('Running remove_stripe_fw on {num_core_tomopy_limit} cores ...') - tomo_stack = tomopy.prep.stripe.remove_stripe_fw(tomo_stack, sigma=sigma, - ncore=num_core_tomopy_limit) + if num_core > NUM_CORE_TOMOPY_LIMIT: + tomo_stack = prep.stripe.remove_stripe_fw( + tomo_stack, sigma=sigma, ncore=NUM_CORE_TOMOPY_LIMIT) else: - self._logger.debug(f'Running remove_stripe_fw on {num_core} cores ...') - tomo_stack = tomopy.prep.stripe.remove_stripe_fw(tomo_stack, sigma=sigma, - ncore=num_core) - self._logger.debug(f'... tomopy.prep.stripe.remove_stripe_fw took {time()-t0:.2f} seconds') + tomo_stack = prep.stripe.remove_stripe_fw( + tomo_stack, sigma=sigma, ncore=num_core) # Perform initial image reconstruction self._logger.debug('Performing initial image reconstruction') t0 = time() - self._logger.debug(f'Running recon on {num_core} cores ...') - tomo_recon_stack = tomopy.recon(tomo_stack, np.radians(thetas), centers, - sinogram_order=True, algorithm=algorithm, ncore=num_core) - self._logger.debug(f'... done in {time()-t0:.2f} seconds') - self._logger.info(f'Performing initial image reconstruction took {time()-t0:.2f} seconds') + tomo_recon_stack = recon( + tomo_stack, np.radians(thetas), centers, sinogram_order=True, + algorithm=algorithm, ncore=num_core) + self._logger.info( + f'Performing initial image reconstruction took {time()-t0:.2f} ' + + 'seconds') # Run optional secondary iterations if secondary_iters > 0: - self._logger.debug(f'Running {secondary_iters} secondary iterations') - #options = {'method':'SIRT_CUDA', 'proj_type':'cuda', 'num_iter':secondary_iters} - #RV: doesn't work for me: - #"Error: CUDA error 803: system has unsupported display driver/cuda driver combination." - #options = {'method':'SIRT', 'proj_type':'linear', 'MinConstraint': 0, 'num_iter':secondary_iters} - #SIRT did not finish while running overnight - #options = {'method':'SART', 'proj_type':'linear', 'num_iter':secondary_iters} - options = {'method':'SART', 'proj_type':'linear', 'MinConstraint': 0, - 'num_iter':secondary_iters} + self._logger.debug( + 'Running {secondary_iters} secondary iterations') +# options = { +# 'method': 'SIRT_CUDA', +# 'proj_type': 'cuda', +# 'num_iter': secondary_iters +# } +# RV doesn't work for me: +# "Error: CUDA error 803: system has unsupported display driver/cuda driver +# combination." +# options = { +# 'method': 'SIRT', +# 'proj_type': 'linear', +# 'MinConstraint': 0, +# 'num_iter':secondary_iters +# } +# SIRT did not finish while running overnight +# options = { +# 'method': 'SART', +# 'proj_type': 'linear', +# 'num_iter':secondary_iters +# } + options = { + 'method': 'SART', + 'proj_type': 'linear', + 'MinConstraint': 0, + 'num_iter': secondary_iters, + } t0 = time() - self._logger.debug(f'Running recon on {num_core} cores ...') - tomo_recon_stack = tomopy.recon(tomo_stack, np.radians(thetas), centers, - init_recon=tomo_recon_stack, options=options, sinogram_order=True, - algorithm=tomopy.astra, ncore=num_core) - self._logger.debug(f'... done in {time()-t0:.2f} seconds') - self._logger.info(f'Performing secondary iterations took {time()-t0:.2f} seconds') + tomo_recon_stack = recon( + tomo_stack, np.radians(thetas), centers, + init_recon=tomo_recon_stack, options=options, + sinogram_order=True, algorithm=astra, ncore=num_core) + self._logger.info( + f'Performing secondary iterations took {time()-t0:.2f} ' + + 'seconds') # Remove ring artifacts - t0 = time() - tomopy.misc.corr.remove_ring(tomo_recon_stack, rwidth=ring_width, out=tomo_recon_stack, - ncore=num_core) - self._logger.debug(f'Removing ring artifacts took {time()-t0:.2f} seconds') + misc.corr.remove_ring( + tomo_recon_stack, rwidth=ring_width, out=tomo_recon_stack, + ncore=num_core) return tomo_recon_stack - def _resize_reconstructed_data(self, data, x_bounds=None, y_bounds=None, z_bounds=None, - z_only=False): - """Resize the reconstructed tomography data. - """ + def _resize_reconstructed_data(self, data, x_bounds=None, y_bounds=None, + z_bounds=None, z_only=False): + """Resize the reconstructed tomography data.""" # Data order: row(z),x,y or stack,row(z),x,y if isinstance(data, list): for stack in data: - assert(stack.ndim == 3) + assert stack.ndim == 3 num_tomo_stacks = len(data) tomo_recon_stacks = data else: - assert(data.ndim == 3) + assert data.ndim == 3 num_tomo_stacks = 1 tomo_recon_stacks = [data] @@ -1944,21 +2338,25 @@ def _resize_reconstructed_data(self, data, x_bounds=None, y_bounds=None, z_bound elif not z_only and x_bounds is None: # Selecting x bounds (in yz-plane) tomosum = 0 - [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(0,2)) - for i in range(num_tomo_stacks)] - select_x_bounds = input_yesno('\nDo you want to change the image x-bounds (y/n)?', 'y') + for i in range(num_tomo_stacks): + tomosum = tomosum + np.sum(tomo_recon_stacks[i], axis=(0,2)) + select_x_bounds = input_yesno( + '\nDo you want to change the image x-bounds (y/n)?', 'y') if not select_x_bounds: x_bounds = None else: accept = False index_ranges = None while not accept: - mask, x_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges, - title='select x data range', legend='recon stack sum yz') + _, x_bounds = draw_mask_1d( + tomosum, current_index_ranges=index_ranges, + title='select x data range', + legend='recon stack sum yz') while len(x_bounds) != 1: print('Please select exactly one continuous range') - mask, x_bounds = draw_mask_1d(tomosum, title='select x data range', - legend='recon stack sum yz') + _, x_bounds = draw_mask_1d( + tomosum, title='select x data range', + legend='recon stack sum yz') x_bounds = x_bounds[0] accept = True self._logger.debug(f'x_bounds = {x_bounds}') @@ -1968,53 +2366,60 @@ def _resize_reconstructed_data(self, data, x_bounds=None, y_bounds=None, z_bound elif not z_only and y_bounds is None: # Selecting y bounds (in xz-plane) tomosum = 0 - [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(0,1)) - for i in range(num_tomo_stacks)] - select_y_bounds = input_yesno('\nDo you want to change the image y-bounds (y/n)?', 'y') + for i in range(num_tomo_stacks): + tomosum = tomosum + np.sum(tomo_recon_stacks[i], axis=(0,1)) + select_y_bounds = input_yesno( + '\nDo you want to change the image y-bounds (y/n)?', 'y') if not select_y_bounds: y_bounds = None else: accept = False index_ranges = None while not accept: - mask, y_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges, - title='select x data range', legend='recon stack sum xz') + _, y_bounds = draw_mask_1d( + tomosum, current_index_ranges=index_ranges, + title='select x data range', + legend='recon stack sum xz') while len(y_bounds) != 1: print('Please select exactly one continuous range') - mask, y_bounds = draw_mask_1d(tomosum, title='select x data range', - legend='recon stack sum xz') + _, y_bounds = draw_mask_1d( + tomosum, title='select x data range', + legend='recon stack sum xz') y_bounds = y_bounds[0] accept = True self._logger.debug(f'y_bounds = {y_bounds}') - # Selecting z bounds (in xy-plane) (only valid for a single image stack) + # Selecting z bounds (in xy-plane) + # (only valid for a single image stack) if z_bounds == (-1, -1): z_bounds = None elif z_bounds is None and num_tomo_stacks != 1: tomosum = 0 - [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(1,2)) - for i in range(num_tomo_stacks)] - select_z_bounds = input_yesno('Do you want to change the image z-bounds (y/n)?', 'n') + for i in range(num_tomo_stacks): + tomosum = tomosum + np.sum(tomo_recon_stacks[i], axis=(1,2)) + select_z_bounds = input_yesno( + 'Do you want to change the image z-bounds (y/n)?', 'n') if not select_z_bounds: z_bounds = None else: accept = False index_ranges = None while not accept: - mask, z_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges, - title='select x data range', legend='recon stack sum xy') + _, z_bounds = draw_mask_1d( + tomosum, current_index_ranges=index_ranges, + title='select x data range', + legend='recon stack sum xy') while len(z_bounds) != 1: print('Please select exactly one continuous range') - mask, z_bounds = draw_mask_1d(tomosum, title='select x data range', - legend='recon stack sum xy') + _, z_bounds = draw_mask_1d( + tomosum, title='select x data range', + legend='recon stack sum xy') z_bounds = z_bounds[0] accept = True self._logger.debug(f'z_bounds = {z_bounds}') - return(x_bounds, y_bounds, z_bounds) + return x_bounds, y_bounds, z_bounds if __name__ == '__main__': - from CHAP.processor import main main() - diff --git a/CHAP/tomo/reader.py b/CHAP/tomo/reader.py index 709d3d3..b90fde1 100755 --- a/CHAP/tomo/reader.py +++ b/CHAP/tomo/reader.py @@ -1,5 +1,7 @@ -#!/usr/bin/env python +'''Tomography command line reader''' + +# Local modules +from CHAP.reader import main if __name__ == '__main__': - from CHAP.reader import main main() diff --git a/CHAP/tomo/writer.py b/CHAP/tomo/writer.py index b00fa9f..47e944b 100755 --- a/CHAP/tomo/writer.py +++ b/CHAP/tomo/writer.py @@ -1,5 +1,7 @@ -#!/usr/bin/env python +'''Tomography command line writer''' + +# Local modules +from CHAP.reader import main if __name__ == '__main__': - from CHAP.writer import main main() From 8aa335e217b66e530b674deabda9c77fd1aba796 Mon Sep 17 00:00:00 2001 From: Rolf Verberg Date: Thu, 20 Apr 2023 13:19:17 -0400 Subject: [PATCH 3/6] fix: fixed errors introduced in style changes CHAP/common/utils/fit.py: - Added back default values for interactive and guess in Fit.fit() - Fixed centers to a keyword argument to Fit.fit() in fit_multipeak - Added np.ndarray as a valued type for centers in fit_multipeak CHAP/common/utils/scanparsers.py: Fixed misplaced parenthesis in get_horizontal_shift() and get_vertical_shift() CHAP/edd/models.py: Fixed name of get_unique_ds to get_ds_unique CHAP/runner.py: Removed a line that was erroneously left in merge --- CHAP/common/utils/fit.py | 8 ++++++-- CHAP/common/utils/scanparsers.py | 16 ++++++++-------- CHAP/edd/models.py | 3 ++- CHAP/runner.py | 1 - 4 files changed, 16 insertions(+), 12 deletions(-) diff --git a/CHAP/common/utils/fit.py b/CHAP/common/utils/fit.py index 160b9fa..89c9f15 100755 --- a/CHAP/common/utils/fit.py +++ b/CHAP/common/utils/fit.py @@ -866,11 +866,15 @@ def fit(self, **kwargs): raise ValueError( 'Invalid value of keyword argument interactive ' + f'({interactive})') + else: + interactive = False if 'guess' in kwargs: guess = kwargs.pop('guess') if not isinstance(guess, bool): raise ValueError( f'Invalid value of keyword argument guess ({guess})') + else: + guess = False if 'try_linear_fit' in kwargs: try_linear_fit = kwargs.pop('try_linear_fit') if not isinstance(try_linear_fit, bool): @@ -1610,7 +1614,7 @@ def fit_multipeak( raise ValueError(f'Invalid parameter x_eval ({x_eval})') fit = cls(y, x=x, normalize=normalize) success = fit.fit( - centers, fit_type=fit_type, peak_models=peak_models, + centers=centers, fit_type=fit_type, peak_models=peak_models, fwhm_max=fwhm_max, center_exprs=center_exprs, background=background, print_report=print_report, plot=plot) if x_eval is None: @@ -1630,7 +1634,7 @@ def fit( """Fit the model to the input data.""" if centers is None: raise ValueError('Missing required parameter centers') - if not isinstance(centers, (int, float, tuple, list)): + if not isinstance(centers, (int, float, tuple, list, np.ndarray)): raise ValueError(f'Invalid parameter centers ({centers})') self._fwhm_max = fwhm_max self._create_model( diff --git a/CHAP/common/utils/scanparsers.py b/CHAP/common/utils/scanparsers.py index 5e3cc96..e11f271 100755 --- a/CHAP/common/utils/scanparsers.py +++ b/CHAP/common/utils/scanparsers.py @@ -942,19 +942,19 @@ def get_theta_vals(self): 'num': int(self.pars['nframes_real'])} def get_horizontal_shift(self): - horizontal_shift = self.pars.get('rams4x', - self.pars.get('ramsx'), None) + horizontal_shift = self.pars.get( + 'rams4x', self.pars.get('ramsx', None)) if horizontal_shift is None: - raise RuntimeError(f'{self.scan_title}: cannot determine the ' - + 'horizontal shift') + raise RuntimeError( + f'{self.scan_title}: cannot determine the horizontal shift') return horizontal_shift def get_vertical_shift(self): - vertical_shift = self.pars.get('rams4z', - self.pars.get('ramsz'), None) + vertical_shift = self.pars.get( + 'rams4z', self.pars.get('ramsz', None)) if vertical_shift is None: - raise RuntimeError(f'{self.scan_title}: cannot determine the ' - + 'vertical shift') + raise RuntimeError( + f'{self.scan_title}: cannot determine the vertical shift') return vertical_shift def get_starting_image_index(self): diff --git a/CHAP/edd/models.py b/CHAP/edd/models.py index c92a4df..573fc43 100644 --- a/CHAP/edd/models.py +++ b/CHAP/edd/models.py @@ -180,7 +180,8 @@ def unique_ds(self): :rtype: np.ndarray, np.ndarray ''' - unique_hkls, unique_ds = self.material().get_unique_ds(tth_tol=self.hkl_tth_tol, tth_max=self.tth_max) + unique_hkls, unique_ds = self.material().get_ds_unique( + tth_tol=self.hkl_tth_tol, tth_max=self.tth_max) return(unique_hkls, unique_ds) diff --git a/CHAP/runner.py b/CHAP/runner.py index 53154b4..c1bc24d 100755 --- a/CHAP/runner.py +++ b/CHAP/runner.py @@ -73,7 +73,6 @@ def runner(opts): else: name = item kwargs = {} - kwargs['interactive'] = opts.interactive mod_name, cls_name = name.split('.') module = __import__(f'CHAP.{mod_name}', fromlist=[cls_name]) obj = getattr(module, cls_name)() From db16af631a77b5fcb5ff24c87c2390d2842f3f72 Mon Sep 17 00:00:00 2001 From: Rolf Verberg Date: Thu, 20 Apr 2023 15:31:28 -0400 Subject: [PATCH 4/6] style: added the .pylintrc file --- .pylintrc | 716 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 716 insertions(+) create mode 100644 .pylintrc diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000..de19556 --- /dev/null +++ b/.pylintrc @@ -0,0 +1,716 @@ +[MAIN] + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Clear in-memory caches upon conclusion of linting. Useful if running pylint +# in a server-like mode. +clear-cache-post-run=no + +# Load and enable all available extensions. Use --list-extensions to see a list +# all available extensions. +#enable-all-extensions= + +# In error mode, messages with a category besides ERROR or FATAL are +# suppressed, and no reports are done by default. Error mode is compatible with +# disabling specific errors. +#errors-only= + +# Always return a 0 (non-error) status code, even if lint errors are found. +# This is primarily useful in continuous integration scripts. +#exit-zero= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-allow-list= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. (This is an alternative name to extension-pkg-allow-list +# for backward compatibility.) +extension-pkg-whitelist=pydantic + +# Return non-zero exit code if any of these messages/categories are detected, +# even if score is above --fail-under value. Syntax same as enable. Messages +# specified are enabled, while categories only check already-enabled messages. +fail-on= + +# Specify a score threshold under which the program will exit with error. +fail-under=10 + +# Interpret the stdin as a python script, whose filename needs to be passed as +# the module_or_package argument. +#from-stdin= + +# Files or directories to be skipped. They should be base names, not paths. +ignore=CVS + +# Add files or directories matching the regular expressions patterns to the +# ignore-list. The regex matches against paths and can be in Posix or Windows +# format. Because '\\' represents the directory delimiter on Windows systems, +# it can't be used as an escape character. +ignore-paths= + +# Files or directories matching the regular expression patterns are skipped. +# The regex matches against base names, not paths. The default value ignores +# Emacs file locks +ignore-patterns=^\.# + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis). It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules=nexusformat.nexus + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use, and will cap the count on Windows to +# avoid hangs. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins=pylint_pydantic + +# Pickle collected data for later comparisons. +persistent=yes + +# Minimum Python version to use for version dependent checks. Will default to +# the version used to run pylint. +py-version=3.11 + +# Discover python modules and packages in the file system subtree. +recursive=no + +# Add paths to the list of the source roots. Supports globbing patterns. The +# source root is an absolute path or a path relative to the current working +# directory used to determine a package namespace for modules located under the +# source root. +source-roots= + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + +# In verbose mode, extra non-checker-related info will be displayed. +#verbose= + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. If left empty, argument names will be checked with the set +# naming style. +#argument-rgx= + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. If left empty, attribute names will be checked with the set naming +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +bad-names-rgxs= + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. If left empty, class attribute names will be checked +# with the set naming style. +#class-attribute-rgx= + +# Naming style matching correct class constant names. +class-const-naming-style=UPPER_CASE + +# Regular expression matching correct class constant names. Overrides class- +# const-naming-style. If left empty, class constant names will be checked with +# the set naming style. +#class-const-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. If left empty, class names will be checked with the set naming style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. If left empty, constant names will be checked with the set naming +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. If left empty, function names will be checked with the set +# naming style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=a, + ai, + ax, + b, + d, + e, + ex, + i, + j, + k, + le, + lt, + m, + n, + Run, + s, + t0, + v, + x, + y, + _ + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs=[\w]*E0 + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. If left empty, inline iteration names will be checked +# with the set naming style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. If left empty, method names will be checked with the set naming style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. If left empty, module names will be checked with the set naming style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Regular expression matching correct type alias names. If left empty, type +# alias names will be checked with the set naming style. +#typealias-rgx= + +# Regular expression matching correct type variable names. If left empty, type +# variable names will be checked with the set naming style. +#typevar-rgx= + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. If left empty, variable names will be checked with the set +# naming style. +#variable-rgx= + + +[CLASSES] + +# Warn about protected attribute access inside special methods +check-protected-access-in-special-methods=no + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + asyncSetUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[DESIGN] + +# List of regular expressions of class ancestor names to ignore when counting +# public methods (see R0903) +exclude-too-few-public-methods= + +# List of qualified class names to ignore when counting class parents (see +# R0901) +ignored-parents= + +# Maximum number of arguments for function / method. +max-args=5 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=1 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when caught. +overgeneral-exceptions=builtins.BaseException,builtins.Exception + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=80 + +# Maximum number of lines in a module. +max-module-lines=5000 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[IMPORTS] + +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level=asteval.Interpreter, + asteval.get_ast_names, + asyncio, + copy.deepcopy, + h5py, + io.BytesIO, + logging.getLogger, + multiprocessing.cpu_count, + numpy, + nexusformat.nexus.nxload, + nexusformat.nexus.NXcollection, + nexusformat.nexus.NXdata, + nexusformat.nexus.NXdetector, + nexusformat.nexus.NXentry, + nexusformat.nexus.NXfield, + nexusformat.nexus.NXgroup, + nexusformat.nexus.NXinstrument, + nexusformat.nexus.NXobject, + nexusformat.nexus.NXprocess, + nexusformat.nexus.NXroot, + nexusformat.nexus.NXsample, + nexusformat.nexus.NXsource, + nexusformat.nexus.NXsubentry, + numexpr.MAX_THREADS, + numexpr.evaluate, + numexpr.set_num_threads, + pathlib.Path, + pyFAI, + pyspec.file.tiff.TiffFile, + requests, + scipy.constants.physical_constants, + scipy.ndimage.gaussian_filter, + scipy.ndimage.zoom, + skimage.transform.iradon, + skimage.restoration.denoise_tv_chambolle, + tomopy.astra, + tomopy.find_center_vo, + tomopy.misc, + tomopy.prep, + tomopy.recon, + yaml.safe_dump, + tarfile, + xarray.DataArray, + yaml, + CHAP.common.MapProcessor, + CHAP.common.models.integration.IntegrationConfig, + CHAP.common.models.map.MapConfig, + CHAP.common.models.map.get_scanparser, + CHAP.common.models.map.import_scanparser, + CHAP.common.utils.fit.Fit, + CHAP.common.utils.fit.FitMultipeak, + CHAP.common.utils.general.index_nearest, + CHAP.common.utils.general.is_int_pair, + CHAP.common.utils.general.is_index_range, + CHAP.common.utils.scanparsers.FMBSAXSWAXSScanParser, + CHAP.common.utils.scanparsers.FMBRotationScanParser, + CHAP.common.utils.scanparsers.FMBXRFScanParser, + CHAP.common.utils.scanparsers.SMBLinearScanParser, + CHAP.common.utils.scanparsers.SMBMCAScanParser, + CHAP.common.utils.scanparsers.SMBRotationScanParser, + CHAP.edd.models.MCACeriaCalibrationConfig, + CHAP.tomo.models.TomoSetupConfig, + CHAP.tomo.models.TomoReduceConfig, + CHAP.tomo.models.TomoFindCenterConfig, + CHAP.tomo.models.TomoReconstructConfig, + CHAP.tomo.models.TomoCombineConfig, + MLaaS.tfaas_client.predictImage + + +# Allow explicit reexports by alias from a package __init__. +allow-reexport-from-package=no + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules= + +# Output a graph (.gv or any supported image format) of external dependencies +# to the given file (report RP0402 must not be disabled). +ext-import-graph= + +# Output a graph (.gv or any supported image format) of all (i.e. internal and +# external) dependencies to the given file (report RP0402 must not be +# disabled). +import-graph= + +# Output a graph (.gv or any supported image format) of internal dependencies +# to the given file (report RP0402 must not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=new + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, +# UNDEFINED. +confidence=HIGH, + CONTROL_FLOW, + INFERENCE, + INFERENCE_FAILURE, + UNDEFINED + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then re-enable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=raw-checker-failed, + bad-inline-option, + locally-disabled, + logging-fstring-interpolation, + logging-not-lazy, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + use-symbolic-message-instead + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable=c-extension-no-member, + logging-format-interpolation + + +[METHOD_ARGS] + +# List of qualified names (i.e., library.method) which require a timeout +# parameter e.g. 'requests.api.get,requests.api.post' +timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + +# Regular expression of note tags to take in consideration. +notes-rgx= + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit,argparse.parse_error + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'fatal', 'error', 'warning', 'refactor', +# 'convention', and 'info' which contain the number of messages in each +# category, as well as 'statement' which is the total number of statements +# analyzed. This score is used by the global evaluation report (RP0004). +evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +msg-template= + +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +#output-format= + +# Tells whether to display a full report or only the messages. +reports=no + +# Activate the evaluation score. +score=yes + + +[SIMILARITIES] + +# Comments are removed from the similarity computation +ignore-comments=yes + +# Docstrings are removed from the similarity computation +ignore-docstrings=yes + +# Imports are removed from the similarity computation +ignore-imports=yes + +# Signatures are removed from the similarity computation +ignore-signatures=yes + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. No available dictionaries : You need to install +# both the python package and the system dependency for enchant to work.. +spelling-dict= + +# List of comma separated words that should be considered directives if they +# appear at the beginning of a comment and should not be checked. +spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no + +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of symbolic message names to ignore for Mixin members. +ignored-checks-for-mixins=no-member, + not-async-context-manager, + not-context-manager, + attribute-defined-outside-init + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# Regex pattern to define which classes are considered mixins. +mixin-class-rgx=.*[Mm]ixin + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of names allowed to shadow builtins +allowed-redefined-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io From 9b79497a5fb9ae3652f7240489f596e71b3bec1a Mon Sep 17 00:00:00 2001 From: Rolf Verberg Date: Fri, 21 Apr 2023 10:54:24 -0400 Subject: [PATCH 5/6] style: fixed style errors found by pycodestyle Ran pycodestyle with --ignore=W503,E231,E226,E265,E722 --- CHAP/common/models/integration.py | 40 ++-- CHAP/common/models/map.py | 52 +++--- CHAP/common/processor.py | 68 +++---- CHAP/common/utils/fit.py | 298 +++++++++++++++--------------- CHAP/common/utils/general.py | 138 ++++++++++---- CHAP/common/utils/material.py | 67 ++++--- CHAP/common/utils/scanparsers.py | 237 +++++++++++++----------- CHAP/common/writer.py | 10 +- CHAP/edd/models.py | 109 ++++++----- CHAP/edd/processor.py | 88 +++++---- CHAP/edd/reader.py | 2 +- CHAP/edd/writer.py | 2 +- CHAP/inference/processor.py | 15 +- CHAP/inference/reader.py | 2 +- CHAP/inference/writer.py | 2 +- CHAP/pipeline.py | 4 +- CHAP/processor.py | 21 ++- CHAP/reader.py | 26 +-- CHAP/runner.py | 27 +-- CHAP/saxswaxs/processor.py | 2 +- CHAP/saxswaxs/reader.py | 2 +- CHAP/saxswaxs/writer.py | 2 +- CHAP/sin2psi/processor.py | 2 +- CHAP/sin2psi/reader.py | 2 +- CHAP/sin2psi/writer.py | 2 +- CHAP/tomo/models.py | 35 ++-- CHAP/tomo/processor.py | 150 ++++++++------- CHAP/writer.py | 21 ++- 28 files changed, 775 insertions(+), 651 deletions(-) diff --git a/CHAP/common/models/integration.py b/CHAP/common/models/integration.py index e230198..34da0a7 100644 --- a/CHAP/common/models/integration.py +++ b/CHAP/common/models/integration.py @@ -200,7 +200,7 @@ def validate_radial_units(cls, radial_units): return radial_units raise ValueError( f'Invalid radial units: {radial_units}. ' - + f'Must be one of {", ".join(RADIAL_UNITS.keys())}') + f'Must be one of {", ".join(RADIAL_UNITS.keys())}') @validator('azimuthal_units', allow_reuse=True) def validate_azimuthal_units(cls, azimuthal_units): @@ -220,7 +220,7 @@ def validate_azimuthal_units(cls, azimuthal_units): return azimuthal_units raise ValueError( f'Invalid azimuthal units: {azimuthal_units}. ' - + f'Must be one of {", ".join(AZIMUTHAL_UNITS.keys())}') + f'Must be one of {", ".join(AZIMUTHAL_UNITS.keys())}') def validate_range_max(range_name:str): """Validate the maximum value of an integration range. @@ -252,8 +252,8 @@ def _validate_range_max(cls, range_max, values): return range_max raise ValueError( 'Maximum value of integration range must be ' - + 'greater than minimum value of integration range ' - + f'({range_name}_min={range_min}).') + 'greater than minimum value of integration range ' + f'({range_name}_min={range_min}).') return _validate_range_max _validate_radial_max = validator( @@ -289,9 +289,9 @@ def validate_for_map_config(self, map_config:BaseModel): except Exception as exc: raise RuntimeError( 'Could not find data file for detector prefix ' - + f'{detector.prefix} ' - + f'on scan number {scan_number} ' - + f'in spec file {scans.spec_file}') from exc + f'{detector.prefix} ' + f'on scan number {scan_number} ' + f'in spec file {scans.spec_file}') from exc def get_azimuthal_adjustments(self): """To enable a continuous range of integration in the @@ -397,17 +397,17 @@ def get_radially_integrated_data(self, intensity_each_detector = [] variance_each_detector = [] integrators = self.get_azimuthal_integrators() - for integrator,detector in zip(integrators,self.detectors): + for integrator, detector in zip(integrators, self.detectors): detector_data = spec_scans.get_detector_data( [detector], scan_number, scan_step_index)[0] result = integrator.integrate_radial( detector_data, self.azimuthal_npt, unit=self.azimuthal_units, - azimuth_range=(chi_min,chi_max), + azimuth_range=(chi_min, chi_max), radial_unit=self.radial_units, - radial_range=(self.radial_min,self.radial_max), - mask=detector.mask_array) #, error_model=self.error_model) + radial_range=(self.radial_min, self.radial_max), + mask=detector.mask_array) # , error_model=self.error_model) intensity_each_detector.append(result.intensity) if result.sigma is not None: variance_each_detector.append(result.sigma**2) @@ -415,7 +415,7 @@ def get_radially_integrated_data(self, # together intensity = np.nansum(intensity_each_detector, axis=0) # Ignore data at values of chi for which there was no data - intensity = np.where(intensity==0, np.nan, intensity) + intensity = np.where(intensity == 0, np.nan, intensity) if len(intensity_each_detector) != len(variance_each_detector): return intensity @@ -501,17 +501,17 @@ def integrated_data_coordinates(self): """ if self.integration_type == 'azimuthal': return get_integrated_data_coordinates( - radial_range=(self.radial_min,self.radial_max), + radial_range=(self.radial_min, self.radial_max), radial_npt=self.radial_npt) if self.integration_type == 'radial': return get_integrated_data_coordinates( - azimuthal_range=(self.azimuthal_min,self.azimuthal_max), + azimuthal_range=(self.azimuthal_min, self.azimuthal_max), azimuthal_npt=self.azimuthal_npt) if self.integration_type == 'cake': return get_integrated_data_coordinates( - radial_range=(self.radial_min,self.radial_max), + radial_range=(self.radial_min, self.radial_max), radial_npt=self.radial_npt, - azimuthal_range=(self.azimuthal_min,self.azimuthal_max), + azimuthal_range=(self.azimuthal_min, self.azimuthal_max), azimuthal_npt=self.azimuthal_npt) return None @@ -521,7 +521,7 @@ def integrated_data_dims(self): data produced by this instance of `IntegrationConfig`. """ directions = list(self.integrated_data_coordinates.keys()) - dim_names = [getattr(self, f'{direction}_units') \ + dim_names = [getattr(self, f'{direction}_units') for direction in directions] return dim_names @@ -531,8 +531,8 @@ def integrated_data_shape(self): data produced by this instance of `IntegrationConfig` for a single scan step. """ - return tuple([len(coordinate_values) \ - for coordinate_name,coordinate_values \ + return tuple([len(coordinate_values) + for coordinate_name, coordinate_values in self.integrated_data_coordinates.items()]) @@ -615,7 +615,7 @@ def get_multi_geometry_integrator(poni_files:tuple, radial_unit:str, ais, unit=radial_unit, radial_range=radial_range, - azimuth_range=(chi_min,chi_max), + azimuth_range=(chi_min, chi_max), wavelength=sum([ai.wavelength for ai in ais])/len(ais), chi_disc=chi_disc) return multi_geometry diff --git a/CHAP/common/models/map.py b/CHAP/common/models/map.py index 82731cf..f50552f 100644 --- a/CHAP/common/models/map.py +++ b/CHAP/common/models/map.py @@ -73,7 +73,7 @@ def validate_scan_numbers(cls, scan_numbers, values): scan = spec_scans.get_scan_by_number(scan_number) if scan is None: raise ValueError( - f'There is no scan number {scan_number} in {spec_file}') + f'No scan number {scan_number} in {spec_file}') return scan_numbers @property @@ -112,13 +112,9 @@ def get_index(self, scan_number:int, scan_step_index:int, map_config): index = () for independent_dimension in map_config.independent_dimensions: coordinate_index = list( - map_config.coords[independent_dimension.label] - ).index( + map_config.coords[independent_dimension.label]).index( independent_dimension.get_value( - self, - scan_number, - scan_step_index) - ) + self, scan_number, scan_step_index)) index = (coordinate_index, *index) return index @@ -200,7 +196,7 @@ def validate_label(cls, label): """Validate that the supplied `label` does not conflict with any of the values for `label` reserved for certain data needed to perform corrections. - + :param label: The value of `label` to validate :type label: str :raises ValueError: If `label` is one of the reserved values. @@ -212,13 +208,13 @@ def validate_label(cls, label): and label in CorrectionsData.reserved_labels()): raise ValueError( f'{cls.__name__}.label may not be any of the following ' - + f'reserved values: {CorrectionsData.reserved_labels()}') + f'reserved values: {CorrectionsData.reserved_labels()}') return label def validate_for_station(self, station:str): """Validate this instance of `PointByPointScanData` for a certain choice of station (beamline). - + :param station: The name of the station (in 'idxx' format). :type station: str :raises TypeError: If the station is not compatible with the @@ -231,7 +227,7 @@ def validate_for_station(self, station:str): and self.data_type == 'smb_par'): raise TypeError( f'{self.__class__.__name__}.data_type may not be "smb_par" ' - + f'when station is "{station}"') + f'when station is "{station}"') def validate_for_spec_scans( self, spec_scans:list[SpecScans], @@ -243,8 +239,8 @@ def validate_for_spec_scans( be checked for the presence of the data represented by this instance of `PointByPointScanData` :type spec_scans: list[SpecScans] - :param scan_step_index: A specific scan step index to validate, defaults - to `'all'`. + :param scan_step_index: A specific scan step index to validate, + defaults to `'all'`. :type scan_step_index: Union[Literal['all'],int], optional :raises RuntimeError: If the data represented by this instance of `PointByPointScanData` is missing for the specified scan steps. @@ -257,24 +253,24 @@ def validate_for_spec_scans( if scan_step_index == 'all': scan_step_index_range = range(scanparser.spec_scan_npts) else: - scan_step_index_range = range(scan_step_index, - scan_step_index + 1) + scan_step_index_range = range( + scan_step_index, 1+scan_step_index) for index in scan_step_index_range: try: self.get_value(scans, scan_number, index) except: raise RuntimeError( f'Could not find data for {self.name} ' - + f'(data_type "{self.data_type}") ' - + f'on scan number {scan_number} ' - + f'for index {index} ' - + f'in spec file {scans.spec_file}') + f'(data_type "{self.data_type}") ' + f'on scan number {scan_number} ' + f'for index {index} ' + f'in spec file {scans.spec_file}') def get_value(self, spec_scans:SpecScans, scan_number:int, scan_step_index:int): """Return the value recorded for this instance of `PointByPointScanData` at a specific scan step. - + :param spec_scans: An instance of `SpecScans` in which the requested scan step occurs. :type spec_scans: SpecScans @@ -304,6 +300,7 @@ def get_value(self, spec_scans:SpecScans, self.name) return None + @cache def get_spec_motor_value(spec_file:str, scan_number:int, scan_step_index:int, spec_mnemonic:str): @@ -582,8 +579,8 @@ def validate_experiment_type(cls, value, values): if value not in allowed_experiment_types: raise ValueError( f'For station {station}, allowed experiment types are ' - + f'{", ".join(allowed_experiment_types)}. ' - + f'Supplied experiment type {value} is not allowed.') + f'{", ".join(allowed_experiment_types)}. ' + f'Supplied experiment type {value} is not allowed.') return value @property @@ -618,8 +615,8 @@ def dims(self): """Return a tuple of the independent dimension labels for the map. """ - return [point_by_point_scan_data.label \ - for point_by_point_scan_data \ + return [point_by_point_scan_data.label + for point_by_point_scan_data in self.independent_dimensions[::-1]] @property @@ -639,10 +636,9 @@ def all_scalar_data(self): corrections-data-related fields, as well as any additional items in the optional `scalar_data` field. """ - return [getattr(self,l,None) \ - for l in CorrectionsData.reserved_labels() \ - if getattr(self,l,None) is not None] \ - + self.scalar_data + return [getattr(self, label, None) + for label in CorrectionsData.reserved_labels() + if getattr(self, label, None) is not None] + self.scalar_data def import_scanparser(station, experiment): diff --git a/CHAP/common/processor.py b/CHAP/common/processor.py index 21c4f7c..8870647 100755 --- a/CHAP/common/processor.py +++ b/CHAP/common/processor.py @@ -4,11 +4,12 @@ """ File : processor.py Author : Valentin Kuznetsov -Description: Module for Processors used in multiple experiment-specific workflows. +Description: Module for Processors used in multiple experiment-specific + workflows. """ # system modules -import json +from json import dumps from time import time # local modules @@ -153,8 +154,8 @@ def get_configs(self, data): map_config = MapConfig(**map_config) integration_config = IntegrationConfig(**integration_config) - self.logger.debug('Got configuration objects in ' - + f'{time()-t0:.3f} seconds') + self.logger.debug( + f'Got configuration objects in {time()-t0:.3f} seconds') return map_config, integration_config @@ -183,13 +184,13 @@ def get_nxprocess(self, map_config, integration_config): nxprocess = NXprocess(name=integration_config.title) - nxprocess.map_config = json.dumps(map_config.dict()) - nxprocess.integration_config = json.dumps(integration_config.dict()) + nxprocess.map_config = dumps(map_config.dict()) + nxprocess.integration_config = dumps(integration_config.dict()) nxprocess.program = 'pyFAI' nxprocess.version = pyFAI.version - for k,v in integration_config.dict().items(): + for k, v in integration_config.dict().items(): if k == 'detectors': continue nxprocess.attrs[k] = v @@ -200,7 +201,8 @@ def get_nxprocess(self, map_config, integration_config): nxdetector.local_name = detector.prefix nxdetector.distance = detector.azimuthal_integrator.dist nxdetector.distance.attrs['units'] = 'm' - nxdetector.calibration_wavelength = detector.azimuthal_integrator.wavelength + nxdetector.calibration_wavelength = \ + detector.azimuthal_integrator.wavelength nxdetector.calibration_wavelength.attrs['units'] = 'm' nxdetector.attrs['poni_file'] = str(detector.poni_file) nxdetector.attrs['mask_file'] = str(detector.mask_file) @@ -213,7 +215,7 @@ def get_nxprocess(self, map_config, integration_config): *map_config.dims, *integration_config.integrated_data_dims ) - for i,dim in enumerate(map_config.independent_dimensions[::-1]): + for i, dim in enumerate(map_config.independent_dimensions[::-1]): nxprocess.data[dim.label] = NXfield( value=map_config.coords[dim.label], units=dim.units, @@ -222,8 +224,8 @@ def get_nxprocess(self, map_config, integration_config): 'local_name': dim.name}) nxprocess.data.attrs[f'{dim.label}_indices'] = i - for i,(coord_name,coord_values) in \ - enumerate(integration_config.integrated_data_coordinates.items()): + for i, (coord_name, coord_values) in enumerate( + integration_config.integrated_data_coordinates.items()): if coord_name == 'radial': type_ = pyFAI.units.RADIAL_UNITS elif coord_name == 'azimuthal': @@ -235,19 +237,16 @@ def get_nxprocess(self, map_config, integration_config): nxprocess.data.attrs[f'{coord_units.name}_indices'] = i + len( map_config.coords) nxprocess.data[coord_units.name].units = coord_units.unit_symbol - nxprocess.data[coord_units.name].attrs['long_name'] = coord_units.label + nxprocess.data[coord_units.name].attrs['long_name'] = \ + coord_units.label nxprocess.data.attrs['signal'] = 'I' nxprocess.data.I = NXfield( value=np.empty( (*tuple( - [len(coord_values) \ - for coord_name,coord_values \ - in map_config.coords.items()][::-1] - ), - *integration_config.integrated_data_shape - ) - ), + [len(coord_values) for coord_name, coord_values + in map_config.coords.items()][::-1]), + *integration_config.integrated_data_shape)), units='a.u', attrs={'long_name':'Intensity (a.u)'}) @@ -255,16 +254,16 @@ def get_nxprocess(self, map_config, integration_config): if integration_config.integration_type == 'azimuthal': integration_method = integrator.integrate1d integration_kwargs = { - 'lst_mask': [detector.mask_array \ - for detector \ + 'lst_mask': [detector.mask_array + for detector in integration_config.detectors], 'npt': integration_config.radial_npt } elif integration_config.integration_type == 'cake': integration_method = integrator.integrate2d integration_kwargs = { - 'lst_mask': [detector.mask_array \ - for detector \ + 'lst_mask': [detector.mask_array + for detector in integration_config.detectors], 'npt_rad': integration_config.radial_npt, 'npt_azim': integration_config.azimuthal_npt, @@ -287,14 +286,14 @@ def get_nxprocess(self, map_config, integration_config): scan_number, scan_step_index) result = integration_processor.process( - (detector_data, integration_method, integration_kwargs)) + (detector_data, + integration_method, integration_kwargs)) nxprocess.data.I[map_index] = result.intensity for detector in integration_config.detectors: - nxprocess[detector.prefix].raw_data_files[map_index] = \ + nxprocess[detector.prefix].raw_data_files[map_index] =\ scanparser.get_detector_data_file( - detector.prefix, - scan_step_index) + detector.prefix, scan_step_index) self.logger.debug(f'Constructed NXprocess in {time()-t0:.3f} seconds') @@ -375,7 +374,7 @@ def get_nxentry(map_config): nxentry = NXentry(name=map_config.title) - nxentry.map_config = json.dumps(map_config.dict()) + nxentry.map_config = dumps(map_config.dict()) nxentry[map_config.sample.name] = NXsample(**map_config.sample.dict()) @@ -386,11 +385,11 @@ def get_nxentry(map_config): nxentry.spec_scans[scans.scanparsers[0].scan_name] = \ NXfield(value=scans.scan_numbers, dtype='int8', - attrs={'spec_file':str(scans.spec_file)}) + attrs={'spec_file': str(scans.spec_file)}) nxentry.data = NXdata() nxentry.data.attrs['axes'] = map_config.dims - for i,dim in enumerate(map_config.independent_dimensions[::-1]): + for i, dim in enumerate(map_config.independent_dimensions[::-1]): nxentry.data[dim.label] = NXfield( value=map_config.coords[dim.label], units=dim.units, @@ -622,17 +621,18 @@ def _process(self, data): content = data['data'] encoding = data['encoding'] - self.logger.debug(f'Decoding content of type {type(content)} ' - + f'with {encoding}') + self.logger.debug( + f'Decoding content of type {type(content)} with {encoding}') try: content = content.decode(encoding) except: self.logger.warning('Failed to decode content of type ' - + f'{type(content)} with {encoding}') + f'{type(content)} with {encoding}') return content + class XarrayToNexusProcessor(Processor): """A Processor to convert the data in an `xarray` structure to an `nexusformat.nexus.NXdata`. @@ -659,6 +659,7 @@ def _process(self, data): return NXdata(signal=signal, axes=axes) + class XarrayToNumpyProcessor(Processor): """A Processor to convert the data in an `xarray.DataArray` structure to an `numpy.ndarray`. @@ -675,6 +676,7 @@ def _process(self, data): return data.data + if __name__ == '__main__': from CHAP.processor import main main() diff --git a/CHAP/common/utils/fit.py b/CHAP/common/utils/fit.py index 89c9f15..728dc93 100755 --- a/CHAP/common/utils/fit.py +++ b/CHAP/common/utils/fit.py @@ -69,8 +69,8 @@ index_nearest, input_num, quick_plot, - #eval_expr, ) +# eval_expr, logger = getLogger(__name__) FLOAT_MIN = float_info.min @@ -80,20 +80,21 @@ fwhm_factor = { 'gaussian': 'fwhm/(2*sqrt(2*log(2)))', 'lorentzian': '0.5*fwhm', - 'splitlorentzian': '0.5*fwhm', # sigma = sigma_r - 'voight': '0.2776*fwhm', # sigma = gamma - 'pseudovoight': '0.5*fwhm', # fraction = 0.5 + 'splitlorentzian': '0.5*fwhm', # sigma = sigma_r + 'voight': '0.2776*fwhm', # sigma = gamma + 'pseudovoight': '0.5*fwhm', # fraction = 0.5 } # amplitude = height_factor*height*fwhm height_factor = { 'gaussian': 'height*fwhm*0.5*sqrt(pi/log(2))', 'lorentzian': 'height*fwhm*0.5*pi', - 'splitlorentzian': 'height*fwhm*0.5*pi', # sigma = sigma_r - 'voight': '3.334*height*fwhm', # sigma = gamma - 'pseudovoight': '1.268*height*fwhm', # fraction = 0.5 + 'splitlorentzian': 'height*fwhm*0.5*pi', # sigma = sigma_r + 'voight': '3.334*height*fwhm', # sigma = gamma + 'pseudovoight': '1.268*height*fwhm', # fraction = 0.5 } + class Fit: """ Wrapper class for lmfit. @@ -122,7 +123,7 @@ def __init__(self, y, x=None, models=None, normalize=True, **kwargs): if not isinstance(self._try_linear_fit, bool): raise ValueError( 'Invalid value of keyword argument try_linear_fit ' - + f'({self._try_linear_fit})') + f'({self._try_linear_fit})') if y is not None: if isinstance(y, (tuple, list, np.ndarray)): self._x = np.asarray(x) @@ -133,7 +134,7 @@ def __init__(self, y, x=None, models=None, normalize=True, **kwargs): if y.ndim != 1: raise ValueError( 'Invalid DataArray dimensions for parameter y ' - + f'({y.ndim})') + f'({y.ndim})') self._x = np.asarray(y[y.dims[0]]) self._y = y else: @@ -144,7 +145,7 @@ def __init__(self, y, x=None, models=None, normalize=True, **kwargs): if self._x.size != self._y.size: raise ValueError( f'Inconsistent x and y dimensions ({self._x.size} vs ' - + f'{self._y.size})') + f'{self._y.size})') if 'mask' in kwargs: self._mask = kwargs.pop('mask') if self._mask is None: @@ -157,7 +158,7 @@ def __init__(self, y, x=None, models=None, normalize=True, **kwargs): if self._x.size != self._mask.size: raise ValueError( f'Inconsistent x and mask dimensions ({self._x.size} ' - + f'vs {self._mask.size})') + f'vs {self._mask.size})') y_masked = np.asarray(self._y)[~self._mask] y_min = float(y_masked.min()) self._y_range = float(y_masked.max())-y_min @@ -182,10 +183,9 @@ def best_errors(self): """Return errors in the best fit parameters.""" if self._result is None: return None - return( - {name:self._result.params[name].stderr - for name in sorted(self._result.params) - if name != 'tmp_normalization_offset_c'}) + return {name:self._result.params[name].stderr + for name in sorted(self._result.params) + if name != 'tmp_normalization_offset_c'} @property def best_fit(self): @@ -204,11 +204,11 @@ def best_parameters(self): par = self._result.params[name] parameters[name] = { 'value': par.value, - 'error': par.stderr, - 'init_value': par.init_value, - 'min': par.min, - 'max': par.max, - 'vary': par.vary, 'expr': par.expr + 'error': par.stderr, + 'init_value': par.init_value, + 'min': par.min, + 'max': par.max, + 'vary': par.vary, 'expr': par.expr } return parameters @@ -250,10 +250,9 @@ def best_values(self): """Return values of the best fit parameters.""" if self._result is None: return None - return( - {name:self._result.params[name].value - for name in sorted(self._result.params) - if name != 'tmp_normalization_offset_c'}) + return {name:self._result.params[name].value + for name in sorted(self._result.params) + if name != 'tmp_normalization_offset_c'} @property def chisqr(self): @@ -336,10 +335,9 @@ def init_values(self): """Return the initial values for the fit parameters.""" if self._result is None or self._result.init_params is None: return None - return( - {name:self._result.init_params[name].value - for name in sorted(self._result.init_params) - if name != 'tmp_normalization_offset_c'}) + return {name:self._result.init_params[name].value + for name in sorted(self._result.init_params) + if name != 'tmp_normalization_offset_c'} @property def normalization_offset(self): @@ -368,10 +366,9 @@ def num_func_eval(self): @property def parameters(self): """Return the fit parameter info.""" - return( - {name:{'min': par.min, 'max': par.max, 'vary': par.vary, - 'expr': par.expr} for name, par in self._parameters.items() - if name != 'tmp_normalization_offset_c'}) + return {name:{'min': par.min, 'max': par.max, 'vary': par.vary, + 'expr': par.expr} for name, par in self._parameters.items() + if name != 'tmp_normalization_offset_c'} @property def redchi(self): @@ -442,13 +439,13 @@ def add_parameter(self, **parameter): if self._norm is None: logger.warning( f'Ignoring norm in parameter {name} in Fit.add_parameter ' - + '(normalization is turned off)') + '(normalization is turned off)') self._parameter_norms[name] = False else: if not isinstance(norm, bool): raise ValueError( f'Invalid "norm" value ({norm}) in parameter ' - + f'{parameter}') + f'{parameter}') self._parameter_norms[name] = norm vary = parameter.get('vary') if vary is not None: @@ -459,12 +456,12 @@ def add_parameter(self, **parameter): if 'min' in parameter: logger.warning( f'Ignoring min in parameter {name} in ' - + f'Fit.add_parameter (vary = {vary})') + f'Fit.add_parameter (vary = {vary})') parameter.pop('min') if 'max' in parameter: logger.warning( f'Ignoring max in parameter {name} in ' - + f'Fit.add_parameter (vary = {vary})') + f'Fit.add_parameter (vary = {vary})') parameter.pop('max') if self._norm is not None and name not in self._parameter_norms: raise ValueError( @@ -508,19 +505,19 @@ def add_model( if parameters is None: raise ValueError( 'Either parameters or parameter_norms is required in ' - + f'{model}') + f'{model}') for par in parameters: name = par['name'] if not isinstance(name, str): raise ValueError( f'Invalid "name" value ({name}) in input ' - + 'parameters') + 'parameters') if par.get('norm') is not None: norm = par.pop('norm') if not isinstance(norm, bool): raise ValueError( f'Invalid "norm" value ({norm}) in input ' - + 'parameters') + 'parameters') new_parameter_norms[f'{pprefix}{name}'] = norm else: for par in parameter_norms: @@ -528,24 +525,24 @@ def add_model( if not isinstance(name, str): raise ValueError( f'Invalid "name" value ({name}) in input ' - + 'parameters') + 'parameters') norm = par.get('norm') if norm is None or not isinstance(norm, bool): raise ValueError( f'Invalid "norm" value ({norm}) in input ' - + 'parameters') + 'parameters') new_parameter_norms[f'{pprefix}{name}'] = norm if parameters is not None: for par in parameters: if par.get('expr') is not None: raise KeyError( f'Invalid "expr" key ({par.get("expr")}) in ' - + f'parameter {name} for a callable model {model}') + f'parameter {name} for a callable model {model}') name = par['name'] if not isinstance(name, str): raise ValueError( f'Invalid "name" value ({name}) in input ' - + 'parameters') + 'parameters') # RV callable model will need partial deriv functions for any linear # parameter to get the linearized matrix, so for now skip linear # solution option @@ -580,7 +577,7 @@ def add_model( if degree is None or not is_int(degree, ge=0, le=7): raise ValueError( 'Invalid parameter degree for build-in step model ' - + f'({degree})') + f'({degree})') newmodel = PolynomialModel(degree=degree, prefix=prefix) for i in range(degree+1): new_parameter_norms[f'{pprefix}c{i}'] = True @@ -627,7 +624,7 @@ def add_model( ('linear', 'atan', 'arctan', 'erf', 'logistic')): raise ValueError( 'Invalid parameter form for build-in step model ' - + f'({form})') + f'({form})') newmodel = StepModel(prefix=prefix, form=form) new_parameter_norms[f'{pprefix}amplitude'] = True new_parameter_norms[f'{pprefix}center'] = False @@ -644,7 +641,7 @@ def add_model( ('linear', 'atan', 'arctan', 'erf', 'logistic')): raise ValueError( 'Invalid parameter form for build-in rectangle model ' - + f'({form})') + f'({form})') newmodel = RectangleModel(prefix=prefix, form=form) new_parameter_norms[f'{pprefix}amplitude'] = True new_parameter_norms[f'{pprefix}center1'] = False @@ -666,28 +663,28 @@ def add_model( if parameter_norms is not None: logger.warning( 'Ignoring parameter_norms (normalization ' - + 'determined from linearity)}') + 'determined from linearity)}') if parameters is not None: for par in parameters: if par.get('expr') is not None: raise KeyError( f'Invalid "expr" key ({par.get("expr")}) in ' - + f'parameter ({par}) for an expression model') + f'parameter ({par}) for an expression model') if par.get('norm') is not None: logger.warning( f'Ignoring "norm" key in parameter ({par}) ' - + '(normalization determined from linearity)') + '(normalization determined from linearity)') par.pop('norm') name = par['name'] if not isinstance(name, str): raise ValueError( f'Invalid "name" value ({name}) in input ' - + 'parameters') + 'parameters') ast = Interpreter() expr_parameters = [ name for name in get_ast_names(ast.parse(expr)) - if name != 'x' and name not in self._parameters - and name not in ast.symtable] + if (name != 'x' and name not in self._parameters + and name not in ast.symtable)] if prefix is None: newmodel = ExpressionModel(expr=expr) else: @@ -765,45 +762,45 @@ def add_model( if not isinstance(parameter['norm'], bool): raise ValueError( f'Invalid "norm" value ({norm}) in the ' - + f'input parameter {name}') + f'input parameter {name}') new_parameter_norms[name] = parameter['norm'] parameter.pop('norm') if parameter.get('expr') is not None: if 'value' in parameter: logger.warning( f'Ignoring value in parameter {name} ' - + f'(set by expression: {parameter["expr"]})') + f'(set by expression: {parameter["expr"]})') parameter.pop('value') if 'vary' in parameter: logger.warning( f'Ignoring vary in parameter {name} ' - + f'(set by expression: {parameter["expr"]})') + f'(set by expression: {parameter["expr"]})') parameter.pop('vary') if 'min' in parameter: logger.warning( f'Ignoring min in parameter {name} ' - + f'(set by expression: {parameter["expr"]})') + f'(set by expression: {parameter["expr"]})') parameter.pop('min') if 'max' in parameter: logger.warning( f'Ignoring max in parameter {name} ' - + f'(set by expression: {parameter["expr"]})') + f'(set by expression: {parameter["expr"]})') parameter.pop('max') if 'vary' in parameter: if not isinstance(parameter['vary'], bool): raise ValueError( f'Invalid "vary" value ({parameter["vary"]}) ' - + f'in the input parameter {name}') + f'in the input parameter {name}') if not parameter['vary']: if 'min' in parameter: logger.warning( f'Ignoring min in parameter {name} ' - + f'(vary = {parameter["vary"]})') + f'(vary = {parameter["vary"]})') parameter.pop('min') if 'max' in parameter: logger.warning( f'Ignoring max in parameter {name} ' - + f'(vary = {parameter["vary"]})') + f'(vary = {parameter["vary"]})') parameter.pop('max') self._parameters[name].set(**parameter) parameter['name'] = name @@ -826,12 +823,12 @@ def add_model( else: logger.warning( f'Ignoring parameter {name} (set by expression: ' - + f'{self._parameters[full_name].expr})') + f'{self._parameters[full_name].expr})') # Check parameter norms # (also need it for expressions to renormalize the errors) - if (self._norm is not None and - (callable(model) or model == 'expression')): + if (self._norm is not None + and (callable(model) or model == 'expression')): missing_norm = False for name in new_parameters.valuesdict(): if name not in self._parameter_norms: @@ -839,7 +836,7 @@ def add_model( print(f'self._parameter_norms:\n{self._parameter_norms}') logger.error( f'Missing parameter normalization type for {name} in ' - + f'{model}') + f'{model}') missing_norm = True if missing_norm: raise ValueError @@ -865,7 +862,7 @@ def fit(self, **kwargs): if not isinstance(interactive, bool): raise ValueError( 'Invalid value of keyword argument interactive ' - + f'({interactive})') + f'({interactive})') else: interactive = False if 'guess' in kwargs: @@ -880,11 +877,11 @@ def fit(self, **kwargs): if not isinstance(try_linear_fit, bool): raise ValueError( 'Invalid value of keyword argument try_linear_fit ' - + f'({try_linear_fit})') + f'({try_linear_fit})') if not self._try_linear_fit: logger.warning( 'Ignore superfluous keyword argument "try_linear_fit" ' - + '(not yet supported for callable models)') + '(not yet supported for callable models)') else: self._try_linear_fit = try_linear_fit if self._result is not None: @@ -906,7 +903,7 @@ def fit(self, **kwargs): if self._x.size != self._mask.size: raise ValueError( f'Inconsistent x and mask dimensions ({self._x.size} vs ' - + f'{self._mask.size})') + f'{self._mask.size})') # Estimate initial parameters with build-in lmfit guess method # (only mplemented for a single model) @@ -926,10 +923,10 @@ def fit(self, **kwargs): 'value': -self._norm[0], 'vary': False, 'norm': True, + }) # 'value': -self._norm[0]/self._norm[1], # 'vary': False, # 'norm': False, - }) # Adjust existing parameters for refit: if 'parameters' in kwargs: @@ -939,17 +936,17 @@ def fit(self, **kwargs): elif not is_dict_series(parameters): raise ValueError( 'Invalid value of keyword argument parameters ' - + f'({parameters})') + f'({parameters})') for par in parameters: name = par['name'] if name not in self._parameters: raise ValueError( f'Unable to match {name} parameter {par} to an ' - + 'existing one') + 'existing one') if self._parameters[name].expr is not None: raise ValueError( f'Unable to modify {name} parameter {par} ' - + '(currently an expression)') + '(currently an expression)') if par.get('expr') is not None: raise KeyError( f'Invalid "expr" key in {name} parameter {par}') @@ -966,7 +963,7 @@ def fit(self, **kwargs): else: logger.warning( f'Ignoring parameter {name} (set by expression: ' - + f'{self._parameters[name].expr})') + f'{self._parameters[name].expr})') # Check for uninitialized parameters for name, par in self._parameters.items(): @@ -1000,8 +997,9 @@ def fit(self, **kwargs): if self._mask is None: self._fit_linear_model(self._x, self._y_norm) else: - self._fit_linear_model(self._x[~self._mask], - np.asarray(self._y_norm)[~self._mask]) + self._fit_linear_model( + self._x[~self._mask], + np.asarray(self._y_norm)[~self._mask]) except: linear_model = False if not linear_model: @@ -1120,23 +1118,23 @@ def guess_init_peak( if len(x) != len(y): logger.error( f'Invalid x and y lengths ({len(x)}, {len(y)}), ' - + 'skip initial guess') + 'skip initial guess') return None, None, None if isinstance(center_guess, (int, float)): if args: logger.warning( 'Ignoring additional arguments for single center_guess ' - + 'value') + 'value') center_guesses = [center_guess] elif isinstance(center_guess, (tuple, list, np.ndarray)): if len(center_guess) == 1: logger.warning( 'Ignoring additional arguments for single center_guess ' - + 'value') + 'value') if not isinstance(center_guess[0], (int, float)): raise ValueError( 'Invalid parameter center_guess ' - + f'({type(center_guess[0])})') + f'({type(center_guess[0])})') center_guess = center_guess[0] else: if len(args) != 1: @@ -1328,7 +1326,7 @@ def _fit_linear_model(self, x, y): if not self._parameter_norms[name]: dexpr_dname = f'({dexpr_dname})/{norm}' y_expr = [(lambda _: ast.eval(str(dexpr_dname))) - (ast(f'x={v}')) for v in x] + (ast(f'x={v}')) for v in x] if ast.error: raise ValueError( f'Unable to evaluate {dexpr_dname}') @@ -1337,7 +1335,7 @@ def _fit_linear_model(self, x, y): # simplify const_expr = str(simplify(f'({const_expr})/{norm}')) delta_y_const = [(lambda _: ast.eval(const_expr)) - (ast(f'x = {v}')) for v in x] + (ast(f'x = {v}')) for v in x] y_const += delta_y_const if ast.error: raise ValueError(f'Unable to evaluate {const_expr}') @@ -1358,7 +1356,7 @@ def _fit_linear_model(self, x, y): mat_a[:,free_parameters.index(name)] = 1.0 else: y_const += self._parameters[name].value \ - * np.ones(len(x)) + * np.ones(len(x)) elif isinstance(component, QuadraticModel): name = f'{component.prefix}a' if name in free_model_parameters: @@ -1399,7 +1397,7 @@ def _fit_linear_model(self, x, y): parameters[name].set(value=1.0) dcomp_dname = component.eval(params=parameters, x=x) for nname in free_parameters: - dexpr_dnname = diff(expr, nname) + dexpr_dnname = diff(expr, nname) if dexpr_dnname: assert self._parameter_norms[name] y_expr = np.asarray( @@ -1425,16 +1423,16 @@ def _fit_linear_model(self, x, y): # (compensate for normalization in expression models) for name, value in zip(free_parameters, solution): self._parameters[name].set(value=value) - if (self._normalized and (have_expression_model - or expr_parameters)): + if (self._normalized + and (have_expression_model or expr_parameters)): for name, norm in self._parameter_norms.items(): par = self._parameters[name] if par.expr is None and norm: self._parameters[name].set(value=par.value*self._norm[1]) self._result = ModelResult(self._model, deepcopy(self._parameters)) self._result.best_fit = self._model.eval(params=self._parameters, x=x) - if (self._normalized and (have_expression_model - or expr_parameters)): + if (self._normalized + and (have_expression_model or expr_parameters)): if 'tmp_normalization_offset_c' in self._parameters: offset = self._parameters['tmp_normalization_offset_c'] else: @@ -1495,8 +1493,8 @@ def _renormalize(self): par.set(value=value, min=_min, max=_max) if self._result is None: return - self._result.best_fit = (self._result.best_fit*self._norm[1] - + self._norm[0]) + self._result.best_fit = ( + self._result.best_fit*self._norm[1] + self._norm[0]) for name, par in self._result.params.items(): if self._parameter_norms.get(name, False): if par.stderr is not None: @@ -1513,8 +1511,8 @@ def _renormalize(self): _max *= self._norm[1] par.set(value=value, min=_min, max=_max) if hasattr(self._result, 'init_fit'): - self._result.init_fit = (self._result.init_fit*self._norm[1] - + self._norm[0]) + self._result.init_fit = ( + self._result.init_fit*self._norm[1] + self._norm[0]) if hasattr(self._result, 'init_values'): init_values = {} for name, value in self._result.init_values.items(): @@ -1604,7 +1602,7 @@ def fit_multipeak( center_exprs=None, fit_type=None, background=None, fwhm_max=None, print_report=False, plot=False, x_eval=None): """Class method for FitMultipeak. - + Make sure that centers and fwhm_max are in the correct units and consistent with expr for a uniform fit (fit_type == 'uniform'). @@ -1658,7 +1656,7 @@ def fit( if param_constraint: logger.warning( ' -> Should not happen with param_constraint set, ' - + 'fail the fit') + 'fail the fit') success = False else: logger.info(' -> Retry fitting with constraints') @@ -1698,7 +1696,7 @@ def _create_model( if len(peak_models) != num_peaks: raise ValueError( 'Inconsistent number of peaks in peak_models ' - + f'({len(peak_models)} vs {num_peaks})') + f'({len(peak_models)} vs {num_peaks})') if num_peaks == 1: if fit_type is not None: logger.debug('Ignoring fit_type input for fitting one peak') @@ -1714,7 +1712,7 @@ def _create_model( if len(center_exprs) != num_peaks: raise ValueError( 'Inconsistent number of peaks in center_exprs ' - + f'({len(center_exprs)} vs {num_peaks})') + f'({len(center_exprs)} vs {num_peaks})') elif fit_type == 'unconstrained' or fit_type is None: if center_exprs is not None: logger.warning( @@ -1747,9 +1745,9 @@ def _create_model( if 'model' not in model: raise KeyError( 'Missing keyword "model" in model in background ' - + f'({model})') + f'({model})') name = model.pop('model') - parameters=model.pop('parameters', None) + parameters = model.pop('parameters', None) self.add_model( name, prefix=f'bkgd_{name}_', parameters=parameters, **model) @@ -1779,7 +1777,7 @@ def _create_model( {'name': 'amplitude', 'value': amp_init, 'min': min_value}, {'name': 'center', 'value': cen_init, 'min': min_value}, {'name': 'sigma', 'value': sig_init, 'min': min_value, - 'max': sig_max}, + 'max': sig_max}, )) else: if fit_type == 'uniform': @@ -1803,10 +1801,10 @@ def _create_model( peak_models[i], prefix=f'peak{i+1}_', parameters=( {'name': 'amplitude', 'value': amp_init, - 'min': min_value}, + 'min': min_value}, {'name': 'center', 'expr': center_exprs[i]}, {'name': 'sigma', 'value': sig_init, - 'min': min_value, 'max': sig_max}, + 'min': min_value, 'max': sig_max}, )) else: self.add_model( @@ -1814,11 +1812,11 @@ def _create_model( prefix=f'peak{i+1}_', parameters=( {'name': 'amplitude', 'value': amp_init, - 'min': min_value}, + 'min': min_value}, {'name': 'center', 'value': cen_init, - 'min': min_value}, + 'min': min_value}, {'name': 'sigma', 'value': sig_init, - 'min': min_value, 'max': sig_max}, + 'min': min_value, 'max': sig_max}, )) def _check_validity(self): @@ -1835,7 +1833,7 @@ def _check_validity(self): elif (((name.endswith('amplitude') or name.endswith('height')) and par['value'] <= 0.0) or ((name.endswith('sigma') or name.endswith('fwhm')) - and par['value'] <= 0.0) + and par['value'] <= 0.0) or (name.endswith('center') and par['value'] <= 0.0) or (name == 'scale_factor' and par['value'] <= 0.0)): logger.info(f'Invalid fit result for {name} ({par["value"]})') @@ -1912,22 +1910,22 @@ def __init__( if self._ymap.ndim < 2: raise ValueError( 'Invalid number of dimension of the input dataset ' - + f'{self._ymap.ndim}') + f'{self._ymap.ndim}') if self._x.size != self._ymap.shape[-1]: raise ValueError( f'Inconsistent x and y dimensions ({self._x.size} vs ' - + f'{self._ymap.shape[-1]})') + f'{self._ymap.shape[-1]})') if not isinstance(normalize, bool): logger.warning( f'Invalid value for normalize ({normalize}) in Fit.__init__: ' - + 'setting normalize to True') + 'setting normalize to True') normalize = True if isinstance(transpose, bool) and not transpose: transpose = None if transpose is not None and self._ymap.ndim < 3: logger.warning( f'Transpose meaningless for {self._ymap.ndim-1}D data maps: ' - + 'ignoring transpose') + 'ignoring transpose') if transpose is not None: if (self._ymap.ndim == 3 and isinstance(transpose, bool) and transpose): @@ -1935,20 +1933,21 @@ def __init__( elif not isinstance(transpose, (tuple, list)): logger.warning( f'Invalid data type for transpose ({transpose}, ' - + f'{type(transpose)}): setting transpose to False') + f'{type(transpose)}): setting transpose to False') elif transpose != self._ymap.ndim-1: logger.warning( f'Invalid dimension for transpose ({transpose}, must be ' - + f'equal to {self._ymap.ndim-1}): ' - + 'setting transpose to False') + f'equal to {self._ymap.ndim-1}): ' + 'setting transpose to False') elif any(i not in transpose for i in range(len(transpose))): logger.warning( f'Invalid index in transpose ({transpose}): ' - + 'setting transpose to False') - elif not all(i==transpose[i] for i in range(self._ymap.ndim-1)): + 'setting transpose to False') + elif not all(i == transpose[i] for i in range(self._ymap.ndim-1)): self._transpose = transpose if self._transpose is not None: - self._inv_transpose = tuple(self._transpose.index(i) + self._inv_transpose = tuple( + self._transpose.index(i) for i in range(len(self._transpose))) # Flatten the map (transpose if requested) @@ -1976,7 +1975,7 @@ def __init__( if self._x.size != self._mask.size: raise ValueError( f'Inconsistent mask dimension ({self._x.size} vs ' - + f'{self._mask.size})') + f'{self._mask.size})') ymap_masked = np.asarray(self._ymap_norm)[:,~self._mask] ymap_min = float(ymap_masked.min()) ymap_max = float(ymap_masked.max()) @@ -2301,39 +2300,39 @@ def fit(self, **kwargs): if num_proc > cpu_count(): logger.warning( f'The requested number of processors ({num_proc}) exceeds the ' - + 'maximum number of processors, num_proc reduced to ' - + f'({cpu_count()})') + 'maximum number of processors, num_proc reduced to ' + f'({cpu_count()})') num_proc = cpu_count() if 'try_no_bounds' in kwargs: self._try_no_bounds = kwargs.pop('try_no_bounds') if not isinstance(self._try_no_bounds, bool): raise ValueError( 'Invalid value for keyword argument try_no_bounds ' - + f'({self._try_no_bounds})') + f'({self._try_no_bounds})') if 'redchi_cutoff' in kwargs: self._redchi_cutoff = kwargs.pop('redchi_cutoff') if not is_num(self._redchi_cutoff, gt=0): raise ValueError( 'Invalid value for keyword argument redchi_cutoff' - + f'({self._redchi_cutoff})') + f'({self._redchi_cutoff})') if 'print_report' in kwargs: self._print_report = kwargs.pop('print_report') if not isinstance(self._print_report, bool): raise ValueError( 'Invalid value for keyword argument print_report' - + f'({self._print_report})') + f'({self._print_report})') if 'plot' in kwargs: self._plot = kwargs.pop('plot') if not isinstance(self._plot, bool): raise ValueError( 'Invalid value for keyword argument plot' - + f'({self._plot})') + f'({self._plot})') if 'skip_init' in kwargs: self._skip_init = kwargs.pop('skip_init') if not isinstance(self._skip_init, bool): raise ValueError( 'Invalid value for keyword argument skip_init' - + f'({self._skip_init})') + f'({self._skip_init})') # Apply mask if supplied: if 'mask' in kwargs: @@ -2343,7 +2342,7 @@ def fit(self, **kwargs): if self._x.size != self._mask.size: raise ValueError( f'Inconsistent x and mask dimensions ({self._x.size} vs ' - + f'{self._mask.size})') + f'{self._mask.size})') # Add constant offset for a normalized single component model if self._result is None and self._norm is not None and self._norm[0]: @@ -2355,10 +2354,10 @@ def fit(self, **kwargs): 'value': -self._norm[0], 'vary': False, 'norm': True, + }) # 'value': -self._norm[0]/self._norm[1], # 'vary': False, # 'norm': False, - }) # Adjust existing parameters for refit: if 'parameters' in kwargs: @@ -2368,19 +2367,19 @@ def fit(self, **kwargs): elif not is_dict_series(parameters): raise ValueError( 'Invalid value for keyword argument parameters' - + f'({parameters})') + f'({parameters})') for par in parameters: name = par['name'] if name not in self._parameters: raise ValueError( f'Unable to match {name} parameter {par} to an ' - + 'existing one') + 'existing one') if self._parameters[name].expr is not None: raise ValueError( f'Unable to modify {name} parameter {par} ' - + '(currently an expression)') - value = par.get('value') - vary = par.get('vary') + '(currently an expression)') + value = par.get('value') + vary = par.get('vary') if par.get('expr') is not None: raise KeyError( f'Invalid "expr" key in {name} parameter {par}') @@ -2439,8 +2438,8 @@ def fit(self, **kwargs): assert self._best_values.shape[0] == num_best_parameters assert self._best_values.shape[1:] == self._map_shape if self._transpose is not None: - self._best_values = np.transpose(self._best_values, - [0]+[i+1 for i in self._transpose]) + self._best_values = np.transpose( + self._best_values, [0]+[i+1 for i in self._transpose]) self._best_values = [ np.reshape(self._best_values[i], self._map_dim) for i in range(num_best_parameters)] @@ -2454,8 +2453,9 @@ def fit(self, **kwargs): self._normalize() # Prevent initial values from sitting at boundaries - self._parameter_bounds = {name:{'min': par.min, 'max': par.max} - for name, par in self._parameters.items() if par.vary} + self._parameter_bounds = { + name:{'min': par.min, 'max': par.max} + for name, par in self._parameters.items() if par.vary} for name, par in self._parameters.items(): if par.vary: par.set(value=self._reset_par_at_boundary(par, par.value)) @@ -2476,8 +2476,8 @@ def fit(self, **kwargs): self._max_nfev_flat = np.zeros(self._map_dim, dtype=bool) self._redchi_flat = np.zeros(self._map_dim, dtype=np.float64) self._success_flat = np.zeros(self._map_dim, dtype=bool) - self._best_fit_flat = np.zeros((self._map_dim, x_size), - dtype=self._ymap_norm.dtype) + self._best_fit_flat = np.zeros( + (self._map_dim, x_size), dtype=self._ymap_norm.dtype) self._best_errors_flat = [ np.zeros(self._map_dim, dtype=np.float64) for _ in range(num_best_parameters+num_new_parameters)] @@ -2517,17 +2517,17 @@ def fit(self, **kwargs): self._best_errors_flat = [] for i in range(num_best_parameters+num_new_parameters): filename_memmap = path.join( - self._memfolder, f'best_errors_memmap_{i}') + self._memfolder, f'best_errors_memmap_{i}') self._best_errors_flat.append( np.memmap(filename_memmap, dtype=np.float64, - shape=self._map_dim, mode='w+')) + shape=self._map_dim, mode='w+')) self._best_values_flat = [] for i in range(num_best_parameters): filename_memmap = path.join( self._memfolder, f'best_values_memmap_{i}') self._best_values_flat.append( np.memmap(filename_memmap, dtype=np.float64, - shape=self._map_dim, mode='w+')) + shape=self._map_dim, mode='w+')) if self._result is not None: self._best_values_flat[i][:] = self._best_values[i][:] for i in range(num_new_parameters): @@ -2536,7 +2536,7 @@ def fit(self, **kwargs): f'best_values_memmap_{i+num_best_parameters}') self._best_values_flat.append( np.memmap(filename_memmap, dtype=np.float64, - shape=self._map_dim, mode='w+')) + shape=self._map_dim, mode='w+')) # Update the best parameter list if num_new_parameters: @@ -2562,7 +2562,6 @@ def fit(self, **kwargs): try: delattr(self._result, attr) except AttributeError: -# logger.warning(f'Unknown attribute {attr}') pass if num_proc == 1: @@ -2575,20 +2574,19 @@ def fit(self, **kwargs): if num_proc > num_fit: logger.warning( f'The requested number of processors ({num_proc}) exceeds ' - + f'the number of fits, num_proc reduced to ({num_fit})') + f'the number of fits, num_proc reduced to ({num_fit})') num_proc = num_fit num_fit_per_proc = 1 else: num_fit_per_proc = round((num_fit)/num_proc) if num_proc*num_fit_per_proc < num_fit: - num_fit_per_proc +=1 + num_fit_per_proc += 1 num_fit_batch = min(num_fit_per_proc, 40) with Parallel(n_jobs=num_proc) as parallel: parallel( delayed(self._fit_parallel) - (current_best_values, num_fit_batch, - n_start, **kwargs) - for n_start in range(1, self._map_dim, num_fit_batch)) + (current_best_values, num_fit_batch, n_start, **kwargs) + for n_start in range(1, self._map_dim, num_fit_batch)) # Renormalize the initial parameters for external use if self._norm is not None and self._normalized: @@ -2770,7 +2768,7 @@ def _fit(self, n, current_best_values, return_result=False, **kwargs): if self._plot: dims = np.unravel_index(n, self._map_shape) if self._inv_transpose is not None: - dims= tuple( + dims = tuple( dims[self._inv_transpose[i]] for i in range(len(dims))) super().plot( result=result, y=np.asarray(self._ymap[dims]), @@ -2807,8 +2805,8 @@ def _renormalize(self, n, result): if (not np.isinf(par.max) and abs(par.max) != FLOAT_MIN): par.max *= self._norm[1] - self._best_fit_flat[n] = (result.best_fit*self._norm[1] - + self._norm[0]) + self._best_fit_flat[n] = ( + result.best_fit*self._norm[1] + self._norm[0]) for i, name in enumerate(self._best_parameters): self._best_values_flat[i][n] = np.float64( result.params[name].value) @@ -2816,6 +2814,6 @@ def _renormalize(self, n, result): result.params[name].stderr) if self._plot: if not self._skip_init: - result.init_fit = (result.init_fit*self._norm[1] - + self._norm[0]) + result.init_fit = ( + result.init_fit*self._norm[1] + self._norm[0]) result.best_fit = np.copy(self._best_fit_flat[n]) diff --git a/CHAP/common/utils/general.py b/CHAP/common/utils/general.py index 6033254..a5ab9e2 100755 --- a/CHAP/common/utils/general.py +++ b/CHAP/common/utils/general.py @@ -6,8 +6,8 @@ Author : Rolf Verberg Description: A collection of general modules """ -#RV write function that returns a list of peak indices for a given plot -#RV use raise_error concept on more functions +# RV write function that returns a list of peak indices for a given plot +# RV use raise_error concept on more functions # System modules from ast import literal_eval @@ -32,18 +32,24 @@ logger = getLogger(__name__) + def depth_list(_list): """Return the depth of a list.""" return isinstance(_list, list) and 1+max(map(depth_list, _list)) + + def depth_tuple(_tuple): """Return the depth of a tuple.""" return isinstance(_tuple, tuple) and 1+max(map(depth_tuple, _tuple)) + + def unwrap_tuple(_tuple): """Unwrap a tuple.""" if depth_tuple(_tuple) > 1 and len(_tuple) == 1: _tuple = unwrap_tuple(*_tuple) return _tuple + def illegal_value(value, name, location=None, raise_error=False, log=True): """Print illegal value message and/or raise error.""" if not isinstance(location, str): @@ -60,6 +66,7 @@ def illegal_value(value, name, location=None, raise_error=False, log=True): if raise_error: raise ValueError(error_msg) + def illegal_combination( value1, name1, value2, name2, location=None, raise_error=False, log=True): @@ -70,15 +77,16 @@ def illegal_combination( location = f'in {location} ' if isinstance(name1, str): error_msg = f'Illegal combination for {name1} and {name2} {location}' \ - + f'({value1}, {type(value1)} and {value2}, {type(value2)})' + f'({value1}, {type(value1)} and {value2}, {type(value2)})' else: error_msg = f'Illegal combination {location}' \ - + f'({value1}, {type(value1)} and {value2}, {type(value2)})' + f'({value1}, {type(value1)} and {value2}, {type(value2)})' if log: logger.error(error_msg) if raise_error: raise ValueError(error_msg) + def test_ge_gt_le_lt( ge, gt, le, lt, func, location=None, raise_error=False, log=True): """ @@ -127,6 +135,7 @@ def test_ge_gt_le_lt( return False return True + def range_string_ge_gt_le_lt(ge=None, gt=None, le=None, lt=None): """ Return a range string representation matching the ge, gt, le, lt @@ -156,6 +165,7 @@ def range_string_ge_gt_le_lt(ge=None, gt=None, le=None, lt=None): range_string += f'{lt})' return range_string + def is_int(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): """ Value is an integer in range ge <= v <= le or gt < v < lt or some @@ -166,6 +176,7 @@ def is_int(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): """ return _is_int_or_num(v, 'int', ge, gt, le, lt, raise_error, log) + def is_num(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): """ Value is a number in range ge <= v <= le or gt < v < lt or some @@ -176,6 +187,7 @@ def is_num(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): """ return _is_int_or_num(v, 'num', ge, gt, le, lt, raise_error, log) + def _is_int_or_num( v, type_str, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): @@ -219,6 +231,7 @@ def _is_int_or_num( return False return True + def is_int_pair( v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): """ @@ -231,6 +244,7 @@ def is_int_pair( """ return _is_int_or_num_pair(v, 'int', ge, gt, le, lt, raise_error, log) + def is_num_pair( v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): """ @@ -243,6 +257,7 @@ def is_num_pair( """ return _is_int_or_num_pair(v, 'num', ge, gt, le, lt, raise_error, log) + def _is_int_or_num_pair( v, type_str, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): @@ -285,11 +300,12 @@ def _is_int_or_num_pair( elif not _is_int_or_num_pair( lt, type_str, raise_error=raise_error, log=log): return False - if (not func(v[0], ge[0], gt[0], le[0], lt[0], raise_error, log) or - not func(v[1], ge[1], gt[1], le[1], lt[1], raise_error, log)): + if (not func(v[0], ge[0], gt[0], le[0], lt[0], raise_error, log) + or not func(v[1], ge[1], gt[1], le[1], lt[1], raise_error, log)): return False return True + def is_int_series( t_or_l, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): @@ -307,6 +323,7 @@ def is_int_series( return False return True + def is_num_series( t_or_l, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): @@ -324,6 +341,7 @@ def is_num_series( return False return True + def is_str_series(t_or_l, raise_error=False, log=True): """ Value is a tuple or list of strings. @@ -334,6 +352,7 @@ def is_str_series(t_or_l, raise_error=False, log=True): return False return True + def is_dict_series(t_or_l, raise_error=False, log=True): """ Value is a tuple or list of dictionaries. @@ -344,6 +363,7 @@ def is_dict_series(t_or_l, raise_error=False, log=True): return False return True + def is_dict_nums(d, raise_error=False, log=True): """ Value is a dictionary with single number values @@ -354,6 +374,7 @@ def is_dict_nums(d, raise_error=False, log=True): return False return True + def is_dict_strings(d, raise_error=False, log=True): """ Value is a dictionary with single string values @@ -364,6 +385,7 @@ def is_dict_strings(d, raise_error=False, log=True): return False return True + def is_index(v, ge=0, lt=None, raise_error=False, log=True): """ Value is an array index in range ge <= v < lt. NOTE lt IS NOT @@ -376,6 +398,7 @@ def is_index(v, ge=0, lt=None, raise_error=False, log=True): return False return is_int(v, ge=ge, lt=lt, raise_error=raise_error, log=log) + def is_index_range(v, ge=0, le=None, lt=None, raise_error=False, log=True): """ Value is an array index range in range ge <= v[0] <= v[1] <= le or @@ -401,6 +424,7 @@ def is_index_range(v, ge=0, le=None, lt=None, raise_error=False, log=True): return False return True + def index_nearest(a, value): """Return index of nearest array value.""" a = np.asarray(a) @@ -411,6 +435,7 @@ def index_nearest(a, value): value *= 1.0+float_info.epsilon return (int)(np.argmin(np.abs(a-value))) + def index_nearest_low(a, value): """Return index of nearest array value, rounded down""" a = np.asarray(a) @@ -422,6 +447,7 @@ def index_nearest_low(a, value): index -= 1 return index + def index_nearest_upp(a, value): """Return index of nearest array value, rounded upp.""" a = np.asarray(a) @@ -433,12 +459,14 @@ def index_nearest_upp(a, value): index += 1 return index + def round_to_n(x, n=1): """Round to a specific number of decimals.""" if x == 0.0: return 0 return type(x)(round(x, n-1-int(np.floor(np.log10(abs(x)))))) + def round_up_to_n(x, n=1): """Round up to a specific number of decimals.""" x_round = round_to_n(x, n) @@ -446,6 +474,7 @@ def round_up_to_n(x, n=1): x_round += np.sign(x) * 10**(np.floor(np.log10(abs(x)))+1-n) return type(x)(x_round) + def trunc_to_n(x, n=1): """Truncate to a specific number of decimals.""" x_round = round_to_n(x, n) @@ -453,6 +482,7 @@ def trunc_to_n(x, n=1): x_round -= np.sign(x) * 10**(np.floor(np.log10(abs(x)))+1-n) return type(x)(x_round) + def almost_equal(a, b, sig_figs): """ Check if equal to within a certain number of significant digits. @@ -461,7 +491,8 @@ def almost_equal(a, b, sig_figs): return abs(round_to_n(a-b, sig_figs)) < pow(10, 1-sig_figs) raise ValueError( f'Invalid value for a or b in almost_equal (a: {a}, {type(a)}, ' - + f'b: {b}, {type(b)})') + f'b: {b}, {type(b)})') + def string_to_list(s, split_on_dash=True, remove_duplicates=True, sort=True): """ @@ -483,7 +514,8 @@ def string_to_list(s, split_on_dash=True, remove_duplicates=True, sort=True): try: l_of_i = [] for v in list1: - list2 = [literal_eval(x) + list2 = [ + literal_eval(x) for x in re_split(r'\s+-\s+|\s+-|-\s+|\s+|-', v)] if len(list2) == 1: l_of_i += list2 @@ -502,6 +534,7 @@ def string_to_list(s, split_on_dash=True, remove_duplicates=True, sort=True): l_of_i = sorted(l_of_i) return l_of_i + def get_trailing_int(string): """Get the trailing integer in a string.""" index_regex = re_compile(r'\d+$') @@ -510,6 +543,7 @@ def get_trailing_int(string): return None return int(match.group()) + def input_int( s=None, ge=None, gt=None, le=None, lt=None, default=None, inset=None, raise_error=False, log=True): @@ -517,6 +551,7 @@ def input_int( return _input_int_or_num( 'int', s, ge, gt, le, lt, default, inset, raise_error, log) + def input_num( s=None, ge=None, gt=None, le=None, lt=None, default=None, raise_error=False, log=True): @@ -524,6 +559,7 @@ def input_num( return _input_int_or_num( 'num', s, ge, gt, le, lt, default, None, raise_error,log) + def _input_int_or_num( type_str, s=None, ge=None, gt=None, le=None, lt=None, default=None, inset=None, raise_error=False, log=True): @@ -599,6 +635,7 @@ def _input_int_or_num( type_str, s, ge, gt, le, lt, default, inset, raise_error, log) return v + def input_int_list( s=None, ge=None, le=None, split_on_dash=True, remove_duplicates=True, sort=True, raise_error=False, log=True): @@ -617,6 +654,7 @@ def input_int_list( 'int', s, ge, le, split_on_dash, remove_duplicates, sort, raise_error, log) + def input_num_list( s=None, ge=None, le=None, remove_duplicates=True, sort=True, raise_error=False, log=True): @@ -633,10 +671,11 @@ def input_num_list( return _input_int_or_num_list( 'num', s, ge, le, False, remove_duplicates, sort, raise_error, log) + def _input_int_or_num_list( type_str, s=None, ge=None, le=None, split_on_dash=True, remove_duplicates=True, sort=True, raise_error=False, log=True): - #RV do we want a limit on max dimension? + # RV do we want a limit on max dimension? if type_str == 'int': if not test_ge_gt_le_lt( ge, None, le, None, is_int, 'input_int_or_num_list', @@ -668,15 +707,16 @@ def _input_int_or_num_list( not _is_int_or_num(v, type_str, ge=ge, le=le) for v in _list)): if split_on_dash: print('Invalid input: enter a valid set of dash/comma/whitespace ' - + 'separated integers e.g. 1 3,5-8 , 12') + 'separated integers e.g. 1 3,5-8 , 12') else: print('Invalid input: enter a valid set of comma/whitespace ' - + 'separated integers e.g. 1 3,5 8 , 12') + 'separated integers e.g. 1 3,5 8 , 12') _list = _input_int_or_num_list( type_str, s, ge, le, split_on_dash, remove_duplicates, sort, raise_error, log) return _list + def input_yesno(s=None, default=None): """Interactively prompt the user to enter a y/n question.""" if default is not None: @@ -710,6 +750,7 @@ def input_yesno(s=None, default=None): v = input_yesno(s, default) return v + def input_menu(items, default=None, header=None): """Interactively prompt the user to select from a menu.""" if (not isinstance(items, (tuple, list)) @@ -725,15 +766,14 @@ def input_menu(items, default=None, header=None): else: default_string = '' if header is None: - print( - 'Choose one of the following items ' - + f'(1, {len(items)}){default_string}:') + print('Choose one of the following items ' + f'(1, {len(items)}){default_string}:') else: print(f'{header} (1, {len(items)}){default_string}:') for i, choice in enumerate(items): print(f' {i+1}: {choice}') try: - choice = input() + choice = input() if isinstance(choice, str) and not choice: choice = items.index(default) print(f'{1+choice}') @@ -753,6 +793,7 @@ def input_menu(items, default=None, header=None): choice = input_menu(items, default) return choice + def assert_no_duplicates_in_list_of_dicts(_list, raise_error=False): """ Assert that there are no duplicates in a list of dictionaries. @@ -760,21 +801,22 @@ def assert_no_duplicates_in_list_of_dicts(_list, raise_error=False): if not isinstance(_list, list): illegal_value( _list, '_list', 'assert_no_duplicates_in_list_of_dicts', - raise_error) + raise_error) return None if any(not isinstance(d, dict) for d in _list): illegal_value( _list, '_list', 'assert_no_duplicates_in_list_of_dicts', - raise_error) + raise_error) return None if (len(_list) != len([dict(_tuple) for _tuple in - {tuple(sorted(d.items())) for d in _list}])): + {tuple(sorted(d.items())) for d in _list}])): if raise_error: raise ValueError(f'Duplicate items found in {_list}') logger.error(f'Duplicate items found in {_list}') return None return _list + def assert_no_duplicate_key_in_list_of_dicts(_list, key, raise_error=False): """ Assert that there are no duplicate keys in a list of dictionaries. @@ -803,6 +845,7 @@ def assert_no_duplicate_key_in_list_of_dicts(_list, key, raise_error=False): return None return _list + def assert_no_duplicate_attr_in_list_of_objs(_list, attr, raise_error=False): """ Assert that there are no duplicate attributes in a list of objects. @@ -826,6 +869,7 @@ def assert_no_duplicate_attr_in_list_of_objs(_list, attr, raise_error=False): return None return _list + def file_exists_and_readable(f): """Check if a file exists and is readable.""" if not os_path.isfile(f): @@ -834,14 +878,15 @@ def file_exists_and_readable(f): raise ValueError(f'{f} is not accessible for reading') return f + def draw_mask_1d( ydata, xdata=None, current_index_ranges=None, current_mask=None, select_mask=True, num_index_ranges_max=None, title=None, legend=None, test_mode=False): """Display a 2D plot and have the user select a mask.""" - #RV make color blind friendly - def draw_selections(ax, current_include, current_exclude, - selected_index_ranges): + # RV make color blind friendly + def draw_selections( + ax, current_include, current_exclude, selected_index_ranges): """Draw the selections.""" ax.clear() ax.set_title(title) @@ -950,12 +995,12 @@ def update_index_ranges(mask): if not isinstance(current_index_ranges, (tuple, list)): logger.warning( 'Invalid current_index_ranges parameter ' - + f'({current_index_ranges}, {type(current_index_ranges)})') + f'({current_index_ranges}, {type(current_index_ranges)})') return None, None if not isinstance(select_mask, bool): logger.warning( f'Invalid select_mask parameter ({select_mask}, ' - + f'{type(select_mask)})') + f'{type(select_mask)})') return None, None if num_index_ranges_max is not None: logger.warning( @@ -1063,6 +1108,7 @@ def update_index_ranges(mask): return selected_mask, current_include + def select_image_bounds( a, axis, low=None, upp=None, num_min=None, title='select array bounds', raise_error=False): @@ -1072,8 +1118,9 @@ def select_image_bounds( """ a = np.asarray(a) if a.ndim != 2: - illegal_value(a.ndim, 'array dimension', location='select_image_bounds', - raise_error=raise_error) + illegal_value( + a.ndim, 'array dimension', location='select_image_bounds', + raise_error=raise_error) return None if axis < 0 or axis >= a.ndim: illegal_value( @@ -1089,7 +1136,7 @@ def select_image_bounds( if num_min < 2 or num_min > a.shape[axis]: logger.warning( 'Invalid input for num_min in select_image_bounds, ' - + 'input ignored') + 'input ignored') num_min = 1 if low is None: min_ = 0 @@ -1097,11 +1144,13 @@ def select_image_bounds( low_max = a.shape[axis]-num_min while True: if axis: - quick_imshow(a[:,min_:max_], title=title, aspect='auto', - extent=[min_,max_,a.shape[0],0]) + quick_imshow( + a[:,min_:max_], title=title, aspect='auto', + extent=[min_,max_,a.shape[0],0]) else: - quick_imshow(a[min_:max_,:], title=title, aspect='auto', - extent=[0,a.shape[1], max_,min_]) + quick_imshow( + a[min_:max_,:], title=title, aspect='auto', + extent=[0,a.shape[1], max_,min_]) zoom_flag = input_yesno( 'Set lower data bound (y) or zoom in (n)?', 'y') if zoom_flag: @@ -1122,11 +1171,13 @@ def select_image_bounds( upp_min = min_ while True: if axis: - quick_imshow(a[:,min_:max_], title=title, aspect='auto', - extent=[min_,max_,a.shape[0],0]) + quick_imshow( + a[:,min_:max_], title=title, aspect='auto', + extent=[min_,max_,a.shape[0],0]) else: - quick_imshow(a[min_:max_,:], title=title, aspect='auto', - extent=[0,a.shape[1], max_,min_]) + quick_imshow( + a[min_:max_,:], title=title, aspect='auto', + extent=[0,a.shape[1], max_,min_]) zoom_flag = input_yesno( 'Set upper data bound (y) or zoom in (n)?', 'y') if zoom_flag: @@ -1162,6 +1213,7 @@ def select_image_bounds( clear_imshow(title) return bounds + def select_one_image_bound( a, axis, bound=None, bound_name=None, title='select array bounds', default='y', raise_error=False): @@ -1187,11 +1239,13 @@ def select_one_image_bound( bound_max = a.shape[axis]-1 while True: if axis: - quick_imshow(a[:,min_:max_], title=title, aspect='auto', - extent=[min_,max_,a.shape[0],0]) + quick_imshow( + a[:,min_:max_], title=title, aspect='auto', + extent=[min_,max_,a.shape[0],0]) else: - quick_imshow(a[min_:max_,:], title=title, aspect='auto', - extent=[0,a.shape[1], max_,min_]) + quick_imshow( + a[min_:max_,:], title=title, aspect='auto', + extent=[0,a.shape[1], max_,min_]) zoom_flag = input_yesno( f'Set {bound_name} (y) or zoom in (n)?', 'y') if zoom_flag: @@ -1223,6 +1277,7 @@ def select_one_image_bound( clear_imshow(title) return bound + def clear_imshow(title=None): """Clear an image opened by quick_imshow().""" plt.ioff() @@ -1232,6 +1287,7 @@ def clear_imshow(title=None): raise ValueError(f'Invalid parameter title ({title})') plt.close(fig=title) + def clear_plot(title=None): """Clear an image opened by quick_plot().""" plt.ioff() @@ -1241,6 +1297,7 @@ def clear_plot(title=None): raise ValueError(f'Invalid parameter title ({title})') plt.close(fig=title) + def quick_imshow( a, title=None, path=None, name=None, save_fig=False, save_only=False, clear=True, extent=None, show_grid=False, grid_color='w', @@ -1259,7 +1316,7 @@ def quick_imshow( if not isinstance(block, bool): raise ValueError(f'Invalid parameter block ({block})') if not title: - title='quick imshow' + title = 'quick imshow' if name is None: ttitle = re_sub(r'\s+', '_', title) if path is None: @@ -1314,6 +1371,7 @@ def quick_imshow( if block: plt.show(block=block) + def quick_plot( *args, xerr=None, yerr=None, vlines=None, title=None, xlim=None, ylim=None, xlabel=None, ylabel=None, legend=None, path=None, name=None, @@ -1412,7 +1470,7 @@ def quick_plot( plt.ylabel(ylabel) if show_grid: ax = plt.gca() - ax.grid(color='k')#, linewidth=1) + ax.grid(color='k') # , linewidth=1) if legend is not None: plt.legend(legend) if save_only: diff --git a/CHAP/common/utils/material.py b/CHAP/common/utils/material.py index 2f9b25b..4eecf8a 100755 --- a/CHAP/common/utils/material.py +++ b/CHAP/common/utils/material.py @@ -34,6 +34,7 @@ logger = getLogger(__name__) + class Material: """ Base class for materials in an sin2psi or EDD analysis. Right now @@ -61,8 +62,9 @@ def lattice_parameters(self, index=0): if isinstance(matl, materials.material.Crystal): return [matl.a, matl.b, matl.c] if isinstance(matl, material.Material): - return [l.getVal('angstrom') - for l in self._materials[index].latticeParameters[0:3]] + return [ + lpars.getVal('angstrom') + for lpars in self._materials[index].latticeParameters[0:3]] raise ValueError('Illegal material class type') def ds_unique(self, tth_tol=None, tth_max=None, round_sig=8): @@ -89,8 +91,9 @@ def add_material( raise ValueError('Multiple materials not implemented yet') self._ds_min.append(dmin_angstroms) self._materials.append( - Material.make_material(material_name, material_file, sgnum, - lattice_parameters_angstroms, atoms, pos, dmin_angstroms)) + Material.make_material( + material_name, material_file, sgnum, + lattice_parameters_angstroms, atoms, pos, dmin_angstroms)) def get_ds_unique(self, tth_tol=None, tth_max=None, round_sig=8): """ @@ -116,9 +119,9 @@ def get_ds_unique(self, tth_tol=None, tth_max=None, round_sig=8): if isinstance(m, materials.material.Crystal): powder = simpack.PowderDiffraction(m, en=self._enrgy) hklsi = [hkl for hkl in powder.data - if powder.data[hkl]['active']] + if powder.data[hkl]['active']] ds_i = [m.planeDistance(hkl) for hkl in powder.data - if powder.data[hkl]['active']] + if powder.data[hkl]['active']] mask = [d > self._ds_min[i] for d in ds_i] hkls = np.vstack((hkls, np.array(hklsi)[mask,:])) ds_i = np.array(ds_i)[mask] @@ -149,12 +152,12 @@ def get_ds_unique(self, tth_tol=None, tth_max=None, round_sig=8): self._ds_unique = ds[ds_index_unique] hkl_list = np.vstack( (np.arange(self._ds_unique.shape[0]), ds_index[ds_index_unique], - self._hkls_unique.T, self._ds_unique)).T + self._hkls_unique.T, self._ds_unique)).T logger.info("Unique d's:") for hkl in hkl_list: logger.info( f'{hkl[0]:4.0f} {hkl[1]:.0f} {hkl[2]:.0f} {hkl[3]:.0f} ' - + f'{hkl[4]:.0f} {hkl[5]:.6f}') + f'{hkl[4]:.0f} {hkl[5]:.6f}') return self._hkls_unique, self._ds_unique @@ -170,21 +173,22 @@ def make_material( if not isinstance(material_name, str): raise ValueError( f'Illegal material_name: {material_name} ' - + f'{type(material_name)}') + f'{type(material_name)}') if lattice_parameters_angstroms is not None: if material_file is not None: logger.warning( 'Overwrite lattice_parameters of material_file with input ' - + f'values ({lattice_parameters_angstroms})') + f'values ({lattice_parameters_angstroms})') if isinstance(lattice_parameters_angstroms, (int, float)): lattice_parameters = [lattice_parameters_angstroms] elif isinstance( lattice_parameters_angstroms, (tuple, list, np.ndarray)): lattice_parameters = list(lattice_parameters_angstroms) else: - raise ValueError('Illegal lattice_parameters_angstroms: ' - + f'{lattice_parameters_angstroms} ' - + f'{type(lattice_parameters_angstroms)}') + raise ValueError( + 'Illegal lattice_parameters_angstroms: ' + f'{lattice_parameters_angstroms} ' + f'{type(lattice_parameters_angstroms)}') if material_file is None: if not isinstance(sgnum, int): raise ValueError(f'Illegal sgnum: {sgnum} {type(sgnum)}') @@ -192,7 +196,7 @@ def make_material( or pos is None): raise ValueError( 'Valid inputs for sgnum, lattice_parameters_angstroms and ' - + 'pos are required if materials file is not specified') + 'pos are required if materials file is not specified') if isinstance(pos, str): pos = [pos] use_xu = True @@ -201,33 +205,34 @@ def make_material( if HAVE_HEXRD: pos = np.array([pos]) use_xu = False - elif (np.array(pos).ndim == 2 and np.array(pos).shape[0] > 0 and - np.array(pos).shape[1] == 3): + elif (np.array(pos).ndim == 2 and np.array(pos).shape[0] > 0 + and np.array(pos).shape[1] == 3): if HAVE_HEXRD: pos = np.array(pos) use_xu = False - elif not (np.array(pos).ndim == 1 and isinstance(pos[0], str) and - np.array(pos).size > 0 and HAVE_XU): + elif not (np.array(pos).ndim == 1 and isinstance(pos[0], str) + and np.array(pos).size > 0 and HAVE_XU): raise ValueError( f'Illegal pos (HAVE_XU = {HAVE_XU}): {pos} {type(pos)}') if use_xu: if atoms is None: atoms = [material_name] matl = materials.Crystal( - material_name, materials.SGLattice(sgnum, - *lattice_parameters, atoms=atoms, - pos=list(np.array(pos)))) + material_name, + materials.SGLattice(sgnum, *lattice_parameters, + atoms=atoms, pos=list(np.array(pos)))) else: matl = material.Material(material_name) matl.sgnum = sgnum matl.atominfo = np.vstack((pos.T, np.ones(pos.shape[0]))).T matl.latticeParameters = lattice_parameters matl.dmin = valWUnit( - 'lp', 'length', dmin_angstroms, 'angstrom') + 'lp', 'length', dmin_angstroms, 'angstrom') exclusions = matl.planeData.get_exclusions() powder_intensity = matl.planeData.powder_intensity - exclusions = [exclusion or i >= len(powder_intensity) or - powder_intensity[i] < POEDER_INTENSITY_CUTOFF + exclusions = [ + exclusion or i >= len(powder_intensity) + or powder_intensity[i] < POEDER_INTENSITY_CUTOFF for i, exclusion in enumerate(exclusions)] matl.planeData.set_exclusions(exclusions) logger.debug( @@ -237,22 +242,24 @@ def make_material( if not HAVE_HEXRD: raise ValueError( 'Illegal inputs: must provide detailed material info when ' - + 'hexrd package is unavailable') + 'hexrd package is unavailable') if sgnum is not None: logger.warning( 'Ignore sgnum input when material_file is specified') if not (path.splitext(material_file)[1] in ('.h5', '.hdf5', '.xtal', '.cif')): raise ValueError(f'Illegal material file {material_file}') - matl = material.Material(material_name, material_file, - dmin=valWUnit('lp', 'length', dmin_angstroms, 'angstrom')) + matl = material.Material( + material_name, material_file, + dmin=valWUnit('lp', 'length', dmin_angstroms, 'angstrom')) if lattice_parameters_angstroms is not None: matl.latticeParameters = lattice_parameters exclusions = matl.planeData.get_exclusions() powder_intensity = matl.planeData.powder_intensity - exclusions = [exclusion or i >= len(powder_intensity) or - powder_intensity[i] < POEDER_INTENSITY_CUTOFF - for i, exclusion in enumerate(exclusions)] + exclusions = [ + exclusion or i >= len(powder_intensity) + or powder_intensity[i] < POEDER_INTENSITY_CUTOFF + for i, exclusion in enumerate(exclusions)] matl.planeData.set_exclusions(exclusions) logger.debug( f'powder_intensity = {matl.planeData.powder_intensity}') diff --git a/CHAP/common/utils/scanparsers.py b/CHAP/common/utils/scanparsers.py index e11f271..0123a7f 100755 --- a/CHAP/common/utils/scanparsers.py +++ b/CHAP/common/utils/scanparsers.py @@ -3,13 +3,13 @@ # -*- coding: utf-8 -*- # system modules -import csv -import fnmatch -import json +from csv import reader +from fnmatch import filter as fnmatch_filter +from json import load import os import re -# other modules +# third party modules import numpy as np from pyspec.file.spec import FileSpec @@ -48,10 +48,9 @@ def __init__(self, self._detector_data_path = None def __repr__(self): - return (f'{self.__class__.__name__}(' - + f'{self.spec_file_name}, ' - + f'{self.scan_number}) ' - + f'-- {self.spec_command}') + return (f'{self.__class__.__name__}' + f'({self.spec_file_name}, {self.scan_number}) ' + f'-- {self.spec_command}') @property def spec_file(self): @@ -59,56 +58,67 @@ def spec_file(self): # attribute because it cannot be pickled (and therefore could # cause problems for parallel code that uses ScanParsers). return FileSpec(self.spec_file_name) + @property def scan_path(self): if self._scan_path is None: self._scan_path = self.get_scan_path() return self._scan_path + @property def scan_name(self): if self._scan_name is None: self._scan_name = self.get_scan_name() return self._scan_name + @property def scan_title(self): if self._scan_title is None: self._scan_title = self.get_scan_title() return self._scan_title + @property def spec_scan(self): if self._spec_scan is None: self._spec_scan = self.get_spec_scan() return self._spec_scan + @property def spec_command(self): if self._spec_command is None: self._spec_command = self.get_spec_command() return self._spec_command + @property def spec_macro(self): if self._spec_macro is None: self._spec_macro = self.get_spec_macro() return self._spec_macro + @property def spec_args(self): if self._spec_args is None: self._spec_args = self.get_spec_args() return self._spec_args + @property def spec_scan_npts(self): if self._spec_scan_npts is None: self._spec_scan_npts = self.get_spec_scan_npts() return self._spec_scan_npts + @property def spec_scan_data(self): if self._spec_scan_data is None: self._spec_scan_data = self.get_spec_scan_data() return self._spec_scan_data + @property def spec_positioner_values(self): if self._spec_positioner_values is None: self._spec_positioner_values = self.get_spec_positioner_values() return self._spec_positioner_values + @property def detector_data_path(self): if self._detector_data_path is None: @@ -253,11 +263,11 @@ def get_spec_positioner_value(self, positioner_name): positioner_value = float(positioner_value) except KeyError: raise KeyError(f'{self.scan_title}: motor {positioner_name} ' - + 'not found for this scan') + 'not found for this scan') except ValueError: raise ValueError(f'{self.scan_title}: could not convert value of' - + f' {positioner_name} to float: ' - + f'{positioner_value}') + f' {positioner_name} to float: ' + f'{positioner_value}') return positioner_value @@ -303,31 +313,31 @@ def get_pars(self): :rtype: dict[str,object] """ # JSON file holds titles for columns in the par file - json_files = fnmatch.filter( + json_files = fnmatch_filter( os.listdir(self.scan_path), f'{self.par_file_pattern}.json') if len(json_files) != 1: raise RuntimeError(f'{self.scan_title}: cannot find the ' - + '.json file to decode the .par file') + '.json file to decode the .par file') with open(os.path.join(self.scan_path, json_files[0])) as json_file: - par_file_cols = json.load(json_file) + par_file_cols = load(json_file) try: par_col_names = list(par_file_cols.values()) scann_val_idx = par_col_names.index('SCAN_N') scann_col_idx = int(list(par_file_cols.keys())[scann_val_idx]) except: raise RuntimeError(f'{self.scan_title}: cannot find scan pars ' - + 'without a "SCAN_N" column in the par file') + 'without a "SCAN_N" column in the par file') - par_files = fnmatch.filter( + par_files = fnmatch_filter( os.listdir(self.scan_path), f'{self.par_file_pattern}.par') if len(par_files) != 1: raise RuntimeError(f'{self.scan_title}: cannot find the .par ' - + 'file for this scan directory') + 'file for this scan directory') par_dict = None with open(os.path.join(self.scan_path, par_files[0])) as par_file: - par_reader = csv.reader(par_file, delimiter=' ') + par_reader = reader(par_file, delimiter=' ') for row in par_reader: if len(row) == len(par_col_names): row_scann = int(row[scann_col_idx]) @@ -348,7 +358,7 @@ def get_pars(self): if par_dict is None: raise RuntimeError(f'{self.scan_title}: could not find scan pars ' - + 'for scan number {self.scan_number}') + 'for scan number {self.scan_number}') return par_dict def get_counter_gain(self, counter_name): @@ -362,20 +372,19 @@ def get_counter_gain(self, counter_name): counter_gain = None for comment in self.spec_scan.comments: match = re.search( - f'{counter_name} gain: ' # start of counter gain comments - + '(?P\d+) ' # gain numerical value - + '(?P[m|u|n])A/V', # gain units + f'{counter_name} gain: ' # start of counter gain comments + '(?P\d+) ' # gain numerical value + '(?P[m|u|n])A/V', # gain units comment) if match: unit_prefix = match['unit_prefix'] gain_scalar = 1 if unit_prefix == 'n' \ - else 1e3 if unit_prefix == 'u' \ - else 1e6 + else 1e3 if unit_prefix == 'u' else 1e6 counter_gain = f'{float(match["gain_value"])*gain_scalar} nA/V' if counter_gain is None: raise RuntimeError(f'{self.scan_title}: could not get gain for ' - + f'counter {counter_name}') + f'counter {counter_name}') return counter_gain @@ -396,16 +405,19 @@ def spec_scan_motor_mnes(self): if self._spec_scan_motor_mnes is None: self._spec_scan_motor_mnes = self.get_spec_scan_motor_mnes() return self._spec_scan_motor_mnes + @property def spec_scan_motor_vals(self): if self._spec_scan_motor_vals is None: self._spec_scan_motor_vals = self.get_spec_scan_motor_vals() return self._spec_scan_motor_vals + @property def spec_scan_shape(self): if self._spec_scan_shape is None: self._spec_scan_shape = self.get_spec_scan_shape() return self._spec_scan_shape + @property def spec_scan_dwell(self): if self._spec_scan_dwell is None: @@ -519,7 +531,7 @@ def get_spec_scan_motor_mnes(self): if self.spec_macro in ('tseries', 'loopscan'): return ('Time',) raise RuntimeError(f'{self.scan_title}: cannot determine scan motors ' - + f'for scans of type {self.spec_macro}') + f'for scans of type {self.spec_macro}') def get_spec_scan_motor_vals(self): if self.spec_macro == 'flymesh': @@ -538,7 +550,7 @@ def get_spec_scan_motor_vals(self): if self.spec_macro in ('tseries', 'loopscan'): return self.spec_scan.data[:,0] raise RuntimeError(f'{self.scan_title}: cannot determine scan motors ' - + f'for scans of type {self.spec_macro}') + f'for scans of type {self.spec_macro}') def get_spec_scan_shape(self): if self.spec_macro == 'flymesh': @@ -551,7 +563,7 @@ def get_spec_scan_shape(self): if self.spec_macro in ('tseries', 'loopscan'): return len(np.array(self.spec_scan.data[:,0])) raise RuntimeError(f'{self.scan_title}: cannot determine scan shape ' - + f'for scans of type {self.spec_macro}') + f'for scans of type {self.spec_macro}') def get_spec_scan_dwell(self): if self.spec_macro in ('flymesh', 'flyscan'): @@ -559,7 +571,7 @@ def get_spec_scan_dwell(self): if self.spec_macro in ('tseries', 'loopscan'): return float(self.spec_args[1]) raise RuntimeError(f'{self.scan_title}: cannot determine dwell for ' - + f'scans of type {self.spec_macro}') + f'scans of type {self.spec_macro}') def get_detector_data_path(self): return os.path.join(self.scan_path, self.scan_title) @@ -575,23 +587,22 @@ def get_scan_title(self): def get_detector_data_file(self, detector_prefix, scan_step_index:int): scan_step = self.get_scan_step(scan_step_index) - file_indices = [f'{scan_step[i]:03d}' \ - for i in range(len(self.spec_scan_shape)) \ + file_indices = [f'{scan_step[i]:03d}' + for i in range(len(self.spec_scan_shape)) if self.spec_scan_shape[i] != 1] - file_name = (f'{self.scan_name}_' - + f'{detector_prefix}_' - + f'{self.scan_number:03d}_' - + '_'.join(file_indices) - + '.tiff') + file_name = f'{self.scan_name}_{detector_prefix}_' \ + f'{self.scan_number:03d}_{"_".join(file_indices)}.tiff' file_name_full = os.path.join(self.detector_data_path, file_name) if os.path.isfile(file_name_full): return file_name_full raise RuntimeError(f'{self.scan_title}: could not find detector image ' - + f'file for detector {detector_prefix} scan step ' - + f'({scan_step})') + f'file for detector {detector_prefix} scan step ' + f'({scan_step})') def get_detector_data(self, detector_prefix, scan_step_index:int): + # third party modules from pyspec.file.tiff import TiffFile + image_file = self.get_detector_data_file(detector_prefix, scan_step_index) with TiffFile(image_file) as tiff_file: @@ -614,16 +625,19 @@ def get_detector_data_file(self, detector_prefix, scan_step_index:int): if os.path.isfile(file_name_full): return file_name_full raise RuntimeError(f'{self.scan_title}: could not find detector image ' - + f'file for detector {detector_prefix} scan step ' - + f'({scan_step_index})') + f'file for detector {detector_prefix} scan step ' + f'({scan_step_index})') def get_detector_data(self, detector_prefix, scan_step_index:int): - import h5py - detector_file = self.get_detector_data_file(detector_prefix, - scan_step_index) + # third party modules + from h5py import File + + detector_file = self.get_detector_data_file( + detector_prefix, scan_step_index) scan_step = self.get_scan_step(scan_step_index) - with h5py.File(detector_file) as h5_file: - detector_data = h5_file['/entry/instrument/detector/data'][scan_step[0]] + with File(detector_file) as h5_file: + detector_data = \ + h5_file['/entry/instrument/detector/data'][scan_step[0]] return detector_data @@ -640,7 +654,7 @@ def get_spec_scan_motor_mnes(self): if self.spec_macro in ('tseries', 'loopscan'): return ('Time',) raise RuntimeError(f'{self.scan_title}: cannot determine scan motors ' - + f'for scans of type {self.spec_macro}') + f'for scans of type {self.spec_macro}') def get_spec_scan_motor_vals(self): if self.spec_macro == 'flymesh': @@ -659,7 +673,7 @@ def get_spec_scan_motor_vals(self): if self.spec_macro in ('tseries', 'loopscan'): return self.spec_scan.data[:,0] raise RuntimeError(f'{self.scan_title}: cannot determine scan motors ' - + f'for scans of type {self.spec_macro}') + f'for scans of type {self.spec_macro}') def get_spec_scan_shape(self): if self.spec_macro == 'flymesh': @@ -672,7 +686,7 @@ def get_spec_scan_shape(self): if self.spec_macro in ('tseries', 'loopscan'): return len(np.array(self.spec_scan.data[:,0])) raise RuntimeError(f'{self.scan_title}: cannot determine scan shape ' - + f'for scans of type {self.spec_macro}') + f'for scans of type {self.spec_macro}') def get_spec_scan_dwell(self): if self.spec_macro == 'flymesh': @@ -680,7 +694,7 @@ def get_spec_scan_dwell(self): if self.spec_macro == 'flyscan': return float(self.spec_args[-1]) raise RuntimeError(f'{self.scan_title}: cannot determine dwell time ' - + f'for scans of type {self.spec_macro}') + f'for scans of type {self.spec_macro}') def get_detector_data_path(self): return os.path.join(self.scan_path, str(self.scan_number)) @@ -690,23 +704,25 @@ def get_detector_data_file(self, detector_prefix, scan_step_index:int): if len(scan_step) == 1: scan_step = (0, *scan_step) file_name_pattern = (f'{detector_prefix}_' - + f'{self.scan_name}_*_' - + f'{scan_step[0]}_data_' - + f'{(scan_step[1]+1):06d}.h5') - file_name_matches = fnmatch.filter( + f'{self.scan_name}_*_' + f'{scan_step[0]}_data_' + f'{(scan_step[1]+1):06d}.h5') + file_name_matches = fnmatch_filter( os.listdir(self.detector_data_path), file_name_pattern) if len(file_name_matches) == 1: return os.path.join(self.detector_data_path, file_name_matches[0]) raise RuntimeError(f'{self.scan_title}: could not find detector image ' - + f'file for detector {detector_prefix} scan step ' - + f'({scan_step_index})') + f'file for detector {detector_prefix} scan step ' + f'({scan_step_index})') def get_detector_data(self, detector_prefix, scan_step_index:int): - import h5py - image_file = self.get_detector_data_file(detector_prefix, - scan_step_index) - with h5py.File(image_file) as h5_file: + # third party modules + from h5py import File + + image_file = self.get_detector_data_file( + detector_prefix, scan_step_index) + with File(image_file) as h5_file: image_data = h5_file['/entry/data/data'][0] return image_data @@ -731,26 +747,31 @@ def scan_type(self): if self._scan_type is None: self._scan_type = self.get_scan_type() return self._scan_type + @property def theta_vals(self): if self._theta_vals is None: self._theta_vals = self.get_theta_vals() return self._theta_vals + @property def horizontal_shift(self): if self._horizontal_shift is None: self._horizontal_shift = self.get_horizontal_shift() return self._horizontal_shift + @property def vertical_shift(self): if self._vertical_shift is None: self._vertical_shift = self.get_vertical_shift() return self._vertical_shift + @property def starting_image_index(self): if self._starting_image_index is None: self._starting_image_index = self.get_starting_image_index() return self._starting_image_index + @property def starting_image_offset(self): if self._starting_image_offset is None: @@ -834,10 +855,10 @@ def get_spec_scan_npts(self): if len(self.spec_args) == 5: return int(self.spec_args[3])+1 raise RuntimeError(f'{self.scan_title}: cannot obtain number of ' - + f'points from {self.spec_macro} with ' - + f'arguments {self.spec_args}') + f'points from {self.spec_macro} with ' + f'arguments {self.spec_args}') raise RuntimeError(f'{self.scan_title}: cannot determine number of ' - + f'points for scans of type {self.spec_macro}') + f'points for scans of type {self.spec_macro}') def get_theta_vals(self): if self.spec_macro == 'flyscan': @@ -849,10 +870,10 @@ def get_theta_vals(self): 'end': float(self.spec_args[2]), 'num': int(self.spec_args[3])+1} raise RuntimeError(f'{self.scan_title}: cannot obtain theta values' - + f' from {self.spec_macro} with arguments ' - + f'{self.spec_args}') + f' from {self.spec_macro} with arguments ' + f'{self.spec_args}') raise RuntimeError(f'{self.scan_title}: cannot determine theta values ' - + f'for scans of type {self.spec_macro}') + f'for scans of type {self.spec_macro}') def get_horizontal_shift(self): return 0.0 @@ -867,9 +888,11 @@ def get_starting_image_offset(self): return 1 def get_num_image(self, detector_prefix): - import h5py + # third party modules + from h5py import File + detector_file = self.get_detector_data_file(detector_prefix) - with h5py.File(detector_file) as h5_file: + with File(detector_file) as h5_file: num_image = h5_file['/entry/instrument/detector/data'].shape[0] return num_image-self.starting_image_offset @@ -883,33 +906,34 @@ def get_detector_data_file(self, detector_prefix): if os.path.isfile(file_name_full): return file_name_full raise RuntimeError(f'{self.scan_title}: could not find detector image ' - + f'file for detector {detector_prefix}') + f'file for detector {detector_prefix}') + + def get_all_detector_data_in_file( + self, detector_prefix, scan_step_index=None): + # third party modules + from h5py import File - def get_all_detector_data_in_file(self, - detector_prefix, - scan_step_index=None): - import h5py detector_file = self.get_detector_data_file(detector_prefix) - with h5py.File(detector_file) as h5_file: + with File(detector_file) as h5_file: if scan_step_index is None: detector_data = h5_file['/entry/instrument/detector/data'][ - self.starting_image_index:] + self.starting_image_index:] elif isinstance(scan_step_index, int): detector_data = h5_file['/entry/instrument/detector/data'][ - self.starting_image_index+scan_step_index] + self.starting_image_index+scan_step_index] elif (isinstance(scan_step_index, (list, tuple)) - and len(scan_step_index) == 2): + and len(scan_step_index) == 2): detector_data = h5_file['/entry/instrument/detector/data'][ - self.starting_image_index+scan_step_index[0]: - self.starting_image_index+scan_step_index[1]] + self.starting_image_index+scan_step_index[0]: + self.starting_image_index+scan_step_index[1]] else: raise ValueError('Invalid parameter scan_step_index ' - + f'({scan_step_index})') + f'({scan_step_index})') return detector_data def get_detector_data(self, detector_prefix, scan_step_index=None): - return self.get_all_detector_data_in_file(detector_prefix, - scan_step_index) + return self.get_all_detector_data_in_file( + detector_prefix, scan_step_index) class SMBRotationScanParser(RotationScanParser, SMBScanParser): @@ -926,14 +950,14 @@ def get_spec_scan_npts(self): if self.spec_macro in ('slew_ome','rams4_slew_ome'): return int(self.pars['nframes_real']) raise RuntimeError(f'{self.scan_title}: cannot determine number of ' - + f'points for scans of type {self.spec_macro}') + f'points for scans of type {self.spec_macro}') def get_scan_type(self): scan_type = self.pars.get('tomo_type', self.pars.get('tomotype', None)) if scan_type is None: raise RuntimeError(f'{self.scan_title}: cannot determine ' - + 'the scan_type') + 'the scan_type') return scan_type def get_theta_vals(self): @@ -951,7 +975,7 @@ def get_horizontal_shift(self): def get_vertical_shift(self): vertical_shift = self.pars.get( - 'rams4z', self.pars.get('ramsz', None)) + 'rams4z', self.pars.get('ramsz', None)) if vertical_shift is None: raise RuntimeError( f'{self.scan_title}: cannot determine the vertical shift') @@ -962,7 +986,7 @@ def get_starting_image_index(self): return int(self.pars['junkstart']) except: raise RuntimeError(f'{self.scan_title}: cannot determine first ' - + 'detector image index') + 'detector image index') def get_starting_image_offset(self): try: @@ -970,7 +994,7 @@ def get_starting_image_offset(self): - self.get_starting_image_index()) except: raise RuntimeError(f'{self.scan_title}: cannot determine index ' - + 'offset of first good detector image') + 'offset of first good detector image') def get_num_image(self, detector_prefix=None): try: @@ -985,7 +1009,7 @@ def get_num_image(self, detector_prefix=None): # return len(files)-self.starting_image_offset except: raise RuntimeError(f'{self.scan_title}: cannot determine the ' - + 'number of good detector images') + 'number of good detector images') def get_detector_data_path(self): return os.path.join(self.scan_path, str(self.scan_number), 'nf') @@ -996,30 +1020,32 @@ def get_detector_data_file(self, scan_step_index:int): if os.path.isfile(file_name_full): return file_name_full raise RuntimeError(f'{self.scan_title}: could not find detector image ' - + f'file for scan step ({scan_step_index})') + f'file for scan step ({scan_step_index})') def get_detector_data(self, detector_prefix, scan_step_index=None): if scan_step_index is None: detector_data = [] for index in range(len(self.get_num_image(detector_prefix))): - detector_data.append(self.get_detector_data(detector_prefix, - index)) + detector_data.append( + self.get_detector_data(detector_prefix, index)) detector_data = np.asarray(detector_data) elif isinstance(scan_step_index, int): - image_file = self.get_detector_data_file(scan_step_index) + # third party modules from pyspec.file.tiff import TiffFile + + image_file = self.get_detector_data_file(scan_step_index) with TiffFile(image_file) as tiff_file: detector_data = tiff_file.asarray() elif (isinstance(scan_step_index, (list, tuple)) - and len(scan_step_index) == 2): + and len(scan_step_index) == 2): detector_data = [] for index in range(scan_step_index[0], scan_step_index[1]): - detector_data.append(self.get_detector_data(detector_prefix, - index)) + detector_data.append( + self.get_detector_data(detector_prefix, index)) detector_data = np.asarray(detector_data) else: raise ValueError('Invalid parameter scan_step_index ' - + f'({scan_step_index})') + f'({scan_step_index})') return detector_data @@ -1073,7 +1099,7 @@ def get_spec_scan_npts(self): if self.spec_scan == 'wbslew_scan': return 1 raise RuntimeError(f'{self.scan_title}: cannot determine number of ' - + f'points for scans of type {self.spec_macro}') + f'points for scans of type {self.spec_macro}') def get_dwell_time(self): if self.spec_macro == 'tseries': @@ -1083,7 +1109,7 @@ def get_dwell_time(self): if self.spec_macro == 'wbslew_scan': return float(self.spec_args[3]) raise RuntimeError(f'{self.scan_title}: cannot determine dwell time ' - + f'for scans of type {self.spec_macro}') + f'for scans of type {self.spec_macro}') def get_detector_num_bins(self, detector_prefix): with open(self.get_detector_data_file(detector_prefix)) \ @@ -1092,27 +1118,24 @@ def get_detector_num_bins(self, detector_prefix): for line in lines: if line.startswith('#@CHANN'): try: - line_prefix, \ - number_saved, \ - first_saved, \ - last_saved, \ - reduction_coef = line.split() + line_prefix, number_saved, first_saved, last_saved, \ + reduction_coef = line.split() return int(number_saved) except: continue raise RuntimeError(f'{self.scan_title}: could not find num_bins for ' - + f'detector {detector_prefix}') + f'detector {detector_prefix}') def get_detector_data_path(self): return self.scan_path - def get_detector_data_file(self, detector_prefix, scan_step_index:int=0): + def get_detector_data_file(self, detector_prefix, scan_step_index=0): file_name = f'spec.log.scan{self.scan_number}.mca1.mca' file_name_full = os.path.join(self.detector_data_path, file_name) if os.path.isfile(file_name_full): return file_name_full - raise RuntimeError(f'{self.scan_title}: could not find detector image ' - + 'file') + raise RuntimeError( + f'{self.scan_title}: could not find detector image file') def get_all_detector_data(self, detector_prefix): # This should be easy with pyspec, but there are bugs in @@ -1123,7 +1146,7 @@ def get_all_detector_data(self, detector_prefix): data = [] with open(self.get_detector_data_file(detector_prefix)) \ - as detector_file: + as detector_file: lines = [line.strip("\\\n") for line in detector_file.readlines()] num_bins = self.get_detector_num_bins(detector_prefix) diff --git a/CHAP/common/writer.py b/CHAP/common/writer.py index 908559f..e05c17a 100755 --- a/CHAP/common/writer.py +++ b/CHAP/common/writer.py @@ -53,7 +53,7 @@ def _write(self, data, filename, force_overwrite=False): if not isinstance(data, NXobject): raise TypeError('Cannot write object of type ' - + f'{type(data).__name__} to a NeXus file.') + f'{type(data).__name__} to a NeXus file.') mode = 'w' if force_overwrite else 'w-' data.save(filename, mode=mode) @@ -83,13 +83,13 @@ def _write(self, data, filename, force_overwrite=False): import yaml if not isinstance(data, (dict, list)): - raise(TypeError(f'{self.__name__}.write: input data must be ' - + 'a dict or list.')) + raise TypeError( + f'{self.__name__}.write: input data must be a dict or list.') if not force_overwrite: if os.path.isfile(filename): - raise(RuntimeError(f'{self.__name__}: {filename} already ' - + 'exists.')) + raise RuntimeError( + f'{self.__name__}: {filename} already exists.') with open(filename, 'w') as outf: yaml.dump(data, outf, sort_keys=False) diff --git a/CHAP/edd/models.py b/CHAP/edd/models.py index 573fc43..a4b1429 100644 --- a/CHAP/edd/models.py +++ b/CHAP/edd/models.py @@ -1,3 +1,4 @@ +# third party modules import numpy as np from pathlib import PosixPath from pydantic import (BaseModel, @@ -12,14 +13,15 @@ class MCACeriaCalibrationConfig(BaseModel): - '''Class representing metadata required to perform a Ceria calibration for an + """ + Class representing metadata required to perform a Ceria calibration for an MCA detector. :ivar spec_file: Path to the SPEC file containing the CeO2 scan :ivar scan_number: Number of the CeO2 scan in `spec_file` :ivar scan_step_index: Index of the scan step to use for calibration, - optional. If not specified, the calibration routine will be performed on - the average of all MCA spectra for the scan. + optional. If not specified, the calibration routine will be performed + on the average of all MCA spectra for the scan. :ivar flux_file: csv file containing station beam energy in eV (column 0) and flux (column 1) @@ -28,8 +30,8 @@ class MCACeriaCalibrationConfig(BaseModel): :ivar num_bins: number of channels on the MCA to calibrate :ivar max_energy_kev: maximum channel energy of the MCA in keV - :ivar hexrd_h5_material_file: path to a HEXRD materials.h5 file containing an - entry for the material properties. + :ivar hexrd_h5_material_file: path to a HEXRD materials.h5 file containing + an entry for the material properties. :ivar hexrd_h5_material_name: Name of the material entry in `hexrd_h5_material_file`, defaults to `'CeO2'`. :ivar lattice_parameter_angstrom: lattice spacing in angstrom to use for @@ -59,9 +61,9 @@ class MCACeriaCalibrationConfig(BaseModel): :ivar max_iter: maximum number of iterations of the calibration routine, defaults to `10`. :ivar tune_tth_tol: stop iteratively tuning 2&theta when an iteration - produces a change in the tuned value of 2&theta that is smaller than this - value, defaults to `1e-8`. - ''' + produces a change in the tuned value of 2&theta that is smaller than + this value, defaults to `1e-8`. + """ spec_file: FilePath scan_number: conint(gt=0) @@ -74,7 +76,8 @@ class MCACeriaCalibrationConfig(BaseModel): max_energy_kev: confloat(gt=0) hexrd_h5_material_file: FilePath - hexrd_h5_material_name: constr(strip_whitespace=True, min_length=1) = 'CeO2' + hexrd_h5_material_name: constr( + strip_whitespace=True, min_length=1) = 'CeO2' lattice_parameter_angstrom: confloat(gt=0) = 5.41153 tth_max: confloat(gt=0, allow_inf_nan=False) = 90.0 @@ -98,20 +101,22 @@ class MCACeriaCalibrationConfig(BaseModel): @validator('fit_include_bin_ranges', each_item=True) def validate_include_bin_range(cls, value, values): - '''Ensure no bin ranges are outside the boundary of the detector''' + """Ensure no bin ranges are outside the boundary of the detector""" num_bins = values.get('num_bins') value[1] = min(value[1], num_bins) - return(value) + return value def mca_data(self): - '''Get the 1D array of MCA data to use for calibration. - + """Get the 1D array of MCA data to use for calibration. + :return: MCA data :rtype: np.ndarray - ''' + """ + # local modules + from CHAP.common.utils.scanparsers \ + import SMBMCAScanParser as ScanParser - from CHAP.common.utils.scanparsers import SMBMCAScanParser as ScanParser scanparser = ScanParser(self.spec_file, self.scan_number) if self.scan_step_index is None: data = scanparser.get_all_detector_data(self.detector_name) @@ -120,97 +125,105 @@ def mca_data(self): else: data = data[0] else: - data = scanparser.get_detector_data(self.detector_name, self.scan_step_index) + data = scanparser.get_detector_data( + self.detector_name, self.scan_step_index) - return(np.array(data)) + return np.array(data) def mca_mask(self): - '''Get a boolean mask array to use on MCA data before fitting. + """Get a boolean mask array to use on MCA data before fitting. :return: boolean mask array :rtype: numpy.ndarray - ''' + """ mask = np.asarray([False]*self.num_bins) bin_indices = np.arange(self.num_bins) for min_, max_ in self.fit_include_bin_ranges: - _mask = np.logical_and(bin_indices > min_, bin_indices < max_) + _mask = np.logical_and(bin_indices > min_, bin_indices < max_) mask = np.logical_or(mask, _mask) - return(mask) + return mask def flux_correction_interpolation_function(self): - '''Get an interpolation function to correct MCA data for relative energy + """ + Get an interpolation function to correct MCA data for relative energy flux of the incident beam. :return: energy flux correction interpolation function :rtype: scipy.interpolate._polyint._Interpolator1D - ''' + """ flux = np.loadtxt(self.flux_file) energies = flux[:,0]/1.e3 relative_intensities = flux[:,1]/np.max(flux[:,1]) interpolation_function = interp1d(energies, relative_intensities) - return(interpolation_function) + return interpolation_function def material(self): - '''Get CeO2 as a `CHAP.common.utils.material.Material` object. + """Get CeO2 as a `CHAP.common.utils.material.Material` object. :return: CeO2 material :rtype: CHAP.common.utils.material.Material - ''' - + """ + # local modules from CHAP.common.utils.material import Material - material = Material(material_name=self.hexrd_h5_material_name, - material_file=self.hexrd_h5_material_file, - lattice_parameters_angstroms=self.lattice_parameter_angstrom) + + material = Material( + material_name=self.hexrd_h5_material_name, + material_file=self.hexrd_h5_material_file, + lattice_parameters_angstroms=self.lattice_parameter_angstrom) # The following kwargs will be needed if we allow the material to be # built using xrayutilities (for now, we only allow hexrd to make the # material): - # sgnum=225, - # atoms=['Ce4p', 'O2mdot'], - # pos=[(0.,0.,0.), (0.25,0.75,0.75)], - # enrgy=50000.) # Why do we need to specify an energy to get HKLs when using xrayutilities? - return(material) + # sgnum=225, + # atoms=['Ce4p', 'O2mdot'], + # pos=[(0.,0.,0.), (0.25,0.75,0.75)], + # enrgy=50000.) + # Why do we need to specify an energy to get HKLs when using + # xrayutilities? + return material def unique_ds(self): - '''Get a list of unique HKLs and their lattice spacings - + """Get a list of unique HKLs and their lattice spacings + :return: unique HKLs and their lattice spacings in angstroms :rtype: np.ndarray, np.ndarray - ''' - + """ + unique_hkls, unique_ds = self.material().get_ds_unique( tth_tol=self.hkl_tth_tol, tth_max=self.tth_max) - return(unique_hkls, unique_ds) + return unique_hkls, unique_ds def fit_ds(self): - '''Get a list of HKLs and their lattice spacings that will be fit in the + """ + Get a list of HKLs and their lattice spacings that will be fit in the calibration routine - + :return: HKLs to fit and their lattice spacings in angstroms :rtype: np.ndarray, np.ndarray - ''' - + """ + unique_hkls, unique_ds = self.unique_ds() fit_hkls = np.array([unique_hkls[i] for i in self.fit_hkls]) fit_ds = np.array([unique_ds[i] for i in self.fit_hkls]) - return(fit_hkls, fit_ds) + return fit_hkls, fit_ds def dict(self): - '''Return a representation of this configuration in a dictionary that is + """ + Return a representation of this configuration in a dictionary that is suitable for dumping to a YAML file (one that converts all instances of fields with type `PosixPath` to `str`). :return: dictionary representation of the configuration. :rtype: dict - ''' + """ d = super().dict() for k,v in d.items(): if isinstance(v, PosixPath): d[k] = str(v) - return(d) + return d diff --git a/CHAP/edd/processor.py b/CHAP/edd/processor.py index fb4e355..da0bcc2 100755 --- a/CHAP/edd/processor.py +++ b/CHAP/edd/processor.py @@ -8,11 +8,15 @@ """ # system modules -import json +from json import dumps + +# third party modules +import numpy as np # local modules from CHAP.processor import Processor + class MCACeriaCalibrationProcessor(Processor): """A Processor using a CeO2 scan to obtain tuned values for the bragg diffraction angle and linear correction parameters for MCA @@ -55,7 +59,7 @@ def get_config(self, data): values taken from `data`. :rtype: MCACeriaCalibrationConfig """ - + # local modules from CHAP.edd.models import MCACeriaCalibrationConfig calibration_config = False @@ -67,8 +71,8 @@ def get_config(self, data): break if not calibration_config: - raise ValueError('No MCA ceria calibration configuration found in ' - + 'input data') + raise ValueError( + 'No MCA ceria calibration configuration found in input data') return MCACeriaCalibrationConfig(**calibration_config) @@ -86,20 +90,20 @@ def calibrate(self, calibration_config): intercept :rtype: float, float, float """ + # third party modules + from scipy.constants import physical_constants + # local modules from CHAP.common.utils.fit import Fit, FitMultipeak - import numpy as np - from scipy.constants import physical_constants - hc = (physical_constants['Planck constant in eV/Hz'][0] - * physical_constants['speed of light in vacuum'][0] - * 1e7) # We'll work in keV and A, not eV and m. + # We'll work in keV and A, not eV and m. + hc = 1e7 * physical_constants['Planck constant in eV/Hz'][0] \ + * physical_constants['speed of light in vacuum'][0] # Collect raw MCA data of interest mca_data = calibration_config.mca_data() - mca_bin_energies = (np.arange(0, calibration_config.num_bins) - * (calibration_config.max_energy_kev - / calibration_config.num_bins)) + mca_bin_energies = np.arange(0, calibration_config.num_bins) \ + * (calibration_config.max_energy_kev/calibration_config.num_bins) # Mask out the corrected MCA data for fitting mca_mask = calibration_config.mca_mask() @@ -107,9 +111,10 @@ def calibrate(self, calibration_config): fit_mca_intensities = mca_data[mca_mask] # Correct raw MCA data for variable flux at different energies - flux_correct = calibration_config.flux_correction_interpolation_function() + flux_correct = \ + calibration_config.flux_correction_interpolation_function() mca_intensity_weights = flux_correct(fit_mca_energies) - fit_mca_intensities = fit_mca_intensities / mca_intensity_weights + fit_mca_intensities = fit_mca_intensities/mca_intensity_weights # Get the HKLs and lattice spacings that will be used for # fitting @@ -119,11 +124,11 @@ def calibrate(self, calibration_config): for iter_i in range(calibration_config.max_iter): - ### Perform the uniform fit first ### + # Perform the uniform fit first # Get expected peak energy locations for this iteration's # starting value of tth - fit_lambda = 2.0 * fit_ds * np.sin(0.5*np.radians(tth)) + fit_lambda = 2.0*fit_ds*np.sin(0.5*np.radians(tth)) fit_E0 = hc / fit_lambda # Run the uniform fit @@ -137,8 +142,9 @@ def calibrate(self, calibration_config): # Extract values of interest from the best values for the # uniform fit parameters - uniform_fit_centers = [best_values[f'peak{i+1}_center'] \ - for i in range(len(calibration_config.fit_hkls))] + uniform_fit_centers = [ + best_values[f'peak{i+1}_center'] + for i in range(len(calibration_config.fit_hkls))] # uniform_a = best_values['scale_factor'] # uniform_strain = np.log( # (uniform_a @@ -147,7 +153,7 @@ def calibrate(self, calibration_config): # uniform_rel_rms_error = (np.linalg.norm(residual) # / np.linalg.norm(fit_mca_intensities)) - ### Next, perform the unconstrained fit ### + # Next, perform the unconstrained fit # Use the peak locations found in the uniform fit as the # initial guesses for peak locations in the unconstrained @@ -163,19 +169,16 @@ def calibrate(self, calibration_config): # Extract values of interest from the best values for the # unconstrained fit parameters unconstrained_fit_centers = np.array( - [best_values[f'peak{i+1}_center'] \ + [best_values[f'peak{i+1}_center'] for i in range(len(calibration_config.fit_hkls))]) - unconstrained_a = (0.5 * hc * np.sqrt(c_1) - / (unconstrained_fit_centers - * abs(np.sin(0.5*np.radians(tth))))) + unconstrained_a = 0.5*hc*np.sqrt(c_1) \ + / (unconstrained_fit_centers*abs(np.sin(0.5*np.radians(tth)))) unconstrained_strains = np.log( - (unconstrained_a - / calibration_config.lattice_parameter_angstrom)) + unconstrained_a/calibration_config.lattice_parameter_angstrom) unconstrained_strain = np.mean(unconstrained_strains) - unconstrained_tth = tth * (1.0 + unconstrained_strain) - unconstrained_rel_rms_error = (np.linalg.norm(residual) - / np.linalg.norm(fit_mca_intensities)) - + unconstrained_tth = tth * (1.0+unconstrained_strain) + unconstrained_rel_rms_error = ( + np.linalg.norm(residual)/np.linalg.norm(fit_mca_intensities)) # Update tth for the next iteration of tuning prev_tth = tth @@ -198,6 +201,7 @@ def calibrate(self, calibration_config): return float(tth), float(slope), float(intercept) + class MCADataProcessor(Processor): """A Processor to return data from an MCA, restuctured to incorporate the shape & metadata associated with a map @@ -236,7 +240,7 @@ def get_configs(self, data): field values taken from `data`. :rtype: tuple[MapConfig, MCACeriaCalibrationConfig] """ - + # local modules from CHAP.common.models.map import MapConfig from CHAP.edd.models import MCACeriaCalibrationConfig @@ -255,7 +259,7 @@ def get_configs(self, data): raise ValueError('No map configuration found in input data') if not calibration_config: raise ValueError('No MCA ceria calibration configuration found in ' - + 'input data') + 'input data') return (MapConfig(**map_config), MCACeriaCalibrationConfig(**calibration_config)) @@ -274,14 +278,14 @@ def get_nxroot(self, map_config, calibration_config): :return: a map of the calibrated and flux-corrected MCA data :rtype: nexusformat.nexus.NXroot """ - - from CHAP.common import MapProcessor - + # third party modules from nexusformat.nexus import (NXdata, NXdetector, NXinstrument, NXroot) - import numpy as np + + # local modules + from CHAP.common import MapProcessor nxroot = NXroot() @@ -290,18 +294,17 @@ def get_nxroot(self, map_config, calibration_config): nxentry.instrument = NXinstrument() nxentry.instrument.detector = NXdetector() - nxentry.instrument.detector.calibration_configuration = json.dumps( + nxentry.instrument.detector.calibration_configuration = dumps( calibration_config.dict()) nxentry.instrument.detector.data = NXdata() nxdata = nxentry.instrument.detector.data nxdata.raw = np.empty((*map_config.shape, calibration_config.num_bins)) nxdata.raw.attrs['units'] = 'counts' - nxdata.channel_energy = (calibration_config.slope_calibrated - * np.arange(0, calibration_config.num_bins) - * (calibration_config.max_energy_kev - / calibration_config.num_bins) - + calibration_config.intercept_calibrated) + nxdata.channel_energy = calibration_config.slope_calibrated \ + * np.arange(0, calibration_config.num_bins) \ + * (calibration_config.max_energy_kev/calibration_config.num_bins) \ + + calibration_config.intercept_calibrated nxdata.channel_energy.attrs['units'] = 'keV' for scans in map_config.spec_scans: @@ -333,6 +336,9 @@ def get_nxroot(self, map_config, calibration_config): return nxroot + if __name__ == '__main__': + # local modules from CHAP.processor import main + main() diff --git a/CHAP/edd/reader.py b/CHAP/edd/reader.py index 709d3d3..b71407d 100755 --- a/CHAP/edd/reader.py +++ b/CHAP/edd/reader.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python if __name__ == '__main__': from CHAP.reader import main diff --git a/CHAP/edd/writer.py b/CHAP/edd/writer.py index b00fa9f..f786837 100755 --- a/CHAP/edd/writer.py +++ b/CHAP/edd/writer.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python if __name__ == '__main__': from CHAP.writer import main diff --git a/CHAP/inference/processor.py b/CHAP/inference/processor.py index 7702180..f46e42a 100755 --- a/CHAP/inference/processor.py +++ b/CHAP/inference/processor.py @@ -13,6 +13,7 @@ # local modules from CHAP import Processor + class TFaaSImageProcessor(Processor): """A Processor to get predictions from TFaaS inference server.""" @@ -37,25 +38,29 @@ def _process(self, data, url, model, verbose): :return: `data` :rtype: object """ + # system modules + from pathlib import Path + # local modules from MLaaS.tfaas_client import predictImage - from pathlib import Path self.logger.info(f'input data {type(data)}') if isinstance(data, str) and Path(data).is_file(): img_file = data data = predictImage(url, img_file, model, verbose) else: + # third party modules + from requests import Session + rdict = data[0] - import requests img = rdict['data'] - session = requests.Session() + session = Session() rurl = url + '/predict/image' payload = {'model': model} files = {'image': img} self.logger.info( f'HTTP request {rurl} with image file and {payload} payload') - req = session.post(rurl, files=files, data=payload ) + req = session.post(rurl, files=files, data=payload) data = req.content data = data.decode('utf-8').replace('\n', '') self.logger.info(f'HTTP response {data}') @@ -64,5 +69,7 @@ def _process(self, data, url, model, verbose): if __name__ == '__main__': + # local modules from CHAP.processor import main + main() diff --git a/CHAP/inference/reader.py b/CHAP/inference/reader.py index 709d3d3..b71407d 100755 --- a/CHAP/inference/reader.py +++ b/CHAP/inference/reader.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python if __name__ == '__main__': from CHAP.reader import main diff --git a/CHAP/inference/writer.py b/CHAP/inference/writer.py index b00fa9f..f786837 100755 --- a/CHAP/inference/writer.py +++ b/CHAP/inference/writer.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python if __name__ == '__main__': from CHAP.writer import main diff --git a/CHAP/pipeline.py b/CHAP/pipeline.py index 074bc7f..91f2bc0 100755 --- a/CHAP/pipeline.py +++ b/CHAP/pipeline.py @@ -4,7 +4,7 @@ """ File : pipeline.py Author : Valentin Kuznetsov -Description: +Description: """ # system modules @@ -16,7 +16,7 @@ class Pipeline(): """Pipeline represent generic Pipeline class""" def __init__(self, items=None, kwds=None): """Pipeline class constructor - + :param items: list of objects :param kwds: list of method args for individual objects """ diff --git a/CHAP/processor.py b/CHAP/processor.py index 5ad750b..ac388c8 100755 --- a/CHAP/processor.py +++ b/CHAP/processor.py @@ -9,9 +9,9 @@ # system modules import argparse -import inspect +from inspect import getfullargspec import logging -import sys +from sys import modules from time import time @@ -35,8 +35,8 @@ def process(self, data, **_process_kwargs): self.logger.info(f'Executing "process" with type(data)={type(data)}') _valid_process_args = {} - allowed_args = inspect.getfullargspec(self._process).args \ - + inspect.getfullargspec(self._process).kwonlyargs + allowed_args = getfullargspec(self._process).args \ + + getfullargspec(self._process).kwonlyargs for k, v in _process_kwargs.items(): if k in allowed_args: _valid_process_args[k] = v @@ -58,9 +58,9 @@ def _process(self, data): """ # If needed, extract data from a returned value of Reader.read if isinstance(data, list): - if all(isinstance(d,dict) for d in data): + if all(isinstance(d, dict) for d in data): data = data[0]['data'] - if data == None: + if data is None: return [] # process operation is a simple print function data += "process part\n" @@ -86,14 +86,14 @@ def __init__(self): def main(opt_parser=OptionParser): """Main function""" - optmgr = opt_parser() + optmgr = opt_parser() opts = optmgr.parser.parse_args() cls_name = opts.processor try: - processor_cls = getattr(sys.modules[__name__],cls_name) - except: + processor_cls = getattr(modules[__name__], cls_name) + except AttributeError: print(f'Unsupported processor {cls_name}') - sys.exit(1) + raise processor = processor_cls() processor.logger.setLevel(getattr(logging, opts.log_level)) @@ -105,5 +105,6 @@ def main(opt_parser=OptionParser): print(f'Processor {processor} operates on data {data}') + if __name__ == '__main__': main() diff --git a/CHAP/reader.py b/CHAP/reader.py index f1cdb6e..b180765 100755 --- a/CHAP/reader.py +++ b/CHAP/reader.py @@ -7,13 +7,11 @@ # system modules import argparse -import inspect +from inspect import getfullargspec import logging -import sys +from sys import modules from time import time -# local modules - class Reader(): """Reader represent generic file writer""" @@ -46,11 +44,11 @@ def read(self, type_=None, schema=None, encoding=None, **_read_kwargs): t0 = time() self.logger.info(f'Executing "read" with type={type_}, ' - + f'schema={schema}, kwargs={_read_kwargs}') + f'schema={schema}, kwargs={_read_kwargs}') _valid_read_args = {} - allowed_args = inspect.getfullargspec(self._read).args \ - + inspect.getfullargspec(self._read).kwonlyargs + allowed_args = getfullargspec(self._read).args \ + + getfullargspec(self._read).kwonlyargs for k, v in _read_kwargs.items(): if k in allowed_args: _valid_read_args[k] = v @@ -74,8 +72,8 @@ def _read(self, filename): """ if not filename: - self.logger.warning('No file name is given, will skip ' - + 'read operation') + self.logger.warning( + 'No file name is given, will skip read operation') return None with open(filename) as file: @@ -97,17 +95,18 @@ def __init__(self): '--log-level', choices=logging._nameToLevel.keys(), dest='log_level', default='INFO', help='logging level') + def main(opt_parser=OptionParser): """Main function""" - optmgr = opt_parser() + optmgr = opt_parser() opts = optmgr.parser.parse_args() cls_name = opts.reader try: - reader_cls = getattr(sys.modules[__name__],cls_name) - except: + reader_cls = getattr(modules[__name__], cls_name) + except AttributeError: print(f'Unsupported reader {cls_name}') - sys.exit(1) + raise reader = reader_cls() reader.logger.setLevel(getattr(logging, opts.log_level)) @@ -119,5 +118,6 @@ def main(opt_parser=OptionParser): print(f'Reader {reader} reads from {opts.filename}, data {data}') + if __name__ == '__main__': main() diff --git a/CHAP/runner.py b/CHAP/runner.py index 4a7604f..82158eb 100755 --- a/CHAP/runner.py +++ b/CHAP/runner.py @@ -4,13 +4,13 @@ """ File : runner.py Author : Valentin Kuznetsov -Description: +Description: """ # system modules import argparse import logging -import yaml +from yaml import safe_load # local modules from CHAP.pipeline import Pipeline @@ -22,28 +22,29 @@ def __init__(self): """OptionParser class constructor""" self.parser = argparse.ArgumentParser(prog='PROG') self.parser.add_argument( - '--config', action='store', dest='config', - default='', help='Input configuration file') + '--config', action='store', dest='config', default='', + help='Input configuration file') self.parser.add_argument( '--interactive', action='store_true', dest='interactive', help='Allow interactive processes') self.parser.add_argument( '--log-level', choices=logging._nameToLevel.keys(), dest='log_level', default='INFO', help='logging level') - self.parser.add_argument("--profile", action="store_true", dest="profile", - help="profile output") + self.parser.add_argument( + '--profile', action='store_true', dest='profile', + help='profile output') def main(): """Main function""" - optmgr = OptionParser() + optmgr = OptionParser() opts = optmgr.parser.parse_args() if opts.profile: - import cProfile # python profiler - import pstats # profiler statistics - cmd = 'runner(opts)' - cProfile.runctx(cmd, globals(), locals(), 'profile.dat') - info = pstats.Stats('profile.dat') + from cProfile import runctx # python profiler + from pstats import Stats # profiler statistics + cmd = 'runner(opts)' + runctx(cmd, globals(), locals(), 'profile.dat') + info = Stats('profile.dat') info.sort_stats('cumulative') info.print_stats() else: @@ -67,7 +68,7 @@ def runner(opts): config = {} with open(opts.config) as file: - config = yaml.safe_load(file) + config = safe_load(file) logger.info(f'Input configuration: {config}\n') pipeline_config = config.get('pipeline', []) objects = [] diff --git a/CHAP/saxswaxs/processor.py b/CHAP/saxswaxs/processor.py index 9a3c53d..971388c 100755 --- a/CHAP/saxswaxs/processor.py +++ b/CHAP/saxswaxs/processor.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python if __name__ == '__main__': from CHAP.processor import main diff --git a/CHAP/saxswaxs/reader.py b/CHAP/saxswaxs/reader.py index 709d3d3..b71407d 100755 --- a/CHAP/saxswaxs/reader.py +++ b/CHAP/saxswaxs/reader.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python if __name__ == '__main__': from CHAP.reader import main diff --git a/CHAP/saxswaxs/writer.py b/CHAP/saxswaxs/writer.py index b00fa9f..f786837 100755 --- a/CHAP/saxswaxs/writer.py +++ b/CHAP/saxswaxs/writer.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python if __name__ == '__main__': from CHAP.writer import main diff --git a/CHAP/sin2psi/processor.py b/CHAP/sin2psi/processor.py index 9a3c53d..971388c 100755 --- a/CHAP/sin2psi/processor.py +++ b/CHAP/sin2psi/processor.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python if __name__ == '__main__': from CHAP.processor import main diff --git a/CHAP/sin2psi/reader.py b/CHAP/sin2psi/reader.py index 709d3d3..b71407d 100755 --- a/CHAP/sin2psi/reader.py +++ b/CHAP/sin2psi/reader.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python if __name__ == '__main__': from CHAP.reader import main diff --git a/CHAP/sin2psi/writer.py b/CHAP/sin2psi/writer.py index b00fa9f..f786837 100755 --- a/CHAP/sin2psi/writer.py +++ b/CHAP/sin2psi/writer.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python if __name__ == '__main__': from CHAP.writer import main diff --git a/CHAP/tomo/models.py b/CHAP/tomo/models.py index 3e2a7e6..62adad0 100644 --- a/CHAP/tomo/models.py +++ b/CHAP/tomo/models.py @@ -1,4 +1,4 @@ -'''Tomography Pydantic model classes''' +"""Tomography Pydantic model classes""" # Third party imports from typing import ( @@ -18,7 +18,7 @@ class Detector(BaseModel): """ Detector class to represent the detector used in the experiment. - + :ivar prefix: Prefix of the detector in the SPEC file. :type prefix: str :ivar rows: Number of pixel rows on the detector @@ -33,7 +33,8 @@ class Detector(BaseModel): prefix: constr(strip_whitespace=True, min_length=1) rows: conint(gt=0) columns: conint(gt=0) - pixel_size: conlist(item_type=confloat(gt=0, allow_inf_nan=False), + pixel_size: conlist( + item_type=confloat(gt=0, allow_inf_nan=False), min_items=1, max_items=2) lens_magnification: confloat(gt=0, allow_inf_nan=False) = 1.0 @@ -68,8 +69,8 @@ class TomoReduceConfig(BaseModel): """ tool_type: Literal['reduce_data'] = 'reduce_data' detector: Detector = Detector.construct() - img_x_bounds: Optional[conlist(item_type=conint(ge=0), min_items=2, - max_items=2)] + img_x_bounds: Optional[ + conlist(item_type=conint(ge=0), min_items=2, max_items=2)] class TomoFindCenterConfig(BaseModel): @@ -115,12 +116,12 @@ class TomoReconstructConfig(BaseModel): :type z_bounds: list[int], optional """ tool_type: Literal['reconstruct_data'] = 'reconstruct_data' - x_bounds: Optional[conlist(item_type=conint(ge=-1), min_items=2, - max_items=2)] - y_bounds: Optional[conlist(item_type=conint(ge=-1), min_items=2, - max_items=2)] - z_bounds: Optional[conlist(item_type=conint(ge=-1), min_items=2, - max_items=2)] + x_bounds: Optional[ + conlist(item_type=conint(ge=-1), min_items=2, max_items=2)] + y_bounds: Optional[ + conlist(item_type=conint(ge=-1), min_items=2, max_items=2)] + z_bounds: Optional[ + conlist(item_type=conint(ge=-1), min_items=2, max_items=2)] class TomoCombineConfig(BaseModel): @@ -138,9 +139,9 @@ class TomoCombineConfig(BaseModel): :type z_bounds: list[int], optional """ tool_type: Literal['combine_data'] = 'combine_data' - x_bounds: Optional[conlist(item_type=conint(ge=-1), min_items=2, - max_items=2)] - y_bounds: Optional[conlist(item_type=conint(ge=-1), min_items=2, - max_items=2)] - z_bounds: Optional[conlist(item_type=conint(ge=-1), min_items=2, - max_items=2)] + x_bounds: Optional[ + conlist(item_type=conint(ge=-1), min_items=2, max_items=2)] + y_bounds: Optional[ + conlist(item_type=conint(ge=-1), min_items=2, max_items=2)] + z_bounds: Optional[ + conlist(item_type=conint(ge=-1), min_items=2, max_items=2)] diff --git a/CHAP/tomo/processor.py b/CHAP/tomo/processor.py index bac852c..b0f3c0e 100644 --- a/CHAP/tomo/processor.py +++ b/CHAP/tomo/processor.py @@ -20,7 +20,6 @@ from CHAP.common.utils.general import ( is_num, input_int, -# input_num, input_yesno, select_image_bounds, select_one_image_bound, @@ -30,6 +29,7 @@ quick_plot, quick_imshow, ) +# input_num, from CHAP.common.utils.fit import Fit from CHAP.processor import Processor from CHAP.reader import main @@ -202,7 +202,7 @@ def get_configs(self, data): values taken from `data`. :rtype: dict """ - #:rtype: dict{'map': MapConfig, 'reduce': TomoReduceConfig} + # :rtype: dict{'map': MapConfig, 'reduce': TomoReduceConfig} # RV is there a way to denote optional items? # Third party modules from nexusformat.nexus import NXroot @@ -396,8 +396,9 @@ def get_nxroot(self, map_config, tool_config): image_offset = scanparser.starting_image_offset if map_config.station in ('id1a3', 'id3a'): theta_vals = scanparser.theta_vals - thetas = np.linspace(theta_vals.get('start'), - theta_vals.get('end'), theta_vals.get('num')) + thetas = np.linspace( + theta_vals.get('start'), theta_vals.get('end'), + theta_vals.get('num')) else: if len(scans.scan_numbers) != 1: raise RuntimeError( @@ -459,7 +460,7 @@ def get_nxroot(self, map_config, tool_config): scanparser.get_detector_data( tool_config.detector.prefix, scan_step_index=(image_offset, - image_offset+num_image))) + image_offset+num_image))) rotation_angles += list(thetas) x_translations += num_image*[x_translation] z_translations += num_image*[z_translation] @@ -634,7 +635,7 @@ def __init__( self._logger.warning( f'num_core = {self._num_core} is larger than the number ' + f'of available processors and reduced to {cpu_count()}') - self._num_core= cpu_count() + self._num_core = cpu_count() def gen_reduced_data(self, data, img_x_bounds=None): """ @@ -775,18 +776,19 @@ def find_centers(self, nxroot, center_rows=None, center_stack_index=None): raise ValueError(f'Invalid nxentry ({nxentry})') if (center_rows is not None and (not isinstance(center_rows, (tuple, list)) - or len(center_rows) != 2)): + or len(center_rows) != 2)): raise ValueError(f'Invalid parameter center_rows ({center_rows})') - if not self._interactive and (center_rows is None - or (center_rows[0] is None and center_rows[1] is None)): + if (not self._interactive + and (center_rows is None + or (center_rows[0] is None and center_rows[1] is None))): self._logger.warning( 'center_rows unspecified, find centers at reduced data bounds') if (center_stack_index is not None and (not isinstance(center_stack_index, int) - or center_stack_index < 0)): + or center_stack_index < 0)): raise ValueError( - 'Invalid parameter center_stack_index ' - + f'({center_stack_index})') + 'Invalid parameter center_stack_index ' + + f'({center_stack_index})') # Check if reduced data is available if ('reduced_data' not in nxentry @@ -831,8 +833,9 @@ def find_centers(self, nxroot, center_rows=None, center_stack_index=None): # Get effective pixel_size if 'zoom_perc' in nxentry.reduced_data: - eff_pixel_size = 100.0 * (nxentry.instrument.detector.x_pixel_size - / nxentry.reduced_data.attrs['zoom_perc']) + eff_pixel_size = \ + 100.0 * (nxentry.instrument.detector.x_pixel_size + / nxentry.reduced_data.attrs['zoom_perc']) else: eff_pixel_size = nxentry.instrument.detector.x_pixel_size @@ -873,10 +876,10 @@ def find_centers(self, nxroot, center_rows=None, center_stack_index=None): f'Invalid parameter center_rows ({center_rows})') t0 = time() lower_center_offset = self._find_center_one_plane( - nxentry.reduced_data.data.tomo_fields[ - center_stack_index,:,lower_row,:], - lower_row, thetas, eff_pixel_size, cross_sectional_dim, - path=self._output_folder, num_core=self._num_core) + nxentry.reduced_data.data.tomo_fields[ + center_stack_index,:,lower_row,:], + lower_row, thetas, eff_pixel_size, cross_sectional_dim, + path=self._output_folder, num_core=self._num_core) self._logger.info(f'Finding center took {time()-t0:.2f} seconds') self._logger.debug(f'lower_row = {lower_row:.2f}') self._logger.debug(f'lower_center_offset = {lower_center_offset:.2f}') @@ -933,7 +936,7 @@ def find_centers(self, nxroot, center_rows=None, center_stack_index=None): # Save test data to file if self._test_mode: with open(f'{self._output_folder}/center_config.yaml', 'w', - encoding='utf8') as f: + encoding='utf8') as f: safe_dump(center_config, f) return center_config @@ -1043,16 +1046,16 @@ def reconstruct_data( assert len(thetas) == tomo_stack.shape[1] assert 0 <= lower_row < upper_row < tomo_stack.shape[0] center_offsets = [ - lower_center_offset - lower_row*center_slope, - upper_center_offset + (tomo_stack.shape[0]-1-upper_row) - * center_slope, + lower_center_offset-lower_row*center_slope, + upper_center_offset + center_slope * ( + tomo_stack.shape[0]-1-upper_row), ] t0 = time() tomo_recon_stack = self._reconstruct_one_tomo_stack( tomo_stack, thetas, center_offsets=center_offsets, num_core=self._num_core, algorithm='gridrec') self._logger.info( - f'Reconstruction of stack {i+1} took {time()-t0:.2f} seconds') + f'Reconstruction of stack {i+1} took {time()-t0:.2f} seconds') # Combine stacks tomo_recon_stacks[i] = tomo_recon_stack @@ -1065,14 +1068,14 @@ def reconstruct_data( z_bounds = None elif self._interactive: x_bounds, y_bounds, z_bounds = self._resize_reconstructed_data( - tomo_recon_stacks, x_bounds=x_bounds, y_bounds=y_bounds, - z_bounds=z_bounds) + tomo_recon_stacks, x_bounds=x_bounds, y_bounds=y_bounds, + z_bounds=z_bounds) else: if x_bounds is None: self._logger.warning( 'x_bounds unspecified, reconstruct data for full x-range') elif not is_int_pair(x_bounds, ge=0, - lt=tomo_recon_stacks[0].shape[1]): + lt=tomo_recon_stacks[0].shape[1]): raise ValueError(f'Invalid parameter x_bounds ({x_bounds})') if y_bounds is None: self._logger.warning( @@ -1146,7 +1149,7 @@ def reconstruct_data( nxprocess.z_bounds = z_bounds nxprocess.data['reconstructed_data'] = np.asarray( [stack[z_range[0]:z_range[1],x_range[0]:x_range[1], - y_range[0]:y_range[1]] for stack in tomo_recon_stacks]) + y_range[0]:y_range[1]] for stack in tomo_recon_stacks]) nxprocess.data.attrs['signal'] = 'reconstructed_data' # Create a copy of the input Nexus object and remove reduced @@ -1246,12 +1249,12 @@ def combine_data( tomo_recon_combined = np.concatenate( [tomo_recon_combined] + [nxentry.reconstructed_data.data.reconstructed_data[i] - for i in range(1, num_tomo_stacks-1)]) + for i in range(1, num_tomo_stacks-1)]) if num_tomo_stacks > 1: tomo_recon_combined = np.concatenate( [tomo_recon_combined] + [nxentry.reconstructed_data.data.reconstructed_data[ - num_tomo_stacks-1]]) + num_tomo_stacks-1]]) self._logger.info( f'Combining the reconstructed stacks took {time()-t0:.2f} seconds') @@ -1304,18 +1307,18 @@ def combine_data( # Plot a few combined image slices if self._save_figs: quick_imshow( - tomo_recon_combined[z_range[0]:z_range[1],x_slice, - y_range[0]:y_range[1]], + tomo_recon_combined[ + z_range[0]:z_range[1],x_slice,y_range[0]:y_range[1]], title=f'recon combined xslice{x_slice}', path=self._output_folder, save_fig=True, save_only=True) quick_imshow( - tomo_recon_combined[z_range[0]:z_range[1], - x_range[0]:x_range[1],y_slice], + tomo_recon_combined[ + z_range[0]:z_range[1],x_range[0]:x_range[1],y_slice], title=f'recon combined yslice{y_slice}', path=self._output_folder, save_fig=True, save_only=True) quick_imshow( - tomo_recon_combined[z_slice,x_range[0]:x_range[1], - y_range[0]:y_range[1]], + tomo_recon_combined[ + z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], title=f'recon combined zslice{z_slice}', path=self._output_folder, save_fig=True, save_only=True) @@ -1324,8 +1327,9 @@ def combine_data( if self._test_mode: np.savetxt( f'{self._output_folder}/recon_combined.txt', - tomo_recon_combined[z_slice,x_range[0]:x_range[1], - y_range[0]:y_range[1]], fmt='%.6e') + tomo_recon_combined[ + z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], + fmt='%.6e') # Add image reconstruction to reconstructed data NXprocess # combined data order: row/z,x,y @@ -1375,8 +1379,8 @@ def _gen_dark(self, nxentry, reduced_data): # Get the dark field images image_key = nxentry.instrument.detector.get('image_key', None) if image_key and 'data' in nxentry.instrument.detector: - field_indices = [index for index, key in enumerate(image_key) - if key == 2] + field_indices = [ + index for index, key in enumerate(image_key) if key == 2] tdf_stack = nxentry.instrument.detector.data[field_indices,:,:] # RV the default NXtomo form does not accomodate dark field # stacks @@ -1399,7 +1403,7 @@ def _gen_dark(self, nxentry, reduced_data): detector_prefix, (image_offset, image_offset+num_image))) if isinstance(tdf_stack, list): - assert len(tdf_stack) == 1 # RV + assert len(tdf_stack) == 1 # RV tdf_stack = tdf_stack[0] # Take median @@ -1455,8 +1459,8 @@ def _gen_bright(self, nxentry, reduced_data): # Get the bright field images image_key = nxentry.instrument.detector.get('image_key', None) if image_key and 'data' in nxentry.instrument.detector: - field_indices = [index for index, key in enumerate(image_key) - if key == 1] + field_indices = [ + index for index, key in enumerate(image_key) if key == 1] tbf_stack = nxentry.instrument.detector.data[field_indices,:,:] # RV the default NXtomo form does not accomodate bright # field stacks @@ -1479,7 +1483,7 @@ def _gen_bright(self, nxentry, reduced_data): detector_prefix, (image_offset, image_offset+num_image))) if isinstance(tbf_stack, list): - assert len(tbf_stack) == 1 # RV + assert len(tbf_stack) == 1 # RV tbf_stack = tbf_stack[0] # Take median if more than one image @@ -1544,8 +1548,8 @@ def _set_detector_bounds(self, nxentry, reduced_data, img_x_bounds=None): # Get the first tomography image and the reference heights image_key = nxentry.instrument.detector.get('image_key', None) if image_key and 'data' in nxentry.instrument.detector: - field_indices = [index for index, key in enumerate(image_key) - if key == 0] + field_indices = [ + index for index, key in enumerate(image_key) if key == 0] first_image = np.asarray( nxentry.instrument.detector.data[field_indices[0],:,:]) theta = float(nxentry.sample.rotation_angle[field_indices[0]]) @@ -1603,10 +1607,10 @@ def _set_detector_bounds(self, nxentry, reduced_data, img_x_bounds=None): sig_low = parameters.get('sigma1', None) sig_upp = parameters.get('sigma2', None) have_fit = (fit.success and x_low_fit is not None - and x_upp_fit is not None and sig_low is not None - and sig_upp is not None - and 0 <= x_low_fit < x_upp_fit <= x_sum.size - and (sig_low+sig_upp) / (x_upp_fit-x_low_fit) < 0.1) + and x_upp_fit is not None and sig_low is not None + and sig_upp is not None + and 0 <= x_low_fit < x_upp_fit <= x_sum.size + and (sig_low+sig_upp) / (x_upp_fit-x_low_fit) < 0.1) if have_fit: # Set a 5% margin on each side margin = 0.05 * (x_upp_fit-x_low_fit) @@ -1667,7 +1671,7 @@ def _set_detector_bounds(self, nxentry, reduced_data, img_x_bounds=None): title='sum over theta and y') print(f'lower bound = {x_low} (inclusive)') print(f'upper bound = {x_upp} (exclusive)]') - accept = input_yesno('Accept these bounds (y/n)?', 'y') + accept = input_yesno('Accept these bounds (y/n)?', 'y') clear_imshow('bright field') clear_imshow(title) clear_plot('sum over theta and y') @@ -1684,7 +1688,7 @@ def _set_detector_bounds(self, nxentry, reduced_data, img_x_bounds=None): img_x_bounds = tuple(img_x_bounds[0]) if (num_tomo_stacks > 1 and (img_x_bounds[1]-img_x_bounds[0]+1) - < int((delta_z - 0.5*pixel_size) / pixel_size)): + < int((delta_z - 0.5*pixel_size) / pixel_size)): self._logger.warning( 'Image bounds and pixel size prevent seamless stacking') else: @@ -1776,9 +1780,9 @@ def _gen_tomo(self, nxentry, reduced_data): # Get image bounds img_x_bounds = tuple( - reduced_data.get('img_x_bounds', (0, tbf_shape[0]))) + reduced_data.get('img_x_bounds', (0, tbf_shape[0]))) img_y_bounds = tuple( - reduced_data.get('img_y_bounds', (0, tbf_shape[1]))) + reduced_data.get('img_y_bounds', (0, tbf_shape[1]))) # Get resized dark field # if 'dark_field' in data: @@ -1794,14 +1798,15 @@ def _gen_tomo(self, nxentry, reduced_data): # Resize bright field if (img_x_bounds != (0, tbf.shape[0]) or img_y_bounds != (0, tbf.shape[1])): - tbf = tbf[img_x_bounds[0]:img_x_bounds[1], + tbf = tbf[ + img_x_bounds[0]:img_x_bounds[1], img_y_bounds[0]:img_y_bounds[1]] # Get the tomography images image_key = nxentry.instrument.detector.get('image_key', None) if image_key and 'data' in nxentry.instrument.detector: - field_indices_all = [index for index, key in enumerate(image_key) - if key == 0] + field_indices_all = [ + index for index, key in enumerate(image_key) if key == 0] z_translation_all = nxentry.sample.z_translation[field_indices_all] z_translation_levels = sorted(list(set(z_translation_all))) num_tomo_stacks = len(z_translation_levels) @@ -1811,7 +1816,8 @@ def _gen_tomo(self, nxentry, reduced_data): thetas = None tomo_stacks = [] for i, z_translation in enumerate(z_translation_levels): - field_indices = [field_indices_all[index] + field_indices = [ + field_indices_all[index] for index, z in enumerate(z_translation_all) if z == z_translation] horizontal_shift = list( @@ -1823,7 +1829,7 @@ def _gen_tomo(self, nxentry, reduced_data): assert len(vertical_shift) == 1 vertical_shifts += vertical_shift sequence_numbers = nxentry.instrument.detector.sequence_number[ - field_indices] + field_indices] if thetas is None: thetas = np.asarray( nxentry.sample.rotation_angle[ @@ -1831,12 +1837,12 @@ def _gen_tomo(self, nxentry, reduced_data): else: assert all( thetas[i] == nxentry.sample.rotation_angle[ - field_indices[index]] + field_indices[index]] for i, index in enumerate(sequence_numbers)) - assert (list(set(sequence_numbers)) == - list(np.arange(0, (len(sequence_numbers))))) - if (list(sequence_numbers) == - list(np.arange(0, (len(sequence_numbers))))): + assert (list(set(sequence_numbers)) + == list(np.arange(0, (len(sequence_numbers))))) + if (list(sequence_numbers) + == list(np.arange(0, (len(sequence_numbers))))): tomo_stack = np.asarray( nxentry.instrument.detector.data[field_indices]) else: @@ -1875,7 +1881,8 @@ def _gen_tomo(self, nxentry, reduced_data): # Right now the range is the same for each set in the stack if (img_x_bounds != (0, tbf.shape[0]) or img_y_bounds != (0, tbf.shape[1])): - tomo_stack = tomo_stack[:,img_x_bounds[0]:img_x_bounds[1], + tomo_stack = tomo_stack[ + :,img_x_bounds[0]:img_x_bounds[1], img_y_bounds[0]:img_y_bounds[1]].astype('float64') # Subtract dark field @@ -1975,7 +1982,7 @@ def _gen_tomo(self, nxentry, reduced_data): def _find_center_one_plane( self, sinogram, row, thetas, eff_pixel_size, cross_sectional_dim, - path=None, num_core=1):#, tol=0.1): + path=None, num_core=1): # , tol=0.1): """Find center for a single tomography plane.""" from tomopy import find_center_vo @@ -2135,7 +2142,7 @@ def _reconstruct_one_plane( del sinogram # Performing Gaussian filtering and removing ring artifacts - recon_parameters = None #self._config.get('recon_parameters') + recon_parameters = None # self._config.get('recon_parameters') if recon_parameters is None: sigma = 1.0 ring_width = 15 @@ -2168,7 +2175,7 @@ def _plot_edges_one_plane(self, recon_plane, title, path=None): """ from skimage.restoration import denoise_tv_chambolle - vis_parameters = None #self._config.get('vis_parameters') + vis_parameters = None # self._config.get('vis_parameters') if vis_parameters is None: weight = 0.1 else: @@ -2229,7 +2236,7 @@ def _reconstruct_one_tomo_stack( centers += tomo_stack.shape[2]/2 # Get reconstruction parameters - recon_parameters = None #self._config.get('recon_parameters') + recon_parameters = None # self._config.get('recon_parameters') if recon_parameters is None: sigma = 2.0 secondary_iters = 0 @@ -2304,7 +2311,7 @@ def _reconstruct_one_tomo_stack( 'num_iter': secondary_iters, } t0 = time() - tomo_recon_stack = recon( + tomo_recon_stack = recon( tomo_stack, np.radians(thetas), centers, init_recon=tomo_recon_stack, options=options, sinogram_order=True, algorithm=astra, ncore=num_core) @@ -2319,8 +2326,9 @@ def _reconstruct_one_tomo_stack( return tomo_recon_stack - def _resize_reconstructed_data(self, data, x_bounds=None, y_bounds=None, - z_bounds=None, z_only=False): + def _resize_reconstructed_data( + self, data, x_bounds=None, y_bounds=None, z_bounds=None, + z_only=False): """Resize the reconstructed tomography data.""" # Data order: row(z),x,y or stack,row(z),x,y if isinstance(data, list): diff --git a/CHAP/writer.py b/CHAP/writer.py index eafed0a..ebe42b1 100755 --- a/CHAP/writer.py +++ b/CHAP/writer.py @@ -7,9 +7,9 @@ # system modules import argparse -import inspect +from inspect import getfullargspec import logging -import sys +from sys import modules from time import time @@ -32,11 +32,11 @@ def write(self, data, filename, **_write_kwargs): t0 = time() self.logger.info(f'Executing "write" with filename={filename}, ' - + f'type(data)={type(data)}, kwargs={_write_kwargs}') + f'type(data)={type(data)}, kwargs={_write_kwargs}') _valid_write_args = {} - allowed_args = inspect.getfullargspec(self._write).args \ - + inspect.getfullargspec(self._write).kwonlyargs + allowed_args = getfullargspec(self._write).args \ + + getfullargspec(self._write).kwonlyargs for k, v in _write_kwargs.items(): if k in allowed_args: _valid_write_args[k] = v @@ -54,6 +54,7 @@ def _write(self, data, filename): file.write(data) return data + class OptionParser(): """User based option parser""" def __init__(self): @@ -71,17 +72,18 @@ def __init__(self): '--log-level', choices=logging._nameToLevel.keys(), dest='log_level', default='INFO', help='logging level') + def main(opt_parser=OptionParser): """Main function""" - optmgr = opt_parser() + optmgr = opt_parser() opts = optmgr.parser.parse_args() cls_name = opts.writer try: - writer_cls = getattr(sys.modules[__name__],cls_name) - except: + writer_cls = getattr(modules[__name__], cls_name) + except AttributeError: print(f'Unsupported writer {cls_name}') - sys.exit(1) + raise writer = writer_cls() writer.logger.setLevel(getattr(logging, opts.log_level)) @@ -92,5 +94,6 @@ def main(opt_parser=OptionParser): data = writer.write(opts.data, opts.filename) print(f'Writer {writer} writes to {opts.filename}, data {data}') + if __name__ == '__main__': main() From 66daaa41472321913e96cc980cdb26dea1196f27 Mon Sep 17 00:00:00 2001 From: Rolf Verberg Date: Fri, 21 Apr 2023 12:02:38 -0400 Subject: [PATCH 6/6] fix: fixed yaml import and interactive keyword bugs due to merge in CHAP/runner.py --- CHAP/runner.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/CHAP/runner.py b/CHAP/runner.py index 611499f..d1cee14 100755 --- a/CHAP/runner.py +++ b/CHAP/runner.py @@ -59,10 +59,10 @@ def runner(opts): logger, log_handler = setLogger(log_level) config = {} with open(opts.config) as file: - config = yaml.safe_load(file) + config = safe_load(file) logger.info(f'Input configuration: {config}\n') pipeline_config = config.get('pipeline', []) - run(pipeline_config, logger, log_level, log_handler) + run(pipeline_config, opts.interactive, logger, log_level, log_handler) def setLogger(log_level="INFO"): """ @@ -79,7 +79,7 @@ def setLogger(log_level="INFO"): logger.addHandler(log_handler) return logger, log_handler -def run(pipeline_config, logger=None, log_level=None, log_handler=None): +def run(pipeline_config, interactive=False, logger=None, log_level=None, log_handler=None): """ Run given pipeline_config @@ -89,7 +89,7 @@ def run(pipeline_config, logger=None, log_level=None, log_handler=None): kwds = [] for item in pipeline_config: # load individual object with given name from its module - kwargs = {'interactive': opts.interactive} + kwargs = {'interactive': interactive} if isinstance(item, dict): name = list(item.keys())[0] # Combine the "interactive" command line argument with the object's keywords