From 71db786b420d5e304d943d0dc0680223d3fd0981 Mon Sep 17 00:00:00 2001 From: roomrys Date: Mon, 30 Oct 2023 22:07:42 -0700 Subject: [PATCH 01/22] Update methods to allow triangulating multiple instances at once --- sleap/gui/commands.py | 63 ++++++++++++++++++++++---------------- tests/gui/test_commands.py | 54 +++++++++++++++++++------------- 2 files changed, 69 insertions(+), 48 deletions(-) diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index b3c83650d..ca66c3474 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -1947,7 +1947,6 @@ class AddSession(EditCommand): @staticmethod def do_action(context: CommandContext, params: dict): - camera_calibration = params["camera_calibration"] session = RecordingSession.load(filename=camera_calibration) @@ -3479,7 +3478,7 @@ def verify_views_and_instances(cls, context: CommandContext, params: dict): if session is None or instance is None: return - track = instance.track + track = instance.track # TODO(LM): Replace with InstanceGroup cams_to_include = params.get("cams_to_include", None) or session.linked_cameras # If not enough `Camcorder`s available/specified, then return @@ -3516,7 +3515,7 @@ def get_and_verify_enough_instances( frame_idx: int, context: Optional[CommandContext] = None, cams_to_include: Optional[List[Camcorder]] = None, - track: Optional[Track] = None, + track: Union[Track, int] = -1, show_dialog: bool = True, ) -> Union[Dict[Camcorder, Instance], bool]: """Get all instances accross views at this frame index. @@ -3528,7 +3527,8 @@ def get_and_verify_enough_instances( frame_idx: Frame index to get instances from (0-indexed). context: The optional command context used to display a dialog. cams_to_include: List of `Camcorder`s to include. Default is all. - track: `Track` object used to find instances accross views. Default is None. + track: `Track` object used to find instances accross views. Default is -1 + which finds all instances regardless of track. show_dialog: If True, then show a warning dialog. Default is True. Returns: @@ -3601,7 +3601,7 @@ def get_instances_across_views( session: RecordingSession, frame_idx: int, cams_to_include: Optional[List[Camcorder]] = None, - track: Optional["Track"] = None, + track: Union["Track", int] = -1, require_multiple_views: bool = False, ) -> Dict[Camcorder, "Instance"]: """Get all `Instances` accross all views at a given frame index. @@ -3610,7 +3610,8 @@ def get_instances_across_views( session: The `RecordingSession` containing the `Camcorder`s. frame_idx: Frame index to get instances from (0-indexed). cams_to_include: List of `Camcorder`s to include. Default is all. - track: `Track` object used to find instances accross views. Default is None. + track: `Track` object used to find instances accross views. Default is -1 + which find all instances regardless of track. require_multiple_views: If True, then raise and error if one or less views or instances are found. @@ -3643,7 +3644,7 @@ def get_instances_across_views( for cam, lf in views.items(): insts = lf.find(track=track) if len(insts) > 0: - instances[cam] = insts[0] + instances[cam] = insts # If not enough instances for multiple views, then raise error if len(instances) <= 1 and require_multiple_views: @@ -3699,7 +3700,7 @@ def get_all_views_at_frame( def get_instances_matrices(instances_ordered: List[Instance]) -> np.ndarray: """Gather instances from views into M x F x T x N x 2 an array. - M: # views, F: # frames = 1, T: # tracks = 1, N: # nodes, 2: x, y + M: # views, F: # frames = 1, T: # tracks, N: # nodes, 2: x, y Args: instances_ordered: List of instances from view (following the order of the @@ -3710,12 +3711,19 @@ def get_instances_matrices(instances_ordered: List[Instance]) -> np.ndarray: """ # Gather instances into M x F x T x N x 2 arrays (require specific order) - # (M = # views, F = # frames = 1, T = # tracks = 1, N = # nodes, 2 = x, y) - inst_coords = np.stack( - [inst.numpy() for inst in instances_ordered], axis=0 - ) # M x N x 2 - inst_coords = np.expand_dims(inst_coords, axis=1) # M x T=1 x N x 2 - inst_coords = np.expand_dims(inst_coords, axis=1) # M x F=1 x T=1 x N x 2 + # (M = # views, F = # frames = 1, T = # tracks, N = # nodes, 2 = x, y) + + # Get list of instances matrices from each view + + inst_coords = [ + np.stack( + [inst.numpy() for inst in instances_in_view], + axis=0, + ) + for instances_in_view in instances_ordered + ] # List[T x N x 2] + inst_coords = np.stack(inst_coords, axis=0) # M x T x N x 2 + inst_coords = np.expand_dims(inst_coords, axis=1) # M x F=1 x T x N x 2 return inst_coords @@ -3742,7 +3750,7 @@ def calculate_excluded_views( @staticmethod def calculate_reprojected_points( - session: RecordingSession, instances: Dict[Camcorder, "Instance"] + session: RecordingSession, instances: Dict[Camcorder, List[Instance]] ) -> Iterator[Tuple["Instance", np.ndarray]]: """Triangulate and reproject instance coordinates. @@ -3752,12 +3760,12 @@ def calculate_reprojected_points( https://github.com/lambdaloop/aniposelib/blob/d03b485c4e178d7cff076e9fe1ac36837db49158/aniposelib/cameras.py#L491 Args: - instances: Dict with `Camcorder` keys and `Instance` values. + instances: Dict with `Camcorder` keys and `List[Instance]` values. Returns: A zip of the ordered instances and the related reprojected coordinates. Each - element in the coordinates is a numpy array of shape (1, N, 2) where N is - the number of nodes. + element in the coordinates is a numpy array of shape (T, N, 2) where N is + the number of nodes and T is the number of reprojected instances. """ # TODO (LM): Support multiple tracks and optimize @@ -3770,23 +3778,23 @@ def calculate_reprojected_points( ] # Gather instances into M x F x T x N x 2 arrays (require specific order) - # (M = # views, F = # frames = 1, T = # tracks = 1, N = # nodes, 2 = x, y) + # (M = # views, F = # frames = 1, T = # tracks, N = # nodes, 2 = x, y) inst_coords = TriangulateSession.get_instances_matrices( instances_ordered=instances_ordered - ) # M x F=1 x T=1 x N x 2 + ) # M x F=1 x T x N x 2 points_3d = triangulate( p2d=inst_coords, calib=session.camera_cluster, excluded_views=excluded_views, - ) # F=1, T=1, N, 3 + ) # F=1, T, N, 3 # Update the views with the new 3D points inst_coords_reprojected = reproject( points_3d, calib=session.camera_cluster, excluded_views=excluded_views - ) # M x F=1 x T=1 x N x 2 + ) # M x F=1 x T x N x 2 insts_coords_list: List[np.ndarray] = np.split( inst_coords_reprojected.squeeze(), inst_coords_reprojected.shape[0], axis=0 - ) # len(M) of T=1 x N x 2 + ) # len(M) of T x N x 2 return zip(instances_ordered, insts_coords_list) @@ -3810,10 +3818,11 @@ def update_instances(session, instances: Dict[Camcorder, Instance]): ) # Update the instance coordinates. - for inst, inst_coord in instances_and_coords: - inst.update_points( - points=inst_coord[0], exclude_complete=True - ) # inst_coord is (1, N, 2) + for instances_in_view, inst_coord in instances_and_coords: + for inst_idx, inst in enumerate(instances_in_view): + inst.update_points( + points=inst_coord[inst_idx], exclude_complete=True + ) # inst_coord is (T, N, 2) def open_website(url: str): diff --git a/tests/gui/test_commands.py b/tests/gui/test_commands.py index 35bef05d1..93d4008b4 100644 --- a/tests/gui/test_commands.py +++ b/tests/gui/test_commands.py @@ -220,7 +220,6 @@ def assert_videos_written(num_videos: int, labels_path: str = None): context.state["filename"] = None if csv: - context.state["filename"] = centered_pair_predictions_hdf5_path params = {"all_videos": True, "csv": csv} @@ -986,7 +985,6 @@ def test_triangulate_session_get_all_views_at_frame( def test_triangulate_session_get_instances_across_views( multiview_min_session_labels: Labels, ): - labels = multiview_min_session_labels session = labels.sessions[0] @@ -1001,10 +999,11 @@ def test_triangulate_session_get_instances_across_views( assert len(instances) == len(session.videos) for vid in session.videos: cam = session[vid] - inst = instances[cam] - assert inst.frame_idx == lf.frame_idx - assert inst.track == track - assert inst.video == vid + instances_in_view = instances[cam] + for inst in instances_in_view: + assert inst.frame_idx == lf.frame_idx + assert inst.track == track + assert inst.video == vid # Try with excluding cam views lf: LabeledFrame = labels[2] @@ -1027,10 +1026,11 @@ def test_triangulate_session_get_instances_across_views( videos_to_include ) # May not be true if no instances at that frame for cam, vid in videos_to_include.items(): - inst = instances[cam] - assert inst.frame_idx == lf.frame_idx - assert inst.track == track - assert inst.video == vid + instances_in_view = instances[cam] + for inst in instances_in_view: + assert inst.frame_idx == lf.frame_idx + assert inst.track == track + assert inst.video == vid # Try with only a single view cams_to_include = [session.linked_cameras[0]] @@ -1073,9 +1073,11 @@ def test_triangulate_session_get_and_verify_enough_instances( for cam in session.linked_cameras: if cam.name in ["side", "sideL"]: # The views that don't have an instance continue - assert instances[cam].frame_idx == lf.frame_idx - assert instances[cam].track == track - assert instances[cam].video == session[cam] + instances_in_view = instances[cam] + for inst in instances_in_view: + assert inst.frame_idx == lf.frame_idx + assert inst.track == track + assert inst.video == session[cam] # Test with cams_to_include, expect views from only those cameras cams_to_include = session.linked_cameras[-2:] @@ -1087,14 +1089,19 @@ def test_triangulate_session_get_and_verify_enough_instances( ) assert len(instances) == len(cams_to_include) for cam in cams_to_include: - assert instances[cam].frame_idx == lf.frame_idx - assert instances[cam].track == track - assert instances[cam].video == session[cam] + instances_in_view = instances[cam] + for inst in instances_in_view: + assert inst.frame_idx == lf.frame_idx + assert inst.track == track + assert inst.video == session[cam] # Test with not enough instances, expect views from only those cameras cams_to_include = session.linked_cameras[0:2] instances = TriangulateSession.get_and_verify_enough_instances( - session=session, frame_idx=lf.frame_idx, cams_to_include=cams_to_include + session=session, + frame_idx=lf.frame_idx, + cams_to_include=cams_to_include, + track=None, ) assert isinstance(instances, bool) assert not instances @@ -1229,10 +1236,15 @@ def test_triangulate_session_update_instances(multiview_min_session_labels: Labe instances_and_coordinates = TriangulateSession.calculate_reprojected_points( session=session, instances=instances ) - for inst, inst_coords in instances_and_coordinates: - assert inst_coords.shape == (1, len(inst.skeleton), 2) # Tracks, Nodes, 2 - # Assert coord are different from original - assert not np.array_equal(inst_coords, inst.points_array) + for instances_in_view, inst_coords in instances_and_coordinates: + for inst in instances_in_view: + assert inst_coords.shape == ( + len(instances_in_view), + len(inst.skeleton), + 2, + ) # Tracks, Nodes, 2 + # Assert coord are different from original + assert not np.array_equal(inst_coords, inst.points_array) # Just run for code coverage testing, do not test output here (race condition) # (see "functional core, imperative shell" pattern) From f6be8c4d6bdc56721ad3790cfe68e5647eac0362 Mon Sep 17 00:00:00 2001 From: roomrys Date: Tue, 31 Oct 2023 07:00:18 -0700 Subject: [PATCH 02/22] Return instances and coords as a dictionary with cams --- sleap/gui/commands.py | 37 ++++++++++++++++++++----------------- tests/gui/test_commands.py | 10 ++++++---- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index ca66c3474..8779967d9 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -35,6 +35,7 @@ class which inherits from `AppCommand` (or a more specialized class such as import traceback from enum import Enum from glob import glob +from itertools import product from pathlib import Path, PurePath from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union @@ -3601,9 +3602,9 @@ def get_instances_across_views( session: RecordingSession, frame_idx: int, cams_to_include: Optional[List[Camcorder]] = None, - track: Union["Track", int] = -1, + track: Union[Track, int] = -1, require_multiple_views: bool = False, - ) -> Dict[Camcorder, "Instance"]: + ) -> Dict[Camcorder, List[Instance]]: """Get all `Instances` accross all views at a given frame index. Args: @@ -3616,7 +3617,7 @@ def get_instances_across_views( or instances are found. Returns: - Dict with `Camcorder` keys and `Instances` values. + Dict with `Camcorder` keys and `List[Instance]` values. Raises: ValueError if require_multiple_view is true and one or less views or @@ -3710,18 +3711,14 @@ def get_instances_matrices(instances_ordered: List[Instance]) -> np.ndarray: M x F x T x N x 2 array of instances coordinates. """ - # Gather instances into M x F x T x N x 2 arrays (require specific order) - # (M = # views, F = # frames = 1, T = # tracks, N = # nodes, 2 = x, y) - # Get list of instances matrices from each view - inst_coords = [ np.stack( [inst.numpy() for inst in instances_in_view], axis=0, ) for instances_in_view in instances_ordered - ] # List[T x N x 2] + ] # len(M), List[T x N x 2] inst_coords = np.stack(inst_coords, axis=0) # M x T x N x 2 inst_coords = np.expand_dims(inst_coords, axis=1) # M x F=1 x T x N x 2 @@ -3751,7 +3748,7 @@ def calculate_excluded_views( @staticmethod def calculate_reprojected_points( session: RecordingSession, instances: Dict[Camcorder, List[Instance]] - ) -> Iterator[Tuple["Instance", np.ndarray]]: + ) -> Dict[Camcorder, Tuple["Instance", np.ndarray]]: """Triangulate and reproject instance coordinates. Note that the order of the instances in the list must match the order of the @@ -3768,14 +3765,13 @@ def calculate_reprojected_points( the number of nodes and T is the number of reprojected instances. """ - # TODO (LM): Support multiple tracks and optimize + # TODO (LM): Optimize excluded_views = TriangulateSession.calculate_excluded_views( session=session, instances=instances ) - instances_ordered = [ - instances[cam] for cam in session.cameras if cam in instances - ] + cams_ordered = [cam for cam in session.cameras if cam in instances] + instances_ordered = [instances[cam] for cam in cams_ordered] # Gather instances into M x F x T x N x 2 arrays (require specific order) # (M = # views, F = # frames = 1, T = # tracks, N = # nodes, 2 = x, y) @@ -3796,7 +3792,14 @@ def calculate_reprojected_points( inst_coords_reprojected.squeeze(), inst_coords_reprojected.shape[0], axis=0 ) # len(M) of T x N x 2 - return zip(instances_ordered, insts_coords_list) + insts_and_coords: Dict[Camcorder, Tuple[Instance, np.ndarray]] = { + cam: (instances_in_view, inst_coords) + for cam, instances_in_view, inst_coords in zip( + cams_ordered, instances_ordered, insts_coords_list + ) + } + + return insts_and_coords @staticmethod def update_instances(session, instances: Dict[Camcorder, Instance]): @@ -3811,14 +3814,14 @@ def update_instances(session, instances: Dict[Camcorder, Instance]): """ # Triangulate and reproject instance coordinates. - instances_and_coords: Iterator[ - Tuple["Instance", np.ndarray] + instances_and_coords: Dict[ + Camcorder, Tuple[Instance, np.ndarray] ] = TriangulateSession.calculate_reprojected_points( session=session, instances=instances ) # Update the instance coordinates. - for instances_in_view, inst_coord in instances_and_coords: + for cam, (instances_in_view, inst_coord) in instances_and_coords.items(): for inst_idx, inst in enumerate(instances_in_view): inst.update_points( points=inst_coord[inst_idx], exclude_complete=True diff --git a/tests/gui/test_commands.py b/tests/gui/test_commands.py index 93d4008b4..4bae7fbc6 100644 --- a/tests/gui/test_commands.py +++ b/tests/gui/test_commands.py @@ -1189,11 +1189,13 @@ def test_triangulate_session_calculate_reprojected_points( ) # Check that we get the same number of instances as input - assert len(instances) == len(list(instances_and_coords)) + assert len(instances) == len(instances_and_coords) # Check that each instance has the same number of points - for inst, inst_coords in instances_and_coords: - assert inst_coords.shape[1] == len(inst.skeleton) # (1, 15, 2) + for instances_in_view, inst_coords in instances_and_coords.values(): + assert inst_coords.shape[0] == len(instances_in_view) + for inst in instances_in_view: + assert inst_coords.shape[1] == len(inst.skeleton) # (1, 15, 2) def test_triangulate_session_get_instances_matrices( @@ -1236,7 +1238,7 @@ def test_triangulate_session_update_instances(multiview_min_session_labels: Labe instances_and_coordinates = TriangulateSession.calculate_reprojected_points( session=session, instances=instances ) - for instances_in_view, inst_coords in instances_and_coordinates: + for instances_in_view, inst_coords in instances_and_coordinates.values(): for inst in instances_in_view: assert inst_coords.shape == ( len(instances_in_view), From c4690ecde92fb49e0c2fc8689c214893bc032f2e Mon Sep 17 00:00:00 2001 From: roomrys Date: Tue, 31 Oct 2023 09:45:00 -0700 Subject: [PATCH 03/22] Update get_instance_across_views to handle multiple frames --- sleap/gui/commands.py | 196 ++++++++++++++++++++++++++++++++---------- 1 file changed, 152 insertions(+), 44 deletions(-) diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index 8779967d9..e3e3ed9f5 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -3513,7 +3513,7 @@ def verify_views_and_instances(cls, context: CommandContext, params: dict): @staticmethod def get_and_verify_enough_instances( session: RecordingSession, - frame_idx: int, + frame_idcs: List[int], context: Optional[CommandContext] = None, cams_to_include: Optional[List[Camcorder]] = None, track: Union[Track, int] = -1, @@ -3525,7 +3525,7 @@ def get_and_verify_enough_instances( Args: session: The `RecordingSession` containing the `Camcorder`s. - frame_idx: Frame index to get instances from (0-indexed). + frame_idcs: List of frame indices to get instances from (0-indexed). context: The optional command context used to display a dialog. cams_to_include: List of `Camcorder`s to include. Default is all. track: `Track` object used to find instances accross views. Default is -1 @@ -3538,10 +3538,10 @@ def get_and_verify_enough_instances( """ try: instances: Dict[ - Camcorder, Instance + int, Dict[Camcorder, Instance] ] = TriangulateSession.get_instances_across_views( session=session, - frame_idx=frame_idx, + frame_idcs=frame_idcs, cams_to_include=cams_to_include, track=track, require_multiple_views=True, @@ -3598,10 +3598,102 @@ def verify_enough_views( return True @staticmethod - def get_instances_across_views( + def get_groups_of_instances( session: RecordingSession, frame_idx: int, cams_to_include: Optional[List[Camcorder]] = None, + ): + """Get instances grouped by `InstanceGroup` or group instances across views. + + If there are not instances in an `InstanceGroup` for all views, then try + regrouping using leftover instances. Do not add to an `InstanceGroup` if the + error is above a set threshold (i.e. there may not be the same instance labeled + across views). + + """ + + permutated_instances: Dict[ + Camcorder, List[Instance] + ] = TriangulateSession.get_permutations_of_instances( + session=session, + frame_idx=frame_idx, + cams_to_include=cams_to_include, + ) + + # Triangulate and reproject instance coordinates. + instances_and_coords: Dict[ + Camcorder, Tuple[Instance, np.ndarray] + ] = TriangulateSession.calculate_reprojected_points( + session=session, instances=permutated_instances + ) + + # Compare the instance coordinates. + reprojection_error = { + cam: np.inf * np.ones() for cam in permutated_instances.keys() + } + grouped_instances = {cam: [] for cam in permutated_instances.keys()} + for cam, (instances_in_view, inst_coord) in instances_and_coords.keys(): + for inst_idx, inst in enumerate(instances_in_view): + instance_error = np.linalg.norm( + np.nan_to_num(inst.points_array - inst_coord[inst_idx]) + ) + + return grouped_instances + + @staticmethod + def get_permutations_of_instances( + session: RecordingSession, + frame_idx: int, + cams_to_include: Optional[List[Camcorder]] = None, + ) -> Dict[Camcorder, List[Instance]]: + """Get all possible combinations of instances across views. + + Args: + session: The `RecordingSession` containing the `Camcorder`s. + frame_idx: Frame index to get instances from (0-indexed). + cams_to_include: List of `Camcorder`s to include. Default is all. + require_multiple_views: If True, then raise and error if one or less views + or instances are found. + + Raises: + ValueError if one or less views or instances are found. + + Returns: + Dict with `Camcorder` keys and `List[Instance]` values. + """ + + instances: Dict[ + Camcorder, List[Instance] + ] = TriangulateSession.get_instances_across_views( + session=session, + frame_idx=frame_idx, + cams_to_include=cams_to_include, + track=-1, # Get all instances regardless of track. + require_multiple_views=True, + ) + + # TODO(LM): Should we only do this for the selected instance? + + # Permutate instances into psuedo groups where each element is a tuple + # grouping elements from different views. + combinations: List[Tuple[Instance]] = list( + product(*instances.values()) + ) # len(prod(instances.values())) with each element of len(instances.keys()) + + # Regroup combos s.t. instances from a single view are in the same list. + cams = list(instances.keys()) + grouped_instances = {cam: [] for cam in cams} + for combination in combinations: + for cam, inst in zip(cams, combination): + grouped_instances[cam].append(inst) + + return grouped_instances + + @staticmethod + def get_instances_across_views( + session: RecordingSession, + frame_idcs: List[int], + cams_to_include: Optional[List[Camcorder]] = None, track: Union[Track, int] = -1, require_multiple_views: bool = False, ) -> Dict[Camcorder, List[Instance]]: @@ -3609,7 +3701,7 @@ def get_instances_across_views( Args: session: The `RecordingSession` containing the `Camcorder`s. - frame_idx: Frame index to get instances from (0-indexed). + frame_idcs: List of frame indices to get instances from (0-indexed). cams_to_include: List of `Camcorder`s to include. Default is all. track: `Track` object used to find instances accross views. Default is -1 which find all instances regardless of track. @@ -3625,34 +3717,39 @@ def get_instances_across_views( """ # Get all views at this frame index - views: Dict[ - Camcorder, "LabeledFrame" - ] = TriangulateSession.get_all_views_at_frame( - session=session, - frame_idx=frame_idx, - cams_to_include=cams_to_include, - ) - - # If not enough views, then raise error - if len(views) <= 1 and require_multiple_views: - raise ValueError( - "One or less views found for frame " - f"{frame_idx} in {session.camera_cluster}." + instances: Dict[int, Dict[Camcorder, List[Instance]]] = {} + for frame_idx in frame_idcs: + views: Dict[ + Camcorder, "LabeledFrame" + ] = TriangulateSession.get_all_views_at_frame( + session=session, + frame_idx=frame_idx, + cams_to_include=cams_to_include, ) - # Find all instance accross all views - instances: Dict[Camcorder, "Instance"] = {} - for cam, lf in views.items(): - insts = lf.find(track=track) - if len(insts) > 0: - instances[cam] = insts + # TODO(LM): Should we just skip this frame if not enough views? + # If not enough views, then raise error + if len(views) <= 1 and require_multiple_views: + raise ValueError( + "One or less views found for frame " + f"{frame_idx} in {session.camera_cluster}." + ) - # If not enough instances for multiple views, then raise error - if len(instances) <= 1 and require_multiple_views: - raise ValueError( - "One or less instances found for frame " - f"{frame_idx} in {session.camera_cluster}." - ) + # Find all instance accross all views + instances_in_frame: Dict[Camcorder, "Instance"] = {} + for cam, lf in views.items(): + insts = lf.find(track=track) + if len(insts) > 0: + instances_in_frame[cam] = insts + + # If not enough instances for multiple views, then raise error + if len(instances_in_frame) <= 1 and require_multiple_views: + raise ValueError( + "One or less instances found for frame " + f"{frame_idx} in {session.camera_cluster}." + ) + + instances[frame_idx] = instances_in_frame return instances @@ -3698,29 +3795,40 @@ def get_all_views_at_frame( return views @staticmethod - def get_instances_matrices(instances_ordered: List[Instance]) -> np.ndarray: + def get_instances_matrices( + instances: Dict[int, Dict[Camcorder, List[Instance]]] + ) -> np.ndarray: """Gather instances from views into M x F x T x N x 2 an array. M: # views, F: # frames = 1, T: # tracks, N: # nodes, 2: x, y + Note that frames indices are not directly used, but rather meant to act as a + marker for independent options (see `TriangulateSession.get_instance_groups`). + Args: - instances_ordered: List of instances from view (following the order of the - `RecordingSession.cameras` if using for triangulation). + instances: Dict with frame indices as keys and another Dict with `Camcorder` + keys and `List[Instance]` values. Returns: M x F x T x N x 2 array of instances coordinates. """ - # Get list of instances matrices from each view - inst_coords = [ - np.stack( - [inst.numpy() for inst in instances_in_view], - axis=0, - ) - for instances_in_view in instances_ordered - ] # len(M), List[T x N x 2] - inst_coords = np.stack(inst_coords, axis=0) # M x T x N x 2 - inst_coords = np.expand_dims(inst_coords, axis=1) # M x F=1 x T x N x 2 + # Get M X T X N X 2 array of instances coordinates for each frame + inst_coords_all_frames = [] + for frame_idx, instances_in_frame in instances.items(): + # Get list of instances matrices from each view + inst_coords_in_views = [ + np.stack( + [inst.numpy() for inst in instances_in_view], + axis=0, + ) + for instances_in_view in instances_in_frame.values() # TODO(LM): Ensure ordered correctly + ] # len(M), List[T x N x 2] + inst_coords_views = np.stack(inst_coords_in_views, axis=0) # M x T x N x 2 + inst_coords_all_frames.append( + inst_coords_views + ) # len=frame_idx, List[M x T x N x 2] + inst_coords = np.stack(inst_coords_all_frames, axis=1) # M x F x T x N x 2 return inst_coords From e4952414a146b328aa6781e9d9c9e2ddc328617f Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Wed, 1 Nov 2023 16:42:13 -0700 Subject: [PATCH 04/22] [wip] Update calculate reprojected points to support multiple frames --- sleap/gui/commands.py | 302 ++++++++++++++++++++++++++++++------------ 1 file changed, 215 insertions(+), 87 deletions(-) diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index e3e3ed9f5..246d7c468 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -37,7 +37,7 @@ class which inherits from `AppCommand` (or a more specialized class such as from glob import glob from itertools import product from pathlib import Path, PurePath -from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union +from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union, cast import attr import cv2 @@ -3427,7 +3427,7 @@ def do_action(cls, context: CommandContext, params: dict): TriangulateSession.update_instances(session=session, instances=instances) @classmethod - def ask(cls, context: CommandContext, params: dict): + def ask(cls, context: CommandContext, params: dict) -> bool: """Add "instances" to params dict if enough views/instances, warning user otherwise. Args: @@ -3447,7 +3447,7 @@ def ask(cls, context: CommandContext, params: dict): return cls.verify_views_and_instances(context=context, params=params) @classmethod - def verify_views_and_instances(cls, context: CommandContext, params: dict): + def verify_views_and_instances(cls, context: CommandContext, params: dict) -> bool: """Verify that there are enough views and instances to triangulate. Also adds "instances" to params dict if there are enough views and instances. @@ -3466,6 +3466,7 @@ def verify_views_and_instances(cls, context: CommandContext, params: dict): Returns: True if enough views/instances for triangulation, False otherwise. """ + video = params.get("video", None) or context.state["video"] session = params.get("session", None) or context.labels.get_session(video) instance = params.get("instance", None) or context.state["instance"] @@ -3495,7 +3496,7 @@ def verify_views_and_instances(cls, context: CommandContext, params: dict): instances = TriangulateSession.get_and_verify_enough_instances( context=context, session=session, - frame_idx=frame_idx, + frame_inds=[frame_idx], cams_to_include=cams_to_include, track=track, show_dialog=show_dialog, @@ -3513,19 +3514,19 @@ def verify_views_and_instances(cls, context: CommandContext, params: dict): @staticmethod def get_and_verify_enough_instances( session: RecordingSession, - frame_idcs: List[int], + frame_inds: List[int], context: Optional[CommandContext] = None, cams_to_include: Optional[List[Camcorder]] = None, track: Union[Track, int] = -1, show_dialog: bool = True, - ) -> Union[Dict[Camcorder, Instance], bool]: + ) -> Union[Dict[int, Dict[Camcorder, List[Instance]]], bool]: """Get all instances accross views at this frame index. If not enough `Instance`s are available at this frame index, then return False. Args: session: The `RecordingSession` containing the `Camcorder`s. - frame_idcs: List of frame indices to get instances from (0-indexed). + frame_inds: List of frame indices to get instances from (0-indexed). context: The optional command context used to display a dialog. cams_to_include: List of `Camcorder`s to include. Default is all. track: `Track` object used to find instances accross views. Default is -1 @@ -3533,15 +3534,17 @@ def get_and_verify_enough_instances( show_dialog: If True, then show a warning dialog. Default is True. Returns: - Dict with `Camcorder` keys and `Instances` values (or False if not enough - instances at this frame index). + Dict with frame identifier keys (does not necessarily need to be the frame + index) and values of another inner dict with `Camcorder` keys and + `List[Instance]` values if enough instances are found, False otherwise. """ + try: instances: Dict[ - int, Dict[Camcorder, Instance] - ] = TriangulateSession.get_instances_across_views( + int, Dict[Camcorder, List[Instance]] + ] = TriangulateSession.get_instances_across_views_multiple_frames( session=session, - frame_idcs=frame_idcs, + frame_inds=frame_inds, cams_to_include=cams_to_include, track=track, require_multiple_views=True, @@ -3549,10 +3552,9 @@ def get_and_verify_enough_instances( return instances except ValueError: # If not enough views or instances, then return - message = ( - "One or less instances found for frame " - f"{frame_idx} in {session.camera_cluster}. " - "Multiple instances accross multiple views needed to triangulate. " + message = traceback.format_exc() + message += ( + "\nMultiple instances accross multiple views needed to triangulate. " "Skipping triangulation and reprojection." ) if show_dialog and context is not None: @@ -3692,7 +3694,7 @@ def get_permutations_of_instances( @staticmethod def get_instances_across_views( session: RecordingSession, - frame_idcs: List[int], + frame_idx: int, cams_to_include: Optional[List[Camcorder]] = None, track: Union[Track, int] = -1, require_multiple_views: bool = False, @@ -3701,7 +3703,7 @@ def get_instances_across_views( Args: session: The `RecordingSession` containing the `Camcorder`s. - frame_idcs: List of frame indices to get instances from (0-indexed). + frame_idx: Frame index to get instances from (0-indexed). cams_to_include: List of `Camcorder`s to include. Default is all. track: `Track` object used to find instances accross views. Default is -1 which find all instances regardless of track. @@ -3717,39 +3719,86 @@ def get_instances_across_views( """ # Get all views at this frame index - instances: Dict[int, Dict[Camcorder, List[Instance]]] = {} - for frame_idx in frame_idcs: - views: Dict[ - Camcorder, "LabeledFrame" - ] = TriangulateSession.get_all_views_at_frame( - session=session, - frame_idx=frame_idx, - cams_to_include=cams_to_include, + views: Dict[ + Camcorder, "LabeledFrame" + ] = TriangulateSession.get_all_views_at_frame( + session=session, + frame_idx=frame_idx, + cams_to_include=cams_to_include, + ) + + # TODO(LM): Should we just skip this frame if not enough views? + # If not enough views, then raise error + if len(views) <= 1 and require_multiple_views: + raise ValueError( + "One or less views found for frame " + f"{frame_idx} in {session.camera_cluster}." ) - # TODO(LM): Should we just skip this frame if not enough views? - # If not enough views, then raise error - if len(views) <= 1 and require_multiple_views: - raise ValueError( - "One or less views found for frame " - f"{frame_idx} in {session.camera_cluster}." - ) + # Find all instance accross all views + instances_in_frame: Dict[Camcorder, List[Instance]] = {} + for cam, lf in views.items(): + insts = lf.find(track=track) + if len(insts) > 0: + instances_in_frame[cam] = insts - # Find all instance accross all views - instances_in_frame: Dict[Camcorder, "Instance"] = {} - for cam, lf in views.items(): - insts = lf.find(track=track) - if len(insts) > 0: - instances_in_frame[cam] = insts + # If not enough instances for multiple views, then raise error + if len(instances_in_frame) <= 1 and require_multiple_views: + raise ValueError( + "One or less instances found for frame " + f"{frame_idx} in {session.camera_cluster}." + ) - # If not enough instances for multiple views, then raise error - if len(instances_in_frame) <= 1 and require_multiple_views: - raise ValueError( - "One or less instances found for frame " - f"{frame_idx} in {session.camera_cluster}." - ) + return instances_in_frame - instances[frame_idx] = instances_in_frame + @staticmethod + def get_instances_across_views_multiple_frames( + session: RecordingSession, + frame_inds: List[int], + cams_to_include: Optional[List[Camcorder]] = None, + track: Union[Track, int] = -1, + require_multiple_views: bool = False, + ) -> Dict[int, Dict[Camcorder, List[Instance]]]: + """Get all `Instances` accross all views at all given frame indices. + + Args: + session: The `RecordingSession` containing the `Camcorder`s. + frame_inds: List of frame indices to get instances from (0-indexed). + cams_to_include: List of `Camcorder`s to include. Default is all. + track: `Track` object used to find instances accross views. Default is -1 + which find all instances regardless of track. + require_multiple_views: If True, then raise and error if one or less views + or instances are found. + + Returns: + Dict with frame identifier keys (does not necessarily need to be the frame + index) and values of another inner dict with `Camcorder` keys and + `List[Instance]` values. + """ + + instances: Dict[int, Dict[Camcorder, List[Instance]]] = {} + for frame_idx in frame_inds: + try: + # Find all instance accross all views + instances_in_frame = TriangulateSession.get_instances_across_views( + session=session, + frame_idx=frame_idx, + cams_to_include=cams_to_include, + track=track, + require_multiple_views=require_multiple_views, + ) + instances[frame_idx] = instances_in_frame + except ValueError: + message = traceback.format_exc() + message += " Skipping frame." + logger.warning(f"{message}") + + if len(instances) == 0: + frame_inds_str = ", ".join([str(frame_idx) for frame_idx in frame_inds]) + raise ValueError( + "Not enough instances or views found for any frame identifiers in " + f"{frame_inds_str}." + ) return instances @@ -3796,8 +3845,9 @@ def get_all_views_at_frame( @staticmethod def get_instances_matrices( - instances: Dict[int, Dict[Camcorder, List[Instance]]] - ) -> np.ndarray: + instances: Dict[int, Dict[Camcorder, List[Instance]]], + session: Optional[RecordingSession] = None, + ) -> Tuple[np.ndarray, List[Camcorder]]: """Gather instances from views into M x F x T x N x 2 an array. M: # views, F: # frames = 1, T: # tracks, N: # nodes, 2: x, y @@ -3808,54 +3858,117 @@ def get_instances_matrices( Args: instances: Dict with frame indices as keys and another Dict with `Camcorder` keys and `List[Instance]` values. + session: The `RecordingSession` containing the `Camcorder`s. Used to order + the instances in the matrix as expected for triangulation. Returns: - M x F x T x N x 2 array of instances coordinates. + M x F x T x N x 2 array of instances coordinates and the ordered list of + `Camcorder`s by which the instances are ordered. """ + cams_ordered = None + # Get M X T X N X 2 array of instances coordinates for each frame - inst_coords_all_frames = [] - for frame_idx, instances_in_frame in instances.items(): + inst_coords_frames = [] + for instances_in_frame in instances.values(): + + # Get correct camera ordering + if cams_ordered is None: + if session is None: + logger.warning( + "No session provided. Cannot organize instance coordinates to " + "be compatible for triangulation." + ) + cams_ordered = [cam for cam in instances_in_frame] + else: + cams_ordered = [ + cam for cam in session.cameras if cam in instances_in_frame + ] + # Get list of instances matrices from each view inst_coords_in_views = [ np.stack( - [inst.numpy() for inst in instances_in_view], + [inst.numpy() for inst in instances_in_frame[cam]], axis=0, ) - for instances_in_view in instances_in_frame.values() # TODO(LM): Ensure ordered correctly + for cam in cams_ordered ] # len(M), List[T x N x 2] + inst_coords_views = np.stack(inst_coords_in_views, axis=0) # M x T x N x 2 - inst_coords_all_frames.append( + inst_coords_frames.append( inst_coords_views ) # len=frame_idx, List[M x T x N x 2] - inst_coords = np.stack(inst_coords_all_frames, axis=1) # M x F x T x N x 2 + inst_coords = np.stack(inst_coords_frames, axis=1) # M x F x T x N x 2 - return inst_coords + return inst_coords, cams_ordered @staticmethod def calculate_excluded_views( session: RecordingSession, - instances: Dict[Camcorder, "Instance"], + cameras_being_used: Union[Dict[Camcorder, List[Instance]], List[Camcorder]], ) -> Tuple[str]: """Get excluded views from dictionary of `Camcorder` to `Instance`. Args: session: The `RecordingSession` containing the `Camcorder`s. - instances: Dict with `Camcorder` key and `Instance` values. + cameras_being_used: List of `Camcorder`s. Returns: Tuple of excluded view names. """ # Calculate excluded views from included cameras - cams_excluded = set(session.cameras) - set(instances.keys()) + cams_excluded = set(session.cameras) - set(cameras_being_used) excluded_views = tuple(cam.name for cam in cams_excluded) + excluded_views = cast(Tuple[str], excluded_views) # cam.name could be Any + + return excluded_views + + @staticmethod + def calculate_excluded_views_multiple_frames( + session: RecordingSession, + instances: Dict[int, Dict[Camcorder, List[Instance]]], + ) -> Tuple[str]: + """Get excluded views from dictionary of `Camcorder` to `Instance`. + + Args: + session: The `RecordingSession` containing the `Camcorder`s. + instances: Dict with frame identifier keys (does not necessarily need to be + the frame index) and values of another inner dict with `Camcorder` keys + and `List[Instance]` values. + + Returns: + Tuple of excluded view names. + + Raises: + ValueError if excluded views are not the same across frames. + """ + + # Calculate excluded views from included cameras + excluded_views = None + for frame_idx, instances_in_frame in instances.items(): + excluded_views_prev = excluded_views + excluded_views = TriangulateSession.calculate_excluded_views( + session=session, cameras_being_used=instances_in_frame + ) + if excluded_views_prev is None: + prev_frame_idx = frame_idx + continue + elif excluded_views != excluded_views_prev: + raise ValueError( + "Excluded views are not the same across frames. " + f"\n\tExcluded views in frame identifier {prev_frame_idx}: {excluded_views_prev}. " + f"\n\tExcluded views in frame identifier {frame_idx}: {excluded_views}." + ) + prev_frame_idx = frame_idx + + excluded_views = cast(Tuple[str], excluded_views) # Could be None if no frames return excluded_views @staticmethod def calculate_reprojected_points( - session: RecordingSession, instances: Dict[Camcorder, List[Instance]] + session: RecordingSession, instances: Dict[int, Dict[Camcorder, List[Instance]]] ) -> Dict[Camcorder, Tuple["Instance", np.ndarray]]: """Triangulate and reproject instance coordinates. @@ -3865,57 +3978,72 @@ def calculate_reprojected_points( https://github.com/lambdaloop/aniposelib/blob/d03b485c4e178d7cff076e9fe1ac36837db49158/aniposelib/cameras.py#L491 Args: - instances: Dict with `Camcorder` keys and `List[Instance]` values. - + session: The `RecordingSession` containing the `Camcorder`s. + instances: Dict with frame identifier keys (does not necessarily need to be + the frame index) and values of another inner dict with `Camcorder` keys + and `List[Instance]` values. Returns: A zip of the ordered instances and the related reprojected coordinates. Each element in the coordinates is a numpy array of shape (T, N, 2) where N is the number of nodes and T is the number of reprojected instances. """ - # TODO (LM): Optimize - - excluded_views = TriangulateSession.calculate_excluded_views( - session=session, instances=instances + # Derive the excluded views from the included cameras and ensures all frames + # have the same excluded views. + excluded_views = TriangulateSession.calculate_excluded_views_multiple_frames( + instances=instances, session=session ) - cams_ordered = [cam for cam in session.cameras if cam in instances] - instances_ordered = [instances[cam] for cam in cams_ordered] # Gather instances into M x F x T x N x 2 arrays (require specific order) - # (M = # views, F = # frames = 1, T = # tracks, N = # nodes, 2 = x, y) - inst_coords = TriangulateSession.get_instances_matrices( - instances_ordered=instances_ordered - ) # M x F=1 x T x N x 2 + # (M = # views, F = # frames, T = # tracks, N = # nodes, 2 = x, y) + inst_coords, cams_ordered = TriangulateSession.get_instances_matrices( + instances=instances, session=session + ) # M x F x T x N x 2 points_3d = triangulate( p2d=inst_coords, calib=session.camera_cluster, excluded_views=excluded_views, - ) # F=1, T, N, 3 + ) # F, T, N, 3 - # Update the views with the new 3D points + # Get the reprojected 2D points from the 3D points inst_coords_reprojected = reproject( points_3d, calib=session.camera_cluster, excluded_views=excluded_views - ) # M x F=1 x T x N x 2 - insts_coords_list: List[np.ndarray] = np.split( - inst_coords_reprojected.squeeze(), inst_coords_reprojected.shape[0], axis=0 - ) # len(M) of T x N x 2 - - insts_and_coords: Dict[Camcorder, Tuple[Instance, np.ndarray]] = { - cam: (instances_in_view, inst_coords) - for cam, instances_in_view, inst_coords in zip( - cams_ordered, instances_ordered, insts_coords_list - ) - } + ) # M x F x T x N x 2 + insts_coords_frames_list: List[np.ndarray] = np.split( + inst_coords_reprojected, inst_coords_reprojected.shape[1], axis=1 + ) # len(F) of M x T x N x 2 + insts_coords_list: List[List[np.ndarray]] = [ + np.split(insts_coords_views, insts_coords_views.shape[0], axis=0) + for insts_coords_views in insts_coords_frames_list + ] # len(F) of [len(M) of T x N x 2] + + # Group together the reordered (by cam) instances and the reprojected coords. + insts_and_coords: Dict[ + int, Dict[Camcorder, Iterator[Tuple[List[Instance], List[np.ndarray]]]] + ] # Dict len(F) of [list len(M) of array (T x N x 2)] + for frame_idx, instances_in_frame in instances.items(): + insts_and_coords_in_frame = {} + for cam_idx, cam in enumerate(cams_ordered): + instances_in_frame_ordered = instances_in_frame[cam] + insts_coords_in_frame = insts_coords_list[frame_idx][cam_idx] + insts_and_coords_in_frame[cam] = zip( + instances_in_frame_ordered, insts_coords_in_frame + ) + insts_and_coords[frame_idx] = insts_and_coords_in_frame return insts_and_coords @staticmethod - def update_instances(session, instances: Dict[Camcorder, Instance]): + def update_instances( + session, instances: Dict[int, Dict[Camcorder, List[Instance]]] + ): """Triangulate, reproject, and update coordinates of `Instances`. Args: session: The `RecordingSession` containing the `Camcorder`s. - instances: Dict with `Camcorder` keys and `Instance` values. + instances: Dict with frame identifier keys (does not necessarily need to be + the frame index) and values of another inner dict with `Camcorder` keys + and `List[Instance]` values. Returns: None From 4470c0e9baf4dc905e5cb1cae0983a051a58ec65 Mon Sep 17 00:00:00 2001 From: roomrys Date: Thu, 2 Nov 2023 13:55:12 -0700 Subject: [PATCH 05/22] Finish support for multi-frame reprojection --- sleap/gui/commands.py | 140 +++++++++++++++++++++++++++++-------- tests/gui/test_commands.py | 75 +++++++++++--------- 2 files changed, 154 insertions(+), 61 deletions(-) diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index 246d7c468..4cbf1ad55 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -3871,7 +3871,6 @@ def get_instances_matrices( # Get M X T X N X 2 array of instances coordinates for each frame inst_coords_frames = [] for instances_in_frame in instances.values(): - # Get correct camera ordering if cams_ordered is None: if session is None: @@ -3967,9 +3966,9 @@ def calculate_excluded_views_multiple_frames( return excluded_views @staticmethod - def calculate_reprojected_points( + def _calculate_reprojected_points( session: RecordingSession, instances: Dict[int, Dict[Camcorder, List[Instance]]] - ) -> Dict[Camcorder, Tuple["Instance", np.ndarray]]: + ) -> Dict[int, Dict[Camcorder, Iterator[Tuple[Instance, np.ndarray]]]]: """Triangulate and reproject instance coordinates. Note that the order of the instances in the list must match the order of the @@ -3977,15 +3976,20 @@ def calculate_reprojected_points( a dictionary mapping back to its `Camcorder`. https://github.com/lambdaloop/aniposelib/blob/d03b485c4e178d7cff076e9fe1ac36837db49158/aniposelib/cameras.py#L491 + Also, this function does not handle grouping instances with their respective + coordinates by reordering by camera. + See `TriangulateSession.calculate_reprojected_points`. + Args: session: The `RecordingSession` containing the `Camcorder`s. instances: Dict with frame identifier keys (does not necessarily need to be the frame index) and values of another inner dict with `Camcorder` keys and `List[Instance]` values. Returns: - A zip of the ordered instances and the related reprojected coordinates. Each - element in the coordinates is a numpy array of shape (T, N, 2) where N is - the number of nodes and T is the number of reprojected instances. + A dictionary with frame identifier keys (does not necessarily need to be the + frame index) and values of another inner dict with `Camcorder` keys and + a zip of the `List[Instance]` and reprojected instance coordinates of shape + (T, N, 2) ordered by the `Camcorder` order in the `CameraCluster`. """ # Derive the excluded views from the included cameras and ensures all frames @@ -4009,30 +4013,112 @@ def calculate_reprojected_points( inst_coords_reprojected = reproject( points_3d, calib=session.camera_cluster, excluded_views=excluded_views ) # M x F x T x N x 2 - insts_coords_frames_list: List[np.ndarray] = np.split( - inst_coords_reprojected, inst_coords_reprojected.shape[1], axis=1 - ) # len(F) of M x T x N x 2 + + return inst_coords_reprojected, cams_ordered + + def group_instances_and_coords( + instances, inst_coords_reprojected, cams_ordered + ) -> Dict[int, Dict[Camcorder, Iterator[Tuple[Instance, np.ndarray]]]]: + """Group instances and reprojected coordinates by frame and view. + + Args: + instances: Dict with frame identifier keys (does not necessarily need to be + the frame index) and values of another inner dict with `Camcorder` keys + and `List[Instance]` values. + inst_coords_reprojected: M x F x T x N x 2 array of reprojected instance + coordinates. + cams_ordered: List of `Camcorder`s ordered by the `CameraCluster` + representing both the order and subset of cameras used to calculate + `inst_coords_reprojected`. + + Returns: + A dictionary with frame identifier keys (does not necessarily need to be the + frame index) and values of another inner dict with `Camcorder` keys and + a zip of the `List[Instance]` and reprojected instance coordinates list with + items of shape (N, 2) ordered by the `Camcorder` order in the `CameraCluster`. + """ + + # Split the reprojected coordinates into a list corresponding to instances list. insts_coords_list: List[List[np.ndarray]] = [ - np.split(insts_coords_views, insts_coords_views.shape[0], axis=0) - for insts_coords_views in insts_coords_frames_list - ] # len(F) of [len(M) of T x N x 2] + [ # Annoyingly, np.split leaves a singleton dimension, so we have to squeeze. + np.squeeze(insts_coords_in_view, axis=0) + for insts_coords_in_view in np.split( + np.squeeze(insts_coords_in_frame, axis=1), + insts_coords_in_frame.shape[0], + axis=0, + ) # len(M) of T x N x 2 + for insts_coords_track in np.split( + np.squeeze(insts_coords_in_view, axis=0), + insts_coords_in_view.shape[0], + axis=0, + ) # len(T) of N x 2 + ] + for insts_coords_in_frame in np.split( + inst_coords_reprojected, inst_coords_reprojected.shape[1], axis=1 + ) # len(F) of M x T x N x 2 + ] # len(F) of len(M) of len(T) of N x 2 # Group together the reordered (by cam) instances and the reprojected coords. insts_and_coords: Dict[ - int, Dict[Camcorder, Iterator[Tuple[List[Instance], List[np.ndarray]]]] - ] # Dict len(F) of [list len(M) of array (T x N x 2)] - for frame_idx, instances_in_frame in instances.items(): - insts_and_coords_in_frame = {} + int, Dict[Camcorder, Iterator[Tuple[Instance, np.ndarray]]] + ] = ( + {} + ) # Dict len(F) of dict len(M) of zipped lists of len(T) instances and array of N x 2 + for frame_idx, instances_in_frame in instances.items(): # len(F) of dict + insts_and_coords_in_frame: Dict[Camcorder, Tuple[Instance, np.ndarray]] = {} for cam_idx, cam in enumerate(cams_ordered): - instances_in_frame_ordered = instances_in_frame[cam] - insts_coords_in_frame = insts_coords_list[frame_idx][cam_idx] - insts_and_coords_in_frame[cam] = zip( - instances_in_frame_ordered, insts_coords_in_frame + instances_in_frame_ordered: List[Instance] = instances_in_frame[ + cam + ] # Reorder by cam to match coordinates, len(T) + insts_coords_in_frame: np.ndarray = insts_coords_list[frame_idx][ + cam_idx + ] # len(T) of N x 2 + insts_and_coords_in_frame[cam]: Tuple[Instance, np.ndarray] = zip( + instances_in_frame_ordered, + insts_coords_in_frame, ) insts_and_coords[frame_idx] = insts_and_coords_in_frame return insts_and_coords + @staticmethod + def calculate_reprojected_points( + session: RecordingSession, instances: Dict[int, Dict[Camcorder, List[Instance]]] + ): + """Triangulate, reproject, and group coordinates of `Instances`. + + Args: + session: The `RecordingSession` containing the `Camcorder`s. + instances: Dict with frame identifier keys (does not necessarily need to be + the frame index) and values of another inner dict with `Camcorder` keys + and `List[Instance]` values. + + Returns: + A dictionary with frame identifier keys (does not necessarily need to be the + frame index) and values of another inner dict with `Camcorder` keys and + a zip of the `List[Instance]` and reprojected instance coordinates list with + items of shape (N, 2) ordered by the `Camcorder` order in the `CameraCluster`. + """ + + # Triangulate and reproject instance coordinates. + ( + inst_coords_reprojected, + cams_ordered, + ) = TriangulateSession._calculate_reprojected_points( + session=session, instances=instances + ) + + # Group together instances (the reordered by cam) and the reprojected coords. + instances_and_coords: Dict[ + int, Dict[Camcorder, Iterator[Tuple[Instance, np.ndarray]]] + ] = TriangulateSession.group_instances_and_coords( + inst_coords_reprojected=inst_coords_reprojected, + instances=instances, + cams_ordered=cams_ordered, + ) + + return instances_and_coords + @staticmethod def update_instances( session, instances: Dict[int, Dict[Camcorder, List[Instance]]] @@ -4050,18 +4136,16 @@ def update_instances( """ # Triangulate and reproject instance coordinates. - instances_and_coords: Dict[ - Camcorder, Tuple[Instance, np.ndarray] - ] = TriangulateSession.calculate_reprojected_points( + instances_and_coords = TriangulateSession.calculate_reprojected_points( session=session, instances=instances ) + # TODO(LM): Since we only use the values here, is a dictionary overkill? # Update the instance coordinates. - for cam, (instances_in_view, inst_coord) in instances_and_coords.items(): - for inst_idx, inst in enumerate(instances_in_view): - inst.update_points( - points=inst_coord[inst_idx], exclude_complete=True - ) # inst_coord is (T, N, 2) + for instances_in_frame in instances_and_coords.values(): + for instances_in_view in instances_in_frame.values(): + for inst, inst_coord in instances_in_view: + inst.update_points(points=inst_coord, exclude_complete=True) def open_website(url: str): diff --git a/tests/gui/test_commands.py b/tests/gui/test_commands.py index 4bae7fbc6..a96b688ed 100644 --- a/tests/gui/test_commands.py +++ b/tests/gui/test_commands.py @@ -1067,13 +1067,16 @@ def test_triangulate_session_get_and_verify_enough_instances( # Test with no cams_to_include, expect views from all linked cameras instances = TriangulateSession.get_and_verify_enough_instances( - session=session, frame_idx=lf.frame_idx, track=track + session=session, frame_inds=[lf.frame_idx], track=track ) - assert len(instances) == 6 # Some views don't have an instance at this track + instances_in_frame = instances[lf.frame_idx] + assert ( + len(instances_in_frame) == 6 + ) # Some views don't have an instance at this track for cam in session.linked_cameras: if cam.name in ["side", "sideL"]: # The views that don't have an instance continue - instances_in_view = instances[cam] + instances_in_view = instances_in_frame[cam] for inst in instances_in_view: assert inst.frame_idx == lf.frame_idx assert inst.track == track @@ -1083,13 +1086,14 @@ def test_triangulate_session_get_and_verify_enough_instances( cams_to_include = session.linked_cameras[-2:] instances = TriangulateSession.get_and_verify_enough_instances( session=session, - frame_idx=lf.frame_idx, + frame_inds=[lf.frame_idx], cams_to_include=cams_to_include, track=track, ) - assert len(instances) == len(cams_to_include) + instances_in_frame = instances[lf.frame_idx] + assert len(instances_in_frame) == len(cams_to_include) for cam in cams_to_include: - instances_in_view = instances[cam] + instances_in_view = instances_in_frame[cam] for inst in instances_in_view: assert inst.frame_idx == lf.frame_idx assert inst.track == track @@ -1099,7 +1103,7 @@ def test_triangulate_session_get_and_verify_enough_instances( cams_to_include = session.linked_cameras[0:2] instances = TriangulateSession.get_and_verify_enough_instances( session=session, - frame_idx=lf.frame_idx, + frame_inds=[lf.frame_idx], cams_to_include=cams_to_include, track=None, ) @@ -1181,8 +1185,8 @@ def test_triangulate_session_calculate_reprojected_points( track = multiview_min_session_labels.tracks[0] instances: Dict[ Camcorder, Instance - ] = TriangulateSession.get_instances_across_views( - session=session, frame_idx=lf.frame_idx, track=track + ] = TriangulateSession.get_instances_across_views_multiple_frames( + session=session, frame_inds=[lf.frame_idx], track=track ) instances_and_coords = TriangulateSession.calculate_reprojected_points( session=session, instances=instances @@ -1192,10 +1196,10 @@ def test_triangulate_session_calculate_reprojected_points( assert len(instances) == len(instances_and_coords) # Check that each instance has the same number of points - for instances_in_view, inst_coords in instances_and_coords.values(): - assert inst_coords.shape[0] == len(instances_in_view) - for inst in instances_in_view: - assert inst_coords.shape[1] == len(inst.skeleton) # (1, 15, 2) + for instances_in_frame in instances_and_coords.values(): + for instances_in_view in instances_in_frame.values(): + for inst, inst_coords in instances_in_view: + assert inst_coords.shape[0] == len(inst.skeleton) # (15, 2) def test_triangulate_session_get_instances_matrices( @@ -1207,18 +1211,20 @@ def test_triangulate_session_get_instances_matrices( lf: LabeledFrame = labels[0] track = labels.tracks[0] instances: Dict[ - Camcorder, Instance - ] = TriangulateSession.get_instances_across_views( - session=session, frame_idx=lf.frame_idx, track=track + int, Dict[Camcorder, List[Instance]] + ] = TriangulateSession.get_instances_across_views_multiple_frames( + session=session, frame_inds=[lf.frame_idx], track=track ) - instances_matrices = TriangulateSession.get_instances_matrices( - instances_ordered=instances.values() + instances_matrices, cams_ordered = TriangulateSession.get_instances_matrices( + instances=instances ) # Verify shape - n_views = len(instances) - n_frames = 1 - n_tracks = 1 + n_frames = len(instances) + n_views = len(instances[lf.frame_idx]) + assert n_views == len(cams_ordered) + n_tracks = len(instances[lf.frame_idx][cams_ordered[0]]) + assert n_tracks == 1 n_nodes = len(labels.skeleton) assert instances_matrices.shape == (n_views, n_frames, n_tracks, n_nodes, 2) @@ -1231,22 +1237,25 @@ def test_triangulate_session_update_instances(multiview_min_session_labels: Labe lf: LabeledFrame = multiview_min_session_labels[0] track = multiview_min_session_labels.tracks[0] instances: Dict[ - Camcorder, Instance - ] = TriangulateSession.get_instances_across_views( - session=session, frame_idx=lf.frame_idx, track=track + int, Dict[Camcorder, List[Instance]] + ] = TriangulateSession.get_instances_across_views_multiple_frames( + session=session, + frame_inds=[lf.frame_idx], + track=track, + require_multiple_views=True, ) instances_and_coordinates = TriangulateSession.calculate_reprojected_points( session=session, instances=instances ) - for instances_in_view, inst_coords in instances_and_coordinates.values(): - for inst in instances_in_view: - assert inst_coords.shape == ( - len(instances_in_view), - len(inst.skeleton), - 2, - ) # Tracks, Nodes, 2 - # Assert coord are different from original - assert not np.array_equal(inst_coords, inst.points_array) + for instances_in_frame in instances_and_coordinates.values(): + for instances_in_view in instances_in_frame.values(): + for inst, inst_coords in instances_in_view: + assert inst_coords.shape == ( + len(inst.skeleton), + 2, + ) # Nodes, 2 + # Assert coord are different from original + assert not np.array_equal(inst_coords, inst.points_array) # Just run for code coverage testing, do not test output here (race condition) # (see "functional core, imperative shell" pattern) From d83ea6457d4ee75caa4ee8549564040ae8db0f7d Mon Sep 17 00:00:00 2001 From: roomrys Date: Thu, 2 Nov 2023 14:01:41 -0700 Subject: [PATCH 06/22] Remove code to put in next PR --- sleap/gui/commands.py | 93 ------------------------------------------- 1 file changed, 93 deletions(-) diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index 4cbf1ad55..cc1e336f8 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -35,7 +35,6 @@ class which inherits from `AppCommand` (or a more specialized class such as import traceback from enum import Enum from glob import glob -from itertools import product from pathlib import Path, PurePath from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union, cast @@ -3599,98 +3598,6 @@ def verify_enough_views( return True - @staticmethod - def get_groups_of_instances( - session: RecordingSession, - frame_idx: int, - cams_to_include: Optional[List[Camcorder]] = None, - ): - """Get instances grouped by `InstanceGroup` or group instances across views. - - If there are not instances in an `InstanceGroup` for all views, then try - regrouping using leftover instances. Do not add to an `InstanceGroup` if the - error is above a set threshold (i.e. there may not be the same instance labeled - across views). - - """ - - permutated_instances: Dict[ - Camcorder, List[Instance] - ] = TriangulateSession.get_permutations_of_instances( - session=session, - frame_idx=frame_idx, - cams_to_include=cams_to_include, - ) - - # Triangulate and reproject instance coordinates. - instances_and_coords: Dict[ - Camcorder, Tuple[Instance, np.ndarray] - ] = TriangulateSession.calculate_reprojected_points( - session=session, instances=permutated_instances - ) - - # Compare the instance coordinates. - reprojection_error = { - cam: np.inf * np.ones() for cam in permutated_instances.keys() - } - grouped_instances = {cam: [] for cam in permutated_instances.keys()} - for cam, (instances_in_view, inst_coord) in instances_and_coords.keys(): - for inst_idx, inst in enumerate(instances_in_view): - instance_error = np.linalg.norm( - np.nan_to_num(inst.points_array - inst_coord[inst_idx]) - ) - - return grouped_instances - - @staticmethod - def get_permutations_of_instances( - session: RecordingSession, - frame_idx: int, - cams_to_include: Optional[List[Camcorder]] = None, - ) -> Dict[Camcorder, List[Instance]]: - """Get all possible combinations of instances across views. - - Args: - session: The `RecordingSession` containing the `Camcorder`s. - frame_idx: Frame index to get instances from (0-indexed). - cams_to_include: List of `Camcorder`s to include. Default is all. - require_multiple_views: If True, then raise and error if one or less views - or instances are found. - - Raises: - ValueError if one or less views or instances are found. - - Returns: - Dict with `Camcorder` keys and `List[Instance]` values. - """ - - instances: Dict[ - Camcorder, List[Instance] - ] = TriangulateSession.get_instances_across_views( - session=session, - frame_idx=frame_idx, - cams_to_include=cams_to_include, - track=-1, # Get all instances regardless of track. - require_multiple_views=True, - ) - - # TODO(LM): Should we only do this for the selected instance? - - # Permutate instances into psuedo groups where each element is a tuple - # grouping elements from different views. - combinations: List[Tuple[Instance]] = list( - product(*instances.values()) - ) # len(prod(instances.values())) with each element of len(instances.keys()) - - # Regroup combos s.t. instances from a single view are in the same list. - cams = list(instances.keys()) - grouped_instances = {cam: [] for cam in cams} - for combination in combinations: - for cam, inst in zip(cams, combination): - grouped_instances[cam].append(inst) - - return grouped_instances - @staticmethod def get_instances_across_views( session: RecordingSession, From 443d410df7b1f34ac53e3dacb90cca7036971ff7 Mon Sep 17 00:00:00 2001 From: Liezl Maree <38435167+roomrys@users.noreply.github.com> Date: Sat, 20 Jan 2024 07:53:38 -0800 Subject: [PATCH 07/22] (3b -> 3a) Add method to get single instance permutations (#1586) * Add method to get single instance permutations * Append a dummy instance for missing instances * Correct 'permutations' to 'products' * (3c -> 3b) Add method to test instance grouping (#1599) * (3d -> 3c) Add method for multi instance products (#1605) --- sleap/gui/commands.py | 428 +++++++++++++++++++++++++++++++++---- tests/gui/test_commands.py | 126 +++++++++-- 2 files changed, 504 insertions(+), 50 deletions(-) diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index beaa9dcda..7b6b44d64 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -35,6 +35,7 @@ class which inherits from `AppCommand` (or a more specialized class such as import traceback from enum import Enum from glob import glob +from itertools import permutations, product from pathlib import Path, PurePath from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union, cast @@ -3421,9 +3422,19 @@ def do_action(cls, context: CommandContext, params: dict): video = params.get("video", None) or context.state["video"] session = params.get("session", None) or context.labels.get_session(video) instances = params["instances"] + session = cast(RecordingSession, session) # Could be None if no labels or video + + # Get best instance grouping and reprojected coords + instances_and_reprojected_coords = ( + TriangulateSession.get_instance_grouping_and_reprojected_coords( + session=session, instance_hypotheses=instances + ) + ) # Update instances - TriangulateSession.update_instances(session=session, instances=instances) + TriangulateSession.update_instances( + instances_and_coords=instances_and_reprojected_coords + ) @classmethod def ask(cls, context: CommandContext, params: dict) -> bool: @@ -3479,7 +3490,6 @@ def verify_views_and_instances(cls, context: CommandContext, params: dict) -> bo if session is None or instance is None: return - track = instance.track # TODO(LM): Replace with InstanceGroup cams_to_include = params.get("cams_to_include", None) or session.linked_cameras # If not enough `Camcorder`s available/specified, then return @@ -3491,13 +3501,12 @@ def verify_views_and_instances(cls, context: CommandContext, params: dict) -> bo ): return False - # Get all instances accross views at this frame index + # Get all instances products accross views at this frame index instances = TriangulateSession.get_and_verify_enough_instances( context=context, session=session, - frame_inds=[frame_idx], + frame_idx=frame_idx, cams_to_include=cams_to_include, - track=track, show_dialog=show_dialog, ) @@ -3513,19 +3522,18 @@ def verify_views_and_instances(cls, context: CommandContext, params: dict) -> bo @staticmethod def get_and_verify_enough_instances( session: RecordingSession, - frame_inds: List[int], + frame_idx: int, context: Optional[CommandContext] = None, cams_to_include: Optional[List[Camcorder]] = None, - track: Union[Track, int] = -1, show_dialog: bool = True, ) -> Union[Dict[int, Dict[Camcorder, List[Instance]]], bool]: - """Get all instances accross views at this frame index. + """Get all instances accross views at this frame index (and products of instances). If not enough `Instance`s are available at this frame index, then return False. Args: session: The `RecordingSession` containing the `Camcorder`s. - frame_inds: List of frame indices to get instances from (0-indexed). + frame_idx: Frame index to get instances from (0-indexed). context: The optional command context used to display a dialog. cams_to_include: List of `Camcorder`s to include. Default is all. track: `Track` object used to find instances accross views. Default is -1 @@ -3533,20 +3541,18 @@ def get_and_verify_enough_instances( show_dialog: If True, then show a warning dialog. Default is True. Returns: - Dict with frame identifier keys (does not necessarily need to be the frame - index) and values of another inner dict with `Camcorder` keys and - `List[Instance]` values if enough instances are found, False otherwise. + Dict with frame identifier keys (not the frame index) and values of another + inner dict with `Camcorder` keys and `List[Instance]` values if enough + instances are found, False otherwise. """ try: instances: Dict[ int, Dict[Camcorder, List[Instance]] - ] = TriangulateSession.get_instances_across_views_multiple_frames( + ] = TriangulateSession.get_products_of_instances( session=session, - frame_inds=frame_inds, + frame_idx=frame_idx, cams_to_include=cams_to_include, - track=track, - require_multiple_views=True, ) return instances except Exception as e: @@ -3746,6 +3752,353 @@ def get_all_views_at_frame( return views + @staticmethod + def get_instance_grouping_and_reprojected_coords( + session: RecordingSession, + instance_hypotheses: Dict[int, Dict[Camcorder, List[Instance]]], + ): + """Get instance grouping and reprojected coords with lowest reprojection error. + + Triangulation of all possible groupings needs to be performed... Thus, we return + the best grouping's triangulation in this step to then be used when updating the + instance. + + Args: + session: The `RecordingSession` containing the `Camcorder`s. + instance_hypotheses: Dict with frame identifier keys (not the frame index) + and values of another inner dict with `Camcorder` keys and + `List[Instance]` values. + + + Returns: + best_instances_and_reprojected_coords: Dict with `Camcorder` keys with + `Tuple[Instance, np.ndarray]` values. + """ + + # Calculate reprojection error for each instance grouping + ( + reprojection_error_per_frame, + instances_and_coords, + ) = TriangulateSession.calculate_error_per_frame( + session=session, + instances=instance_hypotheses, + ) + + # Just for type hinting + reprojection_error_per_frame = cast( + Dict[int, float], reprojection_error_per_frame + ) + instances_and_coords = cast( + Dict[int, Dict[Camcorder, Iterator[Tuple[Instance, np.ndarray]]]], + instances_and_coords, + ) + + # Get instance grouping with lowest reprojection error + best_instances, frame_id_min_error = TriangulateSession._get_instance_grouping( + instances=instance_hypotheses, + reprojection_error_per_frame=reprojection_error_per_frame, + ) + + # Just for type hinting + best_instances = cast(Dict[Camcorder, List[Instance]], best_instances) + instances_and_coords = cast( + Dict[int, Dict[Camcorder, Iterator[Tuple[Instance, np.ndarray]]]], + instances_and_coords, + ) + + # Get the best reprojection + best_instances_and_reprojected_coords: Dict[ + Camcorder, Iterator[Tuple[Instance, np.ndarray]] + ] = instances_and_coords[frame_id_min_error] + + return best_instances_and_reprojected_coords + + @staticmethod + def _get_instance_grouping( + instances: Dict[int, Dict[Camcorder, List[Instance]]], + reprojection_error_per_frame: Dict[int, float], + ) -> Tuple[Dict[Camcorder, List[Instance]], int]: + """Get instance grouping with lowest reprojection error. + + Args: + instances: Dict with frame identifier keys (not the frame index) and values + of another inner dict with `Camcorder` keys and `List[Instance]` values. + reprojection_error_per_frame: Dict with frame identifier keys (not the frame + index) and values of reprojection error for the frame. + + Returns: + best_instances: Dict with `Camcorder` keys and `List[Instance]` values. + frame_id_min_error: The frame identifier with the lowest reprojection + """ + + frame_id_min_error: int = min( + reprojection_error_per_frame, key=reprojection_error_per_frame.get + ) + + best_instances: Dict[Camcorder, List[Instance]] = instances[frame_id_min_error] + + return best_instances, frame_id_min_error + + @staticmethod + def _calculate_reprojection_error( + session: RecordingSession, + instances: Dict[int, Dict[Camcorder, List[Instance]]], + per_instance: bool = False, + per_view: bool = False, + ) -> Tuple[ + Dict[int, Union[float, Dict[Camcorder, List[Tuple[Instance, float]]]]], + Dict[int, Dict[Camcorder, Iterator[Tuple[Instance, np.ndarray]]]], + ]: + """Calculate reprojection error per frame or per instance. + + Args: + session: The `RecordingSession` containing the `Camcorder`s. + instances: Dict with frame identifier keys (not the frame index) and values + of another inner dict with `Camcorder` keys and `List[Instance]` values. + per_instance: If True, then return a dict with frame identifier keys and + values of another inner dict with `Camcorder` keys and + `List[Tuple[Instance, float]]` values. + per_view: If True, then return a dict with frame identifier keys and values + of another inner dict with `Camcorder` keys and + `Tuple[Tuple[str, str], float]` values. If per_instance is True, then that takes precendence. + + Returns: + reprojection_per_frame: Dict with frame identifier keys (not the frame index) and values of another + inner dict with `Camcorder` keys and `List[Tuple[Instance, float]]` values + if per_instance is True, otherwise a dict with frame identifier keys and + values of reprojection error for the frame. + instances_and_coords: Dict with frame identifier keys (not the frame index) + and values of another inner dict with `Camcorder` keys and + `Iterator[Tuple[Instance, np.ndarray]]` values that contain the instance + and the reprojected coordinates. + + """ + + reprojection_error_per_frame = {} + + # Triangulate and reproject instance coordinates. + instances_and_coords: Dict[ + int, Dict[Camcorder, Iterator[Tuple[Instance, np.ndarray]]] + ] = TriangulateSession.calculate_reprojected_points( + session=session, instances=instances + ) + for frame_id, instances_in_frame in instances_and_coords.items(): + frame_error: Union[Dict, float] = {} if per_instance or per_view else 0 + for cam, instances_in_view in instances_in_frame.items(): + # Compare instance coordinates here + instance_ids: List[Union[Track, str]] = [] + view_error: Union[List, int] = [] if per_instance else 0 + for inst, inst_coords in instances_in_view: + node_errors = np.nan_to_num(inst.numpy() - inst_coords) + instance_error = np.linalg.norm(node_errors) + + if per_instance: + view_error = cast(List, view_error) + view_error.append((inst, instance_error)) + else: + view_error = cast(int, view_error) + view_error += instance_error + + inst_id: Union[Track, str] = ( + inst.track if inst.track is not None else "None" + ) + instance_ids.append(inst_id) + + if per_instance: + frame_error = cast(Dict, frame_error) + frame_error[cam] = view_error + elif per_view: + view_error = cast(int, view_error) + frame_error = cast( + Dict[Camcorder, Tuple[Tuple[Union[Track, str], ...], int]], + frame_error, + ) + frame_error[cam] = (tuple(instance_ids), view_error) + else: + view_error = cast(int, view_error) + frame_error = cast(int, frame_error) + frame_error += view_error + + reprojection_error_per_frame[frame_id] = frame_error + + return reprojection_error_per_frame, instances_and_coords + + @staticmethod + def calculate_error_per_instance( + session: RecordingSession, instances: Dict[int, Dict[Camcorder, List[Instance]]] + ) -> Tuple[ + Dict[int, Dict[Camcorder, List[Tuple[Instance, float]]]], + Dict[int, Dict[Camcorder, Iterator[Tuple[Instance, np.ndarray]]]], + ]: + """Calculate reprojection error per instance. + + Args: + session: The `RecordingSession` containing the `Camcorder`s. + instances: Dict with frame identifier keys (not the frame index) and values + of another inner dict with `Camcorder` keys and `List[Instance]` values. + + Returns: + reprojection_error_per_instance: Dict with frame identifier keys (not the + frame index) and values of another inner dict with `Camcorder` keys and + `List[Tuple[Instance, float]]` values. + instances_and_coords: Dict with frame identifier keys (not the frame index) + and values of another inner dict with `Camcorder` keys and + `Iterator[Tuple[Instance, np.ndarray]]` values that contain the instance + and the reprojected coordinates. + """ + + ( + reprojection_error_per_instance, + instances_and_coords, + ) = TriangulateSession._calculate_reprojection_error( + session=session, instances=instances, per_instance=True + ) + + return reprojection_error_per_instance, instances_and_coords + + @staticmethod + def calculate_error_per_view( + session: RecordingSession, instances: Dict[int, Dict[Camcorder, List[Instance]]] + ) -> Tuple[ + Dict[int, Dict[Camcorder, float]], + Dict[int, Dict[Camcorder, Iterator[Tuple[Instance, np.ndarray]]]], + ]: + """Calculate reprojection error per instance. + + Args: + session: The `RecordingSession` containing the `Camcorder`s. + instances: Dict with frame identifier keys (not the frame index) and values + of another inner dict with `Camcorder` keys and `List[Instance]` values. + + Returns: + reprojection_error_per_view: Dict with frame identifier keys (not the frame + index) and values of another inner dict with `Camcorder` keys and + `float` values. + instances_and_coords: Dict with frame identifier keys (not the frame index) + and values of another inner dict with `Camcorder` keys and + `Iterator[Tuple[Instance, np.ndarray]]` values that contain the instance + and the reprojected coordinates. + """ + + ( + reprojection_error_per_view, + instances_and_coords, + ) = TriangulateSession._calculate_reprojection_error( + session=session, instances=instances, per_view=True + ) + + return reprojection_error_per_view, instances_and_coords + + @staticmethod + def calculate_error_per_frame( + session: RecordingSession, instances: Dict[int, Dict[Camcorder, List[Instance]]] + ) -> Tuple[ + Dict[int, float], + Dict[int, Dict[Camcorder, Iterator[Tuple[Instance, np.ndarray]]]], + ]: + """Calculate reprojection error per frame. + + Args: + session: The `RecordingSession` containing the `Camcorder`s. + instances: Dict with frame identifier keys (not the frame index) and values + of another inner dict with `Camcorder` keys and `List[Instance]` values. + + Returns: + reprojection_error_per_frame: Dict with frame identifier keys (not the frame + index) and values of reprojection error for the frame. + instances_and_coords: Dict with frame identifier keys (not the frame index) + and values of another inner dict with `Camcorder` keys and + `Iterator[Tuple[Instance, np.ndarray]]` values that contain the instance + and the reprojected coordinates. + """ + + ( + reprojection_error_per_frame, + instances_and_coords, + ) = TriangulateSession._calculate_reprojection_error( + session=session, instances=instances, per_instance=False + ) + + return reprojection_error_per_frame, instances_and_coords + + @staticmethod + def get_products_of_instances( + session: RecordingSession, + frame_idx: int, + cams_to_include: Optional[List[Camcorder]] = None, + ) -> Dict[int, Dict[Camcorder, List[Instance]]]: + """Get all (multi-instance) possible products of instances across views. + + Args: + session: The `RecordingSession` containing the `Camcorder`s. + frame_idx: Frame index to get instances from (0-indexed). + cams_to_include: List of `Camcorder`s to include. Default is all. + require_multiple_views: If True, then raise and error if one or less views + or instances are found. + + Returns: + Dict with frame identifier keys (not the frame index) and values of another + inner dict with `Camcorder` keys and `List[Instance]` values. Each + `List[Instance]` is of length "max number of instances in frame set". + """ + + # Get all instances accross views at this frame index, then remove selected + instances: Dict[ + Camcorder, List[Instance] + ] = TriangulateSession.get_instances_across_views( + session=session, + frame_idx=frame_idx, + cams_to_include=cams_to_include, + track=-1, # Get all instances regardless of track. + require_multiple_views=True, + ) + + # Get the skeleton from an example instance + skeleton = next(iter(instances.values()))[0].skeleton + + # Find max number of instances in other views + max_num_instances = max([len(instances) for instances in instances.values()]) + + # Create a dummy instance of all nan values + dummy_instance = Instance.from_numpy( + np.full( + shape=(len(skeleton.nodes), 2), + fill_value=np.nan, + ), + skeleton=skeleton, + ) + + # Get permutations of instances from other views + instances_permutations: Dict[Camcorder, Iterator[Tuple]] = {} + for cam, instances_in_view in instances.items(): + # Append a dummy instance to all lists of instances if less than the max length + num_missing = 1 + num_instances = len(instances_in_view) + if num_instances < max_num_instances: + num_missing = max_num_instances - num_instances + + # Extend the list first + instances_in_view.extend([dummy_instance] * num_missing) + + # Permute instances into all possible orderings w/in a view + instances_permutations[cam] = permutations(instances_in_view) + + # Get products of instances from other views into all possible groupings + # Ordering of dict_values is preserved in Python 3.7+ + products_of_instances: Iterator[Iterator[Tuple]] = product( + *instances_permutations.values() + ) + + # Reorganize products by cam and add selected instance to each permutation + instances_hypotheses: Dict[int, Dict[Camcorder, List[Instance]]] = {} + for frame_id, prod in enumerate(products_of_instances): + instances_hypotheses[frame_id] = { + cam: [*inst] for cam, inst in zip(instances.keys(), prod) + } + + # Expect "max # instances in view" ** "# views" frames (a.k.a. hypotheses) + return instances_hypotheses + @staticmethod def get_instances_matrices( instances: Dict[int, Dict[Camcorder, List[Instance]]], @@ -3800,7 +4153,9 @@ def get_instances_matrices( inst_coords_frames.append( inst_coords_views ) # len=frame_idx, List[M x T x N x 2] + inst_coords = np.stack(inst_coords_frames, axis=1) # M x F x T x N x 2 + cams_ordered = cast(List[Camcorder], cams_ordered) # Could be None if no frames return inst_coords, cams_ordered @@ -3871,7 +4226,7 @@ def calculate_excluded_views_multiple_frames( @staticmethod def _calculate_reprojected_points( session: RecordingSession, instances: Dict[int, Dict[Camcorder, List[Instance]]] - ) -> Dict[int, Dict[Camcorder, Iterator[Tuple[Instance, np.ndarray]]]]: + ) -> Tuple[np.ndarray, List[Camcorder]]: """Triangulate and reproject instance coordinates. Note that the order of the instances in the list must match the order of the @@ -3919,8 +4274,11 @@ def _calculate_reprojected_points( return inst_coords_reprojected, cams_ordered + @staticmethod def group_instances_and_coords( - instances, inst_coords_reprojected, cams_ordered + instances: Dict[int, Dict[Camcorder, List[Instance]]], + inst_coords_reprojected: np.ndarray, + cams_ordered: List[Camcorder], ) -> Dict[int, Dict[Camcorder, Iterator[Tuple[Instance, np.ndarray]]]]: """Group instances and reprojected coordinates by frame and view. @@ -3968,7 +4326,9 @@ def group_instances_and_coords( {} ) # Dict len(F) of dict len(M) of zipped lists of len(T) instances and array of N x 2 for frame_idx, instances_in_frame in instances.items(): # len(F) of dict - insts_and_coords_in_frame: Dict[Camcorder, Tuple[Instance, np.ndarray]] = {} + insts_and_coords_in_frame: Dict[ + Camcorder, Iterator[Tuple[Instance, np.ndarray]] + ] = {} for cam_idx, cam in enumerate(cams_ordered): instances_in_frame_ordered: List[Instance] = instances_in_frame[ cam @@ -3976,7 +4336,9 @@ def group_instances_and_coords( insts_coords_in_frame: np.ndarray = insts_coords_list[frame_idx][ cam_idx ] # len(T) of N x 2 - insts_and_coords_in_frame[cam]: Tuple[Instance, np.ndarray] = zip( + + # TODO(LM): I think we will need a reconsumable iterator here. + insts_and_coords_in_frame[cam] = zip( instances_in_frame_ordered, insts_coords_in_frame, ) @@ -4011,7 +4373,7 @@ def calculate_reprojected_points( session=session, instances=instances ) - # Group together instances (the reordered by cam) and the reprojected coords. + # Reorder instances (by cam) and the reprojected coords. instances_and_coords: Dict[ int, Dict[Camcorder, Iterator[Tuple[Instance, np.ndarray]]] ] = TriangulateSession.group_instances_and_coords( @@ -4024,31 +4386,23 @@ def calculate_reprojected_points( @staticmethod def update_instances( - session, instances: Dict[int, Dict[Camcorder, List[Instance]]] + instances_and_coords: Dict[Camcorder, Iterator[Tuple[Instance, np.ndarray]]] ): - """Triangulate, reproject, and update coordinates of `Instances`. + """Update coordinates of `Instances`. Args: - session: The `RecordingSession` containing the `Camcorder`s. - instances: Dict with frame identifier keys (does not necessarily need to be - the frame index) and values of another inner dict with `Camcorder` keys - and `List[Instance]` values. + instances_and_coords: Dict with `Camcorder` keys and + `Iterator[Tuple[Instance, np.ndarray]]` values containing the Instance + and it's reprojected coordinates. Returns: None """ - # Triangulate and reproject instance coordinates. - instances_and_coords = TriangulateSession.calculate_reprojected_points( - session=session, instances=instances - ) - - # TODO(LM): Since we only use the values here, is a dictionary overkill? # Update the instance coordinates. - for instances_in_frame in instances_and_coords.values(): - for instances_in_view in instances_in_frame.values(): - for inst, inst_coord in instances_in_view: - inst.update_points(points=inst_coord, exclude_complete=True) + for instances_in_view in instances_and_coords.values(): + for inst, inst_coord in instances_in_view: + inst.update_points(points=inst_coord, exclude_complete=True) def open_website(url: str): diff --git a/tests/gui/test_commands.py b/tests/gui/test_commands.py index 2c651756f..c6a9a279a 100644 --- a/tests/gui/test_commands.py +++ b/tests/gui/test_commands.py @@ -1063,32 +1063,29 @@ def test_triangulate_session_get_and_verify_enough_instances( labels = multiview_min_session_labels session = labels.sessions[0] lf = labels.labeled_frames[0] - track = labels.tracks[1] # Test with no cams_to_include, expect views from all linked cameras instances = TriangulateSession.get_and_verify_enough_instances( - session=session, frame_inds=[lf.frame_idx], track=track + session=session, frame_idx=lf.frame_idx ) - instances_in_frame = instances[lf.frame_idx] + instances_in_frame = instances[0] assert ( - len(instances_in_frame) == 6 - ) # Some views don't have an instance at this track + len(instances_in_frame) == 8 + ) # All views should have same number of instances (padded with dummy instance(s)) for cam in session.linked_cameras: if cam.name in ["side", "sideL"]: # The views that don't have an instance continue instances_in_view = instances_in_frame[cam] for inst in instances_in_view: assert inst.frame_idx == lf.frame_idx - assert inst.track == track assert inst.video == session[cam] # Test with cams_to_include, expect views from only those cameras cams_to_include = session.linked_cameras[-2:] instances = TriangulateSession.get_and_verify_enough_instances( session=session, - frame_inds=[lf.frame_idx], + frame_idx=lf.frame_idx, cams_to_include=cams_to_include, - track=track, ) instances_in_frame = instances[lf.frame_idx] assert len(instances_in_frame) == len(cams_to_include) @@ -1096,21 +1093,24 @@ def test_triangulate_session_get_and_verify_enough_instances( instances_in_view = instances_in_frame[cam] for inst in instances_in_view: assert inst.frame_idx == lf.frame_idx - assert inst.track == track assert inst.video == session[cam] # Test with not enough instances, expect views from only those cameras cams_to_include = session.linked_cameras[0:2] + cam = cams_to_include[0] + video = session[cam] + lfs = labels.find(video, lf.frame_idx) + lf = lfs[0] + lf.instances = [] instances = TriangulateSession.get_and_verify_enough_instances( session=session, - frame_inds=[lf.frame_idx], + frame_idx=lf.frame_idx, cams_to_include=cams_to_include, - track=None, ) assert isinstance(instances, bool) assert not instances messages = "".join([rec.message for rec in caplog.records]) - assert "One or less instances found for frame" in messages + assert "No Instances found for" in messages def test_triangulate_session_verify_enough_views( @@ -1259,7 +1259,9 @@ def test_triangulate_session_update_instances(multiview_min_session_labels: Labe # Just run for code coverage testing, do not test output here (race condition) # (see "functional core, imperative shell" pattern) - TriangulateSession.update_instances(session=session, instances=instances) + TriangulateSession.update_instances( + instances_and_coords=instances_and_coordinates[0] + ) def test_triangulate_session_do_action(multiview_min_session_labels: Labels): @@ -1319,3 +1321,101 @@ def test_triangulate_session(multiview_min_session_labels: Labels): context.state["instance"] = instance context.state["frame_idx"] = lf.frame_idx context.triangulateSession() + + +def test_triangulate_session_get_products_of_instances( + multiview_min_session_labels: Labels, +): + """Test `TriangulateSession.get_products_of_instances`.""" + + labels = multiview_min_session_labels + session = labels.sessions[0] + lf = labels.labeled_frames[0] + selected_instance = lf.instances[0] + + instances = TriangulateSession.get_products_of_instances( + session=session, + frame_idx=lf.frame_idx, + ) + + views = TriangulateSession.get_all_views_at_frame(session, lf.frame_idx) + max_num_instances_in_view = max([len(instances) for instances in views.values()]) + assert len(instances) == max_num_instances_in_view ** len(views) + + for frame_id in instances: + instances_in_frame = instances[frame_id] + for cam in instances_in_frame: + instances_in_view = instances_in_frame[cam] + assert len(instances_in_view) == max_num_instances_in_view + for inst in instances_in_view: + try: + assert inst.frame_idx == selected_instance.frame_idx + assert inst.video == session[cam] + except: + assert inst.frame is None + assert inst.video is None + + +def test_triangulate_session_calculate_error_per_frame( + multiview_min_session_labels: Labels, +): + """Test `TriangulateSession.calculate_error_per_frame`.""" + + labels = multiview_min_session_labels + session = labels.sessions[0] + lf = labels.labeled_frames[0] + + instances = TriangulateSession.get_products_of_instances( + session=session, + frame_idx=lf.frame_idx, + ) + + ( + reprojection_error_per_frame, + instances_and_coords, + ) = TriangulateSession.calculate_error_per_frame( + session=session, instances=instances + ) + + for frame_id in instances.keys(): + assert frame_id in reprojection_error_per_frame + assert isinstance(reprojection_error_per_frame[frame_id], float) + + +def test_triangulate_session_get_instance_grouping( + multiview_min_session_labels: Labels, +): + """Test `TriangulateSession._get_instance_grouping`.""" + + labels = multiview_min_session_labels + session = labels.sessions[0] + lf = labels.labeled_frames[0] + selected_instance = lf.instances[0] + + instances = TriangulateSession.get_products_of_instances( + session=session, + frame_idx=lf.frame_idx, + ) + + ( + reprojection_error_per_frame, + instances_and_coords, + ) = TriangulateSession.calculate_error_per_frame( + session=session, instances=instances + ) + + best_instances, frame_id_min_error = TriangulateSession._get_instance_grouping( + instances=instances, reprojection_error_per_frame=reprojection_error_per_frame + ) + assert len(best_instances) == len(session.camera_cluster) + for instances_in_view in best_instances.values(): + tracks_in_view = set( + [inst.track if inst is not None else "None" for inst in instances_in_view] + ) + assert len(tracks_in_view) == len(instances_in_view) + for inst in instances_in_view: + try: + assert inst.frame_idx == selected_instance.frame_idx + except: + assert inst.frame is None + assert inst.track is None From c3a81736a0fb21266095e8a0997f498e09906525 Mon Sep 17 00:00:00 2001 From: Liezl Maree <38435167+roomrys@users.noreply.github.com> Date: Fri, 12 Apr 2024 08:41:53 -0700 Subject: [PATCH 08/22] (3e -> 3a) Add `InstanceGroup` class (#1618) * Add method to get single instance permutations * Add method and (failing) test to get instance grouping * Append a dummy instance for missing instances * Update tests to accept a dummy instance * Add initial InstanceGroup class * Few extra tests for `InstanceGroup` * Remember instance grouping after testing hypotheses * Use reconsumable iterator for reprojected coords * Only triangulate user instances, add fixture, update tests * Normalize instance reprojection errors * Add `locked`, `_dummy_instance`, `numpy`, and `update_points` * Allow `PredictedPoint`s to be updated as well * Add tests for new attributes and methods * Add methods to create, add, replace, and remove instances * Use PredictedInstance for new/dummy instances * (3f -> 3e) Add `FrameGroup` class (#1665) * (3g -> 3f) Use frame group for triangulation (#1693) --- sleap/gui/commands.py | 1049 +------------ sleap/instance.py | 25 +- sleap/io/cameras.py | 1319 ++++++++++++++++- .../min_session_user_labeled.slp | Bin 0 -> 60444 bytes tests/fixtures/datasets.py | 8 + tests/gui/test_commands.py | 465 ------ tests/io/test_cameras.py | 173 ++- 7 files changed, 1594 insertions(+), 1445 deletions(-) create mode 100644 tests/data/cameras/minimal_session/min_session_user_labeled.slp diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index 7b6b44d64..fbfcb4b81 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -54,7 +54,7 @@ class which inherits from `AppCommand` (or a more specialized class such as from sleap.gui.state import GuiState from sleap.gui.suggestions import VideoFrameSuggestions from sleap.instance import Instance, LabeledFrame, Point, PredictedInstance, Track -from sleap.io.cameras import Camcorder, RecordingSession +from sleap.io.cameras import Camcorder, InstanceGroup, FrameGroup, RecordingSession from sleap.io.convert import default_analysis_filename from sleap.io.dataset import Labels from sleap.io.format.adaptor import Adaptor @@ -3406,1003 +3406,114 @@ def do_action(cls, context: CommandContext, params: dict): ask_again: If True, then ask for views/instances again. Default is False. """ - # Check if we already ran ask - ask_again = params.get("ask_again", False) - - # Add "instances" to params dict without GUI, otherwise taken care of in ask - if ask_again: - params["show_dialog"] = False - enough_instances = cls.verify_views_and_instances( - context=context, params=params - ) - if not enough_instances: - return - - # Get params + # Get `FrameGroup` for the current frame index video = params.get("video", None) or context.state["video"] session = params.get("session", None) or context.labels.get_session(video) - instances = params["instances"] - session = cast(RecordingSession, session) # Could be None if no labels or video - - # Get best instance grouping and reprojected coords - instances_and_reprojected_coords = ( - TriangulateSession.get_instance_grouping_and_reprojected_coords( - session=session, instance_hypotheses=instances - ) - ) - - # Update instances - TriangulateSession.update_instances( - instances_and_coords=instances_and_reprojected_coords + frame_idx: int = params["frame_idx"] + frame_group: FrameGroup = ( + params.get("frame_group", None) or session.frame_groups[frame_idx] ) - @classmethod - def ask(cls, context: CommandContext, params: dict) -> bool: - """Add "instances" to params dict if enough views/instances, warning user otherwise. - - Args: - context: The command context. - params: The command parameters. - video: The `Video` object to use. Default is current video. - session: The `RecordingSession` object to use. Default is current - video's session. - frame_idx: The frame index to use. Default is current frame index. - instance: The `Instance` object to use. Default is current instance. - show_dialog: If True, then show a warning dialog. Default is True. - - Returns: - True if enough views/instances for triangulation, False otherwise. - """ - - return cls.verify_views_and_instances(context=context, params=params) - - @classmethod - def verify_views_and_instances(cls, context: CommandContext, params: dict) -> bool: - """Verify that there are enough views and instances to triangulate. - - Also adds "instances" to params dict if there are enough views and instances. - - Args: - context: The command context. - params: The command parameters. - video: The `Video` object used to lookup a `session` (if not provided). - Default is current video. - session: The `RecordingSession` object to use. Default is current - video's session. - frame_idx: The frame index to use. Default is current frame index. - instance: The `Instance` object to use. Default is current instance. - show_dialog: If True, then show a warning dialog. Default is True. - - Returns: - True if enough views/instances for triangulation, False otherwise. - """ - - video = params.get("video", None) or context.state["video"] - session = params.get("session", None) or context.labels.get_session(video) + # Get the `InstanceGroup` from `Instance` if any instance = params.get("instance", None) or context.state["instance"] - show_dialog = params.get("show_dialog", True) - - # This value could possibly be 0, so we can't use "or" - frame_idx = params.get("frame_idx", None) - frame_idx = frame_idx if frame_idx is not None else context.state["frame_idx"] - - # Return if we don't have a session for video or an instance selected. - if session is None or instance is None: - return - - cams_to_include = params.get("cams_to_include", None) or session.linked_cameras + instance_group = frame_group.get_instance_group(instance) - # If not enough `Camcorder`s available/specified, then return - if not TriangulateSession.verify_enough_views( - context=context, - session=session, - cams_to_include=cams_to_include, - show_dialog=show_dialog, - ): - return False - - # Get all instances products accross views at this frame index - instances = TriangulateSession.get_and_verify_enough_instances( - context=context, - session=session, - frame_idx=frame_idx, - cams_to_include=cams_to_include, - show_dialog=show_dialog, + # If instance_group is None, then we will try to triangulate entire frame_group + instance_groups = ( + [instance_group] + if instance_group is not None + else frame_group.instance_groups ) - # Return if not enough instances - if not instances: - return False - - # Add instances to params dict - params["instances"] = instances - - return True - - @staticmethod - def get_and_verify_enough_instances( - session: RecordingSession, - frame_idx: int, - context: Optional[CommandContext] = None, - cams_to_include: Optional[List[Camcorder]] = None, - show_dialog: bool = True, - ) -> Union[Dict[int, Dict[Camcorder, List[Instance]]], bool]: - """Get all instances accross views at this frame index (and products of instances). - - If not enough `Instance`s are available at this frame index, then return False. - - Args: - session: The `RecordingSession` containing the `Camcorder`s. - frame_idx: Frame index to get instances from (0-indexed). - context: The optional command context used to display a dialog. - cams_to_include: List of `Camcorder`s to include. Default is all. - track: `Track` object used to find instances accross views. Default is -1 - which finds all instances regardless of track. - show_dialog: If True, then show a warning dialog. Default is True. - - Returns: - Dict with frame identifier keys (not the frame index) and values of another - inner dict with `Camcorder` keys and `List[Instance]` values if enough - instances are found, False otherwise. - """ - - try: - instances: Dict[ - int, Dict[Camcorder, List[Instance]] - ] = TriangulateSession.get_products_of_instances( - session=session, - frame_idx=frame_idx, - cams_to_include=cams_to_include, - ) - return instances - except Exception as e: - # If not enough views, instances or some other error, then return - message = str(e) - message += "\n\tSkipping triangulation and reprojection." - logger.warning(message) - return False - - @staticmethod - def verify_enough_views( - session: RecordingSession, - context: Optional[CommandContext] = None, - cams_to_include: Optional[List[Camcorder]] = None, - show_dialog=True, - ): - """If not enough `Camcorder`s available/specified, then return False. - - Args: - session: The `RecordingSession` containing the `Camcorder`s. - context: The optional command context, used to display a dialog. - cams_to_include: List of `Camcorder`s to include. Default is all. - show_dialog: If True, then show a warning dialog. Default is True. - - Returns: - True if enough views are available, False otherwise. - """ - - if (cams_to_include is not None and len(cams_to_include) <= 1) or ( - len(session.videos) <= 1 - ): - message = ( - "One or less cameras available. " - "Multiple cameras needed to triangulate. " - "Skipping triangulation and reprojection." - ) - if show_dialog and context is not None: - QtWidgets.QMessageBox.warning(context.app, "Triangulation", message) - else: - logger.warning(message) - - return False - - return True - - @staticmethod - def get_instances_across_views( - session: RecordingSession, - frame_idx: int, - cams_to_include: Optional[List[Camcorder]] = None, - track: Union[Track, int] = -1, - require_multiple_views: bool = False, - ) -> Dict[Camcorder, List[Instance]]: - """Get all `Instances` accross all views at a given frame index. - - Args: - session: The `RecordingSession` containing the `Camcorder`s. - frame_idx: Frame index to get instances from (0-indexed). - cams_to_include: List of `Camcorder`s to include. Default is all. - track: `Track` object used to find instances accross views. Default is -1 - which find all instances regardless of track. - require_multiple_views: If True, then raise and error if one or less views - or instances are found. - - Returns: - Dict with `Camcorder` keys and `List[Instance]` values. - - Raises: - ValueError if require_multiple_view is true and one or less views or - instances are found. - """ - - def _message(views: bool): - views_or_instances = "views" if views else "instances" - return ( - f"One or less {views_or_instances} found for frame " - f"{frame_idx} in {session.camera_cluster}. " - "Multiple instances accross multiple views needed to triangulate." - ) - - # Get all views at this frame index - views: Dict[ - Camcorder, "LabeledFrame" - ] = TriangulateSession.get_all_views_at_frame( - session=session, + # Retain instance groups that have enough views/instances for triangulation + instance_groups = TriangulateSession.has_enough_instances( + frame_group=frame_group, + instance_groups=instance_groups, frame_idx=frame_idx, - cams_to_include=cams_to_include, - ) - - # TODO(LM): Should we just skip this frame if not enough views? - # If not enough views, then raise error - if len(views) <= 1 and require_multiple_views: - raise ValueError(_message(views=True)) - - # Find all instance accross all views - instances_in_frame: Dict[Camcorder, List[Instance]] = {} - for cam, lf in views.items(): - insts = lf.find(track=track) - if len(insts) > 0: - instances_in_frame[cam] = insts - - # If not enough instances for multiple views, then raise error - if len(instances_in_frame) <= 1 and require_multiple_views: - raise ValueError(_message(views=False)) - - return instances_in_frame - - @staticmethod - def get_instances_across_views_multiple_frames( - session: RecordingSession, - frame_inds: List[int], - cams_to_include: Optional[List[Camcorder]] = None, - track: Union[Track, int] = -1, - require_multiple_views: bool = False, - ) -> Dict[int, Dict[Camcorder, List[Instance]]]: - """Get all `Instances` accross all views at all given frame indices. - - Args: - session: The `RecordingSession` containing the `Camcorder`s. - frame_inds: List of frame indices to get instances from (0-indexed). - cams_to_include: List of `Camcorder`s to include. Default is all. - track: `Track` object used to find instances accross views. Default is -1 - which find all instances regardless of track. - require_multiple_views: If True, then raise and error if one or less views - or instances are found. - - Returns: - Dict with frame identifier keys (does not necessarily need to be the frame - index) and values of another inner dict with `Camcorder` keys and - `List[Instance]` values. - """ - - instances: Dict[int, Dict[Camcorder, List[Instance]]] = {} - for frame_idx in frame_inds: - try: - # Find all instance accross all views - instances_in_frame = TriangulateSession.get_instances_across_views( - session=session, - frame_idx=frame_idx, - cams_to_include=cams_to_include, - track=track, - require_multiple_views=require_multiple_views, - ) - instances[frame_idx] = instances_in_frame - except ValueError: - message = traceback.format_exc() - message += " Skipping frame." - logger.warning(f"{message}") - - if len(instances) == 0: - frame_inds_str = ", ".join([str(frame_idx) for frame_idx in frame_inds]) - raise ValueError( - "Not enough instances or views found for any frame identifiers in " - f"{frame_inds_str}." - ) - - return instances - - @staticmethod - def get_all_views_at_frame( - session: RecordingSession, - frame_idx, - cams_to_include: Optional[List[Camcorder]] = None, - ) -> Dict[Camcorder, "LabeledFrame"]: - """Get all views at a given frame index. - - Args: - session: The `RecordingSession` containing the `Camcorder`s. - frame_idx: Frame index to get views from (0-indexed). - cams_to_include: List of `Camcorder`s to include. Default is all. - - Returns: - Dict with `Camcorder` keys and `LabeledFrame` values. - """ - - views: Dict[Camcorder, "LabeledFrame"] = {} - videos: Dict[Camcorder, Video] = session.get_videos_from_selected_cameras( - cams_to_include=cams_to_include - ) - for cam, video in videos.items(): - lfs: List["LabeledFrame"] = session.labels.get((video, [frame_idx])) - if len(lfs) == 0: - logger.debug( - f"No LabeledFrames found for video {video} at {frame_idx}." - ) - continue - - lf = lfs[0] - if len(lf.instances) == 0: - logger.warning( - f"No Instances found for {lf}." - " There should not be empty LabeledFrames." - ) - continue - - views[cam] = lf - - return views - - @staticmethod - def get_instance_grouping_and_reprojected_coords( - session: RecordingSession, - instance_hypotheses: Dict[int, Dict[Camcorder, List[Instance]]], - ): - """Get instance grouping and reprojected coords with lowest reprojection error. - - Triangulation of all possible groupings needs to be performed... Thus, we return - the best grouping's triangulation in this step to then be used when updating the - instance. - - Args: - session: The `RecordingSession` containing the `Camcorder`s. - instance_hypotheses: Dict with frame identifier keys (not the frame index) - and values of another inner dict with `Camcorder` keys and - `List[Instance]` values. - - - Returns: - best_instances_and_reprojected_coords: Dict with `Camcorder` keys with - `Tuple[Instance, np.ndarray]` values. - """ - - # Calculate reprojection error for each instance grouping - ( - reprojection_error_per_frame, - instances_and_coords, - ) = TriangulateSession.calculate_error_per_frame( - session=session, - instances=instance_hypotheses, - ) - - # Just for type hinting - reprojection_error_per_frame = cast( - Dict[int, float], reprojection_error_per_frame - ) - instances_and_coords = cast( - Dict[int, Dict[Camcorder, Iterator[Tuple[Instance, np.ndarray]]]], - instances_and_coords, - ) - - # Get instance grouping with lowest reprojection error - best_instances, frame_id_min_error = TriangulateSession._get_instance_grouping( - instances=instance_hypotheses, - reprojection_error_per_frame=reprojection_error_per_frame, - ) - - # Just for type hinting - best_instances = cast(Dict[Camcorder, List[Instance]], best_instances) - instances_and_coords = cast( - Dict[int, Dict[Camcorder, Iterator[Tuple[Instance, np.ndarray]]]], - instances_and_coords, - ) - - # Get the best reprojection - best_instances_and_reprojected_coords: Dict[ - Camcorder, Iterator[Tuple[Instance, np.ndarray]] - ] = instances_and_coords[frame_id_min_error] - - return best_instances_and_reprojected_coords - - @staticmethod - def _get_instance_grouping( - instances: Dict[int, Dict[Camcorder, List[Instance]]], - reprojection_error_per_frame: Dict[int, float], - ) -> Tuple[Dict[Camcorder, List[Instance]], int]: - """Get instance grouping with lowest reprojection error. - - Args: - instances: Dict with frame identifier keys (not the frame index) and values - of another inner dict with `Camcorder` keys and `List[Instance]` values. - reprojection_error_per_frame: Dict with frame identifier keys (not the frame - index) and values of reprojection error for the frame. - - Returns: - best_instances: Dict with `Camcorder` keys and `List[Instance]` values. - frame_id_min_error: The frame identifier with the lowest reprojection - """ - - frame_id_min_error: int = min( - reprojection_error_per_frame, key=reprojection_error_per_frame.get - ) - - best_instances: Dict[Camcorder, List[Instance]] = instances[frame_id_min_error] - - return best_instances, frame_id_min_error - - @staticmethod - def _calculate_reprojection_error( - session: RecordingSession, - instances: Dict[int, Dict[Camcorder, List[Instance]]], - per_instance: bool = False, - per_view: bool = False, - ) -> Tuple[ - Dict[int, Union[float, Dict[Camcorder, List[Tuple[Instance, float]]]]], - Dict[int, Dict[Camcorder, Iterator[Tuple[Instance, np.ndarray]]]], - ]: - """Calculate reprojection error per frame or per instance. - - Args: - session: The `RecordingSession` containing the `Camcorder`s. - instances: Dict with frame identifier keys (not the frame index) and values - of another inner dict with `Camcorder` keys and `List[Instance]` values. - per_instance: If True, then return a dict with frame identifier keys and - values of another inner dict with `Camcorder` keys and - `List[Tuple[Instance, float]]` values. - per_view: If True, then return a dict with frame identifier keys and values - of another inner dict with `Camcorder` keys and - `Tuple[Tuple[str, str], float]` values. If per_instance is True, then that takes precendence. - - Returns: - reprojection_per_frame: Dict with frame identifier keys (not the frame index) and values of another - inner dict with `Camcorder` keys and `List[Tuple[Instance, float]]` values - if per_instance is True, otherwise a dict with frame identifier keys and - values of reprojection error for the frame. - instances_and_coords: Dict with frame identifier keys (not the frame index) - and values of another inner dict with `Camcorder` keys and - `Iterator[Tuple[Instance, np.ndarray]]` values that contain the instance - and the reprojected coordinates. - - """ - - reprojection_error_per_frame = {} - - # Triangulate and reproject instance coordinates. - instances_and_coords: Dict[ - int, Dict[Camcorder, Iterator[Tuple[Instance, np.ndarray]]] - ] = TriangulateSession.calculate_reprojected_points( - session=session, instances=instances - ) - for frame_id, instances_in_frame in instances_and_coords.items(): - frame_error: Union[Dict, float] = {} if per_instance or per_view else 0 - for cam, instances_in_view in instances_in_frame.items(): - # Compare instance coordinates here - instance_ids: List[Union[Track, str]] = [] - view_error: Union[List, int] = [] if per_instance else 0 - for inst, inst_coords in instances_in_view: - node_errors = np.nan_to_num(inst.numpy() - inst_coords) - instance_error = np.linalg.norm(node_errors) - - if per_instance: - view_error = cast(List, view_error) - view_error.append((inst, instance_error)) - else: - view_error = cast(int, view_error) - view_error += instance_error - - inst_id: Union[Track, str] = ( - inst.track if inst.track is not None else "None" - ) - instance_ids.append(inst_id) - - if per_instance: - frame_error = cast(Dict, frame_error) - frame_error[cam] = view_error - elif per_view: - view_error = cast(int, view_error) - frame_error = cast( - Dict[Camcorder, Tuple[Tuple[Union[Track, str], ...], int]], - frame_error, - ) - frame_error[cam] = (tuple(instance_ids), view_error) - else: - view_error = cast(int, view_error) - frame_error = cast(int, frame_error) - frame_error += view_error - - reprojection_error_per_frame[frame_id] = frame_error - - return reprojection_error_per_frame, instances_and_coords - - @staticmethod - def calculate_error_per_instance( - session: RecordingSession, instances: Dict[int, Dict[Camcorder, List[Instance]]] - ) -> Tuple[ - Dict[int, Dict[Camcorder, List[Tuple[Instance, float]]]], - Dict[int, Dict[Camcorder, Iterator[Tuple[Instance, np.ndarray]]]], - ]: - """Calculate reprojection error per instance. - - Args: - session: The `RecordingSession` containing the `Camcorder`s. - instances: Dict with frame identifier keys (not the frame index) and values - of another inner dict with `Camcorder` keys and `List[Instance]` values. - - Returns: - reprojection_error_per_instance: Dict with frame identifier keys (not the - frame index) and values of another inner dict with `Camcorder` keys and - `List[Tuple[Instance, float]]` values. - instances_and_coords: Dict with frame identifier keys (not the frame index) - and values of another inner dict with `Camcorder` keys and - `Iterator[Tuple[Instance, np.ndarray]]` values that contain the instance - and the reprojected coordinates. - """ - - ( - reprojection_error_per_instance, - instances_and_coords, - ) = TriangulateSession._calculate_reprojection_error( - session=session, instances=instances, per_instance=True - ) - - return reprojection_error_per_instance, instances_and_coords - - @staticmethod - def calculate_error_per_view( - session: RecordingSession, instances: Dict[int, Dict[Camcorder, List[Instance]]] - ) -> Tuple[ - Dict[int, Dict[Camcorder, float]], - Dict[int, Dict[Camcorder, Iterator[Tuple[Instance, np.ndarray]]]], - ]: - """Calculate reprojection error per instance. - - Args: - session: The `RecordingSession` containing the `Camcorder`s. - instances: Dict with frame identifier keys (not the frame index) and values - of another inner dict with `Camcorder` keys and `List[Instance]` values. - - Returns: - reprojection_error_per_view: Dict with frame identifier keys (not the frame - index) and values of another inner dict with `Camcorder` keys and - `float` values. - instances_and_coords: Dict with frame identifier keys (not the frame index) - and values of another inner dict with `Camcorder` keys and - `Iterator[Tuple[Instance, np.ndarray]]` values that contain the instance - and the reprojected coordinates. - """ - - ( - reprojection_error_per_view, - instances_and_coords, - ) = TriangulateSession._calculate_reprojection_error( - session=session, instances=instances, per_view=True + instance=instance, ) + if instance_groups is None or len(instance_groups) == 0: + return # Not enough instances for triangulation - return reprojection_error_per_view, instances_and_coords + # Get the `FrameGroup` of shape M=include x T x N x 2 + fg_tensor = frame_group.numpy(instance_groups=instance_groups) - @staticmethod - def calculate_error_per_frame( - session: RecordingSession, instances: Dict[int, Dict[Camcorder, List[Instance]]] - ) -> Tuple[ - Dict[int, float], - Dict[int, Dict[Camcorder, Iterator[Tuple[Instance, np.ndarray]]]], - ]: - """Calculate reprojection error per frame. + # Add extra dimension for number of frames + frame_group_tensor = np.expand_dims(fg_tensor, axis=1) # M=include x F=1 xTxNx2 - Args: - session: The `RecordingSession` containing the `Camcorder`s. - instances: Dict with frame identifier keys (not the frame index) and values - of another inner dict with `Camcorder` keys and `List[Instance]` values. - - Returns: - reprojection_error_per_frame: Dict with frame identifier keys (not the frame - index) and values of reprojection error for the frame. - instances_and_coords: Dict with frame identifier keys (not the frame index) - and values of another inner dict with `Camcorder` keys and - `Iterator[Tuple[Instance, np.ndarray]]` values that contain the instance - and the reprojected coordinates. - """ - - ( - reprojection_error_per_frame, - instances_and_coords, - ) = TriangulateSession._calculate_reprojection_error( - session=session, instances=instances, per_instance=False - ) - - return reprojection_error_per_frame, instances_and_coords - - @staticmethod - def get_products_of_instances( - session: RecordingSession, - frame_idx: int, - cams_to_include: Optional[List[Camcorder]] = None, - ) -> Dict[int, Dict[Camcorder, List[Instance]]]: - """Get all (multi-instance) possible products of instances across views. - - Args: - session: The `RecordingSession` containing the `Camcorder`s. - frame_idx: Frame index to get instances from (0-indexed). - cams_to_include: List of `Camcorder`s to include. Default is all. - require_multiple_views: If True, then raise and error if one or less views - or instances are found. - - Returns: - Dict with frame identifier keys (not the frame index) and values of another - inner dict with `Camcorder` keys and `List[Instance]` values. Each - `List[Instance]` is of length "max number of instances in frame set". - """ - - # Get all instances accross views at this frame index, then remove selected - instances: Dict[ - Camcorder, List[Instance] - ] = TriangulateSession.get_instances_across_views( - session=session, - frame_idx=frame_idx, - cams_to_include=cams_to_include, - track=-1, # Get all instances regardless of track. - require_multiple_views=True, - ) - - # Get the skeleton from an example instance - skeleton = next(iter(instances.values()))[0].skeleton + # Triangulate to one 3D pose per instance + points_3d = triangulate( + p2d=frame_group_tensor, + calib=session.camera_cluster, + excluded_views=frame_group.excluded_views, + ) # F x T x N x 3 - # Find max number of instances in other views - max_num_instances = max([len(instances) for instances in instances.values()]) + # Reproject onto all views + pts_reprojected = reproject( + points_3d, + calib=session.camera_cluster, + excluded_views=frame_group.excluded_views, + ) # M=include x F=1 x T x N x 2 - # Create a dummy instance of all nan values - dummy_instance = Instance.from_numpy( - np.full( - shape=(len(skeleton.nodes), 2), - fill_value=np.nan, - ), - skeleton=skeleton, - ) + # Sqeeze back to the original shape + points_reprojected = np.squeeze(pts_reprojected, axis=1) # M=include x TxNx2 - # Get permutations of instances from other views - instances_permutations: Dict[Camcorder, Iterator[Tuple]] = {} - for cam, instances_in_view in instances.items(): - # Append a dummy instance to all lists of instances if less than the max length - num_missing = 1 - num_instances = len(instances_in_view) - if num_instances < max_num_instances: - num_missing = max_num_instances - num_instances - - # Extend the list first - instances_in_view.extend([dummy_instance] * num_missing) - - # Permute instances into all possible orderings w/in a view - instances_permutations[cam] = permutations(instances_in_view) - - # Get products of instances from other views into all possible groupings - # Ordering of dict_values is preserved in Python 3.7+ - products_of_instances: Iterator[Iterator[Tuple]] = product( - *instances_permutations.values() + # Update or create/insert ("upsert") instance points + frame_group.upsert_points( + points=points_reprojected, + instance_groups=instance_groups, + exclude_complete=True, ) - # Reorganize products by cam and add selected instance to each permutation - instances_hypotheses: Dict[int, Dict[Camcorder, List[Instance]]] = {} - for frame_id, prod in enumerate(products_of_instances): - instances_hypotheses[frame_id] = { - cam: [*inst] for cam, inst in zip(instances.keys(), prod) - } - - # Expect "max # instances in view" ** "# views" frames (a.k.a. hypotheses) - return instances_hypotheses - - @staticmethod - def get_instances_matrices( - instances: Dict[int, Dict[Camcorder, List[Instance]]], - session: Optional[RecordingSession] = None, - ) -> Tuple[np.ndarray, List[Camcorder]]: - """Gather instances from views into M x F x T x N x 2 an array. - - M: # views, F: # frames = 1, T: # tracks, N: # nodes, 2: x, y - - Note that frames indices are not directly used, but rather meant to act as a - marker for independent options (see `TriangulateSession.get_instance_groups`). - - Args: - instances: Dict with frame indices as keys and another Dict with `Camcorder` - keys and `List[Instance]` values. - session: The `RecordingSession` containing the `Camcorder`s. Used to order - the instances in the matrix as expected for triangulation. - - Returns: - M x F x T x N x 2 array of instances coordinates and the ordered list of - `Camcorder`s by which the instances are ordered. - """ - - cams_ordered = None - - # Get M X T X N X 2 array of instances coordinates for each frame - inst_coords_frames = [] - for instances_in_frame in instances.values(): - # Get correct camera ordering - if cams_ordered is None: - if session is None: - logger.warning( - "No session provided. Cannot organize instance coordinates to " - "be compatible for triangulation." - ) - cams_ordered = [cam for cam in instances_in_frame] - else: - cams_ordered = [ - cam for cam in session.cameras if cam in instances_in_frame - ] - - # Get list of instances matrices from each view - inst_coords_in_views = [ - np.stack( - [inst.numpy() for inst in instances_in_frame[cam]], - axis=0, - ) - for cam in cams_ordered - ] # len(M), List[T x N x 2] - - inst_coords_views = np.stack(inst_coords_in_views, axis=0) # M x T x N x 2 - inst_coords_frames.append( - inst_coords_views - ) # len=frame_idx, List[M x T x N x 2] - - inst_coords = np.stack(inst_coords_frames, axis=1) # M x F x T x N x 2 - cams_ordered = cast(List[Camcorder], cams_ordered) # Could be None if no frames - - return inst_coords, cams_ordered - - @staticmethod - def calculate_excluded_views( - session: RecordingSession, - cameras_being_used: Union[Dict[Camcorder, List[Instance]], List[Camcorder]], - ) -> Tuple[str]: - """Get excluded views from dictionary of `Camcorder` to `Instance`. + @classmethod + def has_enough_instances( + cls, + frame_group: FrameGroup, + instance_groups: Optional[List[InstanceGroup]], + frame_idx: Optional[int] = None, + instance: Optional[Instance] = None, + ) -> Optional[List[InstanceGroup]]: + """Filters out instance groups without enough instances for triangulation. Args: - session: The `RecordingSession` containing the `Camcorder`s. - cameras_being_used: List of `Camcorder`s. + frame_group: The `FrameGroup` object to use. + instance_groups: A list of `InstanceGroup` objects to use. + frame_idx: The frame index to use. + instance: The `Instance` object to use (only used in logging). Returns: - Tuple of excluded view names. + A list of `InstanceGroup` objects with enough instances for triangulation. """ - # Calculate excluded views from included cameras - cams_excluded = set(session.cameras) - set(cameras_being_used) - excluded_views = tuple(cam.name for cam in cams_excluded) - excluded_views = cast(Tuple[str], excluded_views) # cam.name could be Any + if instance is None: + instance = "" # Just used for logging - return excluded_views + if frame_idx is None: + frame_idx = "" # Just used for logging - @staticmethod - def calculate_excluded_views_multiple_frames( - session: RecordingSession, - instances: Dict[int, Dict[Camcorder, List[Instance]]], - ) -> Tuple[str]: - """Get excluded views from dictionary of `Camcorder` to `Instance`. - - Args: - session: The `RecordingSession` containing the `Camcorder`s. - instances: Dict with frame identifier keys (does not necessarily need to be - the frame index) and values of another inner dict with `Camcorder` keys - and `List[Instance]` values. - - Returns: - Tuple of excluded view names. - - Raises: - ValueError if excluded views are not the same across frames. - """ - - # Calculate excluded views from included cameras - excluded_views = None - for frame_idx, instances_in_frame in instances.items(): - excluded_views_prev = excluded_views - excluded_views = TriangulateSession.calculate_excluded_views( - session=session, cameras_being_used=instances_in_frame + if len(instance_groups) < 1: + logger.warning( + f"Require at least 1 instance group, but found " + f"{len(frame_group.instance_groups)} for frame group {frame_group} at " + f"frame {frame_idx}." + f"\nSkipping triangulation." ) - if excluded_views_prev is None: - prev_frame_idx = frame_idx - continue - elif excluded_views != excluded_views_prev: - raise ValueError( - "Excluded views are not the same across frames. " - f"\n\tExcluded views in frame identifier {prev_frame_idx}: {excluded_views_prev}. " - f"\n\tExcluded views in frame identifier {frame_idx}: {excluded_views}." - ) - prev_frame_idx = frame_idx - - excluded_views = cast(Tuple[str], excluded_views) # Could be None if no frames - - return excluded_views - - @staticmethod - def _calculate_reprojected_points( - session: RecordingSession, instances: Dict[int, Dict[Camcorder, List[Instance]]] - ) -> Tuple[np.ndarray, List[Camcorder]]: - """Triangulate and reproject instance coordinates. - - Note that the order of the instances in the list must match the order of the - cameras in the `CameraCluster`, that is why we require instances be passed in as - a dictionary mapping back to its `Camcorder`. - https://github.com/lambdaloop/aniposelib/blob/d03b485c4e178d7cff076e9fe1ac36837db49158/aniposelib/cameras.py#L491 - - Also, this function does not handle grouping instances with their respective - coordinates by reordering by camera. - See `TriangulateSession.calculate_reprojected_points`. - - Args: - session: The `RecordingSession` containing the `Camcorder`s. - instances: Dict with frame identifier keys (does not necessarily need to be - the frame index) and values of another inner dict with `Camcorder` keys - and `List[Instance]` values. - Returns: - A dictionary with frame identifier keys (does not necessarily need to be the - frame index) and values of another inner dict with `Camcorder` keys and - a zip of the `List[Instance]` and reprojected instance coordinates of shape - (T, N, 2) ordered by the `Camcorder` order in the `CameraCluster`. - """ - - # Derive the excluded views from the included cameras and ensures all frames - # have the same excluded views. - excluded_views = TriangulateSession.calculate_excluded_views_multiple_frames( - instances=instances, session=session - ) - - # Gather instances into M x F x T x N x 2 arrays (require specific order) - # (M = # views, F = # frames, T = # tracks, N = # nodes, 2 = x, y) - inst_coords, cams_ordered = TriangulateSession.get_instances_matrices( - instances=instances, session=session - ) # M x F x T x N x 2 - points_3d = triangulate( - p2d=inst_coords, - calib=session.camera_cluster, - excluded_views=excluded_views, - ) # F, T, N, 3 - - # Get the reprojected 2D points from the 3D points - inst_coords_reprojected = reproject( - points_3d, calib=session.camera_cluster, excluded_views=excluded_views - ) # M x F x T x N x 2 - - return inst_coords_reprojected, cams_ordered - - @staticmethod - def group_instances_and_coords( - instances: Dict[int, Dict[Camcorder, List[Instance]]], - inst_coords_reprojected: np.ndarray, - cams_ordered: List[Camcorder], - ) -> Dict[int, Dict[Camcorder, Iterator[Tuple[Instance, np.ndarray]]]]: - """Group instances and reprojected coordinates by frame and view. - - Args: - instances: Dict with frame identifier keys (does not necessarily need to be - the frame index) and values of another inner dict with `Camcorder` keys - and `List[Instance]` values. - inst_coords_reprojected: M x F x T x N x 2 array of reprojected instance - coordinates. - cams_ordered: List of `Camcorder`s ordered by the `CameraCluster` - representing both the order and subset of cameras used to calculate - `inst_coords_reprojected`. - - Returns: - A dictionary with frame identifier keys (does not necessarily need to be the - frame index) and values of another inner dict with `Camcorder` keys and - a zip of the `List[Instance]` and reprojected instance coordinates list with - items of shape (N, 2) ordered by the `Camcorder` order in the `CameraCluster`. - """ - - # Split the reprojected coordinates into a list corresponding to instances list. - insts_coords_list: List[List[np.ndarray]] = [ - [ # Annoyingly, np.split leaves a singleton dimension, so we have to squeeze. - np.squeeze(insts_coords_in_view, axis=0) - for insts_coords_in_view in np.split( - np.squeeze(insts_coords_in_frame, axis=1), - insts_coords_in_frame.shape[0], - axis=0, - ) # len(M) of T x N x 2 - for insts_coords_track in np.split( - np.squeeze(insts_coords_in_view, axis=0), - insts_coords_in_view.shape[0], - axis=0, - ) # len(T) of N x 2 - ] - for insts_coords_in_frame in np.split( - inst_coords_reprojected, inst_coords_reprojected.shape[1], axis=1 - ) # len(F) of M x T x N x 2 - ] # len(F) of len(M) of len(T) of N x 2 - - # Group together the reordered (by cam) instances and the reprojected coords. - insts_and_coords: Dict[ - int, Dict[Camcorder, Iterator[Tuple[Instance, np.ndarray]]] - ] = ( - {} - ) # Dict len(F) of dict len(M) of zipped lists of len(T) instances and array of N x 2 - for frame_idx, instances_in_frame in instances.items(): # len(F) of dict - insts_and_coords_in_frame: Dict[ - Camcorder, Iterator[Tuple[Instance, np.ndarray]] - ] = {} - for cam_idx, cam in enumerate(cams_ordered): - instances_in_frame_ordered: List[Instance] = instances_in_frame[ - cam - ] # Reorder by cam to match coordinates, len(T) - insts_coords_in_frame: np.ndarray = insts_coords_list[frame_idx][ - cam_idx - ] # len(T) of N x 2 - - # TODO(LM): I think we will need a reconsumable iterator here. - insts_and_coords_in_frame[cam] = zip( - instances_in_frame_ordered, - insts_coords_in_frame, + return None # No instance groups found + + # Assert that there are enough views and instances + instance_groups_to_tri = [] + for instance_group in instance_groups: + instances = instance_group.get_instances(frame_group.cams_to_include) + if len(instances) < 2: + # Not enough instances + logger.warning( + f"Not enough instances in {instance_group} for triangulation." + f"\nSkipping instance group." ) - insts_and_coords[frame_idx] = insts_and_coords_in_frame - - return insts_and_coords - - @staticmethod - def calculate_reprojected_points( - session: RecordingSession, instances: Dict[int, Dict[Camcorder, List[Instance]]] - ): - """Triangulate, reproject, and group coordinates of `Instances`. - - Args: - session: The `RecordingSession` containing the `Camcorder`s. - instances: Dict with frame identifier keys (does not necessarily need to be - the frame index) and values of another inner dict with `Camcorder` keys - and `List[Instance]` values. - - Returns: - A dictionary with frame identifier keys (does not necessarily need to be the - frame index) and values of another inner dict with `Camcorder` keys and - a zip of the `List[Instance]` and reprojected instance coordinates list with - items of shape (N, 2) ordered by the `Camcorder` order in the `CameraCluster`. - """ - - # Triangulate and reproject instance coordinates. - ( - inst_coords_reprojected, - cams_ordered, - ) = TriangulateSession._calculate_reprojected_points( - session=session, instances=instances - ) - - # Reorder instances (by cam) and the reprojected coords. - instances_and_coords: Dict[ - int, Dict[Camcorder, Iterator[Tuple[Instance, np.ndarray]]] - ] = TriangulateSession.group_instances_and_coords( - inst_coords_reprojected=inst_coords_reprojected, - instances=instances, - cams_ordered=cams_ordered, - ) - - return instances_and_coords - - @staticmethod - def update_instances( - instances_and_coords: Dict[Camcorder, Iterator[Tuple[Instance, np.ndarray]]] - ): - """Update coordinates of `Instances`. - - Args: - instances_and_coords: Dict with `Camcorder` keys and - `Iterator[Tuple[Instance, np.ndarray]]` values containing the Instance - and it's reprojected coordinates. - - Returns: - None - """ + continue + instance_groups_to_tri.append(instance_group) - # Update the instance coordinates. - for instances_in_view in instances_and_coords.values(): - for inst, inst_coord in instances_in_view: - inst.update_points(points=inst_coord, exclude_complete=True) + return instance_groups_to_tri # `InstanceGroup`s with enough instances def open_website(url: str): diff --git a/sleap/instance.py b/sleap/instance.py index 1da784416..ed1fa3d07 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -500,7 +500,6 @@ def _points_dict_to_array( ) try: parray[skeleton.node_to_index(node)] = point - # parray[skeleton.node_to_index(node.name)] = point except: logger.debug( f"Could not set point for node {node} in {skeleton} " @@ -729,9 +728,31 @@ def update_points(self, points: np.ndarray, exclude_complete: bool = False): for point_new, points_old, node_name in zip( points, self._points, self.skeleton.node_names ): + + # Skip if new point is nan or old point is complete if np.isnan(point_new).any() or (exclude_complete and points_old.complete): continue - points_dict[node_name] = Point(x=point_new[0], y=point_new[1]) + + # Grab the x, y from the new point and visible, complete from the old point + x, y = point_new + visible = points_old.visible + complete = points_old.complete + + # Create a new point and add to the dict + if type(self._points) == PredictedPointArray: + # TODO(LM): The point score is meant to rate the confidence of the + # prediction, but this method updates from triangulation. + score = points_old.score + point_obj = PredictedPoint( + x=x, y=y, visible=visible, complete=complete, score=score + ) + else: + point_obj = Point(x=x, y=y, visible=visible, complete=complete) + + # Update the points dict + points_dict[node_name] = point_obj + + # Update the points if len(points_dict) > 0: Instance._points_dict_to_array(points_dict, self._points, self.skeleton) diff --git a/sleap/io/cameras.py b/sleap/io/cameras.py index 0cf830feb..a9151d93c 100644 --- a/sleap/io/cameras.py +++ b/sleap/io/cameras.py @@ -1,8 +1,10 @@ """Module for storing information for camera groups.""" + +from itertools import permutations, product import logging import tempfile from pathlib import Path -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast, Set import cattr import numpy as np @@ -10,9 +12,9 @@ from aniposelib.cameras import Camera, CameraGroup, FisheyeCamera from attrs import define, field from attrs.validators import deep_iterable, instance_of -from sleap_anipose import reproject, triangulate # from sleap.io.dataset import Labels # TODO(LM): Circular import, implement Observer +from sleap.instance import PredictedInstance from sleap.io.video import Video from sleap.util import deep_iterable_converter @@ -394,6 +396,411 @@ def to_calibration_dict(self) -> Dict[str, str]: return calibration_dict +@define +class InstanceGroup: + """Defines a group of instances across the same frame index. + + Args: + camera_cluster: `CameraCluster` object. + instances: List of `Instance` objects. + + """ + + frame_idx: int = field(validator=instance_of(int)) + camera_cluster: Optional[CameraCluster] = None + locked: bool = field(default=False) + _instance_by_camcorder: Dict[Camcorder, "Instance"] = field(factory=dict) + _camcorder_by_instance: Dict["Instance", Camcorder] = field(factory=dict) + _dummy_instance: Optional["Instance"] = field(default=None) + + def __attrs_post_init__(self): + """Initialize `InstanceGroup` object.""" + + instance = None + for cam, instance in self._instance_by_camcorder.items(): + self._camcorder_by_instance[instance] = cam + + # Create a dummy instance to fill in for missing instances + if self._dummy_instance is None: + self._create_dummy_instance(instance=instance) + + def _create_dummy_instance(self, instance: Optional["Instance"] = None): + """Create a dummy instance to fill in for missing instances. + + Args: + instance: Optional `Instance` object to use as an example instance. If None, + then the first instance in the `InstanceGroup` is used. + + Raises: + ValueError: If no instances are available to create a dummy instance. + """ + + if self._dummy_instance is None: + # Get an example instance + if instance is None: + if len(self.instances) < 1: + raise ValueError( + "Cannot create a dummy instance without any instances." + ) + instance = self.instances[0] + + # Use the example instance to create a dummy instance + skeleton: "Skeleton" = instance.skeleton + self._dummy_instance = PredictedInstance.from_numpy( + points=np.full( + shape=(len(skeleton.nodes), 2), + fill_value=np.nan, + ), + point_confidences=np.full( + shape=(len(skeleton.nodes),), + fill_value=np.nan, + ), + instance_score=np.nan, + skeleton=skeleton, + ) + + @property + def dummy_instance(self) -> PredictedInstance: + """Dummy `PredictedInstance` object to fill in for missing instances. + + Also used to create instances that are not found in the `InstanceGroup`. + + Returns: + `PredictedInstance` object or None if unable to create the dummy instance. + """ + + if self._dummy_instance is None: + self._create_dummy_instance() + return self._dummy_instance + + @property + def instances(self) -> List["Instance"]: + """List of `Instance` objects.""" + return list(self._instance_by_camcorder.values()) + + @property + def cameras(self) -> List[Camcorder]: + """List of `Camcorder` objects.""" + return list(self._instance_by_camcorder.keys()) + + def numpy(self) -> np.ndarray: + """Return instances as a numpy array of shape (n_views, n_nodes, 2). + The ordering of views is based on the ordering of `Camcorder`s in the + `self.camera_cluster: CameraCluster`. + If an instance is missing for a `Camcorder`, then the instance is filled in with + the dummy instance (all NaNs). + Returns: + Numpy array of shape (n_views, n_nodes, 2). + """ + + instance_numpys: List[np.ndarray] = [] # len(M) x N x 2 + for cam in self.camera_cluster.cameras: + instance = self.get_instance(cam) or self.dummy_instance + instance_numpy: np.ndarray = instance.numpy() # N x 2 + instance_numpys.append(instance_numpy) + + return np.stack(instance_numpys, axis=0) # M x N x 2 + + def create_and_add_instance(self, cam: Camcorder, labeled_frame: "LabeledFrame"): + """Create an `Instance` at a labeled_frame and add it to the `InstanceGroup`. + + Args: + cam: `Camcorder` object that the `Instance` is for. + labeled_frame: `LabeledFrame` object that the `Instance` is contained in. + + Returns: + All nan `PredictedInstance` created and added to the `InstanceGroup`. + """ + + # Get the `Skeleton` + skeleton: "Skeleton" = self.dummy_instance.skeleton + + # Create an all nan `Instance` + instance: PredictedInstance = PredictedInstance.from_numpy( + points=self.dummy_instance.points_array, + point_confidences=self.dummy_instance.scores, + instance_score=self.dummy_instance.score, + skeleton=skeleton, + ) + instance.frame = labeled_frame + + # Add the instance to the `InstanceGroup` + self.add_instance(cam, instance) + + return instance + + def add_instance(self, cam: Camcorder, instance: "Instance"): + """Add an `Instance` to the `InstanceGroup`. + + Args: + cam: `Camcorder` object that the `Instance` is for. + instance: `Instance` object to add. + + Raises: + ValueError: If the `Camcorder` is not in the `CameraCluster`. + ValueError: If the `Instance` is already in the `InstanceGroup` at another + camera. + """ + + # Ensure the `Camcorder` is in the `CameraCluster` + self._raise_if_cam_not_in_cluster(cam=cam) + + # Ensure the `Instance` is not already in the `InstanceGroup` at another camera + if ( + instance in self._camcorder_by_instance + and self._camcorder_by_instance[instance] != cam + ): + raise ValueError( + f"Instance {instance} is already in this InstanceGroup at camera " + f"{self.get_instance(instance)}." + ) + + # Add the instance to the `InstanceGroup` + self.replace_instance(cam, instance) + + def replace_instance(self, cam: Camcorder, instance: "Instance"): + """Replace an `Instance` in the `InstanceGroup`. + + If the `Instance` is already in the `InstanceGroup`, then it is removed and + replaced. If the `Instance` is not already in the `InstanceGroup`, then it is + added. + + Args: + cam: `Camcorder` object that the `Instance` is for. + instance: `Instance` object to replace. + + Raises: + ValueError: If the `Camcorder` is not in the `CameraCluster`. + """ + + # Ensure the `Camcorder` is in the `CameraCluster` + self._raise_if_cam_not_in_cluster(cam=cam) + + # Remove the instance if it already exists + self.remove_instance(instance_or_cam=instance) + + # Replace the instance in the `InstanceGroup` + self._instance_by_camcorder[cam] = instance + self._camcorder_by_instance[instance] = cam + + def remove_instance(self, instance_or_cam: Union["Instance", Camcorder]): + """Remove an `Instance` from the `InstanceGroup`. + + Args: + instance_or_cam: `Instance` or `Camcorder` object to remove from + `InstanceGroup`. + + Raises: + ValueError: If the `Camcorder` is not in the `CameraCluster`. + """ + + if isinstance(instance_or_cam, Camcorder): + cam = instance_or_cam + + # Ensure the `Camcorder` is in the `CameraCluster` + self._raise_if_cam_not_in_cluster(cam=cam) + + # Remove the instance from the `InstanceGroup` + if cam in self._instance_by_camcorder: + instance = self._instance_by_camcorder.pop(cam) + self._camcorder_by_instance.pop(instance) + + else: + # The input is an `Instance` + instance = instance_or_cam + + # Remove the instance from the `InstanceGroup` + if instance in self._camcorder_by_instance: + cam = self._camcorder_by_instance.pop(instance) + self._instance_by_camcorder.pop(cam) + else: + logger.debug( + f"Instance {instance} not found in this InstanceGroup {self}." + ) + + def _raise_if_cam_not_in_cluster(self, cam: Camcorder): + """Raise a ValueError if the `Camcorder` is not in the `CameraCluster`.""" + + if cam not in self.camera_cluster: + raise ValueError( + f"Camcorder {cam} is not in this InstanceGroup's " + f"{self.camera_cluster}." + ) + + def get_instance(self, cam: Camcorder) -> Optional["Instance"]: + """Retrieve `Instance` linked to `Camcorder`. + + Args: + camcorder: `Camcorder` object. + + Returns: + If `Camcorder` in `self.camera_cluster`, then `Instance` object if found, else + `None` if `Camcorder` has no linked `Instance`. + """ + + if cam not in self._instance_by_camcorder: + logger.debug( + f"Camcorder {cam} has no linked `Instance` in this `InstanceGroup` " + f"{self}." + ) + return None + + return self._instance_by_camcorder[cam] + + def get_instances(self, cams: List[Camcorder]) -> List["Instance"]: + instances = [] + for cam in cams: + instance = self.get_instance(cam) + instances.append(instance) + return instance + + def get_cam(self, instance: "Instance") -> Optional[Camcorder]: + """Retrieve `Camcorder` linked to `Instance`. + + Args: + instance: `Instance` object. + + Returns: + `Camcorder` object if found, else `None`. + """ + + if instance not in self._camcorder_by_instance: + logger.debug( + f"{instance} is not in this InstanceGroup.instances: " + f"\n\t{self.instances}." + ) + return None + + return self._camcorder_by_instance[instance] + + def update_points( + self, + points: np.ndarray, + cams_to_include: Optional[List[Camcorder]] = None, + exclude_complete: bool = True, + ): + """Update the points in the `Instance` for the specified `Camcorder`s. + + Args: + points: Numpy array of shape (M, N, 2) where M is the number of views, N is + the number of Nodes, and 2 is for x, y. + cams_to_include: List of `Camcorder`s to include in the update. The order of + the `Camcorder`s in the list should match the order of the views in the + `points` array. If None, then all `Camcorder`s in the `CameraCluster` + are included. Default is None. + exclude_complete: If True, then do not update points that are marked as + complete. Default is True. + """ + + # If no `Camcorder`s specified, then update `Instance`s for all `CameraCluster` + if cams_to_include is None: + cams_to_include = self.camera_cluster.cameras + + # Check that correct shape was passed in + n_views, n_nodes, _ = points.shape + assert n_views == len(cams_to_include), ( + f"Number of views in `points` ({n_views}) does not match the number of " + f"Camcorders in `cams_to_include` ({len(cams_to_include)})." + ) + + for cam_idx, cam in enumerate(cams_to_include): + # Get the instance for the cam + instance: Optional["Instance"] = self.get_instance(cam) + if instance is None: + logger.warning( + f"Camcorder {cam.name} not found in this InstanceGroup's instances." + ) + continue + + # Update the points (and scores) for the (predicted) instance + instance.update_points( + points=points[cam_idx, :, :], exclude_complete=exclude_complete + ) + + def __getitem__( + self, idx_or_key: Union[int, Camcorder, "Instance"] + ) -> Union[Camcorder, "Instance"]: + """Grab a `Camcorder` of `Instance` from the `InstanceGroup`.""" + + def _raise_key_error(): + raise KeyError(f"Key {idx_or_key} not found in {self.__class__.__name__}.") + + # Try to find in `self.camera_cluster.cameras` + if isinstance(idx_or_key, int): + try: + return self.instances[idx_or_key] + except IndexError: + _raise_key_error() + + # Return a `Instance` if `idx_or_key` is a `Camcorder`` + if isinstance(idx_or_key, Camcorder): + return self.get_instance(idx_or_key) + + else: + # isinstance(idx_or_key, "Instance"): + try: + return self.get_cam(idx_or_key) + except: + pass + + _raise_key_error() + + def __len__(self): + return len(self.instances) + + def __repr__(self): + return f"{self.__class__.__name__}(frame_idx={self.frame_idx}, instances={len(self)}, camera_cluster={self.camera_cluster})" + + @classmethod + def from_dict(cls, d: dict) -> Optional["InstanceGroup"]: + """Creates an `InstanceGroup` object from a dictionary. + + Args: + d: Dictionary with `Camcorder` keys and `Instance` values. + + Returns: + `InstanceGroup` object or None if no "real" (determined by `frame_idx` other + than None) instances found. + """ + + # Ensure not to mutate the original dictionary + d_copy = d.copy() + + frame_idx = None + for cam, instance in d_copy.copy().items(): + camera_cluster = cam.camera_cluster + + # Remove dummy instances (determined by not having a frame index) + if instance.frame_idx is None: + d_copy.pop(cam) + # Grab the frame index from non-dummy instances + elif frame_idx is None: + frame_idx = instance.frame_idx + # Ensure all instances have the same frame index + else: + try: + assert frame_idx == instance.frame_idx + except AssertionError: + logger.warning( + f"Cannot create `InstanceGroup`: Frame index {frame_idx} " + f"does not match instance frame index {instance.frame_idx}." + ) + + if len(d_copy) == 0: + logger.warning("Cannot create `InstanceGroup`: No real instances found.") + return None + + frame_idx = cast( + int, frame_idx + ) # Could be None if no real instances in dictionary + + return cls( + frame_idx=frame_idx, + camera_cluster=camera_cluster, + instance_by_camcorder=d_copy, + ) + + @define(eq=False) class RecordingSession: """Class for storing information for a recording session. @@ -415,6 +822,12 @@ class RecordingSession: _video_by_camcorder: Dict[Camcorder, Video] = field(factory=dict) labels: Optional["Labels"] = None + # TODO(LM): Remove this, replace with `FrameGroup`s + _instance_groups_by_frame_idx: Dict[int, InstanceGroup] = field(factory=dict) + + # TODO(LM): We should serialize all locked instances in a FrameGroup (or the entire FrameGroup) + _frame_group_by_frame_idx: Dict[int, "FrameGroup"] = field(factory=dict) + @property def videos(self) -> List[Video]: """List of `Video`s.""" @@ -423,15 +836,45 @@ def videos(self) -> List[Video]: @property def linked_cameras(self) -> List[Camcorder]: - """List of `Camcorder`s in `self.camera_cluster` that are linked to a video.""" + """List of `Camcorder`s in `self.camera_cluster` that are linked to a video. - return list(self._video_by_camcorder.keys()) + The list is ordered based on the order of the `Camcorder`s in the `CameraCluster`. + """ + + return sorted( + self._video_by_camcorder.keys(), key=self.camera_cluster.cameras.index + ) @property def unlinked_cameras(self) -> List[Camcorder]: - """List of `Camcorder`s in `self.camera_cluster` that are not linked to a video.""" + """List of `Camcorder`s in `self.camera_cluster` that are not linked to a video. + + The list is ordered based on the order of the `Camcorder`s in the `CameraCluster`. + """ - return list(set(self.camera_cluster.cameras) - set(self.linked_cameras)) + return sorted( + set(self.camera_cluster.cameras) - set(self.linked_cameras), + key=self.camera_cluster.cameras.index, + ) + + # TODO(LM): Remove this + @property + def instance_groups(self) -> Dict[int, InstanceGroup]: + """Dict of `InstanceGroup`s by frame index.""" + + return self._instance_groups_by_frame_idx + + @property + def frame_groups(self) -> Dict[int, "FrameGroup"]: + """Dict of `FrameGroup`s by frame index.""" + + return self._frame_group_by_frame_idx + + @property + def frame_inds(self) -> List[int]: + """List of frame indices.""" + + return list(self.frame_groups.keys()) def get_video(self, camcorder: Camcorder) -> Optional[Video]: """Retrieve `Video` linked to `Camcorder`. @@ -519,6 +962,11 @@ def add_video(self, video: Video, camcorder: Camcorder): # Add camcorder-to-video (1-to-1) map to `RecordingSession` self._video_by_camcorder[camcorder] = video + # Sort `_videos_by_session` by order of linked `Camcorder` in `CameraCluster.cameras` + self.camera_cluster._videos_by_session[self].sort( + key=lambda video: self.camera_cluster.cameras.index(self.get_camera(video)) + ) + # Update labels cache if self.labels is not None: self.labels.update_session(self, video) @@ -572,6 +1020,37 @@ def get_videos_from_selected_cameras( return videos + # TODO(LM): There can be multiple `InstanceGroup`s per frame index + def get_instance_group(self, frame_idx: int) -> Optional[InstanceGroup]: + """Get `InstanceGroup` from frame index. + + Args: + frame_idx: Frame index. + + Returns: + `InstanceGroup` object or `None` if not found. + """ + + if frame_idx not in self.instance_groups: + logger.warning( + f"Frame index {frame_idx} not found in this RecordingSession's " + f"InstanceGroup's keys: \n\t{self.instance_groups.keys()}." + ) + return None + + return self.instance_groups[frame_idx] + + # TODO(LM): There can be multiple `InstanceGroup`s per frame index + def update_instance_group(self, frame_idx: int, instance_group: InstanceGroup): + """Update `InstanceGroup` from frame index. + + Args: + frame_idx: Frame index. + instance_groups: `InstanceGroup` object. + """ + + self._instance_groups_by_frame_idx[frame_idx] = instance_group + def __attrs_post_init__(self): self.camera_cluster.add_session(self) @@ -755,3 +1234,831 @@ def make_cattr(videos_list: List[Video]): RecordingSession, lambda x: x.to_session_dict(video_to_idx) ) return sessions_cattr + + +@define +class FrameGroup: + """Defines a group of `InstanceGroups` across views at the same frame index.""" + + # Instance attributes + frame_idx: int = field(validator=instance_of(int)) + instance_groups: List[InstanceGroup] = field( + validator=deep_iterable( + member_validator=instance_of(InstanceGroup), + iterable_validator=instance_of(list), + ), + ) # Akin to `LabeledFrame.instances` + session: RecordingSession = field(validator=instance_of(RecordingSession)) + + # Class attribute to keep track of frame indices across all `RecordingSession`s + _frame_idx_registry: Dict[RecordingSession, Set[int]] = {} + + # "Hidden" class attribute + _cams_to_include: Optional[List[Camcorder]] = None + _excluded_views: Optional[Tuple[str]] = () + _dummy_labeled_frame: Optional["LabeledFrame"] = None + + # "Hidden" instance attributes + + # TODO(LM): This dict should be updated each time a LabeledFrame is added/removed + # from the Labels object. Or if a video is added/removed from the RecordingSession. + _labeled_frames_by_cam: Dict[Camcorder, "LabeledFrame"] = field(factory=dict) + _instances_by_cam: Dict[Camcorder, Set["Instance"]] = field(factory=dict) + + # TODO(LM): This dict should be updated each time an InstanceGroup is + # added/removed/locked/unlocked + _locked_instance_groups: List[InstanceGroup] = field(factory=list) + _locked_instances_by_cam: Dict[Camcorder, Set["Instance"]] = field( + factory=dict + ) # Internally updated in `update_locked_instances_by_cam` + + def __attrs_post_init__(self): + """Initialize `FrameGroup` object.""" + + # Remove existing `FrameGroup` object from the `RecordingSession._frame_group_by_frame_idx` + self.enforce_frame_idx_unique(self.session, self.frame_idx) + + # Reorder `cams_to_include` to match `CameraCluster` order (via setter method) + if self._cams_to_include is not None: + self.cams_to_include = self._cams_to_include + + # Add frame index to registry + if self.session not in self._frame_idx_registry: + self._frame_idx_registry[self.session] = set() + + self._frame_idx_registry[self.session].add(self.frame_idx) + + # Add `FrameGroup` to `RecordingSession` + self.session._frame_group_by_frame_idx[self.frame_idx] = self + + # Initialize `_labeled_frames_by_cam` dictionary + self.update_labeled_frames_and_instances_by_cam() + + # Initialize `_locked_instance_groups` dictionary + self.update_locked_instance_groups() + + # The dummy labeled frame will only be set once for the first `FrameGroup` made + if self._dummy_labeled_frame is None: + self._dummy_labeled_frame = self.labeled_frames[0] + + @property + def cams_to_include(self) -> Optional[List[Camcorder]]: + """List of `Camcorder`s to include in this `FrameGroup`.""" + + if self._cams_to_include is None: + self._cams_to_include = self.session.camera_cluster.cameras.copy() + + # TODO(LM): Should we store this in another attribute? + # Filter cams to include based on videos linked to the session + cams_to_include = [ + cam for cam in self._cams_to_include if cam in self.session.linked_cameras + ] + + return cams_to_include + + @property + def excluded_views(self) -> Optional[Tuple[str]]: + """List of excluded views (names of Camcorders).""" + + return self._excluded_views + + @cams_to_include.setter + def cams_to_include(self, cams_to_include: List[Camcorder]): + """Setter for `cams_to_include` that sorts by `CameraCluster` order.""" + + # Sort the `Camcorder`s to include based on the order of `CameraCluster` cameras + self._cams_to_include = cams_to_include.sort( + key=self.session.camera_cluster.cameras.index + ) + + # Update the `excluded_views` attribute + excluded_cams = list( + set(self.session.camera_cluster.cameras) - set(cams_to_include) + ) + excluded_cams.sort(key=self.session.camera_cluster.cameras.index) + self._excluded_views = (cam.name for cam in excluded_cams) + + @property + def labeled_frames(self) -> List["LabeledFrame"]: + """List of `LabeledFrame`s.""" + + return list(self._labeled_frames_by_cam.values()) + + @property + def cameras(self) -> List[Camcorder]: + """List of `Camcorder`s.""" + + return list(self._labeled_frames_by_cam.keys()) + + @property + def instances_by_cam_to_include(self) -> Dict[Camcorder, Set["Instance"]]: + """List of `Camcorder`s.""" + + return {cam: self._instances_by_cam[cam] for cam in self.cams_to_include} + + @property + def locked_instance_groups(self) -> List[InstanceGroup]: + """List of locked `InstanceGroup`s.""" + + return self._locked_instance_groups + + def numpy( + self, instance_groups: Optional[List[InstanceGroup]] = None + ) -> np.ndarray: + """Numpy array of all `InstanceGroup`s in `FrameGroup.cams_to_include`. + + Args: + instance_groups: `InstanceGroup`s to include. Default is None and uses all + self.instance_groups. + + Returns: + Numpy array of shape (M, T, N, 2) where M is the number of views (determined + by self.cames_to_include), T is the number of `InstanceGroup`s, N is the + number of Nodes, and 2 is for x, y. + """ + + # Use all `InstanceGroup`s if not specified + if instance_groups is None: + instance_groups = self.instance_groups + else: + # Ensure that `InstanceGroup`s is in this `FrameGroup` + for instance_group in instance_groups: + if instance_group not in self.instance_groups: + raise ValueError( + f"InstanceGroup {instance_group} is not in this FrameGroup: " + f"{self.instance_groups}" + ) + + instance_group_numpys: List[np.ndarray] = [] # len(T) M=all x N x 2 + for instance_group in instance_groups: + instance_group_numpy = instance_group.numpy() # M=all x N x 2 + instance_group_numpys.append(instance_group_numpy) + + frame_group_numpy = np.stack(instance_group_numpys, axis=1) # M=all x T x N x 2 + cams_to_include_mask = np.array( + [1 if cam in self.cams_to_include else 0 for cam in self.cameras] + ) # M=include x 1 + + return frame_group_numpy[cams_to_include_mask] # M=include x T x N x 2 + + def add_instance( + self, + instance: "Instance", + camera: Camcorder, + instance_group: Optional[InstanceGroup] = None, + ): + """Add an (existing) `Instance` to the `FrameGroup`. + + If no `InstanceGroup` is provided, then check the `Instance` is already in an + `InstanceGroup` contained in the `FrameGroup`. + + Args: + instance: `Instance` to add to the `FrameGroup`. + camera: `Camcorder` to link the `Instance` to. + instance_group: `InstanceGroup` to add the `Instance` to. If None, then + check the `Instance` is already in an `InstanceGroup`. + + Raises: + ValueError: If the `InstanceGroup` is not in the `FrameGroup`. + ValueError: If the `Instance` is not linked to a `LabeledFrame`. + ValueError: If the frame index of the `Instance` does not match the frame index + of the `FrameGroup`. + ValueError: If the `LabeledFrame` of the `Instance` does not match the existing + `LabeledFrame` for the `Camcorder` in the `FrameGroup`. + ValueError: If the `Instance` is not in an `InstanceGroup` in the + `FrameGroup`. + """ + + # Ensure the `InstanceGroup` is in this `FrameGroup` + if instance_group is not None: + self._raise_if_instance_group_not_in_frame_group( + instance_group=instance_group + ) + + # Ensure `Instance` is compatible with `FrameGroup` + self._raise_if_instance_incompatibile(instance=instance, camera=camera) + + # Add the `Instance` to the `InstanceGroup` + if instance_group is not None: + instance_group.add_instance(cam=camera, instance=instance) + else: + self._raise_if_instance_not_in_instance_group(instance=instance) + + # Add the `Instance` to the `FrameGroup` + self._instances_by_cam[camera].add(instance) + + # Update the labeled frames if necessary + labeled_frame = self.get_labeled_frame(camera=camera) + if labeled_frame is None: + labeled_frame = instance.frame + self.add_labeled_frame(labeled_frame=labeled_frame, camera=camera) + + def add_instance_group(self, instance_group: Optional[InstanceGroup] = None): + """Add an `InstanceGroup` to the `FrameGroup`. + + Args: + instance_group: `InstanceGroup` to add to the `FrameGroup`. If None, then + create a new `InstanceGroup` and add it to the `FrameGroup`. + + Raises: + ValueError: If the `InstanceGroup` is already in the `FrameGroup`. + """ + + if instance_group is None: + # Create an empty `InstanceGroup` with the frame index of the `FrameGroup` + instance_group = InstanceGroup( + frame_idx=self.frame_idx, + camera_cluster=self.session.camera_cluster, + ) + + else: + # Ensure the `InstanceGroup` is not already in this `FrameGroup` + self._raise_if_instance_group_in_frame_group(instance_group=instance_group) + + # Ensure the `InstanceGroup` is compatible with the `FrameGroup` + self._raise_if_instance_group_incompatible(instance_group=instance_group) + + # Add the `InstanceGroup` to the `FrameGroup` + self.instance_groups.append(instance_group) + + # Add `Instance`s and `LabeledFrame`s to the `FrameGroup` + for instance in instance_group.instances: + camera = instance_group.get_cam(instance=instance) + self.add_instance(instance=instance, camera=camera) + + # TODO(LM): Integrate with RecordingSession + # Add the `InstanceGroup` to the `RecordingSession` + ... + + def get_instance_group(self, instance: "Instance") -> Optional[InstanceGroup]: + """Get `InstanceGroup` that contains `Instance` if exists. Otherwise, None. + + Args: + instance: `Instance` + + Returns: + `InstanceGroup` + """ + + instance_group: Optional[InstanceGroup] = next( + ( + instance_group + for instance_group in self.instance_groups + if instance in instance_group.instances + ), + None, + ) + + return instance_group + + def add_labeled_frame(self, labeled_frame: "LabeledFrame", camera: Camcorder): + """Add a `LabeledFrame` to the `FrameGroup`. + + Args: + labeled_frame: `LabeledFrame` to add to the `FrameGroup`. + camera: `Camcorder` to link the `LabeledFrame` to. + """ + + # Add the `LabeledFrame` to the `FrameGroup` + self._labeled_frames_by_cam[camera] = labeled_frame + + # TODO(LM): Should this be an EditCommand instead? + # Add the `LabeledFrame` to the `RecordingSession`'s `Labels` object + if labeled_frame not in self.session.labels: + self.session.labels.append(labeled_frame) + + def get_labeled_frame(self, camera: Camcorder) -> Optional["LabeledFrame"]: + """Get `LabeledFrame` for `Camcorder` if exists. Otherwise, None. + + Args: + camera: `Camcorder` + + Returns: + `LabeledFrame` + """ + + return self._labeled_frames_by_cam.get(camera, None) + + def create_and_add_labeled_frame(self, camera: Camcorder) -> "LabeledFrame": + """Create and add a `LabeledFrame` to the `FrameGroup`. + + This also adds the `LabeledFrame` to the `RecordingSession`'s `Labels` object. + + Args: + camera: `Camcorder` + + Returns: + `LabeledFrame` that was created and added to the `FrameGroup`. + """ + + video = self.session.get_video(camera) + if video is None: + # There should be a `Video` linked to all cams_to_include + raise ValueError( + f"Camcorder {camera} is not linked to a video in this " + f"RecordingSession {self.session}." + ) + + # Use _dummy_labeled_frame to access the `LabeledFrame`` class here + labeled_frame = self._dummy_labeled_frame.__class__( + video=video, frame_idx=self.frame_idx + ) + self.add_labeled_frame(labeled_frame=labeled_frame) + + return labeled_frame + + def create_and_add_instance( + self, + instance_group: InstanceGroup, + camera: Camcorder, + labeled_frame: "LabeledFrame", + ): + """Add an `Instance` to the `InstanceGroup` (and `FrameGroup`). + + Args: + instance_group: `InstanceGroup` to add the `Instance` to. + camera: `Camcorder` to link the `Instance` to. + labeled_frame: `LabeledFrame` that the `Instance` is in. + """ + + # Add the `Instance` to the `InstanceGroup` + instance = instance_group.create_and_add_instance( + cam=camera, labeled_frame=labeled_frame + ) + + # Add the `Instance` to the `FrameGroup` + self._instances_by_cam[camera].add(instance=instance) + + def create_and_add_missing_instances(self, instance_group: InstanceGroup): + """Add missing instances to `FrameGroup` from `InstanceGroup`s. + + If an `InstanceGroup` does not have an `Instance` for a `Camcorder` in + `FrameGroup.cams_to_include`, then create an `Instance` and add it to the + `InstanceGroup`. + + Args: + instance_group: `InstanceGroup` objects to add missing `Instance`s for. + + Raises: + ValueError: If a `Camcorder` in `FrameGroup.cams_to_include` is not in the + `InstanceGroup`. + """ + + # Check that the `InstanceGroup` has `LabeledFrame`s for all included views + for cam in self.cams_to_include: + + # If the `Camcorder` is in the `InstanceGroup`, then `Instance` exists + if cam in instance_group.cameras: + continue # Skip to next cam + + # Get the `LabeledFrame` for the view + labeled_frame = self.get_labeled_frame(camera=cam) + if labeled_frame is None: + # There is no `LabeledFrame` for this view, so lets make one + labeled_frame = self.create_and_add_labeled_frame(camera=cam) + + # Create an instance + self.create_and_add_instance( + instance_group=instance_group, cam=cam, labeled_frame=labeled_frame + ) + + def upsert_points( + self, + points: np.ndarray, + instance_groups: List[InstanceGroup], + exclude_complete: bool = True, + ): + """Upsert points for `Instance`s at included cams in specified `InstanceGroup`. + + This will update the points for existing `Instance`s in the `InstanceGroup`s and + also add new `Instance`s if they do not exist. + + + Included cams are specified by `FrameGroup.cams_to_include`. + + The ordering of the `InstanceGroup`s in `instance_groups` should match the + ordering of the second dimension (T) in `points`. + + Args: + points: Numpy array of shape (M, T, N, 2) where M is the number of views, T + is the number of Tracks, N is the number of Nodes, and 2 is for x, y. + instance_groups: List of `InstanceGroup` objects to update points for. + exclude_complete: If True, then only update points that are not marked as + complete. Default is True. + """ + + # Check that the correct shape was passed in + n_views, n_instances, n_nodes, n_coords = points.shape + assert n_views == len( + self.cams_to_include + ), f"Expected {len(self.cams_to_include)} views, got {n_views}." + assert n_instances == len( + instance_groups + ), f"Expected {len(instance_groups)} instances, got {n_instances}." + assert n_coords == 2, f"Expected 2 coordinates, got {n_coords}." + + # Update points for each `InstanceGroup` + for ig_idx, instance_group in enumerate(instance_groups): + # Ensure that `InstanceGroup`s is in this `FrameGroup` + self._raise_if_instance_group_not_in_frame_group( + instance_group=instance_group + ) + + # Check that the `InstanceGroup` has `Instance`s for all cams_to_include + self.create_and_add_missing_instances(instance_group=instance_group) + + # Update points for each `Instance` in `InstanceGroup` + instance_points = points[:, ig_idx, :, :] # M x N x 2 + instance_group.update_points( + points=instance_points, + cams_to_include=self.cams_to_include, + exclude_complete=exclude_complete, + ) + + def _raise_if_instance_not_in_instance_group(self, instance: "Instance"): + """Raise a ValueError if the `Instance` is not in an `InstanceGroup`. + + Args: + instance: `Instance` to check if in an `InstanceGroup`. + + Raises: + ValueError: If the `Instance` is not in an `InstanceGroup`. + """ + + instance_group = self.get_instance_group(instance=instance) + if instance_group is None: + raise ValueError( + f"Instance {instance} is not in an InstanceGroup within the FrameGroup." + ) + + def _raise_if_instance_incompatibile(self, instance: "Instance", camera: Camcorder): + """Raise a ValueError if the `Instance` is incompatible with the `FrameGroup`. + + The `Instance` is incompatible if: + 1. the `Instance` is not linked to a `LabeledFrame`. + 2. the frame index of the `Instance` does not match the frame index of the + `FrameGroup`. + 3. the `LabeledFrame` of the `Instance` does not match the existing + `LabeledFrame` for the `Camcorder` in the `FrameGroup`. + + Args: + instance: `Instance` to check compatibility of. + camera: `Camcorder` to link the `Instance` to. + """ + + labeled_frame = instance.frame + if labeled_frame is None: + raise ValueError( + f"Instance {instance} is not linked to a LabeledFrame. " + "Cannot add to FrameGroup." + ) + + frame_idx = labeled_frame.frame_idx + if frame_idx != self.frame_idx: + raise ValueError( + f"Instance {instance} frame index {frame_idx} does not match " + f"FrameGroup frame index {self.frame_idx}." + ) + + labeled_frame_fg = self.get_labeled_frame(camera=camera) + if labeled_frame_fg is None: + pass + elif labeled_frame != labeled_frame_fg: + raise ValueError( + f"Instance's LabeledFrame {labeled_frame} is not the same as " + f"FrameGroup's LabeledFrame {labeled_frame_fg} for Camcorder {camera}." + ) + + def _raise_if_instance_group_in_frame_group(self, instance_group: InstanceGroup): + """Raise a ValueError if the `InstanceGroup` is already in the `FrameGroup`. + + Args: + instance_group: `InstanceGroup` to check if already in the `FrameGroup`. + + Raises: + ValueError: If the `InstanceGroup` is already in the `FrameGroup`. + """ + + if instance_group in self.instance_groups: + raise ValueError( + f"InstanceGroup {instance_group} is already in this FrameGroup " + f"{self.instance_groups}." + ) + + def _raise_if_instance_group_incompatible(self, instance_group: InstanceGroup): + """Raise a ValueError if `InstanceGroup` is incompatible with `FrameGroup`. + + An `InstanceGroup` is incompatible if the `frame_idx` does not match the + `FrameGroup`'s `frame_idx`. + + Args: + instance_group: `InstanceGroup` to check compatibility of. + + Raises: + ValueError: If the `InstanceGroup` is incompatible with the `FrameGroup`. + """ + + if instance_group.frame_idx != self.frame_idx: + raise ValueError( + f"InstanceGroup {instance_group} frame index {instance_group.frame_idx} " + f"does not match FrameGroup frame index {self.frame_idx}." + ) + + def _raise_if_instance_group_not_in_frame_group( + self, instance_group: InstanceGroup + ): + """Raise a ValueError if `InstanceGroup` is not in this `FrameGroup`.""" + + if instance_group not in self.instance_groups: + raise ValueError( + f"InstanceGroup {instance_group} is not in this FrameGroup: " + f"{self.instance_groups}." + ) + + def update_labeled_frames_and_instances_by_cam( + self, return_instances_by_camera: bool = False + ) -> Union[Dict[Camcorder, "LabeledFrame"], Dict[Camcorder, List["Instance"]]]: + """Get all views and `Instance`s across all `RecordingSession`s. + + Updates the `_labeled_frames_by_cam` and `_instances_by_cam` + dictionary attributes. + + Args: + return_instances_by_camera: If true, then returns a dictionary with + `Camcorder` key and `Set[Instance]` values instead. Default is False. + + Returns: + Dictionary with `Camcorder` key and `LabeledFrame` value or `Set[Instance]` + value if `return_instances_by_camera` is True. + """ + + logger.debug( + "Updating LabeledFrames for FrameGroup." + "\n\tPrevious LabeledFrames by Camcorder:" + f"\n\t{self._labeled_frames_by_cam}" + ) + + views: Dict[Camcorder, "LabeledFrame"] = {} + instances_by_cam: Dict[Camcorder, Set["Instance"]] = {} + videos = self.session.get_videos_from_selected_cameras() + for cam, video in videos.items(): + lfs: List["LabeledFrame"] = self.session.labels.get( + (video, [self.frame_idx]) + ) + if len(lfs) == 0: + logger.debug( + f"No LabeledFrames found for video {video} at {self.frame_idx}." + ) + continue + + lf = lfs[0] + if len(lf.instances) == 0: + logger.warning( + f"No Instances found for {lf}." + " There should not be empty LabeledFrames." + ) + continue + + views[cam] = lf + + # Find instances in frame + insts = lf.find(track=-1, user=True) + if len(insts) > 0: + instances_by_cam[cam] = set(insts) + + # Update `_labeled_frames_by_cam` dictionary and return it + self._labeled_frames_by_cam = views + logger.debug( + f"\tUpdated LabeledFrames by Camcorder:\n\t{self._labeled_frames_by_cam}" + ) + # Update `_instances_by_camera` dictionary and return it + self._instances_by_cam = instances_by_cam + return ( + self._instances_by_cam + if return_instances_by_camera + else self._labeled_frames_by_cam + ) + + def update_locked_instance_groups(self) -> List[InstanceGroup]: + """Updates locked `InstanceGroup`s in `FrameGroup`. + + Returns: + List of locked `InstanceGroup`s. + """ + + self._locked_instance_groups: List[InstanceGroup] = [ + instance_group + for instance_group in self.instance_groups + if instance_group.locked + ] + + # Also update locked instances by cam + self.update_locked_instances_by_cam(self._locked_instance_groups) + + return self._locked_instance_groups + + def update_locked_instances_by_cam( + self, locked_instance_groups: List[InstanceGroup] = None + ) -> Dict[Camcorder, Set["Instance"]]: + """Updates locked `Instance`s in `FrameGroup`. + + Args: + locked_instance_groups: List of locked `InstanceGroup`s. Default is None. + If None, then uses `self.locked_instance_groups`. + + Returns: + Dictionary with `Camcorder` key and `Set[Instance]` value. + """ + + if locked_instance_groups is None: + locked_instance_groups = self.locked_instance_groups + + locked_instances_by_cam: Dict[Camcorder, Set["Instance"]] = {} + + # Loop through each camera and append locked instances in specific order + for cam in self.cams_to_include: + locked_instances_by_cam[cam] = set() + for instance_group in locked_instance_groups: + instance = instance_group.get_instance(cam) # Returns None if not found + + # TODO(LM): Should this be adding the dummy instance here? + # LM: No, since just using the number of locked instance groups will + # account for the dummy instances + if instance is not None: + locked_instances_by_cam[cam].add(instance) + + # Only update if there were no errors + self._locked_instances_by_cam = locked_instances_by_cam + return self._locked_instances_by_cam + + # TODO(LM): Should we move this to TriangulateSession? + def generate_hypotheses( + self, as_matrix: bool = True + ) -> Union[np.ndarray, Dict[int, List[InstanceGroup]]]: + """Generates all possible hypotheses from the `FrameGroup`. + + Args: + as_matrix: If True (defualt), then return as a matrix of + `Instance.points_array`. Else return as `Dict[int, List[InstanceGroup]]` + where `int` is the hypothesis identifier and `List[InstanceGroup]` is + the list of `InstanceGroup`s. + + Returns: + Either a `np.ndarray` of shape M x F x T x N x 2 an array if as_matrix where + M: # views, F: # frames = 1, T: # tracks, N: # nodes, 2: x, y + or a dictionary with hypothesis ID key and list of `InstanceGroup`s value. + """ + + # Get all `Instance`s for this frame index across all views to include + instances_by_camera: Dict[ + Camcorder, Set["Instance"] + ] = self.instances_by_cam_to_include + + # Get max number of instances across all views + all_instances_by_camera: List[Set["Instance"]] = instances_by_camera.values() + max_num_instances = max( + [len(instances) for instances in all_instances_by_camera], default=0 + ) + + # Create a dummy instance of all nan values + example_instance: "Instance" = next(iter(all_instances_by_camera[0])) + skeleton: "Skeleton" = example_instance.skeleton + dummy_instance: "Instance" = example_instance.from_numpy( + np.full( + shape=(len(skeleton.nodes), 2), + fill_value=np.nan, + ), + skeleton=skeleton, + ) + + def _fill_in_missing_instances( + unlocked_instances_in_view: List["Instance"], + ): + """Fill in missing instances with dummy instances up to max number. + + Note that this function will mutate the input list in addition to returning + the mutated list. + + Args: + unlocked_instances_in_view: List of instances in a view that are not in + a locked InstanceGroup. + + Returns: + List of instances in a view that are not in a locked InstanceGroup with + dummy instances appended. + """ + + # Subtracting the number of locked instance groups accounts for there being + # dummy instances in the locked instance groups. + num_instances_missing = ( + max_num_instances + - len(unlocked_instances_in_view) + - len( + self.locked_instance_groups + ) # TODO(LM): Make sure this property is getting updated properly + ) + + if num_instances_missing > 0: + # Extend the list of instances with dummy instances + unlocked_instances_in_view.extend( + [dummy_instance] * num_instances_missing + ) + + return unlocked_instances_in_view + + # For each view, get permutations of unlocked instances + unlocked_instance_permutations: Dict[ + Camcorder, Iterator[Tuple["Instance"]] + ] = {} + for cam, instances_in_view in instances_by_camera.items(): + # Gather all instances for this cam from locked `InstanceGroup`s + locked_instances_in_view: Set[ + "Instance" + ] = self._locked_instances_by_cam.get(cam, set()) + + # Remove locked instances from instances in view + unlocked_instances_in_view: List["Instance"] = list( + instances_in_view - locked_instances_in_view + ) + + # Fill in missing instances with dummy instances up to max number + unlocked_instances_in_view = _fill_in_missing_instances( + unlocked_instances_in_view + ) + + # Permuate all `Instance`s in the unlocked `InstanceGroup`s + unlocked_instance_permutations[cam] = permutations( + unlocked_instances_in_view + ) + + # Get products of instances from other views into all possible groupings + # Ordering of dict_values is preserved in Python 3.7+ + products_of_unlocked_instances: Iterator[Iterator[Tuple]] = product( + *unlocked_instance_permutations.values() + ) + + # Reorganize products by cam and add selected instance to each permutation + grouping_hypotheses: Dict[int, List[InstanceGroup]] = {} + for frame_id, prod in enumerate(products_of_unlocked_instances): + grouping_hypotheses[frame_id] = { + # TODO(LM): This is where we would create the `InstanceGroup`s instead + cam: list(inst) + for cam, inst in zip(self.cams_to_include, prod) + } + + # TODO(LM): Should we return this as instance matrices or `InstanceGroup`s? + # Answer: Definitely not instance matrices since we need to keep track of the + # `Instance`s, but I kind of wonder if we could just return a list of + # `InstanceGroup`s instead of a dict then the `InstanceGroup` + + return grouping_hypotheses + + @classmethod + def from_instance_groups( + cls, + session: RecordingSession, + instance_groups: List["InstanceGroup"], + ) -> Optional["FrameGroup"]: + """Creates a `FrameGroup` object from an `InstanceGroup` object. + + Args: + session: `RecordingSession` object. + instance_groups: A list of `InstanceGroup` objects. + + Returns: + `FrameGroup` object or None if no "real" (determined by `frame_idx` other + than None) frames found. + """ + + if len(instance_groups) == 0: + raise ValueError("instance_groups must contain at least one InstanceGroup") + + # Get frame index from first instance group + frame_idx = instance_groups[0].frame_idx + + # Create and return `FrameGroup` object + return cls( + frame_idx=frame_idx, instance_groups=instance_groups, session=session + ) + + def enforce_frame_idx_unique( + self, session: RecordingSession, frame_idx: int + ) -> bool: + """Enforces that all frame indices are unique in `RecordingSession`. + + Removes existing `FrameGroup` object from the + `RecordingSession._frame_group_by_frame_idx`. + + Args: + session: `RecordingSession` object. + frame_idx: Frame index. + """ + + if frame_idx in self._frame_idx_registry.get(session, set()): + # Remove existing `FrameGroup` object from the + # `RecordingSession._frame_group_by_frame_idx` + logger.warning( + f"Frame index {frame_idx} for FrameGroup already exists in this " + "RecordingSession. Overwriting." + ) + session._frame_group_by_frame_idx.pop(frame_idx) diff --git a/tests/data/cameras/minimal_session/min_session_user_labeled.slp b/tests/data/cameras/minimal_session/min_session_user_labeled.slp new file mode 100644 index 0000000000000000000000000000000000000000..c7d8fb2dd2d61df7f43a03774a111547f716128a GIT binary patch literal 60444 zcmeIb33wF6)<0Zf7i3335p2Sy5<(`kC4fvJghkn8UrmP)LL?+13421Cuqr!(Tybka zP(Y9!A7ZWETg#SlKP1#;q z@(&!>qgzZ2A;k#t$D<@8c1dwIR5r9ZAx@Aljw?F;Zz9k>rd3xyvvqQMbLBCjo;>*v z9YRQ{eEwTW`tO6Kg}^aVuJB({p#T1d)GK29UxEPoZy%{|8F`#8;~Gu_<0v>gHva!O zcJJJ!O$aGTFy1cs`u5DwpTF37`*EnhR$Pojz1z$0L*=}5$>uHJ?fGF`KRojG(}%}p#wHDxidNQqtd)N#ag`&kn*n$^;d9A{apeoI zAhczC!T19~h;@a?D``lGKpgFW*UREe8tcHqFg|opT>P+vl)<6R#)l3`N=`_L8=er_EVNCtrWvX6NpZdYnM$_R*DY0x= zMp9}@Y-Va&$ME54uF$a1#Pqna8S!z+OipI{$b_)aL5b;j%E>blGMNT!U{Xrj$jn%= zWn?Y!L*r6X5|T4wljAZnL$kBV->Lk+Tlc?D`3z=^e>e1h8}eNFe>e1hoASex2LHQ( z|NGpADgW;V{_j(s2ep4U^naV%aOMBq(D#oB3(ZVtjAJ802j2SP2p#x8a1kq-!oQ3_ zVb4$GYl^(*3I5Md$}_3Ad4BTiy+SL_f)9nyf9vy;`)~pQ$M>t}Z$P~ACY-;!=vxs2 zMFW!ktocWSzv`dD?1%ISzW@IBE%vLR+t};Z`a&FM z!H2>}wbyjn^Z&-7p!bJ)69`aKYm&bK@hWJ1{?SF>iV!G5pa_8?1d0$SLg2pPHKGb%FD;dHyAU11U75n&!@bhvw9Sbu)53~)Ofd7XTKBRpatV>vh}BQrG} zc9D#4lZf!hC{L8j9p#LQaCkgX?nvOw3FQBbCOtJ%D$hhY-Cjo&|3Atb72)!7P1tF) z%j=Acj`XsQNDq_ciE>6pFpgetM1;rT^)N1MKbw+~jEtJN-R^Lg+v|;TMMg(NdR?xt zCXvyuaF>(KF52ylh>nbkVr`CS#y`>#?cq9lqTLSn!0a%9&=CM($HFj-a(Kf%t|+Fd z#}nalMSBzw)e(1!qSNVM)Af3ofm~j9w8M$;;W0mq4EKmgPh^yn&)gH~iFQUsM}_fL zM}*TG?RIQavpi)qDco{cOq{|Z>%{+vy2?jqS5BGQ@94=?HJJRir;LtiFqQYH_6)4>7<`XV0l z`y>tr%{L<9`j?!{En?Y7DXA_(XJ?u!^x)Z2xm~S z7|Z;bEqWJ=lYlD=YAljRL`1W7qzE0Fvma`^N+Y=7^7Iwaz=Yu z@Qv`gB0YupBE!+^WO0-&TowS>ti0S`!rjhj7HIj!hqbV{$=jTaAsgcLMn*UoG(m7= z(_)L(!$PW)g$RZx2h!yXcSc6Cz{q`(IeeIlE$K)H3!ol13*!+k58~+Y!;$?&r!Z$^ zI&&16VVO}K(NW@p>W`v2;?7ZIR}3B-nf?(`E(bax9`ie6M7V=RMQ~dWb36yT=<&el zh#Ae=T}+DCi!gfg!&s;{Mltt~iimJIm|Y{7-;2rmx6sQASe6Aw^a}v@sS~eGCLb~ zLGh{SgA>x(Psd{6@gd6^7!x}@E{$W&wwR$_=B;d>6zX9|c8L#lvm?7ih4SB+u_(*V z!&qB0))&UcG-G4K7)>)qo1J|(d8@R@OW;rN$`E`o$@2mhu-pQCjx|1XR6;uc0UPHN z&aNViN5-&(}A1$Fn_WdP00=0{a1-_%R#%bz9m{KX>RZM>pZ5DI@E9hc2Per66Qf3^@!#V0cMQxP$dDf& zs;J<-tqQJu8nFFUGe>YXQMYB|;6G8hHIv)$iM~CbJ7vQK-UivoJLbboa2q}hygQqg zgGsy%A2#bdAkWZE=HeV!?d)$mGl<8FoYrN=Y4@AfgFhMJ&FI59KFrJaIVj z%!4O+Ee;1ZMdOJJPaJuy@&-BbI-_t@B%XNj#1Y9N3)A$LUh?TcXZAB<78XY*Bqa{z z!7LmE$j%;dNy)K1cn9xpA2%{3erRlR!jQ}Y$8Z)slK8eJI1doP19R~1n7H)VZUy%C z=+U9TKB-(f%&Q>w-G`=*Oy*a(+YjrJJ~D+rLkX@UOo3n8f_Jx1h~qcB!8oU&H6h3wdB(;3)Y+ty}N)9)yg+>TFX` zh0)#VThw6hamPx{@?O#wplLALQArs|gOU?u7IguHthpxBpHHn46C$QqhO}+fv3m?_ z>(Q+zKd_GY)Zy&Xp2@Nl#iuf!6mYyVuAKNb?>+mp65Kh&97^~EzV|HXJ!8i(q{x4q z-XZ^Yv|pZp{M|+0iV!G5pa_8?1d0$SLg0S`0mFWBB4tydYaJ6Q?0HPz^&H+}O`2$k z7W$8G4BleT9+D$k$T2#Ow-D>h1lID$@x`Uy`DG$~>C0McGnWE2>stB>P>}~)Nj~JB z>sg#iJ5!AFX%aW9S_wwK>Vx?3xIf3U9`B6mJM~u$y8Xhcm{M1HOM{k!HalOONS74f zUDg_q&8gWZGzT~v-gZ(S{gqsy^ODb`EF*dc$Rf4RWDA{*o)bmOjd^m6b;>6+9DPdi zvHDkpc^R`$OEk!~Ss>L~GD@~k@27%Kt^K)@bG=R=ou0frNwk=E4;xFFqIA_2#QLkLttZE-mqz9kLd8YDvKuw5Ndx?<#WlOZz0pS~lev?zN zkA5|ldjO%|E$0^8mGX7d5ayBFAEZIYI@W(YEA6+LncRS?l&8s=KEPsRPhA zj6Q3C4|(@J#IWvT!pcX`H%NRESWq8)8Bsm35~yhDdnA?i*wa9z*VrpttlxhWW2WT! zM#VP-&S$ElZ`2wJoUI>gNZ%l3Q^dUDr%OJx^GDztv{)G_-3Ut*vpCW~Qt7)k!`$sr zD}SbLcTz!S8=!Z+yN5VUFW3E&%gtrgsYOENnh37u@v_LoE2 zodVzZ;-2L#J~bxNkKV6+-5x7_!&ox`=A|tiEcs|7CW2Jskwa1z9rFz!rtLZ{eZwmE zB(#td7&q*G!Z(Z$zmt57x5|ill|kRI?iPqyo_L}7Lu>Lf^!e#RA6lmb>a(t^Ky__8qT#Cy zUS|#6h-2P!i`Ct|q-?P#yb02+of{FK`dz^5$bp5Dk7-U4I@dmm9$spY;IpVC`r%1? zH0WIIDehxi-fnnq*G_{LYdb1Fskz`stdLV+FTK-rQ&b#(5Y|Fc;hpOf<4wQ|MLD31|9#=k^WDG13wmXR}0SC@B^IB z!-Gz5HjZb}BTN7ECZ2_LPl(od3CN=FcajmC{`4(ahm<@k_0i3`z|lU^Mz)adM+Ne| z+e&=&z4HZP8!sWw*A56P_h>6orQ@8!Hy#Jy@T`t6pR;5X-F#tYY)Dcr;6#iu3xLX+ z_$>N{E`5Wxk$kLsKLc50;%Vs#bap2A2K^d+!-^2TVRn91QW;}E6`Y^FEIDg+x`A&H z=Mbqgsa-`}t(tCf|J5Gw!k%R)R40mM>5}t{QXGP8eZue_L&=iJe`JDG}|=9+-pzfHvMMh zGXuX~oJGgYdTdwBdOn+!vEFf`{+vLK?wJYCZRIUB#^~eBG=o-|*uKMw2GTc-1AE}i zd~+J-Q*-O2(&@}MrgS(FHu9A84P*KwkY(&^kG?U)78uq$C1;^ss0Vzk*hg`um7f6Z z4d=5G4egbeh25)thjV{hM(oyz2I=`NpkZuj&-HPfD;s)!u|dbQiv0BX&Jqo64%>?` z49TmAhV{5WgYE+~SeNyFJZMN}o#K(dX_@)X!|zO><0rhkX446d#x7s0H@dR5toLkA zqPrs+B~Oa8pH^_`=YRfiN84wz=wv(V-f6!}G)UP5fyVg+i3VvA0xcH3DCcFot^ptO zpNNJv;B#0hde0d2mYrgy&^q6v*Y(;0IvbB-^fR->O0gyk<{tjgvfLvvuZ*I_YG!=+ z=ipp63-+PTXT?l+Oh(RK#7d!++elxv?-Nnf?2jwO$Q~i0=y+TyM(ppPp)qzL`uxg6 zLZ9ZpOU`EPF0fMc9}-aK^;N}6nSoK%>LON()(WF2tvna-(W)(zeDoh%r&6Z75mr~W zSkJZ{$y@B-9oSOqWCm}cZ{EaNlz$g=Hs)H>7#i%O{m~+Yk0Ec&o+?{@{6Ww3{+Ykb|Kh%bJ@uol#wWq;j3b2S?BfUp36>S7j>Ww>!eurs~IOwAhs;(OyPH z$oiCHpL@G!(v*4Uf86*6Z<)CB^j$5kkD_&A+LpYUI*GNgkA8pvhBT!v*RW30Czt15 z8$*v;ZI6{`K1GbNx77vdbnGCUSN|A+#)pW87T*rWkZr9Hjb7VfUiJxsv2M$9VZn%9 z5)CW1C7@wdnStp&NF}&TQvF%?PL7TRz^-9C{ zxXz#dnU8>FKbIQ zj8{Sc4Rdp-)X-k{0Qe6XvP)`cY%V6y7#c6p(B2Yg=#w=$cXL1tj4|p_wCVX}FfaWi z`i6dWkMNC&8>B3=aV6m!jyZ@%-`zlEg#Lgxz)Vu|&J3 zrNh}bdoda>zX8&%-zFnI%ZxNW#^~}s=M!-~WBlpe$#i8&$F$bZ=kgxL$5UcBKQmt5 ziWWQ^5BeAmy!SL8%>{kP=(bWrd+aM>OoET9X2vE`UZlwCXjX-1l zuZYIVC1T9J9>N(10*PZ^C=e}%#r>+(X*`TfT)?mtyJjk;DovEt)S5)Es4eLzFI zj%XN_#9AhFibR8SodjsmPJqUd%=Z%i+@DN0zty30%-~#x!PLRf8GSAw8hLzA(qG7#LLd>*_g`ppQlo4ZF@X7^6i;N#8JbofK%)YcJ8zo0SGCt12?pu&|f_%Bb&Wr2JE%;IhSMv&`Y0(>$3IsE*#Tmj<|k(ap7s^8@<1X+}h>Q9NH@N zYSVv?m#dO2c~k7!Qe2V`@v>b3VevkE7mu0jW7Hfbt}@To<9yi1`Y2883aoB?Ot~I^ zxgF=S=&$ciY7w8sIrpDFscN6t3AFWx4@9hZcp7VAAG2O*uz__b9iw{pVW6S$%Ookw z{9!fB!l+#zWmVV@E!wNOIMdl11fMHYo)P+(^VqJENic>cNItZ3NG@+7i*qF(d#S-~ zNsNePa*S2|6QHs`9yCMnq0hAfoV9JCl8Tlvayga0uo!YqTL*mX)tBTLeDpC;(b4g! zb9jWH>Yr!#!8xEI8Hi|@@d6F=>~F}~tSHdPs)lI92{a5hqM?n*5u8UZlT`X9v7@$T z;T2zB_XNl?=O^$nt6aOcz8srND?FXGHL;6CgBD|Vylh@%SDvqCv!_L@!Fb!2tAN&$ zPcN@ukp4qkpA$ZI=v_I+`gAtE*8-IrRFlLDjn9O0hm&1uI4KI&r1s zij_i-y(=wfKQC5_buS`DGQ~>K--eY^?|Sx{y0Y<7oELdWtQ6asAy*2y+8H#owaF-J_F6HE`&<%3^Q$+Yg*ki3kWb26DL+lf<;ZIr(KqNs z;TzT^K$Z28-^EHHZd@tkg76J841L4MN)wz@@_ge!CD7S;DbMH6cFm=XolzZpW7qMg zKI>a1m+oyjDCbfqiJ{#|@S&qO<1AX-1)Oc)Vf2m78DLX;3r11WO85rpw+F}UnhVHV zpFkA-d7I|A)*^@2zvtC5Epjl5zBErn(VOTS)~J~fMXhP*8+2~6h@$r(Vs<;>8)k@; zYxrKT)oIti&7wE!zIo=uiJbE{QyTwvU!MuINsAei)?*YUm#PCAX16@wcwa|*8fT8tiTY2p^YYY;|kHx2IMuT)1LnnWaJCHiXK)SuT zH*e{hxc|UJwkPfqJN}~%eFZfk_Yt7d8a9$wbA2O$KH~v+YV=!9+pw?L9r)$)%|=Z4 zk>58D{CN7QpO|lSE)m~#gu%VIU-7kD`W&4^pH10)PHPG>+5T{l(1*MY`-DH}fnsH% zC(`dr?^_Td;xrp`^-~y47k?!8Q6%Odw3wl&kMW7{dHvMqa$a`7SH$l9A$&q*JYtCT z-r|}ZLz|uuPdtnt<-80{hcTqVUHsa8^hS@IgpOIX-^SWcHO-MalhNbEsvO!2sb0<& z8m^y;$Y(twIRAAT_r`2D&QKt9V({?#8iTVo!?mXa?6XMJ>@tbR*i z3_WlL$GonAE$EHuDC@UqF^iPvard=gelAE)Z!4*c;+a5YH7t!|-i?PbWb_vjAH6wy z(m~i~eJ|cpap=gMJ=lCwn|2F1-xH`=-_l~alpV-%wucPiC-#tH(5Ec~-|)4*Uamg# zjU!jDzV*&|j>?V!BRr(?B>LpavBo14_-vA@ZMyH;my_u4zrMHn+r~mBA@wKcvOZ!w z4ZcCLPHn94=tNq6M43kSi5_+pckYJK_VIJL$A0oNXfeI$8>AI`^2_*WNArCCHB%se zX{4N~^-B&QW=y_WG(q81@zr?lRr<$^bzQvxMGM(+Dnv%-)4gkCBa~q)D zr#vZUy2y*jPZ7RB9A{Bh+ygLo`dXfEd|n4;p*+$7+1g`Dc#GXZ@bNaESRv@t+;N=0=n==&N9HhTdM$Cr!#zppKg^jVSkEL%I2Swb_pG05PL^<~LwL@;Lei$rva6XGR(B8%B)wYONkw^y!o5w?8>d z>TFyb4;HjrMIjpNg>P7kr%L3>RpA>NMc>e73*Ruy+>5LG?i}I02O6Q3D-HzC+9x%T z&qwTOA)A(|JtSpW4NT!bv1gEKp@<6nh0@;BF}Pg~o1Dt+$K z^*+7paz6JYbsm0i(qQ&er6nsoS&7LfVkVn6-v1iZsxzxM=gy-@K zxG!XWJmV-(k-PdxDlLWWe;6MvblzCW`9pTy;w=<4)F<^5-n((Lq%!ti6D?;qj1*M* z*0H&q%KEv7)R3P15p<@Rzss4D_3Zk`#|+paTeL<`=kgYT>OutweedBZC8|IlO(8|l^g0~nwom}VFA8g-f)sK_uKYYax+_hc$hUQ%l ze5|w8B$fG!@C~wIE|-4g+mx=;?WuIbcMEgV&vBa`f8_4WZHY>!Sr5_aJw&#^c>A%&vX{0f4P6?ADqwEYnE+J zm^FoV?0E3ymNj92M+k}SEcDsP-q?VZsqcnp30)v(O1*ahs#c{+nPLoC+E?Jb<`P=x z#df(&1~IFylPzSQ@D2M>@Qta*9=>_tb@{L4Q-hIdZEY3Mko$}?qnSV^C zug)FND*gty``~U_yT9%`nz5$IUn}1K4N#F?5HJFmS~VSg>R7BY3Lh^g>Psr z=g65_j|ktOEzmc#4MPOZ8}fYP=k~drGwrZQTA8Lx-&i}6^Qr&q)j#gLnoHyD$JRH* zxIs6Z1}ahkeM7HR8d_|MzG3`u72*b+zYsNir!Ux)e7*@i>hx7;(Z=Ta#ysI0#*VJi zrt~lNTLGq_(Z32tq&L`a1z3w!dJ}SfSi}u`f(z}QE+bnNx8!4Oyl)<-vKz+9J2G;* z)?D7APht08!luV(K#Q@kt9*t=&EMwmG1`?UaMx7A1$>P6@X3HxeUhNs8iIUE#|WxM z`0U-fF(1Z|h7aIgYfTm4qiuRqw%GR%f);Z35-xqn=MBo#pFWpvS@-VhSI@w@XL^)& z0w1e=OZoIb`=+00Ia`-g&2#=%V$?^oDVf&%NUKLAmG#hLV$5^j$axWM^ejFv>&iC@P(6eF|d*0)Bi0j{Pk<$8unY7e+$Atmw_>+JB8Gah5 z$iNknkJ+UdP#JsoNIu%ONN6!nXY%{X1__>icJeIx=~4S18DC7NgzUTe_%pw5ok?fb zjs4`t34SL&bVJ3C-;9_|^$)l7+uKauo73i-L1+7|Jt*tNy@0$Ce^lPX(%r2=mYvl| zN~h}@0TubU8a@^2d_N#>^=4NI@B}kU_=d4ASGHK2X96GELYEj?pZzXm-Pa8*7}{L; z#)l`6>Qoj`S>IGfoDa7JDl0k9Hzo+q*5W#nkKw%=_?Ry&l`YoF0fKW0lS^M4T4wv~ z?78&&khL9lqi>MUhXQBv>;n=Fnj(|q1!vGYT9+P^bNC-j2n$lOW!c!gm2Iv z(KpC9E|@8;1M|Y%u=hb(LJH3f(?j^s>LvM;t)kp8ju`(0%Q(wmIr|B$uL_JP_)oCh zx7*~Dy)jy_YR)8c6Tl~UPMP{|sD7kEOPn{omkDAkFr1k{+sJU%s`TK#J zq2##HsIJ1MO0Jvz(rnP0C^>KDK~w6c=G1<|rUkAAHP_C1W2&&Jl5=PM@dB)cfZRJKN2}ZK zIG|IY95$$c)yRJDD&B+T<>VP(-}?>crsn2ZdrVj*L2~qVjs|N}C0CCwsw^ybMC4_e zc{zL5l~kCin!BehuPUO8lEY_h-5@Nd=JK^TB6^gZKE3)qLXMK#XKye;H%-a$qvr<$ zjhgGHU#kvlLCyKoJoSYiD7k;8W4rKpB?r*fy9)1BaslmSJ0XUsIf3S{SHKEtZlIk! z2ry7{1og?=gcX!rK~jC5z)i^+q>Uefs2n7B@bU|wwVFd{Z!RY0lRuY`d6Rv;0mv!X z6V=>8JMC4_AV`kkyydVK%>22A+-nn;{>y7Y%|W!fWeSg1auIcN;n#eXs5yy7>823H z)Z9e!K{{9=NRHz7?}=5fy7lBu)Ig9!SS0GjyO75cF^C{@DlpIF;uIs|uN-m?; zd>i;%ketS)2SGQgyT8q1iPmo*@wFmzzS*>qV{nj%#_B zl7DJ&Ezy(=Sff^DxNcJ=6IOrzSwW*@#FCEvfkw@YHUHcKh^jfER^_9Bo0=PHb`)`k zDmkKd!*T+h{JEmc3f9|I3bTTeJ8I=z646r4A^mkc$We1i^*wzdx~Ms&dh|}fTFouh z4|fsQZ6(K4e{?0#sJW(=F;J`$CFj)cS_?D?l6zYEYq(|^N)D`-mSbrF%2EK&2NeP9JOThv@u6FfxC8r7CB2ESIbN3{!EAc_$s zi`34U3N&gqDQ!?s?8TImed}SduTrv1jpU(#sI6oo+vXIokD8IJ4QvmH zQYAB)PB;l#2V^L7E42Tl5G$w|%d|{Y&`r%;Hivf;5h;HLGm~T7GwiE>^{&mxpV`c~ z*&DYNa&=ZSovFLMSU*a}vwq-t=utDDjUJc9J%*a|J1bG_29?}j^7Av`qn45bOo|buq*x09xzS7~^Xb)ps}%xrs5vUtj}{UW)tqX3(+4oqAi333`v|O+9BcCCcW}ip zlw51_WeIVAnm^~7>833%1y?3^TEMDTbFj^3zlbO3N-nlGs5y8D&7YIa$?4V` zH567*a=VS<%fSk2j<@DqA?}Rw=X!IUs^43Php4&V_OBb9P3~g zK$M*pVC{l?)EseR*A4L;Ny!!08xIt5RLL2)MhyK7tf1tM(?f0%m6aTFW9lTa>eXEG zaqVEHL2}9`ZWC)k$t^c)^#v;guYwhVBs?&Ox%c-`oNAvDIwuxRzqqsbqC)*Iole)$DG4`A5KwDp}s7&l<5#m27YQ!DR3% zHS3$$gJ2g;mF#a?V?XRE)GTnj>zCkFmXZxlwhn>V8;}*w&^LSivyeTxnk7z3^cGK? zl}Z9egTBzCRus^TT(P@VDhrr3hrvwk{1pbcXT&W9zhs&9p;T5d`!o=Dhf0M7BPvYX zmnfAMtYKe?XsJ|OusS4)c%W2Xu%}-X`!}TmgFU|!V4zlFFqb_HG-^c#WA|dwlfN

