Skip to content

Commit

Permalink
column order in single core train data read in
Browse files Browse the repository at this point in the history
  • Loading branch information
sronilsson committed Dec 20, 2024
1 parent 7f3cf06 commit fd110b1
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 42 deletions.
73 changes: 32 additions & 41 deletions simba/mixins/train_model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
check_if_dir_exists, check_if_valid_input,
check_instance, check_int, check_str,
check_that_column_exist, check_valid_array,
check_valid_dataframe, check_valid_lst)
check_valid_dataframe, check_valid_lst, check_all_dfs_in_list_has_same_cols, check_filepaths_in_iterable_exist)
from simba.utils.data import (detect_bouts, detect_bouts_multiclass,
get_library_version)
from simba.utils.enums import (OS, ConfigKey, Defaults, Dtypes, Formats,
Expand Down Expand Up @@ -115,52 +115,38 @@ def read_all_files_in_folder(self,
:examples:
>>> self.read_all_files_in_folder(file_paths=['targets_inserted/Video_1.csv', 'targets_inserted/Video_2.csv'], file_type='csv', classifier_names=['Attack'])
"""

check_filepaths_in_iterable_exist(file_paths=file_paths, name=self.__class__.__name__)
timer = SimbaTimer(start=True)
frm_number_lst = []
df_concat = pd.DataFrame()
frm_number_lst, dfs = [], []
if len(file_paths) == 0:
raise NoDataError(msg="SimBA found 0 annotated frames in the project_folder/csv/targets_inserted directory", source=self.__class__.__name__)
for file_cnt, file in enumerate(file_paths):
print(f"Reading in file {str(file_cnt + 1)}/{str(len(file_paths))}...")
_, vid_name, _ = get_fn_ext(file)
df = (
read_df(file, file_type)
.dropna(axis=0, how="all")
.fillna(0)
.astype(np.float32)
)
frm_number_lst.extend((df.index))
print(f"Reading in {vid_name} (file {str(file_cnt + 1)}/{str(len(file_paths))})...")
df = (read_df(file, file_type).dropna(axis=0, how="all").fillna(0).astype(np.float32))
frm_number_lst.extend(list(df.index))
df.index = [vid_name] * len(df)
if classifier_names != None:
for clf_name in classifier_names:
if not clf_name in df.columns:
raise MissingColumnsError(msg=f"Data for video {vid_name} does not contain any annotations for behavior {clf_name}. Delete classifier {clf_name} from the SimBA project, or add annotations for behavior {clf_name} to the video {vid_name}", source=self.__class__.__name__,)
elif (len(set(df[clf_name].unique()) - {0, 1}) > 0 and raise_bool_clf_error):
raise InvalidInputError(msg=f"The annotation column for a classifier should contain only 0 or 1 values. However, in file {file} the {clf_name} field contains additional value(s): {list(set(df[clf_name].unique()) - {0, 1})}.", source=self.__class__.__name__)
else:
df_concat = pd.concat([df_concat, df], axis=0)
else:
df_concat = pd.concat([df_concat, df], axis=0)
try:
df_concat = df_concat.set_index("scorer")
except KeyError:
pass
if len(df_concat) == 0:
raise NoDataError(msg="SimBA found 0 annotated frames in the project_folder/csv/targets_inserted directory", source=self.__class__.__name__)
df_concat = df_concat.loc[
:, ~df_concat.columns.str.contains("^Unnamed")
].fillna(0)
dfs.append(df)

check_all_dfs_in_list_has_same_cols(dfs=dfs, source='/project_folder/csv/targets_inserted', raise_error=True)
col_headers = [list(x.columns) for x in dfs]
dfs = [x[col_headers[0]] for x in dfs]
dfs = pd.concat(dfs, axis=0)
if 'scorer' in dfs.columns:
dfs = dfs.set_index("scorer")
dfs = dfs.loc[:, ~dfs.columns.str.contains("^Unnamed")].fillna(0)
timer.stop_timer()
memory_size = get_memory_usage_of_df(df=df_concat)
print(
f'Dataset size: {memory_size["megabytes"]}MB / {memory_size["gigabytes"]}GB'
)
print(
"{} file(s) read (elapsed time: {}s) ...".format(
str(len(file_paths)), timer.elapsed_time_str
)
)
memory_size = get_memory_usage_of_df(df=dfs)
print(f'Dataset size: {memory_size["megabytes"]}MB / {memory_size["gigabytes"]}GB')
print(f"{len(file_paths)} file(s) read (elapsed time: {timer.elapsed_time_str}s) ...")

return df_concat.astype(np.float32), frm_number_lst
return dfs.astype(np.float32), frm_number_lst

def read_in_all_model_names_to_remove(self, config: configparser.ConfigParser, model_cnt: int, clf_name: str) -> List[str]:
"""
Expand Down Expand Up @@ -1580,24 +1566,29 @@ def read_all_files_in_folder_mp_futures(self,
"""

THREADSAFE_CORE_COUNT = 16
check_filepaths_in_iterable_exist(file_paths=annotations_file_paths, name=self.__class__.__name__)
try:
if (platform.system() == "Darwin") and (multiprocessing.get_start_method() != "spawn"):
multiprocessing.set_start_method("spawn", force=True)
cpu_cnt, _ = find_core_cnt()
if (cpu_cnt > THREADSAFE_CORE_COUNT) and (platform.system() == OS.WINDOWS.value):
cpu_cnt = THREADSAFE_CORE_COUNT
df_lst, frm_number_list = [], []
dfs, frm_number_list = [], []
with concurrent.futures.ProcessPoolExecutor(max_workers=cpu_cnt) as executor:
results = [executor.submit(self._read_data_file_helper_futures, data, file_type, classifier_names, raise_bool_clf_error) for data in annotations_file_paths]
for result in concurrent.futures.as_completed(results):
df_lst.append(result.result()[0])
dfs.append(result.result()[0])
frm_number_list.extend((result.result()[-1]))
print(f"Reading complete {result.result()[1]} (elapsed time: {result.result()[2]}s)...")
df_concat = pd.concat(df_lst, axis=0).round(4)
if "scorer" in df_concat.columns:
df_concat = df_concat.drop(["scorer"], axis=1)

return df_concat, frm_number_list
check_all_dfs_in_list_has_same_cols(dfs=dfs, source='/project_folder/csv/targets_inserted', raise_error=True)
col_headers = [list(x.columns) for x in dfs]
dfs = [x[col_headers[0]] for x in dfs]
dfs = pd.concat(dfs, axis=0).round(4)
if "scorer" in dfs.columns: dfs = dfs.drop(["scorer"], axis=1)
memory_size = get_memory_usage_of_df(df=dfs)
print(f'Dataset size: {memory_size["megabytes"]}MB / {memory_size["gigabytes"]}GB')
return dfs, frm_number_list

except Exception as e:
MultiProcessingFailedWarning(msg=f"Multi-processing file read failed, reverting to single core (increased run-time on large datasets). Exception: {e.args}")
Expand Down
17 changes: 17 additions & 0 deletions simba/sandbox/check_dfs_in_lst.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pandas as pd
from typing import List

from simba.utils.checks import check_valid_lst

def check_all_dfs_in_list_has_same_cols(dfs: List[pd.DataFrame], raise_error: bool = True) -> bool:
check_valid_lst(data=dfs, source=check_all_dfs_in_list_has_same_cols.__name__, valid_dtypes=(pd.DataFrame,), min_len=1)
col_headers = [list(x.columns) for x in dfs]
common_headers = set(col_headers[0]).intersection(*col_headers[1:])
all_headers = set(item for sublist in col_headers for item in sublist)
missing_headers = list(all_headers - common_headers)
if len(missing_headers) > 0:
if raise_error:
raise MissingColumnsError(msg=f"The data in project_folder/csv/targets_inserted directory do not contain the same headers. Some files are missing the headers: {missing_headers}", source=self.__class__.__name__)
else:
return False
return True
16 changes: 15 additions & 1 deletion simba/utils/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
InvalidFilepathError, InvalidInputError,
NoDataError, NoFilesFoundError, NoROIDataError,
NotDirectoryError, ParametersFileError,
StringError)
StringError, MissingColumnsError)
from simba.utils.warnings import (CorruptedFileWarning, FrameRangeWarning,
InvalidValueWarning, NoDataFoundWarning)

Expand Down Expand Up @@ -1478,3 +1478,17 @@ def check_filepaths_in_iterable_exist(file_paths: Iterable[str],
check_str(name=f'{check_filepaths_in_iterable_exist.__name__} {file_path} {name}', value=file_path)
if not os.path.isfile(file_path):
raise NoFilesFoundError(msg=f'{name} {file_path} is not a valid file path')

def check_all_dfs_in_list_has_same_cols(dfs: List[pd.DataFrame], raise_error: bool = True, source: str = '') -> bool:
""" Checks that all dataframes in list has the same column names"""
check_valid_lst(data=dfs, source=check_all_dfs_in_list_has_same_cols.__name__, valid_dtypes=(pd.DataFrame,), min_len=1)
col_headers = [list(x.columns) for x in dfs]
common_headers = set(col_headers[0]).intersection(*col_headers[1:])
all_headers = set(item for sublist in col_headers for item in sublist)
missing_headers = list(all_headers - common_headers)
if len(missing_headers) > 0:
if raise_error:
raise MissingColumnsError(msg=f"The data in {source} directory do not contain the same headers. Some files are missing the headers: {missing_headers}", source=check_all_dfs_in_list_has_same_cols.__name__)
else:
return False
return True

0 comments on commit fd110b1

Please sign in to comment.