diff --git a/sky/backends/backend.py b/sky/backends/backend.py index 10b51b06038..bf74ed3acf3 100644 --- a/sky/backends/backend.py +++ b/sky/backends/backend.py @@ -45,20 +45,23 @@ def check_resources_fit_cluster(self, handle: _ResourceHandleType, @timeline.event @usage_lib.messages.usage.update_runtime('provision') def provision( - self, - task: 'task_lib.Task', - to_provision: Optional['resources.Resources'], - dryrun: bool, - stream_logs: bool, - cluster_name: Optional[str] = None, - retry_until_up: bool = False) -> Optional[_ResourceHandleType]: + self, + task: 'task_lib.Task', + to_provision: Optional['resources.Resources'], + dryrun: bool, + stream_logs: bool, + cluster_name: Optional[str] = None, + retry_until_up: bool = False, + skip_if_config_hash_matches: Optional[str] = None + ) -> Optional[_ResourceHandleType]: if cluster_name is None: cluster_name = sky.backends.backend_utils.generate_cluster_name() usage_lib.record_cluster_name_for_current_operation(cluster_name) usage_lib.messages.usage.update_actual_task(task) with rich_utils.safe_status(ux_utils.spinner_message('Launching')): return self._provision(task, to_provision, dryrun, stream_logs, - cluster_name, retry_until_up) + cluster_name, retry_until_up, + skip_if_config_hash_matches) @timeline.event @usage_lib.messages.usage.update_runtime('sync_workdir') @@ -126,13 +129,15 @@ def register_info(self, **kwargs) -> None: # --- Implementations of the APIs --- def _provision( - self, - task: 'task_lib.Task', - to_provision: Optional['resources.Resources'], - dryrun: bool, - stream_logs: bool, - cluster_name: str, - retry_until_up: bool = False) -> Optional[_ResourceHandleType]: + self, + task: 'task_lib.Task', + to_provision: Optional['resources.Resources'], + dryrun: bool, + stream_logs: bool, + cluster_name: str, + retry_until_up: bool = False, + skip_if_config_hash_matches: Optional[str] = None + ) -> Optional[_ResourceHandleType]: raise NotImplementedError def _sync_workdir(self, handle: _ResourceHandleType, workdir: Path) -> None: diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 0013e6cbaf9..caaa826ab68 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -1314,6 +1314,7 @@ def _retry_zones( prev_cluster_status: Optional[status_lib.ClusterStatus], prev_handle: Optional['CloudVmRayResourceHandle'], prev_cluster_ever_up: bool, + skip_if_config_hash_matches: Optional[str], ) -> Dict[str, Any]: """The provision retry loop.""" # Get log_path name @@ -1424,8 +1425,15 @@ def _retry_zones( raise exceptions.ResourcesUnavailableError( f'Failed to provision on cloud {to_provision.cloud} due to ' f'invalid cloud config: {common_utils.format_exception(e)}') + + if skip_if_config_hash_matches == config_dict['config_hash']: + logger.info('Skipping provisioning of cluster with matching ' + 'config hash.') + return config_dict + if dryrun: return config_dict + cluster_config_file = config_dict['ray'] launched_resources = to_provision.copy(region=region.name) @@ -1937,6 +1945,7 @@ def provision_with_retries( to_provision_config: ToProvisionConfig, dryrun: bool, stream_logs: bool, + skip_if_config_hash_matches: Optional[str], ) -> Dict[str, Any]: """Provision with retries for all launchable resources.""" cluster_name = to_provision_config.cluster_name @@ -1986,7 +1995,8 @@ def provision_with_retries( cloud_user_identity=cloud_user, prev_cluster_status=prev_cluster_status, prev_handle=prev_handle, - prev_cluster_ever_up=prev_cluster_ever_up) + prev_cluster_ever_up=prev_cluster_ever_up, + skip_if_config_hash_matches=skip_if_config_hash_matches) if dryrun: return config_dict except (exceptions.InvalidClusterNameError, @@ -2687,13 +2697,15 @@ def check_resources_fit_cluster( return valid_resource def _provision( - self, - task: task_lib.Task, - to_provision: Optional[resources_lib.Resources], - dryrun: bool, - stream_logs: bool, - cluster_name: str, - retry_until_up: bool = False) -> Optional[CloudVmRayResourceHandle]: + self, + task: task_lib.Task, + to_provision: Optional[resources_lib.Resources], + dryrun: bool, + stream_logs: bool, + cluster_name: str, + retry_until_up: bool = False, + skip_if_config_hash_matches: Optional[str] = None + ) -> Optional[CloudVmRayResourceHandle]: """Provisions using 'ray up'. Raises: @@ -2779,7 +2791,8 @@ def _provision( rich_utils.force_update_status( ux_utils.spinner_message('Launching', log_path)) config_dict = retry_provisioner.provision_with_retries( - task, to_provision_config, dryrun, stream_logs) + task, to_provision_config, dryrun, stream_logs, + skip_if_config_hash_matches) break except exceptions.ResourcesUnavailableError as e: # Do not remove the stopped cluster from the global state @@ -2829,6 +2842,15 @@ def _provision( record = global_user_state.get_cluster_from_name(cluster_name) return record['handle'] if record is not None else None + config_hash = config_dict['config_hash'] + + if skip_if_config_hash_matches is not None: + record = global_user_state.get_cluster_from_name(cluster_name) + if (record is not None and skip_if_config_hash_matches == + config_hash == record['config_hash']): + logger.info('skip remaining') + return record['handle'] + if 'provision_record' in config_dict: # New provisioner is used here. handle = config_dict['handle'] @@ -2868,7 +2890,7 @@ def _provision( self._update_after_cluster_provisioned( handle, to_provision_config.prev_handle, task, prev_cluster_status, handle.external_ips(), - handle.external_ssh_ports(), lock_path) + handle.external_ssh_ports(), lock_path, config_hash) return handle cluster_config_file = config_dict['ray'] @@ -2940,7 +2962,8 @@ def _get_zone(runner): self._update_after_cluster_provisioned( handle, to_provision_config.prev_handle, task, - prev_cluster_status, ip_list, ssh_port_list, lock_path) + prev_cluster_status, ip_list, ssh_port_list, lock_path, + config_hash) return handle def _open_ports(self, handle: CloudVmRayResourceHandle) -> None: @@ -2958,8 +2981,8 @@ def _update_after_cluster_provisioned( prev_handle: Optional[CloudVmRayResourceHandle], task: task_lib.Task, prev_cluster_status: Optional[status_lib.ClusterStatus], - ip_list: List[str], ssh_port_list: List[int], - lock_path: str) -> None: + ip_list: List[str], ssh_port_list: List[int], lock_path: str, + config_hash: str) -> None: usage_lib.messages.usage.update_cluster_resources( handle.launched_nodes, handle.launched_resources) usage_lib.messages.usage.update_final_cluster_status( @@ -3019,6 +3042,7 @@ def _update_after_cluster_provisioned( handle, set(task.resources), ready=True, + config_hash=config_hash, ) usage_lib.messages.usage.update_final_cluster_status( status_lib.ClusterStatus.UP) diff --git a/sky/backends/local_docker_backend.py b/sky/backends/local_docker_backend.py index 2cc3f3347a5..4b29931d52f 100644 --- a/sky/backends/local_docker_backend.py +++ b/sky/backends/local_docker_backend.py @@ -131,13 +131,14 @@ def check_resources_fit_cluster(self, handle: 'LocalDockerResourceHandle', pass def _provision( - self, - task: 'task_lib.Task', - to_provision: Optional['resources.Resources'], - dryrun: bool, - stream_logs: bool, - cluster_name: str, - retry_until_up: bool = False + self, + task: 'task_lib.Task', + to_provision: Optional['resources.Resources'], + dryrun: bool, + stream_logs: bool, + cluster_name: str, + retry_until_up: bool = False, + skip_if_config_hash_matches: Optional[str] = None ) -> Optional[LocalDockerResourceHandle]: """Builds docker image for the task and returns cluster name as handle. @@ -153,6 +154,9 @@ def _provision( logger.warning( f'Retrying until up is not supported in backend: {self.NAME}. ' 'Ignored the flag.') + if skip_if_config_hash_matches is not None: + logger.warning(f'Config hashing is not supported in backend: ' + f'{self.NAME}. Ignored skip_if_config_hash_matches.') if stream_logs: logger.info( 'Streaming build logs is not supported in LocalDockerBackend. ' diff --git a/sky/execution.py b/sky/execution.py index 8fab5e583fb..70d2a4c5993 100644 --- a/sky/execution.py +++ b/sky/execution.py @@ -108,6 +108,7 @@ def _execute( idle_minutes_to_autostop: Optional[int] = None, no_setup: bool = False, clone_disk_from: Optional[str] = None, + skip_unecessary_provisioning: bool = False, # Internal only: # pylint: disable=invalid-name _is_launched_by_jobs_controller: bool = False, @@ -128,8 +129,9 @@ def _execute( Note that if errors occur during provisioning/data syncing/setting up, the cluster will not be torn down for debugging purposes. stream_logs: bool; whether to stream all tasks' outputs to the client. - handle: Optional[backends.ResourceHandle]; if provided, execution will use - an existing backend cluster handle instead of provisioning a new one. + handle: Optional[backends.ResourceHandle]; if provided, execution will + attempt to use an existing backend cluster handle instead of + provisioning a new one. backend: Backend; backend to use for executing the tasks. Defaults to CloudVmRayBackend() retry_until_up: bool; whether to retry the provisioning until the cluster @@ -150,6 +152,11 @@ def _execute( idle_minutes_to_autostop: int; if provided, the cluster will be set to autostop after this many minutes of idleness. no_setup: bool; whether to skip setup commands or not when (re-)launching. + clone_disk_from: Optional[str]; if set, clone the disk from the specified + cluster. + skip_unecessary_provisioning: bool; if True, compare the calculated + cluster config to the current cluster's config. If they match, shortcut + provisioning even if we have Stage.PROVISION. Returns: job_id: Optional[int]; the job ID of the submitted job. None if the @@ -179,9 +186,13 @@ def _execute( f'{colorama.Style.RESET_ALL}') cluster_exists = False + existing_config_hash = None if cluster_name is not None: cluster_record = global_user_state.get_cluster_from_name(cluster_name) - cluster_exists = cluster_record is not None + if cluster_record is not None: + cluster_exists = True + if skip_unecessary_provisioning: + existing_config_hash = cluster_record['config_hash'] # TODO(woosuk): If the cluster exists, print a warning that # `cpus` and `memory` are not used as a job scheduling constraint, # unlike `gpus`. @@ -279,13 +290,18 @@ def _execute( try: if Stage.PROVISION in stages: - if handle is None: - handle = backend.provision(task, - task.best_resources, - dryrun=dryrun, - stream_logs=stream_logs, - cluster_name=cluster_name, - retry_until_up=retry_until_up) + assert handle is None or skip_unecessary_provisioning, ( + 'Provisioning requested, but handle is already set. PROVISION ' + 'should be excluded from stages or ' + 'skip_unecessary_provisioning should be set. ') + handle = backend.provision( + task, + task.best_resources, + dryrun=dryrun, + stream_logs=stream_logs, + cluster_name=cluster_name, + retry_until_up=retry_until_up, + skip_if_config_hash_matches=existing_config_hash) if handle is None: assert dryrun, ('If not dryrun, handle must be set or ' @@ -459,6 +475,7 @@ def launch( handle = None stages = None + skip_unecessary_provisioning = False # Check if cluster exists and we are doing fast provisioning if fast and cluster_name is not None: maybe_handle = global_user_state.get_handle_from_cluster_name( @@ -472,14 +489,18 @@ def launch( check_cloud_vm_ray_backend=False, dryrun=dryrun) handle = maybe_handle - # Get all stages + logger.info('provision') stages = [ + # Provisioning will be short-circuited if the existing + # cluster config hash matches the calculated one. + Stage.PROVISION, Stage.SYNC_WORKDIR, Stage.SYNC_FILE_MOUNTS, Stage.PRE_EXEC, Stage.EXEC, Stage.DOWN, ] + skip_unecessary_provisioning = True except exceptions.ClusterNotUpError: # Proceed with normal provisioning pass @@ -500,6 +521,7 @@ def launch( idle_minutes_to_autostop=idle_minutes_to_autostop, no_setup=no_setup, clone_disk_from=clone_disk_from, + skip_unecessary_provisioning=skip_unecessary_provisioning, _is_launched_by_jobs_controller=_is_launched_by_jobs_controller, _is_launched_by_sky_serve_controller= _is_launched_by_sky_serve_controller, diff --git a/sky/global_user_state.py b/sky/global_user_state.py index 7c040ea55fc..0f3844a185e 100644 --- a/sky/global_user_state.py +++ b/sky/global_user_state.py @@ -60,7 +60,8 @@ def create_table(cursor, conn): owner TEXT DEFAULT null, cluster_hash TEXT DEFAULT null, storage_mounts_metadata BLOB DEFAULT null, - cluster_ever_up INTEGER DEFAULT 0)""") + cluster_ever_up INTEGER DEFAULT 0, + config_hash TEXT DEFAULT null)""") # Table for Cluster History # usage_intervals: List[Tuple[int, int]] @@ -130,6 +131,10 @@ def create_table(cursor, conn): # clusters were never really UP, setting it to 1 means they won't be # auto-deleted during any failover. value_to_replace_existing_entries=1) + + db_utils.add_column_to_table(cursor, conn, 'clusters', 'config_hash', + 'TEXT DEFAULT null') + conn.commit() @@ -140,7 +145,8 @@ def add_or_update_cluster(cluster_name: str, cluster_handle: 'backends.ResourceHandle', requested_resources: Optional[Set[Any]], ready: bool, - is_launch: bool = True): + is_launch: bool = True, + config_hash: Optional[str] = None): """Adds or updates cluster_name -> cluster_handle mapping. Args: @@ -191,7 +197,7 @@ def add_or_update_cluster(cluster_name: str, # specified. '(name, launched_at, handle, last_use, status, ' 'autostop, to_down, metadata, owner, cluster_hash, ' - 'storage_mounts_metadata, cluster_ever_up) ' + 'storage_mounts_metadata, cluster_ever_up, config_hash) ' 'VALUES (' # name '?, ' @@ -228,7 +234,9 @@ def add_or_update_cluster(cluster_name: str, 'COALESCE(' '(SELECT storage_mounts_metadata FROM clusters WHERE name=?), null), ' # cluster_ever_up - '((SELECT cluster_ever_up FROM clusters WHERE name=?) OR ?)' + '((SELECT cluster_ever_up FROM clusters WHERE name=?) OR ?),' + # config_hash + 'COALESCE(?, (SELECT config_hash FROM clusters WHERE name=?))' ')', ( # name @@ -260,6 +268,9 @@ def add_or_update_cluster(cluster_name: str, # cluster_ever_up cluster_name, int(ready), + # config_hash + config_hash, + cluster_name, )) launched_nodes = getattr(cluster_handle, 'launched_nodes', None) @@ -570,15 +581,18 @@ def _load_storage_mounts_metadata( def get_cluster_from_name( cluster_name: Optional[str]) -> Optional[Dict[str, Any]]: - rows = _DB.cursor.execute('SELECT * FROM clusters WHERE name=(?)', - (cluster_name,)).fetchall() + rows = _DB.cursor.execute( + 'SELECT name, launched_at, handle, last_use, status, autostop, ' + 'metadata, to_down, owner, cluster_hash, storage_mounts_metadata, ' + 'cluster_ever_up, config_hash FROM clusters WHERE name=(?)', + (cluster_name,)).fetchall() for row in rows: # Explicitly specify the number of fields to unpack, so that # we can add new fields to the database in the future without # breaking the previous code. (name, launched_at, handle, last_use, status, autostop, metadata, - to_down, owner, cluster_hash, storage_mounts_metadata, - cluster_ever_up) = row[:12] + to_down, owner, cluster_hash, storage_mounts_metadata, cluster_ever_up, + config_hash) = row[:13] # TODO: use namedtuple instead of dict record = { 'name': name, @@ -594,6 +608,7 @@ def get_cluster_from_name( 'storage_mounts_metadata': _load_storage_mounts_metadata(storage_mounts_metadata), 'cluster_ever_up': bool(cluster_ever_up), + 'config_hash': config_hash, } return record return None @@ -601,12 +616,15 @@ def get_cluster_from_name( def get_clusters() -> List[Dict[str, Any]]: rows = _DB.cursor.execute( - 'select * from clusters order by launched_at desc').fetchall() + 'select name, launched_at, handle, last_use, status, autostop, ' + 'metadata, to_down, owner, cluster_hash, storage_mounts_metadata, ' + 'cluster_ever_up, config_hash from clusters order by launched_at desc' + ).fetchall() records = [] for row in rows: (name, launched_at, handle, last_use, status, autostop, metadata, - to_down, owner, cluster_hash, storage_mounts_metadata, - cluster_ever_up) = row[:12] + to_down, owner, cluster_hash, storage_mounts_metadata, cluster_ever_up, + config_hash) = row[:13] # TODO: use namedtuple instead of dict record = { 'name': name, @@ -622,6 +640,7 @@ def get_clusters() -> List[Dict[str, Any]]: 'storage_mounts_metadata': _load_storage_mounts_metadata(storage_mounts_metadata), 'cluster_ever_up': bool(cluster_ever_up), + 'config_hash': config_hash, } records.append(record)