Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimization of function for converting mask images to uint8 type #426

Open
wants to merge 7 commits into
base: Main
Choose a base branch
from
108 changes: 83 additions & 25 deletions modules/impact/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import cv2
import time
from impact import utils
from impact.uniformers import ensure_nhwc_mask_torch, ensure_nhwc_mask_numpy

SEG = namedtuple("SEG",
['cropped_image', 'cropped_mask', 'confidence', 'crop_region', 'bbox', 'label', 'control_net_wrapper'],
Expand Down Expand Up @@ -814,30 +815,47 @@ def make_sam_mask_segmented(sam_model, segs, image, detection_hint, dilation,
mask = combine_masks2(total_masks)

finally:
mask_working_device = torch.device("cpu")
if sam_model.is_auto_mode:
sam_model.cpu()

pass

mask_working_device = torch.device("cpu")

if mask is not None:
mask = mask.float()
mask = dilate_mask(mask.cpu().numpy(), dilation)
mask = torch.from_numpy(mask)
mask = mask.to(device=mask_working_device)
else:
# Extracting batch, height and width
height, width, _ = image.shape
mask = torch.zeros(
(height, width), dtype=torch.float32, device=mask_working_device
) # empty mask

stacked_masks = convert_and_stack_masks(total_masks)

return (mask, merge_and_stack_masks(stacked_masks, group_size=3))
# return every_three_pick_last(stacked_masks)

if mask is not None:
mask = mask.float()
# Convert to CPU and add one channel dimension to the mask at the end
mask = np.expand_dims(dilate_mask(mask.cpu().numpy(), dilation), axis=-1)
mask = np.expand_dims(mask, axis=0) # Add a batch dimension to the mask at the beginning
mask = torch.from_numpy(mask)
mask = mask.to(device=mask_working_device)
else:
# Extract the batch, height, and width
height, width, _ = image.shape
# Create an empty mask with the shape (N, H, W, 1), where N is the batch size, set to 1
mask = torch.zeros(
(1, height, width, 1), dtype=torch.float32, device=mask_working_device
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason for making the mask 4-dim here? The masks used in ComfyUI are 3-dim masks (b, h, w)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The review was pending due to my mistake.

)

# Handle the stacked_masks at the return statement location
batch_masks = None

if not total_masks:
height, width, _ = image.shape
# Create a blank mask that matches the size of the input image
stacked_masks = np.zeros((1, height, width, 1), dtype=np.float32)
# As there is only one image, so batch_masks is equivalent to stacked_masks
batch_masks = stacked_masks
else:
# Attempt to convert and stack masks
stacked_masks = convert_and_stack_masks(total_masks)
if stacked_masks is not None:
stacked_masks = np.transpose(stacked_masks, (0, 2, 3, 1))
batch_masks = merge_and_stack_masks(stacked_masks, group_size=3)
else:
# If None is returned, create a blank mask that matches the size of the input image
height, width, _ = image.shape
stacked_masks = np.zeros((1, height, width, 1), dtype=np.float32)
batch_masks = stacked_masks

combined_mask = mask
return (combined_mask, batch_masks)

def segs_bitwise_and_mask(segs, mask):
mask = make_2d_mask(mask)
Expand Down Expand Up @@ -961,8 +979,46 @@ def detect_combined(self, image, threshold, dilation):
def setAux(self, x):
pass

def optimized_mask_to_uint8(mask):
"""
Convert the input mask to uint8 type, and perform appropriate clipping and dimension compression.

Args:
mask: Could be a numpy array of any shape.

Returns:
mask_uint8: A numpy array which has been converted to uint8 type and properly dimension-compressed.
"""
try:
# Ensure the input is numpy array
if not isinstance(mask, np.ndarray):
raise ValueError("The type of the input mask needs to be numpy.ndarray!")

# Clamping, make sure values are within 0-1
if np.any(mask < 0) or np.any(mask > 1):
print("Warning: Values within the mask are out of the range [0,1], will be clamped.")
mask_clamped = np.clip(mask, 0, 1)

# If mask is floating-point type, convert it to np.float32 to avoid overflow
if mask_clamped.dtype.kind == 'f':
mask_clamped = mask_clamped.astype(np.float32)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it has already been clamped between 0 and 1, is there a need to consider overflow?


# Convert to uint8
mask_uint8 = (mask_clamped * 255).astype(np.uint8)

# If dimensional reduction is needed, do it
if mask_uint8.ndim == 3 and mask_uint8.shape[0] == 1:
mask_uint8 = mask_uint8.squeeze(0)

return mask_uint8

except Exception as e:
print(f"An error occurred during the process of optimizing mask to uint8: {repr(e)}")
return None

def mask_to_segs(mask, combined, crop_factor, bbox_fill, drop_size=1, label='A', crop_min_size=None, detailer_hook=None, is_contour=True):
print(f'mask shape: {mask.shape}')

drop_size = max(drop_size, 1)
if mask is None:
print("[mask_to_segs] Cannot operate: MASK is empty.")
Expand All @@ -983,8 +1039,10 @@ def mask_to_segs(mask, combined, crop_factor, bbox_fill, drop_size=1, label='A',

result = []

if len(mask.shape) == 2:
mask = np.expand_dims(mask, axis=0)
# make sure the mask is in NHWC format
mask = ensure_nhwc_mask_numpy(mask)
# then we need to remove the channel dimension
mask = mask.squeeze(-1)

for i in range(mask.shape[0]):
mask_i = mask[i]
Expand Down Expand Up @@ -1019,7 +1077,7 @@ def mask_to_segs(mask, combined, crop_factor, bbox_fill, drop_size=1, label='A',
result.append(item)

else:
mask_i_uint8 = (mask_i * 255.0).astype(np.uint8)
mask_i_uint8 = optimized_mask_to_uint8(mask_i)
contours, ctree = cv2.findContours(mask_i_uint8, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
for j, contour in enumerate(contours):
hierarchy = ctree[0][j]
Expand Down
90 changes: 90 additions & 0 deletions modules/impact/uniformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import numpy as np
import torch


def ensure_nhwc_mask_torch(masks):
"""
Ensure that the input masks are in the NHWC format, if not, switch to the NHWC format.
"""

if masks is None or not isinstance(masks, torch.Tensor) or masks.ndim < 2:
print(
"[ERROR] The input masks are not in the expected format. The required types are torch.Tensor with at least two dimensions."
)
print(
" - If it's a list or one-dimensional array, please ensure it's been transformed into an array or tensor with at least two dimensions."
)
print(" - If the masks is null, ensure to provide non-empty input.")
return None

# [N, C, H, W] -> [N, H, W, C]
if masks.ndim == 4:
N, C, H, W = masks.shape
if C in [1, 3] and H > 3 and W > 3:
return masks.permute(0, 2, 3, 1)
else:
# Convert to NHWC format
return masks.permute(0, 2, 3, 1)
# [1, H, W] -> [1, H, W, 1]
elif masks.ndim == 3 and masks.shape[0] == 1:
return masks.unsqueeze(-1)
# [H, W] -> [1, H, W, 1]
elif masks.ndim == 2:
return masks.unsqueeze(0).unsqueeze(-1)
# [H, W, C] -> [1, H, W, C]
elif masks.ndim == 3:
H, W, C = masks.shape
if C in [1, 3] and H > 3 and W > 3:
# Masks are in the HWC format, need to add a batch dimension.
return masks.unsqueeze(0)
else:
print(
"[ERROR] The three-dimensional input tensor [H, W, C] is not in the correct shape. Please ensure that the C is between [1, 3], and H and W are representing the width and height of the pixel."
)
return None

return None


def ensure_nhwc_mask_numpy(masks):
"""
Transform the shape of the input masks into NHWC format (NumPy version).
"""

if masks is None or not isinstance(masks, (np.ndarray)) or masks.ndim < 2:
print(
"[ERROR] The input masks are not in the expected format. The required types are np.ndarray with at least two dimensions."
)
print(
" - If it's a list or one-dimensional array, please ensure it's been transformed into an array with at least two dimensions."
)
print(" - If the masks is null, ensure to provide non-empty input.")
return None

# [N, C, H, W] -> [N, H, W, C]
if masks.ndim == 4:
N, H, W, C = masks.shape
if C in [1, 3] and H > 3 and W > 3:
return masks
else:
# convert to NHWC format
return np.transpose(masks, (0, 2, 3, 1))
# [1, H, W] -> [1, H, W, 1]
elif masks.ndim == 3 and masks.shape[0] == 1:
return np.expand_dims(masks, axis=-1)
# [H, W] -> [1, H, W, 1]
elif masks.ndim == 2:
return np.expand_dims(np.expand_dims(masks, axis=0), axis=-1)
# [H, W, C] -> [1, H, W, C]
elif masks.ndim == 3:
H, W, C = masks.shape
if C in [1, 3] and H > 3 and W > 3:
# masks is in the HWC format, need to add a batch dimension.
return np.expand_dims(masks, axis=0)
else:
print(
"[ERROR] The input array of three dimensions [H, W, C] is not in the correct shape. Please ensure that the C is between [1, 3], and H and W are representing the width and height of the pixel."
)
return None

return None