Skip to content

Commit

Permalink
Add from single file to StableDiffusionUpscalePipeline and StableDiff…
Browse files Browse the repository at this point in the history
…usionLatentUpscalePipeline (huggingface#5194)

* add from single file

* clean up

* make style

* add single file loading for upscaling
  • Loading branch information
DN6 authored Oct 6, 2023
1 parent dddc62f commit b495a24
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 7 deletions.
53 changes: 49 additions & 4 deletions pipelines/stable_diffusion/convert_from_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,6 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
class_embed_type = "projection"
assert "adm_in_channels" in unet_params
projection_class_embeddings_input_dim = unet_params.adm_in_channels
else:
raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")

config = {
"sample_size": image_size // vae_scale_factor,
Expand All @@ -323,6 +321,12 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
"transformer_layers_per_block": transformer_layers_per_block,
}

if "disable_self_attentions" in unet_params:
config["only_cross_attention"] = unet_params.disable_self_attentions

if "num_classes" in unet_params and type(unet_params.num_classes) == int:
config["num_class_embeds"] = unet_params.num_classes

if controlnet:
config["conditioning_channels"] = unet_params.hint_channels
else:
Expand Down Expand Up @@ -441,6 +445,10 @@ def convert_ldm_unet_checkpoint(
new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]

# Relevant to StableDiffusionUpscalePipeline
if "num_class_embeds" in config:
new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"]

new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]

Expand Down Expand Up @@ -496,6 +504,7 @@ def convert_ldm_unet_checkpoint(

if len(attentions):
paths = renew_attention_paths(attentions)

meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
Expand Down Expand Up @@ -1210,6 +1219,7 @@ def download_from_original_stable_diffusion_ckpt(
StableDiffusionControlNetPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
StableDiffusionXLImg2ImgPipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
Expand Down Expand Up @@ -1256,6 +1266,8 @@ def download_from_original_stable_diffusion_ckpt(
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
is_upscale = pipeline_class == StableDiffusionUpscalePipeline

config_url = None

# model_type = "v1"
Expand Down Expand Up @@ -1285,6 +1297,10 @@ def download_from_original_stable_diffusion_ckpt(
original_config_file = config_files["xl_refiner"]
else:
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"

if is_upscale:
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml"

if config_url is not None:
original_config_file = BytesIO(requests.get(config_url).content)

Expand All @@ -1308,6 +1324,8 @@ def download_from_original_stable_diffusion_ckpt(

if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline:
num_in_channels = 9
if num_in_channels is None and pipeline_class == StableDiffusionUpscalePipeline:
num_in_channels = 7
elif num_in_channels is None:
num_in_channels = 4

Expand Down Expand Up @@ -1391,9 +1409,13 @@ def download_from_original_stable_diffusion_ckpt(
else:
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")

if pipeline_class == StableDiffusionUpscalePipeline:
image_size = original_config.model.params.unet_config.params.image_size

# Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet_config["upcast_attention"] = upcast_attention

path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
checkpoint, unet_config, path=path, extract_ema=extract_ema
Expand Down Expand Up @@ -1458,8 +1480,29 @@ def download_from_original_stable_diffusion_ckpt(
controlnet=controlnet,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
if hasattr(pipe, "requires_safety_checker"):
pipe.requires_safety_checker = False

elif pipeline_class == StableDiffusionUpscalePipeline:
scheduler = DDIMScheduler.from_pretrained(
"stabilityai/stable-diffusion-x4-upscaler", subfolder="scheduler"
)
low_res_scheduler = DDPMScheduler.from_pretrained(
"stabilityai/stable-diffusion-x4-upscaler", subfolder="low_res_scheduler"
)

pipe = pipeline_class(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
low_res_scheduler=low_res_scheduler,
safety_checker=None,
feature_extractor=None,
)

else:
pipe = pipeline_class(
vae=vae,
Expand All @@ -1469,8 +1512,10 @@ def download_from_original_stable_diffusion_ckpt(
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
if hasattr(pipe, "requires_safety_checker"):
pipe.requires_safety_checker = False

else:
image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components(
original_config, clip_stats_path=clip_stats_path, device=device
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from transformers import CLIPTextModel, CLIPTokenizer

from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import EulerDiscreteScheduler
from ...utils import deprecate, logging
Expand Down Expand Up @@ -59,7 +60,7 @@ def preprocess(image):
return image


class StableDiffusionLatentUpscalePipeline(DiffusionPipeline):
class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, FromSingleFileMixin):
r"""
Pipeline for upscaling Stable Diffusion output image resolution by a factor of 2.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer

from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
Expand Down Expand Up @@ -67,7 +67,9 @@ def preprocess(image):
return image


class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
class StableDiffusionUpscalePipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for text-guided image super-resolution using Stable Diffusion 2.
Expand Down

0 comments on commit b495a24

Please sign in to comment.