diff --git a/memlimit/memlimit_test.go b/memlimit/memlimit_test.go index 21d7f22..2072531 100644 --- a/memlimit/memlimit_test.go +++ b/memlimit/memlimit_test.go @@ -9,7 +9,9 @@ import ( "math" "os" "runtime/debug" + "sync/atomic" "testing" + "time" "github.com/containerd/cgroups/v3" ) @@ -236,3 +238,68 @@ func TestSetGoMemLimitWithProvider_WithSystemProvider(t *testing.T) { }) } } + +func TestSetGoMemLimitWithOpts_WithRefreshInterval(t *testing.T) { + t.Cleanup(func() { + debug.SetMemoryLimit(math.MaxInt64) + }) + var limit atomic.Int64 + output, err := SetGoMemLimitWithOpts( + WithProvider(func() (uint64, error) { + l := limit.Load() + if l == 0 { + return 0, ErrNoLimit + } + return uint64(l), nil + }), + WithRatio(1), + WithRefreshInterval(10*time.Millisecond), + ) + if err != nil { + t.Errorf("SetGoMemLimitWithOpts() error = %v", err) + } else if output != limit.Load() { + t.Errorf("SetGoMemLimitWithOpts() got = %v, want %v", output, limit.Load()) + } + + // 1. no limit + curr := debug.SetMemoryLimit(-1) + if curr != math.MaxInt64 { + t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", curr, limit.Load()) + } + + // 2. max limit + limit.Add(math.MaxInt64) + time.Sleep(100 * time.Millisecond) + + curr = debug.SetMemoryLimit(-1) + if curr != math.MaxInt64 { + t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", curr, math.MaxInt64) + } + + // 3. adjust limit + limit.Add(-1024) + time.Sleep(100 * time.Millisecond) + + curr = debug.SetMemoryLimit(-1) + if curr != math.MaxInt64-1024 { + t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", curr, math.MaxInt64-1024) + } + + // 4. no limit again (don't change the limit) + limit.Store(0) + time.Sleep(100 * time.Millisecond) + + curr = debug.SetMemoryLimit(-1) + if curr != math.MaxInt64-1024 { + t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", curr, math.MaxInt64-1024) + } + + // 5. new limit + limit.Store(math.MaxInt32) + time.Sleep(100 * time.Millisecond) + + curr = debug.SetMemoryLimit(-1) + if curr != math.MaxInt32 { + t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", curr, math.MaxInt32) + } +} diff --git a/memlimit/provider.go b/memlimit/provider.go index 32cc1ee..4f83770 100644 --- a/memlimit/provider.go +++ b/memlimit/provider.go @@ -16,6 +16,9 @@ func Limit(limit uint64) func() (uint64, error) { // ApplyRationA is a helper Provider function that applies the given ratio to the given provider. func ApplyRatio(provider Provider, ratio float64) Provider { + if ratio == 1 { + return provider + } return func() (uint64, error) { if ratio <= 0 || ratio > 1 { return 0, fmt.Errorf("invalid ratio: %f, ratio should be in the range (0.0,1.0]", ratio)