Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add converter method for ip adapters #6150

Merged
merged 4 commits into from
Dec 18, 2023
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 47 additions & 56 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import OrderedDict, defaultdict
from collections import defaultdict
from contextlib import nullcontext
from typing import Callable, Dict, List, Optional, Union

Expand Down Expand Up @@ -664,6 +664,44 @@ def delete_adapters(self, adapter_names: Union[List[str], str]):
if hasattr(self, "peft_config"):
self.peft_config.pop(adapter_name, None)

def _convert_ip_adapter_to_diffusers(self, state_dict):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _convert_ip_adapter_to_diffusers(self, state_dict):
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should move this method to https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders/ip_adapter.py. This way the UNet loader stays cleaner.

updated_state_dict = {}

if "proj.weight" in state_dict:
for key, value in state_dict.items():
diffusers_name = key.replace("proj", "image_embeds")
updated_state_dict[diffusers_name] = value

elif "proj.3.weight" in state_dict:
for key, value in state_dict.items():
diffusers_name = key.replace("proj.0", "ff.net.0.proj")
diffusers_name = diffusers_name.replace("proj.2", "ff.net.2")
diffusers_name = diffusers_name.replace("proj.3", "norm")
updated_state_dict[diffusers_name] = value

else:
for key, value in state_dict.items():
diffusers_name = key.replace("0.to", "2.to")
diffusers_name = diffusers_name.replace("1.0.weight", "3.0.weight")
diffusers_name = diffusers_name.replace("1.0.bias", "3.0.bias")
diffusers_name = diffusers_name.replace("1.1.weight", "3.1.net.0.proj.weight")
diffusers_name = diffusers_name.replace("1.3.weight", "3.1.net.2.weight")

if "norm1" in diffusers_name:
updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value
elif "norm2" in diffusers_name:
updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value
elif "to_kv" in diffusers_name:
v_chunk = value.chunk(2, dim=0)
updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0]
updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1]
elif "to_out" in diffusers_name:
updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value
else:
updated_state_dict[diffusers_name] = value

return updated_state_dict

def _load_ip_adapter_weights(self, state_dict):
from ..models.attention_processor import (
AttnProcessor,
Expand Down Expand Up @@ -725,6 +763,8 @@ def _load_ip_adapter_weights(self, state_dict):
self.set_attn_processor(attn_procs)

# create image projection layers.
image_proj_state_dict = state_dict["image_proj"]

if "proj.weight" in state_dict["image_proj"]:
# IP-Adapter
clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
Expand All @@ -737,18 +777,8 @@ def _load_ip_adapter_weights(self, state_dict):
)
image_projection.to(dtype=self.dtype, device=self.device)

# load image projection layer weights
image_proj_state_dict = {}
image_proj_state_dict.update(
{
"image_embeds.weight": state_dict["image_proj"]["proj.weight"],
"image_embeds.bias": state_dict["image_proj"]["proj.bias"],
"norm.weight": state_dict["image_proj"]["norm.weight"],
"norm.bias": state_dict["image_proj"]["norm.bias"],
}
)
image_projection.load_state_dict(image_proj_state_dict)
del image_proj_state_dict
new_sd = self._convert_ip_adapter_to_diffusers(image_proj_state_dict)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
new_sd = self._convert_ip_adapter_to_diffusers(image_proj_state_dict)
new_sd = self._convert_ip_adapter_image_proj_to_diffusers(image_proj_state_dict)

image_projection.load_state_dict(new_sd)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed better and cleaner!


elif "proj.3.weight" in state_dict["image_proj"]:
clip_embeddings_dim = state_dict["image_proj"]["proj.0.weight"].shape[0]
Expand All @@ -759,20 +789,8 @@ def _load_ip_adapter_weights(self, state_dict):
)
image_projection.to(dtype=self.dtype, device=self.device)

# load image projection layer weights
image_proj_state_dict = {}
image_proj_state_dict.update(
{
"ff.net.0.proj.weight": state_dict["image_proj"]["proj.0.weight"],
"ff.net.0.proj.bias": state_dict["image_proj"]["proj.0.bias"],
"ff.net.2.weight": state_dict["image_proj"]["proj.2.weight"],
"ff.net.2.bias": state_dict["image_proj"]["proj.2.bias"],
"norm.weight": state_dict["image_proj"]["proj.3.weight"],
"norm.bias": state_dict["image_proj"]["proj.3.bias"],
}
)
image_projection.load_state_dict(image_proj_state_dict)
del image_proj_state_dict
new_sd = self._convert_ip_adapter_to_diffusers(image_proj_state_dict)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
new_sd = self._convert_ip_adapter_to_diffusers(image_proj_state_dict)
new_sd = self._convert_ip_adapter_image_proj_to_diffusers(image_proj_state_dict)

image_projection.load_state_dict(new_sd)

else:
# IP-Adapter Plus
Expand All @@ -789,36 +807,9 @@ def _load_ip_adapter_weights(self, state_dict):
num_queries=num_image_text_embeds,
)

image_proj_state_dict = state_dict["image_proj"]

new_sd = OrderedDict()
for k, v in image_proj_state_dict.items():
if "0.to" in k:
k = k.replace("0.to", "2.to")
elif "1.0.weight" in k:
k = k.replace("1.0.weight", "3.0.weight")
elif "1.0.bias" in k:
k = k.replace("1.0.bias", "3.0.bias")
elif "1.1.weight" in k:
k = k.replace("1.1.weight", "3.1.net.0.proj.weight")
elif "1.3.weight" in k:
k = k.replace("1.3.weight", "3.1.net.2.weight")

if "norm1" in k:
new_sd[k.replace("0.norm1", "0")] = v
elif "norm2" in k:
new_sd[k.replace("0.norm2", "1")] = v
elif "to_kv" in k:
v_chunk = v.chunk(2, dim=0)
new_sd[k.replace("to_kv", "to_k")] = v_chunk[0]
new_sd[k.replace("to_kv", "to_v")] = v_chunk[1]
elif "to_out" in k:
new_sd[k.replace("to_out", "to_out.0")] = v
else:
new_sd[k] = v

new_sd = self._convert_ip_adapter_to_diffusers(image_proj_state_dict)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have this function return the image_projection module instead?

image_projection.load_state_dict(new_sd)
del image_proj_state_dict
del image_proj_state_dict

self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype)
self.config.encoder_hid_dim_type = "ip_image_proj"
Expand Down
Loading