diff --git a/.github/workflows/ci_pylint.yml b/.github/workflows/ci_pylint.yml index 7f162280..4b454ead 100644 --- a/.github/workflows/ci_pylint.yml +++ b/.github/workflows/ci_pylint.yml @@ -34,4 +34,4 @@ jobs: - name: Pylint Source run: | - find . -type f -name '*.py' | xargs pylint + find . -type f -name '*.py' | xargs pylint --extension-pkg-whitelist='pydantic' diff --git a/.gitignore b/.gitignore index b6e47617..ccb43d0f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +env +.env + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/docs/serverless/worker.md b/docs/serverless/worker.md index 30f9e74d..079e8f05 100644 --- a/docs/serverless/worker.md +++ b/docs/serverless/worker.md @@ -15,6 +15,10 @@ RUNPOD_WEBHOOK_GET_JOB= # URL to get job work from RUNPOD_WEBHOOK_POST_OUTPUT= # URL to post output to RUNPOD_WEBHOOK_PING= # URL to ping RUNPOD_PING_INTERVAL= # Interval in milliseconds to ping the API (Default: 10000) + +# Realtime +RUNPOD_REALTIME_PORT= # Port to listen on for realtime connections (Default: None) +RUNPOD_REALTIME_CONCURRENCY= # Number of workers to spawn (Default: 1) ``` ### Additional Variables diff --git a/infer.py b/infer.py index 1531d6f4..466b3789 100644 --- a/infer.py +++ b/infer.py @@ -5,6 +5,8 @@ ''' # pylint: disable=unused-argument,too-few-public-methods +import runpod + def validator(): ''' @@ -38,3 +40,6 @@ def run(model_inputs): "seed": "1234" } ] + + +runpod.serverless.start({"handler": run}) diff --git a/requirements.txt b/requirements.txt index dbd4d729..e09fa3ee 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ python-dotenv >= 0.21.0 requests >= 2.28.1 boto3 >= 1.26.15 aiohttp >= 3.8.3 +fastapi[all] >= 0.89.0 diff --git a/runpod/serverless/__init__.py b/runpod/serverless/__init__.py index 3fe20b52..38f96344 100644 --- a/runpod/serverless/__init__.py +++ b/runpod/serverless/__init__.py @@ -1,12 +1,22 @@ ''' Allows serverless to recognized as a package.''' +import os import asyncio from . import work_loop +from .modules import rp_fastapi def start(config): ''' Starts the serverless worker. ''' - asyncio.run(work_loop.start_worker(config)) + api_port = os.environ.get('RUNPOD_API_PORT', None) + + if api_port: + api_server = rp_fastapi.WorkerAPI() + api_server.config = config + + api_server.start_uvicorn(api_port) + else: + asyncio.run(work_loop.start_worker(config)) diff --git a/runpod/serverless/modules/logging.py b/runpod/serverless/modules/logging.py index 7071a656..97520b9c 100644 --- a/runpod/serverless/modules/logging.py +++ b/runpod/serverless/modules/logging.py @@ -17,7 +17,7 @@ def log(message, level='INFO'): Log message to stdout if RUNPOD_DEBUG is true. ''' if os.environ.get('RUNPOD_DEBUG', 'true') == 'true': - print(f'{level} | {message}') + print(f'{level} | {message}', flush=True) def log_secret(secret_name, secret, level='INFO'): diff --git a/runpod/serverless/modules/rp_fastapi.py b/runpod/serverless/modules/rp_fastapi.py new file mode 100644 index 00000000..910b6209 --- /dev/null +++ b/runpod/serverless/modules/rp_fastapi.py @@ -0,0 +1,58 @@ +''' Used to launch the FastAPI web server when worker is running in API mode. ''' + +import os +import threading + +import uvicorn +from fastapi import FastAPI +from fastapi.encoders import jsonable_encoder +from pydantic import BaseModel + +from .job import run_job +from .worker_state import set_job_id +from .heartbeat import start_heartbeat + + +class Job(BaseModel): + ''' Represents a job. ''' + id: str + input: dict + + +class WorkerAPI: + ''' Used to launch the FastAPI web server when worker is running in API mode. ''' + + def __init__(self): + ''' + Initializes the WorkerAPI class. + 1. Starts the heartbeat thread. + 2. Initializes the FastAPI web server. + ''' + heartbeat_thread = threading.Thread(target=start_heartbeat) + heartbeat_thread.daemon = True + heartbeat_thread.start() + + self.config = {"handler": None} + self.rp_app = FastAPI() + self.rp_app.add_api_route("/run", self.run, methods=["POST"]) + + def start_uvicorn(self, api_port): + ''' + Starts the Uvicorn server. + ''' + uvicorn.run( + self.rp_app, host='0.0.0.0', port=int(api_port), + workers=os.environ.get('RUNPOD_REALTIME_CONCURRENCY', 1) + ) + + async def run(self, job: Job): + ''' + Performs model inference on the input data. + ''' + set_job_id(job.id) + + job_results = run_job(self.config["handler"], job.__dict__) + + set_job_id(None) + + return jsonable_encoder(job_results) diff --git a/setup.cfg b/setup.cfg index cf6067f1..60915557 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,3 +31,4 @@ install_requires = requests >= 2.28.1 boto3 >= 1.26.15 aiohttp >= 3.8.3 + fastapi[all] >= 0.89.0