Skip to content

Commit

Permalink
Merge pull request #6 from runpod/FastAPI
Browse files Browse the repository at this point in the history
Fast api
  • Loading branch information
justinmerrell authored Jan 23, 2023
2 parents ec190ca + 7445c24 commit ef98b95
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci_pylint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ jobs:
- name: Pylint Source
run: |
find . -type f -name '*.py' | xargs pylint
find . -type f -name '*.py' | xargs pylint --extension-pkg-whitelist='pydantic'
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
env
.env

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
4 changes: 4 additions & 0 deletions docs/serverless/worker.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ RUNPOD_WEBHOOK_GET_JOB= # URL to get job work from
RUNPOD_WEBHOOK_POST_OUTPUT= # URL to post output to
RUNPOD_WEBHOOK_PING= # URL to ping
RUNPOD_PING_INTERVAL= # Interval in milliseconds to ping the API (Default: 10000)

# Realtime
RUNPOD_REALTIME_PORT= # Port to listen on for realtime connections (Default: None)
RUNPOD_REALTIME_CONCURRENCY= # Number of workers to spawn (Default: 1)
```

### Additional Variables
Expand Down
5 changes: 5 additions & 0 deletions infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
'''
# pylint: disable=unused-argument,too-few-public-methods

import runpod


def validator():
'''
Expand Down Expand Up @@ -38,3 +40,6 @@ def run(model_inputs):
"seed": "1234"
}
]


runpod.serverless.start({"handler": run})
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ python-dotenv >= 0.21.0
requests >= 2.28.1
boto3 >= 1.26.15
aiohttp >= 3.8.3
fastapi[all] >= 0.89.0
12 changes: 11 additions & 1 deletion runpod/serverless/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
''' Allows serverless to recognized as a package.'''

import os
import asyncio

from . import work_loop
from .modules import rp_fastapi


def start(config):
'''
Starts the serverless worker.
'''
asyncio.run(work_loop.start_worker(config))
api_port = os.environ.get('RUNPOD_API_PORT', None)

if api_port:
api_server = rp_fastapi.WorkerAPI()
api_server.config = config

api_server.start_uvicorn(api_port)
else:
asyncio.run(work_loop.start_worker(config))
2 changes: 1 addition & 1 deletion runpod/serverless/modules/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def log(message, level='INFO'):
Log message to stdout if RUNPOD_DEBUG is true.
'''
if os.environ.get('RUNPOD_DEBUG', 'true') == 'true':
print(f'{level} | {message}')
print(f'{level} | {message}', flush=True)


def log_secret(secret_name, secret, level='INFO'):
Expand Down
58 changes: 58 additions & 0 deletions runpod/serverless/modules/rp_fastapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
''' Used to launch the FastAPI web server when worker is running in API mode. '''

import os
import threading

import uvicorn
from fastapi import FastAPI
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel

from .job import run_job
from .worker_state import set_job_id
from .heartbeat import start_heartbeat


class Job(BaseModel):
''' Represents a job. '''
id: str
input: dict


class WorkerAPI:
''' Used to launch the FastAPI web server when worker is running in API mode. '''

def __init__(self):
'''
Initializes the WorkerAPI class.
1. Starts the heartbeat thread.
2. Initializes the FastAPI web server.
'''
heartbeat_thread = threading.Thread(target=start_heartbeat)
heartbeat_thread.daemon = True
heartbeat_thread.start()

self.config = {"handler": None}
self.rp_app = FastAPI()
self.rp_app.add_api_route("/run", self.run, methods=["POST"])

def start_uvicorn(self, api_port):
'''
Starts the Uvicorn server.
'''
uvicorn.run(
self.rp_app, host='0.0.0.0', port=int(api_port),
workers=os.environ.get('RUNPOD_REALTIME_CONCURRENCY', 1)
)

async def run(self, job: Job):
'''
Performs model inference on the input data.
'''
set_job_id(job.id)

job_results = run_job(self.config["handler"], job.__dict__)

set_job_id(None)

return jsonable_encoder(job_results)
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ install_requires =
requests >= 2.28.1
boto3 >= 1.26.15
aiohttp >= 3.8.3
fastapi[all] >= 0.89.0

0 comments on commit ef98b95

Please sign in to comment.