From fb83d398f63f312d1586d33ebd82fa2e060e8ce8 Mon Sep 17 00:00:00 2001 From: cblmemo Date: Thu, 14 Nov 2024 14:32:00 -0800 Subject: [PATCH 1/9] [WIP][Serve] Enable launching multiple external LB on controller. --- examples/serve/external-lb.yaml | 31 ++++ sky/serve/constants.py | 2 + sky/serve/core.py | 45 +++++- sky/serve/serve_state.py | 67 ++++++++ sky/serve/serve_utils.py | 84 +++++++++- sky/serve/service.py | 151 +++++++++++++++--- sky/serve/service_spec.py | 41 ++++- .../sky-serve-external-load-balancer.yaml.j2 | 23 +++ sky/utils/controller_utils.py | 32 ++-- sky/utils/schemas.py | 21 ++- 10 files changed, 434 insertions(+), 63 deletions(-) create mode 100644 examples/serve/external-lb.yaml create mode 100644 sky/templates/sky-serve-external-load-balancer.yaml.j2 diff --git a/examples/serve/external-lb.yaml b/examples/serve/external-lb.yaml new file mode 100644 index 00000000000..e96d5e0fda0 --- /dev/null +++ b/examples/serve/external-lb.yaml @@ -0,0 +1,31 @@ +# SkyServe YAML to run multiple Load Balancers in different region. + +name: multi-lb + +service: + readiness_probe: + path: /health + initial_delay_seconds: 20 + replicas: 2 + external_load_balancers: + - resources: + # cloud: aws + # region: us-east-2 + cloud: gcp + region: us-east1 + load_balancing_policy: round_robin + - resources: + # cloud: aws + # region: ap-northeast-1 + cloud: gcp + region: asia-northeast1 + load_balancing_policy: round_robin + +resources: + cloud: aws + ports: 8080 + cpus: 2+ + +workdir: examples/serve/http_server + +run: python3 server.py diff --git a/sky/serve/constants.py b/sky/serve/constants.py index 3974293190e..813aa0d6d0e 100644 --- a/sky/serve/constants.py +++ b/sky/serve/constants.py @@ -1,6 +1,7 @@ """Constants used for SkyServe.""" CONTROLLER_TEMPLATE = 'sky-serve-controller.yaml.j2' +EXTERNAL_LB_TEMPLATE = 'sky-serve-external-load-balancer.yaml.j2' SKYSERVE_METADATA_DIR = '~/.sky/serve' @@ -79,6 +80,7 @@ # Default port range start for controller and load balancer. Ports will be # automatically generated from this start port. CONTROLLER_PORT_START = 20001 +CONTROLLER_PORT_RANGE = '20001-20020' LOAD_BALANCER_PORT_START = 30001 LOAD_BALANCER_PORT_RANGE = '30001-30020' diff --git a/sky/serve/core.py b/sky/serve/core.py index abf9bfbc719..a006500679f 100644 --- a/sky/serve/core.py +++ b/sky/serve/core.py @@ -174,12 +174,21 @@ def up( vars_to_fill, output_path=controller_file.name) controller_task = task_lib.Task.from_yaml(controller_file.name) + # TODO(tian): Currently we exposed the controller port to the public + # network, for external load balancer to access. We should implement + # encrypted communication between controller and load balancer, and + # not expose the controller to the public network. + assert task.service is not None + ports_to_open_in_controller = (serve_constants.CONTROLLER_PORT_RANGE + if task.service.external_load_balancers + is not None else + serve_constants.LOAD_BALANCER_PORT_RANGE) # TODO(tian): Probably run another sky.launch after we get the load # balancer port from the controller? So we don't need to open so many # ports here. Or, we should have a nginx traffic control to refuse # any connection to the unregistered ports. controller_resources = { - r.copy(ports=[serve_constants.LOAD_BALANCER_PORT_RANGE]) + r.copy(ports=[ports_to_open_in_controller]) for r in controller_resources } controller_task.set_resources(controller_resources) @@ -267,12 +276,18 @@ def up( 'Failed to spin up the service. Please ' 'check the logs above for more details.') from None else: - lb_port = serve_utils.load_service_initialization_result( - lb_port_payload) - endpoint = backend_utils.get_endpoints( - controller_handle.cluster_name, lb_port, - skip_status_check=True).get(lb_port) - assert endpoint is not None, 'Did not get endpoint for controller.' + if task.service.external_load_balancers is None: + lb_port = serve_utils.load_service_initialization_result( + lb_port_payload) + endpoint = backend_utils.get_endpoints( + controller_handle.cluster_name, + lb_port, + skip_status_check=True).get(lb_port) + assert endpoint is not None, ( + 'Did not get endpoint for controller.') + else: + endpoint = ( + 'Please query with sky serve status for the endpoint.') sky_logging.print( f'{fore.CYAN}Service name: ' @@ -320,6 +335,7 @@ def update( task: sky.Task to update. service_name: Name of the service. """ + # TODO(tian): Implement update of external LBs. _validate_service_task(task) # Always apply the policy again here, even though it might have been applied # in the CLI. This is to ensure that we apply the policy to the final DAG @@ -585,6 +601,8 @@ def status( 'requested_resources_str': (str) str representation of requested resources, 'replica_info': (List[Dict[str, Any]]) replica information, + 'external_lb_info': (Dict[str, Any]) external load balancer + information, } Each entry in replica_info has the following fields: @@ -600,6 +618,17 @@ def status( 'handle': (ResourceHandle) handle of the replica cluster, } + Each entry in external_lb_info has the following fields: + + .. code-block:: python + + { + 'lb_id': (int) index of the external load balancer, + 'cluster_name': (str) cluster name of the external load balancer, + 'port': (int) port of the external load balancer, + 'endpoint': (str) endpoint of the external load balancer, + } + For possible service statuses and replica statuses, please refer to sky.cli.serve_status. @@ -695,6 +724,8 @@ def tail_logs( sky.exceptions.ClusterNotUpError: the sky serve controller is not up. ValueError: arguments not valid, or failed to tail the logs. """ + # TODO(tian): Support tail logs for external load balancer. It should be + # similar to tail replica logs. if isinstance(target, str): target = serve_utils.ServiceComponent(target) if not isinstance(target, serve_utils.ServiceComponent): diff --git a/sky/serve/serve_state.py b/sky/serve/serve_state.py index 333e0138fb4..ed8c06e5459 100644 --- a/sky/serve/serve_state.py +++ b/sky/serve/serve_state.py @@ -12,6 +12,7 @@ from sky.serve import constants from sky.utils import db_utils +from sky import exceptions if typing.TYPE_CHECKING: from sky.serve import replica_managers @@ -58,6 +59,13 @@ def create_table(cursor: 'sqlite3.Cursor', conn: 'sqlite3.Connection') -> None: service_name TEXT, spec BLOB, PRIMARY KEY (service_name, version))""") + cursor.execute("""\ + CREATE TABLE IF NOT EXISTS external_load_balancers ( + lb_id INTEGER, + service_name TEXT, + cluster_name TEXT, + port INTEGER, + PRIMARY KEY (service_name, lb_id))""") conn.commit() @@ -538,3 +546,62 @@ def delete_all_versions(service_name: str) -> None: """\ DELETE FROM version_specs WHERE service_name=(?)""", (service_name,)) + + +# === External Load Balancer functions === +# TODO(tian): Add a status column. +def add_external_load_balancer(service_name: str, lb_id: int, cluster_name: str, + port: int) -> None: + """Adds an external load balancer to the database.""" + with db_utils.safe_cursor(_DB_PATH) as cursor: + cursor.execute( + """\ + INSERT INTO external_load_balancers + (service_name, lb_id, cluster_name, port) + VALUES (?, ?, ?, ?)""", (service_name, lb_id, cluster_name, port)) + + +def _get_external_load_balancer_from_row(row) -> Dict[str, Any]: + from sky import core # pylint: disable=import-outside-toplevel + + # TODO(tian): Temporary workaround to avoid circular import. + # This should be fixed. + lb_id, cluster_name, port = row[:3] + try: + endpoint = core.endpoints(cluster_name, port)[port] + except exceptions.ClusterNotUpError: + # TODO(tian): Currently, when this cluster is not in the UP status, + # the endpoint query will raise an cluster is not up error. We should + # implement a status for external lbs as well and returns a '-' when + # it is still provisioning. + endpoint = '-' + return { + 'lb_id': lb_id, + 'cluster_name': cluster_name, + 'port': port, + 'endpoint': endpoint, + } + + +def get_external_load_balancers(service_name: str) -> List[Dict[str, Any]]: + """Gets all external load balancers of a service.""" + with db_utils.safe_cursor(_DB_PATH) as cursor: + rows = cursor.execute( + """\ + SELECT lb_id, cluster_name, port FROM external_load_balancers + WHERE service_name=(?)""", (service_name,)).fetchall() + external_load_balancers = [] + for row in rows: + external_load_balancers.append( + _get_external_load_balancer_from_row(row)) + return external_load_balancers + + +def remove_external_load_balancer(service_name: str, lb_id: int) -> None: + """Removes an external load balancer from the database.""" + with db_utils.safe_cursor(_DB_PATH) as cursor: + cursor.execute( + """\ + DELETE FROM external_load_balancers + WHERE service_name=(?) + AND lb_id=(?)""", (service_name, lb_id)) diff --git a/sky/serve/serve_utils.py b/sky/serve/serve_utils.py index 3be41cc1593..bd500432467 100644 --- a/sky/serve/serve_utils.py +++ b/sky/serve/serve_utils.py @@ -56,6 +56,8 @@ # Max number of replicas to show in `sky serve status` by default. # If user wants to see all replicas, use `sky serve status --all`. _REPLICA_TRUNC_NUM = 10 +# Similar to _REPLICA_TRUNC_NUM, but for external load balancers. +_EXTERNAL_LB_TRUNC_NUM = 10 class ServiceComponent(enum.Enum): @@ -224,6 +226,13 @@ def generate_remote_load_balancer_log_file_name(service_name: str) -> str: return os.path.join(dir_name, 'load_balancer.log') +def generate_remote_external_load_balancer_log_file_name( + service_name: str, lb_id: int) -> str: + dir_name = generate_remote_service_dir_name(service_name) + # Don't expand here since it is used for remote machine. + return os.path.join(dir_name, f'external_load_balancer_{lb_id}.log') + + def generate_replica_launch_log_file_name(service_name: str, replica_id: int) -> str: dir_name = generate_remote_service_dir_name(service_name) @@ -354,7 +363,8 @@ def terminate_replica(service_name: str, replica_id: int, purge: bool) -> str: def _get_service_status( service_name: str, - with_replica_info: bool = True) -> Optional[Dict[str, Any]]: + with_replica_info: bool = True, + with_external_lb_info: bool = True) -> Optional[Dict[str, Any]]: """Get the status dict of the service. Args: @@ -373,6 +383,9 @@ def _get_service_status( info.to_info_dict(with_handle=True) for info in serve_state.get_replica_infos(service_name) ] + if with_external_lb_info: + record['external_lb_info'] = serve_state.get_external_load_balancers( + service_name) return record @@ -457,7 +470,8 @@ def terminate_services(service_names: Optional[List[str]], purge: bool) -> str: messages = [] for service_name in service_names: service_status = _get_service_status(service_name, - with_replica_info=False) + with_replica_info=False, + with_external_lb_info=False) if (service_status is not None and service_status['status'] == serve_state.ServiceStatus.SHUTTING_DOWN): # Already scheduled to be terminated. @@ -810,10 +824,14 @@ def format_service_table(service_records: List[Dict[str, Any]], service_table = log_utils.create_table(service_columns) replica_infos = [] + external_lb_infos = [] for record in service_records: for replica in record['replica_info']: replica['service_name'] = record['name'] replica_infos.append(replica) + for external_lb in record['external_lb_info']: + external_lb['service_name'] = record['name'] + external_lb_infos.append(external_lb) service_name = record['name'] version = ','.join( @@ -824,7 +842,12 @@ def format_service_table(service_records: List[Dict[str, Any]], service_status = record['status'] status_str = service_status.colored_str() replicas = _get_replicas(record) - endpoint = get_endpoint(record) + if record['external_lb_info']: + # Don't show endpoint for services with external load balancers. + # TODO(tian): Add automatic DNS record creation and show domain here + endpoint = '-' + else: + endpoint = get_endpoint(record) policy = record['policy'] requested_resources_str = record['requested_resources_str'] @@ -841,10 +864,20 @@ def format_service_table(service_records: List[Dict[str, Any]], service_table.add_row(service_values) replica_table = _format_replica_table(replica_infos, show_all) - return (f'{service_table}\n' - f'\n{colorama.Fore.CYAN}{colorama.Style.BRIGHT}' - f'Service Replicas{colorama.Style.RESET_ALL}\n' - f'{replica_table}') + + final_table = (f'{service_table}\n' + f'\n{colorama.Fore.CYAN}{colorama.Style.BRIGHT}' + f'Service Replicas{colorama.Style.RESET_ALL}\n' + f'{replica_table}') + + if external_lb_infos: + external_lb_table = _format_external_lb_table(external_lb_infos, + show_all) + final_table += (f'\n\n{colorama.Fore.CYAN}{colorama.Style.BRIGHT}' + f'External Load Balancers{colorama.Style.RESET_ALL}\n' + f'{external_lb_table}') + + return final_table def _format_replica_table(replica_records: List[Dict[str, Any]], @@ -905,6 +938,43 @@ def _format_replica_table(replica_records: List[Dict[str, Any]], return f'{replica_table}{truncate_hint}' +def _format_external_lb_table(external_lb_records: List[Dict[str, Any]], + show_all: bool) -> str: + if not external_lb_records: + return 'No existing external load balancers.' + + external_lb_columns = ['SERVICE_NAME', 'ID', 'ENDPOINT'] + if show_all: + external_lb_columns.append('PORT') + external_lb_columns.append('CLUSTER_NAME') + external_lb_table = log_utils.create_table(external_lb_columns) + + truncate_hint = '' + if not show_all: + if len(external_lb_records) > _EXTERNAL_LB_TRUNC_NUM: + truncate_hint = ( + '\n... (use --all to show all external load balancers)') + external_lb_records = external_lb_records[:_EXTERNAL_LB_TRUNC_NUM] + + for record in external_lb_records: + service_name = record['service_name'] + external_lb_id = record['lb_id'] + endpoint = record['endpoint'] + port = record['port'] + cluster_name = record['cluster_name'] + + external_lb_values = [ + service_name, + external_lb_id, + endpoint, + ] + if show_all: + external_lb_values.extend([port, cluster_name]) + external_lb_table.add_row(external_lb_values) + + return f'{external_lb_table}{truncate_hint}' + + # =========================== CodeGen for Sky Serve =========================== diff --git a/sky/serve/service.py b/sky/serve/service.py index 0a1c7f34766..999d82cb3b8 100644 --- a/sky/serve/service.py +++ b/sky/serve/service.py @@ -7,25 +7,30 @@ import os import pathlib import shutil +import subprocess +import tempfile import time import traceback -from typing import Dict +from typing import Any, Dict import filelock from sky import authentication from sky import exceptions +from sky import resources as resources_lib from sky import sky_logging from sky import task as task_lib from sky.backends import backend_utils from sky.backends import cloud_vm_ray_backend from sky.serve import constants +from sky.skylet import constants as skylet_constants from sky.serve import controller from sky.serve import load_balancer from sky.serve import replica_managers from sky.serve import serve_state from sky.serve import serve_utils from sky.utils import common_utils +from sky.utils import controller_utils from sky.utils import subprocess_utils from sky.utils import ux_utils @@ -89,6 +94,8 @@ def _cleanup(service_name: str) -> bool: replica_infos = serve_state.get_replica_infos(service_name) info2proc: Dict[replica_managers.ReplicaInfo, multiprocessing.Process] = dict() + external_lbs = serve_state.get_external_load_balancers(service_name) + lbid2proc: Dict[int, multiprocessing.Process] = dict() for info in replica_infos: p = multiprocessing.Process(target=replica_managers.terminate_cluster, args=(info.cluster_name,)) @@ -101,6 +108,14 @@ def _cleanup(service_name: str) -> bool: replica_managers.ProcessStatus.RUNNING) serve_state.add_or_update_replica(service_name, info.replica_id, info) logger.info(f'Terminating replica {info.replica_id} ...') + for external_lb_record in external_lbs: + lb_cluster_name = external_lb_record['cluster_name'] + lb_id = external_lb_record['lb_id'] + p = multiprocessing.Process(target=replica_managers.terminate_cluster, + args=(lb_cluster_name,)) + p.start() + lbid2proc[lb_id] = p + logger.info(f'Terminating external load balancer {lb_cluster_name} ...') for info, p in info2proc.items(): p.join() if p.exitcode == 0: @@ -114,6 +129,15 @@ def _cleanup(service_name: str) -> bool: info) failed = True logger.error(f'Replica {info.replica_id} failed to terminate.') + for lb_id, p in lbid2proc.items(): + p.join() + if p.exitcode == 0: + serve_state.remove_external_load_balancer(service_name, lb_id) + logger.info( + f'External load balancer {lb_id} terminated successfully.') + else: + failed = True + logger.error(f'External load balancer {lb_id} failed to terminate.') versions = serve_state.get_service_versions(service_name) serve_state.remove_service_versions(service_name) @@ -130,6 +154,50 @@ def cleanup_version_storage(version: int) -> bool: return failed +def _get_external_lb_cluster_name(service_name: str, lb_id: int) -> str: + return f'sky-{service_name}-lb-{lb_id}' + + +def _start_external_load_balancer(service_name: str, controller_addr: str, + lb_id: int, lb_port: int, lb_policy: str, + lb_resources: Dict[str, Any]) -> None: + # TODO(tian): Hack. We should figure out the optimal resoruces. + if 'cpus' not in lb_resources: + lb_resources['cpus'] = '2+' + # Already checked in service spec validation. + assert 'ports' not in lb_resources + lb_resources['ports'] = [lb_port] + lbr = resources_lib.Resources.from_yaml_config(lb_resources) + lb_cluster_name = _get_external_lb_cluster_name(service_name, lb_id) + # TODO(tian): Set delete=False to debug. Remove this on production. + with tempfile.NamedTemporaryFile(prefix=lb_cluster_name, + mode='w', + delete=False) as f: + vars_to_fill = { + 'load_balancer_port': lb_port, + 'controller_addr': controller_addr, + 'load_balancing_policy': lb_policy, + 'sky_activate_python_env': skylet_constants.ACTIVATE_SKY_REMOTE_PYTHON_ENV, + 'lb_envs': controller_utils.sky_managed_cluster_envs(), + } + common_utils.fill_template(constants.EXTERNAL_LB_TEMPLATE, + vars_to_fill, + output_path=f.name) + lb_task = task_lib.Task.from_yaml(f.name) + lb_task.set_resources(lbr) + serve_state.add_external_load_balancer(service_name, lb_id, + lb_cluster_name, lb_port) + # TODO(tian): Temporary solution for circular import. We should move + # the import to the top of the file. + import sky # pylint: disable=import-outside-toplevel + sky.launch( + task=lb_task, + stream_logs=True, + cluster_name=lb_cluster_name, + retry_until_up=True, + ) + + def _start(service_name: str, tmp_task_yaml: str, job_id: int): """Starts the service.""" # Generate ssh key pair to avoid race condition when multiple sky.launch @@ -177,12 +245,7 @@ def _start(service_name: str, tmp_task_yaml: str, job_id: int): service_name, constants.INITIAL_VERSION) shutil.copy(tmp_task_yaml, task_yaml) - # Generate load balancer log file name. - load_balancer_log_file = os.path.expanduser( - serve_utils.generate_remote_load_balancer_log_file_name(service_name)) - controller_process = None - load_balancer_process = None try: with filelock.FileLock( os.path.expanduser(constants.PORT_SELECTION_FILE_LOCK_PATH)): @@ -202,6 +265,12 @@ def _get_host(): # ('::1', 20001, 0, 0): cannot assign requested address return '127.0.0.1' + def _get_external_host(): + assert service_spec.external_load_balancers is not None + # TODO(tian): Use a more robust way to get the host. + return subprocess.check_output( + 'curl ifconfig.me', shell=True).decode('utf-8').strip() + controller_host = _get_host() # Start the controller. @@ -215,25 +284,55 @@ def _get_host(): # TODO(tian): Support HTTPS. controller_addr = f'http://{controller_host}:{controller_port}' - - load_balancer_port = common_utils.find_free_port( - constants.LOAD_BALANCER_PORT_START) - - # Extract the load balancing policy from the service spec - policy_name = service_spec.load_balancing_policy - - # Start the load balancer. - # TODO(tian): Probably we could enable multiple ports specified in - # service spec and we could start multiple load balancers. - # After that, we will have a mapping from replica port to endpoint. - load_balancer_process = multiprocessing.Process( - target=ux_utils.RedirectOutputForProcess( - load_balancer.run_load_balancer, - load_balancer_log_file).run, - args=(controller_addr, load_balancer_port, policy_name)) - load_balancer_process.start() - serve_state.set_service_load_balancer_port(service_name, - load_balancer_port) + load_balancer_processes = [] + + if service_spec.external_load_balancers is None: + # Generate load balancer log file name. + load_balancer_log_file = os.path.expanduser( + serve_utils.generate_remote_load_balancer_log_file_name( + service_name)) + + load_balancer_port = common_utils.find_free_port( + constants.LOAD_BALANCER_PORT_START) + + # Extract the load balancing policy from the service spec + policy_name = service_spec.load_balancing_policy + + # Start the load balancer. + # TODO(tian): Probably we could enable multiple ports specified + # in service spec and we could start multiple load balancers. + # After that, we need a mapping from replica port to endpoint. + load_balancer_process = multiprocessing.Process( + target=ux_utils.RedirectOutputForProcess( + load_balancer.run_load_balancer, + load_balancer_log_file).run, + args=(controller_addr, load_balancer_port, policy_name)) + load_balancer_process.start() + load_balancer_processes.append(load_balancer_process) + serve_state.set_service_load_balancer_port( + service_name, load_balancer_port) + else: + lb_port = 8000 + for lb_id, lb_config in enumerate( + service_spec.external_load_balancers): + # Generate load balancer log file name. + load_balancer_log_file = os.path.expanduser( + serve_utils. + generate_remote_external_load_balancer_log_file_name( + service_name, lb_id)) + lb_policy = lb_config['load_balancing_policy'] + lb_resources = lb_config['resources'] + controller_external_addr = ( + f'http://{_get_external_host()}:{controller_port}') + lb_process = multiprocessing.Process( + target=ux_utils.RedirectOutputForProcess( + _start_external_load_balancer, + load_balancer_log_file).run, + args=(service_name, controller_external_addr, lb_id, + lb_port, lb_policy, lb_resources)) + lb_process.start() + load_balancer_processes.append(lb_process) + serve_state.set_service_load_balancer_port(service_name, -1) while True: _handle_signal(service_name) @@ -245,7 +344,7 @@ def _get_host(): # Kill load balancer process first since it will raise errors if failed # to connect to the controller. Then the controller process. process_to_kill = [ - proc for proc in [load_balancer_process, controller_process] + proc for proc in [*load_balancer_processes, controller_process] if proc is not None ] subprocess_utils.kill_children_processes( diff --git a/sky/serve/service_spec.py b/sky/serve/service_spec.py index 000eed139f1..8c888eb3dc2 100644 --- a/sky/serve/service_spec.py +++ b/sky/serve/service_spec.py @@ -2,11 +2,11 @@ import json import os import textwrap -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import yaml -from sky import serve +from sky import resources as resources_lib from sky.serve import constants from sky.utils import common_utils from sky.utils import schemas @@ -31,6 +31,7 @@ def __init__( upscale_delay_seconds: Optional[int] = None, downscale_delay_seconds: Optional[int] = None, load_balancing_policy: Optional[str] = None, + external_load_balancers: Optional[List[Dict[str, Any]]] = None, ) -> None: if max_replicas is not None and max_replicas < min_replicas: with ux_utils.print_exception_no_traceback(): @@ -57,13 +58,27 @@ def __init__( raise ValueError('readiness_path must start with a slash (/). ' f'Got: {readiness_path}') - # Add the check for unknown load balancing policies + # Use load_balancing_policy as fallback for external_load_balancers if (load_balancing_policy is not None and - load_balancing_policy not in serve.LB_POLICIES): - with ux_utils.print_exception_no_traceback(): - raise ValueError( - f'Unknown load balancing policy: {load_balancing_policy}. ' - f'Available policies: {list(serve.LB_POLICIES.keys())}') + external_load_balancers is not None): + for lb_config in external_load_balancers: + if lb_config.get('load_balancing_policy') is None: + lb_config['load_balancing_policy'] = load_balancing_policy + + if external_load_balancers is not None: + for lb_config in external_load_balancers: + r = lb_config.get('resources') + if r is None: + with ux_utils.print_exception_no_traceback(): + raise ValueError('`resources` must be set for ' + 'external_load_balancers.') + if 'ports' in r: + with ux_utils.print_exception_no_traceback(): + raise ValueError('`ports` must not be set for ' + 'external_load_balancers.') + # Validate resources + resources_lib.Resources.from_yaml_config(r) + self._readiness_path: str = readiness_path self._initial_delay_seconds: int = initial_delay_seconds self._readiness_timeout_seconds: int = readiness_timeout_seconds @@ -79,6 +94,8 @@ def __init__( self._upscale_delay_seconds: Optional[int] = upscale_delay_seconds self._downscale_delay_seconds: Optional[int] = downscale_delay_seconds self._load_balancing_policy: Optional[str] = load_balancing_policy + self._external_load_balancers: Optional[List[Dict[str, Any]]] = ( + external_load_balancers) self._use_ondemand_fallback: bool = ( self.dynamic_ondemand_fallback is not None and @@ -162,6 +179,8 @@ def from_yaml_config(config: Dict[str, Any]) -> 'SkyServiceSpec': service_config['load_balancing_policy'] = config.get( 'load_balancing_policy', None) + service_config['external_load_balancers'] = config.get( + 'external_load_balancers', None) return SkyServiceSpec(**service_config) @staticmethod @@ -219,6 +238,8 @@ def add_if_not_none(section, key, value, no_empty: bool = False): self.downscale_delay_seconds) add_if_not_none('load_balancing_policy', None, self._load_balancing_policy) + add_if_not_none('external_load_balancers', None, + self._external_load_balancers) return config def probe_str(self): @@ -329,3 +350,7 @@ def use_ondemand_fallback(self) -> bool: @property def load_balancing_policy(self) -> Optional[str]: return self._load_balancing_policy + + @property + def external_load_balancers(self) -> Optional[List[Dict[str, Any]]]: + return self._external_load_balancers diff --git a/sky/templates/sky-serve-external-load-balancer.yaml.j2 b/sky/templates/sky-serve-external-load-balancer.yaml.j2 new file mode 100644 index 00000000000..0c196d6ad3e --- /dev/null +++ b/sky/templates/sky-serve-external-load-balancer.yaml.j2 @@ -0,0 +1,23 @@ +# The template for the sky serve load balancers + +name: load-balancer + +setup: | + {{ sky_activate_python_env }} + # Install serve dependencies. + # TODO(tian): Gather those into serve constants. + pip list | grep uvicorn > /dev/null 2>&1 || pip install uvicorn > /dev/null 2>&1 + pip list | grep fastapi > /dev/null 2>&1 || pip install fastapi > /dev/null 2>&1 + pip list | grep httpx > /dev/null 2>&1 || pip install httpx > /dev/null 2>&1 + +run: | + {{ sky_activate_python_env }} + python -u -m sky.serve.load_balancer \ + --controller-addr {{controller_addr}} \ + --load-balancer-port {{load_balancer_port}} \ + --load-balancing-policy {{load_balancing_policy}} + +envs: +{%- for env_name, env_value in lb_envs.items() %} + {{env_name}}: {{env_value}} +{%- endfor %} diff --git a/sky/utils/controller_utils.py b/sky/utils/controller_utils.py index 0ab2fd7e117..be0db2ddaa9 100644 --- a/sky/utils/controller_utils.py +++ b/sky/utils/controller_utils.py @@ -391,6 +391,24 @@ def download_and_stream_latest_job_log( return log_file +# TODO(tian): Maybe move this to other places? +def sky_managed_cluster_envs() -> Dict[str, str]: + env_vars: Dict[str, str] = { + env.env_key: str(int(env.get())) for env in env_options.Options + } + env_vars.update({ + # Should not use $USER here, as that env var can be empty when + # running in a container. + constants.USER_ENV_VAR: getpass.getuser(), + constants.USER_ID_ENV_VAR: common_utils.get_user_hash(), + # Skip cloud identity check to avoid the overhead. + env_options.Options.SKIP_CLOUD_IDENTITY_CHECK.env_key: '1', + # Disable minimize logging to get more details on the controller. + env_options.Options.MINIMIZE_LOGGING.env_key: '0', + }) + return env_vars + + def shared_controller_vars_to_fill( controller: Controllers, remote_user_config_path: str, local_user_config: Dict[str, Any]) -> Dict[str, str]: @@ -425,19 +443,7 @@ def shared_controller_vars_to_fill( 'sky_python_cmd': constants.SKY_PYTHON_CMD, 'local_user_config_path': local_user_config_path, } - env_vars: Dict[str, str] = { - env.env_key: str(int(env.get())) for env in env_options.Options - } - env_vars.update({ - # Should not use $USER here, as that env var can be empty when - # running in a container. - constants.USER_ENV_VAR: getpass.getuser(), - constants.USER_ID_ENV_VAR: common_utils.get_user_hash(), - # Skip cloud identity check to avoid the overhead. - env_options.Options.SKIP_CLOUD_IDENTITY_CHECK.env_key: '1', - # Disable minimize logging to get more details on the controller. - env_options.Options.MINIMIZE_LOGGING.env_key: '0', - }) + env_vars = sky_managed_cluster_envs() if skypilot_config.loaded(): # Only set the SKYPILOT_CONFIG env var if the user has a config file. env_vars[ diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 4d5cc809013..1a2b535b9d1 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -311,6 +311,7 @@ def get_service_schema(): # To avoid circular imports, only import when needed. # pylint: disable=import-outside-toplevel from sky.serve import load_balancing_policies + lb_policy_choices = list(load_balancing_policies.LB_POLICIES.keys()) return { '$schema': 'https://json-schema.org/draft/2020-12/schema', 'type': 'object', @@ -385,10 +386,26 @@ def get_service_schema(): 'replicas': { 'type': 'integer', }, + 'external_load_balancers': { + 'type': 'array', + 'items': { + 'type': 'object', + 'required': ['resources', 'load_balancing_policy'], + 'additionalProperties': False, + 'properties': { + 'resources': { + 'type': 'object', + }, + 'load_balancing_policy': { + 'type': 'string', + 'case_insensitive_enum': lb_policy_choices, + }, + } + } + }, 'load_balancing_policy': { 'type': 'string', - 'case_insensitive_enum': list( - load_balancing_policies.LB_POLICIES.keys()) + 'case_insensitive_enum': lb_policy_choices, }, } } From b3e0d414e3cd98abfd6acd83224f1be64daa5896 Mon Sep 17 00:00:00 2001 From: cblmemo Date: Thu, 14 Nov 2024 14:36:51 -0800 Subject: [PATCH 2/9] format --- sky/serve/serve_state.py | 2 +- sky/serve/service.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sky/serve/serve_state.py b/sky/serve/serve_state.py index ed8c06e5459..39391fdbf66 100644 --- a/sky/serve/serve_state.py +++ b/sky/serve/serve_state.py @@ -10,9 +10,9 @@ import colorama +from sky import exceptions from sky.serve import constants from sky.utils import db_utils -from sky import exceptions if typing.TYPE_CHECKING: from sky.serve import replica_managers diff --git a/sky/serve/service.py b/sky/serve/service.py index 999d82cb3b8..9c157058b6f 100644 --- a/sky/serve/service.py +++ b/sky/serve/service.py @@ -23,12 +23,12 @@ from sky.backends import backend_utils from sky.backends import cloud_vm_ray_backend from sky.serve import constants -from sky.skylet import constants as skylet_constants from sky.serve import controller from sky.serve import load_balancer from sky.serve import replica_managers from sky.serve import serve_state from sky.serve import serve_utils +from sky.skylet import constants as skylet_constants from sky.utils import common_utils from sky.utils import controller_utils from sky.utils import subprocess_utils @@ -177,7 +177,8 @@ def _start_external_load_balancer(service_name: str, controller_addr: str, 'load_balancer_port': lb_port, 'controller_addr': controller_addr, 'load_balancing_policy': lb_policy, - 'sky_activate_python_env': skylet_constants.ACTIVATE_SKY_REMOTE_PYTHON_ENV, + 'sky_activate_python_env': + skylet_constants.ACTIVATE_SKY_REMOTE_PYTHON_ENV, 'lb_envs': controller_utils.sky_managed_cluster_envs(), } common_utils.fill_template(constants.EXTERNAL_LB_TEMPLATE, From 7454b1ddb5a09cb9305501bc00af21cd8ffe997f Mon Sep 17 00:00:00 2001 From: cblmemo Date: Thu, 14 Nov 2024 17:01:25 -0800 Subject: [PATCH 3/9] expose to public internet when using external lbs --- sky/serve/service.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sky/serve/service.py b/sky/serve/service.py index 9c157058b6f..bfbe0946947 100644 --- a/sky/serve/service.py +++ b/sky/serve/service.py @@ -257,8 +257,11 @@ def _start(service_name: str, tmp_task_yaml: str, job_id: int): # inside a kubernetes cluster to allow external load balancers # (example, for high availability load balancers) to communicate # with the controller. + # Also, when we are using external load balancers, in which we + # need to get the information from a distinct machine. def _get_host(): - if 'KUBERNETES_SERVICE_HOST' in os.environ: + if ('KUBERNETES_SERVICE_HOST' in os.environ or + service_spec.external_load_balancers is not None): return '0.0.0.0' # Not using localhost to avoid using ipv6 address and causing # the following error: From d8d5cdb2d6259be29eb3b7e245e7278fd16a4cb0 Mon Sep 17 00:00:00 2001 From: cblmemo Date: Thu, 14 Nov 2024 17:42:22 -0800 Subject: [PATCH 4/9] add route 53 field in service. TODO: add record after the LB is provisioned. --- examples/serve/external-lb.yaml | 18 +++++---- sky/serve/service_spec.py | 72 ++++++++++++++++++++++++++++++++- sky/utils/schemas.py | 3 ++ 3 files changed, 83 insertions(+), 10 deletions(-) diff --git a/examples/serve/external-lb.yaml b/examples/serve/external-lb.yaml index e96d5e0fda0..f280a194fa3 100644 --- a/examples/serve/external-lb.yaml +++ b/examples/serve/external-lb.yaml @@ -7,18 +7,20 @@ service: path: /health initial_delay_seconds: 20 replicas: 2 + # TODO(tian): Change the config to a cloud-agnostic way. + route53_hosted_zone: aws.cblmemo.net external_load_balancers: - resources: - # cloud: aws - # region: us-east-2 - cloud: gcp - region: us-east1 + cloud: aws + region: us-east-2 + # cloud: gcp + # region: us-east1 load_balancing_policy: round_robin - resources: - # cloud: aws - # region: ap-northeast-1 - cloud: gcp - region: asia-northeast1 + cloud: aws + region: ap-northeast-1 + # cloud: gcp + # region: asia-northeast1 load_balancing_policy: round_robin resources: diff --git a/sky/serve/service_spec.py b/sky/serve/service_spec.py index 8c888eb3dc2..b0f48eec73b 100644 --- a/sky/serve/service_spec.py +++ b/sky/serve/service_spec.py @@ -32,6 +32,7 @@ def __init__( downscale_delay_seconds: Optional[int] = None, load_balancing_policy: Optional[str] = None, external_load_balancers: Optional[List[Dict[str, Any]]] = None, + route53_hosted_zone: Optional[str] = None, ) -> None: if max_replicas is not None and max_replicas < min_replicas: with ux_utils.print_exception_no_traceback(): @@ -65,6 +66,7 @@ def __init__( if lb_config.get('load_balancing_policy') is None: lb_config['load_balancing_policy'] = load_balancing_policy + self._target_hosted_zone_id: Optional[str] = None if external_load_balancers is not None: for lb_config in external_load_balancers: r = lb_config.get('resources') @@ -76,8 +78,62 @@ def __init__( with ux_utils.print_exception_no_traceback(): raise ValueError('`ports` must not be set for ' 'external_load_balancers.') + if route53_hosted_zone is not None: + if r.get('cloud', 'aws') != 'aws': + with ux_utils.print_exception_no_traceback(): + raise ValueError( + '`cloud` in `external_load_balancers` must be ' + 'set to `aws` if using route53_hosted_zone.') + r['cloud'] = 'aws' + if r.get('region') is None: + with ux_utils.print_exception_no_traceback(): + raise ValueError( + '`region` in `external_load_balancers` must be ' + 'set when using route53_hosted_zone.') # Validate resources resources_lib.Resources.from_yaml_config(r) + if route53_hosted_zone is not None: + # TODO(tian): Move the import. + import boto3 # pylint: disable=import-outside-toplevel + client = boto3.client('route53') + hosted_zones = client.list_hosted_zones()['HostedZones'] + target_hosted_zone_id = None + for hz in hosted_zones: + # Amazon Route 53 treats domain name as FQDN. + # Thus a trailing dot is added to the domain name. + # Here we strip the trailing dot for comparison. + if hz['Name'].strip('.') == route53_hosted_zone: + if hz['Config']['PrivateZone']: + with ux_utils.print_exception_no_traceback(): + raise ValueError('`route53_hosted_zone` must be' + ' a public hosted zone.') + target_hosted_zone_id = hz['Id'] + break + if target_hosted_zone_id is None: + with ux_utils.print_exception_no_traceback(): + raise ValueError( + f'route53_hosted_zone ({route53_hosted_zone}) ' + 'not found.') + self._target_hosted_zone_id = target_hosted_zone_id + print(f'Found hosted zone: {route53_hosted_zone} with ID: ' + f'{target_hosted_zone_id}.') + # TODO(tian): Here we dont have the service_name information. + # Skip checking it for now. We should add it back later. + # for record in client.list_resource_record_sets( + # HostedZoneId=target_hosted_zone_id + # )['ResourceRecordSets']: + # if (record['Type'] == 'A' and + # record['Name'].split('.')[0] == service_name): + # with ux_utils.print_exception_no_traceback(): + # raise ValueError( + # 'Hosted zone already has an A record with ' + # f'subdomain {service_name}. Please remove it ' + # 'before using it.') + else: + if route53_hosted_zone is not None: + with ux_utils.print_exception_no_traceback(): + raise ValueError('`external_load_balancers` must be set ' + 'for route53_hosted_zone.') self._readiness_path: str = readiness_path self._initial_delay_seconds: int = initial_delay_seconds @@ -96,6 +152,7 @@ def __init__( self._load_balancing_policy: Optional[str] = load_balancing_policy self._external_load_balancers: Optional[List[Dict[str, Any]]] = ( external_load_balancers) + self._route53_hosted_zone = route53_hosted_zone self._use_ondemand_fallback: bool = ( self.dynamic_ondemand_fallback is not None and @@ -181,6 +238,8 @@ def from_yaml_config(config: Dict[str, Any]) -> 'SkyServiceSpec': 'load_balancing_policy', None) service_config['external_load_balancers'] = config.get( 'external_load_balancers', None) + service_config['route53_hosted_zone'] = config.get( + 'route53_hosted_zone', None) return SkyServiceSpec(**service_config) @staticmethod @@ -237,9 +296,10 @@ def add_if_not_none(section, key, value, no_empty: bool = False): add_if_not_none('replica_policy', 'downscale_delay_seconds', self.downscale_delay_seconds) add_if_not_none('load_balancing_policy', None, - self._load_balancing_policy) + self.load_balancing_policy) add_if_not_none('external_load_balancers', None, - self._external_load_balancers) + self.external_load_balancers) + add_if_not_none('route53_hosted_zone', None, self.route53_hosted_zone) return config def probe_str(self): @@ -354,3 +414,11 @@ def load_balancing_policy(self) -> Optional[str]: @property def external_load_balancers(self) -> Optional[List[Dict[str, Any]]]: return self._external_load_balancers + + @property + def route53_hosted_zone(self) -> Optional[str]: + return self._route53_hosted_zone + + @property + def target_hosted_zone_id(self) -> Optional[str]: + return self._target_hosted_zone_id diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 1a2b535b9d1..6d078671775 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -407,6 +407,9 @@ def get_service_schema(): 'type': 'string', 'case_insensitive_enum': lb_policy_choices, }, + 'route53_hosted_zone': { + 'type': 'string', + }, } } From a68677ce318c5676aaeab78664cb92add72d380b Mon Sep 17 00:00:00 2001 From: cblmemo Date: Thu, 14 Nov 2024 21:03:58 -0800 Subject: [PATCH 5/9] change status format --- sky/serve/core.py | 2 +- sky/serve/serve_state.py | 27 +++++++++++++-------------- sky/serve/serve_utils.py | 8 ++++---- 3 files changed, 18 insertions(+), 19 deletions(-) diff --git a/sky/serve/core.py b/sky/serve/core.py index a006500679f..6117ee5eaff 100644 --- a/sky/serve/core.py +++ b/sky/serve/core.py @@ -625,8 +625,8 @@ def status( { 'lb_id': (int) index of the external load balancer, 'cluster_name': (str) cluster name of the external load balancer, + 'ip': (str) ip of the external load balancer, 'port': (int) port of the external load balancer, - 'endpoint': (str) endpoint of the external load balancer, } For possible service statuses and replica statuses, please refer to diff --git a/sky/serve/serve_state.py b/sky/serve/serve_state.py index 39391fdbf66..bd1561f6a71 100644 --- a/sky/serve/serve_state.py +++ b/sky/serve/serve_state.py @@ -10,7 +10,8 @@ import colorama -from sky import exceptions +from sky import global_user_state +from sky import status_lib from sky.serve import constants from sky.utils import db_utils @@ -562,24 +563,22 @@ def add_external_load_balancer(service_name: str, lb_id: int, cluster_name: str, def _get_external_load_balancer_from_row(row) -> Dict[str, Any]: - from sky import core # pylint: disable=import-outside-toplevel - - # TODO(tian): Temporary workaround to avoid circular import. - # This should be fixed. lb_id, cluster_name, port = row[:3] - try: - endpoint = core.endpoints(cluster_name, port)[port] - except exceptions.ClusterNotUpError: - # TODO(tian): Currently, when this cluster is not in the UP status, - # the endpoint query will raise an cluster is not up error. We should - # implement a status for external lbs as well and returns a '-' when - # it is still provisioning. - endpoint = '-' + lb_cluster_record = global_user_state.get_cluster_from_name(cluster_name) + if (lb_cluster_record is None or + lb_cluster_record['status'] != status_lib.ClusterStatus.UP): + # TODO(tian): We should implement a status for external lbs as well + # and returns a '-' when it is still provisioning. + lb_ip = '-' + else: + lb_ip = lb_cluster_record['handle'].head_ip + if lb_ip is None: + lb_ip = '-' return { 'lb_id': lb_id, 'cluster_name': cluster_name, + 'ip': lb_ip, 'port': port, - 'endpoint': endpoint, } diff --git a/sky/serve/serve_utils.py b/sky/serve/serve_utils.py index bd500432467..71410e3ff70 100644 --- a/sky/serve/serve_utils.py +++ b/sky/serve/serve_utils.py @@ -945,8 +945,7 @@ def _format_external_lb_table(external_lb_records: List[Dict[str, Any]], external_lb_columns = ['SERVICE_NAME', 'ID', 'ENDPOINT'] if show_all: - external_lb_columns.append('PORT') - external_lb_columns.append('CLUSTER_NAME') + external_lb_columns.extend(['IP', 'PORT', 'CLUSTER_NAME']) external_lb_table = log_utils.create_table(external_lb_columns) truncate_hint = '' @@ -959,8 +958,9 @@ def _format_external_lb_table(external_lb_records: List[Dict[str, Any]], for record in external_lb_records: service_name = record['service_name'] external_lb_id = record['lb_id'] - endpoint = record['endpoint'] + lb_ip = record['ip'] port = record['port'] + endpoint = f'{lb_ip}:{port}' cluster_name = record['cluster_name'] external_lb_values = [ @@ -969,7 +969,7 @@ def _format_external_lb_table(external_lb_records: List[Dict[str, Any]], endpoint, ] if show_all: - external_lb_values.extend([port, cluster_name]) + external_lb_values.extend([lb_ip, port, cluster_name]) external_lb_table.add_row(external_lb_values) return f'{external_lb_table}{truncate_hint}' From 1b63e43263812ef06ed3112627fdc63f2ff60c6f Mon Sep 17 00:00:00 2001 From: cblmemo Date: Thu, 14 Nov 2024 22:39:03 -0800 Subject: [PATCH 6/9] add auto configure route 52 --- sky/core.py | 4 +- sky/serve/constants.py | 5 ++ sky/serve/core.py | 1 + sky/serve/serve_state.py | 14 ++++-- sky/serve/serve_utils.py | 6 ++- sky/serve/service.py | 104 +++++++++++++++++++++++++++++++++++---- 6 files changed, 116 insertions(+), 18 deletions(-) diff --git a/sky/core.py b/sky/core.py index 4bb12f4a21a..fcf12046b15 100644 --- a/sky/core.py +++ b/sky/core.py @@ -683,8 +683,8 @@ def cancel( sky.exceptions.CloudUserIdentityError: if we fail to get the current user identity. """ - controller_utils.check_cluster_name_not_controller( - cluster_name, operation_str='Cancelling jobs') + # controller_utils.check_cluster_name_not_controller( + # cluster_name, operation_str='Cancelling jobs') if all and job_ids: raise ValueError('Cannot specify both `all` and `job_ids`. To cancel ' diff --git a/sky/serve/constants.py b/sky/serve/constants.py index 813aa0d6d0e..0071ee36d11 100644 --- a/sky/serve/constants.py +++ b/sky/serve/constants.py @@ -16,6 +16,11 @@ # Time to wait in seconds for service to register on the controller. SERVICE_REGISTER_TIMEOUT_SECONDS = 60 +# Time to wait in seconds for service to register on the controller with +# external load balancer. We need to wait longer for external load balancer to +# be ready for the ip address of the service. +SERVICE_REGISTER_TIMEOUT_SECONDS_WITH_EXTERNAL_LB = 300 + # The time interval in seconds for load balancer to sync with controller. Every # time the load balancer syncs with controller, it will update all available # replica ips for each service, also send the number of requests in last query diff --git a/sky/serve/core.py b/sky/serve/core.py index 6117ee5eaff..b0c3b04dfd2 100644 --- a/sky/serve/core.py +++ b/sky/serve/core.py @@ -625,6 +625,7 @@ def status( { 'lb_id': (int) index of the external load balancer, 'cluster_name': (str) cluster name of the external load balancer, + 'region': (str) region of the external load balancer, 'ip': (str) ip of the external load balancer, 'port': (int) port of the external load balancer, } diff --git a/sky/serve/serve_state.py b/sky/serve/serve_state.py index bd1561f6a71..8b0a9a1ee18 100644 --- a/sky/serve/serve_state.py +++ b/sky/serve/serve_state.py @@ -65,6 +65,7 @@ def create_table(cursor: 'sqlite3.Cursor', conn: 'sqlite3.Connection') -> None: lb_id INTEGER, service_name TEXT, cluster_name TEXT, + region TEXT, port INTEGER, PRIMARY KEY (service_name, lb_id))""") conn.commit() @@ -552,18 +553,19 @@ def delete_all_versions(service_name: str) -> None: # === External Load Balancer functions === # TODO(tian): Add a status column. def add_external_load_balancer(service_name: str, lb_id: int, cluster_name: str, - port: int) -> None: + region: str, port: int) -> None: """Adds an external load balancer to the database.""" with db_utils.safe_cursor(_DB_PATH) as cursor: cursor.execute( """\ INSERT INTO external_load_balancers - (service_name, lb_id, cluster_name, port) - VALUES (?, ?, ?, ?)""", (service_name, lb_id, cluster_name, port)) + (service_name, lb_id, cluster_name, region, port) + VALUES (?, ?, ?, ?, ?)""", + (service_name, lb_id, cluster_name, region, port)) def _get_external_load_balancer_from_row(row) -> Dict[str, Any]: - lb_id, cluster_name, port = row[:3] + lb_id, cluster_name, region, port = row[:4] lb_cluster_record = global_user_state.get_cluster_from_name(cluster_name) if (lb_cluster_record is None or lb_cluster_record['status'] != status_lib.ClusterStatus.UP): @@ -577,6 +579,7 @@ def _get_external_load_balancer_from_row(row) -> Dict[str, Any]: return { 'lb_id': lb_id, 'cluster_name': cluster_name, + 'region': region, 'ip': lb_ip, 'port': port, } @@ -587,7 +590,8 @@ def get_external_load_balancers(service_name: str) -> List[Dict[str, Any]]: with db_utils.safe_cursor(_DB_PATH) as cursor: rows = cursor.execute( """\ - SELECT lb_id, cluster_name, port FROM external_load_balancers + SELECT lb_id, cluster_name, region, port + FROM external_load_balancers WHERE service_name=(?)""", (service_name,)).fetchall() external_load_balancers = [] for row in rows: diff --git a/sky/serve/serve_utils.py b/sky/serve/serve_utils.py index 71410e3ff70..516181c6fb5 100644 --- a/sky/serve/serve_utils.py +++ b/sky/serve/serve_utils.py @@ -539,6 +539,10 @@ def wait_service_registration(service_name: str, job_id: int) -> str: Encoded load balancer port assigned to the service. """ start_time = time.time() + # TODO(tian): Add a field in service record to indicate whether it has + # external load balancer or not. And change this timeout accordingly. + # timeout = constants.SERVICE_REGISTER_TIMEOUT_SECONDS + timeout = constants.SERVICE_REGISTER_TIMEOUT_SECONDS_WITH_EXTERNAL_LB while True: record = serve_state.get_service_from_name(service_name) if record is not None: @@ -558,7 +562,7 @@ def wait_service_registration(service_name: str, job_id: int) -> str: 'To spin up more services, please ' 'tear down some existing services.') elapsed = time.time() - start_time - if elapsed > constants.SERVICE_REGISTER_TIMEOUT_SECONDS: + if elapsed > timeout: # Print the controller log to help user debug. controller_log_path = ( generate_remote_controller_log_file_name(service_name)) diff --git a/sky/serve/service.py b/sky/serve/service.py index bfbe0946947..9e30be83fd9 100644 --- a/sky/serve/service.py +++ b/sky/serve/service.py @@ -11,12 +11,14 @@ import tempfile import time import traceback -from typing import Any, Dict +import typing +from typing import Any, Dict, Optional import filelock from sky import authentication from sky import exceptions +from sky import global_user_state from sky import resources as resources_lib from sky import sky_logging from sky import task as task_lib @@ -34,6 +36,9 @@ from sky.utils import subprocess_utils from sky.utils import ux_utils +if typing.TYPE_CHECKING: + from sky.serve import service_spec + # Use the explicit logger name so that the logger is under the # `sky.serve.service` namespace when executed directly, so as # to inherit the setup from the `sky` logger. @@ -88,9 +93,40 @@ def cleanup_storage(task_yaml: str) -> bool: return True -def _cleanup(service_name: str) -> bool: +def _get_cluster_ip(cluster_name: str) -> Optional[str]: + record = global_user_state.get_cluster_from_name(cluster_name) + if record is None: + return None + if record['handle'].head_ip is None: + return None + return record['handle'].head_ip + + +def _get_route53_change(action: str, subdomain: str, hosted_zone: str, + record_type: str, region: str, + value: str) -> Dict[str, Any]: + return { + 'Action': action, + 'ResourceRecordSet': { + 'Name': f'{subdomain}.{hosted_zone}', + 'Type': record_type, + 'TTL': 300, + 'Region': region, + 'SetIdentifier': f'{subdomain}-{region}', + 'ResourceRecords': [{ + 'Value': value + }] + } + } + + +def _cleanup(service_name: str, + service_spec: 'service_spec.SkyServiceSpec') -> bool: """Clean up all service related resources, i.e. replicas and storage.""" failed = False + change_batch = [] + hosted_zone = service_spec.route53_hosted_zone + replica_infos = serve_state.get_replica_infos(service_name) info2proc: Dict[replica_managers.ReplicaInfo, multiprocessing.Process] = dict() @@ -111,11 +147,27 @@ def _cleanup(service_name: str) -> bool: for external_lb_record in external_lbs: lb_cluster_name = external_lb_record['cluster_name'] lb_id = external_lb_record['lb_id'] + lb_region = external_lb_record['region'] p = multiprocessing.Process(target=replica_managers.terminate_cluster, args=(lb_cluster_name,)) p.start() lbid2proc[lb_id] = p + lb_ip = _get_cluster_ip(lb_cluster_name) + assert lb_ip is not None + if hosted_zone is not None: + change_batch.append( + _get_route53_change('DELETE', service_name, hosted_zone, 'A', + lb_region, lb_ip)) logger.info(f'Terminating external load balancer {lb_cluster_name} ...') + + if change_batch: + # TODO(tian): Fix this import hack. + import boto3 # pylint: disable=import-outside-toplevel + client = boto3.client('route53') + client.change_resource_record_sets( + HostedZoneId=service_spec.target_hosted_zone_id, + ChangeBatch={'Changes': change_batch}) + for info, p in info2proc.items(): p.join() if p.exitcode == 0: @@ -158,8 +210,9 @@ def _get_external_lb_cluster_name(service_name: str, lb_id: int) -> str: return f'sky-{service_name}-lb-{lb_id}' -def _start_external_load_balancer(service_name: str, controller_addr: str, - lb_id: int, lb_port: int, lb_policy: str, +def _start_external_load_balancer(service_name: str, lb_id: int, + lb_cluster_name: str, controller_addr: str, + lb_port: int, lb_policy: str, lb_resources: Dict[str, Any]) -> None: # TODO(tian): Hack. We should figure out the optimal resoruces. if 'cpus' not in lb_resources: @@ -168,7 +221,6 @@ def _start_external_load_balancer(service_name: str, controller_addr: str, assert 'ports' not in lb_resources lb_resources['ports'] = [lb_port] lbr = resources_lib.Resources.from_yaml_config(lb_resources) - lb_cluster_name = _get_external_lb_cluster_name(service_name, lb_id) # TODO(tian): Set delete=False to debug. Remove this on production. with tempfile.NamedTemporaryFile(prefix=lb_cluster_name, mode='w', @@ -187,7 +239,8 @@ def _start_external_load_balancer(service_name: str, controller_addr: str, lb_task = task_lib.Task.from_yaml(f.name) lb_task.set_resources(lbr) serve_state.add_external_load_balancer(service_name, lb_id, - lb_cluster_name, lb_port) + lb_cluster_name, + lb_resources['region'], lb_port) # TODO(tian): Temporary solution for circular import. We should move # the import to the top of the file. import sky # pylint: disable=import-outside-toplevel @@ -289,6 +342,9 @@ def _get_external_host(): # TODO(tian): Support HTTPS. controller_addr = f'http://{controller_host}:{controller_port}' load_balancer_processes = [] + # TODO(tian): Combine the following two. + lbid2cluster = {} + lbid2region = {} if service_spec.external_load_balancers is None: # Generate load balancer log file name. @@ -324,19 +380,47 @@ def _get_external_host(): serve_utils. generate_remote_external_load_balancer_log_file_name( service_name, lb_id)) + lb_cluster_name = (_get_external_lb_cluster_name( + service_name, lb_id)) + lbid2cluster[lb_id] = lb_cluster_name lb_policy = lb_config['load_balancing_policy'] lb_resources = lb_config['resources'] + lbid2region[lb_id] = lb_resources['region'] controller_external_addr = ( f'http://{_get_external_host()}:{controller_port}') lb_process = multiprocessing.Process( target=ux_utils.RedirectOutputForProcess( _start_external_load_balancer, load_balancer_log_file).run, - args=(service_name, controller_external_addr, lb_id, - lb_port, lb_policy, lb_resources)) + args=(service_name, lb_id, lb_cluster_name, + controller_external_addr, lb_port, lb_policy, + lb_resources)) lb_process.start() load_balancer_processes.append(lb_process) - serve_state.set_service_load_balancer_port(service_name, -1) + + if service_spec.external_load_balancers is not None: + hosted_zone = service_spec.route53_hosted_zone + if hosted_zone is not None: + while True: + if all( + _get_cluster_ip(lb_cluster_name) is not None + for lb_cluster_name in lbid2cluster.values()): + break + time.sleep(1) + # TODO(tian): Fix this import hack. + import boto3 # pylint: disable=import-outside-toplevel + client = boto3.client('route53') + change_batch = [] + for lb_id, lb_cluster_name in lbid2cluster.items(): + lb_ip = _get_cluster_ip(lb_cluster_name) + assert lb_ip is not None + change_batch.append( + _get_route53_change('CREATE', service_name, hosted_zone, + 'A', lbid2region[lb_id], lb_ip)) + client.change_resource_record_sets( + HostedZoneId=service_spec.target_hosted_zone_id, + ChangeBatch={'Changes': change_batch}) + serve_state.set_service_load_balancer_port(service_name, -1) while True: _handle_signal(service_name) @@ -355,7 +439,7 @@ def _get_external_host(): [process.pid for process in process_to_kill], force=True) for process in process_to_kill: process.join() - failed = _cleanup(service_name) + failed = _cleanup(service_name, service_spec) if failed: serve_state.set_service_status_and_active_versions( service_name, serve_state.ServiceStatus.FAILED_CLEANUP) From 6826efc057a14e1dcf36a5cb8cdfefb4fe9c7a55 Mon Sep 17 00:00:00 2001 From: cblmemo Date: Fri, 15 Nov 2024 11:53:18 -0800 Subject: [PATCH 7/9] revert debug --- sky/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sky/core.py b/sky/core.py index fcf12046b15..4bb12f4a21a 100644 --- a/sky/core.py +++ b/sky/core.py @@ -683,8 +683,8 @@ def cancel( sky.exceptions.CloudUserIdentityError: if we fail to get the current user identity. """ - # controller_utils.check_cluster_name_not_controller( - # cluster_name, operation_str='Cancelling jobs') + controller_utils.check_cluster_name_not_controller( + cluster_name, operation_str='Cancelling jobs') if all and job_ids: raise ValueError('Cannot specify both `all` and `job_ids`. To cancel ' From 5b7f51d890ad783be697fb11fb7887a4549f006e Mon Sep 17 00:00:00 2001 From: cblmemo Date: Fri, 15 Nov 2024 12:45:57 -0800 Subject: [PATCH 8/9] show dns endpoint --- sky/serve/constants.py | 3 +++ sky/serve/core.py | 19 +++++++++++-------- sky/serve/serve_state.py | 14 +++++++++++++- sky/serve/serve_utils.py | 18 ++++++++++-------- sky/serve/service.py | 16 ++++++++++++---- 5 files changed, 49 insertions(+), 21 deletions(-) diff --git a/sky/serve/constants.py b/sky/serve/constants.py index 0071ee36d11..c0e04afe7d6 100644 --- a/sky/serve/constants.py +++ b/sky/serve/constants.py @@ -89,6 +89,9 @@ LOAD_BALANCER_PORT_START = 30001 LOAD_BALANCER_PORT_RANGE = '30001-30020' +# Port for external load balancer. +EXTERNAL_LB_PORT = 8000 + # Initial version of service. INITIAL_VERSION = 1 diff --git a/sky/serve/core.py b/sky/serve/core.py index b0c3b04dfd2..c1867bf0806 100644 --- a/sky/serve/core.py +++ b/sky/serve/core.py @@ -239,7 +239,7 @@ def up( assert isinstance(backend, backends.CloudVmRayBackend) assert isinstance(controller_handle, backends.CloudVmRayResourceHandle) - returncode, lb_port_payload, _ = backend.run_on_head( + returncode, service_init_payload, _ = backend.run_on_head( controller_handle, code, require_outputs=True, @@ -247,7 +247,7 @@ def up( try: subprocess_utils.handle_returncode( returncode, code, 'Failed to wait for service initialization', - lb_port_payload) + service_init_payload) except exceptions.CommandError: statuses = backend.get_job_status(controller_handle, [controller_job_id], @@ -276,18 +276,20 @@ def up( 'Failed to spin up the service. Please ' 'check the logs above for more details.') from None else: + service_init_result = ( + serve_utils.load_service_initialization_result( + service_init_payload)) if task.service.external_load_balancers is None: - lb_port = serve_utils.load_service_initialization_result( - lb_port_payload) + assert isinstance(service_init_result, int) endpoint = backend_utils.get_endpoints( controller_handle.cluster_name, - lb_port, - skip_status_check=True).get(lb_port) + service_init_result, + skip_status_check=True).get(service_init_result) assert endpoint is not None, ( 'Did not get endpoint for controller.') else: - endpoint = ( - 'Please query with sky serve status for the endpoint.') + assert isinstance(service_init_result, str) + endpoint = service_init_result sky_logging.print( f'{fore.CYAN}Service name: ' @@ -600,6 +602,7 @@ def status( 'policy': (Optional[str]) load balancer policy description, 'requested_resources_str': (str) str representation of requested resources, + 'dns_endpoint': (Optional[str]) DNS endpoint, 'replica_info': (List[Dict[str, Any]]) replica information, 'external_lb_info': (Dict[str, Any]) external load balancer information, diff --git a/sky/serve/serve_state.py b/sky/serve/serve_state.py index 8b0a9a1ee18..a321f09f0fe 100644 --- a/sky/serve/serve_state.py +++ b/sky/serve/serve_state.py @@ -86,6 +86,8 @@ def create_table(cursor: 'sqlite3.Cursor', conn: 'sqlite3.Connection') -> None: db_utils.add_column_to_table(_DB.cursor, _DB.conn, 'services', 'active_versions', f'TEXT DEFAULT {json.dumps([])!r}') +db_utils.add_column_to_table(_DB.cursor, _DB.conn, 'services', 'dns_endpoint', + 'TEXT DEFAULT NULL') _UNIQUE_CONSTRAINT_FAILED_ERROR_MSG = 'UNIQUE constraint failed: services.name' @@ -331,10 +333,19 @@ def set_service_load_balancer_port(service_name: str, (load_balancer_port, service_name)) +def set_service_dns_endpoint(service_name: str, dns_endpoint: str) -> None: + """Sets the dns endpoint of a service.""" + with db_utils.safe_cursor(_DB_PATH) as cursor: + cursor.execute( + """\ + UPDATE services SET + dns_endpoint=(?) WHERE name=(?)""", (dns_endpoint, service_name)) + + def _get_service_from_row(row) -> Dict[str, Any]: (current_version, name, controller_job_id, controller_port, load_balancer_port, status, uptime, policy, _, _, requested_resources_str, - _, active_versions) = row[:13] + _, active_versions, dns_endpoint) = row[:14] return { 'name': name, 'controller_job_id': controller_job_id, @@ -351,6 +362,7 @@ def _get_service_from_row(row) -> Dict[str, Any]: # integers in json format. This is mainly for display purpose. 'active_versions': json.loads(active_versions), 'requested_resources_str': requested_resources_str, + 'dns_endpoint': dns_endpoint, } diff --git a/sky/serve/serve_utils.py b/sky/serve/serve_utils.py index 516181c6fb5..9e7919b2836 100644 --- a/sky/serve/serve_utils.py +++ b/sky/serve/serve_utils.py @@ -12,7 +12,7 @@ import time import typing from typing import (Any, Callable, DefaultDict, Dict, Generic, Iterator, List, - Optional, TextIO, Type, TypeVar) + Optional, TextIO, Type, TypeVar, Union) import uuid import colorama @@ -555,6 +555,9 @@ def wait_service_registration(service_name: str, job_id: int) -> str: f'{service_name} ') lb_port = record['load_balancer_port'] if lb_port is not None: + if record['dns_endpoint'] is not None: + endpoint = f'{record["dns_endpoint"]}:{lb_port}' + return common_utils.encode_payload(endpoint) return common_utils.encode_payload(lb_port) elif len(serve_state.get_services()) >= NUM_SERVICE_THRESHOLD: with ux_utils.print_exception_no_traceback(): @@ -577,7 +580,7 @@ def wait_service_registration(service_name: str, job_id: int) -> str: time.sleep(1) -def load_service_initialization_result(payload: str) -> int: +def load_service_initialization_result(payload: str) -> Union[int, str]: return common_utils.decode_payload(payload) @@ -794,6 +797,10 @@ def _get_replicas(service_record: Dict[str, Any]) -> str: def get_endpoint(service_record: Dict[str, Any]) -> str: + if service_record['dns_endpoint'] is not None: + dns = service_record['dns_endpoint'] + lb_port = service_record['load_balancer_port'] + return f'{dns}:{lb_port}' # Don't use backend_utils.is_controller_accessible since it is too slow. handle = global_user_state.get_handle_from_cluster_name( SKY_SERVE_CONTROLLER_NAME) @@ -846,12 +853,7 @@ def format_service_table(service_records: List[Dict[str, Any]], service_status = record['status'] status_str = service_status.colored_str() replicas = _get_replicas(record) - if record['external_lb_info']: - # Don't show endpoint for services with external load balancers. - # TODO(tian): Add automatic DNS record creation and show domain here - endpoint = '-' - else: - endpoint = get_endpoint(record) + endpoint = get_endpoint(record) policy = record['policy'] requested_resources_str = record['requested_resources_str'] diff --git a/sky/serve/service.py b/sky/serve/service.py index 9e30be83fd9..3b31bf475a6 100644 --- a/sky/serve/service.py +++ b/sky/serve/service.py @@ -102,13 +102,17 @@ def _get_cluster_ip(cluster_name: str) -> Optional[str]: return record['handle'].head_ip +def _get_domain_name(subdomain: str, hosted_zone: str) -> str: + return f'{subdomain}.{hosted_zone}' + + def _get_route53_change(action: str, subdomain: str, hosted_zone: str, record_type: str, region: str, value: str) -> Dict[str, Any]: return { 'Action': action, 'ResourceRecordSet': { - 'Name': f'{subdomain}.{hosted_zone}', + 'Name': _get_domain_name(subdomain, hosted_zone), 'Type': record_type, 'TTL': 300, 'Region': region, @@ -372,7 +376,6 @@ def _get_external_host(): serve_state.set_service_load_balancer_port( service_name, load_balancer_port) else: - lb_port = 8000 for lb_id, lb_config in enumerate( service_spec.external_load_balancers): # Generate load balancer log file name. @@ -392,8 +395,10 @@ def _get_external_host(): target=ux_utils.RedirectOutputForProcess( _start_external_load_balancer, load_balancer_log_file).run, + # TODO(tian): Let the user to customize the port. args=(service_name, lb_id, lb_cluster_name, - controller_external_addr, lb_port, lb_policy, + controller_external_addr, + constants.EXTERNAL_LB_PORT, lb_policy, lb_resources)) lb_process.start() load_balancer_processes.append(lb_process) @@ -420,7 +425,10 @@ def _get_external_host(): client.change_resource_record_sets( HostedZoneId=service_spec.target_hosted_zone_id, ChangeBatch={'Changes': change_batch}) - serve_state.set_service_load_balancer_port(service_name, -1) + serve_state.set_service_dns_endpoint( + service_name, _get_domain_name(service_name, hosted_zone)) + serve_state.set_service_load_balancer_port( + service_name, constants.EXTERNAL_LB_PORT) while True: _handle_signal(service_name) From 6881bac0571fce047888de22b91e69e9cb0d0b7f Mon Sep 17 00:00:00 2001 From: cblmemo Date: Fri, 15 Nov 2024 13:06:10 -0800 Subject: [PATCH 9/9] comments --- examples/serve/external-lb.yaml | 5 +++++ sky/serve/service.py | 2 ++ 2 files changed, 7 insertions(+) diff --git a/examples/serve/external-lb.yaml b/examples/serve/external-lb.yaml index f280a194fa3..e55d3c822fe 100644 --- a/examples/serve/external-lb.yaml +++ b/examples/serve/external-lb.yaml @@ -1,4 +1,9 @@ # SkyServe YAML to run multiple Load Balancers in different region. +# +# Usage: +# 1. Register the hosted zone in Route53. +# 2. Go to your DNS manager and setup name servers to Route53. +# 3. `sky serve up examples/serve/external-lb.yaml`. name: multi-lb diff --git a/sky/serve/service.py b/sky/serve/service.py index 3b31bf475a6..91900f29166 100644 --- a/sky/serve/service.py +++ b/sky/serve/service.py @@ -396,6 +396,7 @@ def _get_external_host(): _start_external_load_balancer, load_balancer_log_file).run, # TODO(tian): Let the user to customize the port. + # TODO(tian): Or, default to port 80 (need root). args=(service_name, lb_id, lb_cluster_name, controller_external_addr, constants.EXTERNAL_LB_PORT, lb_policy, @@ -406,6 +407,7 @@ def _get_external_host(): if service_spec.external_load_balancers is not None: hosted_zone = service_spec.route53_hosted_zone if hosted_zone is not None: + # Wait for the LBs is ready, get the IPs and setup Route53. while True: if all( _get_cluster_ip(lb_cluster_name) is not None