Skip to content

Commit

Permalink
Modify region definition config in extended_psf
Browse files Browse the repository at this point in the history
  • Loading branch information
bazkiaei committed Oct 8, 2023
1 parent f2769da commit b49c4eb
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 75 deletions.
186 changes: 120 additions & 66 deletions python/lsst/pipe/tasks/extended_psf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,62 @@
"StackBrightStarsTask",
"MeasureExtendedPsfConfig",
"MeasureExtendedPsfTask",
"DetectorsInRegion",
]

from dataclasses import dataclass
from typing import List

from lsst.afw.fits import Fits, readMetadata
from lsst.afw.image import ImageF, MaskedImageF, MaskX
from lsst.afw.math import StatisticsControl, statisticsStack, stringToStatisticsProperty
from lsst.daf.base import PropertyList
from lsst.geom import Extent2I
from lsst.pex.config import ChoiceField, Config, ConfigurableField, DictField, Field, ListField
from lsst.pex.config import ChoiceField, Config, ConfigurableField, Field, ListField, ConfigDictField
from lsst.pipe.base import PipelineTaskConfig, PipelineTaskConnections, Struct, Task
from lsst.pipe.tasks.coaddBase import subBBoxIter
from lsst.pipe.base.connectionTypes import Input, Output


def find_region_for_detector(det_id, detectors_focal_plane_regions):
"""Find the focal plane region that contains a given detector.
Parameters
----------
det_id : int
The detector ID.
detectors_focal_plane_regions : dict
A dictionary containing focal plane region names as keys, and the
corresponding detector IDs as the values.
Returns
-------
_type_
_description_
Raises
------
KeyError
_description_
"""
for key, value_list in detectors_focal_plane_regions.items():
if det_id in value_list.detectors:
return key
raise KeyError(
"Detector %d is not included in any focal plane region.",
det_id,
)


class DetectorsInRegion(Config):
"""Provides a list of detectors that define a region.
"""
detectors = ListField[int](
doc="A list containing the detectors IDs.",
default=None,
)


@dataclass
class FocalPlaneRegionExtendedPsf:
"""Single extended PSF over a focal plane region.
Expand All @@ -54,13 +94,13 @@ class FocalPlaneRegionExtendedPsf:
----------
extended_psf_image : `lsst.afw.image.MaskedImageF`
Image of the extended PSF model.
detector_list : `list` [`int`]
region_detectors : `ListField` [`int`]
List of detector IDs that define the focal plane region over which this
extended PSF model has been built (and can be used).
"""

extended_psf_image: MaskedImageF
detector_list: List[int]
region_detectors: DetectorsInRegion


class ExtendedPsf:
Expand All @@ -81,7 +121,7 @@ def __init__(self, default_extended_psf=None):
self.focal_plane_regions = {}
self.detectors_focal_plane_regions = {}

def add_regional_extended_psf(self, extended_psf_image, region_name, detector_list):
def add_regional_extended_psf(self, extended_psf_image, region_name, region_detectors):
"""Add a new focal plane region, along wit hits extended PSF, to the
ExtendedPsf instance.
Expand All @@ -91,17 +131,16 @@ def add_regional_extended_psf(self, extended_psf_image, region_name, detector_li
Extended PSF model for the region.
region_name : `str`
Name of the focal plane region. Will be converted to all-uppercase.
detector_list : `list` [`int`]
List of IDs for the detectors that define the focal plane region.
region_detectors : `ListField` [`int`]
List of detector IDs for the detectors that define the focal plane region.
"""
region_name = region_name.upper()
if region_name in self.focal_plane_regions:
raise ValueError(f"Region name {region_name} is already used by this ExtendedPsf instance.")
self.focal_plane_regions[region_name] = FocalPlaneRegionExtendedPsf(
extended_psf_image=extended_psf_image, detector_list=detector_list
extended_psf_image=extended_psf_image, region_detectors=region_detectors
)
for det in detector_list:
self.detectors_focal_plane_regions[det] = region_name
self.detectors_focal_plane_regions[region_name] = region_detectors

