From 8f5f4c835947beab6b514ae6a007b7a69f6c43aa Mon Sep 17 00:00:00 2001 From: cheth Date: Wed, 15 May 2024 16:04:57 -0700 Subject: [PATCH] 2021 Capstone publication --- .../daly/daly_calculator.py | 180 ++ .../hale/hale_calculator.py | 233 +++ .../life_expectancy/README.md | 12 + .../life_expectancy/lex.py | 183 ++ .../life_expectancy/lexmodel.py | 903 +++++++++ .../met_need/README.md | 45 + .../met_need/arc_main.py | 370 ++++ .../met_need/arc_method.py | 1082 +++++++++++ .../met_need/collect_submodels.py | 321 +++ .../met_need/constants.py | 201 ++ .../get_model_weights_from_holdouts.py | 210 ++ .../met_need/model_restrictions.py | 98 + .../met_need/omega_selection_strategy.py | 355 ++++ .../met_need/predictive_validity.py | 106 + .../migration/README.md | 61 + .../migration/age_sex_split.py | 122 ++ .../migration/aggregate_shocks_and_sdi.py | 193 ++ .../migration/arima_and_generate_draws.py | 263 +++ .../migration/balance_migration.py | 288 +++ .../migration/csv_to_xr.py | 124 ++ .../migration/migration_rate_to_count.py | 144 ++ .../migration/model_migration.py | 352 ++++ .../migration/model_strategy.py | 66 + .../migration/model_strategy_queries.py | 143 ++ .../migration/run_model.py | 293 +++ .../mortality/README.md | 135 ++ .../data-transformation/correlate.py | 107 + .../data-transformation/exponentiate_draws.py | 65 + .../data-transformation/intercept_shift.py | 469 +++++ .../mortality/lib/config_dataclasses.py | 57 + .../mortality/lib/downloaders.py | 809 ++++++++ .../mortality/lib/get_fatal_causes.py | 93 + .../mortality/lib/intercept_shift.py | 128 ++ .../mortality/lib/make_all_cause.py | 468 +++++ .../mortality/lib/make_hierarchies.py | 116 ++ .../mortality/lib/mortality_approximation.py | 788 ++++++++ .../mortality/lib/run_cod_model.py | 812 ++++++++ .../mortality/lib/smoothing.py | 91 + .../mortality/lib/squeeze.py | 433 +++++ .../mortality/lib/sum_to_all_cause.py | 425 ++++ .../mortality/lib/y_star.py | 495 +++++ .../mortality/models/gk-model/GKModel.py | 455 +++++ .../models/gk-model/model_parameters.py | 274 +++ .../mortality/models/gk-model/omega.py | 159 ++ .../mortality/models/gk-model/post_process.py | 122 ++ .../mortality/models/gk-model/pre_process.py | 97 + .../mortality/models/pooled_random_walk.py | 303 +++ .../mortality/models/random_walk.py | 163 ++ .../mortality/models/remove_drift.py | 175 ++ .../nonfatal/README.md | 77 + .../nonfatal/lib/check_entity_files.py | 78 + .../nonfatal/lib/constants.py | 54 + .../nonfatal/lib/indicator_from_ratio.py | 298 +++ .../nonfatal/lib/model_parameters.py | 20 + .../nonfatal/lib/model_strategy.py | 373 ++++ .../nonfatal/lib/model_strategy_queries.py | 199 ++ .../nonfatal/lib/ratio_from_indicators.py | 165 ++ .../nonfatal/lib/run_model.py | 414 ++++ .../nonfatal/lib/yld_from_prevalence.py | 84 + .../nonfatal/models/arc_method.py | 1082 +++++++++++ .../nonfatal/models/limetr.py | 1107 +++++++++++ .../models/omega_selection_strategy.py | 355 ++++ .../nonfatal/models/processing.py | 1724 +++++++++++++++++ .../nonfatal/models/validate.py | 39 + .../risk_factors/README.md | 113 ++ .../risk_factors/genem/arc_main.py | 371 ++++ .../risk_factors/genem/collect_submodels.py | 326 ++++ .../risk_factors/genem/constants.py | 197 ++ .../risk_factors/genem/create_stage.py | 539 ++++++ .../genem/get_model_weights_from_holdouts.py | 212 ++ .../risk_factors/genem/model_restrictions.py | 100 + .../risk_factors/genem/predictive_validity.py | 106 + .../risk_factors/genem/run_stagewise_mrbrt.py | 576 ++++++ .../risk_factors/paf/compute_paf.py | 869 +++++++++ .../risk_factors/paf/compute_scalar.py | 461 +++++ .../risk_factors/paf/constants.py | 32 + .../risk_factors/paf/forecasting_db.py | 194 ++ .../risk_factors/paf/utils.py | 303 +++ .../sev/compute_future_mediator_total_sev.py | 448 +++++ .../sev/compute_past_intrinsic_sev.py | 444 +++++ .../risk_factors/sev/constants.py | 40 + .../risk_factors/sev/mediation.py | 89 + .../risk_factors/sev/rrmax.py | 52 + .../risk_factors/sev/run_workflow.py | 488 +++++ .../vaccine/aggregate_rake.py | 213 ++ .../vaccine/constants.py | 7 + .../vaccine/model_strategy.py | 100 + .../vaccine/model_strategy_queries.py | 77 + .../vaccine/run_ratio_vaccines.py | 1194 ++++++++++++ .../vaccine/run_simple_vaccines.py | 479 +++++ .../disease_burden_forecast_code/yll/yll.py | 111 ++ .../yll/yll_calculator.py | 119 ++ .../education/arc_weight_selection.py | 452 +++++ .../education/cohort_correction.py | 482 +++++ .../education/covid/apply_shocks.py | 676 +++++++ .../education/education_transform.py | 75 + .../education/forecast_education.py | 314 +++ .../education/maternal_education.py | 83 + .../fertility/__init__.py | 0 .../fertility/constants.py | 80 + .../fertility/input_transform.py | 165 ++ .../fertility_forecast_code/fertility/main.py | 318 +++ .../fertility/model_strategy.py | 93 + .../fertility/stage_1.py | 99 + .../fertility/stage_2.py | 723 +++++++ .../fertility/stage_3.py | 947 +++++++++ .../met_need/arc_forecast.py | 190 ++ .../create_log_habitable_area.py | 63 + gbd_2021/fertility_forecast_code/u5m/u5m.py | 186 ++ 109 files changed, 32862 insertions(+) create mode 100644 gbd_2021/disease_burden_forecast_code/daly/daly_calculator.py create mode 100644 gbd_2021/disease_burden_forecast_code/hale/hale_calculator.py create mode 100644 gbd_2021/disease_burden_forecast_code/life_expectancy/README.md create mode 100644 gbd_2021/disease_burden_forecast_code/life_expectancy/lex.py create mode 100644 gbd_2021/disease_burden_forecast_code/life_expectancy/lexmodel.py create mode 100644 gbd_2021/disease_burden_forecast_code/met_need/README.md create mode 100644 gbd_2021/disease_burden_forecast_code/met_need/arc_main.py create mode 100644 gbd_2021/disease_burden_forecast_code/met_need/arc_method.py create mode 100644 gbd_2021/disease_burden_forecast_code/met_need/collect_submodels.py create mode 100644 gbd_2021/disease_burden_forecast_code/met_need/constants.py create mode 100644 gbd_2021/disease_burden_forecast_code/met_need/get_model_weights_from_holdouts.py create mode 100644 gbd_2021/disease_burden_forecast_code/met_need/model_restrictions.py create mode 100644 gbd_2021/disease_burden_forecast_code/met_need/omega_selection_strategy.py create mode 100644 gbd_2021/disease_burden_forecast_code/met_need/predictive_validity.py create mode 100644 gbd_2021/disease_burden_forecast_code/migration/README.md create mode 100644 gbd_2021/disease_burden_forecast_code/migration/age_sex_split.py create mode 100644 gbd_2021/disease_burden_forecast_code/migration/aggregate_shocks_and_sdi.py create mode 100644 gbd_2021/disease_burden_forecast_code/migration/arima_and_generate_draws.py create mode 100644 gbd_2021/disease_burden_forecast_code/migration/balance_migration.py create mode 100644 gbd_2021/disease_burden_forecast_code/migration/csv_to_xr.py create mode 100644 gbd_2021/disease_burden_forecast_code/migration/migration_rate_to_count.py create mode 100644 gbd_2021/disease_burden_forecast_code/migration/model_migration.py create mode 100644 gbd_2021/disease_burden_forecast_code/migration/model_strategy.py create mode 100644 gbd_2021/disease_burden_forecast_code/migration/model_strategy_queries.py create mode 100644 gbd_2021/disease_burden_forecast_code/migration/run_model.py create mode 100644 gbd_2021/disease_burden_forecast_code/mortality/README.md create mode 100644 gbd_2021/disease_burden_forecast_code/mortality/data-transformation/correlate.py create mode 100644 gbd_2021/disease_burden_forecast_code/mortality/data-transformation/exponentiate_draws.py create mode 100644 gbd_2021/disease_burden_forecast_code/mortality/data-transformation/intercept_shift.py create mode 100644 gbd_2021/disease_burden_forecast_code/mortality/lib/config_dataclasses.py create mode 100644 gbd_2021/disease_burden_forecast_code/mortality/lib/downloaders.py create mode 100644 gbd_2021/disease_burden_forecast_code/mortality/lib/get_fatal_causes.py create mode 100644 gbd_2021/disease_burden_forecast_code/mortality/lib/intercept_shift.py create mode 100644 gbd_2021/disease_burden_forecast_code/mortality/lib/make_all_cause.py create mode 100644 gbd_2021/disease_burden_forecast_code/mortality/lib/make_hierarchies.py create mode 100644 gbd_2021/disease_burden_forecast_code/mortality/lib/mortality_approximation.py create mode 100644 gbd_2021/disease_burden_forecast_code/mortality/lib/run_cod_model.py create mode 100644 gbd_2021/disease_burden_forecast_code/mortality/lib/smoothing.py create mode 100644 gbd_2021/disease_burden_forecast_code/mortality/lib/squeeze.py create mode 100644 gbd_2021/disease_burden_forecast_code/mortality/lib/sum_to_all_cause.py create mode 100644 gbd_2021/disease_burden_forecast_code/mortality/lib/y_star.py create mode 100644 gbd_2021/disease_burden_forecast_code/mortality/models/gk-model/GKModel.py create mode 100644 gbd_2021/disease_burden_forecast_code/mortality/models/gk-model/model_parameters.py create mode 100644 gbd_2021/disease_burden_forecast_code/mortality/models/gk-model/omega.py create mode 100644 gbd_2021/disease_burden_forecast_code/mortality/models/gk-model/post_process.py create mode 100644 gbd_2021/disease_burden_forecast_code/mortality/models/gk-model/pre_process.py create mode 100644 gbd_2021/disease_burden_forecast_code/mortality/models/pooled_random_walk.py create mode 100644 gbd_2021/disease_burden_forecast_code/mortality/models/random_walk.py create mode 100644 gbd_2021/disease_burden_forecast_code/mortality/models/remove_drift.py create mode 100644 gbd_2021/disease_burden_forecast_code/nonfatal/README.md create mode 100644 gbd_2021/disease_burden_forecast_code/nonfatal/lib/check_entity_files.py create mode 100644 gbd_2021/disease_burden_forecast_code/nonfatal/lib/constants.py create mode 100644 gbd_2021/disease_burden_forecast_code/nonfatal/lib/indicator_from_ratio.py create mode 100644 gbd_2021/disease_burden_forecast_code/nonfatal/lib/model_parameters.py create mode 100644 gbd_2021/disease_burden_forecast_code/nonfatal/lib/model_strategy.py create mode 100644 gbd_2021/disease_burden_forecast_code/nonfatal/lib/model_strategy_queries.py create mode 100644 gbd_2021/disease_burden_forecast_code/nonfatal/lib/ratio_from_indicators.py create mode 100644 gbd_2021/disease_burden_forecast_code/nonfatal/lib/run_model.py create mode 100644 gbd_2021/disease_burden_forecast_code/nonfatal/lib/yld_from_prevalence.py create mode 100644 gbd_2021/disease_burden_forecast_code/nonfatal/models/arc_method.py create mode 100644 gbd_2021/disease_burden_forecast_code/nonfatal/models/limetr.py create mode 100644 gbd_2021/disease_burden_forecast_code/nonfatal/models/omega_selection_strategy.py create mode 100644 gbd_2021/disease_burden_forecast_code/nonfatal/models/processing.py create mode 100644 gbd_2021/disease_burden_forecast_code/nonfatal/models/validate.py create mode 100644 gbd_2021/disease_burden_forecast_code/risk_factors/README.md create mode 100644 gbd_2021/disease_burden_forecast_code/risk_factors/genem/arc_main.py create mode 100644 gbd_2021/disease_burden_forecast_code/risk_factors/genem/collect_submodels.py create mode 100644 gbd_2021/disease_burden_forecast_code/risk_factors/genem/constants.py create mode 100644 gbd_2021/disease_burden_forecast_code/risk_factors/genem/create_stage.py create mode 100644 gbd_2021/disease_burden_forecast_code/risk_factors/genem/get_model_weights_from_holdouts.py create mode 100644 gbd_2021/disease_burden_forecast_code/risk_factors/genem/model_restrictions.py create mode 100644 gbd_2021/disease_burden_forecast_code/risk_factors/genem/predictive_validity.py create mode 100644 gbd_2021/disease_burden_forecast_code/risk_factors/genem/run_stagewise_mrbrt.py create mode 100644 gbd_2021/disease_burden_forecast_code/risk_factors/paf/compute_paf.py create mode 100644 gbd_2021/disease_burden_forecast_code/risk_factors/paf/compute_scalar.py create mode 100644 gbd_2021/disease_burden_forecast_code/risk_factors/paf/constants.py create mode 100644 gbd_2021/disease_burden_forecast_code/risk_factors/paf/forecasting_db.py create mode 100644 gbd_2021/disease_burden_forecast_code/risk_factors/paf/utils.py create mode 100644 gbd_2021/disease_burden_forecast_code/risk_factors/sev/compute_future_mediator_total_sev.py create mode 100644 gbd_2021/disease_burden_forecast_code/risk_factors/sev/compute_past_intrinsic_sev.py create mode 100644 gbd_2021/disease_burden_forecast_code/risk_factors/sev/constants.py create mode 100644 gbd_2021/disease_burden_forecast_code/risk_factors/sev/mediation.py create mode 100644 gbd_2021/disease_burden_forecast_code/risk_factors/sev/rrmax.py create mode 100644 gbd_2021/disease_burden_forecast_code/risk_factors/sev/run_workflow.py create mode 100644 gbd_2021/disease_burden_forecast_code/vaccine/aggregate_rake.py create mode 100644 gbd_2021/disease_burden_forecast_code/vaccine/constants.py create mode 100644 gbd_2021/disease_burden_forecast_code/vaccine/model_strategy.py create mode 100644 gbd_2021/disease_burden_forecast_code/vaccine/model_strategy_queries.py create mode 100644 gbd_2021/disease_burden_forecast_code/vaccine/run_ratio_vaccines.py create mode 100644 gbd_2021/disease_burden_forecast_code/vaccine/run_simple_vaccines.py create mode 100644 gbd_2021/disease_burden_forecast_code/yll/yll.py create mode 100644 gbd_2021/disease_burden_forecast_code/yll/yll_calculator.py create mode 100644 gbd_2021/fertility_forecast_code/education/arc_weight_selection.py create mode 100644 gbd_2021/fertility_forecast_code/education/cohort_correction.py create mode 100644 gbd_2021/fertility_forecast_code/education/covid/apply_shocks.py create mode 100644 gbd_2021/fertility_forecast_code/education/education_transform.py create mode 100644 gbd_2021/fertility_forecast_code/education/forecast_education.py create mode 100644 gbd_2021/fertility_forecast_code/education/maternal_education.py create mode 100644 gbd_2021/fertility_forecast_code/fertility/__init__.py create mode 100644 gbd_2021/fertility_forecast_code/fertility/constants.py create mode 100644 gbd_2021/fertility_forecast_code/fertility/input_transform.py create mode 100644 gbd_2021/fertility_forecast_code/fertility/main.py create mode 100644 gbd_2021/fertility_forecast_code/fertility/model_strategy.py create mode 100644 gbd_2021/fertility_forecast_code/fertility/stage_1.py create mode 100644 gbd_2021/fertility_forecast_code/fertility/stage_2.py create mode 100644 gbd_2021/fertility_forecast_code/fertility/stage_3.py create mode 100644 gbd_2021/fertility_forecast_code/met_need/arc_forecast.py create mode 100644 gbd_2021/fertility_forecast_code/pop_by_habitable_area/create_log_habitable_area.py create mode 100644 gbd_2021/fertility_forecast_code/u5m/u5m.py diff --git a/gbd_2021/disease_burden_forecast_code/daly/daly_calculator.py b/gbd_2021/disease_burden_forecast_code/daly/daly_calculator.py new file mode 100644 index 0000000..fbe1276 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/daly/daly_calculator.py @@ -0,0 +1,180 @@ +r"""For each cause, compute a daly from the sum of yld and yll. + +.. math:: + + \text{DALY} = \text{YLD} + \text{YLL} + +Parallelized by cause. + +Example call for future only: + +.. code:: bash + + fhs_pipeline_dalys_console parallelize-by-cause \ + --gbd-round-id 89 \ + --versions FILEPATH \ + -v FILEPATH \ + --draws 500 \ + --years 1776:1864:1927 \ + +Example call for future & past: + +.. code:: bash + + fhs_pipeline_dalys_console parallelize-by-cause \ + --gbd-round-id 89 \ + --versions FILEPATH \ + -v FILEPATH \ + --draws 500 \ + --years 1776:1864:1927 \ + --past include +""" # noqa: D208 +from typing import Tuple, Union + +import xarray as xr +from fhs_lib_data_transformation.lib.dimension_transformation import expand_dimensions +from fhs_lib_data_transformation.lib.resample import resample +from fhs_lib_data_transformation.lib.validate import assert_coords_same +from fhs_lib_file_interface.lib.version_metadata import FHSFileSpec +from fhs_lib_file_interface.lib.versioning import Versions +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr_scenario, save_xr_scenario +from fhs_lib_year_range_manager.lib import YearRange +from tiny_structured_logger.lib import fhs_logging + +logger = fhs_logging.get_logger() + + +def fill_missing_coords( + yld_da: xr.DataArray, yll_da: xr.DataArray, dim: str +) -> Tuple[xr.DataArray, xr.DataArray]: + """Check if there are missing coordinates between two data arrays. + + Fills any missing coordinates in the respective DAs with 0. + + Args: + yld_da (xr.DataArray): YLD data to check for missing coords. + yll_da (xr.DataArray): YLL data to check for missing coords. + dim (str): Dimension to check for missing coords. + + Returns: + Tuple(xr.DataArray, xr.DataArray): yld_da and yll_da with missing coords filled with 0 + """ + yld_coord_values = tuple(yld_da[dim].values) + yll_coord_values = tuple(yll_da[dim].values) + yll_missing_coords = list(set(yld_coord_values) - set(yll_coord_values)) + yld_missing_coords = list(set(yll_coord_values) - set(yld_coord_values)) + + if yll_missing_coords: + logger.warning(f"{dim}:{yll_missing_coords} are missing from YLLs") + yll_da = expand_dimensions(yll_da, **{dim: yll_missing_coords}, fill_value=0) + + if yld_missing_coords: + logger.warning(f"{dim}:{yld_missing_coords} are missing from YLDs") + yld_da = expand_dimensions(yld_da, **{dim: yld_missing_coords}, fill_value=0) + + return yld_da, yll_da + + +def _get_years_in_slice(years: YearRange, past_or_future: str) -> YearRange: + if past_or_future == "future": + years_in_slice = years.forecast_years + elif past_or_future == "past": + years_in_slice = years.past_years + else: + raise RuntimeError("past_or_future must be `past` or `future`") + + return years_in_slice + + +def _read_and_resample_stage( + stage: str, + versions: Versions, + gbd_round_id: int, + draws: int, + years: YearRange, + past_or_future: str, + acause: str, +) -> Union[xr.DataArray, int]: + years_in_slice = _get_years_in_slice(years, past_or_future) + + logger.info(f"acause: {acause} years: {years_in_slice}") + + stage_file_metadata = versions.get(past_or_future, stage).default_data_source(gbd_round_id) + + try: + da = open_xr_scenario( + file_spec=FHSFileSpec( + version_metadata=stage_file_metadata, filename=f"{acause}.nc" + ) + ).sel(year_id=years_in_slice) + + da = resample(da, draws) + except OSError: + logger.warning("{} does not have YLDs".format(acause)) + da = 0 + return da + + +def one_cause_main( + versions: Versions, + gbd_round_id: int, + draws: int, + years: YearRange, + past_or_future: str, + acause: str, +) -> None: + """Compute a daly from the yld and yll at the cause level. + + Args: + versions (Versions): A Versions object that keeps track of all the versions and their + respective data directories. + gbd_round_id (int): What gbd_round_id that yld, yll and daly are saved under. + draws (int): How many draws to save for the daly output. + years (str): years for calculation. Will use either the past or future portion of the + year range depending on the value of past_or_future. + past_or_future (str): whether calculating past or future values. Must be "past" or + "future". + acause (str): cause to calculate dalys for. + + Raises: + RuntimeError: if `past_or_future` is not "past" or "future" + ValueError: if the `daly` DA doesn't have YLLs or YLDs + """ + logger.info("Entering `one_cause_main` function.") + + yld = _read_and_resample_stage( + "yld", versions, gbd_round_id, draws, years, past_or_future, acause + ) + yll = _read_and_resample_stage( + "yll", versions, gbd_round_id, draws, years, past_or_future, acause + ) + + if isinstance(yld, xr.DataArray) and "acause" not in yld.dims: # type: ignore + yld = yld.expand_dims(acause=[acause]) # type: ignore + if isinstance(yll, xr.DataArray) and "acause" not in yll.dims: # type: ignore + yll = yll.expand_dims(acause=[acause]) # type: ignore + + if isinstance(yll, xr.DataArray) and isinstance(yld, xr.DataArray): + yld, yll = fill_missing_coords(yld, yll, dim="age_group_id") + yld, yll = fill_missing_coords(yld, yll, dim="sex_id") + yld, yll = fill_missing_coords(yld, yll, dim="location_id") + + assert_coords_same(yld, yll) + + daly = yld + yll + + if not isinstance(daly, xr.DataArray): + err_msg = f"{acause} is missing both YLDs and YLLs" + logger.error(err_msg) + raise ValueError(err_msg) + + daly_file_metadata = versions.get(past_or_future, "daly").default_data_source(gbd_round_id) + + save_xr_scenario( + xr_obj=daly, + file_spec=FHSFileSpec(version_metadata=daly_file_metadata, filename=f"{acause}.nc"), + metric="rate", + space="identity", + ) + + logger.info("Leaving `one_cause_main` function. DONE") \ No newline at end of file diff --git a/gbd_2021/disease_burden_forecast_code/hale/hale_calculator.py b/gbd_2021/disease_burden_forecast_code/hale/hale_calculator.py new file mode 100644 index 0000000..f5570ee --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/hale/hale_calculator.py @@ -0,0 +1,233 @@ +r"""Compute HALE (Health-Adjusted Life Expectancy) from YLD rate and Life Table. + +**NOTE:** The input is all causes YLD rate and Life Table. HALE will be +saved in the hale stage directory, and the filename will be `hale.nc`. + +.. math:: + \text{adjusted_nLx} = \text{nLx} * \text{100000} * (1-YLD_{\text{rate}}) + +.. math:: + \text{adjusted_Tx} = \sum\limits_{a=x}^{\text{max-age}} \text{adjusted_nLa} + +.. math:: + \text{hale} = \text{adjusted_Tx} / \text{lx} + +Example calculate-hale call: + +.. code:: bash + + fhs_pipeline_hale_console calculate-hale \ + --versions FILEPATH1 \ + -v FILEPATH2 \ + -v FILEPATH3 \ + --draws 500 \ + --gbd-round-id 89 \ + --years 1776:1812:1944 + +Example parallelize call: + +.. code:: bash + + fhs_pipeline_hale_console parallelize \ + --versions FILEPATH1 \ + -v FILEPATH2 \ + -v FILEPATH3 \ + -v FILEPATH4 \ + --draws 500 \ + --gbd-round-id 89 \ + --years 1776:1812:1944 + +""" # noqa: E501 +from typing import List, Tuple + +import numpy as np +import xarray as xr +from fhs_lib_data_transformation.lib.resample import resample +from fhs_lib_database_interface.lib.constants import DimensionConstants +from fhs_lib_database_interface.lib.query.age import get_ages +from fhs_lib_file_interface.lib.version_metadata import FHSFileSpec +from fhs_lib_file_interface.lib.versioning import Versions +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr_scenario, save_xr_scenario +from fhs_lib_year_range_manager.lib.year_range import YearRange +from tiny_structured_logger.lib import fhs_logging + +logger = fhs_logging.get_logger() + +LIFE_TABLE_NLX_FACTOR = 1e5 + + +def calculate_hale_main( + draws: int, + gbd_round_id: int, + past_or_future: str, + run_on_means: bool, + versions: Versions, + year_set: List[int], + include_nLx: bool = False, +) -> None: + """Compute hale from yld and life table. + + Args: + draws (int): the number of draws to forecast on + gbd_round_id (int): What gbd_round_id that yld, life table and hale are saved under. + past_or_future (str): whether we'll be reading data from the past or future. + run_on_means (bool): whether to run data on means instead of draws + versions (Versions): A Versions object that keeps track of all the versions and their + respective data directories. + year_set (List[int]): the years we care about in this task. + include_nLx (bool): Defaults to ``False``. Flag to calculate and save adjusted nLx. + If ``True``, will save ``adjusted_nLx`` data. + """ + yld, nLx, lx = _load_data( + draws, gbd_round_id, past_or_future, run_on_means, versions, year_set + ) + + # Compute adjusted_nLx + adjusted_nLx = (nLx * (1 - yld)) * LIFE_TABLE_NLX_FACTOR + adjusted_nLx.name = "adjusted_nLx" + + # Compute adjusted_Tx + adjusted_Tx = _compute_adjusted_Tx(adjusted_nLx, gbd_round_id) + + # Compute hale + hale = adjusted_Tx / lx + hale.name = "hale" + + # Save out hale + hale_version_metadata = versions.get(past_or_future, "hale").default_data_source( + gbd_round_id + ) + save_xr_scenario( + xr_obj=hale, + file_spec=FHSFileSpec(version_metadata=hale_version_metadata, filename="hale.nc"), + metric="number", + space="identity", + ) + if include_nLx: + save_xr_scenario( + xr_obj=adjusted_nLx, + file_spec=FHSFileSpec( + version_metadata=hale_version_metadata, + filename="adjusted_nLx.nc", + ), + metric="rate", + space="identity", + ) + + +def _load_data( + draws: int, + gbd_round_id: int, + past_or_future: str, + run_on_means: bool, + versions: Versions, + year_set: List[int], +) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]: + """Load ``yld``, ``nLx``, and ``lx`` data; resampling or aggregating as required.""" + # Setup input data dirs + life_expectancy_version_metadata = versions.get( + past_or_future, "life_expectancy" + ).default_data_source(gbd_round_id) + yld_version_metadata = versions.get(past_or_future, "yld").default_data_source( + gbd_round_id + ) + + # Read life table data + if (life_expectancy_version_metadata.data_path("lifetable_ds_agg.nc")).exists(): + life_table_file = "lifetable_ds_agg.nc" + else: + life_table_file = "lifetable_ds.nc" + + life_table = open_xr_scenario( + file_spec=FHSFileSpec( + version_metadata=life_expectancy_version_metadata, filename=life_table_file + ) + ) + + nLx: xr.DataArray = life_table["nLx"] + lx: xr.DataArray = life_table["lx"] + + # Read in YLD data: either from a summary file or from a file containing draws + if run_on_means: + logger.info("Reading data without draws") + yld = open_xr_scenario( + file_spec=FHSFileSpec( + version_metadata=yld_version_metadata, + sub_path=("summary_agg"), + filename="summary.nc", + ) + ) + yld = yld.sel(acause="_all", statistic="mean", drop=True) + + # Take mean over draws if they're present + if DimensionConstants.DRAW in life_table.dims: + nLx = nLx.mean(DimensionConstants.DRAW) + lx = lx.mean(DimensionConstants.DRAW) + + else: + logger.info("Reading data with draws") + yld = open_xr_scenario( + file_spec=FHSFileSpec(version_metadata=yld_version_metadata, filename="_all.nc") + ) + + if isinstance(yld, xr.Dataset): + yld = yld["value"] + yld = resample(data=yld, num_of_draws=draws) + + if "acause" in yld.dims: + yld = yld.sel(acause="_all", drop=True) + + nLx = resample(data=nLx, num_of_draws=draws) + lx = resample(data=lx, num_of_draws=draws) + + # Subset to just the years we're running on + logger.info("Subsetting yld, nLx, and lx onto relevant years") + year_set_dict = {DimensionConstants.YEAR_ID: year_set} + yld = yld.sel(**year_set_dict) + nLx = nLx.sel(**year_set_dict) + lx = lx.sel(**year_set_dict) + + return yld, nLx, lx + + +def _compute_adjusted_Tx(adjusted_nLx: xr.DataArray, gbd_round_id: int) -> xr.DataArray: + """Compute the adjusted_Tx by age group.""" + # Create age_group_id, age_group_years_start mapping + age_df = get_ages(gbd_round_id=gbd_round_id)[["age_group_id", "age_group_years_start"]] + age_dict = age_df.set_index("age_group_id")["age_group_years_start"].to_dict() + + # Compute adjusted_Tx + adjusted_Tx = adjusted_nLx.copy(deep=True) + for age_group_id in adjusted_Tx.age_group_id.data: + logger.debug(f"Computing adjusted_Tx for age group {age_group_id}") + age_group_years_start = age_dict[age_group_id] # noqa: F841 + older_age_group_ids = np.intersect1d( + age_df.query("age_group_years_start >= @age_group_years_start")[ + "age_group_id" + ].unique(), + adjusted_Tx.age_group_id.data, + ) + adjusted_Tx.loc[dict(age_group_id=age_group_id)] = adjusted_nLx.sel( + age_group_id=older_age_group_ids + ).sum("age_group_id") + + return adjusted_Tx + + +def determine_year_set(years: YearRange, past: str) -> List[int]: + """Determine which years we care about based on the value of ``past``.""" + if past == "include": + year_set = list(years.years) + elif past == "only": + year_set = list(years.past_years) + else: + year_set = list(years.forecast_years) + + return year_set + + +def determine_past_or_future(years: YearRange, year_set: List[int]) -> str: + """Determine whether the current ``year_set`` represents past or future data.""" + if set(year_set).issubset(years.past_years): + return "past" + return "future" \ No newline at end of file diff --git a/gbd_2021/disease_burden_forecast_code/life_expectancy/README.md b/gbd_2021/disease_burden_forecast_code/life_expectancy/README.md new file mode 100644 index 0000000..cac2dde --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/life_expectancy/README.md @@ -0,0 +1,12 @@ +Life Expectancy +========================= + + +lex.py +--------- +Driver code that computes the forecasted life table based on future all-cause mortality rate. + +lexmodel.py +--------- +Contains various models that help compute the life table and handle edge cases. + diff --git a/gbd_2021/disease_burden_forecast_code/life_expectancy/lex.py b/gbd_2021/disease_burden_forecast_code/life_expectancy/lex.py new file mode 100644 index 0000000..9e8f40e --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/life_expectancy/lex.py @@ -0,0 +1,183 @@ +"""Calculates Life Expectancy. +""" + +import gc +from typing import Callable, List, Optional, Tuple + +import numpy as np +import xarray as xr +from fhs_lib_data_transformation.lib.resample import resample +from fhs_lib_database_interface.lib.constants import DimensionConstants +# This refers to the lexmodel file that is also included in this directory, +# which also includes its own set of imports +from fhs_lib_demographic_calculation.lib import lexmodel as model +from fhs_lib_file_interface.lib.version_metadata import FHSFileSpec, VersionMetadata +from fhs_lib_file_interface.lib.versioning import validate_versions_scenarios +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr_scenario, save_xr_scenario +from tiny_structured_logger.lib.fhs_logging import get_logger + +logger = get_logger() + +MODELS = [mxx for mxx in dir(model) if mxx.startswith("gbd")] + + +def main( + gbd_round_id: int, + draws: int, + input_version: str, + past_or_future: str, + output_scenario: int | None, + output_version: str, + lx_function: Callable, + chunk_size: int = 100, + suffix: Optional[str] = None, +) -> None: + """Entry point for lex for outside calls. + + If ``draw`` dim exists, chunk over draws. Else if ``scenario`` exists, + chunk over scenarios. + + Args: + gbd_round_id (int): GBD Round used for input and output versions. + draws (int): number of draws to keep, if draw dimension exists. + input_version (str): The version of input to use. Include the :scenario_id notation to + run in single-scenario mode + past_or_future (str): "past" or "future". + output_scenario (int | None): Scenario ID to use for output. + output_version (str): Name to create the output data under. Include the :scenario_id + notation to run in single-scenario mode. + lx_function (Callable): The function to call to calculate life expectancy. Should be + one of the models, all of which are model.gbd* in model.py. + chunk_size (int): number of draw chunks to work on within each iteration. + suffix (Optional[str]): If not None, include this suffix as part of each filename + + Raises: + ValueError: If the mx data has negatives or infinite values. + """ + input_version_metadata: VersionMetadata = ( + VersionMetadata.parse_version(input_version, default_stage="death") + .default_data_source(gbd_round_id) + .with_epoch(past_or_future) + ) + output_version_metadata: VersionMetadata = ( + VersionMetadata.parse_version(output_version, default_stage="life_expectancy") + .default_data_source(gbd_round_id) + .with_epoch(past_or_future) + ) + + validate_versions_scenarios( + versions=[input_version_metadata, output_version_metadata], + output_scenario=output_scenario, + output_epoch_stages=[(past_or_future, "death"), (past_or_future, "life_expectancy")], + ) + + mx_file_spec = FHSFileSpec(version_metadata=input_version_metadata, filename="FILEPATH") + mx = open_xr_scenario(mx_file_spec) + + if type(mx) == xr.Dataset: + logger.info(f"Input mx to lex code is a {xr.Dataset}") + mx = mx[DimensionConstants.VALUE] + + if float(mx.min()) < 0: + raise ValueError("Negatives in mx") + + if ~np.isfinite(mx).all(): + raise ValueError("Non-finites in mx") + + mx_no_point, point_coords = model.without_point_coordinates(mx) + + del mx + gc.collect() + + # Because we operate across age group id in all the calculations so + # this will be much faster. + reordered = list(mx_no_point.dims) + reordered.remove(DimensionConstants.AGE_GROUP_ID) + reordered.append(DimensionConstants.AGE_GROUP_ID) + mx_no_point = mx_no_point.transpose(*reordered) + + mx_no_point, chunk_dim = _set_chunk_dim(mx_no_point, draws) + + if chunk_dim is None: + ds = lx_function(mx_no_point) + else: # compute over chunks along either draw or scenario dim + dim_size = len(mx_no_point[chunk_dim]) + logger.info(f"Chunking over {chunk_dim} dim over {dim_size} coords") + + chunk_da_list: List[xr.DataArray] = [] + for start_idx in range(0, dim_size, chunk_size): + end_idx = ( + start_idx + chunk_size if start_idx + chunk_size <= dim_size else dim_size + ) + mx_small = mx_no_point.sel( + {chunk_dim: mx_no_point[chunk_dim].values[start_idx:end_idx]} + ) + ds_small = lx_function(mx_small) + chunk_da_list.append(ds_small) + + # Concatenate all the small dataarrays + ds = xr.concat(chunk_da_list, dim=chunk_dim) + + del mx_no_point + gc.collect() + + ds_point = ds.assign_coords(**point_coords) + + del ds + gc.collect() + + suffix = f"_{suffix}" if suffix else "" + lex_file = FHSFileSpec( + version_metadata=output_version_metadata, filename=f"FILEPATH" + ) + save_xr_scenario( + ds_point.ex, + lex_file, + metric="number", + space="identity", + mx_source=str(mx_file_spec), + model=str(lx_function.__name__), + ) + + dataset_file = FHSFileSpec( + version_metadata=output_version_metadata, filename=f"FILEPATH" + ) + + # ds contains mx, ax, lx, nLx, and ex. + save_xr_scenario( + ds_point, + dataset_file, + metric="number", + space="identity", + mx_source=str(mx_file_spec), + model=str(lx_function.__name__), + ) + + logger.info(f"wrote {dataset_file}") + + +def _set_chunk_dim(da: xr.DataArray, draws: int) -> Tuple[xr.DataArray, Optional[str]]: + """Set chunk dimension. + + If "draw" in da.dims, resample and chunk along draws. + Otherwise, set chunk_dim to be "scenario". + If neither "draw" nor "scenario" exists, then there's no dim to chunk over. + + Args: + da (xr.DataArray): input data array. + draws (int): number of draws to resample. + + Returns: + Tuple[xr.DataArray, str]: + xr.DataArray: either resampled input da, or just the input da. + str: dimension to chunk over. Could be None. + """ + if DimensionConstants.DRAW in da.dims: # if draw dim exists, chunk over draws + da = resample(da, draws) + chunk_dim = DimensionConstants.DRAW + elif DimensionConstants.SCENARIO in da.dims: + chunk_dim = DimensionConstants.SCENARIO + else: + chunk_dim = None + + return da, chunk_dim \ No newline at end of file diff --git a/gbd_2021/disease_burden_forecast_code/life_expectancy/lexmodel.py b/gbd_2021/disease_burden_forecast_code/life_expectancy/lexmodel.py new file mode 100644 index 0000000..f9e0708 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/life_expectancy/lexmodel.py @@ -0,0 +1,903 @@ +"""Models of life expectancy. + +Models of life expectancy take mortality rate as input and return +a dataset of life expectancy, mortality rate, and mean age of death. +They are models because they have to estimate the mean age of death. + +* There are two kinds of models here, those for GBD 2015 and those + for GBD 2016. They differ by what the input and output age group IDs are. + +* For GBD 2015, the input age groups start with IDs (2, 3, 4) for ENN, + PNN, and LNN. They end with IDs (20, 21) for 75-79 and 80+. + +* For GBD 2016, the input age groups start with IDs (2, 3, 4) for ENN, + PNN, and LNN. They end with IDs 235 for 95+. + +* GBD 2017 uses the same age group IDs as 2016. + +* For both 2015/2016, we predict life expectancy on young ages (28, 5), + meaning 0-1 and 1-5. For older ages, we predict to age ID 148 for 110+. + +* For 2017, we take forecasted mx straight up to compute the the life table. + That is, no extrapolation/extension is done to ameliorate the older ages. +""" + +import gc +import warnings +from typing import Tuple, Union + +import numpy as np +import xarray as xr +from fhs_lib_data_transformation.lib.dimension_transformation import expand_dimensions +from fhs_lib_database_interface.lib.constants import DimensionConstants +from fhs_lib_database_interface.lib.query import age +from tiny_structured_logger.lib.fhs_logging import get_logger + +from fhs_lib_demographic_calculation.lib.constants import ( + AgeConstants, + LexmodelConstants, + LifeTableConstants, +) +from fhs_lib_demographic_calculation.lib.construct import ( + age_id_older_than, + consistent_age_group_ids, + nx_contiguous_round, + nx_from_age_group_ids, +) +from fhs_lib_demographic_calculation.lib.lifemodel import ( + old_age_fit_qx_95_plus, + old_age_fit_us_counties, +) +from fhs_lib_demographic_calculation.lib.lifetable import ( + ax_graduation_cubic, + cm_mean_age, + fm_mortality, + fm_period_life_expectancy, + fm_person_years, + fm_population, +) + +logger = get_logger() + + +def without_point_coordinates( + ds: Union[xr.DataArray, xr.Dataset] +) -> Tuple[xr.DataArray, dict]: + r"""Remove point coordinates and return them, so you can add them back later. + + The code would look like this: + no_point, saved_coords = without_point_coordinates(ds) + # Do things + return results.assign_coords(\**saved_coords) + + Args: + ds (Union[xr.DataArray, xr.Dataset]): A dataarray that may have point coords + + Returns: + Tuple[xr.DataArray, dict]: The point coordinates are copied and returned. + """ + point = dict( + (pname, ds.coords[pname].values.copy()) + for pname in ds.coords + if ds.coords[pname].shape == () + ) + return ds.drop_vars(list(point)), point + + +def append_nLx_lx(ds: xr.Dataset) -> xr.Dataset: + r"""Adds :math:`{}_nL_x` and :math:`l_x` to the dataset. + + It's sometimes called :math:`{}_nU_x`. + + Args: + ds (xr.Dataset): Dataset containing mx and ax + + Returns: + xr.DataSet: ds (xr.Dataset): Dataset containing mx and ax + """ + nx = nx_from_age_group_ids(ds.mx.age_group_id) + nLx = fm_person_years(ds.mx, ds.ax, nx) + lx, dx = fm_population( + ds.mx, ds.ax, nx, LifeTableConstants.DEFAULT_INITIAL_POPULATION_SIZE + ) + return ds.assign(nLx=nLx, lx=lx) + + +def under_5_ax_preston_rake( + gbd_round_id: int, + mx: xr.DataArray, + ax: xr.DataArray, + nx: xr.DataArray, +) -> xr.DataArray: + r"""Uses :math:`m_x` and Preston's Table 3.3 to compute ax for FHS under-5 age groups. + + The steps are as follows: + + (1) Aggregate neonatal mx's to make :math:`{}_1m_0` + This assumes the nLx values are approximately correct. + + (2) Use Preston's Table 3.3 to compute :math:`{}_1a_0` + using :math:`{}_1m_0` + + (3) "Rake" the ``*neonatal`` ax's such that they aggregate to :math:`{}_1a_0`. + We start with the definition + + .. math:: + {}_1a_0 = \frac{a_2 \ d_2 + (a_3 + n_2) \ d_3 + (a_4 + n_2 + n_3) + \ d_4}{d_2 + d_3 + d_4} + + where subscripts 2, 3, and 4 denote enn, lnn, and pnn age groups. + + The above equation can be rewritten as + + .. math:: + {}_1a_0 - \frac{n_2 \ d_3 + (n_2 + n_3) \ d_4}{d_2 + d_3 + d_4} = + \frac{a_2 \ d_2 + a_3 \ d_3 + a_4 \ d_4}{d_2 + d_3 + d_4} + + Because we have :math:`{}_1a_0` and assume that the dx values are + approximately correct, the left-hand side is fixed at this point. + The entire right-hand side, where the :math:`{}_na_x` values are, + needs to be "raked" to satisfy the above equation. We simply multiply + all three :math:`{}_na_x` by the same ratio. + + Args: + mx (xr.DataArray): forecasted mortality rate. + ax (xr.DataArray): ax assuming constant mortality for under 5 age + groups. + nx (xr.DataArray): nx for FHS age groups, in years. + + Raises: + ValueError: If under 5 age group is not a subset of the `mx` data + + Returns: + xr.DataArray: ax, with under-5 age groups (2, 3, 4, 5) optimized. + """ + age_group_ids_under_5_years_old = age.get_most_detailed_age_group_ids_in_age_spans( + gbd_round_id=gbd_round_id, + start=0, + end=5, + include_birth_age_group=False, + ) + if not set(age_group_ids_under_5_years_old).issubset( + mx[DimensionConstants.AGE_GROUP_ID].values + ): + raise ValueError( + f"age group {age_group_ids_under_5_years_old} is not " + f"subset of {mx[DimensionConstants.AGE_GROUP_ID].values}" + ) + + # first compute the approximately correct dx and nLx values. + # they are "approximately" correct because they's not sensitive to ax. + lx, dx = fm_population(mx, ax, nx, 1.0) + nLx = fm_person_years(mx, ax, nx) + + under_1_dx_sum = dx.sel(age_group_id=AgeConstants.AGE_GROUP_IDS_UNDER_ONE).sum( + DimensionConstants.AGE_GROUP_ID + ) + + ax4 = ax.sel(age_group_id=AgeConstants.AGE_1_TO_4_ID) # the 1-4yr ax + + mx_u5 = make_under_one_group_for_preston_with_nLx(mx, nLx) # 1m0 and 4m1 + # now compute the 1a0 and 4a1 from Preston's Table 3.3 + ax_u5 = preston_ax_fit(mx_u5) # 2 age groups: 0-1 (28) and 1-5 (5) + ax1_p = ax_u5.sel(age_group_id=AgeConstants.AGE_UNDER_ONE_ID) # Preston's 1a0 + ax4_p = ax_u5.sel(age_group_id=AgeConstants.AGE_1_TO_4_ID) # Preston's 4a1 + + left_hand_side = ( + ax1_p + - ( + nx.sel(age_group_id=2) * dx.sel(age_group_id=3) + + nx.sel(age_group_id=[2, 3]).sum(DimensionConstants.AGE_GROUP_ID) + * dx.sel(age_group_id=4) + ) + / under_1_dx_sum + ) + + right_hand_side_sum = 0 + for age_group_id in AgeConstants.AGE_GROUP_IDS_UNDER_ONE: + right_hand_side_sum += ax.sel(age_group_id=age_group_id) * dx.sel( + age_group_id=age_group_id + ) + + right_hand_side = right_hand_side_sum / under_1_dx_sum + + ax1_raking_const = left_hand_side / right_hand_side + ax1_raking_const[DimensionConstants.AGE_GROUP_ID] = AgeConstants.AGE_UNDER_ONE_ID + + ax4_ratio = ax4_p / ax4 + + # now modify our under-5 ax values + ax.loc[dict(age_group_id=AgeConstants.AGE_GROUP_IDS_UNDER_ONE)] = ( + ax.sel(age_group_id=AgeConstants.AGE_GROUP_IDS_UNDER_ONE) * ax1_raking_const + ) + ax.loc[dict(age_group_id=AgeConstants.AGE_1_TO_4_ID)] = ( + ax.sel(age_group_id=AgeConstants.AGE_1_TO_4_ID) * ax4_ratio + ) + + # safeguard against wild ax values, as a result of wild mx values + for age_group_id in AgeConstants.AGE_GROUP_IDS_UNDER_ONE + [AgeConstants.AGE_1_TO_4_ID]: + ax_age = ax.sel(age_group_id=age_group_id) + nx_age = nx.sel(age_group_id=age_group_id) + ax.loc[dict(age_group_id=age_group_id)] = ax_age.where(ax_age <= nx_age).fillna( + nx_age / 2 + ) + + return ax + + +def make_under_one_group_for_preston_with_nLx( + mx: xr.DataArray, + nLx: xr.DataArray, +) -> xr.DataArray: + r"""Create an under one age group (id 28). + + From the 0-6 days, 7-27 days, and 28-364 days (ids 2,3,4) groups. This age group is needed + for the Preston young age fit. If this is called with scalar, string dimensions on + mx, then slicing won't work and it will fail. + + Note that this method using nLx as the weight, instead of nx, based + on the definition + + .. math:: + {}_nm_x = \frac{{}_nd_x}{{}_nL_x} + + Args: + mx (xr.DataArray): Mortality data with age groups 2, 3, 4, and 5. + nLx (xr.DataArray): nLx that has age groups 2, 3, and 4 + + Returns: + xr.DataArray: mx for the under one age group (28) and + the 1-4 years age group (5). + + Raises: + RuntimeError: if age groups ids 2, 3, and 4 are NOT in input. + ValueError: if `mx_under_on` contains unexpected age group ids + """ + if all(age_x in mx.age_group_id.values for age_x in AgeConstants.AGE_GROUP_IDS_UNDER_ONE): + nLx_sum = nLx.sel(age_group_id=AgeConstants.AGE_GROUP_IDS_UNDER_ONE).sum( + DimensionConstants.AGE_GROUP_ID + ) + mx_other_ages = mx.loc[dict(age_group_id=[AgeConstants.AGE_1_TO_4_ID])] + + mx_under_one_sum = 0 + for age_group_id in AgeConstants.AGE_GROUP_IDS_UNDER_ONE: + mx_under_one_sum += nLx.sel(age_group_id=age_group_id) * mx.sel( + age_group_id=age_group_id + ) + + mx_under_one = mx_under_one_sum / nLx_sum + mx_under_one.coords[DimensionConstants.AGE_GROUP_ID] = AgeConstants.AGE_UNDER_ONE_ID + mx_under_one = xr.concat( + [mx_under_one, mx_other_ages], + dim=DimensionConstants.AGE_GROUP_ID, + ) + + if not np.array_equal( + mx_under_one.age_group_id, + AgeConstants.GBD_AGE_UNDER_FIVE, + ): + raise ValueError( + f"mx_under_one contains unexpected " + f"age group ids {mx_under_one.age_group_id}" + ) + + return mx_under_one + elif all(age_x in mx.age_group_id.values for age_x in AgeConstants.GBD_AGE_UNDER_FIVE): + return mx.sel(age_group_id=AgeConstants.GBD_AGE_UNDER_FIVE) + else: + raise RuntimeError("No known young ages in input") + + +def preston_ax_fit(mx: xr.DataArray) -> xr.DataArray: + r"""This fit is from Preston, Heuveline, and Guillot, Table 3.3. + + It comes from a fit that Coale-Demeney made in their life tables for + :math:`({}_1a_0, {}_4a_1)` from `{}_1q_0`. PHG turned it into a fit from :math:`m_x` to + :math:`a_x`. This is explored in a notebook in docs to ``fbd_core.demog``. + + Args: + mx (xr.DataArray): Mortality rate. Must have age group IDs (28,5). + + Raises: + ValueError: if an incorrect age group id is in `mx` + + Returns: + xr.DataArray: Mean age, for only those age groups where predicted. + """ + if not all(ax in mx.age_group_id.values for ax in AgeConstants.GBD_AGE_UNDER_FIVE): + raise ValueError(f"Incorrect input with age ID {mx.age_group_id.values}") + + sexes = mx.sex_id + msub = mx.loc[dict(age_group_id=AgeConstants.GBD_AGE_UNDER_FIVE)] + msub_above = msub.where(msub >= LexmodelConstants.MCUT) + msub_below = msub.where(msub < LexmodelConstants.MCUT) + + ax_fit_above = msub_above * LexmodelConstants.PHG.sel( + domain="above", const="m", sex_id=sexes + ) + LexmodelConstants.PHG.sel(domain="above", const="c", sex_id=sexes) + ax_fit_below = msub_below * LexmodelConstants.PHG.sel( + domain="below", const="m", sex_id=sexes + ) + LexmodelConstants.PHG.sel(domain="below", const="c", sex_id=sexes) + + ax_fit = ax_fit_above.combine_first(ax_fit_below) + + return ax_fit + + +def gbd5_no_old_age_fit(mx: xr.DataArray) -> xr.Dataset: + """This is based on :func:`fbd_research.lex.model.gbd4_all_youth`. + + Except that age groups [31, 32, 235] are not replaced with fitted values. + + Args: + mx (xr.DataArray): Mortality rate + + Raises: + RuntimeError: if `mx` is missing some age groups + ValueError: if age group ids in `mx` are not consistent + + Returns: + xr.Dataset: Period life expectancy. + """ + hard_coded_gbd_round_id = 5 + nx_gbd = nx_contiguous_round(gbd_round_id=hard_coded_gbd_round_id) + + try: + mx = mx.sel(age_group_id=nx_gbd[DimensionConstants.AGE_GROUP_ID].values) + except KeyError: + raise RuntimeError( + f"Not all ages in incoming data. Have {mx.age_group_id.values} " + f"Want {nx_gbd.age_group_id.values}" + ) + + if not consistent_age_group_ids(mx.age_group_id.values): + raise ValueError(f"age group id {mx.age_group_id.values} are not consistent") + + nx_base = nx_from_age_group_ids(mx.age_group_id) + + # We want a baseline on nulls. If there is one null in a draw, we + # null out the whole draw, so this algorithm may increase null + # count, but it won't increase the bounding box on nulls. + mx_null_count = mx.where(mx.isnull(), drop=True).size + if mx_null_count > 0: + warnings.warn(f"Incoming mx has {mx_null_count} nulls") + + # first compute ax assuming constant mortality over interval + ax = cm_mean_age(mx, nx_base) + + # Graduation applies only where it's 5-year age groups. + middle_ages = nx_base.where(nx_base == nx_base.median()).dropna( + dim=DimensionConstants.AGE_GROUP_ID + ) + graduation_ages = middle_ages.age_group_id.values + # graduation ages are the age group ids that have 5-year neighbors: + # array([ 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 30, + # 31, 32]) + # Note that 6 and 32 will remain unchanged through graduation method, + # because they do not have 5-year neighbors on both sides + + # use graduation method from Preston to fine-tune ax using neighboring + # ax values. ax at age groups [2-6, 32, 235] will be set to cm_mean_age + ax = ax_graduation_cubic(mx, ax, nx_base, graduation_ages) + + # now fix the under-5 age groups to match GBD methodology + ax = under_5_ax_preston_rake(hard_coded_gbd_round_id, mx, ax, nx_base) + + ex = fm_period_life_expectancy(mx, ax, nx_gbd) + + ex_null_count = ex.where(ex.isnull(), drop=True).size + if ex_null_count > mx_null_count: + ex_null = ex.where(ex.isnull(), drop=True) + raise RuntimeError(f"Graduation created {ex_null.coords} nulls") + + ds = xr.Dataset(dict(ex=ex, mx=mx, ax=ax)) + + return append_nLx_lx(ds) + + +def gbd4_all_youth(mx: xr.DataArray) -> xr.Dataset: + """Period life expectancy that matches GBD 2016 except young ages aren't fit. + + So age groups ENN, PNN, and LNN remain. This is called + the "baseline" set of ages, not the "lifetable" set of ages. + Input ages include ENN, PNN, and LNN, and the terminal age group + is 95+. This includes the old age fit from US Counties code + and uses the graduation method. + + Args: + mx (xr.DataArray): Mortality rate + + Raises: + RuntimeError: `mx` is missing some age groups + ValueError: `mx` contains inconsistent age group ids, or `qxp` contains values outside + what is expected, or `qxp` contains values outside what is expected + + Returns: + xr.Dataset: Period life expectancy. + """ + nx_gbd = nx_contiguous_round(gbd_round_id=4) + # nx_gbd looks like array([1.917808e-02, 5.753425e-02, 9.232877e-01, + # 4.000000e+00, 5.000000e+00, ... 5.000000e+00, 5.000000e+00, 4.500000e+01] + try: + mx = mx.loc[dict(age_group_id=nx_gbd.age_group_id.values)] + except KeyError: + raise RuntimeError( + f"Not all ages in incoming data " + f"have {mx.age_group_id.values} " + f"want {nx_gbd.age_group_id.values}" + ) + if not consistent_age_group_ids(mx.age_group_id.values): + raise ValueError("`mx` contains inconsistent age group ids.") + nx_base = nx_from_age_group_ids(mx.age_group_id) + # >>> nx_base['age_group_id'] + # + # array([ 28, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + # 17, 18, 19, 20, 30, 31, 32, 235]) + # 28 is 1 year wide, 5 is 4 years wide + + # We want a baseline on nulls. If there is one null in a draw, we + # null out the whole draw, so this algorithm may increase null + # count, but it won't increase the bounding box on nulls. + mx_null_count = mx.where(mx.isnull(), drop=True).size + if mx_null_count > 0: + warnings.warn(f"Incoming mx has {mx_null_count} nulls") + expected_good = mx.size - mx_null_count + + # first compute ax assuming constant mortality over interval + ax = cm_mean_age(mx, nx_base) + # Graduation applies only where it's 5-year age groups. + middle_ages = nx_base.where(nx_base == nx_base.median()).dropna( + dim=DimensionConstants.AGE_GROUP_ID + ) + graduation_ages = middle_ages.age_group_id.values + # graduation ages are the age group ids that have 5-year neighbors: + # array([ 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 30, + # 31, 32]) + + # use graduation method from Preston to fine-tune ax using neighboring + # ax values. The boundary ax's will be set to cm_mean_age + ax = ax_graduation_cubic(mx, ax, nx_base, graduation_ages) + + # Mortality from mx is interval-by-interval, so we + # can leave young out of it. + # mortality {}_nq_x = \frac{m_x n_x}{1+ m_x(n_x - a_x)} + qxp = fm_mortality(mx, ax, nx_base) + if not (qxp < 1.0001).sum() >= expected_good: + raise ValueError("`qxp` contains values outside what is expected") + if not (qxp > -0.0001).sum() >= expected_good: + raise ValueError("`qxp` contains values outside what is expected") + qxf, axf, mxf = old_age_fit_us_counties(qxp) + # qxp['age_group_id'] == [2..20, 30..32, 235] + # qxf['age_group_id'] == [ 31, 32, 33, 44, 45, 148] + del qxp # frees up 1G for 100 draws + gc.collect() + + # This goes past the 95+ limit, so reduce it. + mxp, axp = condense_ages_to_terminal( + qxf, axf, AgeConstants.AGE_95_TO_100_ID, AgeConstants.AGE_95_PLUS_ID + ) + v = mxf.age_group_id.values # array([ 31, 32, 33, 44, 45, 148]) + term = np.argwhere(v == AgeConstants.AGE_95_TO_100_ID)[0][0] # 2 + mxf_drop = xr.concat( + [mxf.loc[dict(age_group_id=v[:term])], mxp], dim=DimensionConstants.AGE_GROUP_ID + ) # 31, 32, 235 + axf_drop = xr.concat( + [axf.loc[dict(age_group_id=v[:term])], axp], dim=DimensionConstants.AGE_GROUP_ID + ) # 32, 32, 235 + + del qxf, axf, mxf + gc.collect() + + mx_combined = combine_age_ranges(None, mx, mxf_drop) + ax_combined = combine_age_ranges(None, ax, axf_drop) + if not (mx_combined.age_group_id.values[-5:] == nx_gbd.age_group_id.values[-5:]).all(): + raise ValueError("`mx_combined` and `nx_gbd` age groups aren't aligned") + + del ax, mx # frees up 2G for 100 draws + gc.collect() + + ex = fm_period_life_expectancy(mx_combined, ax_combined, nx_gbd) + + ex_null_count = ex.where(ex.isnull(), drop=True).size + if ex_null_count > mx_null_count: + ex_null = ex.where(ex.isnull(), drop=True) + raise RuntimeError(f"graduation introduced null draws with bounds {ex_null.coords}") + + ds = xr.Dataset(dict(ex=ex, mx=mx_combined, ax=ax_combined)) + + del mx_combined, ax_combined # frees up 2G for 100 draws + gc.collect() + + return append_nLx_lx(ds) + + +def combine_age_ranges( + young: xr.DataArray, + middle: xr.DataArray, + old: xr.DataArray, +) -> xr.DataArray: + """Combine differenr age groups. + + Three Dataarrays have different age groups in them. + Combine them to make one dataarray. There may be overlap + among age_group IDs, so anything in the middle + is overwritten by the young or old. The dims and coords in the middle + are used to determine the dims and coords in + output. Point coordinates are kept. + This works even when young, middle, and old have different age intervals, + for instance [28, 5] versus [2, 3, 4]. + + Args: + young (xr.DataArray): fit for young ages, or None if there is no fit + middle (xr.DataArray): middle age groups + old (xr.DataArray): fit for old ages + + Returns: + xr.DataArray: With dims in the same order as the middle. + """ + if young is not None: + with_young = age_id_older_than( + young.age_group_id[-1].values.tolist(), middle.age_group_id.values + ) + young_ages = young.age_group_id.values + else: + with_young = middle.age_group_id.values + young_ages = "none" + # The age IDs are just IDs, and younger ages may have larger IDs, so sort. + edge_ages = age_id_older_than(old.age_group_id[0].values.tolist(), with_young, True) + logger.debug( + f"combine_age_ranges young {young_ages} " + f"middle {middle.age_group_id.values} " + f"old {old.age_group_id.values} " + f"keep {edge_ages}" + ) + mid_cut = middle.loc[{DimensionConstants.AGE_GROUP_ID: edge_ages}] + if young is not None: + return xr.concat([young, mid_cut, old], dim=DimensionConstants.AGE_GROUP_ID) + else: + return xr.concat([mid_cut, old], dim=DimensionConstants.AGE_GROUP_ID) + + +def condense_ages_to_terminal( + qx: xr.DataArray, + ax: xr.DataArray, + terminal: int, + new_terminal: int, +) -> Tuple[xr.DataArray, xr.DataArray]: + r"""Given a life table that includes later ages, truncate it to a terminal age group. + + For the terminal age group, we know + :math:`{}_nm_x = 1/{}_na_x`, so use that as our guide. + + .. math:: + + {}_na_x = {}_na_{x_0} + {}_np_{x_0}(n_{x_0}+{}_na_{x_1}) + +{}_np_{x_0}\:{}_np_{x_1}(n_{x_0}+n_{x_1}+{}_na_{x_2}) + + Note we get :math:`q_x` and return :math:`m_x`. + + Args: + qx (xr.DataArray): Mortality with age group ids. + ax (xr.DataArray): Mean age of death. + terminal (int): Which age group will become the terminal one. + new_terminal (int): The age group to assign to the last interval. + + Raises: + ValueError: unexpected `v.shape` + + Returns: + Tuple[xr.DataArray, xr.DataArray]: + mx: :math:`{}_nm_x` for the terminal age group. + ax: :math:`{}_na_x` for the terminal age group. + """ + v = qx.age_group_id.values + term = np.argwhere(v == terminal)[0][0] + if not term + 1 < v.shape[0]: + raise ValueError("unexpected `v.shape`") + axp = ax.loc[dict(age_group_id=[terminal])] + # This will be survival from the terminal age group + # to the current age group in the loop. + npx = 1 - qx.loc[dict(age_group_id=[terminal])] + # This will be the total time from the terminal age group + # to the current age group in the loop. + nx = nx_from_age_group_ids(ax.age_group_id) + # This is a C-number, no index. + nx_running = float(nx.loc[dict(age_group_id=[terminal])]) + logger.debug(f"nx_running {type(nx_running)} {nx_running}") + + for avx in v[term + 1 :]: + # Have to set indices so that multiplication can happen. + axp.coords[DimensionConstants.AGE_GROUP_ID] = [avx] + npx.coords[DimensionConstants.AGE_GROUP_ID] = [avx] + # Here, the invariant is now true. + # npx is \prod_{i xr.DataArray: + r"""Computes diff of logit-:math:`q_x` iteratively. + + Starting from age 90-94 (id 32), to compute the :math:`q_x` for 95-99 (id 33), 100-104 + (id 44), 105-109 (id 45). + + Starting with :math:`{}_5q_{90}`, one first computes + + .. math:: + \Delta_{90} = c_{90} + \beta_{90} + \beta\_logit\_q_{90} * + \text{logit}(q_{90}) + + where :math:`c`, :math:`\beta`, and :math:`\beta\_logit\_q_x` are all + regression constants available in + :func:`fbd_core.demog.lifemodel.old_age_fit_qx_95_plus`. + From then, for every sex, we have:: + + for i in [95, 100]: + + .. math:: + q_i = \text{expit}( \text{logit}(q_{i-5}) + \Delta_{i-5} ) + + .. math:: + \Delta_i = c_i + \beta_i + \beta\_logit\_q_{90} * \text{logit}(q_{90}) + + Args: + mx (xr.DataArray): mortality rate, with age_group_id dim. + ax (xr.DataArray): ax. + + Returns: + (xr.DataArray): lx where age group id 235 is replaced with + age group ids 33 (95-100), 44 (100-105), 45 (105-110), + and 148 (110+, where lx is set to 0). + """ + # here we compute qx from mx + nx_base = nx_from_age_group_ids(mx["age_group_id"]) + + # make a baseline qx assuming constant mortality + qx = fm_mortality(mx, ax, nx_base) + + # replace age group id 235 of qx with (33, 44) via demog team extrapolation + qx = old_age_fit_qx_95_plus(qx) # overwrite qx + # make lx from qx + lx = _qx_to_lx(qx) + # Another adjustment based on self-consistency requirement + qx = _self_consistent_qx_adjustment(qx, lx, mx) # overwrite qx + lx = _qx_to_lx(qx) # make lx again, based on adjusted qx + + return lx + + +def _qx_to_lx(qx: xr.DataArray) -> xr.DataArray: + r"""Computes :math:`l_x` based on :math:`q_x`. + + Where :math:`q_x` already contains the 95-100 (33) and 100-105 (44) age groups. Also + computes :math:`l_x` for 105-110 (45), and then set :math:`l_x` for 110+ to be 0. + + Args: + qx (xr.DataArray): Probability of dying. + + Raises: + ValueError: if `qx` is missing ages 33 and/or 34, or if `lx` has an unexpected number + of age group ids, or if final lx should have age group ids 33, 44, 45, and 148. + + Returns: + (xr.DataArray): lx. + """ + if tuple(qx["age_group_id"].values[-2:]) != ( + AgeConstants.AGE_95_TO_100_ID, + AgeConstants.AGE_100_TO_105_ID, + ): + raise ValueError("qx must have age group ids 33 and 44") + + px = 1.0 - qx # now we have survival all the way to 100-105 (44) age group + + # Because l{x+n} = lx * px, we can compute all lx's if we start with + # l_0 = 1 and iteratively apply the px's of higher age groups. + # So we compute l_105-110, since we have p_100-105 from extrapolated qx. + # We start with a set of lx's that are all 1.0 + lx = xr.full_like(px, 1) + # now expand lx to have age groups 105-110 (45) + lx = expand_dimensions(lx, fill_value=1, age_group_id=[AgeConstants.AGE_105_TO_110_ID]) + + # Since l{x+n} = lx * px, we make cumulative prduct of px down age groups + # and apply the product to ages[1:] (since ages[0]) has lx = 1.0 + ages = lx["age_group_id"] + + ppx = px.cumprod(dim="age_group_id") # the cumulative product of px + ppx.coords["age_group_id"] = ages[1:] # need to correspond to ages[1:] + lx.loc[dict(age_group_id=ages[1:])] *= ppx # lx all the way to 100-105 + + # now artificially sets lx to be 0 for the 110+ age group. + lx = expand_dimensions(lx, fill_value=0, age_group_id=[AgeConstants.AGE_110_PLUS_ID]) + + if not (lx.sel(age_group_id=2) == 1).all(): + raise ValueError("`lx` has an unexpected number of age group ids") + if not tuple(lx["age_group_id"].values[-4:]) == ( + AgeConstants.AGE_95_TO_100_ID, + AgeConstants.AGE_100_TO_105_ID, + AgeConstants.AGE_105_TO_110_ID, + AgeConstants.AGE_110_PLUS_ID, + ): + raise ValueError("final lx should have age group ids 33, 44, 45, and 148.") + + return lx + + +def _self_consistent_qx_adjustment( + qx: xr.DataArray, + lx: xr.DataArray, + mx: xr.DataArray, +) -> xr.DataArray: + r"""A universal relationship exists between the following formula. + + :math:`l_x`, :math:`q_x`, and :math:`m_x` at the terminal age group (95+): + + .. math:: + {}_{\infty}m_{95} = \frac{l_{95}}{T_{95}} + + where :math:`T_{95} = \int_{95}^{\infty} l_x dx`. + + Because we forecast :math:`l_{95}` and :math:`m_{95}`, and that + :math:`l_{x}` for :math:`x > 95` is extrapolated independently + (via :func:`fbd_research.lex.model.demography_team_extrapolation_of_lx`), + the above relationship does not hold. We therefore need to adjust our + values of :math:`_5l_{100}`, :math:`{}_5l_{105}` so that the relationship + holds. + + Note that we set :math:`l_{110} = 0`. + + We begin with the approximation of :math:`T_{95}`, + using Simpson's 3/8 rule: + + .. math:: + T_{95} = \int_{95}^{\infty} \ l_{x} dx \ + \approx \ \frac{3 \ n}{8}({}_5l_{95} + 3 {}_5l^{\prime}_{100} + + 3 {}_5l^{\prime}_{105} + + 3 {}_5l^{\prime}_{110}) \ + = \ T^{\prime}_{95} + + where :math:`n` is 5 years, the age group bin size, and :math:`{}^{\prime}` + denotes the current tentative value. + Also note that :math:`{}_5l_{110} = 0` in our case. + + The above formula allows us to define + + .. math:: + \alpha &= \frac{T_{95}}{T^{\prime}_{95}} \\ + &= \frac{\frac{l_{95}}{m_{95}}}{T^{\prime}_{95}} + :label: 1 + + as a "mismatch factor". + + We also declare that the ratio between :math:`{}_5q_{95}` + and :math:`{}_5q_{100}` is fixed: + + .. math:: + \beta = \frac{{}_5q_{95}}{{}_5q_{100}} + = \frac{{}_5q_{95}^{\prime}}{{}_5q_{100}^{\prime}} + < 1 + :label: 2 + + Hence we may proceed with the following derivation: + + .. math:: + {}_5l_{100} &= {}_5l_{95} \ (1 - {}_5q_{95}) \\ + &= {}_5l_{95} \ (1 - \beta \ {}_5q_{100}) + :label: 3 + + .. math:: + {}_5l_{105} &= {}_5l_{100} \ (1 - {}_5q_{100}) \\ + &= {}_5l_{95} \ (1 - {}_5q_{95}) \ (1 - {}_5q_{100}) \\ + &= {}_5l_{95} \ (1 - \beta \ {}_5q_{100}) \ + (1 - {}_5q_{100}) + :label: 4 + + .. math:: + \alpha &\approx \frac{\frac{15}{8} \ ({}_5l_{95} + 3 \ {}_5l_{100} + + 3 \ {}_5l_{105})}{\frac{15}{8} ( {}_5l_{95} + + 3 \ {}_5l^{\prime}_{100} + 3 \ {}_5l^{\prime}_{105})}\\ + &= \frac{{}_5l_{95} + 3 \ {}_5l_{95} \ (1 - {}_5q_{95}) + + 3 \ {}_5l_{95} \ (1 - {}_5q_{95})(1 - {}_5q_{100})} + {{}_5l_{95} + 3 \ {}_5l_{95}(1 - {}_5q^{\prime}_{95}) + + 3 \ {}_5l_{95}(1 - {}_5q^{\prime}_{95}) + (1 - {}_5q^{\prime}_{100})} \\ + &= \frac{{}_5l_{95} + 3 \ {}_5l_{95} \ (1 - \beta \ {}_5q_{100}) + + 3 \ {}_5l_{95} \ (1 - \beta \ {}_5q_{100}) + (1 - {}_5q_{100})} + {{}_5l_{95} + + 3 \ {}_5l_{95} \ (1 - \beta \ {}_5q^{\prime}_{100}) + + 3 \ {}_5l_{95} \ (1 - \beta \ {}_5q^{\prime}_{100}) + (1 - {}_5q^{\prime}_{100})} \\ + &= \frac{1 + 3 \ (1 - \beta \ {}_5q_{100}) + + 3 \ (1 - \beta \ {}_5q_{100})(1 - {}_5q_{100})} + {1 + 3 \ (1 - \beta \ {}_5q^{\prime}_{100}) + + 3 \ (1 - \beta {}_5q^{\prime}_{100}) + (1 - {}_5q^{\prime}_{100})} \\ + &= \frac{4 - 3 \ \beta \ {}_5q_{100} + 3 - 3 \ {}_5q_{100} - + 3 \ \beta \ {}_5q_{100} + 3 \ \beta \ {{}_5q_{100}}^2} + {4 - 3 \ \beta {}_5q^{\prime}_{100} + 3 - + 3 \ {}_5q^{\prime}_{100} - + 3 \ \beta \ {}_5q^{\prime}_{100} + + 3 \ \beta \ {{}_5q^{\prime}_{100}}^2} \\ + &= \frac{\frac{7}{3} - (2 \ \beta + 1) \ {}_5q_{100} + + \beta \ {{}_5q_{100}}^2}{\frac{7}{3} - + (2 \ \beta + 1) \ {}_5q^{\prime}_{100} + + \beta \ {{}_5q^{\prime}_{100}}^2} + :label: 5 + + where the denominator is known. + If we define + + .. math:: + \gamma = \frac{7}{3} - \alpha \ (\frac{7}{3} - + (2 \beta + 1) \ {}_5q^{\prime}_{100} + + \beta \ {{}_5q^{\prime}_{100}}^2) + :label: 6 + + then we have the quadratic equation + + .. math:: + \beta \ {{}_5q_{100}}^2 - (2 \beta + 1) \ {}_5q_{100} + \gamma = 0 + :label: 7 + + with the solution + + .. math:: + {}_5q_{100} = \frac{(2 \ \beta + 1) \pm \sqrt{(2 \ \beta + 1)^2 - + 4 \beta \ \gamma}}{2 \ \beta} + :label: 8 + + Because :math:`{}_5q_{100} \leq 1`, subtraction in the numerator of :eq:`8` + is the only viable solution. + + Args: + qx (xr.DataArray): qx that has age groups (..., 33, 44), + with 33 (95-100) and 44 (100-105). Should not have 235 (95+). + lx (xr.DataArray): lx that has age groups (..., 33, 44, 45, 148), + with 45 (105-110) and 148 (110+), + where (lx.sel(age_group_id=148) == 0).all(). + mx (xr.DataArray): mx, needed to compute :math: `T_{95}`. + Has age_group_id=235 (95+) instead of (33, 44, 45, 148). + + Returns: + (xr.DataArray): qx where age groups 33 (95-100) and 44 (100-105) are + "adjusted". + """ + n = 5.0 # 5 year age group width + # 33 = 95-100 yrs, 44 = 100-105 yrs, 45 = 105-110 yrs, 148 = 110+ yrs. + T_95_prime = lx.sel(age_group_id=AgeConstants.AGE_95_TO_100_ID) * (3.0 / 8.0 * n) + lx.sel( + age_group_id=AgeConstants.AGE_100_TO_110 + ).sum(DimensionConstants.AGE_GROUP_ID) * (9.0 / 8.0 * n) + alpha = ( + lx.sel(age_group_id=AgeConstants.AGE_95_TO_100_ID) + / mx.sel(age_group_id=AgeConstants.AGE_95_PLUS_ID) + ) / T_95_prime + + # these are the original, unadulterated q95 & q100 + q95_prime = qx.sel(age_group_id=AgeConstants.AGE_95_TO_100_ID) + q100_prime = qx.sel(age_group_id=AgeConstants.AGE_100_TO_105_ID) + beta = q95_prime / q100_prime + + gamma = 7.0 / 3.0 - alpha * ( + 7.0 / 3.0 - (2 * beta + 1) * q100_prime + beta * (q100_prime**2) + ) + + q100 = ((2 * beta + 1) - np.sqrt((2 * beta + 1) ** 2 - 4 * beta * gamma)) / (2 * beta) + + # unfortunately, ~20% of q100 adjusted via this approx will be > 1. + # it's even possible to end up with q100 < 0. + # it makes no sense to cap q100 to 1, because that means l105 == 0, + # which we do not want. The only option left is the following + q100 = q100.where((q100 < 1) & (q100 > 0)).fillna(q100_prime) + q95 = q100 * beta # always <= 1 because beta is always <= 1 + + # Update original qx with adjusted q95 and q100 values + qx.loc[dict(age_group_id=AgeConstants.AGE_95_TO_100_ID)] = q95 + qx.loc[dict(age_group_id=AgeConstants.AGE_100_TO_105_ID)] = q100 + + return qx \ No newline at end of file diff --git a/gbd_2021/disease_burden_forecast_code/met_need/README.md b/gbd_2021/disease_burden_forecast_code/met_need/README.md new file mode 100644 index 0000000..662dc37 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/met_need/README.md @@ -0,0 +1,45 @@ +Pipeline code for met need for modern contraceptive use +To model met need, we use an ARC-only run of our GenEM (Generalized Ensemble Model) pipeline. +This consists of iterated parallelized calls to the arc_all_omegas function in arc_main.py, +in this case specifically with entity="met_need" and stage="met_need" + +``` +arc_main.py +Forecasts an entity using the Annualized Rate-of-Change (ARC) method. +``` + +``` +arc_method.py +ARC method module with functions for making forecast scenarios +``` + +``` +collect_submodels.py +Script to collect and collapse components into genem for future stage +``` + +``` +constants.py +FHS generalized ensemble model pipeline for forecasting - local constants +``` + +``` +get_model_weights_from_holdouts.py +Collects submodel predictive validity statistics to compile sampling weights for genem +``` + +``` +model_restrictions.py +Captures restrictions in which models get run for each entity/location +``` + +``` +omega_selection_strategies.py +Strategies for determining the weight for the ARC method +``` + +``` +predictive_validity.py +Utility functions for determining predictive validity +``` + diff --git a/gbd_2021/disease_burden_forecast_code/met_need/arc_main.py b/gbd_2021/disease_burden_forecast_code/met_need/arc_main.py new file mode 100644 index 0000000..e7d35e6 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/met_need/arc_main.py @@ -0,0 +1,370 @@ +"""This script forecasts an entity using the ARC method. + +This script runs predictive validity for a entity to determine the weight used +in the arc quantile method. For each entity, this script forecasts the out +of sample window using different weights. Then determines the rmse and bias for +that weight. The output of this script are netCDFs containing the results of +the predictive metrics for each entity. + +After the predictive validity step an omega will be selected to forecast each +entity + +Notes regarding the truncate/capping optional flags: + +1.) truncate-quantiles are used for winsorizing only the past logit + age-standardized data (before computing the annualized rate of change). + In case of 0.025/0.975 quantiles (calculated across locations), + the data above the 97.5th percentile set to the 97.5th percentile wherein + the data below the 2.5th percentile set to the 2.5th percentile. + +2.) cap-quantiles are used for winsorizing only the future entities + (after generating forecasted entities). In case of 0.01/0.99 quantiles + (calculated based on the past entities), the forecasts above the 99th + percentile set to the 99th percentile wherein the forecasts below the + 1st percentile set to the 1st percentile. +""" + +import gc +from typing import List, Optional, Tuple + +import numpy as np +import xarray as xr +from fhs_lib_data_transformation.lib import filter +from fhs_lib_database_interface.lib.constants import SexConstants +from fhs_lib_file_interface.lib.version_metadata import FHSFileSpec +from fhs_lib_file_interface.lib.versioning import Versions +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr_scenario, save_xr_scenario +from fhs_lib_model.lib.arc_method import arc_method +from fhs_lib_year_range_manager.lib.year_range import YearRange +from tiny_structured_logger.lib import fhs_logging + +from fhs_lib_genem.lib import predictive_validity as pv +from fhs_lib_genem.lib.constants import ( + FileSystemConstants, + ModelConstants, + SEVConstants, + TransformConstants, +) + +logger = fhs_logging.get_logger() + + +def determine_entity_name_path(entity: str, stage: str) -> Tuple[str, str]: + """Take the entity name and determine name and file path.""" + if stage == "sev" and "-" in entity: # is an iSEV, specified as cause-risk + acause, rei = entity.split("-") + sub_folder = "risk_acause_specific" + file_name = "_".join([acause, rei, SEVConstants.INTRINSIC_SEV_FILENAME_SUFFIX]) + else: + sub_folder = "" + file_name = f"{entity}" + + return sub_folder, file_name + + +def _clip_past(past_mean: xr.DataArray, transform: str) -> xr.DataArray: + if transform == "logit": + # it makes sense to ceiling logit-transformable data (since its 0-1) + clipped_past = past_mean.clip(min=ModelConstants.FLOOR, max=1 - ModelConstants.FLOOR) + elif transform == "log": + # log transformable data should only be floored + clipped_past = past_mean.clip(min=ModelConstants.FLOOR) + else: + # data we won't transform shouldn't be clipped. + clipped_past = past_mean + return clipped_past + + +def _find_limits( + past_age_std_mean: xr.DataArray, + past_last_year: xr.DataArray, + upper_quantile: float, + lower_quantile: float, +) -> xr.DataArray: + """Find upper/lower limits to cap the forecasts.""" + past_age_std_quantiles = past_age_std_mean.quantile( + [lower_quantile, upper_quantile], dim=["location_id", "year_id"] + ) + upper = past_age_std_quantiles.sel(quantile=upper_quantile, drop=True) + lower = past_age_std_quantiles.sel(quantile=lower_quantile, drop=True) + + past_last_year_gt_upper = past_last_year.where(past_last_year > upper) + past_last_year_lt_lower = past_last_year.where(past_last_year < lower) + + upper_cap_lims = past_last_year_gt_upper.fillna(upper).rename("upper") + lower_cap_lims = past_last_year_lt_lower.fillna(lower).rename("lower") + + cap_lims = xr.merge([upper_cap_lims, lower_cap_lims]) + return cap_lims + + +def _reshape_bound(data: xr.DataArray, bound: xr.DataArray) -> xr.DataArray: + """Broadcast and align the dims of `bound` so that they match `data`.""" + expanded_bound, _ = xr.broadcast(bound, data) + return expanded_bound.transpose(*data.coords.dims) + + +def _cap_forecasts( + years: YearRange, + cap_quantiles: Tuple[float, float], + most_detailed_past: xr.DataArray, + past_mean: xr.DataArray, + forecast: xr.DataArray, +) -> xr.DataArray: + """Cap upper and lower bound on forecasted data, using quantiles from past data.""" + last_year = most_detailed_past.sel(year_id=years.past_end, drop=True) + lower_quantile, upper_quantile = cap_quantiles + caps = _find_limits( + past_mean, last_year, upper_quantile=upper_quantile, lower_quantile=lower_quantile + ) + returned_past = forecast.sel(year_id=years.past_years) + forecast = forecast.sel(year_id=years.forecast_years) + + lower_bound = _reshape_bound(forecast, caps.lower) + upper_bound = _reshape_bound(forecast, caps.upper) + + mean_clipped = forecast.clip(min=lower_bound, max=upper_bound).fillna(0) + + del forecast + gc.collect() + + capped_forecast = xr.concat([returned_past, mean_clipped], dim="year_id") + + return capped_forecast + + +def _forecast_entity( + omega: float, + past: xr.DataArray, + transform: str, + truncate: bool, + truncate_quantiles: Tuple[float, float], + replace_with_mean: bool, + reference_scenario: str, + years: YearRange, + gbd_round_id: int, + cap_forecasts: bool, + cap_quantiles: Tuple[float, float], + national_only: bool, + age_standardize: bool, + rescale_ages: bool, + remove_zero_slices: bool, +) -> xr.DataArray: + """Prepare data for forecasting, run model and post-process results.""" + most_detailed_past = filter.make_most_detailed_location( + data=past, gbd_round_id=gbd_round_id, national_only=national_only + ) + if "sex_id" not in most_detailed_past.dims or list(most_detailed_past.sex_id.values) != [ + SexConstants.BOTH_SEX_ID + ]: + most_detailed_past = filter.make_most_detailed_sex(data=most_detailed_past) + if age_standardize: + most_detailed_past = filter.make_most_detailed_age( + data=most_detailed_past, gbd_round_id=gbd_round_id + ) + + if "draw" in most_detailed_past.dims: + past_mean = most_detailed_past.mean("draw") + else: + past_mean = most_detailed_past + + clipped_past = _clip_past(past_mean=past_mean, transform=transform) + + processor = TransformConstants.TRANSFORMS[transform]( + years=years, + gbd_round_id=gbd_round_id, + age_standardize=age_standardize, + remove_zero_slices=remove_zero_slices, + rescale_age_weights=rescale_ages, + ) + + transformed_past = processor.pre_process(clipped_past) + + del clipped_past + gc.collect() + + transformed_forecast = arc_method.arc_method( + past_data_da=transformed_past, + gbd_round_id=gbd_round_id, + years=years, + diff_over_mean=ModelConstants.DIFF_OVER_MEAN, + truncate=truncate, + reference_scenario=reference_scenario, + weight_exp=omega, + replace_with_mean=replace_with_mean, + truncate_quantiles=truncate_quantiles, + scenario_roc="national", + ) + + forecast = processor.post_process(transformed_forecast, past_mean) + + if np.isnan(forecast).any(): + raise ValueError("NaNs in forecasts") + + if cap_forecasts: + forecast = _cap_forecasts( + years, cap_quantiles, most_detailed_past, past_mean, forecast + ) + + return forecast + + +def arc_all_omegas( + entity: str, + stage: str, + intrinsic: bool, + subfolder: str, + versions: Versions, + model_name: str, + omega_min: float, + omega_max: float, + omega_step_size: float, + transform: str, + truncate: bool, + truncate_quantiles: Optional[Tuple[float, float]], + replace_with_mean: bool, + reference_scenario: str, + years: YearRange, + gbd_round_id: int, + cap_forecasts: bool, + cap_quantiles: Optional[Tuple[float, float]], + national_only: bool, + age_standardize: bool, + rescale_ages: bool, + predictive_validity: bool, + remove_zero_slices: bool, +) -> None: + """Forecast an entity with different omega values. + + If a SEV, the rei input could be a risk, or it could be a cause-risk. + If it's a cause-risk (connected via hyphen), it's meant to be an + intrinsic SEV, which would come from + in_version/risk_acause_specific/{cause}_{risk}_intrinsic.nc, + and the forecasted result would go to + out_version/risk_acause_specific/{cause}_{risk}_intrinsic.nc. + + Args: + entity (str): Entity to forecast + stage (str): Stage of the run. E.x. sev, death, etc. + intrinsic (bool): Whether this entity obtains the _intrinsic suffix + subfolder (str): Optional subfolder for reading and writing files. + versions (Versions): versions object with both past and future (input and output). + model_name (str): Name to save the model under. + omega_min (float): The minimum omega to try + omega_max (float): The maximum omega to try + omega_step_size (float): The step size of omegas to try between 0 and omega_max + transform (str): Space to forecast data in + truncate (bool): If True, then truncates the dataarray over the given dimensions + truncate_quantiles (Tuple[float, float]): The tuple of two floats representing the + quantiles to take + replace_with_mean (bool): If True and `truncate` is True, then replace values outside + of the upper and lower quantiles taken across `location_id` and `year_id` and with + the mean across `year_id`, if False, then replace with the upper and lower bounds + themselves + reference_scenario (str): If 'median' then the reference scenario is made using the + weighted median of past annualized rate-of-change across all past years, 'mean' + then it is made using the weighted mean of past annualized rate-of-change across + all past years + years (YearRange): forecasting year range + gbd_round_id (int): the gbd round id + cap_forecasts (bool): If used, forecasts will be capped. To forecast without caps, + dont use this + cap_quantiles (tuple[float]): Quantiles for capping the future + national_only (bool): Whether to run national only data or not + rescale_ages (bool): whether to rescale during ARC age standardization. We are + currently only setting this to true for the sevs pipeline. + age_standardize (bool): whether to age_standardize before modeling. + predictive_validity (bool): whether to do predictive validity or real forecasts + remove_zero_slices (bool): If True, remove zero-slices along certain dimensions, when + pre-processing inputs, and add them back in to outputs. + """ + logger.debug(f"Running `forecast_one_risk_main` for {entity}") + + input_version_metadata = versions.get(past_or_future="past", stage=stage) + + file_name = entity + if intrinsic: # intrinsic entities have _intrinsic attached at file name + file_name = entity + "_intrinsic" + + data = open_xr_scenario( + file_spec=FHSFileSpec( + version_metadata=input_version_metadata, + sub_path=(subfolder,), + filename=f"{file_name}.nc", + ) + ) + + # rid the past data of point coords because they throw off weighted-quantile + superfluous_coords = [d for d in data.coords.keys() if d not in data.dims] + data = data.drop_vars(superfluous_coords) + + past = data.sel(year_id=years.past_years) + + if predictive_validity: + holdouts = data.sel(year_id=years.forecast_years) + all_omega_pv_results: List[xr.DataArray] = [] + + for omega in pv.get_omega_weights(omega_min, omega_max, omega_step_size): + logger.debug("omega:{}".format(omega)) + + forecast = _forecast_entity( + omega=omega, + past=past, + transform=transform, + truncate=truncate, + truncate_quantiles=truncate_quantiles, + replace_with_mean=replace_with_mean, + reference_scenario=reference_scenario, + years=years, + gbd_round_id=gbd_round_id, + cap_forecasts=cap_forecasts, + cap_quantiles=cap_quantiles, + national_only=national_only, + age_standardize=age_standardize, + rescale_ages=rescale_ages, + remove_zero_slices=remove_zero_slices, + ) + + if predictive_validity: + all_omega_pv_results.append( + pv.calculate_predictive_validity( + forecast=forecast, holdouts=holdouts, omega=omega + ) + ) + + else: + output_version_metadata = versions.get(past_or_future="future", stage=stage) + + output_file_spec = FHSFileSpec( + version_metadata=output_version_metadata, + sub_path=(FileSystemConstants.SUBMODEL_FOLDER, model_name, subfolder), + filename=f"{file_name}_{omega}.nc", + ) + + save_xr_scenario( + xr_obj=forecast, + file_spec=output_file_spec, + metric="rate", + space="identity", + omega=omega, + transform=transform, + truncate=str(truncate), + truncate_quantiles=str(truncate_quantiles), + replace_with_mean=str(replace_with_mean), + reference_scenario=str(reference_scenario), + cap_forecasts=str(cap_forecasts), + cap_quantiles=str(cap_quantiles), + ) + + if predictive_validity: + pv_df = pv.finalize_pv_data(pv_list=all_omega_pv_results, entity=entity) + + pv.save_predictive_validity( + file_name=file_name, + gbd_round_id=gbd_round_id, + model_name=model_name, + pv_df=pv_df, + stage=stage, + subfolder=subfolder, + versions=versions, + ) diff --git a/gbd_2021/disease_burden_forecast_code/met_need/arc_method.py b/gbd_2021/disease_burden_forecast_code/met_need/arc_method.py new file mode 100644 index 0000000..0e3d8a4 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/met_need/arc_method.py @@ -0,0 +1,1082 @@ +"""Module with functions for making forecast scenarios.""" + +from typing import Any, Callable, Iterable, List, Optional, Type, Union + +import numpy as np +import xarray as xr +from fhs_lib_data_transformation.lib.dimension_transformation import expand_dimensions +from fhs_lib_data_transformation.lib.resample import resample +from fhs_lib_data_transformation.lib.statistic import ( + Quantiles, + weighted_mean_with_extra_dim, + weighted_quantile_with_extra_dim, +) +from fhs_lib_data_transformation.lib.truncate import truncate_dataarray +from fhs_lib_data_transformation.lib.validate import assert_coords_same +from fhs_lib_database_interface.lib.constants import DimensionConstants, ScenarioConstants +from fhs_lib_database_interface.lib.query import location +from fhs_lib_file_interface.lib import xarray_wrapper +from fhs_lib_file_interface.lib.version_metadata import FHSDirSpec +from fhs_lib_year_range_manager.lib.year_range import YearRange +from tiny_structured_logger.lib.fhs_logging import get_logger + +from fhs_lib_model.lib import predictive_validity_metrics as pv_metrics +from fhs_lib_model.lib.constants import ArcMethodConstants +from fhs_lib_model.lib.model_protocol import ModelProtocol + +logger = get_logger() + + +class StatisticSpec: + """A type representing a choice of statistical summary, with its attendant data. + + Can compute a weighted or an unweighted form. See MeanStatistic and QuantileStatistic. + """ + + def weighted_statistic( + self, data: xr.DataArray, stat_dims: List[str], weights: xr.DataArray, extra_dim: str + ) -> xr.DataArray: + """Take a weighted summary statistic on annual_diff.""" + pass + + def unweighted_statistic(self, data: xr.DataArray, stat_dims: List[str]) -> xr.DataArray: + """Take a unweighted statistic on annual_diff.""" + pass + + +class MeanStatistic(StatisticSpec): + """A stat "take the mean of some things." Takes no args.""" + + def weighted_statistic( + self, data: xr.DataArray, stat_dims: List[str], weights: xr.DataArray, extra_dim: str + ) -> xr.DataArray: + """Take a weighted mean on `data`.""" + return weighted_mean_with_extra_dim(data, stat_dims, weights, extra_dim) + + def unweighted_statistic(self, data: xr.DataArray, stat_dims: List[str]) -> xr.DataArray: + """Take an unweighted mean on `dat`a`, over the dimenstions stat_dims.""" + return data.mean(stat_dims) + + +class QuantileStatistic(StatisticSpec): + """The intention of "take some quantiles from the data".""" + + def __init__(self, quantiles: Union[float, Iterable[float]]) -> None: + """Args are the quantile fractions. + + E.g. QuantileStatistic([0.1, 0.9]) represents + the desire to take the 10th percentile and 90th percentile. You may also + pass a single number, as in QuantileStatistic(0.5) for a single quantile, + the median in that case. + """ + if not (isinstance(quantiles, float) or is_iterable_of(float, quantiles)): + raise ValueError("Arg to QuantileStatistic must either float or list of floats") + + self.quantiles = quantiles + + def weighted_statistic( + self, data: xr.DataArray, stat_dims: List[str], weights: xr.DataArray, extra_dim: str + ) -> xr.DataArray: + """Take a weighted set of quantiles on `data`.""" + return weighted_quantile_with_extra_dim( + data, self.quantiles, stat_dims, weights, extra_dim + ) + + def unweighted_statistic(self, data: xr.DataArray, stat_dims: List[str]) -> xr.DataArray: + """Take an unweighted set of quantiles on `data`.""" + return data.quantile(q=self.quantiles, dim=stat_dims) + + +class ArcMethod(ModelProtocol): + """Instances of this class represent an arc_method model. + + Can be fit and used for predicting future estimates. + """ + + # Defined ARC method parameters: + number_of_holdout_years = 10 + omega_step_size = 0.25 + max_omega = 3 + pv_metric = pv_metrics.root_mean_square_error + + def __init__( + self, + past_data: xr.DataArray, + years: YearRange, + draws: int, + gbd_round_id: int, + reference_scenario_statistic: str = "mean", + reverse_scenarios: bool = False, + quantiles: Iterable[float] = ArcMethodConstants.DEFAULT_SCENARIO_QUANTILES, + mean_level_arc: bool = True, + reference_arc_dims: Optional[List[str]] = None, + scenario_arc_dims: Optional[List[str]] = None, + truncate: bool = True, + truncate_dims: Optional[List[str]] = None, + truncate_quantiles: Iterable[float] = ArcMethodConstants.DEFAULT_TRUNCATE_QUANTILES, + replace_with_mean: bool = False, + scenario_roc: str = "all", + pv_results: xr.DataArray = None, + select_omega: bool = True, + omega_selection_strategy: Optional[Callable] = None, + omega: Optional[Union[float, xr.DataArray]] = None, + pv_pre_process_func: Optional[Callable] = None, + single_scenario_mode: bool = False, + **kwargs: Any, + ) -> None: + """Creates a new ``ArcMethod`` model instance. + + Pre-conditions: + =============== + * All given ``xr.DataArray``s must have dimensions with at least 2 + coordinates. This applies for covariates and the dependent variable. + + Args: + past_data (xr.DataArray): Past data for dependent variable being forecasted + years (YearRange): forecasting timeseries + draws (int): Number of draws to generate + gbd_round_id (int): The ID of the GBD round + reference_scenario_statistic (str): The statistic used to make the reference + scenario. If "median" then the reference scenarios is made using the weighted + median of past annualized rate-of-change across all past years, "mean" then it + is made using the weighted mean of past annualized rate-of-change across all + past years. Defaults to "mean". + reverse_scenarios (bool): If ``True``, reverse the usual assumption that high=bad + and low=good. For example, we set to ``True`` for vaccine coverage, because + higher coverage is better. Defaults to ``False``. + quantiles (Iterable[float]): The quantiles to use for better and worse + scenarios. Defaults to ``0.15`` and ``0.85``. + mean_level_arc (bool): If ``True``, then take annual differences for + means-of-draws, instead of draws. Defaults to ``True``. + reference_arc_dims (Optional[List[str]]): To calculate the reference ARC, take + weighted mean or median over these dimensions. Defaults to ["year_id"] when + ``None``. + scenario_arc_dims (Optional[List[str]]): To calculate the scenario ARCs, take + weighted quantiles over these dimensions. Defaults to ["location_id", + "year_id"] when ``None``. + truncate (bool): If ``True``, then truncate (clip) the past data over the given + dimensions. Defaults to ``False``. + truncate_dims (Optional[List[str]]): A list of strings representing the dimensions + to truncate over. If ``None``, truncation occurs over location and year. + truncate_quantiles (Iterable[float]): The two floats representing the quantiles to + take. Defaults to ``0.025`` and ``0.975``. + replace_with_mean (bool): If ``True`` and `truncate` is ``True``, then replace + values outside of the upper and lower quantiles taken across location and year + with the mean across "year_id", if False, then replace with the upper and lower + bounds themselves. Defaults to ``False``. + scenario_roc (str): If "all", then the scenario rate of change is taken over all + locations. If "national_only", roc is taken over national + locations only. Defaults to "all". + pv_results (xr.DataArray): An array of RMSEs resulting from predictive validity + tests. The array has one dimension (weight), and the values are the RMSEs from + each tested weight. When ``pv_results`` is ``None``, the ``fit`` method will + calculate new ``pv_results``. + select_omega (bool): If ``True``, the ``fit`` method will select an omega or create + an omega distribution from ``self.pv_results`` + omega_selection_strategy (Optional[Callable]): Which strategy to use to produce the + omega(s) from the omega-RMSE array, which gets produced in the fit step. + Defaults to ``None``, but must be specified unless you are passing the model an + omega directly. Can be specified as follows: + ``model.oss.name_of_omega_selection_function``. See omega_selection_strategy.py + for all omega selection functions. + omega (Optional[Union[float, xr.DataArray]]): Power to raise the increasing year + weights Must be non-negative. It can be dataarray, but must have only one + dimension, ``draw``. It must have the same coordinates on that dimension as + ``past_data_da``. When omega is ``None``, the fit method will calculate it from + ``self.pv_results`` if select_omega is ``True``. + pv_pre_process_func (Optional[Callable]): Function to call if preprocessing pv + results. + single_scenario_mode (bool): if true, only produces one scenario, not better and + worse. + kwargs (Any): Unused additional keyword arguments + """ + if select_omega and omega_selection_strategy is None: + err_msg = ( + "Must provide an omega_selection_strategy function if select_omega is True." + ) + logger.error(err_msg) + raise ValueError(err_msg) + + self.past_data = past_data + self.years = years + self.draws = draws + self.gbd_round_id = gbd_round_id + self.pv_results = pv_results + self.select_omega = select_omega + self.omega = omega + self.pv_pre_process_func = pv_pre_process_func + self.omega_selection_strategy = omega_selection_strategy + self.reference_scenario_statistic = reference_scenario_statistic + self.reverse_scenarios = reverse_scenarios + self.quantiles = quantiles + self.mean_level_arc = mean_level_arc + self.reference_arc_dims = reference_arc_dims + self.scenario_arc_dims = scenario_arc_dims + self.truncate = truncate + self.truncate_dims = truncate_dims + self.truncate_quantiles = truncate_quantiles + self.replace_with_mean = replace_with_mean + self.scenario_roc = scenario_roc + self.single_scenario_mode = single_scenario_mode + + def fit(self) -> Union[float, xr.DataArray]: + """Runs a predictive validity process to determine omega to use for forecasting. + + If ``self.select_omega`` is ``False``, this will only calculate ``self.pv_results`` + PV results are only calculated when ``self.pv_results`` is ``None``. + + Returns: + float | xr.DataArray: Power to raise the increasing year weights -- must be + nonnegative. It can be dataarray, but must have only one dimension, + DimensionConstants.DRAW. It must have the same coordinates on that dimension + as ``past_data_da``. + """ + holdout_start = self.years.past_end - self.number_of_holdout_years + pv_years = YearRange(self.years.past_start, holdout_start, self.years.past_end) + + holdouts = self.past_data.sel(year_id=pv_years.forecast_years) + omegas_to_test = np.arange( + 0, ArcMethod.max_omega + ArcMethod.omega_step_size, ArcMethod.omega_step_size + ) + + if self.pv_pre_process_func is not None: + holdouts = self.pv_pre_process_func(holdouts) + + if self.pv_results is None: + pv_result_list = [] + for test_omega in omegas_to_test: + predicted = self._arc_method(pv_years, test_omega) + if DimensionConstants.SCENARIO in predicted.coords: + predicted = predicted.sel(scenario=0, drop=True) + + assert_coords_same(predicted, self.past_data) + + predicted_holdouts = predicted.sel(year_id=pv_years.forecast_years) + + if self.pv_pre_process_func is not None: + predicted_holdouts = self.pv_pre_process_func(predicted_holdouts) + + pv_result = ArcMethod.pv_metric(predicted_holdouts, holdouts) + pv_result_da = xr.DataArray( + [pv_result], coords={"weight": [test_omega]}, dims=["weight"] + ) + pv_result_list.append(pv_result_da) + + self.pv_results = xr.concat(pv_result_list, dim="weight") + + if self.select_omega: + self.omega = self.omega_selection_strategy(rmse=self.pv_results, draws=self.draws) + + return self.omega + + def predict(self) -> xr.DataArray: + """Create projections for reference, better, and worse scenarios using the ARC method. + + Returns: + xr.DataArray: Projections for future years made with the ARC method. It will + include all the dimensions and coordinates of the + ``self.past_data``, except that the ``year_id`` dimension will + ONLY have coordinates for all of the years from + ``self.years.forecast_years``. There will also be a new + ``scenario`` dimension with the coordinates 0 for reference, + -1 for worse, and 1 for better. + """ + self.predictions = self._arc_method( + self.years, self.omega, past_resample_draws=self.draws + ).sel(year_id=self.years.forecast_years) + + return self.predictions + + def save_coefficients( + self, output_dir: FHSDirSpec, entity: str, save_omega_draws: bool = False + ) -> None: + """Saves omega. + + I.e. the power to raise the increasing year weights to, and/or PV results, + an array of RMSEs resulting from predictive validity tests. + + Args: + output_dir (Path): directory to save data to + entity (str): name to give output file + save_omega_draws (bool): whether to save omega draws + + Raises: + ValueError: if no omega or PV results present to save + """ + + def is_xarray(da: Any) -> bool: + return isinstance(da, xr.Dataset) or isinstance(da, xr.DataArray) + + if self.omega is None and self.pv_results is None: + err_msg = "No omega or predictive validity results to save" + logger.error(err_msg) + raise ValueError(err_msg) + + if self.omega is not None: + if is_xarray(self.omega) and not save_omega_draws: + logger.debug( + "Computing stats of omega draws", + bindings=dict(model=self.__class__.__name__), + ) + coef_stats = self._compute_stats(self.omega) + elif is_xarray(self.omega) and save_omega_draws: + coef_stats = self.omega + elif isinstance(self.omega, float) or isinstance(self.omega, int): + logger.debug( + "omega is singleton value", + bindings=dict(model=self.__class__.__name__), + ) + coef_stats = xr.DataArray( + [self.omega], + dims=["omega"], + coords={"omega": ["value"]}, + ) + + omega_output_file = output_dir.append_sub_path(("coefficients",)).file( + f"{entity}_omega.nc" + ) + + xarray_wrapper.save_xr_scenario( + coef_stats, + omega_output_file, + metric="rate", + space="identity", + ) + + if self.pv_results is not None: + pv_output_file = output_dir.append_sub_path(("coefficients",)).file( + f"{entity}_omega_rmses.nc" + ) + + xarray_wrapper.save_xr_scenario( + self.pv_results, + pv_output_file, + metric="rate", + space="identity", + ) + + @staticmethod + def _compute_stats(da: xr.DataArray) -> Union[xr.DataArray, xr.Dataset]: + """Compute mean and variance of draws if a ``'draw'`` dim exists. + + Otherwise just return a copy of the original. + + Args: + da (xr.DataArray): data array for computation + + Returns: + Union[xr.DataArray, xr.Dataset]: the computed data + """ + if DimensionConstants.DRAW in da.dims: + mean_da = da.mean(DimensionConstants.DRAW).assign_coords(stat="mean") + var_da = da.var(DimensionConstants.DRAW).assign_coords(stat="var") + stats_da = xr.concat([mean_da, var_da], dim="stat") + else: + logger.warning( + "Draw is NOT a dim, can't compute omega stats", + bindings=dict(model=__class__.__name__, dims=da.dims), + ) + stats_da = da.copy() + return stats_da + + def _arc_method( + self, + years: YearRange, + omega: Union[float, xr.DataArray], + past_resample_draws: Optional[int] = None, + ) -> xr.DataArray: + """Run and return the `arc_method`. + + To keep the PV step and prediction step consistent put the explicit ``arc_method`` + call with all of its defined parameters here. + + Args: + years (YearRange): years to include in the past when calculating ARC + omega (Union[float, xr.DataArray]): the omega to assess for draws + past_resample_draws (Optional[int]): The number of draws to resample from the past + data. This argument is used in the predict step to avoid NaNs in the forecast + when there is a mismatch between the number of draw coordinates in the past + data and the desired number of draw coordinates. + + Returns: + xr.DataArray: result of the `arc_method` function call + """ + omega_dim = ArcMethod._get_omega_dim(omega, self.draws) + + if past_resample_draws is not None and "draw" in self.past_data.dims: + past_data = resample(self.past_data, past_resample_draws) + else: + past_data = self.past_data + + return arc_method( + past_data_da=past_data, + gbd_round_id=self.gbd_round_id, + years=years, + weight_exp=omega, + reference_scenario=self.reference_scenario_statistic, + reverse_scenarios=self.reverse_scenarios, + quantiles=self.quantiles, + diff_over_mean=self.mean_level_arc, + reference_arc_dims=self.reference_arc_dims, + scenario_arc_dims=self.scenario_arc_dims, + truncate=self.truncate, + truncate_dims=self.truncate_dims, + truncate_quantiles=self.truncate_quantiles, + replace_with_mean=self.replace_with_mean, + extra_dim=omega_dim, + scenario_roc=self.scenario_roc, + single_scenario_mode=self.single_scenario_mode, + ) + + @staticmethod + def _get_omega_dim(omega: Union[float, int, xr.DataArray], draws: int) -> Optional[str]: + """Get the omega dimension if passed a data array. + + Args: + omega (Union[float, int, xr.DataArray]): the omega value or data array + draws (int): the number of draws to validate omega against + + Returns: + Optional[str]: ``'draw'``, if ``omega`` contains draw specific omegas as a + dataarray or ``None``, if ``omega`` is float. + + Raises: + ValueError: if `omega` draw dim doesn't have the expected coords + TypeError: if `omega` isn't a float, int, or data array + """ + if isinstance(omega, float) or isinstance(omega, int): + omega_dim = None + elif isinstance(omega, xr.DataArray): + if set(omega.dims) != {DimensionConstants.DRAW}: + err_msg = "`omega` can only have 'draw' as a dim" + logger.error(err_msg) + raise ValueError(err_msg) + elif sorted(list(omega[DimensionConstants.DRAW].values)) != list(range(draws)): + err_msg = "`omega`'s draw dim doesn't have the expected coords" + logger.error(err_msg) + raise ValueError(err_msg) + omega_dim = DimensionConstants.DRAW + else: + err_msg = "`omega` must be either a float, an int, or an xarray.DataArray" + logger.error(err_msg) + raise TypeError(err_msg) + + return omega_dim + + +def arc_method( + past_data_da: xr.DataArray, + gbd_round_id: int, + years: Optional[Iterable[int]] = None, + weight_exp: Union[float, int, xr.DataArray] = 1, + reference_scenario: str = "median", + reverse_scenarios: bool = False, + quantiles: Iterable[float] = ArcMethodConstants.DEFAULT_SCENARIO_QUANTILES, + diff_over_mean: bool = False, + reference_arc_dims: Optional[List[str]] = None, + scenario_arc_dims: Optional[List[str]] = None, + truncate: bool = False, + truncate_dims: Optional[List[str]] = None, + truncate_quantiles: Optional[Iterable[float]] = None, + replace_with_mean: bool = False, + extra_dim: Optional[str] = None, + scenario_roc: str = "all", + single_scenario_mode: bool = False, +) -> xr.DataArray: + """Makes rate forecasts using the Annualized Rate-of-Change (ARC) method. + + Forecasts rates by taking a weighted quantile or weighted mean of + annualized rates-of-change from past data, then walking that weighted + quantile or weighted mean out into future years. + + A reference scenarios is made using the weighted median or mean of past + annualized rate-of-change across all past years. + + Better and worse scenarios are made using weighted 15th and 85th quantiles + of past annualized rates-of-change across all locations and all past years. + + The minimum and maximum are taken across the scenarios (values are + granular, e.g. age/sex/location/year specific) and the minimum is taken as + the better scenario and the maximum is taken as the worse scenario. If + scenarios are reversed (``reverse_scenario = True``) then do the opposite. + + Args: + past_data_da: + A dataarray of past data that must at least of the dimensions + ``year_id`` and ``location_id``. The ``year_id`` dimension must + have coordinates for all the years in ``years.past_years``. + gbd_round_id: + gbd_round_id the data comes from. + years: + years to include in the past when calculating ARC. + weight_exp: + power to raise the increasing year weights -- must be nonnegative. + It can be dataarray, but must have only one dimension, "draw", it + must have the same coordinates on that dimension as + ``past_data_da``. + reference_scenario: + If "median" then the reference scenarios is made using the + weighted median of past annualized rate-of-change across all past + years, "mean" then it is made using the weighted mean of past + annualized rate-of-change across all past years. Defaults to + "median". + reverse_scenarios: + If True, reverse the usual assumption that high=bad and low=good. + For example, we set to True for vaccine coverage, because higher + coverage is better. Defaults to False. + quantiles: + The quantiles to use for better and worse scenarios. Defaults to + ``0.15`` and ``0.85`` quantiles. + diff_over_mean: + If True, then take annual differences for means-of-draws, instead + of draws. Defaults to False. + reference_arc_dims: + To calculate the reference ARC, take weighted mean or median over + these dimensions. Defaults to ["year_id"] + scenario_arc_dims: + To calculate the scenario ARCs, take weighted quantiles over these + dimensions.Defaults to ["location_id", "year_id"] + truncate: + If True, then truncates the dataarray over the given dimensions. + Defaults to False. + truncate_dims: + A list of strings representing the dimensions to truncate over. + truncate_quantiles: + The tuple of two floats representing the quantiles to take. + replace_with_mean: + If True and `truncate` is True, then replace values outside of the + upper and lower quantiles taken across "location_id" and "year_id" + and with the mean across "year_id", if False, then replace with the + upper and lower bounds themselves. + extra_dim: + Extra dimension that exists in `weights` and `data`. It should not + be in `stat_dims`. + scenario_roc: + If "all", then the scenario rate of change is taken over all + locations. If "national_only", roc is taken over national + locations only. Defaults to "all". + single_scenario_mode: + If true, better and worse scenarios are not calculated, and the reference scenario + is returned without a scenario dimension. + + Returns: + Past and future data with reference, better, and worse scenarios. + It will include all the dimensions and coordinates of the input + dataarray and a ``scenario`` dimension with the coordinates 0 for + reference, -1 for worse, and 1 for better. The ``year_id`` + dimension will have coordinates for all of the years from + ``years.years``. + + Raises: + ValueError: If ``weight_exp`` is a negative number or if ``reference_scenario`` + is not "median" or "mean". + """ + logger.debug( + "Inputs for `arc_method` call", + bindings=dict( + years=years, + weight_exp=weight_exp, + reference_scenario=reference_scenario, + reverse_scenarios=reverse_scenarios, + quantiles=quantiles, + diff_over_mean=diff_over_mean, + truncate=truncate, + replace_with_mean=replace_with_mean, + truncate_quantiles=truncate_quantiles, + extra_dim=extra_dim, + ), + ) + + years = YearRange(*years) if years else YearRange(*ArcMethodConstants.DEFAULT_YEAR_RANGE) + + past_data_da = past_data_da.sel(year_id=years.past_years) + + # Create baseline forecasts. Take weighted median or mean only across + # years, so values will be as granular as the inputs (e.g. age/sex/location + # specific) + if reference_scenario == "median": + reference_statistic = QuantileStatistic(0.5) + elif reference_scenario == "mean": + reference_statistic = MeanStatistic() + else: + raise ValueError("reference_scenario must be either 'median' or 'mean'") + + if truncate and not truncate_dims: + truncate_dims = [DimensionConstants.LOCATION_ID, DimensionConstants.YEAR_ID] + + truncate_quantiles = ( + Quantiles(*sorted(truncate_quantiles)) + if truncate_quantiles + else Quantiles(0.025, 0.975) + ) + + reference_arc_dims = reference_arc_dims or [DimensionConstants.YEAR_ID] + reference_change = arc( + past_data_da, + years, + weight_exp, + reference_arc_dims, + reference_statistic, + diff_over_mean=diff_over_mean, + truncate=truncate, + truncate_dims=truncate_dims, + truncate_quantiles=truncate_quantiles, + replace_with_mean=replace_with_mean, + extra_dim=extra_dim, + ) + reference_da = past_data_da.sel(year_id=years.past_end) + reference_change + forecast_data_da = past_data_da.combine_first(reference_da) + + if not single_scenario_mode: + forecast_data_da = _forecast_better_worse_scenarios( + past_data_da=past_data_da, + gbd_round_id=gbd_round_id, + years=years, + weight_exp=weight_exp, + reverse_scenarios=reverse_scenarios, + quantiles=quantiles, + diff_over_mean=diff_over_mean, + scenario_arc_dims=scenario_arc_dims, + replace_with_mean=replace_with_mean, + extra_dim=extra_dim, + scenario_roc=scenario_roc, + forecast_data_da=forecast_data_da, + ) + + return forecast_data_da + + +def _forecast_better_worse_scenarios( + forecast_data_da: xr.DataArray, + past_data_da: xr.DataArray, + gbd_round_id: int, + years: YearRange, + weight_exp: Union[float, int, xr.DataArray], + reverse_scenarios: bool, + quantiles: Iterable[float], + diff_over_mean: bool, + scenario_arc_dims: Optional[List[str]], + replace_with_mean: bool, + extra_dim: Optional[str], + scenario_roc: str, +) -> xr.DataArray: + try: + forecast_data_da = forecast_data_da.rename( + {DimensionConstants.QUANTILE: DimensionConstants.SCENARIO} + ) + except ValueError: + pass # There is no "quantile" point coordinate. + + forecast_data_da[DimensionConstants.SCENARIO] = ScenarioConstants.REFERENCE_SCENARIO_COORD + + # Create better and worse scenario forecasts. Take weighted 85th and 15th + # quantiles across year and location, so values will not be location + # specific (e.g. just age/sex specific). + scenario_arc_dims = scenario_arc_dims or [ + DimensionConstants.LOCATION_ID, + DimensionConstants.YEAR_ID, + ] + if scenario_roc == "national": + nation_ids = location.get_location_set( + gbd_round_id=gbd_round_id, include_aggregates=False, national_only=True + )[DimensionConstants.LOCATION_ID].unique() + + arc_input = past_data_da.sel(location_id=nation_ids) + elif scenario_roc == "all": + arc_input = past_data_da + else: + raise ValueError( + f'scenario_roc should be one of "national" or "all"; got {scenario_roc}' + ) + scenario_change = arc( + arc_input, + years, + weight_exp, + scenario_arc_dims, + QuantileStatistic(quantiles), + diff_over_mean=diff_over_mean, + truncate=False, + replace_with_mean=replace_with_mean, + extra_dim=extra_dim, + ) + + scenario_change = scenario_change.rename( + {DimensionConstants.QUANTILE: DimensionConstants.SCENARIO} + ) + scenarios_da = past_data_da.sel(year_id=years.past_end) + scenario_change + + scenarios_da.coords[DimensionConstants.SCENARIO] = [ + ScenarioConstants.BETTER_SCENARIO_COORD, + ScenarioConstants.WORSE_SCENARIO_COORD, + ] + + forecast_data_da = xr.concat( + [forecast_data_da, scenarios_da], dim=DimensionConstants.SCENARIO + ) + + # Get the minimums and maximums across the scenario dimension, and set + # worse scenarios to the worst (max if normal or min if reversed), and set + # better scenarios to the best (min if normal or max if reversed). + low_values = forecast_data_da.min(DimensionConstants.SCENARIO) + high_values = forecast_data_da.max(DimensionConstants.SCENARIO) + if reverse_scenarios: + forecast_data_da.loc[ + {DimensionConstants.SCENARIO: ScenarioConstants.WORSE_SCENARIO_COORD} + ] = low_values + forecast_data_da.loc[ + {DimensionConstants.SCENARIO: ScenarioConstants.BETTER_SCENARIO_COORD} + ] = high_values + else: + forecast_data_da.loc[ + {DimensionConstants.SCENARIO: ScenarioConstants.BETTER_SCENARIO_COORD} + ] = low_values + forecast_data_da.loc[ + {DimensionConstants.SCENARIO: ScenarioConstants.WORSE_SCENARIO_COORD} + ] = high_values + + forecast_data_da = past_data_da.combine_first(forecast_data_da) + + forecast_data_da = forecast_data_da.loc[ + {DimensionConstants.SCENARIO: sorted(forecast_data_da[DimensionConstants.SCENARIO])} + ] + + return forecast_data_da + + +def arc( + past_data_da: xr.DataArray, + years: YearRange, + weight_exp: Union[float, int, xr.DataArray], + stat_dims: Iterable[str], + statistic: StatisticSpec, + diff_over_mean: bool = False, + truncate: bool = False, + truncate_dims: Optional[List[str]] = None, + truncate_quantiles: Optional[Iterable[float]] = None, + replace_with_mean: bool = False, + extra_dim: Optional[str] = None, +) -> xr.DataArray: + r"""Makes rate forecasts by forecasting the Annualized Rates-of-Change (ARC). + + Uses either weighted means or weighted quantiles. + + The steps for forecasting logged or logitted rates with ARCs are: + + (1) Annualized rate differentials (or annualized rates-of-change if data is + in log or logit space) are calculated. + + .. Math:: + + \vec{D_{p}} = + [x_{1991} - x_{1990}, x_{1992} - x_{1991}, ... x_{2016} - x_{2015}] + + where :math:`x` are values from ``past_data_da`` for each year and + :math:`\vec{D_p}` is the vector of differentials in the past. + + (2) Year weights are used to weight recent years more heavily. Year weights + are made by taking the interval + + .. math:: + + \vec{W} = [1, ..., n]^w + + where :math:`n` is the number of past years, :math:`\vec{w}` is the + value given by ``weight_exp``, and :math:`\vec{W}` is the vector of + year weights. + + (3) Weighted quantiles or the weighted mean of the annualized + rates-of-change are taken over the dimensions. + + .. math:: + + s = \text{weighted-statistic}(\vec{W}, \vec{D}) + + where :math:`s` is the weighted quantile or weighted mean. + + (4) Future rates-of-change are simulated by taking the interval + + .. math:: + + \vec{D_{f}} = [1, ..., m] * s + + where :math:`\vec{D_f}` is the vector of differentials in the future + and :math:`m` is the number of future years to forecast and + + (5) Lastly, these future differentials are added to the rate of the last + observed year. + + .. math:: + + \vec{X_{f}} = \vec{D_{f}} + x_{2016} = [x_{2017}, ..., x_{2040}] + + where :math:`X_{f}` is the vector of forecasted rates. + + Args: + past_data_da: + Past data with a year-id dimension. Must be in log or logit space + in order for this function to actually calculate ARCs, otherwise + it's just calculating weighted statistic of the first differences. + years: + past and future year-ids + weight_exp: + power to raise the increasing year weights -- must be nonnegative. + It can be dataarray, but must have only one dimension, "draw", it + must have the same coordinates on that dimension as + ``past_data_da``. + stat_dims: + list of dimensions to take quantiles over + statistic: A statistic to use to calculate the ARC from the annual + diff, either MeanStatistic() or QuantileStatistic(quantiles). + diff_over_mean: + If True, then take annual differences for means-of-draws, instead + of draws. Defaults to False. + truncate: + If True, then truncates the dataarray over the given dimensions. + Defaults to False. + truncate_dims: + A list of strings representing the dimensions to truncate over. + truncate_quantiles: + The iterable of two floats representing the quantiles to take. + replace_with_mean: + If True and `truncate` is True, then replace values outside of the + upper and lower quantiles taken across "location_id" and "year_id" + and with the mean across "year_id", if False, then replace with the + upper and lower bounds themselves. + extra_dim: + An extra dim to take the `statistic` over. Should exist in + `weights` and `data`. It should not be in `stat_dims`. + + Returns: + Forecasts made using the ARC method. + + Raises: + ValueError: Conditions: + + * If ``statistic`` is ill-formed. + * If ``weight_exp`` is a negative number. + * If `truncate` is True, then `truncate_quantiles` must be a list of floats. + """ + logger.debug( + "Inputs for `arc` call", + bindings=dict( + years=years, + weight_exp=weight_exp, + statistic=statistic, + stat_dims=stat_dims, + diff_over_mean=diff_over_mean, + truncate=truncate, + replace_with_mean=replace_with_mean, + truncate_quantiles=truncate_quantiles, + extra_dim=extra_dim, + ), + ) + + # Calculate the annual differentials. + if diff_over_mean and DimensionConstants.DRAW in past_data_da.dims: + annual_diff = past_data_da.mean(DimensionConstants.DRAW) + else: + annual_diff = past_data_da + annual_diff = annual_diff.sel(year_id=years.past_years).diff( + DimensionConstants.YEAR_ID, n=1 + ) + + if isinstance(weight_exp, xr.DataArray): + if DimensionConstants.DRAW not in weight_exp.dims: # pytype: disable=attribute-error + raise ValueError( + "`weight_exp` must be a float, an int, or an xarray.DataArray " + "with a 'draw' dimension" + ) + + # If annual-differences were taken over means (`annual_diff` doesn't have a "draw" + # dimension), but `year_weights` does have a "draw" dimension, then the draw dimension + # needs to be expanded for `annual_diff` such that the mean is replicated for each draw + if DimensionConstants.DRAW not in annual_diff.dims: + annual_diff = expand_dimensions( + annual_diff, draw=weight_exp[DimensionConstants.DRAW].values + ) + weight_exp = expand_dimensions( + weight_exp, year_id=annual_diff[DimensionConstants.YEAR_ID].values + ) + + year_weights = ( + xr.DataArray( + (np.arange(len(years.past_years) - 1) + 1), + dims=DimensionConstants.YEAR_ID, + coords={DimensionConstants.YEAR_ID: years.past_years[1:]}, + ) + ** weight_exp + ) + + if truncate: + if not is_iterable_of(float, truncate_quantiles): + raise ValueError( + "If `truncate` is True, then `truncate_quantiles` must be a list of floats." + ) + + truncate_dims = truncate_dims or [ + DimensionConstants.LOCATION_ID, + DimensionConstants.YEAR_ID, + ] + truncate_quantiles = Quantiles(*sorted(truncate_quantiles)) + annual_diff = truncate_dataarray( + annual_diff, + truncate_dims, + replace_with_mean=replace_with_mean, + mean_dims=[DimensionConstants.YEAR_ID], + weights=year_weights, + quantiles=truncate_quantiles, + extra_dim=extra_dim, + ) + + stat_dims = list(stat_dims) + + if (xr.DataArray(weight_exp) > 0).any(): + arc_da = statistic.weighted_statistic(annual_diff, stat_dims, year_weights, extra_dim) + elif (xr.DataArray(weight_exp) == 0).all(): + # If ``weight_exp`` is zero, then just take the unweighted mean or + # quantile. + arc_da = statistic.unweighted_statistic(annual_diff, stat_dims) + else: + raise ValueError("weight_exp must be nonnegative.") + + # Find future change by multiplying an array that counts the future + # years, by the quantiles, which is weighted if `weight_exp` > 0. We want + # the multipliers to start at 1, for the first year of forecasts, and count + # to one more than the number of years to forecast. + forecast_year_multipliers = xr.DataArray( + np.arange(len(years.forecast_years)) + 1, + dims=[DimensionConstants.YEAR_ID], + coords={DimensionConstants.YEAR_ID: years.forecast_years}, + ) + future_change = arc_da * forecast_year_multipliers + return future_change + + +def is_iterable_of(type: Type, obj: Any) -> bool: + """True iff the obj is an iterable containing only instances of the given type.""" + return hasattr(obj, "__iter__") and all([isinstance(item, type) for item in obj]) + + +def approach_value_by_year( + past_data: xr.DataArray, + years: YearRange, + target_year: int, + target_value: float, + method: str = "linear", +) -> xr.DataArray: + """Forecasts cases where a target level at a target year is known. + + For e.g., the Rockefeller project for min-risk diet scenarios, wanted to + see the effect of eradicating diet related risks by 2030 on mortality. For + this we need to reach 0 SEV for all diet related risks by 2030 and keep + the level constant at 0 for further years. Here the target_year is 2030 + and target_value is 0. + + Args: + past_data: + The past data with all past years. + years: + past and future year-ids + target_year: + The year at which the target value will be reached. + target_value: + The target value that needs to be achieved during the target year. + method: + The extrapolation method to be used to calculate the values for + intermediate years (years between years.past_end and target_year). + The method currently supported is: `linear`. + + Raises: + ValueError: if method != "linear" + + Returns: + The forecasted results. + """ + if method == "linear": + forecast = _linear_then_constant_arc(past_data, years, target_year, target_value) + else: + raise ValueError( + f"Method {method} not recognized. Please see the documentation for" + " the list of supported methods." + ) + + return forecast + + +def _linear_then_constant_arc( + past_data: xr.DataArray, years: YearRange, target_year: int, target_value: float +) -> xr.DataArray: + r"""Makes rate forecasts by linearly extrapolating. + + Extrapolates the point ARC from the last past year till the target year to reach the target + value. + + The steps for extrapolating the point ARCs are: + + (1) Calculate the rate of change between the last year of the past + data (eg.2017) and ``target_year`` (eg. 2030). + + .. Math:: + + R = + \frac{target\_value - past\_last\_year_value} + {target\_year- past\_last\_year} + + where :math:`R` is the slope of the desired linear trend. + + (2) Calculate the rates of change between the last year of the past and + each future year by multiplying R with future year weights till + ``target_year``. + + .. math:: + + \vec{W} = [1, ..., m] + + \vec{F_r} = \vec{W} * R + + where :math:`m` is the number of years between the ``target_year`` and + the last year of the past, and :math:`\vec{W}` forms the vector of + year weights. + :math:`\vec{F_r}` contains the linearly extrapolated ARCs for each + future year till the ``target_year``. + + (3) Add the future rates :math: `\vec{F_r}` to last year of the past + (eg. 2017) to get the forecasted results. + + (4) Extend the forecasted results till the ``forecast_end`` year by + filling the ``target_value`` for all the remaining future years. + + Args: + past_data: + The past data with all past years. The data is assumed to be in + normal space. + years: + past and future year-ids + target_year: + The year at which the target value will be reached. + target_value: + The value that needs to be achieved by the `target_year`. + + Returns: + The forecasted results. + """ + pre_target_years = np.arange(years.forecast_start, target_year + 1) + post_target_years = np.arange(target_year + 1, years.forecast_end + 1) + + past_last_year = past_data.sel(year_id=years.past_end) + target_yr_arc = (target_value - past_last_year) / (target_year - years.past_end) + + forecast_year_multipliers = xr.DataArray( + np.arange(len(pre_target_years)) + 1, + dims=[DimensionConstants.YEAR_ID], + coords={DimensionConstants.YEAR_ID: pre_target_years}, + ) + + future_change = target_yr_arc * forecast_year_multipliers + forecast_bfr_target_year = past_last_year + future_change + + forecast = expand_dimensions( + forecast_bfr_target_year, fill_value=target_value, year_id=post_target_years + ) + + return forecast diff --git a/gbd_2021/disease_burden_forecast_code/met_need/collect_submodels.py b/gbd_2021/disease_burden_forecast_code/met_need/collect_submodels.py new file mode 100644 index 0000000..3c43ec0 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/met_need/collect_submodels.py @@ -0,0 +1,321 @@ +"""Script to collect and collapse components into genem for future stage. +""" + +from typing import Callable, List + +import numpy as np +import pandas as pd +import xarray as xr +from fhs_lib_data_transformation.lib.dimension_transformation import expand_dimensions +from fhs_lib_data_transformation.lib.resample import resample +from fhs_lib_file_interface.lib.pandas_wrapper import read_csv +from fhs_lib_file_interface.lib.version_metadata import FHSFileSpec +from fhs_lib_file_interface.lib.versioning import Versions +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr_scenario, save_xr_scenario +from fhs_lib_year_range_manager.lib import YearRange +from tiny_structured_logger.lib import fhs_logging + +from fhs_lib_genem.lib.constants import ( + FileSystemConstants, + ModelConstants, + ScenarioConstants, + SEVConstants, + TransformConstants, +) + +logger = fhs_logging.get_logger() + + +def entity_specific_collection( + entity: str, + stage: str, + versions: Versions, + gbd_round_id: int, + years: YearRange, + transform: str, + intercept_shift_from_reference: bool, + uncross_scenarios: bool, +) -> None: + """Collect, sample, collapse, and export a given risk. + + Args: + entity (str): risk to collect across omegas. If intrinsic SEV, + then the rei will look like acause-rei. + stage (str): stage of run (sev, mmr, etc.) + versions (Versions): input and output versions + gbd_round_id (int): gbd round id. + years (YearRange): past_start:forecast_start:forecast_end. + transform (str): name of transform to use for processing (logit, log, no-transform). + intercept_shift_from_reference (bool): If True, and we are in multi-scenario mode, then + the intercept-shifting during the above `transform` is calculated from the + reference scenario but applied to all scenarios; if False then each scenario will + get its own shift amount. + uncross_scenarios (bool): whether to fix crossed scenarios. This is currently only used + for sevs and should be deprecated soon. + + """ + input_model_weights_version_metadata = versions.get(past_or_future="future", stage=stage) + input_model_weights_file_spec = FHSFileSpec( + version_metadata=input_model_weights_version_metadata, + filename=ModelConstants.MODEL_WEIGHTS_FILE, + ) + + omega_df = read_csv(file_spec=input_model_weights_file_spec, keep_default_na=False) + + locations: List[int] = omega_df["location_id"].unique().tolist() + + future_da = get_location_draw_omegas( + versions=versions, + gbd_round_id=gbd_round_id, + stage=stage, + entity=entity, + omega_df=omega_df, + locations=locations, + ) + + # Every entity has many rows, and the "intrinsic" and "subfolder" values + # should be the same over all rows. So we only need first row here. + first_row = omega_df.query(f"entity == '{entity}'").iloc[0] + + intrinsic, subfolder = bool(first_row["intrinsic"]), str(first_row["subfolder"]) + + if intrinsic: + file_name = f"{entity}_{SEVConstants.INTRINSIC_SEV_FILENAME_SUFFIX}.nc" + else: + file_name = f"{entity}.nc" + + if intrinsic: + # Per research decision, we set all intrinsic scenarios to reference + non_ref_scenarios = [ + s + for s in future_da["scenario"].values + if s != ScenarioConstants.REFERENCE_SCENARIO_COORD + ] + for scenario in non_ref_scenarios: + future_da.loc[{"scenario": scenario}] = future_da.sel( + scenario=ScenarioConstants.REFERENCE_SCENARIO_COORD + ) + + logger.info(f"Entering intercept-shift of {entity} submodel") + future_da = intercept_shift_processing( + stage=stage, + versions=versions, + gbd_round_id=gbd_round_id, + years=years, + transform=transform, + subfolder=subfolder, + future_da=future_da, + file_name=file_name, + shift_from_reference=intercept_shift_from_reference, + ) + + if uncross_scenarios: + future_da = fix_scenario_crossing(years=years, future_da=future_da) + + output_version_metadata = versions.get(past_or_future="future", stage=stage) + + output_file_spec = FHSFileSpec( + version_metadata=output_version_metadata, + sub_path=(subfolder,), + filename=file_name, + ) + + save_xr_scenario( + xr_obj=future_da, + file_spec=output_file_spec, + metric="rate", + space="identity", + years=str(years), + past_version=str(versions.get_version_metadata(past_or_future="past", stage=stage)), + out_version=str(versions.get_version_metadata(past_or_future="future", stage=stage)), + gbd_round_id=gbd_round_id, + ) + + +def read_location_draws( + file_spec: FHSFileSpec, location_id: int, draw_start: int, n_draws: int +) -> xr.DataArray: + """Read location-draws from file. + + Notably, this function will expand or contract the number of draws present to fit inside + the closed range [`draw_start`, `draw_start` + `n_draws`], *reassigning coordinates* from + whatever they are read in as. + """ + da = open_xr_scenario(file_spec).sel(location_id=location_id).load() + if "draw" in da.dims: # some sub-models may be draw-less + da = resample(da, n_draws) + da = da.assign_coords(draw=range(draw_start, draw_start + n_draws)) + else: + da = expand_dimensions(da, draw=range(draw_start, draw_start + n_draws)) + return da + + +def fix_scenario_crossing(years: YearRange, future_da: xr.DataArray) -> xr.DataArray: + """Scenario cross the future data and fill missing results within [0, 1].""" + + # Ensure same years.past_end values across scenarios after transformations + future_da_ref = future_da.sel(scenario=ScenarioConstants.REFERENCE_SCENARIO_COORD) + future_da_worse = future_da.sel(scenario=ScenarioConstants.WORSE_SCENARIO_COORD) + future_da_better = future_da.sel(scenario=ScenarioConstants.BETTER_SCENARIO_COORD) + + future_worse_diff = future_da_worse.sel(year_id=years.past_end) - future_da_ref.sel( + year_id=years.past_end + ) + future_better_diff = future_da_better.sel(year_id=years.past_end) - future_da_ref.sel( + year_id=years.past_end + ) + + future_new_worse = future_da_worse - future_worse_diff + future_new_better = future_da_better - future_better_diff + + future_da = xr.concat([future_new_worse, future_da_ref, future_new_better], dim="scenario") + + dam = future_da.mean("draw") + + # For SEV's, worse >= ref >= better + worse = dam.sel(scenario=ScenarioConstants.WORSE_SCENARIO_COORD) + better = dam.sel(scenario=ScenarioConstants.BETTER_SCENARIO_COORD) + ref = dam.sel(scenario=ScenarioConstants.REFERENCE_SCENARIO_COORD) + + worse_diff = ref - worse # should be <= 0 for SEV, so we keep the > 0's + worse_diff = worse_diff.where(worse_diff < 0).fillna(0) # keep > 0's + + better_diff = ref - better # should be >= 0 for SEV, so we keep the < 0's + better_diff = better_diff.where(better_diff > 0).fillna(0) # keep < 0's + + # the worse draws that are below ref will have > 0 values added to them + future_da.loc[dict(scenario=ScenarioConstants.WORSE_SCENARIO_COORD)] = ( + future_da.sel(scenario=ScenarioConstants.REFERENCE_SCENARIO_COORD) - worse_diff + ) + # the better draws that are above ref will have < 0 values added to them + future_da.loc[dict(scenario=ScenarioConstants.BETTER_SCENARIO_COORD)] = ( + future_da.sel(scenario=ScenarioConstants.REFERENCE_SCENARIO_COORD) - better_diff + ) + + # It's also decided that non-ref scenarios should not have uncertainty + dim_order = ["draw"] + [x for x in future_da.dims if x != "draw"] + future_da = future_da.transpose(*dim_order) # draw-dim to 1st to broadcast + + # per RT meeting 20211201, we'll no longer save computed past SEVs + needed_years = np.concatenate(([years.past_end], years.forecast_years)) + future_da = future_da.sel(year_id=needed_years) + + future_da = future_da.where(future_da <= 1).fillna(1) + future_da = future_da.where(future_da >= 0).fillna(0) + + return future_da + + +def intercept_shift_processing( + stage: str, + versions: Versions, + gbd_round_id: int, + years: YearRange, + transform: str, + subfolder: str, + future_da: xr.DataArray, + file_name: str, + shift_from_reference: bool, +) -> xr.DataArray: + """Perform ordered draw intercept shifting of past and future data.""" + # Here we do ordered-draw intercept-shift to ensure uncertainty fan-out + past_version_metadata = versions.get(past_or_future="past", stage=stage) + + past_file_spec = FHSFileSpec( + version_metadata=past_version_metadata, + sub_path=(subfolder,), + filename=file_name, + ) + + past_da = open_xr_scenario(past_file_spec).sel( + sex_id=future_da["sex_id"], + age_group_id=future_da["age_group_id"], + location_id=future_da["location_id"], + ) + + if "draw" in past_da.dims and "draw" in future_da.dims: + past_da = resample(past_da, len(future_da.draw.values)) + + if "acause" in future_da.coords: + future_da = future_da.drop_vars("acause") + + if "acause" in past_da.coords: + past_da = past_da.drop_vars("acause") + + if transform != "no-transform": + # NOTE logit transform requires all inputs > 0, but some PAFs can be < 0 + # Perhaps the right thing to do is to follow the scalars pipeline + past_da = past_da.where(past_da >= ModelConstants.LOGIT_OFFSET).fillna( + ModelConstants.LOGIT_OFFSET + ) + + processor_class = TransformConstants.TRANSFORMS[transform] + future_da = processor_class.intercept_shift( + modeled_data=future_da, + past_data=past_da, + years=years, + offset=ModelConstants.LOGIT_OFFSET, + intercept_shift="unordered_draw", + shift_from_reference=shift_from_reference, + ) + + return future_da + + +def get_location_draw_omegas( + entity: str, + versions: Versions, + gbd_round_id: int, + stage: str, + omega_df: pd.DataFrame, + locations: List[int], + read_location_draws_fn: Callable = read_location_draws, +) -> xr.DataArray: + """Loop over locations and read location-draw omega files.""" + loc_das = [] + + for location_id in locations: + rows = omega_df.query(f"entity == '{entity}' & location_id == {location_id}") + + if len(rows) == 0: + raise ValueError(f"{entity} for loc {location_id} has no weight info") + + omega_das = [] # to collect the omegas + draw_start = 0 + + for _, row in rows.iterrows(): # each row is an omega-model + omega, model_name, n_draws, intrinsic, subfolder = ( + float(row["omega"]), + str(row["model_name"]), + int(row["draws"]), + bool(row["intrinsic"]), + str(row["subfolder"]), + ) + + if n_draws < 1: # this could happen if inverse_rmse_order == True + continue + + if intrinsic: + file_name = f"{entity}_{SEVConstants.INTRINSIC_SEV_FILENAME_SUFFIX}_{omega}.nc" + else: + file_name = f"{entity}_{omega}.nc" + + version_metadata = versions.get(past_or_future="future", stage=stage) + + file_spec = FHSFileSpec( + version_metadata=version_metadata, + sub_path=(FileSystemConstants.SUBMODEL_FOLDER, model_name, subfolder), + filename=file_name, + ) + + omega_das.append( + read_location_draws_fn(file_spec, location_id, draw_start, n_draws) + ) + + draw_start = draw_start + n_draws + + loc_das.append(xr.concat(omega_das, dim="draw", coords="minimal")) + + future_da = xr.concat(loc_das, dim="location_id", coords="minimal") + + return future_da diff --git a/gbd_2021/disease_burden_forecast_code/met_need/constants.py b/gbd_2021/disease_burden_forecast_code/met_need/constants.py new file mode 100644 index 0000000..e151e8d --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/met_need/constants.py @@ -0,0 +1,201 @@ +"""FHS Pipeline for BMI forecasting Local Constants.""" + +from fhs_lib_data_transformation.lib import processing +from fhs_lib_database_interface.lib.constants import ( + ScenarioConstants as ImportedScenarioConstants, +) +from frozendict import frozendict + + +class EntityConstants: + """Constants related to entities (e.g. acauses).""" + + DEFAULT_ENTITY = "default_entity" + MALARIA_ENTITIES = ["malaria", "malaria_act", "malaria_itn"] + NO_SEX_SPLIT_ENTITY = [ + "abuse_csa_male", + "abuse_csa_female", + "abuse_ipv", + "abuse_ipv_exp", + "met_need", + "nutrition_iron", + "inj_homicide_gun_abuse_ipv_paf", + "inj_homicide_other_abuse_ipv_paf", + "inj_homicide_knife_abuse_ipv_paf", + ] + MALARIA = "malaria" + ACT_ITN_COVARIATE = "act-itn" + + +class LocationConstants: + """Constants used for malaria locations.""" + + # locaions with ACT/ITN interventions + MALARIA_ACT_ITN_LOCS = [ + 168, + 175, + 200, + 201, + 169, + 205, + 202, + 171, + 170, + 178, + 179, + 173, + 207, + 208, + 206, + 209, + 172, + 180, + 210, + 181, + 211, + 184, + 212, + 182, + 213, + 214, + 185, + 522, + 216, + 217, + 187, + 435, + 204, + 218, + 189, + 190, + 191, + 198, + 176, + ] + # locaions without ACT/ITN interventions + NON_MALARIA_ACT_ITN_LOCS = [ + 128, + 129, + 130, + 131, + 132, + 133, + 7, + 135, + 10, + 11, + 12, + 13, + 139, + 15, + 16, + 142, + 18, + 19, + 20, + 152, + 26, + 28, + 157, + 30, + 160, + 161, + 162, + 163, + 164, + 165, + 68, + 203, + 215, + 108, + 111, + 113, + 114, + 118, + 121, + 122, + 123, + 125, + 127, + 193, + 195, + 196, + 197, + 177, + ] + + +class ModelConstants: + """Constants used in forecasting.""" + + FLOOR = 1e-6 + LOGIT_OFFSET = 1e-8 + MIN_RMSE = 1e-8 + + DIFF_OVER_MEAN = True # ARC is computed as the difference over mean values + + MODEL_WEIGHTS_FILE = "all_model_weights.csv" + + +class ScenarioConstants(ImportedScenarioConstants): + """Constants related to scenarios.""" + + DEFAULT_BETTER_QUANTILE = 0.15 + DEFAULT_WORSE_QUANTILE = 0.85 + + +class FileSystemConstants: + """Constants for the file system organization.""" + + PV_FOLDER = "pv" + SUBMODEL_FOLDER = "sub_models" + + +class SEVConstants: + """Constants used in SEVs forecasting.""" + + INTRINSIC_SEV_FILENAME_SUFFIX = "intrinsic" + + +class TransformConstants: + """Constants for transformations used during entity forecasting.""" + + TRANSFORMS = frozendict( + { + "logit": processing.LogitProcessor, + "log": processing.LogProcessor, + "no-transform": processing.NoTransformProcessor, + } + ) + + +class JobConstants: + """Constants related to submitting jobs.""" + + DEFAULT_RUNTIME = "12:00:00" + + COLLECT_SUBMODELS_RUNTIME = "05:00:00" + MRBRT_RUNTIME = "16:00:00" + + COLLECT_SUBMODELS_MEM_GB = 280 + MRBRT_MEM_GB = 275 + MODEL_WEIGHTS_MEM_GB = 160 + + COLLECT_SUBMODELS_NUM_CORES = 8 + + +class OrchestrationConstants: + """Constants used for ensemble model orchestration.""" + + OMEGA_MIN = 0.0 + OMEGA_MAX = 3.0 + OMEGA_STEP_SIZE = 0.5 + + SUBFOLDER = "risk_acause_specific" + + PV_SUFFIX = "_pv" + N_HOLDOUT_YEARS = 10 # number of holdout years for predictive validity runs + + ARC_TRANSFORM = "logit" + ARC_TRUNCATE_QUANTILES = (0.025, 0.975) + ARC_REFERENCE_SCENARIO = "mean" diff --git a/gbd_2021/disease_burden_forecast_code/met_need/get_model_weights_from_holdouts.py b/gbd_2021/disease_burden_forecast_code/met_need/get_model_weights_from_holdouts.py new file mode 100644 index 0000000..051649d --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/met_need/get_model_weights_from_holdouts.py @@ -0,0 +1,210 @@ +"""Collects submodel predictive validity statistics to compile sampling weights for genem. + +The output is a file called `all_model_weights.csv` in the out version +""" + +import glob +import os +import re +from typing import List, Tuple + +import numpy as np +import pandas as pd +from fhs_lib_file_interface.lib.pandas_wrapper import read_csv, write_csv +from fhs_lib_file_interface.lib.version_metadata import ( + FHSDirSpec, + FHSFileSpec, + VersionMetadata, +) + +from fhs_lib_genem.lib.constants import FileSystemConstants, ModelConstants, SEVConstants +from fhs_lib_genem.lib.model_restrictions import ModelRestrictions + + +def pv_file_name_breakup(file_path: str) -> Tuple[str, str]: + r"""Parse full path to predictive-validity (PV) file, extracting entity name and suffix. + + Ex: if file_path is /ihme/forecasting/data/6/future/paf/huh/dud_pv.csv, + then entity = "dud", suffix = "_pv.csv". + If \*/dud_intrinsic_pv.csv, then entity = "dud", + suffix = "_intrinsic_pv.csv". + + Args: + file_path (str): full pv-file path, expected to end with "_pv.csv". + + Returns: + Tuple[str, str]: entity and suffix + """ + filename = os.path.basename(file_path) + match = re.match(r"(.*?)((_intrinsic)?_pv.csv)", filename) + if not match: + raise ValueError( + "PV file path should be of the form 'foo_pv.csv' or 'foo_intrinsic_pv.csv'" + ) + entity, suffix, _ = match.groups() + if not entity or not suffix: + raise ValueError( + "PV file path should be of the form 'foo_pv.csv' or 'foo_intrinsic_pv.csv'" + ) + return entity, suffix + + +def collect_model_rmses( + out_version: VersionMetadata, + gbd_round_id: int, + submodel_names: List[str], + subfolder: str, +) -> pd.DataFrame: + """Collect submodel omega rmse values into a dataframe. + + Loops over pv versions, parses all entity-specific _pv.csv files, + including subfolders. + + Args: + out_version (VersionMetadata): the output version for the whole model. Where this + function looks for submodels. + gbd_round_id (int): gbd_round_id used in the model; used for looking for submodels if + not provided with out_version + submodel_names (List[str]): names of all the sub-models to collect. + subfolder (str): subfolder name where intrinsics are stored. + + Returns: + (pd.DataFrame): Dataframe that contains all columns needed to compute + ensemble weights. + """ + combined_pv_df = pd.DataFrame([]) + + # loop over versions and entities + for submodel_name in submodel_names: + input_dir_spec = FHSDirSpec( + version_metadata=out_version, + sub_path=( + FileSystemConstants.PV_FOLDER, + submodel_name, + ), + ) + subfolder_dir_spec = FHSDirSpec( + version_metadata=out_version, + sub_path=( + FileSystemConstants.PV_FOLDER, + submodel_name, + subfolder, + ), + ) + + if not input_dir_spec.data_path().exists(): + raise FileNotFoundError(f"No such directory {input_dir_spec.data_path()}") + files = glob.glob(str(input_dir_spec.data_path() / "*_pv.csv")) + entities = dict([pv_file_name_breakup(file_path) for file_path in files]) + + sub_entities = {} + if (subfolder_dir_spec.data_path()).exists(): # check out the subfolder + files = glob.glob(str(subfolder_dir_spec.data_path() / "*_pv.csv")) + sub_entities = dict([pv_file_name_breakup(file_path) for file_path in files]) + entities.update(sub_entities) + + for ent, suffix in entities.items(): + sub_dir = subfolder if ent in sub_entities else "" + suffix = entities[ent] + + input_file_spec = FHSFileSpec( + version_metadata=input_dir_spec.version_metadata, + sub_path=tuple(list(input_dir_spec.sub_path) + [sub_dir]), + filename=ent + suffix, + ) + pv_df = read_csv(input_file_spec, keep_default_na=False) + + pv_df["model_name"] = submodel_name + pv_df["subfolder"] = sub_dir + pv_df["intrinsic"] = ( + True if SEVConstants.INTRINSIC_SEV_FILENAME_SUFFIX in suffix else False + ) + combined_pv_df = combined_pv_df.append(pv_df) + + # just to move the "entity" column to the front + if "entity" in combined_pv_df.columns: + combined_pv_df = combined_pv_df[ + ["entity"] + [col for col in combined_pv_df.columns if col != "entity"] + ] + + return combined_pv_df + + +def make_omega_weights( + submodel_names: List[str], + subfolder: str, + out_version: VersionMetadata, + gbd_round_id: int, + draws: int, + model_restrictions: ModelRestrictions, +) -> None: + """Collect submodel omega rmse values into a dataframe. + + Loops over pv versions, parses all entity-specific _pv.csv files, + including subfolders. + + Args: + submodel_names (List[str]): names of all the sub-models to collect. + gbd_round_id (int): gbd round id. + subfolder (str): subfolder name where intrinsics are stored. + out_version (VersionMetadata): the output version for the whole model. Where this + function looks for submodels. + gbd_round_id (int): gbd_round_id used in the model; used for looking for submodels if + not provided with out_version + draws (int): number of total draws for the ensemble. + model_restrictions (ModelRestrictions): any arc-only, mrbrt-only restrictions. + """ + df = collect_model_rmses( + out_version=out_version, + gbd_round_id=gbd_round_id, + submodel_names=submodel_names, + subfolder=subfolder, + ) + + out = pd.DataFrame([]) + + for entity in df["entity"].unique(): + for location_id in df["location_id"].unique(): + ent_loc_df = df.query(f"entity == '{entity}' & location_id == {location_id}") + + model_type = model_restrictions.model_type(entity, location_id) + + if model_type == "arc": + # we effectively pull 0 draws from those where rmse == np.inf + ent_loc_df.loc[ent_loc_df["model_name"] != "arc", "rmse"] = np.inf + + if model_type == "mrbrt": + # we effectively pull 0 draws from those where rmse == np.inf + ent_loc_df.loc[ent_loc_df["model_name"] == "arc", "rmse"] = np.inf + + # we use rmse values to determine draws sampled from submodels + ent_loc_df = ent_loc_df.sort_values(by="rmse", ascending=True) + + # use 1/rmse to determine weight/draws + rmse = ent_loc_df["rmse"] + ModelConstants.MIN_RMSE # padding in case of 0 + rmse_recip = 1 / rmse + model_wts = rmse_recip / rmse_recip.sum() + + # lowest rmse contributes the most draws + sub_draws = (np.round(model_wts, 3) * draws).astype(int) + + # in the event that sum(sub_draws) != draws, we make up the diff + # by adding the diff to the first element + if sub_draws.sum() != draws: + sub_draws.iloc[0] += draws - sub_draws.sum() + + # now assign sub-model weight and draws to df + ent_loc_df["model_weight"] = model_wts + ent_loc_df["draws"] = sub_draws + + out = out.append(ent_loc_df) + + write_csv( + df=out, + file_spec=FHSFileSpec( + version_metadata=out_version, filename=ModelConstants.MODEL_WEIGHTS_FILE + ), + sep=",", + na_rep=".", + index=False, + ) diff --git a/gbd_2021/disease_burden_forecast_code/met_need/model_restrictions.py b/gbd_2021/disease_burden_forecast_code/met_need/model_restrictions.py new file mode 100644 index 0000000..4ddc152 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/met_need/model_restrictions.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from collections import defaultdict +from typing import Any, Iterable, List, Tuple, Union + +import yaml + +ALL_CATEGORIES = "all" +BOTH_MODELS = "both" + + +class ModelRestrictions: + """A class for capturing restrictions in which models get run for each entity/location.""" + + def __init__(self, restrictions: Iterable[Tuple[str, Union[int, str], str]]) -> None: + """Initializer. + + Args: + restrictions: A list of tuples each of which specifies a particular restriction. + The tuples contain an entity, a location and the model type to use (in that + order). The entity and location can also be "all", to indicate that it + applies to all entities. Entity-specific restrictions take precedence over + location-specific restrictions. + """ + self._original_specification = list(restrictions) + self._map = defaultdict(dict) + for restriction in self._original_specification: + entity = restriction[0] + location = ( + ALL_CATEGORIES if restriction[1] == ALL_CATEGORIES else int(restriction[1]) + ) + model_type = restriction[2] + + if location in self._map[entity]: + raise ValueError( + f"Restriction list includes multiple restrictions for {entity}/{location}" + ) + + self._map[entity][location] = model_type + + def model_type(self, entity: str, location_id: int) -> str: + """Get the model type to use for an entity and location, according to the restrictions. + + Args: + entity: entity to look up the restriction for. + location: location to look up the restriction for. + + Returns: + Either a model type name to use for the entity/location_id, or "both", to indicate + that both model types should be used. + + """ + if entity in self._map: + if location_id in self._map[entity]: + return self._map[entity][location_id] + elif ALL_CATEGORIES in self._map[entity]: + return self._map[entity][ALL_CATEGORIES] + else: + return BOTH_MODELS + elif ALL_CATEGORIES in self._map: + if location_id in self._map[ALL_CATEGORIES]: + return self._map[ALL_CATEGORIES][location_id] + elif ALL_CATEGORIES in self._map[ALL_CATEGORIES]: + return self._map[ALL_CATEGORIES][ALL_CATEGORIES] + else: + return BOTH_MODELS + else: + return BOTH_MODELS + + def string_specifications(self) -> List[str]: + """Returns a list of strings for the specifications used to initialize the object. + + Useful for serializing the object on the command-line, essentially a representation of + the way the object was initialized. + """ + return [ + " ".join([str(field) for field in spec]) for spec in self._original_specification + ] + + def __eq__(self, other: ModelRestrictions) -> bool: + """Do they have the same underlying dict?""" + return self._map == other._map + + @staticmethod + def yaml_representer(dumper: Any, data: ModelRestrictions) -> str: + """Function for passing to pyyaml telling it how to represent ModelRestrictions. + + This specific tag used tells pyyaml not tuse a tag. + + Args: + dumper: pyyaml dumper + data: ModelRestrictions object ot serialize + """ + return dumper.represent_sequence("tag:yaml.org,2002:seq", data._original_specification) + +yaml.SafeDumper.add_representer( + ModelRestrictions, ModelRestrictions.yaml_representer +) diff --git a/gbd_2021/disease_burden_forecast_code/met_need/omega_selection_strategy.py b/gbd_2021/disease_burden_forecast_code/met_need/omega_selection_strategy.py new file mode 100644 index 0000000..3813441 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/met_need/omega_selection_strategy.py @@ -0,0 +1,355 @@ +"""Strategies for determining the weight for the Annualized Rate-of-Change (ARC) method. + +Find where the RMSE is 1 (the RMSE is normalized so that 1 is always the lowest +RMSE). If there are ties, take the lowest weight. + +There two options for choosing the weight: +1) Use the weight where the normalized-RMSE is 1. +2) If none of the weights have a normalized-RMSE no more than the +""" + +from typing import Any + +import numpy as np +import xarray as xr +from fhs_lib_database_interface.lib.constants import DimensionConstants +from tiny_structured_logger.lib.fhs_logging import get_logger + +from fhs_lib_model.lib.constants import ArcMethodConstants + +logger = get_logger() + + +def use_omega_with_lowest_rmse(rmse: xr.DataArray, **kwargs: Any) -> float: + """Use the omega (weight) with the lowest RMSE. + + If there are ties, choose the smallest omega. + + Args: + rmse: + Array with one dimension, "weight", that contains the tested + omegas as coordinates. The data is the RMSE (Root _Mean_ Square + Error or Root _Median_ Square Error) values. + kwargs: + Ignores any additional keyword args. + + Returns: + The weight to use for the ARC method. + """ + chosen_weight = rmse.where(rmse == rmse.min()).dropna("weight")["weight"].values[0] + + logger.debug(f"`use_omega_with_lowest_rmse` weight selected: {chosen_weight}") + return chosen_weight + + +def use_smallest_omega_within_threshold( + rmse: xr.DataArray, threshold: float = 0.05, **kwargs: Any +) -> float: + """Returns the smallest omega possible compared to normalized-RMSE, using a threshold. + + If none of the weights have a normalized-RMSE (normalized by dividing by + minimum RMSE) no more than the threshold percent greater than the minimum + normalized-RMSE, which will be 1, then the weight of 0.0 is used. + Otherwise, starting at the first weight smaller than the weight of the + minimum normalized-RMSE and moving in the direction of decreasing weights, + choose the first weight that is more than the threshold percent greater + than the minimum normalized-RMSE. + + Args: + rmse: + Array with one dimension, "weight", that contains the tested + omegas as coordinates. The data is the RMSE (Root _Mean_ Square + Error or Root _Median_ Square Error) values. + threshold: + The threshold percent to use for selecting the weight. + kwargs: + Ignores any additional keyword args. + + Returns: + The weight to use for the ARC method. + """ + norm_rmse = rmse / rmse.min() + + diffs = norm_rmse - 1 + + # If there are, then the set the weight to the first weight with an + # normalized-RMSE less than threshold percent above the minimum + # normalized-RMSE. + weight_with_lowest_rmse = ( + norm_rmse.where(norm_rmse == norm_rmse.min()).dropna("weight")["weight"].values[0] + ) + weights_to_check = [w for w in norm_rmse["weight"].values if w < weight_with_lowest_rmse] + diffs_to_check = diffs.sel(weight=weights_to_check) + diffs_greater = diffs_to_check.where(diffs_to_check >= threshold).dropna("weight") + if len(diffs_greater) > 0: + # take the max weight greater than the threshold but less than the + # with the lowest RMSE. + chosen_weight = diffs_greater["weight"].values.max() + else: + chosen_weight = 0.0 + + logger.debug(f"`use_smallest_omega_within_threshold` weight selected: {chosen_weight}") + return chosen_weight + + +def use_omega_rmse_weighted_average(rmse: xr.DataArray, **kwargs: Any) -> float: + r"""Use the RMSE-weighted average of the range of tested-omegas. + + .. math:: + + \bar{\omega} = \frac{\sum\limits_{i=0}^{N}\frac{\omega_i}{RMSE_i}} + {\sum\limits_{i=0}^{N}\frac{1}{RMSE_i}} + + where :math:`N` is the largest in the range of omegas that were tested. + + *Note* under the special case when one or more of the weights has an RMSE + of 0, we consider any weights with RMSE values of zero, to be weighted + infinitely, so we just take the mean of all the weights with an RMSE of + zero. + + Args: + rmse (xarray.DataArray): + Array with one dimension, "weight", that contains the tested + omegas as coordinates. The data is the RMSE (Root _Mean_ Square + Error or Root _Median_ Square Error) values. + kwargs: + Ignores any additional keyword args. + + Returns: + The omega to use for the ARC method. + """ + zero_rmse = rmse == 0 + if zero_rmse.any(): + # Any weights with RMSE values of zero, will be weighted infinitely so + # just take the mean of all the weights with an RMSE of zero. + chosen_weight = float(rmse["weight"].where(zero_rmse).dropna("weight").mean("weight")) + else: + chosen_weight = float((rmse["weight"] / rmse).sum() / (1 / rmse).sum().values) + + logger.debug(f"`use_omega_rmse_weighted_average` weight selected: {chosen_weight}") + return chosen_weight + + +def use_average_omega_within_threshold( + rmse: xr.DataArray, threshold: float = 0.05, **kwargs: Any +) -> float: + """Take the average of the omegas with RMSEs within 5% of lowest RMSE. + + Args: + rmse (xarray.DataArray): + Array with one dimension, "weight", that contains the tested + omegas as coordinates. The data is the RMSE (Root _Mean_ Square + Error or Root _Median_ Square Error) values. + threshold (float): + The threshold percent to use for selecting the weight. + kwargs: + Ignores any additional keyword args. + + Returns: + The weight to use for the ARC method. + """ + chosen_weight = ( + rmse.where(rmse < rmse.values.min() + rmse.values.min() * threshold) + .dropna("weight")["weight"] + .values.mean() + ) + + logger.debug(f"`use_average_omega_within_threshold` weight selected: {chosen_weight}") + return chosen_weight + + +def use_average_of_zero_biased_omegas_within_threshold( + rmse: xr.DataArray, threshold: float = 0.05, **kwargs: Any +) -> float: + """Calculates weight by averaging omegas. + + Take the average of the omegas less than the omega with the lowest RMSE, + and with RMSEs within 5% of that lowest RMSE. + + Args: + rmse: + Array with one dimension, "weight", that contains the tested + omegas as coordinates. The data is the RMSE (Root _Mean_ Square + Error or Root _Median_ Square Error) values. + threshold: + The threshold percent to use for selecting the weight. + kwargs: + Ignores any additional keyword args. + + Returns: + The weight to use for the ARC method. + """ + norm_rmse = rmse / rmse.min() + + weight_with_lowest_rmse = ( + norm_rmse.where(norm_rmse == norm_rmse.min()).dropna("weight")["weight"].values[0] + ) + weights_to_check = [w for w in norm_rmse["weight"].values if w <= weight_with_lowest_rmse] + + rmses_to_check = norm_rmse.sel(weight=weights_to_check) + rmses_to_check_within_threshold = rmses_to_check.where( + rmses_to_check < 1 + threshold + ).dropna("weight") + + chosen_weight = rmses_to_check_within_threshold["weight"].values.mean() + + logger.debug( + ( + "`use_average_of_zero_biased_omegas_within_threshold` weight selected: " + f"{chosen_weight}" + ) + ) + return chosen_weight + + +def use_omega_distribution( + rmse: xr.DataArray, draws: int, threshold: float = 0.05, **kwargs: Any +) -> xr.DataArray: + """Samples omegas from a distribution (using RMSE). + + Takes the omegas with RMSEs within the threshold percent of omega with + the lowest RMSE, and takes the reciprocal RMSEs of those omegas as the + probabilities of omegas being sampled from multinomial a distribution. + + Args: + rmse: + Array with one dimension, "weight", that contains the tested + omegas as coordinates. The data is the RMSE (Root _Mean_ Square + Error or Root _Median_ Square Error) values. + draws: + The number of draws to sample from the distribution of omega values + threshold: + The threshold percent to use for selecting the weight. + kwargs: + Ignores any additional keyword args. + + Returns: + Samples from a distribution of omegas to use for the ARC method. + """ + rmses_in_threshold = rmse.where(rmse < rmse.values.min() + rmse.values.min() * threshold) + reciprocal_rmses_in_threshold = (1 / rmses_in_threshold).fillna(0) + norm_reciprocal_rmses_in_threshold = ( + reciprocal_rmses_in_threshold / reciprocal_rmses_in_threshold.sum() + ) + + omega_draws = xr.DataArray( + np.random.choice( + a=norm_reciprocal_rmses_in_threshold["weight"].values, + size=draws, + p=norm_reciprocal_rmses_in_threshold.values, + ), + coords=[list(range(draws))], + dims=[DimensionConstants.DRAW], + ) + return omega_draws + + +def use_zero_biased_omega_distribution( + rmse: xr.DataArray, draws: int, threshold: float = 0.05, **kwargs: Any +) -> xr.DataArray: + """Samples omegas from a distribution (using RMSE). + + Takes the omegas with RMSEs within the threshold percent of omega with + the lowest RMSE, and takes the reciprocal RMSEs of those omegas as the + probabilities of omegas being sampled from multinomial a distribution. + + Args: + rmse: + Array with one dimension, "weight", that contains the tested + omegas as coordinates. The data is the RMSE (Root _Mean_ Square + Error or Root _Median_ Square Error) values. + draws: + The number of draws to sample from the distribution of omega values + threshold: + The threshold percent to use for selecting the weight. + kwargs: + Ignores any additional keyword args. + + Returns: + Samples from a distribution of omegas to use for the ARC method. + """ + norm_rmse = rmse / rmse.min() + + weight_with_lowest_rmse = ( + norm_rmse.where(norm_rmse == norm_rmse.min()).dropna("weight")["weight"].values[0] + ) + weights_to_check = [w for w in norm_rmse["weight"].values if w <= weight_with_lowest_rmse] + + rmses_to_check = norm_rmse.sel(weight=weights_to_check) + rmses_to_check_within_threshold = rmses_to_check.where( + rmses_to_check < 1 + threshold + ).dropna("weight") + + reciprocal_rmses_to_check_within_threshold = (1 / rmses_to_check_within_threshold).fillna( + 0 + ) + norm_reciprocal_rmses_to_check_within_threshold = ( + reciprocal_rmses_to_check_within_threshold + / reciprocal_rmses_to_check_within_threshold.sum() + ) + + omega_draws = xr.DataArray( + np.random.choice( + a=norm_reciprocal_rmses_to_check_within_threshold["weight"].values, + size=draws, + p=norm_reciprocal_rmses_to_check_within_threshold.values, + ), + coords=[list(range(draws))], + dims=[DimensionConstants.DRAW], + ) + return omega_draws + + +def adjusted_zero_biased_omega_distribution( + rmse: xr.DataArray, + draws: int, + seed: int = ArcMethodConstants.DEFAULT_RANDOM_SEED, + **kwargs: Any, +) -> xr.DataArray: + """Samples omegas from a distribution (using RMSE). + + Takes the omegas from the lowest RMSE to zero, and takes the reciprocal + RMSEs of those omegas as the probabilities of omegas being sampled from + multinomial a distribution. + + Args: + rmse: + Array with one dimension, "weight", that contains the tested + omegas as coordinates. The data is the RMSE (Root _Mean_ Square + Error or Root _Median_ Square Error) values. + draws: + The number of draws to sample from the distribution of omega values + seed: + seed to be set for random number generation. + kwargs: + Ignores any additional keyword args. + + Returns: + Samples from a distribution of omegas to use for the ARC method. + """ + np.random.seed(seed) + + norm_rmse = rmse / rmse.min() + + weight_with_lowest_rmse = ( + norm_rmse.where(norm_rmse == norm_rmse.min()).dropna("weight")["weight"].values[0] + ) + weights_to_check = [w for w in norm_rmse["weight"].values if w <= weight_with_lowest_rmse] + + rmses_to_check = norm_rmse.sel(weight=weights_to_check) + + reciprocal_rmses_to_check = (1 / rmses_to_check).fillna(0) + norm_reciprocal_rmses_to_check = ( + reciprocal_rmses_to_check / reciprocal_rmses_to_check.sum() + ) + + omega_draws = xr.DataArray( + np.random.choice( + a=norm_reciprocal_rmses_to_check["weight"].values, + size=draws, + p=norm_reciprocal_rmses_to_check.values, + ), + coords=[list(range(draws))], + dims=[DimensionConstants.DRAW], + ) + return omega_draws diff --git a/gbd_2021/disease_burden_forecast_code/met_need/predictive_validity.py b/gbd_2021/disease_burden_forecast_code/met_need/predictive_validity.py new file mode 100644 index 0000000..d1b8554 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/met_need/predictive_validity.py @@ -0,0 +1,106 @@ +from typing import Iterable, Optional + +import numpy as np +import pandas as pd +import xarray as xr +from fhs_lib_database_interface.lib.constants import DimensionConstants +from fhs_lib_file_interface.lib.pandas_wrapper import write_csv +from fhs_lib_file_interface.lib.version_metadata import FHSFileSpec +from fhs_lib_file_interface.lib.versioning import Versions + +from fhs_lib_genem.lib.constants import FileSystemConstants + +OMEGA_DIM = "omega" + + +def get_omega_weights(min: float, max: float, step: float) -> Iterable[float]: + """Return the list of weights between ``min`` and ``max``, incrementing by ``step``.""" + return np.arange(min, max, step) + + +def root_mean_square_error( + predicted: xr.DataArray, + observed: xr.DataArray, + dims: Optional[Iterable[str]] = None, +) -> xr.DataArray: + """Dimensions-specific root-mean-square-error. + + Args: + predicted (xr.DataArray): predicted values. + observed (xr.DataArray): observed values. + dims (Optional[Iterable[str]]): list of dims to compute rms for. + + Returns: + (xr.DataArray): root-mean-square error, dims-specific. + """ + dims = dims or [] + + squared_error = (predicted - observed) ** 2 + other_dims = [d for d in squared_error.dims if d not in dims] + return np.sqrt(squared_error.mean(dim=other_dims)) + + +def calculate_predictive_validity( + forecast: xr.DataArray, + holdouts: xr.DataArray, + omega: float, +) -> xr.DataArray: + """Calculate the RMSE between ``forecast`` and ``holdouts`` across location & sex.""" + # Take the mean over draw if forecast or holdouts data has them + if DimensionConstants.DRAW in forecast.dims: + forecast_mean = forecast.mean(DimensionConstants.DRAW) + else: + forecast_mean = forecast + + if DimensionConstants.DRAW in holdouts.dims: + holdouts_mean = holdouts.mean(DimensionConstants.DRAW) + else: + holdouts_mean = holdouts + + # Calculate RMSE + pv_data = root_mean_square_error( + predicted=forecast_mean.sel(scenario=0, drop=True), + observed=holdouts_mean, + dims=[DimensionConstants.LOCATION_ID, DimensionConstants.SEX_ID], + ) + + # Tag the data with a hard-coded attribute & return it + pv_data[OMEGA_DIM] = omega + return pv_data + + +def finalize_pv_data(pv_list: Iterable[xr.DataArray], entity: str) -> pd.DataFrame: + """Convert a list of PV xarrays into a pandas dataframe, and take the mean over sexes.""" + pv_xr = xr.concat(pv_list, dim=OMEGA_DIM) + + # per research decision, mean over sexes (if its present) + if DimensionConstants.SEX_ID in pv_xr.dims: + pv_xr = pv_xr.mean([DimensionConstants.SEX_ID]) + + pv_xr["entity"] = entity + return pv_xr.to_dataframe(name="rmse").reset_index() + + +def save_predictive_validity( + file_name: str, + gbd_round_id: int, + model_name: str, + pv_df: pd.DataFrame, + stage: str, + subfolder: str, + versions: Versions, +) -> None: + """Write a predictive validity dataframe to disk.""" + # Define the output file spec + pv_file_spec = FHSFileSpec( + version_metadata=versions.get(past_or_future="future", stage=stage), + sub_path=( + FileSystemConstants.PV_FOLDER, + model_name, + subfolder, + ), + filename=f"{file_name}_pv.csv", + ) + + # Write the dataframe (note that the pv output directory is "{out_version}_pv") + write_csv(df=pv_df, file_spec=pv_file_spec, sep=",", na_rep=".", index=False) diff --git a/gbd_2021/disease_burden_forecast_code/migration/README.md b/gbd_2021/disease_burden_forecast_code/migration/README.md new file mode 100644 index 0000000..b59bc6c --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/migration/README.md @@ -0,0 +1,61 @@ +# Migration pipeline code + +The general overall order for running these scripts is: +1. aggregate_shocks_and_sdi.py +2. model_migration.py +3. run_model.py +4. csv_to_xr.py +5. arima_and_generate_draws.py +6. migration_rate_to_count.py +7. age_sex_split.py +8. balance_migration.py + +``` +age_sex_split.py +Splits the migration into separate age-sex groups +``` + +``` +aggregate_shocks_and_sdi.py +Produce aggregate versions of shocks and reshape mean sdi for use in modeling migration +``` + +``` +arima_and_generate_draws.py +Applies random walk on every-5-year migration data without draws +``` + +``` +balance_migration.py +Combines the separate location files from the age-sex splitting of migration +``` + +``` +csv_to_xr.py +Converts .CSV predictions to xarray file and makes epsilon +``` + +``` +migration_rate_to_count.py +Converts migration rates output by draw generation step to counts for use in age-sex splitting step +``` + +``` +model_migration.py +Cleans and models UN migration estimates for forecasting +``` + +``` +model_strategy.py +Where migration modeling strategies and their parameters are managed/defined +``` + +``` +model_strategy_queries.py +Has query functions that give nonfatal modeling strategies and their parameters +``` + +``` +run_model.py +Forecasts migration with LimeTr +``` \ No newline at end of file diff --git a/gbd_2021/disease_burden_forecast_code/migration/age_sex_split.py b/gbd_2021/disease_burden_forecast_code/migration/age_sex_split.py new file mode 100644 index 0000000..c463b4b --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/migration/age_sex_split.py @@ -0,0 +1,122 @@ +""" +This script works to split the migration into separate age-sex groups +by converting the Eurostat data to xarray. + +Example: + +.. code:: bash + + python FILEPATH/age_sex_split.py + --migration_version click_20210510_limetr_fixedint + --pattern_version 20191114_eurostat_age_sex_pattern_w_subnat + --gbd_round_id 6 + +""" + +import argparse +import pandas as pd +import xarray as xr + +from db_queries import get_location_metadata +from fhs_lib_file_interface.lib.file_interface import FBDPath +from fhs_lib_file_interface.lib.file_system_manager import FileSystemManager +from fhs_lib_file_interface.lib.os_file_system import OSFileSystem +FileSystemManager.set_file_system(OSFileSystem()) +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr, save_xr +from fhs_lib_data_transformation.lib.pandas_to_xarray import df_to_xr +from fhs_lib_data_transformation.lib.dimension_transformation import expand_dimensions +from tiny_structured_logger.lib import fhs_logging + +logger = fhs_logging.get_logger() + +# Which location ids to use +WPP_LOCATION_IDS = pd.read_csv( + "FILEPATH.csv" + )["location_id"].unique().tolist() +# get subnational location ids +PATTERN_ID_VARS = ["age_group_id", "sex_id"] +LOCATION_SET_ID = 39 +RELEASE_ID = 6 +QATAR_LOCS = [151, 152, 140, 156, 150] #'QAT', 'SAU', 'BHR', 'ARE', 'OMN' + +def create_age_sex_xarray(gbd_round_id, pattern_version, subnat_location_ids): + logger.debug("Creating xarray of age-sex patterns for migration") + # load patterns + # location of the pattern for the Qatar-modeled countries + QATAR_PATTERN = f"FILEPATH/qatar_pattern.csv" + # Location of the pattern for other countries + EUROSTAT_PATTERN = f"FILEPATH/eurostat_pattern.csv" + qatar = pd.read_csv(QATAR_PATTERN) + eurostat = pd.read_csv(EUROSTAT_PATTERN) + # convert to xarrays + qatar = df_to_xr(qatar, dims=PATTERN_ID_VARS) + eurostat = df_to_xr(eurostat, dims=PATTERN_ID_VARS) + # create superarray to hold all locs + all_locs_xr_list = [] + # Put dataframes for each location into a list + for loc in WPP_LOCATION_IDS + subnat_location_ids: + if loc in QATAR_LOCS: + data = qatar + else: + data = eurostat + data = expand_dimensions(data, location_id=[loc]) + all_locs_xr_list.append(data) + # Concat all locations together + pattern = xr.concat(all_locs_xr_list, dim='location_id') + # Save all locs pattern + logger.debug("Saving age-sex pattern xarray") + pattern_dir = FBDPath('FILEPATH') + pattern_path = pattern_dir / f"combined_age_sex_pattern.nc" + save_xr(pattern, pattern_path, metric="percent", space="identity") + logger.debug("Saved age-sex pattern xarray") + return pattern + +def main(migration_version, gbd_round_id, pattern_version): + # load age-sex pattern (loc, draw, age, sex) + logger.debug("Loading age-sex migration pattern") + subnat_location_ids = get_location_metadata(gbd_round_id=gbd_round_id, + location_set_id=LOCATION_SET_ID, + release_id=RELEASE_ID).\ + query("level == 4").\ + location_id.tolist() + try: + pattern_dir = FBDPath('FILEPATH') + pattern_path = pattern_dir / "combined_age_sex_pattern.nc" + pattern = open_xr(pattern_path) + except: # Data doesn't yet exist + pattern = create_age_sex_xarray(gbd_round_id, + pattern_version, + subnat_location_ids) + # load migration counts (loc, draw, year) + logger.debug("Loading migration data") + mig_dir = FBDPath("FILEPATH") + mig_path = mig_dir / "mig_counts.nc" + migration = open_xr(mig_path) + migration = migration.squeeze(drop=True) + # end up with migration counts with age and sex (loc, draw, year, age, sex) + split_data = migration * pattern + # Save it + logger.debug("Saving age-sex split migration data") + + split_path = mig_dir / "migration_split.nc" + save_xr(split_data, split_path, metric="number", space="identity") + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument( + "--migration_version", type=str, required=True, + help="Which version of migrations to use in WPP directory") + parser.add_argument( + "--gbd_round_id", type=int, required=True, + help="Which gbd_round_id to use in file loading and saving") + parser.add_argument( + "--pattern_version", type=str, required=True, + help="Which age-sex pattern version to use in future migration \ + directory") + args = parser.parse_args() + + main(migration_version=args.migration_version, + gbd_round_id=args.gbd_round_id, + pattern_version=args.pattern_version) diff --git a/gbd_2021/disease_burden_forecast_code/migration/aggregate_shocks_and_sdi.py b/gbd_2021/disease_burden_forecast_code/migration/aggregate_shocks_and_sdi.py new file mode 100644 index 0000000..22a668d --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/migration/aggregate_shocks_and_sdi.py @@ -0,0 +1,193 @@ +""" +Produce aggregate versions of shocks for use in +modeling migration, saving them as csvs in their original directories. + +Example: + +.. code:: bash + +python FILEPATH/aggregate_shocks_and_sdi.py \ +--shocks_version 20210419_shocks_only_decay_weight_15 \ +--past_pop_version 20200206_etl_gbd_decomp_step4_1950_2019_run_id_192 \ +--forecast_pop_version 20200513_arc_method_new_locs_ratio_subnats_else_pop_paper \ +--gbd_round_id 6 \ +--years 1950:2020:2050 + +""" +import argparse + +from fhs_lib_data_aggregation.lib.aggregator import Aggregator +from fhs_lib_data_transformation.lib.dimension_transformation import expand_dimensions +from fhs_lib_database_interface.lib.query.age import get_age_weights +from fhs_lib_file_interface.lib.file_interface import FBDPath +from fhs_lib_file_interface.lib.file_system_manager import FileSystemManager +from fhs_lib_file_interface.lib.os_file_system import OSFileSystem +FileSystemManager.set_file_system(OSFileSystem()) +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr, save_xr +from fhs_lib_year_range_manager.lib.year_range import YearRange +from tiny_structured_logger.lib import fhs_logging + +logger = fhs_logging.get_logger() + +ALL_AGE_GROUP_ID = 22 +BOTH_SEX_ID = 3 +REFERENCE_SCENARIO = 0 +SHOCK_ACAUSES = ("inj_disaster", "inj_war_execution", "inj_war_warterror") + + +def load_past_pop(gbd_round_id, version, years): + """ + Load past population data. This will generally be from 1950 to the start of + the forecasts. Takes the mean of draws. + + Args: + gbd_round_id (int): + The gbd round ID that the past population is from + version (str): + The version of past population to read from + + Returns: + xarray.DataArray: The past population xarray dataarray + """ + past_pop_path = FBDPath("FILEPATH") + past_pop_file = past_pop_path / "population.nc" + past_pop_da = open_xr(past_pop_file) + + # slice to correct years + past_pop_da = past_pop_da.sel(year_id=years.past_years) + + return past_pop_da + + +def load_forecast_pop(gbd_round_id, version, years): + """ + Load forecast population data. Aggregates if necessary. Takes mean of draws. + + Args: + gbd_round_id (int): + The gbd round ID that the past population is from + version (str): + The version of forecast population to read from + years (YearRange): + The Forecasting format years to use. + + Returns: + xarray.DataArray: The past population xarray dataarray + """ + forecast_pop_path = FBDPath("FILEPATH") + try: + forecast_pop_file = forecast_pop_path / "population_agg.nc" + forecast_pop_da = open_xr(forecast_pop_file) + try: # Sometimes saved with draws in agg + forecast_pop_da = forecast_pop_da.mean("draw") + except: + pass + except OSError: # Need to make agg version + forecast_pop_file = forecast_pop_path / "population.nc" + forecast_pop_da = open_xr(forecast_pop_file) + forecast_pop_da = forecast_pop_da.mean("draw") + forecast_pop_da = Aggregator.aggregate_everything(forecast_pop_da, gbd_round_id).pop + forecast_pop_out_file = forecast_pop_path / "population_agg.nc" + save_xr(forecast_pop_da, forecast_pop_out_file, metric="number", space="identity") + + # slice to correct years + forecast_pop_da = forecast_pop_da.sel(year_id=years.forecast_years) + + return forecast_pop_da + + +def main( + shocks_version, + past_pop_version, + forecast_pop_version, + gbd_round_id, + years, +): + """ + Load pops and shocks data, aggregate, convert to csv, save + """ + past_pop_da = load_past_pop(gbd_round_id, past_pop_version, years) + forecast_pop_da = load_forecast_pop(gbd_round_id, forecast_pop_version, years) + most_detailed_ages = list( + get_age_weights( + gbd_round_id=gbd_round_id, + most_detailed=True, + )["age_group_id"] + ) + + # Give past populations dummy scenarios to be concatenated with forecast pops + past_pop_da = expand_dimensions(past_pop_da, scenario=forecast_pop_da["scenario"].values) + + forecast_pop_da = forecast_pop_da.sel( + scenario=REFERENCE_SCENARIO, age_group_id=most_detailed_ages + ) + + past_pop_da = past_pop_da.sel( + scenario=REFERENCE_SCENARIO, + age_group_id=most_detailed_ages, + location_id=forecast_pop_da.location_id, + ) + + # Combine past and forecast pop + pop_da = past_pop_da.combine_first(forecast_pop_da) + + # Save out shocks + if shocks_version: + for acause in ["inj_disaster", "inj_war_execution", "inj_war_warterror"]: + # Load data + shock_path = FBDPath( + "FILEPATH".format(g=gbd_round_id, version=shocks_version) + ) + shock_file = shock_path / "{}.nc".format(acause) + shock_da = open_xr(shock_file) + # Take mean of draws + shock_da = shock_da.mean("draw") + # Aggregate everything + pop_agg = Aggregator(pop_da) + + shock_da = pop_agg.aggregate_everything( + pop_da, gbd_round_id, data=shock_da + ).data.rate + + shock_da = shock_da.sel(age_group_id=ALL_AGE_GROUP_ID, sex_id=BOTH_SEX_ID) + # Convert to dataframe and save + shock_df = shock_da.to_dataframe(name="mortality") + shock_out_file = shock_path / "mean_{}.csv".format(acause) + shock_df.to_csv(shock_out_file) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) + parser.add_argument( + "--shocks_version", + type=str, + required=False, + help="Which version of shocks to convert", + ) + + parser.add_argument( + "--past_pop_version", type=str, help="Which version of past populations to use" + ) + parser.add_argument( + "--forecast_pop_version", + type=str, + help="Which version of forecasted populations to use", + ) + parser.add_argument( + "--gbd_round_id", type=int, help="Which gbd round id to use for pops and shocks" + ) + parser.add_argument("--years", type=str, help="past_start:forecast_start:forecast_end") + args = parser.parse_args() + + years = YearRange.parse_year_range(args.years) + + if args.shocks_version: + main( + shocks_version=args.shocks_version, + past_pop_version=args.past_pop_version, + forecast_pop_version=args.forecast_pop_version, + gbd_round_id=args.gbd_round_id, + years=years, + ) diff --git a/gbd_2021/disease_burden_forecast_code/migration/arima_and_generate_draws.py b/gbd_2021/disease_burden_forecast_code/migration/arima_and_generate_draws.py new file mode 100644 index 0000000..cbf5e35 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/migration/arima_and_generate_draws.py @@ -0,0 +1,263 @@ +""" +Apply Random Walk on every-5-year migration data without draws. +Take mean of random walk to model latent trends. +Generate draws from normal distribution mean=point estimate, sd=past sd + +python FILEPATH/arima_and_generate_draws.py +--eps-version click_20210510_limetr_fixedint +--arima-version click_20210510_limetr_fixedint +--measure migration +--gbd-round-id 6 +--draws 1000 +--years 1950:2020:2050 +--locations-to-shock-smooth 12345 133 98 +""" +import argparse +import numpy as np +import xarray as xr +from typing import List, Tuple + +from fhs_lib_file_interface.lib.file_interface import FBDPath +from fhs_lib_file_interface.lib.file_system_manager import FileSystemManager +from fhs_lib_file_interface.lib.os_file_system import OSFileSystem +FileSystemManager.set_file_system(OSFileSystem()) +from fhs_lib_file_interface.lib.xarray_wrapper import save_xr, open_xr +from fhs_lib_model.lib.random_walk.random_walk import RandomWalk +from fhs_lib_year_range_manager.lib.year_range import YearRange +from tiny_structured_logger.lib import fhs_logging +from migration import remove_drift + +logger = fhs_logging.get_logger() +rng = np.random.default_rng(47) + +RATE_CAP = 10 # per 1000 people +N_PAST_ANCHOR_YEARS = 20 +UI_MAX_RATE = 10 # per 1000 people +UI_MIN_RATE = 0.5 # per 1000 people + +def save_y_star( + eps_version: str, + arima_version: str, + years: YearRange, + measure: str, + draws: int, + decay: float, + gbd_round_id: int, + locations_to_shock_smooth: List[int], +) -> None: + """ + apply random walk and save the output + """ + mig_dir = FBDPath("FILEPATH") + # load "true" observed past migration rate + past_migration_rate_path = mig_dir / "migration_single_years.nc" + past_migration_rate = open_xr(past_migration_rate_path).sel(year_id=years.past_years) + + eps_path = mig_dir / "mig_eps.nc" + ds = open_xr(eps_path) + try: + eps_preds = open_xr(mig_dir / "eps_star.nc") + except Exception: + ds = shock_smooth_locations(ds, years, locations_to_shock_smooth, N_PAST_ANCHOR_YEARS) + eps_preds = arima_migration(ds, years, draws, decay) + epsilon_hat_out = mig_dir / "eps_star.nc" + save_xr(eps_preds, epsilon_hat_out, metric="rate", space="identity") + + # cap residuals between -`RATE_CAP` and `RATE_CAP` + # population forecast ui is not credible when residuals are uncapped + eps_past = eps_preds.sel(year_id=years.past_years) + eps_preds = eps_preds.sel(year_id=years.forecast_years) + eps_preds = eps_preds.clip(min=-RATE_CAP, max=RATE_CAP) + eps_preds = xr.concat([eps_past, eps_preds], dim="year_id") + eps_preds = eps_preds.mean(dim="draw") + + pred_path = mig_dir / "mig_hat.nc" + preds = open_xr(pred_path) + preds = preds.sel(year_id=years.years) + y_star = preds + eps_preds + y_star_past = y_star.sel(year_id=years.past_years) + y_star_draws = generate_draws_from_predicted_mean( + y_star, + past_migration_rate, + years, + draws, + std_dev_bounds=(UI_MIN_RATE, UI_MAX_RATE)) + y_star_draws = xr.concat([y_star_past, y_star_draws], dim="year_id") + + save_path = FBDPath("FILEPATH") + ystar_out = save_path / "mig_star.nc" + + save_xr(y_star_draws, ystar_out, metric="rate", space="identity") + + +def arima_migration(epsilon_past, years, draws, decay): + """ + apply drift attenuation and fit random walk on the dataset + """ + drift_component = remove_drift.get_decayed_drift_preds(epsilon_past, years, decay) + remainder = epsilon_past - drift_component.sel(year_id=years.past_years) + ds = xr.Dataset(dict(y=remainder.copy())) + + rw_model = RandomWalk(ds, years, draws) + rw_model.fit() + rw_preds = rw_model.predict() + + return drift_component + rw_preds + + +def generate_draws_from_predicted_mean( + da: xr.DataArray, + true_past: xr.DataArray, + years: YearRange, + draws: int, + std_dev_bounds: Tuple[int, int], +) -> xr.DataArray: + """ + args: + da: y_star array of regression estimates + mean of random walk fit on residual + true_past: data array of "true" past migration rate ETLed from WPP + years: past_start:forecast_start:forecast_end + draws: number of draws to generate + std_dev_cap: rate per thousand at which to cap std_dev of generated draws + + Generates draws from a normal distribution with mean = y_star, std dev = the std dev + for each location over all past years * + (forecast_start - (forecast_start - 40)) / `scale_value` + For `scale_value`= 120, forecast year 1: sd = 0.33 * sd_past, increasing by 2.5% each year + In forecast year 80: sd = 1 * sd_past. Draws are left unordered + """ + scale_value = 120 # scale expanding uncertainty to reasonable value + + forecast_da = da.sel(year_id=years.forecast_years) + forecast_da = forecast_da.expand_dims( + draw=range(0, draws) + ) # create appropriate draw coords + + # calculate past SD for each nation, scale into future + past_da = true_past.sel(scenario=0, drop=True).expand_dims(sex_id=[3], age_group_id=[22]) + past_sd = past_da.std(dim="year_id") + past_sd = past_sd.clip(min=std_dev_bounds[0], max=std_dev_bounds[1]) # we want to cap + # predictions of sustained high migration into future based purely on extreme events in recent past + past_sd = past_sd.expand_dims(draw=range(0, draws), year_id=years.forecast_years) + past_sd = past_sd.transpose(*forecast_da.dims) + + # sd widening equation + coef_da = xr.DataArray( + data=(forecast_da.year_id.values - (years.forecast_start - 40)) / scale_value, + dims={"year_id": forecast_da.year_id.values}, + ) + + coef_da = coef_da.expand_dims( + location_id=forecast_da.location_id.values, + draw=forecast_da.draw.values, + sex_id=forecast_da.sex_id.values, + age_group_id=forecast_da.age_group_id.values, + ) + + # uncertainty is generated from 3 arrays + logger.info("generating draws from mean predicted migration rate") + forecast_with_uncertainty = rng.normal(forecast_da, past_sd * coef_da) + + nd_labeled_uncertainty = xr.DataArray( + data=forecast_with_uncertainty, + coords=forecast_da.coords, + ) + + return nd_labeled_uncertainty + + +def shock_smooth_locations( + ds: xr.DataArray, + years: YearRange, + locations_to_shock_smooth: List[int], + n_past_anchor_years: int, +): + """ + args: + ds: data array of epsilons (migration model error) + years: past_start:forecast_start:forecast_end + locations_to_shock_smooth: location ids for locations we want to shock smooth + n_past_anchor_year: number of most recent past years to calculate mean epsilon from + + This function provides a way of smoothing recent shocks/spikes/extreme migration events + prior to fitting and forecasting model-error using a random-walk. For each location + provided the epsilon of the last-past year is set to the mean of `n_past_anchor_years`. + Example: if n_past_anchor_years=3, last_past_year=2022, then epsilon of 2022 will be set + to the mean epsilon off [2020, 2021, 2022]. + + Overall this means extreme events in the last past years in supplied countries will have + less impact on forecasts. + + NOTE: Because its hard to define what a `shocked` location is it might be prefereble to + find an entirely new method to forecast model error. + """ + logger.info( + f"Smoothing migration shocks in the year {years.past_end} " + f"from the following locations: {locations_to_shock_smooth}" + ) + + locs_to_shock_smooth = ds.sel(location_id=locations_to_shock_smooth) + mean_eps = locs_to_shock_smooth.sel( + year_id=list(range(years.past_end - n_past_anchor_years, years.past_end)) + ).mean(dim="year_id") + + for location in locations_to_shock_smooth: + loc_specific_mean_eps = mean_eps.sel(location_id=location).squeeze().values + ds.loc[ + { + "location_id": location, + "age_group_id": 22, + "sex_id": 3, + "year_id": years.past_end, + } + ] = loc_specific_mean_eps + + return ds + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawTextHelpFormatter + ) + parser.add_argument( + "--eps-version", type=str, required=True, help="The version the eps are saved under." + ) + parser.add_argument( + "--arima-version", + type=str, + required=True, + help="The version the arima results are saved under.", + ) + parser.add_argument("--measure", type=str, required=True, choices=["migration", "death"]) + parser.add_argument( + "--decay-rate", + type=float, + required=False, + default=0.1, + help="Rate at which drift on all-cause epsilons decay in future", + ) + parser.add_argument( + "--gbd-round-id", type=int, required=True, help="Which gbd round id to use" + ) + parser.add_argument("--draws", type=int, help="Number of draws") + parser.add_argument("--years", type=str, help="past_start:forecast_start:forecast_end") + parser.add_argument( + "--locations-to-shock-smooth", + type=int, + nargs="*", + help="Location ids" "to `shock smooth`, can supply multiple location ids.", + ) + args = parser.parse_args() + + years = YearRange.parse_year_range(args.years) + + save_y_star( + args.eps_version, + args.arima_version, + years, + args.measure, + args.draws, + args.decay_rate, + args.gbd_round_id, + args.locations_to_shock_smooth, + ) diff --git a/gbd_2021/disease_burden_forecast_code/migration/balance_migration.py b/gbd_2021/disease_burden_forecast_code/migration/balance_migration.py new file mode 100644 index 0000000..c386656 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/migration/balance_migration.py @@ -0,0 +1,288 @@ + +""" +Combine the separate location files from the age-sex splitting of migration. +Balance the migration data to be zero at each age-sex-draw combo. +Add any missing locations (all level 3 and level 4 locations) and fill with 0s. + +Example: + +.. code:: bash + python FILEPATH/balance_migration.py \ + --version click_20210510_limetr_fixedint \ + --gbd_round_id 6 +""" + +import argparse +import pandas as pd +import xarray as xr + +from db_queries import get_location_metadata +from fhs_lib_data_transformation.lib.pandas_to_xarray import df_to_xr +from fhs_lib_data_transformation.lib.dimension_transformation import expand_dimensions +from fhs_lib_file_interface.lib.file_interface import FBDPath +from fhs_lib_file_interface.lib.file_system_manager import FileSystemManager +from fhs_lib_file_interface.lib.os_file_system import OSFileSystem +FileSystemManager.set_file_system(OSFileSystem()) +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr, save_xr +from fhs_lib_summary_maker.lib.summary import compute_summary +from tiny_structured_logger.lib import fhs_logging + +logger = fhs_logging.get_logger() + +LOCATION_SET_ID = 39 +RELEASE_ID = 6 +ID_VARS = ["location_id", "year_id", "age_group_id", "sex_id", "draw"] +WPP_LOCATION_IDS = pd.read_csv( + "FILEPATH" + )["location_id"].unique().tolist() + + +def combine_and_save_mig(version): + """ + Load location csvs of migration files and combine into an xarray dataarray. + + Args: + version (str): + The version of migration to combine and save + + Returns: + xarray.DataArray: The combined migration data xarray dataarray. + """ + logger.info("Combining migration csvs to xarray") + all_locs_xr_list = [] + # Put dataframes for each location into a list + for loc in WPP_LOCATION_IDS + subnat_location_ids: + temp = pd.read_csv(f'FILEPATH/{loc}.csv') + #temp = temp.set_index(ID_VARS) + temp = df_to_xr(temp, dims=ID_VARS) + all_locs_xr_list.append(temp) + # Concat all locations together + result = xr.concat(all_locs_xr_list, dim='location_id') + + # Save to forecasting directory + result.to_netcdf(f'FILEPATH/migration_split.nc') + logger.info("Saved migration xarray to FHS WPP directory") + return result + +def balance_migration(mig_da): + """ + Ensure that net migration is zero at each loc-sex-age combo. + Calculate K = sqrt(-sum of positive values/sum of negative values) + Divide positive values by K + Multiply negative values by K + + Args: + mig_da (xarray.DataArray): + The input migration xarray dataarray that is being balanced + + Returns: + xarray.DataArray: The balanced migration data, + in dataarray. + """ + logger.info("Entered balancing step") + # only balance national locations (level 3), do not balance subnational (level 4) + subnat_location_ids = list(set(mig_da.location_id.values) - set(WPP_LOCATION_IDS)) + balance_mig_da = mig_da.sel(location_id=WPP_LOCATION_IDS) + no_balance_mig_da = mig_da.sel(location_id=subnat_location_ids) + + negatives = balance_mig_da.where(mig_da < 0) + positives = balance_mig_da.where(mig_da > 0) + zeros = balance_mig_da.where(mig_da == 0) + + sum_dims = [dim for dim in mig_da.dims if dim not in ( + "draw", "age_group_id", "sex_id", "year_id")] + k = positives.sum(sum_dims) / negatives.sum(sum_dims) + + # Add a case for if sum of either is zero + + # Multiply constant by positives + adjusted_positives = xr.ufuncs.sqrt(1 / -k) * positives + adjusted_negatives = xr.ufuncs.sqrt(-k) * negatives + + # Combine + balanced_mig_da = adjusted_positives.combine_first(adjusted_negatives) + balanced_mig_da = balanced_mig_da.combine_first(zeros) + balanced_mig_da = balanced_mig_da.combine_first(no_balance_mig_da) + + logger.info("Balanced migration") + return balanced_mig_da + +def add_missing_locs(mig_da, location_ids): + """ + Append any missing locations in location_id and fill with 0 net migration + Args: + mig_da (xarray.DataArray): + The input migration xarray dataarray that locations are being appended to + location_ids ([int]): + A list of the location ids that are expected in mig_da + + Returns: + xarray.DataArray: + The migration data with all location_ids + """ + missing_locs = list(set(location_ids).difference(set(mig_da.location_id.values))) + logger.info(f"Filling missing locations { {*missing_locs} } with 0") + mig_da = expand_dimensions(mig_da, location_id=missing_locs, fill_value=0) + return mig_da + +def _clean_migration_locations(migration, pop, gbd_round_id): + """Migration uses weird locations. Sometimes, locations are missing + migration data. Other times, locations have migration data but they + shouldn't. + + In the case where locations have migration data, but they should really be + part of another location (e.g. Macao is part of China), that migration will + be added into the "parent" location. + + In the case where locations are missing migration data, those locations + will get the average migration of their regions. This averaging happens + AFTER too-specific locations are merged into their parents. + """ + merged_migration = _merge_too_specific_locations(migration) + filled_migration = _fill_missing_locations( + merged_migration, pop, gbd_round_id) + return filled_migration + +def _fill_missing_locations(data_per_capita, pop, gbd_round_id): + """Missing locations need to be filled in with region averages.""" + avail_locs = set(data_per_capita.location_id.values) + desired_locs = fbd_core.db.get_modeled_locations(gbd_round_id) + missing_locs = set(desired_locs.location_id.values) - avail_locs + if not missing_locs: + return data_per_capita + logger.info("These locations are missing: {}".format(missing_locs)) + parent_locs = desired_locs.query("location_id in @missing_locs")[ + "parent_id"].values + logger.info("Children of these locations will be averaged to fill in " + "missing data: {}".format(parent_locs)) + hierarchy = desired_locs.query( + "parent_id in @parent_locs and location_id in @avail_locs" + )[ + ["location_id", "parent_id"] + ].set_index( + "location_id" + ).to_xarray()["parent_id"] + hierarchy.name = "location_id" + pop_location_slice = pop.sel(location_id=hierarchy.location_id.values) + + data = data_per_capita * pop_location_slice + + mean_data = data.sel( + location_id=hierarchy.location_id.values + ).groupby(hierarchy).mean("location_id") + pop_agged = pop_location_slice.sel( + location_id=hierarchy.location_id.values + ).groupby(hierarchy).mean("location_id") + mean_data_per_capita = (mean_data / pop_agged).fillna(0) + location_da = xr.DataArray( + desired_locs.location_id.values, + dims="location_id", + coords=[desired_locs.location_id.values]) + + filled_data_per_capita, _ = xr.broadcast(data_per_capita, location_da) + + for missing_location in desired_locs.query( + "location_id in @missing_locs").iterrows(): + loc_slice = {"location_id": missing_location[1].location_id} + loc_parent_slice = {"location_id": missing_location[1].parent_id} + filled_data_per_capita.loc[loc_slice] = ( + mean_data_per_capita.sel(**loc_parent_slice)) + match_already_existing_locations = ( + filled_data_per_capita == data_per_capita).all() + does_not_match_err_msg = ( + "Result should match input data for locations that are present.") + assert match_already_existing_locations, does_not_match_err_msg + if not match_already_existing_locations: + logger.error(does_not_match_err_msg) + raise MigrationError(does_not_match_err_msg) + has_new_locations = missing_locs.issubset( + filled_data_per_capita.location_id.values) + does_not_have_new_locs_err_msg = ( + "Missing locations {} are still missing.".format(missing_locs)) + assert has_new_locations, does_not_have_new_locs_err_msg + if not has_new_locations: + logger.error(does_not_have_new_locs_err_msg) + raise MigrationError(does_not_have_new_locs_err_msg) + return filled_data_per_capita + +def _merge_too_specific_locations(data): + """Locations that are too specific (i.e. level 4) need to be merged into + their respective parent locations.""" + avail_locs = set(data.location_id.values) + level_four_locs = fbd_core.db.get_locations_by_level(4) + too_specific_locs = set(level_four_locs.location_id) & avail_locs + if len(too_specific_locs) == 0: # nothing to merge to parent. just return input. + return data + logger.info("These locations are too specific: {}".format( + too_specific_locs)) + children_into_parents = level_four_locs.query( + "location_id in @too_specific_locs" + )[ + ["location_id", "parent_id"] + ].set_index( + "location_id" + ).to_xarray()["parent_id"] + children_into_parents.name = "location_id" + children_merged = data.sel( + location_id=children_into_parents.location_id + ).groupby(children_into_parents).sum("location_id") + good_locations = avail_locs - too_specific_locs + good_data = data.sel(location_id=list(good_locations)) + merged_data = sum(broadcast_and_fill(children_merged, good_data, + fill_value=0)) + unchanged_locs = good_locations - set(children_into_parents.values) + good_locs_didnt_change = ( + data == merged_data.sel(location_id=list(unchanged_locs))).all() + good_locs_did_change_err_msg = ( + "Error: good locations were changed during the merge.") + assert good_locs_didnt_change, good_locs_did_change_err_msg + if not good_locs_didnt_change: + logger.error(good_locs_did_change_err_msg) + raise MigrationError(good_locs_did_change_err_msg) + return merged_data + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--version", type=str, + help="Which version of migration to balance.") + parser.add_argument( + "--gbd_round_id", type=int, required=True, + help="Which gbd_round_id to use in file loading and saving") + args = parser.parse_args() + + subnat_location_ids = get_location_metadata( + location_set_id=LOCATION_SET_ID, + gbd_round_id=args.gbd_round_id, + release_id=RELEASE_ID).\ + query("level == 4").\ + location_id.tolist() + # Try to load data, else combine csvs into dataarray + # Csvs are from if you split in R, xarray from split in Python + try: + mig_dir = FBDPath("FILEPATH") + mig_path = mig_dir / "migration_split.nc" + mig_da = open_xr(mig_path) + except: # Data doesn't yet exist + mig_da = combine_and_save_mig(version=args.version) + + balanced_mig_da = balance_migration(mig_da) + + # add missing locations + location_ids = get_location_metadata(location_set_id=LOCATION_SET_ID, + gbd_round_id=args.gbd_round_id, + release_id=RELEASE_ID).\ + query("level == 4 | level == 3").\ + location_id.to_list() + balanced_mig_da = add_missing_locs(balanced_mig_da, location_ids) + summary = compute_summary(balanced_mig_da) + + # Save to forecasting directory + balanced_path = mig_dir / "migration.nc" + summary_path = mig_dir / "summary.nc" + + save_xr(balanced_mig_da, balanced_path, metric="number", space="identity") + save_xr(summary, summary_path, metric="number", space="identity") diff --git a/gbd_2021/disease_burden_forecast_code/migration/csv_to_xr.py b/gbd_2021/disease_burden_forecast_code/migration/csv_to_xr.py new file mode 100644 index 0000000..6b80fda --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/migration/csv_to_xr.py @@ -0,0 +1,124 @@ +""" +* Convert csv predictions to xarray file and makes epsilon. +* Past takes csv and formats to xarray, or xarray if it's already created w/ +LimeTr output. + +Example Call: + +python FILEPATH/csv_to_xr.py +--mig-version click_20210510_limetr_fixedint +--model-version click_20210510_limetr_fixedint +--model-name model_6_single_years +--gbd-round-id 6 +--years 1950:2020:2050 + +""" +import argparse +import os.path +import pandas as pd +import xarray as xr + +from fhs_lib_file_interface.lib.file_interface import FBDPath +from fhs_lib_file_interface.lib.file_system_manager import FileSystemManager +from fhs_lib_file_interface.lib.os_file_system import OSFileSystem +FileSystemManager.set_file_system(OSFileSystem()) +from fhs_lib_file_interface.lib.xarray_wrapper import save_xr, open_xr +from fhs_lib_year_range_manager.lib.year_range import YearRange +from tiny_structured_logger.lib import fhs_logging + +logger = fhs_logging.get_logger() + +AGE_GROUP_ID = 22 +SEX_ID = 3 + +def make_eps(mig_version, model_version, model_name, gbd_round_id, years): + logger.info("Making Epsilons") + model_dir = FBDPath("FILEPATH") + model_path = model_dir / f"{model_name}.csv" + df = pd.read_csv(model_path) + # add all-sex and all-age id columns + df["sex_id"] = SEX_ID + df["age_group_id"] = AGE_GROUP_ID + # select the columns we need + df = df[["location_id", "year_id", "age_group_id", "sex_id", "predictions", + "migration_rate"]] + # set index columns + index_cols = ["location_id", "year_id", "age_group_id", "sex_id"] + + dataset = df.set_index(index_cols).to_xarray() + dataset["eps"] = dataset["migration_rate"] - dataset["predictions"] + mig_dir = FBDPath("FILEPATH") + eps_path = mig_dir / "mig_eps.nc" + save_xr(dataset["eps"].sel(year_id=years.past_years), + eps_path, metric="rate", space="identity") + + pred_path = mig_dir / "mig_hat.nc" + save_xr(dataset["predictions"].sel(year_id=years.years), + pred_path, metric="rate", space="identity") + + mig_path = mig_dir / "wpp_hat.nc" + save_xr(dataset["migration_rate"].sel(year_id=years.years), + mig_path, metric="rate", space="identity") + + +def make_past(mig_version, gbd_round_id, years): + past_dir = FBDPath("FILEPATH") + if os.path.isfile(past_dir / f"past_mig_rate_single_years.csv"): + logger.info("Past in csv, converting to xarray") + past_path = past_dir / f"past_mig_rate_single_years.csv" + df = pd.read_csv(past_path) + # add all-sex and all-age id columns + df["sex_id"] = SEX_ID + df["age_group_id"] = AGE_GROUP_ID + # select the columns we need + df = df[["location_id", + "year_id", + "age_group_id", + "sex_id", + "migration_rate"]] + # set index columns + index_cols = ["location_id", + "year_id", + "age_group_id", + "sex_id"] + + dataset = df.set_index(index_cols).to_xarray() + save_xr( + dataset.sel(year_id=years.past_years), past_dir / "wpp_past.nc", + metric="rate", space="identity") + elif os.path.isfile(past_dir / f"past_mig_rate_single_years.nc"): + dataset = open_xr(past_dir / f"past_mig_rate_single_years.nc") + save_xr( + dataset.sel(year_id=years.past_years), past_dir / "wpp_past.nc", + metric="rate", space="identity") + logger.info("Past Mig Rate Xarray already created: copied " + "to wpp_past.nc") + else: + raise OSError("Past Mig not in csv or not already an xarray") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument( + "--mig-version", type=str, required=True, + help="The version the migration data are saved under.") + parser.add_argument( + "--model-version", type=str, required=True, + help="The version the model results are saved under.") + parser.add_argument( + "--model-name", type=str, required=True, + help="The name of the model.") + parser.add_argument("--gbd-round-id", type=int, required=True, + help="Which gbd round id to use") + parser.add_argument("--years", type=str, + help="past_start:forecast_start:forecast_end") + args = parser.parse_args() + + years = YearRange.parse_year_range(args.years) + + make_eps(args.mig_version, args.model_version, args.model_name, + args.gbd_round_id, years) + + make_past(args.mig_version, args.gbd_round_id, years) diff --git a/gbd_2021/disease_burden_forecast_code/migration/migration_rate_to_count.py b/gbd_2021/disease_burden_forecast_code/migration/migration_rate_to_count.py new file mode 100644 index 0000000..a99ae04 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/migration/migration_rate_to_count.py @@ -0,0 +1,144 @@ +""" +Convert migration rates output by draw generation step to counts for use in +age-sex splitting step. + +Example: + +.. code:: bash + python -m pdb FILEPATH/migration_rate_to_count.py \ + --migration_version click_20210510_limetr_fixedint \ + --past_pop_version 20200206_etl_gbd_decomp_step4_1950_2019_run_id_192 \ + --forecast_pop_version 20200513_arc_method_new_locs_ratio_subnats_else_pop_paper \ + --gbd_round_id 6 \ + --years 1950:2020:2050 + +Counts are calculated on mean population to avoid over-dispersed draws. ie: a high population +draw is arbitrarily multiplied by high in-migration draw or visa-versa. +""" +import argparse + +from fhs_lib_data_transformation.lib.dimension_transformation import expand_dimensions +from fhs_lib_file_interface.lib.file_interface import FBDPath +from fhs_lib_file_interface.lib.file_system_manager import FileSystemManager +from fhs_lib_file_interface.lib.os_file_system import OSFileSystem +FileSystemManager.set_file_system(OSFileSystem()) +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr, save_xr +from fhs_lib_year_range_manager.lib.year_range import YearRange + +SCALE_FACTOR = 1000 + +def load_past_pop(gbd_round_id, version): + """ + Load past population data. This will generally be from 1950 to the start of + the forecasts. + + Args: + gbd_round_id (int): + The gbd round ID that the past population is from + version (str): + The version of past population to read from + + Returns: + xarray.DataArray: The past population xarray dataarray + """ + past_pop_dir = FBDPath("FILEPATH") + past_pop_path = past_pop_dir / "population.nc" + past_pop_da = open_xr(past_pop_path) + + return past_pop_da + + +def load_forecast_pop(gbd_round_id, version, years): + """ + Load forecast population data. Aggregates if necessary. + + Args: + gbd_round_id (int): The gbd round ID that the past population is from + version (str): The version of forecast population to read from + years (YearRange): The Forecasting format years to use. + draws (int): Population is resampled to number of draws supplied. + + Returns: + xarray.DataArray: The past population xarray dataarray + """ + forecast_pop_dir = FBDPath("FILEPATH") + forecast_pop_path = forecast_pop_dir / "population_agg.nc" + forecast_pop_da = open_xr(forecast_pop_path) + + # slice to correct years + forecast_pop_da = forecast_pop_da.sel(year_id=years.forecast_years) + + return forecast_pop_da + + +def main(migration_version, past_pop_version, forecast_pop_version, + gbd_round_id, years): + """ + Load pops and migration rate, multiply to get counts + """ + # Load migration data + mig_dir = FBDPath("FILEPATH") + mig_path = mig_dir / "mig_star.nc" + mig_da = open_xr(mig_path) + + # Load pops + past_pop_da = load_past_pop(gbd_round_id, past_pop_version) + forecast_pop_da = load_forecast_pop(gbd_round_id, forecast_pop_version, + years) + + past_pop_da = expand_dimensions(past_pop_da, + scenario=forecast_pop_da["scenario"].values) + + # Subset to coordinates relevant to mig_da + forecast_pop_da = forecast_pop_da.sel(sex_id=3, age_group_id=22, + location_id=mig_da.location_id.values, scenario=0) + past_pop_da = past_pop_da.sel(sex_id=3, age_group_id=22, + location_id=mig_da.location_id.values, scenario=0) + + # Combine past and forecast pop + pop_da = past_pop_da.combine_first(forecast_pop_da) + + # Multiply rates by pop to get counts + mig_counts = mig_da * pop_da + mig_counts = mig_counts / SCALE_FACTOR + + # Save out + mig_counts_path = mig_dir / "mig_counts.nc" + save_xr(mig_counts, mig_counts_path, metric="number", space="identity") + + # Load past migration data + past_mig_dir = FBDPath("FILEPATH") + past_mig_path = past_mig_dir / "wpp_past.nc" + past_mig_da = open_xr(past_mig_path) + + # Multiply rates by past pop to get counts + past_mig_counts = past_mig_da * past_pop_da + past_mig_counts = past_mig_counts / SCALE_FACTOR + + past_mig_counts_path = past_mig_dir / "wpp_past_counts.nc" + save_xr(past_mig_counts, past_mig_counts_path, metric="number", space="identity") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--migration_version", type=str, required=True, + help="Which version of migration to convert") + parser.add_argument("--past_pop_version", type=str, required=True, + help="Which version of past populations to use") + parser.add_argument("--forecast_pop_version", type=str, required=True, + help="Which version of forecasted populations to use") + parser.add_argument("--gbd_round_id", type=int, required=True, + help="Which gbd round id to use for populations") + parser.add_argument("--years", type=str, required=True, + help="past_start:forecast_start:forecast_end") + args = parser.parse_args() + + years = YearRange.parse_year_range(args.years) + + main(migration_version=args.migration_version, + past_pop_version=args.past_pop_version, + forecast_pop_version=args.forecast_pop_version, + gbd_round_id=args.gbd_round_id, + years=years) diff --git a/gbd_2021/disease_burden_forecast_code/migration/model_migration.py b/gbd_2021/disease_burden_forecast_code/migration/model_migration.py new file mode 100644 index 0000000..b3c0834 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/migration/model_migration.py @@ -0,0 +1,352 @@ +""" +Clean and model UN migration estimates for forecasting. + +example call: +python FILEPATH/model_migration.py \ +--shocks-past-version 20201005_shocks_get_draws_1980_to_2019 \ +--shocks-forecast-version 20210419_shocks_only_decay_weight_15 \ +--gbd-round-id 6 \ +--output-version click_20210510_limetr_fixedint \ +--years 1950:2020:2050 \ +--subnats + +NOTE: If running with --subnats, make sure that each covariate (shocks, + natural pop increase) has past and future years, and all locations. + This avoids NaNs, which can prevent LimeTr from achieving convergence. +""" +import click +import pandas as pd + +from fhs_lib_file_interface.lib.file_interface import FBDPath +from fhs_lib_file_interface.lib.file_system_manager import FileSystemManager +from fhs_lib_file_interface.lib.os_file_system import OSFileSystem +FileSystemManager.set_file_system(OSFileSystem()) +from os.path import isfile +from fhs_lib_year_range_manager.lib.year_range import YearRange +from db_queries import get_location_metadata +from tiny_structured_logger.lib import fhs_logging + +logger = fhs_logging.get_logger() + + +def _wrangle_2022_wpp( + wpp_data: pd.DataFrame, + fhs_col_name: str, + loc_mapping: pd.DataFrame + ) -> pd.DataFrame: + + logger.info("Reformatting and cleaning WPP input data") + wpp_data = wpp_data.rename(columns={"Location code": "iso_num", "Year": "year_id"}) + wpp_data.fillna(0, inplace=True) + wpp_data["iso_num"] = wpp_data["iso_num"].apply(int) + wpp_data["year_id"] = wpp_data["year_id"].apply(int) + + # convert covariate to float + wpp_data[fhs_col_name] = wpp_data[fhs_col_name].apply(float) + + bool_series = wpp_data.duplicated(keep='first') + wpp_data = wpp_data[~bool_series] + + wpp_data = wpp_data.merge(loc_mapping, how="inner", on="iso_num") + wpp_data = wpp_data.loc[:, ["location_id", "location_name", "year_id", "super_region_id", fhs_col_name]] + + return wpp_data + + +def _read_2022_wpp( + data_path: FBDPath, + wpp_col_name: str, + fhs_col_name: str, + loc_mapping: pd.DataFrame, + ) -> pd.DataFrame: + """ + load Past and Future data in expected WPP format, data in same xlsx file + as seperate sheets. Past as "ESTIMATES" and forecasts as "MEDIUM VARIANT" + + NOTE: read_excel requires openpyxl to be installed to function in later + versions of pandas see https://github.com/pandas-dev/pandas/issues/38424 + + NOTE: this loads data as formatted in WPP 2022, future WPP data may require + refactoring + + Args: + data_path (FBDPath): file path for WPP 2022 covariates + wpp_col_name (str): WPP column name of covariate to load + fhs_col_name (str): FHS column name to replace WPP name + loc_mapping (pd.DataFrame): dataframe mapping IHME country codes to WPP country + + + """ + logger.info("Reading WPP xlsx file") + past_data = pd.read_excel(data_path, sheet_name="Estimates", skiprows=16, + na_values="...", engine="openpyxl") + + past_data = past_data.loc[:, ["Location code", "Year", wpp_col_name]] + + + forecast_data = pd.read_excel(data_path, sheet_name="Medium variant", skiprows=16, + na_values="...", engine="openpyxl") + + forecast_data = forecast_data.loc[:, ["Location code", "Year", wpp_col_name]] + + complete_timeseries = past_data.append(forecast_data) + complete_timeseries = complete_timeseries.rename(columns={wpp_col_name: fhs_col_name}) + + #return complete_timeseries + return _wrangle_2022_wpp(complete_timeseries, fhs_col_name, loc_mapping) + + +def read_shock(past_path, forecast_path, data_column_name, years): + # Combine past/forecast shocks + logger.info("Reading in Shocks") + past_shock = pd.read_csv(past_path) + past_shock = past_shock[["location_id", "year_id", "mortality"]] + forecast_shock = pd.read_csv(forecast_path) + forecast_shock = forecast_shock[["location_id", "year_id", "mortality"]] + forecast_shock = forecast_shock.loc[ + forecast_shock["year_id"] >= years.forecast_start] + past_shock = past_shock.loc[past_shock["year_id"] < years.forecast_start] + shock = past_shock.append(forecast_shock) + shock = shock.rename({"mortality": data_column_name}, axis=1) + return shock + + +@click.command() +@click.option("--shocks-past-version", required=True, type=str) +@click.option("--shocks-forecast-version", required=True, type=str) +@click.option("--gbd-round-id", required=True, type=int) +@click.option("--output-version", required=True, type=str) +@click.option("--years", required=True, type=str, help="Years to forecast") +@click.option("--subnats/--no-subnats", required=False, default=False, + help="Whether or not to include subnats") +def main( + shocks_past_version: str, + shocks_forecast_version: str, + gbd_round_id: int, + output_version: str, + years: str, + subnats: bool, +): + """ + setup directories, creating output if needed. direct calls to other + functions e.g. wrangle_wpp_data, load_combine_data, etc + """ + + years = YearRange.parse_year_range(years) + + disaster_past_path = FBDPath("FILEPATH/mean_inj_disaster.csv") + disaster_forecast_path = FBDPath("FILEPATH/mean_inj_disaster.csv") + execution_past_path = FBDPath("FILEPATH/mean_inj_war_execution.csv") + execution_forecast_path = FBDPath("FILEPATH/mean_inj_war_execution.csv") + terror_past_path = FBDPath("FILEPATH/mean_inj_war_warterror.csv") + terror_forecast_path = FBDPath("FILEPATH/mean_inj_war_warterror.csv") + + logger.info("loading WPP data") + loc_mapping_dir = ("FILEPATH.CSV") + + covariates_2022_path = FBDPath("FILEPATH.xlsx") + + + data_dir = FBDPath("FILEPATH") + data_dir.mkdir(parents=True, exist_ok=True) + + # Test whether migration forecasts already exist. If they do, skip recreating them. + if isfile(f"{data_dir}/covariates_single_years.csv"): + mig_covs_pred = pd.read_csv(f"{data_dir}/covariates_single_years.csv") + logger.info(f"Result files already exist at: {data_dir}") + # If not, align UN and IHME location ids + else: + logger.info("Result files did not exist") + loc_mapping = pd.read_csv(loc_mapping_dir, encoding="ISO-8859-1") + location_hierarchy = get_location_metadata(location_set_id=21, + gbd_round_id=gbd_round_id) + loc_mapping = loc_mapping.merge(location_hierarchy, how="inner", + on="location_id") + loc_namelist = ["location_id", "iso_num", "ihme_loc_id", + "super_region_name", "region_name", "location_name_x", "super_region_id"] + loc_mapping = loc_mapping[loc_namelist] + loc_mapping = loc_mapping.rename( + columns={"location_name_x":"location_name"}) + + natural_pop_increase = _read_2022_wpp( + covariates_2022_path, + "Rate of Natural Change (per 1,000 population)", + "natural_pop_increase", + loc_mapping, + ) + migration_rate = _read_2022_wpp( + covariates_2022_path, + "Net Migration Rate (per 1,000 population)", + "migration_rate", + loc_mapping, + ) + median_age = _read_2022_wpp( + covariates_2022_path, + "Median Age, as of 1 July (years)", + "median_age", + loc_mapping, + ) + + # Read shocks + disaster_dt = read_shock(disaster_past_path, + disaster_forecast_path, "disaster", years) + execution_dt = read_shock(execution_past_path, + execution_forecast_path, "execution", years) + terror_dt = read_shock(terror_past_path, + terror_forecast_path, "terror", years) + + + merge_cols = ["location_id", "year_id"] + # [ disaster, execution_dt, terror_dt, med_age, nat_pop, mig_rate] + mig_covs_pred = pd.merge(execution_dt, disaster_dt, on=merge_cols) + mig_covs_pred = mig_covs_pred.merge(terror_dt, on=merge_cols) + mig_covs_pred = mig_covs_pred.merge(median_age, on=merge_cols) + mig_covs_pred = mig_covs_pred.merge(natural_pop_increase, + on=merge_cols + ["location_name"]) + mig_covs_pred = mig_covs_pred.merge(migration_rate, + on=merge_cols + ["location_name"]) + + mig_covs_pred.loc[:, "shocks"] = (mig_covs_pred["disaster"] + + mig_covs_pred["execution"] + + mig_covs_pred["terror"]) + + logger.info("Remove duplicates by year & location if any," + "create df list") + + # Removing duplicates, if any, from various dataframes. + covar_list = [] + for df in [disaster_dt, execution_dt, terror_dt, natural_pop_increase, median_age]: + if True in df.duplicated( + subset=['year_id', 'location_id']).to_list(): + df = df.drop_duplicates(subset=['year_id', 'location_id']) + covar_list.append(df) + else: + covar_list.append(df) + + # Reassigning dataframes sans duplicates + disaster_dt = covar_list[0] + execution_dt = covar_list[1] + terror_dt = covar_list[2] + natural_pop_increase = covar_list[3] + median_age = covar_list[5] + + # Creating subnats, appending to nationals, if --subnats is True + if subnats is True: + logger.info("Creating subnationals") + subnat_locations = location_hierarchy[ + (location_hierarchy["parent_id"].isin( + [44533, + 6, + 102, + 163, + 135])) & + (location_hierarchy["location_id"] != 44533)].location_id + + mig_covs_subnat_pred = subnat_locations.to_frame() + mig_covs_subnat_pred["year_id"] = ([years.years] * + len(mig_covs_subnat_pred)) + mig_covs_subnat_pred = mig_covs_subnat_pred.explode("year_id") + subnat_location_hierarchy_dt = location_hierarchy.query( + "level > 3")[["location_id", "location_name", "path_to_top_parent"]] + + # create table of subnational locations with national id as a new col + subnat_location_hierarchy_dt.loc[:, "nat_location_id"] = pd.to_numeric( + subnat_location_hierarchy_dt.path_to_top_parent.str.split( + ",", expand=True)[3], downcast="integer") + subnat_location_hierarchy_dt.loc[:, "sr_location_id"] = pd.to_numeric( + subnat_location_hierarchy_dt.path_to_top_parent.str.split( + ",", expand=True)[1], downcast="integer") + + mig_covs_subnat_pred = pd.merge( + mig_covs_subnat_pred, + subnat_location_hierarchy_dt[["location_id", "location_name", + "nat_location_id", "sr_location_id"]], + on="location_id") + + subset_cols = ["location_id", "year_id", "disaster", "execution", + "terror", "median_age", "natural_pop_increase", + "migration_rate", "shocks", "super_region_id"] + + mig_covs_subnat_pred = pd.merge( + mig_covs_subnat_pred, + mig_covs_pred[subset_cols].rename( + columns={"location_id": "nat_location_id", "super_region_id": "sr_location_id"}), + on=["nat_location_id", "sr_location_id", "year_id"]) + mig_covs_subnat_pred = mig_covs_subnat_pred.drop("nat_location_id", + axis=1) + mig_covs_subnat_pred = pd.merge(mig_covs_subnat_pred, + on=merge_cols) + # drop past rows for Hong Kong and Macao + hong_kong_macao_forecast = mig_covs_subnat_pred[ + (mig_covs_subnat_pred["location_id"].isin([354, 361])) & + (mig_covs_subnat_pred["year_id"] >= years.forecast_start)] + mig_covs_subnat_pred = mig_covs_subnat_pred[ + ~mig_covs_subnat_pred.location_id.isin([354, 361])].append( + hong_kong_macao_forecast) + + mig_covs_subnat_pred = mig_covs_subnat_pred.rename(columns={"sr_location_id": "super_region_id"}) + mig_covs_pred = mig_covs_pred.append(mig_covs_subnat_pred) + + if True in mig_covs_pred.duplicated( + subset=['year_id', 'location_id']).to_list(): + mig_covs_pred = mig_covs_pred.drop_duplicates( + subset=['year_id', 'location_id']) + + mig_covs_pred.loc[:, "scenario"] = 0 + mig_covs_pred = mig_covs_pred.drop("super_region_id_y", axis=1) # remove duplicate col + + # Writing covars to csv + mig_covs_pred.to_csv(data_dir / "covariates_single_years.csv", + index=False) + mig_covs_pred[mig_covs_pred["year_id"].isin(years.past_years)].to_csv( + data_dir / "past_mig_rate_single_years.csv", index=False) + + p_mig_path = FBDPath("FILEPATH") + p_mig_path.mkdir(parents=True, exist_ok=True) + + merge_cols = ["location_id", "year_id", "scenario"] + + mig_covs_pred[ + ["migration_rate"] + merge_cols].set_index( + merge_cols).to_xarray().to_netcdf( + data_dir / "migration_single_years.nc") + + # Saving past migration rate to past folder for LimeTr + mig_covs_pred[ + ["migration_rate"] + merge_cols].set_index( + merge_cols).to_xarray().sel( + year_id=years.past_years).to_netcdf( + p_mig_path / "past_mig_rate_single_years.nc") + + def _save_covariate_as_nc( + covariate_col: str, + file_name: str, + stage: str, + ) -> None: + + logger.info(f"saving {covariate_col} to both migration directory and `FILEPATH` directory") + covs_da = mig_covs_pred[[covariate_col] + merge_cols].set_index(merge_cols).to_xarray() + covs_da.to_netcdf(data_dir / f"{file_name}.nc") + + covariate_dir = FBDPath("FILEPATH") + past_covariate_dir = FBDPath("FILEPATH") + + covariate_dir.mkdir(parents=True, exist_ok=True) + past_covariate_dir.mkdir(parents=True, exist_ok=True) + + mig_covs_pred[[covariate_col] + merge_cols].set_index( + merge_cols).to_xarray().sel( + year_id=years.past_years).to_netcdf( + past_covariate_dir / f"past_{file_name}.nc") + + mig_covs_pred[[covariate_col] + merge_cols].set_index( + merge_cols).to_xarray().sel( + year_id=years.forecast_years).to_netcdf( + covariate_dir / f"forecast_{file_name}.nc") + + _save_covariate_as_nc("shocks", "shocks_single_years", "death") + _save_covariate_as_nc("natural_pop_increase", "natural_pop_increase_single_years", "population") + _save_covariate_as_nc("median_age", "median_age_single_years", "median_age") + +if __name__ == "__main__": + main() diff --git a/gbd_2021/disease_burden_forecast_code/migration/model_strategy.py b/gbd_2021/disease_burden_forecast_code/migration/model_strategy.py new file mode 100644 index 0000000..1f28f53 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/migration/model_strategy.py @@ -0,0 +1,66 @@ +r"""This module is where migration modeling strategies and their parameters are +managed/defined. + +**Modeling parameters include:** + +* pre/post processing strategy (i.e. processor object) +* covariates +* fixed-effects +* fixed-intercept +* random-effects +* indicators +""" + +from collections import namedtuple +from enum import Enum +from frozendict import frozendict + +from fhs_lib_model.lib import model +from fhs_lib_data_transformation.lib import processing + + +class Covariates(Enum): + """Covariates uses for modeling migration""" + POPULATION = "population" + DEATH = "death" + +VALID_COVARIATES = tuple(cov.value for cov in Covariates) + + +ModelParameters = namedtuple( + "ModelParameters", ( + "Model, " + "processor, " + "covariates, " + "fixed_effects, " + "fixed_intercept, " + "random_effects, " + "indicators, " + "spline, " + "predict_past_only, " + ) + ) +MODEL_PARAMETERS = frozendict({ + # Indicators: + "migration": frozendict({ + "LimeTr": ModelParameters( + Model=model.LimeTr, + processor=processing.NoTransformProcessor(years=None, gbd_round_id=None), + covariates={"population": processing.NoTransformProcessor(years=None, gbd_round_id=None), + "death": processing.NoTransformProcessor(years=None, gbd_round_id=None), + }, + fixed_effects={"population": [-float('inf'), float('inf')], + "death": [-float('inf'), float('inf')], + }, + fixed_intercept='unrestricted', + random_effects={ + "location_intercept": model.RandomEffect( + ["location_id"], None + ), + }, + indicators=None, + spline=None, + predict_past_only=False, + ), + }) + }) diff --git a/gbd_2021/disease_burden_forecast_code/migration/model_strategy_queries.py b/gbd_2021/disease_burden_forecast_code/migration/model_strategy_queries.py new file mode 100644 index 0000000..815a6c3 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/migration/model_strategy_queries.py @@ -0,0 +1,143 @@ +r"""This module has query functions that give nonfatal modeling strategies and +their parameters. +""" +import logging + +from migration import model_strategy + + +LOGGER = logging.getLogger(__name__) + +MIG_LIMETR_STAGES = ["migration"] + +def get_mig_model(stage, years, gbd_round_id): + r"""Gets modeling parameters associated with the given migration stage. + Technically not a query right now, but functions like get_cause_model. + + Note: + Update ``FILEPATH.MODEL_PARAMETERS`` to + reflect changes in model specification. + Args: + stage (str): + Stage being forecasted, e.g. "temperature". + years (fbd_core.YearRange): + Forecasting timeseries + Returns: + Model (Model): + Class, i.e. un-instantiated from + ``FILEPATH.model.py`` + processor (Processor): + The pre/post process strategy of the cause-stage, i.e. instance of + a class defined in ``FILEPATH.processing.py``. + covariates (dict[str, Processor]] | None): + Maps each needed covariate, i.e. independent variable, to it's + respective preprocess strategy, i.e. instance of a class defined in + ``FILEPATH.processing.py``. + node_models (list[CovModel]): + A list of NodeModels (e.g. StudyModel, OverallModel), each of + which has specifications for cov_models. + study_id_cols (Union[str, List[str]]): + The columns to use in the `col_study_id` argument to MRBRT + ``load_df`` function. If it is a list of strings, those + columns will be concatenated together (e.g. ["location_id", + "sex_id"] would yield columns with values like + ``{location_id}_{sex_id}``). This is done since MRBRT can + currently only use the one ``col_study_id`` column for random + effects. + scenario_quantiles (dict | None): + Whether to use quantiles of the stage two model when predicting + scenarios. Dictionary of quantiles to use is passed in + e.g.:: + + { + -1: dict(sdi=0.85), + 0: None, + 1: dict(sdi=0.15), + 2: None, + } + fixed_effects (dict[str, str] | None): + List of covariates to calculate fixed effect coefficient + estimates for. + e.g.:: + + {"haq": [-float('inf'), float('inf'), "edu": [0, 4.7]} + + fixed_intercept (str | None): + To restrict the fixed intercept to be positive or negative, pass + "positive" or "negative", respectively. "unrestricted" says to + estimate a fixed effect intercept that is not restricted to + positive or negative. + random_effects (dict[str, list[str]] | None): + A dictionary mapping covariates to the dimensions that their + random slopes will be estimated for and the standard deviation of + the gaussian prior on their variance. Of the form + ``dict[covariate, list[dimension]]``. e.g.:: + + {"haq": (["location_id", "age_group_id"], None), + "education": (["location_id"], 4.7)} + + **NOTE** that random intercepts should be included here -- effects + that aren't associated with covariates will be assumed to be random + intercepts. + indicators (dict[str, list[str]] | None): + A dictionary mapping indicators to the dimensions that they are + indicators on. e.g.:: + + {"ind_age_sex": ["age_group_id", "sex_id"], + "ind_loc": ["location_id"]} + spline (dict[str, SplineTuple]): + A dictionary mapping covariates to the spline parameters that + will be used for them of the form + {covariate: SplineParams(degrees_of_freedom, constraints)} + Each key must be a covariate. + The degrees_of_freedom int represents the degrees of freedom + on that spline. + The constraint string can be "center" indicating to apply a + centering constraint or a 2-d array defining general linear + constraints. + Raises: + ValueError: + If the given stage does NOT have any cause-strategy IDs + """ + if stage in MIG_LIMETR_STAGES: + model_strategy_name = "LimeTr" + else: + err_msg = ( + f"stage={stage} does not have a model strategy associated with it") + LOGGER.error(err_msg) + raise ValueError(err_msg) + + model_parameters = ( + model_strategy.MODEL_PARAMETERS[stage][model_strategy_name]) + + model_parameters = _update_processor_years(model_parameters, years) + + model_parameters = _update_processor_gbd_round_id(model_parameters, gbd_round_id) + + return model_parameters + + +def _update_processor_gbd_round_id(model_parameters, gbd_round_id): + """``gbd_round_id`` is entered as ``None`` in the procesor for the dependent + variable and covariates so it needs to be updated here""" + if model_parameters: + model_parameters.processor.gbd_round_id = gbd_round_id + + if model_parameters.covariates: + for cov_name in model_parameters.covariates.keys(): + model_parameters.covariates[cov_name].gbd_round_id = gbd_round_id + + return model_parameters + + +def _update_processor_years(model_parameters, years): + """``years`` is entered as ``None`` in the procesor for the dependent + variable and covariates so it needs to be updated here""" + if model_parameters: + model_parameters.processor.years = years + + if model_parameters.covariates: + for cov_name in model_parameters.covariates.keys(): + model_parameters.covariates[cov_name].years = years + + return model_parameters diff --git a/gbd_2021/disease_burden_forecast_code/migration/run_model.py b/gbd_2021/disease_burden_forecast_code/migration/run_model.py new file mode 100644 index 0000000..361afe5 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/migration/run_model.py @@ -0,0 +1,293 @@ +"""A script that forecasts migration with limetr. + +This run_model takes in covariate files that have already been cleaned and +produced by model_migration.py, and assumes that the data has +been prepped adequately for limetr. As such, it does no cleaning itself. + +Example call for all causes for a given stage (using default model +specifications): + +.. code:: bash + +python -m pdb FILEPATH/run_model.py \ +--gbd-round-id 6 \ +--migration-version click_20210510_limetr_fixedint \ +--stage migration \ +--draws 1000 \ +--mig-past past_mig_rate_single_years \ +--years 1950:2020:2050 \ +--version [*file paths for input versions] +""" +import argparse +import pandas as pd + +from fbd_core.etl import assert_shared_coords_same, exc +from fhs_lib_data_transformation.lib.processing import get_dataarray_from_dataset +from fhs_lib_file_interface.lib.file_interface import FBDPath +from fhs_lib_file_interface.lib.file_system_manager import FileSystemManager +from fhs_lib_file_interface.lib.os_file_system import OSFileSystem +FileSystemManager.set_file_system(OSFileSystem()) +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr, save_xr +from fhs_lib_year_range_manager.lib.year_range import YearRange +from tiny_structured_logger.lib import fhs_logging + +from migration import ( + model_strategy_queries, + assert_covariates_scenarios, + ) + +logger = fhs_logging.get_logger() + + +COV_MAP = {"population": "natural_pop_increase", + "death": "shocks" + } + + +def forecast_migration( + migration_version, years, gbd_round_id, mig_past, stage, draws): + r"""Forecasts given stage for given cause. + + Args: + stage (str): + Stage being forecasted (migration). + Model (Model | None): + Class, i.e. un-instantiated from + ``FILEPATH.model.py`` + processor (Processor): + The pre/post process strategy of the cause-stage, i.e. instance of + a class defined in ``FILEPATH.processing.py``. + covariates (dict[str, Processor]] | None): + Maps each needed covariate, i.e. independent variable, to it's + respective preprocess strategy, i.e. instance of a class defined in + ``FILEPATH.processing.py``. + fixed_effects (dict[str, list[float, float]] | None): + Covariates to calculate fixed-effect coefficients for, mapped to + their respective correlation restrictions (i.e. forcing + coefficients to be positive, negative, or unrestricted). e.g.:: + + {"haq": [-float('inf'), float('inf')], + "edu": [0, 4.7]} + + fixed_intercept (str | None): + To restrict the fixed intercept to be positive or negative, pass + "positive" or "negative", respectively. "unrestricted" says to + estimate a fixed effect intercept that is not restricted to + positive or negative. + If ``None`` then no fixed intercept is estimated. + Currently all of the strings get converted to unrestricted. + random_effects (dict[str, (list[str], float|None)] | None): + A dictionary mapping covariates to the dimensions that their + random slopes will be estimated for. Of the form + ``dict[covariate, RandomEffect(list[dimension], std_prior)]``. + e.g.:: + + {"haq": RandomEffect( + ["location_id"], None), + "education": RandomEffect( + ["location_id", "age_group_id"], 3), + "location_age_intercept": RandomEffect( + ["location_id", "sex_id"], 1)} + + **NOTE** that random intercepts should be included here -- effects + that aren't associated with covariates will be assumed to be random + intercepts. + indicators (dict[str, list[str]] | None): + A dictionary mapping indicators to the dimensions that they are + indicators on. e.g.:: + + {"ind_age_sex": ["age_group_id", "sex_id"], + "ind_loc": ["location_id"]} + + versions (Versions): + All relevant versions. e.g.:: + FILEPATH/123 + FILEPATH/124 + FILEPATH/123 + + years (fbd_core.YearRange): + Forecasting time series. + gbd_round_id (int): + The numeric ID of GBD round associated with the past data + Raises: + CoordinatesError: + If past and forecast data don't line up across all dimensions + except ``year_id`` and ``scenario``, e.g. if coordinates for of + ``age_group_id`` are missing from forecast data, but not past data. + CoordinatesError: + If the covariate data is missing coordinates from a dim it shares + with the dependent variable -- both **BEFORE** and **AFTER** + pre-processing. + CoordinatesError: + If the covariates do not have consistent scenario coords. + DimensionError: + If the past and forecasted data dims don't line up (after the past + is broadcast across the scenario dim). + RuntimeError: + If there is missing version + ValueError: + If the given stage does NOT have any cause-strategy IDs + ValueError: + If the given mig_past/stage/gbd-round-id combo does not have a + modeling strategy associated with it. + """ + logger.info(f"Entering `forecast_migration` function") + + # This will read in 'covariates_single_years.csv' from the future folder + # where it is currently saved when running model_migration.py. + # Will be saved with predictions added, as final output of run_model. + cov_single_path = FBDPath("FILEPATH/covariates_single_years.csv") + covs_single = pd.read_csv(cov_single_path) + + # Reading past data for model + past_path = ( + FBDPath(f"{gbd_round_id}/past/{stage}/{mig_past}/past_mig_rate_single_years.nc") + ) + past_data = open_xr(past_path) + + # Retrieving the Model and processor for running the model. + (Model, processor, covariates, fixed_effects, fixed_intercept, + random_effects, indicators, spline, predict_past_only + ) = _get_model_parameters(stage, years, gbd_round_id) + + logger.info(f"running limetr with the following covariates: {covariates}") + if covariates: + cov_data_list = _get_covariate_data(migration_version, past_data, covariates, years) + else: + cov_data_list = None + + # Running the model. + model_instance = Model( + past_data, years, draws=draws, covariate_data=cov_data_list, + fixed_effects=fixed_effects, fixed_intercept=fixed_intercept, + random_effects=random_effects, indicators=indicators, + gbd_round_id=gbd_round_id) + + coefficients = model_instance.fit() + + forecast_path = FBDPath("FILEPATH") + model_instance.save_coefficients(forecast_path, "migration") + + forecast_data = model_instance.predict() + + prepped_output_data = processor.post_process( + forecast_data, past_data) + # Saving model output. Not necessary, but could be useful. + save_xr( + prepped_output_data, forecast_path / f"{stage}_preds.nc", + metric="rate", space="identity") + + # Prepping xarray file for csv conversion and merge with covariate data + mig_preds = prepped_output_data + mig_preds_mean = mig_preds.mean("draw") + mig_preds_mean_pd = mig_preds_mean.sel(scenario=0, drop=True).to_pandas() + mig_preds_mean_pd.reset_index(inplace=True) + mig_preds_mean_pd = pd.melt( + mig_preds_mean_pd, + id_vars=["location_id"]).sort_values('location_id') + # Merging covariate and migration predictions + preds_df = pd.merge(covs_single, + mig_preds_mean_pd, + how='inner', + left_on=['location_id', 'year_id'], + right_on = ['location_id', 'year_id']) + + preds_df = (preds_df.sort_values(['location_id', "year_id"]) + .rename(columns={"value": "predictions"})) + # Saving final output of run_model.py. Next up: csv_to_xr.py + model_6_single_years_path = FBDPath("FILEPATH/model_6_single_years.csv") + preds_df.drop("scenario", axis=1) + preds_df.to_csv(model_6_single_years_path, index=False) + + logger.info(f"Leaving `forecast_migration` function. DONE") + + +def _get_covariate_data( + migration_version, + dep_var_da, + covariates, + years, + ): + """Returns a list of prepped dataarray for all of the covariates""" + cov_data_list = [] + for cov_stage, cov_processor in covariates.items(): + cov_past_path = FBDPath(f"FILEPATH/past_{COV_MAP[cov_stage]}_single_years.nc") + + cov_past_data = open_xr(cov_past_path) + + cov_past_data = get_dataarray_from_dataset(cov_past_data).rename(cov_stage) + + cov_forecast_path = FBDPath(f"FILEPATH/forecast_{COV_MAP[cov_stage]}_single_years.nc") + + cov_forecast_data = open_xr(cov_forecast_path) + cov_forecast_data = get_dataarray_from_dataset( + cov_forecast_data).rename(cov_stage) + + cov_data = cov_past_data.combine_first(cov_forecast_data) + + prepped_cov_data = cov_data + + try: + assert_shared_coords_same( + prepped_cov_data, + dep_var_da.sel(year_id=years.past_end, drop=True) + ) + except exc.CoordinatesError as ce: + err_msg = f"Coordinates do not match for \ + migration and {COV_MAP[cov_stage]}," + str(ce) + logger.error(err_msg) + raise exc.CoordinatesError(err_msg) + + cov_data_list.append(prepped_cov_data) + + assert_covariates_scenarios(cov_data_list) + return cov_data_list + + +def _get_model_parameters(stage, years, gbd_round_id): + """Gets modeling parameters associated with the given cause-stage. + + If there aren't model parameters associated with the cause-stage then the + script will exit with return code 0. + """ + model_parameters = model_strategy_queries.get_mig_model( + stage, years, gbd_round_id + ) + if not model_parameters: + logger.info( + f"{stage} is not forecasted in this pipeline. DONE") + exit(0) + else: + return model_parameters + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--draws", type=int, + help="Number of draws") + parser.add_argument( + "--migration-version", type=str, required=True, + help="Migration Version") + parser.add_argument( + "--gbd-round-id", type=int, required=True, + help="The numeric ID of GBD round associated with the past data") + parser.add_argument( + "--stage", type=str, required=False, default='migration', + help="The gbd round id associated with the data.\n" + "Default to migration.") + parser.add_argument( + "--mig-past", type=str, required=False, + default="past_mig_rate_single_years", + help="Migration past file name. Default to\n" + "past_mig_rate_single_years") + parser.add_argument("--years", type=str, + help="past_start:forecast_start:forecast_end") + args = parser.parse_args() + + args.years = YearRange.parse_year_range(args.years) + + + forecast_migration(**args.__dict__) diff --git a/gbd_2021/disease_burden_forecast_code/mortality/README.md b/gbd_2021/disease_burden_forecast_code/mortality/README.md new file mode 100644 index 0000000..b1f5c05 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/mortality/README.md @@ -0,0 +1,135 @@ +Mortality pipeline code + +The primary workflow for cause-specific computation begins with `run_cod_model.main()` +as the main entry point, followed by `y_star`, then `sum_to_all_cause`, then `squeeze`. + +# Data transformation + +``` +correlate.py +Correlate residual error and modeled results +``` + +``` +exponentiate_draws.py +Exponentiate a distribution, preserving the mean +``` + +``` +intercept_shift.py +various implementations of the intercept shift operation +``` + +# Lib + +``` +config_dataclasses.py +Dataclasses for capturing the configuration of the pipeline +``` + +``` +downloaders.py +Functions used by run_cod_model.py to import past data +``` + +``` +get_fatal_causes.py +Get the fatal cause dataframe for the input +``` + +``` +intercept_shift.py +Load past data and use it to apply an intercept shift to preds at the draw level +``` + +``` +intercept_shift.py +Load past data and use it to apply an intercept shift to preds at the draw level +``` + +``` +make_all_cause.py +Makes a new version of mortality that consolidates data from externally and internally modeled causes +``` + +``` +make_hierarchies.py +Functions used by Mortality stage 2 and stage 3 to set up the aggregation hierarchy and the cause set for ARIMA to operate on +``` + +``` +mortality_approximation.py +Perform mortality approximation for every cause +``` + +``` +run_cod_model.py +Functions to run cause of death model +``` + +``` +smoothing.py +Utilities related to smoothing in the latent trend model +``` + +``` +squeeze.py +Module for squeezing results into an envelope +``` + +``` +sum_to_all_cause.py +Aggregates the expected value of mortality or ylds up the cause hierarchy and computes y_hat and y_past +``` + +``` +y_star.py +Computes y-star, which is the sum of the latent trend component and the y-hat predictions from the GK model +``` + + +# Models +``` +GKModel.py +Contains an implementation of the Girosi-King ("GK") Model +``` + +``` +model_parameters.py +Contains a class that encapsulates logic/info related GK model parameters +``` + +``` +omega.py +Contains utilities for preparing the Bayesian omega priors of the GK model +``` + +``` +post_process.py +Contains utilities for processing the GK model output into a more usable form +``` + +``` +pre_process.py +contains utilities for preparing the GK model input +``` + +``` +pre_process.py +contains utilities for preparing the GK model input +``` + +``` +pooled_random_walk.py +contains tools creating a collection of correlated random walk projections +``` + +``` +random_walk.py +Contains the random walk model +``` + +``` +remove_drift.py +Contains functions for attenuating or removing the drift effect from epsilon +``` \ No newline at end of file diff --git a/gbd_2021/disease_burden_forecast_code/mortality/data-transformation/correlate.py b/gbd_2021/disease_burden_forecast_code/mortality/data-transformation/correlate.py new file mode 100644 index 0000000..34c4552 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/mortality/data-transformation/correlate.py @@ -0,0 +1,107 @@ +"""A function to correlate residual error and modeled results. + +This function correlates data of residual error, i.e., epsilon, predictions and modeled +results. The intent is to capture any time trends that remain after the explanatory variables +have been incorporated into the model. + +For example, when creating the all-cause ARIMA ensemble, we take the residual error between +modeled past and GBD past and forecast that residual into future years. +""" + +import itertools as it +from collections import namedtuple + +import xarray as xr +from fhs_lib_database_interface.lib.constants import DimensionConstants +from fhs_lib_year_range_manager.lib.year_range import YearRange + +from fhs_lib_data_transformation.lib.constants import CorrelateConstants + + +def correlate_draws( + epsilon_draws: xr.DataArray, + modeled_draws: xr.DataArray, + years: YearRange, +) -> xr.DataArray: + """Correlates draws of predicted epsilon draws with the modeled draws. + + Correlates high epsilon draws with high modeled draws (currently correlates based on first + forecasted year) to create a confidence interval that captures both past uncertainty and + model uncertainty. Returns the correctly ordered epsilons for future years. Time series of + epsilons are ordered by draw in the first predicted year and made to align in rank with the + modeled rates for each predicted year. + + Notes: + * **Pre-condition:** ``epsilon_draws`` and ``modeled_draws`` should have each have + ``"location_id"``, ``"age_group_id"``, ``"sex_id"``, ``"draw"``, and ``"year_id"`` as + dims, and only those. + + Args: + epsilon_draws (x.DataArray): Unordered draw-level epsilon results from a latent trend + model, e.g., an ARIMA or a Random Walk model. + modeled_draws (x.DataArray): Modeled draw-level predictions to correlate with, e.g., + natural-logged mortality rate draws. + years (YearRange): The forecasting time series year range. + + Returns: + xr.DataArray: Epsilon data now with correlated draws for all years. + """ + # Get both dataarrays lined up in terms of dims and coords. + epsilon_draws = epsilon_draws.transpose(*CorrelateConstants.DIM_ORDER) + modeled_draws = modeled_draws.transpose(*CorrelateConstants.DIM_ORDER) + + # We only need the first year of forecasts for ordering. + moded_draws_forecast_start = modeled_draws.sel( + **{DimensionConstants.YEAR_ID: years.forecast_start}, drop=True + ) + + # Create cartesian product of all demographic combinations. + AgeSexLocation = namedtuple("AgeSexLocation", "age sex location") + demog_slice_combos = it.starmap( + AgeSexLocation, + it.product( + epsilon_draws.age_group_id.values, + epsilon_draws.sex_id.values, + epsilon_draws.location_id.values, + ), + ) + + # The index of the year to order draws by -- the last forecast year. + order_year_index = years.forecast_end - years.past_start + + # Iterate through each demographic slice and correlate the draws for each as we go. + for demog_slice_combo in demog_slice_combos: + slice_coord_dict = { + DimensionConstants.LOCATION_ID: demog_slice_combo.location, + DimensionConstants.AGE_GROUP_ID: demog_slice_combo.age, + DimensionConstants.SEX_ID: demog_slice_combo.sex, + } + + modeled_slice = moded_draws_forecast_start.loc[slice_coord_dict] + + # Use ``argsort`` twice to get ranks. + modeled_order = modeled_slice.argsort( + modeled_slice.dims.index(DimensionConstants.DRAW) + ) + modeled_ranks = modeled_order.argsort( + modeled_order.dims.index(DimensionConstants.DRAW) + ) + + # Order by draw, then just take the order from the first predicted year. + epsilon_slice = epsilon_draws.loc[slice_coord_dict] + sorted_epsilon_slice = epsilon_slice.argsort( + epsilon_slice.dims.index(DimensionConstants.DRAW) + ) + + epsilon_order = [ + sorted_epsilon_slice.values[i, order_year_index] + for i in range(epsilon_slice.shape[0]) + ] + + # Apply modeled ranks to the ordered epsilons to correctly place them. + epsilon_ordered = epsilon_slice.values[epsilon_order] + epsilon_ordered = epsilon_ordered[modeled_ranks.values] + epsilon_draws.load() + epsilon_draws.loc[slice_coord_dict] = epsilon_ordered + + return epsilon_draws diff --git a/gbd_2021/disease_burden_forecast_code/mortality/data-transformation/exponentiate_draws.py b/gbd_2021/disease_burden_forecast_code/mortality/data-transformation/exponentiate_draws.py new file mode 100644 index 0000000..7fa790d --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/mortality/data-transformation/exponentiate_draws.py @@ -0,0 +1,65 @@ +import numpy as np +import xarray as xr +from fhs_lib_database_interface.lib.constants import DimensionConstants +from tiny_structured_logger.lib.fhs_logging import get_logger + +logger = get_logger() + + +def bias_exp_new(darray: xr.DataArray) -> xr.DataArray: + r"""Exponentiate a distribution, preserving the mean in a certain way (see below). + + Exponentiate a distribution (draws) and adjust the results such that the + mean of the exponentiated distribution is equal to the exponentiated + expected value of the log distribution. + + Requires that input dataarray has draw dimension. + + The idea is that when + + .. math:: + g(x)=e^x, + + then + + .. math:: + E[g(X)]!=g(E[X]). + + Therefore we make factor :math:`c` such that + + .. math:: + c=g(E[X])/E[g(X)]. + + Then we have + + .. math:: + \begin{equation} + \begin{split} + E[g(X) \ * \ c] &= cE[g(X)] \\ + &= \frac{g(E[X])}{E[g(X)]}E[g(X)] \\ + &= g(E[X]) + \end{split} + \end{equation} + + Args: + darray (xr.DataArray): + Data array to exponentiate with bias correction. Must have draw + dimension. + + Returns: + xr.DataArray: + Exponentiated results with adjusted distribution. + + Raises: + ValueError: in the absence of a "draw" dimension. + """ + if DimensionConstants.DRAW not in darray.dims: + err_msg = "There is no draw dimension to use for bias_exp_new!" + logger.error(err_msg) + raise ValueError(err_msg) + + factor = np.exp(darray.mean(DimensionConstants.DRAW)) / np.exp(darray).mean( + DimensionConstants.DRAW + ) + + return np.exp(darray) * factor diff --git a/gbd_2021/disease_burden_forecast_code/mortality/data-transformation/intercept_shift.py b/gbd_2021/disease_burden_forecast_code/mortality/data-transformation/intercept_shift.py new file mode 100644 index 0000000..d916256 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/mortality/data-transformation/intercept_shift.py @@ -0,0 +1,469 @@ +"""This module contains various implementations of the intercept shift operation. +""" + +import itertools as it + +import xarray as xr +from fhs_lib_database_interface.lib.constants import DimensionConstants, ScenarioConstants +from fhs_lib_year_range_manager.lib.year_range import YearRange +from tiny_structured_logger.lib.fhs_logging import get_logger + +from fhs_lib_data_transformation.lib.draws import mean_of_draws +from fhs_lib_data_transformation.lib.validate import assert_coords_same_ignoring + +logger = get_logger() + + +def point_intercept_shift( + modeled_data: xr.DataArray, + past_data: xr.DataArray, + last_past_year: int, +) -> xr.DataArray: + """Perform point intercept-shift. + + All input data should be point arrays already, so they should not + have the draw dimension. The modeled (future) data should have one scenario, + and the past data should not have any. + + Args: + modeled_data (xr.DataArray): Estimates based on FHS modeling. Expects past and + forecast, or at least a value for the last past year and forecast years. + past_data (xr.DataArray): Past estimates from GBD to base the shift on. + last_past_year (int): last past year to intercept-shift on. + + Returns: + (xr.DataArray): intercept-shifted point estimates. + """ + if DimensionConstants.DRAW in modeled_data.dims: + raise KeyError("draw dimension in modeled_data for point intercept-shift.") + + if DimensionConstants.DRAW in past_data.dims: + raise KeyError("draw dimension in past_data for point intercept-shift.") + + if DimensionConstants.SCENARIO in past_data.dims: + raise KeyError("past data should not have scenario dim.") + + # In point runs, we require ingesting future data using open_xr_scenario() + # with designated scenario, hence future data no longer has scenario dim. + if DimensionConstants.SCENARIO in modeled_data.dims: + raise KeyError("future data should not have scenario during point run.") + + last_year_modeled_data = modeled_data.sel( + **{DimensionConstants.YEAR_ID: last_past_year}, drop=True + ) + + last_year_past_data = past_data.sel( + **{DimensionConstants.YEAR_ID: last_past_year}, drop=True + ) + + diff = last_year_modeled_data - last_year_past_data + + shifted_data = modeled_data - diff + + return shifted_data + + +def mean_intercept_shift( + modeled_data: xr.DataArray, + past_data: xr.DataArray, + years: YearRange, + shift_from_reference: bool = True, +) -> xr.DataArray: + """Move the modeled_data up by the difference between the modeled and actual past data. + + Subtract that difference from the modeled data to obtain shifted data so that the GBD past + and FHS forecasts line up. Mean of draws is taken if there are draws to calculate the + offset, but it is applied to whatever draws are in the modeled_data. + + May raise IndexError if coordinates or dimensions do not match up + between modeled data and GBD data. + + Note: + Expects input data to have the same coordinate dimensions for age, sex, and location, + but not year (except the one overlapping year). + + Args: + modeled_data (xr.DataArray): Estimates based on FHS modeling. Expects past and + forecast, or at least a value for the last past year and forecast years. + past_data (xr.DataArray): Past estimates from GBD to base the shift on + years (YearRange): Forecasting time series year range + shift_from_reference (bool): Optional. Whether to shift values based on difference + from reference scenario only, or to take difference for each scenario. In most + cases these will yield the same results unless scenarios have different value in + last past year than reference. + + Returns: + xr.DataArray: Modeled data that has been shifted to line up with GBD past in the last + year of past data. + """ + if shift_from_reference and DimensionConstants.SCENARIO in modeled_data.dims: + last_year_modeled_data = modeled_data.sel( + **{ + DimensionConstants.YEAR_ID: years.past_end, + DimensionConstants.SCENARIO: ScenarioConstants.REFERENCE_SCENARIO_COORD, + }, + drop=True, + ) + else: + last_year_modeled_data = modeled_data.sel( + **{DimensionConstants.YEAR_ID: years.past_end}, + drop=True, + ) + + mean_last_year_modeled_data = mean_of_draws(last_year_modeled_data) + + last_year_past_data = past_data.sel( + **{DimensionConstants.YEAR_ID: years.past_end}, drop=True + ) + mean_last_year_past_data = mean_of_draws(last_year_past_data) + + assert_coords_same_ignoring( + mean_last_year_past_data, mean_last_year_modeled_data, ["scenario"] + ) + + diff = mean_last_year_modeled_data - mean_last_year_past_data + + shifted_data = modeled_data - diff + + return shifted_data + + +def ordered_draw_intercept_shift( + modeled_data: xr.DataArray, + past_data: xr.DataArray, + past_end_year_id: int, + modeled_order_year_id: int, + shift_from_reference: bool = True, +) -> xr.DataArray: + """Move the modeled_data up by the difference between the modeled and actual past data. + + May raise IndexError if coordinates or dimensions do not match up + between modeled data and GBD data. + + Notes: + * **Steps:** + + 1) Reorder the draws for modeled-past and GBD-past by value in the last past year. + 2) Take the difference between reordered modeled-past and GBD-past in the last past + year. + 3) Add the difference to the past and forecasted modeled-data such that reordered + draw difference gets applied to the values of its original unordered draw + number. + + * **Preconditions:** + + - Expects them to have the same coordinate dimensions for age, sex, and location, + but not year (except the one overlapping year). + - GBD-past and modeled-past have the same number of draws. + - Draw coordinates are zero-indexed integers. + + Args: + modeled_data (xr.DataArray): Estimates based on FHS modeling. Expects past and + forecast, or at least a value for the last past year and forecast years. + past_data (xr.DataArray): Past estimates from GBD to base the shift on + past_end_year_id (int): Last year of past data from GBD + modeled_order_year_id (int): What year to base the draw order off of for the modeled + data. Generally is the last past year or the last forecast year. + shift_from_reference (bool): Optional. Whether to shift values based on difference + from reference scenario only, or to take difference for each scenario. In most + cases these will yield the same results unless scenarios have different value in + last past year than reference. + + Returns: + xr.DataArray: Modeled data that has been shifted to line up with GBD past in the last + year of past data. + """ + # Subset GBD past and modeled values to the year they will base the draw + # order on. Also subset modeled data to last past year. + past_data = past_data.sel(**{DimensionConstants.YEAR_ID: past_end_year_id}, drop=True) + if shift_from_reference and DimensionConstants.SCENARIO in modeled_data.dims: + modeled_order_year = modeled_data.sel( + **{ + DimensionConstants.YEAR_ID: modeled_order_year_id, + DimensionConstants.SCENARIO: ScenarioConstants.REFERENCE_SCENARIO_COORD, + }, + drop=True, + ) + modeled_last_past_year = modeled_data.sel( + **{ + DimensionConstants.YEAR_ID: past_end_year_id, + DimensionConstants.SCENARIO: ScenarioConstants.REFERENCE_SCENARIO_COORD, + }, + drop=True, + ) + else: + modeled_order_year = modeled_data.sel( + **{DimensionConstants.YEAR_ID: modeled_order_year_id}, + drop=True, + ) + modeled_last_past_year = modeled_data.sel( + **{DimensionConstants.YEAR_ID: past_end_year_id}, + drop=True, + ) + + assert_coords_same_ignoring(past_data, modeled_order_year, ["scenario"]) + + # Make a copy of the modeled data that will contain the shifted data. + shifted_modeled = modeled_data.copy() + + # Keep track of the original draw number order + original_draw_order = modeled_order_year.draw.values + + # get dimensions to loop through except for dimension draw + non_draw_coords = modeled_order_year.drop_vars(DimensionConstants.DRAW).coords + coords = list(non_draw_coords.indexes.values()) + dims = list(non_draw_coords.indexes.keys()) + + for coord in it.product(*coords): + logger.debug("Shifting for coord", bindings=dict(coord=coord)) + + # Get coordinates of data slice to reorder draws for + slice_dict = {dims[i]: coord[i] for i in range(len(coord))} + + # If we've got a scenario key in the slice_dict; make a copy dict without the scenario + # key for use when subsetting our past data + if DimensionConstants.SCENARIO in slice_dict.keys(): + non_scenario_slice_dict = slice_dict.copy() + del non_scenario_slice_dict[DimensionConstants.SCENARIO] + else: + non_scenario_slice_dict = slice_dict.copy() + + # Sort the modeled values for the slice of the order year by draw + # such that the numbering actually reflects each draw's rank (e.g. draw + # 0 is the minimum and draw 99 is the maximum for 100 draws). + modeled_slice = modeled_order_year.sel(**slice_dict) + modeled_slice_draw_index = modeled_slice.coords.dims.index(DimensionConstants.DRAW) + modeled_slice_draw_rank = ( + modeled_slice.argsort(modeled_slice_draw_index) + .argsort(modeled_slice_draw_index) + .values + ) + + # Also get the slice to take the difference from, since we always take + # the difference in the last past year regardless of what year we order + # the modeled data by. + modeled_last_past_slice = modeled_last_past_year.sel(**slice_dict) + modeled_last_past_slice = modeled_last_past_slice.assign_coords( + draw=modeled_slice_draw_rank + ) + + # Sort the GBD past values for the slice of the last past year by draw + # such that the numbering actually reflects each draw's rank (e.g. draw + # 0 is the minimum and draw 99 is the maximum for 100 draws). + past_slice = past_data.sel(**non_scenario_slice_dict) + past_slice_draw_index = past_slice.coords.dims.index(DimensionConstants.DRAW) + past_slice_draw_rank = ( + past_slice.argsort(past_slice_draw_index).argsort(past_slice_draw_index).values + ) + past_slice = past_slice.assign_coords(draw=past_slice_draw_rank) + + # Take difference of sorted GBD past and sorted modeled past (again, + # just for last year of past). + diff = past_slice - modeled_last_past_slice + + # Reassign the draw numbers of the modeled data to be consistent with + # the sorted draw values for the order year of the modeled data. + # Now add the difference to shift this modeled data for all years. + modeled_all_years_slice = modeled_data.sel(**slice_dict) + modeled_all_years_slice = modeled_all_years_slice.assign_coords( + draw=modeled_slice_draw_rank + ) + shifted_slice = modeled_all_years_slice + diff + + # Now reassign the draw numbers of the **shifted**-modeled data + # **again** to their original numbers. That is, at this point draw 0 + # for the last past year may not be the minimum value. + shifted_slice = shifted_slice.assign_coords(draw=original_draw_order) + shifted_modeled.loc[slice_dict] = shifted_slice + + return shifted_modeled + + +def unordered_draw_intercept_shift( + modeled_data: xr.DataArray, + past_data: xr.DataArray, + past_end_year_id: int, + shift_from_reference: bool = True, +) -> xr.DataArray: + """Move the modeled_data up by the difference between the modeled and actual past data. + + Move the modeled_data up by the last-year difference between the modeled and actual past + data. + + May raise IndexError: If coordinates or dimensions do not match up + between modeled data and GBD data. + + Notes: + * **Steps:** + + 1) Take the difference between modeled data and GBD-past in the last past year. + 2) Add the difference to the past and forecasted modeled-data such that reordered + draw difference gets applied to the values of its original unordered draw + number. + * **Preconditions:** + + - Expects them to have the same coordinate dimensions for age, sex, and location, + but not year (except the one overlapping year). + - GBD-past and modeled-past have the same number of draws. + - Draw coordinates are zero-indexed integers. + + Args: + modeled_data (xr.DataArray): Estimates based on FHS modeling. Expects past and + forecast, or at least a value for the last past year and forecast years. + past_data (xr.DataArray): Past estimates from GBD to base the shift on. + past_end_year_id (int): Last year of past data from GBD. + shift_from_reference (bool): Optional. Whether to shift values based on difference + from reference scenario only, or to take difference for each scenario. In most + cases these will yield the same results unless scenarios have different value in + last past year than reference. + + Returns: + xr.DataArray: Modeled data that has been shifted to line up with GBD past in the last + year of past data. + """ + # Subset GBD past and modeled data to last past year. + past_data = past_data.sel(**{DimensionConstants.YEAR_ID: past_end_year_id}, drop=True) + + if shift_from_reference and DimensionConstants.SCENARIO in modeled_data.dims: + modeled_last_past_year = modeled_data.sel( + **{ + DimensionConstants.YEAR_ID: past_end_year_id, + DimensionConstants.SCENARIO: ScenarioConstants.REFERENCE_SCENARIO_COORD, + }, + drop=True, + ) + else: + modeled_last_past_year = modeled_data.sel( + **{DimensionConstants.YEAR_ID: past_end_year_id}, + drop=True, + ) + + assert_coords_same_ignoring(past_data, modeled_last_past_year, ["scenario"]) + + diff = modeled_last_past_year - past_data + + shifted_modeled = modeled_data - diff + + return shifted_modeled + + +def max_fanout_intercept_shift( + modeled_data: xr.DataArray, past_data: xr.DataArray, years: YearRange +) -> xr.DataArray: + """Ordered-draw intercept-shift based on fan-out trajectories. + + Differs from ordered_draw_intercept_shift, in that trajectories are determined + by the difference between last forecast year and last past year. + + The function maximizes fan-out via + 1.) ranking the future trajectories and the last past year's values + 2.) connecting the top last past year value to the top trajectry, 2nd to 2nd, etc. + + Assumes there's no scenario dimension. SHOULD ONLY BE DONE IN NORMAL SPACE, + never in any non-linear transformation. + + Args: + modeled_data (xr.DataArray): Estimates based on FHS modeling. Expects past and + forecast, or at least a value for the last past year and forecast years. + past_data (xr.DataArray): Past estimates from GBD to base the shift on. + years (YearRange): first past year:first future year:last future year. + + Returns: + (xr.DataArray): ordered-draw intercept-shifted with maximum fan-out. + """ + assert_coords_same_ignoring( + past_data.sel(year_id=years.past_end, drop=True), + modeled_data.sel(year_id=years.past_end, drop=True), + [], + ) + + shifted_modeled = modeled_data.copy() # this will be modified in place for return + + # we determine draw sort order by trajectory = last year - last past year + trajectories = modeled_data.sel(year_id=years.forecast_end) - modeled_data.sel( + year_id=years.past_end + ) # no more year_id dim + + # these are coords after removing year and draw dims + non_draw_coords = trajectories.drop_vars(DimensionConstants.DRAW).coords + coords = list(non_draw_coords.indexes.values()) + dims = list(non_draw_coords.indexes.keys()) + + for coord in it.product(*coords): + slice_dict = {dims[i]: coord[i] for i in range(len(coord))} + + trajs = trajectories.sel(**slice_dict) # should have only draw dim now + # using argsort once gives the indices that sort the list, + # using it twice gives the rank of each value, from low to high + traj_rank = trajs.argsort().argsort().values + + # from now on we will do some calculations in "rank space". + # ranked_future has year_id/rank dims. + ranked_future = ( + modeled_data.sel(**slice_dict) + .rename({DimensionConstants.DRAW: "rank"}) + .assign_coords(rank=traj_rank) + ) + # each value is labeled by its trajectory rank now, with 0 being the lowest rank. + predicted_last_past_rank = ranked_future.sel( + year_id=years.past_end + ) # 1-D: value is predicted last past year's value, label is future trajectory rank + + # our goal is to allocate the highest trajectory rank to the highest + # observed last past rank, and lowest to lowest, etc. + # so we need to bring the last observed year into rank space as well. + observed_last_past = past_data.sel(**slice_dict, year_id=years.past_end) # draws only + # save the past draw-labels, ordered by rank, so we can convert from rank to draw later + past_draw_order = observed_last_past[DimensionConstants.DRAW].values[ + observed_last_past.argsort() + ] + + # get the rank of the past values + past_rank = observed_last_past.argsort().argsort().values + + # 1-D: value is gbd last past year's value, label is value's rank. + observed_last_past_rank = observed_last_past.rename( + {DimensionConstants.DRAW: "rank"} + ).assign_coords(rank=past_rank) + + # Now observed_last_past_rank and predicted_last_past_rank are both 1-D arrays + # with a "rank" dimension, where observed_last_past_rank is ranked by value, + # and predicted_last_past_rank is ranked by future trajectory. + # The goal now is then to attach the top-ranked future trajectory to the top-ranked + # past value, 2nd-highest to 2nd-highest, and so on. + diff = observed_last_past_rank - predicted_last_past_rank # the diff to shift + + # this is a safeguard in case xarray changes its default arithmetic behavior + if not (diff["rank"] == observed_last_past_rank["rank"]).all(): + raise ValueError("diff must inherit rank values from observed_last_past_rank") + + # diff added to the future draws in rank space. + ranked_future = diff + ranked_future # has year_id / rank dims + # with this, the top-ranked future trajectory now has the same last past year + # value as that of top-ranked gbd last past year value, completing the shift. + + # another safeguard + if not (ranked_future["rank"] == diff["rank"]).all(): + raise ValueError("ranked_future must inherit rank values from diff") + + # Now need to convert back to draw space. + # ranked_future is now in rank space and labeled by ranks of gbd last past year values, + # and we want to give it back its original draw labels. + # We've saved the rank-order of the draws in the past, so if we sort by rank, then + # the order of the ranks is now the order of the draws in `past_draw_order`, so we can + # apply that variable as the draws + ranked_future = ( + ranked_future.sortby("rank") + .rename({"rank": DimensionConstants.DRAW}) + .assign_coords(draw=past_draw_order) + ) + ranked_future = ranked_future.sortby("draw") + + # prep the shape of ranked_future before inserting into modeled_data + dim_order = modeled_data.sel(**slice_dict).dims + ranked_future = ranked_future.transpose(*dim_order) + + # modify in-place + shifted_modeled.loc[slice_dict] = ranked_future # ~ "re_aligned_future" + + return shifted_modeled diff --git a/gbd_2021/disease_burden_forecast_code/mortality/lib/config_dataclasses.py b/gbd_2021/disease_burden_forecast_code/mortality/lib/config_dataclasses.py new file mode 100644 index 0000000..9c49de7 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/mortality/lib/config_dataclasses.py @@ -0,0 +1,57 @@ +"""Dataclasses for capturing the configuration of the pipeline.""" + +from typing import Dict, Optional + +from fhs_lib_cli_tools.lib.fhs_dataclasses import BaseModel +from fhs_lib_file_interface.lib.version_metadata import VersionMetadata +from fhs_lib_file_interface.lib.versioning import Versions +from fhs_lib_year_range_manager.lib.year_range import YearRange + + +class CauseSpecificModelingArguments(BaseModel): + """Common modeling arguments used in mortality cause-specific. + + These attributes represent the "dials" you could turn in this repo which would + influence the modeled aspects of cause-specific mortality. These include the boolean + flags which turn on & off certain features, as well as the attributes which determine the + contents of what's being modeled (like sex and years). + """ + + acause: str + draws: int + drivers_at_reference: bool + fit_on_subnational: bool + gbd_round_id: int + sex_id: int + spline: bool + subnational: bool + version: str + years: YearRange + seed: Optional[int] + + +class CauseSpecificVersionArguments(BaseModel): + """Common version arguments used in mortality cause-specific. + + These attributes generally represent Versions objects containing references to input data. + The ``logspace_conversion_flags`` is the snowflake in this dataclass which is a dictionary + mapping the input data name to whether it should be calculated in log space (e.x. + {"death": True, "sdi": False}). + """ + + versions: Versions + logspace_conversion_flags: Dict[str, bool] + output_scenario: int | None + + +class SumToAllCauseModelingArguments(BaseModel): + """Common modeling arguments used in mortality sum-to-all-cause.""" + + acause: str + agg_version: VersionMetadata + approximation: bool + draws: int + gbd_round_id: int + input_version: VersionMetadata + period: str + output_scenario: int | None diff --git a/gbd_2021/disease_burden_forecast_code/mortality/lib/downloaders.py b/gbd_2021/disease_burden_forecast_code/mortality/lib/downloaders.py new file mode 100644 index 0000000..03fd841 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/mortality/lib/downloaders.py @@ -0,0 +1,809 @@ +"""Functions used by run_cod_model.py to import past data. + +Including cause-specific mx data, SEVS, scalars, and covariates. +""" + +from typing import Dict, Iterable, List, Optional, Union + +import fhs_lib_database_interface.lib.query.risk as query_risk +import numpy as np +import xarray as xr +from fhs_lib_data_transformation.lib.dimension_transformation import expand_dimensions +from fhs_lib_data_transformation.lib.resample import resample +from fhs_lib_database_interface.lib import db_session +from fhs_lib_database_interface.lib.constants import ( + CauseConstants, + CauseRiskPairConstants, + DimensionConstants, + RiskConstants, + ScenarioConstants, +) +from fhs_lib_database_interface.lib.query import age, location +from fhs_lib_database_interface.lib.query.cause import get_acause, get_cause_id +from fhs_lib_database_interface.lib.strategy_set import strategy +from fhs_lib_file_interface.lib.version_metadata import FHSFileSpec +from fhs_lib_file_interface.lib.versioning import Versions +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr_scenario +from tiny_structured_logger.lib import fhs_logging + +from fhs_pipeline_mortality.lib.config_dataclasses import CauseSpecificModelingArguments + +logger = fhs_logging.get_logger() + +DIRECT_MODELED_PAFS = ["drugs_illicit_direct", "unsafe_sex"] +SCALAR_CAP = 1000000.0 +FLOOR = 1e-28 + + +def validate_cause( + acause: str, + gbd_round_id: int, +) -> None: + """Check whether a cause is in the cod cause list. + + Throws an assertion error if not. + + Args: + acause (str): Cause to check whether it is a modeled cod acause. + gbd_round_id (int): round of GBD to pull cause hierarchy from + + Raises: + ValueError: if acause not in causes. + """ + logger.debug("Validating cause", bindings=dict(acause=acause)) + + with db_session.create_db_session() as session: + causes = strategy.get_cause_set( + session=session, + strategy_id=CauseConstants.FATAL_GK_STRATEGY_ID, + gbd_round_id=gbd_round_id, + ).acause.values + + maternal_cause = strategy.get_cause_set( + session=session, + strategy_id=CauseConstants.PARENT_MATERNAL_STRATEGY_ID, + gbd_round_id=gbd_round_id, + ).acause.values + + ckd_cause = strategy.get_cause_set( + session=session, + strategy_id=CauseConstants.CKD_STRATEGY_ID, + gbd_round_id=gbd_round_id, + ).acause.values + + causes = np.append(causes, maternal_cause) + + causes = np.append(causes, ckd_cause) + + if acause not in causes: + raise ValueError("acause must be a valid modeled COD acause") + + +def validate_version(versions: Versions, past_or_future: str, stage: str) -> None: + """Check whether or not the version to load from exists. + + Throws an assertion error if not. + + Args: + versions (Versions): the versions object to query a version for + stage (str): the metric to check the version under (e.g. scalar) + past_or_future (str): past or future of the version + + Raises: + ValueError: if version not valid. + """ + version_path = versions.get(past_or_future=past_or_future, stage=stage).data_path() + + if not version_path.exists(): + raise ValueError(f"Version {version_path} is not valid - path does not exist.") + + +def empty_dem_xarray( + gbd_round_id: int, + locs: Union[int, List[int]], + sex_ids: Union[int, List[int]], + val: float = 0, + draws: int = 100, + start: int = 1990, + end: int = 2040, + scenarios: List[int] = ScenarioConstants.SCENARIOS, +) -> xr.DataArray: + """Build an empty xarray which has all dimensions required for modeling. + + i.e. location_id, age_group_id, year_id, sex_id, draw, & scenario. + + Args: + gbd_round_id (int): round of GBD to build data array around + locs (Union[int, List[int]]): list or array of locations to use as location_id + coordinates + sex_ids (Union[int, List[int]]): the sex_id values to create + val (float): value to fill the array with, default 0. + draws (int): what the length of the draw dimension should be. + start (int): the beginning of the year_id dimension, default 1990 + end (int): the end of the year_id dimension, default 2040 + scenarios (list): the scenarios to create + + Returns: + DataArray: Six-dimensional data array. + """ + demog = dict( + age_group_id=age.get_ages(gbd_round_id=gbd_round_id).age_group_id.values, + year_id=np.arange(start, end + 1), + location_id=locs, + scenario=scenarios, + draw=np.arange(draws), + sex_id=sex_ids, + ) + logger.debug("Building empty demographic array", bindings=demog) + + dims = [ + DimensionConstants.LOCATION_ID, + DimensionConstants.AGE_GROUP_ID, + DimensionConstants.YEAR_ID, + DimensionConstants.SEX_ID, + DimensionConstants.DRAW, + DimensionConstants.SCENARIO, + ] + size = [len(demog[x]) for x in dims] + vals = np.ones(size) * val + dem_array = xr.DataArray(vals, coords=demog, dims=dims) + return dem_array + + +def replace_scenario_dim(da: xr.DataArray, scenarios: List[int]) -> xr.DataArray: + """Discard nonreference scenarios, replacing them with the reference one.""" + if DimensionConstants.SCENARIO in da.dims: + da = da.sel(scenario=ScenarioConstants.REFERENCE_SCENARIO_COORD, drop=True) + return expand_dimensions(da, scenario=scenarios) + + +def enforce_scenario_dim( + modeling_args: CauseSpecificModelingArguments, + covariate: Optional[str], + scenarios: List[int], + da: xr.DataArray, +) -> xr.DataArray: + """Transform `da` so that it has the desired scenarios, enforcing "drivers_at_reference". + + Transform `da` so that it has the desired scenarios, enforcing the "drivers_at_reference" + policy from `modeling_args`. When using drivers_at_reference, the reference scenario may be + duplicated in place of the nonreference scenarios, depending on the covariate. + """ + if ( + modeling_args.drivers_at_reference and covariate != "asfr" + ) or DimensionConstants.SCENARIO not in da.dims: + return replace_scenario_dim(da, scenarios) + else: + return da.sel(scenario=scenarios) + + +def acause_has_scalar(gbd_round_id: int, acause: str) -> bool: + """True iff the given acause "has scalar data" in the gbd_round_id.""" + with db_session.create_db_session() as session: + causes_with_scalars = strategy.get_cause_risk_pair_set( + session=session, + strategy_id=CauseRiskPairConstants.CALC_PAF_STRATEGY_ID, + gbd_round_id=gbd_round_id, + ) + + acauses_with_scalars = causes_with_scalars.cause_id.map(get_acause).unique() + + # NOTE: Change the list of causes-with-scalars from GBD round ID 6 to GBD round ID 7. + acauses_with_scalars_list = acauses_with_scalars.tolist() + + acause_has_scalar = acause in acauses_with_scalars_list + return acause_has_scalar + + +def load_scalar( + modeling_args: CauseSpecificModelingArguments, + versions: Versions, + scenarios: List[int], + log: bool = True, +) -> xr.DataArray: + """Load scalar scenario data as an xarray. + + Args: + modeling_args (CauseSpecificModelingArguments): dataclass containing + cause specific modeling arguments + versions (Versions): the versions object to query a version for + scenarios (List[int]): the scenarios you want to subset scalar to + log (bool): Whether to take the natural log of the values, default True + + Returns: + DataArray: array with scalar data for acause + """ + logger.debug( + "Loading scalars for a cause-sex", + bindings=dict(acause=modeling_args.acause, sex_id=modeling_args.sex_id), + ) + + national_only = not modeling_args.subnational + locs = location.get_location_set( + gbd_round_id=modeling_args.gbd_round_id, national_only=national_only + )[DimensionConstants.LOCATION_ID].tolist() + + validate_cause(acause=modeling_args.acause, gbd_round_id=modeling_args.gbd_round_id) + + future_file_spec = FHSFileSpec( + versions.get(past_or_future="future", stage="scalar"), f"{modeling_args.acause}.nc" + ) + past_file_spec = FHSFileSpec( + versions.get(past_or_future="past", stage="scalar"), f"{modeling_args.acause}.nc" + ) + + if acause_has_scalar(modeling_args.gbd_round_id, modeling_args.acause): + future = open_xr_scenario(file_spec=future_file_spec).sel( + sex_id=modeling_args.sex_id, + year_id=modeling_args.years.forecast_years, + location_id=locs, + ) + future = enforce_scenario_dim(modeling_args, None, scenarios, future) + past = open_xr_scenario(file_spec=past_file_spec).sel( + sex_id=modeling_args.sex_id, + year_id=modeling_args.years.past_years, + location_id=locs, + ) + past = replace_scenario_dim(past, scenarios) + past = resample(past, modeling_args.draws) + future = resample(future, modeling_args.draws) + + da = xr.concat([past, future], dim=DimensionConstants.YEAR_ID) + else: + da = empty_dem_xarray( + gbd_round_id=modeling_args.gbd_round_id, + locs=locs, + sex_ids=[modeling_args.sex_id], + val=1.0, + draws=modeling_args.draws, + start=modeling_args.years.past_start, + end=modeling_args.years.forecast_end, + scenarios=scenarios, + ) + + da = _drop_scenario_in_single_scenario_mode(da, scenarios) + + da.name = "risk_scalar" + da = da.where(da <= SCALAR_CAP).fillna(SCALAR_CAP) + if log: + da = np.log(da) + da.name = "ln_risk_scalar" + return da + + +def load_sdi( + modeling_args: CauseSpecificModelingArguments, + versions: Versions, + scenarios: List[int], + log: bool = False, +) -> xr.DataArray: + """Loads and returns sociodemographic index. + + Args: + modeling_args (CauseSpecificModelingArguments): dataclass containing + cause specific modeling arguments + versions (Versions): the versions object to query a version for + scenarios (List[int]): the scenarios you want to subset SDI to + log (bool): whether to take the log of SDI, default False + + Returns: + DataArray: array with sdi information. + """ + logger.debug("Loading SDI") + + national_only = not modeling_args.subnational + locs = location.get_location_set( + gbd_round_id=modeling_args.gbd_round_id, national_only=national_only + )[DimensionConstants.LOCATION_ID].tolist() + + future_sdi_spec = FHSFileSpec(versions.get(past_or_future="future", stage="sdi"), "sdi.nc") + future_da = open_xr_scenario(file_spec=future_sdi_spec).sel( + year_id=modeling_args.years.forecast_years, + location_id=locs, + ) + if DimensionConstants.SCENARIO in future_da.dims: + future_da = future_da.sel(scenario=scenarios) + + past_sdi_spec = FHSFileSpec(versions.get(past_or_future="past", stage="sdi"), "sdi.nc") + past_da = open_xr_scenario(file_spec=past_sdi_spec).sel( + year_id=modeling_args.years.past_years, location_id=locs + ) + + past_da = replace_scenario_dim(past_da, scenarios) + + past_da = _drop_scenario_in_single_scenario_mode(past_da, scenarios) + + da = xr.concat([past_da, future_da], dim="year_id") + da = resample(data=da, num_of_draws=modeling_args.draws) + da.name = "sdi" + + if log: + da = np.log(da) + da.name = "ln_sdi" + + return da + + +def load_cod_data( + modeling_args: CauseSpecificModelingArguments, + versions: Versions, + past_or_future: str, + stage: str, + log: bool = True, +) -> xr.DataArray: + """Load in cause-sex specific mortality rate. + + Args: + modeling_args (CauseSpecificModelingArguments): dataclass containing + cause specific modeling arguments + versions (Versions): the versions object to query a version for + past_or_future (str): whether to use past or future + stage (str): the stage to load for + log (bool): Whether to take the natural log of the mortality values. + + Returns: + DataArray: array with acause-sex specific death rate information. + """ + logger.debug( + "Loading cause-sex-specific mortality", + bindings=dict(acause=modeling_args.acause, sex_id=modeling_args.sex_id), + ) + + validate_cause( + acause=modeling_args.acause, + gbd_round_id=versions.get_effective_gbd_round_id(past_or_future, stage), + ) + + cod_file_spec = FHSFileSpec( + versions.get(past_or_future=past_or_future, stage=stage), f"{modeling_args.acause}.nc" + ) + da = open_xr_scenario(file_spec=cod_file_spec) + + if DimensionConstants.ACAUSE in da.dims: + da = da.sel(acause=modeling_args.acause, drop=True) + + # Make an array of means that replicates it for as many as there are draws in the raw data. + # leave the last year off for draws + mean_array = da["mean"].loc[{"year_id": modeling_args.years.past_years[:-1]}] + # Note: Some legacy data does in fact have draws in the mean variable. For those cases, we + # use the draws directly. + # In normal cases, we will expand out the draw dim on the means, to combine them + # with draw-level data from the last past year. + if "draw" not in mean_array.dims: + mean_array = mean_array.expand_dims(draw=range(modeling_args.draws)) + + draw_array = da["value"].loc[{"year_id": [modeling_args.years.past_end]}] + draw_array = resample(data=draw_array, num_of_draws=modeling_args.draws) + da_draw = xr.concat([mean_array, draw_array], dim="year_id") + + locdiv = ("draw", "age_group_id", "year_id", "sex_id") + locidx = np.where(~(da_draw == 0).all(locdiv))[0] + locs = da_draw.location_id.values[locidx] + + agediv = ("draw", "location_id", "year_id", "sex_id") + ageidx = np.where(~(da_draw == 0).all(agediv))[0] + ages = da_draw.age_group_id.values[ageidx] + + demdict = dict(location_id=locs, age_group_id=ages, sex_id=modeling_args.sex_id) + + if modeling_args.acause == "ntd_nema": + da_draw += FLOOR + + if log: + da_draw = np.log(da_draw) + + da_draw_sub = da_draw.loc[demdict] + return da_draw_sub + + +def load_cov( + modeling_args: CauseSpecificModelingArguments, + versions: Versions, + scenarios: List[int], + cov: str, + log: bool, +) -> xr.DataArray: + """Read a covariate and format for cod modeling. + + Fills in the country-level covariates for subnational locations if applicable. Takes + natural log of values if applicable. + + Args: + modeling_args (CauseSpecificModelingArguments): General modeling arguments. + versions (Versions): Catalog of versions having entries for the given `cov`. + scenarios (List[int]): Select these scenarios from future; expand past data to match. + Typically, just let these be the same scenarios that are in the future data. + cov (str): Covariate name to load. + log (bool): Whether to take the natural log of the covariate values. + + Returns: + DataArrray: array with covariate information loaded and formatted + """ + logger.debug( + "Loading sex specific covariate data", + bindings=dict(sex_id=modeling_args.sex_id, cov=cov), + ) + + national_only = not modeling_args.subnational + locs = location.get_location_set( + gbd_round_id=modeling_args.gbd_round_id, national_only=national_only + )[DimensionConstants.LOCATION_ID].tolist() + + cov_file = "mort_rate" if cov == "hiv" else cov + future_cov_file_spec = FHSFileSpec( + versions.get(past_or_future="future", stage=cov), f"{cov_file}.nc" + ) + + if cov in versions["past"].keys(): + past_cov_file_spec = FHSFileSpec( + versions.get(past_or_future="past", stage=cov), f"{cov_file}.nc" + ) + + raw_past = open_xr_scenario(file_spec=past_cov_file_spec).sel( + year_id=modeling_args.years.past_years, location_id=locs + ) + + raw_past = replace_scenario_dim(raw_past, scenarios) + + raw_future = open_xr_scenario(file_spec=future_cov_file_spec).sel( + year_id=modeling_args.years.forecast_years, location_id=locs + ) + + raw_future = enforce_scenario_dim(modeling_args, cov, scenarios, raw_future) + + raw = xr.concat([raw_past, raw_future], dim="year_id") + else: + # don't select locations until later because some covs only read in + # from the past may not have all locations? + raw = open_xr_scenario(file_spec=future_cov_file_spec).sel( + year_id=modeling_args.years.years + ) + + raw = resample(data=raw, num_of_draws=modeling_args.draws) + + badkeys = np.setdiff1d(list(raw.coords.keys()), list(raw.coords.indexes.keys())) + for k in badkeys: + raw = raw.drop_vars(k) + + raw = _remove_aggregate_age_groups(raw) + raw = _project_sex_id(raw=raw, sex_id=modeling_args.sex_id) + + if "location_id" in raw.coords.keys(): + raw = raw.loc[{"location_id": np.intersect1d(locs, raw["location_id"].values)}] + + # NOTE this was not previously enforcing the scenario dim in the case where we're not using + # drivers_at_reference. + raw = enforce_scenario_dim(modeling_args, cov, scenarios, raw) + + raw = _drop_scenario_in_single_scenario_mode(raw, scenarios) + + if log: + raw = raw.where(raw >= FLOOR).fillna(FLOOR) + raw = np.log(raw) + + return raw + + +def load_sev( + modeling_args: CauseSpecificModelingArguments, + versions: Versions, + scenarios: List[int], + rei: str, + log: bool, + include_uncertainty: bool = False, +) -> xr.DataArray: + """Read in summary exposure value information. + + Args: + modeling_args (CauseSpecificModelingArguments): dataclass containing + cause specific modeling arguments + versions (Versions): the versions object to query a version for + scenarios (List[int]): the scenarios to subset SEV to + rei (str): Risk, etiology, or impairment to load. + log (bool): whether to take the natural log of sev values + include_uncertainty (bool): whether to include uncertainty + + Returns: + DataArray: xarray with sev data loaded and formatted + """ + logger.debug( + "Reading risk-sex specific SEV", + bindings=dict(rei=rei, sex_id=modeling_args.sex_id), + ) + national_only = not modeling_args.subnational + locs = location.get_location_set( + gbd_round_id=modeling_args.gbd_round_id, national_only=national_only + )[DimensionConstants.LOCATION_ID].tolist() + + future_file_spec = FHSFileSpec( + versions.get(past_or_future="future", stage="sev"), f"{rei}.nc" + ) + da = open_xr_scenario(file_spec=future_file_spec).sel( + sex_id=modeling_args.sex_id, location_id=locs + ) + + da = enforce_scenario_dim(modeling_args, None, scenarios, da) + + if "sev" in versions["past"].keys(): + past_file_spec = FHSFileSpec( + versions.get(past_or_future="past", stage="sev"), f"{rei}.nc" + ) + past_da = open_xr_scenario(file_spec=past_file_spec).sel( + sex_id=modeling_args.sex_id, + location_id=locs, + year_id=modeling_args.years.past_years, + ) + + past_da = replace_scenario_dim(past_da, scenarios) + + da = da.sel(year_id=modeling_args.years.forecast_years) + da = xr.concat([past_da, da], dim="year_id") + else: + da = da.sel(year_id=modeling_args.years.years) + + if not include_uncertainty: + da = da.mean("draw").expand_dims({"draw": np.arange(modeling_args.draws)}) + else: + da = resample(data=da, num_of_draws=modeling_args.draws) + + if log: + da = da.where(da >= FLOOR).fillna(FLOOR) + da = np.log(da) + + da = _drop_scenario_in_single_scenario_mode(da, scenarios) + + single_coords = np.setdiff1d(list(da.coords.keys()), da.dims) + da = da.drop_vars(single_coords) + return da + + +def load_paf_covs( + modeling_args: CauseSpecificModelingArguments, + versions: Versions, + scenarios: List[int], + listonly: bool = False, + include_uncertainty: bool = False, +) -> xr.Dataset: + """Return a Dataset of sev data for a cause or alternatively a list of applicable sevs. + + Args: + modeling_args (CauseSpecificModelingArguments): dataclass containing + cause specific modeling arguments + versions (Versions): the versions object to query a version for + scenarios (List[int]): the scenarios you want to subset for PAFs + listonly (bool): if True, return just a list of applicable sevs, + otherwise return the full dataset + include_uncertainty (bool): Whether to include past uncertainty + (otherwise just copy the mean) + + Returns: + Dataset: dataset whose datavars are the arrays for each relevant sev + """ + logger.debug( + "Reading cause-sex-specific PAF", + bindings=dict(acause=modeling_args.acause, sex_id=modeling_args.sex_id), + ) + + with db_session.create_db_session() as session: + paf1_risk_pairs = strategy.get_cause_risk_pair_set( + session=session, + strategy_id=CauseRiskPairConstants.PAF_OF_ONE_STRATEGY_ID, + gbd_round_id=modeling_args.gbd_round_id, + ) + + cause_id = get_cause_id(acause=modeling_args.acause) + rei_ids = paf1_risk_pairs.query("cause_id == @cause_id").rei_id.values + logger.debug( + f"Reading all risk specifc PAFs for {modeling_args.acause}", + bindings=dict(acause=modeling_args.acause, cause_id=cause_id, rei_ids=rei_ids), + ) + + all_reis = [query_risk.get_rei(int(rei_id)) for rei_id in rei_ids] + + most_detailed_risks = strategy.get_risk_set( + session=session, + strategy_id=RiskConstants.FATAL_DETAILED_STRATEGY_ID, + gbd_round_id=modeling_args.gbd_round_id, + )[DimensionConstants.REI].unique() + + # subset to most detailed reis + reis = [ + rei + for rei in all_reis + if (rei in most_detailed_risks) and not (rei in DIRECT_MODELED_PAFS) + ] + + if listonly: + return reis + + reis_to_log = strategy.get_cause_risk_pair_set( + session=session, + strategy_id=CauseRiskPairConstants.SEV_LOG_TRANSFORM_STRATEGY_ID, + gbd_round_id=modeling_args.gbd_round_id, + )[DimensionConstants.REI_ID].unique() + reis_to_log = [query_risk.get_rei(int(rei_id)) for rei_id in reis_to_log] + + ds = xr.Dataset() + for r in reis: + log_rei = r in reis_to_log + da = load_sev( + modeling_args=modeling_args, + versions=versions, + scenarios=scenarios, + rei=r, + log=log_rei, + include_uncertainty=include_uncertainty, + ) + ds[r] = da + + return ds + + +def load_cod_dataset( + addcovs: List[str], + modeling_args: CauseSpecificModelingArguments, + versions: Versions, + logspace_conversion_flags: Dict[str, bool], + scenarios: Optional[Iterable[int]], + sev_covariate_draws: bool = False, +) -> xr.Dataset: + """Load in acause-sex specific mortality rate. + + Along with scalars, sevs, sdi, and other covariates if applicable. + + Args: + addcovs (list[str]): list of cause-specific covariates to add to the + dataset + modeling_args (CauseSpecificModelingArguments): dataclass containing + cause specific modeling arguments + versions (Versions): set of versions passed in at runtime. Will either be OOS versions + or standard ones based on user input ``oos`` CLI arg. + logspace_conversion_flags (Dict[str, bool]): Mapping of version name to whether or not + it should be logged. + sev_covariate_draws (bool): Whether to include draws of the past SEVs + used covariates or simply take the mean. + + Returns: + Dataset: xarray dataset with cod mortality rate and scalar, sev, sdi, + and other covariate information + """ + logger.debug( + "Loading cause-sex-specific mortality dataset", + bindings=dict(acause=modeling_args.acause, sex_id=modeling_args.sex_id), + ) + + # check validity of inputs + validate_cause(acause=modeling_args.acause, gbd_round_id=modeling_args.gbd_round_id) + for past_or_future in ["past", "future"]: + for stage in list(versions[past_or_future].keys()): + if past_or_future != "future" and stage != "death": + # don't check the output; it doesn't exist yet. + validate_version(versions=versions, past_or_future=past_or_future, stage=stage) + + # parent_id, level info is needed when fitting on nationals only + national_only = not modeling_args.subnational + loc_df = location.get_location_set( + gbd_round_id=modeling_args.gbd_round_id, national_only=national_only + ) + regdf = loc_df[["location_id", "region_id", "super_region_id", "parent_id", "level"]] + if not modeling_args.subnational: + regdf = loc_df.query("level==3")[["location_id", "region_id", "super_region_id"]] + regdf.set_index("location_id", inplace=True) + year_list = modeling_args.years.years + time_array = xr.DataArray( + year_list - modeling_args.years.past_start, + dims=["year_id"], + coords=[year_list], + ) + codda = load_cod_data( + modeling_args=modeling_args, + versions=versions, + past_or_future="past", + stage="death", + log=logspace_conversion_flags["death"], + ) + agevals = codda.coords["age_group_id"].values + locvals = codda.coords["location_id"].values + demdict = dict(year_id=year_list, age_group_id=agevals, location_id=locvals) + + sdi_scenarios = scenarios or _decide_scenario_to_sel(versions, "future", "sdi") + scalar_scenarios = scenarios or _decide_scenario_to_sel(versions, "future", "scalar") + + ds = xr.Dataset( + dict( + y=codda, + sdi=load_sdi( + modeling_args=modeling_args, + versions=versions, + scenarios=sdi_scenarios, + log=logspace_conversion_flags["sdi"], + ), + ln_risk_scalar=load_scalar( + modeling_args=modeling_args, + versions=versions, + scenarios=scalar_scenarios, + log=logspace_conversion_flags["scalar"], + ), + intercept=xr.DataArray(1), + time_var=time_array, + ) + ) + + for cov in addcovs: + cov_scenarios = scenarios or _decide_scenario_to_sel(versions, "future", cov) + ds[cov] = load_cov( + modeling_args=modeling_args, + versions=versions, + scenarios=cov_scenarios, + cov=cov, + log=logspace_conversion_flags[cov], + ) + + sev_scenarios = scenarios or _decide_scenario_to_sel(versions, "future", "sev") + ds.update( + load_paf_covs( + modeling_args=modeling_args, + versions=versions, + scenarios=sev_scenarios, + include_uncertainty=sev_covariate_draws, + ) + ) + + ds_sub = ds.loc[demdict] + ds_sub.update(xr.Dataset(regdf)) + ds_sub.y.values[(ds_sub.y == -np.inf).values] = np.nan + ds_sub.ln_risk_scalar.values[(np.isnan(ds_sub.ln_risk_scalar.values))] = 0.0 + + # select just non-aggregate values + loc_vals = list( + set(regdf.reset_index().location_id.values.tolist()) + & set(ds_sub.location_id.values.tolist()) + ) + ds_sub = ds_sub.loc[{"location_id": loc_vals}] + return ds_sub + + +def _remove_aggregate_age_groups(raw: xr.DataArray) -> xr.DataArray: + """Remove aggregate age-groups, or if we only have aggregate age-groups, keep just one.""" + agg_ages = [22, 27] + if "age_group_id" in raw.coords.keys(): + # drop point coordinate + if len(raw["age_group_id"]) == 1: + return raw.squeeze("age_group_id", drop=True) + # keep all-ages, not age-standardized + elif raw.age_group_id.values.tolist() == agg_ages: + return raw.loc[{"age_group_id": 22}] + # given the choice between age-specific or all-ages, choose age-specific + else: + remaining_ages = np.setdiff1d(raw["age_group_id"].values, agg_ages) + return raw.loc[{"age_group_id": remaining_ages}] + else: + return raw + + +def _project_sex_id(raw: xr.DataArray, sex_id: int) -> xr.DataArray: + """Project data down to a single sex_id, knowing sex_id 3 represents data for either.""" + if "sex_id" in raw.coords.keys(): + if sex_id in raw["sex_id"].values: + return raw.sel(sex_id=sex_id, drop=True) + elif raw["sex_id"].values == 3: + return raw.squeeze("sex_id", drop=True) + else: + print("this covariate doesn't have the sex_id you're looking for") + raise SystemExit + else: + return raw + + +def _decide_scenario_to_sel(versions: Versions, past_or_future: str, stage: str) -> List[int]: + if versions.get(past_or_future, stage).scenario is not None: + return [versions.get(past_or_future, stage).scenario] + return ScenarioConstants.SCENARIOS + + +def _drop_scenario_in_single_scenario_mode( + da: xr.DataArray, scenarios: List[int] +) -> xr.DataArray: + if DimensionConstants.SCENARIO in da.dims and len(da[DimensionConstants.SCENARIO]) == 1: + da = da.sel(scenario=scenarios[0], drop=True) + return da diff --git a/gbd_2021/disease_burden_forecast_code/mortality/lib/get_fatal_causes.py b/gbd_2021/disease_burden_forecast_code/mortality/lib/get_fatal_causes.py new file mode 100644 index 0000000..c4dbda9 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/mortality/lib/get_fatal_causes.py @@ -0,0 +1,93 @@ +import pandas as pd +from fhs_lib_database_interface.lib import db_session +from fhs_lib_database_interface.lib.constants import ( + CauseConstants, + DimensionConstants, + FHSDBConstants, +) +from fhs_lib_database_interface.lib.query import cause +from fhs_lib_database_interface.lib.strategy_set import strategy + +from fhs_pipeline_mortality.lib.make_hierarchies import ( + include_up_hierarchy, + make_hierarchy_tree, +) + +ROOT_CAUSE_ID = 294 +CAUSE_COLUMNS_OF_INTEREST = [ + DimensionConstants.ACAUSE, + DimensionConstants.CAUSE_ID, + DimensionConstants.PARENT_ID_COL, + DimensionConstants.LEVEL_COL, +] + + +def get_fatal_causes_df(gbd_round_id: int) -> pd.DataFrame: + """Get the fatal cause dataframe for the input ``gbd_round_id``. + + Returns: + pd.DataFrame: A dataframe containing only fatal causes and the following columns: + 'acause', 'cause_id', 'parent_id', 'level' + """ + cause_hierarchy = cause.get_cause_hierarchy(gbd_round_id=gbd_round_id) + all_causes = cause_hierarchy.copy(deep=True)[CAUSE_COLUMNS_OF_INTEREST] + + fatal_subset = _get_fatal_subset( + gbd_round_id=gbd_round_id, + cause_hierarchy=cause_hierarchy, + all_causes=all_causes, + ) + + # Subset the all-causes dataframe to just the fatal cause IDs + fatal_causes_df = all_causes[all_causes.cause_id.isin(fatal_subset)] + + _recode_cause_parents(fatal_causes_df) + + return fatal_causes_df + + +def _get_fatal_subset( + gbd_round_id: int, + cause_hierarchy: pd.DataFrame, + all_causes: pd.DataFrame, +) -> list[int]: + with db_session.create_db_session(FHSDBConstants.FORECASTING_DB_NAME) as session: + # Pull fatal cause IDs + fatal_cause_ids = strategy.get_cause_set( + session=session, + strategy_id=CauseConstants.FATAL_GK_STRATEGY_ID, + cause_set_id=CauseConstants.FHS_CAUSE_SET_ID, + gbd_round_id=gbd_round_id, + )[DimensionConstants.CAUSE_ID].values + + # Get the subset of fatal cause IDs + cause_tree, node_map = make_hierarchy_tree(cause_hierarchy, ROOT_CAUSE_ID, "cause_id") + fatal_subset = include_up_hierarchy(cause_tree, node_map, fatal_cause_ids) + + _add_special_causes_to_fatal_subset(fatal_subset, all_causes) + + return fatal_subset + + +def _add_special_causes_to_fatal_subset( + fatal_subset: list[int], all_causes: pd.DataFrame +) -> None: + """In-place update ``fatal_subset`` with special fatal causes.""" + # Isolate the fatal maternal cause IDs and add them to the running list of fatal causes + maternal_subset = all_causes[all_causes.acause.str.startswith("maternal_")][ + "cause_id" + ].tolist() + fatal_subset += maternal_subset + + # Isolate the fatal CKD cause IDs and add them to the running list of fatal causes + ckd_subset = all_causes[all_causes.acause.str.startswith("ckd_")]["cause_id"].tolist() + fatal_subset += ckd_subset + + +def _recode_cause_parents(fatal_causes_df: pd.DataFrame) -> None: + """In-place update ``fatal_causes_df`` with custom cause parent remappings.""" + # Update malaria to level 2 cause (malaria is a lot larger than the other + # children of _ntd), will change it back to lvl3 cause as a child of _ntd + # in stage 5. + fatal_causes_df.loc[fatal_causes_df.acause == "malaria", "level"] = 2 + fatal_causes_df.loc[fatal_causes_df.acause == "malaria", "parent_id"] = 295 diff --git a/gbd_2021/disease_burden_forecast_code/mortality/lib/intercept_shift.py b/gbd_2021/disease_burden_forecast_code/mortality/lib/intercept_shift.py new file mode 100644 index 0000000..5d59cb1 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/mortality/lib/intercept_shift.py @@ -0,0 +1,128 @@ +from typing import Callable, List, Union + +import xarray as xr +from fhs_lib_data_transformation.lib.resample import resample +from fhs_lib_file_interface.lib.version_metadata import FHSFileSpec, VersionMetadata +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr_scenario +from fhs_lib_year_range_manager.lib.year_range import YearRange +from tiny_structured_logger.lib import fhs_logging + +DEMOGRAPHIC_DIMS = ["age_group_id", "sex_id", "location_id"] + +logger = fhs_logging.get_logger() + + +def open_xr_as_dataarray(file: FHSFileSpec, data_var: str = "value") -> xr.DataArray: + """Open an xarray file and return a DataArray (not Dataset): filter to data_var.""" + data = open_xr_scenario(file) + if isinstance(data, xr.Dataset): + data = data[data_var] + return data + + +def select_coords_by_dataarray( + data: xr.DataArray, selective_da: xr.DataArray, dims: List[str] +) -> xr.DataArray: + """Filter some dims of data so that it has the same coords as selective_da. + + Only affects the dims mentioned in "dims". On these dims, the `data` array will have its + coordinates narrowed so that they match the same dim in selective_da. + """ + selector = {dim: selective_da[dim].values for dim in dims} + try: + return data.sel(**selector) + except KeyError as err: + raise IndexError(err) + + +def intercept_shift_draws( + preds: xr.DataArray, + acause: str, + past_version: Union[str, VersionMetadata], + gbd_round_id: int, + years: YearRange, + draws: int, + shift_function: Callable, +) -> xr.DataArray: + """Load past data and use it to apply an intercept shift to preds at the draw level. + + Args: + preds (xr.DataArray): The raw predictions to intercept shift. + acause (str): The short name of the cause to intercept shift predictions of. + past_version (str): The "past", i.e., GBD, data for past years to + gbd_round_id (int): The numeric ID of the relevant GBD round. + years (YearRange): The forecasting time series year range. + draws (int): The number of draws to resample both GBD data and raw predictions + shift_function (Callable): The function for actually executing the intercept shift, + must take 3 arguments: 1) ``preds``, 2) ``past_data``, which is read in based on + ``past_version`` and other parameters, where ``preds`` and ``past_data`` have + their draws resampled to the same number of draws: ``draws``, 3) ``years.past_end`` + + Notes: + Preconditions: Expects past data to *include* the same coordinate dimensions for age, + sex, and location, but for year they should have one overlapping year. GBD-past + and modeled-past will be resampled to the given number of draws. + + Returns: + xr.DataArray: The intercept-shifted predictions. + + Raises: + IndexError: If coordinates do not match up between modeled data and GBD data. + """ + modeled_data = resample(preds, draws) + + if isinstance(past_version, VersionMetadata): + # The next line is a hint to the type checker + past_version = past_version # type: VersionMetadata + else: + past_version = VersionMetadata.make( + data_source=gbd_round_id, + epoch="past", + stage="death", + version=past_version, + root_dir="int", + ) + + past_file = FHSFileSpec(past_version, f"{acause}_hat.nc") + + raw_past_data = open_xr_as_dataarray(past_file) + + # Align DataArrays by throwing away known unimportant dimensions. + modeled_data = eliminate_dims(modeled_data, ["acause", "rei", "rei_id", "cause_id"]) + raw_past_data = eliminate_dims(raw_past_data, ["acause", "rei_id", "cause_id"]) + + # Some causes have more age groups/sex groups in the past data. + # Make sure past data have the same coordinates as the modeled data. + raw_past_data = select_coords_by_dataarray(raw_past_data, modeled_data, DEMOGRAPHIC_DIMS) + past_data = resample(raw_past_data, draws) + + shifted = shift_function( + modeled_data=modeled_data, + past_data=past_data, + past_end_year_id=years.past_end, + ) + + return shifted + + +def eliminate_dims(data: xr.DataArray, dims_to_eliminate: List[str]) -> xr.DataArray: + """Drop dims and coords of data, for each named dim.""" + for dim in dims_to_eliminate: + data = eliminate_dim(data, dim) + return data + + +def eliminate_dim(data: xr.DataArray, dim: str) -> xr.DataArray: + """Drop the given dim and its coords, provided it is single-valued. + + Drops whatever is present, and doesn't crash if the dim is missing, or the coord is + missing, or both. Crashes if dim is present with more than one coord. + + Handles cases where the dim is present as a "point-coord" or as an ordinary dimensional + coord with a single value. + """ + if dim in data.dims: + return data.squeeze(dim, drop=True) + if dim in data.coords: + return data.drop_vars(dim) + return data diff --git a/gbd_2021/disease_burden_forecast_code/mortality/lib/make_all_cause.py b/gbd_2021/disease_burden_forecast_code/mortality/lib/make_all_cause.py new file mode 100644 index 0000000..41c9994 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/mortality/lib/make_all_cause.py @@ -0,0 +1,468 @@ +"""FHS Pipeline Mortality Aggregated Over Causes. + +A module that makes a new version of mortality that +consolidates data from externally (e.g. shocks and HIV) and internally modeled +causes (e.g. cvd_ihd -- things modeled with GK). Parent causes to shocks and HIV +are re-aggregated to incorporate HIV and shocks data. + +This module should be run every time a new version of squeezed mortality, +shocks, or HIV is created. This new version directory will contain +**TRUE** all-cause mortality rate and risk-attributable burden. + +This script takes squeezed mortality data from +``FILEPATH``, + +hiv data from +``FILEPATH``, + +and shocks data from +``FILEPATH``, + +and creates a symlink to the files, or re-aggregated files as needed, in +``FILEPATH``. + +The exported ``_all.nc`` file is meant for the population code that produces +``population.nc`` for the subsequent pipeline. +""" + +from typing import Iterable, Optional, Set, Tuple + +import pandas as pd +import xarray as xr +from fhs_lib_data_transformation.lib.resample import resample +from fhs_lib_database_interface.lib.constants import ( + CauseConstants, + DimensionConstants, + StageConstants, +) +from fhs_lib_database_interface.lib.query import cause +from fhs_lib_file_interface.lib.provenance import ProvenanceManager +from fhs_lib_file_interface.lib.version_metadata import ( + FHSDirSpec, + FHSFileSpec, + VersionMetadata, +) +from fhs_lib_file_interface.lib.versioning import validate_versions_scenarios +from fhs_lib_file_interface.lib.xarray_wrapper import ( + copy_xr_setting_scenario, + open_xr_scenario, + save_xr_scenario, +) +from tiny_structured_logger.lib.fhs_logging import get_logger + +logger = get_logger() + +HIV_ACAUSES = ("hiv",) +OTH_PAND = "_oth_pand" +NTD = "_ntd" +MALARIA = "malaria" + + +def add_external_causes( + covid_scalar_acauses: Iterable[str], + covid_scalar_version: Optional[VersionMetadata], + covid_version: VersionMetadata, + final_version: VersionMetadata, + gbd_round_id: int, + draws: int, + hiv_version: VersionMetadata, + shocks_version: VersionMetadata, + squeezed_mx_version: VersionMetadata, + output_scenario: Optional[int], +) -> None: + """Copy externally-modeled causes into the ``final_version``.""" + # Validate the final_version against the output_scenario + validate_versions_scenarios( + versions=[final_version], + output_scenario=output_scenario, + output_epoch_stages=[("future", "death")], + ) + + all_deaths_causes = _get_all_deaths_causes( + covid_version=covid_version, gbd_round_id=gbd_round_id + ) + + squeezed_mx_acause, _, _ = _identify_causes_for_aggregation( + covid_scalar_acauses=covid_scalar_acauses, + covid_scalar_version=covid_scalar_version, + covid_version=covid_version, + gbd_round_id=gbd_round_id, + all_deaths_causes=all_deaths_causes, + ) + + for acause in CauseConstants.SHOCKS_ACAUSES: + _include_acause_in_version(shocks_version, final_version, acause, draws) + + for acause in HIV_ACAUSES: + _include_acause_in_version(hiv_version, final_version, acause, draws) + + _include_acause_in_version( + covid_version, final_version, CauseConstants.COVID_ACAUSE, draws + ) + + for acause in squeezed_mx_acause: + _include_acause_in_version(squeezed_mx_version, final_version, acause, draws) + + _add_malaria_into_ntd(squeezed_mx_version, final_version) + + if covid_scalar_version is not None: + _apply_covid_scalar( + covid_scalar_version, final_version, squeezed_mx_version, covid_scalar_acauses + ) + + +def _get_all_deaths_causes( + covid_version: VersionMetadata, + gbd_round_id: int, +) -> Set[str]: + """Return the set of all deaths causes, having removed any special causes.""" + all_deaths_causes = set( + cause.get_stage_cause_set( + StageConstants.DEATH, + include_aggregates=True, + gbd_round_id=gbd_round_id, + ) + ) + + # Remove OTH_PAND from cause hierarchy until we get its data from the covid team (do not + # know the date). We'll have this cause in our results for the DALY paper submission. + all_deaths_causes -= {OTH_PAND} + + return all_deaths_causes + + +def _identify_causes_for_aggregation( + covid_scalar_acauses: Iterable[str], + covid_scalar_version: Optional[VersionMetadata], + covid_version: VersionMetadata, + gbd_round_id: int, + all_deaths_causes: Set[str], +) -> Tuple[Set[str], Set[str], pd.DataFrame]: + """Return the collections of causes for aggregation. + + Returns: + Tuple[Set[str], Set[str], pd.DataFrame]: the squeezed_mx_acause, cause_hierarchy_set, + and external_parents_acause collections, in that order + """ + ( + cause_hierarchy_external_model, + cause_hierarchy_set, + ) = _get_cause_hierarchy_external_model_and_set( + gbd_round_id=gbd_round_id, + covid_scalar_acauses_externally_modeled=covid_scalar_version is not None, + covid_scalar_acauses=covid_scalar_acauses, + covid_externally_modeled=True, + ) + covid_cause_set = {CauseConstants.COVID_ACAUSE} + covid_scalar_acauses_set = _get_covid_scalar_cause_set( + covid_scalar_acauses=covid_scalar_acauses, + covid_scalar_version=covid_scalar_version, + ) + + external_parents_acause, squeezed_mx_acause = _split_external_from_squeezed_causes( + cause_hierarchy_external_model=cause_hierarchy_external_model, + all_deaths_causes=all_deaths_causes, + covid_cause_set=covid_cause_set, + covid_scalar_acauses_set=covid_scalar_acauses_set, + ) + + return squeezed_mx_acause, external_parents_acause, cause_hierarchy_set + + +def _include_acause_in_version( + input_version: VersionMetadata, + output_version: VersionMetadata, + acause: str, + draws: int, +) -> None: + """Copy the acause xarray file from input_version to output_version. + + As an optimization, if both have the same scenario configuration, we can just symlink. + """ + input_filespec = FHSFileSpec(input_version, f"{acause}.nc") + # construct dirspec for output version + output_dirspec = FHSDirSpec(version_metadata=output_version) + if output_version.scenario != input_version.scenario: + copy_xr_setting_scenario( + input_filespec=input_filespec, + output_dirspec=output_dirspec, + draws=draws, + new_filename=None, + ) + else: + logger.info(f"Symlinking: {input_filespec.data_path()} into new path.") + output_filespec = FHSFileSpec(output_version, f"{acause}.nc") + ProvenanceManager.symlink(input_filespec, output_filespec, force=True) + + +def _add_malaria_into_ntd( + squeezed_mx_version: VersionMetadata, final_version: VersionMetadata +) -> None: + """Create an ntd output by summing malaria and ntd inputs. + + Why? Malaria is treated as a lvl2 cause in stage 2 (not a child of _ntd), now need to + change it back to a lvl3 cause as a child for _ntd, add malaria to _ntd. + """ + ntd = open_xr_scenario(FHSFileSpec(squeezed_mx_version, f"{NTD}.nc")) + malaria = open_xr_scenario(FHSFileSpec(squeezed_mx_version, f"{MALARIA}.nc")) + + def square_and_sum(a: xr.DataArray, b: xr.DataArray) -> xr.DataArray: + return sum(data.fillna(0.0) for data in xr.broadcast(a, b)) + + ntd_sum = square_and_sum(ntd, malaria) + + save_xr_scenario( + ntd_sum, + FHSFileSpec(final_version, f"{NTD}.nc"), + metric="rate", + space="identity", + ) + + +def _apply_covid_scalar( + covid_scalar_version: VersionMetadata, + final_version: VersionMetadata, + squeezed_mx_version: VersionMetadata, + covid_scalar_acauses: Iterable[str] = (), +) -> None: + """Apply scalar to pertussis, measles, maternal_indirect and lri.""" + for acause in covid_scalar_acauses: + acause_da = open_xr_scenario(FHSFileSpec(squeezed_mx_version, f"{acause}.nc")) + covid_scalar = open_xr_scenario(FHSFileSpec(covid_scalar_version, f"{acause}.nc")) + + covid_scalar = covid_scalar.sel( + location_id=list( + set(acause_da.location_id.values) & set(covid_scalar.location_id.values) + ), + age_group_id=acause_da.age_group_id.values, + ) + + covid_scalar = resample(covid_scalar, len(acause_da.draw.values)) + if acause == "lri": + hib_filename = "lri_hib.nc" + pneumo_filename = "lri_pneumo.nc" + non_pneumo_non_hib_filename = "lri_non_pneumo_non_hib.nc" + + # read in lri children + non_pneumo_non_hib = open_xr_scenario( + FHSFileSpec(squeezed_mx_version, non_pneumo_non_hib_filename) + ) + pneumo = open_xr_scenario(FHSFileSpec(squeezed_mx_version, pneumo_filename)) + hib = open_xr_scenario(FHSFileSpec(squeezed_mx_version, hib_filename)) + + hib = resample(hib, len(acause_da.draw.values)) + pneumo = resample(pneumo, len(acause_da.draw.values)) + non_pneumo_non_hib = resample(non_pneumo_non_hib, len(acause_da.draw.values)) + + # subtract the difference from non_pneumo_non_hib + scaled_acause_da = acause_da * covid_scalar + diff = acause_da - scaled_acause_da + remaining_draws = non_pneumo_non_hib - diff + scaled_non_pneumo_non_hib = xr.where(remaining_draws >= 0, remaining_draws, 0) + + # squeeze hib and pneumo to the scaled lri + children_broadcasted = xr.broadcast(hib, pneumo, scaled_non_pneumo_non_hib) + children_broadcasted = [data.fillna(0.0) for data in children_broadcasted] + children_sum = sum(children_broadcasted) + ratio = scaled_acause_da / children_sum + squeezed_hib = ratio * hib + squeezed_pneumo = ratio * pneumo + squeezed_non_pneumo_non_hib = ratio * scaled_non_pneumo_non_hib + + year_other = [ + x for x in hib.year_id.values if x not in squeezed_hib.year_id.values + ] + hib_other = hib.sel(year_id=year_other) + pneumo_other = pneumo.sel(year_id=year_other) + non_pneumo_non_hib_other = non_pneumo_non_hib.sel(year_id=year_other) + squeezed_hib_all_years = xr.concat([squeezed_hib, hib_other], dim="year_id") + squeezed_pneumo_all_years = xr.concat( + [squeezed_pneumo, pneumo_other], dim="year_id" + ) + squeezed_non_pneumo_non_hib_all_years = xr.concat( + [squeezed_non_pneumo_non_hib, non_pneumo_non_hib_other], dim="year_id" + ) + + hib_file_spec = FHSFileSpec(final_version, hib_filename) + pneumo_file_spec = FHSFileSpec(final_version, pneumo_filename) + non_pneumo_non_hib_file_spec = FHSFileSpec( + final_version, non_pneumo_non_hib_filename + ) + ProvenanceManager.remove(hib_file_spec) + ProvenanceManager.remove(pneumo_file_spec) + ProvenanceManager.remove(non_pneumo_non_hib_file_spec) + + save_xr_scenario( + squeezed_hib_all_years, + hib_file_spec, + metric="rate", + space="identity", + ) + save_xr_scenario( + squeezed_pneumo_all_years, + pneumo_file_spec, + metric="rate", + space="identity", + ) + save_xr_scenario( + squeezed_non_pneumo_non_hib_all_years, + non_pneumo_non_hib_file_spec, + metric="rate", + space="identity", + ) + + else: + scaled_acause_da = acause_da * covid_scalar + + acause_other = acause_da.sel( + year_id=[ + x for x in acause_da.year_id.values if x not in scaled_acause_da.year_id.values + ] + ) + scaled_acause_all_years = xr.concat([scaled_acause_da, acause_other], dim="year_id") + save_xr_scenario( + scaled_acause_all_years, + FHSFileSpec(final_version, f"{acause}.nc"), + metric="rate", + space="identity", + ) + save_xr_scenario( + acause_da, + FHSFileSpec(final_version, f"{acause}_original.nc"), + metric="rate", + space="identity", + ) + + +def _get_cause_hierarchy_external_model_and_set( + gbd_round_id: int, + covid_scalar_acauses_externally_modeled: bool, + covid_externally_modeled: bool, + covid_scalar_acauses: Iterable[str] = (), +) -> Tuple[pd.DataFrame, pd.DataFrame]: + """Query db and get cause hierarchy external model and set. + + Args: + gbd_round_id (int): The GBD round to query. + covid_scalar_acauses_externally_modeled (bool): Include the "covid scalar causes" + amongst the external-modeled ones. + covid_scalar_acauses (Tuple[str, ...]): Which ones are the "covid scalar causes"? + covid_externally_modeled (bool): Include lri_corona amongst the external-modeled + causes. + + Returns: + Tuple[pd.DataFrame, pd.DataFrame]: cause hierarchy external model + DataFrame and cause hierarchy set DataFrame. + """ + cause_hierarchy_set = cause.get_cause_hierarchy(gbd_round_id) + + if not covid_externally_modeled: + cause_hierarchy_set = cause_hierarchy_set[ + cause_hierarchy_set.acause != CauseConstants.COVID_ACAUSE + ] + + relevant_external_modeled_acauses = CauseConstants.EXTERNAL_MODELED_ACAUSES + if covid_externally_modeled: + relevant_external_modeled_acauses += (CauseConstants.COVID_ACAUSE,) + if covid_scalar_acauses_externally_modeled: + relevant_external_modeled_acauses += covid_scalar_acauses + + # Note cause_external_modeled_subset may include OTH_PAND, even if it is removed from + # cause_hierarchy_set below. + cause_external_modeled_subset = cause_hierarchy_set.query( + f"acause in {tuple(relevant_external_modeled_acauses)}" + ) + + # Remove it from cause hierarchy + cause_hierarchy_set = cause_hierarchy_set[cause_hierarchy_set.acause != OTH_PAND] + + # Note in the caller this is _only_ used to produce a set of acauses that are parent + # acauses. + cause_hierarchy_external_model = _create_cause_to_parent_map( + cause_hierarchy_set, cause_external_modeled_subset + ) + + return cause_hierarchy_external_model, cause_hierarchy_set + + +def _create_cause_to_parent_map( + cause_id_to_acause_map: pd.DataFrame, acauses_with_path_to_top_parent: pd.DataFrame +) -> pd.DataFrame: + """Translate acauses_with_path_to_top_parent to a straightforward child-parent map. + + Given a path_to_top_parent column, return a DataFrame mapping each acause to its parent. To + do this, we need a caust_id-to-acause mapping, which must be provided by the first arg. + """ + cause_hierarchy_external_model = acauses_with_path_to_top_parent[ + [DimensionConstants.ACAUSE, DimensionConstants.PATH_TO_TOP_PARENT] + ] + + # Re-structure external cause hierarchy dataframe so each row is one cause-ancestor pair. + cause_hierarchy_external_model[DimensionConstants.PATH_TO_TOP_PARENT] = ( + cause_hierarchy_external_model[DimensionConstants.PATH_TO_TOP_PARENT].str.split(",") + ) + s = ( + cause_hierarchy_external_model[DimensionConstants.PATH_TO_TOP_PARENT] + .apply(pd.Series, 1) + .stack() + ) + s.index = s.index.droplevel(-1) + s.name = DimensionConstants.PARENT_ID_COL + s = pd.to_numeric(s) + del cause_hierarchy_external_model[DimensionConstants.PATH_TO_TOP_PARENT] + cause_hierarchy_external_model = cause_hierarchy_external_model.join(s) + + # Merge name of ancestor causes to external cause hierarchy dataframe + cause_hierarchy_external_model = pd.merge( + cause_hierarchy_external_model, + cause_id_to_acause_map[ + [DimensionConstants.CAUSE_ID, DimensionConstants.ACAUSE] + ].rename( + columns={ + DimensionConstants.CAUSE_ID: DimensionConstants.PARENT_ID_COL, + DimensionConstants.ACAUSE: DimensionConstants.PARENT_ACAUSE, + } + ), + ) + + return cause_hierarchy_external_model + + +def _get_covid_scalar_cause_set( + covid_scalar_acauses: Iterable[str], + covid_scalar_version: VersionMetadata, +) -> Set[str]: + """Return the set of COVID scalar cause set to include in computation.""" + if covid_scalar_version is not None: + covid_scalar_acauses_set = set(covid_scalar_acauses) + else: + covid_scalar_acauses_set = set() + + return covid_scalar_acauses_set + + +def _split_external_from_squeezed_causes( + cause_hierarchy_external_model: pd.DataFrame, + all_deaths_causes: Set[str], + covid_cause_set: Set[str], + covid_scalar_acauses_set: Set[str], +) -> Tuple[Set[str], Set[str]]: + """Return the sets of external and squeezed acauses.""" + external_parents_acause = ( + set(cause_hierarchy_external_model[DimensionConstants.PARENT_ACAUSE].tolist()) + - set(CauseConstants.SHOCKS_ACAUSES) + - set(HIV_ACAUSES) + - covid_cause_set + - covid_scalar_acauses_set + ) + squeezed_mx_acause = ( + set(all_deaths_causes) + - set(CauseConstants.SHOCKS_ACAUSES) + - set(HIV_ACAUSES) + - covid_cause_set + - set(external_parents_acause) + - covid_scalar_acauses_set + - {NTD} + ) + + return external_parents_acause, squeezed_mx_acause diff --git a/gbd_2021/disease_burden_forecast_code/mortality/lib/make_hierarchies.py b/gbd_2021/disease_burden_forecast_code/mortality/lib/make_hierarchies.py new file mode 100644 index 0000000..5e589b4 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/mortality/lib/make_hierarchies.py @@ -0,0 +1,116 @@ +"""Functions used by Mortality stage 2 and stage 3. + +to set up the aggregation hierarchy and the +cause set for ARIMA to operate on. +""" + +from __future__ import annotations + +from typing import Dict, List, Optional, Tuple + +import pandas as pd + + +class EntityNode(object): + """Initialize an EntityNode object. + + Args: + entity_id (int): id for the node to be associated with + children (Optional): Child EntityNodes or None if the node is a leaf + + """ + + def __init__(self, entity_id: int, children: Optional[List[EntityNode]] = None) -> None: + """Initialize an EntityNode object.""" + self.entity_id = entity_id + self.children = children or [] + + +def make_hierarchy_tree( + entity_hierarchy: pd.DataFrame, root_id: int, id_col_name: str +) -> Tuple[EntityNode, Dict]: + """Converts a pandas dataframe hierarchy into a tree. + + Args: + entity_hierarchy (pd.DataFrame): the pandas dataframe hierarchy. + root_id (int): id of the root in the hierarchy. + id_col_name (str): name of the column in the dataframe that the hierarchy is being set + up on. + + Returns: + Tuple: root node and dictionary representation of the tree. + """ + hier_indexed = entity_hierarchy.set_index("parent_id").sort_index() + root = EntityNode(root_id) + node_queue = [root] + node_map = dict(root_id=root) + while node_queue: + curr = node_queue.pop() + try: + children = pd.Series(hier_indexed.loc[curr.entity_id][id_col_name]).values + except KeyError: + continue # leaf node! + + for child in children: + if child == curr.entity_id: + continue + child_node = EntityNode(child) + curr.children.append(child_node) + node_map[child] = child_node + node_queue.append(child_node) + return root, node_map + + +def include_up_hierarchy(root: EntityNode, node_map: Dict, entity_ids: List[int]) -> List: + """Filter the hierarchy tree. + + Include node if any of its descendents are in the entity ids. + + Generally speaking, we would like to include the filtering by [include/exclude] if + [any/all] [ancestors/descendents] are in the entity ids. + + For now this does include if any descendents are in the strategy set. + + Args: + root (EntityNode): An entity node, with some children or maybe no children. + node_map (Dict): Dictionary representation of the tree + entity_ids (List[int]): List of entity ids to filter with + + Returns: + List: List of entity ids + + """ + for entity_id in entity_ids: + node_map[entity_id].mark = True + subset = [] + _include_up_hierarchy(root, subset) + return subset + + +def _include_up_hierarchy(node: EntityNode, subset: List) -> bool: + """Mark nodes or recurse on children. + + Mark nodes True, if they are markable, otherwise recurse on children to add them to the + hierarchy before adding self to the hierarchy. + + Args: + node (EntityNode): Entity node to include. + subset (List): The subset of enity ids in the hierarchy. + + Returns: + bool: Whether the node has been included or not. + + """ + if hasattr(node, "mark"): + subset.append(node.entity_id) + return True + + else: + include = False + for child in node.children: + if _include_up_hierarchy(child, subset): + include = True + if include: + node.mark = True + subset.append(node.entity_id) + return include diff --git a/gbd_2021/disease_burden_forecast_code/mortality/lib/mortality_approximation.py b/gbd_2021/disease_burden_forecast_code/mortality/lib/mortality_approximation.py new file mode 100644 index 0000000..be07270 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/mortality/lib/mortality_approximation.py @@ -0,0 +1,788 @@ +"""Perform mortality approximation for every cause.""" + +import gc +from typing import Iterable, List, Optional, Union + +import numpy as np +import pandas as pd +import xarray as xr +from fhs_lib_data_transformation.lib.dimension_transformation import expand_dimensions +from fhs_lib_data_transformation.lib.processing import strip_single_coord_dims +from fhs_lib_data_transformation.lib.resample import resample +from fhs_lib_data_transformation.lib.validate import check_dataarray_shape +from fhs_lib_database_interface.lib import db_session +from fhs_lib_database_interface.lib.constants import ( + CauseConstants, + DimensionConstants, + SexConstants, +) +from fhs_lib_database_interface.lib.query import location +from fhs_lib_database_interface.lib.strategy_set.query import get_property_values +from fhs_lib_database_interface.lib.strategy_set.strategy import get_cause_set +from fhs_lib_file_interface.lib.check_input import check_versions +from fhs_lib_file_interface.lib.provenance import ProvenanceManager +from fhs_lib_file_interface.lib.version_metadata import ( + FHSDirSpec, + FHSFileSpec, + VersionMetadata, +) +from fhs_lib_file_interface.lib.versioning import Versions +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr_scenario, save_xr_scenario +from fhs_lib_year_range_manager.lib.year_range import YearRange +from tiny_structured_logger.lib.fhs_logging import get_logger + +logger = get_logger() + +SUPERFLUOUS_DIMS = {"variable", "acause", "cov"} +INJ_TRANS_ROAD = "inj_trans_road" +NTD_EXCEPTIONS = ["ntd_nema", "ntd_dengue"] + +SCALAR_RATIO_DEFAULT = 1 +SDI_BETA_COEFF_DEFAULT = 0 # goes into the exponential --> exp(0) = 1 +NON_SEV_COVARIATES = { + "asfr", + "sdi", + "time_var", + "sdi_time", + "hiv", +} +EXPECTED_STAGES = ( + "sdi", + "scalar", + "sev", + "asfr", + "hiv", + "death", +) +NO_COVS_SCALARS = [ + "ntd_afrtryp", + "ntd_dengue", +] + +# .nc file variable specifics for SDI and SDI * t +BETA_GLOBAL = "beta_global" # what the beta coeff is named in .nc +SDI_TIME = "sdi_time" +SDI = "sdi" + + +def get_fatal_sex_availability(gbd_round_id: int) -> pd.DataFrame: + """Get sex availability by acause for fatal causes.""" + with db_session.create_db_session() as session: + fatal_gk_causes = get_cause_set( + session=session, + gbd_round_id=gbd_round_id, + strategy_id=CauseConstants.FATAL_GK_STRATEGY_ID, + )[[DimensionConstants.ACAUSE]] + + fatal_sex_availability = ( + get_property_values(session, DimensionConstants.CAUSE, gbd_round_id) + .reset_index()[[DimensionConstants.ACAUSE, "fatal_available_sex_id"]] + .rename(columns={"fatal_available_sex_id": DimensionConstants.SEX_ID}) + ) + + fatal_gk_causes_with_sex_availability = fatal_gk_causes.merge(fatal_sex_availability) + return fatal_gk_causes_with_sex_availability + + +def _decide_base_scenario( + da: xr.DataArray, base_scenario: Optional[int] = None +) -> xr.DataArray: + """Transform the base dataaarray according to prescribed base scenario. + + If base_scenario is None, return original array. + If prescribed, return that scenario, and remove the scenario dim. + + Args: + da (xr.DataArray): DataArray with scenario dimension. + base_scenario (int): Base scenario to operate on. + If specified, performs all-to-one operation on given base scenario. + If not specified (None), performs one-to-one operation on + scenarios. + + Returns: + (xr.DataArray): Original da if base_scenario is None, otherwise return + da.sel(scenario=base_scenario).drop_vars("scenario") + """ + if base_scenario is not None: + return da.sel(scenario=base_scenario, drop=True) + else: + return da + + +def _load_scalar_ratio( + gbd_round_id: int, + past_or_future: str, + base_versions: Versions, + versions: Versions, + acause: str, + base_scenario: Optional[int] = None, + draws: Optional[int] = None, + run_on_means: bool = False, + national_only: bool = False, +) -> Union[xr.DataArray, int]: + """Compute scalar / scalar_0. + + Args: + gbd_round_id (int): GBD round id. + past_or_future (str): Either ``"past"`` or ``"future"``. + base_versions (Versions): Versions object with baseline scalar version. + versions (Versions): Versions object with new scalar version. + acause (str): Name of analytical cause of death. + base_scenario (Optional[int]): Base scenario to operate on. + If specified, performs all-to-one operation on given base scenario. + If not specified (None), performs one-to-one operation on + scenarios. + draws (Optional[int]): Number of draws needed. + run_on_means (bool): Use mean of draws. + national_only (bool): national locations only. + + Returns: + (xr.DataArray | int): scalar / scalar_0 + """ + base_scalar_version = base_versions.get(past_or_future, "scalar").default_data_source( + gbd_round_id + ) + base_scalar_file = FHSFileSpec(base_scalar_version, f"{acause}.nc") + + # not every cause has a scalar + if ProvenanceManager.exists(base_scalar_file): + scalar_base = open_xr_scenario(base_scalar_file) + + scalar = open_xr_scenario( + FHSFileSpec( + versions.get(past_or_future, "scalar").default_data_source(gbd_round_id), + f"{acause}.nc", + ) + ) + scalar = strip_single_coord_dims(scalar) + + missing_dims = list(set(scalar_base.dims).difference(scalar.dims)) + new_dims = {dim: scalar_base[dim].values.tolist() for dim in missing_dims} + scalar = scalar.expand_dims(new_dims) + + # If all approximations are done on a single base scenario, then we remove all other + # scenarios, drop the scenario dim, to allow broadcast arithmetic to take over + scalar_base = _decide_base_scenario(scalar_base, base_scenario) + + if draws: # only resample if asked to + scalar_base = resample(scalar_base, num_of_draws=draws) + scalar = resample(scalar, num_of_draws=draws) + + if run_on_means: + scalar_base = scalar_base.mean("draw") + scalar = scalar.mean("draw") + + if national_only: + national_locations = ( + location.get_location_set(gbd_round_id) + .query("level == 3") + .location_id.tolist() + ) + scalar_base = scalar_base.sel(location_id=national_locations) + scalar = scalar.sel(location_id=national_locations) + + scalar_ratio = scalar / scalar_base # inner join arithmetic + + else: + logger.warning(f"Base scalar for {acause} does not exist") + + scalar_ratio = SCALAR_RATIO_DEFAULT # 1 + + return scalar_ratio + + +def _copy_national_coefficients_to_subnationals( + da: xr.DataArray, gbd_round_id: int +) -> xr.DataArray: + """Copy national coefficients to subnational locations. + + Args: + da (xr.DataArray): Betas array + gbd_round_id (int): GBD round ID + + Returns: + xr.DataArray: betas array with all subnational location IDs included + """ + # get parents and location_ids of level 4 locations present in ds + locs = location.get_location_set(gbd_round_id=gbd_round_id).query("level in [3, 4]") + + nationals = locs.query("level== 4").parent_id.unique().tolist() + nationals_in_da = [loc for loc in nationals if loc in da.location_id] + + if not np.isin(nationals, nationals_in_da).all(): + raise ValueError("Missing level 3 locations in death dataset.") + + ds_all_locs = [da] + for nat_loc in nationals_in_da: + nat_ds = da.sel(location_id=nat_loc, drop=True) + subnats = locs.query("parent_id == @nat_loc").location_id.values + subnat_ds = expand_dimensions(nat_ds, location_id=subnats) + ds_all_locs.append(subnat_ds) + + ds_all_locs = xr.concat(ds_all_locs, dim=DimensionConstants.LOCATION_ID) + return ds_all_locs + + +def _load_beta_dataset( + gbd_round_id: int, + gk_version: str, + acause: str, + past_or_future: str, + draws: Optional[int] = None, + run_on_means: bool = False, + national_only: bool = False, +) -> xr.Dataset: + """Get beta coefficients dataset from betas directory. + + Args: + gbd_round_id (int): GBD round ID for data being approximated. + gk_version (str): 1st stage mortality results. Contains `betas` + directory. + acause (str): Cause + past_or_future (str): `past` or `future` + draws (int): Optional. Number of draws in data. + run_on_means (bool): Flag for running on means only. + national_only (bool): National locations only. + + Returns: + (xr.Dataset | int): betas for all covariates. + """ + gk_version_metadata = VersionMetadata.make( + data_source=gbd_round_id, epoch=past_or_future, stage="death", version=gk_version + ) + betas_dir = FHSDirSpec(gk_version_metadata, ("betas",)) + # most detailed data is saved by sex with suffix + sex_avail = ( + get_fatal_sex_availability(gbd_round_id).query("acause==@acause").sex_id.values.item() + ) + + betas_list = [] + + for sex_id in sex_avail: + beta_path = FHSFileSpec.from_dirspec( + betas_dir, f"{acause}_{SexConstants.SEX_DICT[sex_id]}.nc" + ) + + beta_tmp = open_xr_scenario(beta_path)[BETA_GLOBAL] + + if draws: + beta_tmp = resample(beta_tmp, draws) + + if run_on_means: + beta_tmp = beta_tmp.mean("draw") + + betas_list.append(beta_tmp) + + beta = _make_beta(betas_list) + + if national_only: + national_locations = ( + location.get_location_set(gbd_round_id).query("level == 3").location_id.tolist() + ) + beta_full = beta.sel(location_id=national_locations) + else: + # copy values in national location ids to subnationals + beta_full = _copy_national_coefficients_to_subnationals( + da=beta, + gbd_round_id=gbd_round_id, + ) + + return beta_full + + +def _load_covariates( + beta: xr.Dataset, + versions: Versions, + gbd_round_id: int, + past_or_future: str, + years: YearRange, + base_scenario: Optional[int] = None, + draws: Optional[int] = None, + run_on_means: bool = False, + national_only: bool = False, +) -> xr.Dataset: + """Load all covariates in a Dataset. + + Args: + beta (Union[xr.Dataset, int]): Covariate dataset or 0. If there are + no covariates in the betas file, this value is 0. + versions (Versions): Versions object containing all covariate + versions e.g. sdi/fake_version. + gbd_round_id (int): GBD round ID. + past_or_future (str): `past` or `future`. + years (YearRange): A YearRange object of relevant years. + base_scenario (Optional[int]): Base scenario to operate on. If + specified, performs all-to-one operation on given base scenario. If + not specified (None), performs one-to-one operation on + scenarios. + draws (Optional[int]): Number of draws needed. + run_on_means (bool): Take the mean of the covariate data. + national_only (bool): Run on national locations only. + + Returns: + (xr.Dataset): All covariates in a dataset + """ + beta_covs = beta.cov.values.tolist() + sdi_time = True if SDI_TIME in beta_covs else False + covs = [cov for cov in beta_covs if cov not in ["intercept", "sdi_time", "time_var"]] + + covariates = xr.Dataset() + for cov in covs: + if cov in NON_SEV_COVARIATES: + cov_version = versions.get(past_or_future, cov).default_data_source(gbd_round_id) + if cov == "hiv": + covariate = open_xr_scenario( + FHSFileSpec(cov_version.with_stage("death"), f"{cov}.nc") + ) + else: + covariate = open_xr_scenario(FHSFileSpec(cov_version, f"{cov}.nc")) + else: + # these are sev covariates + cov_version = versions.get(past_or_future, "sev").default_data_source(gbd_round_id) + covariate = open_xr_scenario(FHSFileSpec(cov_version, f"{cov}.nc")) + + # remove age/sex dims if they are aggregates + covariate = strip_single_coord_dims(covariate) + + covariate = _decide_base_scenario(covariate, base_scenario) + + if draws: + covariate = resample(covariate, num_of_draws=draws) + + if run_on_means: + covariate = covariate.mean("draw") + + if national_only: + national_locations = ( + location.get_location_set(gbd_round_id) + .query("level == 3") + .location_id.tolist() + ) + covariate = covariate.sel(location_id=national_locations) + + covariates[cov] = covariate + + # add sdi_time covariate if it exists + if sdi_time: + # time is treated as t_0, t_0 + 1, t_0 + 2,...t_n + k for t_0 = 0 + # and k = (number of years in time series - 1) + year_list = years.years + time_array = xr.DataArray( + year_list - years.past_start, dims=["year_id"], coords=[year_list] + ) + sdi_time_cov = covariates[SDI] * time_array + covariates[SDI_TIME] = sdi_time_cov + + return covariates + + +def get_exp_beta_cov_diff( + gbd_round_id: int, + past_or_future: str, + base_versions: Versions, + versions: Versions, + base_gk_version: str, + acause: str, + years: YearRange, + base_scenario: Optional[int] = None, + draws: Optional[int] = None, + run_on_means: bool = False, + national_only: bool = False, +) -> xr.Dataset: + """Compute exp(beta * (cov - cov_0)). + + Args: + gbd_round_id (int): GBD round ID + past_or_future (str): `past` or `future` + base_versions (Versions): Versions object containing base versions + for all input data. + versions (Versions): Versions object containing new versions for + approximated data. + base_gk_version (str): Version directory with saved betas files. + acause (str): Cause. + years (YearRange): YearRange object with forecast years + base_scenario (Optional[int]): Base scenario to operate on. If + specified, performs all-to-one operation on given base scenario. If + not specified (None), performs one-to-one operation on + scenarios. + draws (Optional[int]): Number of draws needed. + run_on_means (bool): Take the mean of the covariate data. + national_only (bool): Run on national locations only. + + Returns: + (xr.Dataset): exp(beta * cov - cov_0) for each covariate + """ + beta = _load_beta_dataset( + gbd_round_id, + base_gk_version, + acause, + past_or_future, + draws, + run_on_means, + national_only, + ) + + original_covariates = _load_covariates( + beta, + base_versions, + gbd_round_id, + past_or_future, + years, + base_scenario, + draws, + run_on_means, + national_only, + ) + new_covariates = _load_covariates( + beta, + versions, + gbd_round_id, + past_or_future, + years, + None, + draws, + run_on_means, + national_only, + ) + + cov_diff = new_covariates - original_covariates + + exp_beta_cov_diff = xr.Dataset() + + for cov in original_covariates: + if acause.startswith(INJ_TRANS_ROAD): + beta_cov = beta.sel(cov=cov) + else: + beta_cov = beta.sel(cov=cov, drop=True) + + _exp_beta_cov_diff = np.exp(beta_cov * cov_diff[cov]) + _exp_beta_cov_diff.name = cov + exp_beta_cov_diff = exp_beta_cov_diff.combine_first(_exp_beta_cov_diff.to_dataset()) + + exp_beta_cov_diff = exp_beta_cov_diff.fillna(1) + + return exp_beta_cov_diff + + +def _make_beta( + beta_list: Union[List[xr.DataArray], List[xr.Dataset]] +) -> Union[xr.DataArray, xr.Dataset]: + """Concat a list of sex-specific beta arrays into a single array. + + Used for both SDI (beta) and SDI * t (beta_t) coefficients. + + Args: + beta_list (List[xr.DataArray]): list of sex-specific beta arrays. + + Returns: + (Union[xr.DataArray, int]): + Either dataarray of both sexes of betas, or 0. + """ + # There are three cases: + # * If beta/beta_t actually exist, then combine the sexes into a single array. + # * Otherwise, beta/beta_t are set to 0. + # * Sometimes beta/beta_t exist, but are all nulls. Set these to 0 as well. + + if len(beta_list) > 0: + # might have only one sex_id + beta = xr.concat(beta_list, dim=DimensionConstants.SEX_ID) + del beta_list + if isinstance(beta, xr.Dataset): + for var in beta: + if beta[var].isnull().all(): + beta[var] = SDI_BETA_COEFF_DEFAULT + else: + if beta.isnull().all(): # it could be that they're all nulls + beta = SDI_BETA_COEFF_DEFAULT + else: + beta = SDI_BETA_COEFF_DEFAULT # exp(0) = 1 + + return beta + + +def covariate_approximation( + base_death: xr.DataArray, + exp_beta_cov_diff: xr.Dataset, +) -> xr.Dataset: + """Perform approximation calculation by covariate. + + SEV covariates may not match dimensions of mortality exactly across all + dimensions. + First calculate approximation with non-age-specific covariates. Then use + .combine_first + to default to base_mx dimensions. + + Args: + base_death (xr.DataArray): Base mortality array + exp_beta_cov_diff (xr.Dataset): exp(beta * cov_diff) for each + covariate in betas file. Where there is no covariate data, + this is just 1. + + Returns: + (xr.Dataset): death * exp_beta_cov_diff[cov] for each cov + """ + sev_covariates = set(exp_beta_cov_diff) - NON_SEV_COVARIATES + + non_sev_covariates = set(exp_beta_cov_diff) & NON_SEV_COVARIATES + + death_post_non_sev_covs = base_death + for cov in non_sev_covariates: + death_post_non_sev_covs = death_post_non_sev_covs * exp_beta_cov_diff[cov] + + death = death_post_non_sev_covs + for cov in sev_covariates: + death = death * exp_beta_cov_diff[cov] + + death = death.combine_first(death_post_non_sev_covs) + + return death + + +def has_covs(acause: str, gk_version: str, gbd_round_id: int) -> bool: + """Check betas file for existence of covariate coefficients. + + Args: + acause (str): Cause + gk_version (str): version name with betas sub directory + gbd_round_id (int): GBD round ID + + Returns: + (Bool): True if cov dimension in betas file for acause has + covariates other than intercept and time_var. + """ + if acause in NTD_EXCEPTIONS: + return False + sex_avail = ( + get_fatal_sex_availability(gbd_round_id).query("acause==@acause").sex_id.values.item() + ) + covs = [] + for sex in sex_avail: + sex_name = SexConstants.SEX_DICT[sex] + cov_tmp = open_xr_scenario( + FHSFileSpec( + version_metadata=VersionMetadata.make( + data_source=gbd_round_id, epoch="future", stage="death", version=gk_version + ), + sub_path=("betas",), + filename=f"{acause}_{sex_name}.nc", + ) + ).cov.values.tolist() + for cov in cov_tmp: + if cov not in ["intercept", "time_var"]: + covs.append(cov) + + if len(covs) > 0: + return True + else: + return False + + +def _validate_draws(run_on_means: bool, draws: int) -> None: + """Validate that the `draws` and `run_on_means` parameters have been passed properly. + + Valid specification of these parameters would be where one but not the other was provided. + + Args: + run_on_means (bool): optionally specified argument from `mortality_approximation` + draws (int): optionally specified argument from `mortality_approximation` + + Raises: + ValueError: if neither `run_on_means` nor `draws` have been specified + """ + if not run_on_means: + if draws is None: + raise ValueError("draws arg must be specified if not run_on_means") + + +def mortality_approximation_calculate( + gbd_round_id: int, + past_or_future: str, + base_gk_version: str, + base_versions: Versions, + versions: Versions, + acause: str, + years: YearRange, + expected_scenarios: Iterable[int], + base_scenario: Optional[int] = None, + draws: Optional[int] = None, + run_on_means: bool = True, + national_only: bool = False, +) -> xr.DataArray: + """See mortality_approximation.""" + check_versions(base_versions, "future", EXPECTED_STAGES) + check_versions(versions, "future", EXPECTED_STAGES) + + _validate_draws(run_on_means=run_on_means, draws=draws) + + if has_covs(acause=acause, gbd_round_id=gbd_round_id, gk_version=base_gk_version): + exp_beta_cov_diff = get_exp_beta_cov_diff( + gbd_round_id, + past_or_future, + base_versions, + versions, + base_gk_version, + acause, + years, + base_scenario, + draws, + run_on_means, + national_only, + ) + else: + exp_beta_cov_diff = None + + base_death_version = base_versions.get(past_or_future, "death").with_data_source( + gbd_round_id + ) + base_death = open_xr_scenario(FHSFileSpec(base_death_version, f"{acause}.nc")) + last_past_year = base_death.sel(year_id=years.past_end) + base_death = base_death.sel(year_id=years.forecast_years) + + if draws: + last_past_year = resample(last_past_year, draws) + base_death = resample(base_death, draws) + + if national_only: + national_locations = ( + location.get_location_set(gbd_round_id).query("level == 3").location_id.tolist() + ) + base_death = base_death.sel(location_id=national_locations) + + # get coords of base death for post approximation validation + base_death_coords = dict(base_death.coords) + + # NOTE as it stands now, GK-produced mortality dataarrays have these "variable" and + # "acause" point dims that can be dropped. + for dim in ["variable", "acause"]: + if dim in base_death_coords: + del base_death_coords[dim] + expected_coords = { + dim: list(base_death_coords[dim].values) for dim in base_death_coords.keys() + } + expected_coords["scenario"] = expected_scenarios + + base_death = _decide_base_scenario(base_death, base_scenario) + + # first we perform the covariate adjustments if applicable. + if exp_beta_cov_diff is not None: + death_post_cov = covariate_approximation(base_death, exp_beta_cov_diff) + else: + death_post_cov = base_death + + del base_death, exp_beta_cov_diff + gc.collect() + + # Next we perform adjustments along the scalars axis. + scalar_ratio = _load_scalar_ratio( + gbd_round_id, + past_or_future, + base_versions, + versions, + acause, + base_scenario=base_scenario, + draws=draws, + run_on_means=run_on_means, + national_only=national_only, + ) + + # Because scalars may not match death exactly across dims (age_group_id), we use + # .combine_first to default to _post_sdi if coordinate is missing. If scalar_ratio has more + # age_group_id than base_death, then the following inner-join removes extra ones. If + # scalar_ratio has fewer, then the combine_first will restore the defaults. + death = death_post_cov * scalar_ratio # inner-join + + death = death.combine_first(death_post_cov) + + del death_post_cov, scalar_ratio + gc.collect() + + death = death.drop_vars(SUPERFLUOUS_DIMS & set(death.coords)).squeeze() + + # convert year_id to int + if death.year_id.dtype != "int64": + death = death.assign_coords(year_id=death.year_id.astype("int")) + + # expand sex_id if squeezed + if "sex_id" not in death.dims and "sex_id" in expected_coords.keys(): + death = death.expand_dims("sex_id") + if len(expected_coords["sex_id"]) > 1: + death = expand_dimensions(death, sex_id=[1, 2], fill_value=np.nan) + + # ntd_dengue and ntd_afrtryp have no covariates (other than time and intercept) and no + # scalars - no approximation + if acause in NO_COVS_SCALARS: + death = xr.concat( + [death.assign_coords({"scenario": x}) for x in expected_scenarios], dim="scenario" + ) + + # validate all coords as expected + check_dataarray_shape(death, expected_coords) + + # include last past year + death = xr.concat([death, last_past_year], dim="year_id") + return death + + +def mortality_approximation( + gbd_round_id: int, + past_or_future: str, + base_gk_version: str, + base_versions: Versions, + versions: Versions, + acause: str, + years: YearRange, + expected_scenarios: Iterable[int], + base_scenario: Optional[int] = None, + draws: Optional[int] = None, + run_on_means: bool = True, + national_only: bool = False, +) -> None: + """Perform Mortality Approximation on given cause. + + Args: + gbd_round_id (int): The GBD round for this run. + past_or_future (str): Either ``"past"`` or ``"future"``. + base_gk_version (str): GK results with betas directory. + base_versions (Versions): Base versions of all inputs. + versions (Versions): Approximation versions of all inputs. + acause (str): Name of analytical cause of death. + years (YearRange): Past and forecast years. + expected_scenarios (Iterable[int]): Expected scenarios in final output. + base_scenario (Optional[int]): base scenario to operate on. + If specified, performs all-to-one operation on given base scenario. + If not specified (None), performs one-to-one operation on + scenarios. + draws (Optional[int]): Number of draws needed. + run_on_means (bool): Use means of draws. + national_only (bool): National locations only + + Returns: + None + """ + death = mortality_approximation_calculate( + gbd_round_id, + past_or_future, + base_gk_version, + base_versions, + versions, + acause, + years, + expected_scenarios, + base_scenario, + draws, + run_on_means, + national_only, + ) + + save_file = FHSFileSpec( + versions.get(past_or_future, "death").default_data_source(gbd_round_id), f"{acause}.nc" + ) + + save_xr_scenario( + death, + save_file, + metric="rate", + space="identity", + base_death_version=base_versions["future"]["death"], + base_scenario=str(base_scenario), + ) diff --git a/gbd_2021/disease_burden_forecast_code/mortality/lib/run_cod_model.py b/gbd_2021/disease_burden_forecast_code/mortality/lib/run_cod_model.py new file mode 100644 index 0000000..1f92461 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/mortality/lib/run_cod_model.py @@ -0,0 +1,812 @@ +"""Functions to run cod model. +""" + +import sys +from typing import Dict, Iterable, List, Optional, Tuple + +import numpy as np +import xarray as xr +from fhs_lib_data_transformation.lib.dimension_transformation import expand_dimensions +from fhs_lib_data_transformation.lib.post_hoc_relu import post_hoc_relu +from fhs_lib_data_transformation.lib.processing import invlog_with_offset, log_with_offset +from fhs_lib_data_transformation.lib.quantiles import Quantiles +from fhs_lib_data_transformation.lib.truncate import cap_forecasts +from fhs_lib_database_interface.lib import db_session +from fhs_lib_database_interface.lib.constants import ( + CauseConstants, + DimensionConstants, + LocationConstants, + ScenarioConstants, + SexConstants, +) +from fhs_lib_database_interface.lib.strategy_set import strategy +from fhs_lib_file_interface.lib.version_metadata import FHSFileSpec, VersionMetadata +from fhs_lib_file_interface.lib.versioning import validate_versions_scenarios_listed +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr_scenario, save_xr_scenario +from fhs_lib_model.lib.gk_model.GKModel import ConvergenceError, GKModel +from fhs_lib_model.lib.repeat_last_year import repeat_last_year +from fhs_lib_year_range_manager.lib.year_range import YearRange +from sksparse.cholmod import CholmodNotPositiveDefiniteError +from tiny_structured_logger.lib import fhs_logging + +from fhs_pipeline_mortality.lib.config_dataclasses import ( + CauseSpecificModelingArguments, + CauseSpecificVersionArguments, +) +from fhs_pipeline_mortality.lib.downloaders import load_cod_dataset, load_paf_covs + +logger = fhs_logging.get_logger() + +ASFR = "asfr" +BETA_GLOBAL = "beta_global" +CUTOFF = 1000 +GAMMA_AGE = "gamma_age" +GAMMA_LOCATION_AGE = "gamma_location_age" +GAMMA_LOCATION = "gamma_location" +INTERCEPT = "intercept" +LN_RISK_SCALAR = "ln_risk_scalar" +NLOCS_CUTOFF = 150 +OMEGA_AMP = 0.0 +# Post-hoc constraint on the sum of all coefficients on time to be <= 0 +# ie, beta_global and gamma_age slopes on time will be set to 0 if +# the mean beta_global < 0 and beta_global + gamma_age > 0 +# called a "rectified linear unit" (RELU) in machine learning parlance. +# Constraint options are nonnegative and nonpositive. +RELU = dict(time_var="nonpositive") +SAVE_FUTURE = True +SAVE_PAST = True +SDI = "sdi" +SDI_KNOT = 0.8 +SDI_PART1 = "sdi_part1" +SDI_PART2 = "sdi_part2" +SDI_TIME = "sdi_time" +TIME_VAR = "time_var" +Y = "y" + + +def main( + modeling_args: CauseSpecificModelingArguments, + version_args: CauseSpecificVersionArguments, + scenarios: Optional[Iterable[int]], +) -> None: + """Run a cause-of-death model using appropriate strategy for input acause & save results. + + Args: + modeling_args (CauseSpecificModelingArguments): dataclass containing + cause specific modeling arguments + version_args (CauseSpecificVersionArguments): dataclass containing + cause specific version arguments + scenarios (Optional[Iterable[int]]): scenarios to model on + """ + if scenarios and version_args.output_scenario: + raise ValueError( + "The --scenarios flag cannot be used in single-scenario mode (with the " + "--output-scenario flag)." + ) + + output_version = version_args.versions.get("future", "death") + output_underlying_version = output_version.append_version_suffix("_underlying") + + validate_versions_scenarios_listed( + versions=[v for v in version_args.versions] + [output_version], + output_versions=[version_args.versions.get("future", "death")], + output_scenario=version_args.output_scenario, + ) + + sex_name = _get_sex_name(sex_id=modeling_args.sex_id) + addcovs = _get_addcovs(modeling_args=modeling_args, acause=modeling_args.acause) + + ds = load_cod_dataset( + addcovs=addcovs, + modeling_args=modeling_args, + versions=version_args.versions, + logspace_conversion_flags=version_args.logspace_conversion_flags, + scenarios=scenarios, + ) + # Get list of SEVs to include as covariates (where PAF=1) + # Don't need to supply past versions or scenarios to get this list + reis = load_paf_covs( + modeling_args=modeling_args, + versions=version_args.versions, + scenarios=[0], + listonly=True, + ) + + skip_gk_causes = get_strategy_causes( + modeling_args=modeling_args, gk_strategy_id=CauseConstants.SKIP_GK_STRATEGY_ID + ) + + # ntd_nema and ntd_dengue aren't stable enough to have a regular gk model run on them. + if modeling_args.acause in skip_gk_causes: + _model_ntd(modeling_args, sex_name, output_version, ds, scenarios) + + if modeling_args.spline: + ds[SDI_PART1] = np.minimum(ds[SDI], SDI_KNOT) + ds[SDI_PART2] = np.maximum(ds[SDI] - SDI_KNOT, 0.0) + + interaction_causes = get_strategy_causes( + modeling_args=modeling_args, gk_strategy_id=CauseConstants.INTERACTION_STRATEGY_ID + ) + + if modeling_args.acause in interaction_causes: + ds[SDI_TIME] = ds[SDI] * ds[TIME_VAR] + + past_da = ds["y"].copy() + ds["y"] = ds["y"].mean("draw") + + no_fixed_effect_risk_cause = get_strategy_causes( + modeling_args=modeling_args, + gk_strategy_id=CauseConstants.NO_FIXED_EFFECT_RISK_STRATEGY_ID, + ) + + if not ( + modeling_args.acause == no_fixed_effect_risk_cause[0] and modeling_args.sex_id == 1 + ): + for r in reis: + ds[r].values[np.isnan(ds[r].values)] = 0.0 + + # if forecasting subnationals but not using them for the fit, + # hold out the subnationals until the prediction step + if not modeling_args.fit_on_subnational and modeling_args.subnational: + locs = ds["level"].to_dataframe("level").reset_index() + locs_national = locs.query("level == 3").location_id.tolist() + locs_subnational = locs.query("level == 4").location_id.tolist() + ds_subnat = ds.sel(location_id=locs_subnational) + ds = ds.sel(location_id=locs_national) + + if SAVE_PAST: + scalar_da = ds[LN_RISK_SCALAR] + if output_underlying_version.scenario is None: + scalar_da = scalar_da.sel(scenario=0, drop=True) + past_underlying = past_da - scalar_da + past_underlying = past_underlying.sel(year_id=modeling_args.years.past_years) + + past_underlying_save_path = FHSFileSpec( + output_underlying_version.with_epoch("past"), + f"{modeling_args.acause}{sex_name}.nc", + ) + save_xr_scenario( + past_underlying, past_underlying_save_path, metric="rate", space="identity" + ) + + fxeff = _make_fixed_effects(modeling_args=modeling_args, ds=ds, reis=reis) + + weight_decay09_causes = get_strategy_causes( + modeling_args=modeling_args, gk_strategy_id=CauseConstants.WEIGHT_DECAY09_STRATEGY_ID + ) + + weight_decay = 0.9 if modeling_args.acause in weight_decay09_causes else 0 + + gkmodel, params = _fit_gkmodel( + modeling_args=modeling_args, + ds=ds, + fxeff=fxeff, + reis=reis, + weight_decay=weight_decay, + ) + + # if predicting at subnationals but fitting on national only, need to add + # subnats back into the data and apply coefficients for the level 3 to + # their level 4 children + if ( + not modeling_args.fit_on_subnational + and modeling_args.subnational + and len(locs_subnational) > 0 + ): + ds = xr.concat([ds, ds_subnat], dim=DimensionConstants.LOCATION_ID) + gkmodel.dataset = ds + gkmodel.coefficients = _apply_location_parent_coefs_to_children(params=params, ds=ds) + + preds = gkmodel.predict() + + cap_causes = get_strategy_causes( + modeling_args=modeling_args, gk_strategy_id=CauseConstants.GK_CAP_STRATEGY_ID + ) + + # cap maternal_hiv and drug causes + if modeling_args.acause in cap_causes: + past_mortality = ( + open_xr_scenario( + FHSFileSpec( + version_args.versions.get(past_or_future="past", stage="death"), + f"{modeling_args.acause}.nc", + ) + )["mean"] + .sel(sex_id=modeling_args.sex_id, year_id=modeling_args.years.past_years) + .sel(acause=modeling_args.acause, drop=True) + ) + # data is in log space - capping happens in normal space + last_past_year = invlog_with_offset(preds).sel(year_id=modeling_args.years.past_end) + preds = log_with_offset( + cap_forecasts( + forecast=invlog_with_offset(preds), + past=past_mortality, + quantiles=Quantiles(0.01, 0.99), + last_past_year=last_past_year, + ) + ) + + # save future results without scalar + if SAVE_FUTURE: + scalar_da = ds[LN_RISK_SCALAR] + if output_underlying_version.scenario is None: + scalar_da = scalar_da.sel(scenario=0) + scalar_da = expand_dimensions(scalar_da, scenario=[-1, 0, 1]) + future_underlying = preds - scalar_da + + future_underlying_save_path = FHSFileSpec( + output_underlying_version, f"{modeling_args.acause}{sex_name}.nc" + ) + save_xr_scenario( + future_underlying, future_underlying_save_path, metric="rate", space="log" + ) + + preds_model_path = FHSFileSpec(output_version, f"{modeling_args.acause}{sex_name}.nc") + write_cod_forecast(data=preds, file_spec=preds_model_path) + params = ( + xr.DataArray([1], dims=[DimensionConstants.SEX_ID], coords=[[modeling_args.sex_id]]) + * params + ) + params_betas_path = FHSFileSpec( + output_version, sub_path=("betas",), filename=f"{modeling_args.acause}{sex_name}.nc" + ) + write_cod_betas( + model_params=params, + file_spec=params_betas_path, + acause=modeling_args.acause, + ) + + +def _model_ntd( + modeling_args: CauseSpecificModelingArguments, + sex_name: str, + output_version: VersionMetadata, + ds: xr.Dataset, + scenarios: Optional[Iterable[int]], +) -> None: + """Special case modeling for ntd_nema. + + It isn't stable enough to have a regular gk model run on it. + """ + logger.info( + "Running repeat_last_year model", + bindings=dict(acause=modeling_args.acause, sex_name=sex_name), + ) + + output_underlying_version = output_version.append_version_suffix("_underlying") + + if SAVE_PAST: + past_underlying = ds["y"].copy() + past_underlying = past_underlying.sel(year_id=modeling_args.years.past_years) + past_underlying.name = "value" + past_underlying.coords[DimensionConstants.SEX_ID] = modeling_args.sex_id + + past_underlying_save_path = FHSFileSpec( + output_underlying_version.with_epoch("past"), + f"{modeling_args.acause}{sex_name}.nc", + ) + save_xr_scenario( + past_underlying, past_underlying_save_path, metric="rate", space="identity" + ) + + preds = _repeat_last_year(ds=ds, years=modeling_args.years, draws=modeling_args.draws) + + if output_underlying_version.scenario is None: # "Legacy" multi-scenario mode. + scenarios = scenarios or ScenarioConstants.SCENARIOS + preds = expand_dimensions(preds, scenario=scenarios) + + preds.coords[DimensionConstants.SEX_ID] = modeling_args.sex_id + + if SAVE_FUTURE: + future_underlying_save_path = FHSFileSpec( + output_underlying_version, f"{modeling_args.acause}{sex_name}.nc" + ) + save_xr_scenario(preds, future_underlying_save_path, metric="rate", space="log") + + preds_model_path = FHSFileSpec(output_version, f"{modeling_args.acause}{sex_name}.nc") + write_cod_forecast(data=preds, file_spec=preds_model_path) + sys.exit() + + +def _make_fixed_effects( + modeling_args: CauseSpecificModelingArguments, ds: xr.Dataset, reis: List[str] +) -> Dict: + """Generate the correct fixed effects to use for the GK Model.""" + # Don't put a spline on SDI for causes that aren't modeled for very many + # locations. + nlocs = len(ds.location_id.values) + if (nlocs > NLOCS_CUTOFF) and modeling_args.spline: + fxeff = { + BETA_GLOBAL: [ + (INTERCEPT, 0), + (SDI_PART1, 0), + (SDI_PART2, 0), + (TIME_VAR, 0), + ] + } + else: + fxeff = { + BETA_GLOBAL: [ + (INTERCEPT, 0), + (SDI, 0), + (TIME_VAR, 0), + ] + } + + interaction_causes = get_strategy_causes( + modeling_args=modeling_args, gk_strategy_id=CauseConstants.INTERACTION_STRATEGY_ID + ) + + if modeling_args.acause in interaction_causes: + fxeff[BETA_GLOBAL] += [(SDI_TIME, 0)] + + no_fixed_effect_risk_cause = get_strategy_causes( + modeling_args=modeling_args, + gk_strategy_id=CauseConstants.NO_FIXED_EFFECT_RISK_STRATEGY_ID, + ) + + if not ( + modeling_args.acause == no_fixed_effect_risk_cause[0] and modeling_args.sex_id == 1 + ): + for r in reis: + fxeff[BETA_GLOBAL] += [(r, 1)] + + road_traffic_causes = get_strategy_causes( + modeling_args=modeling_args, gk_strategy_id=CauseConstants.ROAD_TRAFFIC_STRATEGY_ID + ) + + if modeling_args.acause in road_traffic_causes: + # drop time trend + for cov in [ + (TIME_VAR, 0), + ]: + try: + fxeff[BETA_GLOBAL].remove(cov) + except ValueError: + pass + + cause_maternal_hiv = get_strategy_causes( + modeling_args=modeling_args, gk_strategy_id=CauseConstants.MATERNAL_HIV_STRATEGY_ID + ) + if modeling_args.acause == cause_maternal_hiv[0]: + fxeff[BETA_GLOBAL] += [("hiv", 1)] + fxeff[BETA_GLOBAL].remove((SDI, 0)) + + if modeling_args.acause == "malaria": + fxeff[BETA_GLOBAL].remove((SDI, 0)) + fxeff[BETA_GLOBAL].remove((TIME_VAR, 0)) + + vaccine_causes = get_strategy_causes( + modeling_args=modeling_args, gk_strategy_id=CauseConstants.VACCINE_STRATEGY_ID + ) + + if modeling_args.acause in vaccine_causes: + try: + fxeff[BETA_GLOBAL].remove((SDI_PART1, 0)) + fxeff[BETA_GLOBAL].remove((SDI_PART2, 0)) + except ValueError: + fxeff[BETA_GLOBAL].remove((SDI, 0)) + + maternal_causes = get_strategy_causes( + modeling_args=modeling_args, gk_strategy_id=CauseConstants.MATERNAL_CAUSES_STRATEGY_ID + ) + + parent_maternal = get_strategy_causes( + modeling_args=modeling_args, gk_strategy_id=CauseConstants.PARENT_MATERNAL_STRATEGY_ID + ) + + if modeling_args.acause in maternal_causes: + if modeling_args.acause != parent_maternal[0]: + fxeff[BETA_GLOBAL].append(("asfr", 1)) + + notime_maternal_causes = get_strategy_causes( + modeling_args=modeling_args, gk_strategy_id=CauseConstants.MATERNAL_NOTIME_STRATEGY_ID + ) + nosdi_maternal_causes = get_strategy_causes( + modeling_args=modeling_args, gk_strategy_id=CauseConstants.MATERNAL_NOSDI_STRATEGY_ID + ) + + if modeling_args.acause in notime_maternal_causes: + fxeff[BETA_GLOBAL].remove((TIME_VAR, 0)) + + if modeling_args.acause in nosdi_maternal_causes: + fxeff[BETA_GLOBAL].remove((SDI, 0)) + + return fxeff + + +def _fit_gkmodel( + modeling_args: CauseSpecificModelingArguments, + ds: xr.Dataset, + fxeff: Dict, + reis: List[str], + weight_decay: float, +) -> Tuple[GKModel, xr.Dataset]: + gkmodel = _first_gk_model( + modeling_args=modeling_args, ds=ds, fxeff=fxeff, weight_decay=weight_decay + ) + # Try fitting with all the specified fixed effects. If a convergence + # error is given, drop all but the location-age intercepts, and set time + # to be a global variable. If the model technically converges but any of + # the variables have coefficients with unreasonably high standard + # deviations, drop said variables and try again. + try: + params = gkmodel.fit() + cov_to_drop = _covariates_to_drop( + modeling_args=modeling_args, reis=reis, params=params + ) + if cov_to_drop: + logger.info("refitting after dropping covariates") + gkmodel = _second_gk_model( + modeling_args=modeling_args, + ds=ds, + fxeff=fxeff, + weight_decay=weight_decay, + cov_to_drop=cov_to_drop, + ) + params = gkmodel.fit() + except (ConvergenceError, CholmodNotPositiveDefiniteError): + logger.info("refitting after dropping sdi and all covs") + # drop sdi and covariates if model still doesn't converge + gkmodel = _third_gk_model( + modeling_args=modeling_args, ds=ds, fxeff=fxeff, weight_decay=weight_decay + ) + params = gkmodel.fit() + + # apply post-hoc RELU if the dictionary defined in settings is non-empty + if RELU: + params = post_hoc_relu(params=params, cov_dict=RELU) + return gkmodel, params + + +def _first_gk_model( + modeling_args: CauseSpecificModelingArguments, + ds: xr.Dataset, + fxeff: Dict, + weight_decay: float, +) -> GKModel: + """Set up our ideal model.""" + raneff = _make_random_effects(modeling_args) + constants = _make_constants(modeling_args) + + gkmodel = GKModel( + ds, + years=modeling_args.years, + fixed_effects=fxeff, + random_effects=raneff, + draws=modeling_args.draws, + constants=constants, + y=Y, + omega_amp=OMEGA_AMP, + weight_decay=weight_decay, + seed=modeling_args.seed, + ) + + return gkmodel + + +def get_strategy_causes( + modeling_args: CauseSpecificModelingArguments, + gk_strategy_id: int, +) -> List[str]: + """Return cause list based on the strategy id.""" + with db_session.create_db_session() as session: + causes = strategy.get_cause_set( + session=session, + strategy_id=gk_strategy_id, + gbd_round_id=modeling_args.gbd_round_id, + ).acause.values.tolist() + + return causes + + +def _make_random_effects(modeling_args: CauseSpecificModelingArguments) -> Dict: + """Set up the random effects for the ideal model.""" + raneff = { + GAMMA_LOCATION_AGE: [INTERCEPT], + GAMMA_AGE: [TIME_VAR], + } + + road_traffic_causes = get_strategy_causes( + modeling_args=modeling_args, gk_strategy_id=CauseConstants.ROAD_TRAFFIC_STRATEGY_ID + ) + + if modeling_args.acause in road_traffic_causes: + # drop time if present in gamma_age + raneff = {GAMMA_LOCATION_AGE: [INTERCEPT]} + + notime_maternal_causes = get_strategy_causes( + modeling_args=modeling_args, gk_strategy_id=CauseConstants.MATERNAL_NOTIME_STRATEGY_ID + ) + + if modeling_args.acause in notime_maternal_causes: + raneff[GAMMA_AGE].remove(TIME_VAR) + if modeling_args.acause == "malaria": + raneff = { + GAMMA_LOCATION_AGE: [INTERCEPT], + GAMMA_LOCATION: [TIME_VAR], + } + + return raneff + + +def _make_constants(modeling_args: CauseSpecificModelingArguments) -> List[str]: + """Set up the constants for the ideal model.""" + parent_maternal = get_strategy_causes( + modeling_args=modeling_args, gk_strategy_id=CauseConstants.PARENT_MATERNAL_STRATEGY_ID + ) + if modeling_args.acause == parent_maternal[0]: + constants = [LN_RISK_SCALAR, ASFR] + else: + constants = [LN_RISK_SCALAR] + return constants + + +def _second_gk_model( + modeling_args: CauseSpecificModelingArguments, + ds: xr.Dataset, + fxeff: Dict, + weight_decay: float, + cov_to_drop: List[str], +) -> GKModel: + """Set up the GK Model, dropping some number of covariates. + + Intended for the case in which some number of covariates produced weird results. + """ + for covariate in cov_to_drop: + if covariate in [ + SDI, + SDI_PART1, + SDI_PART2, + ]: + fxeff[BETA_GLOBAL].remove((covariate, 0)) + else: + fxeff[BETA_GLOBAL].remove((covariate, 1)) + + raneff = {GAMMA_LOCATION_AGE: [INTERCEPT]} + gkmodel = GKModel( + ds, + years=modeling_args.years, + fixed_effects=fxeff, + random_effects=raneff, + draws=modeling_args.draws, + constants=[LN_RISK_SCALAR], + y=Y, + omega_amp=OMEGA_AMP, + weight_decay=weight_decay, + seed=modeling_args.seed, + ) + return gkmodel + + +def _third_gk_model( + modeling_args: CauseSpecificModelingArguments, + ds: xr.Dataset, + fxeff: Dict, + weight_decay: float, +) -> GKModel: + """Set up the GK model for the case in which the original model didn't converge.""" + raneff = {GAMMA_LOCATION_AGE: [INTERCEPT]} + + road_traffic_causes = get_strategy_causes( + modeling_args=modeling_args, gk_strategy_id=CauseConstants.ROAD_TRAFFIC_STRATEGY_ID + ) + + if modeling_args.acause in road_traffic_causes: + fxeff[BETA_GLOBAL] = [(INTERCEPT, 0)] + else: + fxeff[BETA_GLOBAL] = [ + (TIME_VAR, 0), + (INTERCEPT, 0), + ] + gkmodel = GKModel( + ds, + years=modeling_args.years, + fixed_effects=fxeff, + random_effects=raneff, + draws=modeling_args.draws, + constants=[LN_RISK_SCALAR], + y=Y, + omega_amp=OMEGA_AMP, + weight_decay=weight_decay, + seed=modeling_args.seed, + ) + return gkmodel + + +def _covariates_to_drop( + modeling_args: CauseSpecificModelingArguments, reis: List[str], params: xr.Dataset +) -> List[str]: + """Determine which covariates during model fitting need to be dropped.""" + cov_list = [SDI, SDI_PART1, SDI_PART2] + reis + + maternal_causes = get_strategy_causes( + modeling_args=modeling_args, gk_strategy_id=CauseConstants.MATERNAL_CAUSES_STRATEGY_ID + ) + + if modeling_args.acause in maternal_causes: + cov_list += ["asfr"] + + return [c for c in cov_list if _is_cov_bad(c, params)] + + +def _is_cov_bad(covariate: str, params: xr.Dataset) -> bool: + """Determine if a covariate produced weird results when fitting. + + Currently we test if the SD is NaN or too large. + """ + try: + sd = params[BETA_GLOBAL].sel(cov=covariate).std() + median_coeff = params[BETA_GLOBAL].sel(cov=covariate).median() + except (ValueError, IndexError, KeyError): + logger.warning(f"Something went wrong processing covariate {covariate}; ignoring.") + return False + + sd_too_large = abs(sd / median_coeff) > CUTOFF + return np.isnan(sd) or sd_too_large + + +def _repeat_last_year(ds: xr.Dataset, years: YearRange, draws: int) -> xr.DataArray: + """Run a model for mortality using the RepeatLastYear model. + + Args: + ds (xr.Dataset): Dataset containing y as the response variable + and time_var as the time variable + years (YearRange): past and forecasted years (e.g. 1990:2017:2040) + draws (int): number of draws to return (will all be identical) + + Returns: + xr.DataArray: the past and projected values (held constant) + """ + logger.info("running repeat_last_year") + + rly = repeat_last_year.RepeatLastYear(ds["y"], years=years) + rly.fit() + preds = expand_dimensions(rly.predict(), draw=np.arange(draws)) + return preds + + +def _get_addcovs(modeling_args: CauseSpecificModelingArguments, acause: str) -> List[str]: + """Returns a list of the additional covariates associated with a given acause. + + Args: + acause (str): Cause to find covariates for + + Returns: + List[str]: The non-sev covariates associated with the input acause + """ + addcovs = [] + + maternal_causes = get_strategy_causes( + modeling_args=modeling_args, gk_strategy_id=CauseConstants.MATERNAL_CAUSES_STRATEGY_ID + ) + cause_maternal_hiv = get_strategy_causes( + modeling_args=modeling_args, gk_strategy_id=CauseConstants.MATERNAL_HIV_STRATEGY_ID + ) + + if acause in maternal_causes: + addcovs.append("asfr") + if acause == cause_maternal_hiv[0]: + addcovs.append("hiv") + logger.info("Additional covariates", bindings=dict(addcovs=addcovs)) + + return addcovs + + +def _get_sex_name(sex_id: int) -> str: + """Gets the sex name associated with a given sex_id. + + Args: + sex_id (int): Numeric sex ID. + + Returns: + str: '_male' for 1 or '_female' for 2. + + Raises: + ValueError: If ``sex_id`` is invalid. + """ + try: + return "_" + SexConstants.SEX_DICT[sex_id] + except KeyError: + raise ValueError(f"sex_id must be {SexConstants.SEX_DICT.keys()}; {sex_id} given") + + +def _apply_location_parent_coefs_to_children(params: xr.Dataset, ds: xr.Dataset) -> xr.Dataset: + """Apply location parent coefs to children. + + Adds coefficient estimates for subnational locations to a dataset of + national coefficients by setting the level 4 subnational values to have + the same values as their level 3 national parent. + + Args: + params (xr.Dataset): coefficient/parameter dataset output from the + GKModel class. + ds (xr.Dataset): cause of death dataset output from load_cod_dataset, + which includes the level 4 subnational locations in need of their + parent parameters. + + Returns: + xr.Dataset: coefficient/parameter dataset with values for all locations in ds. + + Raises: + ValueError: If no subnationals in ``ds`` or if missing level 3 parent + locations in the supplied cod dataset. + """ + if not np.isin(LocationConstants.SUBNATIONAL_LEVEL, ds["level"].values): + raise ValueError("No subnationals in ds.") + + # get parents and location_ids of level 4 locations present in ds + locs = ( + ds["parent_id"] + .where(ds.level == LocationConstants.SUBNATIONAL_LEVEL) + .to_dataframe("parent_id") + .reset_index() + .dropna() + ) + parents = params[DimensionConstants.LOCATION_ID].values + parents_with_children = locs.parent_id.unique() + + if not np.isin(parents_with_children, parents).all(): + raise ValueError("Missing level 3 parent locations in the supplied cod dataset.") + + # replicate the parent parameters for all children, so that the parent + # location random effects get applied to the children in GKModel predict + add_params = [params] + for parent in parents_with_children: + parent_params = params.sel(location_id=int(parent), drop=True) + children = locs.query("parent_id == @parent").location_id.values + child_params = expand_dimensions(parent_params, location_id=children) + add_params.append(child_params) + + add_params = xr.concat(add_params, dim=DimensionConstants.LOCATION_ID) + + return add_params + + +def write_cod_forecast(data: xr.DataArray, file_spec: FHSFileSpec) -> None: + """Save a cod model and assert that it has the appropriate dimensions. + + Args: + data (xr.DataArray): Dataarray with cod mortality rate. + file_spec (FHSFileSpec): the spec where data should be saved (includes path and + scenario ID). + + Raises: + ValueError: If data is missing dimensions. + """ + keys = ["year_id", "age_group_id", "location_id", "sex_id", "draw"] + missing = [k for k in keys if k not in list(data.coords.keys())] + if len(missing) > 0: + raise ValueError(f"Data is missing dimensions: {missing}.") + save_xr_scenario( + xr_obj=data, + file_spec=file_spec, + metric="rate", + space="log", + ) + + +def write_cod_betas( + model_params: xr.Dataset, + file_spec: FHSFileSpec, + acause: str, + save_draws: bool = True, +) -> None: + """Write the betas from a gk cod model run. + + Args: + model_params (xr.Dataset): Xarray Dataset with covariate and draw information. + file_spec (FHSFileSpec): the path to save betas into, as a scenario-aware FHSFileSpec. + acause (str): The string of the acause to place saved results in correct location. + save_draws (bool): Whether to save the estimated regression coefficients + (means) or samples from their posterior distribution (draws). + """ + model_params *= xr.DataArray([1], dims=["acause"], coords=[[acause]]) + if not save_draws: + model_params = model_params.mean("draw") + save_xr_scenario( + xr_obj=model_params, + file_spec=file_spec, + metric="rate", + space="log", + ) diff --git a/gbd_2021/disease_burden_forecast_code/mortality/lib/smoothing.py b/gbd_2021/disease_burden_forecast_code/mortality/lib/smoothing.py new file mode 100644 index 0000000..7cf190d --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/mortality/lib/smoothing.py @@ -0,0 +1,91 @@ +"""This module contains utilities related smoothing in the latent trend model.""" + +from typing import List, Union + +import pandas as pd +from fhs_lib_database_interface.lib.constants import CauseConstants, DimensionConstants +from fhs_lib_database_interface.lib.strategy_set import query +from sqlalchemy.orm import Session +from tiny_structured_logger.lib.fhs_logging import get_logger + +logger = get_logger() + + +def get_smoothing_dims(acause: str, gbd_round_id: int, session: Session) -> List[str]: + """Get the dims to smooth over during the ARIMA or Random Walk modeling. + + Args: + acause: The cause for which to get smoothing dimensions. + gbd_round_id (int): The numeric ID for the GBD round. + session: session with the database, used to get the cause hierarchy. + + Returns: + List[str]: The dimensions to smooth over. + """ + cause_hierarchy_version_id = query.get_hierarchy_version_id( + session=session, + entity_type=DimensionConstants.CAUSE, + entity_set_id=CauseConstants.FHS_CAUSE_SET_ID, + gbd_round_id=gbd_round_id, + ) + + cause_hierarchy = query.get_hierarchy( + session=session, + entity_type=DimensionConstants.CAUSE, + hierarchy_version_id=cause_hierarchy_version_id, + ) + return _smoothing_dims_from_hierarchy(acause, cause_hierarchy) + + +def _smoothing_dims_from_hierarchy(acause: str, cause_hierarchy: pd.DataFrame) -> List[str]: + """The dimensions to smooth over during ARIMA modeling, based on the level of acause. + + Args: + acause (str): The cause for which to get smoothing dimensions. + cause_hierarchy (pd.Dataframe): The cause hierarchy used to determine the level of the + given cause. Must have ``acause`` and ``level`` as columns. + + Returns: + List[str]: The dimensions to smooth over. + """ + LOCATION_SEX_AGE_DIMS = [ + DimensionConstants.LOCATION_ID, + DimensionConstants.SEX_ID, + DimensionConstants.AGE_GROUP_ID, + ] + REGION_SEX_AGE_DIMS = [ + DimensionConstants.REGION_ID, + DimensionConstants.SEX_ID, + DimensionConstants.AGE_GROUP_ID, + ] + SUPERREGION_SEX_AGE_DIMS = [ + DimensionConstants.SUPER_REGION_ID, + DimensionConstants.SEX_ID, + DimensionConstants.AGE_GROUP_ID, + ] + SMOOTHING_LOOKUP = { + 0: LOCATION_SEX_AGE_DIMS, + 1: LOCATION_SEX_AGE_DIMS, + 2: REGION_SEX_AGE_DIMS, + 3: SUPERREGION_SEX_AGE_DIMS, + "modeled": SUPERREGION_SEX_AGE_DIMS, + } + + cause_level = cause_hierarchy.query(f"{DimensionConstants.ACAUSE} == @acause")[ + DimensionConstants.LEVEL_COL + ].unique()[0] + + logger.debug( + "Pulling smoothing dims", + bindings=dict(acause=acause, cause_level=cause_level), + ) + + level_option: Union[int, str] + if cause_level > 2: + level_option = "modeled" + else: + level_option = cause_level + + smoothing_dims = SMOOTHING_LOOKUP[level_option] + + return smoothing_dims diff --git a/gbd_2021/disease_burden_forecast_code/mortality/lib/squeeze.py b/gbd_2021/disease_burden_forecast_code/mortality/lib/squeeze.py new file mode 100644 index 0000000..f9122f6 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/mortality/lib/squeeze.py @@ -0,0 +1,433 @@ +"""Module for squeezing results into an envelope. + +The envelope starts with _all_star.nc in the hat_star_dot directory and is generated by +squeezing children causes into their parent envelopes. + +By running `python squeeze.py --parent-cause --squeeze-version + --star-version `, the parent acause will act as +the envelope to squeeze its children causes into and must +already be in the FILEPATH directory. +The one exception to this is if is "_all", in which is it +moved from the hat_star_dot/ directory into the +FILEPATH directory to begin the squeezing process. + +The things being squeezed are the ``*_star.nc`` files from the hat_star_dot +directory and are saved in the FILEPATH directory as a +separate mortality or yld version. + +The number of draws of the ``*_star.nc`` files must match the number of draws of +the _all_star.nc file, and subsequently the resulting squeezed mortality or +ylds in the FILEPATH directory must have the same number of draws. +""" + +from typing import List, Optional + +import numpy as np +import pandas as pd +import xarray as xr +from fhs_lib_database_interface.lib.constants import DimensionConstants +from fhs_lib_file_interface.lib.version_metadata import FHSFileSpec, VersionMetadata +from fhs_lib_file_interface.lib.versioning import validate_versions_scenarios_listed +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr_scenario, save_xr_scenario +from fhs_lib_year_range_manager.lib.year_range import YearRange +from tiny_structured_logger.lib import fhs_logging + +from fhs_pipeline_mortality.lib import get_fatal_causes +from fhs_pipeline_mortality.lib.intercept_shift import intercept_shift_draws + +logger = fhs_logging.get_logger() + +FLOOR = 1e-28 + + +def squeeze( + parent_cause: str, + star_version: VersionMetadata, + squeeze_version: VersionMetadata, + past_version: Optional[VersionMetadata], + gbd_round_id: int, + input_space: str, + years: YearRange, + output_scenario: Optional[int], + dryrun: bool = False, + intercept_shift: bool = False, +) -> None: + """Main method for squeezing. + + Makes sure the sum of the parent_cause's children is equal to the parent + cause's estimates, then saves the squeezed child cause results. + + Args: + parent_cause (str): the acause whose estimates used as the envelope + which its child causes are squeezed into. + star_version (VersionMetadata): the version of mortality with ARIMAed residuals + being squeezed. + squeeze_version (VersionMetadata): the mortality or yld version the squeezed + results will be saved as. The parent_cause envelope will + be of this squeeze_version. + past_version (VersionMetadata): past mortality version to intercept shift to. + gbd_round_id (int): numeric id of gbd round. + input_space (str): space the input data is saved in e.g. log. + years (YearRange): the years to run. + outptu_scenario (Optional[int]): The output scenario if running in single-scenario + mode. + dryrun (bool): When True, don't write anything, just do calculations and print + what files would be written. Defaults to False. + intercept_shift (bool): (Test only) When True, intercept-shift the results of + scaling so that the last-past year matches data in ``past_version``. Defaults + to True. (This is set to False in unit-testing to isolate and test certain + subsystems.) + + Raises: + ValueError: if the parent data shape != the child data shape + """ + versions_to_validate = [star_version, squeeze_version] + if past_version: + versions_to_validate.append(past_version) + + validate_versions_scenarios_listed( + versions=versions_to_validate, + output_versions=[squeeze_version], + output_scenario=output_scenario, + ) + + children_sum = xr.DataArray(name="Empty xr.DataArray") + children_acauses = get_children_acauses(parent_cause, gbd_round_id) + logger.debug(f"Children acauses are: {children_acauses}") + + if input_space == "log": + suffix = "_star" + elif input_space == "normal": + suffix = "" + else: + err_msg = "input_space argument not recognized. Options are `log` or `normal`." + raise ValueError(err_msg) + + for child_acause in children_acauses: + child_data = get_y_star(child_acause, star_version, input_space, suffix) + children_sum = _update_children_sum(children_sum, child_data) + + parent_data = open_xr_scenario(FHSFileSpec(squeeze_version, f"{parent_cause}.nc")) + parent_data.load() + + if ( + parent_data.coords[DimensionConstants.DRAW].shape + != children_sum.coords[DimensionConstants.DRAW].shape + ): + raise ValueError("The parent and child don't have the same number of draws.") + + # Calculate ratio on means, not draws. + parent_data_mean = parent_data.mean("draw") + children_sum_mean = children_sum.mean("draw") + ratio = parent_data_mean / children_sum_mean + + # Loop through and multiply all children by ratio, saving results. + for child in children_acauses: + logger.info(f"Squeezing {child}") + squeezed_child = _multiply_children_by_ratio( + star_version, child, ratio, suffix, input_space + ) + + _save_squeezed_child_result( + squeezed_child, + child, + gbd_round_id, + years, + squeeze_version, + past_version, + dryrun, + intercept_shift, + ) + + +def _save_squeezed_child_result( + squeezed_child: xr.DataArray, + child: str, + gbd_round_id: int, + years: YearRange, + squeeze_version: VersionMetadata, + past_version: Optional[VersionMetadata], + dryrun: bool, + intercept_shift: bool, +) -> None: + """Save squeezed child result.""" + # The intercept shift was added to fix the t+1 introduced by updating the + # malaria to lvl2 cause in squeeze. Now the t+1 has been resolved by + # updating the malaria to lvl2 cause from stage 2 before ARIMA and create + # the new _ntd without malaria cause. Set intercept shift to False. + if intercept_shift: + squeezed_child = squeezed_child.dropna(how="all", dim="sex_id") + log_squeezed_child = np.log(squeezed_child + FLOOR) + log_squeezed_child_shifted = intercept_shift_draws( + preds=log_squeezed_child, + acause=child, + past_version=past_version, + gbd_round_id=gbd_round_id, + years=years, + draws=len(log_squeezed_child.draw.values), + shift_function=unordered_draw_intercept_shift, + ) + squeezed_child = np.exp(log_squeezed_child_shifted) + + # fill NaNs with 0s + squeezed_child = squeezed_child.fillna(0.0) + + out_path = FHSFileSpec(squeeze_version, f"{child}.nc") + + if dryrun: + logger.info(f"(Dry run) Not writing to {out_path}") + else: + logger.info(f"Writing to {out_path}") + save_xr_scenario(squeezed_child, out_path, metric="rate", space="identity") + + +def _remove_acause_dim(da: xr.DataArray) -> xr.DataArray: + if DimensionConstants.ACAUSE in da.coords: + da = da.drop_vars(DimensionConstants.ACAUSE) + return da.squeeze() + + +def _multiply_children_by_ratio( + star_version: VersionMetadata, + child: str, + ratio: xr.DataArray, + suffix: str, + input_space: str, +) -> xr.DataArray: + """Multiply all children by ratio, saving results.""" + path_in = FHSFileSpec(star_version, f"{child}{suffix}.nc") + child_input_data = open_xr_scenario(path_in) + if input_space == "log": + child_data = np.exp(child_input_data) + else: + child_data = child_input_data + + child_data = _remove_acause_dim(child_data) + + # for single-sex causes, add nans to create dimension sex_id + child_data = makeup_array(child_data) + + squeezed_child = child_data * ratio + + return squeezed_child + + +def _update_children_sum( + children_sum: xr.DataArray, + child_data: xr.DataArray, +) -> xr.DataArray: + """Update children sum based on acause and sex. + + For causes with coord `acause`, drop `acause`. + And for ntd causes (the dataset does not have coord `acause`), squeeze the causes, + then add nans for single-sex causes. + Lastly, broadcast child data array against children sum data array. + """ + child_data = _remove_acause_dim(child_data) + + # for single-sex causes, add nans to create dimension sex_id + child_data = makeup_array(child_data) + + if children_sum.name == "Empty xr.DataArray": + children_sum = child_data + else: + children_broadcast = xr.broadcast(children_sum, child_data) + children_broadcast = [data.fillna(0.0) for data in children_broadcast] + children_sum = sum(children_broadcast) + children_sum.load() + + return children_sum + + +def unordered_draw_intercept_shift( + modeled_data: xr.DataArray, + past_data: xr.DataArray, + past_end_year_id: int, +) -> xr.DataArray: + """Intercept shift by the last past year in log space. + + Args: + modeled_data (xr.DataArray): FHS estimates + past_data (xr.DataArray): Past estimates from GBD + past_end_year_id (int): Last year of past data + + Returns: + xr.DataArray: Shifted estimates + """ + past_data = past_data.sel(**{DimensionConstants.YEAR_ID: past_end_year_id}, drop=True) + coords = {DimensionConstants.YEAR_ID: past_end_year_id} + if DimensionConstants.SCENARIO in modeled_data: + coords = dict( + coords, **{DimensionConstants.SCENARIO: DimensionConstants.REFERENCE_SCENARIO} + ) + modeled_last_past_year = modeled_data.sel(**coords, drop=True) + diff = modeled_last_past_year - past_data + shifted_data = modeled_data - diff + return shifted_data + + +def makeup_array(child_data: xr.DataArray) -> xr.DataArray: + """Add dimension sex_id for single_sex causes. + + Args: + child_data (xr.DataArray): child data with single_sex causes + + Returns: + xr.DataArray: child data with opposite_sex data + """ + if DimensionConstants.SEX_ID not in child_data.dims: + na_array = child_data.copy() + sex_value = child_data[DimensionConstants.SEX_ID].values + if sex_value == 1: + opposite_sex = 2 + else: + opposite_sex = 1 + na_array = na_array * 0.0 + na_array[DimensionConstants.SEX_ID] = opposite_sex + child_data = xr.concat([child_data, na_array], dim=DimensionConstants.SEX_ID) + return child_data + return child_data + + +def _drop_measure(data: xr.DataArray) -> xr.DataArray: + """Helper function as a preprocessing step when calling xarray's open_mfdataset function. + + In order to make sure all the dimensions are compatible across dataarrays. + + Args: + data (xr.DataArray): the dataarray to preprocess + + Returns: + xr.DataArray: the original data with the measure dimension dropped + if it originally existed. + """ + try: + data = data.drop_vars("measure") + except ValueError: + pass + try: + data.rename({"ds = da.to_dataset": "value"}) + except ValueError: + pass + return data + + +def get_y_star( + acause: str, + star_version: VersionMetadata, + input_space: str, + suffix: str, + draws: bool = True, +) -> xr.DataArray: + """Gets estimates of means and modeled residuals (``*_star.nc`` file) for a cause. + + Results are returned in normal space regardless of the input space. + + Args: + acause_list (str): acause to get data for. + star_version (VersionMetadata): the version whose data is being read. + input_space (str): "log" or "normal" as describes the space of the data; it will be + transformed to normal space. + suffix (str): a str to be appended to each cause name to produce the file names. + draws (bool): if False, the mean over the draws dimension is used. + + Returns: + xr.DataArray: data in regular rate space which contains a dimension for acause. + """ + input_data = open_xr_scenario(FHSFileSpec(star_version, f"{acause}{suffix}.nc")) + + if not draws: + input_data = input_data.mean(DimensionConstants.DRAW) # take mean to get rid of draws + + if input_space == "log": + y_star = np.exp(input_data) # exponentiate into normal space + elif input_space == "normal": + y_star = input_data + return y_star + + +def get_children_acauses(acause: str, gbd_round_id: int) -> List[str]: + """Gets the children acauses for a given acause. + + Does not include any that are in the CAUSES_TO_EXCLUDE list + + Args: + acause (str): the acause of the cause to find children of. + gbd_round_id (int): numeric id of gbd round. + + Returns: + List[str]: the children acauses of the input acause. + """ + fatal_causes = get_fatal_causes.get_fatal_causes_df(gbd_round_id)[ + [ + DimensionConstants.ACAUSE, + DimensionConstants.CAUSE_ID, + DimensionConstants.PARENT_ID_COL, + ] + ] + return get_children_acauses_from_dataframe(acause, fatal_causes) + + +def get_children_acauses_from_dataframe( + acause: str, cause_hierarchy: pd.DataFrame +) -> List[str]: + """Get the children of the given acause from a dataframe representing the hierarchy.""" + cause_id = cause_hierarchy[cause_hierarchy.acause == acause].cause_id.values[0] + all_children = cause_hierarchy.query(f"parent_id == {cause_id}")[ + DimensionConstants.ACAUSE + ].values + children = [child for child in all_children if child not in ("_all", "_none")] + return children + + +def copy_all_star( + star_version: VersionMetadata, + squeeze_version: VersionMetadata, + input_space: str, + output_scenario: Optional[int], + dryrun: bool = False, +) -> None: + """Copies _all data from star_version into squeeze_version, possibly translating. + + If input_space is log, the data is converted from log rate space into regular rate space + before being saved. In that case, too, the file will be found in _all_star.nc rather than + _all.nc + + Args: + star_version (str): the version of mortality with ARIMAed residuals + squeeze_version (str): the version of mortality with squeezed results + input_space (str): space the input data is saved in. + root_dir (str): FHS file system root directory. Where the input data is stored. + output_scenario (Optional[int]): The output scenario if running in single scenario mode + dryrun (bool): flag for dry runs + """ + versions_to_validate = [star_version, squeeze_version] + + validate_versions_scenarios_listed( + versions=versions_to_validate, + output_versions=[squeeze_version], + output_scenario=output_scenario, + ) + + if input_space == "log": + suffix = "_star" + else: + suffix = "" + all_input_data = open_xr_scenario(FHSFileSpec(star_version, f"_all{suffix}.nc")) + + if input_space == "log": + all_data = np.exp(all_input_data) + else: + all_data = all_input_data + + out_path = FHSFileSpec(squeeze_version, "_all.nc") + if dryrun: + logger.info(f"(Dry run) Not writing to {out_path}") + else: + logger.info(f"Writing to {out_path}") + save_xr_scenario( + all_data, + out_path, + metric="rate", + space="identity", + ) diff --git a/gbd_2021/disease_burden_forecast_code/mortality/lib/sum_to_all_cause.py b/gbd_2021/disease_burden_forecast_code/mortality/lib/sum_to_all_cause.py new file mode 100644 index 0000000..2f6fa30 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/mortality/lib/sum_to_all_cause.py @@ -0,0 +1,425 @@ +""" +This module aggregates the expected value of mortality or ylds up the cause +hierarchy and computes the following at each step: + + y_hat = expected value of mortality or ylds + y_past = past mortality or ylds + +This is also used to aggregate post-approximation mortality results which are +saved in normal space. All results are saved in their own directory. + +It is assumed that the inputs for modeled cause specific mortality or ylds with +version , contain y_hat data in log space. That is, only draws of the +expected value. For aggregated approximated mortality results, no assumptions +are made and modeling space is set with the --space argument. All data in +normal space will be saved in the `data` root directory. Log space data will be +saved in the `int` root directory. + +Spaces: + - Modeled cause results split by sex are in log rate space. + - Modeled cause results not split by sex are in normal rate space. This can be specified. + - Past cause mortality or ylds are in normal space. + - Summing to aggregate causes happens in normal rate space. +""" + +import numpy as np +import xarray as xr +from fhs_lib_data_transformation.lib.resample import resample +from fhs_lib_database_interface.lib import db_session +from fhs_lib_database_interface.lib.constants import ( + CauseConstants, + DimensionConstants, + FHSDBConstants, +) +from fhs_lib_database_interface.lib.strategy_set import strategy +from fhs_lib_file_interface.lib import provenance +from fhs_lib_file_interface.lib.provenance import ProvenanceManager +from fhs_lib_file_interface.lib.version_metadata import FHSFileSpec +from fhs_lib_file_interface.lib.versioning import validate_versions_scenarios_listed +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr_scenario, save_xr_scenario +from tiny_structured_logger.lib import fhs_logging + +from fhs_pipeline_mortality.lib import get_fatal_causes +from fhs_pipeline_mortality.lib.config_dataclasses import SumToAllCauseModelingArguments + +logger = fhs_logging.get_logger() + +FLOOR = 1e-28 + +EXPECTED_VALUE_SUFFIX = "_hat" +INTERMEDIATE_ROOT_DIR = "int" +DATA_ROOT_DIR = "data" +NORMAL_SPACE = "identity" +LOG_SPACE = "log" + + +def aggregate_yhats( + modeling_args: SumToAllCauseModelingArguments, + dryrun: bool = False, +) -> None: + """Computes y hats. + + If acause is a modeled cause, then it is simply moved into the right folder. Otherwise, + the subcauses of acause are aggregated. + + Args: + modeling_args (SumToAllCauseModelingArguments): dataclass containing sum-to-all-cause + modeling arguments. + dryrun (bool): dryrun flag. Don't actually do anything, just pretend like you're doing + it. + """ + validate_versions_scenarios_listed( + versions=[modeling_args.agg_version, modeling_args.input_version], + output_versions=[modeling_args.agg_version], + output_scenario=modeling_args.output_scenario, + ) + + if modeling_args.approximation: + suffix = "" + root_dir = DATA_ROOT_DIR + space = NORMAL_SPACE + else: + suffix = EXPECTED_VALUE_SUFFIX + root_dir = INTERMEDIATE_ROOT_DIR + space = LOG_SPACE + + y_hat = _get_y_hat( + modeling_args=modeling_args, space=space, suffix=suffix, root_dir=root_dir + ) + num_draws = len(y_hat[DimensionConstants.DRAW].values) + logger.info(f"The data has {num_draws} draws.") + y_hat_out = FHSFileSpec( + modeling_args.agg_version.with_root_dir(root_dir), f"{modeling_args.acause}{suffix}.nc" + ) + if dryrun: + logger.info("(Dry run). Not saving y_hat", bindings=dict(out_file=y_hat_out)) + else: + # Save data + logger.info("Saving y_hat", bindings=dict(out_file=y_hat_out)) + save_xr_scenario( + xr_obj=y_hat, + file_spec=y_hat_out, + metric="rate", + space=space, + ) + + +def _get_y_hat( + modeling_args: SumToAllCauseModelingArguments, + space: str, + suffix: str, + root_dir: str, +) -> xr.DataArray: + """Gets expected value of cause specific mortality or yld rates. + + For modeled causes, if the data is split by sex, then it is assumed that it + is in log rate space. If the data is not split by sex, then it is assumed + that it is in normal rate space. + + For aggregate causes, it is assumed that the data is not split by sex and + is saved in log rate space. + + The resulting y_hat is in log rate space. + + Args: + modeling_args (SumToAllCauseModelingArguments): dataclass containing sum-to-all-cause + modeling arguments. + space (str): space the data should be returned in, e.g. "log" or "identity". + suffix (str): stage 2 should save data with "_hat" + root_dir (str): Root directory for saving data + + Returns: + xr.DataArray: The expected value of the cause specific mortality or yld rate + """ + with db_session.create_db_session(FHSDBConstants.FORECASTING_DB_NAME) as session: + gk_causes = strategy.get_cause_set( + session=session, + strategy_id=CauseConstants.FATAL_GK_STRATEGY_ID, + gbd_round_id=modeling_args.gbd_round_id, + )[DimensionConstants.ACAUSE].values + + if modeling_args.period == "past": + gk_causes = np.delete(gk_causes, np.where(gk_causes == "maternal")) + gk_causes = np.delete(gk_causes, np.where(gk_causes == "ckd")) + + if modeling_args.acause in gk_causes: + logger.info("modeled cause", bindings=(dict(modeled_cause=modeling_args.acause))) + y_hat = _get_modeled_y_hat(modeling_args=modeling_args, space=space) + + else: + logger.info( + "aggregated cause.", bindings=(dict(aggregated_cause=modeling_args.acause)) + ) + y_hat = _get_aggregated_y_hat( + modeling_args=modeling_args, space=space, suffix=suffix, root_dir=root_dir + ) + + if isinstance(y_hat, xr.Dataset): + if len(y_hat.data_vars) == 1: + y_hat = y_hat.rename({list(y_hat.data_vars.keys())[0]: "value"}) + return y_hat["value"] + logger.info( + "Using __xarray_dataarray_variable__, but other data_vars are present! " + "(probably just acause)" + ) + y_hat = y_hat.rename({"__xarray_dataarray_variable__": "value"}) + else: + y_hat.name = "value" + return y_hat + + +def _get_aggregated_y_hat( + modeling_args: SumToAllCauseModelingArguments, + space: str, + suffix: str, + root_dir: str, +) -> xr.DataArray: + """Gets expected value of cause specific mortality rates. + + For aggregate causes, it is assumed that the data is not split by sex and + is saved in log rate space. + + When the children are added to form the aggregated acause result, the + summation happens in normal space. Therefore, we must exponentiate the + children's rates, add them up, and log them to get an aggregated + y_hat in log rate space. If data is in normal space, simply sum. + + The resulting y_hat is in log rate space. + + Args: + modeling_args (SumToAllCauseModelingArguments): dataclass containing sum-to-all-cause + modeling arguments. + space (str): Space the data is loaded/saved in. (Summing happens in linear space.) + suffix (str): stage 2 should save data with "_hat" + root_dir (str): Root directory for saving data + + Raises: + Exception: If there are any missing files. + + Returns: + xr.DataArray: he expected value of the cause specific mortality rate. + """ + fatal_causes = get_fatal_causes.get_fatal_causes_df( + gbd_round_id=modeling_args.gbd_round_id + ) + + cause_id = fatal_causes[fatal_causes.acause == modeling_args.acause].cause_id.values[0] + children = fatal_causes.query("parent_id == {}".format(cause_id))[ + DimensionConstants.ACAUSE + ].values + logger.info("y_hat is a sum of children", bindings=(dict(children=children))) + + # Create a list of child acause files which are not external causes and + # check to make sure all the ones we want to sum up are actually present. + + base_vm = modeling_args.agg_version.with_root_dir(root_dir) + + potential_child_files = [ + FHSFileSpec(base_vm, f"{child}{suffix}.nc") + for child in children + if child not in (CauseConstants.ALL_ACAUSE, "_none") + ] + child_files = list(filter(provenance.ProvenanceManager.exists, potential_child_files)) + + if len(potential_child_files) != len(child_files): + missing_children = list(set(potential_child_files) - set(child_files)) + msg = f"You are missing files: {missing_children}." + logger.error( + msg, + bindings=( + dict(potential_child_files=potential_child_files, child_files=child_files), + ), + ) + raise Exception(msg) + logger.debug("Summing these files: ", bindings=dict(child_files=child_files)) + + exp_y_hat_sum = None + for child_file in child_files: + logger.info("Adding child file", bindings=dict(child_file=str(child_file))) + exp_y_hat = transform_spaces( + da=open_xr_scenario(file_spec=child_file).drop_vars( + [DimensionConstants.MEASURE, "cov"], errors="ignore" + ), + src_space=space, + dest_space="identity", + ) + # Remove child acause coordinates so they don't interfere with + # broadcasting and summing. Parent acause will be added back in later. + if DimensionConstants.ACAUSE in exp_y_hat.coords: + exp_y_hat = exp_y_hat.drop_vars(DimensionConstants.ACAUSE) + if exp_y_hat_sum is None: + exp_y_hat_sum = exp_y_hat + else: + exp_y_hat_broadcasted = xr.broadcast(exp_y_hat_sum, exp_y_hat) + exp_y_hat_broadcasted = [data.fillna(0.0) for data in exp_y_hat_broadcasted] + exp_y_hat_sum = sum(exp_y_hat_broadcasted) + + y_hat = transform_spaces(da=exp_y_hat_sum, src_space="identity", dest_space=space) + if DimensionConstants.ACAUSE not in y_hat.coords: + try: + y_hat[DimensionConstants.ACAUSE] = modeling_args.acause + except ValueError: + y_hat = y_hat.squeeze("acause") + y_hat[DimensionConstants.ACAUSE] = modeling_args.acause + elif modeling_args.acause not in y_hat.coords[DimensionConstants.ACAUSE]: + y_hat[DimensionConstants.ACAUSE] = [modeling_args.acause] + return y_hat + + +def expand_sex_id(ds: xr.Dataset) -> xr.Dataset: + """Expand dimension 'sex_id' function for use in open_mfdataset. + + Args: + ds (xr.Dataset): The loaded dataset before concatenation with scalar + coordinate for sex_id dimension. + + Returns: + xr.Dataset: The dataset with expanded 1D sex_id + """ + if DimensionConstants.SEX_ID in ds.dims: + return ds + return ds.expand_dims(DimensionConstants.SEX_ID) + + +def transform_spaces(da: xr.DataArray, src_space: str, dest_space: str) -> xr.DataArray: + """Convert ``da`` from the ``src_space`` to the ``dest_space``. + + Each space should be either "identity" (linear) or "log". + """ + if src_space == dest_space: + return da + if src_space == NORMAL_SPACE and dest_space == LOG_SPACE: + return np.log(da) + if src_space == LOG_SPACE and dest_space == NORMAL_SPACE: + return np.exp(da) + raise ValueError(f"Unknown space conversion from {src_space} to {dest_space}") + + +def transform_spaces_adding_floor( + da: xr.DataArray, src_space: str, dest_space: str +) -> xr.DataArray: + """Convert ``da`` from the ``src_space`` to the ``dest_space``, maybe adding a floor. + + We add a tiny floor constant when going into log space, so that we don't get -inf. + + Each space should be either "identity" (linear) or "log". + """ + if src_space == NORMAL_SPACE and dest_space == LOG_SPACE: + da = da + FLOOR + return transform_spaces(da, src_space, dest_space) + + +def open_xr_scenario_resample(file_spec: FHSFileSpec, draws: int) -> xr.DataArray: + """It is extremely common that we open a file and resample it immediately. + """ + return resample(open_xr_scenario(file_spec=file_spec), draws) + + +def _get_modeled_y_hat( + modeling_args: SumToAllCauseModelingArguments, + space: str, +) -> xr.DataArray: + """Gets mortality data for a modeled acause. + + For modeled causes, if the data is split by sex, then it is assumed that it + is in log rate space. If the data is not split by sex, then it is assumed + that it is in normal rate space. + + Args: + modeling_args (SumToAllCauseModelingArguments): dataclass containing sum-to-all-cause + modeling arguments. + space (str): Space the data should be returned in + + Raises: + IOError: If this cause has no modeled mortality/ylds for this version + + Returns: + xr.DataArray: the mortality or yld data for a cause + """ + input_file_spec = FHSFileSpec(modeling_args.input_version, f"{modeling_args.acause}.nc") + logger.info( + "No children. y_hat is from mort/yld file", + bindings=dict(input_file=input_file_spec), + ) + + if ProvenanceManager.exists(input_file_spec): + # Sex-combined data is assumed to be in linear space. + src_space = "identity" + else: + # Sex-split data is assumed to be in log space. + src_space = "log" + + if ProvenanceManager.exists(input_file_spec): + y_hat = open_xr_scenario_resample(input_file_spec, modeling_args.draws) + if isinstance(y_hat, xr.Dataset): + y_hat = y_hat["value"] + else: + # Modeled data is split by sex. + potential_input_files = [ + FHSFileSpec(modeling_args.input_version, f"{modeling_args.acause}_{sex_name}.nc") + for sex_name in ["male", "female"] + ] + input_files = list(filter(ProvenanceManager.exists, potential_input_files)) + logger.info( + "Input results are split by sex. Files are", + bindings=dict(input_files=input_files), + ) + if len(input_files) == 0: + msg = "This cause has no modeled mortality/ylds for this version." + logger.error( + msg, + bindings=( + dict(acause=modeling_args.acause, version=modeling_args.input_version) + ), + ) + raise IOError(msg) + y_hat_list = [ + open_xr_scenario_resample(input_file_spec, modeling_args.draws) + for input_file_spec in input_files + ] + if modeling_args.period == "future": + y_hat_list = [ + y_hat.drop_vars([DimensionConstants.MEASURE, "cov"], errors="ignore") + for y_hat in y_hat_list + ] + + y_hat = xr.concat(map(expand_sex_id, y_hat_list), dim=DimensionConstants.SEX_ID) + + if modeling_args.period == "future" and len(input_files) == 1: + sex_id = _sex_id_from_filename(input_files[0].filename) + if DimensionConstants.SEX_ID in y_hat.dims: + y_hat.coords[DimensionConstants.SEX_ID] = [sex_id] + else: + y_hat = y_hat.expand_dims({DimensionConstants.SEX_ID: [sex_id]}) + y_hat = y_hat.to_dataset(name="value") + # Note: Previous cases treat acause as a dim, this one as a coord. + y_hat.coords[DimensionConstants.ACAUSE] = modeling_args.acause + + if ProvenanceManager.exists(input_file_spec) or modeling_args.period == "past": + y_hat = _ensure_acause_dim(y_hat, acause=modeling_args.acause) + elif modeling_args.period == "future" and len(input_files) == 1: + y_hat.coords[DimensionConstants.ACAUSE] = modeling_args.acause + + y_hat = transform_spaces(y_hat, src_space=src_space, dest_space="identity") + y_hat = transform_spaces_adding_floor(y_hat, src_space="identity", dest_space=space) + return y_hat + + +def _sex_id_from_filename(filename: str) -> int: + """Return sex_id 1 or 2 based on filename containing "male" or "female".""" + if "female" in filename: + sex_id = 2 + else: + sex_id = 1 + return sex_id + + +def _ensure_acause_dim(da: xr.Dataset, acause: str) -> xr.Dataset: + """Ensure acause dimension is present in da.""" + if DimensionConstants.ACAUSE in da.dims: + if acause not in da.coords[DimensionConstants.ACAUSE]: + da[DimensionConstants.ACAUSE] = [acause] + else: + da = da.expand_dims({DimensionConstants.ACAUSE: [acause]}) + + return da diff --git a/gbd_2021/disease_burden_forecast_code/mortality/lib/y_star.py b/gbd_2021/disease_burden_forecast_code/mortality/lib/y_star.py new file mode 100644 index 0000000..2246a12 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/mortality/lib/y_star.py @@ -0,0 +1,495 @@ +r"""Compute Module. + +This module computes :math:`y^*` (y-star), which is the sum of the latent trend +component, i.e., the epsilon or residual error predictions from an ARIMA or Random Walk model, +and the :math:`\hat{y}` (y-hat) predictions from the GK model. It is assumed that the inputs +are given in log-rate space. +""" + +from functools import partial +from typing import List + +import numpy as np +import xarray as xr +from fhs_lib_data_transformation.lib.correlate import correlate_draws +from fhs_lib_data_transformation.lib.exponentiate_draws import bias_exp_new +from fhs_lib_data_transformation.lib.intercept_shift import ( + ordered_draw_intercept_shift, + unordered_draw_intercept_shift, +) +from fhs_lib_data_transformation.lib.resample import resample +from fhs_lib_database_interface.lib import db_session +from fhs_lib_database_interface.lib.constants import ( + CauseConstants, + DimensionConstants, + ScenarioConstants, +) +from fhs_lib_database_interface.lib.query.location import get_location_set +from fhs_lib_database_interface.lib.strategy_set.strategy import get_cause_set +from fhs_lib_file_interface.lib.version_metadata import FHSFileSpec, VersionMetadata +from fhs_lib_file_interface.lib.versioning import validate_versions_scenarios_listed +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr_scenario, save_xr_scenario +from fhs_lib_model.lib.random_walk.pooled_random_walk import PooledRandomWalk +from fhs_lib_model.lib.random_walk.random_walk import RandomWalk +from fhs_lib_model.lib.remove_drift import get_decayed_drift_preds +from fhs_lib_year_range_manager.lib.year_range import YearRange +from sqlalchemy.orm import Session +from tiny_structured_logger.lib.fhs_logging import get_logger + +from fhs_pipeline_mortality.lib.intercept_shift import intercept_shift_draws +from fhs_pipeline_mortality.lib.smoothing import get_smoothing_dims + +logger = get_logger() + +LOG_SPACE = "log" +NO_DRIFT_LOCATIONS = [ + 7, 11, 16, 29, 34, 39, 41, 50, 26, 30, 165, 169, 173, 180, 184, 190, 194, 196, 197, 198, + 207, 208, 214, 218, 157, 202, 217, 168, 171, 182, 191, 200, 201, 204, 215, 216, 205, 122, + 129, 422, 111, 170, 172, 175, 176, 177, 178, 179, 181, 185, 87, 189, 193, 195, 203, 206, + 209, 210, 211, 212, 213, 435, 40, 164, +] + + +def calculate_y_star( + acause: str, + agg_version: VersionMetadata, + epsilon_version: VersionMetadata, + past_version: VersionMetadata, + years: YearRange, + gbd_round_id: int, + draws: int, + decay: float, + intercept_shift: bool, + bias_correction: bool, + national_only: bool, + seed: int | None, + output_scenario: int | None, +) -> None: + r"""Calculates the :math:`y^*` -- aggregate of latent trend & :math:`\hat{y}` predictions. + + Samples mortality residuals from a latent trend model, e.g, ARIMA or Random Walk model, in + order to aggregate the latent trend and the dependent variable predictions from the GK + model, i.e, the :math:`\hat{y}`. Formally, the operation is as follows: + + .. math:: + + y^* = \hat{y} + \hat{\epsilon} + + Args: + acause (str): Name of the target acause to aggregate to. + agg_version (VersionMetadata): Name of the aggregate version. + epsilon_version (VersionMetadata): Name of the latent trend prediction version. + past_version (VersionMetadata): The version containing actual raw data, :math:`y`, for + past years. + years (YearRange): A container for the three years, which define our forecast. + gbd_round_id (int): The numeric ID for the GBD round. + draws (int): Number of draws to take. + decay (float): Rate at which the slope of the line decays once forecasts start. + intercept_shift (bool): Whether to intercept shift the :math:`y^*` results. + bias_correction (bool): Whether to perform log bias correction. + national_only (bool): Whether to include subnational locations, or to include only + nations. + seed (Optional[int]): An optional seed to set for numpy's random number generation + output_scenario (int | None): The scenario for the output if running in single scenario + mode. + """ + validate_versions_scenarios_listed( + versions=[agg_version, epsilon_version, past_version], + output_versions=[agg_version, epsilon_version], + output_scenario=output_scenario, + ) + + legacy_scenario_mode = True if output_scenario is None else False + + logger.info(f"Computing y^* for {acause}", bindings=dict(acause=acause)) + + y_hat_infile = FHSFileSpec(agg_version, f"{acause}_hat.nc") + + logger.debug("Opening y-hat data file", bindings=dict(y_hat_file=str(y_hat_infile))) + y_hat = open_xr_scenario(y_hat_infile).sel(**{DimensionConstants.YEAR_ID: years.years}) + + # Make sure y_past only has those locations in y_hat. + location_ids = y_hat.location_id.values + + # GK intercept shift + y_hat = intercept_shift_draws( + preds=y_hat, + acause=acause, + past_version=past_version.version, + gbd_round_id=gbd_round_id, + years=years, + draws=draws, + shift_function=partial( + ordered_draw_intercept_shift, + modeled_order_year_id=years.past_end, + shift_from_reference=legacy_scenario_mode, + ), + ) + save_xr_scenario( + y_hat, + FHSFileSpec(epsilon_version, f"{acause}_shifted.nc"), + metric=DimensionConstants.RATE_METRIC, + space=LOG_SPACE, + ) + + # Get actual past data (in log space). + y_past = _get_y_past(acause=acause, years=years, past_version=past_version).sel( + **{DimensionConstants.LOCATION_ID: location_ids} + ) + + with db_session.create_db_session() as session: + is_ntd = _is_ntd(acause, gbd_round_id, session) + smoothing = get_smoothing_dims(acause, gbd_round_id, session) + + if not is_ntd: + logger.debug("Including latent trend") + + # NOTE: latent trend model should be done for every cause except any of the Neglected + # Tropical Diseases (NTDs). + logger.debug("Computing ``epsilon_past``") + + # NOTE: epsilon_hat is calculated purely from past epsilon, and is scenario-less + epsilon_hat_outfile = FHSFileSpec( + epsilon_version.with_scenario(None), f"{acause}_eps.nc" + ) + + # Calculate past epsilon. + epsilon_past = y_past.sel( + **{DimensionConstants.YEAR_ID: years.past_years} + ) - y_hat.sel(**{DimensionConstants.YEAR_ID: years.past_years}) + + if legacy_scenario_mode: + epsilon_past = epsilon_past.sel( + scenario=ScenarioConstants.REFERENCE_SCENARIO_COORD, + drop=True, + ) + epsilon_past = epsilon_past.mean(DimensionConstants.DRAW) + + # Create future epsilon predictions. + epsilon_hat = _draw_epsilons( + epsilon_past=epsilon_past, + draws=draws, + smoothing=smoothing, + years=years, + acause=acause, + decay=decay, + gbd_round_id=gbd_round_id, + national_only=national_only, + seed=seed, + ) + + # Save epsilon predictions for future use. + save_xr_scenario( + epsilon_hat, + epsilon_hat_outfile, + metric=DimensionConstants.RATE_METRIC, + space=LOG_SPACE, + ) + + # Compute initial y-star (before bias correction and/or intercept shift). + y_star = _get_y_star( + y_hat=y_hat, + epsilon_hat=epsilon_hat, + years=years, + epsilon_version=epsilon_version, + ).copy() + + else: + logger.debug("No latent trend") + + # Without latent trend component, y-star is just y-hat. + y_star = y_hat + y_star.name = "value" + + # NOTE: by this point in the code, in legacy_scenario mode, y_star has all 3 "legacy" + # scenarios. In single-scenario mode, y_star has just the scenario of interest + + if bias_correction: + logger.debug("Performing bias correction on logged draws") + y_star = np.log(bias_exp_new(y_star)) + else: + logger.debug("Leaving logged draws raw in terms of bias") + + if intercept_shift: + logger.debug("Performing intercept shift on y^* results") + y_star = intercept_shift_draws( + preds=y_star, + acause=acause, + past_version=past_version.version, + gbd_round_id=gbd_round_id, + years=years, + draws=draws, + shift_function=partial( + unordered_draw_intercept_shift, shift_from_reference=legacy_scenario_mode + ), + ) + else: + logger.debug("NOT applying intercept shift on y^* results") + + y_star_outfile = FHSFileSpec(epsilon_version, f"{acause}_star.nc") + logger.debug("Saving ``y_star``", bindings=dict(y_star_outfile=y_star_outfile)) + save_xr_scenario( + y_star, + y_star_outfile, + metric=DimensionConstants.RATE_METRIC, + space=LOG_SPACE, + ) + + +def _is_ntd(acause: str, gbd_round_id: int, session: Session) -> bool: + ntd_cause_set = get_cause_set( + session=session, + strategy_id=CauseConstants.NTD_STRATEGY_ID, + gbd_round_id=gbd_round_id, + ) + + return acause in ntd_cause_set[DimensionConstants.ACAUSE] + + +def _get_y_star( + y_hat: xr.DataArray, + epsilon_hat: xr.DataArray, + years: YearRange, + epsilon_version: VersionMetadata, +) -> xr.DataArray: + """Returns draws of mortality or yld rates with estimated uncertainty. + + Args: + y_hat (xr.DataArray): Expected value of mortality or yld rates. + epsilon_hat (xr.DataArray): expected value of error. + years (YearRange): The forecasting time series year range. + epsilon_version (VersionMetadata): versioning information for the epsilon data + + Returns: + xr.DataArray: draws of mortality or yld rates with estimated uncertainty. + """ + logger.debug("Creating ``y_star`` by adding ``y_hat`` with ``epsilon_hat``") + + draws = len(epsilon_hat.coords[DimensionConstants.DRAW]) + logger.debug( + "Make sure ``y_hat`` has the right number of draws -- resample if needed", + bindings=dict(expected_draws=draws), + ) + y_hat_resampled = resample(y_hat, draws) + + # Make sure the dimensions of two dataarrays are in the same order. + if "acause" in epsilon_hat.coords: + try: + epsilon_hat = epsilon_hat.squeeze("acause").drop_vars("acause") + except KeyError: + epsilon_hat = epsilon_hat.drop_vars("acause") + dimension_order = list(epsilon_hat.coords.dims) + if epsilon_version.scenario is None: # legacy scenarios mode. + dimension_order += [DimensionConstants.SCENARIO] + + y_hat_resampled = y_hat_resampled.transpose(*dimension_order) + + # Correlate the time series draws with modeled estimates for uncertainty. + epsilon_hat_cleaned = _clean_data(data=epsilon_hat, epsilon_version=epsilon_version) + y_hat_cleaned = _clean_data(data=y_hat_resampled, epsilon_version=epsilon_version) + + epsilon_correlated = correlate_draws( + epsilon_draws=epsilon_hat_cleaned, + modeled_draws=y_hat_cleaned, + years=years, + ) + + return y_hat_resampled + epsilon_correlated + + +def _clean_data(data: xr.DataArray, epsilon_version: VersionMetadata) -> xr.DataArray: + """Strips data. + + Strips ``acause`` and ``scenario`` dims if they exist, and expands ``sex_id`` into a one + coord dim, if there is a ``sex_id`` point coord. + + Args: + data (xr.DataArray): input data. + epsilon_version (VersionMetadata): the scenario number identifying the "reference" case + + Returns: + xr.DataArray: cleaned data. + """ + if DimensionConstants.SCENARIO in data.dims: + scenario_to_sel = determine_scenario_to_select(epsilon_version) + cleaned_data = data.sel(**{DimensionConstants.SCENARIO: scenario_to_sel}, drop=True) + else: + cleaned_data = data.copy() + + if DimensionConstants.ACAUSE in data.dims: + logger.debug("acause is a dimension") + acause = cleaned_data[DimensionConstants.ACAUSE].values[0] + cleaned_data = cleaned_data.sel(**{DimensionConstants.ACAUSE: acause}).drop_vars( + [DimensionConstants.ACAUSE] + ) + elif DimensionConstants.ACAUSE in cleaned_data.coords: + logger.debug("acause is a point coordinate") + cleaned_data = cleaned_data.drop_vars([DimensionConstants.ACAUSE]) + else: + logger.debug("acause is NOT a dim") + + # Ensure sex-id is a dimension + if DimensionConstants.SEX_ID in cleaned_data.dims: + logger.debug("sex_id is a dimension") + elif DimensionConstants.SEX_ID in cleaned_data.coords: + logger.debug("sex_id is a point coordinate") + cleaned_data = cleaned_data.expand_dims(DimensionConstants.SEX_ID) + else: + logger.debug("sex_id is not a dim") + + return cleaned_data + + +def _draw_epsilons( + epsilon_past: xr.DataArray, + draws: int, + smoothing: List[str], + years: YearRange, + acause: str, + decay: float, + gbd_round_id: int, + national_only: bool, + seed: int | None, +) -> xr.DataArray: + """Draws forecasts for epsilons. + + For all-cause, this is done by running an attenuated drift model first to remove some of + the time trend from the epsilons. Then for all causes (including all-cause) except NTD's, + a Pooled AR1 (i.e., and ARIMA variant) or Random walk model is used to forecast the + remaining residuals and generate expanding uncertainty. + + Args: + epsilon_past (xr.DataArray): Past epsilons, i.e., error of predictions based on data. + draws (int): Number of draws to grab. + smoothing (List[str]): Which dimensions to smooth over during the latent trend model. + years (YearRange): The forecasting time series year range. + acause (str): The cause to forecast epsilons for. + decay (float): Rate at which the slope of the line decays once forecasts start. + gbd_round_id (int): The numeric ID for the GBD round. + national_only (bool): Whether to include subnational locations, or to include only + nations. + seed (Optional[int]): An optional seed to set for numpy's random number generation + + Returns: + xr.DataArray: Epsilon predictions for past and future years. + """ + logger.debug( + "Sampling epsilon_hat from latent trend with cross sections", + bindings=dict(smoothing_dims=smoothing), + ) + if acause == CauseConstants.ALL_ACAUSE: + # For all-cause, remove drift first, then run random walk on the remainder. + logger.debug("Computing drift term for all cause") + epsilon_past_no_drift = epsilon_past.sel(location_id=NO_DRIFT_LOCATIONS) + full_loc_l = epsilon_past.location_id.values.tolist() + with_drift_l = [x for x in full_loc_l if x not in NO_DRIFT_LOCATIONS] + epsilon_past_with_drift = epsilon_past.sel(location_id=with_drift_l) + drift_component_with_drift = get_decayed_drift_preds( + epsilon_da=epsilon_past_with_drift, + years=years, + decay=decay, + ) + # Remove drift for selected locations + drift_component_no_drift = epsilon_past_no_drift * 0 + drift_component = xr.concat( + [drift_component_with_drift, drift_component_no_drift], dim="location_id" + ) + drift_component = drift_component.fillna(0) + remainder = epsilon_past - drift_component.sel(year_id=years.past_years) + dataset = xr.Dataset(dict(y=remainder.copy())) + + else: + # If not all-cause, directly model epsilons with random walk. + logger.debug("Computing drift term for all cause", bindings=dict(acause=acause)) + dataset = xr.Dataset(dict(y=epsilon_past.copy())) + drift_component = xr.DataArray(0) + + location_set = get_location_set( + gbd_round_id=gbd_round_id, + include_aggregates=False, + national_only=national_only, + )[ + [ + DimensionConstants.LOCATION_ID, + DimensionConstants.REGION_ID, + DimensionConstants.SUPER_REGION_ID, + ] + ] + dataset.update(location_set.set_index(DimensionConstants.LOCATION_ID).to_xarray()) + + if acause == CauseConstants.ALL_ACAUSE: + logger.debug( + "All cause y^* has drift component", + bindings=dict(latent_trend_model=RandomWalk.__name__), + ) + model_obj = RandomWalk( + dataset=dataset, + years=years, + draws=draws, + seed=seed, + ) + model_obj.fit() + predictions = model_obj.predict() + epsilon_hat = drift_component + predictions + else: + # If not all-cause, use a pooled random walk whose location pooling dimension is based + # on cause level. + logger.debug( + "{acause} y^* does NOT have drift component", + bindings=dict(latent_trend_model=PooledRandomWalk.__name__, acause=acause), + ) + model_obj = PooledRandomWalk( + dataset=dataset, + years=years, + draws=draws, + dims=smoothing, + seed=seed, + ) + model_obj.fit() + epsilon_hat = model_obj.predict() + + return epsilon_hat + + +def _get_y_past(acause: str, years: YearRange, past_version: VersionMetadata) -> xr.DataArray: + """Gets the raw data for past years. + + Past data is saved in normal rate space. The past data is returned in log rate space. + + Args: + acause (str): Short name of the target acause to aggregate to. + years (YearRange): Forecasting time series year range. + gbd_round_id (int): The numeric ID for the GBD round. + past_version (VersionMetadata): The version containing predictions for past years. + + Returns: + xr.DataArray: The expected value of the cause specific mortality or yld rate. + """ + y_hat_past_file = FHSFileSpec(past_version, f"{acause}_hat.nc") + + logger.debug( + "Reading in past data", + bindings=dict( + y_hat_past_file=y_hat_past_file, + past_start=years.past_start, + past_end=years.past_end, + ), + ) + + y_past = ( + open_xr_scenario(y_hat_past_file) + .sel(**{DimensionConstants.YEAR_ID: years.past_years}) + .mean(DimensionConstants.DRAW) + ) + + return y_past + + +def determine_scenario_to_select(version_metadata: VersionMetadata) -> int: + """Return either the ``version_metadata`` scenario or the reference scenario.""" + scenario_to_sel = ( + version_metadata.scenario + if version_metadata.scenario is not None + else ScenarioConstants.REFERENCE_SCENARIO_COORD + ) + return scenario_to_sel diff --git a/gbd_2021/disease_burden_forecast_code/mortality/models/gk-model/GKModel.py b/gbd_2021/disease_burden_forecast_code/mortality/models/gk-model/GKModel.py new file mode 100644 index 0000000..764a488 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/mortality/models/gk-model/GKModel.py @@ -0,0 +1,455 @@ +"""This module contains an implementation of the Girosi-King Model, i.e., the "GK" Model. +""" + +import site +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import pandas as pd +import xarray as xr +from fhs_lib_database_interface.lib.constants import DimensionConstants +from fhs_lib_year_range_manager.lib.year_range import YearRange +from tiny_structured_logger.lib.fhs_logging import get_logger + +from fhs_lib_model.lib.constants import GKModelConstants, ModelConstants, PyMBConstants +from fhs_lib_model.lib.gk_model.model_parameters import ModelParameterEffects +from fhs_lib_model.lib.gk_model.omega import omega_prior_vectors, omega_translate +from fhs_lib_model.lib.gk_model.post_process import ( + np_to_xr, + region_to_location, + transform_parameter_array, +) +from fhs_lib_model.lib.gk_model.pre_process import data_var_merge +from fhs_lib_model.lib.pymb.pymb import PyMB + +logger = get_logger() + + +class ConvergenceError(Exception): + """Error alerting non convergence of models.""" + + +class GKModel(PyMB): + """This class is an implementation of the Girosi-King Model, i.e., the "GK" Model. + + The ``GKModel`` class is a subclass of the ``PyMB`` model class, meaning that it relies + heavily on the Template Model Builder (TMB) R package. The ``GKModel``, allows for ease of + data preparation as long as the expected variables are included in a DataFrame which is + passed in at the initialization step of the class. + """ + + def __init__( + self, + dataset: Union[xr.DataArray, xr.Dataset], + years: YearRange, + fixed_effects: Dict[str, List[Tuple[str, float]]], + random_effects: Dict[str, List[str]], + draws: int, + constants: Optional[List[str]] = None, + y: Optional[str] = None, + omega_amp: Union[float, Dict[str, float]] = 0, + weight_decay: float = 0, + seed: Optional[int] = None, + ) -> None: + """Initializer for the GKModel. + + Args: + dataset: (xr.Dataset | xr.DataArray): Data to forecast. If ``xr.Dataset`` then + ``y`` paramter must be specified as the variable to forecast. + years (YearRange): The forecasting time series. + fixed_effects (Dict[str, List[Tuple[str, float]]]): The fixed effects of the model. + Maps the parameter names to their associated covariates, and the restrictions + of those covariates. + e.g., + .. code:: python + { + "beta_age": [("vehicles_2_plus_4wheels_pc", 0), ("sdi", 0)], + "beta_global": [("hiv", 1), ("sdi", 0), ("intercept", 0)], + } + random_effects (Dict[str, List[str]]): The random effects of the model. Maps the + parameter names to their associated covariates. e.g., + .. code:: python + { + "gamma_location_age": ["intercept"], + "gamma_age": ["time_var"], + } + draws (int): Number of draws to make for predictions. + constants (Optional[List[str]]): The names of the constants to be added to the + model. Defaults to ``None``, i.e., no constants. + y (str): + The name of the dependent, i.e., response, variable. **NOTE:** this should be a + data variable on ``dataset``, if dataset is an ``xr.Dataset``. Defaults to "y". + omega_amp (Union[float, Dict[str, float]]): + How much to amplify the omega priors by. Defaults to ``0``. + weight_decay (float): + How much to decay the weight of the likelihood on predictions as they get + farther away from the last year in sample. Defaults to ``0``. + seed (Optional[int]): + an optional seed for the C++ and numpy random number generators + + Raises: + ValueError: If invalid arguments are given for any of the following reasons: + + * If ``y`` not a variable on ``dataset``, i.e., + ``list(dataset.data_vars.keys())``. + * If the fixed or random effects are improperly named. + * If any constraints are invalid. + """ + super().__init__(name=self.__class__.__name__) + + self._init_model() + self._init_dataset(dataset, y) + + self.years = years + self.num_of_draws = draws + + self.random_effect_names = [ + g for g in random_effects.keys() if len(random_effects[g]) > 0 + ] + + self.constants = constants if constants is not None else [] + param_effects = ModelParameterEffects(fixed_effects, random_effects, constants) + + # Create mapping of data (dependent and independent) variables to their values for this + # model. + self._init_data_map(param_effects, weight_decay, seed) + + # Initialize omega priors. + self._init_omega_priors(omega_amp) + + # Separate the data elements with draws into there own collection, ``data_draw``. + self._init_draw_data() + + # Create mapping of initial parameter values for this model. + self._init_param_map(param_effects) + + # Create an instance variable that will be the xarray dataset containing the optimized + # model parameter coefficients -- after the model-fitting process. + self.coefficients: Optional[xr.Dataset] = None + + # Create an instance variable that will be the xarray dataarray containing the model + # predictions for the all years from the first forecast year to the last forecast year. + self.predictions: Optional[xr.DataArray] = None + + # Set the seed for PyMB's numpy use as well. + self.rng = np.random.default_rng(seed=seed) + + def fit(self) -> xr.Dataset: + """Optimize and fit the coefficients of the model parameters. + + Produces a dataset containing the coefficients for the fixed and random effects fit in + the model. There is one variable per effect type (e.g. ``gamma_location_age``), and + a covariate dimension specifying which covariate is associated with which coefficients. + All covariate/effect-type pairs that were not fit (e.g., ``gamma_age`` and ``sdi``, if + ``sdi`` was included as a global fixed effect) contain ``NaN`` values. + + Returns: + xr.Dataset: A dataset containing the coefficients for the optimized model. + + Raises: + ConvergenceError: If the model cannot converge, i.e., the convergence value is + nonzero. + """ + self._optimize() + + # Convert fit parameters coefficients into xarray format use to keep track of parameter + # order. + demog_coords = dict( + location_id=self.data[GKModelConstants.LOCATION_PARAM], + year_id=self.data[GKModelConstants.YEAR_PARAM], + age_group_id=self.data[GKModelConstants.AGE_PARAM], + ) + + # Get the names of the fixed effects (e.g. ``"beta_global_raw"``). + fixed_effects = [ + f + "_raw" for f in self.fixed_names if f + "_raw" in list(self.parameters.keys()) + ] + + # Get the names of the random effects. + random_effects = [r for r in self.random_names if r in list(self.parameters.keys())] + + # Get the draws of the fixed effects and the random effect parameters. + # **NOTE:** This does does not involve covariate values. + fixed_effect_arrays = [ + transform_parameter_array(self.draws(f), self.data[f.replace("raw", "constraint")]) + for f in fixed_effects + ] + random_effect_arrays = [self.draws(k) for k in random_effects] + + # Get names of fixed effect covariates (e.g., the ``"intercept"`` and ``"sdi"`` + # covariates for the ``"beta_global" parameter``). + fixed_cov_names = [self.data[k.replace("_raw", "")] for k in fixed_effects] + # Get the names for the random effect covariate names in the same way. + random_cov_names = [self.data[k] for k in random_effects] + + # If region or super region are part of the effects structure, then copy it out so the + # param exists for every location. + fixed_effect_arrays = [ + ( + region_to_location( + fixed_effect_arrays[i], self.data[GKModelConstants.REGION_PARAM] + ) + if f"beta_{GKModelConstants.REGION_PARAM}" in fixed_effects[i] + else fixed_effect_arrays[i] + ) + for i in range(len(fixed_effect_arrays)) + ] + fixed_effect_arrays = [ + ( + region_to_location( + fixed_effect_arrays[i], self.data[GKModelConstants.SUPER_REGION_PARAM] + ) + if GKModelConstants.SUPER_REGION_PARAM in fixed_effects[i] + else fixed_effect_arrays[i] + ) + for i in range(len(fixed_effect_arrays)) + ] + random_effect_arrays = [ + ( + region_to_location( + random_effect_arrays[i], self.data[GKModelConstants.REGION_PARAM] + ) + if f"gamma_{GKModelConstants.REGION_PARAM}" in random_effects[i] + else random_effect_arrays[i] + ) + for i in range(len(random_effect_arrays)) + ] + random_effect_arrays = [ + ( + region_to_location( + random_effect_arrays[i], self.data[GKModelConstants.SUPER_REGION_PARAM] + ) + if GKModelConstants.SUPER_REGION_PARAM in random_effects[i] + else random_effect_arrays[i] + ) + for i in range(len(random_effect_arrays)) + ] + + # Convert the coefficient arrays to xarray. + fixed_effect_arrays = [ + np_to_xr(r, demog_coords, fixed_cov_names[i]) + for i, r in enumerate(fixed_effect_arrays) + ] + random_effect_arrays = [ + np_to_xr(r, demog_coords, random_cov_names[i]) + for i, r in enumerate(random_effect_arrays) + ] + + # Put the random and fixed effect parameter coefficients into one xarray.Dataset. + additive_params = dict() + for i, k in enumerate(fixed_effects): + additive_params[k.replace("_raw", "")] = fixed_effect_arrays[i] + + for i, k in enumerate(random_effects): + additive_params[k] = random_effect_arrays[i] + self.coefficients = xr.Dataset(additive_params) + + return self.coefficients + + def predict(self) -> xr.DataArray: + """Generate predictions for future years from the optimized model fit. + + Generate predictions from the year ``years.past_start`` up through the year + ``years.forecast_end`` for input covariates using the fit_params. Variables with a + coefficient defined to be 1 (like scalars), specified in constant_vars, are added on. + Currently there is no intercept_shift option + + Returns: + xr.DataArray: data array containing the predictions from the input covariates and + fit parameters. + """ + # Convert fit_params to an array in order to sum more easily + fit_params = self.coefficients.to_array().fillna(0) + + # loop through covariates and add on their contribution to the total + contributions = [ + (self.dataset[cov] * fit_params.sel(cov=cov, drop=True)).sum("variable") + for cov in fit_params.cov.values + ] + pred_ds = sum(contributions) + + # Add on the data from the constant variables. + for var in self.constants: + pred_ds = pred_ds + self.dataset[var] + + self.predictions = pred_ds.sel(year_id=self.years.years) + return self.predictions + + def _init_model(self) -> None: + """Configure model and create the DLL for C++ extensions.""" + model: Optional[str] = None + for package_dir in site.getsitepackages(): + so_file = Path(package_dir) / f"{self.name}.so" + if so_file.is_file(): + model = str(so_file) + if model is None: + err_msg = f"No shared object library file found. Expecting ``{self.name}.so``" + logger.error(err_msg) + raise EnvironmentError(err_msg) + self.load_model(model) + + def _init_dataset( + self, dataset: Union[xr.DataArray, xr.Dataset], y: Optional[str] + ) -> None: + """Initialize, validate, and prepare input dataset.""" + if isinstance(dataset, xr.DataArray): + logger.debug("DataArray given, converting to Dataset") + dataset = xr.Dataset({y: dataset}) + + self.y = y if y is not None else ModelConstants.DEFAULT_DEPENDENT_VAR + if self.y not in list(dataset.data_vars.keys()): + err_msg = "``y`` was not found in the input dataset" + logger.error(err_msg) + raise ValueError(err_msg) + + self.dataset = dataset + + def _init_omega_priors(self, omega_amp: Union[Dict[str, float]]) -> None: + """Initialize Bayesian Omega Priors.""" + if omega_amp != 0: + omega_amp_map = omega_translate(omega_amp) + self.data.update( + omega_prior_vectors( + self.data[f"y_{GKModelConstants.OmegaLevels.U}"], + self.data[GKModelConstants.HOLDOUT_START_PARAM], + level=GKModelConstants.OmegaLevels.U, + omega_amp_map=omega_amp_map, + ) + ) + self.data.update( + omega_prior_vectors( + self.data[f"y_{GKModelConstants.OmegaLevels.T}"], + self.data[GKModelConstants.HOLDOUT_START_PARAM], + level=GKModelConstants.OmegaLevels.T, + omega_amp_map=omega_amp_map, + ) + ) + else: + self.data.update({k: 0 for k in GKModelConstants.OMEGA_PRIORS}) + + def _init_data_map( + self, + param_effects: ModelParameterEffects, + weight_decay: float, + seed: Optional[int], + ) -> None: + """Create mapping of data variables to their values for this model. + + Args: + param_effects: fixed and random effects to be added to the data. + weight_decay: weight decay value to be added to the data. + seed: A seed for the GK Random number generator + """ + self.data = dict() + self.data["mean_adjust"] = int(0) + self.data["testing_prints"] = int(0) + self.data[GKModelConstants.WEIGHT_DECAY_PARAM] = weight_decay + self.data[GKModelConstants.LOCATION_PARAM] = self.dataset.location_id.values + self.data[GKModelConstants.AGE_PARAM] = self.dataset.age_group_id.values + self.data[GKModelConstants.YEAR_PARAM] = self.dataset.year_id.values + self.data[GKModelConstants.HOLDOUT_START_PARAM] = np.where( + self.data[GKModelConstants.YEAR_PARAM] == self.years.past_end + )[0][0] + self.data["covariates2"] = [] + self.data["beta2_constraint"] = np.array([]) + + final_array = data_var_merge(self.dataset) + final_array_col = data_var_merge(self.dataset, collapse=True) + param_effects.extract_param_data(self.dataset, final_array, final_array_col) + self.data["X2"] = np.zeros((0, 0, 0, 0)) + self.data["X2_draw"] = np.zeros((0, 0, 0, 0)) + self.data.update(param_effects.data) + self.fixed_names = param_effects.fixed_effect_names + self.random_names = param_effects.random_effect_names + self.data[f"y_{GKModelConstants.OmegaLevels.U}"] = final_array_col.loc[ + dict(cov=self.y) + ].values + self.data[f"y_{GKModelConstants.OmegaLevels.T}"] = final_array_col.loc[ + dict(cov=self.y) + ].values + self.data[f"y_{GKModelConstants.OmegaLevels.T}_draws"] = final_array.loc[ + dict(cov=self.y) + ] + # Get region mappings. + self.data[GKModelConstants.REGION_PARAM] = np.repeat( + 0, len(self.data[GKModelConstants.LOCATION_PARAM]) + ) + self.data[GKModelConstants.SUPER_REGION_PARAM] = np.repeat( + 0, len(self.data[GKModelConstants.LOCATION_PARAM]) + ) + if DimensionConstants.REGION_ID in list(self.dataset.data_vars.keys()): + self.data[GKModelConstants.REGION_PARAM] = pd.factorize( + self.dataset.region_id.values + )[0] + if DimensionConstants.SUPER_REGION_ID in list(self.dataset.data_vars.keys()): + self.data[GKModelConstants.SUPER_REGION_PARAM] = pd.factorize( + self.dataset.super_region_id.values + )[0] + + self.data["has_risks"] = np.zeros_like(self.data["age"], dtype=int) + + if seed is None: + self.data["set_seed"] = 0 + self.data["seed"] = 0 + else: + self.data["set_seed"] = 1 + self.data["seed"] = seed + + def _init_draw_data(self) -> None: + """Separate the data elements that have draws into their own collection.""" + self.data_draw = dict() + for k in list(self.data.keys()): + if k.endswith("draw") or k.endswith("draws"): + self.data_draw[k] = self.data.pop(k) + + def _init_param_map(self, param_effects: ModelParameterEffects) -> None: + """Create mapping of initial parameter values for this model.""" + self.init = dict() + self.init.update( + { + "log_age_sigma_pi": np.array([]), + "log_location_sigma_pi": np.array([]), + "logit_rho": np.array([]), + "pi": np.zeros((0, 0, 0)), + } + ) + self.init.update(param_effects.init) + _, _, _, k2 = self.data[ + "X2" + ].shape + r = np.sum(self.data["has_risks"]) + self.init[f"log_sigma_{GKModelConstants.OmegaLevels.U}"] = np.zeros(0 if r == 0 else 1) + self.init[f"log_sigma_{GKModelConstants.OmegaLevels.T}"] = 0 + self.init[f"log_zeta_{GKModelConstants.OmegaLevels.D}"] = np.zeros(r) + self.init["beta2_raw"] = np.zeros((1, r, 1, k2)) + + def _optimize( + self, + opt_fun: str = PyMBConstants.DEFAULT_OPT_FUNC, + method: str = PyMBConstants.DEFAULT_OPT_METHOD, + ) -> None: + """Optimize the model and store results in ``TMB_Model.TMB.fit``. + + Optimize the model and store results in ``TMB_Model.TMB.fit`` using the PyMP + ``optimize``. + + Args: + opt_fun: (str): the R optimization function to use (e.g. ``'nlminb'`` or + ``'optim'``). Defaults to ``'nlminb'`` + method (str): Method to use for optimization. Defaults to ``'L-BGFS-B'``. + + Raises: + ConvergenceError: If the model cannot converge, i.e., the convergence value is + nonzero. + """ + super().optimize( + opt_fun=opt_fun, + method=method, + draws=self.num_of_draws, + random_effects=self.random_effect_names, + ) + if self.convergence != 0: + err_msg = "The model could not converge" + logger.error(err_msg, bindings=dict(convergence=self.convergence)) + raise ConvergenceError(err_msg) diff --git a/gbd_2021/disease_burden_forecast_code/mortality/models/gk-model/model_parameters.py b/gbd_2021/disease_burden_forecast_code/mortality/models/gk-model/model_parameters.py new file mode 100644 index 0000000..e996317 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/mortality/models/gk-model/model_parameters.py @@ -0,0 +1,274 @@ +"""This module contains a class that encapsulates logic/info related GK model parameters. +""" + +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import xarray as xr +from fhs_lib_database_interface.lib.constants import DimensionConstants +from tiny_structured_logger.lib.fhs_logging import get_logger + +from fhs_lib_model.lib.constants import GKModelConstants + +logger = get_logger() + + +class ModelParameterEffects: + """Creates the initial random and fixed effects for GK Model.""" + + def __init__( + self, + fixed_effects: Dict[str, List[Tuple[str, float]]], + random_effects: Dict[str, List[str]], + constants: Optional[List[str]] = None, + ) -> None: + """Initializer for fixed and random effects container. + + Args: + fixed_effects (Dict[str, List[Tuple[str, float]]]): The fixed effects of the model. + Maps the parameter names to their associated covariates, and the restrictions + of those covariates. + e.g., + .. code:: python + { + "beta_age": [("vehicles_2_plus_4wheels_pc", 0), ("sdi", 0)], + "beta_global": [("hiv", 1), ("sdi", 0), ("intercept", 0)], + } + random_effects (Dict[str, List[str]]): The random effects of the model. Maps the + parameter names to their associated covariates. e.g., + .. code:: python + { + "gamma_location_age": ["intercept"], + "gamma_age": ["time_var"], + } + constants (Optional[List[str]]): The names of the constants to be added to the + model. Defaults to ``None``, i.e., no constants. + + Raises: + ValueError: If the fixed or random effects are improperly named. + """ + self.fixed_effect_names = [ + "{}_{}".format("beta", t) for t in GKModelConstants.STANDARD_PARAMS + ] + missing_fixed = set(fixed_effects.keys()) - set(self.fixed_effect_names) + if len(missing_fixed): + err_msg = f"Missing fixed effects: {missing_fixed}" + logger.error(err_msg) + raise ValueError(err_msg) + + self.fixed_effects = { + k: (fixed_effects[k] if k in fixed_effects.keys() else []) + for k in self.fixed_effect_names + } + + self.random_effect_names = [ + "{}_{}".format("gamma", t) for t in GKModelConstants.STANDARD_PARAMS + ] + missing_random = set(random_effects.keys()) - set(self.random_effect_names) + if len(missing_random): + err_msg = f"Missing random effects: {missing_random}" + logger.error(err_msg) + raise ValueError(err_msg) + + self.random_effects = { + k: (random_effects[k] if k in list(random_effects.keys()) else []) + for k in self.random_effect_names + } + + self.constants = constants if constants is not None else [] + + self.init = dict() + self.data = dict() + + def extract_param_data( + self, + dataset: xr.Dataset, + final_array: xr.DataArray, + final_array_col: xr.DataArray, + ) -> None: + """Parse out the covariate information from the appropriate dictionaries. + + Args: + dataset (xr.Dataset): The dataset to extract parameter information from. + final_array (xr.DataArray): An array containing same data as in ``dataset``, except + in array form such that each data variable from ``dataset`` exists as a slice + of the array. + final_array_col (xr.DataArray): Only the dims from + ``GKModelConstants.MODELABLE_DIMS`` are expected, and in that order. + + Raises: + ValueError: If any constraints are invalid. + """ + if DimensionConstants.REGION_ID in list(dataset.data_vars.keys()): + region_size = len(np.unique(dataset.region_id.values)) + logger.debug("Nonzero region size", bindings=dict(region_size=region_size)) + else: + region_size = 0 + logger.debug("Region size is zero") + + if DimensionConstants.SUPER_REGION_ID in list(dataset.data_vars.keys()): + super_region_size = len(np.unique(dataset.super_region_id.values)) + logger.debug( + "Nonzero super region size", + bindings=dict(super_region_size=super_region_size), + ) + else: + super_region_size = 0 + logger.debug("Super Region size is zero") + + for cov in self.fixed_effect_names: + cov_type = "_".join(cov.split("_")[1:]) + + beta_const_name = f"beta_{cov_type}_constraint" + beta_raw_name = f"beta_{cov_type}_raw" + beta_mean_name = f"beta_{cov_type}_mean" + + X = f"X_{cov_type}" + X_draw = f"X_{cov_type}_draw" + + self.data[cov], self.data[beta_const_name] = self._parse_covariates( + self.fixed_effects[cov] + ) + self.data[X] = final_array_col.sel(cov=self.data[cov]).values + self.data[X_draw] = final_array.sel(cov=self.data[cov]) + param_shape = self._name_to_shape( + cov_type, + self.data[X].shape, + super_region_size, + region_size, + ) + self.init[beta_raw_name] = np.zeros(param_shape) + self.init[beta_mean_name] = np.ones(param_shape) + + for random_effect_name in self.random_effect_names: + random_effect_type = "_".join(random_effect_name.split("_")[1:]) + + self.data[random_effect_name], _ = self._parse_covariates( + self.random_effects[random_effect_name] + ) + + gamma_name = f"gamma_{random_effect_type}" + Z = f"Z_{random_effect_type}" + Z_draw = f"Z_{random_effect_type}_draw" + tau_name = f"log_tau_{random_effect_type}" + + self.data[Z] = final_array_col.sel(cov=self.data[random_effect_name]).values + self.data[Z_draw] = final_array.sel(cov=self.data[random_effect_name]) + + param_shape = self._name_to_shape( + random_effect_type, + self.data[Z].shape, + super_region_size, + region_size, + ) + self.init[gamma_name] = np.zeros(param_shape) + self.init[tau_name] = np.zeros( + (1, 1, 1, self.data[Z].shape[-1]) + ) + + self.data["constant"] = final_array_col.sel(cov=self.constants).values + self.data["constant_draw"] = final_array.sel(cov=self.constants) + self.data["constant_mult"] = np.ones( + (1, 1, 1, self.data["constant"].shape[3]) + ) + + @staticmethod + def _name_to_shape( + covariate: str, + dim_sizes: List[int], + super_region_size: int, + region_size: int, + ) -> Tuple[int, int, int, int]: + """Converts a covariate name to a proper shape. + + Args: + covariate (str): The name of the covariate. + dim_sizes (List[int, int, int, int]): Contains the generic size of each model-able + dimension, in the order location, age, time, and covariate. + super_region_size (int): + The size of the super region dimension. + region_size (int): + The size of the region dimension. + + Returns: + Tuple[int, int, int, int]: The dimension sizes expected for the given covariate. + """ + location_size, age_size, _, covariate_size = dim_sizes + + expected_time_size: int = 1 + expected_covariate_size: int = covariate_size + + expected_age_size: int + if GKModelConstants.AGE_PARAM in covariate: + expected_age_size = age_size + else: + expected_age_size = 1 + + expected_location_size: int + if GKModelConstants.SUPER_REGION_PARAM in covariate: + expected_location_size = super_region_size + elif GKModelConstants.REGION_PARAM in covariate: + expected_location_size = region_size + elif GKModelConstants.LOCATION_PARAM in covariate: + expected_location_size = location_size + else: + expected_location_size = 1 + + return ( + expected_location_size, + expected_age_size, + expected_time_size, + expected_covariate_size, + ) + + @staticmethod + def _parse_covariates( + covariates: Union[List[str], List[Tuple[str, float]]], + ) -> Tuple[List[str], np.ndarray]: + """Parse covariates from a list of tuples or strings. + + Args: + covariates (Union[List[str], List[Tuple[str, float]]]): List of covariate names + corresponding to model parameter. If for a fixed-effect parameter, then a list + of tuples is given where each tuple has a covariate name and the restriction, + i.e., constraint, associated with that covariate. + + Returns: + Tuple[List[int], np.array]: The list of covariate names and a 1D array of with + constraints. + + Raises: + RuntimeError: If any constraints are invalid. + """ + # Parameter has no covariates. + if len(covariates) == 0: + return [], np.array([]) + # Parameter is a **random** effect. + elif isinstance(covariates[0], str): + covariate_names = [c for c in covariates] + constraints = [0] * len(covariates) # Random effects don't have constraints. + # Parameter is a **fixed** effect. + else: + # Separate out the covariate names from their constraints. + covariate_names = [x for c in covariates for x in c if isinstance(x, str)] + constraints = [x for c in covariates for x in c if not isinstance(x, str)] + + ModelParameterEffects._check_constraints(constraints) + return covariate_names, np.array(constraints) + + @staticmethod + def _check_constraints(constraints: List[int]) -> None: + """Make sure that all constraints are either -1, 0, or 1. + + Args: + constraints (List[int]): List of integers to use as GK TMB constraints. + + Raises: + ValueError: If any constraints are invalid. + """ + bad_constraints = set(constraints) - set(GKModelConstants.VALID_CONSTRAINTS) + + if len(bad_constraints): + err_msg = "Constraints must be either 1, -1, or 0." + logger.error(err_msg) + raise ValueError(err_msg) diff --git a/gbd_2021/disease_burden_forecast_code/mortality/models/gk-model/omega.py b/gbd_2021/disease_burden_forecast_code/mortality/models/gk-model/omega.py new file mode 100644 index 0000000..1468ab7 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/mortality/models/gk-model/omega.py @@ -0,0 +1,159 @@ +"""This module contains utilities for preparing the Bayesian Omega Priors of the GK Model. +""" + +from typing import Dict, Union + +import numpy as np +from scipy.optimize import brentq +from scipy.special import gammaln +from tiny_structured_logger.lib.fhs_logging import get_logger + +from fhs_lib_model.lib.constants import GKModelConstants + +logger = get_logger() + + +def omega_translate(omega_amp: Union[float, Dict[str, float]]) -> Dict[str, float]: + """Creates a mapping of all omega parameters to their amplification values. + + Creates a dictionary where the keys are the names of all the omega parameters and the + keys are the amplification value used to scale the values of each omega parameter. If + a single ``float`` value is given then all omegas will have that value. Missing omega + parameters will be assigned an amplification value of ``1`` by default. + + Notes: + * The omega parameters are defined in ``GKModelConstants.OMEGA_PARAMS``. + + Args: + omega_amp (float | Dict[str, float]): original input for ``omega_amp`` in + ``GKModel`` class. + + Returns: + Dict[str, float]: Dictionary of omegas mapped to their corresponding amplification + values. + """ + if isinstance(omega_amp, dict): + logger.debug( + "Some omegas have specific amplification values", + bindings=dict(omegas=list(omega_amp.keys())), + ) + omega_amp_map = {o: omega_amp.get(o, 1.0) for o in GKModelConstants.OMEGA_PARAMS} + else: + logger.debug("Setting same single omega amplification value for all omegas") + omega_amp_map = {o: omega_amp for o in GKModelConstants.OMEGA_PARAMS} + + # Note that we disable a pytype "bad return type" error because it incorrectly thinks + # we're returning a dictionary containing dictionaries + return omega_amp_map # pytype: disable=bad-return-type + + +def omega_prior_vectors( + response_array: np.ndarray, + holdout_index: int, + level: str, + omega_amp_map: Dict[str, float], +) -> Dict[str, float]: + """Returns the vectors of smoothing parameters for each omega priors. + + Gets the vectors for omega priors to be calculated for a response variable as per the + GK model process. + + Notes: + * If there are less than 3 ages for the smoothing param, then the value for that omega + will be 0. + + Args: + response_array (np.ndarray): 3D response variable indexed on location, age, time + (i.e., time in years). + holdout_index (int): The index by which the hold outs are started. + level (str): The level name to apply to the end of the dict key. + omega_amp_map (Dict[str, float]): Maps each omega prior to its corresponding + amplification, i.e, how much to scale the omega priors. + + Returns: + Dict[str, np.array]: Dictionary of vectors to get the optimal smoothing from. + """ + location_index, age_index, year_index = response_array.shape + + response_array = np.copy(response_array) + response_array[:, :, holdout_index:] = np.nan + response_array_adj = response_array - np.nanmean(response_array, axis=2).reshape( + (location_index, age_index, 1) + ) + response_array_adj = response_array_adj[:, :, :holdout_index] + + omega_vectors = dict() + + omega_vectors[f"omega_loc_{level}"] = ( + np.sqrt(np.abs(np.diff(response_array_adj, axis=0))) * omega_amp_map["omega_location"] + ) + + omega_vectors[f"omega_age_{level}"] = ( + np.sqrt(np.abs(np.diff(response_array_adj, axis=1))) * omega_amp_map["omega_age"] + ) + + omega_vectors[f"omega_loc_time_{level}"] = ( + np.sqrt(np.abs(np.diff(np.diff(response_array_adj, axis=2), axis=0))) + * omega_amp_map["omega_location_time"] + ) + + omega_vectors[f"omega_age_time_{level}"] = ( + np.sqrt(np.abs(np.diff(np.diff(response_array_adj, axis=2), axis=1))) + * omega_amp_map["omega_age_time"] + ) + + # If there are less than 3 ages for the smoothing param, then the value will be 0. + smoothing_params = { + k: _smoothing_param(v) if (age_index > 2 or "age" not in k) else 0 + for k, v in omega_vectors.items() + } + return smoothing_params + + +def _smoothing_param(array: np.ndarray) -> float: + """Function to define a GK smoothing parameter given its mean and standard deviation. + + This function finds the smoothing param that optimizes the 1/rate value of the Gamma + distribution in order to find an optimal smoothing parameter for the GK model. + + A reference for this code can be found here: + https://github.com/IQSS/YourCast/blob/955a88043fa97d71b922585b5bcd28cc5738d75c/R/compute.sigma.R#L194-L270 + + Args: + array (np.ndarray): The one dimensional array like object to compute smoothing + parameter over. + + Returns: + float: The appropriate smoothing value given an array. + """ + ma = np.ma.array(array, mask=np.isnan(array)) + weights = np.ones_like(array).cumsum(axis=2) ** 2 + m = np.ma.average(ma, weights=weights) + std = np.sqrt(np.ma.average((ma - m) ** 2, weights=weights)) + v = std**2 + # not sure why this is calculated but its in the reference code as well + # _ = (m**2.) / (v + m**2.) + d = brentq(_optimize_rate, 2.0000001, 1000, args=(m, v)) + + e = (d - 2) * (v + m**2) + return d / e + + +def _optimize_rate(x: float, m: float, v: float) -> float: + """Find the optimization rate corresponding to the observations. + + This function finds the optimization rate corresponding to the reciprocal-rate, + weighted-mean, and variance of the observations. + + Args: + x: The **positive** value of 1/rate to test. + m: The weighted average of an array. + v: The variance. + + Returns: + float: The optimization rate + """ + opr = np.sqrt((x / 2) - 1) * np.exp(gammaln((x / 2) - 0.5) - gammaln(x / 2)) - ( + m / np.sqrt(m**2 + v) + ) + return opr diff --git a/gbd_2021/disease_burden_forecast_code/mortality/models/gk-model/post_process.py b/gbd_2021/disease_burden_forecast_code/mortality/models/gk-model/post_process.py new file mode 100644 index 0000000..f0163a4 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/mortality/models/gk-model/post_process.py @@ -0,0 +1,122 @@ +"""This module contains utilities for processing the model output into a more usable form. +""" + +from typing import Any, Dict, List, Union + +import numpy as np +import xarray as xr +from fhs_lib_database_interface.lib.constants import DimensionConstants +from tiny_structured_logger.lib.fhs_logging import get_logger + +from fhs_lib_model.lib.constants import GKModelConstants + +logger = get_logger() + + +def region_to_location(array: np.ndarray, region_map: List[int]) -> np.ndarray: + """Converts a region axis of an array into an axis of detailed locations. + + Given an array of coefficient values and regions transforms the array from + region-age-time-covariate-draw to location-age-time-covariate-draw. + + Args: + array (np.ndarray): 5D array with first axis to be converted from regions to detailed + locations. + region_map (List[int]): A list of region indices -- one index for each location on the + new location axis. + + Returns: + np.ndarray: 5d array with axis for locations instead of for regions. + """ + _, age_dim_size, time_dim_size, covariate_dim_size, draw_dim_size = array.shape + location_dim_size = len(region_map) + + # Initialize an empty array with the desired shape. + location_array = np.zeros( + ( + location_dim_size, + age_dim_size, + time_dim_size, + covariate_dim_size, + draw_dim_size, + ) + ) + + # For each location, assign the coefficient to it that corresponds to the region it is in. + for i in range(len(region_map)): + location_array[i, :, :, :, :] = array[region_map[i], :, :, :, :] + + return location_array + + +def transform_parameter_array(array: np.ndarray, constraints: np.ndarray) -> np.ndarray: + """Transforms an array of coefficients based on the corresponding parameter's constraints. + + Applies the associated constraints to the coefficients of each covariate for the given + parameter. + + Args: + array (np.ndarray): The mutli-dimensional Numpy array of parameter coefficients to + transform. Must have the following 5 dimensions in that order: location, age, time + (i.e., years), covariate, and draw. + constraints (np.array): Array of integers with value 0, -1, or 1 corresponding to the + constraint of each covariate. + + Returns: + np.ndarray: The 5-dimensional array of coefficients that have been transformed + according to the given constraints. + """ + arr_trans = np.copy(array) + for i in range(len(constraints)): + if constraints[i] == -1: + arr_trans[:, :, :, i, :] = -1 * np.exp(arr_trans[:, :, :, i, :]) + if constraints[i] == 1: + arr_trans[:, :, :, i, :] = np.exp(arr_trans[:, :, :, i, :]) + + return arr_trans + + +def np_to_xr( + array: np.ndarray, + demog_coords: Dict[str, Union[int, str]], + covariates: List[str], +) -> xr.DataArray: + """Convert a numpy array from TMB to an xarray for use in forecasting. + + Any single coordinate dimensions will be removed. + + Args: + array (np.ndarray): The Numpy multi-dimension array to convert into an Xarray + DataArray. + demog_coords (Dict[str, Union[int, str]]): Mapping of demographic dimensions to their + corresponding coordinates. + covariates: coordinates for covariate dim. + + Returns: + xr.DataArray: The model output that has been converted to Xarray DataArray form. + """ + all_coords: Dict[str, Any] = demog_coords.copy() + all_coords[GKModelConstants.COVARIATE_DIM] = covariates + + drop_axes = [x for x in range(3) if array.shape[x] == 1] + drop_dims = [GKModelConstants.OUTPUT_DIMS[x] for x in range(3) if array.shape[x] == 1] + dims = [ + GKModelConstants.OUTPUT_DIMS[i] + for i in range(len(GKModelConstants.OUTPUT_DIMS)) + if i not in drop_axes + ] + + for dim in drop_dims: + all_coords.pop(dim) + + all_coords[DimensionConstants.DRAW] = np.arange(array.shape[-1]) + + squeezed_array = array.copy() + if len(drop_axes): + logger.debug( + "Dropping dimensions from output", + bindings=dict(drop_dims=drop_dims, drop_axes=drop_axes), + ) + squeezed_array = squeezed_array.mean(tuple(drop_axes)) + + return xr.DataArray(squeezed_array, dims=dims, coords=all_coords) diff --git a/gbd_2021/disease_burden_forecast_code/mortality/models/gk-model/pre_process.py b/gbd_2021/disease_burden_forecast_code/mortality/models/gk-model/pre_process.py new file mode 100644 index 0000000..6850115 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/mortality/models/gk-model/pre_process.py @@ -0,0 +1,97 @@ +"""This module contains utilities for preparing the model input. +""" + +import gc + +import xarray as xr +from fhs_lib_database_interface.lib.constants import DimensionConstants, ScenarioConstants +from tiny_structured_logger.lib.fhs_logging import get_logger + +from fhs_lib_model.lib.constants import GKModelConstants + +logger = get_logger() + + +def data_var_merge(dataset: xr.Dataset, collapse: bool = False) -> xr.DataArray: + """Merges the data variables of the given dataset into one dataarray. + + Adds dims to each data variable of the dataset such that they are all consistent in + shape, i.e., dimensions and coordinates. Also adds a new single-coordinate dim + ``"cov"``, i.e., "covariate" to each variable, where the coordinate is the name of the + variable. Then all of the data variables are concatenated over this new dimension, such + that each one corresponds to a coordinate on the ``"cov"`` dimension. + + Args: + dataset (xr.Dataset): The dataset to convert into a dataarray. + collapse (bool): Whether to collapse the dataarray into a model-able array. + + Returns: + xr.DataArray: The dataarray containing each data variable of the original dataset + as a dimension. + """ + final_das = [] + exclude_dims = [DimensionConstants.REGION_ID, DimensionConstants.SUPER_REGION_ID] + col_vars = [v for v in dataset.data_vars.keys() if v not in exclude_dims] + + for var in col_vars: + logger.debug("Adding dims to data var", bindings=dict(var=var)) + + ex_da = dataset[var].copy() + das_to_expand_by = list(dataset.data_vars.keys()) + das_to_expand_by.remove(var) + for oda_key in das_to_expand_by: + ex_da, _ = xr.broadcast(ex_da, dataset[oda_key].copy()) + + ex_da = ex_da.expand_dims({GKModelConstants.COVARIATE_DIM: [ex_da.name]}) + + final_das.append(ex_da) + + del ex_da + gc.collect() + + final_array: xr.DataArray = xr.concat(final_das, dim="cov") + del final_das + gc.collect() + + if collapse: + logger.debug("Collapsing dataarray into model-able array") + final_array = _modable_array(final_array) + + return final_array + + +def _modable_array(dataarray: xr.DataArray) -> xr.DataArray: + """Collapse an array to its model-able dimensions. + + Collapse an array to its model-able dimensions collapsing on the draw dimnesion (i.e., + taking the mean of the draws) and any other dimensions that aren't in the designated + model-able dimensions, defined in ``GKModelConstants.MODELABLE_DIMS``. If the scenario + dimension exists, then only use the reference/default scenario. + + Args: + dataarray (xr.DataArray): The array to collapse into a "model-lable" array as + required by the GK Modeling interface. + + Returns: + xr.DataArray: The collapsed model-able array. + """ + # Only keep the reference scenario if scenario is a dimension. + if DimensionConstants.SCENARIO in dataarray.dims: + scenario_slice = { + DimensionConstants.SCENARIO: ScenarioConstants.REFERENCE_SCENARIO_COORD + } + logger.debug("Slicing to reference scenario", bindings=scenario_slice) + dataarray = dataarray.sel(**scenario_slice) + + # Take the mean across all coordinates of any dimension that isn't a designated + # model-able dimension. + for dim in dataarray.dims: + if dim not in GKModelConstants.MODELABLE_DIMS: + logger.debug("Collapsing non-modelable dim into mean", bindings=dict(dim=dim)) + dataarray = dataarray.mean(dim) + + # Reorder the dimensions of the collapsed array so they conform to the order as they + # appear in the list of designated model-able dimensions. + reorder = [d for d in GKModelConstants.MODELABLE_DIMS if d in dataarray.dims] + dataarray = dataarray.transpose(*reorder) + return dataarray diff --git a/gbd_2021/disease_burden_forecast_code/mortality/models/pooled_random_walk.py b/gbd_2021/disease_burden_forecast_code/mortality/models/pooled_random_walk.py new file mode 100644 index 0000000..0678fe3 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/mortality/models/pooled_random_walk.py @@ -0,0 +1,303 @@ +"""This module contains tools creating a collection of correlated Random Walk projections. +""" + +import itertools as it +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import xarray as xr +from fhs_lib_data_transformation.lib.dimension_transformation import expand_dimensions +from fhs_lib_database_interface.lib.constants import DimensionConstants +from fhs_lib_year_range_manager.lib.year_range import YearRange +from tiny_structured_logger.lib.fhs_logging import get_logger + +from fhs_lib_model.lib.constants import ModelConstants, PooledRandomWalkConstants +from fhs_lib_model.lib.random_walk.random_walk import RandomWalk + +logger = get_logger() + + +class PooledRandomWalk(RandomWalk): + """Creates a collection of correlated Random Walk projections. + + Runs a random walk model on the data while pulling strength across multiple time series. + Within the model structure argument define a key dims which defines which dimensions to use + specific ar parameter values for as a list. An empty list, which is the default will use + the same parameters for all time series. The time dimension should always be ``"year_id"``. + If you specify ``"region_id"`` or ``"super_region_id"`` the dataset parameter must be an + ``xr.Dataset`` with the modeled value as ``"y"`` and another key with ``"region_id"`` or + `"super_region_id"`` which has a ``"location_id"`` dimension and a value for each location. + """ + + def __init__( + self, + dataset: Union[xr.DataArray, xr.Dataset], + years: YearRange, + draws: int, + y: str = ModelConstants.DEFAULT_DEPENDENT_VAR, + dims: Optional[List[str]] = None, + seed: int | None = None, + ) -> None: + """Initialize arguments and set state. + + Args: + dataset: (xr.Dataset | xr.DataArray): Data to forecast. If ``xr.Dataset`` then + ``y`` paramter must be specified as the variable to forecast. + years (YearRange): The forecasting time series. + draws (int): Number of draws to make for predictions. + y (str): If using a ``xr.Dataset`` the dependent variable to apply analysis to. + dims (Optional[List[str]]): List of demographics dimensions to pool random walk + estimates over. Defaults to an empty list. **NOTE:** ``super_region_id`` + pooling takes higher precedence than ``region_id`` pooling, so if both are + given in dims, then only ``super_region_id`` pooling will be done. + seed (Optional[int]): An optional seed to set for numpy's random number generation + """ + super().__init__(dataset, years, draws, y, seed) + + self._init_dims(dims) + + # Declare the instance field that holds the projections of, ``y``, i.e., the dependent + # variable (after the ``fit()`` and ``predict`` methods have been called). + self.y_hat_data: Optional[xr.DataArray] = None + + def fit(self) -> None: + """Fit the pooled random walk model.""" + if len(self.dims) >= 1: + self._pooled_fit() + else: + logger.warning("Projecting -- no pooled fit", bindings=dict(pooled_fit=False)) + super().fit() + + def predict(self) -> xr.DataArray: + """Generate predictions based on model fits. + + Returns: + xr.DataArray: The random walk projections. + """ + if len(self.dims) > 0: + self._pooled_predict() + return self.y_hat_data + else: + logger.warning("Projecting -- no pooled fit", bindings=dict(pooled_fit=False)) + self.y_hat_data = super().predict() + return self.y_hat_data + + def _init_dims(self, dims: Optional[str]) -> None: + """Initializes the ``dims`` instance field.""" + self.dims = dims if dims is not None else [] + + unexpected_dims = set(self.dims) - ( # list( + set(self.dataset.dims) | set(self.dataset.data_vars) + ) + if unexpected_dims: + err_msg = f"Some of the dims to pool over are not included: {unexpected_dims}" + logger.error(err_msg, bindings=dict(unexpected_dims=unexpected_dims)) + raise IndexError(err_msg) + + # super-region pooling takes higher precedence than region-pooling. + if DimensionConstants.SUPER_REGION_ID in self.dims: + logger.debug( + ( + "'super_region_id' dim present, 'region_id' and 'location_id' dims will " + "be removed if also present" + ) + ) + self.dims = [ + d + for d in self.dims + if d not in (DimensionConstants.REGION_ID, DimensionConstants.LOCATION_ID) + ] + elif DimensionConstants.REGION_ID in self.dims: + # region pooling takes higher precedence than detailed-location-pooling. + if DimensionConstants.SUPER_REGION_ID in self.dims: + logger.debug( + ( + "'region_id' dim present, 'location_id' dim will be removed if also " + "present" + ) + ) + self.dims = [d for d in self.dims if d != DimensionConstants.LOCATION_ID] + + def _pooled_fit(self) -> None: + """This method handles pooled Random Walk model fitting.""" + # Make a copy of the input data to be pre-processed before using it in the model. + y_data = self.dataset[self.y].copy() + dims = [d for d in self.dims] + + # If ``"super_region_id"`` is a dim to pool over, update coordinates of the + # ``location_id`` dim of the input dependent variable data to the values of the + # ``super_region_id`` data variable from the input dataset. + if DimensionConstants.SUPER_REGION_ID in dims: + y_data, dims = self._pre_pool_fit( + y_data=y_data, + dims=dims, + aggregate_location_dim=DimensionConstants.SUPER_REGION_ID, + ) + # super-region pooling takes higher precedence than region-pooling, but if + # ``"region_id"`` is a dim to pool over and ``"super_region_id"`` is not, then do the + # same thing with ``"region_id"`` as is done above with ``"super_region_id"``. + elif DimensionConstants.REGION_ID in dims: + y_data, dims = self._pre_pool_fit( + y_data=y_data, + dims=dims, + aggregate_location_dim=DimensionConstants.REGION_ID, + ) + + # Create the sigma place holder dataarray to be filled in with fit-parameter values. + dim_coord_dict = {d: list(y_data[d].values) for d in dims} + value_array = np.ones([len(x) for x in dim_coord_dict.values()]) + + self.sigma = xr.DataArray( + value_array, + dims=dims, + coords=dim_coord_dict, + name=PooledRandomWalkConstants.SIGMA, + ) + + year_dict = {DimensionConstants.YEAR_ID: list(self.years.past_years)} + dim_coord_sets = [np.unique(y_data[d].values) for d in dims] + + results = [] + for slice_coords in it.product(*dim_coord_sets): + result = self._fit_work( + dims=dims, + year_dict=year_dict, + y_data=y_data, + slice_coords=slice_coords, + ) + results.append(result) + + for result in results: + coord_slice_dict = result[0] + self.sigma.loc[coord_slice_dict] = result[1] + + if DimensionConstants.SUPER_REGION_ID in self.dims: + self._post_pool_fit(aggregate_location_dim=DimensionConstants.SUPER_REGION_ID) + elif DimensionConstants.REGION_ID in self.dims: + self._post_pool_fit(aggregate_location_dim=DimensionConstants.REGION_ID) + + def _pre_pool_fit( + self, + y_data: xr.DataArray, + dims: List[str], + aggregate_location_dim: str, + ) -> Tuple[xr.DataArray, List[str]]: + """Prepare inputs for pooling.""" + y_data = y_data.assign_coords(location_id=self.dataset[aggregate_location_dim].values) + dims = [ + d if d != aggregate_location_dim else DimensionConstants.LOCATION_ID for d in dims + ] + return y_data, dims + + def _post_pool_fit(self, aggregate_location_dim: str) -> None: + """Cleanup model-fit after pooling.""" + logger.info( + ( + f"Switching dim='location_id' back from {aggregate_location_dim} coords to" + f" 'location_id' coords" + ) + ) + location_ids = self.dataset[aggregate_location_dim].location_id.values + self.sigma = self.sigma.assign_coords(location_id=location_ids) + self.dataset = self.dataset.assign_coords(location_id=location_ids) + + def _pooled_predict(self) -> None: + """This method handles pooled Random Walk model projection.""" + # Set place holder dataarray for projections -- with appropriate dim/coord shape. + self.y_hat_data = expand_dimensions( + xr.ones_like( + self.dataset[self.y].sel( + **{DimensionConstants.YEAR_ID: self.years.past_end}, drop=True + ) + ), + **{DimensionConstants.YEAR_ID: self.years.years}, + draw=list(range(self.draws)), + ) + + y_observed_data = self.dataset[self.y] + dims = list(self.sigma.coords.indexes.keys()) + dim_coord_sets = [self.sigma[d].values for d in self.sigma.dims] + + results = [] + for slice_coords in it.product(*dim_coord_sets): + result = self._predict_work( + dims=dims, + y_observed_data=y_observed_data, + y_hat_data=self.y_hat_data, + sigma=self.sigma, + years=self.years, + slice_coords=slice_coords, + ) + results.append(result) + + for result in results: + coord_slice_dict = result[0] + self.y_hat_data.loc[coord_slice_dict] = result[1] + + @staticmethod + def _fit_work( + dims: List[str], + year_dict: Dict[str, List[int]], + y_data: xr.DataArray, + slice_coords: Tuple[int], + ) -> Tuple[Dict[str, int], np.ndarray]: + """Calculate sigma based on past time series variation.""" + coord_slice_dict = {dims[i]: slice_coords[i] for i in range(len(dims))} + + y_data_slice = y_data.loc[dict(**coord_slice_dict, **year_dict)] + diff = y_data_slice.diff(DimensionConstants.YEAR_ID) + sigma = np.std(diff.values) + + return coord_slice_dict, sigma + + def _predict_work( + self, + dims: List[str], + y_observed_data: xr.DataArray, + y_hat_data: xr.DataArray, + sigma: xr.DataArray, + years: YearRange, + slice_coords: Tuple[int], + ) -> Tuple[Dict[str, int], xr.DataArray]: + """Generate predictions for a specific age, sex, location.""" + coord_slice_dict = {dims[i]: slice_coords[i] for i in range(len(slice_coords))} + + y_hat_slice = y_hat_data.loc[coord_slice_dict] + y_observed_slice = y_observed_data.loc[coord_slice_dict] + sigma_slice = sigma.loc[coord_slice_dict] + + # NaNs produced by upstream R process can be < 0, which causes error -- we want to + # replace these with Numpy NaNs. + if np.isnan(sigma_slice.values): + logger.warning("Replacing R NaNs with Numpy NaNs") + sigma_slice.values = np.nan + + # Populate past year data with observed data but copied to include fake draws. + past_year_dict = {DimensionConstants.YEAR_ID: years.past_years} + + past_shape = y_hat_slice.loc[past_year_dict] + y_hat_past = y_observed_slice.sel(**past_year_dict) * y_hat_slice.sel(**past_year_dict) + y_hat_slice.loc[past_year_dict] = y_hat_past.transpose(*past_shape.coords.dims) + + # Populate future year data with projections. + forecast_year_dict = {DimensionConstants.YEAR_ID: years.forecast_years} + + forecast_shape = y_hat_slice.sel(**forecast_year_dict) + distribution = xr.DataArray( + self.rng.normal( + loc=0, + scale=sigma_slice.values, + size=forecast_shape.shape, + ), + coords=forecast_shape.coords, + dims=forecast_shape.dims, + ) + + y_hat_forecast = y_observed_slice.sel( + **{DimensionConstants.YEAR_ID: years.past_end}, drop=True + ) + distribution.cumsum(DimensionConstants.YEAR_ID) + y_hat_slice.loc[forecast_year_dict] = y_hat_forecast.transpose( + *forecast_shape.coords.dims + ) + + return coord_slice_dict, y_hat_slice diff --git a/gbd_2021/disease_burden_forecast_code/mortality/models/random_walk.py b/gbd_2021/disease_burden_forecast_code/mortality/models/random_walk.py new file mode 100644 index 0000000..2f7bf0c --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/mortality/models/random_walk.py @@ -0,0 +1,163 @@ +"""This module contains the Random Walk model. +""" + +from typing import Optional, Union + +import numpy as np +import xarray as xr +from fhs_lib_data_transformation.lib.dimension_transformation import ( + expand_dimensions, + remove_dims, +) +from fhs_lib_data_transformation.lib.processing import strip_single_coord_dims +from fhs_lib_database_interface.lib.constants import DimensionConstants +from fhs_lib_year_range_manager.lib.year_range import YearRange +from tiny_structured_logger.lib.fhs_logging import get_logger + +from fhs_lib_model.lib.constants import ModelConstants + +logger = get_logger() + + +class RandomWalk: + """Runs an independent random walk model for every time series in the dataset.""" + + def __init__( + self, + dataset: Union[xr.DataArray, xr.Dataset], + years: YearRange, + draws: int, + y: str = ModelConstants.DEFAULT_DEPENDENT_VAR, + seed: int | None = None, + ) -> None: + """Initialize arguments and set state. + + Args: + dataset: (xr.Dataset | xr.DataArray): Data to forecast. If ``xr.Dataset`` then + ``y`` paramter must be specified as the variable to forecast. + years (YearRange): The forecasting time series. + draws (int): Number of draws to make for predictions. + y (str): If using a ``xr.Dataset`` the dependent variable to apply analysis to. + seed (Optional[int]): An optional seed to set for numpy's random number generation + """ + self.rng = np.random.default_rng(seed=seed) + + self.y = y + self.years = years + self.draws = draws + self._init_dataset(dataset) + + # Declare the sigma instance field that will hold the model fit values. + self.sigma: Optional[xr.DataArray] = None + + def fit(self) -> None: + """Fit the model. + + Here, fitting the model corrsponds to calculating the standard deviation of the normal + distribution used to generate the random walk. + """ + diff = self.dataset[self.y].diff(DimensionConstants.YEAR_ID) + self.sigma = diff.std(DimensionConstants.YEAR_ID) + + def predict(self) -> xr.DataArray: + """Generate predictions based on model fit. + + Produces projections into future years using the model fit. + + Notes: + * Must be run after the ``RandomWalk.fit()`` method. + + Returns: + xr.DataArray: Contains past data and future projections. + """ + location_ids = self.dataset[self.y].location_id.values + age_group_ids = self.dataset[self.y].age_group_id.values + sex_ids = self.dataset[self.y].sex_id.values + location_forecast_list = [] + for location_id in location_ids: + age_forecast_list = [] + for age_group_id in age_group_ids: + sex_forecast_list = [] + for sex_id in sex_ids: + sex_forecast = self._predict_demog_slice( + location_id=location_id, + age_group_id=age_group_id, + sex_id=sex_id, + ) + sex_forecast_list.append(sex_forecast) + age_forecast_list.append( + xr.concat(sex_forecast_list, dim=DimensionConstants.SEX_ID) + ) + location_forecast_list.append( + xr.concat(age_forecast_list, dim=DimensionConstants.AGE_GROUP_ID) + ) + forecast = xr.concat(location_forecast_list, dim=DimensionConstants.LOCATION_ID) + past = self.dataset.y + + past = expand_dimensions(past, draw=range(self.draws)) + past = strip_single_coord_dims(past) + + full_time_series: xr.DataArray = xr.concat( + [past, forecast], dim=DimensionConstants.YEAR_ID + ) + + return full_time_series + + def _init_dataset(self, dataset: Union[xr.DataArray, xr.Dataset]) -> None: + """Initialize the dataset of observed/past data.""" + if isinstance(dataset, xr.DataArray): + logger.debug("DataArray given, converting to Dataset") + self.dataset = xr.Dataset({self.y: dataset}) + else: + self.dataset = dataset + + if self.y not in list(self.dataset.data_vars.keys()): + err_msg = "``y`` was not found in the input dataset" + logger.error(err_msg) + raise ValueError(err_msg) + + expected_dims = [ + DimensionConstants.YEAR_ID, + DimensionConstants.LOCATION_ID, + DimensionConstants.AGE_GROUP_ID, + DimensionConstants.SEX_ID, + ] + for dim in self.dataset.dims: + if dim not in expected_dims: + self.dataset = remove_dims(xr_obj=self.dataset, dims_to_remove=[str(dim)]) + + def _predict_demog_slice( + self, location_id: int, age_group_id: int, sex_id: int + ) -> xr.DataArray: + """Generate predictions based on model fit on a single demographic slice.""" + demographic_slice_coords = { + DimensionConstants.LOCATION_ID: location_id, + DimensionConstants.AGE_GROUP_ID: age_group_id, + DimensionConstants.SEX_ID: sex_id, + } + + past_slice = self.dataset.loc[demographic_slice_coords][self.y].values + if len(past_slice.shape) != 1: + err_msg = "Demographic slice must be 1-Dimensional (just year/time)" + logger.error(err_msg) + raise RuntimeError(err_msg) + + sigma_slice = self.sigma.loc[demographic_slice_coords].values + + distribution = self.rng.normal( + loc=0, + scale=sigma_slice, + size=(self.draws, self.years.forecast_end - self.years.past_end), + ) + forecast_slice = past_slice[-1] + np.cumsum(distribution, axis=1) + + forecast_slice_dataarray = xr.DataArray( + forecast_slice, + coords=[list(range(self.draws)), self.years.forecast_years], + dims=[DimensionConstants.DRAW, DimensionConstants.YEAR_ID], + ) + forecast_slice_dataarray.coords[DimensionConstants.LOCATION_ID] = location_id + forecast_slice_dataarray.coords[DimensionConstants.AGE_GROUP_ID] = age_group_id + forecast_slice_dataarray.coords[DimensionConstants.SEX_ID] = sex_id + + return forecast_slice_dataarray diff --git a/gbd_2021/disease_burden_forecast_code/mortality/models/remove_drift.py b/gbd_2021/disease_burden_forecast_code/mortality/models/remove_drift.py new file mode 100644 index 0000000..7d32c2b --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/mortality/models/remove_drift.py @@ -0,0 +1,175 @@ +"""This module contains functions for attenuating or removing the drift effect from epsilon. + +i.e., residual error, predictions. + +This is intended to be used in relation with latent trend models such as ARIMA and Random Walk. +""" + +from typing import Tuple + +import numpy as np +import statsmodels.api as sm +import xarray as xr +from fhs_lib_database_interface.lib.constants import DimensionConstants +from fhs_lib_year_range_manager.lib.year_range import YearRange +from tiny_structured_logger.lib.fhs_logging import get_logger + +logger = get_logger() + + +def get_decayed_drift_preds( + epsilon_da: xr.DataArray, + years: YearRange, + decay: float, +) -> xr.DataArray: + """Generates attenuated drift predictions. + + Generates attenuated drift predictions for each demographic combo in the input dataset, + excluding the ``"year_id"`` dimension. + + Args: + epsilon_da (xr.DataArray): Dataarray with values to fit and remove the drift from. + years (YearRange): The forecasting time series year range. + decay (float): Rate at which the slope of the line decays once forecasts start. + + Returns: + xr.DataArray: Predictions for past and future years -- linear regression for past + years, and attenuated drift for future years. Predictions are for every demographic + combo available in the input dataarray. + """ + # Find the right linear regression parameters to fit the in-sample data to. + param_ds = _get_all_lr_params(da=epsilon_da) + + # Get the right shape for the prediction dataarray - fill with nans + year_shape_da = xr.DataArray( + np.ones(len(years.years)), + coords=[years.years], + dims=[DimensionConstants.YEAR_ID], + ) + pred_da = param_ds["alpha"] * year_shape_da * np.nan + + # Iterate through the demographic combos, and populating the dataarray with predictions for + # each combo. + for sex_id in pred_da[DimensionConstants.SEX_ID].values: + for age_group_id in pred_da[DimensionConstants.AGE_GROUP_ID].values: + for location_id in pred_da[DimensionConstants.LOCATION_ID].values: + slice_coord_dict = { + DimensionConstants.SEX_ID: sex_id, + DimensionConstants.AGE_GROUP_ID: age_group_id, + DimensionConstants.LOCATION_ID: location_id, + } + alpha = param_ds["alpha"].loc[slice_coord_dict].values + beta = param_ds["beta"].loc[slice_coord_dict].values + + pred_da.loc[slice_coord_dict] = _make_single_predictions( + alpha=alpha, + beta=beta, + years=years, + decay=decay, + ) + + return pred_da + + +def _get_all_lr_params(da: xr.DataArray) -> xr.Dataset: + """Fits a linear regression to each demographic time series slice. + + Iterates through all demographic time series slices in an input dataarray, fitting a linear + regression to each. The parameters of the linear regression fit are returned in a dataset. + + Args: + da (xr.DataArray): n-dimensional dataarray with ``"year_id"`` as a dimension. + + Returns: + xr.Dataset: n-1 dimensional xarray containing the intercept, i.e., alpha, and slope, + i.e., beta, for each combination of coords in ``da`` excluding ``"year_id"``. + """ + # Create a dataset to store parameters - fill with nans to start, and get rid of + # ``"year_id"`` dimension. + param_da = da.sel(**{DimensionConstants.YEAR_ID: da.year_id.values[0]}, drop=True) * np.nan + param_ds = xr.Dataset({"alpha": param_da.copy(), "beta": param_da.copy()}) + + # fit linear regression by location-age-sex + for sex_id in da[DimensionConstants.SEX_ID].values: + for age_group_id in da[DimensionConstants.AGE_GROUP_ID].values: + for location_id in da[DimensionConstants.LOCATION_ID].values: + slice_coord_dict = { + DimensionConstants.SEX_ID: sex_id, + DimensionConstants.AGE_GROUP_ID: age_group_id, + DimensionConstants.LOCATION_ID: location_id, + } + ts = da.loc[slice_coord_dict] + ts = ts.squeeze() + alpha, beta = _get_lr_params(ts) + + param_ds["alpha"].loc[slice_coord_dict] = alpha + param_ds["beta"].loc[slice_coord_dict] = beta + + return param_ds + + +def _get_lr_params(ts: xr.DataArray) -> Tuple[float, float]: + """Fits a linear regression to the given time series. + + Fits a linear regression to an input time series and returns the fit params. + + Args: + ts (xr.DataArray): 1D time series of epsilons. + + Returns: + Tuple[float, float]: the intercept and slope of the time series + """ + xdata = np.arange(len(ts)) + xdata = sm.add_constant(xdata) + model = sm.OLS(ts.values, xdata).fit() + alpha, beta = model.params + + return alpha, beta + + +def _make_single_predictions( + alpha: float, + beta: float, + years: YearRange, + decay: float, +) -> np.ndarray: + r"""Generates predictions of the form y = alpha + beta*time. + + These predictions are generated for the years in years.past_years (linear regression + predictions), then attenuates the slope that is added for each year after as the following. + + .. math:: + + y_{t+1} = y_t + \beta * \exp( -\text{decay} * \text{time-since-holdout} ) + + for each year in the future. + + Args: + alpha (float): intercept for linear regression + beta (float): slope for linear regression + years (YearRange): years to fit and forecast over + decay (float): rate at which the slope of the line decays once + forecasts start + + Returns: + numpy.array: the predictions generated by the input parameters and years + """ + linear_years = np.arange(len(years.past_years)) + + # First, create linear predictions for past years. + predictions = alpha + (beta * linear_years) + + # Then add the decay year-by-year for future years. + last = predictions[-1] + for year_index in range(len(years.forecast_years)): + current = last + (beta * np.exp(-decay * year_index)) + + logger.debug( + "Applying decay to future year", + bindings=dict(year_index=year_index, prediction=current), + ) + + predictions = np.append(predictions, current) + last = current + + return predictions diff --git a/gbd_2021/disease_burden_forecast_code/nonfatal/README.md b/gbd_2021/disease_burden_forecast_code/nonfatal/README.md new file mode 100644 index 0000000..1212aca --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/nonfatal/README.md @@ -0,0 +1,77 @@ +Nonfatal pipeline code + +The `run_model.one_cause_main()` function is the main entry point to computation. + +# Lib + +``` +check_entity_files.py +Checks that the proper entity files are output by a given job +``` + +``` +constants.py +Nonfatal pipeline local constants +``` + +``` +indicator_from_ratio.py +Computes target indicator from existing ratio and indicator data +``` + +``` +model_parameters.py +Parameters to be used for model strategy +``` + +``` +model_strategy.py +Where nonfatal modeling strategies and their parameters are managed/defined +``` + +``` +model_strategy_queries.py +Has query functions that give nonfatal modeling strategies and their params +``` + +``` +ratio_from_indicators.py +Computes ratio of two indicators for past data +``` + +``` +run_model.py +Script that forecasts nonfatal measures of health +``` + +``` +yld_from_prevalence.py +Computes and saves YLDs using prevalence forecasts and average disability weight. +``` + + +# Models +``` +arc_method.py +Module with functions for making forecast scenarios +``` + +``` +limetr.py +Provides an interface to the LimeTr model +``` + +``` +omega_selection_strategy.py +Strategies for determining the weight for the Annualized Rate-of-Change (ARC) method +``` + +``` +processing.py +Contains all the functions for processing data for use in modeling +``` + +``` +validate.py +Functions related to validating inputs, and outputs of nonfatal pipeline +``` \ No newline at end of file diff --git a/gbd_2021/disease_burden_forecast_code/nonfatal/lib/check_entity_files.py b/gbd_2021/disease_burden_forecast_code/nonfatal/lib/check_entity_files.py new file mode 100644 index 0000000..305123b --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/nonfatal/lib/check_entity_files.py @@ -0,0 +1,78 @@ +"""This script checks that the proper entity files are output by a given job.""" +import sys +from typing import List, Optional + +from fhs_lib_database_interface.lib.query import cause +from fhs_lib_file_interface.lib.check_input import write_log_message +from fhs_lib_file_interface.lib.version_metadata import VersionMetadata +from tiny_structured_logger.lib import fhs_logging + +logger = fhs_logging.get_logger() + + +def find_missing_entities(entity_dir: VersionMetadata, entities: List[str]) -> List[str]: + """Do the work of finding entities missing in the given directory. + """ + return [ + entity for entity in entities if not (entity_dir.data_path() / f"{entity}.nc").exists() + ] + + +def check_entities_main( + entities: tuple[str], + gbd_round_id: int, + past_or_future: str, + stage: str, + version: str, + suffix: Optional[str], + entities_source: Optional[str], +) -> None: + """Check whether there are any missing entity files. + + Writes a warning file to a warnings subdirectory in the given version if entities are + missing. + + Args: + entities (list[str]): List of entities to check + gbd_round_id (int): What gbd_round_id the results are saved under + past_or_future (str): Whether we are checking past or future + stage (str): The stage that we are checking files for (e.g. prevalence) + version (str): The version that we are checking files for + suffix (str | None, optional): Optionally append a suffix to the warning file name, + e.g. "_from_pi" + entities_source (Optional[str]): When "entities" is undefined/empty, use this "source" + parameter to load the entities from the shared database. + """ + suffix = suffix or "" + + if len(entities) == 0: + entities = cause.get_stage_cause_set(stage, gbd_round_id, source=entities_source) + logger.info(f"Called database with {stage}, {gbd_round_id}. got {len(entities)}") + + if len(entities) == 0: + message = ( + "run-check-entities is running against 0 entities, doing nothing. " + "That's probably a mistake." + ) + logger.error(message) + raise ValueError(message) + + entity_dir = VersionMetadata.parse_version( + f"{gbd_round_id}/{past_or_future}/{stage}/{version}" + ) + missing_entities = find_missing_entities(entity_dir, entities) + + if missing_entities: + warn_msg = ( + f"There are missing entities! Missing: " + f"{', '.join(map(str, missing_entities))}\n" + ) + warn_dir = entity_dir.data_path() / "warnings" + warn_dir.mkdir(parents=True, exist_ok=True) + warn_file = warn_dir / f"missing_entities{suffix}.txt" + + logger.warning(warn_msg) + write_log_message(warn_msg, warn_file) + sys.exit(100) + else: + logger.info("No missing entities!") diff --git a/gbd_2021/disease_burden_forecast_code/nonfatal/lib/constants.py b/gbd_2021/disease_burden_forecast_code/nonfatal/lib/constants.py new file mode 100644 index 0000000..5b61e9d --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/nonfatal/lib/constants.py @@ -0,0 +1,54 @@ +"""FHS Nonfatal Pipeline Local Constants.""" +from collections import namedtuple + +from fhs_lib_database_interface.lib.constants import StageConstants as ImportedStageConstants +from fhs_lib_database_interface.lib.query.model_strategy import RATIO_INDICATORS + + +class JobmonConstants: + """Constants related to Jobmon tasks.""" + + TOOL_NAME = "fhs_nonfatal_pipeline_tool" + PIPELINE_NAME = "fhs_nonfatal_pipeline" + TIMEOUT = 260000 # giving the entire workflow 3 days to run + + +class MADTruncateConstants: + """Constants related to MAD truncation.""" + + # Settings for mad_truncate method + MAX_MULTIPLIER = 6.0 + MEDIAN_DIMS = ("age_group_id",) + MULTIPLIER_STEP = 0.1 + PCT_COVERAGE = 0.975 + + +class ModelConstants: + """Constants related to modeling or model specification.""" + + DEFAULT_OFFSET = 1e-8 + MIN_VALUE = 1e-10 + + +class StageConstants(ImportedStageConstants): + """Stages in FHS file system.""" + + RATIO_MEASURES = tuple(RATIO_INDICATORS.keys()) + FORECAST_MEASURES = RATIO_MEASURES + ("prevalence", "incidence") + ALL_MEASURES = FORECAST_MEASURES + ("yld",) + + RatioToIndicatorMap = namedtuple( + "RatioToIndicatorMap", "target_indicator, available_indicator, ratio" + ) + + PHASE_ONE_RATIO_INDICATOR_MAPS = ( + RatioToIndicatorMap("prevalence", "death", "mp_ratio"), + RatioToIndicatorMap("incidence", "death", "mi_ratio"), + RatioToIndicatorMap("yld", "yll", "yld_yll_ratio"), + ) + PHASE_TWO_RATIO_INDICATOR_MAPS = ( + RatioToIndicatorMap("prevalence", "incidence", "pi_ratio"), + RatioToIndicatorMap("incidence", "prevalence", "pi_ratio"), + ) + + PREVALENCE_MAX = 1 - 1e-8 diff --git a/gbd_2021/disease_burden_forecast_code/nonfatal/lib/indicator_from_ratio.py b/gbd_2021/disease_burden_forecast_code/nonfatal/lib/indicator_from_ratio.py new file mode 100644 index 0000000..99039b3 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/nonfatal/lib/indicator_from_ratio.py @@ -0,0 +1,298 @@ +r"""Computes target indicator from existing ratio and indicator data. + +Parallelized by cause. + +Note: + The past version of the available indicator should line up with the + forecast version, otherwise there will be unsolvable intercept shift issues. + This mostly applies to mortality, where we need to use the past version of + mortality that mortality uses, rather than the final GBD output that + nonfatal uses to make MP and MI ratios. +""" +from typing import List, Optional, Tuple + +import numpy as np +import xarray as xr +from fhs_lib_data_transformation.lib.processing import ( + LogitProcessor, + LogProcessor, + clean_cause_data, + concat_past_and_future, + mad_truncate, + make_shared_dims_conform, +) +from fhs_lib_database_interface.lib.query.model_strategy import RATIO_INDICATORS +from fhs_lib_file_interface.lib.check_input import check_versions +from fhs_lib_file_interface.lib.query.io_helper import read_single_cause +from fhs_lib_file_interface.lib.version_metadata import FHSFileSpec +from fhs_lib_file_interface.lib.versioning import Versions, validate_versions_scenarios +from fhs_lib_file_interface.lib.xarray_wrapper import save_xr_scenario +from fhs_lib_year_range_manager.lib.year_range import YearRange +from tiny_structured_logger.lib import fhs_logging + +from fhs_pipeline_nonfatal.lib.constants import ( + MADTruncateConstants, + ModelConstants, + StageConstants, +) +from fhs_pipeline_nonfatal.lib.ratio_from_indicators import ratio_transformation + +logger = fhs_logging.get_logger() + + +def one_cause_main( + acause: str, + target_indicator_stage: str, + ratio_stage: str, + versions: Versions, + gbd_round_id: int, + draws: int, + years: YearRange, + output_scenario: Optional[int], + national_only: bool, +) -> None: + """Compute target indicator using existing ratio and indicator data. + + If target_is_numerator is true, use ratio multiply available_indicator, + otherwise divide available_indicator by ratio + + For example: death = mi_ratio * incidence + + Args: + acause (str): The cause for which a ratio of two indicators is being calculated + target_indicator_stage (str): What stage to save target indicator + ratio_stage (str): The ratio stage we are using to convert + versions: (Versions) A Versions object that keeps track of all the versions and their + respective data directories. + gbd_round_id: (int) What gbd_round_id the indicators and ratio are saved under + draws (int): How many draws to save for the ratio output + years (YearRange): Forecasting time series year range + output_scenario (Optional[int]): Optional output scenario ID + national_only (bool): Whether to include subnational locations, or to include only + nations. + """ + # validate versions + validate_versions_scenarios( + versions=versions, + output_scenario=output_scenario, + output_epoch_stages=[("future", target_indicator_stage)], + ) + + available_indicator_stage, target_is_numerator = _get_available_indicator_stage( + target_indicator_stage, ratio_stage + ) + + _check_versions(target_indicator_stage, ratio_stage, available_indicator_stage, versions) + + ratio_version_metadata = versions.get("future", ratio_stage).default_data_source( + gbd_round_id + ) + + # Indicator is calculated using scenario specific ratios. + # If `ratio` covariates (ie: SDI) vary by scenario, then ratios will vary by scenario. + # If `ratio` covariates are equal across scenarios, then between scenario differences are + # driven by mortality alone. + ratio = read_single_cause( + acause=acause, + stage=ratio_stage, + version_metadata=ratio_version_metadata, + ) + + # Set floor for rounding errors at ModelConstants.MIN_VALUE + ratio = ratio.where(ratio > ModelConstants.MIN_VALUE, other=ModelConstants.MIN_VALUE) + # For now, we won't worry about slicing the ratio data on the ``year_id`` + # dim, because we are for the available indicator, and so the arithmetic + # (which is inner-join logic) will force consistency on that dim. + clean_ratio, _ = clean_cause_data( + ratio, + ratio_stage, + acause, + draws, + gbd_round_id, + year_ids=None, + national_only=national_only, + ) + + available_indicator = _past_and_future_data( + acause, available_indicator_stage, versions, years, gbd_round_id, draws, national_only + ) + # Set floor for rounding errors at ModelConstants.MIN_VALUE + available_indicator = available_indicator.where( + available_indicator > ModelConstants.MIN_VALUE, other=ModelConstants.MIN_VALUE + ) + # A few death values can be > 1, which throws off the logit function + if available_indicator_stage == "death": + available_indicator = available_indicator.clip(max=1).fillna( + 1 - ModelConstants.DEFAULT_OFFSET + ) + + if target_is_numerator: + modeled_target_indicator = available_indicator * clean_ratio + else: + modeled_target_indicator = ratio_transformation( + available_indicator, clean_ratio, ratio_stage + ) + + # Fill non-finite target-indicator values with zeros; these have zeros in + # the denominator. + modeled_target_indicator = modeled_target_indicator.where( + np.isfinite(modeled_target_indicator) + ).fillna(0) + + # Get target indicator past to intercept shift + target_indicator_past = read_single_cause( + acause=acause, + stage=target_indicator_stage, + version_metadata=versions.get("past", target_indicator_stage).default_data_source( + gbd_round_id + ), + ) + clean_target_indicator_past, _ = clean_cause_data( + target_indicator_past, + target_indicator_stage, + acause, + draws, + gbd_round_id, + year_ids=years.past_years, + national_only=national_only, + ) + + target_indicator_past = make_shared_dims_conform( + clean_target_indicator_past, modeled_target_indicator, ignore_dims=["year_id"] + ) + # Special exception for malaria due to problematic data + if (acause == "malaria") & (target_indicator_stage == "prevalence"): + max_multiplier = 12 + median_dims = ["age_group_id", "location_id"] + else: + max_multiplier = MADTruncateConstants.MAX_MULTIPLIER + median_dims = list(MADTruncateConstants.MEDIAN_DIMS) + + truncated_target_indicator = mad_truncate( + modeled_target_indicator, + median_dims=median_dims, + pct_coverage=MADTruncateConstants.PCT_COVERAGE, + max_multiplier=max_multiplier, + multiplier_step=MADTruncateConstants.MULTIPLIER_STEP, + ) + + # Need to intercept shift in log/logit space to avoid negatives in output. + if target_indicator_stage == "prevalence": + processor = LogitProcessor( + years=years, + gbd_round_id=gbd_round_id, + remove_zero_slices=False, + no_mean=True, + bias_adjust=False, + intercept_shift="unordered_draw", + age_standardize=False, + shift_from_reference=False, + ) + truncated_target_indicator = truncated_target_indicator.clip( + max=StageConstants.PREVALENCE_MAX + ) + else: + processor = LogProcessor( + years=years, + gbd_round_id=gbd_round_id, + remove_zero_slices=False, + no_mean=True, + bias_adjust=False, + intercept_shift="unordered_draw", + age_standardize=False, + shift_from_reference=False, + ) + + processed_data = processor.pre_process(truncated_target_indicator) + shifted_truncated_target_indicator = processor.post_process( + processed_data, target_indicator_past + ) + + target_indicator_file = FHSFileSpec( + versions.get("future", target_indicator_stage), f"{acause}.nc" + ) + + save_xr_scenario( + shifted_truncated_target_indicator, + target_indicator_file, + metric="rate", + space="identity", + ) + + +def _past_and_future_data( + acause: str, + stage: str, + versions: Versions, + years: YearRange, + gbd_round_id: int, + draws: int, + national_only: bool, +) -> xr.DataArray: + """Get past and future data for a cause/stage.""" + + def get_data(past_or_future: str, year_ids: List[int]) -> xr.DataArray: + """Internal method to pull and clean single cause data.""" + data = read_single_cause( + acause=acause, + stage=stage, + version_metadata=versions.get(past_or_future, stage).default_data_source( + gbd_round_id + ), + ) + clean_data, _ = clean_cause_data( + data, + stage, + acause, + draws, + gbd_round_id, + year_ids=year_ids, + national_only=national_only, + ) + return clean_data + + forecast_data = get_data("future", years.forecast_years) + past_data = get_data("past", years.past_years) + + return concat_past_and_future(past_data, forecast_data) + + +def _get_available_indicator_stage( + target_indicator_stage: str, ratio_stage: str +) -> Tuple[str, bool]: + """Get available indicator, and check if target is numerator. + + Args: + target_indicator_stage (str): The stage of the target indicator + ratio_stage (str): The stage of ratio to check + + Raises: + ValueError: if the `target_indicator_stage` doesn't match any ratio stage. + + Returns: + Tuple[str, bool]: the available indicator stage and whether the target is the numerator + """ + if RATIO_INDICATORS[ratio_stage].numerator == target_indicator_stage: + available_indicator_stage = RATIO_INDICATORS[ratio_stage].denominator + target_is_numerator = True + elif RATIO_INDICATORS[ratio_stage].denominator == target_indicator_stage: + available_indicator_stage = RATIO_INDICATORS[ratio_stage].numerator + target_is_numerator = False + else: + raise ValueError(f"{target_indicator_stage} does not match ratio stage") + + return available_indicator_stage, target_is_numerator + + +def _check_versions( + target_indicator_stage: str, + ratio_stage: str, + available_indicator_stage: str, + versions: Versions, +) -> None: + """Checks that all expected versions are given.""" + expected_future_stages = {target_indicator_stage, ratio_stage, available_indicator_stage} + check_versions(versions, "future", expected_future_stages) + + expected_past_stages = {target_indicator_stage, available_indicator_stage} + check_versions(versions, "past", expected_past_stages) diff --git a/gbd_2021/disease_burden_forecast_code/nonfatal/lib/model_parameters.py b/gbd_2021/disease_burden_forecast_code/nonfatal/lib/model_parameters.py new file mode 100644 index 0000000..ad26586 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/nonfatal/lib/model_parameters.py @@ -0,0 +1,20 @@ +from collections import namedtuple + +ModelParameters = namedtuple( + "ModelParameters", + ( + "Model, " + "processor, " + "covariates, " + "fixed_effects, " + "fixed_intercept, " + "random_effects, " + "indicators, " + "spline, " + "predict_past_only, " + "node_models, " + "study_id_cols, " + "scenario_quantiles, " + "omega_selection_strategy, " + ), +) diff --git a/gbd_2021/disease_burden_forecast_code/nonfatal/lib/model_strategy.py b/gbd_2021/disease_burden_forecast_code/nonfatal/lib/model_strategy.py new file mode 100644 index 0000000..e3ea0fb --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/nonfatal/lib/model_strategy.py @@ -0,0 +1,373 @@ +r"""This module is where nonfatal modeling strategies and their parameters are managed/defined. + +**Modeling parameters include:** + +* pre/post processing strategy (i.e. processor object) +* covariates +* fixed-effects +* fixed-intercept +* random-effects +* indicators + +Currently, the LimeTr model for ratios is described as follows: + +.. Math:: + E[log({Y}_{l,a,s,y})] = \beta_{1}SDI + \gamma_{l,a,s}SDI + \alpha_{l,a,s} + +where :math:`E[\log{Y}]` is the expected log ratio, :math:`\beta_{1}` is the +fixed coefficient on SDI across time, :math:`\gamma_{l,a,s}` is the random +coefficient on SDI across time for each location-age-sex combination, and +:math:`\alpha_{l,a,s}` is the location-age-sex-specific random intercept. +The random slope has a Gaussian prior with mean 0 and standard deviation of +0.001. + +The LimeTr model for the indicators (i.e. prevalence or incidence) is simpler: + +.. Math:: + E[logit({Y}_{l,a,s,y})] = \beta_{1}SDI + \alpha_{l,a,s} + +where :math:`E[logit({Y}_{l,a,s,y})]` is the expected value of +logit(prevalence or incidence). +""" +from fhs_lib_data_aggregation.lib import aggregation_methods +from fhs_lib_data_transformation.lib import processing +from fhs_lib_data_transformation.lib.constants import ProcessingConstants +from fhs_lib_database_interface.lib.query.model_strategy import ModelStrategyNames +from fhs_lib_model.lib.arc_method import omega_selection_strategy as oss +from fhs_lib_model.lib.arc_method.arc_method import ArcMethod +from fhs_lib_model.lib.limetr import LimeTr, RandomEffect +from frozendict import frozendict + +from fhs_pipeline_nonfatal.lib.constants import StageConstants +from fhs_pipeline_nonfatal.lib.model_parameters import ModelParameters + +MODEL_PARAMETERS = frozendict( + { + # Indicators: + StageConstants.PREVALENCE: frozendict( + { + ModelStrategyNames.ARC.value: ModelParameters( + Model=ArcMethod, + processor=processing.LogitProcessor( + years=None, + gbd_round_id=None, + remove_zero_slices=True, + no_mean=True, + bias_adjust=False, + intercept_shift=None, + age_standardize=False, + tolerance=ProcessingConstants.MAXIMUM_PRECISION, + ), + covariates=None, + fixed_effects=None, + fixed_intercept=None, + random_effects=None, + indicators=None, + spline=None, + predict_past_only=False, + node_models=None, + study_id_cols=None, + scenario_quantiles=None, + omega_selection_strategy=oss.adjusted_zero_biased_omega_distribution, + ), + ModelStrategyNames.LIMETREE.value: ModelParameters( + Model=LimeTr, + processor=processing.LogitProcessor( + years=None, + gbd_round_id=None, + remove_zero_slices=True, + no_mean=False, + bias_adjust=False, + intercept_shift="unordered_draw", + age_standardize=False, + tolerance=ProcessingConstants.MAXIMUM_PRECISION, + ), + covariates={ + "sdi": processing.NoTransformProcessor( + years=None, + gbd_round_id=None, + ) + }, + fixed_effects={"sdi": [-float("inf"), float("inf")]}, + fixed_intercept=None, + random_effects={ + "location_age_sex_intercept": RandomEffect( + ["location_id", "age_group_id", "sex_id"], None + ), + }, + indicators=None, + spline=None, + predict_past_only=False, + node_models=None, + study_id_cols=None, + scenario_quantiles=None, + omega_selection_strategy=None, + ), + ModelStrategyNames.LIMETREE_BMI.value: ModelParameters( + Model=LimeTr, + processor=processing.LogitProcessor( + years=None, + gbd_round_id=None, + remove_zero_slices=True, + no_mean=False, + bias_adjust=False, + intercept_shift="unordered_draw", + age_standardize=False, + tolerance=ProcessingConstants.MAXIMUM_PRECISION, + ), + covariates={ + "bmi": processing.NoTransformProcessor( + years=None, + gbd_round_id=None, + ) + }, + fixed_effects={"bmi": [-float("inf"), float("inf")]}, + fixed_intercept=None, + random_effects={ + "location_age_sex_intercept": RandomEffect( + ["location_id", "age_group_id", "sex_id"], None + ), + }, + indicators=None, + spline=None, + predict_past_only=False, + node_models=None, + study_id_cols=None, + scenario_quantiles=None, + omega_selection_strategy=None, + ), + ModelStrategyNames.NONE.value: None, + ModelStrategyNames.SPECTRUM.value: None, + } + ), + StageConstants.INCIDENCE: frozendict( + { + ModelStrategyNames.ARC.value: ModelParameters( + Model=ArcMethod, + processor=processing.LogProcessor( + years=None, + gbd_round_id=None, + remove_zero_slices=True, + no_mean=True, + bias_adjust=False, + intercept_shift=None, + age_standardize=False, + tolerance=ProcessingConstants.MAXIMUM_PRECISION, + ), + covariates=None, + fixed_effects=None, + fixed_intercept=None, + random_effects=None, + indicators=None, + spline=None, + predict_past_only=False, + node_models=None, + study_id_cols=None, + scenario_quantiles=None, + omega_selection_strategy=oss.use_smallest_omega_within_threshold, + ), + ModelStrategyNames.LIMETREE.value: ModelParameters( + Model=LimeTr, + processor=processing.LogProcessor( + years=None, + gbd_round_id=None, + remove_zero_slices=True, + no_mean=False, + bias_adjust=False, + intercept_shift="unordered_draw", + age_standardize=False, + tolerance=ProcessingConstants.MAXIMUM_PRECISION, + ), + covariates={ + "sdi": processing.NoTransformProcessor( + years=None, + gbd_round_id=None, + ) + }, + fixed_effects={"sdi": [-float("inf"), float("inf")]}, + fixed_intercept=None, + random_effects={ + "location_age_sex_intercept": RandomEffect( + ["location_id", "age_group_id", "sex_id"], None + ), + }, + indicators=None, + spline=None, + predict_past_only=False, + node_models=None, + study_id_cols=None, + scenario_quantiles=None, + omega_selection_strategy=None, + ), + ModelStrategyNames.NONE.value: None, + ModelStrategyNames.SPECTRUM.value: None, + } + ), + StageConstants.YLD: frozendict({ModelStrategyNames.NONE.value: None}), + # Ratios: + StageConstants.MI_RATIO: frozendict( + { + ModelStrategyNames.LIMETREE.value: ModelParameters( + Model=LimeTr, + processor=processing.LogProcessor( + years=None, + gbd_round_id=None, + remove_zero_slices=True, + no_mean=False, + bias_adjust=False, + intercept_shift="unordered_draw", + age_standardize=False, + tolerance=ProcessingConstants.MAXIMUM_PRECISION, + ), + covariates={ + "sdi": processing.NoTransformProcessor( + years=None, + gbd_round_id=None, + ) + }, + fixed_effects={"sdi": [-float("inf"), float("inf")]}, + fixed_intercept=None, + random_effects={ + "location_age_sex_intercept": RandomEffect( + ["location_id", "age_group_id", "sex_id"], None + ), + "sdi": RandomEffect(["location_id", "age_group_id", "sex_id"], 0.001), + }, + indicators=None, + spline=None, + predict_past_only=False, + node_models=None, + study_id_cols=None, + scenario_quantiles=None, + omega_selection_strategy=None, + ), + } + ), + StageConstants.MP_RATIO: frozendict( + { + ModelStrategyNames.LIMETREE.value: ModelParameters( + Model=LimeTr, + processor=processing.LogitProcessor( + years=None, + gbd_round_id=None, + remove_zero_slices=True, + no_mean=False, + bias_adjust=False, + intercept_shift="unordered_draw", + age_standardize=False, + tolerance=ProcessingConstants.MAXIMUM_PRECISION, + ), + covariates={ + "sdi": processing.NoTransformProcessor( + years=None, + gbd_round_id=None, + ) + }, + fixed_effects={"sdi": [-float("inf"), float("inf")]}, + fixed_intercept=None, + random_effects={ + "location_age_sex_intercept": RandomEffect( + ["location_id", "age_group_id", "sex_id"], None + ), + "sdi": RandomEffect(["location_id", "age_group_id", "sex_id"], 0.001), + }, + indicators=None, + spline=None, + predict_past_only=False, + node_models=None, + study_id_cols=None, + scenario_quantiles=None, + omega_selection_strategy=None, + ), + } + ), + StageConstants.PI_RATIO: frozendict( + { + ModelStrategyNames.LIMETREE.value: ModelParameters( + Model=LimeTr, + processor=processing.LogProcessor( + years=None, + gbd_round_id=None, + remove_zero_slices=True, + no_mean=False, + bias_adjust=False, + intercept_shift="unordered_draw", + age_standardize=False, + tolerance=ProcessingConstants.MAXIMUM_PRECISION, + ), + covariates={ + "sdi": processing.NoTransformProcessor( + years=None, + gbd_round_id=None, + ) + }, + fixed_effects={"sdi": [-float("inf"), float("inf")]}, + fixed_intercept=None, + random_effects={ + "location_age_sex_intercept": RandomEffect( + ["location_id", "age_group_id", "sex_id"], None + ), + "sdi": RandomEffect(["location_id", "age_group_id", "sex_id"], 0.001), + }, + indicators=None, + spline=None, + predict_past_only=False, + node_models=None, + study_id_cols=None, + scenario_quantiles=None, + omega_selection_strategy=None, + ), + } + ), + StageConstants.YLD_YLL_RATIO: frozendict( + { + ModelStrategyNames.LIMETREE.value: ModelParameters( + Model=LimeTr, + processor=processing.LogProcessor( + years=None, + gbd_round_id=None, + remove_zero_slices=True, + no_mean=False, + bias_adjust=False, + intercept_shift="unordered_draw", + age_standardize=False, + tolerance=ProcessingConstants.MAXIMUM_PRECISION, + ), + covariates={ + "sdi": processing.NoTransformProcessor( + years=None, + gbd_round_id=None, + ) + }, + fixed_effects={"sdi": [-float("inf"), float("inf")]}, + fixed_intercept=None, + random_effects={ + "location_age_sex_intercept": RandomEffect( + ["location_id", "age_group_id", "sex_id"], None + ), + "sdi": RandomEffect(["location_id", "age_group_id", "sex_id"], 0.001), + }, + indicators=None, + spline=None, + predict_past_only=False, + node_models=None, + study_id_cols=None, + scenario_quantiles=None, + omega_selection_strategy=None, + ), + } + ), + } +) + + +STAGE_AGGREGATION_METHODS = frozendict( + { + StageConstants.PREVALENCE: aggregation_methods.comorbidity, + StageConstants.INCIDENCE: aggregation_methods.summation, + StageConstants.YLD: aggregation_methods.summation, + StageConstants.DALY: aggregation_methods.summation, + StageConstants.DEATH: aggregation_methods.summation, + StageConstants.YLL: aggregation_methods.summation, + } +) diff --git a/gbd_2021/disease_burden_forecast_code/nonfatal/lib/model_strategy_queries.py b/gbd_2021/disease_burden_forecast_code/nonfatal/lib/model_strategy_queries.py new file mode 100644 index 0000000..fcd2fea --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/nonfatal/lib/model_strategy_queries.py @@ -0,0 +1,199 @@ +"""This module has query functions that give nonfatal modeling strategies and their params. +""" +from fhs_lib_database_interface.lib.constants import ( + DBConnectionParameterConstants, + DimensionConstants, + FHSDBConstants, +) +from fhs_lib_database_interface.lib.db_session import create_db_session +from fhs_lib_database_interface.lib.fhs_lru_cache import fhs_lru_cache +from fhs_lib_database_interface.lib.query.model_strategy import STAGE_STRATEGY_IDS +from fhs_lib_database_interface.lib.strategy_set.strategy import get_cause_set +from fhs_lib_year_range_manager.lib.year_range import YearRange +from tiny_structured_logger.lib import fhs_logging + +from fhs_pipeline_nonfatal.lib import model_parameters + +logger = fhs_logging.get_logger() + +model_strategy_module = None + + +def _real_model_strategy_module() -> object: + """Import and return the "real" implementation of model_strategy. + + Used to load the real implementation late, so that tests can install their own mock + version, without actually running the "import" statement (since that import fails on some + platforms). + """ + import fhs_pipeline_nonfatal.lib.model_strategy + + return fhs_pipeline_nonfatal.lib.model_strategy + + +@fhs_lru_cache(1) +def get_cause_model( + acause: str, stage: str, years: YearRange, gbd_round_id: int +) -> model_parameters.ModelParameters: + r"""Gets modeling parameters associated with the given cause-stage. + + Finds the model appropriate and + 1) get cause strategies associated with stage. + 2) find which strategy cause falls under + 3) Get the modeling parameters associated with that strategy and return them. + + Args: + acause (str): The cause that is being modeled. e.g. ``cvd_ihd`` + stage (str): Stage being forecasted, e.g. "yld_yll". + years (YearRange): Forecasting timeseries + gbd_round_id (int): The numeric ID of GBD round associated with the past data. + + Returns: + model_parameters.ModelParameters named tuple containing the following: + Model (Model): Class, i.e. un-instantiated from + ``fhs_pipeline_nonfatal.lib.model.py`` + + processor (Processor): The pre/post process strategy of the cause-stage, i.e. + instance of a class defined in ``fhs_lib_data_transformation.lib.processing.py``. + + covariates (dict[str, Processor]] | None): Maps each needed covariate, i.e. + independent variable, to it's respective preprocess strategy, i.e. instance of a + class defined in ``fhs_lib_data_transformation.lib.processing.py``. + + fixed_effects (dict[str, str] | None): List of covariates to calculate fixed + effect coefficient estimates for. e.g.: + {"haq": [-float('inf'), float('inf'), "edu": [0, 4.7]} + + fixed_intercept (str | None): To restrict the fixed intercept to be positive or + negative, pass "positive" or "negative", respectively. "unrestricted" says to + estimate a fixed effect intercept that is not restricted to positive or negative. + + random_effects (dict[str, list[str]] | None): A dictionary mapping covariates to + the dimensions that their random slopes will be estimated for and the standard + deviation of the gaussian prior on their variance. Of the form + ``dict[covariate, list[dimension]]``. e.g.: + {"haq": (["location_id", "age_group_id"], None), + "education": (["location_id"], 4.7)} + + indicators (dict[str, list[str]] | None): A dictionary mapping indicators to the + dimensions that they are indicators on. e.g.: + {"ind_age_sex": ["age_group_id", "sex_id"], "ind_loc": ["location_id"]} + + spline (dict): A dictionary mapping covariates to the spline parameters that will + be used for them of the form {covariate: SplineParams(degrees_of_freedom, + constraints)} Each key must be a covariate. The degrees_of_freedom int represents + the degrees of freedom on that spline. The constraint string can be "center" + indicating to apply a centering constraint or a 2-d array defining general linear + constraints. + node_models (list[CovModel]): + A list of NodeModels (e.g. StudyModel, OverallModel), each of + which has specifications for cov_models. + study_id_cols (Union[str, List[str]]): The columns to use in the `col_study_id` + argument to MRBRT ``load_df`` function. If it is a list of strings, those columns + will be concatenated together (e.g. ["location_id", "sex_id"] would yield columns + with values like ``{location_id}_{sex_id}``). This is done since MRBRT can + currently only use the one ``col_study_id`` column for random effects. + scenario_quantiles (dict | None): Whether to use quantiles of the stage two model + when predicting scenarios. Dictionary of quantiles to use is passed in e.g.: + {-1: dict(sdi=0.85), 0: None, 1: dict(sdi=0.15), 2: None,} + + Raises: + ValueError: If the given stage does NOT have any cause-strategy IDs, or if the given + acause/stage/gbd-round-id combo does not have a modeling strategy associated with + it. + """ + with create_db_session( + db_name=FHSDBConstants.FORECASTING_DB_NAME, + server_conn_key=DBConnectionParameterConstants.DEFAULT_SERVER_CONN_KEY, + ) as session: + try: + strategy_ids = list(STAGE_STRATEGY_IDS[stage].keys()) + except KeyError: + raise ValueError(f"{stage} does not have available strategy IDs") + + model_strategy_name = None + for strategy_id in strategy_ids: + cause_strategy_set = list( + get_cause_set( + session=session, + gbd_round_id=gbd_round_id, + strategy_id=strategy_id, + )[DimensionConstants.ACAUSE].unique() + ) + if acause in cause_strategy_set: + # ``STAGE_STRATEGY_IDS[stage][strategy_id]`` is a ``StrategyModelSource`` + # instance, and we only need its model attribute. + model_strategy_name = STAGE_STRATEGY_IDS[stage][strategy_id].model + break # Model strategy found + + if not model_strategy_name: + raise ValueError( + f"acause={acause}, stage={stage}, gbd_round_id={gbd_round_id}" + f"does not have a model strategy associated with it." + ) + else: + logger.info( + f"acause={acause}, stage={stage}, gbd_round_id={gbd_round_id} " + f"has the {model_strategy_name} model strategy associated with it." + ) + + model_strategy = model_strategy_module or _real_model_strategy_module() + model_parameters = model_strategy.MODEL_PARAMETERS[ + stage + ][model_strategy_name] + + model_parameters = _update_processor_years(model_parameters, years) + model_parameters = _update_processor_gbd_round_id(model_parameters, gbd_round_id) + + return model_parameters + + +def _update_processor_years( + model_parameters: model_parameters.ModelParameters, years: YearRange +) -> model_parameters.ModelParameters: + """Update the model_params.processor. + + If ``years`` is entered as ``None`` in the procesor for the dependent + variable and covariates so it needs to be updated here. + + Args: + model_parameters (model_parameters.ModelParameters): named tuple containing a processor + years (YearRange): year range to update processor to + + Returns: + model_parameters.ModelParameters: model parameters where the processor years have been + updated + """ + if model_parameters: + model_parameters.processor.years = years + + if model_parameters.covariates: + for cov_name in model_parameters.covariates.keys(): + model_parameters.covariates[cov_name].years = years + + return model_parameters + + +def _update_processor_gbd_round_id( + model_parameters: model_parameters.ModelParameters, gbd_round_id: int +) -> model_parameters.ModelParameters: + """Update gbd_round_id of input model_parameters.processor. + + ``gbd_round_id`` is entered as ``None`` in the procesor for the dependent + variable and covariates so it needs to be updated here. + + Args: + model_parameters (model_parameters.ModelParameters): model parameters to update + gbd_round_id (int): gbd round + + Returns: + model_parameters.ModelParameters + """ + if model_parameters: + model_parameters.processor.gbd_round_id = gbd_round_id + + if model_parameters.covariates: + for cov_name in model_parameters.covariates.keys(): + model_parameters.covariates[cov_name].gbd_round_id = gbd_round_id + + return model_parameters diff --git a/gbd_2021/disease_burden_forecast_code/nonfatal/lib/ratio_from_indicators.py b/gbd_2021/disease_burden_forecast_code/nonfatal/lib/ratio_from_indicators.py new file mode 100644 index 0000000..aa64684 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/nonfatal/lib/ratio_from_indicators.py @@ -0,0 +1,165 @@ +r"""Computes ratio of two indicators for past data. + +This ratio would be forecasted. Parallelized by cause. +""" +from typing import Optional + +import numpy as np +import xarray as xr +from fhs_lib_data_transformation.lib import processing +from fhs_lib_database_interface.lib.query import cause +from fhs_lib_database_interface.lib.query.model_strategy import RATIO_INDICATORS +from fhs_lib_file_interface.lib.version_metadata import FHSFileSpec, VersionMetadata +from fhs_lib_file_interface.lib.versioning import Versions, validate_versions_scenarios +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr_scenario, save_xr_scenario +from fhs_lib_year_range_manager.lib.year_range import YearRange +from tiny_structured_logger.lib import fhs_logging + +from fhs_pipeline_nonfatal.lib.constants import ModelConstants, StageConstants + +logger = fhs_logging.get_logger() + + +def ratio_transformation( + numerator_da: xr.DataArray, denominator_da: xr.DataArray, ratio_stage: str +) -> xr.DataArray: + """Simple division if not MP ratio. logit(m) - logit(p) if MP ratio.""" + if ratio_stage == StageConstants.MP_RATIO: + logit_numerator_da = processing.logit_with_offset(numerator_da) + logit_denominator_da = processing.logit_with_offset(denominator_da) + result = logit_numerator_da - logit_denominator_da + result = processing.invlogit_with_offset(result, bias_adjust=False) + else: + result = numerator_da / denominator_da + + return result + + +def one_cause_main( + acause: str, + draws: int, + gbd_round_id: int, + ratio_stage: str, + versions: Versions, + years: YearRange, + output_scenario: Optional[int] = None, + national_only: bool = False, +) -> None: + """Compute ratio of two indicators. + + For example: mi_ratio = death / incidence + + Args: + acause (str): The cause for which a ratio of two indicators is being calculated. + draws (int): How many draws to save for the ratio output + gbd_round_id: (int) What gbd_round_id the indicators and ratio are saved under + ratio_stage (str): What stage to save the ratio into + versions: (Versions) A Versions object that keeps track of all the versions and their + respective data directories. + years (YearRange): Forecasting timeseries abstraction + output_scenario (Optional[int]): Optional output scenario ID + national_only (bool): Whether to include subnational locations, or to include only + nations. + """ + # validate versions + validate_versions_scenarios( + versions=versions, + output_scenario=output_scenario, + output_epoch_stages=[("past", ratio_stage)], + ) + + numerator_indicator_stage = RATIO_INDICATORS[ratio_stage].numerator + # We want to infer the stage estimates of the given cause for the numerator + # indicator and the denominator indicator since there is a chance we don't + # actually have estimates for one or both of those stages. + inferred_numerator_acause = cause.get_inferred_acause(acause, numerator_indicator_stage) + + numerator_indicator = open_xr_scenario( + FHSFileSpec( + version_metadata=versions.get("past", numerator_indicator_stage), + filename=f"{inferred_numerator_acause}.nc", + ) + ) + + clean_numerator_indicator, numerator_warn_msg = processing.clean_cause_data( + data=numerator_indicator, + stage=numerator_indicator_stage, + acause=acause, + draws=draws, + gbd_round_id=gbd_round_id, + year_ids=years.past_years, + national_only=national_only, + ) + _assert_all_finite(clean_numerator_indicator, "past") + + denominator_indicator_stage = RATIO_INDICATORS[ratio_stage].denominator + inferred_denominator_acause = cause.get_inferred_acause( + acause, denominator_indicator_stage + ) + + denominator_indicator = open_xr_scenario( + FHSFileSpec( + version_metadata=versions.get("past", denominator_indicator_stage), + filename=f"{inferred_denominator_acause}.nc", + ) + ) + + clean_denominator_indicator, denominator_warn_msg = processing.clean_cause_data( + data=denominator_indicator, + stage=denominator_indicator_stage, + acause=acause, + draws=draws, + gbd_round_id=gbd_round_id, + year_ids=years.past_years, + national_only=national_only, + ) + _assert_all_finite(clean_denominator_indicator, "past") + + # A few death values can be > 1, which throws off the logit function + if numerator_indicator_stage == "death": + clean_numerator_indicator = clean_numerator_indicator.clip(max=1).fillna( + 1 - ModelConstants.DEFAULT_OFFSET + ) + + ratio = ratio_transformation( + clean_numerator_indicator, clean_denominator_indicator, ratio_stage + ) + + # Fill non-finite ratios with zeros; these have zeros in the denominator + ratio = ratio.where(np.isfinite(ratio)).fillna(0) + + ratio_file_spec = FHSFileSpec( + version_metadata=versions.get("past", ratio_stage), + filename=f"{acause}.nc", + ) + + save_xr_scenario( + xr_obj=ratio, + file_spec=ratio_file_spec, + metric="rate", + space="identity", + ) + + warning_msg = "" + if numerator_warn_msg: + warning_msg += numerator_warn_msg + if denominator_warn_msg: + warning_msg += denominator_warn_msg + _write_warning_msg(acause, warning_msg, ratio_file_spec.version_metadata) + + +def _assert_all_finite(data: xr.DataArray, past_or_future: str) -> None: + """Validate that all values in `data` are finite.""" + if not np.isfinite(data).all(): + raise ValueError(f"{past_or_future} {data.name} data has non-finite values!") + + +def _write_warning_msg(acause: str, warning_msg: str, out_path: VersionMetadata) -> None: + """Write a warning msg related to creating the ratio for this cause.""" + if warning_msg: + warning_dir = out_path.data_path() / "warnings" + warning_dir.mkdir(parents=True, exist_ok=True) + warning_file = warning_dir / f"{acause}.txt" + + with open(str(warning_file), "w") as file_obj: + file_obj.write(warning_msg) diff --git a/gbd_2021/disease_burden_forecast_code/nonfatal/lib/run_model.py b/gbd_2021/disease_burden_forecast_code/nonfatal/lib/run_model.py new file mode 100644 index 0000000..dd1fb63 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/nonfatal/lib/run_model.py @@ -0,0 +1,414 @@ +"""A script that forecasts nonfatal measures of health.""" +from typing import Any, Dict, Iterable, List, Optional, Tuple + +import xarray as xr +from fhs_lib_data_transformation.lib import processing +from fhs_lib_data_transformation.lib.dimension_transformation import expand_dimensions +from fhs_lib_data_transformation.lib.resample import resample +from fhs_lib_data_transformation.lib.validate import assert_shared_coords_same +from fhs_lib_file_interface.lib.check_input import check_versions +from fhs_lib_file_interface.lib.version_metadata import FHSFileSpec +from fhs_lib_file_interface.lib.versioning import Versions, validate_versions_scenarios +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr_scenario, save_xr_scenario +from fhs_lib_model.lib import validate +from fhs_lib_year_range_manager.lib.year_range import YearRange +from tiny_structured_logger.lib import fhs_logging + +from fhs_pipeline_nonfatal.lib import model_parameters, model_strategy_queries + +logger = fhs_logging.get_logger() +OPIOID_EXCEPTION_LOCS = [101, 102] # US and Canada + + +def one_cause_main( + acause: str, + stage: str, + versions: Versions, + years: YearRange, + draws: int, + gbd_round_id: int, + output_scenario: Optional[int], + national_only: bool, + expand_scenarios: Optional[List[int]], + seed: Optional[int], +) -> None: + r"""Forecasts given stage for given cause. + + Args: + acause (str): The cause to forecast. + stage (str): Stage being forecasted, e.g. "yld_yll". + versions (Versions): All relevant versions. + years (YearRange): Forecasting time series. + draws (int): The number of draws to compute with and output for betas and predictions. + gbd_round_id (int): The numeric ID of GBD round associated with the past data + output_scenario (Optional[int]): Optional output scenario ID + national_only (bool): Whether to include subnational locations, or to include only + nations. + expand_scenarios (Optional[List[int]]): When present, throw away all but the reference + scenario and expand the scenario dimension to these. + seed (Optional[int]): Seed for random number generator (presently just in LimeTr). + """ + # validate versions + validate_versions_scenarios( + versions=versions, + output_scenario=output_scenario, + output_epoch_stages=[("future", stage)], + ) + + # If there aren't _any_ model parameters associated with the + # cause-stage then the script will exit with return code 0. + ( + Model, + processor, + covariates, + fixed_effects, + fixed_intercept, + random_effects, + indicators, + spline, + predict_past_only, + node_models, + study_id_cols, + scenario_quantiles, + omega_selection_strategy, + ) = _get_model_parameters(acause, stage, years, gbd_round_id) + + versions_to_check = {stage} | covariates.keys() if covariates else {stage} + + # BMI is our covariate name, but these are saved as SEVs + if "bmi" in versions_to_check: + check_versions(versions, "past", ["sev"]) + check_versions(versions, "future", ["sev"]) + + versions_to_check.remove("bmi") + check_versions(versions, "past", versions_to_check) + check_versions(versions, "future", versions_to_check) + else: + check_versions(versions, "past", versions_to_check) + check_versions(versions, "future", versions_to_check) + + past_file = FHSFileSpec(versions.get("past", stage), f"{acause}.nc") + past_data = open_xr_scenario(past_file) + cleaned_past_data, _ = processing.clean_cause_data( + past_data, + stage, + acause, + None, # No draw-resampling should occur. + gbd_round_id, + year_ids=years.past_years, + national_only=national_only, + ) + + location_dict = _get_location_subsets( + acause=acause, all_locations_ids=cleaned_past_data.location_id.values + ) + + future_version_metadata = versions.get("future", stage) + + location_arrays = {} + for loc_group, location_ids in location_dict.items(): + prepped_input_data = processor.pre_process( + cleaned_past_data.sel(location_id=location_ids) + ) + + if covariates: + cov_data_list = _get_covariate_data( + prepped_input_data, + covariates, + versions, + years, + draws, + gbd_round_id, + national_only, + ) + else: + cov_data_list = None + + stripped_input_data = processing.strip_single_coord_dims(prepped_input_data) + + single_scenario_mode = True if future_version_metadata.scenario is not None else False + + model_instance = Model( + stripped_input_data, + years=years, + draws=draws, + covariate_data=cov_data_list, + fixed_effects=fixed_effects, + fixed_intercept=fixed_intercept, + random_effects=random_effects, + indicators=indicators, + gbd_round_id=gbd_round_id, + spline=spline, + predict_past_only=predict_past_only, + node_models=node_models, + study_id_cols=study_id_cols, + scenario_quantiles=scenario_quantiles, + omega_selection_strategy=omega_selection_strategy, + single_scenario_mode=single_scenario_mode, + seed=seed, + ) + + model_instance.fit() + + forecast_path = versions.data_dir(gbd_round_id, "future", stage) + model_instance.save_coefficients(forecast_path, acause + loc_group) + + forecast_data = model_instance.predict() + location_arrays[loc_group] = forecast_data + + forecast_data = xr.concat(location_arrays.values(), "location_id") + + # Expand forecast data to include point coords and single coord dims that + # were stripped off before forecasting. + expanded_output_data = processing.expand_single_coord_dims( + forecast_data, prepped_input_data + ) + + cleaned_past_data_resampled = resample(cleaned_past_data, draws) + + prepped_output_data = processor.post_process( + expanded_output_data, cleaned_past_data_resampled + ) + + # Special exception for malaria PI model, all will be held constant as we don't actually + # expect malaria prevalence/incidence ratio (a.k.a the duration of illness) to continue the + # decrease that we see in the past. Without changing, we have unrealistic rapidly dropping + # prevalence even if incidence increases. + if (acause == "malaria") and (stage == "pi_ratio"): + prepped_output_data = _malaria_pi_exception(years, cleaned_past_data_resampled) + + # For malaria incidence, we hold all of Venezuela constant and clip all locations' age + # group 95+ to past maximum value. + # We do this for Venezuela because it has very unusual past data and any ARC forecasts + # create extreme and unrealistic growth in incidence so we hold it constant. + # Age group 95+ also has extreme growth in the ARC model for many locations so we set an + # upper limit to prevent exponential growth in our results. + elif (acause == "malaria") and (stage == "incidence"): + prepped_output_data = _malaria_incidence_exception( + years, cleaned_past_data_resampled, prepped_output_data + ) + + # ArcMethod generates a fixed three scenarios, and we correct for that here, when + # expand_scenarios is set, by throwing away non-reference scenarios and expanding to the + # given list. + if expand_scenarios and ( + "scenario" not in prepped_output_data.dims + or expand_scenarios != list(prepped_output_data["scenario"].values) + ): + if "scenario" in prepped_output_data.dims: + prepped_output_data = prepped_output_data.sel(scenario=0, drop=True) + prepped_output_data = prepped_output_data.expand_dims(scenario=expand_scenarios) + + forecast_file = FHSFileSpec(future_version_metadata, f"{acause}.nc") + save_xr_scenario(prepped_output_data, forecast_file, metric="rate", space="identity") + + +def _malaria_incidence_exception( + years: YearRange, + cleaned_past_data_resampled: xr.DataArray, + prepped_output_data: xr.DataArray, +) -> xr.DataArray: + """Special malaria in Venezuela exception. + + Takes the last past year and holds it constant for all future years. + + Additionally clips age group 95+ to past max. Due to extreme draw + issues we consistently forecast unrealistic exponential growth for this + age group, so set a limit here. + """ + exception_loc_id = 133 # Venezuela + past_exception_data = cleaned_past_data_resampled.sel( + location_id=exception_loc_id, year_id=years.past_end + ).drop("year_id") + # create "future" data that's last past year held constant + fill_data_future = expand_dimensions( + past_exception_data, year_id=years.forecast_years, fill_value=past_exception_data + ) + + if "scenario" in prepped_output_data.dims: + scenarios = prepped_output_data.coords.get("scenario") + fill_data_future = fill_data_future.expand_dims(dim={"scenario": scenarios}) + + prepped_output_data = prepped_output_data.drop_sel(location_id=exception_loc_id) + # combine exception location back with the rest of the modeled data + prepped_output_data = xr.concat([prepped_output_data, fill_data_future], "location_id") + + # Need to clip age group 95+ to highest previous maximum value. For each location/age/sex + # we find the highest historical draw and limit all future draws to that value. This can + # still allow for high values but clips the exponential growth we see otherwise. + exception_age_group = 235 + + age_group_max = cleaned_past_data_resampled.sel(age_group_id=exception_age_group).max( + ["draw", "year_id"] + ) + + age_group_clipped = ( + prepped_output_data.sel(age_group_id=exception_age_group) + .sortby(age_group_max.location_id) + .clip(max=age_group_max) + ) + + prepped_output_data = xr.concat( + [prepped_output_data.drop_sel(age_group_id=exception_age_group), age_group_clipped], + "age_group_id", + ) + + return prepped_output_data + + +def _malaria_pi_exception( + years: YearRange, cleaned_past_data_resampled: xr.DataArray +) -> xr.DataArray: + """Special malaria PI model exception. + + Takes the last past year and holds it constant for all future years. + + """ + last_past_year = cleaned_past_data_resampled.sel(year_id=years.past_end).drop("year_id") + + # create "future" data that's last past year held constant + prepped_output_data = expand_dimensions( + last_past_year, year_id=years.forecast_years, fill_value=last_past_year + ) + prepped_output_data = xr.concat( + [cleaned_past_data_resampled, prepped_output_data], "year_id" + ) + + return prepped_output_data + + +def _get_location_subsets( + acause: str, all_locations_ids: Iterable[int] +) -> Dict[str, List[int]]: + """Split locations into groups depending on the cause. + + Ideally this will get offloaded to the strategy database. + """ + if acause == "mental_drug_opioids": + # special case for opioids, we're going to model US and Canada separately then model + # all other locations because they have very different patterns + location_dict = { + "_exception": OPIOID_EXCEPTION_LOCS, + "_other_locs": list(set(all_locations_ids) - set(OPIOID_EXCEPTION_LOCS)), + } + + else: + location_dict = {"": list(all_locations_ids)} + + return location_dict + + +def _get_covariate_data( + dep_var_da: xr.DataArray, + covariates: Dict[str, Any], + versions: Versions, + years: YearRange, + draws: int, + gbd_round_id: int, + national_only: bool, +) -> List[xr.DataArray]: + """Returns a list of prepped dataarray for all of the covariates.""" + cov_data_list = [] + for cov_stage, cov_processor in covariates.items(): + # Special circumstance if using BMI as a covariate--BMI is saved separately in adult + # and child files, the past also has an age standardized group which is inaccurate + # when we combine adult and child together, this section serves to open the correct + # files, remove the age standardized group, and combine the adult/child files as one + # BMI file. + cov_past_data, cov_forecast_data = _load_cov_data(versions, cov_stage) + + cov_data = processing.clean_covariate_data( + cov_past_data, + cov_forecast_data, + dep_var_da, + years, + draws, + gbd_round_id, + national_only=national_only, + ) + + prepped_cov_data = cov_processor.pre_process(cov_data) + + try: + assert_shared_coords_same( + prepped_cov_data, dep_var_da.sel(year_id=years.past_end, drop=True) + ) + except IndexError as ce: + raise IndexError(f"After pre-processing {cov_stage}, " + str(ce)) + + cov_data_list.append(prepped_cov_data) + + validate.assert_covariates_scenarios(cov_data_list) + return cov_data_list + + +def _load_cov_data( + versions: Versions, + cov_stage: str, +) -> Tuple[xr.DataArray, xr.DataArray]: + if cov_stage == "bmi": + cov_past_file_adult = FHSFileSpec(versions.get("past", "sev"), "metab_bmi_adult.nc") + cov_past_data_adult = open_xr_scenario(cov_past_file_adult).drop_sel(age_group_id=27) + + cov_past_file_child = FHSFileSpec(versions.get("past", "sev"), "metab_bmi_child.nc") + cov_past_data_child = open_xr_scenario(cov_past_file_child).drop_sel(age_group_id=27) + + cov_past_data = xr.concat([cov_past_data_child, cov_past_data_adult], "age_group_id") + cov_past_data = processing.get_dataarray_from_dataset(cov_past_data).rename(cov_stage) + + cov_forecast_file_adult = FHSFileSpec( + versions.get("future", "sev"), "metab_bmi_adult.nc" + ) + cov_forecast_data_adult = open_xr_scenario(cov_forecast_file_adult) + + cov_forecast_file_child = FHSFileSpec( + versions.get("future", "sev"), "metab_bmi_child.nc" + ) + cov_forecast_data_child = open_xr_scenario(cov_forecast_file_child) + + cov_forecast_data = xr.concat( + [cov_forecast_data_child, cov_forecast_data_adult], "age_group_id" + ) + cov_forecast_data = processing.get_dataarray_from_dataset(cov_forecast_data).rename( + cov_stage + ) + else: + cov_past_file = FHSFileSpec(versions.get("past", cov_stage), f"{cov_stage}.nc") + cov_past_data = open_xr_scenario(cov_past_file) + cov_past_data = processing.get_dataarray_from_dataset(cov_past_data).rename(cov_stage) + + cov_forecast_file = FHSFileSpec(versions.get("future", cov_stage), f"{cov_stage}.nc") + cov_forecast_data = open_xr_scenario(cov_forecast_file) + cov_forecast_data = processing.get_dataarray_from_dataset(cov_forecast_data).rename( + cov_stage + ) + + return cov_past_data, cov_forecast_data + + +def _get_model_parameters( + acause: str, stage: str, years: YearRange, gbd_round_id: int +) -> model_parameters.ModelParameters: + """Gets modeling parameters associated with the given cause-stage. + + If there aren't model parameters associated with the cause-stage then the + script will exit with return code 0. + + Args: + acause (str): the cause to get params for + stage (str): the stage to get params for + years (YearRange): the years to get data for + gbd_round_id (int): the gbd round of data + + Returns: + model_parameters.ModelParameters + """ + model_parameters = model_strategy_queries.get_cause_model( + acause, stage, years, gbd_round_id + ) + if not model_parameters: + logger.info(f"{acause}-{stage} is not forecasted in this pipeline. DONE") + exit(0) + raise ValueError() + + return model_parameters diff --git a/gbd_2021/disease_burden_forecast_code/nonfatal/lib/yld_from_prevalence.py b/gbd_2021/disease_burden_forecast_code/nonfatal/lib/yld_from_prevalence.py new file mode 100644 index 0000000..359823a --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/nonfatal/lib/yld_from_prevalence.py @@ -0,0 +1,84 @@ +r"""Computes and saves YLDs using prevalence forecasts and average disability weight. + +.. code:: python + + yld = prevalence * disability_weight + +Parallelized by cause. +""" +from typing import Optional + +from fhs_lib_data_transformation.lib import processing +from fhs_lib_file_interface.lib.query.io_helper import read_single_cause +from fhs_lib_file_interface.lib.version_metadata import FHSFileSpec +from fhs_lib_file_interface.lib.versioning import Versions, validate_versions_scenarios +from fhs_lib_file_interface.lib.xarray_wrapper import save_xr_scenario +from fhs_lib_year_range_manager.lib.year_range import YearRange + + +def one_cause_main( + acause: str, + years: YearRange, + versions: Versions, + gbd_round_id: int, + draws: int, + output_scenario: Optional[int], + national_only: bool, +) -> None: + """Calculate yld from prevalence and disability weight. + + Args: + acause (str): The cause for yld is being calculated. + years (YearRange): Forecasting time series. + versions (Versions): A Versions object that keeps track of all the versions and their + respective data directories. + gbd_round_id (int): What gbd_round_id that yld, prevalence and disability weight are + saved under + draws (int): How many draws to save for the yld output + output_scenario (Optional[int]): Optional output scenario ID + national_only (bool): Whether to include subnational locations, or to include only + nations. + """ + # validate versions + validate_versions_scenarios( + versions=versions, + output_scenario=output_scenario, + output_epoch_stages=[("future", "yld")], + ) + + prevalence = read_single_cause( + acause=acause, + stage="prevalence", + version_metadata=versions.get("future", "prevalence").default_data_source( + gbd_round_id + ), + ) + cleaned_prevalence, _ = processing.clean_cause_data( + prevalence, + "prevalence", + acause, + draws, + gbd_round_id, + years.forecast_years, + national_only=national_only, + ) + + disability_weight = read_single_cause( + acause=acause, + stage="disability_weight", + version_metadata=versions.get("past", "disability_weight").default_data_source( + gbd_round_id + ), + ) + cleaned_disability_weight, _ = processing.clean_cause_data( + disability_weight, + "disability_weight", + acause, + draws, + gbd_round_id, + national_only=national_only, + ) + + yld = cleaned_prevalence * cleaned_disability_weight + yld_file = FHSFileSpec(versions.get("future", "yld"), f"{acause}.nc") + save_xr_scenario(yld, yld_file, metric="rate", space="identity") diff --git a/gbd_2021/disease_burden_forecast_code/nonfatal/models/arc_method.py b/gbd_2021/disease_burden_forecast_code/nonfatal/models/arc_method.py new file mode 100644 index 0000000..02151db --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/nonfatal/models/arc_method.py @@ -0,0 +1,1082 @@ +"""Module with functions for making forecast scenarios.""" + +from typing import Any, Callable, Iterable, List, Optional, Type, Union + +import numpy as np +import xarray as xr +from fhs_lib_data_transformation.lib.dimension_transformation import expand_dimensions +from fhs_lib_data_transformation.lib.resample import resample +from fhs_lib_data_transformation.lib.statistic import ( + Quantiles, + weighted_mean_with_extra_dim, + weighted_quantile_with_extra_dim, +) +from fhs_lib_data_transformation.lib.truncate import truncate_dataarray +from fhs_lib_data_transformation.lib.validate import assert_coords_same +from fhs_lib_database_interface.lib.constants import DimensionConstants, ScenarioConstants +from fhs_lib_database_interface.lib.query import location +from fhs_lib_file_interface.lib import xarray_wrapper +from fhs_lib_file_interface.lib.version_metadata import FHSDirSpec +from fhs_lib_year_range_manager.lib.year_range import YearRange +from tiny_structured_logger.lib.fhs_logging import get_logger + +from fhs_lib_model.lib import predictive_validity_metrics as pv_metrics +from fhs_lib_model.lib.constants import ArcMethodConstants +from fhs_lib_model.lib.model_protocol import ModelProtocol + +logger = get_logger() + + +class StatisticSpec: + """A type representing a choice of statistical summary, with its attendant data. + + Can compute a weighted or an unweighted form. See MeanStatistic and QuantileStatistic. + """ + + def weighted_statistic( + self, data: xr.DataArray, stat_dims: List[str], weights: xr.DataArray, extra_dim: str + ) -> xr.DataArray: + """Take a weighted summary statistic on annual_diff.""" + pass + + def unweighted_statistic(self, data: xr.DataArray, stat_dims: List[str]) -> xr.DataArray: + """Take a unweighted statistic on annual_diff.""" + pass + + +class MeanStatistic(StatisticSpec): + """A stat "take the mean of some things." Takes no args.""" + + def weighted_statistic( + self, data: xr.DataArray, stat_dims: List[str], weights: xr.DataArray, extra_dim: str + ) -> xr.DataArray: + """Take a weighted mean on `data`.""" + return weighted_mean_with_extra_dim(data, stat_dims, weights, extra_dim) + + def unweighted_statistic(self, data: xr.DataArray, stat_dims: List[str]) -> xr.DataArray: + """Take an unweighted mean on `dat`a`, over the dimenstions stat_dims.""" + return data.mean(stat_dims) + + +class QuantileStatistic(StatisticSpec): + """The intention of "take some quantiles from the data".""" + + def __init__(self, quantiles: Union[float, Iterable[float]]) -> None: + """Args are the quantile fractions. + + E.g. QuantileStatistic([0.1, 0.9]) represents + the desire to take the 10th percentile and 90th percentile. You may also + pass a single number, as in QuantileStatistic(0.5) for a single quantile, + the median in that case. + """ + if not (isinstance(quantiles, float) or is_iterable_of(float, quantiles)): + raise ValueError("Arg to QuantileStatistic must either float or list of floats") + + self.quantiles = quantiles + + def weighted_statistic( + self, data: xr.DataArray, stat_dims: List[str], weights: xr.DataArray, extra_dim: str + ) -> xr.DataArray: + """Take a weighted set of quantiles on `data`.""" + return weighted_quantile_with_extra_dim( + data, self.quantiles, stat_dims, weights, extra_dim + ) + + def unweighted_statistic(self, data: xr.DataArray, stat_dims: List[str]) -> xr.DataArray: + """Take an unweighted set of quantiles on `data`.""" + return data.quantile(q=self.quantiles, dim=stat_dims) + + +class ArcMethod(ModelProtocol): + """Instances of this class represent an arc_method model. + + Can be fit and used for predicting future estimates. + """ + + # Defined ARC method parameters: + number_of_holdout_years = 10 + omega_step_size = 0.25 + max_omega = 3 + pv_metric = pv_metrics.root_mean_square_error + + def __init__( + self, + past_data: xr.DataArray, + years: YearRange, + draws: int, + gbd_round_id: int, + reference_scenario_statistic: str = "mean", + reverse_scenarios: bool = False, + quantiles: Iterable[float] = ArcMethodConstants.DEFAULT_SCENARIO_QUANTILES, + mean_level_arc: bool = True, + reference_arc_dims: Optional[List[str]] = None, + scenario_arc_dims: Optional[List[str]] = None, + truncate: bool = True, + truncate_dims: Optional[List[str]] = None, + truncate_quantiles: Iterable[float] = ArcMethodConstants.DEFAULT_TRUNCATE_QUANTILES, + replace_with_mean: bool = False, + scenario_roc: str = "all", + pv_results: xr.DataArray = None, + select_omega: bool = True, + omega_selection_strategy: Optional[Callable] = None, + omega: Optional[Union[float, xr.DataArray]] = None, + pv_pre_process_func: Optional[Callable] = None, + single_scenario_mode: bool = False, + **kwargs: Any, + ) -> None: + """Creates a new ``ArcMethod`` model instance. + + Pre-conditions: + =============== + * All given ``xr.DataArray``s must have dimensions with at least 2 + coordinates. This applies for covariates and the dependent variable. + + Args: + past_data (xr.DataArray): Past data for dependent variable being forecasted + years (YearRange): forecasting timeseries + draws (int): Number of draws to generate + gbd_round_id (int): The ID of the GBD round + reference_scenario_statistic (str): The statistic used to make the reference + scenario. If "median" then the reference scenarios is made using the weighted + median of past annualized rate-of-change across all past years, "mean" then it + is made using the weighted mean of past annualized rate-of-change across all + past years. Defaults to "mean". + reverse_scenarios (bool): If ``True``, reverse the usual assumption that high=bad + and low=good. For example, we set to ``True`` for vaccine coverage, because + higher coverage is better. Defaults to ``False``. + quantiles (Iterable[float]): The quantiles to use for better and worse + scenarios. Defaults to ``0.15`` and ``0.85``. + mean_level_arc (bool): If ``True``, then take annual differences for + means-of-draws, instead of draws. Defaults to ``True``. + reference_arc_dims (Optional[List[str]]): To calculate the reference ARC, take + weighted mean or median over these dimensions. Defaults to ["year_id"] when + ``None``. + scenario_arc_dims (Optional[List[str]]): To calculate the scenario ARCs, take + weighted quantiles over these dimensions. Defaults to ["location_id", + "year_id"] when ``None``. + truncate (bool): If ``True``, then truncate (clip) the past data over the given + dimensions. Defaults to ``False``. + truncate_dims (Optional[List[str]]): A list of strings representing the dimensions + to truncate over. If ``None``, truncation occurs over location and year. + truncate_quantiles (Iterable[float]): The two floats representing the quantiles to + take. Defaults to ``0.025`` and ``0.975``. + replace_with_mean (bool): If ``True`` and `truncate` is ``True``, then replace + values outside of the upper and lower quantiles taken across location and year + with the mean across "year_id", if False, then replace with the upper and lower + bounds themselves. Defaults to ``False``. + scenario_roc (str): If "all", then the scenario rate of change is taken over all + locations. If "national_only", roc is taken over national + locations only. Defaults to "all". + pv_results (xr.DataArray): An array of RMSEs resulting from predictive validity + tests. The array has one dimension (weight), and the values are the RMSEs from + each tested weight. When ``pv_results`` is ``None``, the ``fit`` method will + calculate new ``pv_results``. + select_omega (bool): If ``True``, the ``fit`` method will select an omega or create + an omega distribution from ``self.pv_results`` + omega_selection_strategy (Optional[Callable]): Which strategy to use to produce the + omega(s) from the omega-RMSE array, which gets produced in the fit step. + Defaults to ``None``, but must be specified unless you are passing the model an + omega directly. Can be specified as follows: + ``model.oss.name_of_omega_selection_function``. See omega_selection_strategy.py + for all omega selection functions. + omega (Optional[Union[float, xr.DataArray]]): Power to raise the increasing year + weights Must be non-negative. It can be dataarray, but must have only one + dimension, ``draw``. It must have the same coordinates on that dimension as + ``past_data_da``. When omega is ``None``, the fit method will calculate it from + ``self.pv_results`` if select_omega is ``True``. + pv_pre_process_func (Optional[Callable]): Function to call if preprocessing pv + results. + single_scenario_mode (bool): if true, only produces one scenario, not better and + worse. + kwargs (Any): Unused additional keyword arguments + """ + if select_omega and omega_selection_strategy is None: + err_msg = ( + "Must provide an omega_selection_strategy function if select_omega is True." + ) + logger.error(err_msg) + raise ValueError(err_msg) + + self.past_data = past_data + self.years = years + self.draws = draws + self.gbd_round_id = gbd_round_id + self.pv_results = pv_results + self.select_omega = select_omega + self.omega = omega + self.pv_pre_process_func = pv_pre_process_func + self.omega_selection_strategy = omega_selection_strategy + self.reference_scenario_statistic = reference_scenario_statistic + self.reverse_scenarios = reverse_scenarios + self.quantiles = quantiles + self.mean_level_arc = mean_level_arc + self.reference_arc_dims = reference_arc_dims + self.scenario_arc_dims = scenario_arc_dims + self.truncate = truncate + self.truncate_dims = truncate_dims + self.truncate_quantiles = truncate_quantiles + self.replace_with_mean = replace_with_mean + self.scenario_roc = scenario_roc + self.single_scenario_mode = single_scenario_mode + + def fit(self) -> Union[float, xr.DataArray]: + """Runs a predictive validity process to determine omega to use for forecasting. + + If ``self.select_omega`` is ``False``, this will only calculate ``self.pv_results`` + PV results are only calculated when ``self.pv_results`` is ``None``. + + Returns: + float | xr.DataArray: Power to raise the increasing year weights -- must be + nonnegative. It can be dataarray, but must have only one dimension, + DimensionConstants.DRAW. It must have the same coordinates on that dimension + as ``past_data_da``. + """ + holdout_start = self.years.past_end - self.number_of_holdout_years + pv_years = YearRange(self.years.past_start, holdout_start, self.years.past_end) + + holdouts = self.past_data.sel(year_id=pv_years.forecast_years) + omegas_to_test = np.arange( + 0, ArcMethod.max_omega + ArcMethod.omega_step_size, ArcMethod.omega_step_size + ) + + if self.pv_pre_process_func is not None: + holdouts = self.pv_pre_process_func(holdouts) + + if self.pv_results is None: + pv_result_list = [] + for test_omega in omegas_to_test: + predicted = self._arc_method(pv_years, test_omega) + if DimensionConstants.SCENARIO in predicted.coords: + predicted = predicted.sel(scenario=0, drop=True) + + assert_coords_same(predicted, self.past_data) + + predicted_holdouts = predicted.sel(year_id=pv_years.forecast_years) + + if self.pv_pre_process_func is not None: + predicted_holdouts = self.pv_pre_process_func(predicted_holdouts) + + pv_result = ArcMethod.pv_metric(predicted_holdouts, holdouts) + pv_result_da = xr.DataArray( + [pv_result], coords={"weight": [test_omega]}, dims=["weight"] + ) + pv_result_list.append(pv_result_da) + + self.pv_results = xr.concat(pv_result_list, dim="weight") + + if self.select_omega: + self.omega = self.omega_selection_strategy(rmse=self.pv_results, draws=self.draws) + + return self.omega + + def predict(self) -> xr.DataArray: + """Create projections for reference, better, and worse scenarios using the ARC method. + + Returns: + xr.DataArray: Projections for future years made with the ARC method. It will + include all the dimensions and coordinates of the + ``self.past_data``, except that the ``year_id`` dimension will + ONLY have coordinates for all of the years from + ``self.years.forecast_years``. There will also be a new + ``scenario`` dimension with the coordinates 0 for reference, + -1 for worse, and 1 for better. + """ + self.predictions = self._arc_method( + self.years, self.omega, past_resample_draws=self.draws + ).sel(year_id=self.years.forecast_years) + + return self.predictions + + def save_coefficients( + self, output_dir: FHSDirSpec, entity: str, save_omega_draws: bool = False + ) -> None: + """Saves omega. + + I.e. the power to raise the increasing year weights to, and/or PV results, + an array of RMSEs resulting from predictive validity tests. + + Args: + output_dir (Path): directory to save data to + entity (str): name to give output file + save_omega_draws (bool): whether to save omega draws + + Raises: + ValueError: if no omega or PV results present to save + """ + + def is_xarray(da: Any) -> bool: + return isinstance(da, xr.Dataset) or isinstance(da, xr.DataArray) + + if self.omega is None and self.pv_results is None: + err_msg = "No omega or predictive validity results to save" + logger.error(err_msg) + raise ValueError(err_msg) + + if self.omega is not None: + if is_xarray(self.omega) and not save_omega_draws: + logger.debug( + "Computing stats of omega draws", + bindings=dict(model=self.__class__.__name__), + ) + coef_stats = self._compute_stats(self.omega) + elif is_xarray(self.omega) and save_omega_draws: + coef_stats = self.omega + elif isinstance(self.omega, float) or isinstance(self.omega, int): + logger.debug( + "omega is singleton value", + bindings=dict(model=self.__class__.__name__), + ) + coef_stats = xr.DataArray( + [self.omega], + dims=["omega"], + coords={"omega": ["value"]}, + ) + + omega_output_file = output_dir.append_sub_path(("coefficients",)).file( + f"{entity}_omega.nc" + ) + + xarray_wrapper.save_xr_scenario( + coef_stats, + omega_output_file, + metric="rate", + space="identity", + ) + + if self.pv_results is not None: + pv_output_file = output_dir.append_sub_path(("coefficients",)).file( + f"{entity}_omega_rmses.nc" + ) + + xarray_wrapper.save_xr_scenario( + self.pv_results, + pv_output_file, + metric="rate", + space="identity", + ) + + @staticmethod + def _compute_stats(da: xr.DataArray) -> Union[xr.DataArray, xr.Dataset]: + """Compute mean and variance of draws if a ``'draw'`` dim exists. + + Otherwise just return a copy of the original. + + Args: + da (xr.DataArray): data array for computation + + Returns: + Union[xr.DataArray, xr.Dataset]: the computed data + """ + if DimensionConstants.DRAW in da.dims: + mean_da = da.mean(DimensionConstants.DRAW).assign_coords(stat="mean") + var_da = da.var(DimensionConstants.DRAW).assign_coords(stat="var") + stats_da = xr.concat([mean_da, var_da], dim="stat") + else: + logger.warning( + "Draw is NOT a dim, can't compute omega stats", + bindings=dict(model=__class__.__name__, dims=da.dims), + ) + stats_da = da.copy() + return stats_da + + def _arc_method( + self, + years: YearRange, + omega: Union[float, xr.DataArray], + past_resample_draws: Optional[int] = None, + ) -> xr.DataArray: + """Run and return the `arc_method`. + + To keep the PV step and prediction step consistent put the explicit ``arc_method`` + call with all of its defined parameters here. + + Args: + years (YearRange): years to include in the past when calculating ARC + omega (Union[float, xr.DataArray]): the omega to assess for draws + past_resample_draws (Optional[int]): The number of draws to resample from the past + data. This argument is used in the predict step to avoid NaNs in the forecast + when there is a mismatch between the number of draw coordinates in the past + data and the desired number of draw coordinates. + + Returns: + xr.DataArray: result of the `arc_method` function call + """ + omega_dim = ArcMethod._get_omega_dim(omega, self.draws) + + if past_resample_draws is not None and "draw" in self.past_data.dims: + past_data = resample(self.past_data, past_resample_draws) + else: + past_data = self.past_data + + return arc_method( + past_data_da=past_data, + gbd_round_id=self.gbd_round_id, + years=years, + weight_exp=omega, + reference_scenario=self.reference_scenario_statistic, + reverse_scenarios=self.reverse_scenarios, + quantiles=self.quantiles, + diff_over_mean=self.mean_level_arc, + reference_arc_dims=self.reference_arc_dims, + scenario_arc_dims=self.scenario_arc_dims, + truncate=self.truncate, + truncate_dims=self.truncate_dims, + truncate_quantiles=self.truncate_quantiles, + replace_with_mean=self.replace_with_mean, + extra_dim=omega_dim, + scenario_roc=self.scenario_roc, + single_scenario_mode=self.single_scenario_mode, + ) + + @staticmethod + def _get_omega_dim(omega: Union[float, int, xr.DataArray], draws: int) -> Optional[str]: + """Get the omega dimension if passed a data array. + + Args: + omega (Union[float, int, xr.DataArray]): the omega value or data array + draws (int): the number of draws to validate omega against + + Returns: + Optional[str]: ``'draw'``, if ``omega`` contains draw specific omegas as a + dataarray or ``None``, if ``omega`` is float. + + Raises: + ValueError: if `omega` draw dim doesn't have the expected coords + TypeError: if `omega` isn't a float, int, or data array + """ + if isinstance(omega, float) or isinstance(omega, int): + omega_dim = None + elif isinstance(omega, xr.DataArray): + if set(omega.dims) != {DimensionConstants.DRAW}: + err_msg = "`omega` can only have 'draw' as a dim" + logger.error(err_msg) + raise ValueError(err_msg) + elif sorted(list(omega[DimensionConstants.DRAW].values)) != list(range(draws)): + err_msg = "`omega`'s draw dim doesn't have the expected coords" + logger.error(err_msg) + raise ValueError(err_msg) + omega_dim = DimensionConstants.DRAW + else: + err_msg = "`omega` must be either a float, an int, or an xarray.DataArray" + logger.error(err_msg) + raise TypeError(err_msg) + + return omega_dim + + +def arc_method( + past_data_da: xr.DataArray, + gbd_round_id: int, + years: Optional[Iterable[int]] = None, + weight_exp: Union[float, int, xr.DataArray] = 1, + reference_scenario: str = "median", + reverse_scenarios: bool = False, + quantiles: Iterable[float] = ArcMethodConstants.DEFAULT_SCENARIO_QUANTILES, + diff_over_mean: bool = False, + reference_arc_dims: Optional[List[str]] = None, + scenario_arc_dims: Optional[List[str]] = None, + truncate: bool = False, + truncate_dims: Optional[List[str]] = None, + truncate_quantiles: Optional[Iterable[float]] = None, + replace_with_mean: bool = False, + extra_dim: Optional[str] = None, + scenario_roc: str = "all", + single_scenario_mode: bool = False, +) -> xr.DataArray: + """Makes rate forecasts using the Annualized Rate-of-Change (ARC) method. + + Forecasts rates by taking a weighted quantile or weighted mean of + annualized rates-of-change from past data, then walking that weighted + quantile or weighted mean out into future years. + + A reference scenario is made using the weighted median or mean of past + annualized rate-of-change across all past years. + + Better and worse scenarios are made using weighted 15th and 85th quantiles + of past annualized rates-of-change across all locations and all past years. + + The minimum and maximum are taken across the scenarios (values are + granular, e.g. age/sex/location/year specific) and the minimum is taken as + the better scenario and the maximum is taken as the worse scenario. If + scenarios are reversed (``reverse_scenario = True``) then do the opposite. + + Args: + past_data_da: + A dataarray of past data that must at least of the dimensions + ``year_id`` and ``location_id``. The ``year_id`` dimension must + have coordinates for all the years in ``years.past_years``. + gbd_round_id: + gbd_round_id the data comes from. + years: + years to include in the past when calculating ARC. + weight_exp: + power to raise the increasing year weights -- must be nonnegative. + It can be dataarray, but must have only one dimension, "draw", it + must have the same coordinates on that dimension as + ``past_data_da``. + reference_scenario: + If "median" then the reference scenarios is made using the + weighted median of past annualized rate-of-change across all past + years, "mean" then it is made using the weighted mean of past + annualized rate-of-change across all past years. Defaults to + "median". + reverse_scenarios: + If True, reverse the usual assumption that high=bad and low=good. + For example, we set to True for vaccine coverage, because higher + coverage is better. Defaults to False. + quantiles: + The quantiles to use for better and worse scenarios. Defaults to + ``0.15`` and ``0.85`` quantiles. + diff_over_mean: + If True, then take annual differences for means-of-draws, instead + of draws. Defaults to False. + reference_arc_dims: + To calculate the reference ARC, take weighted mean or median over + these dimensions. Defaults to ["year_id"] + scenario_arc_dims: + To calculate the scenario ARCs, take weighted quantiles over these + dimensions.Defaults to ["location_id", "year_id"] + truncate: + If True, then truncates the dataarray over the given dimensions. + Defaults to False. + truncate_dims: + A list of strings representing the dimensions to truncate over. + truncate_quantiles: + The tuple of two floats representing the quantiles to take. + replace_with_mean: + If True and `truncate` is True, then replace values outside of the + upper and lower quantiles taken across "location_id" and "year_id" + and with the mean across "year_id", if False, then replace with the + upper and lower bounds themselves. + extra_dim: + Extra dimension that exists in `weights` and `data`. It should not + be in `stat_dims`. + scenario_roc: + If "all", then the scenario rate of change is taken over all + locations. If "national_only", roc is taken over national + locations only. Defaults to "all". + single_scenario_mode: + If true, better and worse scenarios are not calculated, and the reference scenario + is returned without a scenario dimension. + + Returns: + Past and future data with reference, better, and worse scenarios. + It will include all the dimensions and coordinates of the input + dataarray and a ``scenario`` dimension with the coordinates 0 for + reference, -1 for worse, and 1 for better. The ``year_id`` + dimension will have coordinates for all of the years from + ``years.years``. + + Raises: + ValueError: If ``weight_exp`` is a negative number or if ``reference_scenario`` + is not "median" or "mean". + """ + logger.debug( + "Inputs for `arc_method` call", + bindings=dict( + years=years, + weight_exp=weight_exp, + reference_scenario=reference_scenario, + reverse_scenarios=reverse_scenarios, + quantiles=quantiles, + diff_over_mean=diff_over_mean, + truncate=truncate, + replace_with_mean=replace_with_mean, + truncate_quantiles=truncate_quantiles, + extra_dim=extra_dim, + ), + ) + + years = YearRange(*years) if years else YearRange(*ArcMethodConstants.DEFAULT_YEAR_RANGE) + + past_data_da = past_data_da.sel(year_id=years.past_years) + + # Create baseline forecasts. Take weighted median or mean only across + # years, so values will be as granular as the inputs (e.g. age/sex/location + # specific) + if reference_scenario == "median": + reference_statistic = QuantileStatistic(0.5) + elif reference_scenario == "mean": + reference_statistic = MeanStatistic() + else: + raise ValueError("reference_scenario must be either 'median' or 'mean'") + + if truncate and not truncate_dims: + truncate_dims = [DimensionConstants.LOCATION_ID, DimensionConstants.YEAR_ID] + + truncate_quantiles = ( + Quantiles(*sorted(truncate_quantiles)) + if truncate_quantiles + else Quantiles(0.025, 0.975) + ) + + reference_arc_dims = reference_arc_dims or [DimensionConstants.YEAR_ID] + reference_change = arc( + past_data_da, + years, + weight_exp, + reference_arc_dims, + reference_statistic, + diff_over_mean=diff_over_mean, + truncate=truncate, + truncate_dims=truncate_dims, + truncate_quantiles=truncate_quantiles, + replace_with_mean=replace_with_mean, + extra_dim=extra_dim, + ) + reference_da = past_data_da.sel(year_id=years.past_end) + reference_change + forecast_data_da = past_data_da.combine_first(reference_da) + + if not single_scenario_mode: + forecast_data_da = _forecast_better_worse_scenarios( + past_data_da=past_data_da, + gbd_round_id=gbd_round_id, + years=years, + weight_exp=weight_exp, + reverse_scenarios=reverse_scenarios, + quantiles=quantiles, + diff_over_mean=diff_over_mean, + scenario_arc_dims=scenario_arc_dims, + replace_with_mean=replace_with_mean, + extra_dim=extra_dim, + scenario_roc=scenario_roc, + forecast_data_da=forecast_data_da, + ) + + return forecast_data_da + + +def _forecast_better_worse_scenarios( + forecast_data_da: xr.DataArray, + past_data_da: xr.DataArray, + gbd_round_id: int, + years: YearRange, + weight_exp: Union[float, int, xr.DataArray], + reverse_scenarios: bool, + quantiles: Iterable[float], + diff_over_mean: bool, + scenario_arc_dims: Optional[List[str]], + replace_with_mean: bool, + extra_dim: Optional[str], + scenario_roc: str, +) -> xr.DataArray: + try: + forecast_data_da = forecast_data_da.rename( + {DimensionConstants.QUANTILE: DimensionConstants.SCENARIO} + ) + except ValueError: + pass # There is no "quantile" point coordinate. + + forecast_data_da[DimensionConstants.SCENARIO] = ScenarioConstants.REFERENCE_SCENARIO_COORD + + # Create better and worse scenario forecasts. Take weighted 85th and 15th + # quantiles across year and location, so values will not be location + # specific (e.g. just age/sex specific). + scenario_arc_dims = scenario_arc_dims or [ + DimensionConstants.LOCATION_ID, + DimensionConstants.YEAR_ID, + ] + if scenario_roc == "national": + nation_ids = location.get_location_set( + gbd_round_id=gbd_round_id, include_aggregates=False, national_only=True + )[DimensionConstants.LOCATION_ID].unique() + + arc_input = past_data_da.sel(location_id=nation_ids) + elif scenario_roc == "all": + arc_input = past_data_da + else: + raise ValueError( + f'scenario_roc should be one of "national" or "all"; got {scenario_roc}' + ) + scenario_change = arc( + arc_input, + years, + weight_exp, + scenario_arc_dims, + QuantileStatistic(quantiles), + diff_over_mean=diff_over_mean, + truncate=False, + replace_with_mean=replace_with_mean, + extra_dim=extra_dim, + ) + + scenario_change = scenario_change.rename( + {DimensionConstants.QUANTILE: DimensionConstants.SCENARIO} + ) + scenarios_da = past_data_da.sel(year_id=years.past_end) + scenario_change + + scenarios_da.coords[DimensionConstants.SCENARIO] = [ + ScenarioConstants.BETTER_SCENARIO_COORD, + ScenarioConstants.WORSE_SCENARIO_COORD, + ] + + forecast_data_da = xr.concat( + [forecast_data_da, scenarios_da], dim=DimensionConstants.SCENARIO + ) + + # Get the minimums and maximums across the scenario dimension, and set + # worse scenarios to the worst (max if normal or min if reversed), and set + # better scenarios to the best (min if normal or max if reversed). + low_values = forecast_data_da.min(DimensionConstants.SCENARIO) + high_values = forecast_data_da.max(DimensionConstants.SCENARIO) + if reverse_scenarios: + forecast_data_da.loc[ + {DimensionConstants.SCENARIO: ScenarioConstants.WORSE_SCENARIO_COORD} + ] = low_values + forecast_data_da.loc[ + {DimensionConstants.SCENARIO: ScenarioConstants.BETTER_SCENARIO_COORD} + ] = high_values + else: + forecast_data_da.loc[ + {DimensionConstants.SCENARIO: ScenarioConstants.BETTER_SCENARIO_COORD} + ] = low_values + forecast_data_da.loc[ + {DimensionConstants.SCENARIO: ScenarioConstants.WORSE_SCENARIO_COORD} + ] = high_values + + forecast_data_da = past_data_da.combine_first(forecast_data_da) + + forecast_data_da = forecast_data_da.loc[ + {DimensionConstants.SCENARIO: sorted(forecast_data_da[DimensionConstants.SCENARIO])} + ] + + return forecast_data_da + + +def arc( + past_data_da: xr.DataArray, + years: YearRange, + weight_exp: Union[float, int, xr.DataArray], + stat_dims: Iterable[str], + statistic: StatisticSpec, + diff_over_mean: bool = False, + truncate: bool = False, + truncate_dims: Optional[List[str]] = None, + truncate_quantiles: Optional[Iterable[float]] = None, + replace_with_mean: bool = False, + extra_dim: Optional[str] = None, +) -> xr.DataArray: + r"""Makes rate forecasts by forecasting the Annualized Rates-of-Change (ARC). + + Uses either weighted means or weighted quantiles. + + The steps for forecasting logged or logitted rates with ARCs are: + + (1) Annualized rate differentials (or annualized rates-of-change if data is + in log or logit space) are calculated. + + .. Math:: + + \vec{D_{p}} = + [x_{1991} - x_{1990}, x_{1992} - x_{1991}, ... x_{2016} - x_{2015}] + + where :math:`x` are values from ``past_data_da`` for each year and + :math:`\vec{D_p}` is the vector of differentials in the past. + + (2) Year weights are used to weight recent years more heavily. Year weights + are made by taking the interval + + .. math:: + + \vec{W} = [1, ..., n]^w + + where :math:`n` is the number of past years, :math:`\vec{w}` is the + value given by ``weight_exp``, and :math:`\vec{W}` is the vector of + year weights. + + (3) Weighted quantiles or the weighted mean of the annualized + rates-of-change are taken over the dimensions. + + .. math:: + + s = \text{weighted-statistic}(\vec{W}, \vec{D}) + + where :math:`s` is the weighted quantile or weighted mean. + + (4) Future rates-of-change are simulated by taking the interval + + .. math:: + + \vec{D_{f}} = [1, ..., m] * s + + where :math:`\vec{D_f}` is the vector of differentials in the future + and :math:`m` is the number of future years to forecast and + + (5) Lastly, these future differentials are added to the rate of the last + observed year. + + .. math:: + + \vec{X_{f}} = \vec{D_{f}} + x_{2016} = [x_{2017}, ..., x_{2040}] + + where :math:`X_{f}` is the vector of forecasted rates. + + Args: + past_data_da: + Past data with a year-id dimension. Must be in log or logit space + in order for this function to actually calculate ARCs, otherwise + it's just calculating weighted statistic of the first differences. + years: + past and future year-ids + weight_exp: + power to raise the increasing year weights -- must be nonnegative. + It can be dataarray, but must have only one dimension, "draw", it + must have the same coordinates on that dimension as + ``past_data_da``. + stat_dims: + list of dimensions to take quantiles over + statistic: A statistic to use to calculate the ARC from the annual + diff, either MeanStatistic() or QuantileStatistic(quantiles). + diff_over_mean: + If True, then take annual differences for means-of-draws, instead + of draws. Defaults to False. + truncate: + If True, then truncates the dataarray over the given dimensions. + Defaults to False. + truncate_dims: + A list of strings representing the dimensions to truncate over. + truncate_quantiles: + The iterable of two floats representing the quantiles to take. + replace_with_mean: + If True and `truncate` is True, then replace values outside of the + upper and lower quantiles taken across "location_id" and "year_id" + and with the mean across "year_id", if False, then replace with the + upper and lower bounds themselves. + extra_dim: + An extra dim to take the `statistic` over. Should exist in + `weights` and `data`. It should not be in `stat_dims`. + + Returns: + Forecasts made using the ARC method. + + Raises: + ValueError: Conditions: + + * If ``statistic`` is ill-formed. + * If ``weight_exp`` is a negative number. + * If `truncate` is True, then `truncate_quantiles` must be a list of floats. + """ + logger.debug( + "Inputs for `arc` call", + bindings=dict( + years=years, + weight_exp=weight_exp, + statistic=statistic, + stat_dims=stat_dims, + diff_over_mean=diff_over_mean, + truncate=truncate, + replace_with_mean=replace_with_mean, + truncate_quantiles=truncate_quantiles, + extra_dim=extra_dim, + ), + ) + + # Calculate the annual differentials. + if diff_over_mean and DimensionConstants.DRAW in past_data_da.dims: + annual_diff = past_data_da.mean(DimensionConstants.DRAW) + else: + annual_diff = past_data_da + annual_diff = annual_diff.sel(year_id=years.past_years).diff( + DimensionConstants.YEAR_ID, n=1 + ) + + if isinstance(weight_exp, xr.DataArray): + if DimensionConstants.DRAW not in weight_exp.dims: # pytype: disable=attribute-error + raise ValueError( + "`weight_exp` must be a float, an int, or an xarray.DataArray " + "with a 'draw' dimension" + ) + + # If annual-differences were taken over means (`annual_diff` doesn't have a "draw" + # dimension), but `year_weights` does have a "draw" dimension, then the draw dimension + # needs to be expanded for `annual_diff` such that the mean is replicated for each draw + if DimensionConstants.DRAW not in annual_diff.dims: + annual_diff = expand_dimensions( + annual_diff, draw=weight_exp[DimensionConstants.DRAW].values + ) + weight_exp = expand_dimensions( + weight_exp, year_id=annual_diff[DimensionConstants.YEAR_ID].values + ) + + year_weights = ( + xr.DataArray( + (np.arange(len(years.past_years) - 1) + 1), + dims=DimensionConstants.YEAR_ID, + coords={DimensionConstants.YEAR_ID: years.past_years[1:]}, + ) + ** weight_exp + ) + + if truncate: + if not is_iterable_of(float, truncate_quantiles): + raise ValueError( + "If `truncate` is True, then `truncate_quantiles` must be a list of floats." + ) + + truncate_dims = truncate_dims or [ + DimensionConstants.LOCATION_ID, + DimensionConstants.YEAR_ID, + ] + truncate_quantiles = Quantiles(*sorted(truncate_quantiles)) + annual_diff = truncate_dataarray( + annual_diff, + truncate_dims, + replace_with_mean=replace_with_mean, + mean_dims=[DimensionConstants.YEAR_ID], + weights=year_weights, + quantiles=truncate_quantiles, + extra_dim=extra_dim, + ) + + stat_dims = list(stat_dims) + + if (xr.DataArray(weight_exp) > 0).any(): + arc_da = statistic.weighted_statistic(annual_diff, stat_dims, year_weights, extra_dim) + elif (xr.DataArray(weight_exp) == 0).all(): + # If ``weight_exp`` is zero, then just take the unweighted mean or + # quantile. + arc_da = statistic.unweighted_statistic(annual_diff, stat_dims) + else: + raise ValueError("weight_exp must be nonnegative.") + + # Find future change by multiplying an array that counts the future + # years, by the quantiles, which is weighted if `weight_exp` > 0. We want + # the multipliers to start at 1, for the first year of forecasts, and count + # to one more than the number of years to forecast. + forecast_year_multipliers = xr.DataArray( + np.arange(len(years.forecast_years)) + 1, + dims=[DimensionConstants.YEAR_ID], + coords={DimensionConstants.YEAR_ID: years.forecast_years}, + ) + future_change = arc_da * forecast_year_multipliers + return future_change + + +def is_iterable_of(type: Type, obj: Any) -> bool: + """True iff the obj is an iterable containing only instances of the given type.""" + return hasattr(obj, "__iter__") and all([isinstance(item, type) for item in obj]) + + +def approach_value_by_year( + past_data: xr.DataArray, + years: YearRange, + target_year: int, + target_value: float, + method: str = "linear", +) -> xr.DataArray: + """Forecasts cases where a target level at a target year is known. + + For e.g., the Rockefeller project for min-risk diet scenarios, wanted to + see the effect of eradicating diet related risks by 2030 on mortality. For + this we need to reach 0 SEV for all diet related risks by 2030 and keep + the level constant at 0 for further years. Here the target_year is 2030 + and target_value is 0. + + Args: + past_data: + The past data with all past years. + years: + past and future year-ids + target_year: + The year at which the target value will be reached. + target_value: + The target value that needs to be achieved during the target year. + method: + The extrapolation method to be used to calculate the values for + intermediate years (years between years.past_end and target_year). + The method currently supported is: `linear`. + + Raises: + ValueError: if method != "linear" + + Returns: + The forecasted results. + """ + if method == "linear": + forecast = _linear_then_constant_arc(past_data, years, target_year, target_value) + else: + raise ValueError( + f"Method {method} not recognized. Please see the documentation for" + " the list of supported methods." + ) + + return forecast + + +def _linear_then_constant_arc( + past_data: xr.DataArray, years: YearRange, target_year: int, target_value: float +) -> xr.DataArray: + r"""Makes rate forecasts by linearly extrapolating. + + Extraploates the point ARC from the last past year till the target year to reach the target + value. + + The steps for extrapolating the point ARCs are: + + (1) Calculate the rate of change between the last year of the past + data (eg.2017) and ``target_year`` (eg. 2030). + + .. Math:: + + R = + \frac{target\_value - past\_last\_year_value} + {target\_year- past\_last\_year} + + where :math:`R` is the slope of the desired linear trend. + + (2) Calculate the rates of change between the last year of the past and + each future year by multiplying R with future year weights till + ``target_year``. + + .. math:: + + \vec{W} = [1, ..., m] + + \vec{F_r} = \vec{W} * R + + where :math:`m` is the number of years between the ``target_year`` and + the last year of the past, and :math:`\vec{W}` forms the vector of + year weights. + :math:`\vec{F_r}` contains the linearly extrapolated ARCs for each + future year till the ``target_year``. + + (3) Add the future rates :math: `\vec{F_r}` to last year of the past + (eg. 2017) to get the forecasted results. + + (4) Extend the forecasted results till the ``forecast_end`` year by + filling the ``target_value`` for all the remaining future years. + + Args: + past_data: + The past data with all past years. The data is assumed to be in + normal space. + years: + past and future year-ids + target_year: + The year at which the target value will be reached. + target_value: + The value that needs to be achieved by the `target_year`. + + Returns: + The forecasted results. + """ + pre_target_years = np.arange(years.forecast_start, target_year + 1) + post_target_years = np.arange(target_year + 1, years.forecast_end + 1) + + past_last_year = past_data.sel(year_id=years.past_end) + target_yr_arc = (target_value - past_last_year) / (target_year - years.past_end) + + forecast_year_multipliers = xr.DataArray( + np.arange(len(pre_target_years)) + 1, + dims=[DimensionConstants.YEAR_ID], + coords={DimensionConstants.YEAR_ID: pre_target_years}, + ) + + future_change = target_yr_arc * forecast_year_multipliers + forecast_bfr_target_year = past_last_year + future_change + + forecast = expand_dimensions( + forecast_bfr_target_year, fill_value=target_value, year_id=post_target_years + ) + + return forecast diff --git a/gbd_2021/disease_burden_forecast_code/nonfatal/models/limetr.py b/gbd_2021/disease_burden_forecast_code/nonfatal/models/limetr.py new file mode 100644 index 0000000..9aae217 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/nonfatal/models/limetr.py @@ -0,0 +1,1107 @@ +"""This module provides an FHS interface to the LimeTr model.""" + +import itertools +from collections import namedtuple +from copy import deepcopy +from typing import Any, Dict, Iterable, List, Optional, Tuple + +import numpy as np +import xarray as xr +from fhs_lib_data_transformation.lib.dimension_transformation import expand_dimensions +from fhs_lib_database_interface.lib.constants import DimensionConstants +from fhs_lib_file_interface.lib import xarray_wrapper +from fhs_lib_file_interface.lib.file_system_manager import FileSystemManager +from fhs_lib_file_interface.lib.version_metadata import FHSDirSpec +from fhs_lib_year_range_manager.lib.year_range import YearRange +from flme import LME +from tiny_structured_logger.lib.fhs_logging import get_logger + +from fhs_lib_model.lib.constants import ModelConstants +from fhs_lib_model.lib.model_protocol import ModelProtocol +from fhs_lib_model.lib.utils import assert_covariate_coords, mean_of_draw + +logger = get_logger() + +RandomEffect = namedtuple("RandomEffect", "dims, prior_value") + + +class LimeTr(ModelProtocol): + """Instances of this class represent a LimeTr model. + + Can be it and used for predicting future esimates. + """ + + def __init__( + self, + past_data: xr.DataArray, + years: YearRange, + draws: Optional[int], + covariate_data: Optional[List[xr.DataArray]] = None, + fixed_effects: Optional[Dict] = None, + fixed_intercept: Optional[str] = None, + random_effects: Optional[Dict] = None, + indicators: Optional[Dict] = None, + seed: Optional[int] = None, + **kwargs: Any, + ) -> None: + """Creates a new ``LimeTr`` model instance. + + Pre-conditions: + =============== + * All given ``xr.DataArray``s must have dimensions with at least 2 + coordinates. This applies for covariates and the dependent variable. + * There must be at least one fixed effect (counting indicator as a + fixed effect) or random effect in the model. + * Effects (fixed or random) that are on covariates must have exactly + the same name as the covariate they are associated with. e.g. you + can't have a random effect called ``haq_age``. + * Cannot have both global intercept be True and indicators. + * Cannot have fixed effects if ``covariate_data`` is ``None``. + + For covariates: + --------------- + * The dimensions within each covariate are also dimensions of the + dependent variable (i.e. ``dimensions_order_list``). + * For the dimensions shared between each covariate and the dependent + variable, the coordinates should be the same. + * Each covariate's dataarray should have a ``scenario``, with at least + one coord ``scenario=0``. + * All covariates should have the same scenario coordinates -- + covariates that don't have actual scenarios should have their + reference scenario broadcast out to all the expected scenarios by + this point. + + For random effects: + ------------------- + * Random effects have at least one shared dimension. + * Random effects are mapped to a non-empty list of dimensions. + * The dimensions within each random effect are also dimensions of + the dependent variable (i.e. ``dimensions_order_list``). + + For indicators + -------------- + * Indicator are mapped to a non-empty list of dimensions. + * The dimensions within each indicator are also dimensions of the + dependent variable (i.e. ``dimensions_order_list``). + + There should be at least one effect that is specific to a covariate. + + Args: + past_data (xr.DataArray): Past data for dependent variable being forecasted. Only + mean-of-draw-level past data is used, so if draws are given the + mean will be taken. + covariate_data (list[xr.DataArray] | None, optional): Past and forecast data for + each covariate (i.e. independent variable). Each individual covariate + dataarray must be named with the stage as is defined in the FHS file system. + Only mean-of-draw-level past data is used, so if draws are given the + mean will be taken. + years (YearRange): forecasting timeseries + draws (Optional[int]): Number of draws to generate for the betas (and thus the + predictions) during the fitting. + fixed_effects (Optional[Dict]): Dict of covariates to have their + corresponding coefficients estimated and bounded by the given list. + e.g.: {"haq": [0, float('inf')], "edu": [-float('inf'), float('inf')]} + fixed_intercept (str | None, optional): To restrict the fixed intercept to be + positive or negative, pass "positive" or "negative", respectively. + "unestricted" says to estimate a fixed effect intercept that is not restricted + to positive or negative. If ``None`` then no fixed intercept is estimated. + Currently all of the strings get converted to unrestricted. + indicators (Dict | None, optional): A dictionary mapping indicators to the + dimensions that they are indicators on. + e.g.: {"ind_age_sex": ["age_group_id", "sex_id"], "ind_loc": ["location_id"]} + kwargs (Any): Unused additional keyword arguments + random_effects (Dict | None, optional): A dictionary mapping covariates to the + dimensions that their random slopes will be estimated for and the standard + deviation of the gaussian prior on their variance. + of the form {covariate: RandomEffect(list[dimension], std)...} + Any key that is not in covariate_data is assumed to be an + intercept. + The std float represents the value of the gaussian prior on + random effects variance. None means no prior. + e.g.:: + {"haq": RandomEffect(["location_id", "age_group_id"], None), + "education": RandomEffect(["location_id"], 3)} + seed (Optional[int]): an optional seed to set for reproducibility. + + Raises: + ValueError: Conditions + * If fixed effects are non-empty but ``covariate_data`` is ``None``, or If + ``fixed_effects``, ``indicators``, and ``random_effects`` are all ``None``/ + empty + * If a given random effect, indicator, or covariate has 1 or more dimensions + not included in ``self.dimensions_order_list``, i.e. those of the dependent + variable + * If there are no shared dims among all of the random effects, while random + effects do actually exist * If global intercept is True **and** indicators are + non-empty + * If any dimension shared between any covariate and the dependent variable + does not have the same coordinates on the dataarray each. + """ + if not fixed_effects and not random_effects and not indicators: + err_msg = ( + "`fixed_effects`, `indicators`, and `random_effects` are all " "`None`/empty." + ) + logger.error(err_msg) + raise ValueError(err_msg) + elif fixed_intercept and indicators: + err_msg = "Cannot have both global intercept be True and indicators." + logger.error(err_msg) + raise ValueError(err_msg) + elif fixed_effects and not covariate_data: + err_msg = "Cannot have fixed effects if `covariate_data` is `None`." + logger.error(err_msg) + raise ValueError(err_msg) + + self.seed = seed + self._orig_past_data = past_data.copy() + self._orig_random_effects = self._make_dims_conform( + self._orig_past_data, random_effects, ModelConstants.ParamType.RANDOM + ) + self._orig_fixed_effects = fixed_effects or dict() + self._orig_indicators = self._make_dims_conform( + self._orig_past_data, indicators, ModelConstants.ParamType.INDICATOR + ) + + needed_dims = [dim for dim in past_data.dims if dim != DimensionConstants.DRAW] + self.dimensions_order_list, self.n_grouping_dims = self._find_random_shared_dims( + self._orig_random_effects, needed_dims + ) + self.years = years + self.draws = draws + self.past_data = self._convert_xarray_to_numpy( + self._orig_past_data, self.dimensions_order_list + ) + self._orig_covariate_data = ( + dict() if not covariate_data else {cov.name: cov.copy() for cov in covariate_data} + ) + self.covariate_data = self._convert_covariates( + list(self._orig_covariate_data.values()), + self.dimensions_order_list, + self._orig_past_data, + self.years, + ) + self.fixed_effects = self._convert_fixed_effects(self._orig_fixed_effects) + self.fixed_intercept = self._convert_fixed_intercept(fixed_intercept) + self.random_effects = self._convert_random_effects( + self._orig_random_effects, self.dimensions_order_list + ) + + self.indicators = self._convert_indicators( + self._orig_indicators, self.dimensions_order_list + ) + + LimeTr._assert_covariate_params( + self._orig_fixed_effects, + self._orig_random_effects, + list(self._orig_covariate_data.values()), + ) + + ordered_dim_counts = self._get_dim_counts( + self._orig_past_data, self.dimensions_order_list + ) + + self.model_instance = LME( + dimensions=ordered_dim_counts, + n_grouping_dims=self.n_grouping_dims, + y=self.past_data, + covariates=self.covariate_data, + indicators=self.indicators, + global_effects_names=self.fixed_effects, + global_intercept=self.fixed_intercept, + random_effects=self.random_effects, + ) + + def fit(self) -> xr.Dataset: + """Fits the model and then returns the draws of the coefficients. + + Returns: + xr.Dataset: the fit coefficient draws or means + """ + self.model_instance.optimize( + trim_percentage=0.0, + inner_max_iter=1000, + inner_tol=1e-5, + outer_max_iter=1, + outer_step_size=1.0, + outer_tol=1e-6, + share_obs_std=True, + random_seed=self.seed, + ) + if self.fixed_effects or self.indicators: + self.model_instance.postVarGlobal() + + if self.random_effects: + self.model_instance.postVarRandom() + + # mean random effect estimates + coef_means = self._get_coefficient_means() + + # get covariance + covariance_dict = {} + if self.fixed_effects or self.indicators: + fixed_covariance = self.model_instance.var_beta + covariance_dict.update({ModelConstants.ParamType.FIXED: fixed_covariance}) + if self.random_effects: + random_covariance = self.model_instance.var_u + covariance_dict.update({ModelConstants.ParamType.RANDOM: random_covariance}) + self.posterior_cov_dict = np.array(covariance_dict) + + if self.draws: + coef_draws = self._generate_coefficient_draws() + return coef_draws + else: + return coef_means + + def predict(self) -> xr.DataArray: + """Apply the draws of coefficients to obtain draws of forecasted ratio or indicator.""" + self.predictions = 0 + if self.draws: + for data_var in self.coef_draws_ds.data_vars: + self._apply_coefficients(data_var) + else: + for data_var in self.coef_mean_ds.data_vars: + self._apply_coefficients(data_var) + + # expand dims other than scenario and year in case covariates don't + # include all dims. + self.predictions = expand_dimensions(self.predictions, **self._orig_past_data.coords) + return self.predictions + + def save_coefficients(self, output_dir: FHSDirSpec, entity: str) -> None: + """Saves model coefficients and posterior variance. + + Coefficients are saved in a xr.DataSet, while the posterior variance is saved directly + as a dict containing fixed and random variance objects from LimeTr + (var_u and var_beta). + + Args: + output_dir (Path): the output directory to save coefficients in + entity (str): the name of the entity to save. This will become the filename. + """ + fs = FileSystemManager.get_file_system() + + # Save draw-level coefficients. + draw_filespec = output_dir.append_sub_path(("coefficients",)).file(f"{entity}.nc") + xarray_wrapper.save_xr_scenario( + xr_obj=self.coef_mean_ds, + file_spec=draw_filespec, + metric="rate", + space="identity", + ) + + # Save covariance matrix of coefficients. + variance_dir = output_dir.append_sub_path(("variance",)) + fs.makedirs(variance_dir) + + variance_file = variance_dir.file(f"{entity}.npy") + + np.save(variance_file.data_path(), self.posterior_cov_dict) + try: + fs.chmod(variance_file, 0o775) + except PermissionError: + logger.warning( + f"Could not set group-writable permissions on {variance_file}. " + "Please set permissions manually." + ) + + def _apply_coefficients(self, data_var: str) -> None: + """Apply the fixed-effect, random-effect, or indicator coefficients to the predictions. + + Fixed effects and random effects will either be associated with a + covariate, meaning the coefficient is multiplied by covariate data and + then added to the predictions, or it is an intercept and should simply + be added to the predictions. Indicators will never have covariate + names, so will always just be added to the predictions. + + Args: + data_var (str): the name of the data variable to apply coefficients to + """ + if self.draws: + dataset_to_apply = self.coef_draws_ds + else: + dataset_to_apply = self.coef_mean_ds + + covariate_names = list(self.covariate_data.keys()) + if data_var == ModelConstants.ParamType.FIXED: + param_names = dataset_to_apply[data_var]["parameter"].values + elif LimeTr._is_random_effect(data_var, covariate_names): + param_names = [data_var[len(ModelConstants.RANDOM_PREFIX) :]] + else: # Assume that coefficient is a random intercept or indicator + param_names = ["not_covariate"] + + for param_name in param_names: + if "parameter" in dataset_to_apply[data_var].dims: + coef = dataset_to_apply[data_var].sel(parameter=param_name, drop=True) + else: + coef = dataset_to_apply[data_var] + + if param_name in list(self.covariate_data.keys()): + # Since the name of the parameter appears in the list of + # covariates, we apply its coefficient to the respective + # covariate data. + if self.draws: + self.predictions = self.predictions + ( + self._orig_covariate_data[param_name] * coef + ) + else: + self.predictions = self.predictions + mean_of_draw( + self._orig_covariate_data[param_name] * coef + ) + + else: # Assume that coefficient is an intercept or indicator + self.predictions = self.predictions + coef + + def _get_coefficient_means(self) -> xr.Dataset: + """Extract coefficent mean estimates from LimeTr and combine them into xr.Dataset.""" + mean_feffect_da_list = self._get_fixed_effect_means() + mean_reffect_da_list = self._get_random_effect_means() + mean_indicator_da_list = self._get_indicator_means() + + self.coef_mean_ds = xr.merge( + mean_feffect_da_list + mean_reffect_da_list + mean_indicator_da_list + ) + + return self.coef_mean_ds + + def _get_fixed_effect_means(self) -> List[xr.DataArray]: + mean_feffect_da_list = [] + if self.fixed_intercept: + feffects = ["global_intercept"] + list(self.fixed_effects.keys()) + else: + feffects = list(self.fixed_effects.keys()) + for i, feffect_name in enumerate(feffects): + mean_fixed_da = xr.DataArray( + np.array([self.model_instance.beta_soln[i]]), + dims=("parameter"), + name=ModelConstants.ParamType.FIXED, + coords={"parameter": [feffect_name]}, + ) + mean_feffect_da_list.append(mean_fixed_da) + return mean_feffect_da_list + + def _get_random_effect_means(self) -> List[xr.DataArray]: + mean_reffect_da_list = [] + column_index = 0 + for reffect_name in self.random_effects.keys(): + random_name = ModelConstants.RANDOM_PREFIX + reffect_name + random_dims = self._orig_random_effects[reffect_name].dims + ordered_random_dims = [ + dim for dim in self.dimensions_order_list if dim in random_dims + ] + coords_dict = { + dim: list(self._orig_past_data[dim].values) for dim in ordered_random_dims + } + grouping_dims = self.dimensions_order_list[0 : self.n_grouping_dims] + len_list = [len(coords_dict[grouping_dim]) for grouping_dim in grouping_dims] + non_grouping_dims = sorted(list(set(coords_dict.keys()) - set(grouping_dims))) + coord_combos = _coord_combinations(coords_dict, non_grouping_dims) + if non_grouping_dims: + # loop through all coord combinations from right to left, i.e. + # if non_grouping_dims was ["age", "sex"] then it would loop + # through with age 1 sex 1, age 1 sex 2, age 2 sex 1, etc. + combo_ds_list = [] + for combo in coord_combos: + effect = [] + for row_index in range(0, np.prod(len_list)): + effect.append(self.model_instance.u_soln[row_index][column_index]) + reshape_list = len_list + [int(bool(i)) for i in non_grouping_dims] + combo_np = np.array(effect).reshape(reshape_list) + combo_ds = xr.DataArray( + combo_np, dims=list(combo.keys()), name=random_name, coords=combo + ).to_dataset() + combo_ds_list.append(combo_ds) + column_index = column_index + 1 + mean_random_da = xr.combine_by_coords(combo_ds_list)[random_name] + else: + effect = [] + for row_index in range(0, np.prod(len_list)): + effect.append(self.model_instance.u_soln[row_index][column_index]) + effect_np = np.array(effect).reshape(len_list) + mean_random_da = xr.DataArray( + effect_np, + dims=list(coords_dict.keys()), + name=random_name, + coords=coords_dict, + ) + column_index = column_index + 1 + + mean_reffect_da_list.append(mean_random_da) + + return mean_reffect_da_list + + def _get_indicator_means(self) -> List[xr.DataArray]: + mean_indicator_da_list = [] + for i, indicator_name in enumerate(self.indicators.keys()): + if i == 0: + start_value = len(self.fixed_effects.keys()) + indicator_dims = self._orig_indicators[indicator_name] + ordered_indicator_dims = [ + dim for dim in self.dimensions_order_list if dim in indicator_dims + ] + coords_dict = { + dim: list(self._orig_past_data[dim].values) for dim in ordered_indicator_dims + } + num_values = np.prod([len(dim) for dim in coords_dict.values()]) + end_value = start_value + num_values + indicator_vals_list = self.model_instance.beta_soln.tolist()[start_value:end_value] + len_list = [len(val) for val in coords_dict.values()] + + mean_indicator_da = xr.DataArray( + np.array(indicator_vals_list).reshape(len_list), + dims=list(coords_dict.keys()), + name=indicator_name, + coords=coords_dict, + ) + mean_indicator_da_list.append(mean_indicator_da) + + return mean_indicator_da_list + + def _generate_coefficient_draws(self) -> xr.Dataset: + """Utilize the LimeTr wrapper `outputDraws` function to generate draws of coefficients. + + Using `np.random.multivariate_normal` and convert the results to a ``xr.DataSet``. + + Returns: + xr.Dataset: the coefficient draws dataset + """ + (fixed_samples, indicator_samples, random_samples) = self.model_instance.outputDraws( + n_draws=self.draws, by_type=True, combine_cov=True + ) + + coef_da_list = [] + if fixed_samples: + fixed_da = xr.DataArray( + fixed_samples.array, + dims=("parameter", DimensionConstants.DRAW), + name=ModelConstants.ParamType.FIXED, + coords={ + "parameter": fixed_samples.names, + DimensionConstants.DRAW: range(self.draws), + }, + ) + coef_da_list.append(fixed_da) + + for random_index, random_sample in enumerate(random_samples): + random_name = ModelConstants.RANDOM_PREFIX + random_sample.name + param_name = random_sample.name + random_dims = self._orig_random_effects[param_name].dims + ordered_random_dims = [ + dim for dim in self.dimensions_order_list if dim in random_dims + ] + coords_dict = { + dim: list(self._orig_past_data[dim].values) for dim in ordered_random_dims + } + coords_dict.update({DimensionConstants.DRAW: range(self.draws)}) + + random_da = xr.DataArray( + random_sample.array, + dims=list(coords_dict.keys()), + name=random_name, + coords=coords_dict, + ) + coef_da_list.append(random_da) + + for indicator_index, indicator_sample in enumerate(indicator_samples): + param_name = indicator_sample.name + indicator_dims = self._orig_indicators[param_name] + ordered_indicator_dims = [ + dim for dim in self.dimensions_order_list if dim in indicator_dims + ] + + coords_dict = { + dim: list(self._orig_past_data[dim].values) for dim in ordered_indicator_dims + } + coords_dict.update({DimensionConstants.DRAW: range(self.draws)}) + + indicator_da = xr.DataArray( + indicator_sample.array, + dims=list(coords_dict.keys()), + name=param_name, + coords=coords_dict, + ) + coef_da_list.append(indicator_da) + + self.coef_draws_ds = xr.merge(coef_da_list) + + return self.coef_draws_ds + + @staticmethod + def _is_random_effect(effect_name: str, covariate_names: List[str]) -> bool: + """Checks if the effect is a random effect, excluding intercepts.""" + if not effect_name.startswith(ModelConstants.RANDOM_PREFIX): + return False + return effect_name[len(ModelConstants.RANDOM_PREFIX) :] in covariate_names + + @staticmethod + def _get_dim_counts(data: xr.DataArray, dimensions_order_list: List[str]) -> List[int]: + """The lengths of all the dims on the dependent var that are relevant to fitting. + + Args: + data (xr.DataArray): Initial/unchanged past data of dependent variable + dimensions_order_list (list[str]): An ordered list of dimensions needed by y with + the shared grouping dimensions at the front of the list + + Returns: + list[int]: The lengths of each dimension where the dims ordered relative + to ``dimensions_order_list``. + """ + dim_counts = [len(data[dim].values) for dim in dimensions_order_list] + + return dim_counts + + @staticmethod + def _find_random_shared_dims( + random_effects: Optional[Dict], needed_dims: List[str] + ) -> Tuple[List[str], int]: + """Determine the order that dims will go in and assign the number of grouping dims. + + Args: + needed_dims (list[str]): + Dimensions that are present in the y data and must be accounted + for in the ordered list of dimensions. + random_effects (Dict | None, optional): + A dictionary mapping covariates to the dimensions that their + random slopes will be estimated for and the standard deviation + of the gaussian prior on their variance. + of the form {covariate: (list[dimension], std)...} + Any key that is not in covariate_data is assumed to be an + intercept. + The std float represents the value of the gaussian prior on + random effects variance. None means no prior. + e.g.:: + + {"haq": (["location_id", "age_group_id"], None), + "education": (["location_id"], 3)} + + Returns: + Tuple[List[str], int]: dimensions_order_list, an ordered list of dimensions needed + by y with the shared grouping dimensions at the front of the list; and + n_grouping_dims, the number of dimensions to group by for optimization in + LimeTr. **NOTE:** If there are random effects, then the grouping dims + and the non-grouping dims will be sorted alphabetically + relative to their respective sets. + + Raises: + ValueError: If there are no shared dims among all of the random effects, + while random effects do actually exist, or if a given random effect has 1 or + more dimensions not included in ``self.dimensions_order_list``, i.e. those of + the dependent variable. + """ + if not random_effects: + # There are *no* random effects, so there are no dims shared among + # all of the random effects. Skip to the end. + sorted_dims_order_list = needed_dims + n_grouping_dims = 0 + else: + # Build list of all random effect dimensions + all_random_effect_dims = [ + reffect.dims for reffect in list(random_effects.values()) + ] + + # Find shared dimensions + shared_dims = list(set.intersection(*map(set, all_random_effect_dims))) + + # If there are no shared dims among all of the random effects, + # while random effects do actually exist, then LimeTr cannot run. + msg = "There are no shared dimensions!" + if not shared_dims: + logger.error(msg) + raise ValueError(msg) + + # Append the rest of the dims to shared_dims + dimensions_order_list = deepcopy(shared_dims) + for dim in needed_dims: + if dim not in dimensions_order_list: + dimensions_order_list.append(dim) + + n_grouping_dims = len(shared_dims) + + # The grouping dims and the non-grouping dims will be sorted + # alphabetically relative to their respective sets. + sorted_dims_order_list = sorted(dimensions_order_list[:n_grouping_dims]) + sorted( + dimensions_order_list[n_grouping_dims:] + ) + + LimeTr._assert_param_dims( + random_effects, needed_dims, param_type=ModelConstants.ParamType.RANDOM + ) + if set(needed_dims) != set(sorted_dims_order_list): + raise ValueError( + f"The set of `needed_dims` [{needed_dims}] does not match the set of " + f"`sorted_dims_order_list [{sorted_dims_order_list}]." + ) + + return sorted_dims_order_list, n_grouping_dims + + @staticmethod + def _convert_fixed_effects(fixed_effects: Dict) -> Dict: + """Converts dict of fixed effects to list of names required by LimeTr. + + Ignores effect restrictions for the moment until implemented by LimeTr API. + + Args: + fixed_effects (Dict): Dict of covariates to have their + corresponding coefficients estimated and bounded by the given list. + e.g.:: + + {"haq": [0, float('inf')], + "edu": [-float('inf'), float('inf')] + } + + Returns: + limetr_fixed_effects (Dict): A dict of covariate names that will be used + as fixed effects + """ + limetr_fixed_effects = fixed_effects + + return limetr_fixed_effects + + @staticmethod + def _convert_fixed_intercept(fixed_intercept: Optional[str]) -> bool: + """Converts fixed intercept str/bool to bool for LimeTr API. + + Args: + fixed_intercept (str | None, optional): To restrict the fixed intercept to be + positive or negative, pass "positive" or "negative", respectively. + "unestricted" says to estimate a fixed effect intercept that is not restricted + to positive or negative. If ``None`` then no fixed intercept is + estimated. Currently all of the strings get converted to + unrestricted. + + Returns: + (bool): Whether or not to have a fixed intercept + """ + if isinstance(fixed_intercept, str): + return True + + return False + + @staticmethod + def _convert_random_effects( + random_effects: Dict, dimensions_order_list: List[str] + ) -> Dict: + """Converts dict of random effects to format required by LimeTr API. + + **NOTE:** This should be used after ``LimeTr._find_random_shared_dims`` + to ensure that the random effects have been validated. + + **pre-conditions:** + * Random effects have at least one shared dimension. + * Random effects are mapped to a non-empty list of dimensions. + * The dimensions within each random effect are also dimensions of the + dependent variable (i.e. ``dimensions_order_list``). + + Args: + dimensions_order_list (list[str]): An ordered list of dimensions needed by y with + the shared grouping dimensions at the front of the list + random_effects (Dict): A dictionary mapping covariates to the + dimensions that their random slopes will be estimated for and the standard + deviation of the gaussian prior on their variance. + of the form {covariate: (list[dimension], std)...} + Any key that is not in covariate_data is assumed to be an + intercept. + The std float represents the value of the gaussian prior on + random effects variance. None means no prior. + e.g.:: + + {"haq": (["location_id", "age_group_id"], None), + "education": (["location_id"], 3)} + + Returns: + limetr_random_effects (Dict): A dictionary where key is the name of the covariate + or intercept and value is a tuple with a boolean list specifying + dimensions to impose random effects on and the standard + deviation of the gaussian prior on their variance. e.g. + + {'haq': ([True, False, False, False], None), + 'intercept_location': ([True, False, False, False], 3)} + """ + LimeTr._assert_param_dims( + random_effects, + dimensions_order_list, + param_type=ModelConstants.ParamType.RANDOM, + ) + + limetr_random_effects = {} + for effect_name, effect_tuple in random_effects.items(): + effect_bool_dims = LimeTr._get_existing_dims( + avail_dims=effect_tuple.dims, + dimensions_order_list=dimensions_order_list, + param_type=ModelConstants.ParamType.RANDOM, + param_name=effect_name, + ) + limetr_random_effects.update( + {effect_name: (effect_bool_dims, effect_tuple.prior_value)} + ) + + return limetr_random_effects + + @staticmethod + def _convert_covariates( + covariates: List[xr.DataArray], + dimensions_order_list: List[str], + dep_var_da: xr.DataArray, + years: YearRange, + ) -> Dict: + """Converts list of covariate dataarrays into 1D numpy arrays to match LimeTr API. + + We fit covariate coefficients at mean-of-draw, and reference-scenario + level, but we apply the coefficients to the covariate data that + includes draws and scenarios. + + **pre-conditions:** + * The dimensions within each covariate are also dimensions of the + dependent variable (i.e. ``dimensions_order_list``). + * Each covariate's dataarray should have a ``scenario``, with at least + one coord ``scenario=0``. + * All covariates should have the same scenario coordinates -- + covariates that don't have actual scenarios should have their + reference scenario broadcast out to all the expected scenarios by + this point. + + Args: + covariates (list[xr.DataArray]): Past and forecast data for each covariate (i.e. + independent variable). Each individual covariate dataarray should be named + with the stage as is defined in the FHS file system. + dimensions_order_list (list[str]): The dimensions of the dependent variable, where + they are ordered with the random-effect grouping dims first. + dep_var_da (xr.DataArray): Past data for dependent variable being forecasted. Will + be used to infer expected coordinates on each dimension of the + covariate data. + years (YearRange): FHS year range + + Returns: + Dict: A dictionary where the key is the covariate name, value is a + tuple of (1D np.array, order_bool_list_dimensions_of_cov) + with the ordered_bool_list based on the dimensions order of the + y data. + """ + mean_ref_covariates = [ + ( + mean_of_draw(cov.sel(year_id=years.past_years).sel(scenario=0, drop=True)) + if "scenario" in cov.coords + else mean_of_draw(cov.sel(year_id=years.past_years)) + ) + for cov in covariates + ] + + cov_dim_dict = {cov.name: cov.dims for cov in mean_ref_covariates} + LimeTr._assert_param_dims(cov_dim_dict, dimensions_order_list) + assert_covariate_coords(mean_ref_covariates, dep_var_da) + + limetr_covariates = {} + for cov_da in mean_ref_covariates: + cov_name = cov_da.name + ordered_cov_dims = [dim for dim in dimensions_order_list if dim in cov_da.dims] + cov_array = LimeTr._convert_xarray_to_numpy(cov_da, ordered_cov_dims) + + cov_bool_dims = LimeTr._get_existing_dims( + avail_dims=cov_da.dims, + dimensions_order_list=dimensions_order_list, + param_type=ModelConstants.ParamType.COVARIATE, + param_name=cov_name, + ) + limetr_covariates.update({cov_name: (cov_array, cov_bool_dims)}) + + return limetr_covariates + + @staticmethod + def _make_dims_conform( + dep_var: xr.DataArray, params: Optional[Dict], param_type: str + ) -> Dict: + """Make indicators or random effects dims consistent with those of the dependent var. + + For example, we might try applying an age-sex indicator to a cause that + only affects females, so by this point the dependent variable dataarray + has no ``sex_id`` dim at all. Therefore, the indicator should be + reduced to just an age-indicator. This should not be used for other + things, such as covariates. + + Args: + dep_var (xr.DataArray): Dependent variable being forecasted -- used here to infer + expected dims. + params (Dict | None): A dictionary mapping covariates to the dimensions that their + random slopes will be estimated for and the standard deviation + of the gaussian prior on their variance of the form + {covariate: (list[dimension], std)...} for random effects or + {param: list[dimension]} for indicators + param_type (str): What the params are referring to for the purposes of printing + warning messages about dimensions being dropped (e.g. + "indicator".) + + Returns: + Dict: Updated dict that is consistent with the dependent variable. + """ + if not params: + return dict() + + expected_dims = set(dep_var.dims) + updated_params = {} + for name, info in params.items(): + if param_type == ModelConstants.ParamType.RANDOM: + dims = info.dims + else: + dims = info + + new_dims = sorted(expected_dims & set(dims)) + + if param_type == ModelConstants.ParamType.RANDOM: + new_info = RandomEffect(new_dims, info.prior_value) + else: + new_info = new_dims + + if new_dims and set(dims).issubset(expected_dims): + updated_params.update({name: info}) + elif new_dims: + updated_params.update({name: new_info}) + warn_msg = ( + f"The {param_type} {name}, originally across " + f"{dims}, had to be reduced to a(n) {param_type} " + f"only across {new_dims}" + ) + logger.warning( + warn_msg, + bindings=dict( + param_name=name, + param_type=param_type, + old_dims=dims, + new_dims=new_dims, + ), + ) + else: + warn_msg = ( + f"The {param_type} {name}, originally across " + f"{dims} had to be dropped completely" + ) + logger.warning( + warn_msg, + bindings=dict( + param_name=name, + param_type=param_type, + dropped_dims=dims, + ), + ) + + return updated_params + + @staticmethod + def _convert_indicators(indicators: Dict, dimensions_order_list: List[str]) -> Dict: + """Converts list of dims to create indicator variables. + + For dictionary to format required by LimeTr API. + + **pre-conditions:** + * Indicator are mapped to a non-empty list of dimensions. + * The dimensions within each indicator are also dimensions of the + dependent variable (i.e. ``dimensions_order_list``). + + Args: + dimensions_order_list (list[str]): The dimensions of the dependent variable, where + they are ordered with the random-effect grouping dims first. + indicators (Dict): A dictionary mapping indicators to the + dimensions that they are indicators on. e.g.:: + {"ind_age_sex": ["age_group_id", "sex_id"], + "ind_loc": ["location_id"]} + + Returns: + Dict: A dictionary where the key is the indicator name and value is a + boolean list specifying dimensions on which to use indicator. + """ + LimeTr._assert_param_dims(indicators, dimensions_order_list) + + limetr_indicators = {} + for indicator, indicator_dims in indicators.items(): + ind_bool_dims = LimeTr._get_existing_dims( + avail_dims=indicator_dims, + dimensions_order_list=dimensions_order_list, + param_type=ModelConstants.ParamType.INDICATOR, + param_name=indicator, + ) + limetr_indicators.update({indicator: ind_bool_dims}) + + return limetr_indicators + + @staticmethod + def _convert_xarray_to_numpy( + dataarray: xr.DataArray, dimensions_order_list: List[str] + ) -> np.ndarray: + """Converts multi-dimensional dataarray into 1-D numpy array. + + Transposes the array so that the dimensions are in the order that is + expected for computation, before flattening it to be 1-D. + + Args: + dataarray (xr.DataArray): An xarray DataArray to convert into a 1D numpy array + dimensions_order_list (list[str]): The dimensions of the dependent variable, where + they are ordered with the random-effect grouping dims first. + + Returns: + numpy_data (np.ndarray): A 1D numpy array containing the transposed and reshaped + data + """ + return dataarray.transpose(*dimensions_order_list).values.flatten() + + @staticmethod + def _get_existing_dims( + avail_dims: List[str], + dimensions_order_list: List[str], + param_type: str, + param_name: str, + ) -> List[bool]: + """Convert a list of dimension names into a list of booleans. + + Ordered based on the dimensions order list of the dependent-variable + data for the model. + + For example,:: + + >>> avail_dims + ["age_group_id", "sex_id"] + >>> dimensions_order_list + ["sex_id", "location_id", "age_group_id"] + >>> LimeTr._get_existing_dims( + avail_dims, dimensions_order_list, "haq", "random_effect") + [True, False, True] + + Args: + avail_dims (list[str]): The dimensions to convert into the bool list + dimensions_order_list (list[str]): An ordered list of dimensions needed by y with + the shared grouping dimensions at the front of the list + param_type (str): The type of parameter, e.g. covariate, random-effect, or + indicator. + param_name (str): The name of the parameter e.g. "haq", "intercept", or + "age_indicator". + + Returns: + list(bool): Boolean list of whether dimension in avail dims, where order is + based on the order of ``dimensions_order_list``. + + Raises: + ValueError: If there are dims in the covariates, random-effects, or + indicators that are NOT in the dependent variable (i.e. in + ``needed_dims``). + """ + missing_dims = set(avail_dims) - set(dimensions_order_list) + if missing_dims: + err_msg = ( + f"The {param_name} {param_type} has extra dims={missing_dims} " + f"that ``past_data`` is missing" + ) + logger.error( + err_msg, + bindings=dict( + model=__class__.__name__, + param_name=param_name, + param_type=param_type, + missing_dims=missing_dims, + ), + ) + raise ValueError(err_msg) + + avail_dims = [dim in avail_dims for dim in dimensions_order_list] + + return avail_dims + + @staticmethod + def _assert_param_dims( + params: Dict, + needed_dims: List[str], + param_type: str = ModelConstants.ParamType.NONRANDOM, + ) -> None: + """Check that all dimensions of random effects are within ``needed_dims``. + + This is intended to be used on dicts that map random effects and + indicators to their respective dimensions, but NOT for fixed effects. + + Args: + params (Dict): A dictionary mapping covariates to the dimensions that their + random slopes will be estimated for. + of the form ``{covariate: list[dimension], ...}`` + needed_dims (list[str]): Dimensions that are present in the y data and must be + accounted for in the ordered list of dimensions. + param_type (str): the param type name to validate against + + Raises: + ValueError: If a given random effect has 1 or more dimensions not included + in ``needed_dims``, i.e. those of the dependent variable. + """ + err_msg = "" + err_log_bindings = dict( + param_type=param_type, + needed_dims=needed_dims, + extra_dims=dict(), + ) + for param_name, param_dims in params.items(): + if param_type == ModelConstants.ParamType.RANDOM: + param_dims = param_dims.dims + diff = set(param_dims) - set(needed_dims) + if diff: + err_msg += ( + f"`{param_name}` has dims={tuple(diff)} that don't " + f"exist in dependent variable data. " + ) + err_log_bindings["extra_dims"].update({param_name: diff}) + + if err_msg: + logger.error(err_msg, bindings=err_log_bindings) + raise ValueError(err_msg) + + @staticmethod + def _assert_covariate_params( + fixed_effects: Dict, random_effects: Dict, covariates: Iterable[xr.DataArray] + ) -> None: + """Throws a warning if there's no effects associated with one of the covariates.""" + covariate_names = [cov.name for cov in covariates] + effects = list(set(fixed_effects.keys()) & set(random_effects.keys())) + for effect in effects: + if effect in covariate_names: + return + warn_msg = ( + "There must be at least one effect that is associated with one of the given" + " covariates" + ) + logger.warning(warn_msg) + + +def _coord_combinations( + coords_dict: Dict[str, List[np.ndarray]], non_grouping_dims: Optional[List[str]] +) -> List[Dict[str, List[np.ndarray]]]: + """Return list of coord dictionaries in order based on non_grouping_dim coordinate values. + + E.g. if non_grouping_dims is ['age', 'sex'] and coords + dict has {"loc": [6, 8], "age": [28, 29, 30], "sex": [1, 2]} then it would + return a list that looks like this: + [ + {"loc": [6, 8], "age": [28], "sex": [1]}, + {"loc": [6, 8], "age": [28], "sex": [2]}, + {"loc": [6, 8], "age": [29], "sex": [1]}, + ... , + {"loc": [6, 8], "age": [30], "sex": [2]} + ] + + Args: + coords_dict (Dict[str, List[np.ndarray]]): Dictionary containing coordinates observed + in some xarray + non_grouping_dims (Optional[List[str]]): Optional list of dimensions to be iterated + over + + Returns: + List[Dict[str, List[np.ndarray]]]: if ``non_grouping_dims`` is None a copy of + `coords_dict`, otherwise a new list of coordinate dictionaries + """ + if non_grouping_dims: + non_grouping_coords_list = [coords_dict[dim] for dim in non_grouping_dims] + non_grouping_coord_combos = list(itertools.product(*non_grouping_coords_list)) + grouping_coords_dict = coords_dict.copy() + for dim in non_grouping_dims: + del grouping_coords_dict[dim] + coord_combo_list = [] + for combo in non_grouping_coord_combos: + combo_dict = grouping_coords_dict.copy() + for i, non_grouping_dim in enumerate(non_grouping_dims): + combo_dict.update({non_grouping_dim: [combo[i]]}) + coord_combo_list.append(combo_dict) + else: + return [coords_dict.copy()] + + return coord_combo_list diff --git a/gbd_2021/disease_burden_forecast_code/nonfatal/models/omega_selection_strategy.py b/gbd_2021/disease_burden_forecast_code/nonfatal/models/omega_selection_strategy.py new file mode 100644 index 0000000..3813441 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/nonfatal/models/omega_selection_strategy.py @@ -0,0 +1,355 @@ +"""Strategies for determining the weight for the Annualized Rate-of-Change (ARC) method. + +Find where the RMSE is 1 (the RMSE is normalized so that 1 is always the lowest +RMSE). If there are ties, take the lowest weight. + +There two options for choosing the weight: +1) Use the weight where the normalized-RMSE is 1. +2) If none of the weights have a normalized-RMSE no more than the +""" + +from typing import Any + +import numpy as np +import xarray as xr +from fhs_lib_database_interface.lib.constants import DimensionConstants +from tiny_structured_logger.lib.fhs_logging import get_logger + +from fhs_lib_model.lib.constants import ArcMethodConstants + +logger = get_logger() + + +def use_omega_with_lowest_rmse(rmse: xr.DataArray, **kwargs: Any) -> float: + """Use the omega (weight) with the lowest RMSE. + + If there are ties, choose the smallest omega. + + Args: + rmse: + Array with one dimension, "weight", that contains the tested + omegas as coordinates. The data is the RMSE (Root _Mean_ Square + Error or Root _Median_ Square Error) values. + kwargs: + Ignores any additional keyword args. + + Returns: + The weight to use for the ARC method. + """ + chosen_weight = rmse.where(rmse == rmse.min()).dropna("weight")["weight"].values[0] + + logger.debug(f"`use_omega_with_lowest_rmse` weight selected: {chosen_weight}") + return chosen_weight + + +def use_smallest_omega_within_threshold( + rmse: xr.DataArray, threshold: float = 0.05, **kwargs: Any +) -> float: + """Returns the smallest omega possible compared to normalized-RMSE, using a threshold. + + If none of the weights have a normalized-RMSE (normalized by dividing by + minimum RMSE) no more than the threshold percent greater than the minimum + normalized-RMSE, which will be 1, then the weight of 0.0 is used. + Otherwise, starting at the first weight smaller than the weight of the + minimum normalized-RMSE and moving in the direction of decreasing weights, + choose the first weight that is more than the threshold percent greater + than the minimum normalized-RMSE. + + Args: + rmse: + Array with one dimension, "weight", that contains the tested + omegas as coordinates. The data is the RMSE (Root _Mean_ Square + Error or Root _Median_ Square Error) values. + threshold: + The threshold percent to use for selecting the weight. + kwargs: + Ignores any additional keyword args. + + Returns: + The weight to use for the ARC method. + """ + norm_rmse = rmse / rmse.min() + + diffs = norm_rmse - 1 + + # If there are, then the set the weight to the first weight with an + # normalized-RMSE less than threshold percent above the minimum + # normalized-RMSE. + weight_with_lowest_rmse = ( + norm_rmse.where(norm_rmse == norm_rmse.min()).dropna("weight")["weight"].values[0] + ) + weights_to_check = [w for w in norm_rmse["weight"].values if w < weight_with_lowest_rmse] + diffs_to_check = diffs.sel(weight=weights_to_check) + diffs_greater = diffs_to_check.where(diffs_to_check >= threshold).dropna("weight") + if len(diffs_greater) > 0: + # take the max weight greater than the threshold but less than the + # with the lowest RMSE. + chosen_weight = diffs_greater["weight"].values.max() + else: + chosen_weight = 0.0 + + logger.debug(f"`use_smallest_omega_within_threshold` weight selected: {chosen_weight}") + return chosen_weight + + +def use_omega_rmse_weighted_average(rmse: xr.DataArray, **kwargs: Any) -> float: + r"""Use the RMSE-weighted average of the range of tested-omegas. + + .. math:: + + \bar{\omega} = \frac{\sum\limits_{i=0}^{N}\frac{\omega_i}{RMSE_i}} + {\sum\limits_{i=0}^{N}\frac{1}{RMSE_i}} + + where :math:`N` is the largest in the range of omegas that were tested. + + *Note* under the special case when one or more of the weights has an RMSE + of 0, we consider any weights with RMSE values of zero, to be weighted + infinitely, so we just take the mean of all the weights with an RMSE of + zero. + + Args: + rmse (xarray.DataArray): + Array with one dimension, "weight", that contains the tested + omegas as coordinates. The data is the RMSE (Root _Mean_ Square + Error or Root _Median_ Square Error) values. + kwargs: + Ignores any additional keyword args. + + Returns: + The omega to use for the ARC method. + """ + zero_rmse = rmse == 0 + if zero_rmse.any(): + # Any weights with RMSE values of zero, will be weighted infinitely so + # just take the mean of all the weights with an RMSE of zero. + chosen_weight = float(rmse["weight"].where(zero_rmse).dropna("weight").mean("weight")) + else: + chosen_weight = float((rmse["weight"] / rmse).sum() / (1 / rmse).sum().values) + + logger.debug(f"`use_omega_rmse_weighted_average` weight selected: {chosen_weight}") + return chosen_weight + + +def use_average_omega_within_threshold( + rmse: xr.DataArray, threshold: float = 0.05, **kwargs: Any +) -> float: + """Take the average of the omegas with RMSEs within 5% of lowest RMSE. + + Args: + rmse (xarray.DataArray): + Array with one dimension, "weight", that contains the tested + omegas as coordinates. The data is the RMSE (Root _Mean_ Square + Error or Root _Median_ Square Error) values. + threshold (float): + The threshold percent to use for selecting the weight. + kwargs: + Ignores any additional keyword args. + + Returns: + The weight to use for the ARC method. + """ + chosen_weight = ( + rmse.where(rmse < rmse.values.min() + rmse.values.min() * threshold) + .dropna("weight")["weight"] + .values.mean() + ) + + logger.debug(f"`use_average_omega_within_threshold` weight selected: {chosen_weight}") + return chosen_weight + + +def use_average_of_zero_biased_omegas_within_threshold( + rmse: xr.DataArray, threshold: float = 0.05, **kwargs: Any +) -> float: + """Calculates weight by averaging omegas. + + Take the average of the omegas less than the omega with the lowest RMSE, + and with RMSEs within 5% of that lowest RMSE. + + Args: + rmse: + Array with one dimension, "weight", that contains the tested + omegas as coordinates. The data is the RMSE (Root _Mean_ Square + Error or Root _Median_ Square Error) values. + threshold: + The threshold percent to use for selecting the weight. + kwargs: + Ignores any additional keyword args. + + Returns: + The weight to use for the ARC method. + """ + norm_rmse = rmse / rmse.min() + + weight_with_lowest_rmse = ( + norm_rmse.where(norm_rmse == norm_rmse.min()).dropna("weight")["weight"].values[0] + ) + weights_to_check = [w for w in norm_rmse["weight"].values if w <= weight_with_lowest_rmse] + + rmses_to_check = norm_rmse.sel(weight=weights_to_check) + rmses_to_check_within_threshold = rmses_to_check.where( + rmses_to_check < 1 + threshold + ).dropna("weight") + + chosen_weight = rmses_to_check_within_threshold["weight"].values.mean() + + logger.debug( + ( + "`use_average_of_zero_biased_omegas_within_threshold` weight selected: " + f"{chosen_weight}" + ) + ) + return chosen_weight + + +def use_omega_distribution( + rmse: xr.DataArray, draws: int, threshold: float = 0.05, **kwargs: Any +) -> xr.DataArray: + """Samples omegas from a distribution (using RMSE). + + Takes the omegas with RMSEs within the threshold percent of omega with + the lowest RMSE, and takes the reciprocal RMSEs of those omegas as the + probabilities of omegas being sampled from multinomial a distribution. + + Args: + rmse: + Array with one dimension, "weight", that contains the tested + omegas as coordinates. The data is the RMSE (Root _Mean_ Square + Error or Root _Median_ Square Error) values. + draws: + The number of draws to sample from the distribution of omega values + threshold: + The threshold percent to use for selecting the weight. + kwargs: + Ignores any additional keyword args. + + Returns: + Samples from a distribution of omegas to use for the ARC method. + """ + rmses_in_threshold = rmse.where(rmse < rmse.values.min() + rmse.values.min() * threshold) + reciprocal_rmses_in_threshold = (1 / rmses_in_threshold).fillna(0) + norm_reciprocal_rmses_in_threshold = ( + reciprocal_rmses_in_threshold / reciprocal_rmses_in_threshold.sum() + ) + + omega_draws = xr.DataArray( + np.random.choice( + a=norm_reciprocal_rmses_in_threshold["weight"].values, + size=draws, + p=norm_reciprocal_rmses_in_threshold.values, + ), + coords=[list(range(draws))], + dims=[DimensionConstants.DRAW], + ) + return omega_draws + + +def use_zero_biased_omega_distribution( + rmse: xr.DataArray, draws: int, threshold: float = 0.05, **kwargs: Any +) -> xr.DataArray: + """Samples omegas from a distribution (using RMSE). + + Takes the omegas with RMSEs within the threshold percent of omega with + the lowest RMSE, and takes the reciprocal RMSEs of those omegas as the + probabilities of omegas being sampled from multinomial a distribution. + + Args: + rmse: + Array with one dimension, "weight", that contains the tested + omegas as coordinates. The data is the RMSE (Root _Mean_ Square + Error or Root _Median_ Square Error) values. + draws: + The number of draws to sample from the distribution of omega values + threshold: + The threshold percent to use for selecting the weight. + kwargs: + Ignores any additional keyword args. + + Returns: + Samples from a distribution of omegas to use for the ARC method. + """ + norm_rmse = rmse / rmse.min() + + weight_with_lowest_rmse = ( + norm_rmse.where(norm_rmse == norm_rmse.min()).dropna("weight")["weight"].values[0] + ) + weights_to_check = [w for w in norm_rmse["weight"].values if w <= weight_with_lowest_rmse] + + rmses_to_check = norm_rmse.sel(weight=weights_to_check) + rmses_to_check_within_threshold = rmses_to_check.where( + rmses_to_check < 1 + threshold + ).dropna("weight") + + reciprocal_rmses_to_check_within_threshold = (1 / rmses_to_check_within_threshold).fillna( + 0 + ) + norm_reciprocal_rmses_to_check_within_threshold = ( + reciprocal_rmses_to_check_within_threshold + / reciprocal_rmses_to_check_within_threshold.sum() + ) + + omega_draws = xr.DataArray( + np.random.choice( + a=norm_reciprocal_rmses_to_check_within_threshold["weight"].values, + size=draws, + p=norm_reciprocal_rmses_to_check_within_threshold.values, + ), + coords=[list(range(draws))], + dims=[DimensionConstants.DRAW], + ) + return omega_draws + + +def adjusted_zero_biased_omega_distribution( + rmse: xr.DataArray, + draws: int, + seed: int = ArcMethodConstants.DEFAULT_RANDOM_SEED, + **kwargs: Any, +) -> xr.DataArray: + """Samples omegas from a distribution (using RMSE). + + Takes the omegas from the lowest RMSE to zero, and takes the reciprocal + RMSEs of those omegas as the probabilities of omegas being sampled from + multinomial a distribution. + + Args: + rmse: + Array with one dimension, "weight", that contains the tested + omegas as coordinates. The data is the RMSE (Root _Mean_ Square + Error or Root _Median_ Square Error) values. + draws: + The number of draws to sample from the distribution of omega values + seed: + seed to be set for random number generation. + kwargs: + Ignores any additional keyword args. + + Returns: + Samples from a distribution of omegas to use for the ARC method. + """ + np.random.seed(seed) + + norm_rmse = rmse / rmse.min() + + weight_with_lowest_rmse = ( + norm_rmse.where(norm_rmse == norm_rmse.min()).dropna("weight")["weight"].values[0] + ) + weights_to_check = [w for w in norm_rmse["weight"].values if w <= weight_with_lowest_rmse] + + rmses_to_check = norm_rmse.sel(weight=weights_to_check) + + reciprocal_rmses_to_check = (1 / rmses_to_check).fillna(0) + norm_reciprocal_rmses_to_check = ( + reciprocal_rmses_to_check / reciprocal_rmses_to_check.sum() + ) + + omega_draws = xr.DataArray( + np.random.choice( + a=norm_reciprocal_rmses_to_check["weight"].values, + size=draws, + p=norm_reciprocal_rmses_to_check.values, + ), + coords=[list(range(draws))], + dims=[DimensionConstants.DRAW], + ) + return omega_draws diff --git a/gbd_2021/disease_burden_forecast_code/nonfatal/models/processing.py b/gbd_2021/disease_burden_forecast_code/nonfatal/models/processing.py new file mode 100644 index 0000000..4c68450 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/nonfatal/models/processing.py @@ -0,0 +1,1724 @@ +"""This module contains all the functions for processing data for use in modeling. + +It is divided into "pre-" and "post-" processing, i.e. functions that are +called before modeling and functions that are called +after modeling. +""" + +from abc import ABC +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import numpy as np +import xarray as xr +from fhs_lib_database_interface.lib.constants import ( + AgeConstants, + DimensionConstants, + ScenarioConstants, +) +from fhs_lib_database_interface.lib.query import age, restrictions +from fhs_lib_year_range_manager.lib.year_range import YearRange +from tiny_structured_logger.lib.fhs_logging import get_logger + +from fhs_lib_data_transformation.lib import age_standardize, filter +from fhs_lib_data_transformation.lib.constants import GBDRoundIdConstants, ProcessingConstants +from fhs_lib_data_transformation.lib.dimension_transformation import expand_dimensions +from fhs_lib_data_transformation.lib.draws import mean_of_draws +from fhs_lib_data_transformation.lib.exponentiate_draws import bias_exp_new +from fhs_lib_data_transformation.lib.intercept_shift import ( + mean_intercept_shift, + ordered_draw_intercept_shift, + unordered_draw_intercept_shift, +) +from fhs_lib_data_transformation.lib.resample import resample + +logger = get_logger() + + +class BaseProcessor(ABC): + """A "processor" transforms data into and out of some interesting space. + + It's defined by a pair of functions, pre_process and post_process. Implementors should + implement both. Note that the post_process function isn't a straight inverse of + pre_process, because it takes a second argument that helps it shift the data where it + belongs. + """ + + def pre_process(self, past_data: xr.DataArray) -> xr.DataArray: + """Perform pre-processing: transform past_data in some way determined by the object. + + Args: + past_data (xr.DataArray): The data to transform. + + Returns: + The past_data, transformed. + """ + return None + + def post_process( + self, modeled_data: xr.DataArray, past_data: xr.DataArray + ) -> xr.DataArray: + """Reverse the pre-processing. + + This should operate on modeled_data and use past_data to determine an intercept-shift. + + Args: + modeled_data (xr.DataArray): + Data produced using model.py methods. Assumes that it contains + forecasted years and at least years.past_end. + past_data (xr.DataArray): + Data to use for intercept shift. Assumes that it is in the same space produced + by the pre_process step. + + Returns: + modeled_data (xr.DataArray): + The post-processed modeled_data, transformed back to normal space from + wherever the pre_process took it. + """ + return None + + +def apply_intercept_shift( + modeled_data: xr.DataArray, + past_data: xr.DataArray, + intercept_shift: str, + years: YearRange, + shift_from_reference: bool, +) -> xr.DataArray: + """Apply one of the intercept shifts, based on intercept_shift, to modeled_data. + + We "shift" modeled_data so that it aligns with past_data at the last-past year. + + Args: + modeled_data (xr.DataArray): future data to shift. Will be aligned with the + last-past-year data in ``past_data``. + past_data (xr.DataArray): past data, the basis to which to shift. Should overlap with + modeled_data in the last-past-year in ``years``. + intercept_shift (str): the type of shift to perform: "mean", "unordered_draw", or + "ordered_draw". + years (YearRange): Years describing the range of years in the past_data and + modeled_data. Those DataArrays should both contain data for the last-past-year + (middle component). + shift_from_reference (bool): When True, the shift is calculated once from the future + reference scenario and applied to all other scenarios. When False, each scenario + has its own shift applied to match past data. + """ + if intercept_shift == "mean": + return mean_intercept_shift(modeled_data, past_data, years, shift_from_reference) + elif intercept_shift == "unordered_draw": + return unordered_draw_intercept_shift( + modeled_data, past_data, years.past_end, shift_from_reference + ) + elif intercept_shift == "ordered_draw": + return ordered_draw_intercept_shift( + modeled_data, + past_data, + years.past_end, + years.forecast_end, + shift_from_reference, + ) + else: + raise ValueError(f"Unknown intercept shift type {intercept_shift}") + + +class LogProcessor(BaseProcessor): + """A processor that transforms the data to log space and back again. + + Can also remove zero slices, take the mean, and age-standardize, depending on parameters + to the constructor. + """ + + def __init__( + self, + years: YearRange, + gbd_round_id: int, + offset: float = ProcessingConstants.DEFAULT_OFFSET, + remove_zero_slices: bool = False, + no_mean: bool = False, + bias_adjust: bool = False, + intercept_shift: Optional[str] = None, + age_standardize: bool = False, + rescale_age_weights: bool = True, + shift_from_reference: bool = True, + tolerance: float = ProcessingConstants.DEFAULT_PRECISION, + ) -> None: + """Construct a LogProcessor with additional processing as specified by the args. + + Args: + years (YearRange): Forecasting time series year range (used for age-standardize + and intercept shifting). + gbd_round_id (int): Which gbd_round_id the data is from (can be used in age + standardization). + offset (float): Optional. How much to offset the data from zero when taking the + log. + remove_zero_slices (bool): Optional. If True, remove slices of the data that are + all zeros. Default False + no_mean (bool): Optional. If True, the mean will not be taken (note that most + model.py classes expect that the input dependent data will not have draws) + bias_adjust (bool): Optional. Whether or not to perform bias adjustment on the + results coming out of log space. Generally ``True`` is the recommended option + intercept_shift (str | None): Optional. What type of intercept shift to perform, + if any. Options are "mean", "unordered_draw", and "ordered_draw". + age_standardize (bool): Optional. Whether to age-standardize data for modeling + purposes (results are age-specific). If you want to have age-standardized data + modeled and the results also be age-standardized, then you should + age-standardize the data outside of the processor. + rescale_age_weights (bool): Whether to rescale the age weights across available + ages when age-standardizing. + shift_from_reference (bool): Optional. Whether to calculate the differences used + for the intercept shift with only the reference scenario (True), or calculate + scenario-specific differences (False). Defaults to True. + tolerance (float): tolerance value to supply to the "closeness" check. Defaults to + ``ProcessingConstants.DEFAULT_PRECISION`` + """ + self.years = years + self.offset = offset + self.gbd_round_id = gbd_round_id + self.remove_zero_slices = remove_zero_slices + self.no_mean = no_mean + self.bias_adjust = bias_adjust + self.intercept_shift_type = intercept_shift + self.age_standardize = age_standardize + self.rescale_age_weights = rescale_age_weights + self.zero_slices_dict = {} + self.age_ratio = xr.DataArray() + self.shift_from_reference = shift_from_reference + self.tolerance = tolerance + + def __eq__(self, other: Any) -> bool: + """True if the other object matches this one in a few aspects.""" + if type(self) != type(other): + return False + elif self.offset != other.offset: + return False + elif self.years != other.years: + return False + elif self.gbd_round_id != other.gbd_round_id: + return False + else: + return True + + # Note: the docstring is given in the interface declaration, BaseProcessor. + def pre_process(self, past_data: xr.DataArray, *args: Any) -> xr.DataArray: + """Transform past_data into log space. (Also do several other things).""" + if len(args) != 0: + logger.warning( + "Warning: passed extra args to LogProcessor.pre_process. They will be ignored." + ) + + if self.remove_zero_slices: + past_data, self.zero_slices_dict = _remove_all_zero_slices( + data=past_data, + dims=[DimensionConstants.AGE_GROUP_ID, DimensionConstants.SEX_ID], + tolerance=self.tolerance, + ) + if not self.no_mean: + past_data = mean_of_draws(past_data) + + past_data = log_with_offset(past_data, self.offset) + + if self.age_standardize: + age_specific_data = past_data.copy() + past_data = _get_weights_and_age_standardize( + past_data, self.gbd_round_id, rescale=self.rescale_age_weights + ) + # take the difference of age-specific and age-standardized data in last past year + self.age_ratio = ( + age_specific_data + - past_data.sel(age_group_id=AgeConstants.STANDARDIZED_AGE_GROUP_ID, drop=True) + ).sel(year_id=self.years.past_end) + + return past_data + + def post_process( + self, modeled_data: xr.DataArray, past_data: xr.DataArray + ) -> xr.DataArray: + """Transform modeled_data from log into linear space. And intercept_shift if requested. + + past_data is expected to be the same as what was passed to the corresponding + pre_process, and should still be in linear space. + """ + if self.age_standardize: + modeled_data = ( + modeled_data.sel( + age_group_id=AgeConstants.STANDARDIZED_AGE_GROUP_ID, drop=True + ) + + self.age_ratio + ) + + if self.remove_zero_slices: + past_data, _ = _remove_all_zero_slices( + data=past_data, + dims=[DimensionConstants.AGE_GROUP_ID, DimensionConstants.SEX_ID], + tolerance=self.tolerance, + ) + + if self.intercept_shift_type == "mean": + past_data = mean_of_draws(past_data) + + modeled_data = invlog_with_offset( + modeled_data, self.offset, bias_adjust=self.bias_adjust + ) + + if ( + self.intercept_shift_type == "mean" + or self.intercept_shift_type == "unordered_draw" + or self.intercept_shift_type == "ordered_draw" + ): + modeled_data = self.intercept_shift( + modeled_data=modeled_data, + past_data=past_data, + offset=self.offset, + intercept_shift=self.intercept_shift_type, + years=self.years, + shift_from_reference=self.shift_from_reference, + ) + + if self.zero_slices_dict: + return _add_all_zero_slices(modeled_data, self.zero_slices_dict) + + return modeled_data + + @classmethod + def intercept_shift( + cls, + modeled_data: xr.DataArray, + past_data: xr.DataArray, + intercept_shift: str, + years: YearRange, + shift_from_reference: bool, + offset: float = ProcessingConstants.DEFAULT_OFFSET, + ) -> xr.DataArray: + """Move past and modeled data to log space, intercept-shift, and translate back. + + Performs the intercept-shift with the type given in ``intercept_shift``, using + ``years`` and ``shift_from_reference`` as parameters (see docs for + ``apply_intercept_shift``). + + Args: + modeled_data (xr.DataArray): future data to shift. Will be aligned with the + last-past-year data in ``past_data``. + past_data (xr.DataArray): past data, the basis to which to shift. Should overlap + with modeled_data in the last-past-year in ``years``. + offset (float): Offset for the log transform (see ``log_with_offset``) and inverse + transform. + intercept_shift (str): Type of intercept-shift to perform. See + ``apply_intercept_shift`` for the options. + years (YearRange): Years describing the range of years in the past_data and + modeled_data. Those DataArrays should both contain data for the last-past-year + (middle component). + shift_from_reference (bool): When True, the shift is calculated once from the + future reference scenario and applied to all other scenarios. When False, each + scenario has its own shift applied to match past data. + """ + past_data = log_with_offset(past_data, offset) + modeled_data = log_with_offset(modeled_data, offset) + modeled_data = apply_intercept_shift( + modeled_data, + past_data, + intercept_shift, + years, + shift_from_reference, + ) + modeled_data = invlog_with_offset(modeled_data, offset, bias_adjust=False) + return modeled_data + + +def intercept_shift_in_log( + modeled_data: xr.DataArray, + past_data: xr.DataArray, + intercept_shift: str, + years: YearRange, + shift_from_reference: bool, + offset: float = ProcessingConstants.DEFAULT_OFFSET, +) -> xr.DataArray: + """See LogProcessor.intercept_shift.""" + return LogProcessor.intercept_shift( + modeled_data=modeled_data, + past_data=past_data, + offset=offset, + intercept_shift=intercept_shift, + years=years, + shift_from_reference=shift_from_reference, + ) + + +class LogitProcessor(BaseProcessor): + """A Processor that transforms the data to logit space and back again. + + Can also remove zero slices, take the mean, and age-standardize, depending on parameters + to the constructor. + """ + + def __init__( + self, + years: YearRange, + gbd_round_id: int, + offset: float = ProcessingConstants.DEFAULT_OFFSET, + remove_zero_slices: bool = False, + no_mean: bool = False, + bias_adjust: bool = False, + intercept_shift: Optional[str] = None, + age_standardize: bool = False, + rescale_age_weights: bool = True, + shift_from_reference: bool = True, + tolerance: float = ProcessingConstants.DEFAULT_PRECISION, + ) -> None: + """Construct a LogitProcessor with additional processing as specified by the args. + + Args: + years (YearRange): Forecasting time series year range (used for age-standardize + and intercept shifting). + gbd_round_id (int): Which gbd_round_id the data is from (can be used in age + standardization) + offset (float): Optional. How much to offset the data from zero when taking the + logit. + remove_zero_slices (bool): Optional. If True, remove slices of the data that are + all zeros. Default False + no_mean (bool): Optional. If True, the mean will not be taken (note that most + model.py classes expect that the input dependent data will not have draws) + bias_adjust (bool): Optional. Whether or not to perform bias adjustment on the + results coming out of logit space. Generally ``True`` is the recommended option + intercept_shift (str | None): Optional. What type of intercept shift to perform, + if any. Options are "mean", "unordered_draw", and "ordered_draw". + age_standardize (bool): Optional. Whether to age-standardize data for modeling + purposes (results are age-specific). If you want to have age-standardized data + modeled and the results also be age-standardized, then you should + age-standardize the data outside of the processor. + rescale_age_weights (bool): Whether to rescale the age weights across available + ages when age-standardizing. + shift_from_reference (bool): Optional. Whether to calculate the differences used + for the intercept shift with only the reference scenario (True), or calculate + scenario-specific differences (False). Defaults to True. + tolerance (float): tolerance value to supply to the "closeness" check. Defaults to + ``ProcessingConstants.DEFAULT_PRECISION`` + """ + self.years = years + self.offset = offset + self.gbd_round_id = gbd_round_id + self.remove_zero_slices = remove_zero_slices + self.no_mean = no_mean + self.bias_adjust = bias_adjust + self.intercept_shift_type = intercept_shift + self.age_standardize = age_standardize + self.rescale_age_weights = rescale_age_weights + self.zero_slices_dict = {} + self.age_ratio = xr.DataArray() + self.shift_from_reference = shift_from_reference + self.tolerance = tolerance + + def __eq__(self, other: Any) -> bool: + """True if the other object matches this one in a few aspects.""" + if type(self) != type(other): + return False + elif self.offset != other.offset: + return False + elif self.years != other.years: + return False + elif self.gbd_round_id != other.gbd_round_id: + return False + else: + return True + + def pre_process(self, past_data: xr.DataArray, *args: Any) -> xr.DataArray: + """Transform past_data into logit space. (Also do several other things).""" + if len(args) != 0: + logger.warning( + "Warning: passed extra args to LogitProcessor.pre_process. " + "They will be ignored." + ) + + if self.remove_zero_slices: + past_data, self.zero_slices_dict = _remove_all_zero_slices( + data=past_data, + dims=[DimensionConstants.AGE_GROUP_ID, DimensionConstants.SEX_ID], + tolerance=self.tolerance, + ) + if not self.no_mean: + past_data = mean_of_draws(past_data) + + past_data = logit_with_offset(past_data, self.offset) + + if self.age_standardize: + age_specific_data = past_data.copy() + past_data = _get_weights_and_age_standardize( + past_data, self.gbd_round_id, rescale=self.rescale_age_weights + ) + # take the difference of age-specific and age-standardized data in last past year + self.age_ratio = ( + age_specific_data + - past_data.sel(age_group_id=AgeConstants.STANDARDIZED_AGE_GROUP_ID, drop=True) + ).sel(year_id=self.years.past_end) + + return past_data + + def post_process( + self, modeled_data: xr.DataArray, past_data: xr.DataArray + ) -> xr.DataArray: + """Transform modeled_data from logit into linear space, intercept_shift if requested. + + past_data is expected to be the same as what was passed to the corresponding + pre_process, and should still be in linear space. + """ + if self.age_standardize: + modeled_data = ( + modeled_data.sel( + age_group_id=AgeConstants.STANDARDIZED_AGE_GROUP_ID, drop=True + ) + + self.age_ratio + ) + + if self.remove_zero_slices: + past_data, _ = _remove_all_zero_slices( + data=past_data, + dims=[DimensionConstants.AGE_GROUP_ID, DimensionConstants.SEX_ID], + tolerance=self.tolerance, + ) + + if self.intercept_shift_type == "mean": + past_data = mean_of_draws(past_data) + + modeled_data = invlogit_with_offset( + modeled_data, self.offset, bias_adjust=self.bias_adjust + ) + + if ( + self.intercept_shift_type == "mean" + or self.intercept_shift_type == "unordered_draw" + or self.intercept_shift_type == "ordered_draw" + ): + modeled_data = self.intercept_shift( + modeled_data=modeled_data, + past_data=past_data, + offset=self.offset, + intercept_shift=self.intercept_shift_type, + years=self.years, + shift_from_reference=self.shift_from_reference, + ) + + if self.zero_slices_dict: + return _add_all_zero_slices(modeled_data, self.zero_slices_dict) + + return modeled_data + + @classmethod + def intercept_shift( + cls, + modeled_data: xr.DataArray, + past_data: xr.DataArray, + intercept_shift: str, + years: YearRange, + shift_from_reference: bool, + offset: float = ProcessingConstants.DEFAULT_OFFSET, + ) -> xr.DataArray: + """Move past and modeled data to logit space, intercept-shift, and translate back. + + Performs the intercept-shift with the type given in ``intercept_shift``, using + ``years`` and ``shift_from_reference`` as parameters (see docs for + ``apply_intercept_shift``). + + Translation back from + + Args: + modeled_data (xr.DataArray): future data to shift. Will be aligned with the + last-past-year data in ``past_data``. + past_data (xr.DataArray): past data, the basis to which to shift. Should overlap + with modeled_data in the last-past-year in ``years``. + offset (float): Offset for the log transform (see ``logit_with_offset``) and + inverse transform. + intercept_shift (str): Type of intercept-shift to perform. See + ``apply_intercept_shift`` for the options. + years (YearRange): Years describing the range of years in the past_data and + modeled_data. Those DataArrays should both contain data for the last-past-year + (middle component). + shift_from_reference (bool): When True, the shift is calculated once from the + future reference scenario and applied to all other scenarios. When False, each + scenario has its own shift applied to match past data. + """ + past_data = logit_with_offset(past_data, offset) + modeled_data = logit_with_offset(modeled_data, offset) + modeled_data = apply_intercept_shift( + modeled_data, + past_data, + intercept_shift, + years, + shift_from_reference, + ) + modeled_data = invlogit_with_offset(modeled_data, offset, bias_adjust=False) + return modeled_data + + +def intercept_shift_in_logit( + modeled_data: xr.DataArray, + past_data: xr.DataArray, + intercept_shift: str, + years: YearRange, + shift_from_reference: bool, + offset: float = ProcessingConstants.DEFAULT_OFFSET, +) -> xr.DataArray: + """See LogitProcessor.intercept_shift.""" + return LogitProcessor.intercept_shift( + modeled_data=modeled_data, + past_data=past_data, + offset=offset, + intercept_shift=intercept_shift, + years=years, + shift_from_reference=shift_from_reference, + ) + + +class NoTransformProcessor(BaseProcessor): + """A Processor that doesn't do any big space transformation. + + But, it can remove zero slices, take the mean, and age-standardize, depending on + parameters to the constructor. + """ + + def __init__( + self, + years: YearRange, + gbd_round_id: int, + offset: float = ProcessingConstants.DEFAULT_OFFSET, + remove_zero_slices: bool = False, + no_mean: bool = False, + bias_adjust: bool = False, + intercept_shift: Optional[str] = None, + age_standardize: bool = False, + rescale_age_weights: bool = True, + shift_from_reference: bool = True, + tolerance: float = ProcessingConstants.DEFAULT_PRECISION, + **kwargs: Any, + ) -> None: + """Construct a LogProcessor with additional processing as specified by the args. + + Args: + years (YearRange): Forecasting time series year range (used for age-standardize + and intercept shifting). + gbd_round_id (int): Which gbd_round_id the data is from (can be used in age + standardization) + offset (float): Optional. How much to offset the data from zero when taking the + log. + remove_zero_slices (bool): Optional. If True, remove slices of the data that are + all zeros. Default False + no_mean (bool): Optional. If True, the mean will not be taken (note that most + model.py classes expect that the input dependent data will not have draws) + bias_adjust (bool): Optional. Whether or not to perform bias adjustment on the + results coming out of log space. Generally ``True`` is the recommended option + intercept_shift (str): Optional. What type of intercept shift to perform, if any. + Options are "mean", "unordered_draw", and "ordered_draw". + age_standardize (bool): Optional. Whether to age-standardize data for modeling + purposes (results are age-specific). If you want to have age-standardized data + modeled and the results also be age-standardized, then you should + age-standardize the data outside of the processor. + rescale_age_weights (bool): Whether to rescale the age weights across available + ages when age-standardizing. + shift_from_reference (bool): Optional. Whether to calculate the differences used + for the intercept shift with only the reference scenario (True), or calculate + scenario-specific differences (False). Defaults to True. + tolerance (float): tolerance value to supply to the "closeness" check. Defaults to + ``ProcessingConstants.DEFAULT_PRECISION`` + kwargs: Ignored. + """ + self.years = years + self.offset = offset + self.gbd_round_id = gbd_round_id + self.remove_zero_slices = remove_zero_slices + self.no_mean = no_mean + self.bias_adjust = bias_adjust + self.intercept_shift_type = intercept_shift + self.age_standardize = age_standardize + self.rescale_age_weights = rescale_age_weights + self.zero_slices_dict = {} + self.age_ratio = xr.DataArray() + self.shift_from_reference = shift_from_reference + self.tolerance = tolerance + + def __eq__(self, other: Any) -> bool: + """True if the other object matches this one in a few aspects.""" + if type(self) != type(other): + return False + elif self.offset != other.offset: + return False + elif self.years != other.years: + return False + elif self.gbd_round_id != other.gbd_round_id: + return False + else: + return True + + def pre_process(self, past_data: xr.DataArray, *args: Any) -> xr.DataArray: + """Just do the "several other things" that are expected of Processors.""" + if len(args) != 0: + logger.warning( + "Warning: passed extra args to NoTransform.pre_process. They will be ignored." + ) + + if self.remove_zero_slices: + past_data, self.zero_slices_dict = _remove_all_zero_slices( + data=past_data, + dims=[DimensionConstants.AGE_GROUP_ID, DimensionConstants.SEX_ID], + tolerance=self.tolerance, + ) + + if not self.no_mean: + past_data = mean_of_draws(past_data) + + if self.age_standardize: + age_specific_data = past_data.copy() + past_data = _get_weights_and_age_standardize( + past_data, self.gbd_round_id, rescale=self.rescale_age_weights + ) + self.age_ratio = ( + age_specific_data + - past_data.sel(age_group_id=AgeConstants.STANDARDIZED_AGE_GROUP_ID, drop=True) + ).sel(year_id=self.years.past_end) + + return past_data + + def post_process( + self, modeled_data: xr.DataArray, past_data: xr.DataArray + ) -> xr.DataArray: + """Do no-op transformation, but intercept_shift if requested. + + past_data is expected to be the same as what was passed to the corresponding + pre_process, and should still be in linear space. + """ + if self.age_standardize: + modeled_data = ( + modeled_data.sel( + age_group_id=AgeConstants.STANDARDIZED_AGE_GROUP_ID, drop=True + ) + + self.age_ratio + ) + + if self.remove_zero_slices: + past_data, _ = _remove_all_zero_slices( + data=past_data, + dims=[DimensionConstants.AGE_GROUP_ID, DimensionConstants.SEX_ID], + tolerance=self.tolerance, + ) + else: + past_data = past_data + + if ( + self.intercept_shift_type == "mean" + or self.intercept_shift_type == "unordered_draw" + or self.intercept_shift_type == "ordered_draw" + ): + modeled_data = self.intercept_shift( + modeled_data=modeled_data, + past_data=past_data, + offset=self.offset, + intercept_shift=self.intercept_shift_type, + years=self.years, + shift_from_reference=self.shift_from_reference, + ) + + if self.zero_slices_dict: + return _add_all_zero_slices(modeled_data, self.zero_slices_dict) + + return modeled_data + + @classmethod + def intercept_shift( + cls, + modeled_data: xr.DataArray, + past_data: xr.DataArray, + intercept_shift: str, + years: YearRange, + shift_from_reference: bool, + offset: float = ProcessingConstants.DEFAULT_OFFSET, + ) -> xr.DataArray: + """Intercept-shift future to past, in this "null" space, i.e. just intercept-shift.""" + return apply_intercept_shift( + modeled_data, + past_data, + intercept_shift, + years, + shift_from_reference, + ) + + +def logit_with_offset( + data: xr.DataArray, offset: float = ProcessingConstants.DEFAULT_OFFSET +) -> xr.DataArray: + """Apply the logit transformation with an offset adjustment. + + We use an ofset to enforce the range of the logit function while maintaining rotational + symmetry about (0.5, 0). + + Args: + data (xr.DataArray): Data to transform. + offset (float): Amount to offset away from 0 and 1 before logit transform. + + Returns: + logit_data (xr.DataArray): + Data transformed into logit space with offset + + Raises: + RuntimeError: If there are Infs/Nans after transformation + """ + off = 1 - offset + norm = 0.5 * offset if offset > 0 else 0 + offset_data = data * off + norm + logit_data = np.log(offset_data / (1 - offset_data)) + + # Verify transformation, ie no infinite values (except missings) + msg = "There are Infs/Nans after transformation!" + if not np.isfinite(logit_data).all(): + logger.error(msg) + raise RuntimeError(msg) + + return logit_data + + +def bias_correct_invlogit(data: xr.DataArray) -> xr.DataArray: + """Transform out of logit space and perform an adjustment for bias. + + The adjustment in bias is due to difference in mean of logit draws vs logit of mean draws. + + Pre-conditions: + * ``draw`` dimension exists in ``data`` + + Args: + data (xr.DataArray): + Data to perform inverse logit transformation with scaling on + + Returns: + xr.DataArray: + The data taken out of logit space and adjusted with scaling methods + to account for bias (same dims/coords as data) + """ + expit_data = np.exp(data) / (1 + np.exp(data)) + mean_data = data.mean(DimensionConstants.DRAW) + inv_mean = np.exp(mean_data) / (1 + np.exp(mean_data)) + adj_expit_data = expit_data - (expit_data.mean(DimensionConstants.DRAW) - inv_mean) + + return adj_expit_data + + +def invlogit_with_offset( + data: xr.DataArray, + offset: float = ProcessingConstants.DEFAULT_OFFSET, + bias_adjust: bool = True, +) -> xr.DataArray: + """Reverse logit transform with inherent offset adjustments. + + Recall that we do a logit transform with an adjustment to squeeze 0s and 1s to fit logit + assumptions while maintaining logit centering at .5. + + Optionally use `bias_correct_invlogit` to obtain the expit data before correcting the + offset instead of plain logit function. + + Note: If negative values exist after back-transformation, then they are filled with zero. + + Args: + data (xr.DataArray): + Data to inverse logit transform (expit). + offset (float): + Amount that was offset away from 0 and 1 in the logit transform. + bias_adjust (bool): + Whether to apply bias_correct_invlogit instead of "expit". + + Returns: + expit_data (xr.DataArray): + Data transformed back into normal space out of logit space with no offset. + + Raises: + RuntimeError: If there are negative values after back-transformation. + """ + off = 1 - offset + norm = 0.5 * offset if offset > 0 else 0 + + if bias_adjust: + expit_data = bias_correct_invlogit(data) + else: + expit_data = np.exp(data) / (1 + np.exp(data)) + reset_data = ((expit_data - norm) / off).clip(min=0, max=1) + + msg = "There are negatives after back-transformation!" + if (reset_data < 0).any(): + logger.error(msg) + raise RuntimeError(msg) + + return reset_data + + +def log_with_offset( + data: xr.DataArray, offset: float = ProcessingConstants.DEFAULT_OFFSET +) -> xr.DataArray: + """Take the log of data, applying a slight offset to control for the potential of zeros. + + Args: + data (xr.DataArray): + Data to log transform + offset (float): + Amount to offset away from 0 before the log transform + + Returns: + log_data (xr.DataArray): + Data transformed into log space with offset + + Raises: + RuntimeError: If there are infs/NaNs after transformation. + """ + log_data = np.log(data + offset) + + # Verify transformation, ie no infinite values (except missings) + msg = "There are Infs/Nans after transformation!" + if not np.isfinite(log_data).all(): + logger.error(msg) + raise RuntimeError(msg) + + return log_data + + +def log_with_caps(data: xr.DataArray, log_min: float, log_max: float) -> xr.DataArray: + """Take the log of data, pinning the data to within the [min,max] range. + + Args: + data (xr.DataArray): + Data to log transform + log_min (float): + Log minimum to clip data to + log_max (float): + Log maximum to clip data to + + Returns: + log_data (xr.DataArray): + Data transformed into log space with offset + + Raises: + RuntimeError: If there are infs/NaNs after transformation. + """ + log_data = np.log(data).clip(min=log_min, max=log_max) + + # Verify transformation, ie no infinite values (except missings) + msg = "There are Infs/Nans after transformation!" + if not np.isfinite(log_data).all(): + logger.error(msg) + raise RuntimeError(msg) + + return log_data + + +def invlog_with_offset( + data: xr.DataArray, + offset: float = ProcessingConstants.DEFAULT_OFFSET, + bias_adjust: bool = True, +) -> xr.DataArray: + """Undo a log transform & subtract the offset that was added in the log preprocessing. + + With bias_adjust=True, use the `bias_exp_new` function to adjust the results such that the + mean of the exponentiated distribution is equal to the exponentiated expected value of the + log distribution. The adjustment assumes that the log distribution is normally distributed. + + Args: + data (xr.DataArray): + Data to inverse log transform (exp) + offset (float): + Amount that was offset away from 0 in the log transform + bias_adjust: If true, apply the ``bias_exp_new`` function instead of ``exp``. + + Returns: + exp_data (xr.DataArray): + Data transformed out of log space with no offset + + Raises: + RuntimeError: If there are negatives after back-transformation + """ + if bias_adjust: + exp_data = (bias_exp_new(data) - offset).clip(min=0) + else: + exp_data = (np.exp(data) - offset).clip(min=0) + + msg = "There are negatives after back-transformation!" + if (exp_data < 0).any(): + logger.error(msg) + raise RuntimeError(msg) + + return exp_data + + +def _get_weights_and_age_standardize( + da: xr.DataArray, gbd_round_id: int, rescale: bool = True +) -> xr.DataArray: + """Age-standardize an xarray. + + We drop weights to ages in the array and renormalize weights to sum to 1. + + Args: + da: Data to standardize. + gbd_round_id: the round ID the age weights should come from. + rescale: Whether to renormalize across the age groups available. + + Returns: + The input ``da``, standardized. + """ + age_weights = age.get_age_weights(gbd_round_id) + age_weights = age_weights.loc[ + age_weights.age_group_id.isin(da.coords["age_group_id"].values), : + ] + + age_weights = xr.DataArray( + age_weights.age_weight.values, + dims=["age_group_id"], + coords={"age_group_id": age_weights.age_group_id.values}, + ) + + age_std_da = age_standardize.age_standardize(da, age_weights, rescale=rescale) + age_std_da.name = da.name + + return age_std_da + + +def subset_to_reference( + data: xr.DataArray, + draws: Optional[int], + year_ids: Optional[List[int]] = None, +) -> xr.DataArray: + """Filter to the given years and draws, and the reference scenario. + + Args: + data (xr.DataArray): + The dependent variable data that has not been filtered to relevant + coordinates + draws (int | None): + Either the number of draws to resample to or ``None``, which + indicates that no draw-resampling should happen. + year_ids (list[int] | None): + Optional. The coords of the ``year_id`` dim to filter to. If ``None``, then + ``year_id`` dim's coords won't be filtered. Defaults to ``None``. + + Returns: + xr.DataArray: + cleaned/filtered dependent variable data + """ + data = get_dataarray_from_dataset(data) + + if "scenario" in data.dims: + data = data.sel(scenario=ScenarioConstants.REFERENCE_SCENARIO_COORD, drop=True) + + if year_ids is not None: + data_time_slice = data.sel(year_id=year_ids) + else: + data_time_slice = data + + if draws: + if "draw" in data_time_slice.dims: + data_time_slice = resample(data_time_slice, draws) + else: + data_time_slice = expand_dimensions(data_time_slice, draw=np.arange(0, draws)) + + return data_time_slice + + +def clean_cause_data( + data: xr.DataArray, + stage: str, + acause: str, + draws: Optional[int], + gbd_round_id: int, + year_ids: Optional[Iterable[int]] = None, + national_only: bool = False, +) -> xr.DataArray: + """Filter the dependent variable data to only relevant most detailed coordinates. + + Also filters out age and sex restrictions. + + Args: + data (xr.DataArray): + The dependent variable data that has not been filtered to relevant + coordinates + stage (str): + The GBD/FHS stage that the dependent variable is, e.g. ``pi_ratio`` + acause (str): + The GBD cause of the dependent variable, e.g. ``cvd_ihd``. + draws (int | None): + Either the number of draws to resample to or ``None``, which + indicates that no draw-resampling should happen. + gbd_round_id (int): + Numeric ID for the GBD round + year_ids (list[int] | None): + Optional. The coords of the ``year_id`` dim to filter to. If ``None``, then + ``year_id`` dim's coords won't be filtered. Defaults to ``None``. + national_only (bool): Whether to include subnational locations, or to include only + nations. + + Returns: + xr.DataArray: + cleaned/filtered dependent variable data + """ + data = get_dataarray_from_dataset(data) + + if year_ids is not None: + data_time_slice = data.sel(year_id=year_ids) + else: + data_time_slice = data + + if draws is not None: + data_time_slice = resample(data_time_slice, draws) + + cleaned_data, warning_msg = _filter_relevant_coords( + data_time_slice, acause, stage, gbd_round_id, national_only + ) + return cleaned_data, warning_msg + + +def clean_covariate_data( + past: xr.DataArray, + forecast: xr.DataArray, + dep_var: xr.DataArray, + years: YearRange, + draws: int, + gbd_round_id: int, + national_only: bool = False, +) -> xr.DataArray: + """Combines past and forecasted covariate data into one array. + + Filters dims to only relevant most-detailed coords. + + Raises IndexError if the past and forecasted data dims don't line up (after the past is + broadcast across the scenario dim). + + Args: + past (xr.DataArray): Past covariate data + forecast (xr.DataArray): + Forecasted covariate data + dep_var (xr.DataArray): + Dependent variable data that has already been filtered down to relevant + coordinates. Relevant coordinates will be inferred from this array. + years (YearRange): + Forecasting timeseries + draws (int): + Number of draws to include. + gbd_round_id (int): + Numeric ID for the GBD round. + national_only (bool): Whether to include subnational locations, or to include only + nations. + + Returns: + xr.DataArray: Cleaned covariate data. + + Raises: + IndexError: If + * the past and forecasted data coords don't line up. + * the covariate data is missing coordinates from a dim it shares with the + dependent variable. + """ + cov_name = forecast.name + only_forecast = forecast.sel(year_id=years.forecast_years) + + resampled_forecast = ensure_draws(only_forecast, draws) + stripped_forecast = strip_single_coord_dims(resampled_forecast) + + only_past = past.sel(year_id=years.past_years) + + resampled_past = ensure_draws(only_past, draws) + stripped_past = strip_single_coord_dims(resampled_past) + + # expand dims if scenario dim in future but not past + data = concat_past_and_future(stripped_past, stripped_forecast) + + if gbd_round_id in GBDRoundIdConstants.NO_MOST_DETAILED_IDS: + most_detailed_data = data + else: + most_detailed_data = filter.make_most_detailed(data, gbd_round_id, national_only) + + shared_dims = list( + set(most_detailed_data.dims) & set(dep_var.dims) - {DimensionConstants.YEAR_ID} + ) + expected_coords = {dim: list(dep_var[dim].values) for dim in shared_dims} + # Don't modify draws: We resampled the `most_detailed_data` but not the `dep_var`. + if DimensionConstants.DRAW in expected_coords: + del expected_coords[DimensionConstants.DRAW] + try: + cleaned_data = most_detailed_data.sel(**expected_coords) + except KeyError: + err_msg = f"`{cov_name}` is missing expected coords" + logger.error(err_msg) + raise IndexError(err_msg) + + if not np.isfinite(cleaned_data).all(): + err_msg = f"`{cov_name}` past and forecast coords don't line up`" + logger.error(err_msg) + raise IndexError(err_msg) + + return cleaned_data.rename(cov_name) + + +def ensure_draws(da: xr.DataArray, draws: Optional[int]) -> xr.DataArray: + """Resample ``da``, ensuring that it winds up with a ``draw`` dimension. + + This acts just like ``resample`` except that it adds the ``draw`` dimension if missing. + + Args: + da: the data to resample, if it has a ``draws`` dimension. + draws: The desired number of draws. + + Returns: + The resampled data + """ + if DimensionConstants.DRAW in da.dims: + return resample(da, draws) + elif draws: + return expand_dimensions(da, draw=list(range(draws))) + else: + return da + + +def mad_truncate( + da: xr.DataArray, + median_dims: Iterable[str], + pct_coverage: float, + max_multiplier: float, + multiplier_step: float, +) -> xr.DataArray: + """Truncate values based on the median absolute deviation. + + Calculates the median absolute deviation (MAD) across coordinates for each median_dims, + then finds floor and ceiling values based on a multiplier for the MAD that covers + pct_coverage of the data (only one multiplier value across whole dataarray. This could + change to separate multipliers for each of the median_dims coordinates, but is based on the + method from the legacy code right now). Data is truncated to be between the floors and + ceilings. + + This is used as a more flexible floor/ceiling truncation method than hard cutoffs. It is + used in the `indicator_from_ratio.py` script to control for extreme values, as dividing to + obtain the indicator from the ratio can sometimes lead to very high values if the + denominator is low and/or the numerator is high. The primary concern is the MI and MP + ratios, which have demonstrated extreme value problems stemming from the division of M by + MI or MP. However, it is a flexible method for truncating that could be used in other + situations outside of `indicator_from_ratio.py` as well. + + Note: + The MAD calculated by _mad is multiplied by a default scale of 1.4826 + for consistency with `scipy.stats.median_absolute_deviation`. + + Args: + da (xr.DataArray): + The array to calculate the multiplier for. + median_dims (list[str]): + List of dims to calculate the MAD for. E.g. if median_dims is ['age_group_id'], + then median will return a dataarray with only age_group_id dimension). + pct_coverage (float): + Percent of data in array to have between the floor and ceiling, e.g. 0.975. + max_multiplier (float): + The maximum multiplier for the MAD that is acceptable. If it is too small, then + pct_coverage of the data might not be between median +/- multiplier * MAD. + multiplier_step (float): + The amount to test multipliers by (starting value is zero + step). + + Returns: + xr.DataArray: + da truncated to be between ceiling and floor based on the MAD + """ + + def _mad_truncate( + da: xr.DataArray, + median_dims: Iterable[str], + pct_coverage: float, + max_multiplier: float, + multiplier_step: float, + ) -> xr.DataArray: + """Helper function for `mad_truncate`.""" + dims_to_median = set(da.dims).difference(set(median_dims)) + mad_da = _mad(da, median_dims) + median_da = da.median(dim=dims_to_median) + multiplier = _calculate_mad_multiplier( + da, + mad_da, + median_da, + pct_coverage, + max_multiplier=max_multiplier, + step=multiplier_step, + ) + ceiling = median_da + (multiplier * mad_da) + floor = median_da - (multiplier * mad_da) + truncated_da = da.where(da < ceiling, other=ceiling) + truncated_da = truncated_da.where(da > floor, other=floor) + return truncated_da + + if "scenario" in da.dims: + truncated_scenarios_list = [] + for scenario in da[DimensionConstants.SCENARIO].values: + sub_da = da.sel(scenario=scenario) + truncated_sub_da = _mad_truncate( + sub_da, median_dims, pct_coverage, max_multiplier, multiplier_step + ) + truncated_scenarios_list.append(truncated_sub_da) + all_scenarios_truncated = xr.concat( + truncated_scenarios_list, dim=DimensionConstants.SCENARIO + ) + else: + all_scenarios_truncated = _mad_truncate( + da, median_dims, pct_coverage, max_multiplier, multiplier_step + ) + + all_scenarios_truncated = all_scenarios_truncated.transpose(*da.dims) + + return all_scenarios_truncated + + +def strip_single_coord_dims(da: xr.DataArray) -> xr.DataArray: + """Strip off single coordinate dimensions and point coordinates. + + Args: + da (xr.DataArray): + The array to strip. + + Returns: + xr.DataArray: + Array without single coord dims or point coords, but without any other changes. + """ + stripped_da = da.copy() + for dim in list(da.coords): + if dim not in da.dims: + # dim is a point coord + stripped_da = stripped_da.drop_vars(dim) + elif len(da[dim].values) == 1: + # dim has only one coord + stripped_da = stripped_da.sel({dim: da[dim].values[0]}, drop=True) + return stripped_da + + +def expand_single_coord_dims(new_da: xr.DataArray, ref_da: xr.DataArray) -> xr.DataArray: + """Expand dataarray to include point coords and single coord dims per the "ref" dataarray. + + Args: + new_da (xr.DataArray): + The dataarray to expand + ref_da (xr.DataArray): + The dataarray to infer point coords and/or single coord dims from + + Returns: + xr.DataArray: + The expanded copy of ``new_da`` dataarray + """ + expanded_da = new_da.copy() + for dim in list(ref_da.coords): + if dim not in ref_da.dims: + # dim is a point coord + coord = ref_da[dim].values + expanded_da = expanded_da.assign_coords(**{dim: coord}) + elif len(ref_da[dim].values) == 1: + # dim has only one coord + coord = ref_da[dim].values[0] + expanded_da = expanded_da.expand_dims(**{dim: [coord]}) + return expanded_da + + +def get_dataarray_from_dataset(ds: Union[xr.DataArray, xr.Dataset]) -> xr.DataArray: + """Extract a DataArray from a Dataset, or just return the DataArray if given one.""" + if isinstance(ds, xr.Dataset): + try: + return ds["value"] + except KeyError: + return ds["sdi"] + else: + return ds + + +def remove_unexpected_dims(da: xr.DataArray) -> xr.DataArray: + """Remove unexpected single-coord dimensions or point-coordinates. + + Also, asserts that optional or unexpected dims either have just one coord or are + point-coords. + + Args: + da (xr.DataArray): + The dataarray to make conform to the expected dims + + Returns: + xr.DataArray: + The original dataarray, but with unexpected dims removed + + Raises: + IndexError: If optional or unexpected dim has more than one coord + """ + for dim in da.coords: + if dim in ProcessingConstants.EXPECTED_DIMENSIONS["optional"]: + try: + num_coords = len(da[dim]) + if num_coords > 1: + err_msg = f"{dim} is an optional dim with more than one coord" + logger.error(err_msg) + raise IndexError(err_msg) + except TypeError: + pass # dim is a point coord, and that's okay in this case + elif dim not in ProcessingConstants.EXPECTED_DIMENSIONS["required"]: + da = _remove_unexpected_dim(da, dim) + + return da + + +def _remove_unexpected_dim(da: xr.DataArray, dim: str) -> xr.DataArray: + try: + num_coords = len(da[dim]) + if num_coords > 1: + err_msg = f"{dim} is an unexpected dim with more than one coord" + logger.error(err_msg) + raise IndexError(err_msg) + else: + # dim is a single coord-dimension + da = da.sel({dim: da[dim].values[0]}, drop=True) + except TypeError: + # dim is a point-coord + da = da.drop_vars(dim) + + return da + + +def concat_past_and_future(past: xr.DataArray, future: xr.DataArray) -> xr.DataArray: + """Concatenate past and future data by expanding past scenario dimension. + + Does not account for mismatched coordinates at the moment. + + Prerequisites: + * No past scenario dimension + * No overlapping years + * Matching dims except for year and scenario + + Args: + past (xr.DataArray): + The past dataarray to concatenate + future (xr.DataArray): + The forecast dataarray to concatenate + + Returns: + xr.DataArray: + The complete time series data with past and future. + + Raises: + IndexError: If dimensions other than year and scenario do not line up. + """ + if ( + DimensionConstants.SCENARIO in future.dims + and DimensionConstants.SCENARIO not in past.dims + ): + past = past.expand_dims(scenario=future[DimensionConstants.SCENARIO].values) + + inconsistent_dims = set(past.dims).symmetric_difference(set(future.dims)) + if inconsistent_dims: + err_msg = "Dimensions don't line up for past and future" + logger.error(err_msg) + raise IndexError(err_msg) + + data = xr.concat([past, future], dim=DimensionConstants.YEAR_ID) + + return data + + +def _filter_relevant_coords( + data: xr.DataArray, acause: str, stage: str, gbd_round_id: int, national_only: bool = False +) -> xr.DataArray: + """Filter dataarray to the relevant most-detailed coords for the given cause and stage. + + Args: + data: Data to filter. + acause: acause to determine the restrictions to apply. + stage: stage to determine the restrictions to apply. + gbd_round_id: gbd_round_id to determine the restrictions to apply. + national_only: whether to also filter down to national locations only. + + Returns: + The filtered version of ``data``. + """ + inferred_stage = restrictions.get_stage_to_infer_restrictions(stage) + if stage != inferred_stage: + logger.debug(f"Inferring demographic restrictions for {stage}, from {inferred_stage}") + + stage_restrictions = restrictions.get_restrictions(acause, inferred_stage, gbd_round_id) + sliced_data = data.sel(**stage_restrictions) + most_detailed_data = filter.make_most_detailed( + sliced_data, gbd_round_id, national_only=national_only + ) + + filled_with_age_data, warning_msg = _fill_age_restrictions( + most_detailed_data, acause, stage, gbd_round_id + ) + + return filled_with_age_data, warning_msg + + +def _fill_age_restrictions( + data: xr.DataArray, acause: str, stage: str, gbd_round_id: int +) -> xr.DataArray: + """Expand ``data`` to include missing needed age-groups. + + The "needed" age-groups are dependent on the acause, stage, and gbd_round_id. + + The missing groups are filled with the nearest available age-groups. + + Args: + data: DataArray to expand. + acause: The "acause" whose restrictions should apply. + stage: The "stage" of processing whose restrictions should apply. + gbd_round_id: The gbd round ID from which to load age-group data. + + Returns: + The filled-out version of ``data``. + + Raises: + RuntimeError: if we have a missing age group that is a "middle" age group -- one + "between" available age-groups. + """ + available_age_ids = list(data["age_group_id"].values) + # If we are given mortality, for example, we want to get the age + # availability of the measure/cause we are calculating ratio with. + inferred_stage = restrictions.get_stage_to_infer_restrictions(stage, purpose="non_fatal") + + needed_age_ids = restrictions.get_restrictions(acause, inferred_stage, gbd_round_id)[ + "age_group_id" + ] + + missing_age_ids = list(set(needed_age_ids) - set(available_age_ids)) + if missing_age_ids: + logger.warning(f"age-group-ids:{missing_age_ids} are missing") + # Get oldest available age group and its data + oldest_avail_age_id = _get_oldest_age_id(available_age_ids, gbd_round_id) + oldest_avail_data = data.sel(age_group_id=oldest_avail_age_id, drop=True) + + # Get youngest available age group and its data + youngest_avail_age_id = _get_youngest_age_id(available_age_ids, gbd_round_id) + youngest_avail_data = data.sel(age_group_id=youngest_avail_age_id, drop=True) + + # Sort missing age groups into old and young + missing_old_age_ids = [] + missing_young_age_ids = [] + for age_id in missing_age_ids: + if _AgeGroupID(age_id, gbd_round_id) > oldest_avail_age_id: + missing_old_age_ids.append(age_id) + elif _AgeGroupID(age_id, gbd_round_id) < youngest_avail_age_id: + missing_young_age_ids.append(age_id) + else: + err_msg = ( + "age_group_id={age_id} is a middle age group -- it is" + "between available age-groups. This is unexpected." + ) + logger.error(err_msg) + raise RuntimeError(err_msg) + + warning_template = ( + "age_group_id={} are being filled with the " "data from age_group_id={}. " + ) + # Expand to include old age groups + with_missing_old_data = expand_dimensions( + data, age_group_id=missing_old_age_ids, fill_value=oldest_avail_data + ) + # Expand to include young age groups + with_missing_young_data = expand_dimensions( + with_missing_old_data, + age_group_id=missing_young_age_ids, + fill_value=youngest_avail_data, + ) + + warning_msg = "" + if missing_old_age_ids: + warning_msg += warning_template.format( + missing_old_age_ids, int(oldest_avail_age_id) + ) + if missing_young_age_ids: + warning_msg += warning_template.format( + missing_young_age_ids, int(youngest_avail_age_id) + ) + + if warning_msg: + logger.warning(warning_msg) + + return with_missing_young_data, warning_msg + else: + return data.copy(), None + + +def make_shared_dims_conform( + to_update: xr.DataArray, reference: xr.DataArray, ignore_dims: Optional[List[str]] = None +) -> xr.DataArray: + """Make a given array conform to another on certain shared dimensions. + + Note: + * ``to_update`` is expected to have all the same coords and maybe extra + for all the shared dims excluding the dims in ``ignore_dims``. + + Args: + to_update (xr.DataArray): + To make conform + reference (xr.DataArray): + Dataarray to conform to + ignore_dims (list[str]): + list of dims to omit from conforming e.g. DimensionConstants.YEAR_ID + + Returns: + ``to_update`` that has been filtered to conform to ``reference`` + """ + ignore_dims = ignore_dims or [] + coord_dict = {} + for dim in reference.dims: + if dim not in ignore_dims and dim in to_update.dims: + coord_dict.update({dim: reference[dim].values}) + + return to_update.sel(coord_dict) + + +class _AgeGroupID(object): + """Abstraction for age-group IDs. + + Basically just supports comparisons, to see which of two represents a younger/older + age-group. + """ + + def __init__(self, age_id: int, gbd_round_id: int) -> None: + self.age_id = age_id + self.gbd_round_id = gbd_round_id + + def __gt__(self, other: int) -> bool: + return self.age_id == _get_oldest_age_id([self.age_id, other], self.gbd_round_id) + + def __lt__(self, other: int) -> bool: + return self.age_id == _get_youngest_age_id([self.age_id, other], self.gbd_round_id) + + +def _get_oldest_age_id(age_group_ids: List[int], gbd_round_id: int) -> int: + """Among the given age-groups, return the oldest one.""" + age_df = age.get_ages(gbd_round_id)[["age_group_id", "age_group_years_start"]] + relevant_age_df = age_df.query("age_group_id in @age_group_ids") + oldest_age_id = relevant_age_df.loc[relevant_age_df["age_group_years_start"].idxmax()][ + "age_group_id" + ] + return oldest_age_id + + +def _get_youngest_age_id(age_group_ids: List[int], gbd_round_id: int) -> int: + """Among the given age-groups, return the youngest one.""" + age_df = age.get_ages(gbd_round_id)[["age_group_id", "age_group_years_start"]] + relevant_age_df = age_df.query("age_group_id in @age_group_ids") + youngest_age_id = relevant_age_df.loc[relevant_age_df["age_group_years_start"].idxmin()][ + "age_group_id" + ] + return youngest_age_id + + +def _remove_all_zero_slices( + data: xr.DataArray, + dims: Iterable[str], + tolerance: float = ProcessingConstants.DEFAULT_PRECISION, +) -> Tuple[xr.DataArray, xr.DataArray]: + """A method that removes and stores all-zero slices from a dataarray. + + Args: + data (xr.DataArray): data array that contains dims in its dimensions. + dims (Iterable): an iterable of dims over which to seek zero-slices. + tolerance (float): tolerance value to supply to the "closeness" check. Defaults to + ``ProcessingConstants.DEFAULT_PRECISION`` + + Returns: + (tuple): a tuple of data array and dict. The data array is the + input data sans zero-slices. The dict keeps track of the slices + that were removed. + """ + zero_slices_dict = {} # to help keep track of all-zero slices + keep_slices_dict = {} # the complement of zero_slices_dict + + avail_dims = [dim for dim in dims if dim in data.dims] + for dim in avail_dims: + zero_coords = [] + keep_coords = [] + for coord in data[dim].values: + slice = data.sel({dim: coord}) + if np.isclose(a=slice, b=0, atol=tolerance).all(): + zero_coords.append(coord) + else: + keep_coords.append(coord) + if zero_coords: + zero_slices_dict[dim] = zero_coords + keep_slices_dict[dim] = keep_coords + + if zero_slices_dict: + data = data.sel(**keep_slices_dict) + + return data, zero_slices_dict + + +def _add_all_zero_slices( + data: xr.DataArray, new_addition_dict: Dict[str, Any] +) -> xr.DataArray: + """Adds slices of zeros to data array. + + Args: + data (xr.DataArray): data to be added to. + new_addition_dict (dict): new zero-slices to have in data. + + Returns: + (xr.DataArray): data with additional slices that are all zeros. + """ + return expand_dimensions(data, fill_value=0, **new_addition_dict) + + +def _mad(da: xr.DataArray, median_dims: Iterable[str], scale: float = 1.4826) -> xr.DataArray: + """Calculate the median absolute deviation. + + Multiplies by `scale` for consistency with scipy.stats.median_absolute_deviation. + + Args: + da (xr.DataArray): + The array to calculate MAD. + median_dims (list[str]): + List of dims to calculate the MAD for. Ex.: if median_dims is ['age_group_id'], + then it will return a dataarray with only age_group_id dimension + scale (float): + (Optional.) The scaling factor applied to the MAD. The default scale (1.4826) + ensures consistency with the standard deviation for normally distributed data. + + Returns: + xr.DataArray: + Array with only median_dims containing the MAD for those dims. + """ + dims_to_median = set(da.dims).difference(set(median_dims)) + return scale * (np.abs(da - da.median(dim=dims_to_median))).median(dim=dims_to_median) + + +def _calculate_mad_multiplier( + da: xr.DataArray, + mad_da: xr.DataArray, + median_da: xr.DataArray, + pct_coverage: float, + max_multiplier: float, + step: float, +) -> Optional[float]: + """Multiplier necessary to achieve desired % of data within (median +- multiplier * MAD). + + Args: + da (xr.DataArray): + The array to calculate the multiplier for. + mad_da (xr.DataArray): + The array that has the median absolute deviation across some dims of da. + median_da (xr.DataArray): + The array that has the median across some dims of da. + pct_coverage (float): + Percent of data in array to have between the floor and ceiling, e.g. .975. + max_multiplier (float): + The maximum multiplier for the MAD that is acceptable. If it is too small, then + pct_coverage of the data might not be between median +/- multiplier * MAD. + step (float): + The amount to test multipliers by (starting value is zero + step). + + Returns: + float: + Multiplier for MAD to cap values outside of multiplier * MAD, or None if we don't + find such a multiplier less than max_multiplier. + """ + for multiplier in np.arange(0 + step, max_multiplier + step, step): + ceiling = median_da + (multiplier * mad_da) + floor = median_da - (multiplier * mad_da) + between_da = da.where(da > floor).where(da < ceiling) + vals_between = between_da.count().values.item(0) + vals_total = da.count().values.item(0) + if multiplier >= max_multiplier: + logger.warning( + ( + f"Using max_multiplier! {vals_between / vals_total} " + f"pct coverage achieved using max multiplier." + ) + ) + return max_multiplier + elif (vals_between / vals_total) > pct_coverage: + return multiplier diff --git a/gbd_2021/disease_burden_forecast_code/nonfatal/models/validate.py b/gbd_2021/disease_burden_forecast_code/nonfatal/models/validate.py new file mode 100644 index 0000000..e5b4d27 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/nonfatal/models/validate.py @@ -0,0 +1,39 @@ +"""Functions related to validating inputs, and outputs of nonfatal pipeline.""" + +from typing import List + +import xarray as xr +from tiny_structured_logger.lib.fhs_logging import get_logger + +logger = get_logger() + + +def assert_covariates_scenarios(cov_data_list: List[xr.DataArray]) -> None: + """Check that all covariates have the same scenario coordinates. + + Args: + cov_data_list (list[xr.DataArray]): Past and forecast data for each covariate, i.e. + independent variable. + + Raises: + ValueError: If the covariates do not have consistent scenario coords. + """ + first_cov = cov_data_list[0] + + if "scenario" in first_cov.dims: + first_scenarios = set(first_cov["scenario"].values) + for next_cov in cov_data_list[1:]: + next_scenarios = set(next_cov["scenario"].values) + if first_scenarios.symmetric_difference(next_scenarios): + raise ValueError( + f"Covariates have inconsistent scenario coords, e.g. " + f"{first_cov.name} and {next_cov.name}" + ) + + else: + for next_cov in cov_data_list[1:]: + if "scenario" in next_cov.dims: + raise ValueError( + f"{first_cov.name} doesn't have a scenario dimension, but {next_cov.name} " + "does. If any covariates have a scenario, they all need to." + ) diff --git a/gbd_2021/disease_burden_forecast_code/risk_factors/README.md b/gbd_2021/disease_burden_forecast_code/risk_factors/README.md new file mode 100644 index 0000000..66f6001 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/risk_factors/README.md @@ -0,0 +1,113 @@ +Risk factors pipeline code +Includes code for both SEVs and PAFs/scalars + +# Generalized ensemble model (GenEM) + +``` +arc_main.py +Forecasts an entity using the ARC method +``` + +``` +collect_submodels.py +Collects and collapses components into genem for future stage +``` + +``` +constants.py +FHS pipeline for risk factors local constants +``` + +``` +create_stage.py +Create tasks for GenEM +``` + +``` +get_model_weights_from_holdouts.py +Collects submodel predictive validity statistics to compile sampling weights for genem +``` + +``` +model_restrictions.py +Captures restrictions in which models get run for each entity/location +``` + +``` +predictive_validity.py +Calculates the RMSE between forecast and holdouts across location & sex +``` + +``` +run_stagewise_mrbrt.py +Forecasts entities using MRBRT (Meta-regression, Bayesian, regularized, trimmed) model +``` + +# Population attributable fractions (PAFs) + +Note: We first compute all the cause-risk-specific PAFs using `compute_pafs.py`, +followed by the cause-only PAFs and scalars using `compute_scalar.py` + +``` +compute_paf.py +Compute and export all the cause-risk-pair PAFs for given acause +``` + +``` +compute_scalar.py +Computes aggregated acause specific PAFs and scalars +``` + +``` +constants.py +FHS pipeline scalars local constants +``` + +``` +forecasting_db.py +Functions related to PAF queries +``` + +``` +utils.py +Utility/DB functions for the scalars pipeline +``` + +# Severity exposure values (SEVs) + +Note: the SEV pipeline consists of five primary stages: +1. Compute past intrinsic SEVs (stored in past_sev/risk_acause_specific/). See sev/compute_past_intrinsic_sev.py +2. (a) PV run on past years, and (b) full-draws forecast. +3. Export sampling weights based on PV statisics. +4. Collect and concat forecast draws based on PV statistics. +5. Compute future intrinsic SEVs. Stages 2-5 are contained in GenEM. + +``` +compute_future_mediator_total_sev.py +Compute cause-risk-specific future total SEV, given acause and risk +``` + +``` +compute_past_intrinsic_sev.py +Compute Intrinsic SEV of a mediator +``` + +``` +constants.py +FHS pipeline SEVs local constants +``` + +``` +mediation.py +Functions for understanding the mediation hierarchy +``` + +``` +rrmax.py +A wrapper around the central read_rrmax, that handles the PAFs of 1 case +``` + +``` +run_workflow.py +Construct and execute SEV workflow +``` \ No newline at end of file diff --git a/gbd_2021/disease_burden_forecast_code/risk_factors/genem/arc_main.py b/gbd_2021/disease_burden_forecast_code/risk_factors/genem/arc_main.py new file mode 100644 index 0000000..039993f --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/risk_factors/genem/arc_main.py @@ -0,0 +1,371 @@ +"""This script forecasts an entity using the ARC method. + +This script runs predictive validity for a entity to determine the weight used +in the arc quantile method. For each entity, this script forecasts the out +of sample window using different weights. Then determines the rmse and bias for +that weight. The output of this script are netCDFs containing the results of +the predictive metrics for each entity. + +After the predictive validity step an omega will be selected to forecast each +entity + +Notes regarding the truncate/capping optional flags: + +1.) truncate-quantiles are used for winsorizing only the past logit + age-standardized data (before computing the annualized rate of change). + In case of 0.025/0.975 quantiles (calculated across locations), + the data above the 97.5th percentile set to the 97.5th percentile wherein + the data below the 2.5th percentile set to the 2.5th percentile. + +2.) cap-quantiles are used for winsorizing only the future entities + (after generating forecasted entities). In case of 0.01/0.99 quantiles + (calculated based on the past entities), the forecasts above the 99th + percentile set to the 99th percentile wherein the forecasts below the + 1st percentile set to the 1st percentile. +""" + +import gc +from typing import List, Optional, Tuple + +import numpy as np +import xarray as xr +from fhs_lib_data_transformation.lib import filter +from fhs_lib_database_interface.lib.constants import SexConstants +from fhs_lib_file_interface.lib.version_metadata import FHSFileSpec +from fhs_lib_file_interface.lib.versioning import Versions +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr_scenario, save_xr_scenario +from fhs_lib_model.lib.arc_method import arc_method +from fhs_lib_year_range_manager.lib.year_range import YearRange +from tiny_structured_logger.lib import fhs_logging + +from fhs_lib_genem.lib import predictive_validity as pv +from fhs_lib_genem.lib.constants import ( + FileSystemConstants, + ModelConstants, + SEVConstants, + TransformConstants, +) + +logger = fhs_logging.get_logger() + + +def determine_entity_name_path(entity: str, stage: str) -> Tuple[str, str]: + """Take the entity name and determine name and file path.""" + if stage == "sev" and "-" in entity: # is an iSEV, specified as cause-risk + acause, rei = entity.split("-") + sub_folder = "risk_acause_specific" + file_name = "_".join([acause, rei, SEVConstants.INTRINSIC_SEV_FILENAME_SUFFIX]) + else: + sub_folder = "" + file_name = f"{entity}" + + return sub_folder, file_name + + +def _clip_past(past_mean: xr.DataArray, transform: str) -> xr.DataArray: + if transform == "logit": + # it makes sense to ceiling logit-transformable data (since its 0-1) + clipped_past = past_mean.clip(min=ModelConstants.FLOOR, max=1 - ModelConstants.FLOOR) + elif transform == "log": + # log transformable data should only be floored + clipped_past = past_mean.clip(min=ModelConstants.FLOOR) + else: + # data we won't transform shouldn't be clipped. + clipped_past = past_mean + return clipped_past + + +def _find_limits( + past_age_std_mean: xr.DataArray, + past_last_year: xr.DataArray, + upper_quantile: float, + lower_quantile: float, +) -> xr.DataArray: + """Find upper/lower limits to cap the forecasts.""" + past_age_std_quantiles = past_age_std_mean.quantile( + [lower_quantile, upper_quantile], dim=["location_id", "year_id"] + ) + upper = past_age_std_quantiles.sel(quantile=upper_quantile, drop=True) + lower = past_age_std_quantiles.sel(quantile=lower_quantile, drop=True) + + past_last_year_gt_upper = past_last_year.where(past_last_year > upper) + past_last_year_lt_lower = past_last_year.where(past_last_year < lower) + + upper_cap_lims = past_last_year_gt_upper.fillna(upper).rename("upper") + lower_cap_lims = past_last_year_lt_lower.fillna(lower).rename("lower") + + cap_lims = xr.merge([upper_cap_lims, lower_cap_lims]) + return cap_lims + + +def _reshape_bound(data: xr.DataArray, bound: xr.DataArray) -> xr.DataArray: + """Broadcast and align the dims of `bound` so that they match `data`.""" + expanded_bound, _ = xr.broadcast(bound, data) + return expanded_bound.transpose(*data.coords.dims) + + +def _cap_forecasts( + years: YearRange, + cap_quantiles: Tuple[float, float], + most_detailed_past: xr.DataArray, + past_mean: xr.DataArray, + forecast: xr.DataArray, +) -> xr.DataArray: + """Cap upper and lower bound on forecasted data, using quantiles from past data.""" + last_year = most_detailed_past.sel(year_id=years.past_end, drop=True) + lower_quantile, upper_quantile = cap_quantiles + caps = _find_limits( + past_mean, last_year, upper_quantile=upper_quantile, lower_quantile=lower_quantile + ) + returned_past = forecast.sel(year_id=years.past_years) + forecast = forecast.sel(year_id=years.forecast_years) + + lower_bound = _reshape_bound(forecast, caps.lower) + upper_bound = _reshape_bound(forecast, caps.upper) + + mean_clipped = forecast.clip(min=lower_bound, max=upper_bound).fillna(0) + + del forecast + gc.collect() + + capped_forecast = xr.concat([returned_past, mean_clipped], dim="year_id") + + return capped_forecast + + +def _forecast_entity( + omega: float, + past: xr.DataArray, + transform: str, + truncate: bool, + truncate_quantiles: Tuple[float, float], + replace_with_mean: bool, + reference_scenario: str, + years: YearRange, + gbd_round_id: int, + cap_forecasts: bool, + cap_quantiles: Tuple[float, float], + national_only: bool, + age_standardize: bool, + rescale_ages: bool, + remove_zero_slices: bool, +) -> xr.DataArray: + """Prepare data for forecasting, run model and post-process results.""" + most_detailed_past = filter.make_most_detailed_location( + data=past, gbd_round_id=gbd_round_id, national_only=national_only + ) + if "sex_id" not in most_detailed_past.dims or list(most_detailed_past.sex_id.values) != [ + SexConstants.BOTH_SEX_ID + ]: + most_detailed_past = filter.make_most_detailed_sex(data=most_detailed_past) + if age_standardize: + most_detailed_past = filter.make_most_detailed_age( + data=most_detailed_past, gbd_round_id=gbd_round_id + ) + + if "draw" in most_detailed_past.dims: + past_mean = most_detailed_past.mean("draw") + else: + past_mean = most_detailed_past + + clipped_past = _clip_past(past_mean=past_mean, transform=transform) + + processor = TransformConstants.TRANSFORMS[transform]( + years=years, + gbd_round_id=gbd_round_id, + age_standardize=age_standardize, + remove_zero_slices=remove_zero_slices, + rescale_age_weights=rescale_ages, + ) + + transformed_past = processor.pre_process(clipped_past) + + del clipped_past + gc.collect() + + transformed_forecast = arc_method.arc_method( + past_data_da=transformed_past, + gbd_round_id=gbd_round_id, + years=years, + diff_over_mean=ModelConstants.DIFF_OVER_MEAN, + truncate=truncate, + reference_scenario=reference_scenario, + weight_exp=omega, + replace_with_mean=replace_with_mean, + truncate_quantiles=truncate_quantiles, + scenario_roc="national", + ) + + forecast = processor.post_process(transformed_forecast, past_mean) + + if np.isnan(forecast).any(): + raise ValueError("NaNs in forecasts") + + if cap_forecasts: + forecast = _cap_forecasts( + years, cap_quantiles, most_detailed_past, past_mean, forecast + ) + + return forecast + + +def arc_all_omegas( + entity: str, + stage: str, + intrinsic: bool, + subfolder: str, + versions: Versions, + model_name: str, + omega_min: float, + omega_max: float, + omega_step_size: float, + transform: str, + truncate: bool, + truncate_quantiles: Optional[Tuple[float, float]], + replace_with_mean: bool, + reference_scenario: str, + years: YearRange, + gbd_round_id: int, + cap_forecasts: bool, + cap_quantiles: Optional[Tuple[float, float]], + national_only: bool, + age_standardize: bool, + rescale_ages: bool, + predictive_validity: bool, + remove_zero_slices: bool, +) -> None: + """Forecast an entity with different omega values. + + If a SEV, the rei input could be a risk, or it could be a cause-risk. + If it's a cause-risk (connected via hyphen), it's meant to be an + intrinsic SEV, which would come from + in_version/risk_acause_specific/{cause}_{risk}_intrinsic.nc, + and the forecasted result would go to + out_version/risk_acause_specific/{cause}_{risk}_intrinsic.nc. + + Args: + entity (str): Entity to forecast + stage (str): Stage of the run. E.x. sev, death, etc. + intrinsic (bool): Whether this entity obtains the _intrinsic suffix + subfolder (str): Optional subfolder for reading and writing files. + versions (Versions): versions object with both past and future (input and output). + model_name (str): Name to save the model under. + omega_min (float): The minimum omega to try + omega_max (float): The maximum omega to try + omega_step_size (float): The step size of omegas to try between 0 and omega_max + transform (str): Space to forecast data in + truncate (bool): If True, then truncates the dataarray over the given dimensions + truncate_quantiles (Tuple[float, float]): The tuple of two floats representing the + quantiles to take + replace_with_mean (bool): If True and `truncate` is True, then replace values outside + of the upper and lower quantiles taken across `location_id` and `year_id` and with + the mean across `year_id`, if False, then replace with the upper and lower bounds + themselves + reference_scenario (str): If 'median' then the reference scenario is made using the + weighted median of past annualized rate-of-change across all past years, 'mean' + then it is made using the weighted mean of past annualized rate-of-change across + all past years + years (YearRange): forecasting year range + gbd_round_id (int): the gbd round id + cap_forecasts (bool): If used, forecasts will be capped. To forecast without caps, + dont use this + cap_quantiles (tuple[float]): Quantiles for capping the future + national_only (bool): Whether to run national only data or not + rescale_ages (bool): whether to rescale during ARC age standardization. We are + currently only setting this to true for the sevs pipeline. + age_standardize (bool): whether to age_standardize before modeling. + predictive_validity (bool): whether to do predictive validity or real forecasts + remove_zero_slices (bool): If True, remove zero-slices along certain dimensions, when + pre-processing inputs, and add them back in to outputs. + """ + logger.debug(f"Running `forecast_one_risk_main` for {entity}") + + input_version_metadata = versions.get(past_or_future="past", stage=stage) + + file_name = entity + if intrinsic: # intrinsic entities have _intrinsic attached at file name + file_name = entity + "_intrinsic" + + data = open_xr_scenario( + file_spec=FHSFileSpec( + version_metadata=input_version_metadata, + sub_path=(subfolder,), + filename=f"{file_name}.nc", + ) + ) + + # rid the past data of point coords because they throw off weighted-quantile + superfluous_coords = [d for d in data.coords.keys() if d not in data.dims] + data = data.drop_vars(superfluous_coords) + + past = data.sel(year_id=years.past_years) + + if predictive_validity: + holdouts = data.sel(year_id=years.forecast_years) + all_omega_pv_results: List[xr.DataArray] = [] + + # here begins the loop over omegas + for omega in pv.get_omega_weights(omega_min, omega_max, omega_step_size): + logger.debug("omega:{}".format(omega)) + + forecast = _forecast_entity( + omega=omega, + past=past, + transform=transform, + truncate=truncate, + truncate_quantiles=truncate_quantiles, + replace_with_mean=replace_with_mean, + reference_scenario=reference_scenario, + years=years, + gbd_round_id=gbd_round_id, + cap_forecasts=cap_forecasts, + cap_quantiles=cap_quantiles, + national_only=national_only, + age_standardize=age_standardize, + rescale_ages=rescale_ages, + remove_zero_slices=remove_zero_slices, + ) + + if predictive_validity: + all_omega_pv_results.append( + pv.calculate_predictive_validity( + forecast=forecast, holdouts=holdouts, omega=omega + ) + ) + + else: + output_version_metadata = versions.get(past_or_future="future", stage=stage) + + output_file_spec = FHSFileSpec( + version_metadata=output_version_metadata, + sub_path=(FileSystemConstants.SUBMODEL_FOLDER, model_name, subfolder), + filename=f"{file_name}_{omega}.nc", + ) + + save_xr_scenario( + xr_obj=forecast, + file_spec=output_file_spec, + metric="rate", + space="identity", + omega=omega, + transform=transform, + truncate=str(truncate), + truncate_quantiles=str(truncate_quantiles), + replace_with_mean=str(replace_with_mean), + reference_scenario=str(reference_scenario), + cap_forecasts=str(cap_forecasts), + cap_quantiles=str(cap_quantiles), + ) + + if predictive_validity: + pv_df = pv.finalize_pv_data(pv_list=all_omega_pv_results, entity=entity) + + pv.save_predictive_validity( + file_name=file_name, + gbd_round_id=gbd_round_id, + model_name=model_name, + pv_df=pv_df, + stage=stage, + subfolder=subfolder, + versions=versions, + ) diff --git a/gbd_2021/disease_burden_forecast_code/risk_factors/genem/collect_submodels.py b/gbd_2021/disease_burden_forecast_code/risk_factors/genem/collect_submodels.py new file mode 100644 index 0000000..db40d3f --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/risk_factors/genem/collect_submodels.py @@ -0,0 +1,326 @@ +"""Script to collect and collapse components into genem for future stage. +""" + +from typing import Callable, List + +import numpy as np +import pandas as pd +import xarray as xr +from fhs_lib_data_transformation.lib.dimension_transformation import expand_dimensions +from fhs_lib_data_transformation.lib.resample import resample +from fhs_lib_file_interface.lib.pandas_wrapper import read_csv +from fhs_lib_file_interface.lib.version_metadata import FHSFileSpec +from fhs_lib_file_interface.lib.versioning import Versions +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr_scenario, save_xr_scenario +from fhs_lib_year_range_manager.lib import YearRange +from tiny_structured_logger.lib import fhs_logging + +from fhs_lib_genem.lib.constants import ( + FileSystemConstants, + ModelConstants, + ScenarioConstants, + SEVConstants, + TransformConstants, +) + +logger = fhs_logging.get_logger() + + +def entity_specific_collection( + entity: str, + stage: str, + versions: Versions, + gbd_round_id: int, + years: YearRange, + transform: str, + intercept_shift_from_reference: bool, + uncross_scenarios: bool, +) -> None: + """Collect, sample, collapse, and export a given risk. + + Args: + entity (str): risk to collect across omegas. If intrinsic SEV, + then the rei will look like acause-rei. + stage (str): stage of run (sev, mmr, etc.) + versions (Versions): input and output versions + gbd_round_id (int): gbd round id. + years (YearRange): past_start:forecast_start:forecast_end. + transform (str): name of transform to use for processing (logit, log, no-transform). + intercept_shift_from_reference (bool): If True, and we are in multi-scenario mode, then + the intercept-shifting during the above `transform` is calculated from the + reference scenario but applied to all scenarios; if False then each scenario will + get its own shift amount. + uncross_scenarios (bool): whether to fix crossed scenarios. This is currently only used + for sevs and should be deprecated soon. + + """ + input_model_weights_version_metadata = versions.get(past_or_future="future", stage=stage) + input_model_weights_file_spec = FHSFileSpec( + version_metadata=input_model_weights_version_metadata, + filename=ModelConstants.MODEL_WEIGHTS_FILE, + ) + + omega_df = read_csv(file_spec=input_model_weights_file_spec, keep_default_na=False) + + locations: List[int] = omega_df["location_id"].unique().tolist() + + future_da = get_location_draw_omegas( + versions=versions, + gbd_round_id=gbd_round_id, + stage=stage, + entity=entity, + omega_df=omega_df, + locations=locations, + ) + + # Every entity has many rows, and the "intrinsic" and "subfolder" values + # should be the same over all rows. So we only need first row here. + first_row = omega_df.query(f"entity == '{entity}'").iloc[0] + + intrinsic, subfolder = bool(first_row["intrinsic"]), str(first_row["subfolder"]) + + if intrinsic: + file_name = f"{entity}_{SEVConstants.INTRINSIC_SEV_FILENAME_SUFFIX}.nc" + else: + file_name = f"{entity}.nc" + + if intrinsic: + # Set all intrinsic scenarios to reference + non_ref_scenarios = [ + s + for s in future_da["scenario"].values + if s != ScenarioConstants.REFERENCE_SCENARIO_COORD + ] + for scenario in non_ref_scenarios: + future_da.loc[{"scenario": scenario}] = future_da.sel( + scenario=ScenarioConstants.REFERENCE_SCENARIO_COORD + ) + + logger.info(f"Entering intercept-shift of {entity} submodel") + future_da = intercept_shift_processing( + stage=stage, + versions=versions, + gbd_round_id=gbd_round_id, + years=years, + transform=transform, + subfolder=subfolder, + future_da=future_da, + file_name=file_name, + shift_from_reference=intercept_shift_from_reference, + ) + + if uncross_scenarios: + future_da = fix_scenario_crossing(years=years, future_da=future_da) + + # NOTE the following removes uncertainty from the scenarios + # for scenario in non_ref_scenarios: + # da.loc[{"scenario": scenario}] = da.sel(scenario=scenario).mean("draw") + + output_version_metadata = versions.get(past_or_future="future", stage=stage) + + output_file_spec = FHSFileSpec( + version_metadata=output_version_metadata, + sub_path=(subfolder,), + filename=file_name, + ) + + save_xr_scenario( + xr_obj=future_da, + file_spec=output_file_spec, + metric="rate", + space="identity", + years=str(years), + past_version=str(versions.get_version_metadata(past_or_future="past", stage=stage)), + out_version=str(versions.get_version_metadata(past_or_future="future", stage=stage)), + gbd_round_id=gbd_round_id, + ) + + +def read_location_draws( + file_spec: FHSFileSpec, location_id: int, draw_start: int, n_draws: int +) -> xr.DataArray: + """Read location-draws from file. + + Notably, this function will expand or contract the number of draws present to fit inside + the closed range [`draw_start`, `draw_start` + `n_draws`], *reassigning coordinates* from + whatever they are read in as. + """ + da = open_xr_scenario(file_spec).sel(location_id=location_id).load() + if "draw" in da.dims: # some sub-models may be draw-less + da = resample(da, n_draws) + da = da.assign_coords(draw=range(draw_start, draw_start + n_draws)) + else: + da = expand_dimensions(da, draw=range(draw_start, draw_start + n_draws)) + return da + + +def fix_scenario_crossing(years: YearRange, future_da: xr.DataArray) -> xr.DataArray: + """Scenario cross the future data and fill missing results within [0, 1].""" + # NOTE we're NOT fixing scenario-crossing in logit space here + # NOTE Code assumes worse > reference > better + + # Ensure same years.past_end values across scenarios after transformations + future_da_ref = future_da.sel(scenario=ScenarioConstants.REFERENCE_SCENARIO_COORD) + future_da_worse = future_da.sel(scenario=ScenarioConstants.WORSE_SCENARIO_COORD) + future_da_better = future_da.sel(scenario=ScenarioConstants.BETTER_SCENARIO_COORD) + + future_worse_diff = future_da_worse.sel(year_id=years.past_end) - future_da_ref.sel( + year_id=years.past_end + ) + future_better_diff = future_da_better.sel(year_id=years.past_end) - future_da_ref.sel( + year_id=years.past_end + ) + + future_new_worse = future_da_worse - future_worse_diff + future_new_better = future_da_better - future_better_diff + + future_da = xr.concat([future_new_worse, future_da_ref, future_new_better], dim="scenario") + + dam = future_da.mean("draw") + + # For SEV's, worse >= ref >= better + worse = dam.sel(scenario=ScenarioConstants.WORSE_SCENARIO_COORD) + better = dam.sel(scenario=ScenarioConstants.BETTER_SCENARIO_COORD) + ref = dam.sel(scenario=ScenarioConstants.REFERENCE_SCENARIO_COORD) + + worse_diff = ref - worse # should be <= 0 for SEV, so we keep the > 0's + worse_diff = worse_diff.where(worse_diff < 0).fillna(0) # keep > 0's + + better_diff = ref - better # should be >= 0 for SEV, so we keep the < 0's + better_diff = better_diff.where(better_diff > 0).fillna(0) # keep < 0's + + # the worse draws that are below ref will have > 0 values added to them + future_da.loc[dict(scenario=ScenarioConstants.WORSE_SCENARIO_COORD)] = ( + future_da.sel(scenario=ScenarioConstants.REFERENCE_SCENARIO_COORD) - worse_diff + ) + # the better draws that are above ref will have < 0 values added to them + future_da.loc[dict(scenario=ScenarioConstants.BETTER_SCENARIO_COORD)] = ( + future_da.sel(scenario=ScenarioConstants.REFERENCE_SCENARIO_COORD) - better_diff + ) + + # non-ref scenarios do not have uncertainty + dim_order = ["draw"] + [x for x in future_da.dims if x != "draw"] + future_da = future_da.transpose(*dim_order) # draw-dim to 1st to broadcast + + # does not save computed past SEVs + needed_years = np.concatenate(([years.past_end], years.forecast_years)) + future_da = future_da.sel(year_id=needed_years) + + future_da = future_da.where(future_da <= 1).fillna(1) + future_da = future_da.where(future_da >= 0).fillna(0) + + return future_da + + +def intercept_shift_processing( + stage: str, + versions: Versions, + gbd_round_id: int, + years: YearRange, + transform: str, + subfolder: str, + future_da: xr.DataArray, + file_name: str, + shift_from_reference: bool, +) -> xr.DataArray: + """Perform ordered draw intercept shifting of past and future data.""" + # Here we do ordered-draw intercept-shift to ensure uncertainty fan-out + past_version_metadata = versions.get(past_or_future="past", stage=stage) + + past_file_spec = FHSFileSpec( + version_metadata=past_version_metadata, + sub_path=(subfolder,), + filename=file_name, + ) + + past_da = open_xr_scenario(past_file_spec).sel( + sex_id=future_da["sex_id"], + age_group_id=future_da["age_group_id"], + location_id=future_da["location_id"], + ) + + if "draw" in past_da.dims and "draw" in future_da.dims: + past_da = resample(past_da, len(future_da.draw.values)) + + if "acause" in future_da.coords: + future_da = future_da.drop_vars("acause") + + if "acause" in past_da.coords: + past_da = past_da.drop_vars("acause") + + if transform != "no-transform": + # NOTE logit transform requires all inputs > 0, but some PAFs can be < 0 + past_da = past_da.where(past_da >= ModelConstants.LOGIT_OFFSET).fillna( + ModelConstants.LOGIT_OFFSET + ) + + processor_class = TransformConstants.TRANSFORMS[transform] + future_da = processor_class.intercept_shift( + modeled_data=future_da, + past_data=past_da, + years=years, + offset=ModelConstants.LOGIT_OFFSET, + intercept_shift="unordered_draw", + shift_from_reference=shift_from_reference, + ) + + return future_da + + +def get_location_draw_omegas( + entity: str, + versions: Versions, + gbd_round_id: int, + stage: str, + omega_df: pd.DataFrame, + locations: List[int], + read_location_draws_fn: Callable = read_location_draws, +) -> xr.DataArray: + """Loop over locations and read location-draw omega files.""" + loc_das = [] + + for location_id in locations: + rows = omega_df.query(f"entity == '{entity}' & location_id == {location_id}") + + if len(rows) == 0: + raise ValueError(f"{entity} for loc {location_id} has no weight info") + + omega_das = [] # to collect the omegas + draw_start = 0 + + for _, row in rows.iterrows(): # each row is an omega-model + omega, model_name, n_draws, intrinsic, subfolder = ( + float(row["omega"]), + str(row["model_name"]), + int(row["draws"]), + bool(row["intrinsic"]), + str(row["subfolder"]), + ) + + if n_draws < 1: # this could happen if inverse_rmse_order == True + continue + + if intrinsic: + file_name = f"{entity}_{SEVConstants.INTRINSIC_SEV_FILENAME_SUFFIX}_{omega}.nc" + else: + file_name = f"{entity}_{omega}.nc" + + version_metadata = versions.get(past_or_future="future", stage=stage) + + file_spec = FHSFileSpec( + version_metadata=version_metadata, + sub_path=(FileSystemConstants.SUBMODEL_FOLDER, model_name, subfolder), + filename=file_name, + ) + + omega_das.append( + read_location_draws_fn(file_spec, location_id, draw_start, n_draws) + ) + + draw_start = draw_start + n_draws + + loc_das.append(xr.concat(omega_das, dim="draw", coords="minimal")) + + future_da = xr.concat(loc_das, dim="location_id", coords="minimal") + + return future_da diff --git a/gbd_2021/disease_burden_forecast_code/risk_factors/genem/constants.py b/gbd_2021/disease_burden_forecast_code/risk_factors/genem/constants.py new file mode 100644 index 0000000..2482cba --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/risk_factors/genem/constants.py @@ -0,0 +1,197 @@ +"""FHS Pipeline for BMI forecasting Local Constants.""" + +from fhs_lib_data_transformation.lib import processing +from fhs_lib_database_interface.lib.constants import ( + ScenarioConstants as ImportedScenarioConstants, +) +from frozendict import frozendict + + +class EntityConstants: + """Constants related to entities (e.g. acauses).""" + + DEFAULT_ENTITY = "default_entity" + MALARIA_ENTITIES = ["malaria", "malaria_act", "malaria_itn"] + NO_SEX_SPLIT_ENTITY = [ + "abuse_csa_male", + "abuse_csa_female", + "abuse_ipv", + "abuse_ipv_exp", + "met_need", + "nutrition_iron", + "inj_homicide_gun_abuse_ipv_paf", + "inj_homicide_other_abuse_ipv_paf", + "inj_homicide_knife_abuse_ipv_paf", + ] + MALARIA = "malaria" + ACT_ITN_COVARIATE = "act-itn" + + +class LocationConstants: + """Constants used for malaria locations.""" + + # locaions with ACT/ITN interventions + MALARIA_ACT_ITN_LOCS = [ + 168, + 175, + 200, + 201, + 169, + 205, + 202, + 171, + 170, + 178, + 179, + 173, + 207, + 208, + 206, + 209, + 172, + 180, + 210, + 181, + 211, + 184, + 212, + 182, + 213, + 214, + 185, + 522, + 216, + 217, + 187, + 435, + 204, + 218, + 189, + 190, + 191, + 198, + 176, + ] + # locaions without ACT/ITN interventions + NON_MALARIA_ACT_ITN_LOCS = [ + 128, + 129, + 130, + 131, + 132, + 133, + 7, + 135, + 10, + 11, + 12, + 13, + 139, + 15, + 16, + 142, + 18, + 19, + 20, + 152, + 26, + 28, + 157, + 30, + 160, + 161, + 162, + 163, + 164, + 165, + 68, + 203, + 215, + 108, + 111, + 113, + 114, + 118, + 121, + 122, + 123, + 125, + 127, + 193, + 195, + 196, + 197, + 177, + ] + + +class ModelConstants: + """Constants used in forecasting.""" + + FLOOR = 1e-6 + LOGIT_OFFSET = 1e-8 + MIN_RMSE = 1e-8 + + DIFF_OVER_MEAN = True # ARC is computed as the difference over mean values + + MODEL_WEIGHTS_FILE = "all_model_weights.csv" + + +class ScenarioConstants(ImportedScenarioConstants): + """Constants related to scenarios.""" + + DEFAULT_BETTER_QUANTILE = 0.15 + DEFAULT_WORSE_QUANTILE = 0.85 + + +class FileSystemConstants: + """Constants for the file system organization.""" + + PV_FOLDER = "pv" + SUBMODEL_FOLDER = "sub_models" + + +class SEVConstants: + """Constants used in SEVs forecasting.""" + + INTRINSIC_SEV_FILENAME_SUFFIX = "intrinsic" + + +class TransformConstants: + """Constants for transformations used during entity forecasting.""" + + TRANSFORMS = frozendict( + { + "logit": processing.LogitProcessor, + "log": processing.LogProcessor, + "no-transform": processing.NoTransformProcessor, + } + ) + + +class JobConstants: + """Constants related to submitting jobs.""" + + EXECUTABLE = "fhs_lib_genem_console" + DEFAULT_RUNTIME = "12:00:00" + + COLLECT_SUBMODELS_RUNTIME = "05:00:00" + MRBRT_RUNTIME = "16:00:00" + + + +class OrchestrationConstants: + """Constants used for ensemble model orchestration.""" + + OMEGA_MIN = 0.0 + OMEGA_MAX = 3.0 + OMEGA_STEP_SIZE = 0.5 + + SUBFOLDER = "risk_acause_specific" + + PV_SUFFIX = "_pv" + N_HOLDOUT_YEARS = 10 # number of holdout years for predictive validity runs + + ARC_TRANSFORM = "logit" + ARC_TRUNCATE_QUANTILES = (0.025, 0.975) + ARC_REFERENCE_SCENARIO = "mean" diff --git a/gbd_2021/disease_burden_forecast_code/risk_factors/genem/create_stage.py b/gbd_2021/disease_burden_forecast_code/risk_factors/genem/create_stage.py new file mode 100644 index 0000000..5fa3e1d --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/risk_factors/genem/create_stage.py @@ -0,0 +1,539 @@ +from typing import Any, Dict, Iterable, List, Optional + +from fhs_lib_database_interface.lib.fhs_lru_cache import fhs_lru_cache +from fhs_lib_file_interface.lib.version_metadata import VersionMetadata +from fhs_lib_file_interface.lib.versioning import Versions +from fhs_lib_orchestration_interface.lib import cluster_tools +from fhs_lib_year_range_manager.lib.year_range import YearRange +from jobmon.client.api import Tool +from jobmon.client.task import Task +from typeguard import typechecked + +from fhs_lib_genem.lib.constants import JobConstants, OrchestrationConstants +from fhs_lib_genem.lib.model_restrictions import ModelRestrictions +from fhs_lib_genem.run.task import ( + get_arc_task, + get_collect_submodels_task, + get_model_weights_task, + get_stagewise_mrbrt_task, +) + +INTERCEPT_SHIFT_FROM_REFERENCE_DEFAULT = True + + +@typechecked +def create_genem_tasks( + tool: Tool, + cluster_project: str, + entities: Iterable[str], + intrinsic: Optional[Dict[str, bool]], + stage: str, + versions: Versions, + gbd_round_id: int, + years: YearRange, + draws: int, + transform: str, + mrbrt_cov_stage1: str, + mrbrt_cov_stage2: str, + national_only: bool, + scenario_quantiles: bool, + model_restrictions: ModelRestrictions, + subfolder: Optional[str], + uncross_scenarios: bool, + age_standardize: bool, + remove_zero_slices: bool, + rescale_ages: bool, + log_level: Optional[str], + intercept_shift_transform: str, + intercept_shift_from_reference: bool = INTERCEPT_SHIFT_FROM_REFERENCE_DEFAULT, + run_pv: bool = False, + run_forecast: bool = False, + run_model_weights: bool = False, + run_collect_models: bool = False, +) -> List[Task]: + """Create all the tasks for the genem.""" + validate_transform_specification( + transform=transform, intercept_shift_transform=intercept_shift_transform + ) + + run_all_steps = not any([run_pv, run_forecast, run_model_weights, run_collect_models]) + + if run_all_steps or run_pv: + # run all genem submodels for predictive validity (pv) in parallel + pv_tasks = forecast_submodels_stage( + tool=tool, + entities=entities, + intrinsic=intrinsic, + cluster_project=cluster_project, + stage=stage, + predictive_validity=True, + versions=versions, + gbd_round_id=gbd_round_id, + years=years, + draws=draws, + transform=transform, + mrbrt_cov_stage1=mrbrt_cov_stage1, + mrbrt_cov_stage2=mrbrt_cov_stage2, + national_only=national_only, + scenario_quantiles=scenario_quantiles, + remove_zero_slices=remove_zero_slices, + subfolder=subfolder, + age_standardize=age_standardize, + rescale_ages=rescale_ages, + log_level=log_level, + ) + else: + pv_tasks = [] + + if run_all_steps or run_forecast: + # run all genem submodel forecasts in parallel + forecast_tasks = forecast_submodels_stage( + tool=tool, + cluster_project=cluster_project, + entities=entities, + intrinsic=intrinsic, + stage=stage, + predictive_validity=False, + versions=versions, + gbd_round_id=gbd_round_id, + years=years, + draws=draws, + transform=transform, + mrbrt_cov_stage1=mrbrt_cov_stage1, + mrbrt_cov_stage2=mrbrt_cov_stage2, + national_only=national_only, + scenario_quantiles=scenario_quantiles, + remove_zero_slices=remove_zero_slices, + subfolder=subfolder, + age_standardize=age_standardize, + rescale_ages=rescale_ages, + log_level=log_level, + ) + else: + forecast_tasks = [] + + if run_all_steps or run_model_weights: + # get model weights from pv results. + # This step calculates weights for the collection step, using the predictive validity + # calculated in the pv step. + model_weights_tasks = model_weights_stage( + tool=tool, + cluster_project=cluster_project, + out_version=versions.get_version_metadata(past_or_future="future", stage=stage), + gbd_round_id=gbd_round_id, + draws=draws, + mrbrt_cov_stage1=mrbrt_cov_stage1, + mrbrt_cov_stage2=mrbrt_cov_stage2, + model_restrictions=model_restrictions, + log_level=log_level, + ) + for model_weights_task in model_weights_tasks: + for pv_task in pv_tasks: + model_weights_task.add_upstream(pv_task) + else: + model_weights_tasks = [] + + if run_all_steps or run_collect_models: + # collect draws from submodels to make genem + # This step combines the forecasts generated in the forecast step, using the weights + # calculated during the weights step. + + transform_spec = ( + transform if intercept_shift_transform == "none" else intercept_shift_transform + ) + + collect_models_tasks = collect_submodels_stage( + tool=tool, + entities=entities, + cluster_project=cluster_project, + stage=stage, + versions=versions, + gbd_round_id=gbd_round_id, + years=years, + transform=transform_spec, + intercept_shift_from_reference=intercept_shift_from_reference, + uncross_scenarios=uncross_scenarios, + log_level=log_level, + ) + for collect_models_task in collect_models_tasks: + for forecast_task in forecast_tasks: + collect_models_task.add_upstream(forecast_task) + + for model_weights_task in model_weights_tasks: + collect_models_task.add_upstream(model_weights_task) + else: + collect_models_tasks = [] + + return pv_tasks + forecast_tasks + model_weights_tasks + collect_models_tasks + + +@typechecked +def forecast_submodels_stage( + tool: Tool, + cluster_project: str, + entities: Iterable[str], + intrinsic: Optional[Dict[str, bool]], + stage: str, + versions: Versions, + gbd_round_id: int, + years: YearRange, + draws: int, + transform: str, + mrbrt_cov_stage1: str, + mrbrt_cov_stage2: str, + predictive_validity: bool, + national_only: bool, + scenario_quantiles: bool, + remove_zero_slices: bool, + subfolder: Optional[str], + age_standardize: bool, + rescale_ages: bool, + log_level: Optional[str], +) -> List[Task]: + """Make tasks for genem submodels. + + The submodels are: + 1.) arc + 2.) mrbrt + + Args: + tool: Jobmon tool to associate tasks with + cluster_project: cluster project to run tasks under + entities: the entities to forecast + intrinsic: optional mapping of entities to intrinsic values. If not provided, + its assumed that no entities are intrinsic. + stage (str): stage to model + versions (Versions): versions of all inputs, outputs and covariates (past and future). + precalculated_version (str): version name of precalculated sevs. + gbd_round_id (int): gbd round id. + years (YearRange): past_start:forecast_start:forecast_end. + draws (int): number of draws to keep. + transform (str): transformation to perform on data before modeling. + mrbrt_cov_stage1 (str): The covariate name to be used in the MRBRT first stage. + mrbrt_cov_stage2 (str): The covariate name to be used in the MRBRT second stage. + predictive_validity (bool): whether this is a predictive-validity stage. + national_only (bool): whether to only compute for nationals. + scenario_quantiles (bool): If True, then use the scenario quantiles parameter in + the stagewise-mrbrt model specification. + subfolder (str): subfolder to read/write + age_standardize (bool): whether to age-standardize the data before modeling + rescale_ages (bool): whether to rescale during ARC age standardization. We are + currently only setting this to true for the sevs pipeline. + log_level: log_level to use for tasks + + Returns: + List of tasks + """ + subfolder = subfolder or "" + + tasks = [] + for entity in entities: + entity_intrinsic = intrinsic[entity] if intrinsic else False + + tasks += genem_tasks( + tool=tool, + cluster_project=cluster_project, + entity=entity, + stage=stage, + versions=versions, + gbd_round_id=gbd_round_id, + years=years, + draws=draws, + transform=transform, + mrbrt_cov_stage1=mrbrt_cov_stage1, + mrbrt_cov_stage2=mrbrt_cov_stage2, + predictive_validity=predictive_validity, + national_only=national_only, + scenario_quantiles=scenario_quantiles, + remove_zero_slices=remove_zero_slices, + intrinsic=entity_intrinsic, + subfolder=OrchestrationConstants.SUBFOLDER if entity_intrinsic else subfolder, + age_standardize=age_standardize, + rescale_ages=rescale_ages, + log_level=log_level, + ) + + return tasks + + +@typechecked +def genem_tasks( + tool: Tool, + cluster_project: str, + entity: str, + stage: str, + versions: Versions, + gbd_round_id: int, + years: YearRange, + draws: int, + transform: str, + mrbrt_cov_stage1: str, + mrbrt_cov_stage2: str, + predictive_validity: bool, + national_only: bool, + scenario_quantiles: bool, + remove_zero_slices: bool, + intrinsic: bool, + subfolder: Optional[str], + age_standardize: bool, + rescale_ages: bool, + log_level: Optional[str], +) -> List[Task]: + """Make tasks for the genem submodels. + + Args: + tool: Jobmon tool to associate tasks with + cluster_project: cluster project to run tasks under + entity (str): "all", or individual risk/cause-risk. + stage (str): stage to forecast. + versions (Versions): versions of all inputs, outputs and covariates (past and future). + gbd_round_id (int): gbd round id. + years (YearRange): past_start:forecast_start:forecast_end. + draws (int): number of draws to keep. + transform (str): transformation to perform on data before modeling. + mrbrt_cov_stage1 (str): The covariate name to be used in the MRBRT frst stage. + mrbrt_cov_stage2 (str): The covariate name to be used in the MRBRT second stage. + predictive_validity (bool): whether this is a predictive-validity job. + national_only (bool): whether to only compute for nationals. + scenario_quantiles (bool): If True, then use the scenario quantiles parameter in + the stagewise-mrbrt model specification. + intrinsic (bool): whether entity is "intrinsic" (SEV only). + subfolder (Optional[str]): input/ouput data subfolder. + age_standardize (bool): whether to age-standardize before modeling. + rescale_ages (bool): whether to rescale during ARC age standardization. We are + currently only setting this to true for the sevs pipeline. + log_level: log_level to use for tasks + + Returns: + List of tasks + """ + model_name_map = make_model_name_map( + mrbrt_cov_stage1=mrbrt_cov_stage1, + mrbrt_cov_stage2=mrbrt_cov_stage2, + ) + pv_forecast_start = years.past_end - (OrchestrationConstants.N_HOLDOUT_YEARS - 1) + pv_years = YearRange(years.past_start, pv_forecast_start, years.past_end) + + arc_compute_resources = get_compute_resources( + memory_gb=JobConstants.JOB_MEMORY, + cluster_project=cluster_project, + runtime=JobConstants.DEFAULT_RUNTIME, + ) + + mrbrt_compute_resources = get_compute_resources( + memory_gb=JobConstants.MRBRT_MEM_GB, + cluster_project=cluster_project, + runtime=JobConstants.MRBRT_RUNTIME, + ) + + tasks = [] + arc_task = get_arc_task( + tool=tool, + compute_resources=arc_compute_resources, + entity=entity, + stage=stage, + intrinsic=intrinsic, + subfolder=subfolder, + versions=versions, + model_name=model_name_map["arc"], + omega_min=OrchestrationConstants.OMEGA_MIN, + omega_max=OrchestrationConstants.OMEGA_MAX, + omega_step_size=OrchestrationConstants.OMEGA_STEP_SIZE, + transform=transform, + truncate=True, + truncate_quantiles=OrchestrationConstants.ARC_TRUNCATE_QUANTILES, + replace_with_mean=False, + reference_scenario=OrchestrationConstants.ARC_REFERENCE_SCENARIO, + years=pv_years if predictive_validity else years, + gbd_round_id=gbd_round_id, + cap_forecasts=False, + cap_quantiles=None, + national_only=national_only, + age_standardize=age_standardize, + rescale_ages=rescale_ages, + predictive_validity=predictive_validity, + remove_zero_slices=remove_zero_slices, + log_level=log_level, + ) + + tasks.append(arc_task) + + for mrbrt_cov_stage2, mrbrt_name in model_name_map["mrbrt"].items(): + stagewise_mrbrt_task = get_stagewise_mrbrt_task( + tool=tool, + compute_resources_dict=mrbrt_compute_resources, + years=pv_years if predictive_validity else years, + gbd_round_id=gbd_round_id, + versions=versions, + model_name=mrbrt_name, + stage=stage, + entity=entity, + draws=draws, + mrbrt_cov_stage1=mrbrt_cov_stage1, + mrbrt_cov_stage2=mrbrt_cov_stage2, + omega_min=OrchestrationConstants.OMEGA_MIN, + omega_max=OrchestrationConstants.OMEGA_MAX, + step=OrchestrationConstants.OMEGA_STEP_SIZE, + transform=transform, + predictive_validity=predictive_validity, + national_only=national_only, + scenario_quantiles=scenario_quantiles, + remove_zero_slices=remove_zero_slices, + intrinsic=intrinsic, + age_standardize=age_standardize, + subfolder=subfolder, + log_level=log_level, + ) + tasks.append(stagewise_mrbrt_task) + + return tasks + + +@fhs_lru_cache(1) +def make_model_name_map(mrbrt_cov_stage1: str, mrbrt_cov_stage2: str) -> dict: + """Determine arc/mrbrt model_names based on covariates. + + NOTE that "arc" maps to a single string model name, whereas "mrbrt" maps to a dict of + model names. + + Args: + mrbrt_cov_stage1 (str): The covariate name to be used in the MRBRT first stage. + mrbrt_cov_stage2 (str): The covariate name to be used in the MRBRT second stage. + + Returns: + (dict): dictionary mapping submodel type to version metadata. + """ + mrbrt_map = { + cov_stage2: f"{mrbrt_cov_stage1}_{cov_stage2}" for cov_stage2 in [mrbrt_cov_stage2] + } + return {"arc": "arc", "mrbrt": mrbrt_map} + + +@typechecked +def model_weights_stage( + tool: Tool, + cluster_project: str, + out_version: VersionMetadata, + gbd_round_id: int, + draws: int, + model_restrictions: ModelRestrictions, + mrbrt_cov_stage1: str, + mrbrt_cov_stage2: str, + log_level: Optional[str], +) -> List[Task]: + """Generate the task for the model weights calculation.""" + model_name_map = make_model_name_map( + mrbrt_cov_stage1=mrbrt_cov_stage1, + mrbrt_cov_stage2=mrbrt_cov_stage2, + ) + submodel_names = [model_name_map["arc"]] + list(model_name_map["mrbrt"].values()) + + compute_resources = get_compute_resources( + memory_gb=JobConstants.MODEL_WEIGHTS_MEM_GB, + runtime=JobConstants.DEFAULT_RUNTIME, + cluster_project=cluster_project, + ) + + task = get_model_weights_task( + tool=tool, + compute_resources=compute_resources, + submodel_names=submodel_names, + subfolder=OrchestrationConstants.SUBFOLDER, + out_version=out_version, + gbd_round_id=gbd_round_id, + draws=draws, + model_restrictions=model_restrictions, + log_level=log_level, + ) + return [task] + + +@typechecked +def collect_submodels_stage( + tool: Tool, + cluster_project: str, + entities: Iterable[str], + stage: str, + versions: Versions, + gbd_round_id: int, + years: YearRange, + transform: str, + intercept_shift_from_reference: bool, + uncross_scenarios: bool, + log_level: Optional[str], +) -> List[Task]: + """Generate tasks for collecting genem results. + + Args: + tool: Jobmon tool to associate tasks with + cluster_project: cluster project to run tasks under + entities: entities to forecast. + stage (str): stage to forecast. + versions (Versions): versions with input and output versions. + gbd_round_id (int): gbd round id. + years (YearRange): past_start:forecast_start:forecast_end. + transform (str): name of transformation to use when intercept shifting. + intercept_shift_from_reference (bool): If True, and we are in multi-scenario mode, then + the intercept-shifting during the above `transform` is calculated from the + reference scenario but applied to all scenarios; if False then each scenario will + get its own shift amount. + uncross_scenarios (bool): whether to fix crossed scenarios. This is currently only used + for sevs and should be deprecated soon. + remove_zero_slices (bool): If True, remove zero-slices along certain dimensions, when + pre-processing inputs, and add them back in to outputs. + log_level: log_level to use for tasks + """ + compute_resources = get_compute_resources( + memory_gb=JobConstants.COLLECT_SUBMODELS_MEM_GB, + runtime=JobConstants.COLLECT_SUBMODELS_RUNTIME, + cluster_project=cluster_project, + cores=JobConstants.COLLECT_SUBMODELS_NUM_CORES, + ) + + tasks = [] + for rei in entities: + task = get_collect_submodels_task( + compute_resources=compute_resources, + tool=tool, + entity=rei, + stage=stage, + versions=versions, + gbd_round_id=gbd_round_id, + years=years, + transform=transform, + intercept_shift_from_reference=intercept_shift_from_reference, + uncross_scenarios=uncross_scenarios, + log_level=log_level, + ) + tasks.append(task) + + return tasks + + +def get_compute_resources( + memory_gb: int, + cluster_project: str, + runtime: str, + cores: int = JobConstants.DEFAULT_CORES, +) -> Dict[str, Any]: + """Return a dictionary containing keys & values required for Jobmon Task creation.""" + error_logs_dir, output_logs_dir = cluster_tools.get_logs_dirs() + + return dict( + memory=f"{memory_gb}G", + cores=cores, + runtime=runtime, + project=cluster_project, + queue=JobConstants.DEFAULT_QUEUE, + stderr=error_logs_dir, + stdout=output_logs_dir, + ) + + +def validate_transform_specification(transform: str, intercept_shift_transform: str) -> None: + """Ensure ``intercept_shift_transform`` is not none when ``transform`` is no-transform.""" + if transform == "no-transform" and intercept_shift_transform == "none": + raise ValueError( + "When ``--transform no-transform`` is set, ``--intercept-shift-transform`` cannot " + "be ``none``." + ) diff --git a/gbd_2021/disease_burden_forecast_code/risk_factors/genem/get_model_weights_from_holdouts.py b/gbd_2021/disease_burden_forecast_code/risk_factors/genem/get_model_weights_from_holdouts.py new file mode 100644 index 0000000..5f07b08 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/risk_factors/genem/get_model_weights_from_holdouts.py @@ -0,0 +1,212 @@ +"""Collects submodel predictive validity statistics to compile sampling weights for genem. + +The output is a file called `all_model_weights.csv` in the out version +""" + +import glob +import os +import re +from typing import List, Tuple + +import numpy as np +import pandas as pd +from fhs_lib_file_interface.lib.pandas_wrapper import read_csv, write_csv +from fhs_lib_file_interface.lib.version_metadata import ( + FHSDirSpec, + FHSFileSpec, + VersionMetadata, +) + +from fhs_lib_genem.lib.constants import FileSystemConstants, ModelConstants, SEVConstants +from fhs_lib_genem.lib.model_restrictions import ModelRestrictions + + +def pv_file_name_breakup(file_path: str) -> Tuple[str, str]: + r"""Parse full path to predictive-validity (PV) file, extracting entity name and suffix. + + Ex: if file_path is FILEPATH, + then entity = "dud", suffix = "_pv.csv". + If \*/dud_intrinsic_pv.csv, then entity = "dud", + suffix = "_intrinsic_pv.csv". + + Args: + file_path (str): full pv-file path, expected to end with "_pv.csv". + + Returns: + Tuple[str, str]: entity and suffix + """ + filename = os.path.basename(file_path) + match = re.match(r"(.*?)((_intrinsic)?_pv.csv)", filename) + if not match: + raise ValueError( + "PV file path should be of the form 'foo_pv.csv' or 'foo_intrinsic_pv.csv'" + ) + entity, suffix, _ = match.groups() + if not entity or not suffix: + raise ValueError( + "PV file path should be of the form 'foo_pv.csv' or 'foo_intrinsic_pv.csv'" + ) + return entity, suffix + + +def collect_model_rmses( + out_version: VersionMetadata, + gbd_round_id: int, + submodel_names: List[str], + subfolder: str, +) -> pd.DataFrame: + """Collect submodel omega rmse values into a dataframe. + + Loops over pv versions, parses all entity-specific _pv.csv files, + including subfolders. + + Args: + out_version (VersionMetadata): the output version for the whole model. Where this + function looks for submodels. + gbd_round_id (int): gbd_round_id used in the model; used for looking for submodels if + not provided with out_version + submodel_names (List[str]): names of all the sub-models to collect. + subfolder (str): subfolder name where intrinsics are stored. + + Returns: + (pd.DataFrame): Dataframe that contains all columns needed to compute + ensemble weights. + """ + combined_pv_df = pd.DataFrame([]) + + # loop over versions and entities + for submodel_name in submodel_names: + input_dir_spec = FHSDirSpec( + version_metadata=out_version, + sub_path=( + FileSystemConstants.PV_FOLDER, + submodel_name, + ), + ) + subfolder_dir_spec = FHSDirSpec( + version_metadata=out_version, + sub_path=( + FileSystemConstants.PV_FOLDER, + submodel_name, + subfolder, + ), + ) + + if not input_dir_spec.data_path().exists(): + raise FileNotFoundError(f"No such directory {input_dir_spec.data_path()}") + files = glob.glob(str(input_dir_spec.data_path() / "*_pv.csv")) + entities = dict([pv_file_name_breakup(file_path) for file_path in files]) + + sub_entities = {} + if (subfolder_dir_spec.data_path()).exists(): # check out the subfolder + files = glob.glob(str(subfolder_dir_spec.data_path() / "*_pv.csv")) + sub_entities = dict([pv_file_name_breakup(file_path) for file_path in files]) + entities.update(sub_entities) + + for ent, suffix in entities.items(): + sub_dir = subfolder if ent in sub_entities else "" + suffix = entities[ent] + + input_file_spec = FHSFileSpec( + version_metadata=input_dir_spec.version_metadata, + # Note that we have to manually alter the sub_path, otherwise we'd be using + # FHSFileSpec.from_dirspec() + sub_path=tuple(list(input_dir_spec.sub_path) + [sub_dir]), + filename=ent + suffix, + ) + pv_df = read_csv(input_file_spec, keep_default_na=False) + + pv_df["model_name"] = submodel_name + pv_df["subfolder"] = sub_dir + pv_df["intrinsic"] = ( + True if SEVConstants.INTRINSIC_SEV_FILENAME_SUFFIX in suffix else False + ) + combined_pv_df = combined_pv_df.append(pv_df) + + # just to move the "entity" column to the front + if "entity" in combined_pv_df.columns: + combined_pv_df = combined_pv_df[ + ["entity"] + [col for col in combined_pv_df.columns if col != "entity"] + ] + + return combined_pv_df + + +def make_omega_weights( + submodel_names: List[str], + subfolder: str, + out_version: VersionMetadata, + gbd_round_id: int, + draws: int, + model_restrictions: ModelRestrictions, +) -> None: + """Collect submodel omega rmse values into a dataframe. + + Loops over pv versions, parses all entity-specific _pv.csv files, + including subfolders. + + Args: + submodel_names (List[str]): names of all the sub-models to collect. + gbd_round_id (int): gbd round id. + subfolder (str): subfolder name where intrinsics are stored. + out_version (VersionMetadata): the output version for the whole model. Where this + function looks for submodels. + gbd_round_id (int): gbd_round_id used in the model; used for looking for submodels if + not provided with out_version + draws (int): number of total draws for the ensemble. + model_restrictions (ModelRestrictions): any arc-only, mrbrt-only restrictions. + """ + df = collect_model_rmses( + out_version=out_version, + gbd_round_id=gbd_round_id, + submodel_names=submodel_names, + subfolder=subfolder, + ) + + out = pd.DataFrame([]) + + for entity in df["entity"].unique(): + for location_id in df["location_id"].unique(): + ent_loc_df = df.query(f"entity == '{entity}' & location_id == {location_id}") + + model_type = model_restrictions.model_type(entity, location_id) + + if model_type == "arc": + # we effectively pull 0 draws from those where rmse == np.inf + ent_loc_df.loc[ent_loc_df["model_name"] != "arc", "rmse"] = np.inf + + if model_type == "mrbrt": + # we effectively pull 0 draws from those where rmse == np.inf + ent_loc_df.loc[ent_loc_df["model_name"] == "arc", "rmse"] = np.inf + + # we use rmse values to determine draws sampled from submodels + ent_loc_df = ent_loc_df.sort_values(by="rmse", ascending=True) + + # use 1/rmse to determine weight/draws + rmse = ent_loc_df["rmse"] + ModelConstants.MIN_RMSE # padding in case of 0 + rmse_recip = 1 / rmse + model_wts = rmse_recip / rmse_recip.sum() + + # lowest rmse contributes the most draws + sub_draws = (np.round(model_wts, 3) * draws).astype(int) + + # in the event that sum(sub_draws) != draws, we make up the diff + # by adding the diff to the first element + if sub_draws.sum() != draws: + sub_draws.iloc[0] += draws - sub_draws.sum() + + # now assign sub-model weight and draws to df + ent_loc_df["model_weight"] = model_wts + ent_loc_df["draws"] = sub_draws + + out = out.append(ent_loc_df) + + write_csv( + df=out, + file_spec=FHSFileSpec( + version_metadata=out_version, filename=ModelConstants.MODEL_WEIGHTS_FILE + ), + sep=",", + na_rep=".", + index=False, + ) diff --git a/gbd_2021/disease_burden_forecast_code/risk_factors/genem/model_restrictions.py b/gbd_2021/disease_burden_forecast_code/risk_factors/genem/model_restrictions.py new file mode 100644 index 0000000..377d35f --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/risk_factors/genem/model_restrictions.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from collections import defaultdict +from typing import Any, Iterable, List, Tuple, Union + +import yaml + +ALL_CATEGORIES = "all" +BOTH_MODELS = "both" + + +class ModelRestrictions: + """A class for capturing restrictions in which models get run for each entity/location.""" + + def __init__(self, restrictions: Iterable[Tuple[str, Union[int, str], str]]) -> None: + """Initializer. + + Args: + restrictions: A list of tuples each of which specifies a particular restriction. + The tuples contain an entity, a location and the model type to use (in that + order). The entity and location can also be "all", to indicate that it + applies to all entities. Entity-specific restrictions take precedence over + location-specific restrictions. + """ + self._original_specification = list(restrictions) + self._map = defaultdict(dict) + for restriction in self._original_specification: + entity = restriction[0] + location = ( + ALL_CATEGORIES if restriction[1] == ALL_CATEGORIES else int(restriction[1]) + ) + model_type = restriction[2] + + if location in self._map[entity]: + raise ValueError( + f"Restriction list includes multiple restrictions for {entity}/{location}" + ) + + self._map[entity][location] = model_type + + def model_type(self, entity: str, location_id: int) -> str: + """Get the model type to use for an entity and location, according to the restrictions. + + Args: + entity: entity to look up the restriction for. + location: location to look up the restriction for. + + Returns: + Either a model type name to use for the entity/location_id, or "both", to indicate + that both model types should be used. + + """ + if entity in self._map: + if location_id in self._map[entity]: + return self._map[entity][location_id] + elif ALL_CATEGORIES in self._map[entity]: + return self._map[entity][ALL_CATEGORIES] + else: + return BOTH_MODELS + elif ALL_CATEGORIES in self._map: + if location_id in self._map[ALL_CATEGORIES]: + return self._map[ALL_CATEGORIES][location_id] + elif ALL_CATEGORIES in self._map[ALL_CATEGORIES]: + return self._map[ALL_CATEGORIES][ALL_CATEGORIES] + else: + return BOTH_MODELS + else: + return BOTH_MODELS + + def string_specifications(self) -> List[str]: + """Returns a list of strings for the specifications used to initialize the object. + + Useful for serializing the object on the command-line, essentially a representation of + the way the object was initialized. + """ + return [ + " ".join([str(field) for field in spec]) for spec in self._original_specification + ] + + def __eq__(self, other: ModelRestrictions) -> bool: + """Do they have the same underlying dict?""" + return self._map == other._map + + @staticmethod + def yaml_representer(dumper: Any, data: ModelRestrictions) -> str: + """Function for passing to pyyaml telling it how to represent ModelRestrictions. + + This specific tag used tells pyyaml not tuse a tag. + + Args: + dumper: pyyaml dumper + data: ModelRestrictions object ot serialize + """ + return dumper.represent_sequence("tag:yaml.org,2002:seq", data._original_specification) + + +# register the yaml_representer with pyyaml's safedumper (which is what we use to write yaml) +yaml.SafeDumper.add_representer( + ModelRestrictions, ModelRestrictions.yaml_representer +) diff --git a/gbd_2021/disease_burden_forecast_code/risk_factors/genem/predictive_validity.py b/gbd_2021/disease_burden_forecast_code/risk_factors/genem/predictive_validity.py new file mode 100644 index 0000000..ecf0194 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/risk_factors/genem/predictive_validity.py @@ -0,0 +1,106 @@ +from typing import Iterable, Optional + +import numpy as np +import pandas as pd +import xarray as xr +from fhs_lib_database_interface.lib.constants import DimensionConstants +from fhs_lib_file_interface.lib.pandas_wrapper import write_csv +from fhs_lib_file_interface.lib.version_metadata import FHSFileSpec +from fhs_lib_file_interface.lib.versioning import Versions + +from fhs_lib_genem.lib.constants import FileSystemConstants + +OMEGA_DIM = "omega" + + +def get_omega_weights(min: float, max: float, step: float) -> Iterable[float]: + """Return the list of weights between ``min`` and ``max``, incrementing by ``step``.""" + return np.arange(min, max, step) + + +def root_mean_square_error( + predicted: xr.DataArray, + observed: xr.DataArray, + dims: Optional[Iterable[str]] = None, +) -> xr.DataArray: + """Dimensions-specific root-mean-square-error. + + Args: + predicted (xr.DataArray): predicted values. + observed (xr.DataArray): observed values. + dims (Optional[Iterable[str]]): list of dims to compute rms for. + + Returns: + (xr.DataArray): root-mean-square error, dims-specific. + """ + dims = dims or [] + + squared_error = (predicted - observed) ** 2 + other_dims = [d for d in squared_error.dims if d not in dims] + return np.sqrt(squared_error.mean(dim=other_dims)) + + +def calculate_predictive_validity( + forecast: xr.DataArray, + holdouts: xr.DataArray, + omega: float, +) -> xr.DataArray: + """Calculate the RMSE between ``forecast`` and ``holdouts`` across location & sex.""" + # Take the mean over draw if forecast or holdouts data has them + if DimensionConstants.DRAW in forecast.dims: + forecast_mean = forecast.mean(DimensionConstants.DRAW) + else: + forecast_mean = forecast + + if DimensionConstants.DRAW in holdouts.dims: + holdouts_mean = holdouts.mean(DimensionConstants.DRAW) + else: + holdouts_mean = holdouts + + # Calculate RMSE + pv_data = root_mean_square_error( + predicted=forecast_mean.sel(scenario=0, drop=True), + observed=holdouts_mean, + dims=[DimensionConstants.LOCATION_ID, DimensionConstants.SEX_ID], + ) + + # Tag the data with a hard-coded attribute & return it + pv_data[OMEGA_DIM] = omega + return pv_data + + +def finalize_pv_data(pv_list: Iterable[xr.DataArray], entity: str) -> pd.DataFrame: + """Convert a list of PV xarrays into a pandas dataframe, and take the mean over sexes.""" + pv_xr = xr.concat(pv_list, dim=OMEGA_DIM) + + # mean over sexes (if it's present) + if DimensionConstants.SEX_ID in pv_xr.dims: + pv_xr = pv_xr.mean([DimensionConstants.SEX_ID]) + + pv_xr["entity"] = entity + return pv_xr.to_dataframe(name="rmse").reset_index() + + +def save_predictive_validity( + file_name: str, + gbd_round_id: int, + model_name: str, + pv_df: pd.DataFrame, + stage: str, + subfolder: str, + versions: Versions, +) -> None: + """Write a predictive validity dataframe to disk.""" + # Define the output file spec + pv_file_spec = FHSFileSpec( + version_metadata=versions.get(past_or_future="future", stage=stage), + sub_path=( + FileSystemConstants.PV_FOLDER, + model_name, + subfolder, + ), + filename=f"{file_name}_pv.csv", + ) + + # Write the dataframe (note that the pv output directory is "{out_version}_pv") + write_csv(df=pv_df, file_spec=pv_file_spec, sep=",", na_rep=".", index=False) diff --git a/gbd_2021/disease_burden_forecast_code/risk_factors/genem/run_stagewise_mrbrt.py b/gbd_2021/disease_burden_forecast_code/risk_factors/genem/run_stagewise_mrbrt.py new file mode 100644 index 0000000..ae79544 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/risk_factors/genem/run_stagewise_mrbrt.py @@ -0,0 +1,576 @@ +"""A script that forecasts entities using MRBRT.""" + +from typing import Dict, List, Optional + +import numpy as np +import pandas as pd +import xarray as xr +from fhs_lib_data_transformation.lib import filter, processing +from fhs_lib_data_transformation.lib.dimension_transformation import expand_dimensions +from fhs_lib_data_transformation.lib.validate import assert_shared_coords_same +from fhs_lib_database_interface.lib.constants import DimensionConstants, StageConstants +from fhs_lib_database_interface.lib.query import location +from fhs_lib_file_interface.lib.check_input import check_versions +from fhs_lib_file_interface.lib.version_metadata import FHSDirSpec, FHSFileSpec +from fhs_lib_file_interface.lib.versioning import Versions +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr_scenario, save_xr_scenario +from fhs_lib_model.lib.constants import ModelConstants +from fhs_lib_model.lib.stagewise_mrbrt import StagewiseMRBRT +from fhs_lib_year_range_manager.lib import YearRange +from mrtool import LinearCovModel +from stagemodel import OverallModel, StudyModel +from tiny_structured_logger.lib import fhs_logging + +from fhs_lib_genem.lib import predictive_validity as pv +from fhs_lib_genem.lib.constants import ( + EntityConstants, + FileSystemConstants, + LocationConstants, + ScenarioConstants, + TransformConstants, +) + +logger = fhs_logging.get_logger() + +PROCESSOR_OFFSET = 1e-8 +SCENARIO_QUANTILES = { + -1: dict(year_id=ScenarioConstants.DEFAULT_WORSE_QUANTILE), + 0: None, + 1: dict(year_id=ScenarioConstants.DEFAULT_BETTER_QUANTILE), + 2: None, +} +STUDY_ID_COLS = "location_id" + + +def stagewise_mrbrt_all_omegas( + entity: str, + stage: str, + versions: Versions, + model_name: str, + years: YearRange, + gbd_round_id: int, + draws: int, + mrbrt_cov_stage1: str, + mrbrt_cov_stage2: str, + omega_min: float, + omega_max: float, + step: float, + transform: str, + predictive_validity: bool, + intrinsic: bool, + subfolder: str, + national_only: bool, + use_scenario_quantiles: bool, + age_standardize: bool, + remove_zero_slices: bool, +) -> None: + r"""Fit and predict a given entity. + + Models accepted should be just MRBRT and two-stage MRBRT. + + Args: + entity (str): entity. + stage (str): Name of stage to be forecasted, e.g. ``incidence`` + versions (Versions): All relevant versions. e.g. FILEPATH + model_name (str): Name to save the model under. + years (YearRange): Forecasting time series. + gbd_round_id (int): The ID of GBD round associated with the past data + draws (int): number of draws. + mrbrt_cov_stage1 (str): The covariate name to be used in the MRBRT first stage. + mrbrt_cov_stage2 (str): The covariate name to be used in the MRBRT second stage. + omega_min (float): Minimum omega weight value to run MRBRT. + omega_max (float): Maximum omega weight value to run MRBRT. + step (float): Step value to be used for the range of omega values. + transform (str): transformation to perform on data before modeling. + predictive_validity (bool): whether to run predictive-validity. If true, save csv + files with rmse values. + intrinsic (bool): whether to use intrinsic or distal SEVs. If true, use intrinsic + SEVs. Default is distal SEVs. + subfolder (str): whether to use subfolder name. If provided, use subfolder name. + national_only (bool): Filter for only national locations. + use_scenario_quantiles (bool): Use scenario quantiles specification in model + age_standardize(bool) : whether to age standardize data before modeling. + remove_zero_slices (bool): If True, remove zero-slices along certain dimensions, when + pre-processing inputs, and add them back in to outputs. + """ + logger.info("in stagewise_mrbrt_all_omegas()") + + covariates, node_models = _get_covariates_and_node_models( + years=years, + gbd_round_id=gbd_round_id, + mrbrt_cov_stage1=mrbrt_cov_stage1, + mrbrt_cov_stage2=mrbrt_cov_stage2, + ) + + # Open input data + for p_or_f in ["past", "future"]: + versions_to_check = {stage} | covariates.keys() if covariates else {stage} + check_versions(versions, p_or_f, versions_to_check) + + file_name = entity + if intrinsic: + file_name = entity + "_intrinsic" + + dep_version_metadata = versions.get(past_or_future="past", stage=stage) + + dep_file_spec = FHSFileSpec( + version_metadata=dep_version_metadata, + sub_path=(subfolder,) if subfolder else (), + filename=f"{file_name}.nc", + ) + + # Clean input data + dep_data = open_xr_scenario(dep_file_spec) + dep_data.name = stage + + # since we eventually fit on the mean, there's no point to sub-sample here + cleaned_data = processing.subset_to_reference( + data=dep_data, + draws=len(dep_data["draw"]) if "draw" in dep_data.dims else None, + year_ids=years.past_years, + ) + + cleaned_data_sub = _subset_ages_and_sexes(cleaned_data, gbd_round_id, age_standardize) + + cleaned_data_sub = _subset_locations( + cleaned_data_sub, + entity, + gbd_round_id, + national_only, + mrbrt_cov_stage1, + ) + + cleaned_data_sub = cleaned_data_sub.where(cleaned_data_sub >= 0).fillna(0) + + if intrinsic: + cleaned_data_sub = cleaned_data_sub.drop_vars("acause") + + processor = TransformConstants.TRANSFORMS[transform]( + years=years, + offset=PROCESSOR_OFFSET, + gbd_round_id=gbd_round_id, + age_standardize=age_standardize, + remove_zero_slices=remove_zero_slices, + intercept_shift="mean", + ) + + prepped_input_data = processor.pre_process(cleaned_data_sub) + + cov_data_list = _get_covariate_data( + dep_var_da=prepped_input_data, + covariates=covariates, + versions=versions, + years=years, + gbd_round_id=gbd_round_id, + draws=draws, + predictive_validity=predictive_validity, + ) + intersection_locations = set(prepped_input_data.location_id.values) + for cov in cov_data_list: + intersection_locations = intersection_locations.intersection( + set(cov.location_id.values) + ) + + prepped_input_data = prepped_input_data.sel(location_id=list(intersection_locations)) + + stripped_input_data = processing.strip_single_coord_dims(prepped_input_data) + + if predictive_validity: + all_omega_pv_results: List[xr.DataArray] = [] + + # Loop over omega values + + for omega in pv.get_omega_weights(omega_min, omega_max, step): + # Separate fits by sex if needed + if ( + entity in EntityConstants.NO_SEX_SPLIT_ENTITY + or "sex_id" not in stripped_input_data.dims + ): + forecast_data = _fit_and_predict_model( + past_data=stripped_input_data, + years=years, + draws=draws, + cov_data_list=cov_data_list, + node_models=node_models, + scenario_quantiles=SCENARIO_QUANTILES if use_scenario_quantiles else None, + gbd_round_id=gbd_round_id, + stage=stage, + entity=entity, + versions=versions, + sex_id=None, + omega=omega, + predictive_validity=predictive_validity, + ) + else: + stripped_input_male = stripped_input_data.sel(sex_id=[1]) + stripped_input_female = stripped_input_data.sel(sex_id=[2]) + forecast_male = _fit_and_predict_model( + past_data=stripped_input_male, + years=years, + draws=draws, + cov_data_list=cov_data_list, + node_models=node_models, + scenario_quantiles=SCENARIO_QUANTILES if use_scenario_quantiles else None, + gbd_round_id=gbd_round_id, + stage=stage, + entity=entity, + versions=versions, + sex_id=1, + omega=omega, + predictive_validity=predictive_validity, + ) + forecast_female = _fit_and_predict_model( + past_data=stripped_input_female, + years=years, + draws=draws, + cov_data_list=cov_data_list, + node_models=node_models, + scenario_quantiles=SCENARIO_QUANTILES if use_scenario_quantiles else None, + gbd_round_id=gbd_round_id, + stage=stage, + entity=entity, + versions=versions, + sex_id=2, + omega=omega, + predictive_validity=predictive_validity, + ) + forecast_data = xr.concat([forecast_male, forecast_female], dim="sex_id") + + # Expand forecast data to include point coords and single coord dims + # that were stripped off before forecasting. + expanded_output_data = processing.expand_single_coord_dims( + forecast_data, prepped_input_data + ) + + prepped_output_data = processor.post_process( + expanded_output_data, + cleaned_data_sub.sel(location_id=list(intersection_locations)), + ) + + # add to model weights dataframe (if predictive_validity = True), + # or save the forecasts (if predictive_validity = False) + if predictive_validity: + all_omega_pv_results.append( + pv.calculate_predictive_validity( + forecast=prepped_output_data, holdouts=dep_data, omega=omega + ) + ) + + else: + # post-processing ("making sure the scenarios do not have uncertainty") + + prepped_output_data = prepped_output_data.transpose( + *["draw", "sex_id", "location_id", "year_id", "scenario", "age_group_id"] + ) + + # Expand back to all national ids with zeros if subset to malaria locs + if entity in EntityConstants.MALARIA_ENTITIES: + location_set = location.get_location_set( + gbd_round_id=gbd_round_id, + include_aggregates=False, + national_only=True, + ) + location_ids = location_set[DimensionConstants.LOCATION_ID].tolist() + prepped_output_data = expand_dimensions( + prepped_output_data, location_id=location_ids, fill_value=0 + ) + + for scenario in [-1, 1]: + prepped_output_data.loc[{"scenario": scenario}] = prepped_output_data.sel( + scenario=scenario + ).mean("draw") + + output_version_metadata = versions.get(past_or_future="future", stage=stage) + + output_file_spec = FHSFileSpec( + version_metadata=output_version_metadata, + sub_path=( + FileSystemConstants.SUBMODEL_FOLDER, + model_name, + subfolder, + ), + filename=f"{file_name}_{omega}.nc", + ) + + save_xr_scenario( + xr_obj=prepped_output_data, + file_spec=output_file_spec, + metric="rate", + space="identity", + ) + + if predictive_validity: + pv_df = pv.finalize_pv_data(pv_list=all_omega_pv_results, entity=entity) + + pv.save_predictive_validity( + file_name=file_name, + gbd_round_id=gbd_round_id, + model_name=model_name, + pv_df=pv_df, + stage=stage, + subfolder=subfolder, + versions=versions, + ) + + +def _calculate_weighted_se( + df: pd.DataFrame, past_data: pd.DataFrame, omega: float +) -> pd.DataFrame: + """Calculate weighted standard error; used in fitting and predicting MRBRT model.""" + renormalized_years = df["year_id"] - df["year_id"].min() + 1 + df[str(past_data.name) + ModelConstants.STANDARD_ERROR_SUFFIX] = df[ + str(past_data.name) + ModelConstants.STANDARD_ERROR_SUFFIX + ] / np.sqrt(renormalized_years**omega) + return df + + +def _fit_and_predict_model( + past_data: xr.DataArray, + years: YearRange, + draws: int, + cov_data_list: Optional[List[xr.DataArray]], + node_models: List, + scenario_quantiles: Optional[Dict], + gbd_round_id: int, + stage: str, + entity: str, + versions: Versions, + sex_id: Optional[int], + omega: float, + predictive_validity: bool, +) -> xr.DataArray: + """Instantiate, fit, save coefficients, and predict for model.""" + + def df_func(df: pd.DataFrame) -> pd.DataFrame: + return _calculate_weighted_se(df, past_data, omega) + + model_instance = StagewiseMRBRT( + past_data=past_data, + years=years, + draws=draws, + covariate_data=cov_data_list, + node_models=node_models, + study_id_cols=STUDY_ID_COLS, + scenario_quantiles=scenario_quantiles, + df_func=df_func, + gbd_round_id=gbd_round_id, + ) + + model_instance.fit() + + save_entity = ( + f"{entity}_sex_id_{sex_id}_omega_{omega}" if sex_id else f"{entity}_omega_{omega}" + ) + + model_instance.save_coefficients( + output_dir=FHSDirSpec(versions.get("future", stage)), entity=save_entity + ) + + if predictive_validity: + forecast_data = model_instance.predict() + else: + forecast_data = model_instance.predict() + forecast_data = limit_scenario_quantiles(forecast_data) + + return forecast_data + + +def _get_covariate_data( + dep_var_da: xr.DataArray, + covariates: Dict[str, processing.BaseProcessor], + versions: Versions, + years: YearRange, + gbd_round_id: int, + draws: int, + predictive_validity: bool, +) -> List[xr.DataArray]: + """Return a list of prepped dataarray for all of the covariates.""" + cov_data_list = [] + for cov_stage, cov_processor in covariates.items(): + cov_file = cov_stage + filename = f"{cov_file}.nc" + + cov_past_version_metadata = versions.get(past_or_future="past", stage=cov_stage) + + cov_past_file_spec = FHSFileSpec( + version_metadata=cov_past_version_metadata, filename=filename + ) + + if predictive_validity: + cov_future_file_spec = cov_past_file_spec + else: + cov_future_version_metadata = versions.get( + past_or_future="future", stage=cov_stage + ) + + cov_future_file_spec = FHSFileSpec( + version_metadata=cov_future_version_metadata, filename=filename + ) + + cov_past_data = open_xr_scenario(cov_past_file_spec).sel(year_id=years.past_years) + cov_future_data = ( + open_xr_scenario(cov_future_file_spec) + .sel(year_id=years.forecast_years) + .rename(cov_stage) + ) + + intersection_locations = set(dep_var_da.location_id.values).intersection( + set(cov_past_data.location_id.values) + ) + intersection_locations = intersection_locations.intersection( + set(cov_future_data.location_id.values) + ) + intersection_locations = list(intersection_locations) + + cov_past_data = cov_past_data.sel(location_id=intersection_locations) + + cov_future_data = cov_future_data.sel(location_id=intersection_locations) + + dep_var_da = dep_var_da.sel(location_id=intersection_locations) + + if DimensionConstants.STATISTIC in cov_past_data.dims: + cov_past_data = cov_past_data.sel(statistic=DimensionConstants.MEAN, drop=True) + if DimensionConstants.STATISTIC in cov_future_data.dims: + cov_future_data = cov_future_data.sel(statistic=DimensionConstants.MEAN, drop=True) + + prepped_cov_data = processing.clean_covariate_data( + past=cov_past_data, + forecast=cov_future_data, + dep_var=dep_var_da, + years=years, + draws=draws, + gbd_round_id=gbd_round_id, + ) + if DimensionConstants.SCENARIO not in prepped_cov_data.dims: + prepped_cov_data = prepped_cov_data.expand_dims( + scenario=[ScenarioConstants.REFERENCE_SCENARIO_COORD] + ) + + transformed_cov_data = cov_processor.pre_process(prepped_cov_data) + + try: + assert_shared_coords_same( + transformed_cov_data, dep_var_da.sel(year_id=years.past_end, drop=True) + ) + except IndexError as ce: + raise IndexError(f"After pre-processing {cov_stage}, " + str(ce)) + + cov_data_list.append(transformed_cov_data) + + return cov_data_list + + +def _get_covariates_and_node_models( + years: YearRange, + gbd_round_id: int, + mrbrt_cov_stage1: str, + mrbrt_cov_stage2: str, +) -> tuple[dict[str, processing.BaseProcessor], list]: + """Set up covariate list and node_models based on the required covariates.""" + if mrbrt_cov_stage1 == EntityConstants.ACT_ITN_COVARIATE: + covariates = { + "malaria_act": processing.NoTransformProcessor( + years=years, gbd_round_id=gbd_round_id + ), + "malaria_itn": processing.NoTransformProcessor( + years=years, gbd_round_id=gbd_round_id + ), + } + stage_1_cov_models = [ + LinearCovModel("intercept", use_re=True), + LinearCovModel("malaria_itn"), + LinearCovModel("malaria_act"), + ] + else: + covariates = { + mrbrt_cov_stage1: processing.NoTransformProcessor( + years=years, gbd_round_id=gbd_round_id, no_mean=True + ) + } + stage_1_cov_models = [ + LinearCovModel("intercept"), + LinearCovModel( + mrbrt_cov_stage1, + use_spline=True, + spline_knots=np.linspace(0.0, 1.0, 5), + spline_l_linear=True, + spline_r_linear=True, + ), + ] + + node_models = [ + OverallModel(cov_models=stage_1_cov_models), + StudyModel( + cov_models=[ + LinearCovModel(alt_cov="intercept"), + LinearCovModel(alt_cov=mrbrt_cov_stage2), + ] + ), + ] + return covariates, node_models + + +def _subset_ages_and_sexes( + da: xr.DataArray, gbd_round_id: int, age_standardize: bool +) -> xr.DataArray: + """Subset the da to the correct ages and sexes for the given stage.""" + logger.info("in _subset_ages_and_sexes()") + + if "sex_id" not in da.dims or list(da.sex_id.values) != [3]: + da = filter.make_most_detailed_sex(da) + if age_standardize: + da = filter.make_most_detailed_age(data=da, gbd_round_id=gbd_round_id) + + return da + + +def _subset_locations( + da: xr.DataArray, + entity: str, + gbd_round_id: int, + national_only: bool, + mrbrt_cov_stage1: str, +) -> xr.DataArray: + """Subset locations to appropriate location_ids.""" + da = filter.make_most_detailed_location( + data=da, + gbd_round_id=gbd_round_id, + national_only=national_only, + ) + if entity == EntityConstants.MALARIA: + # Use locations with malaria act and itn data + if mrbrt_cov_stage1 == EntityConstants.ACT_ITN_COVARIATE: + da = da.sel(location_id=LocationConstants.MALARIA_ACT_ITN_LOCS) + elif mrbrt_cov_stage1 == StageConstants.SDI: + da = da.sel(location_id=LocationConstants.NON_MALARIA_ACT_ITN_LOCS) + else: + raise IndexError("Malaria is not forecasted with these covariates!") + + return da + + +def limit_scenario_quantiles(da: xr.DataArray) -> xr.DataArray: + """Restrict scenarios so reference does not go outside of better/worse.""" + worse = da.sel(scenario=ScenarioConstants.WORSE_SCENARIO_COORD, drop=True) + better = da.sel(scenario=ScenarioConstants.BETTER_SCENARIO_COORD, drop=True) + ref = da.sel(scenario=ScenarioConstants.REFERENCE_SCENARIO_COORD, drop=True) + + limited_worse = expand_dimensions( + worse.where(worse > ref).fillna(ref), + scenario=[ScenarioConstants.WORSE_SCENARIO_COORD], + ) + limited_better = expand_dimensions( + better.where(better < ref).fillna(ref), + scenario=[ScenarioConstants.BETTER_SCENARIO_COORD], + ) + + limited_da = xr.concat( + [ + limited_worse, + expand_dimensions(ref, scenario=[ScenarioConstants.REFERENCE_SCENARIO_COORD]), + limited_better, + ], + dim="scenario", + ) + + return limited_da diff --git a/gbd_2021/disease_burden_forecast_code/risk_factors/paf/compute_paf.py b/gbd_2021/disease_burden_forecast_code/risk_factors/paf/compute_paf.py new file mode 100644 index 0000000..716ace7 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/risk_factors/paf/compute_paf.py @@ -0,0 +1,869 @@ +r"""Compute and export all the cause-risk-pair PAFs for given acause. + +That involves: + +1) Finding all risks associated with said acause +2) Pulling SEVs and cause-risk-specific RRs from inputs +3) Compute PAF as ``paf = 1 - 1 / (sev * (rr - 1) + 1)`` + +About the two inputs: + +1) SEV. There's a separate directory for non-vaccine SEV and for vaccine SEV. +2) RR. Like SEV, this is further divided into a non-vaccine and vaccine. + +Example call (pulling gbd paf from get_draws): + +.. code:: bash + + python compute_paf.py --acause cvd_ihd --rei metab_bmi --version test \ + --directly-modeled-paf 20190419_dm_pafs \ + --sev 20180412_trunc_widerbounds --rrmax 20180407_paf1_update \ + --vaccine-sev 20180319_new_sdi --vaccine-rrmax 20171205_refresh \ + --gbd-round-id 4 --years 1990:2017:2040 --draws 100 + +If there's already a cleaned version of gbd cause-risk PAFs stored, +one may access it via the --gbd-paf-version flag to bypass using get_draws(): + +.. code:: bash + + python compute_paf.py --acause cvd_ihd --rei metab_bmi --version test \ + --directly-modeled-paf 20190419_dm_pafs \ + --sev 20180412_trunc_widerbounds --rrmax 20180407_paf1_update \ + --vaccine-sev 20180319_new_sdi --vaccine-rrmax 20171205_refresh \ + --gbd-paf-version 20180521_2016_gbd \ + --gbd-round-id 4 --years 1990:2017:2040 --draws 100 + +If there's already a cleaned version of gbd cause-risk PAFs stored, +one may access it via the --gbd-paf-version flag to bypass using get_draws() + +Note that the 'draws' input arg is not only required, it also entails +up/down-sampling if any of the input files have number of draws not equal +to 'draws'. +""" + +import gc +from pathlib import Path +from typing import List, Optional, Tuple + +import xarray as xr +from fhs_lib_data_transformation.lib.resample import resample +from fhs_lib_database_interface.lib.query import cause +from fhs_lib_file_interface.lib.query import mediation +from fhs_lib_file_interface.lib.query import rrmax as rrmax_functions +from fhs_lib_file_interface.lib.symlink_file import symlink_file_to_directory +from fhs_lib_file_interface.lib.version_metadata import ( + FHSDirSpec, + FHSFileSpec, + VersionMetadata, +) +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr_scenario +from fhs_lib_year_range_manager.lib.year_range import YearRange +from scipy.special import expit, logit +from tiny_structured_logger.lib.fhs_logging import get_logger + +from fhs_pipeline_scalars.lib.constants import PAFConstants +from fhs_pipeline_scalars.lib.forecasting_db import ( + get_most_detailed_acause_related_risks, + is_maybe_negative_paf, +) +from fhs_pipeline_scalars.lib.utils import ( + conditionally_triggered_transformations, + product_of_mediation, + save_paf, +) + +logger = get_logger() + + +def _read_sev( + rei: str, + sev: str, + past_sev: str, + vaccine_sev: str, + past_vaccine_sev: str, + gbd_round_id: int, + years: YearRange, + draws: int, +) -> xr.DataArray: + """Read in SEV [from the vaccine subdir if `sev` is a vaccine SEV]. + + Args: + rei (str): risk, could also be vaccine intervention. + sev (str): upstrem sev version. + past_sev (str): past (GBD) sev version. + vaccine_sev (str): input vaccine sev version. + gbd_round_id (int): gbd round id + years (YearRange): [past_start, forecast_start, forecast_end] years. + draws (int): number of draws for output file. This means input files + will be up/down-sampled to meet this criterion. + + Returns: + (xr.DataArray): SEV in dataarray form. + """ + filename = f"{rei}.nc" + + # Determine the effective versions and stage + if rei in PAFConstants.VACCINE_RISKS: + effective_future_version = vaccine_sev + effective_past_version = past_vaccine_sev + effective_stage = "vaccine" + else: + effective_future_version = sev + effective_past_version = past_sev + effective_stage = "sev" + + # Create the past and future file specifications + future_version_metadata = VersionMetadata.make( + data_source=gbd_round_id, + epoch="future", + stage=effective_stage, + version=effective_future_version, + ) + future_file_spec = FHSFileSpec(version_metadata=future_version_metadata, filename=filename) + + past_version_metadata = VersionMetadata.make( + data_source=gbd_round_id, + epoch="past", + stage=effective_stage, + version=effective_past_version, + ) + past_file_spec = FHSFileSpec(version_metadata=past_version_metadata, filename=filename) + + # Load the past and future data + future = open_xr_scenario(future_file_spec).sel(year_id=years.forecast_years) + past = open_xr_scenario(past_file_spec).sel(year_id=years.past_years) + + # Vaccines are treated as an anti-risk, so we subtract it from 1.0 + if rei in PAFConstants.VACCINE_RISKS: + future = 1.0 - future + past = 1.0 - past + + # It is expected that the past and the future are exlusive only in + # year_id and scenario dims. Hence we can find the overlap in loc-age-sex. + past = past.sel( + location_id=future["location_id"], + age_group_id=future["age_group_id"], + sex_id=future["sex_id"], + ) + + past = resample(past, draws) + future = resample(future, draws) + + if "scenario" in past.dims: + past = past.sel(scenario=0).drop_vars("scenario") + + out = xr.concat([past, future], dim="year_id", coords="minimal", join="inner") + + del past, future + gc.collect() + + out = out.where(out.year_id <= years.forecast_end, drop=True) + out = conditionally_triggered_transformations(out, gbd_round_id, years) + + if rei in PAFConstants.VACCINE_RISKS: + if "scenario" in out.dims: + non_reference_scenarios = [s for s in out.scenario.values if s != 0] + for scenario in non_reference_scenarios: + out.loc[dict(scenario=scenario, year_id=years.past_years)] = out.sel( + scenario=0, year_id=years.past_years + ) + + return out + + +def _read_and_process_rrmax( + acause: str, + rei: str, + rrmax: str, + vaccine_rrmax: str, + gbd_round_id: int, + years: YearRange, + draws: int, +) -> xr.DataArray: + """Identify correct parameters for `read_rrmax` and postprocess returned xarray. + + Args: + acause (str): analytical cause. + rei (str): risk, could also be vaccine intervention. + rrmax (str): input rrmax version + vaccine_rrmax (str): input vaccine rrmax version. + gbd_round_id (int): gbd round id. + years (YearRange): [past_start, forecast_start, forecast_end] years. + draws (int): number of draws for output file. This means input files + will be up/down-sampled to meet this criterion. + + Returns: + xr.DataArray: RRmax in dataarray form. + """ + if acause == "tb": # NOTE: RRMax from arbitrary child cause only for TB. + cause_id = PAFConstants.TB_OTHER_CAUSE_ID + else: + cause_id = cause.get_cause_id(acause=acause) + + if rei in PAFConstants.VACCINE_RISKS: + version = vaccine_rrmax + else: + version = rrmax + + out = rrmax_functions.read_rrmax( + acause=acause, + cause_id=cause_id, + rei=rei, + gbd_round_id=gbd_round_id, + version=version, + draws=draws, + ) + out = out.where(out != 0, drop=True) + + out = conditionally_triggered_transformations(out, gbd_round_id, years) + + if rei in PAFConstants.VACCINE_RISKS: + # The values stored in vaccine data files are actually not RR, + # but rather + # r = Incidence[infection | vax] / Incidence[infection | no vax], + # as "percent reduction of diseased cases if vaccinated", + # and should be r < 1. + # We compute the actual RR as 1/r. + # Any value > 1 should be capped. + out = out.where(out <= PAFConstants.PAF_UPPER_BOUND).fillna( + PAFConstants.PAF_UPPER_BOUND + ) + out = 1.0 / out # as mentioned earlier, we compute RR as 1/r. + + if "draw" in out.dims and len(out["draw"]) != draws: + out = resample(out, draws) + + return out + + +def _get_gbd_paf( + acause: str, + rei: str, + gbd_paf_version: str, + gbd_round_id: int, +) -> Optional[xr.DataArray]: + """Load PAF from the given gbd_paf_version. + + Certain acauses may be given that don't exist in the GBD, and for these we will actually + load the *parent*'s PAF data. + + Args: + acause (str): analytical cause. + rei (str): risk, could also be vaccine intervention. + gbd_paf_version (str): the version name to load past (GBD) PAF data from. + gbd_round_id (int): gbd round id + location_ids (list[int]): locations to get pafs from. + draws (int): number of draws for output file. This means input files + will be up/down-sampled to meet this criterion. + + Returns: + Optional[xr.DataArray]: DataArray with complete demographic indices, or None if the rei + is in `PAFConstants.VACCINE_RISKS`. + """ + acause_rei_version_metadata = VersionMetadata.make( + data_source=gbd_round_id, epoch="past", stage="paf", version=gbd_paf_version + ) + acause_rei_file_spec = FHSFileSpec( + version_metadata=acause_rei_version_metadata, + sub_path=("risk_acause_specific",), + filename=f"{acause}_{rei}.nc", + ) + + if rei in PAFConstants.VACCINE_RISKS: + return None + + if acause_rei_file_spec.data_path().exists(): + return open_xr_scenario(acause_rei_file_spec) + + # "etiology causes" must be pulled from the parent cause. + elif acause in cause.non_gbd_causes(): + parent_acause = cause.get_parent(acause=acause, gbd_round_id=gbd_round_id) + parent_acause_rei_file_spec = FHSFileSpec( + version_metadata=acause_rei_version_metadata, + sub_path=("risk_acause_specific",), + filename=f"{parent_acause}_{rei}.nc", + ) + + if parent_acause_rei_file_spec.data_path().exists(): + return open_xr_scenario(parent_acause_rei_file_spec) + + else: # the etiology's parent file does not exist + raise OSError(f"{parent_acause_rei_file_spec.data_path()} does not exist.") + + else: # the file does not exist + raise OSError(f"{acause_rei_file_spec.data_path()} does not exist.") + + +def _data_cleaning_for_paf( + paf: xr.DataArray, maybe_negative_paf: Optional[bool] = False +) -> xr.DataArray: + """Encode data cleaning customized for PAF. + + 1.) set non-finite values (nan, -inf, inf) to 0 + 2.) set > UPPER_BOUND to UPPER_BOUND + 3.) set < LOWER_BOUND to LOWER_BOUND + + Non-finite PAF values likely come from outer-join mismatches between + sev and rr, and we set those to 0 + for (2) and (3), per discussion with central comp, PAF values over + boundaries are simply capped, not resampled. + + Args: + paf (xr.DataArray): dataarray of PAF values + maybe_negative_paf (Optional[bool]): ``True`` is PAF is allowed to be + negative. Defaults to ``False``. + + Returns: + (xr.DataArray): + cleaned dataarray. + """ + if maybe_negative_paf: + lower_bound = PAFConstants.PAF_LOWER_BOUND + else: + lower_bound = 1 - PAFConstants.PAF_UPPER_BOUND + + return paf.fillna(lower_bound).clip(min=lower_bound, max=PAFConstants.PAF_UPPER_BOUND) + + +def _compute_correction_factor( + fhs_paf: xr.DataArray, gbd_paf: xr.DataArray, maybe_negative_paf: Optional[bool] = False +) -> xr.DataArray: + r"""Bias-correct Forecasted PAF by GBD PAF. + + This is essentially + an "intercept-shift", and it happens at the last year of past (gbd round), + and hence the input args should be single-year arrays. + + Even though PAF values should logically be in the closed interval + :math:`[-1, 1]`, we expect both ``fhs_paf`` and ``gbd_paf`` to be in the + open interval :math:`(-1, 1)` due to upstream data cleaning. Furthermore, + most cause-risk pairs are *not* protective (i.e. non-negative), so are + actually expected to be in the open interval :math:`(0, 1)`. + + This method computes correction factor in logit space by transforming + the PAF values via :math:`x_{\text{corrected}} = (1 + x) / 2`, and with the + correction factor being the difference between fhs and gbd in logit space: + + .. math:: + + \text{correction-factor} = + \text{logit}(\frac{1 + \mbox{PAF}_{\text{gbd}}}{2}) + - \text{logit}(\frac{1 + \mbox{PAF}_{\text{fhs}}}{2}) + + For cause-risk pairs that *cannot* have protective (i.e. negative PAFs), + the correction factor equation is: + + .. math:: + + \text{correction-factor} = + \text{logit}(\mbox{PAF}_{\text{gbd}}) + - \text{logit}(\mbox{PAF}_{\text{fhs}}) + + This correction factor will later be added to the forecasted PAF values + in logit space, prior to back-transformation. + + By default, non-finite correction factor values are reset to 0. These + non-finite values could come from mismatched cells from outer-join + arithmetic, commonly found along the ``age_group_id`` dimension between + GBD and FHS (np.nan) + + Args: + fhs_paf (xr.DataArray): forecasted PAF at gbd round year. + gbd_paf (xr.DataArray): gbd PAF. Only contains the gbd round year. + maybe_negative_paf (Optional[bool]): ``True`` is PAF is allowed to be + negative. Defaults to ``False``. + + Returns: + (xr.DataArray): + correction factor. + """ + # first make sure the input args are year-agnostic + if ("year_id" in fhs_paf.coords or "year_id" in fhs_paf.dims) and fhs_paf[ + "year_id" + ].size > 1: + raise ValueError("fhs_paf has year dim larger than size=1") + + if ("year_id" in gbd_paf.coords or "year_id" in gbd_paf.dims) and gbd_paf[ + "year_id" + ].size > 1: + raise ValueError("gbd_paf has year dim larger than size=1") + + with xr.set_options(arithmetic_join="outer"): + if maybe_negative_paf: + correction_factor = logit((1 + gbd_paf) / 2) - logit((1 + fhs_paf) / 2) + else: + correction_factor = logit(gbd_paf) - logit(fhs_paf) + + # the above outer-join could result in nulls. + # by default, null correction factor values are reset to 0 + correction_factor = correction_factor.fillna(0) + + # We required the inputs to have just one year_id; now we remove that dim/coord from the + # result. + if "year_id" in correction_factor.dims: + correction_factor = correction_factor.squeeze("year_id") + if "year_id" in correction_factor.coords: + correction_factor = correction_factor.drop_vars("year_id") + + return correction_factor + + +def _apply_paf_correction( + fhs_paf: xr.DataArray, cf: xr.DataArray, maybe_negative_paf: bool = False +) -> xr.DataArray: + r"""Correct forecasted PAF in logit space and back-transform. + + .. math:: + + \mbox{PAF}_{\text{corrected}} = 2 * \text{expit}(\text{logit} + (\frac{1 + \mbox{PAF}_{\text{FHS}}}{2} + \text{correction-factor} + ) - 1 + + for cause-risk pairs that *can* be protective. If cause-risk pairs are not + allowed to protective (the majority of them), then the equation for the + corrected PAF is + + .. math:: + + \mbox{PAF}_{\text{corrected}} = + \mbox{PAF}_{\text{FHS}} + \text{correction-factor} + + Because a logit function's argument x is :math:`[0, 1]` and a protective + PAF can be in the range :math:`[-1, 1]`, a natural mapping from PAF space + to logit space is :math:`x_{\text{corrected}} = (1 + x) / 2`. The + back-transform to PAF space is hence :math:`2 * expit( logit(x) ) + 1`. + + The correction factor ``cf`` is also computed within the same logit space. + Once the correction is made in logit space, the resultant quantity + is mapped back to PAF-space via the aforementioned back-transform. + + Args: + fhs_paf (xr.DataArray): forecasted PAF. Has many years along year_id. + cf (xr.DataArray): correction factor, should not have year_id dim. + maybe_negative_paf (bool): ``True`` is PAF is allowed to be + negative. Defaults to ``False``. + + Returns: + (xr.DataArray): + correct PAF. + """ + if maybe_negative_paf: + corrected_paf = 2 * expit(logit((1 + fhs_paf) / 2) + cf) - 1 + else: + corrected_paf = expit(logit(fhs_paf) + cf) + + return corrected_paf + + +def _get_paf( + acause: str, + rei: str, + years: YearRange, + gbd_round_id: int, + draws: int, + sev: str, + past_sev: str, + rrmax: str, + vaccine_sev: str, + past_vaccine_sev: str, + vaccine_rrmax: str, +) -> Tuple[xr.DataArray, List[int]]: + """Calculate a PAF from 1) the SEV for the risk, and 2) the RRMax for the cause-risk pair. + + In particular, we load SEVs from the specified future versions ``sev``, ``vaccine_sev``, + plus past versions ``past_sev``, ``past_vaccine_sev``. And we load the RRMax data from the + ``rrmax``, ``vaccine_rrmax`` versions. + """ + sev_da = _read_sev( + rei=rei, + sev=sev, + past_sev=past_sev, + vaccine_sev=vaccine_sev, + past_vaccine_sev=past_vaccine_sev, + gbd_round_id=gbd_round_id, + years=years, + draws=draws, + ) + + rrmax_da = _read_and_process_rrmax( + acause=acause, + rei=rei, + rrmax=rrmax, + vaccine_rrmax=vaccine_rrmax, + gbd_round_id=gbd_round_id, + years=years, + draws=draws, + ) + + # Make 0 values into 1s + defaulted_values = rrmax_da.where(rrmax_da.values != 0) + rrmax_da = defaulted_values.fillna(1) + + # estimated cause-risk-specific paf + with xr.set_options(arithmetic_join="inner"): + paf = 1 - 1 / (sev_da * (rrmax_da - 1) + 1) + + location_ids = sev_da["location_id"].values.tolist() + + del sev_da, rrmax_da + gc.collect() + + return paf, location_ids + + +def compute_paf( + acause: str, + rei: str, + version: str, + years: YearRange, + gbd_round_id: int, + draws: int, + sev: str, + past_sev: str, + rrmax: str, + vaccine_sev: str, + past_vaccine_sev: str, + vaccine_rrmax: str, + gbd_paf_version: str, + save_past_data: bool, +) -> None: + r"""Compute and export PAF for the given acause-risk pair. + + Said PAF is exported to FILEPATH. + + Args: + acause (str): analytical cause. + rei (str): rei, or commonly called risk. + version (str): version to export to. + years (YearRange): [past_start, forecast_start, forecast_end] years. + gbd_round_id (int): gbd round id. + draws (int): number of draws for output file. This means input files + will be up/down-sampled to meet this criterion. + sev (str): input future sev version. + past_sev (str): input past sev version. + rrmax (str): input rrmax version. + vaccine_sev (str): input vaccine sev version. + past_vaccine_sev (str): input past vaccine sev version. + vaccine_rrmax (str): input vaccine rrmax version. + gbd_paf_version (str): gbd_paf version to read from, + if not downloading from get_draws(). + save_past_data (bool): if true, + save files for past data in FILEPATH. + """ + paf, location_ids = _get_paf( + acause, + rei, + years, + gbd_round_id, + draws, + sev, + past_sev, + rrmax, + vaccine_sev, + past_vaccine_sev, + vaccine_rrmax, + ) + + maybe_negative_paf = is_maybe_negative_paf(acause, rei, gbd_round_id) + + # Forecasted PAFs are cleaned first before further processing + paf = _data_cleaning_for_paf(paf, maybe_negative_paf) + + # now ping get_draws for gbd paf values + logger.info(f"Got estimated paf for {acause}_{rei}. Pulling gbd paf...") + + paf, correction_factor = _correct_paf( + gbd_round_id, + acause, + rei, + draws, + paf, + maybe_negative_paf, + gbd_paf_version, + location_ids, + ) + + paf_unmediated = _apply_mediation(acause, rei, gbd_round_id, paf) + + # we need to save the results separately in "past" and "future" + if save_past_data: + past_future_dict = {"past": years.past_years, "future": years.forecast_years} + else: + past_future_dict = {"future": years.forecast_years} + + for p_or_f, yrs in past_future_dict.items(): + save_paf( + paf=paf.sel(year_id=yrs), + gbd_round_id=gbd_round_id, + past_or_future=p_or_f, + version=version, + acause=acause, + cluster_risk=rei, + sev=sev, + rrmax=rrmax, + vaccine_sev=vaccine_sev, + vaccine_rrmax=vaccine_rrmax, + gbd_paf_version=gbd_paf_version, + ) + + if paf_unmediated is not None: + save_paf( + paf=paf_unmediated.sel(year_id=yrs), + gbd_round_id=gbd_round_id, + past_or_future=p_or_f, + version=version, + acause=acause, + cluster_risk=rei, + file_suffix="_unmediated", + sev=sev, + rrmax=rrmax, + vaccine_sev=vaccine_sev, + vaccine_rrmax=vaccine_rrmax, + gbd_paf_version=gbd_paf_version, + ) + + # now saving cause-risk-specific correction factor + if p_or_f == "past": + save_paf( + paf=correction_factor, + gbd_round_id=gbd_round_id, + past_or_future=p_or_f, + version=version, + acause=acause, + cluster_risk=rei, + file_suffix="_cf", + space="logit", + sev=sev, + rrmax=rrmax, + vaccine_sev=vaccine_sev, + vaccine_rrmax=vaccine_rrmax, + gbd_paf_version=gbd_paf_version, + ) + + del paf, correction_factor + gc.collect() + + +def _apply_mediation( + acause: str, rei: str, gbd_round_id: int, paf: xr.DataArray +) -> xr.DataArray: + """Apply to paf the mediation values for the given acause, rei.""" + # Now determine if we need to save unmediated PAF as well + med = mediation.get_mediation_matrix(gbd_round_id) + + if acause in med["acause"].values and rei in med["rei"].values: + mediator_matrix = med.sel(acause=acause, rei=rei) + if (mediator_matrix > 0).any(): # if there is ANY mediation + all_risks = get_most_detailed_acause_related_risks(acause, gbd_round_id) + # this determines how to compute the unmediated RR + mediation_prod = product_of_mediation(acause, rei, all_risks, gbd_round_id) + # reverse engineer: sev * (rrmax - 1) = 1 / (1 - paf) - 1 + # rrmax^U - 1 = (rrmax - 1) * mediation_prod + # so sev * (rrmax^U - 1) = (1 / (1 - paf) - 1) * mediation_prod + sev_rrmax_1_u = (1 / (1 - paf) - 1) * mediation_prod + paf_unmediated = 1 - 1 / (sev_rrmax_1_u + 1) + else: + paf_unmediated = None + else: + paf_unmediated = None + return paf_unmediated + + +def _correct_paf( + gbd_round_id: int, + acause: str, + rei: str, + draws: int, + paf: xr.DataArray, + maybe_negative_paf: bool, + gbd_paf_version: str, + location_ids: List[int], +) -> Tuple[xr.DataArray, xr.DataArray]: + """Determine and apply correction to ``paf`` so that it matches data in gbd_paf_version. + + Essentially, the correction is the difference, in logit space, betweeen the paf and values + in the last-past year, and it is simply added in logit space. The ``paf`` data is taken in + identity space, so logit is applied and reversed. + """ + gbd_paf = _get_gbd_paf(acause, rei, gbd_paf_version, gbd_round_id) + if gbd_paf is not None: + gbd_paf = gbd_paf.sel(location_id=location_ids) + gbd_paf = resample(gbd_paf, draws) + + logger.info(f"Pulled gbd paf for {acause}_{rei}. Computing adjusted paf...") + + # compute correction factor and perform adjustment + + # First make sure there's no COMPLETE mismatch between paf and gbd_paf. + # If so, an error should be raised + paf.load() + gbd_paf.load() # need to force load() because dask is lazy + if (paf - gbd_paf).size == 0: # normal arithmetic is inner-join + error_message = ( + "Complete mismatch between computed and GBD in " + f"{acause}-{rei} PAF. Are you sure you used the correct " + "version of GBD PAF?" + ) + logger.error(error_message) + raise ValueError(error_message) + + gbd_paf = _data_cleaning_for_paf(gbd_paf, maybe_negative_paf) + gbd_year = max(gbd_paf.year_id.values) + + correction_factor = _compute_correction_factor( + paf.sel(year_id=gbd_year), gbd_paf.sel(year_id=gbd_year), maybe_negative_paf + ) + + del gbd_paf + gc.collect() + + paf = _apply_paf_correction(paf, correction_factor, maybe_negative_paf) + + logger.info(f"Adjusted paf for {acause}_{rei}. Now saving...") + else: # correction factor is 0, and we leave paf as is + correction_factor = xr.zeros_like(paf) + logger.info(f"paf for {acause}_{rei} not adjusted because gbd_paf is None") + + return paf, correction_factor + + +def symlink_paf_file( + acause: str, + rei: str, + gbd_paf_version: str, + calculated_paf_version: str, + pre_calculated_paf: str, + gbd_round_id: int, + paf_type: str, + save_past_data: bool, +) -> None: + """Create symlink to files with directly-modeled PAF data. + + Creates symlinks of past and future directly-modeled PAF data files to the + directory with PAFs calculated from SEVs and RRmaxes. + + Args: + acause (str): + Indicates the cause of the cause-risk pair + rei (str): + Indicates the risk of the cause-risk pair + calculated_paf_version (str): + Output version of this script where directly-modeled PAFs are + symlinked, and calculated PAFs are saved. + pre_calculated_paf (str): + The version of PAFs with the pre-calculated PAF to be symlinked + resides. Either a directly modeled PAF version or temperature PAF + version. + gbd_round_id (int): + The numeric ID representing the GBD round. + paf_type (str): + What type of PAF is this? Right now can be "directly_modeled" or + "custom_forecast". + save_past_data (bool): if true, + save files for past data in FILEPATH + + Raises: + RuntimeError: + If symlink sub-process fails. + """ + effective_pre_calc_versions = {"future": pre_calculated_paf} + if save_past_data: + effective_pre_calc_versions["past"] = gbd_paf_version + + for epoch, effective_pre_calc_version in effective_pre_calc_versions.items(): + # Currently, we only get custom forecasts for temperature PAFs. + # As of GBD2019, they are not saved in a sub-directory. + sub_dir_str = ( + "" + if paf_type == "custom_forecast" and epoch == "future" + else "risk_acause_specific" + ) + + # Define the calculated-PAF version and directory specifications + calculated_paf_version_metadata = VersionMetadata.make( + data_source=gbd_round_id, epoch=epoch, stage="paf", version=calculated_paf_version + ) + calculated_paf_dir_spec = FHSDirSpec( + version_metadata=calculated_paf_version_metadata, + sub_path=("risk_acause_specific",), + ) + + # Define the pre-calculated-PAF version, directory and file specifications + pre_calculated_paf_version_metadata = VersionMetadata.make( + data_source=gbd_round_id, + epoch=epoch, + stage="paf", + version=effective_pre_calc_version, + ) + pre_calculated_paf_dir_spec = FHSDirSpec( + version_metadata=pre_calculated_paf_version_metadata, sub_path=(sub_dir_str,) + ) + pre_calculated_paf_file_spec = FHSFileSpec.from_dirspec( + dir=pre_calculated_paf_dir_spec, filename=f"{acause}_{rei}.nc" + ) + + # Symlink the pre-calculated-PAF data into the calculated-PAF directory, if the + # pre-calculated-PAF data already exists + if pre_calculated_paf_file_spec.data_path().exists(): + _attempt_symlink( + source=pre_calculated_paf_file_spec.data_path(), + target=calculated_paf_dir_spec.data_path(), + target_is_directory=True, + ) + + # Otherwise, symlink the user-specifed ``acause``'s **parent** cause from the + # pre-calculated-PAF data into the calculated-PAF directory **as** the child cause + else: + # Check if cause is an etiology in GBD. + if acause not in cause.non_gbd_causes(): + raise FileNotFoundError( + f"{pre_calculated_paf_file_spec.data_path()} not found." + ) + + _symlink_parent_cause( + acause=acause, + rei=rei, + gbd_round_id=gbd_round_id, + calculated_paf_dir_spec=calculated_paf_dir_spec, + pre_calculated_paf_dir_spec=pre_calculated_paf_dir_spec, + ) + + +def _attempt_symlink(source: Path, target: Path, target_is_directory: bool) -> None: + """Tries to symlink; catches the error if the target already exists.""" + try: + symlink_file_to_directory( + source_path=source, target_path=target, target_is_directory=target_is_directory + ) + + except FileExistsError: + filename = source.name if target_is_directory else target.name + directory = target if target_is_directory else target.parent + logger.info(f"{filename} already exists in {directory}.") + + +def _symlink_parent_cause( + acause: str, + rei: str, + gbd_round_id: int, + calculated_paf_dir_spec: FHSDirSpec, + pre_calculated_paf_dir_spec: FHSDirSpec, +) -> None: + # Pull the parent acause from the database + parent_acause = cause.get_parent(acause=acause, gbd_round_id=gbd_round_id) + + # Define the full file specifications using the provided directory specifications + # Note that the pre-calculated-PAF data uses the parent acause while the calculated-PAF + # data uses the actual acause + pre_calculated_paf_file_spec = FHSFileSpec.from_dirspec( + dir=pre_calculated_paf_dir_spec, filename=f"{parent_acause}_{rei}.nc" + ) + calculated_paf_file_spec = FHSFileSpec.from_dirspec( + dir=calculated_paf_dir_spec, filename=f"{acause}_{rei}.nc" + ) + + _attempt_symlink( + source=pre_calculated_paf_file_spec.data_path(), + target=calculated_paf_file_spec.data_path(), + target_is_directory=False, + ) diff --git a/gbd_2021/disease_burden_forecast_code/risk_factors/paf/compute_scalar.py b/gbd_2021/disease_burden_forecast_code/risk_factors/paf/compute_scalar.py new file mode 100644 index 0000000..5e4aa16 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/risk_factors/paf/compute_scalar.py @@ -0,0 +1,461 @@ +"""This script computes aggregated acause specific PAFs and scalars. + +Example call: + +.. code:: bash + + python compute_scalar.py --acause whooping --version 20180321_arc_log \ + --gbd-round-id 4 --years 1990:2017:2040 + +Outputs: + +1) Risk-acause specific scalar. Exported to the input paf/{version} +2) Acause specific scalars. Exported to scalar/{version} +""" + +import gc +import os +from collections import defaultdict +from typing import Any, Dict, List, Optional + +import numpy as np +import xarray as xr +from fhs_lib_data_transformation.lib.dimension_transformation import expand_dimensions +from fhs_lib_data_transformation.lib.resample import resample +from fhs_lib_database_interface.lib.query.risk import get_risk_hierarchy +from fhs_lib_file_interface.lib.query import mediation +from fhs_lib_file_interface.lib.version_metadata import FHSFileSpec, VersionMetadata +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr_scenario, save_xr_scenario +from fhs_lib_year_range_manager.lib.year_range import YearRange +from tiny_structured_logger.lib.fhs_logging import get_logger + +from fhs_pipeline_scalars.lib.constants import PAFConstants +from fhs_pipeline_scalars.lib.forecasting_db import get_most_detailed_acause_related_risks +from fhs_pipeline_scalars.lib.utils import data_value_check, save_paf + +logger = get_logger() + + +def read_paf( + acause: str, + risk: str, + gbd_round_id: int, + past_or_future: str, + version: str, + draws: int, + years: YearRange, + reference_only: bool, + unmediated: bool = False, + custom_paf: Optional[str] = None, +) -> xr.DataArray: + """Read past or forecast PAF. + + Args: + acause (str): cause name. + risk (str): risk name. + gbd_round_id (int): gbd round id. + past_or_future (str): "past" or "forecast". + version (str): str indiciating folder where data comes from. + draw (int): number of draws to keep. + years: past_start:forecast_start:forecast_end. + reference_only (bool): whether to compute reference only. + unmediated (bool): whether to read in unmediated PAF. Defaults to False. + custom_paf (Optional[str]): Version of custom PAFs. + + Returns: + paf (xr.DataArray): data array of PAF. + """ + input_file_name = f"{acause}_{risk}" + if unmediated: + input_file_name += "_unmediated" + + input_paf_file_spec = _decide_paf_file_spec_to_use( + custom_paf=custom_paf, + gbd_round_id=gbd_round_id, + input_file_name=input_file_name, + past_or_future=past_or_future, + version=version, + ) + + paf = resample(open_xr_scenario(file_spec=input_paf_file_spec), draws) + + paf = paf.where(paf.year_id <= years.forecast_end, drop=True) + + if reference_only and ("scenario" in paf.dims): + paf = paf.sel(scenario=[0]) + + # some dimensional pruning before returning + if "acause" in paf.coords: + if "acause" in paf.dims: + paf = paf.squeeze("acause") + paf = paf.reset_coords(["acause"], drop=True) + + if "rei" in paf.coords: + if "rei" in paf.dims: + paf = paf.squeeze("rei") + paf = paf.reset_coords(["rei"], drop=True) + + return paf.mean("draw") + + +def _decide_paf_file_spec_to_use( + gbd_round_id: int, + input_file_name: str, + past_or_future: str, + version: str, + custom_paf: Optional[str], +) -> FHSFileSpec: + """Decide the PAF file spec to use between the PAF and optional customized PAF inputs. + + If the ``custom_paf`` was specified and contains the ``filename``, + then that will be used. If the ``version`` contains the ``filename`` + then that will be used. Otherwise an error will be raised if the file + cannot be found between the two versions. + """ + reference_paf_version_metadata = VersionMetadata.make( + data_source=gbd_round_id, + epoch=past_or_future, + stage="paf", + version=version, + ) + + reference_paf_file_spec = FHSFileSpec( + version_metadata=reference_paf_version_metadata, + sub_path=("risk_acause_specific",), + filename=f"{input_file_name}.nc", + ) + + custom_paf_file_spec = None + if custom_paf: + custom_paf_version_metadata = VersionMetadata.make( + data_source=gbd_round_id, + epoch=past_or_future, + stage="paf", + version=custom_paf, + ) + + custom_paf_file_spec = FHSFileSpec( + version_metadata=custom_paf_version_metadata, + sub_path=("risk_acause_specific",), + filename=f"{input_file_name}.nc", + ) + + if custom_paf_file_spec and custom_paf_file_spec.data_path().PFN().exists(): + input_paf_file_spec = custom_paf_file_spec + elif reference_paf_file_spec.data_path().PFN().exists(): + input_paf_file_spec = reference_paf_file_spec + else: + raise FileNotFoundError( + f"The target file {input_file_name} could not be read from " + f"either the custom PAF version or the reference PAF version." + ) + + return input_paf_file_spec + + +def ancestor_descendant_risks( + most_detailed_risks: List[str], gbd_round_id: int +) -> Dict[str, List[str]]: + """Collect mapping of parent risk to all most-detailed descendant risks. + + Given some most-detailed risks, make a dict of all their ancestor risks, mapping to a list + of their most-detailed risks underneath them. + + Returns a dictionary. key: parent risk; value: list of most-detailed risks below that one. + + Args: + most_detailed_risks (List[str]): list of risks whose ancestors we want to know. + gbd_round_id (int): the GBD Round ID + + Returns: + Dict[str, List[str]]: a dict where keys are risks, and values are sub-risks. + Ex: {'_env': ['air_hap', 'wash', ...], + 'metab': ['metab_fpg', 'metab_bmi', ...], '_behav': ['activity', + 'nutrition_child', ...]} + + Raises: + TypeError: if path_to_top_parent for a particular risk is not recorded + as a string type. That would most likely be error on the db. + """ + risk_table = get_risk_hierarchy(gbd_round_id) + risk_id_dict = dict(risk_table[["rei_id", "rei"]].values) + + result = defaultdict(list) + + for leaf_risk in most_detailed_risks: + risk_specific_metadata = risk_table.query("rei == @leaf_risk") + + if not risk_specific_metadata.empty: + # comma-delimited string, like "169,202,82,83", or None + path_to_root = risk_specific_metadata["path_to_top_parent"].item() + else: + # If there is no metadata for the risk (e.g. for a vaccine in + # GBD2016 we didn't aggregate to all-risk from vaccines), then + # return an empty dict. + return result + + # NOTE the vaccines (hib, pcv, rota, measles, dtp3) currently have + # "None" listed in their path_to_top_parent. + if path_to_root: # not None, must be string + if type(path_to_root) is not str: + raise TypeError(f"{path_to_root} is not of str type") + # first of list is "_all", last of list is itself + _all_to_self_list = path_to_root.split(",") + + for ancestor_id in _all_to_self_list[:-1]: # keep _all, ignore self + ancestor_risk = risk_id_dict[int(ancestor_id)] + result[ancestor_risk].append(leaf_risk) + + return result + + +def aggregate_paf( + acause: str, + risks: List[str], + gbd_round_id: int, + past_or_future: str, + version: str, + draws: int, + years: YearRange, + reference_only: bool, + cluster_risk: Optional[str] = None, + custom_paf: Optional[str] = None, +) -> Optional[xr.DataArray]: + """Aggregate PAFs through mediation. + + Args: + acause (str): acause. + risks (List[str]): set of risks associated with acause. + gbd_round_id (int): gbd round id. + past_or_future (str): 'past' or 'future'. + version (str): indicating folder where data comes from/goes to. + draws (int): number of draws. + years (YearRange): past_start:forecast_start:forecast_end. + reference_only (bool): if true, only return the reference scenario paf. + cluster_risk (Optional[str]): whether this is a cluster risk. Impacts the directory + where it's saved. + custom_paf (Optional[str]): Version of custom PAFs. + + Raises: + ValueError: if there are no risks for a given acause + + Returns: + paf_aggregated (Optional[xr.DataArray]): dataarray of aggregated PAF if + ``cluster_risk`` parameter is specified. + """ + logger.info(f"Start aggregating {past_or_future} PAF:") + + logger.info(f"Acause: {acause}, Risks: {risks}, Cluster Risk: {cluster_risk}") + + if len(risks) == 0: + error_message = f"0 risks for acause {acause}" + logger.error(error_message) + raise ValueError(error_message) + + # We loop over each risk, determine whether to use its total or unmediated + # PAF based on the mediation matrix, and then compute its contribution. + med = mediation.get_mediation_matrix(gbd_round_id) # the mediation matrix + + # If you enumerate risks instead of keeping tracked of not skipped + # reis, you run into an error if the first rei is skipped + # As one_minus_paf_product is never defined + reis_previously_aggregated = [] + + for i, rei in enumerate(risks): + # if cluster_risk is specified, risks is just the subset of this + # cause's risks that fall under the cluster_risk (_env, _metab, etc.) + # If cluster_risk is not specified, that risks are all the risks + # associated with this cause. + logger.info(f"Doing risk {rei}") + + unmediated = False # Set as False for now, may update to True below + + if acause in med["acause"].values and rei in med["rei"].values: + mediator_matrix = med.sel(acause=acause, rei=rei) + + # if there is *any* mediation, use pre-saved unmediated PAF + if (mediator_matrix > 0).any(): + unmediated = True + + paf = read_paf( + acause=acause, + risk=rei, + gbd_round_id=gbd_round_id, + past_or_future=past_or_future, + version=version, + draws=draws, + years=years, + reference_only=reference_only, + unmediated=unmediated, + custom_paf=custom_paf, + ) + + if not reis_previously_aggregated: + logger.debug("Index is 0. Starting the paf_prod.") + one_minus_paf_product = 1 - paf + else: + logger.debug(f"Index is {i}.") + # NOTE there no straight-forward way to check if two dataarrays + # have the same coordinates, so we broadcast indiscriminately + paf, one_minus_paf_product = xr.broadcast(paf, one_minus_paf_product) + + paf = paf.where(np.isfinite(paf)).fillna(0) + one_minus_paf_product = one_minus_paf_product.where( + np.isfinite(one_minus_paf_product) + ).fillna(1) + + one_minus_paf_product = (1 - paf) * one_minus_paf_product + del paf + gc.collect() + reis_previously_aggregated.append(rei) + + logger.info(f"Finished computing PAF for {acause}") + + paf_aggregated = 1 - one_minus_paf_product + + del one_minus_paf_product + gc.collect() + + paf_aggregated = paf_aggregated.clip( + min=PAFConstants.PAF_LOWER_BOUND, max=PAFConstants.PAF_UPPER_BOUND + ) + + save_paf( + paf=paf_aggregated, + gbd_round_id=gbd_round_id, + past_or_future=past_or_future, + version=version, + acause=acause, + cluster_risk=cluster_risk, + ) + + if cluster_risk: + return None + else: + return paf_aggregated + + +def compute_scalar( + acause: str, + version: str, + gbd_round_id: int, + no_update_past: bool, + save_past_data: bool, + draws: int, + years: YearRange, + reference_only: bool, + custom_paf: Optional[str], + **kwargs: Any, +) -> None: + """Compute and save scalars for acause, given input paf version. + + Args: + acause (str): cause to compute scalars for + version (str): date/version string pointing to folder to pull data from + gbd_round_id (int): gbd round id. + no_update_past (bool): whether to overwrite past scalars. + draws (int): number of draws to keep. + years (YearRange): past_start:forecast_start:forecast_end. + save_past_data (bool): if true, save files for past data in FILEPATH + reference_only (bool): if true, only return the reference scenario paf. + custom_paf (Optional[str]): Version of custom PAFs. + kwargs (dict[Any]): dictionary that captures all redundant kwargs. + """ + all_most_detailed_risks = get_most_detailed_acause_related_risks(acause, gbd_round_id) + + if not all_most_detailed_risks: + logger.info(f"{acause} does not have any cause-risk pafs") + return None + + if save_past_data: + past_and_future_needed = ["past", "future"] + else: + past_and_future_needed = ["future"] + + subrisk_map = ancestor_descendant_risks(all_most_detailed_risks, gbd_round_id=gbd_round_id) + # We require that the above mapping maps non-leaf (aggregate) risks to leaf risks, and so + # none of the values ("leaf risks") should include a key ("aggregate risk"). If they did, + # we would double-count by multiplying an aggregate's 1-paf together with the leaves' 1-paf + # values. + for values in subrisk_map.values(): + if subrisk_map.keys() & values: + raise ValueError("Bug: Some risks seem to be both aggregates and most-detailed!?") + + for past_or_future in past_and_future_needed: + logger.info(f"OH BOY WE'RE DOING THE: {past_or_future}") + output_version_metadata = VersionMetadata.make( + data_source=gbd_round_id, + epoch=past_or_future, + stage="scalar", + version=version, + ) + output_file_spec = FHSFileSpec( + version_metadata=output_version_metadata, filename=f"{acause}.nc" + ) + + if os.path.exists(str(output_file_spec.data_path())) and no_update_past: + continue + + # Aggregate PAF for level-1 cluster risks + # We don't need to use the PAF for scalar. + + for key in subrisk_map.keys(): # loop over all antecedent-risks + logger.info("Looping over super/parent risks.") + + leaf_descendants = subrisk_map[key] + + if leaf_descendants: + logger.info(f"Start aggregating cluster risk: {key}") + + aggregate_paf( + acause=acause, + risks=leaf_descendants, + gbd_round_id=gbd_round_id, + past_or_future=past_or_future, + version=version, + draws=draws, + years=years, + reference_only=reference_only, + cluster_risk=key, + custom_paf=custom_paf, + ) + gc.collect() + + # Aggregate PAF for all risks. + # We need to use the PAF for scalar. + paf_mediated = aggregate_paf( + acause=acause, + risks=all_most_detailed_risks, + gbd_round_id=gbd_round_id, + past_or_future=past_or_future, + version=version, + draws=draws, + years=years, + reference_only=reference_only, + custom_paf=custom_paf, + ) + + if paf_mediated is None: + logger.info("No paf_mediated. Early return.") + return + + scalar = 1.0 / (1.0 - paf_mediated) + + scalar = expand_dimensions(scalar, draw=range(draws)) + + del paf_mediated + gc.collect() + + logger.debug(f"Checking data value for {acause} scalar") + data_value_check(scalar) # make sure no NaNs or <0 in dataarray + + save_xr_scenario( + xr_obj=scalar, + file_spec=output_file_spec, + metric="number", + space="identity", + acause=acause, + version=version, + gbd_round_id=gbd_round_id, + no_update_past=str(no_update_past), + ) diff --git a/gbd_2021/disease_burden_forecast_code/risk_factors/paf/constants.py b/gbd_2021/disease_burden_forecast_code/risk_factors/paf/constants.py new file mode 100644 index 0000000..645ff0d --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/risk_factors/paf/constants.py @@ -0,0 +1,32 @@ +"""FHS Pipeline Scalars Local Constants.""" + + +class PAFConstants: + """PAF-calculation-related constants.""" + + # these must correspond to file name in vaccine_sev input folder + VACCINE_RISKS = ( + "vacc_dtp3", + "vacc_mcv1", + "vacc_hib3", + "vacc_pcv3", + "vacc_rotac", + "vacc_mcv2", + ) + + PAF_LOWER_BOUND = -0.999 + PAF_UPPER_BOUND = 0.999 + + # GBD has RRMax values for TB child causes that we don't model, but not for TB. + # As of GBD2019, they are all the same. Therefore we can pull RRMax for TB from + # an arbitrary TB child cause. + TB_OTHER_CAUSE_ID = 934 # arbitrary TB child cause (tb_other) + + DEBUG_CR_COUNT = 3 # number of cause-risk pairs to pull for debug + + +class ClusterJobConstants: + """Constants related to submitting jobs on the cluster.""" + + COMPUTE_PAFS_RUNTIME = "8h" + COMPUTE_SCALARS_RUNTIME = "16h" diff --git a/gbd_2021/disease_burden_forecast_code/risk_factors/paf/forecasting_db.py b/gbd_2021/disease_burden_forecast_code/risk_factors/paf/forecasting_db.py new file mode 100644 index 0000000..c44b4f8 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/risk_factors/paf/forecasting_db.py @@ -0,0 +1,194 @@ +import collections +from typing import List + +import fhs_lib_database_interface.lib.query.risk as query_risk +import pandas as pd +from fhs_lib_database_interface.lib import db_session +from fhs_lib_database_interface.lib.constants import ( + CauseRiskPairConstants, + DimensionConstants, + FHSDBConstants, + SexConstants, +) +from fhs_lib_database_interface.lib.fhs_lru_cache import fhs_lru_cache +from fhs_lib_database_interface.lib.query import cause, entity +from fhs_lib_database_interface.lib.query.age import get_ages +from fhs_lib_database_interface.lib.query.location import get_location_set +from fhs_lib_database_interface.lib.strategy_set import strategy +from fhs_lib_year_range_manager.lib.year_range import YearRange +from tiny_structured_logger.lib.fhs_logging import get_logger + +logger = get_logger() + + +@fhs_lru_cache(1) +def demographic_coords(gbd_round_id: int, years: YearRange) -> collections.OrderedDict: + """Create and cache an OrderedDict of demographic indices. + + Args: + gbd_round_id (int): gbd round id. + years (YearRange): [past_start, forecast_start, forecast_end] years. + + Returns: + OrderedDict: ordered dict of all non-draw dimensions and their + coordinates. + """ + location_ids = get_location_set(gbd_round_id=gbd_round_id).location_id.values.tolist() + + age_group_ids = get_ages()[DimensionConstants.AGE_GROUP_ID].unique().tolist() + + return collections.OrderedDict( + [ + (DimensionConstants.LOCATION_ID, location_ids), + (DimensionConstants.AGE_GROUP_ID, age_group_ids), + (DimensionConstants.SEX_ID, list(SexConstants.SEX_IDS)), + (DimensionConstants.YEAR_ID, years.years), + ] + ) + + +@fhs_lru_cache(3) +def _get_precalculated_pafs(gbd_round_id: int, strategy_id: int) -> pd.DataFrame: + """Get cause-risk pairs that have directly-modeled PAFs. + + Args: + gbd_round_id (int): gbd round id. + strategy_id (int): strategy ID for this particular cause-risk pair. + + Returns: + (pd.DataFrame): spreadsheeto of the precalcualted PAF cause-risk pairs. + """ + if gbd_round_id == 4: + raise ValueError("GBD round ID 4 is no longer supported.") + + with db_session.create_db_session(FHSDBConstants.FORECASTING_DB_NAME) as session: + result = strategy.get_cause_risk_pair_set( + session=session, + strategy_id=strategy_id, + gbd_round_id=gbd_round_id, + ) + + # Attach acause, rei (to match the cause_id, rei_id). + acause_cause_id_map = cause.get_acauses(result["cause_id"].unique()) + rei_rei_id_map = query_risk.get_reis(result["rei_id"].unique()) + result = result.merge(acause_cause_id_map, how="left").merge(rei_rei_id_map, how="left") + + entity.assert_cause_risk_pairs_are_named(result) + + return result + + +def is_pre_calculated(acause: str, rei: str, gbd_round_id: int, strategy_id: int) -> bool: + """Return true if the cause-risk pair has a directly modeled PAF. + + Args: + acause (str): acause. + rei (str): risk. + gbd_round_id (int): gbd round id. + strategy_id (int): strategy ID for this particular cause-risk pair. + + Returns: + (bool): whether this cause-risk pair is pre-calculated. + """ + dm_pafs = _get_precalculated_pafs(gbd_round_id, strategy_id) + return not dm_pafs.query("acause == @acause and rei == @rei").empty + + +def get_most_detailed_acause_related_risks(acause: str, gbd_round_id: int) -> List[str]: + """Return a list of most-detailed risks contributing to certain acause. + + Args: + acause (str): analytical cause. + gbd_round_id (int): gbd round id. + + Returns: + (List[str]): a list of risks associated with this acause, + ignoring the ones that are in_scalar = 0. + """ + if acause in ["rotavirus"]: + risks = ["rota"] + else: + df_acause_risk = entity.get_scalars_most_detailed_cause_risk_pairs(gbd_round_id) + risks = list(df_acause_risk.query("acause == @acause")["rei"].unique()) + return risks + + +@fhs_lru_cache(1) +def _get_maybe_negative_paf_pairs(gbd_round_id: int) -> pd.DataFrame: + """Get cause-risk pairs that *can* have negative PAFs. + + Args: + gbd_round_id (int): gbd round id. + + Returns: + (pd.DataFrame): spreadsheet of cause-risk pairs that could have + negative PAF. + """ + if gbd_round_id == 4: + raise ValueError("GBD round ID 4 is no longer supported.") + + with db_session.create_db_session(FHSDBConstants.FORECASTING_DB_NAME) as session: + result = strategy.get_cause_risk_pair_set( + session=session, + strategy_id=CauseRiskPairConstants.MAYBE_NEGATIVE_PAF_SET_ID, + gbd_round_id=gbd_round_id, + ) + + # Set only has cause_ids and rei_ids, so get acauses + acause_cause_id_map = cause.get_acauses(result["cause_id"].unique()) + rei_rei_id_map = query_risk.get_reis(result["rei_id"].unique()) + result = result.merge(acause_cause_id_map, how="left").merge(rei_rei_id_map, how="left") + + entity.assert_cause_risk_pairs_are_named(result) + + return result + + +def is_maybe_negative_paf(acause: str, rei: str, gbd_round_id: int) -> bool: + """Return true if the cause-risk pair is maybe a negative PAF. + + Args: + acause (str): acause. + rei (str): risk. + gbd_round_id (int): gbd round id. + + Returns: + (bool): whether this cause-risk pair might be protective. + """ + maybe_negative_pafs = _get_maybe_negative_paf_pairs(gbd_round_id) + return not maybe_negative_pafs.query("acause == @acause and rei == @rei").empty + + +@fhs_lru_cache(1) +def get_detailed_directly_modeled_pafs(gbd_round_id: int) -> pd.DataFrame: + """Get the directly modeled PAFs, as cause_risk pairs. + + Args: + gbd_round_id (int): gbd round id. + + Returns: + (pd.DataFrame): dataframe of directly modeled acause and rei. + """ + all_detailed_pafs = entity.get_scalars_most_detailed_cause_risk_pairs(gbd_round_id)[ + ["acause", "rei"] + ] + directly_modeled_pafs = _get_precalculated_pafs( + gbd_round_id, CauseRiskPairConstants.DIRECTLY_MODELED_STRATEGY_ID + ) # has aggregates + detailed_dm_pafs = all_detailed_pafs.merge(directly_modeled_pafs, on=["acause", "rei"]) + + return detailed_dm_pafs + + +def get_modeling_causes(gbd_round_id: int) -> List[str]: + """Return the causes we are modeling. + + Args: + gbd_round_id (int): gbd round id. + + Returns: + (List[str]): list of all acauses for the scalars pipeline. + """ + df_acause_risk = entity.get_scalars_most_detailed_cause_risk_pairs(gbd_round_id) + acauses = list(df_acause_risk.acause.unique()) + return acauses diff --git a/gbd_2021/disease_burden_forecast_code/risk_factors/paf/utils.py b/gbd_2021/disease_burden_forecast_code/risk_factors/paf/utils.py new file mode 100644 index 0000000..ca2fc53 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/risk_factors/paf/utils.py @@ -0,0 +1,303 @@ +"""Utility/DB functions for the scalars pipeline. +""" + +from typing import Any, List, Optional, Union + +import numpy as np +import xarray as xr +from fhs_lib_data_transformation.lib.dimension_transformation import expand_dimensions +from fhs_lib_database_interface.lib.constants import ( + AgeConstants, + DimensionConstants, + LocationConstants, + SexConstants, +) +from fhs_lib_file_interface.lib.query import mediation +from fhs_lib_file_interface.lib.version_metadata import FHSFileSpec, VersionMetadata +from fhs_lib_file_interface.lib.xarray_wrapper import save_xr_scenario +from fhs_lib_year_range_manager.lib.year_range import YearRange +from tiny_structured_logger.lib.fhs_logging import get_logger + +from fhs_pipeline_scalars.lib.forecasting_db import demographic_coords + +logger = get_logger() + + +def product_of_mediation( + acause: str, risk: str, mediators: List[str], gbd_round_id: int +) -> Union[xr.DataArray, int]: + r"""Return :math:`\prod_{i} (1 - MF_{jic})`. + + :math:`j` is the risk whose adjusted PAF to acause c we're considering, + and :math:`i` is the mediator. + + That means this method takes in only 1 risk, and a list of multiple + mediators, and performs a product over all the :math:`(1 - MF_{jic})`'s. + + "metab_bmi" impacts "cvd_ihd" via "metab_sbp". Say sbp is the only + mediator of bmi. So if PAF_{ihd/bmi} is 0.9 and MF_{sbp/bmi} is 0.6, + then the adjusted PAF_{ihd/bmi} is 0.9 * (1 - 0.6) = 0.36, + because 0.54 of the 0.9 comes from sbp. + + The mediation factor for (acause, risk) is provided in a flat file. + + Args: + acause (str): analytical cause. + risk (str): risk related to acause. + mediators (list[str]): the risks that could potentially sit between risk and the acause + in attribution. Usually these are just all the risks that are paired with acause, + and then we filter through the mediation file for the ones that matter. + gbd_round_id (int): gbd round id. + + Returns: + mediation_products: Either the number 1 (float), if the loaded data does not describe + the given acause and rei, or else the mediation matrix, as an xarray DataArray, if + it does. + """ + logger.info( + "Computing product of mediation of risk {} via mediators {} " + "on acause {}".format(risk, mediators, acause) + ) + + med_risks = list(mediators) # just making a copy for later .remove() + # NOTE the list of mediators can include "risk" itself. Here we remove + # risk from the list of mediators + if risk in med_risks: + med_risks.remove(risk) + + med = mediation.get_mediation_matrix(gbd_round_id) + + # only apply mediation when mediation exists + if acause in med["acause"] and risk in med["rei"]: + mediators = list(set(med["med"].values) & set(med_risks)) + if mediators: + mediation_factors = med.sel(acause=acause, rei=risk, med=mediators) + return (1 - mediation_factors).prod("med") + else: + return 1 # default, if no mediation exists + else: + return 1 # default, if no mediation exists + + +def conditionally_triggered_transformations( + da: xr.DataArray, gbd_round_id: int, years: YearRange +) -> xr.DataArray: + """Encode dataarray transformations that are conditionally triggered. + + Any of the following conditions will trigger a transformation: + + 1.) + da only has sex_id = 3 (both sexes). In this case, expand to + sex_id = 1, 2. + 2.) + da onyl has age_group_id = 22 (all ages). In this case, expand to + all age_group_ids of current gbd_round + 3.) + da has different years than expected. In this case, filter/interpolate + for just DEMOGRAPHY_INDICES[YEAR_ID]. + 4.) + da has only location_id = 1 (all locations). In this case, we replace + location_id=1 with the latest gbd location ids. + 5.) + da has a point dimension of "quantile". Simply remove. + + Args: + da (xr.DataArray): may or may not become transformed. + gbd_round_id (int): gbd round id. + years (YearRange): [past_start, forecast_start, forecast_end] years. + + Returns: + (xr.DataArray): transformed datarray, or not. + """ + if ( + DimensionConstants.SEX_ID in da.dims + and len(da[DimensionConstants.SEX_ID]) == 1 + and da[DimensionConstants.SEX_ID] == SexConstants.BOTH_SEX_ID + ): + # some vaccine SEVs could be modeled this way + da = _expand_point_dim_to_all(da, gbd_round_id, years, dim=DimensionConstants.SEX_ID) + + if ( + DimensionConstants.AGE_GROUP_ID in da.dims + and len(da[DimensionConstants.AGE_GROUP_ID]) == 1 + and da[DimensionConstants.AGE_GROUP_ID] == AgeConstants.ALL_AGE_ID + ): + # if da has a few age groups but not all, no transformation here + da = _expand_point_dim_to_all( + da, gbd_round_id, years, dim=DimensionConstants.AGE_GROUP_ID + ) + + if ( + DimensionConstants.LOCATION_ID in da.dims + and len(da[DimensionConstants.LOCATION_ID]) == 1 + and da[DimensionConstants.LOCATION_ID] == LocationConstants.GLOBAL_LOCATION_ID + ): + da = _expand_point_dim_to_all( + da, gbd_round_id, years, dim=DimensionConstants.LOCATION_ID + ) + + # some times we're provided with not enough or too many years. + # here we first construct a superset, then crop out unwanted years. + if DimensionConstants.YEAR_ID in da.dims and set( + da[DimensionConstants.YEAR_ID].values + ) != set(demographic_coords(gbd_round_id, years)[DimensionConstants.YEAR_ID]): + missing_years = list( + set(demographic_coords(gbd_round_id, years)[DimensionConstants.YEAR_ID]) + - set(da[DimensionConstants.YEAR_ID].values) + ) + if missing_years: + da = _fill_missing_years_with_mean(da, DimensionConstants.YEAR_ID, missing_years) + # Now I filter for only the ones I need + da = da.loc[ + { + DimensionConstants.YEAR_ID: demographic_coords(gbd_round_id, years)[ + DimensionConstants.YEAR_ID + ] + } + ] + + # we don't need the the quantile point coordinate + if DimensionConstants.QUANTILE in da.dims and len(da[DimensionConstants.QUANTILE]) == 1: + da = da.drop_vars(DimensionConstants.QUANTILE) + + return da + + +def _fill_missing_years_with_mean( + da: xr.DataArray, year_dim: str, missing_years: List[int] +) -> xr.DataArray: + """Fill missing years with the mean of other years. + + Args: + da (xr.DataArray): should contain the year_dim dimension. + year_dim (str): str name of the year dimension. + missing_years (list): list of years that are missing. + + Returns: + (xr.DataArray): dataarray with missing years filled in. + """ + year_coords = xr.DataArray([np.nan] * len(missing_years), [(year_dim, missing_years)]) + mean_vals = da.mean(year_dim) # NOTE just fill with mean values + fill_da = mean_vals.combine_first(year_coords) + da = da.combine_first(fill_da) + return da.sortby(year_dim) + + +def _expand_point_dim_to_all( + da: xr.DataArray, gbd_round_id: int, years: YearRange, dim: str +) -> xr.DataArray: + """Expand point dim to full set of coordinates. + + Sometimes input data has only age_group_id = 22 (all ages), or sex_id = 3 (both sexes). + In that case, we'll need to convert it to the full age group dim, or sex_id = 1,2. Such + information is stored in self._expected_dims + + Args: + da (xr.DataArray): da to expand. + gbd_round_id (int): gbd round id. + years (YearRange): [past_start, forecast_start, forecast_end] years. + dim (str): Point dimension of da to expand on. Typically this occurs with + age_group_id = 2 (all ages) or sex_id = 3. + + Returns: + (xr.Dataarray): has the dim coordinates prescribed by self.gbd_round_id + """ + if dim not in da.dims: + raise ValueError("{} is not an index dimension".format(dim)) + if len(da[dim]) != 1: + raise ValueError("Dim {} length not equal to 1".format(dim)) + + # first I need the proper coordinates + coords = demographic_coords(gbd_round_id, years)[dim] + # then make keyword argument dict + dimension_kwarg = {dim: coords} + + # expand + out = expand_dimensions(da.drop_vars(dim).squeeze(), **dimension_kwarg) + + return out + + +def data_value_check(da: xr.DataArray) -> None: + """Check sensibility of dataarray values. + + Currently, these are checked: + + 1.) No NaN/Inf + 2.) No negative values + + Args: + da (xr.DataArray): data in dataarray format. + """ + if not np.isfinite(da).all(): + raise ValueError("Array contains non-finite values") + if not (da >= 0.0).all(): + raise ValueError("Array contains values < 0") + + +def save_paf( + paf: Union[xr.DataArray, xr.Dataset], + gbd_round_id: int, + past_or_future: str, + version: str, + acause: str, + cluster_risk: Optional[str] = None, + file_suffix: str = "", + metric: str = DimensionConstants.PERCENT_METRIC, + space: str = DimensionConstants.IDENTITY_SPACE_NAME, + **kwargs: Any, +) -> None: + """Save mediated PAF at cause level. + + Args: + paf (Union[xr.DataArray, xr.Dataset]): DataArray or Dataset of PAF. + gbd_round_id (int): gbd round id. + past_or_future (str): 'past' or 'future'. + version (str): version, dated. + acause (str): analytical cause. + cluster_risk (Optional[str]): If none, it will be just risk. + file_suffix (str): File name suffix to attach onto output file name, should include a + leading underscore if desired. Default is an empty string. + metric (str): Metric name of the output xarray, defaults to + `DimensionConstants.PERCENT_METRIC` + space (str): Space name of the output xarray, defaults to + `DimensionConstants.IDENTITY_SPACE_NAME` + kwargs (Any): Additional keyword arguments to pass into save_xr call + """ + if past_or_future == "past": + if DimensionConstants.SCENARIO in paf.dims: + # past data should be scenarioless; all its scenarios are identical, so we can just + # grab the first one + paf = paf.sel(scenario=paf["scenario"].values[0], drop=True) + + file_attributes = kwargs + version_metadata = VersionMetadata.make( + data_source=gbd_round_id, + epoch=past_or_future, + stage="paf", + version=version, + ) + + if cluster_risk: + output_file_spec = FHSFileSpec( + version_metadata=version_metadata, + sub_path=("risk_acause_specific",), + filename=f"{acause}_{cluster_risk}{file_suffix}.nc", + ) + file_attributes = dict(file_attributes, risk=cluster_risk) + else: + output_file_spec = FHSFileSpec( + version_metadata=version_metadata, + filename=f"{acause}{file_suffix}.nc", + ) + + save_xr_scenario( + xr_obj=paf, + file_spec=output_file_spec, + metric=metric, + space=space, + acause=acause, + version=version, + gbd_round_id=gbd_round_id, + **file_attributes, + ) diff --git a/gbd_2021/disease_burden_forecast_code/risk_factors/sev/compute_future_mediator_total_sev.py b/gbd_2021/disease_burden_forecast_code/risk_factors/sev/compute_future_mediator_total_sev.py new file mode 100644 index 0000000..9eeaafd --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/risk_factors/sev/compute_future_mediator_total_sev.py @@ -0,0 +1,448 @@ +r"""Compute cause-risk-specific future total SEV, given acause and risk. + +Example call: + +.. code:: bash + + python compute_future_mediator_total_sev.py \ + --acause cvd_ihd \ + --rei metab_sbp \ + --gbd-round-id 6 \ + --sev-version 20210124_test \ + --past-sev-version 20201105_etl_sevs \ + --rr-version 20200831_hard_coded_inference + +Note that, to save memory footprint, this code will only analyze and +export future years. +""" + +import gc +import glob +from functools import reduce +from pathlib import Path +from typing import Optional + +import numpy as np +import xarray as xr +from fhs_lib_data_transformation.lib.dimension_transformation import expand_dimensions +from fhs_lib_data_transformation.lib.intercept_shift import unordered_draw_intercept_shift +from fhs_lib_data_transformation.lib.resample import resample +from fhs_lib_database_interface.lib.query import cause +from fhs_lib_file_interface.lib.pandas_wrapper import read_csv +from fhs_lib_file_interface.lib.version_metadata import ( + FHSDirSpec, + FHSFileSpec, + VersionMetadata, +) +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr_scenario, save_xr_scenario +from fhs_lib_genem.lib.constants import OrchestrationConstants +from fhs_lib_year_range_manager.lib.year_range import YearRange +from tiny_structured_logger.lib import fhs_logging + +from fhs_pipeline_sevs.lib import rrmax as rrmax_module +from fhs_pipeline_sevs.lib.compute_past_intrinsic_sev import ( + get_product_on_right_hand_side, + newton_solver, +) +from fhs_pipeline_sevs.lib.constants import FutureSEVConstants, PastSEVConstants + +logger = fhs_logging.get_logger() + + +def keep_minimum_k( + k0: xr.DataArray, + acause: str, + rei: str, + gbd_round_id: int, + past_sev_version: str, + years: YearRange, + draws: int, + draw_start: Optional[int] = 0, +) -> xr.DataArray: + """Compare past and future k coefficients and keep the lower ones. + + `k0` contains ARC-computed past years. + For the past years, we simply replace. + For the future years, we compare to last past year first, and then keep + the lower ones. + + Args: + k0 (xr.DataArray): k coefficients freshly computed from newton solver. + acause (str): acause. + rei (str): risk. + gbd_round_id (int): gbd round id. + sev_version (str): version of future SEV. + rr_version (str): version of RR (past). + years (YearRange): past_start:future_start:future_end. + draws (int): number of draws kept in process. + draw_start (Optional[int]): starting index of draws selected. + + Returns: + (xr.DataArray): k coefficients of min(past_end, forecast_years). + """ + k_past_version_metadata = VersionMetadata.make( + data_source=gbd_round_id, + epoch="past", + stage="sev", + version=past_sev_version, + ) + k_past_file_spec = FHSFileSpec( + version_metadata=k_past_version_metadata, + sub_path=(OrchestrationConstants.SUBFOLDER,), + filename=f"{acause}_{rei}_k_coeff.nc", + ) + + k_past = ( + open_xr_scenario(k_past_file_spec) + .load() + .sel( + draw=range(draw_start, draw_start + draws, 1), + location_id=k0["location_id"].values, + age_group_id=k0["age_group_id"].values, + sex_id=k0["sex_id"].values, + year_id=years.past_years, + ) + ) + + if not (k_past["draw"] == k0["draw"]).all(): + raise ValueError("draw coordinate mismatch between k0 and k_last_past") + + # first make sure k_past has the same scenarios as k0 + k_past = expand_dimensions(k_past, scenario=k0["scenario"].values.tolist()) + + # replace forecasted values with past where last past < forecast + k_last_past = k_past.sel(year_id=years.past_end) + + needed_years = np.concatenate(([years.past_end], years.forecast_years)) + k0_future = k0.sel(year_id=needed_years) + + # compare to last past year first, and then keep the lower ones. + k0.loc[dict(year_id=needed_years)] = k0_future.where(k0_future < k_last_past).fillna( + k_last_past + ) + + return k0 + + +def future_cause_risk_sev( + acause: str, + rei: str, + gbd_round_id: int, + sev_version: str, + past_sev_version: str, + rr_version: str, + years: YearRange, + draws: int, +) -> None: + """Produce cause-risk specific SEV forecast. + + Export to the same version as where the intrinsic SEVs are. + + Args: + acause (str): acause. + rei (str): risk. + gbd_round_id (int): gbd round id. + sev_version (str): version of future SEV. + past_sev_version (str): version of past SEV. + rr_version (str): version of RR (past). + years (YearRange): past_start:future_start:future_end. + draws (int): number of draws kept in process. + """ + cause_id = cause.get_cause_id(acause=acause) + + # where to pick up the cause-risk specific *future* intrinsic SEV + in_out_version_metadata = VersionMetadata.make( + data_source=gbd_round_id, + epoch="future", + stage="sev", + version=sev_version, + ) + in_out_dir_spec = FHSDirSpec( + version_metadata=in_out_version_metadata, sub_path=(OrchestrationConstants.SUBFOLDER,) + ) + + chunk_size = PastSEVConstants.DRAW_CHUNK_SIZE # compute k in chunks of 100 draws + + k = xr.DataArray() + rei_product = xr.DataArray() + + for i, draw_start in enumerate(range(0, draws, chunk_size)): + # burden is on the user to make sure draws & chunk_size are consistent + sev_intrinsic = open_xr_scenario( + file_spec=FHSFileSpec.from_dirspec( + dir=in_out_dir_spec, filename=f"{acause}_{rei}_intrinsic.nc" + ) + ).sel(draw=range(draw_start, draw_start + chunk_size, 1)) + + med_age_ids = sev_intrinsic.age_group_id.values + + a_is = get_product_on_right_hand_side( + acause=acause, + cause_id=cause_id, + rei=rei, + gbd_round_id=gbd_round_id, + past_or_future="future", + sev_version=sev_version, + rr_version=rr_version, + draws=chunk_size, + draw_start=draw_start, + ) + + a_is = expand_dimensions(a_is, age_group_id=med_age_ids, fill_value=0) + + # this is a more memory-friendly way to make the product + rei_product_i = reduce( + lambda x, y: x * y, + (a_is.sel(rei=risk) + 1 for risk in a_is["rei"].values), + ) + + # Note get_product_on_right_hand_side does the same read_rrmax call. + rrmax = rrmax_module.read_rrmax( + acause=acause, + cause_id=cause_id, + rei=rei, + gbd_round_id=gbd_round_id, + version=rr_version, + draws=chunk_size, + draw_start=draw_start, + ) + + sev = (((rrmax - 1) * sev_intrinsic + 1) * rei_product_i - 1) / (rrmax - 1) + + del rei_product_i + gc.collect() + + # some sev values might be greater than 1. Here we solve for their k's + sev_reduced = sev.where(sev <= 1).fillna(1) + + b_const = ((rrmax - 1) * sev_reduced + 1) / ((rrmax - 1) * sev_intrinsic + 1) + + del sev_reduced + gc.collect() + + a_is = a_is.sel(age_group_id=b_const["age_group_id"].values) + + b_const = b_const.where(sev > 1).fillna(1) # so log(b) = 0 + + # now compute the initial k guess + rei_sum = a_is.sum(dim="rei") + + k_i = np.log(b_const) / rei_sum # the initial k + + del rei_sum + gc.collect() + + k_i = k_i.where(sev > 1).fillna(0) # so log(a * k + 1) = 0 + + for dim in k_i.dims: # make sure all indices are matched along axis + b_const = b_const.sel(**{dim: k_i[dim].values}) + a_is = a_is.sel(**{dim: k_i[dim].values}) + + # align everything to k_i's dim order before newton solver + b_const = b_const.transpose(*k_i.dims) + a_is = a_is.transpose(*[a_is.dims[0]] + list(b_const.dims)) + + k_i = newton_solver(a_is, b_const, k_i) + + del b_const + gc.collect() + + k_i = k_i.where(sev > 1).fillna(1) # k=1 for non-problematic cells + + del sev + gc.collect() + + # Per CJLM, pick the min of K between last past year and just-computed + k_i = keep_minimum_k( + k0=k_i, + acause=acause, + rei=rei, + gbd_round_id=gbd_round_id, + past_sev_version=past_sev_version, + years=years, + draws=chunk_size, + draw_start=draw_start, + ) + + # need to recompute rei_product, now with k + rei_product_i = reduce( + lambda x, y: x * y, + (k_i * a_is.sel(rei=risk) + 1 for risk in a_is["rei"].values), + ) + + if i == 0: + k = k_i + rei_product = rei_product_i + else: + k = xr.concat([k, k_i], dim="draw") + rei_product = xr.concat([rei_product, rei_product_i], dim="draw") + del k_i, rei_product_i + + del a_is + gc.collect() + + # re-compute sev_intrinsic with k. + sev_intrinsic = open_xr_scenario( + file_spec=FHSFileSpec.from_dirspec( + dir=in_out_dir_spec, filename=f"{acause}_{rei}_intrinsic.nc" + ) + ) + sev_intrinsic = resample(sev_intrinsic, draws) + + rrmax = rrmax_module.read_rrmax( + acause=acause, + cause_id=cause_id, + rei=rei, + gbd_round_id=gbd_round_id, + version=rr_version, + draws=draws, + ) + + sev = (((rrmax - 1) * sev_intrinsic + 1) * rei_product - 1) / (rrmax - 1) + + del rrmax, sev_intrinsic, rei_product + gc.collect() + + if (sev > 1 + PastSEVConstants.TOL).any(): + max_val = float(sev.max()) + logger.warning( + ( + f"There are values in final SEV > 1 + {PastSEVConstants.TOL}, " + f"the max being {max_val}." + ) + ) + + sev = sev.where(sev <= 1).fillna(1) + sev = sev.where(sev >= 0).fillna(0) + + # remove some redundant point dims + redundants = list(set(sev.coords.keys()) - set(sev.dims)) + redundants.remove("acause") # we're keeping "acause" + + for redundant in redundants: # only keep dims and "acause" + k = k.drop_vars(redundant) + sev = sev.drop_vars(redundant) + + save_xr_scenario( + xr_obj=k, + file_spec=FHSFileSpec.from_dirspec( + dir=in_out_dir_spec, filename=f"{acause}_{rei}_k_coeff.nc" + ), + metric="number", + space="identity", + gbd_round_id=gbd_round_id, + sev_version=sev_version, + rr_version=rr_version, + tol=PastSEVConstants.TOL, + rtol=PastSEVConstants.RTOL, + maxiter=PastSEVConstants.MAXITER, + ) + + save_xr_scenario( + xr_obj=sev, + file_spec=FHSFileSpec.from_dirspec(dir=in_out_dir_spec, filename=f"{acause}_{rei}.nc"), + metric="rate", + space="identity", + gbd_round_id=gbd_round_id, + sev_version=sev_version, + rr_version=rr_version, + ) + + +def combine_cause_risk_sevs_to_sev( + rei: str, + sev_version: str, + past_sev_version: str, + rr_version: str, + gbd_round_id: int, + years: YearRange, +) -> None: + """Combine cause-risk SEVs to make risk-only SEVs, for given risk. + + Cause-risk SEVs are in sev_version/risk_acause_specific/, with risk-only + SEVs exported to sev_version/ + + Args: + rei (str): risk whose cause-risk SEVs will be averaged out. + sev_version (str): future SEV version. + past_sev_version (str): past SEV vdersion. + rr_version (str): version of rr, from FILEPATH. + gbd_round_id (int): gbd round id. + draws (int): number of draws to keep. + years (YearRange): past_start:forecast_start:forecast_end. + """ + daly_df = read_csv(FutureSEVConstants.DALY_WEIGHTS_FILE_SPEC, keep_default_na=False) + + output_version_metadata = VersionMetadata.make( + data_source=gbd_round_id, + epoch="future", + stage="sev", + version=sev_version, + ) + + glob_str = ( + output_version_metadata.data_path() / OrchestrationConstants.SUBFOLDER / f"*{rei}.nc" + ) + file_names = {Path(f).name for f in glob.glob(str(glob_str))} + + if len(file_names) == 0: + raise ValueError(f"No cause-risk files found for *{rei}.nc") + + acauses = [Path(x).name.replace(f"_{rei}.nc", "") for x in file_names] + total_daly = daly_df.query("acause in {}".format(tuple(acauses)))["DALY"].sum() + + sev = xr.DataArray(0) + + for file_name in file_names: + sev_i = open_xr_scenario( + FHSFileSpec( + version_metadata=output_version_metadata, + sub_path=(OrchestrationConstants.SUBFOLDER,), + filename=file_name, + ) + ) + + acause = file_name.replace(f"_{rei}.nc", "") + acause_daly = float(daly_df.query(f"acause == '{acause}'")["DALY"].values) + fraction = acause_daly / total_daly + sev = sev + (sev_i * fraction) + del sev_i + gc.collect() + + if "acause" in sev.coords: + sev = sev.drop_vars("acause") + + past_sev_version_metadata = VersionMetadata.make( + data_source=gbd_round_id, + epoch="past", + stage="sev", + version=past_sev_version, + ) + past_file_spec = FHSFileSpec( + version_metadata=past_sev_version_metadata, filename=f"{rei}.nc" + ) + past_data = open_xr_scenario(past_file_spec) + + # Need to subset past_data by sev coordinates + past_data = past_data.sel( + age_group_id=sev.age_group_id.values, + sex_id=sev.sex_id.values, + location_id=sev.location_id.values, + ) + + past_data = resample(past_data, len(sev.draw.values)) + + sev = unordered_draw_intercept_shift(sev, past_data, years.past_end, years.forecast_end) + + sev = sev.clip(min=FutureSEVConstants.FLOOR, max=1 - FutureSEVConstants.FLOOR) + + # Now we can save it + save_xr_scenario( + xr_obj=sev, + file_spec=FHSFileSpec(version_metadata=output_version_metadata, filename=f"{rei}.nc"), + metric="rate", + space="identity", + gbd_round_id=gbd_round_id, + sev_version=sev_version, + rr_version=rr_version, + ) diff --git a/gbd_2021/disease_burden_forecast_code/risk_factors/sev/compute_past_intrinsic_sev.py b/gbd_2021/disease_burden_forecast_code/risk_factors/sev/compute_past_intrinsic_sev.py new file mode 100644 index 0000000..1199386 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/risk_factors/sev/compute_past_intrinsic_sev.py @@ -0,0 +1,444 @@ +"""Compute Intrinsic SEV of a mediator. +""" + +import gc +from functools import reduce +from typing import Any, List, Optional, Union + +import fhs_lib_database_interface.lib.query.cause as query_cause +import fhs_lib_database_interface.lib.query.risk as query_risk +import numpy as np +import xarray as xr +from fhs_lib_data_transformation.lib.dimension_transformation import expand_dimensions +from fhs_lib_data_transformation.lib.resample import resample +from fhs_lib_database_interface.lib.query.age import get_ages +from fhs_lib_file_interface.lib.query.mediation import get_mediation_matrix +from fhs_lib_file_interface.lib.version_metadata import ( + FHSDirSpec, + FHSFileSpec, + VersionMetadata, +) +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr_scenario, save_xr_scenario +from fhs_lib_genem.lib.constants import OrchestrationConstants +from scipy.optimize import newton +from tiny_structured_logger.lib import fhs_logging + +from fhs_pipeline_sevs.lib import rrmax as rrmax_module +from fhs_pipeline_sevs.lib.constants import PastSEVConstants + +logger = fhs_logging.get_logger() + + +def get_product_on_right_hand_side( + acause: str, + cause_id: int, + rei: str, + gbd_round_id: int, + past_or_future: str, + sev_version: str, + rr_version: str, + draws: int, + age_group_ids: Optional[List[int]] = None, + draw_start: int = 0, +) -> xr.DataArray: + r"""Compute the product term of the right-hand side of Equation (3) from the background_. + + .. _background: compute_intrinsic_background.html + + The right-hand side again: + + .. math:: + \left[ \left( RR_j^{max} \ - \ 1 \right) \ SEV_j^U \ + \ 1 \right] \ + \prod_{i=1}^n \left[ \left( RR_{i}^{M, max} \ - \ 1 \right) \ SEV_i \ + + \ 1\right] + + where :math:`i` stands for the antecedent risk. + + We take the product term as + + .. math:: + \prod_{i=1}^n \left( \ a_i \ k \ + 1 \right) + + and return :math:`a_i` as a dataarray with a risk-dimension. + + Args: + acause (str): acause. + cause_id (int): cause id for acause. + rei (str): the mediator, risk j's formal name. + gbd_round_id (int): gbd round id. + past_or_future (str): "past" or "future". + sev_version (str): version of where SEV is. + rr_version (str): version of etl-ed RR. + draws (int): number of draws kept in process. + draw_start (Optional[int]): starting index of draws selected. + + Returns: + xr.DataArray: The a_i terms of the product term. + """ + # med_da dims are ('acause', 'rei', 'med_', 'draw') + # want to filter to only slices pertaining to the acause-mediator pair + med_da = get_mediation_matrix(gbd_round_id).sel( + acause=acause, med=rei, draw=range(draw_start, draw_start + draws, 1) + ) + + risks = list(med_da["rei"].values) + + # Note that constructing a list of the single-risk arrays is not bad for memory. + # Intuitively: The xr.concat function needs to have all the source arrays in memory at once + # to copy them into the new, concatenated array, so it doesn't help to try to generate one + # at a time. + risk_a_is = [ + _single_risk_a_i( + risk=risk, + med_da=med_da, + acause=acause, + cause_id=cause_id, + rei=rei, + gbd_round_id=gbd_round_id, + past_or_future=past_or_future, + sev_version=sev_version, + rr_version=rr_version, + draws=draws, + age_group_ids=age_group_ids, + draw_start=draw_start, + ) + for risk in risks + ] + a_is = xr.concat([a_i for a_i in risk_a_is if a_i is not None], dim="rei") + + a_is = a_is.fillna(0) + + return a_is + + +def _single_risk_a_i( + risk: str, + med_da: xr.DataArray, + acause: str, + cause_id: int, + rei: str, + gbd_round_id: int, + past_or_future: str, + sev_version: str, + rr_version: str, + draws: int, + age_group_ids: Optional[List[int]], + draw_start: int, +) -> Optional[xr.DataArray]: + mediation_factor = med_da.sel(rei=risk) # now should have only one dim: draw + if mediation_factor.mean() <= 0: + logger.info(f"{acause}-{rei} mediation is 0 for {risk}, skip...") + return None + + logger.info("Computing for {} in product...".format(risk)) + + sev_i = open_xr_scenario( + FHSFileSpec( + version_metadata=VersionMetadata.make( + data_source=gbd_round_id, + epoch=past_or_future, + stage="sev", + version=sev_version, + ), + filename=f"{risk}.nc", + ) + ).sel(draw=range(draw_start, draw_start + draws, 1)) + + if age_group_ids: # if we stipulate age group ids to compute + sev_age_ids = sev_i["age_group_id"].values + + # we don't need age group id 27. Only need the detailed ones. + keep_age_ids = list(set(age_group_ids) & set(sev_age_ids)) + + sev_i = sev_i.sel(age_group_id=keep_age_ids) + + # Because mediation factor (mf) is the fraction of excess risk mediated, + # mf ~ (RR_m - 1) / (RR - 1); hence RR_m ~ (RR - 1) * mf + 1 + rrmax_ij = rrmax_module.read_rrmax( + acause, cause_id, risk, gbd_round_id, rr_version, draws, draw_start + ) + + a_i = _calculate_ai(mediation_factor, sev_i, rrmax_ij) + + a_i["rei"] = risk + return a_i + + +def _calculate_ai( + mediation_factor: xr.DataArray, sev_i: xr.DataArray, rrmax_ij: Union[xr.DataArray, int] +) -> xr.DataArray: + rrmax_ij_med = (rrmax_ij - 1) * mediation_factor + 1 + a_i = (rrmax_ij_med - 1) * sev_i # a_i constant for k-calculation + return a_i + + +def newton_solver( + a_is: xr.DataArray, b_const: xr.DataArray, initial_k: xr.DataArray +) -> xr.DataArray: + """Find the k that satisfies the description under "The k Problem" in the background docs. + + Concretely, find the k such that the sum of log(k * a_is[r] + 1) across risks r is equal to + log(b_const). + + Starts with a "guess" for k, which if too far off, the solving might not converge. + + Args: + a_is (xr.DataArray): product of distal contributions, without k. + b_const (xr.DataArray): total/intrinsic contribution. + k (xr.DataArray): the initial k coefficients. + + Returns: + (xr.DataArray): the final k coefficients. + """ + k_coords = initial_k.coords + + a_is = a_is.fillna(0) + + def fun(k: np.ndarray) -> np.ndarray: # the function to find root of + val = sum(np.log(k * a_is.sel(rei=risk) + 1) for risk in a_is["rei"].values) + val = val - np.log(b_const) + return val + + def dfun(k: np.ndarray) -> np.ndarray: # derivative of fun w/ k + # deriv = (a_is / (a_is * k + 1)).sum(dim="rei") + deriv = sum( + a_is.sel(rei=risk) / (k * a_is.sel(rei=risk) + 1) for risk in a_is["rei"].values + ) + return deriv + + def d2fun(k: xr.DataArray) -> xr.DataArray: + return (-(a_is**2) / ((a_is * k + 1) ** 2)).sum(dim="rei") + + k = newton( + fun, + initial_k, + fprime=dfun, + tol=PastSEVConstants.TOL, + rtol=PastSEVConstants.RTOL, + maxiter=PastSEVConstants.MAXITER, + ) + + return xr.DataArray(k, coords=k_coords) + + +def main( + acause: str, + rei: str, + gbd_round_id: int, + sev_version: str, + rr_version: str, + draws: int, + **kwargs: Any, +) -> None: + """Compute and export intrinsic SEV. + + Args: + acause (str): acause. + rei (str): risk j fomr (1). + gbd_round_id (int): gbd round id. + sev_version (str): version of past SEV. + rr_version (str): version of past RR. + draws (int): number of draws kept in process. + """ + cause_id = query_cause.get_cause_id(acause=acause) + rei_id = query_risk.get_rei_id(rei=rei) + + # will filter to these detailed age group ids + age_group_ids = get_ages(gbd_round_id=gbd_round_id)["age_group_id"].unique().tolist() + + logger.info("computing read in RRmax and SEV...") + + input_sev_version_metadata = VersionMetadata.make( + data_source=gbd_round_id, + epoch="past", + stage="sev", + version=sev_version, + ) + + chunk_size = PastSEVConstants.DRAW_CHUNK_SIZE # compute k in chunks of 100 draws + + for i, draw_start in enumerate(range(0, draws, chunk_size)): + rrmax = rrmax_module.read_rrmax( + acause, + cause_id, + rei, + gbd_round_id, + rr_version, + draws=chunk_size, + draw_start=draw_start, + ) + + sev = open_xr_scenario( + FHSFileSpec(version_metadata=input_sev_version_metadata, filename=f"{rei}.nc") + ).sel(draw=range(draw_start, draw_start + chunk_size, 1)) + + sev_age_ids = sev["age_group_id"].values.tolist() + + # we don't need age group id 27. Only need the detailed ones. + keep_age_ids = list(set(age_group_ids) & set(sev_age_ids)) + + sev = sev.sel(age_group_id=keep_age_ids) + + logger.info(f"Read in RRmax and SEV chunk {i}. " "Now right hand product...") + + a_is = get_product_on_right_hand_side( + acause, + cause_id, + rei, + gbd_round_id, + "past", + sev_version, + rr_version, + draws=chunk_size, + age_group_ids=age_group_ids, + draw_start=draw_start, + ) + + # a_is can be missing age groups if a mediator has few distals + # with different age restrictions than the mediator + # EX. drugs_alcohol starts with age_group_id 6 + # but its mediators start with age_group_id 7 + a_is = expand_dimensions(a_is, age_group_id=keep_age_ids, fill_value=0) + + # We use a generator to produce the multiplicands to save memory. + rei_product_i = reduce( + lambda x, y: x * y, + (a_is.sel(rei=risk) + 1 for risk in a_is["rei"].values.tolist()), + ) + + sev_intrinsic = ((((rrmax - 1) * sev + 1) / rei_product_i) - 1) / (rrmax - 1) + + del rei_product_i + gc.collect() + + # we are setting negative SEV's to 0, to compute k + sev_intrinsic_lifted = sev_intrinsic.where(sev_intrinsic >= 0).fillna(0) + + b_const = ((rrmax - 1) * sev + 1) / ((rrmax - 1) * sev_intrinsic_lifted + 1) + + del sev_intrinsic_lifted + gc.collect() + + a_is = a_is.sel(age_group_id=b_const["age_group_id"].values) + + # we only want to update the parts where sev_intrinsic < 0 so we manipulate the initial + # values to make sure only those are touched. + b_const = b_const.where(sev_intrinsic < 0).fillna(1) # so log(b) = 0 + + # now compute the initial k guess + rei_sum = a_is.sum(dim="rei") + + rei_sum = rei_sum.sel(age_group_id=b_const["age_group_id"].values) + + k_i = np.log(b_const) / rei_sum # the initial k + + del rei_sum + gc.collect() + + k_i = k_i.where(sev_intrinsic < 0).fillna(0) # so log(a * k + 1) = 0 + k_i = k_i.where(k_i <= 1).fillna(0.9999) # adjustment to avoid -inf k + + for dim in k_i.dims: # make sure all indices are matched along axis + b_const = b_const.sel(**{dim: k_i[dim].values.tolist()}) + a_is = a_is.sel(**{dim: k_i[dim].values.tolist()}) + + # align everything to k_i's dim order before newton solver + b_const = b_const.transpose(*k_i.dims) + a_is = a_is.transpose(*[a_is.dims[0]] + list(b_const.dims)) + + k_i = newton_solver(a_is, b_const, k_i) + + k_i = k_i.where(sev_intrinsic < 0).fillna(1) # k=1 for "good" cells + + del sev_intrinsic, b_const # free up some space in memory + gc.collect() + + # need to re-compute sev with k + rei_product_i = reduce( + lambda x, y: x * y, + (k_i * a_is.sel(rei=risk) + 1 for risk in a_is["rei"].values.tolist()), + ) + + if i == 0: + k = k_i + rei_product = rei_product_i + else: + k = xr.concat([k, k_i], dim="draw") + rei_product = xr.concat( + [rei_product, rei_product_i], dim="draw" + ) + + del a_is, k_i, rei_product_i + gc.collect() + + # re-compute sev_intrinsic with k + rrmax = rrmax_module.read_rrmax( + acause, cause_id, rei, gbd_round_id, rr_version, draws=draws + ) + + sev = open_xr_scenario( + FHSFileSpec(version_metadata=input_sev_version_metadata, filename=f"{rei}.nc") + ).sel(age_group_id=keep_age_ids) + sev = resample(sev, draws) + + sev_intrinsic = ((((rrmax - 1) * sev + 1) / rei_product) - 1) / (rrmax - 1) + + del rrmax, sev, rei_product + gc.collect() + + # There will be some sev_intrinsic_final values that are slightly < 0 + # due to numerical precision issues. We set them to 0 here. + # For values less than -tol, we raise a flag. We use tol as a threshold. + if (sev_intrinsic < -PastSEVConstants.TOL).any(): + min_val = float(sev_intrinsic.min()) + logger.warning( + f"There are values in final intrinsic SEV < -{PastSEVConstants.TOL}, " + f"the min being {min_val}." + ) + + sev_intrinsic = sev_intrinsic.where(sev_intrinsic >= 0).fillna(0) + sev_intrinsic = sev_intrinsic.where(sev_intrinsic <= 1).fillna(1) + + # let's remove some redundant point dims + redundants = list(set(sev_intrinsic.coords.keys()) - set(sev_intrinsic.dims)) + redundants.remove("acause") # we're keeping "acause" + + for redundant in redundants: # only keep dims and "acause" + k = k.drop_vars(redundant) + sev_intrinsic = sev_intrinsic.drop_vars(redundant) + + output_dir_spec = FHSDirSpec( + version_metadata=input_sev_version_metadata, + sub_path=(OrchestrationConstants.SUBFOLDER,), + ) + + save_xr_scenario( + xr_obj=k, + file_spec=FHSFileSpec.from_dirspec( + dir=output_dir_spec, filename=f"{acause}_{rei}_k_coeff.nc" + ), + metric="number", + space="identity", + cause_id=cause_id, + rei_id=rei_id, + gbd_round_id=gbd_round_id, + sev_version=sev_version, + rr_version=rr_version, + tol=PastSEVConstants.TOL, + rtol=PastSEVConstants.RTOL, + maxiter=PastSEVConstants.MAXITER, + ) + + save_xr_scenario( + xr_obj=sev_intrinsic, + file_spec=FHSFileSpec.from_dirspec( + dir=output_dir_spec, filename=f"{acause}_{rei}_intrinsic.nc" + ), + metric="rate", + space="identity", + cause_id=cause_id, + rei_id=rei_id, + gbd_round_id=gbd_round_id, + sev_version=sev_version, + rr_version=rr_version, + ) diff --git a/gbd_2021/disease_burden_forecast_code/risk_factors/sev/constants.py b/gbd_2021/disease_burden_forecast_code/risk_factors/sev/constants.py new file mode 100644 index 0000000..7ee4941 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/risk_factors/sev/constants.py @@ -0,0 +1,40 @@ +"""FHS Pipeline SEVs Local Constants.""" + +from fhs_lib_file_interface.lib.version_metadata import FHSFileSpec, VersionMetadata + + +class SEVConstants: + """Constants used in SEVs forecasting.""" + + FLOOR = 1e-3 + + # This value is not a standard REI name but is used in place of one to label the "rei" of + # an intrinsic SEV value. + INTRINSIC_SPECIAL_REI = "intrinsic" + + INTRINSIC_SEV_FILENAME_SUFFIX = "intrinsic" + + +class PastSEVConstants: + """Constants used in compute Past SEVs.""" + + TOL = 1e-4 + RTOL = 1e-2 + MAXITER = 50 + DRAW_CHUNK_SIZE = 100 # compute k in chunks of 100 draws + + +class FutureSEVConstants: + """Constants used in computing Future SEVs.""" + + FLOOR = 1e-6 + DALY_WEIGHTS_FILE_SPEC = FHSFileSpec( + version_metadata=VersionMetadata.make( + root_dir="data", + data_source="6", + epoch="past", + stage="mediation", + version="20210802_DALY_weights", + ), + filename="DALY_weights.csv", + ) diff --git a/gbd_2021/disease_burden_forecast_code/risk_factors/sev/mediation.py b/gbd_2021/disease_burden_forecast_code/risk_factors/sev/mediation.py new file mode 100644 index 0000000..ab70902 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/risk_factors/sev/mediation.py @@ -0,0 +1,89 @@ +"""Functions for understanding the mediation hierarchy.""" + +from typing import Dict, List, Tuple + +from fhs_lib_database_interface.lib.fhs_lru_cache import fhs_lru_cache +from fhs_lib_database_interface.lib.query.risk import get_sev_forecast_reis +from fhs_lib_file_interface.lib.query import mediation + + +def get_mediator_upstreams(gbd_round_id: int) -> Dict: + """Return mediators and their upstreams in a list of tuples. + + Args: + gbd_round_id (int): gbd round id. + + Returns: + (Dict): Dict of (mediator, upstreams) pairs, ascending-ordered by number of + upstreams. + """ + mediation_matrix = mediation.get_mediation_matrix(gbd_round_id).mean("draw") + + # for this task, it's easier to deal with a linear array + stacked = mediation_matrix.stack(dims=mediation_matrix.dims) + nonzeros = stacked[stacked != 0] # where mediation actually exists + + result = [] + for mediator in mediation_matrix["med"].values: + upstreams = list(set(nonzeros.sel(med=mediator)["rei"].values)) + result.append((mediator, upstreams)) + + return dict(sorted(result, key=lambda x: len(x[1]))) + + +@fhs_lru_cache(1) +def get_cause_mediator_pairs(gbd_round_id: int) -> List[tuple]: + """Provide cause-risk pairs where risk is a mediator. + + Args: + gbd_round_id (int): gbd round id. + + Returns: + (List[tuple]): list of (cause, mediator) pairs. + """ + mediation_matrix = mediation.get_mediation_matrix(gbd_round_id) + + mean = mediation_matrix.mean(["rei", "draw"]) + stacked = mean.stack(pairs=["acause", "med"]) # making a 1-D xarray allows us to filter + greater_than_zero = stacked[stacked > 0] + return greater_than_zero.pairs.values.tolist() + + +@fhs_lru_cache(1) +def get_intrinsic_sev_pairs(gbd_round_id: int) -> List[Tuple]: + """Return all intrinsic SEV cause-risk pairs. + + Args: + gbd_round_id (int): gbd round id. + + Returns: + (List[Tuple]): list of (cause, mediator) pairs. + """ + pairs = get_cause_mediator_pairs(gbd_round_id) + + # remove PAF=1 and metab_bmd pairs (RR not available for either) + # pairs = [x for x in pairs + # if not property_values.loc[x[0], x[1]]["paf_equals_one"]] + pairs = [x for x in pairs if x[1] != "metab_bmd"] + + return pairs + + +def get_sev_intrinsic_map(gbd_round_id: int) -> Dict[str, bool]: + """Makes a dictionary mapping all the sevs to whether or not they're intrinsic.""" + # now figure out all the sevs and isevs we need to ensemble for + sevs = get_sev_forecast_reis(gbd_round_id) + + # this part deals with the intrinsic SEVs + intrinsic_pairs = get_intrinsic_sev_pairs(gbd_round_id) + + # because we ensemble mediator isevs, so should remove them from sevs + mediators = list(set(x[1] for x in intrinsic_pairs)) + sevs = list(set(sevs) - set(mediators)) + + # these mediator isevs will be ensembled + isevs = [x[0] + "_" + x[1] for x in intrinsic_pairs] + + intrinsic_map = dict(**{sev: False for sev in sevs}, **{isev: True for isev in isevs}) + + return intrinsic_map diff --git a/gbd_2021/disease_burden_forecast_code/risk_factors/sev/rrmax.py b/gbd_2021/disease_burden_forecast_code/risk_factors/sev/rrmax.py new file mode 100644 index 0000000..d2a8d90 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/risk_factors/sev/rrmax.py @@ -0,0 +1,52 @@ +import xarray as xr +from fhs_lib_database_interface.lib import db_session +from fhs_lib_database_interface.lib.constants import CauseRiskPairConstants +from fhs_lib_database_interface.lib.fhs_lru_cache import fhs_lru_cache +from fhs_lib_database_interface.lib.query import risk as query_risk +from fhs_lib_database_interface.lib.strategy_set import strategy +from fhs_lib_file_interface.lib.query import rrmax + +PAF_OF_ONE_RRMAX = 1000 + + +def read_rrmax( + acause: str, + cause_id: int, + rei: str, + gbd_round_id: int, + version: str, + draws: int, + draw_start: int = 0, +) -> xr.DataArray: + """A wrapper around the central read_rrmax, that handles the PAFs of 1 case. + + PAFs of 1 don't have rrmax files on disk, and should have their rrmax treated as 1000. + """ + pafs_of_one = _get_pafs_of_one(gbd_round_id=gbd_round_id) + rei_id = query_risk.get_rei_id(rei=rei) + + if (cause_id, rei_id) in pafs_of_one: + return xr.DataArray(PAF_OF_ONE_RRMAX) + + return rrmax.read_rrmax( + acause=acause, + cause_id=cause_id, + rei=rei, + gbd_round_id=gbd_round_id, + version=version, + draws=draws, + draw_start=draw_start, + ) + + +@fhs_lru_cache(1) +def _get_pafs_of_one(gbd_round_id: int) -> set[tuple[int, int]]: + """Get the set of cause-risk pairs with PAFs of 1.""" + with db_session.create_db_session() as session: + pafs_of_one = strategy.get_cause_risk_pair_set( + session=session, + gbd_round_id=gbd_round_id, + strategy_id=CauseRiskPairConstants.CAUSE_RISK_PAF_EQUALS_ONE_SET_ID, + ) + + return set(zip(pafs_of_one["cause_id"], pafs_of_one["rei_id"])) diff --git a/gbd_2021/disease_burden_forecast_code/risk_factors/sev/run_workflow.py b/gbd_2021/disease_burden_forecast_code/risk_factors/sev/run_workflow.py new file mode 100644 index 0000000..35c2677 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/risk_factors/sev/run_workflow.py @@ -0,0 +1,488 @@ +from typing import Any, Dict, List, Optional, Tuple + +from fhs_lib_file_interface.lib.file_interface import FBDPath +from fhs_lib_file_interface.lib.symlink_file import symlink_file_to_directory +from fhs_lib_file_interface.lib.versioning import Versions +from fhs_lib_genem.lib.model_restrictions import ModelRestrictions +from fhs_lib_genem.run.create_stage import create_genem_tasks +from fhs_lib_orchestration_interface.lib import cluster_tools +from fhs_lib_year_range_manager.lib.year_range import YearRange +from jobmon.client.api import Tool +from jobmon.client.task import Task +from jobmon.client.workflow import Workflow +from tiny_structured_logger.lib import fhs_logging +from typeguard import typechecked + +from fhs_pipeline_sevs.lib.mediation import ( + get_intrinsic_sev_pairs, + get_mediator_upstreams, + get_sev_intrinsic_map, +) +from fhs_pipeline_sevs.run.task import ( + get_combine_future_mediator_task, + get_compute_future_mediator_task, + get_compute_past_intrinsic_sev_task, +) + +MRBRT_COV_STAGE1 = "sdi" +MRBRT_COV_STAGE2 = "sdi" +TRANSFORM = "logit" + +# Certain risks have SEVs forecasted via processes completely separate +# from the main forecasting pipeline; i.e. custom-made by other teams. +# These files reside in the --precalculated-version input directory +# specified in console.py. "smoking_direct" is currently custom-made. +PRECOMPUTED_SEVS = ["smoking_direct"] + +TIMEOUT = 260000 # giving the entire workflow 3 days to run + +logger = fhs_logging.get_logger() + + +def create_workflow( + wf_version: Optional[str], + cluster_project: str, + past_version: str, + rr_version: str, + past_sdi_version: str, + future_sdi_version: str, + out_version: str, + precalculated_version: str, + draws: int, + years: YearRange, + gbd_round_id: int, + log_level: Optional[str], + model_restrictions: ModelRestrictions, + run_stage_1: bool = False, + run_stage_2a: bool = False, + run_stage_2b: bool = False, + run_stage_3: bool = False, + run_stage_4: bool = False, + run_stage_5: bool = False, + cluster_name: str = cluster_tools.identify_cluster(), +) -> Workflow: + """Construct and return a workflow.""" + # if all flags are False, then run all stages. Otherwise, only run the stages which flags + # are True + run_all_stages = not any( + [run_stage_1, run_stage_2a, run_stage_2b, run_stage_3, run_stage_4, run_stage_5] + ) + # stages 2-5 are the ensemble model + run_ensemble = any([run_all_stages, run_stage_2a, run_stage_2b, run_stage_3, run_stage_4]) + + tool = Tool(name=TOOL_NAME) + + wf_args = f"{PIPELINE_NAME}_{cluster_tools.get_wf_version(wf_version)}" + + # Initialize new workflow. + workflow = tool.create_workflow( + name=PIPELINE_NAME, + workflow_args=wf_args, + default_cluster_name=cluster_name, + ) + + if run_all_stages or run_stage_1: + # (1) compute past intrinsic SEVs + stage_1_tasks = compute_past_intrinsic_sevs_stage( + tool=tool, + cluster_project=cluster_project, + acause="all", + rei="all", + sev_version=past_version, + rr_version=rr_version, + gbd_round_id=gbd_round_id, + draws=draws, + log_level=log_level, + ) + workflow.add_tasks(stage_1_tasks) + else: + stage_1_tasks = [] + + if run_ensemble: + _symlink_precomputed( + gbd_round_id=gbd_round_id, + precalculated_version=precalculated_version, + out_version=out_version, + ) + + intrinsic_map = get_sev_intrinsic_map(gbd_round_id) + ensemble_tasks = create_genem_tasks( + tool=tool, + cluster_project=cluster_project, + entities=intrinsic_map.keys(), + intrinsic=intrinsic_map, + stage="sev", + versions=Versions( + "FILEPATH", + "FILEPATH", + "FILEPATH", + "FILEPATH", + ), + gbd_round_id=gbd_round_id, + years=years, + draws=draws, + transform=TRANSFORM, + intercept_shift_transform="none", + mrbrt_cov_stage1=MRBRT_COV_STAGE1, + mrbrt_cov_stage2=MRBRT_COV_STAGE2, + national_only=False, + scenario_quantiles=True, + subfolder=None, + uncross_scenarios=True, + age_standardize=True, + remove_zero_slices=True, + rescale_ages=True, + model_restrictions=model_restrictions, + log_level=log_level, + run_pv=run_stage_2a, + run_forecast=run_stage_2b, + run_model_weights=run_stage_3, + run_collect_models=run_stage_4, + ) + else: + ensemble_tasks = [] + + for ensemble_task in ensemble_tasks: + for task_1 in stage_1_tasks: + ensemble_task.add_upstream(task_1) + + workflow.add_tasks(ensemble_tasks) + + if run_all_stages or run_stage_5: + # (5) compute future mediator total sevs + stage_5_tasks = future_mediator_stage( + tool=tool, + cluster_project=cluster_project, + acause="all", + rei="all", + sev_version=out_version, + past_sev_version=past_version, + rr_version=rr_version, + gbd_round_id=gbd_round_id, + years=years, + draws=draws, + log_level=log_level, + ) + + for task_5 in stage_5_tasks: + for ensemble_task in ensemble_tasks: + task_5.add_upstream(ensemble_task) + + workflow.add_tasks(stage_5_tasks) + + workflow.bind() + + return workflow + + +def run_pipeline( + wf_version: Optional[str], + cluster_project: str, + past_version: str, + rr_version: str, + past_sdi_version: str, + future_sdi_version: str, + out_version: str, + precalculated_version: str, + draws: int, + years: YearRange, + gbd_round_id: int, + log_level: Optional[str], + model_restrictions: ModelRestrictions, + run_stage_1: bool = False, + run_stage_2a: bool = False, + run_stage_2b: bool = False, + run_stage_3: bool = False, + run_stage_4: bool = False, + run_stage_5: bool = False, + cluster_name: str = cluster_tools.identify_cluster(), +) -> Tuple[str, Workflow]: + """Construct and run a workflow.""" + workflow = create_workflow( + wf_version=wf_version, + cluster_project=cluster_project, + past_version=past_version, + rr_version=rr_version, + past_sdi_version=past_sdi_version, + future_sdi_version=future_sdi_version, + out_version=out_version, + precalculated_version=precalculated_version, + draws=draws, + years=years, + gbd_round_id=gbd_round_id, + log_level=log_level, + model_restrictions=model_restrictions, + run_stage_1=run_stage_1, + run_stage_2a=run_stage_2a, + run_stage_2b=run_stage_2b, + run_stage_3=run_stage_3, + run_stage_4=run_stage_4, + run_stage_5=run_stage_5, + cluster_name=cluster_name, + ) + workflow.run(seconds_until_timeout=TIMEOUT, resume=True) + logger.info( + "Workflow ended", + bindings=dict(status=workflow.workflow_id), + ) + + return workflow + + +@typechecked() +def compute_past_intrinsic_sevs_stage( + tool: Tool, + cluster_project: str, + acause: str, + rei: str, + gbd_round_id: int, + sev_version: str, + rr_version: str, + draws: int, + log_level: Optional[str], +) -> List[Task]: + """Make tasks for computing past intrinsic sevs.""" + # Get all the cause-mediator pairs from mediation matrix + pairs = get_intrinsic_sev_pairs(gbd_round_id) + + # Now filter the pairs based on provided acause/rei inputs + if acause != "all": + pairs = filter(lambda x: x[0] == acause, pairs) + if rei != "all": + pairs = filter(lambda x: x[1] == rei, pairs) + + compute_resources = get_compute_resources( + memory_gb=int(BASE_MEMORY_GB + JOB_MEMORY_GB_PER_DRAW * draws / 1000), + cluster_project=cluster_project, + runtime="16:00:00", + ) + + tasks = [] + for task_acause, task_rei in pairs: + task = get_compute_past_intrinsic_sev_task( + tool=tool, + compute_resources=compute_resources, + acause=task_acause, + rei=task_rei, + gbd_round_id=gbd_round_id, + sev_version=sev_version, + rr_version=rr_version, + draws=draws, + log_level=log_level, + ) + tasks.append(task) + + return tasks + + +@typechecked() +def future_mediator_stage( + tool: Tool, + cluster_project: str, + acause: str, + rei: str, + gbd_round_id: int, + sev_version: str, + past_sev_version: str, + rr_version: str, + years: YearRange, + draws: int, + log_level: Optional[str], +) -> List[Task]: + """Depending on the given acause and rei, computes/exports either. + + 1.) all pairs for a given acause + 2.) all pairs for a given rei + 3.) all cause-risk pairs + + + Args: + tool: Jobmon tool to associate tasks with + cluster_project: cluster project to run tasks under + acause (str): acause. + rei (str): risk j fomr (1). + gbd_round_id (int): gbd round id. + sev_version (str): version of future SEV. + past_sev_version (str): version of past SEV. + rr_version (str): version of past RR. + years (YearRange): past_start:forecast_start:forecast_end. + draws (int): number of draws kept in process. + log_level: log_level to use for tasks + """ + # Get all the cause-mediator pairs from mediation matrix + pairs = get_intrinsic_sev_pairs(gbd_round_id) + + # Now filter the pairs based on provided acause/rei inputs + if acause != "all": + pairs = filter(lambda x: x[0] == acause, pairs) + if rei != "all": + pairs = filter(lambda x: x[1] == rei, pairs) + + # first set up all the compute/combine tasks for the cause-mediator pairs + mediator_upstreams_dict = get_mediator_upstreams(gbd_round_id) + + compute_tasks_by_risk = {} + combine_tasks_by_risk = {} + for mediator in mediator_upstreams_dict.keys(): + acauses = [x[0] for x in pairs if x[1] == mediator] + + if acauses: + compute_tasks, combine_task = _isev_to_sev_batch_tasks( + tool=tool, + cluster_project=cluster_project, + acauses=acauses, + rei=mediator, + gbd_round_id=gbd_round_id, + sev_version=sev_version, + past_sev_version=past_sev_version, + rr_version=rr_version, + years=years, + draws=draws, + log_level=log_level, + ) + compute_tasks_by_risk[mediator] = compute_tasks + combine_tasks_by_risk[mediator] = combine_task + + mediators = compute_tasks_by_risk.keys() + + for med in mediators: # these are mediators + for med_2 in mediators: # another mediator + if med_2 != med and med_2 in mediator_upstreams_dict[med]: + for task in compute_tasks_by_risk[med]: + task.add_upstream(combine_tasks_by_risk[med_2]) + + tasks = sum(list(compute_tasks_by_risk.values()), []) + list( + combine_tasks_by_risk.values() + ) + + return tasks + + +@typechecked() +def _isev_to_sev_batch_tasks( + tool: Tool, + cluster_project: str, + acauses: List[str], + rei: str, + gbd_round_id: int, + sev_version: str, + past_sev_version: str, + rr_version: str, + years: YearRange, + draws: int, + log_level: Optional[str], +) -> Tuple[List[Task], Task]: + """Compute SEVs for a given set of cause-risk iSEVs. + + Given cause-risk iSEV pairs, perform the following: + 1.) make a task to compute the cause-risk SEVs + 2.) make a task to combine the cause-risk SEVs into their risk SEVs, held on (1) + + Args: + tool: Jobmon tool to associate tasks with + cluster_project: cluster project to run tasks under + acauses (List[str]): List of acauses to make cause-risk-specific + compute_future_mediator tasks for + rei (str): REI to create combine_future_mediator task for + gbd_round_id (int): gbd round id. + sev_version (str): version of future SEV. + past_sev_version (str): version of past SEV. + rr_version (str): version of past RR. + years (YearRange): past_start:forecast_start:forecast_end. + draws (int): number of draws kept in process. + log_level: log_level to use for tasks + + Returns: + Tuple[List[Task], Task]: the cause-risk-specific compute_future_mediator tasks, and + the risk-specific combine_future_mediator task -- in that order + """ + compute_resources = get_compute_resources( + memory_gb=int(BASE_MEMORY_GB + JOB_MEMORY_GB_PER_DRAW * draws / 1000), + runtime="16:00:00", + cluster_project=cluster_project, + ) + + compute_tasks = [] + for acause in acauses: + compute_task = get_compute_future_mediator_task( + tool=tool, + compute_resources=compute_resources, + acause=acause, + rei=rei, + gbd_round_id=gbd_round_id, + sev_version=sev_version, + past_sev_version=past_sev_version, + rr_version=rr_version, + years=years, + draws=draws, + log_level=log_level, + ) + + compute_tasks.append(compute_task) + + combine_task = get_combine_future_mediator_task( + tool=tool, + compute_resources=compute_resources, + rei=rei, + gbd_round_id=gbd_round_id, + past_sev_version=past_sev_version, + sev_version=sev_version, + rr_version=rr_version, + years=years, + log_level=log_level, + ) + + for compute_task in compute_tasks: + combine_task.add_upstream(compute_task) + + return compute_tasks, combine_task + + +def _symlink_precomputed( + gbd_round_id: int, precalculated_version: str, out_version: str +) -> None: + """Symlink all the SEVs that have been computed via other processes. + + Args: + gbd_round_id (int): gbd round id. + precalculated_version (str): version where pre-computed SEVs are. + out_version (str): the final output version for all the SEVs. + """ + for entity in PRECOMPUTED_SEVS: + try: + precomputed_dir = FBDPath( + gbd_round_id=gbd_round_id, + past_or_future="future", + stage="sev", + version=precalculated_version, + ) + + out_dir = precomputed_dir.set_version(out_version) + + symlink_file_to_directory(precomputed_dir / (entity + ".nc"), out_dir) + + except FileExistsError: + logger.info(f"{entity}.nc already exists.") + + except FileNotFoundError: + raise FileNotFoundError(f"{entity}.nc not found.") + + +def get_compute_resources( + memory_gb: int, + cluster_project: str, + runtime: str, + cores: int = DEFAULT_CORES, +) -> Dict[str, Any]: + """Return a dictionary containing keys & values required for Jobmon Task creation.""" + error_logs_dir, output_logs_dir = cluster_tools.get_logs_dirs() + + return dict( + memory=f"{memory_gb}G", + cores=cores, + runtime=runtime, + project=cluster_project, + queue=DEFAULT_QUEUE, + stderr=error_logs_dir, + stdout=output_logs_dir, + ) diff --git a/gbd_2021/disease_burden_forecast_code/vaccine/aggregate_rake.py b/gbd_2021/disease_burden_forecast_code/vaccine/aggregate_rake.py new file mode 100644 index 0000000..556f08a --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/vaccine/aggregate_rake.py @@ -0,0 +1,213 @@ +"""Aggregate or rake subnational estimates for one vaccine. + +For all vaccines in India, Brazil, USA and for MCV2 in China, we aggregate +subnational estimates to produce national estimates. For all other vaccines in +China only, we rake the subnational estimates from the national to ensure +location aggregation consistency. This is related to data availability around +national vs subnational vaccine campaign roll-outs from the GBD VPDs team. It +may change over time. + + +example call: +python aggregate_rake.py \ + --versions FILEPATH \ + -v FILEPATH \ + --gbd-round-id 6 \ + --vaccine mcv2 \ + agg-rake +""" + +from typing import List + +import click +import xarray as xr +from fhs_lib_data_aggregation.lib.aggregator import Aggregator +from fhs_lib_database_interface.lib.constants import AgeConstants, SexConstants +from fhs_lib_database_interface.lib.query import location +from fhs_lib_file_interface.lib import xarray_wrapper +from fhs_lib_file_interface.lib.check_input import check_versions +from fhs_lib_file_interface.lib.file_system_manager import FileSystemManager +from fhs_lib_file_interface.lib.os_file_system import OSFileSystem +from fhs_lib_file_interface.lib.versioning import Versions + +# Agglocs: USA, India, Brazil +AGGLOCS = [102, 135, 163] +# Rakelocs: China +RAKELOCS = 6 + + +def aggregate_subnationals( + gbd_round_id: int, + vaccine: str, + pop_da: xr.DataArray, + vaccine_da: xr.DataArray, +) -> xr.DataArray: + """Aggregate locations for one vaccine. + + Args: + gbd_round_id (int): GBD round ID. + vaccine (str): Vaccine name e.g. mcv2 + pop_da (xr.DataArray): Population data + vaccine_da (xr.DataArray): Vaccine data + Returns: + xr.DataArray: DataArray with aggregated locations + """ + for loc in AGGLOCS: + # drop nationals for all locations except China: + vaccine_da = vaccine_da.where(vaccine_da.location_id != loc, drop=True) + if vaccine == "mcv2": # only aggregate China subnationals for MCV2 + vaccine_da = vaccine_da.where(vaccine_da.location_id != RAKELOCS, drop=True) + + # aggregate all locations up the location hierarchy + locs = location.get_location_set(gbd_round_id=gbd_round_id, include_aggregates=True) + hierarchy = locs.set_index("location_id")["parent_id"].to_xarray() + correction_factor = ( + location.get_regional_population_scalars(gbd_round_id) + .set_index(["location_id"]) + .to_xarray()["mean"] + ) + correction_factor.name = "pop_scalar" + aggregator = Aggregator(pop_da) + aggregated_da = aggregator.aggregate_locations( + data=vaccine_da, loc_hierarchy=hierarchy, correction_factor=correction_factor + ).rate + + return aggregated_da + + +def rake_china_subnationals( + gbd_round_id: int, + vaccine_da: xr.DataArray, + pop_da: xr.DataArray, +) -> xr.DataArray: + """Rake China subnational estimates from the national for all vaccines except MCV2. + + Args: + gbd_round_id (int): GBD round ID + vaccine_da (xr.DataArray): Vaccine data with aggregated location estimates + pop_da (xr.DataArray): Population data + + Returns: + xr.DataArray: DataArray with raked subnational estimates for China + """ + # construct raking hierarchy with location ID 6 and subnational location IDs + loc_table = location.get_location_set(gbd_round_id=gbd_round_id) + rake_hierarchy = loc_table[["location_id", "parent_id", "level"]] + china_subnats = loc_table.query(f"parent_id == {RAKELOCS}").location_id.tolist() + china_subnats.append(RAKELOCS) + not_china = list( + set(loc_table.query("level in [3,4]").location_id.tolist()) - set(china_subnats) + ) + rake_hierarchy = rake_hierarchy.query("location_id == @china_subnats") + + # rake subnational estimates from national + aggregator = Aggregator(pop_da) + china_da = vaccine_da.sel(location_id=china_subnats) + china_da_raked = aggregator.rake_locations( + data=china_da, location_hierarchy=rake_hierarchy + ) + + raked_da = xr.concat( + [china_da_raked, vaccine_da.sel(location_id=not_china)], dim="location_id" + ) + + return raked_da + + +def agg_rake_main(versions: Versions, gbd_round_id: int, vaccine: str) -> xr.DataArray: + """Main aggregation function. + + Args: + versions (Versions): Versions object with list of versions + gbd_round_id (int): Current gbd_round_id + vaccine (str): Vaccine name e.g. mcv2 + """ + pop_da = xarray_wrapper.open_xr( + versions.data_dir(gbd_round_id, "future", "population") / "population_agg.nc" + ) + vaccine_da = xarray_wrapper.open_xr( + versions.data_dir(gbd_round_id, "future", "vaccine") / f"vacc_{vaccine}.nc" + ) + + # vaccines data has no detailed age/sex IDs + pop_da = pop_da.sel( + age_group_id=AgeConstants.VACCINE_AGE_ID, + sex_id=SexConstants.BOTH_SEX_ID, + ) + + if vaccine != "mcv2": + # rake subnational estimates for all other vaccines in China only + vaccine_da = rake_china_subnationals(gbd_round_id, vaccine_da, pop_da) + aggregated_da = aggregate_subnationals(gbd_round_id, vaccine, pop_da, vaccine_da) + + fname = ( + versions.data_dir(gbd_round_id, "future", "vaccine") / f"_agg_rake/vacc_{vaccine}.nc" + ) + xarray_wrapper.save_xr(aggregated_da, fname, metric="rate", space="identity") + + +@click.group() +@click.option( + "--versions", + "-v", + type=str, + required=True, + multiple=True, + help=("Vaccine and Population versions"), +) +@click.option( + "--vaccine", + type=str, + required=True, + help=("Vaccine name e.g. dtp3"), +) +@click.option( + "--gbd-round-id", + required=True, + type=int, + help="GBD round ID", +) +@click.pass_context +def cli( + ctx: click.Context, + versions: List[str], + vaccine: str, + gbd_round_id: int, +) -> None: + """Main cli function to parse args and pass them to the subcommands. + + Args: + ctx (click.Context): ctx object. + versions (List[str]): Population and vaccine versions + vaccine (str): Vaccine name + gbd_round_id (int): Current gbd round id + """ + versions = Versions(*versions) + check_versions(versions, "future", ["population", "vaccine"]) + ctx.obj = { + "versions": versions, + "vaccine": vaccine, + "gbd_round_id": gbd_round_id, + } + + +@cli.command() +@click.pass_context +def agg_rake(ctx: click.Context) -> None: + """Call to main function. + + Args: + ctx (click.Context): context object containing relevant params parsed + from command line args. + """ + FileSystemManager.set_file_system(OSFileSystem()) + + agg_rake_main( + versions=ctx.obj["versions"], + vaccine=ctx.obj["vaccine"], + gbd_round_id=ctx.obj["gbd_round_id"], + ) + + +if __name__ == "__main__": + cli() \ No newline at end of file diff --git a/gbd_2021/disease_burden_forecast_code/vaccine/constants.py b/gbd_2021/disease_burden_forecast_code/vaccine/constants.py new file mode 100644 index 0000000..57a5375 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/vaccine/constants.py @@ -0,0 +1,7 @@ +"""Vaccine Pipeline Local Constants.""" + + +class ModelConstants: + """Constants related to modeling or model specification.""" + + DEFAULT_OFFSET = 1e-8 \ No newline at end of file diff --git a/gbd_2021/disease_burden_forecast_code/vaccine/model_strategy.py b/gbd_2021/disease_burden_forecast_code/vaccine/model_strategy.py new file mode 100644 index 0000000..199e03f --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/vaccine/model_strategy.py @@ -0,0 +1,100 @@ +from collections import namedtuple + +from fhs_lib_data_transformation.lib import processing +from fhs_lib_database_interface.lib.query.model_strategy import ModelStrategyNames +from fhs_lib_model.lib.arc_method.arc_method import ArcMethod +from fhs_lib_model.lib.limetr import LimeTr + +from fhs_pipeline_vaccine.lib.constants import ModelConstants + +ModelParameters = namedtuple( + "ModelParameters", + ( + "Model, " + "processor, " + "covariates, " + "fixed_effects, " + "fixed_intercept, " + "random_effects, " + "indicators, " + "spline, " + "predict_past_only, " + ), +) + +MODEL_PARAMETERS = { + "mcv1": { + ModelStrategyNames.ARC.value: ModelParameters( + Model=ArcMethod, + processor=processing.LogProcessor( + years=None, + offset=ModelConstants.DEFAULT_OFFSET, + no_mean=True, + intercept_shift="unordered_draw", + gbd_round_id=6, + ), + covariates=None, + fixed_effects=None, + fixed_intercept=None, + random_effects=None, + indicators=None, + spline=None, + predict_past_only=False, + ), + ModelStrategyNames.LIMETREE.value: ModelParameters( + Model=LimeTr, + processor=processing.LogitProcessor( + years=None, + offset=ModelConstants.DEFAULT_OFFSET, + remove_zero_slices=True, + intercept_shift="unordered_draw", + gbd_round_id=6, + ), + covariates={"sdi": processing.NoTransformProcessor(gbd_round_id=None, years=None)}, + fixed_effects={"sdi": [-float("inf"), float("inf")]}, + fixed_intercept="unrestricted", + random_effects=None, + indicators=None, + spline=None, + predict_past_only=False, + ), + ModelStrategyNames.NONE.value: None, + }, + "dtp3": { + ModelStrategyNames.ARC.value: ModelParameters( + Model=ArcMethod, + processor=processing.LogProcessor( + years=None, + offset=ModelConstants.DEFAULT_OFFSET, + no_mean=True, + intercept_shift="unordered_draw", + gbd_round_id=6, + ), + covariates=None, + fixed_effects=None, + fixed_intercept=None, + random_effects=None, + indicators=None, + spline=None, + predict_past_only=False, + ), + ModelStrategyNames.LIMETREE.value: ModelParameters( + Model=LimeTr, + processor=processing.LogitProcessor( + years=None, + offset=ModelConstants.DEFAULT_OFFSET, + remove_zero_slices=True, + intercept_shift="unordered_draw", + gbd_round_id=6, + ), + covariates={"sdi": processing.NoTransformProcessor(gbd_round_id=None, years=None)}, + fixed_effects={"sdi": [-float("inf"), float("inf")]}, + fixed_intercept="unrestricted", + random_effects=None, + indicators=None, + spline=None, + predict_past_only=False, + ), + ModelStrategyNames.NONE.value: None, + }, +} \ No newline at end of file diff --git a/gbd_2021/disease_burden_forecast_code/vaccine/model_strategy_queries.py b/gbd_2021/disease_burden_forecast_code/vaccine/model_strategy_queries.py new file mode 100644 index 0000000..61eb541 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/vaccine/model_strategy_queries.py @@ -0,0 +1,77 @@ +from fhs_lib_year_range_manager.lib.year_range import YearRange + +from fhs_pipeline_vaccine.lib import model_strategy + + +def get_vaccine_model( + vaccine: str, model_type: str, years: YearRange, gbd_round_id: int +) -> model_strategy.ModelParameters: + """Pulling the necessary processor model to modify. + + Args: + vaccine (str): which vaccine model + model_type (str): what type of model to modify + years (YearRange): years to enter + gbd_round_id (int): which gbd round + + Returns: + model_strategy.ModelParameters + + Return the default model parameters for the given vaccine model, + with the "years" and "gbd_round_id" set to the given ones everywhere in + the parameters. + """ + model_parameters = model_strategy.MODEL_PARAMETERS[vaccine][model_type] + model_parameters = _update_processor_years(model_parameters, years) + model_parameters = _update_processor_gbd_round_id(model_parameters, gbd_round_id) + return model_parameters + + +def _update_processor_years( + model_parameters: model_strategy.ModelParameters, years: YearRange +) -> model_strategy.ModelParameters: + """Updating the years for the processor. + + Args: + model_parameters (model_strategy.ModelParameters): model_parameters + years (YearRange): years to enter + + Returns: + model_strategy.ModelParameters + + Return the default model parameters for the given vaccine model, + with the "years" and "gbd_round_id" set to the given ones everywhere in + the parameters. + """ + model_parameters.processor.years = years + + if model_parameters.covariates: + for cov_name in model_parameters.covariates.keys(): + model_parameters.covariates[cov_name].years = years + + return model_parameters + + +def _update_processor_gbd_round_id( + model_parameters: model_strategy.ModelParameters, gbd_round_id: int +) -> model_strategy.ModelParameters: + """Updating the gbd round id for the processor. + + Args: + model_parameters (model_strategy.ModelParameters): model parameters + to update + gbd_round_id (int): gbd round id to insert + + Returns: + model_strategy.ModelParameters + + gbd_round_id is entered as ``None`` in the processor for the dependent + variable and covariates so it needs to be updated here + """ + model_parameters.processor.gbd_round_id = gbd_round_id + + if model_parameters.covariates: + for cov_name in model_parameters.covariates.keys(): + model_parameters.covariates[cov_name].gbd_round_id = gbd_round_id + + return model_parameters \ No newline at end of file diff --git a/gbd_2021/disease_burden_forecast_code/vaccine/run_ratio_vaccines.py b/gbd_2021/disease_burden_forecast_code/vaccine/run_ratio_vaccines.py new file mode 100644 index 0000000..3a31304 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/vaccine/run_ratio_vaccines.py @@ -0,0 +1,1194 @@ +"""Run Ratio Vaccines. + +Simple vaccines have been introduced in every GBD country. +Ratio vaccines were first introduced more recently and have not yet +been added to the routine schedule in all countries. These newer +generation vaccines therefore require the additional step of forecasting +introduction dates. Due to the typical scheduling of rotavirus, pcv and hib +vaccine administration programs, coverage for these vaccines was assumed to +converge to dtp3 coverage over time (and thus cannot exceed dtp2 coverage). +(Foreman et al) + +The ratio vaccines, and their corresponding simple vaccines, are: +- rotavirus (dtp3) +- pcv (dtp3) +- hib (dtp3) +- mcv2 (mcv1) + +For countries with known introduction dates, the coverage forecast is +the ratio * simple vaccine coverage. + +For countries without any observed or set introduction dates, use survival +analysis to simulate introduction dates. For each theoretically possible +introduction year (every year in the forecasts), generate a theoretical +scale-up curve (a forecast for what the coverage to simple vaccine ratio would +be, if the vaccine was introduced in that country, in that year, using a simple +linear mixed effects model). +Finally, match the theoretical scale-up curve to the simulated years of +introduction to get ratios for the locations without known vaccine rollout +years. This mirrors what is done by the GBD where DTP3 and measles coverage are +modeled directly but rota, pcv, and hib are modeled as ratios to DTP3 +coverage. + +Example Call +python run_ratio_vaccines.py \ +--vaccine mcv2 \ +--gbd-round-id 6 \ +--years 1980:2020:2050 \ +--draws 1000 \ +--past-ratio-version 20210606 \ +--future-ratio-version 20210606 \ +--version FILEPATH \ +-v FILEPATH \ +--intro-version 20220615_vaccine_intro \ +--gavi-version GAVI_eligible_countries_2018 \ +main +""" + +from typing import List, Optional, Tuple + +import click +import numpy as np +import pandas as pd +import statsmodels.api as sm +import xarray as xr +from fhs_lib_data_transformation.lib.processing import ( + LogitProcessor, + get_dataarray_from_dataset, + logit_with_offset, +) +from fhs_lib_data_transformation.lib.resample import resample +from fhs_lib_database_interface.lib.constants import ScenarioConstants +from fhs_lib_database_interface.lib.query.location import get_location_set +from fhs_lib_file_interface.lib.check_input import check_versions +from fhs_lib_file_interface.lib.file_system_manager import FileSystemManager +from fhs_lib_file_interface.lib.os_file_system import OSFileSystem +from fhs_lib_file_interface.lib.versioning import Versions +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr, save_xr +from fhs_lib_year_range_manager.lib.year_range import YearRange +from lifelines import WeibullAFTFitter, WeibullFitter +from scipy.special import expit, logit +from scipy.stats import weibull_min +from tiny_structured_logger.lib.fhs_logging import get_logger + +from fhs_pipeline_vaccine.lib.constants import ModelConstants +from fhs_pipeline_vaccine.lib.run_simple_vaccines import load_past_vaccine_data + +logger = get_logger() + +RATIO_COLUMN_MAP = dict( + rotac="ratio_rotac_to_dtp3", + pcv3="ratio_pcv3_to_dtp3", + hib3="ratio_hib3_to_dtp3", + mcv2="ratio_mcv2_to_mcv1", +) +RELEVANT_MEs = { + "rotac": "vacc_rotac_dpt3_ratio", + "pcv3": "vacc_pcv3_dpt3_ratio", + "hib3": "vacc_hib3_dpt3_ratio", + "mcv2": "vacc_mcv2_mcv1_ratio", +} +# MCV2 actually starts in 1963 but our data only goes back to 1980 +START_YEARS = {"mcv2": 1980, "hib3": 1985, "pcv3": 1999, "rotac": 2005} +SCALE_DICT = {"rotac": 9.5, "pcv3": 14.5, "hib3": None, "mcv2": 35} +SMALL_VALS_THRESHOLD = 1e-4 +DEMOG_COLS = ["location_id", "year_id"] + + +def generate_intro_times( + shape: float, + scale: float, + vacc: str, + scenario: int, + num_draws: int, + start_year: int, + years: YearRange, +) -> List[float]: + """Generate draws within the acceptable range for ratio vaccines. + + Args: + shape (float): parameter for Weibull distribution + scale (float): parameter for Weibull distribution + vacc (str): name of the vaccine to run + scenario (int): scenario ID to use + num_draws (int): number of draws + start_year (int): year vaccine was first introduced + years (YearRange): past_start:forecast_start:forecast_end + + Returns: + List[float] + + Raises: + ValueError: + If Weibull parameters shape and scale are invalid + If the vaccine name inputted is invalid + If the scenario ID inputted is not -1, 0, or 1 + """ + # Check that valid Weibull parameters were given + + if not (shape and scale): + raise ValueError("Invalid Weibull parameters") + + # Check that vaccine is valid + if vacc not in RATIO_COLUMN_MAP.keys(): + raise ValueError( + f"Passed invalid vaccine shorthand ({vacc}), " + f"must be one of {RATIO_COLUMN_MAP.keys()}" + ) + + # Check that the scenario ID is valid + if scenario not in ScenarioConstants.SCENARIOS: + raise ValueError( + f"Passed invalid scenario argument, {scenario} - " + f"must be one of {ScenarioConstants.SCENARIOS}" + ) + + # Calculate how long ago the vaccine was first released + min_time = years.past_end - start_year + + # The scale must be above a set value or we don't get valid years + # This does not occur with hib3 + if vacc != "hib3": + min_scale = SCALE_DICT[vacc] + if scale < min_scale: + scale = min_scale + + # Use Weibull distribution to estimate the time until introduction + times = weibull_min.rvs(shape, loc=0, scale=scale, size=num_draws) + + # Make sure that the resulting time is in the future + ok_times = times[times >= min_time] + + # Repeat process until we have all of the draws needed + while len(ok_times) < num_draws: + new_draws = num_draws - len(ok_times) + new_times = weibull_min.rvs(shape, loc=0, scale=scale, size=new_draws) + ok_times = np.array(list(ok_times) + list(new_times)) + ok_times = ok_times[ok_times >= min_time] + + # Based on the scenario, select the appropriate percentile of the intro times + scenario_pctiles = {0: 50, -1: 85, 1: 15} + scenario_point_est = np.percentile(ok_times, scenario_pctiles[scenario]) + new_times = [scenario_point_est] * num_draws + + # Check again that we're not predicting dates in the past + + if np.min(new_times) < min_time: + raise ValueError("Predicting introduction years in the past.") + + return new_times + + +def generate_theoretical_scaleups( + past_ratios: pd.DataFrame, + simulated_intro_locations: List[int], + vaccine: str, + years: YearRange, + draws: int, + year_intro_col: Optional[str] = None, +) -> pd.DataFrame: + """Estimate the conditional coverage scaleup for all forecasted years. + + Args: + past_ratios (pd.DataFrame): contains information with location and year specific ratios + simulated_intro_locations (List[int]): a list of location_ids with + locations to simulate introduction dates for + vaccine (str): name of the vaccine to run + years (YearRange): past_start:forecast_start:forecast_end + draws (int): number of draws + year_intro_col (str): name of the column in past_ratios with the introduction year + + Returns: + pd.DataFrame + """ + # Select for locations that are simulated and have past ratios + sim_ratios = past_ratios.location_id.isin(simulated_intro_locations) + + # Generate theoretical curves for each year + scale_ups = [] + + # Estimate last known past year to intercept shift + for year_id in range(years.past_end, years.forecast_end + 1): + sample_ratios = past_ratios.copy() + + # If the location is simulated, use year_id as the theoretical + # introduction date + sample_ratios.loc[sim_ratios, year_intro_col] = year_id + + # Calculate the number of years since introduction using the + # hypothetical introductory year + + sample_ratios["year_id"] = sample_ratios["year_id"] - sample_ratios[year_intro_col] + + # Redefine 'year_id' so that it's the number of years since introduction + # This makes it easier to model later on + + # We only want post-introduction data + sample_ratios = sample_ratios[sample_ratios["year_id"] >= 0] + + # Clip 0s and 1s + sample_ratios[RATIO_COLUMN_MAP[vaccine]] = sample_ratios[ + RATIO_COLUMN_MAP[vaccine] + ].clip(lower=SMALL_VALS_THRESHOLD, upper=1 - SMALL_VALS_THRESHOLD) + + # Prepare variables needed for regression + sample_ratios["logit_ratio"] = logit(sample_ratios[RATIO_COLUMN_MAP[vaccine]]) + sample_ratios["log_income"] = np.log(sample_ratios["ldi"]) + not_nan_values = sample_ratios.dropna() + + # The formula is the same for most vaccines + fixed_formula = "logit_ratio ~ simple_vacc_cov + log_income + year_id" + re_formula = "~ education + 1" + + # Rotavirus equation is a bit different + if vaccine == "rotac": + fixed_formula = f"{fixed_formula} + education" + re_formula = None + elif sample_ratios.admin.max() != 0: + # If admin is 0 for every row, don't use it as a covariate + fixed_formula = f"{fixed_formula} + admin" + + # Fit model + lme_model = sm.MixedLM.from_formula( + fixed_formula, + data=not_nan_values, + groups=not_nan_values["location_id"], + re_formula=re_formula, + ) + + # Make ratio predictions + output = lme_model.fit(data=sample_ratios) + sample_ratios["predictions"] = output.predict(exog=sample_ratios) + sample_ratios["ratio"] = expit(sample_ratios["predictions"]) + + intro_year_df = sample_ratios[["location_id", "year_id", "ratio", year_intro_col]] + intro_year_df["year_id"] += intro_year_df[year_intro_col] + intro_year_df = intro_year_df[ + intro_year_df.location_id.isin(simulated_intro_locations) + ] + intro_year_df["intro_year"] = year_id + + scale_ups.append(intro_year_df) + + all_curves = pd.concat(scale_ups) + + return all_curves + + +def parametrize_weibull( + sdi_df: pd.DataFrame, + gavi_version: str, + simple_vacc_cov_scenarios: pd.DataFrame, + vaccine: str, + hib3_intro: pd.DataFrame, + simulated_intro_locations: List[int], + years: YearRange, +) -> Tuple[pd.DataFrame, int, dict]: + """Calculate Weibull parameters for introduction year selection. + + Args: + sdi_df (pd.DataFrame): dataframe with past and future SDI values + gavi_version (str): name of the file with GAVI eligibility information + simple_vacc_cov_scenarios (pd.DataFrame): + contains simple vaccine coverage estimates with scenarios + vaccine (str): name of the vaccine to run + hib3_intro (pd.DataFrame): contains location specific introduction year data about hib3 + simulated_intro_locations (list): + a list of location_ids with locations to simulate introduction dates for + years (YearRange): past_start:forecast_start:forecast_end + + Returns: + Tuple[pd.DataFrame, int, dict] + """ + weibull_inputs = sdi_df.merge( + simple_vacc_cov_scenarios, on=["location_id", "year_id", "scenario"] + ) + + # Make an indicator variable for GAVI eligibility + weibull_inputs["gavi_eligible"] = 0 + gavi_df = pd.read_csv( + "/FILEPATH/" f"{gavi_version}.csv", encoding="ISO-8859-1" + ) + + # Some files may include all locations, not just the GAVI eligible ones + if "gavi_eligible" in gavi_df.columns: + gavi_df = gavi_df[gavi_df.gavi_eligible == 1] + + gavi_locs = list(gavi_df.location_id.values) + is_gavi = weibull_inputs.location_id.isin(gavi_locs) + weibull_inputs.loc[is_gavi, "gavi_eligible"] = 1 + + weibull_preds = weibull_inputs[ + ["location_id", "year_id", "scenario", "gavi_eligible", "sdi", "simple_vacc_cov"] + ] + + weibull_preds["Intercept"] = 1 + + hib3_intro = hib3_intro[hib3_intro["hib3_intro_yr_country"] != 9999] + weibull_inputs = weibull_inputs.merge(hib3_intro, on="location_id") + + if vaccine != "hib3": + weibull_inputs = weibull_inputs[ + weibull_inputs.location_id.isin(simulated_intro_locations) + ] + + weibull_inputs["vacc_intro"] = 0 + + vacc_intro = weibull_inputs.year_id == weibull_inputs.hib3_intro_yr_country + + weibull_inputs.loc[vacc_intro, "vacc_intro"] = 1 + + post_intro = weibull_inputs.year_id > weibull_inputs.hib3_intro_yr_country + + weibull_inputs.loc[post_intro, "vacc_intro"] = 9999 + + # Subset to post-Hib intro years + weibull_inputs = weibull_inputs[ + weibull_inputs.year_id.isin(list(range(1986, years.past_end))) + ] + + weibull_inputs["time_to_intro"] = weibull_inputs["year_id"] - 1985 + + # Remove 9999 + weibull_inputs = weibull_inputs[weibull_inputs.vacc_intro != 9999] + + # Fit initial weibull model to get paramters without covariates + wb = WeibullFitter() + wb.fit( + durations=weibull_inputs["time_to_intro"], event_observed=weibull_inputs["vacc_intro"] + ) + + # Extract shape parameter for reuse in later Weibull + shape_param = wb.rho_ + + cols_needed = ["time_to_intro", "vacc_intro", "gavi_eligible", "sdi", "simple_vacc_cov"] + + if vaccine == "hib3": + cols_needed.remove("gavi_eligible") + + needed_data = weibull_inputs[cols_needed] + + # Fit weibull with covariates to extract coefficients + wb_with_cov = WeibullAFTFitter(fit_intercept=True) + wb_with_cov.fit(needed_data, duration_col="time_to_intro", event_col="vacc_intro") + + coeff_series = wb_with_cov.summary.coef.lambda_ + coeff_dict = dict(coeff_series.items()) + + return (weibull_preds, shape_param, coeff_dict) + + +def draw_years_of_introduction( + vaccine: str, + weibull_preds: pd.DataFrame, + simulated_intro_locations: List[int], + coeff_dict: dict, + shape_param: float, + years: YearRange, + draws: int, + all_curves: pd.DataFrame, +) -> xr.DataArray: + """Estimate years of introductions and their resulting scale up curves. + + Args: + vaccine (str): name of the vaccine to run + weibull_preds (pd.DataFrame): dataframe with covariate values + simulated_intro_locations (list): + a list of location_ids with locations to simulate introduction dates for + coeff_dict (dict): dictionary with the slopes of the Weibull covariates + shape_param (float): parameter given by the Weibull distribution + years (YearRange): past_start:forecast_start:forecast_end + draws (int): number of draws + all_curves (pd.DataFrame): dataframe with the simulated + scale up curves conditioned on the year of introduction + + Returns: + xr.DataArray + """ + start_year = START_YEARS[vaccine] + + needed_preds = weibull_preds[ + (weibull_preds.location_id.isin(simulated_intro_locations)) + & (weibull_preds.year_id == start_year) + ] + + needed_preds["shape"] = shape_param + + if vaccine == "hib3": + coeff_dict["gavi_eligible"] = 0 + + # Go through each row and use the linear model + # to find the logit scale value + needed_preds["logit_scales"] = ( + coeff_dict["Intercept"] + + (coeff_dict["gavi_eligible"] * needed_preds["gavi_eligible"]) + + (coeff_dict["simple_vacc_cov"] * needed_preds["simple_vacc_cov"]) + + (coeff_dict["sdi"] * needed_preds["sdi"]) + ) + + needed_preds["scale"] = np.exp(needed_preds["logit_scales"]) + needed_preds["vaccine"] = vaccine + + draw_cols = [] + for ix in range(0, draws): + draw_cols.append(f"draw_{ix}") + + # Initialize empty columns + for col in draw_cols: + needed_preds[col] = np.nan + + needed_preds[draw_cols] = needed_preds.apply( + lambda x: pd.Series( + { + col: result + for col, result in zip( + draw_cols, + generate_intro_times( + x["shape"], + x["scale"], + x["vaccine"], + x["scenario"], + draws, + start_year, + years, + ), + ) + } + ), + axis=1, + ) + + needed_preds = needed_preds[["location_id", "scenario"] + draw_cols] + needed_preds = needed_preds.set_index(["location_id", "scenario"]).stack().reset_index() + needed_preds = needed_preds.rename(columns={0: "time_to_introduction", "level_2": "draw"}) + needed_preds.columns = ["location_id", "scenario", "draw", "time_to_introduction"] + needed_preds["draw"] = needed_preds.draw.str.replace("draw_", "") + needed_preds["draw"] = needed_preds.draw.astype("int64") + + # We don't deal with fractions of years so round the values + needed_preds["time_to_introduction"] = np.round(needed_preds["time_to_introduction"]) + needed_preds["intro_year"] = needed_preds["time_to_introduction"] + start_year + needed_preds_merged = needed_preds.merge(all_curves, on=["location_id", "intro_year"]) + needed_preds_merged = needed_preds_merged.drop( + ["time_to_introduction", "intro_year"], axis=1 + ) # location, year, scenario, draw, ratio + + needed_preds_merged["age_group_id"] = 22 + needed_preds_merged["sex_id"] = 3 + final_cols = needed_preds_merged[ + ["location_id", "scenario", "draw", "year_id", "ratio", "age_group_id", "sex_id"] + ] + final_cols = final_cols.rename(columns={"ratio": "value"}) + final_cols["year_id"] = final_cols.year_id.astype("int64") + indexed_df = final_cols.set_index( + ["location_id", "scenario", "draw", "year_id", "age_group_id", "sex_id"] + ) + + da = indexed_df.to_xarray().to_array().fillna(0).sel(variable="value", drop=True) + + return da + + +def load_introduction( + intro_version: str, + gbd_round_id: int, + okay_locations: List[int], + vaccine: str, + intro_column_name: Optional[str] = None, +) -> pd.DataFrame: + """Estimate years of introductions and their resulting scale up curves. + + Args: + intro_version (str): file name containing introduction dates + gbd_round_id (int): the gbd round to draw data from + okay_locations (List[int]): acceptable location IDs + vaccine (str): name of the vaccine to run + intro_column_name (str): name of the + column in the introduction data with the years + + Returns: + pd.DataFrame + """ + og_intro = pd.read_csv( + f"FILEPATH/{gbd_round_id}/past/" + f"FILEPATH/{intro_version}.csv" + ) + + if not intro_column_name: + intro_column_name = f"{vaccine}_intro_yr_country" + + rollout = og_intro[["ihme_loc_id", "me_name", "cv_intro", "location_id"]].drop_duplicates() + relevant_info = rollout[rollout.me_name == RELEVANT_MEs[vaccine]] + relevant_locs = relevant_info[relevant_info.location_id.isin(okay_locations)] + cleaned_intro = relevant_locs.rename(columns={"cv_intro": intro_column_name}).drop( + ["me_name", "ihme_loc_id"], axis=1 + ) + + return cleaned_intro + + +def load_simple_vacc_cov( + vaccine: str, version: Versions, gbd_round_id: int, draws: int +) -> pd.DataFrame: + """Load future coverage for the ratio vaccine's corresponding simple vaccine. + + Args: + vaccine (str): name of the vaccine to run + version (Versions): a versions object containing covariates to forecast ratios + gbd_round_id (int): the gbd round to draw data from + draws (int): number of draws + + Returns: + pd.DataFrame + """ + if vaccine == "mcv2": + simple_vacc = "mcv1" + else: + simple_vacc = "dtp3" + + simple_vacc_filepath = ( + version.data_dir(gbd_round_id, "future", "vaccine") / f"vacc_{simple_vacc}.nc" + ) + + simple_vacc_data = open_xr(simple_vacc_filepath) + simple_vacc_data = resample(simple_vacc_data, draws) + simple_vacc_data.name = "simple_vacc_cov" + simple_vacc_cov_df = simple_vacc_data.mean("draw").to_dataframe().reset_index() + simple_vacc_cov_scenarios = simple_vacc_cov_df[ + ["location_id", "year_id", "scenario", "simple_vacc_cov"] + ] + + return simple_vacc_cov_scenarios + + +def load_single_covariate( + version: Versions, + stage: str, + gbd_round_id: int, + years: YearRange, + draws: int, + okay_locations: List[int], +) -> pd.DataFrame: + """Load past and future covariate data. + + Args: + version (Versions): a versions object containing covariates to forecast ratios + stage (str): stage of the desired covariate + gbd_round_id (int): the gbd round to draw data from + years (YearRange): past_start:forecast_start:forecast_end + draws (int): number of draws + okay_locations (List[int]): acceptable location IDs + + Returns: + pd.DataFrame + """ + if stage == "education": + entity = "maternal_education" + else: + entity = stage + + future_filepath = version.data_dir(gbd_round_id, "future", stage) / f"{entity}.nc" + past_filepath = version.data_dir(gbd_round_id, "past", stage) / f"{entity}.nc" + + future_data = open_xr(future_filepath) + past_data = open_xr(past_filepath) + + future_data = future_data.sel(year_id=years.forecast_years) + future_data = resample(future_data, draws) + + past_data = past_data.sel(location_id=okay_locations, year_id=years.past_years) + past_data = resample(past_data, draws) + + full = xr.concat([past_data, future_data], dim="year_id") + + needed = full.sel(year_id=range(years.past_start, years.forecast_end + 1)).mean("draw") + needed.name = stage + + cleaned_df = needed.to_dataframe().reset_index() + + if "age_group_id" in cleaned_df.columns: + cleaned_df = cleaned_df.drop(["age_group_id"], axis=1) + + if "sex_id" in cleaned_df.columns: + cleaned_df = cleaned_df.drop(["sex_id"], axis=1) + + return cleaned_df + + +def create_past_ratios_df( + past_ratio_version: str, + gbd_round_id: int, + vaccine: str, + year_intro_col: str, + okay_locations: List[int], + location_metadata: pd.DataFrame, + ldi: pd.DataFrame, + education: pd.DataFrame, + simple_vacc_cov: pd.DataFrame, + introductions: pd.DataFrame, + years: YearRange, +) -> pd.DataFrame: + """Create a dataframe with all covariates needed to simulate the scale-up curves. + + Args: + past_ratio_version (str): the version containing past vaccine ratios + gbd_round_id (int): the gbd round to draw data from + vaccine (str): name of the vaccine to run + year_intro_col (str): name of the column in + past_ratios with the introduction year + okay_locations (List[int]): acceptable location IDs + location_metadata (pd.DataFrame): id, name, and level of a location + ldi (pd.DataFrame): past and future LDI data + education (pd.DataFrame): past and future maternal education data + simple_vacc_cov (pd.DataFrame): contains reference data + for the corresponding simple vaccine + introductions (pd.DataFrame): contains location specific introduction dates + years (YearRange): past_start:forecast_start:forecast_end + + Returns: + pd.DataFrame + + Raises: + ValueError: + If the dataframe produced is missing necessary columns + """ + if vaccine == "mcv2": + simple_vacc = "mcv1" + else: + simple_vacc = "dtp3" + + past_ratios = pd.read_csv( + f"/FILEPATH" + f"/FILEPATH/{vaccine}_{simple_vacc}_ratios.csv" + ) + + past_ratios = past_ratios.merge(simple_vacc_cov, on=DEMOG_COLS, how="outer") + + past_ratios = past_ratios.merge( + ldi, on=["location_id", "year_id", "scenario"], how="outer" + ) + + past_ratios = past_ratios.merge( + education, on=["location_id", "year_id", "scenario"], how="outer" + ) + + past_ratios = past_ratios.drop("ihme_loc_id", axis=1) + + location_metadata = location_metadata[["location_id", "ihme_loc_id"]] + past_ratios = past_ratios.merge(location_metadata, on="location_id", how="outer") + + past_ratios = past_ratios.merge(introductions, on="location_id", how="outer") + + if not set(past_ratios.columns).issuperset( + set(["simple_vacc_cov", "ldi", "education", f"{vaccine}_intro_yr_country"]) + ): + raise ValueError("Past Ratios DataFrame is missing columns") + + past_ratios["age_group_id"] = 22 + past_ratios["sex_id"] = 3 + past_ratios[f"num_{vaccine}_intro_yr_country"] = np.nan + + not_9999 = past_ratios[year_intro_col] != 9999 + past_ratios.loc[not_9999, f"num_{vaccine}_intro_yr_country"] = ( + past_ratios["year_id"] - past_ratios[year_intro_col] + ) + + # For places with known intros, remove data before first intro year + # Keep all data for places without a known intro year + no_intros = past_ratios[year_intro_col] == 9999 + + after_intro = past_ratios[f"num_{vaccine}_intro_yr_country"] >= 0 + past_ratios = past_ratios[no_intros | after_intro] + + # Ensure you are only using allowed locations in the modeling + past_ratios = past_ratios[past_ratios.location_id.isin(okay_locations)] + + past_ratios = past_ratios[past_ratios.scenario == 0] + + # Add information needed to run the linear model R code + past_ratios["admin"] = 0 + + if set(past_ratios.location_id.unique()) != set(okay_locations): + raise ValueError( + "Past data is missing locations " + f"{set(okay_locations) - set(past_ratios.location_id.unique())}" + ) + + past_ratios = past_ratios.loc[past_ratios.year_id.isin(years.years)] + + return past_ratios + + +def forecast_simulated_ratios( + version: Versions, + gbd_round_id: int, + years: YearRange, + draws: int, + okay_locations: List[int], + past_ratio_version: str, + vaccine: str, + location_metadata: pd.DataFrame, + simple_vacc_cov: pd.DataFrame, + introductions: pd.DataFrame, + simulated_intro_locations: List[int], + year_intro_col: str, + gavi_version: str, + simple_vacc_cov_scenarios: pd.DataFrame, + hib3_intro: pd.DataFrame, +) -> None: + """A function to forecast vaccine ratios that have not yet been introduced. + + Args: + version (Versions): a versions object containing covariates to forecast ratios + gbd_round_id (int): the gbd round to draw data from + years (YearRange): past_start:forecast_start:forecast_end + draws (int): number of draws + okay_locations (List[int]): acceptable location IDs + past_ratio_version (str): the version containing past vaccine ratios + vaccine (str): name of the vaccine to run + location_metadata (pd.DataFrame): id, name, and level of a location + simple_vacc_cov (pd.DataFrame): contains reference data for the + corresponding simple vaccine + introductions (pd.DataFrame): contains location specific introduction dates + simulated_intro_locations (List[int]): locations with no past introductions + year_intro_col (str): name of the column in past_ratios with the introduction year + gavi_version (str): name of the file with GAVI eligibility information + simple_vacc_cov_scenarios (pd.DataFrame): contains scenario data for the + corresponding simple vaccine + hib3_intro (pd.DataFrame): contains introduction dates for hib3 + + Returns: + None + """ + # Load in Covariates + sdi_df = load_single_covariate(version, "sdi", gbd_round_id, years, draws, okay_locations) + + ldi_df = load_single_covariate(version, "ldi", gbd_round_id, years, draws, okay_locations) + + edu_df = load_single_covariate( + version, "education", gbd_round_id, years, draws, okay_locations + ) + + past_ratios = create_past_ratios_df( + past_ratio_version, + gbd_round_id, + vaccine, + year_intro_col, + okay_locations, + location_metadata, + ldi_df, + edu_df, + simple_vacc_cov, + introductions, + years, + ) + + # Estimate curves conditional on the introduction date + all_curves = generate_theoretical_scaleups( + past_ratios, simulated_intro_locations, vaccine, years, draws, year_intro_col + ) + + # Estimate coefficients and inputs for Weibull distribution + # This will be used to select introduction years + weibull_preds, shape_param, coeff_dict = parametrize_weibull( + sdi_df, + gavi_version, + simple_vacc_cov_scenarios, + vaccine, + hib3_intro, + simulated_intro_locations, + years, + ) + + # Estimate introduction date + # Combine with simulated curves conditional on the introduction date + draw_data = draw_years_of_introduction( + vaccine, + weibull_preds, + simulated_intro_locations, + coeff_dict, + shape_param, + years, + draws, + all_curves, + ) + + return draw_data + + +def multiply_ratio_forecasts( + vaccine: str, + gbd_round_id: int, + ratios: xr.DataArray, + version: Versions, + draws: int, + years: YearRange, +) -> xr.DataArray: + """Function to multiply ratio forecasts by simple vaccine forecasts. + + Args: + vaccine (str): the ratio vaccine for which we will be loading ratios + gbd_round_id (int): the gbd round to draw ratio from + ratios (xr.DataArray): the version containing vaccine ratios + version: the location to read the simple vaccine forecasts from + draws (int): number of draws + years (YearRange): past and forecast years + + Returns: + xr.DataArray + """ + if vaccine == "mcv2": + simple_vacc = "mcv1" + else: + simple_vacc = "dtp3" + + vacc_forecast_path = ( + version.data_dir(gbd_round_id, "future", "vaccine") / f"vacc_{simple_vacc}.nc" + ) + + vacc_forecast_data = open_xr(vacc_forecast_path) + vacc_forecast_data = get_dataarray_from_dataset(vacc_forecast_data).rename(simple_vacc) + + forecast = resample(vacc_forecast_data, draws) + ratios = resample(ratios, draws) + + # Also need last past year estimates for intercept shifting + years_needed = [years.past_end] + years_needed.extend(list(years.forecast_years)) + + forecast = forecast.sel(year_id=years_needed) + multiplied_ratios = forecast * ratios + + return multiplied_ratios + + +def load_forecasted_ratios( + future_ratio_version: str, + gbd_round_id: int, + vaccine: str, + known_intro_dates: pd.DataFrame, + years: YearRange, + draws: int, +) -> xr.DataArray: + """Open forecasted ratios given by the vaccine team. + + Args: + future_ratio_version (str): the version containing future vaccine ratios + gbd_round_id (int): the gbd round to draw data from + vaccine (str): name of the vaccine to run + known_intro_dates (pd.DataFrame): dataframe with location IDs with given ratios + years (YearRange): past_start:forecast_start:forecast_end + draws (int): number of draws + + Returns: + xr.DataArray + """ + if vaccine == "mcv2": + simple_vacc = "mcv1" + else: + simple_vacc = "dtp3" + + ratios = pd.read_csv( + f"FILEPATH/" + f"FILEPATH/{vaccine}_{simple_vacc}_ratios.csv" + ) + + years_needed = list(years.forecast_years) + years_needed.append(years.past_end) + + ratios_to_use = ratios[ + (ratios.location_id.isin(known_intro_dates.location_id)) + & (ratios.year_id.isin(years_needed)) + ] + + if "Unnamed: 0" in ratios_to_use.columns: + ratios_to_use = ratios_to_use.drop("Unnamed: 0", axis=1) + + ind_cols = ["location_id", "age_group_id", "sex_id", "year_id"] + + if "scenario" in ratios_to_use.columns: + ind_cols.append("scenario") + draw_level = "level_5" + else: + draw_level = "level_4" + + ratios_to_use = ratios_to_use.set_index(ind_cols).stack().reset_index() + + cleaned_ratio_df = ratios_to_use.rename(columns={draw_level: "draw", 0: "value"}) + cleaned_ratio_df["draw"] = cleaned_ratio_df.draw.str.replace("draw_", "") + cleaned_ratio_df["draw"] = cleaned_ratio_df.draw.astype("int64") + + ind_cols.append("draw") + indexed_df = cleaned_ratio_df.set_index(ind_cols) + + ratio_da = indexed_df.to_xarray().to_array().fillna(0).sel(variable="value", drop=True) + + ratio_da = resample(ratio_da, draws) + + return ratio_da + + +def forecast_ratio_vaccines( + intro_version: str, + version: Versions, + gbd_round_id: int, + years: YearRange, + draws: int, + past_ratio_version: str, + future_ratio_version: str, + vaccine: str, + year_intro_col: str, + gavi_version: str, + intro_column_name: str, +) -> None: + """Forecast ratio vaccines like pcv3, hib3, rotac, and mcv2. + + Args: + intro_version (str): file name containing introduction dates + version (Versions): a versions object containing covariates to forecast ratios + gbd_round_id (int): the gbd round to draw data from + years (YearRange): past_start:forecast_start:forecast_end + draws (int): number of draws + past_ratio_version (str): the version containing past vaccine ratios + future_ratio_version (str): the version containing future vaccine ratios + vaccine (str): name of the vaccine to run + year_intro_col (str): name of the column in past_ratios with the introduction year + gavi_version (str): name of the file with GAVI eligibility information + intro_column_name (str): name of the column in the introduction data with the years + """ + # Past ratios will often have the same column name but not always + if not year_intro_col: + year_intro_col = f"{vaccine}_intro_yr_country" + + loc_6_full = get_location_set(gbd_round_id=gbd_round_id) + location_metadata = loc_6_full[(loc_6_full["level"] >= 3)] + okay_locations = location_metadata.location_id.values + + introductions = load_introduction( + intro_version, gbd_round_id, okay_locations, vaccine, intro_column_name + ) + + # hib3 information gets used for other vaccines as well + hib3_intro = load_introduction( + intro_version, gbd_round_id, okay_locations, "hib3", intro_column_name + ) + + # Load in simple vaccine coverage to use as a covariate + # Simple vaccine will be mcv1 for mcv2 and dtp3 for all other vaccines + simple_vacc_cov_scenarios = load_simple_vacc_cov(vaccine, version, gbd_round_id, draws) + simple_vacc_cov = simple_vacc_cov_scenarios[simple_vacc_cov_scenarios.scenario == 0] + + # Find what locations need to be simulated and which are given by the vaccines team + known_intro_dates = introductions[introductions[year_intro_col] != 9999] + simple_ratio_locations = np.unique(known_intro_dates.location_id.values) + needed_locs = np.unique(simple_vacc_cov.location_id.values) + simulated_intro_locations = list(set(needed_locs) - set(simple_ratio_locations)) + + simulated_ratios = forecast_simulated_ratios( + version, + gbd_round_id, + years, + draws, + okay_locations, + past_ratio_version, + vaccine, + location_metadata, + simple_vacc_cov, + introductions, + simulated_intro_locations, + year_intro_col, + gavi_version, + simple_vacc_cov_scenarios, + hib3_intro, + ) + + given_ratios = load_forecasted_ratios( + future_ratio_version, gbd_round_id, vaccine, known_intro_dates, years, draws + ) + + all_ratios = xr.concat([given_ratios, simulated_ratios], dim="location_id") + + # Go from ratio_vaccine:simple_vaccine to just ratio_vaccine coverage + final_unshifted = multiply_ratio_forecasts( + vaccine, gbd_round_id, all_ratios, version, draws, years + ) + + final_unshifted = final_unshifted.fillna(0) + + past_vaccine = load_past_vaccine_data(version, vaccine, gbd_round_id, draws, years) + past_vaccine_subset = past_vaccine.sel( + year_id=[years.past_end], location_id=final_unshifted.location_id.values + ) + + processor = LogitProcessor( + years=years, + offset=ModelConstants.DEFAULT_OFFSET, + no_mean=True, + intercept_shift="unordered_draw", + age_standardize=False, + gbd_round_id=gbd_round_id, + ) + + final_shifted = processor.post_process( + logit_with_offset(final_unshifted, 1e-8), past_vaccine_subset + ) + + if not np.isfinite(final_shifted).all(): + raise ValueError("Final values are invalid") + + forecast_path = version.data_dir(gbd_round_id, "future", "vaccine") / f"vacc_{vaccine}.nc" + save_xr( + final_shifted, + forecast_path, + metric="rate", + space="identity", + ) + + +@click.group() +@click.option( + "--vaccine", + type=str, + required=True, + help=("Vaccine being modeled e.g. `hib3` or 'rotac'"), +) +@click.option( + "--version", + "-v", + type=str, + required=True, + multiple=True, + help=("Vaccine and SDI versions in the form /FILEPATH/version_name"), +) +@click.option( + "--years", + type=str, + required=True, + help="Year range first_past_year:first_forecast_year:last_forecast_year", +) +@click.option( + "--gbd-round-id", + required=True, + type=int, + help="The gbd round id " "for all data", +) +@click.option( + "--draws", + required=True, + type=int, + help="Number of draws", +) +@click.option( + "--intro-version", + type=str, + required=True, + help="Version name of the file with vaccine introduction dates", +) +@click.option( + "--past-ratio-version", + type=str, + required=True, + help="Version name of the file with past ratios", +) +@click.option( + "--future-ratio-version", + type=str, + required=True, + help="Version name of the file with future ratios", +) +@click.option( + "--gavi-version", + type=str, + required=True, + help="Version name of the file indicating which locations are GAVI eligible", +) +@click.option( + "--year-intro-col", + type=str, + required=False, + help="Name of the column in the ratio files with the introduction year", +) +@click.option( + "--intro-column-name", + type=str, + required=False, + help="Name of the column in the introduction file with the introduction year", +) +@click.pass_context +def cli( + ctx: click.Context, + version: list, + vaccine: str, + gbd_round_id: int, + years: str, + draws: int, + intro_version: str, + past_ratio_version: str, + future_ratio_version: str, + gavi_version: str, + year_intro_col: str or None, + intro_column_name: str or None, +) -> None: + """Main cli function to parse args and pass them to the subcommands. + + Args: + ctx (click.Context): ctx object. + version (list): which population and vaccine versions to pass to + the script + vaccine (str): Relevant vaccine + gbd_round_id (int): Current gbd round id + years (str): years for forecast + draws (int): number of draws + intro_version (str): file name containing introduction dates + past_ratio_version (str): the version containing past vaccine ratios + future_ratio_version (str): the version containing future vaccine ratios + gavi_version (str): name of the file with GAVI eligibility information + year_intro_col (str): name of the column in past_ratios with the introduction year + intro_column_name (str): name of the column in the introduction data with the years + """ + version = Versions(*version) + years = YearRange.parse_year_range(years) + check_versions(version, "future", ["sdi", "vaccine"]) + ctx.obj = { + "version": version, + "vaccine": vaccine, + "gbd_round_id": gbd_round_id, + "years": years, + "draws": draws, + "intro_version": intro_version, + "past_ratio_version": past_ratio_version, + "future_ratio_version": future_ratio_version, + "gavi_version": gavi_version, + "year_intro_col": year_intro_col, + "intro_column_name": intro_column_name, + } + + +@cli.command() +@click.pass_context +def main(ctx: click.Context) -> None: + """Call to main function. + + Args: + ctx (click.Context): context object containing relevant params parsed + from command line args. + """ + FileSystemManager.set_file_system(OSFileSystem()) + + forecast_ratio_vaccines( + version=ctx.obj["version"], + vaccine=ctx.obj["vaccine"], + gbd_round_id=ctx.obj["gbd_round_id"], + years=ctx.obj["years"], + draws=ctx.obj["draws"], + intro_version=ctx.obj["intro_version"], + past_ratio_version=ctx.obj["past_ratio_version"], + future_ratio_version=ctx.obj["future_ratio_version"], + gavi_version=ctx.obj["gavi_version"], + year_intro_col=ctx.obj["year_intro_col"], + intro_column_name=ctx.obj["intro_column_name"], + ) + + +if __name__ == "__main__": + cli() \ No newline at end of file diff --git a/gbd_2021/disease_burden_forecast_code/vaccine/run_simple_vaccines.py b/gbd_2021/disease_burden_forecast_code/vaccine/run_simple_vaccines.py new file mode 100644 index 0000000..0fbaf0f --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/vaccine/run_simple_vaccines.py @@ -0,0 +1,479 @@ +"""Run Simple Vaccines. + +creation Date: 11/9/2020 +Purpose: This script serves to produce forecasts for MCV1 and DTP3 +Forecasts for scenarios are produced in log space, with the reference produced +in logit - the scenarios are then combined to produce output. +To clarify, better and worse scenarios are produced in log space +using the ARC method, while the reference scenario is produced +in logit space using Limetr. This decision was based off of how +scenarios look in logit and log space. The ARC method produced +wider bounds of uncertainty than Limetr, which produced overly +confined bounds of scenarios, and log is used because it makes the +scenarios diverge from reference more quickly in situations where +the coverage is close to 1 or 0. + +Example call: + +python run_simple_vaccines.py --vaccine dtp3 \ +--versions FILEPATH \ +-v FILEPATH \ +--gbd_round_id 6 \ +--years 1980:2020:2050 \ +--draws 500 \ +one-vaccine + +""" +from typing import Dict, Optional + +import click +import xarray as xr +from fhs_lib_data_transformation.lib.dimension_transformation import expand_dimensions +from fhs_lib_data_transformation.lib.processing import ( + clean_covariate_data, + get_dataarray_from_dataset, + strip_single_coord_dims, +) +from fhs_lib_data_transformation.lib.resample import resample +from fhs_lib_data_transformation.lib.validate import assert_shared_coords_same +from fhs_lib_database_interface.lib.constants import DimensionConstants, ScenarioConstants +from fhs_lib_database_interface.lib.query.location import get_location_set +from fhs_lib_database_interface.lib.query.model_strategy import ModelStrategyNames +from fhs_lib_file_interface.lib.check_input import check_versions +from fhs_lib_file_interface.lib.file_system_manager import FileSystemManager +from fhs_lib_file_interface.lib.os_file_system import OSFileSystem +from fhs_lib_file_interface.lib.versioning import Versions +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr, save_xr +from fhs_lib_model.lib.validate import assert_covariates_scenarios +from fhs_lib_year_range_manager.lib.year_range import YearRange +from tiny_structured_logger.lib.fhs_logging import get_logger + +from fhs_pipeline_vaccine.lib import model_strategy, model_strategy_queries + +logger = get_logger() + + +def load_past_vaccine_data( + versions: Versions, + vaccine: str, + gbd_round_id: int, + draws: int, + years: YearRange, +) -> xr.DataArray: + r"""The past is loaded and resampled using draws argument. + + Args: + versions (Versions):All relevant versions. e.g. + past/vaccine/123, past/sdi/456, future/sdi/789 + vaccine (str): The vaccine to forecast. + gbd_round_id (int): The numeric ID of GBD round associated with the + past data + draws (int): The number of draws to compute with and output for + betas and predictions. + years (YearRange): Forecasting time series. + + Returns: + (xr.DataArray): + """ + vacc_past_path = versions.data_dir(gbd_round_id, "past", "vaccine") / f"vacc_{vaccine}.nc" + vacc_past_data = open_xr(vacc_past_path) + vacc_past_data = get_dataarray_from_dataset(vacc_past_data).rename(vaccine) + + loc_6_full = get_location_set(gbd_round_id=gbd_round_id) + countries_and_subnats = loc_6_full.location_id.values + + past = resample(vacc_past_data, draws) + past = past.sel(year_id=years.past_years, location_id=countries_and_subnats) + + if "scenario" in past.dims: + past = past.sel(scenario=0, drop=True) + + return past + + +def load_covariates( + dep_var_da: xr.DataArray, + covariates: Dict, + versions: Versions, + gbd_round_id: int, + draws: int, + years: YearRange, +) -> xr.DataArray: + r"""To read in the covariate data and process according to covariate. + + Args: + dep_var_da (xr.DataArray): The vaccine dataarray in question to + set dimensions on + covariates (Dict): Which covariate to read in (e.g. sdi, ldi, + education) + versions (Versions):All relevant versions. e.g.:: + past/vaccine/123 + past/sdi/456 + future/sdi/789 + gbd_round_id (int): The numeric ID of GBD round associated with + the past data + draws (int): the number of draws to compute with and output for + betas and predictions. + years (YearRange): Forecasting time series. + + Raises: + IndexError: If past and forecast data don't line up across + all dimensions except ``year_id`` and ``scenario``, e.g. if + coordinates for of age_group_id are missing from forecast + data, but not past data. + If the covariate data is missing coordinates from a dim it + shares with the dependent variable -- both **BEFORE** and + AFTER pre-processing. + If the covariates do not have consistent scenario coords. + + Returns: + xr.DataArray + """ + cov_data_list = [] + for cov_stage, cov_processor in covariates.items(): + cov_past_path = versions.data_dir(gbd_round_id, "past", cov_stage) / f"{cov_stage}.nc" + cov_past_data = open_xr(cov_past_path) + cov_past_data = get_dataarray_from_dataset(cov_past_data).rename(cov_stage) + + cov_forecast_path = ( + versions.data_dir(gbd_round_id, "future", cov_stage) / f"{cov_stage}.nc" + ) + cov_forecast_data = open_xr(cov_forecast_path) + cov_forecast_data = get_dataarray_from_dataset(cov_forecast_data).rename(cov_stage) + cov_data = clean_covariate_data( + cov_past_data, + cov_forecast_data, + dep_var_da, + years, + draws, + gbd_round_id, + ) + if DimensionConstants.SCENARIO not in cov_data.dims: + cov_data = cov_data.expand_dims( + scenario=[ScenarioConstants.REFERENCE_SCENARIO_COORD] + ) + + prepped_cov_data = cov_processor.pre_process(cov_data, draws) + + try: + assert_shared_coords_same( + prepped_cov_data, dep_var_da.sel(year_id=years.past_end, drop=True) + ) + except IndexError as ce: + err_msg = f"After pre-processing {cov_stage}," + str(ce) + logger.error(err_msg) + raise IndexError(err_msg) + + cov_data_list.append(prepped_cov_data) + + assert_covariates_scenarios(cov_data_list) + return cov_data_list + + +def one_vaccine_main( + vaccine: str, + versions: Versions, + years: YearRange, + draws: int, + gbd_round_id: int, +) -> xr.DataArray: + r"""Forecasts given stage for given cause. + + Args: + vaccine (str): The vaccine to forecast. + versions (Versions): All relevant versions. e.g.:: + past/vaccine/123 + past/sdi/456 + future/sdi/789 + years (YearRange): Forecasting time series. + draws (int): The number of draws to compute with and output for + betas and predictions. + gbd_round_id (int): The numeric ID of GBD round associated with the + past data + + Returns: + xr.DataArray + """ + # Run linear model for reference scenario + model_parameters = _get_model_parameters( + vaccine, ModelStrategyNames.LIMETREE.value, years, gbd_round_id + ) + if model_parameters: + ( + model, + processor, + covariates, + fixed_effects, + fixed_intercept, + random_effects, + indicators, + spline, + predict_past_only, + ) = model_parameters + + versions_to_check = {"vaccine"} | covariates.keys() if covariates else {"vaccine"} + check_versions(versions, "past", versions_to_check) + versions_to_check.remove("vaccine") + check_versions(versions, "future", versions_to_check) + + vacc_da = load_past_vaccine_data(versions, vaccine, gbd_round_id, draws, years) + # convert past vaccine data into logit space + prep_vacc_da = processor.pre_process(vacc_da, draws) + + if covariates: + cov_data_list = load_covariates( + prep_vacc_da, + covariates, + versions, + gbd_round_id, + draws, + years, + ) + else: + cov_data_list = None + # remove single-variable dimensions (age_group_id, sex_id) + prep_vacc_da = strip_single_coord_dims(prep_vacc_da) + limetr_model = model( + past_data=prep_vacc_da, + years=years, + draws=draws, + covariate_data=cov_data_list, + random_effects=random_effects, + indicators=indicators, + fixed_effects=fixed_effects, + fixed_intercept=fixed_intercept, + gbd_round_id=gbd_round_id, + ) + limetr_model.fit() + forecast_path = versions.data_dir(gbd_round_id, "future", "vaccine") + limetr_model.save_coefficients(forecast_path, f"{vaccine}_limetr") + limetr_forecasts = limetr_model.predict() + vacc_da = vacc_da.drop_vars("age_group_id").squeeze("age_group_id") + vacc_da = vacc_da.drop_vars("sex_id").squeeze("sex_id") + # convert data back into rate space + limetr_forecasts = processor.post_process(limetr_forecasts, vacc_da) + # Run arc for scenarios - model should not use covariate data, just past data + model_parameters = _get_model_parameters( + vaccine, ModelStrategyNames.ARC.value, years, gbd_round_id + ) + if model_parameters: + ( + model, + processor, + covariates, + fixed_effects, + fixed_intercept, + random_effects, + indicators, + spline, + predict_past_only, + ) = model_parameters + # pre-process into log space + prep_vacc_da = processor.pre_process(vacc_da, draws) + arc_model = model( + past_data=prep_vacc_da, + years=years, + draws=draws, + select_omega=False, + omega=1.0, + reference_scenario_statistic="mean", + mean_level_arc=False, + truncate=False, + gbd_round_id=gbd_round_id, + scenario_roc="national", + ) + arc_model.fit() + forecast_path = versions.data_dir(gbd_round_id, "future", "vaccine") + arc_model.save_coefficients(forecast_path, f"{vaccine}_arc") + arc_forecasts = arc_model.predict() + full_arc = xr.concat([prep_vacc_da, arc_forecasts], "year_id") + full_arc.values[full_arc.values > 0] = 0 + # convert arc scenarios back into rate space in preparation for combination + full_arc = processor.post_process(full_arc, vacc_da) + # combine the scenarios with the reference - as arc is forecasted in log space and + # limetree in logit, both must be transformed back into rate space before this step + + forecast_da = combine_scenarios(limetr_da=limetr_forecasts, arc_da=full_arc, years=years) + forecast_da = expand_dimensions(forecast_da, age_group_id=[22]) + forecast_da = expand_dimensions(forecast_da, sex_id=[3]) + save_xr(forecast_da, forecast_path / f"vacc_{vaccine}.nc", metric="rate", space="identity") + return forecast_da + + +def _get_model_parameters( + vaccine: str, model_type: str, years: YearRange, gbd_round_id: int +) -> Optional[model_strategy.ModelParameters]: + """Gets modeling parameters associated with vaccine and model type. + + Args: + vaccine (str): which vaccine to include + model_type (str): which model type to use + years (YearRange): the years to combine across + gbd_round_id (int): the gbd_round_id to use + + Returns: + Optional[model_strategy.ModelParameters] + + If there aren't model parameters associated with the vaccine-model then the + script will exit with return code 0. + """ + model_parameters = model_strategy_queries.get_vaccine_model( + vaccine, model_type, years, gbd_round_id + ) + if not model_parameters: + logger.info(f"{vaccine}-{model_type} is not forecasted in this pipeline. DONE") + exit(0) + else: + return model_parameters + + +def combine_scenarios( + limetr_da: xr.DataArray, + arc_da: xr.DataArray, + years: YearRange, +) -> xr.DataArray: + """Function to combine scenarios into a single dataarray. + + Args: + limetr_da (xr.DataArray): + the output of the limetr model used to produce the reference scenario. + arc_da (xr.DataArray): + the ouptut of the arc model used to produce the better and worse + scenarios. + years (YearRange): Year range. + + Returns: + combine_da (xr.DataArray): the combined dataarray + + This function takes the reference scenario from the limetr_da and uses + it to find an offset between the limetr_da and the arc_da, effectively + replacing the reference scenario of the arc_da with the reference + scenario of the limetr_da + This function also forces the better and worse scenario to be outside of the + reference scenario by setting the better scenario to the highest value + and the worse scenario to the lowest value + Note: This function expects both data arrays to be in rate space + """ + limetr_reference = limetr_da.sel( + scenario=ScenarioConstants.REFERENCE_SCENARIO_COORD, drop=True + ) + arc_reference = arc_da.sel(scenario=ScenarioConstants.REFERENCE_SCENARIO_COORD, drop=True) + + # get difference between arc reference and limetr reference in first forecast year + difference = arc_reference.sel(year_id=years.past_end, drop=True) - limetr_reference.sel( + year_id=years.past_end, drop=True + ) + + # add difference between arc and limetr to limetr reference + new_reference = limetr_reference + difference + # replace arc reference with limetr reference + new_forecast = xr.concat( + [ + arc_da.sel( + scenario=[ + ScenarioConstants.WORSE_SCENARIO_COORD, + ScenarioConstants.BETTER_SCENARIO_COORD, + ], + year_id=years.forecast_years, + ), + new_reference.expand_dims({"scenario": [0]}), + ], + "scenario", + ) + # Force the better and worse scenarios equal to reference if reference is + # higher than better scenario or lower than worse scenario + low_values = new_forecast.min("scenario") + high_values = new_forecast.max("scenario") + new_forecast.loc[{"scenario": ScenarioConstants.BETTER_SCENARIO_COORD}] = high_values + new_forecast.loc[{"scenario": ScenarioConstants.WORSE_SCENARIO_COORD}] = low_values + + return new_forecast + + +@click.group() +@click.option( + "--vaccine", + type=str, + required=True, + help=("Vaccine being modeled e.g. `vacc_dtp3`"), +) +@click.option( + "--versions", + "-v", + type=str, + required=True, + multiple=True, + help=("Vaccine and SDI versions in the form FILEPATH"), +) +@click.option( + "--years", + type=str, + required=True, + help="Year range first_past_year:first_forecast_year:last_forecast_year", +) +@click.option( + "--gbd_round_id", + required=True, + type=int, + help="The gbd round id " "for all data", +) +@click.option( + "--draws", + required=True, + type=int, + help="Number of draws", +) +@click.pass_context +def cli( + ctx: click.Context, + versions: list, + vaccine: str, + gbd_round_id: int, + years: str, + draws: int, +) -> None: + """Main cli function to parse args and pass them to the subcommands. + + Args: + ctx (click.Context): ctx object. + versions (list): which population and vaccine versions to pass to + the script + vaccine (str): Relevant vaccine + gbd_round_id (int): Current gbd round id + years (str): years for forecast + draws (int): number of draws + """ + versions = Versions(*versions) + years = YearRange.parse_year_range(years) + check_versions(versions, "future", ["sdi", "vaccine"]) + ctx.obj = { + "versions": versions, + "vaccine": vaccine, + "gbd_round_id": gbd_round_id, + "years": years, + "draws": draws, + } + + +@cli.command() +@click.pass_context +def one_vaccine(ctx: click.Context) -> None: + """Call to main function. + + Args: + ctx (click.Context): context object containing relevant params parsed + from command line args. + """ + FileSystemManager.set_file_system(OSFileSystem()) + + one_vaccine_main( + versions=ctx.obj["versions"], + vaccine=ctx.obj["vaccine"], + gbd_round_id=ctx.obj["gbd_round_id"], + years=ctx.obj["years"], + draws=ctx.obj["draws"], + ) + + +if __name__ == "__main__": + cli() \ No newline at end of file diff --git a/gbd_2021/disease_burden_forecast_code/yll/yll.py b/gbd_2021/disease_burden_forecast_code/yll/yll.py new file mode 100644 index 0000000..59fc500 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/yll/yll.py @@ -0,0 +1,111 @@ +"""Module for calculating ylls. + +""" +import pandas as pd +import xarray as xr +from fhs_lib_database_interface.lib.query import age +from tiny_structured_logger.lib import fhs_logging + +logger = fhs_logging.get_logger() + +MAX_MEAN_AGE_OF_DEATH = 110 # Model lex is NOT available for mean-age-of-death + + +def calculate_ylls( + deaths: xr.DataArray, ax: xr.DataArray, reference_lex: pd.DataFrame, gbd_round_id: int +) -> xr.DataArray: + r"""Calculate YLLs. + + Calculates YLLs or YLL rates from reference life expectancy, deaths, and mean age of death + (:math:`a_x`). + + The output is in count or rate space according to the space of ``deaths``: If ``deaths`` is + in "counts" then the result is YLLs. If ``deaths`` is in rate space (deaths divided by + population) are given, then this function calculates YLL rates. + + 1. Take age-specific-only reference life expectancy with "reference life starting years" at + 0.01 granularity. Call this "ref lex". + + 2. Take the standard forecasting/FHS `age_group_years_start` values (usually 5-year grid + points) + loc-age-sex-specific `ax` (rounded to nearest 0.01 years), giving the mean + age-of-death within each age-group (i.e. the mean age that people die, when they die + within their age-group), to match the age start years of ref lex from (1). + + 3. The ref lex where `age_group_years_start` + `ax` = reference life starting year (at + 0.01-year granularity) is the ref lex assigned to this age group. + + YLL calculation is then simply + + .. math:: + + \mbox{YLL}_{las} = m_{las} \times \bar{e}_{a} + + where :math:`\mbox{YLL}_{las}` is location, age-group, sex specific YLLs (or YLL rates), + :math:`m_{las}` is location, age-group, sex specific deaths (or mortality rates), and + :math:`\bar{e}_{a}` is reference life expectancy (at life starting year), which is + `only` age-group specific, by definition. + + Args: + deaths (xarray.DataArray): + location, age-group, sex specific deaths or death rates. Has at least + ``location_id``, ``age_group_id``, and ``sex_id`` as dimensions. + ax (xarray.DataArray): + location, age-group, sex specific mean age of death within each age interval, + :math:`a_x`, i.e. the mean number of years lived within that interval among those + that `died` within that interval. Has at least ``location_id``, ``age_group_id``, + and ``sex_id`` as dimensions. + reference_lex (pandas.DataFrame): + Reference life expectancy. Should just be age-specific. Has 2 columns: + ``age_group_years_start`` and ``Ex``. Use + ``fhs_lib_database_interface.query.get_reference_lex`` to get this information in + the necessary format. + gbd_round_id (int): + Numeric ID for the GBD round. + + Returns: + xarray.DataArray: + location, age-group, sex specific YLLs or YLL rates. Has at least ``location_id``, + ``age_group_id``, and ``sex_id`` as dimensions. + + Raises: + RuntimeError: + When mean age of death as calculated excedes the maximum. + """ + logger.debug("Entering `calculate_ylls` function") + age_group_years_start = ( + # TODO: Generally all callers to ages.get_ages are getting the same two columns. + age.get_ages(gbd_round_id)[["age_group_id", "age_group_years_start"]] + .set_index(["age_group_id"]) + .to_xarray()["age_group_years_start"] + ) + + # Calculate the mean age-of-death for all of the standard-FHS age-groups by adding the ax + # (mean-years-lived within the age-group, for those who die within it) and the age-start + # of each age group. + mean_age_of_death = (age_group_years_start + ax).round(2) + mean_age_of_death.name = "mean_age_of_death" + if (mean_age_of_death > MAX_MEAN_AGE_OF_DEATH).any(): + logger.warning( + "There aren't reference life expectancies available for age groups that start " + f"after {MAX_MEAN_AGE_OF_DEATH} years. Clipping data to {MAX_MEAN_AGE_OF_DEATH}" + ) + mean_age_of_death = mean_age_of_death.clip(max=MAX_MEAN_AGE_OF_DEATH) + mean_age_of_death_df = mean_age_of_death.to_dataframe().reset_index() + + # Match the mean age-of-death for each FHS age-group with an `age_group_years_start` value + # from the table of reference life-expectancies. The corresponding life expectancy is the + # selected reference life for that age-group. + selected_ref_lex_df = pd.merge( + mean_age_of_death_df, + reference_lex, + left_on="mean_age_of_death", + right_on="age_group_years_start", + ) + selected_ref_lex = selected_ref_lex_df.set_index(list(mean_age_of_death.dims)).to_xarray()[ + "Ex" + ] + + ylls = deaths * selected_ref_lex + + logger.debug("Leaving `calculate_ylls` function") + return ylls \ No newline at end of file diff --git a/gbd_2021/disease_burden_forecast_code/yll/yll_calculator.py b/gbd_2021/disease_burden_forecast_code/yll/yll_calculator.py new file mode 100644 index 0000000..ff4ae37 --- /dev/null +++ b/gbd_2021/disease_burden_forecast_code/yll/yll_calculator.py @@ -0,0 +1,119 @@ +r"""Calculates YLLs or YLL rates. + +Calculates YLLs or YLL rates from reference life expectancy, deaths, and +mean age of death (:math:`a_x`) + +If death counts are given then this pipeline calculates YLLs. However, if +death rates (mortality rates--deaths divided by mid-year-population) are given, +then this function calculates YLL rates. + +YLL calculation is + +.. math:: + + \mbox{YLL}_{las} = m_{las} * \bar{e}_{a} + +where + +#. :math:`{YLL}_{las}` is location, age-group, sex specific YLLs (or YLL rates) +#. :math:`m_{las}` is location, age-group, sex specific deaths (or rates) +#. :math:`\bar{e}_{a}` is reference life expectancy (at life starting year) + +:math:`\bar{e}_{a}` is described in `fhs_pipeline_yll.lib.ylls.calculate_ylls`. + +Example call: + +.. code-block:: bash + + fhs_pipeline_yll_console \ + --gbd-round-id 5 \ + --version \ + -v FILEPATH \ + --draws 1000 \ + --years 1990:2018:2100 \ + --past include \ + parallelize_by_cause + +""" +from typing import List + +from fhs_lib_data_transformation.lib.dimension_transformation import drop_single_point +from fhs_lib_data_transformation.lib.resample import resample +from fhs_lib_database_interface.lib.query import reference_lex as ref_lex_module +from fhs_lib_file_interface.lib.version_metadata import FHSFileSpec +from fhs_lib_file_interface.lib.versioning import Versions +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr_scenario, save_xr_scenario +from fhs_lib_year_range_manager.lib.year_range import YearRange +from tiny_structured_logger.lib import fhs_logging + +from fhs_pipeline_yll.lib import yll as yll_module + +logger = fhs_logging.get_logger() + + +def one_cause_main( + acause: str, + draws: int, + gbd_round_id: int, + epoch: str, + versions: Versions, + years: YearRange, +) -> None: + """Compute YLL (rate) at cause level. + + Args: + acause (str): Acause to run over + draws (int): How many draws to save for the daly output. + gbd_round_id (int): What gbd_round_id that yld, yll and daly are saved under. + epoch (str): Which epoch to calculate YLLs for, either past or future + versions (Versions): A Versions object that keeps track of all the versions and their + respective data directories. + years (YearRange): Forecasting year range + """ + years_to_compute = determine_years_to_compute(years=years, epoch=epoch) + + death_file = FHSFileSpec( + versions.get(past_or_future=epoch, stage="death"), + f"{acause}.nc", + ) + mx = resample(open_xr_scenario(death_file).sel(year_id=years_to_compute), draws) + mx = drop_single_point(mx, "acause") + + lifetable_file = FHSFileSpec( + versions.get(past_or_future=epoch, stage="life_expectancy"), + "lifetable_ds.nc", + ) + ax = resample(open_xr_scenario(lifetable_file).ax.sel(year_id=years_to_compute), draws) + ax = drop_single_point(ax, "acause") + + reference_lex = ref_lex_module.get_reference_lex(gbd_round_id) # calls db + + yll = yll_module.calculate_ylls(mx, ax, reference_lex, gbd_round_id) + + yll_slice_file = FHSFileSpec( + versions.get(past_or_future=epoch, stage="yll"), + filename=f"{acause}.nc", + ) + + save_xr_scenario( + yll, + yll_slice_file, + metric="rate", + space="identity", + death=str(death_file.data_path()), + ax=str(lifetable_file.data_path()), + ) + + logger.info(f"Leaving `one_cause_yll` function for {acause}. DONE") + + +def determine_years_to_compute(years: YearRange, epoch: str) -> List[int]: + """Return the list of years to compute, based on the ``epoch`` specification.""" + if epoch == "future": + years_in_slice = years.forecast_years + elif epoch == "past": + years_in_slice = years.past_years + else: + raise RuntimeError("epoch must be `past` or `future`") + + return list(years_in_slice) \ No newline at end of file diff --git a/gbd_2021/fertility_forecast_code/education/arc_weight_selection.py b/gbd_2021/fertility_forecast_code/education/arc_weight_selection.py new file mode 100644 index 0000000..93af1fe --- /dev/null +++ b/gbd_2021/fertility_forecast_code/education/arc_weight_selection.py @@ -0,0 +1,452 @@ +"""Education weight-selection for the ARC method using predictive Validity + +>>> python arc_weight_selection.py \ + --reference-scenario mean \ + --transform logit \ + --diff-over-mean \ + --truncate \ + --truncate-quantiles 0.15 0.85 \ + --max-weight 3 \ + --weight-step-size 0.25 \ + --pv-version 20181026_just_nats_capped_pv \ + --past-version 20181003_subnats_included \ + --gbd-round-id 5 \ + --years 1990:2008:2017 \ + all-weights + +""" +import logging +import os +import subprocess +import sys + +import numpy as np +import xarray as xr + +from fbd_core import argparse +from fbd_core.file_interface import FBDPath, open_xr +from fbd_research.education.forecast_education import (TRANSFORMATIONS, + arc_forecast_education) + +REFERENCE_SCENARIO = 0 +LOGGER = logging.getLogger(__name__) + + +def calc_rmse(predicted, observed, years): + predicted = predicted.sel(year_id=years.forecast_years) + observed = observed.sel(year_id=years.forecast_years) + + rmse = np.sqrt(((predicted - observed) ** 2).mean()) + + return rmse + + +def one_weight_main(reference_scenario, transform, diff_over_mean, truncate, + truncate_quantiles, replace_with_mean, + use_past_uncertainty, weight_exp, past_version, pv_version, + years, gbd_round_id, test_mode, **kwargs): + """Predictive validity for one one weight of the range of weights""" + + LOGGER.debug("diff_over_mean:{}".format(diff_over_mean)) + LOGGER.debug("truncate:{}".format(truncate)) + LOGGER.debug("truncate_quantiles:{}".format(truncate_quantiles)) + LOGGER.debug("replace_with_mean:{}".format(replace_with_mean)) + LOGGER.debug("reference_scenario:{}".format(reference_scenario)) + LOGGER.debug("use_past_uncertainty:{}".format(use_past_uncertainty)) + + LOGGER.debug("Reading in the past") + past_path = FBDPath("".format()) # Path structure removed for security + past = open_xr(past_path / "education.nc").data + past = past.transpose(*list(past.coords)) + + if not use_past_uncertainty: + LOGGER.debug("Using past means for PV") + past = past.mean("draw") + else: + LOGGER.debug("Using past draws for PV") + + if test_mode: + past = past.sel( + age_group_id=past["age_group_id"].values[:5], + draw=past["draw"].values[:5], + location_id=past["location_id"].values[:5]) + else: + pass # Use full data set. + + holdouts = past.sel(year_id=years.past_years) + observed = past.sel(year_id=years.forecast_years) + + LOGGER.debug("Calculating RMSE for {}".format(weight_exp)) + predicted = arc_forecast_education( + holdouts, gbd_round_id, transform, weight_exp, years, + reference_scenario, + diff_over_mean, truncate, truncate_quantiles, replace_with_mean) + rmse = calc_rmse(predicted.sel(scenario=REFERENCE_SCENARIO, drop=True), + observed, + years) + + rmse_da = xr.DataArray( + [rmse.values], [[weight_exp]], dims=["weight"]) + + pv_path = FBDPath("".format()) # Path structure removed for security + separate_weights_path = pv_path / "each_weight" + separate_weights_path.mkdir(parents=True, exist_ok=True) + rmse_da.to_netcdf( + str(separate_weights_path / "{}_rmse.nc".format(weight_exp))) + LOGGER.info("Saving RMSE for {}".format(weight_exp)) + + +def merge_main(max_weight, weight_step_size, pv_version, gbd_round_id, + **kwargs): + """Merge RMSE values for all of the tested weights into one dataarray.""" + LOGGER.debug("Calculating RMSE for all weights") + weights_to_test = np.arange(0, max_weight, weight_step_size) + + pv_path = FBDPath("".format()) # Path structure removed for security + separate_weights_path = pv_path / "each_weight" + + rmse_results = [] + for weight_exp in weights_to_test: + rmse_da = open_xr( + separate_weights_path / "{}_rmse.nc".format(weight_exp)).data + rmse_results.append(rmse_da) + rmse_results = xr.concat(rmse_results, dim="weight") + + rmse_results.to_netcdf(str(pv_path / "education_arc_weight_rmse.nc")) + LOGGER.info("RMSE is saved") + + +def parallelize_by_weight_main(reference_scenario, transform, diff_over_mean, + truncate, truncate_quantiles, replace_with_mean, + use_past_uncertainty, max_weight, + weight_step_size, past_version, pv_version, + years, test_mode, gbd_round_id, slots, + **kwargs): + """Parallelize the script so Predictive validity is run for all weights.""" + LOGGER.debug("diff_over_mean:{}".format(diff_over_mean)) + LOGGER.debug("truncate:{}".format(truncate)) + LOGGER.debug("truncate_quantiles:{}".format(truncate_quantiles)) + LOGGER.debug("replace_with_mean:{}".format(replace_with_mean)) + LOGGER.debug("reference_scenario:{}".format(reference_scenario)) + LOGGER.debug("use_past_uncertainty:{}".format(use_past_uncertainty)) + + script = os.path.abspath(os.path.realpath(__file__)) + LOGGER.debug(script) + + weights_to_test = np.arange(0, max_weight, weight_step_size) + num_weights = len(weights_to_test) + + if test_mode: + test_mode_call = "--test-mode" + else: + test_mode_call = "" # Use full data set + + if truncate: + truncate_call = "--truncate" + else: + truncate_call = "" + + if truncate_quantiles: + truncate_quantiles_call = "--truncate-quantiles {}".format( + " ".join([str(i) for i in truncate_quantiles])) + else: + truncate_quantiles_call = "" + + if replace_with_mean: + replace_with_mean_call = "--replace-with-mean" + else: + replace_with_mean_call = "" + + if use_past_uncertainty: + use_past_uncertainty_call = "--use-past-uncertainty" + else: + use_past_uncertainty_call = "" + + if diff_over_mean: + diff_over_mean_call = "--diff-over-mean" + else: + diff_over_mean_call = "" + + one_weight_qsub = ( + "qsub -N 'edu_arc_pv' " + "-pe multi_slot {slots} " + "-t 1:{N} " + "-b y {which_python} " + "{script} " + "--reference-scenario {reference_scenario} " + "--transform {transform} " + "--max-weight {max_weight} " + "--weight-step-size {weight_step_size} " + "--pv-version {pv_version} " + "--past-version {past_version} " + "--years {years} " + "--gbd-round-id {gbd_round_id} " + "{truncate_call} " + "{truncate_quantiles_call} " + "{replace_with_mean_call} " + "{use_past_uncertainty_call} " + "{diff_over_mean_call} " + "{test_mode_call} " + "one-weight" + ).format( + slots=slots, + N=num_weights, + which_python=sys.executable, + script=script, + reference_scenario=reference_scenario, + transform=transform, + max_weight=max_weight, + weight_step_size=weight_step_size, + pv_version=pv_version, + past_version=past_version, + years=years, + gbd_round_id=gbd_round_id, + truncate_call=truncate_call, + truncate_quantiles_call=truncate_quantiles_call, + replace_with_mean_call=replace_with_mean_call, + use_past_uncertainty_call=use_past_uncertainty_call, + diff_over_mean_call=diff_over_mean_call, + test_mode_call=test_mode_call) + + LOGGER.debug(one_weight_qsub) + one_weight_qsub_proc = subprocess.Popen( + one_weight_qsub, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + one_weight_qsub_out, _ = one_weight_qsub_proc.communicate() + if one_weight_qsub_proc.returncode: + one_weight_qsub_err_msg = "One-weight qsub failed." + LOGGER.error(one_weight_qsub_err_msg) + raise RuntimeError(one_weight_qsub_err_msg) + LOGGER.debug(one_weight_qsub_out) + + hold_jid = [ + int(word) + for word in str( + one_weight_qsub_out).split(".")[0].split(" ") + if word.isdigit() + ][0] + LOGGER.debug(hold_jid) + + merge_slots = 5 + merge_qsub = ( + "qsub -N 'edu_arc_pv' " + "-hold_jid {hold_jid} " + "-pe multi_slot {slots} " + "-b y {which_python} " + "{script} " + "--reference-scenario {reference_scenario} " + "--transform {transform} " + "--max-weight {max_weight} " + "--weight-step-size {weight_step_size} " + "--pv-version {pv_version} " + "--past-version {past_version} " + "--gbd-round-id {gbd_round_id} " + "--years {years} " + "{truncate_call} " + "{truncate_quantiles_call} " + "{replace_with_mean_call} " + "{use_past_uncertainty_call} " + "{diff_over_mean_call} " + "{test_mode_call} " + "merge" + ).format(hold_jid=hold_jid, + slots=merge_slots, + which_python=sys.executable, + script=script, + reference_scenario=reference_scenario, + transform=transform, + max_weight=max_weight, + weight_step_size=weight_step_size, + pv_version=pv_version, + past_version=past_version, + years=years, + gbd_round_id=gbd_round_id, + truncate_call=truncate_call, + truncate_quantiles_call=truncate_quantiles_call, + replace_with_mean_call=replace_with_mean_call, + use_past_uncertainty_call=use_past_uncertainty_call, + diff_over_mean_call=diff_over_mean_call, + test_mode_call=test_mode_call) + LOGGER.debug(merge_qsub) + merge_qsub_proc = subprocess.Popen( + merge_qsub, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + merge_qsub_out, _ = merge_qsub_proc.communicate() + if merge_qsub_proc.returncode: + merge_err_msg = "Merge qsub failed." + LOGGER.error(merge_err_msg) + raise RuntimeError(merge_err_msg) + + +def all_weights_main(reference_scenario, diff_over_mean, truncate, + truncate_quantiles, replace_with_mean, + use_past_uncertainty, transform, max_weight, + weight_step_size, past_version, pv_version, years, + gbd_round_id, test_mode, **kwargs): + """Predictive validity for one weight of the range of weights at a time.""" + LOGGER.debug("diff_over_mean:{}".format(diff_over_mean)) + LOGGER.debug("truncate:{}".format(truncate)) + LOGGER.debug("truncate_quantiles:{}".format(truncate_quantiles)) + LOGGER.debug("replace_with_mean:{}".format(replace_with_mean)) + LOGGER.debug("reference_scenario:{}".format(reference_scenario)) + LOGGER.debug("use_past_uncertainty:{}".format(use_past_uncertainty)) + + LOGGER.debug("Reading in the past") + past_path = FBDPath("".format()) + past = open_xr(past_path / "education.nc").data + past = past.transpose(*list(past.coords)) + + if not use_past_uncertainty: + LOGGER.debug("Using past means for PV") + past = past.mean("draw") + else: + LOGGER.debug("Using past draws for PV") + + if test_mode: + past = past.sel( + age_group_id=past["age_group_id"].values[:5], + draw=past["draw"].values[:5], + location_id=past["location_id"].values[:5]) + else: + pass # Use full data set. + + holdouts = past.sel(year_id=years.past_years) + observed = past.sel(year_id=years.forecast_years) + + LOGGER.debug("Calculating RMSE for all weights") + weights_to_test = np.arange(0, max_weight, weight_step_size) + rmse_results = [] + for weight_exp in weights_to_test: + predicted = arc_forecast_education( + holdouts, gbd_round_id, transform, weight_exp, years, + reference_scenario, + diff_over_mean, truncate, truncate_quantiles, replace_with_mean) + rmse = calc_rmse(predicted.sel(scenario=REFERENCE_SCENARIO, drop=True), + observed, + years) + + rmse_da = xr.DataArray( + [rmse.values], [[weight_exp]], dims=["weight"]) + rmse_results.append(rmse_da) + rmse_results = xr.concat(rmse_results, dim="weight") + + pv_path = FBDPath("".format()) + pv_path.mkdir(parents=True, exist_ok=True) + rmse_results.to_netcdf(str(pv_path / "education_arc_weight_rmse.nc")) + LOGGER.info("RMSE is saved") + + +if __name__ == "__main__": + + def get_weight_from_jid(weight_arg, max_weight, weight_step_size): + sge_task_id = os.environ.get("SGE_TASK_ID") + if weight_arg is not None: + return weight_arg + elif sge_task_id: + weights_to_test = np.arange(0, max_weight, weight_step_size) + sge_task_id = int(sge_task_id) - 1 + LOGGER.debug("SGE_TASK_ID: {}".format(sge_task_id)) + + weight = weights_to_test[sge_task_id] + else: + err_msg = "`weight` and `SGE_TASK_ID` can not all be NoneType." + LOGGER.error(err_msg) + raise ValueError(err_msg) + return weight + + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument( + "--reference-scenario", type=str, choices=["median", "mean"], + help=("If 'median' then the reference scenarios is made using the " + "weighted median of past annualized rate-of-change across all " + "past years, 'mean' then it is made using the weighted mean of " + "past annualized rate-of-change across all past years.")) + parser.add_argument( + "--diff-over-mean", action="store_true", + help=("If True, then take annual differences for means-of-draws, " + "instead of draws.")) + parser.add_argument( + "--truncate", action="store_true", + help=("If True, then truncates the dataarray over the given " + "dimensions.")) + parser.add_argument( + "--truncate-quantiles", type=float, nargs="+", + help="The tuple of two floats representing the quantiles to take.") + parser.add_argument( + "--replace-with-mean", action="store_true", + help=("If True and `truncate` is True, then replace values outside of " + "the upper and lower quantiles taken across `location_id` and " + "`year_id` and with the mean across `year_id`, if False, then " + "replace with the upper and lower bounds themselves.")) + parser.add_argument( + "--use-past-uncertainty", action="store_true", + help=("Use past draws in PV forecasts. **Note** if you're running " + "with `diff_over_mean=True`, then it doesn't make much sense " + "use past draws at all, in fact, it will only make the run " + "slower without adding any statistical significance.")) + parser.add_argument( + "--transform", type=str, + choices=list(sorted(TRANSFORMATIONS.keys())), + help="Space to transform education to for forecasting.") + parser.add_argument( + "--max-weight", type=float, required=True, + help="The maximum weight to try. Current convention is 3.") + parser.add_argument( + "--weight-step-size", type=float, required=True, + help="The step size of weights to try between 0 and `max_weight`." + "Current convention is 2.5.") + parser.add_argument( + "--pv-version", type=str, required=True, + help="Version of education weight selection") + parser.add_argument( + "--past-version", type=str, required=True, + help="Version of past education") + parser.add_argument( + "--gbd-round-id", type=int, required=True) + parser.add_argument( + "--test-mode", action="store_true", + help="Run on smaller test data set.") + parser.add_arg_years(required=True) + + subparsers = parser.add_subparsers() + + # Run script for only one weight + one_weight = subparsers.add_parser( + "one-weight", help=one_weight_main.__doc__) + one_weight.add_argument( + "--weight-exp", type=float, help="The weight to do PV for") + one_weight.set_defaults(func=one_weight_main) + + # Merge the outputs of the parallelized jobs + merge = subparsers.add_parser("merge", help=merge_main.__doc__) + merge.set_defaults(func=merge_main) + + # Run script for all weights in parallel. + parallelize_by_weight = subparsers.add_parser( + "parallelize-by-weight", help=parallelize_by_weight_main.__doc__) + parallelize_by_weight.add_argument( + "--slots", type=int, default=35, + help="Number of slots required for each job.") + parallelize_by_weight.set_defaults(func=parallelize_by_weight_main) + + # Run script all weights. + all_weights = subparsers.add_parser( + "all-weights", help=all_weights_main.__doc__) + all_weights.set_defaults(func=all_weights_main) + + args = parser.parse_args() + + if args.func == one_weight_main: + args.weight_exp = get_weight_from_jid( + args.weight_exp, args.max_weight, args.weight_step_size) + else: + # `args.weight_exp` only needs to be defined for `one_weight_main`. + pass + + args.func(**args.__dict__) diff --git a/gbd_2021/fertility_forecast_code/education/cohort_correction.py b/gbd_2021/fertility_forecast_code/education/cohort_correction.py new file mode 100644 index 0000000..4f3bfa8 --- /dev/null +++ b/gbd_2021/fertility_forecast_code/education/cohort_correction.py @@ -0,0 +1,482 @@ +"""This script performs cohort correction on education forecasts created by +the forecast_education.py. The correction is needed to avoid drops in years +of education within a cohort. + + Example: + .. code:: bash + + python cohort_correction.py \ + --gbd-round-id 5 \ + --forecast-version 20190610_edu_sdg_scenario \ + --output-version 20190610_edu_sdg_scenario_cohort_corrected \ + --years 1950:2018:2140 \ + --draws 1000 \ + parallelize-by-draw \ + --draw-memory 6 \ + --merge-memory 40 + +""" + +import glob +import logging +import os +import subprocess +import sys +from pathlib import Path + +import pandas as pd +import xarray as xr + +from fbd_core import argparse, db +from fbd_core.file_interface import FBDPath, open_xr, save_xr + +LOGGER = logging.getLogger(__name__) + +AGE_25_GROUP_ID = 10 + +MODELED_SEX_IDS = (1, 2) +MODELED_AGE_GROUP_IDS = tuple(range(6, 21)) + (30, 31, 32, 235) +INDEX_COLS = ('location_id', 'year_id', 'age_group_id', 'sex_id', 'scenario') + + +def get_cohort_info_from_age_year(age_group_ids, years): + """Calculates the cohort year for each age_group - year pair. + In the context of education, a cohort year is the year the cohort + turned 5 years old and not the birth year. + + For a given age_group - year pair, its cohort year is calculated + in the following manner: + + cohort_year = (year - age_group_lower_bound) + 5 + + The 5 is added at the end to force the cohort to start at age 5 + and not 0. + + For example: + Let the age group be 10 and the year be 1990. Then this group + will belong to the (1990 - 25) + 5 = **1970** cohort. + + Args: + age_group_ids (list(int)): + List of age group ids. + years (YearRange): + The past and forecasted years. + Returns: + cohort_age_df (pandas.DataFrame): + Dataframe with (age&year) to cohort mapping. + """ + + LOGGER.info("In get_cohort_info_from_age_year") + age_meta_df = db.get_ages()[[ + 'age_group_id', 'age_group_years_start']] + + all_age_year_df = pd.MultiIndex.from_product( + [age_group_ids, years.years], + names=['age_group_id', 'year_id'] + ).to_frame().reset_index(drop=True) + + cohort_age_df = all_age_year_df.merge( + age_meta_df, on='age_group_id').astype({'age_group_years_start': int}) + + # Finding the cohort from year and lower bound of age_group as described + # above. + cohort_age_df['cohort_year'] = ( + cohort_age_df['year_id'] - ( + cohort_age_df['age_group_years_start'] - 5)) + + # All cohorts don't need correction.Only those that extend into the future. + correction_cohorts_years = cohort_age_df[ + cohort_age_df['year_id'] > years.past_end]['cohort_year'].unique() + cohort_age_df['need_correction'] = \ + cohort_age_df['cohort_year'].isin(correction_cohorts_years) + + cohort_age_df.drop('age_group_years_start', axis=1, inplace=True) + + return cohort_age_df + + +def apply_cohort_correction(cohort_df, years): + """Perform correction on the cohort data. The purpose of the correction is + to prevent forecasted education years from decreasing within a cohort and + also from increasing after the age of 25. + + The correction id performed only on the forecasted years in a cohort. + The past data is untouched. The correction involves the following checks + and actions: + + 1. Keep track of the running maximum value. + + 2. If the current age is less than 25 and the value is less than the + running max, then replace it with the running max. + For example, consider the cohort of 2000. The cohort turned 15 in 2015 + and 20 in 2020. If the value of 2020 is less than the value at 15 then + we replace the value at 2020. This is where tracking the running max + proves useful. + + 3. If the age is greater than 25 and the cohort turned 25 in the future, + then replace the current value with the value at age 25. + For example, consider the cohort of 2000. The cohort turned 25 in 2025 + and 30 in 2030. The value at 2030 is simply replaced with the value at 25. + + 4. If the age is greater than 25 but the cohort turned 25 in the past, + then replace the current value with the latest value of the past. + For example, consider the cohort of 1980. The cohort turned 25 in the + year 2005 and last year in the past for this cohort is 2015. The forecasts + in this cohort will be replaced with the value at 2015. + + Args: + cohort_df (pandas.Dataframe): + A dataframe containing the data for a single cohort. + years (YearRange): + The past and forecasted years. + + Returns: + cohort_df (pandas.DataFrame): + Dataframe with corrected cohort data. + """ + + max_value = -1 + cohort_df['value'] = cohort_df['value'] + cohort_year = cohort_df['cohort_year'].values[0] + + age_25_year = cohort_df.query( + 'age_group_id == @AGE_25_GROUP_ID')['year_id'].values + # If cohort started in the past, track the value at age 25. + if cohort_year < years.forecast_start: + constant_forecast_val = cohort_df.loc[ + cohort_df['year_id'] < years.forecast_start, 'value'].values[-1] + + for idx, row in cohort_df.iterrows(): + max_value = max(max_value, row['value']) + + # Only consider forecasts for adults for correction. + if row['year_id'] < years.forecast_start or row["age_group_id"] <= 8: + continue + + # Update value if younger than or equal to 25, + if row['age_group_id'] <= AGE_25_GROUP_ID: + # but only if val at previous age_group_id of cohort is + # larger than current val + prev_age = row["age_group_id"] - 1 + val_previous_age_is_max = ( + cohort_df[cohort_df.age_group_id==prev_age].value.iloc[0] == + max_value) + if val_previous_age_is_max: + cohort_df.loc[idx, 'value'] = max_value + + # If older and turned 25 in the future, then replace with the + # value at age 25. + elif age_25_year > years.past_end: + age_25_val = cohort_df.loc[ + cohort_df['age_group_id'] == AGE_25_GROUP_ID, 'value'].values[0] + cohort_df.loc[idx, 'value'] = age_25_val + + # If older but had turned 25 in the past, then replace value with + # value from the last past year of the cohort. + else: + cohort_df.loc[idx, 'value'] = constant_forecast_val + + return cohort_df + + +def get_corrected_da(uncorrected_draw_da, cohort_age_df, years): + """Accepts the uncorrected dataarray, converts it into cohort space, + extracts cohorts that need correction and applies cohort correction. + + Args: + uncorrected_draw_da (xr.DataArray): + Dataarray with uncorrected education cohorts. + cohort_age_df (pd.DataFrame): + Dataframe that contains the ages and year_ids associated with + each cohort years. + years (YearRange): + The past and forecasted years. + + Returns: + corrected_da (xr.DataArray): + Dataarray that contains the cohort corrected forecasts. + """ + + uncorrected_draw_df = uncorrected_draw_da.rename("value") \ + .to_dataframe() \ + .reset_index() + + uncorrected_draw_df.drop_duplicates(inplace=True) + + uncorrected_draw_df = uncorrected_draw_df.merge( + cohort_age_df, on=['year_id', 'age_group_id']) + + uncorrected_draw_df = uncorrected_draw_df.sort_values( + by=['location_id', 'sex_id', 'scenario', 'cohort_year', + 'year_id', 'age_group_id']).reset_index(drop=True) + + # Extracting cohorts that need correction. + cohorts_to_correct_df = uncorrected_draw_df.query( + "need_correction==True").copy(deep=True).reset_index(drop=True) + + # Applying correction + corrected_cohorts = cohorts_to_correct_df.groupby([ + 'location_id', 'sex_id', 'scenario', 'cohort_year'] + ).apply(apply_cohort_correction, years) + + # Re-combine with unmodified cohorts. + unmodified_cohorts_df = uncorrected_draw_df.query( + "need_correction==False").copy(deep=True).reset_index(drop=True) + combined_df = pd.concat( + [unmodified_cohorts_df, corrected_cohorts], ignore_index=True) + + # Convert back to dataarray + combined_df = combined_df.sort_values(by=list(INDEX_COLS)) + corrected_da = combined_df.drop( + ['cohort_year', 'need_correction'], axis=1 + ).set_index(list(INDEX_COLS))['value'].to_xarray() + + return corrected_da + + +def one_draw_main(gbd_round_id, years, draw, forecast_version, output_version): + """Driver function that handles the education cohort correction for a + single draw. + + Args: + gbd_round_id (int): + The gbd round id. + years (YearRange): + The past and forecasted years. + draw (int): + The draw number to perform the correction on. + forecast_version (str): + Forecast version of education. + output_version (str): + Cohort corrected version. + """ + LOGGER.info("Applying cohort correction to draw: {}".format(draw)) + input_dir = FBDPath("".format()) # Path removed for security reasons + uncorrected_da = open_xr(input_dir / "education.nc").data + # subset to national and subnational location ids + location_table = db.get_location_set(gbd_round_id) + + # modeling subnational estimates + modeled_location_ids = list(location_table["location_id"].unique()) + avail_sex_ids = [ + sex for sex in uncorrected_da["sex_id"].values + if sex in MODELED_SEX_IDS] + + # Age groups 2,3,4 and 5 gets filtered out here. Will add them back later. + avail_age_group_ids = [ + age for age in uncorrected_da["age_group_id"].values + if age in MODELED_AGE_GROUP_IDS] + + uncorrected_draw_da = uncorrected_da.sel( + sex_id=avail_sex_ids, + age_group_id=avail_age_group_ids, + location_id=modeled_location_ids + ).sel(draw=draw, drop=True) + + # Create cohort information from age groups and year ids. + + cohort_age_df = get_cohort_info_from_age_year(avail_age_group_ids, years) + + corrected_da = get_corrected_da( + uncorrected_draw_da, cohort_age_df, years) + + # Combine with dropped age groups + dropped_age_ids = [ + age for age in uncorrected_da["age_group_id"].values + if age not in MODELED_AGE_GROUP_IDS] + + dropped_age_da = uncorrected_da.sel( + sex_id=avail_sex_ids, + age_group_id=dropped_age_ids, + location_id=modeled_location_ids).sel(draw=draw, drop=True) + + combined_da = xr.concat([dropped_age_da, corrected_da], dim='age_group_id') + combined_da['draw'] = draw + op_dir = FBDPath("".format()) + + save_xr(combined_da, op_dir / "corrected_edu_draw{}.nc".format(draw), + metric="rate", space="identity") + + +def merge_main(output_version, gbd_round_id): + """Combine all of the netcdf files generated by one_draw_main and save + the combined file as `education.nc` in the same directory. + + Args: + output_version (str): + Cohort corrected version. + gbd_round_id (int): + The gbd round id. + """ + input_dir = FBDPath("".format()) # Path removed for security reasons + file_names = list(input_dir.glob('corrected_edu_draw*.nc')) + edu_ds = xr.open_mfdataset(file_names, concat_dim="draw") + edu_da = list(edu_ds.data_vars.values())[0] + + LOGGER.info("Saving corrected education.") + edu_da.name = "value" + edu_path = input_dir / "education.nc" + edu_da.to_netcdf(str(edu_path)) + + +def parallelize_main( + draw_memory, merge_memory, years, draws, forecast_version, + output_version, + gbd_round_id): + """Subprogram for submitting an array job to perform the cohort + corrections by draw, and then a qsub for merging the draws together. + + Args: + draw_memory (int): + Cohort corrected version. + merge_memory (int): + The gbd round id. + years (YearRange): + The past and forecasted years. + draws (int): + Number of draws in the data. + forecast_version (str): + Forecast version of education. + output_version (str): + Cohort corrected version. + gbd_round_id (int): + The gbd round id. + + Raises: + (RuntimeError): + + * If qsub failed while submitting the array job. + * If qsub failed while submitting the merge job.. + """ + exec = sys.executable + script = Path(__file__).absolute() + + qsub_template = ( + "qsub -N 'cohort-correct' " + f"-l m_mem_free={draw_memory}G " + "-l fthread=1 " + "-l archive " + "-l h_rt=03:00:00 " + "-q all.q " + "-P proj_forecasting " + f"-t 1:{draws} " + f"-tc 200 " + f"-b y {exec} " + f"{script} " + f"--forecast-version {forecast_version} " + f"--output-version {output_version} " + f"--years {years} " + f"--draws {draws} " + f"--gbd-round-id {gbd_round_id} " + "one-draw-correction" + ) + + LOGGER.info(qsub_template) + qsub_proc = subprocess.Popen( + qsub_template, shell=True, stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + qsub_out, _ = qsub_proc.communicate() + if qsub_proc.returncode: + err_msg = "Cohort correction qsub failed." + LOGGER.error(err_msg) + raise RuntimeError(err_msg) + + try: + hold_jid = int((qsub_out.split()[2]).split(b".")[0]) + except: + err_msg = "Error getting the hold job id." + LOGGER.error(err_msg) + raise RuntimeError(err_msg) + + merge_qsub = ( + "qsub -N 'cohort-correct-merge' " + f"-hold_jid {hold_jid} " + f"-l m_mem_free={merge_memory}G " + "-l fthread=1 " + "-l archive " + "-l h_rt=01:00:00 " + "-q all.q " + "-P proj_forecasting " + f"-b y {exec} " + f"{script} " + f"--forecast-version {forecast_version} " + f"--output-version {output_version} " + f"--years {years} " + f"--draws {draws} " + f"--gbd-round-id {gbd_round_id} " + "merge-draws" + ) + + LOGGER.info(merge_qsub) + merge_qsub_proc = subprocess.Popen( + merge_qsub, shell=True, stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + merge_qsub_proc.wait() + if merge_qsub_proc.returncode: + merge_err_msg = "Merge qsub failed." + LOGGER.error(merge_err_msg) + raise RuntimeError(merge_err_msg) + + +if __name__ == "__main__": + + def get_draw(draw): + sge_task_id = os.environ.get("SGE_TASK_ID") + if draw: + return draw + elif sge_task_id: + LOGGER.debug("SGE_TASK_ID: {}".format(sge_task_id)) + return int(sge_task_id) - 1 + else: + err_msg = "rei and SGE_TASK_ID can not both be NoneType." + LOGGER.error(err_msg) + raise ValueError(err_msg) + + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawTextHelpFormatter) + + parser.add_argument( + "--gbd-round-id", type=int, required=True, help="GBD round of data") + parser.add_argument( + "--forecast-version", type=str, required=True, + help="The version of education forecasts to pull.") + parser.add_argument( + "--output-version", type=str, required=True, + help="The version of cohort corrected forecasts.") + parser.add_arg_years(required=True) + parser.add_arg_draws(required=True) + + subparsers = parser.add_subparsers() + + one_draw = subparsers.add_parser( + "one-draw-correction", help=one_draw_main.__doc__) + one_draw.add_argument( + "--draw", type=int, help="Draw to apply the cohort correction.") + one_draw = one_draw.set_defaults(func=one_draw_main) + + parallelize_by_draw = subparsers.add_parser( + "parallelize-by-draw", help=parallelize_main.__doc__) + parallelize_by_draw.add_argument( + "--draw-memory", type=int, required=True, help="Memory for each job.") + parallelize_by_draw.add_argument( + "--merge-memory", type=int, required=True, help="Memory for merge job.") + parallelize_by_draw.set_defaults(func=parallelize_main) + + merge_draws = subparsers.add_parser("merge-draws", + help=merge_main.__doc__) + merge_draws.set_defaults(func=merge_main) + + args = parser.parse_args() + + if args.func == parallelize_main: + parallelize_main( + args.draw_memory, args.merge_memory, args.years, args.draws, + args.forecast_version, args.output_version, args.gbd_round_id) + elif args.func == merge_main: + merge_main(args.output_version, args.gbd_round_id) + else: + args.draw = get_draw(args.draw) + one_draw_main( + args.gbd_round_id, args.years, args.draw, args.forecast_version, + args.output_version) diff --git a/gbd_2021/fertility_forecast_code/education/covid/apply_shocks.py b/gbd_2021/fertility_forecast_code/education/covid/apply_shocks.py new file mode 100644 index 0000000..961122c --- /dev/null +++ b/gbd_2021/fertility_forecast_code/education/covid/apply_shocks.py @@ -0,0 +1,676 @@ +""" +A script that applies education shocks due to COVID school closures to educational attainment. All +calculations are done on the reference scenario at mean level. The means are used to shift education +draws for all scenarios at the end. The output data will end 5 years earlier than the input +education data due to age group interpolation. (E.g. if the `years` arg is "1990:2020:2150", the +output data will have years 2020-2145.) Age group interpolation is not actually necessary to produce +shocked education because the period data will look the same without it, but it is useful for +making sensible cohort plots for vetting. + +Steps for applying COVID shocks to education: + 1. Age-split education from 5-year ages to single-year ages. (Interpolate ages < 25.) + 2. Convert single-year education to cohort space. + 3. Scale shocks such that education lost cannot be greater than the increase in educational + attainment that year. + 4. Apply broadband terms to shocks as a proxy for online education. + 5. Shift shock age/years so that all shocks occurring before age 15 are introduced at age 15. + 6. Convert COVID shocks to cohort space. + 7. Sum shocks occurring before age 15 and shocks at age 15, since that is the youngest age in + the education data. + 8. Expand dimensions of shocks to match education data. + 9. Compute cumulative shock values. + 10. Subtract cumulative shocks from education cohorts. + 11. Convert education data back to 5-year age groups in period space. + 12. Shift period education draws by the difference between shocked and unshocked education means + in period space. + +See comments in `apply_education_disruptions` to match above steps with specific functions and code +blocks. + +.. code:: bash + + python apply_shocks.py \ + --gbd-round-id 11 \ + --years 1990:2020:2150 \ + --draws 500 \ + --edu-version unshocked_edu_vers \ + --shock-version shock_vers \ + --broadband-version broadband_vers \ + --output-version-tag some_tag_append_to_output_version_name (optional) + +.. TODO:: + - Add type hints to all functions. + - Make constants.py to store global constants. +""" + + +import click +import logging +import numpy as np +import pandas as pd +import warnings +import xarray as xr + +from fhs_lib_database_interface.lib.query.age import get_ages +from fhs_lib_database_interface.lib.query.location import get_location_set +from fhs_lib_data_transformation.lib.dimension_transformation import expand_dimensions +from fhs_lib_data_transformation.lib.processing import log_with_offset, invlog_with_offset +from fhs_lib_data_transformation.lib.resample import resample +from fhs_lib_file_interface.lib.file_interface import FBDPath +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr, save_xr +from fhs_lib_year_range_manager.lib.year_range import YearRange + +logging.getLogger("console").setLevel(logging.INFO) +LOGGER = logging.getLogger(__name__) + + +# TODO: These will go to a constants.py since ETL code depends on some of the same variables. +MODELED_AGES = list(range(8, 21)) + [30, 31, 32] + [235] ## standard GBD 5-year age groups +UNDER_25_AGES = [8, 9] ## ages to interpolate +COHORT_AGE_START = 5 +COHORT_AGE_START_ID = 53 ## age group ID for age 5 +YOUNGEST_AGE_IN_DATA = 15 +AGE_15_ID = 63 +ALL_SINGLE_YEAR_AGE_IDS = list(range(53, 143)) ## used to generate cohort metadata +SINGLE_YEAR_15_TO_19 = np.array(range(63, 68)) ## used to make single-year age IDs for age-splitting +TERMINAL_AGE_ID = 235 +TERMINAL_AGE_START = 95 +SHOCK_YEARS = [2020, 2021, 2022] +AGE_GROUP_WIDTH = 5 # years +PERIOD_DF_COLUMNS = [ + "location_id", + "year_id", + "sex_id", + "age_group_id", + "value", + "shocked_val", + "shocked_val_broadband_corrected" +] +PLOT_DF_ID_COLUMNS = [ + "location_id", + "ihme_loc_id", + "location_ascii_name", + "location_level", + "year_id", + "sex_id", + "age_group_id", + "age_group_name" +] +PLOT_DF_RENAME_DICT = { + "value": "unshocked", + "shocked_val": "shocked", + "shocked_val_broadband_corrected": "shocked_broadband_corr" +} +PLOT_DF_VALUE_COLUMNS = list(PLOT_DF_RENAME_DICT.values()) + + +def get_cohort_info_from_age_year(age_group_ids, years, age_metadata): + """Calculates the cohort year for each age_group - year pair. + In the context of education, a cohort year is the year the cohort + turned 5 years old and not the birth year. + + For a given age_group - year pair, its cohort year is calculated + in the following manner: + + cohort_year = (year - age_group_lower_bound) + 5 + + The 5 is added at the end to force the cohort to start at age 5 + and not 0. + + For example: + Let the age group be 10 and the year be 1990. Then this group + will belong to the (1990 - 25) + 5 = **1970** cohort. + + Args: + age_group_ids (list[int]): + List of age group ids. + years (YearRange): + The past and forecasted years. + Returns: + cohort_age_df (pandas.DataFrame): + Dataframe with (age & year) to cohort mapping. + """ + + age_metadata = age_metadata[['age_group_id', 'age_group_years_start']] + + all_age_year_df = pd.MultiIndex.from_product( + [age_group_ids, years.forecast_years], names=['age_group_id', 'year_id'] + ).to_frame().reset_index(drop=True) + + cohort_age_df = all_age_year_df.merge( + age_metadata, on='age_group_id' + ).astype({'age_group_years_start': int}) + + # Finding the cohort from year and lower bound of age_group as described + # above. + cohort_age_df['cohort_year'] = ( + cohort_age_df.year_id - cohort_age_df.age_group_years_start + COHORT_AGE_START + ) + + return cohort_age_df.drop(columns="age_group_years_start") + + +def read_in_data(gbd_round_id, draws, edu_version, shock_version, broadband_version): + """Reads in all inputs from the filesystem, and returns education draws, mean reference-only + education, COVID shock proportions, and broadband correction terms. + + Args: + gbd_round_id (int): + The GBD round ID. + draws (int): + The number of draws desired. This can affect the mean values since resampling occurs + before taking the mean. + edu_version (str): + The education version to pull. + shock_version (str): + The COVID shock version to pull. + broadband_version (str): + The broadband version to pull. + Returns: + education_resampled (xr.DataArray): + Education draws. + reference_mean_edu (xr.DataArray): + Reference-only education means. + shocks (xr.DataArray): + The location/year-specific proportion of school closures due to COVID. + broadband (pd.DataFrame): + Location-specific broadband access scaled between 0 and 0.5, so that the maximum amount + of educational attainment that can be "protected" from COVID shocks via online education + is half of education lost. + """ + + input_edu_path = FBDPath(f"") + shock_path = FBDPath(f"") + broadband_path = FBDPath(f"") + + education = open_xr(input_edu_path / "education.nc").data.sel( + age_group_id=MODELED_AGES + ) + education_resampled = resample(education, draws) + LOGGER.info("Taking mean of education draws.") + reference_mean_edu = education_resampled.sel(scenario=0, drop=True).mean("draw") + + shocks = open_xr(shock_path / "props.nc").data.sel(vac_status="all", drop=True) + broadband = pd.read_csv(broadband_path / "broadband.csv") + + return education_resampled, reference_mean_edu, shocks, broadband + +def compute_average_annual_change(education, age_group_id, year_id): + """Computes the average annual absolute change within a cohort between two adjacent 5-year age + groups (i.e. the annual absolute change within a cohort between t and t+5). This provides a + rough estimate of how much education is attained each year in the lower age group. + + Args: + education (xr.DataArray): + Education means. + age_group_id (int): + The age group ID we want to calculate absolute change for. + year_id (list[int]): + The years for which we want to calculate absolute change.. + Returns: + annual_change (xr.DataArray): + Average annual absolute change over the cohort in the specified age group/years. + """ + + # The age group and years are shifted, so this is how much education a cohort is projected + # to have at `age_group_id + 1` and `year_id + 5`. + # (E.g. What 15 to 19 year olds will look like in 5 years at age 20-24.) + edu_cohort_shifted = education.shift(age_group_id=-1, year_id=-AGE_GROUP_WIDTH).sel( + age_group_id=age_group_id, year_id=year_id, drop=True + ) + # This allows us to compute the average annual absolute change over 5 years as the cohort ages. + annual_change = ( + edu_cohort_shifted - + education.sel(age_group_id=age_group_id, year_id=year_id, drop=True) + ) / AGE_GROUP_WIDTH + + return annual_change + + +def age_split_education(education, years): + """Age splits education from 5-year age groups to single-year age groups. Single-year ages below + 25 are interpolated at a constant value of annual absolute change, which is estimated via the + absolute change in the cohort over 5 years between adjacent 5-year age groups. + + Note: The script output data will end 5 years earlier than the input education data due to age + group interpolation. (E.g. if the `years` arg is "1990:2020:2150", the output data will have + years 2020-2145.) Age group interpolation is not actually necessary to produce shocked education + because the period data will look the same without it, but it is useful for making sensible + cohort plots for vetting.) + + Args: + education (xr.DataArray): + Education means with standard GBD age groups (5-year age groups). + years (list[int]): + The years of the data to be split. + Returns: + edu_age_split_df (pd.DataFrame): + Education split into single-year age groups. + """ + + edu_age_split_list = [] + for i, age in enumerate(MODELED_AGES[:-1]): + new_ages = SINGLE_YEAR_15_TO_19 + (i * AGE_GROUP_WIDTH) + + age_da = education.sel(age_group_id=age, year_id=years.forecast_years) + age_split = expand_dimensions(age_da, age_group_id=new_ages) + + if age in UNDER_25_AGES: + LOGGER.info(f"Interpolating single year ages for age-group-id {age}.") + + age_scalars = xr.DataArray( + [-2, -1, 0, 1, 2], dims=["age_group_id"], coords=dict(age_group_id=new_ages) + ) + + annual_change = compute_average_annual_change( + education, age_group_id=age, year_id=years.forecast_years + ) + + age_addends = ( + expand_dimensions(annual_change, age_group_id=new_ages) * age_scalars + ).rename("age_addend") + + age_split = age_split + age_addends + + edu_age_split_list.append(age_split) + + edu_age_split = xr.concat( + (edu_age_split_list + + [education.sel(age_group_id=[TERMINAL_AGE_ID], year_id=years.forecast_years)]), + dim="age_group_id" + ).rename("value") + + edu_age_split_df = edu_age_split.to_dataframe().reset_index() + + return edu_age_split_df + + +def apply_broadband_terms(shocks, broadband): + """UNESCO broadband access data is used as a proxy for the availability of online education. The + proportion of broadband access in each location is scaled to the shock proportion, and then + subtracted from the shock to create an "online education corrected shock". The broadband terms + have already been scaled between 0 and 0.5 so that students cannot regain more than half the + education they are predicted to lose from COVID school closures (per CJLM). + + Args: + shocks (xr.DataArray): + Proportion of days that school closed due to COVID in 2020, 2021, and 2022 out of 365 + days, pre-scaled + broadband (pd.DataFrame): + Proportion of broadband access per capita scaled between 0 and 0.5. + Returns: + shocks_df (pd.DataFrame): + DataFrame with both uncorrected and broadband-corrected shocks. + """ + + LOGGER.info("Applying broadband correction.") + + shocks_df = shocks.rename("shock").to_dataframe().reset_index() + shocks_df = shocks_df.merge(broadband) + shocks_df["broadband_term"] = shocks_df.shock * shocks_df.broadband_term + shocks_df["shock_minus_broadband"] = shocks_df.shock - shocks_df.broadband_term + shocks_df.loc[shocks_df.shock_minus_broadband < 0, "shock_minus_broadband"] = 0 + shocks_df.drop(columns="broadband_term", inplace=True) + + return shocks_df + + +def shift_shock_age_years(shocks_df): + """Education data is only forecasted for ages 15+, but many shocks occur before age 15. Thus, + any shocks that occur before age 15 are age/year-shifted such that they are introduced in the + future when the shocked cohort reaches age 15. + + Args: + shocks_df (pd.DataFrame): + DataFrame with both uncorrected and broadband-corrected shocks. + Returns: + shocks_df (pd.DataFrame): + Shocks age/year-shifted such that they are introduced in the future when the shocked + cohort reaches age 15. + """ + + shocks_df["age"] = _get_age_from_age_id(shocks_df.age_group_id) + shocks_df["years_till_15"] = YOUNGEST_AGE_IN_DATA - shocks_df.age + shocks_df.loc[shocks_df.years_till_15 < 0, "years_till_15"] = 0 + shocks_df["shock_year"] = shocks_df.year_id + shocks_df["year_id"] = shocks_df.year_id + shocks_df.years_till_15 + shocks_df.loc[shocks_df.age < YOUNGEST_AGE_IN_DATA, "age_group_id"] = AGE_15_ID + shocks_df.loc[shocks_df.age < YOUNGEST_AGE_IN_DATA, "age"] = YOUNGEST_AGE_IN_DATA + shocks_df.drop(columns="years_till_15", inplace=True) + + return shocks_df + + +def expand_shock_cohorts(shock_cohorts, cohort_age_df): + """For the application of the COVID shocks to educational attainment to make sense, the shocks + must be cumulative and have the same dimensions as the education data. This function expands + the dimensions of the shock data such that the data can be applied to educational attainment. + + Args: + shocks_cohorts (pd.DataFrame): + Shocks in cohort space. + cohort_age_df (pandas.DataFrame): + Dataframe with (age & year) to cohort mapping. + Returns: + shock_cohorts_expanded (pd.DataFrame): + Shock cohorts with dimensions expanded to match education cohorts. + """ + + cohort_age_sex_location_df = expand_dimensions( + cohort_age_df.set_index(["age_group_id", "year_id"]).to_xarray()["cohort_year"], + location_id=shock_cohorts.location_id.unique(), sex_id=[1, 2] + ).to_dataframe().reset_index() + + shock_cohorts_expanded = shock_cohorts.merge( + cohort_age_sex_location_df, + on=["location_id", "year_id", "age_group_id", "sex_id", "cohort_year"], + how="outer" + ) + + shock_cohorts_expanded.loc[shock_cohorts_expanded.age.isnull(), "age"] = _get_age_from_age_id( + shock_cohorts_expanded.age_group_id + ) + shock_cohorts_expanded.loc[ + shock_cohorts_expanded.age_group_id == TERMINAL_AGE_ID, "age" + ] = TERMINAL_AGE_START + shock_cohorts_expanded.loc[shock_cohorts_expanded.shock.isnull(), "shock"] = 0 + shock_cohorts_expanded.loc[ + shock_cohorts_expanded.shock_minus_broadband.isnull(), "shock_minus_broadband" + ] = 0 + + return shock_cohorts_expanded + + +def make_cumulative_shocks(shock_cohorts): + """Calculates cumulative shocks over the cohort to apply to the cohort education data. + + Args: + shocks_cohorts (pd.DataFrame): + Shock cohorts with dimensions expanded to match education cohorts. + Returns: + cumulative_shocks (pd.DataFrame): + Shock cohorts with dimensions expanded to match education cohorts with cumulative + shocks. + """ + + LOGGER.info("Making cumulative shocks.") + + cumulative_shocks = shock_cohorts.sort_values(["location_id", "year_id"]) + + cumulative_shocks["cumulative_shock"] = cumulative_shocks.groupby( + ["location_id", "sex_id", "cohort_year"] + ).cumsum().shock + + cumulative_shocks["cumulative_shock_minus_broadband"] = cumulative_shocks.groupby( + ["location_id", "sex_id", "cohort_year"] + ).cumsum().shock_minus_broadband + + return cumulative_shocks + + +def shock_education(shock_cohorts, edu_cohorts): + """Applies cumulative shocks to education cohort data. + + Args: + shocks_cohorts (pd.DataFrame): + Shock cohorts with dimensions expanded to match education cohorts with cumulative + shocks. + Returns: + edu_shocked (pd.DataFrame): + Cohort data with columns for education, shocked education, and shocks. + """ + + LOGGER.info("Shocking education.") + + edu_shocked = edu_cohorts.merge(shock_cohorts.query("age_group_id >= @AGE_15_ID"), how="outer") + + edu_shocked["shocked_val"] = edu_shocked.value - edu_shocked.cumulative_shock + edu_shocked["shocked_val_broadband_corrected"] = ( + edu_shocked.value - edu_shocked.cumulative_shock_minus_broadband + ) + + return edu_shocked + +def convert_to_period_space_gbd_ages(edu_cohorts_shocked): + """Most uses of educational attainment require period educational attainment with 5-year GBD + age groups. This function converts single-year age education cohorts to 5-year age group + education in period space. + + Args: + edu_cohorts_shocked (pd.DataFrame): + Cohort data with columns for education, shocked education, and shocks. + Returns: + period_df (pd.DataFrame): + Period-space education (shocked and unshocked) with GBD 5-year age groups. + """ + + LOGGER.info("Converting to period space.") + + age_bins = [SINGLE_YEAR_15_TO_19 + (i * 5) for i in range(0, int(len(MODELED_AGES) - 1))] + five_year_to_single_year_age_dict = dict(zip(MODELED_AGES[:-1], age_bins)) + + period_df_list = [] + for five_year_id, single_year_ids in five_year_to_single_year_age_dict.items(): + + period_age_df = edu_cohorts_shocked.query( + "age_group_id in @single_year_ids" + )[PERIOD_DF_COLUMNS] + + period_age_df = period_age_df.groupby( + ["location_id", "year_id", "sex_id"] + ).mean().reset_index() + + period_age_df["age_group_id"] = five_year_id + + period_df_list.append(period_age_df) + + terminal_age_df = edu_cohorts_shocked.query( + "age_group_id == @TERMINAL_AGE_ID" + )[PERIOD_DF_COLUMNS] + + period_df = pd.concat(period_df_list + [terminal_age_df], ignore_index=True) + + return period_df + + +def shift_education_draws(reference_mean_edu, period_edu_shocked, education_draws): + """Finds the log difference between reference scenario unshocked education means and shocked + education means in period space. The difference is then subtracted from the draws in log space + to make the shocked draws. + + Args: + reference_mean_edu (xr.DataArray): + Mean-level reference-only educational attainment in period space. + period_edu_shocked (xr.DataArray): + Shocked mean-level reference-only educational attainment in period space. + education_draws (xr.DataArray): + Draw-level educational attainment in period space with all scenarios. + Returns: + edu_shocked_draws (xr.DataArray): + Shocked draw-level educational attainment in period space with all scenarios. + + """ + LOGGER.info("Making draws.") + log_difference = ( + log_with_offset(reference_mean_edu, offset=0) - + log_with_offset(period_edu_shocked, offset=0) + ) + log_edu_shocked_draws = ( + log_with_offset(education_draws, offset=0) - log_difference + ) + edu_shocked_draws = invlog_with_offset(log_edu_shocked_draws, offset=0, bias_adjust=False) + + return edu_shocked_draws + + +def make_plot_dfs(edu_cohorts_shocked, period_edu_shocked, age_metadata, location_metadata): + """Makes DataFrames for later plotting and vetting. + + Args: + edu_cohorts_shocked (pd.DataFrame): + Single-year age education cohorts with shocks. + period_edu_shocked (pd.DataFrame): + 5-year age group education in period space with shocks. + age_metadata (pd.DataFrame): + DataFrame with age group metadata. + location_metadata (pd.DataFrame): + DataFrame with location metadata. + Returns: + plot_dfs["edu_cohorts_shocked"] (pd.DataFrame): + DataFrame for plotting shocked education cohorts with single-year ages. + plot_dfs["period_edu_shocked"] (pd.DataFrame): + DataFrame for plotting shocked education in period space with GBD 5-year age groups. + """ + + LOGGER.info("Making plot dfs.") + + age_metadata = age_metadata[["age_group_id", "age_group_name"]] + location_metadata = location_metadata[ + ["location_id", "ihme_loc_id", "location_ascii_name", "level"] + ].rename(columns={"level": "location_level"}) + + plot_dfs = dict(edu_cohorts_shocked=edu_cohorts_shocked, period_edu_shocked=period_edu_shocked) + for key, df in plot_dfs.items(): + + plot_df = df.merge(location_metadata).merge(age_metadata).rename( + columns=PLOT_DF_RENAME_DICT + ) + + try: + plot_df_id_columns = PLOT_DF_ID_COLUMNS + ["cohort_year"] + plot_df = plot_df[plot_df_id_columns + PLOT_DF_VALUE_COLUMNS] + except: + plot_df_id_columns = PLOT_DF_ID_COLUMNS + plot_df = plot_df[plot_df_id_columns + PLOT_DF_VALUE_COLUMNS] + + plot_df = pd.melt( + frame=plot_df, + id_vars=plot_df_id_columns, + value_vars=PLOT_DF_VALUE_COLUMNS, + var_name="education_state", + value_name="education_years" + ) + + plot_dfs[key] = plot_df + + return plot_dfs["edu_cohorts_shocked"], plot_dfs["period_edu_shocked"] + + +@click.command() +@click.option("--gbd-round-id", type=int) +@click.option("--years", type=str) +@click.option("--draws", type=int) +@click.option("--edu-version", type=str) +@click.option("--shock-version", type=str) +@click.option("--broadband-version", type=str) +@click.option("--output-version-tag", type=str, default=None) +def apply_education_disruptions( + gbd_round_id: int, + years: str, + draws: int, + edu_version: str, + shock_version: str, + broadband_version: str, + output_version_tag: str +): + # The output data will end 5 years earlier than the input education data due to age group + # interpolation. + years = YearRange.parse_year_range(years) + years = YearRange(years.past_start, years.forecast_start, years.forecast_end - AGE_GROUP_WIDTH) + + output_version = edu_version + "_covid_shocks" + if output_version_tag is not None: + output_version += f"_{output_version_tag}" + + output_path = FBDPath(f"") + + age_metadata = get_ages() + location_metadata = get_location_set(gbd_round_id) + + cohort_age_df = get_cohort_info_from_age_year( + age_group_ids=ALL_SINGLE_YEAR_AGE_IDS + [TERMINAL_AGE_ID], + years=years, + age_metadata=age_metadata + ) + + education_draws, reference_mean_edu, shocks, broadband = read_in_data( + gbd_round_id, draws, edu_version, shock_version, broadband_version + ) + + # Step 1 + single_year_edu = age_split_education(reference_mean_edu, years) + + # Step 2 + edu_cohorts = single_year_edu.merge(cohort_age_df, on=['year_id', 'age_group_id']) + + # Step 3 + # This is used to scale the shocks to how much education would have been gained if not for the + # pandemic. Ideally this would be age-specific, but for now we only have ages >= 15. + annual_change_in_edu_15to19_during_pandemic = compute_average_annual_change( + reference_mean_edu, MODELED_AGES[0], SHOCK_YEARS + ) + shocks_scaled = shocks * annual_change_in_edu_15to19_during_pandemic + + # Step 4 + shocks_broadband_corrected = apply_broadband_terms(shocks_scaled, broadband) + + # Step 5 + shocks_age_year_shifted = shift_shock_age_years(shocks_broadband_corrected) + + # Step 6 + shock_cohorts = shocks_age_year_shifted.merge(cohort_age_df, on=['year_id', 'age_group_id']) + + # Step 7 + # Sum shocks occurring before age 15 and shocks at age 15 since that is the youngest age in the + # education data. + shock_cohorts = shock_cohorts.groupby( + ["location_id", "year_id", "age_group_id", "sex_id", "age", "cohort_year"] + ).sum()[["shock", "shock_minus_broadband"]].reset_index() + + # Step 8 + shock_cohorts = expand_shock_cohorts(shock_cohorts, cohort_age_df) + + # Step 9 + shock_cohorts = make_cumulative_shocks(shock_cohorts) + + # Step 10 + edu_cohorts_shocked = shock_education(shock_cohorts, edu_cohorts) + + # Step 11 + period_edu_shocked = convert_to_period_space_gbd_ages(edu_cohorts_shocked) + # Convert to DataArray for use in other pipelines. + period_edu_shocked_da = period_edu_shocked.set_index( + ["location_id", "year_id", "sex_id", "age_group_id"] + ).to_xarray()["shocked_val_broadband_corrected"].rename("value") + + # Step 12 + edu_shocked_draws = shift_education_draws( + reference_mean_edu, period_edu_shocked_da, education_draws + ) + + save_xr( + edu_shocked_draws, + output_path / "education.nc", + metric="rate", + space="identity", + shock_version=shock_version, + broadband_version=broadband_version + ) + + LOGGER.info(f"Shocked education draws saved to: {output_path}") + + cohort_plot_df, period_plot_df = make_plot_dfs( + edu_cohorts_shocked, period_edu_shocked, age_metadata, location_metadata + ) + cohort_plot_df.to_csv(output_path / "cohort_plot_df.csv", index=False) + period_plot_df.to_csv(output_path / "period_plot_df.csv", index=False) + + LOGGER.info(f"Plot dfs saved to {output_path}.") + + +def _get_age_from_age_id(age_group_id_col): + return age_group_id_col - COHORT_AGE_START_ID + COHORT_AGE_START + + +if __name__ == '__main__': + apply_education_disruptions() \ No newline at end of file diff --git a/gbd_2021/fertility_forecast_code/education/education_transform.py b/gbd_2021/fertility_forecast_code/education/education_transform.py new file mode 100644 index 0000000..8bfcc29 --- /dev/null +++ b/gbd_2021/fertility_forecast_code/education/education_transform.py @@ -0,0 +1,75 @@ +"""Functions for transforming education into different modeling-spaces and then +inverse functions for moving predictions back into the identity domain. +""" +import logging + +import xarray as xr +from frozendict import frozendict +from scipy.special import expit, logit + +LOGGER = logging.getLogger(__name__) +MAX_EDU = 18 # The maximum number of years of education someone can attain + +# `EPSILON` is used to form the lower and upper bound before converting +# education data to logit space. This value was picked to include most of the +# past education, and to avoid getting infinities and negative infinities when +# converting to logit space. +EPSILON = 1e-3 + +# The maximum number of years of education someone can attain within certain +# age groups. +SPECIAL_AGE_CAP_DICT = frozendict({ + 6: 3, + 7: 8, + 8: 13 + }) + + +def normal_to_logit(data): + """Convert education from linear space to logit space, using a lower bound + of 0 and an upper bound of ``MAX_EDU``. + + Args: + data (xr.DataArray): + education in linear space. + Returns: + xr.DataArray: + education in logit space. + """ + LOGGER.debug("Converting to logit space.") + data_caps = get_edu_caps(data) + scaled_data = data / data_caps + clipped_data = scaled_data.clip(min=EPSILON, max=1 - EPSILON) + return logit(clipped_data) + + +def logit_to_normal(data): + """Convert education from logit space to linear space, using a lower bound + of 0 and an upper bound of ``MAX_EDU``. + + Args: + data (xr.DataArray): + education in logit space. + Returns: + xr.DataArray: + education in linear space. + """ + LOGGER.debug("Converting to linear space.") + data_caps = get_edu_caps(data) + return expit(data) * data_caps + + +def log_to_normal(data): + LOGGER.debug("Converting to linear space.") + data_caps = get_edu_caps(data) + _, expanded_data_caps = xr.broadcast(data, data_caps) + return xr.ufuncs.exp(data).clip(min=0, max=expanded_data_caps) + + +def get_edu_caps(data): + data_caps = xr.ones_like(data["age_group_id"]) * MAX_EDU + for age_group_id, edu_cap in SPECIAL_AGE_CAP_DICT.items(): + data_caps = data_caps.where( + data_caps["age_group_id"] != age_group_id + ).fillna(edu_cap) + return data_caps diff --git a/gbd_2021/fertility_forecast_code/education/forecast_education.py b/gbd_2021/fertility_forecast_code/education/forecast_education.py new file mode 100644 index 0000000..3ae7def --- /dev/null +++ b/gbd_2021/fertility_forecast_code/education/forecast_education.py @@ -0,0 +1,314 @@ +"""Make a forecast and scenarios for education using the ARC method. + +>>> python forecast_education.py \ + --reference-scenario mean \ + --transform logit \ + --diff-over-mean \ + --truncate \ + --truncate-quantiles 0.15 0.85 \ + --pv-version 20180917_subnat_pv \ + --forecast-version 20190917_subnat_capped \ + --past-version 20181004_met_demand_gprdraws_gbd17final \ + --gbd-round-id 5 \ + --years 1990:2018:2100 \ + --weight-strategy use_smallest_omega_within_threshold + +""" +import logging + +import numpy as np +import xarray as xr +from frozendict import frozendict + +from fbd_core import argparse, db, etl +from fbd_core.etl import (expand_dimensions, omega_selection_strategies, + resample, scenarios) +from fbd_core.file_interface import FBDPath, open_xr, save_xr +from fbd_research.education import education_transform + +LOGGER = logging.getLogger(__name__) + +EDU_CUTOFF = 25 # no one can be educated after they are 25 +MODELED_SEX_IDS = (1, 2) +MODELED_AGE_GROUP_IDS = tuple(range(2, 21)) + (30, 31, 32, 235) +REFERENCE_SCENARIO = 0 +TRANSFORMATIONS = frozendict(( + ("no_transform", lambda x: x), + ("log", np.log), + ("logit", education_transform.normal_to_logit) +)) +INVERSE_TRANSFORMATIONS = frozendict(( + ("no_transform", lambda x: x), + ("log", education_transform.log_to_normal), + ("logit", education_transform.logit_to_normal) +)) + + +def get_num_years_to_shift(age_group_years_start): + num_years_to_shift = age_group_years_start - EDU_CUTOFF + return max(0, num_years_to_shift) + + +def lag_scenarios(data, years): + """Lag scenarios by age-years so that scenarios only deviate from the + reference in age-time pairs that have not finished education in the first + year of the forecast. + + Args: + data (xarray.DataArray): + data that must include forecasts and can include past. + years (YearRange): + forecasting timeseries + Returns: + (xarray.DataArray): + Forecasts with adjusted/lagged scenarios. + """ + age_groups = db.get_ages() + age_group_dict = dict(zip(age_groups["age_group_id"], + age_groups["age_group_years_start"])) + + forecast = data.sel(year_id=years.forecast_years) + + adjusted_scenarios = [] + for age_group_id in forecast["age_group_id"].values: + age_group_years_start = age_group_dict[age_group_id] + num_years_to_shift = get_num_years_to_shift(age_group_years_start) + age_slice = forecast.sel(age_group_id=age_group_id) + + ref_age_slice = age_slice.sel(scenario=REFERENCE_SCENARIO, drop=True) + adjusted_scen_age_slice = [] + for scenario in (-1, 1): + scen_age_slice = age_slice.sel(scenario=scenario).drop( + "scenario") + + diff = scen_age_slice - ref_age_slice + shifted_diff = diff.shift(year_id=num_years_to_shift).fillna(0) + shifted_scen_slice = shifted_diff + ref_age_slice + shifted_scen_slice["scenario"] = scenario + adjusted_scen_age_slice.append(shifted_scen_slice) + ref_age_slice["scenario"] = 0 + adjusted_scen_age_slice.append(ref_age_slice) + adjusted_scen_age_slice = xr.concat(adjusted_scen_age_slice, + dim="scenario") + adjusted_scenarios.append(adjusted_scen_age_slice) + adjusted_scenarios = xr.concat(adjusted_scenarios, dim="age_group_id") + + return adjusted_scenarios.combine_first(data) + + +def arc_forecast_education( + past, gbd_round_id, transform, weight_exp, years, + reference_scenario, diff_over_mean, truncate, truncate_quantiles, + replace_with_mean, extra_dim=None): + """Forecasts education using the ARC method. + + Args: + past (xarray.DataArray): + Past data with dimensions ``location_id``, ``sex_id``, + ``age_group_id``, ``year_id``, and ``draw``. + transform (xarray.DataArray): + Space to transform education to for forecasting. + weight_exp (float): + How much to weight years based on recency + years (YearRange): + Forecasting timeseries. + reference_scenario (str): + If 'median' then the reference scenarios is made using the + weighted median of past annualized rate-of-change across all + past years, 'mean' then it is made using the weighted mean of + past annualized rate-of-change across all past years + diff_over_mean (bool): + If True, then take annual differences for means-of-draws, instead + of draws. + truncate (bool): + If True, then truncates the dataarray over the given dimensions. + truncate_quantiles (object, optional): + The tuple of two floats representing the quantiles to take + replace_with_mean (bool, optional): + If True and `truncate` is True, then replace values outside of the + upper and lower quantiles taken across "location_id" and "year_id" + and with the mean across "year_id", if False, then replace with the + upper and lower bounds themselves. + gbd_round_id (int): + The GBD round of the input data. + Returns: + (xarray.DataArray): + Education forecasts + """ + LOGGER.debug("diff_over_mean:{}".format(diff_over_mean)) + LOGGER.debug("truncate:{}".format(truncate)) + LOGGER.debug("truncate_quantiles:{}".format(truncate_quantiles)) + LOGGER.debug("replace_with_mean:{}".format(replace_with_mean)) + LOGGER.debug("reference_scenario:{}".format(reference_scenario)) + + most_detailed_coords = _get_avail_most_detailed_coords(past, gbd_round_id) + most_detailed_past = past.sel(**most_detailed_coords) + + zeros_dropped = most_detailed_past.where(most_detailed_past > 0) + for dim in zeros_dropped.dims: + zeros_dropped = zeros_dropped.dropna(dim=dim, how="all") + + LOGGER.debug("Transforming the past to {}-space".format(transform)) + transformed_past = TRANSFORMATIONS[transform](zeros_dropped) + + LOGGER.debug("Forecasting education in the transformed space") + transformed_forecast = scenarios.arc_method( + transformed_past, gbd_round_id=gbd_round_id, years=years, + reference_scenario=reference_scenario, weight_exp=weight_exp, + diff_over_mean=diff_over_mean, truncate=truncate, + truncate_quantiles=truncate_quantiles, + replace_with_mean=replace_with_mean, reverse_scenarios=True, + extra_dim=extra_dim, scenario_roc="national") + + LOGGER.debug("Converting the forecasts to normal/identity space") + forecast = INVERSE_TRANSFORMATIONS[transform](transformed_forecast) + + refilled_forecast = etl.expand_dimensions(forecast, **most_detailed_coords) + lagged_scenarios = lag_scenarios(refilled_forecast, years) + + # Since past does get clipped to avoid infs and negative infs, we need to + # append the actual past onto the data being saved (modelers currently + # expect the past to be there) + past_broadcast_scenarios = etl.expand_dimensions( + most_detailed_past, scenario=lagged_scenarios["scenario"]) + all_data = past_broadcast_scenarios.combine_first(lagged_scenarios) + + bound_err_msg = "the forecasts have NaNs" + assert not np.isnan(all_data).any(), bound_err_msg + if np.isnan(all_data).any(): + LOGGER.error(bound_err_msg) + raise RuntimeError(bound_err_msg) + + return all_data + + +def forecast_edu_main(transform, past_version, forecast_version, pv_version, + weight_strategy, gbd_round_id, years, reference_scenario, + diff_over_mean, truncate, truncate_quantiles, + replace_with_mean, draws, **kwargs): + LOGGER.debug("weight strategy: {}".format(weight_strategy.__name__)) + pv_path = FBDPath("".format()) # Path removed for security reasons + rmse = open_xr(pv_path / "education_arc_weight_rmse.nc").data + weight_exp = weight_strategy(rmse, draws) + LOGGER.info("omega selected: {}".format(weight_exp)) + + LOGGER.debug("Reading in the past") + past_path = FBDPath("".format()) # Path removed for security reasons + past = resample(open_xr(past_path / "education.nc").data, draws) + past = past.sel(year_id=years.past_years) + + if isinstance(weight_exp, float) or isinstance(weight_exp, int): + extra_dim = None + else: + if not isinstance(weight_exp, xr.DataArray): + omega_exp_err_msg = ( + "`omega` must be either a float, an int, or an " + "xarray.DataArray") + LOGGER.error(omega_exp_err_msg) + raise RuntimeError(omega_exp_err_msg) + elif len(weight_exp.dims) != 1 or "draw" not in weight_exp.dims: + omega_exp_err_msg = ( + "If `omega` is a xarray.DataArray, then it must have only " + "1 dim, `draw`") + LOGGER.error(omega_exp_err_msg) + raise RuntimeError(omega_exp_err_msg) + elif not weight_exp["draw"].equals(past["draw"]): + omega_err_msg = ( + "If `omega` is a xarray.DataArray, then it's `draw` dim " + "must have the coordinates as `past`") + LOGGER.error(omega_err_msg) + raise RuntimeError(omega_err_msg) + else: + extra_dim = "draw" + + forecast = arc_forecast_education( + past, gbd_round_id, transform, weight_exp, years, + reference_scenario, diff_over_mean, truncate, truncate_quantiles, + replace_with_mean, extra_dim=extra_dim) + + forecast_path = FBDPath("".format()) + if isinstance(weight_exp, xr.DataArray): + report_omega = float(weight_exp.mean()) + else: + report_omega = weight_exp + save_xr(forecast, forecast_path / "education.nc", metric="number", + space="identity", omega=report_omega, + omega_strategy=weight_strategy.__name__) + LOGGER.info("education forecasts have saved") + + +def _get_avail_most_detailed_coords(data, gbd_round_id): + location_table = db.get_location_set(gbd_round_id) + + # subset to national and subnat location ids + modeled_location_ids = list(location_table["location_id"].unique()) + avail_sex_ids = [ + sex for sex in data["sex_id"].values if sex in MODELED_SEX_IDS] + avail_age_group_ids = [ + age for age in data["age_group_id"].values + if age in MODELED_AGE_GROUP_IDS] + return { + "location_id": modeled_location_ids, + "sex_id": avail_sex_ids, + "age_group_id": avail_age_group_ids + } + + +if __name__ == "__main__": + def get_weight_strategy_func(weight_strategy_name): + weight_strategy_func = getattr( + omega_selection_strategies, weight_strategy_name) + return weight_strategy_func + + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawTextHelpFormatter) + + parser.add_argument( + "--reference-scenario", type=str, choices=["median", "mean"], + help=("If 'median' then the reference scenarios is made using the " + "weighted median of past annualized rate-of-change across all " + "past years, 'mean' then it is made using the weighted mean of " + "past annualized rate-of-change across all past years.")) + parser.add_argument( + "--diff-over-mean", action="store_true", + help=("If True, then take annual differences for means-of-draws, " + "instead of draws.")) + parser.add_argument( + "--truncate", action="store_true", + help=("If True, then truncates the dataarray over the given " + "dimensions.")) + parser.add_argument( + "--truncate-quantiles", type=float, nargs="+", + help="The tuple of two floats representing the quantiles to take.") + parser.add_argument( + "--replace-with-mean", action="store_true", + help=("If True and `truncate` is True, then replace values outside of " + "the upper and lower quantiles taken across `location_id` and " + "`year_id` and with the mean across `year_id`, if False, then " + "replace with the upper and lower bounds themselves.")) + parser.add_argument( + "--transform", type=str, + choices=list(sorted(TRANSFORMATIONS.keys())), + help="Space to transform education to for forecasting.") + parser.add_argument( + "--forecast-version", type=str, required=True, + help="Version of education forecasts being made and saved now") + parser.add_argument( + "--past-version", type=str, required=True, + help="Version of past education") + parser.add_argument( + "--pv-version", type=str, required=True, + help=("Version of predictive validation done on a range of weights " + "used in the ARC method")) + parser.add_argument( + "--weight-strategy", type=get_weight_strategy_func, required=True, + help="How the weight used in the ARC method is selected.") + parser.add_argument( + "--gbd-round-id", type=int, required=True) + parser.add_arg_years(required=True) + parser.add_arg_draws(required=True) + + args = parser.parse_args() + + forecast_edu_main(**args.__dict__) diff --git a/gbd_2021/fertility_forecast_code/education/maternal_education.py b/gbd_2021/fertility_forecast_code/education/maternal_education.py new file mode 100644 index 0000000..c16f698 --- /dev/null +++ b/gbd_2021/fertility_forecast_code/education/maternal_education.py @@ -0,0 +1,83 @@ +"""This module is used to calculate maternal education. The get_maternal_edu +function returns both maternal education, and education with maternal edu filled +in for child age-groups. +""" +import logging + +import xarray as xr + +from fbd_core.etl import Aggregator, expand_dimensions, resample +from fbd_core.file_interface import FBDPath, open_xr + +LOGGER = logging.getLogger(__name__) + +SEXES = (1, 2) +FEMALE_SEX_ID = 2 + +CHILD_AGE_GROUPS = tuple(range(2, 8)) +MAT_AGE_GROUPS = tuple(range(8, 15)) +MAT_AGE_GROUP_ID = 198 + + +def get_maternal_edu(education, gbd_round_id, + past_future, pop_version, location_ids): + """Recalculate maternal education, which according to the education team is + the education of women of age-group-IDs 8 to 14 multiplied by their + age-weights and then summed over age. + + Only the age weights of groups 8 to 14 are kept, and then are rescaled so + that the sum of those age weights is 1. + + Args: + education (xarray.DataArray): + Education data. Needs dimensions `age_group_id` and `sex_id`, + but probably also has dimensions `location_id`, `draw`, `year_id` + and maybe `scenario`. + gbd_round_id (int): + Numeric ID for the GBD round. Used to get the age-weights for the + round from the database. + past_pop_version (str): + Version of past population to use for maternal education + aggregation. + future_pop_version (str): + Version of future population to use for maternal education + aggregation. + Returns: + (tuple[xarray.DataArray, xarray.DataArray]): + * The first `xarray.DataArray` of the tuple is educational + attainment for all age-groups and sexes. However, children that + are too young to have their own education are filled in with + maternal education. + * The second `xarray.DataArray` of the tuple is maternal education + -- only for the maternal age-group, given by `MAT_AGE_GROUP_ID` + and females, given by `FEMALE_SEX_ID`. + """ + + pop_path = FBDPath("") # Path removed for security reasons + + pop = open_xr(pop_path / "population.nc").data.sel( + age_group_id=list(MAT_AGE_GROUPS), sex_id=FEMALE_SEX_ID, + location_id=list(location_ids) + ) + + LOGGER.debug("Adding up education of moms to get maternal education.") + mat_slice_edu = education.sel(sex_id=FEMALE_SEX_ID, + age_group_id=list(MAT_AGE_GROUPS), + location_id=list(location_ids)) + + agg = Aggregator(pop) + mat_edu = agg.aggregate_ages(list(MAT_AGE_GROUPS), MAT_AGE_GROUP_ID, + data=mat_slice_edu).rate + + # age_group_id must be dropped. If not, expand_dimensions will broadcast + # NaNs instead of our data into the new child age_group_id values. + mat_edu_expanded = expand_dimensions(mat_edu.drop("age_group_id").squeeze(), + sex_id=list(SEXES), + age_group_id=list(CHILD_AGE_GROUPS)) + + LOGGER.debug("Adding maternal education for both sexes and child age " + "groups to education data array.") + # Even if ``education`` has data for child age groups, combine first will + # make sure that the newly calculated maternal education will be used + # instead. + return mat_edu_expanded.combine_first(education), mat_edu diff --git a/gbd_2021/fertility_forecast_code/fertility/__init__.py b/gbd_2021/fertility_forecast_code/fertility/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gbd_2021/fertility_forecast_code/fertility/constants.py b/gbd_2021/fertility_forecast_code/fertility/constants.py new file mode 100644 index 0000000..5fb73f1 --- /dev/null +++ b/gbd_2021/fertility_forecast_code/fertility/constants.py @@ -0,0 +1,80 @@ +"""FHS Pipeline for BMI forecasting Local Constants. + +Contains all the global variables used by the fertility modeling steps for +data input and transformation, ccfX, ccf, incremental, and conversion to ASFR. + +Model specifications for different modeling steps (note that most are +age-specific) are also defined here. + +Global variables for the intercept shift or plotting are NOT included here. +""" +import numpy as np + +# scenario information +REFERENCE_SCENARIO_COORD = 0 + +# past met need start year +# don't change unless GBD estimates are made for prior to 1980 +PAST_MET_NEED_YEAR_START = 1970 + +# define age-related constants for CCF forecast model +# do not recommend changing this as it is the standard definition for CCF50 +COHORT_AGE_START = 15 # this is the starting age for "modeled" forecast +COHORT_AGE_END = 49 # CCF50 means [15, 50), so last eligible age year is 49. +FIVE_YEAR_AGE_GROUP_SIZE = 5 +MODELED_AGE_STEP_SIZE = 1 +MODELED_FERTILE_AGE_IDS = list(range(8, 15)) +FERTILE_AGE_IDS = list(range(7, 16)) +COVARIATE_START_AGE = 25 # this maps covariate age years to cohorts + +# Bounds for forecasting ccf +CCF_LOWER_BOUND = 0.7 +CCF_UPPER_BOUND = 10.0 +CCF_BOUND_TOLERANCE = 1e-3 + +# these are single year ages that will be inferred/appended after forecast +YOUNG_TERMINAL_AGES = range(10, 15) +OLD_TERMINAL_AGES = range(50, 55) + +PRONATAL_THRESHOLD = 1.75 +RAMP_YEARS = 5 # number of years to ramp up pronatal policy + + +class DimensionConstants: + """Constants related to data dimensionality.""" + + # Common dimensions + REGION_ID = "region_id" + SUPER_REGION_ID = "super_region_id" + SCENARIO = "scenario" + LOCATION_ID = "location_id" + AGE_GROUP_ID = "age_group_id" + YEAR_ID = "year_id" + COHORT_ID = "cohort_id" + DRAW = "draw" + SEX_ID = "sex_id" + QUANTILE = "quantile" + WEIGHT = "weight" + STATISTIC = "statistic" + AGE = "age" + CCF_INDEX_DIMS = ["location_id", "year_id"] # must be list of demog dims + SINGLE_YEAR_ASFR_INDEX_DIMS = ["location_id", "age", "year_id", "draw"] + + FEMALE_SEX_ID = 2 + + +class StageConstants: + """Stages in FHS file system.""" + + CCF = "ccf" + ORDERED_COVARIATES = ["education", "met_need", "u5m", "urbanicity"] + ENSEMBLE_COV_COUNTS = [2, 3, 4] + STAGE_2_ARIMA_ATTENUATION_END_YEAR = 2050 + STAGE_3_COVARIATES = ["education", "met_need"] + + +class ModelConstants: + """Constants for model_strategy.py.""" + + log_logit_offset = 1e-8 + knot_placements = np.array([0.0, 0.1, 0.365, 0.75, 1.0]) diff --git a/gbd_2021/fertility_forecast_code/fertility/input_transform.py b/gbd_2021/fertility_forecast_code/fertility/input_transform.py new file mode 100644 index 0000000..13b920c --- /dev/null +++ b/gbd_2021/fertility_forecast_code/fertility/input_transform.py @@ -0,0 +1,165 @@ +"""Helper I/O functions for the fertility pipeline.""" +import logging +from typing import List, Union + +import xarray as xr +from fhs_lib_file_interface.lib.file_interface import FBDPath +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr + +from fhs_pipeline_fertility.lib.constants import PAST_MET_NEED_YEAR_START, DimensionConstants + +LOGGER = logging.getLogger(__name__) + + +# this func is used for stage 2 of fertility: CCF forecast. +# also used in stage 3 step 1 +def get_covariate_data( + stage: str, + future_version: str, + past_version: str, + year_ids: List[int], + gbd_round_id: int, + location_ids: List[int], + age_group_ids: Union[int, List[int]], + scenario: int, + draws: int, +) -> xr.DataArray: + r"""Extract age-specific covariate data in period space. + + Args: + stage (str): the stage (education, met_need, etc.) of the covariate. + future_version (str): version of future covariate. + past_version (str): version of past covariate. + year_ids (List[int]): all year_ids needed from past + future. + gbd_round_id (int): gbd round id. + location_ids (List[int]): the location ids to pull data for. + age_group_ids (Union[int, List[int]]): age group ids for ccf covariate. + scenario (int): the future scenario to pull. + draws (int): the number of draws to return. + + Returns: + (xr.DataArray): age-specific covariate data + in period space (indexed by ``year_id`` and ``age_group_id``). + """ + da = _load_fhs_covariate_data( + stage=stage, + version=future_version, + gbd_round_id=gbd_round_id, + draws=draws, + past_or_future="future", + age_group_ids=age_group_ids, + scenario=scenario, + location_ids=location_ids, + ) + + da.name = stage # the name of the covariate will be handy + + # assume the past has no scenario (None) + past_da = _load_fhs_covariate_data( + stage=stage, + version=past_version, + gbd_round_id=gbd_round_id, + draws=draws, + past_or_future="past", + age_group_ids=age_group_ids, + scenario=None, + location_ids=location_ids, + ) + + da = da.combine_first(past_da) # future year_ids overwrite if overlap + + if type(age_group_ids) is int: + da = da.drop_vars(DimensionConstants.AGE_GROUP_ID) + + # some covariates (met_need) may not have data going as far back as + # the other covariates (education), + # so we remove the extra early past years to align all past covariates + if min(year_ids) < PAST_MET_NEED_YEAR_START: + year_ids = range(PAST_MET_NEED_YEAR_START, max(year_ids) + 1) + + da = da.sel(year_id=year_ids) # errors if we have missing years + + return da + + +def _load_fhs_covariate_data( + stage: str, + version: str, + gbd_round_id: int, + past_or_future: str, + age_group_ids: Union[int, List[int]], + location_ids: Union[int, List[int]], + draws: int, + scenario: Union[int, None], +) -> xr.DataArray: + """Load covariate data for a given stage. + + This is really just a wrapper to standardize sex/scenario-related + processing we apply to read-in covariate data for fertility. + + 1.) if there is a sex_id dimension, only data for females is returned. + 2.) pull only chosen scenario from object. + + ..todo:: Replace with a centralized/core function? + + Args: + stage (str): stage of data to load. + version (str): past version name. + gbd_round_id (int): GBD round id. + past_or_future (str): either "past" or "future" (period space). + age_group_ids (Union[int, List[int]]): age group ids to filter for. + location_ids (Union[int, List[int]]): location ids to filter for. + draws (int): number of draws to pull. + scenario (int): scenario to return. Ignored if + there is no scenario dimension in the data or ``scenario=None``. + + Returns: + (xr.DataArray): requested data. + """ + path = FBDPath( + gbd_round_id=gbd_round_id, past_or_future=past_or_future, stage=stage, version=version + ) + + da = open_xr(path / f"{stage}.nc").sel( + location_id=location_ids, age_group_id=age_group_ids, draw=range(draws) + ) + + if DimensionConstants.SCENARIO in da.dims and scenario is not None: + da = da.sel(scenario=scenario, drop=True) + # all of fertilty modeling is for female sex. cleaner to do all the needed + # downstream data cleaning and manipulation if the sex dimension is dropped + if DimensionConstants.SEX_ID in da.dims: + if DimensionConstants.FEMALE_SEX_ID in da[DimensionConstants.SEX_ID].values: + da = da.sel(sex_id=DimensionConstants.FEMALE_SEX_ID, drop=True) + else: + da = da.sel(sex_id=da[DimensionConstants.SEX_ID].values[0], drop=True) + + return da + + +def convert_period_to_cohort_space(da: xr.DataArray) -> xr.DataArray: + """Convert single age year data from period space to cohort space. + + Change index dim ``year_id`` to ``cohort_id`` defined by cohort birth + year as``year_id``-``age``. + + Args: + da (xr.DataArray): data in period space by single age year ``age`` + and year of data ``year_id``. + + Returns: + (xr.DataArray): data in cohort space by single age year ``age`` + and year of birth ``cohort_id`` + """ + das = [] + + for age in da[DimensionConstants.AGE].values: + da_sub = da.sel(age=age, drop=True) + # define cohort_id = year_id - age + da_sub[DimensionConstants.YEAR_ID] = da_sub[DimensionConstants.YEAR_ID] - age + da_sub = da_sub.rename({DimensionConstants.YEAR_ID: DimensionConstants.COHORT_ID}) + # add back in age dim + da_sub[DimensionConstants.AGE] = age + das.append(da_sub) + + return xr.concat(das, dim=DimensionConstants.AGE) diff --git a/gbd_2021/fertility_forecast_code/fertility/main.py b/gbd_2021/fertility_forecast_code/fertility/main.py new file mode 100644 index 0000000..bc14f52 --- /dev/null +++ b/gbd_2021/fertility_forecast_code/fertility/main.py @@ -0,0 +1,318 @@ +r"""Orchestrator of fertility pipeline. +""" +import logging +from typing import Union + +import pandas as pd +import xarray as xr +from fhs_lib_database_interface.lib.query.age import get_ages +from fhs_lib_database_interface.lib.query.location import get_location_set +from fhs_lib_file_interface.lib.file_interface import FBDPath +from fhs_lib_file_interface.lib.versioning import Versions +from fhs_lib_file_interface.lib.xarray_wrapper import save_xr +from fhs_lib_year_range_manager.lib.year_range import YearRange + +from fhs_pipeline_fertility.lib import stage_1, stage_2, stage_3 +from fhs_pipeline_fertility.lib.constants import ( + COHORT_AGE_END, + COHORT_AGE_START, + FERTILE_AGE_IDS, + FIVE_YEAR_AGE_GROUP_SIZE, + MODELED_FERTILE_AGE_IDS, + OLD_TERMINAL_AGES, + PRONATAL_THRESHOLD, + RAMP_YEARS, + REFERENCE_SCENARIO_COORD, + YOUNG_TERMINAL_AGES, + DimensionConstants, + StageConstants, +) +from fhs_pipeline_fertility.lib.input_transform import convert_period_to_cohort_space + +LOGGER = logging.getLogger(__name__) + + +def fertility_pipeline( + versions: Versions, + years: YearRange, + draws: int, + gbd_round_id: int, + ensemble: bool, + pronatal_bump: Union[float, None], +) -> None: + """Run the fertility pipeline. + + Fertility pipeline contains the following stages: + + 1.) starting with past ASFR, compute CCFX from age 15 through 49. + 2.) forecast CCF50 (summed over 15-49) based on (1). + 3.) map CCF50 back to ASFR. + + 4.) If pronatal_policy, add bump to each location/year ASFR according + to how its TFR is distributed amongst the ages. + + Args: + versions (Versions): All relevant input/output versions. + years (YearRange): past_start:forecast_start:forecast_end. + draws (int): The number of draws to generate. + gbd_round_id (int): The numeric ID of GBD round. + ensemble (bool): whether to make an ensemble out of 2/3/4-covariate + forecasts for CCF in (2). + pronatal_bump (Union[float, None]): if pronatal, this value will be + distributed amongst single year ASFR for every future + location/year. + """ + # first collect some metadata needed to execute the stages + asfr_past_version = versions["past"]["asfr"] + + # we fit on nations data and predict to all locations, + # and we constantly switch between 5- and 1-year age groups. + # so we need to start with some location/age metadata here. + locations_df = get_location_set(gbd_round_id=gbd_round_id) + ages_df = get_ages().query(f"{DimensionConstants.AGE_GROUP_ID} in {FERTILE_AGE_IDS}")[ + ["age_group_id", "age_group_years_start", "age_group_years_end"] + ] + + # for modeled forecast, we only want age group ids 8-14 + modeled_ages_df = ages_df.query( + f"{DimensionConstants.AGE_GROUP_ID} in {MODELED_FERTILE_AGE_IDS}" + ) + + # stage 1 + ccfx, last_past_year, last_cohort = stage_1.get_past_asfr_into_ccfx( + asfr_past_version, years, gbd_round_id, draws, locations_df + ) + + # done with stage 1, save data first. + ccf_past_data_path = FBDPath( + gbd_round_id=gbd_round_id, + past_or_future="past", + stage="ccf", + version=versions["future"]["asfr"], + ) + + save_xr( + ccfx, + ccf_past_data_path / "ccfx.nc", + metric="rate", + space="identity", + versions=str(versions), + ensemble_covs=str(StageConstants.ENSEMBLE_COV_COUNTS), + ) + + # stage 2 + if ensemble: + ccf_future, fit_summary = stage_2.ensemble_forecast_ccf( + ccfx, + last_cohort, + versions, + gbd_round_id, + years, + draws, + locations_df, + modeled_ages_df, + ) + else: + ccf_future, fit_summary = stage_2.forecast_ccf( + ccfx, + last_cohort, + versions, + gbd_round_id, + years, + draws, + locations_df, + modeled_ages_df, + ) + + if type(fit_summary) is list: # means we ran an ensemble model + coefficients = pd.concat([tup[0] for tup in fit_summary]) + else: # just a two-tuple, where we only need the first element. + coefficients = fit_summary[0] + + # stage 3 + asfr_future = stage_3.forecast_asfr_from_ccf( + ccfx, + ccf_future, + last_cohort, + gbd_round_id, + versions, + locations_df, + modeled_ages_df, + years, + ) + + # add the pronatal policy boost if so desired. + if pronatal_bump: + asfr_future = add_pronatal_policy( + asfr_future, pronatal_bump, tfr_threshold=PRONATAL_THRESHOLD, ramp_years=RAMP_YEARS + ) + + # done with computations. Now save the results. + save_data_path = FBDPath( + gbd_round_id=gbd_round_id, + past_or_future="future", + stage="asfr", + version=versions["future"]["asfr"], + ) + + # save future files with a scenario dim + future_scenario = versions["future"].get_version_metadata("asfr").scenario + + if future_scenario is None: # then defaults to reference + future_scenario = REFERENCE_SCENARIO_COORD + + save_xr( + asfr_future.expand_dims(dim={DimensionConstants.SCENARIO: [future_scenario]}), + save_data_path / "asfr_single_year.nc", + metric="rate", + space="identity", + versions=str(versions), + ensemble_covs=str(StageConstants.ENSEMBLE_COV_COUNTS), + ) + + coefficients.to_csv(save_data_path / "coefficients.csv") + + asfr_5yr = make_five_year_asfr_from_single_year(asfr_future, ages_df) + + save_xr( + asfr_5yr.expand_dims(dim={DimensionConstants.SCENARIO: [future_scenario]}), + save_data_path / "asfr.nc", + metric="rate", + space="identity", + versions=str(versions), + ensemble_covs=str(StageConstants.ENSEMBLE_COV_COUNTS), + ) + + tfr = asfr_future.sum(DimensionConstants.AGE) + + tfr_data_path = FBDPath( + gbd_round_id=gbd_round_id, + past_or_future="future", + stage="tfr", + version=versions["future"]["asfr"], + ) + + save_xr( + tfr.expand_dims(dim={DimensionConstants.SCENARIO: [future_scenario]}), + tfr_data_path / "tfr.nc", + metric="rate", + space="identity", + versions=str(versions), + ensemble_covs=str(StageConstants.ENSEMBLE_COV_COUNTS), + ) + + # ccf50 requires past/future cohort asfrs because ccf spans many years + asfr_future_cohort = convert_period_to_cohort_space(da=asfr_future) + asfr_past_cohort = xr.concat( + [ccfx.sel(age=COHORT_AGE_START), ccfx.diff(DimensionConstants.AGE)], + dim=DimensionConstants.AGE, + ) + asfr_cohort = asfr_past_cohort.combine_first(asfr_future_cohort) + ccf = asfr_cohort.sel(age=range(COHORT_AGE_START, COHORT_AGE_END + 1)).sum( + DimensionConstants.AGE + ) + + ccf_data_path = FBDPath( + gbd_round_id=gbd_round_id, + past_or_future="future", + stage="ccf", + version=versions["future"]["asfr"], + ) + + save_xr( + ccf.expand_dims(dim={DimensionConstants.SCENARIO: [future_scenario]}), + ccf_data_path / "ccf.nc", + metric="rate", + space="identity", + versions=str(versions), + ensemble_covs=str(StageConstants.ENSEMBLE_COV_COUNTS), + ) + + return + + +def make_five_year_asfr_from_single_year( + asfr_future: xr.DataArray, ages_df: pd.DataFrame +) -> xr.DataArray: + """Make 5-year asfr from single-year asfr by taking the mean. + + The asfr of a five-year age group is the mean of the single-year + asfrs within that group. + + Args: + asfr_future (xr.DataArray): forecasted single-year asfr. + ages_df (pd.DataFrame): age-related metadata. + + Returns: + (xr.DataArray): five-year asfr. + """ + asfr_5yr_list = [] + for age_start in range( + min(YOUNG_TERMINAL_AGES), max(OLD_TERMINAL_AGES) + 1, FIVE_YEAR_AGE_GROUP_SIZE + ): + ages = range(age_start, age_start + FIVE_YEAR_AGE_GROUP_SIZE) + # just take the mean over the single years for 5-year asfr + asfr_5yr = asfr_future.sel(age=ages).mean(DimensionConstants.AGE) + asfr_5yr[DimensionConstants.AGE_GROUP_ID] = int( + ages_df.query(f"age_group_years_start == {age_start}")["age_group_id"] + ) + asfr_5yr_list.append(asfr_5yr) + asfr_5yr = xr.concat(asfr_5yr_list, dim=DimensionConstants.AGE_GROUP_ID) + + return asfr_5yr + + +def add_pronatal_policy( + asfr_single_year: xr.DataArray, pronatal_bump: float, tfr_threshold: float, ramp_years: int +) -> xr.DataArray: + """Implement pronatal policy by adding boost to TFR. + + The boost is ramped over from 0 to pronatal_bump over + ram_years. + + Go through every location/year in the future and distribute + the value amongst ages by the current distribution. + + Args: + asfr_single_year (xr.DataArray): forecasted single-year + asfr. + pronatal_bump (float): if pronatal, this value will be + distributed amongst single year ASFR for every future + location/year.the final TFR bump expected from policy. + tfr_threhold (float): threhold of TFR below which the + pronatal policy is triggered. + ramp_years (int): number of years over which the pronatal + boost is linearly ramped up. + + Returns: + (xr.DataArray): post pro-natal policy worldwide asfr. + """ + for location_id in asfr_single_year[DimensionConstants.LOCATION_ID]: + policy_already_implemented = False # starting assumption + policy_start_year = None # starting assumption + + for year_id in asfr_single_year[DimensionConstants.YEAR_ID]: + asfr = asfr_single_year.sel(location_id=location_id, year_id=year_id).mean( + DimensionConstants.DRAW + ) + tfr = float(asfr.sum(DimensionConstants.AGE)) # tfr = sum over asfr ages + # starting year of policy is the first future year below threshold + if tfr < tfr_threshold and not policy_already_implemented: + policy_already_implemented = True + policy_start_year = year_id + + if policy_already_implemented: # meaning this year gets a boost + if year_id - policy_start_year >= ramp_years: # already at full boost + bump = pronatal_bump + else: # during ramp up + slope = (year_id - policy_start_year) / ramp_years # ramp up slope + bump = pronatal_bump * slope # the total tfr bump for this year + + asfr_dist = asfr / tfr # prob. distribution of asfr over ages + asfr_bump = bump * asfr_dist # distributes boost over ages + # now simply add boost to this location/year + asfr_single_year.loc[dict(location_id=location_id, year_id=year_id)] = ( + asfr_single_year.sel(location_id=location_id, year_id=year_id) + asfr_bump + ) + + return asfr_single_year diff --git a/gbd_2021/fertility_forecast_code/fertility/model_strategy.py b/gbd_2021/fertility_forecast_code/fertility/model_strategy.py new file mode 100644 index 0000000..fd33247 --- /dev/null +++ b/gbd_2021/fertility_forecast_code/fertility/model_strategy.py @@ -0,0 +1,93 @@ +"""CCF modeling strategies and their parameters for MRBRT. + +**Modeling parameters include:** + +* pre/post processing strategy (i.e. processor object) +* covariates +* cov_models +* study_id columns +""" +from collections import namedtuple +from enum import Enum + +from fhs_lib_data_transformation.lib import processing +from fhs_lib_model.lib import model +from mrtool import LinearCovModel +from stagemodel import OverallModel + +from fhs_pipeline_fertility.lib.constants import ModelConstants, StageConstants + + +class Covariates(Enum): + """Covariates used for modeling fertility.""" + + EDUCATION = "education" + MET_NEED = "met_need" + + +VALID_COVARIATES = tuple(cov.value for cov in Covariates) + + +ModelParameters = namedtuple( + "ModelParameters", + ( + "Model, " + "processor, " + "covariates, " + "node_models, " + "study_id_cols, " + "scenario_quantiles, " + ), +) + +MODEL_PARAMETERS = { + StageConstants.CCF: { + StageConstants.CCF: ModelParameters( + Model=model.MRBRT, + processor=processing.LogitProcessor( + years=None, + offset=ModelConstants.log_logit_offset, # 1e-8 + gbd_round_id=None, + age_standardize=False, + remove_zero_slices=True, + intercept_shift="mean", + ), + covariates={ + "education": processing.NoTransformProcessor( + years=None, gbd_round_id=None, no_mean=True + ), + "met_need": processing.NoTransformProcessor( + years=None, gbd_round_id=None, no_mean=True + ), + "u5m": processing.NoTransformProcessor( + years=None, gbd_round_id=None, no_mean=True + ), + "urbanicity": processing.NoTransformProcessor( + years=None, gbd_round_id=None, no_mean=True + ), + }, + node_models=[ + OverallModel( + cov_models=[ + LinearCovModel("intercept", use_re=False), + LinearCovModel( + "education", + use_re=False, + use_spline=True, + spline_knots=ModelConstants.knot_placements, + spline_knots_type="frequency", + spline_degree=3, + spline_l_linear=True, + spline_r_linear=True, + ), + LinearCovModel("met_need", use_re=False, use_spline=False), + LinearCovModel("u5m", use_re=False, use_spline=False), + LinearCovModel("urbanicity", use_re=False, use_spline=False), + ] + ), + ], + study_id_cols="location_id", + scenario_quantiles={0: None}, + ), + } +} diff --git a/gbd_2021/fertility_forecast_code/fertility/stage_1.py b/gbd_2021/fertility_forecast_code/fertility/stage_1.py new file mode 100644 index 0000000..17e3328 --- /dev/null +++ b/gbd_2021/fertility_forecast_code/fertility/stage_1.py @@ -0,0 +1,99 @@ +"""Fertility stage 1: read in of past asfr into ccfx.""" +import logging + +import numpy as np +import pandas as pd +import xarray as xr +from fhs_lib_file_interface.lib.file_interface import FBDPath +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr +from fhs_lib_year_range_manager.lib.year_range import YearRange + +from fhs_pipeline_fertility.lib.constants import ( + COHORT_AGE_END, + COHORT_AGE_START, + DimensionConstants, + StageConstants, +) +from fhs_pipeline_fertility.lib.input_transform import convert_period_to_cohort_space + +LOGGER = logging.getLogger(__name__) + + +def get_past_asfr_into_ccfx( + asfr_past_version: str, + years: YearRange, + gbd_round_id: int, + draws: int, + locations_df: pd.DataFrame, +) -> xr.DataArray: + """Read in past asfr data and convert to ccfx (cohort spce). + + The returned object has cohort_id instead of year_id. For the forecast + model, we only use ages from COHORT_AGE_START through COHORT_AGE_END. + + Args: + asfr_past_version (str): past asfr version. + years (YearRange): past_start:forecast_start:forecast_end. + gbd_round_id (int): gbd round id. + draws (int): number of draws to keep. + locations_df (pd.DataFrame): all location metadata. + + Returns: + tuple(xr.DataArray | int) Fertility data in cohort space (indexed by + cohort birth year ``cohort_id``): + * ``ccfX``: cumulative cohort fertility up to + ``DimensionConstants.AGE``=``X``. + and + * ``last_past_year``: inferred from past data. Sets cohort ranges. + * ``last_complete_cohort``: the last cohort that has age up to + COHORT_AGE_END. + """ + LOGGER.info("Stage 1: converting past asfr into ccfx") + + # collect all national/subnationals + loc_ids = locations_df[(locations_df["level"] >= 3)].location_id.values + + # now we infer what past years we have from the past asfr data. + path = FBDPath( + gbd_round_id=gbd_round_id, + past_or_future="past", + stage="asfr", + version=asfr_past_version, + ) + + # past data for some reason labels "age" as "age group id" + asfr_past = ( + open_xr(path / "asfr.nc") + .sel( + location_id=loc_ids, + draw=range(draws), + year_id=years.past_years, + age_group_id=range(COHORT_AGE_START, COHORT_AGE_END + 1), + ) + .rename({DimensionConstants.AGE_GROUP_ID: DimensionConstants.AGE}) + ) + + # now that we're done reading in asfr, do some transformation to get ccfx. + past_years = asfr_past[DimensionConstants.YEAR_ID].values.tolist() + last_past_year = max(past_years) # last past year inferred from data. + + asfr_past_cohort = convert_period_to_cohort_space(da=asfr_past) + + valid_cohorts = range( + min(past_years) - COHORT_AGE_START, max(past_years) - COHORT_AGE_START + 1 + ) + + asfr_past_cohort = asfr_past_cohort.sel(cohort_id=valid_cohorts) + + # ccfx has cumulative fertility by age, for all past complete/incomplete + # cohorts. That means we start seeing NaNs in first complete cohort. + ccfx = asfr_past_cohort.cumsum(DimensionConstants.AGE).where(np.isfinite(asfr_past_cohort)) + # after cumsum, change name of dim so its meaning is more clear + ccfx = ccfx.sortby(DimensionConstants.AGE) + ccfx.name = StageConstants.CCF + + last_complete_cohort = last_past_year - COHORT_AGE_END + + LOGGER.info("Done completing the past cohorts...") + + return ccfx, last_past_year, last_complete_cohort diff --git a/gbd_2021/fertility_forecast_code/fertility/stage_2.py b/gbd_2021/fertility_forecast_code/fertility/stage_2.py new file mode 100644 index 0000000..1906aab --- /dev/null +++ b/gbd_2021/fertility_forecast_code/fertility/stage_2.py @@ -0,0 +1,723 @@ +"""Stage 2: forecasting past ccf using MRBRT.""" +import gc +import logging +from typing import List, Union + +import numpy as np +import pandas as pd +import xarray as xr +import xskillscore +from fhs_lib_file_interface.lib.versioning import Versions +from fhs_lib_model.lib.constants import ModelConstants +from fhs_lib_model.lib.model import MRBRT +from fhs_lib_year_range_manager.lib.year_range import YearRange +from mrtool import MRData +from scipy.special import expit +from statsmodels.tsa.arima.model import ARIMA, ARIMAResults + +from fhs_pipeline_fertility.lib import model_strategy +from fhs_pipeline_fertility.lib.constants import ( + CCF_BOUND_TOLERANCE, + CCF_LOWER_BOUND, + CCF_UPPER_BOUND, + COHORT_AGE_END, + COVARIATE_START_AGE, + PAST_MET_NEED_YEAR_START, + REFERENCE_SCENARIO_COORD, + DimensionConstants, + StageConstants, +) +from fhs_pipeline_fertility.lib.input_transform import get_covariate_data +from fhs_pipeline_fertility.lib.stage_3 import ordered_draw_intercept_shift + +LOGGER = logging.getLogger(__name__) + + +def ensemble_forecast_ccf( + ccfx: xr.DataArray, + last_cohort: int, + versions: Versions, + gbd_round_id: int, + years: YearRange, + draws: int, + locations_df: pd.DataFrame, + ages_df: pd.DataFrame, +) -> xr.DataArray: + """Ensemble forecast of ccf. + + If ensemble is True, run submodels with different number of covariate + counts, and then equally sample from said submodels. + + Args: + ccf_past (xr.DataArray): past ccf draws. + last_cohort (int): last completed cohort. + versions (Versions): All relevant input versions. + gbd_round_id (int): gbd round id. + years (YearRange): past_start:forecast_start:forecast_end + in ``cohort`` space. Only forecast_end will be used. + draws (int): number of draws to keep. + locations_df (pd.DataFrame): all locations metadata. + + Returns: + (xr.DataArray): forecasted ccf. + """ + das = [] + fit_summaries = [] + n_submodels = len(StageConstants.ENSEMBLE_COV_COUNTS) + n_sub_draws = int(np.round(draws / n_submodels)) + + # in each iteration we run model with different number of covariates + for i, n_cov in enumerate(StageConstants.ENSEMBLE_COV_COUNTS): + keepers = StageConstants.ORDERED_COVARIATES[:n_cov] # covariates needed + keepers.append("asfr") # always keep this + # essentially we make a mock Versions object here + # v should look something like "{epoch}/{stage}/{version}" + submodel_versions = Versions( + *[v for v in versions.version_strs() if v.split("/")[1] in keepers] + ) + # unfortunately we run the whole gamut, with all draws, and then subset draws + sub_da, fit_summary = forecast_ccf( + ccfx, + last_cohort, + submodel_versions, + gbd_round_id, + years, + draws, + locations_df, + ages_df, + ) + + # now subset draws + if i < n_submodels - 1: + sub_draw_range = range(i * n_sub_draws, (i + 1) * n_sub_draws) + else: + sub_draw_range = range(i * n_sub_draws, draws) + + sub_da = sub_da.sel(draw=sub_draw_range) + + das.append(sub_da) + fit_summaries.append(fit_summary) + + ccf_future = xr.concat(das, dim=DimensionConstants.DRAW) + + return ccf_future, fit_summaries + + +def forecast_ccf( + ccfx: xr.DataArray, + last_cohort: int, + versions: Versions, + gbd_round_id: int, + years: YearRange, + draws: int, + locations_df: pd.DataFrame, + ages_df: pd.DataFrame, +) -> xr.DataArray: + """Forecast ccf using MRBRT. + + Forecasts ccf in three steps: + 1.) fit on the past means (ccf and covariates). + 2.) use (1) to obtain fitted past mean and mrbrt magic to get future draws. + 3.) take results from (2), run arima, to obtain past/future draws. + + Args: + ccf_past (xr.DataArray): past ccf draws. + last_cohort (int): last completed cohort. + versions (Versions): All relevant input versions. + gbd_round_id (int): gbd round id. + years (YearRange): past_start:forecast_start:forecast_end + in ``cohort`` space. Only forecast_end will be used. + draws (int): number of draws to keep. + locations_df (pd.DataFrame): all locations metadata. + + Returns: + (xr.DataArray): forecasted ccf. + """ + # prep and subset data to get ccf50 + ccf_past = ccfx.sel(age=COHORT_AGE_END).drop(DimensionConstants.AGE) + # to standardize across covariates, we start analyzing past cohorts from + # cohort_id = PAST_MET_NEED_YEAR_START (1970) - COVARIATE_START_AGE (25) + # = 1945, to cohort_id = last_cohort + first_cohort = PAST_MET_NEED_YEAR_START - COVARIATE_START_AGE + ccf_past = ccf_past.sel(cohort_id=range(first_cohort, last_cohort + 1)) + ccf_past.name = StageConstants.CCF + + # it turns out that mrbrt in stage 2 only takes year_id for time axis, + # and so we must rename cohort_id to year_id. + ccf_past = ccf_past.rename({DimensionConstants.COHORT_ID: DimensionConstants.YEAR_ID}) + + # Now we prep the covariate data. + # research decision: fit the past on national locations + location_ids_fit = locations_df.query("level == 3").location_id.to_list() + + # combining past & future, we need the following years from covariates + covariate_years_needed = list( + range( + min(ccf_past[DimensionConstants.YEAR_ID].values) + COVARIATE_START_AGE, + years.forecast_end - COHORT_AGE_END + COVARIATE_START_AGE + 1, + ) + ) + + # infer the covariates (ordered) we have, based on the python call inputs + covariates = [ + cov for cov in StageConstants.ORDERED_COVARIATES if cov in versions["past"].keys() + ] + + # collect covariate data into a list + covariate_data_list = _get_list_of_covariate_data( + covariates, + gbd_round_id, + versions, + covariate_years_needed, + draws=draws, + location_ids=ccf_past[DimensionConstants.LOCATION_ID].values.tolist(), + ages_df=ages_df, + ) + + # pull from model_strategy.py some objects needed for running MRBRT + Model, OverallModel_node_models, study_id_cols = get_mrbrt_model_objects() + + # filter cov_models for only the covariates for which we have data + cov_models = [ + cov + for cov in OverallModel_node_models[0].cov_models + if cov.name in covariates or cov.name == "intercept" + ] + + # research decision: take mean before logit transform + ccf_past_mean_logit = logit_bounded( + ccf_past.mean(DimensionConstants.DRAW), + lower=CCF_LOWER_BOUND, + upper=CCF_UPPER_BOUND, + tol=CCF_BOUND_TOLERANCE, + ) + + # define mrbrt model object + LOGGER.info("Initiatializing MRBRT model object....") + + se_col_name = str(ccf_past.name) + ModelConstants.STANDARD_ERROR_SUFFIX + + def df_func(df: pd.DataFrame) -> pd.DataFrame: + # fit on nationals only, so df with only past years is filtered + if years.forecast_start not in df[DimensionConstants.YEAR_ID].values: + df = df.query(f"{DimensionConstants.LOCATION_ID} in {location_ids_fit}") + df[se_col_name] = 1 # set this way so mrbrt ignores random effects + df["intercept"] = 1 # Peng: add if "intercept" in LinearCovModel + return df + + ccf_past_start = int(ccf_past[DimensionConstants.YEAR_ID].min()) + ccf_forecast_start = int(ccf_past[DimensionConstants.YEAR_ID].max()) + 1 + ccf_forecast_end = years.forecast_end - COHORT_AGE_END + ccf_years = YearRange(ccf_past_start, ccf_forecast_start, ccf_forecast_end) + + # MRBRT class requires past_data to have scenario=0 dim + mrbrt = Model( + past_data=ccf_past_mean_logit.expand_dims(scenario=[REFERENCE_SCENARIO_COORD]), + years=ccf_years, + draws=draws, + cov_models=cov_models, + study_id_cols=study_id_cols, + covariate_data=covariate_data_list, + df_func=df_func, + index_cols=DimensionConstants.CCF_INDEX_DIMS, + ) + + LOGGER.info("Fitting....") + + mrbrt.fit(outer_max_iter=500) + + fit_summary = mrbrt.model_instance.summary() + + # NOTE make sure the kwargs are specified correctly + LOGGER.info("Done fitting. Now making prediction") + + # get fitted past means, and future draws from var/covar matrix. + np.random.seed(gbd_round_id) + df_past_mean, df_future = create_ccf_future_uncertainty(mrbrt, ccf_years) + + # the subsequent arima code wants column names "observed" and "predicted". + # we already have "predicted", now we need to rename "ccf" to "observed" + df_past_mean = df_past_mean.rename(columns={mrbrt._orig_past_data.name: "observed"}) + df_future = df_future.rename(columns={mrbrt._orig_past_data.name: "observed"}) + + del mrbrt + gc.collect() + + # perform arima using logit-space mean-level residual + ccf_future_arima_logit = residual_arima_by_locations( + df_past_mean, + df_future, + ccf_years, + arima_attenuation_end_year=StageConstants.STAGE_2_ARIMA_ATTENUATION_END_YEAR, + ) + + # to bring back to normal space + ccf_future = expit_bounded( + ccf_future_arima_logit, lower=CCF_LOWER_BOUND, upper=CCF_UPPER_BOUND + ) + + del ccf_future_arima_logit + gc.collect() + + # now ordered-draw intercept-shift in normal space + ccf_future = ordered_draw_intercept_shift(ccf_future, ccf_past, ccf_years) + + ccf_future = ccf_future.where(ccf_future >= 0).fillna(0) # safeguard + + # last bit of clean up before returning + ccf_future = ccf_future.rename({DimensionConstants.YEAR_ID: DimensionConstants.COHORT_ID}) + + ccf_future = ccf_future.sel( + cohort_id=range(ccf_years.forecast_start, ccf_years.forecast_end + 1) + ) + ccf_future.name = StageConstants.CCF + + # now compute some skills and attach to fit_summary (1-tuple of dataframe) + df_past_mean = df_past_mean.set_index( + [DimensionConstants.LOCATION_ID, DimensionConstants.YEAR_ID] + ) + observed = df_past_mean["observed"].to_xarray() + predicted = df_past_mean["predicted"].to_xarray() + fit_summary[0]["rmse"] = float(xskillscore.rmse(observed, predicted)) + fit_summary[0]["r2"] = float(xskillscore.r2(observed, predicted)) + + return ccf_future, fit_summary + + +def get_mrbrt_model_objects() -> tuple: + """Coarse-grain the obtainment of model-related objects for MRBRT. + + Pull from model_strategy.py for mrbrt-specific paramters. + + Returns: + (tuple): objects useful for running MRBRT forecast + """ + stage_model_parameters = model_strategy.MODEL_PARAMETERS[StageConstants.CCF] + model_parameters = stage_model_parameters[StageConstants.CCF] + + ( + Model, + processor, + _, + OverallModel_node_models, + study_id_cols, + scenario_quantiles, + ) = model_parameters + + return Model, OverallModel_node_models, study_id_cols + + +def _get_list_of_covariate_data( + covariates: List[str], + gbd_round_id: int, + versions: Versions, + covariate_years_needed: List[int], + draws: int, + location_ids: List[int], + ages_df: pd.DataFrame, +) -> List[xr.DataArray]: + """Helper function for forecast_ccf() to collect covariate data. + + Args: + covariates (List[str]): names of covariates. + gbd_round_id (int): gbd round id. + versions (Versions): object containing all input/output + versions metadata. + covariate_years_needed (List[int]): list of year_ids needed + from all covariates. + draws (int): number of draws for run. + location_ids (List[int]): list of location_ids needed + from all covariates. + ages_df (pd.DataFrame): all age group-related metadata. + + Returns: + (list[xr.DataArray]): list of covariate data, each containing both + past and future. + + """ + covariate_data_list = [] + + for cov_name in covariates: + past_version = versions["past"][cov_name] + future_version = versions["future"][cov_name] + future_scenario = versions["future"].get_version_metadata(cov_name).scenario + + if future_scenario is None: # then defaults to reference + future_scenario = REFERENCE_SCENARIO_COORD + + covariate_age_group_id = _map_cov_name_to_age_group_id(cov_name, ages_df) + + # this pulls both past and future and returns them in one big object + cov_data = get_covariate_data( + stage=cov_name, + future_version=future_version, + past_version=past_version, + year_ids=covariate_years_needed, + gbd_round_id=gbd_round_id, + location_ids=location_ids, + scenario=future_scenario, + draws=draws, + age_group_ids=covariate_age_group_id, + ) + + # MRBRT._convert_covariates expects a scenario = 0 dim + cov_data = cov_data.expand_dims(scenario=[REFERENCE_SCENARIO_COORD]) + + # shift ccf covariates year_id to match their cohorts. + cov_data[DimensionConstants.YEAR_ID] = ( + cov_data[DimensionConstants.YEAR_ID] - COVARIATE_START_AGE + ) + + covariate_data_list.append(cov_data) + + return covariate_data_list + + +def _map_cov_name_to_age_group_id(cov_name: str, ages_df: pd.DataFrame) -> int: + """Map the given covariate name to the age group id(s) it contains. + + Args: + cov_name (str): name of covariate. + ages_df (pd.DataFrame): contains age group-related metadata. + + Returns: + (int): age group id for this particular covariate. + """ + # these are hard-coded quirks related to each upstream file + if cov_name == "urbanicity": + covariate_age_group_id = 22 + elif cov_name == "u5m": + covariate_age_group_id = 1 + else: # age group id of 25-30 year olds + covariate_age_group_id = int( + ages_df.query("age_group_years_start == @COVARIATE_START_AGE")[ + DimensionConstants.AGE_GROUP_ID + ] + ) + + return covariate_age_group_id + + +def create_ccf_future_uncertainty(model: MRBRT, years: YearRange) -> MRBRT: + """Create future uncertainty. + + Modifies in-place the .prediction_df object of ``model``. + + 1.) use MRBRT.predict() to make both past fit based on mean values. + 2.) use MRBRT var/covar methods to create future draws, starting with mean. + 3.) return past mean and future draws for our custom arima function. + + Args: + model (model.MRBRT): MRBRT model object, after fit. + years (YearRange): past_start:forecast_start:forecast_end. + + Returns: + (MRBRT) mrbrt model with modified .prediction_df dataframe. + """ + # first take the MEAN past ccf/covs values. We take it from .prediction_df + # because prediction_df has sub-national locations, whereas .combined_df + # was used for fitting and only contains national locations. + df_past_mean = ( + model.prediction_df.query(f"year_id in {years.past_years.tolist()}") + .groupby(DimensionConstants.CCF_INDEX_DIMS) + .mean() + .reset_index() + ) + + # now introduce the pred_logit fit to the MEAN past ccf/covs. + # model.model_instance already has fitting coefficients based on nationals. + # But that also means model.combined_mr was made with only sub-nationals, + # so in order to fit all (including subnats) past locations (using + # already-obtained fitting coefficients) one must make a new MRData object + # with this df_past_mean to run model.model_instance.predict() on. + mr_data = MRData() + mr_data.load_df( + data=df_past_mean, + col_obs=model._orig_past_data.name, + col_obs_se=(model._orig_past_data.name + ModelConstants.STANDARD_ERROR_SUFFIX), + col_covs=model.col_covs + model.index_cols, + col_study_id=model.study_id_col_name, + ) + df_past_mean["predicted"] = model.model_instance.predict( + data=mr_data, predict_for_study=False, sort_by_data_id=True + ) + + # new future predictions (with draws) + df_future = _create_uncertainty(model, years) # using mean of cov draws + + # these two will be used downstream for arima + return df_past_mean, df_future + + +def _create_uncertainty(model: MRBRT, years: YearRange) -> pd.DataFrame: + """Create uncertainty using covariate means. + + Args: + model (model.MRBRT): MRBRT model object, after fit. + years (YearRange): past_start:forecast_start:forecast_end. + + Returns: + (pd.DataFrame): df with new draws. + """ + sample_size = model.draws + beta_samples, gamma_samples = model.model_instance.sample_soln(sample_size) + + # df_future contains all MEAN future cov data, indexed by location/year + df_future = ( + model.prediction_df[model.col_covs + model.index_cols] + .query(f"year_id in {years.forecast_years.tolist()}") + .groupby(DimensionConstants.CCF_INDEX_DIMS) + .mean() + .reset_index() + ) + + df_future = _get_draws(df_future, model, beta_samples, gamma_samples) + + # df_future is wide-by-draw, and our canonical format is long-by-draw. + # so here we do some memory-expensive transformation. + df_future = _wide_to_long_by_draws(df_future, model.index_cols, sample_size) + + return df_future + + +def _wide_to_long_by_draws( + df: pd.DataFrame, index_cols: List[str], sample_size: int +) -> pd.DataFrame: + """Convert wide-by-draw df to long-by-draw df. + + Args: + df (pd.DataFrame): wide-by-draw df with "draw_{}" columns. + index_cols (List[str]): list of index column names. + sample_size (int): number of draws. + + Returns: + (pd.DataFrame): long-by-draw df, with "draw" columns. + """ + df.columns = [col.replace("draw_", "") for col in df.columns] + + df = pd.melt( + df, + id_vars=index_cols, + value_vars=[str(i) for i in range(sample_size)], + var_name=DimensionConstants.DRAW, + value_name="predicted", + ) + + df = df.astype({DimensionConstants.DRAW: int}) + + return df + + +def _get_draws( + df: pd.DataFrame, model: MRBRT, beta_samples: np.ndarray, gamma_samples: np.ndarray +) -> pd.DataFrame: + """Predict MRBRT model with uncertainty from covariance matrix. + + Args: + df (pd.DataFrame): has future covariates, mean by draw. + model (model.MRBRT): MRBRT object. + beta_samples (np.ndarray): beta samples. + gamme_samples (np.ndarray): gamma samples. + + Returns: + (pd.DataFrame): wide-by-draw dataframe of forecast. + """ + data_pred = MRData() + data_pred.load_df( + df, # df has the future covariates, mean-by-draws + col_covs=model.col_covs + model.index_cols, + col_study_id=DimensionConstants.LOCATION_ID, + ) + draw_ids = [f"draw_{i}" for i in range(model.draws)] + result = model.model_instance.create_draws( + data_pred, beta_samples, gamma_samples, sort_by_data_id=True + ) + result = pd.DataFrame(result, columns=draw_ids) + result = pd.concat([df[model.index_cols].reset_index(drop=True), result], axis=1) + return result + + +# this function is also used by stage 3 +def residual_arima_by_locations( + df_past_mean: pd.DataFrame, + df_future: pd.DataFrame, + years: YearRange, + arima_attenuation_end_year: Union[int, None], +) -> pd.DataFrame: + """Introduce arima uncertainty to forecasted draws. + + For every location: + + 1.) use arima(1,0,0) to fit the residual between true past mean + and predicted past mean. + 2.) use said arima to forecast future residual. + + Args: + df_past_mean (pd.DataFrame): data frame that contains mean past + observed data and its fit. + df_future (pd.DataFrame): predicted future draws, and covariates. + years (YearRange): past_start:forecast_start:forecast_end. + If used in stage 2, then these are probably cohort years. + In stage 3, the years are in period space. + arima_attenuation_end_year (Union[int, None]): last year of + finite arima value. None means no attenuation. + + Returns: + (pd.DataFrame): data frame with arimaed residuals. + """ + basic_index_dims = DimensionConstants.CCF_INDEX_DIMS # for past/future + arima_pasts = [] + arima_forecasts = [] + + for location_id in df_past_mean[DimensionConstants.LOCATION_ID].unique(): + df_loc = df_past_mean.query(f"location_id == {location_id}") + df_loc = df_loc.sort_values(DimensionConstants.YEAR_ID) + + if not (np.diff(df_loc[DimensionConstants.YEAR_ID].values) == 1).all(): + raise ValueError("Arima requires sequential past years.") + + # we want to compute residual = mean(true past) - mean(predicted past) + resid = (df_loc["observed"] - df_loc["predicted"]).values + + arima = ARIMA(resid, order=(1, 0, 0), trend="n") + arima_fit = arima.fit(method="innovations_mle") + + arima_past = arima_fit.fittedvalues + + arima_past = pd.DataFrame( + { + DimensionConstants.LOCATION_ID: np.repeat(location_id, len(arima_past)), + DimensionConstants.YEAR_ID: df_loc[DimensionConstants.YEAR_ID].values, + "residual_arima": arima_past, + } + ) + + arima_pasts.append(arima_past) # collect arima fits of the past + + # NOTE: this attenuation scheme is designed for ARIMA 100 + if arima_attenuation_end_year is not None: + arima_forecast = attenuate_arima(arima_fit, years, arima_attenuation_end_year) + else: # normal arima forecast + arima_forecast = arima_fit.forecast(steps=len(years.forecast_years)) + + arima_forecast = pd.DataFrame( + { + DimensionConstants.LOCATION_ID: np.repeat( + location_id, len(years.forecast_years) + ), + DimensionConstants.YEAR_ID: years.forecast_years, + "residual_arima": arima_forecast, + } + ) + + arima_forecasts.append(arima_forecast) # collect arima forecasts + + arima_past = pd.concat(arima_pasts, axis=0) + arima_forecast = pd.concat(arima_forecasts, axis=0) + + del arima_pasts, arima_forecasts + gc.collect() + + # add arima to mrbrt forecast to make final forecast + arima_past = df_past_mean[basic_index_dims + ["predicted"]].merge( + arima_past, on=basic_index_dims, how="right" + ) + + # replace predicted with arimaed predicted, and drop un-needed column + arima_past["predicted"] = (arima_past["predicted"] + arima_past["residual_arima"]).drop( + columns="residual_arima" + ) + + # convert to xr.DataArray + arima_past = arima_past.set_index(basic_index_dims).to_xarray()["predicted"] + + # expand the past draw dim so we can later concat past and future + arima_past = arima_past.expand_dims( + draw=df_future[DimensionConstants.DRAW].unique().tolist() + ) + + # now the future part + arima_forecast = df_future[ + basic_index_dims + ["predicted", DimensionConstants.DRAW] + ].merge(arima_forecast, on=basic_index_dims, how="right") + + # replace predicted with arimaed predicted, and drop unused column + arima_forecast["predicted"] = ( + arima_forecast["predicted"] + arima_forecast["residual_arima"] + ).drop(columns="residual_arima") + + # convert to xr.DataArray + arima_forecast = arima_forecast.set_index( + basic_index_dims + [DimensionConstants.DRAW] + ).to_xarray()["predicted"] + + # return both past and future in one dataarray + return xr.concat([arima_past, arima_forecast], dim=DimensionConstants.YEAR_ID) + + +def logit_bounded(x: xr.DataArray, lower: float, upper: float, tol: float) -> xr.DataArray: + r"""Compute the bounded logit transformation. + + The transformation is defined by + + ..math:: + f(x, a, b) =\ + \log \left( \frac{x_\text{trunc} - a}{b - x_\text_trunc} \right) + + where :math:`a` is the lower bound, :math:`b` is the upper bound, and + :math:`x_\text{trunct}` is :math:`x` truncated to be in the closed interval + :math:`[a + \text{tol}, b - \text{tol}]` (to avoid undefined logit). + + Args: + x (xr.DataArray): data to be transformed. + lower (float): lower bound. + upper (float): upper bound + tol (float): tolerance for approaching the bounds. + + Returns: + xr.DataArray: the supplied data transformed to bounded logit space. + """ + x = x.clip(min=lower + tol, max=upper - tol) + return np.log((x - lower) / (upper - x)) + + +def expit_bounded(x: xr.DataArray, lower: float, upper: float) -> xr.DataArray: + r"""Compute the inverse of bounded logit transoformation :func:`logit_bounded`. + + Args: + x (xr.DataArray): data to be transformed. + lower (float): lower bound. + upper (float): upper bound. + """ + y = (upper - lower) * expit(x) + lower + return y + + +def attenuate_arima( + arima_fit: ARIMAResults, years: YearRange, arima_attenuation_end_year: int +) -> np.array: + """Linearly attenuates arima forecast to 0 by input end year. + + Args: + arima_fit (ARIMAResults): arima fit from running arima.fit(). + years (YearRange): past start:forecast start:forecast end. + If used in stage 2, then these are likely ``cohort`` years. + arima_attenuation_end_year (int): last year of non-zero arima. + + Returns: + (np.array): forecasted arima values, attenuated to zero by + the prescribed end year. + """ + # - is from statsmodel convention + starting_coeff = -float(arima_fit.polynomial_ar[1]) + yearly_coeff_change = -starting_coeff / (arima_attenuation_end_year - years.forecast_start) + forecast_value = arima_fit.fittedvalues[-1] # starting point + arima_forecast = np.zeros(len(years.forecast_years)) # initialize to all zeros + + # linearly attenuate + for i in range(arima_attenuation_end_year - years.forecast_start): + arima_coeff = starting_coeff + (i * yearly_coeff_change) + forecast_value = forecast_value * arima_coeff + arima_forecast[i] = forecast_value + + return arima_forecast diff --git a/gbd_2021/fertility_forecast_code/fertility/stage_3.py b/gbd_2021/fertility_forecast_code/fertility/stage_3.py new file mode 100644 index 0000000..1010aef --- /dev/null +++ b/gbd_2021/fertility_forecast_code/fertility/stage_3.py @@ -0,0 +1,947 @@ +"""Fertility stage 3. + +1.) Unfold forecasted CCF50 back to 5-year-group cohort fertility via + linear mixed effects model, using education and met_need as + fixed effect covariates, and region_id as random intercept. +2.) Interpolate the 5-year-group ASFR to get single year ASFR. +3.) Apply same ARIMA model used in stage 2 (now in period space) to single year ASFR. +""" +import gc +import itertools as it +import logging +from typing import Iterable, Tuple + +import numpy as np +import pandas as pd +import statsmodels.formula.api as smf +import xarray as xr +from fhs_lib_data_transformation.lib.dimension_transformation import expand_dimensions +from fhs_lib_file_interface.lib.versioning import Versions +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr +from fhs_lib_year_range_manager.lib.year_range import YearRange +from scipy.special import expit, logit + +from fhs_pipeline_fertility.lib import stage_2 +from fhs_pipeline_fertility.lib.constants import ( + COHORT_AGE_END, + COHORT_AGE_START, + FIVE_YEAR_AGE_GROUP_SIZE, + OLD_TERMINAL_AGES, + PAST_MET_NEED_YEAR_START, + REFERENCE_SCENARIO_COORD, + YOUNG_TERMINAL_AGES, + DimensionConstants, + StageConstants, +) +from fhs_pipeline_fertility.lib.input_transform import get_covariate_data + +LOGGER = logging.getLogger(__name__) + + +def forecast_asfr_from_ccf( + ccfx: xr.DataArray, + ccf50_future: xr.DataArray, + last_cohort: int, + gbd_round_id: int, + versions: Versions, + locations_df: pd.DataFrame, + ages_df: pd.DataFrame, + years: YearRange, +) -> xr.DataArray: + """Forecast period-space asfr using forecasted ccf and education/met_need. + + This pipeline includes four/five steps: + + 1.) prepare forecasted education and met need for regression. + 2.) forecast logit(5-year-cohort-fertility / ccf50) using linear mixed + effects model, with education and met need as covariates, region_id + as random intercept. + 3.) interpolate between cohort 5-year age groups to get single-year + cohort asfr. + 4.) convert cohort asfr to period-space asfr and perform arima for future + uncertainty. + + Args: + ccfx (xr.DataArray): past ccfx (with fill-in). This has NaNs. + ccf50_future (xr.DataArray): forecasted ccf50. + last_cohort (int): last completed past cohort. + gbd_round_id (int): gbd round id. + versions (Versions): contains all relevant past/future versions. + locations_df (pd.DataFrame): locations metadata. + ages_df (pd.DataFrame): ages metadata. + years (YearRange): past_start:forecast_start:forecast_end. + + Returns: + (tuple[xr.DataArray]): past/future ccf50, logit(asfr/ccf50), + education, met need. + """ + LOGGER.info("Stage 3 step 1, prepping covariate data...") + + education, met_need = prepare_covariate_data(versions, gbd_round_id, ccf50_future, ages_df) + + LOGGER.info("Stage 3 step 2, LME regression...") + # now begins step 2 (regression) of stage 3 + # where we separate fits between high-income and non-high-income locations + cohort_5yr_fert_future, cohort_5yr_fert_past_fit = forecast_five_year_cohort_asfr( + ccfx, ccf50_future, education, met_need, locations_df + ) + + del education, met_need + gc.collect() + + LOGGER.info("Stage 3 step 3, interpolation to single-year asfr...") + # step 3 is the interpolation to single year asfr in cohort space + asfr_single_year, asfr_single_year_past_mean_fit = interpolate_for_single_year_cohort_asfr( + cohort_5yr_fert_future, cohort_5yr_fert_past_fit + ) + + del cohort_5yr_fert_future, cohort_5yr_fert_past_fit + gc.collect() + + LOGGER.info("Stage 3 step 4, arima of single year asfr...") + # step 4 is the arima of single year asfr + asfr_single_year = arima_single_year_asfr( + ccfx, + asfr_single_year_past_mean_fit, + asfr_single_year, + years, + gbd_round_id, + versions, + ages_df, + ) + + return asfr_single_year + + +def prepare_covariate_data( + versions: Versions, gbd_round_id: int, ccf50_future: xr.DataArray, ages_df: pd.DataFrame +) -> Tuple[xr.DataArray, xr.DataArray]: + """Prepare education and met_need covariates for logit(asfr / ccf50) regression. + + Args: + versions (Versions): input versions metadata. + gbd_round_id (int): gbd round id. + ccf50_future (xr.DataArray): forecasted ccf, used to help inform the + reading-in of covariate files. + ages_df (pd.DataFrame): age-related metadata. + + Returns: + (Tuple[xr.DataArray, xr.DataArray]): two-tuple of education and met_need. + """ + covariates = StageConstants.STAGE_3_COVARIATES + covariate_list = [] + + for cov_name in covariates: + past_version = versions["past"][cov_name] + future_version = versions["future"][cov_name] + future_scenario = versions["future"].get_version_metadata(cov_name).scenario + + if future_scenario is None: # then defaults to reference + future_scenario = REFERENCE_SCENARIO_COORD + + # TODO confirm the 41 is necessary + cov_data = get_covariate_data( + stage=cov_name, + future_version=future_version, + past_version=past_version, + year_ids=list( + range(PAST_MET_NEED_YEAR_START, int(ccf50_future.cohort_id.max()) + 41) + ), + gbd_round_id=gbd_round_id, + location_ids=ccf50_future[DimensionConstants.LOCATION_ID].values.tolist(), + scenario=future_scenario, + draws=len(ccf50_future[DimensionConstants.DRAW].values.tolist()), + age_group_ids=ages_df[DimensionConstants.AGE_GROUP_ID].values.tolist(), + ) + + # we chose to do things at mean level + cov_data = _convert_incremental_covariates_to_cohort_space(cov_data, ages_df) + + covariate_list.append(cov_data) + + education, met_need = covariate_list + + return education, met_need + + +def forecast_five_year_cohort_asfr( + ccfx: xr.DataArray, + ccf50_future: xr.DataArray, + education: xr.DataArray, + met_need: xr.DataArray, + locations_df: pd.DataFrame, +) -> Tuple[xr.DataArray, xr.DataArray]: + """Fit and forecast logit(5-year-cohort-fertility / ccf50). + + Everything in cohort space. + + Start with ccfx, labeled by single years, to compute cohort fertility + by 5-year age groups (sometimes referred to as cohort asfr). + + Then fit these 5-year cohort fertility rates over cohort_id, + using linear mixed effects model with education & met need + as fixed effect covariates, and region_id as random intercept. + + Returns predicted future draws and past means of cohort asfr. + + Args: + ccfx (xr.DataArray): past ccfx, by single years. + ccf50_future (xr.DataArray): forecasted ccf50. + education (xr.DataArray): education, past and future. Probably does + not go as far back as ccf due to data limitations. + met_need (xr.DataArray): met_need, past and future. Probably does not + go as far back as ccf due to data limitations. + locations_df (pd.DataFrame): location metadata. + + Returns: + (Tuple[xr.DataArray, xr.DataArray]): predicted future draws and + fitted past means of cohort asfr. + """ + # first compute the 5-year-age-group cohort fertility (cohort asfr) + cohort_5yr_fert = _make_5yr_cohort_fertility_from_ccfx(ccfx) + + # need some additional + last_past_cohort = ccf50_future[DimensionConstants.COHORT_ID].values.min() - 1 + # this will make a ccf50 with past through future. + ccf = ccfx.sel(age=COHORT_AGE_END, drop=True).combine_first(ccf50_future) + + # one fit/predict for high-income locations + laf_hi, laf_hi_mean_past_fit = _age_group_specific_fit_and_predict( + cohort_5yr_fert, + ccf, + education, + met_need, + last_past_cohort, + locations_df.query("super_region_name == 'High-income'"), + ) + # another fit/predict for non-high-income locations + laf_nhi, laf_nhi_mean_past_fit = _age_group_specific_fit_and_predict( + cohort_5yr_fert, + ccf, + education, + met_need, + last_past_cohort, + locations_df.query("super_region_name != 'High-income'"), + ) + # concat to get back all locations + laf_prediction = xr.concat([laf_hi, laf_nhi], dim=DimensionConstants.LOCATION_ID) + + laf_mean_past_fit = xr.concat( + [laf_hi_mean_past_fit, laf_nhi_mean_past_fit], dim=DimensionConstants.LOCATION_ID + ) + + # intercept-shift within fit means asfr fractions won't sum to 1 over age groups + ccf_fractions_future = expit(laf_prediction) # needs renormalization + ccf_fractions_future = ccf_fractions_future / ccf_fractions_future.sum( + DimensionConstants.AGE + ) # renormalization of fractional ccf + + cohort_5yr_fert_future = ccf * ccf_fractions_future # drops past + + # do the same renormalization for past mean fit + # but keep in mind that there are nans in very early cohorts, in the later + # years, due to the covariates past data starting later. Hence we must + # remove those cohorts with nans first, otherwise they'd be + # disproportionally renormalized. + # complete cohorts start from first met_need year - cohort_age_start, + # and end at last_past_cohort + laf_mean_past_fit = laf_mean_past_fit.sel( + cohort_id=range(PAST_MET_NEED_YEAR_START - COHORT_AGE_START, last_past_cohort + 1) + ) + ccf_frac_mean_past_fit = expit(laf_mean_past_fit) + # same fallacy as earlier + ccf_frac_mean_past_fit = ccf_frac_mean_past_fit / ccf_frac_mean_past_fit.sum( + DimensionConstants.AGE + ) + cohort_5yr_fert_past_fit = ( + ccf.sel(cohort_id=ccf_frac_mean_past_fit[DimensionConstants.COHORT_ID]).mean( + DimensionConstants.DRAW + ) + * ccf_frac_mean_past_fit + ) + + return cohort_5yr_fert_future, cohort_5yr_fert_past_fit + + +def _make_5yr_cohort_fertility_from_ccfx(ccfx: xr.DataArray) -> xr.DataArray: + """Compute cohort fertilility of 5-year age grouops from ccfx (single years). + + Args: + ccfx (xr.DataArray): past ccfx, by single year ages. + + Returns: + (xr.DataArray): cohort fertility by 5-year age groups. + Some cells will be NaN because the later cohorts have not + completed their fertile life yet. For example, the 2015 + cohort is only 18 years old by 2023, and hence won't even + have the 15-20 age group fertility yet. + """ + # because we want asfr of 5-year age groups, we need to take the diff + # of ccfx between end-ages. The starting point is the 15-19 age group, + # and we assume that asfr is 0 before 15. + age_group_end_ages = range(COHORT_AGE_START + 4, COHORT_AGE_END + 1, 5) + + # because ccf is cumulative fertility, we use .diff here to get the 5-year + # cohort fertility values we want to fit/predict. + fert_1st_age_group = ccfx.sel(age=age_group_end_ages[0]) + fert_later_age_groups = ccfx.sel(age=age_group_end_ages).diff(DimensionConstants.AGE) + + # This makes the 5-year cohort fertility whose fractions to ccf we fit/predict + cohort_5yr_fert = xr.concat( + [fert_1st_age_group, fert_later_age_groups], dim=DimensionConstants.AGE + ) # some cells will be NaNs due to cohorts being yet too young. + + return cohort_5yr_fert + + +def interpolate_for_single_year_cohort_asfr( + cohort_5yr_fert_future: xr.DataArray, cohort_5yr_fert_past_fit: xr.DataArray +) -> Tuple[xr.DataArray, xr.DataArray]: + """Interpolate to get single year cohort asfr. + + Args: + cohort_5yr_fert_future (xr.DataArray): future cohort 5-year fertility, + with draws. + cohort_5yr_fert_past_fit (xr.DataArray): past cohort 5-year fertility + mean fit. + + Returns: + (Tuple[xr.DataArray, xr.DataArray]): future single year asfr draws and + past single year asfr mean fit. + """ + interp_ages = range(COHORT_AGE_START, COHORT_AGE_END + 1) + + asfr_single_year = _interp_single_year_rates_from_five_year_age_groups( + cohort_5yr_fert_future, interp_ages + ) + + asfr_single_year_past_fit = _interp_single_year_rates_from_five_year_age_groups( + cohort_5yr_fert_past_fit, interp_ages + ) + + return asfr_single_year, asfr_single_year_past_fit + + +def _interp_single_year_rates_from_five_year_age_groups( + da_5yr: xr.DataArray, interp_ages: Iterable[int] +) -> xr.DataArray: + """Linearly interpolate 1-year rates from 5-year rates. + + Uses spline-linear interpolation. + + Args: + da_5yr (xr.DataArray): five-year rates. + interp_ages (Iterable[int]): single years to interpolate. + + Returns: + (xr.DataArray): single year rates. + """ + # first compute mean single year asfr + da_single_year = da_5yr / FIVE_YEAR_AGE_GROUP_SIZE # hence "mean" + # now reassign these mean values to some mid-age-group points + da_single_year = da_single_year.assign_coords( + age=(da_single_year[DimensionConstants.AGE] + 1) - (FIVE_YEAR_AGE_GROUP_SIZE / 2.0) + ) + + da_single_year = expand_dimensions( + da_single_year, age=[COHORT_AGE_START - 1, COHORT_AGE_END + 1], fill_value=0 + ) + + da_single_year = da_single_year.interp(age=interp_ages, method="slinear") + + # renormalize so sum over age is equal before/after interpolation. + da_single_year = da_single_year * ( + da_5yr.sum(DimensionConstants.AGE) / da_single_year.sum(DimensionConstants.AGE) + ) + + return da_single_year + + +def arima_single_year_asfr( + ccfx: xr.DataArray, + asfr_past_mean_fit: xr.DataArray, + asfr_future: xr.DataArray, + years: YearRange, + gbd_round_id: int, + versions: Versions, + ages_df: pd.DataFrame, +) -> xr.DataArray: + """Apply ARIMA to single year period-space asfr. + + Fit ARIMA 100 between observed past and predicted past, in period space, + and then add the forecast to the period space future. + + Starting with cohort space inputs (ccfx, asfr_past_mean_fit, and asfr_future): + 1.) Make period-space past/future asfrs using the cohort inputs. ccfx contains + all observed past period-space data, but asfr_past_mean_fit is missing a + triangle of past predicted values. Use asfr_future to fill in said missing + values. + 2.) Now in period space, for every location-age, fit past/forecast residual. + 3.) Add forecasted residual to forecast. + 4.) Infer terminal (10-15, 50-54) single year asfrs. + 5.) Ordered-draw intercept-shift with the past. + + Args: + ccfx_past (xr.DataArray): past single-year ccfx, has draws. + asfr_past_mean_fit (xr.DataArray): fitted past single-year cohort asfr. + asfr_future (xr.DataArray): future single-year cohort asfr, with draws. + years (YearRange): past_start:forecast_start:forecast_end. + gbd_round_id (int): gbd round id. + versions (Versions): contains all relevant input/output versions. + ages_df (pd.DataFrame): age metadata. + + Returns: + (xr.DataArray): post-arima future single-year asfr. + """ + df_past, asfr_past_mean_fit = _prep_past_cohort_asfr_da_into_period_asfr_df( + ccfx, asfr_past_mean_fit, asfr_future + ) + + # Now prep for the future draws. + df_future, da_predicted_last_past_year = _prep_future_cohort_asfr_da_into_period_asfr_df( + asfr_past_mean_fit, asfr_future, years + ) + + # now ready to make arima call for every age year + age_da_list = [] + + for age in df_future[DimensionConstants.AGE].unique(): + age_df_past = df_past.query(f"{DimensionConstants.AGE} == {age}") + arima_past_years = age_df_past[DimensionConstants.YEAR_ID].unique().tolist() + # age_df_past might have some years into the future already (stage 1) + age_df_future = df_future.query( + f"{DimensionConstants.AGE} == {age} & " + f"{DimensionConstants.YEAR_ID} not in {arima_past_years}" + ) + + # due to cohort-to-period space conversion, every age has a different + # forecast start year, so we make a special YearRange object for it + arima_years = YearRange( + years.past_start, max(arima_past_years) + 1, years.forecast_end + ) + # this function returns past and future together, both arimaed + past_and_future_da = stage_2.residual_arima_by_locations( + age_df_past, age_df_future, arima_years, arima_attenuation_end_year=None + ) + + # take the arimaed future years + future_da = past_and_future_da.sel(year_id=years.forecast_years) + + # need the "predicted" last past year, post arima and all that. + # as the anchor for intercept-shift. + # We'll compute the arima residual added to the first future year, + # and add it to the pre-arima predicted last past year, as the anchor. + pre_arima_first_future_year = ( + df_future.query( + f"{DimensionConstants.AGE} == {age} & " + f"{DimensionConstants.YEAR_ID} == {years.forecast_start}" + ) + .set_index(DimensionConstants.SINGLE_YEAR_ASFR_INDEX_DIMS)["predicted"] + .to_xarray() + ) + + resid_arima_first_future_year = past_and_future_da.sel( + year_id=years.forecast_start, drop=True + ) - pre_arima_first_future_year.sel(age=age, year_id=years.forecast_start, drop=True) + + # now make a future_da that has an arima-residual-added last past year + future_da = xr.concat( + [ + da_predicted_last_past_year.sel(age=age, drop=True) + + resid_arima_first_future_year, + future_da, + ], + dim=DimensionConstants.YEAR_ID, + ).sortby(DimensionConstants.YEAR_ID) + + future_da[DimensionConstants.AGE] = age + + age_da_list.append(future_da) + + # need to expit back to normal space. + asfr_future = expit( + xr.concat(age_da_list, dim=DimensionConstants.AGE).sortby(DimensionConstants.AGE) + ) # this overwrites the input variable + + # now infer and append the terminal ages (10-14, 50-54) + asfr_future, asfr_past = infer_and_append_terminal_ages( + gbd_round_id, versions, years.past_end, asfr_future + ) + + # we also need to intercept-shift to the past + ishift_years = YearRange( + years.past_start, years.forecast_start, years.forecast_end - COHORT_AGE_END + ) + + asfr_future = ordered_draw_intercept_shift(asfr_future, asfr_past, ishift_years) + + # this is a safeguard that sets all negatives to 0 + asfr_future = asfr_future.where(asfr_future >= 0).fillna(0) + + asfr_future = asfr_future.sel(year_id=years.forecast_years) + + return asfr_future + + +def _prep_past_cohort_asfr_da_into_period_asfr_df( + ccfx: xr.DataArray, asfr_past_mean_fit: xr.DataArray, asfr_future: xr.DataArray +) -> Tuple[pd.DataFrame, xr.DataArray]: + """Prepare past period space observed/predicted mean values in pd.DataFrame. + + Start with cohort space true/predicted values in past, make period-space + dataframe. Need to use some forecasted values to fill in for the past + predicted values. + + Args: + ccfx (xr.DataArray): True past cohort space single year fertility. + Contains all past values needed. Has draws. + asfr_past_mean_fit (xr.DataArray): mean past fit of cohort-space + single year fertility. No draw dimension. For young cohorts, + some past values are missing and need to be filled in with + forecasted values, which do have draws. + asfr_future (xr.DataArray): forecasted cohort-space single year + fertility. When converted to period space, contains some + past values that will be transfered to asfr_past_mean_fit. + + Returns: + (Tuple[pd.DataFrame, xr.DataArray]): dataframe of past mean + observed/predicted values (logit) in period space. + Also return xarray of the past mean fit values in cohort space. + """ + # we call the past draws from stage 1 "true past" + asfr_true_past = xr.concat( + [ccfx.sel(age=COHORT_AGE_START), ccfx.diff(DimensionConstants.AGE)], + dim=DimensionConstants.AGE, + ) # remember this is in cohort space, and still has many nans + + asfr_true_past.name = "observed" # because these values come from GBD + asfr_past_mean_fit.name = "predicted" # because we fitted these earlier + + # need to fill in for the past mean fit to match the true past in year id. + first_missing_cohort = int(asfr_past_mean_fit[DimensionConstants.COHORT_ID].max()) + 1 + last_missing_cohort = int(asfr_true_past[DimensionConstants.COHORT_ID].max()) + + # use these forecasted cohort values to fill in for past predictions + asfr_past_draw_fit = asfr_future.sel( + location_id=asfr_past_mean_fit[DimensionConstants.LOCATION_ID], + cohort_id=range(first_missing_cohort, last_missing_cohort + 1), + ).mean(DimensionConstants.DRAW) + + # this fills the predicted past vlaues + asfr_past_mean_fit = asfr_past_mean_fit.combine_first(asfr_past_draw_fit) + + # we arima fit the past residual: past mean - past prediction + asfr_true_past_mean = asfr_true_past.mean(DimensionConstants.DRAW) + + # now we prep the past mean & fit for residual_arima_by_locations(). + # merge into a dataset is a way to align the coordinates. + ds_past = xr.merge([asfr_true_past_mean, asfr_past_mean_fit], join="inner") + df_past = ds_past.to_dataframe().reset_index().dropna() # don't need NaNs + + # Now start converting to period space. + df_past[DimensionConstants.YEAR_ID] = ( + df_past[DimensionConstants.COHORT_ID] + df_past[DimensionConstants.AGE] + ) + + # some last bits of pruning before arima + df_past = df_past.drop(DimensionConstants.COHORT_ID, axis=1) # need no more + df_past = df_past.query(f"{DimensionConstants.YEAR_ID} >= {PAST_MET_NEED_YEAR_START}") + + df_past["observed"] = logit(df_past["observed"]) # fit arima in logit + df_past["predicted"] = logit(df_past["predicted"]) + + return df_past, asfr_past_mean_fit # df_past (period), asfr_past_mean_fit (cohort) + + +def _prep_future_cohort_asfr_da_into_period_asfr_df( + asfr_past_mean_fit: xr.DataArray, asfr_future: xr.DataArray, years: YearRange +) -> Tuple[pd.DataFrame, xr.DataArray]: + """Make future period-space ASFR from cohort space ASFRs. + + Need "predicted" last past year draws as anchor for intercept-shifting. + asfr_future (in cohort space) already has most of it, except for age 49, + which is part of last complete cohort and hence has only mean fit. + So we fill in asfr_future with asfr_past_mean_fit (repeat age 49 draws), + and will later prune the year_ids of both past and future to be consistent. + + Args: + asfr_past_mean_fit (xr.DataArray): mean past fit of cohort-space + single year fertility. No draw dimension. For young cohorts, + some past values are missing and need to be filled in with + forecasted values, which do have draws. + asfr_future (xr.DataArray): forecasted cohort-space single year + fertility. When converted to period space, contains some + past values that will be transfered to asfr_past_mean_fit. + years (YearRange): past_start:forecast_start:forecast_end + + Returns: + (pd.DataFrame): dataframe of forecasted values (logit) in period space. + Also return xarray of last past year's values (logit) for + downstream intercept-shift. + + """ + asfr_future = asfr_future.combine_first(asfr_past_mean_fit) + + # also need to transform the future data a bit before arima call + asfr_future = logit(asfr_future) # arima in logit space + asfr_future.name = "predicted" + df_future = asfr_future.to_dataframe().reset_index() + df_future[DimensionConstants.YEAR_ID] = ( + df_future[DimensionConstants.COHORT_ID] + df_future[DimensionConstants.AGE] + ) + df_future = df_future.drop("cohort_id", axis=1) + + # need this object as an anchor for intercept-shift + df_predicted_last_past_year = df_future.query( + f"{DimensionConstants.YEAR_ID} == {years.past_end}" + ) + + da_predicted_last_past_year = df_predicted_last_past_year.set_index( + DimensionConstants.SINGLE_YEAR_ASFR_INDEX_DIMS + )[ + "predicted" + ].to_xarray() # later will be last past year's intercept-shift anchor + + df_future = df_future.query( + f"{DimensionConstants.YEAR_ID} >= {years.forecast_start} & " + f"{DimensionConstants.YEAR_ID} <= {years.forecast_end}" + ) + + return df_future, da_predicted_last_past_year + + +def infer_and_append_terminal_ages( + gbd_round_id: int, versions: Versions, last_past_year: int, asfr_future: xr.DataArray +) -> xr.DataArray: + """Append terminal ages to forecasted single-year ASFR. + + Infer terminal ages ASFR based on last past year's ratio. + + The terminal ages are defined by YOUNG_TERMINAL_AGES and OLD_TERMINAL_AGES. + + Args: + gbd_round_id (int): gbd round id. + versions (Versions): used to pull past single-year asfr version. + last_past_year (int): used to filter for last past year. + asfr_future (xr.DataArray): forecasted single-year asfr, without + terminal ages. + + Returns: + (Tuple[xr.DataArray]): terminal-age-inferred future asfr, and the + observed past asfr. + """ + LOGGER.info("Inferring terminl ages...") + path = versions.data_dir(gbd_round_id, "past", "asfr") + + # need 2 past years here for intercept-shift in arima_single_year_asfr() + asfr_past = ( + open_xr(path / "asfr.nc") + .sel( + location_id=asfr_future[DimensionConstants.LOCATION_ID], + draw=asfr_future[DimensionConstants.DRAW], + year_id=[last_past_year - 1, last_past_year], + ) + .rename({DimensionConstants.AGE_GROUP_ID: DimensionConstants.AGE}) + ) + + # use mean-level ratios for inference + ratios = asfr_past.sel(age=YOUNG_TERMINAL_AGES, year_id=last_past_year).mean( + DimensionConstants.DRAW + ) / asfr_past.sel(age=COHORT_AGE_START, year_id=last_past_year, drop=True).mean( + DimensionConstants.DRAW + ) + + young_asfr_future = ratios * asfr_future.sel(age=COHORT_AGE_START, drop=True) + + ratios = asfr_past.sel(age=OLD_TERMINAL_AGES, year_id=last_past_year).mean( + DimensionConstants.DRAW + ) / asfr_past.sel(age=COHORT_AGE_END, year_id=last_past_year, drop=True).mean( + DimensionConstants.DRAW + ) + + old_asfr_future = ratios * asfr_future.sel(age=COHORT_AGE_END, drop=True) + + LOGGER.info("Done inferring terminal ages...") + + asfr_future = xr.concat( + [young_asfr_future, asfr_future, old_asfr_future], dim=DimensionConstants.AGE + ).sortby(DimensionConstants.AGE) + + return asfr_future, asfr_past + + +# NOTE: this should be centralized. +# ordered-draw intercept-shift should only be done in normal space. +def ordered_draw_intercept_shift( + da_future: xr.DataArray, da_past: xr.DataArray, years: YearRange +) -> xr.DataArray: + """Ordered-draw intercept-shift based on fan-out trajectories. + + Trajectories are determined by the difference between last forecast year + and last past year. + + Assumes there's no scenario dimension. Should only be done in normal space, + never in any non-linear transformation. + + Args: + da_future (xr.DataArray): Conctains last past + all future years. + da_past (xr.DataArray): Contains last past year. Should probably have + multiple past years up to the last past year. + years (YearRange): first past year:first future year:last future year. + + Returns: + (xr.DataArray): ordered-draw intercept-shifted da_future. + """ + # we determine draw sort order by trajectory = last year - last past year + trajectories = da_future.sel(year_id=years.forecast_end) - da_future.sel( + year_id=years.past_end + ) # no more year_id dim + + # these are coords after removing year and draw dims + non_draw_coords = trajectories.drop_vars(DimensionConstants.DRAW).coords + coords = list(non_draw_coords.indexes.values()) + dims = list(non_draw_coords.indexes.keys()) + + for coord in it.product(*coords): + slice_dict = {dims[i]: coord[i] for i in range(len(coord))} + + trajs = trajectories.sel(**slice_dict) # should have only draw dim now + # using argsort once gives the indices that sort the list, + # using it twice gives the rank of each value, from low to high + traj_rank = trajs.argsort().argsort().values + # in case draw labels don't start at 0, we obtain labels + # traj_rank_labels = trajectories[DimensionConstants.DRAW].\ + # values[traj_rank.values.tolist()] # np array of draw labels + + # from now on we will do some calculations in "rank space". + # ranked_future has year_id/draw dims. + ranked_future = da_future.sel(**slice_dict).assign_coords(draw=traj_rank) + # each draw label corresponds to the draw value's rank now + predicted_last_past_rank = ranked_future.sel( + year_id=years.past_end + ) # now only has draws + + # our goal is to allocate the highest trajectory rank to the highest + # observed last past rank, and lowest to lowest, etc. + # so we need to bring the last observed year into rank space as well. + observed_last_past = da_past.sel(**slice_dict, year_id=years.past_end) # draws only + past_draw_labels = observed_last_past[DimensionConstants.DRAW].values + past_rank = observed_last_past.argsort().argsort().values + observed_last_past_rank = observed_last_past.assign_coords(draw=past_rank) + + # Important: diff inherits the rank labels from observed_last_past_rank + diff = observed_last_past_rank - predicted_last_past_rank + + if not ( + diff[DimensionConstants.DRAW] == observed_last_past_rank[DimensionConstants.DRAW] + ).all(): + raise ValueError("diff must inherit rank values from " "observed_last_past_rank") + + # diff added to the future draws in rank space. + # Important: inherits rank labels from diff (~observed_last_past_rank). + ranked_future = diff + ranked_future # has year_id / draw (rank) dims + + if not (ranked_future[DimensionConstants.DRAW] == diff[DimensionConstants.DRAW]).all(): + raise ValueError("ranked_future must inherit rank values from " "diff") + + # ranked_future is now labeled by the rank labels of past draws. + # it is then straight-forward to map to draw labels: + ranked_future = ranked_future.assign_coords(draw=past_draw_labels) + + # prep the shape of ranked_future before inserting into da_future + dim_order = da_future.sel(**slice_dict).dims + ranked_future = ranked_future.transpose(*dim_order) + + # modify in-place + da_future.loc[slice_dict] = ranked_future # ~ "re_aligned_future" + + return da_future + + +def _age_group_specific_fit_and_predict( + cohort_5yr_fert: xr.DataArray, + ccf: xr.DataArray, + education: xr.DataArray, + met_need: xr.DataArray, + last_past_cohort: int, + locations_df: pd.DataFrame, +) -> xr.DataArray: + """Fit age-group-specific logit ccf50 fractions using education/met_need. + + Education and met_need used as fixed effects and region_id as random intercept. + + For every 5-year cohort age group, fit + logit(mean age group fertility / mean ccf50) over mean education and mean + met_need using LME with region_id random intercept over country locations. + This is followed by prediction over draws and subnationals. + + Contains a straight-up intercept-shift in the end. + + Args: + cohort_5yr_fert (xr.DataArray): 5 year fertility in cohort space. + Has cohort_id, location_id, age, and draw dims. + ccf (xr.DataArray): ccf50. Has cohort_id, location_id, and draw dims. + education (xr.DataArray): same dims as cohort_asfr. + met_need (xr.DataArray): same dims as cohort_asfr. + last_past_cohort (int): last completed past cohort. + locations_df (pd.DataFrame): locations metadata. The region ids + within will filter for the subset locations desired for analysis. + + Returns: + (xr.DataArray): predicted logit asfr fractions. + """ + # we will make a mean logit( asfr / ccf) and a draw logit( asfr / ccf) + # these objects have onyl past values because cohort_asfr only has past + logit_asfr_fractions = logit(cohort_5yr_fert / ccf) # for prediction + logit_asfr_fractions.name = "logit_5yr_fert_over_ccf" + + logit_asfr_fractions_mean = logit( + cohort_5yr_fert.mean(DimensionConstants.DRAW) / ccf.mean(DimensionConstants.DRAW) + ) # for fitting + logit_asfr_fractions_mean.name = logit_asfr_fractions.name + + # Convert to dataframes to work with statsmodels package + ds_mean = xr.merge( + [ + logit_asfr_fractions_mean, + education.mean(DimensionConstants.DRAW), + met_need.mean(DimensionConstants.DRAW), + ] + ) + + # the .dropna() call here will remove all future years from df_mean + df_mean = ds_mean.to_dataframe().dropna().reset_index() # don't fit on NaN + + # the df for prediction only needs to contain the future cohorts + ds = xr.merge([logit_asfr_fractions, education, met_need]) + # need the last_past_cohort year here for later intercept-shift + ds = ds.sel( + cohort_id=range(last_past_cohort, ccf[DimensionConstants.COHORT_ID].values.max() + 1) + ) + df = ds.to_dataframe().reset_index() # has future rows where Y = NaN + + del ds, ds_mean # no longer needed + gc.collect() + + # merge with locations_df to assign region_id to locations + df = df.merge( + locations_df[[DimensionConstants.LOCATION_ID, DimensionConstants.REGION_ID]], + how="inner", + on=DimensionConstants.LOCATION_ID, + ) # filters for locations + + df_mean = df_mean.merge( + locations_df[[DimensionConstants.LOCATION_ID, DimensionConstants.REGION_ID, "level"]], + how="inner", + on=DimensionConstants.LOCATION_ID, + ) + # df_mean_nats is for fitting and contains only national means + df_mean_nats = df_mean.query("level == 3") + + age_df_mean_list = [] # we're fitting the past for later arima + age_df_list = [] # necessary prep before concats + + for age in logit_asfr_fractions[DimensionConstants.AGE].values: + # the mean df is used only to fit + age_df_mean_nats = df_mean_nats.query(f"{DimensionConstants.AGE} == {age}") + # statsmodel mixedlm api needs "Y ~ X_1 + X_2" to specify fixed effects + md = smf.mixedlm( + f"{logit_asfr_fractions.name} ~ " f"{education.name} + {met_need.name}", + data=age_df_mean_nats, + groups=age_df_mean_nats[DimensionConstants.REGION_ID], + ) + + mdf = md.fit() + + # the random intercepts are stored in a weird dict that needs pruning + re_dict = dict([(k, v.values[0]) for k, v in mdf.random_effects.items()]) + + # no build-in to predict fixed + random effects. Predict separately. + age_df = df.query(f"{DimensionConstants.AGE} == {age}") + age_df[logit_asfr_fractions.name] = mdf.predict( + exog=age_df[StageConstants.STAGE_3_COVARIATES] + ) + age_df[DimensionConstants.REGION_ID].map( + re_dict + ) # fe + re + + age_df_list.append(age_df) + + # now fit the past mean because we need this fit for later arima + age_df_mean = df_mean.query(f"{DimensionConstants.AGE} == {age}") + age_df_mean[logit_asfr_fractions.name] = mdf.predict( + exog=age_df_mean[StageConstants.STAGE_3_COVARIATES] + ) + age_df_mean[DimensionConstants.REGION_ID].map(re_dict) + + age_df_mean_list.append(age_df_mean) + + del df, df_mean + gc.collect() + + prediction_df = pd.concat(age_df_list, axis=0) + prediction_da = prediction_df.set_index(list(logit_asfr_fractions.dims)).to_xarray()[ + logit_asfr_fractions.name + ] + + past_prediction_df = pd.concat(age_df_mean_list, axis=0) + past_prediction_da = past_prediction_df.set_index( + list(logit_asfr_fractions_mean.dims) + ).to_xarray()[logit_asfr_fractions_mean.name] + + # now the intercept-shift for both + prediction_da = ( + prediction_da + + logit_asfr_fractions.sel(cohort_id=last_past_cohort, drop=True) + - prediction_da.sel(cohort_id=last_past_cohort, drop=True) + ) + # don't need last past cohort in future data no more + last_future_cohort = int(prediction_da[DimensionConstants.COHORT_ID].max()) + prediction_da = prediction_da.sel( + cohort_id=range(last_past_cohort + 1, last_future_cohort + 1) + ) + + past_prediction_da = ( + past_prediction_da + + logit_asfr_fractions_mean.sel(cohort_id=last_past_cohort, drop=True) + - past_prediction_da.sel(cohort_id=last_past_cohort, drop=True) + ) + + return prediction_da, past_prediction_da + + +def _convert_incremental_covariates_to_cohort_space( + da: xr.DataArray, ages_df: pd.DataFrame +) -> xr.DataArray: + """Convert covariate data from period to cohort space. + + Naive conversion (via reindexing) of covariates for incremental fertilty + modeling from period space to cohort space by using the start year of the + age group interval to define the cohort birth year. Differs from the + function used in ccf modeling because it is also indexed by age. + + Args: + da (xr.DataArray): period space covariates to reindex. + ages_df (pd.DataFrame): age metadata. + + Returns: + (xr.DataArray): covariate data reindexed in cohort space, with + age_group_id converted to age, the end age of the interval. + """ + if not np.isin( + da[DimensionConstants.AGE_GROUP_ID].values, + ages_df[DimensionConstants.AGE_GROUP_ID].values, + ).all(): + raise ValueError("Missing age data") + + das = [] + for age_group_id in da[DimensionConstants.AGE_GROUP_ID].values.tolist(): + da_age = da.sel(age_group_id=age_group_id, drop=True) + ages = ages_df.query(f"age_group_id == {age_group_id}") + da_age[DimensionConstants.YEAR_ID] = da_age[DimensionConstants.YEAR_ID] - int( + ages["age_group_years_start"] + ) + da_age = da_age.rename({DimensionConstants.YEAR_ID: DimensionConstants.COHORT_ID}) + # GBD age_group_years_end is excessive by 1 + da_age[DimensionConstants.AGE] = int(ages["age_group_years_end"]) - 1 + das.append(da_age) + + return xr.concat(das, dim=DimensionConstants.AGE) diff --git a/gbd_2021/fertility_forecast_code/met_need/arc_forecast.py b/gbd_2021/fertility_forecast_code/met_need/arc_forecast.py new file mode 100644 index 0000000..462acd6 --- /dev/null +++ b/gbd_2021/fertility_forecast_code/met_need/arc_forecast.py @@ -0,0 +1,190 @@ +from fhs_lib_year_range_manager.lib.year_range import YearRange +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr, save_xr +from fhs_lib_model.lib.arc_method import arc_method +from fhs_lib_file_interface.lib.file_interface import FBDPath +from fhs_lib_data_transformation.lib.constants import DimensionConstants +from fhs_lib_data_transformation.lib.dimension_transformation import expand_dimensions +from fhs_lib_database_interface.lib.constants import AgeConstants, SexConstants +from fhs_lib_genem.lib.predictive_validity import root_mean_square_error +from fhs_lib_database_interface.lib.query.location import get_location_set +from fhs_lib_genem.lib.constants import TransformConstants +from fhs_lib_data_transformation.lib.resample import resample +from scipy.special import expit, logit + + + +import gc +from typing import Dict, Iterable, Optional, Tuple + +import numpy as np +import xarray as xr + +stage = "met_need" +entity = "met_need" +input_version = "20230503_met_need_with_draws" +out_version = "20230510_met_need_zero_biased_omega" +transform = "logit" +truncate = True +truncate_quantiles = (0.15, 0.85) +reference_scenario = "mean" +gbd_round_id = 7 +years = YearRange(1970,2023,2150) +pv_years = YearRange(1970, 2013, 2022) +min_omega = 0 +max_omega = 3 +draws = 1000 +omega_step_size = 0.5 +replace_with_mean = False +national_only = False +uncertainty=True + +cap_percentile = 0.95 + + +def _forecast_urbanicity( + omega: float, + past: xr.DataArray, + transform: str, + truncate: bool, + truncate_quantiles: Tuple[float, float], + replace_with_mean: bool, + reference_scenario: str, + years: YearRange, + gbd_round_id: int, + uncertainty: bool, + national_only, + extra_dim=None): + modeled_location_ids = get_location_set( + gbd_round_id, national_only=national_only + ).location_id.values + most_detailed_past = past.sel(location_id = modeled_location_ids) + original_coords = dict(most_detailed_past.coords) + original_coords.pop("draw", None) + zeros_dropped = most_detailed_past + if uncertainty: + most_detailed_past = most_detailed_past.sel(year_id=years.past_end) + else: + most_detailed_past=None + gc.collect() + if "draw" in zeros_dropped.dims: + past_mean = zeros_dropped.mean("draw") + else: + past_mean = zeros_dropped.copy() + transformed_past = logit(past_mean) + transformed_forecast = arc_method.arc_method( + past_data_da = transformed_past, + gbd_round_id = gbd_round_id, + years = years, + truncate = truncate, + reference_scenario=reference_scenario, + weight_exp=omega, + replace_with_mean=replace_with_mean, + truncate_quantiles=truncate_quantiles, + scenario_roc = "national", + extra_dim=extra_dim) + scaled_forecast = expit(transformed_forecast) + zeros_appended = expand_dimensions(scaled_forecast, fill_value=0.0, **original_coords) + return(zeros_appended) + + +def omega_strategy(rmse, draws, threshold=0.05, **kwargs): + norm_rmse = rmse / rmse.min() + weight_with_lowest_rmse = norm_rmse.where( + norm_rmse == norm_rmse.min()).dropna("weight")["weight"].values[0] + weights_to_check = [ + w for w in norm_rmse["weight"].values if w <= weight_with_lowest_rmse] + rmses_to_check = norm_rmse.sel(weight=weights_to_check) + rmses_to_check_within_threshold = rmses_to_check.where( + rmses_to_check < 1 + threshold).dropna("weight") + reciprocal_rmses_to_check_within_threshold = ( + 1 / rmses_to_check_within_threshold).fillna(0) + norm_reciprocal_rmses_to_check_within_threshold = ( + reciprocal_rmses_to_check_within_threshold + / reciprocal_rmses_to_check_within_threshold.sum()) + omega_draws = xr.DataArray( + np.random.choice( + a=norm_reciprocal_rmses_to_check_within_threshold["weight"].values, + size=draws, + p=norm_reciprocal_rmses_to_check_within_threshold.sel( + pv_metric = "root_mean_square_error").values), + coords=[list(range(draws))], dims=["draw"]) + return omega_draws + + +input_dir = f"FILEPATH" + +data = open_xr(f"FILEPATH.nc").data + +if "draw" in data.dims: + data = data.mean("draw") + + +superfluous_coords = [d for d in data.coords.keys() if d not in data.dims] +data = data.drop(superfluous_coords) + +holdouts = data.sel(year_id=pv_years.forecast_years) +past = data.sel(year_id=pv_years.past_years) +past = past.clip(max = 0.999) + +all_omega_pv_results = [] +pv_metrics = [root_mean_square_error] +pv_dims = [DimensionConstants.LOCATION_ID, DimensionConstants.SEX_ID] + +for omega in np.arange(min_omega, max_omega, omega_step_size): + all_pv_metrics_results = [] + for pv_metric in pv_metrics: + predicted_holdouts = _forecast_urbanicity( + omega, past, transform, truncate, + truncate_quantiles, replace_with_mean, + reference_scenario, pv_years, + gbd_round_id, uncertainty, national_only) + pv_data = pv_metric(predicted_holdouts.sel(scenario=0), holdouts) + one_pv_metric_result = xr.DataArray( + [[pv_data.values]], [[omega], [pv_metric.__name__]], + dims=["weight", "pv_metric"]) + all_pv_metrics_results.append(one_pv_metric_result) + all_pv_metrics_results = xr.concat( + all_pv_metrics_results, dim="pv_metric") + all_omega_pv_results.append(all_pv_metrics_results) + + +all_omega_pv_results = xr.concat(all_omega_pv_results, dim="weight") + +pv_dir = FBDPath(f"FILEPATH") +pv_dir.mkdir(parents=True, exist_ok=True) +pv_file = pv_dir / f"FILEPATH.nc" +all_omega_pv_results.to_netcdf(str(pv_file)) + + + + + +input_dir = FBDPath(f"FILEPATH") +past = open_xr(input_dir / f"FILEPATH.nc").data +past = past.clip(max = 0.999) + +pv_path = FBDPath(f"FILEPATH") +pv_file = pv_path / f"FILEPATH.nc" +all_pv_data = open_xr(pv_file).data +omega = omega_strategy(all_pv_data, draws) + +#Dataset has only one age and sex, year is 2022 so ignore those dimensions +#We want to take 0.95 quantile of mean draws over locations (So 0.95 quantile of 346 value array) +forecast = _forecast_urbanicity(omega, past, transform, truncate, + truncate_quantiles, replace_with_mean, + reference_scenario, years, + gbd_round_id, uncertainty, national_only, + extra_dim="draw") + +if isinstance(omega, xr.DataArray): + report_omega = float(omega.mean()) +else: + report_omega = omega + + +future_output_dir = FBDPath(f"FILEPATH") + +save_xr(forecast, + future_output_dir / f"FILEPATH.nc", metric="rate", + space="identity", omega_strategy="use_zero_biased_omega_distribution", + omega=report_omega, pv_metric=pv_metric.__name__) \ No newline at end of file diff --git a/gbd_2021/fertility_forecast_code/pop_by_habitable_area/create_log_habitable_area.py b/gbd_2021/fertility_forecast_code/pop_by_habitable_area/create_log_habitable_area.py new file mode 100644 index 0000000..e02d838 --- /dev/null +++ b/gbd_2021/fertility_forecast_code/pop_by_habitable_area/create_log_habitable_area.py @@ -0,0 +1,63 @@ +import pandas as pd +import xarray as xr +from db_queries import get_location_metadata +from fhs_lib_file_interface.lib.file_interface import FBDPath +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr, save_xr +from fhs_lib_database_interface.lib.query.location import get_location_set +from fhs_lib_data_transformation.lib.dimension_transformation import expand_dimensions + +gbd_round_pop = 6 +stage = "population" +locations_df = get_location_set(gbd_round_id=6) +loc_ids = locations_df[(locations_df["level"] >= 3)].location_id.values +ihme_locs = get_location_metadata(location_set_id = 35, gbd_round_id = 6) +ihme_locs = ihme_locs[['location_id', 'location_name', 'local_id', 'super_region_name']] + +#################################### +##### load population ######### +#################################### + +# load past population +past_pop_version = "VERSION" +past_pop_dir = FBDPath(f"/FILEPATH") +past_pop_da = open_xr(f"/FILEPATH.nc").data.sel(age_group_id = 22, sex_id = 3, location_id = list(loc_ids)) +# load future population +future_pop_version = "VERSION" +future_pop_dir = FBDPath(f"/FILEPATH") +future_pop_da = open_xr(f"/FILEPATH.nc").data.sum(dim = "age_group_id").sum(dim = "sex_id").sel(scenario=0) + +#################################### +##### load habitable area ######### +#################################### +area_hab = pd.read_csv("/FILEPATH.csv") +area_hab = area_hab.rename(columns={"loc_id": "location_id"})[["location_id", "area_hab"]] + +######################################################## +##### Create Pop Density = Pop/Habitable Area ######### +######################################################## + +# Create past version +past_pop_df_merge = past_pop_da.to_dataframe('population').reset_index().merge(area_hab).merge(ihme_locs) +past_pop_df_merge['pop_density'] = past_pop_df_merge['population'] / past_pop_df_merge['area_hab'] +past_pop_density = xr.DataArray.from_series(past_pop_df_merge.set_index(["location_id", "year_id", "age_group_id", "sex_id"])["pop_density"]) + +# Create future version +future_pop_df_merge = future_pop_da.to_dataframe('population').reset_index().merge(area_hab).merge(ihme_locs) +future_pop_df_merge['pop_density'] = future_pop_df_merge['population'] / future_pop_df_merge['area_hab'] +future_pop_density = xr.DataArray.from_series(future_pop_df_merge.set_index(["location_id", "year_id", "draw"])["pop_density"]) + +# Save future version with the required years +new_future_da = expand_dimensions(future_pop_density, year_id=range(2101, 2151), fill_value=future_pop_density.sel(year_id=2100, drop=True)) +da_new = xr.concat([past_pop_density, new_future_da.sel(year_id=range(2020, 2151))], "year_id").sel(year_id=range(2020, 2151)) +future_pop_density_file_path_past = "/FILEPATH/" +save_xr(da_new, f"{future_pop_density_file_path_past}/urbanicity.nc", metric="number", space="identity") + +# Save past version with the required years +da = xr.concat([past_pop_density, future_pop_density.sel(year_id=range(2020, 2023))], "year_id") +past_pop_density_file_path_past = "/FILEPATH/" +save_xr(da, f"{past_pop_density_file_path_past}/urbanicity.nc", metric="number", space="identity") + +log_past_da = np.log(da) +save_xr(log_past_da, f"{file_path_past}/urbanicity.nc", metric="number", space="identity") +log_future_da = np.log(da_new) +save_xr(log_future_da, f"{file_path_future}/urbanicity.nc", metric="number", space="identity") diff --git a/gbd_2021/fertility_forecast_code/u5m/u5m.py b/gbd_2021/fertility_forecast_code/u5m/u5m.py new file mode 100644 index 0000000..0f8414b --- /dev/null +++ b/gbd_2021/fertility_forecast_code/u5m/u5m.py @@ -0,0 +1,186 @@ +from fhs_lib_year_range_manager.lib.year_range import YearRange +from fhs_lib_file_interface.lib.xarray_wrapper import open_xr, save_xr +from fhs_lib_model.lib.arc_method import arc_method +from fhs_lib_file_interface.lib.file_interface import FBDPath +from fhs_lib_data_transformation.lib.constants import DimensionConstants +from fhs_lib_data_transformation.lib.dimension_transformation import expand_dimensions +from fhs_lib_database_interface.lib.constants import AgeConstants, SexConstants +from fhs_lib_genem.lib.predictive_validity import root_mean_square_error +from fhs_lib_database_interface.lib.query.location import get_location_set +from fhs_lib_genem.lib.constants import TransformConstants +from fhs_lib_data_transformation.lib.resample import resample +from scipy.special import expit, logit + + + +import gc +from typing import Dict, Iterable, Optional, Tuple + +import numpy as np +import xarray as xr + +stage = "u5m" +entity = "u5m" +input_version = "20230508_u5m_future_as_fake_past" +out_version = "20230515_u5m_to_2150" +transform = "logit" +truncate = True +truncate_quantiles = (0.15, 0.85) +reference_scenario = "mean" +gbd_round_id = 7 +pv_years = YearRange(2020,2030,2050) +years = YearRange(2020, 2051, 2150) +min_omega = 0 +max_omega = 3 +draws = 1000 +omega_step_size = 0.5 +replace_with_mean = False +national_only = False +uncertainty=True + +cap_percentile = 0.95 + + +def _forecast_urbanicity( + omega: float, + past: xr.DataArray, + transform: str, + truncate: bool, + truncate_quantiles: Tuple[float, float], + replace_with_mean: bool, + reference_scenario: str, + years: YearRange, + gbd_round_id: int, + uncertainty: bool, + national_only, + extra_dim=None): + modeled_location_ids = get_location_set( + gbd_round_id, national_only=national_only + ).location_id.values + most_detailed_past = past.sel(location_id = modeled_location_ids) + original_coords = dict(most_detailed_past.coords) + original_coords.pop("draw", None) + zeros_dropped = most_detailed_past + if uncertainty: + most_detailed_past = most_detailed_past.sel(year_id=years.past_end) + else: + most_detailed_past=None + gc.collect() + if "draw" in zeros_dropped.dims: + past_mean = zeros_dropped.mean("draw") + else: + past_mean = zeros_dropped.copy() + transformed_past = logit(past_mean) + transformed_forecast = arc_method.arc_method( + past_data_da = transformed_past, + gbd_round_id = gbd_round_id, + years = years, + truncate = truncate, + reference_scenario=reference_scenario, + weight_exp=omega, + replace_with_mean=replace_with_mean, + truncate_quantiles=truncate_quantiles, + scenario_roc = "national", + extra_dim=extra_dim) + scaled_forecast = expit(transformed_forecast) + zeros_appended = expand_dimensions(scaled_forecast, fill_value=0.0, **original_coords) + return(zeros_appended) + + +def omega_strategy(rmse, draws, threshold=0.05, **kwargs): + norm_rmse = rmse / rmse.min() + weight_with_lowest_rmse = norm_rmse.where( + norm_rmse == norm_rmse.min()).dropna("weight")["weight"].values[0] + weights_to_check = [ + w for w in norm_rmse["weight"].values if w <= weight_with_lowest_rmse] + rmses_to_check = norm_rmse.sel(weight=weights_to_check) + rmses_to_check_within_threshold = rmses_to_check.where( + rmses_to_check < 1 + threshold).dropna("weight") + reciprocal_rmses_to_check_within_threshold = ( + 1 / rmses_to_check_within_threshold).fillna(0) + norm_reciprocal_rmses_to_check_within_threshold = ( + reciprocal_rmses_to_check_within_threshold + / reciprocal_rmses_to_check_within_threshold.sum()) + omega_draws = xr.DataArray( + np.random.choice( + a=norm_reciprocal_rmses_to_check_within_threshold["weight"].values, + size=draws, + p=norm_reciprocal_rmses_to_check_within_threshold.sel( + pv_metric = "root_mean_square_error").values), + coords=[list(range(draws))], dims=["draw"]) + return omega_draws + +input_dir = f"FILEPATH" + +data = open_xr(f"{input_dir}/{entity}.nc").data + +if "draw" in data.dims: + data = data.mean("draw") + + +superfluous_coords = [d for d in data.coords.keys() if d not in data.dims] +data = data.drop(superfluous_coords) + +holdouts = data.sel(year_id=pv_years.forecast_years) +past = data.sel(year_id=pv_years.past_years) +past = past.clip(max = 0.999) + +all_omega_pv_results = [] +pv_metrics = [root_mean_square_error] +pv_dims = [DimensionConstants.LOCATION_ID, DimensionConstants.SEX_ID] + +for omega in np.arange(min_omega, max_omega, omega_step_size): + all_pv_metrics_results = [] + for pv_metric in pv_metrics: + predicted_holdouts = _forecast_urbanicity( + omega, past, transform, truncate, + truncate_quantiles, replace_with_mean, + reference_scenario, pv_years, + gbd_round_id, uncertainty, national_only) + pv_data = pv_metric(predicted_holdouts.sel(scenario=0), holdouts) + one_pv_metric_result = xr.DataArray( + [[pv_data.values]], [[omega], [pv_metric.__name__]], + dims=["weight", "pv_metric"]) + all_pv_metrics_results.append(one_pv_metric_result) + all_pv_metrics_results = xr.concat( + all_pv_metrics_results, dim="pv_metric") + all_omega_pv_results.append(all_pv_metrics_results) + + +all_omega_pv_results = xr.concat(all_omega_pv_results, dim="weight") + +pv_dir = FBDPath(f"FILEPATH") +pv_dir.mkdir(parents=True, exist_ok=True) +pv_file = pv_dir / f"{entity}_pv.nc" +all_omega_pv_results.to_netcdf(str(pv_file)) + + +input_dir = FBDPath(f"FILEPATH") +past = open_xr(input_dir / f"{entity}.nc").data +past = past.clip(max = 0.999) + +pv_path = FBDPath(f"FILEPATH") +pv_file = pv_path / f"{entity}_pv.nc" +all_pv_data = open_xr(pv_file).data +omega = omega_strategy(all_pv_data, draws) + +#Dataset has only one age and sex, year is 2022 so ignore those dimensions +#We want to take 0.95 quantile of mean draws over locations (So 0.95 quantile of 346 value array) +forecast = _forecast_urbanicity(omega, past, transform, truncate, + truncate_quantiles, replace_with_mean, + reference_scenario, years, + gbd_round_id, uncertainty, national_only, + extra_dim="draw") + +if isinstance(omega, xr.DataArray): + report_omega = float(omega.mean()) +else: + report_omega = omega + + +future_output_dir = FBDPath(f"FILEPATH") + +save_xr(forecast, + future_output_dir / f"{entity}.nc", metric="rate", + space="identity", omega_strategy="use_zero_biased_omega_distribution", + omega=report_omega, pv_metric=pv_metric.__name__) \ No newline at end of file