Skip to content

Commit

Permalink
🎨 allow custom subselection, add NA if not available
Browse files Browse the repository at this point in the history
- Figure 2: add custom selection of models to aggregate best 5 models
  of several datasets (custom plotting for paper)
- rotate performance label
- add NA if model did not run (here: error or not finished within 24h)
  • Loading branch information
Henry committed Nov 26, 2023
1 parent c9e00e4 commit 748a1d7
Show file tree
Hide file tree
Showing 4 changed files with 322 additions and 90 deletions.
207 changes: 163 additions & 44 deletions project/01_2_performance_plots.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"import random\n",
"from pathlib import Path\n",
"\n",
"from IPython.display import display\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
Expand Down Expand Up @@ -119,6 +120,7 @@
"# Machine parsed metadata from rawfile workflow\n",
"fn_rawfile_metadata: str = 'data/dev_datasets/HeLa_6070/files_selected_metadata_N50.csv'\n",
"models: str = 'Median,CF,DAE,VAE' # picked models to compare (comma separated)\n",
"sel_models: str = '' # user defined comparison (comma separated)\n",
"# Restrict plotting to top N methods for imputation based on error of validation data, maximum 10\n",
"plot_to_n: int = 5"
]
Expand Down Expand Up @@ -184,7 +186,10 @@
"METRIC = 'MAE'\n",
"MIN_FREQ = None\n",
"MODELS_PASSED = args.models.split(',')\n",
"MODELS = MODELS_PASSED.copy()"
"MODELS = MODELS_PASSED.copy()\n",
"SEL_MODELS = None\n",
"if args.sel_models:\n",
" SEL_MODELS = args.sel_models.split(',')"
]
},
{
Expand Down Expand Up @@ -243,7 +248,7 @@
"id": "ffc6d140-f48e-4477-84f3-47a196e0a3d8",
"metadata": {},
"source": [
"## Across data completeness"
"## data completeness across entire data"
]
},
{
Expand All @@ -258,7 +263,6 @@
"# load frequency of training features...\n",
"# needs to be pickle -> index.name needed\n",
"freq_feat = vaep.io.datasplits.load_freq(args.data, file='freq_features.json')\n",
"\n",
"freq_feat.head() # training data"
]
},
Expand All @@ -272,7 +276,15 @@
"outputs": [],
"source": [
"prop = freq_feat / len(data.train_X.index.levels[0])\n",
"prop.to_frame()"
"prop.sort_values().to_frame().plot()"
]
},
{
"cell_type": "markdown",
"id": "19e5adfb",
"metadata": {},
"source": [
"View training data in wide format"
]
},
{
Expand All @@ -288,6 +300,14 @@
"data.train_X"
]
},
{
"cell_type": "markdown",
"id": "21102a1d",
"metadata": {},
"source": [
"Number of samples and features:"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -301,6 +321,14 @@
"print(f\"N samples: {N_SAMPLES:,d}, M features: {M_FEAT}\")"
]
},
{
"cell_type": "markdown",
"id": "61186a4e",
"metadata": {},
"source": [
"Collect outputs in excel file:"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -312,15 +340,16 @@
"source": [
"fname = args.folder_experiment / '01_2_performance_summary.xlsx'\n",
"dumps[fname.stem] = fname\n",
"writer = pd.ExcelWriter(fname)"
"writer = pd.ExcelWriter(fname)\n",
"print(f\"Saving to: {fname}\")"
]
},
{
"cell_type": "markdown",
"id": "bbe028c4-190d-4d50-b8a7-d109817d7b98",
"metadata": {},
"source": [
"# Model specifications\n",
"## Model specifications\n",
"- used for bar plot annotations"
]
},
Expand Down Expand Up @@ -365,19 +394,8 @@
"outputs": [],
"source": [
"# index name\n",
"freq_feat.index.name = data.train_X.columns.name"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8088a91f-6aaa-4b9d-b855-332d2bbf5780",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# index name\n",
"freq_feat.index.name = data.train_X.columns.name\n",
"# sample index name\n",
"sample_index_name = data.train_X.index.name"
]
},
Expand Down Expand Up @@ -446,7 +464,7 @@
"lines_to_next_cell": 0
},
"source": [
"## Select top N for plotting and set colors"
"### Select top N for plotting and set colors"
]
},
{
Expand Down Expand Up @@ -478,18 +496,20 @@
"source": [
"mae_stats_ordered_val = errors_val.abs().describe()[ORDER_MODELS]\n",
"mae_stats_ordered_val.to_excel(writer, sheet_name='mae_stats_ordered_val', float_format='%.5f')\n",
"mae_stats_ordered_val"
"mae_stats_ordered_val.T"
]
},
{
"cell_type": "markdown",
"id": "f5b33f93",
"metadata": {
"lines_to_next_cell": 0
},
"metadata": {},
"source": [
"Hack color order, by assing CF, DAE and VAE unique colors no matter their order\n",
"Could be extended to all supported imputation methods"
"Some model have fixed colors, others are assigned randomly\n",
"\n",
"> Note\n",
">\n",
"> 1. The order of \"new\" models is important for the color assignment.\n",
"> 2. User defined model keys for the same model with two configuration will yield different colors."
]
},
{
Expand All @@ -514,12 +534,9 @@
},
"outputs": [],
"source": [
"# For top_N -> define colors\n",
"TOP_N_ORDER = ORDER_MODELS[:args.plot_to_n]\n",
"\n",
"TOP_N_COLOR_PALETTE = {model: color for model,\n",
" color in zip(TOP_N_ORDER, COLORS_TO_USE)}\n",
"\n",
"TOP_N_ORDER"
]
},
Expand Down Expand Up @@ -678,7 +695,7 @@
},
"outputs": [],
"source": [
"errors_val.describe() # mean of means"
"errors_val.describe()[ORDER_MODELS].T # mean of means"
]
},
{
Expand All @@ -692,7 +709,7 @@
"outputs": [],
"source": [
"c_avg_error = 2\n",
"mask = (errors_val[MODELS] >= c_avg_error).any(axis=1)\n",
"mask = (errors_val[TOP_N_ORDER] >= c_avg_error).any(axis=1)\n",
"errors_val.loc[mask]"
]
},
Expand All @@ -715,15 +732,16 @@
"outputs": [],
"source": [
"fig, ax = plt.subplots(figsize=(8, 3))\n",
"ax, errors_binned = vaep.plotting.errors.plot_errors_binned(\n",
"ax, errors_binned = vaep.plotting.errors.plot_errors_by_median(\n",
" pred_val[\n",
" [TARGET_COL] + TOP_N_ORDER\n",
" ],\n",
" feat_medians=data.train_X.median(),\n",
" ax=ax,\n",
" palette=TOP_N_COLOR_PALETTE,\n",
" metric_name=METRIC,)\n",
"ax.set_ylabel(f\"Average error ({METRIC})\")\n",
"fname = args.out_figures / f'2_{group}_errors_binned_by_int_val.pdf'\n",
"fname = args.out_figures / f'2_{group}_errors_binned_by_feat_median_val.pdf'\n",
"figures[fname.stem] = fname\n",
"vaep.savefig(ax.get_figure(), name=fname)"
]
Expand Down Expand Up @@ -845,7 +863,7 @@
"lines_to_next_cell": 0
},
"source": [
"## Intensity distribution as histogram\n",
"### Intensity distribution as histogram\n",
"Plot top 4 models predictions for intensities in test data"
]
},
Expand Down Expand Up @@ -880,8 +898,8 @@
" ax=ax,\n",
" alpha=0.5,\n",
" )\n",
" _ = [(l.set_rotation(90))\n",
" for l in ax.get_xticklabels()]\n",
" _ = [(l_.set_rotation(90))\n",
" for l_ in ax.get_xticklabels()]\n",
" ax.legend()\n",
"\n",
"axes[0].set_ylabel('Number of observations')\n",
Expand Down Expand Up @@ -1217,7 +1235,7 @@
" build_text,\n",
" axis=1)\n",
"except KeyError:\n",
" logger.warning(\"No model PIMMS models in comparsion. Using empty text\")\n",
" logger.warning(\"No PIMMS models in comparsion. Using empty text\")\n",
" text = pd.Series('', index=model_configs.columns)\n",
"\n",
"_to_plot.loc[\"text\"] = text\n",
Expand All @@ -1235,12 +1253,13 @@
"outputs": [],
"source": [
"fig, ax = plt.subplots(figsize=(4, 2))\n",
"ax = _to_plot.loc[[feature_names.name]].plot.bar(rot=0,\n",
" ylabel=f\"{METRIC} for {feature_names.name} ({n_in_comparison:,} intensities)\",\n",
" # title=f'performance on test data (based on {n_in_comparison:,} measurements)',\n",
" color=COLORS_TO_USE,\n",
" ax=ax,\n",
" width=.8)\n",
"ax = _to_plot.loc[[feature_names.name]].plot.bar(\n",
" rot=0,\n",
" ylabel=f\"{METRIC} for {feature_names.name} ({n_in_comparison:,} intensities)\",\n",
" # title=f'performance on test data (based on {n_in_comparison:,} measurements)',\n",
" color=COLORS_TO_USE,\n",
" ax=ax,\n",
" width=.8)\n",
"ax = vaep.plotting.add_height_to_barplot(ax, size=5)\n",
"ax = vaep.plotting.add_text_to_barplot(ax, _to_plot.loc[\"text\"], size=5)\n",
"ax.set_xticklabels([])\n",
Expand Down Expand Up @@ -1273,7 +1292,7 @@
"id": "d88c21c7",
"metadata": {},
"source": [
"Plot error by median feature intensity"
"### Plot error by median feature intensity"
]
},
{
Expand Down Expand Up @@ -1306,6 +1325,106 @@
"errors_binned"
]
},
{
"cell_type": "markdown",
"id": "26370a1a",
"metadata": {},
"source": [
"### Custom model selection"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "712faf9a",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"if SEL_MODELS:\n",
" metrics = vaep.models.Metrics()\n",
" test_metrics = metrics.add_metrics(\n",
" pred_test[['observed', *SEL_MODELS]], key='test data')\n",
" test_metrics = pd.DataFrame(test_metrics)[SEL_MODELS]\n",
" test_metrics\n",
"\n",
" n_in_comparison = int(test_metrics.loc['N'].unique()[0])\n",
" n_in_comparison\n",
"\n",
" _to_plot = test_metrics.loc[METRIC].to_frame().T\n",
" _to_plot.index = [feature_names.name]\n",
" _to_plot\n",
"\n",
" try:\n",
" text = model_configs[[\"latent_dim\", \"hidden_layers\"]].apply(\n",
" build_text,\n",
" axis=1)\n",
" except KeyError:\n",
" logger.warning(\"No PIMMS models in comparsion. Using empty text\")\n",
" text = pd.Series('', index=model_configs.columns)\n",
"\n",
" _to_plot.loc[\"text\"] = text\n",
" _to_plot = _to_plot.fillna('')\n",
" _to_plot\n",
"\n",
" fig, ax = plt.subplots(figsize=(4, 2))\n",
" ax = _to_plot.loc[[feature_names.name]].plot.bar(\n",
" rot=0,\n",
" ylabel=f\"{METRIC} for {feature_names.name} ({n_in_comparison:,} intensities)\",\n",
" # title=f'performance on test data (based on {n_in_comparison:,} measurements)',\n",
" color=COLORS_TO_USE,\n",
" ax=ax,\n",
" width=.8)\n",
" ax = vaep.plotting.add_height_to_barplot(ax, size=5)\n",
" ax = vaep.plotting.add_text_to_barplot(ax, _to_plot.loc[\"text\"], size=5)\n",
" ax.set_xticklabels([])\n",
" fname = args.out_figures / f'2_{group}_performance_test_sel.pdf'\n",
" figures[fname.stem] = fname\n",
" vaep.savefig(fig, name=fname)\n",
"\n",
" dumps[fname.stem] = fname.with_suffix('.csv')\n",
" _to_plot_long = _to_plot.T\n",
" _to_plot_long = _to_plot_long.rename(\n",
" {feature_names.name: 'metric_value'}, axis=1)\n",
" _to_plot_long['data level'] = feature_names.name\n",
" _to_plot_long = _to_plot_long.set_index('data level', append=True)\n",
" _to_plot_long.to_csv(fname.with_suffix('.csv'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2a578570",
"metadata": {},
"outputs": [],
"source": [
"# custom selection\n",
"if SEL_MODELS:\n",
" vaep.plotting.make_large_descriptors(6)\n",
" fig, ax = plt.subplots(figsize=(8, 2))\n",
"\n",
" ax, errors_binned = vaep.plotting.errors.plot_errors_by_median(\n",
" pred=pred_test[\n",
" [TARGET_COL] + SEL_MODELS\n",
" ],\n",
" feat_medians=data.train_X.median(),\n",
" ax=ax,\n",
" metric_name=METRIC,\n",
" palette=COLORS_TO_USE\n",
" )\n",
" ax.set_ylim(0, 1.5)\n",
" # for text in ax.legend().get_texts():\n",
" # text.set_fontsize(6)\n",
" fname = args.out_figures / f'2_{group}_test_errors_binned_by_feat_medians_sel.pdf'\n",
" figures[fname.stem] = fname\n",
" vaep.savefig(ax.get_figure(), name=fname)\n",
" # vaep.plotting.make_large_descriptors(6)\n",
" dumps[fname.stem] = fname.with_suffix('.csv')\n",
" errors_binned.to_csv(fname.with_suffix('.csv'))\n",
" display(errors_binned)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
Loading

0 comments on commit 748a1d7

Please sign in to comment.