Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master'
Browse files Browse the repository at this point in the history
# Conflicts:
#	simba/utils/lookups.py
  • Loading branch information
sronilsson committed Feb 21, 2024
2 parents 8eeefc1 + 0106801 commit 3b55c05
Show file tree
Hide file tree
Showing 7 changed files with 1,573 additions and 474 deletions.
8 changes: 6 additions & 2 deletions simba/SimBA.py
Original file line number Diff line number Diff line change
Expand Up @@ -1467,9 +1467,13 @@ def run_feature_extraction(self):
source=self.__class__.__name__,
)
if self.pose_setting == "8":
feature_extractor = feature_extractor_classes[self.pose_setting][self.animal_cnt](config_path=self.config_path)
feature_extractor = feature_extractor_classes[self.pose_setting][
self.animal_cnt
](config_path=self.config_path)
else:
feature_extractor = feature_extractor_classes[self.pose_setting](config_path=self.config_path)
feature_extractor = feature_extractor_classes[self.pose_setting](
config_path=self.config_path
)
feature_extractor.run()

def set_distance_mm(self):
Expand Down
1,922 changes: 1,480 additions & 442 deletions simba/feature_extractors/amber_feature_extractor.py

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions simba/labelling/targeted_annotations_clips.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def __init__(
)

self.video_frm_label(frm_number=self.min_frm_no)
self.nav_bar = self.h_nav_bar(parent=self.main_frm, change_frm_func=self.change_frame_targeted_annotations
self.nav_bar = self.h_nav_bar(
parent=self.main_frm, change_frm_func=self.change_frame_targeted_annotations
)
self.selection_pane = self.targeted_clips_pane(parent=self.main_frm)
self.play_video_frame = self.v_navigation_pane_targeted_clips_version(
Expand All @@ -53,4 +54,4 @@ def select_labelling_video_targeted_clips(config_path: Union[str, os.PathLike]):
_ = TargetedAnnotatorWithClips(config_path=config_path, video_path=video_file_path)


#test = TargetedAnnotatorWithClips(config_path='/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini', video_name='Together_1')
# test = TargetedAnnotatorWithClips(config_path='/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini', video_name='Together_1')
22 changes: 16 additions & 6 deletions simba/ui/create_project_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,23 @@ def __init__(self):
)
self.tracking_type_dropdown.setChoices(Options.TRACKING_TYPE_OPTIONS.value[0])

project_animal_cnt_path = os.path.join(os.path.dirname(simba.__file__), Paths.SIMBA_NO_ANIMALS_PATH.value)
self.animal_count_lst = list(pd.read_csv(project_animal_cnt_path, header=None)[0])
project_animal_cnt_path = os.path.join(
os.path.dirname(simba.__file__), Paths.SIMBA_NO_ANIMALS_PATH.value
)
self.animal_count_lst = list(
pd.read_csv(project_animal_cnt_path, header=None)[0]
)
self.bp_lu = get_body_part_configurations()
self.bp_config_codes = get_bp_config_codes()
self.classical_tracking_options = deepcopy(Options.CLASSICAL_TRACKING_OPTIONS.value)
self.multi_tracking_options = deepcopy(Options.MULTI_ANIMAL_TRACKING_OPTIONS.value)
self.three_dim_tracking_options = deepcopy(Options.THREE_DIM_TRACKING_OPTIONS.value)
self.classical_tracking_options = deepcopy(
Options.CLASSICAL_TRACKING_OPTIONS.value
)
self.multi_tracking_options = deepcopy(
Options.MULTI_ANIMAL_TRACKING_OPTIONS.value
)
self.three_dim_tracking_options = deepcopy(
Options.THREE_DIM_TRACKING_OPTIONS.value
)
self.user_defined_options = [
x
for x in list(self.bp_lu.keys())
Expand Down Expand Up @@ -343,4 +353,4 @@ def run(self):
self.create_import_videos_menu(parent_frm=self.import_videos_tab)


#ProjectCreatorPopUp()
# ProjectCreatorPopUp()
86 changes: 66 additions & 20 deletions simba/unsupervised/cluster_xai_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,11 @@ def run(self):
)
with pd.ExcelWriter(self.save_path, mode="w") as writer:
pd.DataFrame().to_excel(writer, sheet_name=" ", index=True)
if self.settings[GINI_IMPORTANCE] or self.settings[PERMUTATION_IMPORTANCE] or (self.settings[SHAP][METHOD] == PAIRED):
if (
self.settings[GINI_IMPORTANCE]
or self.settings[PERMUTATION_IMPORTANCE]
or (self.settings[SHAP][METHOD] == PAIRED)
):
self.__train_paired_rf_models()
if self.settings[GINI_IMPORTANCE]:
self.__gini_importance()
Expand Down Expand Up @@ -214,8 +218,7 @@ def __permutation_importance(self):
elapsed_time=timer.elapsed_time_str,
)


