Skip to content

Commit

Permalink
added shorthand exprssions any/all/count
Browse files Browse the repository at this point in the history
  • Loading branch information
joreiche committed Apr 15, 2024
1 parent 2020ccc commit 9dc72ba
Show file tree
Hide file tree
Showing 9 changed files with 564 additions and 44 deletions.
3 changes: 3 additions & 0 deletions pkg/script/common/key-words.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@ const (
Then = "then"
Else = "else"

All = "all"
And = "and"
Any = "any"
Contains = "contains"
Count = "count"
Equal = "equal"
EqualOrGreater = "equal-or-greater"
EqualOrLess = "equal-or-less"
Expand Down
27 changes: 24 additions & 3 deletions pkg/script/common/scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ type Scope struct {
Model map[string]any
Risk map[string]any
Methods map[string]Statement
iterator Value
returnValue Value
}

Expand Down Expand Up @@ -110,6 +111,20 @@ func (what *Scope) Get(name string) (Value, bool) {
return nil, false
}

func (what *Scope) GetIterator() Value {
return what.iterator
}

func (what *Scope) SetIterator(value Value) {
what.iterator = value
}

func (what *Scope) SwapIterator(value Value) Value {
var currentIterator Value
currentIterator, what.iterator = what.iterator, value
return currentIterator
}

func (what *Scope) SetReturnValue(value Value) {
what.returnValue = value
}
Expand Down Expand Up @@ -153,9 +168,15 @@ func (what *Scope) getVar(path []string) (Value, bool) {
return nil, false
}

field, ok := what.Vars[strings.ToLower(path[0])]
if !ok {
return nil, false
var field Value
if len(path[0]) == 0 {
field = what.iterator
} else {
var fieldOk bool
field, fieldOk = what.Vars[strings.ToLower(path[0])]
if !fieldOk {
return nil, false
}
}

if len(path) == 1 {
Expand Down
150 changes: 150 additions & 0 deletions pkg/script/expressions/all-expression.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package expressions

import (
"fmt"
"github.com/threagile/threagile/pkg/script/common"
)

type AllExpression struct {
literal string
in common.ValueExpression
item string
index string
expression common.BoolExpression
}

func (what *AllExpression) ParseBool(script any) (common.BoolExpression, any, error) {
what.literal = common.ToLiteral(script)

switch script.(type) {
case map[string]any:
for key, value := range script.(map[string]any) {
switch key {
case common.In:
item, errorExpression, itemError := new(ValueExpression).ParseValue(value)
if itemError != nil {
return nil, errorExpression, fmt.Errorf("failed to parse %q of any-expression: %v", key, itemError)
}

what.in = item

case common.Item:
text, ok := value.(string)
if !ok {
return nil, value, fmt.Errorf("failed to parse %q of any-expression: expected string, got %T", key, value)
}

what.item = text

case common.Index:
text, ok := value.(string)
if !ok {
return nil, value, fmt.Errorf("failed to parse %q of any-expression: expected string, got %T", key, value)
}

what.index = text

default:
if what.expression != nil {
return nil, script, fmt.Errorf("failed to parse any-expression: additional bool expression %q", key)
}

expression, errorScript, itemError := new(ExpressionList).ParseAny(map[string]any{key: value})
if itemError != nil {
return nil, errorScript, fmt.Errorf("failed to parse any-expression: %v", itemError)
}

boolExpression, ok := expression.(common.BoolExpression)
if !ok {
return nil, script, fmt.Errorf("any-expression contains non-bool expression: %v", itemError)
}

what.expression = boolExpression
}
}

default:
return nil, script, fmt.Errorf("failed to parse any-expression: expected map[string]any, got %T", script)
}

return what, nil, nil
}

func (what *AllExpression) ParseAny(script any) (common.Expression, any, error) {
return what.ParseBool(script)
}

func (what *AllExpression) EvalBool(scope *common.Scope) (bool, string, error) {
oldIterator := scope.SwapIterator(nil)
defer scope.SetIterator(oldIterator)

inValue, errorEvalLiteral, evalError := what.in.EvalAny(scope)
if evalError != nil {
return false, errorEvalLiteral, evalError
}

switch castValue := inValue.(type) {
case []any:
if what.expression == nil {
return true, "", nil
}

for index, item := range castValue {
if len(what.index) > 0 {
scope.Set(what.index, index)
}

scope.SetIterator(item)
if len(what.item) > 0 {
scope.Set(what.item, item)
}

value, errorLiteral, expressionError := what.expression.EvalBool(scope)
if expressionError != nil {
return false, errorLiteral, fmt.Errorf("error evaluating expression #%v of any-expression: %v", index+1, expressionError)
}

if !value {
return false, "", nil
}
}

case map[string]any:
if what.expression == nil {
return true, "", nil
}

for name, item := range castValue {
if len(what.index) > 0 {
scope.Set(what.index, name)
}

scope.SetIterator(item)
if len(what.item) > 0 {
scope.Set(what.item, item)
}

value, errorLiteral, expressionError := what.expression.EvalBool(scope)
if expressionError != nil {
return false, errorLiteral, fmt.Errorf("error evaluating expression %q of any-expression: %v", name, expressionError)
}

if !value {
return false, "", nil
}
}

default:
return false, what.Literal(), fmt.Errorf("failed to eval any-expression: expected iterable type, got %T", inValue)
}

return true, "", nil
}

func (what *AllExpression) EvalAny(scope *common.Scope) (any, string, error) {
return what.EvalBool(scope)
}

func (what *AllExpression) Literal() string {
return what.literal
}
150 changes: 150 additions & 0 deletions pkg/script/expressions/any-expression.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package expressions

import (
"fmt"
"github.com/threagile/threagile/pkg/script/common"
)

type AnyExpression struct {
literal string
in common.ValueExpression
item string
index string
expression common.BoolExpression
}

func (what *AnyExpression) ParseBool(script any) (common.BoolExpression, any, error) {
what.literal = common.ToLiteral(script)

switch script.(type) {
case map[string]any:
for key, value := range script.(map[string]any) {
switch key {
case common.In:
item, errorExpression, itemError := new(ValueExpression).ParseValue(value)
if itemError != nil {
return nil, errorExpression, fmt.Errorf("failed to parse %q of any-expression: %v", key, itemError)
}

what.in = item

case common.Item:
text, ok := value.(string)
if !ok {
return nil, value, fmt.Errorf("failed to parse %q of any-expression: expected string, got %T", key, value)
}

what.item = text

case common.Index:
text, ok := value.(string)
if !ok {
return nil, value, fmt.Errorf("failed to parse %q of any-expression: expected string, got %T", key, value)
}

what.index = text

default:
if what.expression != nil {
return nil, script, fmt.Errorf("failed to parse any-expression: additional bool expression %q", key)
}

expression, errorScript, itemError := new(ExpressionList).ParseAny(map[string]any{key: value})
if itemError != nil {
return nil, errorScript, fmt.Errorf("failed to parse any-expression: %v", itemError)
}

boolExpression, ok := expression.(common.BoolExpression)
if !ok {
return nil, script, fmt.Errorf("any-expression contains non-bool expression: %v", itemError)
}

what.expression = boolExpression
}
}

default:
return nil, script, fmt.Errorf("failed to parse any-expression: expected map[string]any, got %T", script)
}

return what, nil, nil
}

func (what *AnyExpression) ParseAny(script any) (common.Expression, any, error) {
return what.ParseBool(script)
}

func (what *AnyExpression) EvalBool(scope *common.Scope) (bool, string, error) {
oldIterator := scope.SwapIterator(nil)
defer scope.SetIterator(oldIterator)

inValue, errorEvalLiteral, evalError := what.in.EvalAny(scope)
if evalError != nil {
return false, errorEvalLiteral, evalError
}

switch castValue := inValue.(type) {
case []any:
if what.expression == nil {
return false, "", nil
}

for index, item := range castValue {
if len(what.index) > 0 {
scope.Set(what.index, index)
}

scope.SetIterator(item)
if len(what.item) > 0 {
scope.Set(what.item, item)
}

value, errorLiteral, expressionError := what.expression.EvalBool(scope)
if expressionError != nil {
return false, errorLiteral, fmt.Errorf("error evaluating expression #%v of any-expression: %v", index+1, expressionError)
}

if value {
return true, "", nil
}
}

case map[string]any:
if what.expression == nil {
return false, "", nil
}

for name, item := range castValue {
if len(what.index) > 0 {
scope.Set(what.index, name)
}

scope.SetIterator(item)
if len(what.item) > 0 {
scope.Set(what.item, item)
}

value, errorLiteral, expressionError := what.expression.EvalBool(scope)
if expressionError != nil {
return false, errorLiteral, fmt.Errorf("error evaluating expression %q of any-expression: %v", name, expressionError)
}

if value {
return true, "", nil
}
}

default:
return false, what.Literal(), fmt.Errorf("failed to eval any-expression: expected iterable type, got %T", inValue)
}

return false, "", nil
}

func (what *AnyExpression) EvalAny(scope *common.Scope) (any, string, error) {
return what.EvalBool(scope)
}

func (what *AnyExpression) Literal() string {
return what.literal
}
Loading

0 comments on commit 9dc72ba

Please sign in to comment.