diff --git a/python/lsst/pipe/tasks/__init__.py b/python/lsst/pipe/tasks/__init__.py index 0eea61cc1..74b03a216 100644 --- a/python/lsst/pipe/tasks/__init__.py +++ b/python/lsst/pipe/tasks/__init__.py @@ -1 +1,2 @@ +from .brightStarSubtraction import * from .version import * diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/__init__.py b/python/lsst/pipe/tasks/brightStarSubtraction/__init__.py new file mode 100644 index 000000000..fe5088369 --- /dev/null +++ b/python/lsst/pipe/tasks/brightStarSubtraction/__init__.py @@ -0,0 +1,2 @@ +from .brightStarCutout import * +from .brightStarStack import * diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py new file mode 100644 index 000000000..2a737cfcc --- /dev/null +++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py @@ -0,0 +1,636 @@ +# This file is part of pipe_tasks. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +"""Extract bright star cutouts; normalize and warp, optionally fit the PSF.""" + +__all__ = ["BrightStarCutoutConnections", "BrightStarCutoutConfig", "BrightStarCutoutTask"] + +from typing import Any, Iterable, cast + +import astropy.units as u +import numpy as np +from astropy.coordinates import SkyCoord +from astropy.table import Table +from lsst.afw.cameraGeom import FIELD_ANGLE, PIXELS +from lsst.afw.detection import Footprint, FootprintSet, Threshold +from lsst.afw.geom import SkyWcs, SpanSet, makeModifiedWcs +from lsst.afw.geom.transformFactory import makeTransform +from lsst.afw.image import ExposureF, ImageF, MaskedImageF +from lsst.afw.math import BackgroundList, FixedKernel, WarpingControl, warpImage +from lsst.daf.butler import DataCoordinate +from lsst.geom import ( + AffineTransform, + Box2I, + Extent2D, + Extent2I, + Point2D, + Point2I, + SpherePoint, + arcseconds, + floor, + radians, +) +from lsst.meas.algorithms import ( + BrightStarStamp, + BrightStarStamps, + KernelPsf, + LoadReferenceObjectsConfig, + ReferenceObjectLoader, + WarpedPsf, +) +from lsst.pex.config import ChoiceField, ConfigField, Field, ListField +from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct +from lsst.pipe.base.connectionTypes import Input, Output, PrerequisiteInput +from lsst.utils.timer import timeMethod + +NEIGHBOR_MASK_PLANE = "NEIGHBOR" + + +class BrightStarCutoutConnections( + PipelineTaskConnections, + dimensions=("instrument", "visit", "detector"), +): + """Connections for BrightStarCutoutTask.""" + + refCat = PrerequisiteInput( + name="gaia_dr3_20230707", + storageClass="SimpleCatalog", + doc="Reference catalog that contains bright star positions.", + dimensions=("skypix",), + multiple=True, + deferLoad=True, + ) + inputExposure = Input( + name="calexp", + storageClass="ExposureF", + doc="Background-subtracted input exposure from which to extract bright star stamp cutouts.", + dimensions=("visit", "detector"), + ) + inputBackground = Input( + name="calexpBackground", + storageClass="Background", + doc="Background model for the input exposure, to be added back on during processing.", + dimensions=("visit", "detector"), + ) + brightStarStamps = Output( + name="brightStarStamps", + storageClass="BrightStarStamps", + doc="Set of preprocessed postage stamp cutouts, each centered on a single bright star.", + dimensions=("visit", "detector"), + ) + + +class BrightStarCutoutConfig( + PipelineTaskConfig, + pipelineConnections=BrightStarCutoutConnections, +): + """Configuration parameters for BrightStarCutoutTask.""" + + # Star selection + magRange = ListField[float]( + doc="Magnitude range in Gaia G. Cutouts will be made for all stars in this range.", + default=[0, 18], + ) + excludeArcsecRadius = Field[float]( + doc="Stars with a star in the range ``excludeMagRange`` mag in ``excludeArcsecRadius`` are not used.", + default=5, + ) + excludeMagRange = ListField[float]( + doc="Stars with a star in the range ``excludeMagRange`` mag in ``excludeArcsecRadius`` are not used.", + default=[0, 20], + ) + minAreaFraction = Field[float]( + doc="Minimum fraction of the stamp area, post-masking, that must remain for a cutout to be retained.", + default=0.1, + ) + badMaskPlanes = ListField[str]( + doc="Mask planes that identify excluded pixels for the calculation of ``minAreaFraction`` and, " + "optionally, fitting of the PSF.", + default=[ + "BAD", + "CR", + "CROSSTALK", + "EDGE", + "NO_DATA", + "SAT", + "SUSPECT", + "UNMASKEDNAN", + NEIGHBOR_MASK_PLANE, + ], + ) + + # Cutout geometry + stampSize = ListField[int]( + doc="Size of the stamps to be extracted, in pixels.", + default=(251, 251), + ) + stampSizePadding = Field[float]( + doc="Multiplicative factor applied to the cutout stamp size, to guard against post-warp data loss.", + default=1.1, + ) + warpingKernelName = ChoiceField[str]( + doc="Warping kernel.", + default="lanczos5", + allowed={ + "bilinear": "bilinear interpolation", + "lanczos3": "Lanczos kernel of order 3", + "lanczos4": "Lanczos kernel of order 4", + "lanczos5": "Lanczos kernel of order 5", + }, + ) + maskWarpingKernelName = ChoiceField[str]( + doc="Warping kernel for mask.", + default="bilinear", + allowed={ + "bilinear": "bilinear interpolation", + "lanczos3": "Lanczos kernel of order 3", + "lanczos4": "Lanczos kernel of order 4", + "lanczos5": "Lanczos kernel of order 5", + }, + ) + + # PSF Fitting + doFitPsf = Field[bool]( + doc="Fit a scaled PSF and a pedestal to each bright star cutout.", + default=True, + ) + useMedianVariance = Field[bool]( + doc="Use the median of the variance plane for PSF fitting.", + default=False, + ) + psfMaskedFluxFracThreshold = Field[float]( + doc="Maximum allowed fraction of masked PSF flux for PSF fitting to occur.", + default=0.97, + ) + + # Misc + loadReferenceObjectsConfig = ConfigField[LoadReferenceObjectsConfig]( + doc="Reference object loader for astrometric calibration.", + ) + + +class BrightStarCutoutTask(PipelineTask): + """Extract bright star cutouts; normalize and warp to the same pixel grid. + + The BrightStarCutoutTask is used to extract, process, and store small image + cutouts (or "postage stamps") around bright stars. + This task essentially consists of three principal steps. + First, it identifies bright stars within an exposure using a reference + catalog and extracts a stamp around each. + Second, it shifts and warps each stamp to remove optical distortions and + sample all stars on the same pixel grid. + Finally, it optionally fits a PSF plus plane flux model to the cutout. + This final fitting procedure may be used to normalize each bright star + stamp prior to stacking when producing extended PSF models. + """ + + ConfigClass = BrightStarCutoutConfig + _DefaultName = "brightStarCutout" + config: BrightStarCutoutConfig + + def __init__(self, initInputs=None, *args, **kwargs): + super().__init__(*args, **kwargs) + stampSize = Extent2D(*self.config.stampSize.list()) + stampRadius = floor(stampSize / 2) + self.stampBBox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy(stampRadius) + paddedStampSize = stampSize * self.config.stampSizePadding + self.paddedStampRadius = floor(paddedStampSize / 2) + self.paddedStampBBox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy( + self.paddedStampRadius + ) + + def runQuantum(self, butlerQC, inputRefs, outputRefs): + inputs = butlerQC.get(inputRefs) + inputs["dataId"] = butlerQC.quantum.dataId + refObjLoader = ReferenceObjectLoader( + dataIds=[ref.datasetRef.dataId for ref in inputRefs.refCat], + refCats=inputs.pop("refCat"), + name=self.config.connections.refCat, + config=self.config.loadReferenceObjectsConfig, + ) + output = self.run(**inputs, refObjLoader=refObjLoader) + # Only ingest Stamp if it exists; prevents ingesting an empty FITS file + if output: + butlerQC.put(output, outputRefs) + + @timeMethod + def run( + self, + inputExposure: ExposureF, + inputBackground: BackgroundList, + refObjLoader: ReferenceObjectLoader, + dataId: dict[str, Any] | DataCoordinate, + ): + """Identify bright stars within an exposure using a reference catalog, + extract stamps around each, warp/shift stamps onto a common frame and + then optionally fit a PSF plus plane model. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The background-subtracted image to extract bright star stamps. + inputBackground : `~lsst.afw.math.BackgroundList` + The background model associated with the input exposure. + refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional + Loader to find objects within a reference catalog. + dataId : `dict` or `~lsst.daf.butler.DataCoordinate` + The dataId of the exposure that bright stars are extracted from. + Both 'visit' and 'detector' will be persisted in the output data. + + Returns + ------- + brightStarResults : `~lsst.pipe.base.Struct` + Results as a struct with attributes: + + ``brightStarStamps`` + (`~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`) + """ + wcs = inputExposure.getWcs() + bbox = inputExposure.getBBox() + warpingControl = WarpingControl(self.config.warpingKernelName, self.config.maskWarpingKernelName) + + refCatBright = self._getRefCatBright(refObjLoader, wcs, bbox) + zipRaDec = zip(refCatBright["coord_ra"] * radians, refCatBright["coord_dec"] * radians) + spherePoints = [SpherePoint(ra, dec) for ra, dec in zipRaDec] + pixCoords = wcs.skyToPixel(spherePoints) + + # Restore original subtracted background + inputMI = inputExposure.getMaskedImage() + inputMI += inputBackground.getImage() + + # Set up NEIGHBOR mask plane; associate footprints with stars + inputExposure.mask.addMaskPlane(NEIGHBOR_MASK_PLANE) + allFootprints, associations = self._associateFootprints(inputExposure, pixCoords, plane="DETECTED") + + # TODO: If we eventually have better PhotoCalibs (eg FGCM), apply here + inputMI = inputExposure.getPhotoCalib().calibrateImage(inputMI, False) + + # Set up transform + detector = inputExposure.detector + pixelScale = wcs.getPixelScale().asArcseconds() * arcseconds + pixToFocalPlaneTan = detector.getTransform(PIXELS, FIELD_ANGLE).then( + makeTransform(AffineTransform.makeScaling(1 / pixelScale.asRadians())) + ) + + # Loop over each bright star + stamps, goodFracs, stamps_fitPsfResults = [], [], [] + for starIndex, (obj, pixCoord) in enumerate(zip(refCatBright, pixCoords)): # type: ignore + footprintIndex = associations.get(starIndex, None) + stampMI = MaskedImageF(self.paddedStampBBox) + + # Set NEIGHBOR footprints in the mask plane + if footprintIndex: + neighborFootprints = [fp for i, fp in enumerate(allFootprints) if i != footprintIndex] + self._setFootprints(inputExposure, neighborFootprints, NEIGHBOR_MASK_PLANE) + else: + self._setFootprints(inputExposure, allFootprints, NEIGHBOR_MASK_PLANE) + + # Define linear shifting to recenter stamps + coordFocalPlaneTan = pixToFocalPlaneTan.applyForward(pixCoord) # center of warped star + shift = makeTransform(AffineTransform(Point2D(0, 0) - coordFocalPlaneTan)) + angle = np.arctan2(coordFocalPlaneTan.getY(), coordFocalPlaneTan.getX()) * radians + rotation = makeTransform(AffineTransform.makeRotation(-angle)) + pixToPolar = pixToFocalPlaneTan.then(shift).then(rotation) + + # Apply the warp to the star stamp (in-place) + warpImage(stampMI, inputExposure.maskedImage, pixToPolar, warpingControl) + + # Trim to the base stamp size, check mask coverage, update metadata + stampMI = stampMI[self.stampBBox] + badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.badMaskPlanes) + goodFrac = np.sum(stampMI.mask.array & badMaskBitMask == 0) / stampMI.mask.array.size + goodFracs.append(goodFrac) + if goodFrac < self.config.minAreaFraction: + continue + + # Fit a scaled PSF and a pedestal to each bright star cutout + psf = WarpedPsf(inputExposure.getPsf(), pixToPolar, warpingControl) + constantPsf = KernelPsf(FixedKernel(psf.computeKernelImage(Point2D(0, 0)))) + fitPsfResults = {} + if self.config.doFitPsf: + fitPsfResults = self._fitPsf(stampMI, constantPsf) + stamps_fitPsfResults.append(fitPsfResults) + + # Save the stamp if the PSF fit was successful or no fit requested + if fitPsfResults or not self.config.doFitPsf: + stamp = BrightStarStamp( + maskedImage=stampMI, + psf=constantPsf, + wcs=makeModifiedWcs(pixToPolar, wcs, False), + visit=cast(int, dataId["visit"]), + detector=cast(int, dataId["detector"]), + refId=obj["id"], + refMag=obj["mag"], + position=pixCoord, + scale=fitPsfResults.get("scale", None), + scaleErr=fitPsfResults.get("scaleErr", None), + pedestal=fitPsfResults.get("pedestal", None), + pedestalErr=fitPsfResults.get("pedestalErr", None), + pedestalScaleCov=fitPsfResults.get("pedestalScaleCov", None), + xGradient=fitPsfResults.get("xGradient", None), + yGradient=fitPsfResults.get("yGradient", None), + globalReducedChiSquared=fitPsfResults.get("globalReducedChiSquared", None), + globalDegreesOfFreedom=fitPsfResults.get("globalDegreesOfFreedom", None), + psfReducedChiSquared=fitPsfResults.get("psfReducedChiSquared", None), + psfDegreesOfFreedom=fitPsfResults.get("psfDegreesOfFreedom", None), + psfMaskedFluxFrac=fitPsfResults.get("psfMaskedFluxFrac", None), + ) + stamps.append(stamp) + + self.log.info( + "Extracted %i bright star stamp%s. Excluded stars: insufficient area (%i), PSF fit failure (%i).", + len(stamps), + "" if len(stamps) == 1 else "s", + np.sum(np.array(goodFracs) < self.config.minAreaFraction), + ( + np.sum(np.isnan([x.get("pedestal", np.nan) for x in stamps_fitPsfResults])) + if self.config.doFitPsf + else 0 + ), + ) + brightStarStamps = BrightStarStamps(stamps) + return Struct(brightStarStamps=brightStarStamps) + + def _getRefCatBright(self, refObjLoader: ReferenceObjectLoader, wcs: SkyWcs, bbox: Box2I) -> Table: + """Get a bright star subset of the reference catalog. + + Trim the reference catalog to only those objects within the exposure + bounding box dilated by half the bright star stamp size. + This ensures all stars that overlap the exposure are included. + + Parameters + ---------- + refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader` + Loader to find objects within a reference catalog. + wcs : `~lsst.afw.geom.SkyWcs` + World coordinate system. + bbox : `~lsst.geom.Box2I` + Bounding box of the exposure. + + Returns + ------- + refCatBright : `~astropy.table.Table` + Bright star subset of the reference catalog. + """ + dilatedBBox = bbox.dilatedBy(self.paddedStampRadius) + withinExposure = refObjLoader.loadPixelBox(dilatedBBox, wcs, filterName="phot_g_mean") + refCatFull = withinExposure.refCat + fluxField: str = withinExposure.fluxField + + proxFluxRange = sorted(((self.config.excludeMagRange * u.ABmag).to(u.nJy)).to_value()) + brightFluxRange = sorted(((self.config.magRange * u.ABmag).to(u.nJy)).to_value()) + + subsetStars = (refCatFull[fluxField] > np.min((proxFluxRange[0], brightFluxRange[0]))) & ( + refCatFull[fluxField] < np.max((proxFluxRange[1], brightFluxRange[1])) + ) + refCatSubset = Table(refCatFull.extract("id", "coord_ra", "coord_dec", fluxField, where=subsetStars)) + + proxStars = (refCatSubset[fluxField] >= proxFluxRange[0]) & ( + refCatSubset[fluxField] <= proxFluxRange[1] + ) + brightStars = (refCatSubset[fluxField] >= brightFluxRange[0]) & ( + refCatSubset[fluxField] <= brightFluxRange[1] + ) + + coords = SkyCoord(refCatSubset["coord_ra"], refCatSubset["coord_dec"], unit="rad") + excludeArcsecRadius = self.config.excludeArcsecRadius * u.arcsec # type: ignore + refCatBrightIsolated = [] + for coord in cast(Iterable[SkyCoord], coords[brightStars]): + neighbors = coords[proxStars] + seps = coord.separation(neighbors).to(u.arcsec) + tooClose = (seps > 0) & (seps <= excludeArcsecRadius) # not self matched + refCatBrightIsolated.append(not tooClose.any()) + + refCatBright = cast(Table, refCatSubset[brightStars][refCatBrightIsolated]) + + fluxNanojansky = refCatBright[fluxField][:] * u.nJy # type: ignore + refCatBright["mag"] = fluxNanojansky.to(u.ABmag).to_value() # AB magnitudes + + self.log.info( + "Identified %i of %i star%s which overlap%s the frame, in the range %s mag, and %s no nearby " + "neighbors.", + len(refCatBright), + len(refCatFull), + "" if len(refCatFull) == 1 else "s", + "s" if len(refCatBright) == 1 else "", + self.config.magRange, + "has" if len(refCatBright) == 1 else "have", + ) + + return refCatBright + + def _associateFootprints( + self, inputExposure: ExposureF, pixCoords: list[Point2D], plane: str + ) -> tuple[list[Footprint], dict[int, int]]: + """Associate footprints from a given mask plane with specific objects. + + Footprints from the given mask plane are associated with objects at the + coordinates provided, where possible. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The input exposure with a mask plane. + pixCoords : `list` [`~lsst.geom.Point2D`] + The pixel coordinates of the objects. + plane : `str` + The mask plane used to identify masked pixels. + + Returns + ------- + footprints : `list` [`~lsst.afw.detection.Footprint`] + The footprints from the input exposure. + associations : `dict`[int, int] + Association indices between objects (key) and footprints (value). + """ + detThreshold = Threshold(inputExposure.mask.getPlaneBitMask(plane), Threshold.BITMASK) + footprintSet = FootprintSet(inputExposure.mask, detThreshold) + footprints = footprintSet.getFootprints() + associations = {} + for starIndex, pixCoord in enumerate(pixCoords): + for footprintIndex, footprint in enumerate(footprints): + if footprint.contains(Point2I(pixCoord)): + associations[starIndex] = footprintIndex + break + self.log.debug( + "Associated %i of %i star%s to one each of the %i %s footprint%s.", + len(associations), + len(pixCoords), + "" if len(pixCoords) == 1 else "s", + len(footprints), + plane, + "" if len(footprints) == 1 else "s", + ) + return footprints, associations + + def _setFootprints(self, inputExposure: ExposureF, footprints: list, maskPlane: str): + """Set footprints in a given mask plane. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The input exposure to modify. + footprints : `list` [`~lsst.afw.detection.Footprint`] + The footprints to set in the mask plane. + maskPlane : `str` + The mask plane to set the footprints in. + + Notes + ----- + This method modifies the ``inputExposure`` object in-place. + """ + detThreshold = Threshold(inputExposure.mask.getPlaneBitMask(maskPlane), Threshold.BITMASK) + detThresholdValue = int(detThreshold.getValue()) + footprintSet = FootprintSet(inputExposure.mask, detThreshold) + + # Wipe any existing footprints in the mask plane + inputExposure.mask.clearMaskPlane(int(np.log2(detThresholdValue))) + + # Set the footprints in the mask plane + footprintSet.setFootprints(footprints) + footprintSet.setMask(inputExposure.mask, maskPlane) + + def _fitPsf(self, stampMI: MaskedImageF, psf: KernelPsf) -> dict[str, Any]: + """Fit a scaled PSF and a pedestal to each bright star cutout. + + Parameters + ---------- + stampMI : `~lsst.afw.image.MaskedImageF` + The masked image of the bright star cutout. + psf : `~lsst.meas.algorithms.KernelPsf` + The PSF model to fit. + + Returns + ------- + fitPsfResults : `dict`[`str`, `float`] + The result of the PSF fitting, with keys: + + ``scale`` : `float` + The scale factor. + ``scaleErr`` : `float` + The error on the scale factor. + ``pedestal`` : `float` + The pedestal value. + ``pedestalErr`` : `float` + The error on the pedestal value. + ``pedestalScaleCov`` : `float` + The covariance between the pedestal and scale factor. + ``xGradient`` : `float` + The gradient in the x-direction. + ``yGradient`` : `float` + The gradient in the y-direction. + ``globalReducedChiSquared`` : `float` + The global reduced chi-squared goodness-of-fit. + ``globalDegreesOfFreedom`` : `int` + The global number of degrees of freedom. + ``psfReducedChiSquared`` : `float` + The PSF BBox reduced chi-squared goodness-of-fit. + ``psfDegreesOfFreedom`` : `int` + The PSF BBox number of degrees of freedom. + ``psfMaskedFluxFrac`` : `float` + The fraction of the PSF image flux masked by bad pixels. + """ + psfImage = psf.computeKernelImage(psf.getAveragePosition()) + badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.badMaskPlanes) + + # Calculate the fraction of the PSF image flux masked by bad pixels + psfMaskedPixels = ImageF(psfImage.getBBox()) + psfMaskedPixels.array[:, :] = (stampMI.mask[psfImage.getBBox()].array & badMaskBitMask).astype(bool) + psfMaskedFluxFrac = np.dot(psfImage.array.flat, psfMaskedPixels.array.flat) + if psfMaskedFluxFrac > self.config.psfMaskedFluxFracThreshold: + return {} # Handle cases where the PSF image is mostly masked + + # Create a padded version of the input constant PSF image + paddedPsfImage = ImageF(stampMI.getBBox()) + paddedPsfImage[psfImage.getBBox()] = psfImage.convertF() + + # Create consistently masked data + badSpans = SpanSet.fromMask(stampMI.mask, badMaskBitMask) + goodSpans = SpanSet(stampMI.getBBox()).intersectNot(badSpans) + varianceData = goodSpans.flatten(stampMI.variance.array, stampMI.getXY0()) + if self.config.useMedianVariance: + varianceData = np.median(varianceData) + sigmaData = np.sqrt(varianceData) + imageData = goodSpans.flatten(stampMI.image.array, stampMI.getXY0()) # B + imageData /= sigmaData + psfData = goodSpans.flatten(paddedPsfImage.array, paddedPsfImage.getXY0()) + psfData /= sigmaData + + # Fit the PSF scale factor and global pedestal + nData = len(imageData) + coefficientMatrix = np.ones((nData, 4), dtype=float) # A + coefficientMatrix[:, 0] = psfData + coefficientMatrix[:, 1] /= sigmaData + coefficientMatrix[:, 2:] = goodSpans.indices().T + coefficientMatrix[:, 2] /= sigmaData + coefficientMatrix[:, 3] /= sigmaData + try: + solutions, sumSquaredResiduals, *_ = np.linalg.lstsq(coefficientMatrix, imageData, rcond=None) + covarianceMatrix = np.linalg.inv(np.dot(coefficientMatrix.transpose(), coefficientMatrix)) # C + except np.linalg.LinAlgError: + return {} # Handle singular matrix errors + if sumSquaredResiduals.size == 0: + return {} # Handle cases where sum of the squared residuals are empty + scale = solutions[0] + if scale <= 0: + return {} # Handle cases where the PSF scale fit has failed + scaleErr = np.sqrt(covarianceMatrix[0, 0]) + pedestal = solutions[1] + pedestalErr = np.sqrt(covarianceMatrix[1, 1]) + scalePedestalCov = covarianceMatrix[0, 1] + xGradient = solutions[3] + yGradient = solutions[2] + + # Calculate global (whole image) reduced chi-squared + globalChiSquared = np.sum(sumSquaredResiduals) + globalDegreesOfFreedom = nData - 4 + globalReducedChiSquared = globalChiSquared / globalDegreesOfFreedom + + # Calculate PSF BBox reduced chi-squared + psfBBoxGoodSpans = goodSpans.clippedTo(psfImage.getBBox()) + psfBBoxGoodSpansX, psfBBoxGoodSpansY = psfBBoxGoodSpans.indices() + psfBBoxData = psfBBoxGoodSpans.flatten(stampMI.image.array, stampMI.getXY0()) + psfBBoxModel = ( + psfBBoxGoodSpans.flatten(paddedPsfImage.array, stampMI.getXY0()) * scale + + pedestal + + psfBBoxGoodSpansX * xGradient + + psfBBoxGoodSpansY * yGradient + ) + psfBBoxVariance = psfBBoxGoodSpans.flatten(stampMI.variance.array, stampMI.getXY0()) + psfBBoxResiduals = (psfBBoxData - psfBBoxModel) ** 2 / psfBBoxVariance + psfBBoxChiSquared = np.sum(psfBBoxResiduals) + psfBBoxDegreesOfFreedom = len(psfBBoxData) - 4 + psfBBoxReducedChiSquared = psfBBoxChiSquared / psfBBoxDegreesOfFreedom + + return dict( + scale=scale, + scaleErr=scaleErr, + pedestal=pedestal, + pedestalErr=pedestalErr, + xGradient=xGradient, + yGradient=yGradient, + pedestalScaleCov=scalePedestalCov, + globalReducedChiSquared=globalReducedChiSquared, + globalDegreesOfFreedom=globalDegreesOfFreedom, + psfReducedChiSquared=psfBBoxReducedChiSquared, + psfDegreesOfFreedom=psfBBoxDegreesOfFreedom, + psfMaskedFluxFrac=psfMaskedFluxFrac, + ) diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py new file mode 100644 index 000000000..5e3acf075 --- /dev/null +++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py @@ -0,0 +1,258 @@ +# This file is part of pipe_tasks. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +"""Stack bright star postage stamp cutouts to produce an extended PSF model.""" + +__all__ = ["BrightStarStackConnections", "BrightStarStackConfig", "BrightStarStackTask"] + +import numpy as np +from astropy.stats import sigma_clip +from lsst.afw.image import ImageF, MaskedImageF +from lsst.afw.math import StatisticsControl, statisticsStack, stringToStatisticsProperty +from lsst.geom import Box2I, Point2I +from lsst.meas.algorithms import BrightStarStamps +from lsst.pex.config import Field, ListField +from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct +from lsst.pipe.base.connectionTypes import Input, Output +from lsst.utils.timer import timeMethod + +NEIGHBOR_MASK_PLANE = "NEIGHBOR" + + +class BrightStarStackConnections( + PipelineTaskConnections, + dimensions=("instrument", "detector"), +): + """Connections for BrightStarStackTask.""" + + brightStarStamps = Input( + name="brightStarStamps", + storageClass="BrightStarStamps", + doc="Set of preprocessed postage stamp cutouts, each centered on a single bright star.", + dimensions=("visit", "detector"), + multiple=True, + deferLoad=True, + ) + extendedPsf = Output( + name="extendedPsf2", # extendedPsfDetector ??? + storageClass="ExtendedPsf", # MaskedImageF + doc="Extended PSF model, built from stacking bright star cutouts.", + dimensions=("band",), + ) + + +class BrightStarStackConfig( + PipelineTaskConfig, + pipelineConnections=BrightStarStackConnections, +): + """Configuration parameters for BrightStarStackTask.""" + + subsetStampNumber = Field[int]( + doc="Number of stamps per subset to generate stacked images for.", + default=20, + ) + badMaskPlanes = ListField[str]( + doc="Mask planes that identify excluded (masked) pixels.", + default=[ + "BAD", + "CR", + "CROSSTALK", + "EDGE", + "NO_DATA", + # "SAT", + # "SUSPECT", + "UNMASKEDNAN", + NEIGHBOR_MASK_PLANE, + ], + ) + + +class BrightStarStackTask(PipelineTask): + """Stack bright star postage stamps to produce an extended PSF model.""" + + ConfigClass = BrightStarStackConfig + _DefaultName = "brightStarStack" + config: BrightStarStackConfig + + def __init__(self, initInputs=None, *args, **kwargs): + super().__init__(*args, **kwargs) + + def runQuantum(self, butlerQC, inputRefs, outputRefs): + inputs = butlerQC.get(inputRefs) + output = self.run(**inputs) + butlerQC.put(output, outputRefs) + + @timeMethod + def run( + self, + brightStarStamps: BrightStarStamps, + ): + """Identify bright stars within an exposure using a reference catalog, + extract stamps around each, then preprocess them. + + Bright star preprocessing steps are: shifting, warping and potentially + rotating them to the same pixel grid; computing their annular flux, + and; normalizing them. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The image from which bright star stamps should be extracted. + inputBackground : `~lsst.afw.image.Background` + The background model for the input exposure. + refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional + Loader to find objects within a reference catalog. + dataId : `dict` or `~lsst.daf.butler.DataCoordinate` + The dataId of the exposure (including detector) that bright stars + should be extracted from. + + Returns + ------- + brightStarResults : `~lsst.pipe.base.Struct` + Results as a struct with attributes: + + ``brightStarStamps`` + (`~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`) + """ + statisticsControl = StatisticsControl( + numSigmaClip=3, + numIter=5, + ) + + extendedPsfMI = None + extendedImages = [] + tempImages = [] + tempVariances = [] + tempIndex = 0 + for stampsDDH in brightStarStamps: + stamps = stampsDDH.get() + if not stamps: + continue + tempIndex += 1 + + for stamp in stamps: + stampMI = stamp.maskedImage + stampBBox = stampMI.getBBox() + + # Apply fitted components + stampMI -= stamp.pedestal + xGrid, yGrid = np.meshgrid(stampBBox.getX().arange(), stampBBox.getY().arange()) + xPlane = ImageF((xGrid * stamp.xGradient).astype(np.float32), xy0=stampMI.getXY0()) + yPlane = ImageF((yGrid * stamp.yGradient).astype(np.float32), xy0=stampMI.getXY0()) + stampMI -= xPlane + stampMI -= yPlane + stampMI *= stamp.scale + + badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.badMaskPlanes) + breakpoint() + stampMask = (stampMI.mask.array & badMaskBitMask).astype(bool) + stampMI.image.array[stampMask] = 0 + stampMI.variance.array[stampMask] = 0 + + tempImages.append(stampMI.image.array) + tempVariances.append(stampMI.variance.array) + + if tempIndex == self.config.numVisitStack: + # for i in range(tempImage.shape[0]): + # for j in range(tempImage.shape[1]): + # pixel_values = tempImage[i, j, :] + # stats = afwMath.makeStatistics(pixel_values, afwMath.MEANCLIP, sctrl) + # mean_image[i, j] = stats.getValue(afwMath.MEANCLIP) + + tempImages2 = np.stack(tempImages) + tempVariances2 = np.stack(tempVariances) + + clippedImages = sigma_clip(tempImages2, axis=2, sigma=3) + clippedVariances = sigma_clip(tempVariances2, axis=2, sigma=3) + + clippedImage = np.mean(clippedImages, axis=2) + extendedImages.append(clippedImage) + + tempIndex = 0 + + breakpoint() + extendedImages2 = np.stack(extendedImages) + clippedImages2 = sigma_clip(extendedImages2, axis=0, sigma=3) + extendedImage = np.mean(clippedImages2, axis=0) + extendedPsfMI = MaskedImageF(image=ImageF(extendedImage), variance=ImageF(extendedImage)) + + return Struct(extendedPsf=extendedPsfMI) + + # if stamp.psfReducedChiSquared > 5: + # continue + # stampWeight = 1 / stamp.psfReducedChiSquared + # stampMI *= stampWeight + # stampWI = ImageF(stampBBox, stampWeight) + # stampWI.array[stampMask] = 0 + + # if not extendedPsfMI: + # extendedPsfMI = stampMI.clone() + # extendedPsfWI = stampWI.clone() + # else: + # extendedPsfMI += stampMI + # extendedPsfWI += stampWI + + # if extendedPsfMI: + # extendedPsfMI /= extendedPsfWI + # breakpoint() + + # stack = [] + # chiStack = [] + # for loop over all groups: + # load up all visits for this detector + # drop all with GOF > thresh + # sigma-clip mean stack the rest + # append to stack + # compute the scatter (MAD/sigma-clipped var, etc) of the rest + # divide by sqrt(var plane), and append to chiStack + # after for-loop, combine images in median stack for final result + # also combine chi-images, save separately + + # idea: run with two different thresholds, and compare the results + + # medianStack = [] + # for loop over all groups: + # load up all visits for this detector + # drop all with GOF > thresh + # median/sigma-clip stack the rest + # append to medianStack + # after for-loop, combine images in median stack for final result + + def _configureStacking(self, numSigmaClip, numIter, badMaskBitMask, stackingStatistic): + """Configure stacking statistic and control from config fields.""" + statisticsControl = StatisticsControl(numSigmaClip=numSigmaClip, numIter=numIter) + statisticsFlag = stringToStatisticsProperty(stackingStatistic) + statisticsControl.setAndMask(badMaskBitMask) + return statisticsControl, statisticsFlag + + def _configureStacking(self, example_stamp): + """Configure stacking statistic and control from config fields.""" + stats_control = StatisticsControl( + numSigmaClip=self.config.num_sigma_clip, + numIter=self.config.num_iter, + ) + if bad_masks := self.config.bad_mask_planes: + and_mask = example_stamp.mask.getPlaneBitMask(bad_masks[0]) + for bm in bad_masks[1:]: + and_mask = and_mask | example_stamp.mask.getPlaneBitMask(bm) + stats_control.setAndMask(and_mask) + stats_flags = stringToStatisticsProperty(self.config.stacking_statistic) + return stats_control, stats_flags diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarSubtract.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarSubtract.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_brightStarCutout.py b/tests/test_brightStarCutout.py new file mode 100644 index 000000000..67b88d02f --- /dev/null +++ b/tests/test_brightStarCutout.py @@ -0,0 +1,102 @@ +# This file is part of pipe_tasks. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import unittest + +import lsst.afw.cameraGeom.testUtils +import lsst.afw.image +import lsst.utils.tests +import numpy as np +from lsst.afw.image import ImageD, ImageF, MaskedImageF +from lsst.afw.math import FixedKernel +from lsst.geom import Point2I +from lsst.meas.algorithms import KernelPsf +from lsst.pipe.tasks.brightStarSubtraction import BrightStarCutoutConfig, BrightStarCutoutTask + + +class BrightStarCutoutTestCase(lsst.utils.tests.TestCase): + def setUp(self): + # Fit values + self.scale = 2.34e5 + self.pedestal = 3210.1 + self.xGradient = 5.432 + self.yGradient = 10.987 + + # Create a pedestal + 2D plane + xCoords = np.linspace(-50, 50, 101) + yCoords = np.linspace(-50, 50, 101) + xPlane, yPlane = np.meshgrid(xCoords, yCoords) + pedestal = np.ones_like(xPlane) * self.pedestal + + # Create a pseudo-PSF + dist_from_center = np.sqrt(xPlane**2 + yPlane**2) + psfArray = np.exp(-dist_from_center / 5) + psfArray /= np.sum(psfArray) + fixedKernel = FixedKernel(ImageD(psfArray)) + self.psf = KernelPsf(fixedKernel) + + # Bring everything together to construct a stamp masked image + stampArray = psfArray * self.scale + pedestal + xPlane * self.xGradient + yPlane * self.yGradient + stampIm = ImageF((stampArray).astype(np.float32)) + stampVa = ImageF(stampIm.getBBox(), 654.321) + self.stampMI = MaskedImageF(image=stampIm, variance=stampVa) + self.stampMI.setXY0(Point2I(-50, -50)) + + # Ensure that all mask planes required by the task are in-place; + # new mask plane entries will be created as necessary + badMaskPlanes = [ + "BAD", + "CR", + "CROSSTALK", + "EDGE", + "NO_DATA", + "SAT", + "SUSPECT", + "UNMASKEDNAN", + "NEIGHBOR", + ] + _ = [self.stampMI.mask.addMaskPlane(mask) for mask in badMaskPlanes] + + def test_fitPsf(self): + """Test the PSF fitting method.""" + brightStarCutoutConfig = BrightStarCutoutConfig() + brightStarCutoutTask = BrightStarCutoutTask(config=brightStarCutoutConfig) + fitPsfResults = brightStarCutoutTask._fitPsf( + self.stampMI, + self.psf, + ) + self.assertAlmostEqual(fitPsfResults["scale"], self.scale, delta=1e-3) + self.assertAlmostEqual(fitPsfResults["pedestal"], self.pedestal, delta=1e-5) + self.assertAlmostEqual(fitPsfResults["xGradient"], self.xGradient, delta=1e-7) + self.assertAlmostEqual(fitPsfResults["yGradient"], self.yGradient, delta=1e-7) + + +def setup_module(module): + lsst.utils.tests.init() + + +class MemoryTestCase(lsst.utils.tests.MemoryTestCase): + pass + + +if __name__ == "__main__": + lsst.utils.tests.init() + unittest.main()