Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IP adapter support for most pipelines #5900

Merged
merged 23 commits into from
Dec 10, 2023

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Nov 22, 2023

What does this PR do?

Part of #5884.

pipeline_stable_diffusion_instruct_pix2pix
import requests
from io import BytesIO
from PIL import Image

from diffusers import StableDiffusionInstructPix2PixPipeline

pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
    "timbrooks/instruct-pix2pix", torch_dtype=torch.float16, variant="fp16"
)
pipe = pipe.to("cuda")
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")

def download_image(url):
    response = requests.get(url)
    return Image.open(BytesIO(response.content)).convert("RGB")

img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png"
image = download_image(img_url).resize((512, 512))

url = "https://images.nightcafe.studio/jobs/j2VTkZXvVFoOciX20IqZ/j2VTkZXvVFoOciX20IqZ--1--opa80.jpg"
ip_adapter_image = download_image(url).resize((512, 512))

prompt = "make the mountains snowy"
image = pipe(prompt=prompt, image=image, ip_adapter_image=ip_adapter_image, num_inference_steps=100, guidance_scale=6, image_guidance_scale=1).images[0]

Not sure if IP-Adapters and Pix2Pix is a great combination from results. Nonetheless, it's quite cool!

@yiyixuxu Question: Pix2Pix uses two negative prompts and in a different ordering. I've replicated the same. Is this right?

Input Image IP Adapter Image
Input Image IP Adapter Image
Result Image
pipeline_stable_diffusion_panorama
import requests
from io import BytesIO
from PIL import Image

from diffusers import StableDiffusionPanoramaPipeline, DPMSolverMultistepScheduler

model_ckpt = "runwayml/stable-diffusion-v1-5"
scheduler = DPMSolverMultistepScheduler.from_pretrained(model_ckpt, subfolder="scheduler")
pipe = StableDiffusionPanoramaPipeline.from_pretrained(
    model_ckpt, scheduler=scheduler, torch_dtype=torch.float16, variant="fp16"
)
pipe = pipe.to("cuda")
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")

def download_image(url):
    response = requests.get(url)
    return Image.open(BytesIO(response.content)).convert("RGB")

ip_adapter_image = download_image("https://cdn-uploads.huggingface.co/production/uploads/1668693456211-noauth.jpeg")

prompt = "snowman"
image = pipe(prompt, num_inference_steps=20, height=512, width=768, ip_adapter_image=ip_adapter_image).images[0]
IP Adapter Image Result
IP Adapter Image Result Image
pipeline_stable_diffusion_sag
import torch
from diffusers import StableDiffusionSAGPipeline

pipe = StableDiffusionSAGPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16"
)
pipe = pipe.to("cuda")

pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")

import requests
from io import BytesIO
from PIL import Image

def download_image(url):
    response = requests.get(url)
    return Image.open(BytesIO(response.content)).convert("RGB")

ip_adapter_image = download_image("https://cdn-uploads.huggingface.co/production/uploads/1672623854262-noauth.jpeg")

prompt = "a photo of an astronaut riding a horse on mars"
image1 = pipe(prompt, ip_adapter_image=ip_adapter_image, sag_scale=0.75).images[0]
image2 = pipe("", ip_adapter_image=ip_adapter_image, sag_scale=0.75).images[0]
IP Adapter Image Result
IP Adapter Image Result 1 Result 2
pipeline_stable_diffusion_safe
import torch
from diffusers import StableDiffusionPipelineSafe

pipeline = StableDiffusionPipelineSafe.from_pretrained(
    "AIML-TUDA/stable-diffusion-safe", torch_dtype=torch.float16
).to("cuda")

pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")

import requests
from io import BytesIO
from PIL import Image

def download_image(url):
    response = requests.get(url)
    return Image.open(BytesIO(response.content)).convert("RGB")

ip_adapter_image = download_image("https://pbs.twimg.com/media/F08gEZZXgAADW6G.jpg").resize((480, 640))

image = pipeline(prompt="beautiful woman with red hair, the end, exploding sun, apocalypse, facing right", num_inference_steps=50, ip_adapter_image=ip_adapter_image, width=480, height=640, **SafetyConfig.WEAK).images[0]
IP Adapter Image Result
IP Adapter Image Result Image
pipeline_latent_consistency_model_text2img
import torch
from diffusers import LatentConsistencyModelPipeline

pipe = LatentConsistencyModelPipeline.from_pretrained("SimianLuo/LCM_Dreamshaper_v7", torch_dtype=torch.float32).to("cuda")

