diff --git a/sky/provision/kubernetes/instance.py b/sky/provision/kubernetes/instance.py index a233b7a944f..9402b349e4f 100644 --- a/sky/provision/kubernetes/instance.py +++ b/sky/provision/kubernetes/instance.py @@ -530,7 +530,9 @@ def _create_pods(region: str, cluster_name_on_cloud: str, 'override runtimeClassName in ~/.sky/config.yaml. ' 'For more details, refer to https://skypilot.readthedocs.io/en/latest/reference/config.html') # pylint: disable=line-too-long - if nvidia_runtime_exists: + needs_gpus = (pod_spec['spec']['containers'][0].get('resources', {}).get( + 'limits', {}).get('nvidia.com/gpu', 0) > 0) + if nvidia_runtime_exists and needs_gpus: pod_spec['spec']['runtimeClassName'] = 'nvidia' created_pods = {}