diff --git a/06_gpu_and_ml/llm-serving/sgl_vlm.py b/06_gpu_and_ml/llm-serving/sgl_vlm.py index 0ca5432d1..10cb3c4b1 100644 --- a/06_gpu_and_ml/llm-serving/sgl_vlm.py +++ b/06_gpu_and_ml/llm-serving/sgl_vlm.py @@ -1,4 +1,4 @@ -# # Run LLaVA-Next on SGLang for Visual QA +# # Run Qwen2-VL on SGLang for Visual QA # Vision-Language Models (VLMs) are like LLMs with eyes: # they can generate text based not just on other text, @@ -7,7 +7,7 @@ # This example shows how to run a VLM on Modal using the # [SGLang](https://github.com/sgl-project/sglang) library. -# Here's a sample inference, with the image rendered directly in the terminal: +# Here's a sample inference, with the image rendered directly (and at low resolution) in the terminal: # ![Sample output answering a question about a photo of the Statue of Liberty](https://modal-public-assets.s3.amazonaws.com/sgl_vlm_qa_sol.png) @@ -32,7 +32,7 @@ # If you want to see the model really rip, try an `"a100-80gb"` or an `"h100"` # on a large batch. -GPU_TYPE = os.environ.get("GPU_TYPE", "a10g") +GPU_TYPE = os.environ.get("GPU_TYPE", "l40s") GPU_COUNT = os.environ.get("GPU_COUNT", 1) GPU_CONFIG = f"{GPU_TYPE}:{GPU_COUNT}" @@ -41,13 +41,13 @@ MINUTES = 60 # seconds -# We use a [LLaVA-NeXT](https://huggingface.co/docs/transformers/en/model_doc/llava_next) -# model built on top of Meta's LLaMA 3 8B. +# We use the [Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) +# model by Alibaba. -MODEL_PATH = "lmms-lab/llama3-llava-next-8b" -MODEL_REVISION = "e7e6a9fd5fd75d44b32987cba51c123338edbede" -TOKENIZER_PATH = "lmms-lab/llama3-llava-next-8b-tokenizer" -MODEL_CHAT_TEMPLATE = "llama-3-instruct" +MODEL_PATH = "Qwen/Qwen2-VL-7B-Instruct" +MODEL_REVISION = "a7a06a1cc11b4514ce9edcde0e3ca1d16e5ff2fc" +TOKENIZER_PATH = "Qwen/Qwen2-VL-7B-Instruct" +MODEL_CHAT_TEMPLATE = "qwen2-vl" # We download it from the Hugging Face Hub using the Python function below. @@ -73,12 +73,15 @@ def download_model_to_image(): vlm_image = ( modal.Image.debian_slim(python_version="3.11") .pip_install( # add sglang and some Python dependencies - "sglang[all]==0.1.17", - "transformers==4.40.2", + "transformers==4.47.1", "numpy<2", "fastapi[standard]==0.115.4", "pydantic==2.9.2", "starlette==0.41.2", + "torch==2.4.0", + "sglang[all]==0.4.1", + # as per sglang website: https://sgl-project.github.io/start/install.html + extra_options="--find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/", ) .run_function( # download the model by running a Python function download_model_to_image @@ -94,11 +97,11 @@ def download_model_to_image(): # The code below adds a modal `Cls` to an `App` that runs the VLM. -# We define a method `generate` that takes a URL for an image URL and a question +# We define a method `generate` that takes a URL for an image and a question # about the image as inputs and returns the VLM's answer. # By decorating it with `@modal.web_endpoint`, we expose it as an HTTP endpoint, -# so it can be accessed over the public internet from any client. +# so it can be accessed over the public Internet from any client. app = modal.App("example-sgl-vlm") @@ -129,6 +132,8 @@ def start_runtime(self): @modal.web_endpoint(method="POST", docs=True) def generate(self, request: dict): + from pathlib import Path + import sglang as sgl from term_image.image import from_file @@ -140,18 +145,16 @@ def generate(self, request: dict): if image_url is None: image_url = "https://modal-public-assets.s3.amazonaws.com/golden-gate-bridge.jpg" - image_filename = image_url.split("/")[-1] - image_path = f"/tmp/{uuid4()}-{image_filename}" response = requests.get(image_url) - response.raise_for_status() - with open(image_path, "wb") as file: - file.write(response.content) + image_filename = image_url.split("/")[-1] + image_path = Path(f"/tmp/{uuid4()}-{image_filename}") + image_path.write_bytes(response.content) @sgl.function def image_qa(s, image_path, question): - s += sgl.user(sgl.image(image_path) + question) + s += sgl.user(sgl.image(str(image_path)) + question) s += sgl.assistant(sgl.gen("answer")) question = request.get("question")