prompt = "A beautiful cyborg, cyberpunk, ultra realistic, photorealistic, 8k resolution"
num_inference_steps = 4
ip_adapter_image = pipe(prompt=prompt, num_inference_steps=num_inference_steps, height=512, width=512, guidance_scale=8.0).images[0]
ip_adapter_image

pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")

images1 = pipe(prompt=prompt, num_inference_steps=num_inference_steps, ip_adapter_image=ip_adapter_image, height=512, width=512, guidance_scale=12.0).images
images2 = pipe(prompt="", num_inference_steps=num_inference_steps, ip_adapter_image=ip_adapter_image, height=512, width=512, guidance_scale=12.0).images
images[0]
IP Adapter Image Result
IP Adapter Image Result 1 Result 2
pipeline_latent_consistency_models_img2img
import torch
from diffusers import LatentConsistencyModelImg2ImgPipeline

pipe = LatentConsistencyModelImg2ImgPipeline.from_pretrained("SimianLuo/LCM_Dreamshaper_v7", torch_dtype=torch.float32).to("cuda")

import requests
from io import BytesIO
from PIL import Image

def download_image(url):
    response = requests.get(url)
    return Image.open(BytesIO(response.content)).convert("RGB")

image = download_image("https://cdn-uploads.huggingface.co/production/uploads/noauth/pC7XSbG0iXTjE2FPKIOd8.jpeg").resize((512, 512))
image

pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")

ip_adapter_image = download_image("https://private-user-images.githubusercontent.com/12631849/282283038-9dc239c3-4483-46b9-a62b-b0c49c0c3b42.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTEiLCJleHAiOjE3MDA5MTkwMTMsIm5iZiI6MTcwMDkxODcxMywicGF0aCI6Ii8xMjYzMTg0OS8yODIyODMwMzgtOWRjMjM5YzMtNDQ4My00NmI5LWE2MmItYjBjNDljMGMzYjQyLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFJV05KWUFYNENTVkVINTNBJTJGMjAyMzExMjUlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjMxMTI1VDEzMjUxM1omWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTNlZDUzYmExZTYzM2MzODFmMTM0NTljMGQ4OWIwMThiMDU4NTUxNDViNzhjNjRiNGMzNmU1ZGZlMzA5MDEyNWYmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.nenEJTpeGT_hIp7TjfiIzF6ADRQdjeYTzREkvvijJ4U").resize((512, 512))

prompt = "snowy mountains, princess, best quality, high quality"
num_inference_steps = 4
result = pipe(prompt=prompt, image=image, ip_adapter_image=ip_adapter_image, num_inference_steps=num_inference_steps, height=512, width=512, strength=0.6, guidance_scale=12.0).images[0]
IP Adapter Image Result
Input Image
IP Adapter Image Result Image
pipeline_stable_diffusion_ldm3d
import torch
from diffusers import StableDiffusionLDM3DPipeline

pipe = StableDiffusionLDM3DPipeline.from_pretrained("Intel/ldm3d-4c", torch_dype=torch.float16).to("cuda")
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")

import requests
from io import BytesIO
from PIL import Image

def download_image(url):
    response = requests.get(url)
    return Image.open(BytesIO(response.content)).convert("RGB")

ip_adapter_image = download_image("https://mspoweruser.com/wp-content/uploads/2023/07/Fantasy-Dragon-Perched-on-a-Castle-Best-Stable-Diffusion-Prompts.jpg").resize((512, 512))

prompt = "Fantasy illustration of a dragon perched on a castle, with a stormy sky and lightning in the background."
output = pipe(prompt, ip_adapter_image=ip_adapter_image, num_inference_steps=50, guidance_scale=4.0)
rgb_image, depth_image = output.rgb[0], output.depth[0]
IP Adapter Image Result Depth
IP Adapter Image Result Depth

Before submitting

Who can review?

@yiyixuxu @patrickvonplaten

@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Nov 22, 2023

@yiyixuxu I'm facing a few errors when adding IP Adapters to pipeline_stable_diffusion_attend_and_excite and need your help. I tried a few things but can't seem to make it work. Facing the same error with pipeline_stable_diffusion_sag :(

Code
!pip install git+https://github.com/a-r-r-o-w/diffusers.git@ip-adapter-txt2img transformers accelerate

import torch
from diffusers import StableDiffusionAttendAndExcitePipeline

pipe = StableDiffusionAttendAndExcitePipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16"
).to("cuda")

from diffusers.schedulers import DPMSolverMultistepScheduler
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

