diff --git a/sky/backends/backend.py b/sky/backends/backend.py index 10b51b06038..d5fd6f19925 100644 --- a/sky/backends/backend.py +++ b/sky/backends/backend.py @@ -45,20 +45,45 @@ def check_resources_fit_cluster(self, handle: _ResourceHandleType, @timeline.event @usage_lib.messages.usage.update_runtime('provision') def provision( - self, - task: 'task_lib.Task', - to_provision: Optional['resources.Resources'], - dryrun: bool, - stream_logs: bool, - cluster_name: Optional[str] = None, - retry_until_up: bool = False) -> Optional[_ResourceHandleType]: + self, + task: 'task_lib.Task', + to_provision: Optional['resources.Resources'], + dryrun: bool, + stream_logs: bool, + cluster_name: Optional[str] = None, + retry_until_up: bool = False, + skip_unnecessary_provisioning: bool = False, + ) -> Optional[_ResourceHandleType]: + """Provisions resources for the given task. + + Args: + task: The task to provision resources for. + to_provision: Resource config to provision. Should only be None if + cluster_name refers to an existing cluster, whose resources will + be used. + dryrun: If True, don't actually provision anything. + stream_logs: If True, stream additional logs to console. + cluster_name: Name of the cluster to provision. If None, a name will + be auto-generated. If the name refers to an existing cluster, + the existing cluster will be reused and re-provisioned. + retry_until_up: If True, retry provisioning until resources are + successfully launched. + skip_if_no_cluster_updates: If True, compare the cluster config to + the existing cluster_name's config. Skip provisioning if no + updates are needed for the existing cluster. + + Returns: + A ResourceHandle object for the provisioned resources, or None if + dryrun is True. + """ if cluster_name is None: cluster_name = sky.backends.backend_utils.generate_cluster_name() usage_lib.record_cluster_name_for_current_operation(cluster_name) usage_lib.messages.usage.update_actual_task(task) with rich_utils.safe_status(ux_utils.spinner_message('Launching')): return self._provision(task, to_provision, dryrun, stream_logs, - cluster_name, retry_until_up) + cluster_name, retry_until_up, + skip_unnecessary_provisioning) @timeline.event @usage_lib.messages.usage.update_runtime('sync_workdir') @@ -126,13 +151,15 @@ def register_info(self, **kwargs) -> None: # --- Implementations of the APIs --- def _provision( - self, - task: 'task_lib.Task', - to_provision: Optional['resources.Resources'], - dryrun: bool, - stream_logs: bool, - cluster_name: str, - retry_until_up: bool = False) -> Optional[_ResourceHandleType]: + self, + task: 'task_lib.Task', + to_provision: Optional['resources.Resources'], + dryrun: bool, + stream_logs: bool, + cluster_name: str, + retry_until_up: bool = False, + skip_unnecessary_provisioning: bool = False, + ) -> Optional[_ResourceHandleType]: raise NotImplementedError def _sync_workdir(self, handle: _ResourceHandleType, workdir: Path) -> None: diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index 8daeedc6a96..67797c19406 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -3,6 +3,7 @@ import enum import fnmatch import functools +import hashlib import os import pathlib import pprint @@ -644,11 +645,17 @@ def write_cluster_config( keep_launch_fields_in_existing_config: bool = True) -> Dict[str, str]: """Fills in cluster configuration templates and writes them out. - Returns: {provisioner: path to yaml, the provisioning spec}. - 'provisioner' can be - - 'ray' - - 'tpu-create-script' (if TPU is requested) - - 'tpu-delete-script' (if TPU is requested) + Returns: + Dict with the following keys: + - 'ray': Path to the generated Ray yaml config file + - 'cluster_name': Name of the cluster + - 'cluster_name_on_cloud': Name of the cluster as it appears in the + cloud provider + - 'config_hash': Hash of the cluster config and file mounts contents. + Can be missing if we unexpectedly failed to calculate the hash for + some reason. In that case we will continue without the optimization to + skip provisioning. + Raises: exceptions.ResourcesUnavailableError: if the region/zones requested does not appear in the catalog, or an ssh_proxy_command is specified but @@ -864,6 +871,12 @@ def write_cluster_config( if dryrun: # If dryrun, return the unfinished tmp yaml path. config_dict['ray'] = tmp_yaml_path + try: + config_dict['config_hash'] = _deterministic_cluster_yaml_hash( + tmp_yaml_path) + except Exception as e: # pylint: disable=broad-except + logger.warning(f'Failed to calculate config_hash: {e}') + logger.debug('Full exception:', exc_info=e) return config_dict _add_auth_to_cluster_config(cloud, tmp_yaml_path) @@ -886,6 +899,17 @@ def write_cluster_config( yaml_config = common_utils.read_yaml(tmp_yaml_path) config_dict['cluster_name_on_cloud'] = yaml_config['cluster_name'] + # Make sure to do this before we optimize file mounts. Optimization is + # non-deterministic, but everything else before this point should be + # deterministic. + try: + config_dict['config_hash'] = _deterministic_cluster_yaml_hash( + tmp_yaml_path) + except Exception as e: # pylint: disable=broad-except + logger.warning('Failed to calculate config_hash: ' + f'{common_utils.format_exception(e)}') + logger.debug('Full exception:', exc_info=e) + # Optimization: copy the contents of source files in file_mounts to a # special dir, and upload that as the only file_mount instead. Delay # calling this optimization until now, when all source files have been @@ -994,6 +1018,115 @@ def get_ready_nodes_counts(pattern, output): return ready_head, ready_workers +@timeline.event +def _deterministic_cluster_yaml_hash(yaml_path: str) -> str: + """Hash the cluster yaml and contents of file mounts to a unique string. + + Two invocations of this function should return the same string if and only + if the contents of the yaml are the same and the file contents of all the + file_mounts specified in the yaml are the same. + + Limitations: + - This function can be expensive if the file mounts are large. (E.g. a few + seconds for ~1GB.) This should be okay since we expect that the + file_mounts in the cluster yaml (the wheel and cloud credentials) will be + small. + - Symbolic links are not explicitly handled. Some symbolic link changes may + not be detected. + + Implementation: We create a byte sequence that captures the state of the + yaml file and all the files in the file mounts, then hash the byte sequence. + + The format of the byte sequence is: + 32 bytes - sha256 hash of the yaml file + for each file mount: + file mount remote destination (UTF-8), \0 + if the file mount source is a file: + 'file' encoded to UTF-8 + 32 byte sha256 hash of the file contents + if the file mount source is a directory: + 'dir' encoded to UTF-8 + for each directory and subdirectory withinin the file mount (starting from + the root and descending recursively): + name of the directory (UTF-8), \0 + name of each subdirectory within the directory (UTF-8) terminated by \0 + \0 + for each file in the directory: + name of the file (UTF-8), \0 + 32 bytes - sha256 hash of the file contents + \0 + if the file mount source is something else or does not exist, nothing + \0\0 + + Rather than constructing the whole byte sequence, which may be quite large, + we construct it incrementally by using hash.update() to add new bytes. + """ + + def _hash_file(path: str) -> bytes: + return common_utils.hash_file(path, 'sha256').digest() + + config_hash = hashlib.sha256() + + config_hash.update(_hash_file(yaml_path)) + + yaml_config = common_utils.read_yaml(yaml_path) + file_mounts = yaml_config.get('file_mounts', {}) + # Remove the file mounts added by the newline. + if '' in file_mounts: + assert file_mounts[''] == '', file_mounts[''] + file_mounts.pop('') + + for dst, src in sorted(file_mounts.items()): + expanded_src = os.path.expanduser(src) + config_hash.update(dst.encode('utf-8') + b'\0') + + # If the file mount source is a symlink, this should be true. In that + # case we hash the contents of the symlink destination. + if os.path.isfile(expanded_src): + config_hash.update('file'.encode('utf-8')) + config_hash.update(_hash_file(expanded_src)) + + # This can also be a symlink to a directory. os.walk will treat it as a + # normal directory and list the contents of the symlink destination. + elif os.path.isdir(expanded_src): + config_hash.update('dir'.encode('utf-8')) + + # Aside from expanded_src, os.walk will list symlinks to directories + # but will not recurse into them. + for (dirpath, dirnames, filenames) in os.walk(expanded_src): + config_hash.update(dirpath.encode('utf-8') + b'\0') + + # Note: inplace sort will also affect the traversal order of + # os.walk. We need it so that the os.walk order is + # deterministic. + dirnames.sort() + # This includes symlinks to directories. os.walk will recurse + # into all the directories but not the symlinks. We don't hash + # the link destination, so if a symlink to a directory changes, + # we won't notice. + for dirname in dirnames: + config_hash.update(dirname.encode('utf-8') + b'\0') + config_hash.update(b'\0') + + filenames.sort() + # This includes symlinks to files. We could hash the symlink + # destination itself but instead just hash the destination + # contents. + for filename in filenames: + config_hash.update(filename.encode('utf-8') + b'\0') + config_hash.update( + _hash_file(os.path.join(dirpath, filename))) + config_hash.update(b'\0') + + else: + logger.debug( + f'Unexpected file_mount that is not a file or dir: {src}') + + config_hash.update(b'\0\0') + + return config_hash.hexdigest() + + def get_docker_user(ip: str, cluster_config_file: str) -> str: """Find docker container username.""" ssh_credentials = ssh_credential_from_yaml(cluster_config_file) diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index d00560ece23..a048e5e5ab3 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -1146,6 +1146,7 @@ def __init__( prev_cluster_status: Optional[status_lib.ClusterStatus], prev_handle: Optional['CloudVmRayResourceHandle'], prev_cluster_ever_up: bool, + prev_config_hash: Optional[str], ) -> None: assert cluster_name is not None, 'cluster_name must be specified.' self.cluster_name = cluster_name @@ -1154,6 +1155,7 @@ def __init__( self.prev_cluster_status = prev_cluster_status self.prev_handle = prev_handle self.prev_cluster_ever_up = prev_cluster_ever_up + self.prev_config_hash = prev_config_hash def __init__(self, log_dir: str, @@ -1315,8 +1317,21 @@ def _retry_zones( prev_cluster_status: Optional[status_lib.ClusterStatus], prev_handle: Optional['CloudVmRayResourceHandle'], prev_cluster_ever_up: bool, + skip_if_config_hash_matches: Optional[str], ) -> Dict[str, Any]: - """The provision retry loop.""" + """The provision retry loop. + + Returns a config_dict with the following fields: + All fields from backend_utils.write_cluster_config(). See its + docstring. + - 'provisioning_skipped': True if provisioning was short-circuited + by skip_if_config_hash_matches, False otherwise. + - 'handle': The provisioned cluster handle. + - 'provision_record': (Only if using the new skypilot provisioner) The + record returned by provisioner.bulk_provision(). + - 'resources_vars': (Only if using the new skypilot provisioner) The + resources variables given by make_deploy_resources_variables(). + """ # Get log_path name log_path = os.path.join(self.log_dir, 'provision.log') log_abs_path = os.path.abspath(log_path) @@ -1425,8 +1440,18 @@ def _retry_zones( raise exceptions.ResourcesUnavailableError( f'Failed to provision on cloud {to_provision.cloud} due to ' f'invalid cloud config: {common_utils.format_exception(e)}') + + if ('config_hash' in config_dict and + skip_if_config_hash_matches == config_dict['config_hash']): + logger.debug('Skipping provisioning of cluster with matching ' + 'config hash.') + config_dict['provisioning_skipped'] = True + return config_dict + config_dict['provisioning_skipped'] = False + if dryrun: return config_dict + cluster_config_file = config_dict['ray'] launched_resources = to_provision.copy(region=region.name) @@ -1938,8 +1963,13 @@ def provision_with_retries( to_provision_config: ToProvisionConfig, dryrun: bool, stream_logs: bool, + skip_unnecessary_provisioning: bool, ) -> Dict[str, Any]: - """Provision with retries for all launchable resources.""" + """Provision with retries for all launchable resources. + + Returns the config_dict from _retry_zones() - see its docstring for + details. + """ cluster_name = to_provision_config.cluster_name to_provision = to_provision_config.resources num_nodes = to_provision_config.num_nodes @@ -1948,6 +1978,8 @@ def provision_with_retries( prev_cluster_ever_up = to_provision_config.prev_cluster_ever_up launchable_retries_disabled = (self._dag is None or self._optimize_target is None) + skip_if_config_hash_matches = (to_provision_config.prev_config_hash if + skip_unnecessary_provisioning else None) failover_history: List[Exception] = list() @@ -1987,7 +2019,8 @@ def provision_with_retries( cloud_user_identity=cloud_user, prev_cluster_status=prev_cluster_status, prev_handle=prev_handle, - prev_cluster_ever_up=prev_cluster_ever_up) + prev_cluster_ever_up=prev_cluster_ever_up, + skip_if_config_hash_matches=skip_if_config_hash_matches) if dryrun: return config_dict except (exceptions.InvalidClusterNameError, @@ -2688,14 +2721,21 @@ def check_resources_fit_cluster( return valid_resource def _provision( - self, - task: task_lib.Task, - to_provision: Optional[resources_lib.Resources], - dryrun: bool, - stream_logs: bool, - cluster_name: str, - retry_until_up: bool = False) -> Optional[CloudVmRayResourceHandle]: - """Provisions using 'ray up'. + self, + task: task_lib.Task, + to_provision: Optional[resources_lib.Resources], + dryrun: bool, + stream_logs: bool, + cluster_name: str, + retry_until_up: bool = False, + skip_unnecessary_provisioning: bool = False, + ) -> Optional[CloudVmRayResourceHandle]: + """Provisions the cluster, or re-provisions an existing cluster. + + Use the SKYPILOT provisioner if it's supported by the cloud, otherwise + use 'ray up'. + + See also docstring for Backend.provision(). Raises: exceptions.ClusterOwnerIdentityMismatchError: if the cluster @@ -2780,7 +2820,8 @@ def _provision( rich_utils.force_update_status( ux_utils.spinner_message('Launching', log_path)) config_dict = retry_provisioner.provision_with_retries( - task, to_provision_config, dryrun, stream_logs) + task, to_provision_config, dryrun, stream_logs, + skip_unnecessary_provisioning) break except exceptions.ResourcesUnavailableError as e: # Do not remove the stopped cluster from the global state @@ -2830,11 +2871,23 @@ def _provision( record = global_user_state.get_cluster_from_name(cluster_name) return record['handle'] if record is not None else None + if config_dict['provisioning_skipped']: + # Skip further provisioning. + # In this case, we won't have certain fields in the config_dict + # ('handle', 'provision_record', 'resources_vars') + # We need to return the handle - but it should be the existing + # handle for the cluster. + record = global_user_state.get_cluster_from_name(cluster_name) + assert record is not None and record['handle'] is not None, ( + cluster_name, record) + return record['handle'] + if 'provision_record' in config_dict: # New provisioner is used here. handle = config_dict['handle'] provision_record = config_dict['provision_record'] resources_vars = config_dict['resources_vars'] + config_hash = config_dict.get('config_hash', None) # Setup SkyPilot runtime after the cluster is provisioned # 1. Wait for SSH to be ready. @@ -2869,7 +2922,7 @@ def _provision( self._update_after_cluster_provisioned( handle, to_provision_config.prev_handle, task, prev_cluster_status, handle.external_ips(), - handle.external_ssh_ports(), lock_path) + handle.external_ssh_ports(), lock_path, config_hash) return handle cluster_config_file = config_dict['ray'] @@ -2941,7 +2994,8 @@ def _get_zone(runner): self._update_after_cluster_provisioned( handle, to_provision_config.prev_handle, task, - prev_cluster_status, ip_list, ssh_port_list, lock_path) + prev_cluster_status, ip_list, ssh_port_list, lock_path, + config_hash) return handle def _open_ports(self, handle: CloudVmRayResourceHandle) -> None: @@ -2959,8 +3013,8 @@ def _update_after_cluster_provisioned( prev_handle: Optional[CloudVmRayResourceHandle], task: task_lib.Task, prev_cluster_status: Optional[status_lib.ClusterStatus], - ip_list: List[str], ssh_port_list: List[int], - lock_path: str) -> None: + ip_list: List[str], ssh_port_list: List[int], lock_path: str, + config_hash: str) -> None: usage_lib.messages.usage.update_cluster_resources( handle.launched_nodes, handle.launched_resources) usage_lib.messages.usage.update_final_cluster_status( @@ -3020,6 +3074,7 @@ def _update_after_cluster_provisioned( handle, set(task.resources), ready=True, + config_hash=config_hash, ) usage_lib.messages.usage.update_final_cluster_status( status_lib.ClusterStatus.UP) @@ -4318,6 +4373,7 @@ def _check_existing_cluster( # cluster is terminated (through console or auto-dwon), the record will # become None and the cluster_ever_up should be considered as False. cluster_ever_up = record is not None and record['cluster_ever_up'] + prev_config_hash = record['config_hash'] if record is not None else None logger.debug(f'cluster_ever_up: {cluster_ever_up}') logger.debug(f'record: {record}') @@ -4356,7 +4412,8 @@ def _check_existing_cluster( handle.launched_nodes, prev_cluster_status=prev_cluster_status, prev_handle=handle, - prev_cluster_ever_up=cluster_ever_up) + prev_cluster_ever_up=cluster_ever_up, + prev_config_hash=prev_config_hash) usage_lib.messages.usage.set_new_cluster() # Use the task_cloud, because the cloud in `to_provision` can be changed # later during the retry. @@ -4397,7 +4454,8 @@ def _check_existing_cluster( task.num_nodes, prev_cluster_status=None, prev_handle=None, - prev_cluster_ever_up=False) + prev_cluster_ever_up=False, + prev_config_hash=prev_config_hash) def _execute_file_mounts(self, handle: CloudVmRayResourceHandle, file_mounts: Optional[Dict[Path, Path]]): diff --git a/sky/backends/local_docker_backend.py b/sky/backends/local_docker_backend.py index 2cc3f3347a5..c10e51e7975 100644 --- a/sky/backends/local_docker_backend.py +++ b/sky/backends/local_docker_backend.py @@ -131,13 +131,14 @@ def check_resources_fit_cluster(self, handle: 'LocalDockerResourceHandle', pass def _provision( - self, - task: 'task_lib.Task', - to_provision: Optional['resources.Resources'], - dryrun: bool, - stream_logs: bool, - cluster_name: str, - retry_until_up: bool = False + self, + task: 'task_lib.Task', + to_provision: Optional['resources.Resources'], + dryrun: bool, + stream_logs: bool, + cluster_name: str, + retry_until_up: bool = False, + skip_unnecessary_provisioning: bool = False, ) -> Optional[LocalDockerResourceHandle]: """Builds docker image for the task and returns cluster name as handle. @@ -153,6 +154,9 @@ def _provision( logger.warning( f'Retrying until up is not supported in backend: {self.NAME}. ' 'Ignored the flag.') + if skip_unnecessary_provisioning: + logger.warning(f'skip_unnecessary_provisioning is not supported in ' + f'backend: {self.NAME}. Ignored the flag.') if stream_logs: logger.info( 'Streaming build logs is not supported in LocalDockerBackend. ' diff --git a/sky/clouds/service_catalog/common.py b/sky/clouds/service_catalog/common.py index 387d695d637..67c6e09b27e 100644 --- a/sky/clouds/service_catalog/common.py +++ b/sky/clouds/service_catalog/common.py @@ -15,6 +15,7 @@ from sky.clouds import cloud as cloud_lib from sky.clouds import cloud_registry from sky.clouds.service_catalog import constants +from sky.utils import common_utils from sky.utils import rich_utils from sky.utils import ux_utils @@ -69,8 +70,7 @@ def is_catalog_modified(filename: str) -> bool: meta_path = os.path.join(_ABSOLUTE_VERSIONED_CATALOG_DIR, '.meta', filename) md5_filepath = meta_path + '.md5' if os.path.exists(md5_filepath): - with open(catalog_path, 'rb') as f: - file_md5 = hashlib.md5(f.read()).hexdigest() + file_md5 = common_utils.hash_file(catalog_path, 'md5').hexdigest() with open(md5_filepath, 'r', encoding='utf-8') as f: last_md5 = f.read() return file_md5 != last_md5 diff --git a/sky/execution.py b/sky/execution.py index 350a482a418..99261651144 100644 --- a/sky/execution.py +++ b/sky/execution.py @@ -108,6 +108,7 @@ def _execute( idle_minutes_to_autostop: Optional[int] = None, no_setup: bool = False, clone_disk_from: Optional[str] = None, + skip_unnecessary_provisioning: bool = False, # Internal only: # pylint: disable=invalid-name _is_launched_by_jobs_controller: bool = False, @@ -128,8 +129,9 @@ def _execute( Note that if errors occur during provisioning/data syncing/setting up, the cluster will not be torn down for debugging purposes. stream_logs: bool; whether to stream all tasks' outputs to the client. - handle: Optional[backends.ResourceHandle]; if provided, execution will use - an existing backend cluster handle instead of provisioning a new one. + handle: Optional[backends.ResourceHandle]; if provided, execution will + attempt to use an existing backend cluster handle instead of + provisioning a new one. backend: Backend; backend to use for executing the tasks. Defaults to CloudVmRayBackend() retry_until_up: bool; whether to retry the provisioning until the cluster @@ -150,6 +152,11 @@ def _execute( idle_minutes_to_autostop: int; if provided, the cluster will be set to autostop after this many minutes of idleness. no_setup: bool; whether to skip setup commands or not when (re-)launching. + clone_disk_from: Optional[str]; if set, clone the disk from the specified + cluster. + skip_unecessary_provisioning: bool; if True, compare the calculated + cluster config to the current cluster's config. If they match, shortcut + provisioning even if we have Stage.PROVISION. Returns: job_id: Optional[int]; the job ID of the submitted job. None if the @@ -287,13 +294,18 @@ def _execute( try: if Stage.PROVISION in stages: - if handle is None: - handle = backend.provision(task, - task.best_resources, - dryrun=dryrun, - stream_logs=stream_logs, - cluster_name=cluster_name, - retry_until_up=retry_until_up) + assert handle is None or skip_unnecessary_provisioning, ( + 'Provisioning requested, but handle is already set. PROVISION ' + 'should be excluded from stages or ' + 'skip_unecessary_provisioning should be set. ') + handle = backend.provision( + task, + task.best_resources, + dryrun=dryrun, + stream_logs=stream_logs, + cluster_name=cluster_name, + retry_until_up=retry_until_up, + skip_unnecessary_provisioning=skip_unnecessary_provisioning) if handle is None: assert dryrun, ('If not dryrun, handle must be set or ' @@ -467,6 +479,7 @@ def launch( handle = None stages = None + skip_unnecessary_provisioning = False # Check if cluster exists and we are doing fast provisioning if fast and cluster_name is not None: cluster_status, maybe_handle = ( @@ -500,12 +513,16 @@ def launch( if cluster_status == status_lib.ClusterStatus.UP: handle = maybe_handle stages = [ + # Provisioning will be short-circuited if the existing + # cluster config hash matches the calculated one. + Stage.PROVISION, Stage.SYNC_WORKDIR, Stage.SYNC_FILE_MOUNTS, Stage.PRE_EXEC, Stage.EXEC, Stage.DOWN, ] + skip_unnecessary_provisioning = True return _execute( entrypoint=entrypoint, @@ -523,6 +540,7 @@ def launch( idle_minutes_to_autostop=idle_minutes_to_autostop, no_setup=no_setup, clone_disk_from=clone_disk_from, + skip_unnecessary_provisioning=skip_unnecessary_provisioning, _is_launched_by_jobs_controller=_is_launched_by_jobs_controller, _is_launched_by_sky_serve_controller= _is_launched_by_sky_serve_controller, diff --git a/sky/global_user_state.py b/sky/global_user_state.py index e9f15df4f52..2a5cbc7eb3f 100644 --- a/sky/global_user_state.py +++ b/sky/global_user_state.py @@ -61,7 +61,8 @@ def create_table(cursor, conn): cluster_hash TEXT DEFAULT null, storage_mounts_metadata BLOB DEFAULT null, cluster_ever_up INTEGER DEFAULT 0, - status_updated_at INTEGER DEFAULT null)""") + status_updated_at INTEGER DEFAULT null, + config_hash TEXT DEFAULT null)""") # Table for Cluster History # usage_intervals: List[Tuple[int, int]] @@ -135,6 +136,9 @@ def create_table(cursor, conn): db_utils.add_column_to_table(cursor, conn, 'clusters', 'status_updated_at', 'INTEGER DEFAULT null') + db_utils.add_column_to_table(cursor, conn, 'clusters', 'config_hash', + 'TEXT DEFAULT null') + conn.commit() @@ -145,7 +149,8 @@ def add_or_update_cluster(cluster_name: str, cluster_handle: 'backends.ResourceHandle', requested_resources: Optional[Set[Any]], ready: bool, - is_launch: bool = True): + is_launch: bool = True, + config_hash: Optional[str] = None): """Adds or updates cluster_name -> cluster_handle mapping. Args: @@ -197,7 +202,8 @@ def add_or_update_cluster(cluster_name: str, # specified. '(name, launched_at, handle, last_use, status, ' 'autostop, to_down, metadata, owner, cluster_hash, ' - 'storage_mounts_metadata, cluster_ever_up, status_updated_at) ' + 'storage_mounts_metadata, cluster_ever_up, status_updated_at, ' + 'config_hash) ' 'VALUES (' # name '?, ' @@ -236,7 +242,9 @@ def add_or_update_cluster(cluster_name: str, # cluster_ever_up '((SELECT cluster_ever_up FROM clusters WHERE name=?) OR ?),' # status_updated_at - '?' + '?,' + # config_hash + 'COALESCE(?, (SELECT config_hash FROM clusters WHERE name=?))' ')', ( # name @@ -270,6 +278,9 @@ def add_or_update_cluster(cluster_name: str, int(ready), # status_updated_at status_updated_at, + # config_hash + config_hash, + cluster_name, )) launched_nodes = getattr(cluster_handle, 'launched_nodes', None) @@ -585,15 +596,15 @@ def get_cluster_from_name( rows = _DB.cursor.execute( 'SELECT name, launched_at, handle, last_use, status, autostop, ' 'metadata, to_down, owner, cluster_hash, storage_mounts_metadata, ' - 'cluster_ever_up, status_updated_at FROM clusters WHERE name=(?)', - (cluster_name,)).fetchall() + 'cluster_ever_up, status_updated_at, config_hash ' + 'FROM clusters WHERE name=(?)', (cluster_name,)).fetchall() for row in rows: # Explicitly specify the number of fields to unpack, so that # we can add new fields to the database in the future without # breaking the previous code. (name, launched_at, handle, last_use, status, autostop, metadata, to_down, owner, cluster_hash, storage_mounts_metadata, cluster_ever_up, - status_updated_at) = row[:13] + status_updated_at, config_hash) = row[:14] # TODO: use namedtuple instead of dict record = { 'name': name, @@ -610,6 +621,7 @@ def get_cluster_from_name( _load_storage_mounts_metadata(storage_mounts_metadata), 'cluster_ever_up': bool(cluster_ever_up), 'status_updated_at': status_updated_at, + 'config_hash': config_hash, } return record return None @@ -619,13 +631,13 @@ def get_clusters() -> List[Dict[str, Any]]: rows = _DB.cursor.execute( 'select name, launched_at, handle, last_use, status, autostop, ' 'metadata, to_down, owner, cluster_hash, storage_mounts_metadata, ' - 'cluster_ever_up, status_updated_at from clusters ' - 'order by launched_at desc').fetchall() + 'cluster_ever_up, status_updated_at, config_hash ' + 'from clusters order by launched_at desc').fetchall() records = [] for row in rows: (name, launched_at, handle, last_use, status, autostop, metadata, to_down, owner, cluster_hash, storage_mounts_metadata, cluster_ever_up, - status_updated_at) = row[:13] + status_updated_at, config_hash) = row[:14] # TODO: use namedtuple instead of dict record = { 'name': name, @@ -642,6 +654,7 @@ def get_clusters() -> List[Dict[str, Any]]: _load_storage_mounts_metadata(storage_mounts_metadata), 'cluster_ever_up': bool(cluster_ever_up), 'status_updated_at': status_updated_at, + 'config_hash': config_hash, } records.append(record) diff --git a/sky/utils/common_utils.py b/sky/utils/common_utils.py index 5fce435b770..3fcdd24e505 100644 --- a/sky/utils/common_utils.py +++ b/sky/utils/common_utils.py @@ -697,3 +697,22 @@ def truncate_long_string(s: str, max_length: int = 35) -> str: if len(prefix) < max_length: prefix += s[len(prefix):max_length] return prefix + '...' + + +def hash_file(path: str, hash_alg: str) -> 'hashlib._Hash': + # In python 3.11, hashlib.file_digest is available, but for <3.11 we have to + # do it manually. + # This implementation is simplified from the implementation in CPython. + # TODO(cooperc): Use hashlib.file_digest once we move to 3.11+. + # Beware of f.read() as some files may be larger than memory. + with open(path, 'rb') as f: + file_hash = hashlib.new(hash_alg) + buf = bytearray(2**18) + view = memoryview(buf) + while True: + size = f.readinto(buf) + if size == 0: + # EOF + break + file_hash.update(view[:size]) + return file_hash diff --git a/tests/unit_tests/test_backend_utils.py b/tests/unit_tests/test_backend_utils.py index 5da4410abb9..c9aa21567c2 100644 --- a/tests/unit_tests/test_backend_utils.py +++ b/tests/unit_tests/test_backend_utils.py @@ -22,6 +22,8 @@ return_value='~/.aws/credentials') @mock.patch('sky.backends.backend_utils._get_yaml_path_from_cluster_name', return_value='/tmp/fake/path') +@mock.patch('sky.backends.backend_utils._deterministic_cluster_yaml_hash', + return_value='fake-hash') @mock.patch('sky.utils.common_utils.fill_template') def test_write_cluster_config_w_remote_identity(mock_fill_template, *mocks) -> None: