Skip to content

Commit

Permalink
Merge pull request #259 from runpod/rust-core-integration
Browse files Browse the repository at this point in the history
Rust core integration
  • Loading branch information
justinmerrell authored Dec 28, 2023
2 parents 31f0fda + fcf5de6 commit 5645bb1
Show file tree
Hide file tree
Showing 9 changed files with 264 additions and 6 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ __pycache__/

# C extensions
*.so
!sls_core.so

# Distribution / packaging
.Python
Expand Down
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
# Change Log

## Release 1.5.0 (12/28/23)

### Added

- Optional serverless core implementation, use with environment variable `RUNPOD_USE_CORE=True` or `RUNPOD_CORE_PATH=/path/to/core.so`

### Changed

- Reduced *await asyncio.sleep* calls to 0 to reduce execution time.

---

## Release 1.4.2 (12/14/23)

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,4 @@ dependencies = { file = ["requirements.txt"] }

# Used by pytest coverage
[tool.coverage.run]
omit = ["runpod/_version.py"]
omit = ["runpod/_version.py", "runpod/serverless/core.py"]
11 changes: 9 additions & 2 deletions runpod/serverless/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import argparse
from typing import Dict, Any

from runpod.serverless import core
from . import worker
from .modules import rp_fastapi
from .modules.rp_logger import RunPodLogger
Expand Down Expand Up @@ -125,7 +126,7 @@ def start(config: Dict[str, Any]):
realtime_concurrency = _get_realtime_concurrency()

if config["rp_args"]["rp_serve_api"]:
print("Starting API server.")
log.info("Starting API server.")
api_server = rp_fastapi.WorkerAPI(config)

api_server.start_uvicorn(
Expand All @@ -135,7 +136,7 @@ def start(config: Dict[str, Any]):
)

elif realtime_port:
print("Starting API server for realtime.")
log.info("Starting API server for realtime.")
api_server = rp_fastapi.WorkerAPI(config)

api_server.start_uvicorn(
Expand All @@ -144,5 +145,11 @@ def start(config: Dict[str, Any]):
api_concurrency=realtime_concurrency
)

# --------------------------------- SLS-Core --------------------------------- #
elif os.environ.get("RUNPOD_USE_CORE", None) or os.environ.get("RUNPOD_CORE_PATH", None):
log.info("Starting worker with SLS-Core.")
core.main(config)

# --------------------------------- Standard --------------------------------- #
else:
worker.main(config)
227 changes: 227 additions & 0 deletions runpod/serverless/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
""" Core functionality for the runpod serverless worker. """

import ctypes
import inspect
import json
import os
import pathlib
import asyncio
from ctypes import CDLL, byref, c_char_p, c_int
from typing import Any, Callable, List, Dict, Optional

from runpod.serverless.modules.rp_logger import RunPodLogger


log = RunPodLogger()


class CGetJobResult(ctypes.Structure): # pylint: disable=too-few-public-methods
"""
result of _runpod_sls_get_jobs.
## fields
- `res_len` the number bytes were written to the `dst_buf` passed to _runpod_sls_get_jobs.
- `status_code` tells you what happened.
see CGetJobResult.status_code for more information.
"""

_fields_ = [("status_code", ctypes.c_int), ("res_len", ctypes.c_int)]

def __str__(self) -> str:
return f"CGetJobResult(res_len={self.res_len}, status_code={self.status_code})"


class Hook: # pylint: disable=too-many-instance-attributes
""" Singleton class for interacting with sls_core.so"""

_instance = None

# C function pointers
_get_jobs: Callable = None
_progress_update: Callable = None
_stream_output: Callable = None
_post_output: Callable = None
_finish_stream: Callable = None

def __new__(cls):
if Hook._instance is None:
Hook._instance = object.__new__(cls)
Hook._initialized = False
return Hook._instance

def __init__(self, rust_so_path: Optional[str] = None) -> None:
if self._initialized:
return

if rust_so_path is None:
default_path = os.path.join(
pathlib.Path(__file__).parent.absolute(), "sls_core.so"
)
self.rust_so_path = os.environ.get("RUNPOD_SLS_CORE_PATH", str(default_path))
else:
self.rust_so_path = rust_so_path

rust_library = CDLL(self.rust_so_path)
buffer = ctypes.create_string_buffer(1024) # 1 KiB
num_bytes = rust_library._runpod_sls_crate_version(byref(buffer), c_int(len(buffer)))

self.rust_crate_version = buffer.raw[:num_bytes].decode("utf-8")

# Get Jobs
self._get_jobs = rust_library._runpod_sls_get_jobs
self._get_jobs.restype = CGetJobResult

# Progress Update
self._progress_update = rust_library._runpod_sls_progress_update
self._progress_update.argtypes = [
c_char_p, c_int, # id_ptr, id_len
c_char_p, c_int # json_ptr, json_len
]
self._progress_update.restype = c_int # 1 if success, 0 if failure

# Stream Output
self._stream_output = rust_library._runpod_sls_stream_output
self._stream_output.argtypes = [
c_char_p, c_int, # id_ptr, id_len
c_char_p, c_int, # json_ptr, json_len
]
self._stream_output.restype = c_int # 1 if success, 0 if failure

# Post Output
self._post_output = rust_library._runpod_sls_post_output
self._post_output.argtypes = [
c_char_p, c_int, # id_ptr, id_len
c_char_p, c_int, # json_ptr, json_len
]
self._post_output.restype = c_int # 1 if success, 0 if failure

