Skip to content

Commit

Permalink
Merge pull request #709 from lsst/tickets/DM-40451
Browse files Browse the repository at this point in the history
DM-40451: Update MultibandExposure.computePsfImage
  • Loading branch information
fred3m authored Sep 26, 2023
2 parents 4728e35 + 4865523 commit b62437f
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 34 deletions.
90 changes: 60 additions & 30 deletions python/lsst/afw/image/_exposure/_multiband.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,64 +27,93 @@
from lsst.pex.exceptions import InvalidParameterError
from . import Exposure, ExposureF
from ..utils import projectImage
from .._image._multiband import MultibandTripleBase, MultibandPixel
from .._image._multiband import MultibandTripleBase, MultibandPixel, MultibandImage
from .._image._multiband import tripleFromSingles, tripleFromArrays, makeTripleFromKwargs
from .._maskedImage import MaskedImage


class IncompleteDataError(Exception):
"""The PSF could not be computed due to incomplete data
Attributes
----------
missingBands: `list[str]`
The bands for which the PSF could not be calculated.
position: `Point2D`
The point at which the PSF could not be calcualted in the
missing bands.
partialPsf: `MultibandImage`
The image of the PSF using only the bands that successfully
computed a PSF image.
Parameters
----------
bands : `list` of `str`
The full list of bands in the `MultibandExposure` generating
the PSF.
"""
pass
def __init__(self, bands, position, partialPsf):
missingBands = [band for band in bands if band not in partialPsf.filters]

self.missingBands = missingBands
self.position = position
self.partialPsf = partialPsf
message = f"Failed to compute PSF at {position} in {missingBands}"
super().__init__(message)


def computePsfImage(psfModels, position, bands, useKernelImage=True):
def computePsfImage(psfModels, position, useKernelImage=True):
"""Get a multiband PSF image
The PSF Kernel Image is computed for each band
The PSF Image or PSF Kernel Image is computed for each band
and combined into a (filter, y, x) array.
Parameters
----------
psfList : `list` of `lsst.afw.detection.Psf`
psfModels : `dict[str, lsst.afw.detection.Psf]`
The list of PSFs in each band.
position : `Point2D` or `tuple`
Coordinates to evaluate the PSF.
bands: `list` or `str`
List of names for each band
useKernelImage: `bool`
Execute ``Psf.computeKernelImage`` when ``True`,
``PSF/computeImage`` when ``False``.
Returns
-------
psfs: `np.ndarray`
psfs: `lsst.afw.image.MultibandImage`
The multiband PSF image.
"""
psfs = []
psfs = {}
# Make the coordinates into a Point2D (if necessary)
if not isinstance(position, Point2D):
position = Point2D(position[0], position[1])

for bidx, psfModel in enumerate(psfModels):
incomplete = False

for band, psfModel in psfModels.items():
try:
if useKernelImage:
psf = psfModel.computeKernelImage(position)
else:
psf = psfModel.computeImage(position)
psfs.append(psf)
psfs[band] = psf
except InvalidParameterError:
# This band failed to compute the PSF due to incomplete data
# at that location. This is unlikely to be a problem for Rubin,
# however the edges of some HSC COSMOS fields contain incomplete
# data in some bands, so we track this error to distinguish it
# from unknown errors.
msg = "Failed to compute PSF at {} in band {}"
raise IncompleteDataError(msg.format(position, bands[bidx])) from None

left = np.min([psf.getBBox().getMinX() for psf in psfs])
bottom = np.min([psf.getBBox().getMinY() for psf in psfs])
right = np.max([psf.getBBox().getMaxX() for psf in psfs])
top = np.max([psf.getBBox().getMaxY() for psf in psfs])
incomplete = True

left = np.min([psf.getBBox().getMinX() for psf in psfs.values()])
bottom = np.min([psf.getBBox().getMinY() for psf in psfs.values()])
right = np.max([psf.getBBox().getMaxX() for psf in psfs.values()])
top = np.max([psf.getBBox().getMaxY() for psf in psfs.values()])
bbox = Box2I(Point2I(left, bottom), Point2I(right, top))
psfs = np.array([projectImage(psf, bbox).array for psf in psfs])
return psfs

psf_images = [projectImage(psf, bbox) for psf in psfs.values()]

mPsf = MultibandImage.fromImages(list(psfs.keys()), psf_images)

if incomplete:
raise IncompleteDataError(list(psfModels.keys()), position, mPsf)

return mPsf


class MultibandExposure(MultibandTripleBase):
Expand Down Expand Up @@ -217,7 +246,6 @@ def computePsfKernelImage(self, position):
return computePsfImage(
psfModels=self.getPsfs(),
position=position,
bands=self.filters,
useKernelImage=True,
)

Expand Down Expand Up @@ -245,7 +273,6 @@ def computePsfImage(self, position=None):
return computePsfImage(
psfModels=self.getPsfs(),
position=position,
bands=self.filters,
useKernelImage=True,
)

Expand All @@ -254,10 +281,10 @@ def getPsfs(self):
Returns
-------
psfs : `list` of `lsst.afw.detection.Psf`
psfs : `dict` of `lsst.afw.detection.Psf`
The PSF in each band
"""
return [s.getPsf() for s in self]
return {band: self[band].getPsf() for band in self.filters}

def _slice(self, filters, filterIndex, indices):
"""Slice the current object and return the result
Expand All @@ -284,12 +311,15 @@ def _slice(self, filters, filterIndex, indices):
assert isinstance(variance, MultibandPixel)
return (image, mask, variance)

_psfs = self.getPsfs()
psfs = [_psfs[band] for band in filters]

result = MultibandExposure(
filters=filters,
image=image,
mask=mask,
variance=variance,
psfs=self.getPsfs(),
psfs=psfs,
)

assert all([r.getBBox() == result._bbox for r in [result._mask, result._variance]])
Expand Down
8 changes: 4 additions & 4 deletions tests/test_multiband.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,18 +632,18 @@ def testCopy(self):

def testPsf(self):
psfImage = self.exposure.computePsfKernelImage(self.exposure.getBBox().getCenter())
self.assertFloatsAlmostEqual(psfImage, self.psfImage)
self.assertFloatsAlmostEqual(psfImage.array, self.psfImage)

newPsfs = [GaussianPsf(self.kernelSize, self.kernelSize, 1.0) for f in self.filters]
newPsfImage = [p.computeImage(p.getAveragePosition()).array for p in newPsfs]
for psf, exposure in zip(newPsfs, self.exposure.singles):
exposure.setPsf(psf)
psfImage = self.exposure.computePsfKernelImage(self.exposure.getBBox().getCenter())
self.assertFloatsAlmostEqual(psfImage, newPsfImage)
self.assertFloatsAlmostEqual(psfImage.array, newPsfImage)

psfImage = self.exposure.computePsfImage(self.exposure.getBBox().getCenter())[0]
psfImage = self.exposure.computePsfImage(self.exposure.getBBox().getCenter())["G"]
self.assertFloatsAlmostEqual(
psfImage,
psfImage.array,
self.exposure["G"].getPsf().computeImage(
self.exposure["G"].getPsf().getAveragePosition()
).array
Expand Down

0 comments on commit b62437f

Please sign in to comment.