forked from NVIDIA/k8s-device-plugin
-
Notifications
You must be signed in to change notification settings - Fork 0
/
nvidia.go
135 lines (111 loc) · 2.92 KB
/
nvidia.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
// Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
package main
import (
"fmt"
"log"
"os"
"strconv"
"strings"
"github.com/NVIDIA/gpu-monitoring-tools/bindings/go/nvml"
"golang.org/x/net/context"
pluginapi "k8s.io/kubernetes/pkg/kubelet/apis/deviceplugin/v1beta1"
)
func check(err error) {
if err != nil {
log.Panicln("Fatal:", err)
}
}
func generateFakeDeviceID(realID string, fakeCounter uint) string {
return fmt.Sprintf("%s-_-%d", realID, fakeCounter)
}
func extractRealDeviceID (fakeDeviceID string) string {
return strings.Split(fakeDeviceID, "-_-")[0]
}
func getNumberContainersPerGPU() (numGPU uint) {
numGPU = 1 // default value
strNum, present := os.LookupEnv(envNumberContainersPerGPU)
if !present {
return
}
rawNumGPU, err := strconv.Atoi(strNum)
if err != nil {
log.Panicf("Fatal: Could not parse %s environment variable: %v\n", envNumberContainersPerGPU, err)
}
if rawNumGPU < 1 {
log.Panicf("Fatal: invalid %s environment variable value: %v\n", envNumberContainersPerGPU, rawNumGPU)
}
numGPU = uint(rawNumGPU)
return
}
func getDevices() []*pluginapi.Device {
n, err := nvml.GetDeviceCount()
check(err)
var devs []*pluginapi.Device
log.Println("List devices")
for j := uint(0); j < getNumberContainersPerGPU(); j++ {
for i := uint(0); i < n; i++ {
d, err := nvml.NewDeviceLite(i)
check(err)
fakeID := generateFakeDeviceID(d.UUID, j)
log.Println("Device ID:", fakeID)
devs = append(devs, &pluginapi.Device{
ID: fakeID,
Health: pluginapi.Healthy,
})
}
}
return devs
}
func deviceExists(devs []*pluginapi.Device, id string) bool {
for _, d := range devs {
if d.ID == id {
return true
}
}
return false
}
func watchXIDs(ctx context.Context, devs []*pluginapi.Device, xids chan<- *pluginapi.Device) {
eventSet := nvml.NewEventSet()
defer nvml.DeleteEventSet(eventSet)
for _, d := range devs {
realDeviceID := extractRealDeviceID(d.ID)
err := nvml.RegisterEventForDevice(eventSet, nvml.XidCriticalError, realDeviceID)
if err != nil && strings.HasSuffix(err.Error(), "Not Supported") {
log.Printf("Warning: %s (%s) is too old to support healthchecking: %s. Marking it unhealthy.", realDeviceID, d.ID, err)
xids <- d
continue
}
if err != nil {
log.Panicln("Fatal:", err)
}
}
for {
select {
case <-ctx.Done():
return
default:
}
e, err := nvml.WaitForEvent(eventSet, 5000)
if err != nil && e.Etype != nvml.XidCriticalError {
continue
}
// FIXME: formalize the full list and document it.
// http://docs.nvidia.com/deploy/xid-errors/index.html#topic_4
// Application errors: the GPU should still be healthy
if e.Edata == 31 || e.Edata == 43 || e.Edata == 45 {
continue
}
if e.UUID == nil || len(*e.UUID) == 0 {
// All devices are unhealthy
for _, d := range devs {
xids <- d
}
continue
}
for _, d := range devs {
if extractRealDeviceID(d.ID) == *e.UUID {
xids <- d
}
}
}
}