From 69dac602d0699c5b8517bcd1a74de23607966761 Mon Sep 17 00:00:00 2001 From: sronilsson Date: Thu, 12 Dec 2024 15:12:46 -0500 Subject: [PATCH] morans prep --- simba/mixins/feature_extraction_mixin.py | 4 +- simba/mixins/geometry_mixin.py | 53 ++++++++++++------------ simba/utils/data.py | 27 ++++-------- 3 files changed, 39 insertions(+), 45 deletions(-) diff --git a/simba/mixins/feature_extraction_mixin.py b/simba/mixins/feature_extraction_mixin.py index cbc118581..46cb81611 100644 --- a/simba/mixins/feature_extraction_mixin.py +++ b/simba/mixins/feature_extraction_mixin.py @@ -141,7 +141,8 @@ def angle3pt_serialized(data: np.ndarray) -> np.ndarray: :align: center .. seealso:: - For GPU acceleration, use :func:`simba.data_processors.cuda.statistics.get_3pt_angle`. + For GPU acceleration, use :func:`simba.data_processors.cuda.statistics.get_3pt_angle` for single frame alternative, + see :func:`simba.mixins.feature_extraction_mixin.FeatureExtractionMixin.angle3pt` :param ndarray data: 2D numerical array with frame number on x and [ax, ay, bx, by, cx, cy] on y. :return: 1d float numerical array of size data.shape[0] with angles. @@ -611,6 +612,7 @@ def jitted_line_crosses_to_nonstatic_targets(left_ear_array: np.ndarray, .. image:: _static/img/directing_moving_targets.png :width: 400 :align: center + :param np.ndarray left_ear_array: 2D array of size len(frames) x 2 with the coordinates of the observer animals left ear :param np.ndarray right_ear_array: 2D array of size len(frames) x 2 with the coordinates of the observer animals right ear :param np.ndarray nose_array: 2D array of size len(frames) x 2 with the coordinates of the observer animals nose diff --git a/simba/mixins/geometry_mixin.py b/simba/mixins/geometry_mixin.py index 61943889f..266fbda04 100644 --- a/simba/mixins/geometry_mixin.py +++ b/simba/mixins/geometry_mixin.py @@ -845,7 +845,8 @@ def view_shapes(shapes: List[Union[LineString, Polygon, MultiPolygon, MultiLineS bg_img: Optional[np.ndarray] = None, bg_clr: Optional[Tuple[int, int, int]] = None, size: Optional[int] = None, - color_palette: Optional[str] = 'Set1', + color_palette: Union[str, List[Tuple[int, int, int]]] = 'Set1', + fill_shapes: Optional[bool] = False, thickness: Optional[int] = 2, pixel_buffer: Optional[int] = 200, circle_size: Optional[int] = 2) -> np.ndarray: @@ -888,41 +889,41 @@ def view_shapes(shapes: List[Union[LineString, Polygon, MultiPolygon, MultiLineS img = np.full((max_vertices[0], max_vertices[1], 3), bg_clr, dtype=np.uint8) else: img = bg_img + check_instance(source='view_shapes color_palette', instance=color_palette, accepted_types=(list, str)) + if isinstance(color_palette, str): + check_str(name='color_palette', value=color_palette, options=Options.PALETTE_OPTIONS_CATEGORICAL.value + Options.PALETTE_OPTIONS.value) + colors = create_color_palette(pallete_name=color_palette, increments=len(shapes) + 1) + else: + check_valid_lst(data=color_palette, source='color_palette', valid_dtypes=(tuple,), exact_len=len(shapes)) + for clr in color_palette: + check_if_valid_rgb_tuple(data=clr) + colors = color_palette - if color_palette is not None: - check_str(name='color_palette', value=color_palette, - options=Options.PALETTE_OPTIONS_CATEGORICAL.value + Options.PALETTE_OPTIONS.value) - - colors = create_color_palette(pallete_name=color_palette, increments=len(shapes) + 1) for shape_cnt, shape in enumerate(shapes): if isinstance(shape, Polygon): - cv2.polylines(img, [np.array(shape.exterior.coords).astype(np.int32)], True, (colors[shape_cnt][::-1]), - thickness=thickness) - interior_coords = [np.array(interior.coords, dtype=np.int32).reshape((-1, 1, 2)) for interior in - shape.interiors] + if not fill_shapes: + cv2.polylines(img, [np.array(shape.exterior.coords).astype(np.int32)], True, (colors[shape_cnt][::-1]), thickness=thickness) + else: + cv2.fillPoly(img, [np.array(shape.exterior.coords).astype(np.int32)], (colors[shape_cnt][::-1])) + interior_coords = [np.array(interior.coords, dtype=np.int32).reshape((-1, 1, 2)) for interior in shape.interiors] for interior in interior_coords: - cv2.polylines(img, [interior], isClosed=True, color=(colors[shape_cnt][::-1]), - thickness=thickness, ) + if not fill_shapes: + cv2.polylines(img, [interior], isClosed=True, color=(colors[shape_cnt][::-1]), thickness=thickness) + else: + cv2.fillPoly(img, [interior], (colors[shape_cnt][::-1]), lineType=None, shift=None, offset=None) if isinstance(shape, LineString): - if color_palette is None: - cv2.polylines(img, [np.array(shape.coords, dtype=np.int32)], False, (colors[shape_cnt][::-1]), - thickness=thickness) - else: - lines = np.array(shape.coords, dtype=np.int32) - palette = create_color_palette(pallete_name=color_palette, increments=lines.shape[0]) - for i in range(1, lines.shape[0]): - p1, p2 = lines[i - 1], lines[i] - cv2.line(img, tuple(p1), tuple(p2), palette[i], thickness) + lines = np.array(shape.coords, dtype=np.int32) + for i in range(1, lines.shape[0]): + p1, p2 = lines[i - 1], lines[i] + cv2.line(img, tuple(p1), tuple(p2), colors[shape_cnt][::-1], thickness) if isinstance(shape, MultiPolygon): - multi_polygon_clrs = create_color_palette(pallete_name=color_palette, increments=len(shape.geoms) + 1) for polygon_cnt, polygon in enumerate(shape.geoms): polygon_np = np.array((polygon.convex_hull.exterior.coords), dtype=np.int32) - cv2.polylines(img, [polygon_np], True, (multi_polygon_clrs[polygon_cnt][::-1]), thickness=thickness) + cv2.polylines(img, [polygon_np], True, (colors[shape_cnt][::-1]), thickness=thickness) if isinstance(shape, MultiLineString): for line_cnt, line in enumerate(shape.geoms): - cv2.polylines(img, [np.array(shape[line_cnt].coords, dtype=np.int32)], False, - (colors[shape_cnt][::-1]), thickness=thickness) + cv2.polylines(img, [np.array(shape[line_cnt].coords, dtype=np.int32)], False, (colors[shape_cnt][::-1]), thickness=thickness) if isinstance(shape, Point): arr = np.array((shape.coords)).astype(np.int32) x, y = arr[0][0], arr[0][1] @@ -2999,7 +3000,7 @@ def bucket_img_into_grid_square(img_size: Iterable[int], :param Iterable[int] img_size: 2-value tuple, list or array representing the width and height of the image in pixels. :param Optional[float] bucket_grid_size_mm: The width/height of each square bucket in millimeters. E.g., 50 will create 5cm by 5cm squares. If None, then buckets will by defined by ``bucket_grid_size`` argument. - :param Optional[Iterable[int]] bucket_grid_size: 2-value tuple, list or array representing the grid square in number of horizontal squares x number of vertical squares. If None, then buckets will be defined by the ``bucket_size_mm`` argument. + :param Optional[Iterable[int, int]] bucket_grid_size: 2-value tuple, list or array representing the grid square in number of horizontal squares x number of vertical squares. If None, then buckets will be defined by the ``bucket_size_mm`` argument. :param Optional[float] px_per_mm: Pixels per millimeter conversion factor. Necessery if buckets are defined by ``bucket_size_mm`` argument. :param Optional[bool] add_correction: If True, performs correction by adding extra columns or rows to cover any remaining space if using ``bucket_size_mm``. Default True. :param Optional[bool] verbose: If True, prints progress / completion information. Default False. diff --git a/simba/utils/data.py b/simba/utils/data.py index c8ef3c7cc..dc42d106f 100644 --- a/simba/utils/data.py +++ b/simba/utils/data.py @@ -1037,7 +1037,10 @@ def slp_to_df_convert( return data_df -def find_ranked_colors(data: Dict[str, float], palette: str, as_hex: Optional[bool] = False) -> Dict[str, Union[Tuple[int], str]]: +def find_ranked_colors(data: Dict[str, float], + palette: str, + as_hex: Optional[bool] = False, + reverse: Optional[bool] = True) -> Dict[str, Union[Tuple[int], str]]: """ Find ranked colors for a given data dictionary values based on a specified color palette. @@ -1057,27 +1060,15 @@ def find_ranked_colors(data: Dict[str, float], palette: str, as_hex: Optional[bo """ if palette not in Options.PALETTE_OPTIONS.value: - raise InvalidInputError( - msg=f"{palette} is not a valid palette. Options {Options.PALETTE_OPTIONS.value}", - source=find_ranked_colors.__name__, - ) - check_instance( - source=find_ranked_colors.__name__, instance=data, accepted_types=dict - ) + raise InvalidInputError(msg=f"{palette} is not a valid palette. Options {Options.PALETTE_OPTIONS.value}", source=find_ranked_colors.__name__) + check_instance(source=find_ranked_colors.__name__, instance=data, accepted_types=dict) for k, v in data.items(): - check_str(name=k, value=k) - check_float(name=v, value=v) - clrs = create_color_palette( - pallete_name=palette, increments=len(list(data.keys())) - 1, as_hex=as_hex - ) + check_str(name=k, value=k); check_float(name=v, value=v) + clrs = create_color_palette(pallete_name=palette, increments=len(list(data.keys())) - 1, as_hex=as_hex) ranks, results = deepcopy(data), {} - ranks = { - key: rank - for rank, key in enumerate(sorted(ranks, key=ranks.get, reverse=True), 1) - } + ranks = {key: rank for rank, key in enumerate(sorted(ranks, key=ranks.get, reverse=reverse), 1)} for k, v in ranks.items(): results[k] = clrs[int(v) - 1] - return results