Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support any v1.5 checkpoint, various other improvements and fixes... #111

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ pretrained_models
./*.png
./*.mp4
demo/tmp
demo/outputs
demo/outputs/*
.idea/*
outputs/*
308 changes: 181 additions & 127 deletions demo/animate.py

Large diffs are not rendered by default.

93 changes: 70 additions & 23 deletions demo/gradio_animate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,40 @@
# disclosure or distribution of this material and related documentation
# without an express license agreement from ByteDance or
# its affiliates is strictly prohibited.
import argparse
import os

import gradio as gr
import imageio
import numpy as np
import gradio as gr
from PIL import Image
from huggingface_hub import hf_hub_download
from omegaconf import OmegaConf

from demo.animate import MagicAnimate
from demo.paths import config_path, script_path, models_path, magic_models_path, source_images_path, \
motion_sequences_path

animator = MagicAnimate()
animator = None

def animate(reference_image, motion_sequence_state, seed, steps, guidance_scale):
return animator(reference_image, motion_sequence_state, seed, steps, guidance_scale)

with gr.Blocks() as demo:
def list_checkpoints():
checkpoint_dir = os.path.join(models_path, "checkpoints")
# Recursively find all .ckpt and .safetensors files
checkpoints = [""]
for root, dirs, files in os.walk(checkpoint_dir):
for file in files:
if file.endswith(".ckpt") or file.endswith(".safetensors"):
checkpoints.append(os.path.join(root, file))
return checkpoints


# source_image, motion_sequence, random_seed, step, guidance_scale, size=512, checkpoint=None
def animate(prompt, reference_image, motion_sequence_state, seed, steps, guidance_scale, size, half_precision, checkpoint=None):
return animator(prompt, reference_image, motion_sequence_state, seed, steps, guidance_scale, size, half_precision,
checkpoint)


with gr.Blocks() as demo:
gr.HTML(
"""
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
Expand All @@ -40,25 +59,33 @@ def animate(reference_image, motion_sequence_state, seed, steps, guidance_scale)
</div>
""")
animation = gr.Video(format="mp4", label="Animation Results", autoplay=True)

with gr.Row():
reference_image = gr.Image(label="Reference Image")
motion_sequence = gr.Video(format="mp4", label="Motion Sequence")

checkpoint = gr.Dropdown(label="Checkpoint", choices=list_checkpoints())
half_precision = gr.Checkbox(label="Half precision", default=False)
with gr.Row():
prompt = gr.Textbox(label="Prompt", placeholder="Type an optional prompt here", lines=2)
with gr.Row():
reference_image = gr.Image(label="Reference Image")
motion_sequence = gr.Video(format="mp4", label="Motion Sequence")

with gr.Column():
random_seed = gr.Textbox(label="Random seed", value=1, info="default: -1")
sampling_steps = gr.Textbox(label="Sampling steps", value=25, info="default: 25")
guidance_scale = gr.Textbox(label="Guidance scale", value=7.5, info="default: 7.5")
submit = gr.Button("Animate")
size = gr.Slider(label="Size", value=512, min=256, max=1024, step=256, info="default: 512", visible=False)
random_seed = gr.Slider(label="Random seed", value=1, info="default: -1")
sampling_steps = gr.Slider(label="Sampling steps", value=25, info="default: 25")
guidance_scale = gr.Slider(label="Guidance scale", value=7.5, info="default: 7.5", step=0.1)
submit = gr.Button("Animate")


def read_video(video):
reader = imageio.get_reader(video)
fps = reader.get_meta_data()['fps']
return video



def read_image(image, size=512):
return np.array(Image.fromarray(image).resize((size, size)))



# when user uploads a new video
motion_sequence.upload(
read_video,
Expand All @@ -72,25 +99,45 @@ def read_image(image, size=512):
reference_image
)
# when the `submit` button is clicked
# source_image, motion_sequence, random_seed, step, guidance_scale, size = 512, checkpoint = None
submit.click(
animate,
[reference_image, motion_sequence, random_seed, sampling_steps, guidance_scale],
[prompt, reference_image, motion_sequence, random_seed, sampling_steps, guidance_scale, size, half_precision,
checkpoint],
animation
)

# Examples
gr.Markdown("## Examples")

gr.Examples(
examples=[
["inputs/applications/source_image/monalisa.png", "inputs/applications/driving/densepose/running.mp4"],
["inputs/applications/source_image/demo4.png", "inputs/applications/driving/densepose/demo4.mp4"],
["inputs/applications/source_image/dalle2.jpeg", "inputs/applications/driving/densepose/running2.mp4"],
["inputs/applications/source_image/dalle8.jpeg", "inputs/applications/driving/densepose/dancing2.mp4"],
["inputs/applications/source_image/multi1_source.png", "inputs/applications/driving/densepose/multi_dancing.mp4"],
[os.path.join(source_images_path, "monalisa.png"), os.path.join(motion_sequences_path, "running.mp4")],
[os.path.join(source_images_path, "demo4.png"), os.path.join(motion_sequences_path, "demo4.mp4")],
[os.path.join(source_images_path, "dalle2.jpeg"), os.path.join(motion_sequences_path, "running2.mp4")],
[os.path.join(source_images_path, "dalle8.jpeg"), os.path.join(motion_sequences_path, "dancing2.mp4")],
[os.path.join(source_images_path, "multi1_source.png"),
os.path.join(motion_sequences_path, "multi_dancing.mp4")],
],
inputs=[reference_image, motion_sequence],
outputs=animation,
)

if __name__ == '__main__':
if not os.path.exists(models_path):
os.mkdir(models_path)

if not os.path.exists(os.path.join(models_path, "checkpoints")):
os.mkdir(os.path.join(models_path, "checkpoints"))

if not os.path.exists(magic_models_path):
# git lfs clone https://huggingface.co/zcxu-eric/MagicAnimate, not hf_hub_download
git_lfs_path = os.path.join(models_path, "MagicAnimate")
if not os.path.exists(git_lfs_path):
os.system(f"git clone https://huggingface.co/zcxu-eric/MagicAnimate {git_lfs_path}")
else:
print(f"MagicAnimate already exists at {git_lfs_path}")

animator = MagicAnimate()

demo.launch(share=True)
demo.launch(share=True)
23 changes: 23 additions & 0 deletions demo/paths.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import os

script_path = os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), ".."))
output_path = os.path.join(script_path, "outputs")
models_path = os.path.join(script_path, "pretrained_models")
magic_models_path = os.path.join(models_path, "MagicAnimate")

pretrained_model_path = os.path.join(models_path, "stable-diffusion-v1-5")
pretrained_vae_path = os.path.join(models_path, "pretrained_vae")

pretrained_controlnet_path = os.path.join(magic_models_path, "densepose_controlnet")
pretrained_encoder_path = os.path.join(magic_models_path, "appearance_encoder")
pretrained_motion_module_path = os.path.join(magic_models_path, "temporal_attention")
motion_module = os.path.join(pretrained_motion_module_path, "temporal_attention.ckpt")

pretrained_unet_path = ""

config_path = os.path.join(script_path, "configs", "prompts", "animation.yaml")
inference_config_path = os.path.join(script_path, "configs", "inference", "inference.yaml")

source_images_path = os.path.join(script_path, "inputs", "applications", "source_image")
motion_sequences_path = os.path.join(script_path, "inputs", "applications", "driving", "densepose")

1 change: 1 addition & 0 deletions environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ name: manimate
channels:
- conda-forge
- defaults
- nvidia
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
Expand Down
32 changes: 32 additions & 0 deletions magicanimate/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,3 +506,35 @@ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_addition
print(f"### Temporal Module Parameters: {sum(params) / 1e6} M")

return model

@classmethod
def from_2d_unet(cls, unet2d_model, unet_additional_kwargs=None):
# Extract configuration from 2D UNet model
config_2d = unet2d_model.config
config_3d = config_2d.copy() # Convert to a dictionary if necessary

# Update configuration for 3D UNet specifics
config_3d["_class_name"] = cls.__name__
config_3d["down_block_types"] = [
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"DownBlock3D"
]
config_3d["up_block_types"] = [
"UpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D"
]

# Initialize a new 3D UNet model with the updated configuration
model_3d = cls.from_config(config_3d, **unet_additional_kwargs)

# Load state dict from 2D model to 3D model
# Note: This might require additional handling if the architectures differ significantly
state_dict_2d = unet2d_model.state_dict()
model_3d.load_state_dict(state_dict_2d, strict=False)

return model_3d

Loading