diff --git a/SegGPT/SegGPT_inference/seggpt_engine.py b/SegGPT/SegGPT_inference/seggpt_engine.py index 3e678d3..afdf3e5 100644 --- a/SegGPT/SegGPT_inference/seggpt_engine.py +++ b/SegGPT/SegGPT_inference/seggpt_engine.py @@ -1,3 +1,5 @@ +import os + import torch import torch.nn.functional as F import numpy as np @@ -53,7 +55,7 @@ def run_one_image(img, tgt, model, device): return output -def inference_image(model, device, img_path, img2_paths, tgt2_paths, out_path): +def inference_image(model, device, img_path, img2_paths, tgt2_paths, output_dir): res, hres = 448, 448 image = Image.open(img_path).convert("RGB") @@ -93,15 +95,25 @@ def inference_image(model, device, img_path, img2_paths, tgt2_paths, out_path): """### Run SegGPT on the image""" # make random mask reproducible (comment out to make it change) torch.manual_seed(2) - output = run_one_image(img, tgt, model, device) - output = F.interpolate( - output[None, ...].permute(0, 3, 1, 2), + mask = run_one_image(img, tgt, model, device) + mask = F.interpolate( + mask[None, ...].permute(0, 3, 1, 2), size=[size[1], size[0]], mode='nearest', ).permute(0, 2, 3, 1)[0].numpy() - output = Image.fromarray((input_image * (0.6 * output / 255 + 0.4)).astype(np.uint8)) + output = Image.fromarray((input_image * (0.6 * mask / 255 + 0.4)).astype(np.uint8)) + + img_name = os.path.basename(img_path) + + # save segmented output + out_path = os.path.join(output_dir, "output_" + '.'.join(img_name.split('.')[:-1]) + '.png') output.save(out_path) + # save binary mask + mask_path = os.path.join(output_dir, "mask_" + '.'.join(img_name.split('.')[:-1]) + '.png') + mask_image = Image.fromarray(mask.astype(np.uint8)) + mask_image.save(mask_path) + def inference_video(model, device, vid_path, num_frames, img2_paths, tgt2_paths, out_path): res, hres = 448, 448 diff --git a/SegGPT/SegGPT_inference/seggpt_inference.py b/SegGPT/SegGPT_inference/seggpt_inference.py index 42000f0..6a36a78 100644 --- a/SegGPT/SegGPT_inference/seggpt_inference.py +++ b/SegGPT/SegGPT_inference/seggpt_inference.py @@ -1,5 +1,6 @@ -import os +import os, gc import argparse +from tqdm import tqdm import torch import numpy as np @@ -18,7 +19,7 @@ def get_args_parser(): default='seggpt_vit_large.pth') parser.add_argument('--model', type=str, help='dir to ckpt', default='seggpt_vit_large_patch16_input896x448') - parser.add_argument('--input_image', type=str, help='path to input image to be tested', + parser.add_argument('--input_image', type=str, nargs='+', help='path to input image to be tested', default=None) parser.add_argument('--input_video', type=str, help='path to input video to be tested', default=None) @@ -59,10 +60,10 @@ def prepare_model(chkpt_dir, arch='seggpt_vit_large_patch16_input896x448', seg_t if args.input_image is not None: assert args.prompt_image is not None and args.prompt_target is not None - img_name = os.path.basename(args.input_image) - out_path = os.path.join(args.output_dir, "output_" + '.'.join(img_name.split('.')[:-1]) + '.png') - - inference_image(model, device, args.input_image, args.prompt_image, args.prompt_target, out_path) + for image in tqdm(args.input_image): + inference_image(model, device, image, args.prompt_image, args.prompt_target, args.output_dir) + torch.cuda.empty_cache() + gc.collect() if args.input_video is not None: assert args.prompt_target is not None and len(args.prompt_target) == 1