From 30c8e4a7bf127b2a743ef14dbe1ffdfa012b5a27 Mon Sep 17 00:00:00 2001 From: Ludovic Cleroux Date: Fri, 13 May 2022 11:18:01 +0200 Subject: [PATCH] Database Replication POC --- .prow/push-core-backend-test.yaml | 2 +- cmd/serve_db.go | 37 + go.mod | 73 +- go.sum | 12 +- pkg/server/data/api/changestream.go | 33 + pkg/server/data/api/const.go | 13 + pkg/server/data/api/errors.go | 102 +++ pkg/server/data/api/i.go | 13 + pkg/server/data/api/interface.go | 83 +++ pkg/server/data/api/record.go | 169 +++++ pkg/server/data/api/record_test.go | 172 +++++ pkg/server/data/api/revision.go | 109 +++ pkg/server/data/api/revision_test.go | 106 +++ pkg/server/data/api/utils.go | 51 ++ pkg/server/data/api/utils_test.go | 38 + pkg/server/data/api/valuekind_string.go | 27 + pkg/server/data/api/values.go | 734 +++++++++++++++++++ pkg/server/data/client/client.go | 180 +++++ pkg/server/data/client/client_http.go | 69 ++ pkg/server/data/engine/.gitignore | 2 + pkg/server/data/engine/engine.go | 850 ++++++++++++++++++++++ pkg/server/data/engine/engine_test.go | 279 +++++++ pkg/server/data/engine/pkg.go | 47 ++ pkg/server/data/engine/reconciler.go | 208 ++++++ pkg/server/data/engine/reconciler_test.go | 100 +++ pkg/server/data/engine/utils.go | 61 ++ pkg/server/data/engine/validation.go | 64 ++ pkg/server/data/engine/zzz_test.go | 97 +++ pkg/server/data/handler/get_changes.go | 41 ++ pkg/server/data/handler/get_row.go | 45 ++ pkg/server/data/handler/handler.go | 118 +++ pkg/server/data/handler/put_row.go | 46 ++ pkg/server/data/handler/put_table.go | 43 ++ pkg/server/data/server.go | 69 ++ pkg/server/data/server_test.go | 148 ++++ pkg/server/data/test/assertions.go | 9 + pkg/server/data/test/db_recorder.go | 50 ++ pkg/server/data/test/db_utils.go | 44 ++ pkg/server/data/test/mock_clock.go | 32 + pkg/server/data/test/mock_uuid.go | 10 + pkg/server/data/utils/clock.go | 9 + pkg/server/data/utils/db_result_reader.go | 83 +++ pkg/server/data/utils/md5_generator.go | 32 + pkg/server/data/utils/transaction.go | 56 ++ pkg/server/data/utils/uuid_generator.go | 49 ++ pkg/server/generic/server.go | 18 +- pkg/server/options/options.go | 1 + 47 files changed, 4616 insertions(+), 18 deletions(-) create mode 100644 cmd/serve_db.go create mode 100644 pkg/server/data/api/changestream.go create mode 100644 pkg/server/data/api/const.go create mode 100644 pkg/server/data/api/errors.go create mode 100644 pkg/server/data/api/i.go create mode 100644 pkg/server/data/api/interface.go create mode 100644 pkg/server/data/api/record.go create mode 100644 pkg/server/data/api/record_test.go create mode 100644 pkg/server/data/api/revision.go create mode 100644 pkg/server/data/api/revision_test.go create mode 100644 pkg/server/data/api/utils.go create mode 100644 pkg/server/data/api/utils_test.go create mode 100644 pkg/server/data/api/valuekind_string.go create mode 100644 pkg/server/data/api/values.go create mode 100644 pkg/server/data/client/client.go create mode 100644 pkg/server/data/client/client_http.go create mode 100644 pkg/server/data/engine/.gitignore create mode 100644 pkg/server/data/engine/engine.go create mode 100644 pkg/server/data/engine/engine_test.go create mode 100644 pkg/server/data/engine/pkg.go create mode 100644 pkg/server/data/engine/reconciler.go create mode 100644 pkg/server/data/engine/reconciler_test.go create mode 100644 pkg/server/data/engine/utils.go create mode 100644 pkg/server/data/engine/validation.go create mode 100644 pkg/server/data/engine/zzz_test.go create mode 100644 pkg/server/data/handler/get_changes.go create mode 100644 pkg/server/data/handler/get_row.go create mode 100644 pkg/server/data/handler/handler.go create mode 100644 pkg/server/data/handler/put_row.go create mode 100644 pkg/server/data/handler/put_table.go create mode 100644 pkg/server/data/server.go create mode 100644 pkg/server/data/server_test.go create mode 100644 pkg/server/data/test/assertions.go create mode 100644 pkg/server/data/test/db_recorder.go create mode 100644 pkg/server/data/test/db_utils.go create mode 100644 pkg/server/data/test/mock_clock.go create mode 100644 pkg/server/data/test/mock_uuid.go create mode 100644 pkg/server/data/utils/clock.go create mode 100644 pkg/server/data/utils/db_result_reader.go create mode 100644 pkg/server/data/utils/md5_generator.go create mode 100644 pkg/server/data/utils/transaction.go create mode 100644 pkg/server/data/utils/uuid_generator.go diff --git a/.prow/push-core-backend-test.yaml b/.prow/push-core-backend-test.yaml index 578e4a8e4..6c2899c60 100644 --- a/.prow/push-core-backend-test.yaml +++ b/.prow/push-core-backend-test.yaml @@ -4,7 +4,7 @@ presubmits: decorate: true spec: containers: - - image: golang:1.16 + - image: golang:1.18 command: [ "bash", "-c" ] args: - > diff --git a/cmd/serve_db.go b/cmd/serve_db.go new file mode 100644 index 000000000..6df92c175 --- /dev/null +++ b/cmd/serve_db.go @@ -0,0 +1,37 @@ +package cmd + +import ( + "context" + + "github.com/nrc-no/core/pkg/server/data" + "github.com/spf13/cobra" +) + +// serveDataCmd represents the data command +var serveDataCmd = &cobra.Command{ + Use: "data", + Short: "starts the data server", + RunE: func(cmd *cobra.Command, args []string) error { + if err := serveDb(ctx, + data.Options{ + ServerOptions: coreOptions.Serve.Login, + }); err != nil { + return err + } + <-doneSignal + return nil + }, +} + +func init() { + serveCmd.AddCommand(serveDataCmd) +} + +func serveDb(ctx context.Context, options data.Options) error { + server, err := data.NewServer(options) + if err != nil { + return err + } + server.Start(ctx) + return nil +} diff --git a/go.mod b/go.mod index f4078e5b9..622040832 100644 --- a/go.mod +++ b/go.mod @@ -1,10 +1,9 @@ module github.com/nrc-no/core -go 1.16 +go 1.18 require ( github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff - github.com/cheekybits/is v0.0.0-20150225183255-68e9c0620927 // indirect github.com/coreos/go-oidc/v3 v3.0.0 github.com/dustinkirkland/golang-petname v0.0.0-20191129215211-8e5a1ed0cff0 github.com/emicklei/go-restful-openapi/v2 v2.6.0 @@ -20,13 +19,13 @@ require ( github.com/gorilla/securecookie v1.1.1 github.com/gorilla/sessions v1.2.1 github.com/jackc/pgconn v1.10.1 + github.com/jmoiron/sqlx v1.3.5 github.com/juju/fslock v0.0.0-20160525022230-4d5c94c67b4b github.com/lib/pq v1.10.4 github.com/lithammer/shortuuid/v3 v3.0.7 github.com/looplab/fsm v0.3.0 github.com/manifoldco/promptui v0.9.0 - github.com/matryer/try v0.0.0-20161228173917-9ac251b645a2 // indirect - github.com/mattn/go-sqlite3 v1.14.9 + github.com/mattn/go-sqlite3 v1.14.12 github.com/ory/hydra-client-go v1.10.6 github.com/rs/cors v1.8.0 github.com/satori/go.uuid v1.2.0 @@ -34,16 +33,76 @@ require ( github.com/spf13/cobra v1.2.1 github.com/spf13/viper v1.9.0 github.com/stretchr/testify v1.7.0 - go.mongodb.org/mongo-driver v1.5.3 // indirect go.uber.org/zap v1.19.1 golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c - golang.org/x/sys v0.0.0-20211023085530-d6a326fbbf70 // indirect golang.org/x/tools v0.1.8 gopkg.in/matryer/try.v1 v1.0.0-20150601225556-312d2599e12e - gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b gorm.io/driver/postgres v1.2.1 gorm.io/driver/sqlite v1.2.3 gorm.io/gorm v1.22.4 ) + +require ( + github.com/PuerkitoBio/purell v1.1.1 // indirect + github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 // indirect + github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef // indirect + github.com/cheekybits/is v0.0.0-20150225183255-68e9c0620927 // indirect + github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/go-openapi/analysis v0.20.0 // indirect + github.com/go-openapi/errors v0.20.1 // indirect + github.com/go-openapi/jsonpointer v0.19.5 // indirect + github.com/go-openapi/jsonreference v0.19.5 // indirect + github.com/go-openapi/loads v0.20.2 // indirect + github.com/go-openapi/spec v0.20.3 // indirect + github.com/go-openapi/swag v0.19.15 // indirect + github.com/go-openapi/validate v0.20.2 // indirect + github.com/go-stack/stack v1.8.0 // indirect + github.com/golang/protobuf v1.5.2 // indirect + github.com/google/uuid v1.2.0 // indirect + github.com/hashicorp/hcl v1.0.0 // indirect + github.com/inconshreveable/mousetrap v1.0.0 // indirect + github.com/jackc/chunkreader/v2 v2.0.1 // indirect + github.com/jackc/pgio v1.0.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgproto3/v2 v2.1.1 // indirect + github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect + github.com/jackc/pgtype v1.8.1 // indirect + github.com/jackc/pgx/v4 v4.13.0 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.3 // indirect + github.com/josharian/intern v1.0.0 // indirect + github.com/json-iterator/go v1.1.11 // indirect + github.com/magiconair/properties v1.8.5 // indirect + github.com/mailru/easyjson v0.7.6 // indirect + github.com/matryer/try v0.0.0-20161228173917-9ac251b645a2 // indirect + github.com/mitchellh/mapstructure v1.4.2 // indirect + github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect + github.com/modern-go/reflect2 v1.0.1 // indirect + github.com/oklog/ulid v1.3.1 // indirect + github.com/opentracing/opentracing-go v1.2.0 // indirect + github.com/pelletier/go-toml v1.9.4 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/spf13/afero v1.6.0 // indirect + github.com/spf13/cast v1.4.1 // indirect + github.com/spf13/jwalterweatherman v1.1.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + github.com/subosito/gotenv v1.2.0 // indirect + github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 // indirect + go.mongodb.org/mongo-driver v1.5.3 // indirect + go.uber.org/atomic v1.7.0 // indirect + go.uber.org/multierr v1.6.0 // indirect + golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 // indirect + golang.org/x/mod v0.5.1 // indirect + golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f // indirect + golang.org/x/sys v0.0.0-20211023085530-d6a326fbbf70 // indirect + golang.org/x/text v0.3.7 // indirect + golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect + google.golang.org/appengine v1.6.7 // indirect + google.golang.org/protobuf v1.27.1 // indirect + gopkg.in/ini.v1 v1.63.2 // indirect + gopkg.in/square/go-jose.v2 v2.6.0 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect +) diff --git a/go.sum b/go.sum index a68f5ce24..615f8f2b1 100644 --- a/go.sum +++ b/go.sum @@ -237,6 +237,8 @@ github.com/go-openapi/validate v0.20.2/go.mod h1:e7OJoKNgd0twXZwIn0A43tHbvIcr/rZ github.com/go-playground/locales v0.12.1/go.mod h1:IUMDtCfWo/w/mtMfIE/IG2K+Ey3ygWanZIBtBW0W2TM= github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= +github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gobuffalo/attrs v0.0.0-20190224210810-a9411de4debd/go.mod h1:4duuawTqi2wkkpB4ePgWMaai6/Kc6WEz83bhFwpHzj0= @@ -390,7 +392,6 @@ github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1: github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= -github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= @@ -412,7 +413,6 @@ github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5W github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= @@ -445,6 +445,8 @@ github.com/jinzhu/now v1.1.3 h1:PlHq1bSCSZL9K0wUhbm2pGLoTWs2GwVhsP6emvGV/ZI= github.com/jinzhu/now v1.1.3/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= +github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g= +github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ= github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= @@ -513,8 +515,10 @@ github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2y github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcMEpPG5Rm84= github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= -github.com/mattn/go-sqlite3 v1.14.9 h1:10HX2Td0ocZpYEjhilsuo6WWtUqttj2Kb0KtD86/KYA= +github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/go-sqlite3 v1.14.9/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/mattn/go-sqlite3 v1.14.12 h1:TJ1bhYJPV44phC+IMu1u2K/i5RriLTPe+yc68XDJ1Z0= +github.com/mattn/go-sqlite3 v1.14.12/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso= github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= @@ -637,7 +641,6 @@ github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= -github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= go.etcd.io/etcd/api/v3 v3.5.0/go.mod h1:cbVKeC6lCfl7j/8jBhAK6aIYO9XOjdptoxU/nLQcPvs= go.etcd.io/etcd/client/pkg/v3 v3.5.0/go.mod h1:IJHfcCEKxYu1Os13ZdwCwIUTUVGYTSAM3YSwc9/Ac1g= @@ -884,7 +887,6 @@ golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210823070655-63515b42dcdf/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211023085530-d6a326fbbf70 h1:SeSEfdIxyvwGJliREIJhRPPXvW6sDlLT+UQ3B0hD0NA= golang.org/x/sys v0.0.0-20211023085530-d6a326fbbf70/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= diff --git a/pkg/server/data/api/changestream.go b/pkg/server/data/api/changestream.go new file mode 100644 index 000000000..d7826a59e --- /dev/null +++ b/pkg/server/data/api/changestream.go @@ -0,0 +1,33 @@ +package api + +import ( + "encoding/json" + "fmt" +) + +type Changes struct { + Items []ChangeItem `json:"items"` +} + +func (c Changes) String() string { + jsonBytes, err := json.Marshal(c) + if err != nil { + return fmt.Sprintf("%v", err) + } + return string(jsonBytes) +} + +type ChangeItem struct { + Sequence int64 `json:"sequence"` + TableName string `json:"table_name"` + RecordID string `json:"record_id"` + RecordRevision Revision `json:"record_revision"` +} + +func (c ChangeItem) String() string { + jsonBytes, err := json.Marshal(c) + if err != nil { + return fmt.Sprintf("%v", err) + } + return string(jsonBytes) +} diff --git a/pkg/server/data/api/const.go b/pkg/server/data/api/const.go new file mode 100644 index 000000000..563c45c45 --- /dev/null +++ b/pkg/server/data/api/const.go @@ -0,0 +1,13 @@ +package api + +const ( + KeyRecordID = "_id" + KeyRevision = "_rev" + KeyPrevision = "_prev" + KeyDeleted = "_deleted" + ChangeStreamTableName = "_changes" + KeyCSSequence = "seq" + KeyCSTableName = "table_name" + KeyCSRecordID = "record_id" + KeyCSRecordRevision = "record_rev" +) diff --git a/pkg/server/data/api/errors.go b/pkg/server/data/api/errors.go new file mode 100644 index 000000000..a9e2e8af2 --- /dev/null +++ b/pkg/server/data/api/errors.go @@ -0,0 +1,102 @@ +package api + +type Error struct { + Message string + Code ErrorCode +} + +type ErrorCode uint8 + +const ( + ErrCodeInvalidRevision ErrorCode = iota + ErrCodeInvalidPrevision + ErrCodeInvalidRecordID + ErrCodeDuplicateField + ErrCodeMissingRevision + ErrCodeInvalidTable + ErrCodeInvalidColumnType + ErrCodeEmptyTableColumns + ErrCodeDuplicateColumnName + ErrCodeInvalidColumnName + ErrCodeRecordNotFound + ErrCodeFieldNotFound + ErrCodeInternalError + ErrCodeTableAlreadyExists + ErrCodeUnsupportedDialect + ErrCodeInvalidTimestamp + ErrCodeMissingId +) + +func (e *Error) Error() string { + return e.Message +} + +func (e *Error) ErrorCode() ErrorCode { + return e.Code +} + +func (e *Error) Is(other error) bool { + if other == nil { + return false + } + if e == other { + return true + } + if o, ok := other.(*Error); ok { + return e.Code == o.Code + } + return false +} + +func IsError(err error, code ErrorCode) bool { + if err == nil { + return false + } + if e, ok := err.(*Error); ok { + return e.Code == code + } + return false +} + +func NewError(code ErrorCode, message string) *Error { + return &Error{ + Message: message, + Code: code, + } +} + +var ( + ErrInvalidRevision = NewError(ErrCodeInvalidRevision, "invalid revision") + ErrInvalidPrevision = NewError(ErrCodeInvalidPrevision, "invalid previous revision") + ErrInvalidRecordID = NewError(ErrCodeInvalidRecordID, "invalid record id") + ErrDuplicateField = NewError(ErrCodeDuplicateField, "duplicate field") + ErrMissingRevision = NewError(ErrCodeMissingRevision, "missing revision") + ErrInvalidTableName = NewError(ErrCodeInvalidTable, "invalid table name") + ErrEmptyColumns = NewError(ErrCodeEmptyTableColumns, "empty columns") + ErrDuplicateColumnName = NewError(ErrCodeDuplicateColumnName, "duplicate column name") + ErrInvalidColumnName = NewError(ErrCodeInvalidColumnName, "invalid column name") + ErrRecordNotFound = NewError(ErrCodeRecordNotFound, "record not found") + ErrFieldNotFound = NewError(ErrCodeFieldNotFound, "field not found") + ErrInvalidColumnType = NewError(ErrCodeInvalidColumnType, "invalid column type") + ErrTableAlreadyExists = NewError(ErrCodeTableAlreadyExists, "table already exists") + ErrUnsupportedDialect = NewError(ErrCodeUnsupportedDialect, "unsupported dialect") + ErrInvalidValueType = NewError(ErrCodeInternalError, "invalid value type") + ErrInvalidTimestamp = NewError(ErrCodeInvalidTimestamp, "invalid timestamp") + ErrMissingId = NewError(ErrCodeMissingId, "missing id") +) + +func NewDuplicateColumnNameErr(name string) *Error { + return NewError(ErrCodeDuplicateColumnName, "duplicate column name: "+name) +} + +func NewInvalidColumnNameErr(name string) *Error { + return NewError(ErrCodeInvalidColumnName, "invalid column name: "+name) +} + +func NewInvalidColumnTypeErr(name string) *Error { + return NewError(ErrCodeInvalidColumnType, "invalid column type: "+name) +} + +func NewTableAlreadyExistsErr(name string) *Error { + return NewError(ErrCodeTableAlreadyExists, "table already exists: "+name) +} diff --git a/pkg/server/data/api/i.go b/pkg/server/data/api/i.go new file mode 100644 index 000000000..8113e2dce --- /dev/null +++ b/pkg/server/data/api/i.go @@ -0,0 +1,13 @@ +package api + +type PutRecordOptions struct { + IsNew bool +} + +type PutRecordOption func(o *PutRecordOptions) + +var IsNew = func(isNew bool) PutRecordOption { + return func(o *PutRecordOptions) { + o.IsNew = isNew + } +} diff --git a/pkg/server/data/api/interface.go b/pkg/server/data/api/interface.go new file mode 100644 index 000000000..0dbe61138 --- /dev/null +++ b/pkg/server/data/api/interface.go @@ -0,0 +1,83 @@ +package api + +import "context" + +type GetRecordRequest struct { + TableName string + RecordID string + Revision Revision +} + +type GetChangesRequest struct { + Since int64 +} + +type ReadInterface interface { + // GetRecord gets a single record from the database. + // If the record does not exist, an error is returned. + GetRecord(ctx context.Context, request GetRecordRequest) (Record, error) + // GetChangeStream gets a change stream for a table + GetChangeStream(ctx context.Context, request GetChangesRequest) (Changes, error) +} + +type PutRecordRequest struct { + Record Record + IsReplication bool +} + +type WriteInterface interface { + // PutRecord puts a single record inside the database. + PutRecord(ctx context.Context, request PutRecordRequest) (Record, error) + // CreateTable creates a new table in the database. + CreateTable(ctx context.Context, table Table) (Table, error) +} + +type Engine interface { + ReadInterface + WriteInterface +} + +type TxFactory func(ctx context.Context) (Transaction, error) + +type Transaction interface { + Query(ctx context.Context, query string, args []interface{}) (ResultReader, error) + Exec(ctx context.Context, query string, args []interface{}) (interface{}, error) + Commit() error + Rollback() error +} + +// Rand generates random bytes +type Rand interface { + // Read puts random bytes into the given buffer. + Read(b []byte) (n int, err error) +} + +// UUIDGenerator generates UUIDs +type UUIDGenerator interface { + // Generate generates a UUID + Generate() (string, error) +} + +// ResultReader reads results from a query +type ResultReader interface { + // Next returns the next result + Next() bool + // Read returns the current result + Read(columnKinds []ValueKind) (map[string]Value, error) + // Close closes the reader + Close() error + // Err returns the last error + Err() error +} + +// RevisionGenerator generates revision hashes +type RevisionGenerator interface { + // Generate generates a revision Hash + Generate(num int, data map[string]interface{}) Revision +} + +// Clock provides a clock +type Clock interface { + // Now returns the current time + Now() int64 +} diff --git a/pkg/server/data/api/record.go b/pkg/server/data/api/record.go new file mode 100644 index 000000000..f11dc75d4 --- /dev/null +++ b/pkg/server/data/api/record.go @@ -0,0 +1,169 @@ +package api + +import ( + "encoding/json" +) + +type Attributes map[string]Value + +func NewAttributes() Attributes { + return make(Attributes) +} + +func (a Attributes) WithString(key string, value string) Attributes { + a[key] = NewStringValue(value, true) + return a +} + +func (a Attributes) WithInt(key string, value int64) Attributes { + a[key] = NewIntValue(value, true) + return a +} + +func (a Attributes) WithFloat(key string, value float64) Attributes { + a[key] = NewFloatValue(value, true) + return a +} + +func (a Attributes) WithBool(key string, value bool) Attributes { + a[key] = NewBoolValue(value, true) + return a +} + +// Record represents a record in a database +type Record struct { + ID string `json:"id"` + Table string `json:"table"` + Revision Revision `json:"revision"` + PreviousRevision Revision `json:"-"` + Attributes Attributes `json:"attributes"` +} + +func (r *Record) UnmarshalJSON(data []byte) error { + type record struct { + ID string `json:"id"` + Table string `json:"table"` + Revision Revision `json:"revision"` + Attributes Attributes `json:"attributes"` + } + var rr record + if err := json.Unmarshal(data, &rr); err != nil { + return err + } + r.ID = rr.ID + r.Table = rr.Table + r.Revision = rr.Revision + r.Attributes = rr.Attributes + if r.Attributes == nil { + r.Attributes = make(Attributes) + } + return nil +} + +// String returns a string representation of the record +func (r Record) String() string { + jsonBytes, err := json.Marshal(r) + if err != nil { + return "" + } + return string(jsonBytes) +} + +// SetFieldValue sets the value of a field +func (r Record) SetFieldValue(name string, value Value) Record { + if r.Attributes == nil { + r.Attributes = make(map[string]Value) + } + r.Attributes[name] = value + return r +} + +// HasField returns true if the record has a field with the given name +func (r Record) HasField(name string) bool { + _, ok := r.Attributes[name] + return ok +} + +// GetFieldValue returns the value of the field with the given name +// Or an error if the field does not exist +func (r Record) GetFieldValue(name string) (Value, error) { + if r.Attributes == nil { + return Value{}, ErrFieldNotFound + } + value, ok := r.Attributes[name] + if !ok { + return Value{}, ErrFieldNotFound + } + return value, nil +} + +// GetID returns the ID of the record +// or empty string if the record does not have an ID field +func (r Record) GetID() string { + return r.ID +} + +// GetRevision returns the revision of the record +// or empty string if the record does not have a revision field +func (r Record) GetRevision() Revision { + return r.Revision +} + +// Table represents a database table +type Table struct { + // Name of the table + Name string `json:"name"` + // Columns of the table + Columns []Column `json:"columns"` + // Constraints of the table + Constraints []TableConstraint `json:"constraints"` +} + +func (t Table) String() string { + jsonBytes, err := json.Marshal(t) + if err != nil { + return "" + } + return string(jsonBytes) +} + +// TableList represents a list of tables +type TableList struct { + // Items is the list of tables + Items []Table `json:"items"` +} + +// TableConstraint represents a SQL table constraint +type TableConstraint struct { + PrimaryKey *PrimaryKeyTableConstraint `json:"primary_key"` +} + +// PrimaryKeyTableConstraint represents a primary key table constraint +type PrimaryKeyTableConstraint struct { + // Columns of the primary key + Columns []string `json:"columns"` +} + +// Column represents a database column +type Column struct { + // Name of the column + Name string `json:"name"` + // Type is the data type of the column + Type string `json:"type"` + // Default value of the column + Default string `json:"default"` + // Constraints of the column + Constraints []ColumnConstraint `json:"constraints"` +} + +// ColumnConstraint represents a SQL column constraint +type ColumnConstraint struct { + NotNull *NotNullColumnConstraint `json:"not_null"` + PrimaryKey *PrimaryKeyColumnConstraint `json:"primary_key"` +} + +// NotNullColumnConstraint represents a not null column constraint +type NotNullColumnConstraint struct{} + +// PrimaryKeyColumnConstraint represents a primary key column constraint +type PrimaryKeyColumnConstraint struct{} diff --git a/pkg/server/data/api/record_test.go b/pkg/server/data/api/record_test.go new file mode 100644 index 000000000..0218f69b0 --- /dev/null +++ b/pkg/server/data/api/record_test.go @@ -0,0 +1,172 @@ +package api + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_Record_UnmarshalJSON(t *testing.T) { + type testCase struct { + name string + input []byte + expected Record + expectErr assert.ErrorAssertionFunc + } + testCases := []testCase{ + { + name: "empty", + input: []byte(`{}`), + expected: Record{ + Attributes: Attributes{}, + }, + }, { + name: "with id", + input: []byte(`{"id":"abc"}`), + expected: Record{ + ID: "abc", + Attributes: Attributes{}, + }, + }, { + name: "with revision", + input: []byte(`{"revision": "1-96fc52d8fbf5d2adc6d139cb5b2ea099"}`), + expected: Record{ + Revision: Revision{ + Num: 1, + Hash: "96fc52d8fbf5d2adc6d139cb5b2ea099", + }, + Attributes: Attributes{}, + }, + }, { + name: "with data", + input: []byte(` + { + "attributes": { + "foo":{"string":"bar"} + } + }`), + expected: Record{ + Attributes: Attributes{ + "foo": NewStringValue("bar", true), + }, + }, + }, { + name: "full", + input: []byte(`{ + "id": "abc", + "revision": "1-96fc52d8fbf5d2adc6d139cb5b2ea099", + "attributes": { + "foo": {"string":"bar"}, + "bar": {"null":true} + } + }`), + expected: Record{ + ID: "abc", + Revision: Revision{ + Num: 1, + Hash: "96fc52d8fbf5d2adc6d139cb5b2ea099", + }, + Attributes: Attributes{ + "foo": NewStringValue("bar", true), + "bar": NewNullValue(), + }, + }, + }, + } + for _, tc := range testCases { + testCase := tc + t.Run(testCase.name, func(t *testing.T) { + record := Record{} + err := json.Unmarshal(testCase.input, &record) + if testCase.expectErr == nil { + testCase.expectErr = assert.NoError + } + if !testCase.expectErr(t, err) { + return + } + if err != nil { + return + } + assert.Equal(t, testCase.expected, record) + }) + } +} + +func Test_Record_MarshalJSON(t *testing.T) { + type testCase struct { + name string + columns map[string]ValueKind + input Record + expected string + expectErr assert.ErrorAssertionFunc + } + testCases := []testCase{ + { + name: "empty", + input: Record{}, + expected: `{"id":"","table":"","revision":"","attributes":null}`, + }, { + name: "with id", + input: Record{ + ID: "abc", + }, + expected: `{"id":"abc","table":"","revision":"","attributes":null}`, + }, { + name: "with revision", + input: Record{ + Revision: Revision{ + Num: 1, + Hash: "96fc52d8fbf5d2adc6d139cb5b2ea099", + }, + }, + expected: `{"id":"","table":"","revision":"1-96fc52d8fbf5d2adc6d139cb5b2ea099","attributes":null}`, + }, { + name: "with table", + input: Record{ + Table: "abc", + }, + expected: `{"id":"","table":"abc","revision":"","attributes":null}`, + }, { + name: "with fields", + input: Record{ + Attributes: Attributes{ + "bar": NewStringValue("", false), + "foo": NewStringValue("bar", true), + }, + }, + expected: `{"id":"","table":"","revision":"","attributes":{"bar":{"null":true},"foo":{"string":"bar"}}}`, + }, { + name: "full", + input: Record{ + ID: "abc", + Revision: Revision{ + Num: 1, + Hash: "96fc52d8fbf5d2adc6d139cb5b2ea099", + }, + Table: "foo", + Attributes: Attributes{ + "bar": NewNullValue(), + "foo": NewStringValue("bar", true), + }, + }, + expected: `{"id":"abc","table":"foo","revision":"1-96fc52d8fbf5d2adc6d139cb5b2ea099","attributes":{"bar":{"null":true},"foo":{"string":"bar"}}}`, + }, + } + for _, tc := range testCases { + testCase := tc + t.Run(testCase.name, func(t *testing.T) { + actual, err := json.Marshal(testCase.input) + if testCase.expectErr == nil { + testCase.expectErr = assert.NoError + } + if !testCase.expectErr(t, err) { + return + } + if err != nil { + return + } + assert.Equal(t, testCase.expected, string(actual)) + }) + } +} diff --git a/pkg/server/data/api/revision.go b/pkg/server/data/api/revision.go new file mode 100644 index 000000000..2f2c8ad3d --- /dev/null +++ b/pkg/server/data/api/revision.go @@ -0,0 +1,109 @@ +package api + +import "encoding/json" + +var digitTable = [10]string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9"} + +// Revision represents the revision of a record +// It contains a revision number and a Hash +// A valid format a revision string +// is "revision:-" +type Revision struct { + // Num is the revision number + Num int + // Hash is the Hash of the record + Hash string +} + +// MarshalJSON implements the json.Marshaler interface +func (r Revision) MarshalJSON() ([]byte, error) { + return json.Marshal(r.String()) +} + +// UnmarshalJSON implements the json.Unmarshaler interface +func (r *Revision) UnmarshalJSON(data []byte) error { + var str string + if err := json.Unmarshal(data, &str); err != nil { + return err + } + rev, err := ParseRevision(str) + if err != nil { + return err + } + *r = rev + return nil +} + +// EmptyRevision is the zero value of a Revision +var EmptyRevision = Revision{} + +// NewRevision creates a new revision with the +// given number and hash +func NewRevision(num int, hash string) Revision { + return Revision{num, hash} +} + +// String returns the string representation of the revision +// Format: "-" +func (r Revision) String() string { + if len(r.Hash) == 0 { + return "" + } + var ret string + digits := getDigits(r.Num) + for _, d := range digits { + ret += digitTable[d] + } + ret += "-" + r.Hash + return ret +} + +func (r Revision) IsEmpty() bool { + return r.Num == 0 && len(r.Hash) == 0 +} + +// ParseRevision parses a revision string +// the format is - +func ParseRevision(str string) (rev Revision, err error) { + if len(str) == 0 { + return EmptyRevision, nil + } + // The length of the hash is 32 (md5) + // The length of the separator is 1 + // The length of the number is at least 1 + // so the minimum length is 34 + if len(str) < 34 { + return rev, NewError(ErrCodeInvalidRevision, "Invalid revision length") + } + hash := str[len(str)-32:] + // the hash can only contain hex characters + for _, c := range hash { + if c < '0' || c > 'f' { + return rev, NewError(ErrCodeInvalidRevision, "Invalid character in hash") + } + } + // the number is from position 0 to position len(str)-33 + numStr := str[:len(str)-33] + for i, c := range numStr { + // the number cannot start with a zero + if i == 0 && c == '0' { + return rev, NewError(ErrCodeInvalidRevision, "Invalid character in version") + } + // the number can only contain digits + if c < '0' || c > '9' { + return rev, NewError(ErrCodeInvalidRevision, "Invalid character in version") + } + } + // the number is parsed as an int + num, err := parseInt(numStr) + if err != nil { + return rev, NewError(ErrCodeInvalidRevision, "Invalid version") + } + // check that the separator is correct + if str[len(numStr)] != '-' { + return rev, NewError(ErrCodeInvalidRevision, "Invalid revision") + } + rev.Num = num + rev.Hash = hash + return rev, nil +} diff --git a/pkg/server/data/api/revision_test.go b/pkg/server/data/api/revision_test.go new file mode 100644 index 000000000..c6ac6c16f --- /dev/null +++ b/pkg/server/data/api/revision_test.go @@ -0,0 +1,106 @@ +package api + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_parseRevision(t *testing.T) { + tests := []struct { + name string + rev string + want Revision + wantErr assert.ErrorAssertionFunc + }{ + { + name: "valid revision", + rev: "1-96fc52d8fbf5d2adc6d139cb5b2ea099", + want: Revision{ + Num: 1, + Hash: "96fc52d8fbf5d2adc6d139cb5b2ea099", + }, + }, { + name: "another valid revision", + rev: "1232-96fc52d8fbf5d2adc6d139cb5b2ea099", + want: Revision{ + Num: 1232, + Hash: "96fc52d8fbf5d2adc6d139cb5b2ea099", + }, + }, { + name: "extra chars", + rev: "1-96fc52d8fbf5d2adc6d139cb5b2ea099-", + wantErr: assert.Error, + }, { + name: "missing hash", + rev: "1-", + wantErr: assert.Error, + }, { + name: "invalid hash chars", + rev: "1-9z9z9z9z9z9z9z9z9z9z9z9z9z9z9z9z", + wantErr: assert.Error, + }, { + name: "leading 0", + rev: "0123-96fc52d8fbf5d2adc6d139cb5b2ea099", + wantErr: assert.Error, + }, { + name: "invalid hash length", + rev: "1-9f2", + wantErr: assert.Error, + }, { + name: "invalid version number", + rev: "abc-9f2", + wantErr: assert.Error, + }, + } + for _, tt := range tests { + if tt.wantErr == nil { + tt.wantErr = assert.NoError + } + t.Run(tt.name, func(t *testing.T) { + got, err := ParseRevision(tt.rev) + if !tt.wantErr(t, err, fmt.Sprintf("ParseRevision(%v)", tt.rev)) { + return + } + assert.Equalf(t, tt.want, got, "ParseRevision(%v)", tt.rev) + }) + } +} + +func Test_revision_String(t *testing.T) { + type fields struct { + num int + hash string + } + tests := []struct { + name string + fields fields + want string + }{ + { + name: "revision1", + fields: fields{ + num: 1, + hash: "96fc52d8fbf5d2adc6d139cb5b2ea099", + }, + want: "1-96fc52d8fbf5d2adc6d139cb5b2ea099", + }, { + name: "revision2", + fields: fields{ + num: 1232, + hash: "96fc52d8fbf5d2adc6d139cb5b2ea099", + }, + want: "1232-96fc52d8fbf5d2adc6d139cb5b2ea099", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := Revision{ + Num: tt.fields.num, + Hash: tt.fields.hash, + } + assert.Equalf(t, tt.want, r.String(), "String()") + }) + } +} diff --git a/pkg/server/data/api/utils.go b/pkg/server/data/api/utils.go new file mode 100644 index 000000000..1636377a9 --- /dev/null +++ b/pkg/server/data/api/utils.go @@ -0,0 +1,51 @@ +package api + +// getDigits returns an array of individual digits of a number +// useful for formatting numbers to strings +func getDigits(i int) []int { + if i == 0 { + return []int{0} + } + var ret []int + for i > 0 { + ret = append(ret, i%10) + i /= 10 + } + return reverseInts(ret) +} + +// reverseInts returns a reversed copy of an array of integers +func reverseInts(ints []int) []int { + var ret []int + for i := len(ints) - 1; i >= 0; i-- { + ret = append(ret, ints[i]) + } + return ret +} + +// parseInt parses a string into an integer +// it is a zero-dependant version of strconv.Atoi(s) +func parseInt(s string) (int, error) { + if s == "" { + return 0, NewError(ErrCodeInternalError, "cannot parse empty string") + } + var ret int + for i := len(s) - 1; i >= 0; i-- { + c := s[i] + if c < '0' || c > '9' { + return 0, NewError(ErrCodeInternalError, "cannot parse string:"+s) + } + ret += int(c-'0') * pow(10, len(s)-i-1) + } + return ret, nil +} + +// pow returns x raised to the power of y +// it is a zero-dependant version of math.Pow(x, y) +func pow(x, y int) int { + ret := 1 + for i := 0; i < y; i++ { + ret *= x + } + return ret +} diff --git a/pkg/server/data/api/utils_test.go b/pkg/server/data/api/utils_test.go new file mode 100644 index 000000000..be3b1ebeb --- /dev/null +++ b/pkg/server/data/api/utils_test.go @@ -0,0 +1,38 @@ +package api + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_getDigits(t *testing.T) { + tests := []struct { + name string + i int + want []int + }{ + { + name: "0", + i: 0, + want: []int{0}, + }, { + name: "1", + i: 1, + want: []int{1}, + }, { + name: "10", + i: 10, + want: []int{1, 0}, + }, { + name: "105", + i: 105, + want: []int{1, 0, 5}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, getDigits(tt.i), "getDigits(%v)", tt.i) + }) + } +} diff --git a/pkg/server/data/api/valuekind_string.go b/pkg/server/data/api/valuekind_string.go new file mode 100644 index 000000000..fd0dba873 --- /dev/null +++ b/pkg/server/data/api/valuekind_string.go @@ -0,0 +1,27 @@ +// Code generated by "stringer -type=ValueKind"; DO NOT EDIT. + +package api + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[ValueKindNull-0] + _ = x[ValueKindInt-1] + _ = x[ValueKindFloat-2] + _ = x[ValueKindString-3] + _ = x[ValueKindBool-4] +} + +const _ValueKind_name = "ValueKindNullValueKindIntValueKindFloatValueKindStringValueKindBool" + +var _ValueKind_index = [...]uint8{0, 13, 25, 39, 54, 67} + +func (i ValueKind) String() string { + if i >= ValueKind(len(_ValueKind_index)-1) { + return "ValueKind(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _ValueKind_name[_ValueKind_index[i]:_ValueKind_index[i+1]] +} diff --git a/pkg/server/data/api/values.go b/pkg/server/data/api/values.go new file mode 100644 index 000000000..566ee8919 --- /dev/null +++ b/pkg/server/data/api/values.go @@ -0,0 +1,734 @@ +package api + +import ( + "bytes" + "database/sql" + "database/sql/driver" + "encoding/json" + "errors" + "fmt" + "math" + "reflect" + "strconv" +) + +var nullBytes = []byte("null") + +// String is a nullable string. +// It does not consider empty strings to be null. +// It will decode to null, not "" when null. +// It implements json.Marshaler and json.Unmarshaler. +// It also implements sql.Scanner and sql.Valuer to marshal and unmarshal itself. +// So it is both database and json compatible. +type String struct { + sql.NullString +} + +// StringFrom creates a new String that will always be non-null. +func StringFrom(s string) String { + return NewString(s, true) +} + +// StringFromPtr creates a new String that be null if s is nil. +func StringFromPtr(s *string) String { + if s == nil { + return NewString("", false) + } + return NewString(*s, true) +} + +// ValueOrZero returns the inner value if valid, otherwise empty string +func (ns String) ValueOrZero() string { + if !ns.Valid { + return "" + } + return ns.String +} + +// UnmarshalJSON implements json.Unmarshaler. +func (ns *String) UnmarshalJSON(data []byte) error { + if bytes.Equal(data, nullBytes) { + ns.Valid = false + return nil + } + ns.Valid = true + if err := json.Unmarshal(data, &ns.String); err != nil { + return err + } + return nil +} + +// MarshalJSON implements json.Marshaler. +func (ns String) MarshalJSON() ([]byte, error) { + if !ns.Valid { + return []byte("null"), nil + } + return json.Marshal(ns.String) +} + +// NewString creates a new String +func NewString(s string, valid bool) String { + return String{ + NullString: sql.NullString{ + String: s, + Valid: valid, + }, + } +} + +// MarshalText implements encoding.TextMarshaler. +func (ns String) MarshalText() ([]byte, error) { + if ns.Valid { + return []byte(ns.String), nil + } + return nil, nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (ns *String) UnmarshalText(text []byte) error { + ns.String = string(text) + ns.Valid = true + return nil +} + +// SetValid changes this String's value and also sets it to be non-null. +func (ns *String) SetValid(v string) { + ns.String = v + ns.Valid = true +} + +// IsZero returns true for invalid Strings, for Go omitempty tag support. +func (ns *String) IsZero() bool { + return !ns.Valid +} + +// Equal returns true if the other String is equal to this one. +func (ns String) Equal(other String) bool { + if !ns.Valid && !other.Valid { + return true + } + if !ns.Valid || !other.Valid { + return false + } + return ns.String == other.String +} + +// Ptr returns a pointer to this String's value, or a nil pointer if this String is invalid. +func (ns String) Ptr() *string { + if !ns.Valid { + return nil + } + return &ns.String +} + +// Bool is a nullable bool. +// It does not default to false +// It will decode to null, not false when null. +// It implements json.Marshaler and json.Unmarshaler. +// It also implements sql.Scanner and sql.Valuer to marshal and unmarshal itself. +// So it is both database and json compatible. +type Bool struct { + sql.NullBool +} + +// BoolFrom creates a new Bool that will always be non-null. +func BoolFrom(s bool) Bool { + return NewBool(s, true) +} + +// BoolFromPtr creates a new Bool that be null if s is nil. +func BoolFromPtr(b *bool) Bool { + if b == nil { + return NewBool(false, false) + } + return NewBool(*b, true) +} + +// ValueOrZero returns the inner value if valid, otherwise false +func (b Bool) ValueOrZero() bool { + if !b.Valid { + return false + } + return b.Bool +} + +// UnmarshalJSON implements json.Unmarshaler. +func (b *Bool) UnmarshalJSON(data []byte) error { + str := string(data) + switch str { + case "", "null": + b.Valid = false + return nil + case "true": + b.Valid = true + b.Bool = true + return nil + case "false": + b.Valid = true + b.Bool = false + return nil + default: + return fmt.Errorf("invalid boolean value: %s", str) + } +} + +// MarshalJSON implements json.Marshaler. +func (b Bool) MarshalJSON() ([]byte, error) { + if !b.Valid { + return []byte("null"), nil + } + if b.Bool { + return []byte("true"), nil + } + return []byte("false"), nil +} + +// NewBool creates a new Bool +func NewBool(b bool, valid bool) Bool { + return Bool{ + NullBool: sql.NullBool{ + Bool: b, + Valid: valid, + }, + } +} + +// MarshalText implements encoding.TextMarshaler. +func (b Bool) MarshalText() ([]byte, error) { + if !b.Valid { + return []byte{}, nil + } + if b.Bool { + return []byte("true"), nil + } + return []byte("false"), nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (b *Bool) UnmarshalText(text []byte) error { + str := string(text) + switch str { + case "", "null": + b.Valid = false + return nil + case "true": + b.Valid = true + b.Bool = true + return nil + case "false": + b.Valid = true + b.Bool = false + return nil + default: + return fmt.Errorf("invalid boolean value: %s", str) + } +} + +// SetValid changes this Bool's value and also sets it to be non-null. +func (b *Bool) SetValid(v bool) { + b.Bool = v + b.Valid = true +} + +// IsZero returns true for invalid Bools, for Go omitempty tag support. +func (b *Bool) IsZero() bool { + return !b.Valid +} + +// Equal returns true if the other Bool is equal to this one. +func (b Bool) Equal(other Bool) bool { + if !b.Valid && !other.Valid { + return true + } + if !b.Valid || !other.Valid { + return false + } + return b.Bool == other.Bool +} + +// Ptr returns a pointer to this Bool's value, or a nil pointer if this Bool is invalid. +func (b Bool) Ptr() *bool { + if !b.Valid { + return nil + } + return &b.Bool +} + +// Int is a nullable int. +// It does not default to 0 +// It will decode to null, not 0 when null. +// It implements json.Marshaler and json.Unmarshaler. +// It also implements sql.Scanner and sql.Valuer to marshal and unmarshal itself. +// So it is both database and json compatible. +type Int struct { + sql.NullInt64 +} + +// IntFrom creates a new Int that will always be non-null. +func IntFrom(s int64) Int { + return NewInt(s, true) +} + +// IntFromPtr creates a new Int that be null if s is nil. +func IntFromPtr(b *int64) Int { + if b == nil { + return NewInt(0, false) + } + return NewInt(*b, true) +} + +// ValueOrZero returns the inner value if valid, otherwise false +func (b Int) ValueOrZero() int64 { + if !b.Valid { + return 0 + } + return b.Int64 +} + +// UnmarshalJSON implements json.Unmarshaler. +func (b *Int) UnmarshalJSON(data []byte) error { + if bytes.Equal(data, nullBytes) { + b.Valid = false + return nil + } + if err := json.Unmarshal(data, &b.Int64); err != nil { + var typeError *json.UnmarshalTypeError + if errors.As(err, &typeError) { + if typeError.Value != "string" { + return err + } + var str string + if err := json.Unmarshal(data, &str); err != nil { + return err + } + b.Int64, err = strconv.ParseInt(str, 10, 64) + if err != nil { + return err + } + b.Valid = true + return nil + } + return err + } + b.Valid = true + return nil +} + +// MarshalJSON implements json.Marshaler. +func (b Int) MarshalJSON() ([]byte, error) { + if !b.Valid { + return []byte("null"), nil + } + return []byte(strconv.FormatInt(b.Int64, 10)), nil +} + +// NewInt creates a new Int +func NewInt(i int64, valid bool) Int { + return Int{ + NullInt64: sql.NullInt64{ + Int64: i, + Valid: valid, + }, + } +} + +// MarshalText implements encoding.TextMarshaler. +func (b Int) MarshalText() ([]byte, error) { + if !b.Valid { + return []byte{}, nil + } + return []byte(strconv.FormatInt(b.Int64, 10)), nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (b *Int) UnmarshalText(text []byte) error { + if len(text) == 0 { + b.Valid = false + return nil + } + var err error + b.Int64, err = strconv.ParseInt(string(text), 10, 64) + if err != nil { + return err + } + b.Valid = true + return nil +} + +// SetValid changes this Int's value and also sets it to be non-null. +func (b *Int) SetValid(v int64) { + b.Int64 = v + b.Valid = true +} + +// IsZero returns true for invalid Ints, for Go omitempty tag support. +func (b *Int) IsZero() bool { + return !b.Valid +} + +// Equal returns true if the other Int is equal to this one. +func (b Int) Equal(other Int) bool { + if !b.Valid && !other.Valid { + return true + } + if !b.Valid || !other.Valid { + return false + } + return b.Int64 == other.Int64 +} + +// Ptr returns a point64er to this Int's value, or a nil point64er if this Int is invalid. +func (b Int) Ptr() *int64 { + if !b.Valid { + return nil + } + return &b.Int64 +} + +// Float is a nullable float. +// It does not default to 0 +// It will decode to null, not 0 when null. +// It implements json.Marshaler and json.Unmarshaler. +// It also implements sql.Scanner and sql.Valuer to marshal and unmarshal itself. +// So it is both database and json compatible. +type Float struct { + sql.NullFloat64 +} + +// FloatFrom creates a new Float that will always be non-null. +func FloatFrom(s float64) Float { + return NewFloat(s, true) +} + +// FloatFromPtr creates a new Float that be null if s is nil. +func FloatFromPtr(b *float64) Float { + if b == nil { + return NewFloat(0, false) + } + return NewFloat(*b, true) +} + +// ValueOrZero returns the inner value if valid, otherwise false +func (b Float) ValueOrZero() float64 { + if !b.Valid { + return 0 + } + return b.Float64 +} + +// UnmarshalJSON implements json.Unmarshaler. +func (b *Float) UnmarshalJSON(data []byte) error { + if bytes.Equal(data, nullBytes) { + b.Valid = false + return nil + } + if err := json.Unmarshal(data, &b.Float64); err != nil { + var typeError *json.UnmarshalTypeError + if errors.As(err, &typeError) { + if typeError.Value != "string" { + return err + } + var str string + if err := json.Unmarshal(data, &str); err != nil { + return err + } + b.Float64, err = strconv.ParseFloat(str, 64) + if err != nil { + return err + } + b.Valid = true + return nil + } + return err + } + b.Valid = true + return nil +} + +// MarshalJSON implements json.Marshaler. +func (b Float) MarshalJSON() ([]byte, error) { + if !b.Valid { + return []byte("null"), nil + } + if math.IsInf(b.Float64, 0) || math.IsNaN(b.Float64) { + return nil, &json.UnsupportedValueError{ + Value: reflect.ValueOf(b.Float64), + Str: strconv.FormatFloat(b.Float64, 'g', -1, 64), + } + } + return []byte(strconv.FormatFloat(b.Float64, 'f', -1, 64)), nil +} + +// NewFloat creates a new Float +func NewFloat(i float64, valid bool) Float { + return Float{ + NullFloat64: sql.NullFloat64{ + Float64: i, + Valid: valid, + }, + } +} + +// MarshalText implements encoding.TextMarshaler. +func (b Float) MarshalText() ([]byte, error) { + if !b.Valid { + return []byte{}, nil + } + return []byte(strconv.FormatFloat(b.Float64, 'f', -1, 64)), nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (b *Float) UnmarshalText(text []byte) error { + if len(text) == 0 { + b.Valid = false + return nil + } + var err error + b.Float64, err = strconv.ParseFloat(string(text), 64) + if err != nil { + return err + } + b.Valid = true + return nil +} + +// SetValid changes this Float's value and also sets it to be non-null. +func (b *Float) SetValid(v float64) { + b.Float64 = v + b.Valid = true +} + +// IsZero returns true for invalid Floats, for Go omitempty tag support. +func (b *Float) IsZero() bool { + return !b.Valid +} + +// Equal returns true if the other Float is equal to this one. +func (b Float) Equal(other Float) bool { + if !b.Valid && !other.Valid { + return true + } + if !b.Valid || !other.Valid { + return false + } + return b.Float64 == other.Float64 +} + +// Ptr returns a pofloat64er to this Float's value, or a nil pofloat64er if this Float is invalid. +func (b Float) Ptr() *float64 { + if !b.Valid { + return nil + } + return &b.Float64 +} + +// ValueKind represents the kind of value stored in a Value. +type ValueKind uint8 + +//go:generate go run golang.org/x/tools/cmd/stringer -type=ValueKind + +const ( + // ValueKindNull represents a null value. + ValueKindNull ValueKind = iota + // ValueKindInt is a ValueKind representing an int64. + ValueKindInt + // ValueKindFloat is a ValueKind representing a float64. + ValueKindFloat + // ValueKindString is a ValueKind representing a string. + ValueKindString + // ValueKindBool is a ValueKind representing a bool. + ValueKindBool +) + +// Value represents a value of any kind. +type Value struct { + // Kind is the kind of value this is. + Kind ValueKind `json:"-"` + // String is the string value if Kind is ValueKindString. + String *String `json:"string,omitempty"` + // Bool is the bool value if Kind is ValueKindBool. + Bool *Bool `json:"bool,omitempty"` + // Int is the int64 value if Kind is ValueKindInt. + Int *Int `json:"int,omitempty"` + // Float is the float64 value if Kind is ValueKindFloat. + Float *Float `json:"float,omitempty"` +} + +// Scan implements the sql.Scanner interface. +func (v *Value) Scan(value interface{}) error { + switch v.Kind { + case ValueKindInt: + v.Int = &Int{} + return v.Int.Scan(value) + case ValueKindFloat: + v.Float = &Float{} + return v.Float.Scan(value) + case ValueKindString: + v.String = &String{} + return v.String.Scan(value) + case ValueKindBool: + v.Bool = &Bool{} + return v.Bool.Scan(value) + default: + return fmt.Errorf("unknown ValueKind %d", v.Kind) + } +} + +// Value implements the driver.Valuer interface. +func (v *Value) Value() (driver.Value, error) { + switch v.Kind { + case ValueKindNull: + return nil, nil + case ValueKindInt: + return v.Int.Value() + case ValueKindFloat: + return v.Float.Value() + case ValueKindString: + return v.String.Value() + case ValueKindBool: + return v.Bool.Value() + default: + return nil, fmt.Errorf("unknown ValueKind %d", v.Kind) + } +} + +func (v *Value) UnmarshalJSON(data []byte) error { + v.String = nil + v.Bool = nil + v.Int = nil + v.Float = nil + type value struct { + Null bool `json:"null"` + String string `json:"string,omitempty"` + Int string `json:"int,omitempty"` + Float string `json:"float,omitempty"` + Bool string `json:"bool,omitempty"` + } + var val value + if err := json.Unmarshal(data, &val); err != nil { + return err + } + if val.Null { + v.Kind = ValueKindNull + return nil + } + if val.String != "" { + v.Kind = ValueKindString + v.String = &String{} + return v.String.UnmarshalText([]byte(val.String)) + } + if val.Int != "" { + v.Kind = ValueKindInt + v.Int = &Int{} + return v.Int.UnmarshalText([]byte(val.Int)) + } + if val.Float != "" { + v.Kind = ValueKindFloat + v.Float = &Float{} + return v.Float.UnmarshalText([]byte(val.Float)) + } + if val.Bool != "" { + v.Kind = ValueKindBool + v.Bool = &Bool{} + return v.Bool.UnmarshalText([]byte(val.Bool)) + } + return fmt.Errorf("invalid value: %s", data) +} + +// MarshalJSON implements json.Marshaler. +func (v Value) MarshalJSON() ([]byte, error) { + var ret = map[string]interface{}{} + switch v.Kind { + case ValueKindNull: + ret["null"] = true + return json.Marshal(ret) + case ValueKindInt: + if v.Int.IsZero() { + ret["null"] = true + } else { + text, err := v.Int.MarshalText() + if err != nil { + return nil, err + } + ret["int"] = string(text) + } + return json.Marshal(ret) + case ValueKindFloat: + if v.Float.IsZero() { + ret["null"] = true + } else { + text, err := v.Float.MarshalText() + if err != nil { + return nil, err + } + ret["float"] = string(text) + } + return json.Marshal(ret) + case ValueKindString: + if v.String.IsZero() { + ret["null"] = true + } else { + text, err := v.String.MarshalText() + if err != nil { + return nil, err + } + ret["string"] = string(text) + } + return json.Marshal(ret) + case ValueKindBool: + if v.Bool.IsZero() { + ret["null"] = true + } else { + text, err := v.Bool.MarshalText() + if err != nil { + return nil, err + } + ret["bool"] = string(text) + } + return json.Marshal(ret) + } + return nil, fmt.Errorf("unsupported value kind %d", v.Kind) +} + +// NewStringValue creates a new Value with a string value. +func NewStringValue(s string, valid bool) Value { + val := NewString(s, valid) + return Value{ + Kind: ValueKindString, + String: &val, + } +} + +// NewBoolValue creates a new Value with a bool value. +func NewBoolValue(b bool, valid bool) Value { + val := NewBool(b, valid) + return Value{ + Kind: ValueKindBool, + Bool: &val, + } +} + +// NewIntValue creates a new Value with an int64 value. +func NewIntValue(i int64, valid bool) Value { + val := NewInt(i, valid) + return Value{ + Kind: ValueKindInt, + Int: &val, + } +} + +// NewFloatValue creates a new Value with a float64 value. +func NewFloatValue(f float64, valid bool) Value { + val := NewFloat(f, valid) + return Value{ + Kind: ValueKindFloat, + Float: &val, + } +} + +// NewNullValue creates a new null Value +func NewNullValue() Value { + return Value{ + Kind: ValueKindNull, + } +} diff --git a/pkg/server/data/client/client.go b/pkg/server/data/client/client.go new file mode 100644 index 000000000..c3d8291d5 --- /dev/null +++ b/pkg/server/data/client/client.go @@ -0,0 +1,180 @@ +package client + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" +) + +type client struct { + client *http.Client +} + +func (c *client) Method(method string) *request { + return &request{ + client: c, + method: method, + } +} + +func (c *client) Post() *request { + return c.Method("POST") +} + +func (c *client) Get() *request { + return c.Method("GET") +} + +func (c *client) Put() *request { + return c.Method("PUT") +} + +func (c *client) Delete() *request { + return c.Method("DELETE") +} + +type request struct { + err error + method string + client *client + queryParams map[string]string + headers map[string]string + body []byte + url string +} + +func (r *request) URL(url string) *request { + r.url = url + return r +} + +func (r *request) WithMethod(method string) *request { + r.method = method + return r +} + +func (r *request) WithQueryParam(key, value string) *request { + if r.err != nil { + return r + } + if r.queryParams == nil { + r.queryParams = make(map[string]string) + } + r.queryParams[key] = value + return r +} + +func (r *request) WithHeader(key, value string) *request { + if r.err != nil { + return r + } + if r.headers == nil { + r.headers = make(map[string]string) + } + r.headers[key] = value + return r +} + +func (r *request) WithBody(body interface{}) *request { + if r.err != nil { + return r + } + switch b := body.(type) { + case []byte: + r.body = b + case string: + r.body = []byte(b) + default: + jsonBytes, err := json.Marshal(b) + if err != nil { + r.err = err + return r + } + r.body = jsonBytes + } + return r +} + +func (r *request) Do(ctx context.Context) *response { + if r.err != nil { + return &response{err: r.err} + } + var body []byte + if r.body != nil { + body = r.body + } + req, err := http.NewRequestWithContext(ctx, r.method, r.url, bytes.NewBuffer(body)) + if err != nil { + return &response{err: r.err} + } + if req.Header == nil { + req.Header = make(http.Header) + } + for k, v := range r.headers { + req.Header.Set(k, v) + } + if len(r.body) > 0 && (r.headers == nil || r.headers["Content-Type"] == "") { + req.Header.Set("Content-Type", "application/json") + } + if r.method == "POST" || r.method == "PUT" { + if r.headers["Accept"] == "" { + req.Header.Set("Accept", "application/json") + } + } + for k, v := range r.queryParams { + q := req.URL.Query() + q.Add(k, v) + req.URL.RawQuery = q.Encode() + } + resp, err := r.client.client.Do(req) + if err != nil { + return &response{err: r.err} + } + defer func() { + if err := resp.Body.Close(); err != nil { + r.err = err + } + }() + body, err = ioutil.ReadAll(resp.Body) + if err != nil { + return &response{err: r.err} + } + if resp.StatusCode < 200 || resp.StatusCode > 299 { + err := fmt.Errorf("Error occured with status code %d: %s\n", resp.StatusCode, string(body)) + return &response{ + err: err, + statusCode: resp.StatusCode, + body: body, + url: resp.Request.URL.String(), + headers: resp.Header, + } + } + return &response{ + statusCode: resp.StatusCode, + body: body, + headers: resp.Header, + url: resp.Request.URL.String(), + } + +} + +type response struct { + err error + statusCode int + headers http.Header + body []byte + url string +} + +func (r *response) Into(v interface{}) error { + if r.err != nil { + return r.err + } + if r.body == nil { + return nil + } + return json.Unmarshal(r.body, v) +} diff --git a/pkg/server/data/client/client_http.go b/pkg/server/data/client/client_http.go new file mode 100644 index 000000000..b28cd3702 --- /dev/null +++ b/pkg/server/data/client/client_http.go @@ -0,0 +1,69 @@ +package client + +import ( + "context" + "net/http" + "strconv" + + "github.com/nrc-no/core/pkg/server/data/api" +) + +type httpClient struct { + baseURL string + client *client +} + +type HTTPClient interface { + GetRecord(ctx context.Context, request api.GetRecordRequest) (api.Record, error) + PutRecord(ctx context.Context, request api.PutRecordRequest) (api.Record, error) + CreateTable(ctx context.Context, request api.Table) (api.Table, error) + GetChanges(ctx context.Context, request api.GetChangesRequest) (api.Changes, error) +} + +func NewClient(baseURL string) HTTPClient { + return &httpClient{ + client: &client{client: http.DefaultClient}, + baseURL: baseURL, + } +} + +func (c *httpClient) GetRecord(ctx context.Context, request api.GetRecordRequest) (api.Record, error) { + var response api.Record + err := c.client.Get(). + URL(c.baseURL+"/apis/data.nrc.no/v1/tables/"+request.TableName+"/records/"+request.RecordID). + WithHeader("Accept", "application/json"). + WithQueryParam("revision", request.Revision.String()). + Do(ctx).Into(&response) + return response, err +} + +func (c *httpClient) GetChanges(ctx context.Context, request api.GetChangesRequest) (api.Changes, error) { + var response api.Changes + err := c.client.Get(). + URL(c.baseURL+"/apis/data.nrc.no/v1/changes"). + WithHeader("Accept", "application/json"). + WithQueryParam("since", strconv.FormatInt(request.Since, 10)). + Do(ctx).Into(&response) + return response, err +} + +func (c *httpClient) PutRecord(ctx context.Context, request api.PutRecordRequest) (api.Record, error) { + var response api.Record + err := c.client.Put(). + URL(c.baseURL+"/apis/data.nrc.no/v1/tables/"+request.Record.Table+"/records/"+request.Record.ID). + WithQueryParam("replication", strconv.FormatBool(request.IsReplication)). + WithBody(request). + Do(ctx).Into(&response) + return response, err +} + +func (c *httpClient) CreateTable(ctx context.Context, request api.Table) (api.Table, error) { + var response api.Table + err := c.client.Put(). + URL(c.baseURL+"/apis/data.nrc.no/v1/tables/"+request.Name). + WithHeader("Accept", "application/json"). + WithHeader("Content-Type", "application/json"). + WithBody(request). + Do(ctx).Into(&response) + return response, err +} diff --git a/pkg/server/data/engine/.gitignore b/pkg/server/data/engine/.gitignore new file mode 100644 index 000000000..cb510ca1b --- /dev/null +++ b/pkg/server/data/engine/.gitignore @@ -0,0 +1,2 @@ +test.db +test.*.db diff --git a/pkg/server/data/engine/engine.go b/pkg/server/data/engine/engine.go new file mode 100644 index 000000000..b27f40707 --- /dev/null +++ b/pkg/server/data/engine/engine.go @@ -0,0 +1,850 @@ +package engine + +import ( + "context" + "fmt" + "strings" + + "github.com/nrc-no/core/pkg/server/data/api" +) + +type engine struct { + // txFactory is the txFactory factory + txFactory api.TxFactory + // uuidGenerator generates uuids + uuidGenerator api.UUIDGenerator + // revisionGenerator generates revision hashes + revisionGenerator api.RevisionGenerator + // clock is used to get the current time + clock api.Clock + // dialect is the dialect used by the engine + // available dialects are: + // - "sqlite" + dialect string +} + +func NewEngine( + ctx context.Context, + txFactory api.TxFactory, + uuidGenerator api.UUIDGenerator, + revisionGenerator api.RevisionGenerator, + clock api.Clock, + dialect string, +) (api.Engine, error) { + e := &engine{ + txFactory: txFactory, + uuidGenerator: uuidGenerator, + revisionGenerator: revisionGenerator, + clock: clock, + dialect: dialect, + } + if err := e.Init(ctx); err != nil { + return nil, err + } + if dialect != "sqlite" { + return nil, api.ErrUnsupportedDialect + } + return e, nil +} + +// Init initializes the engine +// It creates supporting tables if they don't exist +func (e *engine) Init(ctx context.Context) error { + _, err := e.doTransaction(ctx, func(t api.Transaction) (interface{}, error) { + if err := e.initChangesTable(ctx, t); err != nil { + return nil, err + } + return nil, nil + }) + return err +} + +// PutRecord implements Engine.PutRecord +func (e *engine) PutRecord(ctx context.Context, request api.PutRecordRequest) (api.Record, error) { + ret, err := e.doTransaction(ctx, func(tx api.Transaction) (interface{}, error) { + var recPtr = &request.Record + if err := e.putRecordInternal(ctx, tx, recPtr, request.IsReplication); err != nil { + return nil, err + } + return *recPtr, nil + }) + if err != nil { + return api.Record{}, err + } + return ret.(api.Record), nil +} + +// GetRecord implements Engine.GetRecord +func (e *engine) GetRecord(ctx context.Context, request api.GetRecordRequest) (api.Record, error) { + ret, err := e.doTransaction(ctx, func(tx api.Transaction) (interface{}, error) { + if request.Revision.IsEmpty() { + found, err := e.findRecordInternal(ctx, tx, request.TableName, request.RecordID) + if err != nil { + return api.Record{}, err + } + if found == nil { + return api.Record{}, api.ErrRecordNotFound + } + return *found, nil + } else { + found, err := e.findRecordRevision(ctx, tx, request.TableName, request.RecordID, request.Revision) + if err != nil { + return api.Record{}, err + } + if found == nil { + return api.Record{}, api.ErrRecordNotFound + } + return *found, nil + } + }) + if err != nil { + return api.Record{}, err + } + return ret.(api.Record), nil +} + +// CreateTable implements Engine.CreateTable +func (e *engine) CreateTable(ctx context.Context, table api.Table) (api.Table, error) { + _, err := e.doTransaction(ctx, func(tx api.Transaction) (interface{}, error) { + err := e.createTable(ctx, tx, table) + if err != nil { + return nil, err + } + return nil, nil + }) + return table, err +} + +// GetChangeStream implements Engine.GetChanges +func (e *engine) GetChangeStream(ctx context.Context, request api.GetChangesRequest) (api.Changes, error) { + ret, err := e.doTransaction(ctx, func(tx api.Transaction) (interface{}, error) { + return e.getChangeStreamInternal(tx, ctx, request.Since) + }) + if err != nil { + return api.Changes{}, err + } + return ret.(api.Changes), nil +} + +func (e *engine) getChangeStreamInternal(tx api.Transaction, ctx context.Context, checkpoint int64) (api.Changes, error) { + + // retrieve the information about the table + columnNames, columnKinds, err := e.getColumnInfoForTable(ctx, tx, api.ChangeStreamTableName) + if err != nil { + return api.Changes{}, err + } + + // build the query + sqlQuery := ` +SELECT ` + joinStrings(columnNames, ",") + ` FROM "` + api.ChangeStreamTableName + `" +WHERE "` + api.KeyCSSequence + `" > ? +ORDER BY "` + api.KeyCSSequence + `" ASC; +` + // execute the query + rows, err := tx.Query(ctx, sqlQuery, []interface{}{checkpoint}) + if err != nil { + return api.Changes{}, err + } + defer closeRows(rows) + + // prepare the result + var records = make([]api.ChangeItem, 0) + + // iterate over the rows + for rows.Next() { + + var values map[string]api.Value + var rec api.Record + + // read the values + if values, err = rows.Read(columnKinds); err != nil { + return api.Changes{}, err + } + // create the record + if rec, err = readInRecord(api.ChangeStreamTableName, values); err != nil { + return api.Changes{}, err + } + + // parse the record + changeItem, err := parseChangeStreamItem(rec) + if err != nil { + return api.Changes{}, err + } + + // add the record to the result + records = append(records, changeItem) + } + + // check if there was an error while iterating + if err := rows.Err(); err != nil { + return api.Changes{}, err + } + + return api.Changes{ + Items: records, + }, nil +} + +func parseChangeStreamItem(rec api.Record) (api.ChangeItem, error) { + var changeItem api.ChangeItem + + recordIDValue, err := rec.GetFieldValue(api.KeyCSRecordID) + if err != nil { + return api.ChangeItem{}, err + } + if recordIDValue.Kind != api.ValueKindString { + return api.ChangeItem{}, fmt.Errorf("recordID is not a string") + } + changeItem.RecordID = recordIDValue.String.ValueOrZero() + + tableNameValue, err := rec.GetFieldValue(api.KeyCSTableName) + if err != nil { + return api.ChangeItem{}, err + } + if tableNameValue.Kind != api.ValueKindString { + return api.ChangeItem{}, fmt.Errorf("table name is not a string") + } + changeItem.TableName = tableNameValue.String.ValueOrZero() + + revisionValue, err := rec.GetFieldValue(api.KeyCSRecordRevision) + if err != nil { + return api.ChangeItem{}, err + } + if revisionValue.Kind != api.ValueKindString { + return api.ChangeItem{}, fmt.Errorf("revision is not a string") + } + revision, err := api.ParseRevision(revisionValue.String.ValueOrZero()) + if err != nil { + return api.ChangeItem{}, err + } + changeItem.RecordRevision = revision + + sequenceValue, err := rec.GetFieldValue(api.KeyCSSequence) + if err != nil { + return api.ChangeItem{}, err + } + if sequenceValue.Kind != api.ValueKindInt { + return api.ChangeItem{}, fmt.Errorf("sequence is not an int") + } + changeItem.Sequence = sequenceValue.Int.ValueOrZero() + + return changeItem, nil +} + +// getColumnTypesInternal returns the column types for the given table +func (e *engine) getColumnTypesInternal(ctx context.Context, tx api.Transaction, table string) (map[string]api.ValueKind, error) { + + // these will be the returned column types + columnKinds := []api.ValueKind{ + api.ValueKindString, + api.ValueKindString, + } + + // build the query + sql := `SELECT name, type FROM PRAGMA_TABLE_INFO(?);` + + // execute the query + rows, err := tx.Query(ctx, sql, []interface{}{table}) + if err != nil { + return nil, err + } + defer closeRows(rows) + + // prepare the result + columns := make(map[string]api.ValueKind) + for rows.Next() { + + // read the values + row, err := rows.Read(columnKinds) + if err != nil { + return nil, err + } + + // get the name of the column + name := row["name"].String.ValueOrZero() + // get the type of the column + kind := strings.ToLower(row["type"].String.ValueOrZero()) + + // map to the correct kind + switch kind { + case "integer": + columns[name] = api.ValueKindInt + case "real": + columns[name] = api.ValueKindFloat + case "text", "varchar": + columns[name] = api.ValueKindString + case "bool": + columns[name] = api.ValueKindBool + default: + return nil, fmt.Errorf("unknown column type: %s", kind) + } + } + // check if there was an error while iterating + if err := rows.Err(); err != nil { + return nil, err + } + + // if there are no columns, that means that the table does not exist + if len(columns) == 0 { + return nil, fmt.Errorf("table %s does not exist", table) + } + + return columns, nil +} + +// FindRecord finds a record by id within the given transaction +func (e *engine) findRecordInternal(ctx context.Context, tx api.Transaction, table string, id string) (*api.Record, error) { + + // retrieve the column information + columnNames, columnKinds, err := e.getColumnInfoForTable(ctx, tx, table) + if err != nil { + return nil, err + } + + // build the query + sqlQuery := `SELECT ` + joinStrings(columnNames, ",") + ` FROM "` + table + `" WHERE "` + api.KeyRecordID + `" = ? ORDER BY "` + api.KeyRevision + `" DESC LIMIT 1;` + + // execute the query + rows, err := tx.Query(ctx, sqlQuery, []interface{}{id}) + if err != nil { + return nil, err + } + defer closeRows(rows) + + // if there are no rows, that means that the record does not exist + // but this method will return nil instead of an error + if !rows.Next() { + return nil, nil + } + + // read the values + data, err := rows.Read(columnKinds) + if err != nil { + return nil, err + } + + // create the record + var record api.Record + if record, err = readInRecord(table, data); err != nil { + return nil, err + } + + return &record, nil +} + +// getColumnInfoForTable returns the column names and types for the given table +func (e *engine) getColumnInfoForTable(ctx context.Context, tx api.Transaction, table string) ([]string, []api.ValueKind, error) { + columnTypes, err := e.getColumnTypesInternal(ctx, tx, table) + if err != nil { + return nil, nil, err + } + var columnNames []string + var columnKinds []api.ValueKind + for columnName := range columnTypes { + columnNames = append(columnNames, columnName) + } + sortStrings(columnNames) + for _, columnName := range columnNames { + columnKinds = append(columnKinds, columnTypes[columnName]) + } + return columnNames, columnKinds, nil +} + +// putRecordInternal inserts a record into the database with the given timestamp +func (e *engine) putRecordInternal(ctx context.Context, tx api.Transaction, record *api.Record, isReplication bool) error { + + // insert the record into the database + if err := e.appendToHistory(ctx, tx, record, isReplication); err != nil { + return err + } + + // save to change stream + if !isLocalTable(record.Table) { + if err := e.appendToChangeStream(ctx, tx, record.Table, record.ID, record.Revision.String()); err != nil { + return err + } + } + + // update the view + if err := e.updateView(ctx, tx, record); err != nil { + return err + } + + return nil +} + +// appendToHistory adds the given record to the history table +func (e *engine) appendToHistory(ctx context.Context, tx api.Transaction, record *api.Record, isReplication bool) error { + + // record must have an id + if len(record.ID) == 0 { + return fmt.Errorf("record id is empty") + } + + // is the record is marked as not new, then it must have a revision + if isReplication && record.Revision.IsEmpty() { + return fmt.Errorf("record revision is empty") + } + + // if the record is new, then the record revision must found in the database + // it basically means that this record is the next revision of revision.Revision + if !isReplication && !record.Revision.IsEmpty() { + // find previous revision + _, err := e.findRecordRevision(ctx, tx, record.Table, record.ID, record.Revision) + if err != nil { + return err + } + } + + // generate a new revision is this is a new record + if !isReplication { + record.PreviousRevision = record.Revision + newRevision := generateRevision(e.revisionGenerator, record.ID, record.Revision, record.Attributes) + record.Revision = newRevision + } + + // build the query + sqlBuilder := &StringBuilder{} + fields, placeHolders, params := getUpdateRecordSQLArgs(record.ID, record.PreviousRevision, record.Revision, record.Attributes) + sqlBuilder.WriteString("INSERT INTO \"" + getHistoryTableName(record.Table) + "\" (" + joinStrings(fields, ", ") + ") VALUES (" + joinStrings(placeHolders, ", ") + ");") + + // execute the query + if _, err := tx.Exec(ctx, sqlBuilder.String(), params); err != nil { + return err + } + + return nil +} + +// updateView updates the reconciled view for the given record +func (e *engine) updateView(ctx context.Context, tx api.Transaction, rec *api.Record) error { + + // retrieve information about the table + columnNames, columnTypes, err := e.getColumnInfoForTable(ctx, tx, rec.Table) + if err != nil { + return err + } + + // build the query + sqlQuery := `SELECT ` + joinStrings(columnNames, ",") + ` FROM "` + getHistoryTableName(rec.Table) + `" WHERE "` + api.KeyDeleted + `" = false AND "` + api.KeyRecordID + `" = ? ORDER BY "` + api.KeyRevision + `" DESC LIMIT 1;` + + // execute the query + rows, err := tx.Query(ctx, sqlQuery, []interface{}{rec.ID}) + if err != nil { + return err + } + defer closeRows(rows) + + // it there are no rows, that means that the last version of the record is deleted, + // or that there was no history. We need to delete the record from the view to + // reflect this + if !rows.Next() { + _, err := tx.Exec(ctx, `DELETE FROM "`+rec.Table+`" WHERE "`+api.KeyRecordID+`" = ?`, []interface{}{rec.ID}) + return err + } + + // read the record from the query result + data, err := rows.Read(columnTypes) + if err != nil { + return err + } + + // build the update query + var fields []string + var placeholders []string + var values []interface{} + for k := range data { + if k == api.KeyPrevision || k == api.KeyDeleted { + continue + } + fields = append(fields, k) + placeholders = append(placeholders, "?") + value := data[k] + values = append(values, &value) + } + sqlBuilder := &StringBuilder{} + sqlBuilder.WriteString(`INSERT INTO "` + rec.Table + `" ("` + strings.Join(fields, `", "`) + `") VALUES (` + strings.Join(placeholders, ", ") + `)`) + sqlBuilder.WriteString(` ON CONFLICT ("` + api.KeyRecordID + `") DO UPDATE SET `) + var i int + for _, k := range fields { + if k == api.KeyRecordID { + continue + } + if i != 0 { + sqlBuilder.WriteString(", ") + } + sqlBuilder.WriteString(`"` + k + `" = excluded."` + k + `"`) + i++ + } + sqlBuilder.WriteString(`;`) + + // execute the query + _, err = tx.Exec(ctx, sqlBuilder.String(), values) + return err +} + +// findRecordRevision finds the revision of a record +func (e *engine) findRecordRevision(ctx context.Context, tx api.Transaction, table string, id string, revision api.Revision) (*api.Record, error) { + + // get information about the table + columnNames, columnTypes, err := e.getColumnInfoForTable(ctx, tx, table) + if err != nil { + return nil, err + } + + // build the query + sqlQuery := &StringBuilder{} + sqlQuery.WriteString("SELECT " + joinStrings(columnNames, ",") + " FROM \"" + getHistoryTableName(table) + "\" WHERE \"" + api.KeyRecordID + "\" = ? AND \"" + api.KeyRevision + "\" = ?") + + // execute the query + result, err := tx.Query(ctx, sqlQuery.String(), []interface{}{id, revision.String()}) + if err != nil { + return nil, err + } + defer closeResult(result) + + // if there are no rows, return nil + if !result.Next() { + return nil, nil + } + + // if there was an error while iterating, return it + if err := result.Err(); err != nil { + return nil, err + } + + // get the values + values, err := result.Read(columnTypes) + if err != nil { + return nil, err + } + + // build the record + ret, err := readInRecord(table, values) + if err != nil { + return nil, err + } + + return &ret, nil +} + +// createTable creates a table +func (e *engine) createTable(ctx context.Context, tx api.Transaction, table api.Table) error { + + // validate the table structure + if err := validateTable(table); err != nil { + return err + } + + // check that a table with the same name does not already exist + if err := e.checkTableDoesNotExist(ctx, tx, table.Name); err != nil { + return err + } + + // create the table + if err := e.createTableInternal(ctx, tx, table, false); err != nil { + return err + } + + // create the history table + if err := e.createTableInternal(ctx, tx, table, true); err != nil { + return err + } + + return nil +} + +// createTableInternal creates a table +func (e *engine) createTableInternal(ctx context.Context, tx api.Transaction, table api.Table, isHistoryTable bool) error { + + // build the query + sqlBuilder := &StringBuilder{} + sqlBuilder.WriteString("CREATE TABLE IF NOT EXISTS \"") + if isHistoryTable { + sqlBuilder.WriteString(getHistoryTableName(table.Name)) + } else { + sqlBuilder.WriteString(table.Name) + } + sqlBuilder.WriteString("\" (") + columns := make([]api.Column, 0) + + // append the primary key column + columns = append(columns, api.Column{ + Name: api.KeyRecordID, + Type: "varchar", + }) + + // append the previous revision column if this is the history table + if isHistoryTable { + columns = append(columns, api.Column{ + Name: api.KeyPrevision, + Type: "varchar", + Constraints: []api.ColumnConstraint{ + {NotNull: &api.NotNullColumnConstraint{}}, + }, + }) + } + + // append the revision column + columns = append(columns, api.Column{ + Name: api.KeyRevision, + Type: "varchar", + Constraints: []api.ColumnConstraint{ + {NotNull: &api.NotNullColumnConstraint{}}, + }, + }) + + // build the constraints + var tableConstraints []api.TableConstraint + for _, c := range table.Constraints { + tableConstraints = append(tableConstraints, c) + } + + if isHistoryTable { + // add the deleted column + columns = append(columns, api.Column{ + Name: api.KeyDeleted, + Type: "boolean", + Default: "false", + Constraints: []api.ColumnConstraint{ + {NotNull: &api.NotNullColumnConstraint{}}, + }, + }) + // add the primary key constraint for the history table + tableConstraints = append(tableConstraints, api.TableConstraint{ + PrimaryKey: &api.PrimaryKeyTableConstraint{ + Columns: []string{api.KeyRecordID, api.KeyRevision}, + }, + }) + } else { + // add the primary key constraint for the view table + tableConstraints = append(tableConstraints, api.TableConstraint{ + PrimaryKey: &api.PrimaryKeyTableConstraint{ + Columns: []string{api.KeyRecordID}, + }, + }) + } + + for _, c := range table.Columns { + columns = append(columns, c) + } + + // write the SQL for the columns + for i, column := range columns { + writeColumnDefinition(sqlBuilder, column) + if i < len(columns)-1 { + sqlBuilder.WriteString(", ") + } + } + + // write the SQL for the constraints + for _, c := range tableConstraints { + sqlBuilder.WriteString(", ") + writeConstraintDefinition(sqlBuilder, c) + } + + sqlBuilder.WriteString(")") + + // execute the query + if _, err := tx.Exec(ctx, sqlBuilder.String(), []interface{}{}); err != nil { + return err + } + + return nil +} + +// checkTableDoesNotExist checks that a table does not exist +func (e *engine) checkTableDoesNotExist(ctx context.Context, tx api.Transaction, name string) error { + if e.dialect == "sqlite" { + sql := `SELECT "name" FROM "sqlite_master" WHERE "type" = 'table' AND "name" = ?` + res, err := tx.Query(ctx, sql, []interface{}{name}) + if err != nil { + return err + } + defer closeResult(res) + if !res.Next() { + return nil + } + if err := res.Err(); err != nil { + return err + } + return api.NewTableAlreadyExistsErr(name) + } else { + return api.NewError(api.ErrCodeInternalError, "not implemented") + } +} + +// doTransaction is a helper method that wraps SQL operations within a transaction +func (e *engine) doTransaction(ctx context.Context, fn func(t api.Transaction) (interface{}, error)) (interface{}, error) { + tr, err := e.txFactory(ctx) + if err != nil { + return nil, err + } + errored := false + defer func() { + if !errored { + handleCommit(tr) + } else { + handleRollback(tr) + } + }() + ret, err := fn(tr) + if err != nil { + errored = true + } + return ret, err +} + +// initChangesTable creates the changes table if it does not exist +func (e *engine) initChangesTable(ctx context.Context, t api.Transaction) error { + sqlStatement := ` +CREATE TABLE IF NOT EXISTS "` + api.ChangeStreamTableName + `" ( + "` + api.KeyCSSequence + `" integer PRIMARY KEY AUTOINCREMENT, + "` + api.KeyCSTableName + `" varchar NOT NULL, + "` + api.KeyCSRecordID + `" varchar NOT NULL, + "` + api.KeyCSRecordRevision + `" varchar NOT NULL +);` + _, err := t.Exec(ctx, sqlStatement, nil) + if err != nil { + return err + } + return nil +} + +// appendToChangeStream appends a change to the change stream +func (e *engine) appendToChangeStream(ctx context.Context, t api.Transaction, tableName string, recordId, revision string) error { + sqlStatement := ` +INSERT INTO "` + api.ChangeStreamTableName + `" ("` + api.KeyCSTableName + `", "` + api.KeyCSRecordID + `", "` + api.KeyCSRecordRevision + `") +VALUES (?, ?, ?) +` + _, err := t.Exec(ctx, sqlStatement, []interface{}{tableName, recordId, revision}) + if err != nil { + return err + } + return nil +} + +// generateRevision generates a new revision for a record +func generateRevision(revisionGenerator api.RevisionGenerator, recordId string, previousRevision api.Revision, data map[string]api.Value) api.Revision { + revisionData := map[string]interface{}{ + api.KeyRecordID: recordId, + } + if previousRevision != api.EmptyRevision { + revisionData[api.KeyPrevision] = previousRevision + } + for k, v := range data { + revisionData[k] = v + } + return revisionGenerator.Generate(previousRevision.Num+1, revisionData) +} + +// readInRecord builds a record from a database result +func readInRecord(table string, data map[string]api.Value) (api.Record, error) { + var err error + var record = api.Record{ + Attributes: make(map[string]api.Value), + } + for columnName, columnValue := range data { + switch columnName { + case api.KeyRecordID: + record.ID = columnValue.String.ValueOrZero() + continue + case api.KeyRevision: + record.Revision, err = api.ParseRevision(columnValue.String.ValueOrZero()) + if err != nil { + return api.Record{}, err + } + case api.KeyPrevision: + value := columnValue.String.ValueOrZero() + if len(value) == 0 { + record.PreviousRevision = api.EmptyRevision + } else { + record.PreviousRevision, err = api.ParseRevision(value) + if err != nil { + return api.Record{}, err + } + } + default: + record.Attributes[columnName] = columnValue + } + } + record.Table = table + return record, nil +} + +func writeColumnDefinition(sqlBuilder *StringBuilder, column api.Column) { + sqlBuilder.WriteString("\"" + column.Name + "\" " + column.Type) + if column.Default != "" { + sqlBuilder.WriteString(" DEFAULT " + column.Default) + } + for _, constraint := range column.Constraints { + if constraint.NotNull != nil { + sqlBuilder.WriteString(" NOT NULL") + } + if constraint.PrimaryKey != nil { + sqlBuilder.WriteString(" PRIMARY KEY") + } + } +} + +func writeConstraintDefinition(builder *StringBuilder, constraint api.TableConstraint) { + if constraint.PrimaryKey != nil { + builder.WriteString("PRIMARY KEY (") + for i, column := range constraint.PrimaryKey.Columns { + builder.WriteString("\"" + column + "\"") + if i < len(constraint.PrimaryKey.Columns)-1 { + builder.WriteString(", ") + } + } + builder.WriteString(")") + } +} + +func handleCommit(tr api.Transaction) { + if err := tr.Commit(); err != nil { + fmt.Printf("Error on commit: %v\n", err) + handleRollback(tr) + } +} + +func handleRollback(tr api.Transaction) { + if err := tr.Rollback(); err != nil { + fmt.Printf("error while rolling back transaction: %v", err) + } +} + +func closeRows(rows api.ResultReader) { + if err := rows.Close(); err != nil { + fmt.Printf("error closing rows: %v\n", err) + } +} + +func closeResult(result api.ResultReader) { + if err := result.Close(); err != nil { + fmt.Printf("error closing result: %v\n", err) + } +} + +func isLocalTable(tableName string) bool { + if len(tableName) <= 6 { + return false + } + return tableName[len(tableName)-6:] == "local_" +} + +func getUpdateRecordSQLArgs(id string, previousRevision, currentRevision api.Revision, data map[string]api.Value) (fields, placeholders []string, values []interface{}) { + fields = append(fields, `"`+api.KeyRecordID+`"`, `"`+api.KeyRevision+`"`, `"`+api.KeyPrevision+`"`) + placeholders = append(placeholders, "?", "?", "?") + values = append(values, id, currentRevision.String(), previousRevision.String()) + + for field, value := range data { + fields = append(fields, `"`+field+`"`) + placeholders = append(placeholders, "?") + values = append(values, &value) + } + return +} + +func getHistoryTableName(tableName string) string { + return tableName + "_history" +} diff --git a/pkg/server/data/engine/engine_test.go b/pkg/server/data/engine/engine_test.go new file mode 100644 index 000000000..8aef50e14 --- /dev/null +++ b/pkg/server/data/engine/engine_test.go @@ -0,0 +1,279 @@ +package engine + +import ( + "context" + "fmt" + "testing" + + _ "github.com/mattn/go-sqlite3" + "github.com/nrc-no/core/pkg/server/data/api" + "github.com/nrc-no/core/pkg/server/data/test" + "github.com/nrc-no/core/pkg/server/data/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type Suite struct { + suite.Suite + Bench *Bench +} + +func (s *Suite) SetupSuite() { + s.Bench = NewTestBench("file::memory:?cache=shared") +} + +func (s *Suite) TearDownSuite() { + if err := s.Bench.TearDown(); err != nil { + panic(err) + } +} + +func (s *Suite) SetupTest() { + if err := s.Bench.Reset(); err != nil { + s.T().Fatal(err) + } +} + +func TestSuite(t *testing.T) { + suite.Run(t, new(Suite)) +} + +func (s *Suite) Test_Engine_CreateTable() { + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + type testCase struct { + // name of the test case + name string + // description of the test case + description string + // table to create + table api.Table + // error assertion function + expectError assert.ErrorAssertionFunc + // expected sql statements + expectedStatements []test.ExpectedStatement + // optional setup function + doBefore func() error + } + + testCases := []testCase{ + { + name: "table with no name", + description: "Should not allow to create a table with an empty name", + table: api.Table{ + Name: "", + Columns: []api.Column{ + {Name: "field1", Type: "varchar"}, + }, + }, + expectError: test.ErrorIs(api.ErrInvalidTableName), + }, + { + name: "table with no columns", + description: "Should not allow to create a table with no columns", + table: api.Table{ + Name: "mock_table_1", + Columns: []api.Column{}, + }, + expectError: test.ErrorIs(api.ErrEmptyColumns), + }, + { + name: "table with duplicate column name", + description: "Should not allow to create a table with duplicate column names", + expectError: test.ErrorIs(api.ErrDuplicateColumnName), + table: api.Table{ + Name: "mock_table_2", + Columns: []api.Column{ + {Name: "field1", Type: "varchar"}, + {Name: "field1", Type: "varchar"}, + }, + }, + }, + { + name: "table with invalid column name", + description: "Should not allow to create a table with an invalid column name", + expectError: test.ErrorIs(api.ErrInvalidColumnName), + table: api.Table{ + Name: "mock_table_3", + Columns: []api.Column{ + {Name: "field1", Type: "varchar"}, + {Name: " field2 ", Type: "varchar"}, + }, + }, + }, + { + name: "table with invalid column type", + description: "Should not allow to create a table with an invalid column type", + expectError: test.ErrorIs(api.ErrInvalidColumnType), + table: api.Table{ + Name: "mock_table_4", + Columns: []api.Column{{Name: "field1", Type: "Bla"}}, + }, + }, + { + name: "already exists", + description: "Should not allow to create a table that already exists", + expectError: test.ErrorIs(api.ErrTableAlreadyExists), + doBefore: func() error { + _, err := s.Bench.DB.Exec(`CREATE TABLE "mock_table_6" ("bla" varchar)`) + return err + }, + table: api.Table{ + Name: "mock_table_6", + Columns: []api.Column{ + {Name: "field1", Type: "varchar"}, + }, + }, + expectedStatements: []test.ExpectedStatement{ + { + SQL: `SELECT "name" FROM "sqlite_master" WHERE "type" = 'table' AND "name" = ?`, + Params: []interface{}{"mock_table_6"}, + }, + }, + }, + { + name: "valid table", + description: "Should allow to create a valid table", + table: api.Table{ + Name: "test_create_valid_table", + Columns: []api.Column{ + {Name: "field1", Type: "varchar"}, + {Name: "field2", Type: "varchar"}, + }, + }, + expectedStatements: []test.ExpectedStatement{ + { + SQL: `SELECT "name" FROM "sqlite_master" WHERE "type" = 'table' AND "name" = ?`, + Params: []interface{}{"test_create_valid_table"}, + }, { + SQL: `CREATE TABLE IF NOT EXISTS "test_create_valid_table" ("_id" varchar, "_rev" varchar NOT NULL, "field1" varchar, "field2" varchar, PRIMARY KEY ("_id"))`, + Params: []interface{}{}, + }, { + SQL: `CREATE TABLE IF NOT EXISTS "test_create_valid_table_history" ("_id" varchar, "_prev" varchar NOT NULL, "_rev" varchar NOT NULL, "_deleted" boolean DEFAULT false NOT NULL, "field1" varchar, "field2" varchar, PRIMARY KEY ("_id", "_rev"))`, + Params: []interface{}{}, + }, + }, + }, + } + for _, tc := range testCases { + tc := tc + s.T().Run(tc.name, func(t *testing.T) { + if tc.doBefore != nil { + err := tc.doBefore() + if !assert.NoError(t, err) { + return + } + } + s.Bench.Recorder.Reset() + _, err := s.Bench.Engine.CreateTable(ctx, tc.table) + if tc.expectError != nil { + tc.expectError(t, err) + } else { + assert.NoError(t, err) + } + if tc.expectedStatements != nil { + s.Bench.Recorder.AssertStatementsExecuted(t, tc.expectedStatements) + } + }) + } +} + +func (s *Suite) Test_Engine_GetRecord_Exists() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if _, err := s.Bench.Engine.CreateTable(ctx, api.Table{ + Name: "test_get_record_exists", + Columns: []api.Column{ + {Name: "field1", Type: "varchar"}, + }, + }); !assert.NoError(s.T(), err) { + return + } + + _, err := s.Bench.Engine.PutRecord(ctx, api.PutRecordRequest{ + Record: api.Record{ + Table: "test_get_record_exists", + ID: "mock_id_1", + Attributes: map[string]api.Value{ + "field1": api.NewStringValue("value1", true), + }, + }, + }) + assert.NoError(s.T(), err) + + found, err := s.Bench.Engine.GetRecord(ctx, api.GetRecordRequest{ + TableName: "test_get_record_exists", + RecordID: "mock_id_1", + }) + assert.NoError(s.T(), err) + + s.T().Log(found) +} + +func (s *Suite) Test_Engine_GetRecord_DoesNotExist() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, err := s.Bench.Engine.CreateTable(ctx, api.Table{ + Name: "mock_table_2", + Columns: []api.Column{ + {Name: "field1", Type: "varchar"}, + }, + }) + assert.NoError(s.T(), err) + + _, err = s.Bench.Engine.GetRecord(ctx, api.GetRecordRequest{ + TableName: "mock_table_2", + RecordID: "bla", + }) + assert.ErrorIs(s.T(), err, api.ErrRecordNotFound) +} + +func (s *Suite) Test_Engine_PutRecord_MultipleRevisions() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resetClock := s.Bench.Clock.UseClock(&utils.Clock{}) + defer resetClock() + + _, err := s.Bench.Engine.CreateTable(ctx, api.Table{ + Name: "mock_table_3", + Columns: []api.Column{ + {Name: "field1", Type: "varchar"}, + {Name: "field2", Type: "varchar"}, + }, + }) + + request := api.PutRecordRequest{ + Record: api.Record{ + Table: "mock_table_3", + ID: "mock_id_1", + Attributes: map[string]api.Value{ + "field1": api.NewStringValue("value1", true), + "field2": api.NewStringValue("value2", true), + }, + }, + } + + rec, err := s.Bench.Engine.PutRecord(ctx, request) + assert.NoError(s.T(), err) + + s.T().Log(rec) + + var revisions []api.Revision + revisions = append(revisions, rec.GetRevision()) + + for i := 0; i < 5; i++ { + record := request.Record + record = record.SetFieldValue("field1", api.NewStringValue(fmt.Sprintf("value1_%d", i), true)) + rec2, err := s.Bench.Engine.PutRecord(ctx, api.PutRecordRequest{Record: record}) + assert.NoError(s.T(), err) + revisions = append(revisions, rec2.GetRevision()) + rec = rec2 + } + + s.T().Log(revisions) +} diff --git a/pkg/server/data/engine/pkg.go b/pkg/server/data/engine/pkg.go new file mode 100644 index 000000000..a74da98bb --- /dev/null +++ b/pkg/server/data/engine/pkg.go @@ -0,0 +1,47 @@ +// Package engine provides the core database engine +// +// This is a *headless* database engine that can be used to store and retrieve +// records. It is headless because it does not communicate with the database +// directly. Instead, it generates SQL statements that are to be executed. +// +// The engine is designed to be used in standalone mode or as a part of a +// peer-to-peer network. +// +// # Introduction +// The engine is designed to store and retrieve records in a table. +// A table is a collection of records. +// A record is a collection of fields. +// A record has a unique id. +// A field is a named value. +// A value is a string, a number, or a boolean. +// +// The engine is designed to be used in standalone mode or as a part of a +// peer-to-peer network. +// +// For each table, the engine internally creates two distinct tables: +// * +// *
_history +// +// The
table is used to store the reconciled state of the records. +// The
_history table is used to store the history of the records. +// +// When a new version of a record is created, the engine will append this +// version inside the
_history table. The
table will be +// updated to reflect the new version. +// +// When multiple versions of a record are in a conflict, the engine will +// choose the version with the highest revision number. This is the version +// of the record that will be reflected in the
table. +// +// The revision of a record is a string in the format -. +// The is an incrementing number. +// The is a hash of the record. +// +// Two records might have the same number but different hashes. +// In this case, the record with the alphabetically higher hash is the +// winner. +// +// The engine also stores the previous revision of a record. +// Which allows to build the history of a record, where the history +// is a branching tree. +package engine diff --git a/pkg/server/data/engine/reconciler.go b/pkg/server/data/engine/reconciler.go new file mode 100644 index 000000000..75b32e57e --- /dev/null +++ b/pkg/server/data/engine/reconciler.go @@ -0,0 +1,208 @@ +package engine + +import ( + "context" + "errors" + + "github.com/nrc-no/core/pkg/server/data/api" +) + +const checkpointTableName = "local_reconciler_changes" +const checkpointKey = "checkpoint" + +// reconciler is a type that reconciles a source and a target database +type reconciler struct{} + +// initTable creates the necessary table for the reconciler to work +// the table contains the reconciled revision of the source database +// so that we don't pull tons of data from the source database every +// time we run the reconciler. +// +// The database change stream maintains a sequence number for each +// operation to the database. This incremental sequence number is used +// to determine which operations have already been reconciled. +// +// It is basically a map -> +// that exists in the destination database. +func initTable(ctx context.Context, destination api.Engine) error { + _, err := destination.CreateTable(ctx, api.Table{ + Name: checkpointTableName, + Columns: []api.Column{ + { + Name: checkpointKey, + Type: "integer", + Constraints: []api.ColumnConstraint{ + { + NotNull: &api.NotNullColumnConstraint{}, + }, + }, + }, + }, + }) + if api.IsError(err, api.ErrCodeTableAlreadyExists) { + return nil + } + if err != nil { + return err + } + return nil +} + +// Reconcile reconciles the source and target database +// +// It must work in a peer to peer environment, where there are no authority, but +// where both the clients and the servers are both sources of truth. +// +// Usually, the target database would be the locally running database, and the +// source database would be the remote database. So each database could technically +// run a reconciler alongside itself, and pull the data from the other databases. +// +// The logic of the reconciler is as follows: +// +// Given a source and a target database, we want to reconcile +// the source database with the target database. The goal is to pull the changes +// from the source database that are not already in the target database. +// +// For each source database, the reconciler maintains a sequence number +// corresponding to the last change that has been reconciled from that source. +// That sequence number is maintained in the destination database in a +// 'local_reconciler_changes' table. It is simply a pair of the source database ID +// and the last seen sequence number for that source. +// +// E.g. +// source seq +// foo 101 +// bar 200 +// baz 0 +// +// The reconciler will query the source database to get the changes that +// happened since that known sequence number. And for each change, the +// reconciler will apply the change to the target database. +// +// If the destination already has the change, the reconciler will skip it. +// If the destination does not have the change, the reconciler will apply the change. +// +// Once the changes are applied, the reconciler will update the sequence number +// in the 'local_reconciler_changes' table to reflect the last change that was reconciled +// from that source database. +// +// E.g. +// source seq +// foo 150 +// bar 210 +// baz 5 +// +// This mechanism allows multiple databases to be reconciled in between each other. +// The source and target databases can be accessed locally, or via other protocols +// such as HTTP, GRPC, etc. As long as the protocol implements the Engine interface. +// +// All the reconciler does is to synchronize the source history tables into the destination's history tables. +// The engine in the destination will take care of electing the winning records in the event +// of a conflict. This is also a deterministic process, so we can be sure that the +// results will be the same regardless of the order in which the changes are applied, +// regardless of the database that is the source and the database that is the target. +func (r *reconciler) Reconcile(ctx context.Context, source api.ReadInterface, destination api.Engine) error { + + // TODO assign unique id to destination + var sourceId = "dest1234" + var checkpoint int64 = -1 + var checkpointRec *api.Record + var err error + + // create the checkpoint table if it doesn't exist + if err = initTable(ctx, destination); err != nil { + return err + } + + // get the last checkpoint from the destination database for the source database + foundCheckpointRec, err := destination.GetRecord(ctx, api.GetRecordRequest{ + TableName: checkpointTableName, + RecordID: sourceId, + }) + if err != nil { + if !errors.Is(err, api.ErrRecordNotFound) { + return err + } + } else { + checkpointRec = &foundCheckpointRec + } + + if checkpointRec != nil { + // already has a checkpoint for the source database + // retrieve the checkpoint + var ( + fieldValue interface{} + ok bool + ) + if fieldValue, err = checkpointRec.GetFieldValue(checkpointKey); err != nil { + return err + } + if checkpoint, ok = fieldValue.(int64); !ok { + return errors.New("checkpoint is not an int64") + } + } else { + checkpointRec = &api.Record{ + Table: checkpointTableName, + ID: sourceId, + Attributes: map[string]api.Value{ + checkpointKey: api.NewIntValue(checkpoint, true), + }, + } + } + + changes, err := source.GetChangeStream(ctx, api.GetChangesRequest{Since: checkpoint}) + if err != nil { + return err + } + // todo: Wrap this in a transaction somehow + // probably have to create a WithTransaction(fn func(e Engine) error) method + // But that would only work locally, grpc or through websockets, not through HTTP + // Perhaps it's not so bad that reconciliation is partial, but it's still not ideal + // todo: batch this for sure + for _, change := range changes.Items { + // for each change in the source change stream + + // skip if this checkpoint has already been reconciled + if change.Sequence == checkpoint { + continue + } + + // check if the record revision already exist in the destination + // if so, we don't need to do anything + _, err = destination.GetRecord(ctx, api.GetRecordRequest{ + TableName: change.TableName, + RecordID: change.RecordID, + Revision: change.RecordRevision, + }) + if !api.IsError(err, api.ErrCodeRecordNotFound) { + return err + } + + // get the revision from the source + sourceRevisionRec, err := source.GetRecord(ctx, api.GetRecordRequest{ + TableName: change.TableName, + RecordID: change.RecordID, + Revision: change.RecordRevision, + }) + if err != nil { + return err + } + + // insert the revision into the destination + if _, err := destination.PutRecord(ctx, api.PutRecordRequest{ + Record: sourceRevisionRec, + IsReplication: true, + }); err != nil { + return err + } + checkpointRec.SetFieldValue(checkpointKey, api.NewIntValue(checkpoint, true)) + } + + if _, err := destination.PutRecord(ctx, api.PutRecordRequest{ + Record: *checkpointRec, + }); err != nil { + return err + } + + return nil +} diff --git a/pkg/server/data/engine/reconciler_test.go b/pkg/server/data/engine/reconciler_test.go new file mode 100644 index 000000000..9200f30ba --- /dev/null +++ b/pkg/server/data/engine/reconciler_test.go @@ -0,0 +1,100 @@ +package engine + +import ( + "context" + "testing" + + "github.com/nrc-no/core/pkg/server/data/api" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type ReconcilerSuite struct { + suite.Suite + ctx context.Context + cancel context.CancelFunc + source *Bench + destination *Bench +} + +func (s *ReconcilerSuite) SetupSuite() { + s.source = NewTestBench("file::memory:") + s.destination = NewTestBench("file::memory:") +} + +func (s *ReconcilerSuite) SetupTest() { + s.ctx, s.cancel = context.WithCancel(context.Background()) + if err := s.source.Reset(); err != nil { + panic(err) + } + if err := s.destination.Reset(); err != nil { + panic(err) + } +} + +func (s *ReconcilerSuite) TearDownTest() { + s.cancel() +} + +func (s *ReconcilerSuite) TestReconcileEmptySourceAndDestination() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + r := &reconciler{} + if err := r.Reconcile(ctx, s.source.Engine, s.destination.Engine); !assert.NoError(s.T(), err) { + return + } +} + +func (s *ReconcilerSuite) TestReconcileRecordFromSource() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + r := &reconciler{} + + for _, e := range []api.Engine{s.source.Engine, s.destination.Engine} { + if err := e.CreateTable(ctx, api.Table{ + Name: "test", + Columns: []api.Column{ + { + Name: "field1", + Type: "varchar", + }, + }, + }); !assert.NoError(s.T(), err) { + return + } + } + + srcRec, err := s.source.Engine.PutRecord(ctx, api.PutRecordRequest{ + Record: api.Record{ + Table: "test", + ID: "mock_id_1", + Attributes: map[string]api.Value{ + "field1": api.NewStringValue("mock_value_1", true), + }, + }, + }) + if !assert.NoError(s.T(), err) { + return + } + + if err := r.Reconcile(ctx, s.source.Engine, s.destination.Engine); !assert.NoError(s.T(), err) { + return + } + + // find record in destination + destRec, err := s.source.Engine.GetRecord(ctx, api.GetRecordRequest{ + TableName: "test", + RecordID: srcRec.GetID(), + }) + if !assert.NoError(s.T(), err) { + return + } + + assert.Equal(s.T(), srcRec, destRec) + + s.T().Log(srcRec, destRec) +} + +func TestReconcilerSuite(t *testing.T) { + suite.Run(t, new(ReconcilerSuite)) +} diff --git a/pkg/server/data/engine/utils.go b/pkg/server/data/engine/utils.go new file mode 100644 index 000000000..7c87d1099 --- /dev/null +++ b/pkg/server/data/engine/utils.go @@ -0,0 +1,61 @@ +package engine + +import ( + "bytes" +) + +// joinStrings joins a list of strings into a single string with a separator +// it is a zero-dependant version of strings.Join(strings, sep) +func joinStrings(strings []string, separator string) string { + if len(strings) == 0 { + return "" + } + if len(strings) == 1 { + return strings[0] + } + n := len(separator) * (len(strings) - 1) + for i := 0; i < len(strings); i++ { + n += len(strings[i]) + } + var b = &StringBuilder{} + b.buf.Grow(n) + b.WriteString(strings[0]) + for _, s := range strings[1:] { + b.WriteString(separator) + b.WriteString(s) + } + return b.String() +} + +// StringBuilder is a simple string builder +// it is a zero-dependant version of strings.Builder +type StringBuilder struct { + buf bytes.Buffer +} + +// WriteString writes a string to the buffer +func (b *StringBuilder) WriteString(s string) { + b.buf.WriteString(s) +} + +// String returns the string representation of the buffer +func (b *StringBuilder) String() string { + return b.buf.String() +} + +// Reset resets the buffer +func (b *StringBuilder) Reset() { + b.buf.Reset() +} + +// sortStrings uses bubble sort to sort a list of strings +// it is a zero-dependant version of sort.Strings(strings) +func sortStrings(strings []string) { + for i := 0; i < len(strings); i++ { + for j := 0; j < len(strings)-1; j++ { + if strings[j] > strings[j+1] { + strings[j], strings[j+1] = strings[j+1], strings[j] + } + } + } +} diff --git a/pkg/server/data/engine/validation.go b/pkg/server/data/engine/validation.go new file mode 100644 index 000000000..af03657f7 --- /dev/null +++ b/pkg/server/data/engine/validation.go @@ -0,0 +1,64 @@ +package engine + +import ( + "github.com/nrc-no/core/pkg/server/data/api" +) + +// validateColumnType checks that the data type is a valid column type. +func validateColumnType(typeName string) bool { + var allowedTypes = map[string]bool{ + "varchar": true, + "integer": true, + "timestamp": true, + "boolean": true, + } + return allowedTypes[typeName] +} + +// validateIdentifier validates that a given name is a valid identifier +// for use in the database, such as a table or column name. +func validateIdentifier(name string) bool { + if len(name) == 0 { + return false + } + var prohibitedNames = map[string]bool{ + api.KeyRecordID: true, + api.KeyRevision: true, + api.KeyPrevision: true, + } + if prohibitedNames[name] { + return false + } + for _, c := range name { + if c != '_' && (c < '0' || c > 'z') { + return false + } + } + return true +} + +// validateTable validates that a given table definition is valid +func validateTable(table api.Table) error { + if !validateIdentifier(table.Name) { + return api.ErrInvalidTableName + } + + if len(table.Columns) == 0 { + return api.ErrEmptyColumns + } + + var columnNames = map[string]bool{} + for _, column := range table.Columns { + if columnNames[column.Name] { + return api.NewDuplicateColumnNameErr(column.Name) + } + if !validateIdentifier(column.Name) { + return api.NewInvalidColumnNameErr(column.Name) + } + if !validateColumnType(column.Type) { + return api.NewInvalidColumnTypeErr(column.Type) + } + columnNames[column.Name] = true + } + return nil +} diff --git a/pkg/server/data/engine/zzz_test.go b/pkg/server/data/engine/zzz_test.go new file mode 100644 index 000000000..a20bcaa1d --- /dev/null +++ b/pkg/server/data/engine/zzz_test.go @@ -0,0 +1,97 @@ +package engine + +import ( + "context" + "fmt" + "time" + + "github.com/jmoiron/sqlx" + "github.com/nrc-no/core/pkg/server/data/api" + "github.com/nrc-no/core/pkg/server/data/test" + "github.com/nrc-no/core/pkg/server/data/utils" +) + +// Bench is a testing utility that makes it easy to setup a test database +// with all the dependencies +type Bench struct { + DBName string + Engine api.Engine + DB *sqlx.DB + Recorder *test.DBRecorder + Clock *test.MockClock + UUIDGenerator *test.MockUUIDGenerator + RevisionGenerator *utils.Md5RevGenerator + Cancel context.CancelFunc + Ctx context.Context +} + +// TearDown closes the database connection and cleans up the test database +func (b *Bench) TearDown() error { + if b.Cancel != nil { + b.Cancel() + } + if b.DB != nil { + if err := b.DB.Close(); err != nil { + return err + } + } + return nil +} + +// Reset resets the database to a clean state +func (b *Bench) Reset() error { + + var err error + if err = b.TearDown(); err != nil { + return err + } + + b.Ctx, b.Cancel = context.WithCancel(context.Background()) + + // create the database connection + b.DB, err = sqlx.ConnectContext(b.Ctx, "sqlite3", b.DBName) + if err != nil { + return err + } + + // drop all data + if err := test.DropAll(b.DB); err != nil { + panic(err) + } + + b.Engine, err = NewEngine( + context.Background(), + TxFactory(b.DB, b.Recorder), + b.UUIDGenerator, + b.RevisionGenerator, + b.Clock, + "sqlite", + ) + return err +} + +// NewTestBench creates a new test bench +func NewTestBench(dbName string) *Bench { + b := &Bench{ + DBName: dbName, + Recorder: &test.DBRecorder{}, + Clock: &test.MockClock{ + TheTime: time.Now().Unix(), + }, + UUIDGenerator: &test.MockUUIDGenerator{ + ReturnUUID: "12345678-1234-1234-1234-123456789012", + }, + RevisionGenerator: &utils.Md5RevGenerator{}, + } + return b +} + +// TxFactory creates a new transaction factory +func TxFactory(conn *sqlx.DB, recorder *test.DBRecorder) func(ctx context.Context) (api.Transaction, error) { + return func(ctx context.Context) (api.Transaction, error) { + return utils.NewTransaction(ctx, conn, func(qry string, args []interface{}) { + fmt.Println(qry, args) + recorder.Record(qry, args) + }) + } +} diff --git a/pkg/server/data/handler/get_changes.go b/pkg/server/data/handler/get_changes.go new file mode 100644 index 000000000..a9fb186f4 --- /dev/null +++ b/pkg/server/data/handler/get_changes.go @@ -0,0 +1,41 @@ +package handler + +import ( + "encoding/json" + "net/http" + "strconv" + + "github.com/emicklei/go-restful/v3" + "github.com/nrc-no/core/pkg/server/data/api" +) + +func getChanges(e api.Engine, request api.GetChangesRequest) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + changeStream, err := e.GetChangeStream(req.Context(), request) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + jsonBytes, err := json.Marshal(changeStream) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + w.Write(jsonBytes) + }) +} + +func restfulGetChanges(e api.Engine) func(req *restful.Request, resp *restful.Response) { + return func(req *restful.Request, resp *restful.Response) { + var request api.GetChangesRequest + sinceStr := req.QueryParameter(queryParamSince) + since, err := strconv.ParseInt(sinceStr, 10, 64) + if err != nil { + resp.WriteErrorString(http.StatusBadRequest, err.Error()) + return + } + request.Since = since + getChanges(e, request).ServeHTTP(resp.ResponseWriter, req.Request) + } +} diff --git a/pkg/server/data/handler/get_row.go b/pkg/server/data/handler/get_row.go new file mode 100644 index 000000000..676835fd7 --- /dev/null +++ b/pkg/server/data/handler/get_row.go @@ -0,0 +1,45 @@ +package handler + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/emicklei/go-restful/v3" + "github.com/nrc-no/core/pkg/server/data/api" +) + +func getRow(e api.Engine, request api.GetRecordRequest) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + rec, err := e.GetRecord(req.Context(), request) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + responseBytes, err := json.Marshal(rec) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + header := w.Header() + header.Set("Content-Type", "application/json") + header.Set("ETag", fmt.Sprintf("%s", rec.Revision.String())) + w.Write(responseBytes) + }) +} + +func restfulGetRow(e api.Engine) func(req *restful.Request, resp *restful.Response) { + return func(req *restful.Request, resp *restful.Response) { + rev, err := api.ParseRevision(req.QueryParameter(queryParamRev)) + if err != nil { + http.Error(resp.ResponseWriter, err.Error(), http.StatusBadRequest) + return + } + var request = api.GetRecordRequest{ + RecordID: req.PathParameter(pathParamId), + TableName: req.PathParameter(pathParamTable), + Revision: rev, + } + getRow(e, request).ServeHTTP(resp.ResponseWriter, req.Request) + } +} diff --git a/pkg/server/data/handler/handler.go b/pkg/server/data/handler/handler.go new file mode 100644 index 000000000..19fbc9f20 --- /dev/null +++ b/pkg/server/data/handler/handler.go @@ -0,0 +1,118 @@ +package handler + +import ( + "fmt" + "net/http" + + "github.com/emicklei/go-restful/v3" + "github.com/nrc-no/core/pkg/server/data/api" +) + +const ( + stringDataType = "string" + intDataType = "int" + booleanDataType = "boolean" + pathParamTable = "table" + pathParamId = "id" + queryParamRev = "rev" + queryParamSince = "since" + queryParamIsReplication = "replication" +) + +// Handler is the HTTP Handler for the database API +// It is used to serve the following endpoints: +// CreateTable: PUT /api/v1/
+// GetTable: GET /api/v1/tables/{table} +// GetTables: GET /api/v1/tables +// PutRow: PUT /api/v1/tables/{table}/rows/{row} +// GetRecord: GET /api/v1/tables/{table}/rows/{row}?revision={revision} +// GetRows: GET /api/v1/tables/{table}/rows +// GetChanges: GET /api/v1/changes?since={seq} +type Handler struct { + engine api.Engine + ws *restful.WebService +} + +func (h *Handler) WebService() *restful.WebService { + return h.ws +} + +func NewHandler(engine api.Engine) *Handler { + + ws := new(restful.WebService). + Path("/apis/data.nrc.no/v1"). + Doc("data.nrc.no API") + + ws.Route(ws.PUT(fmt.Sprintf("/tables/{%s}", pathParamTable)). + Operation("PutTable"). + Doc("Creates or Updates a table"). + Reads(api.Table{}). + Writes(api.Table{}). + Consumes("application/json"). + Produces("application/json"). + Param(ws. + PathParameter(pathParamTable, "table name"). + DataType(stringDataType). + Required(true)). + To(restfulPutTable(engine)). + Returns(http.StatusOK, "OK", api.Table{})) + + ws.Route(ws.GET(fmt.Sprintf("/tables/{%s}/records/{%s}", pathParamTable, pathParamId)). + Operation("GetRecord"). + Doc("Gets a record"). + Writes(api.Record{}). + Produces("application/json"). + Param(ws. + PathParameter(pathParamTable, "table name"). + DataType(stringDataType). + Required(true)). + Param(ws. + PathParameter(pathParamId, "record id"). + DataType(stringDataType). + Required(true)). + Param(ws. + QueryParameter(queryParamRev, "revision"). + DataType(stringDataType). + Required(false)). + To(restfulGetRow(engine)). + Returns(http.StatusOK, "OK", api.Record{})) + + ws.Route(ws.PUT(fmt.Sprintf(`/tables/{%s}/records/{%s}`, pathParamTable, pathParamId)). + Operation("PutRow"). + Doc("Puts a record in a table"). + Reads(api.PutRecordRequest{}). + Writes(api.Record{}). + Consumes("application/json"). + Produces("application/json"). + Param(ws. + PathParameter(pathParamTable, "table name"). + DataType(stringDataType). + Required(true)). + Param(ws. + PathParameter(pathParamId, "row id"). + DataType(stringDataType). + Required(true)). + Param(ws. + QueryParameter(queryParamIsReplication, "is this a new record?"). + DataType(booleanDataType). + Required(false)). + To(restfulPutRow(engine)). + Returns(http.StatusOK, "OK", api.Record{})) + + ws.Route(ws.GET("/changes"). + Operation("GetChanges"). + Doc("Get changes"). + Writes(api.Changes{}). + Produces("application/json"). + Param(ws. + PathParameter(queryParamSince, "checkpoint"). + DataType(intDataType). + Required(true)). + To(restfulGetChanges(engine)). + Returns(http.StatusOK, "OK", api.Changes{})) + + return &Handler{ + engine: engine, + ws: ws, + } +} diff --git a/pkg/server/data/handler/put_row.go b/pkg/server/data/handler/put_row.go new file mode 100644 index 000000000..4e8381958 --- /dev/null +++ b/pkg/server/data/handler/put_row.go @@ -0,0 +1,46 @@ +package handler + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + + "github.com/emicklei/go-restful/v3" + "github.com/nrc-no/core/pkg/server/data/api" +) + +func putRow(e api.Engine) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + var request api.PutRecordRequest + bodyBytes, err := ioutil.ReadAll(req.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if err := json.Unmarshal(bodyBytes, &request); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + ret, err := e.PutRecord(req.Context(), request) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + responseBytes, err := json.Marshal(ret) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + header := w.Header() + header.Set("Content-Type", "application/json") + header.Set("ETag", fmt.Sprintf("%s", ret.Revision.String())) + w.Write(responseBytes) + }) +} + +func restfulPutRow(engine api.Engine) func(req *restful.Request, resp *restful.Response) { + return func(req *restful.Request, resp *restful.Response) { + putRow(engine).ServeHTTP(resp.ResponseWriter, req.Request) + } +} diff --git a/pkg/server/data/handler/put_table.go b/pkg/server/data/handler/put_table.go new file mode 100644 index 000000000..44a4eec2f --- /dev/null +++ b/pkg/server/data/handler/put_table.go @@ -0,0 +1,43 @@ +package handler + +import ( + "encoding/json" + "io/ioutil" + "net/http" + + "github.com/emicklei/go-restful/v3" + "github.com/nrc-no/core/pkg/server/data/api" +) + +func putTable(e api.Engine) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + var table api.Table + bodyBytes, err := ioutil.ReadAll(req.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if err := json.Unmarshal(bodyBytes, &table); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if _, err := e.CreateTable(req.Context(), table); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + responseBytes, err := json.Marshal(table) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + header := w.Header() + header.Set("Content-Type", "application/json") + w.Write(responseBytes) + }) +} + +func restfulPutTable(engine api.Engine) func(req *restful.Request, resp *restful.Response) { + return func(req *restful.Request, resp *restful.Response) { + putTable(engine).ServeHTTP(resp.ResponseWriter, req.Request) + } +} diff --git a/pkg/server/data/server.go b/pkg/server/data/server.go new file mode 100644 index 000000000..748a5006b --- /dev/null +++ b/pkg/server/data/server.go @@ -0,0 +1,69 @@ +package data + +import ( + "context" + + "github.com/jmoiron/sqlx" + "github.com/nrc-no/core/pkg/server/data/api" + "github.com/nrc-no/core/pkg/server/data/engine" + "github.com/nrc-no/core/pkg/server/data/handler" + "github.com/nrc-no/core/pkg/server/data/utils" + "github.com/nrc-no/core/pkg/server/generic" + "github.com/nrc-no/core/pkg/server/options" +) + +type Server struct { + *generic.Server + options Options +} + +type Options struct { + options.ServerOptions +} + +func NewServer(options Options) (*Server, error) { + ctx := context.Background() + + // create the generic server + genericServer, err := generic.NewGenericServer(options.ServerOptions, "data") + if err != nil { + return nil, err + } + + // create the database connection + db, err := sqlx.ConnectContext(ctx, "sqlite3", ":memory:") + if err != nil { + return nil, err + } + + // create the engine + e, err := engine.NewEngine( + ctx, + func(ctx context.Context) (api.Transaction, error) { + return utils.NewTransaction(ctx, db, nil) + }, + &utils.UUIDGenerator{}, + &utils.Md5RevGenerator{}, + &utils.Clock{}, + "sqlite", + ) + if err != nil { + return nil, err + } + + // create the handlers + h := handler.NewHandler(e) + genericServer.GoRestfulContainer.Add(h.WebService()) + + // create the server + s := &Server{ + options: options, + Server: genericServer, + } + + return s, nil +} + +func (s *Server) Start(ctx context.Context) { + s.Server.Start(ctx) +} diff --git a/pkg/server/data/server_test.go b/pkg/server/data/server_test.go new file mode 100644 index 000000000..6a80476ca --- /dev/null +++ b/pkg/server/data/server_test.go @@ -0,0 +1,148 @@ +package data + +import ( + "context" + "fmt" + "testing" + + _ "github.com/mattn/go-sqlite3" + "github.com/nrc-no/core/pkg/server/data/api" + "github.com/nrc-no/core/pkg/server/data/client" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type Suite struct { + suite.Suite + ctx context.Context + cancel context.CancelFunc + cli client.HTTPClient +} + +func (s *Suite) SetupSuite() { + server, err := NewServer(Options{}) + if err != nil { + s.T().Fatal(err) + } + s.ctx, s.cancel = context.WithCancel(context.Background()) + server.Start(s.ctx) + + s.cli = client.NewClient(fmt.Sprintf("http://localhost:%d", server.Port())) +} + +func (s *Suite) MustCreateTable(table api.Table) api.Table { + ret, err := s.cli.CreateTable(s.ctx, table) + if !s.NoError(err) { + s.T().Fatal(err) + } + return ret +} + +func (s *Suite) MustPutRecord(record api.PutRecordRequest) api.Record { + ret, err := s.cli.PutRecord(s.ctx, record) + if !s.NoError(err) { + s.T().Fatal(err) + } + return ret +} + +func (s *Suite) MustGetRecord(request api.GetRecordRequest) api.Record { + ret, err := s.cli.GetRecord(s.ctx, request) + if !s.NoError(err) { + s.T().Fatal(err) + } + return ret +} + +func (s *Suite) MustGetChanges(request api.GetChangesRequest) api.Changes { + ret, err := s.cli.GetChanges(s.ctx, request) + if !s.NoError(err) { + s.T().Fatal(err) + } + return ret +} + +func (s *Suite) AssertHasChangesSince(since int64, count int) { + changes := s.MustGetChanges(api.GetChangesRequest{Since: since}) + assert.Equal(s.T(), count, len(changes.Items)) + s.T().Log(changes) +} + +func (s *Suite) TearDownSuite() { + s.cancel() +} + +func (s *Suite) TestServer() { + + var ( + table api.Table + createdRecord api.Record + updatedRecord api.Record + foundRecord api.Record + changes api.Changes + ) + + s.AssertHasChangesSince(0, 0) + + // Create table + table = s.MustCreateTable(api.Table{ + Name: "bla", + Columns: []api.Column{{Name: "bli", Type: "varchar"}}, + }) + + s.T().Log(table) + + // Create record + createdRecord = s.MustPutRecord(api.PutRecordRequest{ + Record: api.Record{ + ID: "1", + Table: table.Name, + Attributes: api.NewAttributes().WithString("bli", "bla"), + }, + }) + assert.Equal(s.T(), 1, createdRecord.Revision.Num) + + s.AssertHasChangesSince(0, 1) + s.T().Log(createdRecord) + + changes = s.MustGetChanges(api.GetChangesRequest{Since: 0}) + assert.Equal(s.T(), 1, len(changes.Items)) + s.T().Log(changes) + + // Get record + foundRecord = s.MustGetRecord(api.GetRecordRequest{ + RecordID: createdRecord.ID, + TableName: table.Name, + }) + + s.AssertHasChangesSince(0, 1) + s.T().Log(foundRecord) + + assert.Equal(s.T(), createdRecord, foundRecord) + + // Update the record + foundRecord.Attributes.WithString("bli", "blub") + updatedRecord = s.MustPutRecord(api.PutRecordRequest{ + Record: foundRecord, + }) + assert.Equal(s.T(), 2, updatedRecord.Revision.Num) + assert.Equal(s.T(), api.NewStringValue("blub", true), updatedRecord.Attributes["bli"]) + + s.AssertHasChangesSince(0, 2) + s.T().Log(updatedRecord) + + // Get record + foundRecord = s.MustGetRecord(api.GetRecordRequest{ + RecordID: createdRecord.ID, + TableName: table.Name, + }) + + s.T().Log(foundRecord) + + assert.Equal(s.T(), updatedRecord, foundRecord) + +} + +func TestSuite(t *testing.T) { + suite.Run(t, new(Suite)) +} diff --git a/pkg/server/data/test/assertions.go b/pkg/server/data/test/assertions.go new file mode 100644 index 000000000..6b591270a --- /dev/null +++ b/pkg/server/data/test/assertions.go @@ -0,0 +1,9 @@ +package test + +import "github.com/stretchr/testify/assert" + +var ErrorIs = func(expect error) assert.ErrorAssertionFunc { + return func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.ErrorIs(t, err, expect, i...) + } +} diff --git a/pkg/server/data/test/db_recorder.go b/pkg/server/data/test/db_recorder.go new file mode 100644 index 000000000..be2a5e3c6 --- /dev/null +++ b/pkg/server/data/test/db_recorder.go @@ -0,0 +1,50 @@ +package test + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// DBRecorder is a testing utility to record SQL statements and parameters +// that were issued by the engine +type DBRecorder struct { + statements []string + params [][]interface{} +} + +// GetStatements returns the recorded statements +func (s *DBRecorder) GetStatements() []string { + return s.statements +} + +// GetParams returns the recorded parameters +func (s *DBRecorder) GetParams() [][]interface{} { + return s.params +} + +func (s *DBRecorder) Record(stmt string, params []interface{}) { + s.statements = append(s.statements, stmt) + s.params = append(s.params, params) +} + +// Reset resets the recorder +func (s *DBRecorder) Reset() { + s.statements = []string{} + s.params = [][]interface{}{} +} + +type ExpectedStatement struct { + SQL string + Params []interface{} +} + +func (s *DBRecorder) AssertStatementsExecuted(t *testing.T, expectStatements []ExpectedStatement) { + actualStatements := s.GetStatements() + actualParams := s.GetParams() + assert.Equal(t, len(expectStatements), len(actualStatements)) + for i, statement := range expectStatements { + assert.Equal(t, statement.SQL, actualStatements[i]) + assert.Equal(t, statement.Params, actualParams[i]) + } +} diff --git a/pkg/server/data/test/db_utils.go b/pkg/server/data/test/db_utils.go new file mode 100644 index 000000000..28f7900f6 --- /dev/null +++ b/pkg/server/data/test/db_utils.go @@ -0,0 +1,44 @@ +package test + +import ( + "database/sql" + "fmt" +) + +type dbIntf interface { + Query(string, ...interface{}) (*sql.Rows, error) + Exec(string, ...interface{}) (sql.Result, error) +} + +// DropAll is a testing utility that resets a database +func DropAll(db dbIntf) error { + rows, err := db.Query(`select "name" from "sqlite_master" where "type" = 'table'`) + if err != nil { + return err + } + var dropStatements []string + defer func() { + if err := rows.Close(); err != nil { + fmt.Println("Error closing rows:", err) + } + }() + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return err + } + if name == "sqlite_sequence" { + continue + } + dropStatements = append(dropStatements, fmt.Sprintf(`DROP TABLE IF EXISTS "%s"`, name)) + } + if err := rows.Err(); err != nil { + return err + } + for _, statement := range dropStatements { + if _, err := db.Exec(statement); err != nil { + return err + } + } + return nil +} diff --git a/pkg/server/data/test/mock_clock.go b/pkg/server/data/test/mock_clock.go new file mode 100644 index 000000000..189e8bfad --- /dev/null +++ b/pkg/server/data/test/mock_clock.go @@ -0,0 +1,32 @@ +package test + +import ( + "time" + + "github.com/nrc-no/core/pkg/server/data/api" +) + +// MockClock is a mock clock that can be used for testing. +type MockClock struct { + clock api.Clock + TheTime int64 +} + +func (c *MockClock) Now() int64 { + if c.clock != nil { + return c.clock.Now() + } + return c.TheTime +} + +func (c *MockClock) Tick(d time.Duration) { + c.TheTime += int64(d) +} + +func (c *MockClock) UseClock(clock api.Clock) func() { + oldClock := c.clock + c.clock = clock + return func() { + c.clock = oldClock + } +} diff --git a/pkg/server/data/test/mock_uuid.go b/pkg/server/data/test/mock_uuid.go new file mode 100644 index 000000000..37ea76473 --- /dev/null +++ b/pkg/server/data/test/mock_uuid.go @@ -0,0 +1,10 @@ +package test + +// MockUUIDGenerator is a mock implementation of UUIDGenerator. +type MockUUIDGenerator struct { + ReturnUUID string +} + +func (g *MockUUIDGenerator) Generate() (string, error) { + return g.ReturnUUID, nil +} diff --git a/pkg/server/data/utils/clock.go b/pkg/server/data/utils/clock.go new file mode 100644 index 000000000..cfeeaf642 --- /dev/null +++ b/pkg/server/data/utils/clock.go @@ -0,0 +1,9 @@ +package utils + +import "time" + +type Clock struct{} + +func (c *Clock) Now() int64 { + return time.Now().Unix() +} diff --git a/pkg/server/data/utils/db_result_reader.go b/pkg/server/data/utils/db_result_reader.go new file mode 100644 index 000000000..c0e55790e --- /dev/null +++ b/pkg/server/data/utils/db_result_reader.go @@ -0,0 +1,83 @@ +package utils + +import ( + "database/sql" + "fmt" + + "github.com/nrc-no/core/pkg/server/data/api" +) + +// SQLResultReader is a wrapper around sql.Rows that implements the ResultReader api. +// It takes care of deserializing the rows into a map of column name to value. +type SQLResultReader struct { + rows *sql.Rows + columns []string + values []api.Value + pointers []interface{} +} + +// Err returns the error, if any, that was encountered during iteration. +func (r SQLResultReader) Err() error { + return r.rows.Err() +} + +// Close closes the Rows, preventing further enumeration. If Next is called +func (r SQLResultReader) Close() error { + return r.rows.Close() +} + +// Next prepares the next result row for reading. It returns true if there is +func (r SQLResultReader) Next() bool { + return r.rows.Next() +} + +// Read reads the next result row +func (r SQLResultReader) Read(columnKinds []api.ValueKind) (map[string]api.Value, error) { + var values []api.Value + var pointers []interface{} + for _, kind := range columnKinds { + switch kind { + case api.ValueKindString: + values = append(values, api.Value{Kind: kind, String: &api.String{}}) + pointers = append(pointers, values[len(values)-1].String) + case api.ValueKindInt: + values = append(values, api.Value{Kind: kind, Int: &api.Int{}}) + pointers = append(pointers, values[len(values)-1].Int) + case api.ValueKindFloat: + values = append(values, api.Value{Kind: kind, Float: &api.Float{}}) + pointers = append(pointers, values[len(values)-1].Float) + case api.ValueKindBool: + values = append(values, api.Value{Kind: kind, Bool: &api.Bool{}}) + pointers = append(pointers, values[len(values)-1].Bool) + default: + return nil, fmt.Errorf("unsupported value kind: %v", kind) + } + } + if err := r.rows.Scan(pointers...); err != nil { + return nil, err + } + result := make(map[string]api.Value) + for i, column := range r.columns { + result[column] = values[i] + } + return result, nil +} + +// NewSQLResultReader creates a new SQLResultReader from the given sql.Rows. +func NewSQLResultReader(rows *sql.Rows) (*SQLResultReader, error) { + columns, err := rows.Columns() + if err != nil { + return nil, err + } + values := make([]api.Value, len(columns)) + pointers := make([]interface{}, len(columns)) + for i := range values { + pointers[i] = &values[i] + } + return &SQLResultReader{ + rows: rows, + columns: columns, + values: values, + pointers: pointers, + }, nil +} diff --git a/pkg/server/data/utils/md5_generator.go b/pkg/server/data/utils/md5_generator.go new file mode 100644 index 000000000..1f3167fa5 --- /dev/null +++ b/pkg/server/data/utils/md5_generator.go @@ -0,0 +1,32 @@ +package utils + +import ( + "crypto/md5" + "encoding/hex" + "fmt" + "sort" + "strconv" + + "github.com/nrc-no/core/pkg/server/data/api" +) + +// Md5RevGenerator implements the RevisionGenerator api. +// It creates a md5 hash out of a given data set with a number prefix +// It produces a valid Revision string +type Md5RevGenerator struct { +} + +// Generate implements the RevisionGenerator.Generate +func (r Md5RevGenerator) Generate(num int, data map[string]interface{}) api.Revision { + h := md5.New() + var sortedFields []string + for key := range data { + sortedFields = append(sortedFields, key) + } + sort.Strings(sortedFields) + for i, sortedField := range sortedFields { + fieldValue, _ := data[sortedField] + h.Write([]byte(sortedField + ":" + strconv.Itoa(i) + ":" + fmt.Sprintf("%v", fieldValue))) + } + return api.NewRevision(num, hex.EncodeToString(h.Sum(nil))) +} diff --git a/pkg/server/data/utils/transaction.go b/pkg/server/data/utils/transaction.go new file mode 100644 index 000000000..4f2cfb7dc --- /dev/null +++ b/pkg/server/data/utils/transaction.go @@ -0,0 +1,56 @@ +package utils + +import ( + "context" + + "github.com/jmoiron/sqlx" + "github.com/nrc-no/core/pkg/server/data/api" +) + +type Transaction struct { + tx *sqlx.Tx + onQuery func(qry string, args []interface{}) +} + +func (t Transaction) Query(ctx context.Context, query string, args []interface{}) (api.ResultReader, error) { + if t.onQuery != nil { + t.onQuery(query, args) + } + res, err := t.tx.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + r, err := NewSQLResultReader(res) + if err != nil { + return nil, err + } + return r, nil +} + +func (t Transaction) Exec(ctx context.Context, query string, args []interface{}) (interface{}, error) { + if t.onQuery != nil { + t.onQuery(query, args) + } + return t.tx.ExecContext(ctx, query, args...) +} + +func (t Transaction) Commit() error { + return t.tx.Commit() +} + +func (t Transaction) Rollback() error { + return t.tx.Rollback() +} + +var _ api.Transaction = &Transaction{} + +func NewTransaction(ctx context.Context, db *sqlx.DB, onQuery func(qry string, args []interface{})) (api.Transaction, error) { + tx, err := db.BeginTxx(ctx, nil) + if err != nil { + return nil, err + } + return &Transaction{ + tx: tx, + onQuery: onQuery, + }, nil +} diff --git a/pkg/server/data/utils/uuid_generator.go b/pkg/server/data/utils/uuid_generator.go new file mode 100644 index 000000000..2f1777359 --- /dev/null +++ b/pkg/server/data/utils/uuid_generator.go @@ -0,0 +1,49 @@ +package utils + +import ( + "github.com/nrc-no/core/pkg/server/data/api" +) + +// UUIDGenerator implements the UUIDGenerator api. +// it is a zero-dependency struct that can only generate V4 uuids +type UUIDGenerator struct { + rand api.Rand +} + +// Generate implements UUIDGenerator.Generate +func (g *UUIDGenerator) Generate() (string, error) { + var u [16]byte + if _, err := g.rand.Read(u[:]); err != nil { + return "", err + } + u[6] = (u[6] & 0x0f) | (4 << 4) + u[8] = u[8]&(0xff>>2) | (0x02 << 6) + buf := make([]byte, 36) + encodeHex(buf[0:8], u[0:4]) + buf[8] = '-' + encodeHex(buf[9:13], u[4:6]) + buf[13] = '-' + encodeHex(buf[14:18], u[6:8]) + buf[18] = '-' + encodeHex(buf[19:23], u[8:10]) + buf[23] = '-' + encodeHex(buf[24:], u[10:]) + return string(buf), nil +} + +const hexTable = "0123456789abcdef" + +// EncodeHex encodes a byte array to hexadecimal string +// dst is the destination buffer +// src is the source buffer +// it returns the number of bytes encoded in dst +// it is a zero-dependent version of hex.Encode +func encodeHex(dst, src []byte) int { + j := 0 + for _, v := range src { + dst[j] = hexTable[v>>4] + dst[j+1] = hexTable[v&0x0f] + j += 2 + } + return len(src) * 2 +} diff --git a/pkg/server/generic/server.go b/pkg/server/generic/server.go index 49b10e828..c8bea26a5 100644 --- a/pkg/server/generic/server.go +++ b/pkg/server/generic/server.go @@ -4,6 +4,13 @@ import ( "context" "errors" "fmt" + "net" + "net/http" + "reflect" + "strconv" + "strings" + "time" + "github.com/boj/redistore" restfulspec "github.com/emicklei/go-restful-openapi/v2" "github.com/emicklei/go-restful/v3" @@ -18,11 +25,6 @@ import ( "github.com/rs/cors" "go.uber.org/zap" "gopkg.in/matryer/try.v1" - "net" - "net/http" - "reflect" - "strings" - "time" ) type Server struct { @@ -248,6 +250,12 @@ func (g Server) Address() string { return g.address } +func (g Server) Port() int { + parts := strings.Split(g.address, ":") + port, _ := strconv.Atoi(parts[len(parts)-1]) + return port +} + func (g Server) Start(ctx context.Context) { ctx = logging.WithServerName(ctx, g.name) diff --git a/pkg/server/options/options.go b/pkg/server/options/options.go index 13007a75e..282419b55 100644 --- a/pkg/server/options/options.go +++ b/pkg/server/options/options.go @@ -56,6 +56,7 @@ type ServeOptions struct { AuthnzApi ServerOptions `mapstructure:"authnz_api"` Login ServerOptions `mapstructure:"login"` AuthnzBouncer ServerOptions `mapstructure:"authnz_bouncer"` + Data ServerOptions `mapstructure:"data"` } type CertOptions struct {