diff --git a/python/lsst/pipe/tasks/__init__.py b/python/lsst/pipe/tasks/__init__.py
index 0eea61cc1..74b03a216 100644
--- a/python/lsst/pipe/tasks/__init__.py
+++ b/python/lsst/pipe/tasks/__init__.py
@@ -1 +1,2 @@
+from .brightStarSubtraction import *
from .version import *
diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/__init__.py b/python/lsst/pipe/tasks/brightStarSubtraction/__init__.py
new file mode 100644
index 000000000..fe5088369
--- /dev/null
+++ b/python/lsst/pipe/tasks/brightStarSubtraction/__init__.py
@@ -0,0 +1,2 @@
+from .brightStarCutout import *
+from .brightStarStack import *
diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py
new file mode 100644
index 000000000..fab07073b
--- /dev/null
+++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py
@@ -0,0 +1,630 @@
+# This file is part of pipe_tasks.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+"""Extract bright star cutouts; normalize and warp, optionally fit the PSF."""
+
+__all__ = ["BrightStarCutoutConnections", "BrightStarCutoutConfig", "BrightStarCutoutTask"]
+
+from typing import Any, Iterable, cast
+
+import astropy.units as u
+import numpy as np
+from astropy.coordinates import SkyCoord
+from astropy.table import Table
+from lsst.afw.cameraGeom import FIELD_ANGLE, PIXELS
+from lsst.afw.detection import Footprint, FootprintSet, Threshold
+from lsst.afw.geom import SkyWcs, SpanSet, makeModifiedWcs
+from lsst.afw.geom.transformFactory import makeTransform
+from lsst.afw.image import ExposureF, ImageF, MaskedImageF
+from lsst.afw.math import BackgroundList, FixedKernel, WarpingControl, warpImage
+from lsst.daf.butler import DataCoordinate
+from lsst.geom import (
+ AffineTransform,
+ Box2I,
+ Extent2D,
+ Extent2I,
+ Point2D,
+ Point2I,
+ SpherePoint,
+ arcseconds,
+ floor,
+ radians,
+)
+from lsst.meas.algorithms import (
+ BrightStarStamp,
+ BrightStarStamps,
+ KernelPsf,
+ LoadReferenceObjectsConfig,
+ ReferenceObjectLoader,
+ WarpedPsf,
+)
+from lsst.pex.config import ChoiceField, ConfigField, Field, ListField
+from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct
+from lsst.pipe.base.connectionTypes import Input, Output, PrerequisiteInput
+from lsst.utils.timer import timeMethod
+
+NEIGHBOR_MASK_PLANE = "NEIGHBOR"
+
+
+class BrightStarCutoutConnections(
+ PipelineTaskConnections,
+ dimensions=("instrument", "visit", "detector"),
+):
+ """Connections for BrightStarCutoutTask."""
+
+ refCat = PrerequisiteInput(
+ name="gaia_dr3_20230707",
+ storageClass="SimpleCatalog",
+ doc="Reference catalog that contains bright star positions.",
+ dimensions=("skypix",),
+ multiple=True,
+ deferLoad=True,
+ )
+ inputExposure = Input(
+ name="calexp",
+ storageClass="ExposureF",
+ doc="Background-subtracted input exposure from which to extract bright star stamp cutouts.",
+ dimensions=("visit", "detector"),
+ )
+ inputBackground = Input(
+ name="calexpBackground",
+ storageClass="Background",
+ doc="Background model for the input exposure, to be added back on during processing.",
+ dimensions=("visit", "detector"),
+ )
+ brightStarStamps = Output(
+ name="brightStarStamps",
+ storageClass="BrightStarStamps",
+ doc="Set of preprocessed postage stamp cutouts, each centered on a single bright star.",
+ dimensions=("visit", "detector"),
+ )
+
+
+class BrightStarCutoutConfig(
+ PipelineTaskConfig,
+ pipelineConnections=BrightStarCutoutConnections,
+):
+ """Configuration parameters for BrightStarCutoutTask."""
+
+ # Star selection
+ magLimit = Field[float](
+ doc="Magnitude limit, in Gaia G. Cutouts will be made for all stars brighter than this magnitude.",
+ default=18,
+ )
+ excludeArcsecRadius = Field[float](
+ doc="Stars with a star brighter than ``excludeMagLimit`` in ``excludeArcsecRadius`` are not be used.",
+ default=5,
+ )
+ excludeMagLimit = Field[float](
+ doc="Stars with a star brighter than ``excludeMagLimit`` in ``excludeArcsecRadius`` are not be used.",
+ default=20,
+ )
+ minAreaFraction = Field[float](
+ doc="Minimum fraction of the stamp area, post-masking, that must remain for a cutout to be retained.",
+ default=0.1,
+ )
+ badMaskPlanes = ListField[str](
+ doc="Mask planes that identify excluded pixels for the calculation of ``minAreaFraction`` and, "
+ "optionally, fitting of the PSF.",
+ default=[
+ "BAD",
+ "CR",
+ "CROSSTALK",
+ "EDGE",
+ "NO_DATA",
+ "SAT",
+ "SUSPECT",
+ "UNMASKEDNAN",
+ NEIGHBOR_MASK_PLANE,
+ ],
+ )
+
+ # Cutout geometry
+ stampSize = ListField[int](
+ doc="Size of the stamps to be extracted, in pixels.",
+ default=(251, 251),
+ )
+ stampSizePadding = Field[float](
+ doc="Multiplicative factor applied to the cutout stamp size, to guard against post-warp data loss.",
+ default=1.1,
+ )
+ warpingKernelName = ChoiceField[str](
+ doc="Warping kernel.",
+ default="lanczos5",
+ allowed={
+ "bilinear": "bilinear interpolation",
+ "lanczos3": "Lanczos kernel of order 3",
+ "lanczos4": "Lanczos kernel of order 4",
+ "lanczos5": "Lanczos kernel of order 5",
+ },
+ )
+ maskWarpingKernelName = ChoiceField[str](
+ doc="Warping kernel for mask.",
+ default="bilinear",
+ allowed={
+ "bilinear": "bilinear interpolation",
+ "lanczos3": "Lanczos kernel of order 3",
+ "lanczos4": "Lanczos kernel of order 4",
+ "lanczos5": "Lanczos kernel of order 5",
+ },
+ )
+
+ # PSF Fitting
+ doFitPsf = Field[bool](
+ doc="Fit a scaled PSF and a pedestal to each bright star cutout.",
+ default=True,
+ )
+ useMedianVariance = Field[bool](
+ doc="Use the median of the variance plane for PSF fitting.",
+ default=False,
+ )
+ psfMaskedFluxFracThreshold = Field[float](
+ doc="Maximum allowed fraction of masked PSF flux for PSF fitting to occur.",
+ default=0.97,
+ )
+
+ # Misc
+ loadReferenceObjectsConfig = ConfigField[LoadReferenceObjectsConfig](
+ doc="Reference object loader for astrometric calibration.",
+ )
+
+
+class BrightStarCutoutTask(PipelineTask):
+ """Extract bright star cutouts; normalize and warp to the same pixel grid.
+
+ The BrightStarCutoutTask is used to extract, process, and store small image
+ cutouts (or "postage stamps") around bright stars.
+ This task essentially consists of three principal steps.
+ First, it identifies bright stars within an exposure using a reference
+ catalog and extracts a stamp around each.
+ Second, it shifts and warps each stamp to remove optical distortions and
+ sample all stars on the same pixel grid.
+ Finally, it optionally fits a PSF plus plane flux model to the cutout.
+ This final fitting procedure may be used to normalize each bright star
+ stamp prior to stacking when producing extended PSF models.
+ """
+
+ ConfigClass = BrightStarCutoutConfig
+ _DefaultName = "brightStarCutout"
+ config: BrightStarCutoutConfig
+
+ def __init__(self, initInputs=None, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ stampSize = Extent2D(*self.config.stampSize.list())
+ stampRadius = floor(stampSize / 2)
+ self.stampBBox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy(stampRadius)
+ paddedStampSize = stampSize * self.config.stampSizePadding
+ self.paddedStampRadius = floor(paddedStampSize / 2)
+ self.paddedStampBBox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy(
+ self.paddedStampRadius
+ )
+
+ def runQuantum(self, butlerQC, inputRefs, outputRefs):
+ inputs = butlerQC.get(inputRefs)
+ inputs["dataId"] = butlerQC.quantum.dataId
+ refObjLoader = ReferenceObjectLoader(
+ dataIds=[ref.datasetRef.dataId for ref in inputRefs.refCat],
+ refCats=inputs.pop("refCat"),
+ name=self.config.connections.refCat,
+ config=self.config.loadReferenceObjectsConfig,
+ )
+ output = self.run(**inputs, refObjLoader=refObjLoader)
+ # Only ingest Stamp if it exists; prevents ingesting an empty FITS file
+ if output:
+ butlerQC.put(output, outputRefs)
+
+ @timeMethod
+ def run(
+ self,
+ inputExposure: ExposureF,
+ inputBackground: BackgroundList,
+ refObjLoader: ReferenceObjectLoader,
+ dataId: dict[str, Any] | DataCoordinate,
+ ):
+ """Identify bright stars within an exposure using a reference catalog,
+ extract stamps around each, warp/shift stamps onto a common frame and
+ then optionally fit a PSF plus plane model.
+
+ Parameters
+ ----------
+ inputExposure : `~lsst.afw.image.ExposureF`
+ The background-subtracted image to extract bright star stamps.
+ inputBackground : `~lsst.afw.math.BackgroundList`
+ The background model associated with the input exposure.
+ refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional
+ Loader to find objects within a reference catalog.
+ dataId : `dict` or `~lsst.daf.butler.DataCoordinate`
+ The dataId of the exposure that bright stars are extracted from.
+ Both 'visit' and 'detector' will be persisted in the output data.
+
+ Returns
+ -------
+ brightStarResults : `~lsst.pipe.base.Struct`
+ Results as a struct with attributes:
+
+ ``brightStarStamps``
+ (`~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`)
+ """
+ wcs = inputExposure.getWcs()
+ bbox = inputExposure.getBBox()
+ warpingControl = WarpingControl(self.config.warpingKernelName, self.config.maskWarpingKernelName)
+
+ refCatBright = self._getRefCatBright(refObjLoader, wcs, bbox)
+ zipRaDec = zip(refCatBright["coord_ra"] * radians, refCatBright["coord_dec"] * radians)
+ spherePoints = [SpherePoint(ra, dec) for ra, dec in zipRaDec]
+ pixCoords = wcs.skyToPixel(spherePoints)
+
+ # Restore original subtracted background
+ inputMI = inputExposure.getMaskedImage()
+ inputMI += inputBackground.getImage()
+
+ # Set up NEIGHBOR mask plane; associate footprints with stars
+ inputExposure.mask.addMaskPlane(NEIGHBOR_MASK_PLANE)
+ allFootprints, associations = self._associateFootprints(inputExposure, pixCoords, plane="DETECTED")
+
+ # TODO: If we eventually have better PhotoCalibs (eg FGCM), apply here
+ inputMI = inputExposure.getPhotoCalib().calibrateImage(inputMI, False)
+
+ # Set up transform
+ detector = inputExposure.detector
+ pixelScale = wcs.getPixelScale().asArcseconds() * arcseconds
+ pixToFocalPlaneTan = detector.getTransform(PIXELS, FIELD_ANGLE).then(
+ makeTransform(AffineTransform.makeScaling(1 / pixelScale.asRadians()))
+ )
+
+ # Loop over each bright star
+ stamps, goodFracs, stamps_fitPsfResults = [], [], []
+ for starIndex, (obj, pixCoord) in enumerate(zip(refCatBright, pixCoords)): # type: ignore
+ footprintIndex = associations.get(starIndex, None)
+ stampMI = MaskedImageF(self.paddedStampBBox)
+
+ # Set NEIGHBOR footprints in the mask plane
+ if footprintIndex:
+ neighborFootprints = [fp for i, fp in enumerate(allFootprints) if i != footprintIndex]
+ self._setFootprints(inputExposure, neighborFootprints, NEIGHBOR_MASK_PLANE)
+ else:
+ self._setFootprints(inputExposure, allFootprints, NEIGHBOR_MASK_PLANE)
+
+ # Define linear shifting to recenter stamps
+ coordFocalPlaneTan = pixToFocalPlaneTan.applyForward(pixCoord) # center of warped star
+ shift = makeTransform(AffineTransform(Point2D(0, 0) - coordFocalPlaneTan))
+ angle = np.arctan2(coordFocalPlaneTan.getY(), coordFocalPlaneTan.getX()) * radians
+ rotation = makeTransform(AffineTransform.makeRotation(-angle))
+ pixToPolar = pixToFocalPlaneTan.then(shift).then(rotation)
+
+ # Apply the warp to the star stamp (in-place)
+ warpImage(stampMI, inputExposure.maskedImage, pixToPolar, warpingControl)
+
+ # Trim to the base stamp size, check mask coverage, update metadata
+ stampMI = stampMI[self.stampBBox]
+ badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.badMaskPlanes)
+ goodFrac = np.sum(stampMI.mask.array & badMaskBitMask == 0) / stampMI.mask.array.size
+ goodFracs.append(goodFrac)
+ if goodFrac < self.config.minAreaFraction:
+ continue
+
+ # Fit a scaled PSF and a pedestal to each bright star cutout
+ psf = WarpedPsf(inputExposure.getPsf(), pixToPolar, warpingControl)
+ constantPsf = KernelPsf(FixedKernel(psf.computeKernelImage(Point2D(0, 0))))
+ fitPsfResults = {}
+ if self.config.doFitPsf:
+ fitPsfResults = self._fitPsf(stampMI, constantPsf)
+ stamps_fitPsfResults.append(fitPsfResults)
+
+ # Save the stamp if the PSF fit was successful or no fit requested
+ if fitPsfResults or not self.config.doFitPsf:
+ stamp = BrightStarStamp(
+ maskedImage=stampMI,
+ psf=constantPsf,
+ wcs=makeModifiedWcs(pixToPolar, wcs, False),
+ visit=cast(int, dataId["visit"]),
+ detector=cast(int, dataId["detector"]),
+ refId=obj["id"],
+ refMag=obj["mag"],
+ position=pixCoord,
+ scale=fitPsfResults.get("scale", None),
+ scaleErr=fitPsfResults.get("scaleErr", None),
+ pedestal=fitPsfResults.get("pedestal", None),
+ pedestalErr=fitPsfResults.get("pedestalErr", None),
+ pedestalScaleCov=fitPsfResults.get("pedestalScaleCov", None),
+ xGradient=fitPsfResults.get("xGradient", None),
+ yGradient=fitPsfResults.get("yGradient", None),
+ globalReducedChiSquared=fitPsfResults.get("globalReducedChiSquared", None),
+ globalDegreesOfFreedom=fitPsfResults.get("globalDegreesOfFreedom", None),
+ psfReducedChiSquared=fitPsfResults.get("psfReducedChiSquared", None),
+ psfDegreesOfFreedom=fitPsfResults.get("psfDegreesOfFreedom", None),
+ psfMaskedFluxFrac=fitPsfResults.get("psfMaskedFluxFrac", None),
+ )
+ stamps.append(stamp)
+
+ self.log.info(
+ "Extracted %i bright star stamp%s. Excluded stars: insufficient area (%i), PSF fit failure (%i).",
+ len(stamps),
+ "" if len(stamps) == 1 else "s",
+ np.sum(np.array(goodFracs) < self.config.minAreaFraction),
+ (
+ np.sum(np.isnan([x.get("pedestal", np.nan) for x in stamps_fitPsfResults]))
+ if self.config.doFitPsf
+ else 0
+ ),
+ )
+ brightStarStamps = BrightStarStamps(stamps)
+ return Struct(brightStarStamps=brightStarStamps)
+
+ def _getRefCatBright(self, refObjLoader: ReferenceObjectLoader, wcs: SkyWcs, bbox: Box2I) -> Table:
+ """Get a bright star subset of the reference catalog.
+
+ Trim the reference catalog to only those objects within the exposure
+ bounding box dilated by half the bright star stamp size.
+ This ensures all stars that overlap the exposure are included.
+
+ Parameters
+ ----------
+ refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`
+ Loader to find objects within a reference catalog.
+ wcs : `~lsst.afw.geom.SkyWcs`
+ World coordinate system.
+ bbox : `~lsst.geom.Box2I`
+ Bounding box of the exposure.
+
+ Returns
+ -------
+ refCatBright : `~astropy.table.Table`
+ Bright star subset of the reference catalog.
+ """
+ dilatedBBox = bbox.dilatedBy(self.paddedStampRadius)
+ withinExposure = refObjLoader.loadPixelBox(dilatedBBox, wcs, filterName="phot_g_mean")
+ refCatFull = withinExposure.refCat
+ fluxField: str = withinExposure.fluxField
+
+ proxFluxLimit = ((self.config.excludeMagLimit * u.ABmag).to(u.nJy)).to_value()
+ brightFluxLimit = ((self.config.magLimit * u.ABmag).to(u.nJy)).to_value()
+
+ subsetStars = refCatFull[fluxField] > np.min((proxFluxLimit, brightFluxLimit))
+ refCatSubset = Table(refCatFull.extract("id", "coord_ra", "coord_dec", fluxField, where=subsetStars))
+ proxStars = refCatSubset[fluxField] >= proxFluxLimit
+ brightStars = refCatSubset[fluxField] >= brightFluxLimit
+
+ coords = SkyCoord(refCatSubset["coord_ra"], refCatSubset["coord_dec"], unit="rad")
+ excludeArcsecRadius = self.config.excludeArcsecRadius * u.arcsec # type: ignore
+ refCatBrightIsolated = []
+ for coord in cast(Iterable[SkyCoord], coords[brightStars]):
+ neighbors = coords[proxStars]
+ seps = coord.separation(neighbors).to(u.arcsec)
+ tooClose = (seps > 0) & (seps <= excludeArcsecRadius) # not self matched
+ refCatBrightIsolated.append(not tooClose.any())
+
+ refCatBright = cast(Table, refCatSubset[brightStars][refCatBrightIsolated])
+
+ fluxNanojansky = refCatBright[fluxField][:] * u.nJy # type: ignore
+ refCatBright["mag"] = fluxNanojansky.to(u.ABmag).to_value() # AB magnitudes
+
+ self.log.info(
+ "Identified %i of %i star%s which overlap%s the frame, %s brighter than %s mag, and %s no nearby "
+ "neighbors.",
+ len(refCatBright),
+ len(refCatFull),
+ "" if len(refCatFull) == 1 else "s",
+ "s" if len(refCatBright) == 1 else "",
+ "is" if len(refCatBright) == 1 else "are",
+ self.config.magLimit,
+ "has" if len(refCatBright) == 1 else "have",
+ )
+
+ return refCatBright
+
+ def _associateFootprints(
+ self, inputExposure: ExposureF, pixCoords: list[Point2D], plane: str
+ ) -> tuple[list[Footprint], dict[int, int]]:
+ """Associate footprints from a given mask plane with specific objects.
+
+ Footprints from the given mask plane are associated with objects at the
+ coordinates provided, where possible.
+
+ Parameters
+ ----------
+ inputExposure : `~lsst.afw.image.ExposureF`
+ The input exposure with a mask plane.
+ pixCoords : `list` [`~lsst.geom.Point2D`]
+ The pixel coordinates of the objects.
+ plane : `str`
+ The mask plane used to identify masked pixels.
+
+ Returns
+ -------
+ footprints : `list` [`~lsst.afw.detection.Footprint`]
+ The footprints from the input exposure.
+ associations : `dict`[int, int]
+ Association indices between objects (key) and footprints (value).
+ """
+ detThreshold = Threshold(inputExposure.mask.getPlaneBitMask(plane), Threshold.BITMASK)
+ footprintSet = FootprintSet(inputExposure.mask, detThreshold)
+ footprints = footprintSet.getFootprints()
+ associations = {}
+ for starIndex, pixCoord in enumerate(pixCoords):
+ for footprintIndex, footprint in enumerate(footprints):
+ if footprint.contains(Point2I(pixCoord)):
+ associations[starIndex] = footprintIndex
+ break
+ self.log.debug(
+ "Associated %i of %i star%s to one each of the %i %s footprint%s.",
+ len(associations),
+ len(pixCoords),
+ "" if len(pixCoords) == 1 else "s",
+ len(footprints),
+ plane,
+ "" if len(footprints) == 1 else "s",
+ )
+ return footprints, associations
+
+ def _setFootprints(self, inputExposure: ExposureF, footprints: list, maskPlane: str):
+ """Set footprints in a given mask plane.
+
+ Parameters
+ ----------
+ inputExposure : `~lsst.afw.image.ExposureF`
+ The input exposure to modify.
+ footprints : `list` [`~lsst.afw.detection.Footprint`]
+ The footprints to set in the mask plane.
+ maskPlane : `str`
+ The mask plane to set the footprints in.
+
+ Notes
+ -----
+ This method modifies the ``inputExposure`` object in-place.
+ """
+ detThreshold = Threshold(inputExposure.mask.getPlaneBitMask(maskPlane), Threshold.BITMASK)
+ detThresholdValue = int(detThreshold.getValue())
+ footprintSet = FootprintSet(inputExposure.mask, detThreshold)
+
+ # Wipe any existing footprints in the mask plane
+ inputExposure.mask.clearMaskPlane(int(np.log2(detThresholdValue)))
+
+ # Set the footprints in the mask plane
+ footprintSet.setFootprints(footprints)
+ footprintSet.setMask(inputExposure.mask, maskPlane)
+
+ def _fitPsf(self, stampMI: MaskedImageF, psf: KernelPsf) -> dict[str, Any]:
+ """Fit a scaled PSF and a pedestal to each bright star cutout.
+
+ Parameters
+ ----------
+ stampMI : `~lsst.afw.image.MaskedImageF`
+ The masked image of the bright star cutout.
+ psf : `~lsst.meas.algorithms.KernelPsf`
+ The PSF model to fit.
+
+ Returns
+ -------
+ fitPsfResults : `dict`[`str`, `float`]
+ The result of the PSF fitting, with keys:
+
+ ``scale`` : `float`
+ The scale factor.
+ ``scaleErr`` : `float`
+ The error on the scale factor.
+ ``pedestal`` : `float`
+ The pedestal value.
+ ``pedestalErr`` : `float`
+ The error on the pedestal value.
+ ``pedestalScaleCov`` : `float`
+ The covariance between the pedestal and scale factor.
+ ``xGradient`` : `float`
+ The gradient in the x-direction.
+ ``yGradient`` : `float`
+ The gradient in the y-direction.
+ ``globalReducedChiSquared`` : `float`
+ The global reduced chi-squared goodness-of-fit.
+ ``globalDegreesOfFreedom`` : `int`
+ The global number of degrees of freedom.
+ ``psfReducedChiSquared`` : `float`
+ The PSF BBox reduced chi-squared goodness-of-fit.
+ ``psfDegreesOfFreedom`` : `int`
+ The PSF BBox number of degrees of freedom.
+ ``psfMaskedFluxFrac`` : `float`
+ The fraction of the PSF image flux masked by bad pixels.
+ """
+ psfImage = psf.computeKernelImage(psf.getAveragePosition())
+ badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.badMaskPlanes)
+
+ # Calculate the fraction of the PSF image flux masked by bad pixels
+ psfMaskedPixels = ImageF(psfImage.getBBox())
+ psfMaskedPixels.array[:, :] = (stampMI.mask[psfImage.getBBox()].array & badMaskBitMask).astype(bool)
+ psfMaskedFluxFrac = np.dot(psfImage.array.flat, psfMaskedPixels.array.flat)
+ if psfMaskedFluxFrac > self.config.psfMaskedFluxFracThreshold:
+ return {} # Handle cases where the PSF image is mostly masked
+
+ # Create a padded version of the input constant PSF image
+ paddedPsfImage = ImageF(stampMI.getBBox())
+ paddedPsfImage[psfImage.getBBox()] = psfImage.convertF()
+
+ # Create consistently masked data
+ badSpans = SpanSet.fromMask(stampMI.mask, badMaskBitMask)
+ goodSpans = SpanSet(stampMI.getBBox()).intersectNot(badSpans)
+ varianceData = goodSpans.flatten(stampMI.variance.array, stampMI.getXY0())
+ if self.config.useMedianVariance:
+ varianceData = np.median(varianceData)
+ sigmaData = np.sqrt(varianceData)
+ imageData = goodSpans.flatten(stampMI.image.array, stampMI.getXY0()) # B
+ imageData /= sigmaData
+ psfData = goodSpans.flatten(paddedPsfImage.array, paddedPsfImage.getXY0())
+ psfData /= sigmaData
+
+ # Fit the PSF scale factor and global pedestal
+ nData = len(imageData)
+ coefficientMatrix = np.ones((nData, 4), dtype=float) # A
+ coefficientMatrix[:, 0] = psfData
+ coefficientMatrix[:, 1] /= sigmaData
+ coefficientMatrix[:, 2:] = goodSpans.indices().T
+ coefficientMatrix[:, 2] /= sigmaData
+ coefficientMatrix[:, 3] /= sigmaData
+ try:
+ solutions, sumSquaredResiduals, *_ = np.linalg.lstsq(coefficientMatrix, imageData, rcond=None)
+ covarianceMatrix = np.linalg.inv(np.dot(coefficientMatrix.transpose(), coefficientMatrix)) # C
+ except np.linalg.LinAlgError:
+ return {} # Handle singular matrix errors
+ if sumSquaredResiduals.size == 0:
+ return {} # Handle cases where sum of the squared residuals are empty
+ scale = solutions[0]
+ if scale <= 0:
+ return {} # Handle cases where the PSF scale fit has failed
+ scaleErr = np.sqrt(covarianceMatrix[0, 0])
+ pedestal = solutions[1]
+ pedestalErr = np.sqrt(covarianceMatrix[1, 1])
+ scalePedestalCov = covarianceMatrix[0, 1]
+ xGradient = solutions[3]
+ yGradient = solutions[2]
+
+ # Calculate global (whole image) reduced chi-squared
+ globalChiSquared = np.sum(sumSquaredResiduals)
+ globalDegreesOfFreedom = nData - 4
+ globalReducedChiSquared = globalChiSquared / globalDegreesOfFreedom
+
+ # Calculate PSF BBox reduced chi-squared
+ psfBBoxGoodSpans = goodSpans.clippedTo(psfImage.getBBox())
+ psfBBoxGoodSpansX, psfBBoxGoodSpansY = psfBBoxGoodSpans.indices()
+ psfBBoxData = psfBBoxGoodSpans.flatten(stampMI.image.array, stampMI.getXY0())
+ psfBBoxModel = (
+ psfBBoxGoodSpans.flatten(paddedPsfImage.array, stampMI.getXY0()) * scale
+ + pedestal
+ + psfBBoxGoodSpansX * xGradient
+ + psfBBoxGoodSpansY * yGradient
+ )
+ psfBBoxVariance = psfBBoxGoodSpans.flatten(stampMI.variance.array, stampMI.getXY0())
+ psfBBoxResiduals = (psfBBoxData - psfBBoxModel) ** 2 / psfBBoxVariance
+ psfBBoxChiSquared = np.sum(psfBBoxResiduals)
+ psfBBoxDegreesOfFreedom = len(psfBBoxData) - 4
+ psfBBoxReducedChiSquared = psfBBoxChiSquared / psfBBoxDegreesOfFreedom
+
+ return dict(
+ scale=scale,
+ scaleErr=scaleErr,
+ pedestal=pedestal,
+ pedestalErr=pedestalErr,
+ xGradient=xGradient,
+ yGradient=yGradient,
+ pedestalScaleCov=scalePedestalCov,
+ globalReducedChiSquared=globalReducedChiSquared,
+ globalDegreesOfFreedom=globalDegreesOfFreedom,
+ psfReducedChiSquared=psfBBoxReducedChiSquared,
+ psfDegreesOfFreedom=psfBBoxDegreesOfFreedom,
+ psfMaskedFluxFrac=psfMaskedFluxFrac,
+ )
diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py
new file mode 100644
index 000000000..dfa6be59c
--- /dev/null
+++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py
@@ -0,0 +1,205 @@
+# This file is part of pipe_tasks.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+"""Stack bright star postage stamp cutouts to produce an extended PSF model."""
+
+__all__ = ["BrightStarStackConnections", "BrightStarStackConfig", "BrightStarStackTask"]
+
+# from typing import Any, Iterable, cast
+
+# import astropy.units as u
+# import numpy as np
+
+import numpy as np
+from lsst.afw.image import ImageF
+from lsst.geom import Box2I, Point2I
+from lsst.meas.algorithms import BrightStarStamps
+from lsst.pex.config import Field, ListField
+from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections
+from lsst.pipe.base.connectionTypes import Input, Output
+from lsst.utils.timer import timeMethod
+
+NEIGHBOR_MASK_PLANE = "NEIGHBOR"
+
+
+class BrightStarStackConnections(
+ PipelineTaskConnections,
+ dimensions=("instrument", "detector"),
+):
+ """Connections for BrightStarStackTask."""
+
+ brightStarStamps = Input(
+ name="brightStarStamps",
+ storageClass="BrightStarStamps",
+ doc="Set of preprocessed postage stamp cutouts, each centered on a single bright star.",
+ dimensions=("visit", "detector"),
+ multiple=True,
+ deferLoad=True,
+ )
+ extendedPsf = Output(
+ name="extendedPsf2", # extendedPsfDetector ???
+ storageClass="ExtendedPsf", # MaskedImageF
+ doc="Extended PSF model, built from stacking bright star cutouts.",
+ dimensions=("band",),
+ )
+
+
+class BrightStarStackConfig(
+ PipelineTaskConfig,
+ pipelineConnections=BrightStarStackConnections,
+):
+ """Configuration parameters for BrightStarStackTask."""
+
+ numVisitStack = Field[int](
+ doc="Number of visits to stack for each detector.",
+ default=5,
+ )
+ reducedChiSquaredThreshold = Field[float](
+ doc="Threshold for reduced chi-squared value of bright star cutouts.",
+ default=2.0,
+ )
+ badMaskPlanes = ListField[str](
+ doc="Mask planes that identify excluded (masked) pixels.",
+ default=[
+ "BAD",
+ "CR",
+ "CROSSTALK",
+ "EDGE",
+ "NO_DATA",
+ # "SAT",
+ # "SUSPECT",
+ "UNMASKEDNAN",
+ NEIGHBOR_MASK_PLANE,
+ ],
+ )
+
+
+class BrightStarStackTask(PipelineTask):
+ """Stack bright star postage stamps to produce an extended PSF model."""
+
+ ConfigClass = BrightStarStackConfig
+ _DefaultName = "brightStarStack"
+ config: BrightStarStackConfig
+
+ def __init__(self, initInputs=None, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def runQuantum(self, butlerQC, inputRefs, outputRefs):
+ inputs = butlerQC.get(inputRefs)
+ output = self.run(**inputs)
+ butlerQC.put(output, outputRefs)
+
+ @timeMethod
+ def run(
+ self,
+ brightStarStamps: BrightStarStamps,
+ ):
+ """Identify bright stars within an exposure using a reference catalog,
+ extract stamps around each, then preprocess them.
+
+ Bright star preprocessing steps are: shifting, warping and potentially
+ rotating them to the same pixel grid; computing their annular flux,
+ and; normalizing them.
+
+ Parameters
+ ----------
+ inputExposure : `~lsst.afw.image.ExposureF`
+ The image from which bright star stamps should be extracted.
+ inputBackground : `~lsst.afw.image.Background`
+ The background model for the input exposure.
+ refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional
+ Loader to find objects within a reference catalog.
+ dataId : `dict` or `~lsst.daf.butler.DataCoordinate`
+ The dataId of the exposure (including detector) that bright stars
+ should be extracted from.
+
+ Returns
+ -------
+ brightStarResults : `~lsst.pipe.base.Struct`
+ Results as a struct with attributes:
+
+ ``brightStarStamps``
+ (`~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`)
+ """
+ extendedPsfMI = None # masked image
+ extendedPsfWI = None # weight image
+ for stampsDDH in brightStarStamps:
+ stamps = stampsDDH.get()
+ if not stamps:
+ continue
+
+ for stamp in stamps:
+ stampMI = stamp.maskedImage
+ stampBBox = stampMI.getBBox()
+
+ # Apply fitted components
+ stampMI -= stamp.pedestal
+ xGrid, yGrid = np.meshgrid(stampBBox.getX().arange(), stampBBox.getY().arange())
+ xPlane = ImageF((xGrid * stamp.xGradient).astype(np.float32), xy0=stampMI.getXY0())
+ yPlane = ImageF((yGrid * stamp.yGradient).astype(np.float32), xy0=stampMI.getXY0())
+ stampMI -= xPlane
+ stampMI -= yPlane
+ stampMI *= stamp.scale
+
+ badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.badMaskPlanes)
+ stampMask = (stampMI.mask.array & badMaskBitMask).astype(bool)
+ stampMI.image.array[stampMask] = 0
+ stampMI.variance.array[stampMask] = 0
+
+ if stamp.psfReducedChiSquared > 5:
+ continue
+ stampWeight = 1 / stamp.psfReducedChiSquared
+ stampMI *= stampWeight
+ stampWI = ImageF(stampBBox, stampWeight)
+ stampWI.array[stampMask] = 0
+
+ if not extendedPsfMI:
+ extendedPsfMI = stampMI.clone()
+ extendedPsfWI = stampWI.clone()
+ else:
+ extendedPsfMI += stampMI
+ extendedPsfWI += stampWI
+
+ if extendedPsfMI:
+ extendedPsfMI /= extendedPsfWI
+ breakpoint()
+
+ # stack = []
+ # chiStack = []
+ # for loop over all groups:
+ # load up all visits for this detector
+ # drop all with GOF > thresh
+ # sigma-clip mean stack the rest
+ # append to stack
+ # compute the scatter (MAD/sigma-clipped var, etc) of the rest
+ # divide by sqrt(var plane), and append to chiStack
+ # after for-loop, combine images in median stack for final result
+ # also combine chi-images, save separately
+
+ # idea: run with two different thresholds, and compare the results
+
+ # medianStack = []
+ # for loop over all groups:
+ # load up all visits for this detector
+ # drop all with GOF > thresh
+ # median/sigma-clip stack the rest
+ # append to medianStack
+ # after for-loop, combine images in median stack for final result
diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarSubtract.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarSubtract.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/test_brightStarCutout.py b/tests/test_brightStarCutout.py
new file mode 100644
index 000000000..67b88d02f
--- /dev/null
+++ b/tests/test_brightStarCutout.py
@@ -0,0 +1,102 @@
+# This file is part of pipe_tasks.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import unittest
+
+import lsst.afw.cameraGeom.testUtils
+import lsst.afw.image
+import lsst.utils.tests
+import numpy as np
+from lsst.afw.image import ImageD, ImageF, MaskedImageF
+from lsst.afw.math import FixedKernel
+from lsst.geom import Point2I
+from lsst.meas.algorithms import KernelPsf
+from lsst.pipe.tasks.brightStarSubtraction import BrightStarCutoutConfig, BrightStarCutoutTask
+
+
+class BrightStarCutoutTestCase(lsst.utils.tests.TestCase):
+ def setUp(self):
+ # Fit values
+ self.scale = 2.34e5
+ self.pedestal = 3210.1
+ self.xGradient = 5.432
+ self.yGradient = 10.987
+
+ # Create a pedestal + 2D plane
+ xCoords = np.linspace(-50, 50, 101)
+ yCoords = np.linspace(-50, 50, 101)
+ xPlane, yPlane = np.meshgrid(xCoords, yCoords)
+ pedestal = np.ones_like(xPlane) * self.pedestal
+
+ # Create a pseudo-PSF
+ dist_from_center = np.sqrt(xPlane**2 + yPlane**2)
+ psfArray = np.exp(-dist_from_center / 5)
+ psfArray /= np.sum(psfArray)
+ fixedKernel = FixedKernel(ImageD(psfArray))
+ self.psf = KernelPsf(fixedKernel)
+
+ # Bring everything together to construct a stamp masked image
+ stampArray = psfArray * self.scale + pedestal + xPlane * self.xGradient + yPlane * self.yGradient
+ stampIm = ImageF((stampArray).astype(np.float32))
+ stampVa = ImageF(stampIm.getBBox(), 654.321)
+ self.stampMI = MaskedImageF(image=stampIm, variance=stampVa)
+ self.stampMI.setXY0(Point2I(-50, -50))
+
+ # Ensure that all mask planes required by the task are in-place;
+ # new mask plane entries will be created as necessary
+ badMaskPlanes = [
+ "BAD",
+ "CR",
+ "CROSSTALK",
+ "EDGE",
+ "NO_DATA",
+ "SAT",
+ "SUSPECT",
+ "UNMASKEDNAN",
+ "NEIGHBOR",
+ ]
+ _ = [self.stampMI.mask.addMaskPlane(mask) for mask in badMaskPlanes]
+
+ def test_fitPsf(self):
+ """Test the PSF fitting method."""
+ brightStarCutoutConfig = BrightStarCutoutConfig()
+ brightStarCutoutTask = BrightStarCutoutTask(config=brightStarCutoutConfig)
+ fitPsfResults = brightStarCutoutTask._fitPsf(
+ self.stampMI,
+ self.psf,
+ )
+ self.assertAlmostEqual(fitPsfResults["scale"], self.scale, delta=1e-3)
+ self.assertAlmostEqual(fitPsfResults["pedestal"], self.pedestal, delta=1e-5)
+ self.assertAlmostEqual(fitPsfResults["xGradient"], self.xGradient, delta=1e-7)
+ self.assertAlmostEqual(fitPsfResults["yGradient"], self.yGradient, delta=1e-7)
+
+
+def setup_module(module):
+ lsst.utils.tests.init()
+
+
+class MemoryTestCase(lsst.utils.tests.MemoryTestCase):
+ pass
+
+
+if __name__ == "__main__":
+ lsst.utils.tests.init()
+ unittest.main()