diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 7309c3fc709c..7dec43571b1c 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from collections import OrderedDict, defaultdict +from collections import defaultdict from contextlib import nullcontext from typing import Callable, Dict, List, Optional, Union @@ -664,6 +664,80 @@ def delete_adapters(self, adapter_names: Union[List[str], str]): if hasattr(self, "peft_config"): self.peft_config.pop(adapter_name, None) + def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict): + updated_state_dict = {} + image_projection = None + + if "proj.weight" in state_dict: + # IP-Adapter + num_image_text_embeds = 4 + clip_embeddings_dim = state_dict["proj.weight"].shape[-1] + cross_attention_dim = state_dict["proj.weight"].shape[0] // 4 + + image_projection = ImageProjection( + cross_attention_dim=cross_attention_dim, + image_embed_dim=clip_embeddings_dim, + num_image_text_embeds=num_image_text_embeds, + ) + + for key, value in state_dict.items(): + diffusers_name = key.replace("proj", "image_embeds") + updated_state_dict[diffusers_name] = value + + elif "proj.3.weight" in state_dict: + # IP-Adapter Full + clip_embeddings_dim = state_dict["proj.0.weight"].shape[0] + cross_attention_dim = state_dict["proj.3.weight"].shape[0] + + image_projection = MLPProjection( + cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim + ) + + for key, value in state_dict.items(): + diffusers_name = key.replace("proj.0", "ff.net.0.proj") + diffusers_name = diffusers_name.replace("proj.2", "ff.net.2") + diffusers_name = diffusers_name.replace("proj.3", "norm") + updated_state_dict[diffusers_name] = value + + else: + # IP-Adapter Plus + num_image_text_embeds = state_dict["latents"].shape[1] + embed_dims = state_dict["proj_in.weight"].shape[1] + output_dims = state_dict["proj_out.weight"].shape[0] + hidden_dims = state_dict["latents"].shape[2] + heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64 + + image_projection = Resampler( + embed_dims=embed_dims, + output_dims=output_dims, + hidden_dims=hidden_dims, + heads=heads, + num_queries=num_image_text_embeds, + ) + + for key, value in state_dict.items(): + diffusers_name = key.replace("0.to", "2.to") + diffusers_name = diffusers_name.replace("1.0.weight", "3.0.weight") + diffusers_name = diffusers_name.replace("1.0.bias", "3.0.bias") + diffusers_name = diffusers_name.replace("1.1.weight", "3.1.net.0.proj.weight") + diffusers_name = diffusers_name.replace("1.3.weight", "3.1.net.2.weight") + + if "norm1" in diffusers_name: + updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value + elif "norm2" in diffusers_name: + updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value + elif "to_kv" in diffusers_name: + v_chunk = value.chunk(2, dim=0) + updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0] + updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1] + elif "to_out" in diffusers_name: + updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value + else: + updated_state_dict[diffusers_name] = value + + image_projection.load_state_dict(updated_state_dict) + return image_projection + def _load_ip_adapter_weights(self, state_dict): from ..models.attention_processor import ( AttnProcessor, @@ -724,103 +798,8 @@ def _load_ip_adapter_weights(self, state_dict): self.set_attn_processor(attn_procs) - # create image projection layers. - if "proj.weight" in state_dict["image_proj"]: - # IP-Adapter - clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1] - cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4 - - image_projection = ImageProjection( - cross_attention_dim=cross_attention_dim, - image_embed_dim=clip_embeddings_dim, - num_image_text_embeds=num_image_text_embeds, - ) - image_projection.to(dtype=self.dtype, device=self.device) - - # load image projection layer weights - image_proj_state_dict = {} - image_proj_state_dict.update( - { - "image_embeds.weight": state_dict["image_proj"]["proj.weight"], - "image_embeds.bias": state_dict["image_proj"]["proj.bias"], - "norm.weight": state_dict["image_proj"]["norm.weight"], - "norm.bias": state_dict["image_proj"]["norm.bias"], - } - ) - image_projection.load_state_dict(image_proj_state_dict) - del image_proj_state_dict - - elif "proj.3.weight" in state_dict["image_proj"]: - clip_embeddings_dim = state_dict["image_proj"]["proj.0.weight"].shape[0] - cross_attention_dim = state_dict["image_proj"]["proj.3.weight"].shape[0] - - image_projection = MLPProjection( - cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim - ) - image_projection.to(dtype=self.dtype, device=self.device) - - # load image projection layer weights - image_proj_state_dict = {} - image_proj_state_dict.update( - { - "ff.net.0.proj.weight": state_dict["image_proj"]["proj.0.weight"], - "ff.net.0.proj.bias": state_dict["image_proj"]["proj.0.bias"], - "ff.net.2.weight": state_dict["image_proj"]["proj.2.weight"], - "ff.net.2.bias": state_dict["image_proj"]["proj.2.bias"], - "norm.weight": state_dict["image_proj"]["proj.3.weight"], - "norm.bias": state_dict["image_proj"]["proj.3.bias"], - } - ) - image_projection.load_state_dict(image_proj_state_dict) - del image_proj_state_dict - - else: - # IP-Adapter Plus - embed_dims = state_dict["image_proj"]["proj_in.weight"].shape[1] - output_dims = state_dict["image_proj"]["proj_out.weight"].shape[0] - hidden_dims = state_dict["image_proj"]["latents"].shape[2] - heads = state_dict["image_proj"]["layers.0.0.to_q.weight"].shape[0] // 64 - - image_projection = Resampler( - embed_dims=embed_dims, - output_dims=output_dims, - hidden_dims=hidden_dims, - heads=heads, - num_queries=num_image_text_embeds, - ) - - image_proj_state_dict = state_dict["image_proj"] - - new_sd = OrderedDict() - for k, v in image_proj_state_dict.items(): - if "0.to" in k: - k = k.replace("0.to", "2.to") - elif "1.0.weight" in k: - k = k.replace("1.0.weight", "3.0.weight") - elif "1.0.bias" in k: - k = k.replace("1.0.bias", "3.0.bias") - elif "1.1.weight" in k: - k = k.replace("1.1.weight", "3.1.net.0.proj.weight") - elif "1.3.weight" in k: - k = k.replace("1.3.weight", "3.1.net.2.weight") - - if "norm1" in k: - new_sd[k.replace("0.norm1", "0")] = v - elif "norm2" in k: - new_sd[k.replace("0.norm2", "1")] = v - elif "to_kv" in k: - v_chunk = v.chunk(2, dim=0) - new_sd[k.replace("to_kv", "to_k")] = v_chunk[0] - new_sd[k.replace("to_kv", "to_v")] = v_chunk[1] - elif "to_out" in k: - new_sd[k.replace("to_out", "to_out.0")] = v - else: - new_sd[k] = v - - image_projection.load_state_dict(new_sd) - del image_proj_state_dict + # convert IP-Adapter Image Projection layers to diffusers + image_projection = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"]) self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype) self.config.encoder_hid_dim_type = "ip_image_proj" - - delete_adapter_layers