Skip to content

Commit

Permalink
Add support for IPAdapterFull (#5911)
Browse files Browse the repository at this point in the history
* Add support for IPAdapterFull


Co-authored-by: Patrick von Platen <[email protected]>

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
  • Loading branch information
3 people authored Dec 7, 2023
1 parent 6bf1ca2 commit b65928b
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 1 deletion.
63 changes: 63 additions & 0 deletions docs/source/en/using-diffusers/loading_adapters.md
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,69 @@ image.save("sdxl_t2i.png")
</div>
</div>

You can use the IP-Adapter face model to apply specific faces to your images. It is an effective way to maintain consistent characters in your image generations.
Weights are loaded with the same method used for the other IP-Adapters.

```python
# Load ip-adapter-full-face_sd15.bin
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-full-face_sd15.bin")
```

<Tip>

It is recommended to use `DDIMScheduler` and `EulerDiscreteScheduler` for face model.


</Tip>

```python
import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler
from diffusers.utils import load_image

noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1
)

pipeline = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
scheduler=noise_scheduler,
).to("cuda")

pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-full-face_sd15.bin")

pipeline.set_ip_adapter_scale(0.7)

image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png")

generator = torch.Generator(device="cpu").manual_seed(33)

image = pipeline(
prompt="A photo of a girl wearing a black dress, holding red roses in hand, upper body, behind is the Eiffel Tower",
ip_adapter_image=image,
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
num_inference_steps=50, num_images_per_prompt=1, width=512, height=704,
generator=generator,
).images[0]
```

<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">input image</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ipadapter_full_face_output.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">output image</figcaption>
</div>
</div>

### LCM-Lora

Expand Down
29 changes: 28 additions & 1 deletion src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from huggingface_hub.utils import validate_hf_hub_args
from torch import nn

from ..models.embeddings import ImageProjection, Resampler
from ..models.embeddings import ImageProjection, MLPProjection, Resampler
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import (
USE_PEFT_BACKEND,
Expand Down Expand Up @@ -675,6 +675,9 @@ def _load_ip_adapter_weights(self, state_dict):
if "proj.weight" in state_dict["image_proj"]:
# IP-Adapter
num_image_text_embeds = 4
elif "proj.3.weight" in state_dict["image_proj"]:
# IP-Adapter Full Face
num_image_text_embeds = 257 # 256 CLIP tokens + 1 CLS token
else:
# IP-Adapter Plus
num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1]
Expand Down Expand Up @@ -744,8 +747,32 @@ def _load_ip_adapter_weights(self, state_dict):
"norm.bias": state_dict["image_proj"]["norm.bias"],
}
)
image_projection.load_state_dict(image_proj_state_dict)
del image_proj_state_dict

elif "proj.3.weight" in state_dict["image_proj"]:
clip_embeddings_dim = state_dict["image_proj"]["proj.0.weight"].shape[0]
cross_attention_dim = state_dict["image_proj"]["proj.3.weight"].shape[0]

image_projection = MLPProjection(
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
)
image_projection.to(dtype=self.dtype, device=self.device)

# load image projection layer weights
image_proj_state_dict = {}
image_proj_state_dict.update(
{
"ff.net.0.proj.weight": state_dict["image_proj"]["proj.0.weight"],
"ff.net.0.proj.bias": state_dict["image_proj"]["proj.0.bias"],
"ff.net.2.weight": state_dict["image_proj"]["proj.2.weight"],
"ff.net.2.bias": state_dict["image_proj"]["proj.2.bias"],
"norm.weight": state_dict["image_proj"]["proj.3.weight"],
"norm.bias": state_dict["image_proj"]["proj.3.bias"],
}
)
image_projection.load_state_dict(image_proj_state_dict)
del image_proj_state_dict

else:
# IP-Adapter Plus
Expand Down
12 changes: 12 additions & 0 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,18 @@ def forward(self, image_embeds: torch.FloatTensor):
return image_embeds


class MLPProjection(nn.Module):
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024):
super().__init__()
from .attention import FeedForward

self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu")
self.norm = nn.LayerNorm(cross_attention_dim)

def forward(self, image_embeds: torch.FloatTensor):
return self.norm(self.ff(image_embeds))


class CombinedTimestepLabelEmbeddings(nn.Module):
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
super().__init__()
Expand Down
19 changes: 19 additions & 0 deletions tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,25 @@ def test_inpainting(self):

assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)

def test_text_to_image_full_face(self):
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
pipeline = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype
)
pipeline.to(torch_device)
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-full-face_sd15.bin")
pipeline.set_ip_adapter_scale(0.7)

inputs = self.get_dummy_inputs()
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()

expected_slice = np.array(
[0.1706543, 0.1303711, 0.12573242, 0.21777344, 0.14550781, 0.14038086, 0.40820312, 0.41455078, 0.42529297]
)

assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)


@slow
@require_torch_gpu
Expand Down

0 comments on commit b65928b

Please sign in to comment.