Skip to content

Commit

Permalink
make --fast robust against credential or wheel updates (#4289)
Browse files Browse the repository at this point in the history
* add config_dict['config_hash'] output to write_cluster_config

* fix docstring for write_cluster_config

This used to be true, but since #2943, 'ray' is the only provisioner.
Add other keys that are now present instead.

* when using --fast, check if config_hash matches, and if not, provision

* mock hashing method in unit test

This is needed since some files in the fake file mounts don't actually exist,
like the wheel path.

* check config hash within provision with lock held

* address other PR review comments

* rename to skip_if_no_cluster_updates

Co-authored-by: Zhanghao Wu <[email protected]>

* add assert details

Co-authored-by: Zhanghao Wu <[email protected]>

* address PR comments and update docstrings

* fix test

* update docstrings

Co-authored-by: Zhanghao Wu <[email protected]>

* address PR comments

* fix lint and tests

* Update sky/backends/cloud_vm_ray_backend.py

Co-authored-by: Zhanghao Wu <[email protected]>

* refactor skip_if_no_cluster_update var

* clarify comment

* format exception

---------

Co-authored-by: Zhanghao Wu <[email protected]>
  • Loading branch information
cg505 and Michaelvll authored Dec 4, 2024
1 parent 51a7e17 commit 3009204
Show file tree
Hide file tree
Showing 9 changed files with 340 additions and 66 deletions.
57 changes: 42 additions & 15 deletions sky/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,45 @@ def check_resources_fit_cluster(self, handle: _ResourceHandleType,
@timeline.event
@usage_lib.messages.usage.update_runtime('provision')
def provision(
self,
task: 'task_lib.Task',
to_provision: Optional['resources.Resources'],
dryrun: bool,
stream_logs: bool,
cluster_name: Optional[str] = None,
retry_until_up: bool = False) -> Optional[_ResourceHandleType]:
self,
task: 'task_lib.Task',
to_provision: Optional['resources.Resources'],
dryrun: bool,
stream_logs: bool,
cluster_name: Optional[str] = None,
retry_until_up: bool = False,
skip_unnecessary_provisioning: bool = False,
) -> Optional[_ResourceHandleType]:
"""Provisions resources for the given task.
Args:
task: The task to provision resources for.
to_provision: Resource config to provision. Should only be None if
cluster_name refers to an existing cluster, whose resources will
be used.
dryrun: If True, don't actually provision anything.
stream_logs: If True, stream additional logs to console.
cluster_name: Name of the cluster to provision. If None, a name will
be auto-generated. If the name refers to an existing cluster,
the existing cluster will be reused and re-provisioned.
retry_until_up: If True, retry provisioning until resources are
successfully launched.
skip_if_no_cluster_updates: If True, compare the cluster config to
the existing cluster_name's config. Skip provisioning if no
updates are needed for the existing cluster.
Returns:
A ResourceHandle object for the provisioned resources, or None if
dryrun is True.
"""
if cluster_name is None:
cluster_name = sky.backends.backend_utils.generate_cluster_name()
usage_lib.record_cluster_name_for_current_operation(cluster_name)
usage_lib.messages.usage.update_actual_task(task)
with rich_utils.safe_status(ux_utils.spinner_message('Launching')):
return self._provision(task, to_provision, dryrun, stream_logs,
cluster_name, retry_until_up)
cluster_name, retry_until_up,
skip_unnecessary_provisioning)

@timeline.event
@usage_lib.messages.usage.update_runtime('sync_workdir')
Expand Down Expand Up @@ -126,13 +151,15 @@ def register_info(self, **kwargs) -> None:

# --- Implementations of the APIs ---
def _provision(
self,
task: 'task_lib.Task',
to_provision: Optional['resources.Resources'],
dryrun: bool,
stream_logs: bool,
cluster_name: str,
retry_until_up: bool = False) -> Optional[_ResourceHandleType]:
self,
task: 'task_lib.Task',
to_provision: Optional['resources.Resources'],
dryrun: bool,
stream_logs: bool,
cluster_name: str,
retry_until_up: bool = False,
skip_unnecessary_provisioning: bool = False,
) -> Optional[_ResourceHandleType]:
raise NotImplementedError

