Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added frame reader iterator for .png or .jpg #11

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion demo/video_demo_with_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from mmdet.apis import init_detector
from mmdet.registry import VISUALIZERS
from mmcv.ops.nms import batched_nms
from demo.video_frame_loader import VideoFrameReader

import masa
from masa.apis import inference_masa, init_masa, inference_detector, build_test_pipeline
Expand Down Expand Up @@ -108,7 +109,13 @@ def main():
sam_model = sam_model_registry[args.sam_type](args.sam_path)
sam_predictor = SamPredictor(sam_model.to(device))

video_reader = mmcv.VideoReader(args.video)
if os.path.isfile(args.video):
video_reader = mmcv.VideoReader(args.video)
elif os.path.isdir(args.video):
video_reader = VideoFrameReader(args.video)
else:
raise Exception('Video should be a file or directory with frames ' + args.video)
# video_reader = mmcv.
video_writer = None

#### parsing the text input
Expand Down
78 changes: 78 additions & 0 deletions demo/video_frame_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import os
import cv2

class VideoFrameReader:
def __init__(self, video_path, extensions=['.jpg', '.png']):
self.fps = 12 # for video writer
self.video_path = video_path
self.frame_list = []
self.extensions = extensions
self.read_count = 0

# Populate frame_list with valid images
for im in sorted(os.listdir(video_path)):
if self.has_extensions(im):
self.frame_list.append(im)

self.frame_len = len(self.frame_list)
print(f"{video_path} has {self.frame_len} frames")

self.init_frame_metadata()

def init_frame_metadata(self):
assert self.frame_len > 0, 'There should be at least one frame'

# Read the first frame to get dimensions
frame = cv2.imread(os.path.join(self.video_path, self.frame_list[0]))

if frame is None:
raise ValueError('The first frame could not be read, please check the frame paths.')

self.height, self.width = frame.shape[:2]

# Ensure dimensions are even
self.height = self.height // 2 * 2
self.width = self.width // 2 * 2

def has_extensions(self, img_name):
return any(img_name.endswith(ext) for ext in self.extensions)

def __len__(self):
return self.frame_len

def __iter__(self):
self.read_count = 0 # Reset the read count for new iterations
return self

def __next__(self):
if self.read_count < self.frame_len:
frame = self(self.read_count) # Get the current frame
self.read_count += 1
return frame
else:
raise StopIteration

def __call__(self, index=None):
if index is None:
index = self.read_count
self.read_count += 1

if index < 0 or index >= self.frame_len:
raise IndexError(f'Trying to access frame beyond the bounds of video with len {self.frame_len}, and given index {index}')

frame_path = os.path.join(self.video_path, self.frame_list[index])
frame = cv2.imread(frame_path)

if frame is None:
raise ValueError(f'Frame at path {frame_path} could not be read.')

# Convert grayscale frames to RGB
if len(frame.shape) == 2:
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)

# Resize frame if necessary
if self.width != frame.shape[1] or self.height != frame.shape[0]:
frame = cv2.resize(frame, (self.width, self.height))

return frame