Skip to content

Commit

Permalink
shap warning msg bug
Browse files Browse the repository at this point in the history
  • Loading branch information
sronilsson committed May 8, 2024
1 parent 18f5e1e commit 25b46e6
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 65 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

setuptools.setup(
name="Simba-UW-tf-dev",
version="1.91.6",
version="1.91.7",
author="Simon Nilsson, Jia Jie Choong, Sophia Hwang",
author_email="[email protected]",
description="Toolkit for computer classification of behaviors in experimental animals",
Expand Down
37 changes: 16 additions & 21 deletions simba/mixins/train_model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,20 +833,18 @@ def split_and_group_df(
obs_per_split = len(data_arr[0])
return data_arr, obs_per_split

def create_shap_log(
self,
ini_file_path: str,
rf_clf: RandomForestClassifier,
x_df: pd.DataFrame,
y_df: pd.Series,
x_names: List[str],
clf_name: str,
cnt_present: int,
cnt_absent: int,
save_it: int = 100,
save_path: Optional[Union[str, os.PathLike]] = None,
save_file_no: Optional[int] = None,
) -> Union[None, Tuple[pd.DataFrame]]:
def create_shap_log(self,
ini_file_path: str,
rf_clf: RandomForestClassifier,
x_df: pd.DataFrame,
y_df: pd.Series,
x_names: List[str],
clf_name: str,
cnt_present: int,
cnt_absent: int,
save_it: int = 100,
save_path: Optional[Union[str, os.PathLike]] = None,
save_file_no: Optional[int] = None) -> Union[None, Tuple[pd.DataFrame]]:
"""
Compute SHAP values for a random forest classifier.
Expand Down Expand Up @@ -1867,12 +1865,8 @@ def create_shap_log_mp(
batch_size = 1
if len(shap_data_df) > 100:
batch_size = 100
print(
f"Computing {len(shap_data_df)} SHAP values (MULTI-CORE BATCH SIZE: {batch_size}, FOLLOW PROGRESS IN OS TERMINAL)..."
)
shap_data, _ = self.split_and_group_df(
df=shap_data_df, splits=int(len(shap_data_df) / batch_size)
)
print(f"Computing {len(shap_data_df)} SHAP values (MULTI-CORE BATCH SIZE: {batch_size}, FOLLOW PROGRESS IN OS TERMINAL)...")
shap_data, _ = self.split_and_group_df(df=shap_data_df, splits=int(len(shap_data_df) / batch_size))
shap_results, shap_raw = [], []
try:
with multiprocessing.Pool(cores, maxtasksperchild=10) as pool:
Expand Down Expand Up @@ -1930,7 +1924,8 @@ def create_shap_log_mp(
else:
return (shap_save_df, raw_save_df, int(expected_value * 100))

except:
except Exception as e:
print(e.args)
ShapWarning(
msg="Multiprocessing SHAP values failed. Revert to single core. This will negatively affect run-time. ",
source=self.__class__.__name__,
Expand Down
59 changes: 28 additions & 31 deletions simba/model/grid_search_rf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ class GridSearchRandomForestClassifier(ConfigReader, TrainModelMixin):
:param str config_path: path to SimBA project config file in Configparser format
Example
----------
:example:
>>> _ = GridSearchRandomForestClassifier(config_path='MyConfigPath').run()
"""

Expand Down Expand Up @@ -295,37 +294,30 @@ def run(self):
if MLParamKeys.SHAP_MULTIPROCESS.value in meta_dict.keys():
shap_multiprocess = meta_dict[MLParamKeys.SHAP_MULTIPROCESS.value]

if (
meta_dict[MLParamKeys.SHAP_SCORES.value]
in Options.PERFORM_FLAGS.value
):
if (meta_dict[MLParamKeys.SHAP_SCORES.value] in Options.PERFORM_FLAGS.value):
if not shap_multiprocess in Options.PERFORM_FLAGS.value:
self.create_shap_log(
ini_file_path=self.config_path,
rf_clf=self.rf_clf,
x_df=self.x_train,
y_df=self.y_train,
x_names=self.feature_names,
clf_name=self.clf_name,
cnt_present=meta_dict[MLParamKeys.SHAP_PRESENT.value],
cnt_absent=meta_dict[MLParamKeys.SHAP_ABSENT.value],
save_path=self.model_dir_out,
save_it=save_n,
save_file_no=config_cnt,
)
self.create_shap_log(ini_file_path=self.config_path,
rf_clf=self.rf_clf,
x_df=self.x_train,
y_df=self.y_train,
x_names=self.feature_names,
clf_name=self.clf_name,
cnt_present=meta_dict[MLParamKeys.SHAP_PRESENT.value],
cnt_absent=meta_dict[MLParamKeys.SHAP_ABSENT.value],
save_path=self.model_dir_out,
save_it=save_n,
save_file_no=config_cnt)
else:
self.create_shap_log_mp(
ini_file_path=self.config_path,
rf_clf=self.rf_clf,
x_df=self.x_train,
y_df=self.y_train,
x_names=self.feature_names,
clf_name=self.clf_name,
cnt_present=meta_dict[MLParamKeys.SHAP_PRESENT.value],
cnt_absent=meta_dict[MLParamKeys.SHAP_ABSENT.value],
save_path=self.model_dir_out,
save_file_no=config_cnt,
)
self.create_shap_log_mp(ini_file_path=self.config_path,
rf_clf=self.rf_clf,
x_df=self.x_train,
y_df=self.y_train,
x_names=self.feature_names,
clf_name=self.clf_name,
cnt_present=meta_dict[MLParamKeys.SHAP_PRESENT.value],
cnt_absent=meta_dict[MLParamKeys.SHAP_ABSENT.value],
save_path=self.model_dir_out,
save_file_no=config_cnt)

if MLParamKeys.PARTIAL_DEPENDENCY.value in meta_dict.keys():
if (
Expand Down Expand Up @@ -355,6 +347,11 @@ def run(self):
)




# test = GridSearchRandomForestClassifier(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/open_field_below/project_folder/project_config.ini')
# test.run()

# test = GridSearchRandomForestClassifier(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini')
# test.run()

Expand Down
20 changes: 8 additions & 12 deletions simba/plotting/shap_agg_stats_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,13 @@ class ShapAggregateStatisticsVisualizer(ConfigReader):
>>> _ = ShapAggregateStatisticsVisualizer(config_path='SimBAConfigFilePath', classifier_name='Attack', shap_df='tests/test_data/test_shap/data/test_shap.csv', shap_baseline_value=4, save_path='SaveDirectory')
"""

def __init__(
self,
config_path: Union[str, os.PathLike],
shap_df: pd.DataFrame,
classifier_name: str,
shap_baseline_value: int,
visualization: Optional[bool] = True,
save_path: Optional[Union[str, os.PathLike]] = None,
):
def __init__(self,
config_path: Union[str, os.PathLike],
shap_df: pd.DataFrame,
classifier_name: str,
shap_baseline_value: int,
visualization: Optional[bool] = True,
save_path: Optional[Union[str, os.PathLike]] = None):

check_file_exist_and_readable(file_path=config_path)
check_instance(
Expand All @@ -77,9 +75,7 @@ def __init__(

ConfigReader.__init__(self, config_path=config_path)
if (self.pose_setting != "14") and (self.pose_setting != "16"):
ShapWarning(
"SHAP visualizations/aggregate stats skipped (only viable for projects with two animals and default 7 or 8 body-parts per animal) ..."
)
ShapWarning(msg="SHAP visualizations/aggregate stats skipped (only viable for projects with two animals and default 7 or 8 body-parts per animal) ...", source=self.__class__.__name__)
else:
self.classifier_name, self.shap_df, self.shap_baseline_value = (
classifier_name,
Expand Down

0 comments on commit 25b46e6

Please sign in to comment.