def _sync_workdir(self, handle: _ResourceHandleType, workdir: Path) -> None:
Expand Down
143 changes: 138 additions & 5 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import enum
import fnmatch
import functools
import hashlib
import os
import pathlib
import pprint
Expand Down Expand Up @@ -644,11 +645,17 @@ def write_cluster_config(
keep_launch_fields_in_existing_config: bool = True) -> Dict[str, str]:
"""Fills in cluster configuration templates and writes them out.
Returns: {provisioner: path to yaml, the provisioning spec}.
'provisioner' can be
- 'ray'
- 'tpu-create-script' (if TPU is requested)
- 'tpu-delete-script' (if TPU is requested)
Returns:
Dict with the following keys:
- 'ray': Path to the generated Ray yaml config file
- 'cluster_name': Name of the cluster
- 'cluster_name_on_cloud': Name of the cluster as it appears in the
cloud provider
- 'config_hash': Hash of the cluster config and file mounts contents.
Can be missing if we unexpectedly failed to calculate the hash for
some reason. In that case we will continue without the optimization to
skip provisioning.
Raises:
exceptions.ResourcesUnavailableError: if the region/zones requested does
not appear in the catalog, or an ssh_proxy_command is specified but
Expand Down Expand Up @@ -903,6 +910,12 @@ def write_cluster_config(
if dryrun:
# If dryrun, return the unfinished tmp yaml path.
config_dict['ray'] = tmp_yaml_path
try:
config_dict['config_hash'] = _deterministic_cluster_yaml_hash(
tmp_yaml_path)
except Exception as e: # pylint: disable=broad-except
logger.warning(f'Failed to calculate config_hash: {e}')
logger.debug('Full exception:', exc_info=e)
return config_dict
_add_auth_to_cluster_config(cloud, tmp_yaml_path)

Expand All @@ -925,6 +938,17 @@ def write_cluster_config(
yaml_config = common_utils.read_yaml(tmp_yaml_path)
config_dict['cluster_name_on_cloud'] = yaml_config['cluster_name']

# Make sure to do this before we optimize file mounts. Optimization is
# non-deterministic, but everything else before this point should be
# deterministic.
try:
config_dict['config_hash'] = _deterministic_cluster_yaml_hash(
tmp_yaml_path)
except Exception as e: # pylint: disable=broad-except
logger.warning('Failed to calculate config_hash: '
f'{common_utils.format_exception(e)}')
logger.debug('Full exception:', exc_info=e)

# Optimization: copy the contents of source files in file_mounts to a
# special dir, and upload that as the only file_mount instead. Delay
# calling this optimization until now, when all source files have been
Expand Down Expand Up @@ -1033,6 +1057,115 @@ def get_ready_nodes_counts(pattern, output):
return ready_head, ready_workers


@timeline.event
def _deterministic_cluster_yaml_hash(yaml_path: str) -> str:
"""Hash the cluster yaml and contents of file mounts to a unique string.
Two invocations of this function should return the same string if and only
if the contents of the yaml are the same and the file contents of all the
file_mounts specified in the yaml are the same.
Limitations:
- This function can be expensive if the file mounts are large. (E.g. a few
seconds for ~1GB.) This should be okay since we expect that the
file_mounts in the cluster yaml (the wheel and cloud credentials) will be
small.
- Symbolic links are not explicitly handled. Some symbolic link changes may
not be detected.
Implementation: We create a byte sequence that captures the state of the
yaml file and all the files in the file mounts, then hash the byte sequence.
The format of the byte sequence is:
32 bytes - sha256 hash of the yaml file
for each file mount:
file mount remote destination (UTF-8), \0
if the file mount source is a file:
'file' encoded to UTF-8
32 byte sha256 hash of the file contents
if the file mount source is a directory:
'dir' encoded to UTF-8
for each directory and subdirectory withinin the file mount (starting from
the root and descending recursively):
name of the directory (UTF-8), \0
name of each subdirectory within the directory (UTF-8) terminated by \0
\0
for each file in the directory:
name of the file (UTF-8), \0
32 bytes - sha256 hash of the file contents
\0
if the file mount source is something else or does not exist, nothing
\0\0
Rather than constructing the whole byte sequence, which may be quite large,
we construct it incrementally by using hash.update() to add new bytes.
"""

def _hash_file(path: str) -> bytes:
return common_utils.hash_file(path, 'sha256').digest()

config_hash = hashlib.sha256()

config_hash.update(_hash_file(yaml_path))

yaml_config = common_utils.read_yaml(yaml_path)
file_mounts = yaml_config.get('file_mounts', {})
# Remove the file mounts added by the newline.
if '' in file_mounts:
assert file_mounts[''] == '', file_mounts['']
file_mounts.pop('')

for dst, src in sorted(file_mounts.items()):
expanded_src = os.path.expanduser(src)
config_hash.update(dst.encode('utf-8') + b'\0')

# If the file mount source is a symlink, this should be true. In that
# case we hash the contents of the symlink destination.
if os.path.isfile(expanded_src):
config_hash.update('file'.encode('utf-8'))
config_hash.update(_hash_file(expanded_src))

# This can also be a symlink to a directory. os.walk will treat it as a
# normal directory and list the contents of the symlink destination.
elif os.path.isdir(expanded_src):
config_hash.update('dir'.encode('utf-8'))

# Aside from expanded_src, os.walk will list symlinks to directories
# but will not recurse into them.
for (dirpath, dirnames, filenames) in os.walk(expanded_src):
config_hash.update(dirpath.encode('utf-8') + b'\0')

# Note: inplace sort will also affect the traversal order of
# os.walk. We need it so that the os.walk order is
# deterministic.
dirnames.sort()
# This includes symlinks to directories. os.walk will recurse
# into all the directories but not the symlinks. We don't hash
# the link destination, so if a symlink to a directory changes,
# we won't notice.
for dirname in dirnames:
config_hash.update(dirname.encode('utf-8') + b'\0')
config_hash.update(b'\0')

filenames.sort()
# This includes symlinks to files. We could hash the symlink
# destination itself but instead just hash the destination
# contents.
for filename in filenames:
config_hash.update(filename.encode('utf-8') + b'\0')
config_hash.update(
_hash_file(os.path.join(dirpath, filename)))
config_hash.update(b'\0')

else:
logger.debug(
f'Unexpected file_mount that is not a file or dir: {src}')

config_hash.update(b'\0\0')

return config_hash.hexdigest()


def get_docker_user(ip: str, cluster_config_file: str) -> str:
"""Find docker container username."""
ssh_credentials = ssh_credential_from_yaml(cluster_config_file)
Expand Down
Loading

0 comments on commit 3009204

Please sign in to comment.