def __train_all_against_one_rf_models(self, n_estimators: Optional[int] = 100):
def __train_all_against_one_rf_models(self, n_estimators: Optional[int] = 100):
all_against_one_rf_mdls = {}
rf_clf = RandomForestClassifier(
n_estimators=n_estimators,
Expand All @@ -226,9 +229,13 @@ def __train_all_against_one_rf_models(self, n_estimators: Optional[int] = 100):
bootstrap=True,
verbose=1,
)
for cluster_id in sorted(self.x_y_df['CLUSTER'].unique()):
cluster_df = self.x_y_df[self.x_y_df['CLUSTER'] == cluster_id].drop(['CLUSTER'], axis=1)
noncluster_df = self.x_y_df[self.x_y_df['CLUSTER'] != cluster_id].drop(['CLUSTER'], axis=1)
for cluster_id in sorted(self.x_y_df["CLUSTER"].unique()):
cluster_df = self.x_y_df[self.x_y_df["CLUSTER"] == cluster_id].drop(
["CLUSTER"], axis=1
)
noncluster_df = self.x_y_df[self.x_y_df["CLUSTER"] != cluster_id].drop(
["CLUSTER"], axis=1
)
cluster_df[TARGET] = 1
noncluster_df[TARGET] = 0

Expand Down Expand Up @@ -308,22 +315,61 @@ def __shap_values(self):
print("Calculating one-against-all shap values ...")
mdls = self.__train_all_against_one_rf_models()
for cluster_id, cluster_mdl in mdls.items():
print(f'Computing SHAP for cluster {cluster_id}...')
explainer = shap.TreeExplainer(cluster_mdl,data=None, model_output="raw", feature_perturbation="tree_path_dependent")
cluster_one_sample = self.x_y_df[self.x_y_df['CLUSTER'] == cluster_id].sample(n=self.settings[SHAP][SAMPLE])
cluster_two_sample = self.x_y_df[self.x_y_df['CLUSTER'] != cluster_id].sample(n=self.settings[SHAP][SAMPLE])
cluster_one_shap = pd.DataFrame(explainer.shap_values(cluster_one_sample, check_additivity=False)[1],columns=cluster_one_sample.columns, index=cluster_one_sample.index)
cluster_two_shap = pd.DataFrame(explainer.shap_values(cluster_two_sample, check_additivity=False)[1],columns=cluster_two_sample.columns, index=cluster_one_sample.index)
cluster_two_shap['CLUSTER'] = cluster_two_sample['CLUSTER'].values
cluster_one_shap['CLUSTER'] = cluster_one_sample['CLUSTER'].values
print(f"Computing SHAP for cluster {cluster_id}...")
explainer = shap.TreeExplainer(
cluster_mdl,
data=None,
model_output="raw",
feature_perturbation="tree_path_dependent",
)
cluster_one_sample = self.x_y_df[
self.x_y_df["CLUSTER"] == cluster_id
].sample(n=self.settings[SHAP][SAMPLE])
cluster_two_sample = self.x_y_df[
self.x_y_df["CLUSTER"] != cluster_id
].sample(n=self.settings[SHAP][SAMPLE])
cluster_one_shap = pd.DataFrame(
explainer.shap_values(cluster_one_sample, check_additivity=False)[
1
],
columns=cluster_one_sample.columns,
index=cluster_one_sample.index,
)
cluster_two_shap = pd.DataFrame(
explainer.shap_values(cluster_two_sample, check_additivity=False)[
1
],
columns=cluster_two_sample.columns,
index=cluster_one_sample.index,
)
cluster_two_shap["CLUSTER"] = cluster_two_sample["CLUSTER"].values
cluster_one_shap["CLUSTER"] = cluster_one_sample["CLUSTER"].values
results = pd.concat([cluster_one_shap, cluster_two_shap], axis=0)
self.__save_results(df=results, name=f"SHAP CLUSTER {cluster_id} vs. ALL")
self.__save_results(
df=results, name=f"SHAP CLUSTER {cluster_id} vs. ALL"
)
timer.stop_timer()
stdout_success(msg=f"SHAP one-vs-all values complete", elapsed_time=timer.elapsed_time_str)
stdout_success(
msg=f"SHAP one-vs-all values complete",
elapsed_time=timer.elapsed_time_str,
)


