Skip to content

Commit

Permalink
morans prep
Browse files Browse the repository at this point in the history
  • Loading branch information
sronilsson committed Dec 12, 2024
1 parent 11d64a8 commit 69dac60
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 45 deletions.
4 changes: 3 additions & 1 deletion simba/mixins/feature_extraction_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ def angle3pt_serialized(data: np.ndarray) -> np.ndarray:
:align: center
.. seealso::
For GPU acceleration, use :func:`simba.data_processors.cuda.statistics.get_3pt_angle`.
For GPU acceleration, use :func:`simba.data_processors.cuda.statistics.get_3pt_angle` for single frame alternative,
see :func:`simba.mixins.feature_extraction_mixin.FeatureExtractionMixin.angle3pt`
:param ndarray data: 2D numerical array with frame number on x and [ax, ay, bx, by, cx, cy] on y.
:return: 1d float numerical array of size data.shape[0] with angles.
Expand Down Expand Up @@ -611,6 +612,7 @@ def jitted_line_crosses_to_nonstatic_targets(left_ear_array: np.ndarray,
.. image:: _static/img/directing_moving_targets.png
:width: 400
:align: center
:param np.ndarray left_ear_array: 2D array of size len(frames) x 2 with the coordinates of the observer animals left ear
:param np.ndarray right_ear_array: 2D array of size len(frames) x 2 with the coordinates of the observer animals right ear
:param np.ndarray nose_array: 2D array of size len(frames) x 2 with the coordinates of the observer animals nose
Expand Down
53 changes: 27 additions & 26 deletions simba/mixins/geometry_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,8 @@ def view_shapes(shapes: List[Union[LineString, Polygon, MultiPolygon, MultiLineS
bg_img: Optional[np.ndarray] = None,
bg_clr: Optional[Tuple[int, int, int]] = None,
size: Optional[int] = None,
color_palette: Optional[str] = 'Set1',
color_palette: Union[str, List[Tuple[int, int, int]]] = 'Set1',
fill_shapes: Optional[bool] = False,
thickness: Optional[int] = 2,
pixel_buffer: Optional[int] = 200,
circle_size: Optional[int] = 2) -> np.ndarray:
Expand Down Expand Up @@ -888,41 +889,41 @@ def view_shapes(shapes: List[Union[LineString, Polygon, MultiPolygon, MultiLineS
img = np.full((max_vertices[0], max_vertices[1], 3), bg_clr, dtype=np.uint8)
else:
img = bg_img
check_instance(source='view_shapes color_palette', instance=color_palette, accepted_types=(list, str))
if isinstance(color_palette, str):
check_str(name='color_palette', value=color_palette, options=Options.PALETTE_OPTIONS_CATEGORICAL.value + Options.PALETTE_OPTIONS.value)
colors = create_color_palette(pallete_name=color_palette, increments=len(shapes) + 1)
else:
check_valid_lst(data=color_palette, source='color_palette', valid_dtypes=(tuple,), exact_len=len(shapes))
for clr in color_palette:
check_if_valid_rgb_tuple(data=clr)
colors = color_palette

if color_palette is not None:
check_str(name='color_palette', value=color_palette,
options=Options.PALETTE_OPTIONS_CATEGORICAL.value + Options.PALETTE_OPTIONS.value)

colors = create_color_palette(pallete_name=color_palette, increments=len(shapes) + 1)
for shape_cnt, shape in enumerate(shapes):
if isinstance(shape, Polygon):
cv2.polylines(img, [np.array(shape.exterior.coords).astype(np.int32)], True, (colors[shape_cnt][::-1]),
thickness=thickness)
interior_coords = [np.array(interior.coords, dtype=np.int32).reshape((-1, 1, 2)) for interior in
shape.interiors]
if not fill_shapes:
cv2.polylines(img, [np.array(shape.exterior.coords).astype(np.int32)], True, (colors[shape_cnt][::-1]), thickness=thickness)
else:
cv2.fillPoly(img, [np.array(shape.exterior.coords).astype(np.int32)], (colors[shape_cnt][::-1]))
interior_coords = [np.array(interior.coords, dtype=np.int32).reshape((-1, 1, 2)) for interior in shape.interiors]
for interior in interior_coords:
cv2.polylines(img, [interior], isClosed=True, color=(colors[shape_cnt][::-1]),
thickness=thickness, )
if not fill_shapes:
cv2.polylines(img, [interior], isClosed=True, color=(colors[shape_cnt][::-1]), thickness=thickness)
else:
cv2.fillPoly(img, [interior], (colors[shape_cnt][::-1]), lineType=None, shift=None, offset=None)
if isinstance(shape, LineString):
if color_palette is None:
cv2.polylines(img, [np.array(shape.coords, dtype=np.int32)], False, (colors[shape_cnt][::-1]),
thickness=thickness)
else:
lines = np.array(shape.coords, dtype=np.int32)
palette = create_color_palette(pallete_name=color_palette, increments=lines.shape[0])
for i in range(1, lines.shape[0]):
p1, p2 = lines[i - 1], lines[i]
cv2.line(img, tuple(p1), tuple(p2), palette[i], thickness)
lines = np.array(shape.coords, dtype=np.int32)
for i in range(1, lines.shape[0]):
p1, p2 = lines[i - 1], lines[i]
cv2.line(img, tuple(p1), tuple(p2), colors[shape_cnt][::-1], thickness)

if isinstance(shape, MultiPolygon):
multi_polygon_clrs = create_color_palette(pallete_name=color_palette, increments=len(shape.geoms) + 1)
for polygon_cnt, polygon in enumerate(shape.geoms):
polygon_np = np.array((polygon.convex_hull.exterior.coords), dtype=np.int32)
cv2.polylines(img, [polygon_np], True, (multi_polygon_clrs[polygon_cnt][::-1]), thickness=thickness)
cv2.polylines(img, [polygon_np], True, (colors[shape_cnt][::-1]), thickness=thickness)
if isinstance(shape, MultiLineString):
for line_cnt, line in enumerate(shape.geoms):
cv2.polylines(img, [np.array(shape[line_cnt].coords, dtype=np.int32)], False,
(colors[shape_cnt][::-1]), thickness=thickness)
cv2.polylines(img, [np.array(shape[line_cnt].coords, dtype=np.int32)], False, (colors[shape_cnt][::-1]), thickness=thickness)
if isinstance(shape, Point):
arr = np.array((shape.coords)).astype(np.int32)
x, y = arr[0][0], arr[0][1]
Expand Down Expand Up @@ -2999,7 +3000,7 @@ def bucket_img_into_grid_square(img_size: Iterable[int],
:param Iterable[int] img_size: 2-value tuple, list or array representing the width and height of the image in pixels.
:param Optional[float] bucket_grid_size_mm: The width/height of each square bucket in millimeters. E.g., 50 will create 5cm by 5cm squares. If None, then buckets will by defined by ``bucket_grid_size`` argument.
:param Optional[Iterable[int]] bucket_grid_size: 2-value tuple, list or array representing the grid square in number of horizontal squares x number of vertical squares. If None, then buckets will be defined by the ``bucket_size_mm`` argument.
:param Optional[Iterable[int, int]] bucket_grid_size: 2-value tuple, list or array representing the grid square in number of horizontal squares x number of vertical squares. If None, then buckets will be defined by the ``bucket_size_mm`` argument.
:param Optional[float] px_per_mm: Pixels per millimeter conversion factor. Necessery if buckets are defined by ``bucket_size_mm`` argument.
:param Optional[bool] add_correction: If True, performs correction by adding extra columns or rows to cover any remaining space if using ``bucket_size_mm``. Default True.
:param Optional[bool] verbose: If True, prints progress / completion information. Default False.
Expand Down
27 changes: 9 additions & 18 deletions simba/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,10 @@ def slp_to_df_convert(
return data_df


def find_ranked_colors(data: Dict[str, float], palette: str, as_hex: Optional[bool] = False) -> Dict[str, Union[Tuple[int], str]]:
def find_ranked_colors(data: Dict[str, float],
palette: str,
as_hex: Optional[bool] = False,
reverse: Optional[bool] = True) -> Dict[str, Union[Tuple[int], str]]:
"""
Find ranked colors for a given data dictionary values based on a specified color palette.
Expand All @@ -1057,27 +1060,15 @@ def find_ranked_colors(data: Dict[str, float], palette: str, as_hex: Optional[bo
"""

if palette not in Options.PALETTE_OPTIONS.value:
raise InvalidInputError(
msg=f"{palette} is not a valid palette. Options {Options.PALETTE_OPTIONS.value}",
source=find_ranked_colors.__name__,
)
check_instance(
source=find_ranked_colors.__name__, instance=data, accepted_types=dict
)
raise InvalidInputError(msg=f"{palette} is not a valid palette. Options {Options.PALETTE_OPTIONS.value}", source=find_ranked_colors.__name__)
check_instance(source=find_ranked_colors.__name__, instance=data, accepted_types=dict)
for k, v in data.items():
check_str(name=k, value=k)
check_float(name=v, value=v)
clrs = create_color_palette(
pallete_name=palette, increments=len(list(data.keys())) - 1, as_hex=as_hex
)
check_str(name=k, value=k); check_float(name=v, value=v)
clrs = create_color_palette(pallete_name=palette, increments=len(list(data.keys())) - 1, as_hex=as_hex)
ranks, results = deepcopy(data), {}
ranks = {
key: rank
for rank, key in enumerate(sorted(ranks, key=ranks.get, reverse=True), 1)
}
ranks = {key: rank for rank, key in enumerate(sorted(ranks, key=ranks.get, reverse=reverse), 1)}
for k, v in ranks.items():
results[k] = clrs[int(v) - 1]

return results


Expand Down

0 comments on commit 69dac60

Please sign in to comment.