diff --git a/sky/jobs/controller.py b/sky/jobs/controller.py index 5219c564500..afd681b6e4e 100644 --- a/sky/jobs/controller.py +++ b/sky/jobs/controller.py @@ -236,9 +236,7 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool: task.num_nodes == 1): continue - if job_status in [ - job_lib.JobStatus.FAILED, job_lib.JobStatus.FAILED_SETUP - ]: + if job_status in job_lib.JobStatus.user_code_failure_states(): # Add a grace period before the check of preemption to avoid # false alarm for job failure. time.sleep(5) @@ -268,9 +266,7 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool: if job_status is not None and not job_status.is_terminal(): # The multi-node job is still running, continue monitoring. continue - elif job_status in [ - job_lib.JobStatus.FAILED, job_lib.JobStatus.FAILED_SETUP - ]: + elif job_status in job_lib.JobStatus.user_code_failure_states(): # The user code has probably crashed, fail immediately. end_time = managed_job_utils.get_job_timestamp( self._backend, cluster_name, get_end_time=True) @@ -473,6 +469,7 @@ def start(job_id, dag_yaml, retry_until_up): """Start the controller.""" controller_process = None cancelling = False + task_id = None try: _handle_signal(job_id) # TODO(suquark): In theory, we should make controller process a @@ -491,6 +488,7 @@ def start(job_id, dag_yaml, retry_until_up): except exceptions.ManagedJobUserCancelledError: dag, _ = _get_dag_and_name(dag_yaml) task_id, _ = managed_job_state.get_latest_task_id_status(job_id) + assert task_id is not None, job_id logger.info( f'Cancelling managed job, job_id: {job_id}, task_id: {task_id}') managed_job_state.set_cancelling( @@ -522,6 +520,7 @@ def start(job_id, dag_yaml, retry_until_up): logger.info(f'Cluster of managed job {job_id} has been cleaned up.') if cancelling: + assert task_id is not None, job_id # Since it's set with cancelling managed_job_state.set_cancelled( job_id=job_id, callback_func=managed_job_utils.event_callback_func( diff --git a/sky/jobs/state.py b/sky/jobs/state.py index 6a0e3caeda3..cbc17353b13 100644 --- a/sky/jobs/state.py +++ b/sky/jobs/state.py @@ -575,10 +575,12 @@ def get_latest_task_id_status( id_statuses = _get_all_task_ids_statuses(job_id) if len(id_statuses) == 0: return None, None - task_id, status = id_statuses[-1] - for task_id, status in id_statuses: - if not status.is_terminal(): - break + task_id, status = next( + ((tid, st) for tid, st in id_statuses if not st.is_terminal()), + id_statuses[-1], + ) + # Unpack the tuple first, or it triggers a Pylint's bug on recognizing + # the return type. return task_id, status diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index f82e1132678..713ac16abb4 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -384,32 +384,15 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str: job_statuses = backend.get_job_status(handle, stream_logs=False) job_status = list(job_statuses.values())[0] assert job_status is not None, 'No job found.' + assert task_id is not None, job_id + if job_status != job_lib.JobStatus.CANCELLED: - assert task_id is not None, job_id - if task_id < num_tasks - 1 and follow: - # The log for the current job is finished. We need to - # wait until next job to be started. - logger.debug( - f'INFO: Log for the current task ({task_id}) ' - 'is finished. Waiting for the next task\'s log ' - 'to be started.') - # Add a newline to avoid the status display below - # removing the last line of the task output. - print() - status_display.update( - ux_utils.spinner_message( - f'Waiting for the next task: {task_id + 1}')) - status_display.start() - original_task_id = task_id - while True: - task_id, managed_job_status = ( - managed_job_state.get_latest_task_id_status( - job_id)) - if original_task_id != task_id: - break - time.sleep(JOB_STATUS_CHECK_GAP_SECONDS) - continue - else: + if not follow: + break + + # Logs for retrying failed tasks. + if (job_status + in job_lib.JobStatus.user_code_failure_states()): task_specs = managed_job_state.get_task_specs( job_id, task_id) if task_specs.get('max_restarts_on_errors', 0) == 0: @@ -422,15 +405,51 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str: ux_utils.spinner_message( 'Waiting for next restart for the failed task')) status_display.start() - while True: - _, managed_job_status = ( - managed_job_state.get_latest_task_id_status( - job_id)) - if (managed_job_status != - managed_job_state.ManagedJobStatus.RUNNING): - break + + def is_managed_job_status_updated( + status: Optional[managed_job_state.ManagedJobStatus] + ) -> bool: + """Check if local managed job status reflects remote + job failure. + + Ensures synchronization between remote cluster + failure detection (JobStatus.FAILED) and controller + retry logic. + """ + return (status != + managed_job_state.ManagedJobStatus.RUNNING) + + while not is_managed_job_status_updated( + managed_job_status := + managed_job_state.get_status(job_id)): time.sleep(JOB_STATUS_CHECK_GAP_SECONDS) continue + + if task_id == num_tasks - 1: + break + + # The log for the current job is finished. We need to + # wait until next job to be started. + logger.debug( + f'INFO: Log for the current task ({task_id}) ' + 'is finished. Waiting for the next task\'s log ' + 'to be started.') + # Add a newline to avoid the status display below + # removing the last line of the task output. + print() + status_display.update( + ux_utils.spinner_message( + f'Waiting for the next task: {task_id + 1}')) + status_display.start() + original_task_id = task_id + while True: + task_id, managed_job_status = ( + managed_job_state.get_latest_task_id_status(job_id)) + if original_task_id != task_id: + break + time.sleep(JOB_STATUS_CHECK_GAP_SECONDS) + continue + # The job can be cancelled by the user or the controller (when # the cluster is partially preempted). logger.debug( diff --git a/sky/serve/replica_managers.py b/sky/serve/replica_managers.py index c0e5220e779..e103d8c5f27 100644 --- a/sky/serve/replica_managers.py +++ b/sky/serve/replica_managers.py @@ -998,9 +998,7 @@ def _fetch_job_status(self) -> None: # Re-raise the exception if it is not preempted. raise job_status = list(job_statuses.values())[0] - if job_status in [ - job_lib.JobStatus.FAILED, job_lib.JobStatus.FAILED_SETUP - ]: + if job_status in job_lib.JobStatus.user_code_failure_states(): info.status_property.user_app_failed = True serve_state.add_or_update_replica(self._service_name, info.replica_id, info) diff --git a/sky/skylet/job_lib.py b/sky/skylet/job_lib.py index dfd8332b019..deba48a3b30 100644 --- a/sky/skylet/job_lib.py +++ b/sky/skylet/job_lib.py @@ -12,7 +12,7 @@ import sqlite3 import subprocess import time -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Sequence import colorama import filelock @@ -162,13 +162,17 @@ class JobStatus(enum.Enum): def nonterminal_statuses(cls) -> List['JobStatus']: return [cls.INIT, cls.SETTING_UP, cls.PENDING, cls.RUNNING] - def is_terminal(self): + def is_terminal(self) -> bool: return self not in self.nonterminal_statuses() - def __lt__(self, other): + @classmethod + def user_code_failure_states(cls) -> Sequence['JobStatus']: + return (cls.FAILED, cls.FAILED_SETUP) + + def __lt__(self, other: 'JobStatus') -> bool: return list(JobStatus).index(self) < list(JobStatus).index(other) - def colored_str(self): + def colored_str(self) -> str: color = _JOB_STATUS_TO_COLOR[self] return f'{color}{self.value}{colorama.Style.RESET_ALL}' diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 6ba81ce68f0..6e19c490409 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -3283,6 +3283,33 @@ def test_managed_jobs_recovery_multi_node_gcp(): run_one_test(test) +@pytest.mark.managed_jobs +def test_managed_jobs_retry_logs(): + """Test managed job retry logs are properly displayed when a task fails.""" + name = _get_cluster_name() + yaml_path = 'tests/test_yamls/test_managed_jobs_retry.yaml' + + with tempfile.NamedTemporaryFile(mode='w', suffix='.log') as log_file: + test = Test( + 'managed_jobs_retry_logs', + [ + f'sky jobs launch -n {name} {yaml_path} -y -d', + f'sky jobs logs -n {name} | tee {log_file.name}', + # First attempt + f'cat {log_file.name} | grep "Job started. Streaming logs..."', + f'cat {log_file.name} | grep "Job 1 failed"', + # Second attempt + f'cat {log_file.name} | grep "Job started. Streaming logs..." | wc -l | grep 2', + f'cat {log_file.name} | grep "Job 1 failed" | wc -l | grep 2', + # Task 2 is not reached + f'! cat {log_file.name} | grep "Job 2"', + ], + f'sky jobs cancel -y -n {name}', + timeout=7 * 60, # 7 mins + ) + run_one_test(test) + + @pytest.mark.aws @pytest.mark.managed_jobs def test_managed_jobs_cancellation_aws(aws_config_region): diff --git a/tests/test_yamls/test_managed_jobs_retry.yaml b/tests/test_yamls/test_managed_jobs_retry.yaml new file mode 100644 index 00000000000..76289986386 --- /dev/null +++ b/tests/test_yamls/test_managed_jobs_retry.yaml @@ -0,0 +1,14 @@ +resources: + cpus: 2+ + job_recovery: + max_restarts_on_errors: 1 + +# Task 1: Always fails +run: | + echo "Task 1 starting" + exit 1 +--- +# Task 2: Never reached due to Task 1 failure +run: | + echo "Task 2 starting" + exit 0 \ No newline at end of file