From 9dc72ba7da950c9531535f7c48b5942454233326 Mon Sep 17 00:00:00 2001 From: Joerg Reichelt Date: Mon, 15 Apr 2024 14:11:42 -0700 Subject: [PATCH] added shorthand exprssions `any`/`all`/`count` --- pkg/script/common/key-words.go | 3 + pkg/script/common/scope.go | 27 +++- pkg/script/expressions/all-expression.go | 150 ++++++++++++++++++++ pkg/script/expressions/any-expression.go | 150 ++++++++++++++++++++ pkg/script/expressions/count-expression.go | 155 +++++++++++++++++++++ pkg/script/expressions/expression-list.go | 9 ++ pkg/script/statements/if-statement.go | 91 ++++++------ pkg/script/statements/loop-statement.go | 5 + test/risk-category.yaml | 18 ++- 9 files changed, 564 insertions(+), 44 deletions(-) create mode 100644 pkg/script/expressions/all-expression.go create mode 100644 pkg/script/expressions/any-expression.go create mode 100644 pkg/script/expressions/count-expression.go diff --git a/pkg/script/common/key-words.go b/pkg/script/common/key-words.go index 9a3fb1e8..c437d0ac 100644 --- a/pkg/script/common/key-words.go +++ b/pkg/script/common/key-words.go @@ -24,8 +24,11 @@ const ( Then = "then" Else = "else" + All = "all" And = "and" + Any = "any" Contains = "contains" + Count = "count" Equal = "equal" EqualOrGreater = "equal-or-greater" EqualOrLess = "equal-or-less" diff --git a/pkg/script/common/scope.go b/pkg/script/common/scope.go index 88bcb64f..c6b54e3a 100644 --- a/pkg/script/common/scope.go +++ b/pkg/script/common/scope.go @@ -14,6 +14,7 @@ type Scope struct { Model map[string]any Risk map[string]any Methods map[string]Statement + iterator Value returnValue Value } @@ -110,6 +111,20 @@ func (what *Scope) Get(name string) (Value, bool) { return nil, false } +func (what *Scope) GetIterator() Value { + return what.iterator +} + +func (what *Scope) SetIterator(value Value) { + what.iterator = value +} + +func (what *Scope) SwapIterator(value Value) Value { + var currentIterator Value + currentIterator, what.iterator = what.iterator, value + return currentIterator +} + func (what *Scope) SetReturnValue(value Value) { what.returnValue = value } @@ -153,9 +168,15 @@ func (what *Scope) getVar(path []string) (Value, bool) { return nil, false } - field, ok := what.Vars[strings.ToLower(path[0])] - if !ok { - return nil, false + var field Value + if len(path[0]) == 0 { + field = what.iterator + } else { + var fieldOk bool + field, fieldOk = what.Vars[strings.ToLower(path[0])] + if !fieldOk { + return nil, false + } } if len(path) == 1 { diff --git a/pkg/script/expressions/all-expression.go b/pkg/script/expressions/all-expression.go new file mode 100644 index 00000000..2e2db563 --- /dev/null +++ b/pkg/script/expressions/all-expression.go @@ -0,0 +1,150 @@ +package expressions + +import ( + "fmt" + "github.com/threagile/threagile/pkg/script/common" +) + +type AllExpression struct { + literal string + in common.ValueExpression + item string + index string + expression common.BoolExpression +} + +func (what *AllExpression) ParseBool(script any) (common.BoolExpression, any, error) { + what.literal = common.ToLiteral(script) + + switch script.(type) { + case map[string]any: + for key, value := range script.(map[string]any) { + switch key { + case common.In: + item, errorExpression, itemError := new(ValueExpression).ParseValue(value) + if itemError != nil { + return nil, errorExpression, fmt.Errorf("failed to parse %q of any-expression: %v", key, itemError) + } + + what.in = item + + case common.Item: + text, ok := value.(string) + if !ok { + return nil, value, fmt.Errorf("failed to parse %q of any-expression: expected string, got %T", key, value) + } + + what.item = text + + case common.Index: + text, ok := value.(string) + if !ok { + return nil, value, fmt.Errorf("failed to parse %q of any-expression: expected string, got %T", key, value) + } + + what.index = text + + default: + if what.expression != nil { + return nil, script, fmt.Errorf("failed to parse any-expression: additional bool expression %q", key) + } + + expression, errorScript, itemError := new(ExpressionList).ParseAny(map[string]any{key: value}) + if itemError != nil { + return nil, errorScript, fmt.Errorf("failed to parse any-expression: %v", itemError) + } + + boolExpression, ok := expression.(common.BoolExpression) + if !ok { + return nil, script, fmt.Errorf("any-expression contains non-bool expression: %v", itemError) + } + + what.expression = boolExpression + } + } + + default: + return nil, script, fmt.Errorf("failed to parse any-expression: expected map[string]any, got %T", script) + } + + return what, nil, nil +} + +func (what *AllExpression) ParseAny(script any) (common.Expression, any, error) { + return what.ParseBool(script) +} + +func (what *AllExpression) EvalBool(scope *common.Scope) (bool, string, error) { + oldIterator := scope.SwapIterator(nil) + defer scope.SetIterator(oldIterator) + + inValue, errorEvalLiteral, evalError := what.in.EvalAny(scope) + if evalError != nil { + return false, errorEvalLiteral, evalError + } + + switch castValue := inValue.(type) { + case []any: + if what.expression == nil { + return true, "", nil + } + + for index, item := range castValue { + if len(what.index) > 0 { + scope.Set(what.index, index) + } + + scope.SetIterator(item) + if len(what.item) > 0 { + scope.Set(what.item, item) + } + + value, errorLiteral, expressionError := what.expression.EvalBool(scope) + if expressionError != nil { + return false, errorLiteral, fmt.Errorf("error evaluating expression #%v of any-expression: %v", index+1, expressionError) + } + + if !value { + return false, "", nil + } + } + + case map[string]any: + if what.expression == nil { + return true, "", nil + } + + for name, item := range castValue { + if len(what.index) > 0 { + scope.Set(what.index, name) + } + + scope.SetIterator(item) + if len(what.item) > 0 { + scope.Set(what.item, item) + } + + value, errorLiteral, expressionError := what.expression.EvalBool(scope) + if expressionError != nil { + return false, errorLiteral, fmt.Errorf("error evaluating expression %q of any-expression: %v", name, expressionError) + } + + if !value { + return false, "", nil + } + } + + default: + return false, what.Literal(), fmt.Errorf("failed to eval any-expression: expected iterable type, got %T", inValue) + } + + return true, "", nil +} + +func (what *AllExpression) EvalAny(scope *common.Scope) (any, string, error) { + return what.EvalBool(scope) +} + +func (what *AllExpression) Literal() string { + return what.literal +} diff --git a/pkg/script/expressions/any-expression.go b/pkg/script/expressions/any-expression.go new file mode 100644 index 00000000..51c15b89 --- /dev/null +++ b/pkg/script/expressions/any-expression.go @@ -0,0 +1,150 @@ +package expressions + +import ( + "fmt" + "github.com/threagile/threagile/pkg/script/common" +) + +type AnyExpression struct { + literal string + in common.ValueExpression + item string + index string + expression common.BoolExpression +} + +func (what *AnyExpression) ParseBool(script any) (common.BoolExpression, any, error) { + what.literal = common.ToLiteral(script) + + switch script.(type) { + case map[string]any: + for key, value := range script.(map[string]any) { + switch key { + case common.In: + item, errorExpression, itemError := new(ValueExpression).ParseValue(value) + if itemError != nil { + return nil, errorExpression, fmt.Errorf("failed to parse %q of any-expression: %v", key, itemError) + } + + what.in = item + + case common.Item: + text, ok := value.(string) + if !ok { + return nil, value, fmt.Errorf("failed to parse %q of any-expression: expected string, got %T", key, value) + } + + what.item = text + + case common.Index: + text, ok := value.(string) + if !ok { + return nil, value, fmt.Errorf("failed to parse %q of any-expression: expected string, got %T", key, value) + } + + what.index = text + + default: + if what.expression != nil { + return nil, script, fmt.Errorf("failed to parse any-expression: additional bool expression %q", key) + } + + expression, errorScript, itemError := new(ExpressionList).ParseAny(map[string]any{key: value}) + if itemError != nil { + return nil, errorScript, fmt.Errorf("failed to parse any-expression: %v", itemError) + } + + boolExpression, ok := expression.(common.BoolExpression) + if !ok { + return nil, script, fmt.Errorf("any-expression contains non-bool expression: %v", itemError) + } + + what.expression = boolExpression + } + } + + default: + return nil, script, fmt.Errorf("failed to parse any-expression: expected map[string]any, got %T", script) + } + + return what, nil, nil +} + +func (what *AnyExpression) ParseAny(script any) (common.Expression, any, error) { + return what.ParseBool(script) +} + +func (what *AnyExpression) EvalBool(scope *common.Scope) (bool, string, error) { + oldIterator := scope.SwapIterator(nil) + defer scope.SetIterator(oldIterator) + + inValue, errorEvalLiteral, evalError := what.in.EvalAny(scope) + if evalError != nil { + return false, errorEvalLiteral, evalError + } + + switch castValue := inValue.(type) { + case []any: + if what.expression == nil { + return false, "", nil + } + + for index, item := range castValue { + if len(what.index) > 0 { + scope.Set(what.index, index) + } + + scope.SetIterator(item) + if len(what.item) > 0 { + scope.Set(what.item, item) + } + + value, errorLiteral, expressionError := what.expression.EvalBool(scope) + if expressionError != nil { + return false, errorLiteral, fmt.Errorf("error evaluating expression #%v of any-expression: %v", index+1, expressionError) + } + + if value { + return true, "", nil + } + } + + case map[string]any: + if what.expression == nil { + return false, "", nil + } + + for name, item := range castValue { + if len(what.index) > 0 { + scope.Set(what.index, name) + } + + scope.SetIterator(item) + if len(what.item) > 0 { + scope.Set(what.item, item) + } + + value, errorLiteral, expressionError := what.expression.EvalBool(scope) + if expressionError != nil { + return false, errorLiteral, fmt.Errorf("error evaluating expression %q of any-expression: %v", name, expressionError) + } + + if value { + return true, "", nil + } + } + + default: + return false, what.Literal(), fmt.Errorf("failed to eval any-expression: expected iterable type, got %T", inValue) + } + + return false, "", nil +} + +func (what *AnyExpression) EvalAny(scope *common.Scope) (any, string, error) { + return what.EvalBool(scope) +} + +func (what *AnyExpression) Literal() string { + return what.literal +} diff --git a/pkg/script/expressions/count-expression.go b/pkg/script/expressions/count-expression.go new file mode 100644 index 00000000..7d36cb05 --- /dev/null +++ b/pkg/script/expressions/count-expression.go @@ -0,0 +1,155 @@ +package expressions + +import ( + "fmt" + "github.com/shopspring/decimal" + "github.com/threagile/threagile/pkg/script/common" +) + +type CountExpression struct { + literal string + in common.ValueExpression + item string + index string + expression common.BoolExpression +} + +func (what *CountExpression) ParseDecimal(script any) (common.DecimalExpression, any, error) { + what.literal = common.ToLiteral(script) + + switch script.(type) { + case map[string]any: + for key, value := range script.(map[string]any) { + switch key { + case common.In: + item, errorExpression, itemError := new(ValueExpression).ParseValue(value) + if itemError != nil { + return nil, errorExpression, fmt.Errorf("failed to parse %q of any-expression: %v", key, itemError) + } + + what.in = item + + case common.Item: + text, ok := value.(string) + if !ok { + return nil, value, fmt.Errorf("failed to parse %q of any-expression: expected string, got %T", key, value) + } + + what.item = text + + case common.Index: + text, ok := value.(string) + if !ok { + return nil, value, fmt.Errorf("failed to parse %q of any-expression: expected string, got %T", key, value) + } + + what.index = text + + default: + if what.expression != nil { + return nil, script, fmt.Errorf("failed to parse any-expression: additional bool expression %q", key) + } + + expression, errorScript, itemError := new(ExpressionList).ParseAny(map[string]any{key: value}) + if itemError != nil { + return nil, errorScript, fmt.Errorf("failed to parse any-expression: %v", itemError) + } + + boolExpression, ok := expression.(common.BoolExpression) + if !ok { + return nil, script, fmt.Errorf("any-expression contains non-bool expression: %v", itemError) + } + + what.expression = boolExpression + } + } + + default: + return nil, script, fmt.Errorf("failed to parse any-expression: expected map[string]any, got %T", script) + } + + return what, nil, nil +} + +func (what *CountExpression) ParseAny(script any) (common.Expression, any, error) { + return what.ParseDecimal(script) +} + +func (what *CountExpression) EvalDecimal(scope *common.Scope) (decimal.Decimal, string, error) { + oldIterator := scope.SwapIterator(nil) + defer scope.SetIterator(oldIterator) + + inValue, errorEvalLiteral, evalError := what.in.EvalAny(scope) + if evalError != nil { + return decimal.NewFromInt(0), errorEvalLiteral, evalError + } + + switch castValue := inValue.(type) { + case []any: + if what.expression == nil { + return decimal.NewFromInt(int64(len(castValue))), "", nil + } + + var count int64 = 0 + for index, item := range castValue { + if len(what.index) > 0 { + scope.Set(what.index, index) + } + + scope.SetIterator(item) + if len(what.item) > 0 { + scope.Set(what.item, item) + } + + value, errorLiteral, expressionError := what.expression.EvalBool(scope) + if expressionError != nil { + return decimal.NewFromInt(0), errorLiteral, fmt.Errorf("error evaluating expression #%v of any-expression: %v", index+1, expressionError) + } + + if value { + count++ + } + } + + return decimal.NewFromInt(count), "", nil + + case map[string]any: + if what.expression == nil { + return decimal.NewFromInt(int64(len(castValue))), "", nil + } + + var count int64 = 0 + for name, item := range castValue { + if len(what.index) > 0 { + scope.Set(what.index, name) + } + + scope.SetIterator(item) + if len(what.item) > 0 { + scope.Set(what.item, item) + } + + value, errorLiteral, expressionError := what.expression.EvalBool(scope) + if expressionError != nil { + return decimal.NewFromInt(0), errorLiteral, fmt.Errorf("error evaluating expression %q of any-expression: %v", name, expressionError) + } + + if value { + count++ + } + } + + return decimal.NewFromInt(count), "", nil + + default: + return decimal.NewFromInt(0), what.Literal(), fmt.Errorf("failed to eval any-expression: expected iterable type, got %T", inValue) + } +} + +func (what *CountExpression) EvalAny(scope *common.Scope) (any, string, error) { + return what.EvalDecimal(scope) +} + +func (what *CountExpression) Literal() string { + return what.literal +} diff --git a/pkg/script/expressions/expression-list.go b/pkg/script/expressions/expression-list.go index 77634a1c..cf5c5fcf 100644 --- a/pkg/script/expressions/expression-list.go +++ b/pkg/script/expressions/expression-list.go @@ -13,12 +13,21 @@ type ExpressionList struct { func (what *ExpressionList) ParseExpression(script map[string]any) (common.Expression, any, error) { for key, value := range script { switch key { + case common.All: + return new(AllExpression).ParseBool(value) + + case common.Any: + return new(AnyExpression).ParseBool(value) + case common.And: return new(AndExpression).ParseBool(value) case common.Contains: return new(ContainsExpression).ParseBool(value) + case common.Count: + return new(CountExpression).ParseDecimal(value) + case common.Equal: return new(EqualExpression).ParseBool(value) diff --git a/pkg/script/statements/if-statement.go b/pkg/script/statements/if-statement.go index 8eb881df..9fc68746 100644 --- a/pkg/script/statements/if-statement.go +++ b/pkg/script/statements/if-statement.go @@ -16,46 +16,20 @@ type IfStatement struct { func (what *IfStatement) Parse(script any) (common.Statement, any, error) { what.literal = common.ToLiteral(script) - switch script.(type) { + switch castScript := script.(type) { + case map[any]any: + for key, value := range castScript { + statement, errorScript, parseError := what.parse(key, value, script) + if parseError != nil { + return statement, errorScript, parseError + } + } + case map[string]any: - for key, value := range script.(map[string]any) { - switch key { - case common.Then: - item, errorScript, itemError := new(StatementList).Parse(value) - if itemError != nil { - return nil, errorScript, fmt.Errorf("failed to parse %q of if-statement: %v", key, itemError) - } - - what.yesPath = item - - case common.Else: - item, errorScript, itemError := new(StatementList).Parse(value) - if itemError != nil { - return nil, errorScript, fmt.Errorf("failed to parse %q of if-statement: %v", key, itemError) - } - - what.noPath = item - - default: - if what.expression != nil { - return nil, script, fmt.Errorf("if-statement has multiple expressions") - } - - expression := map[string]any{ - key: value, - } - - item, errorScript, itemError := new(expressions.ExpressionList).ParseExpression(expression) - if itemError != nil { - return nil, errorScript, fmt.Errorf("failed to parse expression of if-statement: %v", itemError) - } - - boolItem, ok := item.(common.BoolExpression) - if !ok { - return nil, script, fmt.Errorf("expression of if-statement is not a bool expression: %v", itemError) - } - - what.expression = boolItem + for key, value := range castScript { + statement, errorScript, parseError := what.parse(key, value, script) + if parseError != nil { + return statement, errorScript, parseError } } @@ -66,6 +40,45 @@ func (what *IfStatement) Parse(script any) (common.Statement, any, error) { return what, nil, nil } +func (what *IfStatement) parse(key any, value any, script any) (common.Statement, any, error) { + switch key { + case common.Then: + item, errorScript, itemError := new(StatementList).Parse(value) + if itemError != nil { + return nil, errorScript, fmt.Errorf("failed to parse %q of if-statement: %v", key, itemError) + } + + what.yesPath = item + + case common.Else: + item, errorScript, itemError := new(StatementList).Parse(value) + if itemError != nil { + return nil, errorScript, fmt.Errorf("failed to parse %q of if-statement: %v", key, itemError) + } + + what.noPath = item + + default: + if what.expression != nil { + return nil, script, fmt.Errorf("if-statement has multiple expressions") + } + + item, errorScript, itemError := new(expressions.ExpressionList).ParseExpression(map[string]any{fmt.Sprintf("%v", key): value}) + if itemError != nil { + return nil, errorScript, fmt.Errorf("failed to parse expression of if-statement: %v", itemError) + } + + boolItem, ok := item.(common.BoolExpression) + if !ok { + return nil, script, fmt.Errorf("expression of if-statement is not a bool expression: %v", itemError) + } + + what.expression = boolItem + } + + return what, nil, nil +} + func (what *IfStatement) Run(scope *common.Scope) (string, error) { if what.expression == nil { return "", nil diff --git a/pkg/script/statements/loop-statement.go b/pkg/script/statements/loop-statement.go index f37d1615..c1769ad8 100644 --- a/pkg/script/statements/loop-statement.go +++ b/pkg/script/statements/loop-statement.go @@ -66,6 +66,9 @@ func (what *LoopStatement) Parse(script any) (common.Statement, any, error) { } func (what *LoopStatement) Run(scope *common.Scope) (string, error) { + oldIterator := scope.SwapIterator(nil) + defer scope.SetIterator(oldIterator) + value, errorEvalLiteral, evalError := what.in.EvalAny(scope) if evalError != nil { return errorEvalLiteral, evalError @@ -78,6 +81,7 @@ func (what *LoopStatement) Run(scope *common.Scope) (string, error) { scope.Set(what.index, index) } + scope.SetIterator(item) if len(what.item) > 0 { scope.Set(what.item, item) } @@ -94,6 +98,7 @@ func (what *LoopStatement) Run(scope *common.Scope) (string, error) { scope.Set(what.index, name) } + scope.SetIterator(item) if len(what.item) > 0 { scope.Set(what.item, item) } diff --git a/test/risk-category.yaml b/test/risk-category.yaml index 438a9a1c..8840b587 100644 --- a/test/risk-category.yaml +++ b/test/risk-category.yaml @@ -44,11 +44,11 @@ script: - "{tech_asset.id}" most_relevant_data_asset: "{tech_asset.id}" - match: + match-old: parameter: tech_asset do: - if: - "true": "{tech_asset.out_of_scope}" + true: "{tech_asset.out_of_scope}" then: return: false - loop: @@ -62,6 +62,20 @@ script: then: return: true + match: + parameter: tech_asset + do: + - if: + and: + - false: "{tech_asset.out_of_scope}" + - any: + in: "{tech_asset.technologies}" + or: + - true: "{.attributes.sourcecode-repository}" + - true: "{.attributes.artifact-registry}" + then: + return: true + utils: get_title: parameters: