Skip to content

Commit

Permalink
Merge branch 'dpm_zsnr' of https://github.com/Beinsezii/diffusers int…
Browse files Browse the repository at this point in the history
…o dpm_zsnr
  • Loading branch information
Beinsezii committed Feb 27, 2024
2 parents 555831e + 670f178 commit 6b8fabf
Show file tree
Hide file tree
Showing 17 changed files with 409 additions and 169 deletions.
38 changes: 19 additions & 19 deletions .github/ISSUE_TEMPLATE/bug-report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,40 +66,40 @@ body:
Questions on DiffusionPipeline (Saving, Loading, From pretrained, ...):
Questions on pipelines:
- Stable Diffusion @yiyixuxu @DN6 @sayakpaul @patrickvonplaten
- Stable Diffusion XL @yiyixuxu @sayakpaul @DN6 @patrickvonplaten
- Kandinsky @yiyixuxu @patrickvonplaten
- ControlNet @sayakpaul @yiyixuxu @DN6 @patrickvonplaten
- T2I Adapter @sayakpaul @yiyixuxu @DN6 @patrickvonplaten
- IF @DN6 @patrickvonplaten
- Text-to-Video / Video-to-Video @DN6 @sayakpaul @patrickvonplaten
- Wuerstchen @DN6 @patrickvonplaten
- Stable Diffusion @yiyixuxu @DN6 @sayakpaul
- Stable Diffusion XL @yiyixuxu @sayakpaul @DN6
- Kandinsky @yiyixuxu
- ControlNet @sayakpaul @yiyixuxu @DN6
- T2I Adapter @sayakpaul @yiyixuxu @DN6
- IF @DN6
- Text-to-Video / Video-to-Video @DN6 @sayakpaul
- Wuerstchen @DN6
- Other: @yiyixuxu @DN6
Questions on models:
- UNet @DN6 @yiyixuxu @sayakpaul @patrickvonplaten
- VAE @sayakpaul @DN6 @yiyixuxu @patrickvonplaten
- Transformers/Attention @DN6 @yiyixuxu @sayakpaul @DN6 @patrickvonplaten
- UNet @DN6 @yiyixuxu @sayakpaul
- VAE @sayakpaul @DN6 @yiyixuxu
- Transformers/Attention @DN6 @yiyixuxu @sayakpaul @DN6
Questions on Schedulers: @yiyixuxu @patrickvonplaten
Questions on Schedulers: @yiyixuxu
Questions on LoRA: @sayakpaul @patrickvonplaten
Questions on LoRA: @sayakpaul
Questions on Textual Inversion: @sayakpaul @patrickvonplaten
Questions on Textual Inversion: @sayakpaul
Questions on Training:
- DreamBooth @sayakpaul @patrickvonplaten
- Text-to-Image Fine-tuning @sayakpaul @patrickvonplaten
- Textual Inversion @sayakpaul @patrickvonplaten
- ControlNet @sayakpaul @patrickvonplaten
- DreamBooth @sayakpaul
- Text-to-Image Fine-tuning @sayakpaul
- Textual Inversion @sayakpaul
- ControlNet @sayakpaul
Questions on Tests: @DN6 @sayakpaul @yiyixuxu
Questions on Documentation: @stevhliu
Questions on JAX- and MPS-related things: @pcuenca
Questions on audio pipelines: @DN6 @patrickvonplaten
Questions on audio pipelines: @DN6
Expand Down
10 changes: 5 additions & 5 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ members/contributors who may be interested in your PR.
Core library:
- Schedulers: @yiyixuxu and @patrickvonplaten
- Pipelines: @patrickvonplaten and @sayakpaul
- Training examples: @sayakpaul and @patrickvonplaten
- Docs: @stevhliu and @yiyixuxu
- Schedulers: @yiyixuxu
- Pipelines: @sayakpaul @yiyixuxu @DN6
- Training examples: @sayakpaul
- Docs: @stevhliu and @sayakpaul
- JAX and MPS: @pcuenca
- Audio: @sanchit-gandhi
- General functionalities: @patrickvonplaten and @sayakpaul
- General functionalities: @sayakpaul @yiyixuxu @DN6
Integrations:
Expand Down
7 changes: 5 additions & 2 deletions examples/community/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3561,14 +3561,17 @@ pipe.disable_style_aligned()

