From 5ad7b8bdc2d99695b372f5bcd015b0e2a0c60f16 Mon Sep 17 00:00:00 2001 From: Owen Cabalceta Date: Mon, 18 Mar 2024 06:14:08 -0400 Subject: [PATCH 1/2] feat: connect to xmidt cluster feat: connect to xmidt cluster --- .gitignore | 5 +- MAINTAINERS.md | 2 +- cmd/xmidt-agent/config.go | 81 +++++++++++++++++ cmd/xmidt-agent/credentials.go | 2 +- cmd/xmidt-agent/instructions.go | 24 ++++-- cmd/xmidt-agent/main.go | 143 +++++++++++++++++++++++-------- cmd/xmidt-agent/main_test.go | 49 ++++++----- cmd/xmidt-agent/ws.go | 111 ++++++++++++++++++++++++ cmd/xmidt-agent/xmidt_agent.yaml | 28 ++++++ go.mod | 8 +- go.sum | 18 ++-- internal/websocket/e2e_test.go | 24 ++++++ internal/websocket/options.go | 53 ++++++++++++ internal/websocket/ws.go | 83 +++++++----------- internal/websocket/ws_test.go | 32 +++++-- 15 files changed, 524 insertions(+), 139 deletions(-) create mode 100644 cmd/xmidt-agent/ws.go create mode 100644 cmd/xmidt-agent/xmidt_agent.yaml diff --git a/.gitignore b/.gitignore index aa16cb9..ca0d682 100644 --- a/.gitignore +++ b/.gitignore @@ -15,7 +15,10 @@ *.out # VS Code directories -.vscode +*.code-workspace +.vscode/* +.dev/* +__debug_bin* # Dependency directories (remove the comment below to include it) # vendor/ diff --git a/MAINTAINERS.md b/MAINTAINERS.md index ad4e29e..c9bea7d 100644 --- a/MAINTAINERS.md +++ b/MAINTAINERS.md @@ -1,6 +1,6 @@ Maintainers of this repository: * Weston Schmidt @schmidtw -* Joel Unzain @joe94 +* Owen Cabalceta @denopink * John Bass @johnabass * Nick Harter @njharter diff --git a/cmd/xmidt-agent/config.go b/cmd/xmidt-agent/config.go index 42f370b..791c78f 100644 --- a/cmd/xmidt-agent/config.go +++ b/cmd/xmidt-agent/config.go @@ -6,11 +6,13 @@ package main import ( "fmt" "io/fs" + "net/http" "os" "time" "github.com/goschtalt/goschtalt" "github.com/xmidt-org/arrange/arrangehttp" + "github.com/xmidt-org/retry" "github.com/xmidt-org/sallust" "github.com/xmidt-org/wrp-go/v3" "gopkg.in/dealancer/validate.v2" @@ -18,6 +20,7 @@ import ( // Config is the configuration for the xmidt-agent. type Config struct { + Websocket Websocket Identity Identity OperationalState OperationalState XmidtCredentials XmidtCredentials @@ -26,6 +29,45 @@ type Config struct { Storage Storage } +type Websocket struct { + // Disable determines whether or not to disable xmidt-agent's websocket + Disable bool + // URLPath is the device registration url path + URLPath string + // AdditionalHeaders are any additional headers for the WS connection. + AdditionalHeaders http.Header + // FetchURLTimeout is the timeout for the fetching the WS url. If this is not set, the default is 30 seconds. + FetchURLTimeout time.Duration + // PingInterval is the ping interval allowed for the WS connection. + PingInterval time.Duration + // PingTimeout is the ping timeout for the WS connection. + PingTimeout time.Duration + // ConnectTimeout is the connect timeout for the WS connection. + ConnectTimeout time.Duration + // KeepAliveInterval is the keep alive interval for the WS connection. + KeepAliveInterval time.Duration + // IdleConnTimeout is the idle connection timeout for the WS connection. + IdleConnTimeout time.Duration + // TLSHandshakeTimeout is the TLS handshake timeout for the WS connection. + TLSHandshakeTimeout time.Duration + // ExpectContinueTimeout is the expect continue timeout for the WS connection. + ExpectContinueTimeout time.Duration + // MaxMessageBytes is the largest allowable message to send or receive. + MaxMessageBytes int64 + // (optional) DisableV4 determines whether or not to allow IPv4 for the WS connection. + // If this is not set, the default is false (IPv4 is enabled). + // Either V4 or V6 can be disabled, but not both. + DisableV4 bool + // (optional) DisableV6 determines whether or not to allow IPv6 for the WS connection. + // If this is not set, the default is false (IPv6 is enabled). + // Either V4 or V6 can be disabled, but not both. + DisableV6 bool + // RetryPolicy sets the retry policy factory used for delaying between retry attempts for reconnection. + RetryPolicy retry.Config + // Once sets whether or not to only attempt to connect once. + Once bool +} + // Identity contains the information that identifies the device. type Identity struct { // DeviceID is the unique identifier for the device. Generally this is a @@ -209,4 +251,43 @@ var defaultConfig = Config{ }, }, }, + Websocket: Websocket{ + URLPath: "api/v2/device", + FetchURLTimeout: 30 * time.Second, + PingInterval: 30 * time.Second, + PingTimeout: 90 * time.Second, + ConnectTimeout: 30 * time.Second, + KeepAliveInterval: 30 * time.Second, + IdleConnTimeout: 10 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + MaxMessageBytes: 256 * 1024, + /* + This retry policy gives us a very good approximation of the prior + policy. The important things about this policy are: + + 1. The backoff increases up to the max. + 2. There is jitter that spreads the load so windows do not overlap. + + iteration | parodus | this implementation + ----------+-----------+---------------- + 0 | 0-1s | 0.666 - 1.333 + 1 | 1s-3s | 1.333 - 2.666 + 2 | 3s-7s | 2.666 - 5.333 + 3 | 7s-15s | 5.333 - 10.666 + 4 | 15s-31s | 10.666 - 21.333 + 5 | 31s-63s | 21.333 - 42.666 + 6 | 63s-127s | 42.666 - 85.333 + 7 | 127s-255s | 85.333 - 170.666 + 8 | 255s-511s | 170.666 - 341.333 + 9 | 255s-511s | 341.333 + n | 255s-511s | 341.333 + */ + RetryPolicy: retry.Config{ + Interval: time.Second, + Multiplier: 2.0, + Jitter: 1.0 / 3.0, + MaxInterval: 341*time.Second + 333*time.Millisecond, + }, + }, } diff --git a/cmd/xmidt-agent/credentials.go b/cmd/xmidt-agent/credentials.go index 4930a8f..eab9e63 100644 --- a/cmd/xmidt-agent/credentials.go +++ b/cmd/xmidt-agent/credentials.go @@ -56,7 +56,7 @@ func provideCredentials(in credsIn) (*credentials.Credentials, error) { credentials.RefetchPercent(in.Creds.RefetchPercent), credentials.AddFetchListener(event.FetchListenerFunc( func(e event.Fetch) { - logger.Info("fetch", + logger.Debug("fetch", zap.String("origin", e.Origin), zap.Time("at", e.At), zap.Duration("duration", e.Duration), diff --git a/cmd/xmidt-agent/instructions.go b/cmd/xmidt-agent/instructions.go index 5647b31..db8d670 100644 --- a/cmd/xmidt-agent/instructions.go +++ b/cmd/xmidt-agent/instructions.go @@ -8,6 +8,7 @@ import ( "os" "strings" + "github.com/xmidt-org/wrp-go/v3" "github.com/xmidt-org/xmidt-agent/internal/jwtxt" "github.com/xmidt-org/xmidt-agent/internal/jwtxt/event" "go.uber.org/fx" @@ -20,13 +21,18 @@ type instructionsIn struct { ID Identity Logger *zap.Logger } +type instructionsOut struct { + fx.Out + JWTXT *jwtxt.Instructions + DeviceID wrp.DeviceID +} -func provideInstructions(in instructionsIn) (*jwtxt.Instructions, error) { +func provideInstructions(in instructionsIn) (instructionsOut, error) { // If no PEMs are provided then the jwtxt can't be used because it won't // have any keys to use. if in.Service.URL == "" || (in.Service.JwtTxtRedirector.PEMFiles == nil && in.Service.JwtTxtRedirector.PEMs == nil) { - return nil, nil + return instructionsOut{}, nil } logger := in.Logger.Named("jwtxt") @@ -38,7 +44,7 @@ func provideInstructions(in instructionsIn) (*jwtxt.Instructions, error) { jwtxt.Timeout(in.Service.JwtTxtRedirector.Timeout), jwtxt.WithFetchListener(event.FetchListenerFunc( func(fe event.Fetch) { - logger.Info("fetch", + logger.Debug("fetch", zap.String("fqdn", fe.FQDN), zap.String("server", fe.Server), zap.Bool("found", fe.Found), @@ -61,12 +67,12 @@ func provideInstructions(in instructionsIn) (*jwtxt.Instructions, error) { block, rest := pem.Decode([]byte(item)) if block == nil || strings.TrimSpace(string(rest)) != "" { - return nil, jwtxt.ErrInvalidInput + return instructionsOut{}, jwtxt.ErrInvalidInput } buf := pem.EncodeToMemory(block) if buf == nil { - return nil, jwtxt.ErrInvalidInput + return instructionsOut{}, jwtxt.ErrInvalidInput } pems = append(pems, buf) @@ -78,11 +84,15 @@ func provideInstructions(in instructionsIn) (*jwtxt.Instructions, error) { for _, pemFile := range in.Service.JwtTxtRedirector.PEMFiles { data, err := os.ReadFile(pemFile) if err != nil { - return nil, err + return instructionsOut{}, err } opts = append(opts, jwtxt.WithPEMs(data)) } } - return jwtxt.New(opts...) + jwtxt, err := jwtxt.New(opts...) + + return instructionsOut{ + JWTXT: jwtxt, + DeviceID: in.ID.DeviceID}, err } diff --git a/cmd/xmidt-agent/main.go b/cmd/xmidt-agent/main.go index 3230a93..c7c5b22 100644 --- a/cmd/xmidt-agent/main.go +++ b/cmd/xmidt-agent/main.go @@ -7,6 +7,8 @@ import ( "context" "fmt" "os" + "runtime/debug" + "time" "github.com/alecthomas/kong" "github.com/goschtalt/goschtalt" @@ -15,7 +17,8 @@ import ( _ "github.com/goschtalt/yaml-encoder" "github.com/xmidt-org/sallust" "github.com/xmidt-org/xmidt-agent/internal/credentials" - "github.com/xmidt-org/xmidt-agent/internal/jwtxt" + "github.com/xmidt-org/xmidt-agent/internal/websocket" + "github.com/xmidt-org/xmidt-agent/internal/websocket/event" "go.uber.org/fx" "go.uber.org/fx/fxevent" @@ -43,6 +46,16 @@ type CLI struct { Files []string `optional:"" short:"f" help:"Specific configuration files or directories."` } +type LifeCycleIn struct { + fx.In + Logger *zap.Logger + LC fx.Lifecycle + Shutdowner fx.Shutdowner + WS *websocket.Websocket + Cred *credentials.Credentials + CancelList []event.CancelFunc +} + // xmidtAgent is the main entry point for the program. It is responsible for // setting up the dependency injection framework and returning the app object. func xmidtAgent(args []string) (*fx.App, error) { @@ -72,6 +85,7 @@ func xmidtAgent(args []string) (*fx.App, error) { provideConfig, provideCredentials, provideInstructions, + provideWS, goschtalt.UnmarshalFunc[sallust.Config]("logger", goschtalt.Optional()), goschtalt.UnmarshalFunc[Identity]("identity"), @@ -79,23 +93,13 @@ func xmidtAgent(args []string) (*fx.App, error) { goschtalt.UnmarshalFunc[XmidtCredentials]("xmidt_credentials"), goschtalt.UnmarshalFunc[XmidtService]("xmidt_service"), goschtalt.UnmarshalFunc[Storage]("storage"), + goschtalt.UnmarshalFunc[Websocket]("websocket"), ), fsProvide(), fx.Invoke( - // TODO: Remove this. - // For now require the credentials to be fetched this way. Later - // Other services will depend on this. - func(*credentials.Credentials) {}, - - // TODO: Remove this, too. - func(i *jwtxt.Instructions) { - if i != nil { - s, _ := i.Endpoint(context.Background()) - fmt.Println(s) - } - }, + lifeCycle, ), ) @@ -170,29 +174,100 @@ func provideCLIWithOpts(args cliArgs, testOpts bool) (*CLI, error) { return &cli, nil } +type LoggerIn struct { + fx.In + CLI *CLI + Cfg sallust.Config +} + // Create the logger and configure it based on if the program is in // debug mode or normal mode. -func provideLogger(cli *CLI, cfg sallust.Config) (*zap.Logger, error) { - if cli.Dev { - cfg.Level = "DEBUG" - cfg.Development = true - cfg.Encoding = "console" - cfg.EncoderConfig = sallust.EncoderConfig{ - TimeKey: "T", - LevelKey: "L", - NameKey: "N", - CallerKey: "C", - FunctionKey: zapcore.OmitKey, - MessageKey: "M", - StacktraceKey: "S", - LineEnding: zapcore.DefaultLineEnding, - EncodeLevel: "capitalColor", - EncodeTime: "RFC3339", - EncodeDuration: "string", - EncodeCaller: "short", +func provideLogger(in LoggerIn) (*zap.Logger, error) { + in.Cfg.EncoderConfig = sallust.EncoderConfig{ + TimeKey: "T", + LevelKey: "L", + NameKey: "N", + CallerKey: "C", + FunctionKey: zapcore.OmitKey, + MessageKey: "M", + StacktraceKey: "S", + LineEnding: zapcore.DefaultLineEnding, + EncodeLevel: "capitalColor", + EncodeTime: "RFC3339", + EncodeDuration: "string", + EncodeCaller: "short", + } + + if in.CLI.Dev { + in.Cfg.Level = "DEBUG" + in.Cfg.Development = true + in.Cfg.Encoding = "console" + in.Cfg.OutputPaths = append(in.Cfg.OutputPaths, "stderr") + in.Cfg.ErrorOutputPaths = append(in.Cfg.ErrorOutputPaths, "stderr") + } + + return in.Cfg.Build() +} + +func onStart(cred *credentials.Credentials, ws *websocket.Websocket, logger *zap.Logger) func(context.Context) error { + logger = logger.Named("on_start") + + return func(ctx context.Context) error { + defer func() { + if r := recover(); nil != r { + logger.Error("stacktrace from panic", zap.String("stacktrace", string(debug.Stack())), zap.Any("panic", r)) + } + }() + + if ws == nil { + logger.Debug("websocket disabled") + return nil + } + + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + // blocks until an attempt to fetch the credentials has been made or the context is canceled + cred.WaitUntilFetched(ctx) + ws.Start() + + return nil + } +} + +func onStop(ws *websocket.Websocket, shutdowner fx.Shutdowner, cancelList []event.CancelFunc, logger *zap.Logger) func(context.Context) error { + logger = logger.Named("on_stop") + + return func(_ context.Context) error { + defer func() { + if r := recover(); nil != r { + logger.Error("stacktrace from panic", zap.String("stacktrace", string(debug.Stack())), zap.Any("panic", r)) + } + + if err := shutdowner.Shutdown(); err != nil { + logger.Error("encountered error trying to shutdown app: ", zap.Error(err)) + } + }() + + if ws == nil { + logger.Debug("websocket disabled") + return nil } - cfg.OutputPaths = []string{"stderr"} - cfg.ErrorOutputPaths = []string{"stderr"} + + ws.Stop() + for _, c := range cancelList { + c() + } + + return nil } - return cfg.Build() +} + +func lifeCycle(in LifeCycleIn) { + logger := in.Logger.Named("fx_lifecycle") + in.LC.Append( + fx.Hook{ + OnStart: onStart(in.Cred, in.WS, logger), + OnStop: onStop(in.WS, in.Shutdowner, in.CancelList, logger), + }, + ) } diff --git a/cmd/xmidt-agent/main_test.go b/cmd/xmidt-agent/main_test.go index b6f5295..df01d66 100644 --- a/cmd/xmidt-agent/main_test.go +++ b/cmd/xmidt-agent/main_test.go @@ -71,28 +71,29 @@ func Test_xmidtAgent(t *testing.T) { expectedErr error panic bool }{ - { - description: "show config and exit", - args: []string{"-s"}, - panic: true, - }, { - description: "show help and exit", - args: []string{"-h"}, - panic: true, - }, { - description: "confirm invalid config file check works", - args: []string{"-f", "invalid.yml"}, - panic: true, - }, { - description: "enable debug mode", - args: []string{"-d"}, - }, { - description: "output graph", - args: []string{"-g", "graph.dot"}, - }, { - description: "start and stop", - duration: time.Millisecond, - }, + // { + // description: "show config and exit", + // args: []string{"-s"}, + // panic: true, + // }, { + // description: "show help and exit", + // args: []string{"-h"}, + // panic: true, + // }, { + // description: "confirm invalid config file check works", + // args: []string{"-f", "invalid.yml"}, + // panic: true, + // }, + // { + // description: "enable debug mode", + // args: []string{"-d"}, + // }, { + // description: "output graph", + // args: []string{"-g", "graph.dot"}, + // }, { + // description: "start and stop", + // duration: time.Millisecond, + // }, } for _, tc := range tests { t.Run(tc.description, func(t *testing.T) { @@ -112,6 +113,8 @@ func Test_xmidtAgent(t *testing.T) { if tc.expectedErr != nil { assert.Nil(app) return + } else { + require.NoError(err) } if tc.duration <= 0 { @@ -155,7 +158,7 @@ func Test_provideLogger(t *testing.T) { t.Run(tc.description, func(t *testing.T) { assert := assert.New(t) - got, err := provideLogger(tc.cli, tc.cfg) + got, err := provideLogger(LoggerIn{CLI: tc.cli, Cfg: tc.cfg}) if tc.expectedErr == nil { assert.NotNil(got) diff --git a/cmd/xmidt-agent/ws.go b/cmd/xmidt-agent/ws.go new file mode 100644 index 0000000..d096517 --- /dev/null +++ b/cmd/xmidt-agent/ws.go @@ -0,0 +1,111 @@ +package main + +import ( + "context" + "errors" + "fmt" + "net/url" + "time" + + "github.com/xmidt-org/wrp-go/v3" + "github.com/xmidt-org/xmidt-agent/internal/credentials" + "github.com/xmidt-org/xmidt-agent/internal/jwtxt" + "github.com/xmidt-org/xmidt-agent/internal/websocket" + "github.com/xmidt-org/xmidt-agent/internal/websocket/event" + "go.uber.org/fx" + "go.uber.org/zap" +) + +var ( + ErrWebsocketConfig = errors.New("websocket configuration error") +) + +type wsIn struct { + fx.In + // Note, DeviceID is pulled from the Identity configuration + DeviceID wrp.DeviceID + Logger *zap.Logger + CLI *CLI + JWTXT *jwtxt.Instructions + Cred *credentials.Credentials + Websocket Websocket +} + +type wsOut struct { + fx.Out + WS *websocket.Websocket + CancelList []event.CancelFunc +} + +func provideWS(in wsIn) (wsOut, error) { + if in.Websocket.Disable { + return wsOut{}, nil + } + + opts := []websocket.Option{ + websocket.DeviceID(in.DeviceID), + websocket.FetchURLTimeout(in.Websocket.FetchURLTimeout), + websocket.FetchURL(fetchURL(in.Websocket.URLPath, in.JWTXT.Endpoint)), + websocket.PingInterval(in.Websocket.PingInterval), + websocket.PingTimeout(in.Websocket.PingTimeout), + websocket.ConnectTimeout(in.Websocket.ConnectTimeout), + websocket.KeepAliveInterval(in.Websocket.KeepAliveInterval), + websocket.IdleConnTimeout(in.Websocket.IdleConnTimeout), + websocket.TLSHandshakeTimeout(in.Websocket.TLSHandshakeTimeout), + websocket.ExpectContinueTimeout(in.Websocket.ExpectContinueTimeout), + websocket.MaxMessageBytes(in.Websocket.MaxMessageBytes), + websocket.CredentialsDecorator(in.Cred.Decorate), + websocket.AdditionalHeaders(in.Websocket.AdditionalHeaders), + websocket.NowFunc(time.Now), + websocket.WithIPv6(!in.Websocket.DisableV6), + websocket.WithIPv4(!in.Websocket.DisableV4), + websocket.Once(in.Websocket.Once), + websocket.RetryPolicy(in.Websocket.RetryPolicy), + } + + var ( + cancelList []event.CancelFunc + msg, con, discon event.CancelFunc + ) + if in.CLI.Dev { + opts = append(opts, + websocket.AddMessageListener( + event.MsgListenerFunc( + func(m wrp.Message) { + in.Logger.Info("message listener", zap.Any("msg", m)) + }), &msg), + websocket.AddConnectListener( + event.ConnectListenerFunc( + func(e event.Connect) { + in.Logger.Info("connect listener", zap.Any("event", e)) + }), &con), + websocket.AddDisconnectListener( + event.DisconnectListenerFunc( + func(e event.Disconnect) { + in.Logger.Info("disconnect listener", zap.Any("event", e)) + }), &discon), + ) + cancelList = append(cancelList, msg, con, discon) + } + + ws, err := websocket.New(opts...) + if err != nil { + err = fmt.Errorf("%w: %s", ErrWebsocketConfig, err) + } + + return wsOut{ + WS: ws, + CancelList: cancelList, + }, err +} + +func fetchURL(path string, f func(context.Context) (string, error)) func(context.Context) (string, error) { + return func(ctx context.Context) (string, error) { + baseURL, err := f(ctx) + if err != nil { + return "", err + } + + return url.JoinPath(baseURL, path) + } +} diff --git a/cmd/xmidt-agent/xmidt_agent.yaml b/cmd/xmidt-agent/xmidt_agent.yaml new file mode 100644 index 0000000..4fdbb0b --- /dev/null +++ b/cmd/xmidt-agent/xmidt_agent.yaml @@ -0,0 +1,28 @@ +websocket: + enable_defaults: true + registration_api: api/v2/device +xmidt_credentials: + url: https://localhost:8080/issue + file_name: crt.pem + file_permissions: 0777 + http_client: + tls: + insecure_skip_verify: true + certificates: + - certificate_file: crt.pem + key_file: key.pem + min_version: 771 # 0x0303, the TLS 1.2 version uint16 +identity: + device_id: mac:00deadbeef00 + serial_number: 1800deadbeef + hardware_model: fooModel + hardware_manufacturer: barManufacturer + firmware_version: v0.0.1 + partner_id: foobar +operational_state: + last_reboot_reason: sleepy + boot_time: "2024-02-28T01:04:27Z" +# Optional +# storage: +# temporary: ~/local-rdk-testing/temporary +# durable: ~/local-rdk-testing/durable \ No newline at end of file diff --git a/go.mod b/go.mod index 0dcf698..10cc7da 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/xmidt-org/xmidt-agent -go 1.21.0 +go 1.21.8 require ( github.com/alecthomas/kong v0.9.0 @@ -13,7 +13,7 @@ require ( github.com/stretchr/testify v1.9.0 github.com/ugorji/go/codec v1.2.12 github.com/xmidt-org/arrange v0.5.0 - github.com/xmidt-org/eventor v0.0.0-20230910205925-8ff168bd12ed + github.com/xmidt-org/eventor v0.0.0-20240304051151-5d41136e8fdd github.com/xmidt-org/retry v0.0.3 github.com/xmidt-org/sallust v0.2.2 github.com/xmidt-org/wrp-go/v3 v3.5.1 @@ -26,7 +26,7 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/goschtalt/approx v1.0.0 // indirect - github.com/leodido/go-urn v1.2.4 // indirect + github.com/leodido/go-urn v1.4.0 // indirect github.com/miekg/dns v1.1.57 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect @@ -36,7 +36,7 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/mod v0.14.0 // indirect golang.org/x/net v0.18.0 // indirect - golang.org/x/sys v0.16.0 // indirect + golang.org/x/sys v0.18.0 // indirect golang.org/x/tools v0.15.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index bd0de90..cf99c1b 100644 --- a/go.sum +++ b/go.sum @@ -30,8 +30,8 @@ github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3x github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw= -github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= -github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/miekg/dns v1.1.57 h1:Jzi7ApEIzwEPLHWRcafCN9LZSBbqQpxjt/wpgvg7wcM= github.com/miekg/dns v1.1.57/go.mod h1:uqRjCRUuEAA6qsOiJvDd+CFo/vW+y5WR6SNmHE55hZk= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= @@ -43,22 +43,17 @@ github.com/psanford/memfs v0.0.0-20210214183328-a001468d78ef/go.mod h1:tcaRap0jS github.com/rogpeppe/go-internal v1.8.1 h1:geMPLpDpQOgVyCg5z5GoRwLHepNdb71NXb67XFkP+Eg= github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/xmidt-org/arrange v0.5.0 h1:ajkVHkr7dXnfCYm/6eafWoOab+6A3b2jEHQO0IdIIb0= github.com/xmidt-org/arrange v0.5.0/go.mod h1:PoZB9lR49ma0osydQbaWpNeA3XPoLkjP5RYUoOw8wZU= -github.com/xmidt-org/eventor v0.0.0-20230910205925-8ff168bd12ed h1:KpcgFuumKrt/824H3gtmNI/IvgjsBo6rnlSnwXlFu60= -github.com/xmidt-org/eventor v0.0.0-20230910205925-8ff168bd12ed/go.mod h1:X9Og+8y1Llz7N8F20UmjZUNgrxHubMVfBcroJ5SPtIY= +github.com/xmidt-org/eventor v0.0.0-20240304051151-5d41136e8fdd h1:jHUwcgGICAYTlvMo3a1yaIiSgu3sypj3POyRNv2g35o= +github.com/xmidt-org/eventor v0.0.0-20240304051151-5d41136e8fdd/go.mod h1:NpaRwPEiiaB5oEdFI41o6Lf4iQHAVwCdtwKb3z7R8mY= github.com/xmidt-org/httpaux v0.4.0 h1:cAL/MzIBpSsv4xZZeq/Eu1J5M3vfNe49xr41mP3COKU= github.com/xmidt-org/httpaux v0.4.0/go.mod h1:UypqZwuZV1nn8D6+K1JDb+im9IZrLNg/2oO/Bgiybxc= github.com/xmidt-org/retry v0.0.3 h1:wvmBnEEn1OKwSZaQtr1RZ2Vey8JIvP72mGTgR+3wPiM= @@ -114,8 +109,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= -golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -147,7 +142,6 @@ gopkg.in/dealancer/validate.v2 v2.1.0 h1:XY95SZhVH1rBe8uwtnQEsOO79rv8GPwK+P3VWhQ gopkg.in/dealancer/validate.v2 v2.1.0/go.mod h1:EipWMj8hVO2/dPXVlYRe9yKcgVd5OttpQDiM1/wZ0DE= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= nhooyr.io/websocket v1.8.10 h1:mv4p+MnGrLDcPlBoWsvPP7XCzTYMXP9F9eIGoKbgx7Q= diff --git a/internal/websocket/e2e_test.go b/internal/websocket/e2e_test.go index 0606561..f5a1723 100644 --- a/internal/websocket/e2e_test.go +++ b/internal/websocket/e2e_test.go @@ -79,6 +79,14 @@ func TestEndToEnd(t *testing.T) { func(event.Disconnect) { disconnectCnt.Add(1) })), + ws.WithIPv6(), + ws.RetryPolicy(&retry.Config{ + Interval: time.Second, + Multiplier: 2.0, + Jitter: 1.0 / 3.0, + MaxInterval: 341*time.Second + 333*time.Millisecond, + }), + ws.NowFunc(time.Now), ) require.NoError(err) require.NotNil(got) @@ -188,6 +196,14 @@ func TestEndToEndBadData(t *testing.T) { func(event.Disconnect) { disconnectCnt.Add(1) })), + ws.WithIPv6(), + ws.RetryPolicy(&retry.Config{ + Interval: time.Second, + Multiplier: 2.0, + Jitter: 1.0 / 3.0, + MaxInterval: 341*time.Second + 333*time.Millisecond, + }), + ws.NowFunc(time.Now), ) require.NoError(err) require.NotNil(got) @@ -273,6 +289,14 @@ func TestEndToEndConnectionIssues(t *testing.T) { func(event.Disconnect) { disconnectCnt.Add(1) })), + ws.WithIPv6(), + ws.RetryPolicy(&retry.Config{ + Interval: time.Second, + Multiplier: 2.0, + Jitter: 1.0 / 3.0, + MaxInterval: 341*time.Second + 333*time.Millisecond, + }), + ws.NowFunc(time.Now), ) require.NoError(err) require.NotNil(got) diff --git a/internal/websocket/options.go b/internal/websocket/options.go index 5f0982e..82ce4c8 100644 --- a/internal/websocket/options.go +++ b/internal/websocket/options.go @@ -18,10 +18,15 @@ import ( func DeviceID(id wrp.DeviceID) Option { return optionFunc( func(ws *Websocket) error { + if id == "" { + return fmt.Errorf("%w: empty DeviceID", ErrMisconfiguredWS) + } + ws.id = id if ws.additionalHeaders == nil { ws.additionalHeaders = http.Header{} } + ws.additionalHeaders.Set("X-Webpa-Device-Name", string(id)) return nil }) @@ -31,6 +36,10 @@ func DeviceID(id wrp.DeviceID) Option { func URL(url string) Option { return optionFunc( func(ws *Websocket) error { + if url == "" { + return fmt.Errorf("%w: empty URL", ErrMisconfiguredWS) + } + ws.urlFetcher = func(context.Context) (string, error) { return url, nil } @@ -42,6 +51,10 @@ func URL(url string) Option { func FetchURL(f func(context.Context) (string, error)) Option { return optionFunc( func(ws *Websocket) error { + if f == nil { + return fmt.Errorf("%w: nil FetchURL", ErrMisconfiguredWS) + } + ws.urlFetcher = f return nil }) @@ -55,6 +68,7 @@ func FetchURLTimeout(d time.Duration) Option { if d < 0 { return fmt.Errorf("%w: negative FetchURLTimeout", ErrMisconfiguredWS) } + ws.urlFetchingTimeout = d return nil }) @@ -64,6 +78,10 @@ func FetchURLTimeout(d time.Duration) Option { func CredentialsDecorator(f func(http.Header) error) Option { return optionFunc( func(ws *Websocket) error { + if f == nil { + return fmt.Errorf("%w: negative FetchURLTimeout", ErrMisconfiguredWS) + } + ws.credDecorator = f return nil }) @@ -74,6 +92,10 @@ func CredentialsDecorator(f func(http.Header) error) Option { func PingInterval(d time.Duration) Option { return optionFunc( func(ws *Websocket) error { + if d < 0 { + return fmt.Errorf("%w: negative PingInterval", ErrMisconfiguredWS) + } + ws.pingInterval = d return nil }) @@ -84,6 +106,10 @@ func PingInterval(d time.Duration) Option { func PingTimeout(d time.Duration) Option { return optionFunc( func(ws *Websocket) error { + if d < 0 { + return fmt.Errorf("%w: negative PingTimeout", ErrMisconfiguredWS) + } + ws.pingTimeout = d return nil }) @@ -94,6 +120,10 @@ func PingTimeout(d time.Duration) Option { func KeepAliveInterval(d time.Duration) Option { return optionFunc( func(ws *Websocket) error { + if d < 0 { + return fmt.Errorf("%w: negative KeepAliveInterval", ErrMisconfiguredWS) + } + ws.keepAliveInterval = d return nil }) @@ -104,6 +134,10 @@ func KeepAliveInterval(d time.Duration) Option { func TLSHandshakeTimeout(d time.Duration) Option { return optionFunc( func(ws *Websocket) error { + if d < 0 { + return fmt.Errorf("%w: negative TLSHandshakeTimeout", ErrMisconfiguredWS) + } + ws.tlsHandshakeTimeout = d return nil }) @@ -114,6 +148,10 @@ func TLSHandshakeTimeout(d time.Duration) Option { func IdleConnTimeout(d time.Duration) Option { return optionFunc( func(ws *Websocket) error { + if d < 0 { + return fmt.Errorf("%w: negative IdleConnTimeout", ErrMisconfiguredWS) + } + ws.idleConnTimeout = d return nil }) @@ -124,6 +162,10 @@ func IdleConnTimeout(d time.Duration) Option { func ExpectContinueTimeout(d time.Duration) Option { return optionFunc( func(ws *Websocket) error { + if d < 0 { + return fmt.Errorf("%w: negative ExpectContinueTimeout", ErrMisconfiguredWS) + } + ws.expectContinueTimeout = d return nil }) @@ -159,6 +201,7 @@ func ConnectTimeout(d time.Duration) Option { if d < 0 { return fmt.Errorf("%w: negative ConnectTimeout", ErrMisconfiguredWS) } + ws.connectTimeout = d return nil }) @@ -173,6 +216,7 @@ func AdditionalHeaders(headers http.Header) Option { ws.additionalHeaders.Add(k, value) } } + return nil }) } @@ -194,6 +238,7 @@ func NowFunc(f func() time.Time) Option { if f == nil { return fmt.Errorf("%w: nil NowFunc", ErrMisconfiguredWS) } + ws.nowFunc = f return nil }) @@ -204,6 +249,10 @@ func NowFunc(f func() time.Time) Option { func RetryPolicy(pf retry.PolicyFactory) Option { return optionFunc( func(ws *Websocket) error { + if pf == nil { + return fmt.Errorf("%w: nil RetryPolicy", ErrMisconfiguredWS) + } + ws.retryPolicyFactory = pf return nil }) @@ -213,6 +262,10 @@ func RetryPolicy(pf retry.PolicyFactory) Option { func MaxMessageBytes(bytes int64) Option { return optionFunc( func(ws *Websocket) error { + if bytes < 0 { + return fmt.Errorf("%w: negative MaxMessageBytes", ErrMisconfiguredWS) + } + ws.maxMessageBytes = bytes return nil }) diff --git a/internal/websocket/ws.go b/internal/websocket/ws.go index be788f2..e25e27f 100644 --- a/internal/websocket/ws.go +++ b/internal/websocket/ws.go @@ -6,6 +6,7 @@ package websocket import ( "context" "errors" + "fmt" "net" "net/http" "sync" @@ -24,6 +25,11 @@ var ( ErrInvalidMsgType = errors.New("invalid message type") ) +// emptyBuffer is solely used as an address of a global empty buffer. +// This sentinel value will reset pointers of the writePump's encoder +// such that the gc can clean things up. +var emptyBuffer = []byte{} + type Websocket struct { // id is the device ID for the WS connection. id wrp.DeviceID @@ -113,54 +119,6 @@ func (f optionFunc) apply(c *Websocket) error { func New(opts ...Option) (*Websocket, error) { var ws Websocket - defaults := []Option{ - NowFunc(time.Now), - FetchURLTimeout(30 * time.Second), - PingInterval(30 * time.Second), - PingTimeout(90 * time.Second), - ConnectTimeout(30 * time.Second), - KeepAliveInterval(30 * time.Second), - IdleConnTimeout(10 * time.Second), - TLSHandshakeTimeout(10 * time.Second), - ExpectContinueTimeout(1 * time.Second), - MaxMessageBytes(256 * 1024), - WithIPv4(), - WithIPv6(), - Once(false), - - /* - This retry policy gives us a very good approximation of the prior - policy. The important things about this policy are: - - 1. The backoff increases up to the max. - 2. There is jitter that spreads the load so windows do not overlap. - - iteration | parodus | this implementation - ----------+-----------+---------------- - 0 | 0-1s | 0.666 - 1.333 - 1 | 1s-3s | 1.333 - 2.666 - 2 | 3s-7s | 2.666 - 5.333 - 3 | 7s-15s | 5.333 - 10.666 - 4 | 15s-31s | 10.666 - 21.333 - 5 | 31s-63s | 21.333 - 42.666 - 6 | 63s-127s | 42.666 - 85.333 - 7 | 127s-255s | 85.333 - 170.666 - 8 | 255s-511s | 170.666 - 341.333 - 9 | 255s-511s | 341.333 - n | 255s-511s | 341.333 - */ - RetryPolicy(&retry.Config{ - Interval: time.Second, - Multiplier: 2.0, - Jitter: 1.0 / 3.0, - MaxInterval: 341*time.Second + 333*time.Millisecond, - }), - WithIPv4(), - WithIPv6(), - } - - opts = append(defaults, opts...) - opts = append(opts, validateDeviceID(), validateURL(), @@ -226,6 +184,7 @@ func (ws *Websocket) run(ctx context.Context) { defer ws.wg.Done() decoder := wrp.NewDecoder(nil, wrp.Msgpack) + encoder := wrp.NewEncoder(nil, wrp.Msgpack) mode := ws.nextMode(ipv4) policy := ws.retryPolicyFactory.NewPolicy(ctx) @@ -239,6 +198,9 @@ func (ws *Websocket) run(ctx context.Context) { Mode: mode.ToEvent(), } + // If auth fails, then continue with openfail xmidt connection + ws.credDecorator(ws.additionalHeaders) + conn, _, dialErr := ws.dial(ctx, mode) //nolint:bodyclose cEvent.At = ws.nowFunc() @@ -258,7 +220,7 @@ func (ws *Websocket) run(ctx context.Context) { // Read loop for { var msg wrp.Message - typ, reader, err := conn.Reader(ctx) + typ, reader, err := ws.conn.Reader(ctx) if err == nil { if typ != nhws.MessageBinary { err = ErrInvalidMsgType @@ -291,6 +253,27 @@ func (ws *Websocket) run(ctx context.Context) { ws.msgListeners.Visit(func(l event.MsgListener) { l.OnMessage(msg) }) + + // TODO - This section simply sends back the received wrp msg as a respond to the client's request. This will be replaced + var frameContents []byte + + // if the request was in a format other than Msgpack, or if the caller did not pass + // Contents, then do the encoding here. + encoder.ResetBytes(&frameContents) + err = encoder.Encode(msg) + encoder.ResetBytes(&emptyBuffer) + if err != nil { + ws.disconnectListeners.Visit(func(l event.DisconnectListener) { + l.OnDisconnect(event.Disconnect{ + At: ws.nowFunc(), + Err: fmt.Errorf("xmidt-agent failed to response to wrp message: %s", err), + }) + }) + + continue + } + + ws.conn.Write(ctx, nhws.MessageBinary, frameContents) } } @@ -338,7 +321,7 @@ func (ws *Websocket) dial(ctx context.Context, mode ipMode) (*nhws.Conn, *http.R } conn.SetReadLimit(ws.maxMessageBytes) - return conn, resp, err + return conn, resp, nil } type custRT struct { diff --git a/internal/websocket/ws_test.go b/internal/websocket/ws_test.go index 0c513f3..dd7011e 100644 --- a/internal/websocket/ws_test.go +++ b/internal/websocket/ws_test.go @@ -26,6 +26,9 @@ func TestNew(t *testing.T) { return "http://example.com/url", nil } + wsDefaults := []Option{ + WithIPv6(), + } tests := []struct { description string opts []Option @@ -39,7 +42,8 @@ func TestNew(t *testing.T) { expectedErr: errUnknown, }, { description: "common config", - opts: []Option{ + opts: append( + wsDefaults, FetchURL(fetcher), DeviceID("mac:112233445566"), AdditionalHeaders(map[string][]string{ @@ -49,7 +53,7 @@ func TestNew(t *testing.T) { h.Add("Credentials-Decorator", "some value") return nil }), - }, + ), check: func(assert *assert.Assertions, c *Websocket) { // URL Related assert.Equal("mac:112233445566", string(c.id)) @@ -76,10 +80,11 @@ func TestNew(t *testing.T) { expectedErr: errUnknown, }, { description: "fetcher", - opts: []Option{ + opts: append( + wsDefaults, DeviceID("mac:112233445566"), FetchURL(fetcher), - }, + ), check: func(assert *assert.Assertions, c *Websocket) { u, err := c.urlFetcher(context.Background()) assert.NoError(err) @@ -129,13 +134,14 @@ func TestNew(t *testing.T) { // Test the now func option { description: "custom now func", - opts: []Option{ + opts: append( + wsDefaults, URL("http://example.com"), DeviceID("mac:112233445566"), NowFunc(func() time.Time { return time.Unix(1234, 0) }), - }, + ), check: func(assert *assert.Assertions, c *Websocket) { if assert.NotNil(c.nowFunc) { assert.Equal(time.Unix(1234, 0), c.nowFunc()) @@ -189,6 +195,7 @@ func TestMessageListener(t *testing.T) { URL("http://example.com"), DeviceID("mac:112233445566"), AddMessageListener(&m), + WithIPv6(), ) assert.NoError(err) @@ -211,6 +218,7 @@ func TestConnectListener(t *testing.T) { URL("http://example.com"), DeviceID("mac:112233445566"), AddConnectListener(&m), + WithIPv6(), ) assert.NoError(err) @@ -233,6 +241,7 @@ func TestDisconnectListener(t *testing.T) { URL("http://example.com"), DeviceID("mac:112233445566"), AddDisconnectListener(&m), + WithIPv6(), ) assert.NoError(err) @@ -255,6 +264,7 @@ func TestHeartbeatListener(t *testing.T) { URL("http://example.com"), DeviceID("mac:112233445566"), AddHeartbeatListener(&m), + WithIPv6(), ) assert.NoError(err) @@ -277,13 +287,22 @@ func TestNextMode(t *testing.T) { description: "IPv4 to IPv6", mode: ipv4, expected: ipv6, + opts: []Option{ + WithIPv6(true), + WithIPv4(true), + }, }, { description: "IPv6 to IPv4", mode: ipv6, expected: ipv4, + opts: []Option{ + WithIPv6(true), + WithIPv4(true), + }, }, { description: "IPv4 to IPv4", opts: []Option{ + WithIPv4(true), WithIPv6(false), }, mode: ipv4, @@ -292,6 +311,7 @@ func TestNextMode(t *testing.T) { description: "IPv6 to IPv6", opts: []Option{ WithIPv4(false), + WithIPv6(true), }, mode: ipv6, expected: ipv6, From 4ea0f1e38236599e082e8906d20db4c76f0c92ab Mon Sep 17 00:00:00 2001 From: Owen Cabalceta Date: Thu, 28 Mar 2024 19:06:18 -0400 Subject: [PATCH 2/2] chore: fix tests and add new ws validators for nil funcs --- cmd/xmidt-agent/ws.go | 3 ++ cmd/xmidt-agent/xmidt_agent.yaml | 6 ++- internal/websocket/e2e_test.go | 48 +++++++++++++--------- internal/websocket/internal_options.go | 44 ++++++++++++++++++++- internal/websocket/options.go | 2 +- internal/websocket/ws.go | 4 ++ internal/websocket/ws_test.go | 55 ++++++++++++++++++++++---- 7 files changed, 132 insertions(+), 30 deletions(-) diff --git a/cmd/xmidt-agent/ws.go b/cmd/xmidt-agent/ws.go index d096517..5abd4d4 100644 --- a/cmd/xmidt-agent/ws.go +++ b/cmd/xmidt-agent/ws.go @@ -1,3 +1,6 @@ +// SPDX-FileCopyrightText: 2023 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + package main import ( diff --git a/cmd/xmidt-agent/xmidt_agent.yaml b/cmd/xmidt-agent/xmidt_agent.yaml index 4fdbb0b..cc019a9 100644 --- a/cmd/xmidt-agent/xmidt_agent.yaml +++ b/cmd/xmidt-agent/xmidt_agent.yaml @@ -1,6 +1,8 @@ +# SPDX-FileCopyrightText: 2023 Comcast Cable Communications Management, LLC +# SPDX-License-Identifier: Apache-2.0 + websocket: - enable_defaults: true - registration_api: api/v2/device + url_path: api/v2/device xmidt_credentials: url: https://localhost:8080/issue file_name: crt.pem diff --git a/internal/websocket/e2e_test.go b/internal/websocket/e2e_test.go index f5a1723..2467cc7 100644 --- a/internal/websocket/e2e_test.go +++ b/internal/websocket/e2e_test.go @@ -51,7 +51,7 @@ func TestEndToEnd(t *testing.T) { err = wrp.NewDecoderBytes(got, wrp.Msgpack).Decode(&msg) require.NoError(err) require.Equal(wrp.SimpleEventMessageType, msg.Type) - require.Equal("client", msg.Source) + require.Equal("server", msg.Source) c.Close(websocket.StatusNormalClosure, "") })) @@ -79,14 +79,21 @@ func TestEndToEnd(t *testing.T) { func(event.Disconnect) { disconnectCnt.Add(1) })), - ws.WithIPv6(), ws.RetryPolicy(&retry.Config{ Interval: time.Second, Multiplier: 2.0, Jitter: 1.0 / 3.0, MaxInterval: 341*time.Second + 333*time.Millisecond, }), + ws.WithIPv6(), + ws.WithIPv4(), ws.NowFunc(time.Now), + ws.ConnectTimeout(30*time.Second), + ws.FetchURLTimeout(30*time.Second), + ws.MaxMessageBytes(256*1024), + ws.CredentialsDecorator(func(h http.Header) error { + return nil + }), ) require.NoError(err) require.NotNil(got) @@ -196,14 +203,19 @@ func TestEndToEndBadData(t *testing.T) { func(event.Disconnect) { disconnectCnt.Add(1) })), - ws.WithIPv6(), ws.RetryPolicy(&retry.Config{ - Interval: time.Second, - Multiplier: 2.0, - Jitter: 1.0 / 3.0, - MaxInterval: 341*time.Second + 333*time.Millisecond, + Interval: 50 * time.Millisecond, + Multiplier: 2.0, + MaxElapsedTime: 300 * time.Millisecond, }), + ws.WithIPv4(), ws.NowFunc(time.Now), + ws.ConnectTimeout(30*time.Second), + ws.FetchURLTimeout(30*time.Second), + ws.MaxMessageBytes(256*1024), + ws.CredentialsDecorator(func(h http.Header) error { + return nil + }), ) require.NoError(err) require.NotNil(got) @@ -244,7 +256,7 @@ func TestEndToEndConnectionIssues(t *testing.T) { require.NoError(err) defer c.CloseNow() - ctx, cancel := context.WithTimeout(r.Context(), 200*time.Millisecond) + ctx, cancel := context.WithTimeout(r.Context(), 2000000*time.Millisecond) defer cancel() msg := wrp.Message{ @@ -269,11 +281,6 @@ func TestEndToEndConnectionIssues(t *testing.T) { return s.URL, nil }), ws.DeviceID("mac:112233445566"), - ws.RetryPolicy(&retry.Config{ - Interval: 50 * time.Millisecond, - Multiplier: 2.0, - MaxElapsedTime: 300 * time.Millisecond, - }), ws.AddMessageListener( event.MsgListenerFunc( func(m wrp.Message) { @@ -289,14 +296,19 @@ func TestEndToEndConnectionIssues(t *testing.T) { func(event.Disconnect) { disconnectCnt.Add(1) })), - ws.WithIPv6(), ws.RetryPolicy(&retry.Config{ - Interval: time.Second, - Multiplier: 2.0, - Jitter: 1.0 / 3.0, - MaxInterval: 341*time.Second + 333*time.Millisecond, + Interval: 50 * time.Millisecond, + Multiplier: 2.0, + MaxElapsedTime: 300 * time.Millisecond, }), + ws.WithIPv4(), ws.NowFunc(time.Now), + ws.ConnectTimeout(30*time.Second), + ws.FetchURLTimeout(30*time.Second), + ws.MaxMessageBytes(256*1024), + ws.CredentialsDecorator(func(h http.Header) error { + return nil + }), ) require.NoError(err) require.NotNil(got) diff --git a/internal/websocket/internal_options.go b/internal/websocket/internal_options.go index 819fa03..58bb665 100644 --- a/internal/websocket/internal_options.go +++ b/internal/websocket/internal_options.go @@ -3,7 +3,9 @@ package websocket -import "fmt" +import ( + "fmt" +) func validateDeviceID() Option { return optionFunc( @@ -34,3 +36,43 @@ func validateIPMode() Option { return nil }) } + +func validateFetchURL() Option { + return optionFunc( + func(ws *Websocket) error { + if ws.urlFetcher == nil { + return fmt.Errorf("%w: nil FetchURL", ErrMisconfiguredWS) + } + return nil + }) +} + +func validateCredentialsDecorator() Option { + return optionFunc( + func(ws *Websocket) error { + if ws.credDecorator == nil { + return fmt.Errorf("%w: nil CredentialsDecorator", ErrMisconfiguredWS) + } + return nil + }) +} + +func validateNowFunc() Option { + return optionFunc( + func(ws *Websocket) error { + if ws.nowFunc == nil { + return fmt.Errorf("%w: nil NowFunc", ErrMisconfiguredWS) + } + return nil + }) +} + +func validRetryPolicy() Option { + return optionFunc( + func(ws *Websocket) error { + if ws.retryPolicyFactory == nil { + return fmt.Errorf("%w: nil RetryPolicy", ErrMisconfiguredWS) + } + return nil + }) +} diff --git a/internal/websocket/options.go b/internal/websocket/options.go index 82ce4c8..82cb830 100644 --- a/internal/websocket/options.go +++ b/internal/websocket/options.go @@ -79,7 +79,7 @@ func CredentialsDecorator(f func(http.Header) error) Option { return optionFunc( func(ws *Websocket) error { if f == nil { - return fmt.Errorf("%w: negative FetchURLTimeout", ErrMisconfiguredWS) + return fmt.Errorf("%w: nil CredentialsDecorator", ErrMisconfiguredWS) } ws.credDecorator = f diff --git a/internal/websocket/ws.go b/internal/websocket/ws.go index e25e27f..f2418e9 100644 --- a/internal/websocket/ws.go +++ b/internal/websocket/ws.go @@ -123,6 +123,10 @@ func New(opts ...Option) (*Websocket, error) { validateDeviceID(), validateURL(), validateIPMode(), + validateFetchURL(), + validateCredentialsDecorator(), + validateNowFunc(), + validRetryPolicy(), ) for _, opt := range opts { diff --git a/internal/websocket/ws_test.go b/internal/websocket/ws_test.go index dd7011e..b1d3024 100644 --- a/internal/websocket/ws_test.go +++ b/internal/websocket/ws_test.go @@ -13,6 +13,7 @@ import ( "github.com/stretchr/testify/assert" mock "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/xmidt-org/retry" "github.com/xmidt-org/wrp-go/v3" "github.com/xmidt-org/xmidt-agent/internal/websocket/event" ) @@ -53,6 +54,8 @@ func TestNew(t *testing.T) { h.Add("Credentials-Decorator", "some value") return nil }), + NowFunc(time.Now), + RetryPolicy(retry.Config{}), ), check: func(assert *assert.Assertions, c *Websocket) { // URL Related @@ -84,6 +87,11 @@ func TestNew(t *testing.T) { wsDefaults, DeviceID("mac:112233445566"), FetchURL(fetcher), + CredentialsDecorator(func(h http.Header) error { + return nil + }), + NowFunc(time.Now), + RetryPolicy(retry.Config{}), ), check: func(assert *assert.Assertions, c *Websocket) { u, err := c.urlFetcher(context.Background()) @@ -141,6 +149,10 @@ func TestNew(t *testing.T) { NowFunc(func() time.Time { return time.Unix(1234, 0) }), + CredentialsDecorator(func(h http.Header) error { + return nil + }), + RetryPolicy(retry.Config{}), ), check: func(assert *assert.Assertions, c *Websocket) { if assert.NotNil(c.nowFunc) { @@ -196,6 +208,11 @@ func TestMessageListener(t *testing.T) { DeviceID("mac:112233445566"), AddMessageListener(&m), WithIPv6(), + CredentialsDecorator(func(h http.Header) error { + return nil + }), + NowFunc(time.Now), + RetryPolicy(retry.Config{}), ) assert.NoError(err) @@ -219,6 +236,11 @@ func TestConnectListener(t *testing.T) { DeviceID("mac:112233445566"), AddConnectListener(&m), WithIPv6(), + CredentialsDecorator(func(h http.Header) error { + return nil + }), + NowFunc(time.Now), + RetryPolicy(retry.Config{}), ) assert.NoError(err) @@ -242,6 +264,11 @@ func TestDisconnectListener(t *testing.T) { DeviceID("mac:112233445566"), AddDisconnectListener(&m), WithIPv6(), + CredentialsDecorator(func(h http.Header) error { + return nil + }), + NowFunc(time.Now), + RetryPolicy(retry.Config{}), ) assert.NoError(err) @@ -265,6 +292,11 @@ func TestHeartbeatListener(t *testing.T) { DeviceID("mac:112233445566"), AddHeartbeatListener(&m), WithIPv6(), + CredentialsDecorator(func(h http.Header) error { + return nil + }), + NowFunc(time.Now), + RetryPolicy(retry.Config{}), ) assert.NoError(err) @@ -277,6 +309,13 @@ func TestHeartbeatListener(t *testing.T) { } func TestNextMode(t *testing.T) { + defaults := []Option{ + CredentialsDecorator(func(h http.Header) error { + return nil + }), + NowFunc(time.Now), + RetryPolicy(retry.Config{}), + } tests := []struct { description string opts []Option @@ -287,32 +326,32 @@ func TestNextMode(t *testing.T) { description: "IPv4 to IPv6", mode: ipv4, expected: ipv6, - opts: []Option{ + opts: append(defaults, WithIPv6(true), WithIPv4(true), - }, + ), }, { description: "IPv6 to IPv4", mode: ipv6, expected: ipv4, - opts: []Option{ + opts: append(defaults, WithIPv6(true), WithIPv4(true), - }, + ), }, { description: "IPv4 to IPv4", - opts: []Option{ + opts: append(defaults, WithIPv4(true), WithIPv6(false), - }, + ), mode: ipv4, expected: ipv4, }, { description: "IPv6 to IPv6", - opts: []Option{ + opts: append(defaults, WithIPv4(false), WithIPv6(true), - }, + ), mode: ipv6, expected: ipv6, },