Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Migrate batched_whisper example to use a Volume for the HuggingFace c…
Browse files Browse the repository at this point in the history
…ache
mwaskom committed Jan 7, 2025
1 parent 15d9a2a commit b803974
Showing 1 changed file with 27 additions and 18 deletions.
45 changes: 27 additions & 18 deletions 06_gpu_and_ml/openai_whisper/batched_whisper.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,6 @@

# Let's start by importing the Modal client and defining the model that we want to serve.

import os

import modal

@@ -39,10 +38,35 @@
"datasets==3.2.0",
)
# Use the barebones `hf-transfer` package for maximum download speeds. No progress bar, but expect 700MB/s.
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1", "HF_HUB_CACHE": MODEL_DIR})
)

app = modal.App("example-whisper-batched-inference", image=image)
model_cache = modal.Volume.from_name(
"huggingface-hub-cache", create_if_missing=True
)
app = modal.App(
"example-whisper-batched-inference",
image=image,
volumes={MODEL_DIR: model_cache},
)

# ## Caching the model weights

# We'll define a function to download the model and cache it in a volume.
# You can `modal run` against this function prior to deploying the App.


@app.function()
def download_model():
from huggingface_hub import snapshot_download
from transformers.utils import move_cache

snapshot_download(
MODEL_NAME,
ignore_patterns=["*.pt", "*.bin"], # Using safetensors
revision=MODEL_REVISION,
)
move_cache()


# ## The model class
@@ -72,21 +96,6 @@
concurrency_limit=10, # default max GPUs for Modal's free tier
)
class Model:
@modal.build()
def download_model(self):
from huggingface_hub import snapshot_download
from transformers.utils import move_cache

os.makedirs(MODEL_DIR, exist_ok=True)

snapshot_download(
MODEL_NAME,
local_dir=MODEL_DIR,
ignore_patterns=["*.pt", "*.bin"], # Using safetensors
revision=MODEL_REVISION,
)
move_cache()

@modal.enter()
def load_model(self):
import torch

0 comments on commit b803974

Please sign in to comment.