Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serve] HTTPS Support #3380

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions examples/serve/https/service.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# SkyServe YAML to run an HTTPS server.
#
# Usage:
# For testing purpose, generate a self-signed certificate and key,
# or bring your own:
# $ openssl req -x509 -newkey rsa:2048 -days 36500 -nodes \
# -keyout <key-path> -out <cert-path>
# Then:
# $ sky serve up -n https examples/serve/https/service.yaml \
# --env TLS_KEYFILE_ENV_VAR=<key-path> \
# --env TLS_CERTFILE_ENV_VAR=<cert-path>
# The endpoint will be printed in the console. You could also
# check the endpoint by running:
# $ sky serve status --endpoint https

envs:
TLS_KEYFILE_ENV_VAR:
TLS_CERTFILE_ENV_VAR:

service:
readiness_probe: /
replicas: 1
tls:
keyfile: $TLS_KEYFILE_ENV_VAR
certfile: $TLS_CERTFILE_ENV_VAR

resources:
ports: 8080
cpus: 2+

run: python3 -m http.server 8080
55 changes: 53 additions & 2 deletions sky/serve/core.py
Original file line number Diff line number Diff line change
@@ -91,6 +91,38 @@ def _validate_service_task(task: 'sky.Task') -> None:
'Please specify the same port instead.')


def _rewrite_tls_credential_paths_and_get_tls_env_vars(
service_name: str, task: 'sky.Task') -> Dict[str, Any]:
"""Rewrite the paths of TLS credentials in the task.

Args:
service_name: Name of the service.
task: sky.Task to rewrite.

Returns:
The generated template variables for TLS.
"""
service_spec = task.service
# Already checked by _validate_service_task
assert service_spec is not None
if service_spec.tls_credential is None:
return {'use_tls': False}
remote_tls_keyfile = (
serve_utils.generate_remote_tls_keyfile_name(service_name))
remote_tls_certfile = (
serve_utils.generate_remote_tls_certfile_name(service_name))
tls_template_vars = {
'use_tls': True,
'remote_tls_keyfile': remote_tls_keyfile,
'remote_tls_certfile': remote_tls_certfile,
'local_tls_keyfile': service_spec.tls_credential.keyfile,
'local_tls_certfile': service_spec.tls_credential.certfile,
}
service_spec.tls_credential = serve_utils.TLSCredential(
remote_tls_keyfile, remote_tls_certfile)
return tls_template_vars


