Skip to content

Commit

Permalink
[Jobs] Allow logs for finished jobs and add sky jobs logs --refresh
Browse files Browse the repository at this point in the history
… for restartin jobs controller (#4380)

* Stream logs for finished jobs

* Allow stream logs for finished jobs

* Read files after the indicator lines

* Add refresh for `sky jobs logs`

* fix log message

* address comments

* Add smoke test

* fix smoke

* fix jobs queue smoke test

* fix storage
  • Loading branch information
Michaelvll authored Dec 3, 2024
1 parent c3c1fde commit 2157f01
Show file tree
Hide file tree
Showing 9 changed files with 177 additions and 52 deletions.
13 changes: 11 additions & 2 deletions sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3914,16 +3914,25 @@ def jobs_cancel(name: Optional[str], job_ids: Tuple[int], all: bool, yes: bool):
default=False,
help=('Show the controller logs of this job; useful for debugging '
'launching/recoveries, etc.'))
@click.option(
'--refresh',
'-r',
default=False,
is_flag=True,
required=False,
help='Query the latest job logs, restarting the jobs controller if stopped.'
)
@click.argument('job_id', required=False, type=int)
@usage_lib.entrypoint
def jobs_logs(name: Optional[str], job_id: Optional[int], follow: bool,
controller: bool):
controller: bool, refresh: bool):
"""Tail the log of a managed job."""
try:
managed_jobs.tail_logs(name=name,
job_id=job_id,
follow=follow,
controller=controller)
controller=controller,
refresh=refresh)
except exceptions.ClusterNotUpError:
with ux_utils.print_exception_no_traceback():
raise
Expand Down
36 changes: 28 additions & 8 deletions sky/jobs/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import time
import traceback
import typing
from typing import Tuple
from typing import Optional, Tuple

import filelock

