From b396ffde362a265352633bbd3badc7d7ebe56ae5 Mon Sep 17 00:00:00 2001 From: "Amir E. Bazkiaei" Date: Sun, 25 Jun 2023 16:07:21 +1000 Subject: [PATCH 1/2] Subtract missing stars from bright star stamps --- python/lsst/pipe/tasks/processBrightStars.py | 33 +- python/lsst/pipe/tasks/subtractBrightStars.py | 470 ++++++++++++++++-- 2 files changed, 460 insertions(+), 43 deletions(-) diff --git a/python/lsst/pipe/tasks/processBrightStars.py b/python/lsst/pipe/tasks/processBrightStars.py index e1f9fdfc2..939be90c4 100644 --- a/python/lsst/pipe/tasks/processBrightStars.py +++ b/python/lsst/pipe/tasks/processBrightStars.py @@ -223,6 +223,10 @@ def __init__(self, initInputs=None, *args, **kwargs): self.modelStampSize[1] += 1 # central pixel self.modelCenter = self.modelStampSize[0] // 2, self.modelStampSize[1] // 2 + self.setModelStamp() + # configure Gaia refcat + if butler is not None: + self.makeSubtask("refObjLoader", butler=butler) def applySkyCorr(self, calexp, skyCorr): """Apply correction to the sky background level. @@ -246,7 +250,7 @@ def applySkyCorr(self, calexp, skyCorr): calexp = calexp.getMaskedImage() calexp -= skyCorr.getImage() - def extractStamps(self, inputExposure, refObjLoader=None): + def extractStamps(self, inputExposure, refObjLoader=None, inputBrightStarStamps=None): """Read the position of bright stars within an input exposure using a refCat and extract them. @@ -296,9 +300,16 @@ def extractStamps(self, inputExposure, refObjLoader=None): GFluxes = np.array(refCat["phot_g_mean_flux"]) bright = GFluxes > fluxLimit # convert to AB magnitudes - allGMags = [((gFlux * u.nJy).to(u.ABmag)).to_value() for gFlux in GFluxes[bright]] + allGMags = np.array([((gFlux * u.nJy).to(u.ABmag)).to_value() for gFlux in GFluxes[bright]]) allIds = refCat.columns.extract("id", where=bright)["id"] selectedColumns = refCat.columns.extract("coord_ra", "coord_dec", where=bright) + if inputBrightStarStamps is not None: + existings = np.array(inputBrightStarStamps.getGaiaIds()) + existed = np.isin(allIds, existings) + allGMags = allGMags[~existed] + allIds = allIds[~existed] + selectedColumns["coord_ra"] = selectedColumns["coord_ra"][~existed] + selectedColumns["coord_dec"] = selectedColumns["coord_dec"][~existed] for j, (ra, dec) in enumerate(zip(selectedColumns["coord_ra"], selectedColumns["coord_dec"])): sp = SpherePoint(ra, dec, radians) cpix = wcs.skyToPixel(sp) @@ -446,6 +457,19 @@ def warpStamps(self, stamps, pixCenters): warpTransforms.append(starWarper) return Struct(warpedStars=warpedStars, warpTransforms=warpTransforms, xy0s=xy0s, nb90Rots=nb90Rots) + def setModelStamp(self): + self.modelStampSize = [ + int(self.config.stampSize[0] * self.config.modelStampBuffer), + int(self.config.stampSize[1] * self.config.modelStampBuffer), + ] + # force it to be odd-sized so we have a central pixel + if not self.modelStampSize[0] % 2: + self.modelStampSize[0] += 1 + if not self.modelStampSize[1] % 2: + self.modelStampSize[1] += 1 + # central pixel + self.modelCenter = self.modelStampSize[0] // 2, self.modelStampSize[1] // 2 + @timeMethod def run(self, inputExposure, refObjLoader=None, dataId=None, skyCorr=None): """Identify bright stars within an exposure using a reference catalog, @@ -528,6 +552,10 @@ def run(self, inputExposure, refObjLoader=None, dataId=None, skyCorr=None): badMaskPlanes=self.config.badMaskPlanes, discardNanFluxObjects=(self.config.discardNanFluxStars), ) + # Dont create empty fits files if there is no normalized stamp! + if not len(brightStarStamps._stamps) > 0: + self.log.info("No normalized stamps exists for this exposure!") + return None return Struct(brightStarStamps=brightStarStamps) def runQuantum(self, butlerQC, inputRefs, outputRefs): @@ -540,5 +568,6 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs): config=self.config.refObjLoader, ) output = self.run(**inputs, refObjLoader=refObjLoader) + # This if block prevents the code to produce an emtpy fits file in case there is no stamp. if output: butlerQC.put(output, outputRefs) diff --git a/python/lsst/pipe/tasks/subtractBrightStars.py b/python/lsst/pipe/tasks/subtractBrightStars.py index 7e5bf69ef..2d18f2805 100644 --- a/python/lsst/pipe/tasks/subtractBrightStars.py +++ b/python/lsst/pipe/tasks/subtractBrightStars.py @@ -23,11 +23,13 @@ __all__ = ["SubtractBrightStarsConnections", "SubtractBrightStarsConfig", "SubtractBrightStarsTask"] +import logging from functools import reduce from operator import ior import numpy as np from lsst.afw.image import Exposure, ExposureF, MaskedImageF +from lsst.afw.geom import SpanSet, Stencil from lsst.afw.math import ( StatisticsControl, WarpingControl, @@ -37,15 +39,22 @@ warpImage, ) from lsst.geom import Box2I, Point2D, Point2I -from lsst.pex.config import ChoiceField, Field, ListField +from lsst.meas.algorithms import LoadReferenceObjectsConfig, ReferenceObjectLoader +from lsst.meas.algorithms.brightStarStamps import BrightStarStamp, BrightStarStamps +from lsst.pex.config import ChoiceField, ConfigField, Field, ListField from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct -from lsst.pipe.base.connectionTypes import Input, Output +from lsst.pipe.base.connectionTypes import Input, Output, PrerequisiteInput +from lsst.pipe.tasks.processBrightStars import ProcessBrightStarsTask + +logger = logging.getLogger(__name__) class SubtractBrightStarsConnections( PipelineTaskConnections, dimensions=("instrument", "visit", "detector"), - defaultTemplates={"outputExposureName": "brightStar_subtracted", "outputBackgroundName": "brightStars"}, + defaultTemplates={"outputExposureName": "brightStar_subtracted", + "outputBackgroundName": "brightStars", + "badStampsName": "brightStars"}, ): inputExposure = Input( doc="Input exposure from which to subtract bright star stamps.", @@ -81,6 +90,14 @@ class SubtractBrightStarsConnections( "detector", ), ) + refCat = PrerequisiteInput( + doc="Reference catalog that contains bright star positions", + name="gaia_dr2_20200414", + storageClass="SimpleCatalog", + dimensions=("skypix",), + multiple=True, + deferLoad=True, + ) outputExposure = Output( doc="Exposure with bright stars subtracted.", name="{outputExposureName}_calexp", @@ -99,6 +116,15 @@ class SubtractBrightStarsConnections( "detector", ), ) + outputBadStamps = Output( + doc="The stamps the are not normalized and consequently not subtracted from the exposure.", + name="{badStampsName}_unsubtracted_stapms", + storageClass="BrightStarStamps", + dimensions=( + "visit", + "detector", + ), + ) def __init__(self, *, config=None): super().__init__(config=config) @@ -124,6 +150,22 @@ class SubtractBrightStarsConfig(PipelineTaskConfig, pipelineConnections=Subtract doc="Magnitude limit, in Gaia G; all stars brighter than this value will be subtracted", default=18, ) + minValidAnnulusFraction = Field( + dtype=float, + doc="Minumum number of valid pixels that must fall within the annulus for the bright star to be " + "saved for subsequent generation of a PSF.", + default=0.0, + ) + numSigmaClip = Field( + dtype=float, + doc="Sigma for outlier rejection; ignored if annularFluxStatistic != 'MEANCLIP'.", + default=4, + ) + numIter = Field( + dtype=int, + doc="Number of iterations of outlier rejection; ignored if annularFluxStatistic != 'MEANCLIP'.", + default=3, + ) warpingKernelName = ChoiceField[str]( dtype=str, doc="Warping kernel", @@ -148,6 +190,16 @@ class SubtractBrightStarsConfig(PipelineTaskConfig, pipelineConnections=Subtract "leastSquare": "find least square scaling factor", }, ) + annularFluxStatistic = ChoiceField( + dtype=str, + doc="Type of statistic to use to compute annular flux.", + default="MEANCLIP", + allowed={ + "MEAN": "mean", + "MEDIAN": "median", + "MEANCLIP": "clipped mean", + }, + ) badMaskPlanes = ListField[str]( dtype=str, doc="Mask planes that, if set, lead to associated pixels not being included in the computation of " @@ -158,11 +210,28 @@ class SubtractBrightStarsConfig(PipelineTaskConfig, pipelineConnections=Subtract # interest) also get set to `BAD`. default=("BAD", "CR", "CROSSTALK", "EDGE", "NO_DATA", "SAT", "SUSPECT", "UNMASKEDNAN"), ) + subtractionBox = ListField( + dtype=int, + doc="Size of the stamps to be extracted, in pixels.", + default=(250, 250), + ) + subtractionBoxBuffer = Field( + dtype=float, + doc=( + "'Buffer' factor to be applied to determine the size of the stamp the processed stars will be " + "saved in. This will also be the size of the extended PSF model." + ), + default=1.1, + ) doApplySkyCorr = Field[bool]( dtype=bool, doc="Apply full focal plane sky correction before extracting stars?", default=True, ) + refObjLoader = ConfigField( + dtype=LoadReferenceObjectsConfig, + doc="Reference object loader for astrometric calibration.", + ) class SubtractBrightStarsTask(PipelineTask): @@ -182,6 +251,8 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Placeholders to set up Statistics if scalingType is leastSquare. self.statsControl, self.statsFlag = None, None + # warping control; only contains shiftingALg provided in config + self.warpCont = WarpingControl(self.config.warpingKernelName) def _setUpStatistics(self, exampleMask): """Configure statistics control and flag, for use if ``scalingType`` is @@ -213,7 +284,7 @@ def applySkyCorr(self, calexp, skyCorr): calexp = calexp.getMaskedImage() calexp -= skyCorr.getImage() - def scaleModel(self, model, star, inPlace=True, nb90Rots=0): + def scaleModel(self, model, star, inPlace=True, nb90Rots=0, psf_annular_flux=None): """Compute scaling factor to be applied to the extended PSF so that its amplitude matches that of an individual star. @@ -235,8 +306,10 @@ def scaleModel(self, model, star, inPlace=True, nb90Rots=0): The factor by which the model image should be multiplied for it to be scaled to the input bright star. """ + if psf_annular_flux is None: + psf_annular_flux = 1 if self.config.scalingType == "annularFlux": - scalingFactor = star.annularFlux + scalingFactor = star.annularFlux * psf_annular_flux elif self.config.scalingType == "leastSquare": if self.statsControl is None: self._setUpStatistics(star.stamp_im.mask) @@ -259,21 +332,331 @@ def scaleModel(self, model, star, inPlace=True, nb90Rots=0): model.image *= scalingFactor return scalingFactor + def _overRideWarperConfig(self): + """Override the warper config with the config of this task. + """ + self.warper.config.minValidAnnulusFraction = self.config.minValidAnnulusFraction + self.warper.config.numSigmaClip = self.config.numSigmaClip + self.warper.config.numIter = self.config.numIter + self.warper.config.annularFluxStatistic = self.config.annularFluxStatistic + self.warper.config.badMaskPlanes = self.config.badMaskPlanes + self.warper.config.stampSize = self.config.subtractionBox + self.warper.modelStampBuffer = self.config.subtractionBoxBuffer + self.warper.setModelStamp() + + def setMissedStarsStatsControl(self): + """Configure statistics control for processing missing stars from inputBrightStarStamps. + """ + self.missedStatsControl = StatisticsControl() + self.missedStatsControl.setNumSigmaClip(self.warper.config.numSigmaClip) + self.missedStatsControl.setNumIter(self.warper.config.numIter) + self.missedStatsFlag = stringToStatisticsProperty(self.warper.config.annularFluxStatistic) + + def setWarpTask(self): + """Create an instance of ProcessBrightStarsTask that will be used to produce stamps of stars to be + subtracted. + """ + self.warper = ProcessBrightStarsTask() + self._overRideWarperConfig() + self.warper.modelCenter = self.modelStampSize[0] // 2, self.modelStampSize[1] // 2 + + def makeBrightStarList(self, inputBrightStarStamps, inputExposure, refObjLoader): + """Make a list of bright stars that are missing from inputBrightStarStamps to be subtracted. + + Parameters + ---------- + inputBrightStarStamps : `~lsst.meas.algorithms.brightStarStamps.BrightStarStamps` + Set of stamps centered on each bright star to be subtracted, produced by running + `~lsst.pipe.tasks.processBrightStars.ProcessBrightStarsTask`. + inputExposure : `~lsst.afw.image.ExposureF` + The image from which bright stars should be subtracted. + refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional + Loader to find objects within a reference catalog. + + Returns + ------- + brightStarList: + A list containing `lsst.meas.algorithms.brightStarStamps.BrightStarStamp` of stars to be + subtracted. + """ + self.setWarpTask() + missedStars = self.warper.extractStamps( + inputExposure, refObjLoader=refObjLoader, inputBrightStarStamps=inputBrightStarStamps + ) + self.warpOutputs = self.warper.warpStamps(missedStars.starIms, missedStars.pixCenters) + brightStarList = [ + BrightStarStamp( + stamp_im=warp, + archive_element=transform, + position=self.warpOutputs.xy0s[j], + gaiaGMag=missedStars.GMags[j], + gaiaId=missedStars.gaiaIds[j], + minValidAnnulusFraction=self.warper.config.minValidAnnulusFraction, + ) + for j, (warp, transform) in enumerate(zip(self.warpOutputs.warpedStars, + self.warspOutputs.warpTransforms)) + ] + return brightStarList + + def initAnnulusImage(self): + """Initialize an annulus image of the given star. + + Returns + ------- + annulusImage : `~lsst.afw.image.MaskedImageF` + The initialized annulus image. + """ + maskPlaneDict = self.model.mask.getMaskPlaneDict() + annulusImage = MaskedImageF(self.modelStampSize, planeDict=maskPlaneDict) + annulusImage.mask.array[:] = 2 ** maskPlaneDict["NO_DATA"] + return annulusImage + + def createAnnulus(self, brightStarStamp): + """Create an annulus of the given star. + + Parameters + ---------- + brightStarStamp : `~lsst.meas.algorithms.brightStarStamps.BrightStarStamp` + A stamp of a bright star to be subtracted. + + Returns + ------- + annulus : `~lsst.afw.image.MaskedImageF` + An annulus of the given star. + """ + # Create SpanSet of annulus + outerCircle = SpanSet.fromShape( + brightStarStamp.optimalOuterRadius, Stencil.CIRCLE, offset=self.warper.modelCenter + ) + innerCircle = SpanSet.fromShape( + brightStarStamp.optimalInnerRadius, Stencil.CIRCLE, offset=self.warper.modelCenter + ) + annulus = outerCircle.intersectNot(innerCircle) + return annulus + + def applyStatsControl(self, annulusImage): + """Apply statistics control to the PSF annulus image. + + Parameters + ---------- + annulusImage : `~lsst.afw.image.MaskedImageF` + An image containing an annulus of the given model. + + Returns + ------- + annularFlux: float + The annular flux of the PSF model at the radius where the flux of the given star is determined. + """ + andMask = reduce( + ior, (annulusImage.mask.getPlaneBitMask(bm) for bm in self.warper.config.badMaskPlanes) + ) + self.missedStatsControl.setAndMask(andMask) + annulusStat = makeStatistics(annulusImage, self.missedStatsFlag, self.missedStatsControl) + return annulusStat.getValue() + + def findPsfAnnularFlux(self, brightStarStamp, maskedModel): + """Find the annular flux of the PSF model at the radius where the flux of the given star is + determined. + + Parameters + ---------- + brightStarStamp : `~lsst.meas.algorithms.brightStarStamps.BrightStarStamp` + A stamp of a bright star to be subtracted. + maskedModel : `~lsst.afw.image.MaskedImageF` + A masked image of the PSF model. + + Returns + ------- + annularFlux: float + The annular flux of the PSF model at the radius where the flux of the given star is determined. + """ + annulusImage = self.initAnnulusImage() + annulus = self.createAnnulus(brightStarStamp) + annulus.copyMaskedImage(maskedModel, annulusImage) + annularFlux = self.applyStatsControl(annulusImage) + return annularFlux + + def findPsfAnnularFluxes(self, brightStarStamps): + """Find the annular fluxes of the given PSF model. + + Parameters + ---------- + brightStarStamps : `~lsst.meas.algorithms.brightStarStamps.BrightStarStamps` + The stamps of stars that will be subtracted from the exposure. + + Returns + ------- + PsfAnnularFluxes: numpy.array + A two column numpy.array containing annular fluxes of the PSF at radii where the flux for stars + exist (could be found). + + Notes + ----- + While the PSF model is normalized at a certain radius, the flux of a star at that radius might be + impossible to find. Therefore, we have to scale the PSF model considering a radius where the star has + an identified flux. To do that, the flux of the model should be found and used to adjust the scaling + step. + """ + outerRadii = [] + annularFluxes = [] + maskedModel = MaskedImageF(self.model.image) + # the model has wrong bbox values. Should be fixed in extended_psf.py? + maskedModel.setXY0(0, 0) + for star in brightStarStamps: + if star.optimalOuterRadius not in outerRadii: + annularFlux = self.findPsfAnnularFlux(star, maskedModel) + outerRadii.append(star.optimalOuterRadius) + annularFluxes.append(annularFlux) + return np.array([outerRadii, annularFluxes]).T + + def preparePlaneModelStamp(self, brightStarStamp): + """Prepare the PSF model before scaling. + + Parameters + ---------- + brightStarStamp : `~lsst.meas.algorithms.brightStarStamps.BrightStarStamp` + The stamp of the star to which the PSF model will be scaled. + + Returns + ------- + bbox: `~lsst.geom.Box2I` + Contains the corner coordination and the dimensions of the model stamp. + + invImage: `~lsst.afw.image.MaskedImageF` + The extended PSF model, shifted (and potentially warped) to match the bright star's positioning. + + Raises + ------ + RuntimeError + Raised if warping of the model is failed. + """ + # Set the origin. + self.model.setXY0(brightStarStamp.position) + # Create an empty destination image. + invTransform = brightStarStamp.archive_element.inverted() + invOrigin = Point2I(invTransform.applyForward(Point2D(brightStarStamp.position))) + bbox = Box2I(corner=invOrigin, dimensions=self.modelStampSize) + invImage = MaskedImageF(bbox) + # Apply inverse transform. + goodPix = warpImage(invImage, self.model, invTransform, self.warpCont) + if not goodPix: + # Do we want to find another way or just subtract the non-warped scaled model? + # Currently the code just leaves the failed ones un-subtracted. + raise RuntimeError( + f"Warping of a model failed for star {brightStarStamp.gaiaId}: " "no good pixel in output" + ) + return bbox, invImage + + def addScaledModel(self, subtractor, brightStarStamp, multipleAnnuli=False): + """Add the scaled model of the given star to the subtractor plane. + + Parameters + ---------- + subtractor : `~lsst.afw.image.MaskedImageF` + The Exposure containing the scaled model of brigth stars to be subtracted from the input + exposure. + brightStarStamp : `~lsst.meas.algorithms.brightStarStamps.BrightStarStamp` + The stamp of the star of which the PSF model will be scaled and added to the subtractor. + multipleAnnuli : bool, optional + If true, the model should be scaled based on a flux at a radius other than its normalization + radius. + + Returns + ------- + subtractor : `~lsst.afw.image.MaskedImageF` + The input subtractor Exposure with the added scaled model at the given star's location in the + exposure. + invImage: `~lsst.afw.image.MaskedImageF` + The extended PSF model, shifted (and potentially warped) to match the bright star's positioning. + """ + bbox, invImage = self.preparePlaneModelStamp(brightStarStamp) + if multipleAnnuli: + cond = self.psf_annular_fluxes[:, 0] == brightStarStamp.optimalOuterRadius + psf_annular_flux = self.psf_annular_fluxes[cond, 1][0] + self.scaleModel(invImage, + brightStarStamp, + inPlace=True, + nb90Rots=self.inv90Rots, + psf_annular_flux=psf_annular_flux) + else: + self.scaleModel(invImage, brightStarStamp, inPlace=True, nb90Rots=self.inv90Rots) + # Replace NaNs before subtraction (note all NaN pixels have + # the NO_DATA flag). + invImage.image.array[np.isnan(invImage.image.array)] = 0 + bbox.clip(self.inputExpBBox) + if bbox.getArea() > 0: + subtractor[bbox] += invImage[bbox] + return subtractor, invImage + + def buildSubtractor(self, brightStarStamps, subtractor, invImages, multipleAnnuli=False): + """Build an image containing potentially multiple scaled PSF models, each at the location of a given + brigth star. + + Parameters + ---------- + brightStarStamps : `~lsst.meas.algorithms.brightStarStamps.BrightStarStamps` + Set of stamps centered on each bright star to be subtracted, produced by running + `~lsst.pipe.tasks.processBrightStars.ProcessBrightStarsTask`. + subtractor : `~lsst.afw.image.MaskedImageF` + The Exposure that will contain the scaled model of brigth stars to be subtracted from the + exposure. + invImages : `list` + A list containing extended PSF models, shifted (and potentially warped) to match the bright stars + positionings. + multipleAnnuli : bool, optional + This will be passed to addScaledModel method, by default False. + + Returns + ------- + subtractor : `~lsst.afw.image.MaskedImageF` + An Exposure containing a scaled bright star model fit to every bright star profile; its image can + then be subtracted from the input exposure. + invImages: list + A list containing the extended PSF models, shifted (and potentially warped) to match bright + stars' positionings. + """ + for star in brightStarStamps: + if star.gaiaGMag < self.config.magLimit: + try: + # Adding the scaled model at the star location to the subtractor. + subtractor, invImage = self.addScaledModel(subtractor, star, multipleAnnuli) + invImages.append(invImage) + except RuntimeError as err: + logger.error(err) + return subtractor, invImages + def runQuantum(self, butlerQC, inputRefs, outputRefs): # Docstring inherited. inputs = butlerQC.get(inputRefs) dataId = butlerQC.quantum.dataId - subtractor, _ = self.run(**inputs, dataId=dataId) + refObjLoader = ReferenceObjectLoader( + dataIds=[ref.datasetRef.dataId for ref in inputRefs.refCat], + refCats=inputs.pop("refCat"), + name=self.config.connections.refCat, + config=self.config.refObjLoader, + ) + subtractor, _, badStamps = self.run(**inputs, dataId=dataId, refObjLoader=refObjLoader) if self.config.doWriteSubtractedExposure: outputExposure = inputs["inputExposure"].clone() outputExposure.image -= subtractor.image else: outputExposure = None outputBackgroundExposure = subtractor if self.config.doWriteSubtractor else None - output = Struct(outputExposure=outputExposure, outputBackgroundExposure=outputBackgroundExposure) + # in its current state, the code produces outputBadStamps which are the stamps of stars that have not + # been subtracted from the image for any reason. If all the stars are subtracted from the calexp, the + # output is an empty fits file. + output = Struct(outputExposure=outputExposure, + outputBackgroundExposure=outputBackgroundExposure, + outputBadStamps=badStamps) butlerQC.put(output, outputRefs) - def run(self, inputExposure, inputBrightStarStamps, inputExtendedPsf, dataId, skyCorr=None): + def run(self, + inputExposure, + inputBrightStarStamps, + inputExtendedPsf, + dataId, + skyCorr=None, + refObjLoader=None): """Iterate over all bright stars in an exposure to scale the extended PSF model before subtracting bright stars. @@ -296,6 +679,8 @@ def run(self, inputExposure, inputBrightStarStamps, inputExtendedPsf, dataId, sk Full focal plane sky correction, obtained by running `~lsst.pipe.tasks.skyCorrection.SkyCorrectionTask`. If `doApplySkyCorr` is set to `True`, `skyCorr` cannot be `None`. + refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional + Loader to find objects within a reference catalog. Returns ------- @@ -307,7 +692,7 @@ def run(self, inputExposure, inputBrightStarStamps, inputExtendedPsf, dataId, sk A list of small images ("stamps") containing the model, each scaled to its corresponding input bright star. """ - inputExpBBox = inputExposure.getBBox() + self.inputExpBBox = inputExposure.getBBox() if self.config.doApplySkyCorr and (skyCorr is not None): self.log.info( "Applying sky correction to exposure %s (exposure will be modified in-place).", dataId @@ -318,36 +703,39 @@ def run(self, inputExposure, inputBrightStarStamps, inputExtendedPsf, dataId, sk subtractorExp = ExposureF(bbox=inputExposure.getBBox()) subtractor = subtractorExp.maskedImage # Make a copy of the input model. - model = inputExtendedPsf(dataId["detector"]).clone() - modelStampSize = model.getDimensions() - inv90Rots = 4 - inputBrightStarStamps.nb90Rots % 4 - model = rotateImageBy90(model, inv90Rots) - warpCont = WarpingControl(self.config.warpingKernelName) + self.model = inputExtendedPsf(dataId["detector"]).clone() + self.modelStampSize = self.model.getDimensions() + self.inv90Rots = 4 - inputBrightStarStamps.nb90Rots % 4 + self.model = rotateImageBy90(self.model, self.inv90Rots) + + brightStarList = self.makeBrightStarList(inputBrightStarStamps, inputExposure, refObjLoader) + self.setMissedStarsStatsControl() + # This might change when we use multiple categories of stars for creating PSF. + innerRadius = inputBrightStarStamps._innerRadius + outerRadius = inputBrightStarStamps._outerRadius + brightStarStamps, badStamps = BrightStarStamps.initAndNormalize( + brightStarList, + innerRadius=innerRadius, + outerRadius=outerRadius, + nb90Rots=self.warpOutputs.nb90Rots, + imCenter=self.warper.modelCenter, + use_archive=True, + statsControl=self.missedStatsControl, + statsFlag=self.missedStatsFlag, + badMaskPlanes=self.warper.config.badMaskPlanes, + discardNanFluxObjects=False, + forceFindFlux=True, + ) + invImages = [] - # Loop over bright stars, computing the inverse transformed and scaled - # postage stamp for each. - for star in inputBrightStarStamps: - if star.gaiaGMag < self.config.magLimit: - # Set the origin. - model.setXY0(star.position) - # Create an empty destination image. - invTransform = star.archive_element.inverted() - invOrigin = Point2I(invTransform.applyForward(Point2D(star.position))) - bbox = Box2I(corner=invOrigin, dimensions=modelStampSize) - invImage = MaskedImageF(bbox) - # Apply inverse transform. - goodPix = warpImage(invImage, model, invTransform, warpCont) - if not goodPix: - self.log.debug( - f"Warping of a model failed for star {star.gaiaId}: " "no good pixel in output" - ) - # Scale the model. - self.scaleModel(invImage, star, inPlace=True, nb90Rots=inv90Rots) - # Replace NaNs before subtraction (note all NaN pixels have - # the NO_DATA flag). - invImage.image.array[np.isnan(invImage.image.array)] = 0 - bbox.clip(inputExpBBox) - if bbox.getArea() > 0: - subtractor[bbox] += invImage[bbox] - invImages.append(invImage) - return subtractorExp, invImages + subtractor, invImages = self.buildSubtractor( + inputBrightStarStamps, subtractor, invImages, multipleAnnuli=False + ) + if len(brightStarStamps) > 0: + self.psf_annular_fluxes = self.findPsfAnnularFluxes(brightStarStamps) + subtractor, invImages = self.buildSubtractor( + brightStarStamps, subtractor, invImages, multipleAnnuli=True + ) + badStamps = BrightStarStamps(badStamps) + + return subtractorExp, invImages, badStamps From f2254f6b31f7939f708fb4b35bfb69274b464e38 Mon Sep 17 00:00:00 2001 From: Lee Kelvin Date: Sun, 16 Jul 2023 22:44:51 -0700 Subject: [PATCH 2/2] Refactor by EB/LSK following review --- python/lsst/pipe/tasks/extended_psf.py | 28 +- python/lsst/pipe/tasks/processBrightStars.py | 564 ++++++++++-------- python/lsst/pipe/tasks/subtractBrightStars.py | 535 +++++++++-------- 3 files changed, 609 insertions(+), 518 deletions(-) diff --git a/python/lsst/pipe/tasks/extended_psf.py b/python/lsst/pipe/tasks/extended_psf.py index 48461a1fb..f2a65653e 100644 --- a/python/lsst/pipe/tasks/extended_psf.py +++ b/python/lsst/pipe/tasks/extended_psf.py @@ -274,13 +274,11 @@ def readFits(cls, filename): class StackBrightStarsConfig(Config): """Configuration parameters for StackBrightStarsTask.""" - subregion_size = ListField( - dtype=int, + subregion_size = ListField[int]( doc="Size, in pixels, of the subregions over which the stacking will be " "iteratively performed.", default=(100, 100), ) - stacking_statistic = ChoiceField( - dtype=str, + stacking_statistic = ChoiceField[str]( doc="Type of statistic to use for stacking.", default="MEANCLIP", allowed={ @@ -289,28 +287,23 @@ class StackBrightStarsConfig(Config): "MEANCLIP": "clipped mean", }, ) - num_sigma_clip = Field( - dtype=float, + num_sigma_clip = Field[float]( doc="Sigma for outlier rejection; ignored if stacking_statistic != 'MEANCLIP'.", default=4, ) - num_iter = Field( - dtype=int, + num_iter = Field[int]( doc="Number of iterations of outlier rejection; ignored if stackingStatistic != 'MEANCLIP'.", default=3, ) - bad_mask_planes = ListField( - dtype=str, + bad_mask_planes = ListField[str]( doc="Mask planes that define pixels to be excluded from the stacking of the bright star stamps.", default=("BAD", "CR", "CROSSTALK", "EDGE", "NO_DATA", "SAT", "SUSPECT", "UNMASKEDNAN"), ) - do_mag_cut = Field( - dtype=bool, + do_mag_cut = Field[bool]( doc="Apply magnitude cut before stacking?", default=False, ) - mag_limit = Field( - dtype=float, + mag_limit = Field[float]( doc="Magnitude limit, in Gaia G; all stars brighter than this value will be stacked", default=18, ) @@ -324,9 +317,10 @@ class StackBrightStarsTask(Task): def _set_up_stacking(self, example_stamp): """Configure stacking statistic and control from config fields.""" - stats_control = StatisticsControl() - stats_control.setNumSigmaClip(self.config.num_sigma_clip) - stats_control.setNumIter(self.config.num_iter) + stats_control = StatisticsControl( + numSigmaClip=self.config.num_sigma_clip, + numIter=self.config.num_iter, + ) if bad_masks := self.config.bad_mask_planes: and_mask = example_stamp.mask.getPlaneBitMask(bad_masks[0]) for bm in bad_masks[1:]: diff --git a/python/lsst/pipe/tasks/processBrightStars.py b/python/lsst/pipe/tasks/processBrightStars.py index 939be90c4..d7b48736b 100644 --- a/python/lsst/pipe/tasks/processBrightStars.py +++ b/python/lsst/pipe/tasks/processBrightStars.py @@ -21,10 +21,11 @@ """Extract bright star cutouts; normalize and warp to the same pixel grid.""" -__all__ = ["ProcessBrightStarsTask"] +__all__ = ["ProcessBrightStarsConnections", "ProcessBrightStarsConfig", "ProcessBrightStarsTask"] import astropy.units as u import numpy as np +from astropy.table import Table from lsst.afw.cameraGeom import PIXELS, TAN_PIXELS from lsst.afw.detection import FootprintSet, Threshold from lsst.afw.geom.transformFactory import makeIdentityTransform, makeTransform @@ -40,7 +41,6 @@ from lsst.meas.algorithms import LoadReferenceObjectsConfig, ReferenceObjectLoader from lsst.meas.algorithms.brightStarStamps import BrightStarStamp, BrightStarStamps from lsst.pex.config import ChoiceField, ConfigField, Field, ListField -from lsst.pex.exceptions import InvalidParameterError from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct from lsst.pipe.base.connectionTypes import Input, Output, PrerequisiteInput from lsst.utils.timer import timeMethod @@ -50,13 +50,13 @@ class ProcessBrightStarsConnections(PipelineTaskConnections, dimensions=("instru """Connections for ProcessBrightStarsTask.""" inputExposure = Input( - doc="Input exposure from which to extract bright star stamps", + doc="Input exposure from which to extract bright star stamps.", name="calexp", storageClass="ExposureF", dimensions=("visit", "detector"), ) skyCorr = Input( - doc="Input Sky Correction to be subtracted from the calexp if doApplySkyCorr=True", + doc="Input sky correction to be subtracted from the calexp if doApplySkyCorr=True.", name="skyCorr", storageClass="Background", dimensions=("instrument", "visit", "detector"), @@ -85,37 +85,32 @@ def __init__(self, *, config=None): class ProcessBrightStarsConfig(PipelineTaskConfig, pipelineConnections=ProcessBrightStarsConnections): """Configuration parameters for ProcessBrightStarsTask.""" - magLimit = Field( - dtype=float, + magLimit = Field[float]( doc="Magnitude limit, in Gaia G; all stars brighter than this value will be processed.", default=18, ) - stampSize = ListField( - dtype=int, + stampSize = ListField[int]( doc="Size of the stamps to be extracted, in pixels.", default=(250, 250), ) - modelStampBuffer = Field( - dtype=float, + modelStampBuffer = Field[float]( doc=( "'Buffer' factor to be applied to determine the size of the stamp the processed stars will be " "saved in. This will also be the size of the extended PSF model." ), default=1.1, ) - doRemoveDetected = Field( - dtype=bool, - doc="Whether DETECTION footprints, other than that for the central object, should be changed to BAD.", + doRemoveDetected = Field[bool]( + doc="Whether secondary DETECTION footprints (i.e., footprints of objects other than the central " + "primary object) should be changed to BAD.", default=True, ) - doApplyTransform = Field( - dtype=bool, + doApplyTransform = Field[bool]( doc="Apply transform to bright star stamps to correct for optical distortions?", default=True, ) - warpingKernelName = ChoiceField( - dtype=str, - doc="Warping kernel", + warpingKernelName = ChoiceField[str]( + doc="Warping kernel.", default="lanczos5", allowed={ "bilinear": "bilinear interpolation", @@ -124,13 +119,11 @@ class ProcessBrightStarsConfig(PipelineTaskConfig, pipelineConnections=ProcessBr "lanczos5": "Lanczos kernel of order 5", }, ) - annularFluxRadii = ListField( - dtype=int, + annularFluxRadii = ListField[int]( doc="Inner and outer radii of the annulus used to compute AnnularFlux for normalization, in pixels.", default=(70, 80), ) - annularFluxStatistic = ChoiceField( - dtype=str, + annularFluxStatistic = ChoiceField[str]( doc="Type of statistic to use to compute annular flux.", default="MEANCLIP", allowed={ @@ -139,60 +132,42 @@ class ProcessBrightStarsConfig(PipelineTaskConfig, pipelineConnections=ProcessBr "MEANCLIP": "clipped mean", }, ) - numSigmaClip = Field( - dtype=float, + numSigmaClip = Field[float]( doc="Sigma for outlier rejection; ignored if annularFluxStatistic != 'MEANCLIP'.", default=4, ) - numIter = Field( - dtype=int, + numIter = Field[int]( doc="Number of iterations of outlier rejection; ignored if annularFluxStatistic != 'MEANCLIP'.", default=3, ) - badMaskPlanes = ListField( - dtype=str, + badMaskPlanes = ListField[str]( doc="Mask planes that identify pixels to not include in the computation of the annular flux.", default=("BAD", "CR", "CROSSTALK", "EDGE", "NO_DATA", "SAT", "SUSPECT", "UNMASKEDNAN"), ) - minValidAnnulusFraction = Field( - dtype=float, + minValidAnnulusFraction = Field[float]( doc="Minumum number of valid pixels that must fall within the annulus for the bright star to be " "saved for subsequent generation of a PSF.", default=0.0, ) - doApplySkyCorr = Field( - dtype=bool, + doApplySkyCorr = Field[bool]( doc="Apply full focal plane sky correction before extracting stars?", default=True, ) - discardNanFluxStars = Field( - dtype=bool, + discardNanFluxStars = Field[bool]( doc="Should stars with NaN annular flux be discarded?", default=False, ) - refObjLoader = ConfigField( - dtype=LoadReferenceObjectsConfig, + refObjLoader = ConfigField[LoadReferenceObjectsConfig]( doc="Reference object loader for astrometric calibration.", ) class ProcessBrightStarsTask(PipelineTask): - """The description of the parameters for this Task are detailed in - :lsst-task:`~lsst.pipe.base.PipelineTask`. - - Parameters - ---------- - initInputs : `Unknown` - *args - Additional positional arguments. - **kwargs - Additional keyword arguments. - - Notes - ----- - `ProcessBrightStarsTask` is used to extract, process, and store small - image cut-outs (or "postage stamps") around bright stars. It relies on - three methods, called in succession: + """Extract bright star cutouts; normalize and warp to the same pixel grid. + + This task is used to extract, process, and store small image cut-outs + (or "postage stamps") around bright stars. It relies on three methods, + called in succession: `extractStamps` Find bright stars within the exposure using a reference catalog and @@ -211,35 +186,126 @@ class ProcessBrightStarsTask(PipelineTask): def __init__(self, initInputs=None, *args, **kwargs): super().__init__(*args, **kwargs) - # Compute (model) stamp size depending on provided "buffer" value - self.modelStampSize = [ - int(self.config.stampSize[0] * self.config.modelStampBuffer), - int(self.config.stampSize[1] * self.config.modelStampBuffer), - ] - # force it to be odd-sized so we have a central pixel - if not self.modelStampSize[0] % 2: - self.modelStampSize[0] += 1 - if not self.modelStampSize[1] % 2: - self.modelStampSize[1] += 1 - # central pixel - self.modelCenter = self.modelStampSize[0] // 2, self.modelStampSize[1] // 2 self.setModelStamp() - # configure Gaia refcat - if butler is not None: - self.makeSubtask("refObjLoader", butler=butler) + + def runQuantum(self, butlerQC, inputRefs, outputRefs): + inputs = butlerQC.get(inputRefs) + inputs["dataId"] = str(butlerQC.quantum.dataId) + refObjLoader = ReferenceObjectLoader( + dataIds=[ref.datasetRef.dataId for ref in inputRefs.refCat], + refCats=inputs.pop("refCat"), + name=self.config.connections.refCat, + config=self.config.refObjLoader, + ) + output = self.run(**inputs, refObjLoader=refObjLoader) + # Only ingest stamp if it exists; prevent ingesting an empty FITS file. + if output: + butlerQC.put(output, outputRefs) + + @timeMethod + def run(self, inputExposure, refObjLoader=None, dataId=None, skyCorr=None): + """Identify bright stars within an exposure using a reference catalog, + extract stamps around each, then preprocess them. + + Bright star preprocessing steps are: shifting, warping and potentially + rotating them to the same pixel grid; computing their annular flux, + and; normalizing them. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The image from which bright star stamps should be extracted. + refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional + Loader to find objects within a reference catalog. + dataId : `dict` or `~lsst.daf.butler.DataCoordinate` + The dataId of the exposure (including detector) that bright stars + should be extracted from. + skyCorr : `~lsst.afw.math.backgroundList.BackgroundList`, optional + Full focal plane sky correction obtained by `SkyCorrectionTask`. + + Returns + ------- + brightStarResults : `~lsst.pipe.base.Struct` + Results as a struct with attributes: + + ``brightStarStamps`` + (`~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`) + """ + if self.config.doApplySkyCorr: + self.log.info("Applying sky correction to exposure %s (exposure modified in-place).", dataId) + self.applySkyCorr(inputExposure, skyCorr) + + self.log.info("Extracting bright stars from exposure %s", dataId) + # Extract stamps around bright stars. + extractedStamps = self.extractStamps(inputExposure, refObjLoader=refObjLoader) + if not extractedStamps.starStamps: + self.log.info("No suitable bright star found.") + return None + # Warp (and shift, and potentially rotate) them. + self.log.info( + "Applying warp and/or shift to %i star stamps from exposure %s.", + len(extractedStamps.starStamps), + dataId, + ) + warpOutputs = self.warpStamps(extractedStamps.starStamps, extractedStamps.pixCenters) + warpedStars = warpOutputs.warpedStars + xy0s = warpOutputs.xy0s + brightStarList = [ + BrightStarStamp( + stamp_im=warp, + archive_element=transform, + position=xy0s[j], + gaiaGMag=extractedStamps.gMags[j], + gaiaId=extractedStamps.gaiaIds[j], + minValidAnnulusFraction=self.config.minValidAnnulusFraction, + ) + for j, (warp, transform) in enumerate(zip(warpedStars, warpOutputs.warpTransforms)) + ] + # Compute annularFlux and normalize + self.log.info( + "Computing annular flux and normalizing %i bright stars from exposure %s.", + len(warpedStars), + dataId, + ) + # annularFlux statistic set-up, excluding mask planes + statsControl = StatisticsControl( + numSigmaClip=self.config.numSigmaClip, + numIter=self.config.numIter, + ) + + innerRadius, outerRadius = self.config.annularFluxRadii + statsFlag = stringToStatisticsProperty(self.config.annularFluxStatistic) + brightStarStamps = BrightStarStamps.initAndNormalize( + brightStarList, + innerRadius=innerRadius, + outerRadius=outerRadius, + nb90Rots=warpOutputs.nb90Rots, + imCenter=self.modelCenter, + use_archive=True, + statsControl=statsControl, + statsFlag=statsFlag, + badMaskPlanes=self.config.badMaskPlanes, + discardNanFluxObjects=(self.config.discardNanFluxStars), + ) + # Do not create empty FITS files if there aren't any normalized stamps. + if not brightStarStamps._stamps: + self.log.info("No normalized stamps exist for this exposure.") + return None + return Struct(brightStarStamps=brightStarStamps) def applySkyCorr(self, calexp, skyCorr): - """Apply correction to the sky background level. + """Apply sky correction to the input exposure. - Sky corrections can be generated using the ``SkyCorrectionTask``. - As the sky model generated there extends over the full focal plane, - this should produce a more optimal sky subtraction solution. + Sky corrections can be generated using the + `~lsst.pipe.tasks.skyCorrection.SkyCorrectionTask`. + As the sky model generated via that task extends over the full focal + plane, this should produce a more optimal sky subtraction solution. Parameters ---------- calexp : `~lsst.afw.image.Exposure` or `~lsst.afw.image.MaskedImage` - Calibrated exposure. - skyCorr : `~lsst.afw.math.backgroundList.BackgroundList`, optional + Calibrated exposure to correct. + skyCorr : `~lsst.afw.math.backgroundList.BackgroundList` Full focal plane sky correction from ``SkyCorrectionTask``. Notes @@ -250,116 +316,193 @@ def applySkyCorr(self, calexp, skyCorr): calexp = calexp.getMaskedImage() calexp -= skyCorr.getImage() - def extractStamps(self, inputExposure, refObjLoader=None, inputBrightStarStamps=None): - """Read the position of bright stars within an input exposure using a - refCat and extract them. + def extractStamps( + self, inputExposure, filterName="phot_g_mean", refObjLoader=None, inputBrightStarStamps=None + ): + """Identify the positions of bright stars within an input exposure using + a reference catalog and extract them. Parameters ---------- inputExposure : `~lsst.afw.image.ExposureF` - The image from which bright star stamps should be extracted. + The image to extract bright star stamps from. + filterName : `str`, optional + Name of the camera filter to use for reference catalog filtering. refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional Loader to find objects within a reference catalog. + inputBrightStarStamps: + `~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`, optional + Provides information about the stars that have already been + extracted from the inputExposure in other steps of the pipeline. + For example, this is used in the `SubtractBrightStarsTask` to avoid + extracting stars that already have been extracted when running + `ProcessBrightStarsTask` to produce brightStarStamps. Returns ------- result : `~lsst.pipe.base.Struct` Results as a struct with attributes: - ``starIms`` + ``starStamps`` Postage stamps (`list`). ``pixCenters`` Corresponding coords to each star's center, in pixels (`list`). - ``GMags`` + ``gMags`` Corresponding (Gaia) G magnitudes (`list`). ``gaiaIds`` Corresponding unique Gaia identifiers (`np.ndarray`). """ if refObjLoader is None: refObjLoader = self.refObjLoader - starIms = [] - pixCenters = [] - GMags = [] - ids = [] + wcs = inputExposure.getWcs() - # select stars within, or close enough to input exposure from refcat - inputIm = inputExposure.maskedImage - inputExpBBox = inputExposure.getBBox() - # Attempt to include stars that are outside of the exposure but their - # stamps overlap with the exposure. + inputBBox = inputExposure.getBBox() + + # Trim the reference catalog to only those objects within the exposure + # bounding box dilated by half the bright star stamp size. This ensures + # all stars that overlap the exposure are included. dilatationExtent = Extent2I(np.array(self.config.stampSize) // 2) - # TODO (DM-25894): handle catalog with stars missing from Gaia - withinCalexp = refObjLoader.loadPixelBox( - inputExpBBox.dilatedBy(dilatationExtent), - wcs, - filterName="phot_g_mean", + withinExposure = refObjLoader.loadPixelBox( + inputBBox.dilatedBy(dilatationExtent), wcs, filterName=filterName ) - refCat = withinCalexp.refCat - # keep bright objects - fluxLimit = ((self.config.magLimit * u.ABmag).to(u.nJy)).to_value() - GFluxes = np.array(refCat["phot_g_mean_flux"]) - bright = GFluxes > fluxLimit - # convert to AB magnitudes - allGMags = np.array([((gFlux * u.nJy).to(u.ABmag)).to_value() for gFlux in GFluxes[bright]]) - allIds = refCat.columns.extract("id", where=bright)["id"] - selectedColumns = refCat.columns.extract("coord_ra", "coord_dec", where=bright) + refCat = withinExposure.refCat + fluxField = withinExposure.fluxField + + # Define ref cat bright subset: objects brighter than the mag limit. + fluxLimit = ((self.config.magLimit * u.ABmag).to(u.nJy)).to_value() # AB magnitudes. + refCatBright = Table( + refCat.extract("id", "coord_ra", "coord_dec", fluxField, where=refCat[fluxField] > fluxLimit) + ) + refCatBright["mag"] = (refCatBright[fluxField][:] * u.nJy).to(u.ABmag).to_value() # AB magnitudes. + + # Remove input bright stars (if provided) from the bright subset. if inputBrightStarStamps is not None: - existings = np.array(inputBrightStarStamps.getGaiaIds()) - existed = np.isin(allIds, existings) - allGMags = allGMags[~existed] - allIds = allIds[~existed] - selectedColumns["coord_ra"] = selectedColumns["coord_ra"][~existed] - selectedColumns["coord_dec"] = selectedColumns["coord_dec"][~existed] - for j, (ra, dec) in enumerate(zip(selectedColumns["coord_ra"], selectedColumns["coord_dec"])): - sp = SpherePoint(ra, dec, radians) - cpix = wcs.skyToPixel(sp) - try: - starIm = inputExposure.getCutout(sp, Extent2I(self.config.stampSize)) - except InvalidParameterError: - # star is beyond boundary - bboxCorner = np.array(cpix) - np.array(self.config.stampSize) / 2 - # compute bbox as it would be otherwise - idealBBox = Box2I(Point2I(bboxCorner), Extent2I(self.config.stampSize)) - clippedStarBBox = Box2I(idealBBox) - clippedStarBBox.clip(inputExpBBox) - if clippedStarBBox.getArea() > 0: - # create full-sized stamp with all pixels - # flagged as NO_DATA - starIm = ExposureF(bbox=idealBBox) - starIm.image[:] = np.nan - starIm.mask.set(inputExposure.mask.getPlaneBitMask("NO_DATA")) - # recover pixels from intersection with the exposure - clippedIm = inputIm.Factory(inputIm, clippedStarBBox) - starIm.maskedImage[clippedStarBBox] = clippedIm - # set detector and wcs, used in warpStars - starIm.setDetector(inputExposure.getDetector()) - starIm.setWcs(inputExposure.getWcs()) - else: - continue + # Extract the IDs of stars that have already been extracted. + existing = np.isin(refCatBright["id"][:], inputBrightStarStamps.getGaiaIds()) + refCatBright = refCatBright[~existing] + + # Loop over each reference bright star, extract a stamp around it. + pixCenters = [] + starStamps = [] + badRows = [] + for row, object in enumerate(refCatBright): + coordSky = SpherePoint(object["coord_ra"], object["coord_dec"], radians) + coordPix = wcs.skyToPixel(coordSky) + # TODO: Replace this method with exposure getCutout after DM-40042. + starStamp = self._getCutout(inputExposure, coordPix, self.config.stampSize.list()) + if not starStamp: + badRows.append(row) + continue if self.config.doRemoveDetected: - # give detection footprint of other objects the BAD flag - detThreshold = Threshold(starIm.mask.getPlaneBitMask("DETECTED"), Threshold.BITMASK) - omask = FootprintSet(starIm.mask, detThreshold) - allFootprints = omask.getFootprints() - otherFootprints = [] - for fs in allFootprints: - if not fs.contains(Point2I(cpix)): - otherFootprints.append(fs) - nbMatchingFootprints = len(allFootprints) - len(otherFootprints) - if not nbMatchingFootprints == 1: - self.log.warning( - "Failed to uniquely identify central DETECTION footprint for star " - "%s; found %d footprints instead.", - allIds[j], - nbMatchingFootprints, - ) - omask.setFootprints(otherFootprints) - omask.setMask(starIm.mask, "BAD") - starIms.append(starIm) - pixCenters.append(cpix) - GMags.append(allGMags[j]) - ids.append(allIds[j]) - return Struct(starIms=starIms, pixCenters=pixCenters, GMags=GMags, gaiaIds=ids) + self._replaceSecondaryFootprints(starStamp, coordPix, object["id"]) + starStamps.append(starStamp) + pixCenters.append(coordPix) + + # Remove bad rows from the reference catalog; set up return data. + refCatBright.remove_rows(badRows) + gMags = list(refCatBright["mag"][:]) + ids = list(refCatBright["id"][:]) + return Struct(starStamps=starStamps, pixCenters=pixCenters, gMags=gMags, gaiaIds=ids) + + def _getCutout(self, inputExposure, coordPix: Point2D, stampSize: list[int]): + """Get a cutout from an input exposure, handling edge cases. + + Generate a cutout from an input exposure centered on a given position + and with a given size. + If any part of the cutout is outside the input exposure bounding box, + the cutout is padded with NaNs. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The image to extract bright star stamps from. + coordPix : `~lsst.geom.Point2D` + Center of the cutout in pixel space. + stampSize : `list` [`int`] + Size of the cutout, in pixels. + + Returns + ------- + stamp : `~lsst.afw.image.ExposureF` or `None` + The cutout, or `None` if the cutout is entirely outside the input + exposure bounding box. + + Notes + ----- + This method is a short-term workaround until DM-40042 is implemented. + At that point, it should be replaced by a call to the Exposure method + ``getCutout``, which will handle edge cases automatically. + """ + # TODO: Replace this method with exposure getCutout after DM-40042. + corner = Point2I(np.array(coordPix) - np.array(stampSize) / 2) + dimensions = Extent2I(stampSize) + stampBBox = Box2I(corner, dimensions) + overlapBBox = Box2I(stampBBox) + overlapBBox.clip(inputExposure.getBBox()) + if overlapBBox.getArea() > 0: + # Create full-sized stamp with pixels initially flagged as NO_DATA. + stamp = ExposureF(bbox=stampBBox) + stamp.image[:] = np.nan + stamp.mask.set(inputExposure.mask.getPlaneBitMask("NO_DATA")) + # Restore pixels which overlap the input exposure. + inputMI = inputExposure.maskedImage + overlap = inputMI.Factory(inputMI, overlapBBox) + stamp.maskedImage[overlapBBox] = overlap + # Set detector and WCS. + stamp.setDetector(inputExposure.getDetector()) + stamp.setWcs(inputExposure.getWcs()) + else: + stamp = None + return stamp + + def _replaceSecondaryFootprints(self, stamp, coordPix, objectId, find="DETECTED", replace="BAD"): + """Replace all secondary footprints in a stamp with another mask flag. + + This method identifies all secondary footprints in a stamp as those + whose ``find`` footprints do not overlap the given pixel coordinates. + If then sets these secondary footprints to the ``replace`` flag. + + Parameters + ---------- + stamp : `~lsst.afw.image.ExposureF` + The postage stamp to modify. + coordPix : `~lsst.geom.Point2D` + The pixel coordinates of the central primary object. + objectId : `int` + The unique identifier of the central primary object. + find : `str`, optional + The mask plane to use to identify secondary footprints. + replace : `str`, optional + The mask plane to set secondary footprints to. + + Notes + ----- + This method modifies the input ``stamp`` in-place. + """ + # Find a FootprintSet given an Image and a threshold. + detThreshold = Threshold(stamp.mask.getPlaneBitMask(find), Threshold.BITMASK) + footprintSet = FootprintSet(stamp.mask, detThreshold) + allFootprints = footprintSet.getFootprints() + # Identify secondary objects (i.e., not the central primary object). + secondaryFootprints = [] + for footprint in allFootprints: + if not footprint.contains(Point2I(coordPix)): + secondaryFootprints.append(footprint) + # Set secondary object footprints to BAD. + # Note: the value of numPrimaryFootprints can only be 0 or 1. If it is + # 0, then the primary object was not found overlapping a footprint. + # This can occur for low-S/N stars, for example. Processing can still + # continue beyond this point in an attempt to utilize this faint flux. + if (numPrimaryFootprints := len(allFootprints) - len(secondaryFootprints)) == 0: + self.log.info( + "Could not uniquely identify central %s footprint for star %s; " + "found %d footprints instead.", + find, + objectId, + numPrimaryFootprints, + ) + footprintSet.setFootprints(secondaryFootprints) + footprintSet.setMask(stamp.mask, replace) def warpStamps(self, stamps, pixCenters): """Warps and shifts all given stamps so they are sampled on the same @@ -458,116 +601,15 @@ def warpStamps(self, stamps, pixCenters): return Struct(warpedStars=warpedStars, warpTransforms=warpTransforms, xy0s=xy0s, nb90Rots=nb90Rots) def setModelStamp(self): + """Compute (model) stamp size depending on provided buffer value.""" self.modelStampSize = [ int(self.config.stampSize[0] * self.config.modelStampBuffer), int(self.config.stampSize[1] * self.config.modelStampBuffer), ] - # force it to be odd-sized so we have a central pixel + # Force stamp to be odd-sized so we have a central pixel. if not self.modelStampSize[0] % 2: self.modelStampSize[0] += 1 if not self.modelStampSize[1] % 2: self.modelStampSize[1] += 1 - # central pixel + # Central pixel. self.modelCenter = self.modelStampSize[0] // 2, self.modelStampSize[1] // 2 - - @timeMethod - def run(self, inputExposure, refObjLoader=None, dataId=None, skyCorr=None): - """Identify bright stars within an exposure using a reference catalog, - extract stamps around each, then preprocess them. The preprocessing - steps are: shifting, warping and potentially rotating them to the same - pixel grid; computing their annular flux and normalizing them. - - Parameters - ---------- - inputExposure : `~lsst.afw.image.ExposureF` - The image from which bright star stamps should be extracted. - refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional - Loader to find objects within a reference catalog. - dataId : `dict` or `~lsst.daf.butler.DataCoordinate` - The dataId of the exposure (and detector) bright stars should be - extracted from. - skyCorr : `~lsst.afw.math.backgroundList.BackgroundList`, optional - Full focal plane sky correction obtained by `SkyCorrectionTask`. - - Returns - ------- - result : `~lsst.pipe.base.Struct` - Results as a struct with attributes: - - ``brightStarStamps`` - (`~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`) - """ - if self.config.doApplySkyCorr: - self.log.info( - "Applying sky correction to exposure %s (exposure will be modified in-place).", dataId - ) - self.applySkyCorr(inputExposure, skyCorr) - self.log.info("Extracting bright stars from exposure %s", dataId) - # Extract stamps around bright stars - extractedStamps = self.extractStamps(inputExposure, refObjLoader=refObjLoader) - if not extractedStamps.starIms: - self.log.info("No suitable bright star found.") - return None - # Warp (and shift, and potentially rotate) them - self.log.info( - "Applying warp and/or shift to %i star stamps from exposure %s.", - len(extractedStamps.starIms), - dataId, - ) - warpOutputs = self.warpStamps(extractedStamps.starIms, extractedStamps.pixCenters) - warpedStars = warpOutputs.warpedStars - xy0s = warpOutputs.xy0s - brightStarList = [ - BrightStarStamp( - stamp_im=warp, - archive_element=transform, - position=xy0s[j], - gaiaGMag=extractedStamps.GMags[j], - gaiaId=extractedStamps.gaiaIds[j], - minValidAnnulusFraction=self.config.minValidAnnulusFraction, - ) - for j, (warp, transform) in enumerate(zip(warpedStars, warpOutputs.warpTransforms)) - ] - # Compute annularFlux and normalize - self.log.info( - "Computing annular flux and normalizing %i bright stars from exposure %s.", - len(warpedStars), - dataId, - ) - # annularFlux statistic set-up, excluding mask planes - statsControl = StatisticsControl() - statsControl.setNumSigmaClip(self.config.numSigmaClip) - statsControl.setNumIter(self.config.numIter) - innerRadius, outerRadius = self.config.annularFluxRadii - statsFlag = stringToStatisticsProperty(self.config.annularFluxStatistic) - brightStarStamps = BrightStarStamps.initAndNormalize( - brightStarList, - innerRadius=innerRadius, - outerRadius=outerRadius, - nb90Rots=warpOutputs.nb90Rots, - imCenter=self.modelCenter, - use_archive=True, - statsControl=statsControl, - statsFlag=statsFlag, - badMaskPlanes=self.config.badMaskPlanes, - discardNanFluxObjects=(self.config.discardNanFluxStars), - ) - # Dont create empty fits files if there is no normalized stamp! - if not len(brightStarStamps._stamps) > 0: - self.log.info("No normalized stamps exists for this exposure!") - return None - return Struct(brightStarStamps=brightStarStamps) - - def runQuantum(self, butlerQC, inputRefs, outputRefs): - inputs = butlerQC.get(inputRefs) - inputs["dataId"] = str(butlerQC.quantum.dataId) - refObjLoader = ReferenceObjectLoader( - dataIds=[ref.datasetRef.dataId for ref in inputRefs.refCat], - refCats=inputs.pop("refCat"), - name=self.config.connections.refCat, - config=self.config.refObjLoader, - ) - output = self.run(**inputs, refObjLoader=refObjLoader) - # This if block prevents the code to produce an emtpy fits file in case there is no stamp. - if output: - butlerQC.put(output, outputRefs) diff --git a/python/lsst/pipe/tasks/subtractBrightStars.py b/python/lsst/pipe/tasks/subtractBrightStars.py index 2d18f2805..aa313f4a5 100644 --- a/python/lsst/pipe/tasks/subtractBrightStars.py +++ b/python/lsst/pipe/tasks/subtractBrightStars.py @@ -28,8 +28,8 @@ from operator import ior import numpy as np -from lsst.afw.image import Exposure, ExposureF, MaskedImageF from lsst.afw.geom import SpanSet, Stencil +from lsst.afw.image import Exposure, ExposureF, MaskedImageF from lsst.afw.math import ( StatisticsControl, WarpingControl, @@ -52,9 +52,11 @@ class SubtractBrightStarsConnections( PipelineTaskConnections, dimensions=("instrument", "visit", "detector"), - defaultTemplates={"outputExposureName": "brightStar_subtracted", - "outputBackgroundName": "brightStars", - "badStampsName": "brightStars"}, + defaultTemplates={ + "outputExposureName": "brightStar_subtracted", + "outputBackgroundName": "brightStars", + "badStampsName": "brightStars", + }, ): inputExposure = Input( doc="Input exposure from which to subtract bright star stamps.", @@ -117,8 +119,8 @@ class SubtractBrightStarsConnections( ), ) outputBadStamps = Output( - doc="The stamps the are not normalized and consequently not subtracted from the exposure.", - name="{badStampsName}_unsubtracted_stapms", + doc="The stamps that are not normalized and consequently not subtracted from the exposure.", + name="{badStampsName}_unsubtracted_stamps", storageClass="BrightStarStamps", dimensions=( "visit", @@ -136,38 +138,31 @@ class SubtractBrightStarsConfig(PipelineTaskConfig, pipelineConnections=Subtract """Configuration parameters for SubtractBrightStarsTask""" doWriteSubtractor = Field[bool]( - dtype=bool, doc="Should an exposure containing all bright star models be written to disk?", default=True, ) doWriteSubtractedExposure = Field[bool]( - dtype=bool, doc="Should an exposure with bright stars subtracted be written to disk?", default=True, ) magLimit = Field[float]( - dtype=float, doc="Magnitude limit, in Gaia G; all stars brighter than this value will be subtracted", default=18, ) - minValidAnnulusFraction = Field( - dtype=float, - doc="Minumum number of valid pixels that must fall within the annulus for the bright star to be " + minValidAnnulusFraction = Field[float]( + doc="Minimum number of valid pixels that must fall within the annulus for the bright star to be " "saved for subsequent generation of a PSF.", default=0.0, ) - numSigmaClip = Field( - dtype=float, + numSigmaClip = Field[float]( doc="Sigma for outlier rejection; ignored if annularFluxStatistic != 'MEANCLIP'.", default=4, ) - numIter = Field( - dtype=int, + numIter = Field[int]( doc="Number of iterations of outlier rejection; ignored if annularFluxStatistic != 'MEANCLIP'.", default=3, ) warpingKernelName = ChoiceField[str]( - dtype=str, doc="Warping kernel", default="lanczos5", allowed={ @@ -180,7 +175,6 @@ class SubtractBrightStarsConfig(PipelineTaskConfig, pipelineConnections=Subtract }, ) scalingType = ChoiceField[str]( - dtype=str, doc="How the model should be scaled to each bright star; implemented options are " "`annularFlux` to reuse the annular flux of each stamp, or `leastSquare` to perform " "least square fitting on each pixel with no bad mask plane set.", @@ -190,8 +184,7 @@ class SubtractBrightStarsConfig(PipelineTaskConfig, pipelineConnections=Subtract "leastSquare": "find least square scaling factor", }, ) - annularFluxStatistic = ChoiceField( - dtype=str, + annularFluxStatistic = ChoiceField[str]( doc="Type of statistic to use to compute annular flux.", default="MEANCLIP", allowed={ @@ -201,7 +194,6 @@ class SubtractBrightStarsConfig(PipelineTaskConfig, pipelineConnections=Subtract }, ) badMaskPlanes = ListField[str]( - dtype=str, doc="Mask planes that, if set, lead to associated pixels not being included in the computation of " "the scaling factor (`BAD` should always be included). Ignored if scalingType is `annularFlux`, " "as the stamps are expected to already be normalized.", @@ -210,26 +202,24 @@ class SubtractBrightStarsConfig(PipelineTaskConfig, pipelineConnections=Subtract # interest) also get set to `BAD`. default=("BAD", "CR", "CROSSTALK", "EDGE", "NO_DATA", "SAT", "SUSPECT", "UNMASKEDNAN"), ) - subtractionBox = ListField( - dtype=int, + subtractionBox = ListField[int]( doc="Size of the stamps to be extracted, in pixels.", default=(250, 250), ) - subtractionBoxBuffer = Field( - dtype=float, + subtractionBoxBuffer = Field[float]( doc=( - "'Buffer' factor to be applied to determine the size of the stamp the processed stars will be " - "saved in. This will also be the size of the extended PSF model." + "'Buffer' (multiplicative) factor to be applied to determine the size of the stamp the " + "processed stars will be saved in. This is also the size of the extended PSF model. The buffer " + "region is masked and contain no data and subtractionBox determines the region where contains " + "the data." ), default=1.1, ) doApplySkyCorr = Field[bool]( - dtype=bool, doc="Apply full focal plane sky correction before extracting stars?", default=True, ) - refObjLoader = ConfigField( - dtype=LoadReferenceObjectsConfig, + refObjLoader = ConfigField[LoadReferenceObjectsConfig]( doc="Reference object loader for astrometric calibration.", ) @@ -251,18 +241,142 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Placeholders to set up Statistics if scalingType is leastSquare. self.statsControl, self.statsFlag = None, None - # warping control; only contains shiftingALg provided in config - self.warpCont = WarpingControl(self.config.warpingKernelName) + # Warping control; only contains shiftingALg provided in config. + self.warpControl = WarpingControl(self.config.warpingKernelName) + + def runQuantum(self, butlerQC, inputRefs, outputRefs): + # Docstring inherited. + inputs = butlerQC.get(inputRefs) + dataId = butlerQC.quantum.dataId + refObjLoader = ReferenceObjectLoader( + dataIds=[ref.datasetRef.dataId for ref in inputRefs.refCat], + refCats=inputs.pop("refCat"), + name=self.config.connections.refCat, + config=self.config.refObjLoader, + ) + subtractor, _, badStamps = self.run(**inputs, dataId=dataId, refObjLoader=refObjLoader) + if self.config.doWriteSubtractedExposure: + outputExposure = inputs["inputExposure"].clone() + outputExposure.image -= subtractor.image + else: + outputExposure = None + outputBackgroundExposure = subtractor if self.config.doWriteSubtractor else None + # In its current state, the code produces outputBadStamps which are the + # stamps of stars that have not been subtracted from the image for any + # reason. If all the stars are subtracted from the calexp, the output + # is an empty fits file. + output = Struct( + outputExposure=outputExposure, + outputBackgroundExposure=outputBackgroundExposure, + outputBadStamps=badStamps, + ) + butlerQC.put(output, outputRefs) + + def run( + self, inputExposure, inputBrightStarStamps, inputExtendedPsf, dataId, skyCorr=None, refObjLoader=None + ): + """Iterate over all bright stars in an exposure to scale the extended + PSF model before subtracting bright stars. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The image from which bright stars should be subtracted. + inputBrightStarStamps : + `~lsst.meas.algorithms.brightStarStamps.BrightStarStamps` + Set of stamps centered on each bright star to be subtracted, + produced by running + `~lsst.pipe.tasks.processBrightStars.ProcessBrightStarsTask`. + inputExtendedPsf : `~lsst.pipe.tasks.extended_psf.ExtendedPsf` + Extended PSF model, produced by + `~lsst.pipe.tasks.extended_psf.MeasureExtendedPsfTask`. + dataId : `dict` or `~lsst.daf.butler.DataCoordinate` + The dataId of the exposure (and detector) bright stars should be + subtracted from. + skyCorr : `~lsst.afw.math.backgroundList.BackgroundList`, optional + Full focal plane sky correction, obtained by running + `~lsst.pipe.tasks.skyCorrection.SkyCorrectionTask`. If + `doApplySkyCorr` is set to `True`, `skyCorr` cannot be `None`. + refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional + Loader to find objects within a reference catalog. + + Returns + ------- + subtractorExp : `~lsst.afw.image.ExposureF` + An Exposure containing a scaled bright star model fit to every + bright star profile; its image can then be subtracted from the + input exposure. + invImages : `list` [`~lsst.afw.image.MaskedImageF`] + A list of small images ("stamps") containing the model, each scaled + to its corresponding input bright star. + """ + self.inputExpBBox = inputExposure.getBBox() + if self.config.doApplySkyCorr and (skyCorr is not None): + self.log.info( + "Applying sky correction to exposure %s (exposure will be modified in-place).", dataId + ) + self.applySkyCorr(inputExposure, skyCorr) + + # Create an empty image the size of the exposure. + # TODO: DM-31085 (set mask planes). + subtractorExp = ExposureF(bbox=inputExposure.getBBox()) + subtractor = subtractorExp.maskedImage + + # Make a copy of the input model. + self.model = inputExtendedPsf(dataId["detector"]).clone() + self.modelStampSize = self.model.getDimensions() + # Number of 90 deg. rotations to reverse each stamp's rotation. + self.inv90Rots = 4 - inputBrightStarStamps.nb90Rots % 4 + self.model = rotateImageBy90(self.model, self.inv90Rots) + + brightStarList = self.makeBrightStarList(inputBrightStarStamps, inputExposure, refObjLoader) + invImages = [] + subtractor, invImages = self.buildSubtractor( + inputBrightStarStamps, subtractor, invImages, multipleAnnuli=False + ) + if brightStarList: + self.setMissedStarsStatsControl() + # This may change when multiple star bins are used for PSF + # creation. + innerRadius = inputBrightStarStamps._innerRadius + outerRadius = inputBrightStarStamps._outerRadius + brightStarStamps, badStamps = BrightStarStamps.initAndNormalize( + brightStarList, + innerRadius=innerRadius, + outerRadius=outerRadius, + nb90Rots=self.warpOutputs.nb90Rots, + imCenter=self.warper.modelCenter, + use_archive=True, + statsControl=self.missedStatsControl, + statsFlag=self.missedStatsFlag, + badMaskPlanes=self.warper.config.badMaskPlanes, + discardNanFluxObjects=False, + forceFindFlux=True, + ) + + self.psf_annular_fluxes = self.findPsfAnnularFluxes(brightStarStamps) + subtractor, invImages = self.buildSubtractor( + brightStarStamps, subtractor, invImages, multipleAnnuli=True + ) + else: + badStamps = [] + badStamps = BrightStarStamps(badStamps) + + return subtractorExp, invImages, badStamps def _setUpStatistics(self, exampleMask): """Configure statistics control and flag, for use if ``scalingType`` is `leastSquare`. """ if self.config.scalingType == "leastSquare": - self.statsControl = StatisticsControl() # Set the mask planes which will be ignored. - andMask = reduce(ior, (exampleMask.getPlaneBitMask(bm) for bm in self.config.badMaskPlanes)) - self.statsControl.setAndMask(andMask) + andMask = reduce( + ior, + (exampleMask.getPlaneBitMask(bm) for bm in self.config.badMaskPlanes), + ) + self.statsControl = StatisticsControl( + andMask=andMask, + ) self.statsFlag = stringToStatisticsProperty("SUM") def applySkyCorr(self, calexp, skyCorr): @@ -284,7 +398,7 @@ def applySkyCorr(self, calexp, skyCorr): calexp = calexp.getMaskedImage() calexp -= skyCorr.getImage() - def scaleModel(self, model, star, inPlace=True, nb90Rots=0, psf_annular_flux=None): + def scaleModel(self, model, star, inPlace=True, nb90Rots=0, psf_annular_flux=1.0): """Compute scaling factor to be applied to the extended PSF so that its amplitude matches that of an individual star. @@ -292,13 +406,18 @@ def scaleModel(self, model, star, inPlace=True, nb90Rots=0, psf_annular_flux=Non ---------- model : `~lsst.afw.image.MaskedImageF` The extended PSF model, shifted (and potentially warped) to match - the bright star's positioning. + the bright star position. star : `~lsst.meas.algorithms.brightStarStamps.BrightStarStamp` A stamp centered on the bright star to be subtracted. inPlace : `bool` Whether the model should be scaled in place. Default is `True`. nb90Rots : `int` The number of 90-degrees rotations to apply to the star stamp. + psf_annular_flux: `float`, optional + The annular flux of the PSF model at the radius where the flux of + the given star is determined. This is 1 for stars present in + inputBrightStarStamps, but can be different for stars that are + missing from inputBrightStarStamps. Returns ------- @@ -306,8 +425,6 @@ def scaleModel(self, model, star, inPlace=True, nb90Rots=0, psf_annular_flux=Non The factor by which the model image should be multiplied for it to be scaled to the input bright star. """ - if psf_annular_flux is None: - psf_annular_flux = 1 if self.config.scalingType == "annularFlux": scalingFactor = star.annularFlux * psf_annular_flux elif self.config.scalingType == "leastSquare": @@ -332,9 +449,13 @@ def scaleModel(self, model, star, inPlace=True, nb90Rots=0, psf_annular_flux=Non model.image *= scalingFactor return scalingFactor - def _overRideWarperConfig(self): + def _overrideWarperConfig(self): """Override the warper config with the config of this task. + + This override is necessary for stars that are missing from the + inputBrightStarStamps object but still need to be subtracted. """ + # TODO: Replace these copied values with a warperConfig. self.warper.config.minValidAnnulusFraction = self.config.minValidAnnulusFraction self.warper.config.numSigmaClip = self.config.numSigmaClip self.warper.config.numIter = self.config.numIter @@ -342,31 +463,37 @@ def _overRideWarperConfig(self): self.warper.config.badMaskPlanes = self.config.badMaskPlanes self.warper.config.stampSize = self.config.subtractionBox self.warper.modelStampBuffer = self.config.subtractionBoxBuffer + self.warper.config.magLimit = self.config.magLimit self.warper.setModelStamp() def setMissedStarsStatsControl(self): - """Configure statistics control for processing missing stars from inputBrightStarStamps. + """Configure statistics control for processing missing stars from + inputBrightStarStamps. """ - self.missedStatsControl = StatisticsControl() - self.missedStatsControl.setNumSigmaClip(self.warper.config.numSigmaClip) - self.missedStatsControl.setNumIter(self.warper.config.numIter) + self.missedStatsControl = StatisticsControl( + numSigmaClip=self.warper.config.numSigmaClip, + numIter=self.warper.config.numIter, + ) self.missedStatsFlag = stringToStatisticsProperty(self.warper.config.annularFluxStatistic) def setWarpTask(self): - """Create an instance of ProcessBrightStarsTask that will be used to produce stamps of stars to be - subtracted. + """Create an instance of ProcessBrightStarsTask that will be used to + produce stamps of stars to be subtracted. """ self.warper = ProcessBrightStarsTask() - self._overRideWarperConfig() + self._overrideWarperConfig() self.warper.modelCenter = self.modelStampSize[0] // 2, self.modelStampSize[1] // 2 def makeBrightStarList(self, inputBrightStarStamps, inputExposure, refObjLoader): - """Make a list of bright stars that are missing from inputBrightStarStamps to be subtracted. + """Make a list of bright stars that are missing from + inputBrightStarStamps to be subtracted. Parameters ---------- - inputBrightStarStamps : `~lsst.meas.algorithms.brightStarStamps.BrightStarStamps` - Set of stamps centered on each bright star to be subtracted, produced by running + inputBrightStarStamps : + `~lsst.meas.algorithms.brightStarStamps.BrightStarStamps` + Set of stamps centered on each bright star to be subtracted, + produced by running `~lsst.pipe.tasks.processBrightStars.ProcessBrightStarsTask`. inputExposure : `~lsst.afw.image.ExposureF` The image from which bright stars should be subtracted. @@ -376,26 +503,31 @@ def makeBrightStarList(self, inputBrightStarStamps, inputExposure, refObjLoader) Returns ------- brightStarList: - A list containing `lsst.meas.algorithms.brightStarStamps.BrightStarStamp` of stars to be - subtracted. + A list containing + `lsst.meas.algorithms.brightStarStamps.BrightStarStamp` of stars to + be subtracted. """ self.setWarpTask() missedStars = self.warper.extractStamps( inputExposure, refObjLoader=refObjLoader, inputBrightStarStamps=inputBrightStarStamps ) - self.warpOutputs = self.warper.warpStamps(missedStars.starIms, missedStars.pixCenters) - brightStarList = [ - BrightStarStamp( - stamp_im=warp, - archive_element=transform, - position=self.warpOutputs.xy0s[j], - gaiaGMag=missedStars.GMags[j], - gaiaId=missedStars.gaiaIds[j], - minValidAnnulusFraction=self.warper.config.minValidAnnulusFraction, - ) - for j, (warp, transform) in enumerate(zip(self.warpOutputs.warpedStars, - self.warspOutputs.warpTransforms)) - ] + if missedStars.starStamps: + self.warpOutputs = self.warper.warpStamps(missedStars.starStamps, missedStars.pixCenters) + brightStarList = [ + BrightStarStamp( + stamp_im=warp, + archive_element=transform, + position=self.warpOutputs.xy0s[j], + gaiaGMag=missedStars.gMags[j], + gaiaId=missedStars.gaiaIds[j], + minValidAnnulusFraction=self.warper.config.minValidAnnulusFraction, + ) + for j, (warp, transform) in enumerate( + zip(self.warpOutputs.warpedStars, self.warpOutputs.warpTransforms) + ) + ] + else: + brightStarList = [] return brightStarList def initAnnulusImage(self): @@ -412,11 +544,19 @@ def initAnnulusImage(self): return annulusImage def createAnnulus(self, brightStarStamp): - """Create an annulus of the given star. + """Create a circular annulus around the given star. + + The circular annulus is set based on the inner and outer optimal radii. + These radii describe the annulus where the flux of the star is found. + The aim is to create the same annulus for the PSF model, eventually + measuring the model flux around that annulus. + An optimal radius usually differs from the radius where the PSF model + is normalized. Parameters ---------- - brightStarStamp : `~lsst.meas.algorithms.brightStarStamps.BrightStarStamp` + brightStarStamp : + `~lsst.meas.algorithms.brightStarStamps.BrightStarStamp` A stamp of a bright star to be subtracted. Returns @@ -424,7 +564,7 @@ def createAnnulus(self, brightStarStamp): annulus : `~lsst.afw.image.MaskedImageF` An annulus of the given star. """ - # Create SpanSet of annulus + # Create SpanSet of annulus. outerCircle = SpanSet.fromShape( brightStarStamp.optimalOuterRadius, Stencil.CIRCLE, offset=self.warper.modelCenter ) @@ -445,7 +585,8 @@ def applyStatsControl(self, annulusImage): Returns ------- annularFlux: float - The annular flux of the PSF model at the radius where the flux of the given star is determined. + The annular flux of the PSF model at the radius where the flux of + the given star is determined. """ andMask = reduce( ior, (annulusImage.mask.getPlaneBitMask(bm) for bm in self.warper.config.badMaskPlanes) @@ -455,20 +596,25 @@ def applyStatsControl(self, annulusImage): return annulusStat.getValue() def findPsfAnnularFlux(self, brightStarStamp, maskedModel): - """Find the annular flux of the PSF model at the radius where the flux of the given star is - determined. + """Find the annular flux of the PSF model within a specified annulus. + + This flux will be used for re-scaling the PSF to the level of stars + with bad stamps. Stars with bad stamps are those without a flux within + the normalization annulus. Parameters ---------- - brightStarStamp : `~lsst.meas.algorithms.brightStarStamps.BrightStarStamp` + brightStarStamp : + `~lsst.meas.algorithms.brightStarStamps.BrightStarStamp` A stamp of a bright star to be subtracted. maskedModel : `~lsst.afw.image.MaskedImageF` A masked image of the PSF model. Returns ------- - annularFlux: float - The annular flux of the PSF model at the radius where the flux of the given star is determined. + annularFlux: float (between 0 and 1) + The annular flux of the PSF model at the radius where the flux of + the given star is determined. """ annulusImage = self.initAnnulusImage() annulus = self.createAnnulus(brightStarStamp) @@ -481,26 +627,28 @@ def findPsfAnnularFluxes(self, brightStarStamps): Parameters ---------- - brightStarStamps : `~lsst.meas.algorithms.brightStarStamps.BrightStarStamps` + brightStarStamps : + `~lsst.meas.algorithms.brightStarStamps.BrightStarStamps` The stamps of stars that will be subtracted from the exposure. Returns ------- PsfAnnularFluxes: numpy.array - A two column numpy.array containing annular fluxes of the PSF at radii where the flux for stars - exist (could be found). + A two column numpy.array containing annular fluxes of the PSF at + radii where the flux for stars exist (could be found). Notes ----- - While the PSF model is normalized at a certain radius, the flux of a star at that radius might be - impossible to find. Therefore, we have to scale the PSF model considering a radius where the star has - an identified flux. To do that, the flux of the model should be found and used to adjust the scaling - step. + While the PSF model is normalized at a certain radius, the annular flux + of a star around that radius might be impossible to find. Therefore, we + have to scale the PSF model considering a radius where the star has an + identified flux. To do that, the flux of the model should be found and + used to adjust the scaling step. """ outerRadii = [] annularFluxes = [] maskedModel = MaskedImageF(self.model.image) - # the model has wrong bbox values. Should be fixed in extended_psf.py? + # The model has wrong bbox values. Should be fixed in extended_psf.py? maskedModel.setXY0(0, 0) for star in brightStarStamps: if star.optimalOuterRadius not in outerRadii: @@ -510,25 +658,39 @@ def findPsfAnnularFluxes(self, brightStarStamps): return np.array([outerRadii, annularFluxes]).T def preparePlaneModelStamp(self, brightStarStamp): - """Prepare the PSF model before scaling. + """Prepare the PSF plane model stamp. + + It is called PlaneModel because, while it is a PSF model stamp that is + warped and rotated to the same orientation of a chosen star, it is not + yet scaled to the brightness level of the star. Parameters ---------- - brightStarStamp : `~lsst.meas.algorithms.brightStarStamps.BrightStarStamp` + brightStarStamp : + `~lsst.meas.algorithms.brightStarStamps.BrightStarStamp` The stamp of the star to which the PSF model will be scaled. Returns ------- bbox: `~lsst.geom.Box2I` - Contains the corner coordination and the dimensions of the model stamp. + Contains the corner coordination and the dimensions of the model + stamp. invImage: `~lsst.afw.image.MaskedImageF` - The extended PSF model, shifted (and potentially warped) to match the bright star's positioning. + The extended PSF model, shifted (and potentially warped and + rotated) to match the bright star position. Raises ------ RuntimeError - Raised if warping of the model is failed. + Raised if warping of the model failed. + + Notes + ----- + Since detectors have different orientations, the PSF model should be + rotated to match the orientation of the detectors in some cases. To do + that, the code uses the inverse of the transform that is applied to the + bright star stamp to match the orientation of the detector. """ # Set the origin. self.model.setXY0(brightStarStamp.position) @@ -538,12 +700,13 @@ def preparePlaneModelStamp(self, brightStarStamp): bbox = Box2I(corner=invOrigin, dimensions=self.modelStampSize) invImage = MaskedImageF(bbox) # Apply inverse transform. - goodPix = warpImage(invImage, self.model, invTransform, self.warpCont) + goodPix = warpImage(invImage, self.model, invTransform, self.warpControl) if not goodPix: - # Do we want to find another way or just subtract the non-warped scaled model? + # Do we want to find another way or just subtract the non-warped + # scaled model? # Currently the code just leaves the failed ones un-subtracted. raise RuntimeError( - f"Warping of a model failed for star {brightStarStamp.gaiaId}: " "no good pixel in output" + f"Warping of a model failed for star {brightStarStamp.gaiaId}: no good pixel in output." ) return bbox, invImage @@ -553,189 +716,81 @@ def addScaledModel(self, subtractor, brightStarStamp, multipleAnnuli=False): Parameters ---------- subtractor : `~lsst.afw.image.MaskedImageF` - The Exposure containing the scaled model of brigth stars to be subtracted from the input - exposure. - brightStarStamp : `~lsst.meas.algorithms.brightStarStamps.BrightStarStamp` - The stamp of the star of which the PSF model will be scaled and added to the subtractor. + The full image containing the scaled model of bright stars to be + subtracted from the input exposure. + brightStarStamp : + `~lsst.meas.algorithms.brightStarStamps.BrightStarStamp` + The stamp of the star of which the PSF model will be scaled and + added to the subtractor. multipleAnnuli : bool, optional - If true, the model should be scaled based on a flux at a radius other than its normalization - radius. + If true, the model should be scaled based on a flux at a radius + other than its normalization radius. Returns ------- subtractor : `~lsst.afw.image.MaskedImageF` - The input subtractor Exposure with the added scaled model at the given star's location in the - exposure. + The input subtractor full image with the added scaled model at the + given star's location in the exposure. invImage: `~lsst.afw.image.MaskedImageF` - The extended PSF model, shifted (and potentially warped) to match the bright star's positioning. + The extended PSF model, shifted (and potentially warped) to match + the bright star position. """ bbox, invImage = self.preparePlaneModelStamp(brightStarStamp) - if multipleAnnuli: - cond = self.psf_annular_fluxes[:, 0] == brightStarStamp.optimalOuterRadius - psf_annular_flux = self.psf_annular_fluxes[cond, 1][0] - self.scaleModel(invImage, - brightStarStamp, - inPlace=True, - nb90Rots=self.inv90Rots, - psf_annular_flux=psf_annular_flux) - else: - self.scaleModel(invImage, brightStarStamp, inPlace=True, nb90Rots=self.inv90Rots) - # Replace NaNs before subtraction (note all NaN pixels have - # the NO_DATA flag). - invImage.image.array[np.isnan(invImage.image.array)] = 0 bbox.clip(self.inputExpBBox) if bbox.getArea() > 0: + if multipleAnnuli: + cond = self.psf_annular_fluxes[:, 0] == brightStarStamp.optimalOuterRadius + psf_annular_flux = self.psf_annular_fluxes[cond, 1][0] + self.scaleModel( + invImage, + brightStarStamp, + inPlace=True, + nb90Rots=self.inv90Rots, + psf_annular_flux=psf_annular_flux, + ) + else: + self.scaleModel(invImage, brightStarStamp, inPlace=True, nb90Rots=self.inv90Rots) + # Replace NaNs before subtraction (all NaNs have the NO_DATA flag). + invImage.image.array[np.isnan(invImage.image.array)] = 0 subtractor[bbox] += invImage[bbox] return subtractor, invImage def buildSubtractor(self, brightStarStamps, subtractor, invImages, multipleAnnuli=False): - """Build an image containing potentially multiple scaled PSF models, each at the location of a given - brigth star. + """Build an image containing potentially multiple scaled PSF models, + each at the location of a given bright star. Parameters ---------- - brightStarStamps : `~lsst.meas.algorithms.brightStarStamps.BrightStarStamps` - Set of stamps centered on each bright star to be subtracted, produced by running + brightStarStamps : + `~lsst.meas.algorithms.brightStarStamps.BrightStarStamps` + Set of stamps centered on each bright star to be subtracted, + produced by running `~lsst.pipe.tasks.processBrightStars.ProcessBrightStarsTask`. subtractor : `~lsst.afw.image.MaskedImageF` - The Exposure that will contain the scaled model of brigth stars to be subtracted from the - exposure. + The Exposure that will contain the scaled model of bright stars to + be subtracted from the exposure. invImages : `list` - A list containing extended PSF models, shifted (and potentially warped) to match the bright stars - positionings. + A list containing extended PSF models, shifted (and potentially + warped) to match the bright stars positions. multipleAnnuli : bool, optional This will be passed to addScaledModel method, by default False. Returns ------- subtractor : `~lsst.afw.image.MaskedImageF` - An Exposure containing a scaled bright star model fit to every bright star profile; its image can - then be subtracted from the input exposure. + An Exposure containing a scaled bright star model fit to every + bright star profile; its image can then be subtracted from the + input exposure. invImages: list - A list containing the extended PSF models, shifted (and potentially warped) to match bright - stars' positionings. + A list containing the extended PSF models, shifted (and potentially + warped) to match bright stars' positions. """ for star in brightStarStamps: if star.gaiaGMag < self.config.magLimit: try: - # Adding the scaled model at the star location to the subtractor. + # Add the scaled model at the star location to subtractor. subtractor, invImage = self.addScaledModel(subtractor, star, multipleAnnuli) invImages.append(invImage) except RuntimeError as err: logger.error(err) return subtractor, invImages - - def runQuantum(self, butlerQC, inputRefs, outputRefs): - # Docstring inherited. - inputs = butlerQC.get(inputRefs) - dataId = butlerQC.quantum.dataId - refObjLoader = ReferenceObjectLoader( - dataIds=[ref.datasetRef.dataId for ref in inputRefs.refCat], - refCats=inputs.pop("refCat"), - name=self.config.connections.refCat, - config=self.config.refObjLoader, - ) - subtractor, _, badStamps = self.run(**inputs, dataId=dataId, refObjLoader=refObjLoader) - if self.config.doWriteSubtractedExposure: - outputExposure = inputs["inputExposure"].clone() - outputExposure.image -= subtractor.image - else: - outputExposure = None - outputBackgroundExposure = subtractor if self.config.doWriteSubtractor else None - # in its current state, the code produces outputBadStamps which are the stamps of stars that have not - # been subtracted from the image for any reason. If all the stars are subtracted from the calexp, the - # output is an empty fits file. - output = Struct(outputExposure=outputExposure, - outputBackgroundExposure=outputBackgroundExposure, - outputBadStamps=badStamps) - butlerQC.put(output, outputRefs) - - def run(self, - inputExposure, - inputBrightStarStamps, - inputExtendedPsf, - dataId, - skyCorr=None, - refObjLoader=None): - """Iterate over all bright stars in an exposure to scale the extended - PSF model before subtracting bright stars. - - Parameters - ---------- - inputExposure : `~lsst.afw.image.ExposureF` - The image from which bright stars should be subtracted. - inputBrightStarStamps : - `~lsst.meas.algorithms.brightStarStamps.BrightStarStamps` - Set of stamps centered on each bright star to be subtracted, - produced by running - `~lsst.pipe.tasks.processBrightStars.ProcessBrightStarsTask`. - inputExtendedPsf : `~lsst.pipe.tasks.extended_psf.ExtendedPsf` - Extended PSF model, produced by - `~lsst.pipe.tasks.extended_psf.MeasureExtendedPsfTask`. - dataId : `dict` or `~lsst.daf.butler.DataCoordinate` - The dataId of the exposure (and detector) bright stars should be - subtracted from. - skyCorr : `~lsst.afw.math.backgroundList.BackgroundList`, optional - Full focal plane sky correction, obtained by running - `~lsst.pipe.tasks.skyCorrection.SkyCorrectionTask`. If - `doApplySkyCorr` is set to `True`, `skyCorr` cannot be `None`. - refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional - Loader to find objects within a reference catalog. - - Returns - ------- - subtractorExp : `~lsst.afw.image.ExposureF` - An Exposure containing a scaled bright star model fit to every - bright star profile; its image can then be subtracted from the - input exposure. - invImages : `list` [`~lsst.afw.image.MaskedImageF`] - A list of small images ("stamps") containing the model, each scaled - to its corresponding input bright star. - """ - self.inputExpBBox = inputExposure.getBBox() - if self.config.doApplySkyCorr and (skyCorr is not None): - self.log.info( - "Applying sky correction to exposure %s (exposure will be modified in-place).", dataId - ) - self.applySkyCorr(inputExposure, skyCorr) - # Create an empty image the size of the exposure. - # TODO: DM-31085 (set mask planes). - subtractorExp = ExposureF(bbox=inputExposure.getBBox()) - subtractor = subtractorExp.maskedImage - # Make a copy of the input model. - self.model = inputExtendedPsf(dataId["detector"]).clone() - self.modelStampSize = self.model.getDimensions() - self.inv90Rots = 4 - inputBrightStarStamps.nb90Rots % 4 - self.model = rotateImageBy90(self.model, self.inv90Rots) - - brightStarList = self.makeBrightStarList(inputBrightStarStamps, inputExposure, refObjLoader) - self.setMissedStarsStatsControl() - # This might change when we use multiple categories of stars for creating PSF. - innerRadius = inputBrightStarStamps._innerRadius - outerRadius = inputBrightStarStamps._outerRadius - brightStarStamps, badStamps = BrightStarStamps.initAndNormalize( - brightStarList, - innerRadius=innerRadius, - outerRadius=outerRadius, - nb90Rots=self.warpOutputs.nb90Rots, - imCenter=self.warper.modelCenter, - use_archive=True, - statsControl=self.missedStatsControl, - statsFlag=self.missedStatsFlag, - badMaskPlanes=self.warper.config.badMaskPlanes, - discardNanFluxObjects=False, - forceFindFlux=True, - ) - - invImages = [] - subtractor, invImages = self.buildSubtractor( - inputBrightStarStamps, subtractor, invImages, multipleAnnuli=False - ) - if len(brightStarStamps) > 0: - self.psf_annular_fluxes = self.findPsfAnnularFluxes(brightStarStamps) - subtractor, invImages = self.buildSubtractor( - brightStarStamps, subtractor, invImages, multipleAnnuli=True - ) - badStamps = BrightStarStamps(badStamps) - - return subtractorExp, invImages, badStamps