Skip to content

Commit

Permalink
add single file loading for upscaling
Browse files Browse the repository at this point in the history
  • Loading branch information
DN6 committed Oct 4, 2023
1 parent 302b12d commit 61c8729
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 4 deletions.
53 changes: 49 additions & 4 deletions src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,6 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
class_embed_type = "projection"
assert "adm_in_channels" in unet_params
projection_class_embeddings_input_dim = unet_params.adm_in_channels
else:
raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")

config = {
"sample_size": image_size // vae_scale_factor,
Expand All @@ -323,6 +321,12 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
"transformer_layers_per_block": transformer_layers_per_block,
}

if "disable_self_attentions" in unet_params:
config["only_cross_attention"] = unet_params.disable_self_attentions

if "num_classes" in unet_params and type(unet_params.num_classes) == int:
config["num_class_embeds"] = unet_params.num_classes

if controlnet:
config["conditioning_channels"] = unet_params.hint_channels
else:
Expand Down Expand Up @@ -441,6 +445,10 @@ def convert_ldm_unet_checkpoint(
new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]

# Relevant to StableDiffusionUpscalePipeline
if "num_class_embeds" in config:
new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"]

new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]

Expand Down Expand Up @@ -496,6 +504,7 @@ def convert_ldm_unet_checkpoint(

if len(attentions):
paths = renew_attention_paths(attentions)

meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
Expand Down Expand Up @@ -1210,6 +1219,7 @@ def download_from_original_stable_diffusion_ckpt(
StableDiffusionControlNetPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
StableDiffusionXLImg2ImgPipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
Expand Down Expand Up @@ -1256,6 +1266,8 @@ def download_from_original_stable_diffusion_ckpt(
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
is_upscale = pipeline_class == StableDiffusionUpscalePipeline

config_url = None

# model_type = "v1"
Expand Down Expand Up @@ -1285,6 +1297,10 @@ def download_from_original_stable_diffusion_ckpt(
original_config_file = config_files["xl_refiner"]
else:
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"

if is_upscale:
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml"

if config_url is not None:
original_config_file = BytesIO(requests.get(config_url).content)

Expand All @@ -1308,6 +1324,8 @@ def download_from_original_stable_diffusion_ckpt(

if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline:
num_in_channels = 9
if num_in_channels is None and pipeline_class == StableDiffusionUpscalePipeline:
num_in_channels = 7
elif num_in_channels is None:
num_in_channels = 4

Expand Down Expand Up @@ -1391,9 +1409,13 @@ def download_from_original_stable_diffusion_ckpt(
else:
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")

if pipeline_class == StableDiffusionUpscalePipeline:
image_size = original_config.model.params.unet_config.params.image_size

# Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet_config["upcast_attention"] = upcast_attention

path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
checkpoint, unet_config, path=path, extract_ema=extract_ema
Expand Down Expand Up @@ -1458,8 +1480,29 @@ def download_from_original_stable_diffusion_ckpt(
controlnet=controlnet,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
if hasattr(pipe, "requires_safety_checker"):
pipe.requires_safety_checker = False

elif pipeline_class == StableDiffusionUpscalePipeline:
scheduler = DDIMScheduler.from_pretrained(
"stabilityai/stable-diffusion-x4-upscaler", subfolder="scheduler"
)
low_res_scheduler = DDPMScheduler.from_pretrained(
"stabilityai/stable-diffusion-x4-upscaler", subfolder="low_res_scheduler"
)

pipe = pipeline_class(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
low_res_scheduler=low_res_scheduler,
safety_checker=None,
feature_extractor=None,
)

else:
pipe = pipeline_class(
vae=vae,
Expand All @@ -1469,8 +1512,10 @@ def download_from_original_stable_diffusion_ckpt(
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
if hasattr(pipe, "requires_safety_checker"):
pipe.requires_safety_checker = False

else:
image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components(
original_config, clip_stats_path=clip_stats_path, device=device
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
floats_tensor,
load_image,
load_numpy,
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
torch_device,
Expand Down Expand Up @@ -479,3 +480,36 @@ def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
mem_bytes = torch.cuda.max_memory_allocated()
# make sure that less than 2.9 GB is allocated
assert mem_bytes < 2.9 * 10**9

def test_download_ckpt_diff_format_is_same(self):
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/sd2-upscale/low_res_cat.png"
)

prompt = "a cat sitting on a park bench"
model_id = "stabilityai/stable-diffusion-x4-upscaler"
pipe = StableDiffusionUpscalePipeline.from_pretrained(model_id)
pipe.enable_model_cpu_offload()

generator = torch.Generator("cpu").manual_seed(0)
output = pipe(prompt=prompt, image=image, generator=generator, output_type="np", num_inference_steps=3)
image_from_pretrained = output.images[0]

single_file_path = (
"https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/blob/main/x4-upscaler-ema.safetensors"
)
pipe_from_single_file = StableDiffusionUpscalePipeline.from_single_file(single_file_path)
pipe_from_single_file.enable_model_cpu_offload()

generator = torch.Generator("cpu").manual_seed(0)
output_from_single_file = pipe_from_single_file(
prompt=prompt, image=image, generator=generator, output_type="np", num_inference_steps=3
)
image_from_single_file = output_from_single_file.images[0]

assert image_from_pretrained.shape == (512, 512, 3)
assert image_from_single_file.shape == (512, 512, 3)
assert (
numpy_cosine_similarity_distance(image_from_pretrained.flatten(), image_from_single_file.flatten()) < 1e-3
)

0 comments on commit 61c8729

Please sign in to comment.