diff --git a/go.mod b/go.mod index 81e677a27a..23368dd02b 100644 --- a/go.mod +++ b/go.mod @@ -29,14 +29,13 @@ require ( github.com/docker/go-connections v0.4.0 github.com/drone/envsubst/v2 v2.0.0-20210730161058-179042472c46 github.com/fatih/color v1.15.0 - github.com/fatih/structs v1.1.0 github.com/fortytw2/leaktest v1.3.0 github.com/fsnotify/fsnotify v1.6.0 github.com/github/smimesign v0.2.0 github.com/go-git/go-git/v5 v5.11.0 github.com/go-kit/log v0.2.1 github.com/go-logfmt/logfmt v0.6.0 - github.com/go-logr/logr v1.4.1 + github.com/go-logr/logr v1.4.1 // indirect github.com/go-sourcemap/sourcemap v2.1.3+incompatible github.com/go-sql-driver/mysql v1.7.1 github.com/gogo/protobuf v1.3.2 @@ -45,7 +44,6 @@ require ( github.com/google/cadvisor v0.47.0 github.com/google/dnsmasq_exporter v0.2.1-0.20230620100026-44b14480804a github.com/google/go-cmp v0.6.0 - github.com/google/go-jsonnet v0.18.0 github.com/google/pprof v0.0.0-20240117000934-35fc243c5815 github.com/google/renameio/v2 v2.0.0 github.com/google/uuid v1.4.0 @@ -162,7 +160,6 @@ require ( github.com/spf13/cobra v1.7.0 github.com/stretchr/testify v1.8.4 github.com/testcontainers/testcontainers-go v0.25.0 - github.com/testcontainers/testcontainers-go/modules/k3s v0.0.0-20230615142642-c175df34bd1d github.com/uber/jaeger-client-go v2.30.0+incompatible github.com/vincent-petithory/dataurl v1.0.0 github.com/webdevops/azure-metrics-exporter v0.0.0-20230717202958-8701afc2b013 @@ -226,7 +223,7 @@ require ( gopkg.in/yaml.v3 v3.0.1 gotest.tools v2.2.0+incompatible k8s.io/api v0.28.3 - k8s.io/apiextensions-apiserver v0.28.0 + k8s.io/apiextensions-apiserver v0.28.0 // indirect k8s.io/client-go v0.28.3 k8s.io/component-base v0.28.1 k8s.io/klog/v2 v2.100.1 @@ -766,3 +763,9 @@ exclude ( ) replace github.com/github/smimesign => github.com/grafana/smimesign v0.2.1-0.20220408144937-2a5adf3481d3 + +// Submodules. +// TODO(rfratto): Change all imports of github.com/grafana/river in favor of +// importing github.com/grafana/alloy/syntax and change module and package +// names to remove references of "river". +replace github.com/grafana/river => ./syntax diff --git a/go.sum b/go.sum index 95dd6c93ba..93cd5ab33b 100644 --- a/go.sum +++ b/go.sum @@ -666,12 +666,10 @@ github.com/fatih/camelcase v1.0.0 h1:hxNvNX/xYBp0ovncs8WyWZrOrpBNub/JfaMvbURyft8 github.com/fatih/camelcase v1.0.0/go.mod h1:yN2Sb0lFhZJUdVvtELVWefmrXpuZESvPmqwoZc+/fpc= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU= -github.com/fatih/color v1.10.0/go.mod h1:ELkj/draVOlAH/xkhN6mQ50Qd0MPOk5AAr3maGEBuJM= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= github.com/fatih/color v1.15.0 h1:kOqh6YHBtK8aywxGerMG2Eq3H6Qgoqeo13Bk2Mv/nBs= github.com/fatih/color v1.15.0/go.mod h1:0h5ZqXfHYED7Bhv2ZJamyIOUej9KtShiJESRwBDUSsw= github.com/fatih/structs v0.0.0-20180123065059-ebf56d35bba7/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= -github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo= github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= github.com/felixge/fgprof v0.9.3 h1:VvyZxILNuCiUCSXtPtYmmtGvb65nqXh2QFWc0Wpf2/g= github.com/felixge/fgprof v0.9.3/go.mod h1:RdbpDgzqYVh/T9fPELJyV7EYJuHB55UTEULNun8eiPw= @@ -957,8 +955,6 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= github.com/google/go-github/v32 v32.1.0/go.mod h1:rIEpZD9CTDQwDK9GDrtMTycQNA4JU3qBsCizh3q2WCI= -github.com/google/go-jsonnet v0.18.0 h1:/6pTy6g+Jh1a1I2UMoAODkqELFiVIdOxbNwv0DDzoOg= -github.com/google/go-jsonnet v0.18.0/go.mod h1:C3fTzyVJDslXdiTqw/bTFk7vSGyCtH3MGRbDfvEwGd0= github.com/google/go-querystring v0.0.0-20170111101155-53e6ce116135/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= @@ -1084,8 +1080,6 @@ github.com/grafana/pyroscope/ebpf v0.4.3 h1:gPfm2FKabdycRfFIej/s0awSzsbAaoSefaeh github.com/grafana/pyroscope/ebpf v0.4.3/go.mod h1:Iv66aj9WsDWR8bGMPQzCQPCgVgCru0KizGrbcR3YmLk= github.com/grafana/regexp v0.0.0-20221123153739-15dc172cd2db h1:7aN5cccjIqCLTzedH7MZzRZt5/lsAHch6Z3L2ZGn5FA= github.com/grafana/regexp v0.0.0-20221123153739-15dc172cd2db/go.mod h1:M5qHK+eWfAv8VR/265dIuEpL3fNfeC21tXXp9itM24A= -github.com/grafana/river v0.3.1-0.20240123144725-960753160cd1 h1:mCOKdWkLv8n9X0ORWrPR+W/zLOAa1o6iM+Dfy0ofQUs= -github.com/grafana/river v0.3.1-0.20240123144725-960753160cd1/go.mod h1:tAiNX2zt3HUsNyPNUDSvE6AgQ4+kqJvljBI+ACppMtM= github.com/grafana/smimesign v0.2.1-0.20220408144937-2a5adf3481d3 h1:UPkAxuhlAcRmJT3/qd34OMTl+ZU7BLLfOO2+NXBlJpY= github.com/grafana/smimesign v0.2.1-0.20220408144937-2a5adf3481d3/go.mod h1:iZiiwNT4HbtGRVqCQu7uJPEZCuEE5sfSSttcnePkDl4= github.com/grafana/snowflake-prometheus-exporter v0.0.0-20221213150626-862cad8e9538 h1:tkT0yha3JzB5S5VNjfY4lT0cJAe20pU8XGt3Nuq73rM= @@ -1529,7 +1523,6 @@ github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaO github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= -github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= @@ -2074,7 +2067,6 @@ github.com/segmentio/fasthash v1.0.3 h1:EI9+KE1EwvMLBWwjpRDc+fEM+prwxDYbslddQGtr github.com/segmentio/fasthash v1.0.3/go.mod h1:waKX8l2N8yckOgmSsXJi7x1ZfdKZ4x7KRMzBtS3oedY= github.com/sercand/kuberesolver/v5 v5.1.1 h1:CYH+d67G0sGBj7q5wLK61yzqJJ8gLLC8aeprPTHb6yY= github.com/sercand/kuberesolver/v5 v5.1.1/go.mod h1:Fs1KbKhVRnB2aDWN12NjKCB+RgYMWZJ294T3BtmVCpQ= -github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/sergi/go-diff v1.2.0 h1:XU+rvMAioB0UC3q1MFrIQy4Vo5/4VsRDQQXHsEya6xQ= github.com/sergi/go-diff v1.2.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/shirou/gopsutil v0.0.0-20181107111621-48177ef5f880/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= @@ -2196,8 +2188,6 @@ github.com/tencentcloud/tencentcloud-sdk-go v1.0.162/go.mod h1:asUz5BPXxgoPGaRgZ github.com/tent/http-link-go v0.0.0-20130702225549-ac974c61c2f9/go.mod h1:RHkNRtSLfOK7qBTHaeSX1D6BNpI3qw7NTxsmNr4RvN8= github.com/testcontainers/testcontainers-go v0.25.0 h1:erH6cQjsaJrH+rJDU9qIf89KFdhK0Bft0aEZHlYC3Vs= github.com/testcontainers/testcontainers-go v0.25.0/go.mod h1:4sC9SiJyzD1XFi59q8umTQYWxnkweEc5OjVtTUlJzqQ= -github.com/testcontainers/testcontainers-go/modules/k3s v0.0.0-20230615142642-c175df34bd1d h1:KyYCHo9iBoQYw5AzcozD/77uNbFlRjTmMTA7QjSxHOQ= -github.com/testcontainers/testcontainers-go/modules/k3s v0.0.0-20230615142642-c175df34bd1d/go.mod h1:Pa91ahCbzRB6d9FBi6UAjurTEm7WmyBVeuklLkwAKKs= github.com/tg123/go-htpasswd v1.2.1 h1:i4wfsX1KvvkyoMiHZzjS0VzbAPWfxzI8INcZAKtutoU= github.com/tg123/go-htpasswd v1.2.1/go.mod h1:erHp1B86KXdwQf1X5ZrLb7erXZnWueEQezb2dql4q58= github.com/tidwall/gjson v1.6.0/go.mod h1:P256ACg0Mn+j1RXIDXoss50DeIABTYK1PULOJHhxOls= diff --git a/syntax/ast/ast.go b/syntax/ast/ast.go new file mode 100644 index 0000000000..992ee0c71a --- /dev/null +++ b/syntax/ast/ast.go @@ -0,0 +1,328 @@ +// Package ast exposes AST elements used by River. +// +// The various interfaces exposed by ast are all closed; only types within this +// package can satisfy an AST interface. +package ast + +import ( + "fmt" + "reflect" + "strings" + + "github.com/grafana/river/token" +) + +// Node represents any node in the AST. +type Node interface { + astNode() +} + +// Stmt is a type of statement within the body of a file or block. +type Stmt interface { + Node + astStmt() +} + +// Expr is an expression within the AST. +type Expr interface { + Node + astExpr() +} + +// File is a parsed file. +type File struct { + Name string // Filename provided to parser + Body Body // Content of File + Comments []CommentGroup // List of all comments in the File +} + +// Body is a list of statements. +type Body []Stmt + +// A CommentGroup represents a sequence of comments that are not separated by +// any empty lines or other non-comment tokens. +type CommentGroup []*Comment + +// A Comment represents a single line or block comment. +// +// The Text field contains the comment text without any carriage returns (\r) +// that may have been present in the source. Since carriage returns get +// removed, EndPos will not be accurate for any comment which contained +// carriage returns. +type Comment struct { + StartPos token.Pos // Starting position of comment + // Text of the comment. Text will not contain '\n' for line comments. + Text string +} + +// AttributeStmt is a key-value pair being set in a Body or BlockStmt. +type AttributeStmt struct { + Name *Ident + Value Expr +} + +// BlockStmt declares a block. +type BlockStmt struct { + Name []string + NamePos token.Pos + Label string + LabelPos token.Pos + Body Body + + LCurlyPos, RCurlyPos token.Pos +} + +// Ident holds an identifier with its position. +type Ident struct { + Name string + NamePos token.Pos +} + +// IdentifierExpr refers to a named value. +type IdentifierExpr struct { + Ident *Ident +} + +// LiteralExpr is a constant value of a specific token kind. +type LiteralExpr struct { + Kind token.Token + ValuePos token.Pos + + // Value holds the unparsed literal value. For example, if Kind == + // token.STRING, then Value would be wrapped in the original quotes (e.g., + // `"foobar"`). + Value string +} + +// ArrayExpr is an array of values. +type ArrayExpr struct { + Elements []Expr + LBrackPos, RBrackPos token.Pos +} + +// ObjectExpr declares an object of key-value pairs. +type ObjectExpr struct { + Fields []*ObjectField + LCurlyPos, RCurlyPos token.Pos +} + +// ObjectField defines an individual key-value pair within an object. +// ObjectField does not implement Node. +type ObjectField struct { + Name *Ident + Quoted bool // True if the name was wrapped in quotes + Value Expr +} + +// AccessExpr accesses a field in an object value by name. +type AccessExpr struct { + Value Expr + Name *Ident +} + +// IndexExpr accesses an index in an array value. +type IndexExpr struct { + Value, Index Expr + LBrackPos, RBrackPos token.Pos +} + +// CallExpr invokes a function value with a set of arguments. +type CallExpr struct { + Value Expr + Args []Expr + + LParenPos, RParenPos token.Pos +} + +// UnaryExpr performs a unary operation on a single value. +type UnaryExpr struct { + Kind token.Token + KindPos token.Pos + Value Expr +} + +// BinaryExpr performs a binary operation against two values. +type BinaryExpr struct { + Kind token.Token + KindPos token.Pos + Left, Right Expr +} + +// ParenExpr represents an expression wrapped in parentheses. +type ParenExpr struct { + Inner Expr + LParenPos, RParenPos token.Pos +} + +// Type assertions + +var ( + _ Node = (*File)(nil) + _ Node = (*Body)(nil) + _ Node = (*AttributeStmt)(nil) + _ Node = (*BlockStmt)(nil) + _ Node = (*Ident)(nil) + _ Node = (*IdentifierExpr)(nil) + _ Node = (*LiteralExpr)(nil) + _ Node = (*ArrayExpr)(nil) + _ Node = (*ObjectExpr)(nil) + _ Node = (*AccessExpr)(nil) + _ Node = (*IndexExpr)(nil) + _ Node = (*CallExpr)(nil) + _ Node = (*UnaryExpr)(nil) + _ Node = (*BinaryExpr)(nil) + _ Node = (*ParenExpr)(nil) + + _ Stmt = (*AttributeStmt)(nil) + _ Stmt = (*BlockStmt)(nil) + + _ Expr = (*IdentifierExpr)(nil) + _ Expr = (*LiteralExpr)(nil) + _ Expr = (*ArrayExpr)(nil) + _ Expr = (*ObjectExpr)(nil) + _ Expr = (*AccessExpr)(nil) + _ Expr = (*IndexExpr)(nil) + _ Expr = (*CallExpr)(nil) + _ Expr = (*UnaryExpr)(nil) + _ Expr = (*BinaryExpr)(nil) + _ Expr = (*ParenExpr)(nil) +) + +func (n *File) astNode() {} +func (n Body) astNode() {} +func (n CommentGroup) astNode() {} +func (n *Comment) astNode() {} +func (n *AttributeStmt) astNode() {} +func (n *BlockStmt) astNode() {} +func (n *Ident) astNode() {} +func (n *IdentifierExpr) astNode() {} +func (n *LiteralExpr) astNode() {} +func (n *ArrayExpr) astNode() {} +func (n *ObjectExpr) astNode() {} +func (n *AccessExpr) astNode() {} +func (n *IndexExpr) astNode() {} +func (n *CallExpr) astNode() {} +func (n *UnaryExpr) astNode() {} +func (n *BinaryExpr) astNode() {} +func (n *ParenExpr) astNode() {} + +func (n *AttributeStmt) astStmt() {} +func (n *BlockStmt) astStmt() {} + +func (n *IdentifierExpr) astExpr() {} +func (n *LiteralExpr) astExpr() {} +func (n *ArrayExpr) astExpr() {} +func (n *ObjectExpr) astExpr() {} +func (n *AccessExpr) astExpr() {} +func (n *IndexExpr) astExpr() {} +func (n *CallExpr) astExpr() {} +func (n *UnaryExpr) astExpr() {} +func (n *BinaryExpr) astExpr() {} +func (n *ParenExpr) astExpr() {} + +// StartPos returns the position of the first character belonging to a Node. +func StartPos(n Node) token.Pos { + if n == nil || reflect.ValueOf(n).IsZero() { + return token.NoPos + } + switch n := n.(type) { + case *File: + return StartPos(n.Body) + case Body: + if len(n) == 0 { + return token.NoPos + } + return StartPos(n[0]) + case CommentGroup: + if len(n) == 0 { + return token.NoPos + } + return StartPos(n[0]) + case *Comment: + return n.StartPos + case *AttributeStmt: + return StartPos(n.Name) + case *BlockStmt: + return n.NamePos + case *Ident: + return n.NamePos + case *IdentifierExpr: + return StartPos(n.Ident) + case *LiteralExpr: + return n.ValuePos + case *ArrayExpr: + return n.LBrackPos + case *ObjectExpr: + return n.LCurlyPos + case *AccessExpr: + return StartPos(n.Value) + case *IndexExpr: + return StartPos(n.Value) + case *CallExpr: + return StartPos(n.Value) + case *UnaryExpr: + return n.KindPos + case *BinaryExpr: + return StartPos(n.Left) + case *ParenExpr: + return n.LParenPos + default: + panic(fmt.Sprintf("Unhandled Node type %T", n)) + } +} + +// EndPos returns the position of the final character in a Node. +func EndPos(n Node) token.Pos { + if n == nil || reflect.ValueOf(n).IsZero() { + return token.NoPos + } + switch n := n.(type) { + case *File: + return EndPos(n.Body) + case Body: + if len(n) == 0 { + return token.NoPos + } + return EndPos(n[len(n)-1]) + case CommentGroup: + if len(n) == 0 { + return token.NoPos + } + return EndPos(n[len(n)-1]) + case *Comment: + return n.StartPos.Add(len(n.Text) - 1) + case *AttributeStmt: + return EndPos(n.Value) + case *BlockStmt: + return n.RCurlyPos + case *Ident: + return n.NamePos.Add(len(n.Name) - 1) + case *IdentifierExpr: + return EndPos(n.Ident) + case *LiteralExpr: + return n.ValuePos.Add(len(n.Value) - 1) + case *ArrayExpr: + return n.RBrackPos + case *ObjectExpr: + return n.RCurlyPos + case *AccessExpr: + return EndPos(n.Name) + case *IndexExpr: + return n.RBrackPos + case *CallExpr: + return n.RParenPos + case *UnaryExpr: + return EndPos(n.Value) + case *BinaryExpr: + return EndPos(n.Right) + case *ParenExpr: + return n.RParenPos + default: + panic(fmt.Sprintf("Unhandled Node type %T", n)) + } +} + +// GetBlockName retrieves the "." delimited block name. +func (block *BlockStmt) GetBlockName() string { + return strings.Join(block.Name, ".") +} diff --git a/syntax/ast/walk.go b/syntax/ast/walk.go new file mode 100644 index 0000000000..df3f82d9a3 --- /dev/null +++ b/syntax/ast/walk.go @@ -0,0 +1,73 @@ +package ast + +import "fmt" + +// A Visitor has its Visit method invoked for each node encountered by Walk. If +// the resulting visitor w is not nil, Walk visits each of the children of node +// with the visitor w, followed by a call of w.Visit(nil). +type Visitor interface { + Visit(node Node) (w Visitor) +} + +// Walk traverses an AST in depth-first order: it starts by calling +// v.Visit(node); node must not be nil. If the visitor w returned by +// v.Visit(node) is not nil, Walk is invoked recursively with visitor w for +// each of the non-nil children of node, followed by a call of w.Visit(nil). +func Walk(v Visitor, node Node) { + if v = v.Visit(node); v == nil { + return + } + + // Walk children. The order of the cases matches the declared order of nodes + // in ast.go. + switch n := node.(type) { + case *File: + Walk(v, n.Body) + case Body: + for _, s := range n { + Walk(v, s) + } + case *AttributeStmt: + Walk(v, n.Name) + Walk(v, n.Value) + case *BlockStmt: + Walk(v, n.Body) + case *Ident: + // Nothing to do + case *IdentifierExpr: + Walk(v, n.Ident) + case *LiteralExpr: + // Nothing to do + case *ArrayExpr: + for _, e := range n.Elements { + Walk(v, e) + } + case *ObjectExpr: + for _, f := range n.Fields { + Walk(v, f.Name) + Walk(v, f.Value) + } + case *AccessExpr: + Walk(v, n.Value) + Walk(v, n.Name) + case *IndexExpr: + Walk(v, n.Value) + Walk(v, n.Index) + case *CallExpr: + Walk(v, n.Value) + for _, a := range n.Args { + Walk(v, a) + } + case *UnaryExpr: + Walk(v, n.Value) + case *BinaryExpr: + Walk(v, n.Left) + Walk(v, n.Right) + case *ParenExpr: + Walk(v, n.Inner) + default: + panic(fmt.Sprintf("river/ast: unexpected node type %T", n)) + } + + v.Visit(nil) +} diff --git a/syntax/cmd/riverfmt/main.go b/syntax/cmd/riverfmt/main.go new file mode 100644 index 0000000000..d7b4433f5a --- /dev/null +++ b/syntax/cmd/riverfmt/main.go @@ -0,0 +1,103 @@ +package main + +import ( + "bytes" + "errors" + "flag" + "fmt" + "io" + "os" + + "github.com/grafana/river/diag" + "github.com/grafana/river/parser" + "github.com/grafana/river/printer" +) + +func main() { + err := run() + + var diags diag.Diagnostics + if errors.As(err, &diags) { + for _, diag := range diags { + fmt.Fprintln(os.Stderr, diag) + } + os.Exit(1) + } else if err != nil { + fmt.Fprintf(os.Stderr, "error: %s\n", err) + os.Exit(1) + } +} + +func run() error { + var ( + write bool + ) + + fs := flag.NewFlagSet("riverfmt", flag.ExitOnError) + fs.BoolVar(&write, "w", write, "write result to (source) file instead of stdout") + + if err := fs.Parse(os.Args[1:]); err != nil { + return err + } + + args := fs.Args() + switch len(args) { + case 0: + if write { + return fmt.Errorf("cannot use -w with standard input") + } + return format("", nil, os.Stdin, write) + + case 1: + fi, err := os.Stat(args[0]) + if err != nil { + return err + } + if fi.IsDir() { + return fmt.Errorf("cannot format a directory") + } + f, err := os.Open(args[0]) + if err != nil { + return err + } + defer f.Close() + return format(args[0], fi, f, write) + + default: + return fmt.Errorf("can only format one file") + } +} + +func format(filename string, fi os.FileInfo, r io.Reader, write bool) error { + bb, err := io.ReadAll(r) + if err != nil { + return err + } + + f, err := parser.ParseFile(filename, bb) + if err != nil { + return err + } + + var buf bytes.Buffer + if err := printer.Fprint(&buf, f); err != nil { + return err + } + + // Add a newline at the end + _, _ = buf.Write([]byte{'\n'}) + + if !write { + _, err := io.Copy(os.Stdout, &buf) + return err + } + + wf, err := os.OpenFile(filename, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, fi.Mode().Perm()) + if err != nil { + return err + } + defer wf.Close() + + _, err = io.Copy(wf, &buf) + return err +} diff --git a/syntax/diag/diag.go b/syntax/diag/diag.go new file mode 100644 index 0000000000..a49487af61 --- /dev/null +++ b/syntax/diag/diag.go @@ -0,0 +1,95 @@ +// Package diag exposes error types used throughout River and a method to +// pretty-print them to the screen. +package diag + +import ( + "fmt" + + "github.com/grafana/river/token" +) + +// Severity denotes the severity level of a diagnostic. The zero value of +// severity is invalid. +type Severity int + +// Supported severity levels. +const ( + SeverityLevelWarn Severity = iota + 1 + SeverityLevelError +) + +// Diagnostic is an individual diagnostic message. Diagnostic messages can have +// different levels of severities. +type Diagnostic struct { + // Severity holds the severity level of this Diagnostic. + Severity Severity + + // StartPos refers to a position in a file where this Diagnostic starts. + StartPos token.Position + + // EndPos refers to an optional position in a file where this Diagnostic + // ends. If EndPos is the zero value, the Diagnostic should be treated as + // only covering a single character (i.e., StartPos == EndPos). + // + // When defined, EndPos must have the same Filename value as the StartPos. + EndPos token.Position + + Message string + Value string +} + +// As allows d to be interpreted as a list of Diagnostics. +func (d Diagnostic) As(v interface{}) bool { + switch v := v.(type) { + case *Diagnostics: + *v = Diagnostics{d} + return true + } + + return false +} + +// Error implements error. +func (d Diagnostic) Error() string { + return fmt.Sprintf("%s: %s", d.StartPos, d.Message) +} + +// Diagnostics is a collection of diagnostic messages. +type Diagnostics []Diagnostic + +// Add adds an individual Diagnostic to the diagnostics list. +func (ds *Diagnostics) Add(d Diagnostic) { + *ds = append(*ds, d) +} + +// Error implements error. +func (ds Diagnostics) Error() string { + switch len(ds) { + case 0: + return "no errors" + case 1: + return ds[0].Error() + default: + return fmt.Sprintf("%s (and %d more diagnostics)", ds[0], len(ds)-1) + } +} + +// ErrorOrNil returns an error interface if the list diagnostics is non-empty, +// nil otherwise. +func (ds Diagnostics) ErrorOrNil() error { + if len(ds) == 0 { + return nil + } + return ds +} + +// HasErrors reports whether the list of Diagnostics contain any error-level +// diagnostic. +func (ds Diagnostics) HasErrors() bool { + for _, d := range ds { + if d.Severity == SeverityLevelError { + return true + } + } + return false +} diff --git a/syntax/diag/printer.go b/syntax/diag/printer.go new file mode 100644 index 0000000000..03994d68cf --- /dev/null +++ b/syntax/diag/printer.go @@ -0,0 +1,266 @@ +package diag + +import ( + "bufio" + "fmt" + "io" + "strconv" + "strings" + + "github.com/fatih/color" + "github.com/grafana/river/token" +) + +const tabWidth = 4 + +// PrinterConfig controls different settings for the Printer. +type PrinterConfig struct { + // When Color is true, the printer will output with color and special + // formatting characters (such as underlines). + // + // This should be disabled when not printing to a terminal. + Color bool + + // ContextLinesBefore and ContextLinesAfter controls how many context lines + // before and after the range of the diagnostic are printed. + ContextLinesBefore, ContextLinesAfter int +} + +// A Printer pretty-prints Diagnostics. +type Printer struct { + cfg PrinterConfig +} + +// NewPrinter creates a new diagnostics Printer with the provided config. +func NewPrinter(cfg PrinterConfig) *Printer { + return &Printer{cfg: cfg} +} + +// Fprint creates a Printer with default settings and prints diagnostics to the +// provided writer. files is used to look up file contents by name for printing +// diagnostics context. files may be set to nil to avoid printing context. +func Fprint(w io.Writer, files map[string][]byte, diags Diagnostics) error { + p := NewPrinter(PrinterConfig{ + Color: false, + ContextLinesBefore: 1, + ContextLinesAfter: 1, + }) + return p.Fprint(w, files, diags) +} + +// Fprint pretty-prints errors to a writer. files is used to look up file +// contents by name when printing context. files may be nil to avoid printing +// context. +func (p *Printer) Fprint(w io.Writer, files map[string][]byte, diags Diagnostics) error { + // Create a buffered writer since we'll have many small calls to Write while + // we print errors. + // + // Buffers writers track the first write error received and will return it + // (if any) when flushing, so we can ignore write errors throughout the code + // until the very end. + bw := bufio.NewWriter(w) + + for i, diag := range diags { + p.printDiagnosticHeader(bw, diag) + + // If there's no ending position, set the ending position to be the same as + // the start. + if !diag.EndPos.Valid() { + diag.EndPos = diag.StartPos + } + + // We can print the file context if it was found. + fileContents, foundFile := files[diag.StartPos.Filename] + if foundFile && diag.StartPos.Filename == diag.EndPos.Filename { + p.printRange(bw, fileContents, diag) + } + + // Print a blank line to separate diagnostics. + if i+1 < len(diags) { + fmt.Fprintf(bw, "\n") + } + } + + return bw.Flush() +} + +func (p *Printer) printDiagnosticHeader(w io.Writer, diag Diagnostic) { + if p.cfg.Color { + switch diag.Severity { + case SeverityLevelError: + cw := color.New(color.FgRed, color.Bold) + _, _ = cw.Fprintf(w, "Error: ") + case SeverityLevelWarn: + cw := color.New(color.FgYellow, color.Bold) + _, _ = cw.Fprintf(w, "Warning: ") + } + + cw := color.New(color.Bold) + _, _ = cw.Fprintf(w, "%s: %s\n", diag.StartPos, diag.Message) + return + } + + switch diag.Severity { + case SeverityLevelError: + _, _ = fmt.Fprintf(w, "Error: ") + case SeverityLevelWarn: + _, _ = fmt.Fprintf(w, "Warning: ") + } + fmt.Fprintf(w, "%s: %s\n", diag.StartPos, diag.Message) +} + +func (p *Printer) printRange(w io.Writer, file []byte, diag Diagnostic) { + var ( + start = diag.StartPos + end = diag.EndPos + ) + + fmt.Fprintf(w, "\n") + + var ( + lines = strings.Split(string(file), "\n") + + startLine = max(start.Line-p.cfg.ContextLinesBefore, 1) + endLine = min(end.Line+p.cfg.ContextLinesAfter, len(lines)) + + multiline = end.Line-start.Line > 0 + ) + + prefixWidth := len(strconv.Itoa(endLine)) + + for lineNum := startLine; lineNum <= endLine; lineNum++ { + line := lines[lineNum-1] + + // Print line number and margin. + printPaddedNumber(w, prefixWidth, lineNum) + fmt.Fprintf(w, " | ") + + if multiline { + // Use 0 for the column number so we never consider the starting line for + // showing |. + if inRange(lineNum, 0, start, end) { + fmt.Fprint(w, "| ") + } else { + fmt.Fprint(w, " ") + } + } + + // Print the line, but filter out any \r and replace tabs with spaces. + for _, ch := range line { + if ch == '\r' { + continue + } + if ch == '\t' || ch == '\v' { + printCh(w, tabWidth, ' ') + continue + } + fmt.Fprintf(w, "%c", ch) + } + + fmt.Fprintf(w, "\n") + + // Print the focus indicator if we're on a line that needs it. + // + // The focus indicator line must preserve whitespace present in the line + // above it prior to the focus '^' characters. Tab characters are replaced + // with spaces for consistent printing. + if lineNum == start.Line || (multiline && lineNum == end.Line) { + printCh(w, prefixWidth, ' ') // Add empty space where line number would be + + // Print the margin after the blank line number. On multi-line errors, + // the arrow is printed all the way to the margin, with straight + // lines going down in between the lines. + switch { + case multiline && lineNum == start.Line: + // |_ would look like an incorrect right angle, so the second bar + // is dropped. + fmt.Fprintf(w, " | _") + case multiline && lineNum == end.Line: + fmt.Fprintf(w, " | |_") + default: + fmt.Fprintf(w, " | ") + } + + p.printFocus(w, line, lineNum, diag) + fmt.Fprintf(w, "\n") + } + } +} + +// printFocus prints the focus indicator for the line number specified by line. +// The contents of the line should be represented by data so whitespace can be +// retained (injecting spaces where a tab should be, etc.). +func (p *Printer) printFocus(w io.Writer, data string, line int, diag Diagnostic) { + for i, ch := range data { + column := i + 1 + + if line == diag.EndPos.Line && column > diag.EndPos.Column { + // Stop printing the formatting line after printing all the ^. + break + } + + blank := byte(' ') + if diag.EndPos.Line-diag.StartPos.Line > 0 { + blank = byte('_') + } + + switch { + case ch == '\t' || ch == '\v': + printCh(w, tabWidth, blank) + case inRange(line, column, diag.StartPos, diag.EndPos): + fmt.Fprintf(w, "%c", '^') + default: + // Print a space. + fmt.Fprintf(w, "%c", blank) + } + } +} + +func inRange(line, col int, start, end token.Position) bool { + if line < start.Line || line > end.Line { + return false + } + + switch line { + case start.Line: + // If the current line is on the starting line, we have to be past the + // starting column. + return col >= start.Column + case end.Line: + // If the current line is on the ending line, we have to be before the + // final column. + return col <= end.Column + default: + // Otherwise, every column across all the lines in between + // is in the range. + return true + } +} + +func printPaddedNumber(w io.Writer, width int, num int) { + numStr := strconv.Itoa(num) + for i := 0; i < width-len(numStr); i++ { + _, _ = w.Write([]byte{' '}) + } + _, _ = w.Write([]byte(numStr)) +} + +func printCh(w io.Writer, count int, ch byte) { + for i := 0; i < count; i++ { + _, _ = w.Write([]byte{ch}) + } +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func max(a, b int) int { + if a > b { + return a + } + return b +} diff --git a/syntax/diag/printer_test.go b/syntax/diag/printer_test.go new file mode 100644 index 0000000000..4558e666ef --- /dev/null +++ b/syntax/diag/printer_test.go @@ -0,0 +1,222 @@ +package diag_test + +import ( + "bytes" + "fmt" + "testing" + + "github.com/grafana/river/diag" + "github.com/grafana/river/token" + "github.com/stretchr/testify/require" +) + +func TestFprint(t *testing.T) { + // In all tests below, the filename is "testfile" and the severity is an + // error. + + tt := []struct { + name string + input string + start, end token.Position + diag diag.Diagnostic + expect string + }{ + { + name: "highlight on same line", + start: token.Position{Line: 2, Column: 2}, + end: token.Position{Line: 2, Column: 5}, + input: `test.block "label" { + attr = 1 + other_attr = 2 +}`, + expect: `Error: testfile:2:2: synthetic error + +1 | test.block "label" { +2 | attr = 1 + | ^^^^ +3 | other_attr = 2 +`, + }, + + { + name: "end positions should be optional", + start: token.Position{Line: 1, Column: 4}, + input: `foo,bar`, + expect: `Error: testfile:1:4: synthetic error + +1 | foo,bar + | ^ +`, + }, + + { + name: "padding should be inserted to fit line numbers of different lengths", + start: token.Position{Line: 9, Column: 1}, + end: token.Position{Line: 9, Column: 6}, + input: `LINE_1 +LINE_2 +LINE_3 +LINE_4 +LINE_5 +LINE_6 +LINE_7 +LINE_8 +LINE_9 +LINE_10 +LINE_11`, + expect: `Error: testfile:9:1: synthetic error + + 8 | LINE_8 + 9 | LINE_9 + | ^^^^^^ +10 | LINE_10 +`, + }, + + { + name: "errors which cross multiple lines can be printed from start of line", + start: token.Position{Line: 2, Column: 1}, + end: token.Position{Line: 6, Column: 7}, + input: `FILE_BEGIN +START +TEXT + TEXT + TEXT + DONE after +FILE_END`, + expect: `Error: testfile:2:1: synthetic error + +1 | FILE_BEGIN +2 | START + | _^^^^^ +3 | | TEXT +4 | | TEXT +5 | | TEXT +6 | | DONE after + | |_____________^^^^ +7 | FILE_END +`, + }, + + { + name: "errors which cross multiple lines can be printed from middle of line", + start: token.Position{Line: 2, Column: 8}, + end: token.Position{Line: 6, Column: 7}, + input: `FILE_BEGIN +before START +TEXT + TEXT + TEXT + DONE after +FILE_END`, + expect: `Error: testfile:2:8: synthetic error + +1 | FILE_BEGIN +2 | before START + | ________^^^^^ +3 | | TEXT +4 | | TEXT +5 | | TEXT +6 | | DONE after + | |_____________^^^^ +7 | FILE_END +`, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + files := map[string][]byte{ + "testfile": []byte(tc.input), + } + + tc.start.Filename = "testfile" + tc.end.Filename = "testfile" + + diags := diag.Diagnostics{{ + Severity: diag.SeverityLevelError, + StartPos: tc.start, + EndPos: tc.end, + Message: "synthetic error", + }} + + var buf bytes.Buffer + _ = diag.Fprint(&buf, files, diags) + requireEqualStrings(t, tc.expect, buf.String()) + }) + } +} + +func TestFprint_MultipleDiagnostics(t *testing.T) { + fileA := `old_field = 15 +3 & 4` + fileB := `old_field = 22` + + files := map[string][]byte{ + "file_a": []byte(fileA), + "file_b": []byte(fileB), + } + + diags := diag.Diagnostics{ + { + Severity: diag.SeverityLevelWarn, + StartPos: token.Position{Filename: "file_a", Line: 1, Column: 1}, + EndPos: token.Position{Filename: "file_a", Line: 1, Column: 9}, + Message: "old_field is deprecated", + }, + { + Severity: diag.SeverityLevelError, + StartPos: token.Position{Filename: "file_a", Line: 2, Column: 3}, + Message: "unrecognized operator &", + }, + { + Severity: diag.SeverityLevelWarn, + StartPos: token.Position{Filename: "file_b", Line: 1, Column: 1}, + EndPos: token.Position{Filename: "file_b", Line: 1, Column: 9}, + Message: "old_field is deprecated", + }, + } + + expect := `Warning: file_a:1:1: old_field is deprecated + +1 | old_field = 15 + | ^^^^^^^^^ +2 | 3 & 4 + +Error: file_a:2:3: unrecognized operator & + +1 | old_field = 15 +2 | 3 & 4 + | ^ + +Warning: file_b:1:1: old_field is deprecated + +1 | old_field = 22 + | ^^^^^^^^^ +` + + var buf bytes.Buffer + _ = diag.Fprint(&buf, files, diags) + requireEqualStrings(t, expect, buf.String()) +} + +// requireEqualStrings is like require.Equal with two strings but it +// pretty-prints multiline strings to make it easier to compare. +func requireEqualStrings(t *testing.T, expected, actual string) { + if expected == actual { + return + } + + msg := fmt.Sprintf( + "Not equal:\n"+ + "raw expected: %#v\n"+ + "raw actual : %#v\n"+ + "\n"+ + "expected:\n%s\n"+ + "actual:\n%s\n", + expected, actual, + expected, actual, + ) + + require.Fail(t, msg) +} diff --git a/syntax/encoding/riverjson/riverjson.go b/syntax/encoding/riverjson/riverjson.go new file mode 100644 index 0000000000..fc69882918 --- /dev/null +++ b/syntax/encoding/riverjson/riverjson.go @@ -0,0 +1,313 @@ +// Package riverjson encodes River as JSON. +package riverjson + +import ( + "encoding/json" + "fmt" + "reflect" + "sort" + "strings" + + "github.com/grafana/river/internal/reflectutil" + "github.com/grafana/river/internal/rivertags" + "github.com/grafana/river/internal/value" + "github.com/grafana/river/token/builder" +) + +var goRiverDefaulter = reflect.TypeOf((*value.Defaulter)(nil)).Elem() + +// MarshalBody marshals the provided Go value to a JSON representation of +// River. MarshalBody panics if not given a struct with River tags or a map[string]any. +func MarshalBody(val interface{}) ([]byte, error) { + rv := reflect.ValueOf(val) + return json.Marshal(encodeStructAsBody(rv)) +} + +func encodeStructAsBody(rv reflect.Value) jsonBody { + for rv.Kind() == reflect.Pointer { + if rv.IsNil() { + return jsonBody{} + } + rv = rv.Elem() + } + + if rv.Kind() == reflect.Invalid { + return jsonBody{} + } + + body := jsonBody{} + + switch rv.Kind() { + case reflect.Struct: + fields := rivertags.Get(rv.Type()) + defaults := reflect.New(rv.Type()).Elem() + if defaults.CanAddr() && defaults.Addr().Type().Implements(goRiverDefaulter) { + defaults.Addr().Interface().(value.Defaulter).SetToDefault() + } + + for _, field := range fields { + fieldVal := reflectutil.Get(rv, field) + fieldValDefault := reflectutil.Get(defaults, field) + + isEqual := fieldVal.Comparable() && fieldVal.Equal(fieldValDefault) + isZero := fieldValDefault.IsZero() && fieldVal.IsZero() + + if field.IsOptional() && (isEqual || isZero) { + continue + } + + body = append(body, encodeFieldAsStatements(nil, field, fieldVal)...) + } + + case reflect.Map: + if rv.Type().Key().Kind() != reflect.String { + panic("river/encoding/riverjson: unsupported map type; expected map[string]T, got " + rv.Type().String()) + } + + iter := rv.MapRange() + for iter.Next() { + mapKey, mapValue := iter.Key(), iter.Value() + + body = append(body, jsonAttr{ + Name: mapKey.String(), + Type: "attr", + Value: buildJSONValue(value.FromRaw(mapValue)), + }) + } + + default: + panic(fmt.Sprintf("river/encoding/riverjson: can only encode struct or map[string]T values to bodies, got %s", rv.Kind())) + } + + return body +} + +// encodeFieldAsStatements encodes an individual field from a struct as a set +// of statements. One field may map to multiple statements in the case of a +// slice of blocks. +func encodeFieldAsStatements(prefix []string, field rivertags.Field, fieldValue reflect.Value) []jsonStatement { + fieldName := strings.Join(field.Name, ".") + + for fieldValue.Kind() == reflect.Pointer { + if fieldValue.IsNil() { + break + } + fieldValue = fieldValue.Elem() + } + + switch { + case field.IsAttr(): + return []jsonStatement{jsonAttr{ + Name: fieldName, + Type: "attr", + Value: buildJSONValue(value.FromRaw(fieldValue)), + }} + + case field.IsBlock(): + fullName := mergeStringSlice(prefix, field.Name) + + switch { + case fieldValue.Kind() == reflect.Map: + // Iterate over the map and add each element as an attribute into it. + + if fieldValue.Type().Key().Kind() != reflect.String { + panic("river/encoding/riverjson: unsupported map type for block; expected map[string]T, got " + fieldValue.Type().String()) + } + + statements := []jsonStatement{} + + iter := fieldValue.MapRange() + for iter.Next() { + mapKey, mapValue := iter.Key(), iter.Value() + + statements = append(statements, jsonAttr{ + Name: mapKey.String(), + Type: "attr", + Value: buildJSONValue(value.FromRaw(mapValue)), + }) + } + + return []jsonStatement{jsonBlock{ + Name: strings.Join(fullName, "."), + Type: "block", + Body: statements, + }} + + case fieldValue.Kind() == reflect.Slice, fieldValue.Kind() == reflect.Array: + statements := []jsonStatement{} + + for i := 0; i < fieldValue.Len(); i++ { + elem := fieldValue.Index(i) + + // Recursively call encodeField for each element in the slice/array. + // The recursive call will hit the case below and add a new block for + // each field encountered. + statements = append(statements, encodeFieldAsStatements(prefix, field, elem)...) + } + + return statements + + case fieldValue.Kind() == reflect.Struct: + if fieldValue.IsZero() { + // It shouldn't be possible to have a required block which is unset, but + // we'll encode something anyway. + return []jsonStatement{jsonBlock{ + Name: strings.Join(fullName, "."), + Type: "block", + + // Never set this to nil, since the API contract always expects blocks + // to have an array value for the body. + Body: []jsonStatement{}, + }} + } + + return []jsonStatement{jsonBlock{ + Name: strings.Join(fullName, "."), + Type: "block", + Label: getBlockLabel(fieldValue), + Body: encodeStructAsBody(fieldValue), + }} + } + + case field.IsEnum(): + // Blocks within an enum have a prefix set. + newPrefix := mergeStringSlice(prefix, field.Name) + + switch { + case fieldValue.Kind() == reflect.Slice, fieldValue.Kind() == reflect.Array: + statements := []jsonStatement{} + for i := 0; i < fieldValue.Len(); i++ { + statements = append(statements, encodeEnumElementToStatements(newPrefix, fieldValue.Index(i))...) + } + return statements + + default: + panic(fmt.Sprintf("river/encoding/riverjson: unrecognized enum kind %s", fieldValue.Kind())) + } + } + + return nil +} + +func mergeStringSlice(a, b []string) []string { + if len(a) == 0 { + return b + } else if len(b) == 0 { + return a + } + + res := make([]string, 0, len(a)+len(b)) + res = append(res, a...) + res = append(res, b...) + return res +} + +// getBlockLabel returns the label for a given block. +func getBlockLabel(rv reflect.Value) string { + tags := rivertags.Get(rv.Type()) + for _, tag := range tags { + if tag.Flags&rivertags.FlagLabel != 0 { + return reflectutil.Get(rv, tag).String() + } + } + + return "" +} + +func encodeEnumElementToStatements(prefix []string, enumElement reflect.Value) []jsonStatement { + for enumElement.Kind() == reflect.Pointer { + if enumElement.IsNil() { + return nil + } + enumElement = enumElement.Elem() + } + + fields := rivertags.Get(enumElement.Type()) + + statements := []jsonStatement{} + + // Find the first non-zero field and encode it. + for _, field := range fields { + fieldVal := reflectutil.Get(enumElement, field) + if !fieldVal.IsValid() || fieldVal.IsZero() { + continue + } + + statements = append(statements, encodeFieldAsStatements(prefix, field, fieldVal)...) + break + } + + return statements +} + +// MarshalValue marshals the provided Go value to a JSON representation of +// River. +func MarshalValue(val interface{}) ([]byte, error) { + riverValue := value.Encode(val) + return json.Marshal(buildJSONValue(riverValue)) +} + +func buildJSONValue(v value.Value) jsonValue { + if tk, ok := v.Interface().(builder.Tokenizer); ok { + return jsonValue{ + Type: "capsule", + Value: tk.RiverTokenize()[0].Lit, + } + } + + switch v.Type() { + case value.TypeNull: + return jsonValue{Type: "null"} + + case value.TypeNumber: + return jsonValue{Type: "number", Value: v.Number().Float()} + + case value.TypeString: + return jsonValue{Type: "string", Value: v.Text()} + + case value.TypeBool: + return jsonValue{Type: "bool", Value: v.Bool()} + + case value.TypeArray: + elements := []interface{}{} + + for i := 0; i < v.Len(); i++ { + element := v.Index(i) + + elements = append(elements, buildJSONValue(element)) + } + + return jsonValue{Type: "array", Value: elements} + + case value.TypeObject: + keys := v.Keys() + + // If v isn't an ordered object (i.e., a go map), sort the keys so they + // have a deterministic print order. + if !v.OrderedKeys() { + sort.Strings(keys) + } + + fields := []jsonObjectField{} + + for i := 0; i < len(keys); i++ { + field, _ := v.Key(keys[i]) + + fields = append(fields, jsonObjectField{ + Key: keys[i], + Value: buildJSONValue(field), + }) + } + + return jsonValue{Type: "object", Value: fields} + + case value.TypeFunction: + return jsonValue{Type: "function", Value: v.Describe()} + + case value.TypeCapsule: + return jsonValue{Type: "capsule", Value: v.Describe()} + + default: + panic(fmt.Sprintf("river/encoding/riverjson: unrecognized value type %q", v.Type())) + } +} diff --git a/syntax/encoding/riverjson/riverjson_test.go b/syntax/encoding/riverjson/riverjson_test.go new file mode 100644 index 0000000000..0eeb321b59 --- /dev/null +++ b/syntax/encoding/riverjson/riverjson_test.go @@ -0,0 +1,363 @@ +package riverjson_test + +import ( + "testing" + + river "github.com/grafana/river" + "github.com/grafana/river/encoding/riverjson" + "github.com/grafana/river/rivertypes" + "github.com/stretchr/testify/require" +) + +func TestValues(t *testing.T) { + tt := []struct { + name string + input interface{} + expectJSON string + }{ + { + name: "null", + input: nil, + expectJSON: `{ "type": "null", "value": null }`, + }, + { + name: "number", + input: 54, + expectJSON: `{ "type": "number", "value": 54 }`, + }, + { + name: "string", + input: "Hello, world!", + expectJSON: `{ "type": "string", "value": "Hello, world!" }`, + }, + { + name: "bool", + input: true, + expectJSON: `{ "type": "bool", "value": true }`, + }, + { + name: "simple array", + input: []int{0, 1, 2, 3, 4}, + expectJSON: `{ + "type": "array", + "value": [ + { "type": "number", "value": 0 }, + { "type": "number", "value": 1 }, + { "type": "number", "value": 2 }, + { "type": "number", "value": 3 }, + { "type": "number", "value": 4 } + ] + }`, + }, + { + name: "nested array", + input: []interface{}{"testing", []int{0, 1, 2}}, + expectJSON: `{ + "type": "array", + "value": [ + { "type": "string", "value": "testing" }, + { + "type": "array", + "value": [ + { "type": "number", "value": 0 }, + { "type": "number", "value": 1 }, + { "type": "number", "value": 2 } + ] + } + ] + }`, + }, + { + name: "object", + input: map[string]any{"foo": "bar", "fizz": "buzz", "year": 2023}, + expectJSON: `{ + "type": "object", + "value": [ + { "key": "fizz", "value": { "type": "string", "value": "buzz" }}, + { "key": "foo", "value": { "type": "string", "value": "bar" }}, + { "key": "year", "value": { "type": "number", "value": 2023 }} + ] + }`, + }, + { + name: "function", + input: func(i int) int { return i * 2 }, + expectJSON: `{ "type": "function", "value": "function" }`, + }, + { + name: "capsule", + input: rivertypes.Secret("foo"), + expectJSON: `{ "type": "capsule", "value": "(secret)" }`, + }, + { + // nil arrays and objects must always be [] instead of null as that's + // what the API definition says they should be. + name: "nil array", + input: ([]any)(nil), + expectJSON: `{ "type": "array", "value": [] }`, + }, + { + // nil arrays and objects must always be [] instead of null as that's + // what the API definition says they should be. + name: "nil object", + input: (map[string]any)(nil), + expectJSON: `{ "type": "object", "value": [] }`, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + actual, err := riverjson.MarshalValue(tc.input) + require.NoError(t, err) + require.JSONEq(t, tc.expectJSON, string(actual)) + }) + } +} + +func TestBlock(t *testing.T) { + // Zero values should be omitted from result. + + val := testBlock{ + Number: 5, + Array: []any{1, 2, 3}, + Labeled: []labeledBlock{ + { + TestBlock: testBlock{Boolean: true}, + Label: "label_a", + }, + { + TestBlock: testBlock{String: "foo"}, + Label: "label_b", + }, + }, + Blocks: []testBlock{ + {String: "hello"}, + {String: "world"}, + }, + } + + expect := `[ + { + "name": "number", + "type": "attr", + "value": { "type": "number", "value": 5 } + }, + { + "name": "array", + "type": "attr", + "value": { + "type": "array", + "value": [ + { "type": "number", "value": 1 }, + { "type": "number", "value": 2 }, + { "type": "number", "value": 3 } + ] + } + }, + { + "name": "labeled_block", + "type": "block", + "label": "label_a", + "body": [{ + "name": "boolean", + "type": "attr", + "value": { "type": "bool", "value": true } + }] + }, + { + "name": "labeled_block", + "type": "block", + "label": "label_b", + "body": [{ + "name": "string", + "type": "attr", + "value": { "type": "string", "value": "foo" } + }] + }, + { + "name": "inner_block", + "type": "block", + "body": [{ + "name": "string", + "type": "attr", + "value": { "type": "string", "value": "hello" } + }] + }, + { + "name": "inner_block", + "type": "block", + "body": [{ + "name": "string", + "type": "attr", + "value": { "type": "string", "value": "world" } + }] + } + ]` + + actual, err := riverjson.MarshalBody(val) + require.NoError(t, err) + require.JSONEq(t, expect, string(actual)) +} + +func TestBlock_Empty_Required_Block_Slice(t *testing.T) { + type wrapper struct { + Blocks []testBlock `river:"some_block,block"` + } + + tt := []struct { + name string + val any + }{ + {"nil block slice", wrapper{Blocks: nil}}, + {"empty block slice", wrapper{Blocks: []testBlock{}}}, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + expect := `[]` + + actual, err := riverjson.MarshalBody(tc.val) + require.NoError(t, err) + require.JSONEq(t, expect, string(actual)) + }) + } +} + +type testBlock struct { + Number int `river:"number,attr,optional"` + String string `river:"string,attr,optional"` + Boolean bool `river:"boolean,attr,optional"` + Array []any `river:"array,attr,optional"` + Object map[string]any `river:"object,attr,optional"` + + Labeled []labeledBlock `river:"labeled_block,block,optional"` + Blocks []testBlock `river:"inner_block,block,optional"` +} + +type labeledBlock struct { + TestBlock testBlock `river:",squash"` + Label string `river:",label"` +} + +func TestNilBody(t *testing.T) { + actual, err := riverjson.MarshalBody(nil) + require.NoError(t, err) + require.JSONEq(t, `[]`, string(actual)) +} + +func TestEmptyBody(t *testing.T) { + type block struct{} + + actual, err := riverjson.MarshalBody(block{}) + require.NoError(t, err) + require.JSONEq(t, `[]`, string(actual)) +} + +func TestHideDefaults(t *testing.T) { + tt := []struct { + name string + val defaultsBlock + expectJSON string + }{ + { + name: "no defaults", + val: defaultsBlock{ + Name: "Jane", + Age: 41, + }, + expectJSON: `[ + { "name": "name", "type": "attr", "value": { "type": "string", "value": "Jane" }}, + { "name": "age", "type": "attr", "value": { "type": "number", "value": 41 }} + ]`, + }, + { + name: "some defaults", + val: defaultsBlock{ + Name: "John Doe", + Age: 41, + }, + expectJSON: `[ + { "name": "age", "type": "attr", "value": { "type": "number", "value": 41 }} + ]`, + }, + { + name: "all defaults", + val: defaultsBlock{ + Name: "John Doe", + Age: 35, + }, + expectJSON: `[]`, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + actual, err := riverjson.MarshalBody(tc.val) + require.NoError(t, err) + require.JSONEq(t, tc.expectJSON, string(actual)) + }) + } +} + +type defaultsBlock struct { + Name string `river:"name,attr,optional"` + Age int `river:"age,attr,optional"` +} + +var _ river.Defaulter = (*defaultsBlock)(nil) + +func (d *defaultsBlock) SetToDefault() { + *d = defaultsBlock{ + Name: "John Doe", + Age: 35, + } +} + +func TestMapBlocks(t *testing.T) { + type block struct { + Value map[string]any `river:"block,block,optional"` + } + val := block{Value: map[string]any{"field": "value"}} + + expect := `[{ + "name": "block", + "type": "block", + "body": [{ + "name": "field", + "type": "attr", + "value": { "type": "string", "value": "value" } + }] + }]` + + bb, err := riverjson.MarshalBody(val) + require.NoError(t, err) + require.JSONEq(t, expect, string(bb)) +} + +func TestRawMap(t *testing.T) { + val := map[string]any{"field": "value"} + + expect := `[{ + "name": "field", + "type": "attr", + "value": { "type": "string", "value": "value" } + }]` + + bb, err := riverjson.MarshalBody(val) + require.NoError(t, err) + require.JSONEq(t, expect, string(bb)) +} + +func TestRawMap_Capsule(t *testing.T) { + val := map[string]any{"capsule": rivertypes.Secret("foo")} + + expect := `[{ + "name": "capsule", + "type": "attr", + "value": { "type": "capsule", "value": "(secret)" } + }]` + + bb, err := riverjson.MarshalBody(val) + require.NoError(t, err) + require.JSONEq(t, expect, string(bb)) +} diff --git a/syntax/encoding/riverjson/types.go b/syntax/encoding/riverjson/types.go new file mode 100644 index 0000000000..3170331e46 --- /dev/null +++ b/syntax/encoding/riverjson/types.go @@ -0,0 +1,41 @@ +package riverjson + +// Various concrete types used to marshal River values. +type ( + // jsonStatement is a statement within a River body. + jsonStatement interface{ isStatement() } + + // A jsonBody is a collection of statements. + jsonBody = []jsonStatement + + // jsonBlock represents a River block as JSON. jsonBlock is a jsonStatement. + jsonBlock struct { + Name string `json:"name"` + Type string `json:"type"` // Always "block" + Label string `json:"label,omitempty"` + Body []jsonStatement `json:"body"` + } + + // jsonAttr represents a River attribute as JSON. jsonAttr is a + // jsonStatement. + jsonAttr struct { + Name string `json:"name"` + Type string `json:"type"` // Always "attr" + Value jsonValue `json:"value"` + } + + // jsonValue represents a single River value as JSON. + jsonValue struct { + Type string `json:"type"` + Value interface{} `json:"value"` + } + + // jsonObjectField represents a field within a River object. + jsonObjectField struct { + Key string `json:"key"` + Value interface{} `json:"value"` + } +) + +func (jsonBlock) isStatement() {} +func (jsonAttr) isStatement() {} diff --git a/syntax/go.mod b/syntax/go.mod new file mode 100644 index 0000000000..6b85c2e5da --- /dev/null +++ b/syntax/go.mod @@ -0,0 +1,18 @@ +module github.com/grafana/river + +go 1.21.0 + +require ( + github.com/fatih/color v1.15.0 + github.com/ohler55/ojg v1.20.1 + github.com/stretchr/testify v1.8.4 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.17 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/sys v0.6.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/syntax/go.sum b/syntax/go.sum new file mode 100644 index 0000000000..a972b9838b --- /dev/null +++ b/syntax/go.sum @@ -0,0 +1,22 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fatih/color v1.15.0 h1:kOqh6YHBtK8aywxGerMG2Eq3H6Qgoqeo13Bk2Mv/nBs= +github.com/fatih/color v1.15.0/go.mod h1:0h5ZqXfHYED7Bhv2ZJamyIOUej9KtShiJESRwBDUSsw= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng= +github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/ohler55/ojg v1.20.1 h1:Io65sHjMjYPI7yuhUr8VdNmIQdYU6asKeFhOs8xgBnY= +github.com/ohler55/ojg v1.20.1/go.mod h1:uHcD1ErbErC27Zhb5Df2jUjbseLLcmOCo6oxSr3jZxo= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/syntax/internal/reflectutil/walk.go b/syntax/internal/reflectutil/walk.go new file mode 100644 index 0000000000..17cbb25d49 --- /dev/null +++ b/syntax/internal/reflectutil/walk.go @@ -0,0 +1,89 @@ +package reflectutil + +import ( + "reflect" + + "github.com/grafana/river/internal/rivertags" +) + +// GetOrAlloc returns the nested field of value corresponding to index. +// GetOrAlloc panics if not given a struct. +func GetOrAlloc(value reflect.Value, field rivertags.Field) reflect.Value { + return GetOrAllocIndex(value, field.Index) +} + +// GetOrAllocIndex returns the nested field of value corresponding to index. +// GetOrAllocIndex panics if not given a struct. +// +// It is similar to [reflect/Value.FieldByIndex] but can handle traversing +// through nil pointers. If allocate is true, GetOrAllocIndex allocates any +// intermediate nil pointers while traversing the struct. +func GetOrAllocIndex(value reflect.Value, index []int) reflect.Value { + if len(index) == 1 { + return value.Field(index[0]) + } + + if value.Kind() != reflect.Struct { + panic("GetOrAlloc must be given a Struct, but found " + value.Kind().String()) + } + + for _, next := range index { + value = deferencePointer(value).Field(next) + } + + return value +} + +func deferencePointer(value reflect.Value) reflect.Value { + for value.Kind() == reflect.Pointer { + if value.IsNil() { + value.Set(reflect.New(value.Type().Elem())) + } + value = value.Elem() + } + + return value +} + +// Get returns the nested field of value corresponding to index. Get panics if +// not given a struct. +// +// It is similar to [reflect/Value.FieldByIndex] but can handle traversing +// through nil pointers. If Get traverses through a nil pointer, a non-settable +// zero value for the final field is returned. +func Get(value reflect.Value, field rivertags.Field) reflect.Value { + if len(field.Index) == 1 { + return value.Field(field.Index[0]) + } + + if value.Kind() != reflect.Struct { + panic("Get must be given a Struct, but found " + value.Kind().String()) + } + + for i, next := range field.Index { + for value.Kind() == reflect.Pointer { + if value.IsNil() { + return getZero(value, field.Index[i:]) + } + value = value.Elem() + } + + value = value.Field(next) + } + + return value +} + +// getZero returns a non-settable zero value while walking value. +func getZero(value reflect.Value, index []int) reflect.Value { + typ := value.Type() + + for _, next := range index { + for typ.Kind() == reflect.Pointer { + typ = typ.Elem() + } + typ = typ.Field(next).Type + } + + return reflect.Zero(typ) +} diff --git a/syntax/internal/reflectutil/walk_test.go b/syntax/internal/reflectutil/walk_test.go new file mode 100644 index 0000000000..f536770e5a --- /dev/null +++ b/syntax/internal/reflectutil/walk_test.go @@ -0,0 +1,72 @@ +package reflectutil_test + +import ( + "reflect" + "testing" + + "github.com/grafana/river/internal/reflectutil" + "github.com/grafana/river/internal/rivertags" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDeeplyNested_Access(t *testing.T) { + type Struct struct { + Field1 struct { + Field2 struct { + Field3 struct { + Value string + } + } + } + } + + var s Struct + s.Field1.Field2.Field3.Value = "Hello, world!" + + rv := reflect.ValueOf(&s).Elem() + innerValue := reflectutil.GetOrAlloc(rv, rivertags.Field{Index: []int{0, 0, 0, 0}}) + assert.True(t, innerValue.CanSet()) + assert.Equal(t, reflect.String, innerValue.Kind()) +} + +func TestDeeplyNested_Allocate(t *testing.T) { + type Struct struct { + Field1 *struct { + Field2 *struct { + Field3 *struct { + Value string + } + } + } + } + + var s Struct + + rv := reflect.ValueOf(&s).Elem() + innerValue := reflectutil.GetOrAlloc(rv, rivertags.Field{Index: []int{0, 0, 0, 0}}) + require.True(t, innerValue.CanSet()) + require.Equal(t, reflect.String, innerValue.Kind()) + + innerValue.Set(reflect.ValueOf("Hello, world!")) + require.Equal(t, "Hello, world!", s.Field1.Field2.Field3.Value) +} + +func TestDeeplyNested_NoAllocate(t *testing.T) { + type Struct struct { + Field1 *struct { + Field2 *struct { + Field3 *struct { + Value string + } + } + } + } + + var s Struct + + rv := reflect.ValueOf(&s).Elem() + innerValue := reflectutil.Get(rv, rivertags.Field{Index: []int{0, 0, 0, 0}}) + assert.False(t, innerValue.CanSet()) + assert.Equal(t, reflect.String, innerValue.Kind()) +} diff --git a/syntax/internal/rivertags/rivertags.go b/syntax/internal/rivertags/rivertags.go new file mode 100644 index 0000000000..8186a644e0 --- /dev/null +++ b/syntax/internal/rivertags/rivertags.go @@ -0,0 +1,346 @@ +// Package rivertags decodes a struct type into river object +// and structural tags. +package rivertags + +import ( + "fmt" + "reflect" + "strings" +) + +// Flags is a bitmap of flags associated with a field on a struct. +type Flags uint + +// Valid flags. +const ( + FlagAttr Flags = 1 << iota // FlagAttr treats a field as attribute + FlagBlock // FlagBlock treats a field as a block + FlagEnum // FlagEnum treats a field as an enum of blocks + + FlagOptional // FlagOptional marks a field optional for decoding/encoding + FlagLabel // FlagLabel will store block labels in the field + FlagSquash // FlagSquash will expose inner fields from a struct as outer fields. +) + +// String returns the flags as a string. +func (f Flags) String() string { + attrs := make([]string, 0, 5) + + if f&FlagAttr != 0 { + attrs = append(attrs, "attr") + } + if f&FlagBlock != 0 { + attrs = append(attrs, "block") + } + if f&FlagEnum != 0 { + attrs = append(attrs, "enum") + } + if f&FlagOptional != 0 { + attrs = append(attrs, "optional") + } + if f&FlagLabel != 0 { + attrs = append(attrs, "label") + } + if f&FlagSquash != 0 { + attrs = append(attrs, "squash") + } + + return fmt.Sprintf("Flags(%s)", strings.Join(attrs, ",")) +} + +// GoString returns the %#v format of Flags. +func (f Flags) GoString() string { return f.String() } + +// Field is a tagged field within a struct. +type Field struct { + Name []string // Name of tagged field. + Index []int // Index into field. Use [reflectutil.GetOrAlloc] to retrieve a Value. + Flags Flags // Flags assigned to field. +} + +// Equals returns true if two fields are equal. +func (f Field) Equals(other Field) bool { + // Compare names + { + if len(f.Name) != len(other.Name) { + return false + } + + for i := 0; i < len(f.Name); i++ { + if f.Name[i] != other.Name[i] { + return false + } + } + } + + // Compare index. + { + if len(f.Index) != len(other.Index) { + return false + } + + for i := 0; i < len(f.Index); i++ { + if f.Index[i] != other.Index[i] { + return false + } + } + } + + // Finally, compare flags. + return f.Flags == other.Flags +} + +// IsAttr returns whether f is for an attribute. +func (f Field) IsAttr() bool { return f.Flags&FlagAttr != 0 } + +// IsBlock returns whether f is for a block. +func (f Field) IsBlock() bool { return f.Flags&FlagBlock != 0 } + +// IsEnum returns whether f represents an enum of blocks, where only one block +// is set at a time. +func (f Field) IsEnum() bool { return f.Flags&FlagEnum != 0 } + +// IsOptional returns whether f is optional. +func (f Field) IsOptional() bool { return f.Flags&FlagOptional != 0 } + +// IsLabel returns whether f is label. +func (f Field) IsLabel() bool { return f.Flags&FlagLabel != 0 } + +// Get returns the list of tagged fields for some struct type ty. Get panics if +// ty is not a struct type. +// +// Get examines each tagged field in ty for a river key. The river key is then +// parsed as containing a name for the field, followed by a required +// comma-separated list of options. The name may be empty for fields which do +// not require a name. Get will ignore any field that is not tagged with a +// river key. +// +// Get will treat anonymous struct fields as if the inner fields were fields in +// the outer struct. +// +// Examples of struct field tags and their meanings: +// +// // Field is used as a required block named "my_block". +// Field struct{} `river:"my_block,block"` +// +// // Field is used as an optional block named "my_block". +// Field struct{} `river:"my_block,block,optional"` +// +// // Field is used as a required attribute named "my_attr". +// Field string `river:"my_attr,attr"` +// +// // Field is used as an optional attribute named "my_attr". +// Field string `river:"my_attr,attr,optional"` +// +// // Field is used for storing the label of the block which the struct +// // represents. +// Field string `river:",label"` +// +// // Attributes and blocks inside of Field are exposed as top-level fields. +// Field struct{} `river:",squash"` +// +// Blocks []struct{} `river:"my_block_prefix,enum"` +// +// With the exception of the `river:",label"` and `river:",squash" tags, all +// tagged fields must have a unique name. +// +// The type of tagged fields may be any Go type, with the exception of +// `river:",label"` tags, which must be strings. +func Get(ty reflect.Type) []Field { + if k := ty.Kind(); k != reflect.Struct { + panic(fmt.Sprintf("rivertags: Get requires struct kind, got %s", k)) + } + + var ( + fields []Field + + usedNames = make(map[string][]int) + usedLabelField = []int(nil) + ) + + for _, field := range reflect.VisibleFields(ty) { + // River does not support embedding of fields + if field.Anonymous { + panic(fmt.Sprintf("river: anonymous fields not supported %s", printPathToField(ty, field.Index))) + } + + tag, tagged := field.Tag.Lookup("river") + if !tagged { + continue + } + + if !field.IsExported() { + panic(fmt.Sprintf("river: river tag found on unexported field at %s", printPathToField(ty, field.Index))) + } + + options := strings.SplitN(tag, ",", 2) + if len(options) == 0 { + panic(fmt.Sprintf("river: unsupported empty tag at %s", printPathToField(ty, field.Index))) + } + if len(options) != 2 { + panic(fmt.Sprintf("river: field %s tag is missing options", printPathToField(ty, field.Index))) + } + + fullName := options[0] + + tf := Field{ + Name: strings.Split(fullName, "."), + Index: field.Index, + } + + if first, used := usedNames[fullName]; used && fullName != "" { + panic(fmt.Sprintf("river: field name %s already used by %s", fullName, printPathToField(ty, first))) + } + usedNames[fullName] = tf.Index + + flags, ok := parseFlags(options[1]) + if !ok { + panic(fmt.Sprintf("river: unrecognized river tag format %q at %s", tag, printPathToField(ty, tf.Index))) + } + tf.Flags = flags + + if len(tf.Name) > 1 && tf.Flags&(FlagBlock|FlagEnum) == 0 { + panic(fmt.Sprintf("river: field names with `.` may only be used by blocks or enums (found at %s)", printPathToField(ty, tf.Index))) + } + + if tf.Flags&FlagEnum != 0 { + if err := validateEnum(field); err != nil { + panic(err) + } + } + + if tf.Flags&FlagLabel != 0 { + if fullName != "" { + panic(fmt.Sprintf("river: label field at %s must not have a name", printPathToField(ty, tf.Index))) + } + if field.Type.Kind() != reflect.String { + panic(fmt.Sprintf("river: label field at %s must be a string", printPathToField(ty, tf.Index))) + } + + if usedLabelField != nil { + panic(fmt.Sprintf("river: label field already used by %s", printPathToField(ty, tf.Index))) + } + usedLabelField = tf.Index + } + + if tf.Flags&FlagSquash != 0 { + if fullName != "" { + panic(fmt.Sprintf("river: squash field at %s must not have a name", printPathToField(ty, tf.Index))) + } + + innerType := deferenceType(field.Type) + + switch { + case isStructType(innerType): // Squashed struct + // Get the inner fields from the squashed struct and append each of them. + // The index of the squashed field is prepended to the index of the inner + // struct. + innerFields := Get(deferenceType(field.Type)) + for _, innerField := range innerFields { + fields = append(fields, Field{ + Name: innerField.Name, + Index: append(field.Index, innerField.Index...), + Flags: innerField.Flags, + }) + } + + default: + panic(fmt.Sprintf("rivertags: squash field requires struct, got %s", innerType)) + } + + continue + } + + if fullName == "" && tf.Flags&(FlagLabel|FlagSquash) == 0 /* (e.g., *not* a label or squash) */ { + panic(fmt.Sprintf("river: non-empty field name required at %s", printPathToField(ty, tf.Index))) + } + + fields = append(fields, tf) + } + + return fields +} + +func parseFlags(input string) (f Flags, ok bool) { + switch input { + case "attr": + f |= FlagAttr + case "attr,optional": + f |= FlagAttr | FlagOptional + case "block": + f |= FlagBlock + case "block,optional": + f |= FlagBlock | FlagOptional + case "enum": + f |= FlagEnum + case "enum,optional": + f |= FlagEnum | FlagOptional + case "label": + f |= FlagLabel + case "squash": + f |= FlagSquash + default: + return + } + + return f, true +} + +func printPathToField(structTy reflect.Type, path []int) string { + var sb strings.Builder + + sb.WriteString(structTy.String()) + sb.WriteString(".") + + cur := structTy + for i, elem := range path { + sb.WriteString(cur.Field(elem).Name) + + if i+1 < len(path) { + sb.WriteString(".") + } + + cur = cur.Field(i).Type + } + + return sb.String() +} + +func deferenceType(ty reflect.Type) reflect.Type { + for ty.Kind() == reflect.Pointer { + ty = ty.Elem() + } + return ty +} + +func isStructType(ty reflect.Type) bool { + return ty.Kind() == reflect.Struct +} + +// validateEnum ensures that an enum field is valid. Valid enum fields are +// slices of structs containing nothing but non-slice blocks. +func validateEnum(field reflect.StructField) error { + kind := field.Type.Kind() + if kind != reflect.Slice && kind != reflect.Array { + return fmt.Errorf("enum fields can only be slices or arrays") + } + + elementType := deferenceType(field.Type.Elem()) + if elementType.Kind() != reflect.Struct { + return fmt.Errorf("enum fields can only be a slice or array of structs") + } + + enumElementFields := Get(elementType) + for _, field := range enumElementFields { + if !field.IsBlock() { + return fmt.Errorf("fields in an enum element may only be blocks, got " + field.Flags.String()) + } + + fieldType := deferenceType(elementType.FieldByIndex(field.Index).Type) + if fieldType.Kind() != reflect.Struct { + return fmt.Errorf("blocks in an enum element may only be structs, got " + fieldType.Kind().String()) + } + } + + return nil +} diff --git a/syntax/internal/rivertags/rivertags_test.go b/syntax/internal/rivertags/rivertags_test.go new file mode 100644 index 0000000000..43370b33b4 --- /dev/null +++ b/syntax/internal/rivertags/rivertags_test.go @@ -0,0 +1,182 @@ +package rivertags_test + +import ( + "reflect" + "testing" + + "github.com/grafana/river/internal/rivertags" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_Get(t *testing.T) { + type Struct struct { + IgnoreMe bool + + ReqAttr string `river:"req_attr,attr"` + OptAttr string `river:"opt_attr,attr,optional"` + ReqBlock struct{} `river:"req_block,block"` + OptBlock struct{} `river:"opt_block,block,optional"` + ReqEnum []struct{} `river:"req_enum,enum"` + OptEnum []struct{} `river:"opt_enum,enum,optional"` + Label string `river:",label"` + } + + fs := rivertags.Get(reflect.TypeOf(Struct{})) + + expect := []rivertags.Field{ + {[]string{"req_attr"}, []int{1}, rivertags.FlagAttr}, + {[]string{"opt_attr"}, []int{2}, rivertags.FlagAttr | rivertags.FlagOptional}, + {[]string{"req_block"}, []int{3}, rivertags.FlagBlock}, + {[]string{"opt_block"}, []int{4}, rivertags.FlagBlock | rivertags.FlagOptional}, + {[]string{"req_enum"}, []int{5}, rivertags.FlagEnum}, + {[]string{"opt_enum"}, []int{6}, rivertags.FlagEnum | rivertags.FlagOptional}, + {[]string{""}, []int{7}, rivertags.FlagLabel}, + } + + require.Equal(t, expect, fs) +} + +func TestEmbedded(t *testing.T) { + type InnerStruct struct { + InnerField1 string `river:"inner_field_1,attr"` + InnerField2 string `river:"inner_field_2,attr"` + } + + type Struct struct { + Field1 string `river:"parent_field_1,attr"` + InnerStruct + Field2 string `river:"parent_field_2,attr"` + } + require.PanicsWithValue(t, "river: anonymous fields not supported rivertags_test.Struct.InnerStruct", func() { rivertags.Get(reflect.TypeOf(Struct{})) }) +} + +func TestSquash(t *testing.T) { + type InnerStruct struct { + InnerField1 string `river:"inner_field_1,attr"` + InnerField2 string `river:"inner_field_2,attr"` + } + + type Struct struct { + Field1 string `river:"parent_field_1,attr"` + Inner InnerStruct `river:",squash"` + Field2 string `river:"parent_field_2,attr"` + } + + type StructWithPointer struct { + Field1 string `river:"parent_field_1,attr"` + Inner *InnerStruct `river:",squash"` + Field2 string `river:"parent_field_2,attr"` + } + + expect := []rivertags.Field{ + { + Name: []string{"parent_field_1"}, + Index: []int{0}, + Flags: rivertags.FlagAttr, + }, + { + Name: []string{"inner_field_1"}, + Index: []int{1, 0}, + Flags: rivertags.FlagAttr, + }, + { + Name: []string{"inner_field_2"}, + Index: []int{1, 1}, + Flags: rivertags.FlagAttr, + }, + { + Name: []string{"parent_field_2"}, + Index: []int{2}, + Flags: rivertags.FlagAttr, + }, + } + + structActual := rivertags.Get(reflect.TypeOf(Struct{})) + assert.Equal(t, expect, structActual) + + structPointerActual := rivertags.Get(reflect.TypeOf(StructWithPointer{})) + assert.Equal(t, expect, structPointerActual) +} + +func TestDeepSquash(t *testing.T) { + type Inner2Struct struct { + InnerField1 string `river:"inner_field_1,attr"` + InnerField2 string `river:"inner_field_2,attr"` + } + + type InnerStruct struct { + Inner2Struct Inner2Struct `river:",squash"` + } + + type Struct struct { + Inner InnerStruct `river:",squash"` + } + + expect := []rivertags.Field{ + { + Name: []string{"inner_field_1"}, + Index: []int{0, 0, 0}, + Flags: rivertags.FlagAttr, + }, + { + Name: []string{"inner_field_2"}, + Index: []int{0, 0, 1}, + Flags: rivertags.FlagAttr, + }, + } + + structActual := rivertags.Get(reflect.TypeOf(Struct{})) + assert.Equal(t, expect, structActual) +} + +func Test_Get_Panics(t *testing.T) { + expectPanic := func(t *testing.T, expect string, v interface{}) { + t.Helper() + require.PanicsWithValue(t, expect, func() { + _ = rivertags.Get(reflect.TypeOf(v)) + }) + } + + t.Run("Tagged fields must be exported", func(t *testing.T) { + type Struct struct { + attr string `river:"field,attr"` // nolint:unused //nolint:rivertags + } + expect := `river: river tag found on unexported field at rivertags_test.Struct.attr` + expectPanic(t, expect, Struct{}) + }) + + t.Run("Options are required", func(t *testing.T) { + type Struct struct { + Attr string `river:"field"` //nolint:rivertags + } + expect := `river: field rivertags_test.Struct.Attr tag is missing options` + expectPanic(t, expect, Struct{}) + }) + + t.Run("Field names must be unique", func(t *testing.T) { + type Struct struct { + Attr string `river:"field1,attr"` + Block string `river:"field1,block,optional"` //nolint:rivertags + } + expect := `river: field name field1 already used by rivertags_test.Struct.Attr` + expectPanic(t, expect, Struct{}) + }) + + t.Run("Name is required for non-label field", func(t *testing.T) { + type Struct struct { + Attr string `river:",attr"` //nolint:rivertags + } + expect := `river: non-empty field name required at rivertags_test.Struct.Attr` + expectPanic(t, expect, Struct{}) + }) + + t.Run("Only one label field may exist", func(t *testing.T) { + type Struct struct { + Label1 string `river:",label"` + Label2 string `river:",label"` + } + expect := `river: label field already used by rivertags_test.Struct.Label2` + expectPanic(t, expect, Struct{}) + }) +} diff --git a/syntax/internal/stdlib/constants.go b/syntax/internal/stdlib/constants.go new file mode 100644 index 0000000000..89525f855f --- /dev/null +++ b/syntax/internal/stdlib/constants.go @@ -0,0 +1,19 @@ +package stdlib + +import ( + "os" + "runtime" +) + +var constants = map[string]string{ + "hostname": "", // Initialized via init function + "os": runtime.GOOS, + "arch": runtime.GOARCH, +} + +func init() { + hostname, err := os.Hostname() + if err == nil { + constants["hostname"] = hostname + } +} diff --git a/syntax/internal/stdlib/stdlib.go b/syntax/internal/stdlib/stdlib.go new file mode 100644 index 0000000000..e73b950af8 --- /dev/null +++ b/syntax/internal/stdlib/stdlib.go @@ -0,0 +1,132 @@ +// Package stdlib contains standard library functions exposed to River configs. +package stdlib + +import ( + "encoding/json" + "fmt" + "os" + "strings" + + "github.com/grafana/river/internal/value" + "github.com/grafana/river/rivertypes" + "github.com/ohler55/ojg/jp" + "github.com/ohler55/ojg/oj" +) + +// Identifiers holds a list of stdlib identifiers by name. All interface{} +// values are River-compatible values. +// +// Function identifiers are Go functions with exactly one non-error return +// value, with an optionally supported error return value as the second return +// value. +var Identifiers = map[string]interface{}{ + // See constants.go for the definition. + "constants": constants, + + "env": os.Getenv, + + "nonsensitive": func(secret rivertypes.Secret) string { + return string(secret) + }, + + // concat is implemented as a raw function so it can bypass allocations + // converting arguments into []interface{}. concat is optimized to allow it + // to perform well when it is in the hot path for combining targets from many + // other blocks. + "concat": value.RawFunction(func(funcValue value.Value, args ...value.Value) (value.Value, error) { + if len(args) == 0 { + return value.Array(), nil + } + + // finalSize is the final size of the resulting concatenated array. We type + // check our arguments while computing what finalSize will be. + var finalSize int + for i, arg := range args { + if arg.Type() != value.TypeArray { + return value.Null, value.ArgError{ + Function: funcValue, + Argument: arg, + Index: i, + Inner: value.TypeError{ + Value: arg, + Expected: value.TypeArray, + }, + } + } + + finalSize += arg.Len() + } + + // Optimization: if there's only one array, we can just return it directly. + // This is done *after* the previous loop to ensure that args[0] is a River + // array. + if len(args) == 1 { + return args[0], nil + } + + raw := make([]value.Value, 0, finalSize) + for _, arg := range args { + for i := 0; i < arg.Len(); i++ { + raw = append(raw, arg.Index(i)) + } + } + + return value.Array(raw...), nil + }), + + "json_decode": func(in string) (interface{}, error) { + var res interface{} + err := json.Unmarshal([]byte(in), &res) + if err != nil { + return nil, err + } + return res, nil + }, + + "json_path": func(jsonString string, path string) (interface{}, error) { + jsonPathExpr, err := jp.ParseString(path) + if err != nil { + return nil, err + } + + jsonExpr, err := oj.ParseString(jsonString) + if err != nil { + return nil, err + } + + return jsonPathExpr.Get(jsonExpr), nil + }, + + "coalesce": value.RawFunction(func(funcValue value.Value, args ...value.Value) (value.Value, error) { + if len(args) == 0 { + return value.Null, nil + } + + for _, arg := range args { + if arg.Type() == value.TypeNull { + continue + } + + if !arg.Reflect().IsZero() { + if argType := value.RiverType(arg.Reflect().Type()); (argType == value.TypeArray || argType == value.TypeObject) && arg.Len() == 0 { + continue + } + + return arg, nil + } + } + + return args[len(args)-1], nil + }), + + "format": fmt.Sprintf, + "join": strings.Join, + "replace": strings.ReplaceAll, + "split": strings.Split, + "to_lower": strings.ToLower, + "to_upper": strings.ToUpper, + "trim": strings.Trim, + "trim_prefix": strings.TrimPrefix, + "trim_suffix": strings.TrimSuffix, + "trim_space": strings.TrimSpace, +} diff --git a/syntax/internal/value/capsule.go b/syntax/internal/value/capsule.go new file mode 100644 index 0000000000..7a522ff337 --- /dev/null +++ b/syntax/internal/value/capsule.go @@ -0,0 +1,53 @@ +package value + +import ( + "fmt" +) + +// Capsule is a marker interface for Go values which forces a type to be +// represented as a River capsule. This is useful for types whose underlying +// value is not a capsule, such as: +// +// // Secret is a secret value. It would normally be a River string since the +// // underlying Go type is string, but it's a capsule since it implements +// // the Capsule interface. +// type Secret string +// +// func (s Secret) RiverCapsule() {} +// +// Extension interfaces are used to describe additional behaviors for Capsules. +// ConvertibleCapsule allows defining custom conversion rules to convert +// between other Go values. +type Capsule interface { + RiverCapsule() +} + +// ErrNoConversion is returned by implementations of ConvertibleCapsule to +// denote that a custom conversion from or to a specific type is unavailable. +var ErrNoConversion = fmt.Errorf("no custom capsule conversion available") + +// ConvertibleFromCapsule is a Capsule which supports custom conversion rules +// from any Go type which is not the same as the capsule type. +type ConvertibleFromCapsule interface { + Capsule + + // ConvertFrom should modify the ConvertibleCapsule value based on the value + // of src. + // + // ConvertFrom should return ErrNoConversion if no conversion is available + // from src. + ConvertFrom(src interface{}) error +} + +// ConvertibleIntoCapsule is a Capsule which supports custom conversion rules +// into any Go type which is not the same as the capsule type. +type ConvertibleIntoCapsule interface { + Capsule + + // ConvertInto should convert its value and store it into dst. dst will be a + // pointer to a value which ConvertInto is expected to update. + // + // ConvertInto should return ErrNoConversion if no conversion into dst is + // available. + ConvertInto(dst interface{}) error +} diff --git a/syntax/internal/value/decode.go b/syntax/internal/value/decode.go new file mode 100644 index 0000000000..20df78eb6a --- /dev/null +++ b/syntax/internal/value/decode.go @@ -0,0 +1,674 @@ +package value + +import ( + "encoding" + "errors" + "fmt" + "math" + "reflect" + "time" + + "github.com/grafana/river/internal/reflectutil" +) + +// The Defaulter interface allows a type to implement default functionality +// in River evaluation. +// +// Defaulter will be called only on block and body river types. +// +// When using nested blocks, the wrapping type must also implement +// Defaulter to propagate the defaults of the wrapped type. Otherwise, +// defaults used for the wrapped type become inconsistent: +// +// - If the wrapped block is NOT defined in the River config, the wrapping +// type's defaults are used. +// - If the wrapped block IS defined in the River config, the wrapped type's +// defaults are used. +type Defaulter interface { + // SetToDefault is called when evaluating a block or body to set the value + // to its defaults. + SetToDefault() +} + +// Unmarshaler is a custom type which can be used to hook into the decoder. +type Unmarshaler interface { + // UnmarshalRiver is called when decoding a value. f should be invoked to + // continue decoding with a value to decode into. + UnmarshalRiver(f func(v interface{}) error) error +} + +// The Validator interface allows a type to implement validation functionality +// in River evaluation. +type Validator interface { + // Validate is called when evaluating a block or body to enforce the + // value is valid. + Validate() error +} + +// Decode assigns a Value val to a Go pointer target. Pointers will be +// allocated as necessary when decoding. +// +// As a performance optimization, the underlying Go value of val will be +// assigned directly to target if the Go types match. This means that pointers, +// slices, and maps will be passed by reference. Callers should take care not +// to modify any Values after decoding, unless it is expected by the contract +// of the type (i.e., when the type exposes a goroutine-safe API). In other +// cases, new maps and slices will be allocated as necessary. Call DecodeCopy +// to make a copy of val instead. +// +// When a direct assignment is not done, Decode first checks to see if target +// implements the Unmarshaler or text.Unmarshaler interface, invoking methods +// as appropriate. It will also use time.ParseDuration if target is +// *time.Duration. +// +// Next, Decode will attempt to convert val to the type expected by target for +// assignment. If val or target implement ConvertibleCapsule, conversion +// between values will be attempted by calling ConvertFrom and ConvertInto as +// appropriate. If val cannot be converted, an error is returned. +// +// River null values will decode into a nil Go pointer or the zero value for +// the non-pointer type. +// +// Decode will panic if target is not a pointer. +func Decode(val Value, target interface{}) error { + rt := reflect.ValueOf(target) + if rt.Kind() != reflect.Pointer { + panic("river/value: Decode called with non-pointer value") + } + + var d decoder + return d.decode(val, rt) +} + +// DecodeCopy is like Decode but a deep copy of val is always made. +// +// Unlike Decode, DecodeCopy will always invoke Unmarshaler and +// text.Unmarshaler interfaces (if implemented by target). +func DecodeCopy(val Value, target interface{}) error { + rt := reflect.ValueOf(target) + if rt.Kind() != reflect.Pointer { + panic("river/value: Decode called with non-pointer value") + } + + d := decoder{makeCopy: true} + return d.decode(val, rt) +} + +type decoder struct { + makeCopy bool +} + +func (d *decoder) decode(val Value, into reflect.Value) (err error) { + // If everything has decoded successfully, run Validate if implemented. + defer func() { + if err == nil { + if into.CanAddr() && into.Addr().Type().Implements(goRiverValidator) { + err = into.Addr().Interface().(Validator).Validate() + } else if into.Type().Implements(goRiverValidator) { + err = into.Interface().(Validator).Validate() + } + } + }() + + // Store the raw value from val and try to address it so we can do underlying + // type match assignment. + rawValue := val.rv + if rawValue.CanAddr() { + rawValue = rawValue.Addr() + } + + // Fully deference into and allocate pointers as necessary. + for into.Kind() == reflect.Pointer { + // Check for direct assignments before allocating pointers and dereferencing. + // This preserves pointer addresses when decoding an *int into an *int. + switch { + case into.CanSet() && val.Type() == TypeNull: + into.Set(reflect.Zero(into.Type())) + return nil + case into.CanSet() && d.canDirectlyAssign(rawValue.Type(), into.Type()): + into.Set(rawValue) + return nil + case into.CanSet() && d.canDirectlyAssign(val.rv.Type(), into.Type()): + into.Set(val.rv) + return nil + } + + if into.IsNil() { + into.Set(reflect.New(into.Type().Elem())) + } + into = into.Elem() + } + + // Ww need to preform the same switch statement as above after the loop to + // check for direct assignment one more time on the fully deferenced types. + // + // NOTE(rfratto): we skip the rawValue assignment check since that's meant + // for assigning pointers, and into is never a pointer when we reach here. + switch { + case into.CanSet() && val.Type() == TypeNull: + into.Set(reflect.Zero(into.Type())) + return nil + case into.CanSet() && d.canDirectlyAssign(val.rv.Type(), into.Type()): + into.Set(val.rv) + return nil + } + + // Special decoding rules: + // + // 1. If into is an interface{}, go through decodeAny so it gets assigned + // predictable types. + // 2. If into implements a supported interface, use the interface for + // decoding instead. + if into.Type() == goAny { + return d.decodeAny(val, into) + } else if ok, err := d.decodeFromInterface(val, into); ok { + return err + } + + if into.CanAddr() && into.Addr().Type().Implements(goRiverDefaulter) { + into.Addr().Interface().(Defaulter).SetToDefault() + } else if into.Type().Implements(goRiverDefaulter) { + into.Interface().(Defaulter).SetToDefault() + } + + targetType := RiverType(into.Type()) + + // Track a value to use for decoding. This value will be updated if + // conversion is necessary. + // + // NOTE(rfratto): we don't reassign to val here, since Go 1.18 thinks that + // means it escapes the heap. We need to create a local variable to avoid + // extra allocations. + convVal := val + + // Convert the value. + switch { + case val.rv.Type() == goByteSlice && into.Type() == goString: // []byte -> string + into.Set(val.rv.Convert(goString)) + return nil + case val.rv.Type() == goString && into.Type() == goByteSlice: // string -> []byte + into.Set(val.rv.Convert(goByteSlice)) + return nil + case convVal.Type() != targetType: + converted, err := tryCapsuleConvert(convVal, into, targetType) + if err != nil { + return err + } else if converted { + return nil + } + + convVal, err = convertValue(convVal, targetType) + if err != nil { + return err + } + } + + // Slowest case: recursive decoding. Once we've reached this point, we know + // that convVal.rv and into are compatible Go types. + switch convVal.Type() { + case TypeNumber: + into.Set(convertGoNumber(convVal.Number(), into.Type())) + return nil + case TypeString: + // Call convVal.Text() to get the final string value, since convVal.rv + // might not be a string. + into.Set(reflect.ValueOf(convVal.Text())) + return nil + case TypeBool: + into.Set(reflect.ValueOf(convVal.Bool())) + return nil + case TypeArray: + return d.decodeArray(convVal, into) + case TypeObject: + return d.decodeObject(convVal, into) + case TypeFunction: + // The Go types for two functions must be the same. + // + // TODO(rfratto): we may want to consider being more lax here, potentially + // creating an adapter between the two functions. + if convVal.rv.Type() == into.Type() { + into.Set(convVal.rv) + return nil + } + + return Error{ + Value: val, + Inner: fmt.Errorf("expected function(%s), got function(%s)", into.Type(), convVal.rv.Type()), + } + case TypeCapsule: + // The Go types for the capsules must be the same or able to be converted. + if convVal.rv.Type() == into.Type() { + into.Set(convVal.rv) + return nil + } + + converted, err := tryCapsuleConvert(convVal, into, targetType) + if err != nil { + return err + } else if converted { + return nil + } + + // TODO(rfratto): return a TypeError for this instead. TypeError isn't + // appropriate at the moment because it would just print "capsule", which + // doesn't contain all the information the user would want to know (e.g., a + // capsule of what inner type?). + return Error{ + Value: val, + Inner: fmt.Errorf("expected capsule(%q), got %s", into.Type(), convVal.Describe()), + } + default: + panic("river/value: unexpected kind " + convVal.Type().String()) + } +} + +// canDirectlyAssign returns true if the `from` type can be directly asssigned +// to the `into` type. This always returns false if the decoder is set to make +// copies or into contains an interface{} type anywhere in its type definition +// to allow for decoding interfaces{} into a set of known types. +func (d *decoder) canDirectlyAssign(from reflect.Type, into reflect.Type) bool { + if d.makeCopy { + return false + } + if from != into { + return false + } + return !containsAny(into) +} + +// containsAny recursively traverses through into, returning true if it +// contains an interface{} value anywhere in its structure. +func containsAny(into reflect.Type) bool { + // TODO(rfratto): cache result of this function? + + if into == goAny { + return true + } + + switch into.Kind() { + case reflect.Array, reflect.Pointer, reflect.Slice: + return containsAny(into.Elem()) + case reflect.Map: + if into.Key() == goString { + return containsAny(into.Elem()) + } + return false + + case reflect.Struct: + for i := 0; i < into.NumField(); i++ { + if containsAny(into.Field(i).Type) { + return true + } + } + return false + + default: + // Other kinds are not River types where the decodeAny check applies. + return false + } +} + +func (d *decoder) decodeFromInterface(val Value, into reflect.Value) (ok bool, err error) { + // into may only implement interface types for a pointer receiver, so we want + // to address into if possible. + if into.CanAddr() { + into = into.Addr() + } + + switch { + case into.Type() == goDurationPtr: + var s string + err := d.decode(val, reflect.ValueOf(&s)) + if err != nil { + return true, err + } + dur, err := time.ParseDuration(s) + if err != nil { + return true, Error{Value: val, Inner: err} + } + *into.Interface().(*time.Duration) = dur + return true, nil + + case into.Type().Implements(goRiverDecoder): + err := into.Interface().(Unmarshaler).UnmarshalRiver(func(v interface{}) error { + return d.decode(val, reflect.ValueOf(v)) + }) + if err != nil { + // TODO(rfratto): we need to detect if error is one of the error types + // from this package and only wrap it in an Error if it isn't. + return true, Error{Value: val, Inner: err} + } + return true, nil + + case into.Type().Implements(goTextUnmarshaler): + var s string + err := d.decode(val, reflect.ValueOf(&s)) + if err != nil { + return true, err + } + err = into.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(s)) + if err != nil { + return true, Error{Value: val, Inner: err} + } + return true, nil + } + + return false, nil +} + +func tryCapsuleConvert(from Value, into reflect.Value, intoType Type) (ok bool, err error) { + // Check to see if we can use capsule conversion. + if from.Type() == TypeCapsule { + cc, ok := from.Interface().(ConvertibleIntoCapsule) + if ok { + // It's always possible to Addr the reflect.Value below since we expect + // it to be a settable non-pointer value. + err := cc.ConvertInto(into.Addr().Interface()) + if err == nil { + return true, nil + } else if err != nil && !errors.Is(err, ErrNoConversion) { + return false, Error{Value: from, Inner: err} + } + } + } + + if intoType == TypeCapsule { + cc, ok := into.Addr().Interface().(ConvertibleFromCapsule) + if ok { + err := cc.ConvertFrom(from.Interface()) + if err == nil { + return true, nil + } else if err != nil && !errors.Is(err, ErrNoConversion) { + return false, Error{Value: from, Inner: err} + } + } + } + + // Last attempt: allow converting two capsules if the Go types are compatible + // and the into kind is an interface. + // + // TODO(rfratto): we may consider expanding this to allowing conversion to + // any compatible Go type in the future (not just interfaces). + if from.Type() == TypeCapsule && intoType == TypeCapsule && into.Kind() == reflect.Interface { + // We try to convert a pointer to from first to avoid making unnecessary + // copies. + if from.Reflect().CanAddr() && from.Reflect().Addr().CanConvert(into.Type()) { + val := from.Reflect().Addr().Convert(into.Type()) + into.Set(val) + return true, nil + } else if from.Reflect().CanConvert(into.Type()) { + val := from.Reflect().Convert(into.Type()) + into.Set(val) + return true, nil + } + } + + return false, nil +} + +// decodeAny is invoked by decode when into is an interface{}. We assign the +// interface{} a known type based on the River value being decoded: +// +// Null values: nil +// Number values: float64, int, int64, or uint64. +// If the underlying type is a float, always decode to a float64. +// For non-floats the order of preference is int -> int64 -> uint64. +// Arrays: []interface{} +// Objects: map[string]interface{} +// Bool: bool +// String: string +// Function: Passthrough of the underlying function value +// Capsule: Passthrough of the underlying capsule value +// +// In the cases where we do not pass through the underlying value, we create a +// value of that type, recursively call decode to populate that new value, and +// then store that value into the interface{}. +func (d *decoder) decodeAny(val Value, into reflect.Value) error { + var ptr reflect.Value + + switch val.Type() { + case TypeNull: + into.Set(reflect.Zero(into.Type())) + return nil + + case TypeNumber: + + switch val.Number().Kind() { + case NumberKindFloat: + var v float64 + ptr = reflect.ValueOf(&v) + case NumberKindUint: + uint64Val := val.Uint() + if uint64Val <= math.MaxInt { + var v int + ptr = reflect.ValueOf(&v) + } else if uint64Val <= math.MaxInt64 { + var v int64 + ptr = reflect.ValueOf(&v) + } else { + var v uint64 + ptr = reflect.ValueOf(&v) + } + case NumberKindInt: + int64Val := val.Int() + if math.MinInt <= int64Val && int64Val <= math.MaxInt { + var v int + ptr = reflect.ValueOf(&v) + } else { + var v int64 + ptr = reflect.ValueOf(&v) + } + + default: + panic("river/value: unreachable") + } + + case TypeArray: + var v []interface{} + ptr = reflect.ValueOf(&v) + + case TypeObject: + var v map[string]interface{} + ptr = reflect.ValueOf(&v) + + case TypeBool: + var v bool + ptr = reflect.ValueOf(&v) + + case TypeString: + var v string + ptr = reflect.ValueOf(&v) + + case TypeFunction, TypeCapsule: + // Functions and capsules must be directly assigned since there's no + // "generic" representation for either. + // + // We retain the pointer if we were given a pointer. + + if val.rv.CanAddr() { + into.Set(val.rv.Addr()) + return nil + } + + into.Set(val.rv) + return nil + + default: + panic("river/value: unreachable") + } + + if err := d.decode(val, ptr); err != nil { + return err + } + into.Set(ptr.Elem()) + return nil +} + +func (d *decoder) decodeArray(val Value, rt reflect.Value) error { + switch rt.Kind() { + case reflect.Slice: + res := reflect.MakeSlice(rt.Type(), val.Len(), val.Len()) + for i := 0; i < val.Len(); i++ { + // Decode the original elements into the new elements. + if err := d.decode(val.Index(i), res.Index(i)); err != nil { + return ElementError{Value: val, Index: i, Inner: err} + } + } + rt.Set(res) + + case reflect.Array: + res := reflect.New(rt.Type()).Elem() + + if val.Len() != res.Len() { + return Error{ + Value: val, + Inner: fmt.Errorf("array must have exactly %d elements, got %d", res.Len(), val.Len()), + } + } + + for i := 0; i < val.Len(); i++ { + if err := d.decode(val.Index(i), res.Index(i)); err != nil { + return ElementError{Value: val, Index: i, Inner: err} + } + } + rt.Set(res) + + default: + panic(fmt.Sprintf("river/value: unexpected array type %s", val.rv.Kind())) + } + + return nil +} + +func (d *decoder) decodeObject(val Value, rt reflect.Value) error { + switch rt.Kind() { + case reflect.Struct: + targetTags := getCachedTags(rt.Type()) + return d.decodeObjectToStruct(val, rt, targetTags, false) + + case reflect.Slice, reflect.Array: // Slice of labeled blocks + keys := val.Keys() + + var res reflect.Value + + if rt.Kind() == reflect.Slice { + res = reflect.MakeSlice(rt.Type(), len(keys), len(keys)) + } else { // Array + res = reflect.New(rt.Type()).Elem() + + if res.Len() != len(keys) { + return Error{ + Value: val, + Inner: fmt.Errorf("object must have exactly %d keys, got %d", res.Len(), len(keys)), + } + } + } + + fields := getCachedTags(rt.Type().Elem()) + labelField, _ := fields.LabelField() + + for i, key := range keys { + // First decode the key into the label. + elem := res.Index(i) + reflectutil.GetOrAlloc(elem, labelField).Set(reflect.ValueOf(key)) + + // Now decode the inner object. + value, _ := val.Key(key) + if err := d.decodeObjectToStruct(value, elem, fields, true); err != nil { + return FieldError{Value: val, Field: key, Inner: err} + } + } + rt.Set(res) + + case reflect.Map: + if rt.Type().Key() != goString { + // Maps with non-string types are treated as capsules and can't be + // decoded from maps. + return TypeError{Value: val, Expected: RiverType(rt.Type())} + } + + res := reflect.MakeMapWithSize(rt.Type(), val.Len()) + + // Create a shared value to decode each element into. This will be zeroed + // out for each key, and then copied when setting the map index. + into := reflect.New(rt.Type().Elem()).Elem() + intoZero := reflect.Zero(into.Type()) + + for i, key := range val.Keys() { + // We ignore the ok value because we know it exists. + value, _ := val.Key(key) + + // Zero out the value if it was decoded in the previous loop. + if i > 0 { + into.Set(intoZero) + } + // Decode into our element. + if err := d.decode(value, into); err != nil { + return FieldError{Value: val, Field: key, Inner: err} + } + + // Then set the map index. + res.SetMapIndex(reflect.ValueOf(key), into) + } + + rt.Set(res) + + default: + panic(fmt.Sprintf("river/value: unexpected target type %s", rt.Kind())) + } + + return nil +} + +func (d *decoder) decodeObjectToStruct(val Value, rt reflect.Value, fields *objectFields, decodedLabel bool) error { + // TODO(rfratto): this needs to check for required keys being set + + for _, key := range val.Keys() { + // We ignore the ok value because we know it exists. + value, _ := val.Key(key) + + // Struct labels should be decoded first, since objects are wrapped in + // labels. If we have yet to decode the label, decode it now. + if lf, ok := fields.LabelField(); ok && !decodedLabel { + // Safety check: if the inner field isn't an object, there's something + // wrong here. It's unclear if a user can craft an expression that hits + // this case, but it's left in for safety. + if value.Type() != TypeObject { + return FieldError{ + Value: val, + Field: key, + Inner: TypeError{Value: value, Expected: TypeObject}, + } + } + + // Decode the key into the label. + reflectutil.GetOrAlloc(rt, lf).Set(reflect.ValueOf(key)) + + // ...and then code the rest of the object. + if err := d.decodeObjectToStruct(value, rt, fields, true); err != nil { + return err + } + continue + } + + switch fields.Has(key) { + case objectKeyTypeInvalid: + return MissingKeyError{Value: value, Missing: key} + case objectKeyTypeNestedField: // Block with multiple name fragments + next, _ := fields.NestedField(key) + // Recurse the call with the inner value. + if err := d.decodeObjectToStruct(value, rt, next, decodedLabel); err != nil { + return err + } + case objectKeyTypeField: // Single-name fragment + targetField, _ := fields.Field(key) + targetValue := reflectutil.GetOrAlloc(rt, targetField) + + if err := d.decode(value, targetValue); err != nil { + return FieldError{Value: val, Field: key, Inner: err} + } + } + } + + return nil +} diff --git a/syntax/internal/value/decode_benchmarks_test.go b/syntax/internal/value/decode_benchmarks_test.go new file mode 100644 index 0000000000..9a33239329 --- /dev/null +++ b/syntax/internal/value/decode_benchmarks_test.go @@ -0,0 +1,90 @@ +package value_test + +import ( + "fmt" + "testing" + + "github.com/grafana/river/internal/value" +) + +func BenchmarkObjectDecode(b *testing.B) { + b.StopTimer() + + // Create a value with 20 keys. + source := make(map[string]string, 20) + for i := 0; i < 20; i++ { + var ( + key = fmt.Sprintf("key_%d", i+1) + value = fmt.Sprintf("value_%d", i+1) + ) + source[key] = value + } + + sourceVal := value.Encode(source) + + b.StartTimer() + for i := 0; i < b.N; i++ { + var dst map[string]string + _ = value.Decode(sourceVal, &dst) + } +} + +func BenchmarkObject(b *testing.B) { + b.Run("Non-capsule", func(b *testing.B) { + b.StopTimer() + + vals := make(map[string]value.Value) + for i := 0; i < 20; i++ { + vals[fmt.Sprintf("%d", i)] = value.Int(int64(i)) + } + + b.StartTimer() + for i := 0; i < b.N; i++ { + _ = value.Object(vals) + } + }) + + b.Run("Capsule", func(b *testing.B) { + b.StopTimer() + + vals := make(map[string]value.Value) + for i := 0; i < 20; i++ { + vals[fmt.Sprintf("%d", i)] = value.Encapsulate(make(chan int)) + } + + b.StartTimer() + for i := 0; i < b.N; i++ { + _ = value.Object(vals) + } + }) +} + +func BenchmarkArray(b *testing.B) { + b.Run("Non-capsule", func(b *testing.B) { + b.StopTimer() + + var vals []value.Value + for i := 0; i < 20; i++ { + vals = append(vals, value.Int(int64(i))) + } + + b.StartTimer() + for i := 0; i < b.N; i++ { + _ = value.Array(vals...) + } + }) + + b.Run("Capsule", func(b *testing.B) { + b.StopTimer() + + var vals []value.Value + for i := 0; i < 20; i++ { + vals = append(vals, value.Encapsulate(make(chan int))) + } + + b.StartTimer() + for i := 0; i < b.N; i++ { + _ = value.Array(vals...) + } + }) +} diff --git a/syntax/internal/value/decode_test.go b/syntax/internal/value/decode_test.go new file mode 100644 index 0000000000..5b84838bd0 --- /dev/null +++ b/syntax/internal/value/decode_test.go @@ -0,0 +1,761 @@ +package value_test + +import ( + "fmt" + "math" + "reflect" + "testing" + "time" + "unsafe" + + "github.com/grafana/river/internal/value" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDecode_Numbers(t *testing.T) { + // There's a lot of values that can represent numbers, so we construct a + // matrix dynamically of all the combinations here. + vals := []interface{}{ + int(15), int8(15), int16(15), int32(15), int64(15), + uint(15), uint8(15), uint16(15), uint32(15), uint64(15), + float32(15), float64(15), + string("15"), // string holding a valid number (which can be converted to a number) + } + + for _, input := range vals { + for _, expect := range vals { + val := value.Encode(input) + + name := fmt.Sprintf( + "%s to %s", + reflect.TypeOf(input), + reflect.TypeOf(expect), + ) + + t.Run(name, func(t *testing.T) { + vPtr := reflect.New(reflect.TypeOf(expect)).Interface() + require.NoError(t, value.Decode(val, vPtr)) + + actual := reflect.ValueOf(vPtr).Elem().Interface() + require.Equal(t, expect, actual) + }) + } + } +} + +func TestDecode(t *testing.T) { + // Declare some types to use for testing. Person2 is used as a struct + // equivalent to Person, but with a different Go type to force casting. + type Person struct { + Name string `river:"name,attr"` + } + + type Person2 struct { + Name string `river:"name,attr"` + } + + tt := []struct { + input, expect interface{} + }{ + {nil, (*int)(nil)}, + + // Non-number primitives. + {string("Hello!"), string("Hello!")}, + {bool(true), bool(true)}, + + // Arrays + {[]int{1, 2, 3}, []int{1, 2, 3}}, + {[]int{1, 2, 3}, [...]int{1, 2, 3}}, + {[...]int{1, 2, 3}, []int{1, 2, 3}}, + {[...]int{1, 2, 3}, [...]int{1, 2, 3}}, + + // Maps + {map[string]int{"year": 2022}, map[string]uint{"year": 2022}}, + {map[string]string{"name": "John"}, map[string]string{"name": "John"}}, + {map[string]string{"name": "John"}, Person{Name: "John"}}, + {Person{Name: "John"}, map[string]string{"name": "John"}}, + {Person{Name: "John"}, Person{Name: "John"}}, + {Person{Name: "John"}, Person2{Name: "John"}}, + {Person2{Name: "John"}, Person{Name: "John"}}, + + // NOTE(rfratto): we don't test capsules or functions here because they're + // not comparable in the same way as we do the other tests. + // + // See TestDecode_Functions and TestDecode_Capsules for specific decoding + // tests of those types. + } + + for _, tc := range tt { + val := value.Encode(tc.input) + + name := fmt.Sprintf( + "%s (%s) to %s", + val.Type(), + reflect.TypeOf(tc.input), + reflect.TypeOf(tc.expect), + ) + + t.Run(name, func(t *testing.T) { + vPtr := reflect.New(reflect.TypeOf(tc.expect)).Interface() + require.NoError(t, value.Decode(val, vPtr)) + + actual := reflect.ValueOf(vPtr).Elem().Interface() + + require.Equal(t, tc.expect, actual) + }) + } +} + +// TestDecode_PreservePointer ensures that pointer addresses can be preserved +// when decoding. +func TestDecode_PreservePointer(t *testing.T) { + num := 5 + val := value.Encode(&num) + + var nump *int + require.NoError(t, value.Decode(val, &nump)) + require.Equal(t, unsafe.Pointer(nump), unsafe.Pointer(&num)) +} + +// TestDecode_PreserveMapReference ensures that map references can be preserved +// when decoding. +func TestDecode_PreserveMapReference(t *testing.T) { + m := make(map[string]string) + val := value.Encode(m) + + var actual map[string]string + require.NoError(t, value.Decode(val, &actual)) + + // We can't check to see if the pointers of m and actual match, but we can + // modify m to see if actual is also modified. + m["foo"] = "bar" + require.Equal(t, "bar", actual["foo"]) +} + +// TestDecode_PreserveSliceReference ensures that slice references can be +// preserved when decoding. +func TestDecode_PreserveSliceReference(t *testing.T) { + s := make([]string, 3) + val := value.Encode(s) + + var actual []string + require.NoError(t, value.Decode(val, &actual)) + + // We can't check to see if the pointers of m and actual match, but we can + // modify s to see if actual is also modified. + s[0] = "Hello, world!" + require.Equal(t, "Hello, world!", actual[0]) +} +func TestDecode_Functions(t *testing.T) { + val := value.Encode(func() int { return 15 }) + + var f func() int + require.NoError(t, value.Decode(val, &f)) + require.Equal(t, 15, f()) +} + +func TestDecode_Capsules(t *testing.T) { + expect := make(chan int, 5) + + var actual chan int + require.NoError(t, value.Decode(value.Encode(expect), &actual)) + require.Equal(t, expect, actual) +} + +type ValueInterface interface{ SomeMethod() } + +type Value1 struct{ test string } + +func (c Value1) SomeMethod() {} + +// TestDecode_CapsuleInterface tests that we are able to decode when +// the target `into` is an interface. +func TestDecode_CapsuleInterface(t *testing.T) { + tt := []struct { + name string + value ValueInterface + expected ValueInterface + }{ + { + name: "Capsule to Capsule", + value: Value1{test: "true"}, + expected: Value1{test: "true"}, + }, + { + name: "Capsule Pointer to Capsule", + value: &Value1{test: "true"}, + expected: &Value1{test: "true"}, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + var actual ValueInterface + require.NoError(t, value.Decode(value.Encode(tc.value), &actual)) + + // require.Same validates the memory address matches after Decode. + if reflect.TypeOf(tc.value).Kind() == reflect.Pointer { + require.Same(t, tc.value, actual) + } + + // We use tc.expected to validate the properties of actual match the + // original tc.value properties (nothing has mutated them during the test). + require.Equal(t, tc.expected, actual) + }) + } +} + +// TestDecode_CapsulesError tests that we are unable to decode when +// the target `into` is not an interface. +func TestDecode_CapsulesError(t *testing.T) { + type Capsule1 struct{ test string } + type Capsule2 Capsule1 + + v := Capsule1{test: "true"} + actual := Capsule2{} + + require.EqualError(t, value.Decode(value.Encode(v), &actual), `expected capsule("value_test.Capsule2"), got capsule("value_test.Capsule1")`) +} + +// TestDecodeCopy_SliceCopy ensures that copies are made during decoding +// instead of setting values directly. +func TestDecodeCopy_SliceCopy(t *testing.T) { + orig := []int{1, 2, 3} + + var res []int + require.NoError(t, value.DecodeCopy(value.Encode(orig), &res)) + + res[0] = 10 + require.Equal(t, []int{1, 2, 3}, orig, "Original slice should not have been modified") +} + +// TestDecodeCopy_ArrayCopy ensures that copies are made during decoding +// instead of setting values directly. +func TestDecode_ArrayCopy(t *testing.T) { + orig := [...]int{1, 2, 3} + + var res [3]int + require.NoError(t, value.DecodeCopy(value.Encode(orig), &res)) + + res[0] = 10 + require.Equal(t, [3]int{1, 2, 3}, orig, "Original array should not have been modified") +} + +func TestDecode_CustomTypes(t *testing.T) { + t.Run("object to Unmarshaler", func(t *testing.T) { + var actual customUnmarshaler + require.NoError(t, value.Decode(value.Object(nil), &actual)) + require.True(t, actual.UnmarshalCalled, "UnmarshalRiver was not invoked") + require.True(t, actual.DefaultCalled, "SetToDefault was not invoked") + require.True(t, actual.ValidateCalled, "Validate was not invoked") + }) + + t.Run("TextMarshaler to TextUnmarshaler", func(t *testing.T) { + now := time.Now() + + var actual time.Time + require.NoError(t, value.Decode(value.Encode(now), &actual)) + require.True(t, now.Equal(actual)) + }) + + t.Run("time.Duration to time.Duration", func(t *testing.T) { + dur := 15 * time.Second + + var actual time.Duration + require.NoError(t, value.Decode(value.Encode(dur), &actual)) + require.Equal(t, dur, actual) + }) + + t.Run("string to TextUnmarshaler", func(t *testing.T) { + now := time.Now() + nowBytes, _ := now.MarshalText() + + var actual time.Time + require.NoError(t, value.Decode(value.String(string(nowBytes)), &actual)) + + actualBytes, _ := actual.MarshalText() + require.Equal(t, nowBytes, actualBytes) + }) + + t.Run("string to time.Duration", func(t *testing.T) { + dur := 15 * time.Second + + var actual time.Duration + require.NoError(t, value.Decode(value.String(dur.String()), &actual)) + require.Equal(t, dur.String(), actual.String()) + }) +} + +type customUnmarshaler struct { + UnmarshalCalled bool `river:"unmarshal_called,attr,optional"` + DefaultCalled bool `river:"default_called,attr,optional"` + ValidateCalled bool `river:"validate_called,attr,optional"` +} + +func (cu *customUnmarshaler) UnmarshalRiver(f func(interface{}) error) error { + cu.UnmarshalCalled = true + return f((*customUnmarshalerTarget)(cu)) +} + +type customUnmarshalerTarget customUnmarshaler + +func (s *customUnmarshalerTarget) SetToDefault() { + s.DefaultCalled = true +} + +func (s *customUnmarshalerTarget) Validate() error { + s.ValidateCalled = true + return nil +} + +type textEnumType bool + +func (et *textEnumType) UnmarshalText(text []byte) error { + *et = false + + switch string(text) { + case "accepted_value": + *et = true + return nil + default: + return fmt.Errorf("unrecognized value %q", string(text)) + } +} + +func TestDecode_TextUnmarshaler(t *testing.T) { + t.Run("valid type and value", func(t *testing.T) { + var et textEnumType + require.NoError(t, value.Decode(value.String("accepted_value"), &et)) + require.Equal(t, textEnumType(true), et) + }) + + t.Run("invalid type", func(t *testing.T) { + var et textEnumType + err := value.Decode(value.Bool(true), &et) + require.EqualError(t, err, "expected string, got bool") + }) + + t.Run("invalid value", func(t *testing.T) { + var et textEnumType + err := value.Decode(value.String("bad_value"), &et) + require.EqualError(t, err, `unrecognized value "bad_value"`) + }) + + t.Run("unmarshaler nested in other value", func(t *testing.T) { + input := value.Array( + value.String("accepted_value"), + value.String("accepted_value"), + value.String("accepted_value"), + ) + + var ett []textEnumType + require.NoError(t, value.Decode(input, &ett)) + require.Equal(t, []textEnumType{true, true, true}, ett) + }) +} + +func TestDecode_ErrorChain(t *testing.T) { + type Target struct { + Key struct { + Object struct { + Field1 []int `river:"field1,attr"` + } `river:"object,attr"` + } `river:"key,attr"` + } + + val := value.Object(map[string]value.Value{ + "key": value.Object(map[string]value.Value{ + "object": value.Object(map[string]value.Value{ + "field1": value.Array( + value.Int(15), + value.Int(30), + value.String("Hello, world!"), + ), + }), + }), + }) + + // NOTE(rfratto): strings of errors from the value package are fairly limited + // in the amount of information they show, since the value package doesn't + // have a great way to pretty-print the chain of errors. + // + // For example, with the error below, the message doesn't explain where the + // string is coming from, even though the error values hold that context. + // + // Callers consuming errors should print the error chain with extra context + // so it's more useful to users. + err := value.Decode(val, &Target{}) + expectErr := `expected number, got string` + require.EqualError(t, err, expectErr) +} + +type boolish int + +var _ value.ConvertibleFromCapsule = (*boolish)(nil) +var _ value.ConvertibleIntoCapsule = (boolish)(0) + +func (b boolish) RiverCapsule() {} + +func (b *boolish) ConvertFrom(src interface{}) error { + switch v := src.(type) { + case bool: + if v { + *b = 1 + } else { + *b = 0 + } + return nil + } + + return value.ErrNoConversion +} + +func (b boolish) ConvertInto(dst interface{}) error { + switch d := dst.(type) { + case *bool: + if b == 0 { + *d = false + } else { + *d = true + } + return nil + } + + return value.ErrNoConversion +} + +func TestDecode_CustomConvert(t *testing.T) { + t.Run("compatible type to custom", func(t *testing.T) { + var b boolish + err := value.Decode(value.Bool(true), &b) + require.NoError(t, err) + require.Equal(t, boolish(1), b) + }) + + t.Run("custom to compatible type", func(t *testing.T) { + var b bool + err := value.Decode(value.Encapsulate(boolish(10)), &b) + require.NoError(t, err) + require.Equal(t, true, b) + }) + + t.Run("incompatible type to custom", func(t *testing.T) { + var b boolish + err := value.Decode(value.String("true"), &b) + require.EqualError(t, err, "expected capsule, got string") + }) + + t.Run("custom to incompatible type", func(t *testing.T) { + src := boolish(10) + + var s string + err := value.Decode(value.Encapsulate(&src), &s) + require.EqualError(t, err, "expected string, got capsule") + }) +} + +func TestDecode_SquashedFields(t *testing.T) { + type InnerStruct struct { + InnerField1 string `river:"inner_field_1,attr,optional"` + InnerField2 string `river:"inner_field_2,attr,optional"` + } + + type OuterStruct struct { + OuterField1 string `river:"outer_field_1,attr,optional"` + Inner InnerStruct `river:",squash"` + OuterField2 string `river:"outer_field_2,attr,optional"` + } + + var ( + in = map[string]string{ + "outer_field_1": "value1", + "outer_field_2": "value2", + "inner_field_1": "value3", + "inner_field_2": "value4", + } + expect = OuterStruct{ + OuterField1: "value1", + Inner: InnerStruct{ + InnerField1: "value3", + InnerField2: "value4", + }, + OuterField2: "value2", + } + ) + + var out OuterStruct + err := value.Decode(value.Encode(in), &out) + require.NoError(t, err) + require.Equal(t, expect, out) +} + +func TestDecode_SquashedFields_Pointer(t *testing.T) { + type InnerStruct struct { + InnerField1 string `river:"inner_field_1,attr,optional"` + InnerField2 string `river:"inner_field_2,attr,optional"` + } + + type OuterStruct struct { + OuterField1 string `river:"outer_field_1,attr,optional"` + Inner *InnerStruct `river:",squash"` + OuterField2 string `river:"outer_field_2,attr,optional"` + } + + var ( + in = map[string]string{ + "outer_field_1": "value1", + "outer_field_2": "value2", + "inner_field_1": "value3", + "inner_field_2": "value4", + } + expect = OuterStruct{ + OuterField1: "value1", + Inner: &InnerStruct{ + InnerField1: "value3", + InnerField2: "value4", + }, + OuterField2: "value2", + } + ) + + var out OuterStruct + err := value.Decode(value.Encode(in), &out) + require.NoError(t, err) + require.Equal(t, expect, out) +} + +func TestDecode_Slice(t *testing.T) { + type Block struct { + Attr int `river:"attr,attr"` + } + + type Struct struct { + Blocks []Block `river:"block.a,block,optional"` + } + + var ( + in = map[string]interface{}{ + "block": map[string]interface{}{ + "a": []map[string]interface{}{ + {"attr": 1}, + {"attr": 2}, + {"attr": 3}, + {"attr": 4}, + }, + }, + } + expect = Struct{ + Blocks: []Block{ + {Attr: 1}, + {Attr: 2}, + {Attr: 3}, + {Attr: 4}, + }, + } + ) + + var out Struct + err := value.Decode(value.Encode(in), &out) + require.NoError(t, err) + require.Equal(t, expect, out) +} + +func TestDecode_SquashedSlice(t *testing.T) { + type Block struct { + Attr int `river:"attr,attr"` + } + + type InnerStruct struct { + BlockA Block `river:"a,block,optional"` + BlockB Block `river:"b,block,optional"` + BlockC Block `river:"c,block,optional"` + } + + type OuterStruct struct { + OuterField1 string `river:"outer_field_1,attr,optional"` + Inner []InnerStruct `river:"block,enum"` + OuterField2 string `river:"outer_field_2,attr,optional"` + } + + var ( + in = map[string]interface{}{ + "outer_field_1": "value1", + "outer_field_2": "value2", + + "block": []map[string]interface{}{ + {"a": map[string]interface{}{"attr": 1}}, + {"b": map[string]interface{}{"attr": 2}}, + {"c": map[string]interface{}{"attr": 3}}, + {"a": map[string]interface{}{"attr": 4}}, + }, + } + expect = OuterStruct{ + OuterField1: "value1", + OuterField2: "value2", + + Inner: []InnerStruct{ + {BlockA: Block{Attr: 1}}, + {BlockB: Block{Attr: 2}}, + {BlockC: Block{Attr: 3}}, + {BlockA: Block{Attr: 4}}, + }, + } + ) + + var out OuterStruct + err := value.Decode(value.Encode(in), &out) + require.NoError(t, err) + require.Equal(t, expect, out) +} + +func TestDecode_SquashedSlice_Pointer(t *testing.T) { + type Block struct { + Attr int `river:"attr,attr"` + } + + type InnerStruct struct { + BlockA *Block `river:"a,block,optional"` + BlockB *Block `river:"b,block,optional"` + BlockC *Block `river:"c,block,optional"` + } + + type OuterStruct struct { + OuterField1 string `river:"outer_field_1,attr,optional"` + Inner []InnerStruct `river:"block,enum"` + OuterField2 string `river:"outer_field_2,attr,optional"` + } + + var ( + in = map[string]interface{}{ + "outer_field_1": "value1", + "outer_field_2": "value2", + + "block": []map[string]interface{}{ + {"a": map[string]interface{}{"attr": 1}}, + {"b": map[string]interface{}{"attr": 2}}, + {"c": map[string]interface{}{"attr": 3}}, + {"a": map[string]interface{}{"attr": 4}}, + }, + } + expect = OuterStruct{ + OuterField1: "value1", + OuterField2: "value2", + + Inner: []InnerStruct{ + {BlockA: &Block{Attr: 1}}, + {BlockB: &Block{Attr: 2}}, + {BlockC: &Block{Attr: 3}}, + {BlockA: &Block{Attr: 4}}, + }, + } + ) + + var out OuterStruct + err := value.Decode(value.Encode(in), &out) + require.NoError(t, err) + require.Equal(t, expect, out) +} + +// TestDecode_KnownTypes_Any asserts that decoding River values into an +// any/interface{} results in known types. +func TestDecode_KnownTypes_Any(t *testing.T) { + tt := []struct { + input any + expect any + }{ + // expect "int" + {int(0), 0}, + {int(-1), -1}, + {int(15), 15}, + {int8(15), 15}, + {int16(15), 15}, + {int32(15), 15}, + {int64(15), 15}, + {uint(0), 0}, + {uint(15), 15}, + {uint8(15), 15}, + {uint16(15), 15}, + {uint32(15), 15}, + {uint64(15), 15}, + {int64(math.MinInt64), math.MinInt64}, + {int64(math.MaxInt64), math.MaxInt64}, + // expect "uint" + {uint64(math.MaxInt64 + 1), uint64(math.MaxInt64 + 1)}, + {uint64(math.MaxUint64), uint64(math.MaxUint64)}, + // expect "float" + {float32(2.5), float64(2.5)}, + {float64(2.5), float64(2.5)}, + {float64(math.MinInt64) - 10, float64(math.MinInt64) - 10}, + {float64(math.MaxInt64) + 10, float64(math.MaxInt64) + 10}, + + {bool(true), bool(true)}, + {string("Hello"), string("Hello")}, + + { + input: []int{1, 2, 3}, + expect: []any{1, 2, 3}, + }, + + { + input: map[string]int{"number": 15}, + expect: map[string]any{"number": 15}, + }, + { + input: struct { + Name string `river:"name,attr"` + }{Name: "John"}, + + expect: map[string]any{"name": "John"}, + }, + } + + t.Run("basic types", func(t *testing.T) { + for _, tc := range tt { + var actual any + err := value.Decode(value.Encode(tc.input), &actual) + + if assert.NoError(t, err) { + assert.Equal(t, tc.expect, actual, + "Expected %[1]v (%[1]T) to transcode to %[2]v (%[2]T)", tc.input, tc.expect) + } + } + }) + + t.Run("inside maps", func(t *testing.T) { + for _, tc := range tt { + input := map[string]any{ + "key": tc.input, + } + + var actual map[string]any + err := value.Decode(value.Encode(input), &actual) + + if assert.NoError(t, err) { + assert.Equal(t, tc.expect, actual["key"], + "Expected %[1]v (%[1]T) to transcode to %[2]v (%[2]T) inside a map", tc.input, tc.expect) + } + } + }) +} + +func TestRetainCapsulePointer(t *testing.T) { + capsuleVal := &capsule{} + + in := map[string]any{ + "foo": capsuleVal, + } + + var actual map[string]any + err := value.Decode(value.Encode(in), &actual) + require.NoError(t, err) + + expect := map[string]any{ + "foo": capsuleVal, + } + require.Equal(t, expect, actual) +} + +type capsule struct{} + +func (*capsule) RiverCapsule() {} diff --git a/syntax/internal/value/errors.go b/syntax/internal/value/errors.go new file mode 100644 index 0000000000..79f22378b3 --- /dev/null +++ b/syntax/internal/value/errors.go @@ -0,0 +1,107 @@ +package value + +import "fmt" + +// Error is used for reporting on a value-level error. It is the most general +// type of error for a value. +type Error struct { + Value Value + Inner error +} + +// TypeError is used for reporting on a value having an unexpected type. +type TypeError struct { + // Value which caused the error. + Value Value + Expected Type +} + +// Error returns the string form of the TypeError. +func (te TypeError) Error() string { + return fmt.Sprintf("expected %s, got %s", te.Expected, te.Value.Type()) +} + +// Error returns the message of the decode error. +func (de Error) Error() string { return de.Inner.Error() } + +// MissingKeyError is used for reporting that a value is missing a key. +type MissingKeyError struct { + Value Value + Missing string +} + +// Error returns the string form of the MissingKeyError. +func (mke MissingKeyError) Error() string { + return fmt.Sprintf("key %q does not exist", mke.Missing) +} + +// ElementError is used to report on an error inside of an array. +type ElementError struct { + Value Value // The Array value + Index int // The index of the element with the issue + Inner error // The error from the element +} + +// Error returns the text of the inner error. +func (ee ElementError) Error() string { return ee.Inner.Error() } + +// FieldError is used to report on an invalid field inside an object. +type FieldError struct { + Value Value // The Object value + Field string // The field name with the issue + Inner error // The error from the field +} + +// Error returns the text of the inner error. +func (fe FieldError) Error() string { return fe.Inner.Error() } + +// ArgError is used to report on an invalid argument to a function. +type ArgError struct { + Function Value + Argument Value + Index int + Inner error +} + +// Error returns the text of the inner error. +func (ae ArgError) Error() string { return ae.Inner.Error() } + +// WalkError walks err for all value-related errors in this package. +// WalkError returns false if err is not an error from this package. +func WalkError(err error, f func(err error)) bool { + var foundOne bool + + nextError := err + for nextError != nil { + switch ne := nextError.(type) { + case Error: + f(nextError) + nextError = ne.Inner + foundOne = true + case TypeError: + f(nextError) + nextError = nil + foundOne = true + case MissingKeyError: + f(nextError) + nextError = nil + foundOne = true + case ElementError: + f(nextError) + nextError = ne.Inner + foundOne = true + case FieldError: + f(nextError) + nextError = ne.Inner + foundOne = true + case ArgError: + f(nextError) + nextError = ne.Inner + foundOne = true + default: + nextError = nil + } + } + + return foundOne +} diff --git a/syntax/internal/value/number_value.go b/syntax/internal/value/number_value.go new file mode 100644 index 0000000000..c40fbbc802 --- /dev/null +++ b/syntax/internal/value/number_value.go @@ -0,0 +1,135 @@ +package value + +import ( + "math" + "reflect" + "strconv" +) + +var ( + nativeIntBits = reflect.TypeOf(int(0)).Bits() + nativeUintBits = reflect.TypeOf(uint(0)).Bits() +) + +// NumberKind categorizes a type of Go number. +type NumberKind uint8 + +const ( + // NumberKindInt represents an int-like type (e.g., int, int8, etc.). + NumberKindInt NumberKind = iota + // NumberKindUint represents a uint-like type (e.g., uint, uint8, etc.). + NumberKindUint + // NumberKindFloat represents both float32 and float64. + NumberKindFloat +) + +// makeNumberKind converts a Go kind to a River kind. +func makeNumberKind(k reflect.Kind) NumberKind { + switch k { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return NumberKindInt + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return NumberKindUint + case reflect.Float32, reflect.Float64: + return NumberKindFloat + default: + panic("river/value: makeNumberKind called with unsupported Kind value") + } +} + +// Number is a generic representation of Go numbers. It is intended to be +// created on the fly for numerical operations when the real number type is not +// known. +type Number struct { + // Value holds the raw data for the number. Note that for numberKindFloat, + // value is the raw bits of the float64 and must be converted back to a + // float64 before it can be used. + value uint64 + + bits uint8 // 8, 16, 32, 64, used for overflow checking + k NumberKind // int, uint, float +} + +func newNumberValue(v reflect.Value) Number { + var ( + val uint64 + bits int + nk NumberKind + ) + + switch v.Kind() { + case reflect.Int: + val, bits, nk = uint64(v.Int()), nativeIntBits, NumberKindInt + case reflect.Int8: + val, bits, nk = uint64(v.Int()), 8, NumberKindInt + case reflect.Int16: + val, bits, nk = uint64(v.Int()), 16, NumberKindInt + case reflect.Int32: + val, bits, nk = uint64(v.Int()), 32, NumberKindInt + case reflect.Int64: + val, bits, nk = uint64(v.Int()), 64, NumberKindInt + case reflect.Uint: + val, bits, nk = v.Uint(), nativeUintBits, NumberKindUint + case reflect.Uint8: + val, bits, nk = v.Uint(), 8, NumberKindUint + case reflect.Uint16: + val, bits, nk = v.Uint(), 16, NumberKindUint + case reflect.Uint32: + val, bits, nk = v.Uint(), 32, NumberKindUint + case reflect.Uint64: + val, bits, nk = v.Uint(), 64, NumberKindUint + case reflect.Float32: + val, bits, nk = math.Float64bits(v.Float()), 32, NumberKindFloat + case reflect.Float64: + val, bits, nk = math.Float64bits(v.Float()), 64, NumberKindFloat + default: + panic("river/value: unrecognized Go number type " + v.Kind().String()) + } + + return Number{val, uint8(bits), nk} +} + +// Kind returns the Number's NumberKind. +func (nv Number) Kind() NumberKind { return nv.k } + +// Int converts the Number into an int64. +func (nv Number) Int() int64 { + if nv.k == NumberKindFloat { + return int64(math.Float64frombits(nv.value)) + } + return int64(nv.value) +} + +// Uint converts the Number into a uint64. +func (nv Number) Uint() uint64 { + if nv.k == NumberKindFloat { + return uint64(math.Float64frombits(nv.value)) + } + return nv.value +} + +// Float converts the Number into a float64. +func (nv Number) Float() float64 { + switch nv.k { + case NumberKindInt: + // Convert nv.value to an int64 before converting to a float64 so the sign + // flag gets handled correctly. + return float64(int64(nv.value)) + case NumberKindFloat: + return math.Float64frombits(nv.value) + } + return float64(nv.value) +} + +// ToString converts the Number to a string. +func (nv Number) ToString() string { + switch nv.k { + case NumberKindUint: + return strconv.FormatUint(nv.value, 10) + case NumberKindInt: + return strconv.FormatInt(int64(nv.value), 10) + case NumberKindFloat: + return strconv.FormatFloat(math.Float64frombits(nv.value), 'f', -1, 64) + } + panic("river/value: unreachable") +} diff --git a/syntax/internal/value/raw_function.go b/syntax/internal/value/raw_function.go new file mode 100644 index 0000000000..bf25da916d --- /dev/null +++ b/syntax/internal/value/raw_function.go @@ -0,0 +1,9 @@ +package value + +// RawFunction allows creating function implementations using raw River values. +// This is useful for functions which wish to operate over dynamic types while +// avoiding decoding to interface{} for performance reasons. +// +// The func value itself is provided as an argument so error types can be +// filled. +type RawFunction func(funcValue Value, args ...Value) (Value, error) diff --git a/syntax/internal/value/tag_cache.go b/syntax/internal/value/tag_cache.go new file mode 100644 index 0000000000..2bce16209d --- /dev/null +++ b/syntax/internal/value/tag_cache.go @@ -0,0 +1,121 @@ +package value + +import ( + "reflect" + + "github.com/grafana/river/internal/rivertags" +) + +// tagsCache caches the river tags for a struct type. This is never cleared, +// but since most structs will be statically created throughout the lifetime +// of the process, this will consume a negligible amount of memory. +var tagsCache = make(map[reflect.Type]*objectFields) + +func getCachedTags(t reflect.Type) *objectFields { + if t.Kind() != reflect.Struct { + panic("getCachedTags called with non-struct type") + } + + if entry, ok := tagsCache[t]; ok { + return entry + } + + ff := rivertags.Get(t) + + // Build a tree of keys. + tree := &objectFields{ + fields: make(map[string]rivertags.Field), + nestedFields: make(map[string]*objectFields), + keys: []string{}, + } + + for _, f := range ff { + if f.Flags&rivertags.FlagLabel != 0 { + // Skip over label tags. + tree.labelField = f + continue + } + + node := tree + for i, name := range f.Name { + // Add to the list of keys if this is a new key. + if node.Has(name) == objectKeyTypeInvalid { + node.keys = append(node.keys, name) + } + + if i+1 == len(f.Name) { + // Last fragment, add as a field. + node.fields[name] = f + continue + } + + inner, ok := node.nestedFields[name] + if !ok { + inner = &objectFields{ + fields: make(map[string]rivertags.Field), + nestedFields: make(map[string]*objectFields), + keys: []string{}, + } + node.nestedFields[name] = inner + } + node = inner + } + } + + tagsCache[t] = tree + return tree +} + +// objectFields is a parsed tree of fields in rivertags. It forms a tree where +// leaves are nested fields (e.g., for block names that have multiple name +// fragments) and nodes are the fields themselves. +type objectFields struct { + fields map[string]rivertags.Field + nestedFields map[string]*objectFields + keys []string // Combination of fields + nestedFields + labelField rivertags.Field +} + +type objectKeyType int + +const ( + objectKeyTypeInvalid objectKeyType = iota + objectKeyTypeField + objectKeyTypeNestedField +) + +// Has returns whether name exists as a field or a nested key inside keys. +// Returns objectKeyTypeInvalid if name does not exist as either. +func (of *objectFields) Has(name string) objectKeyType { + if _, ok := of.fields[name]; ok { + return objectKeyTypeField + } + if _, ok := of.nestedFields[name]; ok { + return objectKeyTypeNestedField + } + return objectKeyTypeInvalid +} + +// Len returns the number of named keys. +func (of *objectFields) Len() int { return len(of.keys) } + +// Keys returns all named keys (fields and nested fields). +func (of *objectFields) Keys() []string { return of.keys } + +// Field gets a non-nested field. Returns false if name is a nested field. +func (of *objectFields) Field(name string) (rivertags.Field, bool) { + f, ok := of.fields[name] + return f, ok +} + +// NestedField gets a named nested field entry. Returns false if name is not a +// nested field. +func (of *objectFields) NestedField(name string) (*objectFields, bool) { + nk, ok := of.nestedFields[name] + return nk, ok +} + +// LabelField returns the field used for the label (if any). +func (of *objectFields) LabelField() (rivertags.Field, bool) { + return of.labelField, of.labelField.Index != nil +} diff --git a/syntax/internal/value/type.go b/syntax/internal/value/type.go new file mode 100644 index 0000000000..e79715cbbf --- /dev/null +++ b/syntax/internal/value/type.go @@ -0,0 +1,157 @@ +package value + +import ( + "fmt" + "reflect" +) + +// Type represents the type of a River value loosely. For example, a Value may +// be TypeArray, but this does not imply anything about the type of that +// array's elements (all of which may be any type). +// +// TypeCapsule is a special type which encapsulates arbitrary Go values. +type Type uint8 + +// Supported Type values. +const ( + TypeNull Type = iota + TypeNumber + TypeString + TypeBool + TypeArray + TypeObject + TypeFunction + TypeCapsule +) + +var typeStrings = [...]string{ + TypeNull: "null", + TypeNumber: "number", + TypeString: "string", + TypeBool: "bool", + TypeArray: "array", + TypeObject: "object", + TypeFunction: "function", + TypeCapsule: "capsule", +} + +// String returns the name of t. +func (t Type) String() string { + if int(t) < len(typeStrings) { + return typeStrings[t] + } + return fmt.Sprintf("Type(%d)", t) +} + +// GoString returns the name of t. +func (t Type) GoString() string { return t.String() } + +// RiverType returns the River type from the Go type. +// +// Go types map to River types using the following rules: +// +// 1. Go numbers (ints, uints, floats) map to a River number. +// 2. Go strings map to a River string. +// 3. Go bools map to a River bool. +// 4. Go arrays and slices map to a River array. +// 5. Go map[string]T map to a River object. +// 6. Go structs map to a River object, provided they have at least one field +// with a river tag. +// 7. Valid Go functions map to a River function. +// 8. Go interfaces map to a River capsule. +// 9. All other Go values map to a River capsule. +// +// Go functions are only valid for River if they have one non-error return type +// (the first return type) and one optional error return type (the second +// return type). Other function types are treated as capsules. +// +// As an exception, any type which implements the Capsule interface is forced +// to be a capsule. +func RiverType(t reflect.Type) Type { + // We don't know if the RiverCapsule interface is implemented for a pointer + // or non-pointer type, so we have to check before and after dereferencing. + + for t.Kind() == reflect.Pointer { + switch { + case t.Implements(goCapsule): + return TypeCapsule + case t.Implements(goTextMarshaler): + return TypeString + } + + t = t.Elem() + } + + switch { + case t.Implements(goCapsule): + return TypeCapsule + case t.Implements(goTextMarshaler): + return TypeString + case t == goDuration: + return TypeString + } + + switch t.Kind() { + case reflect.Invalid: + return TypeNull + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return TypeNumber + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return TypeNumber + case reflect.Float32, reflect.Float64: + return TypeNumber + + case reflect.String: + return TypeString + + case reflect.Bool: + return TypeBool + + case reflect.Array, reflect.Slice: + if inner := t.Elem(); inner.Kind() == reflect.Struct { + if _, labeled := getCachedTags(inner).LabelField(); labeled { + // An slice/array of labeled blocks is an object, where each label is a + // top-level key. + return TypeObject + } + } + return TypeArray + + case reflect.Map: + if t.Key() != goString { + // Objects must be keyed by string. Anything else is forced to be a + // Capsule. + return TypeCapsule + } + return TypeObject + + case reflect.Struct: + if getCachedTags(t).Len() == 0 { + return TypeCapsule + } + return TypeObject + + case reflect.Func: + switch t.NumOut() { + case 1: + if t.Out(0) == goError { + return TypeCapsule + } + return TypeFunction + case 2: + if t.Out(0) == goError || t.Out(1) != goError { + return TypeCapsule + } + return TypeFunction + default: + return TypeCapsule + } + + case reflect.Interface: + return TypeCapsule + + default: + return TypeCapsule + } +} diff --git a/syntax/internal/value/type_test.go b/syntax/internal/value/type_test.go new file mode 100644 index 0000000000..10ee04bc75 --- /dev/null +++ b/syntax/internal/value/type_test.go @@ -0,0 +1,80 @@ +package value_test + +import ( + "reflect" + "testing" + + "github.com/grafana/river/internal/value" + "github.com/stretchr/testify/require" +) + +type customCapsule bool + +var _ value.Capsule = (customCapsule)(false) + +func (customCapsule) RiverCapsule() {} + +var typeTests = []struct { + input interface{} + expect value.Type +}{ + {int(0), value.TypeNumber}, + {int8(0), value.TypeNumber}, + {int16(0), value.TypeNumber}, + {int32(0), value.TypeNumber}, + {int64(0), value.TypeNumber}, + {uint(0), value.TypeNumber}, + {uint8(0), value.TypeNumber}, + {uint16(0), value.TypeNumber}, + {uint32(0), value.TypeNumber}, + {uint64(0), value.TypeNumber}, + {float32(0), value.TypeNumber}, + {float64(0), value.TypeNumber}, + + {string(""), value.TypeString}, + + {bool(false), value.TypeBool}, + + {[...]int{0, 1, 2}, value.TypeArray}, + {[]int{0, 1, 2}, value.TypeArray}, + + // Struct with no River tags is a capsule. + {struct{}{}, value.TypeCapsule}, + + // A slice of labeled blocks should be an object. + {[]struct { + Label string `river:",label"` + }{}, value.TypeObject}, + + {map[string]interface{}{}, value.TypeObject}, + + // Go functions must have one non-error return type and one optional error + // return type to be River functions. Everything else is a capsule. + {(func() int)(nil), value.TypeFunction}, + {(func() (int, error))(nil), value.TypeFunction}, + {(func())(nil), value.TypeCapsule}, // Must have non-error return type + {(func() error)(nil), value.TypeCapsule}, // First return type must be non-error + {(func() (error, int))(nil), value.TypeCapsule}, // First return type must be non-error + {(func() (error, error))(nil), value.TypeCapsule}, // First return type must be non-error + {(func() (int, int))(nil), value.TypeCapsule}, // Second return type must be error + {(func() (int, int, int))(nil), value.TypeCapsule}, // Can only have 1 or 2 return types + + {make(chan struct{}), value.TypeCapsule}, + {map[bool]interface{}{}, value.TypeCapsule}, // Maps with non-string types are capsules + + // Types with capsule markers should be capsules. + {customCapsule(false), value.TypeCapsule}, + {(*customCapsule)(nil), value.TypeCapsule}, + {(**customCapsule)(nil), value.TypeCapsule}, +} + +func Test_RiverType(t *testing.T) { + for _, tc := range typeTests { + rt := reflect.TypeOf(tc.input) + + t.Run(rt.String(), func(t *testing.T) { + actual := value.RiverType(rt) + require.Equal(t, tc.expect, actual, "Unexpected type for %#v", tc.input) + }) + } +} diff --git a/syntax/internal/value/value.go b/syntax/internal/value/value.go new file mode 100644 index 0000000000..bdd8492c09 --- /dev/null +++ b/syntax/internal/value/value.go @@ -0,0 +1,556 @@ +// Package value holds the internal representation for River values. River +// values act as a lightweight wrapper around reflect.Value. +package value + +import ( + "encoding" + "fmt" + "reflect" + "strconv" + "strings" + "time" + + "github.com/grafana/river/internal/reflectutil" +) + +// Go types used throughout the package. +var ( + goAny = reflect.TypeOf((*interface{})(nil)).Elem() + goString = reflect.TypeOf(string("")) + goByteSlice = reflect.TypeOf([]byte(nil)) + goError = reflect.TypeOf((*error)(nil)).Elem() + goTextMarshaler = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() + goTextUnmarshaler = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() + goStructWrapper = reflect.TypeOf(structWrapper{}) + goCapsule = reflect.TypeOf((*Capsule)(nil)).Elem() + goDuration = reflect.TypeOf((time.Duration)(0)) + goDurationPtr = reflect.TypeOf((*time.Duration)(nil)) + goRiverDefaulter = reflect.TypeOf((*Defaulter)(nil)).Elem() + goRiverDecoder = reflect.TypeOf((*Unmarshaler)(nil)).Elem() + goRiverValidator = reflect.TypeOf((*Validator)(nil)).Elem() + goRawRiverFunc = reflect.TypeOf((RawFunction)(nil)) + goRiverValue = reflect.TypeOf(Null) +) + +// NOTE(rfratto): This package is extremely sensitive to performance, so +// changes should be made with caution; run benchmarks when changing things. +// +// Value is optimized to be as small as possible and exist fully on the stack. +// This allows many values to avoid allocations, with the exception of creating +// arrays and objects. + +// Value represents a River value. +type Value struct { + rv reflect.Value + ty Type +} + +// Null is the null value. +var Null = Value{} + +// Uint returns a Value from a uint64. +func Uint(u uint64) Value { return Value{rv: reflect.ValueOf(u), ty: TypeNumber} } + +// Int returns a Value from an int64. +func Int(i int64) Value { return Value{rv: reflect.ValueOf(i), ty: TypeNumber} } + +// Float returns a Value from a float64. +func Float(f float64) Value { return Value{rv: reflect.ValueOf(f), ty: TypeNumber} } + +// String returns a Value from a string. +func String(s string) Value { return Value{rv: reflect.ValueOf(s), ty: TypeString} } + +// Bool returns a Value from a bool. +func Bool(b bool) Value { return Value{rv: reflect.ValueOf(b), ty: TypeBool} } + +// Object returns a new value from m. A copy of m is made for producing the +// Value. +func Object(m map[string]Value) Value { + return Value{ + rv: reflect.ValueOf(m), + ty: TypeObject, + } +} + +// Array creates an array from the given values. A copy of the vv slice is made +// for producing the Value. +func Array(vv ...Value) Value { + return Value{ + rv: reflect.ValueOf(vv), + ty: TypeArray, + } +} + +// Func makes a new function Value from f. Func panics if f does not map to a +// River function. +func Func(f interface{}) Value { + rv := reflect.ValueOf(f) + if RiverType(rv.Type()) != TypeFunction { + panic("river/value: Func called with non-function type") + } + return Value{rv: rv, ty: TypeFunction} +} + +// Encapsulate creates a new Capsule value from v. Encapsulate panics if v does +// not map to a River capsule. +func Encapsulate(v interface{}) Value { + rv := reflect.ValueOf(v) + if RiverType(rv.Type()) != TypeCapsule { + panic("river/value: Capsule called with non-capsule type") + } + return Value{rv: rv, ty: TypeCapsule} +} + +// Encode creates a new Value from v. If v is a pointer, v must be considered +// immutable and not change while the Value is used. +func Encode(v interface{}) Value { + if v == nil { + return Null + } + return makeValue(reflect.ValueOf(v)) +} + +// FromRaw converts a reflect.Value into a River Value. It is useful to prevent +// downcasting an interface into an any. +func FromRaw(v reflect.Value) Value { + return makeValue(v) +} + +// Type returns the River type for the value. +func (v Value) Type() Type { return v.ty } + +// Describe returns a descriptive type name for the value. For capsule values, +// this prints the underlying Go type name. For other values, it prints the +// normal River type. +func (v Value) Describe() string { + if v.ty != TypeCapsule { + return v.ty.String() + } + return fmt.Sprintf("capsule(%q)", v.rv.Type()) +} + +// Bool returns the boolean value for v. It panics if v is not a bool. +func (v Value) Bool() bool { + if v.ty != TypeBool { + panic("river/value: Bool called on non-bool type") + } + return v.rv.Bool() +} + +// Number returns a Number value for v. It panics if v is not a Number. +func (v Value) Number() Number { + if v.ty != TypeNumber { + panic("river/value: Number called on non-number type") + } + return newNumberValue(v.rv) +} + +// Int returns an int value for v. It panics if v is not a number. +func (v Value) Int() int64 { + if v.ty != TypeNumber { + panic("river/value: Int called on non-number type") + } + switch makeNumberKind(v.rv.Kind()) { + case NumberKindInt: + return v.rv.Int() + case NumberKindUint: + return int64(v.rv.Uint()) + case NumberKindFloat: + return int64(v.rv.Float()) + } + panic("river/value: unreachable") +} + +// Uint returns an uint value for v. It panics if v is not a number. +func (v Value) Uint() uint64 { + if v.ty != TypeNumber { + panic("river/value: Uint called on non-number type") + } + switch makeNumberKind(v.rv.Kind()) { + case NumberKindInt: + return uint64(v.rv.Int()) + case NumberKindUint: + return v.rv.Uint() + case NumberKindFloat: + return uint64(v.rv.Float()) + } + panic("river/value: unreachable") +} + +// Float returns a float value for v. It panics if v is not a number. +func (v Value) Float() float64 { + if v.ty != TypeNumber { + panic("river/value: Float called on non-number type") + } + switch makeNumberKind(v.rv.Kind()) { + case NumberKindInt: + return float64(v.rv.Int()) + case NumberKindUint: + return float64(v.rv.Uint()) + case NumberKindFloat: + return v.rv.Float() + } + panic("river/value: unreachable") +} + +// Text returns a string value of v. It panics if v is not a string. +func (v Value) Text() string { + if v.ty != TypeString { + panic("river/value: Text called on non-string type") + } + + // Attempt to get an address to v.rv for interface checking. + // + // The normal v.rv value is used for other checks. + addrRV := v.rv + if addrRV.CanAddr() { + addrRV = addrRV.Addr() + } + switch { + case addrRV.Type().Implements(goTextMarshaler): + // TODO(rfratto): what should we do if this fails? + text, _ := addrRV.Interface().(encoding.TextMarshaler).MarshalText() + return string(text) + + case v.rv.Type() == goDuration: + // Special case: v.rv is a duration and its String method should be used. + return v.rv.Interface().(time.Duration).String() + + default: + return v.rv.String() + } +} + +// Len returns the length of v. Panics if v is not an array or object. +func (v Value) Len() int { + switch v.ty { + case TypeArray: + return v.rv.Len() + case TypeObject: + switch { + case v.rv.Type() == goStructWrapper: + return v.rv.Interface().(structWrapper).Len() + case v.rv.Kind() == reflect.Array, v.rv.Kind() == reflect.Slice: // Array of labeled blocks + return v.rv.Len() + case v.rv.Kind() == reflect.Struct: + return getCachedTags(v.rv.Type()).Len() + case v.rv.Kind() == reflect.Map: + return v.rv.Len() + } + } + panic("river/value: Len called on non-array and non-object value") +} + +// Index returns index i of the Value. Panics if the value is not an array or +// if it is out of bounds of the array's size. +func (v Value) Index(i int) Value { + if v.ty != TypeArray { + panic("river/value: Index called on non-array value") + } + return makeValue(v.rv.Index(i)) +} + +// Interface returns the underlying Go value for the Value. +func (v Value) Interface() interface{} { + if v.ty == TypeNull { + return nil + } + return v.rv.Interface() +} + +// Reflect returns the raw reflection value backing v. +func (v Value) Reflect() reflect.Value { return v.rv } + +// makeValue converts a reflect value into a Value, dereferencing any pointers or +// interface{} values. +func makeValue(v reflect.Value) Value { + // Early check: if v is interface{}, we need to deference it to get the + // concrete value. + if v.IsValid() && v.Type() == goAny { + v = v.Elem() + } + + // Special case: a reflect.Value may be a value.Value when it's coming from a + // River array or object. We can unwrap the inner value here before + // continuing. + if v.IsValid() && v.Type() == goRiverValue { + // Unwrap the inner value. + v = v.Interface().(Value).rv + } + + // Before we get the River type of the Value, we need to see if it's possible + // to get a pointer to v. This ensures that if v is a non-pointer field of an + // addressable struct, still detect the type of v as if it was a pointer. + if v.CanAddr() { + v = v.Addr() + } + + if !v.IsValid() { + return Null + } + riverType := RiverType(v.Type()) + + // Finally, deference the pointer fully and use the type we detected. + for v.Kind() == reflect.Pointer { + if v.IsNil() { + return Null + } + v = v.Elem() + } + return Value{rv: v, ty: riverType} +} + +// OrderedKeys reports if v represents an object with consistently ordered +// keys. It panics if v is not an object. +func (v Value) OrderedKeys() bool { + if v.ty != TypeObject { + panic("river/value: OrderedKeys called on non-object value") + } + + // Maps are the only type of unordered River object, since their keys can't + // be iterated over in a deterministic order. Every other type of River + // object comes from a struct or a slice where the order of keys stays the + // same. + return v.rv.Kind() != reflect.Map +} + +// Keys returns the keys in v in unspecified order. It panics if v is not an +// object. +func (v Value) Keys() []string { + if v.ty != TypeObject { + panic("river/value: Keys called on non-object value") + } + + switch { + case v.rv.Type() == goStructWrapper: + return v.rv.Interface().(structWrapper).Keys() + + case v.rv.Kind() == reflect.Struct: + return wrapStruct(v.rv, true).Keys() + + case v.rv.Kind() == reflect.Array, v.rv.Kind() == reflect.Slice: + // List of labeled blocks. + labelField, _ := getCachedTags(v.rv.Type().Elem()).LabelField() + + keys := make([]string, v.rv.Len()) + for i := range keys { + keys[i] = reflectutil.Get(v.rv.Index(i), labelField).String() + } + return keys + + case v.rv.Kind() == reflect.Map: + reflectKeys := v.rv.MapKeys() + res := make([]string, len(reflectKeys)) + for i, rk := range reflectKeys { + res[i] = rk.String() + } + return res + } + + panic("river/value: unreachable") +} + +// Key returns the value for a key in v. It panics if v is not an object. ok +// will be false if the key did not exist in the object. +func (v Value) Key(key string) (index Value, ok bool) { + if v.ty != TypeObject { + panic("river/value: Key called on non-object value") + } + + switch { + case v.rv.Type() == goStructWrapper: + return v.rv.Interface().(structWrapper).Key(key) + case v.rv.Kind() == reflect.Struct: + // We return the struct with the label intact. + return wrapStruct(v.rv, true).Key(key) + case v.rv.Kind() == reflect.Map: + val := v.rv.MapIndex(reflect.ValueOf(key)) + if !val.IsValid() { + return Null, false + } + return makeValue(val), true + + case v.rv.Kind() == reflect.Slice, v.rv.Kind() == reflect.Array: + // List of labeled blocks. + labelField, _ := getCachedTags(v.rv.Type().Elem()).LabelField() + + for i := 0; i < v.rv.Len(); i++ { + elem := v.rv.Index(i) + + label := reflectutil.Get(elem, labelField).String() + if label == key { + // We discard the label since the key here represents the label value. + ws := wrapStruct(elem, false) + return ws.Value(), true + } + } + default: + panic("river/value: unreachable") + } + + return +} + +// Call invokes a function value with the provided arguments. It panics if v is +// not a function. If v is a variadic function, args should be the full flat +// list of arguments. +// +// An ArgError will be returned if one of the arguments is invalid. An Error +// will be returned if the function call returns an error or if the number of +// arguments doesn't match. +func (v Value) Call(args ...Value) (Value, error) { + if v.ty != TypeFunction { + panic("river/value: Call called on non-function type") + } + + if v.rv.Type() == goRawRiverFunc { + return v.rv.Interface().(RawFunction)(v, args...) + } + + var ( + variadic = v.rv.Type().IsVariadic() + expectedArgs = v.rv.Type().NumIn() + ) + + if variadic && len(args) < expectedArgs-1 { + return Null, Error{ + Value: v, + Inner: fmt.Errorf("expected at least %d args, got %d", expectedArgs-1, len(args)), + } + } else if !variadic && len(args) != expectedArgs { + return Null, Error{ + Value: v, + Inner: fmt.Errorf("expected %d args, got %d", expectedArgs, len(args)), + } + } + + reflectArgs := make([]reflect.Value, len(args)) + for i, arg := range args { + var argVal reflect.Value + if variadic && i >= expectedArgs-1 { + argType := v.rv.Type().In(expectedArgs - 1).Elem() + argVal = reflect.New(argType).Elem() + } else { + argType := v.rv.Type().In(i) + argVal = reflect.New(argType).Elem() + } + + var d decoder + if err := d.decode(arg, argVal); err != nil { + return Null, ArgError{ + Function: v, + Argument: arg, + Index: i, + Inner: err, + } + } + + reflectArgs[i] = argVal + } + + outs := v.rv.Call(reflectArgs) + switch len(outs) { + case 1: + return makeValue(outs[0]), nil + case 2: + // When there's 2 return values, the second is always an error. + err, _ := outs[1].Interface().(error) + if err != nil { + return Null, Error{Value: v, Inner: err} + } + return makeValue(outs[0]), nil + + default: + // It's not possible to reach here; we enforce that function values always + // have 1 or 2 return values. + panic("river/value: unreachable") + } +} + +func convertValue(val Value, toType Type) (Value, error) { + // TODO(rfratto): Use vm benchmarks to see if making this a method on Value + // changes anything. + + fromType := val.Type() + + if fromType == toType { + // no-op: val is already the right kind. + return val, nil + } + + switch fromType { + case TypeNumber: + switch toType { + case TypeString: // number -> string + strVal := newNumberValue(val.rv).ToString() + return makeValue(reflect.ValueOf(strVal)), nil + } + + case TypeString: + sourceStr := val.rv.String() + + switch toType { + case TypeNumber: // string -> number + switch { + case sourceStr == "": + return Null, TypeError{Value: val, Expected: toType} + + case sourceStr[0] == '-': + // String starts with a -; parse as a signed int. + parsed, err := strconv.ParseInt(sourceStr, 10, 64) + if err != nil { + return Null, TypeError{Value: val, Expected: toType} + } + return Int(parsed), nil + case strings.ContainsAny(sourceStr, ".eE"): + // String contains something that a floating-point number would use; + // convert. + parsed, err := strconv.ParseFloat(sourceStr, 64) + if err != nil { + return Null, TypeError{Value: val, Expected: toType} + } + return Float(parsed), nil + default: + // Otherwise, treat the number as an unsigned int. + parsed, err := strconv.ParseUint(sourceStr, 10, 64) + if err != nil { + return Null, TypeError{Value: val, Expected: toType} + } + return Uint(parsed), nil + } + } + } + + return Null, TypeError{Value: val, Expected: toType} +} + +func convertGoNumber(nval Number, target reflect.Type) reflect.Value { + switch target.Kind() { + case reflect.Int: + return reflect.ValueOf(int(nval.Int())) + case reflect.Int8: + return reflect.ValueOf(int8(nval.Int())) + case reflect.Int16: + return reflect.ValueOf(int16(nval.Int())) + case reflect.Int32: + return reflect.ValueOf(int32(nval.Int())) + case reflect.Int64: + return reflect.ValueOf(nval.Int()) + case reflect.Uint: + return reflect.ValueOf(uint(nval.Uint())) + case reflect.Uint8: + return reflect.ValueOf(uint8(nval.Uint())) + case reflect.Uint16: + return reflect.ValueOf(uint16(nval.Uint())) + case reflect.Uint32: + return reflect.ValueOf(uint32(nval.Uint())) + case reflect.Uint64: + return reflect.ValueOf(nval.Uint()) + case reflect.Float32: + return reflect.ValueOf(float32(nval.Float())) + case reflect.Float64: + return reflect.ValueOf(nval.Float()) + } + + panic("unsupported number conversion") +} diff --git a/syntax/internal/value/value_object.go b/syntax/internal/value/value_object.go new file mode 100644 index 0000000000..6a642cb22f --- /dev/null +++ b/syntax/internal/value/value_object.go @@ -0,0 +1,119 @@ +package value + +import ( + "reflect" + + "github.com/grafana/river/internal/reflectutil" +) + +// structWrapper allows for partially traversing structs which contain fields +// representing blocks. This is required due to how block names and labels +// change the object representation. +// +// If a block name is a.b.c, then it is represented as three nested objects: +// +// { +// a = { +// b = { +// c = { /* block contents */ }, +// }, +// } +// } +// +// Similarly, if a block name is labeled (a.b.c "label"), then the label is the +// top-level key after c. +// +// structWrapper exposes Len, Keys, and Key methods similar to Value to allow +// traversing through the synthetic object. The values it returns are +// structWrappers. +// +// Code in value.go MUST check to see if a struct is a structWrapper *before* +// checking the value kind to ensure the appropriate methods are invoked. +type structWrapper struct { + structVal reflect.Value + fields *objectFields + label string // Non-empty string if this struct is wrapped in a label. +} + +func wrapStruct(val reflect.Value, keepLabel bool) structWrapper { + if val.Kind() != reflect.Struct { + panic("river/value: wrapStruct called on non-struct value") + } + + fields := getCachedTags(val.Type()) + + var label string + if f, ok := fields.LabelField(); ok && keepLabel { + label = reflectutil.Get(val, f).String() + } + + return structWrapper{ + structVal: val, + fields: fields, + label: label, + } +} + +// Value turns sw into a value. +func (sw structWrapper) Value() Value { + return Value{ + rv: reflect.ValueOf(sw), + ty: TypeObject, + } +} + +func (sw structWrapper) Len() int { + if len(sw.label) > 0 { + return 1 + } + return sw.fields.Len() +} + +func (sw structWrapper) Keys() []string { + if len(sw.label) > 0 { + return []string{sw.label} + } + return sw.fields.Keys() +} + +func (sw structWrapper) Key(key string) (index Value, ok bool) { + if len(sw.label) > 0 { + if key != sw.label { + return + } + next := reflect.ValueOf(structWrapper{ + structVal: sw.structVal, + fields: sw.fields, + // Unset the label now that we've traversed it + }) + return Value{rv: next, ty: TypeObject}, true + } + + keyType := sw.fields.Has(key) + + switch keyType { + case objectKeyTypeInvalid: + return // No such key + + case objectKeyTypeNestedField: + // Continue traversing. + nextNode, _ := sw.fields.NestedField(key) + return Value{ + rv: reflect.ValueOf(structWrapper{ + structVal: sw.structVal, + fields: nextNode, + }), + ty: TypeObject, + }, true + + case objectKeyTypeField: + f, _ := sw.fields.Field(key) + val, err := sw.structVal.FieldByIndexErr(f.Index) + if err != nil { + return Null, true + } + return makeValue(val), true + } + + panic("river/value: unreachable") +} diff --git a/syntax/internal/value/value_object_test.go b/syntax/internal/value/value_object_test.go new file mode 100644 index 0000000000..56d72a6102 --- /dev/null +++ b/syntax/internal/value/value_object_test.go @@ -0,0 +1,205 @@ +package value_test + +import ( + "testing" + + "github.com/grafana/river/internal/value" + "github.com/stretchr/testify/require" +) + +// TestBlockRepresentation ensures that the struct tags for blocks are +// represented correctly. +func TestBlockRepresentation(t *testing.T) { + type UnlabledBlock struct { + Value int `river:"value,attr"` + } + type LabeledBlock struct { + Value int `river:"value,attr"` + Label string `river:",label"` + } + type OuterBlock struct { + Attr1 string `river:"attr_1,attr"` + Attr2 string `river:"attr_2,attr"` + + UnlabledBlock1 UnlabledBlock `river:"unlabeled.a,block"` + UnlabledBlock2 UnlabledBlock `river:"unlabeled.b,block"` + UnlabledBlock3 UnlabledBlock `river:"other_unlabeled,block"` + + LabeledBlock1 LabeledBlock `river:"labeled.a,block"` + LabeledBlock2 LabeledBlock `river:"labeled.b,block"` + LabeledBlock3 LabeledBlock `river:"other_labeled,block"` + } + + val := OuterBlock{ + Attr1: "value_1", + Attr2: "value_2", + UnlabledBlock1: UnlabledBlock{ + Value: 1, + }, + UnlabledBlock2: UnlabledBlock{ + Value: 2, + }, + UnlabledBlock3: UnlabledBlock{ + Value: 3, + }, + LabeledBlock1: LabeledBlock{ + Value: 4, + Label: "label_a", + }, + LabeledBlock2: LabeledBlock{ + Value: 5, + Label: "label_b", + }, + LabeledBlock3: LabeledBlock{ + Value: 6, + Label: "label_c", + }, + } + + t.Run("Map decode", func(t *testing.T) { + var m map[string]interface{} + require.NoError(t, value.Decode(value.Encode(val), &m)) + + type object = map[string]interface{} + + expect := object{ + "attr_1": "value_1", + "attr_2": "value_2", + "unlabeled": object{ + "a": object{"value": 1}, + "b": object{"value": 2}, + }, + "other_unlabeled": object{"value": 3}, + "labeled": object{ + "a": object{ + "label_a": object{"value": 4}, + }, + "b": object{ + "label_b": object{"value": 5}, + }, + }, + "other_labeled": object{ + "label_c": object{"value": 6}, + }, + } + + require.Equal(t, m, expect) + }) + + t.Run("Object decode from other object", func(t *testing.T) { + // Decode into a separate type which is structurally identical but not + // literally the same. + type OuterBlock2 OuterBlock + + var actualVal OuterBlock2 + require.NoError(t, value.Decode(value.Encode(val), &actualVal)) + require.Equal(t, val, OuterBlock(actualVal)) + }) +} + +// TestSquashedBlockRepresentation ensures that the struct tags for squashed +// blocks are represented correctly. +func TestSquashedBlockRepresentation(t *testing.T) { + type InnerStruct struct { + InnerField1 string `river:"inner_field_1,attr,optional"` + InnerField2 string `river:"inner_field_2,attr,optional"` + } + + type OuterStruct struct { + OuterField1 string `river:"outer_field_1,attr,optional"` + Inner InnerStruct `river:",squash"` + OuterField2 string `river:"outer_field_2,attr,optional"` + } + + val := OuterStruct{ + OuterField1: "value1", + Inner: InnerStruct{ + InnerField1: "value3", + InnerField2: "value4", + }, + OuterField2: "value2", + } + + t.Run("Map decode", func(t *testing.T) { + var m map[string]interface{} + require.NoError(t, value.Decode(value.Encode(val), &m)) + + type object = map[string]interface{} + + expect := object{ + "outer_field_1": "value1", + "inner_field_1": "value3", + "inner_field_2": "value4", + "outer_field_2": "value2", + } + + require.Equal(t, m, expect) + }) +} + +func TestSliceOfBlocks(t *testing.T) { + type UnlabledBlock struct { + Value int `river:"value,attr"` + } + type LabeledBlock struct { + Value int `river:"value,attr"` + Label string `river:",label"` + } + type OuterBlock struct { + Attr1 string `river:"attr_1,attr"` + Attr2 string `river:"attr_2,attr"` + + Unlabeled []UnlabledBlock `river:"unlabeled,block"` + Labeled []LabeledBlock `river:"labeled,block"` + } + + val := OuterBlock{ + Attr1: "value_1", + Attr2: "value_2", + Unlabeled: []UnlabledBlock{ + {Value: 1}, + {Value: 2}, + {Value: 3}, + }, + Labeled: []LabeledBlock{ + {Label: "label_a", Value: 4}, + {Label: "label_b", Value: 5}, + {Label: "label_c", Value: 6}, + }, + } + + t.Run("Map decode", func(t *testing.T) { + var m map[string]interface{} + require.NoError(t, value.Decode(value.Encode(val), &m)) + + type object = map[string]interface{} + type list = []interface{} + + expect := object{ + "attr_1": "value_1", + "attr_2": "value_2", + "unlabeled": list{ + object{"value": 1}, + object{"value": 2}, + object{"value": 3}, + }, + "labeled": object{ + "label_a": object{"value": 4}, + "label_b": object{"value": 5}, + "label_c": object{"value": 6}, + }, + } + + require.Equal(t, m, expect) + }) + + t.Run("Object decode from other object", func(t *testing.T) { + // Decode into a separate type which is structurally identical but not + // literally the same. + type OuterBlock2 OuterBlock + + var actualVal OuterBlock2 + require.NoError(t, value.Decode(value.Encode(val), &actualVal)) + require.Equal(t, val, OuterBlock(actualVal)) + }) +} diff --git a/syntax/internal/value/value_test.go b/syntax/internal/value/value_test.go new file mode 100644 index 0000000000..4583e5196d --- /dev/null +++ b/syntax/internal/value/value_test.go @@ -0,0 +1,243 @@ +package value_test + +import ( + "fmt" + "io" + "testing" + + "github.com/grafana/river/internal/value" + "github.com/stretchr/testify/require" +) + +// TestEncodeKeyLookup tests where Go values are retained correctly +// throughout values with a key lookup. +func TestEncodeKeyLookup(t *testing.T) { + type Body struct { + Data pointerMarshaler `river:"data,attr"` + } + + tt := []struct { + name string + encodeTarget any + key string + + expectBodyType value.Type + expectKeyExists bool + expectKeyValue value.Value + expectKeyType value.Type + }{ + { + name: "Struct Encode data Key", + encodeTarget: &Body{}, + key: "data", + expectBodyType: value.TypeObject, + expectKeyExists: true, + expectKeyValue: value.String("Hello, world!"), + expectKeyType: value.TypeString, + }, + { + name: "Struct Encode Missing Key", + encodeTarget: &Body{}, + key: "missing", + expectBodyType: value.TypeObject, + expectKeyExists: false, + expectKeyValue: value.Null, + expectKeyType: value.TypeNull, + }, + { + name: "Map Encode data Key", + encodeTarget: map[string]string{"data": "Hello, world!"}, + key: "data", + expectBodyType: value.TypeObject, + expectKeyExists: true, + expectKeyValue: value.String("Hello, world!"), + expectKeyType: value.TypeString, + }, + { + name: "Map Encode Missing Key", + encodeTarget: map[string]string{"data": "Hello, world!"}, + key: "missing", + expectBodyType: value.TypeObject, + expectKeyExists: false, + expectKeyValue: value.Null, + expectKeyType: value.TypeNull, + }, + { + name: "Map Encode empty value Key", + encodeTarget: map[string]string{"data": ""}, + key: "data", + expectBodyType: value.TypeObject, + expectKeyExists: true, + expectKeyValue: value.String(""), + expectKeyType: value.TypeString, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + bodyVal := value.Encode(tc.encodeTarget) + require.Equal(t, tc.expectBodyType, bodyVal.Type()) + + val, ok := bodyVal.Key(tc.key) + require.Equal(t, tc.expectKeyExists, ok) + require.Equal(t, tc.expectKeyType, val.Type()) + switch val.Type() { + case value.TypeString: + require.Equal(t, tc.expectKeyValue.Text(), val.Text()) + case value.TypeNull: + require.Equal(t, tc.expectKeyValue, val) + default: + require.Fail(t, "unexpected value type (this switch can be expanded)") + } + }) + } +} + +// TestEncodeNoKeyLookup tests where Go values are retained correctly +// throughout values without a key lookup. +func TestEncodeNoKeyLookup(t *testing.T) { + tt := []struct { + name string + encodeTarget any + key string + + expectBodyType value.Type + expectBodyText string + }{ + { + name: "Encode", + encodeTarget: &pointerMarshaler{}, + expectBodyType: value.TypeString, + expectBodyText: "Hello, world!", + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + bodyVal := value.Encode(tc.encodeTarget) + require.Equal(t, tc.expectBodyType, bodyVal.Type()) + require.Equal(t, "Hello, world!", bodyVal.Text()) + }) + } +} + +type pointerMarshaler struct{} + +func (*pointerMarshaler) MarshalText() ([]byte, error) { + return []byte("Hello, world!"), nil +} + +func TestValue_Call(t *testing.T) { + t.Run("simple", func(t *testing.T) { + add := func(a, b int) int { return a + b } + addVal := value.Encode(add) + + res, err := addVal.Call( + value.Int(15), + value.Int(43), + ) + require.NoError(t, err) + require.Equal(t, int64(15+43), res.Int()) + }) + + t.Run("fully variadic", func(t *testing.T) { + add := func(nums ...int) int { + var sum int + for _, num := range nums { + sum += num + } + return sum + } + addVal := value.Encode(add) + + t.Run("no args", func(t *testing.T) { + res, err := addVal.Call() + require.NoError(t, err) + require.Equal(t, int64(0), res.Int()) + }) + + t.Run("one arg", func(t *testing.T) { + res, err := addVal.Call(value.Int(32)) + require.NoError(t, err) + require.Equal(t, int64(32), res.Int()) + }) + + t.Run("many args", func(t *testing.T) { + res, err := addVal.Call( + value.Int(32), + value.Int(59), + value.Int(12), + ) + require.NoError(t, err) + require.Equal(t, int64(32+59+12), res.Int()) + }) + }) + + t.Run("partially variadic", func(t *testing.T) { + add := func(firstNum int, nums ...int) int { + sum := firstNum + for _, num := range nums { + sum += num + } + return sum + } + addVal := value.Encode(add) + + t.Run("no variadic args", func(t *testing.T) { + res, err := addVal.Call(value.Int(52)) + require.NoError(t, err) + require.Equal(t, int64(52), res.Int()) + }) + + t.Run("one variadic arg", func(t *testing.T) { + res, err := addVal.Call(value.Int(52), value.Int(32)) + require.NoError(t, err) + require.Equal(t, int64(52+32), res.Int()) + }) + + t.Run("many variadic args", func(t *testing.T) { + res, err := addVal.Call( + value.Int(32), + value.Int(59), + value.Int(12), + ) + require.NoError(t, err) + require.Equal(t, int64(32+59+12), res.Int()) + }) + }) + + t.Run("returns error", func(t *testing.T) { + failWhenTrue := func(val bool) (int, error) { + if val { + return 0, fmt.Errorf("function failed for a very good reason") + } + return 0, nil + } + funcVal := value.Encode(failWhenTrue) + + t.Run("no error", func(t *testing.T) { + res, err := funcVal.Call(value.Bool(false)) + require.NoError(t, err) + require.Equal(t, int64(0), res.Int()) + }) + + t.Run("error", func(t *testing.T) { + _, err := funcVal.Call(value.Bool(true)) + require.EqualError(t, err, "function failed for a very good reason") + }) + }) +} + +func TestValue_Interface_In_Array(t *testing.T) { + type Container struct { + Field io.Closer `river:"field,attr"` + } + + val := value.Encode(Container{Field: io.NopCloser(nil)}) + fieldVal, ok := val.Key("field") + require.True(t, ok, "field not found in object") + require.Equal(t, value.TypeCapsule, fieldVal.Type()) + + arrVal := value.Array(fieldVal) + require.Equal(t, value.TypeCapsule, arrVal.Index(0).Type()) +} diff --git a/syntax/parser/error_test.go b/syntax/parser/error_test.go new file mode 100644 index 0000000000..feb8602e31 --- /dev/null +++ b/syntax/parser/error_test.go @@ -0,0 +1,148 @@ +package parser + +import ( + "os" + "path/filepath" + "regexp" + "strings" + "testing" + + "github.com/grafana/river/diag" + "github.com/grafana/river/scanner" + "github.com/grafana/river/token" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// This file implements a parser test harness. The files in the testdata +// directory are parsed and the errors reported are compared against the error +// messages expected in the test files. +// +// Expected errors are indicated in the test files by putting a comment of the +// form /* ERROR "rx" */ immediately following an offending token. The harness +// will verify that an error matching the regular expression rx is reported at +// that source position. + +// ERROR comments must be of the form /* ERROR "rx" */ and rx is a regular +// expression that matches the expected error message. The special form +// /* ERROR HERE "rx" */ must be used for error messages that appear immediately +// after a token rather than at a token's position. +var errRx = regexp.MustCompile(`^/\* *ERROR *(HERE)? *"([^"]*)" *\*/$`) + +// expectedErrors collects the regular expressions of ERROR comments found in +// files and returns them as a map of error positions to error messages. +func expectedErrors(file *token.File, src []byte) map[token.Pos]string { + errors := make(map[token.Pos]string) + + s := scanner.New(file, src, nil, scanner.IncludeComments) + + var ( + prev token.Pos // Position of last non-comment, non-terminator token + here token.Pos // Position following after token at prev + ) + + for { + pos, tok, lit := s.Scan() + switch tok { + case token.EOF: + return errors + case token.COMMENT: + s := errRx.FindStringSubmatch(lit) + if len(s) == 3 { + pos := prev + if s[1] == "HERE" { + pos = here + } + errors[pos] = s[2] + } + case token.TERMINATOR: + if lit == "\n" { + break + } + fallthrough + default: + prev = pos + var l int // Token length + if isLiteral(tok) { + l = len(lit) + } else { + l = len(tok.String()) + } + here = prev.Add(l) + } + } +} + +func isLiteral(t token.Token) bool { + switch t { + case token.IDENT, token.NUMBER, token.FLOAT, token.STRING: + return true + } + return false +} + +// compareErrors compares the map of expected error messages with the list of +// found errors and reports mismatches. +func compareErrors(t *testing.T, file *token.File, expected map[token.Pos]string, found diag.Diagnostics) { + t.Helper() + + for _, checkError := range found { + pos := file.Pos(checkError.StartPos.Offset) + + if msg, found := expected[pos]; found { + // We expect a message at pos; check if it matches + rx, err := regexp.Compile(msg) + if !assert.NoError(t, err) { + continue + } + assert.True(t, + rx.MatchString(checkError.Message), + "%s: %q does not match %q", + checkError.StartPos, checkError.Message, msg, + ) + delete(expected, pos) // Eliminate consumed error + } else { + assert.Fail(t, + "Unexpected error", + "unexpected error: %s: %s", checkError.StartPos.String(), checkError.Message, + ) + } + } + + // There should be no expected errors left + if len(expected) > 0 { + t.Errorf("%d errors not reported:", len(expected)) + for pos, msg := range expected { + t.Errorf("%s: %s\n", file.PositionFor(pos), msg) + } + } +} + +func TestErrors(t *testing.T) { + list, err := os.ReadDir("testdata") + require.NoError(t, err) + + for _, d := range list { + name := d.Name() + if d.IsDir() || !strings.HasSuffix(name, ".river") { + continue + } + + t.Run(name, func(t *testing.T) { + checkErrors(t, filepath.Join("testdata", name)) + }) + } +} + +func checkErrors(t *testing.T, filename string) { + t.Helper() + + src, err := os.ReadFile(filename) + require.NoError(t, err) + + p := newParser(filename, src) + _ = p.ParseFile() + + expected := expectedErrors(p.file, src) + compareErrors(t, p.file, expected, p.diags) +} diff --git a/syntax/parser/internal.go b/syntax/parser/internal.go new file mode 100644 index 0000000000..1a8b7b7467 --- /dev/null +++ b/syntax/parser/internal.go @@ -0,0 +1,714 @@ +package parser + +import ( + "fmt" + "strings" + + "github.com/grafana/river/ast" + "github.com/grafana/river/diag" + "github.com/grafana/river/scanner" + "github.com/grafana/river/token" +) + +// parser implements the River parser. +// +// It is only safe for callers to use exported methods as entrypoints for +// parsing. +// +// Each Parse* and parse* method will describe the EBNF grammar being used for +// parsing that non-terminal. The EBNF grammar will be written as LL(1) and +// should directly represent the code. +// +// The parser will continue on encountering errors to allow a more complete +// list of errors to be returned to the user. The resulting AST should be +// discarded if errors were encountered during parsing. +type parser struct { + file *token.File + diags diag.Diagnostics + scanner *scanner.Scanner + comments []ast.CommentGroup + + pos token.Pos // Current token position + tok token.Token // Current token + lit string // Current token literal + + // Position of the last error written. Two parse errors on the same line are + // ignored. + lastError token.Position +} + +// newParser creates a new parser which will parse the provided src. +func newParser(filename string, src []byte) *parser { + file := token.NewFile(filename) + + p := &parser{ + file: file, + } + + p.scanner = scanner.New(file, src, func(pos token.Pos, msg string) { + p.diags.Add(diag.Diagnostic{ + Severity: diag.SeverityLevelError, + StartPos: file.PositionFor(pos), + Message: msg, + }) + }, scanner.IncludeComments) + + p.next() + return p +} + +// next advances the parser to the next non-comment token. +func (p *parser) next() { + p.next0() + + for p.tok == token.COMMENT { + p.consumeCommentGroup() + } +} + +// next0 advances the parser to the next token. next0 should not be used +// directly by parse methods; call next instead. +func (p *parser) next0() { p.pos, p.tok, p.lit = p.scanner.Scan() } + +// consumeCommentGroup consumes a group of adjacent comments, adding it to p's +// comment list. +func (p *parser) consumeCommentGroup() { + var list []*ast.Comment + + endline := p.pos.Position().Line + for p.tok == token.COMMENT && p.pos.Position().Line <= endline+1 { + var comment *ast.Comment + comment, endline = p.consumeComment() + list = append(list, comment) + } + + p.comments = append(p.comments, ast.CommentGroup(list)) +} + +// consumeComment consumes a comment and returns it with the line number it +// ends on. +func (p *parser) consumeComment() (comment *ast.Comment, endline int) { + endline = p.pos.Position().Line + + if p.lit[1] == '*' { + // Block comments may end on a different line than where they start. Scan + // the comment for newlines and adjust endline accordingly. + // + // NOTE: don't use range here, since range will unnecessarily decode + // Unicode code points and slow down the parser. + for i := 0; i < len(p.lit); i++ { + if p.lit[i] == '\n' { + endline++ + } + } + } + + comment = &ast.Comment{StartPos: p.pos, Text: p.lit} + p.next0() + return +} + +// advance consumes tokens up to (but not including) the specified token. +// advance will stop consuming tokens if EOF is reached before to. +func (p *parser) advance(to token.Token) { + for p.tok != token.EOF { + if p.tok == to { + return + } + p.next() + } +} + +// advanceAny consumes tokens up to (but not including) any of the tokens in +// the to set. +func (p *parser) advanceAny(to map[token.Token]struct{}) { + for p.tok != token.EOF { + if _, inSet := to[p.tok]; inSet { + return + } + p.next() + } +} + +// expect consumes the next token. It records an error if the consumed token +// was not t. +func (p *parser) expect(t token.Token) (pos token.Pos, tok token.Token, lit string) { + pos, tok, lit = p.pos, p.tok, p.lit + if tok != t { + p.addErrorf("expected %s, got %s", t, p.tok) + } + p.next() + return +} + +func (p *parser) addErrorf(format string, args ...interface{}) { + pos := p.file.PositionFor(p.pos) + + // Ignore errors which occur on the same line. + if p.lastError.Line == pos.Line { + return + } + p.lastError = pos + + p.diags.Add(diag.Diagnostic{ + Severity: diag.SeverityLevelError, + StartPos: pos, + Message: fmt.Sprintf(format, args...), + }) +} + +// ParseFile parses an entire file. +// +// File = Body +func (p *parser) ParseFile() *ast.File { + body := p.parseBody(token.EOF) + + return &ast.File{ + Name: p.file.Name(), + Body: body, + Comments: p.comments, + } +} + +// parseBody parses a series of statements up to and including the "until" +// token, which terminates the body. +// +// Body = [ Statement { terminator Statement } ] +func (p *parser) parseBody(until token.Token) ast.Body { + var body ast.Body + + for p.tok != until && p.tok != token.EOF { + stmt := p.parseStatement() + if stmt != nil { + body = append(body, stmt) + } + + if p.tok == until { + break + } + + if p.tok != token.TERMINATOR { + p.addErrorf("expected %s, got %s", token.TERMINATOR, p.tok) + p.consumeStatement() + } + p.next() + } + + return body +} + +// consumeStatement consumes tokens for the remainder of a statement (i.e., up +// to but not including a terminator). consumeStatement will keep track of the +// number of {}, [], and () pairs, only returning after the count of pairs is +// <= 0. +func (p *parser) consumeStatement() { + var curlyPairs, brackPairs, parenPairs int + + for p.tok != token.EOF { + switch p.tok { + case token.LCURLY: + curlyPairs++ + case token.RCURLY: + curlyPairs-- + case token.LBRACK: + brackPairs++ + case token.RBRACK: + brackPairs-- + case token.LPAREN: + parenPairs++ + case token.RPAREN: + parenPairs-- + } + + if p.tok == token.TERMINATOR { + // Only return after we've consumed all pairs. It's possible for pairs to + // be less than zero if our statement started in a surrounding pair. + if curlyPairs <= 0 && brackPairs <= 0 && parenPairs <= 0 { + return + } + } + + p.next() + } +} + +// parseStatement parses an individual statement within a body. +// +// Statement = Attribute | Block +// Attribute = identifier "=" Expression +// Block = BlockName "{" Body "}" +func (p *parser) parseStatement() ast.Stmt { + blockName := p.parseBlockName() + if blockName == nil { + // parseBlockName failed; skip to the next identifier which would start a + // new Statement. + p.advance(token.IDENT) + return nil + } + + // p.tok is now the first token after the identifier in the attribute or + // block name. + switch p.tok { + case token.ASSIGN: // Attribute + p.next() // Consume "=" + + if len(blockName.Fragments) != 1 { + attrName := strings.Join(blockName.Fragments, ".") + p.diags.Add(diag.Diagnostic{ + Severity: diag.SeverityLevelError, + StartPos: blockName.Start.Position(), + EndPos: blockName.Start.Add(len(attrName) - 1).Position(), + Message: `attribute names may only consist of a single identifier with no "."`, + }) + } else if blockName.LabelPos != token.NoPos { + p.diags.Add(diag.Diagnostic{ + Severity: diag.SeverityLevelError, + StartPos: blockName.LabelPos.Position(), + // Add 1 to the end position to add in the end quote, which is stripped from the label value. + EndPos: blockName.LabelPos.Add(len(blockName.Label) + 1).Position(), + Message: `attribute names may not have labels`, + }) + } + + return &ast.AttributeStmt{ + Name: &ast.Ident{ + Name: blockName.Fragments[0], + NamePos: blockName.Start, + }, + Value: p.ParseExpression(), + } + + case token.LCURLY: // Block + block := &ast.BlockStmt{ + Name: blockName.Fragments, + NamePos: blockName.Start, + Label: blockName.Label, + LabelPos: blockName.LabelPos, + } + + block.LCurlyPos, _, _ = p.expect(token.LCURLY) + block.Body = p.parseBody(token.RCURLY) + block.RCurlyPos, _, _ = p.expect(token.RCURLY) + + return block + + default: + if blockName.ValidAttribute() { + // The blockname could be used for an attribute or a block (no label, + // only one name fragment), so inform the user of both cases. + p.addErrorf("expected attribute assignment or block body, got %s", p.tok) + } else { + p.addErrorf("expected block body, got %s", p.tok) + } + + // Give up on this statement and skip to the next identifier. + p.advance(token.IDENT) + return nil + } +} + +// parseBlockName parses the name used for a block. +// +// BlockName = identifier { "." identifier } [ string ] +func (p *parser) parseBlockName() *blockName { + if p.tok != token.IDENT { + p.addErrorf("expected identifier, got %s", p.tok) + return nil + } + + var bn blockName + + bn.Fragments = append(bn.Fragments, p.lit) // Append first identifier + bn.Start = p.pos + p.next() + + // { "." identifier } + for p.tok == token.DOT { + p.next() // consume "." + + if p.tok != token.IDENT { + p.addErrorf("expected identifier, got %s", p.tok) + + // Continue here to parse as much as possible, even though the block name + // will be malformed. + } + + bn.Fragments = append(bn.Fragments, p.lit) + p.next() + } + + // [ string ] + if p.tok != token.ASSIGN && p.tok != token.LCURLY { + if p.tok == token.STRING { + // Only allow double-quoted strings for block labels. + if p.lit[0] != '"' { + p.addErrorf("expected block label to be a double quoted string, but got %q", p.lit) + } + + // Strip the quotes if it's non-empty. We then require any non-empty + // label to be a valid identifier. + if len(p.lit) > 2 { + bn.Label = p.lit[1 : len(p.lit)-1] + if !scanner.IsValidIdentifier(bn.Label) { + p.addErrorf("expected block label to be a valid identifier, but got %q", bn.Label) + } + } + bn.LabelPos = p.pos + } else { + p.addErrorf("expected block label, got %s", p.tok) + } + p.next() + } + + return &bn +} + +type blockName struct { + Fragments []string // Name fragments (i.e., `a.b.c`) + Label string // Optional user label + + Start token.Pos + LabelPos token.Pos +} + +// ValidAttribute returns true if the blockName can be used as an attribute +// name. +func (n blockName) ValidAttribute() bool { + return len(n.Fragments) == 1 && n.Label == "" +} + +// ParseExpression parses a single expression. +// +// Expression = BinOpExpr +func (p *parser) ParseExpression() ast.Expr { + return p.parseBinOp(1) +} + +// parseBinOp is the entrypoint for binary expressions. If there is no binary +// expressions in the current state, a single operand will be returned instead. +// +// BinOpExpr = OrExpr +// OrExpr = AndExpr { "||" AndExpr } +// AndExpr = CmpExpr { "&&" CmpExpr } +// CmpExpr = AddExpr { cmp_op AddExpr } +// AddExpr = MulExpr { add_op MulExpr } +// MulExpr = PowExpr { mul_op PowExpr } +// +// parseBinOp avoids the need for multiple non-terminal functions by providing +// context for operator precedence in recursive calls. inPrec specifies the +// incoming operator precedence. On the first call to parseBinOp, inPrec should +// be 1. +// +// parseBinOp can only handle left-associative operators, so PowExpr is handled +// by parsePowExpr. +func (p *parser) parseBinOp(inPrec int) ast.Expr { + // The EBNF documented by the function can be generalized into: + // + // CurPrecExpr = NextPrecExpr { cur_prec_ops NextPrecExpr } + // + // The code below implements this specific grammar, continually collecting + // everything at the same precedence level into the LHS of the expression + // while recursively calling parseBinOp for higher-precedence operations. + + lhs := p.parsePowExpr() + + for { + tok, pos, prec := p.tok, p.pos, p.tok.BinaryPrecedence() + if prec < inPrec { + // The next operator is lower precedence; drop up a level in our call + // stack. + return lhs + } + p.next() // Consume the operator + + // Recurse with a higher precedence level, which ensures that operators at + // the same precedence level don't get handled in the recursive call. + rhs := p.parseBinOp(prec + 1) + + lhs = &ast.BinaryExpr{ + Left: lhs, + Kind: tok, + KindPos: pos, + Right: rhs, + } + } +} + +// parsePowExpr is like parseBinOp but handles the right-associative pow +// operator. +// +// PowExpr = UnaryExpr [ "^" PowExpr ] +func (p *parser) parsePowExpr() ast.Expr { + lhs := p.parseUnaryExpr() + + if p.tok == token.POW { + pos := p.pos + p.next() // Consume ^ + + return &ast.BinaryExpr{ + Left: lhs, + Kind: token.POW, + KindPos: pos, + Right: p.parsePowExpr(), + } + } + + return lhs +} + +// parseUnaryExpr parses a unary expression. +// +// UnaryExpr = OperExpr | unary_op UnaryExpr +// +// OperExpr = PrimaryExpr { AccessExpr | IndexExpr | CallExpr } +// AccessExpr = "." identifier +// IndexExpr = "[" Expression "]" +// CallExpr = "(" [ ExpressionList ] ")" +func (p *parser) parseUnaryExpr() ast.Expr { + if isUnaryOp(p.tok) { + op, pos := p.tok, p.pos + p.next() // Consume op + + return &ast.UnaryExpr{ + Kind: op, + KindPos: pos, + Value: p.parseUnaryExpr(), + } + } + + primary := p.parsePrimaryExpr() + +NextOper: + for { + switch p.tok { + case token.DOT: // AccessExpr + p.next() + namePos, _, name := p.expect(token.IDENT) + + primary = &ast.AccessExpr{ + Value: primary, + Name: &ast.Ident{ + Name: name, + NamePos: namePos, + }, + } + + case token.LBRACK: // IndexExpr + lBrack, _, _ := p.expect(token.LBRACK) + index := p.ParseExpression() + rBrack, _, _ := p.expect(token.RBRACK) + + primary = &ast.IndexExpr{ + Value: primary, + LBrackPos: lBrack, + Index: index, + RBrackPos: rBrack, + } + + case token.LPAREN: // CallExpr + var args []ast.Expr + + lParen, _, _ := p.expect(token.LPAREN) + if p.tok != token.RPAREN { + args = p.parseExpressionList(token.RPAREN) + } + rParen, _, _ := p.expect(token.RPAREN) + + primary = &ast.CallExpr{ + Value: primary, + LParenPos: lParen, + Args: args, + RParenPos: rParen, + } + + case token.STRING, token.LCURLY: + // A user might be trying to assign a block to an attribute. let's + // attempt to parse the remainder as a block to tell them something is + // wrong. + // + // If we can't parse the remainder of the expression as a block, we give + // up and parse the remainder of the entire statement. + if p.tok == token.STRING { + p.next() + } + if _, tok, _ := p.expect(token.LCURLY); tok != token.LCURLY { + p.consumeStatement() + return primary + } + p.parseBody(token.RCURLY) + + end, tok, _ := p.expect(token.RCURLY) + if tok != token.RCURLY { + p.consumeStatement() + return primary + } + + p.diags.Add(diag.Diagnostic{ + Severity: diag.SeverityLevelError, + StartPos: ast.StartPos(primary).Position(), + EndPos: end.Position(), + Message: "cannot use a block as an expression", + }) + + default: + break NextOper + } + } + + return primary +} + +func isUnaryOp(tok token.Token) bool { + switch tok { + case token.NOT, token.SUB: + return true + default: + return false + } +} + +// parsePrimaryExpr parses a primary expression. +// +// PrimaryExpr = LiteralValue | ArrayExpr | ObjectExpr +// +// LiteralValue = identifier | string | number | float | bool | null | +// "(" Expression ")" +// +// ArrayExpr = "[" [ ExpressionList ] "]" +// ObjectExpr = "{" [ FieldList ] "}" +func (p *parser) parsePrimaryExpr() ast.Expr { + switch p.tok { + case token.IDENT: + res := &ast.IdentifierExpr{ + Ident: &ast.Ident{ + Name: p.lit, + NamePos: p.pos, + }, + } + p.next() + return res + + case token.STRING, token.NUMBER, token.FLOAT, token.BOOL, token.NULL: + res := &ast.LiteralExpr{ + Kind: p.tok, + Value: p.lit, + ValuePos: p.pos, + } + p.next() + return res + + case token.LPAREN: + lParen, _, _ := p.expect(token.LPAREN) + expr := p.ParseExpression() + rParen, _, _ := p.expect(token.RPAREN) + + return &ast.ParenExpr{ + LParenPos: lParen, + Inner: expr, + RParenPos: rParen, + } + + case token.LBRACK: + var res ast.ArrayExpr + + res.LBrackPos, _, _ = p.expect(token.LBRACK) + if p.tok != token.RBRACK { + res.Elements = p.parseExpressionList(token.RBRACK) + } + res.RBrackPos, _, _ = p.expect(token.RBRACK) + return &res + + case token.LCURLY: + var res ast.ObjectExpr + + res.LCurlyPos, _, _ = p.expect(token.LCURLY) + if p.tok != token.RBRACK { + res.Fields = p.parseFieldList(token.RCURLY) + } + res.RCurlyPos, _, _ = p.expect(token.RCURLY) + return &res + } + + p.addErrorf("expected expression, got %s", p.tok) + res := &ast.LiteralExpr{Kind: token.NULL, Value: "null", ValuePos: p.pos} + p.advanceAny(statementEnd) // Eat up the rest of the line + return res +} + +var statementEnd = map[token.Token]struct{}{ + token.TERMINATOR: {}, + token.RPAREN: {}, + token.RCURLY: {}, + token.RBRACK: {}, + token.COMMA: {}, +} + +// parseExpressionList parses a list of expressions. +// +// ExpressionList = Expression { "," Expression } [ "," ] +func (p *parser) parseExpressionList(until token.Token) []ast.Expr { + var exprs []ast.Expr + + for p.tok != until && p.tok != token.EOF { + exprs = append(exprs, p.ParseExpression()) + + if p.tok == until { + break + } + if p.tok != token.COMMA { + p.addErrorf("missing ',' in expression list") + } + p.next() + } + + return exprs +} + +// parseFieldList parses a list of fields in an object. +// +// FieldList = Field { "," Field } [ "," ] +func (p *parser) parseFieldList(until token.Token) []*ast.ObjectField { + var fields []*ast.ObjectField + + for p.tok != until && p.tok != token.EOF { + fields = append(fields, p.parseField()) + + if p.tok == until { + break + } + if p.tok != token.COMMA { + p.addErrorf("missing ',' in field list") + } + p.next() + } + + return fields +} + +// parseField parses a field in an object. +// +// Field = ( string | identifier ) "=" Expression +func (p *parser) parseField() *ast.ObjectField { + var field ast.ObjectField + + if p.tok == token.STRING || p.tok == token.IDENT { + field.Name = &ast.Ident{ + Name: p.lit, + NamePos: p.pos, + } + if p.tok == token.STRING && len(p.lit) > 2 { + // The field name is a string literal; unwrap the quotes. + field.Name.Name = p.lit[1 : len(p.lit)-1] + field.Quoted = true + } + p.next() // Consume field name + } else { + p.addErrorf("expected field name (string or identifier), got %s", p.tok) + p.advance(token.ASSIGN) + } + + p.expect(token.ASSIGN) + + field.Value = p.ParseExpression() + return &field +} diff --git a/syntax/parser/internal_test.go b/syntax/parser/internal_test.go new file mode 100644 index 0000000000..c3be1e7581 --- /dev/null +++ b/syntax/parser/internal_test.go @@ -0,0 +1,22 @@ +package parser + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestObjectFieldName(t *testing.T) { + tt := []string{ + `field_a = 5`, + `"field_a" = 5`, // Quotes should be removed from the field name + } + + for _, tc := range tt { + p := newParser(t.Name(), []byte(tc)) + + res := p.parseField() + + assert.Equal(t, "field_a", res.Name.Name) + } +} diff --git a/syntax/parser/parser.go b/syntax/parser/parser.go new file mode 100644 index 0000000000..66d2199d4b --- /dev/null +++ b/syntax/parser/parser.go @@ -0,0 +1,43 @@ +// Package parser implements utilities for parsing River configuration files. +package parser + +import ( + "github.com/grafana/river/ast" + "github.com/grafana/river/token" +) + +// ParseFile parses an entire River configuration file. The data parameter +// should hold the file contents to parse, while the filename parameter is used +// for reporting errors. +// +// If an error was encountered during parsing, the returned AST will be nil and +// err will be an diag.Diagnostics all the errors encountered during parsing. +func ParseFile(filename string, data []byte) (*ast.File, error) { + p := newParser(filename, data) + + f := p.ParseFile() + if len(p.diags) > 0 { + return nil, p.diags + } + return f, nil +} + +// ParseExpression parses a single River expression from expr. +// +// If an error was encountered during parsing, the returned expression will be +// nil and err will be an ErrorList with all the errors encountered during +// parsing. +func ParseExpression(expr string) (ast.Expr, error) { + p := newParser("", []byte(expr)) + + e := p.ParseExpression() + + // If the current token is not a TERMINATOR then the parsing did not complete + // in full and there are still parts of the string left unparsed. + p.expect(token.TERMINATOR) + + if len(p.diags) > 0 { + return nil, p.diags + } + return e, nil +} diff --git a/syntax/parser/parser_test.go b/syntax/parser/parser_test.go new file mode 100644 index 0000000000..f567c4650a --- /dev/null +++ b/syntax/parser/parser_test.go @@ -0,0 +1,123 @@ +package parser + +import ( + "io/fs" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func FuzzParser(f *testing.F) { + filepath.WalkDir("./testdata/valid", func(path string, d fs.DirEntry, _ error) error { + if d.IsDir() { + return nil + } + + bb, err := os.ReadFile(path) + require.NoError(f, err) + f.Add(bb) + return nil + }) + + f.Fuzz(func(t *testing.T, input []byte) { + p := newParser(t.Name(), input) + + _ = p.ParseFile() + if len(p.diags) > 0 { + t.SkipNow() + } + }) +} + +// TestValid parses every *.river file in testdata, which is expected to be +// valid. +func TestValid(t *testing.T) { + filepath.WalkDir("./testdata/valid", func(path string, d fs.DirEntry, _ error) error { + if d.IsDir() { + return nil + } + + t.Run(filepath.Base(path), func(t *testing.T) { + bb, err := os.ReadFile(path) + require.NoError(t, err) + + p := newParser(path, bb) + + res := p.ParseFile() + require.NotNil(t, res) + require.Len(t, p.diags, 0) + }) + + return nil + }) +} + +func TestParseExpressions(t *testing.T) { + tt := map[string]string{ + "literal number": `10`, + "literal float": `15.0`, + "literal string": `"Hello, world!"`, + "literal ident": `some_ident`, + "literal null": `null`, + "literal true": `true`, + "literal false": `false`, + + "empty array": `[]`, + "array one element": `[1]`, + "array many elements": `[0, 1, 2, 3]`, + "array trailing comma": `[0, 1, 2, 3,]`, + "nested array": `[[0, 1, 2], [3, 4, 5]]`, + "array multiline": `[ + 0, + 1, + 2, + ]`, + + "empty object": `{}`, + "object one field": `{ field_a = 5 }`, + "object multiple fields": `{ field_a = 5, field_b = 10 }`, + "object trailing comma": `{ field_a = 5, field_b = 10, }`, + "nested objects": `{ field_a = { nested_field = 100 } }`, + "object multiline": `{ + field_a = 5, + field_b = 10, + }`, + + "unary not": `!true`, + "unary neg": `-5`, + + "math": `1 + 2 - 3 * 4 / 5 % 6`, + "compare ops": `1 == 2 != 3 < 4 > 5 <= 6 >= 7`, + "logical ops": `true || false && true`, + "pow operator": "1 ^ 2 ^ 3", + + "field access": `a.b.c.d`, + "element access": `a[0][1][2]`, + + "call no args": `a()`, + "call one arg": `a(1)`, + "call multiple args": `a(1,2,3)`, + "call with trailing comma": `a(1,2,3,)`, + "call multiline": `a( + 1, + 2, + 3, + )`, + + "parens": `(1 + 5) * 100`, + + "mixed expression": `(a.b.c)(1, 3 * some_list[magic_index * 2]).resulting_field`, + } + + for name, input := range tt { + t.Run(name, func(t *testing.T) { + p := newParser(name, []byte(input)) + + res := p.ParseExpression() + require.NotNil(t, res) + require.Len(t, p.diags, 0) + }) + } +} diff --git a/syntax/parser/testdata/assign_block_to_attr.river b/syntax/parser/testdata/assign_block_to_attr.river new file mode 100644 index 0000000000..e291308599 --- /dev/null +++ b/syntax/parser/testdata/assign_block_to_attr.river @@ -0,0 +1,32 @@ +rw = prometheus/* ERROR "cannot use a block as an expression" */.remote_write "default" { + endpoint { + url = "some_url" + basic_auth { + username = "username" + password = "password" + } + } +} + +attr_1 = 15 +attr_2 = 51 + +block { + rw_2 = prometheus/* ERROR "cannot use a block as an expression" */.remote_write "other" { + endpoint { + url = "other_url" + basic_auth { + username = "username_2" + password = "password_2" + } + } + } +} + +other_block { + // This is an expression which looks like it might be a block at first, but + // then isn't. + rw_3 = prometheus.remote_write "other" "other" /* ERROR "expected {, got STRING" */ 12345 +} + +attr_3 = 15 diff --git a/syntax/parser/testdata/attribute_names.river b/syntax/parser/testdata/attribute_names.river new file mode 100644 index 0000000000..1e18c60850 --- /dev/null +++ b/syntax/parser/testdata/attribute_names.river @@ -0,0 +1,7 @@ +valid_attr = 15 + +// The parser parses block names for both blocks and attributes, and later +// validates that the attribute name is just a single identifier with no label. + +invalid/* ERROR "attribute names may only consist of a single identifier" */.attr = 20 +invalid "label" /* ERROR "attribute names may not have labels" */ = 20 diff --git a/syntax/parser/testdata/block_names.river b/syntax/parser/testdata/block_names.river new file mode 100644 index 0000000000..cc0f2040f3 --- /dev/null +++ b/syntax/parser/testdata/block_names.river @@ -0,0 +1,25 @@ +valid_block { + +} + +valid_block "labeled" { + +} + +invalid_block bad_label_name /* ERROR "expected block label, got IDENT" */ { + +} + +other_valid_block { + nested_block { + + } + + nested_block "labeled" { + + } +} + +invalid_block "with space" /* ERROR "expected block label to be a valid identifier" */ { + +} diff --git a/syntax/parser/testdata/commas.river b/syntax/parser/testdata/commas.river new file mode 100644 index 0000000000..a43bb5c873 --- /dev/null +++ b/syntax/parser/testdata/commas.river @@ -0,0 +1,13 @@ +// Test that missing trailing commas for multiline expressions get reported. + +field = [ + 0, + 1, + 2/* ERROR HERE "missing ',' in expression list" */ +] + +obj = { + field_a = 0, + field_b = 1, + field_c = 2/* ERROR HERE "missing ',' in field list" */ +} diff --git a/syntax/parser/testdata/fuzz/FuzzParser/1a39f4e358facc21678b16fad53537b46efdaa76e024a5ef4955d01a68bdac37 b/syntax/parser/testdata/fuzz/FuzzParser/1a39f4e358facc21678b16fad53537b46efdaa76e024a5ef4955d01a68bdac37 new file mode 100644 index 0000000000..6151b52921 --- /dev/null +++ b/syntax/parser/testdata/fuzz/FuzzParser/1a39f4e358facc21678b16fad53537b46efdaa76e024a5ef4955d01a68bdac37 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("A0000000000000000") diff --git a/syntax/parser/testdata/fuzz/FuzzParser/248cf4391f6c48550b7d2cf4c6c80f4ba9099c21ffa2b6869e75e99565dce037 b/syntax/parser/testdata/fuzz/FuzzParser/248cf4391f6c48550b7d2cf4c6c80f4ba9099c21ffa2b6869e75e99565dce037 new file mode 100644 index 0000000000..cc252cd81b --- /dev/null +++ b/syntax/parser/testdata/fuzz/FuzzParser/248cf4391f6c48550b7d2cf4c6c80f4ba9099c21ffa2b6869e75e99565dce037 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("A={A!0A\"") diff --git a/syntax/parser/testdata/fuzz/FuzzParser/b919fa00ebca318001778477c839a06204b55f2636597901d8d7878150d8580a b/syntax/parser/testdata/fuzz/FuzzParser/b919fa00ebca318001778477c839a06204b55f2636597901d8d7878150d8580a new file mode 100644 index 0000000000..ff4ab488f8 --- /dev/null +++ b/syntax/parser/testdata/fuzz/FuzzParser/b919fa00ebca318001778477c839a06204b55f2636597901d8d7878150d8580a @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("A\"") diff --git a/syntax/parser/testdata/invalid_exprs.river b/syntax/parser/testdata/invalid_exprs.river new file mode 100644 index 0000000000..c7ea3e4385 --- /dev/null +++ b/syntax/parser/testdata/invalid_exprs.river @@ -0,0 +1,4 @@ +attr = 1 + + /* ERROR "expected expression, got +" */ 2 + +invalid_func_call = a(() /* ERROR "expected expression, got \)" */) +invalid_access = a.true /* ERROR "expected IDENT, got BOOL" */ diff --git a/syntax/parser/testdata/invalid_object_key.river b/syntax/parser/testdata/invalid_object_key.river new file mode 100644 index 0000000000..1dd5a4ba8c --- /dev/null +++ b/syntax/parser/testdata/invalid_object_key.river @@ -0,0 +1,9 @@ +obj { + map = { + "string_field" = "foo", + identifier_string = "bar", + 1337 /* ERROR "expected field name \(string or identifier\), got NUMBER" */ = "baz", + "another_field" = "qux", + } +} + diff --git a/syntax/parser/testdata/valid/attribute.river b/syntax/parser/testdata/valid/attribute.river new file mode 100644 index 0000000000..9cc731b27f --- /dev/null +++ b/syntax/parser/testdata/valid/attribute.river @@ -0,0 +1 @@ +number_field = 1 diff --git a/syntax/parser/testdata/valid/blocks.river b/syntax/parser/testdata/valid/blocks.river new file mode 100644 index 0000000000..74d79bfe67 --- /dev/null +++ b/syntax/parser/testdata/valid/blocks.river @@ -0,0 +1,36 @@ +one_ident { + number_field = 1 +} + +one_ident "labeled" { + number_field = 1 +} + +multiple.idents { + number_field = 1 +} + +multiple.idents "labeled" { + number_field = 1 +} + +chain.of.idents { + number_field = 1 +} + +chain.of.idents "labeled" { + number_field = 1 +} + +one_ident_inline { number_field = 1 } +one_ident_inline "labeled" { number_field = 1 } +multiple.idents_inline { number_field = 1 } +multiple.idents_inline "labeled" { number_field = 1 } +chain.of.idents { number_field = 1 } +chain.of.idents "labeled" { number_field = 1 } + +nested_block { + inner_block { + some_field = true + } +} diff --git a/syntax/parser/testdata/valid/comments.river b/syntax/parser/testdata/valid/comments.river new file mode 100644 index 0000000000..9dafbd7346 --- /dev/null +++ b/syntax/parser/testdata/valid/comments.river @@ -0,0 +1 @@ +// Hello, world! diff --git a/syntax/parser/testdata/valid/empty.river b/syntax/parser/testdata/valid/empty.river new file mode 100644 index 0000000000..e69de29bb2 diff --git a/syntax/parser/testdata/valid/expressions.river b/syntax/parser/testdata/valid/expressions.river new file mode 100644 index 0000000000..78af48cc22 --- /dev/null +++ b/syntax/parser/testdata/valid/expressions.river @@ -0,0 +1,81 @@ +// Literals +lit_number = 10 +lit_float = 15.0 +lit_string = "Hello, world!" +lit_ident = other_ident +lit_null = null +lit_true = true +lit_false = false + +// Arrays +array_expr_empty = [] +array_expr_one_element = [0] +array_expr = [0, 1, 2, 3] +array_expr_trailing = [0, 1, 2, 3,] +array_expr_multiline = [ + 0, + 1, + 2, + 3, +] +array_expr_nested = [[1]] + +// Objects +object_expr_empty = {} +object_expr_one_field = { field_a = 1 } +object_expr = { field_a = 1, field_b = 2 } +object_expr_trailing = { field_a = 1, field_b = 2, } +object_expr_multiline = { + field_a = 1, + field_b = 2, +} +object_expr_nested = { field_a = { nested_field_a = 1 } } + +// Unary ops +not_something = !true +neg_number = -5 + +// Math binops +binop_sum = 1 + 2 +binop_sub = 1 - 2 +binop_mul = 1 * 2 +binop_div = 1 / 2 +binop_mod = 1 % 2 +binop_pow = 1 ^ 2 ^ 3 + +// Compare binops +binop_eq = 1 == 2 +binop_neq = 1 != 2 +binop_lt = 1 < 2 +binop_lte = 1 <= 2 +binop_gt = 1 > 2 +binop_gte = 1 >= 2 + +// Logical binops +binop_or = true || false +binop_and = true && false + + +// Mixed math operations +math = 1 + 2 - 3 * 4 / 5 % 6 +compare_ops = 1 == 2 != 3 < 4 > 5 <= 6 >= 7 +logical_ops = true || false && true +mixed_assoc = 1 * 3 + 5 ^ 3 - 2 % 1 // Test with both left- and right- associative operators +expr_parens = (5 * 2) + 5 + +// Accessors +field_access = a.b.c.d +element_access = a[0][1][2] + +// Function calls +call_no_args = a() +call_one_arg = a(1) +call_multiple_args = a(1,2,3) +call_trailing_comma = a(1,2,3,) +call_multiline = a( + 1, + 2, + 3, +) + +mixed_expr = (a.b.c)(1, 3 * some_list[magic_index * 2]).resulting_field diff --git a/syntax/printer/printer.go b/syntax/printer/printer.go new file mode 100644 index 0000000000..8faeb22006 --- /dev/null +++ b/syntax/printer/printer.go @@ -0,0 +1,556 @@ +// Package printer contains utilities for pretty-printing River ASTs. +package printer + +import ( + "fmt" + "io" + "math" + "text/tabwriter" + + "github.com/grafana/river/ast" + "github.com/grafana/river/token" +) + +// Config configures behavior of the printer. +type Config struct { + Indent int // Indentation to apply to all emitted code. Default 0. +} + +// Fprint pretty-prints the specified node to w. The Node type must be an +// *ast.File, ast.Body, or a type that implements ast.Stmt or ast.Expr. +func (c *Config) Fprint(w io.Writer, node ast.Node) (err error) { + var p printer + p.Init(c) + + // Pass all of our text through a trimmer to ignore trailing whitespace. + w = &trimmer{next: w} + + if err = (&walker{p: &p}).Walk(node); err != nil { + return + } + + // Call flush one more time to write trailing comments. + p.flush(token.Position{ + Offset: math.MaxInt, + Line: math.MaxInt, + Column: math.MaxInt, + }, token.EOF) + + w = tabwriter.NewWriter(w, 0, 8, 1, ' ', tabwriter.DiscardEmptyColumns|tabwriter.TabIndent) + + if _, err = w.Write(p.output); err != nil { + return + } + if tw, _ := w.(*tabwriter.Writer); tw != nil { + // Flush tabwriter if defined + err = tw.Flush() + } + + return +} + +// Fprint pretty-prints the specified node to w. The Node type must be an +// *ast.File, ast.Body, or a type that implements ast.Stmt or ast.Expr. +func Fprint(w io.Writer, node ast.Node) error { + c := &Config{} + return c.Fprint(w, node) +} + +// The printer writes lexical tokens and whitespace to an internal buffer. +// Comments are written by the printer itself, while all other tokens and +// formatting characters are sent through calls to Write. +// +// Internally, printer depends on a tabwriter for formatting text and aligning +// runs of characters. Horizontal '\t' and vertical '\v' tab characters are +// used to introduce new columns in the row. Runs of characters are stopped +// be either introducing a linefeed '\f' or by having a line with a different +// number of columns from the previous line. See the text/tabwriter package for +// more information on the elastic tabstop algorithm it uses for formatting +// text. +type printer struct { + cfg Config + + // State variables + + output []byte + indent int // Current indentation level + lastTok token.Token // Last token printed (token.LITERAL if it's whitespace) + + // Whitespace holds a buffer of whitespace characters to print prior to the + // next non-whitespace token. Whitespace is held in a buffer to avoid + // printing unnecessary whitespace at the end of a file. + whitespace []whitespace + + // comments stores comments to be processed as elements are printed. + comments commentInfo + + // pos is an approximation of the current position in AST space, and is used + // to determine space between AST elements (e.g., if a comment should come + // before a token). pos automatically as elements are written and can be manually + // set to guarantee an accurate position by passing a token.Pos to Write. + pos token.Position + last token.Position // Last pos written to output (through writeString) + + // out is an accurate representation of the current position in output space, + // used to inject extra formatting like indentation based on the output + // position. + // + // out may differ from pos in terms of whitespace. + out token.Position +} + +type commentInfo struct { + list []ast.CommentGroup + idx int + cur ast.CommentGroup + pos token.Pos +} + +func (ci *commentInfo) commentBefore(next token.Position) bool { + return ci.pos != token.NoPos && ci.pos.Offset() <= next.Offset +} + +// nextComment preloads the next comment. +func (ci *commentInfo) nextComment() { + for ci.idx < len(ci.list) { + c := ci.list[ci.idx] + ci.idx++ + if len(c) > 0 { + ci.cur = c + ci.pos = ast.StartPos(c[0]) + return + } + } + ci.pos = token.NoPos +} + +// Init initializes the printer for printing. Init is intended to be called +// once per printer and doesn't fully reset its state. +func (p *printer) Init(cfg *Config) { + p.cfg = *cfg + p.pos = token.Position{Line: 1, Column: 1} + p.out = token.Position{Line: 1, Column: 1} + // Capacity is set low since most whitespace sequences are short. + p.whitespace = make([]whitespace, 0, 16) +} + +// SetComments set the comments to use. +func (p *printer) SetComments(comments []ast.CommentGroup) { + p.comments = commentInfo{ + list: comments, + idx: 0, + pos: token.NoPos, + } + p.comments.nextComment() +} + +// Write writes a list of writable arguments to the printer. +// +// Arguments can be one of the types described below: +// +// If arg is a whitespace value, it is accumulated into a buffer and flushed +// only after a non-whitespace value is processed. The whitespace buffer will +// be forcibly flushed if the buffer becomes full without writing a +// non-whitespace token. +// +// If arg is an *ast.IdentifierExpr, *ast.LiteralExpr, or a token.Token, the +// human-readable representation of that value will be written. +// +// When writing text, comments which need to appear before that text in +// AST-space are written first, followed by leftover whitespace and then the +// text to write. The written text will update the AST-space position. +// +// If arg is a token.Pos, the AST-space position of the printer is updated to +// the provided Pos. Writing token.Pos values can help make sure the printer's +// AST-space position is accurate, as AST-space position is otherwise an +// estimation based on written data. +func (p *printer) Write(args ...interface{}) { + for _, arg := range args { + var ( + data string + isLit bool + ) + + switch arg := arg.(type) { + case whitespace: + // Whitespace token; add it to our token buffer. Note that a whitespace + // token is different than the actual whitespace which will get written + // (e.g., wsIndent increases indentation level by one instead of setting + // it to one.) + if arg == wsIgnore { + continue + } + i := len(p.whitespace) + if i == cap(p.whitespace) { + // We built up too much whitespace; this can happen if too many calls + // to Write happen without appending a non-comment token. We will + // force-flush the existing whitespace to avoid a panic. + // + // Ideally this line is never hit based on how we walk the AST, but + // it's kept for safety. + p.writeWritespace(i) + i = 0 + } + p.whitespace = p.whitespace[0 : i+1] + p.whitespace[i] = arg + p.lastTok = token.LITERAL + continue + + case *ast.Ident: + data = arg.Name + p.lastTok = token.IDENT + + case *ast.LiteralExpr: + data = arg.Value + p.lastTok = arg.Kind + + case token.Pos: + if arg.Valid() { + p.pos = arg.Position() + } + // Don't write anything; token.Pos is an instruction and doesn't include + // any text to write. + continue + + case token.Token: + s := arg.String() + data = s + + // We will need to inject whitespace if the previous token and the + // current token would combine into a single token when re-scanned. This + // ensures that the sequence of tokens emitted by the output of the + // printer match the sequence of tokens from the input. + if mayCombine(p.lastTok, s[0]) { + if len(p.whitespace) != 0 { + // It shouldn't be possible for the whitespace buffer to be not empty + // here; p.lastTok would've had to been a non-whitespace token and so + // whitespace would've been flushed when it was written to the output + // buffer. + panic("whitespace buffer not empty") + } + p.whitespace = p.whitespace[0:1] + p.whitespace[0] = ' ' + } + p.lastTok = arg + + default: + panic(fmt.Sprintf("printer: unsupported argument %v (%T)\n", arg, arg)) + } + + next := p.pos + + p.flush(next, p.lastTok) + p.writeString(next, data, isLit) + } +} + +// mayCombine returns true if two tokes must not be combined, because combining +// them would format in a different token sequence being generated. +func mayCombine(prev token.Token, next byte) (b bool) { + switch prev { + case token.NUMBER: + return next == '.' // 1. + case token.DIV: + return next == '*' // /* + default: + return false + } +} + +// flush prints any pending comments and whitespace occurring textually before +// the position of the next token tok. The flush result indicates if a newline +// was written or if a formfeed \f character was dropped from the whitespace +// buffer. +func (p *printer) flush(next token.Position, tok token.Token) { + if p.comments.commentBefore(next) { + p.injectComments(next, tok) + } else if tok != token.EOF { + // Write all remaining whitespace. + p.writeWritespace(len(p.whitespace)) + } +} + +func (p *printer) injectComments(next token.Position, tok token.Token) { + var lastComment *ast.Comment + + for p.comments.commentBefore(next) { + for _, c := range p.comments.cur { + p.writeCommentPrefix(next, c) + p.writeComment(next, c) + lastComment = c + } + p.comments.nextComment() + } + + p.writeCommentSuffix(next, tok, lastComment) +} + +// writeCommentPrefix writes whitespace that should appear before c. +func (p *printer) writeCommentPrefix(next token.Position, c *ast.Comment) { + if len(p.output) == 0 { + // The comment is the first thing written to the output. Don't write any + // whitespace before it. + return + } + + cPos := c.StartPos.Position() + + if cPos.Line == p.last.Line { + // Our comment is on the same line as the last token. Write a separator + // between the last token and the comment. + separator := byte('\t') + if cPos.Line == next.Line { + // The comment is on the same line as the next token, which means it has + // to be a block comment (since line comments run to the end of the + // line.) Use a space as the separator instead since a tab in the middle + // of a line between comments would look weird. + separator = byte(' ') + } + p.writeByte(separator, 1) + } else { + // Our comment is on a different line from the last token. First write + // pending whitespace from the last token up to the first newline. + var wsCount int + + for i, ws := range p.whitespace { + switch ws { + case wsBlank, wsVTab: + // Drop any whitespace before the comment. + p.whitespace[i] = wsIgnore + case wsIndent, wsUnindent: + // Allow indentation to be applied. + continue + case wsNewline, wsFormfeed: + // Drop the whitespace since we're about to write our own. + p.whitespace[i] = wsIgnore + } + wsCount = i + break + } + p.writeWritespace(wsCount) + + var newlines int + if cPos.Valid() && p.last.Valid() { + newlines = cPos.Line - p.last.Line + } + if newlines > 0 { + p.writeByte('\f', newlineLimit(newlines)) + } + } +} + +func (p *printer) writeComment(_ token.Position, c *ast.Comment) { + p.writeString(c.StartPos.Position(), c.Text, true) +} + +// writeCommentSuffix writes any whitespace necessary between the last comment +// and next. lastComment should be the final comment written. +func (p *printer) writeCommentSuffix(next token.Position, tok token.Token, lastComment *ast.Comment) { + if tok == token.EOF { + // We don't want to add any blank newlines before the end of the file; + // return early. + return + } + + var droppedFF bool + + // If our final comment is a block comment and is on the same line as the + // next token, add a space as a suffix to separate them. + lastCommentPos := ast.EndPos(lastComment).Position() + if lastComment.Text[1] == '*' && next.Line == lastCommentPos.Line { + p.writeByte(' ', 1) + } + + newlines := next.Line - p.last.Line + + for i, ws := range p.whitespace { + switch ws { + case wsBlank, wsVTab: + p.whitespace[i] = wsIgnore + case wsIndent, wsUnindent: + continue + case wsNewline, wsFormfeed: + if ws == wsFormfeed { + droppedFF = true + } + p.whitespace[i] = wsIgnore + } + } + + p.writeWritespace(len(p.whitespace)) + + // Write newlines as long as the next token isn't EOF (so that there's no + // blank newlines at the end of the file). + if newlines > 0 { + ch := byte('\n') + if droppedFF { + // If we dropped a formfeed while writing comments, we should emit a new + // one. + ch = byte('\f') + } + p.writeByte(ch, newlineLimit(newlines)) + } +} + +// writeString writes the literal string s into the printer's output. +// Formatting characters in s such as '\t' and '\n' will be interpreted by +// underlying tabwriter unless isLit is set. +func (p *printer) writeString(pos token.Position, s string, isLit bool) { + if p.out.Column == 1 { + // We haven't written any text to this line yet; prepend our indentation + // for the line. + p.writeIndent() + } + + if pos.Valid() { + // Update p.pos if pos is valid. This is done *after* handling indentation + // since we want to interpret pos as the literal position for s (and + // writeIndent will update p.pos). + p.pos = pos + } + + if isLit { + // Wrap our literal string in tabwriter.Escape if it's meant to be written + // without interpretation by the tabwriter. + p.output = append(p.output, tabwriter.Escape) + + defer func() { + p.output = append(p.output, tabwriter.Escape) + }() + } + + p.output = append(p.output, s...) + + var ( + newlines int + lastNewlineIdx int + ) + + for i := 0; i < len(s); i++ { + if ch := s[i]; ch == '\n' || ch == '\f' { + newlines++ + lastNewlineIdx = i + } + } + + p.pos.Offset += len(s) + + if newlines > 0 { + p.pos.Line += newlines + p.out.Line += newlines + + newColumn := len(s) - lastNewlineIdx + p.pos.Column = newColumn + p.out.Column = newColumn + } else { + p.pos.Column += len(s) + p.out.Column += len(s) + } + + p.last = p.pos +} + +func (p *printer) writeIndent() { + depth := p.cfg.Indent + p.indent + for i := 0; i < depth; i++ { + p.output = append(p.output, '\t') + } + + p.pos.Offset += depth + p.pos.Column += depth + p.out.Column += depth +} + +// writeByte writes ch n times to the output, updating the position of the +// printer. writeByte is only used for writing whitespace characters. +func (p *printer) writeByte(ch byte, n int) { + if p.out.Column == 1 { + p.writeIndent() + } + + for i := 0; i < n; i++ { + p.output = append(p.output, ch) + } + + // Update positions. + p.pos.Offset += n + if ch == '\n' || ch == '\f' { + p.pos.Line += n + p.out.Line += n + p.pos.Column = 1 + p.out.Column = 1 + return + } + p.pos.Column += n + p.out.Column += n +} + +// writeWhitespace writes the first n whitespace entries in the whitespace +// buffer. +// +// writeWritespace is only safe to be called when len(p.whitespace) >= n. +func (p *printer) writeWritespace(n int) { + for i := 0; i < n; i++ { + switch ch := p.whitespace[i]; ch { + case wsIgnore: // no-op + case wsIndent: + p.indent++ + case wsUnindent: + p.indent-- + if p.indent < 0 { + panic("printer: negative indentation") + } + default: + p.writeByte(byte(ch), 1) + } + } + + // Shift remaining entries down + l := copy(p.whitespace, p.whitespace[n:]) + p.whitespace = p.whitespace[:l] +} + +const maxNewlines = 2 + +// newlineLimit limits a newline count to maxNewlines. +func newlineLimit(count int) int { + if count > maxNewlines { + count = maxNewlines + } + return count +} + +// whitespace represents a whitespace token to write to the printer's internal +// buffer. +type whitespace byte + +const ( + wsIgnore = whitespace(0) + wsBlank = whitespace(' ') + wsVTab = whitespace('\v') + wsNewline = whitespace('\n') + wsFormfeed = whitespace('\f') + wsIndent = whitespace('>') + wsUnindent = whitespace('<') +) + +func (ws whitespace) String() string { + switch ws { + case wsIgnore: + return "wsIgnore" + case wsBlank: + return "wsBlank" + case wsVTab: + return "wsVTab" + case wsNewline: + return "wsNewline" + case wsFormfeed: + return "wsFormfeed" + case wsIndent: + return "wsIndent" + case wsUnindent: + return "wsUnindent" + default: + return fmt.Sprintf("whitespace(%d)", ws) + } +} diff --git a/syntax/printer/printer_test.go b/syntax/printer/printer_test.go new file mode 100644 index 0000000000..38f69217b1 --- /dev/null +++ b/syntax/printer/printer_test.go @@ -0,0 +1,77 @@ +package printer_test + +import ( + "bytes" + "io/fs" + "os" + "path/filepath" + "strings" + "testing" + "unicode" + + "github.com/grafana/river/parser" + "github.com/grafana/river/printer" + "github.com/stretchr/testify/require" +) + +func TestPrinter(t *testing.T) { + filepath.WalkDir("testdata", func(path string, d fs.DirEntry, _ error) error { + if d.IsDir() { + return nil + } + + if strings.HasSuffix(path, ".in") { + inputFile := path + expectFile := strings.TrimSuffix(path, ".in") + ".expect" + expectErrorFile := strings.TrimSuffix(path, ".in") + ".error" + + caseName := filepath.Base(path) + caseName = strings.TrimSuffix(caseName, ".in") + + t.Run(caseName, func(t *testing.T) { + testPrinter(t, inputFile, expectFile, expectErrorFile) + }) + } + + return nil + }) +} + +func testPrinter(t *testing.T, inputFile string, expectFile string, expectErrorFile string) { + inputBB, err := os.ReadFile(inputFile) + require.NoError(t, err) + + f, err := parser.ParseFile(t.Name()+".rvr", inputBB) + if expectedError := getExpectedErrorMessage(t, expectErrorFile); expectedError != "" { + require.Error(t, err) + require.Contains(t, err.Error(), expectedError) + return + } + + expectBB, err := os.ReadFile(expectFile) + require.NoError(t, err) + + var buf bytes.Buffer + require.NoError(t, printer.Fprint(&buf, f)) + + trimmed := strings.TrimRightFunc(string(expectBB), unicode.IsSpace) + require.Equal(t, trimmed, buf.String(), "%s", buf.String()) +} + +// getExpectedErrorMessage will retrieve an optional expected error message for the test. +func getExpectedErrorMessage(t *testing.T, errorFile string) string { + if _, err := os.Stat(errorFile); err == nil { + errorBytes, err := os.ReadFile(errorFile) + require.NoError(t, err) + errorsString := string(normalizeLineEndings(errorBytes)) + return errorsString + } + + return "" +} + +// normalizeLineEndings will replace '\r\n' with '\n'. +func normalizeLineEndings(data []byte) []byte { + normalized := bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'}) + return normalized +} diff --git a/syntax/printer/testdata/.gitattributes b/syntax/printer/testdata/.gitattributes new file mode 100644 index 0000000000..7949f2b32c --- /dev/null +++ b/syntax/printer/testdata/.gitattributes @@ -0,0 +1 @@ +* -text eol=lf diff --git a/syntax/printer/testdata/array_comments.expect b/syntax/printer/testdata/array_comments.expect new file mode 100644 index 0000000000..a9e921b6dc --- /dev/null +++ b/syntax/printer/testdata/array_comments.expect @@ -0,0 +1,17 @@ +// array_comments.in expects that comments in arrays are formatted to +// retain the indentation level of elements within the arrays. + +attr = [ // Inline comment + 0, 1, 2, // Inline comment + 3, 4, 5, // Inline comment + // Trailing comment +] + +attr = [ + 0, + // Element-level comment + 1, + // Element-level comment + 2, + // Trailing comment +] diff --git a/syntax/printer/testdata/array_comments.in b/syntax/printer/testdata/array_comments.in new file mode 100644 index 0000000000..e088e83749 --- /dev/null +++ b/syntax/printer/testdata/array_comments.in @@ -0,0 +1,17 @@ +// array_comments.in expects that comments in arrays are formatted to +// retain the indentation level of elements within the arrays. + +attr = [ // Inline comment + 0, 1, 2, // Inline comment + 3, 4, 5, // Inline comment + // Trailing comment +] + +attr = [ + 0, + // Element-level comment + 1, + // Element-level comment + 2, + // Trailing comment +] diff --git a/syntax/printer/testdata/block_comments.expect b/syntax/printer/testdata/block_comments.expect new file mode 100644 index 0000000000..ca60e0796c --- /dev/null +++ b/syntax/printer/testdata/block_comments.expect @@ -0,0 +1,62 @@ +// block_comments.in expects that comments within blocks are formatted to +// remain within the block with the proper indentation. + +// +// Unlabeled blocks +// + +// Comment is on same line as empty block header. +block { // comment +} + +// Comment is on same line as non-empty block header. +block { // comment + attr = 5 +} + +// Comment is alone in block body. +block { + // comment +} + +// Comment is before a statement. +block { + // comment + attr = 5 +} + +// Comment is after a statement. +block { + attr = 5 + // comment +} + +// +// Labeled blocks +// + +// Comment is on same line as empty block header. +block "label" { // comment +} + +// Comment is on same line as non-empty block header. +block "label" { // comment + attr = 5 +} + +// Comment is alone in block body. +block "label" { + // comment +} + +// Comment is before a statement. +block "label" { + // comment + attr = 5 +} + +// Comment is after a statement. +block "label" { + attr = 5 + // comment +} diff --git a/syntax/printer/testdata/block_comments.in b/syntax/printer/testdata/block_comments.in new file mode 100644 index 0000000000..d60512c9a3 --- /dev/null +++ b/syntax/printer/testdata/block_comments.in @@ -0,0 +1,64 @@ +// block_comments.in expects that comments within blocks are formatted to +// remain within the block with the proper indentation. + +// +// Unlabeled blocks +// + +// Comment is on same line as empty block header. +block { // comment +} + +// Comment is on same line as non-empty block header. +block { // comment + attr = 5 +} + +// Comment is alone in block body. +block { +// comment +} + +// Comment is before a statement. +block { +// comment + attr = 5 +} + +// Comment is after a statement. +block { + attr = 5 +// comment +} + +// +// Labeled blocks +// + +// Comment is on same line as empty block header. +block "label" { // comment +} + +// Comment is on same line as non-empty block header. +block "label" { // comment + attr = 5 +} + +// Comment is alone in block body. +block "label" { +// comment +} + +// Comment is before a statement. +block "label" { +// comment + attr = 5 +} + +// Comment is after a statement. +block "label" { + attr = 5 +// comment +} + + diff --git a/syntax/printer/testdata/example.expect b/syntax/printer/testdata/example.expect new file mode 100644 index 0000000000..dd6ab8e1f2 --- /dev/null +++ b/syntax/printer/testdata/example.expect @@ -0,0 +1,60 @@ +// This file tests a little bit of everything that the formatter should do. For +// example, this block of comments itself ensures that the output retains +// comments found in the source file. + +// +// Whitespace tests +// + +// Attributes should be given whitespace +attr_1 = 15 +attr_2 = 30 * 2 + 5 +attr_3 = field.access * 2 + +// Blocks with nothing inside of them should be truncated. +empty.block { } + +empty.block "labeled" { } + +// +// Alignment tests +// + +// Sequences of attributes which aren't separated by a blank line should have +// the equal sign aligned. +short_name = true +really_long_name = true + +extremely_long_name = true + +// Sequences of comments on aligned lines should also be aligned. +short_name = "short value" // Align me +really_long_name = "really long value" // Align me + +extremely_long_name = true // Unaligned + +// +// Indentation tests +// + +// Array literals, object literals, and blocks should all be indented properly. +multiline_array = [ + 0, + 1, +] + +mulitiline_object = { + foo = "bar", +} + +some_block { + attr = 15 + + inner_block { + attr = 20 + } +} + +// Trailing comments should be retained in the output. If this comment gets +// trimmed out, it usually indicates that a final flush is missing after +// traversing the AST. diff --git a/syntax/printer/testdata/example.in b/syntax/printer/testdata/example.in new file mode 100644 index 0000000000..efce00a8a3 --- /dev/null +++ b/syntax/printer/testdata/example.in @@ -0,0 +1,64 @@ +// This file tests a little bit of everything that the formatter should do. For +// example, this block of comments itself ensures that the output retains +// comments found in the source file. + +// +// Whitespace tests +// + +// Attributes should be given whitespace +attr_1=15 +attr_2=30*2+5 +attr_3=field.access*2 + +// Blocks with nothing inside of them should be truncated. +empty.block { + +} + +empty.block "labeled" { + +} + +// +// Alignment tests +// + +// Sequences of attributes which aren't separated by a blank line should have +// the equal sign aligned. +short_name = true +really_long_name = true + +extremely_long_name = true + +// Sequences of comments on aligned lines should also be aligned. +short_name = "short value" // Align me +really_long_name = "really long value" // Align me + +extremely_long_name = true // Unaligned + +// +// Indentation tests +// + +// Array literals, object literals, and blocks should all be indented properly. +multiline_array = [ +0, +1, +] + +mulitiline_object = { +foo = "bar", +} + +some_block { +attr = 15 + +inner_block { +attr = 20 +} +} + +// Trailing comments should be retained in the output. If this comment gets +// trimmed out, it usually indicates that a final flush is missing after +// traversing the AST. diff --git a/syntax/printer/testdata/func_call.expect b/syntax/printer/testdata/func_call.expect new file mode 100644 index 0000000000..e42c7acaf7 --- /dev/null +++ b/syntax/printer/testdata/func_call.expect @@ -0,0 +1,17 @@ +one_line = some_func(1, 2, 3, 4) + +multi_line = some_func(1, + 2, 3, + 4) + +multi_line_pretty = some_func( + 1, + 2, + 3, + 4, +) + +func_with_obj = some_func({ + key1 = "value1", + key2 = "value2", +}) diff --git a/syntax/printer/testdata/func_call.in b/syntax/printer/testdata/func_call.in new file mode 100644 index 0000000000..141a6057cf --- /dev/null +++ b/syntax/printer/testdata/func_call.in @@ -0,0 +1,17 @@ +one_line = some_func(1, 2, 3, 4) + +multi_line = some_func(1, +2, 3, +4) + +multi_line_pretty = some_func( +1, +2, +3, +4, +) + +func_with_obj = some_func({ + key1 = "value1", + key2 = "value2", +}) diff --git a/syntax/printer/testdata/mixed_list.expect b/syntax/printer/testdata/mixed_list.expect new file mode 100644 index 0000000000..7ce6f28dfb --- /dev/null +++ b/syntax/printer/testdata/mixed_list.expect @@ -0,0 +1,16 @@ +mixed_list = [0, true, { + key_1 = true, + key_2 = true, + key_3 = true, +}, "Hello!"] + +mixed_list_2 = [ + 0, + true, + { + key_1 = true, + key_2 = true, + key_3 = true, + }, + "Hello!", +] diff --git a/syntax/printer/testdata/mixed_list.in b/syntax/printer/testdata/mixed_list.in new file mode 100644 index 0000000000..7ce6f28dfb --- /dev/null +++ b/syntax/printer/testdata/mixed_list.in @@ -0,0 +1,16 @@ +mixed_list = [0, true, { + key_1 = true, + key_2 = true, + key_3 = true, +}, "Hello!"] + +mixed_list_2 = [ + 0, + true, + { + key_1 = true, + key_2 = true, + key_3 = true, + }, + "Hello!", +] diff --git a/syntax/printer/testdata/mixed_object.expect b/syntax/printer/testdata/mixed_object.expect new file mode 100644 index 0000000000..8d301e1f0e --- /dev/null +++ b/syntax/printer/testdata/mixed_object.expect @@ -0,0 +1,8 @@ +mixed_object = { + key_1 = true, + key_2 = [0, true, { + inner_1 = true, + inner_2 = true, + }], +} + diff --git a/syntax/printer/testdata/mixed_object.in b/syntax/printer/testdata/mixed_object.in new file mode 100644 index 0000000000..9334cfafa1 --- /dev/null +++ b/syntax/printer/testdata/mixed_object.in @@ -0,0 +1,7 @@ +mixed_object = { + key_1 = true, + key_2 = [0, true, { + inner_1 = true, + inner_2 = true, + }], +} diff --git a/syntax/printer/testdata/object_align.expect b/syntax/printer/testdata/object_align.expect new file mode 100644 index 0000000000..631536a437 --- /dev/null +++ b/syntax/printer/testdata/object_align.expect @@ -0,0 +1,11 @@ +block { + some_object = { + key_1 = 5, + long_key = 10, + longer_key = { + inner_key = true, + inner_key_2 = false, + }, + other_key = [0, 1, 2], + } +} diff --git a/syntax/printer/testdata/object_align.in b/syntax/printer/testdata/object_align.in new file mode 100644 index 0000000000..b618f61dfc --- /dev/null +++ b/syntax/printer/testdata/object_align.in @@ -0,0 +1,11 @@ +block { + some_object = { + key_1 = 5, + long_key = 10, + longer_key = { + inner_key = true, + inner_key_2 = false, + }, + other_key = [0, 1, 2], + } +} diff --git a/syntax/printer/testdata/oneline_block.expect b/syntax/printer/testdata/oneline_block.expect new file mode 100644 index 0000000000..8cd2c69d25 --- /dev/null +++ b/syntax/printer/testdata/oneline_block.expect @@ -0,0 +1,11 @@ +block { } + +block { } + +block { } + +block { } + +block { + // Comments should be kept. +} diff --git a/syntax/printer/testdata/oneline_block.in b/syntax/printer/testdata/oneline_block.in new file mode 100644 index 0000000000..2c1f74363a --- /dev/null +++ b/syntax/printer/testdata/oneline_block.in @@ -0,0 +1,14 @@ +block {} + +block { } + +block { +} + +block { + +} + +block { + // Comments should be kept. +} diff --git a/syntax/printer/testdata/raw_string.expect b/syntax/printer/testdata/raw_string.expect new file mode 100644 index 0000000000..5837439569 --- /dev/null +++ b/syntax/printer/testdata/raw_string.expect @@ -0,0 +1,15 @@ +block "label" { + attr = `'\"attr` +} + +block "multi_line" { + attr = `'\"this +is +a +multi_line +attr'\"` +} + +block "json" { + attr = `{ "key": "value" }` +} \ No newline at end of file diff --git a/syntax/printer/testdata/raw_string.in b/syntax/printer/testdata/raw_string.in new file mode 100644 index 0000000000..5837439569 --- /dev/null +++ b/syntax/printer/testdata/raw_string.in @@ -0,0 +1,15 @@ +block "label" { + attr = `'\"attr` +} + +block "multi_line" { + attr = `'\"this +is +a +multi_line +attr'\"` +} + +block "json" { + attr = `{ "key": "value" }` +} \ No newline at end of file diff --git a/syntax/printer/testdata/raw_string_label_error.error b/syntax/printer/testdata/raw_string_label_error.error new file mode 100644 index 0000000000..dd3f7f7c8b --- /dev/null +++ b/syntax/printer/testdata/raw_string_label_error.error @@ -0,0 +1 @@ +expected block label to be a double quoted string, but got "`multi_line`" \ No newline at end of file diff --git a/syntax/printer/testdata/raw_string_label_error.in b/syntax/printer/testdata/raw_string_label_error.in new file mode 100644 index 0000000000..1e16f9ae07 --- /dev/null +++ b/syntax/printer/testdata/raw_string_label_error.in @@ -0,0 +1,15 @@ +block "label" { + attr = `'\"attr` +} + +block `multi_line` { + attr = `'\"this +is +a +multi_line +attr'\"` +} + +block `json` { + attr = `{ "key": "value" }` +} \ No newline at end of file diff --git a/syntax/printer/trimmer.go b/syntax/printer/trimmer.go new file mode 100644 index 0000000000..5a76c0c79d --- /dev/null +++ b/syntax/printer/trimmer.go @@ -0,0 +1,115 @@ +package printer + +import ( + "io" + "text/tabwriter" +) + +// A trimmer is an io.Writer which filters tabwriter.Escape characters, +// trailing blanks and tabs from lines, and converting \f and \v characters +// into \n and \t (if no text/tabwriter is used when printing). +// +// Text wrapped by tabwriter.Escape characters is written to the underlying +// io.Writer unmodified. +type trimmer struct { + next io.Writer + state int + space []byte +} + +const ( + trimStateSpace = iota // Trimmer is reading space characters + trimStateEscape // Trimmer is reading escaped characters + trimStateText // Trimmer is reading text +) + +func (t *trimmer) discardWhitespace() { + t.state = trimStateSpace + t.space = t.space[0:0] +} + +func (t *trimmer) Write(data []byte) (n int, err error) { + // textStart holds the index of the start of a chunk of text not containing + // whitespace. It is reset every time a new chunk of text is encountered. + var textStart int + + for off, b := range data { + // Convert \v to \t + if b == '\v' { + b = '\t' + } + + switch t.state { + case trimStateSpace: + // Accumulate tabs and spaces in t.space until finding a non-tab or + // non-space character. + // + // If we find a newline, we write it directly and discard our pending + // whitespace (so that trailing whitespace up to the newline is ignored). + // + // If we find a tabwriter.Escape or text character we transition states. + switch b { + case '\t', ' ': + t.space = append(t.space, b) + case '\n', '\f': + // Discard all unwritten whitespace before the end of the line and write + // a newline. + t.discardWhitespace() + _, err = t.next.Write([]byte("\n")) + case tabwriter.Escape: + _, err = t.next.Write(t.space) + t.state = trimStateEscape + textStart = off + 1 // Skip escape character + default: + // Non-space character. Write our pending whitespace + // and then move to text state. + _, err = t.next.Write(t.space) + t.state = trimStateText + textStart = off + } + + case trimStateText: + // We're reading a chunk of text. Accumulate characters in the chunk + // until we find a whitespace character or a tabwriter.Escape. + switch b { + case '\t', ' ': + _, err = t.next.Write(data[textStart:off]) + t.discardWhitespace() + t.space = append(t.space, b) + case '\n', '\f': + _, err = t.next.Write(data[textStart:off]) + t.discardWhitespace() + if err == nil { + _, err = t.next.Write([]byte("\n")) + } + case tabwriter.Escape: + _, err = t.next.Write(data[textStart:off]) + t.state = trimStateEscape + textStart = off + 1 // +1: skip tabwriter.Escape + } + + case trimStateEscape: + // Accumulate everything until finding the closing tabwriter.Escape. + if b == tabwriter.Escape { + _, err = t.next.Write(data[textStart:off]) + t.discardWhitespace() + } + + default: + panic("unreachable") + } + if err != nil { + return off, err + } + } + n = len(data) + + // Flush the remainder of the text (as long as it's not whitespace). + switch t.state { + case trimStateEscape, trimStateText: + _, err = t.next.Write(data[textStart:n]) + t.discardWhitespace() + } + + return +} diff --git a/syntax/printer/walker.go b/syntax/printer/walker.go new file mode 100644 index 0000000000..01f71b21bd --- /dev/null +++ b/syntax/printer/walker.go @@ -0,0 +1,338 @@ +package printer + +import ( + "fmt" + "strings" + + "github.com/grafana/river/ast" + "github.com/grafana/river/token" +) + +// A walker walks an AST and sends lexical tokens and formatting information to +// a printer. +type walker struct { + p *printer +} + +func (w *walker) Walk(node ast.Node) error { + switch node := node.(type) { + case *ast.File: + w.walkFile(node) + case ast.Body: + w.walkStmts(node) + case ast.Stmt: + w.walkStmt(node) + case ast.Expr: + w.walkExpr(node) + default: + return fmt.Errorf("unsupported node type %T", node) + } + + return nil +} + +func (w *walker) walkFile(f *ast.File) { + w.p.SetComments(f.Comments) + w.walkStmts(f.Body) +} + +func (w *walker) walkStmts(ss []ast.Stmt) { + for i, s := range ss { + var addedSpacing bool + + // Two blocks should always be separated by a blank line. + if _, isBlock := s.(*ast.BlockStmt); i > 0 && isBlock { + w.p.Write(wsFormfeed) + addedSpacing = true + } + + // A blank line should always be added if there is a blank line in the + // source between two statements. + if i > 0 && !addedSpacing { + var ( + prevLine = ast.EndPos(ss[i-1]).Position().Line + curLine = ast.StartPos(ss[i-0]).Position().Line + + lineDiff = curLine - prevLine + ) + + if lineDiff > 1 { + w.p.Write(wsFormfeed) + } + } + + w.walkStmt(s) + + // Statements which cross multiple lines don't belong to the same row run. + // Add a formfeed to start a new row run if the node crossed more than one + // line, otherwise add the normal newline. + if nodeLines(s) > 1 { + w.p.Write(wsFormfeed) + } else { + w.p.Write(wsNewline) + } + } +} + +func nodeLines(n ast.Node) int { + var ( + startLine = ast.StartPos(n).Position().Line + endLine = ast.EndPos(n).Position().Line + ) + + return endLine - startLine + 1 +} + +func (w *walker) walkStmt(s ast.Stmt) { + switch s := s.(type) { + case *ast.AttributeStmt: + w.walkAttributeStmt(s) + case *ast.BlockStmt: + w.walkBlockStmt(s) + } +} + +func (w *walker) walkAttributeStmt(s *ast.AttributeStmt) { + w.p.Write(s.Name.NamePos, s.Name, wsVTab, token.ASSIGN, wsBlank) + w.walkExpr(s.Value) +} + +func (w *walker) walkBlockStmt(s *ast.BlockStmt) { + joined := strings.Join(s.Name, ".") + + w.p.Write( + s.NamePos, + &ast.Ident{Name: joined, NamePos: s.NamePos}, + ) + + if s.Label != "" { + label := fmt.Sprintf("%q", s.Label) + + w.p.Write( + wsBlank, + s.LabelPos, + &ast.LiteralExpr{Kind: token.STRING, Value: label}, + ) + } + + w.p.Write( + wsBlank, + s.LCurlyPos, token.LCURLY, wsIndent, + ) + + if len(s.Body) > 0 { + // Add a formfeed to start a new row run before writing any statements. + w.p.Write(wsFormfeed) + w.walkStmts(s.Body) + } else { + // There's no statements, but add a blank line between the left and right + // curly anyway. + w.p.Write(wsBlank) + } + + w.p.Write(wsUnindent, s.RCurlyPos, token.RCURLY) +} + +func (w *walker) walkExpr(e ast.Expr) { + switch e := e.(type) { + case *ast.LiteralExpr: + w.p.Write(e.ValuePos, e) + + case *ast.ArrayExpr: + w.walkArrayExpr(e) + + case *ast.ObjectExpr: + w.walkObjectExpr(e) + + case *ast.IdentifierExpr: + w.p.Write(e.Ident.NamePos, e.Ident) + + case *ast.AccessExpr: + w.walkExpr(e.Value) + w.p.Write(token.DOT, e.Name) + + case *ast.IndexExpr: + w.walkExpr(e.Value) + w.p.Write(e.LBrackPos, token.LBRACK) + w.walkExpr(e.Index) + w.p.Write(e.RBrackPos, token.RBRACK) + + case *ast.CallExpr: + w.walkCallExpr(e) + + case *ast.UnaryExpr: + w.p.Write(e.KindPos, e.Kind) + w.walkExpr(e.Value) + + case *ast.BinaryExpr: + // TODO(rfratto): + // + // 1. allow RHS to be on a new line + // + // 2. remove spacing between some operators to make precedence + // clearer like Go does + w.walkExpr(e.Left) + w.p.Write(wsBlank, e.KindPos, e.Kind, wsBlank) + w.walkExpr(e.Right) + + case *ast.ParenExpr: + w.p.Write(token.LPAREN) + w.walkExpr(e.Inner) + w.p.Write(token.RPAREN) + } +} + +func (w *walker) walkArrayExpr(e *ast.ArrayExpr) { + w.p.Write(e.LBrackPos, token.LBRACK) + prevPos := e.LBrackPos + + for i := 0; i < len(e.Elements); i++ { + var addedNewline bool + + elementPos := ast.StartPos(e.Elements[i]) + + // Add a newline if this element starts on a different line than the last + // element ended. + if differentLines(prevPos, elementPos) { + // Indent elements inside the array on different lines. The indent is + // done *before* the newline to make sure comments written before the + // newline are indented properly. + w.p.Write(wsIndent, wsFormfeed) + addedNewline = true + } else if i > 0 { + // Make sure a space is injected before the next element if two + // successive elements are on the same line. + w.p.Write(wsBlank) + } + prevPos = ast.EndPos(e.Elements[i]) + + // Write the expression. + w.walkExpr(e.Elements[i]) + + // Always add commas in between successive elements. + if i+1 < len(e.Elements) { + w.p.Write(token.COMMA) + } + + if addedNewline { + w.p.Write(wsUnindent) + } + } + + var addedSuffixNewline bool + + // If the closing bracket is on a different line than the final element, + // we need to add a trailing comma. + if len(e.Elements) > 0 && differentLines(prevPos, e.RBrackPos) { + // We add an indentation here so comments after the final element are + // indented. + w.p.Write(token.COMMA, wsIndent, wsFormfeed) + addedSuffixNewline = true + } + + if addedSuffixNewline { + w.p.Write(wsUnindent) + } + w.p.Write(e.RBrackPos, token.RBRACK) +} + +func (w *walker) walkObjectExpr(e *ast.ObjectExpr) { + w.p.Write(e.LCurlyPos, token.LCURLY, wsIndent) + + prevPos := e.LCurlyPos + + for i := 0; i < len(e.Fields); i++ { + field := e.Fields[i] + elementPos := ast.StartPos(field.Name) + + // Add a newline if this element starts on a different line than the last + // element ended. + if differentLines(prevPos, elementPos) { + // We want to align the equal sign for object attributes if the previous + // field only crossed one line. + if i > 0 && nodeLines(e.Fields[i-1].Value) == 1 { + w.p.Write(wsNewline) + } else { + w.p.Write(wsFormfeed) + } + } else if i > 0 { + // Make sure a space is injected before the next element if two successive + // elements are on the same line. + w.p.Write(wsBlank) + } + prevPos = ast.EndPos(field.Name) + + w.p.Write(field.Name.NamePos) + + // Write the field. + if field.Quoted { + w.p.Write(&ast.LiteralExpr{ + Kind: token.STRING, + ValuePos: field.Name.NamePos, + Value: fmt.Sprintf("%q", field.Name.Name), + }) + } else { + w.p.Write(field.Name) + } + + w.p.Write(wsVTab, token.ASSIGN, wsBlank) + w.walkExpr(field.Value) + + // Always add commas in between successive elements. + if i+1 < len(e.Fields) { + w.p.Write(token.COMMA) + } + } + + // If the closing bracket is on a different line than the final element, + // we need to add a trailing comma. + if len(e.Fields) > 0 && differentLines(prevPos, e.RCurlyPos) { + w.p.Write(token.COMMA, wsFormfeed) + } + + w.p.Write(wsUnindent, e.RCurlyPos, token.RCURLY) +} + +func (w *walker) walkCallExpr(e *ast.CallExpr) { + w.walkExpr(e.Value) + w.p.Write(token.LPAREN) + + prevPos := e.LParenPos + + for i, arg := range e.Args { + var addedNewline bool + + argPos := ast.StartPos(arg) + + // Add a newline if this element starts on a different line than the last + // element ended. + if differentLines(prevPos, argPos) { + w.p.Write(wsFormfeed, wsIndent) + addedNewline = true + } + + w.walkExpr(arg) + prevPos = ast.EndPos(arg) + + if i+1 < len(e.Args) { + w.p.Write(token.COMMA, wsBlank) + } + + if addedNewline { + w.p.Write(wsUnindent) + } + } + + // Add a final comma if the final argument is on a different line than the + // right parenthesis. + if differentLines(prevPos, e.RParenPos) { + w.p.Write(token.COMMA, wsFormfeed) + } + + w.p.Write(token.RPAREN) +} + +// differentLines returns true if a and b are on different lines. +func differentLines(a, b token.Pos) bool { + return a.Position().Line != b.Position().Line +} diff --git a/syntax/river.go b/syntax/river.go new file mode 100644 index 0000000000..0944a9e3be --- /dev/null +++ b/syntax/river.go @@ -0,0 +1,346 @@ +// Package river implements a high-level API for decoding and encoding River +// configuration files. The mapping between River and Go values is described in +// the documentation for the Unmarshal and Marshal functions. +// +// Lower-level APIs which give more control over configuration evaluation are +// available in the inner packages. The implementation of this package is +// minimal and serves as a reference for how to consume the lower-level +// packages. +package river + +import ( + "bytes" + "io" + + "github.com/grafana/river/parser" + "github.com/grafana/river/token/builder" + "github.com/grafana/river/vm" +) + +// Marshal returns the pretty-printed encoding of v as a River configuration +// file. v must be a Go struct with river struct tags which determine the +// structure of the resulting file. +// +// Marshal traverses the value v recursively, encoding each struct field as a +// River block or River attribute, based on the flags provided to the river +// struct tag. +// +// When a struct field represents a River block, Marshal creates a new block +// and recursively encodes the value as the body of the block. The name of the +// created block is taken from the name specified by the river struct tag. +// +// Struct fields which represent River blocks must be either a Go struct or a +// slice of Go structs. When the field is a Go struct, its value is encoded as +// a single block. When the field is a slice of Go structs, a block is created +// for each element in the slice. +// +// When encoding a block, if the inner Go struct has a struct field +// representing a River block label, the value of that field is used as the +// label name for the created block. Fields used for River block labels must be +// the string type. When specified, there must not be more than one struct +// field which represents a block label. +// +// The river tag specifies a name, possibly followed by a comma-separated list +// of options. The name must be empty if the provided options do not support a +// name being defined. The following provides examples for all supported struct +// field tags with their meanings: +// +// // Field appears as a block named "example". It will always appear in the +// // resulting encoding. When decoding, "example" is treated as a required +// // block and must be present in the source text. +// Field struct{...} `river:"example,block"` +// +// // Field appears as a set of blocks named "example." It will appear in the +// // resulting encoding if there is at least one element in the slice. When +// // decoding, "example" is treated as a required block and at least one +// // "example" block must be present in the source text. +// Field []struct{...} `river:"example,block"` +// +// // Field appears as block named "example." It will always appear in the +// // resulting encoding. When decoding, "example" is treated as an optional +// // block and can be omitted from the source text. +// Field struct{...} `river:"example,block,optional"` +// +// // Field appears as a set of blocks named "example." It will appear in the +// // resulting encoding if there is at least one element in the slice. When +// // decoding, "example" is treated as an optional block and can be omitted +// // from the source text. +// Field []struct{...} `river:"example,block,optional"` +// +// // Field appears as an attribute named "example." It will always appear in +// // the resulting encoding. When decoding, "example" is treated as a +// // required attribute and must be present in the source text. +// Field bool `river:"example,attr"` +// +// // Field appears as an attribute named "example." If the field's value is +// // the Go zero value, "example" is omitted from the resulting encoding. +// // When decoding, "example" is treated as an optional attribute and can be +// // omitted from the source text. +// Field bool `river:"example,attr,optional"` +// +// // The value of Field appears as the block label for the struct being +// // converted into a block. When decoding, a block label must be provided. +// Field string `river:",label"` +// +// // The inner attributes and blocks of Field are exposed as top-level +// // attributes and blocks of the outer struct. +// Field struct{...} `river:",squash"` +// +// // Field appears as a set of blocks starting with "example.". Only the +// // first set element in the struct will be encoded. Each field in struct +// // must be a block. The name of the block is prepended to the enum name. +// // When decoding, enum blocks are treated as optional blocks and can be +// // omitted from the source text. +// Field []struct{...} `river:"example,enum"` +// +// // Field is equivalent to `river:"example,enum"`. +// Field []struct{...} `river:"example,enum,optional"` +// +// If a river tag specifies a required or optional block, the name is permitted +// to contain period `.` characters. +// +// Marshal will panic if it encounters a struct with invalid river tags. +// +// When a struct field represents a River attribute, Marshal encodes the struct +// value as a River value. The attribute name will be taken from the name +// specified by the river struct tag. See MarshalValue for the rules used to +// convert a Go value into a River value. +func Marshal(v interface{}) ([]byte, error) { + var buf bytes.Buffer + if err := NewEncoder(&buf).Encode(v); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// MarshalValue returns the pretty-printed encoding of v as a River value. +// +// MarshalValue traverses the value v recursively. If an encountered value +// implements the encoding.TextMarshaler interface, MarshalValue calls its +// MarshalText method and encodes the result as a River string. If a value +// implements the Capsule interface, it always encodes as a River capsule +// value. +// +// Otherwise, MarshalValue uses the following type-dependent default encodings: +// +// Boolean values encode to River bools. +// +// Floating point, integer, and Number values encode to River numbers. +// +// String values encode to River strings. +// +// Array and slice values encode to River arrays, except that []byte is +// converted into a River string. Nil slices encode as an empty array and nil +// []byte slices encode as an empty string. +// +// Structs encode to River objects, using Go struct field tags to determine the +// resulting structure of the River object. Each exported struct field with a +// river tag becomes an object field, using the tag name as the field name. +// Other struct fields are ignored. If no struct field has a river tag, the +// struct encodes to a River capsule instead. +// +// Function values encode to River functions, which appear in the resulting +// text as strings formatted as "function(GO_TYPE)". +// +// All other Go values encode to River capsules, which appear in the resulting +// text as strings formatted as "capsule(GO_TYPE)". +// +// The river tag specifies the field name, possibly followed by a +// comma-separated list of options. The following provides examples for all +// supported struct field tags with their meanings: +// +// // Field appears as an object field named "my_name". It will always +// // appear in the resulting encoding. When decoding, "my_name" is treated +// // as a required attribute and must be present in the source text. +// Field bool `river:"my_name,attr"` +// +// // Field appears as an object field named "my_name". If the field's value +// // is the Go zero value, "example" is omitted from the resulting encoding. +// // When decoding, "my_name" is treated as an optional attribute and can be +// // omitted from the source text. +// Field bool `river:"my_name,attr,optional"` +func MarshalValue(v interface{}) ([]byte, error) { + var buf bytes.Buffer + if err := NewEncoder(&buf).EncodeValue(v); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// Encoder writes River configuration to an output stream. Call NewEncoder to +// create instances of Encoder. +type Encoder struct { + w io.Writer +} + +// NewEncoder returns a new Encoder which writes configuration to w. +func NewEncoder(w io.Writer) *Encoder { + return &Encoder{w: w} +} + +// Encode converts the value pointed to by v into a River configuration file +// and writes the result to the Decoder's output stream. +// +// See the documentation for Marshal for details about the conversion of Go +// values into River configuration. +func (enc *Encoder) Encode(v interface{}) error { + f := builder.NewFile() + f.Body().AppendFrom(v) + + _, err := f.WriteTo(enc.w) + return err +} + +// EncodeValue converts the value pointed to by v into a River value and writes +// the result to the Decoder's output stream. +// +// See the documentation for MarshalValue for details about the conversion of +// Go values into River values. +func (enc *Encoder) EncodeValue(v interface{}) error { + expr := builder.NewExpr() + expr.SetValue(v) + + _, err := expr.WriteTo(enc.w) + return err +} + +// Unmarshal converts the River configuration file specified by in and stores +// it in the struct value pointed to by v. If v is nil or not a pointer, +// Unmarshal panics. The configuration specified by in may use expressions to +// compute values while unmarshaling. Refer to the River language documentation +// for the list of valid formatting and expression rules. +// +// Unmarshal uses the inverse of the encoding rules that Marshal uses, +// allocating maps, slices, and pointers as necessary. +// +// To unmarshal a River body into a map[string]T, Unmarshal assigns each +// attribute to a key in the map, and decodes the attribute's value as the +// value for the map entry. Only attribute statements are allowed when +// unmarshaling into a map. +// +// To unmarshal a River body into a struct, Unmarshal matches incoming +// attributes and blocks to the river struct tags specified by v. Incoming +// attribute and blocks which do not match to a river struct tag cause a +// decoding error. Additionally, any attribute or block marked as required by +// the river struct tag that are not present in the source text will generate a +// decoding error. +// +// To unmarshal a list of River blocks into a slice, Unmarshal resets the slice +// length to zero and then appends each element to the slice. +// +// To unmarshal a list of River blocks into a Go array, Unmarshal decodes each +// block into the corresponding Go array element. If the number of River blocks +// does not match the length of the Go array, a decoding error is returned. +// +// Unmarshal follows the rules specified by UnmarshalValue when unmarshaling +// the value of an attribute. +func Unmarshal(in []byte, v interface{}) error { + dec := NewDecoder(bytes.NewReader(in)) + return dec.Decode(v) +} + +// UnmarshalValue converts the River configuration file specified by in and +// stores it in the value pointed to by v. If v is nil or not a pointer, +// UnmarshalValue panics. The configuration specified by in may use expressions +// to compute values while unmarshaling. Refer to the River language +// documentation for the list of valid formatting and expression rules. +// +// Unmarshal uses the inverse of the encoding rules that MarshalValue uses, +// allocating maps, slices, and pointers as necessary, with the following +// additional rules: +// +// After converting a River value into its Go value counterpart, the Go value +// may be converted into a capsule if the capsule type implements +// ConvertibleIntoCapsule. +// +// To unmarshal a River object into a struct, UnmarshalValue matches incoming +// object fields to the river struct tags specified by v. Incoming object +// fields which do not match to a river struct tag cause a decoding error. +// Additionally, any object field marked as required by the river struct +// tag that are not present in the source text will generate a decoding error. +// +// To unmarshal River into an interface value, Unmarshal stores one of the +// following: +// +// - bool, for River bools +// - float64, for floating point River numbers +// and integers which are too big to fit in either of int/int64/uint64 +// - int/int64/uint64, in this order of preference, for signed and unsigned +// River integer numbers, depending on how big they are +// - string, for River strings +// - []interface{}, for River arrays +// - map[string]interface{}, for River objects +// +// Capsule and function types will retain their original type when decoding +// into an interface value. +// +// To unmarshal a River array into a slice, Unmarshal resets the slice length +// to zero and then appends each element to the slice. +// +// To unmarshal a River array into a Go array, Unmarshal decodes River array +// elements into the corresponding Go array element. If the number of River +// elements does not match the length of the Go array, a decoding error is +// returned. +// +// To unmarshal a River object into a Map, Unmarshal establishes a map to use. +// If the map is nil, Unmarshal allocates a new map. Otherwise, Unmarshal +// reuses the existing map, keeping existing entries. Unmarshal then stores +// key-value pairs from the River object into the map. The map's key type must +// be string. +func UnmarshalValue(in []byte, v interface{}) error { + dec := NewDecoder(bytes.NewReader(in)) + return dec.DecodeValue(v) +} + +// Decoder reads River configuration from an input stream. Call NewDecoder to +// create instances of Decoder. +type Decoder struct { + r io.Reader +} + +// NewDecoder returns a new Decoder which reads configuration from r. +func NewDecoder(r io.Reader) *Decoder { + return &Decoder{r: r} +} + +// Decode reads the River-encoded file from the Decoder's input and stores it +// in the value pointed to by v. Data will be read from the Decoder's input +// until EOF is reached. +// +// See the documentation for Unmarshal for details about the conversion of River +// configuration into Go values. +func (dec *Decoder) Decode(v interface{}) error { + bb, err := io.ReadAll(dec.r) + if err != nil { + return err + } + + f, err := parser.ParseFile("", bb) + if err != nil { + return err + } + + eval := vm.New(f) + return eval.Evaluate(nil, v) +} + +// DecodeValue reads the River-encoded expression from the Decoder's input and +// stores it in the value pointed to by v. Data will be read from the Decoder's +// input until EOF is reached. +// +// See the documentation for UnmarshalValue for details about the conversion of +// River values into Go values. +func (dec *Decoder) DecodeValue(v interface{}) error { + bb, err := io.ReadAll(dec.r) + if err != nil { + return err + } + + f, err := parser.ParseExpression(string(bb)) + if err != nil { + return err + } + + eval := vm.New(f) + return eval.Evaluate(nil, v) +} diff --git a/syntax/river_test.go b/syntax/river_test.go new file mode 100644 index 0000000000..99247f54da --- /dev/null +++ b/syntax/river_test.go @@ -0,0 +1,152 @@ +package river_test + +import ( + "fmt" + "os" + + river "github.com/grafana/river" +) + +func ExampleUnmarshal() { + // Character is our block type which holds an individual character from a + // book. + type Character struct { + // Name of the character. The name is decoded from the block label. + Name string `river:",label"` + // Age of the character. The age is a required attribute within the block, + // and must be set in the config. + Age int `river:"age,attr"` + // Location the character lives in. The location is an optional attribute + // within the block. Optional attributes do not have to bet set. + Location string `river:"location,attr,optional"` + } + + // Book is our overall type where we decode the overall River file into. + type Book struct { + // Title of the book (required attribute). + Title string `river:"title,attr"` + // List of characters. Each character is a labeled block. The optional tag + // means that it is valid not provide a character block. Decoding into a + // slice permits there to be multiple specified character blocks. + Characters []*Character `river:"character,block,optional"` + } + + // Create our book with two characters. + input := ` + title = "Wheel of Time" + + character "Rand" { + age = 19 + location = "Two Rivers" + } + + character "Perrin" { + age = 19 + location = "Two Rivers" + } + ` + + // Unmarshal the config into our Book type and print out the data. + var b Book + if err := river.Unmarshal([]byte(input), &b); err != nil { + panic(err) + } + + fmt.Printf("%s characters:\n", b.Title) + + for _, c := range b.Characters { + if c.Location != "" { + fmt.Printf("\t%s (age %d, location %s)\n", c.Name, c.Age, c.Location) + } else { + fmt.Printf("\t%s (age %d)\n", c.Name, c.Age) + } + } + + // Output: + // Wheel of Time characters: + // Rand (age 19, location Two Rivers) + // Perrin (age 19, location Two Rivers) +} + +// This example shows how functions may be called within user configurations. +// We focus on the `env` function from the standard library, which retrieves a +// value from an environment variable. +func ExampleUnmarshal_functions() { + // Set an environment variable to use in the test. + _ = os.Setenv("EXAMPLE", "Jane Doe") + + type Data struct { + String string `river:"string,attr"` + } + + input := ` + string = env("EXAMPLE") + ` + + var d Data + if err := river.Unmarshal([]byte(input), &d); err != nil { + panic(err) + } + + fmt.Println(d.String) + // Output: Jane Doe +} + +func ExampleUnmarshalValue() { + input := `3 + 5` + + var num int + if err := river.UnmarshalValue([]byte(input), &num); err != nil { + panic(err) + } + + fmt.Println(num) + // Output: 8 +} + +func ExampleMarshal() { + type Person struct { + Name string `river:"name,attr"` + Age int `river:"age,attr"` + Location string `river:"location,attr,optional"` + } + + p := Person{ + Name: "John Doe", + Age: 43, + } + + bb, err := river.Marshal(p) + if err != nil { + panic(err) + } + + fmt.Println(string(bb)) + // Output: + // name = "John Doe" + // age = 43 +} + +func ExampleMarshalValue() { + type Person struct { + Name string `river:"name,attr"` + Age int `river:"age,attr"` + } + + p := Person{ + Name: "John Doe", + Age: 43, + } + + bb, err := river.MarshalValue(p) + if err != nil { + panic(err) + } + + fmt.Println(string(bb)) + // Output: + // { + // name = "John Doe", + // age = 43, + // } +} diff --git a/syntax/rivertypes/optional_secret.go b/syntax/rivertypes/optional_secret.go new file mode 100644 index 0000000000..75648af046 --- /dev/null +++ b/syntax/rivertypes/optional_secret.go @@ -0,0 +1,84 @@ +package rivertypes + +import ( + "fmt" + + "github.com/grafana/river/internal/value" + "github.com/grafana/river/token" + "github.com/grafana/river/token/builder" +) + +// OptionalSecret holds a potentially sensitive value. When IsSecret is true, +// the OptionalSecret's Value will be treated as sensitive and will be hidden +// from users when rendering River. +// +// OptionalSecrets may be converted from river strings and the Secret type, +// which will set IsSecret accordingly. +// +// Additionally, OptionalSecrets may be converted into the Secret type +// regardless of the value of IsSecret. OptionalSecret can be converted into a +// string as long as IsSecret is false. +type OptionalSecret struct { + IsSecret bool + Value string +} + +var ( + _ value.Capsule = OptionalSecret{} + _ value.ConvertibleIntoCapsule = OptionalSecret{} + _ value.ConvertibleFromCapsule = (*OptionalSecret)(nil) + + _ builder.Tokenizer = OptionalSecret{} +) + +// RiverCapsule marks OptionalSecret as a RiverCapsule. +func (s OptionalSecret) RiverCapsule() {} + +// ConvertInto converts the OptionalSecret and stores it into the Go value +// pointed at by dst. OptionalSecrets can always be converted into *Secret. +// OptionalSecrets can only be converted into *string if IsSecret is false. In +// other cases, this method will return an explicit error or +// river.ErrNoConversion. +func (s OptionalSecret) ConvertInto(dst interface{}) error { + switch dst := dst.(type) { + case *Secret: + *dst = Secret(s.Value) + return nil + case *string: + if s.IsSecret { + return fmt.Errorf("secrets may not be converted into strings") + } + *dst = s.Value + return nil + } + + return value.ErrNoConversion +} + +// ConvertFrom converts the src value and stores it into the OptionalSecret s. +// Secrets and strings can be converted into an OptionalSecret. In other +// cases, this method will return river.ErrNoConversion. +func (s *OptionalSecret) ConvertFrom(src interface{}) error { + switch src := src.(type) { + case Secret: + *s = OptionalSecret{IsSecret: true, Value: string(src)} + return nil + case string: + *s = OptionalSecret{Value: src} + return nil + } + + return value.ErrNoConversion +} + +// RiverTokenize returns a set of custom tokens to represent this value in +// River text. +func (s OptionalSecret) RiverTokenize() []builder.Token { + if s.IsSecret { + return []builder.Token{{Tok: token.LITERAL, Lit: "(secret)"}} + } + return []builder.Token{{ + Tok: token.STRING, + Lit: fmt.Sprintf("%q", s.Value), + }} +} diff --git a/syntax/rivertypes/optional_secret_test.go b/syntax/rivertypes/optional_secret_test.go new file mode 100644 index 0000000000..bd8a0baeea --- /dev/null +++ b/syntax/rivertypes/optional_secret_test.go @@ -0,0 +1,92 @@ +package rivertypes_test + +import ( + "testing" + + "github.com/grafana/river/rivertypes" + "github.com/grafana/river/token/builder" + "github.com/stretchr/testify/require" +) + +func TestOptionalSecret(t *testing.T) { + t.Run("non-sensitive conversion to string is allowed", func(t *testing.T) { + input := rivertypes.OptionalSecret{IsSecret: false, Value: "testval"} + + var s string + err := decodeTo(t, input, &s) + require.NoError(t, err) + require.Equal(t, "testval", s) + }) + + t.Run("sensitive conversion to string is disallowed", func(t *testing.T) { + input := rivertypes.OptionalSecret{IsSecret: true, Value: "testval"} + + var s string + err := decodeTo(t, input, &s) + require.NotNil(t, err) + require.Contains(t, err.Error(), "secrets may not be converted into strings") + }) + + t.Run("non-sensitive conversion to secret is allowed", func(t *testing.T) { + input := rivertypes.OptionalSecret{IsSecret: false, Value: "testval"} + + var s rivertypes.Secret + err := decodeTo(t, input, &s) + require.NoError(t, err) + require.Equal(t, rivertypes.Secret("testval"), s) + }) + + t.Run("sensitive conversion to secret is allowed", func(t *testing.T) { + input := rivertypes.OptionalSecret{IsSecret: true, Value: "testval"} + + var s rivertypes.Secret + err := decodeTo(t, input, &s) + require.NoError(t, err) + require.Equal(t, rivertypes.Secret("testval"), s) + }) + + t.Run("conversion from string is allowed", func(t *testing.T) { + var s rivertypes.OptionalSecret + err := decodeTo(t, string("Hello, world!"), &s) + require.NoError(t, err) + + expect := rivertypes.OptionalSecret{ + IsSecret: false, + Value: "Hello, world!", + } + require.Equal(t, expect, s) + }) + + t.Run("conversion from secret is allowed", func(t *testing.T) { + var s rivertypes.OptionalSecret + err := decodeTo(t, rivertypes.Secret("Hello, world!"), &s) + require.NoError(t, err) + + expect := rivertypes.OptionalSecret{ + IsSecret: true, + Value: "Hello, world!", + } + require.Equal(t, expect, s) + }) +} + +func TestOptionalSecret_Write(t *testing.T) { + tt := []struct { + name string + value interface{} + expect string + }{ + {"non-sensitive", rivertypes.OptionalSecret{Value: "foobar"}, `"foobar"`}, + {"sensitive", rivertypes.OptionalSecret{IsSecret: true, Value: "foobar"}, `(secret)`}, + {"non-sensitive pointer", &rivertypes.OptionalSecret{Value: "foobar"}, `"foobar"`}, + {"sensitive pointer", &rivertypes.OptionalSecret{IsSecret: true, Value: "foobar"}, `(secret)`}, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + be := builder.NewExpr() + be.SetValue(tc.value) + require.Equal(t, tc.expect, string(be.Bytes())) + }) + } +} diff --git a/syntax/rivertypes/secret.go b/syntax/rivertypes/secret.go new file mode 100644 index 0000000000..c2eb357d03 --- /dev/null +++ b/syntax/rivertypes/secret.go @@ -0,0 +1,65 @@ +package rivertypes + +import ( + "fmt" + + "github.com/grafana/river/internal/value" + "github.com/grafana/river/token" + "github.com/grafana/river/token/builder" +) + +// Secret is a River capsule holding a sensitive string. The contents of a +// Secret are never displayed to the user when rendering River. +// +// Secret allows itself to be converted from a string River value, but never +// the inverse. This ensures that a user can't accidentally leak a sensitive +// value. +type Secret string + +var ( + _ value.Capsule = Secret("") + _ value.ConvertibleIntoCapsule = Secret("") + _ value.ConvertibleFromCapsule = (*Secret)(nil) + + _ builder.Tokenizer = Secret("") +) + +// RiverCapsule marks Secret as a RiverCapsule. +func (s Secret) RiverCapsule() {} + +// ConvertInto converts the Secret and stores it into the Go value pointed at +// by dst. Secrets can be converted into *OptionalSecret. In other cases, this +// method will return an explicit error or river.ErrNoConversion. +func (s Secret) ConvertInto(dst interface{}) error { + switch dst := dst.(type) { + case *OptionalSecret: + *dst = OptionalSecret{IsSecret: true, Value: string(s)} + return nil + case *string: + return fmt.Errorf("secrets may not be converted into strings") + } + + return value.ErrNoConversion +} + +// ConvertFrom converts the src value and stores it into the Secret s. +// OptionalSecrets and strings can be converted into a Secret. In other cases, +// this method will return river.ErrNoConversion. +func (s *Secret) ConvertFrom(src interface{}) error { + switch src := src.(type) { + case OptionalSecret: + *s = Secret(src.Value) + return nil + case string: + *s = Secret(src) + return nil + } + + return value.ErrNoConversion +} + +// RiverTokenize returns a set of custom tokens to represent this value in +// River text. +func (s Secret) RiverTokenize() []builder.Token { + return []builder.Token{{Tok: token.LITERAL, Lit: "(secret)"}} +} diff --git a/syntax/rivertypes/secret_test.go b/syntax/rivertypes/secret_test.go new file mode 100644 index 0000000000..cade74647b --- /dev/null +++ b/syntax/rivertypes/secret_test.go @@ -0,0 +1,47 @@ +package rivertypes_test + +import ( + "testing" + + "github.com/grafana/river/parser" + "github.com/grafana/river/rivertypes" + "github.com/grafana/river/vm" + "github.com/stretchr/testify/require" +) + +func TestSecret(t *testing.T) { + t.Run("strings can be converted to secret", func(t *testing.T) { + var s rivertypes.Secret + err := decodeTo(t, string("Hello, world!"), &s) + require.NoError(t, err) + require.Equal(t, rivertypes.Secret("Hello, world!"), s) + }) + + t.Run("secrets cannot be converted to strings", func(t *testing.T) { + var s string + err := decodeTo(t, rivertypes.Secret("Hello, world!"), &s) + require.NotNil(t, err) + require.Contains(t, err.Error(), "secrets may not be converted into strings") + }) + + t.Run("secrets can be passed to secrets", func(t *testing.T) { + var s rivertypes.Secret + err := decodeTo(t, rivertypes.Secret("Hello, world!"), &s) + require.NoError(t, err) + require.Equal(t, rivertypes.Secret("Hello, world!"), s) + }) +} + +func decodeTo(t *testing.T, input interface{}, target interface{}) error { + t.Helper() + + expr, err := parser.ParseExpression("val") + require.NoError(t, err) + + eval := vm.New(expr) + return eval.Evaluate(&vm.Scope{ + Variables: map[string]interface{}{ + "val": input, + }, + }, target) +} diff --git a/syntax/scanner/identifier.go b/syntax/scanner/identifier.go new file mode 100644 index 0000000000..ed2239e060 --- /dev/null +++ b/syntax/scanner/identifier.go @@ -0,0 +1,60 @@ +package scanner + +import ( + "fmt" + + "github.com/grafana/river/token" +) + +// IsValidIdentifier returns true if the given string is a valid river +// identifier. +func IsValidIdentifier(in string) bool { + s := New(token.NewFile(""), []byte(in), nil, 0) + _, tok, lit := s.Scan() + return tok == token.IDENT && lit == in +} + +// SanitizeIdentifier will return the given string mutated into a valid river +// identifier. If the given string is already a valid identifier, it will be +// returned unchanged. +// +// This should be used with caution since the different inputs can result in +// identical outputs. +func SanitizeIdentifier(in string) (string, error) { + if in == "" { + return "", fmt.Errorf("cannot generate a new identifier for an empty string") + } + + if IsValidIdentifier(in) { + return in, nil + } + + newValue := generateNewIdentifier(in) + if !IsValidIdentifier(newValue) { + panic(fmt.Errorf("invalid identifier %q generated for `%q`", newValue, in)) + } + + return newValue, nil +} + +// generateNewIdentifier expects a valid river prefix and replacement +// string and returns a new identifier based on the given input. +func generateNewIdentifier(in string) string { + newValue := "" + for i, c := range in { + if i == 0 { + if isDigit(c) { + newValue = "_" + } + } + + if !(isLetter(c) || isDigit(c)) { + newValue += "_" + continue + } + + newValue += string(c) + } + + return newValue +} diff --git a/syntax/scanner/identifier_test.go b/syntax/scanner/identifier_test.go new file mode 100644 index 0000000000..e1dfead833 --- /dev/null +++ b/syntax/scanner/identifier_test.go @@ -0,0 +1,92 @@ +package scanner_test + +import ( + "testing" + + "github.com/grafana/river/scanner" + "github.com/stretchr/testify/require" +) + +var validTestCases = []struct { + name string + identifier string + expect bool +}{ + {"empty", "", false}, + {"start_number", "0identifier_1", false}, + {"start_char", "identifier_1", true}, + {"start_underscore", "_identifier_1", true}, + {"special_chars", "!@#$%^&*()", false}, + {"special_char", "identifier_1!", false}, + {"spaces", "identifier _ 1", false}, +} + +func TestIsValidIdentifier(t *testing.T) { + for _, tc := range validTestCases { + t.Run(tc.name, func(t *testing.T) { + require.Equal(t, tc.expect, scanner.IsValidIdentifier(tc.identifier)) + }) + } +} + +func BenchmarkIsValidIdentifier(b *testing.B) { + for i := 0; i < b.N; i++ { + for _, tc := range validTestCases { + _ = scanner.IsValidIdentifier(tc.identifier) + } + } +} + +var sanitizeTestCases = []struct { + name string + identifier string + expectIdentifier string + expectErr string +}{ + {"empty", "", "", "cannot generate a new identifier for an empty string"}, + {"start_number", "0identifier_1", "_0identifier_1", ""}, + {"start_char", "identifier_1", "identifier_1", ""}, + {"start_underscore", "_identifier_1", "_identifier_1", ""}, + {"special_chars", "!@#$%^&*()", "__________", ""}, + {"special_char", "identifier_1!", "identifier_1_", ""}, + {"spaces", "identifier _ 1", "identifier___1", ""}, +} + +func TestSanitizeIdentifier(t *testing.T) { + for _, tc := range sanitizeTestCases { + t.Run(tc.name, func(t *testing.T) { + newIdentifier, err := scanner.SanitizeIdentifier(tc.identifier) + if tc.expectErr != "" { + require.EqualError(t, err, tc.expectErr) + return + } + + require.NoError(t, err) + require.Equal(t, tc.expectIdentifier, newIdentifier) + }) + } +} + +func BenchmarkSanitizeIdentifier(b *testing.B) { + for i := 0; i < b.N; i++ { + for _, tc := range sanitizeTestCases { + _, _ = scanner.SanitizeIdentifier(tc.identifier) + } + } +} + +func FuzzSanitizeIdentifier(f *testing.F) { + for _, tc := range sanitizeTestCases { + f.Add(tc.identifier) + } + + f.Fuzz(func(t *testing.T, input string) { + newIdentifier, err := scanner.SanitizeIdentifier(input) + if input == "" { + require.EqualError(t, err, "cannot generate a new identifier for an empty string") + return + } + require.NoError(t, err) + require.True(t, scanner.IsValidIdentifier(newIdentifier)) + }) +} diff --git a/syntax/scanner/scanner.go b/syntax/scanner/scanner.go new file mode 100644 index 0000000000..e637a785b9 --- /dev/null +++ b/syntax/scanner/scanner.go @@ -0,0 +1,704 @@ +// Package scanner implements a lexical scanner for River source files. +package scanner + +import ( + "fmt" + "unicode" + "unicode/utf8" + + "github.com/grafana/river/token" +) + +// EBNF for the scanner: +// +// letter = /* any unicode letter class character */ | "_" +// number = /* any unicode number class character */ +// digit = /* ASCII characters 0 through 9 */ +// digits = digit { digit } +// string_character = /* any unicode character that isn't '"' */ +// +// COMMENT = line_comment | block_comment +// line_comment = "//" { character } +// block_comment = "/*" { character | newline } "*/" +// +// IDENT = letter { letter | number } +// NULL = "null" +// BOOL = "true" | "false" +// NUMBER = digits +// FLOAT = ( digits | "." digits ) [ "e" [ "+" | "-" ] digits ] +// STRING = '"' { string_character | escape_sequence } '"' +// OR = "||" +// AND = "&&" +// NOT = "!" +// NEQ = "!=" +// ASSIGN = "=" +// EQ = "==" +// LT = "<" +// LTE = "<=" +// GT = ">" +// GTE = ">=" +// ADD = "+" +// SUB = "-" +// MUL = "*" +// DIV = "/" +// MOD = "%" +// POW = "^" +// LCURLY = "{" +// RCURLY = "}" +// LPAREN = "(" +// RPAREN = ")" +// LBRACK = "[" +// RBRACK = "]" +// COMMA = "," +// DOT = "." +// +// The EBNF for escape_sequence is currently undocumented; see scanEscape for +// details. The escape sequences supported by River are the same as the escape +// sequences supported by Go, except that it is always valid to use \' in +// strings (which in Go, is only valid to use in character literals). + +// ErrorHandler is invoked whenever there is an error. +type ErrorHandler func(pos token.Pos, msg string) + +// Mode is a set of bitwise flags which control scanner behavior. +type Mode uint + +const ( + // IncludeComments will cause comments to be returned as comment tokens. + // Otherwise, comments are ignored. + IncludeComments Mode = 1 << iota + + // Avoids automatic insertion of terminators (for testing only). + dontInsertTerms +) + +const ( + bom = 0xFEFF // byte order mark, permitted as very first character + eof = -1 // end of file +) + +// Scanner holds the internal state for the tokenizer while processing configs. +type Scanner struct { + file *token.File // Config file handle for tracking line offsets + input []byte // Input config + err ErrorHandler // Error reporting (may be nil) + mode Mode + + // scanning state variables: + + ch rune // Current character + offset int // Byte offset of ch + readOffset int // Byte offset of first character *after* ch + insertTerm bool // Insert a newline before the next newline + numErrors int // Number of errors encountered during scanning +} + +// New creates a new scanner to tokenize the provided input config. The scanner +// uses the provided file for adding line information for each token. The mode +// parameter customizes scanner behavior. +// +// Calls to Scan will invoke the error handler eh when a lexical error is found +// if eh is not nil. +func New(file *token.File, input []byte, eh ErrorHandler, mode Mode) *Scanner { + s := &Scanner{ + file: file, + input: input, + err: eh, + mode: mode, + } + + // Preload first character. + s.next() + if s.ch == bom { + s.next() // Ignore BOM if it's the first character. + } + return s +} + +// peek gets the next byte after the current character without advancing the +// scanner. Returns 0 if the scanner is at EOF. +func (s *Scanner) peek() byte { + if s.readOffset < len(s.input) { + return s.input[s.readOffset] + } + return 0 +} + +// next advances the scanner and reads the next Unicode character into s.ch. +// s.ch == eof indicates end of file. +func (s *Scanner) next() { + if s.readOffset >= len(s.input) { + s.offset = len(s.input) + if s.ch == '\n' { + // Make sure we track final newlines at the end of the file + s.file.AddLine(s.offset) + } + s.ch = eof + return + } + + s.offset = s.readOffset + if s.ch == '\n' { + s.file.AddLine(s.offset) + } + + r, width := rune(s.input[s.readOffset]), 1 + switch { + case r == 0: + s.onError(s.offset, "illegal character NUL") + case r >= utf8.RuneSelf: + r, width = utf8.DecodeRune(s.input[s.readOffset:]) + if r == utf8.RuneError && width == 1 { + s.onError(s.offset, "illegal UTF-8 encoding") + } else if r == bom && s.offset > 0 { + s.onError(s.offset, "illegal byte order mark") + } + } + s.readOffset += width + s.ch = r +} + +func (s *Scanner) onError(offset int, msg string) { + if s.err != nil { + s.err(s.file.Pos(offset), msg) + } + s.numErrors++ +} + +// NumErrors returns the current number of errors encountered during scanning. +// This is useful as a fallback to detect errors when no ErrorHandler was +// provided to the scanner. +func (s *Scanner) NumErrors() int { return s.numErrors } + +// Scan scans the next token and returns the token's position, the token +// itself, and the token's literal string (when applicable). The end of the +// input is indicated by token.EOF. +// +// If the returned token is a literal (such as token.STRING), then lit contains +// the corresponding literal text (including surrounding quotes). +// +// If the returned token is a keyword, lit is the keyword text that was +// scanned. +// +// If the returned token is token.TERMINATOR, lit will contain "\n". +// +// If the returned token is token.ILLEGAL, lit contains the offending +// character. +// +// In all other cases, lit will be an empty string. +// +// For more tolerant parsing, Scan returns a valid token character whenever +// possible when a syntax error was encountered. Callers must check NumErrors +// or the number of times the provided ErrorHandler was invoked to ensure there +// were no errors found during scanning. +// +// Scan will inject line information to the file provided by NewScanner. +// Returned token positions are relative to that file. +func (s *Scanner) Scan() (pos token.Pos, tok token.Token, lit string) { +scanAgain: + s.skipWhitespace() + + // Start of current token. + pos = s.file.Pos(s.offset) + + var insertTerm bool + + // Determine token value + switch ch := s.ch; { + case isLetter(ch): + lit = s.scanIdentifier() + if len(lit) > 1 { // Keywords are always > 1 char + tok = token.Lookup(lit) + switch tok { + case token.IDENT, token.NULL, token.BOOL: + insertTerm = true + } + } else { + insertTerm = true + tok = token.IDENT + } + + case isDecimal(ch) || (ch == '.' && isDecimal(rune(s.peek()))): + insertTerm = true + tok, lit = s.scanNumber() + + default: + s.next() // Make progress + + // ch is now the first character in a sequence and s.ch is the second + // character. + + switch ch { + case eof: + if s.insertTerm { + s.insertTerm = false // Consumed EOF + return pos, token.TERMINATOR, "\n" + } + tok = token.EOF + + case '\n': + // This case is only reachable when s.insertTerm is true, since otherwise + // skipWhitespace consumes all other newlines. + s.insertTerm = false // Consumed newline + return pos, token.TERMINATOR, "\n" + + case '\'': + s.onError(pos.Offset(), "illegal single-quoted string; use double quotes") + insertTerm = true + tok = token.ILLEGAL + lit = s.scanString('\'', true, false) + + case '"': + insertTerm = true + tok = token.STRING + lit = s.scanString('"', true, false) + + case '`': + insertTerm = true + tok = token.STRING + lit = s.scanString('`', false, true) + + case '|': + if s.ch != '|' { + s.onError(s.offset, "missing second | in ||") + } else { + s.next() // consume second '|' + } + tok = token.OR + case '&': + if s.ch != '&' { + s.onError(s.offset, "missing second & in &&") + } else { + s.next() // consume second '&' + } + tok = token.AND + + case '!': // !, != + tok = s.switch2(token.NOT, token.NEQ, '=') + case '=': // =, == + tok = s.switch2(token.ASSIGN, token.EQ, '=') + case '<': // <, <= + tok = s.switch2(token.LT, token.LTE, '=') + case '>': // >, >= + tok = s.switch2(token.GT, token.GTE, '=') + case '+': + tok = token.ADD + case '-': + tok = token.SUB + case '*': + tok = token.MUL + case '/': + if s.ch == '/' || s.ch == '*' { + // //- or /*-style comment. + // + // If we're expected to inject a terminator, we can only do so if our + // comment goes to the end of the line. Otherwise, the terminator will + // have to be injected after the comment token. + if s.insertTerm && s.findLineEnd() { + // Reset position to the beginning of the comment. + s.ch = '/' + s.offset = pos.Offset() + s.readOffset = s.offset + 1 + s.insertTerm = false // Consumed newline + return pos, token.TERMINATOR, "\n" + } + comment := s.scanComment() + if s.mode&IncludeComments == 0 { + // Skip over comment + s.insertTerm = false // Consumed newline + goto scanAgain + } + tok = token.COMMENT + lit = comment + } else { + tok = token.DIV + } + + case '%': + tok = token.MOD + case '^': + tok = token.POW + case '{': + tok = token.LCURLY + case '}': + insertTerm = true + tok = token.RCURLY + case '(': + tok = token.LPAREN + case ')': + insertTerm = true + tok = token.RPAREN + case '[': + tok = token.LBRACK + case ']': + insertTerm = true + tok = token.RBRACK + case ',': + tok = token.COMMA + case '.': + // NOTE: Fractions starting with '.' are handled by outer switch + tok = token.DOT + + default: + // s.next() reports invalid BOMs so we don't need to repeat the error. + if ch != bom { + s.onError(pos.Offset(), fmt.Sprintf("illegal character %#U", ch)) + } + insertTerm = s.insertTerm // Preserve previous s.insertTerm state + tok = token.ILLEGAL + lit = string(ch) + } + } + + if s.mode&dontInsertTerms == 0 { + s.insertTerm = insertTerm + } + return +} + +func (s *Scanner) skipWhitespace() { + for s.ch == ' ' || s.ch == '\t' || s.ch == '\r' || (s.ch == '\n' && !s.insertTerm) { + s.next() + } +} + +func isLetter(ch rune) bool { + // We check for ASCII first as an optimization, and leave checking unicode + // (the slowest) to the very end. + return (lower(ch) >= 'a' && lower(ch) <= 'z') || + ch == '_' || + (ch >= utf8.RuneSelf && unicode.IsLetter(ch)) +} + +func lower(ch rune) rune { return ('a' - 'A') | ch } +func isDecimal(ch rune) bool { return '0' <= ch && ch <= '9' } +func isDigit(ch rune) bool { + return isDecimal(ch) || (ch >= utf8.RuneSelf && unicode.IsDigit(ch)) +} + +// scanIdentifier reads the string of valid identifier characters starting at +// s.offet. It must only be called when s.ch is a valid character which starts +// an identifier. +// +// scanIdentifier is highly optimized for identifiers are modifications must be +// made carefully. +func (s *Scanner) scanIdentifier() string { + off := s.offset + + // Optimize for common case of ASCII identifiers. + // + // Ranging over s.input[s.readOffset:] avoids bounds checks and avoids + // conversions to runes. + // + // We'll fall back to the slower path if we find a non-ASCII character. + for readOffset, b := range s.input[s.readOffset:] { + if (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || b == '_' || (b >= '0' && b <= '9') { + // Common case: ASCII character; don't assign a rune. + continue + } + s.readOffset += readOffset + if b > 0 && b < utf8.RuneSelf { + // Optimization: ASCII character that isn't a letter or number; we've + // reached the end of the identifier sequence and can terminate. We avoid + // the call to s.next() and the corresponding setup. + // + // This optimization only works because we know that s.ch (the current + // character when scanIdentifier was called) is never '\n' since '\n' + // cannot start an identifier. + s.ch = rune(b) + s.offset = s.readOffset + s.readOffset++ + goto exit + } + + // The preceding character is valid for an identifier because + // scanIdentifier is only called when s.ch is a letter; calling s.next() at + // s.readOffset will reset the scanner state. + s.next() + for isLetter(s.ch) || isDigit(s.ch) { + s.next() + } + + // No more valid characters for the identifier; terminate. + goto exit + } + + s.offset = len(s.input) + s.readOffset = len(s.input) + s.ch = eof + +exit: + return string(s.input[off:s.offset]) +} + +func (s *Scanner) scanNumber() (tok token.Token, lit string) { + tok = token.NUMBER + off := s.offset + + // Integer part of number + if s.ch != '.' { + s.digits() + } + + // Fractional part of number + if s.ch == '.' { + tok = token.FLOAT + + s.next() + s.digits() + } + + // Exponent + if lower(s.ch) == 'e' { + tok = token.FLOAT + + s.next() + if s.ch == '+' || s.ch == '-' { + s.next() + } + + if s.digits() == 0 { + s.onError(off, "exponent has no digits") + } + } + + return tok, string(s.input[off:s.offset]) +} + +// digits scans a sequence of digits. +func (s *Scanner) digits() (count int) { + for isDecimal(s.ch) { + s.next() + count++ + } + return +} + +func (s *Scanner) scanString(until rune, escape bool, multiline bool) string { + // subtract 1 to account for the opening '"' which was already consumed by + // the scanner forcing progress. + off := s.offset - 1 + + for { + ch := s.ch + if (!multiline && ch == '\n') || ch == eof { + s.onError(off, "string literal not terminated") + break + } + s.next() + if ch == until { + break + } + if escape && ch == '\\' { + s.scanEscape() + } + } + + return string(s.input[off:s.offset]) +} + +// scanEscape parses an escape sequence. In case of a syntax error, scanEscape +// stops at the offending character without consuming it. +func (s *Scanner) scanEscape() { + off := s.offset + + var ( + n int + base, max uint32 + ) + + switch s.ch { + case 'a', 'b', 'f', 'n', 'r', 't', 'v', '\\', '"': + s.next() + return + case '0', '1', '2', '3', '4', '5', '6', '7': + n, base, max = 3, 8, 255 + case 'x': + s.next() + n, base, max = 2, 16, 255 + case 'u': + s.next() + n, base, max = 4, 16, unicode.MaxRune + case 'U': + s.next() + n, base, max = 8, 16, unicode.MaxRune + default: + msg := "unknown escape sequence" + if s.ch == eof { + msg = "escape sequence not terminated" + } + s.onError(off, msg) + return + } + + var x uint32 + for n > 0 { + d := uint32(digitVal(s.ch)) + if d >= base { + msg := fmt.Sprintf("illegal character %#U in escape sequence", s.ch) + if s.ch == eof { + msg = "escape sequence not terminated" + } + s.onError(off, msg) + return + } + x = x*base + d + s.next() + n-- + } + + if x > max || x >= 0xD800 && x < 0xE000 { + s.onError(off, "escape sequence is invalid Unicode code point") + } +} + +func digitVal(ch rune) int { + switch { + case ch >= '0' && ch <= '9': + return int(ch - '0') + case lower(ch) >= 'a' && lower(ch) <= 'f': + return int(lower(ch) - 'a' + 10) + } + return 16 // Larger than any legal digit val +} + +func (s *Scanner) scanComment() string { + // The initial character in the comment was already consumed from the scanner + // forcing progress. + // + // slashComment will be true when the comment is a //- or /*-style comment. + + var ( + off = s.offset - 1 // Offset of initial character + numCR = 0 + + blockComment = false + ) + + if s.ch == '/' { // NOTE: s.ch is second character in comment sequence + // //-style comment. + // + // The final '\n' is not considered to be part of the comment. + if s.ch == '/' { + s.next() // Consume second '/' + } + + for s.ch != '\n' && s.ch != eof { + if s.ch == '\r' { + numCR++ + } + s.next() + } + + goto exit + } + + // /*-style comment. + blockComment = true + s.next() + for s.ch != eof { + ch := s.ch + if ch == '\r' { + numCR++ + } + s.next() + if ch == '*' && s.ch == '/' { + s.next() + goto exit + } + } + + s.onError(off, "block comment not terminated") + +exit: + lit := s.input[off:s.offset] + + // On Windows, a single comment line may end in "\r\n". We want to remove the + // final \r. + if numCR > 0 && len(lit) >= 1 && lit[len(lit)-1] == '\r' { + lit = lit[:len(lit)-1] + numCR-- + } + + if numCR > 0 { + lit = stripCR(lit, blockComment) + } + + return string(lit) +} + +func stripCR(b []byte, blockComment bool) []byte { + c := make([]byte, len(b)) + i := 0 + + for j, ch := range b { + if ch != '\r' || blockComment && i > len("/*") && c[i-1] == '*' && j+1 < len(b) && b[j+1] == '/' { + c[i] = ch + i++ + } + } + + return c[:i] +} + +// findLineEnd checks to see if a comment runs to the end of the line. +func (s *Scanner) findLineEnd() bool { + // NOTE: initial '/' is already consumed by forcing the scanner to progress. + + defer func(off int) { + // Reset scanner state to where it was upon calling findLineEnd. + s.ch = '/' + s.offset = off + s.readOffset = off + 1 + s.next() // Consume initial starting '/' again + }(s.offset - 1) + + // Read ahead until a newline, EOF, or non-comment token is found. + // We loop to consume multiple sequences of comment tokens. + for s.ch == '/' || s.ch == '*' { + if s.ch == '/' { + // //-style comments always contain newlines. + return true + } + + // We're looking at a /*-style comment; look for its newline. + s.next() + for s.ch != eof { + ch := s.ch + if ch == '\n' { + return true + } + s.next() + if ch == '*' && s.ch == '/' { // End of block comment + s.next() + break + } + } + + // Check to see if there's a newline after the block comment. + s.skipWhitespace() // s.insertTerm is set + if s.ch == eof || s.ch == '\n' { + return true + } + if s.ch != '/' { + // Non-comment token + return false + } + s.next() // Consume '/' at the end of the /* style-comment + } + + return false +} + +// switch2 returns a if s.ch is next, b otherwise. The scanner will be advanced +// if b is returned. +// +// This is used for tokens which can either be a single character but also are +// the starting character for a 2-length token (i.e., = and ==). +func (s *Scanner) switch2(a, b token.Token, next rune) token.Token { //nolint:unparam + if s.ch == next { + s.next() + return b + } + return a +} diff --git a/syntax/scanner/scanner_test.go b/syntax/scanner/scanner_test.go new file mode 100644 index 0000000000..38ddcf58ca --- /dev/null +++ b/syntax/scanner/scanner_test.go @@ -0,0 +1,272 @@ +package scanner + +import ( + "path/filepath" + "testing" + + "github.com/grafana/river/token" + "github.com/stretchr/testify/assert" +) + +type tokenExample struct { + tok token.Token + lit string +} + +var tokens = []tokenExample{ + // Special tokens + {token.COMMENT, "/* a comment */"}, + {token.COMMENT, "// a comment \n"}, + {token.COMMENT, "/*\r*/"}, + {token.COMMENT, "/**\r/*/"}, // golang/go#11151 + {token.COMMENT, "/**\r\r/*/"}, + {token.COMMENT, "//\r\n"}, + + // Identifiers and basic type literals + {token.IDENT, "foobar"}, + {token.IDENT, "a۰۱۸"}, + {token.IDENT, "foo६४"}, + {token.IDENT, "bar9876"}, + {token.IDENT, "ŝ"}, // golang/go#4000 + {token.IDENT, "ŝfoo"}, // golang/go#4000 + {token.NUMBER, "0"}, + {token.NUMBER, "1"}, + {token.NUMBER, "123456789012345678890"}, + {token.NUMBER, "01234567"}, + {token.FLOAT, "0."}, + {token.FLOAT, ".0"}, + {token.FLOAT, "3.14159265"}, + {token.FLOAT, "1e0"}, + {token.FLOAT, "1e+100"}, + {token.FLOAT, "1e-100"}, + {token.FLOAT, "2.71828e-1000"}, + {token.STRING, `"Hello, world!"`}, + {token.STRING, "`Hello, world!\\\\`"}, + + // Operators and delimiters + {token.ADD, "+"}, + {token.SUB, "-"}, + {token.MUL, "*"}, + {token.DIV, "/"}, + {token.MOD, "%"}, + {token.POW, "^"}, + + {token.AND, "&&"}, + {token.OR, "||"}, + + {token.EQ, "=="}, + {token.LT, "<"}, + {token.GT, ">"}, + {token.ASSIGN, "="}, + {token.NOT, "!"}, + + {token.NEQ, "!="}, + {token.LTE, "<="}, + {token.GTE, ">="}, + + {token.LPAREN, "("}, + {token.LBRACK, "["}, + {token.LCURLY, "{"}, + {token.COMMA, ","}, + {token.DOT, "."}, + + {token.RPAREN, ")"}, + {token.RBRACK, "]"}, + {token.RCURLY, "}"}, + + // Keywords + {token.NULL, "null"}, + {token.BOOL, "true"}, + {token.BOOL, "false"}, +} + +const whitespace = " \t \n\n\n" // Various whitespace to separate tokens + +var source = func() []byte { + var src []byte + for _, t := range tokens { + src = append(src, t.lit...) + src = append(src, whitespace...) + } + return src +}() + +// FuzzScanner ensures that the scanner will always be able to reach EOF +// regardless of input. +func FuzzScanner(f *testing.F) { + // Add each token into the corpus + for _, t := range tokens { + f.Add([]byte(t.lit)) + } + // Then add the entire source + f.Add(source) + + f.Fuzz(func(t *testing.T, input []byte) { + f := token.NewFile(t.Name()) + + s := New(f, input, nil, IncludeComments) + + for { + _, tok, _ := s.Scan() + if tok == token.EOF { + break + } + } + }) +} + +func TestScanner_Scan(t *testing.T) { + whitespaceLinecount := newlineCount(whitespace) + + var eh ErrorHandler = func(_ token.Pos, msg string) { + t.Errorf("ErrorHandler called (msg = %s)", msg) + } + + f := token.NewFile(t.Name()) + s := New(f, source, eh, IncludeComments|dontInsertTerms) + + // Configure expected position + expectPos := token.Position{ + Filename: t.Name(), + Offset: 0, + Line: 1, + Column: 1, + } + + index := 0 + for { + pos, tok, lit := s.Scan() + + // Check position + checkPos(t, lit, tok, pos, expectPos) + + // Check token + e := tokenExample{token.EOF, ""} + if index < len(tokens) { + e = tokens[index] + index++ + } + assert.Equal(t, e.tok, tok) + + // Check literal + expectLit := "" + switch e.tok { + case token.COMMENT: + // no CRs in comments + expectLit = string(stripCR([]byte(e.lit), e.lit[1] == '*')) + if expectLit[1] == '/' { + // Line comment literals doesn't contain newline + expectLit = expectLit[0 : len(expectLit)-1] + } + case token.IDENT: + expectLit = e.lit + case token.NUMBER, token.FLOAT, token.STRING, token.NULL, token.BOOL: + expectLit = e.lit + } + assert.Equal(t, expectLit, lit) + + if tok == token.EOF { + break + } + + // Update position + expectPos.Offset += len(e.lit) + len(whitespace) + expectPos.Line += newlineCount(e.lit) + whitespaceLinecount + } + + if s.NumErrors() != 0 { + assert.Zero(t, s.NumErrors(), "expected number of scanning errors to be 0") + } +} + +func newlineCount(s string) int { + var n int + for i := 0; i < len(s); i++ { + if s[i] == '\n' { + n++ + } + } + return n +} + +func checkPos(t *testing.T, lit string, tok token.Token, p token.Pos, expected token.Position) { + t.Helper() + + pos := p.Position() + + // Check cleaned filenames so that we don't have to worry about different + // os.PathSeparator values. + if pos.Filename != expected.Filename && filepath.Clean(pos.Filename) != filepath.Clean(expected.Filename) { + assert.Equal(t, expected.Filename, pos.Filename, "Bad filename for %s (%q)", tok, lit) + } + + assert.Equal(t, expected.Offset, pos.Offset, "Bad offset for %s (%q)", tok, lit) + assert.Equal(t, expected.Line, pos.Line, "Bad line for %s (%q)", tok, lit) + assert.Equal(t, expected.Column, pos.Column, "Bad column for %s (%q)", tok, lit) +} + +var errorTests = []struct { + input string + tok token.Token + pos int + lit string + err string +}{ + {"\a", token.ILLEGAL, 0, "", "illegal character U+0007"}, + {`…`, token.ILLEGAL, 0, "", "illegal character U+2026 '…'"}, + {"..", token.DOT, 0, "", ""}, // two periods, not invalid token (golang/go#28112) + {`'illegal string'`, token.ILLEGAL, 0, "", "illegal single-quoted string; use double quotes"}, + {`""`, token.STRING, 0, `""`, ""}, + {`"abc`, token.STRING, 0, `"abc`, "string literal not terminated"}, + {"\"abc\n", token.STRING, 0, `"abc`, "string literal not terminated"}, + {"\"abc\n ", token.STRING, 0, `"abc`, "string literal not terminated"}, + {"\"abc\x00def\"", token.STRING, 4, "\"abc\x00def\"", "illegal character NUL"}, + {"\"abc\x80def\"", token.STRING, 4, "\"abc\x80def\"", "illegal UTF-8 encoding"}, + {"\ufeff\ufeff", token.ILLEGAL, 3, "\ufeff\ufeff", "illegal byte order mark"}, // only first BOM is ignored + {"//\ufeff", token.COMMENT, 2, "//\ufeff", "illegal byte order mark"}, // only first BOM is ignored + {`"` + "abc\ufeffdef" + `"`, token.STRING, 4, `"` + "abc\ufeffdef" + `"`, "illegal byte order mark"}, // only first BOM is ignored + {"abc\x00def", token.IDENT, 3, "abc", "illegal character NUL"}, + {"abc\x00", token.IDENT, 3, "abc", "illegal character NUL"}, + {"10E", token.FLOAT, 0, "10E", "exponent has no digits"}, +} + +func TestScanner_Scan_Errors(t *testing.T) { + for _, e := range errorTests { + checkError(t, e.input, e.tok, e.pos, e.lit, e.err) + } +} + +func checkError(t *testing.T, src string, tok token.Token, pos int, lit, err string) { + t.Helper() + + var ( + actualErrors int + latestError string + latestPos token.Pos + ) + + eh := func(pos token.Pos, msg string) { + actualErrors++ + latestError = msg + latestPos = pos + } + + f := token.NewFile(t.Name()) + s := New(f, []byte(src), eh, IncludeComments|dontInsertTerms) + + _, actualTok, actualLit := s.Scan() + + assert.Equal(t, tok, actualTok) + if actualTok != token.ILLEGAL { + assert.Equal(t, lit, actualLit) + } + + expectErrors := 0 + if err != "" { + expectErrors = 1 + } + + assert.Equal(t, expectErrors, actualErrors, "Unexpected error count in src %q", src) + assert.Equal(t, err, latestError, "Unexpected error message in src %q", src) + assert.Equal(t, pos, latestPos.Offset(), "Unexpected offset in src %q", src) +} diff --git a/syntax/token/builder/builder.go b/syntax/token/builder/builder.go new file mode 100644 index 0000000000..1dc9b5d62b --- /dev/null +++ b/syntax/token/builder/builder.go @@ -0,0 +1,419 @@ +// Package builder exposes an API to create a River configuration file by +// constructing a set of tokens. +package builder + +import ( + "bytes" + "fmt" + "io" + "reflect" + "strings" + + "github.com/grafana/river/internal/reflectutil" + "github.com/grafana/river/internal/rivertags" + "github.com/grafana/river/internal/value" + "github.com/grafana/river/token" +) + +var goRiverDefaulter = reflect.TypeOf((*value.Defaulter)(nil)).Elem() + +// An Expr represents a single River expression. +type Expr struct { + rawTokens []Token +} + +// NewExpr creates a new Expr. +func NewExpr() *Expr { return &Expr{} } + +// Tokens returns the Expr as a set of Tokens. +func (e *Expr) Tokens() []Token { return e.rawTokens } + +// SetValue sets the Expr to a River value converted from a Go value. The Go +// value is encoded using the normal Go to River encoding rules. If any value +// reachable from goValue implements Tokenizer, the printed tokens will instead +// be retrieved by calling the RiverTokenize method. +func (e *Expr) SetValue(goValue interface{}) { + e.rawTokens = tokenEncode(goValue) +} + +// WriteTo renders and formats the File, writing the contents to w. +func (e *Expr) WriteTo(w io.Writer) (int64, error) { + n, err := printExprTokens(w, e.Tokens()) + return int64(n), err +} + +// Bytes renders the File to a formatted byte slice. +func (e *Expr) Bytes() []byte { + var buf bytes.Buffer + _, _ = e.WriteTo(&buf) + return buf.Bytes() +} + +// A File represents a River configuration file. +type File struct { + body *Body +} + +// NewFile creates a new File. +func NewFile() *File { return &File{body: newBody()} } + +// Tokens returns the File as a set of Tokens. +func (f *File) Tokens() []Token { return f.Body().Tokens() } + +// Body returns the Body contents of the file. +func (f *File) Body() *Body { return f.body } + +// WriteTo renders and formats the File, writing the contents to w. +func (f *File) WriteTo(w io.Writer) (int64, error) { + n, err := printFileTokens(w, f.Tokens()) + return int64(n), err +} + +// Bytes renders the File to a formatted byte slice. +func (f *File) Bytes() []byte { + var buf bytes.Buffer + _, _ = f.WriteTo(&buf) + return buf.Bytes() +} + +// Body is a list of block and attribute statements. A Body cannot be manually +// created, but is retrieved from a File or Block. +type Body struct { + nodes []tokenNode + valueOverrideHook ValueOverrideHook +} + +type ValueOverrideHook = func(val interface{}) interface{} + +// SetValueOverrideHook sets a hook to override the value that will be token +// encoded. The hook can mutate the value to be encoded or should return it +// unmodified. This hook can be skipped by leaving it nil or setting it to nil. +func (b *Body) SetValueOverrideHook(valueOverrideHook ValueOverrideHook) { + b.valueOverrideHook = valueOverrideHook +} + +func (b *Body) Nodes() []tokenNode { + return b.nodes +} + +// A tokenNode is a structural element which can be converted into a set of +// Tokens. +type tokenNode interface { + // Tokens builds the set of Tokens from the node. + Tokens() []Token +} + +func newBody() *Body { + return &Body{} +} + +// Tokens returns the File as a set of Tokens. +func (b *Body) Tokens() []Token { + var rawToks []Token + for i, node := range b.nodes { + rawToks = append(rawToks, node.Tokens()...) + + if i+1 < len(b.nodes) { + // Append a terminator between each statement in the Body. + rawToks = append(rawToks, Token{ + Tok: token.LITERAL, + Lit: "\n", + }) + } + } + return rawToks +} + +// AppendTokens appends raw tokens to the Body. +func (b *Body) AppendTokens(tokens []Token) { + b.nodes = append(b.nodes, tokensSlice(tokens)) +} + +// AppendBlock adds a new block inside of the Body. +func (b *Body) AppendBlock(block *Block) { + b.nodes = append(b.nodes, block) +} + +// AppendFrom sets attributes and appends blocks defined by goValue into the +// Body. If any value reachable from goValue implements Tokenizer, the printed +// tokens will instead be retrieved by calling the RiverTokenize method. +// +// Optional attributes and blocks set to default values are trimmed. +// If goValue implements Defaulter, default values are retrieved by +// calling SetToDefault against a copy. Otherwise, default values are +// the zero value of the respective Go types. +// +// goValue must be a struct or a pointer to a struct that contains River struct +// tags. +func (b *Body) AppendFrom(goValue interface{}) { + if goValue == nil { + return + } + + rv := reflect.ValueOf(goValue) + b.encodeFields(rv) +} + +// getBlockLabel returns the label for a given block. +func getBlockLabel(rv reflect.Value) string { + tags := rivertags.Get(rv.Type()) + for _, tag := range tags { + if tag.Flags&rivertags.FlagLabel != 0 { + return reflectutil.Get(rv, tag).String() + } + } + + return "" +} + +func (b *Body) encodeFields(rv reflect.Value) { + for rv.Kind() == reflect.Pointer { + if rv.IsNil() { + return + } + rv = rv.Elem() + } + if rv.Kind() != reflect.Struct { + panic(fmt.Sprintf("river/token/builder: can only encode struct values to bodies, got %s", rv.Type())) + } + + fields := rivertags.Get(rv.Type()) + defaults := reflect.New(rv.Type()).Elem() + if defaults.CanAddr() && defaults.Addr().Type().Implements(goRiverDefaulter) { + defaults.Addr().Interface().(value.Defaulter).SetToDefault() + } + + for _, field := range fields { + fieldVal := reflectutil.Get(rv, field) + fieldValDefault := reflectutil.Get(defaults, field) + + // Check if the values are exactly equal or if they're both equal to the + // zero value. Checking for both fields being zero handles the case where + // an empty and nil map are being compared (which are not equal, but are + // both zero values). + matchesDefault := reflect.DeepEqual(fieldVal.Interface(), fieldValDefault.Interface()) + isZero := fieldValDefault.IsZero() && fieldVal.IsZero() + + if field.IsOptional() && (matchesDefault || isZero) { + continue + } + + b.encodeField(nil, field, fieldVal) + } +} + +func (b *Body) encodeField(prefix []string, field rivertags.Field, fieldValue reflect.Value) { + fieldName := strings.Join(field.Name, ".") + + for fieldValue.Kind() == reflect.Pointer { + if fieldValue.IsNil() { + break + } + fieldValue = fieldValue.Elem() + } + + switch { + case field.IsAttr(): + b.SetAttributeValue(fieldName, fieldValue.Interface()) + + case field.IsBlock(): + fullName := mergeStringSlice(prefix, field.Name) + + switch { + case fieldValue.Kind() == reflect.Map: + // Iterate over the map and add each element as an attribute into it. + if fieldValue.Type().Key().Kind() != reflect.String { + panic("river/token/builder: unsupported map type for block; expected map[string]T, got " + fieldValue.Type().String()) + } + + inner := NewBlock(fullName, "") + inner.body.SetValueOverrideHook(b.valueOverrideHook) + b.AppendBlock(inner) + + iter := fieldValue.MapRange() + for iter.Next() { + mapKey, mapValue := iter.Key(), iter.Value() + inner.body.SetAttributeValue(mapKey.String(), mapValue.Interface()) + } + + case fieldValue.Kind() == reflect.Slice, fieldValue.Kind() == reflect.Array: + for i := 0; i < fieldValue.Len(); i++ { + elem := fieldValue.Index(i) + + // Recursively call encodeField for each element in the slice/array for + // non-zero blocks. The recursive call will hit the case below and add + // a new block for each field encountered. + if field.Flags&rivertags.FlagOptional != 0 && elem.IsZero() { + continue + } + b.encodeField(prefix, field, elem) + } + + case fieldValue.Kind() == reflect.Struct: + inner := NewBlock(fullName, getBlockLabel(fieldValue)) + inner.body.SetValueOverrideHook(b.valueOverrideHook) + inner.Body().encodeFields(fieldValue) + b.AppendBlock(inner) + } + + case field.IsEnum(): + // Blocks within an enum have a prefix set. + newPrefix := mergeStringSlice(prefix, field.Name) + + switch { + case fieldValue.Kind() == reflect.Slice, fieldValue.Kind() == reflect.Array: + for i := 0; i < fieldValue.Len(); i++ { + b.encodeEnumElement(newPrefix, fieldValue.Index(i)) + } + + default: + panic(fmt.Sprintf("river/token/builder: unrecognized enum kind %s", fieldValue.Kind())) + } + } +} + +func mergeStringSlice(a, b []string) []string { + if len(a) == 0 { + return b + } else if len(b) == 0 { + return a + } + + res := make([]string, 0, len(a)+len(b)) + res = append(res, a...) + res = append(res, b...) + return res +} + +func (b *Body) encodeEnumElement(prefix []string, enumElement reflect.Value) { + for enumElement.Kind() == reflect.Pointer { + if enumElement.IsNil() { + return + } + enumElement = enumElement.Elem() + } + + fields := rivertags.Get(enumElement.Type()) + + // Find the first non-zero field and encode it. + for _, field := range fields { + fieldVal := reflectutil.Get(enumElement, field) + if !fieldVal.IsValid() || fieldVal.IsZero() { + continue + } + + b.encodeField(prefix, field, fieldVal) + break + } +} + +// SetAttributeTokens sets an attribute to the Body whose value is a set of raw +// tokens. If the attribute was previously set, its value tokens are updated. +// +// Attributes will be written out in the order they were initially created. +func (b *Body) SetAttributeTokens(name string, tokens []Token) { + attr := b.getOrCreateAttribute(name) + attr.RawTokens = tokens +} + +func (b *Body) getOrCreateAttribute(name string) *attribute { + for _, n := range b.nodes { + if attr, ok := n.(*attribute); ok && attr.Name == name { + return attr + } + } + + newAttr := &attribute{Name: name} + b.nodes = append(b.nodes, newAttr) + return newAttr +} + +// SetAttributeValue sets an attribute in the Body whose value is converted +// from a Go value to a River value. The Go value is encoded using the normal +// Go to River encoding rules. If any value reachable from goValue implements +// Tokenizer, the printed tokens will instead be retrieved by calling the +// RiverTokenize method. +// +// If the attribute was previously set, its value tokens are updated. +// +// Attributes will be written out in the order they were initially crated. +func (b *Body) SetAttributeValue(name string, goValue interface{}) { + attr := b.getOrCreateAttribute(name) + + if b.valueOverrideHook != nil { + attr.RawTokens = tokenEncode(b.valueOverrideHook(goValue)) + } else { + attr.RawTokens = tokenEncode(goValue) + } +} + +type attribute struct { + Name string + RawTokens []Token +} + +func (attr *attribute) Tokens() []Token { + var toks []Token + + toks = append(toks, Token{Tok: token.IDENT, Lit: attr.Name}) + toks = append(toks, Token{Tok: token.ASSIGN}) + toks = append(toks, attr.RawTokens...) + + return toks +} + +// A Block encapsulates a body within a named and labeled River block. Blocks +// must be created by calling NewBlock, but its public struct fields may be +// safely modified by callers. +type Block struct { + // Public fields, safe to be changed by callers: + + Name []string + Label string + + // Private fields: + + body *Body +} + +// NewBlock returns a new Block with the given name and label. The name/label +// can be updated later by modifying the Block's public fields. +func NewBlock(name []string, label string) *Block { + return &Block{ + Name: name, + Label: label, + + body: newBody(), + } +} + +// Tokens returns the File as a set of Tokens. +func (b *Block) Tokens() []Token { + var toks []Token + + for i, frag := range b.Name { + toks = append(toks, Token{Tok: token.IDENT, Lit: frag}) + if i+1 < len(b.Name) { + toks = append(toks, Token{Tok: token.DOT}) + } + } + + toks = append(toks, Token{Tok: token.LITERAL, Lit: " "}) + + if b.Label != "" { + toks = append(toks, Token{Tok: token.STRING, Lit: fmt.Sprintf("%q", b.Label)}) + } + + toks = append(toks, Token{Tok: token.LCURLY}, Token{Tok: token.LITERAL, Lit: "\n"}) + toks = append(toks, b.body.Tokens()...) + toks = append(toks, Token{Tok: token.LITERAL, Lit: "\n"}, Token{Tok: token.RCURLY}) + + return toks +} + +// Body returns the Body contained within the Block. +func (b *Block) Body() *Body { return b.body } + +type tokensSlice []Token + +func (tn tokensSlice) Tokens() []Token { return []Token(tn) } diff --git a/syntax/token/builder/builder_test.go b/syntax/token/builder/builder_test.go new file mode 100644 index 0000000000..d363556929 --- /dev/null +++ b/syntax/token/builder/builder_test.go @@ -0,0 +1,411 @@ +package builder_test + +import ( + "bytes" + "fmt" + "testing" + "time" + + "github.com/grafana/river/parser" + "github.com/grafana/river/printer" + "github.com/grafana/river/token" + "github.com/grafana/river/token/builder" + "github.com/stretchr/testify/require" +) + +func TestBuilder_File(t *testing.T) { + f := builder.NewFile() + + f.Body().SetAttributeTokens("attr_1", []builder.Token{{Tok: token.NUMBER, Lit: "15"}}) + f.Body().SetAttributeTokens("attr_2", []builder.Token{{Tok: token.BOOL, Lit: "true"}}) + + b1 := builder.NewBlock([]string{"test", "block"}, "") + b1.Body().SetAttributeTokens("inner_attr", []builder.Token{{Tok: token.STRING, Lit: `"block 1"`}}) + f.Body().AppendBlock(b1) + + b2 := builder.NewBlock([]string{"test", "block"}, "labeled") + b2.Body().SetAttributeTokens("inner_attr", []builder.Token{{Tok: token.STRING, Lit: "`\"block 2`"}}) + f.Body().AppendBlock(b2) + + expect := format(t, ` + attr_1 = 15 + attr_2 = true + + test.block { + inner_attr = "block 1" + } + + test.block "labeled" { + inner_attr = `+"`\"block 2`"+` + } + `) + + require.Equal(t, expect, string(f.Bytes())) +} + +func TestBuilder_GoEncode(t *testing.T) { + f := builder.NewFile() + + f.Body().AppendTokens([]builder.Token{{token.COMMENT, "// Hello, world!"}}) + f.Body().SetAttributeValue("null_value", nil) + f.Body().AppendTokens([]builder.Token{{token.LITERAL, "\n"}}) + + f.Body().SetAttributeValue("num", 15) + f.Body().SetAttributeValue("string", "Hello, world!") + f.Body().SetAttributeValue("bool", true) + f.Body().SetAttributeValue("list", []int{0, 1, 2}) + f.Body().SetAttributeValue("func", func(int, int) int { return 0 }) + f.Body().SetAttributeValue("capsule", make(chan int)) + f.Body().AppendTokens([]builder.Token{{token.LITERAL, "\n"}}) + + f.Body().SetAttributeValue("map", map[string]interface{}{"foo": "bar"}) + f.Body().SetAttributeValue("map_2", map[string]interface{}{"non ident": "bar"}) + f.Body().AppendTokens([]builder.Token{{token.LITERAL, "\n"}}) + + f.Body().SetAttributeValue("mixed_list", []interface{}{ + 0, + true, + map[string]interface{}{"key": true}, + "Hello!", + }) + + expect := format(t, ` + // Hello, world! + null_value = null + + num = 15 + string = "Hello, world!" + bool = true + list = [0, 1, 2] + func = function + capsule = capsule("chan int") + + map = { + foo = "bar", + } + map_2 = { + "non ident" = "bar", + } + + mixed_list = [0, true, { + key = true, + }, "Hello!"] + `) + + require.Equal(t, expect, string(f.Bytes())) +} + +// TestBuilder_GoEncode_SortMapKeys ensures that object literals from unordered +// values (i.e., Go maps) are printed in a deterministic order by sorting the +// keys lexicographically. Other object literals should be printed in the order +// the keys are reported in (i.e., in the order presented by the Go structs). +func TestBuilder_GoEncode_SortMapKeys(t *testing.T) { + f := builder.NewFile() + + type Ordered struct { + SomeKey string `river:"some_key,attr"` + OtherKey string `river:"other_key,attr"` + } + + // Maps are unordered because you can't iterate over their keys in a + // consistent order. + var unordered = map[string]interface{}{ + "key_a": 1, + "key_c": 3, + "key_b": 2, + } + + f.Body().SetAttributeValue("ordered", Ordered{SomeKey: "foo", OtherKey: "bar"}) + f.Body().SetAttributeValue("unordered", unordered) + + expect := format(t, ` + ordered = { + some_key = "foo", + other_key = "bar", + } + unordered = { + key_a = 1, + key_b = 2, + key_c = 3, + } + `) + + require.Equal(t, expect, string(f.Bytes())) +} + +func TestBuilder_AppendFrom(t *testing.T) { + type InnerBlock struct { + Number int `river:"number,attr"` + } + + type Structure struct { + Field string `river:"field,attr"` + + Block InnerBlock `river:"block,block"` + OtherBlocks []InnerBlock `river:"other_block,block"` + } + + f := builder.NewFile() + f.Body().AppendFrom(Structure{ + Field: "some_value", + + Block: InnerBlock{Number: 1}, + OtherBlocks: []InnerBlock{ + {Number: 2}, + {Number: 3}, + }, + }) + + expect := format(t, ` + field = "some_value" + + block { + number = 1 + } + + other_block { + number = 2 + } + + other_block { + number = 3 + } + `) + + require.Equal(t, expect, string(f.Bytes())) +} + +func TestBuilder_AppendFrom_EnumSlice(t *testing.T) { + type InnerBlock struct { + Number int `river:"number,attr"` + } + + type EnumBlock struct { + BlockA InnerBlock `river:"a,block,optional"` + BlockB InnerBlock `river:"b,block,optional"` + BlockC InnerBlock `river:"c,block,optional"` + } + + type Structure struct { + Field string `river:"field,attr"` + + OtherBlocks []EnumBlock `river:"block,enum"` + } + + f := builder.NewFile() + f.Body().AppendFrom(Structure{ + Field: "some_value", + OtherBlocks: []EnumBlock{ + {BlockC: InnerBlock{Number: 1}}, + {BlockB: InnerBlock{Number: 2}}, + {BlockC: InnerBlock{Number: 3}}, + }, + }) + + expect := format(t, ` + field = "some_value" + + block.c { + number = 1 + } + + block.b { + number = 2 + } + + block.c { + number = 3 + } + `) + + require.Equal(t, expect, string(f.Bytes())) +} + +func TestBuilder_AppendFrom_EnumSlice_Pointer(t *testing.T) { + type InnerBlock struct { + Number int `river:"number,attr"` + } + + type EnumBlock struct { + BlockA *InnerBlock `river:"a,block,optional"` + BlockB *InnerBlock `river:"b,block,optional"` + BlockC *InnerBlock `river:"c,block,optional"` + } + + type Structure struct { + Field string `river:"field,attr"` + + OtherBlocks []EnumBlock `river:"block,enum"` + } + + f := builder.NewFile() + f.Body().AppendFrom(Structure{ + Field: "some_value", + OtherBlocks: []EnumBlock{ + {BlockC: &InnerBlock{Number: 1}}, + {BlockB: &InnerBlock{Number: 2}}, + {BlockC: &InnerBlock{Number: 3}}, + }, + }) + + expect := format(t, ` + field = "some_value" + + block.c { + number = 1 + } + + block.b { + number = 2 + } + + block.c { + number = 3 + } + `) + + require.Equal(t, expect, string(f.Bytes())) +} + +func TestBuilder_SkipOptional(t *testing.T) { + type Structure struct { + OptFieldA string `river:"opt_field_a,attr,optional"` + OptFieldB string `river:"opt_field_b,attr,optional"` + ReqFieldA string `river:"req_field_a,attr"` + ReqFieldB string `river:"req_field_b,attr"` + } + + f := builder.NewFile() + f.Body().AppendFrom(Structure{ + OptFieldA: "some_value", + OptFieldB: "", // Zero value + ReqFieldA: "some_value", + ReqFieldB: "", // Zero value + }) + + expect := format(t, ` + opt_field_a = "some_value" + req_field_a = "some_value" + req_field_b = "" + `) + + require.Equal(t, expect, string(f.Bytes())) +} + +func format(t *testing.T, in string) string { + t.Helper() + + f, err := parser.ParseFile(t.Name(), []byte(in)) + require.NoError(t, err) + + var buf bytes.Buffer + require.NoError(t, printer.Fprint(&buf, f)) + + return buf.String() +} + +type CustomTokenizer bool + +var _ builder.Tokenizer = (CustomTokenizer)(false) + +func (ct CustomTokenizer) RiverTokenize() []builder.Token { + return []builder.Token{{Tok: token.LITERAL, Lit: "CUSTOM_TOKENS"}} +} + +func TestBuilder_GoEncode_Tokenizer(t *testing.T) { + t.Run("Tokenizer", func(t *testing.T) { + f := builder.NewFile() + f.Body().SetAttributeValue("value", CustomTokenizer(true)) + + expect := format(t, `value = CUSTOM_TOKENS`) + require.Equal(t, expect, string(f.Bytes())) + }) + + t.Run("TextMarshaler", func(t *testing.T) { + now := time.Now() + expectBytes, err := now.MarshalText() + require.NoError(t, err) + + f := builder.NewFile() + f.Body().SetAttributeValue("value", now) + + expect := format(t, fmt.Sprintf(`value = %q`, string(expectBytes))) + require.Equal(t, expect, string(f.Bytes())) + }) + + t.Run("Duration", func(t *testing.T) { + dur := 15 * time.Second + + f := builder.NewFile() + f.Body().SetAttributeValue("value", dur) + + expect := format(t, fmt.Sprintf(`value = %q`, dur.String())) + require.Equal(t, expect, string(f.Bytes())) + }) +} + +func TestBuilder_ValueOverrideHook(t *testing.T) { + type InnerBlock struct { + AnotherField string `river:"another_field,attr"` + } + + type Structure struct { + Field string `river:"field,attr"` + + Block InnerBlock `river:"block,block"` + OtherBlocks []InnerBlock `river:"other_block,block"` + } + + f := builder.NewFile() + f.Body().SetValueOverrideHook(func(val interface{}) interface{} { + return "some other value" + }) + f.Body().AppendFrom(Structure{ + Field: "some_value", + + Block: InnerBlock{AnotherField: "some_value"}, + OtherBlocks: []InnerBlock{ + {AnotherField: "some_value"}, + {AnotherField: "some_value"}, + }, + }) + + expect := format(t, ` + field = "some other value" + + block { + another_field = "some other value" + } + + other_block { + another_field = "some other value" + } + + other_block { + another_field = "some other value" + } + `) + + require.Equal(t, expect, string(f.Bytes())) +} + +func TestBuilder_MapBlocks(t *testing.T) { + type block struct { + Value map[string]any `river:"block,block,optional"` + } + + f := builder.NewFile() + f.Body().AppendFrom(block{ + Value: map[string]any{ + "field": "value", + }, + }) + + expect := format(t, ` + block { + field = "value" + } + `) + + require.Equal(t, expect, string(f.Bytes())) +} diff --git a/syntax/token/builder/nested_defaults_test.go b/syntax/token/builder/nested_defaults_test.go new file mode 100644 index 0000000000..1fd8122b28 --- /dev/null +++ b/syntax/token/builder/nested_defaults_test.go @@ -0,0 +1,233 @@ +package builder_test + +import ( + "fmt" + "reflect" + "testing" + + "github.com/grafana/river/ast" + "github.com/grafana/river/parser" + "github.com/grafana/river/token/builder" + "github.com/grafana/river/vm" + "github.com/stretchr/testify/require" +) + +const ( + defaultNumber = 123 + otherDefaultNumber = 321 +) + +var testCases = []struct { + name string + input interface{} + expectedRiver string +}{ + { + name: "struct propagating default - input matching default", + input: StructPropagatingDefault{Inner: AttrWithDefault{Number: defaultNumber}}, + expectedRiver: "", + }, + { + name: "struct propagating default - input with zero-value struct", + input: StructPropagatingDefault{}, + expectedRiver: ` + inner { + number = 0 + } + `, + }, + { + name: "struct propagating default - input with non-default value", + input: StructPropagatingDefault{Inner: AttrWithDefault{Number: 42}}, + expectedRiver: ` + inner { + number = 42 + } + `, + }, + { + name: "pointer propagating default - input matching default", + input: PtrPropagatingDefault{Inner: &AttrWithDefault{Number: defaultNumber}}, + expectedRiver: "", + }, + { + name: "pointer propagating default - input with zero value", + input: PtrPropagatingDefault{Inner: &AttrWithDefault{}}, + expectedRiver: ` + inner { + number = 0 + } + `, + }, + { + name: "pointer propagating default - input with non-default value", + input: PtrPropagatingDefault{Inner: &AttrWithDefault{Number: 42}}, + expectedRiver: ` + inner { + number = 42 + } + `, + }, + { + name: "zero default - input with zero value", + input: ZeroDefault{Inner: &AttrWithDefault{}}, + expectedRiver: "", + }, + { + name: "zero default - input with non-default value", + input: ZeroDefault{Inner: &AttrWithDefault{Number: 42}}, + expectedRiver: ` + inner { + number = 42 + } + `, + }, + { + name: "no default - input with zero value", + input: NoDefaultDefined{Inner: &AttrWithDefault{}}, + expectedRiver: ` + inner { + number = 0 + } + `, + }, + { + name: "no default - input with non-default value", + input: NoDefaultDefined{Inner: &AttrWithDefault{Number: 42}}, + expectedRiver: ` + inner { + number = 42 + } + `, + }, + { + name: "mismatching default - input matching outer default", + input: MismatchingDefault{Inner: &AttrWithDefault{Number: otherDefaultNumber}}, + expectedRiver: "", + }, + { + name: "mismatching default - input matching inner default", + input: MismatchingDefault{Inner: &AttrWithDefault{Number: defaultNumber}}, + expectedRiver: "inner { }", + }, + { + name: "mismatching default - input with non-default value", + input: MismatchingDefault{Inner: &AttrWithDefault{Number: 42}}, + expectedRiver: ` + inner { + number = 42 + } + `, + }, +} + +func TestNestedDefaults(t *testing.T) { + for _, tc := range testCases { + t.Run(fmt.Sprintf("%T/%s", tc.input, tc.name), func(t *testing.T) { + f := builder.NewFile() + f.Body().AppendFrom(tc.input) + actualRiver := string(f.Bytes()) + expected := format(t, tc.expectedRiver) + require.Equal(t, expected, actualRiver, "generated river didn't match expected") + + // Now decode the River produced above and make sure it's the same as the input. + eval := vm.New(parseBlock(t, actualRiver)) + vPtr := reflect.New(reflect.TypeOf(tc.input)).Interface() + require.NoError(t, eval.Evaluate(nil, vPtr), "river evaluation error") + + actualOut := reflect.ValueOf(vPtr).Elem().Interface() + require.Equal(t, tc.input, actualOut, "Invariant violated: encoded and then decoded block didn't match the original value") + }) + } +} + +func TestPtrPropagatingDefaultWithNil(t *testing.T) { + // This is a special case - when defaults are correctly defined, the `Inner: nil` should mean to use defaults. + // Encoding will encode to empty string and decoding will produce the default value - `Inner: {Number: 123}`. + input := PtrPropagatingDefault{} + expectedEncodedRiver := "" + expectedDecoded := PtrPropagatingDefault{Inner: &AttrWithDefault{Number: 123}} + + f := builder.NewFile() + f.Body().AppendFrom(input) + actualRiver := string(f.Bytes()) + expected := format(t, expectedEncodedRiver) + require.Equal(t, expected, actualRiver, "generated river didn't match expected") + + // Now decode the River produced above and make sure it's the same as the input. + eval := vm.New(parseBlock(t, actualRiver)) + vPtr := reflect.New(reflect.TypeOf(input)).Interface() + require.NoError(t, eval.Evaluate(nil, vPtr), "river evaluation error") + + actualOut := reflect.ValueOf(vPtr).Elem().Interface() + require.Equal(t, expectedDecoded, actualOut) +} + +// StructPropagatingDefault has the outer defaults matching the inner block's defaults. The inner block is a struct. +type StructPropagatingDefault struct { + Inner AttrWithDefault `river:"inner,block,optional"` +} + +func (o *StructPropagatingDefault) SetToDefault() { + inner := &AttrWithDefault{} + inner.SetToDefault() + *o = StructPropagatingDefault{Inner: *inner} +} + +// PtrPropagatingDefault has the outer defaults matching the inner block's defaults. The inner block is a pointer. +type PtrPropagatingDefault struct { + Inner *AttrWithDefault `river:"inner,block,optional"` +} + +func (o *PtrPropagatingDefault) SetToDefault() { + inner := &AttrWithDefault{} + inner.SetToDefault() + *o = PtrPropagatingDefault{Inner: inner} +} + +// MismatchingDefault has the outer defaults NOT matching the inner block's defaults. The inner block is a pointer. +type MismatchingDefault struct { + Inner *AttrWithDefault `river:"inner,block,optional"` +} + +func (o *MismatchingDefault) SetToDefault() { + *o = MismatchingDefault{Inner: &AttrWithDefault{ + Number: otherDefaultNumber, + }} +} + +// ZeroDefault has the outer defaults setting to zero values. The inner block is a pointer. +type ZeroDefault struct { + Inner *AttrWithDefault `river:"inner,block,optional"` +} + +func (o *ZeroDefault) SetToDefault() { + *o = ZeroDefault{Inner: &AttrWithDefault{}} +} + +// NoDefaultDefined has no defaults defined. The inner block is a pointer. +type NoDefaultDefined struct { + Inner *AttrWithDefault `river:"inner,block,optional"` +} + +// AttrWithDefault has a default value of a non-zero number. +type AttrWithDefault struct { + Number int `river:"number,attr,optional"` +} + +func (i *AttrWithDefault) SetToDefault() { + *i = AttrWithDefault{Number: defaultNumber} +} + +func parseBlock(t *testing.T, input string) *ast.BlockStmt { + t.Helper() + + input = fmt.Sprintf("test { %s }", input) + res, err := parser.ParseFile("", []byte(input)) + require.NoError(t, err) + require.Len(t, res.Body, 1) + + stmt, ok := res.Body[0].(*ast.BlockStmt) + require.True(t, ok, "Expected stmt to be a ast.BlockStmt, got %T", res.Body[0]) + return stmt +} diff --git a/syntax/token/builder/token.go b/syntax/token/builder/token.go new file mode 100644 index 0000000000..390b968959 --- /dev/null +++ b/syntax/token/builder/token.go @@ -0,0 +1,81 @@ +package builder + +import ( + "bytes" + "io" + + "github.com/grafana/river/parser" + "github.com/grafana/river/printer" + "github.com/grafana/river/token" +) + +// A Token is a wrapper around token.Token which contains the token type +// alongside its literal. Use LiteralTok as the Tok field to write literal +// characters such as whitespace. +type Token struct { + Tok token.Token + Lit string +} + +// printFileTokens prints out the tokens as River text and formats them, writing +// the final result to w. +func printFileTokens(w io.Writer, toks []Token) (int, error) { + var raw bytes.Buffer + for _, tok := range toks { + switch { + case tok.Tok == token.LITERAL: + raw.WriteString(tok.Lit) + case tok.Tok == token.COMMENT: + raw.WriteString(tok.Lit) + case tok.Tok.IsLiteral() || tok.Tok.IsKeyword(): + raw.WriteString(tok.Lit) + default: + raw.WriteString(tok.Tok.String()) + } + } + + f, err := parser.ParseFile("", raw.Bytes()) + if err != nil { + return 0, err + } + + wc := &writerCount{w: w} + err = printer.Fprint(wc, f) + return wc.n, err +} + +// printExprTokens prints out the tokens as River text and formats them, +// writing the final result to w. +func printExprTokens(w io.Writer, toks []Token) (int, error) { + var raw bytes.Buffer + for _, tok := range toks { + switch { + case tok.Tok == token.LITERAL: + raw.WriteString(tok.Lit) + case tok.Tok.IsLiteral() || tok.Tok.IsKeyword(): + raw.WriteString(tok.Lit) + default: + raw.WriteString(tok.Tok.String()) + } + } + + expr, err := parser.ParseExpression(raw.String()) + if err != nil { + return 0, err + } + + wc := &writerCount{w: w} + err = printer.Fprint(wc, expr) + return wc.n, err +} + +type writerCount struct { + w io.Writer + n int +} + +func (wc *writerCount) Write(p []byte) (n int, err error) { + n, err = wc.w.Write(p) + wc.n += n + return +} diff --git a/syntax/token/builder/value_tokens.go b/syntax/token/builder/value_tokens.go new file mode 100644 index 0000000000..c73e34f7b6 --- /dev/null +++ b/syntax/token/builder/value_tokens.go @@ -0,0 +1,95 @@ +package builder + +import ( + "fmt" + "sort" + + "github.com/grafana/river/internal/value" + "github.com/grafana/river/scanner" + "github.com/grafana/river/token" +) + +// TODO(rfratto): check for optional values + +// Tokenizer is any value which can return a raw set of tokens. +type Tokenizer interface { + // RiverTokenize returns the raw set of River tokens which are used when + // printing out the value with river/token/builder. + RiverTokenize() []Token +} + +func tokenEncode(val interface{}) []Token { + return valueTokens(value.Encode(val)) +} + +func valueTokens(v value.Value) []Token { + var toks []Token + + // If v is a Tokenizer, allow it to override what tokens get generated. + if tk, ok := v.Interface().(Tokenizer); ok { + return tk.RiverTokenize() + } + + switch v.Type() { + case value.TypeNull: + toks = append(toks, Token{token.NULL, "null"}) + + case value.TypeNumber: + toks = append(toks, Token{token.NUMBER, v.Number().ToString()}) + + case value.TypeString: + toks = append(toks, Token{token.STRING, fmt.Sprintf("%q", v.Text())}) + + case value.TypeBool: + toks = append(toks, Token{token.STRING, fmt.Sprintf("%v", v.Bool())}) + + case value.TypeArray: + toks = append(toks, Token{token.LBRACK, ""}) + elems := v.Len() + for i := 0; i < elems; i++ { + elem := v.Index(i) + + toks = append(toks, valueTokens(elem)...) + if i+1 < elems { + toks = append(toks, Token{token.COMMA, ""}) + } + } + toks = append(toks, Token{token.RBRACK, ""}) + + case value.TypeObject: + toks = append(toks, Token{token.LCURLY, ""}, Token{token.LITERAL, "\n"}) + + keys := v.Keys() + + // If v isn't an ordered object (i.e., a go map), sort the keys so they + // have a deterministic print order. + if !v.OrderedKeys() { + sort.Strings(keys) + } + + for i := 0; i < len(keys); i++ { + if scanner.IsValidIdentifier(keys[i]) { + toks = append(toks, Token{token.IDENT, keys[i]}) + } else { + toks = append(toks, Token{token.STRING, fmt.Sprintf("%q", keys[i])}) + } + + field, _ := v.Key(keys[i]) + toks = append(toks, Token{token.ASSIGN, ""}) + toks = append(toks, valueTokens(field)...) + toks = append(toks, Token{token.COMMA, ""}, Token{token.LITERAL, "\n"}) + } + toks = append(toks, Token{token.RCURLY, ""}) + + case value.TypeFunction: + toks = append(toks, Token{token.LITERAL, v.Describe()}) + + case value.TypeCapsule: + toks = append(toks, Token{token.LITERAL, v.Describe()}) + + default: + panic(fmt.Sprintf("river/token/builder: unrecognized value type %q", v.Type())) + } + + return toks +} diff --git a/syntax/token/file.go b/syntax/token/file.go new file mode 100644 index 0000000000..419dbaa57c --- /dev/null +++ b/syntax/token/file.go @@ -0,0 +1,142 @@ +package token + +import ( + "fmt" + "sort" + "strconv" +) + +// NoPos is the zero value for Pos. It has no file or line information +// associated with it, and NoPos.Valid is false. +var NoPos = Pos{} + +// Pos is a compact representation of a position within a file. It can be +// converted into a Position for a more convenient, but larger, representation. +type Pos struct { + file *File + off int +} + +// String returns the string form of the Pos (the offset). +func (p Pos) String() string { return strconv.Itoa(p.off) } + +// File returns the file used by the Pos. This will be nil for invalid +// positions. +func (p Pos) File() *File { return p.file } + +// Position converts the Pos into a Position. +func (p Pos) Position() Position { return p.file.PositionFor(p) } + +// Add creates a new Pos relative to p. +func (p Pos) Add(n int) Pos { + return Pos{ + file: p.file, + off: p.off + n, + } +} + +// Offset returns the byte offset associated with Pos. +func (p Pos) Offset() int { return p.off } + +// Valid reports whether the Pos is valid. +func (p Pos) Valid() bool { return p != NoPos } + +// Position holds full position information for a location within an individual +// file. +type Position struct { + Filename string // Filename (if any) + Offset int // Byte offset (starting at 0) + Line int // Line number (starting at 1) + Column int // Offset from start of line (starting at 1) +} + +// Valid reports whether the position is valid. Valid positions must have a +// Line value greater than 0. +func (pos *Position) Valid() bool { return pos.Line > 0 } + +// String returns a string in one of the following forms: +// +// file:line:column Valid position with file name +// file:line Valid position with file name but no column +// line:column Valid position with no file name +// line Valid position with no file name or column +// file Invalid position with file name +// - Invalid position with no file name +func (pos Position) String() string { + s := pos.Filename + + if pos.Valid() { + if s != "" { + s += ":" + } + s += fmt.Sprintf("%d", pos.Line) + if pos.Column != 0 { + s += fmt.Sprintf(":%d", pos.Column) + } + } + + if s == "" { + s = "-" + } + return s +} + +// File holds position information for a specific file. +type File struct { + filename string + lines []int // Byte offset of each line number (first element is always 0). +} + +// NewFile creates a new File for storing position information. +func NewFile(filename string) *File { + return &File{ + filename: filename, + lines: []int{0}, + } +} + +// Pos returns a Pos given a byte offset. Pos panics if off is < 0. +func (f *File) Pos(off int) Pos { + if off < 0 { + panic("Pos: illegal offset") + } + return Pos{file: f, off: off} +} + +// Name returns the name of the file. +func (f *File) Name() string { return f.filename } + +// AddLine tracks a new line from a byte offset. The line offset must be larger +// than the offset for the previous line, otherwise the line offset is ignored. +func (f *File) AddLine(offset int) { + lines := len(f.lines) + if f.lines[lines-1] < offset { + f.lines = append(f.lines, offset) + } +} + +// PositionFor returns a Position from an offset. +func (f *File) PositionFor(p Pos) Position { + if p == NoPos { + return Position{} + } + + // Search through our line offsets to find the line/column info. The else + // case should never happen here, but if it does, Position.Valid will return + // false. + var line, column int + if i := searchInts(f.lines, p.off); i >= 0 { + line, column = i+1, p.off-f.lines[i]+1 + } + + return Position{ + Filename: f.filename, + Offset: p.off, + Line: line, + Column: column, + } +} + +func searchInts(a []int, x int) int { + return sort.Search(len(a), func(i int) bool { return a[i] > x }) - 1 +} diff --git a/syntax/token/token.go b/syntax/token/token.go new file mode 100644 index 0000000000..74dabc9de2 --- /dev/null +++ b/syntax/token/token.go @@ -0,0 +1,174 @@ +// Package token defines the lexical elements of a River config and utilities +// surrounding their position. +package token + +// Token is an individual River lexical token. +type Token int + +// List of all lexical tokens and examples that represent them. +// +// LITERAL is used by token/builder to represent literal strings for writing +// tokens, but never used for reading (so scanner never returns a +// token.LITERAL). +const ( + ILLEGAL Token = iota // Invalid token. + LITERAL // Literal text. + EOF // End-of-file. + COMMENT // // Hello, world! + + literalBeg + IDENT // foobar + NUMBER // 1234 + FLOAT // 1234.0 + STRING // "foobar" + literalEnd + + keywordBeg + BOOL // true + NULL // null + keywordEnd + + operatorBeg + OR // || + AND // && + NOT // ! + + ASSIGN // = + + EQ // == + NEQ // != + LT // < + LTE // <= + GT // > + GTE // >= + + ADD // + + SUB // - + MUL // * + DIV // / + MOD // % + POW // ^ + + LCURLY // { + RCURLY // } + LPAREN // ( + RPAREN // ) + LBRACK // [ + RBRACK // ] + COMMA // , + DOT // . + operatorEnd + + TERMINATOR // \n +) + +var tokenNames = [...]string{ + ILLEGAL: "ILLEGAL", + LITERAL: "LITERAL", + EOF: "EOF", + COMMENT: "COMMENT", + + IDENT: "IDENT", + NUMBER: "NUMBER", + FLOAT: "FLOAT", + STRING: "STRING", + BOOL: "BOOL", + NULL: "NULL", + + OR: "||", + AND: "&&", + NOT: "!", + + ASSIGN: "=", + EQ: "==", + NEQ: "!=", + LT: "<", + LTE: "<=", + GT: ">", + GTE: ">=", + + ADD: "+", + SUB: "-", + MUL: "*", + DIV: "/", + MOD: "%", + POW: "^", + + LCURLY: "{", + RCURLY: "}", + LPAREN: "(", + RPAREN: ")", + LBRACK: "[", + RBRACK: "]", + COMMA: ",", + DOT: ".", + + TERMINATOR: "TERMINATOR", +} + +// Lookup maps a string to its keyword token or IDENT if it's not a keyword. +func Lookup(ident string) Token { + switch ident { + case "true", "false": + return BOOL + case "null": + return NULL + default: + return IDENT + } +} + +// String returns the string representation corresponding to the token. +func (t Token) String() string { + if int(t) >= len(tokenNames) { + return "ILLEGAL" + } + + name := tokenNames[t] + if name == "" { + return "ILLEGAL" + } + return name +} + +// GoString returns the %#v format of t. +func (t Token) GoString() string { return t.String() } + +// IsKeyword returns true if the token corresponds to a keyword. +func (t Token) IsKeyword() bool { return t > keywordBeg && t < keywordEnd } + +// IsLiteral returns true if the token corresponds to a literal token or +// identifier. +func (t Token) IsLiteral() bool { return t > literalBeg && t < literalEnd } + +// IsOperator returns true if the token corresponds to an operator or +// delimiter. +func (t Token) IsOperator() bool { return t > operatorBeg && t < operatorEnd } + +// BinaryPrecedence returns the operator precedence of the binary operator t. +// If t is not a binary operator, the result is LowestPrecedence. +func (t Token) BinaryPrecedence() int { + switch t { + case OR: + return 1 + case AND: + return 2 + case EQ, NEQ, LT, LTE, GT, GTE: + return 3 + case ADD, SUB: + return 4 + case MUL, DIV, MOD: + return 5 + case POW: + return 6 + } + + return LowestPrecedence +} + +// Levels of precedence for operator tokens. +const ( + LowestPrecedence = 0 // non-operators + UnaryPrecedence = 7 + HighestPrecedence = 8 +) diff --git a/syntax/types.go b/syntax/types.go new file mode 100644 index 0000000000..b0123010b8 --- /dev/null +++ b/syntax/types.go @@ -0,0 +1,97 @@ +package river + +import "github.com/grafana/river/internal/value" + +// Our types in this file are re-implementations of interfaces from +// value.Capsule. They are *not* defined as type aliases, since pkg.go.dev +// would show the type alias instead of the contents of that type (which IMO is +// a frustrating user experience). +// +// The types below must be kept in sync with the internal package, and the +// checks below ensure they're compatible. +var ( + _ value.Defaulter = (Defaulter)(nil) + _ value.Unmarshaler = (Unmarshaler)(nil) + _ value.Validator = (Validator)(nil) + _ value.Capsule = (Capsule)(nil) + _ value.ConvertibleFromCapsule = (ConvertibleFromCapsule)(nil) + _ value.ConvertibleIntoCapsule = (ConvertibleIntoCapsule)(nil) +) + +// The Unmarshaler interface allows a type to hook into River decoding and +// decode into another type or provide pre-decoding logic. +type Unmarshaler interface { + // UnmarshalRiver is invoked when decoding a River value into a Go value. f + // should be called with a pointer to a value to decode into. UnmarshalRiver + // will not be called on types which are squashed into the parent struct + // using `river:",squash"`. + UnmarshalRiver(f func(v interface{}) error) error +} + +// The Defaulter interface allows a type to implement default functionality +// in River evaluation. +type Defaulter interface { + // SetToDefault is called when evaluating a block or body to set the value + // to its defaults. SetToDefault will not be called on types which are + // squashed into the parent struct using `river:",squash"`. + SetToDefault() +} + +// The Validator interface allows a type to implement validation functionality +// in River evaluation. +type Validator interface { + // Validate is called when evaluating a block or body to enforce the + // value is valid. Validate will not be called on types which are + // squashed into the parent struct using `river:",squash"`. + Validate() error +} + +// Capsule is an interface marker which tells River that a type should always +// be treated as a "capsule type" instead of the default type River would +// assign. +// +// Capsule types are useful for passing around arbitrary Go values in River +// expressions and for declaring new synthetic types with custom conversion +// rules. +// +// By default, only two capsule values of the same underlying Go type are +// compatible. Types which implement ConvertibleFromCapsule or +// ConvertibleToCapsule can provide custom logic for conversions from and to +// other types. +type Capsule interface { + // RiverCapsule marks the type as a Capsule. RiverCapsule is never invoked by + // River. + RiverCapsule() +} + +// ErrNoConversion is returned by implementations of ConvertibleFromCapsule and +// ConvertibleToCapsule when a conversion with a specific type is unavailable. +// +// Returning this error causes River to fall back to default conversion rules. +var ErrNoConversion = value.ErrNoConversion + +// ConvertibleFromCapsule is a Capsule which supports custom conversion from +// any Go type which is not the same as the capsule type. +type ConvertibleFromCapsule interface { + Capsule + + // ConvertFrom updates the ConvertibleFromCapsule value based on the value of + // src. src may be any Go value, not just other capsules. + // + // ConvertFrom should return ErrNoConversion if no conversion is available + // from src. Other errors are treated as a River decoding error. + ConvertFrom(src interface{}) error +} + +// ConvertibleIntoCapsule is a Capsule which supports custom conversion into +// any Go type which is not the same as the capsule type. +type ConvertibleIntoCapsule interface { + Capsule + + // ConvertInto should convert its value and store it into dst. dst will be a + // pointer to a Go value of any type. + // + // ConvertInto should return ErrNoConversion if no conversion into dst is + // available. Other errors are treated as a River decoding error. + ConvertInto(dst interface{}) error +} diff --git a/syntax/vm/constant.go b/syntax/vm/constant.go new file mode 100644 index 0000000000..d2e54c717d --- /dev/null +++ b/syntax/vm/constant.go @@ -0,0 +1,64 @@ +package vm + +import ( + "fmt" + "strconv" + + "github.com/grafana/river/internal/value" + "github.com/grafana/river/token" +) + +func valueFromLiteral(lit string, tok token.Token) (value.Value, error) { + // NOTE(rfratto): this function should never return an error, since the + // parser only produces valid tokens; it can only fail if a user hand-builds + // an AST with invalid literals. + + switch tok { + case token.NULL: + return value.Null, nil + + case token.NUMBER: + intVal, err1 := strconv.ParseInt(lit, 0, 64) + if err1 == nil { + return value.Int(intVal), nil + } + + uintVal, err2 := strconv.ParseUint(lit, 0, 64) + if err2 == nil { + return value.Uint(uintVal), nil + } + + floatVal, err3 := strconv.ParseFloat(lit, 64) + if err3 == nil { + return value.Float(floatVal), nil + } + + return value.Null, err3 + + case token.FLOAT: + v, err := strconv.ParseFloat(lit, 64) + if err != nil { + return value.Null, err + } + return value.Float(v), nil + + case token.STRING: + v, err := strconv.Unquote(lit) + if err != nil { + return value.Null, err + } + return value.String(v), nil + + case token.BOOL: + switch lit { + case "true": + return value.Bool(true), nil + case "false": + return value.Bool(false), nil + default: + return value.Null, fmt.Errorf("invalid boolean literal %q", lit) + } + default: + panic(fmt.Sprintf("%v is not a valid token", tok)) + } +} diff --git a/syntax/vm/error.go b/syntax/vm/error.go new file mode 100644 index 0000000000..c82a5b418e --- /dev/null +++ b/syntax/vm/error.go @@ -0,0 +1,106 @@ +package vm + +import ( + "fmt" + "strings" + + "github.com/grafana/river/ast" + "github.com/grafana/river/diag" + "github.com/grafana/river/internal/value" + "github.com/grafana/river/printer" + "github.com/grafana/river/token/builder" +) + +// makeDiagnostic tries to convert err into a diag.Diagnostic. err must be an +// error from the river/internal/value package, otherwise err will be returned +// unmodified. +func makeDiagnostic(err error, assoc map[value.Value]ast.Node) error { + var ( + node ast.Node + expr strings.Builder + message string + cause value.Value + + // Until we find a node, we're not a literal error. + literal = false + ) + + isValueError := value.WalkError(err, func(err error) { + var val value.Value + + switch ne := err.(type) { + case value.Error: + message = ne.Error() + val = ne.Value + case value.TypeError: + message = fmt.Sprintf("should be %s, got %s", ne.Expected, ne.Value.Type()) + val = ne.Value + case value.MissingKeyError: + message = fmt.Sprintf("does not have field named %q", ne.Missing) + val = ne.Value + case value.ElementError: + fmt.Fprintf(&expr, "[%d]", ne.Index) + val = ne.Value + case value.FieldError: + fmt.Fprintf(&expr, ".%s", ne.Field) + val = ne.Value + } + + cause = val + + if foundNode, ok := assoc[val]; ok { + // If we just found a direct node, we can reset the expression buffer so + // we don't unnecessarily print element and field accesses for we can see + // directly in the file. + if literal { + expr.Reset() + } + + node = foundNode + literal = true + } else { + literal = false + } + }) + if !isValueError { + return err + } + + if node != nil { + var nodeText strings.Builder + if err := printer.Fprint(&nodeText, node); err != nil { + // This should never panic; printer.Fprint only fails when given an + // unexpected type, which we never do here. + panic(err) + } + + // Merge the node text with the expression together (which will be relative + // accesses to the expression). + message = fmt.Sprintf("%s%s %s", nodeText.String(), expr.String(), message) + } else { + message = fmt.Sprintf("%s %s", expr.String(), message) + } + + // Render the underlying problematic value as a string. + var valueText string + if cause != value.Null { + be := builder.NewExpr() + be.SetValue(cause.Interface()) + valueText = string(be.Bytes()) + } + if literal { + // Hide the value if the node itself has the error we were worried about. + valueText = "" + } + + d := diag.Diagnostic{ + Severity: diag.SeverityLevelError, + Message: message, + Value: valueText, + } + if node != nil { + d.StartPos = ast.StartPos(node).Position() + d.EndPos = ast.EndPos(node).Position() + } + return d +} diff --git a/syntax/vm/op_binary.go b/syntax/vm/op_binary.go new file mode 100644 index 0000000000..e329f6fc49 --- /dev/null +++ b/syntax/vm/op_binary.go @@ -0,0 +1,360 @@ +package vm + +import ( + "fmt" + "math" + "reflect" + + "github.com/grafana/river/internal/value" + "github.com/grafana/river/rivertypes" + "github.com/grafana/river/token" +) + +func evalBinop(lhs value.Value, op token.Token, rhs value.Value) (value.Value, error) { + // Original parameters of lhs and rhs used for returning errors. + var ( + origLHS = lhs + origRHS = rhs + ) + + // Hack to allow OptionalSecrets to be used in binary operations. + // + // TODO(rfratto): be more flexible in the future with broader definitions of + // how capsules can be converted to other types for the purposes of doing a + // binop. + if lhs.Type() == value.TypeCapsule { + lhs = tryUnwrapOptionalSecret(lhs) + } + if rhs.Type() == value.TypeCapsule { + rhs = tryUnwrapOptionalSecret(rhs) + } + + // TODO(rfratto): evalBinop should check for underflows and overflows + + // We have special handling for EQ and NEQ since it's valid to attempt to + // compare values of any two types. + switch op { + case token.EQ: + return value.Bool(valuesEqual(lhs, rhs)), nil + case token.NEQ: + return value.Bool(!valuesEqual(lhs, rhs)), nil + } + + // The type of lhs and rhs must be acceptable for the binary operator. + if !acceptableBinopType(lhs, op) { + return value.Null, value.Error{ + Value: origLHS, + Inner: fmt.Errorf("should be one of %v for binop %s, got %s", binopAllowedTypes[op], op, lhs.Type()), + } + } else if !acceptableBinopType(rhs, op) { + return value.Null, value.Error{ + Value: origRHS, + Inner: fmt.Errorf("should be one of %v for binop %s, got %s", binopAllowedTypes[op], op, rhs.Type()), + } + } + + // At this point, regardless of the operator, lhs and rhs must have the same + // type. + if lhs.Type() != rhs.Type() { + return value.Null, value.TypeError{Value: rhs, Expected: lhs.Type()} + } + + switch op { + case token.OR: // bool || bool + return value.Bool(lhs.Bool() || rhs.Bool()), nil + case token.AND: // bool && Bool + return value.Bool(lhs.Bool() && rhs.Bool()), nil + + case token.ADD: // number + number, string + string + if lhs.Type() == value.TypeString { + return value.String(lhs.Text() + rhs.Text()), nil + } + + lhsNum, rhsNum := lhs.Number(), rhs.Number() + switch fitNumberKinds(lhsNum.Kind(), rhsNum.Kind()) { + case value.NumberKindUint: + return value.Uint(lhsNum.Uint() + rhsNum.Uint()), nil + case value.NumberKindInt: + return value.Int(lhsNum.Int() + rhsNum.Int()), nil + case value.NumberKindFloat: + return value.Float(lhsNum.Float() + rhsNum.Float()), nil + } + + case token.SUB: // number - number + lhsNum, rhsNum := lhs.Number(), rhs.Number() + switch fitNumberKinds(lhsNum.Kind(), rhsNum.Kind()) { + case value.NumberKindUint: + return value.Uint(lhsNum.Uint() - rhsNum.Uint()), nil + case value.NumberKindInt: + return value.Int(lhsNum.Int() - rhsNum.Int()), nil + case value.NumberKindFloat: + return value.Float(lhsNum.Float() - rhsNum.Float()), nil + } + + case token.MUL: // number * number + lhsNum, rhsNum := lhs.Number(), rhs.Number() + switch fitNumberKinds(lhsNum.Kind(), rhsNum.Kind()) { + case value.NumberKindUint: + return value.Uint(lhsNum.Uint() * rhsNum.Uint()), nil + case value.NumberKindInt: + return value.Int(lhsNum.Int() * rhsNum.Int()), nil + case value.NumberKindFloat: + return value.Float(lhsNum.Float() * rhsNum.Float()), nil + } + + case token.DIV: // number / number + lhsNum, rhsNum := lhs.Number(), rhs.Number() + switch fitNumberKinds(lhsNum.Kind(), rhsNum.Kind()) { + case value.NumberKindUint: + return value.Uint(lhsNum.Uint() / rhsNum.Uint()), nil + case value.NumberKindInt: + return value.Int(lhsNum.Int() / rhsNum.Int()), nil + case value.NumberKindFloat: + return value.Float(lhsNum.Float() / rhsNum.Float()), nil + } + + case token.MOD: // number % number + lhsNum, rhsNum := lhs.Number(), rhs.Number() + switch fitNumberKinds(lhsNum.Kind(), rhsNum.Kind()) { + case value.NumberKindUint: + return value.Uint(lhsNum.Uint() % rhsNum.Uint()), nil + case value.NumberKindInt: + return value.Int(lhsNum.Int() % rhsNum.Int()), nil + case value.NumberKindFloat: + return value.Float(math.Mod(lhsNum.Float(), rhsNum.Float())), nil + } + + case token.POW: // number ^ number + lhsNum, rhsNum := lhs.Number(), rhs.Number() + switch fitNumberKinds(lhsNum.Kind(), rhsNum.Kind()) { + case value.NumberKindUint: + return value.Uint(intPow(lhsNum.Uint(), rhsNum.Uint())), nil + case value.NumberKindInt: + return value.Int(intPow(lhsNum.Int(), rhsNum.Int())), nil + case value.NumberKindFloat: + return value.Float(math.Pow(lhsNum.Float(), rhsNum.Float())), nil + } + + case token.LT: // number < number, string < string + // Check string first. + if lhs.Type() == value.TypeString { + return value.Bool(lhs.Text() < rhs.Text()), nil + } + + // Not a string; must be a number. + lhsNum, rhsNum := lhs.Number(), rhs.Number() + switch fitNumberKinds(lhsNum.Kind(), rhsNum.Kind()) { + case value.NumberKindUint: + return value.Bool(lhsNum.Uint() < rhsNum.Uint()), nil + case value.NumberKindInt: + return value.Bool(lhsNum.Int() < rhsNum.Int()), nil + case value.NumberKindFloat: + return value.Bool(lhsNum.Float() < rhsNum.Float()), nil + } + + case token.GT: // number > number, string > string + // Check string first. + if lhs.Type() == value.TypeString { + return value.Bool(lhs.Text() > rhs.Text()), nil + } + + // Not a string; must be a number. + lhsNum, rhsNum := lhs.Number(), rhs.Number() + switch fitNumberKinds(lhsNum.Kind(), rhsNum.Kind()) { + case value.NumberKindUint: + return value.Bool(lhsNum.Uint() > rhsNum.Uint()), nil + case value.NumberKindInt: + return value.Bool(lhsNum.Int() > rhsNum.Int()), nil + case value.NumberKindFloat: + return value.Bool(lhsNum.Float() > rhsNum.Float()), nil + } + + case token.LTE: // number <= number, string <= string + // Check string first. + if lhs.Type() == value.TypeString { + return value.Bool(lhs.Text() <= rhs.Text()), nil + } + + // Not a string; must be a number. + lhsNum, rhsNum := lhs.Number(), rhs.Number() + switch fitNumberKinds(lhsNum.Kind(), rhsNum.Kind()) { + case value.NumberKindUint: + return value.Bool(lhsNum.Uint() <= rhsNum.Uint()), nil + case value.NumberKindInt: + return value.Bool(lhsNum.Int() <= rhsNum.Int()), nil + case value.NumberKindFloat: + return value.Bool(lhsNum.Float() <= rhsNum.Float()), nil + } + + case token.GTE: // number >= number, string >= string + // Check string first. + if lhs.Type() == value.TypeString { + return value.Bool(lhs.Text() >= rhs.Text()), nil + } + + // Not a string; must be a number. + lhsNum, rhsNum := lhs.Number(), rhs.Number() + switch fitNumberKinds(lhsNum.Kind(), rhsNum.Kind()) { + case value.NumberKindUint: + return value.Bool(lhsNum.Uint() >= rhsNum.Uint()), nil + case value.NumberKindInt: + return value.Bool(lhsNum.Int() >= rhsNum.Int()), nil + case value.NumberKindFloat: + return value.Bool(lhsNum.Float() >= rhsNum.Float()), nil + } + } + + panic("river/vm: unreachable") +} + +// tryUnwrapOptionalSecret accepts a value and, if it is a +// rivertypes.OptionalSecret where IsSecret is false, returns a string value +// instead. +// +// If val is not a rivertypes.OptionalSecret or IsSecret is true, +// tryUnwrapOptionalSecret returns the input value unchanged. +func tryUnwrapOptionalSecret(val value.Value) value.Value { + optSecret, ok := val.Interface().(rivertypes.OptionalSecret) + if !ok || optSecret.IsSecret { + return val + } + + return value.String(optSecret.Value) +} + +// valuesEqual returns true if two River Values are equal. +func valuesEqual(lhs value.Value, rhs value.Value) bool { + if lhs.Type() != rhs.Type() { + // Two values with different types are never equal. + return false + } + + switch lhs.Type() { + case value.TypeNull: + // Nothing to compare here: both lhs and rhs have the null type, + // so they're equal. + return true + + case value.TypeNumber: + // Two numbers are equal if they have equal values. However, we have to + // determine what comparison we want to do and upcast the values to a + // different Go type as needed (so that 3 == 3.0 is true). + lhsNum, rhsNum := lhs.Number(), rhs.Number() + switch fitNumberKinds(lhsNum.Kind(), rhsNum.Kind()) { + case value.NumberKindUint: + return lhsNum.Uint() == rhsNum.Uint() + case value.NumberKindInt: + return lhsNum.Int() == rhsNum.Int() + case value.NumberKindFloat: + return lhsNum.Float() == rhsNum.Float() + } + + case value.TypeString: + return lhs.Text() == rhs.Text() + + case value.TypeBool: + return lhs.Bool() == rhs.Bool() + + case value.TypeArray: + // Two arrays are equal if they have equal elements. + if lhs.Len() != rhs.Len() { + return false + } + for i := 0; i < lhs.Len(); i++ { + if !valuesEqual(lhs.Index(i), rhs.Index(i)) { + return false + } + } + return true + + case value.TypeObject: + // Two objects are equal if they have equal elements. + if lhs.Len() != rhs.Len() { + return false + } + for _, key := range lhs.Keys() { + lhsElement, _ := lhs.Key(key) + rhsElement, inRHS := rhs.Key(key) + if !inRHS { + return false + } + if !valuesEqual(lhsElement, rhsElement) { + return false + } + } + return true + + case value.TypeFunction: + // Two functions are never equal. We can't compare functions in Go, so + // there's no way to compare them in River right now. + return false + + case value.TypeCapsule: + // Two capsules are only equal if the underlying values are deeply equal. + return reflect.DeepEqual(lhs.Interface(), rhs.Interface()) + } + + panic("river/vm: unreachable") +} + +// binopAllowedTypes maps what type of values are permitted for a specific +// binary operation. +// +// token.EQ and token.NEQ are not included as they're handled separately from +// other binary ops. +var binopAllowedTypes = map[token.Token][]value.Type{ + token.OR: {value.TypeBool}, + token.AND: {value.TypeBool}, + + token.ADD: {value.TypeNumber, value.TypeString}, + token.SUB: {value.TypeNumber}, + token.MUL: {value.TypeNumber}, + token.DIV: {value.TypeNumber}, + token.MOD: {value.TypeNumber}, + token.POW: {value.TypeNumber}, + + token.LT: {value.TypeNumber, value.TypeString}, + token.GT: {value.TypeNumber, value.TypeString}, + token.LTE: {value.TypeNumber, value.TypeString}, + token.GTE: {value.TypeNumber, value.TypeString}, +} + +func acceptableBinopType(val value.Value, op token.Token) bool { + allowed, ok := binopAllowedTypes[op] + if !ok { + panic("river/vm: unexpected binop type") + } + + actualType := val.Type() + for _, allowType := range allowed { + if allowType == actualType { + return true + } + } + return false +} + +func fitNumberKinds(a, b value.NumberKind) value.NumberKind { + aPrec, bPrec := numberKindPrec[a], numberKindPrec[b] + if aPrec > bPrec { + return a + } + return b +} + +var numberKindPrec = map[value.NumberKind]int{ + value.NumberKindUint: 0, + value.NumberKindInt: 1, + value.NumberKindFloat: 2, +} + +func intPow[Number int64 | uint64](n, m Number) Number { + if m == 0 { + return 1 + } + result := n + for i := Number(2); i <= m; i++ { + result *= n + } + return result +} diff --git a/syntax/vm/op_binary_test.go b/syntax/vm/op_binary_test.go new file mode 100644 index 0000000000..45367777e4 --- /dev/null +++ b/syntax/vm/op_binary_test.go @@ -0,0 +1,94 @@ +package vm_test + +import ( + "reflect" + "testing" + + "github.com/grafana/river/parser" + "github.com/grafana/river/rivertypes" + "github.com/grafana/river/vm" + "github.com/stretchr/testify/require" +) + +func TestVM_OptionalSecret_Conversion(t *testing.T) { + scope := &vm.Scope{ + Variables: map[string]any{ + "string_val": "hello", + "non_secret_val": rivertypes.OptionalSecret{IsSecret: false, Value: "world"}, + "secret_val": rivertypes.OptionalSecret{IsSecret: true, Value: "secret"}, + }, + } + + tt := []struct { + name string + input string + expect interface{} + expectError string + }{ + { + name: "string + capsule", + input: `string_val + non_secret_val`, + expect: string("helloworld"), + }, + { + name: "capsule + string", + input: `non_secret_val + string_val`, + expect: string("worldhello"), + }, + { + name: "string == capsule", + input: `"world" == non_secret_val`, + expect: bool(true), + }, + { + name: "capsule == string", + input: `non_secret_val == "world"`, + expect: bool(true), + }, + { + name: "capsule (secret) == capsule (secret)", + input: `secret_val == secret_val`, + expect: bool(true), + }, + { + name: "capsule (non secret) == capsule (non secret)", + input: `non_secret_val == non_secret_val`, + expect: bool(true), + }, + { + name: "capsule (non secret) == capsule (secret)", + input: `non_secret_val == secret_val`, + expect: bool(false), + }, + { + name: "secret + string", + input: `secret_val + string_val`, + expectError: "secret_val should be one of [number string] for binop +", + }, + { + name: "string + secret", + input: `string_val + secret_val`, + expectError: "secret_val should be one of [number string] for binop +", + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + expr, err := parser.ParseExpression(tc.input) + require.NoError(t, err) + + expectTy := reflect.TypeOf(tc.expect) + if expectTy == nil { + expectTy = reflect.TypeOf((*any)(nil)).Elem() + } + rv := reflect.New(expectTy) + + if err := vm.New(expr).Evaluate(scope, rv.Interface()); tc.expectError == "" { + require.NoError(t, err) + require.Equal(t, tc.expect, rv.Elem().Interface()) + } else { + require.ErrorContains(t, err, tc.expectError) + } + }) + } +} diff --git a/syntax/vm/op_unary.go b/syntax/vm/op_unary.go new file mode 100644 index 0000000000..bc116d58bc --- /dev/null +++ b/syntax/vm/op_unary.go @@ -0,0 +1,33 @@ +package vm + +import ( + "github.com/grafana/river/internal/value" + "github.com/grafana/river/token" +) + +func evalUnaryOp(op token.Token, val value.Value) (value.Value, error) { + switch op { + case token.NOT: + if val.Type() != value.TypeBool { + return value.Null, value.TypeError{Value: val, Expected: value.TypeBool} + } + return value.Bool(!val.Bool()), nil + + case token.SUB: + if val.Type() != value.TypeNumber { + return value.Null, value.TypeError{Value: val, Expected: value.TypeNumber} + } + + valNum := val.Number() + switch valNum.Kind() { + case value.NumberKindInt, value.NumberKindUint: + // It doesn't make much sense to invert a uint, so we always cast to an + // int and return an int. + return value.Int(-valNum.Int()), nil + case value.NumberKindFloat: + return value.Float(-valNum.Float()), nil + } + } + + panic("river/vm: unreachable") +} diff --git a/syntax/vm/struct_decoder.go b/syntax/vm/struct_decoder.go new file mode 100644 index 0000000000..99a0a2358a --- /dev/null +++ b/syntax/vm/struct_decoder.go @@ -0,0 +1,323 @@ +package vm + +import ( + "fmt" + "reflect" + "strings" + + "github.com/grafana/river/ast" + "github.com/grafana/river/diag" + "github.com/grafana/river/internal/reflectutil" + "github.com/grafana/river/internal/rivertags" + "github.com/grafana/river/internal/value" +) + +// structDecoder decodes a series of AST statements into a Go value. +type structDecoder struct { + VM *Evaluator + Scope *Scope + Assoc map[value.Value]ast.Node + TagInfo *tagInfo +} + +// Decode decodes the list of statements into the struct value specified by rv. +func (st *structDecoder) Decode(stmts ast.Body, rv reflect.Value) error { + // TODO(rfratto): potentially loosen this restriction and allow decoding into + // an interface{} or map[string]interface{}. + if rv.Kind() != reflect.Struct { + panic(fmt.Sprintf("river/vm: structDecoder expects struct, got %s", rv.Kind())) + } + + state := decodeOptions{ + Tags: st.TagInfo.TagLookup, + EnumBlocks: st.TagInfo.EnumLookup, + SeenAttrs: make(map[string]struct{}), + SeenBlocks: make(map[string]struct{}), + SeenEnums: make(map[string]struct{}), + + BlockCount: make(map[string]int), + BlockIndex: make(map[*ast.BlockStmt]int), + + EnumCount: make(map[string]int), + EnumIndex: make(map[*ast.BlockStmt]int), + } + + // Iterate over the set of blocks to populate block count and block index. + // Block index is its index in the set of blocks with the *same name*. + // + // If the block belongs to an enum, we populate enum count and enum index + // instead. The enum index is the index on the set of blocks for the *same + // enum*. + for _, stmt := range stmts { + switch stmt := stmt.(type) { + case *ast.BlockStmt: + fullName := strings.Join(stmt.Name, ".") + + if enumTf, isEnum := st.TagInfo.EnumLookup[fullName]; isEnum { + enumName := strings.Join(enumTf.EnumField.Name, ".") + state.EnumIndex[stmt] = state.EnumCount[enumName] + state.EnumCount[enumName]++ + } else { + state.BlockIndex[stmt] = state.BlockCount[fullName] + state.BlockCount[fullName]++ + } + } + } + + for _, stmt := range stmts { + switch stmt := stmt.(type) { + case *ast.AttributeStmt: + // TODO(rfratto): append to list of diagnostics instead of aborting early. + if err := st.decodeAttr(stmt, rv, &state); err != nil { + return err + } + + case *ast.BlockStmt: + // TODO(rfratto): append to list of diagnostics instead of aborting early. + if err := st.decodeBlock(stmt, rv, &state); err != nil { + return err + } + + default: + panic(fmt.Sprintf("river/vm: unrecognized node type %T", stmt)) + } + } + + for _, tf := range st.TagInfo.Tags { + // Ignore any optional tags. + if tf.IsOptional() { + continue + } + + fullName := strings.Join(tf.Name, ".") + + switch { + case tf.IsAttr(): + if _, consumed := state.SeenAttrs[fullName]; !consumed { + // TODO(rfratto): change to diagnostics. + return fmt.Errorf("missing required attribute %q", fullName) + } + + case tf.IsBlock(): + if _, consumed := state.SeenBlocks[fullName]; !consumed { + // TODO(rfratto): change to diagnostics. + return fmt.Errorf("missing required block %q", fullName) + } + } + } + + return nil +} + +type decodeOptions struct { + Tags map[string]rivertags.Field + EnumBlocks map[string]enumBlock + + SeenAttrs, SeenBlocks, SeenEnums map[string]struct{} + + // BlockCount and BlockIndex are used to determine: + // + // * How big a slice of blocks should be for a block of a given name (BlockCount) + // * Which element within that slice is a given block assigned to (BlockIndex) + // + // This is used for decoding a series of rule blocks for prometheus.relabel, + // where 5 rules would have a "rule" key in BlockCount with a value of 5, and + // where the first block would be index 0, the second block would be index 1, + // and so on. + // + // The index in BlockIndex is relative to a block name; the first block named + // "hello.world" and the first block named "fizz.buzz" both have index 0. + + BlockCount map[string]int // Number of times a block by full name is seen. + BlockIndex map[*ast.BlockStmt]int // Index of a block within a set of blocks with the same name. + + // EnumCount and EnumIndex are similar to BlockCount/BlockIndex, but instead + // reference the number of blocks assigned to the same enum (EnumCount) and + // the index of a block within that enum slice (EnumIndex). + + EnumCount map[string]int // Number of times an enum group is seen by enum name. + EnumIndex map[*ast.BlockStmt]int // Index of a block within a set of enum blocks of the same enum. +} + +func (st *structDecoder) decodeAttr(attr *ast.AttributeStmt, rv reflect.Value, state *decodeOptions) error { + fullName := attr.Name.Name + if _, seen := state.SeenAttrs[fullName]; seen { + return diag.Diagnostics{{ + Severity: diag.SeverityLevelError, + StartPos: ast.StartPos(attr).Position(), + EndPos: ast.EndPos(attr).Position(), + Message: fmt.Sprintf("attribute %q may only be provided once", fullName), + }} + } + state.SeenAttrs[fullName] = struct{}{} + + tf, ok := state.Tags[fullName] + if !ok { + return diag.Diagnostics{{ + Severity: diag.SeverityLevelError, + StartPos: ast.StartPos(attr).Position(), + EndPos: ast.EndPos(attr).Position(), + Message: fmt.Sprintf("unrecognized attribute name %q", fullName), + }} + } else if tf.IsBlock() { + return diag.Diagnostics{{ + Severity: diag.SeverityLevelError, + StartPos: ast.StartPos(attr).Position(), + EndPos: ast.EndPos(attr).Position(), + Message: fmt.Sprintf("%q must be a block, but is used as an attribute", fullName), + }} + } + + // Decode the attribute. + val, err := st.VM.evaluateExpr(st.Scope, st.Assoc, attr.Value) + if err != nil { + // TODO(rfratto): get error as diagnostics. + return err + } + + // We're reconverting our reflect.Value back into an interface{}, so we + // need to also turn it back into a pointer for decoding. + field := reflectutil.GetOrAlloc(rv, tf) + if err := value.Decode(val, field.Addr().Interface()); err != nil { + // TODO(rfratto): get error as diagnostics. + return err + } + + return nil +} + +func (st *structDecoder) decodeBlock(block *ast.BlockStmt, rv reflect.Value, state *decodeOptions) error { + fullName := block.GetBlockName() + + if _, isEnum := state.EnumBlocks[fullName]; isEnum { + return st.decodeEnumBlock(fullName, block, rv, state) + } + return st.decodeNormalBlock(fullName, block, rv, state) +} + +// decodeNormalBlock decodes a standard (non-enum) block. +func (st *structDecoder) decodeNormalBlock(fullName string, block *ast.BlockStmt, rv reflect.Value, state *decodeOptions) error { + tf, isBlock := state.Tags[fullName] + if !isBlock { + return diag.Diagnostics{{ + Severity: diag.SeverityLevelError, + StartPos: ast.StartPos(block).Position(), + EndPos: ast.EndPos(block).Position(), + Message: fmt.Sprintf("unrecognized block name %q", fullName), + }} + } else if tf.IsAttr() { + return diag.Diagnostics{{ + Severity: diag.SeverityLevelError, + StartPos: ast.StartPos(block).Position(), + EndPos: ast.EndPos(block).Position(), + Message: fmt.Sprintf("%q must be an attribute, but is used as a block", fullName), + }} + } + + field := reflectutil.GetOrAlloc(rv, tf) + decodeField := prepareDecodeValue(field) + + switch decodeField.Kind() { + case reflect.Slice: + // If this is the first time we've seen the block, reset its length to + // zero. + if _, seen := state.SeenBlocks[fullName]; !seen { + count := state.BlockCount[fullName] + decodeField.Set(reflect.MakeSlice(decodeField.Type(), count, count)) + } + + blockIndex, ok := state.BlockIndex[block] + if !ok { + panic("river/vm: block not found in index lookup table") + } + decodeElement := prepareDecodeValue(decodeField.Index(blockIndex)) + err := st.VM.evaluateBlockOrBody(st.Scope, st.Assoc, block, decodeElement) + if err != nil { + // TODO(rfratto): get error as diagnostics. + return err + } + + case reflect.Array: + if decodeField.Len() != state.BlockCount[fullName] { + return diag.Diagnostics{{ + Severity: diag.SeverityLevelError, + StartPos: ast.StartPos(block).Position(), + EndPos: ast.EndPos(block).Position(), + Message: fmt.Sprintf( + "block %q must be specified exactly %d times, but was specified %d times", + fullName, + decodeField.Len(), + state.BlockCount[fullName], + ), + }} + } + + blockIndex, ok := state.BlockIndex[block] + if !ok { + panic("river/vm: block not found in index lookup table") + } + decodeElement := prepareDecodeValue(decodeField.Index(blockIndex)) + err := st.VM.evaluateBlockOrBody(st.Scope, st.Assoc, block, decodeElement) + if err != nil { + // TODO(rfratto): get error as diagnostics. + return err + } + + default: + if state.BlockCount[fullName] > 1 { + return diag.Diagnostics{{ + Severity: diag.SeverityLevelError, + StartPos: ast.StartPos(block).Position(), + EndPos: ast.EndPos(block).Position(), + Message: fmt.Sprintf("block %q may only be specified once", fullName), + }} + } + + err := st.VM.evaluateBlockOrBody(st.Scope, st.Assoc, block, decodeField) + if err != nil { + // TODO(rfratto): get error as diagnostics. + return err + } + } + + state.SeenBlocks[fullName] = struct{}{} + return nil +} + +func (st *structDecoder) decodeEnumBlock(fullName string, block *ast.BlockStmt, rv reflect.Value, state *decodeOptions) error { + tf, ok := state.EnumBlocks[fullName] + if !ok { + // decodeEnumBlock should only ever be called from decodeBlock, so this + // should never happen. + panic("decodeEnumBlock called with a non-enum block") + } + + enumName := strings.Join(tf.EnumField.Name, ".") + enumField := reflectutil.GetOrAlloc(rv, tf.EnumField) + decodeField := prepareDecodeValue(enumField) + + if decodeField.Kind() != reflect.Slice { + panic("river/vm: enum field must be a slice kind, got " + decodeField.Kind().String()) + } + + // If this is the first time we've seen the enum, reset its length to zero. + if _, seen := state.SeenEnums[enumName]; !seen { + count := state.EnumCount[enumName] + decodeField.Set(reflect.MakeSlice(decodeField.Type(), count, count)) + } + state.SeenEnums[enumName] = struct{}{} + + // Prepare the enum element to decode into. + enumIndex, ok := state.EnumIndex[block] + if !ok { + panic("river/vm: enum block not found in index lookup table") + } + enumElement := prepareDecodeValue(decodeField.Index(enumIndex)) + + // Prepare the block field to decode into. + enumBlock := reflectutil.GetOrAlloc(enumElement, tf.BlockField) + decodeBlock := prepareDecodeValue(enumBlock) + + // Decode into the block field. + return st.VM.evaluateBlockOrBody(st.Scope, st.Assoc, block, decodeBlock) +} diff --git a/syntax/vm/tag_cache.go b/syntax/vm/tag_cache.go new file mode 100644 index 0000000000..f9c1b69c56 --- /dev/null +++ b/syntax/vm/tag_cache.go @@ -0,0 +1,80 @@ +package vm + +import ( + "reflect" + "strings" + "sync" + + "github.com/grafana/river/internal/rivertags" +) + +// tagsCache caches the river tags for a struct type. This is never cleared, +// but since most structs will be statically created throughout the lifetime +// of the process, this will consume a negligible amount of memory. +var tagsCache sync.Map + +func getCachedTagInfo(t reflect.Type) *tagInfo { + if t.Kind() != reflect.Struct { + panic("getCachedTagInfo called with non-struct type") + } + + if entry, ok := tagsCache.Load(t); ok { + return entry.(*tagInfo) + } + + tfs := rivertags.Get(t) + ti := &tagInfo{ + Tags: tfs, + TagLookup: make(map[string]rivertags.Field, len(tfs)), + EnumLookup: make(map[string]enumBlock), // The length is not known ahead of time + } + + for _, tf := range tfs { + switch { + case tf.IsAttr(), tf.IsBlock(): + fullName := strings.Join(tf.Name, ".") + ti.TagLookup[fullName] = tf + + case tf.IsEnum(): + fullName := strings.Join(tf.Name, ".") + + // Find all the blocks that match to the enum, and inject them into the + // EnumLookup table. + enumFieldType := t.FieldByIndex(tf.Index).Type + enumBlocksInfo := getCachedTagInfo(deferenceType(enumFieldType.Elem())) + for _, blockField := range enumBlocksInfo.TagLookup { + // The full name of the enum block is the name of the enum plus the + // name of the block, separated by '.' + enumBlockName := fullName + "." + strings.Join(blockField.Name, ".") + ti.EnumLookup[enumBlockName] = enumBlock{ + EnumField: tf, + BlockField: blockField, + } + } + } + } + + tagsCache.Store(t, ti) + return ti +} + +func deferenceType(ty reflect.Type) reflect.Type { + for ty.Kind() == reflect.Pointer { + ty = ty.Elem() + } + return ty +} + +type tagInfo struct { + Tags []rivertags.Field + TagLookup map[string]rivertags.Field + + // EnumLookup maps enum blocks to the enum field. For example, an enum block + // called "foo.foo" and "foo.bar" will both map to the "foo" enum field. + EnumLookup map[string]enumBlock +} + +type enumBlock struct { + EnumField rivertags.Field // Field in the parent struct of the enum slice + BlockField rivertags.Field // Field in the enum struct for the enum block +} diff --git a/syntax/vm/vm.go b/syntax/vm/vm.go new file mode 100644 index 0000000000..a9c6481593 --- /dev/null +++ b/syntax/vm/vm.go @@ -0,0 +1,486 @@ +// Package vm provides a River expression evaluator. +package vm + +import ( + "fmt" + "reflect" + "strings" + + "github.com/grafana/river/ast" + "github.com/grafana/river/diag" + "github.com/grafana/river/internal/reflectutil" + "github.com/grafana/river/internal/rivertags" + "github.com/grafana/river/internal/stdlib" + "github.com/grafana/river/internal/value" +) + +// Evaluator evaluates River AST nodes into Go values. Each Evaluator is bound +// to a single AST node. To evaluate the node, call Evaluate. +type Evaluator struct { + // node for the AST. + // + // Each Evaluator is bound to a single node to allow for future performance + // optimizations, allowing for precomputing and storing the result of + // anything that is constant. + node ast.Node +} + +// New creates a new Evaluator for the given AST node. The given node must be +// either an *ast.File, *ast.BlockStmt, ast.Body, or assignable to an ast.Expr. +func New(node ast.Node) *Evaluator { + return &Evaluator{node: node} +} + +// Evaluate evaluates the Evaluator's node into a River value and decodes that +// value into the Go value v. +// +// Each call to Evaluate may provide a different scope with new values for +// available variables. If a variable used by the Evaluator's node isn't +// defined in scope or any of the parent scopes, Evaluate will return an error. +func (vm *Evaluator) Evaluate(scope *Scope, v interface{}) (err error) { + // Track a map that allows us to associate values with ast.Nodes so we can + // return decorated error messages. + assoc := make(map[value.Value]ast.Node) + + defer func() { + if err != nil { + // Decorate the error on return. + err = makeDiagnostic(err, assoc) + } + }() + + switch node := vm.node.(type) { + case *ast.BlockStmt, ast.Body: + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Pointer { + panic(fmt.Sprintf("river/vm: expected pointer, got %s", rv.Kind())) + } + return vm.evaluateBlockOrBody(scope, assoc, node, rv) + case *ast.File: + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Pointer { + panic(fmt.Sprintf("river/vm: expected pointer, got %s", rv.Kind())) + } + return vm.evaluateBlockOrBody(scope, assoc, node.Body, rv) + default: + expr, ok := node.(ast.Expr) + if !ok { + panic(fmt.Sprintf("river/vm: unexpected value type %T", node)) + } + val, err := vm.evaluateExpr(scope, assoc, expr) + if err != nil { + return err + } + return value.Decode(val, v) + } +} + +func (vm *Evaluator) evaluateBlockOrBody(scope *Scope, assoc map[value.Value]ast.Node, node ast.Node, rv reflect.Value) error { + // Before decoding the block, we need to temporarily take the address of rv + // to handle the case of it implementing the unmarshaler interface. + if rv.CanAddr() { + rv = rv.Addr() + } + + if err, unmarshaled := vm.evaluateUnmarshalRiver(scope, assoc, node, rv); unmarshaled || err != nil { + return err + } + + if ru, ok := rv.Interface().(value.Defaulter); ok { + ru.SetToDefault() + } + + if err := vm.evaluateDecode(scope, assoc, node, rv); err != nil { + return err + } + + if ru, ok := rv.Interface().(value.Validator); ok { + if err := ru.Validate(); err != nil { + return err + } + } + + return nil +} + +func (vm *Evaluator) evaluateUnmarshalRiver(scope *Scope, assoc map[value.Value]ast.Node, node ast.Node, rv reflect.Value) (error, bool) { + if ru, ok := rv.Interface().(value.Unmarshaler); ok { + return ru.UnmarshalRiver(func(v interface{}) error { + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Pointer { + panic(fmt.Sprintf("river/vm: expected pointer, got %s", rv.Kind())) + } + return vm.evaluateBlockOrBody(scope, assoc, node, rv.Elem()) + }), true + } + + return nil, false +} + +func (vm *Evaluator) evaluateDecode(scope *Scope, assoc map[value.Value]ast.Node, node ast.Node, rv reflect.Value) error { + // TODO(rfratto): the errors returned by this function are missing context to + // be able to print line numbers. We need to return decorated error types. + + // Fully deference rv and allocate pointers as necessary. + for rv.Kind() == reflect.Pointer { + if rv.IsNil() { + rv.Set(reflect.New(rv.Type().Elem())) + } + rv = rv.Elem() + } + + if rv.Kind() == reflect.Interface { + var anyMap map[string]interface{} + into := reflect.MakeMap(reflect.TypeOf(anyMap)) + if err := vm.evaluateMap(scope, assoc, node, into); err != nil { + return err + } + + rv.Set(into) + return nil + } else if rv.Kind() == reflect.Map { + return vm.evaluateMap(scope, assoc, node, rv) + } else if rv.Kind() != reflect.Struct { + panic(fmt.Sprintf("river/vm: can only evaluate blocks into structs, got %s", rv.Kind())) + } + + ti := getCachedTagInfo(rv.Type()) + + var stmts ast.Body + switch node := node.(type) { + case *ast.BlockStmt: + // Decode the block label first. + if err := vm.evaluateBlockLabel(node, ti.Tags, rv); err != nil { + return err + } + stmts = node.Body + case ast.Body: + stmts = node + default: + panic(fmt.Sprintf("river/vm: unrecognized node type %T", node)) + } + + sd := structDecoder{ + VM: vm, + Scope: scope, + Assoc: assoc, + TagInfo: ti, + } + return sd.Decode(stmts, rv) +} + +// evaluateMap evaluates a block or a body into a map. +func (vm *Evaluator) evaluateMap(scope *Scope, assoc map[value.Value]ast.Node, node ast.Node, rv reflect.Value) error { + var stmts ast.Body + + switch node := node.(type) { + case *ast.BlockStmt: + if node.Label != "" { + return diag.Diagnostic{ + Severity: diag.SeverityLevelError, + StartPos: node.NamePos.Position(), + EndPos: node.LCurlyPos.Position(), + Message: fmt.Sprintf("block %q requires non-empty label", strings.Join(node.Name, ".")), + } + } + stmts = node.Body + case ast.Body: + stmts = node + default: + panic(fmt.Sprintf("river/vm: unrecognized node type %T", node)) + } + + if rv.IsNil() { + rv.Set(reflect.MakeMap(rv.Type())) + } + + for _, stmt := range stmts { + switch stmt := stmt.(type) { + case *ast.AttributeStmt: + val, err := vm.evaluateExpr(scope, assoc, stmt.Value) + if err != nil { + // TODO(rfratto): get error as diagnostics. + return err + } + + target := reflect.New(rv.Type().Elem()).Elem() + if err := value.Decode(val, target.Addr().Interface()); err != nil { + // TODO(rfratto): get error as diagnostics. + return err + } + rv.SetMapIndex(reflect.ValueOf(stmt.Name.Name), target) + + case *ast.BlockStmt: + // TODO(rfratto): potentially relax this restriction where nested blocks + // are permitted when decoding to a map. + return diag.Diagnostic{ + Severity: diag.SeverityLevelError, + StartPos: ast.StartPos(stmt).Position(), + EndPos: ast.EndPos(stmt).Position(), + Message: "nested blocks not supported here", + } + + default: + panic(fmt.Sprintf("river/vm: unrecognized node type %T", stmt)) + } + } + + return nil +} + +func (vm *Evaluator) evaluateBlockLabel(node *ast.BlockStmt, tfs []rivertags.Field, rv reflect.Value) error { + var ( + labelField rivertags.Field + foundField bool + ) + for _, tf := range tfs { + if tf.Flags&rivertags.FlagLabel != 0 { + labelField = tf + foundField = true + break + } + } + + // Check for user errors first. + // + // We return parser.Error here to restrict the position of the error to just + // the name. We might be able to clean this up in the future by extending + // ValueError to have an explicit position. + switch { + case node.Label == "" && foundField: // No user label, but struct expects one + return diag.Diagnostic{ + Severity: diag.SeverityLevelError, + StartPos: node.NamePos.Position(), + EndPos: node.LCurlyPos.Position(), + Message: fmt.Sprintf("block %q requires non-empty label", strings.Join(node.Name, ".")), + } + case node.Label != "" && !foundField: // User label, but struct doesn't expect one + return diag.Diagnostic{ + Severity: diag.SeverityLevelError, + StartPos: node.NamePos.Position(), + EndPos: node.LCurlyPos.Position(), + Message: fmt.Sprintf("block %q does not support specifying labels", strings.Join(node.Name, ".")), + } + } + + if node.Label == "" { + // no-op: no labels to set. + return nil + } + + var ( + field = reflectutil.GetOrAlloc(rv, labelField) + fieldType = field.Type() + ) + if !reflect.TypeOf(node.Label).AssignableTo(fieldType) { + // The Label struct field needs to be a string. + panic(fmt.Sprintf("river/vm: cannot assign block label to non-string type %s", fieldType)) + } + field.Set(reflect.ValueOf(node.Label)) + return nil +} + +// prepareDecodeValue prepares v for decoding. Pointers will be fully +// dereferenced until finding a non-pointer value. nil pointers will be +// allocated. +func prepareDecodeValue(v reflect.Value) reflect.Value { + for v.Kind() == reflect.Pointer { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + return v +} + +func (vm *Evaluator) evaluateExpr(scope *Scope, assoc map[value.Value]ast.Node, expr ast.Expr) (v value.Value, err error) { + defer func() { + if v != value.Null { + assoc[v] = expr + } + }() + + switch expr := expr.(type) { + case *ast.LiteralExpr: + return valueFromLiteral(expr.Value, expr.Kind) + + case *ast.BinaryExpr: + lhs, err := vm.evaluateExpr(scope, assoc, expr.Left) + if err != nil { + return value.Null, err + } + rhs, err := vm.evaluateExpr(scope, assoc, expr.Right) + if err != nil { + return value.Null, err + } + return evalBinop(lhs, expr.Kind, rhs) + + case *ast.ArrayExpr: + vals := make([]value.Value, len(expr.Elements)) + for i, element := range expr.Elements { + val, err := vm.evaluateExpr(scope, assoc, element) + if err != nil { + return value.Null, err + } + vals[i] = val + } + return value.Array(vals...), nil + + case *ast.ObjectExpr: + fields := make(map[string]value.Value, len(expr.Fields)) + for _, field := range expr.Fields { + val, err := vm.evaluateExpr(scope, assoc, field.Value) + if err != nil { + return value.Null, err + } + fields[field.Name.Name] = val + } + return value.Object(fields), nil + + case *ast.IdentifierExpr: + val, found := scope.Lookup(expr.Ident.Name) + if !found { + return value.Null, diag.Diagnostic{ + Severity: diag.SeverityLevelError, + StartPos: ast.StartPos(expr).Position(), + EndPos: ast.EndPos(expr).Position(), + Message: fmt.Sprintf("identifier %q does not exist", expr.Ident.Name), + } + } + return value.Encode(val), nil + + case *ast.AccessExpr: + val, err := vm.evaluateExpr(scope, assoc, expr.Value) + if err != nil { + return value.Null, err + } + + switch val.Type() { + case value.TypeObject: + res, ok := val.Key(expr.Name.Name) + if !ok { + return value.Null, diag.Diagnostic{ + Severity: diag.SeverityLevelError, + StartPos: ast.StartPos(expr.Name).Position(), + EndPos: ast.EndPos(expr.Name).Position(), + Message: fmt.Sprintf("field %q does not exist", expr.Name.Name), + } + } + return res, nil + default: + return value.Null, value.Error{ + Value: val, + Inner: fmt.Errorf("cannot access field %q on value of type %s", expr.Name.Name, val.Type()), + } + } + + case *ast.IndexExpr: + val, err := vm.evaluateExpr(scope, assoc, expr.Value) + if err != nil { + return value.Null, err + } + idx, err := vm.evaluateExpr(scope, assoc, expr.Index) + if err != nil { + return value.Null, err + } + + switch val.Type() { + case value.TypeArray: + // Arrays are indexed with a number. + if idx.Type() != value.TypeNumber { + return value.Null, value.TypeError{Value: idx, Expected: value.TypeNumber} + } + intIndex := int(idx.Int()) + + if intIndex < 0 || intIndex >= val.Len() { + return value.Null, value.Error{ + Value: idx, + Inner: fmt.Errorf("index %d is out of range of array with length %d", intIndex, val.Len()), + } + } + return val.Index(intIndex), nil + + case value.TypeObject: + // Objects are indexed with a string. + if idx.Type() != value.TypeString { + return value.Null, value.TypeError{Value: idx, Expected: value.TypeString} + } + + field, ok := val.Key(idx.Text()) + if !ok { + // If a key doesn't exist in an object accessed with [], return null. + return value.Null, nil + } + return field, nil + + default: + return value.Null, value.Error{ + Value: val, + Inner: fmt.Errorf("expected object or array, got %s", val.Type()), + } + } + + case *ast.ParenExpr: + return vm.evaluateExpr(scope, assoc, expr.Inner) + + case *ast.UnaryExpr: + val, err := vm.evaluateExpr(scope, assoc, expr.Value) + if err != nil { + return value.Null, err + } + return evalUnaryOp(expr.Kind, val) + + case *ast.CallExpr: + funcVal, err := vm.evaluateExpr(scope, assoc, expr.Value) + if err != nil { + return funcVal, err + } + if funcVal.Type() != value.TypeFunction { + return value.Null, value.TypeError{Value: funcVal, Expected: value.TypeFunction} + } + + args := make([]value.Value, len(expr.Args)) + for i := 0; i < len(expr.Args); i++ { + args[i], err = vm.evaluateExpr(scope, assoc, expr.Args[i]) + if err != nil { + return value.Null, err + } + } + return funcVal.Call(args...) + + default: + panic(fmt.Sprintf("river/vm: unexpected ast.Expr type %T", expr)) + } +} + +// A Scope exposes a set of variables available to use during evaluation. +type Scope struct { + // Parent optionally points to a parent Scope containing more variable. + // Variables defined in children scopes take precedence over variables of the + // same name found in parent scopes. + Parent *Scope + + // Variables holds the list of available variable names that can be used when + // evaluating a node. + // + // Values in the Variables map should be considered immutable after passed to + // Evaluate; maps and slices will be copied by reference for performance + // optimizations. + Variables map[string]interface{} +} + +// Lookup looks up a named identifier from the scope, all of the scope's +// parents, and the stdlib. +func (s *Scope) Lookup(name string) (interface{}, bool) { + // Traverse the scope first, then fall back to stdlib. + for s != nil { + if val, ok := s.Variables[name]; ok { + return val, true + } + s = s.Parent + } + if ident, ok := stdlib.Identifiers[name]; ok { + return ident, true + } + return nil, false +} diff --git a/syntax/vm/vm_benchmarks_test.go b/syntax/vm/vm_benchmarks_test.go new file mode 100644 index 0000000000..e5530ccb37 --- /dev/null +++ b/syntax/vm/vm_benchmarks_test.go @@ -0,0 +1,106 @@ +package vm_test + +import ( + "fmt" + "math" + "reflect" + "testing" + + "github.com/grafana/river/parser" + "github.com/grafana/river/vm" + "github.com/stretchr/testify/require" +) + +func BenchmarkExprs(b *testing.B) { + // Shared scope across all tests below + scope := &vm.Scope{ + Variables: map[string]interface{}{ + "foobar": int(42), + }, + } + + tt := []struct { + name string + input string + expect interface{} + }{ + // Binops + {"or", `false || true`, bool(true)}, + {"and", `true && false`, bool(false)}, + {"math/eq", `3 == 5`, bool(false)}, + {"math/neq", `3 != 5`, bool(true)}, + {"math/lt", `3 < 5`, bool(true)}, + {"math/lte", `3 <= 5`, bool(true)}, + {"math/gt", `3 > 5`, bool(false)}, + {"math/gte", `3 >= 5`, bool(false)}, + {"math/add", `3 + 5`, int(8)}, + {"math/sub", `3 - 5`, int(-2)}, + {"math/mul", `3 * 5`, int(15)}, + {"math/div", `3 / 5`, int(0)}, + {"math/mod", `5 % 3`, int(2)}, + {"math/pow", `3 ^ 5`, int(243)}, + {"binop chain", `3 + 5 * 2`, int(13)}, // Chain multiple binops + + // Identifier + {"ident lookup", `foobar`, int(42)}, + + // Arrays + {"array", `[0, 1, 2]`, []int{0, 1, 2}}, + + // Objects + {"object to map", `{ a = 5, b = 10 }`, map[string]int{"a": 5, "b": 10}}, + { + name: "object to struct", + input: `{ + name = "John Doe", + age = 42, + }`, + expect: struct { + Name string `river:"name,attr"` + Age int `river:"age,attr"` + Country string `river:"country,attr,optional"` + }{ + Name: "John Doe", + Age: 42, + }, + }, + + // Access + {"access", `{ a = 15 }.a`, int(15)}, + {"nested access", `{ a = { b = 12 } }.a.b`, int(12)}, + + // Indexing + {"index", `[0, 1, 2][1]`, int(1)}, + {"nested index", `[[1,2,3]][0][2]`, int(3)}, + + // Paren + {"paren", `(15)`, int(15)}, + + // Unary + {"unary not", `!true`, bool(false)}, + {"unary neg", `-15`, int(-15)}, + {"unary float", `-15.0`, float64(-15.0)}, + {"unary int64", fmt.Sprintf("%v", math.MaxInt64), math.MaxInt64}, + {"unary uint64", fmt.Sprintf("%v", uint64(math.MaxInt64)+1), uint64(math.MaxInt64) + 1}, + // math.MaxUint64 + 1 = 18446744073709551616 + {"unary float64 from overflowing uint", "18446744073709551616", float64(18446744073709551616)}, + } + + for _, tc := range tt { + b.Run(tc.name, func(b *testing.B) { + b.StopTimer() + expr, err := parser.ParseExpression(tc.input) + require.NoError(b, err) + + eval := vm.New(expr) + b.StartTimer() + + expectType := reflect.TypeOf(tc.expect) + + for i := 0; i < b.N; i++ { + vPtr := reflect.New(expectType).Interface() + _ = eval.Evaluate(scope, vPtr) + } + }) + } +} diff --git a/syntax/vm/vm_block_test.go b/syntax/vm/vm_block_test.go new file mode 100644 index 0000000000..ebc2ff0e6b --- /dev/null +++ b/syntax/vm/vm_block_test.go @@ -0,0 +1,802 @@ +package vm_test + +import ( + "fmt" + "math" + "reflect" + "testing" + + "github.com/grafana/river/ast" + "github.com/grafana/river/parser" + "github.com/grafana/river/vm" + "github.com/stretchr/testify/require" +) + +// This file contains tests for decoding blocks. + +func TestVM_File(t *testing.T) { + type block struct { + String string `river:"string,attr"` + Number int `river:"number,attr,optional"` + } + type file struct { + SettingA int `river:"setting_a,attr"` + SettingB int `river:"setting_b,attr,optional"` + + Block block `river:"some_block,block,optional"` + } + + input := ` + setting_a = 15 + + some_block { + string = "Hello, world!" + } + ` + + expect := file{ + SettingA: 15, + Block: block{ + String: "Hello, world!", + }, + } + + res, err := parser.ParseFile(t.Name(), []byte(input)) + require.NoError(t, err) + + eval := vm.New(res) + + var actual file + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, expect, actual) +} + +func TestVM_Block_Attributes(t *testing.T) { + t.Run("Decodes attributes", func(t *testing.T) { + type block struct { + Number int `river:"number,attr"` + String string `river:"string,attr"` + } + + input := `some_block { + number = 15 + string = "Hello, world!" + }` + eval := vm.New(parseBlock(t, input)) + + var actual block + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, 15, actual.Number) + require.Equal(t, "Hello, world!", actual.String) + }) + + t.Run("Fails if attribute used as block", func(t *testing.T) { + type block struct { + Number int `river:"number,attr"` + } + + input := `some_block { + number {} + }` + eval := vm.New(parseBlock(t, input)) + + err := eval.Evaluate(nil, &block{}) + require.EqualError(t, err, `2:4: "number" must be an attribute, but is used as a block`) + }) + + t.Run("Fails if required attributes are not present", func(t *testing.T) { + type block struct { + Number int `river:"number,attr"` + String string `river:"string,attr"` + } + + input := `some_block { + number = 15 + }` + eval := vm.New(parseBlock(t, input)) + + err := eval.Evaluate(nil, &block{}) + require.EqualError(t, err, `missing required attribute "string"`) + }) + + t.Run("Succeeds if optional attributes are not present", func(t *testing.T) { + type block struct { + Number int `river:"number,attr"` + String string `river:"string,attr,optional"` + } + + input := `some_block { + number = 15 + }` + eval := vm.New(parseBlock(t, input)) + + var actual block + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, 15, actual.Number) + require.Equal(t, "", actual.String) + }) + + t.Run("Fails if attribute is not defined in struct", func(t *testing.T) { + type block struct { + Number int `river:"number,attr"` + } + + input := `some_block { + number = 15 + invalid = "This attribute does not exist!" + }` + eval := vm.New(parseBlock(t, input)) + + err := eval.Evaluate(nil, &block{}) + require.EqualError(t, err, `3:4: unrecognized attribute name "invalid"`) + }) + + t.Run("Tests decoding into an interface", func(t *testing.T) { + type block struct { + Anything interface{} `river:"anything,attr"` + } + + tests := []struct { + testName string + val string + expectedValType reflect.Kind + }{ + {testName: "test_int_1", val: "15", expectedValType: reflect.Int}, + {testName: "test_int_2", val: "-15", expectedValType: reflect.Int}, + {testName: "test_int_3", val: fmt.Sprintf("%v", math.MaxInt64), expectedValType: reflect.Int}, + {testName: "test_int_4", val: fmt.Sprintf("%v", math.MinInt64), expectedValType: reflect.Int}, + {testName: "test_uint_1", val: fmt.Sprintf("%v", uint64(math.MaxInt64)+1), expectedValType: reflect.Uint64}, + {testName: "test_uint_2", val: fmt.Sprintf("%v", uint64(math.MaxUint64)), expectedValType: reflect.Uint64}, + {testName: "test_float_1", val: fmt.Sprintf("%v9", math.MinInt64), expectedValType: reflect.Float64}, + {testName: "test_float_2", val: fmt.Sprintf("%v9", uint64(math.MaxUint64)), expectedValType: reflect.Float64}, + {testName: "test_float_3", val: "16.0", expectedValType: reflect.Float64}, + } + + for _, tt := range tests { + t.Run(tt.testName, func(t *testing.T) { + input := fmt.Sprintf(`some_block { + anything = %s + }`, tt.val) + eval := vm.New(parseBlock(t, input)) + + var actual block + err := eval.Evaluate(nil, &actual) + require.NoError(t, err) + require.Equal(t, tt.expectedValType.String(), reflect.TypeOf(actual.Anything).Kind().String()) + }) + } + }) + + t.Run("Supports arbitrarily nested struct pointer fields", func(t *testing.T) { + type block struct { + NumberA int `river:"number_a,attr"` + NumberB *int `river:"number_b,attr"` + NumberC **int `river:"number_c,attr"` + NumberD ***int `river:"number_d,attr"` + } + + input := `some_block { + number_a = 15 + number_b = 20 + number_c = 25 + number_d = 30 + }` + eval := vm.New(parseBlock(t, input)) + + var actual block + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, 15, actual.NumberA) + require.Equal(t, 20, *actual.NumberB) + require.Equal(t, 25, **actual.NumberC) + require.Equal(t, 30, ***actual.NumberD) + }) + + t.Run("Supports squashed attributes", func(t *testing.T) { + type InnerStruct struct { + InnerField1 string `river:"inner_field_1,attr,optional"` + InnerField2 string `river:"inner_field_2,attr,optional"` + } + + type OuterStruct struct { + OuterField1 string `river:"outer_field_1,attr,optional"` + Inner InnerStruct `river:",squash"` + OuterField2 string `river:"outer_field_2,attr,optional"` + } + + var ( + input = `some_block { + outer_field_1 = "value1" + outer_field_2 = "value2" + inner_field_1 = "value3" + inner_field_2 = "value4" + }` + + expect = OuterStruct{ + OuterField1: "value1", + Inner: InnerStruct{ + InnerField1: "value3", + InnerField2: "value4", + }, + OuterField2: "value2", + } + ) + eval := vm.New(parseBlock(t, input)) + + var actual OuterStruct + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, expect, actual) + }) + + t.Run("Supports squashed attributes in pointers", func(t *testing.T) { + type InnerStruct struct { + InnerField1 string `river:"inner_field_1,attr,optional"` + InnerField2 string `river:"inner_field_2,attr,optional"` + } + + type OuterStruct struct { + OuterField1 string `river:"outer_field_1,attr,optional"` + Inner *InnerStruct `river:",squash"` + OuterField2 string `river:"outer_field_2,attr,optional"` + } + + var ( + input = `some_block { + outer_field_1 = "value1" + outer_field_2 = "value2" + inner_field_1 = "value3" + inner_field_2 = "value4" + }` + + expect = OuterStruct{ + OuterField1: "value1", + Inner: &InnerStruct{ + InnerField1: "value3", + InnerField2: "value4", + }, + OuterField2: "value2", + } + ) + eval := vm.New(parseBlock(t, input)) + + var actual OuterStruct + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, expect, actual) + }) +} + +func TestVM_Block_Children_Blocks(t *testing.T) { + type childBlock struct { + Attr bool `river:"attr,attr"` + } + + t.Run("Decodes children blocks", func(t *testing.T) { + type block struct { + Value int `river:"value,attr"` + Child childBlock `river:"child.block,block"` + } + + input := `some_block { + value = 15 + + child.block { attr = true } + }` + eval := vm.New(parseBlock(t, input)) + + var actual block + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, 15, actual.Value) + require.True(t, actual.Child.Attr) + }) + + t.Run("Decodes multiple instances of children blocks", func(t *testing.T) { + type block struct { + Value int `river:"value,attr"` + Children []childBlock `river:"child.block,block"` + } + + input := `some_block { + value = 10 + + child.block { attr = true } + child.block { attr = false } + child.block { attr = true } + }` + eval := vm.New(parseBlock(t, input)) + + var actual block + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, 10, actual.Value) + require.Len(t, actual.Children, 3) + require.Equal(t, true, actual.Children[0].Attr) + require.Equal(t, false, actual.Children[1].Attr) + require.Equal(t, true, actual.Children[2].Attr) + }) + + t.Run("Decodes multiple instances of children blocks into an array", func(t *testing.T) { + type block struct { + Value int `river:"value,attr"` + Children [3]childBlock `river:"child.block,block"` + } + + input := `some_block { + value = 15 + + child.block { attr = true } + child.block { attr = false } + child.block { attr = true } + }` + eval := vm.New(parseBlock(t, input)) + + var actual block + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, 15, actual.Value) + require.Equal(t, true, actual.Children[0].Attr) + require.Equal(t, false, actual.Children[1].Attr) + require.Equal(t, true, actual.Children[2].Attr) + }) + + t.Run("Fails if block used as an attribute", func(t *testing.T) { + type block struct { + Child childBlock `river:"child,block"` + } + + input := `some_block { + child = 15 + }` + eval := vm.New(parseBlock(t, input)) + + err := eval.Evaluate(nil, &block{}) + require.EqualError(t, err, `2:4: "child" must be a block, but is used as an attribute`) + }) + + t.Run("Fails if required children blocks are not present", func(t *testing.T) { + type block struct { + Value int `river:"value,attr"` + Child childBlock `river:"child.block,block"` + } + + input := `some_block { + value = 15 + }` + eval := vm.New(parseBlock(t, input)) + + err := eval.Evaluate(nil, &block{}) + require.EqualError(t, err, `missing required block "child.block"`) + }) + + t.Run("Succeeds if optional children blocks are not present", func(t *testing.T) { + type block struct { + Value int `river:"value,attr"` + Child childBlock `river:"child.block,block,optional"` + } + + input := `some_block { + value = 15 + }` + eval := vm.New(parseBlock(t, input)) + + var actual block + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, 15, actual.Value) + }) + + t.Run("Fails if child block is not defined in struct", func(t *testing.T) { + type block struct { + Value int `river:"value,attr"` + } + + input := `some_block { + value = 15 + + child.block { attr = true } + }` + eval := vm.New(parseBlock(t, input)) + + err := eval.Evaluate(nil, &block{}) + require.EqualError(t, err, `4:4: unrecognized block name "child.block"`) + }) + + t.Run("Supports arbitrarily nested struct pointer fields", func(t *testing.T) { + type block struct { + BlockA childBlock `river:"block_a,block"` + BlockB *childBlock `river:"block_b,block"` + BlockC **childBlock `river:"block_c,block"` + BlockD ***childBlock `river:"block_d,block"` + } + + input := `some_block { + block_a { attr = true } + block_b { attr = false } + block_c { attr = true } + block_d { attr = false } + }` + eval := vm.New(parseBlock(t, input)) + + var actual block + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, true, (actual.BlockA).Attr) + require.Equal(t, false, (*actual.BlockB).Attr) + require.Equal(t, true, (**actual.BlockC).Attr) + require.Equal(t, false, (***actual.BlockD).Attr) + }) + + t.Run("Supports squashed blocks", func(t *testing.T) { + type InnerStruct struct { + Inner1 childBlock `river:"inner_block_1,block"` + Inner2 childBlock `river:"inner_block_2,block"` + } + + type OuterStruct struct { + Outer1 childBlock `river:"outer_block_1,block"` + Inner InnerStruct `river:",squash"` + Outer2 childBlock `river:"outer_block_2,block"` + } + + var ( + input = `some_block { + outer_block_1 { attr = true } + outer_block_2 { attr = false } + inner_block_1 { attr = true } + inner_block_2 { attr = false } + }` + + expect = OuterStruct{ + Outer1: childBlock{Attr: true}, + Outer2: childBlock{Attr: false}, + Inner: InnerStruct{ + Inner1: childBlock{Attr: true}, + Inner2: childBlock{Attr: false}, + }, + } + ) + eval := vm.New(parseBlock(t, input)) + + var actual OuterStruct + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, expect, actual) + }) + + t.Run("Supports squashed blocks in pointers", func(t *testing.T) { + type InnerStruct struct { + Inner1 *childBlock `river:"inner_block_1,block"` + Inner2 *childBlock `river:"inner_block_2,block"` + } + + type OuterStruct struct { + Outer1 childBlock `river:"outer_block_1,block"` + Inner *InnerStruct `river:",squash"` + Outer2 childBlock `river:"outer_block_2,block"` + } + + var ( + input = `some_block { + outer_block_1 { attr = true } + outer_block_2 { attr = false } + inner_block_1 { attr = true } + inner_block_2 { attr = false } + }` + + expect = OuterStruct{ + Outer1: childBlock{Attr: true}, + Outer2: childBlock{Attr: false}, + Inner: &InnerStruct{ + Inner1: &childBlock{Attr: true}, + Inner2: &childBlock{Attr: false}, + }, + } + ) + eval := vm.New(parseBlock(t, input)) + + var actual OuterStruct + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, expect, actual) + }) + + // TODO(rfratto): decode all blocks into a []*ast.BlockStmt field. +} + +func TestVM_Block_Enum_Block(t *testing.T) { + type childBlock struct { + Attr int `river:"attr,attr"` + } + + type enumBlock struct { + BlockA *childBlock `river:"a,block,optional"` + BlockB *childBlock `river:"b,block,optional"` + BlockC *childBlock `river:"c,block,optional"` + BlockD *childBlock `river:"d,block,optional"` + } + + t.Run("Decodes enum blocks", func(t *testing.T) { + type block struct { + Value int `river:"value,attr"` + Blocks []*enumBlock `river:"child,enum,optional"` + } + + input := `some_block { + value = 15 + + child.a { attr = 1 } + }` + eval := vm.New(parseBlock(t, input)) + + expect := block{ + Value: 15, + Blocks: []*enumBlock{ + {BlockA: &childBlock{Attr: 1}}, + }, + } + + var actual block + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, expect, actual) + }) + + t.Run("Decodes multiple enum blocks", func(t *testing.T) { + type block struct { + Value int `river:"value,attr"` + Blocks []*enumBlock `river:"child,enum,optional"` + } + + input := `some_block { + value = 15 + + child.b { attr = 1 } + child.a { attr = 2 } + child.c { attr = 3 } + }` + eval := vm.New(parseBlock(t, input)) + + expect := block{ + Value: 15, + Blocks: []*enumBlock{ + {BlockB: &childBlock{Attr: 1}}, + {BlockA: &childBlock{Attr: 2}}, + {BlockC: &childBlock{Attr: 3}}, + }, + } + + var actual block + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, expect, actual) + }) + + t.Run("Decodes multiple enum blocks with repeating blocks", func(t *testing.T) { + type block struct { + Value int `river:"value,attr"` + Blocks []*enumBlock `river:"child,enum,optional"` + } + + input := `some_block { + value = 15 + + child.a { attr = 1 } + child.b { attr = 2 } + child.c { attr = 3 } + child.a { attr = 4 } + }` + eval := vm.New(parseBlock(t, input)) + + expect := block{ + Value: 15, + Blocks: []*enumBlock{ + {BlockA: &childBlock{Attr: 1}}, + {BlockB: &childBlock{Attr: 2}}, + {BlockC: &childBlock{Attr: 3}}, + {BlockA: &childBlock{Attr: 4}}, + }, + } + + var actual block + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, expect, actual) + }) +} + +func TestVM_Block_Label(t *testing.T) { + t.Run("Decodes label into string field", func(t *testing.T) { + type block struct { + Label string `river:",label"` + } + + input := `some_block "label_value_1" {}` + eval := vm.New(parseBlock(t, input)) + + var actual block + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, "label_value_1", actual.Label) + }) + + t.Run("Struct must have label field if block is labeled", func(t *testing.T) { + type block struct{} + + input := `some_block "label_value_2" {}` + eval := vm.New(parseBlock(t, input)) + + err := eval.Evaluate(nil, &block{}) + require.EqualError(t, err, `1:1: block "some_block" does not support specifying labels`) + }) + + t.Run("Block must have label if struct accepts label", func(t *testing.T) { + type block struct { + Label string `river:",label"` + } + + input := `some_block {}` + eval := vm.New(parseBlock(t, input)) + + err := eval.Evaluate(nil, &block{}) + require.EqualError(t, err, `1:1: block "some_block" requires non-empty label`) + }) + + t.Run("Block must have non-empty label if struct accepts label", func(t *testing.T) { + type block struct { + Label string `river:",label"` + } + + input := `some_block "" {}` + eval := vm.New(parseBlock(t, input)) + + err := eval.Evaluate(nil, &block{}) + require.EqualError(t, err, `1:1: block "some_block" requires non-empty label`) + }) +} + +func TestVM_Block_Unmarshaler(t *testing.T) { + type OuterBlock struct { + FieldA string `river:"field_a,attr"` + Settings Setting `river:"some.settings,block"` + } + + input := ` + field_a = "foobar" + some.settings { + field_a = "fizzbuzz" + field_b = "helloworld" + } + ` + + file, err := parser.ParseFile(t.Name(), []byte(input)) + require.NoError(t, err) + + eval := vm.New(file) + + var actual OuterBlock + require.NoError(t, eval.Evaluate(nil, &actual)) + require.True(t, actual.Settings.UnmarshalCalled, "UnmarshalRiver did not get invoked") + require.True(t, actual.Settings.DefaultCalled, "SetToDefault did not get invoked") + require.True(t, actual.Settings.ValidateCalled, "Validate did not get invoked") +} + +func TestVM_Block_UnmarshalToMap(t *testing.T) { + type OuterBlock struct { + Settings map[string]interface{} `river:"some.settings,block"` + } + + tt := []struct { + name string + input string + expect OuterBlock + expectError string + }{ + { + name: "decodes successfully", + input: ` + some.settings { + field_a = 12345 + field_b = "helloworld" + } + `, + expect: OuterBlock{ + Settings: map[string]interface{}{ + "field_a": 12345, + "field_b": "helloworld", + }, + }, + }, + { + name: "rejects labeled blocks", + input: ` + some.settings "foo" { + field_a = 12345 + } + `, + expectError: `block "some.settings" requires non-empty label`, + }, + + { + name: "rejects nested maps", + input: ` + some.settings { + inner_map { + field_a = 12345 + } + } + `, + expectError: "nested blocks not supported here", + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + file, err := parser.ParseFile(t.Name(), []byte(tc.input)) + require.NoError(t, err) + + eval := vm.New(file) + + var actual OuterBlock + err = eval.Evaluate(nil, &actual) + + if tc.expectError == "" { + require.NoError(t, err) + require.Equal(t, tc.expect, actual) + } else { + require.ErrorContains(t, err, tc.expectError) + } + }) + } +} + +func TestVM_Block_UnmarshalToAny(t *testing.T) { + type OuterBlock struct { + Settings any `river:"some.settings,block"` + } + + input := ` + some.settings { + field_a = 12345 + field_b = "helloworld" + } + ` + + file, err := parser.ParseFile(t.Name(), []byte(input)) + require.NoError(t, err) + + eval := vm.New(file) + + var actual OuterBlock + require.NoError(t, eval.Evaluate(nil, &actual)) + + expect := map[string]interface{}{ + "field_a": 12345, + "field_b": "helloworld", + } + require.Equal(t, expect, actual.Settings) +} + +type Setting struct { + FieldA string `river:"field_a,attr"` + FieldB string `river:"field_b,attr"` + + UnmarshalCalled bool + DefaultCalled bool + ValidateCalled bool +} + +func (s *Setting) UnmarshalRiver(f func(interface{}) error) error { + s.UnmarshalCalled = true + return f((*settingUnmarshalTarget)(s)) +} + +type settingUnmarshalTarget Setting + +func (s *settingUnmarshalTarget) SetToDefault() { + s.DefaultCalled = true +} + +func (s *settingUnmarshalTarget) Validate() error { + s.ValidateCalled = true + return nil +} + +func parseBlock(t *testing.T, input string) *ast.BlockStmt { + t.Helper() + + res, err := parser.ParseFile("", []byte(input)) + require.NoError(t, err) + require.Len(t, res.Body, 1) + + stmt, ok := res.Body[0].(*ast.BlockStmt) + require.True(t, ok, "Expected stmt to be a ast.BlockStmt, got %T", res.Body[0]) + return stmt +} diff --git a/syntax/vm/vm_errors_test.go b/syntax/vm/vm_errors_test.go new file mode 100644 index 0000000000..87acdd7b1b --- /dev/null +++ b/syntax/vm/vm_errors_test.go @@ -0,0 +1,80 @@ +package vm_test + +import ( + "testing" + + "github.com/grafana/river/parser" + "github.com/grafana/river/vm" + "github.com/stretchr/testify/require" +) + +func TestVM_ExprErrors(t *testing.T) { + type Target struct { + Key struct { + Object struct { + Field1 []int `river:"field1,attr"` + } `river:"object,attr"` + } `river:"key,attr"` + } + + tt := []struct { + name string + input string + into interface{} + scope *vm.Scope + expect string + }{ + { + name: "basic wrong type", + input: `key = true`, + into: &Target{}, + expect: "test:1:7: true should be object, got bool", + }, + { + name: "deeply nested literal", + input: ` + key = { + object = { + field1 = [15, 30, "Hello, world!"], + }, + } + `, + into: &Target{}, + expect: `test:4:25: "Hello, world!" should be number, got string`, + }, + { + name: "deeply nested indirect", + input: `key = key_value`, + into: &Target{}, + scope: &vm.Scope{ + Variables: map[string]interface{}{ + "key_value": map[string]interface{}{ + "object": map[string]interface{}{ + "field1": []interface{}{15, 30, "Hello, world!"}, + }, + }, + }, + }, + expect: `test:1:7: key_value.object.field1[2] should be number, got string`, + }, + { + name: "complex expr", + input: `key = [0, 1, 2]`, + into: &struct { + Key string `river:"key,attr"` + }{}, + expect: `test:1:7: [0, 1, 2] should be string, got array`, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + res, err := parser.ParseFile("test", []byte(tc.input)) + require.NoError(t, err) + + eval := vm.New(res) + err = eval.Evaluate(tc.scope, tc.into) + require.EqualError(t, err, tc.expect) + }) + } +} diff --git a/syntax/vm/vm_stdlib_test.go b/syntax/vm/vm_stdlib_test.go new file mode 100644 index 0000000000..f395f199e5 --- /dev/null +++ b/syntax/vm/vm_stdlib_test.go @@ -0,0 +1,232 @@ +package vm_test + +import ( + "fmt" + "reflect" + "testing" + + "github.com/grafana/river/internal/value" + "github.com/grafana/river/parser" + "github.com/grafana/river/rivertypes" + "github.com/grafana/river/vm" + "github.com/stretchr/testify/require" +) + +func TestVM_Stdlib(t *testing.T) { + t.Setenv("TEST_VAR", "Hello!") + + tt := []struct { + name string + input string + expect interface{} + }{ + {"env", `env("TEST_VAR")`, string("Hello!")}, + {"concat", `concat([true, "foo"], [], [false, 1])`, []interface{}{true, "foo", false, 1}}, + {"json_decode object", `json_decode("{\"foo\": \"bar\"}")`, map[string]interface{}{"foo": "bar"}}, + {"json_decode array", `json_decode("[0, 1, 2]")`, []interface{}{float64(0), float64(1), float64(2)}}, + {"json_decode nil field", `json_decode("{\"foo\": null}")`, map[string]interface{}{"foo": nil}}, + {"json_decode nil array element", `json_decode("[0, null]")`, []interface{}{float64(0), nil}}, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + expr, err := parser.ParseExpression(tc.input) + require.NoError(t, err) + + eval := vm.New(expr) + + rv := reflect.New(reflect.TypeOf(tc.expect)) + require.NoError(t, eval.Evaluate(nil, rv.Interface())) + require.Equal(t, tc.expect, rv.Elem().Interface()) + }) + } +} + +func TestStdlibCoalesce(t *testing.T) { + t.Setenv("TEST_VAR2", "Hello!") + + tt := []struct { + name string + input string + expect interface{} + }{ + {"coalesce()", `coalesce()`, value.Null}, + {"coalesce(string)", `coalesce("Hello!")`, string("Hello!")}, + {"coalesce(string, string)", `coalesce(env("TEST_VAR2"), "World!")`, string("Hello!")}, + {"(string, string) with fallback", `coalesce(env("NON_DEFINED"), "World!")`, string("World!")}, + {"coalesce(list, list)", `coalesce([], ["fallback"])`, []string{"fallback"}}, + {"coalesce(list, list) with fallback", `coalesce(concat(["item"]), ["fallback"])`, []string{"item"}}, + {"coalesce(int, int, int)", `coalesce(0, 1, 2)`, 1}, + {"coalesce(bool, int, int)", `coalesce(false, 1, 2)`, 1}, + {"coalesce(bool, bool)", `coalesce(false, true)`, true}, + {"coalesce(list, bool)", `coalesce(json_decode("[]"), true)`, true}, + {"coalesce(object, true) and return true", `coalesce(json_decode("{}"), true)`, true}, + {"coalesce(object, false) and return false", `coalesce(json_decode("{}"), false)`, false}, + {"coalesce(list, nil)", `coalesce([],null)`, value.Null}, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + expr, err := parser.ParseExpression(tc.input) + require.NoError(t, err) + + eval := vm.New(expr) + + rv := reflect.New(reflect.TypeOf(tc.expect)) + require.NoError(t, eval.Evaluate(nil, rv.Interface())) + require.Equal(t, tc.expect, rv.Elem().Interface()) + }) + } +} + +func TestStdlibJsonPath(t *testing.T) { + tt := []struct { + name string + input string + expect interface{} + }{ + {"json_path with simple json", `json_path("{\"a\": \"b\"}", ".a")`, []string{"b"}}, + {"json_path with simple json without results", `json_path("{\"a\": \"b\"}", ".nonexists")`, []string{}}, + {"json_path with json array", `json_path("[{\"name\": \"Department\",\"value\": \"IT\"},{\"name\":\"ReferenceNumber\",\"value\":\"123456\"},{\"name\":\"TestStatus\",\"value\":\"Pending\"}]", "[?(@.name == \"Department\")].value")`, []string{"IT"}}, + {"json_path with simple json and return first", `json_path("{\"a\": \"b\"}", ".a")[0]`, "b"}, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + expr, err := parser.ParseExpression(tc.input) + require.NoError(t, err) + + eval := vm.New(expr) + + rv := reflect.New(reflect.TypeOf(tc.expect)) + require.NoError(t, eval.Evaluate(nil, rv.Interface())) + require.Equal(t, tc.expect, rv.Elem().Interface()) + }) + } +} + +func TestStdlib_Nonsensitive(t *testing.T) { + scope := &vm.Scope{ + Variables: map[string]any{ + "secret": rivertypes.Secret("foo"), + "optionalSecret": rivertypes.OptionalSecret{Value: "bar"}, + }, + } + + tt := []struct { + name string + input string + expect interface{} + }{ + {"secret to string", `nonsensitive(secret)`, string("foo")}, + {"optional secret to string", `nonsensitive(optionalSecret)`, string("bar")}, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + expr, err := parser.ParseExpression(tc.input) + require.NoError(t, err) + + eval := vm.New(expr) + + rv := reflect.New(reflect.TypeOf(tc.expect)) + require.NoError(t, eval.Evaluate(scope, rv.Interface())) + require.Equal(t, tc.expect, rv.Elem().Interface()) + }) + } +} +func TestStdlib_StringFunc(t *testing.T) { + scope := &vm.Scope{ + Variables: map[string]any{}, + } + + tt := []struct { + name string + input string + expect interface{} + }{ + {"to_lower", `to_lower("String")`, "string"}, + {"to_upper", `to_upper("string")`, "STRING"}, + {"trimspace", `trim_space(" string \n\n")`, "string"}, + {"trimspace+to_upper+trim", `to_lower(to_upper(trim_space(" String ")))`, "string"}, + {"split", `split("/aaa/bbb/ccc/ddd", "/")`, []string{"", "aaa", "bbb", "ccc", "ddd"}}, + {"split+index", `split("/aaa/bbb/ccc/ddd", "/")[0]`, ""}, + {"join+split", `join(split("/aaa/bbb/ccc/ddd", "/"), "/")`, "/aaa/bbb/ccc/ddd"}, + {"join", `join(["foo", "bar", "baz"], ", ")`, "foo, bar, baz"}, + {"join w/ int", `join([0, 0, 1], ", ")`, "0, 0, 1"}, + {"format", `format("Hello %s", "World")`, "Hello World"}, + {"format+int", `format("%#v", 1)`, "1"}, + {"format+bool", `format("%#v", true)`, "true"}, + {"format+quote", `format("%q", "hello")`, `"hello"`}, + {"replace", `replace("Hello World", " World", "!")`, "Hello!"}, + {"trim", `trim("?!hello?!", "!?")`, "hello"}, + {"trim2", `trim(" hello! world.! ", "! ")`, "hello! world."}, + {"trim_prefix", `trim_prefix("helloworld", "hello")`, "world"}, + {"trim_suffix", `trim_suffix("helloworld", "world")`, "hello"}, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + expr, err := parser.ParseExpression(tc.input) + require.NoError(t, err) + + eval := vm.New(expr) + + rv := reflect.New(reflect.TypeOf(tc.expect)) + require.NoError(t, eval.Evaluate(scope, rv.Interface())) + require.Equal(t, tc.expect, rv.Elem().Interface()) + }) + } +} + +func BenchmarkConcat(b *testing.B) { + // There's a bit of setup work to do here: we want to create a scope holding + // a slice of the Person type, which has a fair amount of data in it. + // + // We then want to pass it through concat. + // + // If the code path is fully optimized, there will be no intermediate + // translations to interface{}. + type Person struct { + Name string `river:"name,attr"` + Attrs map[string]string `river:"attrs,attr"` + } + type Body struct { + Values []Person `river:"values,attr"` + } + + in := `values = concat(values_ref)` + f, err := parser.ParseFile("", []byte(in)) + require.NoError(b, err) + + eval := vm.New(f) + + valuesRef := make([]Person, 0, 20) + for i := 0; i < 20; i++ { + data := make(map[string]string, 20) + for j := 0; j < 20; j++ { + var ( + key = fmt.Sprintf("key_%d", i+1) + value = fmt.Sprintf("value_%d", i+1) + ) + data[key] = value + } + valuesRef = append(valuesRef, Person{ + Name: "Test Person", + Attrs: data, + }) + } + scope := &vm.Scope{ + Variables: map[string]interface{}{ + "values_ref": valuesRef, + }, + } + + // Reset timer before running the actual test + b.ResetTimer() + + for i := 0; i < b.N; i++ { + var b Body + _ = eval.Evaluate(scope, &b) + } +} diff --git a/syntax/vm/vm_test.go b/syntax/vm/vm_test.go new file mode 100644 index 0000000000..5591b08d88 --- /dev/null +++ b/syntax/vm/vm_test.go @@ -0,0 +1,277 @@ +package vm_test + +import ( + "reflect" + "strings" + "testing" + "unicode" + + "github.com/grafana/river/parser" + "github.com/grafana/river/scanner" + "github.com/grafana/river/token" + "github.com/grafana/river/vm" + "github.com/stretchr/testify/require" +) + +func TestVM_Evaluate_Literals(t *testing.T) { + tt := map[string]struct { + input string + expect interface{} + }{ + "number to int": {`12`, int(12)}, + "number to int8": {`13`, int8(13)}, + "number to int16": {`14`, int16(14)}, + "number to int32": {`15`, int32(15)}, + "number to int64": {`16`, int64(16)}, + "number to uint": {`17`, uint(17)}, + "number to uint8": {`18`, uint8(18)}, + "number to uint16": {`19`, uint16(19)}, + "number to uint32": {`20`, uint32(20)}, + "number to uint64": {`21`, uint64(21)}, + "number to float32": {`22`, float32(22)}, + "number to float64": {`23`, float64(23)}, + "number to string": {`24`, string("24")}, + + "float to float32": {`3.2`, float32(3.2)}, + "float to float64": {`3.5`, float64(3.5)}, + "float to string": {`3.9`, string("3.9")}, + + "float with dot to float32": {`.2`, float32(0.2)}, + "float with dot to float64": {`.5`, float64(0.5)}, + "float with dot to string": {`.9`, string("0.9")}, + + "string to string": {`"Hello, world!"`, string("Hello, world!")}, + "string to int": {`"12"`, int(12)}, + "string to float64": {`"12"`, float64(12)}, + } + + for name, tc := range tt { + t.Run(name, func(t *testing.T) { + expr, err := parser.ParseExpression(tc.input) + require.NoError(t, err) + + eval := vm.New(expr) + + vPtr := reflect.New(reflect.TypeOf(tc.expect)).Interface() + require.NoError(t, eval.Evaluate(nil, vPtr)) + + actual := reflect.ValueOf(vPtr).Elem().Interface() + require.Equal(t, tc.expect, actual) + }) + } +} + +func TestVM_Evaluate(t *testing.T) { + // Shared scope across all tests below + scope := &vm.Scope{ + Variables: map[string]interface{}{ + "foobar": int(42), + }, + } + + tt := []struct { + input string + expect interface{} + }{ + // Binops + {`true || false`, bool(true)}, + {`false || false`, bool(false)}, + {`true && false`, bool(false)}, + {`true && true`, bool(true)}, + {`3 == 5`, bool(false)}, + {`3 == 3`, bool(true)}, + {`3 != 5`, bool(true)}, + {`3 < 5`, bool(true)}, + {`3 <= 5`, bool(true)}, + {`3 > 5`, bool(false)}, + {`3 >= 5`, bool(false)}, + {`3 + 5`, int(8)}, + {`3 - 5`, int(-2)}, + {`3 * 5`, int(15)}, + {`3.0 / 5.0`, float64(0.6)}, + {`5 % 3`, int(2)}, + {`3 ^ 5`, int(243)}, + {`3 + 5 * 2`, int(13)}, // Chain multiple binops + {`42.0^-2`, float64(0.0005668934240362812)}, + + // Identifier + {`foobar`, int(42)}, + + // Arrays + {`[]`, []int{}}, + {`[0, 1, 2]`, []int{0, 1, 2}}, + {`[true, false]`, []bool{true, false}}, + + // Objects + {`{ a = 5, b = 10 }`, map[string]int{"a": 5, "b": 10}}, + { + input: `{ + name = "John Doe", + age = 42, + }`, + expect: struct { + Name string `river:"name,attr"` + Age int `river:"age,attr"` + Country string `river:"country,attr,optional"` + }{ + Name: "John Doe", + Age: 42, + }, + }, + + // Access + {`{ a = 15 }.a`, int(15)}, + {`{ a = { b = 12 } }.a.b`, int(12)}, + {`{}["foo"]`, nil}, + + // Indexing + {`[0, 1, 2][1]`, int(1)}, + {`[[1,2,3]][0][2]`, int(3)}, + {`[true, false][0]`, bool(true)}, + + // Paren + {`(15)`, int(15)}, + + // Unary + {`!true`, bool(false)}, + {`!false`, bool(true)}, + {`-15`, int(-15)}, + } + + for _, tc := range tt { + name := trimWhitespace(tc.input) + + t.Run(name, func(t *testing.T) { + expr, err := parser.ParseExpression(tc.input) + require.NoError(t, err) + + eval := vm.New(expr) + + var vPtr any + if tc.expect != nil { + vPtr = reflect.New(reflect.TypeOf(tc.expect)).Interface() + } else { + // Create a new any pointer. + vPtr = reflect.New(reflect.TypeOf((*any)(nil)).Elem()).Interface() + } + + require.NoError(t, eval.Evaluate(scope, vPtr)) + + actual := reflect.ValueOf(vPtr).Elem().Interface() + require.Equal(t, tc.expect, actual) + }) + } +} + +func TestVM_Evaluate_Null(t *testing.T) { + expr, err := parser.ParseExpression("null") + require.NoError(t, err) + + eval := vm.New(expr) + + var v interface{} + require.NoError(t, eval.Evaluate(nil, &v)) + require.Nil(t, v) +} + +func TestVM_Evaluate_IdentifierExpr(t *testing.T) { + t.Run("Valid lookup", func(t *testing.T) { + scope := &vm.Scope{ + Variables: map[string]interface{}{ + "foobar": 15, + }, + } + + expr, err := parser.ParseExpression(`foobar`) + require.NoError(t, err) + + eval := vm.New(expr) + + var actual int + require.NoError(t, eval.Evaluate(scope, &actual)) + require.Equal(t, 15, actual) + }) + + t.Run("Invalid lookup", func(t *testing.T) { + expr, err := parser.ParseExpression(`foobar`) + require.NoError(t, err) + + eval := vm.New(expr) + + var v interface{} + err = eval.Evaluate(nil, &v) + require.EqualError(t, err, `1:1: identifier "foobar" does not exist`) + }) +} + +func TestVM_Evaluate_AccessExpr(t *testing.T) { + t.Run("Lookup optional field", func(t *testing.T) { + type Person struct { + Name string `river:"name,attr,optional"` + } + + scope := &vm.Scope{ + Variables: map[string]interface{}{ + "person": Person{}, + }, + } + + expr, err := parser.ParseExpression(`person.name`) + require.NoError(t, err) + + eval := vm.New(expr) + + var actual string + require.NoError(t, eval.Evaluate(scope, &actual)) + require.Equal(t, "", actual) + }) + + t.Run("Invalid lookup 1", func(t *testing.T) { + expr, err := parser.ParseExpression(`{ a = 15 }.b`) + require.NoError(t, err) + + eval := vm.New(expr) + + var v interface{} + err = eval.Evaluate(nil, &v) + require.EqualError(t, err, `1:12: field "b" does not exist`) + }) + + t.Run("Invalid lookup 2", func(t *testing.T) { + _, err := parser.ParseExpression(`{ a = 15 }.7`) + require.EqualError(t, err, `1:11: expected TERMINATOR, got FLOAT`) + }) + + t.Run("Invalid lookup 3", func(t *testing.T) { + _, err := parser.ParseExpression(`{ a = { b = 12 }.7 }.a.b`) + require.EqualError(t, err, `1:17: missing ',' in field list`) + }) + + t.Run("Invalid lookup 4", func(t *testing.T) { + _, err := parser.ParseExpression(`{ a = { b = 12 } }.a.b.7`) + require.EqualError(t, err, `1:23: expected TERMINATOR, got FLOAT`) + }) +} + +func trimWhitespace(in string) string { + f := token.NewFile("") + + s := scanner.New(f, []byte(in), nil, 0) + + var out strings.Builder + + for { + _, tok, lit := s.Scan() + if tok == token.EOF { + break + } + + if lit != "" { + _, _ = out.WriteString(lit) + } else { + _, _ = out.WriteString(tok.String()) + } + } + + return strings.TrimFunc(out.String(), unicode.IsSpace) +}