Skip to content

Commit

Permalink
Add resolution grouping to image dataset import for DLC
Browse files Browse the repository at this point in the history
  • Loading branch information
talmo committed Jan 19, 2025
1 parent 7785f66 commit 4df2010
Showing 1 changed file with 55 additions and 17 deletions.
72 changes: 55 additions & 17 deletions sleap/io/format/deeplabcut.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np
import pandas as pd

from typing import List, Optional, Dict
from typing import List, Optional, Dict, Tuple

from sleap import Labels, Video, Skeleton
from sleap.instance import Instance, LabeledFrame, Point, Track
Expand Down Expand Up @@ -80,8 +80,21 @@ def read(
)

@classmethod
def make_video_for_image_list(cls, image_dir, filenames) -> Video:
"""Creates a Video object from frame images."""
def make_video_for_image_list(
cls, image_dir, filenames
) -> Tuple[List[Video], List[int], List[int]]:
"""Creates a Video object from frame images.
Args:
image_dir: Directory where images are stored.
filenames: List of image filenames.
Returns:
Tuple containing:
- List of Video objects created from the images.
- List of video indices for each image.
- List of frame indices for each image.
"""

# the image filenames in the csv may not match where the user has them
# so we'll change the directory to match where the user has the csv
Expand All @@ -91,9 +104,34 @@ def fix_img_path(img_dir, img_filename):
img_filename = os.path.join(img_dir, img_filename)
return img_filename

def get_shape(filename):
import cv2

img = cv2.imread(filename)
return img.shape[:2]

filenames = list(map(lambda f: fix_img_path(image_dir, f), filenames))

return Video.from_image_filenames(filenames)
# Group by shape.
shapes = list(map(get_shape, filenames))
imgs_by_shape = {}
for filename, shape in zip(filenames, shapes):
if shape not in imgs_by_shape:
imgs_by_shape[shape] = []
imgs_by_shape[shape].append(filename)

# Create videos for each shape group.
videos = []
video_inds = []
frame_inds = []
for video_ind, (shape, img_fns) in enumerate(imgs_by_shape.items()):
videos.append(
Video.from_image_filenames(img_fns, height=shape[0], width=shape[1])
)
video_inds.extend([video_ind] * len(img_fns))
frame_inds.extend(range(len(img_fns)))

return videos, video_inds, frame_inds

@classmethod
def read_frames(
Expand Down Expand Up @@ -147,23 +185,21 @@ def read_frames(
# Old format has filenames in a single column.
img_files = data.iloc[:, 0]

if full_video:
video = full_video
index_frames_by_original_index = True
else:
# Create the Video object
if not full_video:
# Create the Video objects grouped by shape
img_dir = os.path.dirname(filename)
video = cls.make_video_for_image_list(img_dir, img_files)

# The frames in the video we created will be indexed from 0 to N
# rather than having their index from the original source video.
index_frames_by_original_index = False
videos, video_inds, frame_inds = cls.make_video_for_image_list(
img_dir, img_files
)

lfs = []
for i in range(len(data)):

# Figure out frame index to use.
if index_frames_by_original_index:
# Figure out the video and frame index to use.
if full_video:
# Use the input provided one.
video = full_video

# Extract "0123" from "path/img0123.png" as original frame index.
frame_idx_match = re.search("(?<=img)(\\d+)(?=\\.png)", img_files[i])

Expand All @@ -174,7 +210,9 @@ def read_frames(
f"Unable to determine frame index for image {img_files[i]}"
)
else:
frame_idx = i
# Get from pregrouped list.
video = videos[video_inds[i]]
frame_idx = frame_inds[i]

instances = []
if is_multianimal:
Expand Down

0 comments on commit 4df2010

Please sign in to comment.