Skip to content

Commit

Permalink
gibbs
Browse files Browse the repository at this point in the history
  • Loading branch information
sronilsson committed Nov 22, 2024
1 parent ae38f1a commit 2028686
Show file tree
Hide file tree
Showing 12 changed files with 608 additions and 130 deletions.
2 changes: 1 addition & 1 deletion docs/_static/custom.css
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@

.simba-table td {
vertical-align: middle; /* Align text vertically in the center of cells */
}
}
Binary file added docs/_static/img/get_video_slic.webm
Binary file not shown.
17 changes: 17 additions & 0 deletions docs/simba.data_processors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,20 @@ Egocentric data / video alignment
.. automodule:: simba.data_processors.egocentric_aligner
:members:
:undoc-members:


Heuristic circling detector
------------------------------------------------------------

.. automodule:: simba.data_processors.circling_detector
:members:
:undoc-members:


Heuristic freezing detector
------------------------------------------------------------

.. automodule:: simba.data_processors.freezing_detector
:members:
:undoc-members:

122 changes: 122 additions & 0 deletions simba/data_processors/circling_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import os
import numpy as np
import pandas as pd
from numba import typed
from simba.utils.read_write import find_files_of_filetypes_in_directory, read_df, get_fn_ext, read_video_info
from simba.mixins.circular_statistics import CircularStatisticsMixin
from simba.mixins.feature_extraction_mixin import FeatureExtractionMixin
from simba.mixins.timeseries_features_mixin import TimeseriesFeatureMixin
from simba.mixins.config_reader import ConfigReader
from simba.utils.enums import Formats
from typing import Union, Optional
from simba.utils.checks import check_if_dir_exists, check_str, check_valid_dataframe, check_int, check_all_file_names_are_represented_in_video_log
from simba.utils.data import detect_bouts, plug_holes_shortest_bout
from simba.utils.printing import stdout_success

CIRCLING = 'CIRCLING'

class CirclingDetector(ConfigReader):

"""
Detect circling using heuristic rules.
.. important::
Circling is detected as :underline:`present` when **the circular range of the animal is above the ``circular_range_threshold`` within the preceding ``time_threshold``** AND
**the movement of the animal (defined as the sum of the center movement) is above the ``movement_threshold`` within the preceding ``time_threshold``.**
Circling is detected as :underline:`absent` when not present.
:param Union[str, os.PathLike] data_dir: Path to directory containing pose-estimated body-part data in CSV format.
:param Union[str, os.PathLike] config_path: Path to SimBA project config file.
:param Optional[str] nose_name: The name of the pose-estimated nose body-part. Defaults to 'nose'.
:param Optional[str] left_ear_name: The name of the pose-estimated left ear body-part. Defaults to 'left_ear'.
:param Optional[str] right_ear_name: The name of the pose-estimated right ear body-part. Defaults to 'right_ear'.
:param Optional[str] tail_base_name: The name of the pose-estimated tail base body-part. Defaults to 'tail_base'.
:param Optional[str] center_name: The name of the pose-estimated center body-part. Defaults to 'center'.
:param Optional[int] time_threshold: The time window in preceding seconds in which to evaluate the animals circular range. Default: 10.
:param Optional[int] circular_range_threshold: A value in degrees, between 0-360.
:param Optional[int] movement_threshold: A movement threshold in millimeters.
:param Optional[Union[str, os.PathLike]] save_dir: Directory where to store the results. If None, then results are stored in the ``logs`` directory of the SimBA project.
References
----------
.. [1] Sabnis et al., Visual detection of seizures in mice using supervised machine learning, `biorxiv`, doi: https://doi.org/10.1101/2024.05.29.596520.
:example:
>>> CirclingDetector(data_dir=r'D:\troubleshooting\mitra\project_folder\csv\outlier_corrected_movement_location', config_path=r"D:\troubleshooting\mitra\project_folder\project_config.ini")
"""

