Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
sronilsson committed Dec 13, 2024
2 parents 10f3d03 + 69dac60 commit 422d5b4
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 46 deletions.
4 changes: 3 additions & 1 deletion simba/mixins/feature_extraction_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ def angle3pt_serialized(data: np.ndarray) -> np.ndarray:
:align: center
.. seealso::
For GPU acceleration, use :func:`simba.data_processors.cuda.statistics.get_3pt_angle`.
For GPU acceleration, use :func:`simba.data_processors.cuda.statistics.get_3pt_angle` for single frame alternative,
see :func:`simba.mixins.feature_extraction_mixin.FeatureExtractionMixin.angle3pt`
:param ndarray data: 2D numerical array with frame number on x and [ax, ay, bx, by, cx, cy] on y.
:return: 1d float numerical array of size data.shape[0] with angles.
Expand Down Expand Up @@ -611,6 +612,7 @@ def jitted_line_crosses_to_nonstatic_targets(left_ear_array: np.ndarray,
.. image:: _static/img/directing_moving_targets.png
:width: 400
:align: center
:param np.ndarray left_ear_array: 2D array of size len(frames) x 2 with the coordinates of the observer animals left ear
:param np.ndarray right_ear_array: 2D array of size len(frames) x 2 with the coordinates of the observer animals right ear
:param np.ndarray nose_array: 2D array of size len(frames) x 2 with the coordinates of the observer animals nose
Expand Down
53 changes: 27 additions & 26 deletions simba/mixins/geometry_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,8 @@ def view_shapes(shapes: List[Union[LineString, Polygon, MultiPolygon, MultiLineS
bg_img: Optional[np.ndarray] = None,
bg_clr: Optional[Tuple[int, int, int]] = None,
size: Optional[int] = None,
color_palette: Optional[str] = 'Set1',
color_palette: Union[str, List[Tuple[int, int, int]]] = 'Set1',
fill_shapes: Optional[bool] = False,
thickness: Optional[int] = 2,
pixel_buffer: Optional[int] = 200,
circle_size: Optional[int] = 2) -> np.ndarray:
Expand Down Expand Up @@ -888,41 +889,41 @@ def view_shapes(shapes: List[Union[LineString, Polygon, MultiPolygon, MultiLineS
img = np.full((max_vertices[0], max_vertices[1], 3), bg_clr, dtype=np.uint8)
else:
img = bg_img
check_instance(source='view_shapes color_palette', instance=color_palette, accepted_types=(list, str))
if isinstance(color_palette, str):
check_str(name='color_palette', value=color_palette, options=Options.PALETTE_OPTIONS_CATEGORICAL.value + Options.PALETTE_OPTIONS.value)
colors = create_color_palette(pallete_name=color_palette, increments=len(shapes) + 1)
else:
check_valid_lst(data=color_palette, source='color_palette', valid_dtypes=(tuple,), exact_len=len(shapes))
for clr in color_palette:
check_if_valid_rgb_tuple(data=clr)
colors = color_palette

if color_palette is not None:
check_str(name='color_palette', value=color_palette,
options=Options.PALETTE_OPTIONS_CATEGORICAL.value + Options.PALETTE_OPTIONS.value)

colors = create_color_palette(pallete_name=color_palette, increments=len(shapes) + 1)
for shape_cnt, shape in enumerate(shapes):
if isinstance(shape, Polygon):
cv2.polylines(img, [np.array(shape.exterior.coords).astype(np.int32)], True, (colors[shape_cnt][::-1]),
thickness=thickness)
interior_coords = [np.array(interior.coords, dtype=np.int32).reshape((-1, 1, 2)) for interior in
shape.interiors]
if not fill_shapes:
cv2.polylines(img, [np.array(shape.exterior.coords).astype(np.int32)], True, (colors[shape_cnt][::-1]), thickness=thickness)
else:
cv2.fillPoly(img, [np.array(shape.exterior.coords).astype(np.int32)], (colors[shape_cnt][::-1]))
interior_coords = [np.array(interior.coords, dtype=np.int32).reshape((-1, 1, 2)) for interior in shape.interiors]
for interior in interior_coords:
cv2.polylines(img, [interior], isClosed=True, color=(colors[shape_cnt][::-1]),
thickness=thickness, )
if not fill_shapes:
cv2.polylines(img, [interior], isClosed=True, color=(colors[shape_cnt][::-1]), thickness=thickness)
else:
cv2.fillPoly(img, [interior], (colors[shape_cnt][::-1]), lineType=None, shift=None, offset=None)
if isinstance(shape, LineString):
if color_palette is None:
cv2.polylines(img, [np.array(shape.coords, dtype=np.int32)], False, (colors[shape_cnt][::-1]),
thickness=thickness)
else:
lines = np.array(shape.coords, dtype=np.int32)
palette = create_color_palette(pallete_name=color_palette, increments=lines.shape[0])
for i in range(1, lines.shape[0]):
p1, p2 = lines[i - 1], lines[i]
cv2.line(img, tuple(p1), tuple(p2), palette[i], thickness)
lines = np.array(shape.coords, dtype=np.int32)
for i in range(1, lines.shape[0]):
p1, p2 = lines[i - 1], lines[i]
cv2.line(img, tuple(p1), tuple(p2), colors[shape_cnt][::-1], thickness)

if isinstance(shape, MultiPolygon):
multi_polygon_clrs = create_color_palette(pallete_name=color_palette, increments=len(shape.geoms) + 1)
for polygon_cnt, polygon in enumerate(shape.geoms):
polygon_np = np.array((polygon.convex_hull.exterior.coords), dtype=np.int32)
cv2.polylines(img, [polygon_np], True, (multi_polygon_clrs[polygon_cnt][::-1]), thickness=thickness)
cv2.polylines(img, [polygon_np], True, (colors[shape_cnt][::-1]), thickness=thickness)
if isinstance(shape, MultiLineString):
for line_cnt, line in enumerate(shape.geoms):
cv2.polylines(img, [np.array(shape[line_cnt].coords, dtype=np.int32)], False,
(colors[shape_cnt][::-1]), thickness=thickness)
cv2.polylines(img, [np.array(shape[line_cnt].coords, dtype=np.int32)], False, (colors[shape_cnt][::-1]), thickness=thickness)
if isinstance(shape, Point):
arr = np.array((shape.coords)).astype(np.int32)
x, y = arr[0][0], arr[0][1]
Expand Down Expand Up @@ -2999,7 +3000,7 @@ def bucket_img_into_grid_square(img_size: Iterable[int],
:param Iterable[int] img_size: 2-value tuple, list or array representing the width and height of the image in pixels.
:param Optional[float] bucket_grid_size_mm: The width/height of each square bucket in millimeters. E.g., 50 will create 5cm by 5cm squares. If None, then buckets will by defined by ``bucket_grid_size`` argument.
:param Optional[Iterable[int]] bucket_grid_size: 2-value tuple, list or array representing the grid square in number of horizontal squares x number of vertical squares. If None, then buckets will be defined by the ``bucket_size_mm`` argument.
:param Optional[Iterable[int, int]] bucket_grid_size: 2-value tuple, list or array representing the grid square in number of horizontal squares x number of vertical squares. If None, then buckets will be defined by the ``bucket_size_mm`` argument.
:param Optional[float] px_per_mm: Pixels per millimeter conversion factor. Necessery if buckets are defined by ``bucket_size_mm`` argument.
:param Optional[bool] add_correction: If True, performs correction by adding extra columns or rows to cover any remaining space if using ``bucket_size_mm``. Default True.
:param Optional[bool] verbose: If True, prints progress / completion information. Default False.
Expand Down
4 changes: 3 additions & 1 deletion simba/model/regression/model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Dict, List, Optional, Tuple
from itertools import product
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import xgboost as xgb
from sklearn.model_selection import StratifiedKFold

from simba.model.regression.metrics import (mean_absolute_error,
mean_absolute_percentage_error,
mean_squared_error, r2_score,
Expand All @@ -15,6 +16,7 @@
from simba.utils.enums import Formats
from simba.utils.errors import DataHeaderError


def fit_xgb(x: pd.DataFrame,
y: np.ndarray,
xgb_reg: xgb.XGBRegressor) -> xgb.XGBRegressor:
Expand Down
27 changes: 9 additions & 18 deletions simba/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,10 @@ def slp_to_df_convert(
return data_df


def find_ranked_colors(data: Dict[str, float], palette: str, as_hex: Optional[bool] = False) -> Dict[str, Union[Tuple[int], str]]:
def find_ranked_colors(data: Dict[str, float],
palette: str,
as_hex: Optional[bool] = False,
reverse: Optional[bool] = True) -> Dict[str, Union[Tuple[int], str]]:
"""
Find ranked colors for a given data dictionary values based on a specified color palette.
Expand All @@ -1057,27 +1060,15 @@ def find_ranked_colors(data: Dict[str, float], palette: str, as_hex: Optional[bo
"""

if palette not in Options.PALETTE_OPTIONS.value:
raise InvalidInputError(
msg=f"{palette} is not a valid palette. Options {Options.PALETTE_OPTIONS.value}",
source=find_ranked_colors.__name__,
)
check_instance(
source=find_ranked_colors.__name__, instance=data, accepted_types=dict
)
raise InvalidInputError(msg=f"{palette} is not a valid palette. Options {Options.PALETTE_OPTIONS.value}", source=find_ranked_colors.__name__)
check_instance(source=find_ranked_colors.__name__, instance=data, accepted_types=dict)
for k, v in data.items():
check_str(name=k, value=k)
check_float(name=v, value=v)
clrs = create_color_palette(
pallete_name=palette, increments=len(list(data.keys())) - 1, as_hex=as_hex
)
check_str(name=k, value=k); check_float(name=v, value=v)
clrs = create_color_palette(pallete_name=palette, increments=len(list(data.keys())) - 1, as_hex=as_hex)
ranks, results = deepcopy(data), {}
ranks = {
key: rank
for rank, key in enumerate(sorted(ranks, key=ranks.get, reverse=True), 1)
}
ranks = {key: rank for rank, key in enumerate(sorted(ranks, key=ranks.get, reverse=reverse), 1)}
for k, v in ranks.items():
results[k] = clrs[int(v) - 1]

return results


Expand Down

0 comments on commit 422d5b4

Please sign in to comment.