Skip to content

Commit

Permalink
add: optional location of Whisper model cache directory
Browse files Browse the repository at this point in the history
This will enable Docker-located CloudRun to re-use cached model and not
download it on every restart

Change-Id: I720efe7818219c06c4c1ed47611388b75d59a300
  • Loading branch information
qbit-42 committed Nov 20, 2024
1 parent 4142438 commit 8412a65
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
2 changes: 1 addition & 1 deletion ariel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
# limitations under the License.

"""Ariel library for for end-to-end video ad dubbing using AI."""
__version__ = "0.0.23"
__version__ = "0.0.24"
7 changes: 6 additions & 1 deletion ariel/dubbing.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import shutil
import sys
import time
from typing import Final, Mapping, Set, Sequence
from typing import Final, Mapping, Sequence, Set
from absl import logging
from ariel import audio_processing
from ariel import colab_utils
Expand Down Expand Up @@ -609,6 +609,7 @@ def __init__(
elevenlabs_remove_cloned_voices: bool = False,
number_of_steps: int = _NUMBER_OF_STEPS,
with_verification: bool = True,
whisper_cache_dir: str | None = None,
) -> None:
"""Initializes the Dubber class with various parameters for dubbing configuration.
Expand Down Expand Up @@ -692,6 +693,8 @@ def __init__(
number_of_steps: The total number of steps in the dubbing process.
with_verification: Whether a user wishes to verify, and optionally edit,
the utterance metadata in the dubbing process.
whisper_cache_dir: If given, Whisper model downloaded from HuggingFace
will be stored under this path in the runtime
"""
self._input_file = input_file
self.output_directory = output_directory
Expand Down Expand Up @@ -742,6 +745,7 @@ def __init__(
self._dubbing_from_utterance_metadata = False
self._voice_allocation_needed = False
self._voice_properties_added = False
self._whisper_cache_dir = whisper_cache_dir
create_output_directories(output_directory)

@functools.cached_property
Expand Down Expand Up @@ -810,6 +814,7 @@ def speech_to_text_model(self) -> WhisperModel:
model_size_or_path=_DEFAULT_TRANSCRIPTION_MODEL,
device=self.device,
compute_type="float16" if self.device == "cuda" else "int8",
download_root=self._whisper_cache_dir,
)

def configure_gemini_model(
Expand Down

0 comments on commit 8412a65

Please sign in to comment.