diff --git a/sleap/io/format/deeplabcut.py b/sleap/io/format/deeplabcut.py index 5892dba1a..d5d008ca2 100644 --- a/sleap/io/format/deeplabcut.py +++ b/sleap/io/format/deeplabcut.py @@ -19,7 +19,7 @@ import numpy as np import pandas as pd -from typing import List, Optional, Dict +from typing import List, Optional, Dict, Tuple from sleap import Labels, Video, Skeleton from sleap.instance import Instance, LabeledFrame, Point, Track @@ -80,8 +80,21 @@ def read( ) @classmethod - def make_video_for_image_list(cls, image_dir, filenames) -> Video: - """Creates a Video object from frame images.""" + def make_video_for_image_list( + cls, image_dir, filenames + ) -> Tuple[List[Video], List[int], List[int]]: + """Creates a Video object from frame images. + + Args: + image_dir: Directory where images are stored. + filenames: List of image filenames. + + Returns: + Tuple containing: + - List of Video objects created from the images. + - List of video indices for each image. + - List of frame indices for each image. + """ # the image filenames in the csv may not match where the user has them # so we'll change the directory to match where the user has the csv @@ -91,9 +104,34 @@ def fix_img_path(img_dir, img_filename): img_filename = os.path.join(img_dir, img_filename) return img_filename + def get_shape(filename): + import cv2 + + img = cv2.imread(filename) + return img.shape[:2] + filenames = list(map(lambda f: fix_img_path(image_dir, f), filenames)) - return Video.from_image_filenames(filenames) + # Group by shape. + shapes = list(map(get_shape, filenames)) + imgs_by_shape = {} + for filename, shape in zip(filenames, shapes): + if shape not in imgs_by_shape: + imgs_by_shape[shape] = [] + imgs_by_shape[shape].append(filename) + + # Create videos for each shape group. + videos = [] + video_inds = [] + frame_inds = [] + for video_ind, (shape, img_fns) in enumerate(imgs_by_shape.items()): + videos.append( + Video.from_image_filenames(img_fns, height=shape[0], width=shape[1]) + ) + video_inds.extend([video_ind] * len(img_fns)) + frame_inds.extend(range(len(img_fns))) + + return videos, video_inds, frame_inds @classmethod def read_frames( @@ -147,23 +185,21 @@ def read_frames( # Old format has filenames in a single column. img_files = data.iloc[:, 0] - if full_video: - video = full_video - index_frames_by_original_index = True - else: - # Create the Video object + if not full_video: + # Create the Video objects grouped by shape img_dir = os.path.dirname(filename) - video = cls.make_video_for_image_list(img_dir, img_files) - - # The frames in the video we created will be indexed from 0 to N - # rather than having their index from the original source video. - index_frames_by_original_index = False + videos, video_inds, frame_inds = cls.make_video_for_image_list( + img_dir, img_files + ) lfs = [] for i in range(len(data)): - # Figure out frame index to use. - if index_frames_by_original_index: + # Figure out the video and frame index to use. + if full_video: + # Use the input provided one. + video = full_video + # Extract "0123" from "path/img0123.png" as original frame index. frame_idx_match = re.search("(?<=img)(\\d+)(?=\\.png)", img_files[i]) @@ -174,7 +210,9 @@ def read_frames( f"Unable to determine frame index for image {img_files[i]}" ) else: - frame_idx = i + # Get from pregrouped list. + video = videos[video_inds[i]] + frame_idx = frame_inds[i] instances = [] if is_multianimal: