diff --git a/utils/utils.go b/utils/utils.go index e09daa4b9..de3f4aae9 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -602,3 +602,26 @@ func ExtractSha256FromResponseBody(body []byte) (string, error) { func Pointer[K any](val K) *K { return &val } + +func SetEnvWithResetCallback(key, value string) func() { + oldValue, exist := os.LookupEnv(key) + errMsg := "failed %s %s as environment variable. Cause: %s" + + if err := os.Setenv(key, value); err != nil { + log.Debug(fmt.Sprintf(errMsg, "setting", key, err.Error())) + return func() {} + } + + if exist { + return func() { + if err := os.Setenv(key, oldValue); err != nil { + log.Debug(fmt.Sprintf(errMsg, "setting", key, err.Error())) + } + } + } + return func() { + if err := os.Unsetenv(key); err != nil { + log.Debug(fmt.Sprintf(errMsg, "unsetting", key, err.Error())) + } + } +} diff --git a/utils/utils_test.go b/utils/utils_test.go index ac61ad82c..ac7cdb06e 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -2,6 +2,7 @@ package utils import ( "fmt" + "os" "reflect" "sort" "testing" @@ -287,3 +288,47 @@ func TestValidateMinimumVersion(t *testing.T) { }) } } + +func TestSetEnvWithResetCallback(t *testing.T) { + type args struct { + key string + value string + } + tests := []struct { + name string + args args + init func() + finish func() + }{ + { + name: "existing environment variable", + args: args{key: "TEST_KEY", value: "test_value"}, + init: func() { + assert.NoError(t, os.Setenv("TEST_KEY", "test-init-value")) + }, + finish: func() { + assert.Equal(t, os.Getenv("TEST_KEY"), "test-init-value") + }, + }, + { + name: "non-existing environment variable", + args: args{key: "NEW_TEST_KEY", value: "test_value"}, + init: func() { + + }, + finish: func() { + _, exist := os.LookupEnv("NEW_TEST_KEY") + assert.False(t, exist) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.init() + resetCallback := SetEnvWithResetCallback(tt.args.key, tt.args.value) + assert.Equal(t, tt.args.value, os.Getenv(tt.args.key)) + resetCallback() + tt.finish() + }) + } +}