diff --git a/common_querysets/queryset_caring_fish.py b/common_querysets/queryset_caring_fish.py new file mode 100644 index 00000000..65f8b0f7 --- /dev/null +++ b/common_querysets/queryset_caring_fish.py @@ -0,0 +1,41 @@ +from viewser import Queryset, Column + +def generate(): + """ + Contains the configuration for the input data in the form of a viewser queryset. That is the data from viewser that is used to train the model. + This configuration is "behavioral" so modifying it will affect the model's runtime behavior and integration into the deployment system. + There is no guarantee that the model will work if the input data configuration is changed here without changing the model settings and algorithm accordingly. + + Returns: + - queryset_base (Queryset): A queryset containing the base data for the model training. + """ + + # VIEWSER 6, Example configuration. Modify as needed. + + queryset_base = (Queryset("caring_fish", "priogrid_month") + # Create a new column 'ln_sb_best' using data from 'priogrid_month' and 'ged_sb_best_count_nokgi' column + # Apply logarithmic transformation, handle missing values by replacing them with NA + .with_column(Column("ln_sb_best", from_loa="priogrid_month", from_column="ged_sb_best_count_nokgi") + .transform.ops.ln().transform.missing.replace_na()) + + # Create a new column 'ln_ns_best' using data from 'priogrid_month' and 'ged_ns_best_count_nokgi' column + # Apply logarithmic transformation, handle missing values by replacing them with NA + .with_column(Column("ln_ns_best", from_loa="priogrid_month", from_column="ged_ns_best_count_nokgi") + .transform.ops.ln().transform.missing.replace_na()) + + # Create a new column 'ln_os_best' using data from 'priogrid_month' and 'ged_os_best_count_nokgi' column + # Apply logarithmic transformation, handle missing values by replacing them with NA + .with_column(Column("ln_os_best", from_loa="priogrid_month", from_column="ged_os_best_count_nokgi") + .transform.ops.ln().transform.missing.replace_na()) + + # Create columns for month and year_id + .with_column(Column("month", from_loa="month", from_column="month")) + .with_column(Column("year_id", from_loa="country_year", from_column="year_id")) + + # Create columns for country_id, col, and row + .with_column(Column("c_id", from_loa="country_year", from_column="country_id")) + .with_column(Column("col", from_loa="priogrid", from_column="col")) + .with_column(Column("row", from_loa="priogrid", from_column="row")) + ) + + return queryset_base diff --git a/models/caring_fish/README.md b/models/caring_fish/README.md new file mode 100644 index 00000000..bfca3432 --- /dev/null +++ b/models/caring_fish/README.md @@ -0,0 +1,3 @@ +# Model README +## Model name: caring_fish +## Created on: 2024-10-28 16:45:11.931747 \ No newline at end of file diff --git a/models/caring_fish/configs/config_deployment.py b/models/caring_fish/configs/config_deployment.py new file mode 100644 index 00000000..9e45b735 --- /dev/null +++ b/models/caring_fish/configs/config_deployment.py @@ -0,0 +1,20 @@ +""" +Deployment Configuration Script + +This script defines the deployment configuration settings for the application. +It includes the deployment status and any additional settings specified. + +Deployment Status: +- shadow: The deployment is shadowed and not yet active. +- deployed: The deployment is active and in use. +- baseline: The deployment is in a baseline state, for reference or comparison. +- deprecated: The deployment is deprecated and no longer supported. + +Additional settings can be included in the configuration dictionary as needed. + +""" + +def get_deployment_config(): + # Deployment settings + deployment_config = {'deployment_status': 'shadow'} + return deployment_config diff --git a/models/caring_fish/configs/config_hyperparameters.py b/models/caring_fish/configs/config_hyperparameters.py new file mode 100644 index 00000000..8dc75e49 --- /dev/null +++ b/models/caring_fish/configs/config_hyperparameters.py @@ -0,0 +1,14 @@ +def get_hp_config(): + """ + Contains the hyperparameter configurations for model training. + This configuration is "operational" so modifying these settings will impact the model's behavior during the training. + + Returns: + - hyperparameters (dict): A dictionary containing hyperparameters for training the model, which determine the model's behavior during the training phase. + """ + + hyperparameters = { + 'model': 'LightBGM', # The model algorithm used. Eg. "LSTM", "CNN", "Transformer" + # Add more hyperparameters as needed + } + return hyperparameters diff --git a/models/caring_fish/configs/config_meta.py b/models/caring_fish/configs/config_meta.py new file mode 100644 index 00000000..2a2d7ab3 --- /dev/null +++ b/models/caring_fish/configs/config_meta.py @@ -0,0 +1,19 @@ +def get_meta_config(): + """ + Contains the meta data for the model (model algorithm, name, target variable, and level of analysis). + This config is for documentation purposes only, and modifying it will not affect the model, the training, or the evaluation. + + Returns: + - meta_config (dict): A dictionary containing model meta configuration. + """ + + meta_config = { + "name": "caring_fish", # Eg. happy_kitten + "algorithm": "LightBGM", # Eg. "LSTM", "CNN", "Transformer" + # Uncomment and modify the following lines as needed for additional metadata: + # "target(S)": ["ln_sb_best", "ln_ns_best", "ln_os_best", "ln_sb_best_binarized", "ln_ns_best_binarized", "ln_os_best_binarized"], + "queryset": "escwa001_cflong", + # "level": "pgm", + # "creator": "Your name here" + } + return meta_config diff --git a/models/caring_fish/configs/config_sweep.py b/models/caring_fish/configs/config_sweep.py new file mode 100644 index 00000000..f9997c4f --- /dev/null +++ b/models/caring_fish/configs/config_sweep.py @@ -0,0 +1,29 @@ +def get_sweep_config(): + """ + Contains the configuration for hyperparameter sweeps using WandB. + This configuration is "operational" so modifying it will change the search strategy, parameter ranges, and other settings for hyperparameter tuning aimed at optimizing model performance. + + Returns: + - sweep_config (dict): A dictionary containing the configuration for hyperparameter sweeps, defining the methods and parameter ranges used to search for optimal hyperparameters. + """ + + sweep_config = { + 'method': 'grid', + } + + # Example metric setup: + metric = { + 'name': 'MSE', + 'goal': 'minimize' + } + sweep_config['metric'] = metric + + # Example parameters setup: + parameters_dict = { + 'model': { + 'value': 'LightBGM' # Eg. "LSTM", "CNN", "Transformer" + }, + } + sweep_config['parameters'] = parameters_dict + + return sweep_config diff --git a/models/caring_fish/main.py b/models/caring_fish/main.py new file mode 100644 index 00000000..576958b2 --- /dev/null +++ b/models/caring_fish/main.py @@ -0,0 +1,39 @@ +import time +import wandb +import sys +import logging +logging.basicConfig(filename='run.log', encoding='utf-8', level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) +from pathlib import Path +# Set up the path to include common_utils module +PATH = Path(__file__) +sys.path.insert(0, str(Path( + *[i for i in PATH.parts[:PATH.parts.index("views_pipeline") + 1]]) / "common_utils")) # PATH_COMMON_UTILS +# Import necessary functions for project setup and model execution +from set_path import setup_project_paths +setup_project_paths(PATH) +from utils_cli_parser import parse_args, validate_arguments +from execute_model_runs import execute_sweep_run, execute_single_run + +if __name__ == "__main__": + # Parse command-line arguments + args = parse_args() + + # Validate the arguments to ensure they are correct + validate_arguments(args) + # Log in to Weights & Biases (wandb) + wandb.login() + # Record the start time + start_t = time.time() + # Execute the model run based on the sweep flag + if args.sweep: + execute_sweep_run(args) # Execute sweep run + else: + execute_single_run(args) # Execute single run + # Record the end time + end_t = time.time() + + # Calculate and print the runtime in minutes + minutes = (end_t - start_t) / 60 + logger.info(f'Done. Runtime: {minutes:.3f} minutes') diff --git a/models/lavender_haze/requirements.txt b/models/caring_fish/requirements.txt similarity index 100% rename from models/lavender_haze/requirements.txt rename to models/caring_fish/requirements.txt diff --git a/models/lavender_haze/.DS_Store b/models/lavender_haze/.DS_Store deleted file mode 100644 index 7224dcaa..00000000 Binary files a/models/lavender_haze/.DS_Store and /dev/null differ diff --git a/models/lavender_haze/README.md b/models/lavender_haze/README.md deleted file mode 100644 index 7306bc97..00000000 --- a/models/lavender_haze/README.md +++ /dev/null @@ -1,85 +0,0 @@ -# Lavender Haze Model -## Overview -This folder contains code for Lavender Haze model, a machine learning model designed for predicting fatalities. - -The model utilizes Hurdle Model (LGBMClassifier+LGBMRegressor) for its predictions and is on pgm level of analysis. - -The model uses log fatalities. - -## Repository Structure -``` - -lavender_haze/ # should follow the naming convention adjective_noun -|-- README.md -|-- requirements.txt -| -|-- artifacts/ # ensemble stepshifter models -| |-- model_metadata_dict.py # the standard meta data dict for models -| -|-- configs/ # ... -| |-- config_deployment.py # configuration for deploying the model into different environments -| |-- config_hyperparameters.py # hyperparameters for the model -| |-- config_input_data.py # defined queryset as the input data -| |-- config_meta # metadata for the model (model architecture, name, target variable, and level of analysis) -| |-- config_sweep # sweeping parameters for weights & biases -| -|-- data/ # all input, processed, output data -| |-- generated/ # Data generated - i.e. forecast/ evaluation -| |-- processed/ # Data processed -| |-- raw/ # Data directly from VIEiWSER -| -|-- notebooks/ -| -|-- reports/ # dissemination material - internal and external -| |-- figures/ # figures for papers, reports, newsletters, and slides -| |-- papers/ # working papers, white papers, articles ect. -| |-- plots/ # plots for papers, reports, newsletters, and slides -| |-- slides/ # slides, presentation and similar -| |-- timelapse/ # plots to create timelapse and the timelapse -| -|-- src/ # all source code needed to train, test, and forecast - | - |-- dataloaders/ - | |-- get_data.py # script to get data from VIEWSER (and input drift detection) - | - |-- forecasting/ - | |-- generate_forecast.py # script to genereate true-future fc - | - |-- management/ - | |-- execute_model_runs.py # execute a single run - | |-- execute_model_tasks.py # execute various model-related tasks - | - |-- offline_evaluation/ # aka offline quality assurance - | |-- evaluate_model.py # script to evaluate a single model - | |-- evaluate_sweep.py # script to evaluate a model during sweeping - | - |-- online_evaluation/ - | - |-- training/ - | |-- train_model.py # script to train a single model - | - |-- utils/ # functions and classes - | |-- utils.py # a general utils function - | |-- utils_wandb.py # a w&b specific utils function - | - |-- visualization/ # scripts to create visualizations - |-- visual.py - - -``` - -## Setup Instructions -Clone the repository. - -Install dependencies. - -## Usage -Modify configurations in configs/. - -Run main.py. - -``` -python main.py -r calibration -t -e -``` - -Monitor progress and results using [Weights & Biases](https://wandb.ai/views_pipeline/lavender_haze). \ No newline at end of file diff --git a/models/lavender_haze/artifacts/model_metadata_dict.py b/models/lavender_haze/artifacts/model_metadata_dict.py deleted file mode 100644 index e69de29b..00000000 diff --git a/models/lavender_haze/configs/config_deployment.py b/models/lavender_haze/configs/config_deployment.py deleted file mode 100644 index e1d56586..00000000 --- a/models/lavender_haze/configs/config_deployment.py +++ /dev/null @@ -1,16 +0,0 @@ -def get_deployment_config(): - - """ - Contains the configuration for deploying the model into different environments. - This configuration is "behavioral" so modifying it will affect the model's runtime behavior and integration into the deployment system. - - Returns: - - deployment_config (dict): A dictionary containing deployment settings, determining how the model is deployed, including status, endpoints, and resource allocation. - """ - - # More deployment settings can/will be added here - deployment_config = { - "deployment_status": "shadow", # shadow, deployed, baseline, or deprecated - } - - return deployment_config \ No newline at end of file diff --git a/models/lavender_haze/configs/config_hyperparameters.py b/models/lavender_haze/configs/config_hyperparameters.py deleted file mode 100644 index 11675b1f..00000000 --- a/models/lavender_haze/configs/config_hyperparameters.py +++ /dev/null @@ -1,17 +0,0 @@ -def get_hp_config(): - hp_config = { - "steps": [*range(1, 36 + 1, 1)], - "parameters": { - "clf":{ - "learning_rate": 0.05, - "n_estimators": 100, - "n_jobs": 12 - }, - "reg":{ - "learning_rate": 0.05, - "n_estimators": 100, - "n_jobs": 12 - } - } - } - return hp_config \ No newline at end of file diff --git a/models/lavender_haze/configs/config_input_data.py b/models/lavender_haze/configs/config_input_data.py deleted file mode 100644 index 28f790c6..00000000 --- a/models/lavender_haze/configs/config_input_data.py +++ /dev/null @@ -1,167 +0,0 @@ -import numpy as np -from viewser import Queryset, Column - -def get_input_data_config(): - - thetacrit_spatial = 0.7 - return_values = 'distances' - n_nearest = 1 - power = 0.0 - - qs_broad = (Queryset("fatalities003_pgm_broad", "priogrid_month") - - # target variable - .with_column(Column("ln_ged_sb_dep", from_loa="priogrid_month", from_column="ged_sb_best_sum_nokgi") - .transform.missing.replace_na() - .transform.ops.ln() - ) - - # timelags 0 of conflict variables, ged_best versions - - .with_column(Column("ged_sb", from_loa="priogrid_month", from_column="ged_sb_best_sum_nokgi") - .transform.missing.fill() - .transform.missing.replace_na() - ) - - .with_column(Column("ged_os", from_loa="priogrid_month", from_column="ged_os_best_sum_nokgi") - .transform.missing.fill() - .transform.missing.replace_na() - ) - - .with_column(Column("ged_ns", from_loa="priogrid_month", from_column="ged_ns_best_sum_nokgi") - .transform.missing.fill() - .transform.missing.replace_na() - ) - - # Spatial lag - .with_column(Column("splag_1_1_sb_1", from_loa="priogrid_month", from_column="ged_sb_best_sum_nokgi") - .transform.missing.replace_na() - .transform.bool.gte(1) - .transform.temporal.time_since() - .transform.temporal.decay(24) - .transform.spatial.lag(1, 1, 0, 0) - .transform.missing.replace_na() - ) - - # Decay functions - # sb - .with_column(Column("decay_ged_sb_5", from_loa="priogrid_month", from_column="ged_sb_best_sum_nokgi") - .transform.missing.replace_na() - .transform.bool.gte(5) - .transform.temporal.time_since() - .transform.temporal.decay(12) - .transform.missing.replace_na() - ) - # os - .with_column(Column("decay_ged_os_5", from_loa="priogrid_month", from_column="ged_os_best_sum_nokgi") - .transform.missing.replace_na() - .transform.bool.gte(5) - .transform.temporal.time_since() - .transform.temporal.decay(12) - .transform.missing.replace_na() - ) - - # ns - .with_column(Column("decay_ged_ns_5", from_loa="priogrid_month", from_column="ged_ns_best_sum_nokgi") - .transform.missing.replace_na() - .transform.bool.gte(5) - .transform.temporal.time_since() - .transform.temporal.decay(12) - .transform.missing.replace_na() - ) - - # Trees - - .with_column(Column("treelag_1_sb", from_loa="priogrid_month", from_column="ged_sb_best_sum_nokgi") - .transform.missing.replace_na() - .transform.spatial.treelag(thetacrit_spatial, 1) - ) - - .with_column(Column("treelag_2_sb", from_loa="priogrid_month", from_column="ged_sb_best_sum_nokgi") - .transform.missing.replace_na() - .transform.spatial.treelag(thetacrit_spatial, 2) - ) - # sptime - - # continuous, sptime_dist, nu=1 - .with_column(Column("sptime_dist_k1_1_ged_sb", from_loa="priogrid_month", from_column="ged_sb_best_sum_nokgi") - .transform.missing.replace_na() - .transform.spatial.sptime_dist(return_values, n_nearest, 1.0, power) - ) - - .with_column(Column("sptime_dist_k1_2_ged_sb", from_loa="priogrid_month", from_column="ged_sb_best_sum_nokgi") - .transform.missing.replace_na() - .transform.spatial.sptime_dist(return_values, n_nearest, 10.0, power) - ) - - .with_column(Column("sptime_dist_k1_3_ged_sb", from_loa="priogrid_month", from_column="ged_sb_best_sum_nokgi") - .transform.missing.replace_na() - .transform.spatial.sptime_dist(return_values, n_nearest, 0.01, power) - ) - - # From natsoc - .with_column(Column("ln_ttime_mean", from_loa="priogrid_year", from_column="ttime_mean") - .transform.ops.ln() - .transform.missing.fill() - .transform.missing.replace_na() - ) - - .with_column(Column("ln_bdist3", from_loa="priogrid_year", from_column="bdist3") - .transform.ops.ln() - .transform.missing.fill() - .transform.missing.replace_na() - ) - - .with_column(Column("ln_capdist", from_loa="priogrid_year", from_column="capdist") - .transform.ops.ln() - .transform.missing.fill() - .transform.missing.replace_na() - ) - - .with_column(Column("dist_diamsec", from_loa="priogrid", from_column="dist_diamsec_s_wgs") - .transform.missing.fill() - .transform.missing.replace_na() - ) - - .with_column(Column("imr_mean", from_loa="priogrid_year", from_column="imr_mean") - .transform.missing.fill() - .transform.missing.replace_na() - ) - - # From drought - .with_column(Column("tlag1_dr_mod_gs", from_loa="priogrid_month", - from_column="tlag1_dr_mod_gs") - .transform.missing.replace_na(0) - ) - - .with_column(Column("spei1_gs_prev10_anom", from_loa="priogrid_month", - from_column="spei1_gs_prev10_anom") - .transform.missing.replace_na(0) - ) - - .with_column(Column("tlag_12_crop_sum", from_loa="priogrid_month", - from_column="tlag_12_crop_sum") - .transform.missing.replace_na(0) - ) - - .with_column(Column("spei1gsy_lowermedian_count", from_loa="priogrid_month", - from_column="spei1gsy_lowermedian_count") - .transform.missing.replace_na(0) - ) - - # Log population as control - .with_column(Column("ln_pop_gpw_sum", from_loa="priogrid_year", from_column="pop_gpw_sum") - .transform.ops.ln() - .transform.missing.fill() - .transform.missing.replace_na() - ) - - .with_theme("fatalities") - .describe("""fatalities broad model, pgm level - - Predicting ln(ged_best_sb), broad model - - """) - ) - - return qs_broad \ No newline at end of file diff --git a/models/lavender_haze/configs/config_meta.py b/models/lavender_haze/configs/config_meta.py deleted file mode 100644 index e7ca1fc5..00000000 --- a/models/lavender_haze/configs/config_meta.py +++ /dev/null @@ -1,19 +0,0 @@ -def get_meta_config(): - """ - Contains the meta data for the model (model architecture, name, target variable, and level of analysis). - This config is for documentation purposes only, and modifying it will not affect the model, the training, or the evaluation. - - Returns: - - meta_config (dict): A dictionary containing model meta configuration. - """ - meta_config = { - "name": "lavender_haze", - "algorithm": "HurdleRegression", - "model_clf": "LGBMClassifier", - "model_reg": "LGBMRegressor", - "depvar": "ln_ged_sb_dep", # IMPORTANT! The current stepshift only takes one target variable! Not compatiable with Simon's code! - "queryset": "fatalities003_pgm_broad", - "level": "pgm", - "creator": "Xiaolong" - } - return meta_config \ No newline at end of file diff --git a/models/lavender_haze/configs/config_sweep.py b/models/lavender_haze/configs/config_sweep.py deleted file mode 100644 index 36ebf61c..00000000 --- a/models/lavender_haze/configs/config_sweep.py +++ /dev/null @@ -1,26 +0,0 @@ -def get_sweep_config(): - sweep_config = { - "name": "lavender_haze", - "method": "grid" - } - - metric = { - "name": "MSE", - "goal": "minimize" - } - - sweep_config["metric"] = metric - - parameters_dict = { - "steps": {"values": [[*range(1, 36 + 1, 1)]]}, - "cls_n_estimators": {"values": [100, 200]}, - "cls_learning_rate": {"values": [0.05]}, - "cls_n_jobs": {"values": [12]}, - "reg_n_estimators": {"values": [100, 200]}, - "reg_learning_rate": {"values": [0.05]}, - "reg_n_jobs": {"values": [12]} - } - - sweep_config["parameters"] = parameters_dict - - return sweep_config \ No newline at end of file diff --git a/models/lavender_haze/data/generated/.gitkeep b/models/lavender_haze/data/generated/.gitkeep deleted file mode 100644 index e69de29b..00000000 diff --git a/models/lavender_haze/data/processed/.gitkeep b/models/lavender_haze/data/processed/.gitkeep deleted file mode 100644 index e69de29b..00000000 diff --git a/models/lavender_haze/data/raw/.gitkeep b/models/lavender_haze/data/raw/.gitkeep deleted file mode 100644 index e69de29b..00000000 diff --git a/models/lavender_haze/main.py b/models/lavender_haze/main.py deleted file mode 100644 index c199227c..00000000 --- a/models/lavender_haze/main.py +++ /dev/null @@ -1,30 +0,0 @@ -import wandb -import sys -import warnings - -from pathlib import Path -PATH = Path(__file__) -sys.path.insert(0, str(Path( - *[i for i in PATH.parts[:PATH.parts.index("views_pipeline") + 1]]) / "common_utils")) # PATH_COMMON_UTILS -from set_path import setup_project_paths, setup_root_paths -setup_project_paths(PATH) - -from utils_cli_parser import parse_args, validate_arguments -from utils_logger import setup_logging -from execute_model_runs import execute_sweep_run, execute_single_run - -warnings.filterwarnings("ignore") - -logger = setup_logging('run.log') - - -if __name__ == "__main__": - wandb.login() - - args = parse_args() - validate_arguments(args) - - if args.sweep: - execute_sweep_run(args) - else: - execute_single_run(args) diff --git a/models/lavender_haze/notebooks/notebook001.ipynb b/models/lavender_haze/notebooks/notebook001.ipynb deleted file mode 100644 index 4d25bd82..00000000 --- a/models/lavender_haze/notebooks/notebook001.ipynb +++ /dev/null @@ -1,696 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "from typing import Optional, Union\n", - "import numpy as np\n", - "import pandas as pd\n", - "\n", - "from sklearn.linear_model import LinearRegression, LogisticRegression\n", - "from sklearn.base import BaseEstimator\n", - "from sklearn.utils.estimator_checks import check_estimator\n", - "from sklearn.utils.validation import check_X_y, check_array, check_is_fitted\n", - "from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor\n", - "from sklearn.ensemble import RandomForestRegressor\n", - "from sklearn.ensemble import RandomForestClassifier\n", - "from sklearn.ensemble import HistGradientBoostingRegressor\n", - "from sklearn.ensemble import HistGradientBoostingClassifier\n", - "from xgboost import XGBRegressor\n", - "from xgboost import XGBClassifier\n", - "from xgboost import XGBRFRegressor, XGBRFClassifier\n", - "from lightgbm import LGBMClassifier, LGBMRegressor\n", - "\n", - "#from lightgbm import LGBMClassifier, LGBMRegressor\n", - "\n", - "\n", - "class HurdleRegression(BaseEstimator):\n", - " \"\"\" Regression model which handles excessive zeros by fitting a two-part model and combining predictions:\n", - " 1) binary classifier\n", - " 2) continuous regression\n", - " Implementeted as a valid sklearn estimator, so it can be used in pipelines and GridSearch objects.\n", - " Args:\n", - " clf_name: currently supports either 'logistic' or 'LGBMClassifier'\n", - " reg_name: currently supports either 'linear' or 'LGBMRegressor'\n", - " clf_params: dict of parameters to pass to classifier sub-model when initialized\n", - " reg_params: dict of parameters to pass to regression sub-model when initialized\n", - " \"\"\"\n", - "\n", - " def __init__(self,\n", - " clf_name: str = 'logistic',\n", - " reg_name: str = 'linear',\n", - " clf_params: Optional[dict] = None,\n", - " reg_params: Optional[dict] = None):\n", - "\n", - " self.clf_name = clf_name\n", - " self.reg_name = reg_name\n", - " self.clf_params = clf_params\n", - " self.reg_params = reg_params\n", - " self.clf_fi = []\n", - " self.reg_fi = []\n", - "\n", - " @staticmethod\n", - " def _resolve_estimator(func_name: str):\n", - " \"\"\" Lookup table for supported estimators.\n", - " This is necessary because sklearn estimator default arguments\n", - " must pass equality test, and instantiated sub-estimators are not equal. \"\"\"\n", - "\n", - " funcs = {'linear': LinearRegression(),\n", - " 'logistic': LogisticRegression(solver='liblinear'),\n", - " 'LGBMRegressor': LGBMRegressor(n_estimators=250),\n", - " 'LGBMClassifier': LGBMClassifier(n_estimators=250),\n", - " 'RFRegressor': XGBRFRegressor(n_estimators=250,n_jobs=-2),\n", - " 'RFClassifier': XGBRFClassifier(n_estimators=250,n_jobs=-2),\n", - " 'GBMRegressor': GradientBoostingRegressor(n_estimators=200),\n", - " 'GBMClassifier': GradientBoostingClassifier(n_estimators=200),\n", - " 'XGBRegressor': XGBRegressor(n_estimators=100,learning_rate=0.05,n_jobs=-2),\n", - " 'XGBClassifier': XGBClassifier(n_estimators=100,learning_rate=0.05,n_jobs=-2),\n", - " 'HGBRegressor': HistGradientBoostingRegressor(max_iter=200),\n", - " 'HGBClassifier': HistGradientBoostingClassifier(max_iter=200),\n", - " }\n", - "\n", - " return funcs[func_name]\n", - "\n", - " def fit(self,\n", - " X: Union[np.ndarray, pd.DataFrame],\n", - " y: Union[np.ndarray, pd.Series]):\n", - " X, y = check_X_y(X, y, dtype=None,\n", - " accept_sparse=False,\n", - " accept_large_sparse=False,\n", - " force_all_finite='allow-nan')\n", - "\n", - " if X.shape[1] < 2:\n", - " raise ValueError('Cannot fit model when n_features = 1')\n", - "\n", - " self.clf_ = self._resolve_estimator(self.clf_name)\n", - " if self.clf_params:\n", - " self.clf_.set_params(**self.clf_params)\n", - " self.clf_.fit(X, y > 0)\n", - " self.clf_fi = self.clf_.feature_importances_\n", - "\n", - " self.reg_ = self._resolve_estimator(self.reg_name)\n", - " if self.reg_params:\n", - " self.reg_.set_params(**self.reg_params)\n", - " self.reg_.fit(X[y > 0], y[y > 0])\n", - " self.reg_fi = self.reg_.feature_importances_\n", - "\n", - " self.is_fitted_ = True\n", - " return self\n", - "\n", - "\n", - " def predict_bck(self, X: Union[np.ndarray, pd.DataFrame]):\n", - " \"\"\" Predict combined response using binary classification outcome \"\"\"\n", - " X = check_array(X, accept_sparse=False, accept_large_sparse=False)\n", - " check_is_fitted(self, 'is_fitted_')\n", - " return self.clf_.predict(X) * self.reg_.predict(X)\n", - "\n", - " def predict(self, X: Union[np.ndarray, pd.DataFrame]):\n", - " \"\"\" Predict combined response using probabilistic classification outcome \"\"\"\n", - " X = check_array(X, accept_sparse=False, accept_large_sparse=False)\n", - " check_is_fitted(self, 'is_fitted_')\n", - " return self.clf_.predict_proba(X)[:, 1] * self.reg_.predict(X)\n", - "\n", - " \n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "hp_config = {\n", - " \"clf\":{\n", - " \"learning_rate\": 0.05,\n", - " \"n_estimators\": 100,\n", - " \"n_jobs\": 12\n", - " },\n", - " \"reg\":{\n", - " \"learning_rate\": 0.05,\n", - " \"n_estimators\": 100,\n", - " \"n_jobs\": 12\n", - " }\n", - "}\n", - "common_config = {\n", - " \"name\": \"lavender_haze\",\n", - " \"algorithm\": \"HurdleRegression\",\n", - " \"clf_name\":\"LGBMClassifier\",\n", - " \"reg_name\":\"LGBMRegressor\",\n", - " \"depvar\": \"ged_sb_dep\",\n", - " \"queryset\": \"fatalities003_pgm_broad\",\n", - " \"data_train\": \"baseline\",\n", - " \"level\": \"pgm\",\n", - " 'steps': [*range(1, 36 + 1, 1)],\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "cls_model = HurdleRegression(clf_name=common_config['clf_name'], reg_name=common_config['reg_name'], clf_params=hp_config['clf'], reg_params=hp_config['reg'])" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
HurdleRegression(clf_name='LGBMClassifier',\n", - " clf_params={'learning_rate': 0.05, 'n_estimators': 100,\n", - " 'n_jobs': 12},\n", - " reg_name='LGBMRegressor',\n", - " reg_params={'learning_rate': 0.05, 'n_estimators': 100,\n", - " 'n_jobs': 12})In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
HurdleRegression(clf_name='LGBMClassifier',\n", - " clf_params={'learning_rate': 0.05, 'n_estimators': 100,\n", - " 'n_jobs': 12},\n", - " reg_name='LGBMRegressor',\n", - " reg_params={'learning_rate': 0.05, 'n_estimators': 100,\n", - " 'n_jobs': 12})
\n", - " | \n", - " | tlag1_dr_mod_gs | \n", - "spei1_gs_prev10_anom | \n", - "tlag_12_crop_sum | \n", - "spei1gsy_lowermedian_count | \n", - "ln_ged_sb_dep | \n", - "ged_sb | \n", - "ged_os | \n", - "ged_ns | \n", - "treelag_1_sb | \n", - "treelag_2_sb | \n", - "... | \n", - "dist_diamsec | \n", - "imr_mean | \n", - "ln_ttime_mean | \n", - "ln_bdist3 | \n", - "ln_capdist | \n", - "ln_pop_gpw_sum | \n", - "decay_ged_sb_5 | \n", - "decay_ged_os_5 | \n", - "decay_ged_ns_5 | \n", - "splag_1_1_sb_1 | \n", - "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
month_id | \n", - "priogrid_gid | \n", - "\n", - " | \n", - " | \n", - " | \n", - " | \n", - " | \n", - " | \n", - " | \n", - " | \n", - " | \n", - " | \n", - " | \n", - " | \n", - " | \n", - " | \n", - " | \n", - " | \n", - " | \n", - " | \n", - " | \n", - " | \n", - " |
1 | \n", - "62356 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "... | \n", - "19.235384 | \n", - "0.0 | \n", - "7.989464 | \n", - "2.263900 | \n", - "7.817437 | \n", - "0.000000 | \n", - "0.000000e+00 | \n", - "0.000000e+00 | \n", - "0.000000e+00 | \n", - "0.000000e+00 | \n", - "
79599 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "... | \n", - "3.640055 | \n", - "100.0 | \n", - "5.251089 | \n", - "2.961998 | \n", - "7.187934 | \n", - "8.266445 | \n", - "0.000000e+00 | \n", - "0.000000e+00 | \n", - "0.000000e+00 | \n", - "0.000000e+00 | \n", - "|
79600 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "... | \n", - "3.807887 | \n", - "100.0 | \n", - "5.656525 | \n", - "0.364952 | \n", - "7.164395 | \n", - "7.805237 | \n", - "0.000000e+00 | \n", - "0.000000e+00 | \n", - "0.000000e+00 | \n", - "0.000000e+00 | \n", - "|
79601 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "... | \n", - "4.031129 | \n", - "100.0 | \n", - "5.465652 | \n", - "2.379325 | \n", - "7.141138 | \n", - "9.335159 | \n", - "0.000000e+00 | \n", - "0.000000e+00 | \n", - "0.000000e+00 | \n", - "0.000000e+00 | \n", - "|
80317 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "... | \n", - "3.000000 | \n", - "100.0 | \n", - "3.409915 | \n", - "2.520981 | \n", - "7.208015 | \n", - "12.654427 | \n", - "0.000000e+00 | \n", - "0.000000e+00 | \n", - "0.000000e+00 | \n", - "0.000000e+00 | \n", - "|
... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "
852 | \n", - "190496 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "... | \n", - "10.295630 | \n", - "150.0 | \n", - "5.687243 | \n", - "0.493902 | \n", - "5.910060 | \n", - "10.408626 | \n", - "4.487001e-22 | \n", - "4.487001e-22 | \n", - "4.487001e-22 | \n", - "8.473017e-11 | \n", - "
190507 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "... | \n", - "6.103278 | \n", - "419.0 | \n", - "5.335934 | \n", - "3.317541 | \n", - "5.564456 | \n", - "6.647283 | \n", - "4.487001e-22 | \n", - "4.487001e-22 | \n", - "4.487001e-22 | \n", - "8.473017e-11 | \n", - "|
190508 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "... | \n", - "5.830952 | \n", - "419.0 | \n", - "0.000000 | \n", - "3.433905 | \n", - "5.596457 | \n", - "4.562102 | \n", - "4.487001e-22 | \n", - "4.487001e-22 | \n", - "4.487001e-22 | \n", - "8.473017e-11 | \n", - "|
190510 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "... | \n", - "5.385165 | \n", - "419.0 | \n", - "5.904822 | \n", - "3.240468 | \n", - "5.716054 | \n", - "7.619576 | \n", - "4.487001e-22 | \n", - "4.487001e-22 | \n", - "4.487001e-22 | \n", - "8.473017e-11 | \n", - "|
190511 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "0.0 | \n", - "... | \n", - "5.220153 | \n", - "419.0 | \n", - "5.479170 | \n", - "3.287923 | \n", - "5.791936 | \n", - "7.596084 | \n", - "4.487001e-22 | \n", - "4.487001e-22 | \n", - "4.487001e-22 | \n", - "8.473017e-11 | \n", - "
11169720 rows × 23 columns
\n", - "