Skip to content

Commit

Permalink
[k8s] sky status --k8s refactor (#4079)
Browse files Browse the repository at this point in the history
* refactor

* lint

* refactor, dataclass

* refactor, dataclass

* refactor

* lint
  • Loading branch information
romilbhardwaj authored Oct 15, 2024
1 parent 9243113 commit a4e2fcd
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 186 deletions.
4 changes: 2 additions & 2 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from sky.utils import ux_utils

if typing.TYPE_CHECKING:
from sky import resources
from sky import resources as resources_lib
from sky import task as task_lib
from sky.backends import cloud_vm_ray_backend
from sky.backends import local_docker_backend
Expand Down Expand Up @@ -751,7 +751,7 @@ def _restore_block(new_block: Dict[str, Any], old_block: Dict[str, Any]):
# TODO: too many things happening here - leaky abstraction. Refactor.
@timeline.event
def write_cluster_config(
to_provision: 'resources.Resources',
to_provision: 'resources_lib.Resources',
num_nodes: int,
cluster_config_template: str,
cluster_name: str,
Expand Down
52 changes: 3 additions & 49 deletions sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1464,54 +1464,8 @@ def _status_kubernetes(show_all: bool):
Args:
show_all (bool): Show all job information (e.g., start time, failures).
"""
context = kubernetes_utils.get_current_kube_config_context_name()
try:
pods = kubernetes_utils.get_skypilot_pods(context)
except exceptions.ResourcesUnavailableError as e:
with ux_utils.print_exception_no_traceback():
raise ValueError('Failed to get SkyPilot pods from '
f'Kubernetes: {str(e)}') from e
all_clusters, jobs_controllers, serve_controllers = (
status_utils.process_skypilot_pods(pods, context))
all_jobs = []
with rich_utils.safe_status(
'[bold cyan]Checking in-progress managed jobs[/]') as spinner:
for i, (_, job_controller_info) in enumerate(jobs_controllers.items()):
user = job_controller_info['user']
pod = job_controller_info['pods'][0]
status_message = ('[bold cyan]Checking managed jobs controller')
if len(jobs_controllers) > 1:
status_message += f's ({i+1}/{len(jobs_controllers)})'
spinner.update(f'{status_message}[/]')
try:
job_list = managed_jobs.queue_from_kubernetes_pod(
pod.metadata.name)
except RuntimeError as e:
logger.warning('Failed to get managed jobs from controller '
f'{pod.metadata.name}: {str(e)}')
job_list = []
# Add user field to jobs
for job in job_list:
job['user'] = user
all_jobs.extend(job_list)
# Reconcile cluster state between managed jobs and clusters:
# To maintain a clear separation between regular SkyPilot clusters
# and those from managed jobs, we need to exclude the latter from
# the main cluster list.
# We do this by reconstructing managed job cluster names from each
# job's name and ID. We then use this set to filter out managed
# clusters from the main cluster list. This is necessary because there
# are no identifiers distinguishing clusters from managed jobs from
# regular clusters.
managed_job_cluster_names = set()
for job in all_jobs:
# Managed job cluster name is <job_name>-<job_id>
managed_cluster_name = f'{job["job_name"]}-{job["job_id"]}'
managed_job_cluster_names.add(managed_cluster_name)
unmanaged_clusters = [
c for c in all_clusters
if c['cluster_name'] not in managed_job_cluster_names
]
all_clusters, unmanaged_clusters, all_jobs, context = (
core.status_kubernetes())
click.echo(f'{colorama.Fore.CYAN}{colorama.Style.BRIGHT}'
f'Kubernetes cluster state (context: {context})'
f'{colorama.Style.RESET_ALL}')
Expand All @@ -1523,7 +1477,7 @@ def _status_kubernetes(show_all: bool):
f'{colorama.Style.RESET_ALL}')
msg = managed_jobs.format_job_table(all_jobs, show_all=show_all)
click.echo(msg)
if serve_controllers:
if any(['sky-serve-controller' in c.cluster_name for c in all_clusters]):
# TODO: Parse serve controllers and show services separately.
# Currently we show a hint that services are shown as clusters.
click.echo(f'\n{colorama.Style.DIM}Hint: SkyServe replica pods are '
Expand Down
77 changes: 76 additions & 1 deletion sky/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""SDK functions for cluster/job management."""
import getpass
import typing
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import colorama

Expand All @@ -11,10 +11,12 @@
from sky import data
from sky import exceptions
from sky import global_user_state
from sky import jobs as managed_jobs
from sky import sky_logging
from sky import status_lib
from sky import task
from sky.backends import backend_utils
from sky.provision.kubernetes import utils as kubernetes_utils
from sky.skylet import constants
from sky.skylet import job_lib
from sky.usage import usage_lib
Expand Down Expand Up @@ -111,6 +113,79 @@ def status(cluster_names: Optional[Union[str, List[str]]] = None,
cluster_names=cluster_names)


def status_kubernetes(
) -> Tuple[List['kubernetes_utils.KubernetesSkyPilotClusterInfo'],
List['kubernetes_utils.KubernetesSkyPilotClusterInfo'], List[Dict[
str, Any]], Optional[str]]:
"""Get all SkyPilot clusters and jobs in the Kubernetes cluster.
Managed jobs and services are also included in the clusters returned.
The caller must parse the controllers to identify which clusters are run
as managed jobs or services.
all_clusters, unmanaged_clusters, all_jobs, context
Returns:
A tuple containing:
- all_clusters: List of KubernetesSkyPilotClusterInfo with info for
all clusters, including managed jobs, services and controllers.
- unmanaged_clusters: List of KubernetesSkyPilotClusterInfo with info
for all clusters excluding managed jobs and services. Controllers
are included.
- all_jobs: List of managed jobs from all controllers. Each entry is a
dictionary job info, see jobs.queue_from_kubernetes_pod for details.
- context: Kubernetes context used to fetch the cluster information.
"""
context = kubernetes_utils.get_current_kube_config_context_name()
try:
pods = kubernetes_utils.get_skypilot_pods(context)
except exceptions.ResourcesUnavailableError as e:
with ux_utils.print_exception_no_traceback():
raise ValueError('Failed to get SkyPilot pods from '
f'Kubernetes: {str(e)}') from e
all_clusters, jobs_controllers, _ = (kubernetes_utils.process_skypilot_pods(
pods, context))
all_jobs = []
with rich_utils.safe_status(
ux_utils.spinner_message(
'[bold cyan]Checking in-progress managed jobs[/]')) as spinner:
for i, job_controller_info in enumerate(jobs_controllers):
user = job_controller_info.user
pod = job_controller_info.pods[0]
status_message = '[bold cyan]Checking managed jobs controller'
if len(jobs_controllers) > 1:
status_message += f's ({i + 1}/{len(jobs_controllers)})'
spinner.update(f'{status_message}[/]')
try:
job_list = managed_jobs.queue_from_kubernetes_pod(
pod.metadata.name)
except RuntimeError as e:
logger.warning('Failed to get managed jobs from controller '
f'{pod.metadata.name}: {str(e)}')
job_list = []
# Add user field to jobs
for job in job_list:
job['user'] = user
all_jobs.extend(job_list)
# Reconcile cluster state between managed jobs and clusters:
# To maintain a clear separation between regular SkyPilot clusters
# and those from managed jobs, we need to exclude the latter from
# the main cluster list.
# We do this by reconstructing managed job cluster names from each
# job's name and ID. We then use this set to filter out managed
# clusters from the main cluster list. This is necessary because there
# are no identifiers distinguishing clusters from managed jobs from
# regular clusters.
managed_job_cluster_names = set()
for job in all_jobs:
# Managed job cluster name is <job_name>-<job_id>
managed_cluster_name = f'{job["job_name"]}-{job["job_id"]}'
managed_job_cluster_names.add(managed_cluster_name)
unmanaged_clusters = [
c for c in all_clusters
if c.cluster_name not in managed_job_cluster_names
]
return all_clusters, unmanaged_clusters, all_jobs, context


def endpoints(cluster: str,
port: Optional[Union[int, str]] = None) -> Dict[int, str]:
"""Gets the endpoint for a given cluster and port number (endpoint).
Expand Down
113 changes: 113 additions & 0 deletions sky/provision/kubernetes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
import yaml

import sky
from sky import clouds
from sky import exceptions
from sky import sky_logging
from sky import skypilot_config
from sky import status_lib
from sky.adaptors import kubernetes
from sky.provision import constants as provision_constants
from sky.provision.kubernetes import network_utils
Expand All @@ -30,6 +32,7 @@

if typing.TYPE_CHECKING:
from sky import backends
from sky import resources as resources_lib

# TODO(romilb): Move constants to constants.py
DEFAULT_NAMESPACE = 'default'
Expand Down Expand Up @@ -2023,3 +2026,113 @@ def get_skypilot_pods(context: Optional[str] = None) -> List[Any]:
'kubectl get pods --selector=skypilot-cluster --all-namespaces'
) from None
return pods


@dataclasses.dataclass
class KubernetesSkyPilotClusterInfo:
cluster_name_on_cloud: str
cluster_name: str
user: str
status: status_lib.ClusterStatus
pods: List[Any]
launched_at: float
resources: 'resources_lib.Resources'
resources_str: str


def process_skypilot_pods(
pods: List[Any],
context: Optional[str] = None
) -> Tuple[List[KubernetesSkyPilotClusterInfo],
List[KubernetesSkyPilotClusterInfo],
List[KubernetesSkyPilotClusterInfo]]:
"""Process SkyPilot pods on k8s to extract cluster and controller info.
Args:
pods: List of Kubernetes pod objects.
context: Kubernetes context name, used to detect GPU label formatter.
Returns:
A tuple containing:
- List of KubernetesSkyPilotClusterInfo with all cluster info.
- List of KubernetesSkyPilotClusterInfo with job controller info.
- List of KubernetesSkyPilotClusterInfo with serve controller info.
"""
# pylint: disable=import-outside-toplevel
from sky import resources as resources_lib
clusters: Dict[str, KubernetesSkyPilotClusterInfo] = {}
jobs_controllers: List[KubernetesSkyPilotClusterInfo] = []
serve_controllers: List[KubernetesSkyPilotClusterInfo] = []

for pod in pods:
cluster_name_on_cloud = pod.metadata.labels.get('skypilot-cluster')
cluster_name = cluster_name_on_cloud.rsplit(
'-', 1
)[0] # Remove the user hash to get cluster name (e.g., mycluster-2ea4)
if cluster_name_on_cloud not in clusters:
# Parse the start time for the cluster
start_time = pod.status.start_time
if start_time is not None:
start_time = pod.status.start_time.timestamp()

# Parse resources
cpu_request = parse_cpu_or_gpu_resource(
pod.spec.containers[0].resources.requests.get('cpu', '0'))
memory_request = parse_memory_resource(
pod.spec.containers[0].resources.requests.get('memory', '0'),
unit='G')
gpu_count = parse_cpu_or_gpu_resource(
pod.spec.containers[0].resources.requests.get(
'nvidia.com/gpu', '0'))
gpu_name = None
if gpu_count > 0:
label_formatter, _ = (detect_gpu_label_formatter(context))
assert label_formatter is not None, (
'GPU label formatter cannot be None if there are pods '
f'requesting GPUs: {pod.metadata.name}')
gpu_label = label_formatter.get_label_key()
# Get GPU name from pod node selector
if pod.spec.node_selector is not None:
gpu_name = label_formatter.get_accelerator_from_label_value(
pod.spec.node_selector.get(gpu_label))

resources = resources_lib.Resources(
cloud=clouds.Kubernetes(),
cpus=int(cpu_request),
memory=int(memory_request),
accelerators=(f'{gpu_name}:{gpu_count}'
if gpu_count > 0 else None))
if pod.status.phase == 'Pending':
# If pod is pending, do not show it in the status
continue

cluster_info = KubernetesSkyPilotClusterInfo(
cluster_name_on_cloud=cluster_name_on_cloud,
cluster_name=cluster_name,
user=pod.metadata.labels.get('skypilot-user'),
status=status_lib.ClusterStatus.UP,
pods=[],
launched_at=start_time,
resources=resources,
resources_str='')
clusters[cluster_name_on_cloud] = cluster_info
# Check if cluster name is name of a controller
# Can't use controller_utils.Controllers.from_name(cluster_name)
# because hash is different across users
if 'sky-jobs-controller' in cluster_name_on_cloud:
jobs_controllers.append(cluster_info)
elif 'sky-serve-controller' in cluster_name_on_cloud:
serve_controllers.append(cluster_info)
else:
# Update start_time if this pod started earlier
pod_start_time = pod.status.start_time
if pod_start_time is not None:
pod_start_time = pod_start_time.timestamp()
if pod_start_time < clusters[cluster_name_on_cloud].launched_at:
clusters[cluster_name_on_cloud].launched_at = pod_start_time
clusters[cluster_name_on_cloud].pods.append(pod)
# Update resources_str in clusters:
for cluster in clusters.values():
num_pods = len(cluster.pods)
cluster.resources_str = f'{num_pods}x {cluster.resources}'
return list(clusters.values()), jobs_controllers, serve_controllers
Loading

0 comments on commit a4e2fcd

Please sign in to comment.