diff --git a/06_gpu_and_ml/embeddings/wikipedia/download.py b/06_gpu_and_ml/embeddings/wikipedia/download.py index 1ab08a612..6d0ff4990 100644 --- a/06_gpu_and_ml/embeddings/wikipedia/download.py +++ b/06_gpu_and_ml/embeddings/wikipedia/download.py @@ -8,7 +8,7 @@ # We define our Modal Resources that we'll need volume = Volume.persisted("embedding-wikipedia") -image = Image.debian_slim().pip_install("datasets") +image = Image.debian_slim().pip_install("datasets", "apache_beam") stub = Stub(image=image) diff --git a/06_gpu_and_ml/embeddings/wikipedia/main.py b/06_gpu_and_ml/embeddings/wikipedia/main.py index 87a724f88..cbacf7405 100644 --- a/06_gpu_and_ml/embeddings/wikipedia/main.py +++ b/06_gpu_and_ml/embeddings/wikipedia/main.py @@ -2,7 +2,7 @@ import json import subprocess -from modal import Image, Secret, Stub, Volume, gpu, method +from modal import Image, Secret, Stub, Volume, build, enter, exit, gpu, method # We first set out configuration variables for our script. ## Embedding Containers Configuration @@ -66,17 +66,12 @@ def spawn_server() -> subprocess.Popen: ) -def download_model(): - spawn_server() - - tei_image = ( Image.from_registry( "ghcr.io/huggingface/text-embeddings-inference:86-0.4.0", add_python="3.10", ) .dockerfile_commands("ENTRYPOINT []") - .run_function(download_model, gpu=GPU_CONFIG) .pip_install("httpx") ) @@ -129,14 +124,20 @@ def generate_batches(xs, batch_size): retries=3, ) class TextEmbeddingsInference: - def __enter__(self): + @build() + def download_model(self): + spawn_server() + + @enter() + def open_connection(self): # If the process is running for a long time, the client does not seem to close the connections, results in a pool timeout from httpx import AsyncClient self.process = spawn_server() self.client = AsyncClient(base_url="http://127.0.0.1:8000", timeout=30) - def __exit__(self, _exc_type, _exc_value, _traceback): + @exit() + def terminate_connection(self, _exc_type, _exc_value, _traceback): self.process.terminate() async def _embed(self, chunk_batch): @@ -259,7 +260,7 @@ def upload_result_to_hf(batch_size: int) -> None: CHECKPOINT_DIR: EMBEDDING_CHECKPOINT_VOLUME, }, timeout=86400, - secret=Secret.from_name("huggingface-credentials"), + secret=Secret.from_name("huggingface-secret"), ) def embed_dataset(down_scale: float = 1, batch_size: int = 512 * 50): """