Expand Down Expand Up @@ -87,18 +87,28 @@ def __init__(self, job_id: int, dag_yaml: str,
task.update_envs(task_envs)

def _download_log_and_stream(
self,
handle: cloud_vm_ray_backend.CloudVmRayResourceHandle) -> None:
"""Downloads and streams the logs of the latest job.
self, task_id: Optional[int],
handle: Optional[cloud_vm_ray_backend.CloudVmRayResourceHandle]
) -> None:
"""Downloads and streams the logs of the current job with given task ID.
We do not stream the logs from the cluster directly, as the
donwload and stream should be faster, and more robust against
preemptions or ssh disconnection during the streaming.
"""
if handle is None:
logger.info(f'Cluster for job {self._job_id} is not found. '
'Skipping downloading and streaming the logs.')
return
managed_job_logs_dir = os.path.join(constants.SKY_LOGS_DIRECTORY,
'managed_jobs')
controller_utils.download_and_stream_latest_job_log(
log_file = controller_utils.download_and_stream_latest_job_log(
self._backend, handle, managed_job_logs_dir)
if log_file is not None:
# Set the path of the log file for the current task, so it can be
# accessed even after the job is finished
managed_job_state.set_local_log_file(self._job_id, task_id,
log_file)
logger.info(f'\n== End of logs (ID: {self._job_id}) ==')

def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool:
Expand Down Expand Up @@ -213,20 +223,30 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool:
if job_status == job_lib.JobStatus.SUCCEEDED:
end_time = managed_job_utils.get_job_timestamp(
self._backend, cluster_name, get_end_time=True)
# The job is done.
# The job is done. Set the job to SUCCEEDED first before start
# downloading and streaming the logs to make it more responsive.
managed_job_state.set_succeeded(self._job_id,
task_id,
end_time=end_time,
callback_func=callback_func)
logger.info(
f'Managed job {self._job_id} (task: {task_id}) SUCCEEDED. '
f'Cleaning up the cluster {cluster_name}.')
clusters = backend_utils.get_clusters(
cluster_names=[cluster_name],
refresh=False,
include_controller=False)
if clusters:
assert len(clusters) == 1, (clusters, cluster_name)
handle = clusters[0].get('handle')
# Best effort to download and stream the logs.
self._download_log_and_stream(task_id, handle)
# Only clean up the cluster, not the storages, because tasks may
# share storages.
recovery_strategy.terminate_cluster(cluster_name=cluster_name)
return True

# For single-node jobs, nonterminated job_status indicates a
# For single-node jobs, non-terminated job_status indicates a
# healthy cluster. We can safely continue monitoring.
# For multi-node jobs, since the job may not be set to FAILED
# immediately (depending on user program) when only some of the
Expand Down Expand Up @@ -278,7 +298,7 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool:
'The user job failed. Please check the logs below.\n'
f'== Logs of the user job (ID: {self._job_id}) ==\n')

self._download_log_and_stream(handle)
self._download_log_and_stream(task_id, handle)
managed_job_status = (
managed_job_state.ManagedJobStatus.FAILED)
if job_status == job_lib.JobStatus.FAILED_SETUP:
Expand Down
96 changes: 61 additions & 35 deletions sky/jobs/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""SDK functions for managed jobs."""
import os
import tempfile
import typing
from typing import Any, Dict, List, Optional, Union
import uuid

Expand Down Expand Up @@ -29,6 +30,9 @@
from sky.utils import timeline
from sky.utils import ux_utils

if typing.TYPE_CHECKING:
from sky.backends import cloud_vm_ray_backend


@timeline.event
@usage_lib.entrypoint
Expand Down Expand Up @@ -225,6 +229,40 @@ def queue_from_kubernetes_pod(
return jobs


def _maybe_restart_controller(
refresh: bool, stopped_message: str, spinner_message: str
) -> 'cloud_vm_ray_backend.CloudVmRayResourceHandle':
"""Restart controller if refresh is True and it is stopped."""
jobs_controller_type = controller_utils.Controllers.JOBS_CONTROLLER
if refresh:
stopped_message = ''
try:
handle = backend_utils.is_controller_accessible(
controller=jobs_controller_type, stopped_message=stopped_message)
except exceptions.ClusterNotUpError as e:
if not refresh:
raise
handle = None
controller_status = e.cluster_status

if handle is not None:
return handle

sky_logging.print(f'{colorama.Fore.YELLOW}'
f'Restarting {jobs_controller_type.value.name}...'
f'{colorama.Style.RESET_ALL}')

rich_utils.force_update_status(
ux_utils.spinner_message(f'{spinner_message} - restarting '
'controller'))
handle = sky.start(jobs_controller_type.value.cluster_name)
controller_status = status_lib.ClusterStatus.UP
rich_utils.force_update_status(ux_utils.spinner_message(spinner_message))

assert handle is not None, (controller_status, refresh)
return handle


@usage_lib.entrypoint
def queue(refresh: bool, skip_finished: bool = False) -> List[Dict[str, Any]]:
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
Expand Down Expand Up @@ -252,34 +290,11 @@ def queue(refresh: bool, skip_finished: bool = False) -> List[Dict[str, Any]]:
does not exist.
RuntimeError: if failed to get the managed jobs with ssh.
"""
jobs_controller_type = controller_utils.Controllers.JOBS_CONTROLLER
stopped_message = ''
if not refresh:
stopped_message = 'No in-progress managed jobs.'
try:
handle = backend_utils.is_controller_accessible(
controller=jobs_controller_type, stopped_message=stopped_message)
except exceptions.ClusterNotUpError as e:
if not refresh:
raise
handle = None
controller_status = e.cluster_status

if refresh and handle is None:
sky_logging.print(f'{colorama.Fore.YELLOW}'
'Restarting controller for latest status...'
f'{colorama.Style.RESET_ALL}')

rich_utils.force_update_status(
ux_utils.spinner_message('Checking managed jobs - restarting '
'controller'))
handle = sky.start(jobs_controller_type.value.cluster_name)
controller_status = status_lib.ClusterStatus.UP
rich_utils.force_update_status(
ux_utils.spinner_message('Checking managed jobs'))

assert handle is not None, (controller_status, refresh)

handle = _maybe_restart_controller(refresh,
stopped_message='No in-progress '
'managed jobs.',
spinner_message='Checking '
'managed jobs')
backend = backend_utils.get_backend_from_handle(handle)
assert isinstance(backend, backends.CloudVmRayBackend)

Expand Down Expand Up @@ -371,7 +386,7 @@ def cancel(name: Optional[str] = None,

@usage_lib.entrypoint
def tail_logs(name: Optional[str], job_id: Optional[int], follow: bool,
controller: bool) -> None:
controller: bool, refresh: bool) -> None:
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
"""Tail logs of managed jobs.
Expand All @@ -382,15 +397,26 @@ def tail_logs(name: Optional[str], job_id: Optional[int], follow: bool,
sky.exceptions.ClusterNotUpError: the jobs controller is not up.
"""
# TODO(zhwu): Automatically restart the jobs controller
if name is not None and job_id is not None:
with ux_utils.print_exception_no_traceback():
raise ValueError('Cannot specify both name and job_id.')

jobs_controller_type = controller_utils.Controllers.JOBS_CONTROLLER
handle = backend_utils.is_controller_accessible(
controller=jobs_controller_type,
job_name_or_id_str = ''
if job_id is not None:
job_name_or_id_str = str(job_id)
elif name is not None:
job_name_or_id_str = f'-n {name}'
else:
job_name_or_id_str = ''
handle = _maybe_restart_controller(
refresh,
stopped_message=(
'Please restart the jobs controller with '
f'`sky start {jobs_controller_type.value.cluster_name}`.'))
f'{jobs_controller_type.value.name.capitalize()} is stopped. To '
f'get the logs, run: {colorama.Style.BRIGHT}sky jobs logs '
f'-r {job_name_or_id_str}{colorama.Style.RESET_ALL}'),
spinner_message='Retrieving job logs')

if name is not None and job_id is not None:
raise ValueError('Cannot specify both name and job_id.')
backend = backend_utils.get_backend_from_handle(handle)
assert isinstance(backend, backends.CloudVmRayBackend), backend

Expand Down
34 changes: 33 additions & 1 deletion sky/jobs/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def create_table(cursor, conn):
spot_job_id INTEGER,
task_id INTEGER DEFAULT 0,
task_name TEXT,
specs TEXT)""")
specs TEXT,
local_log_file TEXT DEFAULT NULL)""")
conn.commit()

db_utils.add_column_to_table(cursor, conn, 'spot', 'failure_reason', 'TEXT')
Expand Down Expand Up @@ -103,6 +104,8 @@ def create_table(cursor, conn):
value_to_replace_existing_entries=json.dumps({
'max_restarts_on_errors': 0,
}))
db_utils.add_column_to_table(cursor, conn, 'spot', 'local_log_file',
'TEXT DEFAULT NULL')

# `job_info` contains the mapping from job_id to the job_name.
# In the future, it may contain more information about each job.
Expand Down Expand Up @@ -157,6 +160,7 @@ def _get_db_path() -> str:
'task_id',
'task_name',
'specs',
'local_log_file',
# columns from the job_info table
'_job_info_job_id', # This should be the same as job_id
'job_name',
Expand Down Expand Up @@ -512,6 +516,20 @@ def set_cancelled(job_id: int, callback_func: CallbackType):
callback_func('CANCELLED')


def set_local_log_file(job_id: int, task_id: Optional[int],
local_log_file: str):
"""Set the local log file for a job."""
filter_str = 'spot_job_id=(?)'
filter_args = [local_log_file, job_id]
if task_id is not None:
filter_str += ' AND task_id=(?)'
filter_args.append(task_id)
with db_utils.safe_cursor(_DB_PATH) as cursor:
cursor.execute(
'UPDATE spot SET local_log_file=(?) '
f'WHERE {filter_str}', filter_args)


# ======== utility functions ========
def get_nonterminal_job_ids_by_name(name: Optional[str]) -> List[int]:
"""Get non-terminal job ids by name."""
Expand Down Expand Up @@ -662,3 +680,17 @@ def get_task_specs(job_id: int, task_id: int) -> Dict[str, Any]:
WHERE spot_job_id=(?) AND task_id=(?)""",
(job_id, task_id)).fetchone()
return json.loads(task_specs[0])