@usage_lib.entrypoint
def up(
task: 'sky.Task',
@@ -134,6 +166,9 @@ def up(
controller_utils.maybe_translate_local_file_mounts_and_sync_up(
task, path='serve')

tls_template_vars = _rewrite_tls_credential_paths_and_get_tls_env_vars(
service_name, task)

with tempfile.NamedTemporaryFile(
prefix=f'service-task-{service_name}-',
mode='w',
@@ -162,6 +197,7 @@ def up(
'remote_user_config_path': remote_config_yaml_path,
'modified_catalogs':
service_catalog_common.get_modified_catalog_file_mounts(),
**tls_template_vars,
**controller_utils.shared_controller_vars_to_fill(
controller=controller_utils.Controllers.SKY_SERVE_CONTROLLER,
remote_user_config_path=remote_config_yaml_path,
@@ -267,10 +303,16 @@ def up(
else:
lb_port = serve_utils.load_service_initialization_result(
lb_port_payload)
endpoint = backend_utils.get_endpoints(
socket_endpoint = backend_utils.get_endpoints(
controller_handle.cluster_name, lb_port,
skip_status_check=True).get(lb_port)
assert endpoint is not None, 'Did not get endpoint for controller.'
assert socket_endpoint is not None, (
'Did not get endpoint for controller.')
# Already checked by _validate_service_task
assert task.service is not None
protocol = ('http'
if task.service.tls_credential is None else 'https')
endpoint = f'{protocol}://{socket_endpoint}'

sky_logging.print(
f'{fore.CYAN}Service name: '
@@ -319,6 +361,14 @@ def update(
service_name: Name of the service.
"""
_validate_service_task(task)

assert task.service is not None
if task.service.tls_credential is not None:
logger.warning('Updating TLS keyfile and certfile is not supported. '
'Any updates to the keyfile and certfile will not take '
'effect. To update TLS keyfile and certfile, please '
'tear down the service and spin up a new one.')

handle = backend_utils.is_controller_accessible(
controller=controller_utils.Controllers.SKY_SERVE_CONTROLLER,
stopped_message=
@@ -574,6 +624,7 @@ def status(
'policy': (Optional[str]) load balancer policy description,
'requested_resources_str': (str) str representation of
requested resources,
'tls_encrypted': (bool) whether the service is TLS encrypted,
'replica_info': (List[Dict[str, Any]]) replica information,
}

28 changes: 22 additions & 6 deletions sky/serve/load_balancer.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@
import asyncio
import logging
import threading
from typing import Dict, Union
from typing import Dict, Optional, Union

import aiohttp
import fastapi
@@ -27,12 +27,14 @@ class SkyServeLoadBalancer:
policy.
"""

def __init__(self, controller_url: str, load_balancer_port: int) -> None:
def __init__(self, controller_url: str, load_balancer_port: int,
tls_credential: Optional[serve_utils.TLSCredential]) -> None:
"""Initialize the load balancer.

Args:
controller_url: The URL of the controller.
load_balancer_port: The port where the load balancer listens to.
tls_credentials: The TLS credentials for HTTPS endpoint.
"""
self._app = fastapi.FastAPI()
self._controller_url: str = controller_url
@@ -41,6 +43,8 @@ def __init__(self, controller_url: str, load_balancer_port: int) -> None:
lb_policies.RoundRobinPolicy())
self._request_aggregator: serve_utils.RequestsAggregator = (
serve_utils.RequestTimestamp())
self._tls_credential: Optional[serve_utils.TLSCredential] = (
tls_credential)
# TODO(tian): httpx.Client has a resource limit of 100 max connections
# for each client. We should wait for feedback on the best max
# connections.
@@ -217,15 +221,27 @@ async def startup():
# Register controller synchronization task
asyncio.create_task(self._sync_with_controller())

uvicorn_tls_kwargs = ({} if self._tls_credential is None else
self._tls_credential.dump_uvicorn_kwargs())

protocol = 'https' if self._tls_credential is not None else 'http'

logger.info('SkyServe Load Balancer started on '
f'http://0.0.0.0:{self._load_balancer_port}')
f'{protocol}://0.0.0.0:{self._load_balancer_port}')

uvicorn.run(self._app, host='0.0.0.0', port=self._load_balancer_port)
uvicorn.run(self._app,
host='0.0.0.0',
port=self._load_balancer_port,
**uvicorn_tls_kwargs)


def run_load_balancer(controller_addr: str, load_balancer_port: int):
def run_load_balancer(
controller_addr: str,
load_balancer_port: int,
tls_credential: Optional[serve_utils.TLSCredential] = None):
load_balancer = SkyServeLoadBalancer(controller_url=controller_addr,
load_balancer_port=load_balancer_port)
load_balancer_port=load_balancer_port,
tls_credential=tls_credential)
load_balancer.run()


15 changes: 10 additions & 5 deletions sky/serve/serve_state.py
Original file line number Diff line number Diff line change
@@ -76,6 +76,9 @@ def create_table(cursor: 'sqlite3.Cursor', conn: 'sqlite3.Connection') -> None:
db_utils.add_column_to_table(_DB.cursor, _DB.conn, 'services',
'active_versions',
f'TEXT DEFAULT {json.dumps([])!r}')
# Whether the service's load balancer is encrypted with TLS.
db_utils.add_column_to_table(_DB.cursor, _DB.conn, 'services', 'tls_encrypted',
'INTEGER DEFAULT 0')
_UNIQUE_CONSTRAINT_FAILED_ERROR_MSG = 'UNIQUE constraint failed: services.name'


@@ -241,7 +244,8 @@ def from_replica_statuses(


def add_service(name: str, controller_job_id: int, policy: str,
requested_resources_str: str, status: ServiceStatus) -> bool:
requested_resources_str: str, status: ServiceStatus,
tls_encrypted: bool) -> bool:
"""Add a service in the database.

Returns:
@@ -254,10 +258,10 @@ def add_service(name: str, controller_job_id: int, policy: str,
"""\
INSERT INTO services
(name, controller_job_id, status, policy,
requested_resources_str)
VALUES (?, ?, ?, ?, ?)""",
requested_resources_str, tls_encrypted)
VALUES (?, ?, ?, ?, ?, ?)""",
(name, controller_job_id, status.value, policy,
requested_resources_str))
requested_resources_str, int(tls_encrypted)))

except sqlite3.IntegrityError as e:
if str(e) != _UNIQUE_CONSTRAINT_FAILED_ERROR_MSG:
@@ -324,7 +328,7 @@ def set_service_load_balancer_port(service_name: str,
def _get_service_from_row(row) -> Dict[str, Any]:
(current_version, name, controller_job_id, controller_port,
load_balancer_port, status, uptime, policy, _, _, requested_resources_str,
_, active_versions) = row[:13]
_, active_versions, tls_encrypted) = row[:14]
return {
'name': name,
'controller_job_id': controller_job_id,
@@ -341,6 +345,7 @@ def _get_service_from_row(row) -> Dict[str, Any]:
# integers in json format. This is mainly for display purpose.
'active_versions': json.loads(active_versions),
'requested_resources_str': requested_resources_str,
'tls_encrypted': bool(tls_encrypted),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For existing service, this return value will be None, and bool(tls_encrypted) can fail?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, bool(None) will return False:

python
Python 3.9.18 (main, Sep 11 2023, 13:41:44) 
[GCC 11.2.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> bool(None)
False

But I agree that it is confusing. Changed it to 0 🫡

}


29 changes: 28 additions & 1 deletion sky/serve/serve_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""User interface with the SkyServe."""
import base64
import collections
import dataclasses
import enum
import os
import pathlib
@@ -86,6 +87,19 @@ class UpdateMode(enum.Enum):
BLUE_GREEN = 'blue_green'


@dataclasses.dataclass
class TLSCredential:
"""TLS credential for the service."""
keyfile: str
certfile: str

def dump_uvicorn_kwargs(self) -> Dict[str, str]:
return {
'ssl_keyfile': os.path.expanduser(self.keyfile),
'ssl_certfile': os.path.expanduser(self.certfile),
}


DEFAULT_UPDATE_MODE = UpdateMode.ROLLING

_SIGNAL_TO_ERROR = {
@@ -237,6 +251,18 @@ def generate_replica_log_file_name(service_name: str, replica_id: int) -> str:
return os.path.join(dir_name, f'replica_{replica_id}.log')


def generate_remote_tls_keyfile_name(service_name: str) -> str:
dir_name = generate_remote_service_dir_name(service_name)
# Don't expand here since it is used for remote machine.
return os.path.join(dir_name, 'tls_keyfile')


def generate_remote_tls_certfile_name(service_name: str) -> str:
dir_name = generate_remote_service_dir_name(service_name)
# Don't expand here since it is used for remote machine.
return os.path.join(dir_name, 'tls_certfile')


def generate_replica_cluster_name(service_name: str, replica_id: int) -> str:
return f'{service_name}-{replica_id}'

@@ -793,7 +819,8 @@ def get_endpoint(service_record: Dict[str, Any]) -> str:
if endpoint is None:
return '-'
assert isinstance(endpoint, str), endpoint
return endpoint
protocol = 'https' if service_record['tls_encrypted'] else 'http'
return f'{protocol}://{endpoint}'


def format_service_table(service_records: List[Dict[str, Any]],
7 changes: 4 additions & 3 deletions sky/serve/service.py
Original file line number Diff line number Diff line change
@@ -150,7 +150,8 @@ def _start(service_name: str, tmp_task_yaml: str, job_id: int):
controller_job_id=job_id,
policy=service_spec.autoscaling_policy_str(),
requested_resources_str=backend_utils.get_task_resources_str(task),
status=serve_state.ServiceStatus.CONTROLLER_INIT)
status=serve_state.ServiceStatus.CONTROLLER_INIT,
tls_encrypted=service_spec.tls_credential is not None)
# Directly throw an error here. See sky/serve/api.py::up
# for more details.
if not success:
@@ -213,7 +214,6 @@ def _get_host():
serve_state.set_service_controller_port(service_name,
controller_port)

# TODO(tian): Support HTTPS.
controller_addr = f'http://{controller_host}:{controller_port}'

load_balancer_port = common_utils.find_free_port(
@@ -227,7 +227,8 @@ def _get_host():
target=ux_utils.RedirectOutputForProcess(
load_balancer.run_load_balancer,
load_balancer_log_file).run,
args=(controller_addr, load_balancer_port))
args=(controller_addr, load_balancer_port,
service_spec.tls_credential))
load_balancer_process.start()
serve_state.set_service_load_balancer_port(service_name,
load_balancer_port)
Loading