diff --git a/src/rail/pipelines/degradation/truth_to_observed.py b/src/rail/pipelines/degradation/truth_to_observed.py new file mode 100644 index 0000000..a6708f2 --- /dev/null +++ b/src/rail/pipelines/degradation/truth_to_observed.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python +# coding: utf-8 + +# Prerquisites, os, and numpy +import os +import numpy as np + +# Various rail modules +from rail.tools.photometry_tools import Dereddener, Reddener + +from rail.core.stage import RailStage, RailPipeline + +import ceci + +from rail.core.utils import RAILDIR + +from rail.creation.degraders.unrec_bl_model import UnrecBlModel + +from .spectroscopic_selection_pipeline import SELECTORS, CommonConfigParams +from .apply_phot_errors import ERROR_MODELS + + +if 'PZ_DUSTMAP_DIR' not in os.environ: # pragma: no cover + os.environ['PZ_DUSTMAP_DIR'] = '.' + +dustmap_dir = os.path.expandvars("${PZ_DUSTMAP_DIR}") + + +class TruthToObservedPipeline(RailPipeline): + + default_input_dict = dict(input='dummy.in') + + def __init__(self, error_models=None, selectors=None, blending=False): + RailPipeline.__init__(self) + + DS = RailStage.data_store + DS.__class__.allow_overwrite = True + + if error_models is None: + error_models = ERROR_MODELS.copy() + + if selectors is None: + selectors = SELECTORS.copy() + + config_pars = CommonConfigParams.copy() + + self.reddener = Reddener.build( + dustmap_dir=dustmap_dir, + copy_all_cols=True, + ) + previous_stage = self.reddener + + if blending: + self.unrec_bl = UnrecBlModel.build() + previous_stage = self.unrec_bl + + for key, val in error_models.items(): + error_model_class = ceci.PipelineStage.get_stage(val['ErrorModel'], val['Module']) + the_error_model = error_model_class.make_and_connect( + name=f'error_model_{key}', + connections=dict(input=previous_stage.io.output), + hdf5_groupname='', + ) + self.add_stage(the_error_model) + previous_stage = the_error_model + + dereddener_errors = Dereddener.make_and_connect( + name=f"deredden_{key}", + dustmap_dir=dustmap_dir, + connections=dict(input=previous_stage.io.output), + copy_all_cols=True, + ) + self.add_stage(dereddener_errors) + previous_stage = dereddener_errors + + for key2, val2 in selectors.items(): + the_class = ceci.PipelineStage.get_stage(val2['Select'], val2['Module']) + the_selector = the_class.make_and_connect( + name=f'select_{key}_{key2}', + connections=dict(input=previous_stage.io.output), + **config_pars, + ) + self.add_stage(the_selector) diff --git a/tests/astro_tools/test_pipline.py b/tests/astro_tools/test_pipline.py index 75efe47..61c9910 100644 --- a/tests/astro_tools/test_pipline.py +++ b/tests/astro_tools/test_pipline.py @@ -4,13 +4,14 @@ import pytest @pytest.mark.parametrize( - "pipeline_class", + "pipeline_class, options", [ - 'rail.pipelines.degradation.apply_phot_errors.ApplyPhotErrorsPipeline', - 'rail.pipelines.degradation.blending.BlendingPipeline', - 'rail.pipelines.degradation.spectroscopic_selection_pipeline.SpectroscopicSelectionPipeline', + ('rail.pipelines.degradation.apply_phot_errors.ApplyPhotErrorsPipeline', {}), + ('rail.pipelines.degradation.blending.BlendingPipeline', {}), + ('rail.pipelines.degradation.spectroscopic_selection_pipeline.SpectroscopicSelectionPipeline', {}), + ('rail.pipelines.degradation.truth_to_observed.TruthToObservedPipeline', {'blending':True}), ] ) -def test_build_and_read_pipeline(pipeline_class): - build_and_read_pipeline(pipeline_class) +def test_build_and_read_pipeline(pipeline_class, options): + build_and_read_pipeline(pipeline_class, **options)