From 69c894f58224dae639e28bbee44bceb97a5e65d4 Mon Sep 17 00:00:00 2001 From: Abdullah Meda Date: Fri, 14 Jul 2023 20:31:06 +0400 Subject: [PATCH 1/6] save mask separately --- SegGPT/SegGPT_inference/seggpt_engine.py | 15 ++++++++++++--- SegGPT/SegGPT_inference/seggpt_inference.py | 5 +---- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/SegGPT/SegGPT_inference/seggpt_engine.py b/SegGPT/SegGPT_inference/seggpt_engine.py index 3e678d3..0428c6e 100644 --- a/SegGPT/SegGPT_inference/seggpt_engine.py +++ b/SegGPT/SegGPT_inference/seggpt_engine.py @@ -1,3 +1,5 @@ +import os + import torch import torch.nn.functional as F import numpy as np @@ -50,10 +52,10 @@ def run_one_image(img, tgt, model, device): output = y[0, y.shape[1]//2:, :, :] output = torch.clip((output * imagenet_std + imagenet_mean) * 255, 0, 255) - return output + return output, mask -def inference_image(model, device, img_path, img2_paths, tgt2_paths, out_path): +def inference_image(model, device, img_path, img2_paths, tgt2_paths, output_dir): res, hres = 448, 448 image = Image.open(img_path).convert("RGB") @@ -93,15 +95,22 @@ def inference_image(model, device, img_path, img2_paths, tgt2_paths, out_path): """### Run SegGPT on the image""" # make random mask reproducible (comment out to make it change) torch.manual_seed(2) - output = run_one_image(img, tgt, model, device) + output, mask = run_one_image(img, tgt, model, device) output = F.interpolate( output[None, ...].permute(0, 3, 1, 2), size=[size[1], size[0]], mode='nearest', ).permute(0, 2, 3, 1)[0].numpy() output = Image.fromarray((input_image * (0.6 * output / 255 + 0.4)).astype(np.uint8)) + + img_name = os.path.basename(img_path) + out_path = os.path.join(output_dir, "output_" + '.'.join(img_name.split('.')[:-1]) + '.png') output.save(out_path) + mask_path = os.path.join(output_dir, "mask_" + '.'.join(img_name.split('.')[:-1]) + '.png') + mask_image = Image.fromarray(mask.detach().cpu().numpy().astype(np.uint8)) + mask_image.save(mask_path) + def inference_video(model, device, vid_path, num_frames, img2_paths, tgt2_paths, out_path): res, hres = 448, 448 diff --git a/SegGPT/SegGPT_inference/seggpt_inference.py b/SegGPT/SegGPT_inference/seggpt_inference.py index 42000f0..4a4b90a 100644 --- a/SegGPT/SegGPT_inference/seggpt_inference.py +++ b/SegGPT/SegGPT_inference/seggpt_inference.py @@ -59,10 +59,7 @@ def prepare_model(chkpt_dir, arch='seggpt_vit_large_patch16_input896x448', seg_t if args.input_image is not None: assert args.prompt_image is not None and args.prompt_target is not None - img_name = os.path.basename(args.input_image) - out_path = os.path.join(args.output_dir, "output_" + '.'.join(img_name.split('.')[:-1]) + '.png') - - inference_image(model, device, args.input_image, args.prompt_image, args.prompt_target, out_path) + inference_image(model, device, args.input_image, args.prompt_image, args.prompt_target, args.output_dir) if args.input_video is not None: assert args.prompt_target is not None and len(args.prompt_target) == 1 From e05f02f604367462dab3bb357e9ef9a139615a0f Mon Sep 17 00:00:00 2001 From: Abdullah Meda Date: Fri, 14 Jul 2023 21:34:05 +0400 Subject: [PATCH 2/6] minor changes --- SegGPT/SegGPT_inference/seggpt_engine.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/SegGPT/SegGPT_inference/seggpt_engine.py b/SegGPT/SegGPT_inference/seggpt_engine.py index 0428c6e..3bc501a 100644 --- a/SegGPT/SegGPT_inference/seggpt_engine.py +++ b/SegGPT/SegGPT_inference/seggpt_engine.py @@ -52,7 +52,7 @@ def run_one_image(img, tgt, model, device): output = y[0, y.shape[1]//2:, :, :] output = torch.clip((output * imagenet_std + imagenet_mean) * 255, 0, 255) - return output, mask + return output def inference_image(model, device, img_path, img2_paths, tgt2_paths, output_dir): @@ -95,13 +95,13 @@ def inference_image(model, device, img_path, img2_paths, tgt2_paths, output_dir) """### Run SegGPT on the image""" # make random mask reproducible (comment out to make it change) torch.manual_seed(2) - output, mask = run_one_image(img, tgt, model, device) - output = F.interpolate( - output[None, ...].permute(0, 3, 1, 2), + mask = run_one_image(img, tgt, model, device) + maskmask = F.interpolate( + mask[None, ...].permute(0, 3, 1, 2), size=[size[1], size[0]], mode='nearest', ).permute(0, 2, 3, 1)[0].numpy() - output = Image.fromarray((input_image * (0.6 * output / 255 + 0.4)).astype(np.uint8)) + output = Image.fromarray((input_image * (0.6 * mask / 255 + 0.4)).astype(np.uint8)) img_name = os.path.basename(img_path) out_path = os.path.join(output_dir, "output_" + '.'.join(img_name.split('.')[:-1]) + '.png') From 0ef9c3245b839d3074df3fc62ea0b8753e516f26 Mon Sep 17 00:00:00 2001 From: Abdullah Meda Date: Fri, 14 Jul 2023 21:35:03 +0400 Subject: [PATCH 3/6] minor changes --- SegGPT/SegGPT_inference/seggpt_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/SegGPT/SegGPT_inference/seggpt_engine.py b/SegGPT/SegGPT_inference/seggpt_engine.py index 3bc501a..13dd029 100644 --- a/SegGPT/SegGPT_inference/seggpt_engine.py +++ b/SegGPT/SegGPT_inference/seggpt_engine.py @@ -96,7 +96,7 @@ def inference_image(model, device, img_path, img2_paths, tgt2_paths, output_dir) # make random mask reproducible (comment out to make it change) torch.manual_seed(2) mask = run_one_image(img, tgt, model, device) - maskmask = F.interpolate( + mask = F.interpolate( mask[None, ...].permute(0, 3, 1, 2), size=[size[1], size[0]], mode='nearest', From ae8fef40f8a68a0892780d7742e6baf7af2f72a0 Mon Sep 17 00:00:00 2001 From: Abdullah Meda Date: Fri, 14 Jul 2023 21:57:25 +0400 Subject: [PATCH 4/6] minor changes --- SegGPT/SegGPT_inference/seggpt_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/SegGPT/SegGPT_inference/seggpt_engine.py b/SegGPT/SegGPT_inference/seggpt_engine.py index 13dd029..f86a911 100644 --- a/SegGPT/SegGPT_inference/seggpt_engine.py +++ b/SegGPT/SegGPT_inference/seggpt_engine.py @@ -108,7 +108,7 @@ def inference_image(model, device, img_path, img2_paths, tgt2_paths, output_dir) output.save(out_path) mask_path = os.path.join(output_dir, "mask_" + '.'.join(img_name.split('.')[:-1]) + '.png') - mask_image = Image.fromarray(mask.detach().cpu().numpy().astype(np.uint8)) + mask_image = Image.fromarray(mask.astype(np.uint8)) mask_image.save(mask_path) From f15909760427bdb1a1c8e6c187a35a68f641b6e9 Mon Sep 17 00:00:00 2001 From: Abdullah Meda Date: Sun, 16 Jul 2023 18:27:51 +0400 Subject: [PATCH 5/6] minor changes --- SegGPT/SegGPT_inference/seggpt_engine.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/SegGPT/SegGPT_inference/seggpt_engine.py b/SegGPT/SegGPT_inference/seggpt_engine.py index f86a911..afdf3e5 100644 --- a/SegGPT/SegGPT_inference/seggpt_engine.py +++ b/SegGPT/SegGPT_inference/seggpt_engine.py @@ -104,9 +104,12 @@ def inference_image(model, device, img_path, img2_paths, tgt2_paths, output_dir) output = Image.fromarray((input_image * (0.6 * mask / 255 + 0.4)).astype(np.uint8)) img_name = os.path.basename(img_path) + + # save segmented output out_path = os.path.join(output_dir, "output_" + '.'.join(img_name.split('.')[:-1]) + '.png') output.save(out_path) + # save binary mask mask_path = os.path.join(output_dir, "mask_" + '.'.join(img_name.split('.')[:-1]) + '.png') mask_image = Image.fromarray(mask.astype(np.uint8)) mask_image.save(mask_path) From bfb6a78d6066f194efa711340fbc3b578b48ee47 Mon Sep 17 00:00:00 2001 From: Abdullah Meda Date: Sun, 13 Aug 2023 23:07:58 +0400 Subject: [PATCH 6/6] multi input image --- SegGPT/SegGPT_inference/seggpt_inference.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/SegGPT/SegGPT_inference/seggpt_inference.py b/SegGPT/SegGPT_inference/seggpt_inference.py index 4a4b90a..6a36a78 100644 --- a/SegGPT/SegGPT_inference/seggpt_inference.py +++ b/SegGPT/SegGPT_inference/seggpt_inference.py @@ -1,5 +1,6 @@ -import os +import os, gc import argparse +from tqdm import tqdm import torch import numpy as np @@ -18,7 +19,7 @@ def get_args_parser(): default='seggpt_vit_large.pth') parser.add_argument('--model', type=str, help='dir to ckpt', default='seggpt_vit_large_patch16_input896x448') - parser.add_argument('--input_image', type=str, help='path to input image to be tested', + parser.add_argument('--input_image', type=str, nargs='+', help='path to input image to be tested', default=None) parser.add_argument('--input_video', type=str, help='path to input video to be tested', default=None) @@ -59,7 +60,10 @@ def prepare_model(chkpt_dir, arch='seggpt_vit_large_patch16_input896x448', seg_t if args.input_image is not None: assert args.prompt_image is not None and args.prompt_target is not None - inference_image(model, device, args.input_image, args.prompt_image, args.prompt_target, args.output_dir) + for image in tqdm(args.input_image): + inference_image(model, device, image, args.prompt_image, args.prompt_target, args.output_dir) + torch.cuda.empty_cache() + gc.collect() if args.input_video is not None: assert args.prompt_target is not None and len(args.prompt_target) == 1