From cf361a9ff18be267e5aabd38b02b930aaf3dd64b Mon Sep 17 00:00:00 2001 From: sronilsson Date: Thu, 8 Aug 2024 15:48:29 -0400 Subject: [PATCH] bar_graphs --- setup.py | 2 +- simba/mixins/plotting_mixin.py | 90 +++++++++++++++++++++++++++---- simba/mixins/train_model_mixin.py | 49 +++++++++-------- 3 files changed, 104 insertions(+), 37 deletions(-) diff --git a/setup.py b/setup.py index 2a8410ecd..288b25ecd 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ # Setup configuration setuptools.setup( name="Simba-UW-tf-dev", - version="2.0.2", + version="2.0.4", author="Simon Nilsson, Jia Jie Choong, Sophia Hwang", author_email="sronilsson@gmail.com", description="Toolkit for computer classification and analysis of behaviors in experimental animals", diff --git a/simba/mixins/plotting_mixin.py b/simba/mixins/plotting_mixin.py index 1ad941a64..244f58ce1 100644 --- a/simba/mixins/plotting_mixin.py +++ b/simba/mixins/plotting_mixin.py @@ -1718,15 +1718,14 @@ def insert_directing_line( return img @staticmethod - def draw_lines_on_img( - img: np.ndarray, - start_positions: np.ndarray, - end_positions: np.ndarray, - color: Tuple[int, int, int], - highlight_endpoint: Optional[bool] = False, - thickness: Optional[int] = 2, - circle_size: Optional[int] = 2, - ) -> np.ndarray: + def draw_lines_on_img(img: np.ndarray, + start_positions: np.ndarray, + end_positions: np.ndarray, + color: Tuple[int, int, int], + highlight_endpoint: Optional[bool] = False, + thickness: Optional[int] = 2, + circle_size: Optional[int] = 2) -> np.ndarray: + """ Helper to draw a set of lines onto an image. @@ -1737,7 +1736,6 @@ def draw_lines_on_img( :param Optional[bool] highlight_endpoint: If True, highlights the ends of the lines with circles. :param Optional[int] thickness: The thickness of the lines. :param Optional[int] circle_size: If ``highlight_endpoint`` is True, the size of the highlighted points. - :return np.ndarray: The image with the lines overlayed. """ @@ -1826,6 +1824,18 @@ def get_optimal_font_scales(self, def get_optimal_circle_size(self, frame_size: Tuple[int, int], circle_frame_ratio: Optional[int] = 100) -> int: + """ + Calculate the optimal circle size for fitting within a rectangular frame based on a given ratio. + + This method computes the diameter of a circle that fits within the smallest dimension of a rectangular + frame, scaled by a specified ratio. The resulting circle size ensures that it fits within the bounds of + the frame while maintaining the specified size ratio. + + :param Tuple[int, int] frame_size: A tuple representing the dimensions of the rectangular frame (width, height). + :param Optional[int] circle_frame_ratio: An integer representing the ratio between the frame's smallest dimension and the circle's diameter. A lower ratio results in a larger circle, and a higher ratio results in a smaller circle. + :returns int: The computed diameter of the circle that fits within the smallest dimension of the frame, scaled by the `circle_frame_ratio`. + """ + check_int(name='accepted_circle_size', value=circle_frame_ratio, min_value=1) check_valid_tuple(x=frame_size, source='frame_size', accepted_lengths=(2,), valid_dtypes=(int,)) @@ -1843,7 +1853,26 @@ def put_text(self, font: Optional[int] = cv2.FONT_HERSHEY_DUPLEX, text_color: Optional[Tuple[int, int, int]] = (255, 255, 255), text_color_bg: Optional[Tuple[int, int, int]] = (0, 0, 0), - text_bg_alpha: float = 0.8): + text_bg_alpha: float = 0.8) -> np.ndarray: + + """ + Draws text on an image with a background color and transparency. + + This method overlays text on an image at the specified position, with options for adjusting font size, + thickness, background color, and background transparency. The text is drawn with an optional background + rectangle that can have a specified transparency level to ensure readability over various image backgrounds. + + :param img: The image on which the text is to be drawn. This is a NumPy array representing the image data. + :param text: The text string to be drawn on the image. + :param pos: The position (x, y) where the text will be placed on the image. The coordinates correspond to the bottom-left corner of the text. + :param font_size: The size of the font. It determines the scale factor that is multiplied by the font-specific base size. + :param font_thickness: The thickness of the text strokes. It is an integer specifying the number of pixels for the thickness. + :param font: The font type used to render the text. It corresponds to one of the predefined OpenCV font types. + :param text_color: The color of the text in RGB format. By default, the text color is white. + :param text_color_bg: The background color for the text in RGB format. By default, the background color is black. + :param text_bg_alpha: The transparency level of the background rectangle. A value between 0 and 1, where 0 is fully transparent and 1 is fully opaque. + :return: The image with the overlaid text and background rectangle. + """ check_valid_tuple(x=pos, accepted_lengths=(2,), valid_dtypes=(int,)) check_int(name='font_thickness', value=font_thickness, min_value=1) @@ -1860,7 +1889,46 @@ def put_text(self, cv2.putText(output, text, (x, y), font, font_size, text_color, font_thickness) return output + @staticmethod + def plot_bar_chart(df: pd.DataFrame, + x: str, + y: str, + error: Optional[str] = None, + x_label: Optional[str] = None, + y_label: Optional[str] = None, + title: Optional[str] = None, + fig_size: Optional[Tuple[int, int]] = (10, 8), + palette: Optional[str] = 'magma', + save_path: Optional[Union[str, os.PathLike]] = None): + + check_instance(source=f"{PlottingMixin.plot_bar_chart.__name__} df", instance=df, accepted_types=(pd.DataFrame)) + check_str(name=f"{PlottingMixin.plot_bar_chart.__name__} x", value=x, options=tuple(df.columns)) + check_str(name=f"{PlottingMixin.plot_bar_chart.__name__} y", value=y, options=tuple(df.columns)) + check_valid_lst(data=list(df[y]), source=f"{PlottingMixin.plot_bar_chart.__name__} y", valid_dtypes=Formats.NUMERIC_DTYPES.value) + fig, ax = plt.subplots(figsize=fig_size) + sns.barplot(x=x, y=y, data=df, palette=palette, ax=ax) + ax.set_xticklabels(df[x], rotation=90, fontsize=8) + if error is not None: + check_str(name=f"{PlottingMixin.plot_bar_chart.__name__} error", value=error, options=tuple(df.columns)) + check_valid_lst(data=list(df[error]), source=f"{PlottingMixin.plot_bar_chart.__name__} error",valid_dtypes=Formats.NUMERIC_DTYPES.value) + for i, (value, error) in enumerate(zip(df['FEATURE_IMPORTANCE_MEAN'], df['FEATURE_IMPORTANCE_STDEV'])): + ax.errorbar(i, value, yerr=[[0], [error]], fmt='o', color='grey', capsize=2) + if x_label is not None: + check_str(name=f"{PlottingMixin.plot_bar_chart.__name__} x_label", value=x_label) + plt.xlabel(x_label) + if y_label is not None: + check_str(name=f"{PlottingMixin.plot_bar_chart.__name__} y_label", value=y_label) + plt.ylabel(y_label) + if title is not None: + check_str(name=f"{PlottingMixin.plot_bar_chart.__name__} title", value=title) + plt.title(title, ha="center", fontsize=15) + if save_path is not None: + check_str(name=f"{PlottingMixin.plot_bar_chart.__name__} save_path", value=save_path) + check_if_dir_exists(in_dir=os.path.dirname(save_path)) + fig.savefig(save_path, dpi=600, bbox_inches='tight') + else: + return fig # from sklearn.datasets import make_blobs # #from sklearn.datasets import ma diff --git a/simba/mixins/train_model_mixin.py b/simba/mixins/train_model_mixin.py index 072004648..cf6b8963c 100644 --- a/simba/mixins/train_model_mixin.py +++ b/simba/mixins/train_model_mixin.py @@ -670,23 +670,23 @@ def create_x_importance_log(self, print("Creating feature importance log...") timer = SimbaTimer(start=True) - if isinstance(rf_clf, cuRF): - cuml_tree_nodes = loads(rf_clf.get_json()) - importances = self.cuml_rf_x_importances(nodes=cuml_tree_nodes, n_features=len(x_names)) - else: - importances = list(rf_clf.feature_importances_) - feature_importances = [(feature, round(importance, 25)) for feature, importance in zip(x_names, importances)] - df = pd.DataFrame(feature_importances, columns=["FEATURE", "FEATURE_IMPORTANCE"]).sort_values( - by=["FEATURE_IMPORTANCE"], ascending=False) if save_file_no != None: - self.f_importance_save_path = os.path.join(save_dir, - f"{clf_name}_{save_file_no}_feature_importance_log.csv") + self.f_importance_save_path = os.path.join(save_dir, f"{clf_name}_{save_file_no}_feature_importance_log.csv") else: self.f_importance_save_path = os.path.join(save_dir, f"{clf_name}_feature_importance_log.csv") + if cuRF is not None and isinstance(rf_clf, cuRF): + cuml_tree_nodes = loads(rf_clf.get_json()) + importances = list(self.cuml_rf_x_importances(nodes=cuml_tree_nodes, n_features=len(x_names))) + std_importances = [np.nan] * len(importances) + else: + importances_per_tree = np.array([tree.feature_importances_ for tree in rf_clf.estimators_]) + importances = list(np.mean(importances_per_tree, axis=0)) + std_importances = list(np.std(importances_per_tree, axis=0)) + importances = [round(importance, 25) for importance in importances] + df = pd.DataFrame({'FEATURE': x_names,'FEATURE_IMPORTANCE_MEAN': importances, 'FEATURE_IMPORTANCE_STDEV': std_importances}).sort_values(by=["FEATURE_IMPORTANCE_MEAN"], ascending=False) df.to_csv(self.f_importance_save_path, index=False) timer.stop_timer() - print( - f'Feature importance log saved at {self.f_importance_save_path} (elapsed time: {timer.elapsed_time_str}s)') + stdout_success(msg=f'Feature importance log saved at {self.f_importance_save_path}!', elapsed_time=timer.elapsed_time_str) def create_x_importance_bar_chart(self, rf_clf: RandomForestClassifier, @@ -717,24 +717,23 @@ def create_x_importance_bar_chart(self, check_int(name="FEATURE IMPORTANCE BAR COUNT", value=n_bars, min_value=1) print("Creating feature importance bar chart...") timer = SimbaTimer(start=True) - self.create_x_importance_log(rf_clf, x_names, clf_name, save_dir) - importances_df = pd.read_csv(os.path.join(save_dir, f"{clf_name}_feature_importance_log.csv")) - importances_head = importances_df.head(n_bars) - colors = create_color_palette(pallete_name=palette, increments=n_bars, as_rgb_ratio=True) - colors = [x[::-1] for x in colors] - ax = importances_head.plot.bar(x="FEATURE", y="FEATURE_IMPORTANCE", legend=False, rot=90, fontsize=6, - color=colors) - plt.ylabel("Feature importances' (mean decrease impurity)", fontsize=6) - plt.tight_layout() if save_file_no != None: save_file_path = os.path.join(save_dir, f"{clf_name}_{save_file_no}_feature_importance_bar_graph.png") else: save_file_path = os.path.join(save_dir, f"{clf_name}_feature_importance_bar_graph.png") - plt.savefig(save_file_path, dpi=600) - plt.close("all") + self.create_x_importance_log(rf_clf, x_names, clf_name, save_dir) + importances_df = pd.read_csv(os.path.join(save_dir, f"{clf_name}_feature_importance_log.csv")) + importances_head = importances_df.head(n_bars) + _ = PlottingMixin.plot_bar_chart(df=importances_head, + x='FEATURE', + y="FEATURE_IMPORTANCE_MEAN", + error='FEATURE_IMPORTANCE_STDEV', + x_label='FEATURE', + y_label='IMPORTANCE', + title=f'SimBA feature importances {clf_name}', + save_path=save_file_path) timer.stop_timer() - print( - f'Feature importance bar chart complete, saved at {save_file_path} (elapsed time: {timer.elapsed_time_str}s)') + print(f'Feature importance bar chart complete, saved at {save_file_path} (elapsed time: {timer.elapsed_time_str}s)') def dviz_classification_visualization( self,