Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding auto1111 features to inpainting pipeline #6072

Merged
merged 6 commits into from
Dec 26, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 164 additions & 45 deletions src/diffusers/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np
import PIL.Image
import torch
from PIL import Image
from PIL import Image, ImageFilter

from .configuration_utils import ConfigMixin, register_to_config
from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
Expand Down Expand Up @@ -166,72 +166,126 @@ def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:

return image

def get_default_height_width(
@staticmethod
def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image:
"""
Blurs an image.
"""
image = image.filter(ImageFilter.GaussianBlur(blur_factor))

return image

def _resize_and_fill(
self,
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
height: Optional[int] = None,
width: Optional[int] = None,
) -> Tuple[int, int]:
image: PIL.Image.Image,
width: int,
height: int,
) -> PIL.Image.Image:
"""
This function return the height and width that are downscaled to the next integer multiple of
`vae_scale_factor`.
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.

Args:
image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have
shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should
have shape `[batch, channel, height, width]`.
height (`int`, *optional*, defaults to `None`):
The height in preprocessed image. If `None`, will use the height of `image` input.
width (`int`, *optional*`, defaults to `None`):
The width in preprocessed. If `None`, will use the width of the `image` input.
image: The image to resize.
width: The width to resize the image to.
height: The height to resize the image to.
"""

ratio = width / height
src_ratio = image.width / image.height

src_w = width if ratio < src_ratio else image.width * height // image.height
src_h = height if ratio >= src_ratio else image.height * width // image.width

resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
res = Image.new("RGB", (width, height))
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))

if ratio < src_ratio:
fill_height = height // 2 - src_h // 2
if fill_height > 0:
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
res.paste(
resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)),
box=(0, fill_height + src_h),
)
elif ratio > src_ratio:
fill_width = width // 2 - src_w // 2
if fill_width > 0:
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
res.paste(
resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)),
box=(fill_width + src_w, 0),
)

return res

def _resize_and_crop(
self,
image: PIL.Image.Image,
width: int,
height: int,
) -> PIL.Image.Image:
"""
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.

if height is None:
if isinstance(image, PIL.Image.Image):
height = image.height
elif isinstance(image, torch.Tensor):
height = image.shape[2]
else:
height = image.shape[1]
Args:
image: The image to resize.
width: The width to resize the image to.
height: The height to resize the image to.
"""
ratio = width / height
src_ratio = image.width / image.height

if width is None:
if isinstance(image, PIL.Image.Image):
width = image.width
elif isinstance(image, torch.Tensor):
width = image.shape[3]
else:
width = image.shape[2]
src_w = width if ratio > src_ratio else image.width * height // image.height
src_h = height if ratio <= src_ratio else image.height * width // image.width

width, height = (
x - x % self.config.vae_scale_factor for x in (width, height)
) # resize to integer multiple of vae_scale_factor

return height, width
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
res = Image.new("RGB", (width, height))
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
return res

def resize(
self,
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
height: Optional[int] = None,
width: Optional[int] = None,
height: int,
width: int,
Comment on lines +338 to +339
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
height: int,
width: int,
height: int,
width: int,

Was it incorrect to have height and width be optional before? Did it throw an error when they weren't passed?

Ok to remove Optional[int] = None if one had to pass them anyways - otherwise it'd be backwards breaking

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was incorrect before and will throw an error when they were not passed.

This is the function in the current codebase

    def resize(
        self,
        image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
        height: Optional[int] = None,
        width: Optional[int] = None,
    ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
        """
        Resize image.

        Args:
            image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
                The image input, can be a PIL image, numpy array or pytorch tensor.
            height (`int`, *optional*, defaults to `None`):
                The height to resize to.
            width (`int`, *optional*`, defaults to `None`):
                The width to resize to.

        Returns:
            `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
                The resized image.
        """
        if isinstance(image, PIL.Image.Image):
            image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
        elif isinstance(image, torch.Tensor):
            image = torch.nn.functional.interpolate(
                image,
                size=(height, width),
            )
        elif isinstance(image, np.ndarray):
            image = self.numpy_to_pt(image)
            image = torch.nn.functional.interpolate(
                image,
                size=(height, width),
            )
            image = self.pt_to_numpy(image)
        return image

resize_mode: str = "default", # "defalt", "fill", "crop"
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
"""
Resize image.

Args:
image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
The image input, can be a PIL image, numpy array or pytorch tensor.
height (`int`, *optional*, defaults to `None`):
height (`int`):
The height to resize to.
width (`int`, *optional*`, defaults to `None`):
width (`int`):
The width to resize to.
resize_mode (`str`, *optional*, defaults to `default`):
The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit
within the specified width and height, and it may not maintaining the original aspect ratio.
If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
within the dimensions, filling empty with data from image.
If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
within the dimensions, cropping the excess.
Note that resize_mode `fill` and `crop` are only supported for PIL image input.

Returns:
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
The resized image.
"""
if resize_mode != "default" and not isinstance(image, PIL.Image.Image):
raise ValueError(f"Only PIL image input is supported for resize_mode {resize_mode}")
if isinstance(image, PIL.Image.Image):
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
if resize_mode == "default":
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
elif resize_mode == "fill":
image = self._resize_and_fill(image, width, height)
elif resize_mode == "crop":
image = self._resize_and_crop(image, width, height)
else:
raise ValueError(f"resize_mode {resize_mode} is not supported")

elif isinstance(image, torch.Tensor):
image = torch.nn.functional.interpolate(
image,
Expand Down Expand Up @@ -262,14 +316,77 @@ def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
image[image >= 0.5] = 1
return image

def get_default_height_width(
self,
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
height: Optional[int] = None,
width: Optional[int] = None,
) -> Tuple[int, int]:
"""
This function return the height and width that are downscaled to the next integer multiple of
`vae_scale_factor`.

Args:
image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have
shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should
have shape `[batch, channel, height, width]`.
height (`int`, *optional*, defaults to `None`):
The height in preprocessed image. If `None`, will use the height of `image` input.
width (`int`, *optional*`, defaults to `None`):
The width in preprocessed. If `None`, will use the width of the `image` input.
"""

if height is None:
if isinstance(image, PIL.Image.Image):
height = image.height
elif isinstance(image, torch.Tensor):
height = image.shape[2]
else:
height = image.shape[1]

if width is None:
if isinstance(image, PIL.Image.Image):
width = image.width
elif isinstance(image, torch.Tensor):
width = image.shape[3]
else:
width = image.shape[2]

width, height = (
x - x % self.config.vae_scale_factor for x in (width, height)
) # resize to integer multiple of vae_scale_factor

return height, width

def preprocess(
self,
image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
image: PipelineImageInput,
height: Optional[int] = None,
width: Optional[int] = None,
resize_mode: str = "default", # "defalt", "fill", "crop"
crops_coords: Optional[Tuple[int, int, int, int]] = None,
) -> torch.Tensor:
"""
Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
Preprocess the image input.

Args:
image (`pipeline_image_input`):
The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of supported formats.
height (`int`, *optional*, defaults to `None`):
The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default height.
width (`int`, *optional*`, defaults to `None`):
The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
resize_mode (`str`, *optional*, defaults to `default`):
The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit
within the specified width and height, and it may not maintaining the original aspect ratio.
If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
within the dimensions, filling empty with data from image.
If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
within the dimensions, cropping the excess.
Note that resize_mode `fill` and `crop` are only supported for PIL image input.
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
The crop coordinates for each image in the batch. If `None`, will not crop the image.
"""
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)

Expand Down Expand Up @@ -299,13 +416,15 @@ def preprocess(
)

if isinstance(image[0], PIL.Image.Image):
if crops_coords is not None:
image = [i.crop(crops_coords) for i in image]
if self.config.do_resize:
height, width = self.get_default_height_width(image[0], height, width)
image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image]
if self.config.do_convert_rgb:
image = [self.convert_to_rgb(i) for i in image]
elif self.config.do_convert_grayscale:
image = [self.convert_to_grayscale(i) for i in image]
if self.config.do_resize:
height, width = self.get_default_height_width(image[0], height, width)
image = [self.resize(i, height, width) for i in image]
Comment on lines -306 to -308
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain that change a bit? Why should it be done before do_convert_rgb and do_convert_to_grayscale? Or is the order irrelevant here as both will give the same results?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not know this before but resize with resampe = PIL.Image.LANCZOS will turn a grayscale image into RGB....

mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]

Moving the resize before converting ensures the processed image is the desired type. As a reference, this was actually the order in the deprecated prepare_mask_and_masked_image

if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
            # resize all images w.r.t passed height an width
            image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
            image = [np.array(i.convert("RGB"))[None, :] for i in image]
            image = np.concatenate(image, axis=0)

image = self.pil_to_numpy(image) # to np
image = self.numpy_to_pt(image) # to pt

Expand Down
Loading
Loading