diff --git a/src/rail/estimation/algos/lephare.py b/src/rail/estimation/algos/lephare.py index de7d8bc..d499099 100644 --- a/src/rail/estimation/algos/lephare.py +++ b/src/rail/estimation/algos/lephare.py @@ -1,5 +1,6 @@ from rail.estimation.estimator import CatEstimator, CatInformer from rail.core.common_params import SHARED_PARAMS +from ceci.config import StageParameter as Param import os import lephare as lp import numpy as np @@ -9,7 +10,7 @@ class LephareInformer(CatInformer): """Inform stage for LephareEstimator - This class will set templates and filters required for photoz estimation + This class will set templates and filters required for photoz estimation. """ name = "LephareInformer" @@ -24,16 +25,63 @@ class LephareInformer(CatInformer): err_bands=SHARED_PARAMS, ref_band=SHARED_PARAMS, redshift_col=SHARED_PARAMS, + lephare_config_file=Param( + str, + "{}/{}".format(os.path.dirname(os.path.abspath(__file__)), "lsst.para"), + msg="Path to the lephare config in .para format", + ), + star_sed=Param( + str, + "$LEPHAREDIR/examples/STAR_MOD_ALL.list", + msg="Path to text file containing list of star SED templates", + ), + qso_sed=Param( + str, + "$LEPHAREDIR/sed/QSO/SALVATO09/AGN_MOD.list", + msg="Path to text file containing list of galaxy SED templates", + ), + gal_sed=Param( + str, + "$LEPHAREDIR/examples/COSMOS_MOD.list", + msg="Path to text file containing list of quasar SED templates", + ), + star_mag_dict=Param( + dict, + dict( + lib_ascii="YES", + ), + msg="Dictionary of values sent to MagGal for stars", + ), + gal_mag_dict=Param( + dict, + dict( + lib_ascii="YES", + mod_extinc="18,26,26,33,26,33,26,33", + extinc_law=( + "SMC_prevot.dat,SB_calzetti.dat," + "SB_calzetti_bump1.dat,SB_calzetti_bump2.dat" + ), + em_lines="EMP_UV", + em_dispersion="0.5,0.75,1.,1.5,2.", + ), + msg="Dictionary of values sent to MagGal for galaxies", + ), + qso_mag_dict=Param( + dict, + dict( + lib_ascii="YES", + mod_extinc="0,1000", + eb_v="0.,0.1,0.2,0.3", + extinc_law="SB_calzetti.dat", + ), + msg="Dictionary of values sent to MagGal for quasars", + ), ) def __init__(self, args, comm=None): """Init function, init config stuff (COPIED from rail_bpz)""" CatInformer.__init__(self, args, comm=comm) - # Default local parameters - self.config_file = "{}/{}".format( - os.path.dirname(os.path.abspath(__file__)), "lsst.para" - ) - self.lephare_config = lp.read_config(self.config_file) + self.lephare_config = lp.read_config(self.config["lephare_config_file"]) def _set_config(self, lephare_config): """Update the lephare config @@ -48,7 +96,7 @@ def _set_config(self, lephare_config): def _create_filter_library(self): """Make the filter library files in lephare format""" # load filters from config file - filterLib = lp.FilterSvc.from_config(self.config_file) + filterLib = lp.FilterSvc.from_config(self.config["lephare_config_file"]) # Get location to store filter files filter_output = os.path.join( os.environ["LEPHAREWORK"], "filt", self.lephare_config["FILTER_FILE"].value @@ -64,9 +112,9 @@ def _create_sed_library(self): We separately create the star, quasar and galaxy libraries. """ sedlib = lp.Sedtolib(config_keymap=self.lephare_config) - sedlib.run(typ="STAR", star_sed="$LEPHAREDIR/examples/STAR_MOD_ALL.list") - sedlib.run(typ="QSO", qso_sed="$LEPHAREDIR/sed/QSO/SALVATO09/AGN_MOD.list") - sedlib.run(typ="GAL", gal_sed="$LEPHAREDIR/examples/COSMOS_MOD.list") + sedlib.run(typ="STAR", star_sed=self.config["star_sed"]) + sedlib.run(typ="GAL", gal_sed=self.config["gal_sed"]) + sedlib.run(typ="QSO", qso_sed=self.config["qso_sed"]) def _create_mag_library(self): """Make the magnitudes library file in lephare format. @@ -76,31 +124,15 @@ def _create_mag_library(self): TODO: replace hardcoded config options with class config options. """ maglib = lp.MagGal(config_keymap=self.lephare_config) - maglib.run(typ="STAR", lib_ascii="YES") - maglib.run( - typ="QSO", - lib_ascii="YES", - mod_extinc="0,1000", - eb_v="0.,0.1,0.2,0.3", - extinc_law="SB_calzetti.dat", - ) - maglib.run( - typ="GAL", - lib_ascii="YES", - mod_extinc="18,26,26,33,26,33,26,33", - extinc_law=( - "SMC_prevot.dat,SB_calzetti.dat," - + "SB_calzetti_bump1.dat,SB_calzetti_bump2.dat" - ), - em_lines="EMP_UV", - em_dispersion="0.5,0.75,1.,1.5,2.", - ) + maglib.run(typ="STAR", **self.config["star_mag_dict"]) + maglib.run(typ="GAL", **self.config["gal_mag_dict"]) + maglib.run(typ="QSO", **self.config["qso_mag_dict"]) def run(self): """Run rail_lephare inform stage. - This is the basic informer which takes the config and templates and - makes the inputs required for the run. + This informer takes the config and templates and makes the inputs + required for the run. In addition to the three lephare stages making the filter, sed, and magnitude libraries we also do some tasks required by all rail inform @@ -124,7 +156,7 @@ def run(self): self.szs = training_data[self.config.redshift_col] # Give principle inform config 'model' to instance. - self.model = dict(config_file=self.config_file) + self.model = dict(lephare_config_file=self.config["lephare_config_file"]) self.add_data("model", self.model) @@ -145,10 +177,27 @@ class LephareEstimator(CatEstimator): ref_band=SHARED_PARAMS, err_bands=SHARED_PARAMS, redshift_col=SHARED_PARAMS, + lephare_config_file=Param( + str, + "{}/{}".format(os.path.dirname(os.path.abspath(__file__)), "lsst.para"), + msg="Path to the lephare config in .para format", + ), ) def __init__(self, args, comm=None): CatEstimator.__init__(self, args, comm=comm) + self.lephare_config = lp.read_config(self.config["lephare_config_file"]) + self.photz = lp.PhotoZ(self.lephare_config) + + def _estimate_pdf(self, onesource): + """Return the pdf of a single source. + + Do we want to resample on RAIL z grid? + """ + # Check this is the best way to access pdf + pdf = onesource.pdfmap[11] # 11 = Bayesian galaxy redshift + # return the PDF as an array alongside lephare native zgrid + return np.array(pdf.vPDF), np.array(pdf.xaxis) # Default local parameters self.config_file = "{}/{}".format( @@ -189,9 +238,6 @@ def _process_chunk(self, start, end, data, first): pdfs = [] # np.zeros((ng, nz)) zmode = np.zeros(ng) zmean = np.zeros(ng) - # What are tb and todds? - # tb = np.zeros(ng) - # todds = np.zeros(ng) zgrid = self.zgrid # Loop over all ng galaxies! @@ -220,8 +266,7 @@ def _process_chunk(self, start, end, data, first): for i in range(ng): pdf, zgrid = self._estimate_pdf(photozlist[i]) pdfs.append(pdf) - # Take median incase multiple probability densities are equal - # TODO: why is that happening? + # Take median in case multiple probability densities are equal zmode[i] = np.median( zgrid[np.where(pdfs[i] == np.max(pdfs[i]))[0].astype(int)] )