Skip to content

Commit

Permalink
cleaned
Browse files Browse the repository at this point in the history
  • Loading branch information
sronilsson committed Apr 22, 2024
1 parent 5c53bf6 commit c57d5dc
Show file tree
Hide file tree
Showing 3 changed files with 234 additions and 103 deletions.
149 changes: 106 additions & 43 deletions simba/plotting/ez_path_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,42 @@

import os
from copy import deepcopy
from typing import Union, Tuple, Optional
from typing import Optional, Tuple, Union

import cv2
import numpy as np
import pandas as pd

from simba.utils.errors import (DataHeaderError, DuplicationError, InvalidFileTypeError, InvalidInputError)
from simba.utils.checks import (check_file_exist_and_readable,
check_if_dir_exists, check_if_valid_rgb_tuple,
check_int, check_valid_tuple)
from simba.utils.errors import (DataHeaderError, DuplicationError,
InvalidFileTypeError, InvalidInputError)
from simba.utils.printing import SimbaTimer, stdout_success
from simba.utils.read_write import get_fn_ext, get_video_meta_data, read_config_file, read_df, get_number_of_header_columns_in_df
from simba.utils.checks import check_file_exist_and_readable, check_if_valid_rgb_tuple, check_int, check_if_dir_exists, check_valid_tuple
from simba.utils.read_write import (get_fn_ext,
get_number_of_header_columns_in_df,
get_video_meta_data, read_config_file,
read_df)

H5 = '.h5'
CSV = '.csv'
H5 = ".h5"
CSV = ".csv"

class EzPathPlot(object):
def __init__(self,
data_path: Union[str, os.PathLike],
body_part: str,
bg_color: Optional[Tuple[int, int, int]] = (255, 255, 255),
line_color: Optional[Tuple[int, int, int]] = (147, 20, 255),
video_path: Optional[Union[str, os.PathLike]] = None,
size: Optional[Tuple[int, int]] = None,
fps: Optional[int] = None,
line_thickness: Optional[int] = 10,
circle_size: Optional[int] = 5,
last_frm_only: Optional[bool] = False,
save_path: Optional[Union[str, os.PathLike]] = None):

class EzPathPlot(object):
def __init__(
self,
data_path: Union[str, os.PathLike],
body_part: str,
bg_color: Optional[Tuple[int, int, int]] = (255, 255, 255),
line_color: Optional[Tuple[int, int, int]] = (147, 20, 255),
video_path: Optional[Union[str, os.PathLike]] = None,
size: Optional[Tuple[int, int]] = None,
fps: Optional[int] = None,
line_thickness: Optional[int] = 10,
circle_size: Optional[int] = 5,
last_frm_only: Optional[bool] = False,
save_path: Optional[Union[str, os.PathLike]] = None,
):
"""
Create a simpler path plot for a single path in a single video.
Expand Down Expand Up @@ -66,26 +74,48 @@ def __init__(self,
>>> path_plotter.run()
"""


