From 6091a11ecd82a1d5a7f8f46764ac75b85c638718 Mon Sep 17 00:00:00 2001 From: sronilsson Date: Mon, 25 Nov 2024 16:02:18 -0500 Subject: [PATCH] kmeans --- simba/data_processors/cuda/statistics.py | 27 ++++++++++++++++++++++ simba/data_processors/freezing_detector.py | 4 ++-- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/simba/data_processors/cuda/statistics.py b/simba/data_processors/cuda/statistics.py index 54a838d65..7622addeb 100644 --- a/simba/data_processors/cuda/statistics.py +++ b/simba/data_processors/cuda/statistics.py @@ -17,6 +17,10 @@ from cupyx.scipy.spatial.distance import cdist except: import numpy as cp +try: + from cuml.cluster import KMeans +except: + from sklearn.cluster import KMeans from simba.utils.checks import check_int, check_valid_array, check_valid_tuple from simba.utils.enums import Formats @@ -627,3 +631,26 @@ def davis_bouldin(x: np.ndarray, max_ratio = max(max_ratio, ratio) db_index += max_ratio return db_index / n_labels + + +def kmeans_cuml(data: np.ndarray, + k: int = 2, + max_iter: int = 300, + output_type: Optional[str] = None, + sample_n: Optional[int] = None) -> Tuple[np.ndarray, np.ndarray]: + """CRAP, SLOWER THAN SCIKIT""" + + check_valid_array(data=data, source=f'{kmeans_cuml.__name__} data', accepted_dtypes=Formats.NUMERIC_DTYPES.value) + check_int(name=f'{kmeans_cuml.__name__} k', value=k, min_value=1) + check_int(name=f'{kmeans_cuml.__name__} max_iter', value=max_iter, min_value=1) + kmeans = KMeans(n_clusters=k, max_iter=max_iter) + if sample_n is not None: + check_int(name=f'{kmeans_cuml.__name__} sample', value=sample_n, min_value=1) + sample = min(sample_n, data.shape[0]) + data_idx = np.random.choice(np.arange(data.shape[0]), sample) + mdl = kmeans.fit(data[data_idx]) + else: + mdl = kmeans.fit(data) + + return (mdl.cluster_centers_, mdl.predict(data)) + diff --git a/simba/data_processors/freezing_detector.py b/simba/data_processors/freezing_detector.py index 0d43f59b9..daedd5422 100644 --- a/simba/data_processors/freezing_detector.py +++ b/simba/data_processors/freezing_detector.py @@ -130,5 +130,5 @@ def run(self): stdout_success(msg=f'Results saved in {self.save_dir} directory.') # -# FreezingDetector(data_dir=r'D:\troubleshooting\mitra\project_folder\csv\outlier_corrected_movement_location', -# config_path=r"D:\troubleshooting\mitra\project_folder\project_config.ini") \ No newline at end of file +# FreezingDetector(data_dir=r'C:\troubleshooting\mitra\project_folder\csv\outlier_corrected_movement_location', +# config_path=r"D:\troubleshooting\mitra\project_folder\project_config.ini") \ No newline at end of file