Skip to content

Commit

Permalink
shap
Browse files Browse the repository at this point in the history
  • Loading branch information
sronilsson committed Oct 23, 2024
1 parent ae1e470 commit 95c656d
Show file tree
Hide file tree
Showing 21 changed files with 2,199 additions and 55 deletions.
Binary file added docs/_static/img/poly_area_cuda.webp
Binary file not shown.
Binary file added docs/_static/img/scale_pose_img_sizes.webp
Binary file not shown.
8 changes: 8 additions & 0 deletions docs/simba.third_party_label_appenders.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,13 @@ Generic third-party appender tool
------------------------------------------------------------------

.. automodule:: simba.third_party_label_appenders.third_party_appender
:members:
:show-inheritance:


Annotation format converters
---------------------------------------

.. automodule:: simba.third_party_label_appenders.converters
:members:
:show-inheritance:
1 change: 1 addition & 0 deletions lightning_logs/version_65/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
1 change: 1 addition & 0 deletions lightning_logs/version_66/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
1 change: 1 addition & 0 deletions lightning_logs/version_67/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
1,871 changes: 1,871 additions & 0 deletions lightning_logs/version_67/metrics.csv

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions misc/ex_yolo_model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
path: C:\troubleshooting\coco_data # dataset root dir
train: ../images/501_MA142_Gi_CNO_0514 # train images (relative to 'path') 128 images
val: ../images/F0_gq_CNO_0621 # val images (relative to 'path') 128 images
test: ../images/FL_gq_CNO_0625_78

names:
0: animal_1
13 changes: 13 additions & 0 deletions requirements-gpu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
--extra-index-url https://pypi.nvidia.com

cupy-cuda12x==13.3.0
shap==0.46.1.dev78
cuml-cu12==24.10.0
torch==2.5.0
ultralytics==8.3.19






2 changes: 1 addition & 1 deletion simba/SimBA.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def __init__(self, config_path: str):
smooth_btn = SimbaButton(parent=further_methods_frm, txt="SMOOTH POSE IN SIMBA PROJECT", txt_clr='blue', compound='right', img='wand_blue', font=Formats.FONT_REGULAR.value, cmd=SmoothingPopUp, cmd_kwargs={'config_path': lambda:self.config_path})

label_setscale = CreateLabelFrameWithIcon(parent=tab3, header="VIDEO PARAMETERS (FPS, RESOLUTION, PPX/MM ....)", icon_name=Keys.DOCUMENTATION.value, icon_link=Links.VIDEO_PARAMETERS.value)
self.distance_in_mm_eb = Entry_Box(label_setscale, "KNOWN DISTANCE (MILLIMETERS)", "25", validation="numeric")
self.distance_in_mm_eb = Entry_Box(label_setscale, "KNOWN DISTANCE (MILLIMETERS)", "35", validation="numeric")
button_setdistanceinmm = SimbaButton(parent=label_setscale, txt="AUTO-POPULATE", txt_clr='blue', font=Formats.FONT_REGULAR.value, cmd=self.set_distance_mm)

button_setscale = SimbaButton(parent=label_setscale, txt="CONFIGURE VIDEO PARAMETERS", txt_clr='blue', font=Formats.FONT_REGULAR.value, cmd=self.create_video_info_table, img='calipher')
Expand Down
98 changes: 76 additions & 22 deletions simba/bounding_box_tools/yolo/model.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,84 @@
import functools
import multiprocessing
import os
from typing import Optional, Tuple, Union
from typing import Optional, Tuple, Union, Dict

import numpy as np
import pandas as pd
import torch
from ultralytics import YOLO

from simba.utils.enums import Defaults
from simba.utils.read_write import (find_core_cnt, get_video_meta_data,
read_img_batch_from_video_gpu)
from simba.utils.printing import stdout_success, SimbaTimer
from simba.utils.checks import check_file_exist_and_readable, check_if_dir_exists, check_int, get_fn_ext
from simba.utils.read_write import (get_video_meta_data)


