Skip to content

Commit

Permalink
Merge pull request #1507 from bensmrs/main
Browse files Browse the repository at this point in the history
Refactor scriptlet loader and improve checks
  • Loading branch information
stgraber authored Dec 13, 2024
2 parents 409c0d4 + ab51fa6 commit c280784
Show file tree
Hide file tree
Showing 2 changed files with 233 additions and 125 deletions.
133 changes: 8 additions & 125 deletions internal/server/scriptlet/load/load.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
package load

import (
"fmt"
"slices"
"sort"
"sync"

"go.starlark.net/starlark"
"go.starlark.net/syntax"
)

// nameInstancePlacement is the name used in Starlark for the instance placement scriptlet.
Expand All @@ -19,124 +15,9 @@ const prefixQEMU = "qemu"
// nameAuthorization is the name used in Starlark for the Authorization scriptlet.
const nameAuthorization = "authorization"

// compile compiles a scriptlet.
func compile(programName string, src string, preDeclared []string) (*starlark.Program, error) {
isPreDeclared := func(name string) bool {
return slices.Contains(preDeclared, name)
}

// Parse, resolve, and compile a Starlark source file.
_, mod, err := starlark.SourceProgramOptions(syntax.LegacyFileOptions(), programName, src, isPreDeclared)
if err != nil {
return nil, err
}

return mod, nil
}

// validate validates a scriptlet by compiling it and checking the presence of required functions.
func validate(compiler func(string, string) (*starlark.Program, error), programName string, src string, requiredFunctions map[string][]string) error {
prog, err := compiler(programName, src)
if err != nil {
return err
}

thread := &starlark.Thread{Name: programName}
globals, err := prog.Init(thread, nil)
if err != nil {
return err
}

globals.Freeze()

var notFound []string
for funName, requiredArgs := range requiredFunctions {
// The function is missing if its name is not found in the globals.
funv := globals[funName]
if funv == nil {
notFound = append(notFound, funName)
continue
}

// The function is missing if its name is not bound to a function.
fun, ok := funv.(*starlark.Function)
if !ok {
notFound = append(notFound, funName)
}

// Get the function arguments.
argc := fun.NumParams()
var args []string
for i := range argc {
arg, _ := fun.Param(i)
args = append(args, arg)
}

// Return an error early if the function does not have the right arguments.
match := len(args) == len(requiredArgs)
if match {
sort.Strings(args)
sort.Strings(requiredArgs)
for i := range args {
if args[i] != requiredArgs[i] {
match = false
break
}
}
}

if !match {
return fmt.Errorf("The function %q defines arguments %q (expected: %q)", funName, args, requiredArgs)
}
}

switch len(notFound) {
case 0:
return nil
case 1:
return fmt.Errorf("The function %q is required but has not been found in the scriptlet", notFound[0])
default:
return fmt.Errorf("The functions %q are required but have not been found in the scriptlet", notFound)
}
}

var programsMu sync.Mutex
var programs = make(map[string]*starlark.Program)

// set compiles a scriptlet into memory. If empty src is provided the current program is deleted.
func set(compiler func(string, string) (*starlark.Program, error), programName string, src string) error {
if src == "" {
programsMu.Lock()
delete(programs, programName)
programsMu.Unlock()
} else {
prog, err := compiler(programName, src)
if err != nil {
return err
}

programsMu.Lock()
programs[programName] = prog
programsMu.Unlock()
}

return nil
}

// program returns a precompiled scriptlet program.
func program(name string, programName string) (*starlark.Program, *starlark.Thread, error) {
programsMu.Lock()
prog, found := programs[programName]
programsMu.Unlock()
if !found {
return nil, nil, fmt.Errorf("%s scriptlet not loaded", name)
}

thread := &starlark.Thread{Name: programName}

return prog, thread, nil
}