import requests
from PIL import Image
from io import BytesIO
url = "https://user-images.githubusercontent.com/12631849/282283038-9dc239c3-4483-46b9-a62b-b0c49c0c3b42.png"
response = requests.get(url)
ip_adapter_image = Image.open(BytesIO(response.content)).convert("RGB")
ip_adapter_image = ip_adapter_image.resize((512, 512))

pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")

prompt = "a cat and an astronaut"

# use get_indices function to find out indices of the tokens you want to alter
pipe.get_indices(prompt)

token_indices = [2, 5]
seed = 6141
generator = torch.Generator("cuda").manual_seed(seed)

images = pipe(
    prompt=prompt,
    token_indices=token_indices,
    guidance_scale=9,
    generator=generator,
    num_inference_steps=20,
    max_iter_to_alter=12,
    ip_adapter_image=ip_adapter_image,
).images
Error log

In unet_2d_condition.py

        # 1.
        ...
        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
            if "image_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`"
                )
            image_embeds = added_cond_kwargs.get("image_embeds")
            image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
            print(encoder_hidden_states.shape)
            print(image_embeds.shape)
            encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)

        # 2. pre-process
        sample = self.conv_in(sample)

encoder_hidden_states.shape: torch.Size([1, 77, 768])
image_embeds.shape: torch.Size([2, 4, 768])

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-16-2499df1ec676> in <cell line: 10>()
      8 generator = torch.Generator("cuda").manual_seed(seed)
      9 
---> 10 images = pipe(
     11     prompt=prompt,
     12     token_indices=token_indices,

4 frames
/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs)
    113     def decorate_context(*args, **kwargs):
    114         with ctx_factory():
--> 115             return func(*args, **kwargs)
    116 
    117     return decorate_context

/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py in __call__(self, prompt, token_indices, height, width, num_inference_steps, guidance_scale, negative_prompt, num_images_per_prompt, eta, generator, latents, prompt_embeds, negative_prompt_embeds, ip_adapter_image, output_type, return_dict, callback, callback_steps, cross_attention_kwargs, max_iter_to_alter, thresholds, scale_factor, attn_res, clip_skip)
    978                         text_embedding = text_embedding.unsqueeze(0)
    979 
