diff --git a/sky/clouds/kubernetes.py b/sky/clouds/kubernetes.py index 5e1b46d52eb..aec7bdfad63 100644 --- a/sky/clouds/kubernetes.py +++ b/sky/clouds/kubernetes.py @@ -378,7 +378,7 @@ def make_deploy_resources_variables( tpu_requested = True k8s_resource_key = kubernetes_utils.TPU_RESOURCE_KEY else: - k8s_resource_key = kubernetes_utils.GPU_RESOURCE_KEY + k8s_resource_key = kubernetes_utils.get_gpu_resource_key() port_mode = network_utils.get_port_mode(None) diff --git a/sky/provision/kubernetes/instance.py b/sky/provision/kubernetes/instance.py index 2dcf38f2365..3bdd63a9251 100644 --- a/sky/provision/kubernetes/instance.py +++ b/sky/provision/kubernetes/instance.py @@ -685,7 +685,7 @@ def _create_pods(region: str, cluster_name_on_cloud: str, limits = pod_spec['spec']['containers'][0].get('resources', {}).get('limits') if limits is not None: - needs_gpus = limits.get(kubernetes_utils.GPU_RESOURCE_KEY, 0) > 0 + needs_gpus = limits.get(kubernetes_utils.get_gpu_resource_key(), 0) > 0 # TPU pods provisioned on GKE use the default containerd runtime. # Reference: https://cloud.google.com/kubernetes-engine/docs/how-to/migrate-containerd#overview # pylint: disable=line-too-long diff --git a/sky/provision/kubernetes/utils.py b/sky/provision/kubernetes/utils.py index e5bc4228a8e..ce10bc4bfc2 100644 --- a/sky/provision/kubernetes/utils.py +++ b/sky/provision/kubernetes/utils.py @@ -438,7 +438,7 @@ def detect_accelerator_resource( nodes = get_kubernetes_nodes(context) for node in nodes: cluster_resources.update(node.status.allocatable.keys()) - has_accelerator = (GPU_RESOURCE_KEY in cluster_resources or + has_accelerator = (get_gpu_resource_key() in cluster_resources or TPU_RESOURCE_KEY in cluster_resources) return has_accelerator, cluster_resources @@ -2233,10 +2233,11 @@ def get_node_accelerator_count(attribute_dict: dict) -> int: Number of accelerators allocated or available from the node. If no resource is found, it returns 0. """ - assert not (GPU_RESOURCE_KEY in attribute_dict and + gpuResourceKey = get_gpu_resource_key() + assert not (gpuResourceKey in attribute_dict and TPU_RESOURCE_KEY in attribute_dict) - if GPU_RESOURCE_KEY in attribute_dict: - return int(attribute_dict[GPU_RESOURCE_KEY]) + if gpuResourceKey in attribute_dict: + return int(attribute_dict[gpuResourceKey]) elif TPU_RESOURCE_KEY in attribute_dict: return int(attribute_dict[TPU_RESOURCE_KEY]) return 0 @@ -2395,3 +2396,22 @@ def process_skypilot_pods( num_pods = len(cluster.pods) cluster.resources_str = f'{num_pods}x {cluster.resources}' return list(clusters.values()), jobs_controllers, serve_controllers + +def get_gpu_resource_key(): + """Get the GPU resource name to use in kubernetes. + The function first checks for an environment variable. + If defined, it uses its value; otherwise, it returns the default value. + Args: + name (str): Default GPU resource name, default is "nvidia.com/gpu". + Returns: + str: The selected GPU resource name. + """ + # Retrieve GPU resource name from environment variable, if set. + # E.g., can be nvidia.com/gpu-h100, amd.com/gpu etc. + custom_name = os.getenv('CUSTOM_GPU_RESOURCE_KEY') + + # If the environment variable is not defined, return the default name + if custom_name is None: + return GPU_RESOURCE_KEY + + return custom_name diff --git a/sky/utils/kubernetes/gpu_labeler.py b/sky/utils/kubernetes/gpu_labeler.py index 14fbbdedca5..5618cac915c 100644 --- a/sky/utils/kubernetes/gpu_labeler.py +++ b/sky/utils/kubernetes/gpu_labeler.py @@ -101,7 +101,7 @@ def label(): # Get the list of nodes with GPUs gpu_nodes = [] for node in nodes: - if kubernetes_utils.GPU_RESOURCE_KEY in node.status.capacity: + if kubernetes_utils.get_gpu_resource_key() in node.status.capacity: gpu_nodes.append(node) print(f'Found {len(gpu_nodes)} GPU nodes in the cluster') @@ -142,7 +142,7 @@ def label(): if len(gpu_nodes) == 0: print('No GPU nodes found in the cluster. If you have GPU nodes, ' 'please ensure that they have the label ' - f'`{kubernetes_utils.GPU_RESOURCE_KEY}: `') + f'`{kubernetes_utils.get_gpu_resource_key()}: `') else: print('GPU labeling started - this may take 10 min or more to complete.' '\nTo check the status of GPU labeling jobs, run '