diff --git a/sdk/component/conf.go b/sdk/component/conf.go index 912bad072..c479ff7b5 100644 --- a/sdk/component/conf.go +++ b/sdk/component/conf.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/go-errors/errors" + "github.com/smithy-security/pkg/env" "github.com/smithy-security/smithy/sdk" "github.com/smithy-security/smithy/sdk/component/internal/uuid" @@ -191,12 +192,12 @@ func newRunnerConfig() (*RunnerConfig, error) { return nil, errors.Errorf("could not construct panic handler: %w", err) } - componentName, err := fromEnvOrDefault(envVarKeyComponentName, "", withFallbackToDefaultOnError(true)) + componentName, err := env.GetOrDefault(envVarKeyComponentName, "", env.WithDefaultOnError(true)) if err != nil { return nil, errors.Errorf("could not lookup environment for '%s': %w", envVarKeyComponentName, err) } - instanceIDStr, err := fromEnvOrDefault(envVarKeyInstanceID, "", withFallbackToDefaultOnError(true)) + instanceIDStr, err := env.GetOrDefault(envVarKeyInstanceID, "", env.WithDefaultOnError(true)) if err != nil { return nil, errors.Errorf("could not lookup environment for '%s': %w", envVarKeyInstanceID, err) } @@ -208,10 +209,10 @@ func newRunnerConfig() (*RunnerConfig, error) { // --- END - BASIC ENV - END --- // --- BEGIN - LOGGING ENV - BEGIN --- - logLevel, err := fromEnvOrDefault( + logLevel, err := env.GetOrDefault( envVarKeyLoggingLogLevel, RunnerConfigLoggingLevelDebug.String(), - withFallbackToDefaultOnError(true), + env.WithDefaultOnError(true), ) if err != nil { return nil, errors.Errorf("could not lookup environment for '%s': %w", envVarKeyLoggingLogLevel, err) @@ -224,7 +225,7 @@ func newRunnerConfig() (*RunnerConfig, error) { // --- END - LOGGING ENV - END --- // --- BEGIN - STORER ENV - BEGIN --- - st, err := fromEnvOrDefault(envVarKeyBackendStoreType, "", withFallbackToDefaultOnError(true)) + st, err := env.GetOrDefault(envVarKeyBackendStoreType, "", env.WithDefaultOnError(true)) if err != nil { return nil, errors.Errorf("could not lookup environment for '%s': %w", envVarKeyBackendStoreType, err) } @@ -248,10 +249,10 @@ func newRunnerConfig() (*RunnerConfig, error) { conf.storerConfig.storeType = storageType - dbDSN, err := fromEnvOrDefault( + dbDSN, err := env.GetOrDefault( envVarKeyBackendStoreDSN, "smithy.db", - withFallbackToDefaultOnError(true), + env.WithDefaultOnError(true), ) if err != nil { return nil, errors.Errorf("could not lookup environment for '%s': %w", envVarKeyBackendStoreDSN, err) diff --git a/sdk/component/env.go b/sdk/component/env.go deleted file mode 100644 index c8fad30d2..000000000 --- a/sdk/component/env.go +++ /dev/null @@ -1,127 +0,0 @@ -package component - -import ( - "net/url" - "os" - "strconv" - "time" - - "github.com/go-errors/errors" - - "github.com/smithy-security/smithy/sdk/component/internal/uuid" -) - -type ( - // parseableEnvTypes represents the types the parser is capable of handling. - // TODO: extend with slices if needed. - parseableEnvTypes interface { - string | bool | int | uint | int64 | uint64 | float64 | time.Duration | time.Time | url.URL | uuid.UUID - } - - // envLoader is an alias for a function that loads values from the env. It mirrors the signature of os.Getenv. - envLoader func(key string) string - - envParseOpts struct { - envLoader envLoader - defaultOnError bool - timeLayout string - sensitive bool - } - - // envParseOption is a means to customize parse options via variadic parameters. - envParseOption func(o *envParseOpts) error -) - -var ( - defaultEnvParseOptions = envParseOpts{ - envLoader: os.Getenv, - defaultOnError: false, - timeLayout: time.RFC3339, - } -) - -// withEnvLoader allows overriding how env vars are loaded. -// -// Primarily used for testing. -func withEnvLoader(loader envLoader) envParseOption { - return func(o *envParseOpts) error { - if loader == nil { - return errors.New("env loader function cannot be nil") - } - - o.envLoader = loader - return nil - } -} - -// withFallbackToDefaultOnError informs the parser that if an error is encountered during parsing, it should fallback to the default value. -func withFallbackToDefaultOnError(fallback bool) envParseOption { - return func(o *envParseOpts) error { - o.defaultOnError = fallback - return nil - } -} - -// fromEnvOrDefault attempts to parse the environment variable provided. If it is empty or missing, the default value is used. -// -// If an error is encountered, depending on whether the `withFallbackToDefaultOnError` option is provided it will either -// fallback or return the error back to the client. -func fromEnvOrDefault[T parseableEnvTypes](envVar string, defaultVal T, opts ...envParseOption) (dest T, err error) { - parseOpts := &defaultEnvParseOptions - for _, opt := range opts { - if err := opt(parseOpts); err != nil { - return dest, errors.Errorf("option error: %w", err) - } - } - - envStr := parseOpts.envLoader(envVar) - if envStr == "" { - if !parseOpts.defaultOnError { - return dest, errors.Errorf("required env variable '%s' not found", envVar) - } - return defaultVal, nil - } - - var v any - - switch any(dest).(type) { - case string: - v = envStr - case bool: - v, err = strconv.ParseBool(envStr) - case int: - v, err = strconv.Atoi(envStr) - case uint: - var i uint64 - i, err = strconv.ParseUint(envStr, 10, 64) - v = uint(i) - case int64: - v, err = strconv.ParseInt(envStr, 10, 64) - case uint64: - v, err = strconv.ParseUint(envStr, 10, 64) - case float64: - v, err = strconv.ParseFloat(envStr, 64) - case time.Duration: - v, err = time.ParseDuration(envStr) - case time.Time: - v, err = time.Parse(parseOpts.timeLayout, envStr) - case url.URL: - v, err = url.Parse(envStr) - case uuid.UUID: - v, err = uuid.Parse(envStr) - } - if err != nil { - if parseOpts.defaultOnError { - return defaultVal, nil - } - - return dest, errors.Errorf("failed to parse env %s to %T: %v", envVar, dest, err) - } - - dest, ok := v.(T) - if !ok { - return dest, errors.Errorf("failed to cast env %s to %T", envVar, dest) - } - - return dest, nil -} diff --git a/sdk/component/env_test.go b/sdk/component/env_test.go deleted file mode 100644 index 5d93f92c8..000000000 --- a/sdk/component/env_test.go +++ /dev/null @@ -1,366 +0,0 @@ -package component_test - -import ( - "fmt" - "math/rand" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/require" - - "github.com/smithy-security/smithy/sdk/component" - "github.com/smithy-security/smithy/sdk/component/internal/uuid" -) - -func TestParsesParseable(t *testing.T) { - var makeLoader = func(envs map[string]string) component.EnvLoader { - return func(key string) string { - return envs[key] - } - } - - t.Run("string", func(t *testing.T) { - const defaultVal = "default" - var ( - loader = makeLoader(map[string]string{"KNOWN_STRING": "a string"}) - cases = []struct { - searchEnv string - expected string - fallBackOnErr bool - }{ - {searchEnv: "KNOWN_STRING", expected: "a string", fallBackOnErr: true}, - {searchEnv: "UNKNOWN_ENV", expected: defaultVal, fallBackOnErr: true}, - } - ) - for _, tt := range cases { - t.Run(fmt.Sprintf("with env var %s", tt.searchEnv), func(t *testing.T) { - ret, err := component.FromEnvOrDefault( - tt.searchEnv, - defaultVal, - component.WithEnvLoader(loader), - component.WithFallbackToDefaultOnError(tt.fallBackOnErr), - ) - if err != nil { - require.NoError(t, err) - } - require.Equal(t, tt.expected, ret) - }) - } - }) - - t.Run("bool", func(t *testing.T) { - const defaultVal = false - var ( - loader = makeLoader(map[string]string{"KNOWN_BOOL": "true", "NOT_BOOL": "abcd"}) - cases = []struct { - searchEnv string - expected bool - expectedErrContains string - fallBackOnErr bool - }{ - {searchEnv: "KNOWN_BOOL", expected: true, fallBackOnErr: true}, - {searchEnv: "UNKNOWN_ENV", expected: defaultVal, fallBackOnErr: true}, - {searchEnv: "NOT_BOOL", expected: false, expectedErrContains: "invalid syntax", fallBackOnErr: true}, - } - ) - for _, tt := range cases { - t.Run(fmt.Sprintf("with env var %s", tt.searchEnv), func(t *testing.T) { - ret, err := component.FromEnvOrDefault( - tt.searchEnv, - defaultVal, - component.WithEnvLoader(loader), - component.WithFallbackToDefaultOnError(tt.fallBackOnErr), - ) - if err != nil && tt.expectedErrContains != "" { - require.True(t, strings.Contains(err.Error(), tt.expectedErrContains)) - return - } - - require.NoError(t, err) - require.Equal(t, tt.expected, ret) - }) - } - }) - - t.Run("int", func(t *testing.T) { - var defaultVal = rand.Int() - var ( - loader = makeLoader(map[string]string{"KNOWN_INT": "123", "NOT_INT": "abcd"}) - cases = []struct { - searchEnv string - expected int - expectedErrContains string - fallBackOnErr bool - }{ - {searchEnv: "KNOWN_INT", expected: 123, fallBackOnErr: true}, - {searchEnv: "UNKNOWN_ENV", expected: defaultVal, fallBackOnErr: true}, - {searchEnv: "NOT_INT", expectedErrContains: "invalid syntax", fallBackOnErr: false}, - } - ) - for _, tt := range cases { - t.Run(fmt.Sprintf("with env var %s", tt.searchEnv), func(t *testing.T) { - ret, err := component.FromEnvOrDefault( - tt.searchEnv, - defaultVal, - component.WithEnvLoader(loader), - component.WithFallbackToDefaultOnError(tt.fallBackOnErr), - ) - if err != nil && tt.expectedErrContains != "" { - require.True(t, strings.Contains(err.Error(), tt.expectedErrContains)) - return - } - - require.NoError(t, err) - require.Equal(t, tt.expected, ret) - }) - } - }) - - t.Run("uint", func(t *testing.T) { - const defaultVal = uint(555) - var ( - loader = makeLoader(map[string]string{"KNOWN_UINT": "123", "NOT_UINT": "abcd"}) - cases = []struct { - searchEnv string - expected uint - expectedErrContains string - fallBackOnErr bool - }{ - {searchEnv: "KNOWN_UINT", expected: 123, fallBackOnErr: true}, - {searchEnv: "UNKNOWN_ENV", expected: defaultVal, fallBackOnErr: true}, - {searchEnv: "NOT_UINT", expectedErrContains: "invalid syntax", fallBackOnErr: false}, - } - ) - for _, tt := range cases { - t.Run(fmt.Sprintf("with env var %s", tt.searchEnv), func(t *testing.T) { - ret, err := component.FromEnvOrDefault( - tt.searchEnv, - defaultVal, - component.WithEnvLoader(loader), - component.WithFallbackToDefaultOnError(tt.fallBackOnErr), - ) - if err != nil && tt.expectedErrContains != "" { - require.True(t, strings.Contains(err.Error(), tt.expectedErrContains)) - return - } - - require.NoError(t, err) - require.Equal(t, tt.expected, ret) - }) - } - }) - - t.Run("int64", func(t *testing.T) { - var ( - defaultVal = rand.Int63() - loader = makeLoader(map[string]string{"KNOWN_INT": "8675309", "NOT_INT": "abcd"}) - cases = []struct { - searchEnv string - expected int64 - expectedErrContains string - fallBackOnErr bool - }{ - {searchEnv: "KNOWN_INT", expected: 8675309, fallBackOnErr: true}, - {searchEnv: "UNKNOWN_ENV", expected: defaultVal, fallBackOnErr: true}, - {searchEnv: "NOT_INT", expectedErrContains: "invalid syntax", fallBackOnErr: false}, - } - ) - for _, tt := range cases { - t.Run(fmt.Sprintf("with env var %s", tt.searchEnv), func(t *testing.T) { - ret, err := component.FromEnvOrDefault( - tt.searchEnv, - defaultVal, - component.WithEnvLoader(loader), - component.WithFallbackToDefaultOnError(tt.fallBackOnErr), - ) - if err != nil && tt.expectedErrContains != "" { - require.True(t, strings.Contains(err.Error(), tt.expectedErrContains)) - return - } - - require.NoError(t, err) - require.Equal(t, tt.expected, ret) - }) - } - }) - - t.Run("uint64", func(t *testing.T) { - var ( - defaultVal = rand.Uint64() - loader = makeLoader(map[string]string{"KNOWN_UINT": "5555555", "NOT_UINT": "abcd"}) - cases = []struct { - searchEnv string - expected uint64 - expectedErrContains string - fallBackOnErr bool - }{ - {searchEnv: "KNOWN_UINT", expected: 5555555, fallBackOnErr: true}, - {searchEnv: "UNKNOWN_ENV", expected: defaultVal, fallBackOnErr: true}, - {searchEnv: "NOT_UINT", expectedErrContains: "invalid syntax", fallBackOnErr: false}, - } - ) - for _, tt := range cases { - t.Run(fmt.Sprintf("with env var %s", tt.searchEnv), func(t *testing.T) { - ret, err := component.FromEnvOrDefault( - tt.searchEnv, - defaultVal, - component.WithEnvLoader(loader), - component.WithFallbackToDefaultOnError(tt.fallBackOnErr), - ) - if err != nil && tt.expectedErrContains != "" { - require.True(t, strings.Contains(err.Error(), tt.expectedErrContains)) - return - } - - require.NoError(t, err) - require.Equal(t, tt.expected, ret) - }) - } - }) - - t.Run("float64", func(t *testing.T) { - var ( - defaultVal = rand.Float64() - loader = makeLoader(map[string]string{"KNOWN_FLOAT": "69.69", "NOT_FLOAT": "abcd"}) - cases = []struct { - searchEnv string - expected float64 - expectedErrContains string - fallBackOnErr bool - }{ - {searchEnv: "KNOWN_FLOAT", expected: 69.69, fallBackOnErr: true}, - {searchEnv: "UNKNOWN_ENV", expected: defaultVal, fallBackOnErr: true}, - {searchEnv: "NOT_FLOAT", expectedErrContains: "invalid syntax", fallBackOnErr: false}, - } - ) - for _, tt := range cases { - t.Run(fmt.Sprintf("with env var %s", tt.searchEnv), func(t *testing.T) { - ret, err := component.FromEnvOrDefault( - tt.searchEnv, - defaultVal, - component.WithEnvLoader(loader), - component.WithFallbackToDefaultOnError(tt.fallBackOnErr), - ) - if err != nil && tt.expectedErrContains != "" { - require.True(t, strings.Contains(err.Error(), tt.expectedErrContains)) - return - } - - require.NoError(t, err) - require.Equal(t, tt.expected, ret) - }) - } - }) - - t.Run("time.Duration", func(t *testing.T) { - var ( - defaultVal = time.Minute * 5 - loader = makeLoader(map[string]string{"KNOWN_DURATION": "10s", "NOT_DURATION": "abcd"}) - cases = []struct { - searchEnv string - expected time.Duration - expectedErrContains string - fallBackOnErr bool - }{ - {searchEnv: "KNOWN_DURATION", expected: time.Second * 10, fallBackOnErr: true}, - {searchEnv: "UNKNOWN_ENV", expected: defaultVal, fallBackOnErr: true}, - {searchEnv: "NOT_DURATION", expectedErrContains: "invalid duration", fallBackOnErr: false}, - } - ) - for _, tt := range cases { - t.Run(fmt.Sprintf("with env var %s", tt.searchEnv), func(t *testing.T) { - ret, err := component.FromEnvOrDefault( - tt.searchEnv, - defaultVal, - component.WithEnvLoader(loader), - component.WithFallbackToDefaultOnError(tt.fallBackOnErr), - ) - if err != nil && tt.expectedErrContains != "" { - require.True(t, strings.Contains(err.Error(), tt.expectedErrContains)) - return - } - - require.NoError(t, err) - require.Equal(t, tt.expected, ret) - }) - } - }) - - t.Run("time.Time", func(t *testing.T) { - var ( - defaultVal = time.Date(2021, time.January, 1, 0, 0, 0, 0, time.UTC) - loader = makeLoader(map[string]string{"KNOWN_TIME": "2021-01-01T00:00:00Z", "NOT_TIME": "abcd"}) - cases = []struct { - searchEnv string - expected time.Time - expectedErrContains string - fallBackOnErr bool - }{ - {searchEnv: "KNOWN_TIME", expected: time.Date(2021, time.January, 1, 0, 0, 0, 0, time.UTC), fallBackOnErr: true}, - {searchEnv: "UNKNOWN_ENV", expected: defaultVal, fallBackOnErr: true}, - {searchEnv: "NOT_TIME", expectedErrContains: "parsing time", fallBackOnErr: false}, - } - ) - for _, tt := range cases { - t.Run(fmt.Sprintf("with env var %s", tt.searchEnv), func(t *testing.T) { - ret, err := component.FromEnvOrDefault( - tt.searchEnv, - defaultVal, - component.WithEnvLoader(loader), - component.WithFallbackToDefaultOnError(tt.fallBackOnErr), - ) - if err != nil { - require.Error(t, err) - if tt.expectedErrContains != "" { - require.True(t, strings.Contains(err.Error(), tt.expectedErrContains)) - } - return - } - - require.NoError(t, err) - require.Equal(t, tt.expected, ret) - }) - } - }) - - t.Run("UUID", func(t *testing.T) { - const uuidStr = "3b778fbc-93e5-4f85-bda4-c7b56e964695" - - defaultVal, err := uuid.Parse(uuidStr) - require.NoError(t, err) - - var ( - loader = makeLoader(map[string]string{"VALID_ID": uuidStr, "INVALID_ID": "abcd"}) - cases = []struct { - searchEnv string - expected uuid.UUID - expectedErrContains string - fallBackOnErr bool - }{ - {searchEnv: "VALID_ID", expected: defaultVal, fallBackOnErr: true}, - {searchEnv: "INVALID_ID", expectedErrContains: "invalid UUID string", fallBackOnErr: false}, - } - ) - for _, tt := range cases { - t.Run(fmt.Sprintf("with env var %s", tt.searchEnv), func(t *testing.T) { - ret, err := component.FromEnvOrDefault( - tt.searchEnv, - defaultVal, - component.WithEnvLoader(loader), - component.WithFallbackToDefaultOnError(tt.fallBackOnErr), - ) - if err != nil { - require.Error(t, err) - if tt.expectedErrContains != "" { - require.True(t, strings.Contains(err.Error(), tt.expectedErrContains)) - } - return - } - - require.NoError(t, err) - require.Equal(t, tt.expected, ret) - }) - } - }) -} diff --git a/sdk/component/export_test.go b/sdk/component/export_test.go deleted file mode 100644 index 6ac30bacc..000000000 --- a/sdk/component/export_test.go +++ /dev/null @@ -1,20 +0,0 @@ -// This file is ignored by the go build tool as it ends with the '_test.go' postifx. -// It can be used to export unexported symbols for unit testing while not leaking these to the public API. -package component - -// -- START env.go exports -- -type EnvLoader envLoader - -func WithEnvLoader(loader EnvLoader) envParseOption { - return withEnvLoader(envLoader(loader)) -} - -func WithFallbackToDefaultOnError(v bool) envParseOption { - return withFallbackToDefaultOnError(v) -} - -func FromEnvOrDefault[T parseableEnvTypes](envVar string, defaultVal T, opts ...envParseOption) (dest T, err error) { - return fromEnvOrDefault[T](envVar, defaultVal, opts...) -} - -// -- END env.go exports -- diff --git a/sdk/component/internal/storer/local/sqlite/export_test.go b/sdk/component/internal/storer/local/sqlite/export_test.go deleted file mode 100644 index 02fcc135d..000000000 --- a/sdk/component/internal/storer/local/sqlite/export_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package sqlite - -import "github.com/go-errors/errors" - -// CreateTable is used to create a table in testing settings. -func (m *manager) CreateTable() error { - stmt, err := m.db.Prepare(` - CREATE TABLE finding ( - id INTEGER PRIMARY KEY AUTOINCREMENT UNIQUE, - instance_id UUID NOT NULL UNIQUE, - findings TEXT NOT NULL, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ); - `) - if err != nil { - return errors.Errorf("could not prepare statement for creating table: %w", err) - } - - if _, err := stmt.Exec(); err != nil { - return errors.Errorf("could not create table: %w", err) - } - - return stmt.Close() -} diff --git a/sdk/component/internal/storer/local/sqlite/sqlite.go b/sdk/component/internal/storer/local/sqlite/sqlite.go index 20a2f72ff..a05a1ab95 100644 --- a/sdk/component/internal/storer/local/sqlite/sqlite.go +++ b/sdk/component/internal/storer/local/sqlite/sqlite.go @@ -74,6 +74,10 @@ func NewManager(dsn string, opts ...managerOption) (*manager, error) { } } + if err := mgr.migrate(); err != nil { + return nil, errors.Errorf("could not apply migrations: %w", err) + } + return mgr, nil } @@ -211,6 +215,29 @@ func (m *manager) Close(ctx context.Context) error { return nil } +// TODO: potentially leverage migrations here but this is simple enough for now for local setup. +// Tracked here https://linear.app/smithy/issue/OCU-274/automigrate-on-sqlite-storage. +func (m *manager) migrate() error { + stmt, err := m.db.Prepare(` + CREATE TABLE IF NOT EXISTS finding ( + id INTEGER PRIMARY KEY AUTOINCREMENT UNIQUE, + instance_id UUID NOT NULL UNIQUE, + findings TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ); + `) + if err != nil { + return fmt.Errorf("could not prepare statement for creating table: %w", err) + } + + if _, err := stmt.Exec(); err != nil { + return fmt.Errorf("could not create table: %w", err) + } + + return stmt.Close() +} + func (m *manager) marshalFindings(findings []*ocsf.VulnerabilityFinding) (string, error) { var rawFindings []json.RawMessage for _, finding := range findings { diff --git a/sdk/component/internal/storer/local/sqlite/sqlite_test.go b/sdk/component/internal/storer/local/sqlite/sqlite_test.go index bd483276b..2f50dc323 100644 --- a/sdk/component/internal/storer/local/sqlite/sqlite_test.go +++ b/sdk/component/internal/storer/local/sqlite/sqlite_test.go @@ -29,8 +29,6 @@ type ( component.Reader component.Updater component.Writer - - CreateTable() error } ManagerTestSuite struct { @@ -54,7 +52,6 @@ func (mts *ManagerTestSuite) SetupTest() { mts.manager, err = sqlite.NewManager("smithy.db", sqlite.ManagerWithClock(clock)) require.NoError(mts.t, err) - require.NoError(mts.T(), mts.manager.CreateTable()) } func (mts *ManagerTestSuite) TearDownTest() { diff --git a/sdk/go.mod b/sdk/go.mod index cbb560afc..f2a2a81d1 100644 --- a/sdk/go.mod +++ b/sdk/go.mod @@ -1,6 +1,6 @@ module github.com/smithy-security/smithy/sdk -go 1.23.0 +go 1.23.2 require ( github.com/abice/go-enum v0.6.0 @@ -8,6 +8,7 @@ require ( github.com/google/uuid v1.6.0 github.com/jonboulle/clockwork v0.4.0 github.com/mattn/go-sqlite3 v1.14.24 + github.com/smithy-security/pkg/env v0.0.1 github.com/stretchr/testify v1.9.0 go.uber.org/mock v0.5.0 golang.org/x/sync v0.8.0 diff --git a/sdk/go.sum b/sdk/go.sum index 989a59131..6967d9fed 100644 --- a/sdk/go.sum +++ b/sdk/go.sum @@ -59,6 +59,8 @@ github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/smithy-security/pkg/env v0.0.1 h1:uwLTMLdNN/dv3x4zat75JahEBQDpdBeldjEE8El4OiM= +github.com/smithy-security/pkg/env v0.0.1/go.mod h1:VIJfDqeAbQQcmohaXcZI6grjeJC9Y8CmqR4ITpdngZE= github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng= github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/sdk/vendor/github.com/smithy-security/pkg/env/LICENSE b/sdk/vendor/github.com/smithy-security/pkg/env/LICENSE new file mode 100644 index 000000000..2f5da2d1a --- /dev/null +++ b/sdk/vendor/github.com/smithy-security/pkg/env/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 smithy-security + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/sdk/vendor/github.com/smithy-security/pkg/env/README.md b/sdk/vendor/github.com/smithy-security/pkg/env/README.md new file mode 100644 index 000000000..f8f245151 --- /dev/null +++ b/sdk/vendor/github.com/smithy-security/pkg/env/README.md @@ -0,0 +1,31 @@ +# Env + +Minimalistic package for environment variable lookup of type defined in `Parseable`. + +## Example usage + +```go +package main + +import ( + "github.com/smithy-security/pkg/env" +) + +func main() { + // Will error if not defined or a valid integer. + intVar, err := env.GetOrDefault("MY_INT_ENV_VAR", 10) + if err != nil { + ... + } + + // Will return the default value 10 if not defined or on error + anotherIntVar, err := env.GetOrDefault("MY_OTHER_INT_ENV_VAR", 10, env.WithDefaultOnError(true)) + if err != nil { + ... + } +} +``` + +## On testing + +Customise `Loader` to mock your environment. Check the examples in `env_test.go`. diff --git a/sdk/vendor/github.com/smithy-security/pkg/env/env.go b/sdk/vendor/github.com/smithy-security/pkg/env/env.go new file mode 100644 index 000000000..de10bcbfb --- /dev/null +++ b/sdk/vendor/github.com/smithy-security/pkg/env/env.go @@ -0,0 +1,124 @@ +package env + +import ( + "errors" + "fmt" + "os" + "strconv" + "time" +) + +type ( + // Parseable represents the types the parser is capable of handling. + // TODO: extend with slices if needed. + Parseable interface { + string | bool | int | uint | int64 | uint64 | float64 | time.Duration | time.Time + } + + // ParseOption is a means to customize parse options via variadic parameters. + ParseOption func(o *parseOpts) error + + // Loader is an alias for a function that loads values from the env. + // It mirrors the signature of os.Getenv. + Loader func(key string) string + + parseOpts struct { + envLoader Loader + defaultOnError bool + timeLayout string + } +) + +var ( + defaultEnvParseOptions = parseOpts{ + envLoader: os.Getenv, + defaultOnError: false, + timeLayout: time.RFC3339, + } +) + +// WithLoader allows overriding how env vars are loaded. +// +// Primarily used for testing. +func WithLoader(loader Loader) ParseOption { + return func(o *parseOpts) error { + if loader == nil { + return errors.New("env loader function cannot be nil") + } + + o.envLoader = loader + return nil + } +} + +// WithDefaultOnError informs the parser that if an error is encountered during parsing, it should fallback to the default value. +func WithDefaultOnError(fallback bool) ParseOption { + return func(o *parseOpts) error { + o.defaultOnError = fallback + return nil + } +} + +// GetOrDefault attempts to parse the environment variable provided. If it is empty or missing, the default value is used. +// +// If an error is encountered, depending on whether the `WithDefaultOnError` option is provided it will either +// fall back or return the error back to the client. +func GetOrDefault[T Parseable](envVar string, defaultVal T, opts ...ParseOption) (dest T, err error) { + if envVar == "" { + return dest, errors.New("environment variable cannot be blank") + } + + defaultOpts := &defaultEnvParseOptions + for _, opt := range opts { + if err := opt(defaultOpts); err != nil { + return dest, fmt.Errorf("option error: %w", err) + } + } + + envStr := defaultOpts.envLoader(envVar) + if envStr == "" { + if !defaultOpts.defaultOnError { + return dest, fmt.Errorf("required environment variable '%s' not found", envVar) + } + return defaultVal, nil + } + + var v any + + switch any(dest).(type) { + case string: + v = envStr + case bool: + v, err = strconv.ParseBool(envStr) + case int: + v, err = strconv.Atoi(envStr) + case uint: + var i uint64 + i, err = strconv.ParseUint(envStr, 10, 64) + v = uint(i) + case int64: + v, err = strconv.ParseInt(envStr, 10, 64) + case uint64: + v, err = strconv.ParseUint(envStr, 10, 64) + case float64: + v, err = strconv.ParseFloat(envStr, 64) + case time.Duration: + v, err = time.ParseDuration(envStr) + case time.Time: + v, err = time.Parse(defaultOpts.timeLayout, envStr) + } + if err != nil { + if defaultOpts.defaultOnError { + return defaultVal, nil + } + + return dest, fmt.Errorf("failed to parse environment variable '%s' to '%T': %w", envVar, dest, err) + } + + dest, ok := v.(T) + if !ok { + return dest, fmt.Errorf("failed to cast environment variable '%s' to '%T'", envVar, dest) + } + + return dest, nil +} diff --git a/sdk/vendor/modules.txt b/sdk/vendor/modules.txt index c3ce9ed59..66d19ff30 100644 --- a/sdk/vendor/modules.txt +++ b/sdk/vendor/modules.txt @@ -70,6 +70,9 @@ github.com/russross/blackfriday/v2 # github.com/shopspring/decimal v1.2.0 ## explicit; go 1.13 github.com/shopspring/decimal +# github.com/smithy-security/pkg/env v0.0.1 +## explicit; go 1.23.2 +github.com/smithy-security/pkg/env # github.com/spf13/cast v1.3.1 ## explicit github.com/spf13/cast