Skip to content

Commit

Permalink
lint and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
romilbhardwaj committed Sep 5, 2024
1 parent f315be0 commit 004e920
Show file tree
Hide file tree
Showing 10 changed files with 133 additions and 65 deletions.
1 change: 1 addition & 0 deletions sky/adaptors/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def _load_config(context: str = None):
raise ValueError(err_str) from None
_configured = True


@functools.lru_cache()
@_api_logging_decorator('urllib3', logging.ERROR)
def core_api(context: str = None):
Expand Down
17 changes: 9 additions & 8 deletions sky/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,10 +378,11 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
public_key_path = os.path.expanduser(PUBLIC_SSH_KEY_PATH)
secret_name = clouds.Kubernetes.SKY_SSH_KEY_SECRET_NAME
secret_field_name = clouds.Kubernetes().ssh_key_secret_field_name
namespace = config['provider'].get('namespace',
kubernetes_utils.get_current_kube_config_context_namespace())
context = config['provider'].get('context',
kubernetes_utils.get_current_kube_config_context_name())
namespace = config['provider'].get(
'namespace',
kubernetes_utils.get_current_kube_config_context_namespace())
context = config['provider'].get(
'context', kubernetes_utils.get_current_kube_config_context_name())
k8s = kubernetes.kubernetes
with open(public_key_path, 'r', encoding='utf-8') as f:
public_key = f.read()
Expand All @@ -404,8 +405,8 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
string_data={secret_field_name: public_key})
if kubernetes_utils.check_secret_exists(secret_name, namespace, context):
logger.debug(f'Key {secret_name} exists in the cluster, patching it...')
kubernetes.core_api(context).patch_namespaced_secret(secret_name, namespace,
secret)
kubernetes.core_api(context).patch_namespaced_secret(
secret_name, namespace, secret)
else:
logger.debug(
f'Key {secret_name} does not exist in the cluster, creating it...')
Expand All @@ -418,8 +419,8 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
# Setup service for SSH jump pod. We create the SSH jump service here
# because we need to know the service IP address and port to set the
# ssh_proxy_command in the autoscaler config.
kubernetes_utils.setup_ssh_jump_svc(ssh_jump_name, namespace,
context, service_type)
kubernetes_utils.setup_ssh_jump_svc(ssh_jump_name, namespace, context,
service_type)
ssh_proxy_cmd = kubernetes_utils.get_ssh_proxy_command(
ssh_jump_name,
nodeport_mode,
Expand Down
1 change: 0 additions & 1 deletion sky/clouds/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,6 @@ def get_current_user_identity(cls) -> Optional[List[str]]:
"""
return None


@classmethod
def get_current_user_identity_str(cls) -> Optional[str]:
"""Returns a user friendly representation of the current identity."""
Expand Down
6 changes: 4 additions & 2 deletions sky/clouds/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,8 @@ def make_deploy_resources_variables(
'timeout': str(timeout),
'k8s_namespace':
kubernetes_utils.get_current_kube_config_context_namespace(),
'k8s_context': kubernetes_utils.get_current_kube_config_context_name(),
'k8s_context':
kubernetes_utils.get_current_kube_config_context_name(),
'k8s_port_mode': port_mode.value,
'k8s_networking_mode': network_utils.get_networking_mode().value,
'k8s_ssh_key_secret_name': self.SKY_SSH_KEY_SECRET_NAME,
Expand Down Expand Up @@ -453,7 +454,8 @@ def get_supported_identities(cls) -> Optional[List[str]]:
k8s = kubernetes.kubernetes
identities = []
try:
all_contexts, current_context = k8s.config.list_kube_config_contexts()
all_contexts, current_context = k8s.config.list_kube_config_contexts(
)
# Add current context at the head of the list
for context in all_contexts:
identity_str = cls.get_identity_from_context(context)
Expand Down
32 changes: 17 additions & 15 deletions sky/provision/kubernetes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ def bootstrap_instances(
region: str, cluster_name: str,
config: common.ProvisionConfig) -> common.ProvisionConfig:
del region, cluster_name # unused
namespace = kubernetes_utils.get_namespace_from_config(config.provider_config)
namespace = kubernetes_utils.get_namespace_from_config(
config.provider_config)
context = kubernetes_utils.get_context_from_config(config.provider_config)


_configure_services(namespace, context, config.provider_config)

networking_mode = network_utils.get_networking_mode(
Expand All @@ -42,7 +42,8 @@ def bootstrap_instances(
# necessary roles and role bindings.
# If not, set up the roles and bindings for skypilot-service-account
# here.
_configure_autoscaler_service_account(namespace, context, config.provider_config)
_configure_autoscaler_service_account(namespace, context,
config.provider_config)
_configure_autoscaler_role(namespace,
context,
config.provider_config,
Expand All @@ -52,9 +53,9 @@ def bootstrap_instances(
context,
config.provider_config,
binding_field='autoscaler_role_binding')
_configure_autoscaler_cluster_role(namespace, context, config.provider_config)
_configure_autoscaler_cluster_role_binding(namespace,
context,
_configure_autoscaler_cluster_role(namespace, context,
config.provider_config)
_configure_autoscaler_cluster_role_binding(namespace, context,
config.provider_config)
# SkyPilot system namespace is required for FUSE mounting. Here we just
# create the namespace and set up the necessary permissions.
Expand Down Expand Up @@ -274,7 +275,8 @@ def _configure_autoscaler_service_account(

logger.info('_configure_autoscaler_service_account: '
f'{not_found_msg(account_field, name)}')
kubernetes.core_api(context).create_namespaced_service_account(namespace, account)
kubernetes.core_api(context).create_namespaced_service_account(
namespace, account)
logger.info('_configure_autoscaler_service_account: '
f'{created_msg(account_field, name)}')

Expand Down Expand Up @@ -316,7 +318,8 @@ def _configure_autoscaler_role(namespace: str, context: str,
return
logger.info('_configure_autoscaler_role: '
f'{updating_existing_msg(role_field, name)}')
kubernetes.auth_api(context).patch_namespaced_role(name, namespace, role)
kubernetes.auth_api(context).patch_namespaced_role(
name, namespace, role)
return

logger.info('_configure_autoscaler_role: '
Expand Down Expand Up @@ -387,13 +390,13 @@ def _configure_autoscaler_role_binding(

logger.info('_configure_autoscaler_role_binding: '
f'{not_found_msg(binding_field, name)}')
kubernetes.auth_api(context).create_namespaced_role_binding(rb_namespace, binding)
kubernetes.auth_api(context).create_namespaced_role_binding(
rb_namespace, binding)
logger.info('_configure_autoscaler_role_binding: '
f'{created_msg(binding_field, name)}')


def _configure_autoscaler_cluster_role(namespace,
context,
def _configure_autoscaler_cluster_role(namespace, context,
provider_config: Dict[str, Any]) -> None:
role_field = 'autoscaler_cluster_role'
if role_field not in provider_config:
Expand Down Expand Up @@ -432,9 +435,7 @@ def _configure_autoscaler_cluster_role(namespace,


def _configure_autoscaler_cluster_role_binding(
namespace,
context,
provider_config: Dict[str, Any]) -> None:
namespace, context, provider_config: Dict[str, Any]) -> None:
binding_field = 'autoscaler_cluster_role_binding'
if binding_field not in provider_config:
logger.info('_configure_autoscaler_cluster_role_binding: '
Expand Down Expand Up @@ -654,7 +655,8 @@ def _configure_services(namespace: str, context: str,
else:
logger.info(
f'_configure_services: {not_found_msg("service", name)}')
kubernetes.core_api(context).create_namespaced_service(namespace, service)
kubernetes.core_api(context).create_namespaced_service(
namespace, service)
logger.info(f'_configure_services: {created_msg("service", name)}')


Expand Down
34 changes: 22 additions & 12 deletions sky/provision/kubernetes/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ def _lack_resource_msg(resource: str,
return msg

for new_node in new_nodes:
pod = kubernetes.core_api(context).read_namespaced_pod(new_node.metadata.name,
namespace)
pod = kubernetes.core_api(context).read_namespaced_pod(
new_node.metadata.name, namespace)
pod_status = pod.status.phase
# When there are multiple pods involved while launching instance,
# there may be a single pod causing issue while others are
Expand Down Expand Up @@ -366,7 +366,8 @@ def _set_env_vars_in_pods(namespace: str, context: str, new_pods: List):
new_pod.metadata.name, rc, stdout)


def _check_user_privilege(namespace: str, context: str, new_nodes: List) -> None:
def _check_user_privilege(namespace: str, context: str,
new_nodes: List) -> None:
# Checks if the default user has sufficient privilege to set up
# the kubernetes instance pod.
check_k8s_user_sudo_cmd = (
Expand Down Expand Up @@ -434,7 +435,8 @@ def _setup_ssh_in_pods(namespace: str, context: str, new_nodes: List) -> None:
# TODO(romilb): Parallelize the setup of SSH in pods for multi-node clusters
for new_node in new_nodes:
pod_name = new_node.metadata.name
runner = command_runner.KubernetesCommandRunner(((namespace, context), pod_name))
runner = command_runner.KubernetesCommandRunner(
((namespace, context), pod_name))
logger.info(f'{"-"*20}Start: Set up SSH in pod {pod_name!r} {"-"*20}')
rc, stdout, _ = runner.run(set_k8s_ssh_cmd,
require_outputs=True,
Expand All @@ -444,7 +446,8 @@ def _setup_ssh_in_pods(namespace: str, context: str, new_nodes: List) -> None:
logger.info(f'{"-"*20}End: Set up SSH in pod {pod_name!r} {"-"*20}')


def _label_pod(namespace: str, context: str, pod_name: str, label: Dict[str, str]) -> None:
def _label_pod(namespace: str, context: str, pod_name: str,
label: Dict[str, str]) -> None:
"""Label a pod."""
kubernetes.core_api(context).patch_namespaced_pod(
pod_name,
Expand Down Expand Up @@ -480,7 +483,8 @@ def _create_pods(region: str, cluster_name_on_cloud: str,
'terminating pods. Waiting them to finish: '
f'{list(terminating_pods.keys())}')
time.sleep(POLL_INTERVAL)
terminating_pods = _filter_pods(namespace, context, tags, ['Terminating'])
terminating_pods = _filter_pods(namespace, context, tags,
['Terminating'])

if len(terminating_pods) > 0:
# If there are still terminating pods, we force delete them.
Expand All @@ -497,7 +501,8 @@ def _create_pods(region: str, cluster_name_on_cloud: str,
_request_timeout=config_lib.DELETION_TIMEOUT,
grace_period_seconds=0)

running_pods = _filter_pods(namespace, context, tags, ['Pending', 'Running'])
running_pods = _filter_pods(namespace, context, tags,
['Pending', 'Running'])
head_pod_name = _get_head_pod_name(running_pods)
logger.debug(f'Found {len(running_pods)} existing pods: '
f'{list(running_pods.keys())}')
Expand All @@ -515,7 +520,8 @@ def _create_pods(region: str, cluster_name_on_cloud: str,
# Add nvidia runtime class if it exists
nvidia_runtime_exists = False
try:
nvidia_runtime_exists = kubernetes_utils.check_nvidia_runtime_class(context)
nvidia_runtime_exists = kubernetes_utils.check_nvidia_runtime_class(
context)
except kubernetes.kubernetes.client.ApiException as e:
logger.warning('run_instances: Error occurred while checking for '
f'nvidia RuntimeClass - '
Expand Down Expand Up @@ -569,7 +575,8 @@ def _create_pods(region: str, cluster_name_on_cloud: str,
}
}

pod = kubernetes.core_api(context).create_namespaced_pod(namespace, pod_spec)
pod = kubernetes.core_api(context).create_namespaced_pod(
namespace, pod_spec)
created_pods[pod.metadata.name] = pod
if head_pod_name is None:
head_pod_name = pod.metadata.name
Expand Down Expand Up @@ -777,7 +784,8 @@ def get_cluster_info(
ssh_user = 'sky'
get_k8s_ssh_user_cmd = 'echo $(whoami)'
assert head_pod_name is not None
runner = command_runner.KubernetesCommandRunner(((namespace, context), head_pod_name))
runner = command_runner.KubernetesCommandRunner(
((namespace, context), head_pod_name))
rc, stdout, stderr = runner.run(get_k8s_ssh_user_cmd,
require_outputs=True,
separate_stderr=True,
Expand Down Expand Up @@ -857,8 +865,10 @@ def get_command_runners(
"""Get a command runner for the given cluster."""
assert cluster_info.provider_config is not None, cluster_info
instances = cluster_info.instances
namespace = kubernetes_utils.get_namespace_from_config(cluster_info.provider_config)
context = kubernetes_utils.get_context_from_config(cluster_info.provider_config)
namespace = kubernetes_utils.get_namespace_from_config(
cluster_info.provider_config)
context = kubernetes_utils.get_context_from_config(
cluster_info.provider_config)
node_list = []
if cluster_info.head_instance_id is not None:
node_list = [((namespace, context), cluster_info.head_instance_id)]
Expand Down
3 changes: 2 additions & 1 deletion sky/provision/kubernetes/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def _open_ports_using_ingress(
# Update metadata from config
kubernetes_utils.merge_custom_metadata(service_spec['metadata'])
network_utils.create_or_replace_namespaced_service(
namespace=kubernetes_utils.get_namespace_from_config(provider_config),
namespace=kubernetes_utils.get_namespace_from_config(
provider_config),
context=kubernetes_utils.get_context_from_config(provider_config),
service_name=service_name,
service_spec=service_spec,
Expand Down
10 changes: 5 additions & 5 deletions sky/provision/kubernetes/network_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,7 @@ def fill_ingress_template(namespace: str, service_details: List[Tuple[str, int,


def create_or_replace_namespaced_ingress(
namespace: str,
context: str,
ingress_name: str,
namespace: str, context: str, ingress_name: str,
ingress_spec: Dict[str, Union[str, int]]) -> None:
"""Creates an ingress resource for the specified service."""
networking_api = kubernetes.networking_api(context)
Expand All @@ -158,7 +156,8 @@ def create_or_replace_namespaced_ingress(
_request_timeout=kubernetes.API_TIMEOUT)


def delete_namespaced_ingress(namespace: str, context: str, ingress_name: str) -> None:
def delete_namespaced_ingress(namespace: str, context: str,
ingress_name: str) -> None:
"""Deletes an ingress resource."""
networking_api = kubernetes.networking_api(context)
try:
Expand Down Expand Up @@ -209,7 +208,8 @@ def delete_namespaced_service(namespace: str, service_name: str) -> None:
raise e


def ingress_controller_exists(context: str, ingress_class_name: str = 'nginx') -> bool:
def ingress_controller_exists(context: str,
ingress_class_name: str = 'nginx') -> bool:
"""Checks if an ingress controller exists in the cluster."""
networking_api = kubernetes.networking_api(context)
ingress_classes = networking_api.list_ingress_class(
Expand Down
Loading

0 comments on commit 004e920

Please sign in to comment.