-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
changing model directories for testing
- Loading branch information
Borbála Farkas
authored and
Borbála Farkas
committed
Oct 28, 2024
1 parent
7fdf35d
commit 2f283d1
Showing
35 changed files
with
1,580 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# Lavender Haze Model | ||
## Overview | ||
This folder contains code for Lavender Haze model, a machine learning model designed for predicting fatalities. | ||
|
||
The model utilizes Hurdle Model (LGBMClassifier+LGBMRegressor) for its predictions and is on pgm level of analysis. | ||
|
||
The model uses log fatalities. | ||
|
||
## Repository Structure | ||
``` | ||
lavender_haze/ # should follow the naming convention adjective_noun | ||
|-- README.md | ||
|-- requirements.txt | ||
| | ||
|-- artifacts/ # ensemble stepshifter models | ||
| |-- model_metadata_dict.py # the standard meta data dict for models | ||
| | ||
|-- configs/ # ... | ||
| |-- config_deployment.py # configuration for deploying the model into different environments | ||
| |-- config_hyperparameters.py # hyperparameters for the model | ||
| |-- config_input_data.py # defined queryset as the input data | ||
| |-- config_meta # metadata for the model (model architecture, name, target variable, and level of analysis) | ||
| |-- config_sweep # sweeping parameters for weights & biases | ||
| | ||
|-- data/ # all input, processed, output data | ||
| |-- generated/ # Data generated - i.e. forecast/ evaluation | ||
| |-- processed/ # Data processed | ||
| |-- raw/ # Data directly from VIEiWSER | ||
| | ||
|-- notebooks/ | ||
| | ||
|-- reports/ # dissemination material - internal and external | ||
| |-- figures/ # figures for papers, reports, newsletters, and slides | ||
| |-- papers/ # working papers, white papers, articles ect. | ||
| |-- plots/ # plots for papers, reports, newsletters, and slides | ||
| |-- slides/ # slides, presentation and similar | ||
| |-- timelapse/ # plots to create timelapse and the timelapse | ||
| | ||
|-- src/ # all source code needed to train, test, and forecast | ||
| | ||
|-- dataloaders/ | ||
| |-- get_data.py # script to get data from VIEWSER (and input drift detection) | ||
| | ||
|-- forecasting/ | ||
| |-- generate_forecast.py # script to genereate true-future fc | ||
| | ||
|-- management/ | ||
| |-- execute_model_runs.py # execute a single run | ||
| |-- execute_model_tasks.py # execute various model-related tasks | ||
| | ||
|-- offline_evaluation/ # aka offline quality assurance | ||
| |-- evaluate_model.py # script to evaluate a single model | ||
| |-- evaluate_sweep.py # script to evaluate a model during sweeping | ||
| | ||
|-- online_evaluation/ | ||
| | ||
|-- training/ | ||
| |-- train_model.py # script to train a single model | ||
| | ||
|-- utils/ # functions and classes | ||
| |-- utils.py # a general utils function | ||
| |-- utils_wandb.py # a w&b specific utils function | ||
| | ||
|-- visualization/ # scripts to create visualizations | ||
|-- visual.py | ||
``` | ||
|
||
## Setup Instructions | ||
Clone the repository. | ||
|
||
Install dependencies. | ||
|
||
## Usage | ||
Modify configurations in configs/. | ||
|
||
Run main.py. | ||
|
||
``` | ||
python main.py -r calibration -t -e | ||
``` | ||
|
||
Monitor progress and results using [Weights & Biases](https://wandb.ai/views_pipeline/lavender_haze). |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
def get_deployment_config(): | ||
|
||
""" | ||
Contains the configuration for deploying the model into different environments. | ||
This configuration is "behavioral" so modifying it will affect the model's runtime behavior and integration into the deployment system. | ||
Returns: | ||
- deployment_config (dict): A dictionary containing deployment settings, determining how the model is deployed, including status, endpoints, and resource allocation. | ||
""" | ||
|
||
# More deployment settings can/will be added here | ||
deployment_config = { | ||
"deployment_status": "shadow", # shadow, deployed, baseline, or deprecated | ||
} | ||
|
||
return deployment_config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
def get_hp_config(): | ||
hp_config = { | ||
"steps": [*range(1, 36 + 1, 1)], | ||
"parameters": { | ||
"clf":{ | ||
"learning_rate": 0.05, | ||
"n_estimators": 100, | ||
"n_jobs": 12 | ||
}, | ||
"reg":{ | ||
"learning_rate": 0.05, | ||
"n_estimators": 100, | ||
"n_jobs": 12 | ||
} | ||
} | ||
} | ||
return hp_config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
import numpy as np | ||
from viewser import Queryset, Column | ||
|
||
def get_input_data_config(): | ||
|
||
thetacrit_spatial = 0.7 | ||
return_values = 'distances' | ||
n_nearest = 1 | ||
power = 0.0 | ||
|
||
qs_broad = (Queryset("fatalities003_pgm_broad", "priogrid_month") | ||
|
||
# target variable | ||
.with_column(Column("ln_ged_sb_dep", from_loa="priogrid_month", from_column="ged_sb_best_sum_nokgi") | ||
.transform.missing.replace_na() | ||
.transform.ops.ln() | ||
) | ||
|
||
# timelags 0 of conflict variables, ged_best versions | ||
|
||
.with_column(Column("ged_sb", from_loa="priogrid_month", from_column="ged_sb_best_sum_nokgi") | ||
.transform.missing.fill() | ||
.transform.missing.replace_na() | ||
) | ||
|
||
.with_column(Column("ged_os", from_loa="priogrid_month", from_column="ged_os_best_sum_nokgi") | ||
.transform.missing.fill() | ||
.transform.missing.replace_na() | ||
) | ||
|
||
.with_column(Column("ged_ns", from_loa="priogrid_month", from_column="ged_ns_best_sum_nokgi") | ||
.transform.missing.fill() | ||
.transform.missing.replace_na() | ||
) | ||
|
||
# Spatial lag | ||
.with_column(Column("splag_1_1_sb_1", from_loa="priogrid_month", from_column="ged_sb_best_sum_nokgi") | ||
.transform.missing.replace_na() | ||
.transform.bool.gte(1) | ||
.transform.temporal.time_since() | ||
.transform.temporal.decay(24) | ||
.transform.spatial.lag(1, 1, 0, 0) | ||
.transform.missing.replace_na() | ||
) | ||
|
||
# Decay functions | ||
# sb | ||
.with_column(Column("decay_ged_sb_5", from_loa="priogrid_month", from_column="ged_sb_best_sum_nokgi") | ||
.transform.missing.replace_na() | ||
.transform.bool.gte(5) | ||
.transform.temporal.time_since() | ||
.transform.temporal.decay(12) | ||
.transform.missing.replace_na() | ||
) | ||
# os | ||
.with_column(Column("decay_ged_os_5", from_loa="priogrid_month", from_column="ged_os_best_sum_nokgi") | ||
.transform.missing.replace_na() | ||
.transform.bool.gte(5) | ||
.transform.temporal.time_since() | ||
.transform.temporal.decay(12) | ||
.transform.missing.replace_na() | ||
) | ||
|
||
# ns | ||
.with_column(Column("decay_ged_ns_5", from_loa="priogrid_month", from_column="ged_ns_best_sum_nokgi") | ||
.transform.missing.replace_na() | ||
.transform.bool.gte(5) | ||
.transform.temporal.time_since() | ||
.transform.temporal.decay(12) | ||
.transform.missing.replace_na() | ||
) | ||
|
||
# Trees | ||
|
||
.with_column(Column("treelag_1_sb", from_loa="priogrid_month", from_column="ged_sb_best_sum_nokgi") | ||
.transform.missing.replace_na() | ||
.transform.spatial.treelag(thetacrit_spatial, 1) | ||
) | ||
|
||
.with_column(Column("treelag_2_sb", from_loa="priogrid_month", from_column="ged_sb_best_sum_nokgi") | ||
.transform.missing.replace_na() | ||
.transform.spatial.treelag(thetacrit_spatial, 2) | ||
) | ||
# sptime | ||
|
||
# continuous, sptime_dist, nu=1 | ||
.with_column(Column("sptime_dist_k1_1_ged_sb", from_loa="priogrid_month", from_column="ged_sb_best_sum_nokgi") | ||
.transform.missing.replace_na() | ||
.transform.spatial.sptime_dist(return_values, n_nearest, 1.0, power) | ||
) | ||
|
||
.with_column(Column("sptime_dist_k1_2_ged_sb", from_loa="priogrid_month", from_column="ged_sb_best_sum_nokgi") | ||
.transform.missing.replace_na() | ||
.transform.spatial.sptime_dist(return_values, n_nearest, 10.0, power) | ||
) | ||
|
||
.with_column(Column("sptime_dist_k1_3_ged_sb", from_loa="priogrid_month", from_column="ged_sb_best_sum_nokgi") | ||
.transform.missing.replace_na() | ||
.transform.spatial.sptime_dist(return_values, n_nearest, 0.01, power) | ||
) | ||
|
||
# From natsoc | ||
.with_column(Column("ln_ttime_mean", from_loa="priogrid_year", from_column="ttime_mean") | ||
.transform.ops.ln() | ||
.transform.missing.fill() | ||
.transform.missing.replace_na() | ||
) | ||
|
||
.with_column(Column("ln_bdist3", from_loa="priogrid_year", from_column="bdist3") | ||
.transform.ops.ln() | ||
.transform.missing.fill() | ||
.transform.missing.replace_na() | ||
) | ||
|
||
.with_column(Column("ln_capdist", from_loa="priogrid_year", from_column="capdist") | ||
.transform.ops.ln() | ||
.transform.missing.fill() | ||
.transform.missing.replace_na() | ||
) | ||
|
||
.with_column(Column("dist_diamsec", from_loa="priogrid", from_column="dist_diamsec_s_wgs") | ||
.transform.missing.fill() | ||
.transform.missing.replace_na() | ||
) | ||
|
||
.with_column(Column("imr_mean", from_loa="priogrid_year", from_column="imr_mean") | ||
.transform.missing.fill() | ||
.transform.missing.replace_na() | ||
) | ||
|
||
# From drought | ||
.with_column(Column("tlag1_dr_mod_gs", from_loa="priogrid_month", | ||
from_column="tlag1_dr_mod_gs") | ||
.transform.missing.replace_na(0) | ||
) | ||
|
||
.with_column(Column("spei1_gs_prev10_anom", from_loa="priogrid_month", | ||
from_column="spei1_gs_prev10_anom") | ||
.transform.missing.replace_na(0) | ||
) | ||
|
||
.with_column(Column("tlag_12_crop_sum", from_loa="priogrid_month", | ||
from_column="tlag_12_crop_sum") | ||
.transform.missing.replace_na(0) | ||
) | ||
|
||
.with_column(Column("spei1gsy_lowermedian_count", from_loa="priogrid_month", | ||
from_column="spei1gsy_lowermedian_count") | ||
.transform.missing.replace_na(0) | ||
) | ||
|
||
# Log population as control | ||
.with_column(Column("ln_pop_gpw_sum", from_loa="priogrid_year", from_column="pop_gpw_sum") | ||
.transform.ops.ln() | ||
.transform.missing.fill() | ||
.transform.missing.replace_na() | ||
) | ||
|
||
.with_theme("fatalities") | ||
.describe("""fatalities broad model, pgm level | ||
Predicting ln(ged_best_sb), broad model | ||
""") | ||
) | ||
|
||
return qs_broad |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
def get_meta_config(): | ||
""" | ||
Contains the meta data for the model (model architecture, name, target variable, and level of analysis). | ||
This config is for documentation purposes only, and modifying it will not affect the model, the training, or the evaluation. | ||
Returns: | ||
- meta_config (dict): A dictionary containing model meta configuration. | ||
""" | ||
meta_config = { | ||
"name": "lavender_haze", | ||
"algorithm": "HurdleRegression", | ||
"model_clf": "LGBMClassifier", | ||
"model_reg": "LGBMRegressor", | ||
"depvar": "ln_ged_sb_dep", # IMPORTANT! The current stepshift only takes one target variable! Not compatiable with Simon's code! | ||
"queryset": "fatalities003_pgm_broad", | ||
"level": "pgm", | ||
"creator": "Xiaolong" | ||
} | ||
return meta_config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
def get_sweep_config(): | ||
sweep_config = { | ||
"name": "lavender_haze", | ||
"method": "grid" | ||
} | ||
|
||
metric = { | ||
"name": "MSE", | ||
"goal": "minimize" | ||
} | ||
|
||
sweep_config["metric"] = metric | ||
|
||
parameters_dict = { | ||
"steps": {"values": [[*range(1, 36 + 1, 1)]]}, | ||
"cls_n_estimators": {"values": [100, 200]}, | ||
"cls_learning_rate": {"values": [0.05]}, | ||
"cls_n_jobs": {"values": [12]}, | ||
"reg_n_estimators": {"values": [100, 200]}, | ||
"reg_learning_rate": {"values": [0.05]}, | ||
"reg_n_jobs": {"values": [12]} | ||
} | ||
|
||
sweep_config["parameters"] = parameters_dict | ||
|
||
return sweep_config |
Empty file.
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import wandb | ||
import sys | ||
import warnings | ||
|
||
from pathlib import Path | ||
PATH = Path(__file__) | ||
sys.path.insert(0, str(Path( | ||
*[i for i in PATH.parts[:PATH.parts.index("views_pipeline") + 1]]) / "common_utils")) # PATH_COMMON_UTILS | ||
from set_path import setup_project_paths, setup_root_paths | ||
setup_project_paths(PATH) | ||
|
||
from utils_cli_parser import parse_args, validate_arguments | ||
from utils_logger import setup_logging | ||
from execute_model_runs import execute_sweep_run, execute_single_run | ||
|
||
warnings.filterwarnings("ignore") | ||
|
||
logger = setup_logging('run.log') | ||
|
||
|
||
if __name__ == "__main__": | ||
wandb.login() | ||
|
||
args = parse_args() | ||
validate_arguments(args) | ||
|
||
if args.sweep: | ||
execute_sweep_run(args) | ||
else: | ||
execute_single_run(args) |
Oops, something went wrong.