diff --git a/cmd/plugins/cdi-device-injector/cdi-device-injector.go b/cmd/plugins/cdi-device-injector/cdi-device-injector.go index 07834b4c2..4a8cfc5f6 100644 --- a/cmd/plugins/cdi-device-injector/cdi-device-injector.go +++ b/cmd/plugins/cdi-device-injector/cdi-device-injector.go @@ -19,6 +19,7 @@ import ( "errors" "flag" "fmt" + "path/filepath" "strings" "github.com/sirupsen/logrus" @@ -46,7 +47,7 @@ type plugin struct { } // CreateContainer handles container creation requests. -func (p *plugin) CreateContainer(_ context.Context, pod *api.PodSandbox, container *api.Container) (_ *api.ContainerAdjustment, _ []*api.ContainerUpdate, err error) { +func (p *plugin) CreateContainer(ctx context.Context, pod *api.PodSandbox, container *api.Container) (_ *api.ContainerAdjustment, _ []*api.ContainerUpdate, err error) { defer func() { if err != nil { log.Error(err) @@ -60,6 +61,14 @@ func (p *plugin) CreateContainer(_ context.Context, pod *api.PodSandbox, contain log.Infof("CreateContainer %s", name) } + allowedCDIDevicesPattern, err := getAllowedCDIDevicesPattern(ctx, pod) + if err != nil { + return nil, nil, fmt.Errorf("failed to get allowed CDI devices: %w", err) + } + if allowedCDIDevicesPattern == "" { + return nil, nil, nil + } + cdiDevices, err := parseCdiDevices(pod.Annotations, container.Name) if err != nil { return nil, nil, fmt.Errorf("failed to parse CDI Device annotations: %w", err) @@ -69,14 +78,46 @@ func (p *plugin) CreateContainer(_ context.Context, pod *api.PodSandbox, contain return nil, nil, nil } + var allowedCDIDevices []string + for _, cdiDevice := range cdiDevices { + match, _ := filepath.Match(allowedCDIDevicesPattern, cdiDevice) + if !match { + continue + } + allowedCDIDevices = append(allowedCDIDevices, cdiDevice) + } + adjust := &api.ContainerAdjustment{} - if _, err := p.cdiCache.InjectDevices(adjust, cdiDevices...); err != nil { + if _, err := p.cdiCache.InjectDevices(adjust, allowedCDIDevices...); err != nil { return nil, nil, fmt.Errorf("CDI device injection failed: %w", err) } return adjust, nil, nil } +func getAllowedCDIDevicesPattern(ctx context.Context, pod *api.PodSandbox) (string, error) { + namespace := pod.GetNamespace() + if namespace == "" { + return "", nil + } + + annotations, err := getAnnotationsForNamespace(ctx, namespace) + if err != nil { + return "", fmt.Errorf("could not get annotations for namespace: %w", err) + } + + pattern, ok := annotations[cdiDeviceKey+"/allow"] + if !ok { + return "", nil + } + return pattern, nil +} + +// TODO: We probably need to use a kubect client here to get the namespace annotations. +func getAnnotationsForNamespace(_ context.Context, namespace string) (map[string]string, error) { + return nil, nil +} + func parseCdiDevices(annotations map[string]string, ctr string) ([]string, error) { var errs error var cdiDevices []string