--> 980                         self.unet(
    981                             latent,
    982                             t,

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1525                 or _global_backward_pre_hooks or _global_backward_hooks
   1526                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527             return forward_call(*args, **kwargs)
   1528 
   1529         try:

/usr/local/lib/python3.10/dist-packages/diffusers/models/unet_2d_condition.py in forward(self, sample, timestep, encoder_hidden_states, class_labels, timestep_cond, attention_mask, cross_attention_kwargs, added_cond_kwargs, down_block_additional_residuals, mid_block_additional_residual, down_intrablock_additional_residuals, encoder_attention_mask, return_dict)
   1030             image_embeds = added_cond_kwargs.get("image_embeds")
   1031             image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
-> 1032             encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
   1033 
   1034         # 2. pre-process

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 1 but got size 2 for tensor number 1 in the list.

Edit: fixed

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@a-r-r-o-w a-r-r-o-w changed the title [WIP] IP adapter support for all txt2img pipelines [WIP] IP adapter support for most pipelines Nov 25, 2023
@a-r-r-o-w a-r-r-o-w marked this pull request as ready for review November 25, 2023 13:43
@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Nov 25, 2023

Wondering how to support semantic stable diffusion.

Here, edit_concepts is concatenated with prompts and negative prompts:

text_embeddings = torch.cat([uncond_embeddings, text_embeddings, edit_concepts])

I tried this just for the sake of getting it working but the results weren't great and lost all meaning related to the semantic stable diffusion task.

        if ip_adapter_image is not None:
            image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
            if do_classifier_free_guidance:
                image_embeds = torch.cat([negative_image_embeds, image_embeds, *[image_embeds] * enabled_editing_prompts])

Edit: Probably doesn't make sense to add here.

@a-r-r-o-w
Copy link
Member Author

Blip diffusion does not seem to have a feature_extractor as required in the IP Adapter encode_image. Not sure how to support it there either.

@yiyixuxu
Copy link
Collaborator

@a-r-r-o-w
ohh thanks so much! I will take a look tomorrow

I think it makes sense to add to LCM, not sure about rest of the pipelines
cc @sayakpaul and @patrickvonplaten here, let me know what you guys think

@yiyixuxu
Copy link
Collaborator

@a-r-r-o-w
don't worry about blip - it is supposed to do similar task as IP-adapter so we do not need to add it

Blip diffusion does not seem to have a feature_extractor as required in the IP Adapter encode_image. Not sure how to support it there either.

@charchit7
Copy link
Contributor

Addition : for StableDiffusionControlNetImg2ImgPipeline #5901

@sayakpaul
Copy link
Member

IIUC, IP Adapter is quite generic and should help the SD family of pipelines that we have. So, I think it's okay to add it to the other pipelines. This is also a use-case brought up by @apolinario.

@wwirving
Copy link

wwirving commented Nov 27, 2023

Hi all thanks for the great work! I've left a comment in #5887 regarding this - but wanted to check if this support will extend to the StableDiffusionXLControlNet pipelines?

@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Nov 27, 2023

Hi all thanks for the great work! I've left a comment in #5887 regarding this - but wanted to check if this support will extend to the StableDiffusionXLControlNet pipelines?

@wwirving Sure, I'll add support for all pipelines where it makes sense to have this.

I think what @sayakpaul mentioned and @apolinario brought up is the way to go. It is kinda similar to the img2img task in general, but it has other use cases. Let's say you wanted to do img2img with SD-SAG pipeline. It turns out you wouldn't be able to because there's no implementation of SAGImg2Img. I believe the intended use of IP Adapters is to replace the use of text prompts with image prompts, or use them together. I've added examples above demonstrating generations without using a prompt and only an ip_adapter_image. The results I've seen after hundreds of generations, from the currently supported pipelines in this PR, is quite amazing and it would be really cool to see it work with other pipelines as well.

@a-r-r-o-w
Copy link
Member Author

Hi all thanks for the great work! I've left a comment in #5887 regarding this - but wanted to check if this support will extend to the StableDiffusionXLControlNet pipelines?

@wwirving Just noticed that #5713 already adds support for SDXLControlNet :)

@yiyixuxu
Copy link
Collaborator

@a-r-r-o-w
I will wait for @patrickvonplaten 's feedback on which pipelines we want to add ip-adapter to:)
for now let's focus on the LCM pipelines:

  1. add example on the PR to see how the IP-adapter works with these 2 pipelines
  2. add ip-adapter related tests to these pipelines

Thanks!
YiYi

@a-r-r-o-w a-r-r-o-w changed the title [WIP] IP adapter support for most pipelines IP adapter support for most pipelines Nov 28, 2023
@a-r-r-o-w
Copy link
Member Author

Hi @yiyixuxu @patrickvonplaten! Is there anything more you'd like me to add here?

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Dec 7, 2023

thanks! looks good
we recently merged a ip-adapter PR that changed the code a little bit
I just updated the branch, I think you will need to run make fix-copies here

@a-r-r-o-w
Copy link
Member Author

i think i've got all 7 pipelines here covered and synced to the latest ip-adapter implementation correctly 🤞

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@yiyixuxu yiyixuxu requested a review from sayakpaul December 7, 2023 21:39
@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Dec 7, 2023

cc @sayakpaul for a final review here

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking solid! I just have this question here:
https://github.com/huggingface/diffusers/pull/5900/files#r1420300947

@a-r-r-o-w a-r-r-o-w requested a review from sayakpaul December 9, 2023 11:41
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking solid! I think once we take care of the if False situation, we can merge :)

@a-r-r-o-w
Copy link
Member Author

Looking solid! I think once we take care of the if False situation, we can merge :)

Already done ;)

a3ac5ce

@sayakpaul sayakpaul merged commit 88bdd97 into huggingface:main Dec 10, 2023
14 checks passed
@sayakpaul
Copy link
Member

Thank you for your very nice contribution!

@a-r-r-o-w a-r-r-o-w deleted the ip-adapter-txt2img branch December 10, 2023 15:50
sayakpaul added a commit that referenced this pull request Dec 11, 2023
* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py

* update tests

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py

* support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py

* support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py

* revert changes to sd_attend_and_excite and sd_upscale

* make style

* fix broken tests

* update ip-adapter implementation to latest

* apply suggestions from review

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
donhardman pushed a commit to donhardman/diffusers that referenced this pull request Dec 18, 2023
* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py

* update tests

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py

* support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py

* support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py

* revert changes to sd_attend_and_excite and sd_upscale

* make style

* fix broken tests

* update ip-adapter implementation to latest

* apply suggestions from review

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py

* update tests

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py

* support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py

* support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py

* revert changes to sd_attend_and_excite and sd_upscale

* make style

* fix broken tests

* update ip-adapter implementation to latest

