From 40f223c336e661abc074f2d9145466385eab8256 Mon Sep 17 00:00:00 2001 From: Riccardo Freschi Date: Sun, 30 Jun 2024 07:14:54 +0200 Subject: [PATCH] first draft --- .../Dockerfile | 29 ++++ .../model_definition.py | 46 +++++ .../ray-serve-stablediffusion.yaml | 157 ++++++++++++++++++ .../ray_serve_stablediffusion.py | 53 ++++++ .../stablediffusion-nvidia-triton-server.md | 105 ++++++++++++ 5 files changed, 390 insertions(+) create mode 100644 gen-ai/inference/stable-diffusion-rayserve-nvidia-triton-server/Dockerfile create mode 100644 gen-ai/inference/stable-diffusion-rayserve-nvidia-triton-server/model_definition.py create mode 100644 gen-ai/inference/stable-diffusion-rayserve-nvidia-triton-server/ray-serve-stablediffusion.yaml create mode 100644 gen-ai/inference/stable-diffusion-rayserve-nvidia-triton-server/ray_serve_stablediffusion.py create mode 100644 website/docs/gen-ai/inference/stablediffusion-nvidia-triton-server.md diff --git a/gen-ai/inference/stable-diffusion-rayserve-nvidia-triton-server/Dockerfile b/gen-ai/inference/stable-diffusion-rayserve-nvidia-triton-server/Dockerfile new file mode 100644 index 000000000..6e8b3a20d --- /dev/null +++ b/gen-ai/inference/stable-diffusion-rayserve-nvidia-triton-server/Dockerfile @@ -0,0 +1,29 @@ +# docker buildx build --platform=linux/amd64 -t triton-python-api:24.01-py3 -f Dockerfile . + +ARG BASE_IMAGE=nvcr.io/nvidia/tritonserver +ARG BASE_IMAGE_TAG=24.01-py3 + +FROM ${BASE_IMAGE}:${BASE_IMAGE_TAG} as triton-python-api + +# Maintainer label +LABEL maintainer="DoEKS" + +# Switch back to a non-root user for the subsequent commands +USER $USER + +RUN pip install --timeout=2000 "ray[serve]" numpy requests fastapi Pillow scipy accelerate + +RUN find /opt/tritonserver/python -maxdepth 1 -type f -name \ + "tritonserver-*.whl" | xargs -I {} pip3 install --force-reinstall --upgrade {}[all] + +# Set a working directory +WORKDIR /serve_app + +# Copy your Ray Serve script into the container +COPY ray_serve_stablediffusion.py /serve_app/ray_serve_stablediffusion.py + +# Copy Triton Server models +COPY ./models /serve_app/models + +# Set the PYTHONPATH environment variable +ENV PYTHONPATH=/serve_app:$PYTHONPATH diff --git a/gen-ai/inference/stable-diffusion-rayserve-nvidia-triton-server/model_definition.py b/gen-ai/inference/stable-diffusion-rayserve-nvidia-triton-server/model_definition.py new file mode 100644 index 000000000..1ee8f429d --- /dev/null +++ b/gen-ai/inference/stable-diffusion-rayserve-nvidia-triton-server/model_definition.py @@ -0,0 +1,46 @@ +import torch +from diffusers import AutoencoderKL +from transformers import CLIPTextModel, CLIPTokenizer + +prompt = "Draw a dog" +vae = AutoencoderKL.from_pretrained( + "CompVis/stable-diffusion-v1-4", subfolder="vae", use_auth_token=True +) + +tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") +text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") + +vae.forward = vae.decode +torch.onnx.export( + vae, + (torch.randn(1, 4, 64, 64), False), + "vae.onnx", + input_names=["latent_sample", "return_dict"], + output_names=["sample"], + dynamic_axes={ + "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + }, + do_constant_folding=True, + opset_version=14, +) + +text_input = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", +) + +torch.onnx.export( + text_encoder, + (text_input.input_ids.to(torch.int32)), + "encoder.onnx", + input_names=["input_ids"], + output_names=["last_hidden_state", "pooler_output"], + dynamic_axes={ + "input_ids": {0: "batch", 1: "sequence"}, + }, + opset_version=14, + do_constant_folding=True, +) diff --git a/gen-ai/inference/stable-diffusion-rayserve-nvidia-triton-server/ray-serve-stablediffusion.yaml b/gen-ai/inference/stable-diffusion-rayserve-nvidia-triton-server/ray-serve-stablediffusion.yaml new file mode 100644 index 000000000..70eea8e13 --- /dev/null +++ b/gen-ai/inference/stable-diffusion-rayserve-nvidia-triton-server/ray-serve-stablediffusion.yaml @@ -0,0 +1,157 @@ +#---------------------------------------------------------------------- +# NOTE: For deployment instructions, refer to the DoEKS website. +#---------------------------------------------------------------------- +--- +apiVersion: v1 +kind: Namespace +metadata: + name: stablediffusion +--- +apiVersion: ray.io/v1 +kind: RayService +metadata: + name: stablediffusion + namespace: stablediffusion +spec: + serviceUnhealthySecondThreshold: 900 + deploymentUnhealthySecondThreshold: 300 +# Ray Serve can automatically scale deployment replicas up to 3 based on incoming traffic. +# Each replica in this example requires one GPU. For GPU types g5.xlarge to g5.16xlarge, a single node with one GPU can only run one replica. +# However, g5.12xlarge comes with 4 GPUs, allowing you to run all 4 replicas on a single node. + serveConfigV2: | + applications: + - name: stable-diffusion-deployment + import_path: "ray_serve_stablediffusion:entrypoint" + route_prefix: "/" + deployments: + - name: stable-diffusion-nvidia-triton-server + autoscaling_config: + metrics_interval_s: 0.2 + min_replicas: 1 + max_replicas: 4 + look_back_period_s: 2 + downscale_delay_s: 600 + upscale_delay_s: 30 + target_num_ongoing_requests_per_replica: 1 + graceful_shutdown_timeout_s: 6 + max_concurrent_queries: 100 + ray_actor_options: + num_cpus: 3 + num_gpus: 1 + rayClusterConfig: + rayVersion: '2.11.0' + enableInTreeAutoscaling: true + headGroupSpec: + headService: + metadata: + name: stablediffusion + namespace: stablediffusion + rayStartParams: + dashboard-host: '0.0.0.0' + template: + spec: + containers: + # Important Performance Note: + # This image is large (14.5GB). + # For faster inference scaling, consider building a custom image with only your workload's essential dependencies. + # Smaller images lead to faster scaling, especially across multiple nodes. + # Notice that we are using the same image for both the head and worker nodes. You might hit ModuleNotFoundError if you use a different image for head and worker nodes. + - name: head + image: public.ecr.aws/h6c7e9p3/triton-python-api:24.01-py3-g5 + imagePullPolicy: IfNotPresent # Ensure the image is always pulled when updated + lifecycle: + preStop: + exec: + command: ["/bin/sh", "-c", "ray stop"] + ports: + - containerPort: 6379 + name: gcs-server + - containerPort: 8265 + name: dashboard + - containerPort: 10001 + name: client + - containerPort: 8000 + name: serve + volumeMounts: + - mountPath: /tmp/ray + name: ray-logs + resources: + limits: + cpu: 2 + memory: 16Gi + requests: + cpu: 2 + memory: 16Gi + nodeSelector: + NodeGroupType: x86-cpu-karpenter + type: karpenter + volumes: + - name: ray-logs + emptyDir: {} + workerGroupSpecs: + - groupName: gpu + # With g5.2xlarge instance, Ray can scale up to 4 nodes, with one pod per node and 1 GPU per pod. + minReplicas: 1 + maxReplicas: 4 + rayStartParams: {} + template: + spec: + containers: + - name: worker + image: public.ecr.aws/h6c7e9p3/triton-python-api:24.01-py3-g5 + env: + - name: MODEL_REPOSITORY + value: "/serve_app/models" + imagePullPolicy: IfNotPresent # Ensure the image is always pulled when updated + lifecycle: + preStop: + exec: + command: ["/bin/sh", "-c", "ray stop"] + resources: + limits: + cpu: "3" + memory: "14Gi" + nvidia.com/gpu: 1 + requests: + cpu: "3" + memory: "14Gi" + nvidia.com/gpu: 1 + nodeSelector: + NodeGroupType: g5-gpu-karpenter + type: karpenter + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" +--- +apiVersion: networking.k8s.io/v1 +kind: Ingress +metadata: + name: stablediffusion + namespace: stablediffusion + annotations: + nginx.ingress.kubernetes.io/rewrite-target: "/$1" +spec: + ingressClassName: nginx + rules: + - http: + paths: + # Ray Dashboard: you can access the dashboard using the NLB DNS name with port 8265 and the path "/dashboard". However, the NLB is currently configured as private, preventing access from outside the VPC. + # To access the dashboard: + # Option 1: Use the `kubectl port-forward` command. + # Option 2: Deploy a public-facing NLB by modifying the configuration file "ai-ml/jark-stack/terraform/helm-values/ingress-nginx-values.yaml". + - path: /dashboard/(.*) + pathType: ImplementationSpecific + backend: + service: + name: stablediffusion + port: + number: 8265 + # Ray Serve + - path: /serve/(.*) + pathType: ImplementationSpecific + backend: + service: + name: stablediffusion + port: + number: 8000 diff --git a/gen-ai/inference/stable-diffusion-rayserve-nvidia-triton-server/ray_serve_stablediffusion.py b/gen-ai/inference/stable-diffusion-rayserve-nvidia-triton-server/ray_serve_stablediffusion.py new file mode 100644 index 000000000..4415646f3 --- /dev/null +++ b/gen-ai/inference/stable-diffusion-rayserve-nvidia-triton-server/ray_serve_stablediffusion.py @@ -0,0 +1,53 @@ +from io import BytesIO +import numpy +import tritonserver +from fastapi import FastAPI +from PIL import Image +from ray import serve +from fastapi.responses import Response +import os + +app = FastAPI() + +@serve.deployment(name="stable-diffusion-nvidia-triton-server", ray_actor_options={"num_gpus": 1}) +@serve.ingress(app) +class TritonDeployment: + def __init__(self): + self._triton_server = tritonserver + + model_repository = [os.getenv('MODEL_REPOSITORY')] + + self._triton_server = tritonserver.Server( + model_repository=model_repository, + model_control_mode=tritonserver.ModelControlMode.EXPLICIT, + log_info=False, + ) + self._triton_server.start(wait_until_ready=True) + + @app.get("/imagine") + def generate(self, prompt: str, filename: str = "generated_image.jpg") -> None: + print("call done") + if not self._triton_server.model("stable_diffusion").ready(): + try: + self._triton_server.load("text_encoder") + self._triton_server.load("vae") + self._stable_diffusion = self._triton_server.load("stable_diffusion") + if not self._stable_diffusion.ready(): + raise Exception("Model not ready") + except Exception as error: + print(f"Error can't load stable diffusion model, {error}") + return + + for response in self._stable_diffusion.infer(inputs={"prompt": [[prompt]]}): + generated_image = ( + numpy.from_dlpack(response.outputs["generated_image"]) + .squeeze() + .astype(numpy.uint8) + ) + + image = Image.fromarray(generated_image) + file_stream = BytesIO() + image.save(file_stream, "PNG") + return Response(content=file_stream.getvalue(), media_type="image/png") + +entrypoint = TritonDeployment.bind() diff --git a/website/docs/gen-ai/inference/stablediffusion-nvidia-triton-server.md b/website/docs/gen-ai/inference/stablediffusion-nvidia-triton-server.md new file mode 100644 index 000000000..385a9b170 --- /dev/null +++ b/website/docs/gen-ai/inference/stablediffusion-nvidia-triton-server.md @@ -0,0 +1,105 @@ +--- +title: Stable Diffusion with NVIDIA Triton Inference Server and Ray Serve +sidebar_position: 8 +--- + +# Deploying Stable Diffusion with NVIDIA Triton Inference Server, Ray Serve and Gradio + +This pattern is based on the [Serving models with Triton Server in Ray Serve](https://docs.ray.io/en/latest/serve/tutorials/triton-server-integration.html) example from the Ray documentation. The deployed model is *Stable Diffusion v1-4* from the *Computer Vision & Learning research group* ([CompVis](https://huggingface.co/CompVis)). + +## What is NVIDIA Triton Server + +:::info + +Section under construction + +::: + +## Deploying the Solution + +To deploy the solution, you can follow the same steps explained in the [*Stable Diffusion on GPU* pattern](https://awslabs.github.io/data-on-eks/docs/gen-ai/inference/stablediffusion-gpus), just remember to use `/data-on-eks/gen-ai/inference/stable-diffusion-rayserve-nvidia-triton-server` where it says `/data-on-eks/gen-ai/inference/stable-diffusion-rayserve-gpu`. + +The `RayService` manifest references a pre-built Docker image, hosted on a public Amazon ECR repository, which already includes everything needed to run inference. + +Should you wish to build it yourself as an exercise, you can perform the following steps: + +1. **Launch and connecto to an Amazon EC2 G5 instance** +1. **Build and export the model** +1. **Build the model repository** +1. **Build the Docker image** +1. **Push the newly built image to an Amazon ECR repository** + +## Launch an Amazon EC2 G5 instance + +1. Select a Deep Learning AMI, e.g. *Deep Learning OSS Nvidia Driver AMI GPU PyTorch* +1. Select a `g5.12xlarge` instance (**note**: at least `12xlarge`) +1. Set the *Root volume* size to 300 GB + +**Important Note**: we selected `g5`, because we need to build our model on the same instance type as the one where it will run, and our Ray cluster runs on Amazon EC2 G5 instances. + +You can connect to ig by e.g. [using an SSH client](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/connect-linux-inst-ssh.html). + +## Build and export the model + +Install the dependencies: + +```bash +pip install --timeout=2000 torch diffusers transformers onnx numpy +``` + +Clone the repository: + +```bash +git clone https://github.com/awslabs/data-on-eks.git +``` + +Move to the blueprint directory and execute `model_definition.py` file: + +```bash +cd ./data-on-eks/gen-ai/inference/stable-diffusion-rayserve-nvidia-triton-server +python3 model_definition.py +``` + +The outputs are 2 files: `vae.onnx` and `encoder.onnx`. Convert the `vae.onnx` model to the TensorRT engine serialized file: + +```bash +/usr/src/tensorrt/bin/trtexec --onnx=vae.onnx --saveEngine=vae.plan --minShapes=latent_sample:1x4x64x64 --optShapes=latent_sample:4x4x64x64 --maxShapes=latent_sample:8x4x64x64 --fp16 +``` + +## Build the model repository + +The Triton Inference Server requires a [model repository](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_repository.md) with a specific structure. Execute the following commands to build it, and copy over the necessary files: + +```bash +mkdir -p models/vae/1 +mkdir -p models/text_encoder/1 +mkdir -p models/stable_diffusion/1 +mv vae.plan models/vae/1/model.plan +mv encoder.onnx models/text_encoder/1/model.onnx +curl -L "https://raw.githubusercontent.com/triton-inference-server/tutorials/main/Conceptual_Guide/Part_6-building_complex_pipelines/model_repository/pipeline/1/model.py" > models/stable_diffusion/1/model.py +curl -L "https://raw.githubusercontent.com/triton-inference-server/tutorials/main/Conceptual_Guide/Part_6-building_complex_pipelines/model_repository/pipeline/config.pbtxt" > models/stable_diffusion/config.pbtxt +curl -L "https://raw.githubusercontent.com/triton-inference-server/tutorials/main/Conceptual_Guide/Part_6-building_complex_pipelines/model_repository/text_encoder/config.pbtxt" > models/text_encoder/config.pbtxt +curl -L "https://raw.githubusercontent.com/triton-inference-server/tutorials/main/Conceptual_Guide/Part_6-building_complex_pipelines/model_repository/vae/config.pbtxt" > models/vae/config.pbtxt +``` + +Copy the Ray Serve `ray_serve_stablediffusion.py` file and the model repository under a new `serve_app` directory: + +```bash +mkdir /serve_app +cp ray_serve_stablediffusion.py /serve_app/ray_serve_stablediffusion.py +cp models /serve_app/models +``` + +You are ready to build the image. + +## Build the Docker image + +```bash +docker build -t triton-python-api:24.01-py3 -f Dockerfile . +``` + +## Push the newly built image to an Amazon ECR repository + +You can publish your image to either a public or private Amazon ECR repository. Follow e.g. [this guideline](https://docs.aws.amazon.com/AmazonECR/latest/public/docker-push-ecr-image.html) to publish to a public one. + +To run your image on the Ray cluster, you need to replace the exiting image references in `ray-serve-stablediffusion.yaml` with yours.