diff --git a/internal/server/scriptlet/load/load.go b/internal/server/scriptlet/load/load.go index 04f1f9785c2..e91a51df4ab 100644 --- a/internal/server/scriptlet/load/load.go +++ b/internal/server/scriptlet/load/load.go @@ -37,8 +37,8 @@ func InstancePlacementCompile(name string, src string) (*starlark.Program, error // InstancePlacementValidate validates the instance placement scriptlet. func InstancePlacementValidate(src string) error { - return validate(InstancePlacementCompile, nameInstancePlacement, src, map[string][]string{ - "instance_placement": {"request", "candidate_members"}, + return validate(InstancePlacementCompile, nameInstancePlacement, src, declaration{ + required("instance_placement"): {"request", "candidate_members"}, }) } @@ -80,8 +80,8 @@ func QEMUCompile(name string, src string) (*starlark.Program, error) { // QEMUValidate validates the QEMU scriptlet. func QEMUValidate(src string) error { - return validate(QEMUCompile, prefixQEMU, src, map[string][]string{ - "qemu_hook": {"hook_name"}, + return validate(QEMUCompile, prefixQEMU, src, declaration{ + required("qemu_hook"): {"hook_name"}, }) } @@ -107,8 +107,10 @@ func AuthorizationCompile(name string, src string) (*starlark.Program, error) { // AuthorizationValidate validates the authorization scriptlet. func AuthorizationValidate(src string) error { - return validate(AuthorizationCompile, nameAuthorization, src, map[string][]string{ - "authorize": {"details", "object", "entitlement"}, + return validate(AuthorizationCompile, nameAuthorization, src, declaration{ + required("authorize"): {"details", "object", "entitlement"}, + optional("get_instance_access"): {"project_name", "instance_name"}, + optional("get_project_access"): {"project_name"}, }) } diff --git a/internal/server/scriptlet/load/utils.go b/internal/server/scriptlet/load/utils.go index cf3426a4fad..a01fd015f25 100644 --- a/internal/server/scriptlet/load/utils.go +++ b/internal/server/scriptlet/load/utils.go @@ -1,14 +1,31 @@ package load import ( + "errors" "fmt" "slices" "sort" + "strings" "go.starlark.net/starlark" "go.starlark.net/syntax" ) +// argMismatch represents mismatching arguments in a function. +type argMismatch struct { + gotten []string + expected []string +} + +// scriptletFunction represents a possibly optional function in a scriptlet. +type scriptletFunction struct { + name string + optional bool +} + +// declaration is a type alias to make scriptlet declaration easier. +type declaration = map[scriptletFunction][]string + // compile compiles a scriptlet. func compile(programName string, src string, preDeclared []string) (*starlark.Program, error) { isPreDeclared := func(name string) bool { @@ -24,8 +41,69 @@ func compile(programName string, src string, preDeclared []string) (*starlark.Pr return mod, nil } -// validate validates a scriptlet by compiling it and checking the presence of required functions. -func validate(compiler func(string, string) (*starlark.Program, error), programName string, src string, requiredFunctions map[string][]string) error { +// required is a convenience wrapper declaring a required function. +func required(name string) scriptletFunction { + return scriptletFunction{name: name, optional: false} +} + +// required is a convenience wrapper declaring an optional function. +func optional(name string) scriptletFunction { + return scriptletFunction{name: name, optional: true} +} + +// optionalToString converts a Boolean describing optional functions to its string representation. +func optionalToString(optional bool) string { + if optional { + return "optional" + } + + return "required" +} + +// validateFunction validates a single Starlark function. +func validateFunction(funv starlark.Value, requiredArgs []string) (bool, bool, *argMismatch) { + // The function is missing if its name is not found in the globals. + if funv == nil { + return true, false, nil + } + + // The function is actually not a function if its name is not bound to a function. + fun, ok := funv.(*starlark.Function) + if !ok { + return false, true, nil + } + + // Get the function arguments. + argc := fun.NumParams() + var args []string + for i := range argc { + arg, _ := fun.Param(i) + args = append(args, arg) + } + + // The function is invalid if it does not have the right arguments. + match := len(args) == len(requiredArgs) + if match { + sort.Strings(args) + sort.Strings(requiredArgs) + for i := range args { + if args[i] != requiredArgs[i] { + match = false + break + } + } + } + + if !match { + return false, false, &argMismatch{gotten: args, expected: requiredArgs} + } + + return false, false, nil +} + +// validate validates a scriptlet by compiling it and checking the presence of required and optional functions. +func validate(compiler func(string, string) (*starlark.Program, error), programName string, src string, scriptletFunctions declaration) error { + // Try to compile the program. prog, err := compiler(programName, src) if err != nil { return err @@ -39,55 +117,77 @@ func validate(compiler func(string, string) (*starlark.Program, error), programN globals.Freeze() - var notFound []string - for funName, requiredArgs := range requiredFunctions { - // The function is missing if its name is not found in the globals. - funv := globals[funName] - if funv == nil { - notFound = append(notFound, funName) - continue - } - - // The function is missing if its name is not bound to a function. - fun, ok := funv.(*starlark.Function) - if !ok { - notFound = append(notFound, funName) + var missingFuns []string + mistypedFuns := make(map[scriptletFunction]string) + mismatchingFuns := make(map[scriptletFunction]*argMismatch) + errorsFound := false + for fun, requiredArgs := range scriptletFunctions { + funv := globals[fun.name] + missing, mistyped, mismatch := validateFunction(funv, requiredArgs) + + if missing && !fun.optional || mistyped || mismatch != nil { + errorsFound = true + if missing { + missingFuns = append(missingFuns, fun.name) + } else if mistyped { + mistypedFuns[fun] = funv.Type() + } else { + mismatchingFuns[fun] = mismatch + } } + } - // Get the function arguments. - argc := fun.NumParams() - var args []string - for i := range argc { - arg, _ := fun.Param(i) - args = append(args, arg) - } + // Return early if everything looks good. + if !errorsFound { + return nil + } - // Return an error early if the function does not have the right arguments. - match := len(args) == len(requiredArgs) - if match { - sort.Strings(args) - sort.Strings(requiredArgs) - for i := range args { - if args[i] != requiredArgs[i] { - match = false - break - } - } + errorText := "" + sentences := 0 + + // String builder to format pretty error messages. + appendToError := func(text string) { + var link string + if sentences == 0 { + link = "" + } else if sentences == 1 { + link = "; additionally, " + } else { + link = "; finally, " } - if !match { - return fmt.Errorf("The function %q defines arguments %q (expected: %q)", funName, args, requiredArgs) - } + errorText += link + errorText += text + sentences++ } - switch len(notFound) { + switch len(missingFuns) { case 0: - return nil case 1: - return fmt.Errorf("The function %q is required but has not been found in the scriptlet", notFound[0]) + appendToError(fmt.Sprintf("the function %q is required but has not been found in the scriptlet", missingFuns[0])) default: - return fmt.Errorf("The functions %q are required but have not been found in the scriptlet", notFound) + appendToError(fmt.Sprintf("the functions %q are required but have not been found in the scriptlet", missingFuns)) + } + + if len(mistypedFuns) != 0 { + var parts []string + for fun, ty := range mistypedFuns { + parts = append(parts, fmt.Sprintf("%q should define the scriptlet’s %s function of the same name (found a value of type %s instead)", fun.name, optionalToString(fun.optional), ty)) + } + + appendToError(strings.Join(parts, ", ")) } + + if len(mismatchingFuns) != 0 { + var parts []string + for fun, args := range mismatchingFuns { + parts = append(parts, fmt.Sprintf("the %s function %q defines arguments %q (expected %q)", optionalToString(fun.optional), fun.name, args.gotten, args.expected)) + } + + appendToError(strings.Join(parts, ", ")) + } + + return errors.New(errorText) } // set compiles a scriptlet into memory. If empty src is provided the current program is deleted.