Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor scriptlet loader and improve checks #1507

Merged
merged 2 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 8 additions & 125 deletions internal/server/scriptlet/load/load.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
package load

import (
"fmt"
"slices"
"sort"
"sync"

"go.starlark.net/starlark"
"go.starlark.net/syntax"
)

// nameInstancePlacement is the name used in Starlark for the instance placement scriptlet.
Expand All @@ -19,124 +15,9 @@ const prefixQEMU = "qemu"
// nameAuthorization is the name used in Starlark for the Authorization scriptlet.
const nameAuthorization = "authorization"

// compile compiles a scriptlet.
func compile(programName string, src string, preDeclared []string) (*starlark.Program, error) {
isPreDeclared := func(name string) bool {
return slices.Contains(preDeclared, name)
}

// Parse, resolve, and compile a Starlark source file.
_, mod, err := starlark.SourceProgramOptions(syntax.LegacyFileOptions(), programName, src, isPreDeclared)
if err != nil {
return nil, err
}

return mod, nil
}

// validate validates a scriptlet by compiling it and checking the presence of required functions.
func validate(compiler func(string, string) (*starlark.Program, error), programName string, src string, requiredFunctions map[string][]string) error {
prog, err := compiler(programName, src)
if err != nil {
return err
}

thread := &starlark.Thread{Name: programName}
globals, err := prog.Init(thread, nil)
if err != nil {
return err
}

globals.Freeze()

var notFound []string
for funName, requiredArgs := range requiredFunctions {
// The function is missing if its name is not found in the globals.
funv := globals[funName]
if funv == nil {
notFound = append(notFound, funName)
continue
}

// The function is missing if its name is not bound to a function.
fun, ok := funv.(*starlark.Function)
if !ok {
notFound = append(notFound, funName)
}

// Get the function arguments.
argc := fun.NumParams()
var args []string
for i := range argc {
arg, _ := fun.Param(i)
args = append(args, arg)
}

// Return an error early if the function does not have the right arguments.
match := len(args) == len(requiredArgs)
if match {
sort.Strings(args)
sort.Strings(requiredArgs)
for i := range args {
if args[i] != requiredArgs[i] {
match = false
break
}
}
}

if !match {
return fmt.Errorf("The function %q defines arguments %q (expected: %q)", funName, args, requiredArgs)
}
}

switch len(notFound) {
case 0:
return nil
case 1:
return fmt.Errorf("The function %q is required but has not been found in the scriptlet", notFound[0])
default:
return fmt.Errorf("The functions %q are required but have not been found in the scriptlet", notFound)
}
}

var programsMu sync.Mutex
var programs = make(map[string]*starlark.Program)

// set compiles a scriptlet into memory. If empty src is provided the current program is deleted.
func set(compiler func(string, string) (*starlark.Program, error), programName string, src string) error {
if src == "" {
programsMu.Lock()
delete(programs, programName)
programsMu.Unlock()
} else {
prog, err := compiler(programName, src)
if err != nil {
return err
}

programsMu.Lock()
programs[programName] = prog
programsMu.Unlock()
}

return nil
}

// program returns a precompiled scriptlet program.
func program(name string, programName string) (*starlark.Program, *starlark.Thread, error) {
programsMu.Lock()
prog, found := programs[programName]
programsMu.Unlock()
if !found {
return nil, nil, fmt.Errorf("%s scriptlet not loaded", name)
}

thread := &starlark.Thread{Name: programName}

return prog, thread, nil
}