def __init__(self,
data_dir: Union[str, os.PathLike],
config_path: Union[str, os.PathLike],
nose_name: Optional[str] = 'nose',
left_ear_name: Optional[str] = 'left_ear',
right_ear_name: Optional[str] = 'right_ear',
tail_base_name: Optional[str] = 'tail_base',
center_name: Optional[str] = 'center',
time_threshold: Optional[int] = 10,
circular_range_threshold: Optional[int] = 320,
movement_threshold: Optional[int] = 60,
save_dir: Optional[Union[str, os.PathLike]] = None):

check_if_dir_exists(in_dir=data_dir)
for bp_name in [nose_name, left_ear_name, right_ear_name, tail_base_name]: check_str(name='body part name', value=bp_name, allow_blank=False)
self.data_paths = find_files_of_filetypes_in_directory(directory=data_dir, extensions=['.csv'])
ConfigReader.__init__(self, config_path=config_path, read_video_info=True, create_logger=False)
self.nose_heads = [f'{nose_name}_x'.lower(), f'{nose_name}_y'.lower()]
self.left_ear_heads = [f'{left_ear_name}_x'.lower(), f'{left_ear_name}_y'.lower()]
self.right_ear_heads = [f'{right_ear_name}_x'.lower(), f'{right_ear_name}_y'.lower()]
self.center_heads = [f'{center_name}_x'.lower(), f'{center_name}_y'.lower()]
self.required_field = self.nose_heads + self.left_ear_heads + self.right_ear_heads
self.save_dir = save_dir
if self.save_dir is None:
self.save_dir = os.path.join(self.logs_path, f'circling_data_{self.datetime}')
os.makedirs(self.save_dir)
else:
check_if_dir_exists(in_dir=self.save_dir)
self.time_threshold, self.circular_range_threshold, self.movement_threshold = time_threshold, circular_range_threshold, movement_threshold

def run(self):
agg_results = pd.DataFrame(columns=['VIDEO', 'CIRCLING FRAMES', 'CIRCLING TIME (S)', 'CIRCLING BOUT COUNTS', 'CIRCLING PCT OF SESSION', 'VIDEO TOTAL FRAMES', 'VIDEO TOTAL TIME (S)'])
agg_results_path = os.path.join(self.save_dir, 'aggregate_circling_results.csv')
check_all_file_names_are_represented_in_video_log(video_info_df=self.video_info_df, data_paths=self.data_paths)
for file_cnt, file_path in enumerate(self.data_paths):
video_name = get_fn_ext(filepath=file_path)[1]
print(f'Analyzing {video_name} ({file_cnt+1}/{len(self.data_paths)})...')
save_file_path = os.path.join(self.save_dir, f'{video_name}.csv')
df = read_df(file_path=file_path, file_type='csv').reset_index(drop=True)
_, px_per_mm, fps = read_video_info(video_info_df=self.video_info_df, video_name=video_name)
df.columns = [str(x).lower() for x in df.columns]
check_valid_dataframe(df=df, valid_dtypes=Formats.NUMERIC_DTYPES.value, required_fields=self.required_field)

nose_arr = df[self.nose_heads].values.astype(np.float32)
left_ear_arr = df[self.left_ear_heads].values.astype(np.float32)
right_ear_arr = df[self.right_ear_heads].values.astype(np.float32)

center_shifted = FeatureExtractionMixin.create_shifted_df(df[self.center_heads])
center_1, center_2 = center_shifted.iloc[:, 0:2].values, center_shifted.iloc[:, 2:4].values

angle_degrees = CircularStatisticsMixin().direction_three_bps(nose_loc=nose_arr, left_ear_loc=left_ear_arr, right_ear_loc=right_ear_arr).astype(np.float32)
sliding_circular_range = CircularStatisticsMixin().sliding_circular_range(data=angle_degrees, time_windows=np.array([self.time_threshold], dtype=np.float64), fps=int(fps)).flatten()
movement = FeatureExtractionMixin.euclidean_distance(bp_1_x=center_1[:, 0].flatten(), bp_2_x=center_2[:, 0].flatten(), bp_1_y=center_1[:, 1].flatten(), bp_2_y=center_2[:, 1].flatten(), px_per_mm=2.15)
movement_sum = TimeseriesFeatureMixin.sliding_descriptive_statistics(data=movement.astype(np.float32), window_sizes=np.array([self.time_threshold], dtype=np.float64), sample_rate=fps, statistics=typed.List(["sum"])).astype(np.int32)[0].flatten()

