From 59971bb1ed08c318aa49330db16fcbecb4f7ca62 Mon Sep 17 00:00:00 2001 From: Christopher Cooper Date: Tue, 10 Dec 2024 17:32:25 -0800 Subject: [PATCH 1/3] detach the managed job controller from job submission Previously, the ray driver program as well as a ray worker stayed in use for the entire runtime of a managed job. Now, the job controller will detach from the submitted job/ray driver and continue running in the background. This means we have to manually manage logging as well as liveness of the controller process. Two new directories are introduced for this purpose as well as plumbing. --- sky/jobs/constants.py | 2 + sky/jobs/core.py | 4 ++ sky/jobs/utils.py | 90 ++++++++++++++++++++++++++- sky/templates/jobs-controller.yaml.j2 | 9 ++- 4 files changed, 101 insertions(+), 4 deletions(-) diff --git a/sky/jobs/constants.py b/sky/jobs/constants.py index d5f32908317..873822e9291 100644 --- a/sky/jobs/constants.py +++ b/sky/jobs/constants.py @@ -2,6 +2,8 @@ JOBS_CONTROLLER_TEMPLATE = 'jobs-controller.yaml.j2' JOBS_CONTROLLER_YAML_PREFIX = '~/.sky/jobs_controller' +JOBS_CONTROLLER_PID_FILE_DIR = '~/.sky/jobs_controller_pids' +JOBS_CONTROLLER_LOGS_DIR = '~/sky_controller_logs' JOBS_TASK_YAML_PREFIX = '~/.sky/managed_jobs' diff --git a/sky/jobs/core.py b/sky/jobs/core.py index 9cde3443816..50b04bac3bb 100644 --- a/sky/jobs/core.py +++ b/sky/jobs/core.py @@ -119,6 +119,10 @@ def launch( 'remote_user_config_path': remote_user_config_path, 'modified_catalogs': service_catalog_common.get_modified_catalog_file_mounts(), + 'controller_pid_file_dir': + managed_job_constants.JOBS_CONTROLLER_PID_FILE_DIR, + 'controller_logs_dir': + managed_job_constants.JOBS_CONTROLLER_LOGS_DIR, **controller_utils.shared_controller_vars_to_fill( controller_utils.Controllers.JOBS_CONTROLLER, remote_user_config_path=remote_user_config_path, diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index 267c205285b..13d07fc2cbb 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -9,6 +9,7 @@ import inspect import os import pathlib +import psutil import shlex import shutil import textwrap @@ -119,8 +120,38 @@ def update_managed_job_status(job_id: Optional[int] = None): else: job_ids = [job_id] for job_id_ in job_ids: - controller_status = job_lib.get_status(job_id_) - if controller_status is None or controller_status.is_terminal(): + submission_job_status = job_lib.get_status(job_id_) + if submission_job_status is None or submission_job_status.is_terminal(): + if submission_job_status == job_lib.JobStatus.SUCCEEDED: + logger.debug( + f'Job {job_id_} is already {submission_job_status}.') + # This is expected, since the submitted job will detach the + # controller and succeed, even if the controller is still + # running. Check the controller status directly. + pid_file = os.path.join( + os.path.expanduser( + managed_job_constants.JOBS_CONTROLLER_PID_FILE_DIR), + str(job_id_)) + try: + with open(pid_file, 'r') as f: + pid = int(f.read()) + logger.debug(f'Checking controller pid {pid}') + if psutil.Process(pid).is_running(): + # The controller is still running. + continue + # Otherwise, proceed to mark the job as failed. + except FileNotFoundError: + logger.debug('Submission succeeded but controller pid ' + f'file {pid_file} not found.') + # Proceed to mark the job as failed. + except ValueError: + logger.debug(f'Failed to parse the controller pid from ' + f'{pid_file}.') + # Proceed to mark the job as failed. + except psutil.NoSuchProcess: + logger.debug('Controller process not found.') + # Proceed to mark the job as failed. + logger.error(f'Controller for job {job_id_} has exited abnormally. ' 'Setting the job status to FAILED_CONTROLLER.') tasks = managed_job_state.get_managed_jobs(job_id_) @@ -527,6 +558,7 @@ def stream_logs(job_id: Optional[int], 'instead.') job_id = managed_job_ids.pop() assert job_id is not None, (job_id, job_name) + # TODO: keep the following code sync with # job_lib.JobLibCodeGen.tail_logs, we do not directly call that function # as the following code need to be run in the current machine, instead @@ -536,6 +568,59 @@ def stream_logs(job_id: Optional[int], return f'No managed job contrller log found with job_id {job_id}.' log_dir = os.path.join(constants.SKY_LOGS_DIRECTORY, run_timestamp) log_lib.tail_logs(job_id=job_id, log_dir=log_dir, follow=follow) + + controller_log_path = os.path.join( + os.path.expanduser(managed_job_constants.JOBS_CONTROLLER_LOGS_DIR), + f'{job_id}.log') + + # Wait for the log file to be written + while not os.path.exists(controller_log_path): + if not follow: + # Assume that the log file hasn't been written yet. Since we + # aren't following, just return. + return '' + + job_status = managed_job_state.get_status(job_id) + # We know that the job is present in the state table because of + # earlier checks, so it should not be None. + assert job_status is not None, (job_id, job_name) + if job_status.is_terminal(): + # Don't keep waiting. If the log file is not created by this + # point, it never will be. This job may have been submitted + # using an old version that did not create the log file, so this + # is not considered an exceptional case. + return '' + + time.sleep(log_lib._SKY_LOG_WAITING_GAP_SECONDS) + + # See also log_lib.tail_logs. + with open(controller_log_path, 'r', newline='', encoding='utf-8') as f: + # Note: we do not need to care about start_stream_at here, since + # that should be in the job log printed above. + for line in f: + print(line, end='') + # Flush. + print(end='', flush=True) + + if follow: + while True: + line = f.readline() + if line is not None and line != '': + print(line, end='', flush=True) + else: + job_status = managed_job_state.get_status(job_id) + assert job_status is not None, (job_id, job_name) + if job_status.is_terminal(): + break + + time.sleep(log_lib._SKY_LOG_TAILING_GAP_SECONDS) + + # Wait for final logs to be written. + time.sleep(1 + log_lib._SKY_LOG_TAILING_GAP_SECONDS) + + # Print any remaining logs including incomplete line. + print(f.read(), end='', flush=True) + return '' if job_id is None: @@ -868,6 +953,7 @@ def stream_logs(cls, # should be removed in v0.8.0. code = textwrap.dedent("""\ import os + import time from sky.skylet import job_lib, log_lib from sky.skylet import constants diff --git a/sky/templates/jobs-controller.yaml.j2 b/sky/templates/jobs-controller.yaml.j2 index 45cdb5141d4..7367bc9a55e 100644 --- a/sky/templates/jobs-controller.yaml.j2 +++ b/sky/templates/jobs-controller.yaml.j2 @@ -33,9 +33,14 @@ setup: | run: | {{ sky_activate_python_env }} + mkdir -p {{controller_logs_dir}} + mkdir -p {{controller_pid_file_dir}} # Start the controller for the current managed job. - python -u -m sky.jobs.controller {{remote_user_yaml_path}} \ - --job-id $SKYPILOT_INTERNAL_JOB_ID {% if retry_until_up %}--retry-until-up{% endif %} + nohup python -u -m sky.jobs.controller {{remote_user_yaml_path}} \ + --job-id $SKYPILOT_INTERNAL_JOB_ID {% if retry_until_up %}--retry-until-up{% endif %} \ + > {{controller_logs_dir}}/$SKYPILOT_INTERNAL_JOB_ID.log 2>&1 {{controller_pid_file_dir}}/$SKYPILOT_INTERNAL_JOB_ID envs: {%- for env_name, env_value in controller_envs.items() %} From 50a3dce9905004fa59615d87f9f0efc9468980e5 Mon Sep 17 00:00:00 2001 From: Christopher Cooper Date: Tue, 10 Dec 2024 17:46:12 -0800 Subject: [PATCH 2/3] fix lint --- sky/jobs/utils.py | 10 +++++----- sky/skylet/log_lib.py | 16 ++++++++-------- sky/skylet/log_lib.pyi | 3 +++ 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index 13d07fc2cbb..bef25f821c9 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -9,7 +9,6 @@ import inspect import os import pathlib -import psutil import shlex import shutil import textwrap @@ -19,6 +18,7 @@ import colorama import filelock +import psutil from typing_extensions import Literal from sky import backends @@ -133,7 +133,7 @@ def update_managed_job_status(job_id: Optional[int] = None): managed_job_constants.JOBS_CONTROLLER_PID_FILE_DIR), str(job_id_)) try: - with open(pid_file, 'r') as f: + with open(pid_file, 'r', encoding='utf-8') as f: pid = int(f.read()) logger.debug(f'Checking controller pid {pid}') if psutil.Process(pid).is_running(): @@ -591,7 +591,7 @@ def stream_logs(job_id: Optional[int], # is not considered an exceptional case. return '' - time.sleep(log_lib._SKY_LOG_WAITING_GAP_SECONDS) + time.sleep(log_lib.SKY_LOG_WAITING_GAP_SECONDS) # See also log_lib.tail_logs. with open(controller_log_path, 'r', newline='', encoding='utf-8') as f: @@ -613,10 +613,10 @@ def stream_logs(job_id: Optional[int], if job_status.is_terminal(): break - time.sleep(log_lib._SKY_LOG_TAILING_GAP_SECONDS) + time.sleep(log_lib.SKY_LOG_TAILING_GAP_SECONDS) # Wait for final logs to be written. - time.sleep(1 + log_lib._SKY_LOG_TAILING_GAP_SECONDS) + time.sleep(1 + log_lib.SKY_LOG_TAILING_GAP_SECONDS) # Print any remaining logs including incomplete line. print(f.read(), end='', flush=True) diff --git a/sky/skylet/log_lib.py b/sky/skylet/log_lib.py index 8a40982972a..ac2b488baf0 100644 --- a/sky/skylet/log_lib.py +++ b/sky/skylet/log_lib.py @@ -25,9 +25,9 @@ from sky.utils import subprocess_utils from sky.utils import ux_utils -_SKY_LOG_WAITING_GAP_SECONDS = 1 -_SKY_LOG_WAITING_MAX_RETRY = 5 -_SKY_LOG_TAILING_GAP_SECONDS = 0.2 +SKY_LOG_WAITING_GAP_SECONDS = 1 +SKY_LOG_WAITING_MAX_RETRY = 5 +SKY_LOG_TAILING_GAP_SECONDS = 0.2 # Peek the head of the lines to check if we need to start # streaming when tail > 0. PEEK_HEAD_LINES_FOR_START_STREAM = 20 @@ -336,7 +336,7 @@ def _follow_job_logs(file, ]: if wait_last_logs: # Wait all the logs are printed before exit. - time.sleep(1 + _SKY_LOG_TAILING_GAP_SECONDS) + time.sleep(1 + SKY_LOG_TAILING_GAP_SECONDS) wait_last_logs = False continue status_str = status.value if status is not None else 'None' @@ -345,7 +345,7 @@ def _follow_job_logs(file, f'Job finished (status: {status_str}).')) return - time.sleep(_SKY_LOG_TAILING_GAP_SECONDS) + time.sleep(SKY_LOG_TAILING_GAP_SECONDS) status = job_lib.get_status_no_lock(job_id) @@ -426,15 +426,15 @@ def tail_logs(job_id: Optional[int], retry_cnt += 1 if os.path.exists(log_path) and status != job_lib.JobStatus.INIT: break - if retry_cnt >= _SKY_LOG_WAITING_MAX_RETRY: + if retry_cnt >= SKY_LOG_WAITING_MAX_RETRY: print( f'{colorama.Fore.RED}ERROR: Logs for ' f'{job_str} (status: {status.value}) does not exist ' f'after retrying {retry_cnt} times.{colorama.Style.RESET_ALL}') return - print(f'INFO: Waiting {_SKY_LOG_WAITING_GAP_SECONDS}s for the logs ' + print(f'INFO: Waiting {SKY_LOG_WAITING_GAP_SECONDS}s for the logs ' 'to be written...') - time.sleep(_SKY_LOG_WAITING_GAP_SECONDS) + time.sleep(SKY_LOG_WAITING_GAP_SECONDS) status = job_lib.update_job_status([job_id], silent=True)[0] start_stream_at = LOG_FILE_START_STREAMING_AT diff --git a/sky/skylet/log_lib.pyi b/sky/skylet/log_lib.pyi index 89d1628ec11..c7028e121aa 100644 --- a/sky/skylet/log_lib.pyi +++ b/sky/skylet/log_lib.pyi @@ -13,6 +13,9 @@ from sky.skylet import constants as constants from sky.skylet import job_lib as job_lib from sky.utils import log_utils as log_utils +SKY_LOG_WAITING_GAP_SECONDS: int = ... +SKY_LOG_WAITING_MAX_RETRY: int = ... +SKY_LOG_TAILING_GAP_SECONDS: float = ... LOG_FILE_START_STREAMING_AT: str = ... From fa6d8faad8862b580b8784e82246c341dbd397ce Mon Sep 17 00:00:00 2001 From: Christopher Cooper Date: Thu, 12 Dec 2024 18:21:04 -0800 Subject: [PATCH 3/3] limit parallelism --- sky/jobs/controller.py | 46 +++++++++++++++++++++++++++++------------- sky/jobs/semaphore.py | 46 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 14 deletions(-) create mode 100644 sky/jobs/semaphore.py diff --git a/sky/jobs/controller.py b/sky/jobs/controller.py index 72dce3e50d7..49d87fb4822 100644 --- a/sky/jobs/controller.py +++ b/sky/jobs/controller.py @@ -9,6 +9,7 @@ from typing import Optional, Tuple import filelock +import psutil from sky import exceptions from sky import sky_logging @@ -18,6 +19,7 @@ from sky.jobs import recovery_strategy from sky.jobs import state as managed_job_state from sky.jobs import utils as managed_job_utils +from sky.jobs import semaphore from sky.skylet import constants from sky.skylet import job_lib from sky.usage import usage_lib @@ -30,6 +32,9 @@ if typing.TYPE_CHECKING: import sky +_JOB_SEMAPHORE_LOCK_DIR = os.path.expanduser('~/.sky/job_semaphore') +_JOB_LAUNCH_SEMAPHORE_LOCK_DIR = os.path.expanduser('~/.sky/job_launch_semaphore') + # Use the explicit logger name so that the logger is under the # `sky.jobs.controller` namespace when executed directly, so as # to inherit the setup from the `sky` logger. @@ -191,17 +196,19 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool: f'Submitted managed job {self._job_id} (task: {task_id}, name: ' f'{task.name!r}); {constants.TASK_ID_ENV_VAR}: {task_id_env_var}') - logger.info('Started monitoring.') - managed_job_state.set_starting(job_id=self._job_id, - task_id=task_id, - callback_func=callback_func) - remote_job_submitted_at = self._strategy_executor.launch() - assert remote_job_submitted_at is not None, remote_job_submitted_at - - managed_job_state.set_started(job_id=self._job_id, - task_id=task_id, - start_time=remote_job_submitted_at, - callback_func=callback_func) + with semaphore.FileLockSemaphore(lock_dir_path=_JOB_LAUNCH_SEMAPHORE_LOCK_DIR, lock_count=_get_launch_parallelism()): + logger.info('Started monitoring.') + managed_job_state.set_starting(job_id=self._job_id, + task_id=task_id, + callback_func=callback_func) + remote_job_submitted_at = self._strategy_executor.launch() + assert remote_job_submitted_at is not None, remote_job_submitted_at + + managed_job_state.set_started(job_id=self._job_id, + task_id=task_id, + start_time=remote_job_submitted_at, + callback_func=callback_func) + while True: time.sleep(managed_job_utils.JOB_STATUS_CHECK_GAP_SECONDS) @@ -426,7 +433,14 @@ def _update_failed_task_state( job_id=self._job_id, task_id=task_id, task=self._dag.tasks[task_id])) + +def _get_job_parallelism() -> int: + # Assume a running job uses 400MB memory. + job_memory = 400 * 1024 * 1024 + return psutil.virtual_memory().total // job_memory +def _get_launch_parallelism() -> int: + return os.cpu_count() * 4 def _run_controller(job_id: int, dag_yaml: str, retry_until_up: bool): """Runs the controller in a remote process for interruption.""" @@ -563,8 +577,7 @@ def start(job_id, dag_yaml, retry_until_up): failure_reason=('Unexpected error occurred. For details, ' f'run: sky jobs logs --controller {job_id}')) - -if __name__ == '__main__': +def main(): parser = argparse.ArgumentParser() parser.add_argument('--job-id', required=True, @@ -580,4 +593,9 @@ def start(job_id, dag_yaml, retry_until_up): # We start process with 'spawn', because 'fork' could result in weird # behaviors; 'spawn' is also cross-platform. multiprocessing.set_start_method('spawn', force=True) - start(args.job_id, args.dag_yaml, args.retry_until_up) + + with semaphore.FileLockSemaphore(lock_dir_path=_JOB_SEMAPHORE_LOCK_DIR, lock_count=_get_job_parallelism()): + start(args.job_id, args.dag_yaml, args.retry_until_up) + +if __name__ == '__main__': + main() diff --git a/sky/jobs/semaphore.py b/sky/jobs/semaphore.py new file mode 100644 index 00000000000..961eaad2d56 --- /dev/null +++ b/sky/jobs/semaphore.py @@ -0,0 +1,46 @@ +"""A file-lock based semaphore to limit parallelism of the jobs controller.""" + +import time +import filelock +import os +from typing import List + +class FileLockSemaphore: + """A cross-process semaphore-like mechanism using file locks. + + Some semaphore uses are unsupported: + - Each release() call must have a corresponding acquire(), that is, the + semaphore value cannot go above the initial value (lock_count). + - All processes must use the same lock_count. This is not enforced by the + FileLockSemaphore class. + """ + def __init__(self, lock_dir_path: str, lock_count: int): + self.lock_dir_path = lock_dir_path + self.locks = [filelock.FileLock(os.path.join(lock_dir_path, f"{i}.lock")) for i in range(lock_count)] + self.acquired_locks: List[filelock.FileLock] = [] + + def acquire(self): + while True: + for lock in self.locks: + try: + lock.acquire(blocking=False) + self.acquired_locks.append(lock) + return + except filelock.Timeout: + pass + time.sleep(0.05) + + def release(self): + if self.acquired_locks: + self.acquired_locks.pop().release() + + def __enter__(self): + self.acquire() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.release() + + def __del__(self): + for lock in self.acquired_locks: + lock.release()