From 7660aee30f8a55d2481c64936218f2cd5beb2c05 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Tue, 12 Dec 2023 18:03:42 +0100 Subject: [PATCH 1/3] Add converter method for ip adapters --- .../loaders/ip_adapter_conversion_utils.py | 52 +++++++++++++++ src/diffusers/loaders/unet.py | 66 +++---------------- 2 files changed, 62 insertions(+), 56 deletions(-) create mode 100644 src/diffusers/loaders/ip_adapter_conversion_utils.py diff --git a/src/diffusers/loaders/ip_adapter_conversion_utils.py b/src/diffusers/loaders/ip_adapter_conversion_utils.py new file mode 100644 index 000000000000..d6944a7b088c --- /dev/null +++ b/src/diffusers/loaders/ip_adapter_conversion_utils.py @@ -0,0 +1,52 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def _convert_ip_adapter_to_diffusers(state_dict): + updated_state_dict = {} + + if "proj.weight" in state_dict: + for key, value in state_dict.items(): + diffusers_name = key.replace("proj", "image_embeds") + updated_state_dict[diffusers_name] = value + + elif "proj.3.weight" in state_dict: + for key, value in state_dict.items(): + diffusers_name = key.replace("proj.0", "ff.net.0.proj") + diffusers_name = diffusers_name.replace("proj.2", "ff.net.2") + diffusers_name = diffusers_name.replace("proj.3", "norm") + updated_state_dict[diffusers_name] = value + + else: + for key, value in state_dict.items(): + diffusers_name = key.replace("0.to", "2.to") + diffusers_name = diffusers_name.replace("1.0.weight", "3.0.weight") + diffusers_name = diffusers_name.replace("1.0.bias", "3.0.bias") + diffusers_name = diffusers_name.replace("1.1.weight", "3.1.net.0.proj.weight") + diffusers_name = diffusers_name.replace("1.3.weight", "3.1.net.2.weight") + + if "norm1" in diffusers_name: + updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value + elif "norm2" in diffusers_name: + updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value + elif "to_kv" in diffusers_name: + v_chunk = value.chunk(2, dim=0) + updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0] + updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1] + elif "to_out" in diffusers_name: + updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value + else: + updated_state_dict[diffusers_name] = value + + return updated_state_dict diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 7309c3fc709c..1717d1ca6e98 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from collections import OrderedDict, defaultdict +from collections import defaultdict from contextlib import nullcontext from typing import Callable, Dict, List, Optional, Union @@ -33,6 +33,7 @@ set_adapter_layers, set_weights_and_activate_adapters, ) +from .ip_adapter_conversion_utils import _convert_ip_adapter_to_diffusers from .utils import AttnProcsLayers @@ -725,6 +726,8 @@ def _load_ip_adapter_weights(self, state_dict): self.set_attn_processor(attn_procs) # create image projection layers. + image_proj_state_dict = state_dict["image_proj"] + if "proj.weight" in state_dict["image_proj"]: # IP-Adapter clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1] @@ -737,18 +740,8 @@ def _load_ip_adapter_weights(self, state_dict): ) image_projection.to(dtype=self.dtype, device=self.device) - # load image projection layer weights - image_proj_state_dict = {} - image_proj_state_dict.update( - { - "image_embeds.weight": state_dict["image_proj"]["proj.weight"], - "image_embeds.bias": state_dict["image_proj"]["proj.bias"], - "norm.weight": state_dict["image_proj"]["norm.weight"], - "norm.bias": state_dict["image_proj"]["norm.bias"], - } - ) - image_projection.load_state_dict(image_proj_state_dict) - del image_proj_state_dict + new_sd = _convert_ip_adapter_to_diffusers(image_proj_state_dict) + image_projection.load_state_dict(new_sd) elif "proj.3.weight" in state_dict["image_proj"]: clip_embeddings_dim = state_dict["image_proj"]["proj.0.weight"].shape[0] @@ -759,20 +752,8 @@ def _load_ip_adapter_weights(self, state_dict): ) image_projection.to(dtype=self.dtype, device=self.device) - # load image projection layer weights - image_proj_state_dict = {} - image_proj_state_dict.update( - { - "ff.net.0.proj.weight": state_dict["image_proj"]["proj.0.weight"], - "ff.net.0.proj.bias": state_dict["image_proj"]["proj.0.bias"], - "ff.net.2.weight": state_dict["image_proj"]["proj.2.weight"], - "ff.net.2.bias": state_dict["image_proj"]["proj.2.bias"], - "norm.weight": state_dict["image_proj"]["proj.3.weight"], - "norm.bias": state_dict["image_proj"]["proj.3.bias"], - } - ) - image_projection.load_state_dict(image_proj_state_dict) - del image_proj_state_dict + new_sd = _convert_ip_adapter_to_diffusers(image_proj_state_dict) + image_projection.load_state_dict(new_sd) else: # IP-Adapter Plus @@ -789,36 +770,9 @@ def _load_ip_adapter_weights(self, state_dict): num_queries=num_image_text_embeds, ) - image_proj_state_dict = state_dict["image_proj"] - - new_sd = OrderedDict() - for k, v in image_proj_state_dict.items(): - if "0.to" in k: - k = k.replace("0.to", "2.to") - elif "1.0.weight" in k: - k = k.replace("1.0.weight", "3.0.weight") - elif "1.0.bias" in k: - k = k.replace("1.0.bias", "3.0.bias") - elif "1.1.weight" in k: - k = k.replace("1.1.weight", "3.1.net.0.proj.weight") - elif "1.3.weight" in k: - k = k.replace("1.3.weight", "3.1.net.2.weight") - - if "norm1" in k: - new_sd[k.replace("0.norm1", "0")] = v - elif "norm2" in k: - new_sd[k.replace("0.norm2", "1")] = v - elif "to_kv" in k: - v_chunk = v.chunk(2, dim=0) - new_sd[k.replace("to_kv", "to_k")] = v_chunk[0] - new_sd[k.replace("to_kv", "to_v")] = v_chunk[1] - elif "to_out" in k: - new_sd[k.replace("to_out", "to_out.0")] = v - else: - new_sd[k] = v - + new_sd = _convert_ip_adapter_to_diffusers(image_proj_state_dict) image_projection.load_state_dict(new_sd) - del image_proj_state_dict + del image_proj_state_dict self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype) self.config.encoder_hid_dim_type = "ip_image_proj" From dfda680fea1fc186f79f6b2ec4b1fa559e37ae52 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Fri, 15 Dec 2023 09:16:24 +0100 Subject: [PATCH 2/3] Move converter method --- .../loaders/ip_adapter_conversion_utils.py | 52 ------------------- src/diffusers/loaders/unet.py | 45 ++++++++++++++-- 2 files changed, 41 insertions(+), 56 deletions(-) delete mode 100644 src/diffusers/loaders/ip_adapter_conversion_utils.py diff --git a/src/diffusers/loaders/ip_adapter_conversion_utils.py b/src/diffusers/loaders/ip_adapter_conversion_utils.py deleted file mode 100644 index d6944a7b088c..000000000000 --- a/src/diffusers/loaders/ip_adapter_conversion_utils.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -def _convert_ip_adapter_to_diffusers(state_dict): - updated_state_dict = {} - - if "proj.weight" in state_dict: - for key, value in state_dict.items(): - diffusers_name = key.replace("proj", "image_embeds") - updated_state_dict[diffusers_name] = value - - elif "proj.3.weight" in state_dict: - for key, value in state_dict.items(): - diffusers_name = key.replace("proj.0", "ff.net.0.proj") - diffusers_name = diffusers_name.replace("proj.2", "ff.net.2") - diffusers_name = diffusers_name.replace("proj.3", "norm") - updated_state_dict[diffusers_name] = value - - else: - for key, value in state_dict.items(): - diffusers_name = key.replace("0.to", "2.to") - diffusers_name = diffusers_name.replace("1.0.weight", "3.0.weight") - diffusers_name = diffusers_name.replace("1.0.bias", "3.0.bias") - diffusers_name = diffusers_name.replace("1.1.weight", "3.1.net.0.proj.weight") - diffusers_name = diffusers_name.replace("1.3.weight", "3.1.net.2.weight") - - if "norm1" in diffusers_name: - updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value - elif "norm2" in diffusers_name: - updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value - elif "to_kv" in diffusers_name: - v_chunk = value.chunk(2, dim=0) - updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0] - updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1] - elif "to_out" in diffusers_name: - updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value - else: - updated_state_dict[diffusers_name] = value - - return updated_state_dict diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 1717d1ca6e98..d9b557ac2970 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -33,7 +33,6 @@ set_adapter_layers, set_weights_and_activate_adapters, ) -from .ip_adapter_conversion_utils import _convert_ip_adapter_to_diffusers from .utils import AttnProcsLayers @@ -665,6 +664,44 @@ def delete_adapters(self, adapter_names: Union[List[str], str]): if hasattr(self, "peft_config"): self.peft_config.pop(adapter_name, None) + def _convert_ip_adapter_to_diffusers(self, state_dict): + updated_state_dict = {} + + if "proj.weight" in state_dict: + for key, value in state_dict.items(): + diffusers_name = key.replace("proj", "image_embeds") + updated_state_dict[diffusers_name] = value + + elif "proj.3.weight" in state_dict: + for key, value in state_dict.items(): + diffusers_name = key.replace("proj.0", "ff.net.0.proj") + diffusers_name = diffusers_name.replace("proj.2", "ff.net.2") + diffusers_name = diffusers_name.replace("proj.3", "norm") + updated_state_dict[diffusers_name] = value + + else: + for key, value in state_dict.items(): + diffusers_name = key.replace("0.to", "2.to") + diffusers_name = diffusers_name.replace("1.0.weight", "3.0.weight") + diffusers_name = diffusers_name.replace("1.0.bias", "3.0.bias") + diffusers_name = diffusers_name.replace("1.1.weight", "3.1.net.0.proj.weight") + diffusers_name = diffusers_name.replace("1.3.weight", "3.1.net.2.weight") + + if "norm1" in diffusers_name: + updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value + elif "norm2" in diffusers_name: + updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value + elif "to_kv" in diffusers_name: + v_chunk = value.chunk(2, dim=0) + updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0] + updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1] + elif "to_out" in diffusers_name: + updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value + else: + updated_state_dict[diffusers_name] = value + + return updated_state_dict + def _load_ip_adapter_weights(self, state_dict): from ..models.attention_processor import ( AttnProcessor, @@ -740,7 +777,7 @@ def _load_ip_adapter_weights(self, state_dict): ) image_projection.to(dtype=self.dtype, device=self.device) - new_sd = _convert_ip_adapter_to_diffusers(image_proj_state_dict) + new_sd = self._convert_ip_adapter_to_diffusers(image_proj_state_dict) image_projection.load_state_dict(new_sd) elif "proj.3.weight" in state_dict["image_proj"]: @@ -752,7 +789,7 @@ def _load_ip_adapter_weights(self, state_dict): ) image_projection.to(dtype=self.dtype, device=self.device) - new_sd = _convert_ip_adapter_to_diffusers(image_proj_state_dict) + new_sd = self._convert_ip_adapter_to_diffusers(image_proj_state_dict) image_projection.load_state_dict(new_sd) else: @@ -770,7 +807,7 @@ def _load_ip_adapter_weights(self, state_dict): num_queries=num_image_text_embeds, ) - new_sd = _convert_ip_adapter_to_diffusers(image_proj_state_dict) + new_sd = self._convert_ip_adapter_to_diffusers(image_proj_state_dict) image_projection.load_state_dict(new_sd) del image_proj_state_dict From 0b1bfc3072140aa75a61eaf281965ab5eb08686a Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Sat, 16 Dec 2023 09:18:50 +0100 Subject: [PATCH 3/3] Update to image proj converter --- src/diffusers/loaders/unet.py | 92 +++++++++++++++-------------------- 1 file changed, 40 insertions(+), 52 deletions(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index d9b557ac2970..7dec43571b1c 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -664,15 +664,35 @@ def delete_adapters(self, adapter_names: Union[List[str], str]): if hasattr(self, "peft_config"): self.peft_config.pop(adapter_name, None) - def _convert_ip_adapter_to_diffusers(self, state_dict): + def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict): updated_state_dict = {} + image_projection = None if "proj.weight" in state_dict: + # IP-Adapter + num_image_text_embeds = 4 + clip_embeddings_dim = state_dict["proj.weight"].shape[-1] + cross_attention_dim = state_dict["proj.weight"].shape[0] // 4 + + image_projection = ImageProjection( + cross_attention_dim=cross_attention_dim, + image_embed_dim=clip_embeddings_dim, + num_image_text_embeds=num_image_text_embeds, + ) + for key, value in state_dict.items(): diffusers_name = key.replace("proj", "image_embeds") updated_state_dict[diffusers_name] = value elif "proj.3.weight" in state_dict: + # IP-Adapter Full + clip_embeddings_dim = state_dict["proj.0.weight"].shape[0] + cross_attention_dim = state_dict["proj.3.weight"].shape[0] + + image_projection = MLPProjection( + cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim + ) + for key, value in state_dict.items(): diffusers_name = key.replace("proj.0", "ff.net.0.proj") diffusers_name = diffusers_name.replace("proj.2", "ff.net.2") @@ -680,6 +700,21 @@ def _convert_ip_adapter_to_diffusers(self, state_dict): updated_state_dict[diffusers_name] = value else: + # IP-Adapter Plus + num_image_text_embeds = state_dict["latents"].shape[1] + embed_dims = state_dict["proj_in.weight"].shape[1] + output_dims = state_dict["proj_out.weight"].shape[0] + hidden_dims = state_dict["latents"].shape[2] + heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64 + + image_projection = Resampler( + embed_dims=embed_dims, + output_dims=output_dims, + hidden_dims=hidden_dims, + heads=heads, + num_queries=num_image_text_embeds, + ) + for key, value in state_dict.items(): diffusers_name = key.replace("0.to", "2.to") diffusers_name = diffusers_name.replace("1.0.weight", "3.0.weight") @@ -700,7 +735,8 @@ def _convert_ip_adapter_to_diffusers(self, state_dict): else: updated_state_dict[diffusers_name] = value - return updated_state_dict + image_projection.load_state_dict(updated_state_dict) + return image_projection def _load_ip_adapter_weights(self, state_dict): from ..models.attention_processor import ( @@ -762,56 +798,8 @@ def _load_ip_adapter_weights(self, state_dict): self.set_attn_processor(attn_procs) - # create image projection layers. - image_proj_state_dict = state_dict["image_proj"] - - if "proj.weight" in state_dict["image_proj"]: - # IP-Adapter - clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1] - cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4 - - image_projection = ImageProjection( - cross_attention_dim=cross_attention_dim, - image_embed_dim=clip_embeddings_dim, - num_image_text_embeds=num_image_text_embeds, - ) - image_projection.to(dtype=self.dtype, device=self.device) - - new_sd = self._convert_ip_adapter_to_diffusers(image_proj_state_dict) - image_projection.load_state_dict(new_sd) - - elif "proj.3.weight" in state_dict["image_proj"]: - clip_embeddings_dim = state_dict["image_proj"]["proj.0.weight"].shape[0] - cross_attention_dim = state_dict["image_proj"]["proj.3.weight"].shape[0] - - image_projection = MLPProjection( - cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim - ) - image_projection.to(dtype=self.dtype, device=self.device) - - new_sd = self._convert_ip_adapter_to_diffusers(image_proj_state_dict) - image_projection.load_state_dict(new_sd) - - else: - # IP-Adapter Plus - embed_dims = state_dict["image_proj"]["proj_in.weight"].shape[1] - output_dims = state_dict["image_proj"]["proj_out.weight"].shape[0] - hidden_dims = state_dict["image_proj"]["latents"].shape[2] - heads = state_dict["image_proj"]["layers.0.0.to_q.weight"].shape[0] // 64 - - image_projection = Resampler( - embed_dims=embed_dims, - output_dims=output_dims, - hidden_dims=hidden_dims, - heads=heads, - num_queries=num_image_text_embeds, - ) - - new_sd = self._convert_ip_adapter_to_diffusers(image_proj_state_dict) - image_projection.load_state_dict(new_sd) - del image_proj_state_dict + # convert IP-Adapter Image Projection layers to diffusers + image_projection = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"]) self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype) self.config.encoder_hid_dim_type = "ip_image_proj" - - delete_adapter_layers