diff --git a/docs/simba.model_mixin.rst b/docs/simba.model_mixin.rst index d053c0748..142f5eaa9 100644 --- a/docs/simba.model_mixin.rst +++ b/docs/simba.model_mixin.rst @@ -1,6 +1,69 @@ Model mixin ---------------------------------------------- +Utilities for fit, inference, and evaluation of classifiers. + .. autoclass:: simba.mixins.train_model_mixin.TrainModelMixin :members: - :undoc-members: \ No newline at end of file + :undoc-members: + +Batch random forest inference +---------------------------------------------- + +.. autoclass:: simba.model.inference_batch.InferenceBatch + :members: + :undoc-members: + +Batch multi-class random forest inference +---------------------------------------------- + +.. autoclass:: simba.model.inference_multiclass_batch.InferenceMulticlassBatch + :members: + :undoc-members: + +Grid-search random forest classifiers +---------------------------------------------- + +.. autoclass:: simba.model.grid_search_rf.GridSearchRandomForestClassifier + :members: + :undoc-members: + +Grid-search random forest multi-classifiers +---------------------------------------------- + +.. autoclass:: simba.model.grid_search_multiclass_rf.GridSearchMulticlassRandomForestClassifier + :members: + :undoc-members: + +Random forest inference - validation +---------------------------------------------- + +.. autoclass:: simba.model.inference_validation.InferenceValidation + :members: + :undoc-members: + +Fit random forest classifier +---------------------------------------------- + +.. autoclass:: simba.model.train_rf.TrainRandomForestClassifier + :members: + :undoc-members: + + +Fit random forest classifier - multi-class +---------------------------------------------- + +.. autoclass:: simba.model.train_multiclass_rf.TrainMultiClassRandomForestClassifier + :members: + :undoc-members: + + +Ordinal classifier methods +---------------------------------------------- + +.. autoclass:: simba.model.ordinal_clf.OrdinalClassifier + :members: + :undoc-members: + + + diff --git a/simba/model/ordinal_clf.py b/simba/model/ordinal_clf.py new file mode 100644 index 000000000..bd9248978 --- /dev/null +++ b/simba/model/ordinal_clf.py @@ -0,0 +1,111 @@ +import os +from typing import Union, Optional, Dict +from sklearn.ensemble import RandomForestClassifier +import numpy as np +from joblib import Parallel, delayed +from sklearn import clone +from simba.mixins.train_model_mixin import TrainModelMixin +from simba.utils.errors import SamplingError, InvalidInputError +from simba.utils.checks import check_valid_array, check_int, check_if_dir_exists, check_file_exist_and_readable +from simba.utils.enums import Formats +from simba.utils.read_write import find_core_cnt, write_pickle, read_pickle + +ACCEPTED_MODELS = RandomForestClassifier + +class OrdinalClassifier(): + """ + This class implements a strategy for ordinal classification by fitting multiple binary classifiers to predict thresholds between classes. + + It is particularly useful for problems where the target variable has an inherent order but uneven intervals between levels. Thi includes human severity scores, for example, seizures, stereotopy, convulsion, bizarre behavior scores ranging fro 0-5. + + .. note:: + `Modified from sklego <`https://github.com/koaning/scikit-lego/blob/main/sklego/meta/ordinal_classification.py>`__. + + References + ---------- + .. [1] Frank, Eibe, and Mark Hall. “A Simple Approach to Ordinal Classification.” In Machine Learning: ECML 2001, edited by Luc De Raedt and Peter Flach, 2167:145–56. Lecture Notes in Computer Science. Berlin, Heidelberg: Springer Berlin Heidelberg, 2001. https://doi.org/10.1007/3-540-44795-4_13. + .. [2] Sabnis, Gautam, Leinani Hession, J. Matthew Mahoney, Arie Mobley, Marina Santos, and Vivek Kumar. “Visual Detection of Seizures in Mice Using Supervised Machine Learning,” May 31, 2024. https://doi.org/10.1101/2024.05.29.596520. + + :example: + >>> X = np.random.randint(0, 500, (100, 50)) + >>> y = np.random.randint(1, 6, (100)) + >>> rf_mdl = TrainModelMixin().clf_define() + >>> fitted_mdl = OrdinalClassifier.fit(X, y, rf_mdl, -1) + >>> y_hat = OrdinalClassifier.predict_proba(X, fitted_mdl) + >>> y = OrdinalClassifier.predict(X, fitted_mdl) + >>> OrdinalClassifier.save(mdl=fitted_mdl, save_path=r"C:\mdl.pk") + """ + + def __init__(self): + pass + + @staticmethod + def fit(X: np.ndarray, y: np.ndarray, clf: Union[ACCEPTED_MODELS], core_cnt: int = -1) -> Dict[int, Union[ACCEPTED_MODELS]]: + + def _fit_binary_estimator(clf, X, y, y_label): + y_bin = (y <= y_label).astype(np.int32) + return clone(clf).fit(X, y_bin) + + classes_ = np.sort(np.unique(y)) + check_valid_array(data=classes_, source=f'{__class__.__name__} y', accepted_ndims=(1,), accepted_dtypes=(int,)) + if len(classes_) < 3: + raise InvalidInputError(msg=f'Found {len(classes_)} classes in y [{classes_}], requires at least 3', source=f'{OrdinalClassifier.__name__} fit') + intervals = [classes_[i] - classes_[i-1] for i in range(1, len(classes_))] + if len(set(intervals)) != 1: + raise InvalidInputError(msg=f'The values in y ({classes_}) are not of equal interval.', source=f'{OrdinalClassifier.__name__} fit') + check_valid_array(data=X, source=f'{__class__.__name__} x', accepted_ndims=(2,), accepted_dtypes=Formats.NUMERIC_DTYPES.value) + if not isinstance(clf, (RandomForestClassifier,)) or ('predict_proba' not in dir(clf)): + raise InvalidInputError(msg=f'clf is not of valid type: {type(clf)} (accepted: {ACCEPTED_MODELS})', source=f'{OrdinalClassifier.__name__} fit') + check_int(name='core_cnt', min_value=-1, unaccepted_vals=[0], value=core_cnt) + core_cnt = [find_core_cnt()[0] if core_cnt == -1 or core_cnt > find_core_cnt()[0] else core_cnt][0] + return dict(zip(classes_[:-1], Parallel(n_jobs=core_cnt)(delayed(_fit_binary_estimator)(clf, X, y, y_label) for y_label in classes_[:-1]))) + + @staticmethod + def predict_proba(X: np.ndarray, mdl: Dict[int, Union[ACCEPTED_MODELS]]) -> np.ndarray: + OrdinalClassifier._check_valid_mdl_dict(mdls=mdl) + check_valid_array(data=X, source=f'{__class__.__name__} x', accepted_ndims=(2,), accepted_dtypes=Formats.NUMERIC_DTYPES.value) + if mdl[list(mdl.keys())[0]].n_features_ != X.shape[1]: + raise InvalidInputError(msg=f'Model expects {mdl[list(mdl.keys())[0]].n_features_} features, got {X.shape[1]}.', source=f'{OrdinalClassifier.__name__} predict') + check_valid_array(data=X, source=f'{__class__.__name__} x', accepted_ndims=(2,), accepted_dtypes=Formats.NUMERIC_DTYPES.value) + raw_proba = np.array([estimator.predict_proba(X)[:, 1] for estimator in mdl.values()]).T + return np.diff(np.column_stack((np.zeros(X.shape[0]), raw_proba, np.ones(X.shape[0]))), n=1, axis=1) + + + @staticmethod + def predict(X: np.ndarray, mdl: Dict[int, Union[ACCEPTED_MODELS]]) -> np.ndarray: + OrdinalClassifier._check_valid_mdl_dict(mdls=mdl) + check_valid_array(data=X, source=f'{__class__.__name__} x', accepted_ndims=(2,), accepted_dtypes=Formats.NUMERIC_DTYPES.value) + if mdl[list(mdl.keys())[0]].n_features_ != X.shape[1]: + raise InvalidInputError(msg=f'Model expects {mdl[list(mdl.keys())[0]].n_features_} features, got {X.shape[1]}.', source=f'{OrdinalClassifier.__name__} predict') + return np.argmax(OrdinalClassifier.predict_proba(X, mdl=mdl), axis=1) + + @staticmethod + def save(mdl: Dict[int, Union[ACCEPTED_MODELS]], save_path: Union[str, os.PathLike]): + OrdinalClassifier._check_valid_mdl_dict(mdls=mdl) + check_if_dir_exists(in_dir=os.path.dirname(save_path), source=f'{OrdinalClassifier.__name__} save') + write_pickle(data=mdl, save_path=save_path) + + + @staticmethod + def load(file_path: Union[str, os.PathLike]) -> Dict[int, Union[ACCEPTED_MODELS]]: + check_file_exist_and_readable(file_path=file_path) + return read_pickle(data_path=file_path) + + + @staticmethod + def _check_valid_mdl_dict(mdls: Dict[int, Union[ACCEPTED_MODELS]]) -> None: + features_in_cnt = [] + for mdl in mdls.values(): features_in_cnt.append(mdl.n_features_) + if len(set(features_in_cnt)) != 1: + raise InvalidInputError(msg=f'The models has different N features [{features_in_cnt}]') + +# X = np.random.randint(0, 500, (100, 50)) +# y = np.random.randint(1, 6, (100)) +# rf_mdl = TrainModelMixin().clf_define() +# fitted_mdls = OrdinalClassifier.fit(X, y, rf_mdl, -1) +# y_hat = OrdinalClassifier.predict_proba(X, fitted_mdls) +# y = OrdinalClassifier.predict(X, fitted_mdls) +# OrdinalClassifier.save(mdl=fitted_mdls, save_path=r"C:\Users\sroni\OneDrive\Desktop\mdl.pk") + +#predict_proba(X) +# ordinal_clf.predict(X) \ No newline at end of file diff --git a/simba/utils/read_write.py b/simba/utils/read_write.py index 4d8bac7bc..9b50be13d 100644 --- a/simba/utils/read_write.py +++ b/simba/utils/read_write.py @@ -1804,6 +1804,7 @@ def read_pickle(data_path: Union[str, os.PathLike], verbose: Optional[bool] = Fa :example: >>> data = read_pickle(data_path='/test/unsupervised/cluster_models') """ + data = None if os.path.isdir(data_path): if verbose: print(f"Reading in data directory {data_path}...") @@ -1841,10 +1842,7 @@ def read_pickle(data_path: Union[str, os.PathLike], verbose: Optional[bool] = Fa source=read_pickle.__name__, ) else: - raise InvalidFilepathError( - msg=f"The path {data_path} is neither a valid file or directory path", - source=read_pickle.__name__, - ) + raise InvalidFilepathError(msg=f"The path {data_path} is neither a valid file or directory path", source=read_pickle.__name__) return data