diff --git a/tools/inference_video.py b/tools/inference_video.py new file mode 100644 index 0000000..3ad3c06 --- /dev/null +++ b/tools/inference_video.py @@ -0,0 +1,95 @@ +import argparse +import warnings + +import cv2 +import mmcv +import ssod +from mmcv.runner import load_checkpoint +from mmdet.apis import inference_detector, init_detector +from mmdet.core import get_classes + +config_name = 'configs/consistent-teacher/consistent_teacher_r50_fpn_coco_720k_fulldata.py' +checkpoint = '../ckpt_logs/ConsistentTeacher/consistent_teacher_r50_fpn_coco_720k_fulldata/consistent_teacher_r50_fpn_coco_720k_fulldata_iter_720000-d932808f.pth' +video = '/home/yangxingyi/视频/vlc-record-2023-03-11-21h13m23s-videoplayback.mp4-.mp4' +PALETTE = [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), + (106, 0, 228), (0, 60, 100), (0, 80, 100), (0, 0, 70), + (0, 0, 192), (250, 170, 30), (100, 170, 30), (220, 220, 0), + (175, 116, 175), (250, 0, 30), (165, 42, 42), (255, 77, 255), + (0, 226, 252), (182, 182, 255), (0, 82, 0), (120, 166, 157), + (110, 76, 0), (174, 57, 255), (199, 100, 0), (72, 0, 118), + (255, 179, 240), (0, 125, 92), (209, 0, 151), (188, 208, 182), + (0, 220, 176), (255, 99, 164), (92, 0, 73), (133, 129, 255), + (78, 180, 255), (0, 228, 0), (174, 255, 243), (45, 89, 255), + (134, 134, 103), (145, 148, 174), (255, 208, 186), + (197, 226, 255), (171, 134, 1), (109, 63, 54), (207, 138, 255), + (151, 0, 95), (9, 80, 61), (84, 105, 51), (74, 65, 105), + (166, 196, 102), (208, 195, 210), (255, 109, 65), (0, 143, 149), + (179, 0, 194), (209, 99, 106), (5, 121, 0), (227, 255, 205), + (147, 186, 208), (153, 69, 1), (3, 95, 161), (163, 255, 0), + (119, 0, 170), (0, 182, 199), (0, 165, 120), (183, 130, 88), + (95, 32, 0), (130, 114, 135), (110, 129, 133), (166, 74, 118), + (219, 142, 185), (79, 210, 114), (178, 90, 62), (65, 70, 15), + (127, 167, 115), (59, 105, 106), (142, 108, 45), (196, 172, 0), + (95, 54, 80), (128, 76, 255), (201, 57, 1), (246, 0, 122), + (191, 162, 208)] + +def parse_args(): + parser = argparse.ArgumentParser(description='MMDetection video demo') + parser.add_argument('--video', default=video, help='Video file') + parser.add_argument('--config', default=config_name, help='Config file') + parser.add_argument('--checkpoint', default=checkpoint, help='Checkpoint file') + parser.add_argument( + '--device', default='cpu', help='Device used for inference') + parser.add_argument( + '--score-thr', type=float, default=0.3, help='Bbox score threshold') + parser.add_argument('--out', type=str, help='Output video file') + parser.add_argument('--show', action='store_true', help='Show video') + parser.add_argument( + '--wait-time', + type=float, + default=1, + help='The interval of show (s), 0 is block') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + assert args.out or args.show, \ + ('Please specify at least one operation (save/show the ' + 'video) with the argument "--out" or "--show"') + + model = init_detector(args.config, checkpoint=None, device=args.device) + checkpoint = load_checkpoint(model, args.checkpoint, revise_keys=[(r'^teacher\.', '')]) + if 'CLASSES' in checkpoint.get('meta', {}): + model.CLASSES = checkpoint['meta']['CLASSES'] + else: + warnings.simplefilter('once') + warnings.warn('Class names are not saved in the checkpoint\'s ' + 'meta data, use COCO classes by default.') + model.CLASSES = get_classes('coco') + + video_reader = mmcv.VideoReader(args.video) + video_writer = None + if args.out: + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + video_writer = cv2.VideoWriter( + args.out, fourcc, video_reader.fps, + (video_reader.width, video_reader.height)) + + for frame in mmcv.track_iter_progress(video_reader): + result = inference_detector(model, frame) + frame = model.show_result(frame, result, bbox_color=PALETTE, score_thr=args.score_thr) + if args.show: + cv2.namedWindow('video', 0) + mmcv.imshow(frame, 'video', args.wait_time) + if args.out: + video_writer.write(frame) + + if video_writer: + video_writer.release() + cv2.destroyAllWindows() + + +if __name__ == '__main__': + main()