From 8ddff12d3b24f12fc99fcb9c84e597486fc9bc4d Mon Sep 17 00:00:00 2001 From: Charles Frye Date: Thu, 14 Nov 2024 21:06:04 -0800 Subject: [PATCH] stop mounting volume on .cache, just mount huggingface cache (#975) --- 06_gpu_and_ml/text-to-video/mochi.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/06_gpu_and_ml/text-to-video/mochi.py b/06_gpu_and_ml/text-to-video/mochi.py index d64e12950..6135c2bf2 100644 --- a/06_gpu_and_ml/text-to-video/mochi.py +++ b/06_gpu_and_ml/text-to-video/mochi.py @@ -61,7 +61,7 @@ ) ) -app = modal.App("example-mochi", image=image) +app = modal.App("example-mochi") with image.imports(): import numpy as np @@ -85,7 +85,7 @@ model = modal.Volume.from_name("mochi-model", create_if_missing=True) outputs = modal.Volume.from_name("mochi-outputs", create_if_missing=True) -MODEL_CACHE = Path("/root/.cache") # remote path for saving the model +MODEL_CACHE = Path("/models") # remote path for saving the model OUTPUTS_PATH = "/outputs" # remote path for saving video outputs # We download the model using the `hf-transfer` @@ -116,12 +116,23 @@ "transformers", "sentencepiece", ) - .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) + .env( + { + "HF_HUB_ENABLE_HF_TRANSFER": "1", + "HF_HOME": str(MODEL_CACHE / "huggingface"), + } + ) +) + +image = image.env( # so we look for the model in the right place + {"HF_HOME": str(MODEL_CACHE / "huggingface")} ) @app.function( - volumes={MODEL_CACHE: model}, timeout=2 * HOURS, image=download_image + volumes={MODEL_CACHE: model}, + timeout=2 * HOURS, + image=download_image, ) def download_model( model_revision: str = "8e9673c5349979457e515fddd38911df6b4ca07f", @@ -215,11 +226,11 @@ def main( # boot takes a while, so we keep the container warm for 20 minutes after the last call finishes timeout=1 * HOURS, container_idle_timeout=20 * MINUTES, + image=image, ) class Mochi: @modal.enter() def load_model(self): - model.reload() ray.init() model_path = MODEL_CACHE / "mochi-1-preview" vae_stats_path = f"{model_path}/vae_stats.json" @@ -232,7 +243,7 @@ def load_model(self): f"🍡 WARNING: Mochi requires at least 4xH100 GPUs, but only {num_gpus} GPU(s) are available." ) print( - f"🍡 loading model to {num_gpus} GPUs. This can take 5-15 minutes." + f"🍡 loading model to {num_gpus} GPUs. This can take a few minutes." ) self.model = MochiWrapper( num_workers=num_gpus,