def __call__(self, detector=None):
"""Return the appropriate extended PSF.
Expand Down Expand Up @@ -130,7 +169,7 @@ def __call__(self, detector=None):
return self.default_extended_psf
elif not self.focal_plane_regions:
return self.default_extended_psf
return self.get_regional_extended_psf(detector=detector)
return self.get_extended_psf(region_name=detector)

def __len__(self):
"""Returns the number of extended PSF models present in the instance.
Expand All @@ -146,31 +185,38 @@ def __len__(self):
n_regions += 1
return n_regions

def get_regional_extended_psf(self, region_name=None, detector=None):
"""Returns the extended PSF for a focal plane region.
The region can be identified either by name, or through a detector ID.
def get_extended_psf(self, region_name):
"""Returns the extended PSF for a focal plane region or detector.
Parameters
----------
region_name : `str` or `None`, optional
Name of the region for which the extended PSF should be retrieved.
Ignored if ``detector`` is provided. Must be provided if
``detector`` is None.
detector : `int` or `None`, optional
If provided, returns the extended PSF for the focal plane region
that includes this detector.
region_name : `str` or `int`
Name of the region (str) or detector (int) for which the extended
PSF should be retrieved.
Returns
-------
lsst.afw.image._maskedImage.MaskedImageF
The extended PSF model for the requested region or detector.
Raises
------
ValueError
Raised if neither ``detector`` nor ``regionName`` is provided.
Raised if the input is not in the correct type.
Notes
-----
This method takes either a region name or a detector ID as input. If
the input is a `str` type, it is assumed to be the region name and if
the input is a `int` type it is assumed to be the detector ID.
"""
if detector is None:
if region_name is None:
raise ValueError("One of either a regionName or a detector number must be provided.")
if isinstance(region_name, str):
return self.focal_plane_regions[region_name].extended_psf_image
elif isinstance(region_name, int):
region_name = find_region_for_detector(region_name, self.detectors_focal_plane_regions)
return self.focal_plane_regions[region_name].extended_psf_image
return self.focal_plane_regions[self.detectors_focal_plane_regions[detector]].extended_psf_image
else:
raise ValueError("A region name with `str` type or detector number with `int` must be provided")

def write_fits(self, filename):
"""Write this object to a file.
Expand All @@ -187,7 +233,7 @@ def write_fits(self, filename):
metadata["HAS_REGIONS"] = True
metadata["REGION_NAMES"] = list(self.focal_plane_regions.keys())
for region, e_psf_region in self.focal_plane_regions.items():
metadata[region] = e_psf_region.detector_list
metadata[region] = e_psf_region.region_detectors.detectors
else:
metadata["HAS_REGIONS"] = False
fits_primary = Fits(filename, "w")
Expand Down Expand Up @@ -260,8 +306,9 @@ def read_fits(cls, filename):
# Generate extended PSF regions mappings.
for r_name in focal_plane_region_names:
extended_psf_image = MaskedImageF(**extended_psf_parts[r_name])
detector_list = global_metadata.getArray(r_name)
extended_psf.add_regional_extended_psf(extended_psf_image, r_name, detector_list)
region_detectors = DetectorsInRegion()
region_detectors.detectors = global_metadata.getArray(r_name)
extended_psf.add_regional_extended_psf(extended_psf_image, r_name, region_detectors)
# Instantiate ExtendedPsf.
return extended_psf

