Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update cls hooks #561

Merged
merged 2 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 06_gpu_and_ml/embeddings/wikipedia/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

# We define our Modal Resources that we'll need
volume = Volume.persisted("embedding-wikipedia")
image = Image.debian_slim().pip_install("datasets")
image = Image.debian_slim().pip_install("datasets", "apache_beam")
stub = Stub(image=image)


Expand Down
19 changes: 10 additions & 9 deletions 06_gpu_and_ml/embeddings/wikipedia/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import subprocess

from modal import Image, Secret, Stub, Volume, gpu, method
from modal import Image, Secret, Stub, Volume, build, enter, exit, gpu, method

# We first set out configuration variables for our script.
## Embedding Containers Configuration
Expand Down Expand Up @@ -66,17 +66,12 @@ def spawn_server() -> subprocess.Popen:
)


def download_model():
spawn_server()


tei_image = (
Image.from_registry(
"ghcr.io/huggingface/text-embeddings-inference:86-0.4.0",
add_python="3.10",
)
.dockerfile_commands("ENTRYPOINT []")
.run_function(download_model, gpu=GPU_CONFIG)
.pip_install("httpx")
)

Expand Down Expand Up @@ -129,14 +124,20 @@ def generate_batches(xs, batch_size):
retries=3,
)
class TextEmbeddingsInference:
def __enter__(self):
@build()
def download_model(self):
spawn_server()

@enter()
def open_connection(self):
# If the process is running for a long time, the client does not seem to close the connections, results in a pool timeout
from httpx import AsyncClient

self.process = spawn_server()
self.client = AsyncClient(base_url="http://127.0.0.1:8000", timeout=30)

def __exit__(self, _exc_type, _exc_value, _traceback):
@exit()
def terminate_connection(self, _exc_type, _exc_value, _traceback):
self.process.terminate()

async def _embed(self, chunk_batch):
Expand Down Expand Up @@ -259,7 +260,7 @@ def upload_result_to_hf(batch_size: int) -> None:
CHECKPOINT_DIR: EMBEDDING_CHECKPOINT_VOLUME,
},
timeout=86400,
secret=Secret.from_name("huggingface-credentials"),
secret=Secret.from_name("huggingface-secret"),
)
def embed_dataset(down_scale: float = 1, batch_size: int = 512 * 50):
"""
Expand Down
Loading