def fit_yolo(initial_weights: Union[str, os.PathLike],
data: Union[str, os.PathLike],
model_yaml: Union[str, os.PathLike],
project_path: Union[str, os.PathLike],
epochs: Optional[int] = 5,
batch: Optional[Union[int, float]] = 16):
"""
Trains a YOLO model using specified initial weights and a configuration YAML file.
:param initial_weights:
:param data:
:param project_path:
:param epochs:
:param batch:
:return:
.. note::
`Download initial weights <https://docs.ultralytics.com/tasks/obb/#export>`__.
`Example model_yaml <https://github.com/sgoldenlab/simba/blob/master/misc/ex_yolo_model.yaml>`__.
:param initial_weights: Path to the pre-trained YOLO model weights (usually a `.pt` file). Example weights can be found [here](https://docs.ultralytics.com/tasks/obb/#export).
:param model_yaml: YAML file containing paths to the training, validation, and testing datasets and the object class mappings. Example YAML file can be found [here](https://github.com/sgoldenlab/simba/blob/master/misc/ex_yolo_model.yaml).
:param project_path: irectory path where the trained model, logs, and results will be saved.
:param epochs: Number of epochs to train the model. Default is 5.
:param batch: Batch size for training. Default is 16.
:return: None. The trained model and associated training logs are saved in the specified `project_path`.
:example:
>>> fit_yolo(initial_weights=r"C:\troubleshooting\coco_data\weights\yolov8n-obb.pt", data=r"C:\troubleshooting\coco_data\model.yaml", project_path=r"C:\troubleshooting\coco_data\mdl", batch=16)
"""

if not torch.cuda.is_available():
raise ModuleNotFoundError('No GPU detected.')
check_file_exist_and_readable(file_path=initial_weights)
check_file_exist_and_readable(file_path=model_yaml)
check_if_dir_exists(in_dir=project_path)
check_int(name='epochs', value=epochs, min_value=1)
model = YOLO(initial_weights)
model.train(data=data, epochs=epochs, project=project_path, batch=batch)
model.train(data=model_yaml, epochs=epochs, project=project_path, batch=batch)

def inference_yolo(weights: Union[str, os.PathLike],
video_path: Union[str, os.PathLike],
batch: Optional[Union[int, float]] = 100,
verbose: Optional[bool] = False,
save_dir: Optional[Union[str, os.PathLike]] = None):
save_dir: Optional[Union[str, os.PathLike]] = None,
gpu: Optional[bool] = False) -> Union[None, Dict[str, pd.DataFrame]]:
"""
Performs object detection inference on a video using a YOLO model.
This function runs YOLO-based object detection on a given video file, optionally utilizing GPU acceleration for
inference, and either returns the results or saves them to a specified directory. The function outputs bounding box
coordinates and class confidence scores for detected objects in each frame of the video.
:param Union[str, os.PathLike] weights: Path to the YOLO model weights file.
:param Union[str, os.PathLike] video_path: Path to the input video file for performing inference.
:param Optional[bool] verbose: If True, outputs progress information and timing. Defaults to False.
:param Optional[Union[str, os.PathLike]] save_dir: Directory to save the inference results as CSV files. If not provided, results are returned as a dictionary. Defaults to None.
:param Optional[bool] gpu: If True, performs inference on the GPU. Defaults to False.
:example:
>>> inference_yolo(weights=r"/mnt/c/troubleshooting/coco_data/mdl/train8/weights/best.pt", video_path=r"/mnt/c/troubleshooting/mitra/project_folder/videos/FRR_gq_Saline_0624.mp4", save_dir=r"/mnt/c/troubleshooting/coco_data/mdl/results", verbose=True, gpu=True)
"""

