From 655b81f4c59ea6fe540d23521e33c6e7efb00bf7 Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Sun, 24 Nov 2024 12:59:48 -0800 Subject: [PATCH 01/13] fix(jobs): move task retry logic to correct branch in `stream_logs_by_id` --- sky/jobs/utils.py | 45 +++++++++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index f82e1132678..410c18341a5 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -384,8 +384,29 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str: job_statuses = backend.get_job_status(handle, stream_logs=False) job_status = list(job_statuses.values())[0] assert job_status is not None, 'No job found.' + assert task_id is not None, job_id + if job_status == job_lib.JobStatus.FAILED: + task_specs = managed_job_state.get_task_specs( + job_id, task_id) + if task_specs.get('max_restarts_on_errors', 0) == 0: + # We don't need to wait for the managed job status + # update, as the job is guaranteed to be in terminal + # state afterwards. + break + print() + status_display.update( + ux_utils.spinner_message( + 'Waiting for next restart for the failed task')) + status_display.start() + while True: + _, managed_job_status = ( + managed_job_state.get_latest_task_id_status(job_id)) + if (managed_job_status != + managed_job_state.ManagedJobStatus.RUNNING): + break + time.sleep(JOB_STATUS_CHECK_GAP_SECONDS) + continue if job_status != job_lib.JobStatus.CANCELLED: - assert task_id is not None, job_id if task_id < num_tasks - 1 and follow: # The log for the current job is finished. We need to # wait until next job to be started. @@ -410,27 +431,7 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str: time.sleep(JOB_STATUS_CHECK_GAP_SECONDS) continue else: - task_specs = managed_job_state.get_task_specs( - job_id, task_id) - if task_specs.get('max_restarts_on_errors', 0) == 0: - # We don't need to wait for the managed job status - # update, as the job is guaranteed to be in terminal - # state afterwards. - break - print() - status_display.update( - ux_utils.spinner_message( - 'Waiting for next restart for the failed task')) - status_display.start() - while True: - _, managed_job_status = ( - managed_job_state.get_latest_task_id_status( - job_id)) - if (managed_job_status != - managed_job_state.ManagedJobStatus.RUNNING): - break - time.sleep(JOB_STATUS_CHECK_GAP_SECONDS) - continue + break # The job can be cancelled by the user or the controller (when # the cluster is partially preempted). logger.debug( From 0e4e747cd1b2d473f11c5a8520967414c58f4999 Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Mon, 25 Nov 2024 11:36:56 -0800 Subject: [PATCH 02/13] refactor: use `next` for better readibility --- sky/jobs/state.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sky/jobs/state.py b/sky/jobs/state.py index 6a0e3caeda3..d1693d8fcb9 100644 --- a/sky/jobs/state.py +++ b/sky/jobs/state.py @@ -575,11 +575,10 @@ def get_latest_task_id_status( id_statuses = _get_all_task_ids_statuses(job_id) if len(id_statuses) == 0: return None, None - task_id, status = id_statuses[-1] - for task_id, status in id_statuses: - if not status.is_terminal(): - break - return task_id, status + return next( + ((tid, st) for tid, st in id_statuses if not st.is_terminal()), + id_statuses[-1], + ) def get_status(job_id: int) -> Optional[ManagedJobStatus]: From 45f0c46b9e30195b18527c66899a20718a8c7409 Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Mon, 25 Nov 2024 12:56:24 -0800 Subject: [PATCH 03/13] refactor: add some comments for why it's wait until not RUNNING --- sky/jobs/utils.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index 410c18341a5..ddf1f441920 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -398,12 +398,21 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str: ux_utils.spinner_message( 'Waiting for next restart for the failed task')) status_display.start() - while True: + + def is_managed_job_status_updated(): + """Check if local managed job status reflects remote + job failure. + + Ensures synchronization between remote cluster failure + detection (JobStatus.FAILED) and controller retry logic. + """ + nonlocal managed_job_status _, managed_job_status = ( managed_job_state.get_latest_task_id_status(job_id)) - if (managed_job_status != - managed_job_state.ManagedJobStatus.RUNNING): - break + return (managed_job_status != + managed_job_state.ManagedJobStatus.RUNNING) + + while not is_managed_job_status_updated(): time.sleep(JOB_STATUS_CHECK_GAP_SECONDS) continue if job_status != job_lib.JobStatus.CANCELLED: From 782a6d83cce1ad631db70dbffaa2c2dc83c068ca Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Mon, 25 Nov 2024 13:12:30 -0800 Subject: [PATCH 04/13] refactor: a pylint's bug --- sky/jobs/controller.py | 3 +++ sky/jobs/state.py | 5 ++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/sky/jobs/controller.py b/sky/jobs/controller.py index 5219c564500..3cc65a57586 100644 --- a/sky/jobs/controller.py +++ b/sky/jobs/controller.py @@ -473,6 +473,7 @@ def start(job_id, dag_yaml, retry_until_up): """Start the controller.""" controller_process = None cancelling = False + task_id = None try: _handle_signal(job_id) # TODO(suquark): In theory, we should make controller process a @@ -491,6 +492,7 @@ def start(job_id, dag_yaml, retry_until_up): except exceptions.ManagedJobUserCancelledError: dag, _ = _get_dag_and_name(dag_yaml) task_id, _ = managed_job_state.get_latest_task_id_status(job_id) + assert task_id is not None, job_id logger.info( f'Cancelling managed job, job_id: {job_id}, task_id: {task_id}') managed_job_state.set_cancelling( @@ -522,6 +524,7 @@ def start(job_id, dag_yaml, retry_until_up): logger.info(f'Cluster of managed job {job_id} has been cleaned up.') if cancelling: + assert task_id is not None, job_id # Since it's set with cancelling managed_job_state.set_cancelled( job_id=job_id, callback_func=managed_job_utils.event_callback_func( diff --git a/sky/jobs/state.py b/sky/jobs/state.py index d1693d8fcb9..cbc17353b13 100644 --- a/sky/jobs/state.py +++ b/sky/jobs/state.py @@ -575,10 +575,13 @@ def get_latest_task_id_status( id_statuses = _get_all_task_ids_statuses(job_id) if len(id_statuses) == 0: return None, None - return next( + task_id, status = next( ((tid, st) for tid, st in id_statuses if not st.is_terminal()), id_statuses[-1], ) + # Unpack the tuple first, or it triggers a Pylint's bug on recognizing + # the return type. + return task_id, status def get_status(job_id: int) -> Optional[ManagedJobStatus]: From 33ebbf3525e04c3e124ad6541bcb2dbc494c5d34 Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Mon, 25 Nov 2024 13:16:25 -0800 Subject: [PATCH 05/13] fix: also include failed_setup --- sky/jobs/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index ddf1f441920..710e3c4b0bc 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -385,7 +385,11 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str: job_status = list(job_statuses.values())[0] assert job_status is not None, 'No job found.' assert task_id is not None, job_id - if job_status == job_lib.JobStatus.FAILED: + + user_code_failure_states = [ + job_lib.JobStatus.FAILED, job_lib.JobStatus.FAILED_SETUP + ] + if job_status in user_code_failure_states: task_specs = managed_job_state.get_task_specs( job_id, task_id) if task_specs.get('max_restarts_on_errors', 0) == 0: From eb7fb693ee39d080710b2acdc074ff273de9438a Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Mon, 25 Nov 2024 13:44:47 -0800 Subject: [PATCH 06/13] refactor: a extracted `user_code_failure_states` --- sky/jobs/controller.py | 8 ++------ sky/jobs/utils.py | 7 ++----- sky/serve/replica_managers.py | 4 +--- sky/skylet/job_lib.py | 12 ++++++++---- 4 files changed, 13 insertions(+), 18 deletions(-) diff --git a/sky/jobs/controller.py b/sky/jobs/controller.py index 3cc65a57586..afd681b6e4e 100644 --- a/sky/jobs/controller.py +++ b/sky/jobs/controller.py @@ -236,9 +236,7 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool: task.num_nodes == 1): continue - if job_status in [ - job_lib.JobStatus.FAILED, job_lib.JobStatus.FAILED_SETUP - ]: + if job_status in job_lib.JobStatus.user_code_failure_states(): # Add a grace period before the check of preemption to avoid # false alarm for job failure. time.sleep(5) @@ -268,9 +266,7 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool: if job_status is not None and not job_status.is_terminal(): # The multi-node job is still running, continue monitoring. continue - elif job_status in [ - job_lib.JobStatus.FAILED, job_lib.JobStatus.FAILED_SETUP - ]: + elif job_status in job_lib.JobStatus.user_code_failure_states(): # The user code has probably crashed, fail immediately. end_time = managed_job_utils.get_job_timestamp( self._backend, cluster_name, get_end_time=True) diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index 710e3c4b0bc..6546918c41b 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -386,10 +386,7 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str: assert job_status is not None, 'No job found.' assert task_id is not None, job_id - user_code_failure_states = [ - job_lib.JobStatus.FAILED, job_lib.JobStatus.FAILED_SETUP - ] - if job_status in user_code_failure_states: + if job_status in job_lib.JobStatus.user_code_failure_states(): task_specs = managed_job_state.get_task_specs( job_id, task_id) if task_specs.get('max_restarts_on_errors', 0) == 0: @@ -403,7 +400,7 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str: 'Waiting for next restart for the failed task')) status_display.start() - def is_managed_job_status_updated(): + def is_managed_job_status_updated() -> bool: """Check if local managed job status reflects remote job failure. diff --git a/sky/serve/replica_managers.py b/sky/serve/replica_managers.py index c0e5220e779..e103d8c5f27 100644 --- a/sky/serve/replica_managers.py +++ b/sky/serve/replica_managers.py @@ -998,9 +998,7 @@ def _fetch_job_status(self) -> None: # Re-raise the exception if it is not preempted. raise job_status = list(job_statuses.values())[0] - if job_status in [ - job_lib.JobStatus.FAILED, job_lib.JobStatus.FAILED_SETUP - ]: + if job_status in job_lib.JobStatus.user_code_failure_states(): info.status_property.user_app_failed = True serve_state.add_or_update_replica(self._service_name, info.replica_id, info) diff --git a/sky/skylet/job_lib.py b/sky/skylet/job_lib.py index dfd8332b019..deba48a3b30 100644 --- a/sky/skylet/job_lib.py +++ b/sky/skylet/job_lib.py @@ -12,7 +12,7 @@ import sqlite3 import subprocess import time -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Sequence import colorama import filelock @@ -162,13 +162,17 @@ class JobStatus(enum.Enum): def nonterminal_statuses(cls) -> List['JobStatus']: return [cls.INIT, cls.SETTING_UP, cls.PENDING, cls.RUNNING] - def is_terminal(self): + def is_terminal(self) -> bool: return self not in self.nonterminal_statuses() - def __lt__(self, other): + @classmethod + def user_code_failure_states(cls) -> Sequence['JobStatus']: + return (cls.FAILED, cls.FAILED_SETUP) + + def __lt__(self, other: 'JobStatus') -> bool: return list(JobStatus).index(self) < list(JobStatus).index(other) - def colored_str(self): + def colored_str(self) -> str: color = _JOB_STATUS_TO_COLOR[self] return f'{color}{self.value}{colorama.Style.RESET_ALL}' From c554f512b96430019128ec6f5c731cdf3f59b884 Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Mon, 25 Nov 2024 14:00:03 -0800 Subject: [PATCH 07/13] refactor: remove `nonlocal` --- sky/jobs/utils.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index 6546918c41b..5f4fa445664 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -14,7 +14,7 @@ import textwrap import time import typing -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union import colorama import filelock @@ -400,20 +400,18 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str: 'Waiting for next restart for the failed task')) status_display.start() - def is_managed_job_status_updated() -> bool: - """Check if local managed job status reflects remote - job failure. - - Ensures synchronization between remote cluster failure - detection (JobStatus.FAILED) and controller retry logic. - """ - nonlocal managed_job_status - _, managed_job_status = ( - managed_job_state.get_latest_task_id_status(job_id)) - return (managed_job_status != - managed_job_state.ManagedJobStatus.RUNNING) - - while not is_managed_job_status_updated(): + # Check if local managed job status reflects remote job + # failure. + # Ensures synchronization between remote cluster failure + # detection (JobStatus.FAILED) and controller retry logic. + is_managed_job_status_updated: Callable[ + [Optional[managed_job_state.ManagedJobStatus]], + bool] = (lambda status: status != managed_job_state. + ManagedJobStatus.RUNNING) + + while not is_managed_job_status_updated( + managed_job_status := managed_job_state.get_status( + job_id)): time.sleep(JOB_STATUS_CHECK_GAP_SECONDS) continue if job_status != job_lib.JobStatus.CANCELLED: From 62941d5aaf984120a71e40629c8e36cd551f3bd1 Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Mon, 25 Nov 2024 17:05:09 -0800 Subject: [PATCH 08/13] fix: stop logging retry for no-follow --- sky/jobs/utils.py | 97 +++++++++++++++++++++++++---------------------- 1 file changed, 52 insertions(+), 45 deletions(-) diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index 5f4fa445664..1cf80b66c86 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -386,60 +386,67 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str: assert job_status is not None, 'No job found.' assert task_id is not None, job_id - if job_status in job_lib.JobStatus.user_code_failure_states(): - task_specs = managed_job_state.get_task_specs( - job_id, task_id) - if task_specs.get('max_restarts_on_errors', 0) == 0: - # We don't need to wait for the managed job status - # update, as the job is guaranteed to be in terminal - # state afterwards. + if job_status != job_lib.JobStatus.CANCELLED: + if not follow: break - print() - status_display.update( - ux_utils.spinner_message( - 'Waiting for next restart for the failed task')) - status_display.start() - # Check if local managed job status reflects remote job - # failure. - # Ensures synchronization between remote cluster failure - # detection (JobStatus.FAILED) and controller retry logic. - is_managed_job_status_updated: Callable[ - [Optional[managed_job_state.ManagedJobStatus]], - bool] = (lambda status: status != managed_job_state. - ManagedJobStatus.RUNNING) - - while not is_managed_job_status_updated( - managed_job_status := managed_job_state.get_status( - job_id)): - time.sleep(JOB_STATUS_CHECK_GAP_SECONDS) - continue - if job_status != job_lib.JobStatus.CANCELLED: - if task_id < num_tasks - 1 and follow: - # The log for the current job is finished. We need to - # wait until next job to be started. - logger.debug( - f'INFO: Log for the current task ({task_id}) ' - 'is finished. Waiting for the next task\'s log ' - 'to be started.') - # Add a newline to avoid the status display below - # removing the last line of the task output. + # Logs for retrying failed tasks. + if job_status in job_lib.JobStatus.user_code_failure_states( + ): + task_specs = managed_job_state.get_task_specs( + job_id, task_id) + if task_specs.get('max_restarts_on_errors', 0) == 0: + # We don't need to wait for the managed job status + # update, as the job is guaranteed to be in terminal + # state afterwards. + break print() status_display.update( ux_utils.spinner_message( - f'Waiting for the next task: {task_id + 1}')) + 'Waiting for next restart for the failed task')) status_display.start() - original_task_id = task_id - while True: - task_id, managed_job_status = ( - managed_job_state.get_latest_task_id_status( - job_id)) - if original_task_id != task_id: - break + + # Check if local managed job status reflects remote job + # failure. + # Ensures synchronization between remote cluster failure + # detection (JobStatus.FAILED) and controller retry + # logic. + is_managed_job_status_updated: Callable[ + [Optional[managed_job_state.ManagedJobStatus]], + bool] = (lambda status: status != managed_job_state. + ManagedJobStatus.RUNNING) + + while not is_managed_job_status_updated( + managed_job_status := + managed_job_state.get_status(job_id)): time.sleep(JOB_STATUS_CHECK_GAP_SECONDS) continue - else: + + if task_id == num_tasks - 1: break + + # The log for the current job is finished. We need to + # wait until next job to be started. + logger.debug( + f'INFO: Log for the current task ({task_id}) ' + 'is finished. Waiting for the next task\'s log ' + 'to be started.') + # Add a newline to avoid the status display below + # removing the last line of the task output. + print() + status_display.update( + ux_utils.spinner_message( + f'Waiting for the next task: {task_id + 1}')) + status_display.start() + original_task_id = task_id + while True: + task_id, managed_job_status = ( + managed_job_state.get_latest_task_id_status(job_id)) + if original_task_id != task_id: + break + time.sleep(JOB_STATUS_CHECK_GAP_SECONDS) + continue + # The job can be cancelled by the user or the controller (when # the cluster is partially preempted). logger.debug( From ebf3f91e0e155250864b4acd7e565b8bc75b49f1 Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Mon, 25 Nov 2024 19:50:09 -0800 Subject: [PATCH 09/13] tests: smoke tests for managed jobs retrying --- tests/test_smoke.py | 46 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 6ba81ce68f0..894c4f633ee 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -3283,6 +3283,52 @@ def test_managed_jobs_recovery_multi_node_gcp(): run_one_test(test) +@pytest.mark.managed_jobs +def test_managed_jobs_retry_logs(): + """Test managed job retry logs are properly displayed when a task fails.""" + name = _get_cluster_name() + # Create a temporary YAML file with two tasks - first one fails, second succeeds + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml') as f: + yaml_content = textwrap.dedent(""" + resources: + cpus: 2+ + job_recovery: + max_restarts_on_errors: 1 + + # Task 1: Always fails + run: | + echo "Task 1 starting" + exit 1 + --- + # Task 2: Never reached due to Task 1 failure + run: | + echo "Task 2 starting" + exit 0 + """) + f.write(yaml_content) + f.flush() + + with tempfile.NamedTemporaryFile(mode='w', suffix='.log') as log_file: + test = Test( + 'managed_jobs_retry_logs', + [ + f'sky jobs launch -n {name} {f.name} -y -d', + f'sky jobs logs -n {name} | tee {log_file.name}', + # First attempt + f'cat {log_file.name} | grep "Job started. Streaming logs..."', + f'cat {log_file.name} | grep "Job 1 failed"', + # Second attempt + f'cat {log_file.name} | grep "Job started. Streaming logs..." | wc -l | grep 2', + f'cat {log_file.name} | grep "Job 1 failed" | wc -l | grep 2', + # Task 2 is not reached + f'! cat {log_file.name} | grep "Job 2"', + ], + f'sky jobs cancel -y -n {name}', + timeout=7 * 60, # 5 mins + ) + run_one_test(test) + + @pytest.mark.aws @pytest.mark.managed_jobs def test_managed_jobs_cancellation_aws(aws_config_region): From 939b057de22a25c2e8ab30f1bd7e5dd58ba45ba4 Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Mon, 25 Nov 2024 20:46:24 -0800 Subject: [PATCH 10/13] format --- tests/test_smoke.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 894c4f633ee..7dde6160068 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -3311,7 +3311,7 @@ def test_managed_jobs_retry_logs(): with tempfile.NamedTemporaryFile(mode='w', suffix='.log') as log_file: test = Test( 'managed_jobs_retry_logs', - [ + [ f'sky jobs launch -n {name} {f.name} -y -d', f'sky jobs logs -n {name} | tee {log_file.name}', # First attempt From 8da7604c665654db1d7483b443e3e05debb3ede7 Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Tue, 26 Nov 2024 15:33:03 -0800 Subject: [PATCH 11/13] format Co-authored-by: Tian Xia --- sky/jobs/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index 1cf80b66c86..16d18e89e86 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -391,8 +391,8 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str: break # Logs for retrying failed tasks. - if job_status in job_lib.JobStatus.user_code_failure_states( - ): + if (job_status in job_lib.JobStatus.user_code_failure_states( + )): task_specs = managed_job_state.get_task_specs( job_id, task_id) if task_specs.get('max_restarts_on_errors', 0) == 0: From 6518e447c792738f4c4e007598ef094e7bf4d78a Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Tue, 26 Nov 2024 15:36:40 -0800 Subject: [PATCH 12/13] chore: extract yaml file to test_yamls/ --- tests/test_smoke.py | 59 +++++++------------ tests/test_yamls/test_managed_jobs_retry.yaml | 14 +++++ 2 files changed, 34 insertions(+), 39 deletions(-) create mode 100644 tests/test_yamls/test_managed_jobs_retry.yaml diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 7dde6160068..6e19c490409 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -3287,46 +3287,27 @@ def test_managed_jobs_recovery_multi_node_gcp(): def test_managed_jobs_retry_logs(): """Test managed job retry logs are properly displayed when a task fails.""" name = _get_cluster_name() - # Create a temporary YAML file with two tasks - first one fails, second succeeds - with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml') as f: - yaml_content = textwrap.dedent(""" - resources: - cpus: 2+ - job_recovery: - max_restarts_on_errors: 1 - - # Task 1: Always fails - run: | - echo "Task 1 starting" - exit 1 - --- - # Task 2: Never reached due to Task 1 failure - run: | - echo "Task 2 starting" - exit 0 - """) - f.write(yaml_content) - f.flush() + yaml_path = 'tests/test_yamls/test_managed_jobs_retry.yaml' - with tempfile.NamedTemporaryFile(mode='w', suffix='.log') as log_file: - test = Test( - 'managed_jobs_retry_logs', - [ - f'sky jobs launch -n {name} {f.name} -y -d', - f'sky jobs logs -n {name} | tee {log_file.name}', - # First attempt - f'cat {log_file.name} | grep "Job started. Streaming logs..."', - f'cat {log_file.name} | grep "Job 1 failed"', - # Second attempt - f'cat {log_file.name} | grep "Job started. Streaming logs..." | wc -l | grep 2', - f'cat {log_file.name} | grep "Job 1 failed" | wc -l | grep 2', - # Task 2 is not reached - f'! cat {log_file.name} | grep "Job 2"', - ], - f'sky jobs cancel -y -n {name}', - timeout=7 * 60, # 5 mins - ) - run_one_test(test) + with tempfile.NamedTemporaryFile(mode='w', suffix='.log') as log_file: + test = Test( + 'managed_jobs_retry_logs', + [ + f'sky jobs launch -n {name} {yaml_path} -y -d', + f'sky jobs logs -n {name} | tee {log_file.name}', + # First attempt + f'cat {log_file.name} | grep "Job started. Streaming logs..."', + f'cat {log_file.name} | grep "Job 1 failed"', + # Second attempt + f'cat {log_file.name} | grep "Job started. Streaming logs..." | wc -l | grep 2', + f'cat {log_file.name} | grep "Job 1 failed" | wc -l | grep 2', + # Task 2 is not reached + f'! cat {log_file.name} | grep "Job 2"', + ], + f'sky jobs cancel -y -n {name}', + timeout=7 * 60, # 7 mins + ) + run_one_test(test) @pytest.mark.aws diff --git a/tests/test_yamls/test_managed_jobs_retry.yaml b/tests/test_yamls/test_managed_jobs_retry.yaml new file mode 100644 index 00000000000..76289986386 --- /dev/null +++ b/tests/test_yamls/test_managed_jobs_retry.yaml @@ -0,0 +1,14 @@ +resources: + cpus: 2+ + job_recovery: + max_restarts_on_errors: 1 + +# Task 1: Always fails +run: | + echo "Task 1 starting" + exit 1 +--- +# Task 2: Never reached due to Task 1 failure +run: | + echo "Task 2 starting" + exit 0 \ No newline at end of file From 2e937ad5d4fc4b3b4c670e2cd5cf67dee64342fc Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Tue, 26 Nov 2024 15:39:48 -0800 Subject: [PATCH 13/13] refactor: use `def` rather than lambda --- sky/jobs/utils.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index 16d18e89e86..713ac16abb4 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -14,7 +14,7 @@ import textwrap import time import typing -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union import colorama import filelock @@ -391,8 +391,8 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str: break # Logs for retrying failed tasks. - if (job_status in job_lib.JobStatus.user_code_failure_states( - )): + if (job_status + in job_lib.JobStatus.user_code_failure_states()): task_specs = managed_job_state.get_task_specs( job_id, task_id) if task_specs.get('max_restarts_on_errors', 0) == 0: @@ -406,15 +406,18 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str: 'Waiting for next restart for the failed task')) status_display.start() - # Check if local managed job status reflects remote job - # failure. - # Ensures synchronization between remote cluster failure - # detection (JobStatus.FAILED) and controller retry - # logic. - is_managed_job_status_updated: Callable[ - [Optional[managed_job_state.ManagedJobStatus]], - bool] = (lambda status: status != managed_job_state. - ManagedJobStatus.RUNNING) + def is_managed_job_status_updated( + status: Optional[managed_job_state.ManagedJobStatus] + ) -> bool: + """Check if local managed job status reflects remote + job failure. + + Ensures synchronization between remote cluster + failure detection (JobStatus.FAILED) and controller + retry logic. + """ + return (status != + managed_job_state.ManagedJobStatus.RUNNING) while not is_managed_job_status_updated( managed_job_status :=