// InstancePlacementCompile compiles the instance placement scriptlet.
func InstancePlacementCompile(name string, src string) (*starlark.Program, error) {
return compile(name, src, []string{
Expand All @@ -156,8 +37,8 @@ func InstancePlacementCompile(name string, src string) (*starlark.Program, error

// InstancePlacementValidate validates the instance placement scriptlet.
func InstancePlacementValidate(src string) error {
return validate(InstancePlacementCompile, nameInstancePlacement, src, map[string][]string{
"instance_placement": {"request", "candidate_members"},
return validate(InstancePlacementCompile, nameInstancePlacement, src, declaration{
required("instance_placement"): {"request", "candidate_members"},
})
}

Expand Down Expand Up @@ -199,8 +80,8 @@ func QEMUCompile(name string, src string) (*starlark.Program, error) {

// QEMUValidate validates the QEMU scriptlet.
func QEMUValidate(src string) error {
return validate(QEMUCompile, prefixQEMU, src, map[string][]string{
"qemu_hook": {"hook_name"},
return validate(QEMUCompile, prefixQEMU, src, declaration{
required("qemu_hook"): {"hook_name"},
})
}

Expand All @@ -226,8 +107,10 @@ func AuthorizationCompile(name string, src string) (*starlark.Program, error) {

// AuthorizationValidate validates the authorization scriptlet.
func AuthorizationValidate(src string) error {
return validate(AuthorizationCompile, nameAuthorization, src, map[string][]string{
"authorize": {"details", "object", "entitlement"},
return validate(AuthorizationCompile, nameAuthorization, src, declaration{
required("authorize"): {"details", "object", "entitlement"},
optional("get_instance_access"): {"project_name", "instance_name"},
optional("get_project_access"): {"project_name"},
})
}

Expand Down
225 changes: 225 additions & 0 deletions internal/server/scriptlet/load/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
package load

import (
"errors"
"fmt"
"slices"
"sort"
"strings"

"go.starlark.net/starlark"
"go.starlark.net/syntax"
)

// argMismatch represents mismatching arguments in a function.
type argMismatch struct {
gotten []string
expected []string
}

// scriptletFunction represents a possibly optional function in a scriptlet.
type scriptletFunction struct {
name string
optional bool
}

// declaration is a type alias to make scriptlet declaration easier.
type declaration = map[scriptletFunction][]string

// compile compiles a scriptlet.
func compile(programName string, src string, preDeclared []string) (*starlark.Program, error) {
isPreDeclared := func(name string) bool {
return slices.Contains(preDeclared, name)
}

// Parse, resolve, and compile a Starlark source file.
_, mod, err := starlark.SourceProgramOptions(syntax.LegacyFileOptions(), programName, src, isPreDeclared)
if err != nil {
return nil, err
}

return mod, nil
}

// required is a convenience wrapper declaring a required function.
func required(name string) scriptletFunction {
return scriptletFunction{name: name, optional: false}
}

// required is a convenience wrapper declaring an optional function.
func optional(name string) scriptletFunction {
return scriptletFunction{name: name, optional: true}
}

// optionalToString converts a Boolean describing optional functions to its string representation.
func optionalToString(optional bool) string {
if optional {
return "optional"
}

return "required"
}

// validateFunction validates a single Starlark function.
func validateFunction(funv starlark.Value, requiredArgs []string) (bool, bool, *argMismatch) {
// The function is missing if its name is not found in the globals.
if funv == nil {
return true, false, nil
}

// The function is actually not a function if its name is not bound to a function.
fun, ok := funv.(*starlark.Function)
if !ok {
return false, true, nil
}

// Get the function arguments.
argc := fun.NumParams()
var args []string
for i := range argc {
arg, _ := fun.Param(i)
args = append(args, arg)
}

// The function is invalid if it does not have the right arguments.
match := len(args) == len(requiredArgs)
if match {
sort.Strings(args)
sort.Strings(requiredArgs)
for i := range args {
if args[i] != requiredArgs[i] {
match = false
break
}
}
}

if !match {
return false, false, &argMismatch{gotten: args, expected: requiredArgs}
}

return false, false, nil
}

// validate validates a scriptlet by compiling it and checking the presence of required and optional functions.
func validate(compiler func(string, string) (*starlark.Program, error), programName string, src string, scriptletFunctions declaration) error {
// Try to compile the program.
prog, err := compiler(programName, src)
if err != nil {
return err
}

thread := &starlark.Thread{Name: programName}
globals, err := prog.Init(thread, nil)
if err != nil {
return err
}

globals.Freeze()

var missingFuns []string
mistypedFuns := make(map[scriptletFunction]string)
mismatchingFuns := make(map[scriptletFunction]*argMismatch)
errorsFound := false
for fun, requiredArgs := range scriptletFunctions {
funv := globals[fun.name]
missing, mistyped, mismatch := validateFunction(funv, requiredArgs)

if missing && !fun.optional || mistyped || mismatch != nil {
errorsFound = true
if missing {
missingFuns = append(missingFuns, fun.name)
} else if mistyped {
mistypedFuns[fun] = funv.Type()
} else {
mismatchingFuns[fun] = mismatch
}
}
}

// Return early if everything looks good.
if !errorsFound {
return nil
}

errorText := ""
sentences := 0

// String builder to format pretty error messages.
appendToError := func(text string) {
var link string
if sentences == 0 {
link = ""
} else if sentences == 1 {
link = "; additionally, "
} else {
link = "; finally, "
}

errorText += link
errorText += text
sentences++
}

switch len(missingFuns) {
case 0:
case 1:
appendToError(fmt.Sprintf("the function %q is required but has not been found in the scriptlet", missingFuns[0]))
default:
appendToError(fmt.Sprintf("the functions %q are required but have not been found in the scriptlet", missingFuns))
}

if len(mistypedFuns) != 0 {
var parts []string
for fun, ty := range mistypedFuns {
parts = append(parts, fmt.Sprintf("%q should define the scriptlet’s %s function of the same name (found a value of type %s instead)", fun.name, optionalToString(fun.optional), ty))
}

appendToError(strings.Join(parts, ", "))
}

if len(mismatchingFuns) != 0 {
var parts []string
for fun, args := range mismatchingFuns {
parts = append(parts, fmt.Sprintf("the %s function %q defines arguments %q (expected %q)", optionalToString(fun.optional), fun.name, args.gotten, args.expected))
}

appendToError(strings.Join(parts, ", "))
}

return errors.New(errorText)
}

// set compiles a scriptlet into memory. If empty src is provided the current program is deleted.
func set(compiler func(string, string) (*starlark.Program, error), programName string, src string) error {
if src == "" {
programsMu.Lock()
delete(programs, programName)
programsMu.Unlock()
} else {
prog, err := compiler(programName, src)
if err != nil {
return err
}

programsMu.Lock()
programs[programName] = prog
programsMu.Unlock()
}

return nil
}

// program returns a precompiled scriptlet program.
func program(name string, programName string) (*starlark.Program, *starlark.Thread, error) {
programsMu.Lock()
prog, found := programs[programName]
programsMu.Unlock()
if !found {
return nil, nil, fmt.Errorf("%s scriptlet not loaded", name)
}

thread := &starlark.Thread{Name: programName}

return prog, thread, nil
}

0 comments on commit c280784

Please sign in to comment.