Skip to content

Commit

Permalink
stop mounting volume on .cache, just mount huggingface cache (#975)
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesfrye authored Nov 15, 2024
1 parent 8110625 commit 8ddff12
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions 06_gpu_and_ml/text-to-video/mochi.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
)
)

app = modal.App("example-mochi", image=image)
app = modal.App("example-mochi")

with image.imports():
import numpy as np
Expand All @@ -85,7 +85,7 @@
model = modal.Volume.from_name("mochi-model", create_if_missing=True)
outputs = modal.Volume.from_name("mochi-outputs", create_if_missing=True)

MODEL_CACHE = Path("/root/.cache") # remote path for saving the model
MODEL_CACHE = Path("/models") # remote path for saving the model
OUTPUTS_PATH = "/outputs" # remote path for saving video outputs

# We download the model using the `hf-transfer`
Expand Down Expand Up @@ -116,12 +116,23 @@
"transformers",
"sentencepiece",
)
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
.env(
{
"HF_HUB_ENABLE_HF_TRANSFER": "1",
"HF_HOME": str(MODEL_CACHE / "huggingface"),
}
)
)

image = image.env( # so we look for the model in the right place
{"HF_HOME": str(MODEL_CACHE / "huggingface")}
)


@app.function(
volumes={MODEL_CACHE: model}, timeout=2 * HOURS, image=download_image
volumes={MODEL_CACHE: model},
timeout=2 * HOURS,
image=download_image,
)
def download_model(
model_revision: str = "8e9673c5349979457e515fddd38911df6b4ca07f",
Expand Down Expand Up @@ -215,11 +226,11 @@ def main(
# boot takes a while, so we keep the container warm for 20 minutes after the last call finishes
timeout=1 * HOURS,
container_idle_timeout=20 * MINUTES,
image=image,
)
class Mochi:
@modal.enter()
def load_model(self):
model.reload()
ray.init()
model_path = MODEL_CACHE / "mochi-1-preview"
vae_stats_path = f"{model_path}/vae_stats.json"
Expand All @@ -232,7 +243,7 @@ def load_model(self):
f"🍡 WARNING: Mochi requires at least 4xH100 GPUs, but only {num_gpus} GPU(s) are available."
)
print(
f"🍡 loading model to {num_gpus} GPUs. This can take 5-15 minutes."
f"🍡 loading model to {num_gpus} GPUs. This can take a few minutes."
)
self.model = MochiWrapper(
num_workers=num_gpus,
Expand Down

0 comments on commit 8ddff12

Please sign in to comment.