From a7a7d6127cd0ff4dcb4a4b63e50c07d925085fe0 Mon Sep 17 00:00:00 2001
From: Ian Sullivan <sullii@uw.edu>
Date: Thu, 21 Nov 2019 12:28:51 -0800
Subject: [PATCH] Convert DcrAssembleCoadd to a pipelineTask

---
 python/lsst/pipe/tasks/dcrAssembleCoadd.py | 216 ++++++++++++++++++---
 1 file changed, 194 insertions(+), 22 deletions(-)

diff --git a/python/lsst/pipe/tasks/dcrAssembleCoadd.py b/python/lsst/pipe/tasks/dcrAssembleCoadd.py
index fafdc3244..5f54a7a92 100644
--- a/python/lsst/pipe/tasks/dcrAssembleCoadd.py
+++ b/python/lsst/pipe/tasks/dcrAssembleCoadd.py
@@ -27,18 +27,80 @@
 import lsst.afw.image as afwImage
 import lsst.afw.table as afwTable
 import lsst.coadd.utils as coaddUtils
+from lsst.daf.butler import DeferredDatasetHandle
 from lsst.ip.diffim.dcrModel import applyDcr, calculateDcr, DcrModel
 import lsst.meas.algorithms as measAlg
 from lsst.meas.base import SingleFrameMeasurementTask
 import lsst.pex.config as pexConfig
 import lsst.pipe.base as pipeBase
-from .assembleCoadd import AssembleCoaddTask, CompareWarpAssembleCoaddTask, CompareWarpAssembleCoaddConfig
+import lsst.utils as utils
+from .assembleCoadd import (AssembleCoaddTask,
+                            CompareWarpAssembleCoaddConfig,
+                            CompareWarpAssembleCoaddTask)
+from .coaddBase import makeSkyInfo
 from .measurePsf import MeasurePsfTask
 
-__all__ = ["DcrAssembleCoaddTask", "DcrAssembleCoaddConfig"]
+__all__ = ["DcrAssembleCoaddConnections", "DcrAssembleCoaddTask", "DcrAssembleCoaddConfig"]
+
+
+class DcrAssembleCoaddConnections(pipeBase.PipelineTaskConnections,
+                                  dimensions=("tract", "patch", "abstract_filter", "skymap"),
+                                  defaultTemplates={"inputCoaddName": "deep",
+                                                    "outputCoaddName": "dcr",
+                                                    "warpType": "direct",
+                                                    "warpTypeSuffix": "",
+                                                    "fakesType": ""}):
+    inputWarps = pipeBase.connectionTypes.Input(
+        doc=("Input list of warps to be assembled i.e. stacked."
+             "WarpType (e.g. direct, psfMatched) is controlled by the warpType config parameter"),
+        name="{inputCoaddName}Coadd_{warpType}Warp",
+        storageClass="ExposureF",
+        dimensions=("tract", "patch", "skymap", "visit", "instrument"),
+        deferLoad=True,
+        multiple=True
+    )
+    skyMap = pipeBase.connectionTypes.Input(
+        doc="Input definition of geometry/bbox and projection/wcs for coadded exposures",
+        name="{inputCoaddName}Coadd_skyMap",
+        storageClass="SkyMap",
+        dimensions=("skymap", ),
+    )
+    brightObjectMask = pipeBase.connectionTypes.PrerequisiteInput(
+        doc=("Input Bright Object Mask mask produced with external catalogs to be applied to the mask plane"
+             " BRIGHT_OBJECT."),
+        name="brightObjectMask",
+        storageClass="ObjectMaskCatalog",
+        dimensions=("tract", "patch", "skymap", "abstract_filter"),
+    )
+    templateExposure = pipeBase.connectionTypes.Input(
+        doc="Input coadded exposure, produced by previous call to AssembleCoadd",
+        name="{fakesType}{inputCoaddName}Coadd{warpTypeSuffix}",
+        storageClass="ExposureF",
+        dimensions=("tract", "patch", "skymap", "abstract_filter"),
+    )
+    dcrCoadds = pipeBase.connectionTypes.Output(
+        doc="Output coadded exposure, produced by stacking input warps",
+        name="{fakesType}{outputCoaddName}Coadd{warpTypeSuffix}",
+        storageClass="ExposureF",
+        dimensions=("tract", "patch", "skymap", "abstract_filter", "subfilter"),
+        multiple=True,
+    )
+    dcrNImages = pipeBase.connectionTypes.Output(
+        doc="Output image of number of input images per pixel",
+        name="{outputCoaddName}Coadd_nImage",
+        storageClass="ImageU",
+        dimensions=("tract", "patch", "skymap", "abstract_filter", "subfilter"),
+        multiple=True,
+    )
 
