Skip to content

Commit

Permalink
Fix SDXL Inpainting from single file with Refiner Model (huggingface#…
Browse files Browse the repository at this point in the history
…6147)

* update

* update

* update
  • Loading branch information
DN6 authored and Jimmy committed Apr 26, 2024
1 parent 8a7d5b4 commit d63aa28
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 71 deletions.
4 changes: 4 additions & 0 deletions src/diffusers/loaders/single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,12 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
load_safety_checker = kwargs.pop("load_safety_checker", True)
prediction_type = kwargs.pop("prediction_type", None)
text_encoder = kwargs.pop("text_encoder", None)
text_encoder_2 = kwargs.pop("text_encoder_2", None)
vae = kwargs.pop("vae", None)
controlnet = kwargs.pop("controlnet", None)
adapter = kwargs.pop("adapter", None)
tokenizer = kwargs.pop("tokenizer", None)
tokenizer_2 = kwargs.pop("tokenizer_2", None)

torch_dtype = kwargs.pop("torch_dtype", None)

Expand Down Expand Up @@ -274,8 +276,10 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
load_safety_checker=load_safety_checker,
prediction_type=prediction_type,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
vae=vae,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
original_config_file=original_config_file,
config_files=config_files,
local_files_only=local_files_only,
Expand Down
130 changes: 59 additions & 71 deletions src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,7 +1153,9 @@ def download_from_original_stable_diffusion_ckpt(
vae_path=None,
vae=None,
text_encoder=None,
text_encoder_2=None,
tokenizer=None,
tokenizer_2=None,
config_files=None,
) -> DiffusionPipeline:
"""
Expand Down Expand Up @@ -1232,7 +1234,9 @@ def download_from_original_stable_diffusion_ckpt(
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
StableDiffusionXLControlNetInpaintPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLPipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
Expand Down Expand Up @@ -1339,7 +1343,11 @@ def download_from_original_stable_diffusion_ckpt(
else:
pipeline_class = StableDiffusionXLPipeline if model_type == "SDXL" else StableDiffusionXLImg2ImgPipeline

if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline:
if num_in_channels is None and pipeline_class in [
StableDiffusionInpaintPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLControlNetInpaintPipeline,
]:
num_in_channels = 9
if num_in_channels is None and pipeline_class == StableDiffusionUpscalePipeline:
num_in_channels = 7
Expand Down Expand Up @@ -1686,7 +1694,9 @@ def download_from_original_stable_diffusion_ckpt(
feature_extractor=feature_extractor,
)
elif model_type in ["SDXL", "SDXL-Refiner"]:
if model_type == "SDXL":
is_refiner = model_type == "SDXL-Refiner"

if (is_refiner is False) and (tokenizer is None):
try:
tokenizer = CLIPTokenizer.from_pretrained(
"openai/clip-vit-large-patch14", local_files_only=local_files_only
Expand All @@ -1695,7 +1705,11 @@ def download_from_original_stable_diffusion_ckpt(
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
)

if (is_refiner is False) and (text_encoder is None):
text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)

if tokenizer_2 is None:
try:
tokenizer_2 = CLIPTokenizer.from_pretrained(
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only
Expand All @@ -1705,95 +1719,69 @@ def download_from_original_stable_diffusion_ckpt(
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' with `pad_token` set to '!'."
)

if text_encoder_2 is None:
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
config_kwargs = {"projection_dim": 1280}
text_encoder_2 = convert_open_clip_checkpoint(
checkpoint,
config_name,
prefix="conditioner.embedders.1.model.",
has_projection=True,
local_files_only=local_files_only,
**config_kwargs,
)

if is_accelerate_available(): # SBM Now move model to cpu.
if model_type in ["SDXL", "SDXL-Refiner"]:
for param_name, param in converted_unet_checkpoint.items():
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
prefix = "conditioner.embedders.0.model." if is_refiner else "conditioner.embedders.1.model."

if controlnet:
pipe = pipeline_class(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
unet=unet,
controlnet=controlnet,
scheduler=scheduler,
force_zeros_for_empty_prompt=True,
)
elif adapter:
pipe = pipeline_class(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
unet=unet,
adapter=adapter,
scheduler=scheduler,
force_zeros_for_empty_prompt=True,
)
else:
pipe = pipeline_class(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
unet=unet,
scheduler=scheduler,
force_zeros_for_empty_prompt=True,
)
else:
tokenizer = None
text_encoder = None
try:
tokenizer_2 = CLIPTokenizer.from_pretrained(
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only
)
except Exception:
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' with `pad_token` set to '!'."
)
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
config_kwargs = {"projection_dim": 1280}
text_encoder_2 = convert_open_clip_checkpoint(
checkpoint,
config_name,
prefix="conditioner.embedders.0.model.",
prefix=prefix,
has_projection=True,
local_files_only=local_files_only,
**config_kwargs,
)

if is_accelerate_available(): # SBM Now move model to cpu.
if model_type in ["SDXL", "SDXL-Refiner"]:
for param_name, param in converted_unet_checkpoint.items():
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
if is_accelerate_available(): # SBM Now move model to cpu.
for param_name, param in converted_unet_checkpoint.items():
set_module_tensor_to_device(unet, param_name, "cpu", value=param)

pipe = StableDiffusionXLImg2ImgPipeline(
if controlnet:
pipe = pipeline_class(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
unet=unet,
controlnet=controlnet,
scheduler=scheduler,
force_zeros_for_empty_prompt=True,
)
elif adapter:
pipe = pipeline_class(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
unet=unet,
adapter=adapter,
scheduler=scheduler,
requires_aesthetics_score=True,
force_zeros_for_empty_prompt=False,
force_zeros_for_empty_prompt=True,
)

else:
pipeline_kwargs = {
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"unet": unet,
"scheduler": scheduler,
}

if (pipeline_class == StableDiffusionXLImg2ImgPipeline) or (
pipeline_class == StableDiffusionXLInpaintPipeline
):
pipeline_kwargs.update({"requires_aesthetics_score": is_refiner})

if is_refiner:
pipeline_kwargs.update({"force_zeros_for_empty_prompt": False})

pipe = pipeline_class(**pipeline_kwargs)
else:
text_config = create_ldm_bert_config(original_config)
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
Expand Down

0 comments on commit d63aa28

Please sign in to comment.