* apply suggestions from review

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
sayakpaul added a commit that referenced this pull request Dec 26, 2023
* add: script to train lcm lora for sdxl with 🤗 datasets

* suit up the args.

* remove comments.

* fix num_update_steps

* fix batch unmarshalling

* fix num_update_steps_per_epoch

* fix; dataloading.

* fix microconditions.

* unconditional predictions debug

* fix batch size.

* no need to use use_auth_token

* Apply suggestions from code review

Co-authored-by: Suraj Patil <[email protected]>

* make vae encoding batch size an arg

* final serialization in kohya

* style

* state dict rejigging

* feat: no separate teacher unet.

* debug

* fix state dict serialization

* debug

* debug

* debug

* remove prints.

* remove kohya utility and make style

* fix serialization

* fix

* add test

* add peft dependency.

* add: peft

* remove peft

* autocast device determination from accelerator

* autocast

* reduce lora rank.

* remove unneeded space

* Apply suggestions from code review

Co-authored-by: Suraj Patil <[email protected]>

* style

* remove prompt dropout.

* also save in native diffusers ckpt format.

* debug

* debug

* debug

* better formation of the null embeddings.

* remove space.

* autocast fixes.

* autocast fix.

* hacky

* remove lora_sayak

* Apply suggestions from code review

Co-authored-by: Younes Belkada <[email protected]>

* style

* make log validation leaner.

* move back enabled in.

* fix: log_validation call.

* add: checkpointing tests

* taking my chances to see if disabling autocasting has any effect?

* start debugging

* name

* name

* name

* more debug

* more debug

* index

* remove index.

* print length

* print length

* print length

* move unet.train() after add_adapter()

* disable some prints.

* enable_adapters() manually.

* remove prints.

* some changes.

* fix params_to_optimize

* more fixes

* debug

* debug

* remove print

* disable grad for certain contexts.

