From 3bd67e1e5d68d18776b8d61ec37450b2bf4e40e5 Mon Sep 17 00:00:00 2001 From: sronilsson Date: Thu, 25 Apr 2024 15:49:41 +0000 Subject: [PATCH] cleaned --- .../directing_other_animals_calculator.py | 260 +++++++--- simba/mixins/config_reader.py | 28 +- .../feature_extraction_supplement_mixin.py | 32 +- simba/mixins/plotting_mixin.py | 336 +++++++++--- simba/plotting/ROI_feature_visualizer.py | 388 +++++++++++--- simba/plotting/ROI_feature_visualizer_mp.py | 491 +++++++++++++----- simba/plotting/ROI_plotter.py | 363 ++++++++++--- simba/plotting/ROI_plotter_mp.py | 452 ++++++++++++---- .../plotting/directing_animals_visualizer.py | 184 +++++-- .../directing_animals_visualizer_mp.py | 302 +++++++---- simba/roi_tools/ROI_analyzer.py | 343 +++++++++--- simba/roi_tools/ROI_directing_analyzer.py | 145 ++++-- simba/roi_tools/ROI_feature_analyzer.py | 250 +++++++-- simba/roi_tools/ROI_time_bin_calculator.py | 226 ++++++-- .../append_roi_features_animals_pop_up.py | 45 +- .../append_roi_features_bodypart_pop_up.py | 23 +- .../directing_other_animals_plot_pop_up.py | 174 +++++-- simba/ui/pop_ups/roi_analysis_pop_up.py | 23 +- simba/ui/pop_ups/roi_features_plot_pop_up.py | 153 ++++-- simba/ui/pop_ups/roi_tracking_plot_pop_up.py | 173 ++++-- simba/utils/checks.py | 56 +- simba/utils/data.py | 39 +- simba/utils/read_write.py | 45 +- 23 files changed, 3429 insertions(+), 1102 deletions(-) diff --git a/simba/data_processors/directing_other_animals_calculator.py b/simba/data_processors/directing_other_animals_calculator.py index 714445f54..a257419c0 100644 --- a/simba/data_processors/directing_other_animals_calculator.py +++ b/simba/data_processors/directing_other_animals_calculator.py @@ -9,12 +9,14 @@ from simba.mixins.config_reader import ConfigReader from simba.mixins.feature_extraction_mixin import FeatureExtractionMixin -from simba.utils.checks import (check_all_file_names_are_represented_in_video_log, - check_that_dir_has_list_of_filenames, check_file_exist_and_readable) +from simba.utils.checks import ( + check_all_file_names_are_represented_in_video_log, + check_file_exist_and_readable, check_that_dir_has_list_of_filenames) from simba.utils.enums import TagNames from simba.utils.errors import AnimalNumberError, CountError, InvalidInputError from simba.utils.printing import SimbaTimer, log_event, stdout_success -from simba.utils.read_write import get_fn_ext, read_df, write_df, read_data_paths +from simba.utils.read_write import (get_fn_ext, read_data_paths, read_df, + write_df) class DirectingOtherAnimalsAnalyzer(ConfigReader, FeatureExtractionMixin): @@ -42,34 +44,67 @@ class DirectingOtherAnimalsAnalyzer(ConfigReader, FeatureExtractionMixin): >>> directing_analyzer.run() """ - def __init__(self, - config_path: Union[str, os.PathLike], - data_paths: Optional[Union[str, os.PathLike, None]] = None, - bool_tables: Optional[bool] = True, - summary_tables: Optional[bool] = False, - append_bool_tables_to_features: Optional[bool] = False, - aggregate_statistics_tables: Optional[bool] = False): + def __init__( + self, + config_path: Union[str, os.PathLike], + data_paths: Optional[Union[str, os.PathLike, None]] = None, + bool_tables: Optional[bool] = True, + summary_tables: Optional[bool] = False, + append_bool_tables_to_features: Optional[bool] = False, + aggregate_statistics_tables: Optional[bool] = False, + ): check_file_exist_and_readable(file_path=config_path) ConfigReader.__init__(self, config_path=config_path) FeatureExtractionMixin.__init__(self, config_path=config_path) - data_paths = read_data_paths(path=data_paths, default=self.outlier_corrected_paths, default_name=self.outlier_corrected_dir, file_type=self.file_type) - log_event(logger_name=str(self.__class__.__name__), log_type=TagNames.CLASS_INIT.value, msg=self.create_log_msg_from_init_args(locals=locals())) + data_paths = read_data_paths( + path=data_paths, + default=self.outlier_corrected_paths, + default_name=self.outlier_corrected_dir, + file_type=self.file_type, + ) + log_event( + logger_name=str(self.__class__.__name__), + log_type=TagNames.CLASS_INIT.value, + msg=self.create_log_msg_from_init_args(locals=locals()), + ) if self.animal_cnt < 2: - raise AnimalNumberError("Cannot analyze directionality between animals in a project with less than two animals.", source=self.__class__.__name__,) + raise AnimalNumberError( + "Cannot analyze directionality between animals in a project with less than two animals.", + source=self.__class__.__name__, + ) self.animal_permutations = list(itertools.permutations(self.animal_bp_dict, 2)) - self.bool_tables, self.summary_tables, self.aggregate_statistics_tables, self.append_bool_tables_to_features = bool_tables, summary_tables, aggregate_statistics_tables, append_bool_tables_to_features + ( + self.bool_tables, + self.summary_tables, + self.aggregate_statistics_tables, + self.append_bool_tables_to_features, + ) = ( + bool_tables, + summary_tables, + aggregate_statistics_tables, + append_bool_tables_to_features, + ) self.data_paths = data_paths if self.append_bool_tables_to_features: - check_that_dir_has_list_of_filenames(dir=self.features_dir, file_name_lst=self.outlier_corrected_paths, file_type=self.file_type) + check_that_dir_has_list_of_filenames( + dir=self.features_dir, + file_name_lst=self.outlier_corrected_paths, + file_type=self.file_type, + ) print(f"Processing {len(self.data_paths)} video(s)...") if not self.check_directionality_viable()[0]: - raise InvalidInputError(msg='You are not tracking the necessary body-parts to calculate direction.', source=self.__class__.__name__) + raise InvalidInputError( + msg="You are not tracking the necessary body-parts to calculate direction.", + source=self.__class__.__name__, + ) # def run(self): if self.aggregate_statistics_tables: - check_all_file_names_are_represented_in_video_log(video_info_df=self.video_info_df, data_paths=self.data_paths) + check_all_file_names_are_represented_in_video_log( + video_info_df=self.video_info_df, data_paths=self.data_paths + ) self.results = {} for file_cnt, file_path in enumerate(self.data_paths): video_timer = SimbaTimer(start=True) @@ -79,83 +114,179 @@ def run(self): data_df = read_df(file_path=file_path, file_type=self.file_type) direct_bp_dict = self.check_directionality_cords() for animal_permutation in self.animal_permutations: - self.results[video_name][f"{animal_permutation[0]} directing towards {animal_permutation[1]}"] = {} - first_animal_bps, second_animal_bps = (direct_bp_dict[animal_permutation[0]], self.animal_bp_dict[animal_permutation[1]]) - first_ear_left_arr = data_df[[first_animal_bps["Ear_left"]["X_bps"],first_animal_bps["Ear_left"]["Y_bps"]]].values - first_ear_right_arr = data_df[[first_animal_bps["Ear_right"]["X_bps"], first_animal_bps["Ear_right"]["Y_bps"]]].values - first_nose_arr = data_df[[first_animal_bps["Nose"]["X_bps"], first_animal_bps["Nose"]["Y_bps"]]].values - second_animal_x_bps, second_animal_y_bps = (second_animal_bps["X_bps"], second_animal_bps["Y_bps"]) + self.results[video_name][ + f"{animal_permutation[0]} directing towards {animal_permutation[1]}" + ] = {} + first_animal_bps, second_animal_bps = ( + direct_bp_dict[animal_permutation[0]], + self.animal_bp_dict[animal_permutation[1]], + ) + first_ear_left_arr = data_df[ + [ + first_animal_bps["Ear_left"]["X_bps"], + first_animal_bps["Ear_left"]["Y_bps"], + ] + ].values + first_ear_right_arr = data_df[ + [ + first_animal_bps["Ear_right"]["X_bps"], + first_animal_bps["Ear_right"]["Y_bps"], + ] + ].values + first_nose_arr = data_df[ + [ + first_animal_bps["Nose"]["X_bps"], + first_animal_bps["Nose"]["Y_bps"], + ] + ].values + second_animal_x_bps, second_animal_y_bps = ( + second_animal_bps["X_bps"], + second_animal_bps["Y_bps"], + ) for x_bp, y_bp in zip(second_animal_x_bps, second_animal_y_bps): target_cord_arr = data_df[[x_bp, y_bp]].values - direction_data = self.jitted_line_crosses_to_nonstatic_targets(left_ear_array=first_ear_left_arr, - right_ear_array=first_ear_right_arr, - nose_array=first_nose_arr, - target_array=target_cord_arr) + direction_data = self.jitted_line_crosses_to_nonstatic_targets( + left_ear_array=first_ear_left_arr, + right_ear_array=first_ear_right_arr, + nose_array=first_nose_arr, + target_array=target_cord_arr, + ) x_min = np.minimum(direction_data[:, 1], first_nose_arr[:, 0]) y_min = np.minimum(direction_data[:, 2], first_nose_arr[:, 1]) delta_x = abs((direction_data[:, 1] - first_nose_arr[:, 0]) / 2) delta_y = abs((direction_data[:, 2] - first_nose_arr[:, 1]) / 2) x_middle, y_middle = np.add(x_min, delta_x), np.add(y_min, delta_y) - direction_data = np.concatenate((y_middle.reshape(-1, 1), direction_data), axis=1) - direction_data = np.concatenate((x_middle.reshape(-1, 1), direction_data), axis=1) + direction_data = np.concatenate( + (y_middle.reshape(-1, 1), direction_data), axis=1 + ) + direction_data = np.concatenate( + (x_middle.reshape(-1, 1), direction_data), axis=1 + ) direction_data = np.delete(direction_data, [2, 3, 4], 1) direction_data = np.hstack((direction_data, target_cord_arr)) - bp_data = pd.DataFrame(direction_data,columns=["Eye_x", "Eye_y", "Directing_BOOL", x_bp, y_bp],) + bp_data = pd.DataFrame( + direction_data, + columns=["Eye_x", "Eye_y", "Directing_BOOL", x_bp, y_bp], + ) bp_data = bp_data[["Eye_x", "Eye_y", x_bp, y_bp, "Directing_BOOL"]] bp_data.insert(loc=0, column="Animal_2_body_part", value=x_bp[:-2]) - bp_data.insert(loc=0, column="Animal_2", value=animal_permutation[1]) - bp_data.insert(loc=0, column="Animal_1", value=animal_permutation[0]) - self.results[video_name][f"{animal_permutation[0]} directing towards {animal_permutation[1]}"][x_bp[:-2]] = bp_data + bp_data.insert( + loc=0, column="Animal_2", value=animal_permutation[1] + ) + bp_data.insert( + loc=0, column="Animal_1", value=animal_permutation[0] + ) + self.results[video_name][ + f"{animal_permutation[0]} directing towards {animal_permutation[1]}" + ][x_bp[:-2]] = bp_data video_timer.stop_timer() - print(f"Direction analysis complete for video {video_name} ({file_cnt + 1}/{len(self.outlier_corrected_paths)}, elapsed time: {video_timer.elapsed_time_str}s)...") + print( + f"Direction analysis complete for video {video_name} ({file_cnt + 1}/{len(self.outlier_corrected_paths)}, elapsed time: {video_timer.elapsed_time_str}s)..." + ) if self.bool_tables: - save_dir = os.path.join(self.logs_path, f"Animal_directing_animal_booleans_{self.datetime}") - if not os.path.isdir(save_dir): os.makedirs(save_dir) + save_dir = os.path.join( + self.logs_path, f"Animal_directing_animal_booleans_{self.datetime}" + ) + if not os.path.isdir(save_dir): + os.makedirs(save_dir) for video_cnt, (video_name, video_data) in enumerate(self.results.items()): - print(f"Saving boolean directing tables for video {video_name} (Video {video_cnt + 1}/{len(self.results.keys())})...") + print( + f"Saving boolean directing tables for video {video_name} (Video {video_cnt + 1}/{len(self.results.keys())})..." + ) video_df = pd.DataFrame() for animal_permutation, animal_permutation_data in video_data.items(): - for body_part_name, body_part_data in animal_permutation_data.items(): - video_df[f"{animal_permutation}_{body_part_name}"] = body_part_data["Directing_BOOL"] + for ( + body_part_name, + body_part_data, + ) in animal_permutation_data.items(): + video_df[f"{animal_permutation}_{body_part_name}"] = ( + body_part_data["Directing_BOOL"] + ) if self.append_bool_tables_to_features: - print(f"Adding directionality tables to features data for video {video_name}...") - df = read_df(file_path=os.path.join(self.features_dir, f"{video_name}.{self.file_type}"), file_type=self.file_type) + print( + f"Adding directionality tables to features data for video {video_name}..." + ) + df = read_df( + file_path=os.path.join( + self.features_dir, f"{video_name}.{self.file_type}" + ), + file_type=self.file_type, + ) if len(df) != len(video_df): - raise CountError(msg=f"Failed to join data files as they contains different number of frames: the file representing video {video_name} in directory {self.outlier_corrected_dir} contains {len(video_df)} frames, and the file representing video {video_name} in directory {self.features_dir} contains {len(df)} frames.") + raise CountError( + msg=f"Failed to join data files as they contains different number of frames: the file representing video {video_name} in directory {self.outlier_corrected_dir} contains {len(video_df)} frames, and the file representing video {video_name} in directory {self.features_dir} contains {len(df)} frames." + ) else: - df = pd.concat([df.reset_index(drop=True), video_df.reset_index(drop=True)], axis=1) - write_df(df=df, file_type=self.file_type, save_path=os.path.join(self.features_dir, f"{video_name}.{self.file_type}")) + df = pd.concat( + [ + df.reset_index(drop=True), + video_df.reset_index(drop=True), + ], + axis=1, + ) + write_df( + df=df, + file_type=self.file_type, + save_path=os.path.join( + self.features_dir, f"{video_name}.{self.file_type}" + ), + ) video_df.to_csv(os.path.join(save_dir, f"{video_name}.csv")) - stdout_success(msg=f"All boolean tables saved in {save_dir}!",source=self.__class__.__name__) + stdout_success( + msg=f"All boolean tables saved in {save_dir}!", + source=self.__class__.__name__, + ) if self.aggregate_statistics_tables: print("Computing summary statistics...") - save_path = os.path.join(self.logs_path, f"Direction_aggregate_summary_data_{self.datetime}.csv") - out_df = pd.DataFrame(columns=['VIDEO', 'ANIMAL PERMUTATION', 'VALUE (S)']) + save_path = os.path.join( + self.logs_path, f"Direction_aggregate_summary_data_{self.datetime}.csv" + ) + out_df = pd.DataFrame(columns=["VIDEO", "ANIMAL PERMUTATION", "VALUE (S)"]) for video_name, video_data in self.results.items(): _, _, fps = self.read_video_info(video_name=video_name) for animal_permutation, permutation_data in video_data.items(): idx_directing = set() for bp_name, bp_data in permutation_data.items(): - idx_directing.update(list(bp_data.index[bp_data["Directing_BOOL"] == 1])) + idx_directing.update( + list(bp_data.index[bp_data["Directing_BOOL"] == 1]) + ) value = round(len(idx_directing) / fps, 3) out_df.loc[len(out_df)] = [video_name, animal_permutation, value] - self.out_df = (out_df.sort_values(by=["VIDEO", "ANIMAL PERMUTATION"]).set_index("VIDEO")) + self.out_df = out_df.sort_values( + by=["VIDEO", "ANIMAL PERMUTATION"] + ).set_index("VIDEO") self.out_df.to_csv(save_path) - stdout_success(msg=f"Summary directional statistics saved at {save_path}", source=self.__class__.__name__) + stdout_success( + msg=f"Summary directional statistics saved at {save_path}", + source=self.__class__.__name__, + ) if self.summary_tables: self.transpose_results() - save_dir = os.path.join(self.logs_path, f"detailed_directionality_summary_dataframes_{self.datetime}") - if not os.path.exists(save_dir): os.makedirs(save_dir) - for video_cnt, (video_name, video_data) in enumerate(self.directionality_df_dict.items()): + save_dir = os.path.join( + self.logs_path, + f"detailed_directionality_summary_dataframes_{self.datetime}", + ) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + for video_cnt, (video_name, video_data) in enumerate( + self.directionality_df_dict.items() + ): save_name = os.path.join(save_dir, f"{video_name}.csv") video_data.to_csv(save_name) - stdout_success(f"All detailed directional data saved in the {save_dir} directory!", source=self.__class__.__name__) + stdout_success( + f"All detailed directional data saved in the {save_dir} directory!", + source=self.__class__.__name__, + ) self.timer.stop_timer() - stdout_success(msg="All directional data saved in SimBA project", elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__) + stdout_success( + msg="All directional data saved in SimBA project", + elapsed_time=self.timer.elapsed_time_str, + source=self.__class__.__name__, + ) def transpose_results(self): self.directionality_df_dict = {} @@ -163,12 +294,23 @@ def transpose_results(self): out_df_lst = [] for animal_permutation, permutation_data in video_data.items(): for bp_name, bp_data in permutation_data.items(): - directing_df = (bp_data[bp_data["Directing_BOOL"] == 1].reset_index().rename( - columns={"index": "Frame_#", bp_name + "_x": "Animal_2_bodypart_x", - bp_name + "_y": "Animal_2_bodypart_y"})) + directing_df = ( + bp_data[bp_data["Directing_BOOL"] == 1] + .reset_index() + .rename( + columns={ + "index": "Frame_#", + bp_name + "_x": "Animal_2_bodypart_x", + bp_name + "_y": "Animal_2_bodypart_y", + } + ) + ) directing_df.insert(loc=0, column="Video", value=video_name) out_df_lst.append(directing_df) - self.directionality_df_dict[video_name] = pd.concat(out_df_lst, axis=0).drop("Directing_BOOL", axis=1) + self.directionality_df_dict[video_name] = pd.concat( + out_df_lst, axis=0 + ).drop("Directing_BOOL", axis=1) + # test = DirectingOtherAnimalsAnalyzer(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini', # bool_tables=True, diff --git a/simba/mixins/config_reader.py b/simba/mixins/config_reader.py index 36185ebe1..53023f1ff 100644 --- a/simba/mixins/config_reader.py +++ b/simba/mixins/config_reader.py @@ -225,8 +225,12 @@ def read_roi_data(self) -> None: source=self.__class__.__name__, ) else: - self.rectangles_df = pd.read_hdf(self.roi_coordinates_path, key=Keys.ROI_RECTANGLES.value) - if ("Center_X" in self.rectangles_df.columns) and (self.rectangles_df["Center_X"].isnull().values.any()): + self.rectangles_df = pd.read_hdf( + self.roi_coordinates_path, key=Keys.ROI_RECTANGLES.value + ) + if ("Center_X" in self.rectangles_df.columns) and ( + self.rectangles_df["Center_X"].isnull().values.any() + ): for idx, row in self.rectangles_df.iterrows(): self.rectangles_df.loc[idx]["Center_X"] = row["Tags"]["Center tag"][ 0 @@ -264,11 +268,23 @@ def read_roi_data(self) -> None: self.roi_types_names_lst = list(self.roi_types_names_lst) for shape_type, shape_data in self.roi_dict.items(): if shape_type == Keys.ROI_CIRCLES.value: - self.roi_dict[Keys.ROI_CIRCLES.value]["Center_X"] = self.roi_dict[ Keys.ROI_CIRCLES.value]["centerX"] - self.roi_dict[Keys.ROI_CIRCLES.value]["Center_Y"] = self.roi_dict[Keys.ROI_CIRCLES.value]["centerY"] + self.roi_dict[Keys.ROI_CIRCLES.value]["Center_X"] = self.roi_dict[ + Keys.ROI_CIRCLES.value + ]["centerX"] + self.roi_dict[Keys.ROI_CIRCLES.value]["Center_Y"] = self.roi_dict[ + Keys.ROI_CIRCLES.value + ]["centerY"] elif shape_type == Keys.ROI_RECTANGLES.value: - self.roi_dict[Keys.ROI_RECTANGLES.value]["Center_X"] = self.roi_dict[Keys.ROI_RECTANGLES.value]["Bottom_right_X"] - (self.roi_dict[Keys.ROI_RECTANGLES.value]["width"] / 2) - self.roi_dict[Keys.ROI_RECTANGLES.value]["Center_Y"] = self.roi_dict[Keys.ROI_RECTANGLES.value]["Bottom_right_Y"] - (self.roi_dict[Keys.ROI_RECTANGLES.value]["height"] / 2) + self.roi_dict[Keys.ROI_RECTANGLES.value][ + "Center_X" + ] = self.roi_dict[Keys.ROI_RECTANGLES.value]["Bottom_right_X"] - ( + self.roi_dict[Keys.ROI_RECTANGLES.value]["width"] / 2 + ) + self.roi_dict[Keys.ROI_RECTANGLES.value][ + "Center_Y" + ] = self.roi_dict[Keys.ROI_RECTANGLES.value]["Bottom_right_Y"] - ( + self.roi_dict[Keys.ROI_RECTANGLES.value]["height"] / 2 + ) elif shape_type == Keys.ROI_POLYGONS.value: try: self.roi_dict[Keys.ROI_POLYGONS.value]["Center_X"] = ( diff --git a/simba/mixins/feature_extraction_supplement_mixin.py b/simba/mixins/feature_extraction_supplement_mixin.py index 57589f331..aa2fe0c90 100644 --- a/simba/mixins/feature_extraction_supplement_mixin.py +++ b/simba/mixins/feature_extraction_supplement_mixin.py @@ -740,14 +740,31 @@ def distance_and_velocity( >>> sum_movement, avg_velocity = FeatureExtractionSupplemental.distance_and_velocity(x=x, fps=10, pixels_per_mm=10, centimeters=True) """ - check_valid_array(data=x, source=FeatureExtractionSupplemental.distance_and_velocity.__name__, accepted_ndims=(1, 2), accepted_dtypes=(np.float32, np.float64, np.int32, np.int64, int, float)) - check_float(name=f"{FeatureExtractionSupplemental.distance_and_velocity.__name__} fps",value=fps,min_value=1) - check_float(name=f"{FeatureExtractionSupplemental.distance_and_velocity.__name__} pixels_per_mm",value=pixels_per_mm,min_value=10e-6) + check_valid_array( + data=x, + source=FeatureExtractionSupplemental.distance_and_velocity.__name__, + accepted_ndims=(1, 2), + accepted_dtypes=(np.float32, np.float64, np.int32, np.int64, int, float), + ) + check_float( + name=f"{FeatureExtractionSupplemental.distance_and_velocity.__name__} fps", + value=fps, + min_value=1, + ) + check_float( + name=f"{FeatureExtractionSupplemental.distance_and_velocity.__name__} pixels_per_mm", + value=pixels_per_mm, + min_value=10e-6, + ) if x.ndim == 2: - check_valid_array(data=x, source=FeatureExtractionSupplemental.distance_and_velocity.__name__, accepted_axis_1_shape=(2,)) + check_valid_array( + data=x, + source=FeatureExtractionSupplemental.distance_and_velocity.__name__, + accepted_axis_1_shape=(2,), + ) t = np.full((x.shape[0]), 0.0) for i in range(1, x.shape[0]): - t[i] = np.linalg.norm(x[i] - x[i-1]) + t[i] = np.linalg.norm(x[i] - x[i - 1]) x = np.copy(t) / pixels_per_mm movement = np.sum(x) / pixels_per_mm v = [] @@ -761,11 +778,12 @@ def distance_and_velocity( x = np.random.randint(0, 100, (100, 2)) -FeatureExtractionSupplemental.distance_and_velocity(x=x, fps=10, pixels_per_mm=10, centimeters=True) +FeatureExtractionSupplemental.distance_and_velocity( + x=x, fps=10, pixels_per_mm=10, centimeters=True +) # # sum_movement, avg_velocity = - # df = read_df(file_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/csv/targets_inserted/Together_1.csv', file_type='csv') # # df = pd.DataFrame(np.random.randint(0, 2, (100, 2)), columns=['Attack', 'Sniffing']) diff --git a/simba/mixins/plotting_mixin.py b/simba/mixins/plotting_mixin.py index bb3f046c2..c677e6049 100644 --- a/simba/mixins/plotting_mixin.py +++ b/simba/mixins/plotting_mixin.py @@ -32,8 +32,9 @@ check_if_keys_exist_in_dict, check_if_valid_rgb_tuple, check_instance, check_int, check_str, check_that_column_exist, - check_valid_array, check_valid_lst, check_valid_dataframe) -from simba.utils.enums import Formats, Options, TextOptions, Keys + check_valid_array, check_valid_dataframe, + check_valid_lst) +from simba.utils.enums import Formats, Keys, Options, TextOptions from simba.utils.errors import InvalidInputError from simba.utils.lookups import (get_categorical_palettes, get_color_dict, get_named_colors) @@ -1671,53 +1672,117 @@ def make_path_plot( else: return img - - @staticmethod - def rectangles_onto_image(img: np.ndarray, - rectangles: pd.DataFrame, - show_center: Optional[bool] = False, - show_tags: Optional[bool] = False, - circle_size: Optional[int] = 2) -> np.ndarray: + def rectangles_onto_image( + img: np.ndarray, + rectangles: pd.DataFrame, + show_center: Optional[bool] = False, + show_tags: Optional[bool] = False, + circle_size: Optional[int] = 2, + ) -> np.ndarray: check_valid_array(data=img, source=PlottingMixin.rectangles_onto_image.__name__) - check_valid_dataframe(df=rectangles, source=PlottingMixin.rectangles_onto_image.__name__, required_fields=['topLeftX', 'topLeftY', 'Bottom_right_X', 'Bottom_right_Y', 'Color BGR', 'Thickness', 'Center_X', 'Center_Y', 'Tags']) - check_int(name=PlottingMixin.rectangles_onto_image.__name__, value=circle_size, min_value=1) + check_valid_dataframe( + df=rectangles, + source=PlottingMixin.rectangles_onto_image.__name__, + required_fields=[ + "topLeftX", + "topLeftY", + "Bottom_right_X", + "Bottom_right_Y", + "Color BGR", + "Thickness", + "Center_X", + "Center_Y", + "Tags", + ], + ) + check_int( + name=PlottingMixin.rectangles_onto_image.__name__, + value=circle_size, + min_value=1, + ) for _, row in rectangles.iterrows(): - img = cv2.rectangle(img, (int(row["topLeftX"]), int(row["topLeftY"])), (int(row["Bottom_right_X"]), int(row["Bottom_right_Y"])), row["Color BGR"], int(row["Thickness"])) + img = cv2.rectangle( + img, + (int(row["topLeftX"]), int(row["topLeftY"])), + (int(row["Bottom_right_X"]), int(row["Bottom_right_Y"])), + row["Color BGR"], + int(row["Thickness"]), + ) if show_center: - img = cv2.circle(img, (int(row["Center_X"]), int(row["Center_Y"])), circle_size, row["Color BGR"], -1) + img = cv2.circle( + img, + (int(row["Center_X"]), int(row["Center_Y"])), + circle_size, + row["Color BGR"], + -1, + ) if show_tags: for tag_name, tag_data in row["Tags"].items(): - img = cv2.circle(img, tuple(tag_data), circle_size, row["Color BGR"], -1) + img = cv2.circle( + img, tuple(tag_data), circle_size, row["Color BGR"], -1 + ) return img @staticmethod - def circles_onto_image(img: np.ndarray, - circles: pd.DataFrame, - show_center: Optional[bool] = False, - show_tags: Optional[bool] = False, - circle_size: Optional[int] = 2) -> np.ndarray: + def circles_onto_image( + img: np.ndarray, + circles: pd.DataFrame, + show_center: Optional[bool] = False, + show_tags: Optional[bool] = False, + circle_size: Optional[int] = 2, + ) -> np.ndarray: check_valid_array(data=img, source=PlottingMixin.circles_onto_image.__name__) - check_valid_dataframe(df=circles, source=PlottingMixin.circles_onto_image.__name__, required_fields=['centerX', 'centerY', 'radius', 'Color BGR', 'Thickness', 'Tags']) - check_int(name=PlottingMixin.circles_onto_image.__name__, value=circle_size, min_value=1) + check_valid_dataframe( + df=circles, + source=PlottingMixin.circles_onto_image.__name__, + required_fields=[ + "centerX", + "centerY", + "radius", + "Color BGR", + "Thickness", + "Tags", + ], + ) + check_int( + name=PlottingMixin.circles_onto_image.__name__, + value=circle_size, + min_value=1, + ) for _, row in circles.iterrows(): - img = cv2.circle(img, (int(row["centerX"]), int(row["centerY"])), row["radius"], row["Color BGR"], int(row["Thickness"])) + img = cv2.circle( + img, + (int(row["centerX"]), int(row["centerY"])), + row["radius"], + row["Color BGR"], + int(row["Thickness"]), + ) if show_center: - img = cv2.circle(img, (int(row["Center_X"]), int(row["Center_Y"])), circle_size, row["Color BGR"], -1) + img = cv2.circle( + img, + (int(row["Center_X"]), int(row["Center_Y"])), + circle_size, + row["Color BGR"], + -1, + ) if show_tags: for tag_data in row["Tags"].values(): - img = cv2.circle(img, tuple(tag_data), circle_size, row["Color BGR"], -1) + img = cv2.circle( + img, tuple(tag_data), circle_size, row["Color BGR"], -1 + ) return img @staticmethod - def polygons_onto_image(img: np.ndarray, - polygons: pd.DataFrame, - show_center: Optional[bool] = False, - show_tags: Optional[bool] = False, - circle_size: Optional[int] = 2) -> np.ndarray: - + def polygons_onto_image( + img: np.ndarray, + polygons: pd.DataFrame, + show_center: Optional[bool] = False, + show_tags: Optional[bool] = False, + circle_size: Optional[int] = 2, + ) -> np.ndarray: """ Helper to insert polygon overlays onto an image. @@ -1729,42 +1794,104 @@ def polygons_onto_image(img: np.ndarray, :return: """ - check_valid_array(data=img, source=f'{PlottingMixin.polygons_onto_image.__name__} img') - check_valid_dataframe(df=polygons, source=f'{PlottingMixin.polygons_onto_image.__name__} polygons', required_fields=['vertices', 'Center_X', 'Center_Y', 'Color BGR', 'Thickness', 'Tags']) - check_int(name=PlottingMixin.polygons_onto_image.__name__, value=circle_size, min_value=1) + check_valid_array( + data=img, source=f"{PlottingMixin.polygons_onto_image.__name__} img" + ) + check_valid_dataframe( + df=polygons, + source=f"{PlottingMixin.polygons_onto_image.__name__} polygons", + required_fields=[ + "vertices", + "Center_X", + "Center_Y", + "Color BGR", + "Thickness", + "Tags", + ], + ) + check_int( + name=PlottingMixin.polygons_onto_image.__name__, + value=circle_size, + min_value=1, + ) for _, row in polygons.iterrows(): - img = cv2.polylines(img, [row["vertices"].astype(int)], True, row["Color BGR"], thickness=int(row["Thickness"])) + img = cv2.polylines( + img, + [row["vertices"].astype(int)], + True, + row["Color BGR"], + thickness=int(row["Thickness"]), + ) if show_center: - img = cv2.circle(img, (int(row["Center_X"]), int(row["Center_Y"])), circle_size, row["Color BGR"], -1) + img = cv2.circle( + img, + (int(row["Center_X"]), int(row["Center_Y"])), + circle_size, + row["Color BGR"], + -1, + ) if show_tags: for tag_data in row["vertices"]: - img = cv2.circle(img, tuple(tag_data), circle_size, row["Color BGR"], -1) + img = cv2.circle( + img, tuple(tag_data), circle_size, row["Color BGR"], -1 + ) return img @staticmethod - def roi_dict_onto_img(img: np.ndarray, - roi_dict: Dict[str, pd.DataFrame], - circle_size: Optional[int] = 2, - show_center: Optional[bool] = False, - show_tags: Optional[bool] = False) -> np.ndarray: - - check_valid_array(data=img, source=f'{PlottingMixin.roi_dict_onto_img.__name__} img') - check_if_keys_exist_in_dict(data=roi_dict, key=[Keys.ROI_POLYGONS.value, Keys.ROI_CIRCLES.value, Keys.ROI_RECTANGLES.value], name=PlottingMixin.roi_dict_onto_img.__name__) - img = PlottingMixin.rectangles_onto_image(img=img, rectangles=roi_dict[Keys.ROI_RECTANGLES.value], circle_size=circle_size, show_center=show_center, show_tags=show_tags) - img = PlottingMixin.circles_onto_image(img=img, circles=roi_dict[Keys.ROI_CIRCLES.value], circle_size=circle_size, show_center=show_center, show_tags=show_tags) - img = PlottingMixin.polygons_onto_image(img=img, polygons=roi_dict[Keys.ROI_POLYGONS.value], circle_size=circle_size, show_center=show_center, show_tags=show_tags) + def roi_dict_onto_img( + img: np.ndarray, + roi_dict: Dict[str, pd.DataFrame], + circle_size: Optional[int] = 2, + show_center: Optional[bool] = False, + show_tags: Optional[bool] = False, + ) -> np.ndarray: + + check_valid_array( + data=img, source=f"{PlottingMixin.roi_dict_onto_img.__name__} img" + ) + check_if_keys_exist_in_dict( + data=roi_dict, + key=[ + Keys.ROI_POLYGONS.value, + Keys.ROI_CIRCLES.value, + Keys.ROI_RECTANGLES.value, + ], + name=PlottingMixin.roi_dict_onto_img.__name__, + ) + img = PlottingMixin.rectangles_onto_image( + img=img, + rectangles=roi_dict[Keys.ROI_RECTANGLES.value], + circle_size=circle_size, + show_center=show_center, + show_tags=show_tags, + ) + img = PlottingMixin.circles_onto_image( + img=img, + circles=roi_dict[Keys.ROI_CIRCLES.value], + circle_size=circle_size, + show_center=show_center, + show_tags=show_tags, + ) + img = PlottingMixin.polygons_onto_image( + img=img, + polygons=roi_dict[Keys.ROI_POLYGONS.value], + circle_size=circle_size, + show_center=show_center, + show_tags=show_tags, + ) return img @staticmethod - def insert_directing_line(directing_df: pd.DataFrame, - img: np.ndarray, - shape_name: str, - animal_name: str, - frame_id: int, - color: Optional[Tuple[int]] = (0, 0, 255), - thickness: Optional[int] = 2, - style: Optional[str] = 'lines') -> np.ndarray: - + def insert_directing_line( + directing_df: pd.DataFrame, + img: np.ndarray, + shape_name: str, + animal_name: str, + frame_id: int, + color: Optional[Tuple[int]] = (0, 0, 255), + thickness: Optional[int] = 2, + style: Optional[str] = "lines", + ) -> np.ndarray: """ Helper to insert lines between the actor 'eye' and the ROI centers. @@ -1779,26 +1906,58 @@ def insert_directing_line(directing_df: pd.DataFrame, :return np.ndarray: The input image with the line. """ - check_valid_array(data=img, source=PlottingMixin.insert_directing_line.__name__) - check_valid_dataframe(df=directing_df, source=PlottingMixin.rectangles_onto_image.__name__, required_fields=['ROI', 'Animal', 'Frame', 'ROI_edge_1_x', 'ROI_edge_1_y', 'ROI_edge_2_x', 'ROI_edge_2_y']) - r = directing_df.loc[(directing_df["ROI"] == shape_name) & (directing_df["Animal"] == animal_name) & (directing_df["Frame"] == frame_id)].reset_index(drop=True) - if style == 'funnel': - convex_hull_arr = (np.array([[r["ROI_edge_1_x"], r["ROI_edge_1_y"]], [r["ROI_edge_2_x"], r["ROI_edge_2_y"]], [r["Eye_x"], r["Eye_y"]]]).reshape(-1, 2).astype(int)) + check_valid_dataframe( + df=directing_df, + source=PlottingMixin.rectangles_onto_image.__name__, + required_fields=[ + "ROI", + "Animal", + "Frame", + "ROI_edge_1_x", + "ROI_edge_1_y", + "ROI_edge_2_x", + "ROI_edge_2_y", + ], + ) + r = directing_df.loc[ + (directing_df["ROI"] == shape_name) + & (directing_df["Animal"] == animal_name) + & (directing_df["Frame"] == frame_id) + ].reset_index(drop=True) + if style == "funnel": + convex_hull_arr = ( + np.array( + [ + [r["ROI_edge_1_x"], r["ROI_edge_1_y"]], + [r["ROI_edge_2_x"], r["ROI_edge_2_y"]], + [r["Eye_x"], r["Eye_y"]], + ] + ) + .reshape(-1, 2) + .astype(int) + ) img = cv2.fillPoly(img, [convex_hull_arr], color) else: - img = cv2.line(img, (int(r["Eye_x"]), int(r["Eye_y"])), (int(r["ROI_x"]), int(r["ROI_y"])), color, thickness) + img = cv2.line( + img, + (int(r["Eye_x"]), int(r["Eye_y"])), + (int(r["ROI_x"]), int(r["ROI_y"])), + color, + thickness, + ) return img - @staticmethod - def draw_lines_on_img(img: np.ndarray, - start_positions: np.ndarray, - end_positions: np.ndarray, - color: Tuple[int, int, int], - highlight_endpoint: Optional[bool] = False, - thickness: Optional[int] = 2, - circle_size: Optional[int] = 2) -> np.ndarray: + def draw_lines_on_img( + img: np.ndarray, + start_positions: np.ndarray, + end_positions: np.ndarray, + color: Tuple[int, int, int], + highlight_endpoint: Optional[bool] = False, + thickness: Optional[int] = 2, + circle_size: Optional[int] = 2, + ) -> np.ndarray: """ Helper to draw a set of lines onto an image. @@ -1813,14 +1972,41 @@ def draw_lines_on_img(img: np.ndarray, :return np.ndarray: The image with the lines overlayed. """ - check_valid_array(data=start_positions, source=f'{PlottingMixin.draw_lines_on_img.__name__} img') - check_valid_array(data=start_positions, source=f'{PlottingMixin.draw_lines_on_img.__name__} start_positions', accepted_ndims=(2,), accepted_dtypes=(np.int64,), min_axis_0=1) - check_valid_array(data=end_positions, source=f'{PlottingMixin.draw_lines_on_img.__name__} end_positions', accepted_shapes=[(start_positions.shape[0], 2),]) + check_valid_array( + data=start_positions, + source=f"{PlottingMixin.draw_lines_on_img.__name__} img", + ) + check_valid_array( + data=start_positions, + source=f"{PlottingMixin.draw_lines_on_img.__name__} start_positions", + accepted_ndims=(2,), + accepted_dtypes=(np.int64,), + min_axis_0=1, + ) + check_valid_array( + data=end_positions, + source=f"{PlottingMixin.draw_lines_on_img.__name__} end_positions", + accepted_shapes=[ + (start_positions.shape[0], 2), + ], + ) check_if_valid_rgb_tuple(data=color) for i in range(start_positions.shape[0]): - cv2.line(img, (start_positions[i][0], start_positions[i][1]), (end_positions[i][0], end_positions[i][1]), color, thickness) + cv2.line( + img, + (start_positions[i][0], start_positions[i][1]), + (end_positions[i][0], end_positions[i][1]), + color, + thickness, + ) if highlight_endpoint: - cv2.circle(img, (end_positions[i][0], end_positions[i][1]), circle_size, color, -1) + cv2.circle( + img, + (end_positions[i][0], end_positions[i][1]), + circle_size, + color, + -1, + ) return img diff --git a/simba/plotting/ROI_feature_visualizer.py b/simba/plotting/ROI_feature_visualizer.py index 02b37e603..f8e69eb53 100644 --- a/simba/plotting/ROI_feature_visualizer.py +++ b/simba/plotting/ROI_feature_visualizer.py @@ -2,31 +2,45 @@ import itertools import os +from typing import Any, Dict, List, Union import cv2 import numpy as np -from typing import Union, Dict, Any, List from simba.mixins.config_reader import ConfigReader from simba.mixins.plotting_mixin import PlottingMixin from simba.roi_tools.ROI_feature_analyzer import ROIFeatureCreator -from simba.utils.checks import check_file_exist_and_readable, check_if_keys_exist_in_dict, check_valid_lst, check_video_and_data_frm_count_align, check_valid_array, check_valid_dataframe, check_int -from simba.utils.enums import Formats, TextOptions, Keys +from simba.utils.checks import (check_file_exist_and_readable, + check_if_keys_exist_in_dict, check_int, + check_valid_array, check_valid_dataframe, + check_valid_lst, + check_video_and_data_frm_count_align) +from simba.utils.data import slice_roi_dict_for_video +from simba.utils.enums import Formats, Keys, TextOptions +from simba.utils.errors import (BodypartColumnNotFoundError, NoFilesFoundError, + ROICoordinatesNotFoundError) from simba.utils.printing import stdout_success from simba.utils.read_write import get_fn_ext, get_video_meta_data, read_df -from simba.utils.errors import ROICoordinatesNotFoundError, NoFilesFoundError, BodypartColumnNotFoundError -from simba.utils.data import slice_roi_dict_for_video from simba.utils.warnings import DuplicateNamesWarning -ROI_CENTERS = 'roi_centers' -ROI_EAR_TAGS = 'roi_ear_tags' -DIRECTIONALITY = 'directionality' -DIRECTIONALITY_STYLE = 'directionality_style' -BORDER_COLOR = 'border_color' -POSE = 'pose_estimation' -ANIMAL_NAMES = 'animal_names' +ROI_CENTERS = "roi_centers" +ROI_EAR_TAGS = "roi_ear_tags" +DIRECTIONALITY = "directionality" +DIRECTIONALITY_STYLE = "directionality_style" +BORDER_COLOR = "border_color" +POSE = "pose_estimation" +ANIMAL_NAMES = "animal_names" + +STYLE_KEYS = [ + ROI_CENTERS, + ROI_EAR_TAGS, + DIRECTIONALITY, + BORDER_COLOR, + POSE, + DIRECTIONALITY_STYLE, + ANIMAL_NAMES, +] -STYLE_KEYS = [ROI_CENTERS, ROI_EAR_TAGS, DIRECTIONALITY, BORDER_COLOR, POSE, DIRECTIONALITY_STYLE, ANIMAL_NAMES] class ROIfeatureVisualizer(ConfigReader): """ @@ -60,44 +74,87 @@ class ROIfeatureVisualizer(ConfigReader): >>> test.run() """ - def __init__(self, - config_path: Union[str, os.PathLike], - video_path: Union[str, os.PathLike], - body_parts: List[str], - style_attr: Dict[str, Any]): + def __init__( + self, + config_path: Union[str, os.PathLike], + video_path: Union[str, os.PathLike], + body_parts: List[str], + style_attr: Dict[str, Any], + ): check_file_exist_and_readable(file_path=config_path) check_file_exist_and_readable(file_path=video_path) - check_if_keys_exist_in_dict(data=style_attr, key=STYLE_KEYS, name=f'{self.__class__.__name__} style_attr') + check_if_keys_exist_in_dict( + data=style_attr, + key=STYLE_KEYS, + name=f"{self.__class__.__name__} style_attr", + ) _, self.video_name, _ = get_fn_ext(video_path) ConfigReader.__init__(self, config_path=config_path) if not os.path.isfile(self.roi_coordinates_path): - raise ROICoordinatesNotFoundError(expected_file_path=self.roi_coordinates_path) + raise ROICoordinatesNotFoundError( + expected_file_path=self.roi_coordinates_path + ) self.read_roi_data() - self.roi_dict, shape_names = slice_roi_dict_for_video(data=self.roi_dict, video_name=self.video_name) - self.data_path = os.path.join(self.outlier_corrected_dir, f"{self.video_name}.{self.file_type}") + self.roi_dict, shape_names = slice_roi_dict_for_video( + data=self.roi_dict, video_name=self.video_name + ) + self.data_path = os.path.join( + self.outlier_corrected_dir, f"{self.video_name}.{self.file_type}" + ) if not os.path.isfile(self.data_path): - raise NoFilesFoundError( msg=f"SIMBA ERROR: Could not find the file at path {self.data_path}. Make sure the data file exist to create ROI visualizations", source=self.__class__.__name__) + raise NoFilesFoundError( + msg=f"SIMBA ERROR: Could not find the file at path {self.data_path}. Make sure the data file exist to create ROI visualizations", + source=self.__class__.__name__, + ) if not os.path.exists(self.roi_features_save_dir): os.makedirs(self.roi_features_save_dir) - self.save_path = os.path.join(self.roi_features_save_dir, f"{self.video_name}.mp4") - check_valid_lst(data=body_parts, source=f'{self.__class__.__name__} body-parts', valid_dtypes=(str,), min_len=1) + self.save_path = os.path.join( + self.roi_features_save_dir, f"{self.video_name}.mp4" + ) + check_valid_lst( + data=body_parts, + source=f"{self.__class__.__name__} body-parts", + valid_dtypes=(str,), + min_len=1, + ) for bp in body_parts: if bp not in self.body_parts_lst: - raise BodypartColumnNotFoundError(msg=f'The body-part {bp} is not a valid body-part in the SimBA project. Options: {self.body_parts_lst}', source=self.__class__.__name__) - self.roi_feature_creator = ROIFeatureCreator(config_path=config_path, body_parts=body_parts, append_data=False, data_path=self.data_path) + raise BodypartColumnNotFoundError( + msg=f"The body-part {bp} is not a valid body-part in the SimBA project. Options: {self.body_parts_lst}", + source=self.__class__.__name__, + ) + self.roi_feature_creator = ROIFeatureCreator( + config_path=config_path, + body_parts=body_parts, + append_data=False, + data_path=self.data_path, + ) self.roi_feature_creator.run() self.bp_lk = self.roi_feature_creator.bp_lk - self.animal_bp_names = [f'{v[0]} {v[1]}' for v in self.bp_lk.values()] + self.animal_bp_names = [f"{v[0]} {v[1]}" for v in self.bp_lk.values()] self.animal_names = [v[0] for v in self.bp_lk.values()] self.video_meta_data = get_video_meta_data(video_path) self.fourcc = cv2.VideoWriter_fourcc(*Formats.MP4_CODEC.value) self.cap = cv2.VideoCapture(video_path) - self.max_dim = max(self.video_meta_data["width"], self.video_meta_data["height"]) - self.circle_size = int(TextOptions.RADIUS_SCALER.value / (TextOptions.RESOLUTION_SCALER.value / self.max_dim)) - self.font_size = float(TextOptions.FONT_SCALER.value / (TextOptions.RESOLUTION_SCALER.value / self.max_dim)) - self.spacing_scale = int(TextOptions.SPACE_SCALER.value / (TextOptions.RESOLUTION_SCALER.value / self.max_dim)) - check_video_and_data_frm_count_align(video=video_path, data=self.data_path, name=video_path, raise_error=False) + self.max_dim = max( + self.video_meta_data["width"], self.video_meta_data["height"] + ) + self.circle_size = int( + TextOptions.RADIUS_SCALER.value + / (TextOptions.RESOLUTION_SCALER.value / self.max_dim) + ) + self.font_size = float( + TextOptions.FONT_SCALER.value + / (TextOptions.RESOLUTION_SCALER.value / self.max_dim) + ) + self.spacing_scale = int( + TextOptions.SPACE_SCALER.value + / (TextOptions.RESOLUTION_SCALER.value / self.max_dim) + ) + check_video_and_data_frm_count_align( + video=video_path, data=self.data_path, name=video_path, raise_error=False + ) self.style_attr = style_attr self.direct_viable = self.roi_feature_creator.roi_directing_viable self.data_df = read_df(file_path=self.data_path, file_type=self.file_type) @@ -109,24 +166,72 @@ def __calc_text_locs(self): self.loc_dict = {} for animal_cnt, animal_data in self.bp_lk.items(): animal, animal_bp, _ = animal_data - animal_name = f'{animal} {animal_bp}' + animal_name = f"{animal} {animal_bp}" self.loc_dict[animal_name] = {} self.loc_dict[animal] = {} for shape in self.shape_names: self.loc_dict[animal_name][shape] = {} - self.loc_dict[animal_name][shape]["in_zone_text"] = f"{shape} {animal_name} in zone" - self.loc_dict[animal_name][shape]["distance_text"] = f"{shape} {animal_name} distance" - self.loc_dict[animal_name][shape]["in_zone_text_loc"] = ((self.video_meta_data["width"] + 5), (self.video_meta_data["height"] - (self.video_meta_data["height"] + 10) + self.spacing_scale * add_spacer)) - self.loc_dict[animal_name][shape]["in_zone_data_loc"] = (int(self.img_w_border_w - (self.img_w_border_w / 8)), (self.video_meta_data["height"] - (self.video_meta_data["height"] + 10) + self.spacing_scale * add_spacer)) + self.loc_dict[animal_name][shape][ + "in_zone_text" + ] = f"{shape} {animal_name} in zone" + self.loc_dict[animal_name][shape][ + "distance_text" + ] = f"{shape} {animal_name} distance" + self.loc_dict[animal_name][shape]["in_zone_text_loc"] = ( + (self.video_meta_data["width"] + 5), + ( + self.video_meta_data["height"] + - (self.video_meta_data["height"] + 10) + + self.spacing_scale * add_spacer + ), + ) + self.loc_dict[animal_name][shape]["in_zone_data_loc"] = ( + int(self.img_w_border_w - (self.img_w_border_w / 8)), + ( + self.video_meta_data["height"] + - (self.video_meta_data["height"] + 10) + + self.spacing_scale * add_spacer + ), + ) add_spacer += 1 - self.loc_dict[animal_name][shape]["distance_text_loc"] = ((self.video_meta_data["width"] + 5), (self.video_meta_data["height"] - (self.video_meta_data["height"] + 10) + self.spacing_scale * add_spacer)) - self.loc_dict[animal_name][shape]["distance_data_loc"] = (int(self.img_w_border_w - (self.img_w_border_w / 8)), (self.video_meta_data["height"] - (self.video_meta_data["height"] + 10) + self.spacing_scale * add_spacer)) + self.loc_dict[animal_name][shape]["distance_text_loc"] = ( + (self.video_meta_data["width"] + 5), + ( + self.video_meta_data["height"] + - (self.video_meta_data["height"] + 10) + + self.spacing_scale * add_spacer + ), + ) + self.loc_dict[animal_name][shape]["distance_data_loc"] = ( + int(self.img_w_border_w - (self.img_w_border_w / 8)), + ( + self.video_meta_data["height"] + - (self.video_meta_data["height"] + 10) + + self.spacing_scale * add_spacer + ), + ) add_spacer += 1 if self.direct_viable and self.style_attr[DIRECTIONALITY]: self.loc_dict[animal][shape] = {} - self.loc_dict[animal][shape]["directing_text"] = f"{shape} {animal} facing" - self.loc_dict[animal][shape]["directing_text_loc"] = ((self.video_meta_data["width"] + 5), (self.video_meta_data["height"] - (self.video_meta_data["height"] + 10) + self.spacing_scale * add_spacer)) - self.loc_dict[animal][shape]["directing_data_loc"] = (int(self.img_w_border_w - (self.img_w_border_w / 8)), (self.video_meta_data["height"] - (self.video_meta_data["height"] + 10) + self.spacing_scale * add_spacer)) + self.loc_dict[animal][shape][ + "directing_text" + ] = f"{shape} {animal} facing" + self.loc_dict[animal][shape]["directing_text_loc"] = ( + (self.video_meta_data["width"] + 5), + ( + self.video_meta_data["height"] + - (self.video_meta_data["height"] + 10) + + self.spacing_scale * add_spacer + ), + ) + self.loc_dict[animal][shape]["directing_data_loc"] = ( + int(self.img_w_border_w - (self.img_w_border_w / 8)), + ( + self.video_meta_data["height"] + - (self.video_meta_data["height"] + 10) + + self.spacing_scale * add_spacer + ), + ) add_spacer += 1 def __create_shape_dicts(self): @@ -134,7 +239,10 @@ def __create_shape_dicts(self): for shape, df in self.roi_dict.items(): if not df["Name"].is_unique: df = df.drop_duplicates(subset=["Name"], keep="first") - DuplicateNamesWarning(msg=f'Some of your ROIs with the same shape ({shape}) has the same names for video {self.video_name}. E.g., you have two rectangles named "My rectangle". SimBA prefers ROI shapes with unique names. SimBA will keep one of the unique shape names and drop the rest.', source=self.__class__.__name__) + DuplicateNamesWarning( + msg=f'Some of your ROIs with the same shape ({shape}) has the same names for video {self.video_name}. E.g., you have two rectangles named "My rectangle". SimBA prefers ROI shapes with unique names. SimBA will keep one of the unique shape names and drop the rest.', + source=self.__class__.__name__, + ) d = df.set_index("Name").to_dict(orient="index") shape_dicts = {**shape_dicts, **d} return shape_dicts @@ -142,77 +250,187 @@ def __create_shape_dicts(self): def __insert_texts(self, shape_df): for cnt, animal_data in self.bp_lk.items(): animal, animal_bp, _ = animal_data - animal_name = f'{animal} {animal_bp}' + animal_name = f"{animal} {animal_bp}" for _, shape in shape_df.iterrows(): shape_name, shape_color = shape["Name"], shape["Color BGR"] - cv2.putText(self.img_w_border, self.loc_dict[animal_name][shape_name]["in_zone_text"], self.loc_dict[animal_name][shape_name]["in_zone_text_loc"], self.font, self.font_size, shape_color, 1) - cv2.putText(self.img_w_border, self.loc_dict[animal_name][shape_name]["distance_text"], self.loc_dict[animal_name][shape_name]["distance_text_loc"], self.font, self.font_size, shape_color, 1) + cv2.putText( + self.img_w_border, + self.loc_dict[animal_name][shape_name]["in_zone_text"], + self.loc_dict[animal_name][shape_name]["in_zone_text_loc"], + self.font, + self.font_size, + shape_color, + 1, + ) + cv2.putText( + self.img_w_border, + self.loc_dict[animal_name][shape_name]["distance_text"], + self.loc_dict[animal_name][shape_name]["distance_text_loc"], + self.font, + self.font_size, + shape_color, + 1, + ) if self.direct_viable: - cv2.putText(self.img_w_border, self.loc_dict[animal][shape_name]["directing_text"], self.loc_dict[animal][shape_name]["directing_text_loc"], self.font, self.font_size, shape_color, 1,) + cv2.putText( + self.img_w_border, + self.loc_dict[animal][shape_name]["directing_text"], + self.loc_dict[animal][shape_name]["directing_text_loc"], + self.font, + self.font_size, + shape_color, + 1, + ) def run(self): self.frame_cnt = 0 while self.cap.isOpened(): ret, self.img = self.cap.read() if ret: - self.img_w_border = cv2.copyMakeBorder(self.img, 0, 0, 0, self.video_meta_data["width"], borderType=cv2.BORDER_CONSTANT, value=self.style_attr[BORDER_COLOR]) + self.img_w_border = cv2.copyMakeBorder( + self.img, + 0, + 0, + 0, + self.video_meta_data["width"], + borderType=cv2.BORDER_CONSTANT, + value=self.style_attr[BORDER_COLOR], + ) if self.frame_cnt == 0: - self.img_w_border_h, self.img_w_border_w = (self.img_w_border.shape[0], self.img_w_border.shape[1]) + self.img_w_border_h, self.img_w_border_w = ( + self.img_w_border.shape[0], + self.img_w_border.shape[1], + ) self.__calc_text_locs() - self.writer = cv2.VideoWriter(self.save_path, self.fourcc, self.video_meta_data["fps"], (self.img_w_border_w, self.img_w_border_h)) + self.writer = cv2.VideoWriter( + self.save_path, + self.fourcc, + self.video_meta_data["fps"], + (self.img_w_border_w, self.img_w_border_h), + ) self.__insert_texts(self.roi_dict[Keys.ROI_RECTANGLES.value]) self.__insert_texts(self.roi_dict[Keys.ROI_CIRCLES.value]) self.__insert_texts(self.roi_dict[Keys.ROI_POLYGONS.value]) if self.style_attr[POSE]: for animal_name, bp_data in self.animal_bp_dict.items(): - for bp_cnt, bp in enumerate(zip(bp_data['X_bps'], bp_data['Y_bps'])): - bp_cords = self.data_df.loc[self.frame_cnt, list(bp)].values.astype(np.int64) - cv2.circle(self.img_w_border, (bp_cords[0], bp_cords[1]), 0, self.animal_bp_dict[animal_name]["colors"][bp_cnt], self.circle_size) + for bp_cnt, bp in enumerate( + zip(bp_data["X_bps"], bp_data["Y_bps"]) + ): + bp_cords = self.data_df.loc[ + self.frame_cnt, list(bp) + ].values.astype(np.int64) + cv2.circle( + self.img_w_border, + (bp_cords[0], bp_cords[1]), + 0, + self.animal_bp_dict[animal_name]["colors"][bp_cnt], + self.circle_size, + ) if self.style_attr[ANIMAL_NAMES]: for animal_name, bp_data in self.animal_bp_dict.items(): - headers = [bp_data['X_bps'][-1], bp_data['Y_bps'][-1]] - bp_cords = self.data_df.loc[self.frame_cnt, headers].values.astype(np.int64) - cv2.putText(self.img_w_border, animal_name, (bp_cords[0], bp_cords[1]), self.font, self.font_size, self.animal_bp_dict[animal_name]["colors"][0], 1) - - self.img_w_border = PlottingMixin.roi_dict_onto_img(img=self.img_w_border, - roi_dict=self.roi_dict, - circle_size=self.circle_size, - show_tags=self.style_attr[ROI_EAR_TAGS], - show_center=self.style_attr[ROI_CENTERS]) - - for animal_name, shape_name in itertools.product(self.animal_bp_names, self.shape_names): + headers = [bp_data["X_bps"][-1], bp_data["Y_bps"][-1]] + bp_cords = self.data_df.loc[ + self.frame_cnt, headers + ].values.astype(np.int64) + cv2.putText( + self.img_w_border, + animal_name, + (bp_cords[0], bp_cords[1]), + self.font, + self.font_size, + self.animal_bp_dict[animal_name]["colors"][0], + 1, + ) + + self.img_w_border = PlottingMixin.roi_dict_onto_img( + img=self.img_w_border, + roi_dict=self.roi_dict, + circle_size=self.circle_size, + show_tags=self.style_attr[ROI_EAR_TAGS], + show_center=self.style_attr[ROI_CENTERS], + ) + + for animal_name, shape_name in itertools.product( + self.animal_bp_names, self.shape_names + ): in_zone_col_name = f"{shape_name} {animal_name} in zone" distance_col_name = f"{shape_name} {animal_name} distance" - in_zone_value = str(bool(self.roi_feature_creator.out_df.loc[self.frame_cnt, in_zone_col_name])) - distance_value = round(self.roi_feature_creator.out_df.loc[self.frame_cnt, distance_col_name], 2) - cv2.putText(self.img_w_border, in_zone_value, self.loc_dict[animal_name][shape_name]["in_zone_data_loc"], self.font, self.font_size, self.shape_dicts[shape_name]["Color BGR"], 1) - cv2.putText(self.img_w_border, str(distance_value), self.loc_dict[animal_name][shape_name]["distance_data_loc"], self.font, self.font_size, self.shape_dicts[shape_name]["Color BGR"], 1) + in_zone_value = str( + bool( + self.roi_feature_creator.out_df.loc[ + self.frame_cnt, in_zone_col_name + ] + ) + ) + distance_value = round( + self.roi_feature_creator.out_df.loc[ + self.frame_cnt, distance_col_name + ], + 2, + ) + cv2.putText( + self.img_w_border, + in_zone_value, + self.loc_dict[animal_name][shape_name]["in_zone_data_loc"], + self.font, + self.font_size, + self.shape_dicts[shape_name]["Color BGR"], + 1, + ) + cv2.putText( + self.img_w_border, + str(distance_value), + self.loc_dict[animal_name][shape_name]["distance_data_loc"], + self.font, + self.font_size, + self.shape_dicts[shape_name]["Color BGR"], + 1, + ) if self.direct_viable and self.style_attr[DIRECTIONALITY]: - for animal_name, shape_name in itertools.product(self.animal_names, self.shape_names): + for animal_name, shape_name in itertools.product( + self.animal_names, self.shape_names + ): facing_col_name = f"{shape_name} {animal_name} facing" - facing_value = self.roi_feature_creator.out_df.loc[self.frame_cnt, facing_col_name] - cv2.putText(self.img_w_border, str(bool(facing_value)), self.loc_dict[animal_name][shape_name]["directing_data_loc"], self.font, self.font_size, self.shape_dicts[shape_name]["Color BGR"], 1) + facing_value = self.roi_feature_creator.out_df.loc[ + self.frame_cnt, facing_col_name + ] + cv2.putText( + self.img_w_border, + str(bool(facing_value)), + self.loc_dict[animal_name][shape_name][ + "directing_data_loc" + ], + self.font, + self.font_size, + self.shape_dicts[shape_name]["Color BGR"], + 1, + ) if facing_value: - self.img_w_border = PlottingMixin.insert_directing_line(directing_df=self.directing_df, - img=self.img_w_border, - shape_name=shape_name, - animal_name=animal_name, - frame_id=self.frame_cnt, - color=self.shape_dicts[shape_name]['Color BGR'], - thickness=self.shape_dicts[shape_name]['Thickness'], - style=self.style_attr[DIRECTIONALITY_STYLE]) + self.img_w_border = PlottingMixin.insert_directing_line( + directing_df=self.directing_df, + img=self.img_w_border, + shape_name=shape_name, + animal_name=animal_name, + frame_id=self.frame_cnt, + color=self.shape_dicts[shape_name]["Color BGR"], + thickness=self.shape_dicts[shape_name]["Thickness"], + style=self.style_attr[DIRECTIONALITY_STYLE], + ) self.frame_cnt += 1 self.writer.write(np.uint8(self.img_w_border)) - print(f"Frame: {self.frame_cnt} / {self.video_meta_data['frame_count']}. Video: {self.video_name} ...") + print( + f"Frame: {self.frame_cnt} / {self.video_meta_data['frame_count']}. Video: {self.video_name} ..." + ) else: break self.timer.stop_timer() self.cap.release() self.writer.release() - stdout_success(msg=f"Feature video {self.video_name} saved in {self.save_path} directory ...", elapsed_time=self.timer.elapsed_time_str) - - + stdout_success( + msg=f"Feature video {self.video_name} saved in {self.save_path} directory ...", + elapsed_time=self.timer.elapsed_time_str, + ) # style_attr = {'roi_centers': True, 'roi_ear_tags': True, 'directionality': True, 'directionality_style': 'lines', 'border_color': (0, 0, 0), 'pose_estimation': True, 'animal_names': True} @@ -223,8 +441,6 @@ def run(self): # test.run() - - # style_attr = {'roi_centers': True, 'roi_ear_tags': True, 'directionality': True, 'directionality_style': 'funnel', 'border_color': (0, 128, 0), 'pose_estimation': True, 'animal_names': True} # test = ROIfeatureVisualizer(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini', # video_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/videos/Together_1.avi', diff --git a/simba/plotting/ROI_feature_visualizer_mp.py b/simba/plotting/ROI_feature_visualizer_mp.py index 7eee20a1d..d7aa81bf5 100644 --- a/simba/plotting/ROI_feature_visualizer_mp.py +++ b/simba/plotting/ROI_feature_visualizer_mp.py @@ -5,134 +5,233 @@ import multiprocessing import os import platform -import pandas as pd -import numpy as np -from typing import Any, Dict, Optional, Union, List +from typing import Any, Dict, List, Optional, Union import cv2 +import numpy as np +import pandas as pd from simba.mixins.config_reader import ConfigReader from simba.mixins.plotting_mixin import PlottingMixin from simba.roi_tools.ROI_feature_analyzer import ROIFeatureCreator from simba.utils.checks import (check_file_exist_and_readable, - check_int, - check_if_keys_exist_in_dict, - check_valid_lst, check_video_and_data_frm_count_align) -from simba.utils.enums import TextOptions, Formats + check_if_keys_exist_in_dict, check_int, + check_valid_lst, + check_video_and_data_frm_count_align) +from simba.utils.data import slice_roi_dict_for_video +from simba.utils.enums import Formats, TextOptions +from simba.utils.errors import (BodypartColumnNotFoundError, NoFilesFoundError, + ROICoordinatesNotFoundError) from simba.utils.printing import stdout_success from simba.utils.read_write import (concatenate_videos_in_folder, find_core_cnt, get_fn_ext, get_video_meta_data, read_df, remove_a_folder) -from simba.utils.data import slice_roi_dict_for_video -from simba.utils.errors import ROICoordinatesNotFoundError, NoFilesFoundError, BodypartColumnNotFoundError from simba.utils.warnings import DuplicateNamesWarning -ROI_CENTERS = 'roi_centers' -ROI_EAR_TAGS = 'roi_ear_tags' -DIRECTIONALITY = 'directionality' -DIRECTIONALITY_STYLE = 'directionality_style' -BORDER_COLOR = 'border_color' -POSE = 'pose_estimation' -ANIMAL_NAMES = 'animal_names' -STYLE_KEYS = [ROI_CENTERS, ROI_EAR_TAGS, DIRECTIONALITY, BORDER_COLOR, POSE, DIRECTIONALITY_STYLE, ANIMAL_NAMES] - - - -def _roi_feature_visualizer_mp(data: pd.DataFrame, - text_locations: dict, - font_size: float, - circle_size: float, - save_temp_dir: str, - video_meta_data: dict, - shape_info: dict, - shape_names: list, - style_attr: dict, - video_path: str, - animal_names: list, - roi_dict: dict, - bp_lk: dict, - animal_bps: dict, - animal_bp_names: List[str], - animal_bp_dict: dict, - roi_features_df: pd.DataFrame, - directing_data: Union[pd.DataFrame, None]): +ROI_CENTERS = "roi_centers" +ROI_EAR_TAGS = "roi_ear_tags" +DIRECTIONALITY = "directionality" +DIRECTIONALITY_STYLE = "directionality_style" +BORDER_COLOR = "border_color" +POSE = "pose_estimation" +ANIMAL_NAMES = "animal_names" +STYLE_KEYS = [ + ROI_CENTERS, + ROI_EAR_TAGS, + DIRECTIONALITY, + BORDER_COLOR, + POSE, + DIRECTIONALITY_STYLE, + ANIMAL_NAMES, +] + + +def _roi_feature_visualizer_mp( + data: pd.DataFrame, + text_locations: dict, + font_size: float, + circle_size: float, + save_temp_dir: str, + video_meta_data: dict, + shape_info: dict, + shape_names: list, + style_attr: dict, + video_path: str, + animal_names: list, + roi_dict: dict, + bp_lk: dict, + animal_bps: dict, + animal_bp_names: List[str], + animal_bp_dict: dict, + roi_features_df: pd.DataFrame, + directing_data: Union[pd.DataFrame, None], +): def __insert_texts(shape_info: dict, img: np.ndarray): for shape_name, shape_info in shape_info.items(): shape_color = shape_info["Color BGR"] for cnt, animal_data in bp_lk.items(): animal, animal_bp, _ = animal_data - animal_name = f'{animal} {animal_bp}' - cv2.putText(img,text_locations[animal_name][shape_name]["in_zone_text"],text_locations[animal_name][shape_name]["in_zone_text_loc"],font, font_size ,shape_color,1) - cv2.putText(img, text_locations[animal_name][shape_name]["distance_text"], text_locations[animal_name][shape_name]["distance_text_loc"], font, font_size, shape_color, 1) + animal_name = f"{animal} {animal_bp}" + cv2.putText( + img, + text_locations[animal_name][shape_name]["in_zone_text"], + text_locations[animal_name][shape_name]["in_zone_text_loc"], + font, + font_size, + shape_color, + 1, + ) + cv2.putText( + img, + text_locations[animal_name][shape_name]["distance_text"], + text_locations[animal_name][shape_name]["distance_text_loc"], + font, + font_size, + shape_color, + 1, + ) if directing_data is not None and style_attr[DIRECTIONALITY]: - cv2.putText(img, text_locations[animal][shape_name]["directing_text"], text_locations[animal][shape_name][ "directing_text_loc"], font, font_size, shape_color, 1) + cv2.putText( + img, + text_locations[animal][shape_name]["directing_text"], + text_locations[animal][shape_name]["directing_text_loc"], + font, + font_size, + shape_color, + 1, + ) return img - fourcc = cv2.VideoWriter_fourcc(*Formats.MP4_CODEC.value) font = cv2.FONT_HERSHEY_COMPLEX group_cnt = int(data["group"].values[0]) start_frm, current_frm, end_frm = data.index[0], data.index[0], data.index[-1] save_path = os.path.join(save_temp_dir, f"{group_cnt}.mp4") - writer = cv2.VideoWriter(save_path, fourcc, video_meta_data["fps"], (video_meta_data["width"] * 2, video_meta_data["height"])) + writer = cv2.VideoWriter( + save_path, + fourcc, + video_meta_data["fps"], + (video_meta_data["width"] * 2, video_meta_data["height"]), + ) cap = cv2.VideoCapture(video_path) cap.set(1, start_frm) while current_frm <= end_frm: ret, img = cap.read() if ret: - img = cv2.copyMakeBorder(img, 0, 0, 0, int(video_meta_data["width"]), borderType=cv2.BORDER_CONSTANT, value=style_attr[BORDER_COLOR]) + img = cv2.copyMakeBorder( + img, + 0, + 0, + 0, + int(video_meta_data["width"]), + borderType=cv2.BORDER_CONSTANT, + value=style_attr[BORDER_COLOR], + ) img = __insert_texts(shape_info=shape_info, img=img) if style_attr[POSE]: for animal_name, bp_data in animal_bp_dict.items(): - for bp_cnt, bp in enumerate(zip(bp_data['X_bps'], bp_data['Y_bps'])): - bp_cords = data.loc[current_frm, list(bp)].values.astype(np.int64) - cv2.circle(img, (bp_cords[0], bp_cords[1]), 0, animal_bp_dict[animal_name]["colors"][bp_cnt], circle_size) + for bp_cnt, bp in enumerate( + zip(bp_data["X_bps"], bp_data["Y_bps"]) + ): + bp_cords = data.loc[current_frm, list(bp)].values.astype( + np.int64 + ) + cv2.circle( + img, + (bp_cords[0], bp_cords[1]), + 0, + animal_bp_dict[animal_name]["colors"][bp_cnt], + circle_size, + ) if style_attr[ANIMAL_NAMES]: for animal_name, bp_data in animal_bp_dict.items(): - headers = [bp_data['X_bps'][-1], bp_data['Y_bps'][-1]] + headers = [bp_data["X_bps"][-1], bp_data["Y_bps"][-1]] bp_cords = data.loc[current_frm, headers].values.astype(np.int64) - cv2.putText(img, animal_name, (bp_cords[0], bp_cords[1]), font, font_size, animal_bp_dict[animal_name]["colors"][0], 1) - - img = PlottingMixin.roi_dict_onto_img(img=img, - roi_dict=roi_dict, - circle_size=circle_size, - show_tags=style_attr[ROI_EAR_TAGS], - show_center=style_attr[ROI_CENTERS]) + cv2.putText( + img, + animal_name, + (bp_cords[0], bp_cords[1]), + font, + font_size, + animal_bp_dict[animal_name]["colors"][0], + 1, + ) + + img = PlottingMixin.roi_dict_onto_img( + img=img, + roi_dict=roi_dict, + circle_size=circle_size, + show_tags=style_attr[ROI_EAR_TAGS], + show_center=style_attr[ROI_CENTERS], + ) - for animal_name, shape_name in itertools.product(animal_bp_names, shape_names): + for animal_name, shape_name in itertools.product( + animal_bp_names, shape_names + ): in_zone_col_name = f"{shape_name} {animal_name} in zone" distance_col_name = f"{shape_name} {animal_name} distance" - in_zone_value = str(bool(roi_features_df.loc[current_frm, in_zone_col_name])) - distance_value = round(roi_features_df.loc[current_frm, distance_col_name], 2) - cv2.putText(img, in_zone_value, text_locations[animal_name][shape_name]["in_zone_data_loc"], font, font_size, shape_info[shape_name]["Color BGR"], 1) - cv2.putText(img, str(distance_value), text_locations[animal_name][shape_name]["distance_data_loc"], font, font_size, shape_info[shape_name]["Color BGR"], 1) + in_zone_value = str( + bool(roi_features_df.loc[current_frm, in_zone_col_name]) + ) + distance_value = round( + roi_features_df.loc[current_frm, distance_col_name], 2 + ) + cv2.putText( + img, + in_zone_value, + text_locations[animal_name][shape_name]["in_zone_data_loc"], + font, + font_size, + shape_info[shape_name]["Color BGR"], + 1, + ) + cv2.putText( + img, + str(distance_value), + text_locations[animal_name][shape_name]["distance_data_loc"], + font, + font_size, + shape_info[shape_name]["Color BGR"], + 1, + ) if (directing_data is not None) and (style_attr[DIRECTIONALITY]): - for animal_name, shape_name in itertools.product(animal_names, shape_names): + for animal_name, shape_name in itertools.product( + animal_names, shape_names + ): facing_col_name = f"{shape_name} {animal_name} facing" facing_value = roi_features_df.loc[current_frm, facing_col_name] - cv2.putText(img, str(bool(facing_value)), text_locations[animal_name][shape_name]["directing_data_loc"], font, font_size, shape_info[shape_name]["Color BGR"], 1) + cv2.putText( + img, + str(bool(facing_value)), + text_locations[animal_name][shape_name]["directing_data_loc"], + font, + font_size, + shape_info[shape_name]["Color BGR"], + 1, + ) if facing_value: - img = PlottingMixin.insert_directing_line(directing_df=directing_data, - img=img, - shape_name=shape_name, - animal_name=animal_name, - frame_id=current_frm, - color=shape_info[shape_name]['Color BGR'], - thickness=shape_info[shape_name]['Thickness'], - style=style_attr[DIRECTIONALITY_STYLE]) + img = PlottingMixin.insert_directing_line( + directing_df=directing_data, + img=img, + shape_name=shape_name, + animal_name=animal_name, + frame_id=current_frm, + color=shape_info[shape_name]["Color BGR"], + thickness=shape_info[shape_name]["Thickness"], + style=style_attr[DIRECTIONALITY_STYLE], + ) current_frm += 1 writer.write(np.uint8(img)) - print(f"Multiprocessing frame: {current_frm} / {video_meta_data['frame_count']} on core {group_cnt}...") + print( + f"Multiprocessing frame: {current_frm} / {video_meta_data['frame_count']} on core {group_cnt}..." + ) else: break writer.release() return group_cnt - - - class ROIfeatureVisualizerMultiprocess(ConfigReader): """ Visualize features that depend on the relationships between the location of the animals and user-defined @@ -164,54 +263,103 @@ class ROIfeatureVisualizerMultiprocess(ConfigReader): >>> _ = ROIfeatureVisualizerMultiprocess(config_path='test_/project_folder/project_config.ini', video_name='Together_1.avi', style_attr=style_attr, core_cnt=3).run() """ - def __init__(self, - config_path: Union[str, os.PathLike], - video_path: Union[str, os.PathLike], - body_parts: List[str], - style_attr: Dict[str, Any], - core_cnt: Optional[int] = -1): + def __init__( + self, + config_path: Union[str, os.PathLike], + video_path: Union[str, os.PathLike], + body_parts: List[str], + style_attr: Dict[str, Any], + core_cnt: Optional[int] = -1, + ): if platform.system() == "Darwin": multiprocessing.set_start_method("spawn", force=True) - check_int(name=f"{self.__class__.__name__} core_cnt", value=core_cnt, min_value=-1, max_value=find_core_cnt()[0]) - if core_cnt == -1: core_cnt = find_core_cnt()[0] + check_int( + name=f"{self.__class__.__name__} core_cnt", + value=core_cnt, + min_value=-1, + max_value=find_core_cnt()[0], + ) + if core_cnt == -1: + core_cnt = find_core_cnt()[0] check_file_exist_and_readable(file_path=video_path) ConfigReader.__init__(self, config_path=config_path) PlottingMixin.__init__(self) - check_if_keys_exist_in_dict(data=style_attr, key=STYLE_KEYS, name=f'{self.__class__.__name__} style_attr') + check_if_keys_exist_in_dict( + data=style_attr, + key=STYLE_KEYS, + name=f"{self.__class__.__name__} style_attr", + ) if not os.path.isfile(self.roi_coordinates_path): - raise ROICoordinatesNotFoundError(expected_file_path=self.roi_coordinates_path) + raise ROICoordinatesNotFoundError( + expected_file_path=self.roi_coordinates_path + ) self.read_roi_data() _, self.video_name, _ = get_fn_ext(video_path) - self.roi_dict, self.shape_names = slice_roi_dict_for_video(data=self.roi_dict, video_name=self.video_name) + self.roi_dict, self.shape_names = slice_roi_dict_for_video( + data=self.roi_dict, video_name=self.video_name + ) self.core_cnt, self.style_attr = core_cnt, style_attr - self.save_path = os.path.join(self.roi_features_save_dir, f"{self.video_name}.mp4") + self.save_path = os.path.join( + self.roi_features_save_dir, f"{self.video_name}.mp4" + ) if not os.path.exists(self.roi_features_save_dir): os.makedirs(self.roi_features_save_dir) self.save_temp_dir = os.path.join(self.roi_features_save_dir, "temp") if os.path.exists(self.save_temp_dir): remove_a_folder(folder_dir=self.save_temp_dir) os.makedirs(self.save_temp_dir) - self.data_path = os.path.join(self.outlier_corrected_dir, f"{self.video_name}.{self.file_type}") + self.data_path = os.path.join( + self.outlier_corrected_dir, f"{self.video_name}.{self.file_type}" + ) if not os.path.isfile(self.data_path): - raise NoFilesFoundError( msg=f"SIMBA ERROR: Could not find the file at path {self.data_path}. Make sure the data file exist to create ROI visualizations", source=self.__class__.__name__) - check_valid_lst(data=body_parts, source=f'{self.__class__.__name__} body-parts', valid_dtypes=(str,), min_len=1) + raise NoFilesFoundError( + msg=f"SIMBA ERROR: Could not find the file at path {self.data_path}. Make sure the data file exist to create ROI visualizations", + source=self.__class__.__name__, + ) + check_valid_lst( + data=body_parts, + source=f"{self.__class__.__name__} body-parts", + valid_dtypes=(str,), + min_len=1, + ) for bp in body_parts: if bp not in self.body_parts_lst: - raise BodypartColumnNotFoundError(msg=f'The body-part {bp} is not a valid body-part in the SimBA project. Options: {self.body_parts_lst}', source=self.__class__.__name__) - self.roi_feature_creator = ROIFeatureCreator(config_path=config_path, body_parts=body_parts, append_data=False, data_path=self.data_path) + raise BodypartColumnNotFoundError( + msg=f"The body-part {bp} is not a valid body-part in the SimBA project. Options: {self.body_parts_lst}", + source=self.__class__.__name__, + ) + self.roi_feature_creator = ROIFeatureCreator( + config_path=config_path, + body_parts=body_parts, + append_data=False, + data_path=self.data_path, + ) self.roi_feature_creator.run() self.bp_lk = self.roi_feature_creator.bp_lk self.animal_names = [v[0] for v in self.bp_lk.values()] - self.animal_bp_names = [f'{v[0]} {v[1]}' for v in self.bp_lk.values()] + self.animal_bp_names = [f"{v[0]} {v[1]}" for v in self.bp_lk.values()] self.video_meta_data = get_video_meta_data(video_path) self.fourcc = cv2.VideoWriter_fourcc(*Formats.MP4_CODEC.value) self.cap = cv2.VideoCapture(video_path) - self.max_dim = max(self.video_meta_data["width"], self.video_meta_data["height"]) - self.circle_size = int(TextOptions.RADIUS_SCALER.value / (TextOptions.RESOLUTION_SCALER.value / self.max_dim)) - self.font_size = float(TextOptions.FONT_SCALER.value / (TextOptions.RESOLUTION_SCALER.value / self.max_dim)) - self.spacing_scale = int(TextOptions.SPACE_SCALER.value / (TextOptions.RESOLUTION_SCALER.value / self.max_dim)) - check_video_and_data_frm_count_align(video=video_path, data=self.data_path, name=video_path, raise_error=False) + self.max_dim = max( + self.video_meta_data["width"], self.video_meta_data["height"] + ) + self.circle_size = int( + TextOptions.RADIUS_SCALER.value + / (TextOptions.RESOLUTION_SCALER.value / self.max_dim) + ) + self.font_size = float( + TextOptions.FONT_SCALER.value + / (TextOptions.RESOLUTION_SCALER.value / self.max_dim) + ) + self.spacing_scale = int( + TextOptions.SPACE_SCALER.value + / (TextOptions.RESOLUTION_SCALER.value / self.max_dim) + ) + check_video_and_data_frm_count_align( + video=video_path, data=self.data_path, name=video_path, raise_error=False + ) self.style_attr = style_attr self.direct_viable = self.roi_feature_creator.roi_directing_viable self.data_df = read_df(file_path=self.data_path, file_type=self.file_type) @@ -225,7 +373,10 @@ def __create_shape_dicts(self): for shape, df in self.roi_dict.items(): if not df["Name"].is_unique: df = df.drop_duplicates(subset=["Name"], keep="first") - DuplicateNamesWarning(msg=f'Some of your ROIs with the same shape ({shape}) has the same names for video {self.video_name}. E.g., you have two rectangles named "My rectangle". SimBA prefers ROI shapes with unique names. SimBA will keep one of the unique shape names and drop the rest.', source=self.__class__.__name__) + DuplicateNamesWarning( + msg=f'Some of your ROIs with the same shape ({shape}) has the same names for video {self.video_name}. E.g., you have two rectangles named "My rectangle". SimBA prefers ROI shapes with unique names. SimBA will keep one of the unique shape names and drop the rest.', + source=self.__class__.__name__, + ) d = df.set_index("Name").to_dict(orient="index") shape_dicts = {**shape_dicts, **d} return shape_dicts @@ -235,61 +386,131 @@ def __calc_text_locs(self): self.loc_dict = {} for animal_cnt, animal_data in self.bp_lk.items(): animal, animal_bp, _ = animal_data - animal_name = f'{animal} {animal_bp}' + animal_name = f"{animal} {animal_bp}" self.loc_dict[animal_name] = {} self.loc_dict[animal] = {} for shape in self.shape_names: self.loc_dict[animal_name][shape] = {} - self.loc_dict[animal_name][shape]["in_zone_text"] = f"{shape} {animal_name} in zone" - self.loc_dict[animal_name][shape]["distance_text"] = f"{shape} {animal_name} distance" - self.loc_dict[animal_name][shape]["in_zone_text_loc"] = ((self.video_meta_data["width"] + 5), (self.video_meta_data["height"] - (self.video_meta_data["height"] + 10) + self.spacing_scale * add_spacer)) - self.loc_dict[animal_name][shape]["in_zone_data_loc"] = (int(self.img_w_border_w - (self.img_w_border_w / 8)), (self.video_meta_data["height"] - (self.video_meta_data["height"] + 10) + self.spacing_scale * add_spacer)) + self.loc_dict[animal_name][shape][ + "in_zone_text" + ] = f"{shape} {animal_name} in zone" + self.loc_dict[animal_name][shape][ + "distance_text" + ] = f"{shape} {animal_name} distance" + self.loc_dict[animal_name][shape]["in_zone_text_loc"] = ( + (self.video_meta_data["width"] + 5), + ( + self.video_meta_data["height"] + - (self.video_meta_data["height"] + 10) + + self.spacing_scale * add_spacer + ), + ) + self.loc_dict[animal_name][shape]["in_zone_data_loc"] = ( + int(self.img_w_border_w - (self.img_w_border_w / 8)), + ( + self.video_meta_data["height"] + - (self.video_meta_data["height"] + 10) + + self.spacing_scale * add_spacer + ), + ) add_spacer += 1 - self.loc_dict[animal_name][shape]["distance_text_loc"] = ((self.video_meta_data["width"] + 5), (self.video_meta_data["height"] - (self.video_meta_data["height"] + 10) + self.spacing_scale * add_spacer)) - self.loc_dict[animal_name][shape]["distance_data_loc"] = (int(self.img_w_border_w - (self.img_w_border_w / 8)), (self.video_meta_data["height"] - (self.video_meta_data["height"] + 10) + self.spacing_scale * add_spacer)) + self.loc_dict[animal_name][shape]["distance_text_loc"] = ( + (self.video_meta_data["width"] + 5), + ( + self.video_meta_data["height"] + - (self.video_meta_data["height"] + 10) + + self.spacing_scale * add_spacer + ), + ) + self.loc_dict[animal_name][shape]["distance_data_loc"] = ( + int(self.img_w_border_w - (self.img_w_border_w / 8)), + ( + self.video_meta_data["height"] + - (self.video_meta_data["height"] + 10) + + self.spacing_scale * add_spacer + ), + ) add_spacer += 1 if self.direct_viable and self.style_attr[DIRECTIONALITY]: self.loc_dict[animal][shape] = {} - self.loc_dict[animal][shape]["directing_text"] = f"{shape} {animal} facing" - self.loc_dict[animal][shape]["directing_text_loc"] = ((self.video_meta_data["width"] + 5), (self.video_meta_data["height"] - (self.video_meta_data["height"] + 10) + self.spacing_scale * add_spacer)) - self.loc_dict[animal][shape]["directing_data_loc"] = (int(self.img_w_border_w - (self.img_w_border_w / 8)), (self.video_meta_data["height"] - (self.video_meta_data["height"] + 10) + self.spacing_scale * add_spacer)) + self.loc_dict[animal][shape][ + "directing_text" + ] = f"{shape} {animal} facing" + self.loc_dict[animal][shape]["directing_text_loc"] = ( + (self.video_meta_data["width"] + 5), + ( + self.video_meta_data["height"] + - (self.video_meta_data["height"] + 10) + + self.spacing_scale * add_spacer + ), + ) + self.loc_dict[animal][shape]["directing_data_loc"] = ( + int(self.img_w_border_w - (self.img_w_border_w / 8)), + ( + self.video_meta_data["height"] + - (self.video_meta_data["height"] + 10) + + self.spacing_scale * add_spacer + ), + ) add_spacer += 1 - def __get_border_img_size(self, video_path: Union[str, os.PathLike]): cap = cv2.VideoCapture(video_path) cap.set(1, 1) _, img = self.cap.read() - bordered_img = cv2.copyMakeBorder(img, 0, 0, 0, int(self.video_meta_data["width"]), borderType=cv2.BORDER_CONSTANT, value=[0, 0, 0]) + bordered_img = cv2.copyMakeBorder( + img, + 0, + 0, + 0, + int(self.video_meta_data["width"]), + borderType=cv2.BORDER_CONSTANT, + value=[0, 0, 0], + ) cap.release() return bordered_img.shape[0], bordered_img.shape[1] def run(self): - self.img_w_border_h, self.img_w_border_w = self.__get_border_img_size(video_path=self.video_path) + self.img_w_border_h, self.img_w_border_w = self.__get_border_img_size( + video_path=self.video_path + ) self.__calc_text_locs() - data_arr, frm_per_core = PlottingMixin.split_and_group_df(self, df=self.roi_feature_creator.out_df, splits=self.core_cnt, include_split_order=True) - print(f"Creating ROI feature images, multiprocessing (chunksize: {self.multiprocess_chunksize}, cores: {self.core_cnt})...") - with multiprocessing.Pool(self.core_cnt, maxtasksperchild=self.maxtasksperchild) as pool: - constants = functools.partial(_roi_feature_visualizer_mp, - text_locations=self.loc_dict, - font_size=self.font_size, - circle_size=self.circle_size, - video_meta_data=self.video_meta_data, - shape_info=self.shape_dicts, - roi_dict=self.roi_dict, - style_attr=self.style_attr, - save_temp_dir=self.save_temp_dir, - directing_data=self.directing_df, - shape_names=self.shape_names, - animal_bp_names=self.animal_bp_names, - video_path=self.video_path, - animal_names=self.animal_names, - animal_bp_dict= self.animal_bp_dict, - bp_lk=self.bp_lk, - roi_features_df=self.roi_features_df, - animal_bps=self.animal_bp_dict) - for cnt, result in enumerate(pool.imap(constants, data_arr, chunksize=self.multiprocess_chunksize)): - print(f'Batch core {result+1}/{self.core_cnt} complete...') + data_arr, frm_per_core = PlottingMixin.split_and_group_df( + self, + df=self.roi_feature_creator.out_df, + splits=self.core_cnt, + include_split_order=True, + ) + print( + f"Creating ROI feature images, multiprocessing (chunksize: {self.multiprocess_chunksize}, cores: {self.core_cnt})..." + ) + with multiprocessing.Pool( + self.core_cnt, maxtasksperchild=self.maxtasksperchild + ) as pool: + constants = functools.partial( + _roi_feature_visualizer_mp, + text_locations=self.loc_dict, + font_size=self.font_size, + circle_size=self.circle_size, + video_meta_data=self.video_meta_data, + shape_info=self.shape_dicts, + roi_dict=self.roi_dict, + style_attr=self.style_attr, + save_temp_dir=self.save_temp_dir, + directing_data=self.directing_df, + shape_names=self.shape_names, + animal_bp_names=self.animal_bp_names, + video_path=self.video_path, + animal_names=self.animal_names, + animal_bp_dict=self.animal_bp_dict, + bp_lk=self.bp_lk, + roi_features_df=self.roi_features_df, + animal_bps=self.animal_bp_dict, + ) + for cnt, result in enumerate( + pool.imap(constants, data_arr, chunksize=self.multiprocess_chunksize) + ): + print(f"Batch core {result+1}/{self.core_cnt} complete...") print(f"Joining {self.video_name} multi-processed video...") concatenate_videos_in_folder( in_folder=self.save_temp_dir, diff --git a/simba/plotting/ROI_plotter.py b/simba/plotting/ROI_plotter.py index 59028158c..c16cbcf96 100644 --- a/simba/plotting/ROI_plotter.py +++ b/simba/plotting/ROI_plotter.py @@ -1,23 +1,29 @@ __author__ = "Simon Nilsson" -import os -from typing import Tuple, Optional, Union, Dict, List import itertools +import os +from typing import Dict, List, Optional, Tuple, Union import cv2 + from simba.mixins.config_reader import ConfigReader from simba.mixins.plotting_mixin import PlottingMixin from simba.sandbox.ROI_analyzer import ROIAnalyzer -from simba.utils.data import create_color_palettes, slice_roi_dict_for_video, detect_bouts -from simba.utils.checks import check_float, check_if_keys_exist_in_dict, check_file_exist_and_readable, check_video_and_data_frm_count_align, check_valid_lst -from simba.utils.enums import Formats, Paths, TagNames, TextOptions, Keys -from simba.utils.errors import DuplicationError, NoFilesFoundError, CountError, BodypartColumnNotFoundError, ROICoordinatesNotFoundError +from simba.utils.checks import (check_file_exist_and_readable, check_float, + check_if_keys_exist_in_dict, check_valid_lst, + check_video_and_data_frm_count_align) +from simba.utils.data import (create_color_palettes, detect_bouts, + slice_roi_dict_for_video) +from simba.utils.enums import Formats, Keys, Paths, TagNames, TextOptions +from simba.utils.errors import (BodypartColumnNotFoundError, CountError, + DuplicationError, NoFilesFoundError, + ROICoordinatesNotFoundError) from simba.utils.printing import SimbaTimer, log_event, stdout_success -from simba.utils.read_write import (get_fn_ext, get_video_meta_data) +from simba.utils.read_write import get_fn_ext, get_video_meta_data from simba.utils.warnings import DuplicateNamesWarning -SHOW_BODY_PARTS = 'show_body_part' -SHOW_ANIMAL_NAMES = 'show_animal_name' +SHOW_BODY_PARTS = "show_body_part" +SHOW_ANIMAL_NAMES = "show_animal_name" STYLE_KEYS = [SHOW_BODY_PARTS, SHOW_ANIMAL_NAMES] @@ -52,46 +58,97 @@ class ROIPlot(ConfigReader): >>> test.run() """ - def __init__(self, - config_path: Union[str, os.PathLike], - video_path: Union[str, os.PathLike], - style_attr: Dict[str, bool], - body_parts: List[str], - threshold: Optional[float] = 0.0): - - check_float(name=f'{self.__class__.__name__} threshold', value=threshold, min_value=0.0, max_value=1.0) - check_if_keys_exist_in_dict(data=style_attr, key=STYLE_KEYS, name=f'{self.__class__.__name__} style_attr') + def __init__( + self, + config_path: Union[str, os.PathLike], + video_path: Union[str, os.PathLike], + style_attr: Dict[str, bool], + body_parts: List[str], + threshold: Optional[float] = 0.0, + ): + + check_float( + name=f"{self.__class__.__name__} threshold", + value=threshold, + min_value=0.0, + max_value=1.0, + ) + check_if_keys_exist_in_dict( + data=style_attr, + key=STYLE_KEYS, + name=f"{self.__class__.__name__} style_attr", + ) check_file_exist_and_readable(file_path=video_path) _, self.video_name, _ = get_fn_ext(video_path) ConfigReader.__init__(self, config_path=config_path) if not os.path.isfile(self.roi_coordinates_path): - raise ROICoordinatesNotFoundError(expected_file_path=self.roi_coordinates_path) - self.data_path = os.path.join(self.outlier_corrected_dir, f'{self.video_name}.{self.file_type}') + raise ROICoordinatesNotFoundError( + expected_file_path=self.roi_coordinates_path + ) + self.data_path = os.path.join( + self.outlier_corrected_dir, f"{self.video_name}.{self.file_type}" + ) if not os.path.isfile(self.data_path): - raise NoFilesFoundError( msg=f"SIMBA ERROR: Could not find the file at path {self.data_path}. Make sure the data file exist to create ROI visualizations", source=self.__class__.__name__) - log_event(logger_name=str(__class__.__name__), log_type=TagNames.CLASS_INIT.value, msg=self.create_log_msg_from_init_args(locals=locals())) - check_valid_lst(data=body_parts, source=f'{self.__class__.__name__} body-parts', valid_dtypes=(str,), min_len=1) + raise NoFilesFoundError( + msg=f"SIMBA ERROR: Could not find the file at path {self.data_path}. Make sure the data file exist to create ROI visualizations", + source=self.__class__.__name__, + ) + log_event( + logger_name=str(__class__.__name__), + log_type=TagNames.CLASS_INIT.value, + msg=self.create_log_msg_from_init_args(locals=locals()), + ) + check_valid_lst( + data=body_parts, + source=f"{self.__class__.__name__} body-parts", + valid_dtypes=(str,), + min_len=1, + ) if len(set(body_parts)) != len(body_parts): - raise DuplicationError(msg=f'All body-part entries have to be unique. Got {body_parts}', source=self.__class__.__name__) + raise DuplicationError( + msg=f"All body-part entries have to be unique. Got {body_parts}", + source=self.__class__.__name__, + ) for bp in body_parts: if bp not in self.body_parts_lst: - raise BodypartColumnNotFoundError(msg=f'The body-part {bp} is not a valid body-part in the SimBA project. Options: {self.body_parts_lst}', source=self.__class__.__name__) - - self.roi_analyzer = ROIAnalyzer(config_path=self.config_path, data_path=self.data_path, detailed_bout_data=True, threshold=threshold, body_parts=body_parts) + raise BodypartColumnNotFoundError( + msg=f"The body-part {bp} is not a valid body-part in the SimBA project. Options: {self.body_parts_lst}", + source=self.__class__.__name__, + ) + + self.roi_analyzer = ROIAnalyzer( + config_path=self.config_path, + data_path=self.data_path, + detailed_bout_data=True, + threshold=threshold, + body_parts=body_parts, + ) self.roi_analyzer.run() self.roi_entries_df = self.roi_analyzer.detailed_df self.data_df, self.style_attr = self.roi_analyzer.data_df, style_attr self.save_dir = os.path.join(self.project_path, Paths.ROI_ANALYSIS.value) - if not os.path.exists(self.save_dir): os.makedirs(self.save_dir) + if not os.path.exists(self.save_dir): + os.makedirs(self.save_dir) self.video_save_path = os.path.join(self.save_dir, f"{self.video_name}.mp4") self.read_roi_data() self.shape_columns = [] - self.roi_dict, self.shape_names = slice_roi_dict_for_video(data=self.roi_dict, video_name=self.video_name) + self.roi_dict, self.shape_names = slice_roi_dict_for_video( + data=self.roi_dict, video_name=self.video_name + ) if len(self.shape_names) == 0: - raise CountError(msg=f'No drawn ROIs detected for video {self.video_name}, please draw ROIs on this video before visualizing ROIs', source=self.__class__.__name__) - self.animal_names = [self.find_animal_name_from_body_part_name(bp_name=x, bp_dict=self.animal_bp_dict) for x in body_parts] + raise CountError( + msg=f"No drawn ROIs detected for video {self.video_name}, please draw ROIs on this video before visualizing ROIs", + source=self.__class__.__name__, + ) + self.animal_names = [ + self.find_animal_name_from_body_part_name( + bp_name=x, bp_dict=self.animal_bp_dict + ) + for x in body_parts + ] for x in itertools.product(self.animal_names, self.shape_names): - self.data_df[f"{x[0]}_{x[1]}"] = 0; self.shape_columns.append(f"{x[0]}_{x[1]}") + self.data_df[f"{x[0]}_{x[1]}"] = 0 + self.shape_columns.append(f"{x[0]}_{x[1]}") self.bp_dict = self.roi_analyzer.bp_dict self.__insert_data() self.video_path = video_path @@ -100,12 +157,16 @@ def __init__(self, self.threshold, self.body_parts = threshold, body_parts def __insert_data(self): - roi_entries_dict = self.roi_entries_df[["ANIMAL", "SHAPE NAME", "START FRAME", "END FRAME"]].to_dict(orient="records") + roi_entries_dict = self.roi_entries_df[ + ["ANIMAL", "SHAPE NAME", "START FRAME", "END FRAME"] + ].to_dict(orient="records") for entry_dict in roi_entries_dict: entry, exit = int(entry_dict["START FRAME"]), int(entry_dict["END FRAME"]) entry_dict["frame_range"] = list(range(entry, exit + 1)) - col_name = f'{entry_dict["ANIMAL"]}_{entry_dict["SHAPE NAME"]}' - self.data_df[col_name][self.data_df.index.isin(entry_dict["frame_range"])] = 1 + col_name = f'{entry_dict["ANIMAL"]}_{entry_dict["SHAPE NAME"]}' + self.data_df[col_name][ + self.data_df.index.isin(entry_dict["frame_range"]) + ] = 1 def __calc_text_locs(self) -> dict: loc_dict = {} @@ -114,13 +175,57 @@ def __calc_text_locs(self) -> dict: loc_dict[animal_name] = {} for shape in self.shape_names: loc_dict[animal_name][shape] = {} - loc_dict[animal_name][shape]["timer_text"] = f"{shape} {animal_name} timer:" - loc_dict[animal_name][shape]["entries_text"] = f"{shape} {animal_name} entries:" - loc_dict[animal_name][shape]["timer_text_loc"] = ((self.video_meta_data["width"] + TextOptions.BORDER_BUFFER_X.value), (self.video_meta_data["height"] - (self.video_meta_data["height"] + TextOptions.BORDER_BUFFER_Y.value) + self.scalers["space_size"] * line_spacer)) - loc_dict[animal_name][shape]["timer_data_loc"] = (int(self.border_img_w - (self.border_img_w / 8)), (self.video_meta_data["height"] - (self.video_meta_data["height"] + TextOptions.BORDER_BUFFER_Y.value) + self.scalers["space_size"] * line_spacer)) + loc_dict[animal_name][shape][ + "timer_text" + ] = f"{shape} {animal_name} timer:" + loc_dict[animal_name][shape][ + "entries_text" + ] = f"{shape} {animal_name} entries:" + loc_dict[animal_name][shape]["timer_text_loc"] = ( + (self.video_meta_data["width"] + TextOptions.BORDER_BUFFER_X.value), + ( + self.video_meta_data["height"] + - ( + self.video_meta_data["height"] + + TextOptions.BORDER_BUFFER_Y.value + ) + + self.scalers["space_size"] * line_spacer + ), + ) + loc_dict[animal_name][shape]["timer_data_loc"] = ( + int(self.border_img_w - (self.border_img_w / 8)), + ( + self.video_meta_data["height"] + - ( + self.video_meta_data["height"] + + TextOptions.BORDER_BUFFER_Y.value + ) + + self.scalers["space_size"] * line_spacer + ), + ) line_spacer += TextOptions.LINE_SPACING.value - loc_dict[animal_name][shape]["entries_text_loc"] = ((self.video_meta_data["width"] + TextOptions.BORDER_BUFFER_X.value), (self.video_meta_data["height"] - (self.video_meta_data["height"] + TextOptions.BORDER_BUFFER_Y.value) + self.scalers["space_size"] * line_spacer)) - loc_dict[animal_name][shape]["entries_data_loc"] = (int(self.border_img_w - (self.border_img_w / 8)), (self.video_meta_data["height"]- (self.video_meta_data["height"] + TextOptions.BORDER_BUFFER_Y.value) + self.scalers["space_size"] * line_spacer)) + loc_dict[animal_name][shape]["entries_text_loc"] = ( + (self.video_meta_data["width"] + TextOptions.BORDER_BUFFER_X.value), + ( + self.video_meta_data["height"] + - ( + self.video_meta_data["height"] + + TextOptions.BORDER_BUFFER_Y.value + ) + + self.scalers["space_size"] * line_spacer + ), + ) + loc_dict[animal_name][shape]["entries_data_loc"] = ( + int(self.border_img_w - (self.border_img_w / 8)), + ( + self.video_meta_data["height"] + - ( + self.video_meta_data["height"] + + TextOptions.BORDER_BUFFER_Y.value + ) + + self.scalers["space_size"] * line_spacer + ), + ) line_spacer += TextOptions.LINE_SPACING.value return loc_dict @@ -128,8 +233,24 @@ def __insert_texts(self, shape_df): for animal_name in self.animal_names: for _, shape in shape_df.iterrows(): shape_name, shape_color = shape["Name"], shape["Color BGR"] - cv2.putText(self.border_img, self.loc_dict[animal_name][shape_name]["timer_text"], self.loc_dict[animal_name][shape_name]["timer_text_loc"], TextOptions.FONT.value, self.scalers["font_size"], shape_color, TextOptions.TEXT_THICKNESS.value) - cv2.putText(self.border_img, self.loc_dict[animal_name][shape_name]["entries_text"], self.loc_dict[animal_name][shape_name]["entries_text_loc"], TextOptions.FONT.value, self.scalers["font_size"], shape_color, TextOptions.TEXT_THICKNESS.value) + cv2.putText( + self.border_img, + self.loc_dict[animal_name][shape_name]["timer_text"], + self.loc_dict[animal_name][shape_name]["timer_text_loc"], + TextOptions.FONT.value, + self.scalers["font_size"], + shape_color, + TextOptions.TEXT_THICKNESS.value, + ) + cv2.putText( + self.border_img, + self.loc_dict[animal_name][shape_name]["entries_text"], + self.loc_dict[animal_name][shape_name]["entries_text_loc"], + TextOptions.FONT.value, + self.scalers["font_size"], + shape_color, + TextOptions.TEXT_THICKNESS.value, + ) def __create_counters(self) -> dict: cnt_dict = {} @@ -145,18 +266,32 @@ def __create_counters(self) -> dict: def __calculate_cumulative(self): for animal_name in self.animal_names: for shape in self.shape_names: - self.data_df[f"{animal_name}_{shape}_cum_sum_time"] = (self.data_df[f"{animal_name}_{shape}"].cumsum() / self.video_meta_data['fps']) - roi_bouts = list(detect_bouts(data_df=self.data_df, target_lst=[f"{animal_name}_{shape}"], fps=self.video_meta_data['fps'])["Start_frame"]) + self.data_df[f"{animal_name}_{shape}_cum_sum_time"] = ( + self.data_df[f"{animal_name}_{shape}"].cumsum() + / self.video_meta_data["fps"] + ) + roi_bouts = list( + detect_bouts( + data_df=self.data_df, + target_lst=[f"{animal_name}_{shape}"], + fps=self.video_meta_data["fps"], + )["Start_frame"] + ) self.data_df[f"{animal_name}_{shape}_entry"] = 0 self.data_df.loc[roi_bouts, f"{animal_name}_{shape}_entry"] = 1 - self.data_df[f"{animal_name}_{shape}_cum_sum_entries"] = (self.data_df[f"{animal_name}_{shape}_entry"].cumsum()) + self.data_df[f"{animal_name}_{shape}_cum_sum_entries"] = self.data_df[ + f"{animal_name}_{shape}_entry" + ].cumsum() def __create_shape_dicts(self): shape_dicts = {} for shape, df in self.roi_dict.items(): if not df["Name"].is_unique: df = df.drop_duplicates(subset=["Name"], keep="first") - DuplicateNamesWarning(f'Some of your ROIs with the same shape ({shape}) has the same names for video {self.video_name}. E.g., you have two rectangles named "My rectangle". SimBA prefers ROI shapes with unique names. SimBA will keep one of the unique shape names and drop the rest.', source=self.__class__.__name__) + DuplicateNamesWarning( + f'Some of your ROIs with the same shape ({shape}) has the same names for video {self.video_name}. E.g., you have two rectangles named "My rectangle". SimBA prefers ROI shapes with unique names. SimBA will keep one of the unique shape names and drop the rest.', + source=self.__class__.__name__, + ) d = df.set_index("Name").to_dict(orient="index") shape_dicts = {**shape_dicts, **d} return shape_dicts @@ -165,8 +300,19 @@ def __get_bordered_img_size(self) -> Tuple[int, int]: cap = cv2.VideoCapture(self.video_path) cap.set(1, 1) _, img = self.cap.read() - self.base_img = cv2.copyMakeBorder(img, 0, 0, 0, int(self.video_meta_data["width"]), borderType=cv2.BORDER_CONSTANT, value=[0, 0, 0]) - self.base_img_h, self.base_img_w = self.base_img.shape[0], self.base_img.shape[1] + self.base_img = cv2.copyMakeBorder( + img, + 0, + 0, + 0, + int(self.video_meta_data["width"]), + borderType=cv2.BORDER_CONSTANT, + value=[0, 0, 0], + ) + self.base_img_h, self.base_img_w = ( + self.base_img.shape[0], + self.base_img.shape[1], + ) cap.release() return self.base_img_h, self.base_img_w @@ -174,49 +320,132 @@ def run(self): video_timer = SimbaTimer(start=True) max_dim = max(self.video_meta_data["width"], self.video_meta_data["height"]) self.scalers = {} - self.scalers["circle_size"] = int(TextOptions.RADIUS_SCALER.value / (TextOptions.RESOLUTION_SCALER.value / max_dim)) - self.scalers["font_size"] = float(TextOptions.FONT_SCALER.value / (TextOptions.RESOLUTION_SCALER.value / max_dim)) - self.scalers["space_size"] = int(TextOptions.SPACE_SCALER.value / (TextOptions.RESOLUTION_SCALER.value / max_dim)) - color_lst = create_color_palettes(self.roi_analyzer.animal_cnt, len(self.body_parts))[0] + self.scalers["circle_size"] = int( + TextOptions.RADIUS_SCALER.value + / (TextOptions.RESOLUTION_SCALER.value / max_dim) + ) + self.scalers["font_size"] = float( + TextOptions.FONT_SCALER.value + / (TextOptions.RESOLUTION_SCALER.value / max_dim) + ) + self.scalers["space_size"] = int( + TextOptions.SPACE_SCALER.value + / (TextOptions.RESOLUTION_SCALER.value / max_dim) + ) + color_lst = create_color_palettes( + self.roi_analyzer.animal_cnt, len(self.body_parts) + )[0] self.border_img_h, self.border_img_w = self.__get_bordered_img_size() fourcc = cv2.VideoWriter_fourcc(*Formats.MP4_CODEC.value) - writer = cv2.VideoWriter(self.video_save_path, fourcc, self.video_meta_data["fps"], (self.border_img_w, self.border_img_h)) + writer = cv2.VideoWriter( + self.video_save_path, + fourcc, + self.video_meta_data["fps"], + (self.border_img_w, self.border_img_h), + ) self.loc_dict = self.__calc_text_locs() self.cnt_dict = self.__create_counters() self.shape_dicts = self.__create_shape_dicts() self.__calculate_cumulative() - check_video_and_data_frm_count_align(video=self.video_path, data=self.data_df, name=self.video_name, raise_error=False) + check_video_and_data_frm_count_align( + video=self.video_path, + data=self.data_df, + name=self.video_name, + raise_error=False, + ) frame_cnt = 0 while self.cap.isOpened(): ret, img = self.cap.read() if ret: - self.border_img = cv2.copyMakeBorder(img, 0, 0, 0, int(self.video_meta_data["width"]), borderType=cv2.BORDER_CONSTANT, value=[0, 0, 0]) + self.border_img = cv2.copyMakeBorder( + img, + 0, + 0, + 0, + int(self.video_meta_data["width"]), + borderType=cv2.BORDER_CONSTANT, + value=[0, 0, 0], + ) self.__insert_texts(self.roi_dict[Keys.ROI_RECTANGLES.value]) self.__insert_texts(self.roi_dict[Keys.ROI_CIRCLES.value]) self.__insert_texts(self.roi_dict[Keys.ROI_POLYGONS.value]) - self.img_w_border = PlottingMixin.roi_dict_onto_img(img=self.border_img, roi_dict=self.roi_dict) + self.img_w_border = PlottingMixin.roi_dict_onto_img( + img=self.border_img, roi_dict=self.roi_dict + ) for animal_cnt, animal_name in enumerate(self.animal_names): - bp_data = (self.data_df.loc[frame_cnt, self.bp_dict[animal_name]].fillna(0.0).values) + bp_data = ( + self.data_df.loc[frame_cnt, self.bp_dict[animal_name]] + .fillna(0.0) + .values + ) if self.threshold < bp_data[2]: if self.style_attr[SHOW_BODY_PARTS]: - cv2.circle(self.border_img, (int(bp_data[0]), int(bp_data[1])), self.scalers["circle_size"], color_lst[animal_cnt], -1) + cv2.circle( + self.border_img, + (int(bp_data[0]), int(bp_data[1])), + self.scalers["circle_size"], + color_lst[animal_cnt], + -1, + ) if self.style_attr[SHOW_ANIMAL_NAMES]: - cv2.putText(self.border_img, animal_name, (int(bp_data[0]), int(bp_data[1])), self.font, self.scalers["font_size"], color_lst[animal_cnt], TextOptions.TEXT_THICKNESS.value) + cv2.putText( + self.border_img, + animal_name, + (int(bp_data[0]), int(bp_data[1])), + self.font, + self.scalers["font_size"], + color_lst[animal_cnt], + TextOptions.TEXT_THICKNESS.value, + ) for animal_cnt, animal_name in enumerate(self.animal_names): for shape in self.shape_names: - time = str(round(self.data_df.loc[frame_cnt, f"{animal_name}_{shape}_cum_sum_time"], 2)) - entries = str(int(self.data_df.loc[frame_cnt, f"{animal_name}_{shape}_cum_sum_entries"])) - cv2.putText(self.border_img, time, self.loc_dict[animal_name][shape]["timer_data_loc"], self.font, self.scalers["font_size"], self.shape_dicts[shape]["Color BGR"], TextOptions.TEXT_THICKNESS.value) - cv2.putText(self.border_img, entries, self.loc_dict[animal_name][shape]["entries_data_loc"], self.font, self.scalers["font_size"], self.shape_dicts[shape]["Color BGR"], TextOptions.TEXT_THICKNESS.value) + time = str( + round( + self.data_df.loc[ + frame_cnt, f"{animal_name}_{shape}_cum_sum_time" + ], + 2, + ) + ) + entries = str( + int( + self.data_df.loc[ + frame_cnt, f"{animal_name}_{shape}_cum_sum_entries" + ] + ) + ) + cv2.putText( + self.border_img, + time, + self.loc_dict[animal_name][shape]["timer_data_loc"], + self.font, + self.scalers["font_size"], + self.shape_dicts[shape]["Color BGR"], + TextOptions.TEXT_THICKNESS.value, + ) + cv2.putText( + self.border_img, + entries, + self.loc_dict[animal_name][shape]["entries_data_loc"], + self.font, + self.scalers["font_size"], + self.shape_dicts[shape]["Color BGR"], + TextOptions.TEXT_THICKNESS.value, + ) writer.write(self.border_img) - print(f"Frame: {frame_cnt+1} / {self.video_meta_data['frame_count']}, Video: {self.video_name}.") + print( + f"Frame: {frame_cnt+1} / {self.video_meta_data['frame_count']}, Video: {self.video_name}." + ) frame_cnt += 1 else: break writer.release() video_timer.stop_timer() - stdout_success(msg=f"Video {self.video_name} created. Video saved at {self.video_save_path}", elapsed_time=video_timer.elapsed_time_str, source=self.__class__.__name__) - + stdout_success( + msg=f"Video {self.video_name} created. Video saved at {self.video_save_path}", + elapsed_time=video_timer.elapsed_time_str, + source=self.__class__.__name__, + ) # test = ROIPlot(config_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/project_config.ini', diff --git a/simba/plotting/ROI_plotter_mp.py b/simba/plotting/ROI_plotter_mp.py index e2274ac9a..1c6e80dba 100644 --- a/simba/plotting/ROI_plotter_mp.py +++ b/simba/plotting/ROI_plotter_mp.py @@ -1,12 +1,12 @@ __author__ = "Simon Nilsson" import functools +import itertools import multiprocessing import os import platform import shutil -import itertools -from typing import Optional, Union, Dict, Tuple, List +from typing import Dict, List, Optional, Tuple, Union import cv2 import numpy as np @@ -14,82 +14,155 @@ from simba.mixins.config_reader import ConfigReader from simba.mixins.plotting_mixin import PlottingMixin -#from simba.roi_tools.ROI_analyzer import ROIAnalyzer +# from simba.roi_tools.ROI_analyzer import ROIAnalyzer from simba.sandbox.ROI_analyzer import ROIAnalyzer -from simba.utils.data import create_color_palettes, detect_bouts, slice_roi_dict_for_video -from simba.utils.enums import Paths, TagNames, TextOptions, Formats, Keys -from simba.utils.errors import NoFilesFoundError, CountError, BodypartColumnNotFoundError, ROICoordinatesNotFoundError +from simba.utils.checks import (check_file_exist_and_readable, check_float, + check_if_keys_exist_in_dict, check_int, + check_valid_lst, + check_video_and_data_frm_count_align) +from simba.utils.data import (create_color_palettes, detect_bouts, + slice_roi_dict_for_video) +from simba.utils.enums import Formats, Keys, Paths, TagNames, TextOptions +from simba.utils.errors import (BodypartColumnNotFoundError, CountError, + NoFilesFoundError, ROICoordinatesNotFoundError) from simba.utils.printing import SimbaTimer, log_event, stdout_success -from simba.utils.read_write import (concatenate_videos_in_folder, get_fn_ext, get_video_meta_data, find_core_cnt) +from simba.utils.read_write import (concatenate_videos_in_folder, + find_core_cnt, get_fn_ext, + get_video_meta_data) from simba.utils.warnings import DuplicateNamesWarning -from simba.utils.checks import (check_float, - check_int, - check_if_keys_exist_in_dict, - check_file_exist_and_readable, - check_video_and_data_frm_count_align, - check_valid_lst) + pd.options.mode.chained_assignment = None -SHOW_BODY_PARTS = 'show_body_part' -SHOW_ANIMAL_NAMES = 'show_animal_name' +SHOW_BODY_PARTS = "show_body_part" +SHOW_ANIMAL_NAMES = "show_animal_name" STYLE_KEYS = [SHOW_BODY_PARTS, SHOW_ANIMAL_NAMES] -def _roi_plotter_mp(data: pd.DataFrame, - loc_dict: dict, - scalers: dict, - video_meta_data: dict, - save_temp_directory: str, - shape_meta_data: dict, - video_shape_names: list, - input_video_path: str, - body_part_dict: dict, - roi_dict: Dict[str, pd.DataFrame], - colors: list, - style_attr: dict, - animal_ids: list, - threshold: float): +def _roi_plotter_mp( + data: pd.DataFrame, + loc_dict: dict, + scalers: dict, + video_meta_data: dict, + save_temp_directory: str, + shape_meta_data: dict, + video_shape_names: list, + input_video_path: str, + body_part_dict: dict, + roi_dict: Dict[str, pd.DataFrame], + colors: list, + style_attr: dict, + animal_ids: list, + threshold: float, +): def __insert_texts(shape_df): for animal_name in animal_ids: for _, shape in shape_df.iterrows(): shape_name, shape_color = shape["Name"], shape["Color BGR"] - cv2.putText(border_img, loc_dict[animal_name][shape_name]["timer_text"], loc_dict[animal_name][shape_name]["timer_text_loc"], TextOptions.FONT.value, scalers["font_size"], shape_color, TextOptions.TEXT_THICKNESS.value) - cv2.putText(border_img, loc_dict[animal_name][shape_name]["entries_text"], loc_dict[animal_name][shape_name]["entries_text_loc"], TextOptions.FONT.value, scalers["font_size"], shape_color, TextOptions.TEXT_THICKNESS.value) + cv2.putText( + border_img, + loc_dict[animal_name][shape_name]["timer_text"], + loc_dict[animal_name][shape_name]["timer_text_loc"], + TextOptions.FONT.value, + scalers["font_size"], + shape_color, + TextOptions.TEXT_THICKNESS.value, + ) + cv2.putText( + border_img, + loc_dict[animal_name][shape_name]["entries_text"], + loc_dict[animal_name][shape_name]["entries_text_loc"], + TextOptions.FONT.value, + scalers["font_size"], + shape_color, + TextOptions.TEXT_THICKNESS.value, + ) return border_img fourcc = cv2.VideoWriter_fourcc(*Formats.MP4_CODEC.value) group_cnt = int(data["group"].values[0]) start_frm, current_frm, end_frm = data.index[0], data.index[0], data.index[-1] save_path = os.path.join(save_temp_directory, f"{group_cnt}.mp4") - writer = cv2.VideoWriter(save_path, fourcc, video_meta_data["fps"], (video_meta_data["width"] * 2, video_meta_data["height"])) + writer = cv2.VideoWriter( + save_path, + fourcc, + video_meta_data["fps"], + (video_meta_data["width"] * 2, video_meta_data["height"]), + ) cap = cv2.VideoCapture(input_video_path) cap.set(1, start_frm) while current_frm < end_frm: ret, img = cap.read() if ret: - border_img = cv2.copyMakeBorder(img, 0, 0, 0, int(video_meta_data["width"]), borderType=cv2.BORDER_CONSTANT, value=[0, 0, 0]) + border_img = cv2.copyMakeBorder( + img, + 0, + 0, + 0, + int(video_meta_data["width"]), + borderType=cv2.BORDER_CONSTANT, + value=[0, 0, 0], + ) border_img = __insert_texts(roi_dict[Keys.ROI_RECTANGLES.value]) border_img = __insert_texts(roi_dict[Keys.ROI_CIRCLES.value]) border_img = __insert_texts(roi_dict[Keys.ROI_POLYGONS.value]) - border_img = PlottingMixin.roi_dict_onto_img(img=border_img, roi_dict=roi_dict) + border_img = PlottingMixin.roi_dict_onto_img( + img=border_img, roi_dict=roi_dict + ) for animal_cnt, animal_name in enumerate(animal_ids): if style_attr[SHOW_BODY_PARTS] or style_attr[SHOW_ANIMAL_NAMES]: bp_data = data.loc[current_frm, body_part_dict[animal_name]].values if threshold < bp_data[2]: if style_attr[SHOW_BODY_PARTS]: - cv2.circle(border_img, (int(bp_data[0]), int(bp_data[1])), scalers["circle_size"], colors[animal_cnt], -1) + cv2.circle( + border_img, + (int(bp_data[0]), int(bp_data[1])), + scalers["circle_size"], + colors[animal_cnt], + -1, + ) if style_attr[SHOW_ANIMAL_NAMES]: - cv2.putText(border_img, animal_name, (int(bp_data[0]), int(bp_data[1])), TextOptions.FONT.value, scalers["font_size"], colors[animal_cnt], TextOptions.TEXT_THICKNESS.value) + cv2.putText( + border_img, + animal_name, + (int(bp_data[0]), int(bp_data[1])), + TextOptions.FONT.value, + scalers["font_size"], + colors[animal_cnt], + TextOptions.TEXT_THICKNESS.value, + ) for shape_name in video_shape_names: - timer = round(data.loc[current_frm, f"{animal_name}_{shape_name}_cum_sum_time"], 2) - entries = data.loc[current_frm, f"{animal_name}_{shape_name}_cum_sum_entries"] - cv2.putText(border_img, str(timer), loc_dict[animal_name][shape_name]["timer_data_loc"], TextOptions.FONT.value, scalers["font_size"], shape_meta_data[shape_name]["Color BGR"], TextOptions.TEXT_THICKNESS.value) - cv2.putText(border_img, str(entries), loc_dict[animal_name][shape_name]["entries_data_loc"], TextOptions.FONT.value, scalers["font_size"], shape_meta_data[shape_name]["Color BGR"], TextOptions.TEXT_THICKNESS.value) + timer = round( + data.loc[ + current_frm, f"{animal_name}_{shape_name}_cum_sum_time" + ], + 2, + ) + entries = data.loc[ + current_frm, f"{animal_name}_{shape_name}_cum_sum_entries" + ] + cv2.putText( + border_img, + str(timer), + loc_dict[animal_name][shape_name]["timer_data_loc"], + TextOptions.FONT.value, + scalers["font_size"], + shape_meta_data[shape_name]["Color BGR"], + TextOptions.TEXT_THICKNESS.value, + ) + cv2.putText( + border_img, + str(entries), + loc_dict[animal_name][shape_name]["entries_data_loc"], + TextOptions.FONT.value, + scalers["font_size"], + shape_meta_data[shape_name]["Color BGR"], + TextOptions.TEXT_THICKNESS.value, + ) writer.write(border_img) current_frm += 1 print(f"Multi-processing video frame {current_frm} on core {group_cnt}...") @@ -136,54 +209,115 @@ class ROIPlotMultiprocess(ConfigReader): >>> test.run() """ - def __init__(self, - config_path: Union[str, os.PathLike], - video_path: Union[str, os.PathLike], - body_parts: List[str], - style_attr: Dict[str, bool], - threshold: Optional[float] = 0.0, - core_cnt: Optional[int] = -1): + def __init__( + self, + config_path: Union[str, os.PathLike], + video_path: Union[str, os.PathLike], + body_parts: List[str], + style_attr: Dict[str, bool], + threshold: Optional[float] = 0.0, + core_cnt: Optional[int] = -1, + ): # if platform.system() == "Darwin": multiprocessing.set_start_method("spawn", force=True) - check_float(name=f'{self.__class__.__name__} threshold', value=threshold, min_value=0.0, max_value=1.0) - check_int(name=f'{self.__class__.__name__} core_cnt', value=core_cnt, min_value=-1) - if core_cnt == -1: core_cnt = find_core_cnt()[0] - check_if_keys_exist_in_dict(data=style_attr, key=STYLE_KEYS, name=f'{self.__class__.__name__} style_attr') + check_float( + name=f"{self.__class__.__name__} threshold", + value=threshold, + min_value=0.0, + max_value=1.0, + ) + check_int( + name=f"{self.__class__.__name__} core_cnt", value=core_cnt, min_value=-1 + ) + if core_cnt == -1: + core_cnt = find_core_cnt()[0] + check_if_keys_exist_in_dict( + data=style_attr, + key=STYLE_KEYS, + name=f"{self.__class__.__name__} style_attr", + ) check_file_exist_and_readable(file_path=video_path) _, self.video_name, _ = get_fn_ext(video_path) ConfigReader.__init__(self, config_path=config_path) if not os.path.isfile(self.roi_coordinates_path): - raise ROICoordinatesNotFoundError(expected_file_path=self.roi_coordinates_path) - self.data_path = os.path.join(self.outlier_corrected_dir, f'{self.video_name}.{self.file_type}') + raise ROICoordinatesNotFoundError( + expected_file_path=self.roi_coordinates_path + ) + self.data_path = os.path.join( + self.outlier_corrected_dir, f"{self.video_name}.{self.file_type}" + ) if not os.path.isfile(self.data_path): - raise NoFilesFoundError( msg=f"SIMBA ERROR: Could not find the file at path {self.data_path}. Make sure the data file exist to create ROI visualizations", source=self.__class__.__name__) - check_valid_lst(data=body_parts, source=f'{self.__class__.__name__} body-parts', valid_dtypes=(str,), min_len=1) + raise NoFilesFoundError( + msg=f"SIMBA ERROR: Could not find the file at path {self.data_path}. Make sure the data file exist to create ROI visualizations", + source=self.__class__.__name__, + ) + check_valid_lst( + data=body_parts, + source=f"{self.__class__.__name__} body-parts", + valid_dtypes=(str,), + min_len=1, + ) if len(set(body_parts)) != len(body_parts): - raise CountError(msg=f'All body-part entries have to be unique. Got {body_parts}', source=self.__class__.__name__) - log_event(logger_name=str(__class__.__name__), log_type=TagNames.CLASS_INIT.value, msg=self.create_log_msg_from_init_args(locals=locals())) + raise CountError( + msg=f"All body-part entries have to be unique. Got {body_parts}", + source=self.__class__.__name__, + ) + log_event( + logger_name=str(__class__.__name__), + log_type=TagNames.CLASS_INIT.value, + msg=self.create_log_msg_from_init_args(locals=locals()), + ) for bp in body_parts: if bp not in self.body_parts_lst: - raise BodypartColumnNotFoundError(msg=f'The body-part {bp} is not a valid body-part in the SimBA project. Options: {self.body_parts_lst}', source=self.__class__.__name__) - self.roi_analyzer = ROIAnalyzer(config_path=config_path, data_path=self.data_path, detailed_bout_data=True, threshold=threshold, body_parts=body_parts) + raise BodypartColumnNotFoundError( + msg=f"The body-part {bp} is not a valid body-part in the SimBA project. Options: {self.body_parts_lst}", + source=self.__class__.__name__, + ) + self.roi_analyzer = ROIAnalyzer( + config_path=config_path, + data_path=self.data_path, + detailed_bout_data=True, + threshold=threshold, + body_parts=body_parts, + ) self.roi_analyzer.run() self.roi_entries_df = self.roi_analyzer.detailed_df self.data_df, self.style_attr = self.roi_analyzer.data_df, style_attr self.out_parent_dir = os.path.join(self.project_path, Paths.ROI_ANALYSIS.value) - if not os.path.exists(self.out_parent_dir): os.makedirs(self.out_parent_dir) - self.video_save_path = os.path.join(self.out_parent_dir, f"{self.video_name}.mp4") + if not os.path.exists(self.out_parent_dir): + os.makedirs(self.out_parent_dir) + self.video_save_path = os.path.join( + self.out_parent_dir, f"{self.video_name}.mp4" + ) self.read_roi_data() self.shape_columns = [] - self.roi_dict, self.shape_names = slice_roi_dict_for_video(data=self.roi_dict, video_name=self.video_name) + self.roi_dict, self.shape_names = slice_roi_dict_for_video( + data=self.roi_dict, video_name=self.video_name + ) if len(self.shape_names) == 0: - raise CountError(msg=f'No drawn ROIs detected for video {self.video_name}, please draw ROIs on this video before visualizing ROIs', source=self.__class__.__name__) - self.animal_names = [self.find_animal_name_from_body_part_name(bp_name=x, bp_dict=self.animal_bp_dict) for x in body_parts] + raise CountError( + msg=f"No drawn ROIs detected for video {self.video_name}, please draw ROIs on this video before visualizing ROIs", + source=self.__class__.__name__, + ) + self.animal_names = [ + self.find_animal_name_from_body_part_name( + bp_name=x, bp_dict=self.animal_bp_dict + ) + for x in body_parts + ] for x in itertools.product(self.animal_names, self.shape_names): - self.data_df[f"{x[0]}_{x[1]}"] = 0; self.shape_columns.append(f"{x[0]}_{x[1]}") + self.data_df[f"{x[0]}_{x[1]}"] = 0 + self.shape_columns.append(f"{x[0]}_{x[1]}") self.bp_dict = self.roi_analyzer.bp_dict self.__insert_data() - self.video_path, self.core_cnt, self.threshold, self.body_parts = video_path, core_cnt, threshold, body_parts + self.video_path, self.core_cnt, self.threshold, self.body_parts = ( + video_path, + core_cnt, + threshold, + body_parts, + ) self.cap = cv2.VideoCapture(self.video_path) self.video_meta_data = get_video_meta_data(self.video_path) self.temp_folder = os.path.join(self.out_parent_dir, self.video_name, "temp") @@ -192,12 +326,16 @@ def __init__(self, os.makedirs(self.temp_folder) def __insert_data(self): - roi_entries_dict = self.roi_entries_df[["ANIMAL", "SHAPE NAME", "START FRAME", "END FRAME"]].to_dict(orient="records") + roi_entries_dict = self.roi_entries_df[ + ["ANIMAL", "SHAPE NAME", "START FRAME", "END FRAME"] + ].to_dict(orient="records") for entry_dict in roi_entries_dict: entry, exit = int(entry_dict["START FRAME"]), int(entry_dict["END FRAME"]) entry_dict["frame_range"] = list(range(entry, exit + 1)) - col_name = f'{entry_dict["ANIMAL"]}_{entry_dict["SHAPE NAME"]}' - self.data_df[col_name][self.data_df.index.isin(entry_dict["frame_range"])] = 1 + col_name = f'{entry_dict["ANIMAL"]}_{entry_dict["SHAPE NAME"]}' + self.data_df[col_name][ + self.data_df.index.isin(entry_dict["frame_range"]) + ] = 1 def __calc_text_locs(self) -> dict: loc_dict = {} @@ -206,13 +344,57 @@ def __calc_text_locs(self) -> dict: loc_dict[animal_name] = {} for shape in self.shape_names: loc_dict[animal_name][shape] = {} - loc_dict[animal_name][shape]["timer_text"] = f"{shape} {animal_name} timer:" - loc_dict[animal_name][shape]["entries_text"] = f"{shape} {animal_name} entries:" - loc_dict[animal_name][shape]["timer_text_loc"] = ((self.video_meta_data["width"] + TextOptions.BORDER_BUFFER_X.value), (self.video_meta_data["height"] - (self.video_meta_data["height"] + TextOptions.BORDER_BUFFER_Y.value) + self.scalers["space_size"] * line_spacer)) - loc_dict[animal_name][shape]["timer_data_loc"] = (int(self.border_img_w - (self.border_img_w / 8)), (self.video_meta_data["height"] - (self.video_meta_data["height"] + TextOptions.BORDER_BUFFER_Y.value) + self.scalers["space_size"] * line_spacer)) + loc_dict[animal_name][shape][ + "timer_text" + ] = f"{shape} {animal_name} timer:" + loc_dict[animal_name][shape][ + "entries_text" + ] = f"{shape} {animal_name} entries:" + loc_dict[animal_name][shape]["timer_text_loc"] = ( + (self.video_meta_data["width"] + TextOptions.BORDER_BUFFER_X.value), + ( + self.video_meta_data["height"] + - ( + self.video_meta_data["height"] + + TextOptions.BORDER_BUFFER_Y.value + ) + + self.scalers["space_size"] * line_spacer + ), + ) + loc_dict[animal_name][shape]["timer_data_loc"] = ( + int(self.border_img_w - (self.border_img_w / 8)), + ( + self.video_meta_data["height"] + - ( + self.video_meta_data["height"] + + TextOptions.BORDER_BUFFER_Y.value + ) + + self.scalers["space_size"] * line_spacer + ), + ) line_spacer += TextOptions.LINE_SPACING.value - loc_dict[animal_name][shape]["entries_text_loc"] = ((self.video_meta_data["width"] + TextOptions.BORDER_BUFFER_X.value), (self.video_meta_data["height"] - (self.video_meta_data["height"] + TextOptions.BORDER_BUFFER_Y.value) + self.scalers["space_size"] * line_spacer)) - loc_dict[animal_name][shape]["entries_data_loc"] = (int(self.border_img_w - (self.border_img_w / 8)), (self.video_meta_data["height"]- (self.video_meta_data["height"] + TextOptions.BORDER_BUFFER_Y.value) + self.scalers["space_size"] * line_spacer)) + loc_dict[animal_name][shape]["entries_text_loc"] = ( + (self.video_meta_data["width"] + TextOptions.BORDER_BUFFER_X.value), + ( + self.video_meta_data["height"] + - ( + self.video_meta_data["height"] + + TextOptions.BORDER_BUFFER_Y.value + ) + + self.scalers["space_size"] * line_spacer + ), + ) + loc_dict[animal_name][shape]["entries_data_loc"] = ( + int(self.border_img_w - (self.border_img_w / 8)), + ( + self.video_meta_data["height"] + - ( + self.video_meta_data["height"] + + TextOptions.BORDER_BUFFER_Y.value + ) + + self.scalers["space_size"] * line_spacer + ), + ) line_spacer += TextOptions.LINE_SPACING.value return loc_dict @@ -230,19 +412,32 @@ def __create_counters(self) -> dict: def __calculate_cumulative(self): for animal_name in self.animal_names: for shape in self.shape_names: - self.data_df[f"{animal_name}_{shape}_cum_sum_time"] = (self.data_df[f"{animal_name}_{shape}"].cumsum() / self.video_meta_data['fps']) - roi_bouts = list(detect_bouts(data_df=self.data_df, target_lst=[f"{animal_name}_{shape}"], fps=self.video_meta_data['fps'])["Start_frame"]) + self.data_df[f"{animal_name}_{shape}_cum_sum_time"] = ( + self.data_df[f"{animal_name}_{shape}"].cumsum() + / self.video_meta_data["fps"] + ) + roi_bouts = list( + detect_bouts( + data_df=self.data_df, + target_lst=[f"{animal_name}_{shape}"], + fps=self.video_meta_data["fps"], + )["Start_frame"] + ) self.data_df[f"{animal_name}_{shape}_entry"] = 0 self.data_df.loc[roi_bouts, f"{animal_name}_{shape}_entry"] = 1 - self.data_df[f"{animal_name}_{shape}_cum_sum_entries"] = (self.data_df[f"{animal_name}_{shape}_entry"].cumsum()) - + self.data_df[f"{animal_name}_{shape}_cum_sum_entries"] = self.data_df[ + f"{animal_name}_{shape}_entry" + ].cumsum() def __create_shape_dicts(self): shape_dicts = {} for shape, df in self.roi_dict.items(): if not df["Name"].is_unique: df = df.drop_duplicates(subset=["Name"], keep="first") - DuplicateNamesWarning(f'Some of your ROIs with the same shape ({shape}) has the same names for video {self.video_name}. E.g., you have two rectangles named "My rectangle". SimBA prefers ROI shapes with unique names. SimBA will keep one of the unique shape names and drop the rest.', source=self.__class__.__name__) + DuplicateNamesWarning( + f'Some of your ROIs with the same shape ({shape}) has the same names for video {self.video_name}. E.g., you have two rectangles named "My rectangle". SimBA prefers ROI shapes with unique names. SimBA will keep one of the unique shape names and drop the rest.', + source=self.__class__.__name__, + ) d = df.set_index("Name").to_dict(orient="index") shape_dicts = {**shape_dicts, **d} return shape_dicts @@ -251,7 +446,15 @@ def __get_bordered_img_size(self) -> Tuple[int, int]: cap = cv2.VideoCapture(self.video_path) cap.set(1, 1) _, img = self.cap.read() - bordered_img = cv2.copyMakeBorder(img, 0, 0, 0, int(self.video_meta_data["width"]), borderType=cv2.BORDER_CONSTANT, value=[0, 0, 0]) + bordered_img = cv2.copyMakeBorder( + img, + 0, + 0, + 0, + int(self.video_meta_data["width"]), + borderType=cv2.BORDER_CONSTANT, + value=[0, 0, 0], + ) cap.release() return bordered_img.shape[0], bordered_img.shape[1] @@ -259,47 +462,78 @@ def run(self): video_timer = SimbaTimer(start=True) max_dim = max(self.video_meta_data["width"], self.video_meta_data["height"]) self.scalers = {} - self.scalers["circle_size"] = int(TextOptions.RADIUS_SCALER.value / (TextOptions.RESOLUTION_SCALER.value / max_dim)) - self.scalers["font_size"] = float(TextOptions.FONT_SCALER.value / (TextOptions.RESOLUTION_SCALER.value / max_dim)) - self.scalers["space_size"] = int(TextOptions.SPACE_SCALER.value / (TextOptions.RESOLUTION_SCALER.value / max_dim)) - color_lst = create_color_palettes(self.roi_analyzer.animal_cnt, len(self.body_parts))[0] + self.scalers["circle_size"] = int( + TextOptions.RADIUS_SCALER.value + / (TextOptions.RESOLUTION_SCALER.value / max_dim) + ) + self.scalers["font_size"] = float( + TextOptions.FONT_SCALER.value + / (TextOptions.RESOLUTION_SCALER.value / max_dim) + ) + self.scalers["space_size"] = int( + TextOptions.SPACE_SCALER.value + / (TextOptions.RESOLUTION_SCALER.value / max_dim) + ) + color_lst = create_color_palettes( + self.roi_analyzer.animal_cnt, len(self.body_parts) + )[0] self.border_img_h, self.border_img_w = self.__get_bordered_img_size() self.loc_dict = self.__calc_text_locs() self.cnt_dict = self.__create_counters() self.shape_dicts = self.__create_shape_dicts() self.__calculate_cumulative() - check_video_and_data_frm_count_align(video=self.video_path, data=self.data_df, name=self.video_name, raise_error=False) + check_video_and_data_frm_count_align( + video=self.video_path, + data=self.data_df, + name=self.video_name, + raise_error=False, + ) data_lst = np.array_split(self.data_df.fillna(0), self.core_cnt) for cnt in range(len(data_lst)): data_lst[cnt]["group"] = cnt - print(f"Creating ROI images, multiprocessing (determined chunksize: {self.multiprocess_chunksize}, cores: {self.core_cnt})...") + print( + f"Creating ROI images, multiprocessing (determined chunksize: {self.multiprocess_chunksize}, cores: {self.core_cnt})..." + ) del self.roi_analyzer.logger - with multiprocessing.Pool(self.core_cnt, maxtasksperchild=self.maxtasksperchild) as pool: - constants = functools.partial(_roi_plotter_mp, - loc_dict=self.loc_dict, - scalers=self.scalers, - video_meta_data=self.video_meta_data, - save_temp_directory=self.temp_folder, - body_part_dict=self.bp_dict, - input_video_path=self.video_path, - roi_dict=self.roi_dict, - video_shape_names=self.shape_names, - shape_meta_data=self.shape_dicts, - colors=color_lst, - style_attr=self.style_attr, - animal_ids=self.animal_names, - threshold=self.threshold) - - for cnt, result in enumerate(pool.imap(constants, data_lst, chunksize=self.multiprocess_chunksize)): - print(f'Image batch {result+1} / {len(data_lst)} complete...') + with multiprocessing.Pool( + self.core_cnt, maxtasksperchild=self.maxtasksperchild + ) as pool: + constants = functools.partial( + _roi_plotter_mp, + loc_dict=self.loc_dict, + scalers=self.scalers, + video_meta_data=self.video_meta_data, + save_temp_directory=self.temp_folder, + body_part_dict=self.bp_dict, + input_video_path=self.video_path, + roi_dict=self.roi_dict, + video_shape_names=self.shape_names, + shape_meta_data=self.shape_dicts, + colors=color_lst, + style_attr=self.style_attr, + animal_ids=self.animal_names, + threshold=self.threshold, + ) + + for cnt, result in enumerate( + pool.imap(constants, data_lst, chunksize=self.multiprocess_chunksize) + ): + print(f"Image batch {result+1} / {len(data_lst)} complete...") print(f"Joining {self.video_name} multi-processed ROI video...") - concatenate_videos_in_folder(in_folder=self.temp_folder, save_path=self.video_save_path, video_format="mp4") + concatenate_videos_in_folder( + in_folder=self.temp_folder, + save_path=self.video_save_path, + video_format="mp4", + ) video_timer.stop_timer() pool.terminate() pool.join() - stdout_success(msg=f"Video {self.video_name} created. ROI video saved at {self.video_save_path}", elapsed_time=video_timer.elapsed_time_str, source=self.__class__.__name__, ) - + stdout_success( + msg=f"Video {self.video_name} created. ROI video saved at {self.video_save_path}", + elapsed_time=video_timer.elapsed_time_str, + source=self.__class__.__name__, + ) # test = ROIPlotMultiprocess(config_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/open_field_below/project_folder/project_config.ini', @@ -316,7 +550,6 @@ def run(self): # test.run() - # test = ROIPlotMultiprocess(config_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/project_folder/project_config.ini', # video_path="/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/project_folder/videos/SI_DAY3_308_CD1_PRESENT.mp4", # core_cnt=-1, @@ -332,7 +565,6 @@ def run(self): # test.run() - # # test = ROIPlotMultiprocess(ini_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/project_config.ini', # video_path="2022-06-20_NOB_DOT_4.mp4", diff --git a/simba/plotting/directing_animals_visualizer.py b/simba/plotting/directing_animals_visualizer.py index c7cd72fb1..2f7d6c8a5 100644 --- a/simba/plotting/directing_animals_visualizer.py +++ b/simba/plotting/directing_animals_visualizer.py @@ -1,31 +1,43 @@ __author__ = "Simon Nilsson" import os -from typing import Dict, Union, Any +from typing import Any, Dict, Union import cv2 import numpy as np -from simba.data_processors.directing_other_animals_calculator import DirectingOtherAnimalsAnalyzer +from simba.data_processors.directing_other_animals_calculator import \ + DirectingOtherAnimalsAnalyzer from simba.mixins.config_reader import ConfigReader from simba.mixins.plotting_mixin import PlottingMixin +from simba.utils.checks import (check_file_exist_and_readable, + check_if_keys_exist_in_dict, + check_if_valid_rgb_tuple, check_valid_array, + check_valid_lst, + check_video_and_data_frm_count_align) from simba.utils.data import create_color_palettes from simba.utils.enums import Formats, TextOptions -from simba.utils.errors import NoFilesFoundError, AnimalNumberError +from simba.utils.errors import AnimalNumberError, NoFilesFoundError from simba.utils.printing import stdout_success from simba.utils.read_write import get_fn_ext, get_video_meta_data, read_df from simba.utils.warnings import NoDataFoundWarning -from simba.utils.checks import check_file_exist_and_readable, check_if_keys_exist_in_dict, check_if_valid_rgb_tuple, check_video_and_data_frm_count_align, check_valid_lst, check_valid_array +DIRECTION_THICKNESS = "direction_thickness" +DIRECTIONALITY_COLOR = "directionality_color" +CIRCLE_SIZE = "circle_size" +HIGHLIGHT_ENDPOINTS = "highlight_endpoints" +SHOW_POSE = "show_pose" +ANIMAL_NAMES = "animal_names" -DIRECTION_THICKNESS = 'direction_thickness' -DIRECTIONALITY_COLOR = 'directionality_color' -CIRCLE_SIZE = 'circle_size' -HIGHLIGHT_ENDPOINTS = 'highlight_endpoints' -SHOW_POSE = 'show_pose' -ANIMAL_NAMES = 'animal_names' +STYLE_ATTR = [ + DIRECTION_THICKNESS, + DIRECTIONALITY_COLOR, + CIRCLE_SIZE, + HIGHLIGHT_ENDPOINTS, + SHOW_POSE, + ANIMAL_NAMES, +] -STYLE_ATTR = [DIRECTION_THICKNESS, DIRECTIONALITY_COLOR, CIRCLE_SIZE, HIGHLIGHT_ENDPOINTS, SHOW_POSE, ANIMAL_NAMES] class DirectingOtherAnimalsVisualizer(ConfigReader, PlottingMixin): """ @@ -54,62 +66,101 @@ class DirectingOtherAnimalsVisualizer(ConfigReader, PlottingMixin): >>> test.run() """ - def __init__(self, - config_path: Union[str, os.PathLike], - video_path: Union[str, os.PathLike], - style_attr: Dict[str, Any]): + + def __init__( + self, + config_path: Union[str, os.PathLike], + video_path: Union[str, os.PathLike], + style_attr: Dict[str, Any], + ): check_file_exist_and_readable(file_path=video_path) check_file_exist_and_readable(file_path=config_path) - check_if_keys_exist_in_dict(data=style_attr, key=STYLE_ATTR, name=f'{self.__class__.__name__} style_attr') + check_if_keys_exist_in_dict( + data=style_attr, + key=STYLE_ATTR, + name=f"{self.__class__.__name__} style_attr", + ) ConfigReader.__init__(self, config_path=config_path) PlottingMixin.__init__(self) if self.animal_cnt < 2: - raise AnimalNumberError("Cannot analyze directionality between animals in a project with less than two animals.", source=self.__class__.__name__,) + raise AnimalNumberError( + "Cannot analyze directionality between animals in a project with less than two animals.", + source=self.__class__.__name__, + ) self.animal_names = [k for k in self.animal_bp_dict.keys()] _, self.video_name, _ = get_fn_ext(video_path) - self.data_path = os.path.join(self.outlier_corrected_dir, f"{self.video_name}.{self.file_type}") + self.data_path = os.path.join( + self.outlier_corrected_dir, f"{self.video_name}.{self.file_type}" + ) if not os.path.isfile(self.data_path): - raise NoFilesFoundError( msg=f"SIMBA ERROR: Could not find the file at path {self.data_path}. Make sure the data file exist to create directionality visualizations", source=self.__class__.__name__) - self.direction_analyzer = DirectingOtherAnimalsAnalyzer(config_path=config_path, - bool_tables=False, - summary_tables=False, - aggregate_statistics_tables=False, - data_paths=self.data_path) + raise NoFilesFoundError( + msg=f"SIMBA ERROR: Could not find the file at path {self.data_path}. Make sure the data file exist to create directionality visualizations", + source=self.__class__.__name__, + ) + self.direction_analyzer = DirectingOtherAnimalsAnalyzer( + config_path=config_path, + bool_tables=False, + summary_tables=False, + aggregate_statistics_tables=False, + data_paths=self.data_path, + ) self.direction_analyzer.run() self.direction_analyzer.transpose_results() self.fourcc = cv2.VideoWriter_fourcc(*Formats.MP4_CODEC.value) self.style_attr = style_attr self.direction_colors = {} if isinstance(self.style_attr[DIRECTIONALITY_COLOR], list): - check_valid_lst(data=self.style_attr[DIRECTIONALITY_COLOR], source=f'{self.__class__.__name__} colors', valid_dtypes=(tuple,), min_len=self.animal_cnt) + check_valid_lst( + data=self.style_attr[DIRECTIONALITY_COLOR], + source=f"{self.__class__.__name__} colors", + valid_dtypes=(tuple,), + min_len=self.animal_cnt, + ) for i in range(len(self.animal_names)): check_if_valid_rgb_tuple(data=self.style_attr[DIRECTIONALITY_COLOR][i]) - self.direction_colors[self.animal_names[i]] = (self.style_attr[DIRECTIONALITY_COLOR][i]) + self.direction_colors[self.animal_names[i]] = self.style_attr[ + DIRECTIONALITY_COLOR + ][i] if isinstance(self.style_attr[DIRECTIONALITY_COLOR], tuple): check_if_valid_rgb_tuple(self.style_attr[DIRECTIONALITY_COLOR]) for i in range(len(self.animal_names)): - self.direction_colors[self.animal_names[i]] = self.style_attr[DIRECTIONALITY_COLOR] + self.direction_colors[self.animal_names[i]] = self.style_attr[ + DIRECTIONALITY_COLOR + ] else: - self.random_colors = create_color_palettes(1, int(self.animal_cnt ** 2))[0] - self.random_colors = [[int(item) for item in sublist] for sublist in self.random_colors] + self.random_colors = create_color_palettes(1, int(self.animal_cnt**2))[0] + self.random_colors = [ + [int(item) for item in sublist] for sublist in self.random_colors + ] for cnt in range(len(self.animal_names)): - self.direction_colors[self.animal_names[cnt]] = (self.random_colors[cnt]) + self.direction_colors[self.animal_names[cnt]] = self.random_colors[cnt] self.data_dict = self.direction_analyzer.directionality_df_dict if not os.path.exists(self.directing_animals_video_output_path): os.makedirs(self.directing_animals_video_output_path) self.data_df = read_df(self.data_path, file_type=self.file_type) - self.video_save_path = os.path.join(self.directing_animals_video_output_path, f"{self.video_name}.mp4") + self.video_save_path = os.path.join( + self.directing_animals_video_output_path, f"{self.video_name}.mp4" + ) self.cap = cv2.VideoCapture(video_path) self.video_meta_data = get_video_meta_data(video_path) - check_video_and_data_frm_count_align(video=video_path, data=self.data_path, name=video_path, raise_error=False) + check_video_and_data_frm_count_align( + video=video_path, data=self.data_path, name=video_path, raise_error=False + ) print(f"Processing video {self.video_name}...") def run(self): video_data = self.data_dict[self.video_name] - self.writer = cv2.VideoWriter(self.video_save_path, self.fourcc, self.video_meta_data["fps"], (self.video_meta_data["width"], self.video_meta_data["height"])) + self.writer = cv2.VideoWriter( + self.video_save_path, + self.fourcc, + self.video_meta_data["fps"], + (self.video_meta_data["width"], self.video_meta_data["height"]), + ) if len(video_data) < 1: - NoDataFoundWarning(msg=f"SimBA skipping video {self.video_name}: No animals are directing each other in the video.") + NoDataFoundWarning( + msg=f"SimBA skipping video {self.video_name}: No animals are directing each other in the video." + ) else: frm_cnt = 0 while self.cap.isOpened(): @@ -117,35 +168,68 @@ def run(self): if ret: bp_data = self.data_df.iloc[frm_cnt] if self.style_attr[SHOW_POSE]: - for animal_cnt, (animal_name, animal_bps) in enumerate(self.animal_bp_dict.items()): - for bp_cnt, bp in enumerate(zip(animal_bps["X_bps"], animal_bps["Y_bps"])): + for animal_cnt, (animal_name, animal_bps) in enumerate( + self.animal_bp_dict.items() + ): + for bp_cnt, bp in enumerate( + zip(animal_bps["X_bps"], animal_bps["Y_bps"]) + ): x_bp, y_bp = bp_data[bp[0]], bp_data[bp[1]] - cv2.circle(img, (int(x_bp), int(y_bp)), self.style_attr[CIRCLE_SIZE], self.animal_bp_dict[animal_name]["colors"][bp_cnt], -1) + cv2.circle( + img, + (int(x_bp), int(y_bp)), + self.style_attr[CIRCLE_SIZE], + self.animal_bp_dict[animal_name]["colors"][bp_cnt], + -1, + ) if self.style_attr[ANIMAL_NAMES]: for animal_name, bp_v in self.animal_bp_dict.items(): - headers = [bp_v['X_bps'][-1], bp_v['Y_bps'][-1]] - bp_cords = self.data_df.loc[frm_cnt, headers].values.astype(np.int64) - cv2.putText(img, animal_name, (bp_cords[0], bp_cords[1]), TextOptions.FONT.value, 2, self.animal_bp_dict[animal_name]["colors"][0], 1) + headers = [bp_v["X_bps"][-1], bp_v["Y_bps"][-1]] + bp_cords = self.data_df.loc[frm_cnt, headers].values.astype( + np.int64 + ) + cv2.putText( + img, + animal_name, + (bp_cords[0], bp_cords[1]), + TextOptions.FONT.value, + 2, + self.animal_bp_dict[animal_name]["colors"][0], + 1, + ) if frm_cnt in list(video_data["Frame_#"].unique()): img_data = video_data[video_data["Frame_#"] == frm_cnt] for animal_name in img_data["Animal_1"].unique(): - animal_img_data = img_data[img_data["Animal_1"] == animal_name].reset_index(drop=True) - img = PlottingMixin.draw_lines_on_img(img=img, - start_positions=animal_img_data[['Eye_x', 'Eye_y']].values.astype(np.int64), - end_positions=animal_img_data[['Animal_2_bodypart_x', 'Animal_2_bodypart_y']].values.astype(np.int64), - color=tuple(self.direction_colors[animal_name]), - highlight_endpoint=self.style_attr[HIGHLIGHT_ENDPOINTS], - thickness=self.style_attr[DIRECTION_THICKNESS], - circle_size=self.style_attr[CIRCLE_SIZE]) + animal_img_data = img_data[ + img_data["Animal_1"] == animal_name + ].reset_index(drop=True) + img = PlottingMixin.draw_lines_on_img( + img=img, + start_positions=animal_img_data[ + ["Eye_x", "Eye_y"] + ].values.astype(np.int64), + end_positions=animal_img_data[ + ["Animal_2_bodypart_x", "Animal_2_bodypart_y"] + ].values.astype(np.int64), + color=tuple(self.direction_colors[animal_name]), + highlight_endpoint=self.style_attr[HIGHLIGHT_ENDPOINTS], + thickness=self.style_attr[DIRECTION_THICKNESS], + circle_size=self.style_attr[CIRCLE_SIZE], + ) frm_cnt += 1 self.writer.write(np.uint8(img)) - print(f"Frame: {frm_cnt} / {self.video_meta_data['frame_count']}. Video: {self.video_name}") + print( + f"Frame: {frm_cnt} / {self.video_meta_data['frame_count']}. Video: {self.video_name}" + ) else: break self.writer.release() self.timer.stop_timer() - stdout_success(msg=f"Directionality video {self.video_name} saved in {self.directing_animals_video_output_path} directory", elapsed_time=self.timer.elapsed_time_str) + stdout_success( + msg=f"Directionality video {self.video_name} saved in {self.directing_animals_video_output_path} directory", + elapsed_time=self.timer.elapsed_time_str, + ) # style_attr = {SHOW_POSE: True, diff --git a/simba/plotting/directing_animals_visualizer_mp.py b/simba/plotting/directing_animals_visualizer_mp.py index f30747fa9..d5a8fea94 100644 --- a/simba/plotting/directing_animals_visualizer_mp.py +++ b/simba/plotting/directing_animals_visualizer_mp.py @@ -4,45 +4,67 @@ import multiprocessing import os import platform -import pandas as pd +from typing import Any, Dict, Optional, Tuple, Union + import cv2 import numpy as np -from typing import Optional, Union, Dict, Any, Tuple +import pandas as pd -from simba.data_processors.directing_other_animals_calculator import DirectingOtherAnimalsAnalyzer +from simba.data_processors.directing_other_animals_calculator import \ + DirectingOtherAnimalsAnalyzer from simba.mixins.config_reader import ConfigReader from simba.mixins.plotting_mixin import PlottingMixin -from simba.utils.checks import check_file_exist_and_readable, check_if_keys_exist_in_dict, check_if_valid_rgb_tuple, check_valid_lst, check_video_and_data_frm_count_align, check_int +from simba.utils.checks import (check_file_exist_and_readable, + check_if_keys_exist_in_dict, + check_if_valid_rgb_tuple, check_int, + check_valid_lst, + check_video_and_data_frm_count_align) from simba.utils.data import create_color_palettes from simba.utils.enums import Formats, TextOptions from simba.utils.errors import AnimalNumberError, NoFilesFoundError from simba.utils.printing import SimbaTimer, stdout_success -from simba.utils.read_write import (concatenate_videos_in_folder, find_core_cnt, get_fn_ext, get_video_meta_data, read_df) +from simba.utils.read_write import (concatenate_videos_in_folder, + find_core_cnt, get_fn_ext, + get_video_meta_data, read_df) from simba.utils.warnings import NoDataFoundWarning -DIRECTION_THICKNESS = 'direction_thickness' -DIRECTIONALITY_COLOR = 'directionality_color' -CIRCLE_SIZE = 'circle_size' -HIGHLIGHT_ENDPOINTS = 'highlight_endpoints' -SHOW_POSE = 'show_pose' -ANIMAL_NAMES = 'animal_names' -STYLE_ATTR = [DIRECTION_THICKNESS, DIRECTIONALITY_COLOR, CIRCLE_SIZE, HIGHLIGHT_ENDPOINTS, SHOW_POSE, ANIMAL_NAMES] - - -def _directing_animals_mp(frm_range: Tuple[int, np.ndarray], - directionality_data: pd.DataFrame, - pose_data: pd.DataFrame, - style_attr: dict, - animal_bp_dict: dict, - save_temp_dir: str, - video_path: str, - video_meta_data: dict, - colors: list): +DIRECTION_THICKNESS = "direction_thickness" +DIRECTIONALITY_COLOR = "directionality_color" +CIRCLE_SIZE = "circle_size" +HIGHLIGHT_ENDPOINTS = "highlight_endpoints" +SHOW_POSE = "show_pose" +ANIMAL_NAMES = "animal_names" +STYLE_ATTR = [ + DIRECTION_THICKNESS, + DIRECTIONALITY_COLOR, + CIRCLE_SIZE, + HIGHLIGHT_ENDPOINTS, + SHOW_POSE, + ANIMAL_NAMES, +] + + +def _directing_animals_mp( + frm_range: Tuple[int, np.ndarray], + directionality_data: pd.DataFrame, + pose_data: pd.DataFrame, + style_attr: dict, + animal_bp_dict: dict, + save_temp_dir: str, + video_path: str, + video_meta_data: dict, + colors: list, +): batch = frm_range[0] fourcc = cv2.VideoWriter_fourcc(*Formats.MP4_CODEC.value) start_frm, current_frm, end_frm = frm_range[1][0], frm_range[1][0], frm_range[1][-1] save_path = os.path.join(save_temp_dir, f"{batch}.mp4") - writer = cv2.VideoWriter(save_path, fourcc, video_meta_data["fps"], (video_meta_data["width"], video_meta_data["height"])) + writer = cv2.VideoWriter( + save_path, + fourcc, + video_meta_data["fps"], + (video_meta_data["width"], video_meta_data["height"]), + ) cap = cv2.VideoCapture(video_path) cap.set(1, start_frm) while current_frm <= end_frm: @@ -50,30 +72,62 @@ def _directing_animals_mp(frm_range: Tuple[int, np.ndarray], if ret: frm_data = pose_data.iloc[current_frm] if style_attr[SHOW_POSE]: - for animal_cnt, (animal_name, animal_bps) in enumerate(animal_bp_dict.items()): - for bp_cnt, bp in enumerate(zip(animal_bps["X_bps"], animal_bps["Y_bps"])): + for animal_cnt, (animal_name, animal_bps) in enumerate( + animal_bp_dict.items() + ): + for bp_cnt, bp in enumerate( + zip(animal_bps["X_bps"], animal_bps["Y_bps"]) + ): x_bp, y_bp = frm_data[bp[0]], frm_data[bp[1]] - cv2.circle(img, (int(x_bp), int(y_bp)), style_attr[CIRCLE_SIZE], animal_bp_dict[animal_name]["colors"][bp_cnt], -1) + cv2.circle( + img, + (int(x_bp), int(y_bp)), + style_attr[CIRCLE_SIZE], + animal_bp_dict[animal_name]["colors"][bp_cnt], + -1, + ) if style_attr[ANIMAL_NAMES]: for animal_name, bp_data in animal_bp_dict.items(): - headers = [bp_data['X_bps'][-1], bp_data['Y_bps'][-1]] - bp_cords = pose_data.loc[current_frm, headers].values.astype(np.int64) - cv2.putText(img, animal_name, (bp_cords[0], bp_cords[1]), TextOptions.FONT.value, 2, animal_bp_dict[animal_name]["colors"][0], 1) + headers = [bp_data["X_bps"][-1], bp_data["Y_bps"][-1]] + bp_cords = pose_data.loc[current_frm, headers].values.astype( + np.int64 + ) + cv2.putText( + img, + animal_name, + (bp_cords[0], bp_cords[1]), + TextOptions.FONT.value, + 2, + animal_bp_dict[animal_name]["colors"][0], + 1, + ) if current_frm in list(directionality_data["Frame_#"].unique()): - img_data = directionality_data[directionality_data["Frame_#"] == current_frm] + img_data = directionality_data[ + directionality_data["Frame_#"] == current_frm + ] for animal_name in img_data["Animal_1"].unique(): - animal_img_data = img_data[img_data["Animal_1"] == animal_name].reset_index(drop=True) - img = PlottingMixin.draw_lines_on_img(img=img, - start_positions=animal_img_data[['Eye_x', 'Eye_y']].values.astype(np.int64), - end_positions=animal_img_data[['Animal_2_bodypart_x','Animal_2_bodypart_y']].values.astype(np.int64), - color=tuple(colors[animal_name]), - highlight_endpoint=style_attr[HIGHLIGHT_ENDPOINTS], - thickness=style_attr[DIRECTION_THICKNESS], - circle_size=style_attr[CIRCLE_SIZE]) + animal_img_data = img_data[ + img_data["Animal_1"] == animal_name + ].reset_index(drop=True) + img = PlottingMixin.draw_lines_on_img( + img=img, + start_positions=animal_img_data[ + ["Eye_x", "Eye_y"] + ].values.astype(np.int64), + end_positions=animal_img_data[ + ["Animal_2_bodypart_x", "Animal_2_bodypart_y"] + ].values.astype(np.int64), + color=tuple(colors[animal_name]), + highlight_endpoint=style_attr[HIGHLIGHT_ENDPOINTS], + thickness=style_attr[DIRECTION_THICKNESS], + circle_size=style_attr[CIRCLE_SIZE], + ) current_frm += 1 writer.write(np.uint8(img)) - print(f"Frame: {current_frm} / {video_meta_data['frame_count']}. Core batch: {batch}") + print( + f"Frame: {current_frm} / {video_meta_data['frame_count']}. Core batch: {batch}" + ) else: break @@ -81,6 +135,7 @@ def _directing_animals_mp(frm_range: Tuple[int, np.ndarray], writer.release() return batch + class DirectingOtherAnimalsVisualizerMultiprocess(ConfigReader, PlottingMixin): """ Class for visualizing when animals are directing towards body-parts of other animals using multiprocessing. @@ -109,103 +164,170 @@ class DirectingOtherAnimalsVisualizerMultiprocess(ConfigReader, PlottingMixin): >>> test.run() """ - def __init__(self, - config_path: Union[str, os.PathLike], - video_path: Union[str, os.PathLike], - style_attr: Dict[str, Any], - core_cnt: Optional[int] = -1): + def __init__( + self, + config_path: Union[str, os.PathLike], + video_path: Union[str, os.PathLike], + style_attr: Dict[str, Any], + core_cnt: Optional[int] = -1, + ): if platform.system() == "Darwin": multiprocessing.set_start_method("spawn", force=True) check_file_exist_and_readable(file_path=video_path) check_file_exist_and_readable(file_path=config_path) - check_if_keys_exist_in_dict(data=style_attr, key=STYLE_ATTR, name=f'{self.__class__.__name__} style_attr') - check_int(name=f"{self.__class__.__name__} core_cnt", value=core_cnt, min_value=-1, max_value=find_core_cnt()[0]) - if core_cnt == -1: core_cnt = find_core_cnt()[0] + check_if_keys_exist_in_dict( + data=style_attr, + key=STYLE_ATTR, + name=f"{self.__class__.__name__} style_attr", + ) + check_int( + name=f"{self.__class__.__name__} core_cnt", + value=core_cnt, + min_value=-1, + max_value=find_core_cnt()[0], + ) + if core_cnt == -1: + core_cnt = find_core_cnt()[0] ConfigReader.__init__(self, config_path=config_path) PlottingMixin.__init__(self) if self.animal_cnt < 2: - raise AnimalNumberError("Cannot analyze directionality between animals in a project with less than two animals.", source=self.__class__.__name__,) + raise AnimalNumberError( + "Cannot analyze directionality between animals in a project with less than two animals.", + source=self.__class__.__name__, + ) self.animal_names = [k for k in self.animal_bp_dict.keys()] _, self.video_name, _ = get_fn_ext(video_path) - self.data_path = os.path.join(self.outlier_corrected_dir, f"{self.video_name}.{self.file_type}") + self.data_path = os.path.join( + self.outlier_corrected_dir, f"{self.video_name}.{self.file_type}" + ) if not os.path.isfile(self.data_path): - raise NoFilesFoundError( msg=f"SIMBA ERROR: Could not find the file at path {self.data_path}. Make sure the data file exist to create directionality visualizations", source=self.__class__.__name__) - self.direction_analyzer = DirectingOtherAnimalsAnalyzer(config_path=config_path, - bool_tables=False, - summary_tables=False, - aggregate_statistics_tables=False, - data_paths=self.data_path) + raise NoFilesFoundError( + msg=f"SIMBA ERROR: Could not find the file at path {self.data_path}. Make sure the data file exist to create directionality visualizations", + source=self.__class__.__name__, + ) + self.direction_analyzer = DirectingOtherAnimalsAnalyzer( + config_path=config_path, + bool_tables=False, + summary_tables=False, + aggregate_statistics_tables=False, + data_paths=self.data_path, + ) self.direction_analyzer.run() self.direction_analyzer.transpose_results() self.fourcc = cv2.VideoWriter_fourcc(*Formats.MP4_CODEC.value) self.style_attr = style_attr self.direction_colors = {} if isinstance(self.style_attr[DIRECTIONALITY_COLOR], list): - check_valid_lst(data=self.style_attr[DIRECTIONALITY_COLOR], source=f'{self.__class__.__name__} colors', valid_dtypes=(tuple,), min_len=self.animal_cnt) + check_valid_lst( + data=self.style_attr[DIRECTIONALITY_COLOR], + source=f"{self.__class__.__name__} colors", + valid_dtypes=(tuple,), + min_len=self.animal_cnt, + ) for i in range(len(self.animal_names)): check_if_valid_rgb_tuple(data=self.style_attr[DIRECTIONALITY_COLOR][i]) - self.direction_colors[self.animal_names[i]] = (self.style_attr[DIRECTIONALITY_COLOR][i]) + self.direction_colors[self.animal_names[i]] = self.style_attr[ + DIRECTIONALITY_COLOR + ][i] if isinstance(self.style_attr[DIRECTIONALITY_COLOR], tuple): check_if_valid_rgb_tuple(self.style_attr[DIRECTIONALITY_COLOR]) for i in range(len(self.animal_names)): - self.direction_colors[self.animal_names[i]] = self.style_attr[DIRECTIONALITY_COLOR] + self.direction_colors[self.animal_names[i]] = self.style_attr[ + DIRECTIONALITY_COLOR + ] else: - self.random_colors = create_color_palettes(1, int(self.animal_cnt ** 2))[0] - self.random_colors = [[int(item) for item in sublist] for sublist in self.random_colors] + self.random_colors = create_color_palettes(1, int(self.animal_cnt**2))[0] + self.random_colors = [ + [int(item) for item in sublist] for sublist in self.random_colors + ] for cnt in range(len(self.animal_names)): - self.direction_colors[self.animal_names[cnt]] = (self.random_colors[cnt]) + self.direction_colors[self.animal_names[cnt]] = self.random_colors[cnt] self.data_dict = self.direction_analyzer.directionality_df_dict if not os.path.exists(self.directing_animals_video_output_path): os.makedirs(self.directing_animals_video_output_path) self.data_df = read_df(self.data_path, file_type=self.file_type) - self.video_save_path = os.path.join(self.directing_animals_video_output_path, f"{self.video_name}.mp4") + self.video_save_path = os.path.join( + self.directing_animals_video_output_path, f"{self.video_name}.mp4" + ) self.cap = cv2.VideoCapture(video_path) self.video_meta_data = get_video_meta_data(video_path) - check_video_and_data_frm_count_align(video=video_path, data=self.data_path, name=video_path, raise_error=False) - self.video_save_path = os.path.join(self.directing_animals_video_output_path, f"{self.video_name}.mp4") + check_video_and_data_frm_count_align( + video=video_path, data=self.data_path, name=video_path, raise_error=False + ) + self.video_save_path = os.path.join( + self.directing_animals_video_output_path, f"{self.video_name}.mp4" + ) if not os.path.exists(self.directing_animals_video_output_path): os.makedirs(self.directing_animals_video_output_path) - self.save_path = os.path.join(self.directing_animals_video_output_path, self.video_name + ".mp4") - self.save_temp_path = os.path.join(self.directing_animals_video_output_path, "temp") + self.save_path = os.path.join( + self.directing_animals_video_output_path, self.video_name + ".mp4" + ) + self.save_temp_path = os.path.join( + self.directing_animals_video_output_path, "temp" + ) if os.path.exists(self.save_temp_path): self.remove_a_folder(folder_dir=self.save_temp_path) os.makedirs(self.save_temp_path) self.core_cnt, self.video_path = core_cnt, video_path print(f"Processing video {self.video_name}...") - def run(self): video_data = self.data_dict[self.video_name] - self.writer = cv2.VideoWriter(self.video_save_path, self.fourcc, self.video_meta_data["fps"], (self.video_meta_data["width"], self.video_meta_data["height"])) + self.writer = cv2.VideoWriter( + self.video_save_path, + self.fourcc, + self.video_meta_data["fps"], + (self.video_meta_data["width"], self.video_meta_data["height"]), + ) if len(video_data) < 1: - NoDataFoundWarning(msg=f"SimBA skipping video {self.video_name}: No animals are directing each other in the video.") + NoDataFoundWarning( + msg=f"SimBA skipping video {self.video_name}: No animals are directing each other in the video." + ) else: - frm_data = np.array_split(list(range(0, self.video_meta_data["frame_count"]+1)), self.core_cnt) + frm_data = np.array_split( + list(range(0, self.video_meta_data["frame_count"] + 1)), self.core_cnt + ) frm_ranges = [] - for i in range(len(frm_data)): frm_ranges.append((i, frm_data[i])) - print(f"Creating directing images, multiprocessing (chunksize: {self.multiprocess_chunksize}, cores: {self.core_cnt})...") - with multiprocessing.Pool(self.core_cnt, maxtasksperchild=self.maxtasksperchild) as pool: - constants = functools.partial(_directing_animals_mp, - directionality_data=video_data, - pose_data=self.data_df, - video_meta_data=self.video_meta_data, - style_attr=self.style_attr, - save_temp_dir=self.save_temp_path, - video_path=self.video_path, - animal_bp_dict=self.animal_bp_dict, - colors=self.direction_colors) - for cnt, result in enumerate(pool.imap(constants, frm_ranges, chunksize=self.multiprocess_chunksize)): - print(f'Core batch {result+1} complete...') + for i in range(len(frm_data)): + frm_ranges.append((i, frm_data[i])) + print( + f"Creating directing images, multiprocessing (chunksize: {self.multiprocess_chunksize}, cores: {self.core_cnt})..." + ) + with multiprocessing.Pool( + self.core_cnt, maxtasksperchild=self.maxtasksperchild + ) as pool: + constants = functools.partial( + _directing_animals_mp, + directionality_data=video_data, + pose_data=self.data_df, + video_meta_data=self.video_meta_data, + style_attr=self.style_attr, + save_temp_dir=self.save_temp_path, + video_path=self.video_path, + animal_bp_dict=self.animal_bp_dict, + colors=self.direction_colors, + ) + for cnt, result in enumerate( + pool.imap( + constants, frm_ranges, chunksize=self.multiprocess_chunksize + ) + ): + print(f"Core batch {result+1} complete...") print(f"Joining {self.video_name} multi-processed video...") - concatenate_videos_in_folder(in_folder=self.save_temp_path, - save_path=self.save_path, - video_format="mp4", - remove_splits=True) + concatenate_videos_in_folder( + in_folder=self.save_temp_path, + save_path=self.save_path, + video_format="mp4", + remove_splits=True, + ) self.timer.stop_timer() pool.terminate() pool.join() - stdout_success(msg=f"Video {self.video_name} complete. Video saved in {self.directing_animals_video_output_path} directory", elapsed_time=self.timer.elapsed_time_str) + stdout_success( + msg=f"Video {self.video_name} complete. Video saved in {self.directing_animals_video_output_path} directory", + elapsed_time=self.timer.elapsed_time_str, + ) # style_attr = {SHOW_POSE: True, diff --git a/simba/roi_tools/ROI_analyzer.py b/simba/roi_tools/ROI_analyzer.py index 1c26f5cb7..a6f43475f 100644 --- a/simba/roi_tools/ROI_analyzer.py +++ b/simba/roi_tools/ROI_analyzer.py @@ -1,20 +1,26 @@ __author__ = "Simon Nilsson" import os -from typing import Optional, Union, List +from typing import List, Optional, Union + import numpy as np import pandas as pd from simba.mixins.config_reader import ConfigReader from simba.mixins.feature_extraction_mixin import FeatureExtractionMixin -from simba.mixins.feature_extraction_supplement_mixin import FeatureExtractionSupplemental -from simba.utils.checks import check_file_exist_and_readable, check_float, check_valid_lst, check_all_file_names_are_represented_in_video_log, check_that_column_exist +from simba.mixins.feature_extraction_supplement_mixin import \ + FeatureExtractionSupplemental +from simba.utils.checks import ( + check_all_file_names_are_represented_in_video_log, + check_file_exist_and_readable, check_float, check_that_column_exist, + check_valid_lst) +from simba.utils.data import detect_bouts, slice_roi_dict_for_video from simba.utils.enums import Keys -from simba.utils.errors import (MissingColumnsError, CountError, ROICoordinatesNotFoundError) +from simba.utils.errors import (CountError, MissingColumnsError, + ROICoordinatesNotFoundError) from simba.utils.printing import stdout_success -from simba.utils.read_write import get_fn_ext, read_df, read_data_paths +from simba.utils.read_write import get_fn_ext, read_data_paths, read_df from simba.utils.warnings import NoDataFoundWarning -from simba.utils.data import slice_roi_dict_for_video, detect_bouts class ROIAnalyzer(ConfigReader, FeatureExtractionMixin): @@ -38,31 +44,54 @@ class ROIAnalyzer(ConfigReader, FeatureExtractionMixin): >>> test.save() """ - def __init__(self, - config_path: Union[str, os.PathLike], - data_path: Optional[Union[str, os.PathLike, List[str]]] = None, - detailed_bout_data: Optional[bool] = False, - calculate_distances: Optional[bool] = False, - threshold: Optional[float] = 0.0, - body_parts: Optional[List[str]] = None): + def __init__( + self, + config_path: Union[str, os.PathLike], + data_path: Optional[Union[str, os.PathLike, List[str]]] = None, + detailed_bout_data: Optional[bool] = False, + calculate_distances: Optional[bool] = False, + threshold: Optional[float] = 0.0, + body_parts: Optional[List[str]] = None, + ): check_file_exist_and_readable(file_path=config_path) ConfigReader.__init__(self, config_path=config_path) if not os.path.isfile(self.roi_coordinates_path): - raise ROICoordinatesNotFoundError(expected_file_path=self.roi_coordinates_path) + raise ROICoordinatesNotFoundError( + expected_file_path=self.roi_coordinates_path + ) self.read_roi_data() FeatureExtractionMixin.__init__(self) if detailed_bout_data and (not os.path.exists(self.detailed_roi_data_dir)): os.makedirs(self.detailed_roi_data_dir) - self.data_paths = read_data_paths(path=data_path, default=self.outlier_corrected_paths, default_name=self.outlier_corrected_dir, file_type=self.file_type) + self.data_paths = read_data_paths( + path=data_path, + default=self.outlier_corrected_paths, + default_name=self.outlier_corrected_dir, + file_type=self.file_type, + ) - check_float(name="Body-part probability threshold", value=threshold, min_value=0.0, max_value=1.0) - check_valid_lst(data=body_parts, source=f'{self.__class__.__name__} body-parts', valid_dtypes=(str,)) + check_float( + name="Body-part probability threshold", + value=threshold, + min_value=0.0, + max_value=1.0, + ) + check_valid_lst( + data=body_parts, + source=f"{self.__class__.__name__} body-parts", + valid_dtypes=(str,), + ) if len(set(body_parts)) != len(body_parts): - raise CountError(msg=f'All body-part entries have to be unique. Got {body_parts}', source=self.__class__.__name__) + raise CountError( + msg=f"All body-part entries have to be unique. Got {body_parts}", + source=self.__class__.__name__, + ) self.bp_dict, self.bp_lk = {}, {} for bp in body_parts: - animal = self.find_animal_name_from_body_part_name(bp_name=bp, bp_dict=self.animal_bp_dict) + animal = self.find_animal_name_from_body_part_name( + bp_name=bp, bp_dict=self.animal_bp_dict + ) self.bp_dict[animal] = [f'{bp}_{"x"}', f'{bp}_{"y"}', f'{bp}_{"p"}'] self.bp_lk[animal] = bp self.roi_headers = [v for k, v in self.bp_dict.items()] @@ -71,100 +100,266 @@ def __init__(self, self.detailed_bout_data = detailed_bout_data def run(self): - check_all_file_names_are_represented_in_video_log(video_info_df=self.video_info_df, data_paths=self.data_paths) - self.movements_df = pd.DataFrame(columns=["VIDEO", "ANIMAL", "SHAPE", "MEASUREMENT", "VALUE"]) - self.entry_results = pd.DataFrame(columns=["VIDEO", "ANIMAL", "SHAPE", "ENTRY COUNT"]) - self.time_results = pd.DataFrame(columns=["VIDEO", "ANIMAL", "SHAPE", "TIME (S)"]) + check_all_file_names_are_represented_in_video_log( + video_info_df=self.video_info_df, data_paths=self.data_paths + ) + self.movements_df = pd.DataFrame( + columns=["VIDEO", "ANIMAL", "SHAPE", "MEASUREMENT", "VALUE"] + ) + self.entry_results = pd.DataFrame( + columns=["VIDEO", "ANIMAL", "SHAPE", "ENTRY COUNT"] + ) + self.time_results = pd.DataFrame( + columns=["VIDEO", "ANIMAL", "SHAPE", "TIME (S)"] + ) self.roi_bout_results = [] for file_cnt, file_path in enumerate(self.data_paths): _, video_name, _ = get_fn_ext(file_path) print(f"Analysing ROI data for video {video_name}...") - video_settings, pix_per_mm, self.fps = self.read_video_info(video_name=video_name) - self.roi_dict, video_shape_names = slice_roi_dict_for_video(data=self.roi_dict, video_name=video_name) + video_settings, pix_per_mm, self.fps = self.read_video_info( + video_name=video_name + ) + self.roi_dict, video_shape_names = slice_roi_dict_for_video( + data=self.roi_dict, video_name=video_name + ) if video_shape_names == 0: - NoDataFoundWarning(msg=f"Skipping video {video_name}: No user-defined ROI data found for this video...") + NoDataFoundWarning( + msg=f"Skipping video {video_name}: No user-defined ROI data found for this video..." + ) continue else: self.data_df = read_df(file_path, self.file_type).reset_index(drop=True) if len(self.bp_headers) != len(self.data_df.columns): - raise MissingColumnsError(msg=f"The data file {file_path} contains {len(self.data_df.columns)} body-part columns, but the project is made for {len(self.bp_headers)} body-part columns as suggested by the {self.body_parts_path} file", source=self.__class__.__name__) + raise MissingColumnsError( + msg=f"The data file {file_path} contains {len(self.data_df.columns)} body-part columns, but the project is made for {len(self.bp_headers)} body-part columns as suggested by the {self.body_parts_path} file", + source=self.__class__.__name__, + ) self.data_df.columns = self.bp_headers - check_that_column_exist(df=self.data_df, column_name=self.roi_headers, file_name=file_path) + check_that_column_exist( + df=self.data_df, column_name=self.roi_headers, file_name=file_path + ) for animal_name, bp_names in self.bp_dict.items(): - animal_df = self.data_df[self.bp_dict[animal_name]].reset_index(drop=True) + animal_df = self.data_df[self.bp_dict[animal_name]].reset_index( + drop=True + ) animal_bout_results = {} for _, row in self.roi_dict[Keys.ROI_RECTANGLES.value].iterrows(): - roi_coords = np.array([[row["topLeftX"], row["topLeftY"]], [row["Bottom_right_X"], row["Bottom_right_Y"]]]) - animal_df[row['Name']] = FeatureExtractionMixin.framewise_inside_rectangle_roi(bp_location=animal_df.values[:, 0:2], roi_coords=roi_coords) - animal_df.loc[animal_df[bp_names[2]] < self.threshold, row["Name"]] = 0 - roi_bouts = detect_bouts(data_df=animal_df, target_lst=[row['Name']], fps=self.fps) - roi_bouts['ANIMAL'] = animal_name; roi_bouts['VIDEO'] = video_name + roi_coords = np.array( + [ + [row["topLeftX"], row["topLeftY"]], + [row["Bottom_right_X"], row["Bottom_right_Y"]], + ] + ) + animal_df[row["Name"]] = ( + FeatureExtractionMixin.framewise_inside_rectangle_roi( + bp_location=animal_df.values[:, 0:2], + roi_coords=roi_coords, + ) + ) + animal_df.loc[ + animal_df[bp_names[2]] < self.threshold, row["Name"] + ] = 0 + roi_bouts = detect_bouts( + data_df=animal_df, target_lst=[row["Name"]], fps=self.fps + ) + roi_bouts["ANIMAL"] = animal_name + roi_bouts["VIDEO"] = video_name self.roi_bout_results.append(roi_bouts) - animal_bout_results[row['Name']] = roi_bouts - self.entry_results.loc[len(self.entry_results)] = [video_name, animal_name, row['Name'], len(roi_bouts)] - self.time_results.loc[len(self.time_results)] = [video_name, animal_name, row['Name'], roi_bouts['Bout_time'].sum()] + animal_bout_results[row["Name"]] = roi_bouts + self.entry_results.loc[len(self.entry_results)] = [ + video_name, + animal_name, + row["Name"], + len(roi_bouts), + ] + self.time_results.loc[len(self.time_results)] = [ + video_name, + animal_name, + row["Name"], + roi_bouts["Bout_time"].sum(), + ] for _, row in self.roi_dict[Keys.ROI_CIRCLES.value].iterrows(): center_x, center_y = row["centerX"], row["centerY"] - animal_df[f'{row["Name"]}_distance'] = FeatureExtractionMixin.framewise_euclidean_distance_roi(location_1=animal_df.values[:, 0:2], location_2=np.array([center_x, center_y])) + animal_df[f'{row["Name"]}_distance'] = ( + FeatureExtractionMixin.framewise_euclidean_distance_roi( + location_1=animal_df.values[:, 0:2], + location_2=np.array([center_x, center_y]), + ) + ) animal_df[row["Name"]] = 0 - animal_df.loc[animal_df[row["Name"]] <= row["radius"], row["Name"]] = 1 - animal_df.loc[animal_df[bp_names[2]] < self.threshold, row["Name"]] = 0 - roi_bouts = detect_bouts(data_df=animal_df, target_lst=[row['Name']], fps=self.fps) - roi_bouts['ANIMAL'] = animal_name; roi_bouts['VIDEO'] = video_name + animal_df.loc[ + animal_df[row["Name"]] <= row["radius"], row["Name"] + ] = 1 + animal_df.loc[ + animal_df[bp_names[2]] < self.threshold, row["Name"] + ] = 0 + roi_bouts = detect_bouts( + data_df=animal_df, target_lst=[row["Name"]], fps=self.fps + ) + roi_bouts["ANIMAL"] = animal_name + roi_bouts["VIDEO"] = video_name self.roi_bout_results.append(roi_bouts) - animal_bout_results[row['Name']] = roi_bouts - self.entry_results.loc[len(self.entry_results)] = [video_name, animal_name, row['Name'], len(roi_bouts)] - self.time_results.loc[len(self.time_results)] = [video_name, animal_name, row['Name'], roi_bouts['Bout_time'].sum()] + animal_bout_results[row["Name"]] = roi_bouts + self.entry_results.loc[len(self.entry_results)] = [ + video_name, + animal_name, + row["Name"], + len(roi_bouts), + ] + self.time_results.loc[len(self.time_results)] = [ + video_name, + animal_name, + row["Name"], + roi_bouts["Bout_time"].sum(), + ] for _, row in self.roi_dict[Keys.ROI_POLYGONS.value].iterrows(): - roi_coords = np.array(list(zip(row["vertices"][:, 0], row["vertices"][:, 1]))) - animal_df[row['Name']] = FeatureExtractionMixin.framewise_inside_polygon_roi(bp_location=animal_df.values[:, 0:2], roi_coords=roi_coords) - animal_df.loc[animal_df[bp_names[2]] < self.threshold, row["Name"]] = 0 - roi_bouts = detect_bouts(data_df=animal_df, target_lst=[row['Name']], fps=self.fps) - roi_bouts['ANIMAL'] = animal_name; roi_bouts['VIDEO'] = video_name + roi_coords = np.array( + list(zip(row["vertices"][:, 0], row["vertices"][:, 1])) + ) + animal_df[row["Name"]] = ( + FeatureExtractionMixin.framewise_inside_polygon_roi( + bp_location=animal_df.values[:, 0:2], + roi_coords=roi_coords, + ) + ) + animal_df.loc[ + animal_df[bp_names[2]] < self.threshold, row["Name"] + ] = 0 + roi_bouts = detect_bouts( + data_df=animal_df, target_lst=[row["Name"]], fps=self.fps + ) + roi_bouts["ANIMAL"] = animal_name + roi_bouts["VIDEO"] = video_name self.roi_bout_results.append(roi_bouts) - animal_bout_results[row['Name']] = roi_bouts - self.entry_results.loc[len(self.entry_results)] = [video_name, animal_name, row['Name'], len(roi_bouts)] - self.time_results.loc[len(self.time_results)] = [video_name, animal_name, row['Name'], roi_bouts['Bout_time'].sum()] + animal_bout_results[row["Name"]] = roi_bouts + self.entry_results.loc[len(self.entry_results)] = [ + video_name, + animal_name, + row["Name"], + len(roi_bouts), + ] + self.time_results.loc[len(self.time_results)] = [ + video_name, + animal_name, + row["Name"], + roi_bouts["Bout_time"].sum(), + ] if self.calculate_distances: for roi_name, roi_data in animal_bout_results.items(): if len(roi_data) == 0: - self.movements_df.loc[len(self.movements_df)] = [video_name, animal_name, roi_name, "Movement (cm)", 0] - self.movements_df.loc[len(self.movements_df)] = [video_name, animal_name, roi_name, "Average velocity (cm/s)", "None"] + self.movements_df.loc[len(self.movements_df)] = [ + video_name, + animal_name, + roi_name, + "Movement (cm)", + 0, + ] + self.movements_df.loc[len(self.movements_df)] = [ + video_name, + animal_name, + roi_name, + "Average velocity (cm/s)", + "None", + ] else: distances, velocities = [], [] - roi_frames = roi_data[['Start_frame', 'End_frame']].values + roi_frames = roi_data[ + ["Start_frame", "End_frame"] + ].values for event in roi_frames: - event_pose = animal_df.loc[np.arange(event[0], event[1]+1), bp_names] - event_pose = event_pose[event_pose[bp_names[2]] > self.threshold][bp_names[:2]].values + event_pose = animal_df.loc[ + np.arange(event[0], event[1] + 1), bp_names + ] + event_pose = event_pose[ + event_pose[bp_names[2]] > self.threshold + ][bp_names[:2]].values if event_pose.shape[0] > 1: - distance, velocity = FeatureExtractionSupplemental.distance_and_velocity(x=event_pose, fps=self.fps, pixels_per_mm=pix_per_mm, centimeters=True) - distances.append(distance); velocities.append(velocity) - self.movements_df.loc[len(self.movements_df)] = [video_name, animal_name, roi_name, "Movement (cm)", sum(distances)] - self.movements_df.loc[len(self.movements_df)] = [video_name, animal_name, roi_name, "Average velocity (cm/s)", np.average(velocities)] + distance, velocity = ( + FeatureExtractionSupplemental.distance_and_velocity( + x=event_pose, + fps=self.fps, + pixels_per_mm=pix_per_mm, + centimeters=True, + ) + ) + distances.append(distance) + velocities.append(velocity) + self.movements_df.loc[len(self.movements_df)] = [ + video_name, + animal_name, + roi_name, + "Movement (cm)", + sum(distances), + ] + self.movements_df.loc[len(self.movements_df)] = [ + video_name, + animal_name, + roi_name, + "Average velocity (cm/s)", + np.average(velocities), + ] self.detailed_df = pd.concat(self.roi_bout_results, axis=0) - self.detailed_df = self.detailed_df.rename(columns={"Event": "SHAPE NAME", "Start_time": "START TIME", 'End Time': 'END TIME', 'Start_frame': 'START FRAME', 'End_frame': 'END FRAME', 'Bout_time': 'DURATION (S)'}) + self.detailed_df = self.detailed_df.rename( + columns={ + "Event": "SHAPE NAME", + "Start_time": "START TIME", + "End Time": "END TIME", + "Start_frame": "START FRAME", + "End_frame": "END FRAME", + "Bout_time": "DURATION (S)", + } + ) self.detailed_df["BODY-PART"] = self.detailed_df["ANIMAL"].map(self.bp_lk) - self.detailed_df = self.detailed_df[['VIDEO', 'ANIMAL', 'BODY-PART', 'SHAPE NAME', 'START TIME', 'END TIME', 'START FRAME', 'END FRAME', 'DURATION (S)']] + self.detailed_df = self.detailed_df[ + [ + "VIDEO", + "ANIMAL", + "BODY-PART", + "SHAPE NAME", + "START TIME", + "END TIME", + "START FRAME", + "END FRAME", + "DURATION (S)", + ] + ] def save(self): self.entry_results["BODY-PART"] = self.entry_results["ANIMAL"].map(self.bp_lk) self.time_results["BODY-PART"] = self.time_results["ANIMAL"].map(self.bp_lk) - self.entry_results = self.entry_results[['VIDEO', 'ANIMAL', 'BODY-PART', 'SHAPE', 'ENTRY COUNT']] - self.time_results = self.time_results[['VIDEO', 'ANIMAL', 'BODY-PART', 'SHAPE', 'TIME (S)']] - self.entry_results.to_csv(os.path.join(self.logs_path, f'{"ROI_entry_data"}_{self.datetime}.csv')) - self.time_results.to_csv(os.path.join(self.logs_path, f'{"ROI_time_data"}_{self.datetime}.csv')) + self.entry_results = self.entry_results[ + ["VIDEO", "ANIMAL", "BODY-PART", "SHAPE", "ENTRY COUNT"] + ] + self.time_results = self.time_results[ + ["VIDEO", "ANIMAL", "BODY-PART", "SHAPE", "TIME (S)"] + ] + self.entry_results.to_csv( + os.path.join(self.logs_path, f'{"ROI_entry_data"}_{self.datetime}.csv') + ) + self.time_results.to_csv( + os.path.join(self.logs_path, f'{"ROI_time_data"}_{self.datetime}.csv') + ) if self.detailed_bout_data: - detailed_path = os.path.join(self.logs_path, f'{"Detailed_ROI_data"}_{self.datetime}.csv') + detailed_path = os.path.join( + self.logs_path, f'{"Detailed_ROI_data"}_{self.datetime}.csv' + ) self.detailed_df.to_csv(detailed_path) - print(f'Detailed ROI data saved at {detailed_path}...') + print(f"Detailed ROI data saved at {detailed_path}...") if self.calculate_distances: - movement_path = os.path.join(self.logs_path, f'{"ROI_movement_data"}_{self.datetime}.csv') + movement_path = os.path.join( + self.logs_path, f'{"ROI_movement_data"}_{self.datetime}.csv' + ) self.movements_df["BODY-PART"] = self.movements_df["ANIMAL"].map(self.bp_lk) - self.movements_df = self.movements_df[['VIDEO', 'ANIMAL', 'BODY-PART', 'SHAPE', 'MEASUREMENT', 'VALUE']] + self.movements_df = self.movements_df[ + ["VIDEO", "ANIMAL", "BODY-PART", "SHAPE", "MEASUREMENT", "VALUE"] + ] self.movements_df.to_csv(movement_path) - print(f'ROI aggregate movement data saved at {movement_path}...') - stdout_success(msg=f'ROI time and ROI entry saved in the {self.logs_path} directory in CSV format.') + print(f"ROI aggregate movement data saved at {movement_path}...") + stdout_success( + msg=f"ROI time and ROI entry saved in the {self.logs_path} directory in CSV format." + ) + # test = ROIAnalyzer(config_path = r"/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini", # data_path=None, diff --git a/simba/roi_tools/ROI_directing_analyzer.py b/simba/roi_tools/ROI_directing_analyzer.py index 8701c79be..e15e2546f 100644 --- a/simba/roi_tools/ROI_directing_analyzer.py +++ b/simba/roi_tools/ROI_directing_analyzer.py @@ -1,7 +1,7 @@ __author__ = "Simon Nilsson" import os -from typing import Union, Optional +from typing import Optional, Union import numpy as np import pandas as pd @@ -9,11 +9,12 @@ from simba.mixins.config_reader import ConfigReader from simba.mixins.feature_extraction_mixin import FeatureExtractionMixin -from simba.utils.read_write import get_fn_ext, read_df, read_data_paths -from simba.utils.errors import ROICoordinatesNotFoundError, InvalidInputError from simba.utils.checks import check_file_exist_and_readable from simba.utils.data import slice_roi_dict_for_video +from simba.utils.errors import InvalidInputError, ROICoordinatesNotFoundError from simba.utils.printing import SimbaTimer, stdout_success +from simba.utils.read_write import get_fn_ext, read_data_paths, read_df + class DirectingROIAnalyzer(ConfigReader, FeatureExtractionMixin): """ @@ -33,37 +34,57 @@ class DirectingROIAnalyzer(ConfigReader, FeatureExtractionMixin): >>> test.save() """ - def __init__(self, - config_path: Union[str, os.PathLike], - data_path: Optional[Union[str, os.PathLike]] = None): + def __init__( + self, + config_path: Union[str, os.PathLike], + data_path: Optional[Union[str, os.PathLike]] = None, + ): check_file_exist_and_readable(file_path=config_path) ConfigReader.__init__(self, config_path=config_path) FeatureExtractionMixin.__init__(self, config_path=config_path) - self.data_paths = read_data_paths(path=data_path, default=self.outlier_corrected_paths, default_name=self.outlier_corrected_dir, file_type=self.file_type) + self.data_paths = read_data_paths( + path=data_path, + default=self.outlier_corrected_paths, + default_name=self.outlier_corrected_dir, + file_type=self.file_type, + ) if not os.path.isfile(self.roi_coordinates_path): - raise ROICoordinatesNotFoundError(expected_file_path=self.roi_coordinates_path) + raise ROICoordinatesNotFoundError( + expected_file_path=self.roi_coordinates_path + ) if not self.check_directionality_viable()[0]: - raise InvalidInputError(msg='Cannot compute directionality towards ROIs. The ear and nose data is tracked in the project', source=self.__class__.__name__) + raise InvalidInputError( + msg="Cannot compute directionality towards ROIs. The ear and nose data is tracked in the project", + source=self.__class__.__name__, + ) self.read_roi_data() self.direct_bp_dict = self.check_directionality_cords() - def __format_direction_data(self, - direction_data: np.ndarray, - nose_arr: np.ndarray, - roi_center: np.ndarray, - animal_name: str, - shape_name: str) -> pd.DataFrame: + def __format_direction_data( + self, + direction_data: np.ndarray, + nose_arr: np.ndarray, + roi_center: np.ndarray, + animal_name: str, + shape_name: str, + ) -> pd.DataFrame: x_min = np.minimum(direction_data[:, 1], nose_arr[:, 0]) y_min = np.minimum(direction_data[:, 2], nose_arr[:, 1]) delta_x = abs((direction_data[:, 1] - nose_arr[:, 0]) / 2) delta_y = abs((direction_data[:, 2] - nose_arr[:, 1]) / 2) x_middle, y_middle = np.add(x_min, delta_x), np.add(y_min, delta_y) - direction_data = np.concatenate((y_middle.reshape(-1, 1), direction_data), axis=1) - direction_data = np.concatenate((x_middle.reshape(-1, 1), direction_data), axis=1) + direction_data = np.concatenate( + (y_middle.reshape(-1, 1), direction_data), axis=1 + ) + direction_data = np.concatenate( + (x_middle.reshape(-1, 1), direction_data), axis=1 + ) direction_data = np.delete(direction_data, [2, 3, 4], 1) - bp_data = pd.DataFrame(direction_data, columns=["Eye_x", "Eye_y", "Directing_BOOL"]) + bp_data = pd.DataFrame( + direction_data, columns=["Eye_x", "Eye_y", "Directing_BOOL"] + ) bp_data["ROI_x"] = roi_center[0] bp_data["ROI_y"] = roi_center[1] bp_data = bp_data[["Eye_x", "Eye_y", "ROI_x", "ROI_y", "Directing_BOOL"]] @@ -119,18 +140,21 @@ def calc(A, B, C): return results - def __find_roi_intersections(self, - bp_data: pd.DataFrame, - shape_info: dict): - + def __find_roi_intersections(self, bp_data: pd.DataFrame, shape_info: dict): eye_lines = bp_data[["Eye_x", "Eye_y", "ROI_x", "ROI_y"]].values.astype(int) roi_lines = None if shape_info["Shape_type"] == "Rectangle": top_left_x, top_left_y = (shape_info["topLeftX"], shape_info["topLeftY"]) - bottom_right_x, bottom_right_y = (shape_info["Bottom_right_X"], shape_info["Bottom_right_Y"]) + bottom_right_x, bottom_right_y = ( + shape_info["Bottom_right_X"], + shape_info["Bottom_right_Y"], + ) top_right_x, top_right_y = top_left_x + shape_info["width"], top_left_y - bottom_left_x, bottom_left_y = (bottom_right_x - shape_info["width"], bottom_right_y) + bottom_left_x, bottom_left_y = ( + bottom_right_x - shape_info["width"], + bottom_right_y, + ) roi_lines = np.array( [ [top_left_x, top_left_y, bottom_left_x, bottom_left_y], @@ -181,43 +205,72 @@ def run(self): for file_cnt, file_path in enumerate(self.data_paths): _, self.video_name, _ = get_fn_ext(file_path) video_timer = SimbaTimer(start=True) - print(f'Analyzing ROI directionality in video {self.video_name}...') + print(f"Analyzing ROI directionality in video {self.video_name}...") data_df = read_df(file_path=file_path, file_type=self.file_type) - video_roi_dict, shape_names = slice_roi_dict_for_video(data=self.roi_dict, video_name=self.video_name) + video_roi_dict, shape_names = slice_roi_dict_for_video( + data=self.roi_dict, video_name=self.video_name + ) for animal_name, bps in self.direct_bp_dict.items(): - ear_left_arr = data_df[[bps["Ear_left"]["X_bps"], bps["Ear_left"]["Y_bps"]]].values - ear_right_arr = data_df[[bps["Ear_right"]["X_bps"], bps["Ear_right"]["Y_bps"]]].values + ear_left_arr = data_df[ + [bps["Ear_left"]["X_bps"], bps["Ear_left"]["Y_bps"]] + ].values + ear_right_arr = data_df[ + [bps["Ear_right"]["X_bps"], bps["Ear_right"]["Y_bps"]] + ].values nose_arr = data_df[[bps["Nose"]["X_bps"], bps["Nose"]["Y_bps"]]].values for roi_type, roi_type_data in video_roi_dict.items(): for _, row in roi_type_data.iterrows(): roi_center = np.array([row["Center_X"], row["Center_Y"]]) roi_name = row["Name"] - direction_data = FeatureExtractionMixin.jitted_line_crosses_to_static_targets(left_ear_array=ear_left_arr, - right_ear_array=ear_right_arr, - nose_array=nose_arr, - target_array=roi_center) - bp_data = self.__format_direction_data(direction_data=direction_data, - nose_arr=nose_arr, - roi_center=roi_center, - animal_name=animal_name, - shape_name=roi_name) - - eye_roi_intersections = pd.DataFrame(self.__find_roi_intersections(bp_data=bp_data, shape_info=row), columns=["ROI_edge_1_x", "ROI_edge_1_y", "ROI_edge_2_x", "ROI_edge_2_y"]) - self.results.append(pd.concat([bp_data, eye_roi_intersections], axis=1)) + direction_data = FeatureExtractionMixin.jitted_line_crosses_to_static_targets( + left_ear_array=ear_left_arr, + right_ear_array=ear_right_arr, + nose_array=nose_arr, + target_array=roi_center, + ) + bp_data = self.__format_direction_data( + direction_data=direction_data, + nose_arr=nose_arr, + roi_center=roi_center, + animal_name=animal_name, + shape_name=roi_name, + ) + + eye_roi_intersections = pd.DataFrame( + self.__find_roi_intersections( + bp_data=bp_data, shape_info=row + ), + columns=[ + "ROI_edge_1_x", + "ROI_edge_1_y", + "ROI_edge_2_x", + "ROI_edge_2_y", + ], + ) + self.results.append( + pd.concat([bp_data, eye_roi_intersections], axis=1) + ) video_timer.stop_timer() - print(f'ROI directionality analyzed in video {self.video_name}... (elapsed time: {video_timer.elapsed_time_str}s)') + print( + f"ROI directionality analyzed in video {self.video_name}... (elapsed time: {video_timer.elapsed_time_str}s)" + ) self.results_df = pd.concat(self.results, axis=0) def save(self, path: Optional[Union[str, os.PathLike]] = None): - if not hasattr(self, 'results_df'): - raise InvalidInputError(msg='Run the ROI direction analyzer before saving') + if not hasattr(self, "results_df"): + raise InvalidInputError(msg="Run the ROI direction analyzer before saving") if path is None: - path = os.path.join(self.logs_path, f'ROI_directionality_summary_{self.datetime}.csv') + path = os.path.join( + self.logs_path, f"ROI_directionality_summary_{self.datetime}.csv" + ) self.results_df.to_csv(path) - stdout_success(msg=f'Detailed ROI directionality data saved in {path}', source=self.__class__.__name__) + stdout_success( + msg=f"Detailed ROI directionality data saved in {path}", + source=self.__class__.__name__, + ) + # # test = DirectingROIAnalyzer(config_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/project_config.ini') # test.run() # test.save() - diff --git a/simba/roi_tools/ROI_feature_analyzer.py b/simba/roi_tools/ROI_feature_analyzer.py index 365ed9ff0..ca0b76a68 100644 --- a/simba/roi_tools/ROI_feature_analyzer.py +++ b/simba/roi_tools/ROI_feature_analyzer.py @@ -9,12 +9,17 @@ from simba.mixins.config_reader import ConfigReader from simba.mixins.feature_extraction_mixin import FeatureExtractionMixin from simba.roi_tools.ROI_directing_analyzer import DirectingROIAnalyzer -from simba.utils.enums import TagNames, Keys -from simba.utils.errors import NoFilesFoundError, CountError, BodypartColumnNotFoundError, ROICoordinatesNotFoundError -from simba.utils.printing import SimbaTimer, log_event, stdout_success -from simba.utils.read_write import get_fn_ext, read_df, write_df, find_files_of_filetypes_in_directory, read_data_paths -from simba.utils.checks import check_valid_lst, check_that_column_exist, check_all_file_names_are_represented_in_video_log, check_file_exist_and_readable +from simba.utils.checks import ( + check_all_file_names_are_represented_in_video_log, + check_file_exist_and_readable, check_that_column_exist, check_valid_lst) from simba.utils.data import slice_roi_dict_for_video +from simba.utils.enums import Keys, TagNames +from simba.utils.errors import (BodypartColumnNotFoundError, CountError, + NoFilesFoundError, ROICoordinatesNotFoundError) +from simba.utils.printing import SimbaTimer, log_event, stdout_success +from simba.utils.read_write import (find_files_of_filetypes_in_directory, + get_fn_ext, read_data_paths, read_df, + write_df) class ROIFeatureCreator(ConfigReader, FeatureExtractionMixin): @@ -40,116 +45,259 @@ class ROIFeatureCreator(ConfigReader, FeatureExtractionMixin): >>> roi_featurizer.save() """ - def __init__(self, - config_path: Union[str, os.PathLike], - body_parts: List[str], - data_path: Optional[Union[str, os.PathLike]] = None, - append_data: Optional[bool] = False): + def __init__( + self, + config_path: Union[str, os.PathLike], + body_parts: List[str], + data_path: Optional[Union[str, os.PathLike]] = None, + append_data: Optional[bool] = False, + ): - check_valid_lst(data=body_parts, source=f'{self.__class__.__name__} body-parts', valid_dtypes=(str,), min_len=1) + check_valid_lst( + data=body_parts, + source=f"{self.__class__.__name__} body-parts", + valid_dtypes=(str,), + min_len=1, + ) if len(set(body_parts)) != len(body_parts): - raise CountError(msg=f'All body-part entries have to be unique. Got {body_parts}', source=self.__class__.__name__) - log_event(logger_name=str(__class__.__name__), log_type=TagNames.CLASS_INIT.value, msg=self.create_log_msg_from_init_args(locals=locals())) + raise CountError( + msg=f"All body-part entries have to be unique. Got {body_parts}", + source=self.__class__.__name__, + ) + log_event( + logger_name=str(__class__.__name__), + log_type=TagNames.CLASS_INIT.value, + msg=self.create_log_msg_from_init_args(locals=locals()), + ) ConfigReader.__init__(self, config_path=config_path) FeatureExtractionMixin.__init__(self, config_path=config_path) self.read_roi_data() if not os.path.isfile(self.roi_coordinates_path): - raise ROICoordinatesNotFoundError(expected_file_path=self.roi_coordinates_path) + raise ROICoordinatesNotFoundError( + expected_file_path=self.roi_coordinates_path + ) for bp in body_parts: if bp not in self.body_parts_lst: - raise BodypartColumnNotFoundError(msg=f'The body-part {bp} is not a valid body-part in the SimBA project. Options: {self.body_parts_lst}', source=self.__class__.__name__) + raise BodypartColumnNotFoundError( + msg=f"The body-part {bp} is not a valid body-part in the SimBA project. Options: {self.body_parts_lst}", + source=self.__class__.__name__, + ) self.bp_lk = {} for cnt, bp in enumerate(body_parts): - animal = self.find_animal_name_from_body_part_name(bp_name=bp, bp_dict=self.animal_bp_dict) - self.bp_lk[cnt] = [animal, bp, [f'{bp}_{"x"}', f'{bp}_{"y"}', f'{bp}_{"p"}']] + animal = self.find_animal_name_from_body_part_name( + bp_name=bp, bp_dict=self.animal_bp_dict + ) + self.bp_lk[cnt] = [ + animal, + bp, + [f'{bp}_{"x"}', f'{bp}_{"y"}', f'{bp}_{"p"}'], + ] self.roi_directing_viable = self.check_directionality_viable()[0] - log_event(logger_name=str(__class__.__name__), log_type=TagNames.CLASS_INIT.value, msg=self.create_log_msg_from_init_args(locals=locals())) - self.data_paths = read_data_paths(path=data_path, default=self.outlier_corrected_paths, default_name=self.outlier_corrected_dir, file_type=self.file_type) + log_event( + logger_name=str(__class__.__name__), + log_type=TagNames.CLASS_INIT.value, + msg=self.create_log_msg_from_init_args(locals=locals()), + ) + self.data_paths = read_data_paths( + path=data_path, + default=self.outlier_corrected_paths, + default_name=self.outlier_corrected_dir, + file_type=self.file_type, + ) if self.roi_directing_viable: print("Directionality calculations are VIABLE.") - self.directing_analyzer = DirectingROIAnalyzer( config_path=config_path, data_path=self.data_paths) + self.directing_analyzer = DirectingROIAnalyzer( + config_path=config_path, data_path=self.data_paths + ) self.directing_analyzer.run() self.dr = self.directing_analyzer.results_df else: self.directing_analyzer = None self.dr = None - if len(self.outlier_corrected_paths) == 0: raise NoFilesFoundError(msg=f'No data found in the {self.outlier_corrected_dir} directory', source=self.__class__.__name__) - if len(self.feature_file_paths) == 0: raise NoFilesFoundError(msg=f'No data found in the {self.features_dir} directory', source=self.__class__.__name__) + if len(self.outlier_corrected_paths) == 0: + raise NoFilesFoundError( + msg=f"No data found in the {self.outlier_corrected_dir} directory", + source=self.__class__.__name__, + ) + if len(self.feature_file_paths) == 0: + raise NoFilesFoundError( + msg=f"No data found in the {self.features_dir} directory", + source=self.__class__.__name__, + ) self.append_data = append_data - print(f"Processing {len(self.outlier_corrected_paths)} video(s) for ROI features...") - + print( + f"Processing {len(self.outlier_corrected_paths)} video(s) for ROI features..." + ) def run(self): - check_all_file_names_are_represented_in_video_log(video_info_df=self.video_info_df, data_paths=self.outlier_corrected_paths) - self.summary = pd.DataFrame(columns=['VIDEO', 'ANIMAL', 'SHAPE NAME', 'MEASUREMENT', 'VALUE']) + check_all_file_names_are_represented_in_video_log( + video_info_df=self.video_info_df, data_paths=self.outlier_corrected_paths + ) + self.summary = pd.DataFrame( + columns=["VIDEO", "ANIMAL", "SHAPE NAME", "MEASUREMENT", "VALUE"] + ) if self.append_data: _o_paths = set([get_fn_ext(x)[1] for x in self.outlier_corrected_paths]) _f_paths = set([get_fn_ext(x)[1] for x in self.feature_file_paths]) o_paths = _o_paths.difference(_f_paths) f_paths = _f_paths.difference(_o_paths) if len(o_paths) != 0 or len(f_paths) != 0: - raise NoFilesFoundError(msg=f'Before appending ROI features, make sure each video is represented in both the {self.outlier_corrected_dir} and {self.features_dir} directory, you have data files that are only represented in one folder {o_paths} {f_paths}', source=self.__class__.__name__) + raise NoFilesFoundError( + msg=f"Before appending ROI features, make sure each video is represented in both the {self.outlier_corrected_dir} and {self.features_dir} directory, you have data files that are only represented in one folder {o_paths} {f_paths}", + source=self.__class__.__name__, + ) for file_cnt, file_path in enumerate(self.data_paths): video_timer = SimbaTimer(start=True) _, self.video_name, _ = get_fn_ext(file_path) - features_file_path = os.path.join(self.features_dir, f'{self.video_name}.{self.file_type}') - self.out_df = read_df(file_path=features_file_path, file_type=self.file_type) - _, self.pixels_per_mm, self.fps = self.read_video_info(video_name=self.video_name) + features_file_path = os.path.join( + self.features_dir, f"{self.video_name}.{self.file_type}" + ) + self.out_df = read_df( + file_path=features_file_path, file_type=self.file_type + ) + _, self.pixels_per_mm, self.fps = self.read_video_info( + video_name=self.video_name + ) data_df = read_df(file_path=file_path, file_type=self.file_type) - self.video_roi_dict, self.shape_names = slice_roi_dict_for_video(data=self.roi_dict, video_name=self.video_name) + self.video_roi_dict, self.shape_names = slice_roi_dict_for_video( + data=self.roi_dict, video_name=self.video_name + ) for animal_cnt, animal_data in self.bp_lk.items(): animal_name, body_part_name, bp_cols = animal_data - check_that_column_exist(df=data_df, column_name=bp_cols, file_name=file_path) + check_that_column_exist( + df=data_df, column_name=bp_cols, file_name=file_path + ) animal_df = data_df[bp_cols] for _, row in self.video_roi_dict[Keys.ROI_RECTANGLES.value].iterrows(): - roi_border = np.array([[row["topLeftX"], row["topLeftY"]], [row["Bottom_right_X"], row["Bottom_right_Y"]]]) + roi_border = np.array( + [ + [row["topLeftX"], row["topLeftY"]], + [row["Bottom_right_X"], row["Bottom_right_Y"]], + ] + ) roi_name = row["Name"] roi_center = np.array([row["Center_X"], row["Center_Y"]]) c = f"{roi_name} {animal_name} {body_part_name} distance" - self.out_df[c] = FeatureExtractionMixin.framewise_euclidean_distance_roi(location_1=animal_df.values[:, 0:2], location_2=roi_center, px_per_mm=self.pixels_per_mm) - self.summary.loc[len(self.summary)] = [self.video_name, animal_name, roi_name, 'Average distance (mm)', self.out_df[c].mean().round(4)] + self.out_df[c] = ( + FeatureExtractionMixin.framewise_euclidean_distance_roi( + location_1=animal_df.values[:, 0:2], + location_2=roi_center, + px_per_mm=self.pixels_per_mm, + ) + ) + self.summary.loc[len(self.summary)] = [ + self.video_name, + animal_name, + roi_name, + "Average distance (mm)", + self.out_df[c].mean().round(4), + ] c = f"{roi_name} {animal_name} {body_part_name} in zone" - self.out_df[c] = FeatureExtractionMixin.framewise_inside_rectangle_roi(bp_location=animal_df.values[:, 0:2], roi_coords=roi_border) + self.out_df[c] = ( + FeatureExtractionMixin.framewise_inside_rectangle_roi( + bp_location=animal_df.values[:, 0:2], roi_coords=roi_border + ) + ) for _, row in self.video_roi_dict[Keys.ROI_CIRCLES.value].iterrows(): roi_center = np.array([row["centerX"], row["centerY"]]) roi_name, radius = row["Name"], row["radius"] c = f"{roi_name} {animal_name} {body_part_name} distance" - self.out_df[c] = FeatureExtractionMixin.framewise_euclidean_distance_roi(location_1=animal_df.values[:, 0:2], location_2=roi_center, px_per_mm=self.pixels_per_mm) - self.summary.loc[len(self.summary)] = [self.video_name, animal_name, roi_name, 'Average distance (mm)', self.out_df[c].mean().round(4)] + self.out_df[c] = ( + FeatureExtractionMixin.framewise_euclidean_distance_roi( + location_1=animal_df.values[:, 0:2], + location_2=roi_center, + px_per_mm=self.pixels_per_mm, + ) + ) + self.summary.loc[len(self.summary)] = [ + self.video_name, + animal_name, + roi_name, + "Average distance (mm)", + self.out_df[c].mean().round(4), + ] in_zone = f"{roi_name} {animal_name} {body_part_name} in zone" self.out_df[in_zone] = 0 self.out_df.loc[self.out_df[c] <= row["radius"], in_zone] = 1 for _, row in self.roi_dict[Keys.ROI_POLYGONS.value].iterrows(): - roi_coords = np.array(list(zip(row["vertices"][:, 0], row["vertices"][:, 1]))) + roi_coords = np.array( + list(zip(row["vertices"][:, 0], row["vertices"][:, 1])) + ) roi_center = np.array([row["Center_X"], row["Center_Y"]]) - roi_name = row['Name'] + roi_name = row["Name"] c = f"{roi_name} {animal_name} {body_part_name} distance" - self.out_df[c] = FeatureExtractionMixin.framewise_euclidean_distance_roi(location_1=animal_df.values[:, 0:2], location_2=roi_center, px_per_mm=self.pixels_per_mm) - self.summary.loc[len(self.summary)] = [self.video_name, animal_name, roi_name, 'Average distance (mm)', self.out_df[c].mean().round(4)] + self.out_df[c] = ( + FeatureExtractionMixin.framewise_euclidean_distance_roi( + location_1=animal_df.values[:, 0:2], + location_2=roi_center, + px_per_mm=self.pixels_per_mm, + ) + ) + self.summary.loc[len(self.summary)] = [ + self.video_name, + animal_name, + roi_name, + "Average distance (mm)", + self.out_df[c].mean().round(4), + ] c = f"{roi_name} {animal_name} {body_part_name} in zone" - self.out_df[c] = FeatureExtractionMixin.framewise_inside_polygon_roi(bp_location=animal_df.values[:, 0:2], roi_coords=roi_coords) + self.out_df[c] = ( + FeatureExtractionMixin.framewise_inside_polygon_roi( + bp_location=animal_df.values[:, 0:2], roi_coords=roi_coords + ) + ) if self.roi_directing_viable: - animal_dr = self.dr.loc[(self.dr['Video'] == self.video_name) & (self.dr['Animal'] == animal_name)] + animal_dr = self.dr.loc[ + (self.dr["Video"] == self.video_name) + & (self.dr["Animal"] == animal_name) + ] for shape_name in self.shape_names: c = f"{shape_name} {animal_name} facing" - animal_shape_idx = list(animal_dr.loc[(animal_dr['ROI'] == shape_name) & (animal_dr['Directing_BOOL'] == 1)]['Frame']) + animal_shape_idx = list( + animal_dr.loc[ + (animal_dr["ROI"] == shape_name) + & (animal_dr["Directing_BOOL"] == 1) + ]["Frame"] + ) self.out_df[c] = 0 self.out_df.loc[animal_shape_idx, c] = 1 - self.summary.loc[len(self.summary)] = [self.video_name, animal_name, shape_name, 'Total direction time (s)', round((self.out_df[c].sum() / self.fps), 4)] + self.summary.loc[len(self.summary)] = [ + self.video_name, + animal_name, + shape_name, + "Total direction time (s)", + round((self.out_df[c].sum() / self.fps), 4), + ] video_timer.stop_timer() if self.append_data: - write_df(df=self.out_df, file_type=self.file_type, save_path=features_file_path) - print(f'New file with ROI features created at {features_file_path} saved (File {file_cnt+1}/{len(self.outlier_corrected_paths)}), elapsed time: {video_timer.elapsed_time_str}s') + write_df( + df=self.out_df, + file_type=self.file_type, + save_path=features_file_path, + ) + print( + f"New file with ROI features created at {features_file_path} saved (File {file_cnt+1}/{len(self.outlier_corrected_paths)}), elapsed time: {video_timer.elapsed_time_str}s" + ) self.timer.stop_timer() - stdout_success(msg=f'ROI features analysed for {len(self.data_paths)} videos', elapsed_time=self.timer.elapsed_time_str) + stdout_success( + msg=f"ROI features analysed for {len(self.data_paths)} videos", + elapsed_time=self.timer.elapsed_time_str, + ) def save(self): - save_path = os.path.join(self.logs_path, f"ROI_features_summary_{self.datetime}.csv") + save_path = os.path.join( + self.logs_path, f"ROI_features_summary_{self.datetime}.csv" + ) self.summary.to_csv(save_path) print(f"ROI feature summary data saved at {save_path}") self.timer.stop_timer() - stdout_success(msg=f'{len(self.outlier_corrected_paths)} new files with ROI features saved in {self.features_dir}', elapsed_time=self.timer.elapsed_time_str) + stdout_success( + msg=f"{len(self.outlier_corrected_paths)} new files with ROI features saved in {self.features_dir}", + elapsed_time=self.timer.elapsed_time_str, + ) + + # # roi_featurizer = ROIFeatureCreator(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini', # body_parts=['Nose_1', 'Nose_2'], diff --git a/simba/roi_tools/ROI_time_bin_calculator.py b/simba/roi_tools/ROI_time_bin_calculator.py index 78e15c4b8..05ec9b00e 100644 --- a/simba/roi_tools/ROI_time_bin_calculator.py +++ b/simba/roi_tools/ROI_time_bin_calculator.py @@ -10,12 +10,13 @@ from simba.mixins.config_reader import ConfigReader from simba.mixins.feature_extraction_supplement_mixin import \ FeatureExtractionSupplemental -#from simba.roi_tools.ROI_analyzer import ROIAnalyzer +# from simba.roi_tools.ROI_analyzer import ROIAnalyzer from simba.sandbox.ROI_analyzer import ROIAnalyzer from simba.utils.checks import check_float, check_if_filepath_list_is_empty -from simba.utils.errors import FrameRangeError, ROICoordinatesNotFoundError, BodypartColumnNotFoundError, DuplicationError +from simba.utils.errors import (BodypartColumnNotFoundError, DuplicationError, + FrameRangeError, ROICoordinatesNotFoundError) from simba.utils.printing import SimbaTimer, stdout_success -from simba.utils.read_write import get_fn_ext, read_df, read_data_paths +from simba.utils.read_write import get_fn_ext, read_data_paths, read_df class ROITimebinCalculator(ConfigReader): @@ -44,44 +45,110 @@ class ROITimebinCalculator(ConfigReader): >>> calculator.save() """ - def __init__(self, - config_path: Union[str, os.PathLike], - bin_length: float, - body_parts: List[str], - data_path: Optional[Union[str, os.PathLike, List[str]]] = None, - threshold: Optional[float] = 0.0, - movement: Optional[bool] = False): + def __init__( + self, + config_path: Union[str, os.PathLike], + bin_length: float, + body_parts: List[str], + data_path: Optional[Union[str, os.PathLike, List[str]]] = None, + threshold: Optional[float] = 0.0, + movement: Optional[bool] = False, + ): ConfigReader.__init__(self, config_path=config_path) if not os.path.isfile(self.roi_coordinates_path): - raise ROICoordinatesNotFoundError(expected_file_path=self.roi_coordinates_path) + raise ROICoordinatesNotFoundError( + expected_file_path=self.roi_coordinates_path + ) check_float(name="bin_length", value=bin_length, min_value=10e-6) check_float(name="threshold", value=threshold, min_value=0.0, max_value=1.0) - self.data_paths = read_data_paths(path=data_path, default=self.outlier_corrected_paths, default_name=self.outlier_corrected_dir, file_type=self.file_type) + self.data_paths = read_data_paths( + path=data_path, + default=self.outlier_corrected_paths, + default_name=self.outlier_corrected_dir, + file_type=self.file_type, + ) self.read_roi_data() - self.bin_length, self.body_parts, self.threshold = (bin_length, body_parts, threshold) - self.save_path_time = os.path.join(self.logs_path, f"ROI_time_bins_{bin_length}s_time_data_{self.datetime}.csv") - self.save_path_entries = os.path.join(self.logs_path, f"ROI_time_bins_{bin_length}s_entry_data_{self.datetime}.csv") + self.bin_length, self.body_parts, self.threshold = ( + bin_length, + body_parts, + threshold, + ) + self.save_path_time = os.path.join( + self.logs_path, f"ROI_time_bins_{bin_length}s_time_data_{self.datetime}.csv" + ) + self.save_path_entries = os.path.join( + self.logs_path, + f"ROI_time_bins_{bin_length}s_entry_data_{self.datetime}.csv", + ) for bp in body_parts: if bp not in self.body_parts_lst: - raise BodypartColumnNotFoundError(msg=f'The body-part {bp} is not a valid body-part in the SimBA project. Options: {self.body_parts_lst}', source=self.__class__.__name__) + raise BodypartColumnNotFoundError( + msg=f"The body-part {bp} is not a valid body-part in the SimBA project. Options: {self.body_parts_lst}", + source=self.__class__.__name__, + ) if len(set(body_parts)) != len(body_parts): - raise DuplicationError(msg=f'All body-part entries have to be unique. Got {body_parts}', source=self.__class__.__name__) - self.roi_analyzer = ROIAnalyzer(config_path=self.config_path, data_path=self.outlier_corrected_dir, calculate_distances=False, threshold=threshold, body_parts=body_parts) + raise DuplicationError( + msg=f"All body-part entries have to be unique. Got {body_parts}", + source=self.__class__.__name__, + ) + self.roi_analyzer = ROIAnalyzer( + config_path=self.config_path, + data_path=self.outlier_corrected_dir, + calculate_distances=False, + threshold=threshold, + body_parts=body_parts, + ) self.roi_analyzer.run() self.animal_names = list(self.roi_analyzer.bp_dict.keys()) self.bp_dict = self.roi_analyzer.bp_dict self.entries_exits_df = self.roi_analyzer.detailed_df self.movement = movement if movement: - self.save_path_movement_velocity = os.path.join(self.logs_path, f"ROI_time_bins_{bin_length}s_movement_velocity_data_{self.datetime}.csv") - self.movement_timebins = TimeBinsMovementCalculator(config_path=config_path, bin_length=bin_length, body_parts=body_parts, plots=False) + self.save_path_movement_velocity = os.path.join( + self.logs_path, + f"ROI_time_bins_{bin_length}s_movement_velocity_data_{self.datetime}.csv", + ) + self.movement_timebins = TimeBinsMovementCalculator( + config_path=config_path, + bin_length=bin_length, + body_parts=body_parts, + plots=False, + ) self.movement_timebins.run() def run(self): - self.results_entries = pd.DataFrame(columns=["VIDEO","SHAPE","ANIMAL","BODY-PART","TIME BIN #","ENTRY COUNT",]) - self.results_time = pd.DataFrame(columns=["VIDEO","SHAPE","ANIMAL","BODY-PART","TIME BIN #","TIME INSIDE SHAPE (S)"]) - self.results_movement_velocity = pd.DataFrame(columns=["VIDEO","SHAPE","ANIMAL","BODY-PART","TIME BIN #","DISTANCE (CM)","VELOCITY (CM/S)"]) + self.results_entries = pd.DataFrame( + columns=[ + "VIDEO", + "SHAPE", + "ANIMAL", + "BODY-PART", + "TIME BIN #", + "ENTRY COUNT", + ] + ) + self.results_time = pd.DataFrame( + columns=[ + "VIDEO", + "SHAPE", + "ANIMAL", + "BODY-PART", + "TIME BIN #", + "TIME INSIDE SHAPE (S)", + ] + ) + self.results_movement_velocity = pd.DataFrame( + columns=[ + "VIDEO", + "SHAPE", + "ANIMAL", + "BODY-PART", + "TIME BIN #", + "DISTANCE (CM)", + "VELOCITY (CM/S)", + ] + ) print(f"Analyzing time-bin data for {len(self.data_paths)} video(s)...") for file_cnt, file_path in enumerate(self.data_paths): video_timer = SimbaTimer(start=True) @@ -89,44 +156,103 @@ def run(self): _, px_per_mm, fps = self.read_video_info(video_name=self.video_name) frames_per_bin = int(fps * self.bin_length) if frames_per_bin == 0: - raise FrameRangeError(msg=f"The specified time-bin length of {self.bin_length} is TOO SHORT for video {self.video_name} which has a specified FPS of {fps}. This results in time bins that are LESS THAN a single frame.", source=self.__class__.__name__) - video_frms = list(range(0, len(read_df(file_path=file_path, file_type=self.file_type)))) - frame_bins = [video_frms[i : i + (frames_per_bin)] for i in range(0, len(video_frms), frames_per_bin)] - self.video_data = self.entries_exits_df[self.entries_exits_df["VIDEO"] == self.video_name] - for animal_name, shape_name in list(itertools.product(self.animal_names, self.shape_names)): - data_df = self.video_data.loc[(self.video_data["SHAPE NAME"] == shape_name) & (self.video_data["ANIMAL"] == animal_name)] + raise FrameRangeError( + msg=f"The specified time-bin length of {self.bin_length} is TOO SHORT for video {self.video_name} which has a specified FPS of {fps}. This results in time bins that are LESS THAN a single frame.", + source=self.__class__.__name__, + ) + video_frms = list( + range(0, len(read_df(file_path=file_path, file_type=self.file_type))) + ) + frame_bins = [ + video_frms[i : i + (frames_per_bin)] + for i in range(0, len(video_frms), frames_per_bin) + ] + self.video_data = self.entries_exits_df[ + self.entries_exits_df["VIDEO"] == self.video_name + ] + for animal_name, shape_name in list( + itertools.product(self.animal_names, self.shape_names) + ): + data_df = self.video_data.loc[ + (self.video_data["SHAPE NAME"] == shape_name) + & (self.video_data["ANIMAL"] == animal_name) + ] body_part = self.bp_dict[animal_name][0][:-2] entry_frms = list(data_df["START FRAME"]) - inside_shape_frms = [list(range(x, y)) for x, y in zip(list(data_df["START FRAME"].astype(int)), list(data_df["END FRAME"].astype(int) + 1))] + inside_shape_frms = [ + list(range(x, y)) + for x, y in zip( + list(data_df["START FRAME"].astype(int)), + list(data_df["END FRAME"].astype(int) + 1), + ) + ] inside_shape_frms = [i for s in inside_shape_frms for i in s] for bin_cnt, bin_frms in enumerate(frame_bins): - frms_inside_roi_in_timebin = [x for x in inside_shape_frms if x in bin_frms] + frms_inside_roi_in_timebin = [ + x for x in inside_shape_frms if x in bin_frms + ] entry_roi_in_timebin = [x for x in entry_frms if x in bin_frms] - self.results_time.loc[len(self.results_time)] = [self.video_name,shape_name,animal_name,body_part,bin_cnt,len(frms_inside_roi_in_timebin) / fps] - self.results_entries.loc[len(self.results_entries)] = [self.video_name,shape_name,animal_name,body_part,bin_cnt,len(entry_roi_in_timebin)] + self.results_time.loc[len(self.results_time)] = [ + self.video_name, + shape_name, + animal_name, + body_part, + bin_cnt, + len(frms_inside_roi_in_timebin) / fps, + ] + self.results_entries.loc[len(self.results_entries)] = [ + self.video_name, + shape_name, + animal_name, + body_part, + bin_cnt, + len(entry_roi_in_timebin), + ] if self.movement: if len(frms_inside_roi_in_timebin) > 0: - bin_move = (self.movement_timebins.movement_dict[self.video_name].iloc[frms_inside_roi_in_timebin].values.flatten().astype(np.float32)) + bin_move = ( + self.movement_timebins.movement_dict[self.video_name] + .iloc[frms_inside_roi_in_timebin] + .values.flatten() + .astype(np.float32) + ) print(bin_move) - _, velocity = (FeatureExtractionSupplemental.distance_and_velocity(x=bin_move,fps=fps, pixels_per_mm=1, centimeters=True)) - self.results_movement_velocity.loc[len(self.results_movement_velocity)] = [self.video_name, - shape_name, - animal_name, - body_part, - bin_cnt, - bin_move[1:].sum() / 10, - velocity] + _, velocity = ( + FeatureExtractionSupplemental.distance_and_velocity( + x=bin_move, + fps=fps, + pixels_per_mm=1, + centimeters=True, + ) + ) + self.results_movement_velocity.loc[ + len(self.results_movement_velocity) + ] = [ + self.video_name, + shape_name, + animal_name, + body_part, + bin_cnt, + bin_move[1:].sum() / 10, + velocity, + ] else: - self.results_movement_velocity.loc[len(self.results_movement_velocity)] = [self.video_name, - shape_name, - animal_name, - body_part, - bin_cnt, - 0, - 0] + self.results_movement_velocity.loc[ + len(self.results_movement_velocity) + ] = [ + self.video_name, + shape_name, + animal_name, + body_part, + bin_cnt, + 0, + 0, + ] video_timer.stop_timer() - print(f"Video {self.video_name} complete (elapsed time {video_timer.elapsed_time_str}s)") + print( + f"Video {self.video_name} complete (elapsed time {video_timer.elapsed_time_str}s)" + ) def save(self): self.results_time.sort_values( diff --git a/simba/ui/pop_ups/append_roi_features_animals_pop_up.py b/simba/ui/pop_ups/append_roi_features_animals_pop_up.py index 9d7a0d1a0..383595c28 100644 --- a/simba/ui/pop_ups/append_roi_features_animals_pop_up.py +++ b/simba/ui/pop_ups/append_roi_features_animals_pop_up.py @@ -1,14 +1,14 @@ __author__ = "Simon Nilsson" import os +import threading from tkinter import * from typing import Union -import threading from simba.mixins.config_reader import ConfigReader from simba.mixins.pop_up_mixin import PopUpMixin from simba.roi_tools.ROI_feature_analyzer import ROIFeatureCreator -from simba.ui.tkinter_functions import (CreateLabelFrameWithIcon, DropDownMenu) +from simba.ui.tkinter_functions import CreateLabelFrameWithIcon, DropDownMenu from simba.utils.enums import Formats, Keys, Links from simba.utils.errors import NoROIDataError @@ -17,13 +17,32 @@ class AppendROIFeaturesByAnimalPopUp(ConfigReader, PopUpMixin): def __init__(self, config_path: Union[str, os.PathLike]): ConfigReader.__init__(self, config_path=config_path) if not os.path.isfile(self.roi_coordinates_path): - raise NoROIDataError(msg=f"SIMBA ERROR: No ROIs have been defined. Please define ROIs before appending ROI-based features (no data file found at path {self.roi_coordinates_path})", source=self.__class__.__name__) + raise NoROIDataError( + msg=f"SIMBA ERROR: No ROIs have been defined. Please define ROIs before appending ROI-based features (no data file found at path {self.roi_coordinates_path})", + source=self.__class__.__name__, + ) - PopUpMixin.__init__(self, title="APPEND ROI FEATURES: BY ANIMALS", size=(400, 400)) - self.animal_cnt_frm = CreateLabelFrameWithIcon(parent=self.main_frm, header="SELECT NUMBER OF ANIMALS", icon_name=Keys.DOCUMENTATION.value, icon_link=Links.ROI_FEATURES.value) - self.animal_cnt_dropdown = DropDownMenu(self.animal_cnt_frm,"# of animals",list(range(1, self.animal_cnt + 1)),labelwidth=20) + PopUpMixin.__init__( + self, title="APPEND ROI FEATURES: BY ANIMALS", size=(400, 400) + ) + self.animal_cnt_frm = CreateLabelFrameWithIcon( + parent=self.main_frm, + header="SELECT NUMBER OF ANIMALS", + icon_name=Keys.DOCUMENTATION.value, + icon_link=Links.ROI_FEATURES.value, + ) + self.animal_cnt_dropdown = DropDownMenu( + self.animal_cnt_frm, + "# of animals", + list(range(1, self.animal_cnt + 1)), + labelwidth=20, + ) self.animal_cnt_dropdown.setChoices(1) - self.animal_cnt_confirm_btn = Button(self.animal_cnt_frm,text="Confirm",command=lambda: self.create_settings_frm()) + self.animal_cnt_confirm_btn = Button( + self.animal_cnt_frm, + text="Confirm", + command=lambda: self.create_settings_frm(), + ) self.animal_cnt_frm.grid(row=0, column=0, sticky=NW) self.animal_cnt_dropdown.grid(row=0, column=0, sticky=NW) self.animal_cnt_confirm_btn.grid(row=0, column=1, sticky=NW) @@ -45,16 +64,20 @@ def run(self): for bp_cnt, bp_dropdown in self.body_parts_dropdowns.items(): body_parts.append(bp_dropdown.getChoices()) - roi_feature_creator = ROIFeatureCreator(config_path=self.config_path, body_parts=body_parts, data_path=None, append_data=True) + roi_feature_creator = ROIFeatureCreator( + config_path=self.config_path, + body_parts=body_parts, + data_path=None, + append_data=True, + ) threading.Thread(target=roi_feature_creator.run()).start() self.root.destroy() -#AppendROIFeaturesByAnimalPopUp(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/open_field_below/project_folder/project_config.ini') - +# AppendROIFeaturesByAnimalPopUp(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/open_field_below/project_folder/project_config.ini') -#AppendROIFeaturesByAnimalPopUp(config_path='/Users/simon/Desktop/envs/troubleshooting/locomotion/project_folder/project_config.ini') +# AppendROIFeaturesByAnimalPopUp(config_path='/Users/simon/Desktop/envs/troubleshooting/locomotion/project_folder/project_config.ini') # AppendROIFeaturesByAnimalPopUp(config_path='/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini') diff --git a/simba/ui/pop_ups/append_roi_features_bodypart_pop_up.py b/simba/ui/pop_ups/append_roi_features_bodypart_pop_up.py index e420edbbe..a0b2ac9c1 100644 --- a/simba/ui/pop_ups/append_roi_features_bodypart_pop_up.py +++ b/simba/ui/pop_ups/append_roi_features_bodypart_pop_up.py @@ -1,7 +1,7 @@ __author__ = "Simon Nilsson" import os -from typing import Union import threading +from typing import Union from simba.mixins.config_reader import ConfigReader from simba.mixins.pop_up_mixin import PopUpMixin @@ -13,9 +13,16 @@ class AppendROIFeaturesByBodyPartPopUp(PopUpMixin, ConfigReader): def __init__(self, config_path: Union[str, os.PathLike]): ConfigReader.__init__(self, config_path=config_path) if not os.path.isfile(self.roi_coordinates_path): - raise NoROIDataError(msg="SIMBA ERROR: No ROIs have been defined. Please define ROIs before appending ROI-based features", source=self.__class__.__name__) - PopUpMixin.__init__(self, config_path=config_path, title="APPEND ROI FEATURES: BY BODY-PARTS") - self.create_choose_number_of_body_parts_frm(project_body_parts=self.project_bps, run_function=self.run) + raise NoROIDataError( + msg="SIMBA ERROR: No ROIs have been defined. Please define ROIs before appending ROI-based features", + source=self.__class__.__name__, + ) + PopUpMixin.__init__( + self, config_path=config_path, title="APPEND ROI FEATURES: BY BODY-PARTS" + ) + self.create_choose_number_of_body_parts_frm( + project_body_parts=self.project_bps, run_function=self.run + ) self.main_frm.mainloop() def run(self): @@ -23,10 +30,16 @@ def run(self): for bp_cnt, bp_dropdown in self.body_parts_dropdowns.items(): body_parts.append(bp_dropdown.getChoices()) - roi_feature_creator = ROIFeatureCreator(config_path=self.config_path, body_parts=body_parts, data_path=None, append_data=True) + roi_feature_creator = ROIFeatureCreator( + config_path=self.config_path, + body_parts=body_parts, + data_path=None, + append_data=True, + ) threading.Thread(target=roi_feature_creator.run()).start() self.root.destroy() + # _ = AppendROIFeaturesByBodyPartPopUp(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/open_field_below/project_folder/project_config.ini') diff --git a/simba/ui/pop_ups/directing_other_animals_plot_pop_up.py b/simba/ui/pop_ups/directing_other_animals_plot_pop_up.py index 95d6fc312..075370f92 100644 --- a/simba/ui/pop_ups/directing_other_animals_plot_pop_up.py +++ b/simba/ui/pop_ups/directing_other_animals_plot_pop_up.py @@ -1,9 +1,8 @@ __author__ = "Simon Nilsson" import os -from tkinter import * import threading - +from tkinter import * from typing import Union from simba.mixins.config_reader import ConfigReader @@ -15,57 +14,130 @@ from simba.ui.tkinter_functions import CreateLabelFrameWithIcon, DropDownMenu from simba.utils.enums import Formats, Keys, Links from simba.utils.errors import AnimalNumberError -from simba.utils.read_write import find_all_videos_in_directory from simba.utils.lookups import get_color_dict +from simba.utils.read_write import find_all_videos_in_directory -DIRECTION_THICKNESS = 'direction_thickness' -DIRECTIONALITY_COLOR = 'directionality_color' -CIRCLE_SIZE = 'circle_size' -HIGHLIGHT_ENDPOINTS = 'highlight_endpoints' -SHOW_POSE = 'show_pose' -ANIMAL_NAMES = 'animal_names' +DIRECTION_THICKNESS = "direction_thickness" +DIRECTIONALITY_COLOR = "directionality_color" +CIRCLE_SIZE = "circle_size" +HIGHLIGHT_ENDPOINTS = "highlight_endpoints" +SHOW_POSE = "show_pose" +ANIMAL_NAMES = "animal_names" class DirectingOtherAnimalsVisualizerPopUp(PopUpMixin, ConfigReader): - def __init__(self, - config_path: Union[str, os.PathLike]): + def __init__(self, config_path: Union[str, os.PathLike]): ConfigReader.__init__(self, config_path=config_path) if self.animal_cnt == 1: - raise AnimalNumberError(msg="Cannot visualize directionality between animals in a 1 animal project.", source=self.__class__.__name__,) + raise AnimalNumberError( + msg="Cannot visualize directionality between animals in a 1 animal project.", + source=self.__class__.__name__, + ) PopUpMixin.__init__(self, title="CREATE ANIMAL DIRECTION VIDEOS") self.color_dict = get_color_dict() self.color_lst = list(self.color_dict.keys()) self.color_lst.insert(0, "random") self.size_lst = list(range(1, 11)) - self.files_found_dict = find_all_videos_in_directory(directory=self.video_dir, as_dict=True) + self.files_found_dict = find_all_videos_in_directory( + directory=self.video_dir, as_dict=True + ) self.show_pose_var = BooleanVar(value=True) self.show_animal_names_var = BooleanVar(value=True) self.highlight_direction_endpoints_var = BooleanVar(value=True) self.multiprocess_var = BooleanVar(value=False) - self.style_settings_frm = CreateLabelFrameWithIcon(parent=self.main_frm, header="STYLE SETTINGS", icon_name=Keys.DOCUMENTATION.value, icon_link=Links.DIRECTING_ANIMALS_PLOTS.value) - self.show_pose_cb = Checkbutton(self.style_settings_frm, text="Show pose-estimated body-parts", variable=self.show_pose_var,) - self.highlight_direction_endpoints_cb = Checkbutton(self.style_settings_frm, text="Highlight direction end-points", variable=self.highlight_direction_endpoints_var) - self.show_animal_names_cb = Checkbutton(self.style_settings_frm, text="Show animal names", variable=self.show_animal_names_var) - - self.direction_clr_dropdown = DropDownMenu(self.style_settings_frm, "Direction color:", self.color_lst, "16") - self.pose_size_dropdown = DropDownMenu(self.style_settings_frm, "Pose circle size:", self.size_lst, "16") - self.line_thickness = DropDownMenu(self.style_settings_frm, "Line thickness:", self.size_lst, "16") + self.style_settings_frm = CreateLabelFrameWithIcon( + parent=self.main_frm, + header="STYLE SETTINGS", + icon_name=Keys.DOCUMENTATION.value, + icon_link=Links.DIRECTING_ANIMALS_PLOTS.value, + ) + self.show_pose_cb = Checkbutton( + self.style_settings_frm, + text="Show pose-estimated body-parts", + variable=self.show_pose_var, + ) + self.highlight_direction_endpoints_cb = Checkbutton( + self.style_settings_frm, + text="Highlight direction end-points", + variable=self.highlight_direction_endpoints_var, + ) + self.show_animal_names_cb = Checkbutton( + self.style_settings_frm, + text="Show animal names", + variable=self.show_animal_names_var, + ) + + self.direction_clr_dropdown = DropDownMenu( + self.style_settings_frm, "Direction color:", self.color_lst, "16" + ) + self.pose_size_dropdown = DropDownMenu( + self.style_settings_frm, "Pose circle size:", self.size_lst, "16" + ) + self.line_thickness = DropDownMenu( + self.style_settings_frm, "Line thickness:", self.size_lst, "16" + ) self.line_thickness.setChoices(choice=4) self.pose_size_dropdown.setChoices(choice=3) self.direction_clr_dropdown.setChoices(choice="random") - multiprocess_cb = Checkbutton(self.style_settings_frm, text="Multi-process (faster)",variable=self.multiprocess_var, command=lambda: self.enable_dropdown_from_checkbox(check_box_var=self.multiprocess_var, dropdown_menus=[self.multiprocess_dropdown])) - self.multiprocess_dropdown = DropDownMenu(self.style_settings_frm, "CPU cores:", list(range(2, self.cpu_cnt)), "12") + multiprocess_cb = Checkbutton( + self.style_settings_frm, + text="Multi-process (faster)", + variable=self.multiprocess_var, + command=lambda: self.enable_dropdown_from_checkbox( + check_box_var=self.multiprocess_var, + dropdown_menus=[self.multiprocess_dropdown], + ), + ) + self.multiprocess_dropdown = DropDownMenu( + self.style_settings_frm, "CPU cores:", list(range(2, self.cpu_cnt)), "12" + ) self.multiprocess_dropdown.setChoices(2) self.multiprocess_dropdown.disable() - self.run_frm = LabelFrame(self.main_frm,text="RUN",font=Formats.LABELFRAME_HEADER_FORMAT.value,pady=5,padx=5,fg="black") - self.run_single_video_frm = LabelFrame(self.run_frm, text="SINGLE VIDEO", font=Formats.LABELFRAME_HEADER_FORMAT.value, pady=5, padx=5, fg="black") - self.run_single_video_btn = Button(self.run_single_video_frm, text="Create single video", fg="blue", command=lambda: self.__create_directionality_plots(multiple_videos=False)) - self.single_video_dropdown = DropDownMenu(self.run_single_video_frm, "Video:", list(self.files_found_dict.keys()),"12") + self.run_frm = LabelFrame( + self.main_frm, + text="RUN", + font=Formats.LABELFRAME_HEADER_FORMAT.value, + pady=5, + padx=5, + fg="black", + ) + self.run_single_video_frm = LabelFrame( + self.run_frm, + text="SINGLE VIDEO", + font=Formats.LABELFRAME_HEADER_FORMAT.value, + pady=5, + padx=5, + fg="black", + ) + self.run_single_video_btn = Button( + self.run_single_video_frm, + text="Create single video", + fg="blue", + command=lambda: self.__create_directionality_plots(multiple_videos=False), + ) + self.single_video_dropdown = DropDownMenu( + self.run_single_video_frm, + "Video:", + list(self.files_found_dict.keys()), + "12", + ) self.single_video_dropdown.setChoices(list(self.files_found_dict.keys())[0]) - self.run_multiple_videos = LabelFrame(self.run_frm, text="MULTIPLE VIDEO", font=Formats.LABELFRAME_HEADER_FORMAT.value, pady=5, padx=5, fg="black") - self.run_multiple_video_btn = Button(self.run_multiple_videos, text=f"Create multiple videos ({len(list(self.files_found_dict.keys()))} video(s) found)", fg="blue", command=lambda: self.__create_directionality_plots(multiple_videos=True)) + self.run_multiple_videos = LabelFrame( + self.run_frm, + text="MULTIPLE VIDEO", + font=Formats.LABELFRAME_HEADER_FORMAT.value, + pady=5, + padx=5, + fg="black", + ) + self.run_multiple_video_btn = Button( + self.run_multiple_videos, + text=f"Create multiple videos ({len(list(self.files_found_dict.keys()))} video(s) found)", + fg="blue", + command=lambda: self.__create_directionality_plots(multiple_videos=True), + ) self.style_settings_frm.grid(row=0, column=0, sticky=NW) self.show_pose_cb.grid(row=0, column=0, sticky=NW) @@ -90,34 +162,40 @@ def __create_directionality_plots(self, multiple_videos: bool): if multiple_videos: video_paths = list(self.files_found_dict.values()) else: - video_paths = [self.files_found_dict[self.single_video_dropdown.getChoices()]] - + video_paths = [ + self.files_found_dict[self.single_video_dropdown.getChoices()] + ] direction_clr = self.direction_clr_dropdown.getChoices() - if direction_clr != 'random': + if direction_clr != "random": direction_clr = self.color_dict[direction_clr] - - style_attr = {SHOW_POSE: self.show_pose_var.get(), - CIRCLE_SIZE: int(self.pose_size_dropdown.getChoices()), - DIRECTIONALITY_COLOR: direction_clr, - DIRECTION_THICKNESS: int(self.line_thickness.getChoices()), - HIGHLIGHT_ENDPOINTS: self.highlight_direction_endpoints_var.get(), - ANIMAL_NAMES: self.show_animal_names_var.get()} + style_attr = { + SHOW_POSE: self.show_pose_var.get(), + CIRCLE_SIZE: int(self.pose_size_dropdown.getChoices()), + DIRECTIONALITY_COLOR: direction_clr, + DIRECTION_THICKNESS: int(self.line_thickness.getChoices()), + HIGHLIGHT_ENDPOINTS: self.highlight_direction_endpoints_var.get(), + ANIMAL_NAMES: self.show_animal_names_var.get(), + } for video_path in video_paths: if not self.multiprocess_var.get(): - visualizer = DirectingOtherAnimalsVisualizer(config_path=self.config_path, - video_path=video_path, - style_attr=style_attr) + visualizer = DirectingOtherAnimalsVisualizer( + config_path=self.config_path, + video_path=video_path, + style_attr=style_attr, + ) else: - visualizer = DirectingOtherAnimalsVisualizerMultiprocess(config_path=self.config_path, - video_path=video_path, - style_attr=style_attr, - core_cnt=int(self.multiprocess_dropdown.getChoices())) + visualizer = DirectingOtherAnimalsVisualizerMultiprocess( + config_path=self.config_path, + video_path=video_path, + style_attr=style_attr, + core_cnt=int(self.multiprocess_dropdown.getChoices()), + ) threading.Thread(target=visualizer.run()).start() -#_ = DirectingOtherAnimalsVisualizerPopUp(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini') -#_ = DirectingOtherAnimalsVisualizerPopUp(config_path='/Users/simon/Desktop/envs/troubleshooting/Two_animals_16bps/project_folder/project_config.ini') +# _ = DirectingOtherAnimalsVisualizerPopUp(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini') +# _ = DirectingOtherAnimalsVisualizerPopUp(config_path='/Users/simon/Desktop/envs/troubleshooting/Two_animals_16bps/project_folder/project_config.ini') diff --git a/simba/ui/pop_ups/roi_analysis_pop_up.py b/simba/ui/pop_ups/roi_analysis_pop_up.py index 076a14357..d26cd55fb 100644 --- a/simba/ui/pop_ups/roi_analysis_pop_up.py +++ b/simba/ui/pop_ups/roi_analysis_pop_up.py @@ -83,23 +83,30 @@ def create_settings_frm(self): def run(self): settings = {} - check_float(name="Probability threshold", value=self.probability_entry.entry_get, min_value=0.00, max_value=1.00) + check_float( + name="Probability threshold", + value=self.probability_entry.entry_get, + min_value=0.00, + max_value=1.00, + ) settings["threshold"] = float(self.probability_entry.entry_get) body_parts = [] for cnt, dropdown in self.body_parts_dropdowns.items(): body_parts.append(dropdown.getChoices()) - roi_analyzer = ROIAnalyzer(config_path=self.config_path, - data_path=None, - calculate_distances=self.calculate_distance_moved_var.get(), - detailed_bout_data=self.detailed_roi_var.get(), - threshold=float(self.probability_entry.entry_get), - body_parts=body_parts) + roi_analyzer = ROIAnalyzer( + config_path=self.config_path, + data_path=None, + calculate_distances=self.calculate_distance_moved_var.get(), + detailed_bout_data=self.detailed_roi_var.get(), + threshold=float(self.probability_entry.entry_get), + body_parts=body_parts, + ) roi_analyzer.run() roi_analyzer.save() self.root.destroy() -#_ = ROIAnalysisPopUp(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/project_config.ini') +# _ = ROIAnalysisPopUp(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/project_config.ini') # ROIAnalysisPopUp(config_path='/Users/simon/Desktop/envs/simba_dev/tests/data/test_projects/mouse_open_field/project_folder/project_config.ini') diff --git a/simba/ui/pop_ups/roi_features_plot_pop_up.py b/simba/ui/pop_ups/roi_features_plot_pop_up.py index 4fd60fe04..276426980 100644 --- a/simba/ui/pop_ups/roi_features_plot_pop_up.py +++ b/simba/ui/pop_ups/roi_features_plot_pop_up.py @@ -1,8 +1,8 @@ __author__ = "Simon Nilsson" -from tkinter import * import threading +from tkinter import * from simba.mixins.config_reader import ConfigReader from simba.mixins.pop_up_mixin import PopUpMixin @@ -14,27 +14,33 @@ from simba.utils.checks import check_float from simba.utils.enums import ConfigKey, Formats, Keys, Links from simba.utils.errors import NoFilesFoundError -from simba.utils.read_write import (find_files_of_filetypes_in_directory, get_fn_ext, find_all_videos_in_directory) - -ROI_CENTERS = 'roi_centers' -ROI_EAR_TAGS = 'roi_ear_tags' -DIRECTIONALITY = 'directionality' -DIRECTIONALITY_STYLE = 'directionality_style' -BORDER_COLOR = 'border_color' -POSE = 'pose_estimation' -ANIMAL_NAMES = 'animal_names' +from simba.utils.read_write import (find_all_videos_in_directory, + find_files_of_filetypes_in_directory, + get_fn_ext) +ROI_CENTERS = "roi_centers" +ROI_EAR_TAGS = "roi_ear_tags" +DIRECTIONALITY = "directionality" +DIRECTIONALITY_STYLE = "directionality_style" +BORDER_COLOR = "border_color" +POSE = "pose_estimation" +ANIMAL_NAMES = "animal_names" class VisualizeROIFeaturesPopUp(PopUpMixin, ConfigReader): def __init__(self, config_path: str): PopUpMixin.__init__(self, title="VISUALIZE ROI FEATURES", size=(400, 500)) ConfigReader.__init__(self, config_path=config_path) - self.video_file_paths = find_all_videos_in_directory(directory=self.video_dir, as_dict=True) + self.video_file_paths = find_all_videos_in_directory( + directory=self.video_dir, as_dict=True + ) self.video_list = [k for k in self.video_file_paths.keys()] if len(self.video_list) == 0: - raise NoFilesFoundError(msg="SIMBA ERROR: No videos in SimBA project. Import videos into you SimBA project to visualize ROI features.", source=self.__class__.__name__) + raise NoFilesFoundError( + msg="SIMBA ERROR: No videos in SimBA project. Import videos into you SimBA project to visualize ROI features.", + source=self.__class__.__name__, + ) self.settings_frm = CreateLabelFrameWithIcon( parent=self.main_frm, @@ -74,8 +80,11 @@ def __init__(self, config_path: str): self.settings_frm, text="Show ROI ear tags", variable=self.show_ROI_tags_var ) - show_animal_names_cb = Checkbutton(self.settings_frm, text="Show animal names", variable=self.show_animal_names_var) - + show_animal_names_cb = Checkbutton( + self.settings_frm, + text="Show animal names", + variable=self.show_animal_names_var, + ) show_roi_directionality_cb = Checkbutton( self.settings_frm, @@ -96,32 +105,67 @@ def __init__(self, config_path: str): dropdown_menus=[self.multiprocess_dropdown], ), ) - self.multiprocess_dropdown = DropDownMenu(self.settings_frm, "CPU cores:", list(range(2, self.cpu_cnt)), "12") + self.multiprocess_dropdown = DropDownMenu( + self.settings_frm, "CPU cores:", list(range(2, self.cpu_cnt)), "12" + ) self.multiprocess_dropdown.setChoices(2) self.multiprocess_dropdown.disable() - self.directionality_type_dropdown = DropDownMenu(self.settings_frm, "Direction type:", ["funnel", "Lines"], "12") + self.directionality_type_dropdown = DropDownMenu( + self.settings_frm, "Direction type:", ["funnel", "Lines"], "12" + ) self.directionality_type_dropdown.setChoices(choice="funnel") self.directionality_type_dropdown.disable() - self.body_parts_frm = LabelFrame(self.main_frm, text="SELECT BODY-PARTS", pady=10, font=Formats.LABELFRAME_HEADER_FORMAT.value, fg="black") - self.animal_cnt_dropdown = DropDownMenu(self.body_parts_frm, "Number of animals", list(range(1, self.animal_cnt + 1)), "20", com=lambda x: self.__populate_bp_dropdown(bp_cnt=x)) + self.body_parts_frm = LabelFrame( + self.main_frm, + text="SELECT BODY-PARTS", + pady=10, + font=Formats.LABELFRAME_HEADER_FORMAT.value, + fg="black", + ) + self.animal_cnt_dropdown = DropDownMenu( + self.body_parts_frm, + "Number of animals", + list(range(1, self.animal_cnt + 1)), + "20", + com=lambda x: self.__populate_bp_dropdown(bp_cnt=x), + ) self.animal_cnt_dropdown.setChoices(1) self.__populate_bp_dropdown(bp_cnt=1) self.animal_cnt_dropdown.grid(row=0, column=0, sticky=NW) - self.single_video_frm = LabelFrame(self.main_frm, text="Visualize ROI features on SINGLE video", pady=10, padx=10, font=Formats.LABELFRAME_HEADER_FORMAT.value, fg="black") - self.single_video_dropdown = DropDownMenu(self.single_video_frm, "Select video", self.video_list, "15") + self.single_video_frm = LabelFrame( + self.main_frm, + text="Visualize ROI features on SINGLE video", + pady=10, + padx=10, + font=Formats.LABELFRAME_HEADER_FORMAT.value, + fg="black", + ) + self.single_video_dropdown = DropDownMenu( + self.single_video_frm, "Select video", self.video_list, "15" + ) self.single_video_dropdown.setChoices(self.video_list[0]) - self.single_video_btn = Button(self.single_video_frm, text="Visualize ROI features for SINGLE video", command=lambda: self.run(multiple=False),) + self.single_video_btn = Button( + self.single_video_frm, + text="Visualize ROI features for SINGLE video", + command=lambda: self.run(multiple=False), + ) - self.all_videos_frm = LabelFrame(self.main_frm, - text="Visualize ROI features on ALL videos", - pady=10, - padx=10, - font=Formats.LABELFRAME_HEADER_FORMAT.value, - fg="black") + self.all_videos_frm = LabelFrame( + self.main_frm, + text="Visualize ROI features on ALL videos", + pady=10, + padx=10, + font=Formats.LABELFRAME_HEADER_FORMAT.value, + fg="black", + ) - self.all_videos_btn = Button(self.all_videos_frm, text="Generate ROI visualization on ALL videos", command=lambda: self.run(multiple=True)) + self.all_videos_btn = Button( + self.all_videos_frm, + text="Generate ROI visualization on ALL videos", + command=lambda: self.run(multiple=True), + ) self.settings_frm.grid(row=0, column=0, sticky=NW) self.threshold_entry_box.grid(row=0, sticky=NW) threshold_label.grid(row=1, sticky=NW) @@ -161,39 +205,52 @@ def __populate_bp_dropdown(self, bp_cnt: int): self.bp_dropdown_dict[cnt].grid(row=cnt + 1, column=0, sticky=NW) def run(self, multiple: bool): - check_float(name="Body-part probability threshold", value=self.threshold_entry_box.entry_get, min_value=0.0, max_value=1.0) - style_attr = {ROI_CENTERS: self.show_ROI_centers_var.get(), - ROI_EAR_TAGS: self.show_ROI_tags_var.get(), - POSE: self.show_pose_var.get(), - ANIMAL_NAMES: self.show_animal_names_var.get(), - DIRECTIONALITY: self.show_direction_var.get(), - BORDER_COLOR: self.colors_dict[self.border_clr_dropdown.getChoices()], - DIRECTIONALITY_STYLE: self.directionality_type_dropdown.getChoices()} + check_float( + name="Body-part probability threshold", + value=self.threshold_entry_box.entry_get, + min_value=0.0, + max_value=1.0, + ) + style_attr = { + ROI_CENTERS: self.show_ROI_centers_var.get(), + ROI_EAR_TAGS: self.show_ROI_tags_var.get(), + POSE: self.show_pose_var.get(), + ANIMAL_NAMES: self.show_animal_names_var.get(), + DIRECTIONALITY: self.show_direction_var.get(), + BORDER_COLOR: self.colors_dict[self.border_clr_dropdown.getChoices()], + DIRECTIONALITY_STYLE: self.directionality_type_dropdown.getChoices(), + } if multiple: video_paths = [v for k, v in self.video_file_paths.items()] else: - video_paths = [self.video_file_paths[self.single_video_dropdown.getChoices()]] + video_paths = [ + self.video_file_paths[self.single_video_dropdown.getChoices()] + ] body_parts = [v.getChoices() for v in self.bp_dropdown_dict.values()] for video_path in video_paths: if not self.multiprocess_var.get(): - roi_feature_visualizer = ROIfeatureVisualizer(config_path=self.config_path, - video_path=video_path, - body_parts=body_parts, - style_attr=style_attr) + roi_feature_visualizer = ROIfeatureVisualizer( + config_path=self.config_path, + video_path=video_path, + body_parts=body_parts, + style_attr=style_attr, + ) else: core_cnt = int(self.multiprocess_dropdown.getChoices()) - roi_feature_visualizer = ROIfeatureVisualizerMultiprocess(config_path=self.config_path, - video_path=video_path, - body_parts=body_parts, - style_attr=style_attr, - core_cnt=core_cnt) + roi_feature_visualizer = ROIfeatureVisualizerMultiprocess( + config_path=self.config_path, + video_path=video_path, + body_parts=body_parts, + style_attr=style_attr, + core_cnt=core_cnt, + ) threading.Thread(target=roi_feature_visualizer.run()).start() -#_ = VisualizeROIFeaturesPopUp(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/project_folder/project_config.ini') +# _ = VisualizeROIFeaturesPopUp(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/project_folder/project_config.ini') -#_ = VisualizeROIFeaturesPopUp(config_path='/Users/simon/Desktop/envs/troubleshooting/Two_animals_16bps/project_folder/project_config.ini') +# _ = VisualizeROIFeaturesPopUp(config_path='/Users/simon/Desktop/envs/troubleshooting/Two_animals_16bps/project_folder/project_config.ini') # _ = VisualizeROIFeaturesPopUp(config_path='/Users/simon/Desktop/envs/troubleshooting/Termites_5/project_folder/project_config.ini') # ROIAnalysisPopUp(config_path='/Users/simon/Desktop/envs/troubleshooting/Termites_5/project_folder/project_config.ini') diff --git a/simba/ui/pop_ups/roi_tracking_plot_pop_up.py b/simba/ui/pop_ups/roi_tracking_plot_pop_up.py index 85bece879..c5e899386 100644 --- a/simba/ui/pop_ups/roi_tracking_plot_pop_up.py +++ b/simba/ui/pop_ups/roi_tracking_plot_pop_up.py @@ -1,18 +1,19 @@ __author__ = "Simon Nilsson" import os +import threading from tkinter import * from typing import Union -import threading from simba.mixins.config_reader import ConfigReader from simba.mixins.pop_up_mixin import PopUpMixin from simba.plotting.ROI_plotter import ROIPlot from simba.plotting.ROI_plotter_mp import ROIPlotMultiprocess -from simba.ui.tkinter_functions import (CreateLabelFrameWithIcon, DropDownMenu, Entry_Box, FileSelect) -from simba.utils.checks import check_float, check_file_exist_and_readable +from simba.ui.tkinter_functions import (CreateLabelFrameWithIcon, DropDownMenu, + Entry_Box, FileSelect) +from simba.utils.checks import check_file_exist_and_readable, check_float from simba.utils.enums import Formats, Keys, Links, Options -from simba.utils.read_write import (find_all_videos_in_directory) +from simba.utils.read_write import find_all_videos_in_directory class VisualizeROITrackingPopUp(PopUpMixin, ConfigReader): @@ -20,35 +21,118 @@ def __init__(self, config_path: Union[str, os.PathLike]): check_file_exist_and_readable(file_path=config_path) ConfigReader.__init__(self, config_path=config_path, read_video_info=False) self.read_roi_data() - self.video_file_paths = find_all_videos_in_directory(directory=self.video_dir, as_dict=True, raise_error=True) + self.video_file_paths = find_all_videos_in_directory( + directory=self.video_dir, as_dict=True, raise_error=True + ) self.video_names = list(self.video_file_paths.keys()) PopUpMixin.__init__(self, title="VISUALIZE ROI TRACKING", size=(800, 500)) self.multiprocess_var = BooleanVar() self.show_pose_var = BooleanVar(value=True) self.animal_name_var = BooleanVar(value=True) - self.settings_frm = CreateLabelFrameWithIcon(parent=self.main_frm, header="SETTINGS", icon_name=Keys.DOCUMENTATION.value, icon_link=Links.ROI_DATA_PLOT.value) - self.threshold_entry_box = Entry_Box(self.settings_frm, "Body-part probability threshold", "30") + self.settings_frm = CreateLabelFrameWithIcon( + parent=self.main_frm, + header="SETTINGS", + icon_name=Keys.DOCUMENTATION.value, + icon_link=Links.ROI_DATA_PLOT.value, + ) + self.threshold_entry_box = Entry_Box( + self.settings_frm, "Body-part probability threshold", "30" + ) self.threshold_entry_box.entry_set(0.0) - threshold_label = Label(self.settings_frm, text="Note: body-part locations detected with probabilities \n below this threshold is removed from the visualization(s).", font=("Helvetica", 10, "italic")) - self.show_pose_cb = Checkbutton(self.settings_frm, text="Show pose-estimated location", variable=self.show_pose_var) - self.show_animal_name_cb = Checkbutton(self.settings_frm, text="Show animal names", variable=self.animal_name_var) - self.multiprocess_cb = Checkbutton(self.settings_frm, text="Multi-process (faster)", variable=self.multiprocess_var, command=lambda: self.enable_dropdown_from_checkbox(check_box_var=self.multiprocess_var, dropdown_menus=[self.multiprocess_dropdown])) - self.multiprocess_dropdown = DropDownMenu(self.settings_frm, "CPU cores:", list(range(2, self.cpu_cnt)), "12") + threshold_label = Label( + self.settings_frm, + text="Note: body-part locations detected with probabilities \n below this threshold is removed from the visualization(s).", + font=("Helvetica", 10, "italic"), + ) + self.show_pose_cb = Checkbutton( + self.settings_frm, + text="Show pose-estimated location", + variable=self.show_pose_var, + ) + self.show_animal_name_cb = Checkbutton( + self.settings_frm, text="Show animal names", variable=self.animal_name_var + ) + self.multiprocess_cb = Checkbutton( + self.settings_frm, + text="Multi-process (faster)", + variable=self.multiprocess_var, + command=lambda: self.enable_dropdown_from_checkbox( + check_box_var=self.multiprocess_var, + dropdown_menus=[self.multiprocess_dropdown], + ), + ) + self.multiprocess_dropdown = DropDownMenu( + self.settings_frm, "CPU cores:", list(range(2, self.cpu_cnt)), "12" + ) self.multiprocess_dropdown.setChoices(2) self.multiprocess_dropdown.disable() - self.body_parts_frm = LabelFrame(self.main_frm, text="SELECT BODY-PARTS", pady=10, font=Formats.LABELFRAME_HEADER_FORMAT.value, fg="black") - self.animal_cnt_dropdown = DropDownMenu(self.body_parts_frm, "NUMBER OF ANIMALS", list(range(1, self.animal_cnt + 1)), "20", com=lambda x: self.__populate_bp_dropdown(bp_cnt=x)) + self.body_parts_frm = LabelFrame( + self.main_frm, + text="SELECT BODY-PARTS", + pady=10, + font=Formats.LABELFRAME_HEADER_FORMAT.value, + fg="black", + ) + self.animal_cnt_dropdown = DropDownMenu( + self.body_parts_frm, + "NUMBER OF ANIMALS", + list(range(1, self.animal_cnt + 1)), + "20", + com=lambda x: self.__populate_bp_dropdown(bp_cnt=x), + ) self.animal_cnt_dropdown.setChoices(1) self.__populate_bp_dropdown(bp_cnt=1) - self.run_frm = LabelFrame(self.main_frm, text="RUN VISUALIZATION", pady=10, font=Formats.LABELFRAME_HEADER_FORMAT.value, fg="black") - self.single_video_frm = LabelFrame(self.run_frm, text="SINGLE video", pady=10, font=Formats.LABELFRAME_HEADER_FORMAT.value, fg="black") - self.single_video_dropdown = DropDownMenu(self.single_video_frm, "Select video", self.video_names, "15", com=lambda x: self.update_file_select_box_from_dropdown(filename=x, fileselectbox=self.select_video_file_select)) + self.run_frm = LabelFrame( + self.main_frm, + text="RUN VISUALIZATION", + pady=10, + font=Formats.LABELFRAME_HEADER_FORMAT.value, + fg="black", + ) + self.single_video_frm = LabelFrame( + self.run_frm, + text="SINGLE video", + pady=10, + font=Formats.LABELFRAME_HEADER_FORMAT.value, + fg="black", + ) + self.single_video_dropdown = DropDownMenu( + self.single_video_frm, + "Select video", + self.video_names, + "15", + com=lambda x: self.update_file_select_box_from_dropdown( + filename=x, fileselectbox=self.select_video_file_select + ), + ) self.single_video_dropdown.setChoices(self.video_names[0]) - self.select_video_file_select = FileSelect(self.single_video_frm, "", lblwidth="1", file_types=[("VIDEO FILE", Options.ALL_VIDEO_FORMAT_STR_OPTIONS.value)], dropdown=self.single_video_dropdown) + self.select_video_file_select = FileSelect( + self.single_video_frm, + "", + lblwidth="1", + file_types=[("VIDEO FILE", Options.ALL_VIDEO_FORMAT_STR_OPTIONS.value)], + dropdown=self.single_video_dropdown, + ) self.select_video_file_select.filePath.set(self.video_names[0]) - self.single_video_btn = Button(self.single_video_frm, text="Create SINGLE ROI video", fg="blue", command=lambda: self.run(multiple=False)) - self.all_videos_frm = LabelFrame(self.run_frm,text="ALL videos",pady=10,font=Formats.LABELFRAME_HEADER_FORMAT.value,fg="black") - self.all_videos_btn = Button(self.all_videos_frm, text=f"Create ALL ROI videos ({len(self.video_names)} videos found)",fg="red", command=lambda: self.run(multiple=True)) + self.single_video_btn = Button( + self.single_video_frm, + text="Create SINGLE ROI video", + fg="blue", + command=lambda: self.run(multiple=False), + ) + self.all_videos_frm = LabelFrame( + self.run_frm, + text="ALL videos", + pady=10, + font=Formats.LABELFRAME_HEADER_FORMAT.value, + fg="black", + ) + self.all_videos_btn = Button( + self.all_videos_frm, + text=f"Create ALL ROI videos ({len(self.video_names)} videos found)", + fg="red", + command=lambda: self.run(multiple=True), + ) self.settings_frm.grid(row=0, column=0, sticky=NW) self.threshold_entry_box.grid(row=0, column=0, sticky=NW) threshold_label.grid(row=1, column=0, sticky=NW) @@ -76,7 +160,12 @@ def __populate_bp_dropdown(self, bp_cnt: int): self.bp_dropdown_dict = {} for cnt in range(int(self.animal_cnt_dropdown.getChoices())): - self.bp_dropdown_dict[cnt] = DropDownMenu(self.body_parts_frm, self.multi_animal_id_list[cnt], self.body_parts_lst, "20") + self.bp_dropdown_dict[cnt] = DropDownMenu( + self.body_parts_frm, + self.multi_animal_id_list[cnt], + self.body_parts_lst, + "20", + ) self.bp_dropdown_dict[cnt].setChoices(self.body_parts_lst[cnt]) self.bp_dropdown_dict[cnt].grid(row=cnt + 1, column=0, sticky=NW) @@ -88,27 +177,39 @@ def run(self, multiple: bool): for video_path in videos: self.check_if_selected_video_path_exist_in_project(video_path=video_path) - check_float(name="Body-part probability threshold", value=self.threshold_entry_box.entry_get, min_value=0.0, max_value=1.0) - style_attr = {'show_body_part': self.show_pose_var.get(), - 'show_animal_name': self.animal_name_var.get()} + check_float( + name="Body-part probability threshold", + value=self.threshold_entry_box.entry_get, + min_value=0.0, + max_value=1.0, + ) + style_attr = { + "show_body_part": self.show_pose_var.get(), + "show_animal_name": self.animal_name_var.get(), + } body_parts = [v.getChoices() for k, v in self.bp_dropdown_dict.items()] for video_path in videos: if not self.multiprocess_var.get(): - roi_plotter = ROIPlot(config_path=self.config_path, - video_path=video_path, - style_attr=style_attr, - threshold=float(self.threshold_entry_box.entry_get), - body_parts=body_parts) + roi_plotter = ROIPlot( + config_path=self.config_path, + video_path=video_path, + style_attr=style_attr, + threshold=float(self.threshold_entry_box.entry_get), + body_parts=body_parts, + ) else: core_cnt = self.multiprocess_dropdown.getChoices() - roi_plotter = ROIPlotMultiprocess(config_path=self.config_path, - video_path=video_path, - core_cnt=int(core_cnt), - style_attr=style_attr, - threshold=float(self.threshold_entry_box.entry_get), - body_parts=body_parts) + roi_plotter = ROIPlotMultiprocess( + config_path=self.config_path, + video_path=video_path, + core_cnt=int(core_cnt), + style_attr=style_attr, + threshold=float(self.threshold_entry_box.entry_get), + body_parts=body_parts, + ) threading.Thread(target=roi_plotter.run()).start() + # _ = VisualizeROITrackingPopUp(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/project_config.ini') # _ = VisualizeROITrackingPopUp(config_path='/Users/simon/Desktop/envs/troubleshooting/Termites_5/project_folder/project_config.ini') # _ = VisualizeROITrackingPopUp(config_path='/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini') diff --git a/simba/utils/checks.py b/simba/utils/checks.py index 6acd7ce67..f35ed4fe7 100644 --- a/simba/utils/checks.py +++ b/simba/utils/checks.py @@ -3,12 +3,12 @@ import ast import glob import os -import cv2 import re import subprocess from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +import cv2 import numpy as np import pandas as pd import trafaret as t @@ -17,12 +17,13 @@ from simba.utils.errors import (ArrayError, ColumnNotFoundError, CorruptedFileError, CountError, DirectoryNotEmptyError, FFMPEGNotFoundError, - FloatError, IntegerError, InvalidFilepathError, - InvalidInputError, NoDataError, - NoFilesFoundError, NoROIDataError, + FloatError, FrameRangeError, IntegerError, + InvalidFilepathError, InvalidInputError, + NoDataError, NoFilesFoundError, NoROIDataError, NotDirectoryError, ParametersFileError, - StringError, FrameRangeError) -from simba.utils.warnings import NoDataFoundWarning, FrameRangeWarning + StringError) +from simba.utils.warnings import FrameRangeWarning, NoDataFoundWarning + def check_file_exist_and_readable(file_path: Union[str, os.PathLike]) -> None: """ @@ -1285,10 +1286,13 @@ def check_valid_tuple( source=source, ) -def check_video_and_data_frm_count_align(video: Union[str, os.PathLike, cv2.VideoCapture], - data: Union[str, os.PathLike, pd.DataFrame], - name: Optional[str] = '', - raise_error: Optional[bool] = True) -> None: + +def check_video_and_data_frm_count_align( + video: Union[str, os.PathLike, cv2.VideoCapture], + data: Union[str, os.PathLike, pd.DataFrame], + name: Optional[str] = "", + raise_error: Optional[bool] = True, +) -> None: """ Check if the frame count of a video matches the row count of a data file. @@ -1310,22 +1314,40 @@ def _count_generator(reader): yield b b = reader(1024 * 1024) - check_instance(source=f'{check_video_and_data_frm_count_align.__name__} video', instance=video, accepted_types=(str, cv2.VideoCapture)) - check_instance(source=f'{check_video_and_data_frm_count_align.__name__} data', instance=data, accepted_types=(str, pd.DataFrame)) - check_str(name=f'{check_video_and_data_frm_count_align.__name__} name', value=name, allow_blank=True) + check_instance( + source=f"{check_video_and_data_frm_count_align.__name__} video", + instance=video, + accepted_types=(str, cv2.VideoCapture), + ) + check_instance( + source=f"{check_video_and_data_frm_count_align.__name__} data", + instance=data, + accepted_types=(str, pd.DataFrame), + ) + check_str( + name=f"{check_video_and_data_frm_count_align.__name__} name", + value=name, + allow_blank=True, + ) if isinstance(video, str): check_file_exist_and_readable(file_path=video) video = cv2.VideoCapture(video) video_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) if isinstance(data, str): check_file_exist_and_readable(file_path=data) - with open(data, 'rb') as fp: + with open(data, "rb") as fp: c_generator = _count_generator(fp.raw.read) - data_count = (sum(buffer.count(b'\n') for buffer in c_generator)) - 1 + data_count = (sum(buffer.count(b"\n") for buffer in c_generator)) - 1 else: data_count = len(data) if data_count != video_count: if not raise_error: - FrameRangeWarning(msg=f'The video {name} has {video_count} frames, but the associated data file for this video has {data_count} rows', source=check_video_and_data_frm_count_align.__name__) + FrameRangeWarning( + msg=f"The video {name} has {video_count} frames, but the associated data file for this video has {data_count} rows", + source=check_video_and_data_frm_count_align.__name__, + ) else: - raise FrameRangeError(msg=f'The video {name} has {video_count} frames, but the associated data file for this video has {data_count} rows', source=check_video_and_data_frm_count_align.__name__) \ No newline at end of file + raise FrameRangeError( + msg=f"The video {name} has {video_count} frames, but the associated data file for this video has {data_count} rows", + source=check_video_and_data_frm_count_align.__name__, + ) diff --git a/simba/utils/data.py b/simba/utils/data.py index 9c6a66dd2..3e6177afc 100644 --- a/simba/utils/data.py +++ b/simba/utils/data.py @@ -25,13 +25,14 @@ from simba.utils.checks import (check_file_exist_and_readable, check_float, check_if_dir_exists, + check_if_keys_exist_in_dict, check_if_module_has_import, check_if_string_value_is_valid_video_timestamp, check_instance, check_int, check_str, check_that_column_exist, check_that_hhmmss_start_is_before_end, - check_valid_array, check_valid_dataframe, check_if_keys_exist_in_dict) -from simba.utils.enums import ConfigKey, Dtypes, Options, Keys + check_valid_array, check_valid_dataframe) +from simba.utils.enums import ConfigKey, Dtypes, Keys, Options from simba.utils.errors import (BodypartColumnNotFoundError, CountError, InvalidFileTypeError, InvalidInputError, NoFilesFoundError) @@ -646,22 +647,42 @@ def convert_roi_definitions( source=convert_roi_definitions.__name__, ) -def slice_roi_dict_for_video(data: Dict[str, pd.DataFrame], video_name: str) -> Tuple[Dict[str, pd.DataFrame], List[str]]: + +def slice_roi_dict_for_video( + data: Dict[str, pd.DataFrame], video_name: str +) -> Tuple[Dict[str, pd.DataFrame], List[str]]: """ Given a dictionary of dataframes representing different ROIs (created by ``simba.mixins.config_reader.ConfigReader.read_roi_data``), retain only the ROIs belonging to the specified video. """ - check_if_keys_exist_in_dict(data=data, key=[Keys.ROI_RECTANGLES.value, Keys.ROI_CIRCLES.value, Keys.ROI_POLYGONS.value], name=slice_roi_dict_for_video.__name__) + check_if_keys_exist_in_dict( + data=data, + key=[ + Keys.ROI_RECTANGLES.value, + Keys.ROI_CIRCLES.value, + Keys.ROI_POLYGONS.value, + ], + name=slice_roi_dict_for_video.__name__, + ) new_data, shape_names = {}, [] for k, v in data.items(): - check_instance(source=f'{slice_roi_dict_for_video.__name__} {k}', instance=v, accepted_types=(pd.DataFrame,)) - check_that_column_exist(df=v, column_name='Video', file_name=slice_roi_dict_for_video.__name__) - check_that_column_exist(df=v, column_name='Name', file_name=slice_roi_dict_for_video.__name__) - v = v[v['Video'] == video_name] + check_instance( + source=f"{slice_roi_dict_for_video.__name__} {k}", + instance=v, + accepted_types=(pd.DataFrame,), + ) + check_that_column_exist( + df=v, column_name="Video", file_name=slice_roi_dict_for_video.__name__ + ) + check_that_column_exist( + df=v, column_name="Name", file_name=slice_roi_dict_for_video.__name__ + ) + v = v[v["Video"] == video_name] new_data[k] = v.reset_index(drop=True) - shape_names.extend((list(v['Name'].unique()))) + shape_names.extend((list(v["Name"].unique()))) return new_data, shape_names + def freedman_diaconis(data: np.array) -> (float, int): """ Use Freedman-Diaconis rule to compute optimal count of histogram bins and their width. diff --git a/simba/utils/read_write.py b/simba/utils/read_write.py index de2f299c6..bd4967b1a 100644 --- a/simba/utils/read_write.py +++ b/simba/utils/read_write.py @@ -1953,14 +1953,19 @@ def seconds_to_timestamp(seconds: int) -> str: return "{:02d}:{:02d}:{:02d}".format(hours, minutes, seconds) -def read_data_paths(path: Union[str, os.PathLike], - default: List[Union[str, os.PathLike]], - default_name: Optional[str] = '', - file_type: Optional[str] = 'csv') -> List[str]: +def read_data_paths( + path: Union[str, os.PathLike], + default: List[Union[str, os.PathLike]], + default_name: Optional[str] = "", + file_type: Optional[str] = "csv", +) -> List[str]: if path is None: if len(default) == 0: - raise NoFilesFoundError(msg = f"No files in format found in {default_name}", source=read_data_paths.__name__) + raise NoFilesFoundError( + msg=f"No files in format found in {default_name}", + source=read_data_paths.__name__, + ) else: for i in default: check_file_exist_and_readable(file_path=i) @@ -1970,22 +1975,34 @@ def read_data_paths(path: Union[str, os.PathLike], check_file_exist_and_readable(file_path=path) data_paths = [path] elif os.path.isdir(s=path): - data_paths = find_files_of_filetypes_in_directory(directory=path, extensions=[f'.{file_type}'], raise_error=True) + data_paths = find_files_of_filetypes_in_directory( + directory=path, extensions=[f".{file_type}"], raise_error=True + ) if len(data_paths) == 0: - raise NoFilesFoundError(msg=f"No files in format {file_type} found in {default_name}", source=read_data_paths.__name__) + raise NoFilesFoundError( + msg=f"No files in format {file_type} found in {default_name}", + source=read_data_paths.__name__, + ) else: - raise NoFilesFoundError(msg=f"{path} is not a valid path string", source=read_data_paths.__name__) + raise NoFilesFoundError( + msg=f"{path} is not a valid path string", + source=read_data_paths.__name__, + ) elif isinstance(path, (list, tuple)): - check_valid_lst(data=path, source=f'{read_data_paths.__name__} path', valid_dtypes=(str,), min_len=1) + check_valid_lst( + data=path, + source=f"{read_data_paths.__name__} path", + valid_dtypes=(str,), + min_len=1, + ) data_paths = [] for i in path: check_file_exist_and_readable(file_path=i) data_paths.append(i) else: - raise NoFilesFoundError(msg=f"{type(path)} is not a valid type for path", source=read_data_paths.__name__) + raise NoFilesFoundError( + msg=f"{type(path)} is not a valid type for path", + source=read_data_paths.__name__, + ) return data_paths - - - -