Skip to content

Commit

Permalink
Feat: Realtime abstraction built on top of asgi (#643)
Browse files Browse the repository at this point in the history
This is an experimental decorator that configures a basic websocket
streaming example under the hood. The idea is to have a drop-in-place
feature that lets users create streaming handlers that can do image
generation, chat, etc. without writing the ASGI boilerplate.

**TODO**:
- [x] Add automatic type inference
- [ ] Create client examples

```
from beta9 import realtime


def load_model():
    return "my_model"


@realtime(
    cpu=1,
    memory=128,
    timeout=180,
    concurrent_requests=10,
    on_start=load_model,
    authorized=False,
)
def handler(*, context, input: str):
    print(context.on_start_value)
    return input

```

The output from this echo server looks like this:
```
luke@Lukes-MBP beta9 % websocat ws://localhost:1994/asgi/public/bbfb4433-9685-4f79-9d92-68fb57b240c3
do something
do something
echo
echo
here
here
```
  • Loading branch information
luke-lombardi authored Oct 21, 2024
1 parent 6bd5076 commit b6c9e64
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 3 deletions.
2 changes: 1 addition & 1 deletion sdk/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "beta9"
version = "0.1.93"
version = "0.1.94"
description = ""
authors = ["beam.cloud <[email protected]>"]
packages = [
Expand Down
2 changes: 2 additions & 0 deletions sdk/src/beta9/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .abstractions.container import Container
from .abstractions.endpoint import ASGI as asgi
from .abstractions.endpoint import Endpoint as endpoint
from .abstractions.endpoint import RealtimeASGI as realtime
from .abstractions.function import Function as function
from .abstractions.function import Schedule as schedule
from .abstractions.image import Image
Expand All @@ -24,6 +25,7 @@
"function",
"endpoint",
"asgi",
"realtime",
"Container",
"env",
"GpuType",
Expand Down
188 changes: 188 additions & 0 deletions sdk/src/beta9/abstractions/endpoint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import os
import threading
import traceback
from typing import Any, Callable, List, Optional, Union

from uvicorn.protocols.utils import ClientDisconnected

from .. import terminal
from ..abstractions.base.runner import (
ASGI_DEPLOYMENT_STUB_TYPE,
Expand Down Expand Up @@ -276,6 +279,191 @@ def __init__(
self.is_asgi = True


REALTIME_ASGI_SLEEP_INTERVAL_SECONDS = 0.1
REALTIME_ASGI_HEALTH_CHECK_INTERVAL_SECONDS = 5


class RealtimeASGI(ASGI):
"""
Decorator which allows you to create a realtime application built on top of ASGI / websockets.
Your handler function will run every time a message is received over the websocket.
Parameters:
cpu (Union[int, float, str]):
The number of CPU cores allocated to the container. Default is 1.0.
memory (Union[int, str]):
The amount of memory allocated to the container. It should be specified in
MiB, or as a string with units (e.g. "1Gi"). Default is 128 MiB.
gpu (Union[GpuType, str]):
The type or name of the GPU device to be used for GPU-accelerated tasks. If not
applicable or no GPU required, leave it empty. Default is [GpuType.NoGPU](#gputype).
image (Union[Image, dict]):
The container image used for the task execution. Default is [Image](#image).
volumes (Optional[List[Volume]]):
A list of volumes to be mounted to the ASGI application. Default is None.
timeout (Optional[int]):
The maximum number of seconds a task can run before it times out.
Default is 3600. Set it to -1 to disable the timeout.
workers (Optional[int]):
The number of processes handling tasks per container.
Modifying this parameter can improve throughput for certain workloads.
Workers will share the CPU, Memory, and GPU defined.
You may need to increase these values to increase concurrency.
Default is 1.
concurrent_requests (int):
The maximum number of concurrent requests the ASGI application can handle.
Unlike regular endpoints that process requests synchronously, ASGI applications
can handle multiple requests concurrently. This parameter allows you to specify
the level of concurrency. For applications with blocking operations, this can
improve throughput by allowing the application to process other requests while
waiting for blocking operations to complete. Default is 1.
keep_warm_seconds (int):
The duration in seconds to keep the task queue warm even if there are no pending
tasks. Keeping the queue warm helps to reduce the latency when new tasks arrive.
Default is 10s.
max_pending_tasks (int):
The maximum number of tasks that can be pending in the queue. If the number of
pending tasks exceeds this value, the task queue will stop accepting new tasks.
Default is 100.
secrets (Optional[List[str]):
A list of secrets that are injected into the container as environment variables. Default is [].
name (Optional[str]):
An optional name for this ASGI application, used during deployment. If not specified, you must
specify the name at deploy time with the --name argument
authorized (Optional[str]):
If false, allows the ASGI application to be invoked without an auth token.
Default is True.
autoscaler (Optional[Autoscaler]):
Configure a deployment autoscaler - if specified, you can use scale your function horizontally using
various autoscaling strategies (Defaults to QueueDepthAutoscaler())
callback_url (Optional[str]):
An optional URL to send a callback to when a task is completed, timed out, or cancelled.
Example:
```python
from beta9 import realtime
def generate_text():
return ["this", "could", "be", "anything"]
@realtime(
cpu=1.0,
memory=128,
gpu="T4"
)
def handler(context):
return generate_text()
```
"""

def __init__(
self,
cpu: Union[int, float, str] = 1.0,
memory: Union[int, str] = 128,
gpu: GpuTypeAlias = GpuType.NoGPU,
image: Image = Image(),
timeout: int = 180,
workers: int = 1,
concurrent_requests: int = 1,
keep_warm_seconds: int = 180,
max_pending_tasks: int = 100,
on_start: Optional[Callable] = None,
volumes: Optional[List[Volume]] = None,
secrets: Optional[List[str]] = None,
name: Optional[str] = None,
authorized: bool = True,
autoscaler: Autoscaler = QueueDepthAutoscaler(),
callback_url: Optional[str] = None,
):
super().__init__(
cpu=cpu,
memory=memory,
gpu=gpu,
image=image,
timeout=timeout,
workers=workers,
keep_warm_seconds=keep_warm_seconds,
max_pending_tasks=max_pending_tasks,
on_start=on_start,
volumes=volumes,
secrets=secrets,
name=name,
authorized=authorized,
autoscaler=autoscaler,
callback_url=callback_url,
concurrent_requests=concurrent_requests,
)

def __call__(self, func):
import asyncio
from queue import Queue

from fastapi import FastAPI, WebSocket, WebSocketDisconnect, WebSocketException

internal_asgi_app = FastAPI()
internal_asgi_app.input_queue = Queue()

@internal_asgi_app.websocket("/")
async def stream(websocket: WebSocket):
async def _heartbeat():
while True:
try:
await websocket.send_json({"type": "ping"})
await asyncio.sleep(REALTIME_ASGI_HEALTH_CHECK_INTERVAL_SECONDS)
except (WebSocketDisconnect, WebSocketException, RuntimeError):
break

await websocket.accept()

heartbeat_task = asyncio.create_task(_heartbeat())
try:
while True:
try:
message = await websocket.receive()

if "text" in message:
data = message["text"]
elif "bytes" in message:
data = message["bytes"]
elif "json" in message:
data = message["json"]
elif message.get("type") == "websocket.disconnect":
return
else:
print(f"WS stream error - unknown message type: {message}")
continue

internal_asgi_app.input_queue.put(data)

while not internal_asgi_app.input_queue.empty():
output = internal_asgi_app.handler(
context=internal_asgi_app.context,
input=internal_asgi_app.input_queue.get(),
)

if isinstance(output, str):
await websocket.send_text(output)
elif isinstance(output, dict) or isinstance(output, list):
await websocket.send_json(output)
else:
await websocket.send_bytes(output)

await asyncio.sleep(REALTIME_ASGI_SLEEP_INTERVAL_SECONDS)
except (
WebSocketDisconnect,
WebSocketException,
RuntimeError,
ClientDisconnected,
):
return
except BaseException:
print(f"Unhandled exception in websocket stream: {traceback.format_exc()}")
finally:
heartbeat_task.cancel()

func.internal_asgi_app = internal_asgi_app
return _CallableWrapper(func, self)


class _CallableWrapper(DeployableMixin):
deployment_stub_type = ENDPOINT_DEPLOYMENT_STUB_TYPE

Expand Down
14 changes: 12 additions & 2 deletions sdk/src/beta9/runner/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import traceback
from contextlib import asynccontextmanager
from http import HTTPStatus
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, Union

from fastapi import Depends, FastAPI, Request
from fastapi.responses import JSONResponse, Response
Expand Down Expand Up @@ -146,7 +146,17 @@ def __init__(self, logger: logging.Logger, worker: UvicornWorker) -> None:
task_id=None,
on_start_value=self.on_start_value,
)
self.app = self.handler(context)

app: Union[FastAPI, None] = None
internal_asgi_app = getattr(self.handler.handler.func, "internal_asgi_app", None)
if internal_asgi_app is not None:
app = internal_asgi_app
app.context = context
app.handler = self.handler
else:
app = self.handler(context)

self.app = app
if not is_asgi3(self.app):
raise ValueError("Invalid ASGI app returned from handler")

Expand Down

0 comments on commit b6c9e64

Please sign in to comment.