Skip to content

Commit

Permalink
bar_graphs
Browse files Browse the repository at this point in the history
  • Loading branch information
sronilsson committed Aug 8, 2024
1 parent 205b10a commit cf361a9
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 37 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# Setup configuration
setuptools.setup(
name="Simba-UW-tf-dev",
version="2.0.2",
version="2.0.4",
author="Simon Nilsson, Jia Jie Choong, Sophia Hwang",
author_email="[email protected]",
description="Toolkit for computer classification and analysis of behaviors in experimental animals",
Expand Down
90 changes: 79 additions & 11 deletions simba/mixins/plotting_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1718,15 +1718,14 @@ def insert_directing_line(
return img

@staticmethod
def draw_lines_on_img(
img: np.ndarray,
start_positions: np.ndarray,
end_positions: np.ndarray,
color: Tuple[int, int, int],
highlight_endpoint: Optional[bool] = False,
thickness: Optional[int] = 2,
circle_size: Optional[int] = 2,
) -> np.ndarray:
def draw_lines_on_img(img: np.ndarray,
start_positions: np.ndarray,
end_positions: np.ndarray,
color: Tuple[int, int, int],
highlight_endpoint: Optional[bool] = False,
thickness: Optional[int] = 2,
circle_size: Optional[int] = 2) -> np.ndarray:

"""
Helper to draw a set of lines onto an image.
Expand All @@ -1737,7 +1736,6 @@ def draw_lines_on_img(
:param Optional[bool] highlight_endpoint: If True, highlights the ends of the lines with circles.
:param Optional[int] thickness: The thickness of the lines.
:param Optional[int] circle_size: If ``highlight_endpoint`` is True, the size of the highlighted points.
:return np.ndarray: The image with the lines overlayed.
"""

Expand Down Expand Up @@ -1826,6 +1824,18 @@ def get_optimal_font_scales(self,
def get_optimal_circle_size(self,
frame_size: Tuple[int, int],
circle_frame_ratio: Optional[int] = 100) -> int:
"""
Calculate the optimal circle size for fitting within a rectangular frame based on a given ratio.
This method computes the diameter of a circle that fits within the smallest dimension of a rectangular
frame, scaled by a specified ratio. The resulting circle size ensures that it fits within the bounds of
the frame while maintaining the specified size ratio.
:param Tuple[int, int] frame_size: A tuple representing the dimensions of the rectangular frame (width, height).
:param Optional[int] circle_frame_ratio: An integer representing the ratio between the frame's smallest dimension and the circle's diameter. A lower ratio results in a larger circle, and a higher ratio results in a smaller circle.
:returns int: The computed diameter of the circle that fits within the smallest dimension of the frame, scaled by the `circle_frame_ratio`.
"""


check_int(name='accepted_circle_size', value=circle_frame_ratio, min_value=1)
check_valid_tuple(x=frame_size, source='frame_size', accepted_lengths=(2,), valid_dtypes=(int,))
Expand All @@ -1843,7 +1853,26 @@ def put_text(self,
font: Optional[int] = cv2.FONT_HERSHEY_DUPLEX,
text_color: Optional[Tuple[int, int, int]] = (255, 255, 255),
text_color_bg: Optional[Tuple[int, int, int]] = (0, 0, 0),
text_bg_alpha: float = 0.8):
text_bg_alpha: float = 0.8) -> np.ndarray:

"""
Draws text on an image with a background color and transparency.
This method overlays text on an image at the specified position, with options for adjusting font size,
thickness, background color, and background transparency. The text is drawn with an optional background
rectangle that can have a specified transparency level to ensure readability over various image backgrounds.
:param img: The image on which the text is to be drawn. This is a NumPy array representing the image data.
:param text: The text string to be drawn on the image.
:param pos: The position (x, y) where the text will be placed on the image. The coordinates correspond to the bottom-left corner of the text.
:param font_size: The size of the font. It determines the scale factor that is multiplied by the font-specific base size.
:param font_thickness: The thickness of the text strokes. It is an integer specifying the number of pixels for the thickness.
:param font: The font type used to render the text. It corresponds to one of the predefined OpenCV font types.
:param text_color: The color of the text in RGB format. By default, the text color is white.
:param text_color_bg: The background color for the text in RGB format. By default, the background color is black.
:param text_bg_alpha: The transparency level of the background rectangle. A value between 0 and 1, where 0 is fully transparent and 1 is fully opaque.
:return: The image with the overlaid text and background rectangle.
"""

check_valid_tuple(x=pos, accepted_lengths=(2,), valid_dtypes=(int,))
check_int(name='font_thickness', value=font_thickness, min_value=1)
Expand All @@ -1860,7 +1889,46 @@ def put_text(self,
cv2.putText(output, text, (x, y), font, font_size, text_color, font_thickness)
return output

@staticmethod
def plot_bar_chart(df: pd.DataFrame,
x: str,
y: str,
error: Optional[str] = None,
x_label: Optional[str] = None,
y_label: Optional[str] = None,
title: Optional[str] = None,
fig_size: Optional[Tuple[int, int]] = (10, 8),
palette: Optional[str] = 'magma',
save_path: Optional[Union[str, os.PathLike]] = None):

check_instance(source=f"{PlottingMixin.plot_bar_chart.__name__} df", instance=df, accepted_types=(pd.DataFrame))
check_str(name=f"{PlottingMixin.plot_bar_chart.__name__} x", value=x, options=tuple(df.columns))
check_str(name=f"{PlottingMixin.plot_bar_chart.__name__} y", value=y, options=tuple(df.columns))
check_valid_lst(data=list(df[y]), source=f"{PlottingMixin.plot_bar_chart.__name__} y", valid_dtypes=Formats.NUMERIC_DTYPES.value)
fig, ax = plt.subplots(figsize=fig_size)
sns.barplot(x=x, y=y, data=df, palette=palette, ax=ax)
ax.set_xticklabels(df[x], rotation=90, fontsize=8)
if error is not None:
check_str(name=f"{PlottingMixin.plot_bar_chart.__name__} error", value=error, options=tuple(df.columns))
check_valid_lst(data=list(df[error]), source=f"{PlottingMixin.plot_bar_chart.__name__} error",valid_dtypes=Formats.NUMERIC_DTYPES.value)
for i, (value, error) in enumerate(zip(df['FEATURE_IMPORTANCE_MEAN'], df['FEATURE_IMPORTANCE_STDEV'])):
ax.errorbar(i, value, yerr=[[0], [error]], fmt='o', color='grey', capsize=2)

if x_label is not None:
check_str(name=f"{PlottingMixin.plot_bar_chart.__name__} x_label", value=x_label)
plt.xlabel(x_label)
if y_label is not None:
check_str(name=f"{PlottingMixin.plot_bar_chart.__name__} y_label", value=y_label)
plt.ylabel(y_label)
if title is not None:
check_str(name=f"{PlottingMixin.plot_bar_chart.__name__} title", value=title)
plt.title(title, ha="center", fontsize=15)
if save_path is not None:
check_str(name=f"{PlottingMixin.plot_bar_chart.__name__} save_path", value=save_path)
check_if_dir_exists(in_dir=os.path.dirname(save_path))
fig.savefig(save_path, dpi=600, bbox_inches='tight')
else:
return fig

# from sklearn.datasets import make_blobs
# #from sklearn.datasets import ma
Expand Down
49 changes: 24 additions & 25 deletions simba/mixins/train_model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,23 +670,23 @@ def create_x_importance_log(self,

print("Creating feature importance log...")
timer = SimbaTimer(start=True)
if isinstance(rf_clf, cuRF):
cuml_tree_nodes = loads(rf_clf.get_json())
importances = self.cuml_rf_x_importances(nodes=cuml_tree_nodes, n_features=len(x_names))
else:
importances = list(rf_clf.feature_importances_)
feature_importances = [(feature, round(importance, 25)) for feature, importance in zip(x_names, importances)]
df = pd.DataFrame(feature_importances, columns=["FEATURE", "FEATURE_IMPORTANCE"]).sort_values(
by=["FEATURE_IMPORTANCE"], ascending=False)
if save_file_no != None:
self.f_importance_save_path = os.path.join(save_dir,
f"{clf_name}_{save_file_no}_feature_importance_log.csv")
self.f_importance_save_path = os.path.join(save_dir, f"{clf_name}_{save_file_no}_feature_importance_log.csv")
else:
self.f_importance_save_path = os.path.join(save_dir, f"{clf_name}_feature_importance_log.csv")
if cuRF is not None and isinstance(rf_clf, cuRF):
cuml_tree_nodes = loads(rf_clf.get_json())
importances = list(self.cuml_rf_x_importances(nodes=cuml_tree_nodes, n_features=len(x_names)))
std_importances = [np.nan] * len(importances)
else:
importances_per_tree = np.array([tree.feature_importances_ for tree in rf_clf.estimators_])
importances = list(np.mean(importances_per_tree, axis=0))
std_importances = list(np.std(importances_per_tree, axis=0))
importances = [round(importance, 25) for importance in importances]
df = pd.DataFrame({'FEATURE': x_names,'FEATURE_IMPORTANCE_MEAN': importances, 'FEATURE_IMPORTANCE_STDEV': std_importances}).sort_values(by=["FEATURE_IMPORTANCE_MEAN"], ascending=False)
df.to_csv(self.f_importance_save_path, index=False)
timer.stop_timer()
print(
f'Feature importance log saved at {self.f_importance_save_path} (elapsed time: {timer.elapsed_time_str}s)')
stdout_success(msg=f'Feature importance log saved at {self.f_importance_save_path}!', elapsed_time=timer.elapsed_time_str)

def create_x_importance_bar_chart(self,
rf_clf: RandomForestClassifier,
Expand Down Expand Up @@ -717,24 +717,23 @@ def create_x_importance_bar_chart(self,
check_int(name="FEATURE IMPORTANCE BAR COUNT", value=n_bars, min_value=1)
print("Creating feature importance bar chart...")
timer = SimbaTimer(start=True)
self.create_x_importance_log(rf_clf, x_names, clf_name, save_dir)
importances_df = pd.read_csv(os.path.join(save_dir, f"{clf_name}_feature_importance_log.csv"))
importances_head = importances_df.head(n_bars)
colors = create_color_palette(pallete_name=palette, increments=n_bars, as_rgb_ratio=True)
colors = [x[::-1] for x in colors]
ax = importances_head.plot.bar(x="FEATURE", y="FEATURE_IMPORTANCE", legend=False, rot=90, fontsize=6,
color=colors)
plt.ylabel("Feature importances' (mean decrease impurity)", fontsize=6)
plt.tight_layout()
if save_file_no != None:
save_file_path = os.path.join(save_dir, f"{clf_name}_{save_file_no}_feature_importance_bar_graph.png")
else:
save_file_path = os.path.join(save_dir, f"{clf_name}_feature_importance_bar_graph.png")
plt.savefig(save_file_path, dpi=600)
plt.close("all")
self.create_x_importance_log(rf_clf, x_names, clf_name, save_dir)
importances_df = pd.read_csv(os.path.join(save_dir, f"{clf_name}_feature_importance_log.csv"))
importances_head = importances_df.head(n_bars)
_ = PlottingMixin.plot_bar_chart(df=importances_head,
x='FEATURE',
y="FEATURE_IMPORTANCE_MEAN",
error='FEATURE_IMPORTANCE_STDEV',
x_label='FEATURE',
y_label='IMPORTANCE',
title=f'SimBA feature importances {clf_name}',
save_path=save_file_path)
timer.stop_timer()
print(
f'Feature importance bar chart complete, saved at {save_file_path} (elapsed time: {timer.elapsed_time_str}s)')
print(f'Feature importance bar chart complete, saved at {save_file_path} (elapsed time: {timer.elapsed_time_str}s)')

def dviz_classification_visualization(
self,
Expand Down

0 comments on commit cf361a9

Please sign in to comment.