From b6c9e64789121f2fbbfe2b9ba0df388739fd331a Mon Sep 17 00:00:00 2001 From: luke-lombardi <33990301+luke-lombardi@users.noreply.github.com> Date: Mon, 21 Oct 2024 11:21:00 -0400 Subject: [PATCH] Feat: Realtime abstraction built on top of asgi (#643) This is an experimental decorator that configures a basic websocket streaming example under the hood. The idea is to have a drop-in-place feature that lets users create streaming handlers that can do image generation, chat, etc. without writing the ASGI boilerplate. **TODO**: - [x] Add automatic type inference - [ ] Create client examples ``` from beta9 import realtime def load_model(): return "my_model" @realtime( cpu=1, memory=128, timeout=180, concurrent_requests=10, on_start=load_model, authorized=False, ) def handler(*, context, input: str): print(context.on_start_value) return input ``` The output from this echo server looks like this: ``` luke@Lukes-MBP beta9 % websocat ws://localhost:1994/asgi/public/bbfb4433-9685-4f79-9d92-68fb57b240c3 do something do something echo echo here here ``` --- sdk/pyproject.toml | 2 +- sdk/src/beta9/__init__.py | 2 + sdk/src/beta9/abstractions/endpoint.py | 188 +++++++++++++++++++++++++ sdk/src/beta9/runner/endpoint.py | 14 +- 4 files changed, 203 insertions(+), 3 deletions(-) diff --git a/sdk/pyproject.toml b/sdk/pyproject.toml index a1ca1357c..f197322b1 100644 --- a/sdk/pyproject.toml +++ b/sdk/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "beta9" -version = "0.1.93" +version = "0.1.94" description = "" authors = ["beam.cloud "] packages = [ diff --git a/sdk/src/beta9/__init__.py b/sdk/src/beta9/__init__.py index 1bc2302d6..043af96a6 100644 --- a/sdk/src/beta9/__init__.py +++ b/sdk/src/beta9/__init__.py @@ -3,6 +3,7 @@ from .abstractions.container import Container from .abstractions.endpoint import ASGI as asgi from .abstractions.endpoint import Endpoint as endpoint +from .abstractions.endpoint import RealtimeASGI as realtime from .abstractions.function import Function as function from .abstractions.function import Schedule as schedule from .abstractions.image import Image @@ -24,6 +25,7 @@ "function", "endpoint", "asgi", + "realtime", "Container", "env", "GpuType", diff --git a/sdk/src/beta9/abstractions/endpoint.py b/sdk/src/beta9/abstractions/endpoint.py index a793bcd0c..1abd113eb 100644 --- a/sdk/src/beta9/abstractions/endpoint.py +++ b/sdk/src/beta9/abstractions/endpoint.py @@ -1,7 +1,10 @@ import os import threading +import traceback from typing import Any, Callable, List, Optional, Union +from uvicorn.protocols.utils import ClientDisconnected + from .. import terminal from ..abstractions.base.runner import ( ASGI_DEPLOYMENT_STUB_TYPE, @@ -276,6 +279,191 @@ def __init__( self.is_asgi = True +REALTIME_ASGI_SLEEP_INTERVAL_SECONDS = 0.1 +REALTIME_ASGI_HEALTH_CHECK_INTERVAL_SECONDS = 5 + + +class RealtimeASGI(ASGI): + """ + Decorator which allows you to create a realtime application built on top of ASGI / websockets. + Your handler function will run every time a message is received over the websocket. + + Parameters: + cpu (Union[int, float, str]): + The number of CPU cores allocated to the container. Default is 1.0. + memory (Union[int, str]): + The amount of memory allocated to the container. It should be specified in + MiB, or as a string with units (e.g. "1Gi"). Default is 128 MiB. + gpu (Union[GpuType, str]): + The type or name of the GPU device to be used for GPU-accelerated tasks. If not + applicable or no GPU required, leave it empty. Default is [GpuType.NoGPU](#gputype). + image (Union[Image, dict]): + The container image used for the task execution. Default is [Image](#image). + volumes (Optional[List[Volume]]): + A list of volumes to be mounted to the ASGI application. Default is None. + timeout (Optional[int]): + The maximum number of seconds a task can run before it times out. + Default is 3600. Set it to -1 to disable the timeout. + workers (Optional[int]): + The number of processes handling tasks per container. + Modifying this parameter can improve throughput for certain workloads. + Workers will share the CPU, Memory, and GPU defined. + You may need to increase these values to increase concurrency. + Default is 1. + concurrent_requests (int): + The maximum number of concurrent requests the ASGI application can handle. + Unlike regular endpoints that process requests synchronously, ASGI applications + can handle multiple requests concurrently. This parameter allows you to specify + the level of concurrency. For applications with blocking operations, this can + improve throughput by allowing the application to process other requests while + waiting for blocking operations to complete. Default is 1. + keep_warm_seconds (int): + The duration in seconds to keep the task queue warm even if there are no pending + tasks. Keeping the queue warm helps to reduce the latency when new tasks arrive. + Default is 10s. + max_pending_tasks (int): + The maximum number of tasks that can be pending in the queue. If the number of + pending tasks exceeds this value, the task queue will stop accepting new tasks. + Default is 100. + secrets (Optional[List[str]): + A list of secrets that are injected into the container as environment variables. Default is []. + name (Optional[str]): + An optional name for this ASGI application, used during deployment. If not specified, you must + specify the name at deploy time with the --name argument + authorized (Optional[str]): + If false, allows the ASGI application to be invoked without an auth token. + Default is True. + autoscaler (Optional[Autoscaler]): + Configure a deployment autoscaler - if specified, you can use scale your function horizontally using + various autoscaling strategies (Defaults to QueueDepthAutoscaler()) + callback_url (Optional[str]): + An optional URL to send a callback to when a task is completed, timed out, or cancelled. + Example: + ```python + from beta9 import realtime + + def generate_text(): + return ["this", "could", "be", "anything"] + + @realtime( + cpu=1.0, + memory=128, + gpu="T4" + ) + def handler(context): + return generate_text() + ``` + """ + + def __init__( + self, + cpu: Union[int, float, str] = 1.0, + memory: Union[int, str] = 128, + gpu: GpuTypeAlias = GpuType.NoGPU, + image: Image = Image(), + timeout: int = 180, + workers: int = 1, + concurrent_requests: int = 1, + keep_warm_seconds: int = 180, + max_pending_tasks: int = 100, + on_start: Optional[Callable] = None, + volumes: Optional[List[Volume]] = None, + secrets: Optional[List[str]] = None, + name: Optional[str] = None, + authorized: bool = True, + autoscaler: Autoscaler = QueueDepthAutoscaler(), + callback_url: Optional[str] = None, + ): + super().__init__( + cpu=cpu, + memory=memory, + gpu=gpu, + image=image, + timeout=timeout, + workers=workers, + keep_warm_seconds=keep_warm_seconds, + max_pending_tasks=max_pending_tasks, + on_start=on_start, + volumes=volumes, + secrets=secrets, + name=name, + authorized=authorized, + autoscaler=autoscaler, + callback_url=callback_url, + concurrent_requests=concurrent_requests, + ) + + def __call__(self, func): + import asyncio + from queue import Queue + + from fastapi import FastAPI, WebSocket, WebSocketDisconnect, WebSocketException + + internal_asgi_app = FastAPI() + internal_asgi_app.input_queue = Queue() + + @internal_asgi_app.websocket("/") + async def stream(websocket: WebSocket): + async def _heartbeat(): + while True: + try: + await websocket.send_json({"type": "ping"}) + await asyncio.sleep(REALTIME_ASGI_HEALTH_CHECK_INTERVAL_SECONDS) + except (WebSocketDisconnect, WebSocketException, RuntimeError): + break + + await websocket.accept() + + heartbeat_task = asyncio.create_task(_heartbeat()) + try: + while True: + try: + message = await websocket.receive() + + if "text" in message: + data = message["text"] + elif "bytes" in message: + data = message["bytes"] + elif "json" in message: + data = message["json"] + elif message.get("type") == "websocket.disconnect": + return + else: + print(f"WS stream error - unknown message type: {message}") + continue + + internal_asgi_app.input_queue.put(data) + + while not internal_asgi_app.input_queue.empty(): + output = internal_asgi_app.handler( + context=internal_asgi_app.context, + input=internal_asgi_app.input_queue.get(), + ) + + if isinstance(output, str): + await websocket.send_text(output) + elif isinstance(output, dict) or isinstance(output, list): + await websocket.send_json(output) + else: + await websocket.send_bytes(output) + + await asyncio.sleep(REALTIME_ASGI_SLEEP_INTERVAL_SECONDS) + except ( + WebSocketDisconnect, + WebSocketException, + RuntimeError, + ClientDisconnected, + ): + return + except BaseException: + print(f"Unhandled exception in websocket stream: {traceback.format_exc()}") + finally: + heartbeat_task.cancel() + + func.internal_asgi_app = internal_asgi_app + return _CallableWrapper(func, self) + + class _CallableWrapper(DeployableMixin): deployment_stub_type = ENDPOINT_DEPLOYMENT_STUB_TYPE diff --git a/sdk/src/beta9/runner/endpoint.py b/sdk/src/beta9/runner/endpoint.py index a05f50ede..a86bac1d2 100644 --- a/sdk/src/beta9/runner/endpoint.py +++ b/sdk/src/beta9/runner/endpoint.py @@ -5,7 +5,7 @@ import traceback from contextlib import asynccontextmanager from http import HTTPStatus -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union from fastapi import Depends, FastAPI, Request from fastapi.responses import JSONResponse, Response @@ -146,7 +146,17 @@ def __init__(self, logger: logging.Logger, worker: UvicornWorker) -> None: task_id=None, on_start_value=self.on_start_value, ) - self.app = self.handler(context) + + app: Union[FastAPI, None] = None + internal_asgi_app = getattr(self.handler.handler.func, "internal_asgi_app", None) + if internal_asgi_app is not None: + app = internal_asgi_app + app.context = context + app.handler = self.handler + else: + app = self.handler(context) + + self.app = app if not is_asgi3(self.app): raise ValueError("Invalid ASGI app returned from handler")