diff --git a/sky/adaptors/kubernetes.py b/sky/adaptors/kubernetes.py index 8ecfcd2ebe9..592e3968da6 100644 --- a/sky/adaptors/kubernetes.py +++ b/sky/adaptors/kubernetes.py @@ -95,6 +95,7 @@ def _load_config(context: str = None): raise ValueError(err_str) from None _configured = True + @functools.lru_cache() @_api_logging_decorator('urllib3', logging.ERROR) def core_api(context: str = None): diff --git a/sky/authentication.py b/sky/authentication.py index ee413c22264..4a37cbd2373 100644 --- a/sky/authentication.py +++ b/sky/authentication.py @@ -378,10 +378,11 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]: public_key_path = os.path.expanduser(PUBLIC_SSH_KEY_PATH) secret_name = clouds.Kubernetes.SKY_SSH_KEY_SECRET_NAME secret_field_name = clouds.Kubernetes().ssh_key_secret_field_name - namespace = config['provider'].get('namespace', - kubernetes_utils.get_current_kube_config_context_namespace()) - context = config['provider'].get('context', - kubernetes_utils.get_current_kube_config_context_name()) + namespace = config['provider'].get( + 'namespace', + kubernetes_utils.get_current_kube_config_context_namespace()) + context = config['provider'].get( + 'context', kubernetes_utils.get_current_kube_config_context_name()) k8s = kubernetes.kubernetes with open(public_key_path, 'r', encoding='utf-8') as f: public_key = f.read() @@ -404,8 +405,8 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]: string_data={secret_field_name: public_key}) if kubernetes_utils.check_secret_exists(secret_name, namespace, context): logger.debug(f'Key {secret_name} exists in the cluster, patching it...') - kubernetes.core_api(context).patch_namespaced_secret(secret_name, namespace, - secret) + kubernetes.core_api(context).patch_namespaced_secret( + secret_name, namespace, secret) else: logger.debug( f'Key {secret_name} does not exist in the cluster, creating it...') @@ -418,8 +419,8 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]: # Setup service for SSH jump pod. We create the SSH jump service here # because we need to know the service IP address and port to set the # ssh_proxy_command in the autoscaler config. - kubernetes_utils.setup_ssh_jump_svc(ssh_jump_name, namespace, - context, service_type) + kubernetes_utils.setup_ssh_jump_svc(ssh_jump_name, namespace, context, + service_type) ssh_proxy_cmd = kubernetes_utils.get_ssh_proxy_command( ssh_jump_name, nodeport_mode, diff --git a/sky/clouds/cloud.py b/sky/clouds/cloud.py index db81cf8f436..9eadd50a239 100644 --- a/sky/clouds/cloud.py +++ b/sky/clouds/cloud.py @@ -488,7 +488,6 @@ def get_current_user_identity(cls) -> Optional[List[str]]: """ return None - @classmethod def get_current_user_identity_str(cls) -> Optional[str]: """Returns a user friendly representation of the current identity.""" diff --git a/sky/clouds/kubernetes.py b/sky/clouds/kubernetes.py index 6a0db1a181f..4a1ad967f4e 100644 --- a/sky/clouds/kubernetes.py +++ b/sky/clouds/kubernetes.py @@ -316,7 +316,8 @@ def make_deploy_resources_variables( 'timeout': str(timeout), 'k8s_namespace': kubernetes_utils.get_current_kube_config_context_namespace(), - 'k8s_context': kubernetes_utils.get_current_kube_config_context_name(), + 'k8s_context': + kubernetes_utils.get_current_kube_config_context_name(), 'k8s_port_mode': port_mode.value, 'k8s_networking_mode': network_utils.get_networking_mode().value, 'k8s_ssh_key_secret_name': self.SKY_SSH_KEY_SECRET_NAME, @@ -453,7 +454,8 @@ def get_supported_identities(cls) -> Optional[List[str]]: k8s = kubernetes.kubernetes identities = [] try: - all_contexts, current_context = k8s.config.list_kube_config_contexts() + all_contexts, current_context = k8s.config.list_kube_config_contexts( + ) # Add current context at the head of the list for context in all_contexts: identity_str = cls.get_identity_from_context(context) diff --git a/sky/provision/kubernetes/config.py b/sky/provision/kubernetes/config.py index 98e89477058..e377f3029b8 100644 --- a/sky/provision/kubernetes/config.py +++ b/sky/provision/kubernetes/config.py @@ -23,10 +23,10 @@ def bootstrap_instances( region: str, cluster_name: str, config: common.ProvisionConfig) -> common.ProvisionConfig: del region, cluster_name # unused - namespace = kubernetes_utils.get_namespace_from_config(config.provider_config) + namespace = kubernetes_utils.get_namespace_from_config( + config.provider_config) context = kubernetes_utils.get_context_from_config(config.provider_config) - _configure_services(namespace, context, config.provider_config) networking_mode = network_utils.get_networking_mode( @@ -42,7 +42,8 @@ def bootstrap_instances( # necessary roles and role bindings. # If not, set up the roles and bindings for skypilot-service-account # here. - _configure_autoscaler_service_account(namespace, context, config.provider_config) + _configure_autoscaler_service_account(namespace, context, + config.provider_config) _configure_autoscaler_role(namespace, context, config.provider_config, @@ -52,9 +53,9 @@ def bootstrap_instances( context, config.provider_config, binding_field='autoscaler_role_binding') - _configure_autoscaler_cluster_role(namespace, context, config.provider_config) - _configure_autoscaler_cluster_role_binding(namespace, - context, + _configure_autoscaler_cluster_role(namespace, context, + config.provider_config) + _configure_autoscaler_cluster_role_binding(namespace, context, config.provider_config) # SkyPilot system namespace is required for FUSE mounting. Here we just # create the namespace and set up the necessary permissions. @@ -274,7 +275,8 @@ def _configure_autoscaler_service_account( logger.info('_configure_autoscaler_service_account: ' f'{not_found_msg(account_field, name)}') - kubernetes.core_api(context).create_namespaced_service_account(namespace, account) + kubernetes.core_api(context).create_namespaced_service_account( + namespace, account) logger.info('_configure_autoscaler_service_account: ' f'{created_msg(account_field, name)}') @@ -316,7 +318,8 @@ def _configure_autoscaler_role(namespace: str, context: str, return logger.info('_configure_autoscaler_role: ' f'{updating_existing_msg(role_field, name)}') - kubernetes.auth_api(context).patch_namespaced_role(name, namespace, role) + kubernetes.auth_api(context).patch_namespaced_role( + name, namespace, role) return logger.info('_configure_autoscaler_role: ' @@ -387,13 +390,13 @@ def _configure_autoscaler_role_binding( logger.info('_configure_autoscaler_role_binding: ' f'{not_found_msg(binding_field, name)}') - kubernetes.auth_api(context).create_namespaced_role_binding(rb_namespace, binding) + kubernetes.auth_api(context).create_namespaced_role_binding( + rb_namespace, binding) logger.info('_configure_autoscaler_role_binding: ' f'{created_msg(binding_field, name)}') -def _configure_autoscaler_cluster_role(namespace, - context, +def _configure_autoscaler_cluster_role(namespace, context, provider_config: Dict[str, Any]) -> None: role_field = 'autoscaler_cluster_role' if role_field not in provider_config: @@ -432,9 +435,7 @@ def _configure_autoscaler_cluster_role(namespace, def _configure_autoscaler_cluster_role_binding( - namespace, - context, - provider_config: Dict[str, Any]) -> None: + namespace, context, provider_config: Dict[str, Any]) -> None: binding_field = 'autoscaler_cluster_role_binding' if binding_field not in provider_config: logger.info('_configure_autoscaler_cluster_role_binding: ' @@ -654,7 +655,8 @@ def _configure_services(namespace: str, context: str, else: logger.info( f'_configure_services: {not_found_msg("service", name)}') - kubernetes.core_api(context).create_namespaced_service(namespace, service) + kubernetes.core_api(context).create_namespaced_service( + namespace, service) logger.info(f'_configure_services: {created_msg("service", name)}') diff --git a/sky/provision/kubernetes/instance.py b/sky/provision/kubernetes/instance.py index e6aea9aac50..f865cc521f3 100644 --- a/sky/provision/kubernetes/instance.py +++ b/sky/provision/kubernetes/instance.py @@ -133,8 +133,8 @@ def _lack_resource_msg(resource: str, return msg for new_node in new_nodes: - pod = kubernetes.core_api(context).read_namespaced_pod(new_node.metadata.name, - namespace) + pod = kubernetes.core_api(context).read_namespaced_pod( + new_node.metadata.name, namespace) pod_status = pod.status.phase # When there are multiple pods involved while launching instance, # there may be a single pod causing issue while others are @@ -366,7 +366,8 @@ def _set_env_vars_in_pods(namespace: str, context: str, new_pods: List): new_pod.metadata.name, rc, stdout) -def _check_user_privilege(namespace: str, context: str, new_nodes: List) -> None: +def _check_user_privilege(namespace: str, context: str, + new_nodes: List) -> None: # Checks if the default user has sufficient privilege to set up # the kubernetes instance pod. check_k8s_user_sudo_cmd = ( @@ -434,7 +435,8 @@ def _setup_ssh_in_pods(namespace: str, context: str, new_nodes: List) -> None: # TODO(romilb): Parallelize the setup of SSH in pods for multi-node clusters for new_node in new_nodes: pod_name = new_node.metadata.name - runner = command_runner.KubernetesCommandRunner(((namespace, context), pod_name)) + runner = command_runner.KubernetesCommandRunner( + ((namespace, context), pod_name)) logger.info(f'{"-"*20}Start: Set up SSH in pod {pod_name!r} {"-"*20}') rc, stdout, _ = runner.run(set_k8s_ssh_cmd, require_outputs=True, @@ -444,7 +446,8 @@ def _setup_ssh_in_pods(namespace: str, context: str, new_nodes: List) -> None: logger.info(f'{"-"*20}End: Set up SSH in pod {pod_name!r} {"-"*20}') -def _label_pod(namespace: str, context: str, pod_name: str, label: Dict[str, str]) -> None: +def _label_pod(namespace: str, context: str, pod_name: str, + label: Dict[str, str]) -> None: """Label a pod.""" kubernetes.core_api(context).patch_namespaced_pod( pod_name, @@ -480,7 +483,8 @@ def _create_pods(region: str, cluster_name_on_cloud: str, 'terminating pods. Waiting them to finish: ' f'{list(terminating_pods.keys())}') time.sleep(POLL_INTERVAL) - terminating_pods = _filter_pods(namespace, context, tags, ['Terminating']) + terminating_pods = _filter_pods(namespace, context, tags, + ['Terminating']) if len(terminating_pods) > 0: # If there are still terminating pods, we force delete them. @@ -497,7 +501,8 @@ def _create_pods(region: str, cluster_name_on_cloud: str, _request_timeout=config_lib.DELETION_TIMEOUT, grace_period_seconds=0) - running_pods = _filter_pods(namespace, context, tags, ['Pending', 'Running']) + running_pods = _filter_pods(namespace, context, tags, + ['Pending', 'Running']) head_pod_name = _get_head_pod_name(running_pods) logger.debug(f'Found {len(running_pods)} existing pods: ' f'{list(running_pods.keys())}') @@ -515,7 +520,8 @@ def _create_pods(region: str, cluster_name_on_cloud: str, # Add nvidia runtime class if it exists nvidia_runtime_exists = False try: - nvidia_runtime_exists = kubernetes_utils.check_nvidia_runtime_class(context) + nvidia_runtime_exists = kubernetes_utils.check_nvidia_runtime_class( + context) except kubernetes.kubernetes.client.ApiException as e: logger.warning('run_instances: Error occurred while checking for ' f'nvidia RuntimeClass - ' @@ -569,7 +575,8 @@ def _create_pods(region: str, cluster_name_on_cloud: str, } } - pod = kubernetes.core_api(context).create_namespaced_pod(namespace, pod_spec) + pod = kubernetes.core_api(context).create_namespaced_pod( + namespace, pod_spec) created_pods[pod.metadata.name] = pod if head_pod_name is None: head_pod_name = pod.metadata.name @@ -777,7 +784,8 @@ def get_cluster_info( ssh_user = 'sky' get_k8s_ssh_user_cmd = 'echo $(whoami)' assert head_pod_name is not None - runner = command_runner.KubernetesCommandRunner(((namespace, context), head_pod_name)) + runner = command_runner.KubernetesCommandRunner( + ((namespace, context), head_pod_name)) rc, stdout, stderr = runner.run(get_k8s_ssh_user_cmd, require_outputs=True, separate_stderr=True, @@ -857,8 +865,10 @@ def get_command_runners( """Get a command runner for the given cluster.""" assert cluster_info.provider_config is not None, cluster_info instances = cluster_info.instances - namespace = kubernetes_utils.get_namespace_from_config(cluster_info.provider_config) - context = kubernetes_utils.get_context_from_config(cluster_info.provider_config) + namespace = kubernetes_utils.get_namespace_from_config( + cluster_info.provider_config) + context = kubernetes_utils.get_context_from_config( + cluster_info.provider_config) node_list = [] if cluster_info.head_instance_id is not None: node_list = [((namespace, context), cluster_info.head_instance_id)] diff --git a/sky/provision/kubernetes/network.py b/sky/provision/kubernetes/network.py index d56c3cc02b3..7b086473d64 100644 --- a/sky/provision/kubernetes/network.py +++ b/sky/provision/kubernetes/network.py @@ -110,7 +110,8 @@ def _open_ports_using_ingress( # Update metadata from config kubernetes_utils.merge_custom_metadata(service_spec['metadata']) network_utils.create_or_replace_namespaced_service( - namespace=kubernetes_utils.get_namespace_from_config(provider_config), + namespace=kubernetes_utils.get_namespace_from_config( + provider_config), context=kubernetes_utils.get_context_from_config(provider_config), service_name=service_name, service_spec=service_spec, diff --git a/sky/provision/kubernetes/network_utils.py b/sky/provision/kubernetes/network_utils.py index 6048ec57cb9..ba126197446 100644 --- a/sky/provision/kubernetes/network_utils.py +++ b/sky/provision/kubernetes/network_utils.py @@ -132,9 +132,7 @@ def fill_ingress_template(namespace: str, service_details: List[Tuple[str, int, def create_or_replace_namespaced_ingress( - namespace: str, - context: str, - ingress_name: str, + namespace: str, context: str, ingress_name: str, ingress_spec: Dict[str, Union[str, int]]) -> None: """Creates an ingress resource for the specified service.""" networking_api = kubernetes.networking_api(context) @@ -158,7 +156,8 @@ def create_or_replace_namespaced_ingress( _request_timeout=kubernetes.API_TIMEOUT) -def delete_namespaced_ingress(namespace: str, context: str, ingress_name: str) -> None: +def delete_namespaced_ingress(namespace: str, context: str, + ingress_name: str) -> None: """Deletes an ingress resource.""" networking_api = kubernetes.networking_api(context) try: @@ -209,7 +208,8 @@ def delete_namespaced_service(namespace: str, service_name: str) -> None: raise e -def ingress_controller_exists(context: str, ingress_class_name: str = 'nginx') -> bool: +def ingress_controller_exists(context: str, + ingress_class_name: str = 'nginx') -> bool: """Checks if an ingress controller exists in the cluster.""" networking_api = kubernetes.networking_api(context) ingress_classes = networking_api.list_ingress_class( diff --git a/sky/provision/kubernetes/utils.py b/sky/provision/kubernetes/utils.py index d34df9bb235..dc49415f989 100644 --- a/sky/provision/kubernetes/utils.py +++ b/sky/provision/kubernetes/utils.py @@ -595,9 +595,8 @@ def get_port(svc_name: str, namespace: str, context: str) -> int: return head_service.spec.ports[0].node_port -def get_external_ip( - network_mode: Optional[kubernetes_enums.KubernetesNetworkingMode], - context: str) -> str: +def get_external_ip(network_mode: Optional[ + kubernetes_enums.KubernetesNetworkingMode], context: str) -> str: if network_mode == kubernetes_enums.KubernetesNetworkingMode.PORTFORWARD: return '127.0.0.1' # Return the IP address of the first node with an external IP @@ -627,8 +626,8 @@ def check_credentials(timeout: int = kubernetes.API_TIMEOUT) -> \ try: ns = get_current_kube_config_context_namespace() context = get_current_kube_config_context_name() - kubernetes.core_api(context).list_namespaced_pod(ns, - _request_timeout=timeout) + kubernetes.core_api(context).list_namespaced_pod( + ns, _request_timeout=timeout) except ImportError: # TODO(romilb): Update these error strs to also include link to docs # when docs are ready. @@ -1092,13 +1091,14 @@ def setup_ssh_jump_svc(ssh_jump_name: str, namespace: str, context: str, # Create service try: - kubernetes.core_api(context).create_namespaced_service(namespace, - content['service_spec']) + kubernetes.core_api(context).create_namespaced_service( + namespace, content['service_spec']) except kubernetes.api_exception() as e: # SSH Jump Pod service already exists. if e.status == 409: - ssh_jump_service = kubernetes.core_api(context).read_namespaced_service( - name=ssh_jump_name, namespace=namespace) + ssh_jump_service = kubernetes.core_api( + context).read_namespaced_service(name=ssh_jump_name, + namespace=namespace) curr_svc_type = ssh_jump_service.spec.type if service_type.value == curr_svc_type: # If the currently existing SSH Jump service's type is identical @@ -1183,7 +1183,8 @@ def setup_ssh_jump_pod(ssh_jump_name: str, ssh_jump_image: str, logger.info('Created SSH Jump ServiceAccount.') # Role try: - kubernetes.auth_api(context).create_namespaced_role(namespace, content['role']) + kubernetes.auth_api(context).create_namespaced_role( + namespace, content['role']) except kubernetes.api_exception() as e: if e.status == 409: logger.info( @@ -1207,8 +1208,8 @@ def setup_ssh_jump_pod(ssh_jump_name: str, ssh_jump_image: str, logger.info('Created SSH Jump RoleBinding.') # Pod try: - kubernetes.core_api(context).create_namespaced_pod(namespace, - content['pod_spec']) + kubernetes.core_api(context).create_namespaced_pod( + namespace, content['pod_spec']) except kubernetes.api_exception() as e: if e.status == 409: logger.info( @@ -1240,7 +1241,8 @@ def find(l, predicate): # Get the SSH jump pod name from the head pod try: - pod = kubernetes.core_api(context).read_namespaced_pod(node_id, namespace) + pod = kubernetes.core_api(context).read_namespaced_pod( + node_id, namespace) except kubernetes.api_exception() as e: if e.status == 404: logger.warning(f'Failed to get pod {node_id},' @@ -1260,8 +1262,8 @@ def find(l, predicate): # ssh jump pod, lets remove it and the service. Otherwise, main # container is ready and its lifecycle management script takes # care of the cleaning. - kubernetes.core_api(context).delete_namespaced_pod(ssh_jump_name, - namespace) + kubernetes.core_api(context).delete_namespaced_pod( + ssh_jump_name, namespace) kubernetes.core_api(context).delete_namespaced_service( ssh_jump_name, namespace) except kubernetes.api_exception() as e: @@ -1742,12 +1744,10 @@ def get_kubernetes_node_info() -> Dict[str, KubernetesNodeInfo]: def get_namespace_from_config(provider_config: Dict[str, Any]) -> str: - return provider_config.get( - 'namespace', - get_current_kube_config_context_namespace()) + return provider_config.get('namespace', + get_current_kube_config_context_namespace()) def get_context_from_config(provider_config: Dict[str, Any]) -> str: - return provider_config.get( - 'context', - get_current_kube_config_context_name()) + return provider_config.get('context', + get_current_kube_config_context_name()) diff --git a/tests/test_smoke.py b/tests/test_smoke.py index c5d723b8be4..cdd2dc218ee 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -1171,6 +1171,58 @@ def test_kubernetes_storage_mounts(): run_one_test(test) +@pytest.mark.kubernetes +def test_kubernetes_context_switch(): + name = _get_cluster_name() + new_context = f'sky-test-context-{int(time.time())}' + new_namespace = f'sky-test-namespace-{int(time.time())}' + + test_commands = [ + # Launch a cluster and run a simple task + f'sky launch -y -c {name} --cloud kubernetes "echo Hello from original context"', + f'sky logs {name} 1 --status', # Ensure job succeeded + + # Get current context details and save to a file for later use in cleanup + 'CURRENT_CONTEXT=$(kubectl config current-context); ' + 'echo "$CURRENT_CONTEXT" > /tmp/sky_test_current_context; ' + 'CURRENT_CLUSTER=$(kubectl config view -o jsonpath="{.contexts[?(@.name==\\"$CURRENT_CONTEXT\\")].context.cluster}"); ' + 'CURRENT_USER=$(kubectl config view -o jsonpath="{.contexts[?(@.name==\\"$CURRENT_CONTEXT\\")].context.user}"); ' + + # Create a new context with a different name and namespace + f'kubectl config set-context {new_context} --cluster="$CURRENT_CLUSTER" --user="$CURRENT_USER" --namespace={new_namespace}', + + # Create the new namespace if it doesn't exist + f'kubectl create namespace {new_namespace} --dry-run=client -o yaml | kubectl apply -f -', + + # Set the new context as active + f'kubectl config use-context {new_context}', + + # Verify the new context is active + f'[ "$(kubectl config current-context)" = "{new_context}" ] || exit 1', + + # Try to run sky exec on the original cluster (should still work) + f'sky exec {name} "echo Success: sky exec works after context switch"', + + # Test sky queue + f'sky queue {name}', + ] + + cleanup_commands = ( + f'kubectl delete namespace {new_namespace}; ' + f'kubectl config delete-context {new_context}; ' + 'kubectl config use-context $(cat /tmp/sky_test_current_context); ' + 'rm /tmp/sky_test_current_context; ' + f'sky down -y {name}') + + test = Test( + 'kubernetes_context_switch', + test_commands, + cleanup_commands, + timeout=20 * 60, # 20 mins + ) + run_one_test(test) + + @pytest.mark.parametrize( 'image_id', [