diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index a3651bdba9a..0f55b8a7f17 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -173,6 +173,16 @@ ('available_node_types', 'ray.head.default', 'node_config', 'azure_arm_parameters', 'cloudInitSetupCommands'), ] +# These keys are expected to change when provisioning on an existing cluster, +# but they don't actually represent a change that requires re-provisioning the +# cluster. If the cluster yaml is the same except for these keys, we can safely +# skip reprovisioning. See _deterministic_cluster_yaml_hash. +_RAY_YAML_KEYS_TO_REMOVE_FOR_HASH = [ + # On first launch, availability_zones will include all possible zones. Once + # the cluster exists, it will only include the zone that the cluster is + # actually in. + ('provider', 'availability_zone'), +] def is_ip(s: str) -> bool: @@ -1087,7 +1097,7 @@ def _deterministic_cluster_yaml_hash(yaml_path: str) -> str: yaml file and all the files in the file mounts, then hash the byte sequence. The format of the byte sequence is: - 32 bytes - sha256 hash of the yaml file + 32 bytes - sha256 hash of the yaml for each file mount: file mount remote destination (UTF-8), \0 if the file mount source is a file: @@ -1111,14 +1121,29 @@ def _deterministic_cluster_yaml_hash(yaml_path: str) -> str: we construct it incrementally by using hash.update() to add new bytes. """ + # Load the yaml contents so that we can directly remove keys. + yaml_config = common_utils.read_yaml(yaml_path) + for key_list in _RAY_YAML_KEYS_TO_REMOVE_FOR_HASH: + dict_to_remove_from = yaml_config + found_key = True + for key in key_list[:-1]: + if (not isinstance(dict_to_remove_from, dict) or + key not in dict_to_remove_from): + found_key = False + break + dict_to_remove_from = dict_to_remove_from[key] + if found_key and key_list[-1] in dict_to_remove_from: + dict_to_remove_from.pop(key_list[-1]) + def _hash_file(path: str) -> bytes: return common_utils.hash_file(path, 'sha256').digest() config_hash = hashlib.sha256() - config_hash.update(_hash_file(yaml_path)) + yaml_hash = hashlib.sha256( + common_utils.dump_yaml_str(yaml_config).encode('utf-8')) + config_hash.update(yaml_hash.digest()) - yaml_config = common_utils.read_yaml(yaml_path) file_mounts = yaml_config.get('file_mounts', {}) # Remove the file mounts added by the newline. if '' in file_mounts: @@ -1126,6 +1151,11 @@ def _hash_file(path: str) -> bytes: file_mounts.pop('') for dst, src in sorted(file_mounts.items()): + if src == yaml_path: + # Skip the yaml file itself. We have already hashed a modified + # version of it. The file may include fields we don't want to hash. + continue + expanded_src = os.path.expanduser(src) config_hash.update(dst.encode('utf-8') + b'\0')