-
Notifications
You must be signed in to change notification settings - Fork 146
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
071f9b6
commit b9aebfa
Showing
31 changed files
with
1,560 additions
and
500 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "abb4b6b3", | ||
"metadata": {}, | ||
"source": [ | ||
"# Shapley calculations: Example II (MULTI-CORE)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "0ee78a39", | ||
"metadata": {}, | ||
"source": [ | ||
"In this example, we have previously created a classifier. We have the data used to create this classifier, and now we want to compute SHAP explainability scores for this classifier using multiprocessing (to speed things up)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "1b35abe6", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from simba.mixins.train_model_mixin import TrainModelMixin\n", | ||
"from simba.mixins.config_reader import ConfigReader\n", | ||
"from simba.utils.read_write import read_df, read_config_file\n", | ||
"import glob" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "55fd00d6", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# DEFINITIONS\n", | ||
"CONFIG_PATH = r\"C:\\troubleshooting\\mitra\\project_folder\\project_config.ini\"\n", | ||
"CLASSIFIER_PATH = r\"C:\\troubleshooting\\mitra\\models\\generated_models\\grooming.sav\"\n", | ||
"CLASSIFIER_NAME = 'grooming'\n", | ||
"COUNT_PRESENT = 250\n", | ||
"COUNT_ABSENT = 250" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "b4c29a21", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# READ IN THE CONFIG AND THE CLASSIFIER\n", | ||
"config = read_config_file(config_path=CONFIG_PATH)\n", | ||
"config_object = ConfigReader(config_path=CONFIG_PATH)\n", | ||
"clf = read_df(file_path=CLASSIFIER_PATH, file_type='pickle')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "e1381c40", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# READ IN THE DATA \n", | ||
"\n", | ||
"#Read in the path to all files inside the project_folder/csv/targets_inserted directory\n", | ||
"file_paths = glob.glob(config_object.targets_folder + '/*' + config_object.file_type)\n", | ||
"\n", | ||
"#Reads in the data held in all files in ``file_paths`` defined above\n", | ||
"data, _ = TrainModelMixin().read_all_files_in_folder_mp(file_paths=file_paths, file_type=config.get('General settings', 'workflow_file_type').strip())\n", | ||
"\n", | ||
"#We find all behavior annotations that are NOT the targets. I.e., if SHAP values for Attack is going to be calculated, bit we need to find which other annotations exist in the data e.g., Escape and Defensive.\n", | ||
"non_target_annotations = TrainModelMixin().read_in_all_model_names_to_remove(config=config, model_cnt=config_object.clf_cnt, clf_name=CLASSIFIER_NAME)\n", | ||
"\n", | ||
"# We remove the body-part coordinate columns and the annotations which are not the target from the data \n", | ||
"data = data.reset_index(drop=True).drop(non_target_annotations + config_object.bp_headers, axis=1)\n", | ||
"\n", | ||
"# We place the target data in its own variable\n", | ||
"target_df = data.pop(CLASSIFIER_NAME)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 16, | ||
"id": "e978462c", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Computing 20 SHAP values (MULTI-CORE BATCH SIZE: 10, FOLLOW PROGRESS IN OS TERMINAL)...\n", | ||
"Concatenating multi-processed SHAP data (batch 1/2)\n", | ||
"Concatenating multi-processed SHAP data (batch 2/2)\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"[Parallel(n_jobs=32)]: Using backend ThreadingBackend with 32 concurrent workers.\n", | ||
"[Parallel(n_jobs=32)]: Done 136 tasks | elapsed: 0.0s\n", | ||
"[Parallel(n_jobs=32)]: Done 386 tasks | elapsed: 0.0s\n", | ||
"[Parallel(n_jobs=32)]: Done 500 out of 500 | elapsed: 0.0s finished\n", | ||
"[Parallel(n_jobs=32)]: Using backend ThreadingBackend with 32 concurrent workers.\n", | ||
"[Parallel(n_jobs=32)]: Done 136 tasks | elapsed: 0.0s\n", | ||
"[Parallel(n_jobs=32)]: Done 386 tasks | elapsed: 0.0s\n", | ||
"[Parallel(n_jobs=32)]: Done 500 out of 500 | elapsed: 0.0s finished\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"SIMBA COMPLETE: SHAP calculations complete (elapsed time: 18.611s) \tcomplete\n", | ||
"SIMBA WARNING: ShapWarning: SHAP visualizations/aggregate stats skipped (only viable for projects with two animals and default 7 or 8 body-parts per animal) ... \twarning\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# We define a SHAP computer intance using the data created and defined in the prior two cells.\n", | ||
"TrainModelMixin().create_shap_log_mp(ini_file_path=CONFIG_PATH,\n", | ||
" rf_clf=clf,\n", | ||
" x_df=data,\n", | ||
" y_df=target_df,\n", | ||
" x_names=data.columns,\n", | ||
" clf_name=CLASSIFIER_NAME,\n", | ||
" cnt_present=COUNT_PRESENT,\n", | ||
" cnt_absent=COUNT_ABSENT,\n", | ||
" save_path=config_object.logs_path)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "11e58a5b", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "simba", | ||
"language": "python", | ||
"name": "simba" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.13" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "486d83d7", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"shap_example_2" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "simba", | ||
"language": "python", | ||
"name": "simba" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.13" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.