1-ZZ51JwRQom@s)@!N(}>h z^-}O_wU&XosUld8Dm4wP#Lt9oN^Jw9#&&^$QsaR36o>|?bue4L{a0!p*oy`Wtd-ga zy03+>Y5p1r+zQ_=gSAkw71Wvt*3(3+dZjjk9otr5pw>u8eofFQwGxch*WhYK^VdvZ z3hA$(6Ep=zB%;(%(7bbiMy;h_EEI2fg49$nRzgGysI9;>Fz2-_uZRM>MdrFMfE_YhbipoRlurX3t3o+A}((|@5SpvD7}LywItWPh#Jd@%o5AZV1@ z59Y3R0ynh=#4ZAHRILS})zlz{s5K$T#sqOstJH?j|9TE$RX~jhHXoz8c&Eos3)&+o zH6zH8Tv&~UQai%v)=)$-rG^B(91puROQ|KHZ`&_0P-;r(@vn+KtWsOTYU>a)RccIF zqgRVxg(ri+l|3Sg z<*&xTpf`IB5!ZzRt)NzCuy0-wR#2)m82v>=3R0~h?Os4st=d4kt%F#lR&Sske-Kej zsp3H1t{@OqsyXN}cf;DY)v69-#IH>(wYtNC9PoHcsq$cs&lQL&l|b~gWS|LB5v0q5 z5Z|a$8N?WS3Zhs*g%G|L{wgWrXn|HxD~6b!Ln2z{uN=a-X-n6`T41LIt*n2gC!n$j zlhe3QA#pTFrI80UfsRsf#3(f!R-;;Z#9AK@*LtHyngdQIU9->yAu<9KL z?+vIv!Ep9W_6W-< zl~AinW{BT zjBjU%{issw#H!y}>|OHLJmI)Kz7`NI7#+0+igmORtW&iXihkFZ!U{@F6#L`y;>x7d zMj>Cd01r`Xq!_*CLyuZ3#U4==W~$aqG3q}o{7tEyqSxyMJ!%aV>aziBwU&z2<72U_ zQL2d0hfj%_D%C`^j0)nqtyC4EP1XvU{MAJm^yb?a3$a3wY9l*c;##6q9ihL9Cm3q= z5o3;cKB_5INQ|1hAsz&&M)KrYaa~ZVl90*`t-S_l zOr>gx`9pbO1+}ux&WQqRrNWIlJ_>9~lu9?|cZuLTR{n}N%nJ6_(S_W(sughb>feBF z0hMqVH~k^Euy%p*t>88HxAiC$bgXix;VPq6(jh;V10Pju$XIQ@6IM`a$ruaviFK;f zl+i0R0pAIzEyE^myyz=rr>fSPF}r;Td!&GxGo0Ij(EmrQpw_4{#+MaO(3M&>B&4ah zE+{o?bni8Yy=v_mWB+oPsanH^tdjXMYAu`fPPi@v)U;u^*)&}2#R~LZwZ@H^AohuB zts8CDCeTf-d87ULxbP6A_Km$N9wG%%YT(d`U&8*)QflFtE6agbsWowoXNQaHf>Ikt zZ+r>lP^CtWR=yqRW+=6C=#kdoZ)(jPy=i-K$D-8EF;=w&L~W&pj-iR)$ONgSb3GDP z7Ex;I8126k?>>~;I@DPK)~T)5*oof*)>don7}xiM9IDpbX+0Eh3#h%rJ>-uwBElBv zA!;oiW6fZ=UkIql!`Er?v4Ax@Eoe8W)abFAir;Yssnrwnn7B47HGA|$Wx+!v7P1!9nm+os#Q{;Zwhw(v{AN3##t)N2dwpKW-EP6?6x^fM0Mhz@ z0l(id)mlIYazxxwY62M@>xyS6N^Kyc`ChSpl&Uz!52az15T!bf{@Xe5QMF2r`C>od zrsc1e!=`6{w!IK5sMT|5v0uOnY84&(LFFmYM5KUXKC!TfT;ZM?NO@?>1Qv1SE<#8^lRv>oUYGo^H)yk^7e;HWq7?L{-L*MGsmX!T1lcEI)nbH#9BqofHGuAdfae-p?3wPxYfvMQEtYt zgye+G)D-Nkj3)&gIc*8((@17|T>LP0eop~ov6;^kX@;by4v$SsPZ*pOpP9g5x)%o( zaP-Fr>qyltU@SIQ)*i@{SECH6k;6LDQj=0LnPFm?iPAIKzs!q+3pn~?lt-J80)}l; z01Y<#(*j0GDT%RCc{dI!;22Ez+tk2Xw@dKhuE|{xKEZ2Xv46oL;IpZKf?W1~-ZJMD$Hu_pEfuUBB)BEf$`Sifrt++$yRf$#( z-a~}=!^@?ks$W$f=|dGCM=9)xHxT&l|E2msA zLf(d|a6(n0m2rqNvVDdR z37+#K%)2oNl!&TuRaJ>rPE};$Psk;rDx6T2XysIsgdw+_TD0IC_eTu(p^D2<3OnM> zXFfvzIaYoz(}x6`{0OV#xxYkCh#a}9szfWNMleF&hI8PAszfWNrX&o>3k9+P=a^n< zfDcvNE&KQ(?kZTbt z`$O{Lq-;QmQ+f~ap`;2SqVt8?6@2xOaUdabSd^)Mewsx2k* z5ea#dA{R!-w!gq$(XffK3{ z88b;3l6TpX4XEm-;=_EXVqZU1aYK=hkoP0u{NAxXBpB*P7%m0ex=_jnRE4XmO0;sC z1|#H*Q58<8O0;q!Bw+|{^!RswRlCOy_o0fT{Zw@XBjjy3$Czh^`jB9%A7P{4Q>p^6**RK*9A`3U{z zxZ08ILxNp?gw3QxzwL?e9OTGVRVA{e7L1U$;T$-jD)H9lH>3)3U`KydXTLYdhbkWQ zQ}q!!J|7`(LvYr|r23G+_9JX53G@B)3%noq;apW!qLt$dM#$Sx6;7y1v~s8<49VLK z$_7;R8%LTCRs6+I)mFg>c|Q_1+#K&if*XE>k4i$~FOl~nN3Nr<8<^jYkG`JZLlw&xyEP~#7$I-NIrix}!iNO+`w_N99)S{36+TB*qLtGw z7$I*%RXCw4(aLEr2}5q}h(!ykdc6H`AF9~cPt^{=2>tDNvHM^j5;XH8d`uGh?ZHJ0 zs=`%OCEmI<%ty#2qAHwFm1yO3L{;-{2Uvr@swMUh@}Y_y{8a4}jF7h>IDIpceMr#9 zkFc{OB>ob48*=2TsuHc7F2M-7syGKus7kbQx=O;3ygL`x;6KMTzYp@Eiiv)zJ`s$N zx8WTBd~>)D2}b!5c9Vp}Um|Zqj$Bn$qLtG<7$H{`=fDY7i59D$UJ|Bc^1ky>#EFw) zAMSO_OMZL|%EG=d*@usR`52UieGzgDKFmbCc{BF86FG`NV$9H@T%ECT; z=!B2I&~BU*`zZgyKKL+)k3m_Kcd$b#6>}xN_;NuIlKX1mqj&cm%jPo%l U3;VjtF?ch}$Dl0i>mkSde;(YT_W%F@ literal 0 HcmV?d00001 diff --git a/tests/fixtures/datasets.py b/tests/fixtures/datasets.py index 790f29946..529087088 100644 --- a/tests/fixtures/datasets.py +++ b/tests/fixtures/datasets.py @@ -278,3 +278,11 @@ def multiview_min_session_labels(): "tests/data/cameras/minimal_session/min_session.slp", video_search=["tests/data/videos/"], ) + + +@pytest.fixture +def multiview_min_session_user_labels(): + return Labels.load_file( + "tests/data/cameras/minimal_session/min_session_user_labeled.slp", + video_search=["tests/data/videos/"], + ) diff --git a/tests/gui/test_commands.py b/tests/gui/test_commands.py index c6a9a279a..78219c52c 100644 --- a/tests/gui/test_commands.py +++ b/tests/gui/test_commands.py @@ -954,468 +954,3 @@ def test_AddSession( assert len(labels.sessions) == 2 assert context.state["session"] is session assert labels.sessions[1] is not session - - -def test_triangulate_session_get_all_views_at_frame( - multiview_min_session_labels: Labels, -): - labels = multiview_min_session_labels - session = labels.sessions[0] - lf = labels.labeled_frames[0] - frame_idx = lf.frame_idx - - # Test with no cams_to_include, expect views from all linked cameras - views = TriangulateSession.get_all_views_at_frame(session, frame_idx) - assert len(views) == len(session.linked_cameras) - for cam in session.linked_cameras: - assert views[cam].frame_idx == frame_idx - assert views[cam].video == session[cam] - - # Test with cams_to_include, expect views from only those cameras - cams_to_include = session.linked_cameras[0:2] - views = TriangulateSession.get_all_views_at_frame( - session, frame_idx, cams_to_include=cams_to_include - ) - assert len(views) == len(cams_to_include) - for cam in cams_to_include: - assert views[cam].frame_idx == frame_idx - assert views[cam].video == session[cam] - - -def test_triangulate_session_get_instances_across_views( - multiview_min_session_labels: Labels, -): - labels = multiview_min_session_labels - session = labels.sessions[0] - - # Test get_instances_across_views - lf: LabeledFrame = labels[0] - track = labels.tracks[0] - instances: Dict[ - Camcorder, Instance - ] = TriangulateSession.get_instances_across_views( - session=session, frame_idx=lf.frame_idx, track=track - ) - assert len(instances) == len(session.videos) - for vid in session.videos: - cam = session[vid] - instances_in_view = instances[cam] - for inst in instances_in_view: - assert inst.frame_idx == lf.frame_idx - assert inst.track == track - assert inst.video == vid - - # Try with excluding cam views - lf: LabeledFrame = labels[2] - track = labels.tracks[1] - cams_to_include = session.linked_cameras[:4] - videos_to_include: Dict[ - Camcorder, Video - ] = session.get_videos_from_selected_cameras(cams_to_include=cams_to_include) - assert len(cams_to_include) == 4 - assert len(videos_to_include) == len(cams_to_include) - instances: Dict[ - Camcorder, Instance - ] = TriangulateSession.get_instances_across_views( - session=session, - frame_idx=lf.frame_idx, - track=track, - cams_to_include=cams_to_include, - ) - assert len(instances) == len( - videos_to_include - ) # May not be true if no instances at that frame - for cam, vid in videos_to_include.items(): - instances_in_view = instances[cam] - for inst in instances_in_view: - assert inst.frame_idx == lf.frame_idx - assert inst.track == track - assert inst.video == vid - - # Try with only a single view - cams_to_include = [session.linked_cameras[0]] - with pytest.raises(ValueError): - instances = TriangulateSession.get_instances_across_views( - session=session, - frame_idx=lf.frame_idx, - cams_to_include=cams_to_include, - track=track, - require_multiple_views=True, - ) - - # Try with multiple views, but not enough instances - track = labels.tracks[1] - cams_to_include = session.linked_cameras[4:6] - with pytest.raises(ValueError): - instances = TriangulateSession.get_instances_across_views( - session=session, - frame_idx=lf.frame_idx, - cams_to_include=cams_to_include, - track=track, - require_multiple_views=True, - ) - - -def test_triangulate_session_get_and_verify_enough_instances( - multiview_min_session_labels: Labels, - caplog, -): - labels = multiview_min_session_labels - session = labels.sessions[0] - lf = labels.labeled_frames[0] - - # Test with no cams_to_include, expect views from all linked cameras - instances = TriangulateSession.get_and_verify_enough_instances( - session=session, frame_idx=lf.frame_idx - ) - instances_in_frame = instances[0] - assert ( - len(instances_in_frame) == 8 - ) # All views should have same number of instances (padded with dummy instance(s)) - for cam in session.linked_cameras: - if cam.name in ["side", "sideL"]: # The views that don't have an instance - continue - instances_in_view = instances_in_frame[cam] - for inst in instances_in_view: - assert inst.frame_idx == lf.frame_idx - assert inst.video == session[cam] - - # Test with cams_to_include, expect views from only those cameras - cams_to_include = session.linked_cameras[-2:] - instances = TriangulateSession.get_and_verify_enough_instances( - session=session, - frame_idx=lf.frame_idx, - cams_to_include=cams_to_include, - ) - instances_in_frame = instances[lf.frame_idx] - assert len(instances_in_frame) == len(cams_to_include) - for cam in cams_to_include: - instances_in_view = instances_in_frame[cam] - for inst in instances_in_view: - assert inst.frame_idx == lf.frame_idx - assert inst.video == session[cam] - - # Test with not enough instances, expect views from only those cameras - cams_to_include = session.linked_cameras[0:2] - cam = cams_to_include[0] - video = session[cam] - lfs = labels.find(video, lf.frame_idx) - lf = lfs[0] - lf.instances = [] - instances = TriangulateSession.get_and_verify_enough_instances( - session=session, - frame_idx=lf.frame_idx, - cams_to_include=cams_to_include, - ) - assert isinstance(instances, bool) - assert not instances - messages = "".join([rec.message for rec in caplog.records]) - assert "No Instances found for" in messages - - -def test_triangulate_session_verify_enough_views( - multiview_min_session_labels: Labels, caplog -): - labels = multiview_min_session_labels - session = labels.sessions[0] - - # Test with enough views - enough_views = TriangulateSession.verify_enough_views( - session=session, show_dialog=False - ) - assert enough_views - messages = "".join([rec.message for rec in caplog.records]) - assert len(messages) == 0 - caplog.clear() - - # Test with not enough views - cams_to_include = [session.linked_cameras[0]] - enough_views = TriangulateSession.verify_enough_views( - session=session, cams_to_include=cams_to_include, show_dialog=False - ) - assert not enough_views - messages = "".join([rec.message for rec in caplog.records]) - assert "One or less cameras available." in messages - - -def test_triangulate_session_verify_views_and_instances( - multiview_min_session_labels: Labels, -): - labels = multiview_min_session_labels - session = labels.sessions[0] - - # Test with enough views and instances - lf = labels.labeled_frames[0] - instance = lf.instances[0] - - context = CommandContext.from_labels(labels) - params = { - "video": session.videos[0], - "session": session, - "frame_idx": lf.frame_idx, - "instance": instance, - "show_dialog": False, - } - enough_views = TriangulateSession.verify_views_and_instances(context, params) - assert enough_views - assert "instances" in params - - # Test with not enough views - cams_to_include = [session.linked_cameras[0]] - params = { - "video": session.videos[0], - "session": session, - "frame_idx": lf.frame_idx, - "instance": instance, - "cams_to_include": cams_to_include, - "show_dialog": False, - } - enough_views = TriangulateSession.verify_views_and_instances(context, params) - assert not enough_views - assert "instances" not in params - - -def test_triangulate_session_calculate_reprojected_points( - multiview_min_session_labels: Labels, -): - """Test `TriangulateSession.calculate_reprojected_points`.""" - - session = multiview_min_session_labels.sessions[0] - lf: LabeledFrame = multiview_min_session_labels[0] - track = multiview_min_session_labels.tracks[0] - instances: Dict[ - Camcorder, Instance - ] = TriangulateSession.get_instances_across_views_multiple_frames( - session=session, frame_inds=[lf.frame_idx], track=track - ) - instances_and_coords = TriangulateSession.calculate_reprojected_points( - session=session, instances=instances - ) - - # Check that we get the same number of instances as input - assert len(instances) == len(instances_and_coords) - - # Check that each instance has the same number of points - for instances_in_frame in instances_and_coords.values(): - for instances_in_view in instances_in_frame.values(): - for inst, inst_coords in instances_in_view: - assert inst_coords.shape[0] == len(inst.skeleton) # (15, 2) - - -def test_triangulate_session_get_instances_matrices( - multiview_min_session_labels: Labels, -): - """Test `TriangulateSession.get_instance_matrices`.""" - labels = multiview_min_session_labels - session = labels.sessions[0] - lf: LabeledFrame = labels[0] - track = labels.tracks[0] - instances: Dict[ - int, Dict[Camcorder, List[Instance]] - ] = TriangulateSession.get_instances_across_views_multiple_frames( - session=session, frame_inds=[lf.frame_idx], track=track - ) - instances_matrices, cams_ordered = TriangulateSession.get_instances_matrices( - instances=instances - ) - - # Verify shape - n_frames = len(instances) - n_views = len(instances[lf.frame_idx]) - assert n_views == len(cams_ordered) - n_tracks = len(instances[lf.frame_idx][cams_ordered[0]]) - assert n_tracks == 1 - n_nodes = len(labels.skeleton) - assert instances_matrices.shape == (n_views, n_frames, n_tracks, n_nodes, 2) - - -def test_triangulate_session_update_instances(multiview_min_session_labels: Labels): - """Test `RecordingSession.update_instances`.""" - - # Test update_instances - session = multiview_min_session_labels.sessions[0] - lf: LabeledFrame = multiview_min_session_labels[0] - track = multiview_min_session_labels.tracks[0] - instances: Dict[ - int, Dict[Camcorder, List[Instance]] - ] = TriangulateSession.get_instances_across_views_multiple_frames( - session=session, - frame_inds=[lf.frame_idx], - track=track, - require_multiple_views=True, - ) - instances_and_coordinates = TriangulateSession.calculate_reprojected_points( - session=session, instances=instances - ) - for instances_in_frame in instances_and_coordinates.values(): - for instances_in_view in instances_in_frame.values(): - for inst, inst_coords in instances_in_view: - assert inst_coords.shape == ( - len(inst.skeleton), - 2, - ) # Nodes, 2 - # Assert coord are different from original - assert not np.array_equal(inst_coords, inst.points_array) - - # Just run for code coverage testing, do not test output here (race condition) - # (see "functional core, imperative shell" pattern) - TriangulateSession.update_instances( - instances_and_coords=instances_and_coordinates[0] - ) - - -def test_triangulate_session_do_action(multiview_min_session_labels: Labels): - """Test `TriangulateSession.do_action`.""" - - labels = multiview_min_session_labels - session = labels.sessions[0] - - # Test with enough views and instances - lf = labels.labeled_frames[0] - instance = lf.instances[0] - - context = CommandContext.from_labels(labels) - params = { - "video": session.videos[0], - "session": session, - "frame_idx": lf.frame_idx, - "instance": instance, - "ask_again": True, - } - TriangulateSession.do_action(context, params) - - # Test with not enough views - cams_to_include = [session.linked_cameras[0]] - params = { - "video": session.videos[0], - "session": session, - "frame_idx": lf.frame_idx, - "instance": instance, - "cams_to_include": cams_to_include, - "ask_again": True, - } - TriangulateSession.do_action(context, params) - - -def test_triangulate_session(multiview_min_session_labels: Labels): - """Test `TriangulateSession`.""" - - labels = multiview_min_session_labels - session = labels.sessions[0] - video = session.videos[0] - lf = labels.labeled_frames[0] - instance = lf.instances[0] - context = CommandContext.from_labels(labels) - - # Test with enough views and instances so we don't get any GUI pop-ups - context.triangulateSession( - frame_idx=lf.frame_idx, - video=video, - instance=instance, - session=session, - ) - - # Test with using state to gather params - context.state["session"] = session - context.state["video"] = video - context.state["instance"] = instance - context.state["frame_idx"] = lf.frame_idx - context.triangulateSession() - - -def test_triangulate_session_get_products_of_instances( - multiview_min_session_labels: Labels, -): - """Test `TriangulateSession.get_products_of_instances`.""" - - labels = multiview_min_session_labels - session = labels.sessions[0] - lf = labels.labeled_frames[0] - selected_instance = lf.instances[0] - - instances = TriangulateSession.get_products_of_instances( - session=session, - frame_idx=lf.frame_idx, - ) - - views = TriangulateSession.get_all_views_at_frame(session, lf.frame_idx) - max_num_instances_in_view = max([len(instances) for instances in views.values()]) - assert len(instances) == max_num_instances_in_view ** len(views) - - for frame_id in instances: - instances_in_frame = instances[frame_id] - for cam in instances_in_frame: - instances_in_view = instances_in_frame[cam] - assert len(instances_in_view) == max_num_instances_in_view - for inst in instances_in_view: - try: - assert inst.frame_idx == selected_instance.frame_idx - assert inst.video == session[cam] - except: - assert inst.frame is None - assert inst.video is None - - -def test_triangulate_session_calculate_error_per_frame( - multiview_min_session_labels: Labels, -): - """Test `TriangulateSession.calculate_error_per_frame`.""" - - labels = multiview_min_session_labels - session = labels.sessions[0] - lf = labels.labeled_frames[0] - - instances = TriangulateSession.get_products_of_instances( - session=session, - frame_idx=lf.frame_idx, - ) - - ( - reprojection_error_per_frame, - instances_and_coords, - ) = TriangulateSession.calculate_error_per_frame( - session=session, instances=instances - ) - - for frame_id in instances.keys(): - assert frame_id in reprojection_error_per_frame - assert isinstance(reprojection_error_per_frame[frame_id], float) - - -def test_triangulate_session_get_instance_grouping( - multiview_min_session_labels: Labels, -): - """Test `TriangulateSession._get_instance_grouping`.""" - - labels = multiview_min_session_labels - session = labels.sessions[0] - lf = labels.labeled_frames[0] - selected_instance = lf.instances[0] - - instances = TriangulateSession.get_products_of_instances( - session=session, - frame_idx=lf.frame_idx, - ) - - ( - reprojection_error_per_frame, - instances_and_coords, - ) = TriangulateSession.calculate_error_per_frame( - session=session, instances=instances - ) - - best_instances, frame_id_min_error = TriangulateSession._get_instance_grouping( - instances=instances, reprojection_error_per_frame=reprojection_error_per_frame - ) - assert len(best_instances) == len(session.camera_cluster) - for instances_in_view in best_instances.values(): - tracks_in_view = set( - [inst.track if inst is not None else "None" for inst in instances_in_view] - ) - assert len(tracks_in_view) == len(instances_in_view) - for inst in instances_in_view: - try: - assert inst.frame_idx == selected_instance.frame_idx - except: - assert inst.frame is None - assert inst.track is None diff --git a/tests/io/test_cameras.py b/tests/io/test_cameras.py index 35ecaa50e..f5b0ba014 100644 --- a/tests/io/test_cameras.py +++ b/tests/io/test_cameras.py @@ -1,12 +1,18 @@ """Module to test functions in `sleap.io.cameras`.""" -from typing import Dict, List +from typing import Dict, List, Tuple, Union import numpy as np import pytest -from sleap.io.cameras import Camcorder, CameraCluster, RecordingSession -from sleap.io.dataset import Instance, LabeledFrame, Labels, LabelsDataCache +from sleap.io.cameras import ( + Camcorder, + CameraCluster, + InstanceGroup, + FrameGroup, + RecordingSession, +) +from sleap.io.dataset import Instance, Labels from sleap.io.video import Video @@ -280,3 +286,164 @@ def test_recording_session_remove_video(multiview_min_session_labels: Labels): session.remove_video(video) assert labels_cache._session_by_video.get(video, None) is None assert video not in session.videos + + +# TODO(LM): Remove after adding method to (de)seralize `InstanceGroup` +def create_instance_group( + labels: Labels, + frame_idx: int, + add_dummy: bool = False, +) -> Union[ + InstanceGroup, Tuple[InstanceGroup, Dict[Camcorder, Instance], Instance, Camcorder] +]: + """Create an `InstanceGroup` from a `Labels` object. + + Args: + labels: The `Labels` object to use. + frame_idx: The frame index to use. + add_dummy: Whether to add a dummy instance to the `InstanceGroup`. + + Returns: + The `InstanceGroup` object. + """ + + session = labels.sessions[0] + + lf = labels.labeled_frames[0] + instance = lf.instances[0] + + instance_by_camera = {} + for cam in session.linked_cameras: + video = session.get_video(cam) + lfs_in_view = labels.find(video=video, frame_idx=frame_idx) + if len(lfs_in_view) > 0: + instance = lfs_in_view[0].instances[0] + instance_by_camera[cam] = instance + + # Add a dummy instance to make sure it gets ignored + if add_dummy: + dummy_instance = Instance.from_numpy( + np.full( + shape=(len(instance.skeleton.nodes), 2), + fill_value=np.nan, + ), + skeleton=instance.skeleton, + ) + instance_by_camera[cam] = dummy_instance + + instance_group = InstanceGroup.from_dict(d=instance_by_camera) + return ( + (instance_group, instance_by_camera, dummy_instance, cam) + if add_dummy + else instance_group + ) + + +def test_instance_group(multiview_min_session_labels: Labels): + """Test `InstanceGroup` data structure.""" + + labels = multiview_min_session_labels + session = labels.sessions[0] + camera_cluster = session.camera_cluster + + lf = labels.labeled_frames[0] + frame_idx = lf.frame_idx + + # Test `from_dict` + instance_group, instance_by_camera, dummy_instance, cam = create_instance_group( + labels=labels, frame_idx=frame_idx, add_dummy=True + ) + assert isinstance(instance_group, InstanceGroup) + assert instance_group.frame_idx == frame_idx + assert instance_group.camera_cluster == camera_cluster + for camera in session.linked_cameras: + if camera == cam: + assert instance_by_camera[camera] == dummy_instance + assert camera not in instance_group.cameras + else: + instance = instance_group[camera] + assert isinstance(instance, Instance) + assert instance_group[camera] == instance_by_camera[camera] + assert instance_group[instance] == camera + + # Test `__repr__` + print(instance_group) + + # Test `__len__` + assert len(instance_group) == len(instance_by_camera) - 1 + + # Test `get_cam` + assert instance_group.get_cam(dummy_instance) is None + + # Test `get_instance` + assert instance_group.get_instance(cam) is None + + # Test `instances` property + assert len(instance_group.instances) == len(instance_by_camera) - 1 + + # Test `cameras` property + assert len(instance_group.cameras) == len(instance_by_camera) - 1 + + # Test `__getitem__` with `int` key + assert isinstance(instance_group[0], Instance) + with pytest.raises(KeyError): + instance_group[len(instance_group)] + + # Test `_dummy_instance` property + assert ( + instance_group.dummy_instance.skeleton == instance_group.instances[0].skeleton + ) + assert isinstance(instance_group.dummy_instance, Instance) + + # Test `numpy` method + instance_group_numpy = instance_group.numpy() + assert isinstance(instance_group_numpy, np.ndarray) + n_views, n_nodes, n_coords = instance_group_numpy.shape + assert n_views == len(instance_group.camera_cluster.cameras) + assert n_nodes == len(instance_group.dummy_instance.skeleton.nodes) + assert n_coords == 2 + + # Test `update_points` method + instance_group.update_points(np.full((n_views, n_nodes, n_coords), 0)) + instance_group_numpy = instance_group.numpy() + np.nan_to_num(instance_group_numpy, nan=0) + assert np.all(np.nan_to_num(instance_group_numpy, nan=0) == 0) + + # Populate with only dummy instance and test `from_dict` + instance_by_camera = {cam: dummy_instance} + instance_group = InstanceGroup.from_dict(d=instance_by_camera) + assert instance_group is None + + +def test_frame_group(multiview_min_session_labels: Labels): + """Test `FrameGroup` data structure.""" + + labels = multiview_min_session_labels + session = labels.sessions[0] + + # Test `from_instance_groups` from list of instance groups + frame_idx_1 = 0 + instance_group = create_instance_group(labels=labels, frame_idx=frame_idx_1) + instance_groups: List[InstanceGroup] = [instance_group] + frame_group_1 = FrameGroup.from_instance_groups( + session=session, instance_groups=instance_groups + ) + assert isinstance(frame_group_1, FrameGroup) + assert session in frame_group_1._frame_idx_registry + assert len(frame_group_1._frame_idx_registry) == 1 + assert frame_group_1._frame_idx_registry[session] == {frame_idx_1} + + # Test `_frame_idx_registry` property + frame_idx_2 = 1 + instance_group = create_instance_group(labels=labels, frame_idx=frame_idx_2) + instance_groups: List[InstanceGroup] = [instance_group] + frame_group_2 = FrameGroup.from_instance_groups( + session=session, instance_groups=instance_groups + ) + assert isinstance(frame_group_2, FrameGroup) + assert session in frame_group_2._frame_idx_registry + assert len(frame_group_2._frame_idx_registry) == 1 + assert frame_group_2._frame_idx_registry[session] == {frame_idx_1, frame_idx_2} + assert frame_group_1._frame_idx_registry == frame_group_2._frame_idx_registry + + # TODO(LM): Test `generate_hypotheses` From 9d619803a92c4a7b71ea8c7ababbebd744ffd96d Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Fri, 12 Apr 2024 13:18:45 -0700 Subject: [PATCH 09/22] Only use user-`Instance`s for triangulation --- sleap/gui/commands.py | 2 +- sleap/io/cameras.py | 31 +++++++++++++++++++++++++++---- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index fbfcb4b81..ed9c377d2 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -3436,7 +3436,7 @@ def do_action(cls, context: CommandContext, params: dict): return # Not enough instances for triangulation # Get the `FrameGroup` of shape M=include x T x N x 2 - fg_tensor = frame_group.numpy(instance_groups=instance_groups) + fg_tensor = frame_group.numpy(instance_groups=instance_groups, pred_as_nan=True) # Add extra dimension for number of frames frame_group_tensor = np.expand_dims(fg_tensor, axis=1) # M=include x F=1 xTxNx2 diff --git a/sleap/io/cameras.py b/sleap/io/cameras.py index a9151d93c..d0df2a053 100644 --- a/sleap/io/cameras.py +++ b/sleap/io/cameras.py @@ -483,19 +483,36 @@ def cameras(self) -> List[Camcorder]: """List of `Camcorder` objects.""" return list(self._instance_by_camcorder.keys()) - def numpy(self) -> np.ndarray: + def numpy(self, pred_as_nan: bool = False) -> np.ndarray: """Return instances as a numpy array of shape (n_views, n_nodes, 2). + The ordering of views is based on the ordering of `Camcorder`s in the `self.camera_cluster: CameraCluster`. + If an instance is missing for a `Camcorder`, then the instance is filled in with the dummy instance (all NaNs). + + Args: + pred_as_nan: If True, then replaces `PredictedInstance`s with all nan + self.dummy_instance. Default is False. + Returns: Numpy array of shape (n_views, n_nodes, 2). """ instance_numpys: List[np.ndarray] = [] # len(M) x N x 2 for cam in self.camera_cluster.cameras: - instance = self.get_instance(cam) or self.dummy_instance + instance = self.get_instance(cam) + + # Determine whether to use a dummy (all nan) instance + instance_is_missing = instance is None + instance_as_nan = pred_as_nan and isinstance(instance, PredictedInstance) + use_dummy_instance = instance_is_missing or instance_as_nan + + # Add the dummy instance if the instance is missing + if use_dummy_instance: + instance = self.dummy_instance # This is an all nan PredictedInstance + instance_numpy: np.ndarray = instance.numpy() # N x 2 instance_numpys.append(instance_numpy) @@ -1363,13 +1380,17 @@ def locked_instance_groups(self) -> List[InstanceGroup]: return self._locked_instance_groups def numpy( - self, instance_groups: Optional[List[InstanceGroup]] = None + self, + instance_groups: Optional[List[InstanceGroup]] = None, + pred_as_nan: bool = False, ) -> np.ndarray: """Numpy array of all `InstanceGroup`s in `FrameGroup.cams_to_include`. Args: instance_groups: `InstanceGroup`s to include. Default is None and uses all self.instance_groups. + pred_as_nan: If True, then replaces `PredictedInstance`s with all nan + self.dummy_instance. Default is False. Returns: Numpy array of shape (M, T, N, 2) where M is the number of views (determined @@ -1391,7 +1412,9 @@ def numpy( instance_group_numpys: List[np.ndarray] = [] # len(T) M=all x N x 2 for instance_group in instance_groups: - instance_group_numpy = instance_group.numpy() # M=all x N x 2 + instance_group_numpy = instance_group.numpy( + pred_as_nan=pred_as_nan + ) # M=all x N x 2 instance_group_numpys.append(instance_group_numpy) frame_group_numpy = np.stack(instance_group_numpys, axis=1) # M=all x T x N x 2 From 08eb996baeef5929d5ecaff40a0f0bd9af99f91e Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Fri, 12 Apr 2024 14:12:49 -0700 Subject: [PATCH 10/22] Remove unused code --- sleap/io/cameras.py | 200 +------------------------------------------- 1 file changed, 1 insertion(+), 199 deletions(-) diff --git a/sleap/io/cameras.py b/sleap/io/cameras.py index d0df2a053..28bef62c8 100644 --- a/sleap/io/cameras.py +++ b/sleap/io/cameras.py @@ -1,6 +1,5 @@ """Module for storing information for camera groups.""" -from itertools import permutations, product import logging import tempfile from pathlib import Path @@ -1282,13 +1281,6 @@ class FrameGroup: _labeled_frames_by_cam: Dict[Camcorder, "LabeledFrame"] = field(factory=dict) _instances_by_cam: Dict[Camcorder, Set["Instance"]] = field(factory=dict) - # TODO(LM): This dict should be updated each time an InstanceGroup is - # added/removed/locked/unlocked - _locked_instance_groups: List[InstanceGroup] = field(factory=list) - _locked_instances_by_cam: Dict[Camcorder, Set["Instance"]] = field( - factory=dict - ) # Internally updated in `update_locked_instances_by_cam` - def __attrs_post_init__(self): """Initialize `FrameGroup` object.""" @@ -1311,9 +1303,6 @@ def __attrs_post_init__(self): # Initialize `_labeled_frames_by_cam` dictionary self.update_labeled_frames_and_instances_by_cam() - # Initialize `_locked_instance_groups` dictionary - self.update_locked_instance_groups() - # The dummy labeled frame will only be set once for the first `FrameGroup` made if self._dummy_labeled_frame is None: self._dummy_labeled_frame = self.labeled_frames[0] @@ -1367,18 +1356,6 @@ def cameras(self) -> List[Camcorder]: return list(self._labeled_frames_by_cam.keys()) - @property - def instances_by_cam_to_include(self) -> Dict[Camcorder, Set["Instance"]]: - """List of `Camcorder`s.""" - - return {cam: self._instances_by_cam[cam] for cam in self.cams_to_include} - - @property - def locked_instance_groups(self) -> List[InstanceGroup]: - """List of locked `InstanceGroup`s.""" - - return self._locked_instance_groups - def numpy( self, instance_groups: Optional[List[InstanceGroup]] = None, @@ -1545,7 +1522,6 @@ def add_labeled_frame(self, labeled_frame: "LabeledFrame", camera: Camcorder): # Add the `LabeledFrame` to the `FrameGroup` self._labeled_frames_by_cam[camera] = labeled_frame - # TODO(LM): Should this be an EditCommand instead? # Add the `LabeledFrame` to the `RecordingSession`'s `Labels` object if labeled_frame not in self.session.labels: self.session.labels.append(labeled_frame) @@ -1845,7 +1821,7 @@ def update_labeled_frames_and_instances_by_cam( views[cam] = lf # Find instances in frame - insts = lf.find(track=-1, user=True) + insts = lf.find(track=-1) if len(insts) > 0: instances_by_cam[cam] = set(insts) @@ -1862,180 +1838,6 @@ def update_labeled_frames_and_instances_by_cam( else self._labeled_frames_by_cam ) - def update_locked_instance_groups(self) -> List[InstanceGroup]: - """Updates locked `InstanceGroup`s in `FrameGroup`. - - Returns: - List of locked `InstanceGroup`s. - """ - - self._locked_instance_groups: List[InstanceGroup] = [ - instance_group - for instance_group in self.instance_groups - if instance_group.locked - ] - - # Also update locked instances by cam - self.update_locked_instances_by_cam(self._locked_instance_groups) - - return self._locked_instance_groups - - def update_locked_instances_by_cam( - self, locked_instance_groups: List[InstanceGroup] = None - ) -> Dict[Camcorder, Set["Instance"]]: - """Updates locked `Instance`s in `FrameGroup`. - - Args: - locked_instance_groups: List of locked `InstanceGroup`s. Default is None. - If None, then uses `self.locked_instance_groups`. - - Returns: - Dictionary with `Camcorder` key and `Set[Instance]` value. - """ - - if locked_instance_groups is None: - locked_instance_groups = self.locked_instance_groups - - locked_instances_by_cam: Dict[Camcorder, Set["Instance"]] = {} - - # Loop through each camera and append locked instances in specific order - for cam in self.cams_to_include: - locked_instances_by_cam[cam] = set() - for instance_group in locked_instance_groups: - instance = instance_group.get_instance(cam) # Returns None if not found - - # TODO(LM): Should this be adding the dummy instance here? - # LM: No, since just using the number of locked instance groups will - # account for the dummy instances - if instance is not None: - locked_instances_by_cam[cam].add(instance) - - # Only update if there were no errors - self._locked_instances_by_cam = locked_instances_by_cam - return self._locked_instances_by_cam - - # TODO(LM): Should we move this to TriangulateSession? - def generate_hypotheses( - self, as_matrix: bool = True - ) -> Union[np.ndarray, Dict[int, List[InstanceGroup]]]: - """Generates all possible hypotheses from the `FrameGroup`. - - Args: - as_matrix: If True (defualt), then return as a matrix of - `Instance.points_array`. Else return as `Dict[int, List[InstanceGroup]]` - where `int` is the hypothesis identifier and `List[InstanceGroup]` is - the list of `InstanceGroup`s. - - Returns: - Either a `np.ndarray` of shape M x F x T x N x 2 an array if as_matrix where - M: # views, F: # frames = 1, T: # tracks, N: # nodes, 2: x, y - or a dictionary with hypothesis ID key and list of `InstanceGroup`s value. - """ - - # Get all `Instance`s for this frame index across all views to include - instances_by_camera: Dict[ - Camcorder, Set["Instance"] - ] = self.instances_by_cam_to_include - - # Get max number of instances across all views - all_instances_by_camera: List[Set["Instance"]] = instances_by_camera.values() - max_num_instances = max( - [len(instances) for instances in all_instances_by_camera], default=0 - ) - - # Create a dummy instance of all nan values - example_instance: "Instance" = next(iter(all_instances_by_camera[0])) - skeleton: "Skeleton" = example_instance.skeleton - dummy_instance: "Instance" = example_instance.from_numpy( - np.full( - shape=(len(skeleton.nodes), 2), - fill_value=np.nan, - ), - skeleton=skeleton, - ) - - def _fill_in_missing_instances( - unlocked_instances_in_view: List["Instance"], - ): - """Fill in missing instances with dummy instances up to max number. - - Note that this function will mutate the input list in addition to returning - the mutated list. - - Args: - unlocked_instances_in_view: List of instances in a view that are not in - a locked InstanceGroup. - - Returns: - List of instances in a view that are not in a locked InstanceGroup with - dummy instances appended. - """ - - # Subtracting the number of locked instance groups accounts for there being - # dummy instances in the locked instance groups. - num_instances_missing = ( - max_num_instances - - len(unlocked_instances_in_view) - - len( - self.locked_instance_groups - ) # TODO(LM): Make sure this property is getting updated properly - ) - - if num_instances_missing > 0: - # Extend the list of instances with dummy instances - unlocked_instances_in_view.extend( - [dummy_instance] * num_instances_missing - ) - - return unlocked_instances_in_view - - # For each view, get permutations of unlocked instances - unlocked_instance_permutations: Dict[ - Camcorder, Iterator[Tuple["Instance"]] - ] = {} - for cam, instances_in_view in instances_by_camera.items(): - # Gather all instances for this cam from locked `InstanceGroup`s - locked_instances_in_view: Set[ - "Instance" - ] = self._locked_instances_by_cam.get(cam, set()) - - # Remove locked instances from instances in view - unlocked_instances_in_view: List["Instance"] = list( - instances_in_view - locked_instances_in_view - ) - - # Fill in missing instances with dummy instances up to max number - unlocked_instances_in_view = _fill_in_missing_instances( - unlocked_instances_in_view - ) - - # Permuate all `Instance`s in the unlocked `InstanceGroup`s - unlocked_instance_permutations[cam] = permutations( - unlocked_instances_in_view - ) - - # Get products of instances from other views into all possible groupings - # Ordering of dict_values is preserved in Python 3.7+ - products_of_unlocked_instances: Iterator[Iterator[Tuple]] = product( - *unlocked_instance_permutations.values() - ) - - # Reorganize products by cam and add selected instance to each permutation - grouping_hypotheses: Dict[int, List[InstanceGroup]] = {} - for frame_id, prod in enumerate(products_of_unlocked_instances): - grouping_hypotheses[frame_id] = { - # TODO(LM): This is where we would create the `InstanceGroup`s instead - cam: list(inst) - for cam, inst in zip(self.cams_to_include, prod) - } - - # TODO(LM): Should we return this as instance matrices or `InstanceGroup`s? - # Answer: Definitely not instance matrices since we need to keep track of the - # `Instance`s, but I kind of wonder if we could just return a list of - # `InstanceGroup`s instead of a dict then the `InstanceGroup` - - return grouping_hypotheses - @classmethod def from_instance_groups( cls, From 34a4dcd1f4b9c6ad1a1b460636b07a54c2f903d6 Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Fri, 12 Apr 2024 14:17:32 -0700 Subject: [PATCH 11/22] Use `LabeledFrame` class instead of dummy labeled frame --- sleap/io/cameras.py | 72 +++++++++++++++++++++------------------------ 1 file changed, 33 insertions(+), 39 deletions(-) diff --git a/sleap/io/cameras.py b/sleap/io/cameras.py index 28bef62c8..50f0a12a9 100644 --- a/sleap/io/cameras.py +++ b/sleap/io/cameras.py @@ -13,7 +13,7 @@ from attrs.validators import deep_iterable, instance_of # from sleap.io.dataset import Labels # TODO(LM): Circular import, implement Observer -from sleap.instance import PredictedInstance +from sleap.instance import LabeledFrame, Instance, PredictedInstance from sleap.io.video import Video from sleap.util import deep_iterable_converter @@ -408,9 +408,9 @@ class InstanceGroup: frame_idx: int = field(validator=instance_of(int)) camera_cluster: Optional[CameraCluster] = None locked: bool = field(default=False) - _instance_by_camcorder: Dict[Camcorder, "Instance"] = field(factory=dict) - _camcorder_by_instance: Dict["Instance", Camcorder] = field(factory=dict) - _dummy_instance: Optional["Instance"] = field(default=None) + _instance_by_camcorder: Dict[Camcorder, Instance] = field(factory=dict) + _camcorder_by_instance: Dict[Instance, Camcorder] = field(factory=dict) + _dummy_instance: Optional[Instance] = field(default=None) def __attrs_post_init__(self): """Initialize `InstanceGroup` object.""" @@ -423,7 +423,7 @@ def __attrs_post_init__(self): if self._dummy_instance is None: self._create_dummy_instance(instance=instance) - def _create_dummy_instance(self, instance: Optional["Instance"] = None): + def _create_dummy_instance(self, instance: Optional[Instance] = None): """Create a dummy instance to fill in for missing instances. Args: @@ -473,7 +473,7 @@ def dummy_instance(self) -> PredictedInstance: return self._dummy_instance @property - def instances(self) -> List["Instance"]: + def instances(self) -> List[Instance]: """List of `Instance` objects.""" return list(self._instance_by_camcorder.values()) @@ -517,7 +517,7 @@ def numpy(self, pred_as_nan: bool = False) -> np.ndarray: return np.stack(instance_numpys, axis=0) # M x N x 2 - def create_and_add_instance(self, cam: Camcorder, labeled_frame: "LabeledFrame"): + def create_and_add_instance(self, cam: Camcorder, labeled_frame: LabeledFrame): """Create an `Instance` at a labeled_frame and add it to the `InstanceGroup`. Args: @@ -545,7 +545,7 @@ def create_and_add_instance(self, cam: Camcorder, labeled_frame: "LabeledFrame") return instance - def add_instance(self, cam: Camcorder, instance: "Instance"): + def add_instance(self, cam: Camcorder, instance: Instance): """Add an `Instance` to the `InstanceGroup`. Args: @@ -574,7 +574,7 @@ def add_instance(self, cam: Camcorder, instance: "Instance"): # Add the instance to the `InstanceGroup` self.replace_instance(cam, instance) - def replace_instance(self, cam: Camcorder, instance: "Instance"): + def replace_instance(self, cam: Camcorder, instance: Instance): """Replace an `Instance` in the `InstanceGroup`. If the `Instance` is already in the `InstanceGroup`, then it is removed and @@ -599,7 +599,7 @@ def replace_instance(self, cam: Camcorder, instance: "Instance"): self._instance_by_camcorder[cam] = instance self._camcorder_by_instance[instance] = cam - def remove_instance(self, instance_or_cam: Union["Instance", Camcorder]): + def remove_instance(self, instance_or_cam: Union[Instance, Camcorder]): """Remove an `Instance` from the `InstanceGroup`. Args: @@ -643,7 +643,7 @@ def _raise_if_cam_not_in_cluster(self, cam: Camcorder): f"{self.camera_cluster}." ) - def get_instance(self, cam: Camcorder) -> Optional["Instance"]: + def get_instance(self, cam: Camcorder) -> Optional[Instance]: """Retrieve `Instance` linked to `Camcorder`. Args: @@ -663,14 +663,14 @@ def get_instance(self, cam: Camcorder) -> Optional["Instance"]: return self._instance_by_camcorder[cam] - def get_instances(self, cams: List[Camcorder]) -> List["Instance"]: + def get_instances(self, cams: List[Camcorder]) -> List[Instance]: instances = [] for cam in cams: instance = self.get_instance(cam) instances.append(instance) return instance - def get_cam(self, instance: "Instance") -> Optional[Camcorder]: + def get_cam(self, instance: Instance) -> Optional[Camcorder]: """Retrieve `Camcorder` linked to `Instance`. Args: @@ -721,7 +721,7 @@ def update_points( for cam_idx, cam in enumerate(cams_to_include): # Get the instance for the cam - instance: Optional["Instance"] = self.get_instance(cam) + instance: Optional[Instance] = self.get_instance(cam) if instance is None: logger.warning( f"Camcorder {cam.name} not found in this InstanceGroup's instances." @@ -734,8 +734,8 @@ def update_points( ) def __getitem__( - self, idx_or_key: Union[int, Camcorder, "Instance"] - ) -> Union[Camcorder, "Instance"]: + self, idx_or_key: Union[int, Camcorder, Instance] + ) -> Union[Camcorder, Instance]: """Grab a `Camcorder` of `Instance` from the `InstanceGroup`.""" def _raise_key_error(): @@ -753,7 +753,7 @@ def _raise_key_error(): return self.get_instance(idx_or_key) else: - # isinstance(idx_or_key, "Instance"): + # isinstance(idx_or_key, Instance): try: return self.get_cam(idx_or_key) except: @@ -1272,14 +1272,13 @@ class FrameGroup: # "Hidden" class attribute _cams_to_include: Optional[List[Camcorder]] = None _excluded_views: Optional[Tuple[str]] = () - _dummy_labeled_frame: Optional["LabeledFrame"] = None # "Hidden" instance attributes # TODO(LM): This dict should be updated each time a LabeledFrame is added/removed # from the Labels object. Or if a video is added/removed from the RecordingSession. - _labeled_frames_by_cam: Dict[Camcorder, "LabeledFrame"] = field(factory=dict) - _instances_by_cam: Dict[Camcorder, Set["Instance"]] = field(factory=dict) + _labeled_frames_by_cam: Dict[Camcorder, LabeledFrame] = field(factory=dict) + _instances_by_cam: Dict[Camcorder, Set[Instance]] = field(factory=dict) def __attrs_post_init__(self): """Initialize `FrameGroup` object.""" @@ -1303,10 +1302,6 @@ def __attrs_post_init__(self): # Initialize `_labeled_frames_by_cam` dictionary self.update_labeled_frames_and_instances_by_cam() - # The dummy labeled frame will only be set once for the first `FrameGroup` made - if self._dummy_labeled_frame is None: - self._dummy_labeled_frame = self.labeled_frames[0] - @property def cams_to_include(self) -> Optional[List[Camcorder]]: """List of `Camcorder`s to include in this `FrameGroup`.""" @@ -1345,7 +1340,7 @@ def cams_to_include(self, cams_to_include: List[Camcorder]): self._excluded_views = (cam.name for cam in excluded_cams) @property - def labeled_frames(self) -> List["LabeledFrame"]: + def labeled_frames(self) -> List[LabeledFrame]: """List of `LabeledFrame`s.""" return list(self._labeled_frames_by_cam.values()) @@ -1403,7 +1398,7 @@ def numpy( def add_instance( self, - instance: "Instance", + instance: Instance, camera: Camcorder, instance_group: Optional[InstanceGroup] = None, ): @@ -1490,7 +1485,7 @@ def add_instance_group(self, instance_group: Optional[InstanceGroup] = None): # Add the `InstanceGroup` to the `RecordingSession` ... - def get_instance_group(self, instance: "Instance") -> Optional[InstanceGroup]: + def get_instance_group(self, instance: Instance) -> Optional[InstanceGroup]: """Get `InstanceGroup` that contains `Instance` if exists. Otherwise, None. Args: @@ -1511,7 +1506,7 @@ def get_instance_group(self, instance: "Instance") -> Optional[InstanceGroup]: return instance_group - def add_labeled_frame(self, labeled_frame: "LabeledFrame", camera: Camcorder): + def add_labeled_frame(self, labeled_frame: LabeledFrame, camera: Camcorder): """Add a `LabeledFrame` to the `FrameGroup`. Args: @@ -1526,7 +1521,7 @@ def add_labeled_frame(self, labeled_frame: "LabeledFrame", camera: Camcorder): if labeled_frame not in self.session.labels: self.session.labels.append(labeled_frame) - def get_labeled_frame(self, camera: Camcorder) -> Optional["LabeledFrame"]: + def get_labeled_frame(self, camera: Camcorder) -> Optional[LabeledFrame]: """Get `LabeledFrame` for `Camcorder` if exists. Otherwise, None. Args: @@ -1538,7 +1533,7 @@ def get_labeled_frame(self, camera: Camcorder) -> Optional["LabeledFrame"]: return self._labeled_frames_by_cam.get(camera, None) - def create_and_add_labeled_frame(self, camera: Camcorder) -> "LabeledFrame": + def create_and_add_labeled_frame(self, camera: Camcorder) -> LabeledFrame: """Create and add a `LabeledFrame` to the `FrameGroup`. This also adds the `LabeledFrame` to the `RecordingSession`'s `Labels` object. @@ -1558,8 +1553,7 @@ def create_and_add_labeled_frame(self, camera: Camcorder) -> "LabeledFrame": f"RecordingSession {self.session}." ) - # Use _dummy_labeled_frame to access the `LabeledFrame`` class here - labeled_frame = self._dummy_labeled_frame.__class__( + labeled_frame = LabeledFrame( video=video, frame_idx=self.frame_idx ) self.add_labeled_frame(labeled_frame=labeled_frame) @@ -1570,7 +1564,7 @@ def create_and_add_instance( self, instance_group: InstanceGroup, camera: Camcorder, - labeled_frame: "LabeledFrame", + labeled_frame: LabeledFrame, ): """Add an `Instance` to the `InstanceGroup` (and `FrameGroup`). @@ -1674,7 +1668,7 @@ def upsert_points( exclude_complete=exclude_complete, ) - def _raise_if_instance_not_in_instance_group(self, instance: "Instance"): + def _raise_if_instance_not_in_instance_group(self, instance: Instance): """Raise a ValueError if the `Instance` is not in an `InstanceGroup`. Args: @@ -1690,7 +1684,7 @@ def _raise_if_instance_not_in_instance_group(self, instance: "Instance"): f"Instance {instance} is not in an InstanceGroup within the FrameGroup." ) - def _raise_if_instance_incompatibile(self, instance: "Instance", camera: Camcorder): + def _raise_if_instance_incompatibile(self, instance: Instance, camera: Camcorder): """Raise a ValueError if the `Instance` is incompatible with the `FrameGroup`. The `Instance` is incompatible if: @@ -1776,7 +1770,7 @@ def _raise_if_instance_group_not_in_frame_group( def update_labeled_frames_and_instances_by_cam( self, return_instances_by_camera: bool = False - ) -> Union[Dict[Camcorder, "LabeledFrame"], Dict[Camcorder, List["Instance"]]]: + ) -> Union[Dict[Camcorder, LabeledFrame], Dict[Camcorder, List[Instance]]]: """Get all views and `Instance`s across all `RecordingSession`s. Updates the `_labeled_frames_by_cam` and `_instances_by_cam` @@ -1797,11 +1791,11 @@ def update_labeled_frames_and_instances_by_cam( f"\n\t{self._labeled_frames_by_cam}" ) - views: Dict[Camcorder, "LabeledFrame"] = {} - instances_by_cam: Dict[Camcorder, Set["Instance"]] = {} + views: Dict[Camcorder, LabeledFrame] = {} + instances_by_cam: Dict[Camcorder, Set[Instance]] = {} videos = self.session.get_videos_from_selected_cameras() for cam, video in videos.items(): - lfs: List["LabeledFrame"] = self.session.labels.get( + lfs: List[LabeledFrame] = self.session.labels.get( (video, [self.frame_idx]) ) if len(lfs) == 0: From a330c69fb36e6296cbbf8b7a73b63e0599bb5b56 Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Mon, 15 Apr 2024 11:59:43 -0700 Subject: [PATCH 12/22] Limit which methods can update `Labels.labeled_frames` --- sleap/gui/dialogs/delete.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sleap/gui/dialogs/delete.py b/sleap/gui/dialogs/delete.py index 7e8d39e6b..a0a281e74 100644 --- a/sleap/gui/dialogs/delete.py +++ b/sleap/gui/dialogs/delete.py @@ -216,7 +216,7 @@ def _delete(self, lf_inst_list: List[Tuple[LabeledFrame, Instance]]): for lf, inst in lf_inst_list: self.context.labels.remove_instance(lf, inst, in_transaction=True) if not lf.instances: - self.context.labels.remove(lf) + self.context.labels.remove_frame(lf=lf, update_cache=False) # Update caches since we skipped doing this after each deletion self.context.labels.update_cache() From e78642ec88ed474c26772dd66deffda0a6660e0b Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Mon, 15 Apr 2024 12:00:48 -0700 Subject: [PATCH 13/22] Update cache when Labels. remove_session_video --- sleap/io/dataset.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 1221cadae..e12c963ba 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -274,6 +274,11 @@ def remove_video(self, video: Video): del self._lf_by_video[video] if video in self._frame_idx_map: del self._frame_idx_map[video] + self.remove_session_video(video=video) + + def remove_session_video(self, video: Video): + """Remove video from session in cache.""" + if video in self._session_by_video: del self._session_by_video[video] @@ -442,8 +447,7 @@ def _del_count_cache(self, video, video_idx, frame_idx, type_key: str): @attr.s(auto_attribs=True, repr=False, str=False) class Labels(MutableSequence): - """ - The :class:`Labels` class collects the data for a SLEAP project. + """The :class:`Labels` class collects the data for a SLEAP project. This class is front-end for all interactions with loading, writing, and modifying these labels. The actual storage backend for the data @@ -1657,6 +1661,7 @@ def remove_video(self, video: Video): # Delete video self.videos.remove(video) + self.remove_session_video(video) self._cache.remove_video(video) def add_session(self, session: RecordingSession): @@ -1710,9 +1715,9 @@ def remove_session_video(self, session: RecordingSession, video: Video): video: `Video` instance """ - self._cache._session_by_video.pop(video, None) if video in session.videos: session.remove_video(video) + self._cache.remove_session_video(video) @classmethod def from_json(cls, *args, **kwargs): From 712319c9f74e0808319b0e6f7c6bd04ee1ae53f7 Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Mon, 15 Apr 2024 13:33:59 -0700 Subject: [PATCH 14/22] Remove RecordingSession.instance_groups --- sleap/io/cameras.py | 46 ++------------------------------------------- 1 file changed, 2 insertions(+), 44 deletions(-) diff --git a/sleap/io/cameras.py b/sleap/io/cameras.py index 50f0a12a9..851e3b93b 100644 --- a/sleap/io/cameras.py +++ b/sleap/io/cameras.py @@ -835,13 +835,8 @@ class RecordingSession: # TODO(LM): Consider implementing Observer pattern for `camera_cluster` and `labels` camera_cluster: CameraCluster = field(factory=CameraCluster) metadata: dict = field(factory=dict) + labels: Optional["Labels"] = field(default=None) _video_by_camcorder: Dict[Camcorder, Video] = field(factory=dict) - labels: Optional["Labels"] = None - - # TODO(LM): Remove this, replace with `FrameGroup`s - _instance_groups_by_frame_idx: Dict[int, InstanceGroup] = field(factory=dict) - - # TODO(LM): We should serialize all locked instances in a FrameGroup (or the entire FrameGroup) _frame_group_by_frame_idx: Dict[int, "FrameGroup"] = field(factory=dict) @property @@ -873,13 +868,6 @@ def unlinked_cameras(self) -> List[Camcorder]: key=self.camera_cluster.cameras.index, ) - # TODO(LM): Remove this - @property - def instance_groups(self) -> Dict[int, InstanceGroup]: - """Dict of `InstanceGroup`s by frame index.""" - - return self._instance_groups_by_frame_idx - @property def frame_groups(self) -> Dict[int, "FrameGroup"]: """Dict of `FrameGroup`s by frame index.""" @@ -1036,36 +1024,6 @@ def get_videos_from_selected_cameras( return videos - # TODO(LM): There can be multiple `InstanceGroup`s per frame index - def get_instance_group(self, frame_idx: int) -> Optional[InstanceGroup]: - """Get `InstanceGroup` from frame index. - - Args: - frame_idx: Frame index. - - Returns: - `InstanceGroup` object or `None` if not found. - """ - - if frame_idx not in self.instance_groups: - logger.warning( - f"Frame index {frame_idx} not found in this RecordingSession's " - f"InstanceGroup's keys: \n\t{self.instance_groups.keys()}." - ) - return None - - return self.instance_groups[frame_idx] - - # TODO(LM): There can be multiple `InstanceGroup`s per frame index - def update_instance_group(self, frame_idx: int, instance_group: InstanceGroup): - """Update `InstanceGroup` from frame index. - - Args: - frame_idx: Frame index. - instance_groups: `InstanceGroup` object. - """ - - self._instance_groups_by_frame_idx[frame_idx] = instance_group def __attrs_post_init__(self): self.camera_cluster.add_session(self) @@ -1299,7 +1257,7 @@ def __attrs_post_init__(self): # Add `FrameGroup` to `RecordingSession` self.session._frame_group_by_frame_idx[self.frame_idx] = self - # Initialize `_labeled_frames_by_cam` dictionary + # Initialize `_labeled_frames_by_cam` and `_instances_by_cam` dictionary self.update_labeled_frames_and_instances_by_cam() @property From ddbd19c89211d8d2add3d4213e41cbd5d14899e5 Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Mon, 15 Apr 2024 16:50:40 -0700 Subject: [PATCH 15/22] [wip] Maintain cached FrameGroup dictionaries --- sleap/io/cameras.py | 251 +++++++++++++++++++++++++-------------- tests/io/test_cameras.py | 4 + 2 files changed, 169 insertions(+), 86 deletions(-) diff --git a/sleap/io/cameras.py b/sleap/io/cameras.py index 851e3b93b..4dd40d505 100644 --- a/sleap/io/cameras.py +++ b/sleap/io/cameras.py @@ -408,10 +408,14 @@ class InstanceGroup: frame_idx: int = field(validator=instance_of(int)) camera_cluster: Optional[CameraCluster] = None locked: bool = field(default=False) + _name: str = field(default="inst_group_0") _instance_by_camcorder: Dict[Camcorder, Instance] = field(factory=dict) _camcorder_by_instance: Dict[Instance, Camcorder] = field(factory=dict) _dummy_instance: Optional[Instance] = field(default=None) + # Class attributes + _name_registry: Set[str] = set() + def __attrs_post_init__(self): """Initialize `InstanceGroup` object.""" @@ -472,6 +476,12 @@ def dummy_instance(self) -> PredictedInstance: self._create_dummy_instance() return self._dummy_instance + @property + def name(self) -> str: + """Name of the `InstanceGroup`.""" + + return self._name + @property def instances(self) -> List[Instance]: """List of `Instance` objects.""" @@ -482,6 +492,11 @@ def cameras(self) -> List[Camcorder]: """List of `Camcorder` objects.""" return list(self._instance_by_camcorder.keys()) + @property + def instance_by_camcorder(self) -> Dict[Camcorder, Instance]: + """Dictionary of `Instance` objects by `Camcorder`.""" + return self._instance_by_camcorder + def numpy(self, pred_as_nan: bool = False) -> np.ndarray: """Return instances as a numpy array of shape (n_views, n_nodes, 2). @@ -767,6 +782,9 @@ def __len__(self): def __repr__(self): return f"{self.__class__.__name__}(frame_idx={self.frame_idx}, instances={len(self)}, camera_cluster={self.camera_cluster})" + def __hash__(self) -> int: + return hash(self._name) + @classmethod def from_dict(cls, d: dict) -> Optional["InstanceGroup"]: """Creates an `InstanceGroup` object from a dictionary. @@ -1024,7 +1042,6 @@ def get_videos_from_selected_cameras( return videos - def __attrs_post_init__(self): self.camera_cluster.add_session(self) @@ -1216,7 +1233,7 @@ class FrameGroup: # Instance attributes frame_idx: int = field(validator=instance_of(int)) - instance_groups: List[InstanceGroup] = field( + _instance_groups: List[InstanceGroup] = field( validator=deep_iterable( member_validator=instance_of(InstanceGroup), iterable_validator=instance_of(list), @@ -1235,7 +1252,8 @@ class FrameGroup: # TODO(LM): This dict should be updated each time a LabeledFrame is added/removed # from the Labels object. Or if a video is added/removed from the RecordingSession. - _labeled_frames_by_cam: Dict[Camcorder, LabeledFrame] = field(factory=dict) + _labeled_frame_by_cam: Dict[Camcorder, LabeledFrame] = field(factory=dict) + _cam_by_labeled_frame: Dict[LabeledFrame, Camcorder] = field(factory=dict) _instances_by_cam: Dict[Camcorder, Set[Instance]] = field(factory=dict) def __attrs_post_init__(self): @@ -1257,8 +1275,30 @@ def __attrs_post_init__(self): # Add `FrameGroup` to `RecordingSession` self.session._frame_group_by_frame_idx[self.frame_idx] = self - # Initialize `_labeled_frames_by_cam` and `_instances_by_cam` dictionary - self.update_labeled_frames_and_instances_by_cam() + # Build `_labeled_frame_by_cam` and `_instances_by_cam` dictionary + for camera in self.session.camera_cluster.cameras: + self._instances_by_cam[camera] = set() + self.instance_groups = self._instance_groups + + @property + def instance_groups(self) -> List[InstanceGroup]: + """List of `InstanceGroup`s.""" + + return self._instance_groups + + @instance_groups.setter + def instance_groups(self, instance_groups: List[InstanceGroup]): + """Setter for `instance_groups` that updates `LabeledFrame`s and `Instance`s.""" + + instance_groups_to_remove = set(self.instance_groups) - set(instance_groups) + instance_groups_to_add = set(instance_groups) - set(self.instance_groups) + + # Update the `_labeled_frame_by_cam` and `_instances_by_cam` dictionary + for instance_group in instance_groups_to_remove: + self.remove_instance_group(instance_group=instance_group) + + for instance_group in instance_groups_to_add: + self.add_instance_group(instance_group=instance_group) @property def cams_to_include(self) -> Optional[List[Camcorder]]: @@ -1301,13 +1341,15 @@ def cams_to_include(self, cams_to_include: List[Camcorder]): def labeled_frames(self) -> List[LabeledFrame]: """List of `LabeledFrame`s.""" - return list(self._labeled_frames_by_cam.values()) + # TODO(LM): Revisit whether we need to return a list instead of a view object + return list(self._labeled_frame_by_cam.values()) @property def cameras(self) -> List[Camcorder]: """List of `Camcorder`s.""" - return list(self._labeled_frames_by_cam.keys()) + # TODO(LM): Revisit whether we need to return a list instead of a view object + return list(self._labeled_frame_by_cam.keys()) def numpy( self, @@ -1406,6 +1448,33 @@ def add_instance( labeled_frame = instance.frame self.add_labeled_frame(labeled_frame=labeled_frame, camera=camera) + def remove_instance(self, instance: Instance): + """Removes an `Instance` from the `FrameGroup`. + + Args: + instance: `Instance` to remove from the `FrameGroup`. + """ + + instance_group = self.get_instance_group(instance=instance) + + if instance_group is None: + logger.warning( + f"Instance {instance} not found in this FrameGroup.instance_groups: " + f"{self.instance_groups}." + ) + return + + # Remove the `Instance` from the `InstanceGroup` + camera = instance_group.get_cam(instance=instance) + instance_group.remove_instance(instance=instance) + + # Remove the `Instance` from the `FrameGroup` + self._instances_by_cam[camera].remove(instance) + + # Remove "empty" `LabeledFrame`s from the `FrameGroup` + if len(self._instances_by_cam[camera]) < 1: + self.remove_labeled_frame(labeled_frame_or_camera=camera) + def add_instance_group(self, instance_group: Optional[InstanceGroup] = None): """Add an `InstanceGroup` to the `FrameGroup`. @@ -1435,13 +1504,30 @@ def add_instance_group(self, instance_group: Optional[InstanceGroup] = None): self.instance_groups.append(instance_group) # Add `Instance`s and `LabeledFrame`s to the `FrameGroup` - for instance in instance_group.instances: - camera = instance_group.get_cam(instance=instance) + for camera, instance in instance_group.instance_by_camcorder.items(): self.add_instance(instance=instance, camera=camera) - # TODO(LM): Integrate with RecordingSession - # Add the `InstanceGroup` to the `RecordingSession` - ... + def remove_instance_group(self, instance_group: InstanceGroup): + """Remove an `InstanceGroup` from the `FrameGroup`.""" + + if instance_group not in self.instance_groups: + logger.warning( + f"InstanceGroup {instance_group} not found in this FrameGroup: " + f"{self.instance_groups}." + ) + return + + # Remove the `InstanceGroup` from the `FrameGroup` + self.instance_groups.remove(instance_group) + + # Remove the `Instance`s from the `FrameGroup` + for camera, instance in instance_group.instance_by_camcorder.items(): + self._instances_by_cam[camera].remove(instance) + + # Remove the `LabeledFrame` from the `FrameGroup` + labeled_frame = self.get_labeled_frame(camera=camera) + if labeled_frame is not None: + self.remove_labeled_frame(camera=camera) def get_instance_group(self, instance: Instance) -> Optional[InstanceGroup]: """Get `InstanceGroup` that contains `Instance` if exists. Otherwise, None. @@ -1470,15 +1556,62 @@ def add_labeled_frame(self, labeled_frame: LabeledFrame, camera: Camcorder): Args: labeled_frame: `LabeledFrame` to add to the `FrameGroup`. camera: `Camcorder` to link the `LabeledFrame` to. + + Raises: + ValueError: If the `LabeledFrame` is not compatible with the `FrameGroup`. """ + # Some checks to ensure the `LabeledFrame` is compatible with the `FrameGroup` + if not isinstance(labeled_frame, LabeledFrame): + raise ValueError( + f"Cannot add LabeledFrame: {labeled_frame} is not a LabeledFrame." + ) + elif labeled_frame.frame_idx != self.frame_idx: + raise ValueError( + f"Cannot add LabeledFrame: Frame index {labeled_frame.frame_idx} does " + f"not match FrameGroup frame index {self.frame_idx}." + ) + elif not isinstance(camera, Camcorder): + raise ValueError(f"Cannot add LabeledFrame: {camera} is not a Camcorder.") + # Add the `LabeledFrame` to the `FrameGroup` - self._labeled_frames_by_cam[camera] = labeled_frame + self._labeled_frame_by_cam[camera] = labeled_frame + self._cam_by_labeled_frame[labeled_frame] = camera # Add the `LabeledFrame` to the `RecordingSession`'s `Labels` object - if labeled_frame not in self.session.labels: + if (self.session.labels is not None) and ( + labeled_frame not in self.session.labels + ): self.session.labels.append(labeled_frame) + def remove_labeled_frame( + self, labeled_frame_or_camera: Union[LabeledFrame, Camcorder] + ): + """Remove a `LabeledFrame` from the `FrameGroup`. + + Args: + labeled_frame_or_camera: `LabeledFrame` or `Camcorder` to remove the + `LabeledFrame` for. + """ + + if isinstance(labeled_frame_or_camera, LabeledFrame): + labeled_frame: LabeledFrame = labeled_frame_or_camera + camera = self.get_camera(labeled_frame=labeled_frame) + + elif isinstance(labeled_frame_or_camera, Camcorder): + camera: Camcorder = labeled_frame_or_camera + labeled_frame = self.get_labeled_frame(camera=camera) + + else: + logger.warning( + f"Cannot remove LabeledFrame: {labeled_frame_or_camera} is not a " + "LabeledFrame or Camcorder." + ) + + # Remove the `LabeledFrame` from the `FrameGroup` + self._labeled_frame_by_cam.pop(camera, None) + self._cam_by_labeled_frame.pop(labeled_frame, None) + def get_labeled_frame(self, camera: Camcorder) -> Optional[LabeledFrame]: """Get `LabeledFrame` for `Camcorder` if exists. Otherwise, None. @@ -1489,9 +1622,21 @@ def get_labeled_frame(self, camera: Camcorder) -> Optional[LabeledFrame]: `LabeledFrame` """ - return self._labeled_frames_by_cam.get(camera, None) + return self._labeled_frame_by_cam.get(camera, None) + + def get_camera(self, labeled_frame: LabeledFrame) -> Optional[Camcorder]: + """Get `Camcorder` for `LabeledFrame` if exists. Otherwise, None. + + Args: + labeled_frame: `LabeledFrame` - def create_and_add_labeled_frame(self, camera: Camcorder) -> LabeledFrame: + Returns: + `Camcorder` + """ + + return self._cam_by_labeled_frame.get(labeled_frame, None) + + def _create_and_add_labeled_frame(self, camera: Camcorder) -> LabeledFrame: """Create and add a `LabeledFrame` to the `FrameGroup`. This also adds the `LabeledFrame` to the `RecordingSession`'s `Labels` object. @@ -1511,14 +1656,12 @@ def create_and_add_labeled_frame(self, camera: Camcorder) -> LabeledFrame: f"RecordingSession {self.session}." ) - labeled_frame = LabeledFrame( - video=video, frame_idx=self.frame_idx - ) + labeled_frame = LabeledFrame(video=video, frame_idx=self.frame_idx) self.add_labeled_frame(labeled_frame=labeled_frame) return labeled_frame - def create_and_add_instance( + def _create_and_add_instance( self, instance_group: InstanceGroup, camera: Camcorder, @@ -1566,10 +1709,10 @@ def create_and_add_missing_instances(self, instance_group: InstanceGroup): labeled_frame = self.get_labeled_frame(camera=cam) if labeled_frame is None: # There is no `LabeledFrame` for this view, so lets make one - labeled_frame = self.create_and_add_labeled_frame(camera=cam) + labeled_frame = self._create_and_add_labeled_frame(camera=cam) # Create an instance - self.create_and_add_instance( + self._create_and_add_instance( instance_group=instance_group, cam=cam, labeled_frame=labeled_frame ) @@ -1726,70 +1869,6 @@ def _raise_if_instance_group_not_in_frame_group( f"{self.instance_groups}." ) - def update_labeled_frames_and_instances_by_cam( - self, return_instances_by_camera: bool = False - ) -> Union[Dict[Camcorder, LabeledFrame], Dict[Camcorder, List[Instance]]]: - """Get all views and `Instance`s across all `RecordingSession`s. - - Updates the `_labeled_frames_by_cam` and `_instances_by_cam` - dictionary attributes. - - Args: - return_instances_by_camera: If true, then returns a dictionary with - `Camcorder` key and `Set[Instance]` values instead. Default is False. - - Returns: - Dictionary with `Camcorder` key and `LabeledFrame` value or `Set[Instance]` - value if `return_instances_by_camera` is True. - """ - - logger.debug( - "Updating LabeledFrames for FrameGroup." - "\n\tPrevious LabeledFrames by Camcorder:" - f"\n\t{self._labeled_frames_by_cam}" - ) - - views: Dict[Camcorder, LabeledFrame] = {} - instances_by_cam: Dict[Camcorder, Set[Instance]] = {} - videos = self.session.get_videos_from_selected_cameras() - for cam, video in videos.items(): - lfs: List[LabeledFrame] = self.session.labels.get( - (video, [self.frame_idx]) - ) - if len(lfs) == 0: - logger.debug( - f"No LabeledFrames found for video {video} at {self.frame_idx}." - ) - continue - - lf = lfs[0] - if len(lf.instances) == 0: - logger.warning( - f"No Instances found for {lf}." - " There should not be empty LabeledFrames." - ) - continue - - views[cam] = lf - - # Find instances in frame - insts = lf.find(track=-1) - if len(insts) > 0: - instances_by_cam[cam] = set(insts) - - # Update `_labeled_frames_by_cam` dictionary and return it - self._labeled_frames_by_cam = views - logger.debug( - f"\tUpdated LabeledFrames by Camcorder:\n\t{self._labeled_frames_by_cam}" - ) - # Update `_instances_by_camera` dictionary and return it - self._instances_by_cam = instances_by_cam - return ( - self._instances_by_cam - if return_instances_by_camera - else self._labeled_frames_by_cam - ) - @classmethod def from_instance_groups( cls, diff --git a/tests/io/test_cameras.py b/tests/io/test_cameras.py index f5b0ba014..767e54c62 100644 --- a/tests/io/test_cameras.py +++ b/tests/io/test_cameras.py @@ -447,3 +447,7 @@ def test_frame_group(multiview_min_session_labels: Labels): assert frame_group_1._frame_idx_registry == frame_group_2._frame_idx_registry # TODO(LM): Test `generate_hypotheses` + + +if __name__ == "__main__": + pytest.main([f"{__file__}::test_frame_group"]) \ No newline at end of file From 9454943d35d5475fcd2b6d76fb5ac0ebfe21821f Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Tue, 16 Apr 2024 16:44:47 -0700 Subject: [PATCH 16/22] Add unique name (per FrameGroup) requirement for InstanceGroup --- sleap/io/cameras.py | 178 +++++++++++++++++++++++++++++---------- tests/io/test_cameras.py | 14 ++- 2 files changed, 141 insertions(+), 51 deletions(-) diff --git a/sleap/io/cameras.py b/sleap/io/cameras.py index 4dd40d505..a79f85f19 100644 --- a/sleap/io/cameras.py +++ b/sleap/io/cameras.py @@ -405,16 +405,14 @@ class InstanceGroup: """ + _name: str = field() frame_idx: int = field(validator=instance_of(int)) - camera_cluster: Optional[CameraCluster] = None - locked: bool = field(default=False) - _name: str = field(default="inst_group_0") _instance_by_camcorder: Dict[Camcorder, Instance] = field(factory=dict) _camcorder_by_instance: Dict[Instance, Camcorder] = field(factory=dict) _dummy_instance: Optional[Instance] = field(default=None) # Class attributes - _name_registry: Set[str] = set() + camera_cluster: Optional[CameraCluster] = None def __attrs_post_init__(self): """Initialize `InstanceGroup` object.""" @@ -482,6 +480,63 @@ def name(self) -> str: return self._name + @name.setter + def name(self, name: str): + """Set the name of the `InstanceGroup`.""" + + raise ValueError( + "Cannot set name directly. Use `set_name` method instead (preferably " + "through FrameGroup.set_instance_group_name)." + ) + + def set_name(self, name: str, name_registry: Set[str]): + """Set the name of the `InstanceGroup`. + + This function mutates the name_registry input (see side-effect). + + Args: + name: Name to set for the `InstanceGroup`. + name_registry: Set of names to check for uniqueness. + + Raises: + ValueError: If the name is already in use (in the name_registry). + """ + + # Check if the name is already in use + if name in name_registry: + raise ValueError( + f"Name {name} already in use. Please use a unique name not currently " + f"in the registry: {name_registry}" + ) + + # Remove the old name from the registry + if self._name in name_registry: + name_registry.remove(self._name) + + self._name = name + name_registry.add(name) + + @classmethod + def return_unique_name(cls, name_registry: Set[str]) -> str: + """Return a unique name for the `InstanceGroup`. + + Args: + name_registry: Set of names to check for uniqueness. + + Returns: + Unique name for the `InstanceGroup`. + """ + + base_name = "instance_group_" + count = len(name_registry) + new_name = f"{base_name}{count}" + + while new_name in name_registry: + count += 1 + new_name = f"{base_name}{count}" + + return new_name + @property def instances(self) -> List[Instance]: """List of `Instance` objects.""" @@ -786,11 +841,18 @@ def __hash__(self) -> int: return hash(self._name) @classmethod - def from_dict(cls, d: dict) -> Optional["InstanceGroup"]: + def from_dict( + cls, d: dict, name: str, name_registry: Set[str] + ) -> Optional["InstanceGroup"]: """Creates an `InstanceGroup` object from a dictionary. Args: d: Dictionary with `Camcorder` keys and `Instance` values. + name: Name to use for the `InstanceGroup`. + name_registry: Set of names to check for uniqueness. + + Raises: + ValueError: If the `InstanceGroup` name is already in use. Returns: `InstanceGroup` object or None if no "real" (determined by `frame_idx` other @@ -811,24 +873,23 @@ def from_dict(cls, d: dict) -> Optional["InstanceGroup"]: elif frame_idx is None: frame_idx = instance.frame_idx # Ensure all instances have the same frame index - else: - try: - assert frame_idx == instance.frame_idx - except AssertionError: - logger.warning( - f"Cannot create `InstanceGroup`: Frame index {frame_idx} " - f"does not match instance frame index {instance.frame_idx}." - ) + elif frame_idx != instance.frame_idx: + raise ValueError( + f"Cannot create `InstanceGroup`: Frame index {frame_idx} does " + f"not match instance frame index {instance.frame_idx}." + ) if len(d_copy) == 0: - logger.warning("Cannot create `InstanceGroup`: No real instances found.") - return None + raise ValueError("Cannot create `InstanceGroup`: No real instances found.") - frame_idx = cast( - int, frame_idx - ) # Could be None if no real instances in dictionary + if name in name_registry: + raise ValueError( + f"Cannot create `InstanceGroup`: Name {name} already in use. Please " + f"use a unique name that is not in the registry: {name_registry}." + ) return cls( + name=name, frame_idx=frame_idx, camera_cluster=camera_cluster, instance_by_camcorder=d_copy, @@ -1240,7 +1301,9 @@ class FrameGroup: ), ) # Akin to `LabeledFrame.instances` session: RecordingSession = field(validator=instance_of(RecordingSession)) + _instance_group_name_registry: Set[str] = field(factory=set) + # TODO(LM): Should we move this to an instance attribute of `RecordingSession`? # Class attribute to keep track of frame indices across all `RecordingSession`s _frame_idx_registry: Dict[RecordingSession, Set[int]] = {} @@ -1259,6 +1322,17 @@ class FrameGroup: def __attrs_post_init__(self): """Initialize `FrameGroup` object.""" + # Check that `InstanceGroup` names unique (later added via add_instance_group) + instance_group_name_registry_copy = set(self._instance_group_name_registry) + for instance_group in self.instance_groups: + if instance_group.name in instance_group_name_registry_copy: + raise ValueError( + f"InstanceGroup name {instance_group.name} already in use. " + f"Please use a unique name not currently in the registry: " + f"{self._instance_group_name_registry}" + ) + instance_group_name_registry_copy.add(instance_group.name) + # Remove existing `FrameGroup` object from the `RecordingSession._frame_group_by_frame_idx` self.enforce_frame_idx_unique(self.session, self.frame_idx) @@ -1278,7 +1352,8 @@ def __attrs_post_init__(self): # Build `_labeled_frame_by_cam` and `_instances_by_cam` dictionary for camera in self.session.camera_cluster.cameras: self._instances_by_cam[camera] = set() - self.instance_groups = self._instance_groups + for instance_group in self.instance_groups: + self.add_instance_group(instance_group) @property def instance_groups(self) -> List[InstanceGroup]: @@ -1405,7 +1480,8 @@ def add_instance( """Add an (existing) `Instance` to the `FrameGroup`. If no `InstanceGroup` is provided, then check the `Instance` is already in an - `InstanceGroup` contained in the `FrameGroup`. + `InstanceGroup` contained in the `FrameGroup`. Otherwise, add the `Instance` to + the `InstanceGroup` and `FrameGroup`. Args: instance: `Instance` to add to the `FrameGroup`. @@ -1478,6 +1554,11 @@ def remove_instance(self, instance: Instance): def add_instance_group(self, instance_group: Optional[InstanceGroup] = None): """Add an `InstanceGroup` to the `FrameGroup`. + This method updates the underlying dictionaries in calling add_instance: + - `_instances_by_cam` + - `_labeled_frame_by_cam` + - `_cam_by_labeled_frame` + Args: instance_group: `InstanceGroup` to add to the `FrameGroup`. If None, then create a new `InstanceGroup` and add it to the `FrameGroup`. @@ -1487,21 +1568,29 @@ def add_instance_group(self, instance_group: Optional[InstanceGroup] = None): """ if instance_group is None: + + # Find a unique name for the `InstanceGroup` + instance_group_name = InstanceGroup.return_unique_name( + name_registry=self._instance_group_name_registry + ) + # Create an empty `InstanceGroup` with the frame index of the `FrameGroup` instance_group = InstanceGroup( + name=instance_group_name, frame_idx=self.frame_idx, camera_cluster=self.session.camera_cluster, ) - else: - # Ensure the `InstanceGroup` is not already in this `FrameGroup` - self._raise_if_instance_group_in_frame_group(instance_group=instance_group) - # Ensure the `InstanceGroup` is compatible with the `FrameGroup` self._raise_if_instance_group_incompatible(instance_group=instance_group) # Add the `InstanceGroup` to the `FrameGroup` - self.instance_groups.append(instance_group) + # We only expect this to be false on initialization + if instance_group not in self.instance_groups: + self.instance_groups.append(instance_group) + + # Add instance group name to the registry + self._instance_group_name_registry.add(instance_group.name) # Add `Instance`s and `LabeledFrame`s to the `FrameGroup` for camera, instance in instance_group.instance_by_camcorder.items(): @@ -1519,6 +1608,7 @@ def remove_instance_group(self, instance_group: InstanceGroup): # Remove the `InstanceGroup` from the `FrameGroup` self.instance_groups.remove(instance_group) + self._instance_group_name_registry.remove(instance_group.name) # Remove the `Instance`s from the `FrameGroup` for camera, instance in instance_group.instance_by_camcorder.items(): @@ -1550,6 +1640,15 @@ def get_instance_group(self, instance: Instance) -> Optional[InstanceGroup]: return instance_group + def set_instance_group_name(self, instance_group: InstanceGroup, name: str): + """Set the name of an `InstanceGroup` in the `FrameGroup`.""" + + self._raise_if_instance_group_not_in_frame_group(instance_group=instance_group) + + instance_group.set_name( + name=name, name_registry=self._instance_group_name_registry + ) + def add_labeled_frame(self, labeled_frame: LabeledFrame, camera: Camcorder): """Add a `LabeledFrame` to the `FrameGroup`. @@ -1823,27 +1922,12 @@ def _raise_if_instance_incompatibile(self, instance: Instance, camera: Camcorder f"FrameGroup's LabeledFrame {labeled_frame_fg} for Camcorder {camera}." ) - def _raise_if_instance_group_in_frame_group(self, instance_group: InstanceGroup): - """Raise a ValueError if the `InstanceGroup` is already in the `FrameGroup`. - - Args: - instance_group: `InstanceGroup` to check if already in the `FrameGroup`. - - Raises: - ValueError: If the `InstanceGroup` is already in the `FrameGroup`. - """ - - if instance_group in self.instance_groups: - raise ValueError( - f"InstanceGroup {instance_group} is already in this FrameGroup " - f"{self.instance_groups}." - ) - def _raise_if_instance_group_incompatible(self, instance_group: InstanceGroup): """Raise a ValueError if `InstanceGroup` is incompatible with `FrameGroup`. - An `InstanceGroup` is incompatible if the `frame_idx` does not match the - `FrameGroup`'s `frame_idx`. + An `InstanceGroup` is incompatible if + - the `frame_idx` does not match the `FrameGroup`'s `frame_idx`. + - the `InstanceGroup.name` is already used in the `FrameGroup`. Args: instance_group: `InstanceGroup` to check compatibility of. @@ -1858,6 +1942,14 @@ def _raise_if_instance_group_incompatible(self, instance_group: InstanceGroup): f"does not match FrameGroup frame index {self.frame_idx}." ) + if instance_group.name in self._instance_group_name_registry: + raise ValueError( + f"InstanceGroup name {instance_group.name} is already registered in " + "this FrameGroup's list of names: " + f"{self._instance_group_name_registry}\n" + "Please use a unique name for the new InstanceGroup." + ) + def _raise_if_instance_group_not_in_frame_group( self, instance_group: InstanceGroup ): diff --git a/tests/io/test_cameras.py b/tests/io/test_cameras.py index 767e54c62..236565c8a 100644 --- a/tests/io/test_cameras.py +++ b/tests/io/test_cameras.py @@ -331,7 +331,9 @@ def create_instance_group( ) instance_by_camera[cam] = dummy_instance - instance_group = InstanceGroup.from_dict(d=instance_by_camera) + instance_group = InstanceGroup.from_dict( + d=instance_by_camera, name="test_instance_group", name_registry={} + ) return ( (instance_group, instance_by_camera, dummy_instance, cam) if add_dummy @@ -411,8 +413,8 @@ def test_instance_group(multiview_min_session_labels: Labels): # Populate with only dummy instance and test `from_dict` instance_by_camera = {cam: dummy_instance} - instance_group = InstanceGroup.from_dict(d=instance_by_camera) - assert instance_group is None + with pytest.raises(ValueError): + instance_group = InstanceGroup.from_dict(d=instance_by_camera, name="test_instance_group", name_registry={}) def test_frame_group(multiview_min_session_labels: Labels): @@ -446,8 +448,4 @@ def test_frame_group(multiview_min_session_labels: Labels): assert frame_group_2._frame_idx_registry[session] == {frame_idx_1, frame_idx_2} assert frame_group_1._frame_idx_registry == frame_group_2._frame_idx_registry - # TODO(LM): Test `generate_hypotheses` - - -if __name__ == "__main__": - pytest.main([f"{__file__}::test_frame_group"]) \ No newline at end of file + # TODO(LM): Test underlying dictionaries more thoroughly From 545dd4459a80cb63c0b694033be4471f5bbae36b Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Tue, 16 Apr 2024 16:45:08 -0700 Subject: [PATCH 17/22] Lint --- tests/io/test_cameras.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/io/test_cameras.py b/tests/io/test_cameras.py index 236565c8a..6b2f1f354 100644 --- a/tests/io/test_cameras.py +++ b/tests/io/test_cameras.py @@ -414,7 +414,9 @@ def test_instance_group(multiview_min_session_labels: Labels): # Populate with only dummy instance and test `from_dict` instance_by_camera = {cam: dummy_instance} with pytest.raises(ValueError): - instance_group = InstanceGroup.from_dict(d=instance_by_camera, name="test_instance_group", name_registry={}) + instance_group = InstanceGroup.from_dict( + d=instance_by_camera, name="test_instance_group", name_registry={} + ) def test_frame_group(multiview_min_session_labels: Labels): From 3fe02469d587b1b1fe95abb3d08256b1e2ab6544 Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Tue, 16 Apr 2024 17:09:30 -0700 Subject: [PATCH 18/22] Fix remove_video bug --- sleap/io/cameras.py | 2 +- sleap/io/dataset.py | 22 +++++++++++++--------- tests/io/test_dataset.py | 2 +- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/sleap/io/cameras.py b/sleap/io/cameras.py index a79f85f19..47fd4b941 100644 --- a/sleap/io/cameras.py +++ b/sleap/io/cameras.py @@ -1076,7 +1076,7 @@ def remove_video(self, video: Video): # Update labels cache if self.labels is not None and self.labels.get_session(video) is not None: - self.labels.remove_session_video(self, video) + self.labels.remove_session_video(video=video) def get_videos_from_selected_cameras( self, cams_to_include: Optional[List[Camcorder]] = None diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index e12c963ba..6c91af612 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -278,7 +278,7 @@ def remove_video(self, video: Video): def remove_session_video(self, video: Video): """Remove video from session in cache.""" - + if video in self._session_by_video: del self._session_by_video[video] @@ -1661,8 +1661,8 @@ def remove_video(self, video: Video): # Delete video self.videos.remove(video) - self.remove_session_video(video) - self._cache.remove_video(video) + self.remove_session_video(video=video) + self._cache.remove_video(video=video) def add_session(self, session: RecordingSession): """Add a recording session to the labels. @@ -1707,17 +1707,21 @@ def get_session(self, video: Video) -> Optional[RecordingSession]: """ return self._cache._session_by_video.get(video, None) - def remove_session_video(self, session: RecordingSession, video: Video): - """Remove a video from a recording session. + def remove_session_video(self, video: Video): + """Remove a video from its linked recording session (if any). Args: - session: `RecordingSession` instance video: `Video` instance """ - if video in session.videos: - session.remove_video(video) - self._cache.remove_session_video(video) + session = self.get_session(video) + + if session is None: + return + + # Need to remove from cache first to avoid circular reference + self._cache.remove_session_video(video=video) + session.remove_video(video) @classmethod def from_json(cls, *args, **kwargs): diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index 020dd64ed..a544b7703 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -1030,7 +1030,7 @@ def test_add_session_and_update_session( assert labels._cache._session_by_video == {video: session} assert labels.get_session(video) == session - labels.remove_session_video(session, video) + labels.remove_session_video(video=video) assert video not in session.videos assert video not in labels._cache._session_by_video From 118cf37280776b5bd1df0713af7c62a7eeca767c Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Wed, 17 Apr 2024 13:41:33 -0700 Subject: [PATCH 19/22] Add RecordingSession.new_frame_group method --- sleap/io/cameras.py | 22 ++++++++++++++++++---- tests/io/test_cameras.py | 20 ++++++++++++++++++++ 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/sleap/io/cameras.py b/sleap/io/cameras.py index 47fd4b941..450b072a9 100644 --- a/sleap/io/cameras.py +++ b/sleap/io/cameras.py @@ -1016,9 +1016,7 @@ def add_video(self, video: Video, camcorder: Camcorder): """ # Ensure the `Camcorder` is in this `RecordingSession`'s `CameraCluster` - try: - assert camcorder in self.camera_cluster - except AssertionError: + if camcorder not in self.camera_cluster: raise ValueError( f"Camcorder {camcorder.name} is not in this RecordingSession's " f"{self.camera_cluster}." @@ -1078,6 +1076,21 @@ def remove_video(self, video: Video): if self.labels is not None and self.labels.get_session(video) is not None: self.labels.remove_session_video(video=video) + def new_frame_group(self, frame_idx: int): + """Creates and adds an empty `FrameGroup` to the `RecordingSession`. + + Args: + frame_idx: Frame index for the `FrameGroup`. + + Returns: + `FrameGroup` object. + """ + + # `FrameGroup.__attrs_post_init` will manage `_frame_group_by_frame_idx` + frame_group = FrameGroup(frame_idx=frame_idx, session=self) + + return frame_group + def get_videos_from_selected_cameras( self, cams_to_include: Optional[List[Camcorder]] = None ) -> Dict[Camcorder, Video]: @@ -1294,13 +1307,14 @@ class FrameGroup: # Instance attributes frame_idx: int = field(validator=instance_of(int)) + session: RecordingSession = field(validator=instance_of(RecordingSession)) _instance_groups: List[InstanceGroup] = field( + factory=list, validator=deep_iterable( member_validator=instance_of(InstanceGroup), iterable_validator=instance_of(list), ), ) # Akin to `LabeledFrame.instances` - session: RecordingSession = field(validator=instance_of(RecordingSession)) _instance_group_name_registry: Set[str] = field(factory=set) # TODO(LM): Should we move this to an instance attribute of `RecordingSession`? diff --git a/tests/io/test_cameras.py b/tests/io/test_cameras.py index 6b2f1f354..21f322ad5 100644 --- a/tests/io/test_cameras.py +++ b/tests/io/test_cameras.py @@ -170,6 +170,13 @@ def test_recording_session( # Test __repr__ assert f"{session.__class__.__name__}(" in repr(session) + # Test new_frame_group + frame_group = session.new_frame_group(frame_idx=0) + assert isinstance(frame_group, FrameGroup) + assert frame_group.session == session + assert frame_group.frame_idx == 0 + assert frame_group == session.frame_groups[0] + # Test add_video camcorder = session.camera_cluster.cameras[0] session.add_video(centered_pair_vid, camcorder) @@ -450,4 +457,17 @@ def test_frame_group(multiview_min_session_labels: Labels): assert frame_group_2._frame_idx_registry[session] == {frame_idx_1, frame_idx_2} assert frame_group_1._frame_idx_registry == frame_group_2._frame_idx_registry + frame_idx_3 = 2 + frame_group_3 = FrameGroup(frame_idx=frame_idx_3, session=session) + assert isinstance(frame_group_3, FrameGroup) + assert session in frame_group_3._frame_idx_registry + assert len(frame_group_3._frame_idx_registry) == 1 + assert frame_group_3._frame_idx_registry[session] == { + frame_idx_1, + frame_idx_2, + frame_idx_3, + } + assert frame_group_1._frame_idx_registry == frame_group_3._frame_idx_registry + assert len(frame_group_3.instance_groups) == 0 + # TODO(LM): Test underlying dictionaries more thoroughly From f7b178b4e913e0daa4e524c9164f653b93c6ba72 Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Wed, 17 Apr 2024 16:39:44 -0700 Subject: [PATCH 20/22] Add TODO comments for later --- sleap/io/dataset.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 6c91af612..dc08e07cf 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -279,6 +279,7 @@ def remove_video(self, video: Video): def remove_session_video(self, video: Video): """Remove video from session in cache.""" + # TODO(LM): Also remove LabeledFrames from frame_group if video in self._session_by_video: del self._session_by_video[video] @@ -971,6 +972,9 @@ def remove_frame(self, lf: LabeledFrame, update_cache: bool = True): update_cache: If True, update the internal frame cache. If False, cache update can be postponed (useful when removing many frames). """ + + # TODO(LM): Remove LabeledFrame from any frame groups it's in. + self.labeled_frames.remove(lf) if update_cache: self._cache.remove_frame(lf) @@ -981,6 +985,8 @@ def remove_frames(self, lfs: List[LabeledFrame]): Args: lfs: A sequence of labeled frames to remove. """ + + # TODO(LM): Remove LabeledFrame from any frame groups it's in. to_remove = set(lfs) self.labeled_frames = [lf for lf in self.labeled_frames if lf not in to_remove] self.update_cache() @@ -1004,6 +1010,8 @@ def remove_empty_instances(self, keep_empty_frames: bool = True): def remove_empty_frames(self): """Remove frames with no instances.""" + + # TODO(LM): Remove LabeledFrame from any frame groups it's in. self.labeled_frames = [ lf for lf in self.labeled_frames if len(lf.instances) > 0 ] @@ -1854,6 +1862,8 @@ def remove_user_instances(self, new_labels: Optional["Labels"] = None): # Keep only labeled frames with no conflicting predictions. self.labeled_frames = keep_lfs + # TODO(LM): Remove LabeledFrame from any frame groups it's in. + def remove_predictions(self, new_labels: Optional["Labels"] = None): """Clear predicted instances from the labels. @@ -1890,6 +1900,8 @@ def remove_predictions(self, new_labels: Optional["Labels"] = None): # Keep only labeled frames with no conflicting predictions. self.labeled_frames = keep_lfs + # TODO(LM): Remove LabeledFrame from any frame groups it's in. + def remove_untracked_instances(self, remove_empty_frames: bool = True): """Remove instances that do not have a track assignment. @@ -2007,6 +2019,7 @@ def merge_matching_frames(self, video: Optional[Video] = None): for vid in {lf.video for lf in self.labeled_frames}: self.merge_matching_frames(video=vid) else: + # TODO(LM): Remove LabeledFrame from any frame groups it's in. self.labeled_frames = LabeledFrame.merge_frames( self.labeled_frames, video=video ) From ca6319a5bea58542dabdd35026272edfbc430cc4 Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Wed, 17 Apr 2024 16:54:25 -0700 Subject: [PATCH 21/22] Fix RecordingSesssion.remove_video bug --- sleap/io/dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index dc08e07cf..185b235a8 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -1729,7 +1729,8 @@ def remove_session_video(self, video: Video): # Need to remove from cache first to avoid circular reference self._cache.remove_session_video(video=video) - session.remove_video(video) + if session.get_camera(video) is not None: + session.remove_video(video) @classmethod def from_json(cls, *args, **kwargs): From 176efb26e30961980ddcb076c73032a14e8bba85 Mon Sep 17 00:00:00 2001 From: roomrys Date: Thu, 18 Apr 2024 12:53:08 -0700 Subject: [PATCH 22/22] Remove FrameGroup._frame_idx_registry class attribute --- sleap/io/cameras.py | 14 ++------------ tests/io/test_cameras.py | 28 ++++++++++++---------------- 2 files changed, 14 insertions(+), 28 deletions(-) diff --git a/sleap/io/cameras.py b/sleap/io/cameras.py index 450b072a9..d8bac5807 100644 --- a/sleap/io/cameras.py +++ b/sleap/io/cameras.py @@ -1317,10 +1317,6 @@ class FrameGroup: ) # Akin to `LabeledFrame.instances` _instance_group_name_registry: Set[str] = field(factory=set) - # TODO(LM): Should we move this to an instance attribute of `RecordingSession`? - # Class attribute to keep track of frame indices across all `RecordingSession`s - _frame_idx_registry: Dict[RecordingSession, Set[int]] = {} - # "Hidden" class attribute _cams_to_include: Optional[List[Camcorder]] = None _excluded_views: Optional[Tuple[str]] = () @@ -1354,12 +1350,6 @@ def __attrs_post_init__(self): if self._cams_to_include is not None: self.cams_to_include = self._cams_to_include - # Add frame index to registry - if self.session not in self._frame_idx_registry: - self._frame_idx_registry[self.session] = set() - - self._frame_idx_registry[self.session].add(self.frame_idx) - # Add `FrameGroup` to `RecordingSession` self.session._frame_group_by_frame_idx[self.frame_idx] = self @@ -2016,11 +2006,11 @@ def enforce_frame_idx_unique( frame_idx: Frame index. """ - if frame_idx in self._frame_idx_registry.get(session, set()): + if session.frame_groups.get(frame_idx, None) is not None: # Remove existing `FrameGroup` object from the # `RecordingSession._frame_group_by_frame_idx` logger.warning( f"Frame index {frame_idx} for FrameGroup already exists in this " "RecordingSession. Overwriting." ) - session._frame_group_by_frame_idx.pop(frame_idx) + session.frame_groups.pop(frame_idx) diff --git a/tests/io/test_cameras.py b/tests/io/test_cameras.py index 21f322ad5..16fbfd0a7 100644 --- a/tests/io/test_cameras.py +++ b/tests/io/test_cameras.py @@ -440,11 +440,12 @@ def test_frame_group(multiview_min_session_labels: Labels): session=session, instance_groups=instance_groups ) assert isinstance(frame_group_1, FrameGroup) - assert session in frame_group_1._frame_idx_registry - assert len(frame_group_1._frame_idx_registry) == 1 - assert frame_group_1._frame_idx_registry[session] == {frame_idx_1} + assert frame_idx_1 in session.frame_groups + assert len(session.frame_groups) == 1 + assert frame_group_1 == session.frame_groups[frame_idx_1] + assert len(frame_group_1.instance_groups) == 1 - # Test `_frame_idx_registry` property + # Test `RecordingSession.frame_groups` property frame_idx_2 = 1 instance_group = create_instance_group(labels=labels, frame_idx=frame_idx_2) instance_groups: List[InstanceGroup] = [instance_group] @@ -452,22 +453,17 @@ def test_frame_group(multiview_min_session_labels: Labels): session=session, instance_groups=instance_groups ) assert isinstance(frame_group_2, FrameGroup) - assert session in frame_group_2._frame_idx_registry - assert len(frame_group_2._frame_idx_registry) == 1 - assert frame_group_2._frame_idx_registry[session] == {frame_idx_1, frame_idx_2} - assert frame_group_1._frame_idx_registry == frame_group_2._frame_idx_registry + assert frame_idx_2 in session.frame_groups + assert len(session.frame_groups) == 2 + assert frame_group_2 == session.frame_groups[frame_idx_2] + assert len(frame_group_2.instance_groups) == 1 frame_idx_3 = 2 frame_group_3 = FrameGroup(frame_idx=frame_idx_3, session=session) assert isinstance(frame_group_3, FrameGroup) - assert session in frame_group_3._frame_idx_registry - assert len(frame_group_3._frame_idx_registry) == 1 - assert frame_group_3._frame_idx_registry[session] == { - frame_idx_1, - frame_idx_2, - frame_idx_3, - } - assert frame_group_1._frame_idx_registry == frame_group_3._frame_idx_registry + assert frame_idx_3 in session.frame_groups + assert len(session.frame_groups) == 3 + assert frame_group_3 == session.frame_groups[frame_idx_3] assert len(frame_group_3.instance_groups) == 0 # TODO(LM): Test underlying dictionaries more thoroughly