Skip to content

Commit

Permalink
shaps
Browse files Browse the repository at this point in the history
  • Loading branch information
sronilsson committed Dec 26, 2024
1 parent 0149893 commit 8feadad
Show file tree
Hide file tree
Showing 9 changed files with 469 additions and 81 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ build/
*.whl

# dev
simba/sandbox/pose_estimation
simba/sandbox/

#local docs build
_build/html/
Expand Down
12 changes: 3 additions & 9 deletions simba/labelling/labelling_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,12 @@
import simba
from simba.mixins.config_reader import ConfigReader
from simba.ui.tkinter_functions import Entry_Box
from simba.utils.checks import (check_file_exist_and_readable, check_float,
check_int, check_str, check_that_column_exist,
check_valid_boolean, check_valid_dataframe,
check_valid_dict)
from simba.utils.checks import (check_file_exist_and_readable, check_float, check_int, check_str, check_that_column_exist, check_valid_boolean, check_valid_dataframe, check_valid_dict)
from simba.utils.enums import Options, TagNames
from simba.utils.errors import FrameRangeError, NoDataError, NoFilesFoundError
from simba.utils.lookups import (get_labelling_img_kbd_bindings,
get_labelling_video_kbd_bindings)
from simba.utils.lookups import (get_labelling_img_kbd_bindings, get_labelling_video_kbd_bindings)
from simba.utils.printing import log_event, stdout_success
from simba.utils.read_write import (get_all_clf_names, get_fn_ext,
get_video_meta_data, read_config_entry,
read_df, read_frm_of_video, write_df)
from simba.utils.read_write import (get_all_clf_names, get_fn_ext, get_video_meta_data, read_config_entry, read_df, read_frm_of_video, write_df)
from simba.utils.warnings import FrameRangeWarning

PLAY_VIDEO_SCRIPT_PATH = os.path.join(os.path.dirname(simba.__file__), "labelling/play_annotation_video.py")
Expand Down
23 changes: 8 additions & 15 deletions simba/mixins/train_model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,7 @@ def create_shap_log(self,
plot: bool = True,
save_it: Optional[int] = 100,
save_dir: Optional[Union[str, os.PathLike]] = None,
save_file_suffix: Optional[int] = None) -> Union[None, Tuple[pd.DataFrame]]:
save_file_suffix: Optional[int] = None) -> Union[None, Tuple[pd.DataFrame, pd.DataFrame, Dict[str, pd.DataFrame], np.ndarray]]:

"""
Compute SHAP values for a random forest classifier.
Expand Down Expand Up @@ -1675,9 +1675,7 @@ def read_all_files_in_folder_mp_futures(self,
raise_bool_clf_error=raise_bool_clf_error,
)

def check_raw_dataset_integrity(
self, df: pd.DataFrame, logs_path: Optional[Union[str, os.PathLike]]
) -> None:
def check_raw_dataset_integrity(self, df: pd.DataFrame, logs_path: Optional[Union[str, os.PathLike]]) -> None:
"""
Helper to check column-wise NaNs in raw input data for fitting model.
Expand Down Expand Up @@ -1744,7 +1742,7 @@ def create_shap_log_mp(self,
verbose: bool = True,
save_dir: Optional[Union[str, os.PathLike]] = None,
save_file_suffix: Optional[int] = None,
plot: bool = False) -> Union[None, Tuple[pd.DataFrame]]:
plot: bool = False) -> Union[None, Tuple[pd.DataFrame, pd.DataFrame, Dict[str, pd.DataFrame], np.ndarray]]:
"""
Compute SHAP values using multiprocessing.
Expand Down Expand Up @@ -1814,29 +1812,24 @@ def create_shap_log_mp(self,
if len(absent_df) < cnt_absent:
NotEnoughDataWarning(msg=f"Train data contains {len(absent_df)} behavior-absent annotations. This is less the number of frames you specified to calculate shap values for ({str(cnt_absent)}). SimBA will calculate shap scores for the {len(absent_df)} behavior-absent frames available", source=TrainModelMixin.create_shap_log_mp.__name__)
cnt_absent = len(absent_df)
shap_data = pd.concat(
[present_df.sample(cnt_present, replace=False), absent_df.sample(cnt_absent, replace=False)],
axis=0).reset_index(drop=True)
shap_data = pd.concat([present_df.sample(cnt_present, replace=False), absent_df.sample(cnt_absent, replace=False)], axis=0).reset_index(drop=True)
batch_cnt = max(1, int(np.ceil(len(shap_data) / chunk_size)))
shap_data = np.array_split(shap_data, batch_cnt)
shap_data = [(x, y) for x, y in enumerate(shap_data)]
explainer = TrainModelMixin().define_tree_explainer(clf=rf_clf)
expected_value = explainer.expected_value[1]
shap_results, shap_raw = [], []
print(
f"Computing {cnt_present + cnt_absent} SHAP values. Follow progress in OS terminal... (CORES: {core_cnt}, CHUNK SIZE: {chunk_size})")
print(f"Computing {cnt_present + cnt_absent} SHAP values. Follow progress in OS terminal... (CORES: {core_cnt}, CHUNK SIZE: {chunk_size})")
with multiprocessing.Pool(core_cnt, maxtasksperchild=Defaults.MAXIMUM_MAX_TASK_PER_CHILD.value) as pool:
constants = functools.partial(TrainModelMixin._create_shap_mp_helper, explainer=explainer, clf_name=clf_name, verbose=verbose)
for cnt, result in enumerate(pool.imap_unordered(constants, shap_data, chunksize=1)):
proba = TrainModelMixin().clf_predict_proba(clf=rf_clf,
x_df=shap_data[result[1]][1].drop(clf_name, axis=1),
model_name=clf_name).reshape(-1, 1)
proba = TrainModelMixin().clf_predict_proba(clf=rf_clf, x_df=shap_data[result[1]][1].drop(clf_name, axis=1), model_name=clf_name).reshape(-1, 1)
shap_sum = np.sum(result[0], axis=1).reshape(-1, 1)
batch_shap_results = np.hstack((result[0], np.full((result[0].shape[0]), expected_value).reshape(-1, 1), shap_sum, proba, shap_data[result[1]][1][clf_name].values.reshape(-1, 1))).astype(np.float32)
batch_shap_results = np.hstack((result[0], np.full((result[0].shape[0]), expected_value).reshape(-1, 1), shap_sum + expected_value, proba, shap_data[result[1]][1][clf_name].values.reshape(-1, 1))).astype(np.float32)
shap_results.append(batch_shap_results)
shap_raw.append(shap_data[result[1]][1].drop(clf_name, axis=1))
if verbose:
print(f"Completed SHAP data (Batch {result[1] + 1}/{len(shap_data)}).")
print(f"Completed SHAP batch (Batch {result[1] + 1}/{len(shap_data)}).")

pool.terminate(); pool.join()
shap_df = pd.DataFrame(data=np.row_stack(shap_results), columns=list(x_names) + ["Expected_value", "Sum", "Prediction_probability", clf_name])
Expand Down
26 changes: 12 additions & 14 deletions simba/model/grid_search_multiclass_rf.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,24 +271,22 @@ def run(self):

if (meta_dict[MLParamKeys.SHAP_SCORES.value] in Options.PERFORM_FLAGS.value):
if not shap_multiprocess in Options.PERFORM_FLAGS.value:
self.create_shap_log(
ini_file_path=self.config_path,
rf_clf=self.rf_clf,
x_df=self.x_train,
y_df=self.y_train,
x_names=self.feature_names,
clf_name=self.clf_name,
cnt_present=meta_dict[MLParamKeys.SHAP_PRESENT.value],
cnt_absent=meta_dict[MLParamKeys.SHAP_ABSENT.value],
save_path=self.model_dir_out,
save_it=save_n,
save_file_no=config_cnt,
)
self.create_shap_log(rf_clf=self.rf_clf,
x=self.x_train,
y=self.y_train,
x_names=list(self.feature_names),
clf_name=self.clf_name,
cnt_present=meta_dict[MLParamKeys.SHAP_PRESENT.value],
cnt_absent=meta_dict[MLParamKeys.SHAP_ABSENT.value],
verbose=True,
plot=shap_plot,
save_it=save_n,
save_dir=self.model_dir_out)
else:
self.create_shap_log_mp(rf_clf=self.rf_clf,
x=self.x_train,
y=self.y_train,
x_names=self.feature_names,
x_names=list(self.feature_names),
clf_name=self.clf_name,
cnt_present=meta_dict[MLParamKeys.SHAP_PRESENT.value],
cnt_absent=meta_dict[MLParamKeys.SHAP_ABSENT.value],
Expand Down
20 changes: 10 additions & 10 deletions simba/model/grid_search_rf.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,22 +125,22 @@ def run(self):
shap_multiprocess = meta_dict[MLParamKeys.SHAP_MULTIPROCESS.value]
if (meta_dict[MLParamKeys.SHAP_SCORES.value] in Options.PERFORM_FLAGS.value):
if not shap_multiprocess in Options.PERFORM_FLAGS.value:
self.create_shap_log(ini_file_path=self.config_path,
rf_clf=self.rf_clf,
x_df=self.x_train,
y_df=self.y_train,
x_names=self.feature_names,
self.create_shap_log(rf_clf=self.rf_clf,
x=self.x_train,
y=self.y_train,
x_names=list(self.feature_names),
clf_name=self.clf_name,
cnt_present=meta_dict[MLParamKeys.SHAP_PRESENT.value],
cnt_absent=meta_dict[MLParamKeys.SHAP_ABSENT.value],
save_path=self.model_dir_out,
verbose=True,
plot=shap_plot,
save_it=save_n,
save_file_no=config_cnt)
save_dir=self.model_dir_out)
else:
self.create_shap_log_mp(rf_clf=self.rf_clf,
x=self.x_train,
y=self.y_train,
x_names=self.feature_names,
x_names=list(self.feature_names),
clf_name=self.clf_name,
cnt_present=meta_dict[MLParamKeys.SHAP_PRESENT.value],
cnt_absent=meta_dict[MLParamKeys.SHAP_ABSENT.value],
Expand All @@ -159,8 +159,8 @@ def run(self):


#
# test = GridSearchRandomForestClassifier(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini")
# test.run()
test = GridSearchRandomForestClassifier(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini")
test.run()

#
# test = GridSearchRandomForestClassifier(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini")
Expand Down
26 changes: 13 additions & 13 deletions simba/model/train_multiclass_rf.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,23 +246,23 @@ def run(self):


if not self.shap_multiprocess in Options.PERFORM_FLAGS.value:
self.create_shap_log(
ini_file_path=self.config_path,
rf_clf=self.rf_clf,
x_df=self.x_train,
y_df=self.y_train,
x_names=self.feature_names,
clf_name=self.clf_name,
cnt_present=self.shap_target_present_cnt,
cnt_absent=self.shap_target_absent_cnt,
save_it=self.shap_save_n,
save_path=self.eval_out_path,
)
self.create_shap_log(rf_clf=self.rf_clf,
x=self.x_train,
y=self.y_train,
x_names=list(self.feature_names),
clf_name=self.clf_name,
cnt_present=self.shap_target_present_cnt,
cnt_absent=self.shap_target_absent_cnt,
verbose=True,
plot=shap_plot,
save_it=self.shap_save_n,
save_dir=self.eval_out_path)

else:
self.create_shap_log_mp(rf_clf=self.rf_clf,
x=self.x_train,
y=self.y_train,
x_names=self.feature_names,
x_names=list(self.feature_names),
clf_name=self.clf_name,
cnt_present=self.shap_target_present_cnt,
cnt_absent=self.shap_target_absent_cnt,
Expand Down
32 changes: 15 additions & 17 deletions simba/model/train_rf.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self,
self.feature_names = self.x_df.columns
self.check_sampled_dataset_integrity(x_df=self.x_df, y_df=self.y_df)
print(f"Number of features in dataset: {len(self.x_df.columns)}")
print(f"Number of {self.clf_name} frames in dataset: {self.y_df.sum()} ({str(round(self.y_df.sum() / len(self.y_df), 4) * 100)}%)")
print(f"Number of {self.clf_name} frames in dataset: {int(self.y_df.sum())} ({str(round(self.y_df.sum() / len(self.y_df), 4) * 100)}%)")

def perform_sampling(self):
"""
Expand Down Expand Up @@ -283,23 +283,22 @@ def run(self):
if generate_shap_scores in Options.PERFORM_FLAGS.value:
shap_plot = self.bp_config in {'14', '16'}
if not shap_multiprocess in Options.PERFORM_FLAGS.value:
self.create_shap_log(
ini_file_path=self.config_path,
rf_clf=self.rf_clf,
x_df=self.x_train,
y_df=self.y_train,
x_names=self.feature_names,
clf_name=self.clf_name,
cnt_present=shap_target_present_cnt,
cnt_absent=shap_target_absent_cnt,
save_it=shap_save_n,
save_path=self.eval_out_path,
)
self.create_shap_log(rf_clf=self.rf_clf,
x=self.x_train,
y=self.y_train,
x_names=list(self.feature_names),
clf_name=self.clf_name,
cnt_present=shap_target_present_cnt,
cnt_absent=shap_target_absent_cnt,
verbose=True,
plot=shap_plot,
save_it=shap_save_n,
save_dir=self.eval_out_path)
else:
self.create_shap_log_mp(rf_clf=self.rf_clf,
x=self.x_train,
y=self.y_train,
x_names=self.feature_names,
x_names=list(self.feature_names),
clf_name=self.clf_name,
cnt_present=shap_target_present_cnt,
cnt_absent=shap_target_absent_cnt,
Expand Down Expand Up @@ -356,11 +355,10 @@ def save(self) -> None:
if not os.listdir(self.model_dir_out):
os.makedirs(self.model_dir_out)
self.save_rf_model(self.rf_clf, self.clf_name, self.model_dir_out)
stdout_success(msg=f"Classifier {self.clf_name} saved in models/generated_models directory", elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__)
stdout_success(msg=f"Evaluation files are in models/generated_models/model_evaluations folders", source=self.__class__.__name__)
stdout_success(msg=f"Classifier {self.clf_name} saved in {self.model_dir_out} directory", elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__)
stdout_success(msg=f"Evaluation files are in {self.eval_out_path} folders", source=self.__class__.__name__)


#
# test = TrainRandomForestClassifier(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini")
# test.run()
# test.save()
Expand Down
Loading

0 comments on commit 8feadad

Please sign in to comment.