# Finish Stream
self._finish_stream = rust_library._runpod_sls_finish_stream
self._finish_stream.argtypes = [c_char_p, c_int] # id_ptr, id_len
self._finish_stream.restype = c_int # 1 if success, 0 if failure

rust_library._runpod_sls_crate_version.restype = c_int

rust_library._runpod_sls_init.argtypes = []
rust_library._runpod_sls_init.restype = c_int
rust_library._runpod_sls_init()

self._initialized = True

def _json_serialize_job_data(self, job_data: Any) -> bytes:
return json.dumps(job_data, ensure_ascii=False).encode("utf-8")

def get_jobs(self, max_concurrency: int, max_jobs: int) -> List[Dict[str, Any]]:
"""Get a job or jobs from the queue. The jobs are returned as a list of Job objects."""
buffer = ctypes.create_string_buffer(1024 * 1024 * 20) # 20MB buffer to store jobs in
destination_length = len(buffer.raw)
result: CGetJobResult = self._get_jobs(
c_int(max_concurrency), c_int(max_jobs),
byref(buffer), c_int(destination_length)
)
if result.status_code == 1: # success! the job was stored bytes 0..res_len of buf.raw
return list(json.loads(buffer.raw[: result.res_len].decode("utf-8")))

if result.status_code not in [0, 1]:
raise RuntimeError(f"get_jobs failed with status code {result.status_code}")

return [] # Status code 0, still waiting for jobs

def progress_update(self, job_id: str, json_data: bytes) -> bool:
"""
send a progress update to AI-API.
"""
id_bytes = job_id.encode("utf-8")
return bool(self._progress_update(
c_char_p(id_bytes), c_int(len(id_bytes)),
c_char_p(json_data), c_int(len(json_data))
))

def stream_output(self, job_id: str, job_output: bytes) -> bool:
"""
send part of a streaming result to AI-API.
"""
json_data = self._json_serialize_job_data(job_output)
id_bytes = job_id.encode("utf-8")
return bool(self._stream_output(
c_char_p(id_bytes), c_int(len(id_bytes)),
c_char_p(json_data), c_int(len(json_data))
))

def post_output(self, job_id: str, job_output: bytes) -> bool:
"""
send the result of a job to AI-API.
Returns True if the task was successfully stored, False otherwise.
"""
json_data = self._json_serialize_job_data(job_output)
id_bytes = job_id.encode("utf-8")
return bool(self._post_output(
c_char_p(id_bytes), c_int(len(id_bytes)),
c_char_p(json_data), c_int(len(json_data))
))

def finish_stream(self, job_id: str) -> bool:
"""
tell the SLS queue that the result of a streaming job is complete.
"""
id_bytes = job_id.encode("utf-8")
return bool(self._finish_stream(
c_char_p(id_bytes), c_int(len(id_bytes))
))


# -------------------------------- Process Job ------------------------------- #
def _process_job(handler: Callable, job: Dict[str, Any]) -> Dict[str, Any]:
""" Process a single job. """
hook = Hook()

try:
result = handler(job)
except Exception as err:
raise RuntimeError(
f"run {job['id']}: user code raised an {type(err).__name__}") from err

if inspect.isgeneratorfunction(handler):
for part in result:
hook.stream_output(job['id'], part)

hook.finish_stream(job['id'])

else:
hook.post_output(job['id'], result)


# -------------------------------- Run Worker -------------------------------- #
async def run(config: Dict[str, Any]) -> None:
""" Run the worker.
Args:
config: A dictionary containing the following keys:
handler: A function that takes a job and returns a result.
"""
handler = config['handler']
max_concurrency = config.get('max_concurrency', 4)
max_jobs = config.get('max_jobs', 4)

hook = Hook()

while True:
jobs = hook.get_jobs(max_concurrency, max_jobs)

if len(jobs) == 0:
continue

for job in jobs:
asyncio.create_task(_process_job(handler, job))
await asyncio.sleep(0)

await asyncio.sleep(0)


def main(config: Dict[str, Any]) -> None:
"""Run the worker in an asyncio event loop."""
try:
work_loop = asyncio.new_event_loop()
asyncio.ensure_future(run(config), loop=work_loop)
work_loop.run_forever()
finally:
work_loop.close()
2 changes: 1 addition & 1 deletion runpod/serverless/modules/rp_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ async def get_job(session: ClientSession, retry=True) -> Optional[Dict[str, Any]
if retry is False:
break

await asyncio.sleep(1)
await asyncio.sleep(0)
else:
job_list.add_job(next_job["id"])
log.debug("Request ID added.", next_job['id'])
Expand Down
3 changes: 1 addition & 2 deletions runpod/serverless/modules/rp_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ async def get_jobs(self, session):
if job:
yield job

await asyncio.sleep(1)

await asyncio.sleep(0)

log.debug(f"Concurrency set to: {self.current_concurrency}")
Binary file added runpod/serverless/sls_core.so
Binary file not shown.
12 changes: 12 additions & 0 deletions tests/test_serverless/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,3 +511,15 @@ def mock_is_alive():
# 5 calls with actual jobs
assert mock_run_job.call_count == 5
assert mock_send_result.call_count == 5

# Test with sls-core
async def test_run_worker_with_sls_core(self):
'''
Test run_worker with sls-core.
'''
os.environ["RUNPOD_USE_CORE"] = "true"

with patch("runpod.serverless.core.main") as mock_main:
runpod.serverless.start(self.config)

assert mock_main.called

0 comments on commit 5645bb1

Please sign in to comment.