From f1d587b1b5440065ef3d2baac980fad1acc0d8fd Mon Sep 17 00:00:00 2001 From: sronilsson Date: Sun, 4 Aug 2024 15:37:56 -0400 Subject: [PATCH] path plotter with roi option --- simba/mixins/plotting_mixin.py | 52 +--- simba/plotting/path_plotter.py | 359 ++++++++--------------- simba/plotting/path_plotter_mp.py | 420 ++++++++++----------------- simba/ui/pop_ups/path_plot_pop_up.py | 245 +++++++--------- simba/utils/data.py | 4 +- 5 files changed, 385 insertions(+), 695 deletions(-) diff --git a/simba/mixins/plotting_mixin.py b/simba/mixins/plotting_mixin.py index c2e3a7cf0..2b4f4afcb 100644 --- a/simba/mixins/plotting_mixin.py +++ b/simba/mixins/plotting_mixin.py @@ -1625,47 +1625,17 @@ def polygons_onto_image( return img @staticmethod - def roi_dict_onto_img( - img: np.ndarray, - roi_dict: Dict[str, pd.DataFrame], - circle_size: Optional[int] = 2, - show_center: Optional[bool] = False, - show_tags: Optional[bool] = False, - ) -> np.ndarray: - - check_valid_array( - data=img, source=f"{PlottingMixin.roi_dict_onto_img.__name__} img" - ) - check_if_keys_exist_in_dict( - data=roi_dict, - key=[ - Keys.ROI_POLYGONS.value, - Keys.ROI_CIRCLES.value, - Keys.ROI_RECTANGLES.value, - ], - name=PlottingMixin.roi_dict_onto_img.__name__, - ) - img = PlottingMixin.rectangles_onto_image( - img=img, - rectangles=roi_dict[Keys.ROI_RECTANGLES.value], - circle_size=circle_size, - show_center=show_center, - show_tags=show_tags, - ) - img = PlottingMixin.circles_onto_image( - img=img, - circles=roi_dict[Keys.ROI_CIRCLES.value], - circle_size=circle_size, - show_center=show_center, - show_tags=show_tags, - ) - img = PlottingMixin.polygons_onto_image( - img=img, - polygons=roi_dict[Keys.ROI_POLYGONS.value], - circle_size=circle_size, - show_center=show_center, - show_tags=show_tags, - ) + def roi_dict_onto_img(img: np.ndarray, + roi_dict: Dict[str, pd.DataFrame], + circle_size: Optional[int] = 2, + show_center: Optional[bool] = False, + show_tags: Optional[bool] = False) -> np.ndarray: + + check_valid_array(data=img, source=f"{PlottingMixin.roi_dict_onto_img.__name__} img") + check_if_keys_exist_in_dict(data=roi_dict, key=[Keys.ROI_POLYGONS.value, Keys.ROI_CIRCLES.value, Keys.ROI_RECTANGLES.value], name=PlottingMixin.roi_dict_onto_img.__name__) + img = PlottingMixin.rectangles_onto_image(img=img, rectangles=roi_dict[Keys.ROI_RECTANGLES.value], circle_size=circle_size, show_center=show_center, show_tags=show_tags) + img = PlottingMixin.circles_onto_image(img=img, circles=roi_dict[Keys.ROI_CIRCLES.value], circle_size=circle_size, show_center=show_center, show_tags=show_tags) + img = PlottingMixin.polygons_onto_image(img=img, polygons=roi_dict[Keys.ROI_POLYGONS.value], circle_size=circle_size, show_center=show_center, show_tags=show_tags) return img @staticmethod diff --git a/simba/plotting/path_plotter.py b/simba/plotting/path_plotter.py index 96fd0f39a..17bdf9790 100644 --- a/simba/plotting/path_plotter.py +++ b/simba/plotting/path_plotter.py @@ -10,19 +10,17 @@ from simba.mixins.config_reader import ConfigReader from simba.mixins.plotting_mixin import PlottingMixin -from simba.utils.checks import ( - check_all_file_names_are_represented_in_video_log, - check_if_keys_exist_in_dict, - check_if_string_value_is_valid_video_timestamp, check_if_valid_rgb_str, - check_instance, check_int, check_that_column_exist, - check_that_hhmmss_start_is_before_end, check_valid_lst) -from simba.utils.data import find_frame_numbers_from_time_stamp +from simba.utils.checks import (check_all_file_names_are_represented_in_video_log, + check_if_keys_exist_in_dict, + check_if_string_value_is_valid_video_timestamp, + check_instance, check_int, check_that_column_exist, + check_that_hhmmss_start_is_before_end, check_valid_lst) +from simba.utils.data import find_frame_numbers_from_time_stamp, slice_roi_dict_for_video from simba.utils.enums import Formats, TagNames from simba.utils.errors import FrameRangeError, NoSpecifiedOutputError +from simba.utils.warnings import ROIWarning from simba.utils.printing import SimbaTimer, log_event, stdout_success -from simba.utils.read_write import (find_video_of_file, get_fn_ext, - get_video_meta_data, read_df, - read_frm_of_video) +from simba.utils.read_write import (find_video_of_file, get_fn_ext, get_video_meta_data, read_df, read_frm_of_video) STYLE_WIDTH = "width" STYLE_HEIGHT = "height" @@ -66,8 +64,6 @@ class PathPlotterSingleCore(ConfigReader, PlottingMixin): :width: 500 :align: center - - :param str config_path: Path to SimBA project config file in Configparser format :param bool frame_setting: If True, individual frames will be created. :param bool video_setting: If True, compressed videos will be created. @@ -75,6 +71,7 @@ class PathPlotterSingleCore(ConfigReader, PlottingMixin): :param dict animal_attr: Animal body-parts and colors :param dict style_attr: Plot sttributes (line thickness, color, etc..) :param Optional[dict] slicing: If Dict, start time and end time of video slice to create path plot from. E.g., {'start_time': '00:00:01', 'end_time': '00:00:03'}. If None, creates path plot for entire video. + :param Optional[bool] roi: If True, also plots the ROIs associated with the video. Default False. .. note:: If style_attr['bg color'] is a dictionary, e.g., {'opacity': 100%}, then SimBA will use the video as background with set opacity. @@ -85,64 +82,36 @@ class PathPlotterSingleCore(ConfigReader, PlottingMixin): >>> path_plotter = PathPlotterSingleCore(config_path=r'MyConfigPath', frame_setting=False, video_setting=True, style_attr=style_attr, animal_attr=animal_attr, files_found=['project_folder/csv/machine_results/MyVideo.csv'], print_animal_names=True).run() """ - def __init__( - self, - config_path: Union[str, os.PathLike], - files_found: List[Union[str, os.PathLike]], - animal_attr: dict, - input_style_attr: Optional[Union[Dict[str, Any], None]] = None, - clf_attr: Optional[Dict[int, List[str]]] = None, - frame_setting: Union[bool] = False, - video_setting: Union[bool] = False, - last_frame: Union[bool] = False, - print_animal_names: Optional[bool] = True, - slicing: Optional[Dict] = None, - ): - - log_event( - logger_name=str(__class__.__name__), - log_type=TagNames.CLASS_INIT.value, - msg=self.create_log_msg_from_init_args(locals=locals()), - ) + def __init__(self, + config_path: Union[str, os.PathLike], + files_found: List[Union[str, os.PathLike]], + animal_attr: dict, + input_style_attr: Optional[Union[Dict[str, Any], None]] = None, + clf_attr: Optional[Dict[int, List[str]]] = None, + frame_setting: Union[bool] = False, + video_setting: Union[bool] = False, + last_frame: Union[bool] = False, + print_animal_names: Optional[bool] = True, + slicing: Optional[Dict] = None, + roi: Optional[bool] = False): + + log_event(logger_name=str(__class__.__name__), log_type=TagNames.CLASS_INIT.value, msg=self.create_log_msg_from_init_args(locals=locals())) if (not frame_setting) and (not video_setting) and (not last_frame): - raise NoSpecifiedOutputError( - msg="SIMBA ERROR: Please choice to create path frames and/or video path plots", - source=self.__class__.__name__, - ) - check_valid_lst( - data=files_found, - source=self.__class__.__name__, - valid_dtypes=(str,), - min_len=1, - ) + raise NoSpecifiedOutputError(msg="SIMBA ERROR: Please choice to create path frames and/or video path plots", source=self.__class__.__name__) + check_valid_lst(data=files_found, source=self.__class__.__name__, valid_dtypes=(str,), min_len=1) if input_style_attr is not None: - check_if_keys_exist_in_dict( - data=input_style_attr, key=STYLE_KEYS, name="input_style_attr" - ) + check_if_keys_exist_in_dict(data=input_style_attr, key=STYLE_KEYS, name="input_style_attr") ConfigReader.__init__(self, config_path=config_path) PlottingMixin.__init__(self) - ( - self.video_setting, - self.frame_setting, - self.input_style_attr, - self.files_found, - self.animal_attr, - self.clf_attr, - self.last_frame, - ) = ( - video_setting, - frame_setting, - input_style_attr, - files_found, - animal_attr, - clf_attr, - last_frame, - ) + if roi: + self.read_roi_data() + (self.video_setting, self.frame_setting, self.input_style_attr, self.files_found, self.animal_attr, self.clf_attr, self.last_frame, self.roi) = ( video_setting, frame_setting, input_style_attr, files_found, animal_attr, clf_attr, last_frame, roi) self.print_animal_names, self.slicing = print_animal_names, slicing if not os.path.exists(self.path_plot_dir): os.makedirs(self.path_plot_dir) print(f"Processing {len(self.files_found)} videos...") + def __get_styles(self): self.style_attr = {} if self.input_style_attr is not None: @@ -150,37 +119,23 @@ def __get_styles(self): if self.input_style_attr["max lines"] == "entire video": self.style_attr["max lines"] = len(self.data_df) else: - self.style_attr["max lines"] = max( - 1, - int( - int(self.input_style_attr["max lines"] / 1000) - * (int(self.video_info["fps"].values[0])) - ), - ) + self.style_attr["max lines"] = max(1, int(int(self.input_style_attr["max lines"] / 1000) * (int(self.video_info["fps"].values[0])))) self.style_attr["font thickness"] = self.input_style_attr["font thickness"] self.style_attr["line width"] = self.input_style_attr["line width"] self.style_attr["font size"] = self.input_style_attr["font size"] self.style_attr["circle size"] = self.input_style_attr["circle size"] self.style_attr["print_animal_names"] = self.print_animal_names if self.input_style_attr["width"] == "As input": - self.style_attr["width"], self.style_attr["height"] = int( - self.video_info["Resolution_width"].values[0] - ), int(self.video_info["Resolution_height"].values[0]) + self.style_attr["width"], self.style_attr["height"] = int(self.video_info["Resolution_width"].values[0]), int(self.video_info["Resolution_height"].values[0]) else: pass else: space_scaler, radius_scaler, res_scaler, font_scaler = 25, 10, 1500, 0.8 - self.style_attr["width"] = int( - self.video_info["Resolution_width"].values[0] - ) - self.style_attr["height"] = int( - self.video_info["Resolution_height"].values[0] - ) + self.style_attr["width"] = int(self.video_info["Resolution_width"].values[0]) + self.style_attr["height"] = int(self.video_info["Resolution_height"].values[0]) max_res = max(self.style_attr["width"], self.style_attr["height"]) self.style_attr["circle size"] = int(radius_scaler / (res_scaler / max_res)) - self.style_attr["font size"] = min( - 0.5, (font_scaler / (res_scaler / max_res)) - ) + self.style_attr["font size"] = min(0.5, (font_scaler / (res_scaler / max_res))) self.style_attr["bg color"] = self.color_dict["White"] self.style_attr["print_animal_names"] = self.print_animal_names self.style_attr["max lines"] = len(self.data_df) @@ -188,9 +143,7 @@ def __get_styles(self): self.style_attr["line width"] = 2 def run(self): - check_all_file_names_are_represented_in_video_log( - video_info_df=self.video_info_df, data_paths=self.files_found - ) + check_all_file_names_are_represented_in_video_log(video_info_df=self.video_info_df, data_paths=self.files_found) for file_cnt, file_path in enumerate(self.files_found): video_timer = SimbaTimer(start=True) _, self.video_name, _ = get_fn_ext(file_path) @@ -204,204 +157,128 @@ def run(self): line_data.append(self.data_df[data_cols].values.astype(np.int64)) colors.append(v["color"]) if self.print_animal_names: - animal_names.append( - self.find_animal_name_from_body_part_name( - bp_name=v["bp"], bp_dict=self.animal_bp_dict - ) - ) + animal_names.append(self.find_animal_name_from_body_part_name(bp_name=v["bp"], bp_dict=self.animal_bp_dict)) if not self.print_animal_names: animal_names = None if self.slicing: - check_if_keys_exist_in_dict( - data=self.slicing, key=["start_time", "end_time"] - ) - check_if_string_value_is_valid_video_timestamp( - value=self.slicing["start_time"], name="slice start time" - ) - check_if_string_value_is_valid_video_timestamp( - value=self.slicing["end_time"], name="slice end time" - ) - check_that_hhmmss_start_is_before_end( - start_time=self.slicing["start_time"], - end_time=self.slicing["end_time"], - name="slice times", - ) - frm_numbers = find_frame_numbers_from_time_stamp( - start_time=self.slicing["start_time"], - end_time=self.slicing["end_time"], - fps=self.fps, - ) + check_if_keys_exist_in_dict(data=self.slicing, key=["start_time", "end_time"]) + check_if_string_value_is_valid_video_timestamp(value=self.slicing["start_time"], name="slice start time") + check_if_string_value_is_valid_video_timestamp(value=self.slicing["end_time"], name="slice end time") + check_that_hhmmss_start_is_before_end(start_time=self.slicing["start_time"], end_time=self.slicing["end_time"], name="slice times") + frm_numbers = find_frame_numbers_from_time_stamp(start_time=self.slicing["start_time"], end_time=self.slicing["end_time"], fps=self.fps) if len(set(frm_numbers) - set(self.data_df.index)) > 0: - raise FrameRangeError( - msg=f'The chosen time-period ({self.slicing["start_time"]} - {self.slicing["end_time"]}) does not exist in {self.video_name}.', - source=self.__class__.__name__, - ) + raise FrameRangeError(msg=f'The chosen time-period ({self.slicing["start_time"]} - {self.slicing["end_time"]}) does not exist in {self.video_name}.', source=self.__class__.__name__) for i in range(len(line_data)): line_data[i] = line_data[i][frm_numbers, :] self.__get_styles() if self.video_setting: - self.video_save_path = os.path.join( - self.path_plot_dir, f"{self.video_name}.mp4" - ) + self.video_save_path = os.path.join(self.path_plot_dir, f"{self.video_name}.mp4") self.fourcc = cv2.VideoWriter_fourcc(*Formats.MP4_CODEC.value) - self.writer = cv2.VideoWriter( - self.video_save_path, - self.fourcc, - self.fps, - (self.style_attr["width"], self.style_attr["height"]), - ) + self.writer = cv2.VideoWriter(self.video_save_path, self.fourcc, self.fps, (self.style_attr["width"], self.style_attr["height"])) if self.frame_setting: - self.save_video_folder = os.path.join( - self.path_plot_dir, self.video_name - ) + self.save_video_folder = os.path.join(self.path_plot_dir, self.video_name) if os.path.exists(self.save_video_folder): shutil.rmtree(path=self.save_video_folder) os.makedirs(self.save_video_folder) + video_rois, video_roi_names = None, None + if self.roi: + video_rois, roi_names = slice_roi_dict_for_video(data=self.roi_dict, video_name=self.video_name) + if len(roi_names) == 0: + video_rois, video_roi_names = None, None + ROIWarning(msg=f'NO ROI data found for video {self.video_name}. Skipping ROI plotting for this video.') + if self.clf_attr is not None: self.clf_attr_appended = {} - check_instance( - source=self.__class__.__name__, - instance=self.clf_attr, - accepted_types=(dict,), - ) + check_instance(source=self.__class__.__name__, instance=self.clf_attr, accepted_types=(dict,)) for k, v in self.clf_attr.items(): - check_if_keys_exist_in_dict( - data=v, key=["color", "size"], name=f"clf_attr {k}" - ) - check_that_column_exist( - df=self.data_df, column_name=k, file_name=file_path - ) + check_if_keys_exist_in_dict(data=v, key=["color", "size"], name=f"clf_attr {k}") + check_that_column_exist(df=self.data_df, column_name=k, file_name=file_path) self.clf_attr_appended[k] = self.clf_attr[k] - self.clf_attr_appended[k]["clfs"] = self.data_df[k].values.astype( - np.int8 - ) - self.clf_attr_appended[k]["positions"] = self.data_df[ - [ - self.animal_attr[0]["bp"] + "_x", - self.animal_attr[0]["bp"] + "_y", - ] - ].values.astype(np.int64) + self.clf_attr_appended[k]["clfs"] = self.data_df[k].values.astype(np.int8) + self.clf_attr_appended[k]["positions"] = self.data_df[[self.animal_attr[0]["bp"] + "_x", self.animal_attr[0]["bp"] + "_y"]].values.astype(np.int64) self.clf_attr = deepcopy(self.clf_attr_appended) del self.clf_attr_appended self.video_path = None if type(self.style_attr["bg color"]) == dict: - check_if_keys_exist_in_dict( - data=self.style_attr["bg color"], - key=["type", "opacity", "frame_index"], - ) - self.video_path = find_video_of_file( - video_dir=self.video_dir, filename=self.video_name, raise_error=True - ) + check_if_keys_exist_in_dict(data=self.style_attr["bg color"], key=["type", "opacity", "frame_index"]) + self.video_path = find_video_of_file(video_dir=self.video_dir, filename=self.video_name, raise_error=True) video_meta_data = get_video_meta_data(video_path=self.video_path) if "frame_index" in self.style_attr["bg color"].keys(): - check_int( - name="Static frame index", - value=self.style_attr["bg color"]["frame_index"], - min_value=0, - ) + check_int( name="Static frame index", value=self.style_attr["bg color"]["frame_index"], min_value=0) frame_index = self.style_attr["bg color"]["frame_index"] else: - frame_index = video_meta_data["frame_count"] - 1 - self.style_attr["bg color"] = read_frm_of_video( - video_path=self.video_path, - opacity=self.style_attr["bg color"]["opacity"], - frame_index=frame_index, - ) + frame_index = video_meta_data["frame_count"] - 1 + self.style_attr["bg color"] = read_frm_of_video( video_path=self.video_path, opacity=self.style_attr["bg color"]["opacity"], frame_index=frame_index) if self.last_frame: - PlottingMixin.make_path_plot( - data=line_data, - colors=colors, - width=self.style_attr["width"], - height=self.style_attr["height"], - max_lines=self.style_attr["max lines"], - bg_clr=self.style_attr["bg color"], - circle_size=self.style_attr["circle size"], - font_size=self.style_attr["font size"], - font_thickness=self.style_attr["font thickness"], - line_width=self.style_attr["line width"], - animal_names=animal_names, - clf_attr=self.clf_attr, - save_path=os.path.join( - self.path_plot_dir, f"{self.video_name}_final_frame.png" - ), - ) - + last_frame_save_path = os.path.join(self.path_plot_dir, f"{self.video_name}_final_frame.png") + last_frm = PlottingMixin.make_path_plot(data=line_data, + colors=colors, + width=self.style_attr["width"], + height=self.style_attr["height"], + max_lines=self.style_attr["max lines"], + bg_clr=self.style_attr["bg color"], + circle_size=self.style_attr["circle size"], + font_size=self.style_attr["font size"], + font_thickness=self.style_attr["font thickness"], + line_width=self.style_attr["line width"], + animal_names=animal_names, + clf_attr=self.clf_attr, + save_path=None) + if video_rois is not None: + last_frm = PlottingMixin.roi_dict_onto_img(img=last_frm, roi_dict=video_rois, show_tags=False, show_center=False) + cv2.imwrite(filename=last_frame_save_path, img=last_frm) + stdout_success(msg=f'Last path plot frame saved at {last_frame_save_path}') bg = self.style_attr["bg color"] if self.video_setting or self.frame_setting: if self.input_style_attr is not None: self.capture = cv2.VideoCapture(self.video_path) - if (type(self.input_style_attr["bg color"]) == dict) and ( - self.input_style_attr["bg color"]["type"] - ) == "static": - bg = read_frm_of_video( - video_path=self.capture, - opacity=self.input_style_attr["bg color"]["opacity"], - frame_index=self.input_style_attr["bg color"][ - "frame_index" - ], - ) - + if (type(self.input_style_attr["bg color"]) == dict) and (self.input_style_attr["bg color"]["type"]) == "static": + bg = read_frm_of_video(video_path=self.capture, opacity=self.input_style_attr["bg color"]["opacity"], frame_index=self.input_style_attr["bg color"]["frame_index"]) for frm_cnt in range(1, line_data[0].shape[0]): plot_arrs = [x[:frm_cnt, :] for x in line_data] if self.input_style_attr is not None: - if (type(self.input_style_attr["bg color"]) == dict) and ( - self.input_style_attr["bg color"]["type"] == "moving" - ): - bg = read_frm_of_video( - video_path=self.capture, - opacity=self.input_style_attr["bg color"]["opacity"], - frame_index=frm_cnt, - ) + if (type(self.input_style_attr["bg color"]) == dict) and (self.input_style_attr["bg color"]["type"] == "moving"): + bg = read_frm_of_video(video_path=self.capture, opacity=self.input_style_attr["bg color"]["opacity"],frame_index=frm_cnt) self.clf_attr_cpy = deepcopy(self.clf_attr) if self.clf_attr is not None: for k, v in self.clf_attr.items(): self.clf_attr_cpy[k]["clfs"][frm_cnt + 1 :] = 0 - img = PlottingMixin.make_path_plot( - data=plot_arrs, - colors=colors, - width=self.style_attr["width"], - height=self.style_attr["height"], - max_lines=self.style_attr["max lines"], - bg_clr=bg, - circle_size=self.style_attr["circle size"], - font_size=self.style_attr["font size"], - font_thickness=self.style_attr["font thickness"], - line_width=self.style_attr["line width"], - animal_names=animal_names, - clf_attr=self.clf_attr_cpy, - save_path=None, - ) - + img = PlottingMixin.make_path_plot(data=plot_arrs, + colors=colors, + width=self.style_attr["width"], + height=self.style_attr["height"], + max_lines=self.style_attr["max lines"], + bg_clr=bg, + circle_size=self.style_attr["circle size"], + font_size=self.style_attr["font size"], + font_thickness=self.style_attr["font thickness"], + line_width=self.style_attr["line width"], + animal_names=animal_names, + clf_attr=self.clf_attr_cpy, + save_path=None) + + if video_rois is not None: + img = PlottingMixin.roi_dict_onto_img(img=img, roi_dict=video_rois, show_tags=False, show_center=False) if self.video_setting: self.writer.write(np.uint8(img)) if self.frame_setting: - frm_name = os.path.join( - self.save_video_folder, str() + f"{frm_cnt}.png" - ) + frm_name = os.path.join(self.save_video_folder, str() + f"{frm_cnt}.png") cv2.imwrite(frm_name, np.uint8(img)) - print( - f"Path frame: {frm_cnt + 1} / {line_data[0].shape[0]}) created. Video: {self.video_name} ({str(file_cnt + 1)}/{len(self.files_found)})" - ) + print(f"Path frame: {frm_cnt + 1} / {line_data[0].shape[0]}) created. Video: {self.video_name} ({str(file_cnt + 1)}/{len(self.files_found)})") if self.video_setting: self.writer.release() video_timer.stop_timer() - print( - f"Path visualization for video {self.video_name} saved (elapsed time {video_timer.elapsed_time_str}s)..." - ) + print(f"Path visualization for video {self.video_name} saved (elapsed time {video_timer.elapsed_time_str}s)...") self.timer.stop_timer() - stdout_success( - msg=f"Path visualizations for {len(self.files_found)} video(s) saved in {self.path_plot_dir} directory", - elapsed_time=self.timer.elapsed_time_str, - source=self.__class__.__name__, - ) + stdout_success(msg=f"Path visualizations for {len(self.files_found)} video(s) saved in {self.path_plot_dir} directory", elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__) # @@ -414,31 +291,33 @@ def run(self): # 'bg color': {'type': 'moving', 'opacity': 50, 'frame_index': 200}, #{'type': 'static', 'opacity': 100, 'frame_index': 200} # 'max lines': 'entire video'} # # -# animal_attr = {0: {'bp': 'Ear_right_1', 'color': (255, 0, 0)}, 1: {'bp': 'Ear_right_2', 'color': (0, 0, 255)}} #['Ear_right_1', 'Red'], 1: ['Ear_right_2', 'Green']} -# clf_attr = {'Nose to Nose': {'color': (155, 1, 10), 'size': 30}, 'Nose to Tailbase': {'color': (155, 90, 10), 'size': 30}} +# animal_attr = {0: {'bp': 'Ear_right', 'color': (255, 0, 0)}, 1: {'bp': 'Center', 'color': (0, 0, 255)}} #['Ear_right_1', 'Red'], 1: ['Ear_right_2', 'Green']} +# # # clf_attr = {'Nose to Nose': {'color': (155, 1, 10), 'size': 30}, 'Nose to Tailbase': {'color': (155, 90, 10), 'size': 30}} # style_attr = None -# test = PathPlotterSingleCore(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/beepboop174/project_folder/project_config.ini', -# frame_setting=True, +# test = PathPlotterSingleCore(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/project_config.ini', +# frame_setting=False, # video_setting=True, # last_frame=True, # slicing=None,#{'start_time': '00:00:00', 'end_time': '00:00:05'}, #{'start_time': '00:00:00', 'end_time': '00:00:50'}, #{'start_time': '00:00:00', 'end_time': '00:00:01'},, #{'start_time': '00:00:00', 'end_time': '00:00:01'}, #{'start_time': '00:00:00', 'end_time': '00:00:01'}, # input_style_attr=style_attr, # animal_attr=animal_attr, -# clf_attr=clf_attr, -# print_animal_names=True, -# files_found=['/Users/simon/Desktop/envs/simba/troubleshooting/beepboop174/project_folder/csv/machine_results/Trial 10.csv']) +# clf_attr=None, +# print_animal_names=False, +# roi=True, +# files_found=['/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/csv/outlier_corrected_movement_location/2022-06-20_NOB_DOT_4.csv']) # test.run() -# test = PathPlotterSingleCore(config_path='/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini', +# test = PathPlotterSingleCore(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini', # frame_setting=False, # video_setting=True, # last_frame=True, # input_style_attr=style_attr, # animal_attr=animal_attr, -# input_clf_attr=clf_attr, -# slicing = {'start_time': '00:00:01', 'end_time': '00:00:03'}, -# files_found=['/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/csv/machine_results/Together_1.csv']) -# test.run() +# clf_attr=None, +# roi=True, +# slicing = {'start_time': '00:00:01', 'end_time': '00:00:08'}, +# files_found=['/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/csv/machine_results/Together_1.csv']) +# #test.run() # style_attr = {'width': 'As input', diff --git a/simba/plotting/path_plotter_mp.py b/simba/plotting/path_plotter_mp.py index 485c7a4a9..d933dd002 100644 --- a/simba/plotting/path_plotter_mp.py +++ b/simba/plotting/path_plotter_mp.py @@ -13,11 +13,10 @@ from simba.mixins.config_reader import ConfigReader from simba.mixins.plotting_mixin import PlottingMixin from simba.utils.checks import (check_file_exist_and_readable, - check_if_keys_exist_in_dict, - check_if_valid_rgb_str, check_instance, + check_if_keys_exist_in_dict, check_instance, check_int, check_that_column_exist, check_valid_lst) -from simba.utils.data import find_frame_numbers_from_time_stamp +from simba.utils.data import find_frame_numbers_from_time_stamp, slice_roi_dict_for_video from simba.utils.enums import Formats, TagNames from simba.utils.errors import FrameRangeError, NoSpecifiedOutputError from simba.utils.printing import SimbaTimer, log_event, stdout_success @@ -26,80 +25,68 @@ get_fn_ext, get_video_meta_data, read_df, read_frm_of_video, remove_a_folder) - -def path_plot_mp( - frm_rng: np.ndarray, - data: np.array, - colors: List[Tuple], - video_setting: bool, - frame_setting: bool, - video_save_dir: str, - video_name: str, - frame_folder_dir: str, - style_attr: dict, - animal_names: Union[None, List[str]], - fps: int, - clf_attr: dict, - input_style_attr: dict, - video_path: Optional[Union[str, os.PathLike]] = None, -): +from simba.utils.warnings import ROIWarning + + +def path_plot_mp(frm_rng: np.ndarray, + data: np.array, + colors: List[Tuple], + video_setting: bool, + frame_setting: bool, + video_save_dir: str, + video_name: str, + frame_folder_dir: str, + style_attr: dict, + roi: Union[dict, None], + animal_names: Union[None, List[str]], + fps: int, + clf_attr: dict, + input_style_attr: dict, + video_path: Optional[Union[str, os.PathLike]] = None): batch_id, frm_rng = frm_rng[0], frm_rng[1] if video_setting: fourcc = cv2.VideoWriter_fourcc(*Formats.MP4_CODEC.value) video_save_path = os.path.join(video_save_dir, f"{batch_id}.mp4") - video_writer = cv2.VideoWriter( - video_save_path, fourcc, fps, (style_attr["width"], style_attr["height"]) - ) + video_writer = cv2.VideoWriter(video_save_path, fourcc, fps, (style_attr["width"], style_attr["height"])) if input_style_attr is not None: - if (isinstance(input_style_attr["bg color"], dict)) and ( - input_style_attr["bg color"]["type"] - ) == "moving": + if (isinstance(input_style_attr["bg color"], dict)) and (input_style_attr["bg color"]["type"]) == "moving": check_file_exist_and_readable(file_path=video_path) video_cap = cv2.VideoCapture(video_path) bg_clr = style_attr["bg color"] for frame_id in frm_rng: - if (isinstance(style_attr["bg color"], dict)) and ( - style_attr["bg color"]["type"] - ) == "moving": - bg_clr = read_frm_of_video( - video_path=video_cap, - opacity=style_attr["bg color"]["opacity"], - frame_index=frame_id, - ) + if (isinstance(style_attr["bg color"], dict)) and (style_attr["bg color"]["type"]) == "moving": + bg_clr = read_frm_of_video(video_path=video_cap, opacity=style_attr["bg color"]["opacity"], frame_index=frame_id) plot_arrs = [x[:frame_id, :] for x in data] clf_attr_cpy = deepcopy(clf_attr) if clf_attr is not None: - for k, v in clf_attr.items(): - clf_attr_cpy[k]["clfs"][frame_id + 1 :] = 0 - - img = PlottingMixin.make_path_plot( - data=plot_arrs, - colors=colors, - width=style_attr["width"], - height=style_attr["height"], - max_lines=style_attr["max lines"], - bg_clr=bg_clr, - circle_size=style_attr["circle size"], - font_size=style_attr["font size"], - font_thickness=style_attr["font thickness"], - line_width=style_attr["line width"], - animal_names=animal_names, - clf_attr=clf_attr_cpy, - save_path=None, - ) + for k, v in clf_attr.items(): clf_attr_cpy[k]["clfs"][frame_id + 1 :] = 0 + + img = PlottingMixin.make_path_plot(data=plot_arrs, + colors=colors, + width=style_attr["width"], + height=style_attr["height"], + max_lines=style_attr["max lines"], + bg_clr=bg_clr, + circle_size=style_attr["circle size"], + font_size=style_attr["font size"], + font_thickness=style_attr["font thickness"], + line_width=style_attr["line width"], + animal_names=animal_names, + clf_attr=clf_attr_cpy, + save_path=None) + if roi is not None: + img = PlottingMixin.roi_dict_onto_img(img=img, roi_dict=roi, show_tags=False, show_center=False) if video_setting: video_writer.write(np.uint8(img)) if frame_setting: frm_name = os.path.join(frame_folder_dir, f"{frame_id}.png") cv2.imwrite(frm_name, np.uint8(img)) - print( - f"Path frame created: {frame_id}, Video: {video_name}, Processing core: {batch_id}" - ) + print(f"Path frame created: {frame_id}, Video: {video_name}, Processing core: {batch_id}") if video_setting: video_writer.release() return batch_id @@ -119,6 +106,9 @@ class PathPlotterMulticore(ConfigReader, PlottingMixin): :param Optional[dict] input_style_attr: Plot sttributes (line thickness, color, etc..). If None, then autocomputed. Max lines will be set to 2s. :param Optional[dict] input_clf_attr: Dict reprenting classified behavior locations, their color and size. If None, then no classified behavior locations are shown. :param Optional[dict] slicing: If Dict, start time and end time of video slice to create path plot from. E.g., {'start_time': '00:00:01', 'end_time': '00:00:03'}. If None, creates path plot for entire video. + :param Optional[bool] roi: If True, also plots the ROIs associated with the video. Default False. + + :param int cores: Number of cores to use. .. note:: @@ -144,72 +134,37 @@ class PathPlotterMulticore(ConfigReader, PlottingMixin): >>> path_plotter.run() """ - def __init__( - self, - config_path: Union[str, os.PathLike], - files_found: List[str], - frame_setting: Optional[bool] = False, - video_setting: Optional[bool] = False, - last_frame: Optional[bool] = True, - cores: Optional[int] = -1, - print_animal_names: Optional[bool] = False, - input_style_attr: Optional[Dict] = None, - animal_attr: Dict[int, Any] = None, - clf_attr: Optional[Dict[int, List[str]]] = None, - slicing: Optional[Dict[str, str]] = None, - ): + def __init__(self, + config_path: Union[str, os.PathLike], + files_found: List[str], + frame_setting: Optional[bool] = False, + video_setting: Optional[bool] = False, + last_frame: Optional[bool] = True, + cores: Optional[int] = -1, + print_animal_names: Optional[bool] = False, + input_style_attr: Optional[Dict] = None, + animal_attr: Dict[int, Any] = None, + clf_attr: Optional[Dict[int, List[str]]] = None, + slicing: Optional[Dict[str, str]] = None, + roi: Optional[bool] = False): if platform.system() == "Darwin": multiprocessing.set_start_method("spawn", force=True) if (not frame_setting) and (not video_setting) and (not last_frame): - raise NoSpecifiedOutputError( - msg="SIMBA ERROR: Please choice to create path frames and/or video path plots", - source=self.__class__.__name__, - ) - check_valid_lst( - data=files_found, source=self.__class__.__name__, valid_dtypes=(str,) - ) - - check_int( - name=f"{self.__class__.__name__} core_cnt", - value=cores, - min_value=-1, - max_value=find_core_cnt()[0], - ) + raise NoSpecifiedOutputError(msg="SIMBA ERROR: Please choice to create path frames and/or video path plots", source=self.__class__.__name__) + check_valid_lst(data=files_found, source=self.__class__.__name__, valid_dtypes=(str,)) + + check_int(name=f"{self.__class__.__name__} core_cnt", value=cores, min_value=-1, max_value=find_core_cnt()[0]) if cores == -1: cores = find_core_cnt()[0] ConfigReader.__init__(self, config_path=config_path) PlottingMixin.__init__(self) - log_event( - logger_name=str(__class__.__name__), - log_type=TagNames.CLASS_INIT.value, - msg=self.create_log_msg_from_init_args(locals=locals()), - ) - ( - self.video_setting, - self.frame_setting, - self.input_style_attr, - self.files_found, - self.animal_attr, - self.input_clf_attr, - self.last_frame, - self.cores, - ) = ( - video_setting, - frame_setting, - input_style_attr, - files_found, - animal_attr, - clf_attr, - last_frame, - cores, - ) - self.print_animal_names, self.clf_attr, self.slicing = ( - print_animal_names, - clf_attr, - slicing, - ) + if roi: + self.read_roi_data() + log_event(logger_name=str(__class__.__name__), log_type=TagNames.CLASS_INIT.value, msg=self.create_log_msg_from_init_args(locals=locals())) + (self.video_setting, self.frame_setting, self.input_style_attr, self.files_found, self.animal_attr, self.input_clf_attr, self.last_frame, self.cores, self.roi) = (video_setting, frame_setting, input_style_attr, files_found, animal_attr, clf_attr, last_frame, cores, roi) + self.print_animal_names, self.clf_attr, self.slicing = (print_animal_names, clf_attr, slicing) if not os.path.exists(self.path_plot_dir): os.makedirs(self.path_plot_dir) print(f"Processing {len(self.files_found)} videos...") @@ -221,32 +176,20 @@ def __get_styles(self): if self.input_style_attr["max lines"] == "entire video": self.style_attr["max lines"] = len(self.data_df) else: - self.style_attr["max lines"] = max( - 1, - int( - int(self.input_style_attr["max lines"] / 1000) - * (int(self.video_info["fps"].values[0])) - ), - ) + self.style_attr["max lines"] = max(1, int(int(self.input_style_attr["max lines"] / 1000) * (int(self.video_info["fps"].values[0])))) self.style_attr["font thickness"] = self.input_style_attr["font thickness"] self.style_attr["line width"] = self.input_style_attr["line width"] self.style_attr["font size"] = self.input_style_attr["font size"] self.style_attr["circle size"] = self.input_style_attr["circle size"] self.style_attr["print_animal_names"] = self.print_animal_names if self.input_style_attr["width"] == "As input": - self.style_attr["width"], self.style_attr["height"] = int( - self.video_info["Resolution_width"].values[0] - ), int(self.video_info["Resolution_height"].values[0]) + self.style_attr["width"], self.style_attr["height"] = int(self.video_info["Resolution_width"].values[0]), int(self.video_info["Resolution_height"].values[0]) else: pass else: space_scaler, radius_scaler, res_scaler, font_scaler = 25, 10, 1500, 0.8 - self.style_attr["width"] = int( - self.video_info["Resolution_width"].values[0] - ) - self.style_attr["height"] = int( - self.video_info["Resolution_height"].values[0] - ) + self.style_attr["width"] = int(self.video_info["Resolution_width"].values[0]) + self.style_attr["height"] = int(self.video_info["Resolution_height"].values[0]) max_res = max(self.style_attr["width"], self.style_attr["height"]) self.style_attr["circle size"] = int(radius_scaler / (res_scaler / max_res)) self.style_attr["font size"] = font_scaler / (res_scaler / max_res) @@ -260,51 +203,29 @@ def run(self): for file_cnt, file_path in enumerate(self.files_found): video_timer = SimbaTimer(start=True) _, self.video_name, _ = get_fn_ext(file_path) - self.video_info, _, self.fps = self.read_video_info( - video_name=self.video_name - ) + self.video_info, _, self.fps = self.read_video_info(video_name=self.video_name) self.data_df = read_df(file_path, self.file_type) line_data, colors, animal_names = [], [], [] for k, v in self.animal_attr.items(): - check_if_keys_exist_in_dict( - data=v, key=["bp", "color"], name=f"animal attr {k}" - ) + check_if_keys_exist_in_dict(data=v, key=["bp", "color"], name=f"animal attr {k}") line_data.append( - self.data_df[[f'{v["bp"]}_x', f'{v["bp"]}_y']].values.astype( - np.int64 - ) - ) + self.data_df[[f'{v["bp"]}_x', f'{v["bp"]}_y']].values.astype(np.int64)) colors.append(v["color"]) if self.print_animal_names: - animal_names.append( - self.find_animal_name_from_body_part_name( - bp_name=v["bp"], bp_dict=self.animal_bp_dict - ) - ) + animal_names.append(self.find_animal_name_from_body_part_name(bp_name=v["bp"], bp_dict=self.animal_bp_dict)) if not self.print_animal_names: animal_names = None if self.slicing: - check_if_keys_exist_in_dict( - data=self.slicing, key=["start_time", "end_time"], name="slicing" - ) - frm_numbers = find_frame_numbers_from_time_stamp( - start_time=self.slicing["start_time"], - end_time=self.slicing["end_time"], - fps=self.fps, - ) + check_if_keys_exist_in_dict(data=self.slicing, key=["start_time", "end_time"], name="slicing") + frm_numbers = find_frame_numbers_from_time_stamp(start_time=self.slicing["start_time"], end_time=self.slicing["end_time"], fps=self.fps) if len(set(frm_numbers) - set(self.data_df.index)) > 0: - raise FrameRangeError( - msg=f'The chosen time-period ({self.slicing["start_time"]} - {self.slicing["end_time"]}) does not exist in {self.video_name}.', - source=self.__class__.__name__, - ) + raise FrameRangeError(msg=f'The chosen time-period ({self.slicing["start_time"]} - {self.slicing["end_time"]}) does not exist in {self.video_name}.', source=self.__class__.__name__,) for i in range(len(line_data)): line_data[i] = line_data[i][frm_numbers, :] self.__get_styles() self.temp_folder = os.path.join(self.path_plot_dir, self.video_name, "temp") - self.save_frame_folder_dir = os.path.join( - self.path_plot_dir, self.video_name - ) + self.save_frame_folder_dir = os.path.join(self.path_plot_dir, self.video_name) if self.frame_setting: if os.path.exists(self.save_frame_folder_dir): @@ -317,151 +238,120 @@ def run(self): remove_a_folder(self.temp_folder) remove_a_folder(self.video_folder) os.makedirs(self.temp_folder) - self.save_video_path = os.path.join( - self.path_plot_dir, f"{self.video_name}.mp4" - ) + self.save_video_path = os.path.join(self.path_plot_dir, f"{self.video_name}.mp4") if self.clf_attr is not None: self.clf_attr_appended = {} - check_instance( - source=self.__class__.__name__, - instance=self.clf_attr, - accepted_types=(dict,), - ) + check_instance(source=self.__class__.__name__, instance=self.clf_attr, accepted_types=(dict,)) for k, v in self.clf_attr.items(): - check_if_keys_exist_in_dict( - data=v, key=["color", "size"], name=f"clf_attr {k}" - ) - check_that_column_exist( - df=self.data_df, column_name=k, file_name=file_path - ) + check_if_keys_exist_in_dict(data=v, key=["color", "size"], name=f"clf_attr {k}") + check_that_column_exist(df=self.data_df, column_name=k, file_name=file_path) self.clf_attr_appended[k] = self.clf_attr[k] - self.clf_attr_appended[k]["clfs"] = self.data_df[k].values.astype( - np.int8 - ) - self.clf_attr_appended[k]["positions"] = self.data_df[ - [ - self.animal_attr[0]["bp"] + "_x", - self.animal_attr[0]["bp"] + "_y", - ] - ].values.astype(np.int64) + self.clf_attr_appended[k]["clfs"] = self.data_df[k].values.astype(np.int8) + self.clf_attr_appended[k]["positions"] = self.data_df[[self.animal_attr[0]["bp"] + "_x", self.animal_attr[0]["bp"] + "_y",]].values.astype(np.int64) self.clf_attr = deepcopy(self.clf_attr_appended) del self.clf_attr_appended bg_clr = self.style_attr["bg color"] self.video_path = None if isinstance(self.style_attr["bg color"], dict): - self.video_path = find_video_of_file( - video_dir=self.video_dir, filename=self.video_name, raise_error=True - ) + self.video_path = find_video_of_file(video_dir=self.video_dir, filename=self.video_name, raise_error=True) if "frame_index" in self.style_attr["bg color"].keys(): - check_int( - name="Static frame index", - value=self.style_attr["bg color"]["frame_index"], - min_value=0, - ) + check_int(name="Static frame index", value=self.style_attr["bg color"]["frame_index"], min_value=0) frame_index = self.style_attr["bg color"]["frame_index"] else: video_meta_data = get_video_meta_data(video_path=self.video_path) frame_index = video_meta_data["frame_count"] - 1 - bg_clr = read_frm_of_video( - video_path=self.video_path, - opacity=self.style_attr["bg color"]["opacity"], - frame_index=frame_index, - ) + bg_clr = read_frm_of_video(video_path=self.video_path,opacity=self.style_attr["bg color"]["opacity"],frame_index=frame_index) + + video_rois, video_roi_names = None, None + if self.roi: + video_rois, roi_names = slice_roi_dict_for_video(data=self.roi_dict, video_name=self.video_name) + if len(roi_names) == 0: + video_rois, video_roi_names = None, None + ROIWarning(msg=f'NO ROI data found for video {self.video_name}. Skipping ROI plotting for this video.') + + if self.last_frame: - PlottingMixin.make_path_plot( - data=line_data, - colors=colors, - width=self.style_attr["width"], - height=self.style_attr["height"], - max_lines=self.style_attr["max lines"], - bg_clr=bg_clr, - circle_size=self.style_attr["circle size"], - font_size=self.style_attr["font size"], - font_thickness=self.style_attr["font thickness"], - line_width=self.style_attr["line width"], - animal_names=animal_names, - clf_attr=self.clf_attr, - save_path=os.path.join( - self.path_plot_dir, f"{self.video_name}_final_frame.png" - ), - ) + last_frame_save_path = os.path.join(self.path_plot_dir, f"{self.video_name}_final_frame.png") + last_frm = PlottingMixin.make_path_plot(data=line_data, + colors=colors, + width=self.style_attr["width"], + height=self.style_attr["height"], + max_lines=self.style_attr["max lines"], + bg_clr=bg_clr, + circle_size=self.style_attr["circle size"], + font_size=self.style_attr["font size"], + font_thickness=self.style_attr["font thickness"], + line_width=self.style_attr["line width"], + animal_names=animal_names, + clf_attr=self.clf_attr, + save_path=None) + + if video_rois is not None: + last_frm = PlottingMixin.roi_dict_onto_img(img=last_frm, roi_dict=video_rois, show_tags=False, show_center=False) + cv2.imwrite(filename=last_frame_save_path, img=last_frm) + stdout_success(msg=f'Last path plot frame saved at {last_frame_save_path}') if self.video_setting or self.frame_setting: frm_range = np.arange(1, line_data[0].shape[0]) frm_range = np.array_split(frm_range, self.cores) frm_range = [(cnt, x) for cnt, x in enumerate(frm_range)] - print( - f"Creating path plots, multiprocessing (chunksize: {self.multiprocess_chunksize}, cores: {self.cores})..." - ) - with multiprocessing.Pool( - self.cores, maxtasksperchild=self.maxtasksperchild - ) as pool: - constants = functools.partial( - path_plot_mp, - data=line_data, - colors=colors, - video_setting=self.video_setting, - video_name=self.video_name, - frame_setting=self.frame_setting, - video_save_dir=self.temp_folder, - frame_folder_dir=self.save_frame_folder_dir, - style_attr=self.style_attr, - animal_names=animal_names, - fps=self.fps, - clf_attr=self.clf_attr, - input_style_attr=self.input_style_attr, - video_path=self.video_path, - ) + print(f"Creating path plots, multiprocessing (chunksize: {self.multiprocess_chunksize}, cores: {self.cores})...") + with multiprocessing.Pool(self.cores, maxtasksperchild=self.maxtasksperchild) as pool: + constants = functools.partial(path_plot_mp, + data=line_data, + colors=colors, + video_setting=self.video_setting, + video_name=self.video_name, + frame_setting=self.frame_setting, + video_save_dir=self.temp_folder, + frame_folder_dir=self.save_frame_folder_dir, + style_attr=self.style_attr, + animal_names=animal_names, + fps=self.fps, + roi=video_rois, + clf_attr=self.clf_attr, + input_style_attr=self.input_style_attr, + video_path=self.video_path) for cnt, result in enumerate( - pool.imap( - constants, frm_range, chunksize=self.multiprocess_chunksize - ) - ): + pool.imap(constants, frm_range, chunksize=self.multiprocess_chunksize)): print(f"Path batch {result+1}/{self.cores} complete...") pool.terminate() pool.join() if self.video_setting: print(f"Joining {self.video_name} multiprocessed video...") - concatenate_videos_in_folder( - in_folder=self.temp_folder, save_path=self.save_video_path - ) + concatenate_videos_in_folder(in_folder=self.temp_folder, save_path=self.save_video_path) video_timer.stop_timer() - print( - f"Path plot video {self.video_name} complete (elapsed time: {video_timer.elapsed_time_str}s) ..." - ) + print(f"Path plot video {self.video_name} complete (elapsed time: {video_timer.elapsed_time_str}s) ...") self.timer.stop_timer() - stdout_success( - msg=f"Path plot visualizations for {len(self.files_found)} videos created in {self.path_plot_dir} directory", - elapsed_time=self.timer.elapsed_time_str, - source=self.__class__.__name__, - ) - - -# animal_attr = {0: {'bp': 'Ear_right_1', 'color': (255, 0, 0)}, 1: {'bp': 'Ear_right_2', 'color': (0, 0, 255)}} #['Ear_right_1', 'Red'], 1: ['Ear_right_2', 'Green']} -# style_attr = {'width': 'As input', -# 'height': 'As input', -# 'line width': 2, -# 'font size': 0.9, -# 'font thickness': 2, -# 'circle size': 5, -# 'bg color': {'type': 'moving', 'opacity': 50, 'frame_index': 200}, #{'type': 'static', 'opacity': 100, 'frame_index': 200} -# 'max lines': 'entire video'} -# clf_attr = {'Nose to Nose': {'color': (155, 1, 10), 'size': 30}, 'Nose to Tailbase': {'color': (155, 90, 10), 'size': 30}} -# #clf_attr=None - -# path_plotter = PathPlotterMulticore(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/beepboop174/project_folder/project_config.ini', + stdout_success(msg=f"Path plot visualizations for {len(self.files_found)} videos created in {self.path_plot_dir} directory", elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__) + + +# animal_attr = {0: {'bp': 'Ear_right', 'color': (255, 0, 0)}, 1: {'bp': 'Tail_base', 'color': (0, 0, 255)}} #['Ear_right_1', 'Red'], 1: ['Ear_right_2', 'Green']} +# # style_attr = {'width': 'As input', +# # 'height': 'As input', +# # 'line width': 2, +# # 'font size': 0.9, +# # 'font thickness': 2, +# # 'circle size': 5, +# # 'bg color': {'type': 'moving', 'opacity': 50, 'frame_index': 200}, #{'type': 'static', 'opacity': 100, 'frame_index': 200} +# # 'max lines': 'entire video'} +# # clf_attr = {'Nose to Nose': {'color': (155, 1, 10), 'size': 30}, 'Nose to Tailbase': {'color': (155, 90, 10), 'size': 30}} +# clf_attr=None +# +# path_plotter = PathPlotterMulticore(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/project_config.ini', # frame_setting=False, # video_setting=True, # last_frame=True, # clf_attr=clf_attr, -# input_style_attr=style_attr, +# input_style_attr=None, # animal_attr=animal_attr, -# files_found=['/Users/simon/Desktop/envs/simba/troubleshooting/beepboop174/project_folder/csv/machine_results/Trial 10.csv'], +# roi=True, +# files_found=['/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/csv/outlier_corrected_movement_location/2022-06-20_NOB_DOT_4.csv'], # cores=-1, # slicing = {'start_time': '00:00:00', 'end_time': '00:00:05'}, # {'start_time': '00:00:00', 'end_time': '00:00:05'}, # , #None, # print_animal_names=False) diff --git a/simba/ui/pop_ups/path_plot_pop_up.py b/simba/ui/pop_ups/path_plot_pop_up.py index 976dd895a..d716dc7a9 100644 --- a/simba/ui/pop_ups/path_plot_pop_up.py +++ b/simba/ui/pop_ups/path_plot_pop_up.py @@ -10,14 +10,10 @@ from simba.mixins.pop_up_mixin import PopUpMixin from simba.plotting.path_plotter import PathPlotterSingleCore from simba.plotting.path_plotter_mp import PathPlotterMulticore -from simba.ui.tkinter_functions import (CreateLabelFrameWithIcon, DropDownMenu, - Entry_Box, SimbaButton, SimbaCheckbox) -from simba.utils.checks import (check_if_filepath_list_is_empty, - check_if_string_value_is_valid_video_timestamp, - check_if_valid_rgb_str, check_int, - check_that_hhmmss_start_is_before_end) -from simba.utils.enums import Formats, Keys, Links, Paths -from simba.utils.errors import FrameRangeError +from simba.ui.tkinter_functions import (CreateLabelFrameWithIcon, DropDownMenu, Entry_Box, SimbaButton, SimbaCheckbox) +from simba.utils.checks import (check_if_string_value_is_valid_video_timestamp, check_if_valid_rgb_str, check_int, check_that_hhmmss_start_is_before_end) +from simba.utils.enums import Formats, Keys, Links +from simba.utils.errors import FrameRangeError, NoFilesFoundError, NoROIDataError from simba.utils.lookups import get_color_dict from simba.utils.read_write import get_file_name_info_in_directory @@ -25,9 +21,11 @@ class PathPlotPopUp(PopUpMixin, ConfigReader): def __init__(self, config_path: Union[str, os.PathLike]): ConfigReader.__init__(self, config_path=config_path, read_video_info=False) - self.data_path = os.path.join(self.project_path, Paths.MACHINE_RESULTS_DIR.value) - self.files_found_dict = get_file_name_info_in_directory(directory=self.data_path, file_type=self.file_type) - check_if_filepath_list_is_empty(filepaths=list(self.files_found_dict.keys()), error_msg="SIMBA ERROR: Zero files found in the project_folder/csv/machine_results directory. Create classification results before visualizing path plots") + self.machine_results_files = get_file_name_info_in_directory(directory=self.machine_results_dir, file_type=self.file_type) + self.outlier_corrected_files = get_file_name_info_in_directory(directory=self.outlier_corrected_dir, file_type=self.file_type) + self.files_found = list(set(list(self.machine_results_files.keys()) + list(self.outlier_corrected_files.keys()))) + if len(self.files_found) == 0: + raise NoFilesFoundError(msg=f'No data files found inside the {self.outlier_corrected_dir} or the {self.machine_results_dir} directory', source=self.__class__.__name__) PopUpMixin.__init__(self, title="CREATE PATH PLOTS", size=(550, 850)) self.resolution_options = deepcopy(self.resolutions) self.resolution_options.insert(0, "As input") @@ -83,7 +81,7 @@ def __init__(self, config_path: Union[str, os.PathLike]): self.video_end_time_entry.set_state("disable") self.slice_cb = Checkbutton( self.video_slicing_frm, - text="Plot only defined time-segment", + text="Plot ONLY defined time-segment", font=Formats.FONT_REGULAR.value, variable=self.slice_var, command=lambda: self.enable_entrybox_from_checkbox( @@ -95,21 +93,26 @@ def __init__(self, config_path: Union[str, os.PathLike]): self.video_start_time_entry.grid(row=1, column=0, sticky=NW) self.video_end_time_entry.grid(row=2, column=0, sticky=NW) - self.clf_frm = LabelFrame(self.main_frm, text="CHOOSE CLASSIFICATION VISUALIZATION", font=Formats.FONT_HEADER.value, pady=5, padx=5) + self.clf_frm = LabelFrame(self.main_frm, text="CLASSIFICATION VISUALIZATION", font=Formats.FONT_HEADER.value, pady=5, padx=5) self.include_clf_locations_var = BooleanVar(value=False) - self.include_clf_locations_cb = Checkbutton(self.clf_frm, text="Include classification locations", font=Formats.FONT_REGULAR.value, variable=self.include_clf_locations_var, command=self.populate_clf_location_data) + self.include_clf_locations_cb = Checkbutton(self.clf_frm, text="INCLUDE CLASSIFICATION LOCATIONS", font=Formats.FONT_REGULAR.value, variable=self.include_clf_locations_var, command=self.populate_clf_location_data) self.include_clf_locations_cb.grid(row=0, sticky=NW) self.populate_clf_location_data() + self.roi_frm = LabelFrame(self.main_frm, text="ROI VISUALIZATION", font=Formats.FONT_HEADER.value, pady=5, padx=5) + roi_cb, self.roi_var = SimbaCheckbox(parent=self.roi_frm, txt='INCLUDE ROIs', txt_img='roi', val=False) + + + self.populate_body_parts_menu(self.animal_cnt_options[0]) self.settings_frm = LabelFrame(self.main_frm, text="VISUALIZATION SETTINGS", font=Formats.FONT_HEADER.value, pady=5, padx=5) self.multiprocessing_var = BooleanVar() path_frames_cb, self.path_frames_var = SimbaCheckbox(parent=self.settings_frm, txt='CREATE FRAMES', txt_img='frames') path_videos_cb, self.path_videos_var = SimbaCheckbox(parent=self.settings_frm, txt='CREATE VIDEOS', txt_img='video') - path_last_frm_cb, self.path_last_frm_var = SimbaCheckbox(parent=self.settings_frm, txt='CREATE LAST FRAME', txt_img='finish') - self.include_animal_names_cb, self.include_animal_names_var = SimbaCheckbox(parent=self.settings_frm, txt='CREATE LAST FRAME', txt_img='id_card', val=True) + path_last_frm_cb, self.path_last_frm_var = SimbaCheckbox(parent=self.settings_frm, txt='CREATE LAST FRAME', txt_img='finish', val=True) + self.include_animal_names_cb, self.include_animal_names_var = SimbaCheckbox(parent=self.settings_frm, txt='INCLUDE ANIMAL NAMES', txt_img='id_card') self.multiprocess_cb = Checkbutton( self.settings_frm, @@ -129,10 +132,10 @@ def __init__(self, config_path: Union[str, os.PathLike]): self.run_single_video_frm = LabelFrame(self.run_frm, text="SINGLE VIDEO", font=Formats.FONT_HEADER.value, pady=5, padx=5, fg="black") self.run_single_video_btn = SimbaButton(parent=self.run_single_video_frm, txt='CREATE SINGLE VIDEO', img='rocket', txt_clr='blue', font=Formats.FONT_REGULAR.value, cmd=self.__create_path_plots, cmd_kwargs={'multiple_videos': False}) - self.single_video_dropdown = DropDownMenu( self.run_single_video_frm, "Video:", list(self.files_found_dict.keys()), "12") - self.single_video_dropdown.setChoices(list(self.files_found_dict.keys())[0]) + self.single_video_dropdown = DropDownMenu( self.run_single_video_frm, "Video:", self.files_found, "12") + self.single_video_dropdown.setChoices(self.files_found[0]) self.run_multiple_videos = LabelFrame( self.run_frm, text="MULTIPLE VIDEO", font=Formats.FONT_HEADER.value, pady=5, padx=5, fg="black") - self.run_multiple_video_btn = SimbaButton(parent=self.run_multiple_videos, txt="Create multiple videos ({} video(s) found)".format(str(len(list(self.files_found_dict.keys())))), font=Formats.FONT_REGULAR.value, img='rocket', txt_clr='blue', cmd=self.__create_path_plots, cmd_kwargs={'multiple_videos': True}) + self.run_multiple_video_btn = SimbaButton(parent=self.run_multiple_videos, txt=f"Create multiple videos ({len(self.files_found)} video(s) found)", font=Formats.FONT_REGULAR.value, img='rocket', txt_clr='blue', cmd=self.__create_path_plots, cmd_kwargs={'multiple_videos': True}) self.style_settings_frm.grid(row=0, sticky=NW) self.auto_compute_styles.grid(row=0, sticky=NW) @@ -149,10 +152,13 @@ def __init__(self, config_path: Union[str, os.PathLike]): self.video_slicing_frm.grid(row=1, sticky=NW) self.clf_frm.grid(row=2, sticky=NW) - self.body_parts_frm.grid(row=3, sticky=NW) + self.roi_frm.grid(row=3, sticky=NW) + roi_cb.grid(row=0, sticky=NW) + + self.body_parts_frm.grid(row=4, sticky=NW) self.number_of_animals_dropdown.grid(row=0, sticky=NW) - self.settings_frm.grid(row=4, sticky=NW) + self.settings_frm.grid(row=5, sticky=NW) path_frames_cb.grid(row=0, sticky=NW) path_videos_cb.grid(row=1, sticky=NW) path_last_frm_cb.grid(row=2, sticky=NW) @@ -160,7 +166,7 @@ def __init__(self, config_path: Union[str, os.PathLike]): self.multiprocess_cb.grid(row=4, column=0, sticky=NW) self.multiprocess_dropdown.grid(row=4, column=1, sticky=NW) - self.run_frm.grid(row=5, sticky=NW) + self.run_frm.grid(row=6, sticky=NW) self.run_single_video_frm.grid(row=0, sticky=NW) self.run_single_video_btn.grid(row=0, column=0, sticky=NW) self.single_video_dropdown.grid(row=0, column=1, sticky=NW) @@ -184,24 +190,11 @@ def populate_body_parts_menu(self, choice): self.bp_dropdowns, self.bp_colors = {}, {} self.bp_row_idx = [] for animal_cnt in range(int(self.number_of_animals_dropdown.getChoices())): - self.bp_dropdowns[animal_cnt] = DropDownMenu( - self.body_parts_frm, - "Body-part {}:".format(str(animal_cnt + 1)), - self.body_parts_lst, - "16", - ) + self.bp_dropdowns[animal_cnt] = DropDownMenu(self.body_parts_frm, "Body-part {}:".format(str(animal_cnt + 1)), self.body_parts_lst, "16") self.bp_dropdowns[animal_cnt].setChoices(self.body_parts_lst[animal_cnt]) self.bp_dropdowns[animal_cnt].grid(row=animal_cnt + 1, column=0, sticky=NW) - self.bp_colors[animal_cnt] = DropDownMenu( - self.body_parts_frm, - "", - self.animal_trace_clrs, - "2", - com=lambda x, k=animal_cnt: self.__set_custom_clrs(choice=x, row=k), - ) - self.bp_colors[animal_cnt].setChoices( - list(self.colors_dict.keys())[animal_cnt] - ) + self.bp_colors[animal_cnt] = DropDownMenu(self.body_parts_frm, "", self.animal_trace_clrs, "2", com=lambda x, k=animal_cnt: self.__set_custom_clrs(choice=x, row=k)) + self.bp_colors[animal_cnt].setChoices(list(self.colors_dict.keys())[animal_cnt]) self.bp_colors[animal_cnt].grid(row=animal_cnt + 1, column=1, sticky=NW) def __activate_settings(self, choice: str): @@ -210,9 +203,7 @@ def __activate_settings(self, choice: str): else: self.bg_opacity_dropdown.disable() if choice == "Video - static frame": - self.static_frm_index_eb = Entry_Box( - self.style_settings_frm, "Frame index: ", "10", validation="numeric" - ) + self.static_frm_index_eb = Entry_Box(self.style_settings_frm, "Frame index: ", "10", validation="numeric") self.static_frm_index_eb.entry_set(val=1) self.static_frm_index_eb.grid(row=8, column=1, sticky=NW) else: @@ -221,9 +212,7 @@ def __activate_settings(self, choice: str): def __set_custom_clrs(self, choice: str, row: int): if choice == "Custom": - self.custom_rgb_selections[row] = Entry_Box( - self.body_parts_frm, "RGB:", "5", entry_box_width=10 - ) + self.custom_rgb_selections[row] = Entry_Box(self.body_parts_frm, "RGB:", "5", entry_box_width=10) self.custom_rgb_selections[row].entry_set(val="255,0,0") self.custom_rgb_selections[row].grid(row=row + 1, column=3, sticky=NW) else: @@ -236,18 +225,11 @@ def populate_clf_location_data(self): size_lst = list(range(1, 51)) size_lst = ["Size: " + str(x) for x in size_lst] for clf_cnt, clf_name in enumerate(self.clf_names): - self.clf_name[clf_cnt] = DropDownMenu( - self.clf_frm, - "Classifier {}:".format(str(clf_cnt + 1)), - self.clf_names, - "16", - ) + self.clf_name[clf_cnt] = DropDownMenu(self.clf_frm, "Classifier {}:".format(str(clf_cnt + 1)), self.clf_names, "16") self.clf_name[clf_cnt].setChoices(self.clf_names[clf_cnt]) self.clf_name[clf_cnt].grid(row=clf_cnt + 1, column=0, sticky=NW) - self.clf_clr[clf_cnt] = DropDownMenu( - self.clf_frm, "", list(self.colors_dict.keys()), "2" - ) + self.clf_clr[clf_cnt] = DropDownMenu(self.clf_frm, "", list(self.colors_dict.keys()), "2") self.clf_clr[clf_cnt].setChoices(list(self.colors_dict.keys())[clf_cnt]) self.clf_clr[clf_cnt].grid(row=clf_cnt + 1, column=1, sticky=NW) @@ -287,9 +269,7 @@ def enable_style_settings(self): self.bg_clr_dropdown.enable() if self.bg_clr_dropdown.getChoices() == "Video": self.bg_opacity_dropdown.enable() - self.enable_entrybox_from_dropdown( - self.max_prior_lines_dropdown.getChoices() - ) + self.enable_entrybox_from_dropdown(self.max_prior_lines_dropdown.getChoices()) else: self.resolution_dropdown.disable() self.max_prior_lines_dropdown.disable() @@ -310,24 +290,12 @@ def __create_path_plots(self, multiple_videos: bool): height = int(self.resolution_dropdown.getChoices().split("×")[1]) else: width, height = "As input", "As input" - check_int( - name="PATH LINE WIDTH", value=self.line_width.entry_get, min_value=1 - ) - check_int( - name="PATH CIRCLE SIZE", value=self.circle_size.entry_get, min_value=1 - ) - check_int( - name="PATH FONT SIZE", value=self.font_size.entry_get, min_value=1 - ) - check_int( - name="FONT THICKNESS", value=self.font_thickness.entry_get, min_value=1 - ) + check_int(name="PATH LINE WIDTH", value=self.line_width.entry_get, min_value=1) + check_int(name="PATH CIRCLE SIZE", value=self.circle_size.entry_get, min_value=1) + check_int(name="PATH FONT SIZE", value=self.font_size.entry_get, min_value=1) + check_int(name="FONT THICKNESS", value=self.font_thickness.entry_get, min_value=1) if self.bg_clr_dropdown.getChoices() == "Video - static frame": - check_int( - name="Static frame index", - value=self.static_frm_index_eb.entry_get, - min_value=0, - ) + check_int(name="Static frame index", value=self.static_frm_index_eb.entry_get, min_value=0) bg_clr = { "type": "static", "opacity": int( @@ -353,25 +321,19 @@ def __create_path_plots(self, multiple_videos: bool): else: bg_clr = get_color_dict()[self.bg_clr_dropdown.getChoices()] - style_attr = { - "width": width, - "height": height, - "line width": int(self.line_width.entry_get), - "font size": int(self.font_size.entry_get), - "font thickness": int(self.font_thickness.entry_get), - "circle size": int(self.circle_size.entry_get), - "bg color": bg_clr, - "clf locations": self.include_clf_locations_var.get(), - } + style_attr = {"width": width, + "height": height, + "line width": int(self.line_width.entry_get), + "font size": int(self.font_size.entry_get), + "font thickness": int(self.font_thickness.entry_get), + "circle size": int(self.circle_size.entry_get), + "bg color": bg_clr, + "clf locations": self.include_clf_locations_var.get()} if self.max_prior_lines_dropdown.getChoices() == "Entire video": style_attr["max lines"] = "entire video" else: - check_int( - name="PATH MAX LINES", - value=self.max_lines_entry.entry_get, - min_value=1, - ) + check_int( name="PATH MAX LINES", value=self.max_lines_entry.entry_get, min_value=1) style_attr["max lines"] = int(self.max_lines_entry.entry_get) animal_attr = {} @@ -392,82 +354,73 @@ def __create_path_plots(self, multiple_videos: bool): self.slicing = None if self.slice_var.get(): - check_if_string_value_is_valid_video_timestamp( - value=self.video_start_time_entry.entry_get, - name="Video slicing START TIME", - ) - check_if_string_value_is_valid_video_timestamp( - value=self.video_end_time_entry.entry_get, name="Video slicing END TIME" - ) - if ( - self.video_start_time_entry.entry_get - == self.video_end_time_entry.entry_get - ): - raise FrameRangeError( - msg="The sliced start and end times cannot be identical", - source=self.__class__.__name__, - ) - check_that_hhmmss_start_is_before_end( - start_time=self.video_start_time_entry.entry_get, - end_time=self.video_end_time_entry.entry_get, - name="SLICE TIME STAMPS", - ) - self.slicing = { - "start_time": self.video_start_time_entry.entry_get, - "end_time": self.video_end_time_entry.entry_get, - } + check_if_string_value_is_valid_video_timestamp(value=self.video_start_time_entry.entry_get, name="Video slicing START TIME") + check_if_string_value_is_valid_video_timestamp(value=self.video_end_time_entry.entry_get, name="Video slicing END TIME") + if (self.video_start_time_entry.entry_get == self.video_end_time_entry.entry_get): + raise FrameRangeError(msg="The sliced start and end times cannot be identical", source=self.__class__.__name__) + check_that_hhmmss_start_is_before_end(start_time=self.video_start_time_entry.entry_get, end_time=self.video_end_time_entry.entry_get, name="SLICE TIME STAMPS") + self.slicing = {"start_time": self.video_start_time_entry.entry_get, "end_time": self.video_end_time_entry.entry_get} clf_attr = None if self.include_clf_locations_var.get(): + if multiple_videos: + if len(self.machine_results_paths) == 0: + raise NoFilesFoundError(msg=f'No DATA found in {self.machine_results_dir} directory. Un-check the classifier location checkbox, OR make sure the folder contains classification data.') + else: + data_paths = list(self.machine_results_files.values()) + else: + if self.single_video_dropdown.getChoices() not in self.machine_results_files.keys(): + raise NoFilesFoundError(msg=f'No DATA found for video in {self.single_video_dropdown.getChoices()} in directory {self.machine_results_dir}. Un-check the classifier location checkbox, OR make sure the folder contains classification data for the video.') + else: + data_paths = [self.machine_results_files[self.single_video_dropdown.getChoices()]] clf_attr = {} for cnt, (key, value) in enumerate(self.clf_name.items()): clf_attr[value.getChoices()] = {} - clf_attr[value.getChoices()]["color"] = get_color_dict()[ - self.clf_clr[cnt].getChoices() - ] + clf_attr[value.getChoices()]["color"] = get_color_dict()[self.clf_clr[cnt].getChoices()] size = "".join(filter(str.isdigit, self.clf_size[cnt].getChoices())) clf_attr[value.getChoices()]["size"] = int(size) - if multiple_videos: - data_paths = list(self.files_found_dict.values()) else: - data_paths = [ - self.files_found_dict[self.single_video_dropdown.getChoices()] - ] + if multiple_videos: + data_paths = list(self.outlier_corrected_files.values()) + else: + data_paths = [self.outlier_corrected_files[self.single_video_dropdown.getChoices()]] + + if self.roi_var.get(): + if not os.path.isfile(self.roi_coordinates_path): + raise NoROIDataError(msg=f'No SimBA ROI project data found. Expected at path {self.roi_coordinates_path}', source=self.__class__.__name__) if not self.multiprocessing_var.get(): - path_plotter = PathPlotterSingleCore( - config_path=self.config_path, - frame_setting=self.path_frames_var.get(), - video_setting=self.path_videos_var.get(), - last_frame=self.path_last_frm_var.get(), - files_found=data_paths, - input_style_attr=style_attr, - print_animal_names=self.include_animal_names_var.get(), - animal_attr=animal_attr, - clf_attr=clf_attr, - slicing=self.slicing, - ) + path_plotter = PathPlotterSingleCore(config_path=self.config_path, + frame_setting=self.path_frames_var.get(), + video_setting=self.path_videos_var.get(), + last_frame=self.path_last_frm_var.get(), + files_found=data_paths, + input_style_attr=style_attr, + print_animal_names=self.include_animal_names_var.get(), + animal_attr=animal_attr, + clf_attr=clf_attr, + slicing=self.slicing, + roi=self.roi_var.get()) else: - path_plotter = PathPlotterMulticore( - config_path=self.config_path, - frame_setting=self.path_frames_var.get(), - video_setting=self.path_videos_var.get(), - last_frame=self.path_last_frm_var.get(), - files_found=data_paths, - input_style_attr=style_attr, - print_animal_names=self.include_animal_names_var.get(), - animal_attr=animal_attr, - clf_attr=clf_attr, - cores=int(self.multiprocess_dropdown.getChoices()), - slicing=self.slicing, - ) + path_plotter = PathPlotterMulticore(config_path=self.config_path, + frame_setting=self.path_frames_var.get(), + video_setting=self.path_videos_var.get(), + last_frame=self.path_last_frm_var.get(), + files_found=data_paths, + input_style_attr=style_attr, + print_animal_names=self.include_animal_names_var.get(), + animal_attr=animal_attr, + clf_attr=clf_attr, + cores=int(self.multiprocess_dropdown.getChoices()), + slicing=self.slicing, + roi=self.roi_var.get()) threading.Thread(target=path_plotter.run()).start() -# _ = PathPlotPopUp(config_path=r"C:\troubleshooting\RAT_NOR\project_folder\project_config.ini") +#_ = PathPlotPopUp(config_path=r"/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini") # _ = PathPlotPopUp(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/beepboop174/project_folder/project_config.ini') diff --git a/simba/utils/data.py b/simba/utils/data.py index d13a262e9..00aa9f658 100644 --- a/simba/utils/data.py +++ b/simba/utils/data.py @@ -687,9 +687,7 @@ def convert_roi_definitions( ) -def slice_roi_dict_for_video( - data: Dict[str, pd.DataFrame], video_name: str -) -> Tuple[Dict[str, pd.DataFrame], List[str]]: +def slice_roi_dict_for_video(data: Dict[str, pd.DataFrame], video_name: str) -> Tuple[Dict[str, pd.DataFrame], List[str]]: """ Given a dictionary of dataframes representing different ROIs (created by ``simba.mixins.config_reader.ConfigReader.read_roi_data``), retain only the ROIs belonging to the specified video.