circling_idx = np.argwhere(sliding_circular_range >= self.circular_range_threshold).astype(np.int32).flatten()
movement_idx = np.argwhere(movement_sum >= self.movement_threshold).astype(np.int32).flatten()
circling_idx = [x for x in movement_idx if x in circling_idx]
df[CIRCLING] = 0
df.loc[circling_idx, CIRCLING] = 1
bouts = detect_bouts(data_df=df, target_lst=[CIRCLING], fps=fps)
df = plug_holes_shortest_bout(data_df=df, clf_name=CIRCLING, fps=fps, shortest_bout=100)
df.to_csv(save_file_path)
agg_results.loc[len(agg_results)] = [video_name, len(circling_idx), round(len(circling_idx) / fps, 4), len(bouts), round((len(circling_idx) / len(df)) * 100, 4), len(df), round(len(df)/fps, 2) ]

agg_results.to_csv(agg_results_path)
stdout_success(msg=f'Results saved in {self.save_dir} directory.')



# detector = CirclingDetector(data_dir=r'D:\troubleshooting\mitra\project_folder\csv\outlier_corrected_movement_location', config_path=r"D:\troubleshooting\mitra\project_folder\project_config.ini")
# detector.run()

2 changes: 1 addition & 1 deletion simba/data_processors/cuda/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from numba import cuda, njit

from simba.utils.checks import check_float, check_int, check_valid_array
from simba.utils.checks import check_float, check_valid_array
from simba.utils.enums import Formats

try:
Expand Down
82 changes: 75 additions & 7 deletions simba/data_processors/cuda/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import cv2
import numpy as np
from numba import cuda
from cupyx.scipy.ndimage import rotate

from simba.data_processors.cuda.utils import _cuda_mse
from simba.mixins.image_mixin import ImageMixin
Expand Down Expand Up @@ -170,13 +171,13 @@ def _average_3d_stack_cuda(image_stack: np.ndarray) -> np.ndarray:


def create_average_frm_cuda(video_path: Union[str, os.PathLike],
start_frm: Optional[int] = None,
end_frm: Optional[int] = None,
start_time: Optional[str] = None,
end_time: Optional[str] = None,
save_path: Optional[Union[str, os.PathLike]] = None,
batch_size: Optional[int] = 6000,
verbose: Optional[bool] = False) -> Union[None, np.ndarray]:
start_frm: Optional[int] = None,
end_frm: Optional[int] = None,
start_time: Optional[str] = None,
end_time: Optional[str] = None,
save_path: Optional[Union[str, os.PathLike]] = None,
batch_size: Optional[int] = 6000,
verbose: Optional[bool] = False) -> Union[None, np.ndarray]:
"""
Computes the average frame using GPU acceleration from a specified range of frames or time interval in a video file.
This average frame typically used for background substraction.
Expand Down Expand Up @@ -908,5 +909,72 @@ def sliding_psnr(data: np.ndarray,
_sliding_psnr[bpg, THREADS_PER_BLOCK](data_dev, stride_dev, results_dev)
return results_dev.copy_to_host()

def rotate_img_stack_cupy(imgs: np.ndarray,
rotation_degrees: Optional[float] = 180,
batch_size: Optional[int] = 500) -> np.ndarray:
"""
Rotates a stack of images by a specified number of degrees using GPU acceleration with CuPy.
Accepts a 3D (single-channel images) or 4D (multichannel images) NumPy array, rotates each image in the stack by the specified degree around the center, and returns the result as a NumPy array.
:param np.ndarray imgs: The input stack of images to be rotated. Expected to be a NumPy array with 3 or 4 dimensions. 3D shape: (num_images, height, width) - 4D shape: (num_images, height, width, channels)
:param Optional[float] rotation_degrees: The angle by which the images should be rotated, in degrees. Must be between 1 and 359 degrees. Defaults to 180 degrees.
:param Optional[int] batch_size: Number of images to process on GPU in each batch. Decrease if data can't fit on GPU RAM.
:returns: A NumPy array containing the rotated images with the same shape as the input.
:rtype: np.ndarray
:example:
>>> video_path = r"/mnt/c/troubleshooting/mitra/project_folder/videos/F0_gq_Saline_0626_clipped.mp4"
>>> imgs = read_img_batch_from_video_gpu(video_path=video_path)
>>> imgs = np.stack(np.array(list(imgs.values())), axis=0)
>>> imgs = rotate_img_stack_cupy(imgs=imgs, rotation=50)
"""

check_valid_array(data=imgs, source=f'{rotate_img_stack_cupy.__name__} imgs', accepted_ndims=(3, 4))
check_int(name=f'{rotate_img_stack_cupy.__name__} rotation', value=rotation_degrees, min_value=1, max_value=359)
results = cp.full_like(imgs, fill_value=np.nan, dtype=np.uint8)
for l in range(0, imgs.shape[0], batch_size):
r = l + batch_size
batch_imgs = cp.array(imgs[l:r])
results[l:r] = rotate(input=batch_imgs, angle=rotation_degrees, axes=(2, 1), reshape=True)
return results.get()

def rotate_video_cupy(video_path: Union[str, os.PathLike],
save_path: Optional[Union[str, os.PathLike]] = None,
rotation_degrees: Optional[float] = 180,
batch_cnt: Optional[int] = 1) -> None:
"""
Rotates a video by a specified angle using GPU acceleration and CuPy for image processing.
:param Union[str, os.PathLike] video_path: Path to the input video file.
:param Optional[Union[str, os.PathLike]] save_path: Path to save the rotated video. If None, saves the video in the same directory as the input with '_rotated_<rotation_degrees>' appended to the filename.
:param nptional[float] rotation_degrees: Degrees to rotate the video. Must be between 1 and 359 degrees. Default is 180.
:param Optional[int] batch_cnt: Number of batches to split the video frames into for processing. Higher values reduce memory usage. Default is 1.
:returns: None.
:example:
>>> video_path = r"/mnt/c/troubleshooting/mitra/project_folder/videos/F0_gq_Saline_0626_clipped.mp4"
>>> rotate_video_cupy(video_path=video_path, rotation_degrees=45)
"""

timer = SimbaTimer(start=True)
check_int(name=f'{rotate_img_stack_cupy.__name__} rotation', value=rotation_degrees, min_value=1, max_value=359)
check_int(name=f'{rotate_img_stack_cupy.__name__} batch_cnt', value=batch_cnt, min_value=1)
if save_path is None:
video_dir, video_name, _ = get_fn_ext(filepath=video_path)
save_path = os.path.join(video_dir, f'{video_name}_rotated_{rotation_degrees}.mp4')
video_meta_data = get_video_meta_data(video_path=video_path)
fourcc = cv2.VideoWriter_fourcc(*Formats.MP4_CODEC.value)
is_clr = ImageMixin.is_video_color(video=video_path)
frm_ranges = np.arange(0, video_meta_data['frame_count'])
frm_ranges = np.array_split(frm_ranges, batch_cnt)
for frm_batch, frm_range in enumerate(frm_ranges):
imgs = read_img_batch_from_video_gpu(video_path=video_path, start_frm=frm_range[0], end_frm=frm_range[-1])
imgs = np.stack(np.array(list(imgs.values())), axis=0)
imgs = rotate_img_stack_cupy(imgs=imgs, rotation_degrees=rotation_degrees)
if frm_batch == 0:
writer = cv2.VideoWriter(save_path, fourcc, video_meta_data['fps'], (imgs.shape[2], imgs.shape[1]), isColor=is_clr)
for img in imgs: writer.write(img)
writer.release()
timer.stop_timer()
stdout_success(f'Rotated video saved at {save_path}', source=rotate_video_cupy.__name__)
Loading

0 comments on commit 2028686

Please sign in to comment.