Skip to content

Commit

Permalink
incusd/scriptlet: Add support for checking optional functions and imp…
Browse files Browse the repository at this point in the history
…rove error messages

Signed-off-by: Benjamin Somers <[email protected]>
  • Loading branch information
bensmrs committed Dec 13, 2024
1 parent 564b775 commit ab51fa6
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 46 deletions.
14 changes: 8 additions & 6 deletions internal/server/scriptlet/load/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ func InstancePlacementCompile(name string, src string) (*starlark.Program, error

// InstancePlacementValidate validates the instance placement scriptlet.
func InstancePlacementValidate(src string) error {
return validate(InstancePlacementCompile, nameInstancePlacement, src, map[string][]string{
"instance_placement": {"request", "candidate_members"},
return validate(InstancePlacementCompile, nameInstancePlacement, src, declaration{
required("instance_placement"): {"request", "candidate_members"},
})
}

Expand Down Expand Up @@ -80,8 +80,8 @@ func QEMUCompile(name string, src string) (*starlark.Program, error) {

// QEMUValidate validates the QEMU scriptlet.
func QEMUValidate(src string) error {
return validate(QEMUCompile, prefixQEMU, src, map[string][]string{
"qemu_hook": {"hook_name"},
return validate(QEMUCompile, prefixQEMU, src, declaration{
required("qemu_hook"): {"hook_name"},
})
}

Expand All @@ -107,8 +107,10 @@ func AuthorizationCompile(name string, src string) (*starlark.Program, error) {

// AuthorizationValidate validates the authorization scriptlet.
func AuthorizationValidate(src string) error {
return validate(AuthorizationCompile, nameAuthorization, src, map[string][]string{
"authorize": {"details", "object", "entitlement"},
return validate(AuthorizationCompile, nameAuthorization, src, declaration{
required("authorize"): {"details", "object", "entitlement"},
optional("get_instance_access"): {"project_name", "instance_name"},
optional("get_project_access"): {"project_name"},
})
}

Expand Down
180 changes: 140 additions & 40 deletions internal/server/scriptlet/load/utils.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,31 @@
package load

import (
"errors"
"fmt"
"slices"
"sort"
"strings"

"go.starlark.net/starlark"
"go.starlark.net/syntax"
)

// argMismatch represents mismatching arguments in a function.
type argMismatch struct {
gotten []string
expected []string
}

// scriptletFunction represents a possibly optional function in a scriptlet.
type scriptletFunction struct {
name string
optional bool
}

// declaration is a type alias to make scriptlet declaration easier.
type declaration = map[scriptletFunction][]string

// compile compiles a scriptlet.
func compile(programName string, src string, preDeclared []string) (*starlark.Program, error) {
isPreDeclared := func(name string) bool {
Expand All @@ -24,8 +41,69 @@ func compile(programName string, src string, preDeclared []string) (*starlark.Pr
return mod, nil
}

// validate validates a scriptlet by compiling it and checking the presence of required functions.
func validate(compiler func(string, string) (*starlark.Program, error), programName string, src string, requiredFunctions map[string][]string) error {
// required is a convenience wrapper declaring a required function.
func required(name string) scriptletFunction {
return scriptletFunction{name: name, optional: false}
}

// required is a convenience wrapper declaring an optional function.
func optional(name string) scriptletFunction {
return scriptletFunction{name: name, optional: true}
}

// optionalToString converts a Boolean describing optional functions to its string representation.
func optionalToString(optional bool) string {
if optional {
return "optional"
}

return "required"
}

// validateFunction validates a single Starlark function.
func validateFunction(funv starlark.Value, requiredArgs []string) (bool, bool, *argMismatch) {
// The function is missing if its name is not found in the globals.
if funv == nil {
return true, false, nil
}

// The function is actually not a function if its name is not bound to a function.
fun, ok := funv.(*starlark.Function)
if !ok {
return false, true, nil
}

// Get the function arguments.
argc := fun.NumParams()
var args []string
for i := range argc {
arg, _ := fun.Param(i)
args = append(args, arg)
}

// The function is invalid if it does not have the right arguments.
match := len(args) == len(requiredArgs)
if match {
sort.Strings(args)
sort.Strings(requiredArgs)
for i := range args {
if args[i] != requiredArgs[i] {
match = false
break
}
}
}

if !match {
return false, false, &argMismatch{gotten: args, expected: requiredArgs}
}

return false, false, nil
}

// validate validates a scriptlet by compiling it and checking the presence of required and optional functions.
func validate(compiler func(string, string) (*starlark.Program, error), programName string, src string, scriptletFunctions declaration) error {
// Try to compile the program.
prog, err := compiler(programName, src)
if err != nil {
return err
Expand All @@ -39,55 +117,77 @@ func validate(compiler func(string, string) (*starlark.Program, error), programN

globals.Freeze()

var notFound []string
for funName, requiredArgs := range requiredFunctions {
// The function is missing if its name is not found in the globals.
funv := globals[funName]
if funv == nil {
notFound = append(notFound, funName)
continue
}

// The function is missing if its name is not bound to a function.
fun, ok := funv.(*starlark.Function)
if !ok {
notFound = append(notFound, funName)
var missingFuns []string
mistypedFuns := make(map[scriptletFunction]string)
mismatchingFuns := make(map[scriptletFunction]*argMismatch)
errorsFound := false
for fun, requiredArgs := range scriptletFunctions {
funv := globals[fun.name]
missing, mistyped, mismatch := validateFunction(funv, requiredArgs)

if missing && !fun.optional || mistyped || mismatch != nil {
errorsFound = true
if missing {
missingFuns = append(missingFuns, fun.name)
} else if mistyped {
mistypedFuns[fun] = funv.Type()
} else {
mismatchingFuns[fun] = mismatch
}
}
}

// Get the function arguments.
argc := fun.NumParams()
var args []string
for i := range argc {
arg, _ := fun.Param(i)
args = append(args, arg)
}
// Return early if everything looks good.
if !errorsFound {
return nil
}

// Return an error early if the function does not have the right arguments.
match := len(args) == len(requiredArgs)
if match {
sort.Strings(args)
sort.Strings(requiredArgs)
for i := range args {
if args[i] != requiredArgs[i] {
match = false
break
}
}
errorText := ""
sentences := 0

// String builder to format pretty error messages.
appendToError := func(text string) {
var link string
if sentences == 0 {
link = ""
} else if sentences == 1 {
link = "; additionally, "
} else {
link = "; finally, "
}

if !match {
return fmt.Errorf("The function %q defines arguments %q (expected: %q)", funName, args, requiredArgs)
}
errorText += link
errorText += text
sentences++
}

switch len(notFound) {
switch len(missingFuns) {
case 0:
return nil
case 1:
return fmt.Errorf("The function %q is required but has not been found in the scriptlet", notFound[0])
appendToError(fmt.Sprintf("the function %q is required but has not been found in the scriptlet", missingFuns[0]))
default:
return fmt.Errorf("The functions %q are required but have not been found in the scriptlet", notFound)
appendToError(fmt.Sprintf("the functions %q are required but have not been found in the scriptlet", missingFuns))
}

if len(mistypedFuns) != 0 {
var parts []string
for fun, ty := range mistypedFuns {
parts = append(parts, fmt.Sprintf("%q should define the scriptlet’s %s function of the same name (found a value of type %s instead)", fun.name, optionalToString(fun.optional), ty))
}

appendToError(strings.Join(parts, ", "))
}

if len(mismatchingFuns) != 0 {
var parts []string
for fun, args := range mismatchingFuns {
parts = append(parts, fmt.Sprintf("the %s function %q defines arguments %q (expected %q)", optionalToString(fun.optional), fun.name, args.gotten, args.expected))
}

appendToError(strings.Join(parts, ", "))
}

return errors.New(errorText)
}

// set compiles a scriptlet into memory. If empty src is provided the current program is deleted.
Expand Down

0 comments on commit ab51fa6

Please sign in to comment.