// InstancePlacementCompile compiles the instance placement scriptlet.
func InstancePlacementCompile(name string, src string) (*starlark.Program, error) {
return compile(name, src, []string{
Expand All @@ -156,8 +37,8 @@ func InstancePlacementCompile(name string, src string) (*starlark.Program, error

// InstancePlacementValidate validates the instance placement scriptlet.
func InstancePlacementValidate(src string) error {
return validate(InstancePlacementCompile, nameInstancePlacement, src, map[string][]string{
"instance_placement": {"request", "candidate_members"},
return validate(InstancePlacementCompile, nameInstancePlacement, src, declaration{
required("instance_placement"): {"request", "candidate_members"},
})
}

Expand Down Expand Up @@ -199,8 +80,8 @@ func QEMUCompile(name string, src string) (*starlark.Program, error) {

// QEMUValidate validates the QEMU scriptlet.
func QEMUValidate(src string) error {
return validate(QEMUCompile, prefixQEMU, src, map[string][]string{
"qemu_hook": {"hook_name"},
return validate(QEMUCompile, prefixQEMU, src, declaration{
required("qemu_hook"): {"hook_name"},
})
}

Expand All @@ -226,8 +107,10 @@ func AuthorizationCompile(name string, src string) (*starlark.Program, error) {

// AuthorizationValidate validates the authorization scriptlet.
func AuthorizationValidate(src string) error {
return validate(AuthorizationCompile, nameAuthorization, src, map[string][]string{
"authorize": {"details", "object", "entitlement"},
return validate(AuthorizationCompile, nameAuthorization, src, declaration{
required("authorize"): {"details", "object", "entitlement"},
optional("get_instance_access"): {"project_name", "instance_name"},
optional("get_project_access"): {"project_name"},
})
}

Expand Down
225 changes: 225 additions & 0 deletions internal/server/scriptlet/load/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
package load

import (
"errors"
"fmt"
"slices"
"sort"
"strings"

"go.starlark.net/starlark"
"go.starlark.net/syntax"
)

// argMismatch represents mismatching arguments in a function.
type argMismatch struct {
gotten []string
expected []string
}

// scriptletFunction represents a possibly optional function in a scriptlet.
type scriptletFunction struct {
name string
optional bool
}

// declaration is a type alias to make scriptlet declaration easier.
type declaration = map[scriptletFunction][]string

// compile compiles a scriptlet.
func compile(programName string, src string, preDeclared []string) (*starlark.Program, error) {
isPreDeclared := func(name string) bool {
return slices.Contains(preDeclared, name)
}

// Parse, resolve, and compile a Starlark source file.
_, mod, err := starlark.SourceProgramOptions(syntax.LegacyFileOptions(), programName, src, isPreDeclared)
if err != nil {
return nil, err
}

return mod, nil
}

// required is a convenience wrapper declaring a required function.
func required(name string) scriptletFunction {
return scriptletFunction{name: name, optional: false}
}

// required is a convenience wrapper declaring an optional function.
func optional(name string) scriptletFunction {
return scriptletFunction{name: name, optional: true}
}

// optionalToString converts a Boolean describing optional functions to its string representation.
func optionalToString(optional bool) string {
if optional {
return "optional"
}

return "required"
}

// validateFunction validates a single Starlark function.
func validateFunction(funv starlark.Value, requiredArgs []string) (bool, bool, *argMismatch) {
// The function is missing if its name is not found in the globals.
if funv == nil {
return true, false, nil
}

// The function is actually not a function if its name is not bound to a function.
fun, ok := funv.(*starlark.Function)
if !ok {
return false, true, nil
}

// Get the function arguments.
argc := fun.NumParams()
var args []string
for i := range argc {
arg, _ := fun.Param(i)
args = append(args, arg)
}

// The function is invalid if it does not have the right arguments.
match := len(args) == len(requiredArgs)
if match {
sort.Strings(args)
sort.Strings(requiredArgs)
for i := range args {
if args[i] != requiredArgs[i] {
match = false
break
}
}
}

if !match {
return false, false, &argMismatch{gotten: args, expected: requiredArgs}
}

return false, false, nil
}

// validate validates a scriptlet by compiling it and checking the presence of required and optional functions.
func validate(compiler func(string, string) (*starlark.Program, error), programName string, src string, scriptletFunctions declaration) error {
// Try to compile the program.
prog, err := compiler(programName, src)
if err != nil {
return err
}

thread := &starlark.Thread{Name: programName}
globals, err := prog.Init(thread, nil)
if err != nil {
return err
}

globals.Freeze()

var missingFuns []string
mistypedFuns := make(map[scriptletFunction]string)
mismatchingFuns := make(map[scriptletFunction]*argMismatch)
errorsFound := false
for fun, requiredArgs := range scriptletFunctions {
funv := globals[fun.name]
missing, mistyped, mismatch := validateFunction(funv, requiredArgs)

if missing && !fun.optional || mistyped || mismatch != nil {
errorsFound = true
if missing {
missingFuns = append(missingFuns, fun.name)
} else if mistyped {
mistypedFuns[fun] = funv.Type()
} else {
mismatchingFuns[fun] = mismatch
}
}
}

// Return early if everything looks good.
if !errorsFound {
return nil
}

errorText := ""
sentences := 0

// String builder to format pretty error messages.
appendToError := func(text string) {
var link string
if sentences == 0 {
link = ""
} else if sentences == 1 {
link = "; additionally, "
} else {
link = "; finally, "
}

errorText += link
errorText += text
sentences++
}

switch len(missingFuns) {
case 0:
case 1:
appendToError(fmt.Sprintf("the function %q is required but has not been found in the scriptlet", missingFuns[0]))
default:
appendToError(fmt.Sprintf("the functions %q are required but have not been found in the scriptlet", missingFuns))
}

if len(mistypedFuns) != 0 {
var parts []string
for fun, ty := range mistypedFuns {
parts = append(parts, fmt.Sprintf("%q should define the scriptlet’s %s function of the same name (found a value of type %s instead)", fun.name, optionalToString(fun.optional), ty))
}

appendToError(strings.Join(parts, ", "))
}

if len(mismatchingFuns) != 0 {
var parts []string
for fun, args := range mismatchingFuns {
parts = append(parts, fmt.Sprintf("the %s function %q defines arguments %q (expected %q)", optionalToString(fun.optional), fun.name, args.gotten, args.expected))
}

appendToError(strings.Join(parts, ", "))
}

return errors.New(errorText)
}

// set compiles a scriptlet into memory. If empty src is provided the current program is deleted.
func set(compiler func(string, string) (*starlark.Program, error), programName string, src string) error {
if src == "" {
programsMu.Lock()
delete(programs, programName)
programsMu.Unlock()
} else {
prog, err := compiler(programName, src)
if err != nil {
return err
}

programsMu.Lock()
programs[programName] = prog
programsMu.Unlock()
}

return nil
}

// program returns a precompiled scriptlet program.
func program(name string, programName string) (*starlark.Program, *starlark.Thread, error) {
programsMu.Lock()
prog, found := programs[programName]
programsMu.Unlock()
if !found {
return nil, nil, fmt.Errorf("%s scriptlet not loaded", name)
}

thread := &starlark.Thread{Name: programName}

return prog, thread, nil
}
Loading