diff --git a/simba/mixins/train_model_mixin.py b/simba/mixins/train_model_mixin.py index 14fbef863..eec56bd1e 100644 --- a/simba/mixins/train_model_mixin.py +++ b/simba/mixins/train_model_mixin.py @@ -58,7 +58,9 @@ from simba.plotting.shap_agg_stats_visualizer import \ ShapAggregateStatisticsVisualizer from simba.ui.tkinter_functions import TwoOptionQuestionPopUp -from simba.utils.checks import (check_file_exist_and_readable, check_float, +from simba.utils.checks import (check_all_dfs_in_list_has_same_cols, + check_file_exist_and_readable, + check_filepaths_in_iterable_exist, check_float, check_if_dir_exists, check_if_valid_input, check_instance, check_int, check_str, check_that_column_exist, check_valid_array, @@ -115,20 +117,16 @@ def read_all_files_in_folder(self, :examples: >>> self.read_all_files_in_folder(file_paths=['targets_inserted/Video_1.csv', 'targets_inserted/Video_2.csv'], file_type='csv', classifier_names=['Attack']) """ - + check_filepaths_in_iterable_exist(file_paths=file_paths, name=self.__class__.__name__) timer = SimbaTimer(start=True) - frm_number_lst = [] - df_concat = pd.DataFrame() + frm_number_lst, dfs = [], [] + if len(file_paths) == 0: + raise NoDataError(msg="SimBA found 0 annotated frames in the project_folder/csv/targets_inserted directory", source=self.__class__.__name__) for file_cnt, file in enumerate(file_paths): - print(f"Reading in file {str(file_cnt + 1)}/{str(len(file_paths))}...") _, vid_name, _ = get_fn_ext(file) - df = ( - read_df(file, file_type) - .dropna(axis=0, how="all") - .fillna(0) - .astype(np.float32) - ) - frm_number_lst.extend((df.index)) + print(f"Reading in {vid_name} (file {str(file_cnt + 1)}/{str(len(file_paths))})...") + df = (read_df(file, file_type).dropna(axis=0, how="all").fillna(0).astype(np.float32)) + frm_number_lst.extend(list(df.index)) df.index = [vid_name] * len(df) if classifier_names != None: for clf_name in classifier_names: @@ -136,31 +134,21 @@ def read_all_files_in_folder(self, raise MissingColumnsError(msg=f"Data for video {vid_name} does not contain any annotations for behavior {clf_name}. Delete classifier {clf_name} from the SimBA project, or add annotations for behavior {clf_name} to the video {vid_name}", source=self.__class__.__name__,) elif (len(set(df[clf_name].unique()) - {0, 1}) > 0 and raise_bool_clf_error): raise InvalidInputError(msg=f"The annotation column for a classifier should contain only 0 or 1 values. However, in file {file} the {clf_name} field contains additional value(s): {list(set(df[clf_name].unique()) - {0, 1})}.", source=self.__class__.__name__) - else: - df_concat = pd.concat([df_concat, df], axis=0) - else: - df_concat = pd.concat([df_concat, df], axis=0) - try: - df_concat = df_concat.set_index("scorer") - except KeyError: - pass - if len(df_concat) == 0: - raise NoDataError(msg="SimBA found 0 annotated frames in the project_folder/csv/targets_inserted directory", source=self.__class__.__name__) - df_concat = df_concat.loc[ - :, ~df_concat.columns.str.contains("^Unnamed") - ].fillna(0) + dfs.append(df) + + check_all_dfs_in_list_has_same_cols(dfs=dfs, source='/project_folder/csv/targets_inserted', raise_error=True) + col_headers = [list(x.columns) for x in dfs] + dfs = [x[col_headers[0]] for x in dfs] + dfs = pd.concat(dfs, axis=0) + if 'scorer' in dfs.columns: + dfs = dfs.set_index("scorer") + dfs = dfs.loc[:, ~dfs.columns.str.contains("^Unnamed")].fillna(0) timer.stop_timer() - memory_size = get_memory_usage_of_df(df=df_concat) - print( - f'Dataset size: {memory_size["megabytes"]}MB / {memory_size["gigabytes"]}GB' - ) - print( - "{} file(s) read (elapsed time: {}s) ...".format( - str(len(file_paths)), timer.elapsed_time_str - ) - ) + memory_size = get_memory_usage_of_df(df=dfs) + print(f'Dataset size: {memory_size["megabytes"]}MB / {memory_size["gigabytes"]}GB') + print(f"{len(file_paths)} file(s) read (elapsed time: {timer.elapsed_time_str}s) ...") - return df_concat.astype(np.float32), frm_number_lst + return dfs.astype(np.float32), frm_number_lst def read_in_all_model_names_to_remove(self, config: configparser.ConfigParser, model_cnt: int, clf_name: str) -> List[str]: """ @@ -1580,24 +1568,29 @@ def read_all_files_in_folder_mp_futures(self, """ THREADSAFE_CORE_COUNT = 16 + check_filepaths_in_iterable_exist(file_paths=annotations_file_paths, name=self.__class__.__name__) try: if (platform.system() == "Darwin") and (multiprocessing.get_start_method() != "spawn"): multiprocessing.set_start_method("spawn", force=True) cpu_cnt, _ = find_core_cnt() if (cpu_cnt > THREADSAFE_CORE_COUNT) and (platform.system() == OS.WINDOWS.value): cpu_cnt = THREADSAFE_CORE_COUNT - df_lst, frm_number_list = [], [] + dfs, frm_number_list = [], [] with concurrent.futures.ProcessPoolExecutor(max_workers=cpu_cnt) as executor: results = [executor.submit(self._read_data_file_helper_futures, data, file_type, classifier_names, raise_bool_clf_error) for data in annotations_file_paths] for result in concurrent.futures.as_completed(results): - df_lst.append(result.result()[0]) + dfs.append(result.result()[0]) frm_number_list.extend((result.result()[-1])) print(f"Reading complete {result.result()[1]} (elapsed time: {result.result()[2]}s)...") - df_concat = pd.concat(df_lst, axis=0).round(4) - if "scorer" in df_concat.columns: - df_concat = df_concat.drop(["scorer"], axis=1) - return df_concat, frm_number_list + check_all_dfs_in_list_has_same_cols(dfs=dfs, source='/project_folder/csv/targets_inserted', raise_error=True) + col_headers = [list(x.columns) for x in dfs] + dfs = [x[col_headers[0]] for x in dfs] + dfs = pd.concat(dfs, axis=0).round(4) + if "scorer" in dfs.columns: dfs = dfs.drop(["scorer"], axis=1) + memory_size = get_memory_usage_of_df(df=dfs) + print(f'Dataset size: {memory_size["megabytes"]}MB / {memory_size["gigabytes"]}GB') + return dfs, frm_number_list except Exception as e: MultiProcessingFailedWarning(msg=f"Multi-processing file read failed, reverting to single core (increased run-time on large datasets). Exception: {e.args}") diff --git a/simba/sandbox/check_dfs_in_lst.py b/simba/sandbox/check_dfs_in_lst.py new file mode 100644 index 000000000..f08b35203 --- /dev/null +++ b/simba/sandbox/check_dfs_in_lst.py @@ -0,0 +1,19 @@ +from typing import List + +import pandas as pd + +from simba.utils.checks import check_valid_lst + + +def check_all_dfs_in_list_has_same_cols(dfs: List[pd.DataFrame], raise_error: bool = True) -> bool: + check_valid_lst(data=dfs, source=check_all_dfs_in_list_has_same_cols.__name__, valid_dtypes=(pd.DataFrame,), min_len=1) + col_headers = [list(x.columns) for x in dfs] + common_headers = set(col_headers[0]).intersection(*col_headers[1:]) + all_headers = set(item for sublist in col_headers for item in sublist) + missing_headers = list(all_headers - common_headers) + if len(missing_headers) > 0: + if raise_error: + raise MissingColumnsError(msg=f"The data in project_folder/csv/targets_inserted directory do not contain the same headers. Some files are missing the headers: {missing_headers}", source=self.__class__.__name__) + else: + return False + return True \ No newline at end of file diff --git a/simba/utils/checks.py b/simba/utils/checks.py index 991faff46..b47747432 100644 --- a/simba/utils/checks.py +++ b/simba/utils/checks.py @@ -24,7 +24,8 @@ DirectoryNotEmptyError, FFMPEGNotFoundError, FloatError, FrameRangeError, IntegerError, InvalidFilepathError, InvalidInputError, - NoDataError, NoFilesFoundError, NoROIDataError, + MissingColumnsError, NoDataError, + NoFilesFoundError, NoROIDataError, NotDirectoryError, ParametersFileError, StringError) from simba.utils.warnings import (CorruptedFileWarning, FrameRangeWarning, @@ -1478,3 +1479,17 @@ def check_filepaths_in_iterable_exist(file_paths: Iterable[str], check_str(name=f'{check_filepaths_in_iterable_exist.__name__} {file_path} {name}', value=file_path) if not os.path.isfile(file_path): raise NoFilesFoundError(msg=f'{name} {file_path} is not a valid file path') + +def check_all_dfs_in_list_has_same_cols(dfs: List[pd.DataFrame], raise_error: bool = True, source: str = '') -> bool: + """ Checks that all dataframes in list has the same column names""" + check_valid_lst(data=dfs, source=check_all_dfs_in_list_has_same_cols.__name__, valid_dtypes=(pd.DataFrame,), min_len=1) + col_headers = [list(x.columns) for x in dfs] + common_headers = set(col_headers[0]).intersection(*col_headers[1:]) + all_headers = set(item for sublist in col_headers for item in sublist) + missing_headers = list(all_headers - common_headers) + if len(missing_headers) > 0: + if raise_error: + raise MissingColumnsError(msg=f"The data in {source} directory do not contain the same headers. Some files are missing the headers: {missing_headers}", source=check_all_dfs_in_list_has_same_cols.__name__) + else: + return False + return True \ No newline at end of file