Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functionality to save segmentation mask during inference for images #50

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions SegGPT/SegGPT_inference/seggpt_engine.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import torch
import torch.nn.functional as F
import numpy as np
Expand Down Expand Up @@ -53,7 +55,7 @@ def run_one_image(img, tgt, model, device):
return output


def inference_image(model, device, img_path, img2_paths, tgt2_paths, out_path):
def inference_image(model, device, img_path, img2_paths, tgt2_paths, output_dir):
res, hres = 448, 448

image = Image.open(img_path).convert("RGB")
Expand Down Expand Up @@ -93,15 +95,25 @@ def inference_image(model, device, img_path, img2_paths, tgt2_paths, out_path):
"""### Run SegGPT on the image"""
# make random mask reproducible (comment out to make it change)
torch.manual_seed(2)
output = run_one_image(img, tgt, model, device)
output = F.interpolate(
output[None, ...].permute(0, 3, 1, 2),
mask = run_one_image(img, tgt, model, device)
mask = F.interpolate(
mask[None, ...].permute(0, 3, 1, 2),
size=[size[1], size[0]],
mode='nearest',
).permute(0, 2, 3, 1)[0].numpy()
output = Image.fromarray((input_image * (0.6 * output / 255 + 0.4)).astype(np.uint8))
output = Image.fromarray((input_image * (0.6 * mask / 255 + 0.4)).astype(np.uint8))

img_name = os.path.basename(img_path)

# save segmented output
out_path = os.path.join(output_dir, "output_" + '.'.join(img_name.split('.')[:-1]) + '.png')
output.save(out_path)

# save binary mask
mask_path = os.path.join(output_dir, "mask_" + '.'.join(img_name.split('.')[:-1]) + '.png')
mask_image = Image.fromarray(mask.astype(np.uint8))
mask_image.save(mask_path)


def inference_video(model, device, vid_path, num_frames, img2_paths, tgt2_paths, out_path):
res, hres = 448, 448
Expand Down
13 changes: 7 additions & 6 deletions SegGPT/SegGPT_inference/seggpt_inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import os, gc
import argparse
from tqdm import tqdm

import torch
import numpy as np
Expand All @@ -18,7 +19,7 @@ def get_args_parser():
default='seggpt_vit_large.pth')
parser.add_argument('--model', type=str, help='dir to ckpt',
default='seggpt_vit_large_patch16_input896x448')
parser.add_argument('--input_image', type=str, help='path to input image to be tested',
parser.add_argument('--input_image', type=str, nargs='+', help='path to input image to be tested',
default=None)
parser.add_argument('--input_video', type=str, help='path to input video to be tested',
default=None)
Expand Down Expand Up @@ -59,10 +60,10 @@ def prepare_model(chkpt_dir, arch='seggpt_vit_large_patch16_input896x448', seg_t
if args.input_image is not None:
assert args.prompt_image is not None and args.prompt_target is not None

img_name = os.path.basename(args.input_image)
out_path = os.path.join(args.output_dir, "output_" + '.'.join(img_name.split('.')[:-1]) + '.png')

inference_image(model, device, args.input_image, args.prompt_image, args.prompt_target, out_path)
for image in tqdm(args.input_image):
inference_image(model, device, image, args.prompt_image, args.prompt_target, args.output_dir)
torch.cuda.empty_cache()
gc.collect()

if args.input_video is not None:
assert args.prompt_target is not None and len(args.prompt_target) == 1
Expand Down