Expand All @@ -280,7 +327,7 @@ class StackBrightStarsConfig(Config):
)
stacking_statistic = ChoiceField[str](
doc="Type of statistic to use for stacking.",
default="MEANCLIP",
default="MEDIAN",
allowed={
"MEAN": "mean",
"MEDIAN": "median",
Expand Down Expand Up @@ -413,9 +460,9 @@ class MeasureExtendedPsfConfig(PipelineTaskConfig, pipelineConnections=MeasureEx
target=StackBrightStarsTask,
doc="Stack selected bright stars",
)
detectors_focal_plane_regions = DictField(
keytype=int,
itemtype=str,
detectors_focal_plane_regions = ConfigDictField(
keytype=str,
itemtype=DetectorsInRegion,
doc=(
"Mapping from detector IDs to focal plane region names. If empty, a constant extended PSF model "
"is built from all selected bright stars."
Expand All @@ -442,17 +489,32 @@ class MeasureExtendedPsfTask(Task):
def __init__(self, initInputs=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.makeSubtask("stack_bright_stars")
self.focal_plane_regions = {
region: [] for region in set(self.config.detectors_focal_plane_regions.values())
}
for det, region in self.config.detectors_focal_plane_regions.items():
self.focal_plane_regions[region].append(det)
# make no assumption on what detector IDs should be, but if we come
# across one where there are processed bright stars, but no
# corresponding focal plane region, make sure we keep track of
# it (eg to raise a warning only once)
self._set_detectors_focal_plane_regions()
self.regionless_dets = []

def _set_detectors_focal_plane_regions(self):
"""Set the mapping from detector IDs to focal plane regions."""
if not self.config.detectors_focal_plane_regions:
self._set_default_detectors_focal_plane_regions()
else:
self.detectors_focal_plane_regions = self.config.detectors_focal_plane_regions

def _set_default_detectors_focal_plane_regions(self):
"""Set the default mapping from detector IDs to focal plane regions."""
self.detectors_focal_plane_regions = {}
# The range for this for loop needs some more thoughts! We might
# want to use the instrument property to read how many detectors
# are there!
# TODO: DM-41101 Modify the default per-detector regions
for i in range(104):
det = DetectorsInRegion()
det.detectors = [i]
self.detectors_focal_plane_regions[f"region{i:02d}"] = det
self.log.info(
"No detector groups were provided to MeasureExtendedPsfTask; computing a single, "
"extended PSF model for each detector that has available observations."
)

def select_detector_refs(self, ref_list):
"""Split available sets of bright star stamps according to focal plane
regions.
Expand All @@ -463,13 +525,13 @@ def select_detector_refs(self, ref_list):
`lsst.daf.butler._deferredDatasetHandle.DeferredDatasetHandle`
List of available bright star stamps data references.
"""
region_ref_list = {region: [] for region in self.focal_plane_regions.keys()}
region_ref_list = {region: [] for region in self.detectors_focal_plane_regions.keys()}
for dataset_handle in ref_list:
det_id = dataset_handle.ref.dataId["detector"]
if det_id in self.regionless_dets:
continue
try:
region_name = self.config.detectors_focal_plane_regions[det_id]
region_name = find_region_for_detector(det_id, self.detectors_focal_plane_regions)
except KeyError:
self.log.warning(
"Bright stars were available for detector %d, but it was missing from the %s config "
Expand All @@ -485,27 +547,19 @@ def select_detector_refs(self, ref_list):
def runQuantum(self, butlerQC, inputRefs, outputRefs):
input_data = butlerQC.get(inputRefs)
bss_ref_list = input_data["input_brightStarStamps"]
# Handle default case of a single region with empty detector list
if not self.config.detectors_focal_plane_regions:
self.log.info(
"No detector groups were provided to MeasureExtendedPsfTask; computing a single, "
"constant extended PSF model over all available observations."
)
output_e_psf = ExtendedPsf(self.stack_bright_stars.run(bss_ref_list))
else:
output_e_psf = ExtendedPsf()
region_ref_list = self.select_detector_refs(bss_ref_list)
for region_name, ref_list in region_ref_list.items():
if not ref_list:
# no valid references found
self.log.warning(
"No valid brightStarStamps reference found for region '%s'; skipping it.",
region_name,
)
continue
ext_psf = self.stack_bright_stars.run(ref_list, region_name)
output_e_psf.add_regional_extended_psf(
ext_psf, region_name, self.focal_plane_regions[region_name]
output_e_psf = ExtendedPsf()
region_ref_list = self.select_detector_refs(bss_ref_list)
for region_name, ref_list in region_ref_list.items():
if not ref_list:
# no valid references found
self.log.warning(
"No valid brightStarStamps reference found for region '%s'; skipping it.",
region_name,
)
continue
ext_psf = self.stack_bright_stars.run(ref_list, region_name)
output_e_psf.add_regional_extended_psf(
ext_psf, region_name, self.detectors_focal_plane_regions[region_name]
)
output = Struct(extended_psf=output_e_psf)
butlerQC.put(output, outputRefs)
25 changes: 16 additions & 9 deletions tests/test_extended_psf.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,20 @@ def setUp(self):
self.default_e_psf = make_extended_psf(1)[0]
self.constant_e_psf = extended_psf.ExtendedPsf(self.default_e_psf)
self.regions = ["NW", "SW", "E"]
self.region_detectors = [list(range(10)), list(range(10, 20)), list(range(20, 40))]
self.region_detectors = []
for i in range(3):
self.det = extended_psf.DetectorsInRegion()
r0 = 10*i
r1 = 10*(i+1)
self.det.detectors = list(range(r0, r1))
self.region_detectors.append(self.det)
self.regional_e_psfs = make_extended_psf(3)

def tearDown(self):
del self.default_e_psf
del self.regions
del self.region_detectors
del self.det
del self.regional_e_psfs

def test_constant_psf(self):
Expand Down Expand Up @@ -79,20 +86,20 @@ def test_regional_psf_addition(self):
self.assertEqual(len(with_default_e_psf), 3)
# Ensure we recover the correct regional PSF.
for j in range(2):
for det in self.region_detectors[j]:
for det in self.region_detectors[j].detectors:
# Try it by calling the class directly.
reg_psf0, reg_psf1 = starts_empty_e_psf(det), with_default_e_psf(det)
self.assertMaskedImagesAlmostEqual(reg_psf0, self.regional_e_psfs[j])
self.assertMaskedImagesAlmostEqual(reg_psf1, self.regional_e_psfs[j])
# Try it by passing on a detector number to the
# get_regional_extended_psf method.
reg_psf0 = starts_empty_e_psf.get_regional_extended_psf(detector=det)
reg_psf1 = with_default_e_psf.get_regional_extended_psf(detector=det)
# get_extended_psf method.
reg_psf0 = starts_empty_e_psf.get_extended_psf(region_name=det)
reg_psf1 = with_default_e_psf.get_extended_psf(region_name=det)
self.assertMaskedImagesAlmostEqual(reg_psf0, self.regional_e_psfs[j])
self.assertMaskedImagesAlmostEqual(reg_psf1, self.regional_e_psfs[j])
# Try it by passing on a region name.
reg_psf0 = starts_empty_e_psf.get_regional_extended_psf(region_name=self.regions[j])
reg_psf1 = with_default_e_psf.get_regional_extended_psf(region_name=self.regions[j])
reg_psf0 = starts_empty_e_psf.get_extended_psf(region_name=self.regions[j])
reg_psf1 = with_default_e_psf.get_extended_psf(region_name=self.regions[j])
self.assertMaskedImagesAlmostEqual(reg_psf0, self.regional_e_psfs[j])
self.assertMaskedImagesAlmostEqual(reg_psf1, self.regional_e_psfs[j])
# Ensure we recover the original default PSF.
Expand All @@ -118,7 +125,7 @@ def test_IO(self):
self.assertMaskedImagesAlmostEqual(per_region_e_psf0(), read_e_psf0())
# And per-region extended PSFs.
for j in range(3):
for det in self.region_detectors[j]:
for det in self.region_detectors[j].detectors:
reg_psf0, read_reg_psf0 = per_region_e_psf0(det), read_e_psf0(det)
self.assertMaskedImagesAlmostEqual(reg_psf0, read_reg_psf0)
# Test IO with a single per-region extended PSF.
Expand All @@ -130,7 +137,7 @@ def test_IO(self):
read_e_psf1 = extended_psf.ExtendedPsf.readFits(f.name)
self.assertEqual(per_region_e_psf0.detectors_focal_plane_regions,
read_e_psf0.detectors_focal_plane_regions)
for det in self.region_detectors[1]:
for det in self.region_detectors[1].detectors:
reg_psf1, read_reg_psf1 = per_region_e_psf1(det), read_e_psf1(det)
self.assertMaskedImagesAlmostEqual(reg_psf1, read_reg_psf1)

Expand Down

0 comments on commit b49c4eb

Please sign in to comment.