Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Jobs] Move task retry logic to correct branch in stream_logs_by_id #4407

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
11 changes: 5 additions & 6 deletions sky/jobs/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,7 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool:
task.num_nodes == 1):
continue

if job_status in [
job_lib.JobStatus.FAILED, job_lib.JobStatus.FAILED_SETUP
]:
if job_status in job_lib.JobStatus.user_code_failure_states():
# Add a grace period before the check of preemption to avoid
# false alarm for job failure.
time.sleep(5)
Expand Down Expand Up @@ -268,9 +266,7 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool:
if job_status is not None and not job_status.is_terminal():
# The multi-node job is still running, continue monitoring.
continue
elif job_status in [
job_lib.JobStatus.FAILED, job_lib.JobStatus.FAILED_SETUP
]:
elif job_status in job_lib.JobStatus.user_code_failure_states():
# The user code has probably crashed, fail immediately.
end_time = managed_job_utils.get_job_timestamp(
self._backend, cluster_name, get_end_time=True)
Expand Down Expand Up @@ -473,6 +469,7 @@ def start(job_id, dag_yaml, retry_until_up):
"""Start the controller."""
controller_process = None
cancelling = False
task_id = None
try:
_handle_signal(job_id)
# TODO(suquark): In theory, we should make controller process a
Expand All @@ -491,6 +488,7 @@ def start(job_id, dag_yaml, retry_until_up):
except exceptions.ManagedJobUserCancelledError:
dag, _ = _get_dag_and_name(dag_yaml)
task_id, _ = managed_job_state.get_latest_task_id_status(job_id)
assert task_id is not None, job_id
logger.info(
f'Cancelling managed job, job_id: {job_id}, task_id: {task_id}')
managed_job_state.set_cancelling(
Expand Down Expand Up @@ -522,6 +520,7 @@ def start(job_id, dag_yaml, retry_until_up):
logger.info(f'Cluster of managed job {job_id} has been cleaned up.')

if cancelling:
assert task_id is not None, job_id # Since it's set with cancelling
managed_job_state.set_cancelled(
job_id=job_id,
callback_func=managed_job_utils.event_callback_func(
Expand Down
10 changes: 6 additions & 4 deletions sky/jobs/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,10 +575,12 @@ def get_latest_task_id_status(
id_statuses = _get_all_task_ids_statuses(job_id)
if len(id_statuses) == 0:
return None, None
task_id, status = id_statuses[-1]
for task_id, status in id_statuses:
if not status.is_terminal():
break
task_id, status = next(
((tid, st) for tid, st in id_statuses if not st.is_terminal()),
id_statuses[-1],
)
# Unpack the tuple first, or it triggers a Pylint's bug on recognizing
# the return type.
return task_id, status


Expand Down
55 changes: 32 additions & 23 deletions sky/jobs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import textwrap
import time
import typing
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

import colorama
import filelock
Expand Down Expand Up @@ -384,8 +384,37 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str:
job_statuses = backend.get_job_status(handle, stream_logs=False)
job_status = list(job_statuses.values())[0]
assert job_status is not None, 'No job found.'
assert task_id is not None, job_id

if job_status in job_lib.JobStatus.user_code_failure_states():
task_specs = managed_job_state.get_task_specs(
job_id, task_id)
if task_specs.get('max_restarts_on_errors', 0) == 0:
# We don't need to wait for the managed job status
# update, as the job is guaranteed to be in terminal
# state afterwards.
break
print()
status_display.update(
ux_utils.spinner_message(
'Waiting for next restart for the failed task'))
status_display.start()

# Check if local managed job status reflects remote job
# failure.
# Ensures synchronization between remote cluster failure
# detection (JobStatus.FAILED) and controller retry logic.
is_managed_job_status_updated: Callable[
[Optional[managed_job_state.ManagedJobStatus]],
bool] = (lambda status: status != managed_job_state.
ManagedJobStatus.RUNNING)

while not is_managed_job_status_updated(
managed_job_status := managed_job_state.get_status(
job_id)):
time.sleep(JOB_STATUS_CHECK_GAP_SECONDS)
continue
if job_status != job_lib.JobStatus.CANCELLED:
assert task_id is not None, job_id
if task_id < num_tasks - 1 and follow:
# The log for the current job is finished. We need to
# wait until next job to be started.
Expand All @@ -410,27 +439,7 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str:
time.sleep(JOB_STATUS_CHECK_GAP_SECONDS)
continue
else:
task_specs = managed_job_state.get_task_specs(
job_id, task_id)
if task_specs.get('max_restarts_on_errors', 0) == 0:
# We don't need to wait for the managed job status
# update, as the job is guaranteed to be in terminal
# state afterwards.
break
print()
status_display.update(
ux_utils.spinner_message(
'Waiting for next restart for the failed task'))
status_display.start()
while True:
_, managed_job_status = (
managed_job_state.get_latest_task_id_status(
job_id))
if (managed_job_status !=
managed_job_state.ManagedJobStatus.RUNNING):
break
time.sleep(JOB_STATUS_CHECK_GAP_SECONDS)
continue
break
# The job can be cancelled by the user or the controller (when
# the cluster is partially preempted).
logger.debug(
Expand Down
4 changes: 1 addition & 3 deletions sky/serve/replica_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,9 +998,7 @@ def _fetch_job_status(self) -> None:
# Re-raise the exception if it is not preempted.
raise
job_status = list(job_statuses.values())[0]
if job_status in [
job_lib.JobStatus.FAILED, job_lib.JobStatus.FAILED_SETUP
]:
if job_status in job_lib.JobStatus.user_code_failure_states():
info.status_property.user_app_failed = True
serve_state.add_or_update_replica(self._service_name,
info.replica_id, info)
Expand Down
12 changes: 8 additions & 4 deletions sky/skylet/job_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import sqlite3
import subprocess
import time
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Sequence

import colorama
import filelock
Expand Down Expand Up @@ -162,13 +162,17 @@ class JobStatus(enum.Enum):
def nonterminal_statuses(cls) -> List['JobStatus']:
return [cls.INIT, cls.SETTING_UP, cls.PENDING, cls.RUNNING]

def is_terminal(self):
def is_terminal(self) -> bool:
return self not in self.nonterminal_statuses()

def __lt__(self, other):
@classmethod
def user_code_failure_states(cls) -> Sequence['JobStatus']:
return (cls.FAILED, cls.FAILED_SETUP)

def __lt__(self, other: 'JobStatus') -> bool:
return list(JobStatus).index(self) < list(JobStatus).index(other)

def colored_str(self):
def colored_str(self) -> str:
color = _JOB_STATUS_TO_COLOR[self]
return f'{color}{self.value}{colorama.Style.RESET_ALL}'

Expand Down