diff --git a/examples/dynamic/go.mod b/examples/dynamic/go.mod new file mode 100644 index 0000000..af179e7 --- /dev/null +++ b/examples/dynamic/go.mod @@ -0,0 +1,23 @@ +module github.com/KimMachineGun/automemlimit/examples/dynamic + +go 1.21 + +toolchain go1.21.0 + +require github.com/KimMachineGun/automemlimit v0.0.0 + +require ( + github.com/cilium/ebpf v0.9.1 // indirect + github.com/containerd/cgroups/v3 v3.0.1 // indirect + github.com/coreos/go-systemd/v22 v22.3.2 // indirect + github.com/docker/go-units v0.4.0 // indirect + github.com/godbus/dbus/v5 v5.0.4 // indirect + github.com/google/go-cmp v0.5.9 // indirect + github.com/opencontainers/runtime-spec v1.0.2 // indirect + github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 // indirect + github.com/sirupsen/logrus v1.8.1 // indirect + golang.org/x/sys v0.13.0 // indirect + google.golang.org/protobuf v1.33.0 // indirect +) + +replace github.com/KimMachineGun/automemlimit => ../../ diff --git a/examples/dynamic/go.sum b/examples/dynamic/go.sum new file mode 100644 index 0000000..da57856 --- /dev/null +++ b/examples/dynamic/go.sum @@ -0,0 +1,42 @@ +github.com/cilium/ebpf v0.9.1 h1:64sn2K3UKw8NbP/blsixRpF3nXuyhz/VjRlRzvlBRu4= +github.com/cilium/ebpf v0.9.1/go.mod h1:+OhNOIXx/Fnu1IE8bJz2dzOA+VSfyTfdNUVdlQnxUFY= +github.com/containerd/cgroups/v3 v3.0.1 h1:4hfGvu8rfGIwVIDd+nLzn/B9ZXx4BcCjzt5ToenJRaE= +github.com/containerd/cgroups/v3 v3.0.1/go.mod h1:/vtwk1VXrtoa5AaZLkypuOJgA/6DyPMZHJPGQNtlHnw= +github.com/coreos/go-systemd/v22 v22.3.2 h1:D9/bQk5vlXQFZ6Kwuu6zaiXJ9oTPe68++AzAJc1DzSI= +github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/docker/go-units v0.4.0 h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw= +github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/frankban/quicktest v1.14.0 h1:+cqqvzZV87b4adx/5ayVOaYZ2CrvM4ejQvUdBzPPUss= +github.com/frankban/quicktest v1.14.0/go.mod h1:NeW+ay9A/U67EYXNFA1nPE8e/tnQv/09mUdL/ijj8og= +github.com/godbus/dbus/v5 v5.0.4 h1:9349emZab16e7zQvpmsbtjc18ykshndd8y2PG3sgJbA= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/opencontainers/runtime-spec v1.0.2 h1:UfAcuLBJB9Coz72x1hgl8O5RVzTdNiaglX6v2DM6FI0= +github.com/opencontainers/runtime-spec v1.0.2/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= +github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 h1:onHthvaw9LFnH4t2DcNVpwGmV9E1BkGknEliJkfwQj0= +github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58/go.mod h1:DXv8WO4yhMYhSNPKjeNKa5WY9YCIEBRbNzFFPJbWO6Y= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBOAvL+k= +github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= +github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE= +github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +go.uber.org/goleak v1.1.12 h1:gZAh5/EyT/HQwlpkCy6wTpqfH9H8Lz8zbm3dZh+OyzA= +go.uber.org/goleak v1.1.12/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/examples/dynamic/limit.txt b/examples/dynamic/limit.txt new file mode 100644 index 0000000..7df3e13 --- /dev/null +++ b/examples/dynamic/limit.txt @@ -0,0 +1 @@ +4294967296 diff --git a/examples/dynamic/main.go b/examples/dynamic/main.go new file mode 100644 index 0000000..557babc --- /dev/null +++ b/examples/dynamic/main.go @@ -0,0 +1,52 @@ +package main + +import ( + "bytes" + "errors" + "log/slog" + "os" + "os/signal" + "strconv" + "time" + + "github.com/KimMachineGun/automemlimit/memlimit" +) + +func init() { + slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stderr, nil))) + + memlimit.SetGoMemLimitWithOpts( + memlimit.WithProvider( + FileProvider("limit.txt"), + ), + memlimit.WithRefreshInterval(5*time.Second), + memlimit.WithLogger(slog.Default()), + ) +} + +func main() { + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt) + + s := <-c + slog.Info("signal captured", slog.Any("signal", s)) +} + +func FileProvider(path string) memlimit.Provider { + return func() (uint64, error) { + b, err := os.ReadFile(path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return memlimit.ApplyFallback(memlimit.FromCgroup, memlimit.FromSystem)() + } + return 0, err + } + + b = bytes.TrimSpace(b) + if len(b) == 0 { + return memlimit.ApplyFallback(memlimit.FromCgroup, memlimit.FromSystem)() + } + + return strconv.ParseUint(string(b), 10, 64) + } +} diff --git a/memlimit/memlimit.go b/memlimit/memlimit.go index 89404b3..946e851 100644 --- a/memlimit/memlimit.go +++ b/memlimit/memlimit.go @@ -8,6 +8,7 @@ import ( "os" "runtime/debug" "strconv" + "time" ) const ( @@ -19,15 +20,14 @@ const ( defaultAUTOMEMLIMIT = 0.9 ) -var ( - // ErrNoLimit is returned when the memory limit is not set. - ErrNoLimit = errors.New("memory is not limited") -) +// ErrNoLimit is returned when the memory limit is not set. +var ErrNoLimit = errors.New("memory is not limited") type config struct { logger *slog.Logger ratio float64 provider Provider + refresh time.Duration } // Option is a function that configures the behavior of SetGoMemLimitWithOptions. @@ -61,6 +61,20 @@ func WithLogger(logger *slog.Logger) Option { } } +// WithRefreshInterval configures the refresh interval for automemlimit. +// If a refresh interval is greater than 0, automemlimit periodically fetches +// the memory limit from the provider and reapplies it if it has changed. +// If the provider returns an error, it logs the error and continues. +// Since ErrNoLimit is also considered as an error (but not logged), +// you should return math.MaxInt64 if you want to unset the limit. +// +// Default: 0 (no refresh) +func WithRefreshInterval(refresh time.Duration) Option { + return func(cfg *config) { + cfg.refresh = refresh + } +} + // WithEnv configures whether to use environment variables. // // Default: false @@ -80,7 +94,7 @@ func memlimitLogger(logger *slog.Logger) *slog.Logger { // SetGoMemLimitWithOpts sets GOMEMLIMIT with options and environment variables. // // You can configure how much memory of the cgroup's memory limit to set as GOMEMLIMIT -// through AUTOMEMLIMIT envrironment variable in the half-open range (0.0,1.0]. +// through AUTOMEMLIMIT environment variable in the half-open range (0.0,1.0]. // // If AUTOMEMLIMIT is not set, it defaults to 0.9. (10% is the headroom for memory sources the Go runtime is unaware of.) // If GOMEMLIMIT is already set or AUTOMEMLIMIT=off, this function does nothing. @@ -128,20 +142,9 @@ func SetGoMemLimitWithOpts(opts ...Option) (_ int64, _err error) { cfg.provider = ApplyFallback(cfg.provider, FromSystem) } - // capture the current GOMEMLIMIT for rollback in case of panic + // rollback to previous memory limit on panic snapshot := debug.SetMemoryLimit(-1) - defer func() { - panicErr := recover() - if panicErr != nil { - if _err != nil { - cfg.logger.Error("failed to set GOMEMLIMIT", slog.Any("error", _err)) - } - _err = fmt.Errorf("panic during setting the Go's memory limit, rolling back to previous limit %d: %v", - snapshot, panicErr, - ) - debug.SetMemoryLimit(snapshot) - } - }() + defer rollbackOnPanic(cfg.logger, snapshot, &_err) // check if GOMEMLIMIT is already set if val, ok := os.LookupEnv(envGOMEMLIMIT); ok { @@ -156,26 +159,87 @@ func SetGoMemLimitWithOpts(opts ...Option) (_ int64, _err error) { cfg.logger.Info("AUTOMEMLIMIT is set to off, skipping") return 0, nil } - _ratio, err := strconv.ParseFloat(val, 64) + ratio, err = strconv.ParseFloat(val, 64) if err != nil { return 0, fmt.Errorf("cannot parse AUTOMEMLIMIT: %s", val) } - ratio = _ratio } - // set GOMEMLIMIT - limit, err := setGoMemLimit(ApplyRatio(cfg.provider, ratio)) + // get the memory limit from the provider + provider := capProvider(ApplyRatio(cfg.provider, ratio)) + + // set the memory limit and start refresh + limit, err := updateGoMemLimit(uint64(snapshot), provider, cfg.logger) + go refresh(provider, cfg.logger, cfg.refresh) if err != nil { if errors.Is(err, ErrNoLimit) { cfg.logger.Info("memory is not limited, skipping") + // TODO: consider returning the snapshot return 0, nil } - return 0, fmt.Errorf("failed to set GOMEMLIMIT: %w", err) + return 0, fmt.Errorf("failed to get memory limit: %w", err) + } + + return int64(limit), nil +} + +// updateGoMemLimit updates the Go's memory limit, if it has changed. +func updateGoMemLimit(currLimit uint64, provider Provider, logger *slog.Logger) (uint64, error) { + newLimit, err := provider() + if err != nil { + return 0, err } - cfg.logger.Info("GOMEMLIMIT is updated", slog.Int64(envGOMEMLIMIT, limit)) + if newLimit == currLimit { + logger.Debug("GOMEMLIMIT is not changed, skipping", slog.Uint64(envGOMEMLIMIT, newLimit)) + return newLimit, nil + } - return limit, nil + debug.SetMemoryLimit(int64(newLimit)) + logger.Info("GOMEMLIMIT is updated", slog.Uint64(envGOMEMLIMIT, newLimit), slog.Uint64("previous", currLimit)) + + return newLimit, nil +} + +// refresh periodically fetches the memory limit from the provider and reapplies it if it has changed. +// See more details in the documentation of WithRefreshInterval. +func refresh(provider Provider, logger *slog.Logger, refresh time.Duration) { + if refresh == 0 { + return + } + + t := time.NewTicker(refresh) + for range t.C { + err := func() (_err error) { + snapshot := debug.SetMemoryLimit(-1) + defer rollbackOnPanic(logger, snapshot, &_err) + + _, err := updateGoMemLimit(uint64(snapshot), provider, logger) + if err != nil && !errors.Is(err, ErrNoLimit) { + return err + } + + return nil + }() + if err != nil { + logger.Error("failed to refresh GOMEMLIMIT", slog.Any("error", err)) + } + } +} + +// rollbackOnPanic rollbacks to the snapshot on panic. +// Since it uses recover, it should be called in a deferred function. +func rollbackOnPanic(logger *slog.Logger, snapshot int64, err *error) { + panicErr := recover() + if panicErr != nil { + if *err != nil { + logger.Error("failed to set GOMEMLIMIT", slog.Any("error", *err)) + } + *err = fmt.Errorf("panic during setting the Go's memory limit, rolling back to previous limit %d: %v", + snapshot, panicErr, + ) + debug.SetMemoryLimit(snapshot) + } } // SetGoMemLimitWithEnv sets GOMEMLIMIT with the value from the environment variables. @@ -195,19 +259,24 @@ func SetGoMemLimitWithProvider(provider Provider, ratio float64) (int64, error) return SetGoMemLimitWithOpts(WithProvider(provider), WithRatio(ratio)) } -func setGoMemLimit(provider Provider) (int64, error) { - limit, err := provider() - if err != nil { - return 0, err +func capProvider(provider Provider) Provider { + return func() (uint64, error) { + limit, err := provider() + if err != nil { + return 0, err + } else if limit > math.MaxInt64 { + return math.MaxInt64, nil + } + return limit, nil } - capped := cappedU64ToI64(limit) - debug.SetMemoryLimit(capped) - return capped, nil } -func cappedU64ToI64(limit uint64) int64 { - if limit > math.MaxInt64 { - return math.MaxInt64 +func maxInt64OnNoLimit(provider Provider) Provider { + return func() (uint64, error) { + limit, err := provider() + if errors.Is(err, ErrNoLimit) { + return math.MaxInt64, nil + } + return limit, err } - return int64(limit) } diff --git a/memlimit/memlimit_common_test.go b/memlimit/memlimit_common_test.go index 361d117..d5310ac 100644 --- a/memlimit/memlimit_common_test.go +++ b/memlimit/memlimit_common_test.go @@ -1,7 +1,9 @@ package memlimit import ( + "fmt" "math" + "runtime/debug" "testing" ) @@ -68,10 +70,11 @@ func TestSetGoMemLimitWithProvider(t *testing.T) { ratio float64 } tests := []struct { - name string - args args - want int64 - wantErr error + name string + args args + want int64 + wantErr error + gomemlimit int64 }{ { name: "Limit_0.5", @@ -79,8 +82,9 @@ func TestSetGoMemLimitWithProvider(t *testing.T) { provider: Limit(1024 * 1024 * 1024), ratio: 0.5, }, - want: 536870912, - wantErr: nil, + want: 536870912, + wantErr: nil, + gomemlimit: 536870912, }, { name: "Limit_0.9", @@ -88,8 +92,9 @@ func TestSetGoMemLimitWithProvider(t *testing.T) { provider: Limit(1024 * 1024 * 1024), ratio: 0.9, }, - want: 966367641, - wantErr: nil, + want: 966367641, + wantErr: nil, + gomemlimit: 966367641, }, { name: "Limit_0.9_math.MaxUint64", @@ -97,8 +102,9 @@ func TestSetGoMemLimitWithProvider(t *testing.T) { provider: Limit(math.MaxUint64), ratio: 0.9, }, - want: math.MaxInt64, - wantErr: nil, + want: math.MaxInt64, + wantErr: nil, + gomemlimit: math.MaxInt64, }, { name: "Limit_0.9_math.MaxUint64", @@ -106,8 +112,9 @@ func TestSetGoMemLimitWithProvider(t *testing.T) { provider: Limit(math.MaxUint64), ratio: 0.9, }, - want: math.MaxInt64, - wantErr: nil, + want: math.MaxInt64, + wantErr: nil, + gomemlimit: math.MaxInt64, }, { name: "Limit_0.45_math.MaxUint64", @@ -115,12 +122,16 @@ func TestSetGoMemLimitWithProvider(t *testing.T) { provider: Limit(math.MaxUint64), ratio: 0.45, }, - want: 8301034833169298432, - wantErr: nil, + want: 8301034833169298432, + wantErr: nil, + gomemlimit: 8301034833169298432, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Cleanup(func() { + debug.SetMemoryLimit(math.MaxInt64) + }) got, err := SetGoMemLimitWithProvider(tt.args.provider, tt.args.ratio) if err != tt.wantErr { t.Errorf("SetGoMemLimitWithProvider() error = %v, wantErr %v", err, tt.wantErr) @@ -129,6 +140,68 @@ func TestSetGoMemLimitWithProvider(t *testing.T) { if got != tt.want { t.Errorf("SetGoMemLimitWithProvider() got = %v, want %v", got, tt.want) } + if debug.SetMemoryLimit(-1) != tt.gomemlimit { + t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", debug.SetMemoryLimit(-1), tt.gomemlimit) + } + }) + } +} + +func TestSetGoMemLimitWithOpts(t *testing.T) { + tests := []struct { + name string + opts []Option + want int64 + wantErr error + gomemlimit int64 + }{ + { + name: "unknown error", + opts: []Option{ + WithProvider(func() (uint64, error) { + return 0, fmt.Errorf("unknown error") + }), + }, + want: 0, + wantErr: fmt.Errorf("failed to get memory limit: unknown error"), + gomemlimit: math.MaxInt64, + }, + { + name: "ErrNoLimit", + opts: []Option{ + WithProvider(func() (uint64, error) { + return 0, ErrNoLimit + }), + }, + want: 0, + wantErr: nil, + gomemlimit: math.MaxInt64, + }, + { + name: "wrapped ErrNoLimit", + opts: []Option{ + WithProvider(func() (uint64, error) { + return 0, fmt.Errorf("wrapped: %w", ErrNoLimit) + }), + }, + want: 0, + wantErr: nil, + gomemlimit: math.MaxInt64, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := SetGoMemLimitWithOpts(tt.opts...) + if tt.wantErr != nil && err.Error() != tt.wantErr.Error() { + t.Errorf("SetGoMemLimitWithOpts() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("SetGoMemLimitWithOpts() got = %v, want %v", got, tt.want) + } + if debug.SetMemoryLimit(-1) != tt.gomemlimit { + t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", debug.SetMemoryLimit(-1), tt.gomemlimit) + } }) } } diff --git a/memlimit/memlimit_test.go b/memlimit/memlimit_test.go index 7414594..21d7f22 100644 --- a/memlimit/memlimit_test.go +++ b/memlimit/memlimit_test.go @@ -6,7 +6,9 @@ package memlimit import ( "flag" "log" + "math" "os" + "runtime/debug" "testing" "github.com/containerd/cgroups/v3" @@ -34,38 +36,42 @@ func TestSetGoMemLimit(t *testing.T) { ratio float64 } tests := []struct { - name string - args args - want int64 - wantErr error - skip bool + name string + args args + want int64 + wantErr error + gomemlimit int64 + skip bool }{ { name: "0.5", args: args{ ratio: 0.5, }, - want: int64(float64(expected) * 0.5), - wantErr: nil, - skip: expected == 0 || cgVersion == cgroups.Unavailable, + want: int64(float64(expected) * 0.5), + wantErr: nil, + gomemlimit: int64(float64(expected) * 0.5), + skip: expected == 0 || cgVersion == cgroups.Unavailable, }, { name: "0.9", args: args{ ratio: 0.9, }, - want: int64(float64(expected) * 0.9), - wantErr: nil, - skip: expected == 0 || cgVersion == cgroups.Unavailable, + want: int64(float64(expected) * 0.9), + wantErr: nil, + gomemlimit: int64(float64(expected) * 0.9), + skip: expected == 0 || cgVersion == cgroups.Unavailable, }, { name: "Unavailable", args: args{ ratio: 0.9, }, - want: 0, - wantErr: ErrCgroupsNotSupported, - skip: cgVersion != cgroups.Unavailable, + want: 0, + wantErr: ErrCgroupsNotSupported, + gomemlimit: math.MaxInt64, + skip: cgVersion != cgroups.Unavailable, }, } for _, tt := range tests { @@ -73,6 +79,9 @@ func TestSetGoMemLimit(t *testing.T) { if tt.skip { t.Skip() } + t.Cleanup(func() { + debug.SetMemoryLimit(math.MaxInt64) + }) got, err := SetGoMemLimit(tt.args.ratio) if err != tt.wantErr { t.Errorf("SetGoMemLimit() error = %v, wantErr %v", err, tt.wantErr) @@ -81,6 +90,9 @@ func TestSetGoMemLimit(t *testing.T) { if got != tt.want { t.Errorf("SetGoMemLimit() got = %v, want %v", got, tt.want) } + if debug.SetMemoryLimit(-1) != tt.gomemlimit { + t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", debug.SetMemoryLimit(-1), tt.gomemlimit) + } }) } } @@ -91,11 +103,12 @@ func TestSetGoMemLimitWithProvider_WithCgroupProvider(t *testing.T) { ratio float64 } tests := []struct { - name string - args args - want int64 - wantErr error - skip bool + name string + args args + want int64 + wantErr error + gomemlimit int64 + skip bool }{ { name: "FromCgroup", @@ -103,9 +116,10 @@ func TestSetGoMemLimitWithProvider_WithCgroupProvider(t *testing.T) { provider: FromCgroup, ratio: 0.9, }, - want: int64(float64(expected) * 0.9), - wantErr: nil, - skip: expected == 0 || cgVersion == cgroups.Unavailable, + want: int64(float64(expected) * 0.9), + wantErr: nil, + gomemlimit: int64(float64(expected) * 0.9), + skip: expected == 0 || cgVersion == cgroups.Unavailable, }, { name: "FromCgroup_Unavaliable", @@ -113,9 +127,10 @@ func TestSetGoMemLimitWithProvider_WithCgroupProvider(t *testing.T) { provider: FromCgroup, ratio: 0.9, }, - want: 0, - wantErr: ErrNoCgroup, - skip: expected == 0 || cgVersion != cgroups.Unavailable, + want: 0, + wantErr: ErrNoCgroup, + gomemlimit: math.MaxInt64, + skip: expected == 0 || cgVersion != cgroups.Unavailable, }, { name: "FromCgroupV1", @@ -123,9 +138,10 @@ func TestSetGoMemLimitWithProvider_WithCgroupProvider(t *testing.T) { provider: FromCgroupV1, ratio: 0.9, }, - want: int64(float64(expected) * 0.9), - wantErr: nil, - skip: expected == 0 || cgVersion != cgroups.Legacy, + want: int64(float64(expected) * 0.9), + wantErr: nil, + gomemlimit: int64(float64(expected) * 0.9), + skip: expected == 0 || cgVersion != cgroups.Legacy, }, { name: "FromCgroupHybrid", @@ -133,9 +149,10 @@ func TestSetGoMemLimitWithProvider_WithCgroupProvider(t *testing.T) { provider: FromCgroupHybrid, ratio: 0.9, }, - want: int64(float64(expected) * 0.9), - wantErr: nil, - skip: expected == 0 || cgVersion != cgroups.Hybrid, + want: int64(float64(expected) * 0.9), + wantErr: nil, + gomemlimit: int64(float64(expected) * 0.9), + skip: expected == 0 || cgVersion != cgroups.Hybrid, }, { name: "FromCgroupV2", @@ -143,9 +160,10 @@ func TestSetGoMemLimitWithProvider_WithCgroupProvider(t *testing.T) { provider: FromCgroupV2, ratio: 0.9, }, - want: int64(float64(expected) * 0.9), - wantErr: nil, - skip: expected == 0 || cgVersion != cgroups.Unified, + want: int64(float64(expected) * 0.9), + wantErr: nil, + gomemlimit: int64(float64(expected) * 0.9), + skip: expected == 0 || cgVersion != cgroups.Unified, }, } for _, tt := range tests { @@ -153,6 +171,9 @@ func TestSetGoMemLimitWithProvider_WithCgroupProvider(t *testing.T) { if tt.skip { t.Skip() } + t.Cleanup(func() { + debug.SetMemoryLimit(math.MaxInt64) + }) got, err := SetGoMemLimitWithProvider(tt.args.provider, tt.args.ratio) if err != tt.wantErr { t.Errorf("SetGoMemLimitWithProvider() error = %v, wantErr %v", err, tt.wantErr) @@ -161,6 +182,9 @@ func TestSetGoMemLimitWithProvider_WithCgroupProvider(t *testing.T) { if got != tt.want { t.Errorf("SetGoMemLimitWithProvider() got = %v, want %v", got, tt.want) } + if debug.SetMemoryLimit(-1) != tt.gomemlimit { + t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", debug.SetMemoryLimit(-1), tt.gomemlimit) + } }) } } @@ -171,11 +195,12 @@ func TestSetGoMemLimitWithProvider_WithSystemProvider(t *testing.T) { ratio float64 } tests := []struct { - name string - args args - want int64 - wantErr error - skip bool + name string + args args + want int64 + wantErr error + gomemlimit int64 + skip bool }{ { name: "FromSystem", @@ -183,9 +208,10 @@ func TestSetGoMemLimitWithProvider_WithSystemProvider(t *testing.T) { provider: FromSystem, ratio: 0.9, }, - want: int64(float64(expectedSystem) * 0.9), - wantErr: nil, - skip: expectedSystem == 0, + want: int64(float64(expectedSystem) * 0.9), + wantErr: nil, + gomemlimit: int64(float64(expectedSystem) * 0.9), + skip: expectedSystem == 0, }, } for _, tt := range tests { @@ -193,6 +219,9 @@ func TestSetGoMemLimitWithProvider_WithSystemProvider(t *testing.T) { if tt.skip { t.Skip() } + t.Cleanup(func() { + debug.SetMemoryLimit(math.MaxInt64) + }) got, err := SetGoMemLimitWithProvider(tt.args.provider, tt.args.ratio) if err != tt.wantErr { t.Errorf("SetGoMemLimitWithProvider() error = %v, wantErr %v", err, tt.wantErr) @@ -201,6 +230,9 @@ func TestSetGoMemLimitWithProvider_WithSystemProvider(t *testing.T) { if got != tt.want { t.Errorf("SetGoMemLimitWithProvider() got = %v, want %v", got, tt.want) } + if debug.SetMemoryLimit(-1) != tt.gomemlimit { + t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", debug.SetMemoryLimit(-1), tt.gomemlimit) + } }) } } diff --git a/memlimit/memlimit_unsupported_test.go b/memlimit/memlimit_unsupported_test.go index e55ceb8..330f93b 100644 --- a/memlimit/memlimit_unsupported_test.go +++ b/memlimit/memlimit_unsupported_test.go @@ -6,13 +6,13 @@ package memlimit import ( "errors" "flag" + "math" "os" + "runtime/debug" "testing" ) -var ( - expected uint64 -) +var expected uint64 func TestMain(m *testing.M) { flag.Uint64Var(&expected, "expected", 0, "Expected memory limit") @@ -26,30 +26,36 @@ func TestSetGoMemLimit(t *testing.T) { ratio float64 } tests := []struct { - name string - args args - want int64 - wantErr error + name string + args args + want int64 + wantErr error + gomemlimit int64 }{ { name: "0.5", args: args{ ratio: 0.5, }, - want: 0, - wantErr: ErrCgroupsNotSupported, + want: 0, + wantErr: ErrCgroupsNotSupported, + gomemlimit: math.MaxInt64, }, { name: "0.9", args: args{ ratio: 0.9, }, - want: 0, - wantErr: ErrCgroupsNotSupported, + want: 0, + wantErr: ErrCgroupsNotSupported, + gomemlimit: math.MaxInt64, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Cleanup(func() { + debug.SetMemoryLimit(math.MaxInt64) + }) got, err := SetGoMemLimit(tt.args.ratio) if !errors.Is(err, tt.wantErr) { t.Errorf("SetGoMemLimit() error = %v, wantErr %v", err, tt.wantErr) @@ -58,6 +64,9 @@ func TestSetGoMemLimit(t *testing.T) { if got != tt.want { t.Errorf("SetGoMemLimit() got = %v, want %v", got, tt.want) } + if debug.SetMemoryLimit(-1) != tt.gomemlimit { + t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", debug.SetMemoryLimit(-1), tt.gomemlimit) + } }) } } @@ -130,11 +139,12 @@ func TestSetGoMemLimitWithProvider_WithSystemProvider(t *testing.T) { ratio float64 } tests := []struct { - name string - args args - want int64 - wantErr error - skip bool + name string + args args + want int64 + wantErr error + gomemlimit int64 + skip bool }{ { name: "FromSystem", @@ -142,9 +152,10 @@ func TestSetGoMemLimitWithProvider_WithSystemProvider(t *testing.T) { provider: FromSystem, ratio: 0.9, }, - want: int64(float64(expected) * 0.9), - wantErr: nil, - skip: expected == 0, + want: int64(float64(expected) * 0.9), + wantErr: nil, + gomemlimit: int64(float64(expected) * 0.9), + skip: expected == 0, }, } for _, tt := range tests { @@ -152,6 +163,9 @@ func TestSetGoMemLimitWithProvider_WithSystemProvider(t *testing.T) { if tt.skip { t.Skip() } + t.Cleanup(func() { + debug.SetMemoryLimit(math.MaxInt64) + }) got, err := SetGoMemLimitWithProvider(tt.args.provider, tt.args.ratio) if !errors.Is(err, tt.wantErr) { t.Errorf("SetGoMemLimitWithProvider() error = %v, wantErr %v", err, tt.wantErr) @@ -160,6 +174,9 @@ func TestSetGoMemLimitWithProvider_WithSystemProvider(t *testing.T) { if got != tt.want { t.Errorf("SetGoMemLimitWithProvider() got = %v, want %v", got, tt.want) } + if debug.SetMemoryLimit(-1) != tt.gomemlimit { + t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", debug.SetMemoryLimit(-1), tt.gomemlimit) + } }) } }