diff --git a/python/lsst/pipe/tasks/extended_psf.py b/python/lsst/pipe/tasks/extended_psf.py index d480a3aa51..073ce538b0 100644 --- a/python/lsst/pipe/tasks/extended_psf.py +++ b/python/lsst/pipe/tasks/extended_psf.py @@ -28,20 +28,61 @@ "StackBrightStarsTask", "MeasureExtendedPsfConfig", "MeasureExtendedPsfTask", + "DetectorsInRegion", ] from dataclasses import dataclass -from typing import List from lsst.afw.fits import Fits, readMetadata from lsst.afw.image import ImageF, MaskedImageF, MaskX from lsst.afw.math import StatisticsControl, statisticsStack, stringToStatisticsProperty from lsst.daf.base import PropertyList from lsst.geom import Extent2I -from lsst.pex.config import ChoiceField, Config, ConfigurableField, DictField, Field, ListField +from lsst.pex.config import ChoiceField, Config, ConfigDictField, ConfigurableField, Field, ListField from lsst.pipe.base import PipelineTaskConfig, PipelineTaskConnections, Struct, Task -from lsst.pipe.tasks.coaddBase import subBBoxIter from lsst.pipe.base.connectionTypes import Input, Output +from lsst.pipe.tasks.coaddBase import subBBoxIter + + +def find_region_for_detector(detector_id, detectors_focal_plane_regions): + """Find the focal plane region that contains a given detector. + + Parameters + ---------- + detector_id : `int` + The detector ID. + + detectors_focal_plane_regions : + `dict` [`str`, `lsst.pipe.tasks.extended_psf.DetectorsInRegion`] + A dictionary containing focal plane region names as keys, and the + corresponding detector IDs encoded within the values. + + Returns + ------- + key: `str` + The name of the region to which the given detector belongs. + + Raises + ------ + KeyError + Raised if the given detector is not included in any focal plane region. + """ + for region_id, detectors_in_region in detectors_focal_plane_regions.items(): + if detector_id in detectors_in_region.detectors: + return region_id + raise KeyError( + "Detector %d is not included in any focal plane region.", + detector_id, + ) + + +class DetectorsInRegion(Config): + """Provides a list of detectors that define a region.""" + + detectors = ListField[int]( + doc="A list containing the detectors IDs.", + default=[], + ) @dataclass @@ -54,13 +95,13 @@ class FocalPlaneRegionExtendedPsf: ---------- extended_psf_image : `lsst.afw.image.MaskedImageF` Image of the extended PSF model. - detector_list : `list` [`int`] + region_detectors : `lsst.pipe.tasks.extended_psf.DetectorsInRegion` List of detector IDs that define the focal plane region over which this extended PSF model has been built (and can be used). """ extended_psf_image: MaskedImageF - detector_list: List[int] + region_detectors: DetectorsInRegion class ExtendedPsf: @@ -81,8 +122,8 @@ def __init__(self, default_extended_psf=None): self.focal_plane_regions = {} self.detectors_focal_plane_regions = {} - def add_regional_extended_psf(self, extended_psf_image, region_name, detector_list): - """Add a new focal plane region, along wit hits extended PSF, to the + def add_regional_extended_psf(self, extended_psf_image, region_name, region_detectors): + """Add a new focal plane region, along with its extended PSF, to the ExtendedPsf instance. Parameters @@ -91,17 +132,17 @@ def add_regional_extended_psf(self, extended_psf_image, region_name, detector_li Extended PSF model for the region. region_name : `str` Name of the focal plane region. Will be converted to all-uppercase. - detector_list : `list` [`int`] - List of IDs for the detectors that define the focal plane region. + region_detectors : `lsst.pipe.tasks.extended_psf.DetectorsInRegion` + List of detector IDs for the detectors that define a region on the + focal plane. """ region_name = region_name.upper() if region_name in self.focal_plane_regions: raise ValueError(f"Region name {region_name} is already used by this ExtendedPsf instance.") self.focal_plane_regions[region_name] = FocalPlaneRegionExtendedPsf( - extended_psf_image=extended_psf_image, detector_list=detector_list + extended_psf_image=extended_psf_image, region_detectors=region_detectors ) - for det in detector_list: - self.detectors_focal_plane_regions[det] = region_name + self.detectors_focal_plane_regions[region_name] = region_detectors def __call__(self, detector=None): """Return the appropriate extended PSF. @@ -130,7 +171,7 @@ def __call__(self, detector=None): return self.default_extended_psf elif not self.focal_plane_regions: return self.default_extended_psf - return self.get_regional_extended_psf(detector=detector) + return self.get_extended_psf(region_name=detector) def __len__(self): """Returns the number of extended PSF models present in the instance. @@ -146,31 +187,36 @@ def __len__(self): n_regions += 1 return n_regions - def get_regional_extended_psf(self, region_name=None, detector=None): - """Returns the extended PSF for a focal plane region. + def get_extended_psf(self, region_name): + """Returns the extended PSF for a focal plane region or detector. - The region can be identified either by name, or through a detector ID. + This method takes either a region name or a detector ID as input. If + the input is a `str` type, it is assumed to be the region name and if + the input is a `int` type it is assumed to be the detector ID. Parameters ---------- - region_name : `str` or `None`, optional - Name of the region for which the extended PSF should be retrieved. - Ignored if ``detector`` is provided. Must be provided if - ``detector`` is None. - detector : `int` or `None`, optional - If provided, returns the extended PSF for the focal plane region - that includes this detector. + region_name : `str` or `int` + Name of the region (str) or detector (int) for which the extended + PSF should be retrieved. + + Returns + ------- + extended_psf_image: `lsst.afw.image.MaskedImageF` + The extended PSF model for the requested region or detector. Raises ------ ValueError - Raised if neither ``detector`` nor ``regionName`` is provided. + Raised if the input is not in the correct type. """ - if detector is None: - if region_name is None: - raise ValueError("One of either a regionName or a detector number must be provided.") + if isinstance(region_name, str): return self.focal_plane_regions[region_name].extended_psf_image - return self.focal_plane_regions[self.detectors_focal_plane_regions[detector]].extended_psf_image + elif isinstance(region_name, int): + region_name = find_region_for_detector(region_name, self.detectors_focal_plane_regions) + return self.focal_plane_regions[region_name].extended_psf_image + else: + raise ValueError("A region name with `str` type or detector number with `int` must be provided") def write_fits(self, filename): """Write this object to a file. @@ -187,7 +233,7 @@ def write_fits(self, filename): metadata["HAS_REGIONS"] = True metadata["REGION_NAMES"] = list(self.focal_plane_regions.keys()) for region, e_psf_region in self.focal_plane_regions.items(): - metadata[region] = e_psf_region.detector_list + metadata[region] = e_psf_region.region_detectors.detectors else: metadata["HAS_REGIONS"] = False fits_primary = Fits(filename, "w") @@ -260,8 +306,9 @@ def read_fits(cls, filename): # Generate extended PSF regions mappings. for r_name in focal_plane_region_names: extended_psf_image = MaskedImageF(**extended_psf_parts[r_name]) - detector_list = global_metadata.getArray(r_name) - extended_psf.add_regional_extended_psf(extended_psf_image, r_name, detector_list) + region_detectors = DetectorsInRegion() + region_detectors.detectors = global_metadata.getArray(r_name) + extended_psf.add_regional_extended_psf(extended_psf_image, r_name, region_detectors) # Instantiate ExtendedPsf. return extended_psf @@ -413,12 +460,13 @@ class MeasureExtendedPsfConfig(PipelineTaskConfig, pipelineConnections=MeasureEx target=StackBrightStarsTask, doc="Stack selected bright stars", ) - detectors_focal_plane_regions = DictField( - keytype=int, - itemtype=str, + detectors_focal_plane_regions = ConfigDictField( + keytype=str, + itemtype=DetectorsInRegion, doc=( - "Mapping from detector IDs to focal plane region names. If empty, a constant extended PSF model " - "is built from all selected bright stars." + "Mapping from focal plane region names to detector IDs. " + "If empty, a constant extended PSF model is built from all selected bright stars. " + "It's possible for a single detector to be included in multiple regions if so desired." ), default={}, ) @@ -442,15 +490,7 @@ class MeasureExtendedPsfTask(Task): def __init__(self, initInputs=None, *args, **kwargs): super().__init__(*args, **kwargs) self.makeSubtask("stack_bright_stars") - self.focal_plane_regions = { - region: [] for region in set(self.config.detectors_focal_plane_regions.values()) - } - for det, region in self.config.detectors_focal_plane_regions.items(): - self.focal_plane_regions[region].append(det) - # make no assumption on what detector IDs should be, but if we come - # across one where there are processed bright stars, but no - # corresponding focal plane region, make sure we keep track of - # it (eg to raise a warning only once) + self.detectors_focal_plane_regions = self.config.detectors_focal_plane_regions self.regionless_dets = [] def select_detector_refs(self, ref_list): @@ -463,21 +503,21 @@ def select_detector_refs(self, ref_list): `lsst.daf.butler._deferredDatasetHandle.DeferredDatasetHandle` List of available bright star stamps data references. """ - region_ref_list = {region: [] for region in self.focal_plane_regions.keys()} + region_ref_list = {region: [] for region in self.detectors_focal_plane_regions.keys()} for dataset_handle in ref_list: - det_id = dataset_handle.ref.dataId["detector"] - if det_id in self.regionless_dets: + detector_id = dataset_handle.ref.dataId["detector"] + if detector_id in self.regionless_dets: continue try: - region_name = self.config.detectors_focal_plane_regions[det_id] + region_name = find_region_for_detector(detector_id, self.detectors_focal_plane_regions) except KeyError: self.log.warning( "Bright stars were available for detector %d, but it was missing from the %s config " "field, so they will not be used to build any of the extended PSF models.", - det_id, + detector_id, "'detectors_focal_plane_regions'", ) - self.regionless_dets.append(det_id) + self.regionless_dets.append(detector_id) continue region_ref_list[region_name].append(dataset_handle) return region_ref_list @@ -485,7 +525,6 @@ def select_detector_refs(self, ref_list): def runQuantum(self, butlerQC, inputRefs, outputRefs): input_data = butlerQC.get(inputRefs) bss_ref_list = input_data["input_brightStarStamps"] - # Handle default case of a single region with empty detector list if not self.config.detectors_focal_plane_regions: self.log.info( "No detector groups were provided to MeasureExtendedPsfTask; computing a single, " @@ -505,7 +544,7 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs): continue ext_psf = self.stack_bright_stars.run(ref_list, region_name) output_e_psf.add_regional_extended_psf( - ext_psf, region_name, self.focal_plane_regions[region_name] + ext_psf, region_name, self.detectors_focal_plane_regions[region_name] ) output = Struct(extended_psf=output_e_psf) butlerQC.put(output, outputRefs) diff --git a/tests/test_extended_psf.py b/tests/test_extended_psf.py index 4ae7ba1fb6..95219d4e1a 100644 --- a/tests/test_extended_psf.py +++ b/tests/test_extended_psf.py @@ -45,13 +45,20 @@ def setUp(self): self.default_e_psf = make_extended_psf(1)[0] self.constant_e_psf = extended_psf.ExtendedPsf(self.default_e_psf) self.regions = ["NW", "SW", "E"] - self.region_detectors = [list(range(10)), list(range(10, 20)), list(range(20, 40))] + self.region_detectors = [] + for i in range(3): + self.det = extended_psf.DetectorsInRegion() + r0 = 10*i + r1 = 10*(i+1) + self.det.detectors = list(range(r0, r1)) + self.region_detectors.append(self.det) self.regional_e_psfs = make_extended_psf(3) def tearDown(self): del self.default_e_psf del self.regions del self.region_detectors + del self.det del self.regional_e_psfs def test_constant_psf(self): @@ -79,20 +86,20 @@ def test_regional_psf_addition(self): self.assertEqual(len(with_default_e_psf), 3) # Ensure we recover the correct regional PSF. for j in range(2): - for det in self.region_detectors[j]: + for det in self.region_detectors[j].detectors: # Try it by calling the class directly. reg_psf0, reg_psf1 = starts_empty_e_psf(det), with_default_e_psf(det) self.assertMaskedImagesAlmostEqual(reg_psf0, self.regional_e_psfs[j]) self.assertMaskedImagesAlmostEqual(reg_psf1, self.regional_e_psfs[j]) # Try it by passing on a detector number to the - # get_regional_extended_psf method. - reg_psf0 = starts_empty_e_psf.get_regional_extended_psf(detector=det) - reg_psf1 = with_default_e_psf.get_regional_extended_psf(detector=det) + # get_extended_psf method. + reg_psf0 = starts_empty_e_psf.get_extended_psf(region_name=det) + reg_psf1 = with_default_e_psf.get_extended_psf(region_name=det) self.assertMaskedImagesAlmostEqual(reg_psf0, self.regional_e_psfs[j]) self.assertMaskedImagesAlmostEqual(reg_psf1, self.regional_e_psfs[j]) # Try it by passing on a region name. - reg_psf0 = starts_empty_e_psf.get_regional_extended_psf(region_name=self.regions[j]) - reg_psf1 = with_default_e_psf.get_regional_extended_psf(region_name=self.regions[j]) + reg_psf0 = starts_empty_e_psf.get_extended_psf(region_name=self.regions[j]) + reg_psf1 = with_default_e_psf.get_extended_psf(region_name=self.regions[j]) self.assertMaskedImagesAlmostEqual(reg_psf0, self.regional_e_psfs[j]) self.assertMaskedImagesAlmostEqual(reg_psf1, self.regional_e_psfs[j]) # Ensure we recover the original default PSF. @@ -118,7 +125,7 @@ def test_IO(self): self.assertMaskedImagesAlmostEqual(per_region_e_psf0(), read_e_psf0()) # And per-region extended PSFs. for j in range(3): - for det in self.region_detectors[j]: + for det in self.region_detectors[j].detectors: reg_psf0, read_reg_psf0 = per_region_e_psf0(det), read_e_psf0(det) self.assertMaskedImagesAlmostEqual(reg_psf0, read_reg_psf0) # Test IO with a single per-region extended PSF. @@ -130,7 +137,7 @@ def test_IO(self): read_e_psf1 = extended_psf.ExtendedPsf.readFits(f.name) self.assertEqual(per_region_e_psf0.detectors_focal_plane_regions, read_e_psf0.detectors_focal_plane_regions) - for det in self.region_detectors[1]: + for det in self.region_detectors[1].detectors: reg_psf1, read_reg_psf1 = per_region_e_psf1(det), read_e_psf1(det) self.assertMaskedImagesAlmostEqual(reg_psf1, read_reg_psf1)