settings = {'gini_importance': False, 'permutation_importance': False, 'shap': {'method': 'cluster_paired', 'run': True, 'sample': 10}}
settings = {'gini_importance': False, 'permutation_importance': False, 'shap': {'method': 'One-against-all', 'run': True, 'sample': 10}}
calculator = ClusterXAICalculator(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/NG_Unsupervised/project_folder/project_config.ini', data_path='/Users/simon/Desktop/envs/simba/troubleshooting/NG_Unsupervised/project_folder/cluster_mdls/hopeful_khorana.pickle', settings=settings)
settings = {
"gini_importance": False,
"permutation_importance": False,
"shap": {"method": "cluster_paired", "run": True, "sample": 10},
}
settings = {
"gini_importance": False,
"permutation_importance": False,
"shap": {"method": "One-against-all", "run": True, "sample": 10},
}
calculator = ClusterXAICalculator(
config_path="/Users/simon/Desktop/envs/simba/troubleshooting/NG_Unsupervised/project_folder/project_config.ini",
data_path="/Users/simon/Desktop/envs/simba/troubleshooting/NG_Unsupervised/project_folder/cluster_mdls/hopeful_khorana.pickle",
settings=settings,
)


calculator.run()
calculator.run()
2 changes: 1 addition & 1 deletion simba/unsupervised/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class UMLOptions(Enum):
CATEGORICAL_OPTIONS = ["VIDEO", "CLASSIFIER", "CLUSTER"]
CONTINUOUS_OPTIONS = ["START_FRAME", "END_FRAME", "PROBABILITY"]
SPEED_OPTIONS = [round(x, 1) for x in list(np.arange(0.1, 2.1, 0.1))]
SHAP_CLUSTER_METHODS = ["Paired clusters", 'One-against-all']
SHAP_CLUSTER_METHODS = ["Paired clusters", "One-against-all"]
DR_ALGO_OPTIONS = ["UMAP", "TSNE"]
CLUSTERING_ALGO_OPTIONS = ["HDBSCAN"]
VARIANCE_OPTIONS = list(range(0, 100, 10))
Expand Down
2 changes: 1 addition & 1 deletion simba/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ class Options(Enum):
"Multi-animals; 4 body-parts",
"Multi-animals; 7 body-parts",
"Multi-animals; 8 body-parts",
"AMBER"
"AMBER",
]
THREE_DIM_TRACKING_OPTIONS = ["3D tracking"]
TRAIN_TEST_SPLIT = ["FRAMES", "BOUTS"]
Expand Down

0 comments on commit 3b55c05

Please sign in to comment.