Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AnimateDiff+Controlnet] Fix multicontrolnet support #6551

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 75 additions & 2 deletions examples/community/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2989,7 +2989,7 @@ pipe = DiffusionPipeline.from_pretrained(
custom_pipeline="pipeline_animatediff_controlnet",
).to(device="cuda", dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(
model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1
model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear",
)
pipe.enable_vae_slicing()

Expand All @@ -3005,7 +3005,7 @@ result = pipe(
width=512,
height=768,
conditioning_frames=conditioning_frames,
num_inference_steps=12,
num_inference_steps=20,
).frames[0]

from diffusers.utils import export_to_gif
Expand All @@ -3029,6 +3029,79 @@ export_to_gif(result.frames[0], "result.gif")
</tr>
</table>

You can also use multiple controlnets at once!

```python
import torch
from diffusers import AutoencoderKL, ControlNetModel, MotionAdapter
from diffusers.pipelines import DiffusionPipeline
from diffusers.schedulers import DPMSolverMultistepScheduler
from PIL import Image

motion_id = "guoyww/animatediff-motion-adapter-v1-5-2"
adapter = MotionAdapter.from_pretrained(motion_id)
controlnet1 = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16)
controlnet2 = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)

model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
pipe = DiffusionPipeline.from_pretrained(
model_id,
motion_adapter=adapter,
controlnet=[controlnet1, controlnet2],
vae=vae,
custom_pipeline="pipeline_animatediff_controlnet",
).to(device="cuda", dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(
model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear",
)
pipe.enable_vae_slicing()

def load_video(file_path: str):
images = []

if file_path.startswith(('http://', 'https://')):
# If the file_path is a URL
response = requests.get(file_path)
response.raise_for_status()
content = BytesIO(response.content)
vid = imageio.get_reader(content)
else:
# Assuming it's a local file path
vid = imageio.get_reader(file_path)

for frame in vid:
pil_image = Image.fromarray(frame)
images.append(pil_image)

return images

video = load_video("dance.gif")

# You need to install it using `pip install controlnet_aux`
from controlnet_aux.processor import Processor

p1 = Processor("openpose_full")
cn1 = [p1(frame) for frame in video]

p2 = Processor("canny")
cn2 = [p2(frame) for frame in video]

prompt = "astronaut in space, dancing"
negative_prompt = "bad quality, worst quality, jpeg artifacts, ugly"
result = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=512,
height=768,
conditioning_frames=[cn1, cn2],
num_inference_steps=20,
)

from diffusers.utils import export_to_gif
export_to_gif(result.frames[0], "result.gif")
```

### DemoFusion

This pipeline is the official implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://arxiv.org/abs/2311.16973).
Expand Down
45 changes: 19 additions & 26 deletions examples/community/pipeline_animatediff_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import inspect
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -66,7 +66,7 @@
... custom_pipeline="pipeline_animatediff_controlnet",
... ).to(device="cuda", dtype=torch.float16)
>>> pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(
... model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1
... model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear",
... )
>>> pipe.enable_vae_slicing()

Expand All @@ -83,7 +83,7 @@
... height=768,
... conditioning_frames=conditioning_frames,
... num_inference_steps=12,
... ).frames[0]
... )

>>> from diffusers.utils import export_to_gif
>>> export_to_gif(result.frames[0], "result.gif")
Expand Down Expand Up @@ -151,7 +151,7 @@ def __init__(
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
motion_adapter: MotionAdapter,
controlnet: Union[ControlNetModel, MultiControlNetModel],
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
Expand All @@ -166,6 +166,9 @@ def __init__(
super().__init__()
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)

if isinstance(controlnet, (list, tuple)):
controlnet = MultiControlNetModel(controlnet)

self.register_modules(
vae=vae,
text_encoder=text_encoder,
Expand Down Expand Up @@ -488,6 +491,7 @@ def check_inputs(
prompt,
height,
width,
num_frames,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
Expand Down Expand Up @@ -557,31 +561,21 @@ def check_inputs(
or is_compiled
and isinstance(self.controlnet._orig_mod, ControlNetModel)
):
if isinstance(image, list):
for image_ in image:
self.check_image(image_, prompt, prompt_embeds)
else:
self.check_image(image, prompt, prompt_embeds)
if not isinstance(image, list):
raise TypeError(f"For single controlnet, `image` must be of type `list` but got {type(image)}")
if len(image) != num_frames:
raise ValueError(f"Excepted image to have length {num_frames} but got {len(image)=}")
elif (
isinstance(self.controlnet, MultiControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
):
if not isinstance(image, list):
raise TypeError("For multiple controlnets: `image` must be type `list`")

# When `image` is a nested list:
# (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
elif any(isinstance(i, list) for i in image):
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
elif len(image) != len(self.controlnet.nets):
raise ValueError(
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
)

for control_ in image:
for image_ in control_:
self.check_image(image_, prompt, prompt_embeds)
if not isinstance(image, list) or not isinstance(image[0], list):
raise TypeError(f"For multiple controlnets: `image` must be type list of lists but got {type(image)=}")
if len(image[0]) != num_frames:
raise ValueError(f"Expected length of image sublist as {num_frames} but got {len(image[0])=}")
if any(len(img) != len(image[0]) for img in image):
raise ValueError("All conditioning frame batches for multicontrolnet must be same size")
else:
assert False

Expand Down Expand Up @@ -913,6 +907,7 @@ def __call__(
prompt=prompt,
height=height,
width=width,
num_frames=num_frames,
callback_steps=callback_steps,
negative_prompt=negative_prompt,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
Expand Down Expand Up @@ -1000,9 +995,7 @@ def __call__(
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)

cond_prepared_frames.append(prepared_frame)

conditioning_frames = cond_prepared_frames
else:
assert False
Expand Down
Loading