+    def __init__(self, *, config=None):
+        super().__init__(config=config)
+        if not config.doWrite:
+            self.outputs.remove("dcrCoadds")
 
-class DcrAssembleCoaddConfig(CompareWarpAssembleCoaddConfig):
+
+class DcrAssembleCoaddConfig(CompareWarpAssembleCoaddConfig,
+                             pipelineConnections=DcrAssembleCoaddConnections):
     dcrNumSubfilters = pexConfig.Field(
         dtype=int,
         doc="Number of sub-filters to forward model chromatic effects to fit the supplied exposures.",
@@ -253,6 +315,60 @@ def __init__(self, *args, **kwargs):
             self.makeSubtask("measurePsfSources", schema=self.schema)
             self.makeSubtask("measurePsf", schema=self.schema)
 
+    @utils.inheritDoc(pipeBase.PipelineTask)
+    def runQuantum(self, butlerQC, inputRefs, outputRefs):
+        # Docstring to be formatted with info from PipelineTask.runQuantum
+        """
+        Notes
+        -----
+        Assemble a coadd from a set of Warps.
+
+        PipelineTask (Gen3) entry point to Coadd a set of Warps.
+        Analogous to `runDataRef`, it prepares all the data products to be
+        passed to `run`, and processes the results before returning a struct
+        of results to be written out. AssembleCoadd cannot fit all Warps in memory.
+        Therefore, its inputs are accessed subregion by subregion
+        by the Gen3 `DeferredDatasetHandle` that is analagous to the Gen2
+        `lsst.daf.persistence.ButlerDataRef`. Any updates to this method should
+        correspond to an update in `runDataRef` while both entry points
+        are used.
+        """
+        inputData = butlerQC.get(inputRefs)
+
+        # Construct skyInfo expected by run
+        # Do not remove skyMap from inputData in case makeSupplementaryDataGen3 needs it
+        skyMap = inputData["skyMap"]
+        outputDataId = butlerQC.quantum.dataId
+
+        inputData['skyInfo'] = makeSkyInfo(skyMap,
+                                           tractId=outputDataId['tract'],
+                                           patchId=outputDataId['patch'])
+
+        # Construct list of input Deferred Datasets
+        # These quack a bit like like Gen2 DataRefs
+        warpRefList = inputData['inputWarps']
+        # Perform same middle steps as `runDataRef` does
+        inputs = self.prepareInputs(warpRefList)
+        self.log.info("Found %d %s", len(inputs.tempExpRefList),
+                      self.getTempExpDatasetName(self.warpType))
+        if len(inputs.tempExpRefList) == 0:
+            self.log.warn("No coadd temporary exposures found")
+            return
+
+        supplementaryData = self.makeSupplementaryDataGen3(butlerQC, inputRefs, outputRefs)
+        retStruct = self.run(inputData['skyInfo'], inputs.tempExpRefList, inputs.imageScalerList,
+                             inputs.weightList, supplementaryData=supplementaryData)
+
+        inputData.setdefault('brightObjectMask', None)
+        for subfilter in range(self.config.dcrNumSubfilters):
+            # Use the PSF of the stacked dcrModel, and do not recalculate the PSF for each subfilter
+            retStruct.dcrCoadds[subfilter].setPsf(retStruct.coaddExposure.getPsf())
+            self.processResults(retStruct.dcrCoadds[subfilter], inputData['brightObjectMask'], outputDataId)
+
+        if self.config.doWrite:
+            butlerQC.put(retStruct, outputRefs)
+        return retStruct
+
     @pipeBase.timeMethod
     def runDataRef(self, dataRef, selectDataList=None, warpRefList=None):
         """Assemble a coadd from a set of warps.
@@ -315,6 +431,23 @@ def runDataRef(self, dataRef, selectDataList=None, warpRefList=None):
 
         return results
 
+    @utils.inheritDoc(AssembleCoaddTask)
+    def makeSupplementaryDataGen3(self, butlerQC, inputRefs, outputRefs):
+        """Load the previously-generated template coadd.
+
+        This can be removed entirely once we no longer support the Gen 2 butler.
+
+        Returns
+        -------
+        templateCoadd : `lsst.pipe.base.Struct`
+           Result struct with components:
+
+           - ``templateCoadd``: coadded exposure (`lsst.afw.image.ExposureF`)
+        """
+        templateCoadd = butlerQC.get(inputRefs.templateExposure)
+
+        return pipeBase.Struct(templateCoadd=templateCoadd)
+
     def measureCoaddPsf(self, coaddExposure):
         """Detect sources on the coadd exposure and measure the final PSF.
 
@@ -350,7 +483,8 @@ def prepareDcrInputs(self, templateCoadd, warpRefList, weightList):
         ----------
         templateCoadd : `lsst.afw.image.ExposureF`
             The initial coadd exposure before accounting for DCR.
-        warpRefList : `list` of `lsst.daf.persistence.ButlerDataRef`
+        warpRefList : `list` of `lsst.daf.butler.DeferredDatasetHandle` or
+            `lsst.daf.persistence.ButlerDataRef`
             The data references to the input warped exposures.
         weightList : `list` of `float`
             The weight to give each input exposure in the coadd
@@ -378,9 +512,16 @@ def prepareDcrInputs(self, templateCoadd, warpRefList, weightList):
         angleDict = {}
         psfSizeDict = {}
         for visitNum, warpExpRef in enumerate(warpRefList):
-            visitInfo = warpExpRef.get(tempExpName + "_visitInfo")
-            visit = warpExpRef.dataId["visit"]
-            psf = warpExpRef.get(tempExpName).getPsf()
+            if isinstance(warpExpRef, DeferredDatasetHandle):
+                # Gen 3 API
+                visitInfo = warpExpRef.get(component="visitInfo")
+                psf = warpExpRef.get(component="psf")
+                visit = warpExpRef.datasetRefOrType.dataId["visit"]
+            else:
+                # Gen 2 API. Delete this when Gen 2 retired
+                visitInfo = warpExpRef.get(tempExpName + "_visitInfo")
+                psf = warpExpRef.get(tempExpName).getPsf()
+                visit = warpExpRef.dataId["visit"]
             psfSize = psf.computeShape().getDeterminantRadius()*sigma2fwhm
             airmass = visitInfo.getBoresightAirmass()
             parallacticAngle = visitInfo.getBoresightParAngle().asDegrees()
@@ -431,7 +572,8 @@ def run(self, skyInfo, warpRefList, imageScalerList, weightList,
         ----------
         skyInfo : `lsst.pipe.base.Struct`
             Patch geometry information, from getSkyInfo
-        warpRefList : `list` of `lsst.daf.persistence.ButlerDataRef`
+        warpRefList : `list` of `lsst.daf.butler.DeferredDatasetHandle` or
+            `lsst.daf.persistence.ButlerDataRef`
             The data references to the input warped exposures.
         imageScalerList : `list` of `lsst.pipe.task.ImageScaler`
             The image scalars correct for the zero point of the exposures.
@@ -571,7 +713,8 @@ def calculateNImage(self, dcrModels, bbox, warpRefList, spanSetMaskList, statsCt
             Best fit model of the true sky after correcting chromatic effects.
         bbox : `lsst.geom.box.Box2I`
             Bounding box of the patch to coadd.
-        warpRefList : `list` of `lsst.daf.persistence.ButlerDataRef`
+        warpRefList : `list` of `lsst.daf.butler.DeferredDatasetHandle` or
+            `lsst.daf.persistence.ButlerDataRef`
             The data references to the input warped exposures.
         spanSetMaskList : `list` of `dict` containing spanSet lists, or None
             Each element of the `dict` contains the new mask plane name
@@ -592,7 +735,12 @@ def calculateNImage(self, dcrModels, bbox, warpRefList, spanSetMaskList, statsCt
         dcrWeights = [afwImage.ImageF(bbox) for subfilter in range(self.config.dcrNumSubfilters)]
         tempExpName = self.getTempExpDatasetName(self.warpType)
         for warpExpRef, altMaskSpans in zip(warpRefList, spanSetMaskList):
-            exposure = warpExpRef.get(tempExpName + "_sub", bbox=bbox)
+            if isinstance(warpExpRef, DeferredDatasetHandle):
+                # Gen 3 API
+                exposure = warpExpRef.get(parameters={'bbox': bbox})
+            else:
+                # Gen 2 API. Delete this when Gen 2 retired
+                exposure = warpExpRef.get(tempExpName + "_sub", bbox=bbox)
             visitInfo = exposure.getInfo().getVisitInfo()
             wcs = exposure.getInfo().getWcs()
             mask = exposure.mask
@@ -644,7 +792,8 @@ def dcrAssembleSubregion(self, dcrModels, subExposures, bbox, dcrBBox, warpRefLi
             Bounding box of the subregion to coadd.
         dcrBBox : `lsst.geom.box.Box2I`
             Sub-region of the coadd which includes a buffer to allow for DCR.
-        warpRefList : `list` of `lsst.daf.persistence.ButlerDataRef`
+        warpRefList : `list` of `lsst.daf.butler.DeferredDatasetHandle` or
+            `lsst.daf.persistence.ButlerDataRef`
             The data references to the input warped exposures.
         statsCtrl : `lsst.afw.math.StatisticsControl`
             Statistics control object for coadd
@@ -664,7 +813,11 @@ def dcrAssembleSubregion(self, dcrModels, subExposures, bbox, dcrBBox, warpRefLi
         residualGeneratorList = []
 
         for warpExpRef in warpRefList:
-            exposure = subExposures[warpExpRef.dataId["visit"]]
+            if isinstance(warpExpRef, DeferredDatasetHandle):
+                visit = warpExpRef.datasetRefOrType.dataId["visit"]
+            else:
+                visit = warpExpRef.dataId["visit"]
+            exposure = subExposures[visit]
             visitInfo = exposure.getInfo().getVisitInfo()
             wcs = exposure.getInfo().getWcs()
             templateImage = dcrModels.buildMatchedTemplate(exposure=exposure,
@@ -785,7 +938,8 @@ def calculateConvergence(self, dcrModels, subExposures, bbox, warpRefList, weigh
             The pre-loaded exposures for the current subregion.
         bbox : `lsst.geom.box.Box2I`
             Sub-region to coadd
-        warpRefList : `list` of `lsst.daf.persistence.ButlerDataRef`
+        warpRefList : `list` of `lsst.daf.butler.DeferredDatasetHandle` or
+            `lsst.daf.persistence.ButlerDataRef`
             The data references to the input warped exposures.
         weightList : `list` of `float`
             The weight to give each input exposure in the coadd
@@ -807,10 +961,14 @@ def calculateConvergence(self, dcrModels, subExposures, bbox, warpRefList, weigh
         metric = 0.
         metricList = {}
         for warpExpRef, expWeight in zip(warpRefList, weightList):
-            exposure = subExposures[warpExpRef.dataId["visit"]][bbox]
+            if isinstance(warpExpRef, DeferredDatasetHandle):
+                visit = warpExpRef.datasetRefOrType.dataId["visit"]
+            else:
+                visit = warpExpRef.dataId["visit"]
+            exposure = subExposures[visit][bbox]
             singleMetric = self.calculateSingleConvergence(dcrModels, exposure, significanceImage, statsCtrl)
             metric += singleMetric
-            metricList[warpExpRef.dataId["visit"]] = singleMetric
+            metricList[visit] = singleMetric
             weight += 1.
         self.log.info("Individual metrics:\n%s", metricList)
         return 1.0 if weight == 0.0 else metric/weight
@@ -885,7 +1043,8 @@ def fillCoadd(self, dcrModels, skyInfo, warpRefList, weightList, calibration=Non
             Best fit model of the true sky after correcting chromatic effects.
         skyInfo : `lsst.pipe.base.Struct`
             Patch geometry information, from getSkyInfo
-        warpRefList : `list` of `lsst.daf.persistence.ButlerDataRef`
+        warpRefList : `list` of `lsst.daf.butler.DeferredDatasetHandle` or
+            `lsst.daf.persistence.ButlerDataRef`
             The data references to the input warped exposures.
         weightList : `list` of `float`
             The weight to give each input exposure in the coadd
@@ -1074,7 +1233,8 @@ def loadSubExposures(self, bbox, statsCtrl, warpRefList, imageScalerList, spanSe
             Sub-region to coadd
         statsCtrl : `lsst.afw.math.StatisticsControl`
             Statistics control object for coadd
-        warpRefList : `list` of `lsst.daf.persistence.ButlerDataRef`
+        warpRefList : `list` of `lsst.daf.butler.DeferredDatasetHandle` or
+            `lsst.daf.persistence.ButlerDataRef`
             The data references to the input warped exposures.
         imageScalerList : `list` of `lsst.pipe.task.ImageScaler`
             The image scalars correct for the zero point of the exposures.
@@ -1093,7 +1253,12 @@ def loadSubExposures(self, bbox, statsCtrl, warpRefList, imageScalerList, spanSe
         zipIterables = zip(warpRefList, imageScalerList, spanSetMaskList)
         subExposures = {}
         for warpExpRef, imageScaler, altMaskSpans in zipIterables:
-            exposure = warpExpRef.get(tempExpName + "_sub", bbox=bbox)
+            if isinstance(warpExpRef, DeferredDatasetHandle):
+                visit = warpExpRef.datasetRefOrType.dataId["visit"]
+                exposure = warpExpRef.get(parameters={'bbox': bbox})
+            else:
+                visit = warpExpRef.dataId["visit"]
+                exposure = warpExpRef.get(tempExpName + "_sub", bbox=bbox)
             if altMaskSpans is not None:
                 self.applyAltMaskPlanes(exposure.mask, altMaskSpans)
             imageScaler.scaleMaskedImage(exposure.maskedImage)
@@ -1104,7 +1269,7 @@ def loadSubExposures(self, bbox, statsCtrl, warpRefList, imageScalerList, spanSe
             # Set the image value of masked pixels to zero.
             # This eliminates needing the mask plane when stacking images in ``newModelFromResidual``
             exposure.image.array[(exposure.mask.array & statsCtrl.getAndMask()) > 0] = 0.
-            subExposures[warpExpRef.dataId["visit"]] = exposure
+            subExposures[visit] = exposure
         return subExposures
 
     def selectCoaddPsf(self, templateCoadd, warpRefList):
@@ -1114,7 +1279,8 @@ def selectCoaddPsf(self, templateCoadd, warpRefList):
         ----------
         templateCoadd : `lsst.afw.image.ExposureF`
             The initial coadd exposure before accounting for DCR.
-        warpRefList : `list` of `lsst.daf.persistence.ButlerDataRef`
+        warpRefList : `list` of `lsst.daf.butler.DeferredDatasetHandle` or
+            `lsst.daf.persistence.ButlerDataRef`
             The data references to the input warped exposures.
 
         Returns
@@ -1131,9 +1297,15 @@ def selectCoaddPsf(self, templateCoadd, warpRefList):
         psfSizes = np.zeros(len(ccds))
         ccdVisits = np.array(ccds["visit"])
         for warpExpRef in warpRefList:
-            psf = warpExpRef.get(tempExpName).getPsf()
+            if isinstance(warpExpRef, DeferredDatasetHandle):
+                # Gen 3 API
+                psf = warpExpRef.get(component="psf")
+                visit = warpExpRef.datasetRefOrType.dataId["visit"]
+            else:
+                # Gen 2 API. Delete this when Gen 2 retired
+                psf = warpExpRef.get(tempExpName).getPsf()
+                visit = warpExpRef.dataId["visit"]
             psfSize = psf.computeShape().getDeterminantRadius()*sigma2fwhm
-            visit = warpExpRef.dataId["visit"]
             psfSizes[ccdVisits == visit] = psfSize
         # Note that the input PSFs include DCR, which should be absent from the DcrCoadd
         # The selected PSFs are those that have a FWHM less than or equal to the smaller