Skip to content

Commit

Permalink
update hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
rachelspark committed Jan 24, 2024
1 parent ce68de2 commit 4ba4ebb
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
2 changes: 1 addition & 1 deletion 06_gpu_and_ml/embeddings/wikipedia/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

# We define our Modal Resources that we'll need
volume = Volume.persisted("embedding-wikipedia")
image = Image.debian_slim().pip_install("datasets")
image = Image.debian_slim().pip_install("datasets", "apache_beam")
stub = Stub(image=image)


Expand Down
19 changes: 10 additions & 9 deletions 06_gpu_and_ml/embeddings/wikipedia/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import subprocess

from modal import Image, Secret, Stub, Volume, gpu, method
from modal import Image, Secret, Stub, Volume, gpu, method, build, enter, exit

# We first set out configuration variables for our script.
## Embedding Containers Configuration
Expand Down Expand Up @@ -66,17 +66,12 @@ def spawn_server() -> subprocess.Popen:
)


def download_model():
spawn_server()


tei_image = (
Image.from_registry(
"ghcr.io/huggingface/text-embeddings-inference:86-0.4.0",
add_python="3.10",
)
.dockerfile_commands("ENTRYPOINT []")
.run_function(download_model, gpu=GPU_CONFIG)
.pip_install("httpx")
)

Expand Down Expand Up @@ -129,14 +124,20 @@ def generate_batches(xs, batch_size):
retries=3,
)
class TextEmbeddingsInference:
def __enter__(self):
@build()
def download_model(self):
spawn_server()

@enter()
def open_connection(self):
# If the process is running for a long time, the client does not seem to close the connections, results in a pool timeout
from httpx import AsyncClient

self.process = spawn_server()
self.client = AsyncClient(base_url="http://127.0.0.1:8000", timeout=30)

def __exit__(self, _exc_type, _exc_value, _traceback):
@exit()
def terminate_connection(self, _exc_type, _exc_value, _traceback):
self.process.terminate()

async def _embed(self, chunk_batch):
Expand Down Expand Up @@ -259,7 +260,7 @@ def upload_result_to_hf(batch_size: int) -> None:
CHECKPOINT_DIR: EMBEDDING_CHECKPOINT_VOLUME,
},
timeout=86400,
secret=Secret.from_name("huggingface-credentials"),
secret=Secret.from_name("huggingface-secret"),
)
def embed_dataset(down_scale: float = 1, batch_size: int = 512 * 50):
"""
Expand Down

0 comments on commit 4ba4ebb

Please sign in to comment.