diff --git a/pkg/nvcdi/lib-nvml.go b/pkg/nvcdi/lib-nvml.go index ab7cb8ba2..d461aa511 100644 --- a/pkg/nvcdi/lib-nvml.go +++ b/pkg/nvcdi/lib-nvml.go @@ -27,6 +27,7 @@ import ( "tags.cncf.io/container-device-interface/specs-go" "github.com/NVIDIA/nvidia-container-toolkit/internal/edits" + "github.com/NVIDIA/nvidia-container-toolkit/internal/nvsandboxutils" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" ) @@ -52,6 +53,19 @@ func (l *nvmllib) GetAllDeviceSpecs() ([]specs.Device, error) { } }() + if l.nvsandboxutilslib != nil { + if r := l.nvsandboxutilslib.Init(l.driverRoot); r != nvsandboxutils.SUCCESS { + l.logger.Warningf("Failed to init nvsandboxutils: %v; ignoring", r) + l.nvsandboxutilslib = nil + } + defer func() { + if l.nvsandboxutilslib == nil { + return + } + _ = l.nvsandboxutilslib.Shutdown() + }() + } + gpuDeviceSpecs, err := l.getGPUDeviceSpecs() if err != nil { return nil, err diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index d2db3b6c4..35e72d392 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -26,6 +26,7 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" + "github.com/NVIDIA/nvidia-container-toolkit/internal/nvsandboxutils" "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform" @@ -43,6 +44,7 @@ type wrapper struct { type nvcdilib struct { logger logger.Interface nvmllib nvml.Interface + nvsandboxutilslib nvsandboxutils.Interface mode string devicelib device.Interface deviceNamers DeviceNamers @@ -107,6 +109,19 @@ func New(opts ...Option) (Interface, error) { } l.nvmllib = nvml.New(nvmlOpts...) } + if l.nvsandboxutilslib == nil { + var nvsandboxutilsOpts []nvsandboxutils.LibraryOption + // Set the library path for libnvidia-sandboxutils + candidates, err := l.driver.Libraries().Locate("libnvidia-sandboxutils.so.1") + if err != nil { + l.logger.Warningf("Ignoring error in locating libnvidia-sandboxutils.so.1: %v", err) + } else { + libNvidiaSandboxutilsPath := candidates[0] + l.logger.Infof("Using %v", libNvidiaSandboxutilsPath) + nvsandboxutilsOpts = append(nvsandboxutilsOpts, nvsandboxutils.WithLibraryPath(libNvidiaSandboxutilsPath)) + } + l.nvsandboxutilslib = nvsandboxutils.New(nvsandboxutilsOpts...) + } if l.devicelib == nil { l.devicelib = device.New(l.nvmllib) } @@ -213,7 +228,7 @@ func (l *nvcdilib) resolveMode() (rmode string) { } // getCudaVersion returns the CUDA version of the current system. -func (l *nvcdilib) getCudaVersion() (string, error) { +func (l *nvcdilib) getCudaVersionNvml() (string, error) { if hasNVML, reason := l.infolib.HasNvml(); !hasNVML { return "", fmt.Errorf("nvml not detected: %v", reason) } @@ -236,3 +251,22 @@ func (l *nvcdilib) getCudaVersion() (string, error) { } return version, nil } + +func (l *nvcdilib) getCudaVersionNvsandboxutils() (string, error) { + // Sandboxutils initialization should happen before this function is called + version, ret := l.nvsandboxutilslib.GetDriverVersion() + if ret != nvsandboxutils.SUCCESS { + return "", fmt.Errorf("%v", ret) + } + return version, nil +} + +func (l *nvcdilib) getCudaVersion() (string, error) { + version, err := l.getCudaVersionNvsandboxutils() + if err == nil { + return version, err + } + + // Fallback to NVML + return l.getCudaVersionNvml() +}