Skip to content

Commit

Permalink
not reading video_info unless required.
Browse files Browse the repository at this point in the history
  • Loading branch information
sronilsson committed Aug 1, 2024
1 parent cb4c143 commit ea8d3c9
Show file tree
Hide file tree
Showing 50 changed files with 218 additions and 117 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# Setup configuration
setuptools.setup(
name="Simba-UW-tf-dev",
version="1.99.2",
version="1.99.7",
author="Simon Nilsson, Jia Jie Choong, Sophia Hwang",
author_email="[email protected]",
description="Toolkit for computer classification and analysis of behaviors in experimental animals",
Expand Down
7 changes: 7 additions & 0 deletions simba/mixins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
try:
from cuml.ensemble import RandomForestClassifier as cuRF
except ImportError:
cuRF = None

__all__ = ['cuRF']

106 changes: 69 additions & 37 deletions simba/mixins/train_model_mixin.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
__author__ = "Simon Nilsson"


from . import cuRF
import warnings

warnings.filterwarnings("ignore")

import ast
import concurrent
import configparser
import os
import pickle
import platform
import subprocess
from concurrent.futures import ProcessPoolExecutor
from concurrent.futures.process import BrokenProcessPool
from copy import deepcopy
Expand Down Expand Up @@ -68,7 +65,7 @@
FaultyTrainingSetError,
FeatureNumberMismatchError, InvalidInputError,
MissingColumnsError, NoDataError,
SamplingError)
SamplingError, SimBAModuleNotFoundError)
from simba.utils.lookups import get_meta_data_file_headers
from simba.utils.printing import SimbaTimer, stdout_success
from simba.utils.read_write import (find_core_cnt, get_fn_ext,
Expand Down Expand Up @@ -843,9 +840,7 @@ def create_shap_log(self,
out_df_shap.to_csv(self.out_df_shap_path)
out_df_raw.to_csv(self.out_df_raw_path)
shap_frm_timer.stop_timer()
print(
f"SHAP frame: {cnt + 1} / {len(shap_df)}, elapsed time: {shap_frm_timer.elapsed_time_str}..."
)
print( f"SHAP frame: {cnt + 1} / {len(shap_df)}, elapsed time: {shap_frm_timer.elapsed_time_str}...")

shap_timer.stop_timer()
stdout_success(
Expand Down Expand Up @@ -1287,16 +1282,13 @@ def partial_dependence_calculator(
f"Partial dependencies for {feature_name} complete ({feature_cnt+1}/{len(x_df.columns)})..."
)

def clf_predict_proba(
self,
clf: RandomForestClassifier,
x_df: pd.DataFrame,
multiclass: bool = False,
model_name: Optional[str] = None,
data_path: Optional[Union[str, os.PathLike]] = None,
) -> np.ndarray:
def clf_predict_proba( self,
clf: RandomForestClassifier,
x_df: pd.DataFrame,
multiclass: bool = False,
model_name: Optional[str] = None,
data_path: Optional[Union[str, os.PathLike]] = None) -> np.ndarray:
"""
:param RandomForestClassifier clf: Random forest classifier object
:param pd.DataFrame x_df: Features df
:param bool multiclass: If True, the classifier predicts more than 2 targets. Else, boolean classifier.
Expand Down Expand Up @@ -1342,9 +1334,54 @@ def clf_predict_proba(
else:
return p_vals

def clf_fit(
self, clf: RandomForestClassifier, x_df: pd.DataFrame, y_df: pd.DataFrame
) -> RandomForestClassifier:


def clf_define(self,
n_estimators: Optional[int] = 2000,
max_depth: Optional[int] = None,
max_features: Optional[Union[str, int]] = 'sqrt',
n_jobs: Optional[int] = -1,
criterion: Optional[str] = 'gini',
min_samples_leaf: Optional[int] = 1,
bootstrap: Optional[bool] = True,
verbose: Optional[int] = 1,
class_weight: Optional[dict] = None,
cuda: Optional[bool] = False) -> RandomForestClassifier:


if not cuda:
return RandomForestClassifier(n_estimators=n_estimators,
max_depth=max_depth,
max_features=max_features,
n_jobs=n_jobs,
criterion=criterion,
min_samples_leaf=min_samples_leaf,
bootstrap=bootstrap,
verbose=verbose,
class_weight=class_weight)

else:
if cuRF is not None:
return cuRF(n_estimators=n_estimators,
split_criterion=criterion,
bootstrap=bootstrap,
max_depth=max_depth,
max_features=max_features,
min_samples_leaf=min_samples_leaf,
verbose=verbose)
else:
raise SimBAModuleNotFoundError(msg='SimBA could not find the cuml library for GPU machine learning algorithms.', source=self.__class__.__name__)





def clf_fit(self,
clf: Union[RandomForestClassifier, cuRF],
x_df: pd.DataFrame,
y_df: pd.DataFrame,
) -> RandomForestClassifier:

"""
Helper to fit clf model
Expand All @@ -1357,23 +1394,20 @@ def clf_fit(
nan_target = y_df.loc[pd.to_numeric(y_df).isna()]
if len(nan_features) > 0:
raise FaultyTrainingSetError(
msg=f"{len(nan_features)} frame(s) in your project_folder/csv/targets_inserted directory contains FEATURES with non-numerical values",
source=self.__class__.__name__,
)
msg=f"{len(nan_features)} frame(s) in your project_folder/csv/targets_inserted directory contains FEATURES with non-numerical values", source=self.__class__.__name__)
if len(nan_target) > 0:
raise FaultyTrainingSetError(
msg=f"{len(nan_target)} frame(s) in your project_folder/csv/targets_inserted directory contains ANNOTATIONS with non-numerical values",
source=self.__class__.__name__,
)
return clf.fit(x_df, y_df)
raise FaultyTrainingSetError( msg=f"{len(nan_target)} frame(s) in your project_folder/csv/targets_inserted directory contains ANNOTATIONS with non-numerical values", source=self.__class__.__name__)

clf.fit(x_df, y_df)

return clf

@staticmethod
def _read_data_file_helper(
file_path: str,
file_type: str,
clf_names: Optional[List[str]] = None,
raise_bool_clf_error: bool = True,
):
def _read_data_file_helper(file_path: str,
file_type: str,
clf_names: Optional[List[str]] = None,
raise_bool_clf_error: bool = True):

"""
Private function called by :meth:`simba.train_model_functions.read_all_files_in_folder_mp`
"""
Expand All @@ -1400,9 +1434,7 @@ def _read_data_file_helper(
source=TrainModelMixin._read_data_file_helper.__name__,
)
timer.stop_timer()
print(
f"Reading complete {vid_name} (elapsed time: {timer.elapsed_time_str}s)..."
)
print(f"Reading complete {vid_name} (elapsed time: {timer.elapsed_time_str}s)...")

return df, frame_numbers

Expand Down
11 changes: 1 addition & 10 deletions simba/model/grid_search_rf.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,7 @@ def run(self):
print(f"MODEL {config_cnt+1} settings")
self.print_machine_model_information(meta_dict)
print(f"# {len(self.feature_names)} features.")
self.rf_clf = RandomForestClassifier(n_estimators=meta_dict[MLParamKeys.RF_ESTIMATORS.value],
max_features=meta_dict[MLParamKeys.RF_MAX_FEATURES.value],
max_depth=meta_dict[MLParamKeys.RF_MAX_DEPTH.value],
n_jobs=-1,
criterion=meta_dict[MLParamKeys.RF_CRITERION.value],
min_samples_leaf=meta_dict[MLParamKeys.MIN_LEAF.value],
bootstrap=True,
verbose=1,
class_weight=meta_dict[MLParamKeys.CLASS_WEIGHTS.value])

self.rf_clf = self.clf_define(n_estimators=meta_dict[MLParamKeys.RF_ESTIMATORS.value], max_depth=meta_dict[MLParamKeys.RF_MAX_DEPTH.value], max_features=meta_dict[MLParamKeys.RF_MAX_FEATURES.value], n_jobs=-1, criterion=meta_dict[MLParamKeys.RF_CRITERION.value], min_samples_leaf=meta_dict[MLParamKeys.MIN_LEAF.value], bootstrap=True, verbose=1, class_weight=meta_dict[MLParamKeys.CLASS_WEIGHTS.value])
print(f"Fitting {self.clf_name} model...")
self.rf_clf = self.clf_fit(clf=self.rf_clf, x_df=self.x_train, y_df=self.y_train)
if (meta_dict[MLParamKeys.PERMUTATION_IMPORTANCE.value] in Options.PERFORM_FLAGS.value):
Expand Down
22 changes: 6 additions & 16 deletions simba/model/train_rf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@
from typing import Union

import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

from simba.mixins.config_reader import ConfigReader
from simba.mixins.train_model_mixin import TrainModelMixin
from simba.utils.checks import check_if_filepath_list_is_empty, check_int
from simba.utils.enums import (ConfigKey, Dtypes, Formats, Methods,
MLParamKeys, Options, TagNames)
from simba.utils.printing import SimbaTimer, log_event, stdout_success
from simba.utils.enums import (ConfigKey, Dtypes, Formats, Methods, MLParamKeys, Options)
from simba.utils.printing import SimbaTimer, stdout_success
from simba.utils.read_write import read_config_entry, write_df


Expand Down Expand Up @@ -235,17 +233,9 @@ def run(self):
name=MLParamKeys.SHAP_ABSENT.value, value=shap_target_absent_cnt
)

self.rf_clf = RandomForestClassifier(
n_estimators=n_estimators,
max_depth=self.rf_max_depth,
max_features=max_features,
n_jobs=-1,
criterion=criterion,
min_samples_leaf=min_sample_leaf,
bootstrap=True,
verbose=1,
class_weight=class_weights,
)


self.rf_clf = self.clf_define(n_estimators=n_estimators, max_depth=self.rf_max_depth, max_features=max_features, n_jobs=-1, criterion=criterion, min_samples_leaf=min_sample_leaf, verbose=1, class_weight=class_weights)

print(f"Fitting {self.clf_name} model...")
self.rf_clf = self.clf_fit(
Expand Down Expand Up @@ -397,7 +387,7 @@ def save(self) -> None:
stdout_success(msg=f"Evaluation files are in models/generated_models/model_evaluations folders", source=self.__class__.__name__)


# test = TrainRandomForestClassifier(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/mitra/project_folder/project_config.ini')
# test = TrainRandomForestClassifier(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini")
# test.run()
# test.save()

Expand Down
6 changes: 3 additions & 3 deletions simba/pose_importers/dlc_importer_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def import_dlc_csv(config_path: Union[str, os.PathLike], source: str) -> List[st
"""

check_file_exist_and_readable(file_path=config_path)
conf = ConfigReader(config_path=config_path)
conf = ConfigReader(config_path=config_path, read_video_info=False)
original_file_name_dir = os.path.join(conf.input_csv_dir, "original_filename")
if not os.path.exists(original_file_name_dir): os.makedirs(original_file_name_dir)
prev_imported_file_paths = find_files_of_filetypes_in_directory(directory=conf.input_csv_dir, extensions=[f'.{conf.file_type}'], raise_warning=False, raise_error=False)
Expand All @@ -62,8 +62,8 @@ def import_dlc_csv(config_path: Union[str, os.PathLike], source: str) -> List[st
new_file_name_wo_ext = new_file_name.split(".")[0]
video_basename = os.path.basename(file_path)
print(f"Importing {video_name} to SimBA project...")
# if new_file_name_wo_ext in prev_imported_file_names:
# raise FileExistError(f"SIMBA IMPORT ERROR: {new_file_name} already exist in project in the directory {conf.input_csv_dir}. Remove file from project or rename imported video file name before importing.")
if new_file_name_wo_ext in prev_imported_file_names:
raise FileExistError(f"SIMBA IMPORT ERROR: {new_file_name} already exist in project in the directory {conf.input_csv_dir}. Remove file from project or rename imported video file name before importing.")
shutil.copy(file_path, conf.input_csv_dir)
shutil.copy(file_path, original_file_name_dir)
os.rename(os.path.join(conf.input_csv_dir, video_basename), os.path.join(conf.input_csv_dir, new_file_name))
Expand Down
5 changes: 3 additions & 2 deletions simba/ui/machine_model_settings_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import webbrowser
from tkinter import *
from typing import Union

import pandas as pd

Expand All @@ -27,8 +28,8 @@ class MachineModelSettingsPopUp(PopUpMixin, ConfigReader):
GUI window for specifying ML model training parameters.
"""

def __init__(self, config_path: str):
ConfigReader.__init__(self, config_path=config_path)
def __init__(self, config_path: Union[str, os.PathLike]):
ConfigReader.__init__(self, config_path=config_path, read_video_info=False)
PopUpMixin.__init__(self, title="MACHINE MODEL SETTINGS", size=(450, 800))
if not os.path.exists(self.configs_meta_dir):
os.makedirs(self.configs_meta_dir)
Expand Down
2 changes: 1 addition & 1 deletion simba/ui/pop_ups/animal_directing_other_animals_pop_up.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class AnimalDirectingAnimalPopUp(ConfigReader, PopUpMixin):
>>> test = AnimalDirectingAnimalPopUp(config_path=r"C:\troubleshooting\two_black_animals_14bp\project_folder\project_config.ini")
"""
def __init__(self, config_path: Union[str, os.PathLike]):
ConfigReader.__init__(self, config_path=config_path)
ConfigReader.__init__(self, config_path=config_path, read_video_info=False)
if self.animal_cnt < 2:
raise AnimalNumberError(msg=f"Directionality between animals require at least two animals. The SimBA project is set to use {self.animal_cnt} animal.", source=self.__class__.__name__,)
if len(self.outlier_corrected_paths) == 0:
Expand Down
2 changes: 1 addition & 1 deletion simba/ui/pop_ups/append_roi_features_animals_pop_up.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

class AppendROIFeaturesByAnimalPopUp(ConfigReader, PopUpMixin):
def __init__(self, config_path: Union[str, os.PathLike]):
ConfigReader.__init__(self, config_path=config_path)
ConfigReader.__init__(self, config_path=config_path, read_video_info=False)
if not os.path.isfile(self.roi_coordinates_path):
ROIWarning(msg=f"SIMBA ERROR: No ROIs have been defined. Please define ROIs before appending ROI-based features (no data file found at path {self.roi_coordinates_path})", source=self.__class__.__name__,)

Expand Down
2 changes: 1 addition & 1 deletion simba/ui/pop_ups/append_roi_features_bodypart_pop_up.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class AppendROIFeaturesByBodyPartPopUp(PopUpMixin, ConfigReader):
def __init__(self, config_path: Union[str, os.PathLike]):
ConfigReader.__init__(self, config_path=config_path)
ConfigReader.__init__(self, config_path=config_path, read_video_info=False)
if not os.path.isfile(self.roi_coordinates_path):
raise NoROIDataError(msg="SIMBA ERROR: No ROIs have been defined. Please define ROIs before appending ROI-based features",source=self.__class__.__name__,)
PopUpMixin.__init__(self, config_path=config_path, title="APPEND ROI FEATURES: BY BODY-PARTS")
Expand Down
2 changes: 1 addition & 1 deletion simba/ui/pop_ups/archive_files_pop_up.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
class ArchiveProcessedFilesPopUp(PopUpMixin, ConfigReader):
def __init__(self, config_path: str):
PopUpMixin.__init__(self, title="ADD CLASSIFIER")
ConfigReader.__init__(self, config_path=config_path)
ConfigReader.__init__(self, config_path=config_path, read_video_info=False)
self.archive_eb = Entry_Box(self.main_frm, "ARCHIVE DIRECTORY NAME", "25")
archive_btn = Button(self.main_frm, text="RUN ARCHIVE", font=Formats.FONT_REGULAR.value, fg="blue", command=lambda: self.run())
self.archive_eb.grid(row=0, column=0, sticky=NW)
Expand Down
2 changes: 1 addition & 1 deletion simba/ui/pop_ups/boolean_conditional_slicer_pup_up.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

class BooleanConditionalSlicerPopUp(PopUpMixin, ConfigReader):
def __init__(self, config_path: str):
ConfigReader.__init__(self, config_path=config_path)
ConfigReader.__init__(self, config_path=config_path, read_video_info=False)
PopUpMixin.__init__(self, title="CONDITIONAL BOOLEAN AGGREGATE STATISTICS", size=(600, 400))
self.rule_cnt_frm = CreateLabelFrameWithIcon(parent=self.main_frm, header="CONDITIONAL RULES #", icon_name=Keys.DOCUMENTATION.value, icon_link=Links.AGGREGATE_BOOL_STATS.value,)
self.rule_cnt_dropdown = DropDownMenu(self.rule_cnt_frm,"# RULES:",list(range(2, 21)),"25",com=self.create_rules_frames)
Expand Down
2 changes: 1 addition & 1 deletion simba/ui/pop_ups/clf_add_remove_print_pop_up.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class AddClfPopUp(PopUpMixin, ConfigReader):
def __init__(self, config_path: Union[str, os.PathLike]):

PopUpMixin.__init__(self, config_path=config_path, title="ADD CLASSIFIER")
ConfigReader.__init__(self, config_path=config_path) .clf_eb = Entry_Box(self.main_frm, "CLASSIFIER NAME", "15")
ConfigReader.__init__(self, config_path=config_path, read_video_info=False) .clf_eb = Entry_Box(self.main_frm, "CLASSIFIER NAME", "15")
add_btn = Button(self.main_frm, text="ADD CLASSIFIER", command=lambda: self.run())
self.clf_eb.grid(row=0, column=0, sticky=NW)
add_btn.grid(row=1, column=0, sticky=NW)
Expand Down
2 changes: 1 addition & 1 deletion simba/ui/pop_ups/clf_annotation_counts_pop_up.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class ClfAnnotationCountPopUp(PopUpMixin, ConfigReader):
>>> ClfAnnotationCountPopUp(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini')
"""
def __init__(self, config_path: Union[str, os.PathLike]):
ConfigReader.__init__(self, config_path=config_path)
ConfigReader.__init__(self, config_path=config_path, read_video_info=False)
self.config_path = config_path
PopUpMixin.__init__(self, title='COUNT NUMBER OF ANNOTATIONS IN SIMBA PROJECT', config_path=config_path)
if len(self.clf_names) == 0:
Expand Down
5 changes: 3 additions & 2 deletions simba/ui/pop_ups/clf_by_roi_pop_up.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from collections import defaultdict
from tkinter import *
from typing import Union

from simba.mixins.config_reader import ConfigReader
from simba.mixins.pop_up_mixin import PopUpMixin
Expand All @@ -22,8 +23,8 @@ class ClfByROIPopUp(PopUpMixin, ConfigReader):
>>> _ = ClfByROIPopUp(config_path=r"C:\troubleshooting\open_field_below\project_folder\project_config.ini")
"""

def __init__(self, config_path: str):
ConfigReader.__init__(self, config_path=config_path)
def __init__(self, config_path: Union[str, os.PathLike]):
ConfigReader.__init__(self, config_path=config_path, read_video_info=False)
if not os.path.isfile(self.roi_coordinates_path):
raise ROICoordinatesNotFoundError(expected_file_path=self.roi_coordinates_path, source=self.__class__.__name__)
PopUpMixin.__init__(self, title="CLASSIFICATIONS BY ROI")
Expand Down
6 changes: 4 additions & 2 deletions simba/ui/pop_ups/clf_by_timebins_pop_up.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
__author__ = "Simon Nilsson"

import multiprocessing
import os
from tkinter import *
from typing import Union

from simba.data_processors.timebins_clf_calculator import TimeBinsClfCalculator
from simba.mixins.config_reader import ConfigReader
Expand All @@ -15,9 +17,9 @@


class TimeBinsClfPopUp(PopUpMixin, ConfigReader):
def __init__(self, config_path: str):
def __init__(self, config_path: Union[str, os.PathLike]):
PopUpMixin.__init__(self, title="CLASSIFICATION BY TIME BINS")
ConfigReader.__init__(self, config_path=config_path)
ConfigReader.__init__(self, config_path=config_path, read_video_info=False)
cbox_titles = Options.TIMEBINS_MEASURMENT_OPTIONS.value
self.timebin_entrybox = Entry_Box(self.main_frm, "Set time bin size (s)", "15", validation="numeric")
measures_frm = CreateLabelFrameWithIcon(parent=self.main_frm, header="MEASUREMENTS", icon_name=Keys.DOCUMENTATION.value, icon_link=Links.ANALYZE_ML_RESULTS.value)
Expand Down
Loading

0 comments on commit ea8d3c9

Please sign in to comment.