diff --git a/examples/serve/https/service.yaml b/examples/serve/https/service.yaml new file mode 100644 index 00000000000..b4d48f4e495 --- /dev/null +++ b/examples/serve/https/service.yaml @@ -0,0 +1,31 @@ +# SkyServe YAML to run an HTTPS server. +# +# Usage: +# For testing purpose, generate a self-signed certificate and key, +# or bring your own: +# $ openssl req -x509 -newkey rsa:2048 -days 36500 -nodes \ +# -keyout -out +# Then: +# $ sky serve up -n https examples/serve/https/service.yaml \ +# --env TLS_KEYFILE_ENV_VAR= \ +# --env TLS_CERTFILE_ENV_VAR= +# The endpoint will be printed in the console. You could also +# check the endpoint by running: +# $ sky serve status --endpoint https + +envs: + TLS_KEYFILE_ENV_VAR: + TLS_CERTFILE_ENV_VAR: + +service: + readiness_probe: / + replicas: 1 + tls: + keyfile: $TLS_KEYFILE_ENV_VAR + certfile: $TLS_CERTFILE_ENV_VAR + +resources: + ports: 8080 + cpus: 2+ + +run: python3 -m http.server 8080 diff --git a/sky/serve/core.py b/sky/serve/core.py index ea8f380a2e7..156392a444e 100644 --- a/sky/serve/core.py +++ b/sky/serve/core.py @@ -91,6 +91,38 @@ def _validate_service_task(task: 'sky.Task') -> None: 'Please specify the same port instead.') +def _rewrite_tls_credential_paths_and_get_tls_env_vars( + service_name: str, task: 'sky.Task') -> Dict[str, Any]: + """Rewrite the paths of TLS credentials in the task. + + Args: + service_name: Name of the service. + task: sky.Task to rewrite. + + Returns: + The generated template variables for TLS. + """ + service_spec = task.service + # Already checked by _validate_service_task + assert service_spec is not None + if service_spec.tls_credential is None: + return {'use_tls': False} + remote_tls_keyfile = ( + serve_utils.generate_remote_tls_keyfile_name(service_name)) + remote_tls_certfile = ( + serve_utils.generate_remote_tls_certfile_name(service_name)) + tls_template_vars = { + 'use_tls': True, + 'remote_tls_keyfile': remote_tls_keyfile, + 'remote_tls_certfile': remote_tls_certfile, + 'local_tls_keyfile': service_spec.tls_credential.keyfile, + 'local_tls_certfile': service_spec.tls_credential.certfile, + } + service_spec.tls_credential = serve_utils.TLSCredential( + remote_tls_keyfile, remote_tls_certfile) + return tls_template_vars + + @usage_lib.entrypoint def up( task: 'sky.Task', @@ -134,6 +166,9 @@ def up( controller_utils.maybe_translate_local_file_mounts_and_sync_up( task, path='serve') + tls_template_vars = _rewrite_tls_credential_paths_and_get_tls_env_vars( + service_name, task) + with tempfile.NamedTemporaryFile( prefix=f'service-task-{service_name}-', mode='w', @@ -162,6 +197,7 @@ def up( 'remote_user_config_path': remote_config_yaml_path, 'modified_catalogs': service_catalog_common.get_modified_catalog_file_mounts(), + **tls_template_vars, **controller_utils.shared_controller_vars_to_fill( controller=controller_utils.Controllers.SKY_SERVE_CONTROLLER, remote_user_config_path=remote_config_yaml_path, @@ -267,10 +303,16 @@ def up( else: lb_port = serve_utils.load_service_initialization_result( lb_port_payload) - endpoint = backend_utils.get_endpoints( + socket_endpoint = backend_utils.get_endpoints( controller_handle.cluster_name, lb_port, skip_status_check=True).get(lb_port) - assert endpoint is not None, 'Did not get endpoint for controller.' + assert socket_endpoint is not None, ( + 'Did not get endpoint for controller.') + # Already checked by _validate_service_task + assert task.service is not None + protocol = ('http' + if task.service.tls_credential is None else 'https') + endpoint = f'{protocol}://{socket_endpoint}' sky_logging.print( f'{fore.CYAN}Service name: ' @@ -319,6 +361,14 @@ def update( service_name: Name of the service. """ _validate_service_task(task) + + assert task.service is not None + if task.service.tls_credential is not None: + logger.warning('Updating TLS keyfile and certfile is not supported. ' + 'Any updates to the keyfile and certfile will not take ' + 'effect. To update TLS keyfile and certfile, please ' + 'tear down the service and spin up a new one.') + handle = backend_utils.is_controller_accessible( controller=controller_utils.Controllers.SKY_SERVE_CONTROLLER, stopped_message= @@ -574,6 +624,7 @@ def status( 'policy': (Optional[str]) load balancer policy description, 'requested_resources_str': (str) str representation of requested resources, + 'tls_encrypted': (bool) whether the service is TLS encrypted, 'replica_info': (List[Dict[str, Any]]) replica information, } diff --git a/sky/serve/load_balancer.py b/sky/serve/load_balancer.py index c15f71e214a..35ef1dea3ca 100644 --- a/sky/serve/load_balancer.py +++ b/sky/serve/load_balancer.py @@ -2,7 +2,7 @@ import asyncio import logging import threading -from typing import Dict, Union +from typing import Dict, Optional, Union import aiohttp import fastapi @@ -27,12 +27,14 @@ class SkyServeLoadBalancer: policy. """ - def __init__(self, controller_url: str, load_balancer_port: int) -> None: + def __init__(self, controller_url: str, load_balancer_port: int, + tls_credential: Optional[serve_utils.TLSCredential]) -> None: """Initialize the load balancer. Args: controller_url: The URL of the controller. load_balancer_port: The port where the load balancer listens to. + tls_credentials: The TLS credentials for HTTPS endpoint. """ self._app = fastapi.FastAPI() self._controller_url: str = controller_url @@ -41,6 +43,8 @@ def __init__(self, controller_url: str, load_balancer_port: int) -> None: lb_policies.RoundRobinPolicy()) self._request_aggregator: serve_utils.RequestsAggregator = ( serve_utils.RequestTimestamp()) + self._tls_credential: Optional[serve_utils.TLSCredential] = ( + tls_credential) # TODO(tian): httpx.Client has a resource limit of 100 max connections # for each client. We should wait for feedback on the best max # connections. @@ -217,15 +221,27 @@ async def startup(): # Register controller synchronization task asyncio.create_task(self._sync_with_controller()) + uvicorn_tls_kwargs = ({} if self._tls_credential is None else + self._tls_credential.dump_uvicorn_kwargs()) + + protocol = 'https' if self._tls_credential is not None else 'http' + logger.info('SkyServe Load Balancer started on ' - f'http://0.0.0.0:{self._load_balancer_port}') + f'{protocol}://0.0.0.0:{self._load_balancer_port}') - uvicorn.run(self._app, host='0.0.0.0', port=self._load_balancer_port) + uvicorn.run(self._app, + host='0.0.0.0', + port=self._load_balancer_port, + **uvicorn_tls_kwargs) -def run_load_balancer(controller_addr: str, load_balancer_port: int): +def run_load_balancer( + controller_addr: str, + load_balancer_port: int, + tls_credential: Optional[serve_utils.TLSCredential] = None): load_balancer = SkyServeLoadBalancer(controller_url=controller_addr, - load_balancer_port=load_balancer_port) + load_balancer_port=load_balancer_port, + tls_credential=tls_credential) load_balancer.run() diff --git a/sky/serve/serve_state.py b/sky/serve/serve_state.py index 333e0138fb4..0c47513f1d4 100644 --- a/sky/serve/serve_state.py +++ b/sky/serve/serve_state.py @@ -76,6 +76,9 @@ def create_table(cursor: 'sqlite3.Cursor', conn: 'sqlite3.Connection') -> None: db_utils.add_column_to_table(_DB.cursor, _DB.conn, 'services', 'active_versions', f'TEXT DEFAULT {json.dumps([])!r}') +# Whether the service's load balancer is encrypted with TLS. +db_utils.add_column_to_table(_DB.cursor, _DB.conn, 'services', 'tls_encrypted', + 'INTEGER DEFAULT 0') _UNIQUE_CONSTRAINT_FAILED_ERROR_MSG = 'UNIQUE constraint failed: services.name' @@ -241,7 +244,8 @@ def from_replica_statuses( def add_service(name: str, controller_job_id: int, policy: str, - requested_resources_str: str, status: ServiceStatus) -> bool: + requested_resources_str: str, status: ServiceStatus, + tls_encrypted: bool) -> bool: """Add a service in the database. Returns: @@ -254,10 +258,10 @@ def add_service(name: str, controller_job_id: int, policy: str, """\ INSERT INTO services (name, controller_job_id, status, policy, - requested_resources_str) - VALUES (?, ?, ?, ?, ?)""", + requested_resources_str, tls_encrypted) + VALUES (?, ?, ?, ?, ?, ?)""", (name, controller_job_id, status.value, policy, - requested_resources_str)) + requested_resources_str, int(tls_encrypted))) except sqlite3.IntegrityError as e: if str(e) != _UNIQUE_CONSTRAINT_FAILED_ERROR_MSG: @@ -324,7 +328,7 @@ def set_service_load_balancer_port(service_name: str, def _get_service_from_row(row) -> Dict[str, Any]: (current_version, name, controller_job_id, controller_port, load_balancer_port, status, uptime, policy, _, _, requested_resources_str, - _, active_versions) = row[:13] + _, active_versions, tls_encrypted) = row[:14] return { 'name': name, 'controller_job_id': controller_job_id, @@ -341,6 +345,7 @@ def _get_service_from_row(row) -> Dict[str, Any]: # integers in json format. This is mainly for display purpose. 'active_versions': json.loads(active_versions), 'requested_resources_str': requested_resources_str, + 'tls_encrypted': bool(tls_encrypted), } diff --git a/sky/serve/serve_utils.py b/sky/serve/serve_utils.py index 6e7b6f6eb4a..e35873382ce 100644 --- a/sky/serve/serve_utils.py +++ b/sky/serve/serve_utils.py @@ -1,6 +1,7 @@ """User interface with the SkyServe.""" import base64 import collections +import dataclasses import enum import os import pathlib @@ -86,6 +87,19 @@ class UpdateMode(enum.Enum): BLUE_GREEN = 'blue_green' +@dataclasses.dataclass +class TLSCredential: + """TLS credential for the service.""" + keyfile: str + certfile: str + + def dump_uvicorn_kwargs(self) -> Dict[str, str]: + return { + 'ssl_keyfile': os.path.expanduser(self.keyfile), + 'ssl_certfile': os.path.expanduser(self.certfile), + } + + DEFAULT_UPDATE_MODE = UpdateMode.ROLLING _SIGNAL_TO_ERROR = { @@ -237,6 +251,18 @@ def generate_replica_log_file_name(service_name: str, replica_id: int) -> str: return os.path.join(dir_name, f'replica_{replica_id}.log') +def generate_remote_tls_keyfile_name(service_name: str) -> str: + dir_name = generate_remote_service_dir_name(service_name) + # Don't expand here since it is used for remote machine. + return os.path.join(dir_name, 'tls_keyfile') + + +def generate_remote_tls_certfile_name(service_name: str) -> str: + dir_name = generate_remote_service_dir_name(service_name) + # Don't expand here since it is used for remote machine. + return os.path.join(dir_name, 'tls_certfile') + + def generate_replica_cluster_name(service_name: str, replica_id: int) -> str: return f'{service_name}-{replica_id}' @@ -793,7 +819,8 @@ def get_endpoint(service_record: Dict[str, Any]) -> str: if endpoint is None: return '-' assert isinstance(endpoint, str), endpoint - return endpoint + protocol = 'https' if service_record['tls_encrypted'] else 'http' + return f'{protocol}://{endpoint}' def format_service_table(service_records: List[Dict[str, Any]], diff --git a/sky/serve/service.py b/sky/serve/service.py index 956a4839a87..11184c7abe8 100644 --- a/sky/serve/service.py +++ b/sky/serve/service.py @@ -150,7 +150,8 @@ def _start(service_name: str, tmp_task_yaml: str, job_id: int): controller_job_id=job_id, policy=service_spec.autoscaling_policy_str(), requested_resources_str=backend_utils.get_task_resources_str(task), - status=serve_state.ServiceStatus.CONTROLLER_INIT) + status=serve_state.ServiceStatus.CONTROLLER_INIT, + tls_encrypted=service_spec.tls_credential is not None) # Directly throw an error here. See sky/serve/api.py::up # for more details. if not success: @@ -213,7 +214,6 @@ def _get_host(): serve_state.set_service_controller_port(service_name, controller_port) - # TODO(tian): Support HTTPS. controller_addr = f'http://{controller_host}:{controller_port}' load_balancer_port = common_utils.find_free_port( @@ -227,7 +227,8 @@ def _get_host(): target=ux_utils.RedirectOutputForProcess( load_balancer.run_load_balancer, load_balancer_log_file).run, - args=(controller_addr, load_balancer_port)) + args=(controller_addr, load_balancer_port, + service_spec.tls_credential)) load_balancer_process.start() serve_state.set_service_load_balancer_port(service_name, load_balancer_port) diff --git a/sky/serve/service_spec.py b/sky/serve/service_spec.py index 2eff6f40a9d..d87d8771dce 100644 --- a/sky/serve/service_spec.py +++ b/sky/serve/service_spec.py @@ -7,6 +7,7 @@ import yaml from sky.serve import constants +from sky.serve import serve_utils from sky.utils import common_utils from sky.utils import schemas from sky.utils import ux_utils @@ -24,6 +25,7 @@ def __init__( max_replicas: Optional[int] = None, target_qps_per_replica: Optional[float] = None, post_data: Optional[Dict[str, Any]] = None, + tls_credential: Optional[serve_utils.TLSCredential] = None, readiness_headers: Optional[Dict[str, str]] = None, dynamic_ondemand_fallback: Optional[bool] = None, base_ondemand_fallback_replicas: Optional[int] = None, @@ -62,6 +64,8 @@ def __init__( self._max_replicas: Optional[int] = max_replicas self._target_qps_per_replica: Optional[float] = target_qps_per_replica self._post_data: Optional[Dict[str, Any]] = post_data + self._tls_credential: Optional[serve_utils.TLSCredential] = ( + tls_credential) self._readiness_headers: Optional[Dict[str, str]] = readiness_headers self._dynamic_ondemand_fallback: Optional[ bool] = dynamic_ondemand_fallback @@ -150,6 +154,13 @@ def from_yaml_config(config: Dict[str, Any]) -> 'SkyServiceSpec': service_config['dynamic_ondemand_fallback'] = policy_section.get( 'dynamic_ondemand_fallback', None) + tls_section = config.get('tls', None) + if tls_section is not None: + service_config['tls_credential'] = serve_utils.TLSCredential( + keyfile=tls_section.get('keyfile', None), + certfile=tls_section.get('certfile', None), + ) + return SkyServiceSpec(**service_config) @staticmethod @@ -205,6 +216,9 @@ def add_if_not_none(section, key, value, no_empty: bool = False): self.upscale_delay_seconds) add_if_not_none('replica_policy', 'downscale_delay_seconds', self.downscale_delay_seconds) + if self.tls_credential is not None: + add_if_not_none('tls', 'keyfile', self.tls_credential.keyfile) + add_if_not_none('tls', 'certfile', self.tls_credential.certfile) return config def probe_str(self): @@ -249,12 +263,19 @@ def autoscaling_policy_str(self): f'replica{max_plural} (target QPS per replica: ' f'{self.target_qps_per_replica})') + def tls_str(self): + if self.tls_credential is None: + return 'No TLS Enabled' + return (f'Keyfile: {self.tls_credential.keyfile}, ' + f'Certfile: {self.tls_credential.certfile}') + def __repr__(self) -> str: return textwrap.dedent(f"""\ Readiness probe method: {self.probe_str()} Readiness initial delay seconds: {self.initial_delay_seconds} Readiness probe timeout seconds: {self.readiness_timeout_seconds} Replica autoscaling policy: {self.autoscaling_policy_str()} + TLS Certificates: {self.tls_str()} Spot Policy: {self.spot_policy_str()} """) @@ -287,6 +308,15 @@ def target_qps_per_replica(self) -> Optional[float]: def post_data(self) -> Optional[Dict[str, Any]]: return self._post_data + @property + def tls_credential(self) -> Optional[serve_utils.TLSCredential]: + return self._tls_credential + + @tls_credential.setter + def tls_credential(self, + value: Optional[serve_utils.TLSCredential]) -> None: + self._tls_credential = value + @property def readiness_headers(self) -> Optional[Dict[str, str]]: return self._readiness_headers diff --git a/sky/templates/sky-serve-controller.yaml.j2 b/sky/templates/sky-serve-controller.yaml.j2 index 507a6e3a325..dfdce1379b1 100644 --- a/sky/templates/sky-serve-controller.yaml.j2 +++ b/sky/templates/sky-serve-controller.yaml.j2 @@ -29,6 +29,10 @@ file_mounts: {%- for remote_catalog_path, local_catalog_path in modified_catalogs.items() %} {{remote_catalog_path}}: {{local_catalog_path}} {%- endfor %} +{%- if use_tls %} + {{remote_tls_keyfile}}: {{local_tls_keyfile}} + {{remote_tls_certfile}}: {{local_tls_certfile}} +{%- endif %} run: | # Activate the Python environment, so that cloud SDKs can be found in the diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index d9f105db8b0..d43a04d090a 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -362,6 +362,19 @@ def get_service_schema(): 'replicas': { 'type': 'integer', }, + 'tls': { + 'type': 'object', + 'required': ['keyfile', 'certfile'], + 'additionalProperties': False, + 'properties': { + 'keyfile': { + 'type': 'string', + }, + 'certfile': { + 'type': 'string', + }, + }, + }, } } diff --git a/tests/skyserve/cancel/send_cancel_request.py b/tests/skyserve/cancel/send_cancel_request.py index 48c2b2bec63..8ec4fcf524b 100644 --- a/tests/skyserve/cancel/send_cancel_request.py +++ b/tests/skyserve/cancel/send_cancel_request.py @@ -22,7 +22,7 @@ async def main(): timeout = 2 async with aiohttp.ClientSession() as session: - task = asyncio.create_task(fetch(session, f'http://{args.endpoint}/')) + task = asyncio.create_task(fetch(session, f'{args.endpoint}/')) await asyncio.sleep(timeout) # We manually cancel requests for test purposes. diff --git a/tests/skyserve/https/service.yaml b/tests/skyserve/https/service.yaml new file mode 100644 index 00000000000..5874b560017 --- /dev/null +++ b/tests/skyserve/https/service.yaml @@ -0,0 +1,19 @@ +envs: + TLS_KEYFILE_ENV_VAR: + TLS_CERTFILE_ENV_VAR: + +service: + readiness_probe: /health + replicas: 1 + tls: + keyfile: $TLS_KEYFILE_ENV_VAR + certfile: $TLS_CERTFILE_ENV_VAR + +resources: + ports: 8081 + cpus: 2+ + +workdir: examples/serve/http_server + +# Use 8081 to test jupyterhub service is terminated +run: python3 server.py --port 8081 diff --git a/tests/skyserve/llm/get_response.py b/tests/skyserve/llm/get_response.py index 9dd6ea53804..4fe7c55c3c0 100644 --- a/tests/skyserve/llm/get_response.py +++ b/tests/skyserve/llm/get_response.py @@ -19,7 +19,7 @@ }, ] - url = f'http://{args.endpoint}/v1/chat/completions' + url = f'{args.endpoint}/v1/chat/completions' resp = requests.post(url, json={ 'model': 'fastchat-t5-3b-v1.0', diff --git a/tests/skyserve/load_balancer/test_round_robin.py b/tests/skyserve/load_balancer/test_round_robin.py index 80763d1259c..564cfd5c665 100644 --- a/tests/skyserve/load_balancer/test_round_robin.py +++ b/tests/skyserve/load_balancer/test_round_robin.py @@ -16,7 +16,7 @@ replica_ips = [] for r in range(args.replica_num): - url = f'http://{args.endpoint}/get_ip' + url = f'{args.endpoint}/get_ip' resp = requests.get(url) assert resp.status_code == 200, resp.text assert 'ip' in resp.json(), resp.json() @@ -29,7 +29,7 @@ for i in range(_REPEAT): for r in range(args.replica_num): - url = f'http://{args.endpoint}/get_ip' + url = f'{args.endpoint}/get_ip' resp = requests.get(url) assert resp.status_code == 200, resp.text assert 'ip' in resp.json(), resp.json() diff --git a/tests/skyserve/streaming/send_streaming_request.py b/tests/skyserve/streaming/send_streaming_request.py index 7c56d929761..b7f09922bef 100644 --- a/tests/skyserve/streaming/send_streaming_request.py +++ b/tests/skyserve/streaming/send_streaming_request.py @@ -8,7 +8,7 @@ parser = argparse.ArgumentParser() parser.add_argument('--endpoint', type=str, required=True) args = parser.parse_args() -url = f'http://{args.endpoint}/' +url = f'{args.endpoint}/' expected = WORD_TO_STREAM.split() index = 0 diff --git a/tests/test_smoke.py b/tests/test_smoke.py index ed86f93ca27..daadac6de06 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -3684,7 +3684,7 @@ def _get_skyserve_http_test(name: str, cloud: str, f'sky serve up -n {name} -y tests/skyserve/http/{cloud}.yaml', _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2), f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; ' - 'curl http://$endpoint | grep "Hi, SkyPilot here"', + 'curl $endpoint | grep "Hi, SkyPilot here"', ], _TEARDOWN_SERVICE.format(name=name), timeout=timeout_minutes * 60, @@ -3807,11 +3807,11 @@ def test_skyserve_spot_recovery(): f'sky serve up -n {name} -y tests/skyserve/spot/recovery.yaml', _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=1), f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; ' - 'request_output=$(curl http://$endpoint); echo "$request_output"; echo "$request_output" | grep "Hi, SkyPilot here"', + 'request_output=$(curl $endpoint); echo "$request_output"; echo "$request_output" | grep "Hi, SkyPilot here"', _terminate_gcp_replica(name, zone, 1), _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=1), f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; ' - 'request_output=$(curl http://$endpoint); echo "$request_output"; echo "$request_output" | grep "Hi, SkyPilot here"', + 'request_output=$(curl $endpoint); echo "$request_output"; echo "$request_output" | grep "Hi, SkyPilot here"', ], _TEARDOWN_SERVICE.format(name=name), timeout=20 * 60, @@ -3909,7 +3909,7 @@ def test_skyserve_user_bug_restart(generic_cloud: str): f'echo "$s" | grep -B 100 "NO_REPLICA" | grep "0/0"', f'sky serve update {name} --cloud {generic_cloud} -y tests/skyserve/auto_restart.yaml', f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; ' - 'until curl http://$endpoint | grep "Hi, SkyPilot here!"; do sleep 2; done; sleep 2; ' + 'until curl $endpoint | grep "Hi, SkyPilot here!"; do sleep 2; done; sleep 2; ' + _check_replica_in_status(name, [(1, False, 'READY'), (1, False, 'FAILED')]), ], @@ -3957,7 +3957,7 @@ def test_skyserve_auto_restart(): f'sky serve up -n {name} -y tests/skyserve/auto_restart.yaml', _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=1), f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; ' - 'request_output=$(curl http://$endpoint); echo "$request_output"; echo "$request_output" | grep "Hi, SkyPilot here"', + 'request_output=$(curl $endpoint); echo "$request_output"; echo "$request_output" | grep "Hi, SkyPilot here"', # sleep for 20 seconds (initial delay) to make sure it will # be restarted f'sleep 20', @@ -3977,7 +3977,7 @@ def test_skyserve_auto_restart(): ' sleep 10;' f'done); sleep {serve.LB_CONTROLLER_SYNC_INTERVAL_SECONDS};', f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; ' - 'request_output=$(curl http://$endpoint); echo "$request_output"; echo "$request_output" | grep "Hi, SkyPilot here"', + 'request_output=$(curl $endpoint); echo "$request_output"; echo "$request_output" | grep "Hi, SkyPilot here"', ], _TEARDOWN_SERVICE.format(name=name), timeout=20 * 60, @@ -4062,7 +4062,7 @@ def test_skyserve_large_readiness_timeout(generic_cloud: str): f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/readiness_timeout/task_large_timeout.yaml', _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=1), f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; ' - 'request_output=$(curl http://$endpoint); echo "$request_output"; echo "$request_output" | grep "Hi, SkyPilot here"', + 'request_output=$(curl $endpoint); echo "$request_output"; echo "$request_output" | grep "Hi, SkyPilot here"', ], _TEARDOWN_SERVICE.format(name=name), timeout=20 * 60, @@ -4081,14 +4081,14 @@ def test_skyserve_update(generic_cloud: str): [ f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/update/old.yaml', _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2), - f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; curl http://$endpoint | grep "Hi, SkyPilot here"', + f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; curl $endpoint | grep "Hi, SkyPilot here"', f'sky serve update {name} --cloud {generic_cloud} --mode blue_green -y tests/skyserve/update/new.yaml', # sleep before update is registered. 'sleep 20', f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; ' - 'until curl http://$endpoint | grep "Hi, new SkyPilot here!"; do sleep 2; done;' + 'until curl $endpoint | grep "Hi, new SkyPilot here!"; do sleep 2; done;' # Make sure the traffic is not mixed - 'curl http://$endpoint | grep "Hi, new SkyPilot here"', + 'curl $endpoint | grep "Hi, new SkyPilot here"', # The latest 2 version should be READY and the older versions should be shutting down (_check_replica_in_status(name, [(2, False, 'READY'), (2, False, 'SHUTTING_DOWN')]) + @@ -4114,14 +4114,14 @@ def test_skyserve_rolling_update(generic_cloud: str): [ f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/update/old.yaml', _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2), - f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; curl http://$endpoint | grep "Hi, SkyPilot here"', + f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; curl $endpoint | grep "Hi, SkyPilot here"', f'sky serve update {name} --cloud {generic_cloud} -y tests/skyserve/update/new.yaml', # Make sure the traffic is mixed across two versions, the replicas # with even id will sleep 60 seconds before being ready, so we # should be able to get observe the period that the traffic is mixed # across two versions. f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; ' - 'until curl http://$endpoint | grep "Hi, new SkyPilot here!"; do sleep 2; done; sleep 2; ' + 'until curl $endpoint | grep "Hi, new SkyPilot here!"; do sleep 2; done; sleep 2; ' # The latest version should have one READY and the one of the older versions should be shutting down f'{single_new_replica} {_check_service_version(name, "1,2")} ' # Check the output from the old version, immediately after the @@ -4130,7 +4130,7 @@ def test_skyserve_rolling_update(generic_cloud: str): # TODO(zhwu): we should have a more generalized way for checking the # mixed version of replicas to avoid depending on the specific # round robin load balancing policy. - 'curl http://$endpoint | grep "Hi, SkyPilot here"', + 'curl $endpoint | grep "Hi, SkyPilot here"', ], _TEARDOWN_SERVICE.format(name=name), timeout=20 * 60, @@ -4149,7 +4149,7 @@ def test_skyserve_fast_update(generic_cloud: str): [ f'sky serve up -n {name} -y --cloud {generic_cloud} tests/skyserve/update/bump_version_before.yaml', _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2), - f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; curl http://$endpoint | grep "Hi, SkyPilot here"', + f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; curl $endpoint | grep "Hi, SkyPilot here"', f'sky serve update {name} --cloud {generic_cloud} --mode blue_green -y tests/skyserve/update/bump_version_after.yaml', # sleep to wait for update to be registered. 'sleep 40', @@ -4162,7 +4162,7 @@ def test_skyserve_fast_update(generic_cloud: str): _check_service_version(name, "2")), _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=3) + _check_service_version(name, "2"), - f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; curl http://$endpoint | grep "Hi, SkyPilot here"', + f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; curl $endpoint | grep "Hi, SkyPilot here"', # Test rolling update f'sky serve update {name} --cloud {generic_cloud} -y tests/skyserve/update/bump_version_before.yaml', # sleep to wait for update to be registered. @@ -4172,7 +4172,7 @@ def test_skyserve_fast_update(generic_cloud: str): (1, False, 'SHUTTING_DOWN')]), _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2) + _check_service_version(name, "3"), - f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; curl http://$endpoint | grep "Hi, SkyPilot here"', + f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; curl $endpoint | grep "Hi, SkyPilot here"', ], _TEARDOWN_SERVICE.format(name=name), timeout=30 * 60, @@ -4191,7 +4191,7 @@ def test_skyserve_update_autoscale(generic_cloud: str): _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2) + _check_service_version(name, "1"), f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; ' - 'curl http://$endpoint | grep "Hi, SkyPilot here"', + 'curl $endpoint | grep "Hi, SkyPilot here"', f'sky serve update {name} --cloud {generic_cloud} --mode blue_green -y tests/skyserve/update/num_min_one.yaml', # sleep before update is registered. 'sleep 20', @@ -4199,7 +4199,7 @@ def test_skyserve_update_autoscale(generic_cloud: str): _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=1) + _check_service_version(name, "2"), f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; ' - 'curl http://$endpoint | grep "Hi, SkyPilot here!"', + 'curl $endpoint | grep "Hi, SkyPilot here!"', # Rolling Update f'sky serve update {name} --cloud {generic_cloud} -y tests/skyserve/update/num_min_two.yaml', # sleep before update is registered. @@ -4208,7 +4208,7 @@ def test_skyserve_update_autoscale(generic_cloud: str): _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2) + _check_service_version(name, "3"), f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; ' - 'curl http://$endpoint | grep "Hi, SkyPilot here!"', + 'curl $endpoint | grep "Hi, SkyPilot here!"', ], _TEARDOWN_SERVICE.format(name=name), timeout=30 * 60, @@ -4258,7 +4258,7 @@ def test_skyserve_new_autoscaler_update(mode: str, generic_cloud: str): _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=2) + _check_service_version(name, "1"), f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; ' - 's=$(curl http://$endpoint); echo "$s"; echo "$s" | grep "Hi, SkyPilot here"', + 's=$(curl $endpoint); echo "$s"; echo "$s" | grep "Hi, SkyPilot here"', f'sky serve update {name} --cloud {generic_cloud} --mode {mode} -y tests/skyserve/update/new_autoscaler_after.yaml', # Wait for update to be registered f'sleep 90', @@ -4270,7 +4270,7 @@ def test_skyserve_new_autoscaler_update(mode: str, generic_cloud: str): *update_check, _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=5), f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; ' - 'curl http://$endpoint | grep "Hi, SkyPilot here"', + 'curl $endpoint | grep "Hi, SkyPilot here"', _check_replica_in_status(name, [(4, True, 'READY'), (1, False, 'READY')]), ], @@ -4328,6 +4328,40 @@ def test_skyserve_failures(generic_cloud: str): run_one_test(test) +@pytest.mark.serve +def test_skyserve_https(generic_cloud: str): + """Test skyserve with https""" + name = _get_service_name() + + with tempfile.TemporaryDirectory() as tempdir: + keyfile = os.path.join(tempdir, 'key.pem') + certfile = os.path.join(tempdir, 'cert.pem') + subprocess_utils.run_no_outputs( + f'openssl req -x509 -newkey rsa:2048 -days 36500 -nodes ' + f'-subj "/" -keyout {keyfile} -out {certfile}') + + test = Test( + f'test-skyserve-https', + [ + f'sky serve up -n {name} --cloud {generic_cloud} -y tests/skyserve/https/service.yaml ' + f'--env TLS_KEYFILE_ENV_VAR={keyfile} --env TLS_CERTFILE_ENV_VAR={certfile}', + _SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=1), + f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; ' + 'curl $endpoint -k | grep "Hi, SkyPilot here"', + # Self signed certificate should fail without -k. + f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; ' + 'curl $endpoint 2>&1 | grep "self signed certificate"', + # curl with wrong schema (http) should fail. + f'{_SERVE_ENDPOINT_WAIT.format(name=name)}; ' + 'http_endpoint="${endpoint/https:/http:}"; ' + 'curl $http_endpoint 2>&1 | grep "Empty reply from server"', + ], + _TEARDOWN_SERVICE.format(name=name), + timeout=20 * 60, + ) + run_one_test(test) + + # TODO(Ziming, Tian): Add tests for autoscaling.