Skip to content

Commit

Permalink
feat: merge changes required after first revision of the paper (state…
Browse files Browse the repository at this point in the history
…: d614a02)
  • Loading branch information
noxthot committed Oct 9, 2024
1 parent 2ed3ed2 commit 364de7a
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 114 deletions.
10 changes: 4 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,16 @@ How to retrieve the data is described in `data-preprocessing/README_preprocessin
The order of the following list defines the order in which the scripts should be run.
- `etl.py`: This data pipeline transforms the raw data (see previous subsection) into the format that is required for training, testing and analysing with the following files.
- `train.py`: Trains a neural network on the transformed data.
- `test.py`: Evaluates the performance of the trained neural network on previously unseen test data. This file was used to compute the corresponding confusion matrix in Table 2.
- `test.py`: Evaluates the performance of the trained neural network on previously unseen test data. This file was used to compute the corresponding confusion matrix in Table 3.
- `test_shap.py`: Computes the shapley values using the trained model on the test data.
- `validation_scores.py`: Computes the classification threshold such that the diurnal cycle is least biased on the validation data.
- `analyse_diurnal_cycles.py`: Uses the previously calculated classification threshold to generates plots that visualize how that threshold performs in reproducing the diurnal cycle on previously unseen test data. This file produces figure 2 of the paper.
- `analyse_shap_and_features.ipynb`: Visualizes the shap and real values of the vertical profiles distinguishing between true positives, false positives, false negatives, aswell as providing some plots regarding cloud top and bottom height. This file was used to generate figures 3, 4, 5 and 7 of the paper.
- `flash_case_study_final.ipynb`: Visualizes network classifications at a specific time on a map of austria. This file generates figure 6 of the paper.
- `analyse_shap_and_features.ipynb`: Visualizes the shap and real values of the vertical profiles distinguishing between true positives, false positives, false negatives, aswell as providing some plots regarding cloud top and bottom height. This file was used to generate figures 1, 2, 3 and 4 of the paper.
- `flash_case_study_final.ipynb`: Visualizes network classifications at a specific time on a map of austria. This file generates figure 5 of the paper.

### Runnable Files (Reference model):
- `reference_model.R`: Trains the reference model.
- `reference_valpred.R`: Stores the model output on the validation data (used for calculating the classification threshold later on).
- `reference_test.py`: Evaluates the trained reference model on previously unseen test data. This file was used to compute the corresponding confusion matrix in Table 2.
- `reference_test.py`: Evaluates the trained reference model on previously unseen test data. This file was used to compute the corresponding confusion matrix in Table 3.

### Helper files
- `ccc.py`: Defining some global constants.
Expand All @@ -78,4 +77,3 @@ The order of the following list defines the order in which the scripts should be
# References
<a id="1">[1]</a>
[1] Ehrensperger, G., Simon, T., Mayr, G. & Hell, T. (2024). Identifying Lightning Processes in ERA5 Soundings with Deep Learning. arXiv (https://arxiv.org/abs/2210.11529)
>>>>>>> Stashed changes
43 changes: 32 additions & 11 deletions analyse_shap_and_features.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,6 @@
"PLOT_COLS = [f\"{c}_relative_shap\" for c in (ccc.LVL_TRAIN_COLS + [\"hour\", \"dayofyear\"])]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dd"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -130,7 +121,8 @@
"metadata": {},
"outputs": [],
"source": [
"used_threshold = 0.8708981871604919\n",
"used_threshold = 0.8708981871604919 # F1 score based threshold of epoch 18 on validation set\n",
"\n",
"vc_threshold = utils.getVeryConfidentThreshold(used_threshold)\n",
"\n",
"df_joined = utils.joinDataframes(dshap, dd)\n",
Expand Down Expand Up @@ -266,6 +258,8 @@
"df_mass = df_TP[mass_tp_sum > 0.5]\n",
"df_wind = df_TP[wind_tp_sum > 0.5]\n",
"\n",
"df_nodom = df_TP[(cloud_tp_sum <= 0.5) & (mass_tp_sum <= 0.5) & (wind_tp_sum <= 0.5)]\n",
"\n",
"df_cloud_plus_TN = pd.concat([df_cloud, df_TN])\n",
"df_cloudmass_plus_TN = pd.concat([df_cloudmass, df_TN])\n",
"df_cloudwind_plus_TN = pd.concat([df_cloudwind, df_TN])\n",
Expand All @@ -284,7 +278,34 @@
"print(f\"Number of samples in cloud-mass-dominant TPs:\\t\\t{len(df_cloudmass)}\")\n",
"print(f\"Number of samples in cloud-wind-dominant TPs:\\t\\t{len(df_cloudwind)}\")\n",
"print(f\"Number of samples in mass-dominant TPs: \\t\\t {len(df_mass)}\")\n",
"print(f\"Number of samples in wind-dominant TPs: \\t\\t{len(df_wind)}\")"
"print(f\"Number of samples in wind-dominant TPs: \\t\\t{len(df_wind)}\")\n",
"print(f\"Number of samples TPs without dominance: \\t\\t{len(df_nodom)}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(f\"Relative number of samples in cloud-dominant TPs:\\t\\t{100 * len(df_cloud) / len(df_TP)}\")\n",
"print(f\"Relative number of samples in cloud-mass-dominant TPs:\\t\\t{100 * len(df_cloudmass) / len(df_TP)}\")\n",
"print(f\"Relative number of samples in cloud-wind-dominant TPs:\\t\\t{100 * len(df_cloudwind) / len(df_TP)}\")\n",
"print(f\"Relative number of samples in mass-dominant TPs: \\t\\t {100 * len(df_mass) / len(df_TP)}\")\n",
"print(f\"Relative number of samples in wind-dominant TPs: \\t\\t{100 * len(df_wind) / len(df_TP)}\")\n",
"print(f\"Relative number of samples TPs without dominance: \\t\\t{100 * len(df_nodom) / len(df_TP)}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(f\"Number of samples being cloud and mass dominant at the same time:\\t\\t{len(df_cloud[['latitude']].join(df_mass[['longitude']], how='inner'))}\")\n",
"print(f\"Number of samples being cloud and wind dominant at the same time:\\t\\t{len(df_cloud[['latitude']].join(df_wind[['longitude']], how='inner'))}\")\n",
"print(f\"Number of samples being cloud, wind and mass dominant at the same time:\\t\\t{len(df_cloud[['latitude']].join(df_wind[['longitude']], how='inner').join(df_mass[['hour']], how='inner'))}\")\n",
"print(f\"Number of samples being mass and wind dominant at the same time:\\t\\t{len(df_mass[['latitude']].join(df_wind[['longitude']], how='inner'))}\")"
]
},
{
Expand Down
164 changes: 74 additions & 90 deletions analyse_shap_paper.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@
"outputs": [],
"source": [
"TARGET_MODE = 1\n",
"GRADIENT_EXPLAINER = False"
"GRADIENT_EXPLAINER = False\n",
"MODEL = \"2022_02_21__11-11__ALDIS_paper\""
]
},
{
Expand All @@ -69,7 +70,7 @@
"outputs": [],
"source": [
"tmsubdir = f'targetmode_{TARGET_MODE}'\n",
"model_root_tm_path = os.path.join(ccc.MODEL_ROOT_PATH, tmsubdir)"
"modelpath = os.path.join(ccc.MODEL_ROOT_PATH, f'targetmode_{TARGET_MODE}', MODEL)"
]
},
{
Expand All @@ -78,47 +79,14 @@
"metadata": {},
"outputs": [],
"source": [
"modeldirs = os.listdir(model_root_tm_path)\n",
"modeldirs.sort(reverse=True)\n",
"_, model_name = utils.load_model(os.path.join(f'targetmode_{TARGET_MODE}', MODEL), torch.device(\"cpu\"), \"18\")\n",
"\n",
"wmodel = widgets.Dropdown(\n",
" options=modeldirs,\n",
" value=modeldirs[0],\n",
" description='Choose a model:',\n",
")\n",
"\n",
"display(wmodel)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"modelpath = os.path.join(model_root_tm_path, wmodel.value)\n",
"shaps = os.listdir(modelpath)\n",
"shaps.sort()\n",
"\n",
"shaps = [x for x in shaps if x.endswith(\"test_scores.json\")]\n",
"prefix = \"_model_00018\"\n",
"\n",
"wshaps = widgets.Dropdown(\n",
" options=shaps,\n",
" value=shaps[0],\n",
" description='Choose a model:',\n",
")\n",
"shap_path = os.path.join(modelpath, prefix + \"_shap_parquet_bg_by_lon_lat_no_flash\")\n",
"df_path = os.path.join(modelpath, prefix + \"_test_df.pickle\")\n",
"\n",
"display(wshaps)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(os.path.join(modelpath, wshaps.value), 'r') as f:\n",
" test_scores_json = json.load(f)"
"dd = pd.read_pickle(df_path)"
]
},
{
Expand All @@ -127,14 +95,7 @@
"metadata": {},
"outputs": [],
"source": [
"_, model_name = utils.load_model(os.path.join(f'targetmode_{TARGET_MODE}', wmodel.value), torch.device(\"cpu\"), \"18\")\n",
"\n",
"prefix = \"_model_00018\"\n",
"\n",
"shap_path = os.path.join(modelpath, prefix + \"_shap_parquet_bg_by_lon_lat_no_flash\")\n",
"df_path = os.path.join(modelpath, prefix + \"_test_df.pickle\")\n",
"\n",
"dd = pd.read_pickle(df_path)"
"dd.columns"
]
},
{
Expand All @@ -143,7 +104,7 @@
"metadata": {},
"outputs": [],
"source": [
"dd.columns"
"used_threshold = 0.8708981871604919 # F1 score based threshold of epoch 18 on validation set"
]
},
{
Expand All @@ -152,10 +113,10 @@
"metadata": {},
"outputs": [],
"source": [
"vc_threshold = utils.getVeryConfidentThreshold(test_scores_json[\"used_threshold\"])\n",
"vc_threshold = utils.getVeryConfidentThreshold(used_threshold)\n",
"dd_transf = dd\n",
"\n",
"dd_transf.loc[:, \"pred_class\"] = np.where(dd_transf[\"output\"] > test_scores_json[\"used_threshold\"], \"pred_flash\", \"pred_no_flash\")\n",
"dd_transf.loc[:, \"pred_class\"] = np.where(dd_transf[\"output\"] > used_threshold, \"pred_flash\", \"pred_no_flash\")\n",
"dd_transf.loc[:, \"real_class\"] = np.where(dd_transf[\"target\"] > 0.5, \"real_flash\", \"real_no_flash\") # target col only contains 0s and 1s.\n",
"\n",
"dd_transf.loc[:, 'cat'] = np.select(\n",
Expand Down Expand Up @@ -225,7 +186,6 @@
"\n",
"print(\"Convert test data (excluding TNs) into spark df\", flush=True)\n",
"spark = utils.getsparksession()\n",
"#sparkdd = spark.createDataFrame(dd_transf.query(\"cat != 'TN'\"))\n",
"sparkdd = spark.createDataFrame(dd_transf)\n",
"\n",
"print(\"Join the two dfs\")\n",
Expand Down Expand Up @@ -281,10 +241,18 @@
"metadata": {},
"outputs": [],
"source": [
"size_of_smallest_cl = dd_enriched['cluster'].value_counts().min()\n",
"dd_enriched['cluster'].value_counts()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sample_size = 61431 # dd_enriched['cluster'].value_counts().min()"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -318,7 +286,16 @@
"metadata": {},
"outputs": [],
"source": [
"df_many_cases_sampled = df_many_cases.groupby('cluster').sample(size_of_smallest_cl)"
"df_many_cases_sampled = df_many_cases.groupby('cluster')[df_many_cases.columns].apply(lambda x: x.sample(n=sample_size) if len(x) > sample_size else x)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df_many_cases_sampled['cluster'].value_counts()"
]
},
{
Expand Down Expand Up @@ -378,14 +355,14 @@
"outputs": [],
"source": [
"ptype = \"q50\" # can be mult, q50, q95\n",
"use_cache = False\n",
"use_cache = True\n",
"write_cache = False\n",
"\n",
"separate_clusters = False\n",
"\n",
"plot_group = \"vartype_grouped\" # \"confmat\": TP, FP, TN, FN; \"vartype\": CLOUD_HIGH, MASS_HIGH, WIND_HIGH; \"vartype_ext\": \"CLOUD_MASS_HIGH\", \"CLOUD_WIND_HIGH\", \"MASS_HIGH\", \"WIND_HIGH\"\n",
"\n",
"only_show_cols = [] # [\"ciwc\", \"cswc\"]\n",
"only_show_cols = [] # [\"ciwc\", \"cswc\"]\n",
"\n",
"y_axis = \"geopotential_altitude\" # level, geopotential_altitude\n",
"\n",
Expand Down Expand Up @@ -429,22 +406,24 @@
" 3: 'TP_WIND_HIGH',\n",
" 4: 'TN',\n",
" }\n",
"\n",
" cloud_sum = df_shap_to_plot['cswc_shapsum'] + df_shap_to_plot['ciwc_shapsum'] + df_shap_to_plot['crwc_shapsum'] + df_shap_to_plot['clwc_shapsum']\n",
" mass_sum = df_shap_to_plot['q_shapsum'] + df_shap_to_plot['t_shapsum']\n",
" wind_sum = df_shap_to_plot['u_shapsum'] + df_shap_to_plot['v_shapsum'] + df_shap_to_plot['w_shapsum']\n",
" \n",
" df_shap_to_plot.loc[:, 'cluster'] = np.select(\n",
" [\n",
" df_shap_to_plot['cat'].isin(['TP_LC', 'TP_VC']) & (df_shap_to_plot['cswc_shapsum'] + df_shap_to_plot['ciwc_shapsum'] + df_shap_to_plot['crwc_shapsum'] + df_shap_to_plot['clwc_shapsum'] > 0.5),\n",
" df_shap_to_plot['cat'].isin(['TP_LC', 'TP_VC']) & (df_shap_to_plot['q_shapsum'] + df_shap_to_plot['t_shapsum'] > 0.5),\n",
" df_shap_to_plot['cat'].isin(['TP_LC', 'TP_VC']) & (df_shap_to_plot['u_shapsum'] + df_shap_to_plot['v_shapsum'] + df_shap_to_plot['w_shapsum'] > 0.5),\n",
" df_shap_to_plot['cat'].isin(['TN']),\n",
" ], \n",
" [\n",
" 1, \n",
" 2,\n",
" 3,\n",
" 4,\n",
" ], \n",
" default=-1\n",
" )\n",
" df_tempshap_cloud = df_shap_to_plot[df_shap_to_plot['cat'].isin(['TP_LC', 'TP_VC']) & (cloud_sum > 0.5)]\n",
" df_tempshap_mass = df_shap_to_plot[df_shap_to_plot['cat'].isin(['TP_LC', 'TP_VC']) & (mass_sum > 0.5)]\n",
" df_tempshap_wind = df_shap_to_plot[df_shap_to_plot['cat'].isin(['TP_LC', 'TP_VC']) & (wind_sum > 0.5)]\n",
" df_tempshap_tn = df_shap_to_plot[df_shap_to_plot['cat'].isin(['TN'])]\n",
" \n",
" df_tempshap_tn = df_shap_to_plot[df_shap_to_plot['cat'].isin(['TN'])]\n",
"\n",
" df_tempshap_cloud.loc[:, \"cluster\"] = 1\n",
" df_tempshap_mass.loc[:, \"cluster\"] = 2\n",
" df_tempshap_wind.loc[:, \"cluster\"] = 3\n",
" df_tempshap_tn.loc[:, \"cluster\"] = 4\n",
"\n",
" df_shap_to_plot = pd.concat([df_tempshap_cloud, df_tempshap_mass, df_tempshap_wind, df_tempshap_tn], ignore_index=True)\n",
"elif plot_group == \"vartype_grouped\":\n",
" plot_clusters = {\n",
" 1: 'TP_CLOUD_MASS_HIGH',\n",
Expand All @@ -458,29 +437,25 @@
" mass_sum = df_shap_to_plot['q_shapsum'] + df_shap_to_plot['t_shapsum']\n",
" wind_sum = df_shap_to_plot['u_shapsum'] + df_shap_to_plot['v_shapsum'] + df_shap_to_plot['w_shapsum']\n",
" \n",
" df_shap_to_plot.loc[:, 'cluster'] = np.select(\n",
" [\n",
" df_shap_to_plot['cat'].isin(['TP_LC', 'TP_VC']) & (cloud_sum > 0.5) & (mass_sum > wind_sum),\n",
" df_shap_to_plot['cat'].isin(['TP_LC', 'TP_VC']) & (cloud_sum > 0.5) & (mass_sum <= wind_sum),\n",
" df_shap_to_plot['cat'].isin(['TP_LC', 'TP_VC']) & (mass_sum > 0.5),\n",
" df_shap_to_plot['cat'].isin(['TP_LC', 'TP_VC']) & (wind_sum > 0.5),\n",
" df_shap_to_plot['cat'].isin(['TN']),\n",
" ], \n",
" [\n",
" 1, \n",
" 2,\n",
" 3,\n",
" 4,\n",
" 5,\n",
" ], \n",
" default=-1\n",
" )\n",
" df_tempshap_cloudmass = df_shap_to_plot[df_shap_to_plot['cat'].isin(['TP_LC', 'TP_VC']) & (cloud_sum > 0.5) & (mass_sum > wind_sum)]\n",
" df_tempshap_cloudwind = df_shap_to_plot[df_shap_to_plot['cat'].isin(['TP_LC', 'TP_VC']) & (cloud_sum > 0.5) & (mass_sum <= wind_sum)]\n",
" df_tempshap_mass = df_shap_to_plot[df_shap_to_plot['cat'].isin(['TP_LC', 'TP_VC']) & (mass_sum > 0.5)]\n",
" df_tempshap_wind = df_shap_to_plot[df_shap_to_plot['cat'].isin(['TP_LC', 'TP_VC']) & (wind_sum > 0.5)]\n",
" df_tempshap_tn = df_shap_to_plot[df_shap_to_plot['cat'].isin(['TN'])]\n",
"\n",
" df_tempshap_cloudmass.loc[:, \"cluster\"] = 1\n",
" df_tempshap_cloudwind.loc[:, \"cluster\"] = 2\n",
" df_tempshap_mass.loc[:, \"cluster\"] = 3\n",
" df_tempshap_wind.loc[:, \"cluster\"] = 4\n",
" df_tempshap_tn.loc[:, \"cluster\"] = 5\n",
"\n",
" df_shap_to_plot = pd.concat([df_tempshap_cloudmass, df_tempshap_cloudwind, df_tempshap_mass, df_tempshap_wind, df_tempshap_tn], ignore_index=True)\n",
"elif plot_group == \"confmat\":\n",
" plot_clusters = {\n",
" 0: 'TP less confident',\n",
"# 1: 'TP very confident',\n",
"# 2: 'FN',\n",
" 3: 'FP',\n",
" 2: 'FN',\n",
"# 3: 'FP',\n",
" 4: 'TN',\n",
" }\n",
" \n",
Expand All @@ -505,6 +480,15 @@
"df_shap_to_plot[\"cluster\"].value_counts()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df_shap_to_plot"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
2 changes: 1 addition & 1 deletion ccc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

SEED_RANDOM = 1337 # Set to None to disable

DATA_ROOT_PATH = os.path.join('.', 'data_processed_archived')
DATA_ROOT_PATH = os.path.join('.', 'data', 'data_processed')
MODEL_ROOT_PATH = os.path.join('.', 'data', 'models')
DATASTATS_PATH = os.path.join(".", "data", "data_stats")
CACHE_PATH = os.path.join(".", "data", "cache")
Expand Down
Loading

0 comments on commit 364de7a

Please sign in to comment.