def get_local_log_file(job_id: int, task_id: Optional[int]) -> Optional[str]:
"""Get the local log directory for a job."""
filter_str = 'spot_job_id=(?)'
filter_args = [job_id]
if task_id is not None:
filter_str += ' AND task_id=(?)'
filter_args.append(task_id)
with db_utils.safe_cursor(_DB_PATH) as cursor:
local_log_file = cursor.execute(
f'SELECT local_log_file FROM spot '
f'WHERE {filter_str}', filter_args).fetchone()
return local_log_file[-1] if local_log_file else None
18 changes: 16 additions & 2 deletions sky/jobs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,10 +327,24 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str:
if managed_job_status.is_failed():
job_msg = ('\nFailure reason: '
f'{managed_job_state.get_failure_reason(job_id)}')
log_file = managed_job_state.get_local_log_file(job_id, None)
if log_file is not None:
with open(log_file, 'r', encoding='utf-8') as f:
# Stream the logs to the console without reading the whole
# file into memory.
start_streaming = False
for line in f:
if log_lib.LOG_FILE_START_STREAMING_AT in line:
start_streaming = True
if start_streaming:
print(line, end='', flush=True)
return ''
return (f'{colorama.Fore.YELLOW}'
f'Job {job_id} is already in terminal state '
f'{managed_job_status.value}. Logs will not be shown.'
f'{colorama.Style.RESET_ALL}{job_msg}')
f'{managed_job_status.value}. For more details, run: '
f'sky jobs logs --controller {job_id}'
f'{colorama.Style.RESET_ALL}'
f'{job_msg}')
backend = backends.CloudVmRayBackend()
task_id, managed_job_status = (
managed_job_state.get_latest_task_id_status(job_id))
Expand Down
4 changes: 3 additions & 1 deletion sky/skylet/log_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@

logger = sky_logging.init_logger(__name__)

LOG_FILE_START_STREAMING_AT = 'Waiting for task resources on '


class _ProcessingArgs:
"""Arguments for processing logs."""
Expand Down Expand Up @@ -435,7 +437,7 @@ def tail_logs(job_id: Optional[int],
time.sleep(_SKY_LOG_WAITING_GAP_SECONDS)
status = job_lib.update_job_status([job_id], silent=True)[0]

start_stream_at = 'Waiting for task resources on '
start_stream_at = LOG_FILE_START_STREAMING_AT
# Explicitly declare the type to avoid mypy warning.
lines: Iterable[str] = []
if follow and status in [
Expand Down
3 changes: 3 additions & 0 deletions sky/skylet/log_lib.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ from sky.skylet import constants as constants
from sky.skylet import job_lib as job_lib
from sky.utils import log_utils as log_utils

LOG_FILE_START_STREAMING_AT: str = ...


class _ProcessingArgs:
log_path: str
stream_logs: bool
Expand Down
Loading

0 comments on commit 2157f01

Please sign in to comment.