This pipeline adds experimental support for the image-to-video task using AnimateDiff. Refer to [this](https://github.com/huggingface/diffusers/pull/6328) PR for more examples and results.

This pipeline relies on a "hack" discovered by the community that allows the generation of videos given an input image with AnimateDiff. It works by creating a copy of the image `num_frames` times and progressively adding more noise to the image based on the strength and latent interpolation method.

```py
import torch
from diffusers import MotionAdapter, DiffusionPipeline, DDIMScheduler
from diffusers.utils import export_to_gif, load_image

model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
pipe = DiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter, custom_pipeline="pipeline_animatediff_img2video").to("cuda")
pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace")
pipe = DiffusionPipeline.from_pretrained(model_id, motion_adapter=adapter, custom_pipeline="pipeline_animatediff_img2video").to("cuda")
pipe.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", beta_schedule="linear", steps_offset=1)

image = load_image("snail.png")
output = pipe(
Expand Down
53 changes: 48 additions & 5 deletions examples/community/pipeline_animatediff_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel, UNetMotionModel
from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, UNetMotionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.models.unets.unet_motion_model import MotionAdapter
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
Expand Down Expand Up @@ -382,6 +382,41 @@ def encode_image(self, image, device, num_images_per_prompt):
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
):
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]

if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)

image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)

if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)

image_embeds.append(single_image_embeds)
else:
image_embeds = ip_adapter_image_embeds
return image_embeds

# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
Expand Down Expand Up @@ -767,6 +802,7 @@ def __call__(
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[PipelineImageInput] = None,
conditioning_frames: Optional[List[PipelineImageInput]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
Expand Down Expand Up @@ -821,6 +857,9 @@ def __call__(
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
conditioning_frames (`List[PipelineImageInput]`, *optional*):
The ControlNet input condition to provide guidance to the `unet` for generation. If multiple ControlNets
are specified, images must be passed as a list such that each element of the list can be correctly
Expand Down Expand Up @@ -965,9 +1004,9 @@ def __call__(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

if ip_adapter_image is not None:
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_videos_per_prompt)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt
)

if isinstance(controlnet, ControlNetModel):
conditioning_frames = self.prepare_image(
Expand Down Expand Up @@ -1023,7 +1062,11 @@ def __call__(
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

# 7. Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
added_cond_kwargs = (
{"image_embeds": image_embeds}
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
else None
)

# 7.1 Create tensor stating which controlnets to keep
controlnet_keep = []
Expand Down
82 changes: 62 additions & 20 deletions examples/community/pipeline_animatediff_img2video.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Note:
# This pipeline relies on a "hack" discovered by the community that allows
# the generation of videos given an input image with AnimateDiff. It works
# by creating a copy of the image `num_frames` times and progressively adding
# more noise to the image based on the strength and latent interpolation method.

import inspect
from dataclasses import dataclass
from types import FunctionType
from typing import Any, Callable, Dict, List, Optional, Union

Expand All @@ -25,7 +30,8 @@
from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.models.unet_motion_model import MotionAdapter
from diffusers.models.unets.unet_motion_model import MotionAdapter
from diffusers.pipelines.animatediff.pipeline_output import AnimateDiffPipelineOutput
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import (
DDIMScheduler,
Expand All @@ -35,7 +41,7 @@
LMSDiscreteScheduler,
PNDMScheduler,
)
from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from diffusers.utils.torch_utils import randn_tensor


Expand All @@ -48,9 +54,10 @@
>>> from diffusers import MotionAdapter, DiffusionPipeline, DDIMScheduler
>>> from diffusers.utils import export_to_gif, load_image
>>> model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
>>> adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
>>> pipe = DiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter, custom_pipeline="pipeline_animatediff_img2video").to("cuda")
>>> pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace")
>>> pipe.scheduler = pipe.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", beta_schedule="linear", steps_offset=1)
>>> image = load_image("snail.png")
>>> output = pipe(image=image, prompt="A snail moving on the ground", strength=0.8, latent_interpolation_method="slerp")
Expand Down Expand Up @@ -225,14 +232,9 @@ def retrieve_timesteps(
return timesteps, num_inference_steps


@dataclass
class AnimateDiffImgToVideoPipelineOutput(BaseOutput):
frames: Union[torch.Tensor, np.ndarray]


class AnimateDiffImgToVideoPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin):
r"""
Pipeline for text-to-video generation.
Pipeline for image-to-video generation.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Expand Down Expand Up @@ -503,6 +505,41 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state

return image_embeds, uncond_image_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
):
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]

if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)

image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)

if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)

image_embeds.append(single_image_embeds)
else:
image_embeds = ip_adapter_image_embeds
return image_embeds

# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
Expand Down Expand Up @@ -765,6 +802,7 @@ def __call__(
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
Expand Down Expand Up @@ -818,6 +856,9 @@ def __call__(
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
`np.array`.
Expand All @@ -842,8 +883,8 @@ def __call__(
Examples:
Returns:
[`AnimateDiffImgToVideoPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`AnimateDiffImgToVideoPipelineOutput`] is
[`AnimateDiffPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`AnimateDiffPipelineOutput`] is
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
"""
# 0. Default height and width to unet
Expand Down Expand Up @@ -902,12 +943,9 @@ def __call__(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_videos_per_prompt, output_hidden_state
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt
)
if do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])

# 4. Preprocess image
image = self.image_processor.preprocess(image, height=height, width=width)
Expand Down Expand Up @@ -936,7 +974,11 @@ def __call__(
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

# 8. Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
added_cond_kwargs = (
{"image_embeds": image_embeds}
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
else None
)

# 9. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
Expand Down Expand Up @@ -970,7 +1012,7 @@ def __call__(
callback(i, t, latents)

if output_type == "latent":
return AnimateDiffImgToVideoPipelineOutput(frames=latents)
return AnimateDiffPipelineOutput(frames=latents)

# 10. Post-processing
video_tensor = self.decode_latents(latents)
Expand All @@ -986,4 +1028,4 @@ def __call__(
if not return_dict:
return (video,)

return AnimateDiffImgToVideoPipelineOutput(frames=video)
return AnimateDiffPipelineOutput(frames=video)
Loading

0 comments on commit 6b8fabf

Please sign in to comment.