Skip to content

Commit

Permalink
chore: make easier to reuse
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Maik Jablonka committed Sep 27, 2022
1 parent 90d48bd commit 8ab6e59
Show file tree
Hide file tree
Showing 13 changed files with 440 additions and 50 deletions.
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"esbonio.server.enabled": true
}
16 changes: 9 additions & 7 deletions paper/20220310_plot_causalimpact.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
"metadata": {},
"outputs": [],
"source": [
"from aeml.utils.io import dump_pickle, read_pickle\n",
"from aeml.models.gbdt.plot import make_forecast_plot\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"from aeml.models.gbdt.plot import make_forecast_plot\n",
"from aeml.utils.io import dump_pickle, read_pickle\n",
"\n",
"TARGETS_clean = ['2-Amino-2-methylpropanol C4H11NO', 'Piperazine C4H10N2']\n",
"\n",
Expand All @@ -20,6 +19,7 @@
"plt.style.reload_library()\n",
"plt.style.use('science')\n",
"from matplotlib import rcParams\n",
"\n",
"rcParams['font.family'] = 'sans-serif'"
]
},
Expand Down Expand Up @@ -380,9 +380,6 @@
}
],
"metadata": {
"interpreter": {
"hash": "ffaf8fb1d943ea3037f4a4fdf1b01b66ee672aa3b35a840f4a4d996ec833da06"
},
"kernelspec": {
"display_name": "Python 3.8.12 ('aeml')",
"language": "python",
Expand All @@ -400,7 +397,12 @@
"pygments_lexer": "ipython3",
"version": "3.8.12"
},
"orig_nbformat": 4
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "6cba9f7d8f0983127eee8d7d40af44cfa7d9f36acbf6cef9327ee0048f1108bf"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
Expand Down
215 changes: 198 additions & 17 deletions paper/20220310_plot_forecast_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
"cells": [
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt \n",
"from aeml.utils.io import read_pickle\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from darts import TimeSeries\n",
"from darts.metrics import mape, mae, ope\n",
"from aeml.utils.io import read_pickle\n",
"\n",
"colors = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00', '#ffff33', '#a65628', '#f781bf', '#999999']\n",
"\n",
Expand All @@ -19,25 +19,25 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# those were run with during/2 as forecasting horizon \n",
"\n",
"step_1_target_0 = read_pickle('/home/kjablonk/documents/aeml/scratch/20220311-081020_6_filtered-1-step-target-0-forecasts_quantiles_0.1_0.9.pkl')\n",
"step_1_target_1 = read_pickle('/home/kjablonk/documents/aeml/scratch/20220311-081020_6_filtered-1-step-target-1-forecasts_quantiles_0.1_0.9.pkl')\n",
"step_1_target_0 = read_pickle('models/20220311-081020_6_filtered-1-step-target-0-forecasts_quantiles_0.1_0.9.pkl')\n",
"step_1_target_1 = read_pickle('models/20220311-081020_6_filtered-1-step-target-1-forecasts_quantiles_0.1_0.9.pkl')\n",
"\n",
"step_30_target_0 = read_pickle('/home/kjablonk/documents/aeml/scratch/20220311-081020_6_filtered-30-step-target-0-forecasts_quantiles_0.1_0.9.pkl')\n",
"step_30_target_1 = read_pickle('/home/kjablonk/documents/aeml/scratch/20220311-081020_6_filtered-30-step-target-1-forecasts_quantiles_0.1_0.9.pkl')\n",
"step_30_target_0 = read_pickle('models/20220311-081020_6_filtered-30-step-target-0-forecasts_quantiles_0.1_0.9.pkl')\n",
"step_30_target_1 = read_pickle('models/20220311-081020_6_filtered-30-step-target-1-forecasts_quantiles_0.1_0.9.pkl')\n",
"\n",
"step_60_target_0 = read_pickle('/home/kjablonk/documents/aeml/scratch/20220311-081020_6_filtered-60-step-target-0-forecasts_quantiles_0.1_0.9.pkl')\n",
"step_60_target_1 = read_pickle('/home/kjablonk/documents/aeml/scratch/20220311-081020_6_filtered-60-step-target-1-forecasts_quantiles_0.1_0.9.pkl')"
"step_60_target_0 = read_pickle('models/20220311-081020_6_filtered-60-step-target-0-forecasts_quantiles_0.1_0.9.pkl')\n",
"step_60_target_1 = read_pickle('models/20220311-081020_6_filtered-60-step-target-1-forecasts_quantiles_0.1_0.9.pkl')"
]
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -51,9 +51,16 @@
"x_conncected = [val.total_seconds() / (60 * 60 * 24) for val in x_conncected]\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Make plots"
]
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": 9,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -141,6 +148,178 @@
"fig.savefig(f'20220311_forecast_overview.pdf', bbox_inches='tight')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Compute metrics"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"def get_metrics(actual, predicted): \n",
" actual = TimeSeries.from_series(actual)\n",
" predicted = TimeSeries.from_series(predicted)\n",
" mae_score = mae(actual, predicted)\n",
" mape_score = mape(actual, predicted)\n",
" ope_score = ope(actual, predicted)\n",
" return {\n",
" 'mae': mae_score,\n",
" 'mape': mape_score,\n",
" 'ope': ope_score\n",
" }"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### AMP"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'mae': 0.008939960985387737,\n",
" 'mape': 2.4368939957216833,\n",
" 'ope': 0.37827978394930756}"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_metrics(y_scaled_df[TARGETS_clean[1]][4056:],step_1_target_0[1].pd_dataframe()['0'][4056:])"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'mae': 0.0394073422168147,\n",
" 'mape': 11.029698450021334,\n",
" 'ope': 0.34473332080028374}"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_metrics(y_scaled_df[TARGETS_clean[1]][4056:],step_30_target_0[1].pd_dataframe()['0'][4056:])"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'mae': 0.03550057382513647,\n",
" 'mape': 9.478101733271242,\n",
" 'ope': 3.820748410348062}"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_metrics(y_scaled_df[TARGETS_clean[1]][4056:],step_60_target_0[1].pd_dataframe()['0'][4056:])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Pz "
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'mae': 0.009471103586712878,\n",
" 'mape': 4.235518674101464,\n",
" 'ope': 1.9672137804416856}"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_metrics(y_scaled_df[TARGETS_clean[0]][4056:],step_1_target_1[1].pd_dataframe()['0'][4056:])"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'mae': 0.04999564169471071,\n",
" 'mape': 23.38424190674555,\n",
" 'ope': 17.405593802762066}"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_metrics(y_scaled_df[TARGETS_clean[0]][4056:],step_30_target_1[1].pd_dataframe()['0'][4056:])"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'mae': 0.04041629135027904,\n",
" 'mape': 20.868076234780723,\n",
" 'ope': 10.372566946028094}"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_metrics(y_scaled_df[TARGETS_clean[0]][4056:],step_60_target_1[1].pd_dataframe()['0'][4056:])"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -150,9 +329,6 @@
}
],
"metadata": {
"interpreter": {
"hash": "ffaf8fb1d943ea3037f4a4fdf1b01b66ee672aa3b35a840f4a4d996ec833da06"
},
"kernelspec": {
"display_name": "Python 3.8.12 ('aeml')",
"language": "python",
Expand All @@ -170,7 +346,12 @@
"pygments_lexer": "ipython3",
"version": "3.8.12"
},
"orig_nbformat": 4
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "6cba9f7d8f0983127eee8d7d40af44cfa7d9f36acbf6cef9327ee0048f1108bf"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
Expand Down
Binary file added paper/20220311_forecast_overview.pdf
Binary file not shown.
11 changes: 8 additions & 3 deletions paper/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

## Notebooks
- `20220310_plot_causalimpact.ipynb` used to plots the results of the causal impact analysis (e.g. generated via `xgboost_causalimpact.py`)
- `20220310_plot_forecast_overview.ipynb` used to plot an overview of the historical forecasts
- `20220310_plot_forecast_overview.ipynb` used to plot an overview of the historical forecasts (Figure 7 in the main text as well as the model evaluation)
- `20220310_train_gbdt_on_all.ipynb` used to train the GBDT model on all data (subsequently used to compute the scenarios)
- `20220306_predict_w_gbdt.ipynb` example for training a GBDT model

## Scripts
### Causaul impact analysis
- `causalimpact_sweep.py` run the hyperparamter sweep
### Causal impact analysis
- `causalimpact_sweep.py` run the hyperparamter sweep (assumes [Weights and Biases](https://wandb.ai/site) is set up)
- `causalimpact_xgboost.py` run the causal impact analysis using GBDT models
- `tcn_causalimpact.py` run the analysis using TCN models
- `step_times.pkl` contains the timestamps for the step changes in our study
Expand All @@ -17,3 +17,8 @@
- `loop_over_maps_gbdt.py` / `loop_over_maps_scitas.py` used to create and submit slurm script for "scenario" analysis
- `plot_effects_gbdt.py` / `plot_effects.py` used to convert the outputs of the scenario scripts into heatmaps
- `run_gbdt_scenarios.py` / `run_scenarios.py` contain the logic for running the scenarios

### Models

Model checkpoints are archived on Zenodo (DOI: [https://dx.doi.org/10.5281/zenodo.5153417](10.5281/zenodo.5153417)) but also available in the `model` subdirectory.
Unfortunately, we could only serialize the models as pickle files wherefore the same Python version and package versions are needed for reusing the models.
35 changes: 12 additions & 23 deletions paper/causalimpact_sweep.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,21 @@
from pyexpat import features
from darts.models.forecasting.gradient_boosted_model import LightGBMModel
import wandb
from darts.models import TCNModel
import pandas as pd
from darts.metrics import mape, mae
from darts import TimeSeries
from darts.dataprocessing.transformers import Scaler
from copy import deepcopy
import numpy as np

import logging
import click
import pickle
from functools import partial
from aeml.models.utils import split_data, choose_index

import click
import numpy as np
import pandas as pd
import wandb
from darts import TimeSeries
from darts.dataprocessing.transformers import Scaler
from darts.metrics import mae
from darts.models.forecasting.gradient_boosted_model import LightGBMModel

from aeml.causalimpact.utils import get_timestep_tuples, get_causalimpact_splits
import pickle
from aeml.causalimpact.utils import _select_unrelated_x
from aeml.causalimpact.utils import get_causalimpact_splits
from aeml.models.gbdt.gbmquantile import LightGBMQuantileRegressor
from aeml.models.gbdt.run import run_model
from aeml.models.gbdt.settings import *

from darts.dataprocessing.transformers import Scaler
from darts import TimeSeries
import pandas as pd
from copy import deepcopy
import time

from aeml.models.utils import choose_index, split_data

log = logging.getLogger(__name__)

Expand Down
Loading

0 comments on commit 8ab6e59

Please sign in to comment.