From 99fda50285167aa96dffb448857f9ef262e233a3 Mon Sep 17 00:00:00 2001 From: Erik Dunteman <44653944+erik-dunteman@users.noreply.github.com> Date: Wed, 20 Nov 2024 00:40:37 -0600 Subject: [PATCH] Updates Mochi to use Diffusers (#986) * move mochi example to new diffusers native implementation * commit outputs * final edits * apply changes * fix ruff * minor text fixes * add missing # --------- Co-authored-by: Charles Frye --- 06_gpu_and_ml/text-to-video/mochi.py | 396 +++++++++------------------ 1 file changed, 133 insertions(+), 263 deletions(-) diff --git a/06_gpu_and_ml/text-to-video/mochi.py b/06_gpu_and_ml/text-to-video/mochi.py index 6135c2bf2..bc0365371 100644 --- a/06_gpu_and_ml/text-to-video/mochi.py +++ b/06_gpu_and_ml/text-to-video/mochi.py @@ -16,10 +16,9 @@ # # Note that the Mochi model, at time of writing, -# requires several minutes on four H100s to produce +# requires several minutes on one H100 to produce # a high-quality clip of even a few seconds. -# It also takes a five to ten minutes to boot up. -# So a single video generation therefore costs about $2 +# So a single video generation therefore costs about $0.33 # at our ~$5/hr rate for H100s. # Keep your eyes peeled for improved efficiency @@ -28,71 +27,62 @@ # ## Setting up the environment for Mochi -# We start by defining the environment the model runs in. -# We'll need the [full CUDA toolkit](https://modal.com/docs/guide/cuda), -# [Flash Attention](https://arxiv.org/abs/2205.14135) for fast attention kernels, -# and the Mochi model code. +# At the time of writing, Mochi is supported natively in the [`diffusers`](https://github.com/huggingface/diffusers) library, +# but only in a pre-release version. +# So we'll need to install `diffusers` and `transformers` from GitHub. -import json -import os -import tempfile +import string import time from pathlib import Path import modal -MINUTES = 60 -HOURS = 60 * MINUTES - - -cuda_version = "12.3.1" # should be no greater than host CUDA version -flavor = "devel" # includes full CUDA toolkit -os_version = "ubuntu22.04" -tag = f"{cuda_version}-{flavor}-{os_version}" +app = modal.App() image = ( - modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.11") - .entrypoint([]) - .apt_install("git", "ffmpeg") - .pip_install("torch==2.4.0", "packaging", "ninja", "wheel", "setuptools") - .pip_install("flash-attn==2.6.3", extra_options="--no-build-isolation") + modal.Image.debian_slim(python_version="3.11") + .apt_install("git") .pip_install( - "git+https://github.com/genmoai/models.git@075b6e36db58f1242921deff83a1066887b9c9e1" + "torch==2.5.1", + "accelerate==1.1.1", + "hf_transfer==0.1.8", + "sentencepiece==0.2.0", + "imageio==2.36.0", + "imageio-ffmpeg==0.5.1", + "git+https://github.com/huggingface/transformers@30335093276212ce74938bdfd85bfd5df31a668a", + "git+https://github.com/huggingface/diffusers@99c0483b67427de467f11aa35d54678fd36a7ea2", + ) + .env( + { + "HF_HUB_ENABLE_HF_TRANSFER": "1", + "HF_HOME": "/models", + } ) ) -app = modal.App("example-mochi") +# ## Saving outputs -with image.imports(): - import numpy as np - import ray - import torch - from einops import rearrange - from mochi_preview.handler import MochiWrapper - from PIL import Image - from tqdm import tqdm +# On Modal, we save large or expensive-to-compute data to +# [distributed Volumes](https://modal.com/docs/guide/volumes) -# ## Saving model weights and outputs +# We'll use this for saving our Mochi weights, as well as our video outputs. -# Mochi weighs in at ~80 GB (~20B params, released in full 32bit precision) -# and can take several minutes to generate videos. +VOLUME_NAME = "mochi-outputs" +outputs = modal.Volume.from_name(VOLUME_NAME, create_if_missing=True) +OUTPUTS_PATH = Path("/outputs") # remote path for saving video outputs -# On Modal, we save large or expensive-to-compute data to -# [distributed Volumes](https://modal.com/docs/guide/volumes) -# so that they are accessible from any Modal Function -# or downloadable via the Modal dashboard or CLI. +MODEL_VOLUME_NAME = "mochi-model" +model = modal.Volume.from_name(MODEL_VOLUME_NAME, create_if_missing=True) +MODEL_PATH = Path("/models") # remote path for saving model weights -model = modal.Volume.from_name("mochi-model", create_if_missing=True) -outputs = modal.Volume.from_name("mochi-outputs", create_if_missing=True) +MINUTES = 60 +HOURS = 60 * MINUTES -MODEL_CACHE = Path("/models") # remote path for saving the model -OUTPUTS_PATH = "/outputs" # remote path for saving video outputs +# ## Downloading the model -# We download the model using the `hf-transfer` -# library from Hugging Face and additionally download -# the text encoder (Google's T5 XXL) using `transformers`. +# We download the model weights into Volume cache to speed up cold starts. -# This can takes five to thirty minutes, depending on traffic +# This download takes five minutes or more, depending on traffic # and network speed. # If you want to launch the download first, @@ -107,232 +97,128 @@ # even if you close your terminal or shut down your computer # while it's running. -download_image = ( - modal.Image.debian_slim(python_version="3.11") - .pip_install( - "huggingface_hub", - "hf-transfer", - "torch", - "transformers", - "sentencepiece", - ) - .env( - { - "HF_HUB_ENABLE_HF_TRANSFER": "1", - "HF_HOME": str(MODEL_CACHE / "huggingface"), - } - ) -) -image = image.env( # so we look for the model in the right place - {"HF_HOME": str(MODEL_CACHE / "huggingface")} -) +with image.imports(): + import torch + from diffusers import MochiPipeline + from diffusers.utils import export_to_video @app.function( - volumes={MODEL_CACHE: model}, - timeout=2 * HOURS, - image=download_image, + image=image, + volumes={ + MODEL_PATH: model, + }, + timeout=20 * MINUTES, ) -def download_model( - model_revision: str = "8e9673c5349979457e515fddd38911df6b4ca07f", -): - from huggingface_hub import snapshot_download - from transformers import T5EncoderModel, T5Tokenizer - - model.reload() - print("🍡 downloading Mochi model") - - snapshot_download( - repo_id="genmo/mochi-1-preview", - local_dir=MODEL_CACHE / "mochi-1-preview", - revision=model_revision, +def download_model(revision="83359d26a7e2bbe200ecbfda8ebff850fd03b545"): + # uses HF_HOME to point download to the model volume + MochiPipeline.from_pretrained( + "genmo/mochi-1-preview", + torch_dtype=torch.bfloat16, + revision=revision, ) - print("🍡 model downloaded") - print("🍡 downloading text encoder") - T5Tokenizer.from_pretrained("google/t5-v1_1-xxl", legacy=False) - T5EncoderModel.from_pretrained("google/t5-v1_1-xxl") +# ## Setting up our Mochi class - model.commit() - print("🍡 text encoder downloaded") - -# ## Running Mochi inference - -# We can trigger Mochi inference from our local machine by running the code in -# the local entrypoint below. - -# It ensures the model is downloaded to a remote volume, -# spins up a new replica to generate a video, also saved remotely, -# and then downloads the video to the local machine. - -# You can trigger it with: - -# ```bash -# modal run --detach mochi -# ``` - - -@app.local_entrypoint() -def main( - prompt: str = "A cat playing drums in a jazz ensemble", - num_inference_steps: int = 200, -): - from pathlib import Path - - mochi = Mochi() - local_dir = Path("/tmp/mochi") - local_dir.mkdir(exist_ok=True, parents=True) - download_model.remote() - remote_path = Path( - mochi.generate_video.remote( - prompt=prompt, num_inference_steps=num_inference_steps - ) - ) - local_path = local_dir / remote_path.name - local_path.write_bytes(b"".join(outputs.read_file(remote_path.name))) - print("🍡 video saved locally at", local_path) - - -# To deploy Mochi, run -# ```bash -# modal deploy mochi -# ``` - -# And then use it from another Python process that has access to your Modal credentials: - -# ```python -# import modal +# We'll use the `@cls` decorator to define a [Modal Class](https://modal.com/docs/guide/lifecycle-functions) +# which we use to control the lifecycle of our cloud container. # -# Mochi = modal.Cls.lookup("example-mochi", "Mochi") -# remote_path = Mochi().generate_video.remote(prompt="A cat playing drums in a jazz ensemble") -# ``` - - -# The Mochi inference logic is defined in the Modal [`Cls`](https://modal.com/docs/guide/lifecycle-functions) below. - -# See [the Mochi GitHub repo](https://github.com/genmoai/models) -# for more details on running Mochi. - - +# We configure it to use our image, the distributed volume, and a single H100 GPU. @app.cls( - gpu=modal.gpu.H100(count=4), + image=image, volumes={ - MODEL_CACHE: model, - OUTPUTS_PATH: outputs, # videos are saved to (distributed) disk + OUTPUTS_PATH: outputs, # videos will be saved to a distributed volume + MODEL_PATH: model, }, - # boot takes a while, so we keep the container warm for 20 minutes after the last call finishes + gpu=modal.gpu.H100(count=1), timeout=1 * HOURS, - container_idle_timeout=20 * MINUTES, - image=image, ) class Mochi: @modal.enter() def load_model(self): - ray.init() - model_path = MODEL_CACHE / "mochi-1-preview" - vae_stats_path = f"{model_path}/vae_stats.json" - vae_checkpoint_path = f"{model_path}/vae.safetensors" - model_config_path = f"{model_path}/dit-config.yaml" - model_checkpoint_path = f"{model_path}/dit.safetensors" - num_gpus = torch.cuda.device_count() - if num_gpus < 4: - print( - f"🍡 WARNING: Mochi requires at least 4xH100 GPUs, but only {num_gpus} GPU(s) are available." - ) - print( - f"🍡 loading model to {num_gpus} GPUs. This can take a few minutes." - ) - self.model = MochiWrapper( - num_workers=num_gpus, - vae_stats_path=vae_stats_path, - vae_checkpoint_path=vae_checkpoint_path, - dit_config_path=model_config_path, - dit_checkpoint_path=model_checkpoint_path, + # our HF_HOME env var points to the model volume as the cache + self.pipe = MochiPipeline.from_pretrained( + "genmo/mochi-1-preview", + torch_dtype=torch.bfloat16, ) - print("🍡 model loaded") - - @modal.exit() - def graceful_exit(self): - ray.shutdown() + self.pipe.enable_model_cpu_offload() + self.pipe.enable_vae_tiling() @modal.method() - def generate_video( + def generate( self, - prompt="", + prompt, negative_prompt="", - width=848, - height=480, - num_frames=163, - seed=12345, - cfg_scale=4.5, num_inference_steps=200, + guidance_scale=4.5, + num_frames=19, ): - # credit: https://github.com/genmoai/models/blob/7c7d33c49d53bbf939fd6676610e949f3008b5a8/src/mochi_preview/infer.py#L63 - - # sigma_schedule should be a list of floats of length (num_inference_steps + 1), - # such that sigma_schedule[0] == 1.0 and sigma_schedule[-1] == 0.0 and monotonically decreasing. - sigma_schedule = linear_quadratic_schedule(num_inference_steps, 0.025) - - # cfg_schedule should be a list of floats of length num_inference_steps. - # For simplicity, we just use the same cfg scale at all timesteps, - # but more optimal schedules may use varying cfg, e.g: - # [5.0] * (num_inference_steps // 2) + [4.5] * (num_inference_steps // 2) - cfg_schedule = [cfg_scale] * num_inference_steps - - args = { - "height": height, - "width": width, - "num_frames": num_frames, - "mochi_args": { - "sigma_schedule": sigma_schedule, - "cfg_schedule": cfg_schedule, - "num_inference_steps": num_inference_steps, - "batch_cfg": True, - }, - "prompt": [prompt], - "negative_prompt": [negative_prompt], - "seed": seed, - } + frames = self.pipe( + prompt=prompt, + negative_prompt=negative_prompt, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_frames=num_frames, + ).frames[0] + + # save to disk using prompt as filename + mp4_name = slugify(prompt) + export_to_video(frames, Path(OUTPUTS_PATH) / mp4_name) + outputs.commit() + return mp4_name - final_frames = None - for cur_progress, frames, finished in tqdm( - self.model(args), total=num_inference_steps + 1 - ): - final_frames = frames - assert isinstance(final_frames, np.ndarray) - assert final_frames.dtype == np.float32 +# ## Running Mochi inference - final_frames = rearrange(final_frames, "t b h w c -> b t h w c") - final_frames = final_frames[0] +# We can trigger Mochi inference from our local machine by running the code in +# the local entrypoint below. - output_path = os.path.join( - OUTPUTS_PATH, f"output_{int(time.time())}.mp4" - ) +# It ensures the model is downloaded to a remote volume, +# spins up a new replica to generate a video, also saved remotely, +# and then downloads the video to the local machine. + +# You can trigger it with: +# ```bash +# modal run --detach mochi +# ``` + +# Optional command line flags can be viewed with: +# ```bash +# modal run mochi --help +# ``` - with tempfile.TemporaryDirectory() as tmpdir: - frame_paths = [] - for i, frame in enumerate(final_frames): - frame = (frame * 255).astype(np.uint8) - frame_img = Image.fromarray(frame) - frame_path = os.path.join(tmpdir, f"frame_{i:04d}.png") - frame_img.save(frame_path) - frame_paths.append(frame_path) +# Using these flags, you can tweak your generation from the command line: +# ```bash +# modal run --detach mochi --prompt="a cat playing drums in a jazz ensemble" --num-inference-steps=64 +# ``` - frame_pattern = os.path.join(tmpdir, "frame_%04d.png") - ffmpeg_cmd = f"ffmpeg -y -r 30 -i {frame_pattern} -vcodec libx264 -pix_fmt yuv420p {output_path}" - os.system(ffmpeg_cmd) - json_path = os.path.splitext(output_path)[0] + ".json" - with open(json_path, "w") as f: - json.dump(args, f, indent=4) +@app.local_entrypoint() +def main( + prompt="Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k.", + negative_prompt="", + num_inference_steps=200, + guidance_scale=4.5, + num_frames=19, # produces ~1s of video +): + mochi = Mochi() + mp4_name = mochi.generate.remote( + prompt=str(prompt), + negative_prompt=str(negative_prompt), + num_inference_steps=int(num_inference_steps), + guidance_scale=float(guidance_scale), + num_frames=int(num_frames), + ) + print(f"🍡 video saved to volume at {mp4_name}") - outputs.commit() - print(f"Video saved remotely at: {output_path}") - return output_path + local_dir = Path("/tmp/mochi") + local_dir.mkdir(exist_ok=True, parents=True) + local_path = local_dir / mp4_name + local_path.write_bytes(b"".join(outputs.read_file(mp4_name))) + print(f"🍡 video saved locally at {local_path}") # ## Addenda @@ -340,26 +226,10 @@ def generate_video( # The remainder of the code in this file is utility code. -def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None): - if linear_steps is None: - linear_steps = num_steps // 2 - linear_sigma_schedule = [ - i * threshold_noise / linear_steps for i in range(linear_steps) - ] - threshold_noise_step_diff = linear_steps - threshold_noise * num_steps - quadratic_steps = num_steps - linear_steps - quadratic_coef = threshold_noise_step_diff / ( - linear_steps * quadratic_steps**2 - ) - linear_coef = ( - threshold_noise / linear_steps - - 2 * threshold_noise_step_diff / (quadratic_steps**2) - ) - const = quadratic_coef * (linear_steps**2) - quadratic_sigma_schedule = [ - quadratic_coef * (i**2) + linear_coef * i + const - for i in range(linear_steps, num_steps) - ] - sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0] - sigma_schedule = [1.0 - x for x in sigma_schedule] - return sigma_schedule +def slugify(prompt): + for char in string.punctuation: + prompt = prompt.replace(char, "") + prompt = prompt.replace(" ", "_") + prompt = prompt[:230] # since filenames can't be longer than 255 characters + mp4_name = str(int(time.time())) + "_" + prompt + ".mp4" + return mp4_name