check_file_exist_and_readable(file_path=data_path)
check_if_valid_rgb_tuple(data=bg_color)
check_if_valid_rgb_tuple(data=line_color)
check_int(name=f'{self.__class__.__name__} line_thickness', value=line_thickness, min_value=1)
check_int(name=f'{self.__class__.__name__} circle_size', value=circle_size, min_value=1)
check_int(
name=f"{self.__class__.__name__} line_thickness",
value=line_thickness,
min_value=1,
)
check_int(
name=f"{self.__class__.__name__} circle_size",
value=circle_size,
min_value=1,
)
if save_path is not None:
check_if_dir_exists(in_dir=os.path.dirname(save_path), create_if_not_exist=True)
check_if_dir_exists(
in_dir=os.path.dirname(save_path), create_if_not_exist=True
)
if line_color == bg_color:
raise DuplicationError(msg=f"The line and background cannot be identical - ({line_color})", source=self.__class__.__name__)
raise DuplicationError(
msg=f"The line and background cannot be identical - ({line_color})",
source=self.__class__.__name__,
)
if video_path is not None:
video_meta_data = get_video_meta_data(video_path=video_path)
self.height, self.width = int(video_meta_data['height']), int(video_meta_data['width'])
self.fps = int(video_meta_data['fps'])
self.height, self.width = int(video_meta_data["height"]), int(
video_meta_data["width"]
)
self.fps = int(video_meta_data["fps"])
else:
if (size is None) or (fps is None):
raise InvalidInputError(msg='If video path is None, then pass size and fps', source=self.__class__.__name__)
check_valid_tuple(x=size, source=f'{self.__class__.__name__} size', accepted_lengths=(2,), valid_dtypes=(int,))
raise InvalidInputError(
msg="If video path is None, then pass size and fps",
source=self.__class__.__name__,
)
check_valid_tuple(
x=size,
source=f"{self.__class__.__name__} size",
accepted_lengths=(2,),
valid_dtypes=(int,),
)
self.height, self.width = size[1], size[0]
check_int(name=f'{self.__class__.__name__} fps', value=fps, min_value=1)
check_int(name=f"{self.__class__.__name__} fps", value=fps, min_value=1)
self.fps = int(fps)
dir, file_name, ext = get_fn_ext(filepath=data_path)
if ext.lower() == H5:
Expand All @@ -101,18 +131,32 @@ def __init__(self,
elif ext.lower() == CSV:
self.data = pd.read_csv(data_path)
else:
raise InvalidFileTypeError(msg=f"File type {ext} is not supported (OPTIONS: h5 or csv)")
raise InvalidFileTypeError(
msg=f"File type {ext} is not supported (OPTIONS: h5 or csv)"
)
if len(self.data.columns[0]) == 4:
self.data = self.data.loc[3:]
elif len(self.data.columns[0]) == 3:
self.data = self.data.loc[2:]
body_parts_available = list(set([x[:-2] for x in self.data.columns]))
if body_part not in body_parts_available:
raise DataHeaderError(msg=f"Body-part {body_part} is not present in the data file. The body-parts available are: {body_parts_available}", source=self.__class__.__name__)
bps = [f'{body_part}_x', f'{body_part}_y']
raise DataHeaderError(
msg=f"Body-part {body_part} is not present in the data file. The body-parts available are: {body_parts_available}",
source=self.__class__.__name__,
)
bps = [f"{body_part}_x", f"{body_part}_y"]
if (bps[0] not in self.data.columns) or (bps[1] not in self.data.columns):
raise DataHeaderError(msg=f"Could not finc column {bps[0]} and/or column {bps[1]} in the data file {data_path}", source=self.__class__.__name__)
self.data = self.data[bps].fillna(method="ffill").astype(int).reset_index(drop=True).values
raise DataHeaderError(
msg=f"Could not finc column {bps[0]} and/or column {bps[1]} in the data file {data_path}",
source=self.__class__.__name__,
)
self.data = (
self.data[bps]
.fillna(method="ffill")
.astype(int)
.reset_index(drop=True)
.values
)
if (save_path is None) and (not last_frm_only):
self.save_name = os.path.join(dir, f"{file_name}_line_plot.mp4")
elif (save_path is None) and (last_frm_only):
Expand All @@ -121,20 +165,35 @@ def __init__(self,
self.save_name = save_path
self.bg_img = np.zeros([self.height, self.width, 3])
self.bg_img[:] = [bg_color]
self.line_color, self.line_thickness, self.circle_size, self.last_frm = line_color, line_thickness, circle_size, last_frm_only
self.line_color, self.line_thickness, self.circle_size, self.last_frm = (
line_color,
line_thickness,
circle_size,
last_frm_only,
)
self.timer = SimbaTimer(start=True)

def run(self):
if not self.last_frm:
self.writer = cv2.VideoWriter(self.save_name, 0x7634706D, self.fps, (self.width, self.height))
for i in range(1, self.data.shape[0]+1):
line_data = self.data[:i+1]
self.writer = cv2.VideoWriter(
self.save_name, 0x7634706D, self.fps, (self.width, self.height)
)
for i in range(1, self.data.shape[0] + 1):
line_data = self.data[: i + 1]
img = deepcopy(self.bg_img)
for j in range(1, line_data.shape[0]):
x1, y1 = line_data[j-1][0], line_data[j-1][1]
x1, y1 = line_data[j - 1][0], line_data[j - 1][1]
x2, y2 = line_data[j][0], line_data[j][1]
cv2.line(img, (x1, y1), (x2, y2), self.line_color, self.line_thickness)
cv2.circle(img, (line_data[-1][0], line_data[-1][1]), self.circle_size, self.line_color, -1)
cv2.line(
img, (x1, y1), (x2, y2), self.line_color, self.line_thickness
)
cv2.circle(
img,
(line_data[-1][0], line_data[-1][1]),
self.circle_size,
self.line_color,
-1,
)
self.writer.write(img.astype(np.uint8))
print(f"Frame {i}/{len(self.data)} complete...")
self.writer.release()
Expand All @@ -146,7 +205,11 @@ def run(self):
cv2.line(img, (x1, y1), (x2, y2), self.line_color, self.line_thickness)
cv2.imwrite(filename=self.save_name, img=img)
self.timer.stop_timer()
stdout_success(msg=f"Path plot saved at {self.save_name}", elapsed_time=self.timer.elapsed_time_str)
stdout_success(
msg=f"Path plot saved at {self.save_name}",
elapsed_time=self.timer.elapsed_time_str,
)


# path_plotter = EzPathPlot(data_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/h5/Together_1DLC_resnet50_two_black_mice_DLC_052820May27shuffle1_150000_el.h5',
# size=(2056, 1549),
Expand All @@ -155,4 +218,4 @@ def run(self):
# bg_color=(255, 255, 255),
# last_frm_only=False,
# line_color=(147,20,255))
# path_plotter.run()
# path_plotter.run()
77 changes: 54 additions & 23 deletions simba/ui/pop_ups/make_path_plot_pop_up.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,38 @@
from tkinter import *
import threading
from tkinter import *

from simba.mixins.pop_up_mixin import PopUpMixin
from simba.plotting.ez_path_plot import EzPathPlot
from simba.ui.tkinter_functions import (CreateLabelFrameWithIcon, DropDownMenu, Entry_Box, FileSelect)
from simba.ui.tkinter_functions import (CreateLabelFrameWithIcon, DropDownMenu,
Entry_Box, FileSelect)
from simba.utils.checks import (check_file_exist_and_readable,
check_if_valid_rgb_tuple, check_int, check_str)
from simba.utils.enums import Keys, Links, Options
from simba.utils.lookups import get_color_dict
from simba.utils.checks import check_file_exist_and_readable, check_int, check_if_valid_rgb_tuple, check_str


class MakePathPlotPopUp(PopUpMixin):
def __init__(self):
PopUpMixin.__init__(self, title="CREATE SIMPLE PATH PLOT", size=(500, 300))
settings_frm = CreateLabelFrameWithIcon(parent=self.main_frm, header="SETTINGS", icon_name=Keys.DOCUMENTATION.value, icon_link=Links.VIDEO_TOOLS.value,
settings_frm = CreateLabelFrameWithIcon(
parent=self.main_frm,
header="SETTINGS",
icon_name=Keys.DOCUMENTATION.value,
icon_link=Links.VIDEO_TOOLS.value,
)
self.video_path = FileSelect(
settings_frm,
"VIDEO PATH: ",
lblwidth="30",
file_types=[("VIDEO FILE", Options.ALL_VIDEO_FORMAT_STR_OPTIONS.value)],
)
self.video_path = FileSelect(settings_frm, "VIDEO PATH: ", lblwidth="30", file_types=[("VIDEO FILE", Options.ALL_VIDEO_FORMAT_STR_OPTIONS.value)])
self.body_part = Entry_Box(settings_frm, "BODY PART: ", "30")
self.data_path = FileSelect(
settings_frm, "DATA PATH (e.g., H5 or CSV file): ", lblwidth="30"
)
color_lst = list(get_color_dict().keys())
self.background_color = DropDownMenu(settings_frm, "BACKGROUND COLOR: ", color_lst, "30"
self.background_color = DropDownMenu(
settings_frm, "BACKGROUND COLOR: ", color_lst, "30"
)
self.background_color.setChoices(choice="White")
self.line_color = DropDownMenu(settings_frm, "LINE COLOR: ", color_lst, "30")
Expand All @@ -29,9 +41,13 @@ def __init__(self):
settings_frm, "LINE THICKNESS: ", list(range(1, 11)), "30"
)
self.line_thickness.setChoices(choice=1)
self.circle_size = DropDownMenu(settings_frm, "CIRCLE SIZE: ", list(range(1, 11)), "30")
self.last_frm_only_dropdown = DropDownMenu(settings_frm, "LAST FRAME ONLY: ", ['TRUE', 'FALSE'], "30")
self.last_frm_only_dropdown.setChoices('FALSE')
self.circle_size = DropDownMenu(
settings_frm, "CIRCLE SIZE: ", list(range(1, 11)), "30"
)
self.last_frm_only_dropdown = DropDownMenu(
settings_frm, "LAST FRAME ONLY: ", ["TRUE", "FALSE"], "30"
)
self.last_frm_only_dropdown.setChoices("FALSE")
self.circle_size.setChoices(choice=5)
settings_frm.grid(row=0, sticky=W)
self.video_path.grid(row=0, sticky=W)
Expand All @@ -42,7 +58,11 @@ def __init__(self):
self.line_thickness.grid(row=5, sticky=W)
self.circle_size.grid(row=6, sticky=W)
self.last_frm_only_dropdown.grid(row=7, sticky=W)
Label(settings_frm, fg='green', text=" NOTE: For more complex path plots, faster, \n see 'CREATE PATH PLOTS' under the [VISUALIZATIONS] tab after loading your SimBA project").grid(row=8, sticky=W)
Label(
settings_frm,
fg="green",
text=" NOTE: For more complex path plots, faster, \n see 'CREATE PATH PLOTS' under the [VISUALIZATIONS] tab after loading your SimBA project",
).grid(row=8, sticky=W)
self.create_run_frm(run_function=self.run)
self.main_frm.mainloop()

Expand All @@ -55,24 +75,35 @@ def run(self):
circle_size = self.circle_size.getChoices()
bp = self.body_part.entry_get
check_file_exist_and_readable(file_path=data_path)
check_int(name=f'{self.__class__.__name__} line_thickness', value=line_thickness, min_value=1)
check_int(name=f'{self.__class__.__name__} circle_size', value=circle_size, min_value=1)
check_int(
name=f"{self.__class__.__name__} line_thickness",
value=line_thickness,
min_value=1,
)
check_int(
name=f"{self.__class__.__name__} circle_size",
value=circle_size,
min_value=1,
)
check_if_valid_rgb_tuple(data=background_color)
check_if_valid_rgb_tuple(data=line_color)
check_str(name=f'{self.__class__.__name__} body-part', value=bp)
check_str(name=f"{self.__class__.__name__} body-part", value=bp)
last_frm = self.last_frm_only_dropdown.getChoices()
if last_frm == 'TRUE':
if last_frm == "TRUE":
last_frm = True
else:
last_frm = False
plotter = EzPathPlot(data_path=data_path,
video_path=video_path,
body_part=bp,
bg_color=background_color,
line_color=line_color,
line_thickness=int(line_thickness),
circle_size=int(circle_size),
last_frm_only=last_frm)
plotter = EzPathPlot(
data_path=data_path,
video_path=video_path,
body_part=bp,
bg_color=background_color,
line_color=line_color,
line_thickness=int(line_thickness),
circle_size=int(circle_size),
last_frm_only=last_frm,
)
threading.Thread(target=plotter.run).start()

#MakePathPlotPopUp()

# MakePathPlotPopUp()
Loading

0 comments on commit c57d5dc

Please sign in to comment.