* Add support for IPAdapterFull (#5911)

* Add support for IPAdapterFull


Co-authored-by: Patrick von Platen <[email protected]>

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>

* Fix a bug in `add_noise` function  (#6085)

* fix

* copies

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>

* [Advanced Diffusion Script] Add Widget default text (#6100)

add widget

* [Advanced Training Script] Fix pipe example (#6106)

* IP-Adapter for StableDiffusionControlNetImg2ImgPipeline (#5901)

* adapter for StableDiffusionControlNetImg2ImgPipeline

* fix-copies

* fix-copies

---------

Co-authored-by: Sayak Paul <[email protected]>

* IP adapter support for most pipelines (#5900)

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py

* update tests

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py

* support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py

* support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py

* revert changes to sd_attend_and_excite and sd_upscale

* make style

* fix broken tests

* update ip-adapter implementation to latest

* apply suggestions from review

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>

* fix: lora_alpha

* make vae casting conditional/

* param upcasting

* propagate comments from #6145

Co-authored-by: dg845 <[email protected]>

* [Peft] fix saving / loading when unet is not "unet" (#6046)

* [Peft] fix saving / loading when unet is not "unet"

* Update src/diffusers/loaders/lora.py

Co-authored-by: Sayak Paul <[email protected]>

* undo stablediffusion-xl changes

* use unet_name to get unet for lora helpers

* use unet_name

---------

Co-authored-by: Sayak Paul <[email protected]>

* [Wuerstchen] fix fp16 training and correct lora args (#6245)

fix fp16 training

Co-authored-by: Sayak Paul <[email protected]>

* [docs] fix: animatediff docs (#6339)

fix: animatediff docs

* add: note about the new script in readme_sdxl.

* Revert "[Peft] fix saving / loading when unet is not "unet" (#6046)"

This reverts commit 4c7e983.

* Revert "[Wuerstchen] fix fp16 training and correct lora args (#6245)"

This reverts commit 0bb9cf0.

* Revert "[docs] fix: animatediff docs (#6339)"

This reverts commit 11659a6.

* remove tokenize_prompt().

* assistive comments around enable_adapters() and diable_adapters().

---------

Co-authored-by: Suraj Patil <[email protected]>
Co-authored-by: Younes Belkada <[email protected]>
Co-authored-by: Fabio Rigano <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: apolinário <[email protected]>
Co-authored-by: Charchit Sharma <[email protected]>
Co-authored-by: Aryan V S <[email protected]>
Co-authored-by: dg845 <[email protected]>
Co-authored-by: Kashif Rasul <[email protected]>
donhardman pushed a commit to donhardman/diffusers that referenced this pull request Dec 29, 2023
* add: script to train lcm lora for sdxl with 🤗 datasets

* suit up the args.

* remove comments.

* fix num_update_steps

* fix batch unmarshalling

* fix num_update_steps_per_epoch

* fix; dataloading.

* fix microconditions.

* unconditional predictions debug

* fix batch size.

* no need to use use_auth_token

* Apply suggestions from code review

Co-authored-by: Suraj Patil <[email protected]>

* make vae encoding batch size an arg

* final serialization in kohya

* style

* state dict rejigging

* feat: no separate teacher unet.

* debug

* fix state dict serialization

* debug

* debug

* debug

* remove prints.

* remove kohya utility and make style

* fix serialization

* fix

* add test

* add peft dependency.

* add: peft

* remove peft

* autocast device determination from accelerator

* autocast

* reduce lora rank.

* remove unneeded space

* Apply suggestions from code review

Co-authored-by: Suraj Patil <[email protected]>

* style

* remove prompt dropout.

* also save in native diffusers ckpt format.

* debug

* debug

* debug

* better formation of the null embeddings.

* remove space.

* autocast fixes.

* autocast fix.

* hacky

* remove lora_sayak

* Apply suggestions from code review

Co-authored-by: Younes Belkada <[email protected]>

* style

* make log validation leaner.

* move back enabled in.

* fix: log_validation call.

* add: checkpointing tests

* taking my chances to see if disabling autocasting has any effect?

* start debugging

* name

* name

* name

* more debug

* more debug

* index

* remove index.

* print length

* print length

* print length

* move unet.train() after add_adapter()

* disable some prints.

* enable_adapters() manually.

* remove prints.

* some changes.

* fix params_to_optimize

* more fixes

* debug

* debug

* remove print

* disable grad for certain contexts.

* Add support for IPAdapterFull (huggingface#5911)

* Add support for IPAdapterFull


Co-authored-by: Patrick von Platen <[email protected]>

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>

* Fix a bug in `add_noise` function  (huggingface#6085)

* fix

* copies

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>

* [Advanced Diffusion Script] Add Widget default text (huggingface#6100)

add widget

* [Advanced Training Script] Fix pipe example (huggingface#6106)

* IP-Adapter for StableDiffusionControlNetImg2ImgPipeline (huggingface#5901)

* adapter for StableDiffusionControlNetImg2ImgPipeline

* fix-copies

* fix-copies

---------

Co-authored-by: Sayak Paul <[email protected]>

* IP adapter support for most pipelines (huggingface#5900)

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py

* update tests

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py

* support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py

* support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py

* revert changes to sd_attend_and_excite and sd_upscale

* make style

* fix broken tests

* update ip-adapter implementation to latest

* apply suggestions from review

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>

* fix: lora_alpha

* make vae casting conditional/

* param upcasting

* propagate comments from huggingface#6145

Co-authored-by: dg845 <[email protected]>

* [Peft] fix saving / loading when unet is not "unet" (huggingface#6046)

* [Peft] fix saving / loading when unet is not "unet"

* Update src/diffusers/loaders/lora.py

Co-authored-by: Sayak Paul <[email protected]>

* undo stablediffusion-xl changes

* use unet_name to get unet for lora helpers

* use unet_name

---------

Co-authored-by: Sayak Paul <[email protected]>

* [Wuerstchen] fix fp16 training and correct lora args (huggingface#6245)

fix fp16 training

Co-authored-by: Sayak Paul <[email protected]>

* [docs] fix: animatediff docs (huggingface#6339)

fix: animatediff docs

* add: note about the new script in readme_sdxl.

* Revert "[Peft] fix saving / loading when unet is not "unet" (huggingface#6046)"

This reverts commit 4c7e983.

* Revert "[Wuerstchen] fix fp16 training and correct lora args (huggingface#6245)"

This reverts commit 0bb9cf0.

* Revert "[docs] fix: animatediff docs (huggingface#6339)"

This reverts commit 11659a6.

* remove tokenize_prompt().

* assistive comments around enable_adapters() and diable_adapters().

---------

Co-authored-by: Suraj Patil <[email protected]>
Co-authored-by: Younes Belkada <[email protected]>
Co-authored-by: Fabio Rigano <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: apolinário <[email protected]>
Co-authored-by: Charchit Sharma <[email protected]>
Co-authored-by: Aryan V S <[email protected]>
Co-authored-by: dg845 <[email protected]>
Co-authored-by: Kashif Rasul <[email protected]>
antoine-scenario pushed a commit to antoine-scenario/diffusers that referenced this pull request Jan 2, 2024
* add: script to train lcm lora for sdxl with 🤗 datasets

* suit up the args.

* remove comments.

* fix num_update_steps

* fix batch unmarshalling

* fix num_update_steps_per_epoch

* fix; dataloading.

* fix microconditions.

* unconditional predictions debug

* fix batch size.

* no need to use use_auth_token

* Apply suggestions from code review

Co-authored-by: Suraj Patil <[email protected]>

* make vae encoding batch size an arg

* final serialization in kohya

* style

* state dict rejigging

* feat: no separate teacher unet.

* debug

* fix state dict serialization

* debug

* debug

* debug

* remove prints.

* remove kohya utility and make style

* fix serialization

* fix

* add test

* add peft dependency.

* add: peft

* remove peft

* autocast device determination from accelerator

* autocast

* reduce lora rank.

* remove unneeded space

* Apply suggestions from code review

Co-authored-by: Suraj Patil <[email protected]>

* style

* remove prompt dropout.

* also save in native diffusers ckpt format.

* debug

* debug

* debug

* better formation of the null embeddings.

* remove space.

* autocast fixes.

* autocast fix.

* hacky

* remove lora_sayak

* Apply suggestions from code review

Co-authored-by: Younes Belkada <[email protected]>

* style

* make log validation leaner.

* move back enabled in.

* fix: log_validation call.

* add: checkpointing tests

* taking my chances to see if disabling autocasting has any effect?

* start debugging

* name

* name

* name

* more debug

* more debug

* index

* remove index.

* print length

* print length

* print length

* move unet.train() after add_adapter()

* disable some prints.

* enable_adapters() manually.

* remove prints.

* some changes.

* fix params_to_optimize

* more fixes

* debug

* debug

* remove print

* disable grad for certain contexts.

* Add support for IPAdapterFull (huggingface#5911)

* Add support for IPAdapterFull


Co-authored-by: Patrick von Platen <[email protected]>

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>

* Fix a bug in `add_noise` function  (huggingface#6085)

* fix

* copies

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>

* [Advanced Diffusion Script] Add Widget default text (huggingface#6100)

add widget

* [Advanced Training Script] Fix pipe example (huggingface#6106)

* IP-Adapter for StableDiffusionControlNetImg2ImgPipeline (huggingface#5901)

* adapter for StableDiffusionControlNetImg2ImgPipeline

* fix-copies

* fix-copies

---------

Co-authored-by: Sayak Paul <[email protected]>

* IP adapter support for most pipelines (huggingface#5900)

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py

* update tests

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py

* support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py

* support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py

* revert changes to sd_attend_and_excite and sd_upscale

* make style

* fix broken tests

* update ip-adapter implementation to latest

* apply suggestions from review

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>

* fix: lora_alpha

* make vae casting conditional/

* param upcasting

* propagate comments from huggingface#6145

Co-authored-by: dg845 <[email protected]>

* [Peft] fix saving / loading when unet is not "unet" (huggingface#6046)

* [Peft] fix saving / loading when unet is not "unet"

* Update src/diffusers/loaders/lora.py

Co-authored-by: Sayak Paul <[email protected]>

* undo stablediffusion-xl changes

* use unet_name to get unet for lora helpers

* use unet_name

---------

Co-authored-by: Sayak Paul <[email protected]>

* [Wuerstchen] fix fp16 training and correct lora args (huggingface#6245)

fix fp16 training

Co-authored-by: Sayak Paul <[email protected]>

* [docs] fix: animatediff docs (huggingface#6339)

fix: animatediff docs

* add: note about the new script in readme_sdxl.

* Revert "[Peft] fix saving / loading when unet is not "unet" (huggingface#6046)"

This reverts commit 4c7e983.

* Revert "[Wuerstchen] fix fp16 training and correct lora args (huggingface#6245)"

This reverts commit 0bb9cf0.

* Revert "[docs] fix: animatediff docs (huggingface#6339)"

This reverts commit 11659a6.

* remove tokenize_prompt().

* assistive comments around enable_adapters() and diable_adapters().

---------

Co-authored-by: Suraj Patil <[email protected]>
Co-authored-by: Younes Belkada <[email protected]>
Co-authored-by: Fabio Rigano <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: apolinário <[email protected]>
Co-authored-by: Charchit Sharma <[email protected]>
Co-authored-by: Aryan V S <[email protected]>
Co-authored-by: dg845 <[email protected]>
Co-authored-by: Kashif Rasul <[email protected]>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py

* update tests

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py

* support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py

* support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py

* revert changes to sd_attend_and_excite and sd_upscale

* make style

* fix broken tests

* update ip-adapter implementation to latest

* apply suggestions from review

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* add: script to train lcm lora for sdxl with 🤗 datasets

* suit up the args.

* remove comments.

* fix num_update_steps

* fix batch unmarshalling

* fix num_update_steps_per_epoch

* fix; dataloading.

* fix microconditions.

* unconditional predictions debug

* fix batch size.

* no need to use use_auth_token

* Apply suggestions from code review

Co-authored-by: Suraj Patil <[email protected]>

* make vae encoding batch size an arg

* final serialization in kohya

* style

* state dict rejigging

* feat: no separate teacher unet.

* debug

* fix state dict serialization

* debug

* debug

* debug

* remove prints.

* remove kohya utility and make style

* fix serialization

* fix

* add test

* add peft dependency.

* add: peft

* remove peft

* autocast device determination from accelerator

* autocast

* reduce lora rank.

* remove unneeded space

* Apply suggestions from code review

Co-authored-by: Suraj Patil <[email protected]>

* style

* remove prompt dropout.

* also save in native diffusers ckpt format.

* debug

* debug

* debug

* better formation of the null embeddings.

* remove space.

* autocast fixes.

* autocast fix.

* hacky

* remove lora_sayak

* Apply suggestions from code review

Co-authored-by: Younes Belkada <[email protected]>

* style

* make log validation leaner.

* move back enabled in.

* fix: log_validation call.

* add: checkpointing tests

* taking my chances to see if disabling autocasting has any effect?

* start debugging

* name

* name

* name

* more debug

* more debug

* index

* remove index.

* print length

* print length

* print length

* move unet.train() after add_adapter()

* disable some prints.

* enable_adapters() manually.

* remove prints.

* some changes.

* fix params_to_optimize

* more fixes

* debug

* debug

* remove print

* disable grad for certain contexts.

* Add support for IPAdapterFull (huggingface#5911)

* Add support for IPAdapterFull


Co-authored-by: Patrick von Platen <[email protected]>

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>

* Fix a bug in `add_noise` function  (huggingface#6085)

* fix

* copies

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>

* [Advanced Diffusion Script] Add Widget default text (huggingface#6100)

add widget

* [Advanced Training Script] Fix pipe example (huggingface#6106)

* IP-Adapter for StableDiffusionControlNetImg2ImgPipeline (huggingface#5901)

* adapter for StableDiffusionControlNetImg2ImgPipeline

* fix-copies

* fix-copies

---------

Co-authored-by: Sayak Paul <[email protected]>

* IP adapter support for most pipelines (huggingface#5900)

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py

* update tests

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py

* support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py

* support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py

* revert changes to sd_attend_and_excite and sd_upscale

* make style

* fix broken tests

* update ip-adapter implementation to latest

* apply suggestions from review

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>

* fix: lora_alpha

* make vae casting conditional/

* param upcasting

* propagate comments from huggingface#6145

Co-authored-by: dg845 <[email protected]>

* [Peft] fix saving / loading when unet is not "unet" (huggingface#6046)

* [Peft] fix saving / loading when unet is not "unet"

* Update src/diffusers/loaders/lora.py

Co-authored-by: Sayak Paul <[email protected]>

* undo stablediffusion-xl changes

* use unet_name to get unet for lora helpers

* use unet_name

---------

Co-authored-by: Sayak Paul <[email protected]>

* [Wuerstchen] fix fp16 training and correct lora args (huggingface#6245)

fix fp16 training

Co-authored-by: Sayak Paul <[email protected]>

* [docs] fix: animatediff docs (huggingface#6339)

fix: animatediff docs

* add: note about the new script in readme_sdxl.

* Revert "[Peft] fix saving / loading when unet is not "unet" (huggingface#6046)"

This reverts commit 4c7e983.

* Revert "[Wuerstchen] fix fp16 training and correct lora args (huggingface#6245)"

This reverts commit 0bb9cf0.

* Revert "[docs] fix: animatediff docs (huggingface#6339)"

This reverts commit 11659a6.

* remove tokenize_prompt().

* assistive comments around enable_adapters() and diable_adapters().

---------

Co-authored-by: Suraj Patil <[email protected]>
Co-authored-by: Younes Belkada <[email protected]>
Co-authored-by: Fabio Rigano <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: apolinário <[email protected]>
Co-authored-by: Charchit Sharma <[email protected]>
Co-authored-by: Aryan V S <[email protected]>
Co-authored-by: dg845 <[email protected]>
Co-authored-by: Kashif Rasul <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants