diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index 7be26cbbb5f1..e94c251e21d9 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -18,7 +18,7 @@ import numpy as np import PIL.Image import torch -from PIL import Image, ImageFilter +from PIL import Image, ImageFilter, ImageOps from .configuration_utils import ConfigMixin, register_to_config from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate @@ -613,6 +613,39 @@ def postprocess( if output_type == "pil": return self.numpy_to_pil(image) + def apply_overlay( + self, + mask: PIL.Image.Image, + init_image: PIL.Image.Image, + image: PIL.Image.Image, + crop_coords: Optional[Tuple[int, int, int, int]] = None, + ) -> PIL.Image.Image: + """ + overlay the inpaint output to the original image + """ + + width, height = image.width, image.height + + init_image = self.resize(init_image, width=width, height=height) + mask = self.resize(mask, width=width, height=height) + + init_image_masked = PIL.Image.new("RGBa", (width, height)) + init_image_masked.paste(init_image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert("L"))) + init_image_masked = init_image_masked.convert("RGBA") + + if crop_coords is not None: + x, y, w, h = crop_coords + base_image = PIL.Image.new("RGBA", (width, height)) + image = self.resize(image, height=h, width=w, resize_mode="crop") + base_image.paste(image, (x, y)) + image = base_image.convert("RGB") + + image = image.convert("RGBA") + image.alpha_composite(init_image_masked) + image = image.convert("RGB") + + return image + class VaeImageProcessorLDM3D(VaeImageProcessor): """ diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 5f62bd135f7b..6940842f5d65 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -13,13 +13,12 @@ # limitations under the License. import inspect -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL.Image import torch from packaging import version -from PIL import ImageOps from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...configuration_utils import FrozenDict @@ -880,39 +879,6 @@ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32 assert emb.shape == (w.shape[0], embedding_dim) return emb - def apply_overlay( - self, - mask: PIL.Image.Image, - init_image: PIL.Image.Image, - image: PIL.Image.Image, - crop_coords: Optional[Tuple[int, int, int, int]] = None, - ) -> PIL.Image.Image: - """ - overlay the inpaint output to the original image - """ - - width, height = image.width, image.height - - init_image = self.image_processor.resize(init_image, width=width, height=height) - mask = self.mask_processor.resize(mask, width=width, height=height) - - init_image_masked = PIL.Image.new("RGBa", (width, height)) - init_image_masked.paste(init_image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert("L"))) - init_image_masked = init_image_masked.convert("RGBA") - - if crop_coords is not None: - x, y, w, h = crop_coords - base_image = PIL.Image.new("RGBA", (width, height)) - image = self.image_processor.resize(image, height=h, width=w, resize_mode="crop") - base_image.paste(image, (x, y)) - image = base_image.convert("RGB") - - image = image.convert("RGBA") - image.alpha_composite(init_image_masked) - image = image.convert("RGB") - - return image - @property def guidance_scale(self): return self._guidance_scale @@ -1370,7 +1336,7 @@ def __call__( image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) if padding_mask_crop is not None: - image = [self.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] + image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] # Offload all models self.maybe_free_model_hooks()