From 2ff16d2936abe794c2e4b2e0110e73c640f3456c Mon Sep 17 00:00:00 2001 From: Robert Fratto Date: Fri, 1 Mar 2024 11:53:01 -0500 Subject: [PATCH] misc: port github.com/grafana/river into syntax (#17) This commit brings the github.com/grafana/river code into the syntax package, which is intended to be a submodule. There are a few reasons to do this: 1. "River" is being sunset as a term in favor of the "Alloy configuration syntax," so it no longer makes sense for River to exist on its own. 2. With the transition to Alloy configuration syntax, error messages and alike should remove references to the word "River." 3. It is likely that the Alloy configuration syntax will be receiving a stream of updates soon after the 1.0 release, where having it in a separate repo would slow us down and risks desyncing documentation. 4. There are projects which depend on River, so it must be importable if we bring it into Alloy. However, since we don't want to mark the Go API as 1.0 yet, constraining the Alloy configuration syntax to a submodule allows us to version it separately. --- go.mod | 13 +- go.sum | 10 - .../configs/otel-metrics-gen/Dockerfile | 1 + .../configs/prom-gen/Dockerfile | 1 + syntax/ast/ast.go | 328 +++++++ syntax/ast/walk.go | 73 ++ syntax/cmd/riverfmt/main.go | 103 +++ syntax/diag/diag.go | 95 +++ syntax/diag/printer.go | 266 ++++++ syntax/diag/printer_test.go | 222 +++++ syntax/encoding/riverjson/riverjson.go | 313 +++++++ syntax/encoding/riverjson/riverjson_test.go | 363 ++++++++ syntax/encoding/riverjson/types.go | 41 + syntax/go.mod | 18 + syntax/go.sum | 22 + syntax/internal/reflectutil/walk.go | 89 ++ syntax/internal/reflectutil/walk_test.go | 72 ++ syntax/internal/rivertags/rivertags.go | 346 ++++++++ syntax/internal/rivertags/rivertags_test.go | 182 ++++ syntax/internal/stdlib/constants.go | 19 + syntax/internal/stdlib/stdlib.go | 132 +++ syntax/internal/value/capsule.go | 53 ++ syntax/internal/value/decode.go | 674 +++++++++++++++ .../internal/value/decode_benchmarks_test.go | 90 ++ syntax/internal/value/decode_test.go | 761 +++++++++++++++++ syntax/internal/value/errors.go | 107 +++ syntax/internal/value/number_value.go | 135 +++ syntax/internal/value/raw_function.go | 9 + syntax/internal/value/tag_cache.go | 121 +++ syntax/internal/value/type.go | 157 ++++ syntax/internal/value/type_test.go | 80 ++ syntax/internal/value/value.go | 556 ++++++++++++ syntax/internal/value/value_object.go | 119 +++ syntax/internal/value/value_object_test.go | 205 +++++ syntax/internal/value/value_test.go | 243 ++++++ syntax/parser/error_test.go | 148 ++++ syntax/parser/internal.go | 714 ++++++++++++++++ syntax/parser/internal_test.go | 22 + syntax/parser/parser.go | 43 + syntax/parser/parser_test.go | 123 +++ .../testdata/assign_block_to_attr.river | 32 + syntax/parser/testdata/attribute_names.river | 7 + syntax/parser/testdata/block_names.river | 25 + syntax/parser/testdata/commas.river | 13 + ...fad53537b46efdaa76e024a5ef4955d01a68bdac37 | 2 + ...f4c6c80f4ba9099c21ffa2b6869e75e99565dce037 | 2 + ...77c839a06204b55f2636597901d8d7878150d8580a | 2 + syntax/parser/testdata/invalid_exprs.river | 4 + .../parser/testdata/invalid_object_key.river | 9 + syntax/parser/testdata/valid/attribute.river | 1 + syntax/parser/testdata/valid/blocks.river | 36 + syntax/parser/testdata/valid/comments.river | 1 + syntax/parser/testdata/valid/empty.river | 0 .../parser/testdata/valid/expressions.river | 81 ++ syntax/printer/printer.go | 556 ++++++++++++ syntax/printer/printer_test.go | 77 ++ syntax/printer/testdata/.gitattributes | 1 + syntax/printer/testdata/array_comments.expect | 17 + syntax/printer/testdata/array_comments.in | 17 + syntax/printer/testdata/block_comments.expect | 62 ++ syntax/printer/testdata/block_comments.in | 64 ++ syntax/printer/testdata/example.expect | 60 ++ syntax/printer/testdata/example.in | 64 ++ syntax/printer/testdata/func_call.expect | 17 + syntax/printer/testdata/func_call.in | 17 + syntax/printer/testdata/mixed_list.expect | 16 + syntax/printer/testdata/mixed_list.in | 16 + syntax/printer/testdata/mixed_object.expect | 8 + syntax/printer/testdata/mixed_object.in | 7 + syntax/printer/testdata/object_align.expect | 11 + syntax/printer/testdata/object_align.in | 11 + syntax/printer/testdata/oneline_block.expect | 11 + syntax/printer/testdata/oneline_block.in | 14 + syntax/printer/testdata/raw_string.expect | 15 + syntax/printer/testdata/raw_string.in | 15 + .../testdata/raw_string_label_error.error | 1 + .../testdata/raw_string_label_error.in | 15 + syntax/printer/trimmer.go | 115 +++ syntax/printer/walker.go | 338 ++++++++ syntax/river.go | 346 ++++++++ syntax/river_test.go | 152 ++++ syntax/rivertypes/optional_secret.go | 84 ++ syntax/rivertypes/optional_secret_test.go | 92 ++ syntax/rivertypes/secret.go | 65 ++ syntax/rivertypes/secret_test.go | 47 + syntax/scanner/identifier.go | 60 ++ syntax/scanner/identifier_test.go | 92 ++ syntax/scanner/scanner.go | 704 +++++++++++++++ syntax/scanner/scanner_test.go | 272 ++++++ syntax/token/builder/builder.go | 419 +++++++++ syntax/token/builder/builder_test.go | 411 +++++++++ syntax/token/builder/nested_defaults_test.go | 233 +++++ syntax/token/builder/token.go | 81 ++ syntax/token/builder/value_tokens.go | 95 +++ syntax/token/file.go | 142 ++++ syntax/token/token.go | 174 ++++ syntax/types.go | 97 +++ syntax/vm/constant.go | 64 ++ syntax/vm/error.go | 106 +++ syntax/vm/op_binary.go | 360 ++++++++ syntax/vm/op_binary_test.go | 94 ++ syntax/vm/op_unary.go | 33 + syntax/vm/struct_decoder.go | 323 +++++++ syntax/vm/tag_cache.go | 80 ++ syntax/vm/vm.go | 486 +++++++++++ syntax/vm/vm_benchmarks_test.go | 106 +++ syntax/vm/vm_block_test.go | 802 ++++++++++++++++++ syntax/vm/vm_errors_test.go | 80 ++ syntax/vm/vm_stdlib_test.go | 232 +++++ syntax/vm/vm_test.go | 277 ++++++ 110 files changed, 15423 insertions(+), 15 deletions(-) create mode 100644 syntax/ast/ast.go create mode 100644 syntax/ast/walk.go create mode 100644 syntax/cmd/riverfmt/main.go create mode 100644 syntax/diag/diag.go create mode 100644 syntax/diag/printer.go create mode 100644 syntax/diag/printer_test.go create mode 100644 syntax/encoding/riverjson/riverjson.go create mode 100644 syntax/encoding/riverjson/riverjson_test.go create mode 100644 syntax/encoding/riverjson/types.go create mode 100644 syntax/go.mod create mode 100644 syntax/go.sum create mode 100644 syntax/internal/reflectutil/walk.go create mode 100644 syntax/internal/reflectutil/walk_test.go create mode 100644 syntax/internal/rivertags/rivertags.go create mode 100644 syntax/internal/rivertags/rivertags_test.go create mode 100644 syntax/internal/stdlib/constants.go create mode 100644 syntax/internal/stdlib/stdlib.go create mode 100644 syntax/internal/value/capsule.go create mode 100644 syntax/internal/value/decode.go create mode 100644 syntax/internal/value/decode_benchmarks_test.go create mode 100644 syntax/internal/value/decode_test.go create mode 100644 syntax/internal/value/errors.go create mode 100644 syntax/internal/value/number_value.go create mode 100644 syntax/internal/value/raw_function.go create mode 100644 syntax/internal/value/tag_cache.go create mode 100644 syntax/internal/value/type.go create mode 100644 syntax/internal/value/type_test.go create mode 100644 syntax/internal/value/value.go create mode 100644 syntax/internal/value/value_object.go create mode 100644 syntax/internal/value/value_object_test.go create mode 100644 syntax/internal/value/value_test.go create mode 100644 syntax/parser/error_test.go create mode 100644 syntax/parser/internal.go create mode 100644 syntax/parser/internal_test.go create mode 100644 syntax/parser/parser.go create mode 100644 syntax/parser/parser_test.go create mode 100644 syntax/parser/testdata/assign_block_to_attr.river create mode 100644 syntax/parser/testdata/attribute_names.river create mode 100644 syntax/parser/testdata/block_names.river create mode 100644 syntax/parser/testdata/commas.river create mode 100644 syntax/parser/testdata/fuzz/FuzzParser/1a39f4e358facc21678b16fad53537b46efdaa76e024a5ef4955d01a68bdac37 create mode 100644 syntax/parser/testdata/fuzz/FuzzParser/248cf4391f6c48550b7d2cf4c6c80f4ba9099c21ffa2b6869e75e99565dce037 create mode 100644 syntax/parser/testdata/fuzz/FuzzParser/b919fa00ebca318001778477c839a06204b55f2636597901d8d7878150d8580a create mode 100644 syntax/parser/testdata/invalid_exprs.river create mode 100644 syntax/parser/testdata/invalid_object_key.river create mode 100644 syntax/parser/testdata/valid/attribute.river create mode 100644 syntax/parser/testdata/valid/blocks.river create mode 100644 syntax/parser/testdata/valid/comments.river create mode 100644 syntax/parser/testdata/valid/empty.river create mode 100644 syntax/parser/testdata/valid/expressions.river create mode 100644 syntax/printer/printer.go create mode 100644 syntax/printer/printer_test.go create mode 100644 syntax/printer/testdata/.gitattributes create mode 100644 syntax/printer/testdata/array_comments.expect create mode 100644 syntax/printer/testdata/array_comments.in create mode 100644 syntax/printer/testdata/block_comments.expect create mode 100644 syntax/printer/testdata/block_comments.in create mode 100644 syntax/printer/testdata/example.expect create mode 100644 syntax/printer/testdata/example.in create mode 100644 syntax/printer/testdata/func_call.expect create mode 100644 syntax/printer/testdata/func_call.in create mode 100644 syntax/printer/testdata/mixed_list.expect create mode 100644 syntax/printer/testdata/mixed_list.in create mode 100644 syntax/printer/testdata/mixed_object.expect create mode 100644 syntax/printer/testdata/mixed_object.in create mode 100644 syntax/printer/testdata/object_align.expect create mode 100644 syntax/printer/testdata/object_align.in create mode 100644 syntax/printer/testdata/oneline_block.expect create mode 100644 syntax/printer/testdata/oneline_block.in create mode 100644 syntax/printer/testdata/raw_string.expect create mode 100644 syntax/printer/testdata/raw_string.in create mode 100644 syntax/printer/testdata/raw_string_label_error.error create mode 100644 syntax/printer/testdata/raw_string_label_error.in create mode 100644 syntax/printer/trimmer.go create mode 100644 syntax/printer/walker.go create mode 100644 syntax/river.go create mode 100644 syntax/river_test.go create mode 100644 syntax/rivertypes/optional_secret.go create mode 100644 syntax/rivertypes/optional_secret_test.go create mode 100644 syntax/rivertypes/secret.go create mode 100644 syntax/rivertypes/secret_test.go create mode 100644 syntax/scanner/identifier.go create mode 100644 syntax/scanner/identifier_test.go create mode 100644 syntax/scanner/scanner.go create mode 100644 syntax/scanner/scanner_test.go create mode 100644 syntax/token/builder/builder.go create mode 100644 syntax/token/builder/builder_test.go create mode 100644 syntax/token/builder/nested_defaults_test.go create mode 100644 syntax/token/builder/token.go create mode 100644 syntax/token/builder/value_tokens.go create mode 100644 syntax/token/file.go create mode 100644 syntax/token/token.go create mode 100644 syntax/types.go create mode 100644 syntax/vm/constant.go create mode 100644 syntax/vm/error.go create mode 100644 syntax/vm/op_binary.go create mode 100644 syntax/vm/op_binary_test.go create mode 100644 syntax/vm/op_unary.go create mode 100644 syntax/vm/struct_decoder.go create mode 100644 syntax/vm/tag_cache.go create mode 100644 syntax/vm/vm.go create mode 100644 syntax/vm/vm_benchmarks_test.go create mode 100644 syntax/vm/vm_block_test.go create mode 100644 syntax/vm/vm_errors_test.go create mode 100644 syntax/vm/vm_stdlib_test.go create mode 100644 syntax/vm/vm_test.go diff --git a/go.mod b/go.mod index 81e677a27a..23368dd02b 100644 --- a/go.mod +++ b/go.mod @@ -29,14 +29,13 @@ require ( github.com/docker/go-connections v0.4.0 github.com/drone/envsubst/v2 v2.0.0-20210730161058-179042472c46 github.com/fatih/color v1.15.0 - github.com/fatih/structs v1.1.0 github.com/fortytw2/leaktest v1.3.0 github.com/fsnotify/fsnotify v1.6.0 github.com/github/smimesign v0.2.0 github.com/go-git/go-git/v5 v5.11.0 github.com/go-kit/log v0.2.1 github.com/go-logfmt/logfmt v0.6.0 - github.com/go-logr/logr v1.4.1 + github.com/go-logr/logr v1.4.1 // indirect github.com/go-sourcemap/sourcemap v2.1.3+incompatible github.com/go-sql-driver/mysql v1.7.1 github.com/gogo/protobuf v1.3.2 @@ -45,7 +44,6 @@ require ( github.com/google/cadvisor v0.47.0 github.com/google/dnsmasq_exporter v0.2.1-0.20230620100026-44b14480804a github.com/google/go-cmp v0.6.0 - github.com/google/go-jsonnet v0.18.0 github.com/google/pprof v0.0.0-20240117000934-35fc243c5815 github.com/google/renameio/v2 v2.0.0 github.com/google/uuid v1.4.0 @@ -162,7 +160,6 @@ require ( github.com/spf13/cobra v1.7.0 github.com/stretchr/testify v1.8.4 github.com/testcontainers/testcontainers-go v0.25.0 - github.com/testcontainers/testcontainers-go/modules/k3s v0.0.0-20230615142642-c175df34bd1d github.com/uber/jaeger-client-go v2.30.0+incompatible github.com/vincent-petithory/dataurl v1.0.0 github.com/webdevops/azure-metrics-exporter v0.0.0-20230717202958-8701afc2b013 @@ -226,7 +223,7 @@ require ( gopkg.in/yaml.v3 v3.0.1 gotest.tools v2.2.0+incompatible k8s.io/api v0.28.3 - k8s.io/apiextensions-apiserver v0.28.0 + k8s.io/apiextensions-apiserver v0.28.0 // indirect k8s.io/client-go v0.28.3 k8s.io/component-base v0.28.1 k8s.io/klog/v2 v2.100.1 @@ -766,3 +763,9 @@ exclude ( ) replace github.com/github/smimesign => github.com/grafana/smimesign v0.2.1-0.20220408144937-2a5adf3481d3 + +// Submodules. +// TODO(rfratto): Change all imports of github.com/grafana/river in favor of +// importing github.com/grafana/alloy/syntax and change module and package +// names to remove references of "river". +replace github.com/grafana/river => ./syntax diff --git a/go.sum b/go.sum index 95dd6c93ba..93cd5ab33b 100644 --- a/go.sum +++ b/go.sum @@ -666,12 +666,10 @@ github.com/fatih/camelcase v1.0.0 h1:hxNvNX/xYBp0ovncs8WyWZrOrpBNub/JfaMvbURyft8 github.com/fatih/camelcase v1.0.0/go.mod h1:yN2Sb0lFhZJUdVvtELVWefmrXpuZESvPmqwoZc+/fpc= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU= -github.com/fatih/color v1.10.0/go.mod h1:ELkj/draVOlAH/xkhN6mQ50Qd0MPOk5AAr3maGEBuJM= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= github.com/fatih/color v1.15.0 h1:kOqh6YHBtK8aywxGerMG2Eq3H6Qgoqeo13Bk2Mv/nBs= github.com/fatih/color v1.15.0/go.mod h1:0h5ZqXfHYED7Bhv2ZJamyIOUej9KtShiJESRwBDUSsw= github.com/fatih/structs v0.0.0-20180123065059-ebf56d35bba7/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= -github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo= github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= github.com/felixge/fgprof v0.9.3 h1:VvyZxILNuCiUCSXtPtYmmtGvb65nqXh2QFWc0Wpf2/g= github.com/felixge/fgprof v0.9.3/go.mod h1:RdbpDgzqYVh/T9fPELJyV7EYJuHB55UTEULNun8eiPw= @@ -957,8 +955,6 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= github.com/google/go-github/v32 v32.1.0/go.mod h1:rIEpZD9CTDQwDK9GDrtMTycQNA4JU3qBsCizh3q2WCI= -github.com/google/go-jsonnet v0.18.0 h1:/6pTy6g+Jh1a1I2UMoAODkqELFiVIdOxbNwv0DDzoOg= -github.com/google/go-jsonnet v0.18.0/go.mod h1:C3fTzyVJDslXdiTqw/bTFk7vSGyCtH3MGRbDfvEwGd0= github.com/google/go-querystring v0.0.0-20170111101155-53e6ce116135/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= @@ -1084,8 +1080,6 @@ github.com/grafana/pyroscope/ebpf v0.4.3 h1:gPfm2FKabdycRfFIej/s0awSzsbAaoSefaeh github.com/grafana/pyroscope/ebpf v0.4.3/go.mod h1:Iv66aj9WsDWR8bGMPQzCQPCgVgCru0KizGrbcR3YmLk= github.com/grafana/regexp v0.0.0-20221123153739-15dc172cd2db h1:7aN5cccjIqCLTzedH7MZzRZt5/lsAHch6Z3L2ZGn5FA= github.com/grafana/regexp v0.0.0-20221123153739-15dc172cd2db/go.mod h1:M5qHK+eWfAv8VR/265dIuEpL3fNfeC21tXXp9itM24A= -github.com/grafana/river v0.3.1-0.20240123144725-960753160cd1 h1:mCOKdWkLv8n9X0ORWrPR+W/zLOAa1o6iM+Dfy0ofQUs= -github.com/grafana/river v0.3.1-0.20240123144725-960753160cd1/go.mod h1:tAiNX2zt3HUsNyPNUDSvE6AgQ4+kqJvljBI+ACppMtM= github.com/grafana/smimesign v0.2.1-0.20220408144937-2a5adf3481d3 h1:UPkAxuhlAcRmJT3/qd34OMTl+ZU7BLLfOO2+NXBlJpY= github.com/grafana/smimesign v0.2.1-0.20220408144937-2a5adf3481d3/go.mod h1:iZiiwNT4HbtGRVqCQu7uJPEZCuEE5sfSSttcnePkDl4= github.com/grafana/snowflake-prometheus-exporter v0.0.0-20221213150626-862cad8e9538 h1:tkT0yha3JzB5S5VNjfY4lT0cJAe20pU8XGt3Nuq73rM= @@ -1529,7 +1523,6 @@ github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaO github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= -github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= @@ -2074,7 +2067,6 @@ github.com/segmentio/fasthash v1.0.3 h1:EI9+KE1EwvMLBWwjpRDc+fEM+prwxDYbslddQGtr github.com/segmentio/fasthash v1.0.3/go.mod h1:waKX8l2N8yckOgmSsXJi7x1ZfdKZ4x7KRMzBtS3oedY= github.com/sercand/kuberesolver/v5 v5.1.1 h1:CYH+d67G0sGBj7q5wLK61yzqJJ8gLLC8aeprPTHb6yY= github.com/sercand/kuberesolver/v5 v5.1.1/go.mod h1:Fs1KbKhVRnB2aDWN12NjKCB+RgYMWZJ294T3BtmVCpQ= -github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/sergi/go-diff v1.2.0 h1:XU+rvMAioB0UC3q1MFrIQy4Vo5/4VsRDQQXHsEya6xQ= github.com/sergi/go-diff v1.2.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/shirou/gopsutil v0.0.0-20181107111621-48177ef5f880/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= @@ -2196,8 +2188,6 @@ github.com/tencentcloud/tencentcloud-sdk-go v1.0.162/go.mod h1:asUz5BPXxgoPGaRgZ github.com/tent/http-link-go v0.0.0-20130702225549-ac974c61c2f9/go.mod h1:RHkNRtSLfOK7qBTHaeSX1D6BNpI3qw7NTxsmNr4RvN8= github.com/testcontainers/testcontainers-go v0.25.0 h1:erH6cQjsaJrH+rJDU9qIf89KFdhK0Bft0aEZHlYC3Vs= github.com/testcontainers/testcontainers-go v0.25.0/go.mod h1:4sC9SiJyzD1XFi59q8umTQYWxnkweEc5OjVtTUlJzqQ= -github.com/testcontainers/testcontainers-go/modules/k3s v0.0.0-20230615142642-c175df34bd1d h1:KyYCHo9iBoQYw5AzcozD/77uNbFlRjTmMTA7QjSxHOQ= -github.com/testcontainers/testcontainers-go/modules/k3s v0.0.0-20230615142642-c175df34bd1d/go.mod h1:Pa91ahCbzRB6d9FBi6UAjurTEm7WmyBVeuklLkwAKKs= github.com/tg123/go-htpasswd v1.2.1 h1:i4wfsX1KvvkyoMiHZzjS0VzbAPWfxzI8INcZAKtutoU= github.com/tg123/go-htpasswd v1.2.1/go.mod h1:erHp1B86KXdwQf1X5ZrLb7erXZnWueEQezb2dql4q58= github.com/tidwall/gjson v1.6.0/go.mod h1:P256ACg0Mn+j1RXIDXoss50DeIABTYK1PULOJHhxOls= diff --git a/internal/cmd/integration-tests/configs/otel-metrics-gen/Dockerfile b/internal/cmd/integration-tests/configs/otel-metrics-gen/Dockerfile index 0270edbd0b..bc0c2cf3a9 100644 --- a/internal/cmd/integration-tests/configs/otel-metrics-gen/Dockerfile +++ b/internal/cmd/integration-tests/configs/otel-metrics-gen/Dockerfile @@ -1,6 +1,7 @@ FROM golang:1.21 as build WORKDIR /app/ COPY go.mod go.sum ./ +COPY syntax/go.mod syntax/go.sum ./syntax/ RUN go mod download COPY ./internal/cmd/integration-tests/configs/otel-metrics-gen/ ./ RUN CGO_ENABLED=0 go build -o main main.go diff --git a/internal/cmd/integration-tests/configs/prom-gen/Dockerfile b/internal/cmd/integration-tests/configs/prom-gen/Dockerfile index d1e0bfdcaf..875b7bad7e 100644 --- a/internal/cmd/integration-tests/configs/prom-gen/Dockerfile +++ b/internal/cmd/integration-tests/configs/prom-gen/Dockerfile @@ -1,6 +1,7 @@ FROM golang:1.21 as build WORKDIR /app/ COPY go.mod go.sum ./ +COPY syntax/go.mod syntax/go.sum ./syntax/ RUN go mod download COPY ./internal/cmd/integration-tests/configs/prom-gen/ ./ RUN CGO_ENABLED=0 go build -o main main.go diff --git a/syntax/ast/ast.go b/syntax/ast/ast.go new file mode 100644 index 0000000000..992ee0c71a --- /dev/null +++ b/syntax/ast/ast.go @@ -0,0 +1,328 @@ +// Package ast exposes AST elements used by River. +// +// The various interfaces exposed by ast are all closed; only types within this +// package can satisfy an AST interface. +package ast + +import ( + "fmt" + "reflect" + "strings" + + "github.com/grafana/river/token" +) + +// Node represents any node in the AST. +type Node interface { + astNode() +} + +// Stmt is a type of statement within the body of a file or block. +type Stmt interface { + Node + astStmt() +} + +// Expr is an expression within the AST. +type Expr interface { + Node + astExpr() +} + +// File is a parsed file. +type File struct { + Name string // Filename provided to parser + Body Body // Content of File + Comments []CommentGroup // List of all comments in the File +} + +// Body is a list of statements. +type Body []Stmt + +// A CommentGroup represents a sequence of comments that are not separated by +// any empty lines or other non-comment tokens. +type CommentGroup []*Comment + +// A Comment represents a single line or block comment. +// +// The Text field contains the comment text without any carriage returns (\r) +// that may have been present in the source. Since carriage returns get +// removed, EndPos will not be accurate for any comment which contained +// carriage returns. +type Comment struct { + StartPos token.Pos // Starting position of comment + // Text of the comment. Text will not contain '\n' for line comments. + Text string +} + +// AttributeStmt is a key-value pair being set in a Body or BlockStmt. +type AttributeStmt struct { + Name *Ident + Value Expr +} + +// BlockStmt declares a block. +type BlockStmt struct { + Name []string + NamePos token.Pos + Label string + LabelPos token.Pos + Body Body + + LCurlyPos, RCurlyPos token.Pos +} + +// Ident holds an identifier with its position. +type Ident struct { + Name string + NamePos token.Pos +} + +// IdentifierExpr refers to a named value. +type IdentifierExpr struct { + Ident *Ident +} + +// LiteralExpr is a constant value of a specific token kind. +type LiteralExpr struct { + Kind token.Token + ValuePos token.Pos + + // Value holds the unparsed literal value. For example, if Kind == + // token.STRING, then Value would be wrapped in the original quotes (e.g., + // `"foobar"`). + Value string +} + +// ArrayExpr is an array of values. +type ArrayExpr struct { + Elements []Expr + LBrackPos, RBrackPos token.Pos +} + +// ObjectExpr declares an object of key-value pairs. +type ObjectExpr struct { + Fields []*ObjectField + LCurlyPos, RCurlyPos token.Pos +} + +// ObjectField defines an individual key-value pair within an object. +// ObjectField does not implement Node. +type ObjectField struct { + Name *Ident + Quoted bool // True if the name was wrapped in quotes + Value Expr +} + +// AccessExpr accesses a field in an object value by name. +type AccessExpr struct { + Value Expr + Name *Ident +} + +// IndexExpr accesses an index in an array value. +type IndexExpr struct { + Value, Index Expr + LBrackPos, RBrackPos token.Pos +} + +// CallExpr invokes a function value with a set of arguments. +type CallExpr struct { + Value Expr + Args []Expr + + LParenPos, RParenPos token.Pos +} + +// UnaryExpr performs a unary operation on a single value. +type UnaryExpr struct { + Kind token.Token + KindPos token.Pos + Value Expr +} + +// BinaryExpr performs a binary operation against two values. +type BinaryExpr struct { + Kind token.Token + KindPos token.Pos + Left, Right Expr +} + +// ParenExpr represents an expression wrapped in parentheses. +type ParenExpr struct { + Inner Expr + LParenPos, RParenPos token.Pos +} + +// Type assertions + +var ( + _ Node = (*File)(nil) + _ Node = (*Body)(nil) + _ Node = (*AttributeStmt)(nil) + _ Node = (*BlockStmt)(nil) + _ Node = (*Ident)(nil) + _ Node = (*IdentifierExpr)(nil) + _ Node = (*LiteralExpr)(nil) + _ Node = (*ArrayExpr)(nil) + _ Node = (*ObjectExpr)(nil) + _ Node = (*AccessExpr)(nil) + _ Node = (*IndexExpr)(nil) + _ Node = (*CallExpr)(nil) + _ Node = (*UnaryExpr)(nil) + _ Node = (*BinaryExpr)(nil) + _ Node = (*ParenExpr)(nil) + + _ Stmt = (*AttributeStmt)(nil) + _ Stmt = (*BlockStmt)(nil) + + _ Expr = (*IdentifierExpr)(nil) + _ Expr = (*LiteralExpr)(nil) + _ Expr = (*ArrayExpr)(nil) + _ Expr = (*ObjectExpr)(nil) + _ Expr = (*AccessExpr)(nil) + _ Expr = (*IndexExpr)(nil) + _ Expr = (*CallExpr)(nil) + _ Expr = (*UnaryExpr)(nil) + _ Expr = (*BinaryExpr)(nil) + _ Expr = (*ParenExpr)(nil) +) + +func (n *File) astNode() {} +func (n Body) astNode() {} +func (n CommentGroup) astNode() {} +func (n *Comment) astNode() {} +func (n *AttributeStmt) astNode() {} +func (n *BlockStmt) astNode() {} +func (n *Ident) astNode() {} +func (n *IdentifierExpr) astNode() {} +func (n *LiteralExpr) astNode() {} +func (n *ArrayExpr) astNode() {} +func (n *ObjectExpr) astNode() {} +func (n *AccessExpr) astNode() {} +func (n *IndexExpr) astNode() {} +func (n *CallExpr) astNode() {} +func (n *UnaryExpr) astNode() {} +func (n *BinaryExpr) astNode() {} +func (n *ParenExpr) astNode() {} + +func (n *AttributeStmt) astStmt() {} +func (n *BlockStmt) astStmt() {} + +func (n *IdentifierExpr) astExpr() {} +func (n *LiteralExpr) astExpr() {} +func (n *ArrayExpr) astExpr() {} +func (n *ObjectExpr) astExpr() {} +func (n *AccessExpr) astExpr() {} +func (n *IndexExpr) astExpr() {} +func (n *CallExpr) astExpr() {} +func (n *UnaryExpr) astExpr() {} +func (n *BinaryExpr) astExpr() {} +func (n *ParenExpr) astExpr() {} + +// StartPos returns the position of the first character belonging to a Node. +func StartPos(n Node) token.Pos { + if n == nil || reflect.ValueOf(n).IsZero() { + return token.NoPos + } + switch n := n.(type) { + case *File: + return StartPos(n.Body) + case Body: + if len(n) == 0 { + return token.NoPos + } + return StartPos(n[0]) + case CommentGroup: + if len(n) == 0 { + return token.NoPos + } + return StartPos(n[0]) + case *Comment: + return n.StartPos + case *AttributeStmt: + return StartPos(n.Name) + case *BlockStmt: + return n.NamePos + case *Ident: + return n.NamePos + case *IdentifierExpr: + return StartPos(n.Ident) + case *LiteralExpr: + return n.ValuePos + case *ArrayExpr: + return n.LBrackPos + case *ObjectExpr: + return n.LCurlyPos + case *AccessExpr: + return StartPos(n.Value) + case *IndexExpr: + return StartPos(n.Value) + case *CallExpr: + return StartPos(n.Value) + case *UnaryExpr: + return n.KindPos + case *BinaryExpr: + return StartPos(n.Left) + case *ParenExpr: + return n.LParenPos + default: + panic(fmt.Sprintf("Unhandled Node type %T", n)) + } +} + +// EndPos returns the position of the final character in a Node. +func EndPos(n Node) token.Pos { + if n == nil || reflect.ValueOf(n).IsZero() { + return token.NoPos + } + switch n := n.(type) { + case *File: + return EndPos(n.Body) + case Body: + if len(n) == 0 { + return token.NoPos + } + return EndPos(n[len(n)-1]) + case CommentGroup: + if len(n) == 0 { + return token.NoPos + } + return EndPos(n[len(n)-1]) + case *Comment: + return n.StartPos.Add(len(n.Text) - 1) + case *AttributeStmt: + return EndPos(n.Value) + case *BlockStmt: + return n.RCurlyPos + case *Ident: + return n.NamePos.Add(len(n.Name) - 1) + case *IdentifierExpr: + return EndPos(n.Ident) + case *LiteralExpr: + return n.ValuePos.Add(len(n.Value) - 1) + case *ArrayExpr: + return n.RBrackPos + case *ObjectExpr: + return n.RCurlyPos + case *AccessExpr: + return EndPos(n.Name) + case *IndexExpr: + return n.RBrackPos + case *CallExpr: + return n.RParenPos + case *UnaryExpr: + return EndPos(n.Value) + case *BinaryExpr: + return EndPos(n.Right) + case *ParenExpr: + return n.RParenPos + default: + panic(fmt.Sprintf("Unhandled Node type %T", n)) + } +} + +// GetBlockName retrieves the "." delimited block name. +func (block *BlockStmt) GetBlockName() string { + return strings.Join(block.Name, ".") +} diff --git a/syntax/ast/walk.go b/syntax/ast/walk.go new file mode 100644 index 0000000000..df3f82d9a3 --- /dev/null +++ b/syntax/ast/walk.go @@ -0,0 +1,73 @@ +package ast + +import "fmt" + +// A Visitor has its Visit method invoked for each node encountered by Walk. If +// the resulting visitor w is not nil, Walk visits each of the children of node +// with the visitor w, followed by a call of w.Visit(nil). +type Visitor interface { + Visit(node Node) (w Visitor) +} + +// Walk traverses an AST in depth-first order: it starts by calling +// v.Visit(node); node must not be nil. If the visitor w returned by +// v.Visit(node) is not nil, Walk is invoked recursively with visitor w for +// each of the non-nil children of node, followed by a call of w.Visit(nil). +func Walk(v Visitor, node Node) { + if v = v.Visit(node); v == nil { + return + } + + // Walk children. The order of the cases matches the declared order of nodes + // in ast.go. + switch n := node.(type) { + case *File: + Walk(v, n.Body) + case Body: + for _, s := range n { + Walk(v, s) + } + case *AttributeStmt: + Walk(v, n.Name) + Walk(v, n.Value) + case *BlockStmt: + Walk(v, n.Body) + case *Ident: + // Nothing to do + case *IdentifierExpr: + Walk(v, n.Ident) + case *LiteralExpr: + // Nothing to do + case *ArrayExpr: + for _, e := range n.Elements { + Walk(v, e) + } + case *ObjectExpr: + for _, f := range n.Fields { + Walk(v, f.Name) + Walk(v, f.Value) + } + case *AccessExpr: + Walk(v, n.Value) + Walk(v, n.Name) + case *IndexExpr: + Walk(v, n.Value) + Walk(v, n.Index) + case *CallExpr: + Walk(v, n.Value) + for _, a := range n.Args { + Walk(v, a) + } + case *UnaryExpr: + Walk(v, n.Value) + case *BinaryExpr: + Walk(v, n.Left) + Walk(v, n.Right) + case *ParenExpr: + Walk(v, n.Inner) + default: + panic(fmt.Sprintf("river/ast: unexpected node type %T", n)) + } + + v.Visit(nil) +} diff --git a/syntax/cmd/riverfmt/main.go b/syntax/cmd/riverfmt/main.go new file mode 100644 index 0000000000..d7b4433f5a --- /dev/null +++ b/syntax/cmd/riverfmt/main.go @@ -0,0 +1,103 @@ +package main + +import ( + "bytes" + "errors" + "flag" + "fmt" + "io" + "os" + + "github.com/grafana/river/diag" + "github.com/grafana/river/parser" + "github.com/grafana/river/printer" +) + +func main() { + err := run() + + var diags diag.Diagnostics + if errors.As(err, &diags) { + for _, diag := range diags { + fmt.Fprintln(os.Stderr, diag) + } + os.Exit(1) + } else if err != nil { + fmt.Fprintf(os.Stderr, "error: %s\n", err) + os.Exit(1) + } +} + +func run() error { + var ( + write bool + ) + + fs := flag.NewFlagSet("riverfmt", flag.ExitOnError) + fs.BoolVar(&write, "w", write, "write result to (source) file instead of stdout") + + if err := fs.Parse(os.Args[1:]); err != nil { + return err + } + + args := fs.Args() + switch len(args) { + case 0: + if write { + return fmt.Errorf("cannot use -w with standard input") + } + return format("", nil, os.Stdin, write) + + case 1: + fi, err := os.Stat(args[0]) + if err != nil { + return err + } + if fi.IsDir() { + return fmt.Errorf("cannot format a directory") + } + f, err := os.Open(args[0]) + if err != nil { + return err + } + defer f.Close() + return format(args[0], fi, f, write) + + default: + return fmt.Errorf("can only format one file") + } +} + +func format(filename string, fi os.FileInfo, r io.Reader, write bool) error { + bb, err := io.ReadAll(r) + if err != nil { + return err + } + + f, err := parser.ParseFile(filename, bb) + if err != nil { + return err + } + + var buf bytes.Buffer + if err := printer.Fprint(&buf, f); err != nil { + return err + } + + // Add a newline at the end + _, _ = buf.Write([]byte{'\n'}) + + if !write { + _, err := io.Copy(os.Stdout, &buf) + return err + } + + wf, err := os.OpenFile(filename, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, fi.Mode().Perm()) + if err != nil { + return err + } + defer wf.Close() + + _, err = io.Copy(wf, &buf) + return err +} diff --git a/syntax/diag/diag.go b/syntax/diag/diag.go new file mode 100644 index 0000000000..a49487af61 --- /dev/null +++ b/syntax/diag/diag.go @@ -0,0 +1,95 @@ +// Package diag exposes error types used throughout River and a method to +// pretty-print them to the screen. +package diag + +import ( + "fmt" + + "github.com/grafana/river/token" +) + +// Severity denotes the severity level of a diagnostic. The zero value of +// severity is invalid. +type Severity int + +// Supported severity levels. +const ( + SeverityLevelWarn Severity = iota + 1 + SeverityLevelError +) + +// Diagnostic is an individual diagnostic message. Diagnostic messages can have +// different levels of severities. +type Diagnostic struct { + // Severity holds the severity level of this Diagnostic. + Severity Severity + + // StartPos refers to a position in a file where this Diagnostic starts. + StartPos token.Position + + // EndPos refers to an optional position in a file where this Diagnostic + // ends. If EndPos is the zero value, the Diagnostic should be treated as + // only covering a single character (i.e., StartPos == EndPos). + // + // When defined, EndPos must have the same Filename value as the StartPos. + EndPos token.Position + + Message string + Value string +} + +// As allows d to be interpreted as a list of Diagnostics. +func (d Diagnostic) As(v interface{}) bool { + switch v := v.(type) { + case *Diagnostics: + *v = Diagnostics{d} + return true + } + + return false +} + +// Error implements error. +func (d Diagnostic) Error() string { + return fmt.Sprintf("%s: %s", d.StartPos, d.Message) +} + +// Diagnostics is a collection of diagnostic messages. +type Diagnostics []Diagnostic + +// Add adds an individual Diagnostic to the diagnostics list. +func (ds *Diagnostics) Add(d Diagnostic) { + *ds = append(*ds, d) +} + +// Error implements error. +func (ds Diagnostics) Error() string { + switch len(ds) { + case 0: + return "no errors" + case 1: + return ds[0].Error() + default: + return fmt.Sprintf("%s (and %d more diagnostics)", ds[0], len(ds)-1) + } +} + +// ErrorOrNil returns an error interface if the list diagnostics is non-empty, +// nil otherwise. +func (ds Diagnostics) ErrorOrNil() error { + if len(ds) == 0 { + return nil + } + return ds +} + +// HasErrors reports whether the list of Diagnostics contain any error-level +// diagnostic. +func (ds Diagnostics) HasErrors() bool { + for _, d := range ds { + if d.Severity == SeverityLevelError { + return true + } + } + return false +} diff --git a/syntax/diag/printer.go b/syntax/diag/printer.go new file mode 100644 index 0000000000..03994d68cf --- /dev/null +++ b/syntax/diag/printer.go @@ -0,0 +1,266 @@ +package diag + +import ( + "bufio" + "fmt" + "io" + "strconv" + "strings" + + "github.com/fatih/color" + "github.com/grafana/river/token" +) + +const tabWidth = 4 + +// PrinterConfig controls different settings for the Printer. +type PrinterConfig struct { + // When Color is true, the printer will output with color and special + // formatting characters (such as underlines). + // + // This should be disabled when not printing to a terminal. + Color bool + + // ContextLinesBefore and ContextLinesAfter controls how many context lines + // before and after the range of the diagnostic are printed. + ContextLinesBefore, ContextLinesAfter int +} + +// A Printer pretty-prints Diagnostics. +type Printer struct { + cfg PrinterConfig +} + +// NewPrinter creates a new diagnostics Printer with the provided config. +func NewPrinter(cfg PrinterConfig) *Printer { + return &Printer{cfg: cfg} +} + +// Fprint creates a Printer with default settings and prints diagnostics to the +// provided writer. files is used to look up file contents by name for printing +// diagnostics context. files may be set to nil to avoid printing context. +func Fprint(w io.Writer, files map[string][]byte, diags Diagnostics) error { + p := NewPrinter(PrinterConfig{ + Color: false, + ContextLinesBefore: 1, + ContextLinesAfter: 1, + }) + return p.Fprint(w, files, diags) +} + +// Fprint pretty-prints errors to a writer. files is used to look up file +// contents by name when printing context. files may be nil to avoid printing +// context. +func (p *Printer) Fprint(w io.Writer, files map[string][]byte, diags Diagnostics) error { + // Create a buffered writer since we'll have many small calls to Write while + // we print errors. + // + // Buffers writers track the first write error received and will return it + // (if any) when flushing, so we can ignore write errors throughout the code + // until the very end. + bw := bufio.NewWriter(w) + + for i, diag := range diags { + p.printDiagnosticHeader(bw, diag) + + // If there's no ending position, set the ending position to be the same as + // the start. + if !diag.EndPos.Valid() { + diag.EndPos = diag.StartPos + } + + // We can print the file context if it was found. + fileContents, foundFile := files[diag.StartPos.Filename] + if foundFile && diag.StartPos.Filename == diag.EndPos.Filename { + p.printRange(bw, fileContents, diag) + } + + // Print a blank line to separate diagnostics. + if i+1 < len(diags) { + fmt.Fprintf(bw, "\n") + } + } + + return bw.Flush() +} + +func (p *Printer) printDiagnosticHeader(w io.Writer, diag Diagnostic) { + if p.cfg.Color { + switch diag.Severity { + case SeverityLevelError: + cw := color.New(color.FgRed, color.Bold) + _, _ = cw.Fprintf(w, "Error: ") + case SeverityLevelWarn: + cw := color.New(color.FgYellow, color.Bold) + _, _ = cw.Fprintf(w, "Warning: ") + } + + cw := color.New(color.Bold) + _, _ = cw.Fprintf(w, "%s: %s\n", diag.StartPos, diag.Message) + return + } + + switch diag.Severity { + case SeverityLevelError: + _, _ = fmt.Fprintf(w, "Error: ") + case SeverityLevelWarn: + _, _ = fmt.Fprintf(w, "Warning: ") + } + fmt.Fprintf(w, "%s: %s\n", diag.StartPos, diag.Message) +} + +func (p *Printer) printRange(w io.Writer, file []byte, diag Diagnostic) { + var ( + start = diag.StartPos + end = diag.EndPos + ) + + fmt.Fprintf(w, "\n") + + var ( + lines = strings.Split(string(file), "\n") + + startLine = max(start.Line-p.cfg.ContextLinesBefore, 1) + endLine = min(end.Line+p.cfg.ContextLinesAfter, len(lines)) + + multiline = end.Line-start.Line > 0 + ) + + prefixWidth := len(strconv.Itoa(endLine)) + + for lineNum := startLine; lineNum <= endLine; lineNum++ { + line := lines[lineNum-1] + + // Print line number and margin. + printPaddedNumber(w, prefixWidth, lineNum) + fmt.Fprintf(w, " | ") + + if multiline { + // Use 0 for the column number so we never consider the starting line for + // showing |. + if inRange(lineNum, 0, start, end) { + fmt.Fprint(w, "| ") + } else { + fmt.Fprint(w, " ") + } + } + + // Print the line, but filter out any \r and replace tabs with spaces. + for _, ch := range line { + if ch == '\r' { + continue + } + if ch == '\t' || ch == '\v' { + printCh(w, tabWidth, ' ') + continue + } + fmt.Fprintf(w, "%c", ch) + } + + fmt.Fprintf(w, "\n") + + // Print the focus indicator if we're on a line that needs it. + // + // The focus indicator line must preserve whitespace present in the line + // above it prior to the focus '^' characters. Tab characters are replaced + // with spaces for consistent printing. + if lineNum == start.Line || (multiline && lineNum == end.Line) { + printCh(w, prefixWidth, ' ') // Add empty space where line number would be + + // Print the margin after the blank line number. On multi-line errors, + // the arrow is printed all the way to the margin, with straight + // lines going down in between the lines. + switch { + case multiline && lineNum == start.Line: + // |_ would look like an incorrect right angle, so the second bar + // is dropped. + fmt.Fprintf(w, " | _") + case multiline && lineNum == end.Line: + fmt.Fprintf(w, " | |_") + default: + fmt.Fprintf(w, " | ") + } + + p.printFocus(w, line, lineNum, diag) + fmt.Fprintf(w, "\n") + } + } +} + +// printFocus prints the focus indicator for the line number specified by line. +// The contents of the line should be represented by data so whitespace can be +// retained (injecting spaces where a tab should be, etc.). +func (p *Printer) printFocus(w io.Writer, data string, line int, diag Diagnostic) { + for i, ch := range data { + column := i + 1 + + if line == diag.EndPos.Line && column > diag.EndPos.Column { + // Stop printing the formatting line after printing all the ^. + break + } + + blank := byte(' ') + if diag.EndPos.Line-diag.StartPos.Line > 0 { + blank = byte('_') + } + + switch { + case ch == '\t' || ch == '\v': + printCh(w, tabWidth, blank) + case inRange(line, column, diag.StartPos, diag.EndPos): + fmt.Fprintf(w, "%c", '^') + default: + // Print a space. + fmt.Fprintf(w, "%c", blank) + } + } +} + +func inRange(line, col int, start, end token.Position) bool { + if line < start.Line || line > end.Line { + return false + } + + switch line { + case start.Line: + // If the current line is on the starting line, we have to be past the + // starting column. + return col >= start.Column + case end.Line: + // If the current line is on the ending line, we have to be before the + // final column. + return col <= end.Column + default: + // Otherwise, every column across all the lines in between + // is in the range. + return true + } +} + +func printPaddedNumber(w io.Writer, width int, num int) { + numStr := strconv.Itoa(num) + for i := 0; i < width-len(numStr); i++ { + _, _ = w.Write([]byte{' '}) + } + _, _ = w.Write([]byte(numStr)) +} + +func printCh(w io.Writer, count int, ch byte) { + for i := 0; i < count; i++ { + _, _ = w.Write([]byte{ch}) + } +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func max(a, b int) int { + if a > b { + return a + } + return b +} diff --git a/syntax/diag/printer_test.go b/syntax/diag/printer_test.go new file mode 100644 index 0000000000..4558e666ef --- /dev/null +++ b/syntax/diag/printer_test.go @@ -0,0 +1,222 @@ +package diag_test + +import ( + "bytes" + "fmt" + "testing" + + "github.com/grafana/river/diag" + "github.com/grafana/river/token" + "github.com/stretchr/testify/require" +) + +func TestFprint(t *testing.T) { + // In all tests below, the filename is "testfile" and the severity is an + // error. + + tt := []struct { + name string + input string + start, end token.Position + diag diag.Diagnostic + expect string + }{ + { + name: "highlight on same line", + start: token.Position{Line: 2, Column: 2}, + end: token.Position{Line: 2, Column: 5}, + input: `test.block "label" { + attr = 1 + other_attr = 2 +}`, + expect: `Error: testfile:2:2: synthetic error + +1 | test.block "label" { +2 | attr = 1 + | ^^^^ +3 | other_attr = 2 +`, + }, + + { + name: "end positions should be optional", + start: token.Position{Line: 1, Column: 4}, + input: `foo,bar`, + expect: `Error: testfile:1:4: synthetic error + +1 | foo,bar + | ^ +`, + }, + + { + name: "padding should be inserted to fit line numbers of different lengths", + start: token.Position{Line: 9, Column: 1}, + end: token.Position{Line: 9, Column: 6}, + input: `LINE_1 +LINE_2 +LINE_3 +LINE_4 +LINE_5 +LINE_6 +LINE_7 +LINE_8 +LINE_9 +LINE_10 +LINE_11`, + expect: `Error: testfile:9:1: synthetic error + + 8 | LINE_8 + 9 | LINE_9 + | ^^^^^^ +10 | LINE_10 +`, + }, + + { + name: "errors which cross multiple lines can be printed from start of line", + start: token.Position{Line: 2, Column: 1}, + end: token.Position{Line: 6, Column: 7}, + input: `FILE_BEGIN +START +TEXT + TEXT + TEXT + DONE after +FILE_END`, + expect: `Error: testfile:2:1: synthetic error + +1 | FILE_BEGIN +2 | START + | _^^^^^ +3 | | TEXT +4 | | TEXT +5 | | TEXT +6 | | DONE after + | |_____________^^^^ +7 | FILE_END +`, + }, + + { + name: "errors which cross multiple lines can be printed from middle of line", + start: token.Position{Line: 2, Column: 8}, + end: token.Position{Line: 6, Column: 7}, + input: `FILE_BEGIN +before START +TEXT + TEXT + TEXT + DONE after +FILE_END`, + expect: `Error: testfile:2:8: synthetic error + +1 | FILE_BEGIN +2 | before START + | ________^^^^^ +3 | | TEXT +4 | | TEXT +5 | | TEXT +6 | | DONE after + | |_____________^^^^ +7 | FILE_END +`, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + files := map[string][]byte{ + "testfile": []byte(tc.input), + } + + tc.start.Filename = "testfile" + tc.end.Filename = "testfile" + + diags := diag.Diagnostics{{ + Severity: diag.SeverityLevelError, + StartPos: tc.start, + EndPos: tc.end, + Message: "synthetic error", + }} + + var buf bytes.Buffer + _ = diag.Fprint(&buf, files, diags) + requireEqualStrings(t, tc.expect, buf.String()) + }) + } +} + +func TestFprint_MultipleDiagnostics(t *testing.T) { + fileA := `old_field = 15 +3 & 4` + fileB := `old_field = 22` + + files := map[string][]byte{ + "file_a": []byte(fileA), + "file_b": []byte(fileB), + } + + diags := diag.Diagnostics{ + { + Severity: diag.SeverityLevelWarn, + StartPos: token.Position{Filename: "file_a", Line: 1, Column: 1}, + EndPos: token.Position{Filename: "file_a", Line: 1, Column: 9}, + Message: "old_field is deprecated", + }, + { + Severity: diag.SeverityLevelError, + StartPos: token.Position{Filename: "file_a", Line: 2, Column: 3}, + Message: "unrecognized operator &", + }, + { + Severity: diag.SeverityLevelWarn, + StartPos: token.Position{Filename: "file_b", Line: 1, Column: 1}, + EndPos: token.Position{Filename: "file_b", Line: 1, Column: 9}, + Message: "old_field is deprecated", + }, + } + + expect := `Warning: file_a:1:1: old_field is deprecated + +1 | old_field = 15 + | ^^^^^^^^^ +2 | 3 & 4 + +Error: file_a:2:3: unrecognized operator & + +1 | old_field = 15 +2 | 3 & 4 + | ^ + +Warning: file_b:1:1: old_field is deprecated + +1 | old_field = 22 + | ^^^^^^^^^ +` + + var buf bytes.Buffer + _ = diag.Fprint(&buf, files, diags) + requireEqualStrings(t, expect, buf.String()) +} + +// requireEqualStrings is like require.Equal with two strings but it +// pretty-prints multiline strings to make it easier to compare. +func requireEqualStrings(t *testing.T, expected, actual string) { + if expected == actual { + return + } + + msg := fmt.Sprintf( + "Not equal:\n"+ + "raw expected: %#v\n"+ + "raw actual : %#v\n"+ + "\n"+ + "expected:\n%s\n"+ + "actual:\n%s\n", + expected, actual, + expected, actual, + ) + + require.Fail(t, msg) +} diff --git a/syntax/encoding/riverjson/riverjson.go b/syntax/encoding/riverjson/riverjson.go new file mode 100644 index 0000000000..fc69882918 --- /dev/null +++ b/syntax/encoding/riverjson/riverjson.go @@ -0,0 +1,313 @@ +// Package riverjson encodes River as JSON. +package riverjson + +import ( + "encoding/json" + "fmt" + "reflect" + "sort" + "strings" + + "github.com/grafana/river/internal/reflectutil" + "github.com/grafana/river/internal/rivertags" + "github.com/grafana/river/internal/value" + "github.com/grafana/river/token/builder" +) + +var goRiverDefaulter = reflect.TypeOf((*value.Defaulter)(nil)).Elem() + +// MarshalBody marshals the provided Go value to a JSON representation of +// River. MarshalBody panics if not given a struct with River tags or a map[string]any. +func MarshalBody(val interface{}) ([]byte, error) { + rv := reflect.ValueOf(val) + return json.Marshal(encodeStructAsBody(rv)) +} + +func encodeStructAsBody(rv reflect.Value) jsonBody { + for rv.Kind() == reflect.Pointer { + if rv.IsNil() { + return jsonBody{} + } + rv = rv.Elem() + } + + if rv.Kind() == reflect.Invalid { + return jsonBody{} + } + + body := jsonBody{} + + switch rv.Kind() { + case reflect.Struct: + fields := rivertags.Get(rv.Type()) + defaults := reflect.New(rv.Type()).Elem() + if defaults.CanAddr() && defaults.Addr().Type().Implements(goRiverDefaulter) { + defaults.Addr().Interface().(value.Defaulter).SetToDefault() + } + + for _, field := range fields { + fieldVal := reflectutil.Get(rv, field) + fieldValDefault := reflectutil.Get(defaults, field) + + isEqual := fieldVal.Comparable() && fieldVal.Equal(fieldValDefault) + isZero := fieldValDefault.IsZero() && fieldVal.IsZero() + + if field.IsOptional() && (isEqual || isZero) { + continue + } + + body = append(body, encodeFieldAsStatements(nil, field, fieldVal)...) + } + + case reflect.Map: + if rv.Type().Key().Kind() != reflect.String { + panic("river/encoding/riverjson: unsupported map type; expected map[string]T, got " + rv.Type().String()) + } + + iter := rv.MapRange() + for iter.Next() { + mapKey, mapValue := iter.Key(), iter.Value() + + body = append(body, jsonAttr{ + Name: mapKey.String(), + Type: "attr", + Value: buildJSONValue(value.FromRaw(mapValue)), + }) + } + + default: + panic(fmt.Sprintf("river/encoding/riverjson: can only encode struct or map[string]T values to bodies, got %s", rv.Kind())) + } + + return body +} + +// encodeFieldAsStatements encodes an individual field from a struct as a set +// of statements. One field may map to multiple statements in the case of a +// slice of blocks. +func encodeFieldAsStatements(prefix []string, field rivertags.Field, fieldValue reflect.Value) []jsonStatement { + fieldName := strings.Join(field.Name, ".") + + for fieldValue.Kind() == reflect.Pointer { + if fieldValue.IsNil() { + break + } + fieldValue = fieldValue.Elem() + } + + switch { + case field.IsAttr(): + return []jsonStatement{jsonAttr{ + Name: fieldName, + Type: "attr", + Value: buildJSONValue(value.FromRaw(fieldValue)), + }} + + case field.IsBlock(): + fullName := mergeStringSlice(prefix, field.Name) + + switch { + case fieldValue.Kind() == reflect.Map: + // Iterate over the map and add each element as an attribute into it. + + if fieldValue.Type().Key().Kind() != reflect.String { + panic("river/encoding/riverjson: unsupported map type for block; expected map[string]T, got " + fieldValue.Type().String()) + } + + statements := []jsonStatement{} + + iter := fieldValue.MapRange() + for iter.Next() { + mapKey, mapValue := iter.Key(), iter.Value() + + statements = append(statements, jsonAttr{ + Name: mapKey.String(), + Type: "attr", + Value: buildJSONValue(value.FromRaw(mapValue)), + }) + } + + return []jsonStatement{jsonBlock{ + Name: strings.Join(fullName, "."), + Type: "block", + Body: statements, + }} + + case fieldValue.Kind() == reflect.Slice, fieldValue.Kind() == reflect.Array: + statements := []jsonStatement{} + + for i := 0; i < fieldValue.Len(); i++ { + elem := fieldValue.Index(i) + + // Recursively call encodeField for each element in the slice/array. + // The recursive call will hit the case below and add a new block for + // each field encountered. + statements = append(statements, encodeFieldAsStatements(prefix, field, elem)...) + } + + return statements + + case fieldValue.Kind() == reflect.Struct: + if fieldValue.IsZero() { + // It shouldn't be possible to have a required block which is unset, but + // we'll encode something anyway. + return []jsonStatement{jsonBlock{ + Name: strings.Join(fullName, "."), + Type: "block", + + // Never set this to nil, since the API contract always expects blocks + // to have an array value for the body. + Body: []jsonStatement{}, + }} + } + + return []jsonStatement{jsonBlock{ + Name: strings.Join(fullName, "."), + Type: "block", + Label: getBlockLabel(fieldValue), + Body: encodeStructAsBody(fieldValue), + }} + } + + case field.IsEnum(): + // Blocks within an enum have a prefix set. + newPrefix := mergeStringSlice(prefix, field.Name) + + switch { + case fieldValue.Kind() == reflect.Slice, fieldValue.Kind() == reflect.Array: + statements := []jsonStatement{} + for i := 0; i < fieldValue.Len(); i++ { + statements = append(statements, encodeEnumElementToStatements(newPrefix, fieldValue.Index(i))...) + } + return statements + + default: + panic(fmt.Sprintf("river/encoding/riverjson: unrecognized enum kind %s", fieldValue.Kind())) + } + } + + return nil +} + +func mergeStringSlice(a, b []string) []string { + if len(a) == 0 { + return b + } else if len(b) == 0 { + return a + } + + res := make([]string, 0, len(a)+len(b)) + res = append(res, a...) + res = append(res, b...) + return res +} + +// getBlockLabel returns the label for a given block. +func getBlockLabel(rv reflect.Value) string { + tags := rivertags.Get(rv.Type()) + for _, tag := range tags { + if tag.Flags&rivertags.FlagLabel != 0 { + return reflectutil.Get(rv, tag).String() + } + } + + return "" +} + +func encodeEnumElementToStatements(prefix []string, enumElement reflect.Value) []jsonStatement { + for enumElement.Kind() == reflect.Pointer { + if enumElement.IsNil() { + return nil + } + enumElement = enumElement.Elem() + } + + fields := rivertags.Get(enumElement.Type()) + + statements := []jsonStatement{} + + // Find the first non-zero field and encode it. + for _, field := range fields { + fieldVal := reflectutil.Get(enumElement, field) + if !fieldVal.IsValid() || fieldVal.IsZero() { + continue + } + + statements = append(statements, encodeFieldAsStatements(prefix, field, fieldVal)...) + break + } + + return statements +} + +// MarshalValue marshals the provided Go value to a JSON representation of +// River. +func MarshalValue(val interface{}) ([]byte, error) { + riverValue := value.Encode(val) + return json.Marshal(buildJSONValue(riverValue)) +} + +func buildJSONValue(v value.Value) jsonValue { + if tk, ok := v.Interface().(builder.Tokenizer); ok { + return jsonValue{ + Type: "capsule", + Value: tk.RiverTokenize()[0].Lit, + } + } + + switch v.Type() { + case value.TypeNull: + return jsonValue{Type: "null"} + + case value.TypeNumber: + return jsonValue{Type: "number", Value: v.Number().Float()} + + case value.TypeString: + return jsonValue{Type: "string", Value: v.Text()} + + case value.TypeBool: + return jsonValue{Type: "bool", Value: v.Bool()} + + case value.TypeArray: + elements := []interface{}{} + + for i := 0; i < v.Len(); i++ { + element := v.Index(i) + + elements = append(elements, buildJSONValue(element)) + } + + return jsonValue{Type: "array", Value: elements} + + case value.TypeObject: + keys := v.Keys() + + // If v isn't an ordered object (i.e., a go map), sort the keys so they + // have a deterministic print order. + if !v.OrderedKeys() { + sort.Strings(keys) + } + + fields := []jsonObjectField{} + + for i := 0; i < len(keys); i++ { + field, _ := v.Key(keys[i]) + + fields = append(fields, jsonObjectField{ + Key: keys[i], + Value: buildJSONValue(field), + }) + } + + return jsonValue{Type: "object", Value: fields} + + case value.TypeFunction: + return jsonValue{Type: "function", Value: v.Describe()} + + case value.TypeCapsule: + return jsonValue{Type: "capsule", Value: v.Describe()} + + default: + panic(fmt.Sprintf("river/encoding/riverjson: unrecognized value type %q", v.Type())) + } +} diff --git a/syntax/encoding/riverjson/riverjson_test.go b/syntax/encoding/riverjson/riverjson_test.go new file mode 100644 index 0000000000..0eeb321b59 --- /dev/null +++ b/syntax/encoding/riverjson/riverjson_test.go @@ -0,0 +1,363 @@ +package riverjson_test + +import ( + "testing" + + river "github.com/grafana/river" + "github.com/grafana/river/encoding/riverjson" + "github.com/grafana/river/rivertypes" + "github.com/stretchr/testify/require" +) + +func TestValues(t *testing.T) { + tt := []struct { + name string + input interface{} + expectJSON string + }{ + { + name: "null", + input: nil, + expectJSON: `{ "type": "null", "value": null }`, + }, + { + name: "number", + input: 54, + expectJSON: `{ "type": "number", "value": 54 }`, + }, + { + name: "string", + input: "Hello, world!", + expectJSON: `{ "type": "string", "value": "Hello, world!" }`, + }, + { + name: "bool", + input: true, + expectJSON: `{ "type": "bool", "value": true }`, + }, + { + name: "simple array", + input: []int{0, 1, 2, 3, 4}, + expectJSON: `{ + "type": "array", + "value": [ + { "type": "number", "value": 0 }, + { "type": "number", "value": 1 }, + { "type": "number", "value": 2 }, + { "type": "number", "value": 3 }, + { "type": "number", "value": 4 } + ] + }`, + }, + { + name: "nested array", + input: []interface{}{"testing", []int{0, 1, 2}}, + expectJSON: `{ + "type": "array", + "value": [ + { "type": "string", "value": "testing" }, + { + "type": "array", + "value": [ + { "type": "number", "value": 0 }, + { "type": "number", "value": 1 }, + { "type": "number", "value": 2 } + ] + } + ] + }`, + }, + { + name: "object", + input: map[string]any{"foo": "bar", "fizz": "buzz", "year": 2023}, + expectJSON: `{ + "type": "object", + "value": [ + { "key": "fizz", "value": { "type": "string", "value": "buzz" }}, + { "key": "foo", "value": { "type": "string", "value": "bar" }}, + { "key": "year", "value": { "type": "number", "value": 2023 }} + ] + }`, + }, + { + name: "function", + input: func(i int) int { return i * 2 }, + expectJSON: `{ "type": "function", "value": "function" }`, + }, + { + name: "capsule", + input: rivertypes.Secret("foo"), + expectJSON: `{ "type": "capsule", "value": "(secret)" }`, + }, + { + // nil arrays and objects must always be [] instead of null as that's + // what the API definition says they should be. + name: "nil array", + input: ([]any)(nil), + expectJSON: `{ "type": "array", "value": [] }`, + }, + { + // nil arrays and objects must always be [] instead of null as that's + // what the API definition says they should be. + name: "nil object", + input: (map[string]any)(nil), + expectJSON: `{ "type": "object", "value": [] }`, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + actual, err := riverjson.MarshalValue(tc.input) + require.NoError(t, err) + require.JSONEq(t, tc.expectJSON, string(actual)) + }) + } +} + +func TestBlock(t *testing.T) { + // Zero values should be omitted from result. + + val := testBlock{ + Number: 5, + Array: []any{1, 2, 3}, + Labeled: []labeledBlock{ + { + TestBlock: testBlock{Boolean: true}, + Label: "label_a", + }, + { + TestBlock: testBlock{String: "foo"}, + Label: "label_b", + }, + }, + Blocks: []testBlock{ + {String: "hello"}, + {String: "world"}, + }, + } + + expect := `[ + { + "name": "number", + "type": "attr", + "value": { "type": "number", "value": 5 } + }, + { + "name": "array", + "type": "attr", + "value": { + "type": "array", + "value": [ + { "type": "number", "value": 1 }, + { "type": "number", "value": 2 }, + { "type": "number", "value": 3 } + ] + } + }, + { + "name": "labeled_block", + "type": "block", + "label": "label_a", + "body": [{ + "name": "boolean", + "type": "attr", + "value": { "type": "bool", "value": true } + }] + }, + { + "name": "labeled_block", + "type": "block", + "label": "label_b", + "body": [{ + "name": "string", + "type": "attr", + "value": { "type": "string", "value": "foo" } + }] + }, + { + "name": "inner_block", + "type": "block", + "body": [{ + "name": "string", + "type": "attr", + "value": { "type": "string", "value": "hello" } + }] + }, + { + "name": "inner_block", + "type": "block", + "body": [{ + "name": "string", + "type": "attr", + "value": { "type": "string", "value": "world" } + }] + } + ]` + + actual, err := riverjson.MarshalBody(val) + require.NoError(t, err) + require.JSONEq(t, expect, string(actual)) +} + +func TestBlock_Empty_Required_Block_Slice(t *testing.T) { + type wrapper struct { + Blocks []testBlock `river:"some_block,block"` + } + + tt := []struct { + name string + val any + }{ + {"nil block slice", wrapper{Blocks: nil}}, + {"empty block slice", wrapper{Blocks: []testBlock{}}}, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + expect := `[]` + + actual, err := riverjson.MarshalBody(tc.val) + require.NoError(t, err) + require.JSONEq(t, expect, string(actual)) + }) + } +} + +type testBlock struct { + Number int `river:"number,attr,optional"` + String string `river:"string,attr,optional"` + Boolean bool `river:"boolean,attr,optional"` + Array []any `river:"array,attr,optional"` + Object map[string]any `river:"object,attr,optional"` + + Labeled []labeledBlock `river:"labeled_block,block,optional"` + Blocks []testBlock `river:"inner_block,block,optional"` +} + +type labeledBlock struct { + TestBlock testBlock `river:",squash"` + Label string `river:",label"` +} + +func TestNilBody(t *testing.T) { + actual, err := riverjson.MarshalBody(nil) + require.NoError(t, err) + require.JSONEq(t, `[]`, string(actual)) +} + +func TestEmptyBody(t *testing.T) { + type block struct{} + + actual, err := riverjson.MarshalBody(block{}) + require.NoError(t, err) + require.JSONEq(t, `[]`, string(actual)) +} + +func TestHideDefaults(t *testing.T) { + tt := []struct { + name string + val defaultsBlock + expectJSON string + }{ + { + name: "no defaults", + val: defaultsBlock{ + Name: "Jane", + Age: 41, + }, + expectJSON: `[ + { "name": "name", "type": "attr", "value": { "type": "string", "value": "Jane" }}, + { "name": "age", "type": "attr", "value": { "type": "number", "value": 41 }} + ]`, + }, + { + name: "some defaults", + val: defaultsBlock{ + Name: "John Doe", + Age: 41, + }, + expectJSON: `[ + { "name": "age", "type": "attr", "value": { "type": "number", "value": 41 }} + ]`, + }, + { + name: "all defaults", + val: defaultsBlock{ + Name: "John Doe", + Age: 35, + }, + expectJSON: `[]`, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + actual, err := riverjson.MarshalBody(tc.val) + require.NoError(t, err) + require.JSONEq(t, tc.expectJSON, string(actual)) + }) + } +} + +type defaultsBlock struct { + Name string `river:"name,attr,optional"` + Age int `river:"age,attr,optional"` +} + +var _ river.Defaulter = (*defaultsBlock)(nil) + +func (d *defaultsBlock) SetToDefault() { + *d = defaultsBlock{ + Name: "John Doe", + Age: 35, + } +} + +func TestMapBlocks(t *testing.T) { + type block struct { + Value map[string]any `river:"block,block,optional"` + } + val := block{Value: map[string]any{"field": "value"}} + + expect := `[{ + "name": "block", + "type": "block", + "body": [{ + "name": "field", + "type": "attr", + "value": { "type": "string", "value": "value" } + }] + }]` + + bb, err := riverjson.MarshalBody(val) + require.NoError(t, err) + require.JSONEq(t, expect, string(bb)) +} + +func TestRawMap(t *testing.T) { + val := map[string]any{"field": "value"} + + expect := `[{ + "name": "field", + "type": "attr", + "value": { "type": "string", "value": "value" } + }]` + + bb, err := riverjson.MarshalBody(val) + require.NoError(t, err) + require.JSONEq(t, expect, string(bb)) +} + +func TestRawMap_Capsule(t *testing.T) { + val := map[string]any{"capsule": rivertypes.Secret("foo")} + + expect := `[{ + "name": "capsule", + "type": "attr", + "value": { "type": "capsule", "value": "(secret)" } + }]` + + bb, err := riverjson.MarshalBody(val) + require.NoError(t, err) + require.JSONEq(t, expect, string(bb)) +} diff --git a/syntax/encoding/riverjson/types.go b/syntax/encoding/riverjson/types.go new file mode 100644 index 0000000000..3170331e46 --- /dev/null +++ b/syntax/encoding/riverjson/types.go @@ -0,0 +1,41 @@ +package riverjson + +// Various concrete types used to marshal River values. +type ( + // jsonStatement is a statement within a River body. + jsonStatement interface{ isStatement() } + + // A jsonBody is a collection of statements. + jsonBody = []jsonStatement + + // jsonBlock represents a River block as JSON. jsonBlock is a jsonStatement. + jsonBlock struct { + Name string `json:"name"` + Type string `json:"type"` // Always "block" + Label string `json:"label,omitempty"` + Body []jsonStatement `json:"body"` + } + + // jsonAttr represents a River attribute as JSON. jsonAttr is a + // jsonStatement. + jsonAttr struct { + Name string `json:"name"` + Type string `json:"type"` // Always "attr" + Value jsonValue `json:"value"` + } + + // jsonValue represents a single River value as JSON. + jsonValue struct { + Type string `json:"type"` + Value interface{} `json:"value"` + } + + // jsonObjectField represents a field within a River object. + jsonObjectField struct { + Key string `json:"key"` + Value interface{} `json:"value"` + } +) + +func (jsonBlock) isStatement() {} +func (jsonAttr) isStatement() {} diff --git a/syntax/go.mod b/syntax/go.mod new file mode 100644 index 0000000000..6b85c2e5da --- /dev/null +++ b/syntax/go.mod @@ -0,0 +1,18 @@ +module github.com/grafana/river + +go 1.21.0 + +require ( + github.com/fatih/color v1.15.0 + github.com/ohler55/ojg v1.20.1 + github.com/stretchr/testify v1.8.4 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.17 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/sys v0.6.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/syntax/go.sum b/syntax/go.sum new file mode 100644 index 0000000000..a972b9838b --- /dev/null +++ b/syntax/go.sum @@ -0,0 +1,22 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fatih/color v1.15.0 h1:kOqh6YHBtK8aywxGerMG2Eq3H6Qgoqeo13Bk2Mv/nBs= +github.com/fatih/color v1.15.0/go.mod h1:0h5ZqXfHYED7Bhv2ZJamyIOUej9KtShiJESRwBDUSsw= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng= +github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/ohler55/ojg v1.20.1 h1:Io65sHjMjYPI7yuhUr8VdNmIQdYU6asKeFhOs8xgBnY= +github.com/ohler55/ojg v1.20.1/go.mod h1:uHcD1ErbErC27Zhb5Df2jUjbseLLcmOCo6oxSr3jZxo= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/syntax/internal/reflectutil/walk.go b/syntax/internal/reflectutil/walk.go new file mode 100644 index 0000000000..17cbb25d49 --- /dev/null +++ b/syntax/internal/reflectutil/walk.go @@ -0,0 +1,89 @@ +package reflectutil + +import ( + "reflect" + + "github.com/grafana/river/internal/rivertags" +) + +// GetOrAlloc returns the nested field of value corresponding to index. +// GetOrAlloc panics if not given a struct. +func GetOrAlloc(value reflect.Value, field rivertags.Field) reflect.Value { + return GetOrAllocIndex(value, field.Index) +} + +// GetOrAllocIndex returns the nested field of value corresponding to index. +// GetOrAllocIndex panics if not given a struct. +// +// It is similar to [reflect/Value.FieldByIndex] but can handle traversing +// through nil pointers. If allocate is true, GetOrAllocIndex allocates any +// intermediate nil pointers while traversing the struct. +func GetOrAllocIndex(value reflect.Value, index []int) reflect.Value { + if len(index) == 1 { + return value.Field(index[0]) + } + + if value.Kind() != reflect.Struct { + panic("GetOrAlloc must be given a Struct, but found " + value.Kind().String()) + } + + for _, next := range index { + value = deferencePointer(value).Field(next) + } + + return value +} + +func deferencePointer(value reflect.Value) reflect.Value { + for value.Kind() == reflect.Pointer { + if value.IsNil() { + value.Set(reflect.New(value.Type().Elem())) + } + value = value.Elem() + } + + return value +} + +// Get returns the nested field of value corresponding to index. Get panics if +// not given a struct. +// +// It is similar to [reflect/Value.FieldByIndex] but can handle traversing +// through nil pointers. If Get traverses through a nil pointer, a non-settable +// zero value for the final field is returned. +func Get(value reflect.Value, field rivertags.Field) reflect.Value { + if len(field.Index) == 1 { + return value.Field(field.Index[0]) + } + + if value.Kind() != reflect.Struct { + panic("Get must be given a Struct, but found " + value.Kind().String()) + } + + for i, next := range field.Index { + for value.Kind() == reflect.Pointer { + if value.IsNil() { + return getZero(value, field.Index[i:]) + } + value = value.Elem() + } + + value = value.Field(next) + } + + return value +} + +// getZero returns a non-settable zero value while walking value. +func getZero(value reflect.Value, index []int) reflect.Value { + typ := value.Type() + + for _, next := range index { + for typ.Kind() == reflect.Pointer { + typ = typ.Elem() + } + typ = typ.Field(next).Type + } + + return reflect.Zero(typ) +} diff --git a/syntax/internal/reflectutil/walk_test.go b/syntax/internal/reflectutil/walk_test.go new file mode 100644 index 0000000000..f536770e5a --- /dev/null +++ b/syntax/internal/reflectutil/walk_test.go @@ -0,0 +1,72 @@ +package reflectutil_test + +import ( + "reflect" + "testing" + + "github.com/grafana/river/internal/reflectutil" + "github.com/grafana/river/internal/rivertags" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDeeplyNested_Access(t *testing.T) { + type Struct struct { + Field1 struct { + Field2 struct { + Field3 struct { + Value string + } + } + } + } + + var s Struct + s.Field1.Field2.Field3.Value = "Hello, world!" + + rv := reflect.ValueOf(&s).Elem() + innerValue := reflectutil.GetOrAlloc(rv, rivertags.Field{Index: []int{0, 0, 0, 0}}) + assert.True(t, innerValue.CanSet()) + assert.Equal(t, reflect.String, innerValue.Kind()) +} + +func TestDeeplyNested_Allocate(t *testing.T) { + type Struct struct { + Field1 *struct { + Field2 *struct { + Field3 *struct { + Value string + } + } + } + } + + var s Struct + + rv := reflect.ValueOf(&s).Elem() + innerValue := reflectutil.GetOrAlloc(rv, rivertags.Field{Index: []int{0, 0, 0, 0}}) + require.True(t, innerValue.CanSet()) + require.Equal(t, reflect.String, innerValue.Kind()) + + innerValue.Set(reflect.ValueOf("Hello, world!")) + require.Equal(t, "Hello, world!", s.Field1.Field2.Field3.Value) +} + +func TestDeeplyNested_NoAllocate(t *testing.T) { + type Struct struct { + Field1 *struct { + Field2 *struct { + Field3 *struct { + Value string + } + } + } + } + + var s Struct + + rv := reflect.ValueOf(&s).Elem() + innerValue := reflectutil.Get(rv, rivertags.Field{Index: []int{0, 0, 0, 0}}) + assert.False(t, innerValue.CanSet()) + assert.Equal(t, reflect.String, innerValue.Kind()) +} diff --git a/syntax/internal/rivertags/rivertags.go b/syntax/internal/rivertags/rivertags.go new file mode 100644 index 0000000000..8186a644e0 --- /dev/null +++ b/syntax/internal/rivertags/rivertags.go @@ -0,0 +1,346 @@ +// Package rivertags decodes a struct type into river object +// and structural tags. +package rivertags + +import ( + "fmt" + "reflect" + "strings" +) + +// Flags is a bitmap of flags associated with a field on a struct. +type Flags uint + +// Valid flags. +const ( + FlagAttr Flags = 1 << iota // FlagAttr treats a field as attribute + FlagBlock // FlagBlock treats a field as a block + FlagEnum // FlagEnum treats a field as an enum of blocks + + FlagOptional // FlagOptional marks a field optional for decoding/encoding + FlagLabel // FlagLabel will store block labels in the field + FlagSquash // FlagSquash will expose inner fields from a struct as outer fields. +) + +// String returns the flags as a string. +func (f Flags) String() string { + attrs := make([]string, 0, 5) + + if f&FlagAttr != 0 { + attrs = append(attrs, "attr") + } + if f&FlagBlock != 0 { + attrs = append(attrs, "block") + } + if f&FlagEnum != 0 { + attrs = append(attrs, "enum") + } + if f&FlagOptional != 0 { + attrs = append(attrs, "optional") + } + if f&FlagLabel != 0 { + attrs = append(attrs, "label") + } + if f&FlagSquash != 0 { + attrs = append(attrs, "squash") + } + + return fmt.Sprintf("Flags(%s)", strings.Join(attrs, ",")) +} + +// GoString returns the %#v format of Flags. +func (f Flags) GoString() string { return f.String() } + +// Field is a tagged field within a struct. +type Field struct { + Name []string // Name of tagged field. + Index []int // Index into field. Use [reflectutil.GetOrAlloc] to retrieve a Value. + Flags Flags // Flags assigned to field. +} + +// Equals returns true if two fields are equal. +func (f Field) Equals(other Field) bool { + // Compare names + { + if len(f.Name) != len(other.Name) { + return false + } + + for i := 0; i < len(f.Name); i++ { + if f.Name[i] != other.Name[i] { + return false + } + } + } + + // Compare index. + { + if len(f.Index) != len(other.Index) { + return false + } + + for i := 0; i < len(f.Index); i++ { + if f.Index[i] != other.Index[i] { + return false + } + } + } + + // Finally, compare flags. + return f.Flags == other.Flags +} + +// IsAttr returns whether f is for an attribute. +func (f Field) IsAttr() bool { return f.Flags&FlagAttr != 0 } + +// IsBlock returns whether f is for a block. +func (f Field) IsBlock() bool { return f.Flags&FlagBlock != 0 } + +// IsEnum returns whether f represents an enum of blocks, where only one block +// is set at a time. +func (f Field) IsEnum() bool { return f.Flags&FlagEnum != 0 } + +// IsOptional returns whether f is optional. +func (f Field) IsOptional() bool { return f.Flags&FlagOptional != 0 } + +// IsLabel returns whether f is label. +func (f Field) IsLabel() bool { return f.Flags&FlagLabel != 0 } + +// Get returns the list of tagged fields for some struct type ty. Get panics if +// ty is not a struct type. +// +// Get examines each tagged field in ty for a river key. The river key is then +// parsed as containing a name for the field, followed by a required +// comma-separated list of options. The name may be empty for fields which do +// not require a name. Get will ignore any field that is not tagged with a +// river key. +// +// Get will treat anonymous struct fields as if the inner fields were fields in +// the outer struct. +// +// Examples of struct field tags and their meanings: +// +// // Field is used as a required block named "my_block". +// Field struct{} `river:"my_block,block"` +// +// // Field is used as an optional block named "my_block". +// Field struct{} `river:"my_block,block,optional"` +// +// // Field is used as a required attribute named "my_attr". +// Field string `river:"my_attr,attr"` +// +// // Field is used as an optional attribute named "my_attr". +// Field string `river:"my_attr,attr,optional"` +// +// // Field is used for storing the label of the block which the struct +// // represents. +// Field string `river:",label"` +// +// // Attributes and blocks inside of Field are exposed as top-level fields. +// Field struct{} `river:",squash"` +// +// Blocks []struct{} `river:"my_block_prefix,enum"` +// +// With the exception of the `river:",label"` and `river:",squash" tags, all +// tagged fields must have a unique name. +// +// The type of tagged fields may be any Go type, with the exception of +// `river:",label"` tags, which must be strings. +func Get(ty reflect.Type) []Field { + if k := ty.Kind(); k != reflect.Struct { + panic(fmt.Sprintf("rivertags: Get requires struct kind, got %s", k)) + } + + var ( + fields []Field + + usedNames = make(map[string][]int) + usedLabelField = []int(nil) + ) + + for _, field := range reflect.VisibleFields(ty) { + // River does not support embedding of fields + if field.Anonymous { + panic(fmt.Sprintf("river: anonymous fields not supported %s", printPathToField(ty, field.Index))) + } + + tag, tagged := field.Tag.Lookup("river") + if !tagged { + continue + } + + if !field.IsExported() { + panic(fmt.Sprintf("river: river tag found on unexported field at %s", printPathToField(ty, field.Index))) + } + + options := strings.SplitN(tag, ",", 2) + if len(options) == 0 { + panic(fmt.Sprintf("river: unsupported empty tag at %s", printPathToField(ty, field.Index))) + } + if len(options) != 2 { + panic(fmt.Sprintf("river: field %s tag is missing options", printPathToField(ty, field.Index))) + } + + fullName := options[0] + + tf := Field{ + Name: strings.Split(fullName, "."), + Index: field.Index, + } + + if first, used := usedNames[fullName]; used && fullName != "" { + panic(fmt.Sprintf("river: field name %s already used by %s", fullName, printPathToField(ty, first))) + } + usedNames[fullName] = tf.Index + + flags, ok := parseFlags(options[1]) + if !ok { + panic(fmt.Sprintf("river: unrecognized river tag format %q at %s", tag, printPathToField(ty, tf.Index))) + } + tf.Flags = flags + + if len(tf.Name) > 1 && tf.Flags&(FlagBlock|FlagEnum) == 0 { + panic(fmt.Sprintf("river: field names with `.` may only be used by blocks or enums (found at %s)", printPathToField(ty, tf.Index))) + } + + if tf.Flags&FlagEnum != 0 { + if err := validateEnum(field); err != nil { + panic(err) + } + } + + if tf.Flags&FlagLabel != 0 { + if fullName != "" { + panic(fmt.Sprintf("river: label field at %s must not have a name", printPathToField(ty, tf.Index))) + } + if field.Type.Kind() != reflect.String { + panic(fmt.Sprintf("river: label field at %s must be a string", printPathToField(ty, tf.Index))) + } + + if usedLabelField != nil { + panic(fmt.Sprintf("river: label field already used by %s", printPathToField(ty, tf.Index))) + } + usedLabelField = tf.Index + } + + if tf.Flags&FlagSquash != 0 { + if fullName != "" { + panic(fmt.Sprintf("river: squash field at %s must not have a name", printPathToField(ty, tf.Index))) + } + + innerType := deferenceType(field.Type) + + switch { + case isStructType(innerType): // Squashed struct + // Get the inner fields from the squashed struct and append each of them. + // The index of the squashed field is prepended to the index of the inner + // struct. + innerFields := Get(deferenceType(field.Type)) + for _, innerField := range innerFields { + fields = append(fields, Field{ + Name: innerField.Name, + Index: append(field.Index, innerField.Index...), + Flags: innerField.Flags, + }) + } + + default: + panic(fmt.Sprintf("rivertags: squash field requires struct, got %s", innerType)) + } + + continue + } + + if fullName == "" && tf.Flags&(FlagLabel|FlagSquash) == 0 /* (e.g., *not* a label or squash) */ { + panic(fmt.Sprintf("river: non-empty field name required at %s", printPathToField(ty, tf.Index))) + } + + fields = append(fields, tf) + } + + return fields +} + +func parseFlags(input string) (f Flags, ok bool) { + switch input { + case "attr": + f |= FlagAttr + case "attr,optional": + f |= FlagAttr | FlagOptional + case "block": + f |= FlagBlock + case "block,optional": + f |= FlagBlock | FlagOptional + case "enum": + f |= FlagEnum + case "enum,optional": + f |= FlagEnum | FlagOptional + case "label": + f |= FlagLabel + case "squash": + f |= FlagSquash + default: + return + } + + return f, true +} + +func printPathToField(structTy reflect.Type, path []int) string { + var sb strings.Builder + + sb.WriteString(structTy.String()) + sb.WriteString(".") + + cur := structTy + for i, elem := range path { + sb.WriteString(cur.Field(elem).Name) + + if i+1 < len(path) { + sb.WriteString(".") + } + + cur = cur.Field(i).Type + } + + return sb.String() +} + +func deferenceType(ty reflect.Type) reflect.Type { + for ty.Kind() == reflect.Pointer { + ty = ty.Elem() + } + return ty +} + +func isStructType(ty reflect.Type) bool { + return ty.Kind() == reflect.Struct +} + +// validateEnum ensures that an enum field is valid. Valid enum fields are +// slices of structs containing nothing but non-slice blocks. +func validateEnum(field reflect.StructField) error { + kind := field.Type.Kind() + if kind != reflect.Slice && kind != reflect.Array { + return fmt.Errorf("enum fields can only be slices or arrays") + } + + elementType := deferenceType(field.Type.Elem()) + if elementType.Kind() != reflect.Struct { + return fmt.Errorf("enum fields can only be a slice or array of structs") + } + + enumElementFields := Get(elementType) + for _, field := range enumElementFields { + if !field.IsBlock() { + return fmt.Errorf("fields in an enum element may only be blocks, got " + field.Flags.String()) + } + + fieldType := deferenceType(elementType.FieldByIndex(field.Index).Type) + if fieldType.Kind() != reflect.Struct { + return fmt.Errorf("blocks in an enum element may only be structs, got " + fieldType.Kind().String()) + } + } + + return nil +} diff --git a/syntax/internal/rivertags/rivertags_test.go b/syntax/internal/rivertags/rivertags_test.go new file mode 100644 index 0000000000..43370b33b4 --- /dev/null +++ b/syntax/internal/rivertags/rivertags_test.go @@ -0,0 +1,182 @@ +package rivertags_test + +import ( + "reflect" + "testing" + + "github.com/grafana/river/internal/rivertags" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_Get(t *testing.T) { + type Struct struct { + IgnoreMe bool + + ReqAttr string `river:"req_attr,attr"` + OptAttr string `river:"opt_attr,attr,optional"` + ReqBlock struct{} `river:"req_block,block"` + OptBlock struct{} `river:"opt_block,block,optional"` + ReqEnum []struct{} `river:"req_enum,enum"` + OptEnum []struct{} `river:"opt_enum,enum,optional"` + Label string `river:",label"` + } + + fs := rivertags.Get(reflect.TypeOf(Struct{})) + + expect := []rivertags.Field{ + {[]string{"req_attr"}, []int{1}, rivertags.FlagAttr}, + {[]string{"opt_attr"}, []int{2}, rivertags.FlagAttr | rivertags.FlagOptional}, + {[]string{"req_block"}, []int{3}, rivertags.FlagBlock}, + {[]string{"opt_block"}, []int{4}, rivertags.FlagBlock | rivertags.FlagOptional}, + {[]string{"req_enum"}, []int{5}, rivertags.FlagEnum}, + {[]string{"opt_enum"}, []int{6}, rivertags.FlagEnum | rivertags.FlagOptional}, + {[]string{""}, []int{7}, rivertags.FlagLabel}, + } + + require.Equal(t, expect, fs) +} + +func TestEmbedded(t *testing.T) { + type InnerStruct struct { + InnerField1 string `river:"inner_field_1,attr"` + InnerField2 string `river:"inner_field_2,attr"` + } + + type Struct struct { + Field1 string `river:"parent_field_1,attr"` + InnerStruct + Field2 string `river:"parent_field_2,attr"` + } + require.PanicsWithValue(t, "river: anonymous fields not supported rivertags_test.Struct.InnerStruct", func() { rivertags.Get(reflect.TypeOf(Struct{})) }) +} + +func TestSquash(t *testing.T) { + type InnerStruct struct { + InnerField1 string `river:"inner_field_1,attr"` + InnerField2 string `river:"inner_field_2,attr"` + } + + type Struct struct { + Field1 string `river:"parent_field_1,attr"` + Inner InnerStruct `river:",squash"` + Field2 string `river:"parent_field_2,attr"` + } + + type StructWithPointer struct { + Field1 string `river:"parent_field_1,attr"` + Inner *InnerStruct `river:",squash"` + Field2 string `river:"parent_field_2,attr"` + } + + expect := []rivertags.Field{ + { + Name: []string{"parent_field_1"}, + Index: []int{0}, + Flags: rivertags.FlagAttr, + }, + { + Name: []string{"inner_field_1"}, + Index: []int{1, 0}, + Flags: rivertags.FlagAttr, + }, + { + Name: []string{"inner_field_2"}, + Index: []int{1, 1}, + Flags: rivertags.FlagAttr, + }, + { + Name: []string{"parent_field_2"}, + Index: []int{2}, + Flags: rivertags.FlagAttr, + }, + } + + structActual := rivertags.Get(reflect.TypeOf(Struct{})) + assert.Equal(t, expect, structActual) + + structPointerActual := rivertags.Get(reflect.TypeOf(StructWithPointer{})) + assert.Equal(t, expect, structPointerActual) +} + +func TestDeepSquash(t *testing.T) { + type Inner2Struct struct { + InnerField1 string `river:"inner_field_1,attr"` + InnerField2 string `river:"inner_field_2,attr"` + } + + type InnerStruct struct { + Inner2Struct Inner2Struct `river:",squash"` + } + + type Struct struct { + Inner InnerStruct `river:",squash"` + } + + expect := []rivertags.Field{ + { + Name: []string{"inner_field_1"}, + Index: []int{0, 0, 0}, + Flags: rivertags.FlagAttr, + }, + { + Name: []string{"inner_field_2"}, + Index: []int{0, 0, 1}, + Flags: rivertags.FlagAttr, + }, + } + + structActual := rivertags.Get(reflect.TypeOf(Struct{})) + assert.Equal(t, expect, structActual) +} + +func Test_Get_Panics(t *testing.T) { + expectPanic := func(t *testing.T, expect string, v interface{}) { + t.Helper() + require.PanicsWithValue(t, expect, func() { + _ = rivertags.Get(reflect.TypeOf(v)) + }) + } + + t.Run("Tagged fields must be exported", func(t *testing.T) { + type Struct struct { + attr string `river:"field,attr"` // nolint:unused //nolint:rivertags + } + expect := `river: river tag found on unexported field at rivertags_test.Struct.attr` + expectPanic(t, expect, Struct{}) + }) + + t.Run("Options are required", func(t *testing.T) { + type Struct struct { + Attr string `river:"field"` //nolint:rivertags + } + expect := `river: field rivertags_test.Struct.Attr tag is missing options` + expectPanic(t, expect, Struct{}) + }) + + t.Run("Field names must be unique", func(t *testing.T) { + type Struct struct { + Attr string `river:"field1,attr"` + Block string `river:"field1,block,optional"` //nolint:rivertags + } + expect := `river: field name field1 already used by rivertags_test.Struct.Attr` + expectPanic(t, expect, Struct{}) + }) + + t.Run("Name is required for non-label field", func(t *testing.T) { + type Struct struct { + Attr string `river:",attr"` //nolint:rivertags + } + expect := `river: non-empty field name required at rivertags_test.Struct.Attr` + expectPanic(t, expect, Struct{}) + }) + + t.Run("Only one label field may exist", func(t *testing.T) { + type Struct struct { + Label1 string `river:",label"` + Label2 string `river:",label"` + } + expect := `river: label field already used by rivertags_test.Struct.Label2` + expectPanic(t, expect, Struct{}) + }) +} diff --git a/syntax/internal/stdlib/constants.go b/syntax/internal/stdlib/constants.go new file mode 100644 index 0000000000..89525f855f --- /dev/null +++ b/syntax/internal/stdlib/constants.go @@ -0,0 +1,19 @@ +package stdlib + +import ( + "os" + "runtime" +) + +var constants = map[string]string{ + "hostname": "", // Initialized via init function + "os": runtime.GOOS, + "arch": runtime.GOARCH, +} + +func init() { + hostname, err := os.Hostname() + if err == nil { + constants["hostname"] = hostname + } +} diff --git a/syntax/internal/stdlib/stdlib.go b/syntax/internal/stdlib/stdlib.go new file mode 100644 index 0000000000..e73b950af8 --- /dev/null +++ b/syntax/internal/stdlib/stdlib.go @@ -0,0 +1,132 @@ +// Package stdlib contains standard library functions exposed to River configs. +package stdlib + +import ( + "encoding/json" + "fmt" + "os" + "strings" + + "github.com/grafana/river/internal/value" + "github.com/grafana/river/rivertypes" + "github.com/ohler55/ojg/jp" + "github.com/ohler55/ojg/oj" +) + +// Identifiers holds a list of stdlib identifiers by name. All interface{} +// values are River-compatible values. +// +// Function identifiers are Go functions with exactly one non-error return +// value, with an optionally supported error return value as the second return +// value. +var Identifiers = map[string]interface{}{ + // See constants.go for the definition. + "constants": constants, + + "env": os.Getenv, + + "nonsensitive": func(secret rivertypes.Secret) string { + return string(secret) + }, + + // concat is implemented as a raw function so it can bypass allocations + // converting arguments into []interface{}. concat is optimized to allow it + // to perform well when it is in the hot path for combining targets from many + // other blocks. + "concat": value.RawFunction(func(funcValue value.Value, args ...value.Value) (value.Value, error) { + if len(args) == 0 { + return value.Array(), nil + } + + // finalSize is the final size of the resulting concatenated array. We type + // check our arguments while computing what finalSize will be. + var finalSize int + for i, arg := range args { + if arg.Type() != value.TypeArray { + return value.Null, value.ArgError{ + Function: funcValue, + Argument: arg, + Index: i, + Inner: value.TypeError{ + Value: arg, + Expected: value.TypeArray, + }, + } + } + + finalSize += arg.Len() + } + + // Optimization: if there's only one array, we can just return it directly. + // This is done *after* the previous loop to ensure that args[0] is a River + // array. + if len(args) == 1 { + return args[0], nil + } + + raw := make([]value.Value, 0, finalSize) + for _, arg := range args { + for i := 0; i < arg.Len(); i++ { + raw = append(raw, arg.Index(i)) + } + } + + return value.Array(raw...), nil + }), + + "json_decode": func(in string) (interface{}, error) { + var res interface{} + err := json.Unmarshal([]byte(in), &res) + if err != nil { + return nil, err + } + return res, nil + }, + + "json_path": func(jsonString string, path string) (interface{}, error) { + jsonPathExpr, err := jp.ParseString(path) + if err != nil { + return nil, err + } + + jsonExpr, err := oj.ParseString(jsonString) + if err != nil { + return nil, err + } + + return jsonPathExpr.Get(jsonExpr), nil + }, + + "coalesce": value.RawFunction(func(funcValue value.Value, args ...value.Value) (value.Value, error) { + if len(args) == 0 { + return value.Null, nil + } + + for _, arg := range args { + if arg.Type() == value.TypeNull { + continue + } + + if !arg.Reflect().IsZero() { + if argType := value.RiverType(arg.Reflect().Type()); (argType == value.TypeArray || argType == value.TypeObject) && arg.Len() == 0 { + continue + } + + return arg, nil + } + } + + return args[len(args)-1], nil + }), + + "format": fmt.Sprintf, + "join": strings.Join, + "replace": strings.ReplaceAll, + "split": strings.Split, + "to_lower": strings.ToLower, + "to_upper": strings.ToUpper, + "trim": strings.Trim, + "trim_prefix": strings.TrimPrefix, + "trim_suffix": strings.TrimSuffix, + "trim_space": strings.TrimSpace, +} diff --git a/syntax/internal/value/capsule.go b/syntax/internal/value/capsule.go new file mode 100644 index 0000000000..7a522ff337 --- /dev/null +++ b/syntax/internal/value/capsule.go @@ -0,0 +1,53 @@ +package value + +import ( + "fmt" +) + +// Capsule is a marker interface for Go values which forces a type to be +// represented as a River capsule. This is useful for types whose underlying +// value is not a capsule, such as: +// +// // Secret is a secret value. It would normally be a River string since the +// // underlying Go type is string, but it's a capsule since it implements +// // the Capsule interface. +// type Secret string +// +// func (s Secret) RiverCapsule() {} +// +// Extension interfaces are used to describe additional behaviors for Capsules. +// ConvertibleCapsule allows defining custom conversion rules to convert +// between other Go values. +type Capsule interface { + RiverCapsule() +} + +// ErrNoConversion is returned by implementations of ConvertibleCapsule to +// denote that a custom conversion from or to a specific type is unavailable. +var ErrNoConversion = fmt.Errorf("no custom capsule conversion available") + +// ConvertibleFromCapsule is a Capsule which supports custom conversion rules +// from any Go type which is not the same as the capsule type. +type ConvertibleFromCapsule interface { + Capsule + + // ConvertFrom should modify the ConvertibleCapsule value based on the value + // of src. + // + // ConvertFrom should return ErrNoConversion if no conversion is available + // from src. + ConvertFrom(src interface{}) error +} + +// ConvertibleIntoCapsule is a Capsule which supports custom conversion rules +// into any Go type which is not the same as the capsule type. +type ConvertibleIntoCapsule interface { + Capsule + + // ConvertInto should convert its value and store it into dst. dst will be a + // pointer to a value which ConvertInto is expected to update. + // + // ConvertInto should return ErrNoConversion if no conversion into dst is + // available. + ConvertInto(dst interface{}) error +} diff --git a/syntax/internal/value/decode.go b/syntax/internal/value/decode.go new file mode 100644 index 0000000000..20df78eb6a --- /dev/null +++ b/syntax/internal/value/decode.go @@ -0,0 +1,674 @@ +package value + +import ( + "encoding" + "errors" + "fmt" + "math" + "reflect" + "time" + + "github.com/grafana/river/internal/reflectutil" +) + +// The Defaulter interface allows a type to implement default functionality +// in River evaluation. +// +// Defaulter will be called only on block and body river types. +// +// When using nested blocks, the wrapping type must also implement +// Defaulter to propagate the defaults of the wrapped type. Otherwise, +// defaults used for the wrapped type become inconsistent: +// +// - If the wrapped block is NOT defined in the River config, the wrapping +// type's defaults are used. +// - If the wrapped block IS defined in the River config, the wrapped type's +// defaults are used. +type Defaulter interface { + // SetToDefault is called when evaluating a block or body to set the value + // to its defaults. + SetToDefault() +} + +// Unmarshaler is a custom type which can be used to hook into the decoder. +type Unmarshaler interface { + // UnmarshalRiver is called when decoding a value. f should be invoked to + // continue decoding with a value to decode into. + UnmarshalRiver(f func(v interface{}) error) error +} + +// The Validator interface allows a type to implement validation functionality +// in River evaluation. +type Validator interface { + // Validate is called when evaluating a block or body to enforce the + // value is valid. + Validate() error +} + +// Decode assigns a Value val to a Go pointer target. Pointers will be +// allocated as necessary when decoding. +// +// As a performance optimization, the underlying Go value of val will be +// assigned directly to target if the Go types match. This means that pointers, +// slices, and maps will be passed by reference. Callers should take care not +// to modify any Values after decoding, unless it is expected by the contract +// of the type (i.e., when the type exposes a goroutine-safe API). In other +// cases, new maps and slices will be allocated as necessary. Call DecodeCopy +// to make a copy of val instead. +// +// When a direct assignment is not done, Decode first checks to see if target +// implements the Unmarshaler or text.Unmarshaler interface, invoking methods +// as appropriate. It will also use time.ParseDuration if target is +// *time.Duration. +// +// Next, Decode will attempt to convert val to the type expected by target for +// assignment. If val or target implement ConvertibleCapsule, conversion +// between values will be attempted by calling ConvertFrom and ConvertInto as +// appropriate. If val cannot be converted, an error is returned. +// +// River null values will decode into a nil Go pointer or the zero value for +// the non-pointer type. +// +// Decode will panic if target is not a pointer. +func Decode(val Value, target interface{}) error { + rt := reflect.ValueOf(target) + if rt.Kind() != reflect.Pointer { + panic("river/value: Decode called with non-pointer value") + } + + var d decoder + return d.decode(val, rt) +} + +// DecodeCopy is like Decode but a deep copy of val is always made. +// +// Unlike Decode, DecodeCopy will always invoke Unmarshaler and +// text.Unmarshaler interfaces (if implemented by target). +func DecodeCopy(val Value, target interface{}) error { + rt := reflect.ValueOf(target) + if rt.Kind() != reflect.Pointer { + panic("river/value: Decode called with non-pointer value") + } + + d := decoder{makeCopy: true} + return d.decode(val, rt) +} + +type decoder struct { + makeCopy bool +} + +func (d *decoder) decode(val Value, into reflect.Value) (err error) { + // If everything has decoded successfully, run Validate if implemented. + defer func() { + if err == nil { + if into.CanAddr() && into.Addr().Type().Implements(goRiverValidator) { + err = into.Addr().Interface().(Validator).Validate() + } else if into.Type().Implements(goRiverValidator) { + err = into.Interface().(Validator).Validate() + } + } + }() + + // Store the raw value from val and try to address it so we can do underlying + // type match assignment. + rawValue := val.rv + if rawValue.CanAddr() { + rawValue = rawValue.Addr() + } + + // Fully deference into and allocate pointers as necessary. + for into.Kind() == reflect.Pointer { + // Check for direct assignments before allocating pointers and dereferencing. + // This preserves pointer addresses when decoding an *int into an *int. + switch { + case into.CanSet() && val.Type() == TypeNull: + into.Set(reflect.Zero(into.Type())) + return nil + case into.CanSet() && d.canDirectlyAssign(rawValue.Type(), into.Type()): + into.Set(rawValue) + return nil + case into.CanSet() && d.canDirectlyAssign(val.rv.Type(), into.Type()): + into.Set(val.rv) + return nil + } + + if into.IsNil() { + into.Set(reflect.New(into.Type().Elem())) + } + into = into.Elem() + } + + // Ww need to preform the same switch statement as above after the loop to + // check for direct assignment one more time on the fully deferenced types. + // + // NOTE(rfratto): we skip the rawValue assignment check since that's meant + // for assigning pointers, and into is never a pointer when we reach here. + switch { + case into.CanSet() && val.Type() == TypeNull: + into.Set(reflect.Zero(into.Type())) + return nil + case into.CanSet() && d.canDirectlyAssign(val.rv.Type(), into.Type()): + into.Set(val.rv) + return nil + } + + // Special decoding rules: + // + // 1. If into is an interface{}, go through decodeAny so it gets assigned + // predictable types. + // 2. If into implements a supported interface, use the interface for + // decoding instead. + if into.Type() == goAny { + return d.decodeAny(val, into) + } else if ok, err := d.decodeFromInterface(val, into); ok { + return err + } + + if into.CanAddr() && into.Addr().Type().Implements(goRiverDefaulter) { + into.Addr().Interface().(Defaulter).SetToDefault() + } else if into.Type().Implements(goRiverDefaulter) { + into.Interface().(Defaulter).SetToDefault() + } + + targetType := RiverType(into.Type()) + + // Track a value to use for decoding. This value will be updated if + // conversion is necessary. + // + // NOTE(rfratto): we don't reassign to val here, since Go 1.18 thinks that + // means it escapes the heap. We need to create a local variable to avoid + // extra allocations. + convVal := val + + // Convert the value. + switch { + case val.rv.Type() == goByteSlice && into.Type() == goString: // []byte -> string + into.Set(val.rv.Convert(goString)) + return nil + case val.rv.Type() == goString && into.Type() == goByteSlice: // string -> []byte + into.Set(val.rv.Convert(goByteSlice)) + return nil + case convVal.Type() != targetType: + converted, err := tryCapsuleConvert(convVal, into, targetType) + if err != nil { + return err + } else if converted { + return nil + } + + convVal, err = convertValue(convVal, targetType) + if err != nil { + return err + } + } + + // Slowest case: recursive decoding. Once we've reached this point, we know + // that convVal.rv and into are compatible Go types. + switch convVal.Type() { + case TypeNumber: + into.Set(convertGoNumber(convVal.Number(), into.Type())) + return nil + case TypeString: + // Call convVal.Text() to get the final string value, since convVal.rv + // might not be a string. + into.Set(reflect.ValueOf(convVal.Text())) + return nil + case TypeBool: + into.Set(reflect.ValueOf(convVal.Bool())) + return nil + case TypeArray: + return d.decodeArray(convVal, into) + case TypeObject: + return d.decodeObject(convVal, into) + case TypeFunction: + // The Go types for two functions must be the same. + // + // TODO(rfratto): we may want to consider being more lax here, potentially + // creating an adapter between the two functions. + if convVal.rv.Type() == into.Type() { + into.Set(convVal.rv) + return nil + } + + return Error{ + Value: val, + Inner: fmt.Errorf("expected function(%s), got function(%s)", into.Type(), convVal.rv.Type()), + } + case TypeCapsule: + // The Go types for the capsules must be the same or able to be converted. + if convVal.rv.Type() == into.Type() { + into.Set(convVal.rv) + return nil + } + + converted, err := tryCapsuleConvert(convVal, into, targetType) + if err != nil { + return err + } else if converted { + return nil + } + + // TODO(rfratto): return a TypeError for this instead. TypeError isn't + // appropriate at the moment because it would just print "capsule", which + // doesn't contain all the information the user would want to know (e.g., a + // capsule of what inner type?). + return Error{ + Value: val, + Inner: fmt.Errorf("expected capsule(%q), got %s", into.Type(), convVal.Describe()), + } + default: + panic("river/value: unexpected kind " + convVal.Type().String()) + } +} + +// canDirectlyAssign returns true if the `from` type can be directly asssigned +// to the `into` type. This always returns false if the decoder is set to make +// copies or into contains an interface{} type anywhere in its type definition +// to allow for decoding interfaces{} into a set of known types. +func (d *decoder) canDirectlyAssign(from reflect.Type, into reflect.Type) bool { + if d.makeCopy { + return false + } + if from != into { + return false + } + return !containsAny(into) +} + +// containsAny recursively traverses through into, returning true if it +// contains an interface{} value anywhere in its structure. +func containsAny(into reflect.Type) bool { + // TODO(rfratto): cache result of this function? + + if into == goAny { + return true + } + + switch into.Kind() { + case reflect.Array, reflect.Pointer, reflect.Slice: + return containsAny(into.Elem()) + case reflect.Map: + if into.Key() == goString { + return containsAny(into.Elem()) + } + return false + + case reflect.Struct: + for i := 0; i < into.NumField(); i++ { + if containsAny(into.Field(i).Type) { + return true + } + } + return false + + default: + // Other kinds are not River types where the decodeAny check applies. + return false + } +} + +func (d *decoder) decodeFromInterface(val Value, into reflect.Value) (ok bool, err error) { + // into may only implement interface types for a pointer receiver, so we want + // to address into if possible. + if into.CanAddr() { + into = into.Addr() + } + + switch { + case into.Type() == goDurationPtr: + var s string + err := d.decode(val, reflect.ValueOf(&s)) + if err != nil { + return true, err + } + dur, err := time.ParseDuration(s) + if err != nil { + return true, Error{Value: val, Inner: err} + } + *into.Interface().(*time.Duration) = dur + return true, nil + + case into.Type().Implements(goRiverDecoder): + err := into.Interface().(Unmarshaler).UnmarshalRiver(func(v interface{}) error { + return d.decode(val, reflect.ValueOf(v)) + }) + if err != nil { + // TODO(rfratto): we need to detect if error is one of the error types + // from this package and only wrap it in an Error if it isn't. + return true, Error{Value: val, Inner: err} + } + return true, nil + + case into.Type().Implements(goTextUnmarshaler): + var s string + err := d.decode(val, reflect.ValueOf(&s)) + if err != nil { + return true, err + } + err = into.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(s)) + if err != nil { + return true, Error{Value: val, Inner: err} + } + return true, nil + } + + return false, nil +} + +func tryCapsuleConvert(from Value, into reflect.Value, intoType Type) (ok bool, err error) { + // Check to see if we can use capsule conversion. + if from.Type() == TypeCapsule { + cc, ok := from.Interface().(ConvertibleIntoCapsule) + if ok { + // It's always possible to Addr the reflect.Value below since we expect + // it to be a settable non-pointer value. + err := cc.ConvertInto(into.Addr().Interface()) + if err == nil { + return true, nil + } else if err != nil && !errors.Is(err, ErrNoConversion) { + return false, Error{Value: from, Inner: err} + } + } + } + + if intoType == TypeCapsule { + cc, ok := into.Addr().Interface().(ConvertibleFromCapsule) + if ok { + err := cc.ConvertFrom(from.Interface()) + if err == nil { + return true, nil + } else if err != nil && !errors.Is(err, ErrNoConversion) { + return false, Error{Value: from, Inner: err} + } + } + } + + // Last attempt: allow converting two capsules if the Go types are compatible + // and the into kind is an interface. + // + // TODO(rfratto): we may consider expanding this to allowing conversion to + // any compatible Go type in the future (not just interfaces). + if from.Type() == TypeCapsule && intoType == TypeCapsule && into.Kind() == reflect.Interface { + // We try to convert a pointer to from first to avoid making unnecessary + // copies. + if from.Reflect().CanAddr() && from.Reflect().Addr().CanConvert(into.Type()) { + val := from.Reflect().Addr().Convert(into.Type()) + into.Set(val) + return true, nil + } else if from.Reflect().CanConvert(into.Type()) { + val := from.Reflect().Convert(into.Type()) + into.Set(val) + return true, nil + } + } + + return false, nil +} + +// decodeAny is invoked by decode when into is an interface{}. We assign the +// interface{} a known type based on the River value being decoded: +// +// Null values: nil +// Number values: float64, int, int64, or uint64. +// If the underlying type is a float, always decode to a float64. +// For non-floats the order of preference is int -> int64 -> uint64. +// Arrays: []interface{} +// Objects: map[string]interface{} +// Bool: bool +// String: string +// Function: Passthrough of the underlying function value +// Capsule: Passthrough of the underlying capsule value +// +// In the cases where we do not pass through the underlying value, we create a +// value of that type, recursively call decode to populate that new value, and +// then store that value into the interface{}. +func (d *decoder) decodeAny(val Value, into reflect.Value) error { + var ptr reflect.Value + + switch val.Type() { + case TypeNull: + into.Set(reflect.Zero(into.Type())) + return nil + + case TypeNumber: + + switch val.Number().Kind() { + case NumberKindFloat: + var v float64 + ptr = reflect.ValueOf(&v) + case NumberKindUint: + uint64Val := val.Uint() + if uint64Val <= math.MaxInt { + var v int + ptr = reflect.ValueOf(&v) + } else if uint64Val <= math.MaxInt64 { + var v int64 + ptr = reflect.ValueOf(&v) + } else { + var v uint64 + ptr = reflect.ValueOf(&v) + } + case NumberKindInt: + int64Val := val.Int() + if math.MinInt <= int64Val && int64Val <= math.MaxInt { + var v int + ptr = reflect.ValueOf(&v) + } else { + var v int64 + ptr = reflect.ValueOf(&v) + } + + default: + panic("river/value: unreachable") + } + + case TypeArray: + var v []interface{} + ptr = reflect.ValueOf(&v) + + case TypeObject: + var v map[string]interface{} + ptr = reflect.ValueOf(&v) + + case TypeBool: + var v bool + ptr = reflect.ValueOf(&v) + + case TypeString: + var v string + ptr = reflect.ValueOf(&v) + + case TypeFunction, TypeCapsule: + // Functions and capsules must be directly assigned since there's no + // "generic" representation for either. + // + // We retain the pointer if we were given a pointer. + + if val.rv.CanAddr() { + into.Set(val.rv.Addr()) + return nil + } + + into.Set(val.rv) + return nil + + default: + panic("river/value: unreachable") + } + + if err := d.decode(val, ptr); err != nil { + return err + } + into.Set(ptr.Elem()) + return nil +} + +func (d *decoder) decodeArray(val Value, rt reflect.Value) error { + switch rt.Kind() { + case reflect.Slice: + res := reflect.MakeSlice(rt.Type(), val.Len(), val.Len()) + for i := 0; i < val.Len(); i++ { + // Decode the original elements into the new elements. + if err := d.decode(val.Index(i), res.Index(i)); err != nil { + return ElementError{Value: val, Index: i, Inner: err} + } + } + rt.Set(res) + + case reflect.Array: + res := reflect.New(rt.Type()).Elem() + + if val.Len() != res.Len() { + return Error{ + Value: val, + Inner: fmt.Errorf("array must have exactly %d elements, got %d", res.Len(), val.Len()), + } + } + + for i := 0; i < val.Len(); i++ { + if err := d.decode(val.Index(i), res.Index(i)); err != nil { + return ElementError{Value: val, Index: i, Inner: err} + } + } + rt.Set(res) + + default: + panic(fmt.Sprintf("river/value: unexpected array type %s", val.rv.Kind())) + } + + return nil +} + +func (d *decoder) decodeObject(val Value, rt reflect.Value) error { + switch rt.Kind() { + case reflect.Struct: + targetTags := getCachedTags(rt.Type()) + return d.decodeObjectToStruct(val, rt, targetTags, false) + + case reflect.Slice, reflect.Array: // Slice of labeled blocks + keys := val.Keys() + + var res reflect.Value + + if rt.Kind() == reflect.Slice { + res = reflect.MakeSlice(rt.Type(), len(keys), len(keys)) + } else { // Array + res = reflect.New(rt.Type()).Elem() + + if res.Len() != len(keys) { + return Error{ + Value: val, + Inner: fmt.Errorf("object must have exactly %d keys, got %d", res.Len(), len(keys)), + } + } + } + + fields := getCachedTags(rt.Type().Elem()) + labelField, _ := fields.LabelField() + + for i, key := range keys { + // First decode the key into the label. + elem := res.Index(i) + reflectutil.GetOrAlloc(elem, labelField).Set(reflect.ValueOf(key)) + + // Now decode the inner object. + value, _ := val.Key(key) + if err := d.decodeObjectToStruct(value, elem, fields, true); err != nil { + return FieldError{Value: val, Field: key, Inner: err} + } + } + rt.Set(res) + + case reflect.Map: + if rt.Type().Key() != goString { + // Maps with non-string types are treated as capsules and can't be + // decoded from maps. + return TypeError{Value: val, Expected: RiverType(rt.Type())} + } + + res := reflect.MakeMapWithSize(rt.Type(), val.Len()) + + // Create a shared value to decode each element into. This will be zeroed + // out for each key, and then copied when setting the map index. + into := reflect.New(rt.Type().Elem()).Elem() + intoZero := reflect.Zero(into.Type()) + + for i, key := range val.Keys() { + // We ignore the ok value because we know it exists. + value, _ := val.Key(key) + + // Zero out the value if it was decoded in the previous loop. + if i > 0 { + into.Set(intoZero) + } + // Decode into our element. + if err := d.decode(value, into); err != nil { + return FieldError{Value: val, Field: key, Inner: err} + } + + // Then set the map index. + res.SetMapIndex(reflect.ValueOf(key), into) + } + + rt.Set(res) + + default: + panic(fmt.Sprintf("river/value: unexpected target type %s", rt.Kind())) + } + + return nil +} + +func (d *decoder) decodeObjectToStruct(val Value, rt reflect.Value, fields *objectFields, decodedLabel bool) error { + // TODO(rfratto): this needs to check for required keys being set + + for _, key := range val.Keys() { + // We ignore the ok value because we know it exists. + value, _ := val.Key(key) + + // Struct labels should be decoded first, since objects are wrapped in + // labels. If we have yet to decode the label, decode it now. + if lf, ok := fields.LabelField(); ok && !decodedLabel { + // Safety check: if the inner field isn't an object, there's something + // wrong here. It's unclear if a user can craft an expression that hits + // this case, but it's left in for safety. + if value.Type() != TypeObject { + return FieldError{ + Value: val, + Field: key, + Inner: TypeError{Value: value, Expected: TypeObject}, + } + } + + // Decode the key into the label. + reflectutil.GetOrAlloc(rt, lf).Set(reflect.ValueOf(key)) + + // ...and then code the rest of the object. + if err := d.decodeObjectToStruct(value, rt, fields, true); err != nil { + return err + } + continue + } + + switch fields.Has(key) { + case objectKeyTypeInvalid: + return MissingKeyError{Value: value, Missing: key} + case objectKeyTypeNestedField: // Block with multiple name fragments + next, _ := fields.NestedField(key) + // Recurse the call with the inner value. + if err := d.decodeObjectToStruct(value, rt, next, decodedLabel); err != nil { + return err + } + case objectKeyTypeField: // Single-name fragment + targetField, _ := fields.Field(key) + targetValue := reflectutil.GetOrAlloc(rt, targetField) + + if err := d.decode(value, targetValue); err != nil { + return FieldError{Value: val, Field: key, Inner: err} + } + } + } + + return nil +} diff --git a/syntax/internal/value/decode_benchmarks_test.go b/syntax/internal/value/decode_benchmarks_test.go new file mode 100644 index 0000000000..9a33239329 --- /dev/null +++ b/syntax/internal/value/decode_benchmarks_test.go @@ -0,0 +1,90 @@ +package value_test + +import ( + "fmt" + "testing" + + "github.com/grafana/river/internal/value" +) + +func BenchmarkObjectDecode(b *testing.B) { + b.StopTimer() + + // Create a value with 20 keys. + source := make(map[string]string, 20) + for i := 0; i < 20; i++ { + var ( + key = fmt.Sprintf("key_%d", i+1) + value = fmt.Sprintf("value_%d", i+1) + ) + source[key] = value + } + + sourceVal := value.Encode(source) + + b.StartTimer() + for i := 0; i < b.N; i++ { + var dst map[string]string + _ = value.Decode(sourceVal, &dst) + } +} + +func BenchmarkObject(b *testing.B) { + b.Run("Non-capsule", func(b *testing.B) { + b.StopTimer() + + vals := make(map[string]value.Value) + for i := 0; i < 20; i++ { + vals[fmt.Sprintf("%d", i)] = value.Int(int64(i)) + } + + b.StartTimer() + for i := 0; i < b.N; i++ { + _ = value.Object(vals) + } + }) + + b.Run("Capsule", func(b *testing.B) { + b.StopTimer() + + vals := make(map[string]value.Value) + for i := 0; i < 20; i++ { + vals[fmt.Sprintf("%d", i)] = value.Encapsulate(make(chan int)) + } + + b.StartTimer() + for i := 0; i < b.N; i++ { + _ = value.Object(vals) + } + }) +} + +func BenchmarkArray(b *testing.B) { + b.Run("Non-capsule", func(b *testing.B) { + b.StopTimer() + + var vals []value.Value + for i := 0; i < 20; i++ { + vals = append(vals, value.Int(int64(i))) + } + + b.StartTimer() + for i := 0; i < b.N; i++ { + _ = value.Array(vals...) + } + }) + + b.Run("Capsule", func(b *testing.B) { + b.StopTimer() + + var vals []value.Value + for i := 0; i < 20; i++ { + vals = append(vals, value.Encapsulate(make(chan int))) + } + + b.StartTimer() + for i := 0; i < b.N; i++ { + _ = value.Array(vals...) + } + }) +} diff --git a/syntax/internal/value/decode_test.go b/syntax/internal/value/decode_test.go new file mode 100644 index 0000000000..5b84838bd0 --- /dev/null +++ b/syntax/internal/value/decode_test.go @@ -0,0 +1,761 @@ +package value_test + +import ( + "fmt" + "math" + "reflect" + "testing" + "time" + "unsafe" + + "github.com/grafana/river/internal/value" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDecode_Numbers(t *testing.T) { + // There's a lot of values that can represent numbers, so we construct a + // matrix dynamically of all the combinations here. + vals := []interface{}{ + int(15), int8(15), int16(15), int32(15), int64(15), + uint(15), uint8(15), uint16(15), uint32(15), uint64(15), + float32(15), float64(15), + string("15"), // string holding a valid number (which can be converted to a number) + } + + for _, input := range vals { + for _, expect := range vals { + val := value.Encode(input) + + name := fmt.Sprintf( + "%s to %s", + reflect.TypeOf(input), + reflect.TypeOf(expect), + ) + + t.Run(name, func(t *testing.T) { + vPtr := reflect.New(reflect.TypeOf(expect)).Interface() + require.NoError(t, value.Decode(val, vPtr)) + + actual := reflect.ValueOf(vPtr).Elem().Interface() + require.Equal(t, expect, actual) + }) + } + } +} + +func TestDecode(t *testing.T) { + // Declare some types to use for testing. Person2 is used as a struct + // equivalent to Person, but with a different Go type to force casting. + type Person struct { + Name string `river:"name,attr"` + } + + type Person2 struct { + Name string `river:"name,attr"` + } + + tt := []struct { + input, expect interface{} + }{ + {nil, (*int)(nil)}, + + // Non-number primitives. + {string("Hello!"), string("Hello!")}, + {bool(true), bool(true)}, + + // Arrays + {[]int{1, 2, 3}, []int{1, 2, 3}}, + {[]int{1, 2, 3}, [...]int{1, 2, 3}}, + {[...]int{1, 2, 3}, []int{1, 2, 3}}, + {[...]int{1, 2, 3}, [...]int{1, 2, 3}}, + + // Maps + {map[string]int{"year": 2022}, map[string]uint{"year": 2022}}, + {map[string]string{"name": "John"}, map[string]string{"name": "John"}}, + {map[string]string{"name": "John"}, Person{Name: "John"}}, + {Person{Name: "John"}, map[string]string{"name": "John"}}, + {Person{Name: "John"}, Person{Name: "John"}}, + {Person{Name: "John"}, Person2{Name: "John"}}, + {Person2{Name: "John"}, Person{Name: "John"}}, + + // NOTE(rfratto): we don't test capsules or functions here because they're + // not comparable in the same way as we do the other tests. + // + // See TestDecode_Functions and TestDecode_Capsules for specific decoding + // tests of those types. + } + + for _, tc := range tt { + val := value.Encode(tc.input) + + name := fmt.Sprintf( + "%s (%s) to %s", + val.Type(), + reflect.TypeOf(tc.input), + reflect.TypeOf(tc.expect), + ) + + t.Run(name, func(t *testing.T) { + vPtr := reflect.New(reflect.TypeOf(tc.expect)).Interface() + require.NoError(t, value.Decode(val, vPtr)) + + actual := reflect.ValueOf(vPtr).Elem().Interface() + + require.Equal(t, tc.expect, actual) + }) + } +} + +// TestDecode_PreservePointer ensures that pointer addresses can be preserved +// when decoding. +func TestDecode_PreservePointer(t *testing.T) { + num := 5 + val := value.Encode(&num) + + var nump *int + require.NoError(t, value.Decode(val, &nump)) + require.Equal(t, unsafe.Pointer(nump), unsafe.Pointer(&num)) +} + +// TestDecode_PreserveMapReference ensures that map references can be preserved +// when decoding. +func TestDecode_PreserveMapReference(t *testing.T) { + m := make(map[string]string) + val := value.Encode(m) + + var actual map[string]string + require.NoError(t, value.Decode(val, &actual)) + + // We can't check to see if the pointers of m and actual match, but we can + // modify m to see if actual is also modified. + m["foo"] = "bar" + require.Equal(t, "bar", actual["foo"]) +} + +// TestDecode_PreserveSliceReference ensures that slice references can be +// preserved when decoding. +func TestDecode_PreserveSliceReference(t *testing.T) { + s := make([]string, 3) + val := value.Encode(s) + + var actual []string + require.NoError(t, value.Decode(val, &actual)) + + // We can't check to see if the pointers of m and actual match, but we can + // modify s to see if actual is also modified. + s[0] = "Hello, world!" + require.Equal(t, "Hello, world!", actual[0]) +} +func TestDecode_Functions(t *testing.T) { + val := value.Encode(func() int { return 15 }) + + var f func() int + require.NoError(t, value.Decode(val, &f)) + require.Equal(t, 15, f()) +} + +func TestDecode_Capsules(t *testing.T) { + expect := make(chan int, 5) + + var actual chan int + require.NoError(t, value.Decode(value.Encode(expect), &actual)) + require.Equal(t, expect, actual) +} + +type ValueInterface interface{ SomeMethod() } + +type Value1 struct{ test string } + +func (c Value1) SomeMethod() {} + +// TestDecode_CapsuleInterface tests that we are able to decode when +// the target `into` is an interface. +func TestDecode_CapsuleInterface(t *testing.T) { + tt := []struct { + name string + value ValueInterface + expected ValueInterface + }{ + { + name: "Capsule to Capsule", + value: Value1{test: "true"}, + expected: Value1{test: "true"}, + }, + { + name: "Capsule Pointer to Capsule", + value: &Value1{test: "true"}, + expected: &Value1{test: "true"}, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + var actual ValueInterface + require.NoError(t, value.Decode(value.Encode(tc.value), &actual)) + + // require.Same validates the memory address matches after Decode. + if reflect.TypeOf(tc.value).Kind() == reflect.Pointer { + require.Same(t, tc.value, actual) + } + + // We use tc.expected to validate the properties of actual match the + // original tc.value properties (nothing has mutated them during the test). + require.Equal(t, tc.expected, actual) + }) + } +} + +// TestDecode_CapsulesError tests that we are unable to decode when +// the target `into` is not an interface. +func TestDecode_CapsulesError(t *testing.T) { + type Capsule1 struct{ test string } + type Capsule2 Capsule1 + + v := Capsule1{test: "true"} + actual := Capsule2{} + + require.EqualError(t, value.Decode(value.Encode(v), &actual), `expected capsule("value_test.Capsule2"), got capsule("value_test.Capsule1")`) +} + +// TestDecodeCopy_SliceCopy ensures that copies are made during decoding +// instead of setting values directly. +func TestDecodeCopy_SliceCopy(t *testing.T) { + orig := []int{1, 2, 3} + + var res []int + require.NoError(t, value.DecodeCopy(value.Encode(orig), &res)) + + res[0] = 10 + require.Equal(t, []int{1, 2, 3}, orig, "Original slice should not have been modified") +} + +// TestDecodeCopy_ArrayCopy ensures that copies are made during decoding +// instead of setting values directly. +func TestDecode_ArrayCopy(t *testing.T) { + orig := [...]int{1, 2, 3} + + var res [3]int + require.NoError(t, value.DecodeCopy(value.Encode(orig), &res)) + + res[0] = 10 + require.Equal(t, [3]int{1, 2, 3}, orig, "Original array should not have been modified") +} + +func TestDecode_CustomTypes(t *testing.T) { + t.Run("object to Unmarshaler", func(t *testing.T) { + var actual customUnmarshaler + require.NoError(t, value.Decode(value.Object(nil), &actual)) + require.True(t, actual.UnmarshalCalled, "UnmarshalRiver was not invoked") + require.True(t, actual.DefaultCalled, "SetToDefault was not invoked") + require.True(t, actual.ValidateCalled, "Validate was not invoked") + }) + + t.Run("TextMarshaler to TextUnmarshaler", func(t *testing.T) { + now := time.Now() + + var actual time.Time + require.NoError(t, value.Decode(value.Encode(now), &actual)) + require.True(t, now.Equal(actual)) + }) + + t.Run("time.Duration to time.Duration", func(t *testing.T) { + dur := 15 * time.Second + + var actual time.Duration + require.NoError(t, value.Decode(value.Encode(dur), &actual)) + require.Equal(t, dur, actual) + }) + + t.Run("string to TextUnmarshaler", func(t *testing.T) { + now := time.Now() + nowBytes, _ := now.MarshalText() + + var actual time.Time + require.NoError(t, value.Decode(value.String(string(nowBytes)), &actual)) + + actualBytes, _ := actual.MarshalText() + require.Equal(t, nowBytes, actualBytes) + }) + + t.Run("string to time.Duration", func(t *testing.T) { + dur := 15 * time.Second + + var actual time.Duration + require.NoError(t, value.Decode(value.String(dur.String()), &actual)) + require.Equal(t, dur.String(), actual.String()) + }) +} + +type customUnmarshaler struct { + UnmarshalCalled bool `river:"unmarshal_called,attr,optional"` + DefaultCalled bool `river:"default_called,attr,optional"` + ValidateCalled bool `river:"validate_called,attr,optional"` +} + +func (cu *customUnmarshaler) UnmarshalRiver(f func(interface{}) error) error { + cu.UnmarshalCalled = true + return f((*customUnmarshalerTarget)(cu)) +} + +type customUnmarshalerTarget customUnmarshaler + +func (s *customUnmarshalerTarget) SetToDefault() { + s.DefaultCalled = true +} + +func (s *customUnmarshalerTarget) Validate() error { + s.ValidateCalled = true + return nil +} + +type textEnumType bool + +func (et *textEnumType) UnmarshalText(text []byte) error { + *et = false + + switch string(text) { + case "accepted_value": + *et = true + return nil + default: + return fmt.Errorf("unrecognized value %q", string(text)) + } +} + +func TestDecode_TextUnmarshaler(t *testing.T) { + t.Run("valid type and value", func(t *testing.T) { + var et textEnumType + require.NoError(t, value.Decode(value.String("accepted_value"), &et)) + require.Equal(t, textEnumType(true), et) + }) + + t.Run("invalid type", func(t *testing.T) { + var et textEnumType + err := value.Decode(value.Bool(true), &et) + require.EqualError(t, err, "expected string, got bool") + }) + + t.Run("invalid value", func(t *testing.T) { + var et textEnumType + err := value.Decode(value.String("bad_value"), &et) + require.EqualError(t, err, `unrecognized value "bad_value"`) + }) + + t.Run("unmarshaler nested in other value", func(t *testing.T) { + input := value.Array( + value.String("accepted_value"), + value.String("accepted_value"), + value.String("accepted_value"), + ) + + var ett []textEnumType + require.NoError(t, value.Decode(input, &ett)) + require.Equal(t, []textEnumType{true, true, true}, ett) + }) +} + +func TestDecode_ErrorChain(t *testing.T) { + type Target struct { + Key struct { + Object struct { + Field1 []int `river:"field1,attr"` + } `river:"object,attr"` + } `river:"key,attr"` + } + + val := value.Object(map[string]value.Value{ + "key": value.Object(map[string]value.Value{ + "object": value.Object(map[string]value.Value{ + "field1": value.Array( + value.Int(15), + value.Int(30), + value.String("Hello, world!"), + ), + }), + }), + }) + + // NOTE(rfratto): strings of errors from the value package are fairly limited + // in the amount of information they show, since the value package doesn't + // have a great way to pretty-print the chain of errors. + // + // For example, with the error below, the message doesn't explain where the + // string is coming from, even though the error values hold that context. + // + // Callers consuming errors should print the error chain with extra context + // so it's more useful to users. + err := value.Decode(val, &Target{}) + expectErr := `expected number, got string` + require.EqualError(t, err, expectErr) +} + +type boolish int + +var _ value.ConvertibleFromCapsule = (*boolish)(nil) +var _ value.ConvertibleIntoCapsule = (boolish)(0) + +func (b boolish) RiverCapsule() {} + +func (b *boolish) ConvertFrom(src interface{}) error { + switch v := src.(type) { + case bool: + if v { + *b = 1 + } else { + *b = 0 + } + return nil + } + + return value.ErrNoConversion +} + +func (b boolish) ConvertInto(dst interface{}) error { + switch d := dst.(type) { + case *bool: + if b == 0 { + *d = false + } else { + *d = true + } + return nil + } + + return value.ErrNoConversion +} + +func TestDecode_CustomConvert(t *testing.T) { + t.Run("compatible type to custom", func(t *testing.T) { + var b boolish + err := value.Decode(value.Bool(true), &b) + require.NoError(t, err) + require.Equal(t, boolish(1), b) + }) + + t.Run("custom to compatible type", func(t *testing.T) { + var b bool + err := value.Decode(value.Encapsulate(boolish(10)), &b) + require.NoError(t, err) + require.Equal(t, true, b) + }) + + t.Run("incompatible type to custom", func(t *testing.T) { + var b boolish + err := value.Decode(value.String("true"), &b) + require.EqualError(t, err, "expected capsule, got string") + }) + + t.Run("custom to incompatible type", func(t *testing.T) { + src := boolish(10) + + var s string + err := value.Decode(value.Encapsulate(&src), &s) + require.EqualError(t, err, "expected string, got capsule") + }) +} + +func TestDecode_SquashedFields(t *testing.T) { + type InnerStruct struct { + InnerField1 string `river:"inner_field_1,attr,optional"` + InnerField2 string `river:"inner_field_2,attr,optional"` + } + + type OuterStruct struct { + OuterField1 string `river:"outer_field_1,attr,optional"` + Inner InnerStruct `river:",squash"` + OuterField2 string `river:"outer_field_2,attr,optional"` + } + + var ( + in = map[string]string{ + "outer_field_1": "value1", + "outer_field_2": "value2", + "inner_field_1": "value3", + "inner_field_2": "value4", + } + expect = OuterStruct{ + OuterField1: "value1", + Inner: InnerStruct{ + InnerField1: "value3", + InnerField2: "value4", + }, + OuterField2: "value2", + } + ) + + var out OuterStruct + err := value.Decode(value.Encode(in), &out) + require.NoError(t, err) + require.Equal(t, expect, out) +} + +func TestDecode_SquashedFields_Pointer(t *testing.T) { + type InnerStruct struct { + InnerField1 string `river:"inner_field_1,attr,optional"` + InnerField2 string `river:"inner_field_2,attr,optional"` + } + + type OuterStruct struct { + OuterField1 string `river:"outer_field_1,attr,optional"` + Inner *InnerStruct `river:",squash"` + OuterField2 string `river:"outer_field_2,attr,optional"` + } + + var ( + in = map[string]string{ + "outer_field_1": "value1", + "outer_field_2": "value2", + "inner_field_1": "value3", + "inner_field_2": "value4", + } + expect = OuterStruct{ + OuterField1: "value1", + Inner: &InnerStruct{ + InnerField1: "value3", + InnerField2: "value4", + }, + OuterField2: "value2", + } + ) + + var out OuterStruct + err := value.Decode(value.Encode(in), &out) + require.NoError(t, err) + require.Equal(t, expect, out) +} + +func TestDecode_Slice(t *testing.T) { + type Block struct { + Attr int `river:"attr,attr"` + } + + type Struct struct { + Blocks []Block `river:"block.a,block,optional"` + } + + var ( + in = map[string]interface{}{ + "block": map[string]interface{}{ + "a": []map[string]interface{}{ + {"attr": 1}, + {"attr": 2}, + {"attr": 3}, + {"attr": 4}, + }, + }, + } + expect = Struct{ + Blocks: []Block{ + {Attr: 1}, + {Attr: 2}, + {Attr: 3}, + {Attr: 4}, + }, + } + ) + + var out Struct + err := value.Decode(value.Encode(in), &out) + require.NoError(t, err) + require.Equal(t, expect, out) +} + +func TestDecode_SquashedSlice(t *testing.T) { + type Block struct { + Attr int `river:"attr,attr"` + } + + type InnerStruct struct { + BlockA Block `river:"a,block,optional"` + BlockB Block `river:"b,block,optional"` + BlockC Block `river:"c,block,optional"` + } + + type OuterStruct struct { + OuterField1 string `river:"outer_field_1,attr,optional"` + Inner []InnerStruct `river:"block,enum"` + OuterField2 string `river:"outer_field_2,attr,optional"` + } + + var ( + in = map[string]interface{}{ + "outer_field_1": "value1", + "outer_field_2": "value2", + + "block": []map[string]interface{}{ + {"a": map[string]interface{}{"attr": 1}}, + {"b": map[string]interface{}{"attr": 2}}, + {"c": map[string]interface{}{"attr": 3}}, + {"a": map[string]interface{}{"attr": 4}}, + }, + } + expect = OuterStruct{ + OuterField1: "value1", + OuterField2: "value2", + + Inner: []InnerStruct{ + {BlockA: Block{Attr: 1}}, + {BlockB: Block{Attr: 2}}, + {BlockC: Block{Attr: 3}}, + {BlockA: Block{Attr: 4}}, + }, + } + ) + + var out OuterStruct + err := value.Decode(value.Encode(in), &out) + require.NoError(t, err) + require.Equal(t, expect, out) +} + +func TestDecode_SquashedSlice_Pointer(t *testing.T) { + type Block struct { + Attr int `river:"attr,attr"` + } + + type InnerStruct struct { + BlockA *Block `river:"a,block,optional"` + BlockB *Block `river:"b,block,optional"` + BlockC *Block `river:"c,block,optional"` + } + + type OuterStruct struct { + OuterField1 string `river:"outer_field_1,attr,optional"` + Inner []InnerStruct `river:"block,enum"` + OuterField2 string `river:"outer_field_2,attr,optional"` + } + + var ( + in = map[string]interface{}{ + "outer_field_1": "value1", + "outer_field_2": "value2", + + "block": []map[string]interface{}{ + {"a": map[string]interface{}{"attr": 1}}, + {"b": map[string]interface{}{"attr": 2}}, + {"c": map[string]interface{}{"attr": 3}}, + {"a": map[string]interface{}{"attr": 4}}, + }, + } + expect = OuterStruct{ + OuterField1: "value1", + OuterField2: "value2", + + Inner: []InnerStruct{ + {BlockA: &Block{Attr: 1}}, + {BlockB: &Block{Attr: 2}}, + {BlockC: &Block{Attr: 3}}, + {BlockA: &Block{Attr: 4}}, + }, + } + ) + + var out OuterStruct + err := value.Decode(value.Encode(in), &out) + require.NoError(t, err) + require.Equal(t, expect, out) +} + +// TestDecode_KnownTypes_Any asserts that decoding River values into an +// any/interface{} results in known types. +func TestDecode_KnownTypes_Any(t *testing.T) { + tt := []struct { + input any + expect any + }{ + // expect "int" + {int(0), 0}, + {int(-1), -1}, + {int(15), 15}, + {int8(15), 15}, + {int16(15), 15}, + {int32(15), 15}, + {int64(15), 15}, + {uint(0), 0}, + {uint(15), 15}, + {uint8(15), 15}, + {uint16(15), 15}, + {uint32(15), 15}, + {uint64(15), 15}, + {int64(math.MinInt64), math.MinInt64}, + {int64(math.MaxInt64), math.MaxInt64}, + // expect "uint" + {uint64(math.MaxInt64 + 1), uint64(math.MaxInt64 + 1)}, + {uint64(math.MaxUint64), uint64(math.MaxUint64)}, + // expect "float" + {float32(2.5), float64(2.5)}, + {float64(2.5), float64(2.5)}, + {float64(math.MinInt64) - 10, float64(math.MinInt64) - 10}, + {float64(math.MaxInt64) + 10, float64(math.MaxInt64) + 10}, + + {bool(true), bool(true)}, + {string("Hello"), string("Hello")}, + + { + input: []int{1, 2, 3}, + expect: []any{1, 2, 3}, + }, + + { + input: map[string]int{"number": 15}, + expect: map[string]any{"number": 15}, + }, + { + input: struct { + Name string `river:"name,attr"` + }{Name: "John"}, + + expect: map[string]any{"name": "John"}, + }, + } + + t.Run("basic types", func(t *testing.T) { + for _, tc := range tt { + var actual any + err := value.Decode(value.Encode(tc.input), &actual) + + if assert.NoError(t, err) { + assert.Equal(t, tc.expect, actual, + "Expected %[1]v (%[1]T) to transcode to %[2]v (%[2]T)", tc.input, tc.expect) + } + } + }) + + t.Run("inside maps", func(t *testing.T) { + for _, tc := range tt { + input := map[string]any{ + "key": tc.input, + } + + var actual map[string]any + err := value.Decode(value.Encode(input), &actual) + + if assert.NoError(t, err) { + assert.Equal(t, tc.expect, actual["key"], + "Expected %[1]v (%[1]T) to transcode to %[2]v (%[2]T) inside a map", tc.input, tc.expect) + } + } + }) +} + +func TestRetainCapsulePointer(t *testing.T) { + capsuleVal := &capsule{} + + in := map[string]any{ + "foo": capsuleVal, + } + + var actual map[string]any + err := value.Decode(value.Encode(in), &actual) + require.NoError(t, err) + + expect := map[string]any{ + "foo": capsuleVal, + } + require.Equal(t, expect, actual) +} + +type capsule struct{} + +func (*capsule) RiverCapsule() {} diff --git a/syntax/internal/value/errors.go b/syntax/internal/value/errors.go new file mode 100644 index 0000000000..79f22378b3 --- /dev/null +++ b/syntax/internal/value/errors.go @@ -0,0 +1,107 @@ +package value + +import "fmt" + +// Error is used for reporting on a value-level error. It is the most general +// type of error for a value. +type Error struct { + Value Value + Inner error +} + +// TypeError is used for reporting on a value having an unexpected type. +type TypeError struct { + // Value which caused the error. + Value Value + Expected Type +} + +// Error returns the string form of the TypeError. +func (te TypeError) Error() string { + return fmt.Sprintf("expected %s, got %s", te.Expected, te.Value.Type()) +} + +// Error returns the message of the decode error. +func (de Error) Error() string { return de.Inner.Error() } + +// MissingKeyError is used for reporting that a value is missing a key. +type MissingKeyError struct { + Value Value + Missing string +} + +// Error returns the string form of the MissingKeyError. +func (mke MissingKeyError) Error() string { + return fmt.Sprintf("key %q does not exist", mke.Missing) +} + +// ElementError is used to report on an error inside of an array. +type ElementError struct { + Value Value // The Array value + Index int // The index of the element with the issue + Inner error // The error from the element +} + +// Error returns the text of the inner error. +func (ee ElementError) Error() string { return ee.Inner.Error() } + +// FieldError is used to report on an invalid field inside an object. +type FieldError struct { + Value Value // The Object value + Field string // The field name with the issue + Inner error // The error from the field +} + +// Error returns the text of the inner error. +func (fe FieldError) Error() string { return fe.Inner.Error() } + +// ArgError is used to report on an invalid argument to a function. +type ArgError struct { + Function Value + Argument Value + Index int + Inner error +} + +// Error returns the text of the inner error. +func (ae ArgError) Error() string { return ae.Inner.Error() } + +// WalkError walks err for all value-related errors in this package. +// WalkError returns false if err is not an error from this package. +func WalkError(err error, f func(err error)) bool { + var foundOne bool + + nextError := err + for nextError != nil { + switch ne := nextError.(type) { + case Error: + f(nextError) + nextError = ne.Inner + foundOne = true + case TypeError: + f(nextError) + nextError = nil + foundOne = true + case MissingKeyError: + f(nextError) + nextError = nil + foundOne = true + case ElementError: + f(nextError) + nextError = ne.Inner + foundOne = true + case FieldError: + f(nextError) + nextError = ne.Inner + foundOne = true + case ArgError: + f(nextError) + nextError = ne.Inner + foundOne = true + default: + nextError = nil + } + } + + return foundOne +} diff --git a/syntax/internal/value/number_value.go b/syntax/internal/value/number_value.go new file mode 100644 index 0000000000..c40fbbc802 --- /dev/null +++ b/syntax/internal/value/number_value.go @@ -0,0 +1,135 @@ +package value + +import ( + "math" + "reflect" + "strconv" +) + +var ( + nativeIntBits = reflect.TypeOf(int(0)).Bits() + nativeUintBits = reflect.TypeOf(uint(0)).Bits() +) + +// NumberKind categorizes a type of Go number. +type NumberKind uint8 + +const ( + // NumberKindInt represents an int-like type (e.g., int, int8, etc.). + NumberKindInt NumberKind = iota + // NumberKindUint represents a uint-like type (e.g., uint, uint8, etc.). + NumberKindUint + // NumberKindFloat represents both float32 and float64. + NumberKindFloat +) + +// makeNumberKind converts a Go kind to a River kind. +func makeNumberKind(k reflect.Kind) NumberKind { + switch k { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return NumberKindInt + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return NumberKindUint + case reflect.Float32, reflect.Float64: + return NumberKindFloat + default: + panic("river/value: makeNumberKind called with unsupported Kind value") + } +} + +// Number is a generic representation of Go numbers. It is intended to be +// created on the fly for numerical operations when the real number type is not +// known. +type Number struct { + // Value holds the raw data for the number. Note that for numberKindFloat, + // value is the raw bits of the float64 and must be converted back to a + // float64 before it can be used. + value uint64 + + bits uint8 // 8, 16, 32, 64, used for overflow checking + k NumberKind // int, uint, float +} + +func newNumberValue(v reflect.Value) Number { + var ( + val uint64 + bits int + nk NumberKind + ) + + switch v.Kind() { + case reflect.Int: + val, bits, nk = uint64(v.Int()), nativeIntBits, NumberKindInt + case reflect.Int8: + val, bits, nk = uint64(v.Int()), 8, NumberKindInt + case reflect.Int16: + val, bits, nk = uint64(v.Int()), 16, NumberKindInt + case reflect.Int32: + val, bits, nk = uint64(v.Int()), 32, NumberKindInt + case reflect.Int64: + val, bits, nk = uint64(v.Int()), 64, NumberKindInt + case reflect.Uint: + val, bits, nk = v.Uint(), nativeUintBits, NumberKindUint + case reflect.Uint8: + val, bits, nk = v.Uint(), 8, NumberKindUint + case reflect.Uint16: + val, bits, nk = v.Uint(), 16, NumberKindUint + case reflect.Uint32: + val, bits, nk = v.Uint(), 32, NumberKindUint + case reflect.Uint64: + val, bits, nk = v.Uint(), 64, NumberKindUint + case reflect.Float32: + val, bits, nk = math.Float64bits(v.Float()), 32, NumberKindFloat + case reflect.Float64: + val, bits, nk = math.Float64bits(v.Float()), 64, NumberKindFloat + default: + panic("river/value: unrecognized Go number type " + v.Kind().String()) + } + + return Number{val, uint8(bits), nk} +} + +// Kind returns the Number's NumberKind. +func (nv Number) Kind() NumberKind { return nv.k } + +// Int converts the Number into an int64. +func (nv Number) Int() int64 { + if nv.k == NumberKindFloat { + return int64(math.Float64frombits(nv.value)) + } + return int64(nv.value) +} + +// Uint converts the Number into a uint64. +func (nv Number) Uint() uint64 { + if nv.k == NumberKindFloat { + return uint64(math.Float64frombits(nv.value)) + } + return nv.value +} + +// Float converts the Number into a float64. +func (nv Number) Float() float64 { + switch nv.k { + case NumberKindInt: + // Convert nv.value to an int64 before converting to a float64 so the sign + // flag gets handled correctly. + return float64(int64(nv.value)) + case NumberKindFloat: + return math.Float64frombits(nv.value) + } + return float64(nv.value) +} + +// ToString converts the Number to a string. +func (nv Number) ToString() string { + switch nv.k { + case NumberKindUint: + return strconv.FormatUint(nv.value, 10) + case NumberKindInt: + return strconv.FormatInt(int64(nv.value), 10) + case NumberKindFloat: + return strconv.FormatFloat(math.Float64frombits(nv.value), 'f', -1, 64) + } + panic("river/value: unreachable") +} diff --git a/syntax/internal/value/raw_function.go b/syntax/internal/value/raw_function.go new file mode 100644 index 0000000000..bf25da916d --- /dev/null +++ b/syntax/internal/value/raw_function.go @@ -0,0 +1,9 @@ +package value + +// RawFunction allows creating function implementations using raw River values. +// This is useful for functions which wish to operate over dynamic types while +// avoiding decoding to interface{} for performance reasons. +// +// The func value itself is provided as an argument so error types can be +// filled. +type RawFunction func(funcValue Value, args ...Value) (Value, error) diff --git a/syntax/internal/value/tag_cache.go b/syntax/internal/value/tag_cache.go new file mode 100644 index 0000000000..2bce16209d --- /dev/null +++ b/syntax/internal/value/tag_cache.go @@ -0,0 +1,121 @@ +package value + +import ( + "reflect" + + "github.com/grafana/river/internal/rivertags" +) + +// tagsCache caches the river tags for a struct type. This is never cleared, +// but since most structs will be statically created throughout the lifetime +// of the process, this will consume a negligible amount of memory. +var tagsCache = make(map[reflect.Type]*objectFields) + +func getCachedTags(t reflect.Type) *objectFields { + if t.Kind() != reflect.Struct { + panic("getCachedTags called with non-struct type") + } + + if entry, ok := tagsCache[t]; ok { + return entry + } + + ff := rivertags.Get(t) + + // Build a tree of keys. + tree := &objectFields{ + fields: make(map[string]rivertags.Field), + nestedFields: make(map[string]*objectFields), + keys: []string{}, + } + + for _, f := range ff { + if f.Flags&rivertags.FlagLabel != 0 { + // Skip over label tags. + tree.labelField = f + continue + } + + node := tree + for i, name := range f.Name { + // Add to the list of keys if this is a new key. + if node.Has(name) == objectKeyTypeInvalid { + node.keys = append(node.keys, name) + } + + if i+1 == len(f.Name) { + // Last fragment, add as a field. + node.fields[name] = f + continue + } + + inner, ok := node.nestedFields[name] + if !ok { + inner = &objectFields{ + fields: make(map[string]rivertags.Field), + nestedFields: make(map[string]*objectFields), + keys: []string{}, + } + node.nestedFields[name] = inner + } + node = inner + } + } + + tagsCache[t] = tree + return tree +} + +// objectFields is a parsed tree of fields in rivertags. It forms a tree where +// leaves are nested fields (e.g., for block names that have multiple name +// fragments) and nodes are the fields themselves. +type objectFields struct { + fields map[string]rivertags.Field + nestedFields map[string]*objectFields + keys []string // Combination of fields + nestedFields + labelField rivertags.Field +} + +type objectKeyType int + +const ( + objectKeyTypeInvalid objectKeyType = iota + objectKeyTypeField + objectKeyTypeNestedField +) + +// Has returns whether name exists as a field or a nested key inside keys. +// Returns objectKeyTypeInvalid if name does not exist as either. +func (of *objectFields) Has(name string) objectKeyType { + if _, ok := of.fields[name]; ok { + return objectKeyTypeField + } + if _, ok := of.nestedFields[name]; ok { + return objectKeyTypeNestedField + } + return objectKeyTypeInvalid +} + +// Len returns the number of named keys. +func (of *objectFields) Len() int { return len(of.keys) } + +// Keys returns all named keys (fields and nested fields). +func (of *objectFields) Keys() []string { return of.keys } + +// Field gets a non-nested field. Returns false if name is a nested field. +func (of *objectFields) Field(name string) (rivertags.Field, bool) { + f, ok := of.fields[name] + return f, ok +} + +// NestedField gets a named nested field entry. Returns false if name is not a +// nested field. +func (of *objectFields) NestedField(name string) (*objectFields, bool) { + nk, ok := of.nestedFields[name] + return nk, ok +} + +// LabelField returns the field used for the label (if any). +func (of *objectFields) LabelField() (rivertags.Field, bool) { + return of.labelField, of.labelField.Index != nil +} diff --git a/syntax/internal/value/type.go b/syntax/internal/value/type.go new file mode 100644 index 0000000000..e79715cbbf --- /dev/null +++ b/syntax/internal/value/type.go @@ -0,0 +1,157 @@ +package value + +import ( + "fmt" + "reflect" +) + +// Type represents the type of a River value loosely. For example, a Value may +// be TypeArray, but this does not imply anything about the type of that +// array's elements (all of which may be any type). +// +// TypeCapsule is a special type which encapsulates arbitrary Go values. +type Type uint8 + +// Supported Type values. +const ( + TypeNull Type = iota + TypeNumber + TypeString + TypeBool + TypeArray + TypeObject + TypeFunction + TypeCapsule +) + +var typeStrings = [...]string{ + TypeNull: "null", + TypeNumber: "number", + TypeString: "string", + TypeBool: "bool", + TypeArray: "array", + TypeObject: "object", + TypeFunction: "function", + TypeCapsule: "capsule", +} + +// String returns the name of t. +func (t Type) String() string { + if int(t) < len(typeStrings) { + return typeStrings[t] + } + return fmt.Sprintf("Type(%d)", t) +} + +// GoString returns the name of t. +func (t Type) GoString() string { return t.String() } + +// RiverType returns the River type from the Go type. +// +// Go types map to River types using the following rules: +// +// 1. Go numbers (ints, uints, floats) map to a River number. +// 2. Go strings map to a River string. +// 3. Go bools map to a River bool. +// 4. Go arrays and slices map to a River array. +// 5. Go map[string]T map to a River object. +// 6. Go structs map to a River object, provided they have at least one field +// with a river tag. +// 7. Valid Go functions map to a River function. +// 8. Go interfaces map to a River capsule. +// 9. All other Go values map to a River capsule. +// +// Go functions are only valid for River if they have one non-error return type +// (the first return type) and one optional error return type (the second +// return type). Other function types are treated as capsules. +// +// As an exception, any type which implements the Capsule interface is forced +// to be a capsule. +func RiverType(t reflect.Type) Type { + // We don't know if the RiverCapsule interface is implemented for a pointer + // or non-pointer type, so we have to check before and after dereferencing. + + for t.Kind() == reflect.Pointer { + switch { + case t.Implements(goCapsule): + return TypeCapsule + case t.Implements(goTextMarshaler): + return TypeString + } + + t = t.Elem() + } + + switch { + case t.Implements(goCapsule): + return TypeCapsule + case t.Implements(goTextMarshaler): + return TypeString + case t == goDuration: + return TypeString + } + + switch t.Kind() { + case reflect.Invalid: + return TypeNull + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return TypeNumber + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return TypeNumber + case reflect.Float32, reflect.Float64: + return TypeNumber + + case reflect.String: + return TypeString + + case reflect.Bool: + return TypeBool + + case reflect.Array, reflect.Slice: + if inner := t.Elem(); inner.Kind() == reflect.Struct { + if _, labeled := getCachedTags(inner).LabelField(); labeled { + // An slice/array of labeled blocks is an object, where each label is a + // top-level key. + return TypeObject + } + } + return TypeArray + + case reflect.Map: + if t.Key() != goString { + // Objects must be keyed by string. Anything else is forced to be a + // Capsule. + return TypeCapsule + } + return TypeObject + + case reflect.Struct: + if getCachedTags(t).Len() == 0 { + return TypeCapsule + } + return TypeObject + + case reflect.Func: + switch t.NumOut() { + case 1: + if t.Out(0) == goError { + return TypeCapsule + } + return TypeFunction + case 2: + if t.Out(0) == goError || t.Out(1) != goError { + return TypeCapsule + } + return TypeFunction + default: + return TypeCapsule + } + + case reflect.Interface: + return TypeCapsule + + default: + return TypeCapsule + } +} diff --git a/syntax/internal/value/type_test.go b/syntax/internal/value/type_test.go new file mode 100644 index 0000000000..10ee04bc75 --- /dev/null +++ b/syntax/internal/value/type_test.go @@ -0,0 +1,80 @@ +package value_test + +import ( + "reflect" + "testing" + + "github.com/grafana/river/internal/value" + "github.com/stretchr/testify/require" +) + +type customCapsule bool + +var _ value.Capsule = (customCapsule)(false) + +func (customCapsule) RiverCapsule() {} + +var typeTests = []struct { + input interface{} + expect value.Type +}{ + {int(0), value.TypeNumber}, + {int8(0), value.TypeNumber}, + {int16(0), value.TypeNumber}, + {int32(0), value.TypeNumber}, + {int64(0), value.TypeNumber}, + {uint(0), value.TypeNumber}, + {uint8(0), value.TypeNumber}, + {uint16(0), value.TypeNumber}, + {uint32(0), value.TypeNumber}, + {uint64(0), value.TypeNumber}, + {float32(0), value.TypeNumber}, + {float64(0), value.TypeNumber}, + + {string(""), value.TypeString}, + + {bool(false), value.TypeBool}, + + {[...]int{0, 1, 2}, value.TypeArray}, + {[]int{0, 1, 2}, value.TypeArray}, + + // Struct with no River tags is a capsule. + {struct{}{}, value.TypeCapsule}, + + // A slice of labeled blocks should be an object. + {[]struct { + Label string `river:",label"` + }{}, value.TypeObject}, + + {map[string]interface{}{}, value.TypeObject}, + + // Go functions must have one non-error return type and one optional error + // return type to be River functions. Everything else is a capsule. + {(func() int)(nil), value.TypeFunction}, + {(func() (int, error))(nil), value.TypeFunction}, + {(func())(nil), value.TypeCapsule}, // Must have non-error return type + {(func() error)(nil), value.TypeCapsule}, // First return type must be non-error + {(func() (error, int))(nil), value.TypeCapsule}, // First return type must be non-error + {(func() (error, error))(nil), value.TypeCapsule}, // First return type must be non-error + {(func() (int, int))(nil), value.TypeCapsule}, // Second return type must be error + {(func() (int, int, int))(nil), value.TypeCapsule}, // Can only have 1 or 2 return types + + {make(chan struct{}), value.TypeCapsule}, + {map[bool]interface{}{}, value.TypeCapsule}, // Maps with non-string types are capsules + + // Types with capsule markers should be capsules. + {customCapsule(false), value.TypeCapsule}, + {(*customCapsule)(nil), value.TypeCapsule}, + {(**customCapsule)(nil), value.TypeCapsule}, +} + +func Test_RiverType(t *testing.T) { + for _, tc := range typeTests { + rt := reflect.TypeOf(tc.input) + + t.Run(rt.String(), func(t *testing.T) { + actual := value.RiverType(rt) + require.Equal(t, tc.expect, actual, "Unexpected type for %#v", tc.input) + }) + } +} diff --git a/syntax/internal/value/value.go b/syntax/internal/value/value.go new file mode 100644 index 0000000000..bdd8492c09 --- /dev/null +++ b/syntax/internal/value/value.go @@ -0,0 +1,556 @@ +// Package value holds the internal representation for River values. River +// values act as a lightweight wrapper around reflect.Value. +package value + +import ( + "encoding" + "fmt" + "reflect" + "strconv" + "strings" + "time" + + "github.com/grafana/river/internal/reflectutil" +) + +// Go types used throughout the package. +var ( + goAny = reflect.TypeOf((*interface{})(nil)).Elem() + goString = reflect.TypeOf(string("")) + goByteSlice = reflect.TypeOf([]byte(nil)) + goError = reflect.TypeOf((*error)(nil)).Elem() + goTextMarshaler = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() + goTextUnmarshaler = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() + goStructWrapper = reflect.TypeOf(structWrapper{}) + goCapsule = reflect.TypeOf((*Capsule)(nil)).Elem() + goDuration = reflect.TypeOf((time.Duration)(0)) + goDurationPtr = reflect.TypeOf((*time.Duration)(nil)) + goRiverDefaulter = reflect.TypeOf((*Defaulter)(nil)).Elem() + goRiverDecoder = reflect.TypeOf((*Unmarshaler)(nil)).Elem() + goRiverValidator = reflect.TypeOf((*Validator)(nil)).Elem() + goRawRiverFunc = reflect.TypeOf((RawFunction)(nil)) + goRiverValue = reflect.TypeOf(Null) +) + +// NOTE(rfratto): This package is extremely sensitive to performance, so +// changes should be made with caution; run benchmarks when changing things. +// +// Value is optimized to be as small as possible and exist fully on the stack. +// This allows many values to avoid allocations, with the exception of creating +// arrays and objects. + +// Value represents a River value. +type Value struct { + rv reflect.Value + ty Type +} + +// Null is the null value. +var Null = Value{} + +// Uint returns a Value from a uint64. +func Uint(u uint64) Value { return Value{rv: reflect.ValueOf(u), ty: TypeNumber} } + +// Int returns a Value from an int64. +func Int(i int64) Value { return Value{rv: reflect.ValueOf(i), ty: TypeNumber} } + +// Float returns a Value from a float64. +func Float(f float64) Value { return Value{rv: reflect.ValueOf(f), ty: TypeNumber} } + +// String returns a Value from a string. +func String(s string) Value { return Value{rv: reflect.ValueOf(s), ty: TypeString} } + +// Bool returns a Value from a bool. +func Bool(b bool) Value { return Value{rv: reflect.ValueOf(b), ty: TypeBool} } + +// Object returns a new value from m. A copy of m is made for producing the +// Value. +func Object(m map[string]Value) Value { + return Value{ + rv: reflect.ValueOf(m), + ty: TypeObject, + } +} + +// Array creates an array from the given values. A copy of the vv slice is made +// for producing the Value. +func Array(vv ...Value) Value { + return Value{ + rv: reflect.ValueOf(vv), + ty: TypeArray, + } +} + +// Func makes a new function Value from f. Func panics if f does not map to a +// River function. +func Func(f interface{}) Value { + rv := reflect.ValueOf(f) + if RiverType(rv.Type()) != TypeFunction { + panic("river/value: Func called with non-function type") + } + return Value{rv: rv, ty: TypeFunction} +} + +// Encapsulate creates a new Capsule value from v. Encapsulate panics if v does +// not map to a River capsule. +func Encapsulate(v interface{}) Value { + rv := reflect.ValueOf(v) + if RiverType(rv.Type()) != TypeCapsule { + panic("river/value: Capsule called with non-capsule type") + } + return Value{rv: rv, ty: TypeCapsule} +} + +// Encode creates a new Value from v. If v is a pointer, v must be considered +// immutable and not change while the Value is used. +func Encode(v interface{}) Value { + if v == nil { + return Null + } + return makeValue(reflect.ValueOf(v)) +} + +// FromRaw converts a reflect.Value into a River Value. It is useful to prevent +// downcasting an interface into an any. +func FromRaw(v reflect.Value) Value { + return makeValue(v) +} + +// Type returns the River type for the value. +func (v Value) Type() Type { return v.ty } + +// Describe returns a descriptive type name for the value. For capsule values, +// this prints the underlying Go type name. For other values, it prints the +// normal River type. +func (v Value) Describe() string { + if v.ty != TypeCapsule { + return v.ty.String() + } + return fmt.Sprintf("capsule(%q)", v.rv.Type()) +} + +// Bool returns the boolean value for v. It panics if v is not a bool. +func (v Value) Bool() bool { + if v.ty != TypeBool { + panic("river/value: Bool called on non-bool type") + } + return v.rv.Bool() +} + +// Number returns a Number value for v. It panics if v is not a Number. +func (v Value) Number() Number { + if v.ty != TypeNumber { + panic("river/value: Number called on non-number type") + } + return newNumberValue(v.rv) +} + +// Int returns an int value for v. It panics if v is not a number. +func (v Value) Int() int64 { + if v.ty != TypeNumber { + panic("river/value: Int called on non-number type") + } + switch makeNumberKind(v.rv.Kind()) { + case NumberKindInt: + return v.rv.Int() + case NumberKindUint: + return int64(v.rv.Uint()) + case NumberKindFloat: + return int64(v.rv.Float()) + } + panic("river/value: unreachable") +} + +// Uint returns an uint value for v. It panics if v is not a number. +func (v Value) Uint() uint64 { + if v.ty != TypeNumber { + panic("river/value: Uint called on non-number type") + } + switch makeNumberKind(v.rv.Kind()) { + case NumberKindInt: + return uint64(v.rv.Int()) + case NumberKindUint: + return v.rv.Uint() + case NumberKindFloat: + return uint64(v.rv.Float()) + } + panic("river/value: unreachable") +} + +// Float returns a float value for v. It panics if v is not a number. +func (v Value) Float() float64 { + if v.ty != TypeNumber { + panic("river/value: Float called on non-number type") + } + switch makeNumberKind(v.rv.Kind()) { + case NumberKindInt: + return float64(v.rv.Int()) + case NumberKindUint: + return float64(v.rv.Uint()) + case NumberKindFloat: + return v.rv.Float() + } + panic("river/value: unreachable") +} + +// Text returns a string value of v. It panics if v is not a string. +func (v Value) Text() string { + if v.ty != TypeString { + panic("river/value: Text called on non-string type") + } + + // Attempt to get an address to v.rv for interface checking. + // + // The normal v.rv value is used for other checks. + addrRV := v.rv + if addrRV.CanAddr() { + addrRV = addrRV.Addr() + } + switch { + case addrRV.Type().Implements(goTextMarshaler): + // TODO(rfratto): what should we do if this fails? + text, _ := addrRV.Interface().(encoding.TextMarshaler).MarshalText() + return string(text) + + case v.rv.Type() == goDuration: + // Special case: v.rv is a duration and its String method should be used. + return v.rv.Interface().(time.Duration).String() + + default: + return v.rv.String() + } +} + +// Len returns the length of v. Panics if v is not an array or object. +func (v Value) Len() int { + switch v.ty { + case TypeArray: + return v.rv.Len() + case TypeObject: + switch { + case v.rv.Type() == goStructWrapper: + return v.rv.Interface().(structWrapper).Len() + case v.rv.Kind() == reflect.Array, v.rv.Kind() == reflect.Slice: // Array of labeled blocks + return v.rv.Len() + case v.rv.Kind() == reflect.Struct: + return getCachedTags(v.rv.Type()).Len() + case v.rv.Kind() == reflect.Map: + return v.rv.Len() + } + } + panic("river/value: Len called on non-array and non-object value") +} + +// Index returns index i of the Value. Panics if the value is not an array or +// if it is out of bounds of the array's size. +func (v Value) Index(i int) Value { + if v.ty != TypeArray { + panic("river/value: Index called on non-array value") + } + return makeValue(v.rv.Index(i)) +} + +// Interface returns the underlying Go value for the Value. +func (v Value) Interface() interface{} { + if v.ty == TypeNull { + return nil + } + return v.rv.Interface() +} + +// Reflect returns the raw reflection value backing v. +func (v Value) Reflect() reflect.Value { return v.rv } + +// makeValue converts a reflect value into a Value, dereferencing any pointers or +// interface{} values. +func makeValue(v reflect.Value) Value { + // Early check: if v is interface{}, we need to deference it to get the + // concrete value. + if v.IsValid() && v.Type() == goAny { + v = v.Elem() + } + + // Special case: a reflect.Value may be a value.Value when it's coming from a + // River array or object. We can unwrap the inner value here before + // continuing. + if v.IsValid() && v.Type() == goRiverValue { + // Unwrap the inner value. + v = v.Interface().(Value).rv + } + + // Before we get the River type of the Value, we need to see if it's possible + // to get a pointer to v. This ensures that if v is a non-pointer field of an + // addressable struct, still detect the type of v as if it was a pointer. + if v.CanAddr() { + v = v.Addr() + } + + if !v.IsValid() { + return Null + } + riverType := RiverType(v.Type()) + + // Finally, deference the pointer fully and use the type we detected. + for v.Kind() == reflect.Pointer { + if v.IsNil() { + return Null + } + v = v.Elem() + } + return Value{rv: v, ty: riverType} +} + +// OrderedKeys reports if v represents an object with consistently ordered +// keys. It panics if v is not an object. +func (v Value) OrderedKeys() bool { + if v.ty != TypeObject { + panic("river/value: OrderedKeys called on non-object value") + } + + // Maps are the only type of unordered River object, since their keys can't + // be iterated over in a deterministic order. Every other type of River + // object comes from a struct or a slice where the order of keys stays the + // same. + return v.rv.Kind() != reflect.Map +} + +// Keys returns the keys in v in unspecified order. It panics if v is not an +// object. +func (v Value) Keys() []string { + if v.ty != TypeObject { + panic("river/value: Keys called on non-object value") + } + + switch { + case v.rv.Type() == goStructWrapper: + return v.rv.Interface().(structWrapper).Keys() + + case v.rv.Kind() == reflect.Struct: + return wrapStruct(v.rv, true).Keys() + + case v.rv.Kind() == reflect.Array, v.rv.Kind() == reflect.Slice: + // List of labeled blocks. + labelField, _ := getCachedTags(v.rv.Type().Elem()).LabelField() + + keys := make([]string, v.rv.Len()) + for i := range keys { + keys[i] = reflectutil.Get(v.rv.Index(i), labelField).String() + } + return keys + + case v.rv.Kind() == reflect.Map: + reflectKeys := v.rv.MapKeys() + res := make([]string, len(reflectKeys)) + for i, rk := range reflectKeys { + res[i] = rk.String() + } + return res + } + + panic("river/value: unreachable") +} + +// Key returns the value for a key in v. It panics if v is not an object. ok +// will be false if the key did not exist in the object. +func (v Value) Key(key string) (index Value, ok bool) { + if v.ty != TypeObject { + panic("river/value: Key called on non-object value") + } + + switch { + case v.rv.Type() == goStructWrapper: + return v.rv.Interface().(structWrapper).Key(key) + case v.rv.Kind() == reflect.Struct: + // We return the struct with the label intact. + return wrapStruct(v.rv, true).Key(key) + case v.rv.Kind() == reflect.Map: + val := v.rv.MapIndex(reflect.ValueOf(key)) + if !val.IsValid() { + return Null, false + } + return makeValue(val), true + + case v.rv.Kind() == reflect.Slice, v.rv.Kind() == reflect.Array: + // List of labeled blocks. + labelField, _ := getCachedTags(v.rv.Type().Elem()).LabelField() + + for i := 0; i < v.rv.Len(); i++ { + elem := v.rv.Index(i) + + label := reflectutil.Get(elem, labelField).String() + if label == key { + // We discard the label since the key here represents the label value. + ws := wrapStruct(elem, false) + return ws.Value(), true + } + } + default: + panic("river/value: unreachable") + } + + return +} + +// Call invokes a function value with the provided arguments. It panics if v is +// not a function. If v is a variadic function, args should be the full flat +// list of arguments. +// +// An ArgError will be returned if one of the arguments is invalid. An Error +// will be returned if the function call returns an error or if the number of +// arguments doesn't match. +func (v Value) Call(args ...Value) (Value, error) { + if v.ty != TypeFunction { + panic("river/value: Call called on non-function type") + } + + if v.rv.Type() == goRawRiverFunc { + return v.rv.Interface().(RawFunction)(v, args...) + } + + var ( + variadic = v.rv.Type().IsVariadic() + expectedArgs = v.rv.Type().NumIn() + ) + + if variadic && len(args) < expectedArgs-1 { + return Null, Error{ + Value: v, + Inner: fmt.Errorf("expected at least %d args, got %d", expectedArgs-1, len(args)), + } + } else if !variadic && len(args) != expectedArgs { + return Null, Error{ + Value: v, + Inner: fmt.Errorf("expected %d args, got %d", expectedArgs, len(args)), + } + } + + reflectArgs := make([]reflect.Value, len(args)) + for i, arg := range args { + var argVal reflect.Value + if variadic && i >= expectedArgs-1 { + argType := v.rv.Type().In(expectedArgs - 1).Elem() + argVal = reflect.New(argType).Elem() + } else { + argType := v.rv.Type().In(i) + argVal = reflect.New(argType).Elem() + } + + var d decoder + if err := d.decode(arg, argVal); err != nil { + return Null, ArgError{ + Function: v, + Argument: arg, + Index: i, + Inner: err, + } + } + + reflectArgs[i] = argVal + } + + outs := v.rv.Call(reflectArgs) + switch len(outs) { + case 1: + return makeValue(outs[0]), nil + case 2: + // When there's 2 return values, the second is always an error. + err, _ := outs[1].Interface().(error) + if err != nil { + return Null, Error{Value: v, Inner: err} + } + return makeValue(outs[0]), nil + + default: + // It's not possible to reach here; we enforce that function values always + // have 1 or 2 return values. + panic("river/value: unreachable") + } +} + +func convertValue(val Value, toType Type) (Value, error) { + // TODO(rfratto): Use vm benchmarks to see if making this a method on Value + // changes anything. + + fromType := val.Type() + + if fromType == toType { + // no-op: val is already the right kind. + return val, nil + } + + switch fromType { + case TypeNumber: + switch toType { + case TypeString: // number -> string + strVal := newNumberValue(val.rv).ToString() + return makeValue(reflect.ValueOf(strVal)), nil + } + + case TypeString: + sourceStr := val.rv.String() + + switch toType { + case TypeNumber: // string -> number + switch { + case sourceStr == "": + return Null, TypeError{Value: val, Expected: toType} + + case sourceStr[0] == '-': + // String starts with a -; parse as a signed int. + parsed, err := strconv.ParseInt(sourceStr, 10, 64) + if err != nil { + return Null, TypeError{Value: val, Expected: toType} + } + return Int(parsed), nil + case strings.ContainsAny(sourceStr, ".eE"): + // String contains something that a floating-point number would use; + // convert. + parsed, err := strconv.ParseFloat(sourceStr, 64) + if err != nil { + return Null, TypeError{Value: val, Expected: toType} + } + return Float(parsed), nil + default: + // Otherwise, treat the number as an unsigned int. + parsed, err := strconv.ParseUint(sourceStr, 10, 64) + if err != nil { + return Null, TypeError{Value: val, Expected: toType} + } + return Uint(parsed), nil + } + } + } + + return Null, TypeError{Value: val, Expected: toType} +} + +func convertGoNumber(nval Number, target reflect.Type) reflect.Value { + switch target.Kind() { + case reflect.Int: + return reflect.ValueOf(int(nval.Int())) + case reflect.Int8: + return reflect.ValueOf(int8(nval.Int())) + case reflect.Int16: + return reflect.ValueOf(int16(nval.Int())) + case reflect.Int32: + return reflect.ValueOf(int32(nval.Int())) + case reflect.Int64: + return reflect.ValueOf(nval.Int()) + case reflect.Uint: + return reflect.ValueOf(uint(nval.Uint())) + case reflect.Uint8: + return reflect.ValueOf(uint8(nval.Uint())) + case reflect.Uint16: + return reflect.ValueOf(uint16(nval.Uint())) + case reflect.Uint32: + return reflect.ValueOf(uint32(nval.Uint())) + case reflect.Uint64: + return reflect.ValueOf(nval.Uint()) + case reflect.Float32: + return reflect.ValueOf(float32(nval.Float())) + case reflect.Float64: + return reflect.ValueOf(nval.Float()) + } + + panic("unsupported number conversion") +} diff --git a/syntax/internal/value/value_object.go b/syntax/internal/value/value_object.go new file mode 100644 index 0000000000..6a642cb22f --- /dev/null +++ b/syntax/internal/value/value_object.go @@ -0,0 +1,119 @@ +package value + +import ( + "reflect" + + "github.com/grafana/river/internal/reflectutil" +) + +// structWrapper allows for partially traversing structs which contain fields +// representing blocks. This is required due to how block names and labels +// change the object representation. +// +// If a block name is a.b.c, then it is represented as three nested objects: +// +// { +// a = { +// b = { +// c = { /* block contents */ }, +// }, +// } +// } +// +// Similarly, if a block name is labeled (a.b.c "label"), then the label is the +// top-level key after c. +// +// structWrapper exposes Len, Keys, and Key methods similar to Value to allow +// traversing through the synthetic object. The values it returns are +// structWrappers. +// +// Code in value.go MUST check to see if a struct is a structWrapper *before* +// checking the value kind to ensure the appropriate methods are invoked. +type structWrapper struct { + structVal reflect.Value + fields *objectFields + label string // Non-empty string if this struct is wrapped in a label. +} + +func wrapStruct(val reflect.Value, keepLabel bool) structWrapper { + if val.Kind() != reflect.Struct { + panic("river/value: wrapStruct called on non-struct value") + } + + fields := getCachedTags(val.Type()) + + var label string + if f, ok := fields.LabelField(); ok && keepLabel { + label = reflectutil.Get(val, f).String() + } + + return structWrapper{ + structVal: val, + fields: fields, + label: label, + } +} + +// Value turns sw into a value. +func (sw structWrapper) Value() Value { + return Value{ + rv: reflect.ValueOf(sw), + ty: TypeObject, + } +} + +func (sw structWrapper) Len() int { + if len(sw.label) > 0 { + return 1 + } + return sw.fields.Len() +} + +func (sw structWrapper) Keys() []string { + if len(sw.label) > 0 { + return []string{sw.label} + } + return sw.fields.Keys() +} + +func (sw structWrapper) Key(key string) (index Value, ok bool) { + if len(sw.label) > 0 { + if key != sw.label { + return + } + next := reflect.ValueOf(structWrapper{ + structVal: sw.structVal, + fields: sw.fields, + // Unset the label now that we've traversed it + }) + return Value{rv: next, ty: TypeObject}, true + } + + keyType := sw.fields.Has(key) + + switch keyType { + case objectKeyTypeInvalid: + return // No such key + + case objectKeyTypeNestedField: + // Continue traversing. + nextNode, _ := sw.fields.NestedField(key) + return Value{ + rv: reflect.ValueOf(structWrapper{ + structVal: sw.structVal, + fields: nextNode, + }), + ty: TypeObject, + }, true + + case objectKeyTypeField: + f, _ := sw.fields.Field(key) + val, err := sw.structVal.FieldByIndexErr(f.Index) + if err != nil { + return Null, true + } + return makeValue(val), true + } + + panic("river/value: unreachable") +} diff --git a/syntax/internal/value/value_object_test.go b/syntax/internal/value/value_object_test.go new file mode 100644 index 0000000000..56d72a6102 --- /dev/null +++ b/syntax/internal/value/value_object_test.go @@ -0,0 +1,205 @@ +package value_test + +import ( + "testing" + + "github.com/grafana/river/internal/value" + "github.com/stretchr/testify/require" +) + +// TestBlockRepresentation ensures that the struct tags for blocks are +// represented correctly. +func TestBlockRepresentation(t *testing.T) { + type UnlabledBlock struct { + Value int `river:"value,attr"` + } + type LabeledBlock struct { + Value int `river:"value,attr"` + Label string `river:",label"` + } + type OuterBlock struct { + Attr1 string `river:"attr_1,attr"` + Attr2 string `river:"attr_2,attr"` + + UnlabledBlock1 UnlabledBlock `river:"unlabeled.a,block"` + UnlabledBlock2 UnlabledBlock `river:"unlabeled.b,block"` + UnlabledBlock3 UnlabledBlock `river:"other_unlabeled,block"` + + LabeledBlock1 LabeledBlock `river:"labeled.a,block"` + LabeledBlock2 LabeledBlock `river:"labeled.b,block"` + LabeledBlock3 LabeledBlock `river:"other_labeled,block"` + } + + val := OuterBlock{ + Attr1: "value_1", + Attr2: "value_2", + UnlabledBlock1: UnlabledBlock{ + Value: 1, + }, + UnlabledBlock2: UnlabledBlock{ + Value: 2, + }, + UnlabledBlock3: UnlabledBlock{ + Value: 3, + }, + LabeledBlock1: LabeledBlock{ + Value: 4, + Label: "label_a", + }, + LabeledBlock2: LabeledBlock{ + Value: 5, + Label: "label_b", + }, + LabeledBlock3: LabeledBlock{ + Value: 6, + Label: "label_c", + }, + } + + t.Run("Map decode", func(t *testing.T) { + var m map[string]interface{} + require.NoError(t, value.Decode(value.Encode(val), &m)) + + type object = map[string]interface{} + + expect := object{ + "attr_1": "value_1", + "attr_2": "value_2", + "unlabeled": object{ + "a": object{"value": 1}, + "b": object{"value": 2}, + }, + "other_unlabeled": object{"value": 3}, + "labeled": object{ + "a": object{ + "label_a": object{"value": 4}, + }, + "b": object{ + "label_b": object{"value": 5}, + }, + }, + "other_labeled": object{ + "label_c": object{"value": 6}, + }, + } + + require.Equal(t, m, expect) + }) + + t.Run("Object decode from other object", func(t *testing.T) { + // Decode into a separate type which is structurally identical but not + // literally the same. + type OuterBlock2 OuterBlock + + var actualVal OuterBlock2 + require.NoError(t, value.Decode(value.Encode(val), &actualVal)) + require.Equal(t, val, OuterBlock(actualVal)) + }) +} + +// TestSquashedBlockRepresentation ensures that the struct tags for squashed +// blocks are represented correctly. +func TestSquashedBlockRepresentation(t *testing.T) { + type InnerStruct struct { + InnerField1 string `river:"inner_field_1,attr,optional"` + InnerField2 string `river:"inner_field_2,attr,optional"` + } + + type OuterStruct struct { + OuterField1 string `river:"outer_field_1,attr,optional"` + Inner InnerStruct `river:",squash"` + OuterField2 string `river:"outer_field_2,attr,optional"` + } + + val := OuterStruct{ + OuterField1: "value1", + Inner: InnerStruct{ + InnerField1: "value3", + InnerField2: "value4", + }, + OuterField2: "value2", + } + + t.Run("Map decode", func(t *testing.T) { + var m map[string]interface{} + require.NoError(t, value.Decode(value.Encode(val), &m)) + + type object = map[string]interface{} + + expect := object{ + "outer_field_1": "value1", + "inner_field_1": "value3", + "inner_field_2": "value4", + "outer_field_2": "value2", + } + + require.Equal(t, m, expect) + }) +} + +func TestSliceOfBlocks(t *testing.T) { + type UnlabledBlock struct { + Value int `river:"value,attr"` + } + type LabeledBlock struct { + Value int `river:"value,attr"` + Label string `river:",label"` + } + type OuterBlock struct { + Attr1 string `river:"attr_1,attr"` + Attr2 string `river:"attr_2,attr"` + + Unlabeled []UnlabledBlock `river:"unlabeled,block"` + Labeled []LabeledBlock `river:"labeled,block"` + } + + val := OuterBlock{ + Attr1: "value_1", + Attr2: "value_2", + Unlabeled: []UnlabledBlock{ + {Value: 1}, + {Value: 2}, + {Value: 3}, + }, + Labeled: []LabeledBlock{ + {Label: "label_a", Value: 4}, + {Label: "label_b", Value: 5}, + {Label: "label_c", Value: 6}, + }, + } + + t.Run("Map decode", func(t *testing.T) { + var m map[string]interface{} + require.NoError(t, value.Decode(value.Encode(val), &m)) + + type object = map[string]interface{} + type list = []interface{} + + expect := object{ + "attr_1": "value_1", + "attr_2": "value_2", + "unlabeled": list{ + object{"value": 1}, + object{"value": 2}, + object{"value": 3}, + }, + "labeled": object{ + "label_a": object{"value": 4}, + "label_b": object{"value": 5}, + "label_c": object{"value": 6}, + }, + } + + require.Equal(t, m, expect) + }) + + t.Run("Object decode from other object", func(t *testing.T) { + // Decode into a separate type which is structurally identical but not + // literally the same. + type OuterBlock2 OuterBlock + + var actualVal OuterBlock2 + require.NoError(t, value.Decode(value.Encode(val), &actualVal)) + require.Equal(t, val, OuterBlock(actualVal)) + }) +} diff --git a/syntax/internal/value/value_test.go b/syntax/internal/value/value_test.go new file mode 100644 index 0000000000..4583e5196d --- /dev/null +++ b/syntax/internal/value/value_test.go @@ -0,0 +1,243 @@ +package value_test + +import ( + "fmt" + "io" + "testing" + + "github.com/grafana/river/internal/value" + "github.com/stretchr/testify/require" +) + +// TestEncodeKeyLookup tests where Go values are retained correctly +// throughout values with a key lookup. +func TestEncodeKeyLookup(t *testing.T) { + type Body struct { + Data pointerMarshaler `river:"data,attr"` + } + + tt := []struct { + name string + encodeTarget any + key string + + expectBodyType value.Type + expectKeyExists bool + expectKeyValue value.Value + expectKeyType value.Type + }{ + { + name: "Struct Encode data Key", + encodeTarget: &Body{}, + key: "data", + expectBodyType: value.TypeObject, + expectKeyExists: true, + expectKeyValue: value.String("Hello, world!"), + expectKeyType: value.TypeString, + }, + { + name: "Struct Encode Missing Key", + encodeTarget: &Body{}, + key: "missing", + expectBodyType: value.TypeObject, + expectKeyExists: false, + expectKeyValue: value.Null, + expectKeyType: value.TypeNull, + }, + { + name: "Map Encode data Key", + encodeTarget: map[string]string{"data": "Hello, world!"}, + key: "data", + expectBodyType: value.TypeObject, + expectKeyExists: true, + expectKeyValue: value.String("Hello, world!"), + expectKeyType: value.TypeString, + }, + { + name: "Map Encode Missing Key", + encodeTarget: map[string]string{"data": "Hello, world!"}, + key: "missing", + expectBodyType: value.TypeObject, + expectKeyExists: false, + expectKeyValue: value.Null, + expectKeyType: value.TypeNull, + }, + { + name: "Map Encode empty value Key", + encodeTarget: map[string]string{"data": ""}, + key: "data", + expectBodyType: value.TypeObject, + expectKeyExists: true, + expectKeyValue: value.String(""), + expectKeyType: value.TypeString, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + bodyVal := value.Encode(tc.encodeTarget) + require.Equal(t, tc.expectBodyType, bodyVal.Type()) + + val, ok := bodyVal.Key(tc.key) + require.Equal(t, tc.expectKeyExists, ok) + require.Equal(t, tc.expectKeyType, val.Type()) + switch val.Type() { + case value.TypeString: + require.Equal(t, tc.expectKeyValue.Text(), val.Text()) + case value.TypeNull: + require.Equal(t, tc.expectKeyValue, val) + default: + require.Fail(t, "unexpected value type (this switch can be expanded)") + } + }) + } +} + +// TestEncodeNoKeyLookup tests where Go values are retained correctly +// throughout values without a key lookup. +func TestEncodeNoKeyLookup(t *testing.T) { + tt := []struct { + name string + encodeTarget any + key string + + expectBodyType value.Type + expectBodyText string + }{ + { + name: "Encode", + encodeTarget: &pointerMarshaler{}, + expectBodyType: value.TypeString, + expectBodyText: "Hello, world!", + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + bodyVal := value.Encode(tc.encodeTarget) + require.Equal(t, tc.expectBodyType, bodyVal.Type()) + require.Equal(t, "Hello, world!", bodyVal.Text()) + }) + } +} + +type pointerMarshaler struct{} + +func (*pointerMarshaler) MarshalText() ([]byte, error) { + return []byte("Hello, world!"), nil +} + +func TestValue_Call(t *testing.T) { + t.Run("simple", func(t *testing.T) { + add := func(a, b int) int { return a + b } + addVal := value.Encode(add) + + res, err := addVal.Call( + value.Int(15), + value.Int(43), + ) + require.NoError(t, err) + require.Equal(t, int64(15+43), res.Int()) + }) + + t.Run("fully variadic", func(t *testing.T) { + add := func(nums ...int) int { + var sum int + for _, num := range nums { + sum += num + } + return sum + } + addVal := value.Encode(add) + + t.Run("no args", func(t *testing.T) { + res, err := addVal.Call() + require.NoError(t, err) + require.Equal(t, int64(0), res.Int()) + }) + + t.Run("one arg", func(t *testing.T) { + res, err := addVal.Call(value.Int(32)) + require.NoError(t, err) + require.Equal(t, int64(32), res.Int()) + }) + + t.Run("many args", func(t *testing.T) { + res, err := addVal.Call( + value.Int(32), + value.Int(59), + value.Int(12), + ) + require.NoError(t, err) + require.Equal(t, int64(32+59+12), res.Int()) + }) + }) + + t.Run("partially variadic", func(t *testing.T) { + add := func(firstNum int, nums ...int) int { + sum := firstNum + for _, num := range nums { + sum += num + } + return sum + } + addVal := value.Encode(add) + + t.Run("no variadic args", func(t *testing.T) { + res, err := addVal.Call(value.Int(52)) + require.NoError(t, err) + require.Equal(t, int64(52), res.Int()) + }) + + t.Run("one variadic arg", func(t *testing.T) { + res, err := addVal.Call(value.Int(52), value.Int(32)) + require.NoError(t, err) + require.Equal(t, int64(52+32), res.Int()) + }) + + t.Run("many variadic args", func(t *testing.T) { + res, err := addVal.Call( + value.Int(32), + value.Int(59), + value.Int(12), + ) + require.NoError(t, err) + require.Equal(t, int64(32+59+12), res.Int()) + }) + }) + + t.Run("returns error", func(t *testing.T) { + failWhenTrue := func(val bool) (int, error) { + if val { + return 0, fmt.Errorf("function failed for a very good reason") + } + return 0, nil + } + funcVal := value.Encode(failWhenTrue) + + t.Run("no error", func(t *testing.T) { + res, err := funcVal.Call(value.Bool(false)) + require.NoError(t, err) + require.Equal(t, int64(0), res.Int()) + }) + + t.Run("error", func(t *testing.T) { + _, err := funcVal.Call(value.Bool(true)) + require.EqualError(t, err, "function failed for a very good reason") + }) + }) +} + +func TestValue_Interface_In_Array(t *testing.T) { + type Container struct { + Field io.Closer `river:"field,attr"` + } + + val := value.Encode(Container{Field: io.NopCloser(nil)}) + fieldVal, ok := val.Key("field") + require.True(t, ok, "field not found in object") + require.Equal(t, value.TypeCapsule, fieldVal.Type()) + + arrVal := value.Array(fieldVal) + require.Equal(t, value.TypeCapsule, arrVal.Index(0).Type()) +} diff --git a/syntax/parser/error_test.go b/syntax/parser/error_test.go new file mode 100644 index 0000000000..feb8602e31 --- /dev/null +++ b/syntax/parser/error_test.go @@ -0,0 +1,148 @@ +package parser + +import ( + "os" + "path/filepath" + "regexp" + "strings" + "testing" + + "github.com/grafana/river/diag" + "github.com/grafana/river/scanner" + "github.com/grafana/river/token" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// This file implements a parser test harness. The files in the testdata +// directory are parsed and the errors reported are compared against the error +// messages expected in the test files. +// +// Expected errors are indicated in the test files by putting a comment of the +// form /* ERROR "rx" */ immediately following an offending token. The harness +// will verify that an error matching the regular expression rx is reported at +// that source position. + +// ERROR comments must be of the form /* ERROR "rx" */ and rx is a regular +// expression that matches the expected error message. The special form +// /* ERROR HERE "rx" */ must be used for error messages that appear immediately +// after a token rather than at a token's position. +var errRx = regexp.MustCompile(`^/\* *ERROR *(HERE)? *"([^"]*)" *\*/$`) + +// expectedErrors collects the regular expressions of ERROR comments found in +// files and returns them as a map of error positions to error messages. +func expectedErrors(file *token.File, src []byte) map[token.Pos]string { + errors := make(map[token.Pos]string) + + s := scanner.New(file, src, nil, scanner.IncludeComments) + + var ( + prev token.Pos // Position of last non-comment, non-terminator token + here token.Pos // Position following after token at prev + ) + + for { + pos, tok, lit := s.Scan() + switch tok { + case token.EOF: + return errors + case token.COMMENT: + s := errRx.FindStringSubmatch(lit) + if len(s) == 3 { + pos := prev + if s[1] == "HERE" { + pos = here + } + errors[pos] = s[2] + } + case token.TERMINATOR: + if lit == "\n" { + break + } + fallthrough + default: + prev = pos + var l int // Token length + if isLiteral(tok) { + l = len(lit) + } else { + l = len(tok.String()) + } + here = prev.Add(l) + } + } +} + +func isLiteral(t token.Token) bool { + switch t { + case token.IDENT, token.NUMBER, token.FLOAT, token.STRING: + return true + } + return false +} + +// compareErrors compares the map of expected error messages with the list of +// found errors and reports mismatches. +func compareErrors(t *testing.T, file *token.File, expected map[token.Pos]string, found diag.Diagnostics) { + t.Helper() + + for _, checkError := range found { + pos := file.Pos(checkError.StartPos.Offset) + + if msg, found := expected[pos]; found { + // We expect a message at pos; check if it matches + rx, err := regexp.Compile(msg) + if !assert.NoError(t, err) { + continue + } + assert.True(t, + rx.MatchString(checkError.Message), + "%s: %q does not match %q", + checkError.StartPos, checkError.Message, msg, + ) + delete(expected, pos) // Eliminate consumed error + } else { + assert.Fail(t, + "Unexpected error", + "unexpected error: %s: %s", checkError.StartPos.String(), checkError.Message, + ) + } + } + + // There should be no expected errors left + if len(expected) > 0 { + t.Errorf("%d errors not reported:", len(expected)) + for pos, msg := range expected { + t.Errorf("%s: %s\n", file.PositionFor(pos), msg) + } + } +} + +func TestErrors(t *testing.T) { + list, err := os.ReadDir("testdata") + require.NoError(t, err) + + for _, d := range list { + name := d.Name() + if d.IsDir() || !strings.HasSuffix(name, ".river") { + continue + } + + t.Run(name, func(t *testing.T) { + checkErrors(t, filepath.Join("testdata", name)) + }) + } +} + +func checkErrors(t *testing.T, filename string) { + t.Helper() + + src, err := os.ReadFile(filename) + require.NoError(t, err) + + p := newParser(filename, src) + _ = p.ParseFile() + + expected := expectedErrors(p.file, src) + compareErrors(t, p.file, expected, p.diags) +} diff --git a/syntax/parser/internal.go b/syntax/parser/internal.go new file mode 100644 index 0000000000..1a8b7b7467 --- /dev/null +++ b/syntax/parser/internal.go @@ -0,0 +1,714 @@ +package parser + +import ( + "fmt" + "strings" + + "github.com/grafana/river/ast" + "github.com/grafana/river/diag" + "github.com/grafana/river/scanner" + "github.com/grafana/river/token" +) + +// parser implements the River parser. +// +// It is only safe for callers to use exported methods as entrypoints for +// parsing. +// +// Each Parse* and parse* method will describe the EBNF grammar being used for +// parsing that non-terminal. The EBNF grammar will be written as LL(1) and +// should directly represent the code. +// +// The parser will continue on encountering errors to allow a more complete +// list of errors to be returned to the user. The resulting AST should be +// discarded if errors were encountered during parsing. +type parser struct { + file *token.File + diags diag.Diagnostics + scanner *scanner.Scanner + comments []ast.CommentGroup + + pos token.Pos // Current token position + tok token.Token // Current token + lit string // Current token literal + + // Position of the last error written. Two parse errors on the same line are + // ignored. + lastError token.Position +} + +// newParser creates a new parser which will parse the provided src. +func newParser(filename string, src []byte) *parser { + file := token.NewFile(filename) + + p := &parser{ + file: file, + } + + p.scanner = scanner.New(file, src, func(pos token.Pos, msg string) { + p.diags.Add(diag.Diagnostic{ + Severity: diag.SeverityLevelError, + StartPos: file.PositionFor(pos), + Message: msg, + }) + }, scanner.IncludeComments) + + p.next() + return p +} + +// next advances the parser to the next non-comment token. +func (p *parser) next() { + p.next0() + + for p.tok == token.COMMENT { + p.consumeCommentGroup() + } +} + +// next0 advances the parser to the next token. next0 should not be used +// directly by parse methods; call next instead. +func (p *parser) next0() { p.pos, p.tok, p.lit = p.scanner.Scan() } + +// consumeCommentGroup consumes a group of adjacent comments, adding it to p's +// comment list. +func (p *parser) consumeCommentGroup() { + var list []*ast.Comment + + endline := p.pos.Position().Line + for p.tok == token.COMMENT && p.pos.Position().Line <= endline+1 { + var comment *ast.Comment + comment, endline = p.consumeComment() + list = append(list, comment) + } + + p.comments = append(p.comments, ast.CommentGroup(list)) +} + +// consumeComment consumes a comment and returns it with the line number it +// ends on. +func (p *parser) consumeComment() (comment *ast.Comment, endline int) { + endline = p.pos.Position().Line + + if p.lit[1] == '*' { + // Block comments may end on a different line than where they start. Scan + // the comment for newlines and adjust endline accordingly. + // + // NOTE: don't use range here, since range will unnecessarily decode + // Unicode code points and slow down the parser. + for i := 0; i < len(p.lit); i++ { + if p.lit[i] == '\n' { + endline++ + } + } + } + + comment = &ast.Comment{StartPos: p.pos, Text: p.lit} + p.next0() + return +} + +// advance consumes tokens up to (but not including) the specified token. +// advance will stop consuming tokens if EOF is reached before to. +func (p *parser) advance(to token.Token) { + for p.tok != token.EOF { + if p.tok == to { + return + } + p.next() + } +} + +// advanceAny consumes tokens up to (but not including) any of the tokens in +// the to set. +func (p *parser) advanceAny(to map[token.Token]struct{}) { + for p.tok != token.EOF { + if _, inSet := to[p.tok]; inSet { + return + } + p.next() + } +} + +// expect consumes the next token. It records an error if the consumed token +// was not t. +func (p *parser) expect(t token.Token) (pos token.Pos, tok token.Token, lit string) { + pos, tok, lit = p.pos, p.tok, p.lit + if tok != t { + p.addErrorf("expected %s, got %s", t, p.tok) + } + p.next() + return +} + +func (p *parser) addErrorf(format string, args ...interface{}) { + pos := p.file.PositionFor(p.pos) + + // Ignore errors which occur on the same line. + if p.lastError.Line == pos.Line { + return + } + p.lastError = pos + + p.diags.Add(diag.Diagnostic{ + Severity: diag.SeverityLevelError, + StartPos: pos, + Message: fmt.Sprintf(format, args...), + }) +} + +// ParseFile parses an entire file. +// +// File = Body +func (p *parser) ParseFile() *ast.File { + body := p.parseBody(token.EOF) + + return &ast.File{ + Name: p.file.Name(), + Body: body, + Comments: p.comments, + } +} + +// parseBody parses a series of statements up to and including the "until" +// token, which terminates the body. +// +// Body = [ Statement { terminator Statement } ] +func (p *parser) parseBody(until token.Token) ast.Body { + var body ast.Body + + for p.tok != until && p.tok != token.EOF { + stmt := p.parseStatement() + if stmt != nil { + body = append(body, stmt) + } + + if p.tok == until { + break + } + + if p.tok != token.TERMINATOR { + p.addErrorf("expected %s, got %s", token.TERMINATOR, p.tok) + p.consumeStatement() + } + p.next() + } + + return body +} + +// consumeStatement consumes tokens for the remainder of a statement (i.e., up +// to but not including a terminator). consumeStatement will keep track of the +// number of {}, [], and () pairs, only returning after the count of pairs is +// <= 0. +func (p *parser) consumeStatement() { + var curlyPairs, brackPairs, parenPairs int + + for p.tok != token.EOF { + switch p.tok { + case token.LCURLY: + curlyPairs++ + case token.RCURLY: + curlyPairs-- + case token.LBRACK: + brackPairs++ + case token.RBRACK: + brackPairs-- + case token.LPAREN: + parenPairs++ + case token.RPAREN: + parenPairs-- + } + + if p.tok == token.TERMINATOR { + // Only return after we've consumed all pairs. It's possible for pairs to + // be less than zero if our statement started in a surrounding pair. + if curlyPairs <= 0 && brackPairs <= 0 && parenPairs <= 0 { + return + } + } + + p.next() + } +} + +// parseStatement parses an individual statement within a body. +// +// Statement = Attribute | Block +// Attribute = identifier "=" Expression +// Block = BlockName "{" Body "}" +func (p *parser) parseStatement() ast.Stmt { + blockName := p.parseBlockName() + if blockName == nil { + // parseBlockName failed; skip to the next identifier which would start a + // new Statement. + p.advance(token.IDENT) + return nil + } + + // p.tok is now the first token after the identifier in the attribute or + // block name. + switch p.tok { + case token.ASSIGN: // Attribute + p.next() // Consume "=" + + if len(blockName.Fragments) != 1 { + attrName := strings.Join(blockName.Fragments, ".") + p.diags.Add(diag.Diagnostic{ + Severity: diag.SeverityLevelError, + StartPos: blockName.Start.Position(), + EndPos: blockName.Start.Add(len(attrName) - 1).Position(), + Message: `attribute names may only consist of a single identifier with no "."`, + }) + } else if blockName.LabelPos != token.NoPos { + p.diags.Add(diag.Diagnostic{ + Severity: diag.SeverityLevelError, + StartPos: blockName.LabelPos.Position(), + // Add 1 to the end position to add in the end quote, which is stripped from the label value. + EndPos: blockName.LabelPos.Add(len(blockName.Label) + 1).Position(), + Message: `attribute names may not have labels`, + }) + } + + return &ast.AttributeStmt{ + Name: &ast.Ident{ + Name: blockName.Fragments[0], + NamePos: blockName.Start, + }, + Value: p.ParseExpression(), + } + + case token.LCURLY: // Block + block := &ast.BlockStmt{ + Name: blockName.Fragments, + NamePos: blockName.Start, + Label: blockName.Label, + LabelPos: blockName.LabelPos, + } + + block.LCurlyPos, _, _ = p.expect(token.LCURLY) + block.Body = p.parseBody(token.RCURLY) + block.RCurlyPos, _, _ = p.expect(token.RCURLY) + + return block + + default: + if blockName.ValidAttribute() { + // The blockname could be used for an attribute or a block (no label, + // only one name fragment), so inform the user of both cases. + p.addErrorf("expected attribute assignment or block body, got %s", p.tok) + } else { + p.addErrorf("expected block body, got %s", p.tok) + } + + // Give up on this statement and skip to the next identifier. + p.advance(token.IDENT) + return nil + } +} + +// parseBlockName parses the name used for a block. +// +// BlockName = identifier { "." identifier } [ string ] +func (p *parser) parseBlockName() *blockName { + if p.tok != token.IDENT { + p.addErrorf("expected identifier, got %s", p.tok) + return nil + } + + var bn blockName + + bn.Fragments = append(bn.Fragments, p.lit) // Append first identifier + bn.Start = p.pos + p.next() + + // { "." identifier } + for p.tok == token.DOT { + p.next() // consume "." + + if p.tok != token.IDENT { + p.addErrorf("expected identifier, got %s", p.tok) + + // Continue here to parse as much as possible, even though the block name + // will be malformed. + } + + bn.Fragments = append(bn.Fragments, p.lit) + p.next() + } + + // [ string ] + if p.tok != token.ASSIGN && p.tok != token.LCURLY { + if p.tok == token.STRING { + // Only allow double-quoted strings for block labels. + if p.lit[0] != '"' { + p.addErrorf("expected block label to be a double quoted string, but got %q", p.lit) + } + + // Strip the quotes if it's non-empty. We then require any non-empty + // label to be a valid identifier. + if len(p.lit) > 2 { + bn.Label = p.lit[1 : len(p.lit)-1] + if !scanner.IsValidIdentifier(bn.Label) { + p.addErrorf("expected block label to be a valid identifier, but got %q", bn.Label) + } + } + bn.LabelPos = p.pos + } else { + p.addErrorf("expected block label, got %s", p.tok) + } + p.next() + } + + return &bn +} + +type blockName struct { + Fragments []string // Name fragments (i.e., `a.b.c`) + Label string // Optional user label + + Start token.Pos + LabelPos token.Pos +} + +// ValidAttribute returns true if the blockName can be used as an attribute +// name. +func (n blockName) ValidAttribute() bool { + return len(n.Fragments) == 1 && n.Label == "" +} + +// ParseExpression parses a single expression. +// +// Expression = BinOpExpr +func (p *parser) ParseExpression() ast.Expr { + return p.parseBinOp(1) +} + +// parseBinOp is the entrypoint for binary expressions. If there is no binary +// expressions in the current state, a single operand will be returned instead. +// +// BinOpExpr = OrExpr +// OrExpr = AndExpr { "||" AndExpr } +// AndExpr = CmpExpr { "&&" CmpExpr } +// CmpExpr = AddExpr { cmp_op AddExpr } +// AddExpr = MulExpr { add_op MulExpr } +// MulExpr = PowExpr { mul_op PowExpr } +// +// parseBinOp avoids the need for multiple non-terminal functions by providing +// context for operator precedence in recursive calls. inPrec specifies the +// incoming operator precedence. On the first call to parseBinOp, inPrec should +// be 1. +// +// parseBinOp can only handle left-associative operators, so PowExpr is handled +// by parsePowExpr. +func (p *parser) parseBinOp(inPrec int) ast.Expr { + // The EBNF documented by the function can be generalized into: + // + // CurPrecExpr = NextPrecExpr { cur_prec_ops NextPrecExpr } + // + // The code below implements this specific grammar, continually collecting + // everything at the same precedence level into the LHS of the expression + // while recursively calling parseBinOp for higher-precedence operations. + + lhs := p.parsePowExpr() + + for { + tok, pos, prec := p.tok, p.pos, p.tok.BinaryPrecedence() + if prec < inPrec { + // The next operator is lower precedence; drop up a level in our call + // stack. + return lhs + } + p.next() // Consume the operator + + // Recurse with a higher precedence level, which ensures that operators at + // the same precedence level don't get handled in the recursive call. + rhs := p.parseBinOp(prec + 1) + + lhs = &ast.BinaryExpr{ + Left: lhs, + Kind: tok, + KindPos: pos, + Right: rhs, + } + } +} + +// parsePowExpr is like parseBinOp but handles the right-associative pow +// operator. +// +// PowExpr = UnaryExpr [ "^" PowExpr ] +func (p *parser) parsePowExpr() ast.Expr { + lhs := p.parseUnaryExpr() + + if p.tok == token.POW { + pos := p.pos + p.next() // Consume ^ + + return &ast.BinaryExpr{ + Left: lhs, + Kind: token.POW, + KindPos: pos, + Right: p.parsePowExpr(), + } + } + + return lhs +} + +// parseUnaryExpr parses a unary expression. +// +// UnaryExpr = OperExpr | unary_op UnaryExpr +// +// OperExpr = PrimaryExpr { AccessExpr | IndexExpr | CallExpr } +// AccessExpr = "." identifier +// IndexExpr = "[" Expression "]" +// CallExpr = "(" [ ExpressionList ] ")" +func (p *parser) parseUnaryExpr() ast.Expr { + if isUnaryOp(p.tok) { + op, pos := p.tok, p.pos + p.next() // Consume op + + return &ast.UnaryExpr{ + Kind: op, + KindPos: pos, + Value: p.parseUnaryExpr(), + } + } + + primary := p.parsePrimaryExpr() + +NextOper: + for { + switch p.tok { + case token.DOT: // AccessExpr + p.next() + namePos, _, name := p.expect(token.IDENT) + + primary = &ast.AccessExpr{ + Value: primary, + Name: &ast.Ident{ + Name: name, + NamePos: namePos, + }, + } + + case token.LBRACK: // IndexExpr + lBrack, _, _ := p.expect(token.LBRACK) + index := p.ParseExpression() + rBrack, _, _ := p.expect(token.RBRACK) + + primary = &ast.IndexExpr{ + Value: primary, + LBrackPos: lBrack, + Index: index, + RBrackPos: rBrack, + } + + case token.LPAREN: // CallExpr + var args []ast.Expr + + lParen, _, _ := p.expect(token.LPAREN) + if p.tok != token.RPAREN { + args = p.parseExpressionList(token.RPAREN) + } + rParen, _, _ := p.expect(token.RPAREN) + + primary = &ast.CallExpr{ + Value: primary, + LParenPos: lParen, + Args: args, + RParenPos: rParen, + } + + case token.STRING, token.LCURLY: + // A user might be trying to assign a block to an attribute. let's + // attempt to parse the remainder as a block to tell them something is + // wrong. + // + // If we can't parse the remainder of the expression as a block, we give + // up and parse the remainder of the entire statement. + if p.tok == token.STRING { + p.next() + } + if _, tok, _ := p.expect(token.LCURLY); tok != token.LCURLY { + p.consumeStatement() + return primary + } + p.parseBody(token.RCURLY) + + end, tok, _ := p.expect(token.RCURLY) + if tok != token.RCURLY { + p.consumeStatement() + return primary + } + + p.diags.Add(diag.Diagnostic{ + Severity: diag.SeverityLevelError, + StartPos: ast.StartPos(primary).Position(), + EndPos: end.Position(), + Message: "cannot use a block as an expression", + }) + + default: + break NextOper + } + } + + return primary +} + +func isUnaryOp(tok token.Token) bool { + switch tok { + case token.NOT, token.SUB: + return true + default: + return false + } +} + +// parsePrimaryExpr parses a primary expression. +// +// PrimaryExpr = LiteralValue | ArrayExpr | ObjectExpr +// +// LiteralValue = identifier | string | number | float | bool | null | +// "(" Expression ")" +// +// ArrayExpr = "[" [ ExpressionList ] "]" +// ObjectExpr = "{" [ FieldList ] "}" +func (p *parser) parsePrimaryExpr() ast.Expr { + switch p.tok { + case token.IDENT: + res := &ast.IdentifierExpr{ + Ident: &ast.Ident{ + Name: p.lit, + NamePos: p.pos, + }, + } + p.next() + return res + + case token.STRING, token.NUMBER, token.FLOAT, token.BOOL, token.NULL: + res := &ast.LiteralExpr{ + Kind: p.tok, + Value: p.lit, + ValuePos: p.pos, + } + p.next() + return res + + case token.LPAREN: + lParen, _, _ := p.expect(token.LPAREN) + expr := p.ParseExpression() + rParen, _, _ := p.expect(token.RPAREN) + + return &ast.ParenExpr{ + LParenPos: lParen, + Inner: expr, + RParenPos: rParen, + } + + case token.LBRACK: + var res ast.ArrayExpr + + res.LBrackPos, _, _ = p.expect(token.LBRACK) + if p.tok != token.RBRACK { + res.Elements = p.parseExpressionList(token.RBRACK) + } + res.RBrackPos, _, _ = p.expect(token.RBRACK) + return &res + + case token.LCURLY: + var res ast.ObjectExpr + + res.LCurlyPos, _, _ = p.expect(token.LCURLY) + if p.tok != token.RBRACK { + res.Fields = p.parseFieldList(token.RCURLY) + } + res.RCurlyPos, _, _ = p.expect(token.RCURLY) + return &res + } + + p.addErrorf("expected expression, got %s", p.tok) + res := &ast.LiteralExpr{Kind: token.NULL, Value: "null", ValuePos: p.pos} + p.advanceAny(statementEnd) // Eat up the rest of the line + return res +} + +var statementEnd = map[token.Token]struct{}{ + token.TERMINATOR: {}, + token.RPAREN: {}, + token.RCURLY: {}, + token.RBRACK: {}, + token.COMMA: {}, +} + +// parseExpressionList parses a list of expressions. +// +// ExpressionList = Expression { "," Expression } [ "," ] +func (p *parser) parseExpressionList(until token.Token) []ast.Expr { + var exprs []ast.Expr + + for p.tok != until && p.tok != token.EOF { + exprs = append(exprs, p.ParseExpression()) + + if p.tok == until { + break + } + if p.tok != token.COMMA { + p.addErrorf("missing ',' in expression list") + } + p.next() + } + + return exprs +} + +// parseFieldList parses a list of fields in an object. +// +// FieldList = Field { "," Field } [ "," ] +func (p *parser) parseFieldList(until token.Token) []*ast.ObjectField { + var fields []*ast.ObjectField + + for p.tok != until && p.tok != token.EOF { + fields = append(fields, p.parseField()) + + if p.tok == until { + break + } + if p.tok != token.COMMA { + p.addErrorf("missing ',' in field list") + } + p.next() + } + + return fields +} + +// parseField parses a field in an object. +// +// Field = ( string | identifier ) "=" Expression +func (p *parser) parseField() *ast.ObjectField { + var field ast.ObjectField + + if p.tok == token.STRING || p.tok == token.IDENT { + field.Name = &ast.Ident{ + Name: p.lit, + NamePos: p.pos, + } + if p.tok == token.STRING && len(p.lit) > 2 { + // The field name is a string literal; unwrap the quotes. + field.Name.Name = p.lit[1 : len(p.lit)-1] + field.Quoted = true + } + p.next() // Consume field name + } else { + p.addErrorf("expected field name (string or identifier), got %s", p.tok) + p.advance(token.ASSIGN) + } + + p.expect(token.ASSIGN) + + field.Value = p.ParseExpression() + return &field +} diff --git a/syntax/parser/internal_test.go b/syntax/parser/internal_test.go new file mode 100644 index 0000000000..c3be1e7581 --- /dev/null +++ b/syntax/parser/internal_test.go @@ -0,0 +1,22 @@ +package parser + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestObjectFieldName(t *testing.T) { + tt := []string{ + `field_a = 5`, + `"field_a" = 5`, // Quotes should be removed from the field name + } + + for _, tc := range tt { + p := newParser(t.Name(), []byte(tc)) + + res := p.parseField() + + assert.Equal(t, "field_a", res.Name.Name) + } +} diff --git a/syntax/parser/parser.go b/syntax/parser/parser.go new file mode 100644 index 0000000000..66d2199d4b --- /dev/null +++ b/syntax/parser/parser.go @@ -0,0 +1,43 @@ +// Package parser implements utilities for parsing River configuration files. +package parser + +import ( + "github.com/grafana/river/ast" + "github.com/grafana/river/token" +) + +// ParseFile parses an entire River configuration file. The data parameter +// should hold the file contents to parse, while the filename parameter is used +// for reporting errors. +// +// If an error was encountered during parsing, the returned AST will be nil and +// err will be an diag.Diagnostics all the errors encountered during parsing. +func ParseFile(filename string, data []byte) (*ast.File, error) { + p := newParser(filename, data) + + f := p.ParseFile() + if len(p.diags) > 0 { + return nil, p.diags + } + return f, nil +} + +// ParseExpression parses a single River expression from expr. +// +// If an error was encountered during parsing, the returned expression will be +// nil and err will be an ErrorList with all the errors encountered during +// parsing. +func ParseExpression(expr string) (ast.Expr, error) { + p := newParser("", []byte(expr)) + + e := p.ParseExpression() + + // If the current token is not a TERMINATOR then the parsing did not complete + // in full and there are still parts of the string left unparsed. + p.expect(token.TERMINATOR) + + if len(p.diags) > 0 { + return nil, p.diags + } + return e, nil +} diff --git a/syntax/parser/parser_test.go b/syntax/parser/parser_test.go new file mode 100644 index 0000000000..f567c4650a --- /dev/null +++ b/syntax/parser/parser_test.go @@ -0,0 +1,123 @@ +package parser + +import ( + "io/fs" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func FuzzParser(f *testing.F) { + filepath.WalkDir("./testdata/valid", func(path string, d fs.DirEntry, _ error) error { + if d.IsDir() { + return nil + } + + bb, err := os.ReadFile(path) + require.NoError(f, err) + f.Add(bb) + return nil + }) + + f.Fuzz(func(t *testing.T, input []byte) { + p := newParser(t.Name(), input) + + _ = p.ParseFile() + if len(p.diags) > 0 { + t.SkipNow() + } + }) +} + +// TestValid parses every *.river file in testdata, which is expected to be +// valid. +func TestValid(t *testing.T) { + filepath.WalkDir("./testdata/valid", func(path string, d fs.DirEntry, _ error) error { + if d.IsDir() { + return nil + } + + t.Run(filepath.Base(path), func(t *testing.T) { + bb, err := os.ReadFile(path) + require.NoError(t, err) + + p := newParser(path, bb) + + res := p.ParseFile() + require.NotNil(t, res) + require.Len(t, p.diags, 0) + }) + + return nil + }) +} + +func TestParseExpressions(t *testing.T) { + tt := map[string]string{ + "literal number": `10`, + "literal float": `15.0`, + "literal string": `"Hello, world!"`, + "literal ident": `some_ident`, + "literal null": `null`, + "literal true": `true`, + "literal false": `false`, + + "empty array": `[]`, + "array one element": `[1]`, + "array many elements": `[0, 1, 2, 3]`, + "array trailing comma": `[0, 1, 2, 3,]`, + "nested array": `[[0, 1, 2], [3, 4, 5]]`, + "array multiline": `[ + 0, + 1, + 2, + ]`, + + "empty object": `{}`, + "object one field": `{ field_a = 5 }`, + "object multiple fields": `{ field_a = 5, field_b = 10 }`, + "object trailing comma": `{ field_a = 5, field_b = 10, }`, + "nested objects": `{ field_a = { nested_field = 100 } }`, + "object multiline": `{ + field_a = 5, + field_b = 10, + }`, + + "unary not": `!true`, + "unary neg": `-5`, + + "math": `1 + 2 - 3 * 4 / 5 % 6`, + "compare ops": `1 == 2 != 3 < 4 > 5 <= 6 >= 7`, + "logical ops": `true || false && true`, + "pow operator": "1 ^ 2 ^ 3", + + "field access": `a.b.c.d`, + "element access": `a[0][1][2]`, + + "call no args": `a()`, + "call one arg": `a(1)`, + "call multiple args": `a(1,2,3)`, + "call with trailing comma": `a(1,2,3,)`, + "call multiline": `a( + 1, + 2, + 3, + )`, + + "parens": `(1 + 5) * 100`, + + "mixed expression": `(a.b.c)(1, 3 * some_list[magic_index * 2]).resulting_field`, + } + + for name, input := range tt { + t.Run(name, func(t *testing.T) { + p := newParser(name, []byte(input)) + + res := p.ParseExpression() + require.NotNil(t, res) + require.Len(t, p.diags, 0) + }) + } +} diff --git a/syntax/parser/testdata/assign_block_to_attr.river b/syntax/parser/testdata/assign_block_to_attr.river new file mode 100644 index 0000000000..e291308599 --- /dev/null +++ b/syntax/parser/testdata/assign_block_to_attr.river @@ -0,0 +1,32 @@ +rw = prometheus/* ERROR "cannot use a block as an expression" */.remote_write "default" { + endpoint { + url = "some_url" + basic_auth { + username = "username" + password = "password" + } + } +} + +attr_1 = 15 +attr_2 = 51 + +block { + rw_2 = prometheus/* ERROR "cannot use a block as an expression" */.remote_write "other" { + endpoint { + url = "other_url" + basic_auth { + username = "username_2" + password = "password_2" + } + } + } +} + +other_block { + // This is an expression which looks like it might be a block at first, but + // then isn't. + rw_3 = prometheus.remote_write "other" "other" /* ERROR "expected {, got STRING" */ 12345 +} + +attr_3 = 15 diff --git a/syntax/parser/testdata/attribute_names.river b/syntax/parser/testdata/attribute_names.river new file mode 100644 index 0000000000..1e18c60850 --- /dev/null +++ b/syntax/parser/testdata/attribute_names.river @@ -0,0 +1,7 @@ +valid_attr = 15 + +// The parser parses block names for both blocks and attributes, and later +// validates that the attribute name is just a single identifier with no label. + +invalid/* ERROR "attribute names may only consist of a single identifier" */.attr = 20 +invalid "label" /* ERROR "attribute names may not have labels" */ = 20 diff --git a/syntax/parser/testdata/block_names.river b/syntax/parser/testdata/block_names.river new file mode 100644 index 0000000000..cc0f2040f3 --- /dev/null +++ b/syntax/parser/testdata/block_names.river @@ -0,0 +1,25 @@ +valid_block { + +} + +valid_block "labeled" { + +} + +invalid_block bad_label_name /* ERROR "expected block label, got IDENT" */ { + +} + +other_valid_block { + nested_block { + + } + + nested_block "labeled" { + + } +} + +invalid_block "with space" /* ERROR "expected block label to be a valid identifier" */ { + +} diff --git a/syntax/parser/testdata/commas.river b/syntax/parser/testdata/commas.river new file mode 100644 index 0000000000..a43bb5c873 --- /dev/null +++ b/syntax/parser/testdata/commas.river @@ -0,0 +1,13 @@ +// Test that missing trailing commas for multiline expressions get reported. + +field = [ + 0, + 1, + 2/* ERROR HERE "missing ',' in expression list" */ +] + +obj = { + field_a = 0, + field_b = 1, + field_c = 2/* ERROR HERE "missing ',' in field list" */ +} diff --git a/syntax/parser/testdata/fuzz/FuzzParser/1a39f4e358facc21678b16fad53537b46efdaa76e024a5ef4955d01a68bdac37 b/syntax/parser/testdata/fuzz/FuzzParser/1a39f4e358facc21678b16fad53537b46efdaa76e024a5ef4955d01a68bdac37 new file mode 100644 index 0000000000..6151b52921 --- /dev/null +++ b/syntax/parser/testdata/fuzz/FuzzParser/1a39f4e358facc21678b16fad53537b46efdaa76e024a5ef4955d01a68bdac37 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("A0000000000000000") diff --git a/syntax/parser/testdata/fuzz/FuzzParser/248cf4391f6c48550b7d2cf4c6c80f4ba9099c21ffa2b6869e75e99565dce037 b/syntax/parser/testdata/fuzz/FuzzParser/248cf4391f6c48550b7d2cf4c6c80f4ba9099c21ffa2b6869e75e99565dce037 new file mode 100644 index 0000000000..cc252cd81b --- /dev/null +++ b/syntax/parser/testdata/fuzz/FuzzParser/248cf4391f6c48550b7d2cf4c6c80f4ba9099c21ffa2b6869e75e99565dce037 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("A={A!0A\"") diff --git a/syntax/parser/testdata/fuzz/FuzzParser/b919fa00ebca318001778477c839a06204b55f2636597901d8d7878150d8580a b/syntax/parser/testdata/fuzz/FuzzParser/b919fa00ebca318001778477c839a06204b55f2636597901d8d7878150d8580a new file mode 100644 index 0000000000..ff4ab488f8 --- /dev/null +++ b/syntax/parser/testdata/fuzz/FuzzParser/b919fa00ebca318001778477c839a06204b55f2636597901d8d7878150d8580a @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("A\"") diff --git a/syntax/parser/testdata/invalid_exprs.river b/syntax/parser/testdata/invalid_exprs.river new file mode 100644 index 0000000000..c7ea3e4385 --- /dev/null +++ b/syntax/parser/testdata/invalid_exprs.river @@ -0,0 +1,4 @@ +attr = 1 + + /* ERROR "expected expression, got +" */ 2 + +invalid_func_call = a(() /* ERROR "expected expression, got \)" */) +invalid_access = a.true /* ERROR "expected IDENT, got BOOL" */ diff --git a/syntax/parser/testdata/invalid_object_key.river b/syntax/parser/testdata/invalid_object_key.river new file mode 100644 index 0000000000..1dd5a4ba8c --- /dev/null +++ b/syntax/parser/testdata/invalid_object_key.river @@ -0,0 +1,9 @@ +obj { + map = { + "string_field" = "foo", + identifier_string = "bar", + 1337 /* ERROR "expected field name \(string or identifier\), got NUMBER" */ = "baz", + "another_field" = "qux", + } +} + diff --git a/syntax/parser/testdata/valid/attribute.river b/syntax/parser/testdata/valid/attribute.river new file mode 100644 index 0000000000..9cc731b27f --- /dev/null +++ b/syntax/parser/testdata/valid/attribute.river @@ -0,0 +1 @@ +number_field = 1 diff --git a/syntax/parser/testdata/valid/blocks.river b/syntax/parser/testdata/valid/blocks.river new file mode 100644 index 0000000000..74d79bfe67 --- /dev/null +++ b/syntax/parser/testdata/valid/blocks.river @@ -0,0 +1,36 @@ +one_ident { + number_field = 1 +} + +one_ident "labeled" { + number_field = 1 +} + +multiple.idents { + number_field = 1 +} + +multiple.idents "labeled" { + number_field = 1 +} + +chain.of.idents { + number_field = 1 +} + +chain.of.idents "labeled" { + number_field = 1 +} + +one_ident_inline { number_field = 1 } +one_ident_inline "labeled" { number_field = 1 } +multiple.idents_inline { number_field = 1 } +multiple.idents_inline "labeled" { number_field = 1 } +chain.of.idents { number_field = 1 } +chain.of.idents "labeled" { number_field = 1 } + +nested_block { + inner_block { + some_field = true + } +} diff --git a/syntax/parser/testdata/valid/comments.river b/syntax/parser/testdata/valid/comments.river new file mode 100644 index 0000000000..9dafbd7346 --- /dev/null +++ b/syntax/parser/testdata/valid/comments.river @@ -0,0 +1 @@ +// Hello, world! diff --git a/syntax/parser/testdata/valid/empty.river b/syntax/parser/testdata/valid/empty.river new file mode 100644 index 0000000000..e69de29bb2 diff --git a/syntax/parser/testdata/valid/expressions.river b/syntax/parser/testdata/valid/expressions.river new file mode 100644 index 0000000000..78af48cc22 --- /dev/null +++ b/syntax/parser/testdata/valid/expressions.river @@ -0,0 +1,81 @@ +// Literals +lit_number = 10 +lit_float = 15.0 +lit_string = "Hello, world!" +lit_ident = other_ident +lit_null = null +lit_true = true +lit_false = false + +// Arrays +array_expr_empty = [] +array_expr_one_element = [0] +array_expr = [0, 1, 2, 3] +array_expr_trailing = [0, 1, 2, 3,] +array_expr_multiline = [ + 0, + 1, + 2, + 3, +] +array_expr_nested = [[1]] + +// Objects +object_expr_empty = {} +object_expr_one_field = { field_a = 1 } +object_expr = { field_a = 1, field_b = 2 } +object_expr_trailing = { field_a = 1, field_b = 2, } +object_expr_multiline = { + field_a = 1, + field_b = 2, +} +object_expr_nested = { field_a = { nested_field_a = 1 } } + +// Unary ops +not_something = !true +neg_number = -5 + +// Math binops +binop_sum = 1 + 2 +binop_sub = 1 - 2 +binop_mul = 1 * 2 +binop_div = 1 / 2 +binop_mod = 1 % 2 +binop_pow = 1 ^ 2 ^ 3 + +// Compare binops +binop_eq = 1 == 2 +binop_neq = 1 != 2 +binop_lt = 1 < 2 +binop_lte = 1 <= 2 +binop_gt = 1 > 2 +binop_gte = 1 >= 2 + +// Logical binops +binop_or = true || false +binop_and = true && false + + +// Mixed math operations +math = 1 + 2 - 3 * 4 / 5 % 6 +compare_ops = 1 == 2 != 3 < 4 > 5 <= 6 >= 7 +logical_ops = true || false && true +mixed_assoc = 1 * 3 + 5 ^ 3 - 2 % 1 // Test with both left- and right- associative operators +expr_parens = (5 * 2) + 5 + +// Accessors +field_access = a.b.c.d +element_access = a[0][1][2] + +// Function calls +call_no_args = a() +call_one_arg = a(1) +call_multiple_args = a(1,2,3) +call_trailing_comma = a(1,2,3,) +call_multiline = a( + 1, + 2, + 3, +) + +mixed_expr = (a.b.c)(1, 3 * some_list[magic_index * 2]).resulting_field diff --git a/syntax/printer/printer.go b/syntax/printer/printer.go new file mode 100644 index 0000000000..8faeb22006 --- /dev/null +++ b/syntax/printer/printer.go @@ -0,0 +1,556 @@ +// Package printer contains utilities for pretty-printing River ASTs. +package printer + +import ( + "fmt" + "io" + "math" + "text/tabwriter" + + "github.com/grafana/river/ast" + "github.com/grafana/river/token" +) + +// Config configures behavior of the printer. +type Config struct { + Indent int // Indentation to apply to all emitted code. Default 0. +} + +// Fprint pretty-prints the specified node to w. The Node type must be an +// *ast.File, ast.Body, or a type that implements ast.Stmt or ast.Expr. +func (c *Config) Fprint(w io.Writer, node ast.Node) (err error) { + var p printer + p.Init(c) + + // Pass all of our text through a trimmer to ignore trailing whitespace. + w = &trimmer{next: w} + + if err = (&walker{p: &p}).Walk(node); err != nil { + return + } + + // Call flush one more time to write trailing comments. + p.flush(token.Position{ + Offset: math.MaxInt, + Line: math.MaxInt, + Column: math.MaxInt, + }, token.EOF) + + w = tabwriter.NewWriter(w, 0, 8, 1, ' ', tabwriter.DiscardEmptyColumns|tabwriter.TabIndent) + + if _, err = w.Write(p.output); err != nil { + return + } + if tw, _ := w.(*tabwriter.Writer); tw != nil { + // Flush tabwriter if defined + err = tw.Flush() + } + + return +} + +// Fprint pretty-prints the specified node to w. The Node type must be an +// *ast.File, ast.Body, or a type that implements ast.Stmt or ast.Expr. +func Fprint(w io.Writer, node ast.Node) error { + c := &Config{} + return c.Fprint(w, node) +} + +// The printer writes lexical tokens and whitespace to an internal buffer. +// Comments are written by the printer itself, while all other tokens and +// formatting characters are sent through calls to Write. +// +// Internally, printer depends on a tabwriter for formatting text and aligning +// runs of characters. Horizontal '\t' and vertical '\v' tab characters are +// used to introduce new columns in the row. Runs of characters are stopped +// be either introducing a linefeed '\f' or by having a line with a different +// number of columns from the previous line. See the text/tabwriter package for +// more information on the elastic tabstop algorithm it uses for formatting +// text. +type printer struct { + cfg Config + + // State variables + + output []byte + indent int // Current indentation level + lastTok token.Token // Last token printed (token.LITERAL if it's whitespace) + + // Whitespace holds a buffer of whitespace characters to print prior to the + // next non-whitespace token. Whitespace is held in a buffer to avoid + // printing unnecessary whitespace at the end of a file. + whitespace []whitespace + + // comments stores comments to be processed as elements are printed. + comments commentInfo + + // pos is an approximation of the current position in AST space, and is used + // to determine space between AST elements (e.g., if a comment should come + // before a token). pos automatically as elements are written and can be manually + // set to guarantee an accurate position by passing a token.Pos to Write. + pos token.Position + last token.Position // Last pos written to output (through writeString) + + // out is an accurate representation of the current position in output space, + // used to inject extra formatting like indentation based on the output + // position. + // + // out may differ from pos in terms of whitespace. + out token.Position +} + +type commentInfo struct { + list []ast.CommentGroup + idx int + cur ast.CommentGroup + pos token.Pos +} + +func (ci *commentInfo) commentBefore(next token.Position) bool { + return ci.pos != token.NoPos && ci.pos.Offset() <= next.Offset +} + +// nextComment preloads the next comment. +func (ci *commentInfo) nextComment() { + for ci.idx < len(ci.list) { + c := ci.list[ci.idx] + ci.idx++ + if len(c) > 0 { + ci.cur = c + ci.pos = ast.StartPos(c[0]) + return + } + } + ci.pos = token.NoPos +} + +// Init initializes the printer for printing. Init is intended to be called +// once per printer and doesn't fully reset its state. +func (p *printer) Init(cfg *Config) { + p.cfg = *cfg + p.pos = token.Position{Line: 1, Column: 1} + p.out = token.Position{Line: 1, Column: 1} + // Capacity is set low since most whitespace sequences are short. + p.whitespace = make([]whitespace, 0, 16) +} + +// SetComments set the comments to use. +func (p *printer) SetComments(comments []ast.CommentGroup) { + p.comments = commentInfo{ + list: comments, + idx: 0, + pos: token.NoPos, + } + p.comments.nextComment() +} + +// Write writes a list of writable arguments to the printer. +// +// Arguments can be one of the types described below: +// +// If arg is a whitespace value, it is accumulated into a buffer and flushed +// only after a non-whitespace value is processed. The whitespace buffer will +// be forcibly flushed if the buffer becomes full without writing a +// non-whitespace token. +// +// If arg is an *ast.IdentifierExpr, *ast.LiteralExpr, or a token.Token, the +// human-readable representation of that value will be written. +// +// When writing text, comments which need to appear before that text in +// AST-space are written first, followed by leftover whitespace and then the +// text to write. The written text will update the AST-space position. +// +// If arg is a token.Pos, the AST-space position of the printer is updated to +// the provided Pos. Writing token.Pos values can help make sure the printer's +// AST-space position is accurate, as AST-space position is otherwise an +// estimation based on written data. +func (p *printer) Write(args ...interface{}) { + for _, arg := range args { + var ( + data string + isLit bool + ) + + switch arg := arg.(type) { + case whitespace: + // Whitespace token; add it to our token buffer. Note that a whitespace + // token is different than the actual whitespace which will get written + // (e.g., wsIndent increases indentation level by one instead of setting + // it to one.) + if arg == wsIgnore { + continue + } + i := len(p.whitespace) + if i == cap(p.whitespace) { + // We built up too much whitespace; this can happen if too many calls + // to Write happen without appending a non-comment token. We will + // force-flush the existing whitespace to avoid a panic. + // + // Ideally this line is never hit based on how we walk the AST, but + // it's kept for safety. + p.writeWritespace(i) + i = 0 + } + p.whitespace = p.whitespace[0 : i+1] + p.whitespace[i] = arg + p.lastTok = token.LITERAL + continue + + case *ast.Ident: + data = arg.Name + p.lastTok = token.IDENT + + case *ast.LiteralExpr: + data = arg.Value + p.lastTok = arg.Kind + + case token.Pos: + if arg.Valid() { + p.pos = arg.Position() + } + // Don't write anything; token.Pos is an instruction and doesn't include + // any text to write. + continue + + case token.Token: + s := arg.String() + data = s + + // We will need to inject whitespace if the previous token and the + // current token would combine into a single token when re-scanned. This + // ensures that the sequence of tokens emitted by the output of the + // printer match the sequence of tokens from the input. + if mayCombine(p.lastTok, s[0]) { + if len(p.whitespace) != 0 { + // It shouldn't be possible for the whitespace buffer to be not empty + // here; p.lastTok would've had to been a non-whitespace token and so + // whitespace would've been flushed when it was written to the output + // buffer. + panic("whitespace buffer not empty") + } + p.whitespace = p.whitespace[0:1] + p.whitespace[0] = ' ' + } + p.lastTok = arg + + default: + panic(fmt.Sprintf("printer: unsupported argument %v (%T)\n", arg, arg)) + } + + next := p.pos + + p.flush(next, p.lastTok) + p.writeString(next, data, isLit) + } +} + +// mayCombine returns true if two tokes must not be combined, because combining +// them would format in a different token sequence being generated. +func mayCombine(prev token.Token, next byte) (b bool) { + switch prev { + case token.NUMBER: + return next == '.' // 1. + case token.DIV: + return next == '*' // /* + default: + return false + } +} + +// flush prints any pending comments and whitespace occurring textually before +// the position of the next token tok. The flush result indicates if a newline +// was written or if a formfeed \f character was dropped from the whitespace +// buffer. +func (p *printer) flush(next token.Position, tok token.Token) { + if p.comments.commentBefore(next) { + p.injectComments(next, tok) + } else if tok != token.EOF { + // Write all remaining whitespace. + p.writeWritespace(len(p.whitespace)) + } +} + +func (p *printer) injectComments(next token.Position, tok token.Token) { + var lastComment *ast.Comment + + for p.comments.commentBefore(next) { + for _, c := range p.comments.cur { + p.writeCommentPrefix(next, c) + p.writeComment(next, c) + lastComment = c + } + p.comments.nextComment() + } + + p.writeCommentSuffix(next, tok, lastComment) +} + +// writeCommentPrefix writes whitespace that should appear before c. +func (p *printer) writeCommentPrefix(next token.Position, c *ast.Comment) { + if len(p.output) == 0 { + // The comment is the first thing written to the output. Don't write any + // whitespace before it. + return + } + + cPos := c.StartPos.Position() + + if cPos.Line == p.last.Line { + // Our comment is on the same line as the last token. Write a separator + // between the last token and the comment. + separator := byte('\t') + if cPos.Line == next.Line { + // The comment is on the same line as the next token, which means it has + // to be a block comment (since line comments run to the end of the + // line.) Use a space as the separator instead since a tab in the middle + // of a line between comments would look weird. + separator = byte(' ') + } + p.writeByte(separator, 1) + } else { + // Our comment is on a different line from the last token. First write + // pending whitespace from the last token up to the first newline. + var wsCount int + + for i, ws := range p.whitespace { + switch ws { + case wsBlank, wsVTab: + // Drop any whitespace before the comment. + p.whitespace[i] = wsIgnore + case wsIndent, wsUnindent: + // Allow indentation to be applied. + continue + case wsNewline, wsFormfeed: + // Drop the whitespace since we're about to write our own. + p.whitespace[i] = wsIgnore + } + wsCount = i + break + } + p.writeWritespace(wsCount) + + var newlines int + if cPos.Valid() && p.last.Valid() { + newlines = cPos.Line - p.last.Line + } + if newlines > 0 { + p.writeByte('\f', newlineLimit(newlines)) + } + } +} + +func (p *printer) writeComment(_ token.Position, c *ast.Comment) { + p.writeString(c.StartPos.Position(), c.Text, true) +} + +// writeCommentSuffix writes any whitespace necessary between the last comment +// and next. lastComment should be the final comment written. +func (p *printer) writeCommentSuffix(next token.Position, tok token.Token, lastComment *ast.Comment) { + if tok == token.EOF { + // We don't want to add any blank newlines before the end of the file; + // return early. + return + } + + var droppedFF bool + + // If our final comment is a block comment and is on the same line as the + // next token, add a space as a suffix to separate them. + lastCommentPos := ast.EndPos(lastComment).Position() + if lastComment.Text[1] == '*' && next.Line == lastCommentPos.Line { + p.writeByte(' ', 1) + } + + newlines := next.Line - p.last.Line + + for i, ws := range p.whitespace { + switch ws { + case wsBlank, wsVTab: + p.whitespace[i] = wsIgnore + case wsIndent, wsUnindent: + continue + case wsNewline, wsFormfeed: + if ws == wsFormfeed { + droppedFF = true + } + p.whitespace[i] = wsIgnore + } + } + + p.writeWritespace(len(p.whitespace)) + + // Write newlines as long as the next token isn't EOF (so that there's no + // blank newlines at the end of the file). + if newlines > 0 { + ch := byte('\n') + if droppedFF { + // If we dropped a formfeed while writing comments, we should emit a new + // one. + ch = byte('\f') + } + p.writeByte(ch, newlineLimit(newlines)) + } +} + +// writeString writes the literal string s into the printer's output. +// Formatting characters in s such as '\t' and '\n' will be interpreted by +// underlying tabwriter unless isLit is set. +func (p *printer) writeString(pos token.Position, s string, isLit bool) { + if p.out.Column == 1 { + // We haven't written any text to this line yet; prepend our indentation + // for the line. + p.writeIndent() + } + + if pos.Valid() { + // Update p.pos if pos is valid. This is done *after* handling indentation + // since we want to interpret pos as the literal position for s (and + // writeIndent will update p.pos). + p.pos = pos + } + + if isLit { + // Wrap our literal string in tabwriter.Escape if it's meant to be written + // without interpretation by the tabwriter. + p.output = append(p.output, tabwriter.Escape) + + defer func() { + p.output = append(p.output, tabwriter.Escape) + }() + } + + p.output = append(p.output, s...) + + var ( + newlines int + lastNewlineIdx int + ) + + for i := 0; i < len(s); i++ { + if ch := s[i]; ch == '\n' || ch == '\f' { + newlines++ + lastNewlineIdx = i + } + } + + p.pos.Offset += len(s) + + if newlines > 0 { + p.pos.Line += newlines + p.out.Line += newlines + + newColumn := len(s) - lastNewlineIdx + p.pos.Column = newColumn + p.out.Column = newColumn + } else { + p.pos.Column += len(s) + p.out.Column += len(s) + } + + p.last = p.pos +} + +func (p *printer) writeIndent() { + depth := p.cfg.Indent + p.indent + for i := 0; i < depth; i++ { + p.output = append(p.output, '\t') + } + + p.pos.Offset += depth + p.pos.Column += depth + p.out.Column += depth +} + +// writeByte writes ch n times to the output, updating the position of the +// printer. writeByte is only used for writing whitespace characters. +func (p *printer) writeByte(ch byte, n int) { + if p.out.Column == 1 { + p.writeIndent() + } + + for i := 0; i < n; i++ { + p.output = append(p.output, ch) + } + + // Update positions. + p.pos.Offset += n + if ch == '\n' || ch == '\f' { + p.pos.Line += n + p.out.Line += n + p.pos.Column = 1 + p.out.Column = 1 + return + } + p.pos.Column += n + p.out.Column += n +} + +// writeWhitespace writes the first n whitespace entries in the whitespace +// buffer. +// +// writeWritespace is only safe to be called when len(p.whitespace) >= n. +func (p *printer) writeWritespace(n int) { + for i := 0; i < n; i++ { + switch ch := p.whitespace[i]; ch { + case wsIgnore: // no-op + case wsIndent: + p.indent++ + case wsUnindent: + p.indent-- + if p.indent < 0 { + panic("printer: negative indentation") + } + default: + p.writeByte(byte(ch), 1) + } + } + + // Shift remaining entries down + l := copy(p.whitespace, p.whitespace[n:]) + p.whitespace = p.whitespace[:l] +} + +const maxNewlines = 2 + +// newlineLimit limits a newline count to maxNewlines. +func newlineLimit(count int) int { + if count > maxNewlines { + count = maxNewlines + } + return count +} + +// whitespace represents a whitespace token to write to the printer's internal +// buffer. +type whitespace byte + +const ( + wsIgnore = whitespace(0) + wsBlank = whitespace(' ') + wsVTab = whitespace('\v') + wsNewline = whitespace('\n') + wsFormfeed = whitespace('\f') + wsIndent = whitespace('>') + wsUnindent = whitespace('<') +) + +func (ws whitespace) String() string { + switch ws { + case wsIgnore: + return "wsIgnore" + case wsBlank: + return "wsBlank" + case wsVTab: + return "wsVTab" + case wsNewline: + return "wsNewline" + case wsFormfeed: + return "wsFormfeed" + case wsIndent: + return "wsIndent" + case wsUnindent: + return "wsUnindent" + default: + return fmt.Sprintf("whitespace(%d)", ws) + } +} diff --git a/syntax/printer/printer_test.go b/syntax/printer/printer_test.go new file mode 100644 index 0000000000..38f69217b1 --- /dev/null +++ b/syntax/printer/printer_test.go @@ -0,0 +1,77 @@ +package printer_test + +import ( + "bytes" + "io/fs" + "os" + "path/filepath" + "strings" + "testing" + "unicode" + + "github.com/grafana/river/parser" + "github.com/grafana/river/printer" + "github.com/stretchr/testify/require" +) + +func TestPrinter(t *testing.T) { + filepath.WalkDir("testdata", func(path string, d fs.DirEntry, _ error) error { + if d.IsDir() { + return nil + } + + if strings.HasSuffix(path, ".in") { + inputFile := path + expectFile := strings.TrimSuffix(path, ".in") + ".expect" + expectErrorFile := strings.TrimSuffix(path, ".in") + ".error" + + caseName := filepath.Base(path) + caseName = strings.TrimSuffix(caseName, ".in") + + t.Run(caseName, func(t *testing.T) { + testPrinter(t, inputFile, expectFile, expectErrorFile) + }) + } + + return nil + }) +} + +func testPrinter(t *testing.T, inputFile string, expectFile string, expectErrorFile string) { + inputBB, err := os.ReadFile(inputFile) + require.NoError(t, err) + + f, err := parser.ParseFile(t.Name()+".rvr", inputBB) + if expectedError := getExpectedErrorMessage(t, expectErrorFile); expectedError != "" { + require.Error(t, err) + require.Contains(t, err.Error(), expectedError) + return + } + + expectBB, err := os.ReadFile(expectFile) + require.NoError(t, err) + + var buf bytes.Buffer + require.NoError(t, printer.Fprint(&buf, f)) + + trimmed := strings.TrimRightFunc(string(expectBB), unicode.IsSpace) + require.Equal(t, trimmed, buf.String(), "%s", buf.String()) +} + +// getExpectedErrorMessage will retrieve an optional expected error message for the test. +func getExpectedErrorMessage(t *testing.T, errorFile string) string { + if _, err := os.Stat(errorFile); err == nil { + errorBytes, err := os.ReadFile(errorFile) + require.NoError(t, err) + errorsString := string(normalizeLineEndings(errorBytes)) + return errorsString + } + + return "" +} + +// normalizeLineEndings will replace '\r\n' with '\n'. +func normalizeLineEndings(data []byte) []byte { + normalized := bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'}) + return normalized +} diff --git a/syntax/printer/testdata/.gitattributes b/syntax/printer/testdata/.gitattributes new file mode 100644 index 0000000000..7949f2b32c --- /dev/null +++ b/syntax/printer/testdata/.gitattributes @@ -0,0 +1 @@ +* -text eol=lf diff --git a/syntax/printer/testdata/array_comments.expect b/syntax/printer/testdata/array_comments.expect new file mode 100644 index 0000000000..a9e921b6dc --- /dev/null +++ b/syntax/printer/testdata/array_comments.expect @@ -0,0 +1,17 @@ +// array_comments.in expects that comments in arrays are formatted to +// retain the indentation level of elements within the arrays. + +attr = [ // Inline comment + 0, 1, 2, // Inline comment + 3, 4, 5, // Inline comment + // Trailing comment +] + +attr = [ + 0, + // Element-level comment + 1, + // Element-level comment + 2, + // Trailing comment +] diff --git a/syntax/printer/testdata/array_comments.in b/syntax/printer/testdata/array_comments.in new file mode 100644 index 0000000000..e088e83749 --- /dev/null +++ b/syntax/printer/testdata/array_comments.in @@ -0,0 +1,17 @@ +// array_comments.in expects that comments in arrays are formatted to +// retain the indentation level of elements within the arrays. + +attr = [ // Inline comment + 0, 1, 2, // Inline comment + 3, 4, 5, // Inline comment + // Trailing comment +] + +attr = [ + 0, + // Element-level comment + 1, + // Element-level comment + 2, + // Trailing comment +] diff --git a/syntax/printer/testdata/block_comments.expect b/syntax/printer/testdata/block_comments.expect new file mode 100644 index 0000000000..ca60e0796c --- /dev/null +++ b/syntax/printer/testdata/block_comments.expect @@ -0,0 +1,62 @@ +// block_comments.in expects that comments within blocks are formatted to +// remain within the block with the proper indentation. + +// +// Unlabeled blocks +// + +// Comment is on same line as empty block header. +block { // comment +} + +// Comment is on same line as non-empty block header. +block { // comment + attr = 5 +} + +// Comment is alone in block body. +block { + // comment +} + +// Comment is before a statement. +block { + // comment + attr = 5 +} + +// Comment is after a statement. +block { + attr = 5 + // comment +} + +// +// Labeled blocks +// + +// Comment is on same line as empty block header. +block "label" { // comment +} + +// Comment is on same line as non-empty block header. +block "label" { // comment + attr = 5 +} + +// Comment is alone in block body. +block "label" { + // comment +} + +// Comment is before a statement. +block "label" { + // comment + attr = 5 +} + +// Comment is after a statement. +block "label" { + attr = 5 + // comment +} diff --git a/syntax/printer/testdata/block_comments.in b/syntax/printer/testdata/block_comments.in new file mode 100644 index 0000000000..d60512c9a3 --- /dev/null +++ b/syntax/printer/testdata/block_comments.in @@ -0,0 +1,64 @@ +// block_comments.in expects that comments within blocks are formatted to +// remain within the block with the proper indentation. + +// +// Unlabeled blocks +// + +// Comment is on same line as empty block header. +block { // comment +} + +// Comment is on same line as non-empty block header. +block { // comment + attr = 5 +} + +// Comment is alone in block body. +block { +// comment +} + +// Comment is before a statement. +block { +// comment + attr = 5 +} + +// Comment is after a statement. +block { + attr = 5 +// comment +} + +// +// Labeled blocks +// + +// Comment is on same line as empty block header. +block "label" { // comment +} + +// Comment is on same line as non-empty block header. +block "label" { // comment + attr = 5 +} + +// Comment is alone in block body. +block "label" { +// comment +} + +// Comment is before a statement. +block "label" { +// comment + attr = 5 +} + +// Comment is after a statement. +block "label" { + attr = 5 +// comment +} + + diff --git a/syntax/printer/testdata/example.expect b/syntax/printer/testdata/example.expect new file mode 100644 index 0000000000..dd6ab8e1f2 --- /dev/null +++ b/syntax/printer/testdata/example.expect @@ -0,0 +1,60 @@ +// This file tests a little bit of everything that the formatter should do. For +// example, this block of comments itself ensures that the output retains +// comments found in the source file. + +// +// Whitespace tests +// + +// Attributes should be given whitespace +attr_1 = 15 +attr_2 = 30 * 2 + 5 +attr_3 = field.access * 2 + +// Blocks with nothing inside of them should be truncated. +empty.block { } + +empty.block "labeled" { } + +// +// Alignment tests +// + +// Sequences of attributes which aren't separated by a blank line should have +// the equal sign aligned. +short_name = true +really_long_name = true + +extremely_long_name = true + +// Sequences of comments on aligned lines should also be aligned. +short_name = "short value" // Align me +really_long_name = "really long value" // Align me + +extremely_long_name = true // Unaligned + +// +// Indentation tests +// + +// Array literals, object literals, and blocks should all be indented properly. +multiline_array = [ + 0, + 1, +] + +mulitiline_object = { + foo = "bar", +} + +some_block { + attr = 15 + + inner_block { + attr = 20 + } +} + +// Trailing comments should be retained in the output. If this comment gets +// trimmed out, it usually indicates that a final flush is missing after +// traversing the AST. diff --git a/syntax/printer/testdata/example.in b/syntax/printer/testdata/example.in new file mode 100644 index 0000000000..efce00a8a3 --- /dev/null +++ b/syntax/printer/testdata/example.in @@ -0,0 +1,64 @@ +// This file tests a little bit of everything that the formatter should do. For +// example, this block of comments itself ensures that the output retains +// comments found in the source file. + +// +// Whitespace tests +// + +// Attributes should be given whitespace +attr_1=15 +attr_2=30*2+5 +attr_3=field.access*2 + +// Blocks with nothing inside of them should be truncated. +empty.block { + +} + +empty.block "labeled" { + +} + +// +// Alignment tests +// + +// Sequences of attributes which aren't separated by a blank line should have +// the equal sign aligned. +short_name = true +really_long_name = true + +extremely_long_name = true + +// Sequences of comments on aligned lines should also be aligned. +short_name = "short value" // Align me +really_long_name = "really long value" // Align me + +extremely_long_name = true // Unaligned + +// +// Indentation tests +// + +// Array literals, object literals, and blocks should all be indented properly. +multiline_array = [ +0, +1, +] + +mulitiline_object = { +foo = "bar", +} + +some_block { +attr = 15 + +inner_block { +attr = 20 +} +} + +// Trailing comments should be retained in the output. If this comment gets +// trimmed out, it usually indicates that a final flush is missing after +// traversing the AST. diff --git a/syntax/printer/testdata/func_call.expect b/syntax/printer/testdata/func_call.expect new file mode 100644 index 0000000000..e42c7acaf7 --- /dev/null +++ b/syntax/printer/testdata/func_call.expect @@ -0,0 +1,17 @@ +one_line = some_func(1, 2, 3, 4) + +multi_line = some_func(1, + 2, 3, + 4) + +multi_line_pretty = some_func( + 1, + 2, + 3, + 4, +) + +func_with_obj = some_func({ + key1 = "value1", + key2 = "value2", +}) diff --git a/syntax/printer/testdata/func_call.in b/syntax/printer/testdata/func_call.in new file mode 100644 index 0000000000..141a6057cf --- /dev/null +++ b/syntax/printer/testdata/func_call.in @@ -0,0 +1,17 @@ +one_line = some_func(1, 2, 3, 4) + +multi_line = some_func(1, +2, 3, +4) + +multi_line_pretty = some_func( +1, +2, +3, +4, +) + +func_with_obj = some_func({ + key1 = "value1", + key2 = "value2", +}) diff --git a/syntax/printer/testdata/mixed_list.expect b/syntax/printer/testdata/mixed_list.expect new file mode 100644 index 0000000000..7ce6f28dfb --- /dev/null +++ b/syntax/printer/testdata/mixed_list.expect @@ -0,0 +1,16 @@ +mixed_list = [0, true, { + key_1 = true, + key_2 = true, + key_3 = true, +}, "Hello!"] + +mixed_list_2 = [ + 0, + true, + { + key_1 = true, + key_2 = true, + key_3 = true, + }, + "Hello!", +] diff --git a/syntax/printer/testdata/mixed_list.in b/syntax/printer/testdata/mixed_list.in new file mode 100644 index 0000000000..7ce6f28dfb --- /dev/null +++ b/syntax/printer/testdata/mixed_list.in @@ -0,0 +1,16 @@ +mixed_list = [0, true, { + key_1 = true, + key_2 = true, + key_3 = true, +}, "Hello!"] + +mixed_list_2 = [ + 0, + true, + { + key_1 = true, + key_2 = true, + key_3 = true, + }, + "Hello!", +] diff --git a/syntax/printer/testdata/mixed_object.expect b/syntax/printer/testdata/mixed_object.expect new file mode 100644 index 0000000000..8d301e1f0e --- /dev/null +++ b/syntax/printer/testdata/mixed_object.expect @@ -0,0 +1,8 @@ +mixed_object = { + key_1 = true, + key_2 = [0, true, { + inner_1 = true, + inner_2 = true, + }], +} + diff --git a/syntax/printer/testdata/mixed_object.in b/syntax/printer/testdata/mixed_object.in new file mode 100644 index 0000000000..9334cfafa1 --- /dev/null +++ b/syntax/printer/testdata/mixed_object.in @@ -0,0 +1,7 @@ +mixed_object = { + key_1 = true, + key_2 = [0, true, { + inner_1 = true, + inner_2 = true, + }], +} diff --git a/syntax/printer/testdata/object_align.expect b/syntax/printer/testdata/object_align.expect new file mode 100644 index 0000000000..631536a437 --- /dev/null +++ b/syntax/printer/testdata/object_align.expect @@ -0,0 +1,11 @@ +block { + some_object = { + key_1 = 5, + long_key = 10, + longer_key = { + inner_key = true, + inner_key_2 = false, + }, + other_key = [0, 1, 2], + } +} diff --git a/syntax/printer/testdata/object_align.in b/syntax/printer/testdata/object_align.in new file mode 100644 index 0000000000..b618f61dfc --- /dev/null +++ b/syntax/printer/testdata/object_align.in @@ -0,0 +1,11 @@ +block { + some_object = { + key_1 = 5, + long_key = 10, + longer_key = { + inner_key = true, + inner_key_2 = false, + }, + other_key = [0, 1, 2], + } +} diff --git a/syntax/printer/testdata/oneline_block.expect b/syntax/printer/testdata/oneline_block.expect new file mode 100644 index 0000000000..8cd2c69d25 --- /dev/null +++ b/syntax/printer/testdata/oneline_block.expect @@ -0,0 +1,11 @@ +block { } + +block { } + +block { } + +block { } + +block { + // Comments should be kept. +} diff --git a/syntax/printer/testdata/oneline_block.in b/syntax/printer/testdata/oneline_block.in new file mode 100644 index 0000000000..2c1f74363a --- /dev/null +++ b/syntax/printer/testdata/oneline_block.in @@ -0,0 +1,14 @@ +block {} + +block { } + +block { +} + +block { + +} + +block { + // Comments should be kept. +} diff --git a/syntax/printer/testdata/raw_string.expect b/syntax/printer/testdata/raw_string.expect new file mode 100644 index 0000000000..5837439569 --- /dev/null +++ b/syntax/printer/testdata/raw_string.expect @@ -0,0 +1,15 @@ +block "label" { + attr = `'\"attr` +} + +block "multi_line" { + attr = `'\"this +is +a +multi_line +attr'\"` +} + +block "json" { + attr = `{ "key": "value" }` +} \ No newline at end of file diff --git a/syntax/printer/testdata/raw_string.in b/syntax/printer/testdata/raw_string.in new file mode 100644 index 0000000000..5837439569 --- /dev/null +++ b/syntax/printer/testdata/raw_string.in @@ -0,0 +1,15 @@ +block "label" { + attr = `'\"attr` +} + +block "multi_line" { + attr = `'\"this +is +a +multi_line +attr'\"` +} + +block "json" { + attr = `{ "key": "value" }` +} \ No newline at end of file diff --git a/syntax/printer/testdata/raw_string_label_error.error b/syntax/printer/testdata/raw_string_label_error.error new file mode 100644 index 0000000000..dd3f7f7c8b --- /dev/null +++ b/syntax/printer/testdata/raw_string_label_error.error @@ -0,0 +1 @@ +expected block label to be a double quoted string, but got "`multi_line`" \ No newline at end of file diff --git a/syntax/printer/testdata/raw_string_label_error.in b/syntax/printer/testdata/raw_string_label_error.in new file mode 100644 index 0000000000..1e16f9ae07 --- /dev/null +++ b/syntax/printer/testdata/raw_string_label_error.in @@ -0,0 +1,15 @@ +block "label" { + attr = `'\"attr` +} + +block `multi_line` { + attr = `'\"this +is +a +multi_line +attr'\"` +} + +block `json` { + attr = `{ "key": "value" }` +} \ No newline at end of file diff --git a/syntax/printer/trimmer.go b/syntax/printer/trimmer.go new file mode 100644 index 0000000000..5a76c0c79d --- /dev/null +++ b/syntax/printer/trimmer.go @@ -0,0 +1,115 @@ +package printer + +import ( + "io" + "text/tabwriter" +) + +// A trimmer is an io.Writer which filters tabwriter.Escape characters, +// trailing blanks and tabs from lines, and converting \f and \v characters +// into \n and \t (if no text/tabwriter is used when printing). +// +// Text wrapped by tabwriter.Escape characters is written to the underlying +// io.Writer unmodified. +type trimmer struct { + next io.Writer + state int + space []byte +} + +const ( + trimStateSpace = iota // Trimmer is reading space characters + trimStateEscape // Trimmer is reading escaped characters + trimStateText // Trimmer is reading text +) + +func (t *trimmer) discardWhitespace() { + t.state = trimStateSpace + t.space = t.space[0:0] +} + +func (t *trimmer) Write(data []byte) (n int, err error) { + // textStart holds the index of the start of a chunk of text not containing + // whitespace. It is reset every time a new chunk of text is encountered. + var textStart int + + for off, b := range data { + // Convert \v to \t + if b == '\v' { + b = '\t' + } + + switch t.state { + case trimStateSpace: + // Accumulate tabs and spaces in t.space until finding a non-tab or + // non-space character. + // + // If we find a newline, we write it directly and discard our pending + // whitespace (so that trailing whitespace up to the newline is ignored). + // + // If we find a tabwriter.Escape or text character we transition states. + switch b { + case '\t', ' ': + t.space = append(t.space, b) + case '\n', '\f': + // Discard all unwritten whitespace before the end of the line and write + // a newline. + t.discardWhitespace() + _, err = t.next.Write([]byte("\n")) + case tabwriter.Escape: + _, err = t.next.Write(t.space) + t.state = trimStateEscape + textStart = off + 1 // Skip escape character + default: + // Non-space character. Write our pending whitespace + // and then move to text state. + _, err = t.next.Write(t.space) + t.state = trimStateText + textStart = off + } + + case trimStateText: + // We're reading a chunk of text. Accumulate characters in the chunk + // until we find a whitespace character or a tabwriter.Escape. + switch b { + case '\t', ' ': + _, err = t.next.Write(data[textStart:off]) + t.discardWhitespace() + t.space = append(t.space, b) + case '\n', '\f': + _, err = t.next.Write(data[textStart:off]) + t.discardWhitespace() + if err == nil { + _, err = t.next.Write([]byte("\n")) + } + case tabwriter.Escape: + _, err = t.next.Write(data[textStart:off]) + t.state = trimStateEscape + textStart = off + 1 // +1: skip tabwriter.Escape + } + + case trimStateEscape: + // Accumulate everything until finding the closing tabwriter.Escape. + if b == tabwriter.Escape { + _, err = t.next.Write(data[textStart:off]) + t.discardWhitespace() + } + + default: + panic("unreachable") + } + if err != nil { + return off, err + } + } + n = len(data) + + // Flush the remainder of the text (as long as it's not whitespace). + switch t.state { + case trimStateEscape, trimStateText: + _, err = t.next.Write(data[textStart:n]) + t.discardWhitespace() + } + + return +} diff --git a/syntax/printer/walker.go b/syntax/printer/walker.go new file mode 100644 index 0000000000..01f71b21bd --- /dev/null +++ b/syntax/printer/walker.go @@ -0,0 +1,338 @@ +package printer + +import ( + "fmt" + "strings" + + "github.com/grafana/river/ast" + "github.com/grafana/river/token" +) + +// A walker walks an AST and sends lexical tokens and formatting information to +// a printer. +type walker struct { + p *printer +} + +func (w *walker) Walk(node ast.Node) error { + switch node := node.(type) { + case *ast.File: + w.walkFile(node) + case ast.Body: + w.walkStmts(node) + case ast.Stmt: + w.walkStmt(node) + case ast.Expr: + w.walkExpr(node) + default: + return fmt.Errorf("unsupported node type %T", node) + } + + return nil +} + +func (w *walker) walkFile(f *ast.File) { + w.p.SetComments(f.Comments) + w.walkStmts(f.Body) +} + +func (w *walker) walkStmts(ss []ast.Stmt) { + for i, s := range ss { + var addedSpacing bool + + // Two blocks should always be separated by a blank line. + if _, isBlock := s.(*ast.BlockStmt); i > 0 && isBlock { + w.p.Write(wsFormfeed) + addedSpacing = true + } + + // A blank line should always be added if there is a blank line in the + // source between two statements. + if i > 0 && !addedSpacing { + var ( + prevLine = ast.EndPos(ss[i-1]).Position().Line + curLine = ast.StartPos(ss[i-0]).Position().Line + + lineDiff = curLine - prevLine + ) + + if lineDiff > 1 { + w.p.Write(wsFormfeed) + } + } + + w.walkStmt(s) + + // Statements which cross multiple lines don't belong to the same row run. + // Add a formfeed to start a new row run if the node crossed more than one + // line, otherwise add the normal newline. + if nodeLines(s) > 1 { + w.p.Write(wsFormfeed) + } else { + w.p.Write(wsNewline) + } + } +} + +func nodeLines(n ast.Node) int { + var ( + startLine = ast.StartPos(n).Position().Line + endLine = ast.EndPos(n).Position().Line + ) + + return endLine - startLine + 1 +} + +func (w *walker) walkStmt(s ast.Stmt) { + switch s := s.(type) { + case *ast.AttributeStmt: + w.walkAttributeStmt(s) + case *ast.BlockStmt: + w.walkBlockStmt(s) + } +} + +func (w *walker) walkAttributeStmt(s *ast.AttributeStmt) { + w.p.Write(s.Name.NamePos, s.Name, wsVTab, token.ASSIGN, wsBlank) + w.walkExpr(s.Value) +} + +func (w *walker) walkBlockStmt(s *ast.BlockStmt) { + joined := strings.Join(s.Name, ".") + + w.p.Write( + s.NamePos, + &ast.Ident{Name: joined, NamePos: s.NamePos}, + ) + + if s.Label != "" { + label := fmt.Sprintf("%q", s.Label) + + w.p.Write( + wsBlank, + s.LabelPos, + &ast.LiteralExpr{Kind: token.STRING, Value: label}, + ) + } + + w.p.Write( + wsBlank, + s.LCurlyPos, token.LCURLY, wsIndent, + ) + + if len(s.Body) > 0 { + // Add a formfeed to start a new row run before writing any statements. + w.p.Write(wsFormfeed) + w.walkStmts(s.Body) + } else { + // There's no statements, but add a blank line between the left and right + // curly anyway. + w.p.Write(wsBlank) + } + + w.p.Write(wsUnindent, s.RCurlyPos, token.RCURLY) +} + +func (w *walker) walkExpr(e ast.Expr) { + switch e := e.(type) { + case *ast.LiteralExpr: + w.p.Write(e.ValuePos, e) + + case *ast.ArrayExpr: + w.walkArrayExpr(e) + + case *ast.ObjectExpr: + w.walkObjectExpr(e) + + case *ast.IdentifierExpr: + w.p.Write(e.Ident.NamePos, e.Ident) + + case *ast.AccessExpr: + w.walkExpr(e.Value) + w.p.Write(token.DOT, e.Name) + + case *ast.IndexExpr: + w.walkExpr(e.Value) + w.p.Write(e.LBrackPos, token.LBRACK) + w.walkExpr(e.Index) + w.p.Write(e.RBrackPos, token.RBRACK) + + case *ast.CallExpr: + w.walkCallExpr(e) + + case *ast.UnaryExpr: + w.p.Write(e.KindPos, e.Kind) + w.walkExpr(e.Value) + + case *ast.BinaryExpr: + // TODO(rfratto): + // + // 1. allow RHS to be on a new line + // + // 2. remove spacing between some operators to make precedence + // clearer like Go does + w.walkExpr(e.Left) + w.p.Write(wsBlank, e.KindPos, e.Kind, wsBlank) + w.walkExpr(e.Right) + + case *ast.ParenExpr: + w.p.Write(token.LPAREN) + w.walkExpr(e.Inner) + w.p.Write(token.RPAREN) + } +} + +func (w *walker) walkArrayExpr(e *ast.ArrayExpr) { + w.p.Write(e.LBrackPos, token.LBRACK) + prevPos := e.LBrackPos + + for i := 0; i < len(e.Elements); i++ { + var addedNewline bool + + elementPos := ast.StartPos(e.Elements[i]) + + // Add a newline if this element starts on a different line than the last + // element ended. + if differentLines(prevPos, elementPos) { + // Indent elements inside the array on different lines. The indent is + // done *before* the newline to make sure comments written before the + // newline are indented properly. + w.p.Write(wsIndent, wsFormfeed) + addedNewline = true + } else if i > 0 { + // Make sure a space is injected before the next element if two + // successive elements are on the same line. + w.p.Write(wsBlank) + } + prevPos = ast.EndPos(e.Elements[i]) + + // Write the expression. + w.walkExpr(e.Elements[i]) + + // Always add commas in between successive elements. + if i+1 < len(e.Elements) { + w.p.Write(token.COMMA) + } + + if addedNewline { + w.p.Write(wsUnindent) + } + } + + var addedSuffixNewline bool + + // If the closing bracket is on a different line than the final element, + // we need to add a trailing comma. + if len(e.Elements) > 0 && differentLines(prevPos, e.RBrackPos) { + // We add an indentation here so comments after the final element are + // indented. + w.p.Write(token.COMMA, wsIndent, wsFormfeed) + addedSuffixNewline = true + } + + if addedSuffixNewline { + w.p.Write(wsUnindent) + } + w.p.Write(e.RBrackPos, token.RBRACK) +} + +func (w *walker) walkObjectExpr(e *ast.ObjectExpr) { + w.p.Write(e.LCurlyPos, token.LCURLY, wsIndent) + + prevPos := e.LCurlyPos + + for i := 0; i < len(e.Fields); i++ { + field := e.Fields[i] + elementPos := ast.StartPos(field.Name) + + // Add a newline if this element starts on a different line than the last + // element ended. + if differentLines(prevPos, elementPos) { + // We want to align the equal sign for object attributes if the previous + // field only crossed one line. + if i > 0 && nodeLines(e.Fields[i-1].Value) == 1 { + w.p.Write(wsNewline) + } else { + w.p.Write(wsFormfeed) + } + } else if i > 0 { + // Make sure a space is injected before the next element if two successive + // elements are on the same line. + w.p.Write(wsBlank) + } + prevPos = ast.EndPos(field.Name) + + w.p.Write(field.Name.NamePos) + + // Write the field. + if field.Quoted { + w.p.Write(&ast.LiteralExpr{ + Kind: token.STRING, + ValuePos: field.Name.NamePos, + Value: fmt.Sprintf("%q", field.Name.Name), + }) + } else { + w.p.Write(field.Name) + } + + w.p.Write(wsVTab, token.ASSIGN, wsBlank) + w.walkExpr(field.Value) + + // Always add commas in between successive elements. + if i+1 < len(e.Fields) { + w.p.Write(token.COMMA) + } + } + + // If the closing bracket is on a different line than the final element, + // we need to add a trailing comma. + if len(e.Fields) > 0 && differentLines(prevPos, e.RCurlyPos) { + w.p.Write(token.COMMA, wsFormfeed) + } + + w.p.Write(wsUnindent, e.RCurlyPos, token.RCURLY) +} + +func (w *walker) walkCallExpr(e *ast.CallExpr) { + w.walkExpr(e.Value) + w.p.Write(token.LPAREN) + + prevPos := e.LParenPos + + for i, arg := range e.Args { + var addedNewline bool + + argPos := ast.StartPos(arg) + + // Add a newline if this element starts on a different line than the last + // element ended. + if differentLines(prevPos, argPos) { + w.p.Write(wsFormfeed, wsIndent) + addedNewline = true + } + + w.walkExpr(arg) + prevPos = ast.EndPos(arg) + + if i+1 < len(e.Args) { + w.p.Write(token.COMMA, wsBlank) + } + + if addedNewline { + w.p.Write(wsUnindent) + } + } + + // Add a final comma if the final argument is on a different line than the + // right parenthesis. + if differentLines(prevPos, e.RParenPos) { + w.p.Write(token.COMMA, wsFormfeed) + } + + w.p.Write(token.RPAREN) +} + +// differentLines returns true if a and b are on different lines. +func differentLines(a, b token.Pos) bool { + return a.Position().Line != b.Position().Line +} diff --git a/syntax/river.go b/syntax/river.go new file mode 100644 index 0000000000..0944a9e3be --- /dev/null +++ b/syntax/river.go @@ -0,0 +1,346 @@ +// Package river implements a high-level API for decoding and encoding River +// configuration files. The mapping between River and Go values is described in +// the documentation for the Unmarshal and Marshal functions. +// +// Lower-level APIs which give more control over configuration evaluation are +// available in the inner packages. The implementation of this package is +// minimal and serves as a reference for how to consume the lower-level +// packages. +package river + +import ( + "bytes" + "io" + + "github.com/grafana/river/parser" + "github.com/grafana/river/token/builder" + "github.com/grafana/river/vm" +) + +// Marshal returns the pretty-printed encoding of v as a River configuration +// file. v must be a Go struct with river struct tags which determine the +// structure of the resulting file. +// +// Marshal traverses the value v recursively, encoding each struct field as a +// River block or River attribute, based on the flags provided to the river +// struct tag. +// +// When a struct field represents a River block, Marshal creates a new block +// and recursively encodes the value as the body of the block. The name of the +// created block is taken from the name specified by the river struct tag. +// +// Struct fields which represent River blocks must be either a Go struct or a +// slice of Go structs. When the field is a Go struct, its value is encoded as +// a single block. When the field is a slice of Go structs, a block is created +// for each element in the slice. +// +// When encoding a block, if the inner Go struct has a struct field +// representing a River block label, the value of that field is used as the +// label name for the created block. Fields used for River block labels must be +// the string type. When specified, there must not be more than one struct +// field which represents a block label. +// +// The river tag specifies a name, possibly followed by a comma-separated list +// of options. The name must be empty if the provided options do not support a +// name being defined. The following provides examples for all supported struct +// field tags with their meanings: +// +// // Field appears as a block named "example". It will always appear in the +// // resulting encoding. When decoding, "example" is treated as a required +// // block and must be present in the source text. +// Field struct{...} `river:"example,block"` +// +// // Field appears as a set of blocks named "example." It will appear in the +// // resulting encoding if there is at least one element in the slice. When +// // decoding, "example" is treated as a required block and at least one +// // "example" block must be present in the source text. +// Field []struct{...} `river:"example,block"` +// +// // Field appears as block named "example." It will always appear in the +// // resulting encoding. When decoding, "example" is treated as an optional +// // block and can be omitted from the source text. +// Field struct{...} `river:"example,block,optional"` +// +// // Field appears as a set of blocks named "example." It will appear in the +// // resulting encoding if there is at least one element in the slice. When +// // decoding, "example" is treated as an optional block and can be omitted +// // from the source text. +// Field []struct{...} `river:"example,block,optional"` +// +// // Field appears as an attribute named "example." It will always appear in +// // the resulting encoding. When decoding, "example" is treated as a +// // required attribute and must be present in the source text. +// Field bool `river:"example,attr"` +// +// // Field appears as an attribute named "example." If the field's value is +// // the Go zero value, "example" is omitted from the resulting encoding. +// // When decoding, "example" is treated as an optional attribute and can be +// // omitted from the source text. +// Field bool `river:"example,attr,optional"` +// +// // The value of Field appears as the block label for the struct being +// // converted into a block. When decoding, a block label must be provided. +// Field string `river:",label"` +// +// // The inner attributes and blocks of Field are exposed as top-level +// // attributes and blocks of the outer struct. +// Field struct{...} `river:",squash"` +// +// // Field appears as a set of blocks starting with "example.". Only the +// // first set element in the struct will be encoded. Each field in struct +// // must be a block. The name of the block is prepended to the enum name. +// // When decoding, enum blocks are treated as optional blocks and can be +// // omitted from the source text. +// Field []struct{...} `river:"example,enum"` +// +// // Field is equivalent to `river:"example,enum"`. +// Field []struct{...} `river:"example,enum,optional"` +// +// If a river tag specifies a required or optional block, the name is permitted +// to contain period `.` characters. +// +// Marshal will panic if it encounters a struct with invalid river tags. +// +// When a struct field represents a River attribute, Marshal encodes the struct +// value as a River value. The attribute name will be taken from the name +// specified by the river struct tag. See MarshalValue for the rules used to +// convert a Go value into a River value. +func Marshal(v interface{}) ([]byte, error) { + var buf bytes.Buffer + if err := NewEncoder(&buf).Encode(v); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// MarshalValue returns the pretty-printed encoding of v as a River value. +// +// MarshalValue traverses the value v recursively. If an encountered value +// implements the encoding.TextMarshaler interface, MarshalValue calls its +// MarshalText method and encodes the result as a River string. If a value +// implements the Capsule interface, it always encodes as a River capsule +// value. +// +// Otherwise, MarshalValue uses the following type-dependent default encodings: +// +// Boolean values encode to River bools. +// +// Floating point, integer, and Number values encode to River numbers. +// +// String values encode to River strings. +// +// Array and slice values encode to River arrays, except that []byte is +// converted into a River string. Nil slices encode as an empty array and nil +// []byte slices encode as an empty string. +// +// Structs encode to River objects, using Go struct field tags to determine the +// resulting structure of the River object. Each exported struct field with a +// river tag becomes an object field, using the tag name as the field name. +// Other struct fields are ignored. If no struct field has a river tag, the +// struct encodes to a River capsule instead. +// +// Function values encode to River functions, which appear in the resulting +// text as strings formatted as "function(GO_TYPE)". +// +// All other Go values encode to River capsules, which appear in the resulting +// text as strings formatted as "capsule(GO_TYPE)". +// +// The river tag specifies the field name, possibly followed by a +// comma-separated list of options. The following provides examples for all +// supported struct field tags with their meanings: +// +// // Field appears as an object field named "my_name". It will always +// // appear in the resulting encoding. When decoding, "my_name" is treated +// // as a required attribute and must be present in the source text. +// Field bool `river:"my_name,attr"` +// +// // Field appears as an object field named "my_name". If the field's value +// // is the Go zero value, "example" is omitted from the resulting encoding. +// // When decoding, "my_name" is treated as an optional attribute and can be +// // omitted from the source text. +// Field bool `river:"my_name,attr,optional"` +func MarshalValue(v interface{}) ([]byte, error) { + var buf bytes.Buffer + if err := NewEncoder(&buf).EncodeValue(v); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// Encoder writes River configuration to an output stream. Call NewEncoder to +// create instances of Encoder. +type Encoder struct { + w io.Writer +} + +// NewEncoder returns a new Encoder which writes configuration to w. +func NewEncoder(w io.Writer) *Encoder { + return &Encoder{w: w} +} + +// Encode converts the value pointed to by v into a River configuration file +// and writes the result to the Decoder's output stream. +// +// See the documentation for Marshal for details about the conversion of Go +// values into River configuration. +func (enc *Encoder) Encode(v interface{}) error { + f := builder.NewFile() + f.Body().AppendFrom(v) + + _, err := f.WriteTo(enc.w) + return err +} + +// EncodeValue converts the value pointed to by v into a River value and writes +// the result to the Decoder's output stream. +// +// See the documentation for MarshalValue for details about the conversion of +// Go values into River values. +func (enc *Encoder) EncodeValue(v interface{}) error { + expr := builder.NewExpr() + expr.SetValue(v) + + _, err := expr.WriteTo(enc.w) + return err +} + +// Unmarshal converts the River configuration file specified by in and stores +// it in the struct value pointed to by v. If v is nil or not a pointer, +// Unmarshal panics. The configuration specified by in may use expressions to +// compute values while unmarshaling. Refer to the River language documentation +// for the list of valid formatting and expression rules. +// +// Unmarshal uses the inverse of the encoding rules that Marshal uses, +// allocating maps, slices, and pointers as necessary. +// +// To unmarshal a River body into a map[string]T, Unmarshal assigns each +// attribute to a key in the map, and decodes the attribute's value as the +// value for the map entry. Only attribute statements are allowed when +// unmarshaling into a map. +// +// To unmarshal a River body into a struct, Unmarshal matches incoming +// attributes and blocks to the river struct tags specified by v. Incoming +// attribute and blocks which do not match to a river struct tag cause a +// decoding error. Additionally, any attribute or block marked as required by +// the river struct tag that are not present in the source text will generate a +// decoding error. +// +// To unmarshal a list of River blocks into a slice, Unmarshal resets the slice +// length to zero and then appends each element to the slice. +// +// To unmarshal a list of River blocks into a Go array, Unmarshal decodes each +// block into the corresponding Go array element. If the number of River blocks +// does not match the length of the Go array, a decoding error is returned. +// +// Unmarshal follows the rules specified by UnmarshalValue when unmarshaling +// the value of an attribute. +func Unmarshal(in []byte, v interface{}) error { + dec := NewDecoder(bytes.NewReader(in)) + return dec.Decode(v) +} + +// UnmarshalValue converts the River configuration file specified by in and +// stores it in the value pointed to by v. If v is nil or not a pointer, +// UnmarshalValue panics. The configuration specified by in may use expressions +// to compute values while unmarshaling. Refer to the River language +// documentation for the list of valid formatting and expression rules. +// +// Unmarshal uses the inverse of the encoding rules that MarshalValue uses, +// allocating maps, slices, and pointers as necessary, with the following +// additional rules: +// +// After converting a River value into its Go value counterpart, the Go value +// may be converted into a capsule if the capsule type implements +// ConvertibleIntoCapsule. +// +// To unmarshal a River object into a struct, UnmarshalValue matches incoming +// object fields to the river struct tags specified by v. Incoming object +// fields which do not match to a river struct tag cause a decoding error. +// Additionally, any object field marked as required by the river struct +// tag that are not present in the source text will generate a decoding error. +// +// To unmarshal River into an interface value, Unmarshal stores one of the +// following: +// +// - bool, for River bools +// - float64, for floating point River numbers +// and integers which are too big to fit in either of int/int64/uint64 +// - int/int64/uint64, in this order of preference, for signed and unsigned +// River integer numbers, depending on how big they are +// - string, for River strings +// - []interface{}, for River arrays +// - map[string]interface{}, for River objects +// +// Capsule and function types will retain their original type when decoding +// into an interface value. +// +// To unmarshal a River array into a slice, Unmarshal resets the slice length +// to zero and then appends each element to the slice. +// +// To unmarshal a River array into a Go array, Unmarshal decodes River array +// elements into the corresponding Go array element. If the number of River +// elements does not match the length of the Go array, a decoding error is +// returned. +// +// To unmarshal a River object into a Map, Unmarshal establishes a map to use. +// If the map is nil, Unmarshal allocates a new map. Otherwise, Unmarshal +// reuses the existing map, keeping existing entries. Unmarshal then stores +// key-value pairs from the River object into the map. The map's key type must +// be string. +func UnmarshalValue(in []byte, v interface{}) error { + dec := NewDecoder(bytes.NewReader(in)) + return dec.DecodeValue(v) +} + +// Decoder reads River configuration from an input stream. Call NewDecoder to +// create instances of Decoder. +type Decoder struct { + r io.Reader +} + +// NewDecoder returns a new Decoder which reads configuration from r. +func NewDecoder(r io.Reader) *Decoder { + return &Decoder{r: r} +} + +// Decode reads the River-encoded file from the Decoder's input and stores it +// in the value pointed to by v. Data will be read from the Decoder's input +// until EOF is reached. +// +// See the documentation for Unmarshal for details about the conversion of River +// configuration into Go values. +func (dec *Decoder) Decode(v interface{}) error { + bb, err := io.ReadAll(dec.r) + if err != nil { + return err + } + + f, err := parser.ParseFile("", bb) + if err != nil { + return err + } + + eval := vm.New(f) + return eval.Evaluate(nil, v) +} + +// DecodeValue reads the River-encoded expression from the Decoder's input and +// stores it in the value pointed to by v. Data will be read from the Decoder's +// input until EOF is reached. +// +// See the documentation for UnmarshalValue for details about the conversion of +// River values into Go values. +func (dec *Decoder) DecodeValue(v interface{}) error { + bb, err := io.ReadAll(dec.r) + if err != nil { + return err + } + + f, err := parser.ParseExpression(string(bb)) + if err != nil { + return err + } + + eval := vm.New(f) + return eval.Evaluate(nil, v) +} diff --git a/syntax/river_test.go b/syntax/river_test.go new file mode 100644 index 0000000000..99247f54da --- /dev/null +++ b/syntax/river_test.go @@ -0,0 +1,152 @@ +package river_test + +import ( + "fmt" + "os" + + river "github.com/grafana/river" +) + +func ExampleUnmarshal() { + // Character is our block type which holds an individual character from a + // book. + type Character struct { + // Name of the character. The name is decoded from the block label. + Name string `river:",label"` + // Age of the character. The age is a required attribute within the block, + // and must be set in the config. + Age int `river:"age,attr"` + // Location the character lives in. The location is an optional attribute + // within the block. Optional attributes do not have to bet set. + Location string `river:"location,attr,optional"` + } + + // Book is our overall type where we decode the overall River file into. + type Book struct { + // Title of the book (required attribute). + Title string `river:"title,attr"` + // List of characters. Each character is a labeled block. The optional tag + // means that it is valid not provide a character block. Decoding into a + // slice permits there to be multiple specified character blocks. + Characters []*Character `river:"character,block,optional"` + } + + // Create our book with two characters. + input := ` + title = "Wheel of Time" + + character "Rand" { + age = 19 + location = "Two Rivers" + } + + character "Perrin" { + age = 19 + location = "Two Rivers" + } + ` + + // Unmarshal the config into our Book type and print out the data. + var b Book + if err := river.Unmarshal([]byte(input), &b); err != nil { + panic(err) + } + + fmt.Printf("%s characters:\n", b.Title) + + for _, c := range b.Characters { + if c.Location != "" { + fmt.Printf("\t%s (age %d, location %s)\n", c.Name, c.Age, c.Location) + } else { + fmt.Printf("\t%s (age %d)\n", c.Name, c.Age) + } + } + + // Output: + // Wheel of Time characters: + // Rand (age 19, location Two Rivers) + // Perrin (age 19, location Two Rivers) +} + +// This example shows how functions may be called within user configurations. +// We focus on the `env` function from the standard library, which retrieves a +// value from an environment variable. +func ExampleUnmarshal_functions() { + // Set an environment variable to use in the test. + _ = os.Setenv("EXAMPLE", "Jane Doe") + + type Data struct { + String string `river:"string,attr"` + } + + input := ` + string = env("EXAMPLE") + ` + + var d Data + if err := river.Unmarshal([]byte(input), &d); err != nil { + panic(err) + } + + fmt.Println(d.String) + // Output: Jane Doe +} + +func ExampleUnmarshalValue() { + input := `3 + 5` + + var num int + if err := river.UnmarshalValue([]byte(input), &num); err != nil { + panic(err) + } + + fmt.Println(num) + // Output: 8 +} + +func ExampleMarshal() { + type Person struct { + Name string `river:"name,attr"` + Age int `river:"age,attr"` + Location string `river:"location,attr,optional"` + } + + p := Person{ + Name: "John Doe", + Age: 43, + } + + bb, err := river.Marshal(p) + if err != nil { + panic(err) + } + + fmt.Println(string(bb)) + // Output: + // name = "John Doe" + // age = 43 +} + +func ExampleMarshalValue() { + type Person struct { + Name string `river:"name,attr"` + Age int `river:"age,attr"` + } + + p := Person{ + Name: "John Doe", + Age: 43, + } + + bb, err := river.MarshalValue(p) + if err != nil { + panic(err) + } + + fmt.Println(string(bb)) + // Output: + // { + // name = "John Doe", + // age = 43, + // } +} diff --git a/syntax/rivertypes/optional_secret.go b/syntax/rivertypes/optional_secret.go new file mode 100644 index 0000000000..75648af046 --- /dev/null +++ b/syntax/rivertypes/optional_secret.go @@ -0,0 +1,84 @@ +package rivertypes + +import ( + "fmt" + + "github.com/grafana/river/internal/value" + "github.com/grafana/river/token" + "github.com/grafana/river/token/builder" +) + +// OptionalSecret holds a potentially sensitive value. When IsSecret is true, +// the OptionalSecret's Value will be treated as sensitive and will be hidden +// from users when rendering River. +// +// OptionalSecrets may be converted from river strings and the Secret type, +// which will set IsSecret accordingly. +// +// Additionally, OptionalSecrets may be converted into the Secret type +// regardless of the value of IsSecret. OptionalSecret can be converted into a +// string as long as IsSecret is false. +type OptionalSecret struct { + IsSecret bool + Value string +} + +var ( + _ value.Capsule = OptionalSecret{} + _ value.ConvertibleIntoCapsule = OptionalSecret{} + _ value.ConvertibleFromCapsule = (*OptionalSecret)(nil) + + _ builder.Tokenizer = OptionalSecret{} +) + +// RiverCapsule marks OptionalSecret as a RiverCapsule. +func (s OptionalSecret) RiverCapsule() {} + +// ConvertInto converts the OptionalSecret and stores it into the Go value +// pointed at by dst. OptionalSecrets can always be converted into *Secret. +// OptionalSecrets can only be converted into *string if IsSecret is false. In +// other cases, this method will return an explicit error or +// river.ErrNoConversion. +func (s OptionalSecret) ConvertInto(dst interface{}) error { + switch dst := dst.(type) { + case *Secret: + *dst = Secret(s.Value) + return nil + case *string: + if s.IsSecret { + return fmt.Errorf("secrets may not be converted into strings") + } + *dst = s.Value + return nil + } + + return value.ErrNoConversion +} + +// ConvertFrom converts the src value and stores it into the OptionalSecret s. +// Secrets and strings can be converted into an OptionalSecret. In other +// cases, this method will return river.ErrNoConversion. +func (s *OptionalSecret) ConvertFrom(src interface{}) error { + switch src := src.(type) { + case Secret: + *s = OptionalSecret{IsSecret: true, Value: string(src)} + return nil + case string: + *s = OptionalSecret{Value: src} + return nil + } + + return value.ErrNoConversion +} + +// RiverTokenize returns a set of custom tokens to represent this value in +// River text. +func (s OptionalSecret) RiverTokenize() []builder.Token { + if s.IsSecret { + return []builder.Token{{Tok: token.LITERAL, Lit: "(secret)"}} + } + return []builder.Token{{ + Tok: token.STRING, + Lit: fmt.Sprintf("%q", s.Value), + }} +} diff --git a/syntax/rivertypes/optional_secret_test.go b/syntax/rivertypes/optional_secret_test.go new file mode 100644 index 0000000000..bd8a0baeea --- /dev/null +++ b/syntax/rivertypes/optional_secret_test.go @@ -0,0 +1,92 @@ +package rivertypes_test + +import ( + "testing" + + "github.com/grafana/river/rivertypes" + "github.com/grafana/river/token/builder" + "github.com/stretchr/testify/require" +) + +func TestOptionalSecret(t *testing.T) { + t.Run("non-sensitive conversion to string is allowed", func(t *testing.T) { + input := rivertypes.OptionalSecret{IsSecret: false, Value: "testval"} + + var s string + err := decodeTo(t, input, &s) + require.NoError(t, err) + require.Equal(t, "testval", s) + }) + + t.Run("sensitive conversion to string is disallowed", func(t *testing.T) { + input := rivertypes.OptionalSecret{IsSecret: true, Value: "testval"} + + var s string + err := decodeTo(t, input, &s) + require.NotNil(t, err) + require.Contains(t, err.Error(), "secrets may not be converted into strings") + }) + + t.Run("non-sensitive conversion to secret is allowed", func(t *testing.T) { + input := rivertypes.OptionalSecret{IsSecret: false, Value: "testval"} + + var s rivertypes.Secret + err := decodeTo(t, input, &s) + require.NoError(t, err) + require.Equal(t, rivertypes.Secret("testval"), s) + }) + + t.Run("sensitive conversion to secret is allowed", func(t *testing.T) { + input := rivertypes.OptionalSecret{IsSecret: true, Value: "testval"} + + var s rivertypes.Secret + err := decodeTo(t, input, &s) + require.NoError(t, err) + require.Equal(t, rivertypes.Secret("testval"), s) + }) + + t.Run("conversion from string is allowed", func(t *testing.T) { + var s rivertypes.OptionalSecret + err := decodeTo(t, string("Hello, world!"), &s) + require.NoError(t, err) + + expect := rivertypes.OptionalSecret{ + IsSecret: false, + Value: "Hello, world!", + } + require.Equal(t, expect, s) + }) + + t.Run("conversion from secret is allowed", func(t *testing.T) { + var s rivertypes.OptionalSecret + err := decodeTo(t, rivertypes.Secret("Hello, world!"), &s) + require.NoError(t, err) + + expect := rivertypes.OptionalSecret{ + IsSecret: true, + Value: "Hello, world!", + } + require.Equal(t, expect, s) + }) +} + +func TestOptionalSecret_Write(t *testing.T) { + tt := []struct { + name string + value interface{} + expect string + }{ + {"non-sensitive", rivertypes.OptionalSecret{Value: "foobar"}, `"foobar"`}, + {"sensitive", rivertypes.OptionalSecret{IsSecret: true, Value: "foobar"}, `(secret)`}, + {"non-sensitive pointer", &rivertypes.OptionalSecret{Value: "foobar"}, `"foobar"`}, + {"sensitive pointer", &rivertypes.OptionalSecret{IsSecret: true, Value: "foobar"}, `(secret)`}, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + be := builder.NewExpr() + be.SetValue(tc.value) + require.Equal(t, tc.expect, string(be.Bytes())) + }) + } +} diff --git a/syntax/rivertypes/secret.go b/syntax/rivertypes/secret.go new file mode 100644 index 0000000000..c2eb357d03 --- /dev/null +++ b/syntax/rivertypes/secret.go @@ -0,0 +1,65 @@ +package rivertypes + +import ( + "fmt" + + "github.com/grafana/river/internal/value" + "github.com/grafana/river/token" + "github.com/grafana/river/token/builder" +) + +// Secret is a River capsule holding a sensitive string. The contents of a +// Secret are never displayed to the user when rendering River. +// +// Secret allows itself to be converted from a string River value, but never +// the inverse. This ensures that a user can't accidentally leak a sensitive +// value. +type Secret string + +var ( + _ value.Capsule = Secret("") + _ value.ConvertibleIntoCapsule = Secret("") + _ value.ConvertibleFromCapsule = (*Secret)(nil) + + _ builder.Tokenizer = Secret("") +) + +// RiverCapsule marks Secret as a RiverCapsule. +func (s Secret) RiverCapsule() {} + +// ConvertInto converts the Secret and stores it into the Go value pointed at +// by dst. Secrets can be converted into *OptionalSecret. In other cases, this +// method will return an explicit error or river.ErrNoConversion. +func (s Secret) ConvertInto(dst interface{}) error { + switch dst := dst.(type) { + case *OptionalSecret: + *dst = OptionalSecret{IsSecret: true, Value: string(s)} + return nil + case *string: + return fmt.Errorf("secrets may not be converted into strings") + } + + return value.ErrNoConversion +} + +// ConvertFrom converts the src value and stores it into the Secret s. +// OptionalSecrets and strings can be converted into a Secret. In other cases, +// this method will return river.ErrNoConversion. +func (s *Secret) ConvertFrom(src interface{}) error { + switch src := src.(type) { + case OptionalSecret: + *s = Secret(src.Value) + return nil + case string: + *s = Secret(src) + return nil + } + + return value.ErrNoConversion +} + +// RiverTokenize returns a set of custom tokens to represent this value in +// River text. +func (s Secret) RiverTokenize() []builder.Token { + return []builder.Token{{Tok: token.LITERAL, Lit: "(secret)"}} +} diff --git a/syntax/rivertypes/secret_test.go b/syntax/rivertypes/secret_test.go new file mode 100644 index 0000000000..cade74647b --- /dev/null +++ b/syntax/rivertypes/secret_test.go @@ -0,0 +1,47 @@ +package rivertypes_test + +import ( + "testing" + + "github.com/grafana/river/parser" + "github.com/grafana/river/rivertypes" + "github.com/grafana/river/vm" + "github.com/stretchr/testify/require" +) + +func TestSecret(t *testing.T) { + t.Run("strings can be converted to secret", func(t *testing.T) { + var s rivertypes.Secret + err := decodeTo(t, string("Hello, world!"), &s) + require.NoError(t, err) + require.Equal(t, rivertypes.Secret("Hello, world!"), s) + }) + + t.Run("secrets cannot be converted to strings", func(t *testing.T) { + var s string + err := decodeTo(t, rivertypes.Secret("Hello, world!"), &s) + require.NotNil(t, err) + require.Contains(t, err.Error(), "secrets may not be converted into strings") + }) + + t.Run("secrets can be passed to secrets", func(t *testing.T) { + var s rivertypes.Secret + err := decodeTo(t, rivertypes.Secret("Hello, world!"), &s) + require.NoError(t, err) + require.Equal(t, rivertypes.Secret("Hello, world!"), s) + }) +} + +func decodeTo(t *testing.T, input interface{}, target interface{}) error { + t.Helper() + + expr, err := parser.ParseExpression("val") + require.NoError(t, err) + + eval := vm.New(expr) + return eval.Evaluate(&vm.Scope{ + Variables: map[string]interface{}{ + "val": input, + }, + }, target) +} diff --git a/syntax/scanner/identifier.go b/syntax/scanner/identifier.go new file mode 100644 index 0000000000..ed2239e060 --- /dev/null +++ b/syntax/scanner/identifier.go @@ -0,0 +1,60 @@ +package scanner + +import ( + "fmt" + + "github.com/grafana/river/token" +) + +// IsValidIdentifier returns true if the given string is a valid river +// identifier. +func IsValidIdentifier(in string) bool { + s := New(token.NewFile(""), []byte(in), nil, 0) + _, tok, lit := s.Scan() + return tok == token.IDENT && lit == in +} + +// SanitizeIdentifier will return the given string mutated into a valid river +// identifier. If the given string is already a valid identifier, it will be +// returned unchanged. +// +// This should be used with caution since the different inputs can result in +// identical outputs. +func SanitizeIdentifier(in string) (string, error) { + if in == "" { + return "", fmt.Errorf("cannot generate a new identifier for an empty string") + } + + if IsValidIdentifier(in) { + return in, nil + } + + newValue := generateNewIdentifier(in) + if !IsValidIdentifier(newValue) { + panic(fmt.Errorf("invalid identifier %q generated for `%q`", newValue, in)) + } + + return newValue, nil +} + +// generateNewIdentifier expects a valid river prefix and replacement +// string and returns a new identifier based on the given input. +func generateNewIdentifier(in string) string { + newValue := "" + for i, c := range in { + if i == 0 { + if isDigit(c) { + newValue = "_" + } + } + + if !(isLetter(c) || isDigit(c)) { + newValue += "_" + continue + } + + newValue += string(c) + } + + return newValue +} diff --git a/syntax/scanner/identifier_test.go b/syntax/scanner/identifier_test.go new file mode 100644 index 0000000000..e1dfead833 --- /dev/null +++ b/syntax/scanner/identifier_test.go @@ -0,0 +1,92 @@ +package scanner_test + +import ( + "testing" + + "github.com/grafana/river/scanner" + "github.com/stretchr/testify/require" +) + +var validTestCases = []struct { + name string + identifier string + expect bool +}{ + {"empty", "", false}, + {"start_number", "0identifier_1", false}, + {"start_char", "identifier_1", true}, + {"start_underscore", "_identifier_1", true}, + {"special_chars", "!@#$%^&*()", false}, + {"special_char", "identifier_1!", false}, + {"spaces", "identifier _ 1", false}, +} + +func TestIsValidIdentifier(t *testing.T) { + for _, tc := range validTestCases { + t.Run(tc.name, func(t *testing.T) { + require.Equal(t, tc.expect, scanner.IsValidIdentifier(tc.identifier)) + }) + } +} + +func BenchmarkIsValidIdentifier(b *testing.B) { + for i := 0; i < b.N; i++ { + for _, tc := range validTestCases { + _ = scanner.IsValidIdentifier(tc.identifier) + } + } +} + +var sanitizeTestCases = []struct { + name string + identifier string + expectIdentifier string + expectErr string +}{ + {"empty", "", "", "cannot generate a new identifier for an empty string"}, + {"start_number", "0identifier_1", "_0identifier_1", ""}, + {"start_char", "identifier_1", "identifier_1", ""}, + {"start_underscore", "_identifier_1", "_identifier_1", ""}, + {"special_chars", "!@#$%^&*()", "__________", ""}, + {"special_char", "identifier_1!", "identifier_1_", ""}, + {"spaces", "identifier _ 1", "identifier___1", ""}, +} + +func TestSanitizeIdentifier(t *testing.T) { + for _, tc := range sanitizeTestCases { + t.Run(tc.name, func(t *testing.T) { + newIdentifier, err := scanner.SanitizeIdentifier(tc.identifier) + if tc.expectErr != "" { + require.EqualError(t, err, tc.expectErr) + return + } + + require.NoError(t, err) + require.Equal(t, tc.expectIdentifier, newIdentifier) + }) + } +} + +func BenchmarkSanitizeIdentifier(b *testing.B) { + for i := 0; i < b.N; i++ { + for _, tc := range sanitizeTestCases { + _, _ = scanner.SanitizeIdentifier(tc.identifier) + } + } +} + +func FuzzSanitizeIdentifier(f *testing.F) { + for _, tc := range sanitizeTestCases { + f.Add(tc.identifier) + } + + f.Fuzz(func(t *testing.T, input string) { + newIdentifier, err := scanner.SanitizeIdentifier(input) + if input == "" { + require.EqualError(t, err, "cannot generate a new identifier for an empty string") + return + } + require.NoError(t, err) + require.True(t, scanner.IsValidIdentifier(newIdentifier)) + }) +} diff --git a/syntax/scanner/scanner.go b/syntax/scanner/scanner.go new file mode 100644 index 0000000000..e637a785b9 --- /dev/null +++ b/syntax/scanner/scanner.go @@ -0,0 +1,704 @@ +// Package scanner implements a lexical scanner for River source files. +package scanner + +import ( + "fmt" + "unicode" + "unicode/utf8" + + "github.com/grafana/river/token" +) + +// EBNF for the scanner: +// +// letter = /* any unicode letter class character */ | "_" +// number = /* any unicode number class character */ +// digit = /* ASCII characters 0 through 9 */ +// digits = digit { digit } +// string_character = /* any unicode character that isn't '"' */ +// +// COMMENT = line_comment | block_comment +// line_comment = "//" { character } +// block_comment = "/*" { character | newline } "*/" +// +// IDENT = letter { letter | number } +// NULL = "null" +// BOOL = "true" | "false" +// NUMBER = digits +// FLOAT = ( digits | "." digits ) [ "e" [ "+" | "-" ] digits ] +// STRING = '"' { string_character | escape_sequence } '"' +// OR = "||" +// AND = "&&" +// NOT = "!" +// NEQ = "!=" +// ASSIGN = "=" +// EQ = "==" +// LT = "<" +// LTE = "<=" +// GT = ">" +// GTE = ">=" +// ADD = "+" +// SUB = "-" +// MUL = "*" +// DIV = "/" +// MOD = "%" +// POW = "^" +// LCURLY = "{" +// RCURLY = "}" +// LPAREN = "(" +// RPAREN = ")" +// LBRACK = "[" +// RBRACK = "]" +// COMMA = "," +// DOT = "." +// +// The EBNF for escape_sequence is currently undocumented; see scanEscape for +// details. The escape sequences supported by River are the same as the escape +// sequences supported by Go, except that it is always valid to use \' in +// strings (which in Go, is only valid to use in character literals). + +// ErrorHandler is invoked whenever there is an error. +type ErrorHandler func(pos token.Pos, msg string) + +// Mode is a set of bitwise flags which control scanner behavior. +type Mode uint + +const ( + // IncludeComments will cause comments to be returned as comment tokens. + // Otherwise, comments are ignored. + IncludeComments Mode = 1 << iota + + // Avoids automatic insertion of terminators (for testing only). + dontInsertTerms +) + +const ( + bom = 0xFEFF // byte order mark, permitted as very first character + eof = -1 // end of file +) + +// Scanner holds the internal state for the tokenizer while processing configs. +type Scanner struct { + file *token.File // Config file handle for tracking line offsets + input []byte // Input config + err ErrorHandler // Error reporting (may be nil) + mode Mode + + // scanning state variables: + + ch rune // Current character + offset int // Byte offset of ch + readOffset int // Byte offset of first character *after* ch + insertTerm bool // Insert a newline before the next newline + numErrors int // Number of errors encountered during scanning +} + +// New creates a new scanner to tokenize the provided input config. The scanner +// uses the provided file for adding line information for each token. The mode +// parameter customizes scanner behavior. +// +// Calls to Scan will invoke the error handler eh when a lexical error is found +// if eh is not nil. +func New(file *token.File, input []byte, eh ErrorHandler, mode Mode) *Scanner { + s := &Scanner{ + file: file, + input: input, + err: eh, + mode: mode, + } + + // Preload first character. + s.next() + if s.ch == bom { + s.next() // Ignore BOM if it's the first character. + } + return s +} + +// peek gets the next byte after the current character without advancing the +// scanner. Returns 0 if the scanner is at EOF. +func (s *Scanner) peek() byte { + if s.readOffset < len(s.input) { + return s.input[s.readOffset] + } + return 0 +} + +// next advances the scanner and reads the next Unicode character into s.ch. +// s.ch == eof indicates end of file. +func (s *Scanner) next() { + if s.readOffset >= len(s.input) { + s.offset = len(s.input) + if s.ch == '\n' { + // Make sure we track final newlines at the end of the file + s.file.AddLine(s.offset) + } + s.ch = eof + return + } + + s.offset = s.readOffset + if s.ch == '\n' { + s.file.AddLine(s.offset) + } + + r, width := rune(s.input[s.readOffset]), 1 + switch { + case r == 0: + s.onError(s.offset, "illegal character NUL") + case r >= utf8.RuneSelf: + r, width = utf8.DecodeRune(s.input[s.readOffset:]) + if r == utf8.RuneError && width == 1 { + s.onError(s.offset, "illegal UTF-8 encoding") + } else if r == bom && s.offset > 0 { + s.onError(s.offset, "illegal byte order mark") + } + } + s.readOffset += width + s.ch = r +} + +func (s *Scanner) onError(offset int, msg string) { + if s.err != nil { + s.err(s.file.Pos(offset), msg) + } + s.numErrors++ +} + +// NumErrors returns the current number of errors encountered during scanning. +// This is useful as a fallback to detect errors when no ErrorHandler was +// provided to the scanner. +func (s *Scanner) NumErrors() int { return s.numErrors } + +// Scan scans the next token and returns the token's position, the token +// itself, and the token's literal string (when applicable). The end of the +// input is indicated by token.EOF. +// +// If the returned token is a literal (such as token.STRING), then lit contains +// the corresponding literal text (including surrounding quotes). +// +// If the returned token is a keyword, lit is the keyword text that was +// scanned. +// +// If the returned token is token.TERMINATOR, lit will contain "\n". +// +// If the returned token is token.ILLEGAL, lit contains the offending +// character. +// +// In all other cases, lit will be an empty string. +// +// For more tolerant parsing, Scan returns a valid token character whenever +// possible when a syntax error was encountered. Callers must check NumErrors +// or the number of times the provided ErrorHandler was invoked to ensure there +// were no errors found during scanning. +// +// Scan will inject line information to the file provided by NewScanner. +// Returned token positions are relative to that file. +func (s *Scanner) Scan() (pos token.Pos, tok token.Token, lit string) { +scanAgain: + s.skipWhitespace() + + // Start of current token. + pos = s.file.Pos(s.offset) + + var insertTerm bool + + // Determine token value + switch ch := s.ch; { + case isLetter(ch): + lit = s.scanIdentifier() + if len(lit) > 1 { // Keywords are always > 1 char + tok = token.Lookup(lit) + switch tok { + case token.IDENT, token.NULL, token.BOOL: + insertTerm = true + } + } else { + insertTerm = true + tok = token.IDENT + } + + case isDecimal(ch) || (ch == '.' && isDecimal(rune(s.peek()))): + insertTerm = true + tok, lit = s.scanNumber() + + default: + s.next() // Make progress + + // ch is now the first character in a sequence and s.ch is the second + // character. + + switch ch { + case eof: + if s.insertTerm { + s.insertTerm = false // Consumed EOF + return pos, token.TERMINATOR, "\n" + } + tok = token.EOF + + case '\n': + // This case is only reachable when s.insertTerm is true, since otherwise + // skipWhitespace consumes all other newlines. + s.insertTerm = false // Consumed newline + return pos, token.TERMINATOR, "\n" + + case '\'': + s.onError(pos.Offset(), "illegal single-quoted string; use double quotes") + insertTerm = true + tok = token.ILLEGAL + lit = s.scanString('\'', true, false) + + case '"': + insertTerm = true + tok = token.STRING + lit = s.scanString('"', true, false) + + case '`': + insertTerm = true + tok = token.STRING + lit = s.scanString('`', false, true) + + case '|': + if s.ch != '|' { + s.onError(s.offset, "missing second | in ||") + } else { + s.next() // consume second '|' + } + tok = token.OR + case '&': + if s.ch != '&' { + s.onError(s.offset, "missing second & in &&") + } else { + s.next() // consume second '&' + } + tok = token.AND + + case '!': // !, != + tok = s.switch2(token.NOT, token.NEQ, '=') + case '=': // =, == + tok = s.switch2(token.ASSIGN, token.EQ, '=') + case '<': // <, <= + tok = s.switch2(token.LT, token.LTE, '=') + case '>': // >, >= + tok = s.switch2(token.GT, token.GTE, '=') + case '+': + tok = token.ADD + case '-': + tok = token.SUB + case '*': + tok = token.MUL + case '/': + if s.ch == '/' || s.ch == '*' { + // //- or /*-style comment. + // + // If we're expected to inject a terminator, we can only do so if our + // comment goes to the end of the line. Otherwise, the terminator will + // have to be injected after the comment token. + if s.insertTerm && s.findLineEnd() { + // Reset position to the beginning of the comment. + s.ch = '/' + s.offset = pos.Offset() + s.readOffset = s.offset + 1 + s.insertTerm = false // Consumed newline + return pos, token.TERMINATOR, "\n" + } + comment := s.scanComment() + if s.mode&IncludeComments == 0 { + // Skip over comment + s.insertTerm = false // Consumed newline + goto scanAgain + } + tok = token.COMMENT + lit = comment + } else { + tok = token.DIV + } + + case '%': + tok = token.MOD + case '^': + tok = token.POW + case '{': + tok = token.LCURLY + case '}': + insertTerm = true + tok = token.RCURLY + case '(': + tok = token.LPAREN + case ')': + insertTerm = true + tok = token.RPAREN + case '[': + tok = token.LBRACK + case ']': + insertTerm = true + tok = token.RBRACK + case ',': + tok = token.COMMA + case '.': + // NOTE: Fractions starting with '.' are handled by outer switch + tok = token.DOT + + default: + // s.next() reports invalid BOMs so we don't need to repeat the error. + if ch != bom { + s.onError(pos.Offset(), fmt.Sprintf("illegal character %#U", ch)) + } + insertTerm = s.insertTerm // Preserve previous s.insertTerm state + tok = token.ILLEGAL + lit = string(ch) + } + } + + if s.mode&dontInsertTerms == 0 { + s.insertTerm = insertTerm + } + return +} + +func (s *Scanner) skipWhitespace() { + for s.ch == ' ' || s.ch == '\t' || s.ch == '\r' || (s.ch == '\n' && !s.insertTerm) { + s.next() + } +} + +func isLetter(ch rune) bool { + // We check for ASCII first as an optimization, and leave checking unicode + // (the slowest) to the very end. + return (lower(ch) >= 'a' && lower(ch) <= 'z') || + ch == '_' || + (ch >= utf8.RuneSelf && unicode.IsLetter(ch)) +} + +func lower(ch rune) rune { return ('a' - 'A') | ch } +func isDecimal(ch rune) bool { return '0' <= ch && ch <= '9' } +func isDigit(ch rune) bool { + return isDecimal(ch) || (ch >= utf8.RuneSelf && unicode.IsDigit(ch)) +} + +// scanIdentifier reads the string of valid identifier characters starting at +// s.offet. It must only be called when s.ch is a valid character which starts +// an identifier. +// +// scanIdentifier is highly optimized for identifiers are modifications must be +// made carefully. +func (s *Scanner) scanIdentifier() string { + off := s.offset + + // Optimize for common case of ASCII identifiers. + // + // Ranging over s.input[s.readOffset:] avoids bounds checks and avoids + // conversions to runes. + // + // We'll fall back to the slower path if we find a non-ASCII character. + for readOffset, b := range s.input[s.readOffset:] { + if (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || b == '_' || (b >= '0' && b <= '9') { + // Common case: ASCII character; don't assign a rune. + continue + } + s.readOffset += readOffset + if b > 0 && b < utf8.RuneSelf { + // Optimization: ASCII character that isn't a letter or number; we've + // reached the end of the identifier sequence and can terminate. We avoid + // the call to s.next() and the corresponding setup. + // + // This optimization only works because we know that s.ch (the current + // character when scanIdentifier was called) is never '\n' since '\n' + // cannot start an identifier. + s.ch = rune(b) + s.offset = s.readOffset + s.readOffset++ + goto exit + } + + // The preceding character is valid for an identifier because + // scanIdentifier is only called when s.ch is a letter; calling s.next() at + // s.readOffset will reset the scanner state. + s.next() + for isLetter(s.ch) || isDigit(s.ch) { + s.next() + } + + // No more valid characters for the identifier; terminate. + goto exit + } + + s.offset = len(s.input) + s.readOffset = len(s.input) + s.ch = eof + +exit: + return string(s.input[off:s.offset]) +} + +func (s *Scanner) scanNumber() (tok token.Token, lit string) { + tok = token.NUMBER + off := s.offset + + // Integer part of number + if s.ch != '.' { + s.digits() + } + + // Fractional part of number + if s.ch == '.' { + tok = token.FLOAT + + s.next() + s.digits() + } + + // Exponent + if lower(s.ch) == 'e' { + tok = token.FLOAT + + s.next() + if s.ch == '+' || s.ch == '-' { + s.next() + } + + if s.digits() == 0 { + s.onError(off, "exponent has no digits") + } + } + + return tok, string(s.input[off:s.offset]) +} + +// digits scans a sequence of digits. +func (s *Scanner) digits() (count int) { + for isDecimal(s.ch) { + s.next() + count++ + } + return +} + +func (s *Scanner) scanString(until rune, escape bool, multiline bool) string { + // subtract 1 to account for the opening '"' which was already consumed by + // the scanner forcing progress. + off := s.offset - 1 + + for { + ch := s.ch + if (!multiline && ch == '\n') || ch == eof { + s.onError(off, "string literal not terminated") + break + } + s.next() + if ch == until { + break + } + if escape && ch == '\\' { + s.scanEscape() + } + } + + return string(s.input[off:s.offset]) +} + +// scanEscape parses an escape sequence. In case of a syntax error, scanEscape +// stops at the offending character without consuming it. +func (s *Scanner) scanEscape() { + off := s.offset + + var ( + n int + base, max uint32 + ) + + switch s.ch { + case 'a', 'b', 'f', 'n', 'r', 't', 'v', '\\', '"': + s.next() + return + case '0', '1', '2', '3', '4', '5', '6', '7': + n, base, max = 3, 8, 255 + case 'x': + s.next() + n, base, max = 2, 16, 255 + case 'u': + s.next() + n, base, max = 4, 16, unicode.MaxRune + case 'U': + s.next() + n, base, max = 8, 16, unicode.MaxRune + default: + msg := "unknown escape sequence" + if s.ch == eof { + msg = "escape sequence not terminated" + } + s.onError(off, msg) + return + } + + var x uint32 + for n > 0 { + d := uint32(digitVal(s.ch)) + if d >= base { + msg := fmt.Sprintf("illegal character %#U in escape sequence", s.ch) + if s.ch == eof { + msg = "escape sequence not terminated" + } + s.onError(off, msg) + return + } + x = x*base + d + s.next() + n-- + } + + if x > max || x >= 0xD800 && x < 0xE000 { + s.onError(off, "escape sequence is invalid Unicode code point") + } +} + +func digitVal(ch rune) int { + switch { + case ch >= '0' && ch <= '9': + return int(ch - '0') + case lower(ch) >= 'a' && lower(ch) <= 'f': + return int(lower(ch) - 'a' + 10) + } + return 16 // Larger than any legal digit val +} + +func (s *Scanner) scanComment() string { + // The initial character in the comment was already consumed from the scanner + // forcing progress. + // + // slashComment will be true when the comment is a //- or /*-style comment. + + var ( + off = s.offset - 1 // Offset of initial character + numCR = 0 + + blockComment = false + ) + + if s.ch == '/' { // NOTE: s.ch is second character in comment sequence + // //-style comment. + // + // The final '\n' is not considered to be part of the comment. + if s.ch == '/' { + s.next() // Consume second '/' + } + + for s.ch != '\n' && s.ch != eof { + if s.ch == '\r' { + numCR++ + } + s.next() + } + + goto exit + } + + // /*-style comment. + blockComment = true + s.next() + for s.ch != eof { + ch := s.ch + if ch == '\r' { + numCR++ + } + s.next() + if ch == '*' && s.ch == '/' { + s.next() + goto exit + } + } + + s.onError(off, "block comment not terminated") + +exit: + lit := s.input[off:s.offset] + + // On Windows, a single comment line may end in "\r\n". We want to remove the + // final \r. + if numCR > 0 && len(lit) >= 1 && lit[len(lit)-1] == '\r' { + lit = lit[:len(lit)-1] + numCR-- + } + + if numCR > 0 { + lit = stripCR(lit, blockComment) + } + + return string(lit) +} + +func stripCR(b []byte, blockComment bool) []byte { + c := make([]byte, len(b)) + i := 0 + + for j, ch := range b { + if ch != '\r' || blockComment && i > len("/*") && c[i-1] == '*' && j+1 < len(b) && b[j+1] == '/' { + c[i] = ch + i++ + } + } + + return c[:i] +} + +// findLineEnd checks to see if a comment runs to the end of the line. +func (s *Scanner) findLineEnd() bool { + // NOTE: initial '/' is already consumed by forcing the scanner to progress. + + defer func(off int) { + // Reset scanner state to where it was upon calling findLineEnd. + s.ch = '/' + s.offset = off + s.readOffset = off + 1 + s.next() // Consume initial starting '/' again + }(s.offset - 1) + + // Read ahead until a newline, EOF, or non-comment token is found. + // We loop to consume multiple sequences of comment tokens. + for s.ch == '/' || s.ch == '*' { + if s.ch == '/' { + // //-style comments always contain newlines. + return true + } + + // We're looking at a /*-style comment; look for its newline. + s.next() + for s.ch != eof { + ch := s.ch + if ch == '\n' { + return true + } + s.next() + if ch == '*' && s.ch == '/' { // End of block comment + s.next() + break + } + } + + // Check to see if there's a newline after the block comment. + s.skipWhitespace() // s.insertTerm is set + if s.ch == eof || s.ch == '\n' { + return true + } + if s.ch != '/' { + // Non-comment token + return false + } + s.next() // Consume '/' at the end of the /* style-comment + } + + return false +} + +// switch2 returns a if s.ch is next, b otherwise. The scanner will be advanced +// if b is returned. +// +// This is used for tokens which can either be a single character but also are +// the starting character for a 2-length token (i.e., = and ==). +func (s *Scanner) switch2(a, b token.Token, next rune) token.Token { //nolint:unparam + if s.ch == next { + s.next() + return b + } + return a +} diff --git a/syntax/scanner/scanner_test.go b/syntax/scanner/scanner_test.go new file mode 100644 index 0000000000..38ddcf58ca --- /dev/null +++ b/syntax/scanner/scanner_test.go @@ -0,0 +1,272 @@ +package scanner + +import ( + "path/filepath" + "testing" + + "github.com/grafana/river/token" + "github.com/stretchr/testify/assert" +) + +type tokenExample struct { + tok token.Token + lit string +} + +var tokens = []tokenExample{ + // Special tokens + {token.COMMENT, "/* a comment */"}, + {token.COMMENT, "// a comment \n"}, + {token.COMMENT, "/*\r*/"}, + {token.COMMENT, "/**\r/*/"}, // golang/go#11151 + {token.COMMENT, "/**\r\r/*/"}, + {token.COMMENT, "//\r\n"}, + + // Identifiers and basic type literals + {token.IDENT, "foobar"}, + {token.IDENT, "a۰۱۸"}, + {token.IDENT, "foo६४"}, + {token.IDENT, "bar9876"}, + {token.IDENT, "ŝ"}, // golang/go#4000 + {token.IDENT, "ŝfoo"}, // golang/go#4000 + {token.NUMBER, "0"}, + {token.NUMBER, "1"}, + {token.NUMBER, "123456789012345678890"}, + {token.NUMBER, "01234567"}, + {token.FLOAT, "0."}, + {token.FLOAT, ".0"}, + {token.FLOAT, "3.14159265"}, + {token.FLOAT, "1e0"}, + {token.FLOAT, "1e+100"}, + {token.FLOAT, "1e-100"}, + {token.FLOAT, "2.71828e-1000"}, + {token.STRING, `"Hello, world!"`}, + {token.STRING, "`Hello, world!\\\\`"}, + + // Operators and delimiters + {token.ADD, "+"}, + {token.SUB, "-"}, + {token.MUL, "*"}, + {token.DIV, "/"}, + {token.MOD, "%"}, + {token.POW, "^"}, + + {token.AND, "&&"}, + {token.OR, "||"}, + + {token.EQ, "=="}, + {token.LT, "<"}, + {token.GT, ">"}, + {token.ASSIGN, "="}, + {token.NOT, "!"}, + + {token.NEQ, "!="}, + {token.LTE, "<="}, + {token.GTE, ">="}, + + {token.LPAREN, "("}, + {token.LBRACK, "["}, + {token.LCURLY, "{"}, + {token.COMMA, ","}, + {token.DOT, "."}, + + {token.RPAREN, ")"}, + {token.RBRACK, "]"}, + {token.RCURLY, "}"}, + + // Keywords + {token.NULL, "null"}, + {token.BOOL, "true"}, + {token.BOOL, "false"}, +} + +const whitespace = " \t \n\n\n" // Various whitespace to separate tokens + +var source = func() []byte { + var src []byte + for _, t := range tokens { + src = append(src, t.lit...) + src = append(src, whitespace...) + } + return src +}() + +// FuzzScanner ensures that the scanner will always be able to reach EOF +// regardless of input. +func FuzzScanner(f *testing.F) { + // Add each token into the corpus + for _, t := range tokens { + f.Add([]byte(t.lit)) + } + // Then add the entire source + f.Add(source) + + f.Fuzz(func(t *testing.T, input []byte) { + f := token.NewFile(t.Name()) + + s := New(f, input, nil, IncludeComments) + + for { + _, tok, _ := s.Scan() + if tok == token.EOF { + break + } + } + }) +} + +func TestScanner_Scan(t *testing.T) { + whitespaceLinecount := newlineCount(whitespace) + + var eh ErrorHandler = func(_ token.Pos, msg string) { + t.Errorf("ErrorHandler called (msg = %s)", msg) + } + + f := token.NewFile(t.Name()) + s := New(f, source, eh, IncludeComments|dontInsertTerms) + + // Configure expected position + expectPos := token.Position{ + Filename: t.Name(), + Offset: 0, + Line: 1, + Column: 1, + } + + index := 0 + for { + pos, tok, lit := s.Scan() + + // Check position + checkPos(t, lit, tok, pos, expectPos) + + // Check token + e := tokenExample{token.EOF, ""} + if index < len(tokens) { + e = tokens[index] + index++ + } + assert.Equal(t, e.tok, tok) + + // Check literal + expectLit := "" + switch e.tok { + case token.COMMENT: + // no CRs in comments + expectLit = string(stripCR([]byte(e.lit), e.lit[1] == '*')) + if expectLit[1] == '/' { + // Line comment literals doesn't contain newline + expectLit = expectLit[0 : len(expectLit)-1] + } + case token.IDENT: + expectLit = e.lit + case token.NUMBER, token.FLOAT, token.STRING, token.NULL, token.BOOL: + expectLit = e.lit + } + assert.Equal(t, expectLit, lit) + + if tok == token.EOF { + break + } + + // Update position + expectPos.Offset += len(e.lit) + len(whitespace) + expectPos.Line += newlineCount(e.lit) + whitespaceLinecount + } + + if s.NumErrors() != 0 { + assert.Zero(t, s.NumErrors(), "expected number of scanning errors to be 0") + } +} + +func newlineCount(s string) int { + var n int + for i := 0; i < len(s); i++ { + if s[i] == '\n' { + n++ + } + } + return n +} + +func checkPos(t *testing.T, lit string, tok token.Token, p token.Pos, expected token.Position) { + t.Helper() + + pos := p.Position() + + // Check cleaned filenames so that we don't have to worry about different + // os.PathSeparator values. + if pos.Filename != expected.Filename && filepath.Clean(pos.Filename) != filepath.Clean(expected.Filename) { + assert.Equal(t, expected.Filename, pos.Filename, "Bad filename for %s (%q)", tok, lit) + } + + assert.Equal(t, expected.Offset, pos.Offset, "Bad offset for %s (%q)", tok, lit) + assert.Equal(t, expected.Line, pos.Line, "Bad line for %s (%q)", tok, lit) + assert.Equal(t, expected.Column, pos.Column, "Bad column for %s (%q)", tok, lit) +} + +var errorTests = []struct { + input string + tok token.Token + pos int + lit string + err string +}{ + {"\a", token.ILLEGAL, 0, "", "illegal character U+0007"}, + {`…`, token.ILLEGAL, 0, "", "illegal character U+2026 '…'"}, + {"..", token.DOT, 0, "", ""}, // two periods, not invalid token (golang/go#28112) + {`'illegal string'`, token.ILLEGAL, 0, "", "illegal single-quoted string; use double quotes"}, + {`""`, token.STRING, 0, `""`, ""}, + {`"abc`, token.STRING, 0, `"abc`, "string literal not terminated"}, + {"\"abc\n", token.STRING, 0, `"abc`, "string literal not terminated"}, + {"\"abc\n ", token.STRING, 0, `"abc`, "string literal not terminated"}, + {"\"abc\x00def\"", token.STRING, 4, "\"abc\x00def\"", "illegal character NUL"}, + {"\"abc\x80def\"", token.STRING, 4, "\"abc\x80def\"", "illegal UTF-8 encoding"}, + {"\ufeff\ufeff", token.ILLEGAL, 3, "\ufeff\ufeff", "illegal byte order mark"}, // only first BOM is ignored + {"//\ufeff", token.COMMENT, 2, "//\ufeff", "illegal byte order mark"}, // only first BOM is ignored + {`"` + "abc\ufeffdef" + `"`, token.STRING, 4, `"` + "abc\ufeffdef" + `"`, "illegal byte order mark"}, // only first BOM is ignored + {"abc\x00def", token.IDENT, 3, "abc", "illegal character NUL"}, + {"abc\x00", token.IDENT, 3, "abc", "illegal character NUL"}, + {"10E", token.FLOAT, 0, "10E", "exponent has no digits"}, +} + +func TestScanner_Scan_Errors(t *testing.T) { + for _, e := range errorTests { + checkError(t, e.input, e.tok, e.pos, e.lit, e.err) + } +} + +func checkError(t *testing.T, src string, tok token.Token, pos int, lit, err string) { + t.Helper() + + var ( + actualErrors int + latestError string + latestPos token.Pos + ) + + eh := func(pos token.Pos, msg string) { + actualErrors++ + latestError = msg + latestPos = pos + } + + f := token.NewFile(t.Name()) + s := New(f, []byte(src), eh, IncludeComments|dontInsertTerms) + + _, actualTok, actualLit := s.Scan() + + assert.Equal(t, tok, actualTok) + if actualTok != token.ILLEGAL { + assert.Equal(t, lit, actualLit) + } + + expectErrors := 0 + if err != "" { + expectErrors = 1 + } + + assert.Equal(t, expectErrors, actualErrors, "Unexpected error count in src %q", src) + assert.Equal(t, err, latestError, "Unexpected error message in src %q", src) + assert.Equal(t, pos, latestPos.Offset(), "Unexpected offset in src %q", src) +} diff --git a/syntax/token/builder/builder.go b/syntax/token/builder/builder.go new file mode 100644 index 0000000000..1dc9b5d62b --- /dev/null +++ b/syntax/token/builder/builder.go @@ -0,0 +1,419 @@ +// Package builder exposes an API to create a River configuration file by +// constructing a set of tokens. +package builder + +import ( + "bytes" + "fmt" + "io" + "reflect" + "strings" + + "github.com/grafana/river/internal/reflectutil" + "github.com/grafana/river/internal/rivertags" + "github.com/grafana/river/internal/value" + "github.com/grafana/river/token" +) + +var goRiverDefaulter = reflect.TypeOf((*value.Defaulter)(nil)).Elem() + +// An Expr represents a single River expression. +type Expr struct { + rawTokens []Token +} + +// NewExpr creates a new Expr. +func NewExpr() *Expr { return &Expr{} } + +// Tokens returns the Expr as a set of Tokens. +func (e *Expr) Tokens() []Token { return e.rawTokens } + +// SetValue sets the Expr to a River value converted from a Go value. The Go +// value is encoded using the normal Go to River encoding rules. If any value +// reachable from goValue implements Tokenizer, the printed tokens will instead +// be retrieved by calling the RiverTokenize method. +func (e *Expr) SetValue(goValue interface{}) { + e.rawTokens = tokenEncode(goValue) +} + +// WriteTo renders and formats the File, writing the contents to w. +func (e *Expr) WriteTo(w io.Writer) (int64, error) { + n, err := printExprTokens(w, e.Tokens()) + return int64(n), err +} + +// Bytes renders the File to a formatted byte slice. +func (e *Expr) Bytes() []byte { + var buf bytes.Buffer + _, _ = e.WriteTo(&buf) + return buf.Bytes() +} + +// A File represents a River configuration file. +type File struct { + body *Body +} + +// NewFile creates a new File. +func NewFile() *File { return &File{body: newBody()} } + +// Tokens returns the File as a set of Tokens. +func (f *File) Tokens() []Token { return f.Body().Tokens() } + +// Body returns the Body contents of the file. +func (f *File) Body() *Body { return f.body } + +// WriteTo renders and formats the File, writing the contents to w. +func (f *File) WriteTo(w io.Writer) (int64, error) { + n, err := printFileTokens(w, f.Tokens()) + return int64(n), err +} + +// Bytes renders the File to a formatted byte slice. +func (f *File) Bytes() []byte { + var buf bytes.Buffer + _, _ = f.WriteTo(&buf) + return buf.Bytes() +} + +// Body is a list of block and attribute statements. A Body cannot be manually +// created, but is retrieved from a File or Block. +type Body struct { + nodes []tokenNode + valueOverrideHook ValueOverrideHook +} + +type ValueOverrideHook = func(val interface{}) interface{} + +// SetValueOverrideHook sets a hook to override the value that will be token +// encoded. The hook can mutate the value to be encoded or should return it +// unmodified. This hook can be skipped by leaving it nil or setting it to nil. +func (b *Body) SetValueOverrideHook(valueOverrideHook ValueOverrideHook) { + b.valueOverrideHook = valueOverrideHook +} + +func (b *Body) Nodes() []tokenNode { + return b.nodes +} + +// A tokenNode is a structural element which can be converted into a set of +// Tokens. +type tokenNode interface { + // Tokens builds the set of Tokens from the node. + Tokens() []Token +} + +func newBody() *Body { + return &Body{} +} + +// Tokens returns the File as a set of Tokens. +func (b *Body) Tokens() []Token { + var rawToks []Token + for i, node := range b.nodes { + rawToks = append(rawToks, node.Tokens()...) + + if i+1 < len(b.nodes) { + // Append a terminator between each statement in the Body. + rawToks = append(rawToks, Token{ + Tok: token.LITERAL, + Lit: "\n", + }) + } + } + return rawToks +} + +// AppendTokens appends raw tokens to the Body. +func (b *Body) AppendTokens(tokens []Token) { + b.nodes = append(b.nodes, tokensSlice(tokens)) +} + +// AppendBlock adds a new block inside of the Body. +func (b *Body) AppendBlock(block *Block) { + b.nodes = append(b.nodes, block) +} + +// AppendFrom sets attributes and appends blocks defined by goValue into the +// Body. If any value reachable from goValue implements Tokenizer, the printed +// tokens will instead be retrieved by calling the RiverTokenize method. +// +// Optional attributes and blocks set to default values are trimmed. +// If goValue implements Defaulter, default values are retrieved by +// calling SetToDefault against a copy. Otherwise, default values are +// the zero value of the respective Go types. +// +// goValue must be a struct or a pointer to a struct that contains River struct +// tags. +func (b *Body) AppendFrom(goValue interface{}) { + if goValue == nil { + return + } + + rv := reflect.ValueOf(goValue) + b.encodeFields(rv) +} + +// getBlockLabel returns the label for a given block. +func getBlockLabel(rv reflect.Value) string { + tags := rivertags.Get(rv.Type()) + for _, tag := range tags { + if tag.Flags&rivertags.FlagLabel != 0 { + return reflectutil.Get(rv, tag).String() + } + } + + return "" +} + +func (b *Body) encodeFields(rv reflect.Value) { + for rv.Kind() == reflect.Pointer { + if rv.IsNil() { + return + } + rv = rv.Elem() + } + if rv.Kind() != reflect.Struct { + panic(fmt.Sprintf("river/token/builder: can only encode struct values to bodies, got %s", rv.Type())) + } + + fields := rivertags.Get(rv.Type()) + defaults := reflect.New(rv.Type()).Elem() + if defaults.CanAddr() && defaults.Addr().Type().Implements(goRiverDefaulter) { + defaults.Addr().Interface().(value.Defaulter).SetToDefault() + } + + for _, field := range fields { + fieldVal := reflectutil.Get(rv, field) + fieldValDefault := reflectutil.Get(defaults, field) + + // Check if the values are exactly equal or if they're both equal to the + // zero value. Checking for both fields being zero handles the case where + // an empty and nil map are being compared (which are not equal, but are + // both zero values). + matchesDefault := reflect.DeepEqual(fieldVal.Interface(), fieldValDefault.Interface()) + isZero := fieldValDefault.IsZero() && fieldVal.IsZero() + + if field.IsOptional() && (matchesDefault || isZero) { + continue + } + + b.encodeField(nil, field, fieldVal) + } +} + +func (b *Body) encodeField(prefix []string, field rivertags.Field, fieldValue reflect.Value) { + fieldName := strings.Join(field.Name, ".") + + for fieldValue.Kind() == reflect.Pointer { + if fieldValue.IsNil() { + break + } + fieldValue = fieldValue.Elem() + } + + switch { + case field.IsAttr(): + b.SetAttributeValue(fieldName, fieldValue.Interface()) + + case field.IsBlock(): + fullName := mergeStringSlice(prefix, field.Name) + + switch { + case fieldValue.Kind() == reflect.Map: + // Iterate over the map and add each element as an attribute into it. + if fieldValue.Type().Key().Kind() != reflect.String { + panic("river/token/builder: unsupported map type for block; expected map[string]T, got " + fieldValue.Type().String()) + } + + inner := NewBlock(fullName, "") + inner.body.SetValueOverrideHook(b.valueOverrideHook) + b.AppendBlock(inner) + + iter := fieldValue.MapRange() + for iter.Next() { + mapKey, mapValue := iter.Key(), iter.Value() + inner.body.SetAttributeValue(mapKey.String(), mapValue.Interface()) + } + + case fieldValue.Kind() == reflect.Slice, fieldValue.Kind() == reflect.Array: + for i := 0; i < fieldValue.Len(); i++ { + elem := fieldValue.Index(i) + + // Recursively call encodeField for each element in the slice/array for + // non-zero blocks. The recursive call will hit the case below and add + // a new block for each field encountered. + if field.Flags&rivertags.FlagOptional != 0 && elem.IsZero() { + continue + } + b.encodeField(prefix, field, elem) + } + + case fieldValue.Kind() == reflect.Struct: + inner := NewBlock(fullName, getBlockLabel(fieldValue)) + inner.body.SetValueOverrideHook(b.valueOverrideHook) + inner.Body().encodeFields(fieldValue) + b.AppendBlock(inner) + } + + case field.IsEnum(): + // Blocks within an enum have a prefix set. + newPrefix := mergeStringSlice(prefix, field.Name) + + switch { + case fieldValue.Kind() == reflect.Slice, fieldValue.Kind() == reflect.Array: + for i := 0; i < fieldValue.Len(); i++ { + b.encodeEnumElement(newPrefix, fieldValue.Index(i)) + } + + default: + panic(fmt.Sprintf("river/token/builder: unrecognized enum kind %s", fieldValue.Kind())) + } + } +} + +func mergeStringSlice(a, b []string) []string { + if len(a) == 0 { + return b + } else if len(b) == 0 { + return a + } + + res := make([]string, 0, len(a)+len(b)) + res = append(res, a...) + res = append(res, b...) + return res +} + +func (b *Body) encodeEnumElement(prefix []string, enumElement reflect.Value) { + for enumElement.Kind() == reflect.Pointer { + if enumElement.IsNil() { + return + } + enumElement = enumElement.Elem() + } + + fields := rivertags.Get(enumElement.Type()) + + // Find the first non-zero field and encode it. + for _, field := range fields { + fieldVal := reflectutil.Get(enumElement, field) + if !fieldVal.IsValid() || fieldVal.IsZero() { + continue + } + + b.encodeField(prefix, field, fieldVal) + break + } +} + +// SetAttributeTokens sets an attribute to the Body whose value is a set of raw +// tokens. If the attribute was previously set, its value tokens are updated. +// +// Attributes will be written out in the order they were initially created. +func (b *Body) SetAttributeTokens(name string, tokens []Token) { + attr := b.getOrCreateAttribute(name) + attr.RawTokens = tokens +} + +func (b *Body) getOrCreateAttribute(name string) *attribute { + for _, n := range b.nodes { + if attr, ok := n.(*attribute); ok && attr.Name == name { + return attr + } + } + + newAttr := &attribute{Name: name} + b.nodes = append(b.nodes, newAttr) + return newAttr +} + +// SetAttributeValue sets an attribute in the Body whose value is converted +// from a Go value to a River value. The Go value is encoded using the normal +// Go to River encoding rules. If any value reachable from goValue implements +// Tokenizer, the printed tokens will instead be retrieved by calling the +// RiverTokenize method. +// +// If the attribute was previously set, its value tokens are updated. +// +// Attributes will be written out in the order they were initially crated. +func (b *Body) SetAttributeValue(name string, goValue interface{}) { + attr := b.getOrCreateAttribute(name) + + if b.valueOverrideHook != nil { + attr.RawTokens = tokenEncode(b.valueOverrideHook(goValue)) + } else { + attr.RawTokens = tokenEncode(goValue) + } +} + +type attribute struct { + Name string + RawTokens []Token +} + +func (attr *attribute) Tokens() []Token { + var toks []Token + + toks = append(toks, Token{Tok: token.IDENT, Lit: attr.Name}) + toks = append(toks, Token{Tok: token.ASSIGN}) + toks = append(toks, attr.RawTokens...) + + return toks +} + +// A Block encapsulates a body within a named and labeled River block. Blocks +// must be created by calling NewBlock, but its public struct fields may be +// safely modified by callers. +type Block struct { + // Public fields, safe to be changed by callers: + + Name []string + Label string + + // Private fields: + + body *Body +} + +// NewBlock returns a new Block with the given name and label. The name/label +// can be updated later by modifying the Block's public fields. +func NewBlock(name []string, label string) *Block { + return &Block{ + Name: name, + Label: label, + + body: newBody(), + } +} + +// Tokens returns the File as a set of Tokens. +func (b *Block) Tokens() []Token { + var toks []Token + + for i, frag := range b.Name { + toks = append(toks, Token{Tok: token.IDENT, Lit: frag}) + if i+1 < len(b.Name) { + toks = append(toks, Token{Tok: token.DOT}) + } + } + + toks = append(toks, Token{Tok: token.LITERAL, Lit: " "}) + + if b.Label != "" { + toks = append(toks, Token{Tok: token.STRING, Lit: fmt.Sprintf("%q", b.Label)}) + } + + toks = append(toks, Token{Tok: token.LCURLY}, Token{Tok: token.LITERAL, Lit: "\n"}) + toks = append(toks, b.body.Tokens()...) + toks = append(toks, Token{Tok: token.LITERAL, Lit: "\n"}, Token{Tok: token.RCURLY}) + + return toks +} + +// Body returns the Body contained within the Block. +func (b *Block) Body() *Body { return b.body } + +type tokensSlice []Token + +func (tn tokensSlice) Tokens() []Token { return []Token(tn) } diff --git a/syntax/token/builder/builder_test.go b/syntax/token/builder/builder_test.go new file mode 100644 index 0000000000..d363556929 --- /dev/null +++ b/syntax/token/builder/builder_test.go @@ -0,0 +1,411 @@ +package builder_test + +import ( + "bytes" + "fmt" + "testing" + "time" + + "github.com/grafana/river/parser" + "github.com/grafana/river/printer" + "github.com/grafana/river/token" + "github.com/grafana/river/token/builder" + "github.com/stretchr/testify/require" +) + +func TestBuilder_File(t *testing.T) { + f := builder.NewFile() + + f.Body().SetAttributeTokens("attr_1", []builder.Token{{Tok: token.NUMBER, Lit: "15"}}) + f.Body().SetAttributeTokens("attr_2", []builder.Token{{Tok: token.BOOL, Lit: "true"}}) + + b1 := builder.NewBlock([]string{"test", "block"}, "") + b1.Body().SetAttributeTokens("inner_attr", []builder.Token{{Tok: token.STRING, Lit: `"block 1"`}}) + f.Body().AppendBlock(b1) + + b2 := builder.NewBlock([]string{"test", "block"}, "labeled") + b2.Body().SetAttributeTokens("inner_attr", []builder.Token{{Tok: token.STRING, Lit: "`\"block 2`"}}) + f.Body().AppendBlock(b2) + + expect := format(t, ` + attr_1 = 15 + attr_2 = true + + test.block { + inner_attr = "block 1" + } + + test.block "labeled" { + inner_attr = `+"`\"block 2`"+` + } + `) + + require.Equal(t, expect, string(f.Bytes())) +} + +func TestBuilder_GoEncode(t *testing.T) { + f := builder.NewFile() + + f.Body().AppendTokens([]builder.Token{{token.COMMENT, "// Hello, world!"}}) + f.Body().SetAttributeValue("null_value", nil) + f.Body().AppendTokens([]builder.Token{{token.LITERAL, "\n"}}) + + f.Body().SetAttributeValue("num", 15) + f.Body().SetAttributeValue("string", "Hello, world!") + f.Body().SetAttributeValue("bool", true) + f.Body().SetAttributeValue("list", []int{0, 1, 2}) + f.Body().SetAttributeValue("func", func(int, int) int { return 0 }) + f.Body().SetAttributeValue("capsule", make(chan int)) + f.Body().AppendTokens([]builder.Token{{token.LITERAL, "\n"}}) + + f.Body().SetAttributeValue("map", map[string]interface{}{"foo": "bar"}) + f.Body().SetAttributeValue("map_2", map[string]interface{}{"non ident": "bar"}) + f.Body().AppendTokens([]builder.Token{{token.LITERAL, "\n"}}) + + f.Body().SetAttributeValue("mixed_list", []interface{}{ + 0, + true, + map[string]interface{}{"key": true}, + "Hello!", + }) + + expect := format(t, ` + // Hello, world! + null_value = null + + num = 15 + string = "Hello, world!" + bool = true + list = [0, 1, 2] + func = function + capsule = capsule("chan int") + + map = { + foo = "bar", + } + map_2 = { + "non ident" = "bar", + } + + mixed_list = [0, true, { + key = true, + }, "Hello!"] + `) + + require.Equal(t, expect, string(f.Bytes())) +} + +// TestBuilder_GoEncode_SortMapKeys ensures that object literals from unordered +// values (i.e., Go maps) are printed in a deterministic order by sorting the +// keys lexicographically. Other object literals should be printed in the order +// the keys are reported in (i.e., in the order presented by the Go structs). +func TestBuilder_GoEncode_SortMapKeys(t *testing.T) { + f := builder.NewFile() + + type Ordered struct { + SomeKey string `river:"some_key,attr"` + OtherKey string `river:"other_key,attr"` + } + + // Maps are unordered because you can't iterate over their keys in a + // consistent order. + var unordered = map[string]interface{}{ + "key_a": 1, + "key_c": 3, + "key_b": 2, + } + + f.Body().SetAttributeValue("ordered", Ordered{SomeKey: "foo", OtherKey: "bar"}) + f.Body().SetAttributeValue("unordered", unordered) + + expect := format(t, ` + ordered = { + some_key = "foo", + other_key = "bar", + } + unordered = { + key_a = 1, + key_b = 2, + key_c = 3, + } + `) + + require.Equal(t, expect, string(f.Bytes())) +} + +func TestBuilder_AppendFrom(t *testing.T) { + type InnerBlock struct { + Number int `river:"number,attr"` + } + + type Structure struct { + Field string `river:"field,attr"` + + Block InnerBlock `river:"block,block"` + OtherBlocks []InnerBlock `river:"other_block,block"` + } + + f := builder.NewFile() + f.Body().AppendFrom(Structure{ + Field: "some_value", + + Block: InnerBlock{Number: 1}, + OtherBlocks: []InnerBlock{ + {Number: 2}, + {Number: 3}, + }, + }) + + expect := format(t, ` + field = "some_value" + + block { + number = 1 + } + + other_block { + number = 2 + } + + other_block { + number = 3 + } + `) + + require.Equal(t, expect, string(f.Bytes())) +} + +func TestBuilder_AppendFrom_EnumSlice(t *testing.T) { + type InnerBlock struct { + Number int `river:"number,attr"` + } + + type EnumBlock struct { + BlockA InnerBlock `river:"a,block,optional"` + BlockB InnerBlock `river:"b,block,optional"` + BlockC InnerBlock `river:"c,block,optional"` + } + + type Structure struct { + Field string `river:"field,attr"` + + OtherBlocks []EnumBlock `river:"block,enum"` + } + + f := builder.NewFile() + f.Body().AppendFrom(Structure{ + Field: "some_value", + OtherBlocks: []EnumBlock{ + {BlockC: InnerBlock{Number: 1}}, + {BlockB: InnerBlock{Number: 2}}, + {BlockC: InnerBlock{Number: 3}}, + }, + }) + + expect := format(t, ` + field = "some_value" + + block.c { + number = 1 + } + + block.b { + number = 2 + } + + block.c { + number = 3 + } + `) + + require.Equal(t, expect, string(f.Bytes())) +} + +func TestBuilder_AppendFrom_EnumSlice_Pointer(t *testing.T) { + type InnerBlock struct { + Number int `river:"number,attr"` + } + + type EnumBlock struct { + BlockA *InnerBlock `river:"a,block,optional"` + BlockB *InnerBlock `river:"b,block,optional"` + BlockC *InnerBlock `river:"c,block,optional"` + } + + type Structure struct { + Field string `river:"field,attr"` + + OtherBlocks []EnumBlock `river:"block,enum"` + } + + f := builder.NewFile() + f.Body().AppendFrom(Structure{ + Field: "some_value", + OtherBlocks: []EnumBlock{ + {BlockC: &InnerBlock{Number: 1}}, + {BlockB: &InnerBlock{Number: 2}}, + {BlockC: &InnerBlock{Number: 3}}, + }, + }) + + expect := format(t, ` + field = "some_value" + + block.c { + number = 1 + } + + block.b { + number = 2 + } + + block.c { + number = 3 + } + `) + + require.Equal(t, expect, string(f.Bytes())) +} + +func TestBuilder_SkipOptional(t *testing.T) { + type Structure struct { + OptFieldA string `river:"opt_field_a,attr,optional"` + OptFieldB string `river:"opt_field_b,attr,optional"` + ReqFieldA string `river:"req_field_a,attr"` + ReqFieldB string `river:"req_field_b,attr"` + } + + f := builder.NewFile() + f.Body().AppendFrom(Structure{ + OptFieldA: "some_value", + OptFieldB: "", // Zero value + ReqFieldA: "some_value", + ReqFieldB: "", // Zero value + }) + + expect := format(t, ` + opt_field_a = "some_value" + req_field_a = "some_value" + req_field_b = "" + `) + + require.Equal(t, expect, string(f.Bytes())) +} + +func format(t *testing.T, in string) string { + t.Helper() + + f, err := parser.ParseFile(t.Name(), []byte(in)) + require.NoError(t, err) + + var buf bytes.Buffer + require.NoError(t, printer.Fprint(&buf, f)) + + return buf.String() +} + +type CustomTokenizer bool + +var _ builder.Tokenizer = (CustomTokenizer)(false) + +func (ct CustomTokenizer) RiverTokenize() []builder.Token { + return []builder.Token{{Tok: token.LITERAL, Lit: "CUSTOM_TOKENS"}} +} + +func TestBuilder_GoEncode_Tokenizer(t *testing.T) { + t.Run("Tokenizer", func(t *testing.T) { + f := builder.NewFile() + f.Body().SetAttributeValue("value", CustomTokenizer(true)) + + expect := format(t, `value = CUSTOM_TOKENS`) + require.Equal(t, expect, string(f.Bytes())) + }) + + t.Run("TextMarshaler", func(t *testing.T) { + now := time.Now() + expectBytes, err := now.MarshalText() + require.NoError(t, err) + + f := builder.NewFile() + f.Body().SetAttributeValue("value", now) + + expect := format(t, fmt.Sprintf(`value = %q`, string(expectBytes))) + require.Equal(t, expect, string(f.Bytes())) + }) + + t.Run("Duration", func(t *testing.T) { + dur := 15 * time.Second + + f := builder.NewFile() + f.Body().SetAttributeValue("value", dur) + + expect := format(t, fmt.Sprintf(`value = %q`, dur.String())) + require.Equal(t, expect, string(f.Bytes())) + }) +} + +func TestBuilder_ValueOverrideHook(t *testing.T) { + type InnerBlock struct { + AnotherField string `river:"another_field,attr"` + } + + type Structure struct { + Field string `river:"field,attr"` + + Block InnerBlock `river:"block,block"` + OtherBlocks []InnerBlock `river:"other_block,block"` + } + + f := builder.NewFile() + f.Body().SetValueOverrideHook(func(val interface{}) interface{} { + return "some other value" + }) + f.Body().AppendFrom(Structure{ + Field: "some_value", + + Block: InnerBlock{AnotherField: "some_value"}, + OtherBlocks: []InnerBlock{ + {AnotherField: "some_value"}, + {AnotherField: "some_value"}, + }, + }) + + expect := format(t, ` + field = "some other value" + + block { + another_field = "some other value" + } + + other_block { + another_field = "some other value" + } + + other_block { + another_field = "some other value" + } + `) + + require.Equal(t, expect, string(f.Bytes())) +} + +func TestBuilder_MapBlocks(t *testing.T) { + type block struct { + Value map[string]any `river:"block,block,optional"` + } + + f := builder.NewFile() + f.Body().AppendFrom(block{ + Value: map[string]any{ + "field": "value", + }, + }) + + expect := format(t, ` + block { + field = "value" + } + `) + + require.Equal(t, expect, string(f.Bytes())) +} diff --git a/syntax/token/builder/nested_defaults_test.go b/syntax/token/builder/nested_defaults_test.go new file mode 100644 index 0000000000..1fd8122b28 --- /dev/null +++ b/syntax/token/builder/nested_defaults_test.go @@ -0,0 +1,233 @@ +package builder_test + +import ( + "fmt" + "reflect" + "testing" + + "github.com/grafana/river/ast" + "github.com/grafana/river/parser" + "github.com/grafana/river/token/builder" + "github.com/grafana/river/vm" + "github.com/stretchr/testify/require" +) + +const ( + defaultNumber = 123 + otherDefaultNumber = 321 +) + +var testCases = []struct { + name string + input interface{} + expectedRiver string +}{ + { + name: "struct propagating default - input matching default", + input: StructPropagatingDefault{Inner: AttrWithDefault{Number: defaultNumber}}, + expectedRiver: "", + }, + { + name: "struct propagating default - input with zero-value struct", + input: StructPropagatingDefault{}, + expectedRiver: ` + inner { + number = 0 + } + `, + }, + { + name: "struct propagating default - input with non-default value", + input: StructPropagatingDefault{Inner: AttrWithDefault{Number: 42}}, + expectedRiver: ` + inner { + number = 42 + } + `, + }, + { + name: "pointer propagating default - input matching default", + input: PtrPropagatingDefault{Inner: &AttrWithDefault{Number: defaultNumber}}, + expectedRiver: "", + }, + { + name: "pointer propagating default - input with zero value", + input: PtrPropagatingDefault{Inner: &AttrWithDefault{}}, + expectedRiver: ` + inner { + number = 0 + } + `, + }, + { + name: "pointer propagating default - input with non-default value", + input: PtrPropagatingDefault{Inner: &AttrWithDefault{Number: 42}}, + expectedRiver: ` + inner { + number = 42 + } + `, + }, + { + name: "zero default - input with zero value", + input: ZeroDefault{Inner: &AttrWithDefault{}}, + expectedRiver: "", + }, + { + name: "zero default - input with non-default value", + input: ZeroDefault{Inner: &AttrWithDefault{Number: 42}}, + expectedRiver: ` + inner { + number = 42 + } + `, + }, + { + name: "no default - input with zero value", + input: NoDefaultDefined{Inner: &AttrWithDefault{}}, + expectedRiver: ` + inner { + number = 0 + } + `, + }, + { + name: "no default - input with non-default value", + input: NoDefaultDefined{Inner: &AttrWithDefault{Number: 42}}, + expectedRiver: ` + inner { + number = 42 + } + `, + }, + { + name: "mismatching default - input matching outer default", + input: MismatchingDefault{Inner: &AttrWithDefault{Number: otherDefaultNumber}}, + expectedRiver: "", + }, + { + name: "mismatching default - input matching inner default", + input: MismatchingDefault{Inner: &AttrWithDefault{Number: defaultNumber}}, + expectedRiver: "inner { }", + }, + { + name: "mismatching default - input with non-default value", + input: MismatchingDefault{Inner: &AttrWithDefault{Number: 42}}, + expectedRiver: ` + inner { + number = 42 + } + `, + }, +} + +func TestNestedDefaults(t *testing.T) { + for _, tc := range testCases { + t.Run(fmt.Sprintf("%T/%s", tc.input, tc.name), func(t *testing.T) { + f := builder.NewFile() + f.Body().AppendFrom(tc.input) + actualRiver := string(f.Bytes()) + expected := format(t, tc.expectedRiver) + require.Equal(t, expected, actualRiver, "generated river didn't match expected") + + // Now decode the River produced above and make sure it's the same as the input. + eval := vm.New(parseBlock(t, actualRiver)) + vPtr := reflect.New(reflect.TypeOf(tc.input)).Interface() + require.NoError(t, eval.Evaluate(nil, vPtr), "river evaluation error") + + actualOut := reflect.ValueOf(vPtr).Elem().Interface() + require.Equal(t, tc.input, actualOut, "Invariant violated: encoded and then decoded block didn't match the original value") + }) + } +} + +func TestPtrPropagatingDefaultWithNil(t *testing.T) { + // This is a special case - when defaults are correctly defined, the `Inner: nil` should mean to use defaults. + // Encoding will encode to empty string and decoding will produce the default value - `Inner: {Number: 123}`. + input := PtrPropagatingDefault{} + expectedEncodedRiver := "" + expectedDecoded := PtrPropagatingDefault{Inner: &AttrWithDefault{Number: 123}} + + f := builder.NewFile() + f.Body().AppendFrom(input) + actualRiver := string(f.Bytes()) + expected := format(t, expectedEncodedRiver) + require.Equal(t, expected, actualRiver, "generated river didn't match expected") + + // Now decode the River produced above and make sure it's the same as the input. + eval := vm.New(parseBlock(t, actualRiver)) + vPtr := reflect.New(reflect.TypeOf(input)).Interface() + require.NoError(t, eval.Evaluate(nil, vPtr), "river evaluation error") + + actualOut := reflect.ValueOf(vPtr).Elem().Interface() + require.Equal(t, expectedDecoded, actualOut) +} + +// StructPropagatingDefault has the outer defaults matching the inner block's defaults. The inner block is a struct. +type StructPropagatingDefault struct { + Inner AttrWithDefault `river:"inner,block,optional"` +} + +func (o *StructPropagatingDefault) SetToDefault() { + inner := &AttrWithDefault{} + inner.SetToDefault() + *o = StructPropagatingDefault{Inner: *inner} +} + +// PtrPropagatingDefault has the outer defaults matching the inner block's defaults. The inner block is a pointer. +type PtrPropagatingDefault struct { + Inner *AttrWithDefault `river:"inner,block,optional"` +} + +func (o *PtrPropagatingDefault) SetToDefault() { + inner := &AttrWithDefault{} + inner.SetToDefault() + *o = PtrPropagatingDefault{Inner: inner} +} + +// MismatchingDefault has the outer defaults NOT matching the inner block's defaults. The inner block is a pointer. +type MismatchingDefault struct { + Inner *AttrWithDefault `river:"inner,block,optional"` +} + +func (o *MismatchingDefault) SetToDefault() { + *o = MismatchingDefault{Inner: &AttrWithDefault{ + Number: otherDefaultNumber, + }} +} + +// ZeroDefault has the outer defaults setting to zero values. The inner block is a pointer. +type ZeroDefault struct { + Inner *AttrWithDefault `river:"inner,block,optional"` +} + +func (o *ZeroDefault) SetToDefault() { + *o = ZeroDefault{Inner: &AttrWithDefault{}} +} + +// NoDefaultDefined has no defaults defined. The inner block is a pointer. +type NoDefaultDefined struct { + Inner *AttrWithDefault `river:"inner,block,optional"` +} + +// AttrWithDefault has a default value of a non-zero number. +type AttrWithDefault struct { + Number int `river:"number,attr,optional"` +} + +func (i *AttrWithDefault) SetToDefault() { + *i = AttrWithDefault{Number: defaultNumber} +} + +func parseBlock(t *testing.T, input string) *ast.BlockStmt { + t.Helper() + + input = fmt.Sprintf("test { %s }", input) + res, err := parser.ParseFile("", []byte(input)) + require.NoError(t, err) + require.Len(t, res.Body, 1) + + stmt, ok := res.Body[0].(*ast.BlockStmt) + require.True(t, ok, "Expected stmt to be a ast.BlockStmt, got %T", res.Body[0]) + return stmt +} diff --git a/syntax/token/builder/token.go b/syntax/token/builder/token.go new file mode 100644 index 0000000000..390b968959 --- /dev/null +++ b/syntax/token/builder/token.go @@ -0,0 +1,81 @@ +package builder + +import ( + "bytes" + "io" + + "github.com/grafana/river/parser" + "github.com/grafana/river/printer" + "github.com/grafana/river/token" +) + +// A Token is a wrapper around token.Token which contains the token type +// alongside its literal. Use LiteralTok as the Tok field to write literal +// characters such as whitespace. +type Token struct { + Tok token.Token + Lit string +} + +// printFileTokens prints out the tokens as River text and formats them, writing +// the final result to w. +func printFileTokens(w io.Writer, toks []Token) (int, error) { + var raw bytes.Buffer + for _, tok := range toks { + switch { + case tok.Tok == token.LITERAL: + raw.WriteString(tok.Lit) + case tok.Tok == token.COMMENT: + raw.WriteString(tok.Lit) + case tok.Tok.IsLiteral() || tok.Tok.IsKeyword(): + raw.WriteString(tok.Lit) + default: + raw.WriteString(tok.Tok.String()) + } + } + + f, err := parser.ParseFile("", raw.Bytes()) + if err != nil { + return 0, err + } + + wc := &writerCount{w: w} + err = printer.Fprint(wc, f) + return wc.n, err +} + +// printExprTokens prints out the tokens as River text and formats them, +// writing the final result to w. +func printExprTokens(w io.Writer, toks []Token) (int, error) { + var raw bytes.Buffer + for _, tok := range toks { + switch { + case tok.Tok == token.LITERAL: + raw.WriteString(tok.Lit) + case tok.Tok.IsLiteral() || tok.Tok.IsKeyword(): + raw.WriteString(tok.Lit) + default: + raw.WriteString(tok.Tok.String()) + } + } + + expr, err := parser.ParseExpression(raw.String()) + if err != nil { + return 0, err + } + + wc := &writerCount{w: w} + err = printer.Fprint(wc, expr) + return wc.n, err +} + +type writerCount struct { + w io.Writer + n int +} + +func (wc *writerCount) Write(p []byte) (n int, err error) { + n, err = wc.w.Write(p) + wc.n += n + return +} diff --git a/syntax/token/builder/value_tokens.go b/syntax/token/builder/value_tokens.go new file mode 100644 index 0000000000..c73e34f7b6 --- /dev/null +++ b/syntax/token/builder/value_tokens.go @@ -0,0 +1,95 @@ +package builder + +import ( + "fmt" + "sort" + + "github.com/grafana/river/internal/value" + "github.com/grafana/river/scanner" + "github.com/grafana/river/token" +) + +// TODO(rfratto): check for optional values + +// Tokenizer is any value which can return a raw set of tokens. +type Tokenizer interface { + // RiverTokenize returns the raw set of River tokens which are used when + // printing out the value with river/token/builder. + RiverTokenize() []Token +} + +func tokenEncode(val interface{}) []Token { + return valueTokens(value.Encode(val)) +} + +func valueTokens(v value.Value) []Token { + var toks []Token + + // If v is a Tokenizer, allow it to override what tokens get generated. + if tk, ok := v.Interface().(Tokenizer); ok { + return tk.RiverTokenize() + } + + switch v.Type() { + case value.TypeNull: + toks = append(toks, Token{token.NULL, "null"}) + + case value.TypeNumber: + toks = append(toks, Token{token.NUMBER, v.Number().ToString()}) + + case value.TypeString: + toks = append(toks, Token{token.STRING, fmt.Sprintf("%q", v.Text())}) + + case value.TypeBool: + toks = append(toks, Token{token.STRING, fmt.Sprintf("%v", v.Bool())}) + + case value.TypeArray: + toks = append(toks, Token{token.LBRACK, ""}) + elems := v.Len() + for i := 0; i < elems; i++ { + elem := v.Index(i) + + toks = append(toks, valueTokens(elem)...) + if i+1 < elems { + toks = append(toks, Token{token.COMMA, ""}) + } + } + toks = append(toks, Token{token.RBRACK, ""}) + + case value.TypeObject: + toks = append(toks, Token{token.LCURLY, ""}, Token{token.LITERAL, "\n"}) + + keys := v.Keys() + + // If v isn't an ordered object (i.e., a go map), sort the keys so they + // have a deterministic print order. + if !v.OrderedKeys() { + sort.Strings(keys) + } + + for i := 0; i < len(keys); i++ { + if scanner.IsValidIdentifier(keys[i]) { + toks = append(toks, Token{token.IDENT, keys[i]}) + } else { + toks = append(toks, Token{token.STRING, fmt.Sprintf("%q", keys[i])}) + } + + field, _ := v.Key(keys[i]) + toks = append(toks, Token{token.ASSIGN, ""}) + toks = append(toks, valueTokens(field)...) + toks = append(toks, Token{token.COMMA, ""}, Token{token.LITERAL, "\n"}) + } + toks = append(toks, Token{token.RCURLY, ""}) + + case value.TypeFunction: + toks = append(toks, Token{token.LITERAL, v.Describe()}) + + case value.TypeCapsule: + toks = append(toks, Token{token.LITERAL, v.Describe()}) + + default: + panic(fmt.Sprintf("river/token/builder: unrecognized value type %q", v.Type())) + } + + return toks +} diff --git a/syntax/token/file.go b/syntax/token/file.go new file mode 100644 index 0000000000..419dbaa57c --- /dev/null +++ b/syntax/token/file.go @@ -0,0 +1,142 @@ +package token + +import ( + "fmt" + "sort" + "strconv" +) + +// NoPos is the zero value for Pos. It has no file or line information +// associated with it, and NoPos.Valid is false. +var NoPos = Pos{} + +// Pos is a compact representation of a position within a file. It can be +// converted into a Position for a more convenient, but larger, representation. +type Pos struct { + file *File + off int +} + +// String returns the string form of the Pos (the offset). +func (p Pos) String() string { return strconv.Itoa(p.off) } + +// File returns the file used by the Pos. This will be nil for invalid +// positions. +func (p Pos) File() *File { return p.file } + +// Position converts the Pos into a Position. +func (p Pos) Position() Position { return p.file.PositionFor(p) } + +// Add creates a new Pos relative to p. +func (p Pos) Add(n int) Pos { + return Pos{ + file: p.file, + off: p.off + n, + } +} + +// Offset returns the byte offset associated with Pos. +func (p Pos) Offset() int { return p.off } + +// Valid reports whether the Pos is valid. +func (p Pos) Valid() bool { return p != NoPos } + +// Position holds full position information for a location within an individual +// file. +type Position struct { + Filename string // Filename (if any) + Offset int // Byte offset (starting at 0) + Line int // Line number (starting at 1) + Column int // Offset from start of line (starting at 1) +} + +// Valid reports whether the position is valid. Valid positions must have a +// Line value greater than 0. +func (pos *Position) Valid() bool { return pos.Line > 0 } + +// String returns a string in one of the following forms: +// +// file:line:column Valid position with file name +// file:line Valid position with file name but no column +// line:column Valid position with no file name +// line Valid position with no file name or column +// file Invalid position with file name +// - Invalid position with no file name +func (pos Position) String() string { + s := pos.Filename + + if pos.Valid() { + if s != "" { + s += ":" + } + s += fmt.Sprintf("%d", pos.Line) + if pos.Column != 0 { + s += fmt.Sprintf(":%d", pos.Column) + } + } + + if s == "" { + s = "-" + } + return s +} + +// File holds position information for a specific file. +type File struct { + filename string + lines []int // Byte offset of each line number (first element is always 0). +} + +// NewFile creates a new File for storing position information. +func NewFile(filename string) *File { + return &File{ + filename: filename, + lines: []int{0}, + } +} + +// Pos returns a Pos given a byte offset. Pos panics if off is < 0. +func (f *File) Pos(off int) Pos { + if off < 0 { + panic("Pos: illegal offset") + } + return Pos{file: f, off: off} +} + +// Name returns the name of the file. +func (f *File) Name() string { return f.filename } + +// AddLine tracks a new line from a byte offset. The line offset must be larger +// than the offset for the previous line, otherwise the line offset is ignored. +func (f *File) AddLine(offset int) { + lines := len(f.lines) + if f.lines[lines-1] < offset { + f.lines = append(f.lines, offset) + } +} + +// PositionFor returns a Position from an offset. +func (f *File) PositionFor(p Pos) Position { + if p == NoPos { + return Position{} + } + + // Search through our line offsets to find the line/column info. The else + // case should never happen here, but if it does, Position.Valid will return + // false. + var line, column int + if i := searchInts(f.lines, p.off); i >= 0 { + line, column = i+1, p.off-f.lines[i]+1 + } + + return Position{ + Filename: f.filename, + Offset: p.off, + Line: line, + Column: column, + } +} + +func searchInts(a []int, x int) int { + return sort.Search(len(a), func(i int) bool { return a[i] > x }) - 1 +} diff --git a/syntax/token/token.go b/syntax/token/token.go new file mode 100644 index 0000000000..74dabc9de2 --- /dev/null +++ b/syntax/token/token.go @@ -0,0 +1,174 @@ +// Package token defines the lexical elements of a River config and utilities +// surrounding their position. +package token + +// Token is an individual River lexical token. +type Token int + +// List of all lexical tokens and examples that represent them. +// +// LITERAL is used by token/builder to represent literal strings for writing +// tokens, but never used for reading (so scanner never returns a +// token.LITERAL). +const ( + ILLEGAL Token = iota // Invalid token. + LITERAL // Literal text. + EOF // End-of-file. + COMMENT // // Hello, world! + + literalBeg + IDENT // foobar + NUMBER // 1234 + FLOAT // 1234.0 + STRING // "foobar" + literalEnd + + keywordBeg + BOOL // true + NULL // null + keywordEnd + + operatorBeg + OR // || + AND // && + NOT // ! + + ASSIGN // = + + EQ // == + NEQ // != + LT // < + LTE // <= + GT // > + GTE // >= + + ADD // + + SUB // - + MUL // * + DIV // / + MOD // % + POW // ^ + + LCURLY // { + RCURLY // } + LPAREN // ( + RPAREN // ) + LBRACK // [ + RBRACK // ] + COMMA // , + DOT // . + operatorEnd + + TERMINATOR // \n +) + +var tokenNames = [...]string{ + ILLEGAL: "ILLEGAL", + LITERAL: "LITERAL", + EOF: "EOF", + COMMENT: "COMMENT", + + IDENT: "IDENT", + NUMBER: "NUMBER", + FLOAT: "FLOAT", + STRING: "STRING", + BOOL: "BOOL", + NULL: "NULL", + + OR: "||", + AND: "&&", + NOT: "!", + + ASSIGN: "=", + EQ: "==", + NEQ: "!=", + LT: "<", + LTE: "<=", + GT: ">", + GTE: ">=", + + ADD: "+", + SUB: "-", + MUL: "*", + DIV: "/", + MOD: "%", + POW: "^", + + LCURLY: "{", + RCURLY: "}", + LPAREN: "(", + RPAREN: ")", + LBRACK: "[", + RBRACK: "]", + COMMA: ",", + DOT: ".", + + TERMINATOR: "TERMINATOR", +} + +// Lookup maps a string to its keyword token or IDENT if it's not a keyword. +func Lookup(ident string) Token { + switch ident { + case "true", "false": + return BOOL + case "null": + return NULL + default: + return IDENT + } +} + +// String returns the string representation corresponding to the token. +func (t Token) String() string { + if int(t) >= len(tokenNames) { + return "ILLEGAL" + } + + name := tokenNames[t] + if name == "" { + return "ILLEGAL" + } + return name +} + +// GoString returns the %#v format of t. +func (t Token) GoString() string { return t.String() } + +// IsKeyword returns true if the token corresponds to a keyword. +func (t Token) IsKeyword() bool { return t > keywordBeg && t < keywordEnd } + +// IsLiteral returns true if the token corresponds to a literal token or +// identifier. +func (t Token) IsLiteral() bool { return t > literalBeg && t < literalEnd } + +// IsOperator returns true if the token corresponds to an operator or +// delimiter. +func (t Token) IsOperator() bool { return t > operatorBeg && t < operatorEnd } + +// BinaryPrecedence returns the operator precedence of the binary operator t. +// If t is not a binary operator, the result is LowestPrecedence. +func (t Token) BinaryPrecedence() int { + switch t { + case OR: + return 1 + case AND: + return 2 + case EQ, NEQ, LT, LTE, GT, GTE: + return 3 + case ADD, SUB: + return 4 + case MUL, DIV, MOD: + return 5 + case POW: + return 6 + } + + return LowestPrecedence +} + +// Levels of precedence for operator tokens. +const ( + LowestPrecedence = 0 // non-operators + UnaryPrecedence = 7 + HighestPrecedence = 8 +) diff --git a/syntax/types.go b/syntax/types.go new file mode 100644 index 0000000000..b0123010b8 --- /dev/null +++ b/syntax/types.go @@ -0,0 +1,97 @@ +package river + +import "github.com/grafana/river/internal/value" + +// Our types in this file are re-implementations of interfaces from +// value.Capsule. They are *not* defined as type aliases, since pkg.go.dev +// would show the type alias instead of the contents of that type (which IMO is +// a frustrating user experience). +// +// The types below must be kept in sync with the internal package, and the +// checks below ensure they're compatible. +var ( + _ value.Defaulter = (Defaulter)(nil) + _ value.Unmarshaler = (Unmarshaler)(nil) + _ value.Validator = (Validator)(nil) + _ value.Capsule = (Capsule)(nil) + _ value.ConvertibleFromCapsule = (ConvertibleFromCapsule)(nil) + _ value.ConvertibleIntoCapsule = (ConvertibleIntoCapsule)(nil) +) + +// The Unmarshaler interface allows a type to hook into River decoding and +// decode into another type or provide pre-decoding logic. +type Unmarshaler interface { + // UnmarshalRiver is invoked when decoding a River value into a Go value. f + // should be called with a pointer to a value to decode into. UnmarshalRiver + // will not be called on types which are squashed into the parent struct + // using `river:",squash"`. + UnmarshalRiver(f func(v interface{}) error) error +} + +// The Defaulter interface allows a type to implement default functionality +// in River evaluation. +type Defaulter interface { + // SetToDefault is called when evaluating a block or body to set the value + // to its defaults. SetToDefault will not be called on types which are + // squashed into the parent struct using `river:",squash"`. + SetToDefault() +} + +// The Validator interface allows a type to implement validation functionality +// in River evaluation. +type Validator interface { + // Validate is called when evaluating a block or body to enforce the + // value is valid. Validate will not be called on types which are + // squashed into the parent struct using `river:",squash"`. + Validate() error +} + +// Capsule is an interface marker which tells River that a type should always +// be treated as a "capsule type" instead of the default type River would +// assign. +// +// Capsule types are useful for passing around arbitrary Go values in River +// expressions and for declaring new synthetic types with custom conversion +// rules. +// +// By default, only two capsule values of the same underlying Go type are +// compatible. Types which implement ConvertibleFromCapsule or +// ConvertibleToCapsule can provide custom logic for conversions from and to +// other types. +type Capsule interface { + // RiverCapsule marks the type as a Capsule. RiverCapsule is never invoked by + // River. + RiverCapsule() +} + +// ErrNoConversion is returned by implementations of ConvertibleFromCapsule and +// ConvertibleToCapsule when a conversion with a specific type is unavailable. +// +// Returning this error causes River to fall back to default conversion rules. +var ErrNoConversion = value.ErrNoConversion + +// ConvertibleFromCapsule is a Capsule which supports custom conversion from +// any Go type which is not the same as the capsule type. +type ConvertibleFromCapsule interface { + Capsule + + // ConvertFrom updates the ConvertibleFromCapsule value based on the value of + // src. src may be any Go value, not just other capsules. + // + // ConvertFrom should return ErrNoConversion if no conversion is available + // from src. Other errors are treated as a River decoding error. + ConvertFrom(src interface{}) error +} + +// ConvertibleIntoCapsule is a Capsule which supports custom conversion into +// any Go type which is not the same as the capsule type. +type ConvertibleIntoCapsule interface { + Capsule + + // ConvertInto should convert its value and store it into dst. dst will be a + // pointer to a Go value of any type. + // + // ConvertInto should return ErrNoConversion if no conversion into dst is + // available. Other errors are treated as a River decoding error. + ConvertInto(dst interface{}) error +} diff --git a/syntax/vm/constant.go b/syntax/vm/constant.go new file mode 100644 index 0000000000..d2e54c717d --- /dev/null +++ b/syntax/vm/constant.go @@ -0,0 +1,64 @@ +package vm + +import ( + "fmt" + "strconv" + + "github.com/grafana/river/internal/value" + "github.com/grafana/river/token" +) + +func valueFromLiteral(lit string, tok token.Token) (value.Value, error) { + // NOTE(rfratto): this function should never return an error, since the + // parser only produces valid tokens; it can only fail if a user hand-builds + // an AST with invalid literals. + + switch tok { + case token.NULL: + return value.Null, nil + + case token.NUMBER: + intVal, err1 := strconv.ParseInt(lit, 0, 64) + if err1 == nil { + return value.Int(intVal), nil + } + + uintVal, err2 := strconv.ParseUint(lit, 0, 64) + if err2 == nil { + return value.Uint(uintVal), nil + } + + floatVal, err3 := strconv.ParseFloat(lit, 64) + if err3 == nil { + return value.Float(floatVal), nil + } + + return value.Null, err3 + + case token.FLOAT: + v, err := strconv.ParseFloat(lit, 64) + if err != nil { + return value.Null, err + } + return value.Float(v), nil + + case token.STRING: + v, err := strconv.Unquote(lit) + if err != nil { + return value.Null, err + } + return value.String(v), nil + + case token.BOOL: + switch lit { + case "true": + return value.Bool(true), nil + case "false": + return value.Bool(false), nil + default: + return value.Null, fmt.Errorf("invalid boolean literal %q", lit) + } + default: + panic(fmt.Sprintf("%v is not a valid token", tok)) + } +} diff --git a/syntax/vm/error.go b/syntax/vm/error.go new file mode 100644 index 0000000000..c82a5b418e --- /dev/null +++ b/syntax/vm/error.go @@ -0,0 +1,106 @@ +package vm + +import ( + "fmt" + "strings" + + "github.com/grafana/river/ast" + "github.com/grafana/river/diag" + "github.com/grafana/river/internal/value" + "github.com/grafana/river/printer" + "github.com/grafana/river/token/builder" +) + +// makeDiagnostic tries to convert err into a diag.Diagnostic. err must be an +// error from the river/internal/value package, otherwise err will be returned +// unmodified. +func makeDiagnostic(err error, assoc map[value.Value]ast.Node) error { + var ( + node ast.Node + expr strings.Builder + message string + cause value.Value + + // Until we find a node, we're not a literal error. + literal = false + ) + + isValueError := value.WalkError(err, func(err error) { + var val value.Value + + switch ne := err.(type) { + case value.Error: + message = ne.Error() + val = ne.Value + case value.TypeError: + message = fmt.Sprintf("should be %s, got %s", ne.Expected, ne.Value.Type()) + val = ne.Value + case value.MissingKeyError: + message = fmt.Sprintf("does not have field named %q", ne.Missing) + val = ne.Value + case value.ElementError: + fmt.Fprintf(&expr, "[%d]", ne.Index) + val = ne.Value + case value.FieldError: + fmt.Fprintf(&expr, ".%s", ne.Field) + val = ne.Value + } + + cause = val + + if foundNode, ok := assoc[val]; ok { + // If we just found a direct node, we can reset the expression buffer so + // we don't unnecessarily print element and field accesses for we can see + // directly in the file. + if literal { + expr.Reset() + } + + node = foundNode + literal = true + } else { + literal = false + } + }) + if !isValueError { + return err + } + + if node != nil { + var nodeText strings.Builder + if err := printer.Fprint(&nodeText, node); err != nil { + // This should never panic; printer.Fprint only fails when given an + // unexpected type, which we never do here. + panic(err) + } + + // Merge the node text with the expression together (which will be relative + // accesses to the expression). + message = fmt.Sprintf("%s%s %s", nodeText.String(), expr.String(), message) + } else { + message = fmt.Sprintf("%s %s", expr.String(), message) + } + + // Render the underlying problematic value as a string. + var valueText string + if cause != value.Null { + be := builder.NewExpr() + be.SetValue(cause.Interface()) + valueText = string(be.Bytes()) + } + if literal { + // Hide the value if the node itself has the error we were worried about. + valueText = "" + } + + d := diag.Diagnostic{ + Severity: diag.SeverityLevelError, + Message: message, + Value: valueText, + } + if node != nil { + d.StartPos = ast.StartPos(node).Position() + d.EndPos = ast.EndPos(node).Position() + } + return d +} diff --git a/syntax/vm/op_binary.go b/syntax/vm/op_binary.go new file mode 100644 index 0000000000..e329f6fc49 --- /dev/null +++ b/syntax/vm/op_binary.go @@ -0,0 +1,360 @@ +package vm + +import ( + "fmt" + "math" + "reflect" + + "github.com/grafana/river/internal/value" + "github.com/grafana/river/rivertypes" + "github.com/grafana/river/token" +) + +func evalBinop(lhs value.Value, op token.Token, rhs value.Value) (value.Value, error) { + // Original parameters of lhs and rhs used for returning errors. + var ( + origLHS = lhs + origRHS = rhs + ) + + // Hack to allow OptionalSecrets to be used in binary operations. + // + // TODO(rfratto): be more flexible in the future with broader definitions of + // how capsules can be converted to other types for the purposes of doing a + // binop. + if lhs.Type() == value.TypeCapsule { + lhs = tryUnwrapOptionalSecret(lhs) + } + if rhs.Type() == value.TypeCapsule { + rhs = tryUnwrapOptionalSecret(rhs) + } + + // TODO(rfratto): evalBinop should check for underflows and overflows + + // We have special handling for EQ and NEQ since it's valid to attempt to + // compare values of any two types. + switch op { + case token.EQ: + return value.Bool(valuesEqual(lhs, rhs)), nil + case token.NEQ: + return value.Bool(!valuesEqual(lhs, rhs)), nil + } + + // The type of lhs and rhs must be acceptable for the binary operator. + if !acceptableBinopType(lhs, op) { + return value.Null, value.Error{ + Value: origLHS, + Inner: fmt.Errorf("should be one of %v for binop %s, got %s", binopAllowedTypes[op], op, lhs.Type()), + } + } else if !acceptableBinopType(rhs, op) { + return value.Null, value.Error{ + Value: origRHS, + Inner: fmt.Errorf("should be one of %v for binop %s, got %s", binopAllowedTypes[op], op, rhs.Type()), + } + } + + // At this point, regardless of the operator, lhs and rhs must have the same + // type. + if lhs.Type() != rhs.Type() { + return value.Null, value.TypeError{Value: rhs, Expected: lhs.Type()} + } + + switch op { + case token.OR: // bool || bool + return value.Bool(lhs.Bool() || rhs.Bool()), nil + case token.AND: // bool && Bool + return value.Bool(lhs.Bool() && rhs.Bool()), nil + + case token.ADD: // number + number, string + string + if lhs.Type() == value.TypeString { + return value.String(lhs.Text() + rhs.Text()), nil + } + + lhsNum, rhsNum := lhs.Number(), rhs.Number() + switch fitNumberKinds(lhsNum.Kind(), rhsNum.Kind()) { + case value.NumberKindUint: + return value.Uint(lhsNum.Uint() + rhsNum.Uint()), nil + case value.NumberKindInt: + return value.Int(lhsNum.Int() + rhsNum.Int()), nil + case value.NumberKindFloat: + return value.Float(lhsNum.Float() + rhsNum.Float()), nil + } + + case token.SUB: // number - number + lhsNum, rhsNum := lhs.Number(), rhs.Number() + switch fitNumberKinds(lhsNum.Kind(), rhsNum.Kind()) { + case value.NumberKindUint: + return value.Uint(lhsNum.Uint() - rhsNum.Uint()), nil + case value.NumberKindInt: + return value.Int(lhsNum.Int() - rhsNum.Int()), nil + case value.NumberKindFloat: + return value.Float(lhsNum.Float() - rhsNum.Float()), nil + } + + case token.MUL: // number * number + lhsNum, rhsNum := lhs.Number(), rhs.Number() + switch fitNumberKinds(lhsNum.Kind(), rhsNum.Kind()) { + case value.NumberKindUint: + return value.Uint(lhsNum.Uint() * rhsNum.Uint()), nil + case value.NumberKindInt: + return value.Int(lhsNum.Int() * rhsNum.Int()), nil + case value.NumberKindFloat: + return value.Float(lhsNum.Float() * rhsNum.Float()), nil + } + + case token.DIV: // number / number + lhsNum, rhsNum := lhs.Number(), rhs.Number() + switch fitNumberKinds(lhsNum.Kind(), rhsNum.Kind()) { + case value.NumberKindUint: + return value.Uint(lhsNum.Uint() / rhsNum.Uint()), nil + case value.NumberKindInt: + return value.Int(lhsNum.Int() / rhsNum.Int()), nil + case value.NumberKindFloat: + return value.Float(lhsNum.Float() / rhsNum.Float()), nil + } + + case token.MOD: // number % number + lhsNum, rhsNum := lhs.Number(), rhs.Number() + switch fitNumberKinds(lhsNum.Kind(), rhsNum.Kind()) { + case value.NumberKindUint: + return value.Uint(lhsNum.Uint() % rhsNum.Uint()), nil + case value.NumberKindInt: + return value.Int(lhsNum.Int() % rhsNum.Int()), nil + case value.NumberKindFloat: + return value.Float(math.Mod(lhsNum.Float(), rhsNum.Float())), nil + } + + case token.POW: // number ^ number + lhsNum, rhsNum := lhs.Number(), rhs.Number() + switch fitNumberKinds(lhsNum.Kind(), rhsNum.Kind()) { + case value.NumberKindUint: + return value.Uint(intPow(lhsNum.Uint(), rhsNum.Uint())), nil + case value.NumberKindInt: + return value.Int(intPow(lhsNum.Int(), rhsNum.Int())), nil + case value.NumberKindFloat: + return value.Float(math.Pow(lhsNum.Float(), rhsNum.Float())), nil + } + + case token.LT: // number < number, string < string + // Check string first. + if lhs.Type() == value.TypeString { + return value.Bool(lhs.Text() < rhs.Text()), nil + } + + // Not a string; must be a number. + lhsNum, rhsNum := lhs.Number(), rhs.Number() + switch fitNumberKinds(lhsNum.Kind(), rhsNum.Kind()) { + case value.NumberKindUint: + return value.Bool(lhsNum.Uint() < rhsNum.Uint()), nil + case value.NumberKindInt: + return value.Bool(lhsNum.Int() < rhsNum.Int()), nil + case value.NumberKindFloat: + return value.Bool(lhsNum.Float() < rhsNum.Float()), nil + } + + case token.GT: // number > number, string > string + // Check string first. + if lhs.Type() == value.TypeString { + return value.Bool(lhs.Text() > rhs.Text()), nil + } + + // Not a string; must be a number. + lhsNum, rhsNum := lhs.Number(), rhs.Number() + switch fitNumberKinds(lhsNum.Kind(), rhsNum.Kind()) { + case value.NumberKindUint: + return value.Bool(lhsNum.Uint() > rhsNum.Uint()), nil + case value.NumberKindInt: + return value.Bool(lhsNum.Int() > rhsNum.Int()), nil + case value.NumberKindFloat: + return value.Bool(lhsNum.Float() > rhsNum.Float()), nil + } + + case token.LTE: // number <= number, string <= string + // Check string first. + if lhs.Type() == value.TypeString { + return value.Bool(lhs.Text() <= rhs.Text()), nil + } + + // Not a string; must be a number. + lhsNum, rhsNum := lhs.Number(), rhs.Number() + switch fitNumberKinds(lhsNum.Kind(), rhsNum.Kind()) { + case value.NumberKindUint: + return value.Bool(lhsNum.Uint() <= rhsNum.Uint()), nil + case value.NumberKindInt: + return value.Bool(lhsNum.Int() <= rhsNum.Int()), nil + case value.NumberKindFloat: + return value.Bool(lhsNum.Float() <= rhsNum.Float()), nil + } + + case token.GTE: // number >= number, string >= string + // Check string first. + if lhs.Type() == value.TypeString { + return value.Bool(lhs.Text() >= rhs.Text()), nil + } + + // Not a string; must be a number. + lhsNum, rhsNum := lhs.Number(), rhs.Number() + switch fitNumberKinds(lhsNum.Kind(), rhsNum.Kind()) { + case value.NumberKindUint: + return value.Bool(lhsNum.Uint() >= rhsNum.Uint()), nil + case value.NumberKindInt: + return value.Bool(lhsNum.Int() >= rhsNum.Int()), nil + case value.NumberKindFloat: + return value.Bool(lhsNum.Float() >= rhsNum.Float()), nil + } + } + + panic("river/vm: unreachable") +} + +// tryUnwrapOptionalSecret accepts a value and, if it is a +// rivertypes.OptionalSecret where IsSecret is false, returns a string value +// instead. +// +// If val is not a rivertypes.OptionalSecret or IsSecret is true, +// tryUnwrapOptionalSecret returns the input value unchanged. +func tryUnwrapOptionalSecret(val value.Value) value.Value { + optSecret, ok := val.Interface().(rivertypes.OptionalSecret) + if !ok || optSecret.IsSecret { + return val + } + + return value.String(optSecret.Value) +} + +// valuesEqual returns true if two River Values are equal. +func valuesEqual(lhs value.Value, rhs value.Value) bool { + if lhs.Type() != rhs.Type() { + // Two values with different types are never equal. + return false + } + + switch lhs.Type() { + case value.TypeNull: + // Nothing to compare here: both lhs and rhs have the null type, + // so they're equal. + return true + + case value.TypeNumber: + // Two numbers are equal if they have equal values. However, we have to + // determine what comparison we want to do and upcast the values to a + // different Go type as needed (so that 3 == 3.0 is true). + lhsNum, rhsNum := lhs.Number(), rhs.Number() + switch fitNumberKinds(lhsNum.Kind(), rhsNum.Kind()) { + case value.NumberKindUint: + return lhsNum.Uint() == rhsNum.Uint() + case value.NumberKindInt: + return lhsNum.Int() == rhsNum.Int() + case value.NumberKindFloat: + return lhsNum.Float() == rhsNum.Float() + } + + case value.TypeString: + return lhs.Text() == rhs.Text() + + case value.TypeBool: + return lhs.Bool() == rhs.Bool() + + case value.TypeArray: + // Two arrays are equal if they have equal elements. + if lhs.Len() != rhs.Len() { + return false + } + for i := 0; i < lhs.Len(); i++ { + if !valuesEqual(lhs.Index(i), rhs.Index(i)) { + return false + } + } + return true + + case value.TypeObject: + // Two objects are equal if they have equal elements. + if lhs.Len() != rhs.Len() { + return false + } + for _, key := range lhs.Keys() { + lhsElement, _ := lhs.Key(key) + rhsElement, inRHS := rhs.Key(key) + if !inRHS { + return false + } + if !valuesEqual(lhsElement, rhsElement) { + return false + } + } + return true + + case value.TypeFunction: + // Two functions are never equal. We can't compare functions in Go, so + // there's no way to compare them in River right now. + return false + + case value.TypeCapsule: + // Two capsules are only equal if the underlying values are deeply equal. + return reflect.DeepEqual(lhs.Interface(), rhs.Interface()) + } + + panic("river/vm: unreachable") +} + +// binopAllowedTypes maps what type of values are permitted for a specific +// binary operation. +// +// token.EQ and token.NEQ are not included as they're handled separately from +// other binary ops. +var binopAllowedTypes = map[token.Token][]value.Type{ + token.OR: {value.TypeBool}, + token.AND: {value.TypeBool}, + + token.ADD: {value.TypeNumber, value.TypeString}, + token.SUB: {value.TypeNumber}, + token.MUL: {value.TypeNumber}, + token.DIV: {value.TypeNumber}, + token.MOD: {value.TypeNumber}, + token.POW: {value.TypeNumber}, + + token.LT: {value.TypeNumber, value.TypeString}, + token.GT: {value.TypeNumber, value.TypeString}, + token.LTE: {value.TypeNumber, value.TypeString}, + token.GTE: {value.TypeNumber, value.TypeString}, +} + +func acceptableBinopType(val value.Value, op token.Token) bool { + allowed, ok := binopAllowedTypes[op] + if !ok { + panic("river/vm: unexpected binop type") + } + + actualType := val.Type() + for _, allowType := range allowed { + if allowType == actualType { + return true + } + } + return false +} + +func fitNumberKinds(a, b value.NumberKind) value.NumberKind { + aPrec, bPrec := numberKindPrec[a], numberKindPrec[b] + if aPrec > bPrec { + return a + } + return b +} + +var numberKindPrec = map[value.NumberKind]int{ + value.NumberKindUint: 0, + value.NumberKindInt: 1, + value.NumberKindFloat: 2, +} + +func intPow[Number int64 | uint64](n, m Number) Number { + if m == 0 { + return 1 + } + result := n + for i := Number(2); i <= m; i++ { + result *= n + } + return result +} diff --git a/syntax/vm/op_binary_test.go b/syntax/vm/op_binary_test.go new file mode 100644 index 0000000000..45367777e4 --- /dev/null +++ b/syntax/vm/op_binary_test.go @@ -0,0 +1,94 @@ +package vm_test + +import ( + "reflect" + "testing" + + "github.com/grafana/river/parser" + "github.com/grafana/river/rivertypes" + "github.com/grafana/river/vm" + "github.com/stretchr/testify/require" +) + +func TestVM_OptionalSecret_Conversion(t *testing.T) { + scope := &vm.Scope{ + Variables: map[string]any{ + "string_val": "hello", + "non_secret_val": rivertypes.OptionalSecret{IsSecret: false, Value: "world"}, + "secret_val": rivertypes.OptionalSecret{IsSecret: true, Value: "secret"}, + }, + } + + tt := []struct { + name string + input string + expect interface{} + expectError string + }{ + { + name: "string + capsule", + input: `string_val + non_secret_val`, + expect: string("helloworld"), + }, + { + name: "capsule + string", + input: `non_secret_val + string_val`, + expect: string("worldhello"), + }, + { + name: "string == capsule", + input: `"world" == non_secret_val`, + expect: bool(true), + }, + { + name: "capsule == string", + input: `non_secret_val == "world"`, + expect: bool(true), + }, + { + name: "capsule (secret) == capsule (secret)", + input: `secret_val == secret_val`, + expect: bool(true), + }, + { + name: "capsule (non secret) == capsule (non secret)", + input: `non_secret_val == non_secret_val`, + expect: bool(true), + }, + { + name: "capsule (non secret) == capsule (secret)", + input: `non_secret_val == secret_val`, + expect: bool(false), + }, + { + name: "secret + string", + input: `secret_val + string_val`, + expectError: "secret_val should be one of [number string] for binop +", + }, + { + name: "string + secret", + input: `string_val + secret_val`, + expectError: "secret_val should be one of [number string] for binop +", + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + expr, err := parser.ParseExpression(tc.input) + require.NoError(t, err) + + expectTy := reflect.TypeOf(tc.expect) + if expectTy == nil { + expectTy = reflect.TypeOf((*any)(nil)).Elem() + } + rv := reflect.New(expectTy) + + if err := vm.New(expr).Evaluate(scope, rv.Interface()); tc.expectError == "" { + require.NoError(t, err) + require.Equal(t, tc.expect, rv.Elem().Interface()) + } else { + require.ErrorContains(t, err, tc.expectError) + } + }) + } +} diff --git a/syntax/vm/op_unary.go b/syntax/vm/op_unary.go new file mode 100644 index 0000000000..bc116d58bc --- /dev/null +++ b/syntax/vm/op_unary.go @@ -0,0 +1,33 @@ +package vm + +import ( + "github.com/grafana/river/internal/value" + "github.com/grafana/river/token" +) + +func evalUnaryOp(op token.Token, val value.Value) (value.Value, error) { + switch op { + case token.NOT: + if val.Type() != value.TypeBool { + return value.Null, value.TypeError{Value: val, Expected: value.TypeBool} + } + return value.Bool(!val.Bool()), nil + + case token.SUB: + if val.Type() != value.TypeNumber { + return value.Null, value.TypeError{Value: val, Expected: value.TypeNumber} + } + + valNum := val.Number() + switch valNum.Kind() { + case value.NumberKindInt, value.NumberKindUint: + // It doesn't make much sense to invert a uint, so we always cast to an + // int and return an int. + return value.Int(-valNum.Int()), nil + case value.NumberKindFloat: + return value.Float(-valNum.Float()), nil + } + } + + panic("river/vm: unreachable") +} diff --git a/syntax/vm/struct_decoder.go b/syntax/vm/struct_decoder.go new file mode 100644 index 0000000000..99a0a2358a --- /dev/null +++ b/syntax/vm/struct_decoder.go @@ -0,0 +1,323 @@ +package vm + +import ( + "fmt" + "reflect" + "strings" + + "github.com/grafana/river/ast" + "github.com/grafana/river/diag" + "github.com/grafana/river/internal/reflectutil" + "github.com/grafana/river/internal/rivertags" + "github.com/grafana/river/internal/value" +) + +// structDecoder decodes a series of AST statements into a Go value. +type structDecoder struct { + VM *Evaluator + Scope *Scope + Assoc map[value.Value]ast.Node + TagInfo *tagInfo +} + +// Decode decodes the list of statements into the struct value specified by rv. +func (st *structDecoder) Decode(stmts ast.Body, rv reflect.Value) error { + // TODO(rfratto): potentially loosen this restriction and allow decoding into + // an interface{} or map[string]interface{}. + if rv.Kind() != reflect.Struct { + panic(fmt.Sprintf("river/vm: structDecoder expects struct, got %s", rv.Kind())) + } + + state := decodeOptions{ + Tags: st.TagInfo.TagLookup, + EnumBlocks: st.TagInfo.EnumLookup, + SeenAttrs: make(map[string]struct{}), + SeenBlocks: make(map[string]struct{}), + SeenEnums: make(map[string]struct{}), + + BlockCount: make(map[string]int), + BlockIndex: make(map[*ast.BlockStmt]int), + + EnumCount: make(map[string]int), + EnumIndex: make(map[*ast.BlockStmt]int), + } + + // Iterate over the set of blocks to populate block count and block index. + // Block index is its index in the set of blocks with the *same name*. + // + // If the block belongs to an enum, we populate enum count and enum index + // instead. The enum index is the index on the set of blocks for the *same + // enum*. + for _, stmt := range stmts { + switch stmt := stmt.(type) { + case *ast.BlockStmt: + fullName := strings.Join(stmt.Name, ".") + + if enumTf, isEnum := st.TagInfo.EnumLookup[fullName]; isEnum { + enumName := strings.Join(enumTf.EnumField.Name, ".") + state.EnumIndex[stmt] = state.EnumCount[enumName] + state.EnumCount[enumName]++ + } else { + state.BlockIndex[stmt] = state.BlockCount[fullName] + state.BlockCount[fullName]++ + } + } + } + + for _, stmt := range stmts { + switch stmt := stmt.(type) { + case *ast.AttributeStmt: + // TODO(rfratto): append to list of diagnostics instead of aborting early. + if err := st.decodeAttr(stmt, rv, &state); err != nil { + return err + } + + case *ast.BlockStmt: + // TODO(rfratto): append to list of diagnostics instead of aborting early. + if err := st.decodeBlock(stmt, rv, &state); err != nil { + return err + } + + default: + panic(fmt.Sprintf("river/vm: unrecognized node type %T", stmt)) + } + } + + for _, tf := range st.TagInfo.Tags { + // Ignore any optional tags. + if tf.IsOptional() { + continue + } + + fullName := strings.Join(tf.Name, ".") + + switch { + case tf.IsAttr(): + if _, consumed := state.SeenAttrs[fullName]; !consumed { + // TODO(rfratto): change to diagnostics. + return fmt.Errorf("missing required attribute %q", fullName) + } + + case tf.IsBlock(): + if _, consumed := state.SeenBlocks[fullName]; !consumed { + // TODO(rfratto): change to diagnostics. + return fmt.Errorf("missing required block %q", fullName) + } + } + } + + return nil +} + +type decodeOptions struct { + Tags map[string]rivertags.Field + EnumBlocks map[string]enumBlock + + SeenAttrs, SeenBlocks, SeenEnums map[string]struct{} + + // BlockCount and BlockIndex are used to determine: + // + // * How big a slice of blocks should be for a block of a given name (BlockCount) + // * Which element within that slice is a given block assigned to (BlockIndex) + // + // This is used for decoding a series of rule blocks for prometheus.relabel, + // where 5 rules would have a "rule" key in BlockCount with a value of 5, and + // where the first block would be index 0, the second block would be index 1, + // and so on. + // + // The index in BlockIndex is relative to a block name; the first block named + // "hello.world" and the first block named "fizz.buzz" both have index 0. + + BlockCount map[string]int // Number of times a block by full name is seen. + BlockIndex map[*ast.BlockStmt]int // Index of a block within a set of blocks with the same name. + + // EnumCount and EnumIndex are similar to BlockCount/BlockIndex, but instead + // reference the number of blocks assigned to the same enum (EnumCount) and + // the index of a block within that enum slice (EnumIndex). + + EnumCount map[string]int // Number of times an enum group is seen by enum name. + EnumIndex map[*ast.BlockStmt]int // Index of a block within a set of enum blocks of the same enum. +} + +func (st *structDecoder) decodeAttr(attr *ast.AttributeStmt, rv reflect.Value, state *decodeOptions) error { + fullName := attr.Name.Name + if _, seen := state.SeenAttrs[fullName]; seen { + return diag.Diagnostics{{ + Severity: diag.SeverityLevelError, + StartPos: ast.StartPos(attr).Position(), + EndPos: ast.EndPos(attr).Position(), + Message: fmt.Sprintf("attribute %q may only be provided once", fullName), + }} + } + state.SeenAttrs[fullName] = struct{}{} + + tf, ok := state.Tags[fullName] + if !ok { + return diag.Diagnostics{{ + Severity: diag.SeverityLevelError, + StartPos: ast.StartPos(attr).Position(), + EndPos: ast.EndPos(attr).Position(), + Message: fmt.Sprintf("unrecognized attribute name %q", fullName), + }} + } else if tf.IsBlock() { + return diag.Diagnostics{{ + Severity: diag.SeverityLevelError, + StartPos: ast.StartPos(attr).Position(), + EndPos: ast.EndPos(attr).Position(), + Message: fmt.Sprintf("%q must be a block, but is used as an attribute", fullName), + }} + } + + // Decode the attribute. + val, err := st.VM.evaluateExpr(st.Scope, st.Assoc, attr.Value) + if err != nil { + // TODO(rfratto): get error as diagnostics. + return err + } + + // We're reconverting our reflect.Value back into an interface{}, so we + // need to also turn it back into a pointer for decoding. + field := reflectutil.GetOrAlloc(rv, tf) + if err := value.Decode(val, field.Addr().Interface()); err != nil { + // TODO(rfratto): get error as diagnostics. + return err + } + + return nil +} + +func (st *structDecoder) decodeBlock(block *ast.BlockStmt, rv reflect.Value, state *decodeOptions) error { + fullName := block.GetBlockName() + + if _, isEnum := state.EnumBlocks[fullName]; isEnum { + return st.decodeEnumBlock(fullName, block, rv, state) + } + return st.decodeNormalBlock(fullName, block, rv, state) +} + +// decodeNormalBlock decodes a standard (non-enum) block. +func (st *structDecoder) decodeNormalBlock(fullName string, block *ast.BlockStmt, rv reflect.Value, state *decodeOptions) error { + tf, isBlock := state.Tags[fullName] + if !isBlock { + return diag.Diagnostics{{ + Severity: diag.SeverityLevelError, + StartPos: ast.StartPos(block).Position(), + EndPos: ast.EndPos(block).Position(), + Message: fmt.Sprintf("unrecognized block name %q", fullName), + }} + } else if tf.IsAttr() { + return diag.Diagnostics{{ + Severity: diag.SeverityLevelError, + StartPos: ast.StartPos(block).Position(), + EndPos: ast.EndPos(block).Position(), + Message: fmt.Sprintf("%q must be an attribute, but is used as a block", fullName), + }} + } + + field := reflectutil.GetOrAlloc(rv, tf) + decodeField := prepareDecodeValue(field) + + switch decodeField.Kind() { + case reflect.Slice: + // If this is the first time we've seen the block, reset its length to + // zero. + if _, seen := state.SeenBlocks[fullName]; !seen { + count := state.BlockCount[fullName] + decodeField.Set(reflect.MakeSlice(decodeField.Type(), count, count)) + } + + blockIndex, ok := state.BlockIndex[block] + if !ok { + panic("river/vm: block not found in index lookup table") + } + decodeElement := prepareDecodeValue(decodeField.Index(blockIndex)) + err := st.VM.evaluateBlockOrBody(st.Scope, st.Assoc, block, decodeElement) + if err != nil { + // TODO(rfratto): get error as diagnostics. + return err + } + + case reflect.Array: + if decodeField.Len() != state.BlockCount[fullName] { + return diag.Diagnostics{{ + Severity: diag.SeverityLevelError, + StartPos: ast.StartPos(block).Position(), + EndPos: ast.EndPos(block).Position(), + Message: fmt.Sprintf( + "block %q must be specified exactly %d times, but was specified %d times", + fullName, + decodeField.Len(), + state.BlockCount[fullName], + ), + }} + } + + blockIndex, ok := state.BlockIndex[block] + if !ok { + panic("river/vm: block not found in index lookup table") + } + decodeElement := prepareDecodeValue(decodeField.Index(blockIndex)) + err := st.VM.evaluateBlockOrBody(st.Scope, st.Assoc, block, decodeElement) + if err != nil { + // TODO(rfratto): get error as diagnostics. + return err + } + + default: + if state.BlockCount[fullName] > 1 { + return diag.Diagnostics{{ + Severity: diag.SeverityLevelError, + StartPos: ast.StartPos(block).Position(), + EndPos: ast.EndPos(block).Position(), + Message: fmt.Sprintf("block %q may only be specified once", fullName), + }} + } + + err := st.VM.evaluateBlockOrBody(st.Scope, st.Assoc, block, decodeField) + if err != nil { + // TODO(rfratto): get error as diagnostics. + return err + } + } + + state.SeenBlocks[fullName] = struct{}{} + return nil +} + +func (st *structDecoder) decodeEnumBlock(fullName string, block *ast.BlockStmt, rv reflect.Value, state *decodeOptions) error { + tf, ok := state.EnumBlocks[fullName] + if !ok { + // decodeEnumBlock should only ever be called from decodeBlock, so this + // should never happen. + panic("decodeEnumBlock called with a non-enum block") + } + + enumName := strings.Join(tf.EnumField.Name, ".") + enumField := reflectutil.GetOrAlloc(rv, tf.EnumField) + decodeField := prepareDecodeValue(enumField) + + if decodeField.Kind() != reflect.Slice { + panic("river/vm: enum field must be a slice kind, got " + decodeField.Kind().String()) + } + + // If this is the first time we've seen the enum, reset its length to zero. + if _, seen := state.SeenEnums[enumName]; !seen { + count := state.EnumCount[enumName] + decodeField.Set(reflect.MakeSlice(decodeField.Type(), count, count)) + } + state.SeenEnums[enumName] = struct{}{} + + // Prepare the enum element to decode into. + enumIndex, ok := state.EnumIndex[block] + if !ok { + panic("river/vm: enum block not found in index lookup table") + } + enumElement := prepareDecodeValue(decodeField.Index(enumIndex)) + + // Prepare the block field to decode into. + enumBlock := reflectutil.GetOrAlloc(enumElement, tf.BlockField) + decodeBlock := prepareDecodeValue(enumBlock) + + // Decode into the block field. + return st.VM.evaluateBlockOrBody(st.Scope, st.Assoc, block, decodeBlock) +} diff --git a/syntax/vm/tag_cache.go b/syntax/vm/tag_cache.go new file mode 100644 index 0000000000..f9c1b69c56 --- /dev/null +++ b/syntax/vm/tag_cache.go @@ -0,0 +1,80 @@ +package vm + +import ( + "reflect" + "strings" + "sync" + + "github.com/grafana/river/internal/rivertags" +) + +// tagsCache caches the river tags for a struct type. This is never cleared, +// but since most structs will be statically created throughout the lifetime +// of the process, this will consume a negligible amount of memory. +var tagsCache sync.Map + +func getCachedTagInfo(t reflect.Type) *tagInfo { + if t.Kind() != reflect.Struct { + panic("getCachedTagInfo called with non-struct type") + } + + if entry, ok := tagsCache.Load(t); ok { + return entry.(*tagInfo) + } + + tfs := rivertags.Get(t) + ti := &tagInfo{ + Tags: tfs, + TagLookup: make(map[string]rivertags.Field, len(tfs)), + EnumLookup: make(map[string]enumBlock), // The length is not known ahead of time + } + + for _, tf := range tfs { + switch { + case tf.IsAttr(), tf.IsBlock(): + fullName := strings.Join(tf.Name, ".") + ti.TagLookup[fullName] = tf + + case tf.IsEnum(): + fullName := strings.Join(tf.Name, ".") + + // Find all the blocks that match to the enum, and inject them into the + // EnumLookup table. + enumFieldType := t.FieldByIndex(tf.Index).Type + enumBlocksInfo := getCachedTagInfo(deferenceType(enumFieldType.Elem())) + for _, blockField := range enumBlocksInfo.TagLookup { + // The full name of the enum block is the name of the enum plus the + // name of the block, separated by '.' + enumBlockName := fullName + "." + strings.Join(blockField.Name, ".") + ti.EnumLookup[enumBlockName] = enumBlock{ + EnumField: tf, + BlockField: blockField, + } + } + } + } + + tagsCache.Store(t, ti) + return ti +} + +func deferenceType(ty reflect.Type) reflect.Type { + for ty.Kind() == reflect.Pointer { + ty = ty.Elem() + } + return ty +} + +type tagInfo struct { + Tags []rivertags.Field + TagLookup map[string]rivertags.Field + + // EnumLookup maps enum blocks to the enum field. For example, an enum block + // called "foo.foo" and "foo.bar" will both map to the "foo" enum field. + EnumLookup map[string]enumBlock +} + +type enumBlock struct { + EnumField rivertags.Field // Field in the parent struct of the enum slice + BlockField rivertags.Field // Field in the enum struct for the enum block +} diff --git a/syntax/vm/vm.go b/syntax/vm/vm.go new file mode 100644 index 0000000000..a9c6481593 --- /dev/null +++ b/syntax/vm/vm.go @@ -0,0 +1,486 @@ +// Package vm provides a River expression evaluator. +package vm + +import ( + "fmt" + "reflect" + "strings" + + "github.com/grafana/river/ast" + "github.com/grafana/river/diag" + "github.com/grafana/river/internal/reflectutil" + "github.com/grafana/river/internal/rivertags" + "github.com/grafana/river/internal/stdlib" + "github.com/grafana/river/internal/value" +) + +// Evaluator evaluates River AST nodes into Go values. Each Evaluator is bound +// to a single AST node. To evaluate the node, call Evaluate. +type Evaluator struct { + // node for the AST. + // + // Each Evaluator is bound to a single node to allow for future performance + // optimizations, allowing for precomputing and storing the result of + // anything that is constant. + node ast.Node +} + +// New creates a new Evaluator for the given AST node. The given node must be +// either an *ast.File, *ast.BlockStmt, ast.Body, or assignable to an ast.Expr. +func New(node ast.Node) *Evaluator { + return &Evaluator{node: node} +} + +// Evaluate evaluates the Evaluator's node into a River value and decodes that +// value into the Go value v. +// +// Each call to Evaluate may provide a different scope with new values for +// available variables. If a variable used by the Evaluator's node isn't +// defined in scope or any of the parent scopes, Evaluate will return an error. +func (vm *Evaluator) Evaluate(scope *Scope, v interface{}) (err error) { + // Track a map that allows us to associate values with ast.Nodes so we can + // return decorated error messages. + assoc := make(map[value.Value]ast.Node) + + defer func() { + if err != nil { + // Decorate the error on return. + err = makeDiagnostic(err, assoc) + } + }() + + switch node := vm.node.(type) { + case *ast.BlockStmt, ast.Body: + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Pointer { + panic(fmt.Sprintf("river/vm: expected pointer, got %s", rv.Kind())) + } + return vm.evaluateBlockOrBody(scope, assoc, node, rv) + case *ast.File: + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Pointer { + panic(fmt.Sprintf("river/vm: expected pointer, got %s", rv.Kind())) + } + return vm.evaluateBlockOrBody(scope, assoc, node.Body, rv) + default: + expr, ok := node.(ast.Expr) + if !ok { + panic(fmt.Sprintf("river/vm: unexpected value type %T", node)) + } + val, err := vm.evaluateExpr(scope, assoc, expr) + if err != nil { + return err + } + return value.Decode(val, v) + } +} + +func (vm *Evaluator) evaluateBlockOrBody(scope *Scope, assoc map[value.Value]ast.Node, node ast.Node, rv reflect.Value) error { + // Before decoding the block, we need to temporarily take the address of rv + // to handle the case of it implementing the unmarshaler interface. + if rv.CanAddr() { + rv = rv.Addr() + } + + if err, unmarshaled := vm.evaluateUnmarshalRiver(scope, assoc, node, rv); unmarshaled || err != nil { + return err + } + + if ru, ok := rv.Interface().(value.Defaulter); ok { + ru.SetToDefault() + } + + if err := vm.evaluateDecode(scope, assoc, node, rv); err != nil { + return err + } + + if ru, ok := rv.Interface().(value.Validator); ok { + if err := ru.Validate(); err != nil { + return err + } + } + + return nil +} + +func (vm *Evaluator) evaluateUnmarshalRiver(scope *Scope, assoc map[value.Value]ast.Node, node ast.Node, rv reflect.Value) (error, bool) { + if ru, ok := rv.Interface().(value.Unmarshaler); ok { + return ru.UnmarshalRiver(func(v interface{}) error { + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Pointer { + panic(fmt.Sprintf("river/vm: expected pointer, got %s", rv.Kind())) + } + return vm.evaluateBlockOrBody(scope, assoc, node, rv.Elem()) + }), true + } + + return nil, false +} + +func (vm *Evaluator) evaluateDecode(scope *Scope, assoc map[value.Value]ast.Node, node ast.Node, rv reflect.Value) error { + // TODO(rfratto): the errors returned by this function are missing context to + // be able to print line numbers. We need to return decorated error types. + + // Fully deference rv and allocate pointers as necessary. + for rv.Kind() == reflect.Pointer { + if rv.IsNil() { + rv.Set(reflect.New(rv.Type().Elem())) + } + rv = rv.Elem() + } + + if rv.Kind() == reflect.Interface { + var anyMap map[string]interface{} + into := reflect.MakeMap(reflect.TypeOf(anyMap)) + if err := vm.evaluateMap(scope, assoc, node, into); err != nil { + return err + } + + rv.Set(into) + return nil + } else if rv.Kind() == reflect.Map { + return vm.evaluateMap(scope, assoc, node, rv) + } else if rv.Kind() != reflect.Struct { + panic(fmt.Sprintf("river/vm: can only evaluate blocks into structs, got %s", rv.Kind())) + } + + ti := getCachedTagInfo(rv.Type()) + + var stmts ast.Body + switch node := node.(type) { + case *ast.BlockStmt: + // Decode the block label first. + if err := vm.evaluateBlockLabel(node, ti.Tags, rv); err != nil { + return err + } + stmts = node.Body + case ast.Body: + stmts = node + default: + panic(fmt.Sprintf("river/vm: unrecognized node type %T", node)) + } + + sd := structDecoder{ + VM: vm, + Scope: scope, + Assoc: assoc, + TagInfo: ti, + } + return sd.Decode(stmts, rv) +} + +// evaluateMap evaluates a block or a body into a map. +func (vm *Evaluator) evaluateMap(scope *Scope, assoc map[value.Value]ast.Node, node ast.Node, rv reflect.Value) error { + var stmts ast.Body + + switch node := node.(type) { + case *ast.BlockStmt: + if node.Label != "" { + return diag.Diagnostic{ + Severity: diag.SeverityLevelError, + StartPos: node.NamePos.Position(), + EndPos: node.LCurlyPos.Position(), + Message: fmt.Sprintf("block %q requires non-empty label", strings.Join(node.Name, ".")), + } + } + stmts = node.Body + case ast.Body: + stmts = node + default: + panic(fmt.Sprintf("river/vm: unrecognized node type %T", node)) + } + + if rv.IsNil() { + rv.Set(reflect.MakeMap(rv.Type())) + } + + for _, stmt := range stmts { + switch stmt := stmt.(type) { + case *ast.AttributeStmt: + val, err := vm.evaluateExpr(scope, assoc, stmt.Value) + if err != nil { + // TODO(rfratto): get error as diagnostics. + return err + } + + target := reflect.New(rv.Type().Elem()).Elem() + if err := value.Decode(val, target.Addr().Interface()); err != nil { + // TODO(rfratto): get error as diagnostics. + return err + } + rv.SetMapIndex(reflect.ValueOf(stmt.Name.Name), target) + + case *ast.BlockStmt: + // TODO(rfratto): potentially relax this restriction where nested blocks + // are permitted when decoding to a map. + return diag.Diagnostic{ + Severity: diag.SeverityLevelError, + StartPos: ast.StartPos(stmt).Position(), + EndPos: ast.EndPos(stmt).Position(), + Message: "nested blocks not supported here", + } + + default: + panic(fmt.Sprintf("river/vm: unrecognized node type %T", stmt)) + } + } + + return nil +} + +func (vm *Evaluator) evaluateBlockLabel(node *ast.BlockStmt, tfs []rivertags.Field, rv reflect.Value) error { + var ( + labelField rivertags.Field + foundField bool + ) + for _, tf := range tfs { + if tf.Flags&rivertags.FlagLabel != 0 { + labelField = tf + foundField = true + break + } + } + + // Check for user errors first. + // + // We return parser.Error here to restrict the position of the error to just + // the name. We might be able to clean this up in the future by extending + // ValueError to have an explicit position. + switch { + case node.Label == "" && foundField: // No user label, but struct expects one + return diag.Diagnostic{ + Severity: diag.SeverityLevelError, + StartPos: node.NamePos.Position(), + EndPos: node.LCurlyPos.Position(), + Message: fmt.Sprintf("block %q requires non-empty label", strings.Join(node.Name, ".")), + } + case node.Label != "" && !foundField: // User label, but struct doesn't expect one + return diag.Diagnostic{ + Severity: diag.SeverityLevelError, + StartPos: node.NamePos.Position(), + EndPos: node.LCurlyPos.Position(), + Message: fmt.Sprintf("block %q does not support specifying labels", strings.Join(node.Name, ".")), + } + } + + if node.Label == "" { + // no-op: no labels to set. + return nil + } + + var ( + field = reflectutil.GetOrAlloc(rv, labelField) + fieldType = field.Type() + ) + if !reflect.TypeOf(node.Label).AssignableTo(fieldType) { + // The Label struct field needs to be a string. + panic(fmt.Sprintf("river/vm: cannot assign block label to non-string type %s", fieldType)) + } + field.Set(reflect.ValueOf(node.Label)) + return nil +} + +// prepareDecodeValue prepares v for decoding. Pointers will be fully +// dereferenced until finding a non-pointer value. nil pointers will be +// allocated. +func prepareDecodeValue(v reflect.Value) reflect.Value { + for v.Kind() == reflect.Pointer { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + return v +} + +func (vm *Evaluator) evaluateExpr(scope *Scope, assoc map[value.Value]ast.Node, expr ast.Expr) (v value.Value, err error) { + defer func() { + if v != value.Null { + assoc[v] = expr + } + }() + + switch expr := expr.(type) { + case *ast.LiteralExpr: + return valueFromLiteral(expr.Value, expr.Kind) + + case *ast.BinaryExpr: + lhs, err := vm.evaluateExpr(scope, assoc, expr.Left) + if err != nil { + return value.Null, err + } + rhs, err := vm.evaluateExpr(scope, assoc, expr.Right) + if err != nil { + return value.Null, err + } + return evalBinop(lhs, expr.Kind, rhs) + + case *ast.ArrayExpr: + vals := make([]value.Value, len(expr.Elements)) + for i, element := range expr.Elements { + val, err := vm.evaluateExpr(scope, assoc, element) + if err != nil { + return value.Null, err + } + vals[i] = val + } + return value.Array(vals...), nil + + case *ast.ObjectExpr: + fields := make(map[string]value.Value, len(expr.Fields)) + for _, field := range expr.Fields { + val, err := vm.evaluateExpr(scope, assoc, field.Value) + if err != nil { + return value.Null, err + } + fields[field.Name.Name] = val + } + return value.Object(fields), nil + + case *ast.IdentifierExpr: + val, found := scope.Lookup(expr.Ident.Name) + if !found { + return value.Null, diag.Diagnostic{ + Severity: diag.SeverityLevelError, + StartPos: ast.StartPos(expr).Position(), + EndPos: ast.EndPos(expr).Position(), + Message: fmt.Sprintf("identifier %q does not exist", expr.Ident.Name), + } + } + return value.Encode(val), nil + + case *ast.AccessExpr: + val, err := vm.evaluateExpr(scope, assoc, expr.Value) + if err != nil { + return value.Null, err + } + + switch val.Type() { + case value.TypeObject: + res, ok := val.Key(expr.Name.Name) + if !ok { + return value.Null, diag.Diagnostic{ + Severity: diag.SeverityLevelError, + StartPos: ast.StartPos(expr.Name).Position(), + EndPos: ast.EndPos(expr.Name).Position(), + Message: fmt.Sprintf("field %q does not exist", expr.Name.Name), + } + } + return res, nil + default: + return value.Null, value.Error{ + Value: val, + Inner: fmt.Errorf("cannot access field %q on value of type %s", expr.Name.Name, val.Type()), + } + } + + case *ast.IndexExpr: + val, err := vm.evaluateExpr(scope, assoc, expr.Value) + if err != nil { + return value.Null, err + } + idx, err := vm.evaluateExpr(scope, assoc, expr.Index) + if err != nil { + return value.Null, err + } + + switch val.Type() { + case value.TypeArray: + // Arrays are indexed with a number. + if idx.Type() != value.TypeNumber { + return value.Null, value.TypeError{Value: idx, Expected: value.TypeNumber} + } + intIndex := int(idx.Int()) + + if intIndex < 0 || intIndex >= val.Len() { + return value.Null, value.Error{ + Value: idx, + Inner: fmt.Errorf("index %d is out of range of array with length %d", intIndex, val.Len()), + } + } + return val.Index(intIndex), nil + + case value.TypeObject: + // Objects are indexed with a string. + if idx.Type() != value.TypeString { + return value.Null, value.TypeError{Value: idx, Expected: value.TypeString} + } + + field, ok := val.Key(idx.Text()) + if !ok { + // If a key doesn't exist in an object accessed with [], return null. + return value.Null, nil + } + return field, nil + + default: + return value.Null, value.Error{ + Value: val, + Inner: fmt.Errorf("expected object or array, got %s", val.Type()), + } + } + + case *ast.ParenExpr: + return vm.evaluateExpr(scope, assoc, expr.Inner) + + case *ast.UnaryExpr: + val, err := vm.evaluateExpr(scope, assoc, expr.Value) + if err != nil { + return value.Null, err + } + return evalUnaryOp(expr.Kind, val) + + case *ast.CallExpr: + funcVal, err := vm.evaluateExpr(scope, assoc, expr.Value) + if err != nil { + return funcVal, err + } + if funcVal.Type() != value.TypeFunction { + return value.Null, value.TypeError{Value: funcVal, Expected: value.TypeFunction} + } + + args := make([]value.Value, len(expr.Args)) + for i := 0; i < len(expr.Args); i++ { + args[i], err = vm.evaluateExpr(scope, assoc, expr.Args[i]) + if err != nil { + return value.Null, err + } + } + return funcVal.Call(args...) + + default: + panic(fmt.Sprintf("river/vm: unexpected ast.Expr type %T", expr)) + } +} + +// A Scope exposes a set of variables available to use during evaluation. +type Scope struct { + // Parent optionally points to a parent Scope containing more variable. + // Variables defined in children scopes take precedence over variables of the + // same name found in parent scopes. + Parent *Scope + + // Variables holds the list of available variable names that can be used when + // evaluating a node. + // + // Values in the Variables map should be considered immutable after passed to + // Evaluate; maps and slices will be copied by reference for performance + // optimizations. + Variables map[string]interface{} +} + +// Lookup looks up a named identifier from the scope, all of the scope's +// parents, and the stdlib. +func (s *Scope) Lookup(name string) (interface{}, bool) { + // Traverse the scope first, then fall back to stdlib. + for s != nil { + if val, ok := s.Variables[name]; ok { + return val, true + } + s = s.Parent + } + if ident, ok := stdlib.Identifiers[name]; ok { + return ident, true + } + return nil, false +} diff --git a/syntax/vm/vm_benchmarks_test.go b/syntax/vm/vm_benchmarks_test.go new file mode 100644 index 0000000000..e5530ccb37 --- /dev/null +++ b/syntax/vm/vm_benchmarks_test.go @@ -0,0 +1,106 @@ +package vm_test + +import ( + "fmt" + "math" + "reflect" + "testing" + + "github.com/grafana/river/parser" + "github.com/grafana/river/vm" + "github.com/stretchr/testify/require" +) + +func BenchmarkExprs(b *testing.B) { + // Shared scope across all tests below + scope := &vm.Scope{ + Variables: map[string]interface{}{ + "foobar": int(42), + }, + } + + tt := []struct { + name string + input string + expect interface{} + }{ + // Binops + {"or", `false || true`, bool(true)}, + {"and", `true && false`, bool(false)}, + {"math/eq", `3 == 5`, bool(false)}, + {"math/neq", `3 != 5`, bool(true)}, + {"math/lt", `3 < 5`, bool(true)}, + {"math/lte", `3 <= 5`, bool(true)}, + {"math/gt", `3 > 5`, bool(false)}, + {"math/gte", `3 >= 5`, bool(false)}, + {"math/add", `3 + 5`, int(8)}, + {"math/sub", `3 - 5`, int(-2)}, + {"math/mul", `3 * 5`, int(15)}, + {"math/div", `3 / 5`, int(0)}, + {"math/mod", `5 % 3`, int(2)}, + {"math/pow", `3 ^ 5`, int(243)}, + {"binop chain", `3 + 5 * 2`, int(13)}, // Chain multiple binops + + // Identifier + {"ident lookup", `foobar`, int(42)}, + + // Arrays + {"array", `[0, 1, 2]`, []int{0, 1, 2}}, + + // Objects + {"object to map", `{ a = 5, b = 10 }`, map[string]int{"a": 5, "b": 10}}, + { + name: "object to struct", + input: `{ + name = "John Doe", + age = 42, + }`, + expect: struct { + Name string `river:"name,attr"` + Age int `river:"age,attr"` + Country string `river:"country,attr,optional"` + }{ + Name: "John Doe", + Age: 42, + }, + }, + + // Access + {"access", `{ a = 15 }.a`, int(15)}, + {"nested access", `{ a = { b = 12 } }.a.b`, int(12)}, + + // Indexing + {"index", `[0, 1, 2][1]`, int(1)}, + {"nested index", `[[1,2,3]][0][2]`, int(3)}, + + // Paren + {"paren", `(15)`, int(15)}, + + // Unary + {"unary not", `!true`, bool(false)}, + {"unary neg", `-15`, int(-15)}, + {"unary float", `-15.0`, float64(-15.0)}, + {"unary int64", fmt.Sprintf("%v", math.MaxInt64), math.MaxInt64}, + {"unary uint64", fmt.Sprintf("%v", uint64(math.MaxInt64)+1), uint64(math.MaxInt64) + 1}, + // math.MaxUint64 + 1 = 18446744073709551616 + {"unary float64 from overflowing uint", "18446744073709551616", float64(18446744073709551616)}, + } + + for _, tc := range tt { + b.Run(tc.name, func(b *testing.B) { + b.StopTimer() + expr, err := parser.ParseExpression(tc.input) + require.NoError(b, err) + + eval := vm.New(expr) + b.StartTimer() + + expectType := reflect.TypeOf(tc.expect) + + for i := 0; i < b.N; i++ { + vPtr := reflect.New(expectType).Interface() + _ = eval.Evaluate(scope, vPtr) + } + }) + } +} diff --git a/syntax/vm/vm_block_test.go b/syntax/vm/vm_block_test.go new file mode 100644 index 0000000000..ebc2ff0e6b --- /dev/null +++ b/syntax/vm/vm_block_test.go @@ -0,0 +1,802 @@ +package vm_test + +import ( + "fmt" + "math" + "reflect" + "testing" + + "github.com/grafana/river/ast" + "github.com/grafana/river/parser" + "github.com/grafana/river/vm" + "github.com/stretchr/testify/require" +) + +// This file contains tests for decoding blocks. + +func TestVM_File(t *testing.T) { + type block struct { + String string `river:"string,attr"` + Number int `river:"number,attr,optional"` + } + type file struct { + SettingA int `river:"setting_a,attr"` + SettingB int `river:"setting_b,attr,optional"` + + Block block `river:"some_block,block,optional"` + } + + input := ` + setting_a = 15 + + some_block { + string = "Hello, world!" + } + ` + + expect := file{ + SettingA: 15, + Block: block{ + String: "Hello, world!", + }, + } + + res, err := parser.ParseFile(t.Name(), []byte(input)) + require.NoError(t, err) + + eval := vm.New(res) + + var actual file + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, expect, actual) +} + +func TestVM_Block_Attributes(t *testing.T) { + t.Run("Decodes attributes", func(t *testing.T) { + type block struct { + Number int `river:"number,attr"` + String string `river:"string,attr"` + } + + input := `some_block { + number = 15 + string = "Hello, world!" + }` + eval := vm.New(parseBlock(t, input)) + + var actual block + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, 15, actual.Number) + require.Equal(t, "Hello, world!", actual.String) + }) + + t.Run("Fails if attribute used as block", func(t *testing.T) { + type block struct { + Number int `river:"number,attr"` + } + + input := `some_block { + number {} + }` + eval := vm.New(parseBlock(t, input)) + + err := eval.Evaluate(nil, &block{}) + require.EqualError(t, err, `2:4: "number" must be an attribute, but is used as a block`) + }) + + t.Run("Fails if required attributes are not present", func(t *testing.T) { + type block struct { + Number int `river:"number,attr"` + String string `river:"string,attr"` + } + + input := `some_block { + number = 15 + }` + eval := vm.New(parseBlock(t, input)) + + err := eval.Evaluate(nil, &block{}) + require.EqualError(t, err, `missing required attribute "string"`) + }) + + t.Run("Succeeds if optional attributes are not present", func(t *testing.T) { + type block struct { + Number int `river:"number,attr"` + String string `river:"string,attr,optional"` + } + + input := `some_block { + number = 15 + }` + eval := vm.New(parseBlock(t, input)) + + var actual block + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, 15, actual.Number) + require.Equal(t, "", actual.String) + }) + + t.Run("Fails if attribute is not defined in struct", func(t *testing.T) { + type block struct { + Number int `river:"number,attr"` + } + + input := `some_block { + number = 15 + invalid = "This attribute does not exist!" + }` + eval := vm.New(parseBlock(t, input)) + + err := eval.Evaluate(nil, &block{}) + require.EqualError(t, err, `3:4: unrecognized attribute name "invalid"`) + }) + + t.Run("Tests decoding into an interface", func(t *testing.T) { + type block struct { + Anything interface{} `river:"anything,attr"` + } + + tests := []struct { + testName string + val string + expectedValType reflect.Kind + }{ + {testName: "test_int_1", val: "15", expectedValType: reflect.Int}, + {testName: "test_int_2", val: "-15", expectedValType: reflect.Int}, + {testName: "test_int_3", val: fmt.Sprintf("%v", math.MaxInt64), expectedValType: reflect.Int}, + {testName: "test_int_4", val: fmt.Sprintf("%v", math.MinInt64), expectedValType: reflect.Int}, + {testName: "test_uint_1", val: fmt.Sprintf("%v", uint64(math.MaxInt64)+1), expectedValType: reflect.Uint64}, + {testName: "test_uint_2", val: fmt.Sprintf("%v", uint64(math.MaxUint64)), expectedValType: reflect.Uint64}, + {testName: "test_float_1", val: fmt.Sprintf("%v9", math.MinInt64), expectedValType: reflect.Float64}, + {testName: "test_float_2", val: fmt.Sprintf("%v9", uint64(math.MaxUint64)), expectedValType: reflect.Float64}, + {testName: "test_float_3", val: "16.0", expectedValType: reflect.Float64}, + } + + for _, tt := range tests { + t.Run(tt.testName, func(t *testing.T) { + input := fmt.Sprintf(`some_block { + anything = %s + }`, tt.val) + eval := vm.New(parseBlock(t, input)) + + var actual block + err := eval.Evaluate(nil, &actual) + require.NoError(t, err) + require.Equal(t, tt.expectedValType.String(), reflect.TypeOf(actual.Anything).Kind().String()) + }) + } + }) + + t.Run("Supports arbitrarily nested struct pointer fields", func(t *testing.T) { + type block struct { + NumberA int `river:"number_a,attr"` + NumberB *int `river:"number_b,attr"` + NumberC **int `river:"number_c,attr"` + NumberD ***int `river:"number_d,attr"` + } + + input := `some_block { + number_a = 15 + number_b = 20 + number_c = 25 + number_d = 30 + }` + eval := vm.New(parseBlock(t, input)) + + var actual block + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, 15, actual.NumberA) + require.Equal(t, 20, *actual.NumberB) + require.Equal(t, 25, **actual.NumberC) + require.Equal(t, 30, ***actual.NumberD) + }) + + t.Run("Supports squashed attributes", func(t *testing.T) { + type InnerStruct struct { + InnerField1 string `river:"inner_field_1,attr,optional"` + InnerField2 string `river:"inner_field_2,attr,optional"` + } + + type OuterStruct struct { + OuterField1 string `river:"outer_field_1,attr,optional"` + Inner InnerStruct `river:",squash"` + OuterField2 string `river:"outer_field_2,attr,optional"` + } + + var ( + input = `some_block { + outer_field_1 = "value1" + outer_field_2 = "value2" + inner_field_1 = "value3" + inner_field_2 = "value4" + }` + + expect = OuterStruct{ + OuterField1: "value1", + Inner: InnerStruct{ + InnerField1: "value3", + InnerField2: "value4", + }, + OuterField2: "value2", + } + ) + eval := vm.New(parseBlock(t, input)) + + var actual OuterStruct + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, expect, actual) + }) + + t.Run("Supports squashed attributes in pointers", func(t *testing.T) { + type InnerStruct struct { + InnerField1 string `river:"inner_field_1,attr,optional"` + InnerField2 string `river:"inner_field_2,attr,optional"` + } + + type OuterStruct struct { + OuterField1 string `river:"outer_field_1,attr,optional"` + Inner *InnerStruct `river:",squash"` + OuterField2 string `river:"outer_field_2,attr,optional"` + } + + var ( + input = `some_block { + outer_field_1 = "value1" + outer_field_2 = "value2" + inner_field_1 = "value3" + inner_field_2 = "value4" + }` + + expect = OuterStruct{ + OuterField1: "value1", + Inner: &InnerStruct{ + InnerField1: "value3", + InnerField2: "value4", + }, + OuterField2: "value2", + } + ) + eval := vm.New(parseBlock(t, input)) + + var actual OuterStruct + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, expect, actual) + }) +} + +func TestVM_Block_Children_Blocks(t *testing.T) { + type childBlock struct { + Attr bool `river:"attr,attr"` + } + + t.Run("Decodes children blocks", func(t *testing.T) { + type block struct { + Value int `river:"value,attr"` + Child childBlock `river:"child.block,block"` + } + + input := `some_block { + value = 15 + + child.block { attr = true } + }` + eval := vm.New(parseBlock(t, input)) + + var actual block + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, 15, actual.Value) + require.True(t, actual.Child.Attr) + }) + + t.Run("Decodes multiple instances of children blocks", func(t *testing.T) { + type block struct { + Value int `river:"value,attr"` + Children []childBlock `river:"child.block,block"` + } + + input := `some_block { + value = 10 + + child.block { attr = true } + child.block { attr = false } + child.block { attr = true } + }` + eval := vm.New(parseBlock(t, input)) + + var actual block + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, 10, actual.Value) + require.Len(t, actual.Children, 3) + require.Equal(t, true, actual.Children[0].Attr) + require.Equal(t, false, actual.Children[1].Attr) + require.Equal(t, true, actual.Children[2].Attr) + }) + + t.Run("Decodes multiple instances of children blocks into an array", func(t *testing.T) { + type block struct { + Value int `river:"value,attr"` + Children [3]childBlock `river:"child.block,block"` + } + + input := `some_block { + value = 15 + + child.block { attr = true } + child.block { attr = false } + child.block { attr = true } + }` + eval := vm.New(parseBlock(t, input)) + + var actual block + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, 15, actual.Value) + require.Equal(t, true, actual.Children[0].Attr) + require.Equal(t, false, actual.Children[1].Attr) + require.Equal(t, true, actual.Children[2].Attr) + }) + + t.Run("Fails if block used as an attribute", func(t *testing.T) { + type block struct { + Child childBlock `river:"child,block"` + } + + input := `some_block { + child = 15 + }` + eval := vm.New(parseBlock(t, input)) + + err := eval.Evaluate(nil, &block{}) + require.EqualError(t, err, `2:4: "child" must be a block, but is used as an attribute`) + }) + + t.Run("Fails if required children blocks are not present", func(t *testing.T) { + type block struct { + Value int `river:"value,attr"` + Child childBlock `river:"child.block,block"` + } + + input := `some_block { + value = 15 + }` + eval := vm.New(parseBlock(t, input)) + + err := eval.Evaluate(nil, &block{}) + require.EqualError(t, err, `missing required block "child.block"`) + }) + + t.Run("Succeeds if optional children blocks are not present", func(t *testing.T) { + type block struct { + Value int `river:"value,attr"` + Child childBlock `river:"child.block,block,optional"` + } + + input := `some_block { + value = 15 + }` + eval := vm.New(parseBlock(t, input)) + + var actual block + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, 15, actual.Value) + }) + + t.Run("Fails if child block is not defined in struct", func(t *testing.T) { + type block struct { + Value int `river:"value,attr"` + } + + input := `some_block { + value = 15 + + child.block { attr = true } + }` + eval := vm.New(parseBlock(t, input)) + + err := eval.Evaluate(nil, &block{}) + require.EqualError(t, err, `4:4: unrecognized block name "child.block"`) + }) + + t.Run("Supports arbitrarily nested struct pointer fields", func(t *testing.T) { + type block struct { + BlockA childBlock `river:"block_a,block"` + BlockB *childBlock `river:"block_b,block"` + BlockC **childBlock `river:"block_c,block"` + BlockD ***childBlock `river:"block_d,block"` + } + + input := `some_block { + block_a { attr = true } + block_b { attr = false } + block_c { attr = true } + block_d { attr = false } + }` + eval := vm.New(parseBlock(t, input)) + + var actual block + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, true, (actual.BlockA).Attr) + require.Equal(t, false, (*actual.BlockB).Attr) + require.Equal(t, true, (**actual.BlockC).Attr) + require.Equal(t, false, (***actual.BlockD).Attr) + }) + + t.Run("Supports squashed blocks", func(t *testing.T) { + type InnerStruct struct { + Inner1 childBlock `river:"inner_block_1,block"` + Inner2 childBlock `river:"inner_block_2,block"` + } + + type OuterStruct struct { + Outer1 childBlock `river:"outer_block_1,block"` + Inner InnerStruct `river:",squash"` + Outer2 childBlock `river:"outer_block_2,block"` + } + + var ( + input = `some_block { + outer_block_1 { attr = true } + outer_block_2 { attr = false } + inner_block_1 { attr = true } + inner_block_2 { attr = false } + }` + + expect = OuterStruct{ + Outer1: childBlock{Attr: true}, + Outer2: childBlock{Attr: false}, + Inner: InnerStruct{ + Inner1: childBlock{Attr: true}, + Inner2: childBlock{Attr: false}, + }, + } + ) + eval := vm.New(parseBlock(t, input)) + + var actual OuterStruct + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, expect, actual) + }) + + t.Run("Supports squashed blocks in pointers", func(t *testing.T) { + type InnerStruct struct { + Inner1 *childBlock `river:"inner_block_1,block"` + Inner2 *childBlock `river:"inner_block_2,block"` + } + + type OuterStruct struct { + Outer1 childBlock `river:"outer_block_1,block"` + Inner *InnerStruct `river:",squash"` + Outer2 childBlock `river:"outer_block_2,block"` + } + + var ( + input = `some_block { + outer_block_1 { attr = true } + outer_block_2 { attr = false } + inner_block_1 { attr = true } + inner_block_2 { attr = false } + }` + + expect = OuterStruct{ + Outer1: childBlock{Attr: true}, + Outer2: childBlock{Attr: false}, + Inner: &InnerStruct{ + Inner1: &childBlock{Attr: true}, + Inner2: &childBlock{Attr: false}, + }, + } + ) + eval := vm.New(parseBlock(t, input)) + + var actual OuterStruct + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, expect, actual) + }) + + // TODO(rfratto): decode all blocks into a []*ast.BlockStmt field. +} + +func TestVM_Block_Enum_Block(t *testing.T) { + type childBlock struct { + Attr int `river:"attr,attr"` + } + + type enumBlock struct { + BlockA *childBlock `river:"a,block,optional"` + BlockB *childBlock `river:"b,block,optional"` + BlockC *childBlock `river:"c,block,optional"` + BlockD *childBlock `river:"d,block,optional"` + } + + t.Run("Decodes enum blocks", func(t *testing.T) { + type block struct { + Value int `river:"value,attr"` + Blocks []*enumBlock `river:"child,enum,optional"` + } + + input := `some_block { + value = 15 + + child.a { attr = 1 } + }` + eval := vm.New(parseBlock(t, input)) + + expect := block{ + Value: 15, + Blocks: []*enumBlock{ + {BlockA: &childBlock{Attr: 1}}, + }, + } + + var actual block + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, expect, actual) + }) + + t.Run("Decodes multiple enum blocks", func(t *testing.T) { + type block struct { + Value int `river:"value,attr"` + Blocks []*enumBlock `river:"child,enum,optional"` + } + + input := `some_block { + value = 15 + + child.b { attr = 1 } + child.a { attr = 2 } + child.c { attr = 3 } + }` + eval := vm.New(parseBlock(t, input)) + + expect := block{ + Value: 15, + Blocks: []*enumBlock{ + {BlockB: &childBlock{Attr: 1}}, + {BlockA: &childBlock{Attr: 2}}, + {BlockC: &childBlock{Attr: 3}}, + }, + } + + var actual block + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, expect, actual) + }) + + t.Run("Decodes multiple enum blocks with repeating blocks", func(t *testing.T) { + type block struct { + Value int `river:"value,attr"` + Blocks []*enumBlock `river:"child,enum,optional"` + } + + input := `some_block { + value = 15 + + child.a { attr = 1 } + child.b { attr = 2 } + child.c { attr = 3 } + child.a { attr = 4 } + }` + eval := vm.New(parseBlock(t, input)) + + expect := block{ + Value: 15, + Blocks: []*enumBlock{ + {BlockA: &childBlock{Attr: 1}}, + {BlockB: &childBlock{Attr: 2}}, + {BlockC: &childBlock{Attr: 3}}, + {BlockA: &childBlock{Attr: 4}}, + }, + } + + var actual block + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, expect, actual) + }) +} + +func TestVM_Block_Label(t *testing.T) { + t.Run("Decodes label into string field", func(t *testing.T) { + type block struct { + Label string `river:",label"` + } + + input := `some_block "label_value_1" {}` + eval := vm.New(parseBlock(t, input)) + + var actual block + require.NoError(t, eval.Evaluate(nil, &actual)) + require.Equal(t, "label_value_1", actual.Label) + }) + + t.Run("Struct must have label field if block is labeled", func(t *testing.T) { + type block struct{} + + input := `some_block "label_value_2" {}` + eval := vm.New(parseBlock(t, input)) + + err := eval.Evaluate(nil, &block{}) + require.EqualError(t, err, `1:1: block "some_block" does not support specifying labels`) + }) + + t.Run("Block must have label if struct accepts label", func(t *testing.T) { + type block struct { + Label string `river:",label"` + } + + input := `some_block {}` + eval := vm.New(parseBlock(t, input)) + + err := eval.Evaluate(nil, &block{}) + require.EqualError(t, err, `1:1: block "some_block" requires non-empty label`) + }) + + t.Run("Block must have non-empty label if struct accepts label", func(t *testing.T) { + type block struct { + Label string `river:",label"` + } + + input := `some_block "" {}` + eval := vm.New(parseBlock(t, input)) + + err := eval.Evaluate(nil, &block{}) + require.EqualError(t, err, `1:1: block "some_block" requires non-empty label`) + }) +} + +func TestVM_Block_Unmarshaler(t *testing.T) { + type OuterBlock struct { + FieldA string `river:"field_a,attr"` + Settings Setting `river:"some.settings,block"` + } + + input := ` + field_a = "foobar" + some.settings { + field_a = "fizzbuzz" + field_b = "helloworld" + } + ` + + file, err := parser.ParseFile(t.Name(), []byte(input)) + require.NoError(t, err) + + eval := vm.New(file) + + var actual OuterBlock + require.NoError(t, eval.Evaluate(nil, &actual)) + require.True(t, actual.Settings.UnmarshalCalled, "UnmarshalRiver did not get invoked") + require.True(t, actual.Settings.DefaultCalled, "SetToDefault did not get invoked") + require.True(t, actual.Settings.ValidateCalled, "Validate did not get invoked") +} + +func TestVM_Block_UnmarshalToMap(t *testing.T) { + type OuterBlock struct { + Settings map[string]interface{} `river:"some.settings,block"` + } + + tt := []struct { + name string + input string + expect OuterBlock + expectError string + }{ + { + name: "decodes successfully", + input: ` + some.settings { + field_a = 12345 + field_b = "helloworld" + } + `, + expect: OuterBlock{ + Settings: map[string]interface{}{ + "field_a": 12345, + "field_b": "helloworld", + }, + }, + }, + { + name: "rejects labeled blocks", + input: ` + some.settings "foo" { + field_a = 12345 + } + `, + expectError: `block "some.settings" requires non-empty label`, + }, + + { + name: "rejects nested maps", + input: ` + some.settings { + inner_map { + field_a = 12345 + } + } + `, + expectError: "nested blocks not supported here", + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + file, err := parser.ParseFile(t.Name(), []byte(tc.input)) + require.NoError(t, err) + + eval := vm.New(file) + + var actual OuterBlock + err = eval.Evaluate(nil, &actual) + + if tc.expectError == "" { + require.NoError(t, err) + require.Equal(t, tc.expect, actual) + } else { + require.ErrorContains(t, err, tc.expectError) + } + }) + } +} + +func TestVM_Block_UnmarshalToAny(t *testing.T) { + type OuterBlock struct { + Settings any `river:"some.settings,block"` + } + + input := ` + some.settings { + field_a = 12345 + field_b = "helloworld" + } + ` + + file, err := parser.ParseFile(t.Name(), []byte(input)) + require.NoError(t, err) + + eval := vm.New(file) + + var actual OuterBlock + require.NoError(t, eval.Evaluate(nil, &actual)) + + expect := map[string]interface{}{ + "field_a": 12345, + "field_b": "helloworld", + } + require.Equal(t, expect, actual.Settings) +} + +type Setting struct { + FieldA string `river:"field_a,attr"` + FieldB string `river:"field_b,attr"` + + UnmarshalCalled bool + DefaultCalled bool + ValidateCalled bool +} + +func (s *Setting) UnmarshalRiver(f func(interface{}) error) error { + s.UnmarshalCalled = true + return f((*settingUnmarshalTarget)(s)) +} + +type settingUnmarshalTarget Setting + +func (s *settingUnmarshalTarget) SetToDefault() { + s.DefaultCalled = true +} + +func (s *settingUnmarshalTarget) Validate() error { + s.ValidateCalled = true + return nil +} + +func parseBlock(t *testing.T, input string) *ast.BlockStmt { + t.Helper() + + res, err := parser.ParseFile("", []byte(input)) + require.NoError(t, err) + require.Len(t, res.Body, 1) + + stmt, ok := res.Body[0].(*ast.BlockStmt) + require.True(t, ok, "Expected stmt to be a ast.BlockStmt, got %T", res.Body[0]) + return stmt +} diff --git a/syntax/vm/vm_errors_test.go b/syntax/vm/vm_errors_test.go new file mode 100644 index 0000000000..87acdd7b1b --- /dev/null +++ b/syntax/vm/vm_errors_test.go @@ -0,0 +1,80 @@ +package vm_test + +import ( + "testing" + + "github.com/grafana/river/parser" + "github.com/grafana/river/vm" + "github.com/stretchr/testify/require" +) + +func TestVM_ExprErrors(t *testing.T) { + type Target struct { + Key struct { + Object struct { + Field1 []int `river:"field1,attr"` + } `river:"object,attr"` + } `river:"key,attr"` + } + + tt := []struct { + name string + input string + into interface{} + scope *vm.Scope + expect string + }{ + { + name: "basic wrong type", + input: `key = true`, + into: &Target{}, + expect: "test:1:7: true should be object, got bool", + }, + { + name: "deeply nested literal", + input: ` + key = { + object = { + field1 = [15, 30, "Hello, world!"], + }, + } + `, + into: &Target{}, + expect: `test:4:25: "Hello, world!" should be number, got string`, + }, + { + name: "deeply nested indirect", + input: `key = key_value`, + into: &Target{}, + scope: &vm.Scope{ + Variables: map[string]interface{}{ + "key_value": map[string]interface{}{ + "object": map[string]interface{}{ + "field1": []interface{}{15, 30, "Hello, world!"}, + }, + }, + }, + }, + expect: `test:1:7: key_value.object.field1[2] should be number, got string`, + }, + { + name: "complex expr", + input: `key = [0, 1, 2]`, + into: &struct { + Key string `river:"key,attr"` + }{}, + expect: `test:1:7: [0, 1, 2] should be string, got array`, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + res, err := parser.ParseFile("test", []byte(tc.input)) + require.NoError(t, err) + + eval := vm.New(res) + err = eval.Evaluate(tc.scope, tc.into) + require.EqualError(t, err, tc.expect) + }) + } +} diff --git a/syntax/vm/vm_stdlib_test.go b/syntax/vm/vm_stdlib_test.go new file mode 100644 index 0000000000..f395f199e5 --- /dev/null +++ b/syntax/vm/vm_stdlib_test.go @@ -0,0 +1,232 @@ +package vm_test + +import ( + "fmt" + "reflect" + "testing" + + "github.com/grafana/river/internal/value" + "github.com/grafana/river/parser" + "github.com/grafana/river/rivertypes" + "github.com/grafana/river/vm" + "github.com/stretchr/testify/require" +) + +func TestVM_Stdlib(t *testing.T) { + t.Setenv("TEST_VAR", "Hello!") + + tt := []struct { + name string + input string + expect interface{} + }{ + {"env", `env("TEST_VAR")`, string("Hello!")}, + {"concat", `concat([true, "foo"], [], [false, 1])`, []interface{}{true, "foo", false, 1}}, + {"json_decode object", `json_decode("{\"foo\": \"bar\"}")`, map[string]interface{}{"foo": "bar"}}, + {"json_decode array", `json_decode("[0, 1, 2]")`, []interface{}{float64(0), float64(1), float64(2)}}, + {"json_decode nil field", `json_decode("{\"foo\": null}")`, map[string]interface{}{"foo": nil}}, + {"json_decode nil array element", `json_decode("[0, null]")`, []interface{}{float64(0), nil}}, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + expr, err := parser.ParseExpression(tc.input) + require.NoError(t, err) + + eval := vm.New(expr) + + rv := reflect.New(reflect.TypeOf(tc.expect)) + require.NoError(t, eval.Evaluate(nil, rv.Interface())) + require.Equal(t, tc.expect, rv.Elem().Interface()) + }) + } +} + +func TestStdlibCoalesce(t *testing.T) { + t.Setenv("TEST_VAR2", "Hello!") + + tt := []struct { + name string + input string + expect interface{} + }{ + {"coalesce()", `coalesce()`, value.Null}, + {"coalesce(string)", `coalesce("Hello!")`, string("Hello!")}, + {"coalesce(string, string)", `coalesce(env("TEST_VAR2"), "World!")`, string("Hello!")}, + {"(string, string) with fallback", `coalesce(env("NON_DEFINED"), "World!")`, string("World!")}, + {"coalesce(list, list)", `coalesce([], ["fallback"])`, []string{"fallback"}}, + {"coalesce(list, list) with fallback", `coalesce(concat(["item"]), ["fallback"])`, []string{"item"}}, + {"coalesce(int, int, int)", `coalesce(0, 1, 2)`, 1}, + {"coalesce(bool, int, int)", `coalesce(false, 1, 2)`, 1}, + {"coalesce(bool, bool)", `coalesce(false, true)`, true}, + {"coalesce(list, bool)", `coalesce(json_decode("[]"), true)`, true}, + {"coalesce(object, true) and return true", `coalesce(json_decode("{}"), true)`, true}, + {"coalesce(object, false) and return false", `coalesce(json_decode("{}"), false)`, false}, + {"coalesce(list, nil)", `coalesce([],null)`, value.Null}, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + expr, err := parser.ParseExpression(tc.input) + require.NoError(t, err) + + eval := vm.New(expr) + + rv := reflect.New(reflect.TypeOf(tc.expect)) + require.NoError(t, eval.Evaluate(nil, rv.Interface())) + require.Equal(t, tc.expect, rv.Elem().Interface()) + }) + } +} + +func TestStdlibJsonPath(t *testing.T) { + tt := []struct { + name string + input string + expect interface{} + }{ + {"json_path with simple json", `json_path("{\"a\": \"b\"}", ".a")`, []string{"b"}}, + {"json_path with simple json without results", `json_path("{\"a\": \"b\"}", ".nonexists")`, []string{}}, + {"json_path with json array", `json_path("[{\"name\": \"Department\",\"value\": \"IT\"},{\"name\":\"ReferenceNumber\",\"value\":\"123456\"},{\"name\":\"TestStatus\",\"value\":\"Pending\"}]", "[?(@.name == \"Department\")].value")`, []string{"IT"}}, + {"json_path with simple json and return first", `json_path("{\"a\": \"b\"}", ".a")[0]`, "b"}, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + expr, err := parser.ParseExpression(tc.input) + require.NoError(t, err) + + eval := vm.New(expr) + + rv := reflect.New(reflect.TypeOf(tc.expect)) + require.NoError(t, eval.Evaluate(nil, rv.Interface())) + require.Equal(t, tc.expect, rv.Elem().Interface()) + }) + } +} + +func TestStdlib_Nonsensitive(t *testing.T) { + scope := &vm.Scope{ + Variables: map[string]any{ + "secret": rivertypes.Secret("foo"), + "optionalSecret": rivertypes.OptionalSecret{Value: "bar"}, + }, + } + + tt := []struct { + name string + input string + expect interface{} + }{ + {"secret to string", `nonsensitive(secret)`, string("foo")}, + {"optional secret to string", `nonsensitive(optionalSecret)`, string("bar")}, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + expr, err := parser.ParseExpression(tc.input) + require.NoError(t, err) + + eval := vm.New(expr) + + rv := reflect.New(reflect.TypeOf(tc.expect)) + require.NoError(t, eval.Evaluate(scope, rv.Interface())) + require.Equal(t, tc.expect, rv.Elem().Interface()) + }) + } +} +func TestStdlib_StringFunc(t *testing.T) { + scope := &vm.Scope{ + Variables: map[string]any{}, + } + + tt := []struct { + name string + input string + expect interface{} + }{ + {"to_lower", `to_lower("String")`, "string"}, + {"to_upper", `to_upper("string")`, "STRING"}, + {"trimspace", `trim_space(" string \n\n")`, "string"}, + {"trimspace+to_upper+trim", `to_lower(to_upper(trim_space(" String ")))`, "string"}, + {"split", `split("/aaa/bbb/ccc/ddd", "/")`, []string{"", "aaa", "bbb", "ccc", "ddd"}}, + {"split+index", `split("/aaa/bbb/ccc/ddd", "/")[0]`, ""}, + {"join+split", `join(split("/aaa/bbb/ccc/ddd", "/"), "/")`, "/aaa/bbb/ccc/ddd"}, + {"join", `join(["foo", "bar", "baz"], ", ")`, "foo, bar, baz"}, + {"join w/ int", `join([0, 0, 1], ", ")`, "0, 0, 1"}, + {"format", `format("Hello %s", "World")`, "Hello World"}, + {"format+int", `format("%#v", 1)`, "1"}, + {"format+bool", `format("%#v", true)`, "true"}, + {"format+quote", `format("%q", "hello")`, `"hello"`}, + {"replace", `replace("Hello World", " World", "!")`, "Hello!"}, + {"trim", `trim("?!hello?!", "!?")`, "hello"}, + {"trim2", `trim(" hello! world.! ", "! ")`, "hello! world."}, + {"trim_prefix", `trim_prefix("helloworld", "hello")`, "world"}, + {"trim_suffix", `trim_suffix("helloworld", "world")`, "hello"}, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + expr, err := parser.ParseExpression(tc.input) + require.NoError(t, err) + + eval := vm.New(expr) + + rv := reflect.New(reflect.TypeOf(tc.expect)) + require.NoError(t, eval.Evaluate(scope, rv.Interface())) + require.Equal(t, tc.expect, rv.Elem().Interface()) + }) + } +} + +func BenchmarkConcat(b *testing.B) { + // There's a bit of setup work to do here: we want to create a scope holding + // a slice of the Person type, which has a fair amount of data in it. + // + // We then want to pass it through concat. + // + // If the code path is fully optimized, there will be no intermediate + // translations to interface{}. + type Person struct { + Name string `river:"name,attr"` + Attrs map[string]string `river:"attrs,attr"` + } + type Body struct { + Values []Person `river:"values,attr"` + } + + in := `values = concat(values_ref)` + f, err := parser.ParseFile("", []byte(in)) + require.NoError(b, err) + + eval := vm.New(f) + + valuesRef := make([]Person, 0, 20) + for i := 0; i < 20; i++ { + data := make(map[string]string, 20) + for j := 0; j < 20; j++ { + var ( + key = fmt.Sprintf("key_%d", i+1) + value = fmt.Sprintf("value_%d", i+1) + ) + data[key] = value + } + valuesRef = append(valuesRef, Person{ + Name: "Test Person", + Attrs: data, + }) + } + scope := &vm.Scope{ + Variables: map[string]interface{}{ + "values_ref": valuesRef, + }, + } + + // Reset timer before running the actual test + b.ResetTimer() + + for i := 0; i < b.N; i++ { + var b Body + _ = eval.Evaluate(scope, &b) + } +} diff --git a/syntax/vm/vm_test.go b/syntax/vm/vm_test.go new file mode 100644 index 0000000000..5591b08d88 --- /dev/null +++ b/syntax/vm/vm_test.go @@ -0,0 +1,277 @@ +package vm_test + +import ( + "reflect" + "strings" + "testing" + "unicode" + + "github.com/grafana/river/parser" + "github.com/grafana/river/scanner" + "github.com/grafana/river/token" + "github.com/grafana/river/vm" + "github.com/stretchr/testify/require" +) + +func TestVM_Evaluate_Literals(t *testing.T) { + tt := map[string]struct { + input string + expect interface{} + }{ + "number to int": {`12`, int(12)}, + "number to int8": {`13`, int8(13)}, + "number to int16": {`14`, int16(14)}, + "number to int32": {`15`, int32(15)}, + "number to int64": {`16`, int64(16)}, + "number to uint": {`17`, uint(17)}, + "number to uint8": {`18`, uint8(18)}, + "number to uint16": {`19`, uint16(19)}, + "number to uint32": {`20`, uint32(20)}, + "number to uint64": {`21`, uint64(21)}, + "number to float32": {`22`, float32(22)}, + "number to float64": {`23`, float64(23)}, + "number to string": {`24`, string("24")}, + + "float to float32": {`3.2`, float32(3.2)}, + "float to float64": {`3.5`, float64(3.5)}, + "float to string": {`3.9`, string("3.9")}, + + "float with dot to float32": {`.2`, float32(0.2)}, + "float with dot to float64": {`.5`, float64(0.5)}, + "float with dot to string": {`.9`, string("0.9")}, + + "string to string": {`"Hello, world!"`, string("Hello, world!")}, + "string to int": {`"12"`, int(12)}, + "string to float64": {`"12"`, float64(12)}, + } + + for name, tc := range tt { + t.Run(name, func(t *testing.T) { + expr, err := parser.ParseExpression(tc.input) + require.NoError(t, err) + + eval := vm.New(expr) + + vPtr := reflect.New(reflect.TypeOf(tc.expect)).Interface() + require.NoError(t, eval.Evaluate(nil, vPtr)) + + actual := reflect.ValueOf(vPtr).Elem().Interface() + require.Equal(t, tc.expect, actual) + }) + } +} + +func TestVM_Evaluate(t *testing.T) { + // Shared scope across all tests below + scope := &vm.Scope{ + Variables: map[string]interface{}{ + "foobar": int(42), + }, + } + + tt := []struct { + input string + expect interface{} + }{ + // Binops + {`true || false`, bool(true)}, + {`false || false`, bool(false)}, + {`true && false`, bool(false)}, + {`true && true`, bool(true)}, + {`3 == 5`, bool(false)}, + {`3 == 3`, bool(true)}, + {`3 != 5`, bool(true)}, + {`3 < 5`, bool(true)}, + {`3 <= 5`, bool(true)}, + {`3 > 5`, bool(false)}, + {`3 >= 5`, bool(false)}, + {`3 + 5`, int(8)}, + {`3 - 5`, int(-2)}, + {`3 * 5`, int(15)}, + {`3.0 / 5.0`, float64(0.6)}, + {`5 % 3`, int(2)}, + {`3 ^ 5`, int(243)}, + {`3 + 5 * 2`, int(13)}, // Chain multiple binops + {`42.0^-2`, float64(0.0005668934240362812)}, + + // Identifier + {`foobar`, int(42)}, + + // Arrays + {`[]`, []int{}}, + {`[0, 1, 2]`, []int{0, 1, 2}}, + {`[true, false]`, []bool{true, false}}, + + // Objects + {`{ a = 5, b = 10 }`, map[string]int{"a": 5, "b": 10}}, + { + input: `{ + name = "John Doe", + age = 42, + }`, + expect: struct { + Name string `river:"name,attr"` + Age int `river:"age,attr"` + Country string `river:"country,attr,optional"` + }{ + Name: "John Doe", + Age: 42, + }, + }, + + // Access + {`{ a = 15 }.a`, int(15)}, + {`{ a = { b = 12 } }.a.b`, int(12)}, + {`{}["foo"]`, nil}, + + // Indexing + {`[0, 1, 2][1]`, int(1)}, + {`[[1,2,3]][0][2]`, int(3)}, + {`[true, false][0]`, bool(true)}, + + // Paren + {`(15)`, int(15)}, + + // Unary + {`!true`, bool(false)}, + {`!false`, bool(true)}, + {`-15`, int(-15)}, + } + + for _, tc := range tt { + name := trimWhitespace(tc.input) + + t.Run(name, func(t *testing.T) { + expr, err := parser.ParseExpression(tc.input) + require.NoError(t, err) + + eval := vm.New(expr) + + var vPtr any + if tc.expect != nil { + vPtr = reflect.New(reflect.TypeOf(tc.expect)).Interface() + } else { + // Create a new any pointer. + vPtr = reflect.New(reflect.TypeOf((*any)(nil)).Elem()).Interface() + } + + require.NoError(t, eval.Evaluate(scope, vPtr)) + + actual := reflect.ValueOf(vPtr).Elem().Interface() + require.Equal(t, tc.expect, actual) + }) + } +} + +func TestVM_Evaluate_Null(t *testing.T) { + expr, err := parser.ParseExpression("null") + require.NoError(t, err) + + eval := vm.New(expr) + + var v interface{} + require.NoError(t, eval.Evaluate(nil, &v)) + require.Nil(t, v) +} + +func TestVM_Evaluate_IdentifierExpr(t *testing.T) { + t.Run("Valid lookup", func(t *testing.T) { + scope := &vm.Scope{ + Variables: map[string]interface{}{ + "foobar": 15, + }, + } + + expr, err := parser.ParseExpression(`foobar`) + require.NoError(t, err) + + eval := vm.New(expr) + + var actual int + require.NoError(t, eval.Evaluate(scope, &actual)) + require.Equal(t, 15, actual) + }) + + t.Run("Invalid lookup", func(t *testing.T) { + expr, err := parser.ParseExpression(`foobar`) + require.NoError(t, err) + + eval := vm.New(expr) + + var v interface{} + err = eval.Evaluate(nil, &v) + require.EqualError(t, err, `1:1: identifier "foobar" does not exist`) + }) +} + +func TestVM_Evaluate_AccessExpr(t *testing.T) { + t.Run("Lookup optional field", func(t *testing.T) { + type Person struct { + Name string `river:"name,attr,optional"` + } + + scope := &vm.Scope{ + Variables: map[string]interface{}{ + "person": Person{}, + }, + } + + expr, err := parser.ParseExpression(`person.name`) + require.NoError(t, err) + + eval := vm.New(expr) + + var actual string + require.NoError(t, eval.Evaluate(scope, &actual)) + require.Equal(t, "", actual) + }) + + t.Run("Invalid lookup 1", func(t *testing.T) { + expr, err := parser.ParseExpression(`{ a = 15 }.b`) + require.NoError(t, err) + + eval := vm.New(expr) + + var v interface{} + err = eval.Evaluate(nil, &v) + require.EqualError(t, err, `1:12: field "b" does not exist`) + }) + + t.Run("Invalid lookup 2", func(t *testing.T) { + _, err := parser.ParseExpression(`{ a = 15 }.7`) + require.EqualError(t, err, `1:11: expected TERMINATOR, got FLOAT`) + }) + + t.Run("Invalid lookup 3", func(t *testing.T) { + _, err := parser.ParseExpression(`{ a = { b = 12 }.7 }.a.b`) + require.EqualError(t, err, `1:17: missing ',' in field list`) + }) + + t.Run("Invalid lookup 4", func(t *testing.T) { + _, err := parser.ParseExpression(`{ a = { b = 12 } }.a.b.7`) + require.EqualError(t, err, `1:23: expected TERMINATOR, got FLOAT`) + }) +} + +func trimWhitespace(in string) string { + f := token.NewFile("") + + s := scanner.New(f, []byte(in), nil, 0) + + var out strings.Builder + + for { + _, tok, lit := s.Scan() + if tok == token.EOF { + break + } + + if lit != "" { + _, _ = out.WriteString(lit) + } else { + _, _ = out.WriteString(tok.String()) + } + } + + return strings.TrimFunc(out.String(), unicode.IsSpace) +}