Skip to content

Commit

Permalink
move apply_overlay to image processor
Browse files Browse the repository at this point in the history
  • Loading branch information
yiyixuxu committed Dec 26, 2023
1 parent 615e331 commit 294a05a
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 37 deletions.
35 changes: 34 additions & 1 deletion src/diffusers/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np
import PIL.Image
import torch
from PIL import Image, ImageFilter
from PIL import Image, ImageFilter, ImageOps

from .configuration_utils import ConfigMixin, register_to_config
from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
Expand Down Expand Up @@ -613,6 +613,39 @@ def postprocess(
if output_type == "pil":
return self.numpy_to_pil(image)

def apply_overlay(
self,
mask: PIL.Image.Image,
init_image: PIL.Image.Image,
image: PIL.Image.Image,
crop_coords: Optional[Tuple[int, int, int, int]] = None,
) -> PIL.Image.Image:
"""
overlay the inpaint output to the original image
"""

width, height = image.width, image.height

init_image = self.resize(init_image, width=width, height=height)
mask = self.resize(mask, width=width, height=height)

init_image_masked = PIL.Image.new("RGBa", (width, height))
init_image_masked.paste(init_image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert("L")))
init_image_masked = init_image_masked.convert("RGBA")

if crop_coords is not None:
x, y, w, h = crop_coords
base_image = PIL.Image.new("RGBA", (width, height))
image = self.resize(image, height=h, width=w, resize_mode="crop")
base_image.paste(image, (x, y))
image = base_image.convert("RGB")

image = image.convert("RGBA")
image.alpha_composite(init_image_masked)
image = image.convert("RGB")

return image


class VaeImageProcessorLDM3D(VaeImageProcessor):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@
# limitations under the License.

import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
import PIL.Image
import torch
from packaging import version
from PIL import ImageOps
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from ...configuration_utils import FrozenDict
Expand Down Expand Up @@ -880,39 +879,6 @@ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32
assert emb.shape == (w.shape[0], embedding_dim)
return emb

def apply_overlay(
self,
mask: PIL.Image.Image,
init_image: PIL.Image.Image,
image: PIL.Image.Image,
crop_coords: Optional[Tuple[int, int, int, int]] = None,
) -> PIL.Image.Image:
"""
overlay the inpaint output to the original image
"""

width, height = image.width, image.height

init_image = self.image_processor.resize(init_image, width=width, height=height)
mask = self.mask_processor.resize(mask, width=width, height=height)

init_image_masked = PIL.Image.new("RGBa", (width, height))
init_image_masked.paste(init_image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert("L")))
init_image_masked = init_image_masked.convert("RGBA")

if crop_coords is not None:
x, y, w, h = crop_coords
base_image = PIL.Image.new("RGBA", (width, height))
image = self.image_processor.resize(image, height=h, width=w, resize_mode="crop")
base_image.paste(image, (x, y))
image = base_image.convert("RGB")

image = image.convert("RGBA")
image.alpha_composite(init_image_masked)
image = image.convert("RGB")

return image

@property
def guidance_scale(self):
return self._guidance_scale
Expand Down Expand Up @@ -1370,7 +1336,7 @@ def __call__(
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)

if padding_mask_crop is not None:
image = [self.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]

# Offload all models
self.maybe_free_model_hooks()
Expand Down

0 comments on commit 294a05a

Please sign in to comment.