Skip to content

Commit

Permalink
Convert examples to all use lifecycle method decorators (#586)
Browse files Browse the repository at this point in the history
* Convert examples to all use lifecycle method decorators

* Reinstante exc_info arguments on exit methods
  • Loading branch information
mwaskom authored Feb 21, 2024
1 parent 8a2d5f6 commit 39b3b21
Show file tree
Hide file tree
Showing 23 changed files with 80 additions and 57 deletions.
6 changes: 3 additions & 3 deletions 06_gpu_and_ml/alpaca/alpaca_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@

stub = Stub(name="example-alpaca-lora", image=image)

# The Alpaca-LoRA model is integrated into model as a Python class with an __enter__
# method to take advantage of Modal's container lifecycle functionality.
# The Alpaca-LoRA model is integrated as a Python class with a method decorated
# using `@enter()` to take advantage of Modal's container lifecycle functionality.
#
# https://modal.com/docs/guide/lifecycle-functions#container-lifecycle-beta
#
Expand All @@ -76,7 +76,7 @@ def download_models(self):
LlamaTokenizer.from_pretrained(base_model)

@enter()
def enter(self):
def setup_model(self):
"""
Container-lifeycle method for model setup. Code is taken from
https://github.com/tloen/alpaca-lora/blob/main/generate.py and minor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,21 @@
# ## Defining the prediction function
#
# Instead of a using `@stub.function()` in the global scope,
# we put the method on a class, and define an `__enter__` method on that class.
# we put the method on a class, and define a setup method that we
# decorate with `@modal.enter()`.
#
# Modal reuses containers for successive calls to the same function, so
# we want to take advantage of this and avoid setting up the same model
# for every function call.
#
# Since the transformer model is very CPU-hungry, we allocate 8 CPUs
# to the model.
# Every container that runs will have 8 CPUs set aside for it.
# to the model. Every container that runs will have 8 CPUs set aside for it.


@stub.cls(cpu=8, retries=3)
class SentimentAnalysis:
def __enter__(self):
@modal.enter()
def setup_pipeline(self):
from transformers import pipeline

self.sentiment_pipeline = pipeline(
Expand Down
4 changes: 3 additions & 1 deletion 06_gpu_and_ml/diffusers/train_and_serve_diffusers_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
Stub,
Volume,
asgi_app,
enter,
gpu,
method,
)
Expand Down Expand Up @@ -364,7 +365,8 @@ def run():
volumes=VOLUME_CONFIG, # mount the location where your model weights were saved to
)
class Model:
def __enter__(self):
@enter()
def load_model(self):
import torch
from diffusers import DDIMScheduler, StableDiffusionPipeline

Expand Down
4 changes: 3 additions & 1 deletion 06_gpu_and_ml/dreambooth/dreambooth_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
Stub,
Volume,
asgi_app,
enter,
method,
)

Expand Down Expand Up @@ -271,7 +272,8 @@ def _exec_subprocess(cmd: list[str]):
volumes={MODEL_DIR: volume},
)
class Model:
def __enter__(self):
@enter()
def load_model(self):
import torch
from diffusers import DDIMScheduler, StableDiffusionPipeline

Expand Down
11 changes: 6 additions & 5 deletions 06_gpu_and_ml/embeddings/text_embeddings_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import subprocess
from pathlib import Path

from modal import Image, Secret, Stub, Volume, gpu, method
from modal import Image, Secret, Stub, Volume, enter, exit, gpu, method

GPU_CONFIG = gpu.A10G()
MODEL_ID = "BAAI/bge-base-en-v1.5"
Expand Down Expand Up @@ -93,18 +93,19 @@ def download_model():
allow_concurrent_inputs=10,
)
class TextEmbeddingsInference:
def __enter__(self):
@enter()
def setup_server(self):
self.process = spawn_server()
self.client = AsyncClient(base_url="http://127.0.0.1:8000")

def __exit__(self, _exc_type, _exc_value, _traceback):
@exit()
def teardown_server(self, exc_type, exc_value, traceback):
self.process.terminate()

@method()
async def embed(self, inputs_with_ids: list[tuple[int, str]]):
ids, inputs = zip(*inputs_with_ids)
resp = self.client.post("/embed", json={"inputs": inputs})
resp = await resp
resp = await self.client.post("/embed", json={"inputs": inputs})
resp.raise_for_status()
outputs = resp.json()

Expand Down
2 changes: 1 addition & 1 deletion 06_gpu_and_ml/embeddings/wikipedia/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def open_connection(self):
self.client = AsyncClient(base_url="http://127.0.0.1:8000", timeout=30)

@exit()
def terminate_connection(self, _exc_type, _exc_value, _traceback):
def terminate_connection(self, exc_type, exc_value, traceback):
self.process.terminate()

async def _embed(self, chunk_batch):
Expand Down
7 changes: 4 additions & 3 deletions 06_gpu_and_ml/falcon_bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#
# First we import the components we need from `modal`.

from modal import Image, Stub, gpu, method, web_endpoint
from modal import Image, Stub, enter, gpu, method, web_endpoint


# Spec for an image where falcon-40b-instruct is cached locally
Expand Down Expand Up @@ -62,7 +62,7 @@ def download_falcon_40b():
# ## The model class
#
# Next, we write the model code. We want Modal to load the model into memory just once every time a container starts up,
# so we use [class syntax](/docs/guide/lifecycle-functions) and the __enter__` method.
# so we use [class syntax](/docs/guide/lifecycle-functions) and the `@enter` decorator.
#
# Within the [@stub.cls](/docs/reference/modal.Stub#cls) decorator, we use the [gpu parameter](/docs/guide/gpu)
# to specify that we want to run our function on an [A100 GPU](/pricing). We also allow each call 10 mintues to complete,
Expand All @@ -78,7 +78,8 @@ def download_falcon_40b():
container_idle_timeout=60 * 5, # Keep runner alive for 5 minutes
)
class Falcon40B_4bit:
def __enter__(self):
@enter()
def load_model(self):
import torch
from transformers import (
AutoModelForCausalLM,
Expand Down
7 changes: 4 additions & 3 deletions 06_gpu_and_ml/falcon_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#
# First we import the components we need from `modal`.

from modal import Image, Stub, gpu, method, web_endpoint
from modal import Image, Stub, enter, gpu, method, web_endpoint

# ## Define a container image
#
Expand Down Expand Up @@ -59,7 +59,7 @@ def download_model():
# ## The model class
#
# Next, we write the model code. We want Modal to load the model into memory just once every time a container starts up,
# so we use [class syntax](/docs/guide/lifecycle-functions) and the `__enter__` method.
# so we use [class syntax](/docs/guide/lifecycle-functions) and the `@enter` decorator.
#
# Within the [@stub.cls](/docs/reference/modal.Stub#cls) decorator, we use the [gpu parameter](/docs/guide/gpu)
# to specify that we want to run our function on an [A100 GPU](/pricing). We also allow each call 10 mintues to complete,
Expand All @@ -73,7 +73,8 @@ def download_model():
# yield the text back from the streamer in the main thread. This is an idiosyncrasy with streaming in `transformers`.
@stub.cls(gpu=gpu.A100(), timeout=60 * 10, container_idle_timeout=60 * 5)
class Falcon40BGPTQ:
def __enter__(self):
@enter()
def load_model(self):
from auto_gptq import AutoGPTQForCausalLM
from transformers import AutoTokenizer

Expand Down
5 changes: 3 additions & 2 deletions 06_gpu_and_ml/flan_t5/flan_t5_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pathlib import Path

import modal
from modal import Image, Stub, Volume, method, wsgi_app
from modal import Image, Stub, Volume, enter, method, wsgi_app

VOL_MOUNT_PATH = Path("/vol")

Expand Down Expand Up @@ -221,7 +221,8 @@ def monitor():

@stub.cls(volumes={VOL_MOUNT_PATH: output_vol})
class Summarizer:
def __enter__(self):
@enter()
def load_model(self):
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline

# Load saved tokenizer and finetuned from training run
Expand Down
9 changes: 5 additions & 4 deletions 06_gpu_and_ml/obj_detection_webcam/webcam.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

from fastapi import FastAPI, Request, Response
from fastapi.staticfiles import StaticFiles
from modal import Image, Mount, Stub, asgi_app, build, method
from modal import Image, Mount, Stub, asgi_app, build, enter, method

# We need to install [transformers](https://github.com/huggingface/transformers)
# which is a package Huggingface uses for all their models, but also
Expand Down Expand Up @@ -65,8 +65,8 @@
#
# The object detection function has a few different features worth mentioning:
#
# * There's a container initialization step in the `__enter__` method, which
# runs on every container start. This lets us load the model only once per
# * There's a container initialization step in the method decorated with `@enter()`,
# which runs on every container start. This lets us load the model only once per
# container, so that it's reused for subsequent function calls.
# * Above we stored the model in the container image. This lets us download the model only
# when the image is (re)built, and not everytime the function is called.
Expand Down Expand Up @@ -95,7 +95,8 @@ class ObjectDetection:
def download_model(self):
snapshot_download(repo_id=model_repo_id, cache_dir="/cache")

def __enter__(self):
@enter()
def load_model(self):
self.feature_extractor = DetrImageProcessor.from_pretrained(
model_repo_id,
cache_dir="/cache",
Expand Down
7 changes: 4 additions & 3 deletions 06_gpu_and_ml/openllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#
# First we import the components we need from `modal`.

from modal import Image, Stub, gpu, method
from modal import Image, Stub, enter, gpu, method

# ## Define a container image
#
Expand Down Expand Up @@ -56,7 +56,7 @@ def download_models():
# ## The model class
#
# Next, we write the model code. We want Modal to load the model into memory just once every time a container starts up,
# so we use [class syntax](/docs/guide/lifecycle-functions) and the `__enter__` method.
# so we use [class syntax](/docs/guide/lifecycle-functions) and the `@enter` decorator.
#
# Within the [@stub.cls](/docs/reference/modal.Stub#cls) decorator, we use the [gpu parameter](/docs/guide/gpu)
# to specify that we want to run our function on an [A100 GPU with 20 GB of VRAM](/pricing).
Expand All @@ -67,7 +67,8 @@ def download_models():

@stub.cls(gpu=gpu.A100())
class OpenLlamaModel:
def __enter__(self):
@enter()
def load_model(self):
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer

Expand Down
3 changes: 1 addition & 2 deletions 06_gpu_and_ml/stable_diffusion/stable_diffusion_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@
# since Modal reuses the same containers when possible.
#
# The way to implement this is to turn the Modal function into a method on a
# class that also implement the Python context manager interface, meaning it
# has the `__enter__` method (the `__exit__` method is optional).
# class that also has lifecycle methods (decorated with `@enter()` and/or `@exit()`).
#
# We have also have applied a few model optimizations to make the model run
# faster. On an A10G, the model takes about 6.5s to load into memory, and then
Expand Down
7 changes: 4 additions & 3 deletions 06_gpu_and_ml/stable_diffusion/stable_diffusion_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import time
from pathlib import Path

from modal import Image, Stub, method
from modal import Image, Stub, enter, method

# Create a Stub representing a Modal app.

Expand Down Expand Up @@ -41,13 +41,14 @@ def download_models():

# ## Load model and run inference
#
# The container lifecycle [`__enter__` function](https://modal.com/docs/guide/lifecycle-functions#container-lifecycle-beta)
# The container lifecycle [`@enter` decorator](https://modal.com/docs/guide/lifecycle-functions#container-lifecycle-beta)
# loads the model at startup. Then, we evaluate it in the `run_inference` function.


@stub.cls(gpu="A10G")
class StableDiffusion:
def __enter__(self):
@enter()
def load_model(self):
from optimum.onnxruntime import ORTStableDiffusionPipeline

self.pipe = ORTStableDiffusionPipeline.from_pretrained(
Expand Down
2 changes: 1 addition & 1 deletion 06_gpu_and_ml/stable_diffusion/stable_diffusion_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@

# ## Load model and run inference
#
# The container lifecycle [`__enter__` function](https://modal.com/docs/guide/lifecycle-functions#container-lifecycle-beta)
# The container lifecycle [`@enter` decorator](https://modal.com/docs/guide/lifecycle-functions#container-lifecycle-beta)
# loads the model at startup. Then, we evaluate it in the `run_inference` function.
#
# To avoid excessive cold-starts, we set the idle timeout to 240 seconds, meaning once a GPU has loaded the model it will stay
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

# ## Load model and run inference
#
# The container lifecycle [`__enter__` function](https://modal.com/docs/guide/lifecycle-functions#container-lifecycle-beta)
# The container lifecycle [`@enter` decorator](https://modal.com/docs/guide/lifecycle-functions#container-lifecycle-beta)
# loads the model at startup. Then, we evaluate it in the `inference` function.
#
# To avoid excessive cold-starts, we set the idle timeout to 240 seconds, meaning once a GPU has loaded the model it will stay
Expand Down
3 changes: 2 additions & 1 deletion 06_gpu_and_ml/stable_lm/stable_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ def __init__(
os.environ["HF_HUB_OFFLINE"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"

def __enter__(self):
@modal.enter()
def setup_model(self):
"""
Container-lifeycle method for model setup.
"""
Expand Down
5 changes: 3 additions & 2 deletions 06_gpu_and_ml/text-to-pokemon/text_to_pokemon/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import urllib.request
from datetime import timedelta

from modal import Mount, asgi_app, method
from modal import Mount, asgi_app, enter, method

from . import config, inpaint, ops, pokemon_naming
from .config import stub, volume
Expand Down Expand Up @@ -65,7 +65,8 @@ def image_to_byte_array(image) -> bytes:
gpu="A10G", network_file_systems={config.CACHE_DIR: volume}, keep_warm=1
)
class Model:
def __enter__(self):
@enter()
def load_model(self):
import threading

if not pokemon_naming.rnn_names_output_path.exists():
Expand Down
10 changes: 6 additions & 4 deletions 06_gpu_and_ml/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import subprocess
from pathlib import Path

from modal import Image, Mount, Secret, Stub, asgi_app, gpu, method
from modal import Image, Mount, Secret, Stub, asgi_app, enter, exit, gpu, method

# Next, we set which model to serve, taking care to specify the GPU configuration required
# to fit the model into VRAM, and the quantization method (`bitsandbytes` or `gptq`) if desired.
Expand Down Expand Up @@ -99,7 +99,7 @@ def download_model():
#
# The inference function is best represented with Modal's [class syntax](/docs/guide/lifecycle-functions).
# The class syntax is a special representation for a Modal function which splits logic into two parts:
# 1. the `__enter__` method, which runs once per container when it starts up, and
# 1. the `@enter()` function, which runs once per container when it starts up, and
# 2. the `@method()` function, which runs per inference request.
#
# This means the model is loaded into the GPUs, and the backend for TGI is launched just once when each
Expand All @@ -124,7 +124,8 @@ def download_model():
image=tgi_image,
)
class Model:
def __enter__(self):
@enter()
def start_server(self):
import socket
import time

Expand Down Expand Up @@ -164,7 +165,8 @@ def webserver_ready():

print("Webserver ready!")

def __exit__(self, _exc_type, _exc_value, _traceback):
@exit()
def terminate_server(self, exc_type, exc_value, traceback):
self.launcher.terminate()

@method()
Expand Down
Loading

0 comments on commit 39b3b21

Please sign in to comment.