Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
sronilsson committed Dec 20, 2024
2 parents 8b301d2 + ea1d7d1 commit 3027398
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 42 deletions.
75 changes: 34 additions & 41 deletions simba/mixins/train_model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@
from simba.plotting.shap_agg_stats_visualizer import \
ShapAggregateStatisticsVisualizer
from simba.ui.tkinter_functions import TwoOptionQuestionPopUp
from simba.utils.checks import (check_file_exist_and_readable, check_float,
from simba.utils.checks import (check_all_dfs_in_list_has_same_cols,
check_file_exist_and_readable,
check_filepaths_in_iterable_exist, check_float,
check_if_dir_exists, check_if_valid_input,
check_instance, check_int, check_str,
check_that_column_exist, check_valid_array,
Expand Down Expand Up @@ -115,52 +117,38 @@ def read_all_files_in_folder(self,
:examples:
>>> self.read_all_files_in_folder(file_paths=['targets_inserted/Video_1.csv', 'targets_inserted/Video_2.csv'], file_type='csv', classifier_names=['Attack'])
"""

check_filepaths_in_iterable_exist(file_paths=file_paths, name=self.__class__.__name__)
timer = SimbaTimer(start=True)
frm_number_lst = []
df_concat = pd.DataFrame()
frm_number_lst, dfs = [], []
if len(file_paths) == 0:
raise NoDataError(msg="SimBA found 0 annotated frames in the project_folder/csv/targets_inserted directory", source=self.__class__.__name__)
for file_cnt, file in enumerate(file_paths):
print(f"Reading in file {str(file_cnt + 1)}/{str(len(file_paths))}...")
_, vid_name, _ = get_fn_ext(file)
df = (
read_df(file, file_type)
.dropna(axis=0, how="all")
.fillna(0)
.astype(np.float32)
)
frm_number_lst.extend((df.index))
print(f"Reading in {vid_name} (file {str(file_cnt + 1)}/{str(len(file_paths))})...")
df = (read_df(file, file_type).dropna(axis=0, how="all").fillna(0).astype(np.float32))
frm_number_lst.extend(list(df.index))
df.index = [vid_name] * len(df)
if classifier_names != None:
for clf_name in classifier_names:
if not clf_name in df.columns:
raise MissingColumnsError(msg=f"Data for video {vid_name} does not contain any annotations for behavior {clf_name}. Delete classifier {clf_name} from the SimBA project, or add annotations for behavior {clf_name} to the video {vid_name}", source=self.__class__.__name__,)
elif (len(set(df[clf_name].unique()) - {0, 1}) > 0 and raise_bool_clf_error):
raise InvalidInputError(msg=f"The annotation column for a classifier should contain only 0 or 1 values. However, in file {file} the {clf_name} field contains additional value(s): {list(set(df[clf_name].unique()) - {0, 1})}.", source=self.__class__.__name__)
else:
df_concat = pd.concat([df_concat, df], axis=0)
else:
df_concat = pd.concat([df_concat, df], axis=0)
try:
df_concat = df_concat.set_index("scorer")
except KeyError:
pass
if len(df_concat) == 0:
raise NoDataError(msg="SimBA found 0 annotated frames in the project_folder/csv/targets_inserted directory", source=self.__class__.__name__)
df_concat = df_concat.loc[
:, ~df_concat.columns.str.contains("^Unnamed")
].fillna(0)
dfs.append(df)

check_all_dfs_in_list_has_same_cols(dfs=dfs, source='/project_folder/csv/targets_inserted', raise_error=True)
col_headers = [list(x.columns) for x in dfs]
dfs = [x[col_headers[0]] for x in dfs]
dfs = pd.concat(dfs, axis=0)
if 'scorer' in dfs.columns:
dfs = dfs.set_index("scorer")
dfs = dfs.loc[:, ~dfs.columns.str.contains("^Unnamed")].fillna(0)
timer.stop_timer()
memory_size = get_memory_usage_of_df(df=df_concat)
print(
f'Dataset size: {memory_size["megabytes"]}MB / {memory_size["gigabytes"]}GB'
)
print(
"{} file(s) read (elapsed time: {}s) ...".format(
str(len(file_paths)), timer.elapsed_time_str
)
)
memory_size = get_memory_usage_of_df(df=dfs)
print(f'Dataset size: {memory_size["megabytes"]}MB / {memory_size["gigabytes"]}GB')
print(f"{len(file_paths)} file(s) read (elapsed time: {timer.elapsed_time_str}s) ...")

return df_concat.astype(np.float32), frm_number_lst
return dfs.astype(np.float32), frm_number_lst

def read_in_all_model_names_to_remove(self, config: configparser.ConfigParser, model_cnt: int, clf_name: str) -> List[str]:
"""
Expand Down Expand Up @@ -1580,24 +1568,29 @@ def read_all_files_in_folder_mp_futures(self,
"""

THREADSAFE_CORE_COUNT = 16
check_filepaths_in_iterable_exist(file_paths=annotations_file_paths, name=self.__class__.__name__)
try:
if (platform.system() == "Darwin") and (multiprocessing.get_start_method() != "spawn"):
multiprocessing.set_start_method("spawn", force=True)
cpu_cnt, _ = find_core_cnt()
if (cpu_cnt > THREADSAFE_CORE_COUNT) and (platform.system() == OS.WINDOWS.value):
cpu_cnt = THREADSAFE_CORE_COUNT
df_lst, frm_number_list = [], []
dfs, frm_number_list = [], []
with concurrent.futures.ProcessPoolExecutor(max_workers=cpu_cnt) as executor:
results = [executor.submit(self._read_data_file_helper_futures, data, file_type, classifier_names, raise_bool_clf_error) for data in annotations_file_paths]
for result in concurrent.futures.as_completed(results):
df_lst.append(result.result()[0])
dfs.append(result.result()[0])
frm_number_list.extend((result.result()[-1]))
print(f"Reading complete {result.result()[1]} (elapsed time: {result.result()[2]}s)...")
df_concat = pd.concat(df_lst, axis=0).round(4)
if "scorer" in df_concat.columns:
df_concat = df_concat.drop(["scorer"], axis=1)

return df_concat, frm_number_list
check_all_dfs_in_list_has_same_cols(dfs=dfs, source='/project_folder/csv/targets_inserted', raise_error=True)
col_headers = [list(x.columns) for x in dfs]
dfs = [x[col_headers[0]] for x in dfs]
dfs = pd.concat(dfs, axis=0).round(4)
if "scorer" in dfs.columns: dfs = dfs.drop(["scorer"], axis=1)
memory_size = get_memory_usage_of_df(df=dfs)
print(f'Dataset size: {memory_size["megabytes"]}MB / {memory_size["gigabytes"]}GB')
return dfs, frm_number_list

except Exception as e:
MultiProcessingFailedWarning(msg=f"Multi-processing file read failed, reverting to single core (increased run-time on large datasets). Exception: {e.args}")
Expand Down
19 changes: 19 additions & 0 deletions simba/sandbox/check_dfs_in_lst.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import List

import pandas as pd

from simba.utils.checks import check_valid_lst


def check_all_dfs_in_list_has_same_cols(dfs: List[pd.DataFrame], raise_error: bool = True) -> bool:
check_valid_lst(data=dfs, source=check_all_dfs_in_list_has_same_cols.__name__, valid_dtypes=(pd.DataFrame,), min_len=1)
col_headers = [list(x.columns) for x in dfs]
common_headers = set(col_headers[0]).intersection(*col_headers[1:])
all_headers = set(item for sublist in col_headers for item in sublist)
missing_headers = list(all_headers - common_headers)
if len(missing_headers) > 0:
if raise_error:
raise MissingColumnsError(msg=f"The data in project_folder/csv/targets_inserted directory do not contain the same headers. Some files are missing the headers: {missing_headers}", source=self.__class__.__name__)
else:
return False
return True
17 changes: 16 additions & 1 deletion simba/utils/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
DirectoryNotEmptyError, FFMPEGNotFoundError,
FloatError, FrameRangeError, IntegerError,
InvalidFilepathError, InvalidInputError,
NoDataError, NoFilesFoundError, NoROIDataError,
MissingColumnsError, NoDataError,
NoFilesFoundError, NoROIDataError,
NotDirectoryError, ParametersFileError,
StringError)
from simba.utils.warnings import (CorruptedFileWarning, FrameRangeWarning,
Expand Down Expand Up @@ -1478,3 +1479,17 @@ def check_filepaths_in_iterable_exist(file_paths: Iterable[str],
check_str(name=f'{check_filepaths_in_iterable_exist.__name__} {file_path} {name}', value=file_path)
if not os.path.isfile(file_path):
raise NoFilesFoundError(msg=f'{name} {file_path} is not a valid file path')

def check_all_dfs_in_list_has_same_cols(dfs: List[pd.DataFrame], raise_error: bool = True, source: str = '') -> bool:
""" Checks that all dataframes in list has the same column names"""
check_valid_lst(data=dfs, source=check_all_dfs_in_list_has_same_cols.__name__, valid_dtypes=(pd.DataFrame,), min_len=1)
col_headers = [list(x.columns) for x in dfs]
common_headers = set(col_headers[0]).intersection(*col_headers[1:])
all_headers = set(item for sublist in col_headers for item in sublist)
missing_headers = list(all_headers - common_headers)
if len(missing_headers) > 0:
if raise_error:
raise MissingColumnsError(msg=f"The data in {source} directory do not contain the same headers. Some files are missing the headers: {missing_headers}", source=check_all_dfs_in_list_has_same_cols.__name__)
else:
return False
return True

0 comments on commit 3027398

Please sign in to comment.