diff --git a/docs/_static/img/ImportPoseFrame.webp b/docs/_static/img/ImportPoseFrame.webp new file mode 100644 index 000000000..9ba338817 Binary files /dev/null and b/docs/_static/img/ImportPoseFrame.webp differ diff --git a/docs/_static/img/ImportVideosFrame.webp b/docs/_static/img/ImportVideosFrame.webp new file mode 100644 index 000000000..7db02131b Binary files /dev/null and b/docs/_static/img/ImportVideosFrame.webp differ diff --git a/docs/_static/img/ProjectCreatorPopUp.webp b/docs/_static/img/ProjectCreatorPopUp.webp new file mode 100644 index 000000000..a80b33a8d Binary files /dev/null and b/docs/_static/img/ProjectCreatorPopUp.webp differ diff --git a/docs/tutorials_rst/example_1.py b/docs/tutorials_rst/example_1.py index ab5764e85..6f5445dd6 100644 --- a/docs/tutorials_rst/example_1.py +++ b/docs/tutorials_rst/example_1.py @@ -1,5 +1,5 @@ import simba -from simba.pose_importers.dlc_importer_csv import import_multiple_dlc_tracking_csv_file +from simba.pose_importers.dlc_importer_csv import import_dlc_csv_data from simba.outlier_tools.skip_outlier_correction import OutlierCorrectionSkipper from simba.utils.cli.cli_tools import feature_extraction_runner, set_video_parameters from simba.model.inference_batch import InferenceBatch @@ -27,11 +27,11 @@ # RUN THE DATA IMPORTER FOR A DIRECTORY OF FILES -import_multiple_dlc_tracking_csv_file(config_path=CONFIG_PATH, - interpolation_setting=INTERPOLATION_SETTING, - smoothing_setting=SMOOTHING_SETTING, - smoothing_time=SMOOTHING_TIME, - data_dir=DATA_DIR) +import_dlc_csv_data(config_path=CONFIG_PATH, + interpolation_setting=INTERPOLATION_SETTING, + smoothing_setting=SMOOTHING_SETTING, + smoothing_time=SMOOTHING_TIME, + data_path=DATA_DIR) # RUN THE OUTLIER CORRECTION SKIPPER diff --git a/setup.py b/setup.py index 15daf7e37..622614a88 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ setuptools.setup( name="Simba-UW-tf-dev", - version="1.95.4", + version="1.95.6", author="Simon Nilsson, Jia Jie Choong, Sophia Hwang", author_email="sronilsson@gmail.com", description="Toolkit for computer classification of behaviors in experimental animals", diff --git a/simba/SimBA.py b/simba/SimBA.py index 69e43cd9f..cc6dbc609 100644 --- a/simba/SimBA.py +++ b/simba/SimBA.py @@ -20,68 +20,48 @@ from simba.bounding_box_tools.boundary_menus import BoundaryMenus from simba.cue_light_tools.cue_light_menues import CueLightAnalyzerMenu -from simba.labelling.labelling_advanced_interface import \ - select_labelling_video_advanced +from simba.labelling.labelling_advanced_interface import select_labelling_video_advanced from simba.labelling.labelling_interface import select_labelling_video -from simba.labelling.targeted_annotations_clips import \ - select_labelling_video_targeted_clips +from simba.labelling.targeted_annotations_clips import select_labelling_video_targeted_clips from simba.model.grid_search_rf import GridSearchRandomForestClassifier from simba.model.inference_batch import InferenceBatch from simba.model.inference_validation import InferenceValidation from simba.model.train_rf import TrainRandomForestClassifier -from simba.outlier_tools.outlier_corrector_location import \ - OutlierCorrecterLocation -from simba.outlier_tools.outlier_corrector_movement import \ - OutlierCorrecterMovement -from simba.outlier_tools.skip_outlier_correction import \ - OutlierCorrectionSkipper -from simba.plotting.interactive_probability_grapher import \ - InteractiveProbabilityGrapher +from simba.outlier_tools.outlier_corrector_location import OutlierCorrecterLocation +from simba.outlier_tools.outlier_corrector_movement import OutlierCorrecterMovement +from simba.outlier_tools.skip_outlier_correction import OutlierCorrectionSkipper +from simba.plotting.interactive_probability_grapher import InteractiveProbabilityGrapher from simba.roi_tools.ROI_define import * from simba.roi_tools.ROI_menus import * from simba.roi_tools.ROI_reset import * from simba.third_party_label_appenders.BENTO_appender import BentoAppender from simba.third_party_label_appenders.BORIS_appender import BorisAppender -from simba.third_party_label_appenders.deepethogram_importer import \ - DeepEthogramImporter -from simba.third_party_label_appenders.ethovision_import import \ - ImportEthovision -from simba.third_party_label_appenders.observer_importer import \ - NoldusObserverImporter +from simba.third_party_label_appenders.deepethogram_importer import DeepEthogramImporter +from simba.third_party_label_appenders.ethovision_import import ImportEthovision +from simba.third_party_label_appenders.observer_importer import NoldusObserverImporter from simba.third_party_label_appenders.solomon_importer import SolomonImporter from simba.ui.create_project_ui import ProjectCreatorPopUp from simba.ui.machine_model_settings_ui import MachineModelSettingsPopUp from simba.ui.pop_ups.about_simba_pop_up import AboutSimBAPopUp -from simba.ui.pop_ups.animal_directing_other_animals_pop_up import \ - AnimalDirectingAnimalPopUp -from simba.ui.pop_ups.append_roi_features_animals_pop_up import \ - AppendROIFeaturesByAnimalPopUp -from simba.ui.pop_ups.append_roi_features_bodypart_pop_up import \ - AppendROIFeaturesByBodyPartPopUp +from simba.ui.pop_ups.animal_directing_other_animals_pop_up import AnimalDirectingAnimalPopUp +from simba.ui.pop_ups.append_roi_features_animals_pop_up import AppendROIFeaturesByAnimalPopUp +from simba.ui.pop_ups.append_roi_features_bodypart_pop_up import AppendROIFeaturesByBodyPartPopUp from simba.ui.pop_ups.archive_files_pop_up import ArchiveProcessedFilesPopUp from simba.ui.pop_ups.batch_preprocess_pop_up import BatchPreProcessPopUp -from simba.ui.pop_ups.boolean_conditional_slicer_pup_up import \ - BooleanConditionalSlicerPopUp -from simba.ui.pop_ups.clf_add_remove_print_pop_up import ( - AddClfPopUp, PrintModelInfoPopUp, RemoveAClassifierPopUp) +from simba.ui.pop_ups.boolean_conditional_slicer_pup_up import BooleanConditionalSlicerPopUp +from simba.ui.pop_ups.clf_add_remove_print_pop_up import (AddClfPopUp, PrintModelInfoPopUp, RemoveAClassifierPopUp) from simba.ui.pop_ups.clf_by_roi_pop_up import ClfByROIPopUp from simba.ui.pop_ups.clf_by_timebins_pop_up import TimeBinsClfPopUp -from simba.ui.pop_ups.clf_descriptive_statistics_pop_up import \ - ClfDescriptiveStatsPopUp +from simba.ui.pop_ups.clf_descriptive_statistics_pop_up import ClfDescriptiveStatsPopUp from simba.ui.pop_ups.clf_plot_pop_up import SklearnVisualizationPopUp -from simba.ui.pop_ups.clf_probability_plot_pop_up import \ - VisualizeClassificationProbabilityPopUp -from simba.ui.pop_ups.clf_validation_plot_pop_up import \ - ClassifierValidationPopUp +from simba.ui.pop_ups.clf_probability_plot_pop_up import VisualizeClassificationProbabilityPopUp +from simba.ui.pop_ups.clf_validation_plot_pop_up import ClassifierValidationPopUp from simba.ui.pop_ups.csv_2_parquet_pop_up import (Csv2ParquetPopUp, Parquet2CsvPopUp) from simba.ui.pop_ups.data_plot_pop_up import DataPlotterPopUp -from simba.ui.pop_ups.directing_animal_to_bodypart_plot_pop_up import \ - DirectingAnimalToBodyPartVisualizerPopUp -from simba.ui.pop_ups.directing_other_animals_plot_pop_up import \ - DirectingOtherAnimalsVisualizerPopUp -from simba.ui.pop_ups.direction_animal_to_bodypart_settings_pop_up import \ - DirectionAnimalToBodyPartSettingsPopUp +from simba.ui.pop_ups.directing_animal_to_bodypart_plot_pop_up import DirectingAnimalToBodyPartVisualizerPopUp +from simba.ui.pop_ups.directing_other_animals_plot_pop_up import DirectingOtherAnimalsVisualizerPopUp +from simba.ui.pop_ups.direction_animal_to_bodypart_settings_pop_up import DirectionAnimalToBodyPartSettingsPopUp from simba.ui.pop_ups.distance_plot_pop_up import DistancePlotterPopUp from simba.ui.pop_ups.fsttc_pop_up import FSTTCPopUp from simba.ui.pop_ups.gantt_pop_up import GanttPlotPopUp @@ -90,8 +70,7 @@ from simba.ui.pop_ups.kleinberg_pop_up import KleinbergPopUp from simba.ui.pop_ups.make_path_plot_pop_up import MakePathPlotPopUp from simba.ui.pop_ups.movement_analysis_pop_up import MovementAnalysisPopUp -from simba.ui.pop_ups.movement_analysis_time_bins_pop_up import \ - MovementAnalysisTimeBinsPopUp +from simba.ui.pop_ups.movement_analysis_time_bins_pop_up import MovementAnalysisTimeBinsPopUp from simba.ui.pop_ups.mutual_exclusivity_pop_up import MutualExclusivityPupUp from simba.ui.pop_ups.outlier_settings_pop_up import OutlierSettingsPopUp from simba.ui.pop_ups.path_plot_pop_up import PathPlotPopUp @@ -101,23 +80,17 @@ from simba.ui.pop_ups.quick_path_plot_pop_up import QuickLineplotPopup from simba.ui.pop_ups.remove_roi_features_pop_up import RemoveROIFeaturesPopUp from simba.ui.pop_ups.roi_analysis_pop_up import ROIAnalysisPopUp -from simba.ui.pop_ups.roi_analysis_time_bins_pop_up import \ - ROIAnalysisTimeBinsPopUp +from simba.ui.pop_ups.roi_analysis_time_bins_pop_up import ROIAnalysisTimeBinsPopUp from simba.ui.pop_ups.roi_features_plot_pop_up import VisualizeROIFeaturesPopUp -from simba.ui.pop_ups.roi_size_standardizer_popup import \ - ROISizeStandardizerPopUp +from simba.ui.pop_ups.roi_size_standardizer_popup import ROISizeStandardizerPopUp from simba.ui.pop_ups.roi_tracking_plot_pop_up import VisualizeROITrackingPopUp -from simba.ui.pop_ups.set_machine_model_parameters_pop_up import \ - SetMachineModelParameters +from simba.ui.pop_ups.set_machine_model_parameters_pop_up import SetMachineModelParameters from simba.ui.pop_ups.severity_analysis_pop_up import AnalyzeSeverityPopUp -from simba.ui.pop_ups.smoothing_interpolation_pop_up import (InterpolatePopUp, - SmoothingPopUp) -from simba.ui.pop_ups.spontaneous_alternation_pop_up import \ - SpontaneousAlternationPopUp -from simba.ui.pop_ups.subset_feature_extractor_pop_up import \ - FeatureSubsetExtractorPopUp -from simba.ui.pop_ups.third_party_annotator_appender_pop_up import \ - ThirdPartyAnnotatorAppenderPopUp +from simba.ui.pop_ups.smoothing_popup import SmoothingPopUp +from simba.ui.pop_ups.interpolate_pop_up import InterpolatePopUp +from simba.ui.pop_ups.spontaneous_alternation_pop_up import SpontaneousAlternationPopUp +from simba.ui.pop_ups.subset_feature_extractor_pop_up import FeatureSubsetExtractorPopUp +from simba.ui.pop_ups.third_party_annotator_appender_pop_up import ThirdPartyAnnotatorAppenderPopUp from simba.ui.pop_ups.validation_plot_pop_up import ValidationVideoPopUp from simba.ui.pop_ups.video_processing_pop_up import ( BackgroundRemoverPopUp, BoxBlurPopUp, BrightnessContrastPopUp, @@ -141,22 +114,20 @@ SuperimposeTextPopUp, SuperimposeTimerPopUp, SuperimposeVideoNamesPopUp, SuperimposeVideoPopUp, SuperimposeWatermarkPopUp, UpsampleVideosPopUp, VideoRotatorPopUp, VideoTemporalJoinPopUp) -from simba.ui.pop_ups.visualize_pose_in_dir_pop_up import \ - VisualizePoseInFolderPopUp +from simba.ui.pop_ups.visualize_pose_in_dir_pop_up import VisualizePoseInFolderPopUp from simba.ui.tkinter_functions import DropDownMenu, Entry_Box, FileSelect from simba.ui.video_info_ui import VideoInfoTable -from simba.utils.checks import (check_ffmpeg_available, - check_file_exist_and_readable, check_int) +from simba.ui.import_pose_frame import ImportPoseFrame +from simba.ui.import_videos_frame import ImportVideosFrame +from simba.utils.checks import (check_ffmpeg_available, check_file_exist_and_readable, check_int) from simba.utils.custom_feature_extractor import CustomFeatureExtractor from simba.utils.enums import OS, Defaults, Formats, Paths, TagNames from simba.utils.errors import InvalidInputError -from simba.utils.lookups import (get_bp_config_code_class_pairs, get_emojis, - get_icons_paths) +from simba.utils.lookups import (get_bp_config_code_class_pairs, get_emojis, get_icons_paths) from simba.utils.printing import stdout_success, stdout_warning from simba.utils.read_write import get_video_meta_data from simba.utils.warnings import FFMpegNotFoundWarning, PythonVersionWarning -from simba.video_processors.video_processing import \ - extract_frames_from_all_videos_in_directory +from simba.video_processors.video_processing import extract_frames_from_all_videos_in_directory sys.setrecursionlimit(10**6) currentPlatform = platform.system() @@ -356,12 +327,7 @@ def __init__(self, config_path: str): fg="blue", command=lambda: None, ) - interpolate_btn = Button( - further_methods_frm, - text="INTERPOLATE POSE IN SIMBA PROJECT", - fg="blue", - command=lambda: InterpolatePopUp(config_path=self.config_path), - ) + interpolate_btn = Button(further_methods_frm, text="INTERPOLATE POSE IN SIMBA PROJECT", fg="blue", command=lambda: InterpolatePopUp(config_path=self.config_path)) smooth_btn = Button( further_methods_frm, text="SMOOTH POSE IN SIMBA PROJECT", @@ -1030,173 +996,53 @@ def activate(box, *args): fg="blue", command=lambda: MovementAnalysisTimeBinsPopUp(config_path=self.config_path), ) - button_classifierbins = Button( - label_machineresults, - text="ANALYZE MACHINE PREDICTIONS: TIME-BINS", - fg="blue", - command=lambda: TimeBinsClfPopUp(config_path=self.config_path), - ) - button_classifier_ROI = Button( - label_machineresults, - text="ANALYZE MACHINE PREDICTION: BY ROI", - fg="blue", - command=lambda: ClfByROIPopUp(config_path=self.config_path), - ) - button_severity = Button( - label_machineresults, - text="ANALYZE MACHINE PREDICTION: BY SEVERITY", - fg="blue", - command=lambda: AnalyzeSeverityPopUp(config_path=self.config_path), - ) - - visualization_frm = CreateLabelFrameWithIcon( - parent=tab10, - header="DATA VISUALIZATIONS", - icon_name=Keys.DOCUMENTATION.value, - icon_link=Links.VISUALIZATION.value, - ) - sklearn_visualization_btn = Button( - visualization_frm, - text="VISUALIZE CLASSIFICATIONS", - fg="black", - command=lambda: SklearnVisualizationPopUp(config_path=self.config_path), - ) + button_classifierbins = Button(label_machineresults, text="ANALYZE MACHINE PREDICTIONS: TIME-BINS", fg="blue", command=lambda: TimeBinsClfPopUp(config_path=self.config_path)) + button_classifier_ROI = Button(label_machineresults, text="ANALYZE MACHINE PREDICTION: BY ROI", fg="blue", command=lambda: ClfByROIPopUp(config_path=self.config_path)) + button_severity = Button(label_machineresults, text="ANALYZE MACHINE PREDICTION: BY SEVERITY", fg="blue", command=lambda: AnalyzeSeverityPopUp(config_path=self.config_path)) + visualization_frm = CreateLabelFrameWithIcon(parent=tab10, header="DATA VISUALIZATIONS", icon_name=Keys.DOCUMENTATION.value, icon_link=Links.VISUALIZATION.value) + sklearn_visualization_btn = Button(visualization_frm, text="VISUALIZE CLASSIFICATIONS", fg="black", command=lambda: SklearnVisualizationPopUp(config_path=self.config_path)) sklearn_visualization_btn.grid(row=0, column=0, sticky=NW) - gantt_visualization_btn = Button( - visualization_frm, - text="VISUALIZE GANTT", - fg="blue", - command=lambda: GanttPlotPopUp(config_path=self.config_path), - ) + gantt_visualization_btn = Button(visualization_frm, text="VISUALIZE GANTT", fg="blue", command=lambda: GanttPlotPopUp(config_path=self.config_path)) gantt_visualization_btn.grid(row=1, column=0, sticky=NW) - probability_visualization_btn = Button( - visualization_frm, - text="VISUALIZE PROBABILITIES", - fg="green", - command=lambda: VisualizeClassificationProbabilityPopUp( - config_path=self.config_path - ), - ) + probability_visualization_btn = Button(visualization_frm, text="VISUALIZE PROBABILITIES", fg="green", command=lambda: VisualizeClassificationProbabilityPopUp( config_path=self.config_path)) probability_visualization_btn.grid(row=2, column=0, sticky=NW) - path_visualization_btn = Button( - visualization_frm, - text="VISUALIZE PATHS", - fg="orange", - command=lambda: PathPlotPopUp(config_path=self.config_path), - ) + path_visualization_btn = Button(visualization_frm, text="VISUALIZE PATHS", fg="orange", command=lambda: PathPlotPopUp(config_path=self.config_path)) path_visualization_btn.grid(row=3, column=0, sticky=NW) - distance_visualization_btn = Button( - visualization_frm, - text="VISUALIZE DISTANCES", - fg="red", - command=lambda: DistancePlotterPopUp(config_path=self.config_path), - ) + distance_visualization_btn = Button(visualization_frm, text="VISUALIZE DISTANCES", fg="red", command=lambda: DistancePlotterPopUp(config_path=self.config_path)) distance_visualization_btn.grid(row=4, column=0, sticky=NW) - heatmap_clf_visualization_btn = Button( - visualization_frm, - text="VISUALIZE CLASSIFICATION HEATMAPS", - fg="pink", - command=lambda: HeatmapClfPopUp(config_path=self.config_path), - ) + heatmap_clf_visualization_btn = Button(visualization_frm, text="VISUALIZE CLASSIFICATION HEATMAPS", fg="pink", command=lambda: HeatmapClfPopUp(config_path=self.config_path)) heatmap_clf_visualization_btn.grid(row=5, column=0, sticky=NW) - data_plot_visualization_btn = Button( - visualization_frm, - text="VISUALIZE DATA PLOTS", - fg="purple", - command=lambda: DataPlotterPopUp(config_path=self.config_path), - ) + data_plot_visualization_btn = Button(visualization_frm, text="VISUALIZE DATA PLOTS", fg="purple", command=lambda: DataPlotterPopUp(config_path=self.config_path)) data_plot_visualization_btn.grid(row=6, column=0, sticky=NW) - clf_validation_btn = Button( - visualization_frm, - text="CLASSIFIER VALIDATION CLIPS", - fg="blue", - command=lambda: ClassifierValidationPopUp(config_path=self.config_path), - ) + clf_validation_btn = Button(visualization_frm, text="CLASSIFIER VALIDATION CLIPS", fg="blue", command=lambda: ClassifierValidationPopUp(config_path=self.config_path)) clf_validation_btn.grid(row=7, column=0, sticky=NW) - merge_frm = CreateLabelFrameWithIcon( - parent=tab10, - header="MERGE FRAMES", - icon_name=Keys.DOCUMENTATION.value, - icon_link=Links.CONCAT_VIDEOS.value, - ) - merge_frm_btn = Button( - merge_frm, - text="MERGE FRAMES", - fg="black", - command=lambda: ConcatenatorPopUp(config_path=self.config_path), - ) - plotlyInterface = CreateLabelFrameWithIcon( - parent=tab10, - header="PLOTLY / DASH", - icon_name=Keys.DOCUMENTATION.value, - icon_link=Links.PLOTLY.value, - ) - plotlyInterfaceTitles = [ - "Sklearn results", - "Time bin analyses", - "Probabilities", - "Severity analysis", - ] + merge_frm = CreateLabelFrameWithIcon(parent=tab10, header="MERGE FRAMES", icon_name=Keys.DOCUMENTATION.value, icon_link=Links.CONCAT_VIDEOS.value) + merge_frm_btn = Button(merge_frm, text="MERGE FRAMES", fg="black", command=lambda: ConcatenatorPopUp(config_path=self.config_path)) + plotlyInterface = CreateLabelFrameWithIcon(parent=tab10, header="PLOTLY / DASH", icon_name=Keys.DOCUMENTATION.value, icon_link=Links.PLOTLY.value) + plotlyInterfaceTitles = ["Sklearn results", "Time bin analyses", "Probabilities", "Severity analysis"] toIncludeVar = [] for i in range(len(plotlyInterfaceTitles) + 1): toIncludeVar.append(IntVar()) plotlyCheckbox = [0] * (len(plotlyInterfaceTitles) + 1) for i in range(len(plotlyInterfaceTitles)): - plotlyCheckbox[i] = Checkbutton( - plotlyInterface, text=plotlyInterfaceTitles[i], variable=toIncludeVar[i] - ) + plotlyCheckbox[i] = Checkbutton(plotlyInterface, text=plotlyInterfaceTitles[i], variable=toIncludeVar[i]) plotlyCheckbox[i].grid(row=i, sticky=W) - button_save_plotly_file = Button( - plotlyInterface, - text="Save SimBA / Plotly dataset", - command=lambda: self.generateSimBPlotlyFile(toIncludeVar), - ) - self.plotly_file = FileSelect( - plotlyInterface, - "SimBA Dashboard file (H5)", - title="Select SimBA/Plotly dataset (h5)", - ) - self.groups_file = FileSelect( - plotlyInterface, "SimBA Groups file (CSV)", title="Select groups file (csv" - ) - button_open_plotly_interface = Button( - plotlyInterface, - text="Open SimBA / Plotly dataset", - fg="black", - command=lambda: [self.open_plotly_interface("http://127.0.0.1:8050")], - ) + button_save_plotly_file = Button(plotlyInterface, text="Save SimBA / Plotly dataset", command=lambda: self.generateSimBPlotlyFile(toIncludeVar)) + self.plotly_file = FileSelect( plotlyInterface, "SimBA Dashboard file (H5)", title="Select SimBA/Plotly dataset (h5)") + self.groups_file = FileSelect(plotlyInterface, "SimBA Groups file (CSV)", title="Select groups file (csv)") + button_open_plotly_interface = Button(plotlyInterface, text="Open SimBA / Plotly dataset", fg="black", command=lambda: [self.open_plotly_interface("http://127.0.0.1:8050")]) # addons - lbl_addon = LabelFrame( - tab11, - text="SimBA Expansions", - pady=5, - padx=5, - font=Formats.LABELFRAME_HEADER_FORMAT.value, - fg="black", - ) - button_bel = Button( - lbl_addon, - text="Pup retrieval - Analysis Protocol 1", - fg="blue", - command=lambda: PupRetrievalPopUp(config_path=self.config_path), - ) - cue_light_analyser_btn = Button( - lbl_addon, - text="Cue light analysis", - fg="red", - command=lambda: CueLightAnalyzerMenu(config_path=self.config_path), - ) - anchored_roi_analysis_btn = Button( - lbl_addon, - text="Animal-anchored ROI analysis", - fg="orange", - command=lambda: BoundaryMenus(config_path=self.config_path), - ) + lbl_addon = LabelFrame(tab11, text="SimBA Expansions", pady=5, padx=5, font=Formats.LABELFRAME_HEADER_FORMAT.value, fg="black") + button_bel = Button(lbl_addon, text="Pup retrieval - Analysis Protocol 1", fg="blue", command=lambda: PupRetrievalPopUp(config_path=self.config_path)) + + cue_light_analyser_btn = Button(lbl_addon, text="Cue light analysis", fg="red", command=lambda: CueLightAnalyzerMenu(config_path=self.config_path)) + anchored_roi_analysis_btn = Button(lbl_addon, text="Animal-anchored ROI analysis", fg="orange", command=lambda: BoundaryMenus(config_path=self.config_path)) + - self.create_import_videos_menu(parent_frm=import_frm, idx_row=0, idx_column=0) - self.create_import_pose_menu(parent_frm=import_frm, idx_row=1, idx_column=0) + ImportVideosFrame(parent_frm=import_frm, config_path=config_path, idx_row=0, idx_column=0) + ImportPoseFrame(parent_frm=import_frm, idx_row=1, idx_column=0, config_path=config_path) further_methods_frm.grid(row=0, column=1, sticky=NW, pady=5, padx=5) extract_frm_btn.grid(row=1, column=0, sticky=NW) import_frm_dir_btn.grid(row=2, column=0, sticky=NW) diff --git a/simba/data_processors/interpolate.py b/simba/data_processors/interpolate.py new file mode 100644 index 000000000..b9ff0e26a --- /dev/null +++ b/simba/data_processors/interpolate.py @@ -0,0 +1,114 @@ +import pandas as pd +pd.options.mode.chained_assignment = None +import os +from typing import Optional, Union, List +from copy import deepcopy +import numpy as np + +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + +from simba.mixins.config_reader import ConfigReader +from simba.utils.checks import (check_str, check_valid_lst, check_file_exist_and_readable) +from simba.utils.enums import TagNames +from simba.utils.errors import DataHeaderError, InvalidInputError +from simba.utils.printing import SimbaTimer, log_event, stdout_success +from simba.utils.read_write import (find_files_of_filetypes_in_directory, get_fn_ext, read_df, write_df, copy_files_to_directory) +from simba.utils.data import animal_interpolator, body_part_interpolator + + +class Interpolate(ConfigReader): + """ + Interpolate missing body-parts in pose-estimation data. "Missing" is defined as either (i) when a single body-parts is None, or + when all body-parts belonging to an animal are identical (i.e., the same 2D coordinate or all None). + + .. image:: _static/img/interpolation_comparison.png + :width: 500 + :align: center + + .. note:: + `Interpolation tutorial `__. + + .. importants:: + The interpolated data overwrites the original data on disk. If the original data is required, pass ``copy_originals = True`` to save a copy of the original data. + + + :param Union[str, os.PathLike] config_path: path to SimBA project config file in Configparser format. + :param Union[str, os.PathLike] data_path: Path to a directory, path to a file, or a list of file paths to files with pose-estimation data in CSV or parquet format. + :param Optional[Literal['body-parts', 'animals']] type: If 'animals', then interpolation is performed when all body-parts belonging to an animal are identical (i.e., the same 2D coordinate or all None). If 'body-parts` then all body-parts that are None will be interpolated. Default: body-parts. + :param Optional[Literal['nearest', 'linear', 'quadratic']] method: If 'animals', then interpolation is performed when all body-parts belonging to an animal are identical (i.e., the same 2D coordinate or all None). If 'body-parts` then all body-parts that are None will be interpolated. Default: body-parts. + :param Optional[bool] multi_index_df_headers: If truth-like, then the input data is anticipated to have multiple header columns, and output columns will have multiple header columns. Default: False. + :param Optional[bool] copy_originals: If truth-like, then the pre-interpolated, original data, will be bo stored in a subdirectory of the original data. The subdirectory is named according to the type of interpolation and datetime of the operation. + + :example: + >>> interpolator = Interpolate(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini', data_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/csv/input_csv/test', type='body-parts', multi_index_df_headers=True, copy_originals=True) + >>> interpolator.run() + + """ + def __init__(self, + config_path: Union[str, os.PathLike], + data_path: Union[str, os.PathLike, List[Union[str, os.PathLike]]], + type: Optional[Literal['body-parts', 'animals']] = 'body-parts', + method: Optional[Literal['nearest', 'linear', 'quadratic']] = 'nearest', + multi_index_df_headers: Optional[bool] = False, + copy_originals: Optional[bool] = False) -> None: + + log_event(logger_name=str(self.__class__.__name__), log_type=TagNames.CLASS_INIT.value, msg=self.create_log_msg_from_init_args(locals=locals())) + ConfigReader.__init__(self, config_path=config_path, read_video_info=False) + check_str(name=f'{self.__class__.__name__} type', value=type.lower(), options=('body-parts', 'animals')) + check_str(name=f'{self.__class__.__name__} method', value=method.lower(), options=('nearest', 'linear', 'quadratic')) + if isinstance(data_path, list): + check_valid_lst(data=data_path, source=self.__class__.__name__, valid_dtypes=(str,)) + for i in data_path: check_file_exist_and_readable(file_path=i) + self.file_paths = deepcopy(data_path) + elif os.path.isdir(data_path): + self.file_paths = find_files_of_filetypes_in_directory(directory=data_path, extensions=[f'.{self.file_type}'], raise_error=True) + elif os.path.isfile(data_path): + check_file_exist_and_readable(file_path=data_path) + self.file_paths = [data_path] + else: + raise InvalidInputError(msg=f'{data_path} is not a valid data directory, or a valid file path, or a valid list of file paths', source=self.__class__.__name__) + if copy_originals: + self.originals_dir = os.path.join(os.path.dirname(self.file_paths[0]), f"Pre_{method}_{type}_interpolation_{self.datetime}") + os.makedirs(self.originals_dir) + self.type, self.method, self.multi_index_df_headers, self.copy_originals = type.lower(), method.lower(), multi_index_df_headers, copy_originals + + def __insert_multiindex_header(self, df: pd.DataFrame): + multi_idx_header = [] + for i in range(len(df.columns)): + multi_idx_header.append(("IMPORTED_POSE", "IMPORTED_POSE", list(df.columns)[i])) + df.columns = pd.MultiIndex.from_tuples(multi_idx_header) + return df + + def run(self): + print(f'Running interpolation on {len(self.file_paths)} data files...') + for file_cnt, file_path in enumerate(self.file_paths): + video_timer = SimbaTimer(start=True) + _, self.video_name, _ = get_fn_ext(filepath=file_path) + df = read_df(file_path=file_path, file_type=self.file_type, check_multiindex=self.multi_index_df_headers) + if self.multi_index_df_headers: + if len(df.columns) != len(self.bp_headers): + raise DataHeaderError( msg=f"The file {file_path} contains {len(df.columns)} columns, but your SimBA project expects {len(self.bp_headers)} columns representing {int(len(self.bp_headers) / 3)} body-parts (x, y, p). Check that the {self.body_parts_path} lists the correct body-parts associated with the project", source=self.__class__.__name__) + df.columns = self.bp_headers + df = df.apply(pd.to_numeric, errors="coerce").fillna(0) + df[df < 0] = 0 + if self.type == 'animals': + df = animal_interpolator(df=df, animal_bp_dict=self.animal_bp_dict, source=file_path, method=self.method) + else: + df = body_part_interpolator(df=df, animal_bp_dict=self.animal_bp_dict, source=file_path, method=self.method) + if self.multi_index_df_headers: + df = self.__insert_multiindex_header(df=df) + if self.copy_originals: + copy_files_to_directory(file_paths=[file_path], dir=self.originals_dir) + write_df(df=df.astype(np.int64), file_type=self.file_type, save_path=file_path, multi_idx_header=self.multi_index_df_headers) + video_timer.stop_timer() + print(f"Video {self.video_name} interpolated (elapsed time {video_timer.elapsed_time_str}) ...") + self.timer.stop_timer() + if self.copy_originals: + msg = f"{len(self.file_paths)} data file(s) interpolated using {self.type} {self.method} methods. Originals saved in {self.originals_dir} directory." + else: + msg = f"{len(self.file_paths)} data file(s) interpolated using {self.type} {self.method} methods." + stdout_success(msg=msg, elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__) + diff --git a/simba/data_processors/interpolation_smoothing.py b/simba/data_processors/interpolation_smoothing.py index 1507676b0..5db5d725a 100644 --- a/simba/data_processors/interpolation_smoothing.py +++ b/simba/data_processors/interpolation_smoothing.py @@ -20,8 +20,7 @@ from simba.utils.enums import Methods, TagNames from simba.utils.errors import DataHeaderError, NoFilesFoundError from simba.utils.printing import SimbaTimer, log_event, stdout_success -from simba.utils.read_write import (find_files_of_filetypes_in_directory, - find_video_of_file, get_fn_ext, +from simba.utils.read_write import (find_files_of_filetypes_in_directory, find_video_of_file, get_fn_ext, get_video_meta_data, read_df, write_df) diff --git a/simba/data_processors/smoothing.py b/simba/data_processors/smoothing.py new file mode 100644 index 000000000..7fcd9d2a5 --- /dev/null +++ b/simba/data_processors/smoothing.py @@ -0,0 +1,121 @@ +__author__ = "Simon Nilsson" + +import os +from typing import Union, List, Optional +from copy import deepcopy + +import pandas as pd + +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + +from simba.mixins.config_reader import ConfigReader +from simba.utils.checks import (check_str, check_int, check_valid_lst, check_file_exist_and_readable) +from simba.utils.enums import TagNames +from simba.utils.errors import InvalidInputError, NoFilesFoundError +from simba.utils.printing import SimbaTimer, log_event, stdout_success +from simba.utils.read_write import (find_files_of_filetypes_in_directory, find_video_of_file, get_fn_ext, read_video_info, get_video_meta_data, read_df, write_df, copy_files_to_directory) +from simba.utils.data import savgol_smoother, df_smoother + + +class Smoothing(ConfigReader): + """ + Smooth pose-estimation data according to user-defined method. + + .. image:: _static/img/smoothing.gif + :width: 600 + :align: center + + .. note:: + `Smoothing tutorial `__. + + .. importants:: + The wmoothened data overwrites the original data on disk. If the original data is required, pass ``copy_originals = True`` to save a copy of the original data. + + :param Union[str, os.PathLike] config_path: path to SimBA project config file in Configparser format. + :param Union[str, os.PathLike, List[Union[str, os.PathLike]]] data_path: Path to directory containing pose-estimation data, to a file containing pose-estimation data, or a list of paths containing pose-estimation data. + :param int time_window: Rolling time window in millisecond to use when smoothing. Larger time-windows and greater smoothing. + :param Optional[Literal["gaussian", "savitzky-golay"]] method: Type of smoothing_method. OPTIONS: ``gaussian``, ``savitzky-golay``. Default `gaussian`. + :param bool multi_index_df_headers: If True, the incoming data is multi-index columns dataframes. Default: False. + :param bool copy_originals: If truth-like, then the pre-smoothened, original data, will be bo stored in a subdirectory of the original data. The subdirectory is named according to the type of smoothing method and datetime of the operation. + + :references: + .. [1] `Video expected putput `__. + + :examples: + >>> smoother = Smoothing(data_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/csv/input_csv/Together_1.csv', config_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini', method='Savitzky-Golay', time_window=500, multi_index_df_headers=True, copy_originals=True) + >>> smoother.run() + """ + + def __init__(self, + config_path: Union[str, os.PathLike], + data_path: Union[str, os.PathLike, List[Union[str, os.PathLike]]], + time_window: int, + method: Optional[Literal["gaussian", "savitzky-golay"]] = 'Savitzky-Golay', + multi_index_df_headers: Optional[bool] = False, + copy_originals: Optional[bool] = False) -> None: + + ConfigReader.__init__(self, config_path=config_path, read_video_info=False) + log_event(logger_name=str(self.__class__.__name__), log_type=TagNames.CLASS_INIT.value, msg=self.create_log_msg_from_init_args(locals=locals())) + if isinstance(data_path, list): + check_valid_lst(data=data_path, source=self.__class__.__name__, valid_dtypes=(str,)) + for i in data_path: check_file_exist_and_readable(file_path=i) + self.file_paths = deepcopy(data_path) + elif os.path.isdir(data_path): + self.file_paths = find_files_of_filetypes_in_directory(directory=data_path, extensions=[f'.{self.file_type}'], raise_error=True) + elif os.path.isfile(data_path): + check_file_exist_and_readable(file_path=data_path) + self.file_paths = [data_path] + else: + raise InvalidInputError(msg=f'{data_path} is not a valid data directory, or a valid file path, or a valid list of file paths', source=self.__class__.__name__) + check_int(value=time_window, min_value=1, name=f'{self.__class__.__name__} time_window') + check_str(name=f'{self.__class__.__name__} method', value=method.lower(), options=("gaussian", "savitzky-golay")) + if copy_originals: + self.originals_dir = os.path.join(os.path.dirname(self.file_paths[0]), f"Pre_{method}_{time_window}_smoothing_{self.datetime}") + os.makedirs(self.originals_dir) + self.multi_index_df_headers, self.method, self.time_window, self.copy_originals = multi_index_df_headers, method.lower(), time_window, copy_originals + + def __insert_multiindex_header(self, df: pd.DataFrame): + multi_idx_header = [] + for i in range(len(df.columns)): + multi_idx_header.append(("IMPORTED_POSE", "IMPORTED_POSE", list(df.columns)[i])) + df.columns = pd.MultiIndex.from_tuples(multi_idx_header) + return df + + def run(self): + print(f'Running smoothing on {len(self.file_paths)} data files...') + for file_cnt, file_path in enumerate(self.file_paths): + df = read_df(file_path=file_path, file_type=self.file_type, check_multiindex=True) + video_timer = SimbaTimer(start=True) + _, video_name, _ = get_fn_ext(filepath=file_path) + video_path = find_video_of_file(video_dir=self.video_dir, filename=video_name, raise_error=False, warning=False) + if video_path is None: + if not os.path.isfile(self.video_info_path): + raise NoFilesFoundError(msg=f"To perform smoothing, SimBA needs to read the video FPS. SimBA could not find the video {video_name} in represented in the {self.video_dir} directory or in {self.video_info_path} file. Please import the video and/or include it in the video_logs.csv file so SimBA can know the video FPS", source=self.__class__.__name__) + else: + self.video_info_df = self.read_video_info_csv(file_path=self.video_info_path) + video_info = read_video_info(vid_info_df=self.video_info_df,video_name=video_name, raise_error=False) + if video_info is None: + raise NoFilesFoundError(msg=f"To perform smoothing, SimBA needs to read the video FPS. SimBA could not find the video {video_name} in represented in the {self.video_dir} directory or in {self.video_info_path} file. Please import the video and/or include it in the video_logs.csv file so SimBA can know the video FPS", source=self.__class__.__name__) + fps = video_info[2] + else: + fps = get_video_meta_data(video_path=video_path)['fps'] + if self.method == 'savitzky-golay': + df = savgol_smoother(data=df, fps=fps, time_window=self.time_window, source=video_name) + else: + df = df_smoother(data=df, fps=fps, time_window=self.time_window, source=video_name, method='gaussian') + if self.multi_index_df_headers: + df = self.__insert_multiindex_header(df=df) + if self.copy_originals: + copy_files_to_directory(file_paths=[file_path], dir=self.originals_dir) + write_df(df=df, file_type=self.file_type, save_path=file_path, multi_idx_header=self.multi_index_df_headers) + video_timer.stop_timer() + print(f"Video {video_name} smoothed ({self.method}: {str(self.time_window)}ms) (elapsed time {video_timer.elapsed_time_str})...") + self.timer.stop_timer() + if self.copy_originals: + msg = f"{len(self.file_paths)} data file(s) smoothened using {self.method} method and {self.time_window} time-window. Originals saved in {self.originals_dir} directory." + else: + msg = f"{len(self.file_paths)} data file(s) smoothened using {self.method} method and {self.time_window} time-window." + stdout_success(msg=msg, elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__) \ No newline at end of file diff --git a/simba/mixins/config_reader.py b/simba/mixins/config_reader.py index 3120ad021..255904ffd 100644 --- a/simba/mixins/config_reader.py +++ b/simba/mixins/config_reader.py @@ -20,7 +20,7 @@ except: from typing_extensions import Literal -from simba.utils.checks import (check_file_exist_and_readable, check_float, +from simba.utils.checks import (check_file_exist_and_readable, check_if_dir_exists, check_if_filepath_list_is_empty) from simba.utils.enums import ConfigKey, Defaults, Dtypes, Keys, Paths diff --git a/simba/mixins/pop_up_mixin.py b/simba/mixins/pop_up_mixin.py index dbc881c3a..9150a55ee 100644 --- a/simba/mixins/pop_up_mixin.py +++ b/simba/mixins/pop_up_mixin.py @@ -2,35 +2,19 @@ import os from tkinter import * -from tkinter import messagebox -from typing import Callable, Dict, List, Optional, Tuple, Union +from tkinter import ttk +from typing import Callable, Dict, List, Optional, Tuple, Union, Any import PIL.Image from PIL import ImageTk from simba.mixins.config_reader import ConfigReader -from simba.pose_importers import trk_importer -from simba.pose_importers.dlc_importer_csv import ( - import_multiple_dlc_tracking_csv_file, import_single_dlc_tracking_csv_file) -from simba.pose_importers.import_mars import MarsImporter -from simba.pose_importers.madlc_importer import MADLCImporterH5 -from simba.pose_importers.read_DANNCE_mat import (import_DANNCE_file, - import_DANNCE_folder) -from simba.pose_importers.sleap_csv_importer import SLEAPImporterCSV -from simba.pose_importers.sleap_h5_importer import SLEAPImporterH5 -from simba.pose_importers.sleap_slp_importer import SLEAPImporterSLP -from simba.ui.tkinter_functions import (DropDownMenu, Entry_Box, FileSelect, - FolderSelect, hxtScrollbar) -from simba.utils.checks import (check_file_exist_and_readable, check_float, - check_if_dir_exists, check_int, check_str) -from simba.utils.enums import ConfigKey, Formats, Options +from simba.ui.tkinter_functions import (DropDownMenu, Entry_Box, FileSelect, hxtScrollbar) +from simba.utils.checks import (check_float, check_int, check_valid_lst, check_instance) +from simba.utils.enums import Formats, Options from simba.utils.errors import CountError, NoFilesFoundError -from simba.utils.lookups import (get_color_dict, get_icons_paths, - get_named_colors) -from simba.utils.read_write import (copy_multiple_videos_to_project, - copy_single_video_to_project, - find_core_cnt, read_config_entry, - read_config_file) +from simba.utils.lookups import (get_color_dict, get_icons_paths, get_named_colors) +from simba.utils.read_write import find_core_cnt class PopUpMixin(object): @@ -44,19 +28,12 @@ class PopUpMixin(object): :param bool main_scrollbar: If True, the pop-up window is scrollable. """ - def __init__( - self, - title: str, - config_path: Optional[str] = None, - main_scrollbar: Optional[bool] = True, - size: Tuple[int, int] = (960, 720), - ): + def __init__(self, + title: str, + config_path: Optional[str] = None, + main_scrollbar: Optional[bool] = True, + size: Tuple[int, int] = (960, 720)): - # self.main_frm = Toplevel() - # self.main_frm.minsize(size[0], size[1]) - # self.main_frm.wm_title(title) - # self.main_frm.lift() - # self.main_frm = Canvas(hxtScrollbar(self.main_frm)) self.root = Toplevel() self.root.minsize(size[0], size[1]) self.root.wm_title(title) @@ -76,22 +53,14 @@ def __init__( self.cpu_cnt, _ = find_core_cnt() self.menu_icons = get_icons_paths() for k in self.menu_icons.keys(): - self.menu_icons[k]["img"] = ImageTk.PhotoImage( - image=PIL.Image.open( - os.path.join( - os.path.dirname(__file__), self.menu_icons[k]["icon_path"] - ) - ) - ) + self.menu_icons[k]["img"] = ImageTk.PhotoImage(image=PIL.Image.open(os.path.join(os.path.dirname(__file__), self.menu_icons[k]["icon_path"]))) if config_path: ConfigReader.__init__(self, config_path=config_path, read_video_info=False) - def create_clf_checkboxes( - self, - main_frm: Frame, - clfs: List[str], - title: str = "SELECT CLASSIFIER ANNOTATIONS", - ): + def create_clf_checkboxes(self, + main_frm: Frame, + clfs: List[str], + title: str = "SELECT CLASSIFIER ANNOTATIONS"): """ Creates a labelframe with one checkbox per classifier, and inserts the labelframe into the bottom of the pop-up window. @@ -100,80 +69,64 @@ def create_clf_checkboxes( """ - self.choose_clf_frm = LabelFrame( - self.main_frm, text=title, font=Formats.LABELFRAME_HEADER_FORMAT.value - ) + self.choose_clf_frm = LabelFrame(self.main_frm, text=title, font=Formats.LABELFRAME_HEADER_FORMAT.value) self.clf_selections = {} for clf_cnt, clf in enumerate(clfs): self.clf_selections[clf] = BooleanVar(value=False) - self.calculate_distance_moved_cb = Checkbutton( - self.choose_clf_frm, text=clf, variable=self.clf_selections[clf] - ) + self.calculate_distance_moved_cb = Checkbutton(self.choose_clf_frm, text=clf, variable=self.clf_selections[clf]) self.calculate_distance_moved_cb.grid(row=clf_cnt, column=0, sticky=NW) self.choose_clf_frm.grid(row=self.children_cnt_main(), column=0, sticky=NW) - def create_cb_frame( - self, - main_frm: Frame, - cb_titles: List[str], - frm_title: str, - command: Optional[object] = None, - ) -> Dict[str, BooleanVar]: + def create_cb_frame(self, + cb_titles: List[str], + main_frm: Optional[Union[Frame, Canvas, LabelFrame, ttk.Frame]] = None, + frm_title: Optional[str] = '', + idx_row: Optional[int] = -1, + command: Optional[Callable[[str], Any]] = None) -> Dict[str, BooleanVar]: """ - Creates a labelframe with one checkbox per classifier, and inserts the labelframe into the bottom of the pop-up window. + Creates a labelframe with checkboxes and inserts the labelframe into a window. + + .. image:: _static/img/create_cb_frame.png + :width: 200 + :align: center + + .. note:: + One checkbox will be created per ``cb_titles``. The checkboxes will be labeled according to the ``cb_titles``. + If checking/un-checking the box should have some effect, pass that function as ``command`` which takes the name of the checked/unchecked box. - :param Frame main_frm: The tkinter pop-up window. + :param Optional[Union[Frame, Canvas, LabelFrame, ttk.Frame]] main_frm: The pop-up window to insert the labelframe into. :param List[str] cb_titles: List of strings representing the names of the checkboxes. - :param str frm_title: Title of the frame. + :param Optional[str] frm_title: Title of the frame. + :param Optional[int] idx_row: The location in main_frm to create the LabelFrame. If -1, then at the bottom. + :param Optional[Callable[[str], Any]] frm_title: Optional function callable associated with checking/unchecking the checkboxes. :return Dict[str, BooleanVar]: Dictionary holding the ``cb_titles`` as keys and the BooleanVar representing if the checkbox is ticked or not. + + :example: + >>> PopUpMixin.create_cb_frame(cb_titles=['Attack', 'Sniffing', 'Rearing'], frm_title='My classifiers') """ - cb_frm = LabelFrame( - main_frm, text=frm_title, font=Formats.LABELFRAME_HEADER_FORMAT.value - ) + check_valid_lst(data=cb_titles, source=f'{PopUpMixin.create_cb_frame.__name__} cb_titles', valid_dtypes=(str,), min_len=1) + check_int(name=f'{PopUpMixin.create_cb_frame.__name__} idx_row', value=idx_row, min_value=-1) + + if main_frm is not None: + check_instance(source=f'{PopUpMixin.create_cb_frame.__name__} parent_frm', accepted_types=(Frame, Canvas, LabelFrame, ttk.Frame), instance=main_frm) + else: + main_frm = Toplevel(); main_frm.minsize(960, 720); main_frm.lift() + if idx_row == -1: + idx_row = int(len(list(main_frm.children.keys()))) + cb_frm = LabelFrame(main_frm, text=frm_title, font=Formats.LABELFRAME_HEADER_FORMAT.value) cb_dict = {} for cnt, title in enumerate(cb_titles): cb_dict[title] = BooleanVar(value=False) if command is not None: - if isinstance(command, object): - cb = Checkbutton( - cb_frm, - text=title, - variable=cb_dict[title], - command=lambda k=cb_titles[cnt]: command(k), - ) + cb = Checkbutton(cb_frm, text=title, variable=cb_dict[title], command=lambda k=cb_titles[cnt]: command(k)) else: cb = Checkbutton(cb_frm, text=title, variable=cb_dict[title]) cb.grid(row=cnt, column=0, sticky=NW) - cb_frm.grid(row=self.children_cnt_main(), column=0, sticky=NW) + cb_frm.grid(row=idx_row, column=0, sticky=NW) + # main_frm.mainloop() return cb_dict - def frame_of_radiobuttons( - self, - main_frm: Frame, - titles: List[str], - frm_title: str, - command: Optional[object] = None, - default_idx: Optional[int] = 0, - ): - - selection_var = StringVar() - radiobuttons_frm = LabelFrame( - main_frm, text=frm_title, font=Formats.LABELFRAME_HEADER_FORMAT.value - ) - for cnt, title in enumerate(titles): - radio_button = Radiobutton( - main_frm, - text=titles[cnt], - variable=selection_var, - value=titles[cnt], - command=command, - ) - radio_button.grid(row=cnt, column=0, sticky=NW) - radiobuttons_frm.grid(row=self.children_cnt_main(), column=0, sticky=NW) - selection_var.set(value=titles[default_idx]) - return selection_var - def place_frm_at_top_right(self, frm: Toplevel): """ Place a TopLevel tkinter pop-up at the top right of the monitor. Note: call before putting scrollbars or converting to Canvas. @@ -183,93 +136,50 @@ def place_frm_at_top_right(self, frm: Toplevel): x_position = screen_width - window_width frm.geometry(f"{window_width}x{window_height}+{x_position}+{0}") - def create_dropdown_frame( - self, - main_frm: Frame, - drop_down_titles: List[str], - drop_down_options: List[str], - frm_title: str, - ) -> Dict[str, DropDownMenu]: + def create_dropdown_frame(self, + drop_down_titles: List[str], + drop_down_options: List[str], + frm_title: Optional[str] = '', + idx_row: Optional[int] = -1, + main_frm: Optional[Union[Frame, Canvas, LabelFrame, ttk.Frame]] = None) -> Dict[str, DropDownMenu]: + """ - Creates a labelframe with dropdown menus and inserts it at the bottom of the pop-up window. + Creates a labelframe with dropdowns. + + .. image:: _static/img/create_dropdown_frame.png + :width: 300 + :align: center - :param Frame main_frm: The tkinter pop-up window. - :param List[str] drop_down_titles: The dropdown menu names - :param List[str] drop_down_options: The options in each dropdown. All dropdowns must have the same options. - :param str frm_title: Title of the frame. - :return Dict[str, BooleanVar]: Dictionary holding the ``drop_down_titles`` and the drop-down menus as values. + :param Optional[Union[Frame, Canvas, LabelFrame, ttk.Frame]] main_frm: The pop-up window to insert the labelframe into. If None, one will be created. + :param List[str] drop_down_titles: The titles of the dropdown menus. + :param List[str] drop_down_options: The options in each dropdown. Note: All dropdowns must have the same options. + :param Optional[str] frm_title: Title of the frame. + :return Dict[str, BooleanVar]: Dictionary holding the ``drop_down_titles`` as keys and the drop-down menus as values. + :example: + >>> PopUpMixin.create_dropdown_frame(drop_down_titles=['Dropdown 1', 'Dropdown 2', 'Dropdown 2'], drop_down_options=['Option 1', 'Option 2'], frm_title='My dropdown frame') """ - dropdown_frm = LabelFrame( - main_frm, text=frm_title, font=Formats.LABELFRAME_HEADER_FORMAT.value - ) + check_valid_lst(data=drop_down_titles, source=f'{PopUpMixin.create_dropdown_frame.__name__} drop_down_titles', + valid_dtypes=(str,), min_len=1) + check_valid_lst(data=drop_down_options, source=f'{PopUpMixin.create_dropdown_frame.__name__} drop_down_options', valid_dtypes=(str,), min_len=2) + check_int(name=f'{PopUpMixin.create_cb_frame.__name__} idx_row', value=idx_row, min_value=-1) + if main_frm is not None: + check_instance(source=f'{PopUpMixin.create_cb_frame.__name__} parent_frm', accepted_types=(Frame, Canvas, LabelFrame, ttk.Frame), instance=main_frm) + else: + main_frm = Toplevel(); main_frm.minsize(960, 720); main_frm.lift() + if idx_row == -1: + idx_row = int(len(list(main_frm.children.keys()))) + dropdown_frm = LabelFrame(main_frm, text=frm_title, font=Formats.LABELFRAME_HEADER_FORMAT.value) dropdown_dict = {} for cnt, title in enumerate(drop_down_titles): - dropdown_dict[title] = DropDownMenu( - dropdown_frm, title, drop_down_options, "35" - ) + dropdown_dict[title] = DropDownMenu(dropdown_frm, title, drop_down_options, "35") dropdown_dict[title].setChoices(drop_down_options[0]) dropdown_dict[title].grid(row=cnt, column=0, sticky=NW) - dropdown_frm.grid(row=self.children_cnt_main(), column=0, sticky=NW) + dropdown_frm.grid(row=idx_row, column=0, sticky=NW) + # main_frm.mainloop() return dropdown_dict - def create_choose_animal_cnt_dropdown(self): - if hasattr(self, "animal_cnt_frm"): - self.animal_cnt_frm.destroy() - animal_cnt_options = set(range(1, self.project_animal_cnt + 1)) - self.animal_cnt_frm = LabelFrame( - self.main_frm, - text="SELECT NUMBER OF ANIMALS", - font=Formats.LABELFRAME_HEADER_FORMAT.value, - ) - self.animal_cnt_dropdown = DropDownMenu( - self.animal_cnt_frm, "# of animals", animal_cnt_options, "12" - ) - self.animal_cnt_dropdown.setChoices(max(animal_cnt_options)) - self.animal_cnt_confirm_btn = Button( - self.animal_cnt_frm, - text="Confirm", - command=lambda: self.update_body_parts(), - ) - self.animal_cnt_frm.grid(row=self.children_cnt_main(), sticky=NW) - self.animal_cnt_dropdown.grid(row=self.children_cnt_main(), column=1, sticky=NW) - self.animal_cnt_confirm_btn.grid( - row=self.children_cnt_main(), column=2, sticky=NW - ) - self.create_choose_body_parts_frm() - self.update_body_parts() - - def create_choose_body_parts_frm(self): - if hasattr(self, "body_part_frm"): - self.body_part_frm.destroy() - self.body_parts_dropdown_dict = {} - self.body_part_frm = LabelFrame( - self.main_frm, - text="CHOOSE ANIMAL BODY-PARTS", - font=Formats.LABELFRAME_HEADER_FORMAT.value, - name="choose animal body-parts", - ) - self.body_part_frm.grid(row=self.children_cnt_main(), sticky=NW) - - def update_body_parts(self): - for child in self.body_part_frm.winfo_children(): - child.destroy() - for animal_cnt in range(int(self.animal_cnt_dropdown.getChoices())): - animal_name = list(self.animal_bp_dict.keys())[animal_cnt] - self.body_parts_dropdown_dict[animal_name] = DropDownMenu( - self.body_part_frm, - f"{animal_name} body-part:", - self.body_parts_lst, - "25", - ) - self.body_parts_dropdown_dict[animal_name].grid( - row=animal_cnt, column=1, sticky=NW - ) - self.body_parts_dropdown_dict[animal_name].setChoices( - self.body_parts_lst[animal_cnt] - ) - def create_time_bin_entry(self): if hasattr(self, "time_bin_frm"): self.time_bin_frm.destroy() @@ -282,9 +192,7 @@ def create_time_bin_entry(self): self.time_bin_entrybox.grid(row=0, column=0, sticky=NW) self.time_bin_frm.grid(row=self.children_cnt_main(), column=0, sticky=NW) - def create_run_frm( - self, - run_function: Callable, + def create_run_frm(self,run_function: Callable, title: Optional[str] = "RUN", btn_txt_clr: Optional[str] = "black", ) -> None: @@ -310,9 +218,7 @@ def create_run_frm( self.run_frm.grid(row=self.children_cnt_main() + 1, column=0, sticky=NW) self.run_btn.grid(row=0, column=0, sticky=NW) - def create_choose_number_of_body_parts_frm( - self, project_body_parts: List[str], run_function: object - ): + def create_choose_number_of_body_parts_frm(self, project_body_parts: List[str], run_function: object): """ Many menus depend on how many animals the user choose to compute metrics for. Thus, we need to populate the menus dynamically. This function creates a single drop-down menu where the user select the number of animals the @@ -510,9 +416,7 @@ def enable_dropdown_from_checkbox( for menu in dropdown_menus: menu.disable() - def create_entry_boxes_from_entrybox( - self, count: int, parent: Frame, current_entries: list - ): + def create_entry_boxes_from_entrybox(self, count: int, parent: Frame, current_entries: list): check_int(name="CLASSIFIER COUNT", value=count, min_value=1) for entry in current_entries: entry.destroy() @@ -572,627 +476,6 @@ def enable_entrybox_from_checkbox( for box in entry_boxes: box.set_state("disable") - def create_import_pose_menu( - self, parent_frm: Frame, idx_row: int = 0, idx_column: int = 0 - ): - - def run_call( - data_type: str, - interpolation: str, - smoothing: str, - smoothing_window: str, - animal_names: dict, - data_path: str, - tracking_data_type: str or None = None, - ): - - smooth_settings = {} - smooth_settings["Method"] = smoothing - smooth_settings["Parameters"] = {} - smooth_settings["Parameters"]["Time_window"] = smoothing_window - - if smooth_settings["Method"] != "None": - check_int( - name="SMOOTHING TIME WINDOW", value=smoothing_window, min_value=1 - ) - - if self.animal_name_entry_boxes is None: - raise CountError( - msg="Select animal number and animal names BEFORE importing data.", - source=self.__class__.__name__, - ) - - animal_ids = [] - if len(list(animal_names.items())) == 1: - animal_ids.append("Animal_1") - else: - for animal_cnt, animal_entry_box in animal_names.items(): - check_str( - name=f"ANIMAL {str(animal_cnt)} NAME", - value=animal_entry_box.entry_get, - allow_blank=False, - ) - animal_ids.append(animal_entry_box.entry_get) - - self.config = read_config_file(config_path=self.config_path) - self.config.set( - ConfigKey.MULTI_ANIMAL_ID_SETTING.value, - ConfigKey.MULTI_ANIMAL_IDS.value, - ",".join(animal_ids), - ) - self.update_config() - - if data_type == "H5 (multi-animal DLC)": - dlc_multi_animal_importer = MADLCImporterH5( - config_path=self.config_path, - data_folder=data_path, - file_type=tracking_data_type, - id_lst=animal_ids, - interpolation_settings=interpolation, - smoothing_settings=smooth_settings, - ) - dlc_multi_animal_importer.run() - - if data_type == "SLP (SLEAP)": - sleap_importer = SLEAPImporterSLP( - project_path=self.config_path, - data_folder=data_path, - id_lst=animal_ids, - interpolation_settings=interpolation, - smoothing_settings=smooth_settings, - ) - sleap_importer.run() - - if data_type == "TRK (multi-animal APT)": - try: - trk_importer( - self.config_path, - data_path, - animal_ids, - interpolation, - smooth_settings, - ) - except Exception as e: - messagebox.showerror("Error", str(e)) - - if data_type == "CSV (SLEAP)": - sleap_csv_importer = SLEAPImporterCSV( - config_path=self.config_path, - data_folder=data_path, - id_lst=animal_ids, - interpolation_settings=interpolation, - smoothing_settings=smooth_settings, - ) - sleap_csv_importer.run() - - if data_type == "H5 (SLEAP)": - sleap_h5_importer = SLEAPImporterH5( - config_path=self.config_path, - data_folder=data_path, - id_lst=animal_ids, - interpolation_settings=interpolation, - smoothing_settings=smooth_settings, - ) - sleap_h5_importer.run() - - def import_menu(data_type_choice: str): - if hasattr(self, "choice_frm"): - self.choice_frm.destroy() - self.choice_frm = Frame(self.import_tracking_frm) - self.animal_name_entry_boxes = None - - self.interpolation_frm = LabelFrame( - self.choice_frm, text="INTERPOLATION METHOD", pady=5, padx=5 - ) - self.interpolation_dropdown = DropDownMenu( - self.interpolation_frm, - "Interpolation method: ", - Options.INTERPOLATION_OPTIONS_W_NONE.value, - "25", - ) - self.interpolation_dropdown.setChoices("None") - self.interpolation_frm.grid(row=0, column=0, sticky=NW) - self.interpolation_dropdown.grid(row=0, column=0, sticky=NW) - - self.smoothing_frm = LabelFrame( - self.choice_frm, text="SMOOTHING METHOD", pady=5, padx=5 - ) - self.smoothing_dropdown = DropDownMenu( - self.smoothing_frm, - "Smoothing", - Options.SMOOTHING_OPTIONS_W_NONE.value, - "25", - com=self.show_smoothing_entry_box_from_dropdown, - ) - self.smoothing_dropdown.setChoices("None") - self.smoothing_time_eb = Entry_Box( - self.smoothing_frm, - "Period (ms):", - labelwidth="12", - width=10, - validation="numeric", - ) - self.smoothing_frm.grid(row=1, column=0, sticky=NW) - self.smoothing_dropdown.grid(row=0, column=0, sticky=NW) - - if data_type_choice in [ - "CSV (DLC/DeepPoseKit)", - "MAT (DANNCE 3D)", - "JSON (BENTO)", - ]: - if data_type_choice == "CSV (DLC/DeepPoseKit)": - self.import_directory_frm = LabelFrame( - self.choice_frm, text="IMPORT DLC CSV DIRECTORY", pady=5, padx=5 - ) - self.import_directory_select = FolderSelect( - self.import_directory_frm, "Input DIRECTORY:", lblwidth=25 - ) - self.import_dir_btn = Button( - self.import_directory_frm, - fg="blue", - text="Import DIRECTORY to SimBA project", - command=lambda: import_multiple_dlc_tracking_csv_file( - config_path=self.config_path, - interpolation_setting=self.interpolation_dropdown.getChoices(), - smoothing_setting=self.smoothing_dropdown.getChoices(), - smoothing_time=self.smoothing_time_eb.entry_get, - data_dir=self.import_directory_select.folder_path, - ), - ) - - self.import_single_frm = LabelFrame( - self.choice_frm, text="IMPORT DLC CSV FILE", pady=5, padx=5 - ) - self.import_file_select = FileSelect( - self.import_single_frm, - "Input FILE:", - lblwidth=25, - file_types=[("CSV", "*.csv")], - ) - self.import_file_btn = Button( - self.import_single_frm, - fg="blue", - text="Import FILE to SimBA project", - command=lambda: import_single_dlc_tracking_csv_file( - config_path=self.config_path, - interpolation_setting=self.interpolation_dropdown.getChoices(), - smoothing_setting=self.smoothing_dropdown.getChoices(), - smoothing_time=self.smoothing_time_eb.entry_get, - file_path=self.import_file_select.file_path, - ), - ) - - elif data_type_choice == "MAT (DANNCE 3D)": - self.import_directory_frm = LabelFrame( - self.choice_frm, - text="IMPORT DANNCE MAT DIRECTORY", - pady=5, - padx=5, - ) - self.import_directory_select = FolderSelect( - self.import_directory_frm, "Input DIRECTORY:", lblwidth=25 - ) - self.import_dir_btn = Button( - self.import_directory_frm, - fg="blue", - text="Import directory to SimBA project", - command=lambda: import_DANNCE_folder( - config_path=self.config_path, - folder_path=self.import_directory_select.folder_path, - interpolation_method=self.interpolation_dropdown.getChoices(), - ), - ) - - self.import_single_frm = LabelFrame( - self.choice_frm, text="IMPORT DANNCE CSV FILE", pady=5, padx=5 - ) - self.import_file_select = FileSelect( - self.import_single_frm, "Input FILE:", lblwidth=25 - ) - self.import_file_btn = Button( - self.import_single_frm, - fg="blue", - text="Import file to SimBA project", - command=lambda: import_DANNCE_file( - config_path=self.config_path, - file_path=self.import_file_select.file_path, - interpolation_method=self.interpolation_dropdown.getChoices(), - ), - ) - - elif data_type_choice == "JSON (BENTO)": - self.import_directory_frm = LabelFrame( - self.choice_frm, - text="IMPORT MARS JSON DIRECTORY", - pady=5, - padx=5, - ) - self.import_directory_select = FolderSelect( - self.import_directory_frm, "Input DIRECTORY:", lblwidth=25 - ) - self.import_dir_btn = Button( - self.import_directory_frm, - fg="blue", - text="Import directory to SimBA project", - command=lambda: MarsImporter( - config_path=self.config_path, - data_path=self.import_directory_select.folder_path, - interpolation_method=self.interpolation_dropdown.getChoices(), - smoothing_method={ - "Method": self.smoothing_dropdown.getChoices(), - "Parameters": { - "Time_window": self.smoothing_time_eb.entry_get - }, - }, - ), - ) - - self.import_single_frm = LabelFrame( - self.choice_frm, text="IMPORT MARS JSON FILE", pady=5, padx=5 - ) - self.import_file_select = FileSelect( - self.import_single_frm, "Input FILE:", lblwidth=25 - ) - self.import_file_btn = Button( - self.import_single_frm, - fg="blue", - text="Import file to SimBA project", - command=lambda: MarsImporter( - config_path=self.config_path, - data_path=self.import_directory_select.folder_path, - interpolation_method=self.interpolation_dropdown.getChoices(), - smoothing_method={ - "Method": self.smoothing_dropdown.getChoices(), - "Parameters": { - "Time_window": self.smoothing_time_eb.entry_get - }, - }, - ), - ) - self.import_directory_frm.grid(row=2, column=0, sticky=NW) - self.import_directory_select.grid(row=0, column=0, sticky=NW) - self.import_dir_btn.grid(row=1, column=0, sticky=NW) - - self.import_single_frm.grid(row=3, column=0, sticky=NW) - self.import_file_select.grid(row=0, column=0, sticky=NW) - self.import_file_btn.grid(row=1, column=0, sticky=NW) - - elif data_type_choice in [ - "SLP (SLEAP)", - "H5 (multi-animal DLC)", - "TRK (multi-animal APT)", - "CSV (SLEAP)", - "H5 (SLEAP)", - ]: - self.animal_settings_frm = LabelFrame( - self.choice_frm, text="ANIMAL SETTINGS", pady=5, padx=5 - ) - animal_cnt_entry_box = Entry_Box( - self.animal_settings_frm, - "ANIMAL COUNT:", - "25", - validation="numeric", - ) - animal_cnt_entry_box.entry_set(val=self.project_animal_cnt) - animal_cnt_confirm = Button( - self.animal_settings_frm, - text="CONFIRM", - fg="blue", - command=lambda: self.create_animal_names_entry_boxes( - animal_cnt=animal_cnt_entry_box.entry_get - ), - ) - self.create_animal_names_entry_boxes( - animal_cnt=animal_cnt_entry_box.entry_get - ) - self.animal_settings_frm.grid(row=4, column=0, sticky=NW) - animal_cnt_entry_box.grid(row=0, column=0, sticky=NW) - animal_cnt_confirm.grid(row=0, column=1, sticky=NW) - - self.data_dir_frm = LabelFrame( - self.choice_frm, text="DATA DIRECTORY", pady=5, padx=5 - ) - self.import_frm = LabelFrame( - self.choice_frm, text="IMPORT", pady=5, padx=5 - ) - - if data_type_choice == "H5 (multi-animal DLC)": - self.tracking_type_frm = LabelFrame( - self.choice_frm, text="TRACKING DATA TYPE", pady=5, padx=5 - ) - self.dlc_data_type_option_dropdown = DropDownMenu( - self.tracking_type_frm, - "Tracking type", - Options.MULTI_DLC_TYPE_IMPORT_OPTION.value, - labelwidth=25, - ) - self.dlc_data_type_option_dropdown.setChoices( - Options.MULTI_DLC_TYPE_IMPORT_OPTION.value[1] - ) - self.tracking_type_frm.grid(row=5, column=0, sticky=NW) - self.dlc_data_type_option_dropdown.grid(row=0, column=0, sticky=NW) - - self.data_dir_select = FolderSelect( - self.data_dir_frm, "H5 DLC DIRECTORY: ", lblwidth=25 - ) - self.instructions_lbl = Label( - self.data_dir_frm, - text="Please import videos before importing the \n multi animal DLC tracking data", - ) - self.run_btn = Button( - self.import_frm, - text="IMPORT DLC .H5", - fg="blue", - command=lambda: run_call( - data_type=data_type_choice, - interpolation=self.interpolation_dropdown.getChoices(), - smoothing=self.smoothing_dropdown.getChoices(), - smoothing_window=self.smoothing_time_eb.entry_get, - animal_names=self.animal_name_entry_boxes, - data_path=self.data_dir_select.folder_path, - tracking_data_type=self.dlc_data_type_option_dropdown.getChoices(), - ), - ) - - elif data_type_choice == "SLP (SLEAP)": - self.data_dir_select = FolderSelect( - self.data_dir_frm, "SLP SLEAP DIRECTORY: ", lblwidth=25 - ) - self.instructions_lbl = Label( - self.data_dir_frm, - text="Please import videos before importing the \n multi animal SLEAP tracking data if you are tracking more than ONE animal", - ) - self.run_btn = Button( - self.import_frm, - text="IMPORT SLEAP .SLP", - fg="blue", - command=lambda: run_call( - data_type=data_type_choice, - interpolation=self.interpolation_dropdown.getChoices(), - smoothing=self.smoothing_dropdown.getChoices(), - smoothing_window=self.smoothing_time_eb.entry_get, - animal_names=self.animal_name_entry_boxes, - data_path=self.data_dir_select.folder_path, - ), - ) - - elif data_type_choice == "TRK (multi-animal APT)": - self.data_dir_select = FolderSelect( - self.data_dir_frm, "TRK APT DIRECTORY: ", lblwidth=25 - ) - self.instructions_lbl = Label( - self.data_dir_frm, - text="Please import videos before importing the \n multi animal TRK tracking data", - ) - self.run_btn = Button( - self.import_frm, - text="IMPORT APT .TRK", - fg="blue", - command=lambda: run_call( - data_type=data_type_choice, - interpolation=self.interpolation_dropdown.getChoices(), - smoothing=self.smoothing_dropdown.getChoices(), - smoothing_window=self.smoothing_time_eb.entry_get, - animal_names=self.animal_name_entry_boxes, - data_path=self.data_dir_select.folder_path, - ), - ) - - elif data_type_choice == "CSV (SLEAP)": - self.data_dir_select = FolderSelect( - self.data_dir_frm, "CSV SLEAP DIRECTORY:", lblwidth=25 - ) - self.instructions_lbl = Label( - self.data_dir_frm, - text="Please import videos before importing the sleap csv tracking data \n if you are tracking more than ONE animal", - ) - self.run_btn = Button( - self.import_frm, - text="IMPORT SLEAP .CSV", - fg="blue", - command=lambda: run_call( - data_type=data_type_choice, - interpolation=self.interpolation_dropdown.getChoices(), - smoothing=self.smoothing_dropdown.getChoices(), - smoothing_window=self.smoothing_time_eb.entry_get, - animal_names=self.animal_name_entry_boxes, - data_path=self.data_dir_select.folder_path, - ), - ) - - elif data_type_choice == "H5 (SLEAP)": - self.data_dir_select = FolderSelect( - self.data_dir_frm, "H5 SLEAP DIRECTORY", lblwidth=25 - ) - self.instructions_lbl = Label( - self.data_dir_frm, - text="Please import videos before importing the sleap h5 tracking data \n if you are tracking more than ONE animal", - ) - self.run_btn = Button( - self.import_frm, - text="IMPORT SLEAP H5", - fg="blue", - command=lambda: run_call( - data_type=data_type_choice, - interpolation=self.interpolation_dropdown.getChoices(), - smoothing=self.smoothing_dropdown.getChoices(), - smoothing_window=self.smoothing_time_eb.entry_get, - animal_names=self.animal_name_entry_boxes, - data_path=self.data_dir_select.folder_path, - ), - ) - - self.data_dir_frm.grid( - row=self.frame_children(frame=self.choice_frm), column=0, sticky=NW - ) - self.data_dir_select.grid(row=0, column=0, sticky=NW) - self.instructions_lbl.grid(row=1, column=0, sticky=NW) - self.import_frm.grid( - row=self.frame_children(frame=self.choice_frm) + 1, - column=0, - sticky=NW, - ) - self.run_btn.grid(row=0, column=0, sticky=NW) - self.choice_frm.grid(row=1, column=0, sticky=NW) - - self.import_tracking_frm = LabelFrame( - parent_frm, - text="IMPORT TRACKING DATA", - font=Formats.LABELFRAME_HEADER_FORMAT.value, - fg="black", - ) - if not hasattr(self, "config_path"): - self.instructions_lbl = Label( - self.import_tracking_frm, - text="Please CREATE PROJECT CONFIG before importing tracking data \n", - ) - self.import_tracking_frm.grid(row=0, column=0, sticky=NW) - self.instructions_lbl.grid(row=0, column=0, sticky=NW) - else: - self.config = read_config_file(config_path=self.config_path) - self.project_animal_cnt = read_config_entry( - config=self.config, - section=ConfigKey.GENERAL_SETTINGS.value, - option=ConfigKey.ANIMAL_CNT.value, - data_type="int", - ) - self.data_type_dropdown = DropDownMenu( - self.import_tracking_frm, - "DATA TYPE:", - Options.IMPORT_TYPE_OPTIONS.value, - labelwidth=25, - com=import_menu, - ) - self.data_type_dropdown.setChoices(Options.IMPORT_TYPE_OPTIONS.value[0]) - import_menu(data_type_choice=Options.IMPORT_TYPE_OPTIONS.value[0]) - self.import_tracking_frm.grid(row=idx_row, column=idx_column, sticky=NW) - self.data_type_dropdown.grid(row=0, column=0, sticky=NW) - - def create_import_videos_menu( - self, parent_frm: Frame, idx_row: int = 0, idx_column: int = 0 - ): - - def run_import(multiple_videos: bool): - if multiple_videos: - check_if_dir_exists(in_dir=self.video_directory_select.folder_path) - copy_multiple_videos_to_project( - config_path=self.config_path, - source=self.video_directory_select.folder_path, - symlink=self.multiple_videos_symlink_var.get(), - file_type=self.video_type.getChoices(), - ) - else: - check_file_exist_and_readable( - file_path=self.video_file_select.file_path - ) - copy_single_video_to_project( - simba_ini_path=self.config_path, - symlink=self.single_video_symlink_var.get(), - source_path=self.video_file_select.file_path, - ) - - import_videos_frm = LabelFrame( - parent_frm, - text="IMPORT VIDEOS", - fg="black", - font=Formats.LABELFRAME_HEADER_FORMAT.value, - ) - if not hasattr(self, "config_path"): - self.instructions_lbl = Label( - import_videos_frm, - text="Please CREATE PROJECT CONFIG before importing VIDEOS \n", - ) - import_videos_frm.grid(row=0, column=0, sticky=NW) - self.instructions_lbl.grid(row=0, column=0, sticky=NW) - - else: - import_multiple_videos_frm = LabelFrame( - import_videos_frm, text="IMPORT MULTIPLE VIDEOS" - ) - self.video_directory_select = FolderSelect( - import_multiple_videos_frm, "VIDEO DIRECTORY: ", lblwidth=25 - ) - self.video_type = DropDownMenu( - import_multiple_videos_frm, - "VIDEO FILE FORMAT: ", - Options.VIDEO_FORMAT_OPTIONS.value, - "25", - ) - self.video_type.setChoices(Options.VIDEO_FORMAT_OPTIONS.value[0]) - import_multiple_btn = Button( - import_multiple_videos_frm, - text="Import MULTIPLE videos", - fg="blue", - command=lambda: run_import(multiple_videos=True), - ) - self.multiple_videos_symlink_var = BooleanVar(value=False) - multiple_videos_symlink_cb = Checkbutton( - import_multiple_videos_frm, - text="Import SYMLINKS", - variable=self.multiple_videos_symlink_var, - ) - - import_single_frm = LabelFrame( - import_videos_frm, text="IMPORT SINGLE VIDEO", pady=5, padx=5 - ) - self.video_file_select = FileSelect( - import_single_frm, - "VIDEO PATH: ", - title="Select a video file", - lblwidth=25, - file_types=[("VIDEO FILE", Options.ALL_VIDEO_FORMAT_STR_OPTIONS.value)], - ) - import_single_btn = Button( - import_single_frm, - text="Import SINGLE video", - fg="blue", - command=lambda: run_import(multiple_videos=False), - ) - self.single_video_symlink_var = BooleanVar(value=False) - single_video_symlink_cb = Checkbutton( - import_single_frm, - text="Import SYMLINK", - variable=self.single_video_symlink_var, - ) - - import_videos_frm.grid(row=0, column=0, sticky=NW) - import_multiple_videos_frm.grid(row=0, sticky=W) - self.video_directory_select.grid(row=1, sticky=W) - self.video_type.grid(row=2, sticky=W) - multiple_videos_symlink_cb.grid(row=3, sticky=W) - import_multiple_btn.grid(row=4, sticky=W) - - import_single_frm.grid(row=1, column=0, sticky=NW) - self.video_file_select.grid(row=0, sticky=W) - single_video_symlink_cb.grid(row=1, sticky=W) - import_single_btn.grid(row=2, sticky=W) - import_videos_frm.grid(row=idx_row, column=idx_column, sticky=NW) - - def create_multiprocess_choice( - self, - parent: Frame, - cb_text: str = "Multiprocess videos (faster)", - ): - self.multiprocess_var = BooleanVar(value=False) - self.multiprocess_dropdown = DropDownMenu( - parent, "CPU cores:", list(range(2, self.cpu_cnt)), "12" - ) - multiprocess_cb = Checkbutton( - parent, - text=cb_text, - variable=self.multiprocess_var, - command=lambda: self.enable_dropdown_from_checkbox( - check_box_var=self.multiprocess_var, - dropdown_menus=[self.multiprocess_dropdown], - ), - ) - self.multiprocess_dropdown.setChoices(2) - self.multiprocess_dropdown.disable() - multiprocess_cb.grid(row=self.frame_children(frame=parent), column=0, sticky=NW) - # print(multiprocess_dropdown.popupMenugrid_info) - self.multiprocess_dropdown.grid( - row=multiprocess_cb.grid_info()["row"], column=1, sticky=NW - ) # def quit(self, e): # self.main_frm.quit() diff --git a/simba/pose_importers/dlc_importer_csv.py b/simba/pose_importers/dlc_importer_csv.py index 0e040d810..daafc4a77 100644 --- a/simba/pose_importers/dlc_importer_csv.py +++ b/simba/pose_importers/dlc_importer_csv.py @@ -1,25 +1,20 @@ __author__ = "Simon Nilsson" -import glob import os import shutil -from typing import List, Union - +from typing import List, Union, Dict, Optional, Any import pandas as pd -from simba.data_processors.interpolation_smoothing import Interpolate, Smooth -from simba.utils.checks import (check_file_exist_and_readable, - check_if_dir_exists, - check_if_filepath_list_is_empty, check_int) -from simba.utils.data import smooth_data_gaussian, smooth_data_savitzky_golay -from simba.utils.enums import Methods -from simba.utils.errors import FileExistError, NoFilesFoundError +from simba.mixins.config_reader import ConfigReader +from simba.data_processors.interpolate import Interpolate +from simba.data_processors.smoothing import Smoothing +from simba.utils.checks import (check_file_exist_and_readable, check_int, check_str, check_if_keys_exist_in_dict, check_instance) +from simba.utils.errors import FileExistError, InvalidInputError from simba.utils.printing import SimbaTimer, stdout_success -from simba.utils.read_write import (get_fn_ext, - get_number_of_header_columns_in_df, - read_config_file, - read_project_path_and_file_type) +from simba.utils.read_write import (get_fn_ext, get_number_of_header_columns_in_df, find_files_of_filetypes_in_directory) +DLC_ = 'DLC_' +DeepCut = 'DeepCut' def import_dlc_csv(config_path: Union[str, os.PathLike], source: str) -> List[str]: """ @@ -28,150 +23,104 @@ def import_dlc_csv(config_path: Union[str, os.PathLike], source: str) -> List[st :parameter str config_path: path to SimBA project config file in Configparser format :parameter str source: path to file or folder containing DLC pose-estimation CSV files - :return List[str]: Paths of imported files. + :return List[str]: Paths to location of imported files. :example: >>> import_dlc_csv(config_path='project_folder/project_config.ini', source='CSV_import/Together_1.csv') >>> ['project_folder/csv/input_csv/Together_1.csv'] """ - config = read_config_file(config_path=config_path) - project_path, file_type = read_project_path_and_file_type(config=config) - original_file_name_dir = os.path.join( - project_path, "csv", "input_csv", "original_filename" - ) - input_csv_dir = os.path.join(project_path, "csv", "input_csv") - imported_files = glob.glob(input_csv_dir + "/*." + file_type) - imported_file_names, imported_file_paths = [], [] - for file_path in imported_files: - _, video_name, _ = get_fn_ext(filepath=file_path) - imported_file_names.append(video_name) - if not os.path.exists(original_file_name_dir): - os.makedirs(original_file_name_dir) + check_file_exist_and_readable(file_path=config_path) + conf = ConfigReader(config_path=config_path) + original_file_name_dir = os.path.join(conf.input_csv_dir, "original_filename") + if not os.path.exists(original_file_name_dir): os.makedirs(original_file_name_dir) + prev_imported_file_paths = find_files_of_filetypes_in_directory(directory=conf.input_csv_dir, extensions=[f'.{conf.file_type}'], raise_warning=False, raise_error=False) + prev_imported_file_names = [get_fn_ext(x)[1] for x in prev_imported_file_paths] if os.path.isdir(source): - csv_files = glob.glob(source + "/*.csv") - check_if_filepath_list_is_empty( - csv_files, - error_msg=f"SIMBA ERROR: NO .csv files found in {source} directory.", - ) + new_data_paths = find_files_of_filetypes_in_directory(directory=source, extensions=['.csv'], raise_warning=False, raise_error=True) + elif os.path.isfile(source): + new_data_paths = [source] else: - csv_files = [source] + raise InvalidInputError(msg=f'{source} is not a valid data directory path or file path.', source=import_dlc_csv.__name__) - for file_path in csv_files: + imported_file_paths = [] + for file_cnt, file_path in enumerate(new_data_paths): video_timer = SimbaTimer(start=True) check_file_exist_and_readable(file_path=file_path) _, video_name, file_ext = get_fn_ext(filepath=file_path) - if "DLC_" in video_name: - new_file_name = video_name.split("DLC_")[0] + ".csv" - elif "DeepCut" in video_name: - new_file_name = video_name.split("DeepCut")[0] + ".csv" + if DLC_ in video_name: + new_file_name = video_name.split(DLC_)[0] + ".csv" + elif DeepCut in video_name: + new_file_name = video_name.split(DeepCut)[0] + ".csv" else: new_file_name = video_name.split(".")[0] + ".csv" new_file_name_wo_ext = new_file_name.split(".")[0] video_basename = os.path.basename(file_path) print(f"Importing {video_name} to SimBA project...") - if new_file_name_wo_ext in imported_file_names: - raise FileExistError( - "SIMBA IMPORT ERROR: {} already exist in project. Remove file from project or rename imported video file name before importing.".format( - new_file_name - ) - ) - shutil.copy(file_path, input_csv_dir) + # if new_file_name_wo_ext in prev_imported_file_names: + # raise FileExistError(f"SIMBA IMPORT ERROR: {new_file_name} already exist in project in the directory {conf.input_csv_dir}. Remove file from project or rename imported video file name before importing.") + shutil.copy(file_path, conf.input_csv_dir) shutil.copy(file_path, original_file_name_dir) - os.rename( - os.path.join(input_csv_dir, video_basename), - os.path.join(input_csv_dir, new_file_name), - ) - df = pd.read_csv(os.path.join(input_csv_dir, new_file_name)) + os.rename(os.path.join(conf.input_csv_dir, video_basename), os.path.join(conf.input_csv_dir, new_file_name)) + df = pd.read_csv(os.path.join(conf.input_csv_dir, new_file_name)) header_cols = get_number_of_header_columns_in_df(df=df) if header_cols == 3: df = df.iloc[1:] - if file_type == "parquet": - df = pd.read_csv(os.path.join(input_csv_dir, video_basename)) + if conf.file_type == "parquet": + df = pd.read_csv(os.path.join(conf.input_csv_dir, video_basename)) df = df.apply(pd.to_numeric, errors="coerce") - df.to_parquet(os.path.join(input_csv_dir, new_file_name)) - os.remove(os.path.join(input_csv_dir, video_basename)) - if file_type == "csv": - df.to_csv(os.path.join(input_csv_dir, new_file_name), index=False) - imported_file_paths.append(os.path.join(input_csv_dir, new_file_name)) + df.to_parquet(os.path.join(conf.input_csv_dir, new_file_name)) + os.remove(os.path.join(conf.input_csv_dir, video_basename)) + if conf.file_type == "csv": + df.to_csv(os.path.join(conf.input_csv_dir, new_file_name), index=False) + imported_file_paths.append(os.path.join(conf.input_csv_dir, new_file_name)) video_timer.stop_timer() - print( - f"Pose-estimation data for video {video_name} imported to SimBA project (elapsed time: {video_timer.elapsed_time_str}s)..." - ) + print(f"Pose-estimation data for video {video_name} imported to SimBA project (elapsed time: {video_timer.elapsed_time_str}s)...") return imported_file_paths +def import_dlc_csv_data(config_path: Union[str, os.PathLike], + data_path: Union[str, os.PathLike], + interpolation_settings: Optional[Dict[str, Any]] = None, + smoothing_settings: Optional[Dict[str, Any]] = None) -> None: -def import_single_dlc_tracking_csv_file( - config_path: str, - interpolation_setting: str, - smoothing_setting: str, - smoothing_time: int, - file_path: str, -): - timer = SimbaTimer(start=True) - if (smoothing_setting == Methods.GAUSSIAN.value) or ( - smoothing_setting == Methods.SAVITZKY_GOLAY.value - ): - check_int(name="SMOOTHING TIME WINDOW", value=smoothing_time, min_value=1) - check_file_exist_and_readable(file_path=file_path) - imported_file_paths = import_dlc_csv(config_path=config_path, source=file_path) - if interpolation_setting != "None": - _ = Interpolate( - input_path=imported_file_paths[0], - config_path=config_path, - method=interpolation_setting, - initial_import_multi_index=True, - ) - if (smoothing_setting == Methods.GAUSSIAN.value) or ( - smoothing_setting == Methods.SAVITZKY_GOLAY.value - ): - _ = Smooth( - config_path=config_path, - input_path=imported_file_paths[0], - time_window=smoothing_time, - smoothing_method=smoothing_setting, - initial_import_multi_index=True, - ) - timer.stop_timer() - stdout_success( - msg=f"Imported {str(len(imported_file_paths))} pose estimation file(s)", - elapsed_time=timer.elapsed_time_str, - ) + """ + Import multiple DLC CSV tracking files to SimBA project and apply specified interpolation and smoothing + parameters to the imported data. + :param Union[str, os.PathLike] config_path: Path to SimBA config file in ConfigParser format. + :param Union[str, os.PathLike] data_path: Path to directory holding DLC pose-estimation data in CSV format, or path to a single CSV file with DLC pose-estimation data. + :param Optional[Dict[str, Any]] interpolation_settings: Dictionary holding settings for interpolation. + :param Optional[Dict[str, Any]] smoothing_settings: Dictionary holding settings for smoothing. + :return None: + + :example: + >>> interpolation_settings = {'type': 'body-parts', 'method': 'linear'} + >>> smoothing_settings = None #{'time_window': 500, 'method': 'savitzky-golay'} + >>> import_dlc_csv_data(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini', data_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/new_data', interpolation_settings=interpolation_settings, smoothing_settings=smoothing_settings) + + """ -def import_multiple_dlc_tracking_csv_file( - config_path: str, - interpolation_setting: str, - smoothing_setting: str, - smoothing_time: int, - data_dir: str, -): timer = SimbaTimer(start=True) - if (smoothing_setting == Methods.GAUSSIAN.value) or ( - smoothing_setting == Methods.SAVITZKY_GOLAY.value - ): - check_int(name="SMOOTHING TIME WINDOW", value=smoothing_time, min_value=1) - check_if_dir_exists(in_dir=data_dir) - imported_file_paths = import_dlc_csv(config_path=config_path, source=data_dir) - if interpolation_setting != "None": - _ = Interpolate( - input_path=os.path.dirname(imported_file_paths[0]), - config_path=config_path, - method=interpolation_setting, - initial_import_multi_index=True, - ) - if (smoothing_setting == Methods.GAUSSIAN.value) or ( - smoothing_setting == Methods.SAVITZKY_GOLAY.value - ): - _ = Smooth( - config_path=config_path, - input_path=os.path.dirname(imported_file_paths[0]), - time_window=int(smoothing_time), - smoothing_method=smoothing_setting, - initial_import_multi_index=True, - ) + check_file_exist_and_readable(file_path=config_path) + if (not os.path.isdir(data_path)) and (not os.path.isfile(data_path)): + raise InvalidInputError(msg=f'{data_path} is not a valid data directory path or file path.', source=import_dlc_csv.__name__) + if interpolation_settings is not None: + check_instance(source=f'{import_dlc_csv_data.__name__} interpolation_settings', accepted_types=(dict,), instance=interpolation_settings) + check_if_keys_exist_in_dict(data=interpolation_settings, key=['type', 'method']) + check_str(name='type', value=interpolation_settings['type'].lower(), options=['animals', 'body-parts']) + check_str(name='method', value=interpolation_settings['method'].lower(), options=['nearest', 'linear', 'quadratic']) + if smoothing_settings is not None: + check_instance(source=f'{import_dlc_csv_data.__name__} smoothing_settings', accepted_types=(dict,), instance=smoothing_settings) + check_if_keys_exist_in_dict(data=smoothing_settings, key=['time_window', 'method']) + check_int(name='time_window', value=smoothing_settings['time_window'], min_value=1) + check_str(name='method', value=smoothing_settings['method'].lower(), options=['savitzky-golay', 'gaussian']) + + imported_file_paths = import_dlc_csv(config_path=config_path, source=data_path) + if interpolation_settings != None: + interpolator = Interpolate(config_path=config_path, data_path=imported_file_paths, type=interpolation_settings['type'], method=interpolation_settings['method'], multi_index_df_headers=True, copy_originals=False) + interpolator.run() + if smoothing_settings != None: + smoother = Smoothing(config_path=config_path, data_path=imported_file_paths, time_window=smoothing_settings['time_window'], method=smoothing_settings['method'], multi_index_df_headers=True, copy_originals=False) + smoother.run() timer.stop_timer() - stdout_success( - msg=f"Imported {str(len(imported_file_paths))} pose estimation file(s)", - elapsed_time=timer.elapsed_time_str, - ) + stdout_success(msg=f"Imported {len(imported_file_paths)} pose estimation file(s) to directory", elapsed_time=timer.elapsed_time_str) diff --git a/simba/pose_importers/madlc_importer.py b/simba/pose_importers/madlc_importer.py index ea481390b..123d38c7b 100644 --- a/simba/pose_importers/madlc_importer.py +++ b/simba/pose_importers/madlc_importer.py @@ -12,18 +12,18 @@ import numpy as np import pandas as pd -from simba.data_processors.interpolation_smoothing import Interpolate, Smooth +from simba.data_processors.interpolate import Interpolate +from simba.data_processors.smoothing import Smoothing from simba.mixins.config_reader import ConfigReader from simba.mixins.pose_importer_mixin import PoseImporterMixin from simba.utils.checks import (check_file_exist_and_readable, check_if_dir_exists, check_if_keys_exist_in_dict, check_instance, - check_str, check_valid_lst) + check_str, check_valid_lst, check_int) from simba.utils.enums import Formats, Methods, Options from simba.utils.errors import BodypartColumnNotFoundError from simba.utils.printing import SimbaTimer, stdout_success -from simba.utils.read_write import (find_all_videos_in_project, - get_video_meta_data, write_df) +from simba.utils.read_write import (find_all_videos_in_project, get_video_meta_data, write_df) class MADLCImporterH5(ConfigReader, PoseImporterMixin): @@ -35,17 +35,14 @@ class MADLCImporterH5(ConfigReader, PoseImporterMixin): :parameter str data_folder: Path to folder containing maDLC data in ``.h5`` format. :parameter str file_type: Method used to perform pose-estimation in maDLC. OPTIONS: `skeleton`, `box`, `ellipse`. :param List[str] id_lst: Names of animals. - :parameter str interpolation_setting: String defining the pose-estimation interpolation method. OPTIONS: 'None', 'Animal(s): Nearest', - 'Animal(s): Linear', 'Animal(s): Quadratic','Body-parts: Nearest', 'Body-parts: Linear', - 'Body-parts: Quadratic'. - :parameter dict smoothing_settings: Dictionary defining the pose estimation smoothing method. EXAMPLE: {'Method': 'Savitzky Golay', - 'Parameters': {'Time_window': '200'}}) + :parameter Optional[Dict[str, str]] interpolation_setting: Dict defining the type and method to use to perform interpolation {'type': 'animals', 'method': 'linear'}. + :parameter Optional[Dict[str, Union[str, int]]] smoothing_settings: Dictionary defining the pose estimation smoothing method {'time_window': 500, 'method': 'gaussian'}. .. note:: `Multi-animal import tutorial `__. :examples: - >>> _ = MADLCImporterH5(config_path=r'MyConfigPath', data_folder=r'maDLCDataFolder', file_type='ellipse', id_lst=['Animal_1', 'Animal_2'], interpolation_settings='None', smoothing_settings={'Method': 'None', 'Parameters': {'Time_window': '200'}}).run() + >>> _ = MADLCImporterH5(config_path=r'MyConfigPath', data_folder=r'maDLCDataFolder', file_type='ellipse', id_lst=['Animal_1', 'Animal_2'], interpolation_settings={'type': 'animals', 'method': 'linear'}, smoothing_settings={'time_window': 500, 'method': 'gaussian'}).run() References ---------- @@ -58,7 +55,7 @@ def __init__(self, data_folder: Union[str, os.PathLike], file_type: Literal['skeleton', 'box', 'ellipse'], id_lst: List[str], - interpolation_settings: Optional[Union[None, Literal['None', 'Animal(s): Nearest', 'Animal(s): Linear', 'Animal(s): Quadratic','Body-parts: Nearest', 'Body-parts: Linear', 'Body-parts: Quadratic']]] = None, + interpolation_settings: Optional[Dict[str, str]] = None, smoothing_settings: Optional[Dict[str, Any]] = None): check_file_exist_and_readable(file_path=config_path) @@ -66,13 +63,17 @@ def __init__(self, check_str(name=f'{self.__class__.__name__} file_type', value=file_type, options=Options.MULTI_DLC_TYPE_IMPORT_OPTION.value) check_valid_lst(data=id_lst, source=f'{self.__class__.__name__} id_lst', valid_dtypes=(str,)) if interpolation_settings is not None: - check_str(name=f'{self.__class__.__name__} interpolation_settings', value=interpolation_settings, options=Options.INTERPOLATION_OPTIONS_W_NONE.value) + check_if_keys_exist_in_dict(data=interpolation_settings, key=['method', 'type'], name=f'{self.__class__.__name__} interpolation_settings') + check_str(name=f'{self.__class__.__name__} interpolation_settings type', value=interpolation_settings['type'], options=('body-parts', 'animals')) + check_str(name=f'{self.__class__.__name__} interpolation_settings method', value=interpolation_settings['method'], options=('linear', 'quadratic', 'nearest')) if smoothing_settings is not None: - check_instance(source=f'{self.__class__.__name__} smoothing_settings', instance=smoothing_settings, accepted_types=(dict,)) - check_if_keys_exist_in_dict(data=smoothing_settings, key=['Method', 'Parameters']) + check_if_keys_exist_in_dict(data=smoothing_settings, key=['method', 'time_window'], name=f'{self.__class__.__name__} smoothing_settings') + check_str(name=f'{self.__class__.__name__} smoothing_settings method', value=smoothing_settings['method'], options=('savitzky-golay', 'gaussian')) + check_int(name=f'{self.__class__.__name__} smoothing_settings time_window', value=smoothing_settings['time_window'], min_value=1) + ConfigReader.__init__(self, config_path=config_path, read_video_info=False) PoseImporterMixin.__init__(self) - self.interpolation_settings, self.smoothing_settings = (interpolation_settings, smoothing_settings) + self.interpolation_settings, self.smoothing_settings = interpolation_settings, smoothing_settings self.data_folder, self.id_lst = data_folder, id_lst self.import_log_path = os.path.join(self.logs_path, f"data_import_log_{self.datetime}.csv") self.video_paths = find_all_videos_in_project(videos_dir=self.video_dir) @@ -93,7 +94,7 @@ def run(self): video_timer = SimbaTimer(start=True) self.add_spacer, self.frame_no, self.video_data, self.video_name = (2, 1, video_data, video_name) print(f"Processing {video_name} ({cnt+1}/{len(self.input_data_paths)})...") - self.data_df = (pd.read_hdf(video_data["DATA"]).replace([np.inf, -np.inf], np.nan).fillna(0)) + self.data_df = pd.read_hdf(video_data["DATA"]).replace([np.inf, -np.inf], np.nan).fillna(0) if len(self.data_df.columns) != len(self.bp_headers): raise BodypartColumnNotFoundError( msg=f'The number of body-parts in data file {video_data["DATA"]} do not match the number of body-parts in your SimBA project. ' @@ -108,27 +109,22 @@ def run(self): self.multianimal_identification() self.save_path = os.path.join(os.path.join(self.input_csv_dir, f"{self.video_name}.{self.file_type}")) write_df(df=self.out_df, file_type=self.file_type, save_path=self.save_path, multi_idx_header=True) - if self.interpolation_settings not in ["None", None]: - self.__run_interpolation() - if self.smoothing_settings["Method"] != "None": - self.__run_smoothing() + if self.interpolation_settings is not None: + interpolator = Interpolate(config_path=self.config_path, data_path=self.save_path, type=self.interpolation_settings['type'], method=self.interpolation_settings['method'], multi_index_df_headers=True, copy_originals=False) + interpolator.run() + if self.smoothing_settings is not None: + smoother = Smoothing(config_path=self.config_path, data_path=self.save_path, time_window=self.smoothing_settings['time_window'], method=self.smoothing_settings['method'], multi_index_df_headers=True, copy_originals=False) + smoother.run() video_timer.stop_timer() stdout_success(msg=f"Video {video_name} data imported...", elapsed_time=video_timer.elapsed_time_str) self.timer.stop_timer() stdout_success(msg="All maDLC H5 data files imported", elapsed_time=self.timer.elapsed_time_str) - def __run_interpolation(self): - print(f"Interpolating missing values in video {self.video_name} (Method: {self.interpolation_settings}) ...") - _ = Interpolate(input_path=self.save_path, config_path=self.config_path, method=self.interpolation_settings, initial_import_multi_index=True) - - def __run_smoothing(self): - print(f'Performing {self.smoothing_settings["Method"]} smoothing on video {self.video_name}...') - Smooth(config_path=self.config_path, input_path=self.save_path, time_window=int(self.smoothing_settings["Parameters"]["Time_window"]), smoothing_method=self.smoothing_settings["Method"], initial_import_multi_index=True) # test = MADLCImporterH5(config_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini', # data_folder=r'/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/h5', # file_type='ellipse', # id_lst=['Simon', 'JJ'], -# interpolation_settings='Body-parts: Nearest', -# smoothing_settings = {'Method': 'Savitzky Golay', 'Parameters': {'Time_window': '200'}}) +# interpolation_settings= {'type': 'animals', 'method': 'linear'}, +# smoothing_settings = {'time_window': 500, 'method': 'gaussian'}) # test.run() diff --git a/simba/pose_importers/sleap_csv_importer.py b/simba/pose_importers/sleap_csv_importer.py index 7d2bb1df7..5dcec75c6 100644 --- a/simba/pose_importers/sleap_csv_importer.py +++ b/simba/pose_importers/sleap_csv_importer.py @@ -2,20 +2,20 @@ import os from copy import deepcopy +from typing import Union, Dict, List, Any, Optional import numpy as np import pandas as pd -from simba.data_processors.interpolation_smoothing import Interpolate, Smooth +from simba.data_processors.interpolate import Interpolate +from simba.data_processors.smoothing import Smoothing from simba.mixins.config_reader import ConfigReader from simba.mixins.pose_importer_mixin import PoseImporterMixin -from simba.utils.checks import check_that_column_exist +from simba.utils.checks import check_that_column_exist, check_str, check_int, check_if_keys_exist_in_dict, check_if_dir_exists, check_valid_lst from simba.utils.enums import Methods, TagNames from simba.utils.errors import CountError from simba.utils.printing import SimbaTimer, log_event, stdout_success -from simba.utils.read_write import (clean_sleap_file_name, - find_all_videos_in_project, get_fn_ext, - get_video_meta_data, write_df) +from simba.utils.read_write import (clean_sleap_file_name, find_all_videos_in_project, get_fn_ext, get_video_meta_data, write_df) TRACK = "track" INSTANCE_SCORE = "instance.score" @@ -32,88 +32,62 @@ class SLEAPImporterCSV(ConfigReader, PoseImporterMixin): :parameter str config_path: path to SimBA project config file in Configparser format :parameter str data_folder: Path to folder containing SLEAP data in `csv` format. :parameter List[str] id_lst: Animal names. This will be ignored in one animal projects and default to ``Animal_1``. - :parameter str interpolation_settings: String defining the pose-estimation interpolation method. OPTIONS: 'None', 'Animal(s): Nearest', - 'Animal(s): Linear', 'Animal(s): Quadratic','Body-parts: Nearest', 'Body-parts: Linear', - 'Body-parts: Quadratic'. - :parameter str smoothing_settings: Dictionary defining the pose estimation smoothing method. EXAMPLE: {'Method': 'Savitzky Golay', - 'Parameters': {'Time_window': '200'}}) + :param Optional[Dict[str, str]] interpolation_setting: Dict defining the type and method to use to perform interpolation {'type': 'animals', 'method': 'linear'}. + :param Optional[Dict[str, Union[str, int]]] smoothing_settings: Dictionary defining the pose estimation smoothing method {'time_window': 500, 'method': 'gaussian'}. References ---------- - .. [1] Pereira et al., SLEAP: A deep learning system for multi-animal pose tracking, `Nature Methods`, - 2022. + .. [1] Pereira et al., SLEAP: A deep learning system for multi-animal pose tracking, `Nature Methods`, 2022. - >>> sleap_csv_importer = SLEAPImporterCSV(config_path=r'project_folder/project_config.ini', data_folder=r'data_folder', actor_IDs=['Termite_1', 'Termite_2', 'Termite_3', 'Termite_4', 'Termite_5'], interpolation_settings="Body-parts: Nearest", smoothing_settings = {'Method': 'Savitzky Golay', 'Parameters': {'Time_window': '200'}}) + >>> sleap_csv_importer = SLEAPImporterCSV(config_path=r'project_folder/project_config.ini', data_folder=r'data_folder', id_lst=['Termite_1', 'Termite_2', 'Termite_3', 'Termite_4', 'Termite_5'], interpolation_settings={'type': 'animals', 'method': 'linear'}, smoothing_settings = {'time_window': 500, 'method': 'gaussian'}) >>> sleap_csv_importer.run() """ - def __init__( - self, - config_path: str, - data_folder: str, - id_lst: list, - interpolation_settings: str, - smoothing_settings: dict, - ): + def __init__(self, + config_path: Union[str, os.PathLike], + data_folder: Union[str, os.PathLike], + id_lst: List[str], + interpolation_settings: Optional[Dict[str, str]] = None, + smoothing_settings: Optional[Dict[str, Union[int, str]]] = None): + ConfigReader.__init__(self, config_path=config_path, read_video_info=False) PoseImporterMixin.__init__(self) - log_event( - logger_name=str(__class__.__name__), - log_type=TagNames.CLASS_INIT.value, - msg=self.create_log_msg_from_init_args(locals=locals()), - ) - self.interpolation_settings, self.smoothing_settings = ( - interpolation_settings, - smoothing_settings, - ) + check_if_dir_exists(in_dir=data_folder) + check_valid_lst(data=id_lst, source=f'{self.__class__.__name__} id_lst', valid_dtypes=(str,), min_len=1) + if interpolation_settings is not None: + check_if_keys_exist_in_dict(data=interpolation_settings, key=['method', 'type'], name=f'{self.__class__.__name__} interpolation_settings') + check_str(name=f'{self.__class__.__name__} interpolation_settings type', value=interpolation_settings['type'], options=('body-parts', 'animals')) + check_str(name=f'{self.__class__.__name__} interpolation_settings method', value=interpolation_settings['method'], options=('linear', 'quadratic', 'nearest')) + if smoothing_settings is not None: + check_if_keys_exist_in_dict(data=smoothing_settings, key=['method', 'time_window'], name=f'{self.__class__.__name__} smoothing_settings') + check_str(name=f'{self.__class__.__name__} smoothing_settings method', value=smoothing_settings['method'], options=('savitzky-golay', 'gaussian')) + check_int(name=f'{self.__class__.__name__} smoothing_settings time_window', value=smoothing_settings['time_window'], min_value=1) + + log_event(logger_name=str(__class__.__name__), log_type=TagNames.CLASS_INIT.value, msg=self.create_log_msg_from_init_args(locals=locals())) + self.interpolation_settings, self.smoothing_settings = (interpolation_settings, smoothing_settings) self.data_folder, self.id_lst = data_folder, id_lst - self.import_log_path = os.path.join( - self.logs_path, f"data_import_log_{self.datetime}.csv" - ) + self.import_log_path = os.path.join(self.logs_path, f"data_import_log_{self.datetime}.csv") self.video_paths = find_all_videos_in_project(videos_dir=self.video_dir) - self.input_data_paths = self.find_data_files( - dir=self.data_folder, extensions=[".csv"] - ) + self.input_data_paths = self.find_data_files(dir=self.data_folder, extensions=[".csv"]) if self.pose_setting is Methods.USER_DEFINED.value: self.__update_config_animal_cnt() if self.animal_cnt > 1: - self.data_and_videos_lk = self.link_video_paths_to_data_paths( - data_paths=self.input_data_paths, - video_paths=self.video_paths, - filename_cleaning_func=clean_sleap_file_name, - ) + self.data_and_videos_lk = self.link_video_paths_to_data_paths(data_paths=self.input_data_paths, video_paths=self.video_paths, filename_cleaning_func=clean_sleap_file_name) self.check_multi_animal_status() - self.animal_bp_dict = self.create_body_part_dictionary( - self.multi_animal_status, - self.id_lst, - self.animal_cnt, - self.x_cols, - self.y_cols, - self.p_cols, - self.clr_lst, - ) + self.animal_bp_dict = self.create_body_part_dictionary(self.multi_animal_status, self.id_lst, self.animal_cnt, self.x_cols, self.y_cols, self.p_cols, self.clr_lst) if self.pose_setting is Methods.USER_DEFINED.value: self.update_bp_headers_file(update_bp_headers=True) else: - self.data_and_videos_lk = dict( - [ - (get_fn_ext(file_path)[1], {"DATA": file_path, "VIDEO": None}) - for file_path in self.input_data_paths - ] - ) + self.data_and_videos_lk = dict([(get_fn_ext(file_path)[1], {"DATA": file_path, "VIDEO": None}) for file_path in self.input_data_paths]) print(f"Importing {len(list(self.data_and_videos_lk.keys()))} file(s)...") def run(self): - for file_cnt, (video_name, video_data) in enumerate( - self.data_and_videos_lk.items() - ): + for file_cnt, (video_name, video_data) in enumerate(self.data_and_videos_lk.items()): output_filename = clean_sleap_file_name(filename=video_name) print(f"Analysing {output_filename}...") video_timer = SimbaTimer(start=True) self.video_name = video_name - self.save_path = os.path.join( - os.path.join(self.input_csv_dir, f"{output_filename}.{self.file_type}") - ) + self.save_path = os.path.join(os.path.join(self.input_csv_dir, f"{output_filename}.{self.file_type}")) data_df = pd.read_csv(video_data["DATA"]) if INSTANCE_SCORE in data_df.columns: data_df = data_df.drop([INSTANCE_SCORE], axis=1) @@ -123,86 +97,41 @@ def run(self): idx[TRACK] = idx[TRACK].str.replace(r"[^\d.]+", "").astype(int) data_df = data_df.iloc[:, 2:].fillna(0) if self.animal_cnt > 1: - self.data_df = pd.DataFrame( - self.transpose_multi_animal_table( - data=data_df.values, idx=idx.values, animal_cnt=self.animal_cnt - ) - ) + self.data_df = pd.DataFrame(self.transpose_multi_animal_table(data=data_df.values, idx=idx.values, animal_cnt=self.animal_cnt)) else: idx = list(idx.drop(TRACK, axis=1)["frame_idx"]) self.data_df = data_df.set_index([idx]).sort_index() self.data_df.columns = np.arange(len(self.data_df.columns)) - self.data_df = self.data_df.reindex( - range(0, self.data_df.index[-1] + 1), fill_value=0 - ) + self.data_df = self.data_df.reindex(range(0, self.data_df.index[-1] + 1), fill_value=0) if len(self.bp_headers) != len(self.data_df.columns): - raise CountError( - msg=f"SimBA project expects {len(self.bp_headers)} data columns, but your SLEAP data file {video_name} contains {len(self.data_df.columns)} columns. Missing columns: {list(set(self.bp_headers) - set(self.data_df.columns))}", - source=self.__class__.__name__, - ) + raise CountError(msg=f"SimBA project expects {len(self.bp_headers)} data columns, but your SLEAP data file {video_name} contains {len(self.data_df.columns)} columns. Missing columns: {list(set(self.bp_headers) - set(self.data_df.columns))}", source=self.__class__.__name__) self.data_df.columns = self.bp_headers self.out_df = deepcopy(self.data_df) if self.animal_cnt > 1: - self.initialize_multi_animal_ui( - animal_bp_dict=self.animal_bp_dict, - video_info=get_video_meta_data(video_data["VIDEO"]), - data_df=self.data_df, - video_path=video_data["VIDEO"], - ) + self.initialize_multi_animal_ui(animal_bp_dict=self.animal_bp_dict, + video_info=get_video_meta_data(video_data["VIDEO"]), + data_df=self.data_df, + video_path=video_data["VIDEO"]) self.multianimal_identification() - write_df( - df=self.out_df, - file_type=self.file_type, - save_path=self.save_path, - multi_idx_header=True, - ) - if self.interpolation_settings != "None": - self.__run_interpolation() - if self.smoothing_settings["Method"] != "None": - self.__run_smoothing() + write_df(df=self.out_df, file_type=self.file_type, save_path=self.save_path, multi_idx_header=True) + if self.interpolation_settings is not None: + interpolator = Interpolate(config_path=self.config_path, data_path=self.save_path, type=self.interpolation_settings['type'], method=self.interpolation_settings['method'], multi_index_df_headers=True, copy_originals=False) + interpolator.run() + if self.smoothing_settings is not None: + smoother = Smoothing(config_path=self.config_path, data_path=self.save_path, time_window=self.smoothing_settings['time_window'], method=self.smoothing_settings['method'], multi_index_df_headers=True, copy_originals=False) + smoother.run() video_timer.stop_timer() - stdout_success( - msg=f"Video {video_name} data imported...", - elapsed_time=video_timer.elapsed_time_str, - source=self.__class__.__name__, - ) + stdout_success(msg=f"Video {video_name} data imported...", elapsed_time=video_timer.elapsed_time_str, source=self.__class__.__name__) self.timer.stop_timer() - stdout_success( - msg=f"{len(list(self.data_and_videos_lk.keys()))} file(s) imported to the SimBA project (project_folder/csv/input_csv directory)", - source=self.__class__.__name__, - ) - - def __run_interpolation(self): - print( - f"Interpolating missing values in video {self.video_name} (Method: {self.interpolation_settings})..." - ) - _ = Interpolate( - input_path=self.save_path, - config_path=self.config_path, - method=self.interpolation_settings, - initial_import_multi_index=True, - ) - - def __run_smoothing(self): - print( - f'Performing {self.smoothing_settings["Method"]} smoothing on video {self.video_name}...' - ) - Smooth( - config_path=self.config_path, - input_path=self.save_path, - time_window=int(self.smoothing_settings["Parameters"]["Time_window"]), - smoothing_method=self.smoothing_settings["Method"], - initial_import_multi_index=True, - ) - + stdout_success(msg=f"{len(list(self.data_and_videos_lk.keys()))} file(s) imported to the SimBA project {self.input_csv_dir}", source=self.__class__.__name__) # test = SLEAPImporterCSV(config_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/sleap_two_animals/project_folder/project_config.ini', # data_folder=r'/Users/simon/Desktop/envs/simba/troubleshooting/sleap_two_animals/csv_import', # id_lst=['Simon', 'jj'], -# interpolation_settings="None", -# smoothing_settings = {'Method': 'None', 'Parameters': {'Time_window': '200'}}) +# interpolation_settings={'type': 'animals', 'method': 'linear'}, +# smoothing_settings = {'time_window': 500, 'method': 'gaussian'}) # test.run() # test = SLEAPImporterCSV(config_path=r'/Users/simon/Desktop/envs/troubleshooting/Hornet/project_folder/project_config.ini', diff --git a/simba/pose_importers/sleap_h5_importer.py b/simba/pose_importers/sleap_h5_importer.py index 57792aab2..850b1a695 100644 --- a/simba/pose_importers/sleap_h5_importer.py +++ b/simba/pose_importers/sleap_h5_importer.py @@ -4,36 +4,30 @@ from typing import Any, Dict, List, Union import h5py -import numpy as np import pandas as pd -from simba.data_processors.interpolation_smoothing import Interpolate, Smooth +from simba.data_processors.interpolate import Interpolate +from simba.data_processors.smoothing import Smoothing from simba.mixins.config_reader import ConfigReader from simba.mixins.pose_importer_mixin import PoseImporterMixin -from simba.utils.checks import (check_file_exist_and_readable, - check_if_dir_exists, - check_if_keys_exist_in_dict, check_str) -from simba.utils.enums import Methods, Options, TagNames +from simba.utils.checks import check_file_exist_and_readable +from simba.utils.checks import check_str, check_int, check_if_keys_exist_in_dict, check_if_dir_exists, check_valid_lst +from simba.utils.enums import Methods, TagNames from simba.utils.errors import BodypartColumnNotFoundError -from simba.utils.printing import (SimbaTimer, log_event, stdout_success, - stdout_warning) -from simba.utils.read_write import (clean_sleap_file_name, - find_all_videos_in_project, get_fn_ext, - get_video_meta_data, write_df) +from simba.utils.printing import (SimbaTimer, log_event, stdout_success, stdout_warning) +from simba.utils.read_write import (clean_sleap_file_name, find_all_videos_in_project, get_fn_ext, get_video_meta_data, write_df) class SLEAPImporterH5(ConfigReader, PoseImporterMixin): """ Importing SLEAP pose-estimation data into SimBA project in ``H5`` format - :parameter str config_path: path to SimBA project config file in Configparser format - :parameter str data_folder: Path to folder containing SLEAP data in `csv` format. - :parameter List[str] id_lst: Animal names. This will be ignored in one animal projects and default to ``Animal_1``. - :parameter str interpolation_settings: String defining the pose-estimation interpolation method. OPTIONS: 'None', 'Animal(s): Nearest', - 'Animal(s): Linear', 'Animal(s): Quadratic','Body-parts: Nearest', 'Body-parts: Linear', - 'Body-parts: Quadratic'. - :parameter str smoothing_settings: Dictionary defining the pose estimation smoothing method. EXAMPLE: {'Method': 'Savitzky Golay', - 'Parameters': {'Time_window': '200'}} + :param str config_path: path to SimBA project config file in Configparser format + :param str data_folder: Path to folder containing SLEAP data in `H5` format. + :param List[str] id_lst: Animal names. This will be ignored in one animal projects and default to ``Animal_1``. + :param Optional[Dict[str, str]] interpolation_setting: Dict defining the type and method to use to perform interpolation {'type': 'animals', 'method': 'linear'}. + :param Optional[Dict[str, Union[str, int]]] smoothing_settings: Dictionary defining the pose estimation smoothing method {'time_window': 500, 'method': 'gaussian'}. + .. note:: `Multi-animal import tutorial `__. @@ -59,8 +53,16 @@ def __init__(self, log_event(logger_name=str(__class__.__name__),log_type=TagNames.CLASS_INIT.value, msg=self.create_log_msg_from_init_args(locals=locals())) check_file_exist_and_readable(file_path=config_path) check_if_dir_exists(in_dir=data_folder, source=self.__class__.__name__) - check_str(name=f'{self.__class__.__name__} interpolation_settings', value=interpolation_settings, options=Options.INTERPOLATION_OPTIONS_W_NONE.value) - check_if_keys_exist_in_dict(data=smoothing_settings, key='Method', name=f'{self.__class__.__name__} smoothing_settings') + check_valid_lst(data=id_lst, source=f'{self.__class__.__name__} id_lst', valid_dtypes=(str,), min_len=1) + if interpolation_settings is not None: + check_if_keys_exist_in_dict(data=interpolation_settings, key=['method', 'type'], name=f'{self.__class__.__name__} interpolation_settings') + check_str(name=f'{self.__class__.__name__} interpolation_settings type', value=interpolation_settings['type'], options=('body-parts', 'animals')) + check_str(name=f'{self.__class__.__name__} interpolation_settings method', value=interpolation_settings['method'], options=('linear', 'quadratic', 'nearest')) + if smoothing_settings is not None: + check_if_keys_exist_in_dict(data=smoothing_settings, key=['method', 'time_window'], name=f'{self.__class__.__name__} smoothing_settings') + check_str(name=f'{self.__class__.__name__} smoothing_settings method', value=smoothing_settings['method'], options=('savitzky-golay', 'gaussian')) + check_int(name=f'{self.__class__.__name__} smoothing_settings time_window', value=smoothing_settings['time_window'], min_value=1) + self.interpolation_settings, self.smoothing_settings = interpolation_settings, smoothing_settings self.data_folder, self.id_lst = data_folder, id_lst self.import_log_path = os.path.join(self.logs_path, f"data_import_log_{self.datetime}.csv") @@ -115,8 +117,7 @@ def run(self): f"The number of of body-parts expected by your SimBA project is {int(len(self.bp_headers) / 3)}. " f'The number of of body-parts contained in data file {video_data["DATA"]} is {int(len(self.data_df.columns) / 3)}. ' f"Make sure you have specified the correct number of animals and body-parts in your project.", - source=self.__class__.__name__, - ) + source=self.__class__.__name__) self.data_df.columns = self.bp_headers if self.animal_cnt > 1: self.initialize_multi_animal_ui(animal_bp_dict=self.animal_bp_dict, video_info=get_video_meta_data(video_data["VIDEO"]), data_df=self.data_df, video_path=video_data["VIDEO"]) @@ -126,29 +127,22 @@ def run(self): self.save_path = os.path.join(os.path.join(self.input_csv_dir, f"{self.output_filename}.{self.file_type}")) write_df(df=self.out_df, file_type=self.file_type, save_path=self.save_path, multi_idx_header=True) - if self.interpolation_settings != "None": - self.__run_interpolation() - if self.smoothing_settings["Method"] != "None": - self.__run_smoothing() + if self.interpolation_settings is not None: + interpolator = Interpolate(config_path=self.config_path, data_path=self.save_path, type=self.interpolation_settings['type'], method=self.interpolation_settings['method'], multi_index_df_headers=True, copy_originals=False) + interpolator.run() + if self.smoothing_settings is not None: + smoother = Smoothing(config_path=self.config_path, data_path=self.save_path, time_window=self.smoothing_settings['time_window'], method=self.smoothing_settings['method'], multi_index_df_headers=True, copy_originals=False) + smoother.run() video_timer.stop_timer() stdout_success(msg=f"Video {self.output_filename} data imported...", elapsed_time=video_timer.elapsed_time_str, source=self.__class__.__name__) self.timer.stop_timer() stdout_success(msg="All SLEAP H5 data files imported", elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__) - def __run_interpolation(self): - print(f"Interpolating missing values in video {self.output_filename} (Method: {self.interpolation_settings})...") - _ = Interpolate(input_path=self.save_path, config_path=self.config_path, method=self.interpolation_settings, initial_import_multi_index=True) - - def __run_smoothing(self): - print(f'Performing {self.smoothing_settings["Method"]} smoothing on video {self.output_filename}...') - Smooth(config_path=self.config_path, input_path=self.save_path, time_window=int(self.smoothing_settings["Parameters"]["Time_window"]), smoothing_method=self.smoothing_settings["Method"], initial_import_multi_index=True) - - # test = SLEAPImporterH5(config_path="/Users/simon/Desktop/envs/simba/troubleshooting/sleap_two_animals/project_folder/project_config.ini", -# data_folder=r'/Users/simon/Desktop/envs/simba/troubleshooting/sleap_two_animals/h5_import', -# id_lst=['White', 'Black'], -# interpolation_settings= "Body-parts: Nearest", #'"Body-parts: Nearest", -# smoothing_settings = {'Method': 'None', 'Parameters': {'Time_window': '200'}}) +# data_folder=r'/Users/simon/Desktop/envs/simba/troubleshooting/sleap_two_animals/h5_import', +# id_lst=['White', 'Black'], +# interpolation_settings={'type': 'body-parts', 'method': 'nearest'}, +# smoothing_settings = {'time_window': 500, 'method': 'gaussian'}) # test.run() diff --git a/simba/pose_importers/sleap_slp_importer.py b/simba/pose_importers/sleap_slp_importer.py index af0eb63aa..ce1d1a663 100644 --- a/simba/pose_importers/sleap_slp_importer.py +++ b/simba/pose_importers/sleap_slp_importer.py @@ -4,18 +4,20 @@ import json import os from collections import defaultdict +from typing import Optional, Dict, List, Any, Union import h5py import numpy as np import pandas as pd -from simba.data_processors.interpolation_smoothing import Interpolate, Smooth +from simba.data_processors.smoothing import Smoothing +from simba.data_processors.interpolate import Interpolate from simba.mixins.config_reader import ConfigReader from simba.mixins.pose_importer_mixin import PoseImporterMixin from simba.utils.enums import Methods from simba.utils.printing import SimbaTimer, stdout_success -from simba.utils.read_write import (find_all_videos_in_project, - get_video_meta_data, write_df) +from simba.utils.read_write import (find_all_videos_in_project, get_video_meta_data, write_df) +from simba.utils.checks import check_int, check_str, check_valid_lst, check_if_dir_exists, check_if_keys_exist_in_dict class SLEAPImporterSLP(ConfigReader, PoseImporterMixin): @@ -26,14 +28,11 @@ class SLEAPImporterSLP(ConfigReader, PoseImporterMixin): Importing SLEAP .SLP files into SimBA come at long runtimes. For fater runtimes, use :meth:`simba.pose_importers.sleap_h5_importer.SLEAPImporterH5` or :meth:`simba.pose_importers.sleap_csv_importer.SLEAPImporterCSV` - :parameter str config_path: path to SimBA project config file in Configparser format - :parameter str data_folder: Path to folder containing SLEAP data in `csv` format. - :parameter List[str] id_lst: Animal names. This will be ignored in one animal projects and default to ``Animal_1``. - :parameter str interpolation_settings: String defining the pose-estimation interpolation method. OPTIONS: 'None', 'Animal(s): Nearest', - 'Animal(s): Linear', 'Animal(s): Quadratic','Body-parts: Nearest', 'Body-parts: Linear', - 'Body-parts: Quadratic'. - :parameter str smoothing_settings: Dictionary defining the pose estimation smoothing method. EXAMPLE: {'Method': 'Savitzky Golay', - 'Parameters': {'Time_window': '200'}}. + :param str config_path: path to SimBA project config file in Configparser format + :param str data_folder: Path to folder containing SLEAP data in `csv` format. + :param List[str] id_lst: Animal names. This will be ignored in one animal projects and default to ``Animal_1``. + :param Optional[Dict[str, str]] interpolation_setting: Dict defining the type and method to use to perform interpolation {'type': 'animals', 'method': 'linear'}. + :param Optional[Dict[str, Union[str, int]]] smoothing_settings: Dictionary defining the pose estimation smoothing method {'time_window': 500, 'method': 'gaussian'}. Example ---------- @@ -47,44 +46,38 @@ class SLEAPImporterSLP(ConfigReader, PoseImporterMixin): """ - def __init__( - self, - project_path: str, - data_folder: str, - id_lst: list, - interpolation_settings: str, - smoothing_settings: dict, - ): + def __init__(self, + project_path: Union[str, os.PathLike], + data_folder: Union[str, os.PathLike], + id_lst: List[str], + interpolation_settings: Optional[Dict[str, str]] = None, + smoothing_settings: Optional[Dict[str, Any]] = None): + ConfigReader.__init__(self, config_path=project_path, read_video_info=False) PoseImporterMixin.__init__(self) - self.interpolation_settings, self.smoothing_settings = ( - interpolation_settings, - smoothing_settings, - ) + check_if_dir_exists(in_dir=data_folder) + check_valid_lst(data=id_lst, source=f'{self.__class__.__name__} id_lst', valid_dtypes=(str,)) + if interpolation_settings is not None: + check_if_keys_exist_in_dict(data=interpolation_settings, key=['method', 'type'], name=f'{self.__class__.__name__} interpolation_settings') + check_str(name=f'{self.__class__.__name__} interpolation_settings type', value=interpolation_settings['type'], options=('body-parts', 'animals')) + check_str(name=f'{self.__class__.__name__} interpolation_settings method', value=interpolation_settings['method'], options=('linear', 'quadratic', 'nearest')) + if smoothing_settings is not None: + check_if_keys_exist_in_dict(data=smoothing_settings, key=['method', 'time_window'], name=f'{self.__class__.__name__} smoothing_settings') + check_str(name=f'{self.__class__.__name__} smoothing_settings method', value=smoothing_settings['method'], options=('savitzky-golay', 'gaussian')) + check_int(name=f'{self.__class__.__name__} smoothing_settings time_window', value=smoothing_settings['time_window'], min_value=1) + + + self.interpolation_settings, self.smoothing_settings = interpolation_settings, smoothing_settings self.data_folder, self.id_lst = data_folder, id_lst - self.import_log_path = os.path.join( - self.logs_path, f"data_import_log_{self.datetime}.csv" - ) + self.import_log_path = os.path.join(self.logs_path, f"data_import_log_{self.datetime}.csv") self.video_paths = find_all_videos_in_project(videos_dir=self.video_dir) - self.input_data_paths = self.find_data_files( - dir=self.data_folder, extensions=[".slp"] - ) - self.data_and_videos_lk = self.link_video_paths_to_data_paths( - data_paths=self.input_data_paths, video_paths=self.video_paths - ) + self.input_data_paths = self.find_data_files(dir=self.data_folder, extensions=[".slp"]) + self.data_and_videos_lk = self.link_video_paths_to_data_paths(data_paths=self.input_data_paths, video_paths=self.video_paths) if self.pose_setting is Methods.USER_DEFINED.value: self.__update_config_animal_cnt() if self.animal_cnt > 1: self.check_multi_animal_status() - self.animal_bp_dict = self.create_body_part_dictionary( - self.multi_animal_status, - self.id_lst, - self.animal_cnt, - self.x_cols, - self.y_cols, - self.p_cols, - self.clr_lst, - ) + self.animal_bp_dict = self.create_body_part_dictionary(self.multi_animal_status, self.id_lst, self.animal_cnt, self.x_cols, self.y_cols, self.p_cols, self.clr_lst) if self.pose_setting is Methods.USER_DEFINED.value: self.update_bp_headers_file(update_bp_headers=True) print(f"Importing {len(list(self.data_and_videos_lk.keys()))} file(s)...") @@ -118,34 +111,10 @@ def __fill_missing_indexes(self): ) self.data_df = pd.concat([self.data_df, missing_df], axis=0) - def __run_interpolation(self): - print( - f"Interpolating missing values in video {self.video_name} (Method: {self.interpolation_settings})..." - ) - _ = Interpolate( - input_path=self.save_path, - config_path=self.config_path, - method=self.interpolation_settings, - initial_import_multi_index=True, - ) - - def __run_smoothing(self): - print( - f'Performing {self.smoothing_settings["Method"]} smoothing on video {self.video_name}...' - ) - Smooth( - config_path=self.config_path, - input_path=self.save_path, - time_window=int(self.smoothing_settings["Parameters"]["Time_window"]), - smoothing_method=self.smoothing_settings, - ) - def run(self): - self.analysis_dict = defaultdict(list) - self.save_paths_lst = [] - for file_cnt, (video_name, video_data) in enumerate( - self.data_and_videos_lk.items() - ): + self.analysis_dict, self.save_paths_lst = defaultdict(list), [] + + for file_cnt, (video_name, video_data) in enumerate(self.data_and_videos_lk.items()): print(f"Analysing {video_name}...") video_timer = SimbaTimer(start=True) self.video_name = video_name @@ -163,16 +132,12 @@ def run(self): for orderVar in self.sleap_dict["skeletons"][0]["nodes"]: self.analysis_dict["ordered_ids"].append((orderVar["id"])) for index in self.analysis_dict["ordered_ids"]: - self.analysis_dict["ordered_bps"].append( - self.analysis_dict["bp_names"][index] - ) + self.analysis_dict["ordered_bps"].append(self.analysis_dict["bp_names"][index]) with h5py.File(video_data["DATA"], "r") as file: self.analysis_dict["frames"] = file["frames"][:] self.analysis_dict["instances"] = file["instances"][:] - self.analysis_dict["predicted_points"] = np.reshape( - file["pred_points"][:], (file["pred_points"][:].size, 1) - ) + self.analysis_dict["predicted_points"] = np.reshape(file["pred_points"][:], (file["pred_points"][:].size, 1)) self.analysis_dict["no_frames"] = len(self.analysis_dict["frames"]) for c in itertools.product(self.id_lst, self.analysis_dict["ordered_bps"]): @@ -186,9 +151,7 @@ def run(self): # self.data_df = pd.DataFrame(columns=self.analysis_dict["xyp_headers"]) frames_lst = [l.tolist() for l in self.analysis_dict["frames"]] - self.analysis_dict["animals_in_each_frame"] = [ - x[4] - x[3] for x in frames_lst - ] + self.analysis_dict["animals_in_each_frame"] = [x[4] - x[3] for x in frames_lst] self.__create_tracks() self.initialize_multi_animal_ui( @@ -199,19 +162,14 @@ def run(self): ) if self.animal_cnt > 1: self.multianimal_identification() - self.save_path = os.path.join( - os.path.join(self.input_csv_dir, f"{self.video_name}.{self.file_type}") - ) - write_df( - df=self.out_df, - file_type=self.file_type, - save_path=self.save_path, - multi_idx_header=True, - ) - if self.interpolation_settings != "None": - self.__run_interpolation() - if self.smoothing_settings["Method"] != "None": - self.__run_smoothing() + self.save_path = os.path.join(os.path.join(self.input_csv_dir, f"{self.video_name}.{self.file_type}")) + write_df(df=self.out_df, file_type=self.file_type, save_path=self.save_path, multi_idx_header=True) + if self.interpolation_settings is not None: + interpolator = Interpolate(config_path=self.config_path, data_path=self.save_path, type=self.interpolation_settings['type'], method=self.interpolation_settings['method'], multi_index_df_headers=True, copy_originals=False) + interpolator.run() + if self.smoothing_settings is not None: + smoother = Smoothing(config_path=self.config_path, data_path=self.save_path, time_window=self.smoothing_settings['time_window'], method=self.smoothing_settings['method'], multi_index_df_headers=True, copy_originals=False) + smoother.run() video_timer.stop_timer() stdout_success( msg=f"Video {video_name} data imported...", @@ -273,11 +231,16 @@ def __create_tracks(self): self.data_df.columns = self.bp_headers + + + + + # test = SLEAPImporterSLP(project_path="/Users/simon/Desktop/envs/simba/troubleshooting/sleap_two_animals/project_folder/project_config.ini", # data_folder=r'/Users/simon/Desktop/envs/simba/troubleshooting/sleap_two_animals/slp_import', # id_lst=['Simon', 'JJ'], -# interpolation_settings="Body-parts: Nearest", -# smoothing_settings = {'Method': 'Savitzky Golay', 'Parameters': {'Time_window': '200'}}) #Savitzky Golay +# interpolation_settings={'type': 'animals', 'method': 'linear'}, +# smoothing_settings = {'time_window': 500, 'method': 'gaussian'}) #Savitzky Golay # test.run() # # print('All SLEAP imports complete.') diff --git a/simba/pose_importers/trk_importer.py b/simba/pose_importers/trk_importer.py index f6a7fbffa..892ded167 100644 --- a/simba/pose_importers/trk_importer.py +++ b/simba/pose_importers/trk_importer.py @@ -88,7 +88,7 @@ def trk_read(self, file_path: str): ) return animals_tracked_list - def import_trk(self): + def run(self): for file_cnt, file_path in enumerate(self.data_paths): _, file_name, file_ext = get_fn_ext(file_path) if self.animal_cnt > 0: diff --git a/simba/ui/create_project_ui.py b/simba/ui/create_project_ui.py index d4fce33d9..c818e8003 100644 --- a/simba/ui/create_project_ui.py +++ b/simba/ui/create_project_ui.py @@ -9,29 +9,37 @@ import PIL.Image from PIL import ImageTk + + import simba from simba.mixins.pop_up_mixin import PopUpMixin from simba.ui.pop_ups.clf_add_remove_print_pop_up import PoseResetterPopUp -from simba.ui.pop_ups.create_user_defined_pose_configuration_pop_up import \ - CreateUserDefinedPoseConfigurationPopUp -from simba.ui.tkinter_functions import (CreateLabelFrameWithIcon, DropDownMenu, - Entry_Box, FolderSelect, hxtScrollbar) +from simba.ui.pop_ups.create_user_defined_pose_configuration_pop_up import CreateUserDefinedPoseConfigurationPopUp +from simba.ui.tkinter_functions import (CreateLabelFrameWithIcon, DropDownMenu, Entry_Box, FolderSelect, hxtScrollbar) +from simba.ui.import_pose_frame import ImportPoseFrame +from simba.ui.import_videos_frame import ImportVideosFrame from simba.utils.checks import check_if_dir_exists, check_str from simba.utils.config_creator import ProjectConfigCreator from simba.utils.enums import Formats, Keys, Links, Methods, Options, Paths from simba.utils.errors import DuplicationError, MissingProjectConfigEntryError -from simba.utils.lookups import (get_body_part_configurations, - get_bp_config_codes, get_icons_paths) -from simba.video_processors.video_processing import \ - extract_frames_from_all_videos_in_directory +from simba.utils.lookups import (get_body_part_configurations, get_bp_config_codes, get_icons_paths) +from simba.video_processors.video_processing import extract_frames_from_all_videos_in_directory class ProjectCreatorPopUp(PopUpMixin): """ - Mixin for GUI pop-up windows that accept user-inputs. + Mixin for GUI pop-up windows that accept user-inputs for creating a SimBA project. + + .. image:: _static/img/ProjectCreatorPopUp.webp + :width: 800 + :align: center + + :example: + >>> ProjectCreatorPopUp() """ def __init__(self): + #PopUpMixin.__init__(self, title='') self.main_frm = Toplevel() self.main_frm.minsize(750, 750) self.main_frm.wm_title("PROJECT CONFIGURATION") @@ -40,183 +48,64 @@ def __init__(self): parent_tab = ttk.Notebook(hxtScrollbar(self.main_frm)) self.btn_icons = get_icons_paths() for k in self.btn_icons.keys(): - self.btn_icons[k]["img"] = ImageTk.PhotoImage( - image=PIL.Image.open( - os.path.join( - os.path.dirname("__file__"), self.btn_icons[k]["icon_path"] - ) - ) - ) + self.btn_icons[k]["img"] = ImageTk.PhotoImage(image=PIL.Image.open(os.path.join(os.path.dirname("__file__"), self.btn_icons[k]["icon_path"]))) self.create_project_tab = ttk.Frame(parent_tab) self.import_videos_tab = ttk.Frame(parent_tab) self.import_data_tab = ttk.Frame(parent_tab) self.extract_frms_tab = ttk.Frame(parent_tab) - parent_tab.add( - self.create_project_tab, - text=f'{"[ Create project config ]": ^20s}', - compound="left", - image=self.btn_icons["create"]["img"], - ) - parent_tab.add( - self.import_videos_tab, - text=f'{"[ Import videos ]": ^20s}', - compound="left", - image=self.btn_icons["video"]["img"], - ) - parent_tab.add( - self.import_data_tab, - text=f'{"[ Import tracking data ]": ^20s}', - compound="left", - image=self.btn_icons["pose"]["img"], - ) - parent_tab.add( - self.extract_frms_tab, - text=f'{"[ Extract frames ]": ^20s}', - compound="left", - image=self.btn_icons["frames"]["img"], - ) + parent_tab.add(self.create_project_tab, text=f'{"[ Create project config ]": ^20s}', compound="left", image=self.btn_icons["create"]["img"]) + parent_tab.add(self.import_videos_tab, text=f'{"[ Import videos ]": ^20s}', compound="left", image=self.btn_icons["video"]["img"]) + parent_tab.add(self.import_data_tab, text=f'{"[ Import tracking data ]": ^20s}', compound="left", image=self.btn_icons["pose"]["img"]) + parent_tab.add( self.extract_frms_tab, text=f'{"[ Extract frames ]": ^20s}', compound="left", image=self.btn_icons["frames"]["img"]) parent_tab.grid(row=0, column=0, sticky=NW) - self.settings_frm = CreateLabelFrameWithIcon( - parent=self.create_project_tab, - header="SETTINGS", - icon_name=Keys.DOCUMENTATION.value, - icon_link=Links.CREATE_PROJECT.value, - ) - self.general_settings_frm = LabelFrame( - self.settings_frm, - text="GENERAL PROJECT SETTINGS", - fg="black", - font=Formats.LABELFRAME_HEADER_FORMAT.value, - padx=5, - pady=5, - ) + self.settings_frm = CreateLabelFrameWithIcon(parent=self.create_project_tab, header="SETTINGS", icon_name=Keys.DOCUMENTATION.value, icon_link=Links.CREATE_PROJECT.value) + self.general_settings_frm = LabelFrame(self.settings_frm, text="GENERAL PROJECT SETTINGS", fg="black", font=Formats.LABELFRAME_HEADER_FORMAT.value, padx=5, pady=5) - self.project_dir_select = FolderSelect( - self.general_settings_frm, "Project directory:", lblwidth="25" - ) - self.project_name_eb = Entry_Box( - self.general_settings_frm, "Project name:", labelwidth="25" - ) - self.file_type_dropdown = DropDownMenu( - self.general_settings_frm, - "Workflow file type:", - Options.WORKFLOW_FILE_TYPE_OPTIONS.value, - "25", - ) - self.file_type_dropdown.setChoices( - choice=Options.WORKFLOW_FILE_TYPE_OPTIONS.value[0] - ) + self.project_dir_select = FolderSelect(self.general_settings_frm, "Project directory:", lblwidth="25") + self.project_name_eb = Entry_Box(self.general_settings_frm, "Project name:", labelwidth="25") + self.file_type_dropdown = DropDownMenu(self.general_settings_frm, "Workflow file type:", Options.WORKFLOW_FILE_TYPE_OPTIONS.value, "25") + self.file_type_dropdown.setChoices(choice=Options.WORKFLOW_FILE_TYPE_OPTIONS.value[0]) self.clf_name_entries = [] - self.ml_settings_frm = LabelFrame( - self.settings_frm, - text="MACHINE LEARNING SETTINGS", - font=Formats.LABELFRAME_HEADER_FORMAT.value, - padx=5, - pady=5, - ) - self.clf_cnt = Entry_Box( - self.ml_settings_frm, - "Number of classifiers (behaviors): ", - "25", - validation="numeric", - ) - add_clf_btn = Button( - self.ml_settings_frm, - text="", - fg="blue", - command=lambda: self.create_entry_boxes_from_entrybox( - count=self.clf_cnt.entry_get, - parent=self.ml_settings_frm, - current_entries=self.clf_name_entries, - ), - ) + self.ml_settings_frm = LabelFrame(self.settings_frm, text="MACHINE LEARNING SETTINGS", font=Formats.LABELFRAME_HEADER_FORMAT.value, padx=5, pady=5) + self.clf_cnt = Entry_Box(self.ml_settings_frm, "Number of classifiers (behaviors): ", "25", validation="numeric") + add_clf_btn = Button(self.ml_settings_frm, text="", fg="blue", command=lambda: self.create_entry_boxes_from_entrybox(count=self.clf_cnt.entry_get, parent=self.ml_settings_frm, current_entries=self.clf_name_entries)) - self.animal_settings_frm = LabelFrame( - self.settings_frm, - text="ANIMAL SETTINGS", - font=Formats.LABELFRAME_HEADER_FORMAT.value, - ) - self.tracking_type_dropdown = DropDownMenu( - self.animal_settings_frm, - "Type of Tracking", - Options.TRACKING_TYPE_OPTIONS.value, - "25", - com=self.update_body_part_dropdown, - ) + self.animal_settings_frm = LabelFrame(self.settings_frm, text="ANIMAL SETTINGS", font=Formats.LABELFRAME_HEADER_FORMAT.value) + self.tracking_type_dropdown = DropDownMenu(self.animal_settings_frm, "Type of Tracking", Options.TRACKING_TYPE_OPTIONS.value, "25", com=self.update_body_part_dropdown) self.tracking_type_dropdown.setChoices(Options.TRACKING_TYPE_OPTIONS.value[0]) - project_animal_cnt_path = os.path.join( - os.path.dirname(simba.__file__), Paths.SIMBA_NO_ANIMALS_PATH.value - ) - self.animal_count_lst = list( - pd.read_csv(project_animal_cnt_path, header=None)[0] - ) + project_animal_cnt_path = os.path.join(os.path.dirname(simba.__file__), Paths.SIMBA_NO_ANIMALS_PATH.value) + self.animal_count_lst = list(pd.read_csv(project_animal_cnt_path, header=None)[0]) self.bp_lu = get_body_part_configurations() self.bp_config_codes = get_bp_config_codes() - self.classical_tracking_options = deepcopy( - Options.CLASSICAL_TRACKING_OPTIONS.value - ) - self.multi_tracking_options = deepcopy( - Options.MULTI_ANIMAL_TRACKING_OPTIONS.value - ) - self.three_dim_tracking_options = deepcopy( - Options.THREE_DIM_TRACKING_OPTIONS.value - ) - self.user_defined_options = [ - x - for x in list(self.bp_lu.keys()) - if x not in self.classical_tracking_options - and x not in self.multi_tracking_options - and x not in self.three_dim_tracking_options - ] + self.classical_tracking_options = deepcopy(Options.CLASSICAL_TRACKING_OPTIONS.value) + self.multi_tracking_options = deepcopy(Options.MULTI_ANIMAL_TRACKING_OPTIONS.value) + self.three_dim_tracking_options = deepcopy(Options.THREE_DIM_TRACKING_OPTIONS.value) + self.user_defined_options = [x + for x in list(self.bp_lu.keys()) + if x not in self.classical_tracking_options + and x not in self.multi_tracking_options + and x not in self.three_dim_tracking_options] for k in self.bp_lu.keys(): - self.bp_lu[k]["img"] = ImageTk.PhotoImage( - file=os.path.join( - os.path.dirname("__file__"), self.bp_lu[k]["img_path"] - ) - ) - self.classical_tracking_option_dict = { - k: self.bp_lu[k] for k in self.classical_tracking_options - } - self.multi_tracking_option_dict = { - k: self.bp_lu[k] for k in self.multi_tracking_options - } + self.bp_lu[k]["img"] = ImageTk.PhotoImage(file=os.path.join(os.path.dirname("__file__"), self.bp_lu[k]["img_path"])) + self.classical_tracking_option_dict = {k: self.bp_lu[k] for k in self.classical_tracking_options} + self.multi_tracking_option_dict = {k: self.bp_lu[k] for k in self.multi_tracking_options} self.classical_tracking_options.append(Methods.CREATE_POSE_CONFIG.value) self.multi_tracking_options.append(Methods.CREATE_POSE_CONFIG.value) self.three_dim_tracking_options.append(Methods.CREATE_POSE_CONFIG.value) self.classical_tracking_options.extend(self.user_defined_options) self.multi_tracking_options.extend(self.user_defined_options) self.three_dim_tracking_options.extend(self.user_defined_options) - self.selected_tracking_dropdown = DropDownMenu( - self.animal_settings_frm, - "Body-part config", - Options.CLASSICAL_TRACKING_OPTIONS.value, - "25", - com=self.update_img, - ) + self.selected_tracking_dropdown = DropDownMenu(self.animal_settings_frm, "Body-part config", Options.CLASSICAL_TRACKING_OPTIONS.value, "25", com=self.update_img) self.selected_tracking_dropdown.setChoices(self.classical_tracking_options[0]) - self.img_lbl = Label( - self.animal_settings_frm, - image=self.bp_lu[self.classical_tracking_options[0]]["img"], - ) - reset_btn = Button( - self.animal_settings_frm, - text="RESET USER DEFINED POSE-CONFIGS", - fg="red", - command=lambda: PoseResetterPopUp(), - ) + self.img_lbl = Label(self.animal_settings_frm, image=self.bp_lu[self.classical_tracking_options[0]]["img"]) + reset_btn = Button(self.animal_settings_frm, text="RESET USER DEFINED POSE-CONFIGS", fg="red", command=lambda: PoseResetterPopUp()) run_frm = Frame(master=self.settings_frm) - create_project_btn = Button( - run_frm, - text="CREATE PROJECT CONFIG", - fg="navy", - font=("Helvetica", 16, "bold"), - command=lambda: self.run(), - ) + create_project_btn = Button(run_frm, text="CREATE PROJECT CONFIG", fg="navy", font=("Helvetica", 16, "bold"), command=lambda: self.run()) self.settings_frm.grid(row=0, column=0, sticky=NW) self.general_settings_frm.grid(row=0, column=0, sticky=NW) self.project_dir_select.grid(row=0, column=0, sticky=NW) @@ -235,24 +124,11 @@ def __init__(self): run_frm.grid(row=3, column=0, sticky=NW) create_project_btn.grid(row=0, column=0, sticky=NW) - self.create_import_videos_menu(parent_frm=self.import_videos_tab) - self.create_import_pose_menu(parent_frm=self.import_data_tab) - - extract_frames_frm = LabelFrame( - self.extract_frms_tab, - text="EXTRACT FRAMES INTO PROJECT", - fg="black", - font=Formats.LABELFRAME_HEADER_FORMAT.value, - pady=5, - padx=5, - ) - extract_frames_note = Label( - extract_frames_frm, - text="Note: Frame extraction is not needed for any of the parts of the SimBA pipeline.\n Caution: This extract all frames from all videos in project. \n and is computationally expensive if there is a lot of videos at high frame rates/resolution.", - ) - extract_frames_btn = Button( - extract_frames_frm, text="EXTRACT FRAMES", fg="blue", command=lambda: None - ) + ImportVideosFrame(parent_frm=self.import_videos_tab, config_path=None, idx_row=0, idx_column=0) + ImportPoseFrame(parent_frm=self.import_data_tab, config_path=None, idx_row=0, idx_column=0) + extract_frames_frm = LabelFrame(self.extract_frms_tab, text="EXTRACT FRAMES INTO PROJECT", fg="black", font=Formats.LABELFRAME_HEADER_FORMAT.value, pady=5, padx=5) + extract_frames_note = Label(extract_frames_frm, text="Note: Frame extraction is not needed for any of the parts of the SimBA pipeline.\n Caution: This extract all frames from all videos in project. \n and is computationally expensive if there is a lot of videos at high frame rates/resolution.") + extract_frames_btn = Button(extract_frames_frm, text="EXTRACT FRAMES", fg="blue", command=lambda: None) extract_frames_frm.grid(row=0, column=0, sticky=NW) extract_frames_note.grid(row=0, column=0, sticky=NW) @@ -263,38 +139,16 @@ def __init__(self): def update_body_part_dropdown(self, selected_value): self.selected_tracking_dropdown.destroy() if selected_value == Methods.MULTI_TRACKING.value: - self.selected_tracking_dropdown = DropDownMenu( - self.animal_settings_frm, - "Body-part config", - self.multi_tracking_options, - "25", - com=self.update_img, - ) + self.selected_tracking_dropdown = DropDownMenu(self.animal_settings_frm, "Body-part config", self.multi_tracking_options, "25", com=self.update_img) self.selected_tracking_dropdown.setChoices(self.multi_tracking_options[0]) self.selected_tracking_dropdown.grid(row=1, column=0, sticky=NW) elif selected_value == Methods.CLASSIC_TRACKING.value: - self.selected_tracking_dropdown = DropDownMenu( - self.animal_settings_frm, - "Body-part config", - self.classical_tracking_options, - "25", - com=self.update_img, - ) - self.selected_tracking_dropdown.setChoices( - self.classical_tracking_options[0] - ) + self.selected_tracking_dropdown = DropDownMenu(self.animal_settings_frm, "Body-part config", self.classical_tracking_options, "25", com=self.update_img) + self.selected_tracking_dropdown.setChoices(self.classical_tracking_options[0]) self.selected_tracking_dropdown.grid(row=1, column=0, sticky=NW) elif selected_value == Methods.THREE_D_TRACKING.value: - self.selected_tracking_dropdown = DropDownMenu( - self.animal_settings_frm, - "Body-part config", - self.three_dim_tracking_options, - "25", - com=self.update_img, - ) - self.selected_tracking_dropdown.setChoices( - self.three_dim_tracking_options[0] - ) + self.selected_tracking_dropdown = DropDownMenu(self.animal_settings_frm, "Body-part config", self.three_dim_tracking_options, "25", com=self.update_img) + self.selected_tracking_dropdown.setChoices(self.three_dim_tracking_options[0]) self.selected_tracking_dropdown.grid(row=1, column=0, sticky=NW) self.update_img(self.selected_tracking_dropdown.getChoices()) @@ -302,19 +156,13 @@ def update_img(self, selected_value): if selected_value != Methods.CREATE_POSE_CONFIG.value: self.img_lbl.config(image=self.bp_lu[selected_value]["img"]) else: - _ = CreateUserDefinedPoseConfigurationPopUp( - master=self.main_frm, project_config_class=ProjectCreatorPopUp - ) + _ = CreateUserDefinedPoseConfigurationPopUp(master=self.main_frm, project_config_class=ProjectCreatorPopUp) def extract_frames(self): if not hasattr(self, "config_path"): - raise MissingProjectConfigEntryError( - msg="Create PROJECT CONFIG before extracting frames" - ) + raise MissingProjectConfigEntryError(msg="Create PROJECT CONFIG before extracting frames") video_dir = os.path.join(os.path.dirname(self.config_path), "videos") - extract_frames_from_all_videos_in_directory( - config_path=self.config_path, directory=video_dir - ) + extract_frames_from_all_videos_in_directory(config_path=self.config_path, directory=video_dir) def run(self): project_dir = self.project_dir_select.folder_path @@ -340,19 +188,16 @@ def run(self): config_idx = cnt animal_cnt = self.animal_count_lst[config_idx] - config_creator = ProjectConfigCreator( - project_path=project_dir, - project_name=project_name, - target_list=target_list, - pose_estimation_bp_cnt=config_code, - body_part_config_idx=config_idx, - animal_cnt=animal_cnt, - file_type=self.file_type_dropdown.getChoices(), - ) + config_creator = ProjectConfigCreator(project_path=project_dir, + project_name=project_name, + target_list=target_list, + pose_estimation_bp_cnt=config_code, + body_part_config_idx=config_idx, + animal_cnt=animal_cnt, + file_type=self.file_type_dropdown.getChoices()) self.config_path = config_creator.config_path - self.create_import_pose_menu(parent_frm=self.import_data_tab) - self.create_import_videos_menu(parent_frm=self.import_videos_tab) - + ImportPoseFrame(parent_frm=self.import_data_tab, idx_row=0, idx_column=0, config_path=self.config_path) + ImportVideosFrame(parent_frm=self.import_videos_tab, config_path=self.config_path, idx_row=0, idx_column=0) -# ProjectCreatorPopUp() +#ProjectCreatorPopUp() diff --git a/simba/ui/import_pose_frame.py b/simba/ui/import_pose_frame.py new file mode 100644 index 000000000..493ff0803 --- /dev/null +++ b/simba/ui/import_pose_frame.py @@ -0,0 +1,346 @@ +__author__ = "Simon Nilsson" + +import os +from tkinter import * +from tkinter import ttk +from typing import Dict, Optional, Union +try: + from typing import Literal +except: + from typing_extensions import Literal + + +from simba.mixins.config_reader import ConfigReader +from simba.mixins.pop_up_mixin import PopUpMixin +from simba.pose_importers.trk_importer import TRKImporter +from simba.pose_importers.dlc_importer_csv import import_dlc_csv_data +from simba.pose_importers.import_mars import MarsImporter +from simba.pose_importers.madlc_importer import MADLCImporterH5 +from simba.pose_importers.read_DANNCE_mat import (import_DANNCE_file, + import_DANNCE_folder) +from simba.pose_importers.sleap_csv_importer import SLEAPImporterCSV +from simba.pose_importers.sleap_h5_importer import SLEAPImporterH5 +from simba.pose_importers.sleap_slp_importer import SLEAPImporterSLP +from simba.ui.tkinter_functions import (DropDownMenu, Entry_Box, FileSelect, FolderSelect) +from simba.utils.checks import (check_int, check_str, check_instance) +from simba.utils.enums import ConfigKey, Formats, Options, Dtypes +from simba.utils.errors import InvalidInputError + +from simba.utils.read_write import read_config_file + + +GAUSSIAN = 'Gaussian' +SAVITZKY_GOLAY = 'Savitzky Golay' +INTERPOLATION_MAP = {'Animal(s)': 'animals', 'Body-parts': 'body-parts'} +SMOOTHING_MAP = {'Savitzky Golay': 'savitzky-golay', 'Gaussian': 'gaussian'} + +FRAME_DIR_IMPORT_TITLES = {'CSV (DLC/DeepPoseKit)': 'IMPORT DLC CSV DIRECTORY', 'MAT (DANNCE 3D)': 'IMPORT DANNCE MAT DIRECTORY', 'JSON (BENTO)': 'IMPORT MARS JSON DIRECTORY'} +FRAME_FILE_IMPORT_TITLES = {'CSV (DLC/DeepPoseKit)': 'IMPORT DLC CSV FILE', 'MAT (DANNCE 3D)': 'IMPORT DANNCE MAT FILE', 'JSON (BENTO)': 'IMPORT MARS JSON FILE'} +FILE_TYPES = {'CSV (DLC/DeepPoseKit)': '*.csv', 'MAT (DANNCE 3D)': '*.mat', 'JSON (BENTO)': '*.json'} + + +class ImportPoseFrame(ConfigReader, PopUpMixin): + + """ + .. image:: _static/img/ImportPoseFrame.webp + :width: 500 + :align: center + + :param Optional[Union[Frame, Canvas, LabelFrame, ttk.Frame]] parent_frm: Parent frame to insert the Import pose frame into. If None, one is created. + :param Optional[Union[str, os.PathLike]] config_path: + :param Optional[int] idx_row: The row in parent_frm to insert the Import pose frame into. Default: 0. + :param Optional[int] idx_column: The column in parent_frm to insert the Import pose frame into. Default: 0. + + :example: + >>> _ = ImportPoseFrame(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini') + """ + + def __init__(self, + parent_frm: Optional[Union[Frame, Canvas, LabelFrame, ttk.Frame]] = None, + config_path: Optional[Union[str, os.PathLike]] = None, + idx_row: Optional[int] = 0, + idx_column: Optional[int] = 0): + + if parent_frm is None and config_path is None: + raise InvalidInputError(msg='If parent_frm is None, please pass config_path', source=self.__class__.__name__) + + elif parent_frm is None and config_path is not None: + PopUpMixin.__init__(self, config_path=config_path, title='IMPORT POSE ESTIMATION') + parent_frm = self.main_frm + + check_instance(source=f'{self.__class__.__name__} parent_frm', accepted_types=(Frame, Canvas, LabelFrame, ttk.Frame), instance=parent_frm) + check_int(name=f'{self.__class__.__name__} idx_row', value=idx_row, min_value=0) + check_int(name=f'{self.__class__.__name__} idx_column', value=idx_column, min_value=0) + + self.import_tracking_frm = LabelFrame(parent_frm, text="IMPORT TRACKING DATA", font=Formats.LABELFRAME_HEADER_FORMAT.value, fg="black") + self.import_tracking_frm.grid(row=0, column=0, sticky=NW) + if config_path is None: + Label(self.import_tracking_frm, text="Please CREATE PROJECT CONFIG before importing tracking data \n").grid(row=0, column=0, sticky=NW) + else: + ConfigReader.__init__(self, config_path=config_path, read_video_info=False) + self.data_type_dropdown = DropDownMenu(self.import_tracking_frm, "DATA TYPE:", Options.IMPORT_TYPE_OPTIONS.value, labelwidth=25, com=self.create_import_menu) + self.data_type_dropdown.setChoices(Options.IMPORT_TYPE_OPTIONS.value[0]) + self.data_type_dropdown.grid(row=0, column=0, sticky=NW) + + self.create_import_menu(data_type_choice=Options.IMPORT_TYPE_OPTIONS.value[0]) + self.import_tracking_frm.grid(row=idx_row, column=idx_column, sticky=NW) + + # parent_frm.mainloop() + + def __show_smoothing_entry_box_from_dropdown(self, choice: str): + if (choice == GAUSSIAN) or (choice == SAVITZKY_GOLAY): + self.smoothing_time_eb.grid(row=0, column=1, sticky=E) + else: + self.smoothing_time_eb.grid_forget() + + + def __get_smooth_interpolation_settings(self, + interpolation_settings: str, + smoothing_setting: str, + smoothing_time: Union[str, int]): + + if interpolation_settings != Dtypes.NONE.value: + interpolation_settings = interpolation_settings.split(':') + interpolation_settings = {'type': INTERPOLATION_MAP[interpolation_settings[0]].lower().strip(), 'method': interpolation_settings[1].lower().strip()} + else: + interpolation_settings = None + if smoothing_setting != Dtypes.NONE.value: + check_int(name='SMOOTHING TIME', value=smoothing_time, min_value=1) + smoothing_setting = {'time_window': int(smoothing_time), 'method': SMOOTHING_MAP[smoothing_setting]} + else: + smoothing_setting = None + + return interpolation_settings, smoothing_setting + + + def __import_dlc_csv_data(self, + interpolation_settings: str, + smoothing_setting: str, + smoothing_time: Union[str, int], + data_path: Union[str, os.PathLike]): + + if not os.path.isfile(data_path) and not os.path.isdir(data_path): + raise InvalidInputError(msg=f'{data_path} is NOT a valid path', source=self.__class__.__name__) + + smoothing_settings, interpolation_settings = self.__get_smooth_interpolation_settings(interpolation_settings, smoothing_setting, smoothing_time) + import_dlc_csv_data(config_path=self.config_path, + data_path=data_path, + interpolation_settings=interpolation_settings, + smoothing_settings=smoothing_setting) + + def __multi_animal_run_call(self, + pose_estimation_tool: str, + interpolation_settings: str, + smoothing_settings: str, + smoothing_window: int, + animal_names: Dict[int, Entry_Box], + data_path: Union[str, os.PathLike], + tracking_data_type: Optional[str] = None): + + if not os.path.isfile(data_path) and not os.path.isdir(data_path): + raise InvalidInputError(msg=f'{data_path} is NOT a valid path', source=self.__class__.__name__) + smoothing_settings, interpolation_settings = self.__get_smooth_interpolation_settings(interpolation_settings, smoothing_settings, smoothing_window) + animal_ids = [] + if len(list(animal_names.items())) == 1: animal_ids.append("Animal_1") + else: + for animal_cnt, animal_entry_box in animal_names.items(): + check_str(name=f"ANIMAL {str(animal_cnt)} NAME", value=animal_entry_box.entry_get, allow_blank=False) + animal_ids.append(animal_entry_box.entry_get) + + config = read_config_file(config_path=self.config_path) + config.set(ConfigKey.MULTI_ANIMAL_ID_SETTING.value, ConfigKey.MULTI_ANIMAL_IDS.value, ",".join(animal_ids)) + with open(config, "w") as f: config.write(f) + + if pose_estimation_tool == "H5 (multi-animal DLC)": + data_importer = MADLCImporterH5(config_path=self.config_path, + data_folder=data_path, + file_type=tracking_data_type, + id_lst=animal_ids, + interpolation_settings=interpolation_settings, + smoothing_settings=smoothing_settings) + + elif pose_estimation_tool == "SLP (SLEAP)": + data_importer = SLEAPImporterSLP(project_path=self.config_path, + data_folder=data_path, + id_lst=animal_ids, + interpolation_settings=interpolation_settings, + smoothing_settings=smoothing_settings) + + elif pose_estimation_tool == "TRK (multi-animal APT)": + data_importer = TRKImporter(config_path=self.config_path, + data_path=data_path, + animal_id_lst=animal_ids, + interpolation_method=interpolation_settings, + smoothing_settings=smoothing_settings) + + elif pose_estimation_tool == "CSV (SLEAP)": + data_importer = SLEAPImporterCSV(config_path=self.config_path, + data_folder=data_path, + id_lst=animal_ids, + interpolation_settings=interpolation_settings, + smoothing_settings=smoothing_settings) + + elif pose_estimation_tool == "H5 (SLEAP)": + data_importer = SLEAPImporterH5(config_path=self.config_path, + data_folder=data_path, + id_lst=animal_ids, + interpolation_settings=interpolation_settings, + smoothing_settings=smoothing_settings) + else: + raise InvalidInputError(msg=f'pose estimation tool {pose_estimation_tool} not recognized', source=self.__class__.__name__) + data_importer.run() + + def __create_animal_names_entry_boxes(self, + animal_cnt: str) -> None: + check_int(name="NUMBER OF ANIMALS", value=animal_cnt, min_value=0) + if hasattr(self, "animal_names_frm"): + self.animal_names_frm.destroy() + if not hasattr(self, "multi_animal_id_list"): + self.multi_animal_id_list = [] + for i in range(int(animal_cnt)): + self.multi_animal_id_list.append(f"Animal {i+1}") + self.animal_names_frm = Frame(self.animal_settings_frm, pady=5, padx=5) + self.animal_name_entry_boxes = {} + for i in range(int(animal_cnt)): + self.animal_name_entry_boxes[i + 1] = Entry_Box(self.animal_names_frm, f"Animal {str(i+1)} name: ", "25") + if i <= len(self.multi_animal_id_list) - 1: + self.animal_name_entry_boxes[i + 1].entry_set(self.multi_animal_id_list[i]) + self.animal_name_entry_boxes[i + 1].grid(row=i, column=0, sticky=NW) + self.animal_names_frm.grid(row=1, column=0, sticky=NW) + + def create_import_menu(self, data_type_choice: Literal["CSV (DLC/DeepPoseKit)", "JSON (BENTO)", "H5 (multi-animal DLC)", "SLP (SLEAP)", "CSV (SLEAP)", "H5 (SLEAP)", "TRK (multi-animal APT)", "MAT (DANNCE 3D)"]): + if hasattr(self, "choice_frm"): + self.choice_frm.destroy() + + self.choice_frm = Frame(self.import_tracking_frm) + self.choice_frm.grid(row=1, column=0, sticky=NW) + self.animal_name_entry_boxes = None + self.interpolation_frm = LabelFrame(self.choice_frm, text="INTERPOLATION METHOD", pady=5, padx=5) + self.interpolation_dropdown = DropDownMenu(self.interpolation_frm, "Interpolation method: ", Options.INTERPOLATION_OPTIONS_W_NONE.value, "25") + self.interpolation_dropdown.setChoices(Options.INTERPOLATION_OPTIONS_W_NONE.value[0]) + self.interpolation_frm.grid(row=0, column=0, sticky=NW) + self.interpolation_dropdown.grid(row=0, column=0, sticky=NW) + + self.smoothing_frm = LabelFrame(self.choice_frm, text="SMOOTHING METHOD", pady=5, padx=5) + self.smoothing_dropdown = DropDownMenu(self.smoothing_frm, "Smoothing", Options.SMOOTHING_OPTIONS_W_NONE.value, "25", com=self.__show_smoothing_entry_box_from_dropdown) + self.smoothing_dropdown.setChoices(Options.SMOOTHING_OPTIONS_W_NONE.value[0]) + self.smoothing_time_eb = Entry_Box(self.smoothing_frm, "Smoothing period (milliseconds):", labelwidth="25", width=10, validation="numeric") + self.smoothing_frm.grid(row=1, column=0, sticky=NW) + self.smoothing_dropdown.grid(row=0, column=0, sticky=NW) + + if data_type_choice in ["CSV (DLC/DeepPoseKit)", "MAT (DANNCE 3D)", "JSON (BENTO)"]: # DATA TYPES WHERE NO TRACKS HAVE TO BE SPECIFIED + self.import_directory_frm = LabelFrame(self.choice_frm, text=FRAME_DIR_IMPORT_TITLES[data_type_choice], pady=5, padx=5) + self.import_directory_select = FolderSelect(self.import_directory_frm, "Input data DIRECTORY:", lblwidth=25, initialdir=self.project_path) + self.import_single_frm = LabelFrame(self.choice_frm, text=FRAME_FILE_IMPORT_TITLES[data_type_choice], pady=5, padx=5) + self.import_file_select = FileSelect(self.import_single_frm, "Input data FILE:", lblwidth=25, file_types=[("Pose data file", FILE_TYPES[data_type_choice])]) + + if data_type_choice == "CSV (DLC/DeepPoseKit)": + self.import_dir_btn = Button(self.import_directory_frm, fg="blue", text="Import DLC CSV DIRECTORY to SimBA project", command=lambda: self.__import_dlc_csv_data(interpolation_settings=self.interpolation_dropdown.getChoices(), + smoothing_setting=self.smoothing_dropdown.getChoices(), + smoothing_time=self.smoothing_time_eb.entry_get, + data_path=self.import_directory_select.folder_path)) + self.import_file_btn = Button(self.import_single_frm, fg="blue", text="Import DLC CSV FILE to SimBA project", command=lambda: self.__import_dlc_csv_data(interpolation_settings=self.interpolation_dropdown.getChoices(), + smoothing_setting=self.smoothing_dropdown.getChoices(), + smoothing_time=self.smoothing_time_eb.entry_get, + data_path=self.import_file_select.file_path)) + elif data_type_choice == "MAT (DANNCE 3D)": + self.import_dir_btn = Button(self.import_directory_frm, fg="blue", text="Import DANNCE MAT DIRECTORY to SimBA project", command=lambda: import_DANNCE_folder(config_path=self.config_path, + folder_path=self.import_directory_select.folder_path, + interpolation_method=self.interpolation_dropdown.getChoices())) + + self.import_file_btn = Button(self.import_single_frm, fg="blue", text="Import DANNCE MAT FILE to SimBA project", command=lambda: import_DANNCE_file(config_path=self.config_path, + file_path=self.import_file_select.file_path, + interpolation_method=self.interpolation_dropdown.getChoices())) + else: + self.import_dir_btn = Button(self.import_directory_frm, fg="blue", text="Import BENTO JSON DIRECTORY to SimBA project", command=lambda: MarsImporter(config_path=self.config_path, + data_path=self.import_directory_select.folder_path, + interpolation_method=self.interpolation_dropdown.getChoices(), + smoothing_method={"Method": self.smoothing_dropdown.getChoices(), "Parameters": {"Time_window": self.smoothing_time_eb.entry_get}})) + + self.import_file_btn = Button(self.import_single_frm, fg="blue", text="Import BENTO JSON FILE to SimBA project", command=lambda: MarsImporter(config_path=self.config_path, data_path=self.import_directory_select.folder_path, interpolation_method=self.interpolation_dropdown.getChoices(), + smoothing_method={"Method": self.smoothing_dropdown.getChoices(), "Parameters": {"Time_window": self.smoothing_time_eb.entry_get}})) + + self.import_directory_frm.grid(row=2, column=0, sticky=NW) + self.import_directory_select.grid(row=0, column=0, sticky=NW) + self.import_dir_btn.grid(row=1, column=0, sticky=NW) + + self.import_single_frm.grid(row=3, column=0, sticky=NW) + self.import_file_select.grid(row=0, column=0, sticky=NW) + self.import_file_btn.grid(row=1, column=0, sticky=NW) + + else: # DATA TYPES WHERE TRACKS HAVE TO BE SPECIFIED + self.animal_settings_frm = LabelFrame(self.choice_frm, text="ANIMAL SETTINGS", pady=5, padx=5) + animal_cnt_entry_box = Entry_Box(self.animal_settings_frm, "ANIMAL COUNT:", "25", validation="numeric") + animal_cnt_entry_box.entry_set(val=self.animal_cnt) + animal_cnt_confirm = Button(self.animal_settings_frm, text="CONFIRM", fg="blue", command=lambda: self.create_animal_names_entry_boxes( animal_cnt=animal_cnt_entry_box.entry_get)) + self.create_animal_names_entry_boxes(animal_cnt=animal_cnt_entry_box.entry_get) + self.animal_settings_frm.grid(row=4, column=0, sticky=NW) + animal_cnt_entry_box.grid(row=0, column=0, sticky=NW) + animal_cnt_confirm.grid(row=0, column=1, sticky=NW) + + self.data_dir_frm = LabelFrame(self.choice_frm, text="DATA DIRECTORY", pady=5, padx=5) + self.import_frm = LabelFrame(self.choice_frm, text="IMPORT", pady=5, padx=5) + + if data_type_choice == "H5 (multi-animal DLC)": + self.tracking_type_frm = LabelFrame(self.choice_frm, text="TRACKING DATA TYPE", pady=5, padx=5) + self.dlc_data_type_option_dropdown = DropDownMenu(self.tracking_type_frm, "TRACKING_TYPE", Options.MULTI_DLC_TYPE_IMPORT_OPTION.value, labelwidth=25) + self.dlc_data_type_option_dropdown.setChoices(Options.MULTI_DLC_TYPE_IMPORT_OPTION.value[1]) + self.tracking_type_frm.grid(row=5, column=0, sticky=NW) + self.dlc_data_type_option_dropdown.grid(row=0, column=0, sticky=NW) + self.data_dir_select = FolderSelect(self.data_dir_frm, "H5 DLC DIRECTORY: ", lblwidth=25) + self.instructions_lbl = Label(self.data_dir_frm, text="Please import videos BEFORE importing the \n multi animal DLC tracking data") + self.run_btn = Button(self.import_frm, text="IMPORT DLC .H5", fg="blue", command=lambda: self.__multi_animal_run_call(pose_estimation_tool=data_type_choice, + interpolation_settings=self.interpolation_dropdown.getChoices(), + smoothing_settings=self.smoothing_dropdown.getChoices(), + smoothing_window=self.smoothing_time_eb.entry_get, + animal_names=self.animal_name_entry_boxes, + data_path=self.data_dir_select.folder_path, + tracking_data_type=self.dlc_data_type_option_dropdown.getChoices())) + elif data_type_choice == "SLP (SLEAP)": + self.data_dir_select = FolderSelect(self.data_dir_frm, "SLP SLEAP DIRECTORY: ", lblwidth=25) + self.instructions_lbl = Label(self.data_dir_frm, text="Please import videos before importing the \n multi animal SLEAP tracking data if you are tracking more than ONE animal") + self.run_btn = Button(self.import_frm, text="IMPORT SLEAP .SLP", fg="blue", command=lambda: self.__multi_animal_run_call(pose_estimation_tool=data_type_choice, + interpolation_settings=self.interpolation_dropdown.getChoices(), + smoothing_settings=self.smoothing_dropdown.getChoices(), + smoothing_window=self.smoothing_time_eb.entry_get, + animal_names=self.animal_name_entry_boxes, + data_path=self.data_dir_select.folder_path)) + + elif data_type_choice == "TRK (multi-animal APT)": + self.data_dir_select = FolderSelect(self.data_dir_frm, "TRK APT DIRECTORY: ", lblwidth=25) + self.instructions_lbl = Label(self.data_dir_frm, text="Please import videos before importing the \n multi animal TRK tracking data") + self.run_btn = Button(self.import_frm, text="IMPORT APT .TRK", fg="blue", command=lambda: self.__multi_animal_run_call(pose_estimation_tool=data_type_choice, + interpolation_settings=self.interpolation_dropdown.getChoices(), + smoothing_settings=self.smoothing_dropdown.getChoices(), + smoothing_window=self.smoothing_time_eb.entry_get, + animal_names=self.animal_name_entry_boxes, + data_path=self.data_dir_select.folder_path)) + + elif data_type_choice == "CSV (SLEAP)": + self.data_dir_select = FolderSelect(self.data_dir_frm, "CSV SLEAP DIRECTORY:", lblwidth=25) + self.instructions_lbl = Label(self.data_dir_frm, text="Please import videos before importing the SLEAP tracking data \n IF you are tracking more than ONE animal") + self.run_btn = Button(self.import_frm, text="IMPORT SLEAP .CSV", fg="blue", command=lambda: self.__multi_animal_run_call(pose_estimation_tool=data_type_choice, + interpolation_settings=self.interpolation_dropdown.getChoices(), + smoothing_settings=self.smoothing_dropdown.getChoices(), + smoothing_window=self.smoothing_time_eb.entry_get, + animal_names=self.animal_name_entry_boxes, + data_path=self.data_dir_select.folder_path)) + + elif data_type_choice == "H5 (SLEAP)": + self.data_dir_select = FolderSelect(self.data_dir_frm, "H5 SLEAP DIRECTORY", lblwidth=25) + self.instructions_lbl = Label(self.data_dir_frm,text="Please import videos before importing the SLEAP H5 tracking data \n IF you are tracking more than ONE animal") + self.run_btn = Button(self.import_frm, text="IMPORT SLEAP H5", fg="blue", command=lambda: self.__multi_animal_run_call(pose_estimation_tool=data_type_choice, + interpolation_settings=self.interpolation_dropdown.getChoices(), + smoothing_settings=self.smoothing_dropdown.getChoices(), + smoothing_window=self.smoothing_time_eb.entry_get, + animal_names=self.animal_name_entry_boxes, + data_path=self.data_dir_select.folder_path)) + + self.data_dir_frm.grid(row=self.frame_children(frame=self.choice_frm), column=0, sticky=NW) + self.data_dir_select.grid(row=0, column=0, sticky=NW) + self.instructions_lbl.grid(row=1, column=0, sticky=NW) + self.import_frm.grid(row=self.frame_children(frame=self.choice_frm) + 1, column=0, sticky=NW) + self.run_btn.grid(row=0, column=0, sticky=NW) + self.choice_frm.grid(row=1, column=0, sticky=NW) + +#_ = ImportPoseFrame(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini') \ No newline at end of file diff --git a/simba/ui/import_videos_frame.py b/simba/ui/import_videos_frame.py new file mode 100644 index 000000000..e53a7fe07 --- /dev/null +++ b/simba/ui/import_videos_frame.py @@ -0,0 +1,94 @@ +import os +from tkinter import * +from tkinter import ttk +from typing import Optional, Union + +from simba.mixins.config_reader import ConfigReader +from simba.mixins.pop_up_mixin import PopUpMixin +from simba.ui.tkinter_functions import (DropDownMenu, FileSelect, FolderSelect) +from simba.utils.checks import (check_file_exist_and_readable, check_if_dir_exists, check_int, check_instance) +from simba.utils.enums import Formats, Options +from simba.utils.errors import InvalidInputError +from simba.utils.read_write import (copy_multiple_videos_to_project, copy_single_video_to_project) + +class ImportVideosFrame(PopUpMixin, ConfigReader): + + """ + .. image:: _static/img/ImportVideosFrame.webp + :width: 500 + :align: center + + :param Optional[Union[Frame, Canvas, LabelFrame, ttk.Frame]] parent_frm: Parent frame to insert the Import Videos frame into. If None, one is created. + :param Optional[Union[str, os.PathLike]] config_path: + :param Optional[int] idx_row: The row in parent_frm to insert the Videos frame into. Default: 0. + :param Optional[int] idx_column: The column in parent_frm to insert the Videos frame into. Default: 0. + + :example: + >>> ImportVideosFrame(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini') + """ + + def __init__(self, + parent_frm: Optional[Union[Frame, Canvas, LabelFrame, ttk.Frame]] = None, + config_path: Optional[Union[str, os.PathLike]] = None, + idx_row: Optional[int] = 0, + idx_column: Optional[int] = 0): + + if parent_frm is None and config_path is None: + raise InvalidInputError(msg='If parent_frm is None, please pass config_path', source=self.__class__.__name__) + + elif parent_frm is None and config_path is not None: + PopUpMixin.__init__(self, config_path=config_path, title='IMPORT VIDEO FILES') + parent_frm = self.main_frm + + check_instance(source=f'{ImportVideosFrame} parent_frm', accepted_types=(Frame, Canvas, LabelFrame, ttk.Frame), instance=parent_frm) + check_int(name=f'{ImportVideosFrame} idx_row', value=idx_row, min_value=0) + check_int(name=f'{ImportVideosFrame} idx_column', value=idx_column, min_value=0) + + import_videos_frm = LabelFrame(parent_frm, text="IMPORT VIDEOS", fg="black", font=Formats.LABELFRAME_HEADER_FORMAT.value) + if config_path is None: + Label(import_videos_frm, text="Please CREATE PROJECT CONFIG before importing VIDEOS \n").grid(row=0, column=0, sticky=NW) + import_videos_frm.grid(row=0, column=0, sticky=NW) + else: + ConfigReader.__init__(self, config_path=config_path, read_video_info=False) + import_multiple_videos_frm = LabelFrame(import_videos_frm, text="IMPORT MULTIPLE VIDEOS") + self.video_directory_select = FolderSelect(import_multiple_videos_frm, "VIDEO DIRECTORY: ", lblwidth=25) + self.video_type = DropDownMenu(import_multiple_videos_frm, "VIDEO FILE FORMAT: ", Options.VIDEO_FORMAT_OPTIONS.value, "25") + self.video_type.setChoices(Options.VIDEO_FORMAT_OPTIONS.value[0]) + import_multiple_btn = Button(import_multiple_videos_frm, text="Import MULTIPLE videos", fg="blue", command=lambda: self.__run_video_import(multiple_videos=True)) + self.multiple_videos_symlink_var = BooleanVar(value=False) + multiple_videos_symlink_cb = Checkbutton(import_multiple_videos_frm, text="Import SYMLINKS", variable=self.multiple_videos_symlink_var) + + import_single_frm = LabelFrame(import_videos_frm, text="IMPORT SINGLE VIDEO", pady=5, padx=5) + self.video_file_select = FileSelect(import_single_frm, "VIDEO PATH: ", title="Select a video file", lblwidth=25, file_types=[("VIDEO FILE", Options.ALL_VIDEO_FORMAT_STR_OPTIONS.value)]) + import_single_btn = Button(import_single_frm, text="Import SINGLE video", fg="blue", command=lambda: self.__run_video_import(multiple_videos=False)) + self.single_video_symlink_var = BooleanVar(value=False) + single_video_symlink_cb = Checkbutton(import_single_frm, text="Import SYMLINK", variable=self.single_video_symlink_var) + + import_videos_frm.grid(row=0, column=0, sticky=NW) + import_multiple_videos_frm.grid(row=0, sticky=W) + self.video_directory_select.grid(row=1, sticky=W) + self.video_type.grid(row=2, sticky=W) + multiple_videos_symlink_cb.grid(row=3, sticky=W) + import_multiple_btn.grid(row=4, sticky=W) + + import_single_frm.grid(row=1, column=0, sticky=NW) + self.video_file_select.grid(row=0, sticky=W) + single_video_symlink_cb.grid(row=1, sticky=W) + import_single_btn.grid(row=2, sticky=W) + import_videos_frm.grid(row=idx_row, column=idx_column, sticky=NW) + + #parent_frm.mainloop() + + def __run_video_import(self, multiple_videos: bool): + if multiple_videos: + check_if_dir_exists(in_dir=self.video_directory_select.folder_path) + copy_multiple_videos_to_project(config_path=self.config_path, + source=self.video_directory_select.folder_path, + symlink=self.multiple_videos_symlink_var.get(), + file_type=self.video_type.getChoices()) + + else: + check_file_exist_and_readable(file_path=self.video_file_select.file_path) + copy_single_video_to_project(simba_ini_path=self.config_path, + symlink=self.single_video_symlink_var.get(), + source_path=self.video_file_select.file_path) \ No newline at end of file diff --git a/simba/ui/pop_ups/interpolate_pop_up.py b/simba/ui/pop_ups/interpolate_pop_up.py new file mode 100644 index 000000000..4993ed081 --- /dev/null +++ b/simba/ui/pop_ups/interpolate_pop_up.py @@ -0,0 +1,80 @@ +__author__ = "Simon Nilsson" + +import os +from tkinter import * +from copy import deepcopy +from typing import Union + +from simba.data_processors.interpolate import Interpolate +from simba.mixins.config_reader import ConfigReader +from simba.mixins.pop_up_mixin import PopUpMixin +from simba.ui.tkinter_functions import DropDownMenu, FolderSelect, FileSelect +from simba.utils.checks import check_file_exist_and_readable, check_if_dir_exists +from simba.utils.enums import Options, Formats +from simba.utils.read_write import str_2_bool + +INTERPOLATOR_METHOD = {'MISSING BODY-PARTS': 'body-parts', 'MISSING ANIMALS': 'animals'} + +class InterpolatePopUp(PopUpMixin, ConfigReader): + """ + :example: + >>> InterpolatePopUp(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini') + + """ + + def __init__(self, config_path: Union[str, os.PathLike]): + PopUpMixin.__init__(self, title="INTERPOLATE POSE") + ConfigReader.__init__(self, config_path=config_path) + self.config_path = config_path + self.settings_frm = LabelFrame(self.main_frm, text="SETTINGS", font=Formats.LABELFRAME_HEADER_FORMAT.value) + self.type_dropdown = DropDownMenu(self.settings_frm, "INTERPOLATION TYPE:", ['MISSING BODY-PARTS', 'MISSING ANIMALS'], "35") + self.method_dropdown = DropDownMenu(self.settings_frm, "INTERPOLATION METHOD:", ['NEAREST', 'LINEAR', 'QUADRATIC'], "35") + self.save_originals_dropdown = DropDownMenu(self.settings_frm, "SAVE ORIGINALS:", Options.BOOL_STR_OPTIONS.value, "35") + self.type_dropdown.setChoices('MISSING BODY-PARTS') + self.method_dropdown.setChoices('NEAREST') + self.save_originals_dropdown.setChoices(Options.BOOL_STR_OPTIONS.value[0]) + + self.settings_frm.grid(row=0, column=0, sticky=NW) + self.type_dropdown.grid(row=0, column=0, sticky=NW) + self.method_dropdown.grid(row=1, column=0, sticky=NW) + self.save_originals_dropdown.grid(row=2, column=0, sticky=NW) + + self.single_file_frm = LabelFrame(self.main_frm, text="INTERPOLATE SINGLE DATA FILE", font=Formats.LABELFRAME_HEADER_FORMAT.value) + self.selected_file = FileSelect(self.single_file_frm, "DATA PATH:", lblwidth=35, file_types=[("VIDEO FILE", ".csv .parquet")], initialdir=self.project_path) + self.run_btn_single = Button(self.single_file_frm, text="RUN SINGLE DATA FILE INTERPOLATION", fg="blue", command=lambda: self.run(multiple=False)) + + self.single_file_frm.grid(row=1, column=0, sticky=NW) + self.selected_file.grid(row=0, column=0, sticky=NW) + self.run_btn_single.grid(row=1, column=0, sticky=NW) + + self.multiple_file_frm = LabelFrame(self.main_frm, text="INTERPOLATE DIRECTORY OF DATA", font=Formats.LABELFRAME_HEADER_FORMAT.value) + self.selected_dir = FolderSelect(self.multiple_file_frm, "SELECT DIRECTORY OF DATA FILES:", lblwidth=35, initialdir=self.project_path) + self.run_btn_multiple = Button(self.multiple_file_frm, text="RUN SINGLE DATA FILE INTERPOLATION", fg="blue", command=lambda: self.run(multiple=True)) + + self.multiple_file_frm.grid(row=2, column=0, sticky=NW) + self.selected_dir.grid(row=0, column=0, sticky=NW) + self.run_btn_multiple.grid(row=1, column=0, sticky=NW) + self.main_frm.mainloop() + + def run(self, multiple): + interpolation_type = INTERPOLATOR_METHOD[self.type_dropdown.getChoices()] + interpolation_method = self.method_dropdown.getChoices().lower() + copy_originals = str_2_bool(self.save_originals_dropdown.getChoices()) + if not multiple: + data_path = self.selected_file.file_path + check_file_exist_and_readable(file_path=data_path) + data_dir = os.path.dirname(data_path) + else: + data_path = self.selected_dir.folder_path + check_if_dir_exists(in_dir=data_path) + data_dir = deepcopy(data_path) + + multi_index_df_headers = False + if data_dir == self.input_csv_dir: multi_index_df_headers = True + interpolator = Interpolate(config_path=self.config_path, + data_path=data_path, + type=interpolation_type, + method=interpolation_method, + copy_originals=copy_originals, + multi_index_df_headers=multi_index_df_headers) + interpolator.run() diff --git a/simba/ui/pop_ups/smoothing_interpolation_pop_up.py b/simba/ui/pop_ups/smoothing_interpolation_pop_up.py deleted file mode 100644 index c777a1632..000000000 --- a/simba/ui/pop_ups/smoothing_interpolation_pop_up.py +++ /dev/null @@ -1,83 +0,0 @@ -__author__ = "Simon Nilsson" - -import os -from tkinter import * - -from simba.data_processors.interpolation_smoothing import Interpolate, Smooth -from simba.mixins.config_reader import ConfigReader -from simba.mixins.pop_up_mixin import PopUpMixin -from simba.ui.tkinter_functions import DropDownMenu, Entry_Box, FolderSelect -from simba.utils.checks import check_int -from simba.utils.enums import Options -from simba.utils.errors import NotDirectoryError - - -class InterpolatePopUp(PopUpMixin, ConfigReader): - def __init__(self, config_path: str): - PopUpMixin.__init__(self, title="INTERPOLATE POSE") - ConfigReader.__init__(self, config_path=config_path) - self.input_dir = FolderSelect(self.main_frm, "DATA DIRECTORY:", lblwidth=25) - self.method_dropdown = DropDownMenu( - self.main_frm, - "INTERPOLATION METHOD:", - Options.INTERPOLATION_OPTIONS.value, - "25", - ) - self.method_dropdown.setChoices(Options.INTERPOLATION_OPTIONS.value[0]) - run_btn = Button( - self.main_frm, - text="RUN INTERPOLATION", - fg="blue", - command=lambda: self.run(), - ) - self.input_dir.grid(row=0, column=0, sticky=NW) - self.method_dropdown.grid(row=1, column=0, sticky=NW) - run_btn.grid(row=2, column=0, sticky=NW) - - def run(self): - if not os.path.isdir(self.input_dir.folder_path): - raise NotDirectoryError( - msg=f"{self.input_dir.folder_path} is not a valid directory.", - source=self.__class__.__name__, - ) - Interpolate( - config_path=self.config_path, - method=self.method_dropdown.getChoices(), - input_path=self.input_dir.folder_path, - ) - - -class SmoothingPopUp(PopUpMixin, ConfigReader): - def __init__(self, config_path: str): - PopUpMixin.__init__(self, title="SMOOTH POSE") - ConfigReader.__init__(self, config_path=config_path) - self.input_dir = FolderSelect(self.main_frm, "DATA DIRECTORY:", lblwidth=20) - self.time_window = Entry_Box( - self.main_frm, "TIME WINDOW (MS):", "20", validation="numeric" - ) - self.method_dropdown = DropDownMenu( - self.main_frm, "SMOOTHING METHOD:", Options.SMOOTHING_OPTIONS.value, "20" - ) - self.method_dropdown.setChoices(Options.SMOOTHING_OPTIONS.value[0]) - run_btn = Button( - self.main_frm, text="RUN SMOOTHING", fg="blue", command=lambda: self.run() - ) - - self.input_dir.grid(row=0, column=0, sticky=NW) - self.method_dropdown.grid(row=1, column=0, sticky=NW) - self.time_window.grid(row=2, column=0, sticky=NW) - run_btn.grid(row=3, column=0, sticky=NW) - - def run(self): - if not os.path.isdir(self.input_dir.folder_path): - raise NotDirectoryError( - msg=f"{self.input_dir.folder_path} is not a valid directory.", - source=self.__class__.__name__, - ) - check_int(name="TIME WINDOW", value=self.time_window.entry_get, min_value=1) - _ = Smooth( - config_path=self.config_path, - input_path=self.input_dir.folder_path, - time_window=self.time_window.entry_get, - smoothing_method=self.method_dropdown.getChoices(), - ) diff --git a/simba/ui/pop_ups/smoothing_popup.py b/simba/ui/pop_ups/smoothing_popup.py new file mode 100644 index 000000000..1217a8907 --- /dev/null +++ b/simba/ui/pop_ups/smoothing_popup.py @@ -0,0 +1,81 @@ +__author__ = "Simon Nilsson" + +import os +from tkinter import * +from copy import deepcopy +from typing import Union + +from simba.data_processors.smoothing import Smoothing +from simba.mixins.config_reader import ConfigReader +from simba.mixins.pop_up_mixin import PopUpMixin +from simba.ui.tkinter_functions import DropDownMenu, FolderSelect, FileSelect, Entry_Box +from simba.utils.checks import check_file_exist_and_readable, check_if_dir_exists, check_int +from simba.utils.enums import Options, Formats +from simba.utils.read_write import str_2_bool + +SMOOTHING_OPTION = {'Savitzky Golay': "savitzky-golay", "Gaussian": "gaussian"} + +class SmoothingPopUp(PopUpMixin, ConfigReader): + + """ + :example: + >>> SmoothingPopUp(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini') + """ + def __init__(self, config_path: Union[str, os.PathLike]): + PopUpMixin.__init__(self, title="SMOOTH POSE-ESTIMATION DATA") + ConfigReader.__init__(self, config_path=config_path) + self.config_path = config_path + + self.settings_frm = LabelFrame(self.main_frm, text="SETTINGS", font=Formats.LABELFRAME_HEADER_FORMAT.value) + self.time_window = Entry_Box(self.settings_frm, "TIME WINDOW (MILLISECONDS):", "35", validation="numeric") + self.method_dropdown = DropDownMenu(self.settings_frm, "SMOOTHING METHOD:", Options.SMOOTHING_OPTIONS.value, "35") + self.save_originals_dropdown = DropDownMenu(self.settings_frm, "SAVE ORIGINALS:", Options.BOOL_STR_OPTIONS.value, "35") + self.save_originals_dropdown.setChoices(Options.BOOL_STR_OPTIONS.value[0]) + + self.method_dropdown.setChoices(Options.SMOOTHING_OPTIONS.value[0]) + self.settings_frm.grid(row=0, column=0, sticky=NW) + self.time_window.grid(row=0, column=0, sticky=NW) + self.method_dropdown.grid(row=1, column=0, sticky=NW) + self.save_originals_dropdown.grid(row=2, column=0, sticky=NW) + + self.single_file_frm = LabelFrame(self.main_frm, text="SMOOTH SINGLE DATA FILE", font=Formats.LABELFRAME_HEADER_FORMAT.value) + self.selected_file = FileSelect(self.single_file_frm, "DATA PATH:", lblwidth=35, file_types=[("VIDEO FILE", ".csv .parquet")], initialdir=self.project_path) + self.run_btn_single = Button(self.single_file_frm, text="RUN SINGLE DATA FILE SMOOTHING", fg="blue", command=lambda: self.run(multiple=False)) + + self.single_file_frm.grid(row=1, column=0, sticky=NW) + self.selected_file.grid(row=0, column=0, sticky=NW) + self.run_btn_single.grid(row=1, column=0, sticky=NW) + + self.multiple_file_frm = LabelFrame(self.main_frm, text="SMOOTH DIRECTORY OF DATA", font=Formats.LABELFRAME_HEADER_FORMAT.value) + self.selected_dir = FolderSelect(self.multiple_file_frm, "SELECT DIRECTORY OF DATA FILES:", lblwidth=35, initialdir=self.project_path) + self.run_btn_multiple = Button(self.multiple_file_frm, text="RUN SINGLE DATA FILE SMOOTHING", fg="blue", command=lambda: self.run(multiple=True)) + + self.multiple_file_frm.grid(row=2, column=0, sticky=NW) + self.selected_dir.grid(row=0, column=0, sticky=NW) + self.run_btn_multiple.grid(row=1, column=0, sticky=NW) + self.main_frm.mainloop() + + def run(self, multiple): + smooth_time = self.time_window.entry_get + smooth_method = SMOOTHING_OPTION[self.method_dropdown.getChoices()] + copy_originals = str_2_bool(self.save_originals_dropdown.getChoices()) + check_int(name='TIME WINDOW (MILLISECONDS)', value=smooth_time, min_value=1) + + if not multiple: + data_path = self.selected_file.file_path + check_file_exist_and_readable(file_path=data_path) + data_dir = os.path.dirname(data_path) + else: + data_path = self.selected_dir.folder_path + check_if_dir_exists(in_dir=data_path) + data_dir = deepcopy(data_path) + multi_index_df_headers = False + if data_dir == self.input_csv_dir: multi_index_df_headers = True + + smoothing = Smoothing(config_path=self.config_path, + data_path=data_path, + time_window=int(smooth_time), + method=smooth_method, + multi_index_df_headers=multi_index_df_headers, + copy_originals=copy_originals) + smoothing.run() \ No newline at end of file diff --git a/simba/ui/pop_ups/video_processing_pop_up.py b/simba/ui/pop_ups/video_processing_pop_up.py index be5c2f98f..d9b6e7f86 100644 --- a/simba/ui/pop_ups/video_processing_pop_up.py +++ b/simba/ui/pop_ups/video_processing_pop_up.py @@ -2361,7 +2361,7 @@ def __init__(self): self.selected_frame_dir.grid(row=0, column=0, sticky="NW") run_btn_dir.grid(row=1, column=0, sticky="NW") - convert_img_frm = CreateLabelFrameWithIcon(parent=self.main_frm, header="CONVERT IMAGE DIRECTORY TO WEBP", icon_name=Keys.DOCUMENTATION.value, icon_link=Links.VIDEO_TOOLS.value) + convert_img_frm = CreateLabelFrameWithIcon(parent=self.main_frm, header="CONVERT IMAGE FILE TO WEBP", icon_name=Keys.DOCUMENTATION.value, icon_link=Links.VIDEO_TOOLS.value) self.selected_file = FileSelect(convert_img_frm, "IMAGE PATH:", title="Select an image file", lblwidth=25, file_types=[("VIDEO FILE", Options.ALL_IMAGE_FORMAT_OPTIONS.value)]) run_btn_frm = Button(convert_img_frm, text="RUN IMAGE WEBP CONVERSION", command=lambda: self.run_img()) diff --git a/simba/utils/data.py b/simba/utils/data.py index 8e71dfa2a..344db170c 100644 --- a/simba/utils/data.py +++ b/simba/utils/data.py @@ -8,7 +8,7 @@ from copy import deepcopy from datetime import datetime from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union, Any import h5py import numpy as np @@ -33,7 +33,7 @@ check_int, check_str, check_that_column_exist, check_that_hhmmss_start_is_before_end, check_valid_array, check_valid_dataframe) -from simba.utils.enums import ConfigKey, Dtypes, Keys, Options +from simba.utils.enums import ConfigKey, Dtypes, Keys, Options, Formats from simba.utils.errors import (BodypartColumnNotFoundError, CountError, InvalidFileTypeError, InvalidInputError, NoFilesFoundError) @@ -372,6 +372,8 @@ def smooth_data_savitzky_golay( Perform Savitzky-Golay smoothing of pose-estimation data within a file. .. important:: + LEGACY: USE ``savgol_smoother`` instead. + Overwrites the input data with smoothened data. :param configparser.ConfigParser config: Parsed SimBA project_config.ini file. @@ -1032,9 +1034,7 @@ def get_mode(x: np.ndarray) -> Union[float, int]: return counts.argmax() -def run_user_defined_feature_extraction_class( - file_path: Union[str, os.PathLike], config_path: Union[str, os.PathLike] -) -> None: +def run_user_defined_feature_extraction_class(file_path: Union[str, os.PathLike], config_path: Union[str, os.PathLike]) -> None: """ Loads and executes user-defined feature extraction class within .py file. @@ -1121,6 +1121,191 @@ def run_user_defined_feature_extraction_class( user_class(config_path) +def animal_interpolator(df: pd.DataFrame, + animal_bp_dict: Dict[str, Any], + source: Optional[str] = '', + method: Optional[Literal['nearest', 'linear', 'quadratic']] = 'nearest', + verbose: Optional[bool] = True) -> pd.DataFrame: + + """ + Interpolate missing values for frames where entire animals are missing. + + .. note:: + Animals are inferred to be "missing" when all their body-parts have exactly the same value on both the x and y + plane (or None). + + :param pd.DataFrame df: The input DataFrame containing animal body part positions. + :param Dict[str, Any] animal_bp_dict: A dictionary where keys are animal names and values are dictionaries with keys "X_bps" and "Y_bps", which are lists of column names for the x and y coordinates of the animal body parts. + :param Optional[str] source: An optional string indicating the source of the DataFrame, used for logging and informative error messages. + :param Optional[Literal['nearest', 'linear', 'quadratic']] method: The interpolation method to use. Options are 'nearest', 'linear', and 'quadratic'. Defaults to 'nearest'. + :param Optional[bool] verbose: If True, prints the number of missing body parts being interpolated for each animal. + :return pd.DataFrame: The DataFrame with interpolated values for the specified animal body parts. + + :example: + >>> animal_bp_dict = {'Animal_1': {'X_bps': ['Ear_left_1_x', 'Ear_right_1_x', 'Nose_1_x', 'Center_1_x', 'Lat_left_1_x', 'Lat_right_1_x', 'Tail_base_1_x'], 'Y_bps': ['Ear_left_1_y', 'Ear_right_1_y', 'Nose_1_y', 'Center_1_y', 'Lat_left_1_y', 'Lat_right_1_y', 'Tail_base_1_y']}, 'Animal_2': {'X_bps': ['Ear_left_2_x', 'Ear_right_2_x', 'Nose_2_x', 'Center_2_x', 'Lat_left_2_x', 'Lat_right_2_x', 'Tail_base_2_x'], 'Y_bps': ['Ear_left_2_y', 'Ear_right_2_y', 'Nose_2_y', 'Center_2_y', 'Lat_left_2_y', 'Lat_right_2_y', 'Tail_base_2_y']}} + >>> df = pd.read_csv('/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/csv/machine_results/Together_1.csv', index_col=0) + >>> interpolated_df = animal_interpolator(df=df, animal_bp_dict=animal_bp_dict, source='test') + + """ + + check_instance(source=source, instance=df, accepted_types=(pd.DataFrame,)) + check_instance(source=source, instance=animal_bp_dict, accepted_types=(dict,)) + check_valid_dataframe(df=df, source=source, valid_dtypes=Formats.NUMERIC_DTYPES.value) + check_str(name='method', value=method, options=('nearest', 'linear', 'quadratic'), raise_error=True) + + df = df.fillna(0).clip(lower=0).astype(int) + for animal_name, animal_bps in animal_bp_dict.items(): + check_if_keys_exist_in_dict(data=animal_bps, key=["X_bps", "Y_bps"]) + check_that_column_exist(df=df, column_name=animal_bps["X_bps"] + animal_bps["Y_bps"], file_name=source) + animal_df = df[animal_bps["X_bps"] + animal_bps["Y_bps"]] + missing_idx = list(animal_df[animal_df.eq(animal_df.iloc[:, 0], axis=0).all(axis="columns")].index) + if verbose: + print(f"Interpolating {len(missing_idx)} body-parts for animal {animal_name} in {source}...") + animal_df.loc[missing_idx, :] = np.nan + animal_df = animal_df.interpolate(method=method, axis=0).ffill().bfill() + df.update(animal_df) + + return df.clip(lower=0) + + +def body_part_interpolator(df: pd.DataFrame, + animal_bp_dict: Dict[str, Any], + source: Optional[str] = '', + method: Optional[Literal['nearest', 'linear', 'quadratic']] = 'nearest', + verbose: Optional[bool] = True) -> pd.DataFrame: + """ + Interpolate missing body-parts in pose-estimation data. + + .. note:: + Data is inferred to be "missing" when data for the body-part is either "None" on both the x- and y-plane or located at + (0, 0). + + :param pd.DataFrame df: The input DataFrame containing animal body part positions. + :param Dict[str, Any] animal_bp_dict: A dictionary where keys are animal names and values are dictionaries with keys "X_bps" and "Y_bps", which are lists of column names for the x and y coordinates of the animal body parts. + :param Optional[str] source: An optional string indicating the source of the DataFrame, used for logging and informative error messages. + :param Optional[Literal['nearest', 'linear', 'quadratic']] method: The interpolation method to use. Options are 'nearest', 'linear', and 'quadratic'. Defaults to 'nearest'. + :param Optional[bool] verbose: If True, prints the number of missing body parts being interpolated for each animal. + :return pd.DataFrame: The DataFrame with interpolated values for the specified animal body parts. + + :example: + >>> animal_bp_dict = {'Animal_1': {'X_bps': ['Ear_left_1_x', 'Ear_right_1_x', 'Nose_1_x', 'Center_1_x', 'Lat_left_1_x', 'Lat_right_1_x', 'Tail_base_1_x'], 'Y_bps': ['Ear_left_1_y', 'Ear_right_1_y', 'Nose_1_y', 'Center_1_y', 'Lat_left_1_y', 'Lat_right_1_y', 'Tail_base_1_y']}, 'Animal_2': {'X_bps': ['Ear_left_2_x', 'Ear_right_2_x', 'Nose_2_x', 'Center_2_x', 'Lat_left_2_x', 'Lat_right_2_x', 'Tail_base_2_x'], 'Y_bps': ['Ear_left_2_y', 'Ear_right_2_y', 'Nose_2_y', 'Center_2_y', 'Lat_left_2_y', 'Lat_right_2_y', 'Tail_base_2_y']}} + >>> df = pd.read_csv('/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/csv/machine_results/Together_1.csv', index_col=0) + >>> interpolated_df = body_part_interpolator(df=df, animal_bp_dict=animal_bp_dict, source='test') + """ + + check_instance(source=source, instance=df, accepted_types=(pd.DataFrame,)) + check_valid_dataframe(df=df, source=source, valid_dtypes=Formats.NUMERIC_DTYPES.value) + check_str(name='method', value=method, options=('nearest', 'linear', 'quadratic'), raise_error=True) + + df = df.fillna(0).clip(lower=0).astype(int) + for animal in animal_bp_dict: + check_if_keys_exist_in_dict(data=animal_bp_dict[animal], key=["X_bps", "Y_bps"]) + for x_bps_name, y_bps_name in zip(animal_bp_dict[animal]["X_bps"], animal_bp_dict[animal]["Y_bps"]): + check_that_column_exist(df=df, column_name=[x_bps_name, y_bps_name], file_name=source) + bp_df = df[[x_bps_name, y_bps_name]].astype(int) + missing_idx = df.loc[(bp_df[x_bps_name] <= 0.0) & (bp_df[y_bps_name] <= 0.0)].index.tolist() + if verbose: + print(f"Interpolating {len(missing_idx)} {x_bps_name[:-2]} body-parts for animal {animal} in {source}...") + bp_df.loc[missing_idx, [x_bps_name, y_bps_name]] = np.nan + bp_df[x_bps_name] = bp_df[x_bps_name].interpolate(method=method, axis=0).ffill().bfill() + bp_df[y_bps_name] = bp_df[y_bps_name].interpolate(method=method, axis=0).ffill().bfill() + df.update(bp_df) + return df.clip(lower=0) + +def savgol_smoother(data: Union[pd.DataFrame, np.ndarray], + fps: float, + time_window: int, + source: Optional[str] = '', + mode: Optional[Literal['mirror', 'constant', 'nearest', 'wrap', 'interp']] = 'nearest', + polyorder: Optional[int] = 3) -> Union[pd.DataFrame, np.ndarray]: + """ + Apply Savitzky-Golay smoothing to the input data pose-estimation data + + Applies the Savitzky-Golay filter to smooth the data in a DataFrame or a NumPy array. The filter smoothes the data using a polynomial of the specified order and a window size based on the frame rate per second (fps) and the time window. + + :param Union[pd.DataFrame, np.ndarray] data: The input data to be smoothed. Can be a pandas DataFrame or a 2D NumPy array. + :param float fps: The frame rate per second of the data. + :param int time_window: The time window in milliseconds over which to apply the smoothing. + :param Optional[str] source: An optional string indicating the source of the data, used for logging and informative error messages. + :param Optional[Literal['mirror', 'constant', 'nearest', 'wrap', 'interp']] mode: The mode parameter determines the behavior at the edges of the data. Options are:'mirror', 'constant', 'nearest', 'wrap', 'interp'. Default: 'nearest'. + :param Optional[int] polyorder: The order of the polynomial used to fit the samples. + :return Union[pd.DataFrame, np.ndarray]: The smoothed data, returned as a DataFrame if the input was a DataFrame, or a NumPy array if the input was an array. + + :example: + >>> data = pd.read_csv('/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/csv/machine_results/Together_1.csv', index_col=0) + >>> savgol_smoother(data=data.values, fps=15, time_window=1000) + """ + + check_float(name='fps', value=fps, min_value=10e-16, raise_error=True) + check_int(name='time_window', value=time_window, min_value=1, raise_error=True) + check_str(name='mode', value=mode, options=('mirror', 'constant', 'nearest', 'wrap', 'interp'), raise_error=True) + check_int(name='time_window', value=time_window, min_value=1, raise_error=True) + check_int(name='polyorder', value=polyorder, min_value=1, raise_error=True) + check_instance(source=source, instance=data, accepted_types=(pd.DataFrame, np.ndarray,)) + frms_in_smoothing_window = int(time_window / (1000 / fps)) + if frms_in_smoothing_window <= 1: + return data + if (frms_in_smoothing_window % 2) == 0: + frms_in_smoothing_window = frms_in_smoothing_window - 1 + if frms_in_smoothing_window <= polyorder: + if polyorder % 2 == 0: + frms_in_smoothing_window = polyorder + 1 + else: + frms_in_smoothing_window = polyorder + 2 + if isinstance(data, pd.DataFrame): + check_valid_dataframe(df=data, valid_dtypes=Formats.NUMERIC_DTYPES.value, source=source) + data = data.clip(lower=0) + for c in data.columns: + data[c] = savgol_filter(x=data[c].to_numpy(), window_length=frms_in_smoothing_window, polyorder=polyorder, mode=mode) + data = data.clip(lower=0) + elif isinstance(data, np.ndarray): + check_valid_array(data=data, source=source, accepted_ndims=(2,), accepted_dtypes=Formats.NUMERIC_DTYPES.value) + data = np.clip(a=data, a_min=0, a_max=None) + for c in range(data.shape[1]): + data[:, c] = savgol_filter(x=data[:, c], window_length=frms_in_smoothing_window, polyorder=polyorder, mode=mode) + data = np.clip(a=data, a_min=0, a_max=None) + return data + +def df_smoother(data: pd.DataFrame, + fps: float, + time_window: int, + source: Optional[str] = '', + std: Optional[int] = 5, + method: Optional[Literal['bartlett', 'blackman', 'boxcar', 'cosine', 'gaussian', 'hamming', 'exponential']] = 'gaussian') -> pd.DataFrame: + + """ + Smooth the data in a DataFrame using a specified window function. + + This function applies a rolling window smoothing operation to the data in the DataFrame. The type of window function + and the standard deviation for the smoothing can be specified. The window size is determined based on the frame rate + per second (fps) and the time window. + + :param pd.DataFrame data: The input data to be smoothed. + :param float fps: The frame rate per second of the data. + :param int time_window: The time window in milliseconds over which to apply the smoothing. + :param Optional[str] source: An optional string indicating the source of the data, used for logging and informative error messages. + :param Optional[int] std: The standard deviation for the window function, used when the method is 'gaussian'. + :param Optional[Literal['bartlett', 'blackman', 'boxcar', 'cosine', 'gaussian', 'hamming', 'exponential']] method: The type of window function to use for smoothing. Default 'gaussian'. + :return pd.DataFrame: The smoothed DataFrame. + """ + + check_float(name='fps', value=fps, min_value=10e-16, raise_error=True) + check_int(name='time_window', value=time_window, min_value=1, raise_error=True) + check_int(name='std', value=std, min_value=1, raise_error=True) + check_str(name='method', value=method, options=('bartlett', 'blackman', 'boxcar', 'cosine', 'gaussian', 'hamming', 'exponential'), raise_error=True) + check_valid_dataframe(df=data, valid_dtypes=Formats.NUMERIC_DTYPES.value, source=source) + frms_in_smoothing_window = int(time_window / (1000 / fps)) + if frms_in_smoothing_window <= 1: + return data + data = data.clip(lower=0) + for c in data.columns: + data[c] = data[c].rolling(window=int(frms_in_smoothing_window), win_type=method, center=True).mean(std=std).fillna(data[c]).abs() + return data.clip(lower=0) + + + + + # run_user_defined_feature_extraction_class(config_path='/Users/simon/Desktop/envs/troubleshooting/circular_features_zebrafish/project_folder/project_config.ini', file_path='/Users/simon/Desktop/fish_feature_extractor_2023_version_5.py') diff --git a/simba/utils/read_write.py b/simba/utils/read_write.py index 990cd6131..09281bc0c 100644 --- a/simba/utils/read_write.py +++ b/simba/utils/read_write.py @@ -32,7 +32,7 @@ check_if_string_value_is_valid_video_timestamp, check_instance, check_int, check_nvidea_gpu_available, check_valid_lst) -from simba.utils.enums import ConfigKey, Dtypes, Formats, Keys +from simba.utils.enums import ConfigKey, Dtypes, Formats, Keys, Options from simba.utils.errors import (DataHeaderError, DuplicationError, FeatureNumberMismatchError, FFMPEGCodecGPUError, FileExistError, @@ -558,18 +558,16 @@ def get_bp_headers(body_parts_lst: List[str]) -> list: return bp_headers -def read_video_info( - vid_info_df: pd.DataFrame, video_name: str -) -> (pd.DataFrame, float, float): +def read_video_info(vid_info_df: pd.DataFrame, + video_name: str, + raise_error: Optional[bool] = True) -> Tuple[pd.DataFrame, float, float]: """ Helper to read the metadata (pixels per mm, resolution, fps etc) from the video_info.csv for a single input file/video :parameter pd.DataFrame vid_info_df: Parsed ``project_folder/logs/video_info.csv`` file. This file can be parsed by :meth:`simba.utils.read_write.read_video_info_csv`. :parameter str video_name: Name of the video as represented in the ``Video`` column of the ``project_folder/logs/video_info.csv`` file. - :returns pd.DataFrame: One row DataFrame representing the video in the ``project_folder/logs/video_info.csv`` file. - :return float: The frame rate of the video as represented in the ``project_folder/logs/video_info.csv`` file - :return float: The pixels per millimeter of the video as represented in the ``project_folder/logs/video_info.csv`` file - :raise ParametersFileError: The video is not accurately represented in the ``project_folder/logs/video_info.csv`` file. + :parameter Optional[bool] raise_error: If True, raises error if the video cannot be found in the ``vid_info_df`` file. If False, returns None if the video cannot be found. + :returns Tuple[pd.DataFrame, float, float]: One row DataFrame representing the video in the ``project_folder/logs/video_info.csv`` file, the frame rate of the video, and the the pixels per millimeter of the video :example: >>> video_info_df = read_video_info_csv(file_path='project_folder/logs/video_info.csv') @@ -578,25 +576,19 @@ def read_video_info( video_settings = vid_info_df.loc[vid_info_df["Video"] == video_name] if len(video_settings) > 1: - raise DuplicationError( - msg=f"SimBA found multiple rows in the project_folder/logs/video_info.csv named {str(video_name)}. Please make sure that each video name is represented ONCE in the video_info.csv", - source=read_video_info.__name__, - ) + raise DuplicationError(msg=f"SimBA found multiple rows in the project_folder/logs/video_info.csv named {str(video_name)}. Please make sure that each video name is represented ONCE in the video_info.csv", source=read_video_info.__name__) elif len(video_settings) < 1: - raise ParametersFileError( - msg=f" SimBA could not find {str(video_name)} in the video_info.csv file. Make sure all videos analyzed are represented in the project_folder/logs/video_info.csv file.", - source=read_video_info.__name__, - ) + if raise_error: + raise ParametersFileError(msg=f" SimBA could not find {str(video_name)} in the video_info.csv file. Make sure all videos analyzed are represented in the project_folder/logs/video_info.csv file.", source=read_video_info.__name__) + else: + return None else: try: px_per_mm = float(video_settings["pixels/mm"]) fps = float(video_settings["fps"]) return video_settings, px_per_mm, fps except TypeError: - raise ParametersFileError( - msg=f"Make sure the videos that are going to be analyzed are represented with APPROPRIATE VALUES inside the project_folder/logs/video_info.csv file in your SimBA project. Could not interpret the fps, pixels per millimeter and/or fps as numerical values for video {video_name}", - source=read_video_info.__name__, - ) + raise ParametersFileError(msg=f"Make sure the videos that are going to be analyzed are represented with APPROPRIATE VALUES inside the project_folder/logs/video_info.csv file in your SimBA project. Could not interpret the fps, pixels per millimeter and/or fps as numerical values for video {video_name}", source=read_video_info.__name__) def find_all_videos_in_directory( @@ -728,17 +720,18 @@ def read_frm_of_video( return img -def find_video_of_file( - video_dir: Union[str, os.PathLike], filename: str, raise_error: bool = False -) -> Union[str, os.PathLike]: +def find_video_of_file(video_dir: Union[str, os.PathLike], + filename: str, + raise_error: Optional[bool] = False, + warning: Optional[bool] = True) -> Union[str, os.PathLike]: """ Helper to find the video file with the SimBA project that represents a known data file path. :param str video_dir: Directory holding putative video file. :param str filename: Data file name, e.g., ``Video_1``. - :param bool raise_error: If True, raise error if no file can be found. Else, print warning. Default: False + :param Optional[bool] raise_error: If True, raise error if no file can be found. If False, returns None if no file can be found. Default: False + :param Optional[bool] warning: If True, print warning if no file can be found. If False, no warning is printed if file cannot be found. Default: False :return str: Video path. - :raise NoFilesFoundError: No video file representing file found. :examples: >>> find_video_of_file(video_dir='project_folder/videos', filename='Together_1') @@ -746,35 +739,25 @@ def find_video_of_file( """ try: - all_files_in_video_folder = [ - f for f in next(os.walk(video_dir))[2] if not f[0] == "." - ] + all_files_in_video_folder = [f for f in next(os.walk(video_dir))[2] if not f[0] == "."] except StopIteration: - raise NoFilesFoundError( - msg=f"No files found in the {video_dir} directory", - source=find_video_of_file.__name__, - ) - all_files_in_video_folder = [ - os.path.join(video_dir, x) for x in all_files_in_video_folder - ] + if raise_error: + raise NoFilesFoundError(msg=f"No files found in the {video_dir} directory", source=find_video_of_file.__name__) + elif warning: + NoFileFoundWarning(msg=f"SimBA could not find a video file representing {filename} in the project video directory {video_dir}", source=find_video_of_file.__name__) + return None + + all_files_in_video_folder = [os.path.join(video_dir, x) for x in all_files_in_video_folder] return_path = None for file_path in all_files_in_video_folder: _, video_filename, ext = get_fn_ext(file_path) - if (video_filename == filename) and ( - (ext.lower() == ".mp4") or (ext.lower() == ".avi") - ): + if (video_filename == filename) and (ext.lower() in Options.ALL_VIDEO_FORMAT_OPTIONS.value): return_path = file_path if return_path is None and raise_error: - raise NoFilesFoundError( - msg=f"SimBA could not find a video file representing {filename} in the project video directory {video_dir}", - source=find_video_of_file.__name__, - ) - elif return_path is None: - NoFileFoundWarning( - msg=f"SimBA could not find a video file representing {filename} in the project video directory {video_dir}", - source=find_video_of_file.__name__, - ) + raise NoFilesFoundError(msg=f"SimBA could not find a video file representing {filename} in the project video directory {video_dir}", source=find_video_of_file.__name__) + elif return_path is None and warning: + NoFileFoundWarning(msg=f"SimBA could not find a video file representing {filename} in the project video directory {video_dir}", source=find_video_of_file.__name__) return return_path @@ -1896,12 +1879,10 @@ def get_unique_values_in_iterable( return cnt -def copy_files_to_directory( - file_paths: List[Union[str, os.PathLike]], - dir: Union[str, os.PathLike], - verbose: Optional[bool] = True, - integer_save_names: Optional[bool] = False, -) -> List[Union[str, os.PathLike]]: +def copy_files_to_directory(file_paths: List[Union[str, os.PathLike]], + dir: Union[str, os.PathLike], + verbose: Optional[bool] = True, + integer_save_names: Optional[bool] = False) -> List[Union[str, os.PathLike]]: """ Copy a list of files to a specified directory. diff --git a/simba/video_processors/video_processing.py b/simba/video_processors/video_processing.py index 0a76a6c55..355a3669e 100644 --- a/simba/video_processors/video_processing.py +++ b/simba/video_processors/video_processing.py @@ -138,7 +138,7 @@ def convert_to_jpeg(path: Union[str, os.PathLike, List[Union[str, os.PathLike]]] raise InvalidInputError(msg=f'{path} is not a valid file path, directory path, or list of file paths', source=convert_to_jpeg.__name__) directory, _, _ = get_fn_ext(filepath=file_paths[0]) if save_dir is None: - save_dir = os.path.join(directory, f'bmp_{datetime_}') + save_dir = os.path.join(directory, f'jpeg_{datetime_}') os.makedirs(save_dir) else: check_if_dir_exists(in_dir=save_dir, source=f'{convert_to_jpeg.__name__} save_dir', create_if_not_exist=True) @@ -237,7 +237,7 @@ def convert_to_png(path: Union[str, os.PathLike], raise InvalidInputError(msg=f'{path} is not a valid file path or directory path or list of file paths', source=convert_to_png.__name__) directory, _, _ = get_fn_ext(filepath=file_paths[0]) if save_dir is None: - save_dir = os.path.join(directory, f'bmp_{datetime_}') + save_dir = os.path.join(directory, f'png_{datetime_}') os.makedirs(save_dir) else: check_if_dir_exists(in_dir=save_dir, source=f'{convert_to_png.__name__} save_dir', create_if_not_exist=True) @@ -349,7 +349,7 @@ def convert_to_webp(path: Union[str, os.PathLike], raise InvalidInputError(msg=f'{path} is not a valid file path or directory path', source=convert_to_webp.__name__) directory, _, _ = get_fn_ext(filepath=file_paths[0]) if save_dir is None: - save_dir = os.path.join(directory, f'bmp_{datetime_}') + save_dir = os.path.join(directory, f'webp_{datetime_}') os.makedirs(save_dir) else: check_if_dir_exists(in_dir=save_dir, source=f'{convert_to_png.__name__} save_dir', create_if_not_exist=True)