timer = SimbaTimer(start=True)
torch.set_num_threads(8)
model = YOLO(weights, verbose=verbose)
# model.export(format='engine')
# model.to('cuda')
results = []
if gpu:
model.export(format='engine')
model.to('cuda')
results = {}
out_cols = ['FRAME', 'CLASS', 'CONFIDENCE', 'X1', 'Y1', 'X2', 'Y2', 'X3', 'Y3', 'X4', 'Y4']
if save_dir is not None:
check_if_dir_exists(in_dir=save_dir, source=inference_yolo.__name__)
if os.path.isfile(video_path):
_ = get_video_meta_data(video_path=video_path)
_, video_name, _ = get_fn_ext(filepath=video_path)
video_out = []
video_results = model(video_path)
for frm_cnt, frm in enumerate(video_results):
if frm.obb is not None:
Expand All @@ -62,11 +90,37 @@ def inference_yolo(weights: Union[str, os.PathLike],
cls_data = data[np.argwhere(data[:, -1] == c)].reshape(-1, data.shape[1])
cls_data = cls_data[np.argmax(data[:, -2].flatten())]
cord_data = np.array([cls_data[0], cls_data[1], cls_data[0], cls_data[3], cls_data[1], cls_data[3], cls_data[2], cls_data[1]]).astype(np.int32)
results.append([frm_cnt, cls_data[-1], cls_data[-2]] + list(cord_data))
results = pd.DataFrame(results, columns=out_cols)
video_out.append([frm_cnt, cls_data[-1], cls_data[-2]] + list(cord_data))
results[video_name] = pd.DataFrame(video_out, columns=out_cols)

if not save_dir:
return results
else:
for k, v in results.items():
save_path = os.path.join(save_dir, f'{k}.csv')
v.to_csv(save_path)
if verbose:
timer.stop_timer()
stdout_success(f'Results saved in {save_dir} directory', elapsed_time=timer.elapsed_time_str)






# fit_yolo(initial_weights=r"/mnt/c/troubleshooting/coco_data/weights/yolov8n-obb.pt",
# model_yaml=r"/mnt/c/troubleshooting/coco_data/model.yaml",
# project_path=r"/mnt/c/troubleshooting/coco_data/mdl",
# batch=16, epochs=100)


# inference_yolo(weights=r"/mnt/c/troubleshooting/coco_data/mdl/train8/weights/best.pt",
# video_path=r"/mnt/c/troubleshooting/mitra/project_folder/videos/FRR_gq_Saline_0624.mp4",
# save_dir=r"/mnt/c/troubleshooting/coco_data/mdl/results",
# verbose=True,
# gpu=True)
#
#


# r = inference_yolo(weights=r"C:\troubleshooting\coco_data\mdl\train\weights\best.pt",
Expand Down
4 changes: 2 additions & 2 deletions simba/data_processors/cuda/circular_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def sliding_circular_std(x: np.ndarray,
sample_rate: float,
batch_size: Optional[int] = int(5e+7)) -> np.ndarray:

"""
r"""
Calculate the sliding circular standard deviation of a time series data on GPU.
This function computes the circular standard deviation over a sliding window for a given time series array.
Expand Down Expand Up @@ -307,7 +307,7 @@ def sliding_rayleigh_z(x: np.ndarray,
sample_rate: float,
batch_size: Optional[int] = int(5e+7)) -> Tuple[np.ndarray, np.ndarray]:

"""
r"""
Computes the Rayleigh Z-statistic over a sliding window for a given time series of angles
This function calculates the Rayleigh Z-statistic, which tests the null hypothesis that the population of angles
Expand Down
6 changes: 6 additions & 0 deletions simba/data_processors/cuda/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,12 @@ def poly_area(data: np.ndarray,
.. seealso::
:func:`~simba.feature_extractors.perimeter_jit.jitted_hull`.
.. image:: _static/img/simba.data_processors.cuda.geometry.poly_area_cuda.webp
:width: 450
:align: center
:param data: A 3D numpy array of shape (N, M, 2), where N is the number of polygons, M is the number of points per polygon, and 2 represents the x and y coordinates.
:param pixels_per_mm: Optional scaling factor to convert the area from pixels squared to square millimeters. Default is 1.0.
:param batch_size: Optional batch size for processing the data in chunks to fit in memory. Default is 0.5e+7.
Expand Down
42 changes: 42 additions & 0 deletions simba/mixins/image_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1807,6 +1807,48 @@ def is_video_color(video: Union[str, os.PathLike, cv2.VideoCapture]):
else:
return False

@staticmethod
def resize_img_dict(imgs: Dict[str, np.ndarray],
size: Union[Literal['min', 'max'], Tuple[int, int]],
interpolation: Optional[int] = cv2.INTER_LINEAR) -> Dict[str, np.ndarray]:

"""
Resize a dictionary of images to a specified size.
:param Dict[str, np.ndarray] imgs: A dictionary where keys are image names (strings) and values are NumPy arrays representing the images.
:param Union[Literal['min', 'max'], Tuple[int, int]] size: The target size for the resizing operation. It can be: - `'min'`: Resize all images to the smallest height and width found among the input images. - `'max'`: Resize all images to the largest height and width found among the input images. - Tuple of two integers `(height, width)`: Explicitly specify the target size for all images.
:param interpolation: Interpolation method to use for resizing. This can be one of OpenCV's interpolation methods.
:return: A dictionary of resized images, where the keys match the original dictionary, and the values are the resized images as NumPy arrays.
:rtype: Dict[str, np.ndarray]
"""

check_instance(source=ImageMixin.resize_img_dict.__name__, instance=imgs, accepted_types=(dict,))
check_instance(source=ImageMixin.resize_img_dict.__name__, instance=size, accepted_types=(tuple, str,))
results = {}
if size == 'min':
target_h, target_w = np.inf, np.inf
for k, v in imgs.items():
target_h, target_w = min(v.shape[0], target_h), min(v.shape[1], target_w)
elif size == 'max':
target_h, target_w = -np.inf, -np.inf
for k, v in imgs.items():
target_h, target_w = max(v.shape[0], target_h), max(v.shape[1], target_w)
elif isinstance(size, tuple):
check_valid_tuple(x=size, accepted_lengths=(2,), valid_dtypes=(int,))
check_int(name=ImageMixin.resize_img_dict.__name__, value=size[0], min_value=1)
check_int(name=ImageMixin.resize_img_dict.__name__, value=size[1], min_value=1)
target_h, target_w = size[0], size[1]
else:
raise InvalidInputError(msg=f'{size} is not a valid size argument.', source=ImageMixin.resize_img_dict.__name__)

for k, v in imgs.items():
check_if_valid_img(data=v, source=ImageMixin.resize_img_dict.__name__)
results[k] = cv2.resize(v, dsize=(target_w, target_h), fx=0, fy=0, interpolation=interpolation)

return results
def blah(self):
print(1)


#x = ImageMixin.get_blob_locations(video_path=r"C:\troubleshooting\RAT_NOR\project_folder\videos\2022-06-20_NOB_DOT_4_downsampled_bg_subtracted.mp4", gpu=True)
# imgs = ImageMixin().read_all_img_in_dir(dir='/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/videos/examples')
Expand Down
10 changes: 7 additions & 3 deletions simba/model/train_rf.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,13 @@ def save(self) -> None:



test = TrainRandomForestClassifier(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini")
test.run()
test.save()




# test = TrainRandomForestClassifier(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini")
# test.run()
# test.save()



Expand Down
21 changes: 11 additions & 10 deletions simba/plotting/pose_plotter_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,12 @@

from simba.mixins.config_reader import ConfigReader
from simba.mixins.plotting_mixin import PlottingMixin
from simba.utils.checks import (check_instance, check_int, check_str,
check_that_column_exist)
from simba.utils.checks import (check_instance, check_int, check_str, check_that_column_exist, check_valid_boolean, check_nvidea_gpu_available)
from simba.utils.data import create_color_palette
from simba.utils.enums import OS, Formats, Options
from simba.utils.errors import CountError, InvalidFilepathError
from simba.utils.printing import SimbaTimer, stdout_success
from simba.utils.read_write import (concatenate_videos_in_folder,
find_core_cnt,
find_files_of_filetypes_in_directory,
get_fn_ext, get_video_meta_data, read_df)
from simba.utils.read_write import (concatenate_videos_in_folder, find_core_cnt, find_files_of_filetypes_in_directory, get_fn_ext, get_video_meta_data, read_df)
from simba.utils.warnings import FrameRangeWarning


Expand Down Expand Up @@ -65,8 +61,7 @@ class PosePlotterMultiProcess():
:param str in_directory: Path to SimBA project directory containing pose-estimation data in parquet or CSV format.
:param str out_directory: Directory to where to save the pose-estimation videos.
:param int Size of the circles denoting the location of the pose-estimated body-parts.
:param Optional[dict] clr_attr: Python dict where animals are keys and color attributes values. E.g., {'Animal_1': (255, 107, 198)}. If None,
random palettes will be used.
:param Optional[dict] clr_attr: Python dict where animals are keys and color attributes values. E.g., {'Animal_1': (255, 107, 198)}. If None, random palettes will be used.
.. image:: _static/img/pose_plotter.png
:width: 600
Expand All @@ -83,6 +78,7 @@ def __init__(self,
palettes: Optional[Dict[str, str]] = None,
circle_size: Optional[int] = None,
core_cnt: Optional[int] = -1,
gpu: Optional[bool] = False,
sample_time: Optional[int] = None) -> None:

if os.path.isdir(data_path):
Expand Down Expand Up @@ -122,6 +118,11 @@ def __init__(self,
if not os.path.exists(self.out_dir):
os.makedirs(self.out_dir)
self.data = {}
check_valid_boolean(value=[gpu])
if gpu and check_nvidea_gpu_available():
self.gpu = True
else:
self.gpu = False
for file in files_found:
self.data[file] = self.config.find_video_of_file(video_dir=self.config.video_dir, filename=get_fn_ext(file)[1])
if platform.system() == OS.MAC.value:
Expand Down Expand Up @@ -162,11 +163,11 @@ def run(self):
circle_size=video_circle_size,
video_save_dir=self.temp_folder)
for cnt, result in enumerate(pool.imap(constants, pose_lst, chunksize=self.config.multiprocess_chunksize)):
print(f"Image {obs_per_split*(cnt+1)}/{len(pose_df)}, Video {file_cnt+1}/{len(list(self.data.keys()))}...")
print(f"Image {max(len(pose_df), obs_per_split*(cnt+1))}/{len(pose_df)}, Video {file_cnt+1}/{len(list(self.data.keys()))}...")
pool.terminate()
pool.join()
print(f"Joining {video_name} multi-processed video...")
concatenate_videos_in_folder(in_folder=self.temp_folder, save_path=save_video_path, remove_splits=True)
concatenate_videos_in_folder(in_folder=self.temp_folder, save_path=save_video_path, remove_splits=True, gpu=self.gpu)
video_timer.stop_timer()
stdout_success(msg=f"Pose video {video_name} complete and saved at {save_video_path}", elapsed_time=video_timer.elapsed_time_str, source=self.__class__.__name__)
self.config.timer.stop_timer()
Expand Down
3 changes: 2 additions & 1 deletion simba/pose_importers/sleap_csv_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def run(self):
idx = data_df.iloc[:, :2]
check_that_column_exist(df=idx, column_name=TRACK, file_name=video_name)
idx[TRACK] = idx[TRACK].fillna("track_1")
idx[TRACK] = idx[TRACK].str.replace(r"[^\d.]+", "").astype(int)
#idx[TRACK] = idx[TRACK].str.replace(r"[^\d.]+", "").astype(int)
idx[TRACK] = idx[TRACK].str.replace(r"[^\d.]+", "", regex=True).astype(int)
data_df = data_df.iloc[:, 2:].fillna(0)
if self.animal_cnt > 1:
self.data_df = pd.DataFrame(self.transpose_multi_animal_table(data=data_df.values, idx=idx.values, animal_cnt=self.animal_cnt))
Expand Down
Loading

0 comments on commit 95c656d

Please sign in to comment.