From ae074408c2b553e5bc8c60422302f29068115aed Mon Sep 17 00:00:00 2001 From: Geon Kim Date: Sun, 8 Dec 2024 15:49:31 +0900 Subject: [PATCH] test(memlimit): add tests for new cgroup implementation --- .github/workflows/test.yml | 12 +- memlimit/cgroups.go | 400 +++++++++++++++++++++++++++++++ memlimit/cgroups_linux.go | 371 +--------------------------- memlimit/cgroups_linux_test.go | 70 ++++++ memlimit/cgroups_test.go | 272 +++++++++++++++++---- memlimit/memlimit_common_test.go | 299 ----------------------- memlimit/memlimit_linux_test.go | 234 ++++++++++++++++++ memlimit/memlimit_test.go | 301 +++++++++++++---------- 8 files changed, 1124 insertions(+), 835 deletions(-) create mode 100644 memlimit/cgroups_linux_test.go delete mode 100644 memlimit/memlimit_common_test.go create mode 100644 memlimit/memlimit_linux_test.go diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b805685..c019a7b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -19,15 +19,15 @@ jobs: - name: Run tests in Go container (1000m) run: | - docker run --rm -v=$(pwd):/app -w=/app -m=1000m golang:1.22 go test -v ./... -expected=1048576000 + docker run --rm -v=$(pwd):/app -w=/app -m=1000m golang:1.22 go test -v ./... -expected=1048576000 -cgroup-version 1 - name: Run tests in Go container (4321m) run: | - docker run --rm -v=$(pwd):/app -w=/app -m=4321m golang:1.22 go test -v ./... -expected=4530896896 + docker run --rm -v=$(pwd):/app -w=/app -m=4321m golang:1.22 go test -v ./... -expected=4530896896 -cgroup-version 1 - name: Run tests in Go container (system memory limit) run: | - docker run --rm -v=$(pwd):/app -w=/app golang:1.22 go test -v ./... -expected-system=$(($(awk '/MemTotal/ {print $2}' /proc/meminfo) * 1024)) + docker run --rm -v=$(pwd):/app -w=/app golang:1.22 go test -v ./... -expected-system=$(($(awk '/MemTotal/ {print $2}' /proc/meminfo) * 1024)) -cgroup-version 1 test-ubuntu-22_04: runs-on: ubuntu-22.04 @@ -45,12 +45,12 @@ jobs: - name: Run tests in Go container (1000m) run: | - docker run --rm -v=$(pwd):/app -w=/app -m=1000m golang:1.22 go test -v ./... -expected=1048576000 + docker run --rm -v=$(pwd):/app -w=/app -m=1000m golang:1.22 go test -v ./... -expected=1048576000 -cgroup-version 2 - name: Run tests in Go container (4321m) run: | - docker run --rm -v=$(pwd):/app -w=/app -m=4321m golang:1.22 go test -v ./... -expected=4530896896 + docker run --rm -v=$(pwd):/app -w=/app -m=4321m golang:1.22 go test -v ./... -expected=4530896896 -cgroup-version 2 - name: Run tests in Go container (system memory limit) run: | - docker run --rm -v=$(pwd):/app -w=/app golang:1.22 go test -v ./... -expected-system=$(($(awk '/MemTotal/ {print $2}' /proc/meminfo) * 1024)) + docker run --rm -v=$(pwd):/app -w=/app golang:1.22 go test -v ./... -expected-system=$(($(awk '/MemTotal/ {print $2}' /proc/meminfo) * 1024)) -cgroup-version 2 diff --git a/memlimit/cgroups.go b/memlimit/cgroups.go index 979bd39..4e27f5e 100644 --- a/memlimit/cgroups.go +++ b/memlimit/cgroups.go @@ -1,7 +1,16 @@ package memlimit import ( + "bufio" "errors" + "fmt" + "io" + "math" + "os" + "path/filepath" + "slices" + "strconv" + "strings" ) var ( @@ -10,3 +19,394 @@ var ( // ErrCgroupsNotSupported is returned when the system does not support cgroups. ErrCgroupsNotSupported = errors.New("cgroups is not supported on this system") ) + +// fromCgroup retrieves the memory limit from the cgroup. +// The versionDetector function is used to detect the cgroup version from the mountinfo. +func fromCgroup(versionDetector func(mis []mountInfo) (bool, bool)) (uint64, error) { + mf, err := os.Open("/proc/self/mountinfo") + if err != nil { + return 0, fmt.Errorf("failed to open /proc/self/mountinfo: %w", err) + } + defer mf.Close() + + mis, err := parseMountInfo(mf) + if err != nil { + return 0, fmt.Errorf("failed to parse mountinfo: %w", err) + } + + v1, v2 := versionDetector(mis) + if !(v1 || v2) { + return 0, ErrNoCgroup + } + + cf, err := os.Open("/proc/self/cgroup") + if err != nil { + return 0, fmt.Errorf("failed to open /proc/self/cgroup: %w", err) + } + defer cf.Close() + + chs, err := parseCgroupFile(cf) + if err != nil { + return 0, fmt.Errorf("failed to parse cgroup file: %w", err) + } + + if v2 { + limit, err := getMemoryLimitV2(chs, mis) + if err == nil { + return limit, nil + } else if !v1 { + return 0, err + } + } + + return getMemoryLimitV1(chs, mis) +} + +// detectCgroupVersion detects the cgroup version from the mountinfo. +func detectCgroupVersion(mis []mountInfo) (bool, bool) { + var v1, v2 bool + for _, mi := range mis { + switch mi.FilesystemType { + case "cgroup": + v1 = true + case "cgroup2": + v2 = true + } + } + return v1, v2 +} + +// getMemoryLimitV2 retrieves the memory limit from the cgroup v2 controller. +func getMemoryLimitV2(chs []cgroupHierarchy, mis []mountInfo) (uint64, error) { + // find the cgroup v2 path for the memory controller. + // in cgroup v2, the paths are unified and the controller list is empty. + idx := slices.IndexFunc(chs, func(ch cgroupHierarchy) bool { + return ch.HierarchyID == "0" && ch.ControllerList == "" + }) + if idx == -1 { + return 0, errors.New("cgroup v2 path not found") + } + relPath := chs[idx].CgroupPath + + // find the mountpoint for the cgroup v2 controller. + idx = slices.IndexFunc(mis, func(mi mountInfo) bool { + return mi.FilesystemType == "cgroup2" + }) + if idx == -1 { + return 0, errors.New("cgroup v2 mountpoint not found") + } + root, mountPoint := mis[idx].Root, mis[idx].MountPoint + + // resolve the actual cgroup path + cgroupPath, err := resolveCgroupPath(mountPoint, root, relPath) + if err != nil { + return 0, err + } + + // retrieve the memory limit from the memory.max file + return readMemoryLimitV2FromPath(filepath.Join(cgroupPath, "memory.max")) +} + +// readMemoryLimitV2FromPath reads the memory limit for cgroup v2 from the given path. +// this function expects the path to be memory.max file. +func readMemoryLimitV2FromPath(path string) (uint64, error) { + b, err := os.ReadFile(path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return 0, ErrNoLimit + } + return 0, fmt.Errorf("failed to read memory.max: %w", err) + } + + slimit := strings.TrimSpace(string(b)) + if slimit == "max" { + return 0, ErrNoLimit + } + + limit, err := strconv.ParseUint(slimit, 10, 64) + if err != nil { + return 0, fmt.Errorf("failed to parse memory.max value: %w", err) + } + + return limit, nil +} + +// getMemoryLimitV1 retrieves the memory limit from the cgroup v1 controller. +func getMemoryLimitV1(chs []cgroupHierarchy, mis []mountInfo) (uint64, error) { + // find the cgroup v1 path for the memory controller. + idx := slices.IndexFunc(chs, func(ch cgroupHierarchy) bool { + return slices.Contains(strings.Split(ch.ControllerList, ","), "memory") + }) + if idx == -1 { + return 0, errors.New("cgroup v1 path for memory controller not found") + } + relPath := chs[idx].CgroupPath + + // find the mountpoint for the cgroup v1 controller. + idx = slices.IndexFunc(mis, func(mi mountInfo) bool { + return mi.FilesystemType == "cgroup" && slices.Contains(strings.Split(mi.SuperOptions, ","), "memory") + }) + if idx == -1 { + return 0, errors.New("cgroup v1 mountpoint for memory controller not found") + } + root, mountPoint := mis[idx].Root, mis[idx].MountPoint + + // resolve the actual cgroup path + cgroupPath, err := resolveCgroupPath(mountPoint, root, relPath) + if err != nil { + return 0, err + } + + // retrieve the memory limit from the memory.stats and memory.limit_in_bytes files. + return readMemoryLimitV1FromPath(cgroupPath) +} + +// getCgroupV1NoLimit returns the maximum value that is used to represent no limit in cgroup v1. +// the max memory limit is max int64, but it should be multiple of the page size. +func getCgroupV1NoLimit() uint64 { + ps := uint64(os.Getpagesize()) + return math.MaxInt64 / ps * ps +} + +// readMemoryLimitV1FromPath reads the memory limit for cgroup v1 from the given path. +// this function expects the path to be the cgroup directory. +func readMemoryLimitV1FromPath(cgroupPath string) (uint64, error) { + // read hierarchical_memory_limit and memory.limit_in_bytes files. + // but if hierarchical_memory_limit is not available, then use the max value as a fallback. + hml, err := readHierarchicalMemoryLimit(filepath.Join(cgroupPath, "memory.stats")) + if err != nil && !errors.Is(err, os.ErrNotExist) { + return 0, fmt.Errorf("failed to read hierarchical_memory_limit: %w", err) + } else if hml == 0 { + hml = math.MaxUint64 + } + + // read memory.limit_in_bytes file. + b, err := os.ReadFile(filepath.Join(cgroupPath, "memory.limit_in_bytes")) + if err != nil && !errors.Is(err, os.ErrNotExist) { + return 0, fmt.Errorf("failed to read memory.limit_in_bytes: %w", err) + } + lib, err := strconv.ParseUint(strings.TrimSpace(string(b)), 10, 64) + if err != nil { + return 0, fmt.Errorf("failed to parse memory.limit_in_bytes value: %w", err) + } else if lib == 0 { + hml = math.MaxUint64 + } + + // use the minimum value between hierarchical_memory_limit and memory.limit_in_bytes. + // if the limit is the maximum value, then it is considered as no limit. + limit := min(hml, lib) + if limit >= getCgroupV1NoLimit() { + return 0, ErrNoLimit + } + + return limit, nil +} + +// readHierarchicalMemoryLimit extracts hierarchical_memory_limit from memory.stats. +// this function expects the path to be memory.stats file. +func readHierarchicalMemoryLimit(path string) (uint64, error) { + file, err := os.Open(path) + if err != nil { + return 0, err + } + defer file.Close() + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + + fields := strings.Split(line, " ") + if len(fields) < 2 { + return 0, fmt.Errorf("failed to parse memory.stat %q: not enough fields", line) + } + + if fields[0] == "hierarchical_memory_limit" { + if len(fields) > 2 { + return 0, fmt.Errorf("failed to parse memory.stat %q: too many fields for hierarchical_memory_limit", line) + } + return strconv.ParseUint(fields[1], 10, 64) + } + } + if err := scanner.Err(); err != nil { + return 0, err + } + + return 0, nil +} + +// https://www.man7.org/linux/man-pages/man5/proc_pid_mountinfo.5.html +// 731 771 0:59 /sysrq-trigger /proc/sysrq-trigger ro,nosuid,nodev,noexec,relatime - proc proc rw +// +// 36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw,errors=continue +// (1)(2)(3) (4) (5) (6) (7) (8) (9) (10) (11) +// +// (1) mount ID: a unique ID for the mount (may be reused after umount(2)). +// (2) parent ID: the ID of the parent mount (or of self for the root of this mount namespace's mount tree). +// (3) major:minor: the value of st_dev for files on this filesystem (see stat(2)). +// (4) root: the pathname of the directory in the filesystem which forms the root of this mount. +// (5) mount point: the pathname of the mount point relative to the process's root directory. +// (6) mount options: per-mount options (see mount(2)). +// (7) optional fields: zero or more fields of the form "tag[:value]"; see below. +// (8) separator: the end of the optional fields is marked by a single hyphen. +// (9) filesystem type: the filesystem type in the form "type[.subtype]". +// (10) mount source: filesystem-specific information or "none". +// (11) super options: per-superblock options (see mount(2)). +type mountInfo struct { + Root string + MountPoint string + FilesystemType string + SuperOptions string +} + +// parseMountInfoLine parses a line from the mountinfo file. +func parseMountInfoLine(line string) (mountInfo, error) { + if line == "" { + return mountInfo{}, errors.New("empty line") + } + + fieldss := strings.SplitN(line, " - ", 2) + if len(fieldss) != 2 { + return mountInfo{}, fmt.Errorf("invalid separator") + } + + fields1 := strings.Split(fieldss[0], " ") + if len(fields1) < 6 { + return mountInfo{}, fmt.Errorf("not enough fields before separator: %v", fields1) + } else if len(fields1) > 7 { + return mountInfo{}, fmt.Errorf("too many fields before separator: %v", fields1) + } else if len(fields1) == 6 { + fields1 = append(fields1, "") + } + + fields2 := strings.Split(fieldss[1], " ") + if len(fields2) < 3 { + return mountInfo{}, fmt.Errorf("not enough fields after separator: %v", fields2) + } else if len(fields2) > 3 { + return mountInfo{}, fmt.Errorf("too many fields after separator: %v", fields2) + } + + return mountInfo{ + Root: fields1[3], + MountPoint: fields1[4], + FilesystemType: fields2[0], + SuperOptions: fields2[2], + }, nil +} + +// parseMountInfo parses the mountinfo file. +func parseMountInfo(r io.Reader) ([]mountInfo, error) { + var ( + s = bufio.NewScanner(r) + mis []mountInfo + ) + for s.Scan() { + line := s.Text() + + mi, err := parseMountInfoLine(line) + if err != nil { + return nil, fmt.Errorf("failed to parse mountinfo file %q: %w", line, err) + } + + mis = append(mis, mi) + } + if err := s.Err(); err != nil { + return nil, err + } + + return mis, nil +} + +// https://www.man7.org/linux/man-pages/man7/cgroups.7.html +// +// 5:cpuacct,cpu,cpuset:/daemons +// (1) (2) (3) +// +// (1) hierarchy ID: +// +// cgroups version 1 hierarchies, this field +// contains a unique hierarchy ID number that can be +// matched to a hierarchy ID in /proc/cgroups. For the +// cgroups version 2 hierarchy, this field contains the +// value 0. +// +// (2) controller list: +// +// For cgroups version 1 hierarchies, this field +// contains a comma-separated list of the controllers +// bound to the hierarchy. For the cgroups version 2 +// hierarchy, this field is empty. +// +// (3) cgroup path: +// +// This field contains the pathname of the control group +// in the hierarchy to which the process belongs. This +// pathname is relative to the mount point of the +// hierarchy. +type cgroupHierarchy struct { + HierarchyID string + ControllerList string + CgroupPath string +} + +// parseCgroupHierarchyLine parses a line from the cgroup file. +func parseCgroupHierarchyLine(line string) (cgroupHierarchy, error) { + if line == "" { + return cgroupHierarchy{}, errors.New("empty line") + } + + fields := strings.Split(line, ":") + if len(fields) < 3 { + return cgroupHierarchy{}, fmt.Errorf("not enough fields: %v", fields) + } else if len(fields) > 3 { + return cgroupHierarchy{}, fmt.Errorf("too many fields: %v", fields) + } + + return cgroupHierarchy{ + HierarchyID: fields[0], + ControllerList: fields[1], + CgroupPath: fields[2], + }, nil +} + +// parseCgroupFile parses the cgroup file. +func parseCgroupFile(r io.Reader) ([]cgroupHierarchy, error) { + var ( + s = bufio.NewScanner(r) + chs []cgroupHierarchy + ) + for s.Scan() { + line := s.Text() + + ch, err := parseCgroupHierarchyLine(line) + if err != nil { + return nil, fmt.Errorf("failed to parse cgroup file %q: %w", line, err) + } + + chs = append(chs, ch) + } + if err := s.Err(); err != nil { + return nil, err + } + + return chs, nil +} + +// resolveCgroupPath resolves the actual cgroup path from the mountpoint, root, and cgroupRelPath. +func resolveCgroupPath(mountpoint, root, cgroupRelPath string) (string, error) { + rel, err := filepath.Rel(root, cgroupRelPath) + if err != nil { + return "", err + } + + // if the relative path is ".", then the cgroupRelPath is the root itself. + if rel == "." { + return mountpoint, nil + } + + // if the relative path starts with "..", then it is outside the root. + if strings.HasPrefix(rel, "..") { + return "", fmt.Errorf("invalid cgroup path: %s is not under root %s", cgroupRelPath, root) + } + + return filepath.Join(mountpoint, rel), nil +} diff --git a/memlimit/cgroups_linux.go b/memlimit/cgroups_linux.go index 8aa8cab..fd2c7e4 100644 --- a/memlimit/cgroups_linux.go +++ b/memlimit/cgroups_linux.go @@ -3,385 +3,30 @@ package memlimit -import ( - "bufio" - "errors" - "fmt" - "io" - "math" - "os" - "path/filepath" - "slices" - "strconv" - "strings" -) - -// GetMemoryLimit retrieves the memory limit for the current cgroup, supporting: -// - cgroup v1 -// - cgroup v2 -// - Hybrid mode (fallback to v1 if v2 fails) +// FromCgroup retrieves the memory limit from the cgroup. func FromCgroup() (uint64, error) { return fromCgroup(detectCgroupVersion) } +// FromCgroupV1 retrieves the memory limit from the cgroup v1 controller. +// After v1.0.0, this function could be removed and FromCgroup should be used instead. func FromCgroupV1() (uint64, error) { return fromCgroup(func(_ []mountInfo) (bool, bool) { return true, false }) } +// FromCgroupHybrid retrieves the memory limit from the cgroup v2 and v1 controller sequentially, +// basically, it is equivalent to FromCgroup. +// After v1.0.0, this function could be removed and FromCgroup should be used instead. func FromCgroupHybrid() (uint64, error) { return FromCgroup() } +// FromCgroupV2 retrieves the memory limit from the cgroup v2 controller. +// After v1.0.0, this function could be removed and FromCgroup should be used instead. func FromCgroupV2() (uint64, error) { return fromCgroup(func(_ []mountInfo) (bool, bool) { return false, true }) } - -func fromCgroup(versionDetector func(mis []mountInfo) (bool, bool)) (uint64, error) { - mf, err := os.Open("/proc/self/mountinfo") - if err != nil { - return 0, fmt.Errorf("failed to open /proc/self/mountinfo: %w", err) - } - defer mf.Close() - - mis, err := parseMountInfo(mf) - if err != nil { - return 0, fmt.Errorf("failed to parse mountinfo: %w", err) - } - - v1, v2 := versionDetector(mis) - if !(v1 || v2) { - return 0, ErrNoCgroup - } - - cf, err := os.Open("/proc/self/cgroup") - if err != nil { - return 0, fmt.Errorf("failed to open /proc/self/cgroup: %w", err) - } - defer cf.Close() - - chs, err := parseCgroupFile(cf) - if err != nil { - return 0, fmt.Errorf("failed to parse cgroup file: %w", err) - } - - if v2 { - limit, err := getMemoryLimitV2(chs, mis) - if err == nil { - return limit, nil - } else if !v1 { - return 0, err - } - } - - return getMemoryLimitV1(chs, mis) -} - -func detectCgroupVersion(mis []mountInfo) (bool, bool) { - var v1, v2 bool - for _, mi := range mis { - switch mi.FilesystemType { - case "cgroup": - v1 = true - case "cgroup2": - v2 = true - } - } - return v1, v2 -} - -// getMemoryLimitV2 retrieves the memory limit for cgroup v2. -func getMemoryLimitV2(chs []cgroupHierarchy, mis []mountInfo) (uint64, error) { - idx := slices.IndexFunc(chs, func(ch cgroupHierarchy) bool { - return ch.HierarchyID == "0" && ch.ControllerList == "" - }) - if idx == -1 { - return 0, errors.New("cgroup v2 path not found") - } - relPath := chs[idx].CgroupPath - - idx = slices.IndexFunc(mis, func(mi mountInfo) bool { - return mi.FilesystemType == "cgroup2" - }) - if idx == -1 { - return 0, errors.New("cgroup v2 mountpoint not found") - } - root, mountPoint := mis[idx].Root, mis[idx].MountPoint - - // Resolve the actual cgroup path - cgroupPath, err := resolveCgroupPath(mountPoint, root, relPath) - if err != nil { - return 0, err - } - - // Construct the path to memory.max - memoryMaxPath := filepath.Join(cgroupPath, "memory.max") - - // Read the memory limit from memory.max - return readMemoryLimitV2FromPath(memoryMaxPath) -} - -// getMemoryLimitV1 retrieves the memory limit for cgroup v1. -func getMemoryLimitV1(chs []cgroupHierarchy, mis []mountInfo) (uint64, error) { - idx := slices.IndexFunc(chs, func(ch cgroupHierarchy) bool { - return slices.Contains(strings.Split(ch.ControllerList, ","), "memory") - }) - if idx == -1 { - return 0, errors.New("cgroup v1 path for memory controller not found") - } - relPath := chs[idx].CgroupPath - - idx = slices.IndexFunc(mis, func(mi mountInfo) bool { - return mi.FilesystemType == "cgroup" && slices.Contains(strings.Split(mi.SuperOptions, ","), "memory") - }) - if idx == -1 { - return 0, errors.New("cgroup v1 mountpoint for memory controller not found") - } - root, mountPoint := mis[idx].Root, mis[idx].MountPoint - - // Resolve the actual cgroup path - cgroupPath, err := resolveCgroupPath(mountPoint, root, relPath) - if err != nil { - return 0, err - } - - // Retrieve the memory limit - return readMemoryLimitV1FromPath(cgroupPath) -} - -// readMemoryLimitV2FromPath reads the memory limit from the memory.max file for cgroup v2. -func readMemoryLimitV2FromPath(path string) (uint64, error) { - b, err := os.ReadFile(path) - if err != nil { - if errors.Is(err, os.ErrNotExist) { - return 0, ErrNoLimit - } - return 0, fmt.Errorf("failed to read memory.max: %w", err) - } - - slimit := strings.TrimSpace(string(b)) - if slimit == "max" { - return 0, ErrNoLimit - } - - limit, err := strconv.ParseUint(slimit, 10, 64) - if err != nil { - return 0, fmt.Errorf("failed to parse memory.max value: %w", err) - } - - return limit, nil -} - -func getCgroupV1NoLimit() uint64 { - ps := uint64(os.Getpagesize()) - return math.MaxInt64 / ps * ps -} - -// readMemoryLimitV1FromPath reads the memory limit for cgroup v1 from the given path. -func readMemoryLimitV1FromPath(cgroupPath string) (uint64, error) { - hml, err := readHierarchicalMemoryLimit(filepath.Join(cgroupPath, "memory.stats")) - if err != nil && !errors.Is(err, os.ErrNotExist) { - return 0, fmt.Errorf("failed to read hierarchical_memory_limit: %w", err) - } else if hml == 0 { - hml = math.MaxUint64 - } - - b, err := os.ReadFile(filepath.Join(cgroupPath, "memory.limit_in_bytes")) - if err != nil && !errors.Is(err, os.ErrNotExist) { - return 0, fmt.Errorf("failed to read memory.limit_in_bytes: %w", err) - } - lib, err := strconv.ParseUint(strings.TrimSpace(string(b)), 10, 64) - if err != nil { - return 0, fmt.Errorf("failed to parse memory.limit_in_bytes value: %w", err) - } else if lib == 0 { - hml = math.MaxUint64 - } - - limit := min(hml, lib) - if limit >= getCgroupV1NoLimit() { - return 0, ErrNoLimit - } - - return limit, nil -} - -// readHierarchicalMemoryLimit extracts hierarchical_memory_limit from memory.stats for cgroup v1. -func readHierarchicalMemoryLimit(statPath string) (uint64, error) { - file, err := os.Open(statPath) - if err != nil { - return 0, err - } - defer file.Close() - - scanner := bufio.NewScanner(file) - for scanner.Scan() { - line := scanner.Text() - - fields := strings.Split(line, " ") - if len(fields) < 2 { - return 0, fmt.Errorf("failed to parse memory.stat %q: not enough fields", line) - } - - if fields[0] == "hierarchical_memory_limit" { - if len(fields) > 2 { - return 0, fmt.Errorf("failed to parse memory.stat %q: too many fields for hierarchical_memory_limit", line) - } - return strconv.ParseUint(fields[1], 10, 64) - } - } - if err := scanner.Err(); err != nil { - return 0, err - } - - return 0, nil -} - -// https://www.man7.org/linux/man-pages/man5/proc_pid_mountinfo.5.html -// 731 771 0:59 /sysrq-trigger /proc/sysrq-trigger ro,nosuid,nodev,noexec,relatime - proc proc rw -// -// 36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw,errors=continue -// (1)(2)(3) (4) (5) (6) (7) (8) (9) (10) (11) -// -// (1) mount ID: a unique ID for the mount (may be reused after umount(2)). -// (2) parent ID: the ID of the parent mount (or of self for the root of this mount namespace's mount tree). -// (3) major:minor: the value of st_dev for files on this filesystem (see stat(2)). -// (4) root: the pathname of the directory in the filesystem which forms the root of this mount. -// (5) mount point: the pathname of the mount point relative to the process's root directory. -// (6) mount options: per-mount options (see mount(2)). -// (7) optional fields: zero or more fields of the form "tag[:value]"; see below. -// (8) separator: the end of the optional fields is marked by a single hyphen. -// (9) filesystem type: the filesystem type in the form "type[.subtype]". -// (10) mount source: filesystem-specific information or "none". -// (11) super options: per-superblock options (see mount(2)). -type mountInfo struct { - Root string - MountPoint string - FilesystemType string - SuperOptions string -} - -func parseMountInfo(r io.Reader) ([]mountInfo, error) { - var ( - s = bufio.NewScanner(r) - mis []mountInfo - ) - for s.Scan() { - line := s.Text() - - fieldss := strings.SplitN(line, " - ", 2) - if len(fieldss) != 2 { - return nil, fmt.Errorf("failed to parse mountinfo %q: invalid separator", line) - } - - fields1 := strings.Split(fieldss[0], " ") - if len(fields1) < 6 { - return nil, fmt.Errorf("failed to parse mountinfo %q: not enough fields1 %v", line, fields1) - } else if len(fields1) > 7 { - return nil, fmt.Errorf("failed to parse mountinfo %q: too many fields", line) - } else if len(fields1) == 6 { - fields1 = append(fields1, "") - } - - fields2 := strings.Split(fieldss[1], " ") - if len(fields2) < 3 { - return nil, fmt.Errorf("failed to parse mountinfo %q: not enough fields2 %v", line, fields2) - } else if len(fields2) > 3 { - return nil, fmt.Errorf("failed to parse mountinfo %q: too many fields", line) - } - - mis = append(mis, mountInfo{ - Root: fields1[3], - MountPoint: fields1[4], - FilesystemType: fields2[0], - SuperOptions: fields2[2], - }) - } - if err := s.Err(); err != nil { - return nil, err - } - - return mis, nil -} - -// https://www.man7.org/linux/man-pages/man7/cgroups.7.html -// -// 5:cpuacct,cpu,cpuset:/daemons -// (1) (2) (3) -// -// (1) hierarchy ID: -// -// cgroups version 1 hierarchies, this field -// contains a unique hierarchy ID number that can be -// matched to a hierarchy ID in /proc/cgroups. For the -// cgroups version 2 hierarchy, this field contains the -// value 0. -// -// (2) controller list: -// -// For cgroups version 1 hierarchies, this field -// contains a comma-separated list of the controllers -// bound to the hierarchy. For the cgroups version 2 -// hierarchy, this field is empty. -// -// (3) cgroup path: -// -// This field contains the pathname of the control group -// in the hierarchy to which the process belongs. This -// pathname is relative to the mount point of the -// hierarchy. -type cgroupHierarchy struct { - HierarchyID string - ControllerList string - CgroupPath string -} - -func parseCgroupFile(r io.Reader) ([]cgroupHierarchy, error) { - var ( - s = bufio.NewScanner(r) - chs []cgroupHierarchy - ) - for s.Scan() { - line := s.Text() - - fields := strings.Split(line, ":") - if len(fields) != 3 { - return nil, fmt.Errorf("failed to parse cgroup file %q: invalid separator", line) - } - - chs = append(chs, cgroupHierarchy{ - HierarchyID: fields[0], - ControllerList: fields[1], - CgroupPath: fields[2], - }) - } - if err := s.Err(); err != nil { - return nil, err - } - - return chs, nil -} - -func resolveCgroupPath(mountpoint, root, cgroupRelPath string) (string, error) { - root = filepath.Clean(strings.TrimPrefix(root, "/")) - cgroupRelPath = filepath.Clean(strings.TrimPrefix(cgroupRelPath, "/")) - - if root == cgroupRelPath || (root == "." && cgroupRelPath == ".") { - return mountpoint, nil - } - - if strings.HasPrefix(cgroupRelPath, root) { - relativePath := strings.TrimPrefix(cgroupRelPath, root) - finalPath := filepath.Join(mountpoint, relativePath) - - if _, err := os.Stat(finalPath); os.IsNotExist(err) { - return "", fmt.Errorf("resolved cgroup path does not exist: %s", finalPath) - } - - return finalPath, nil - } - - return "", fmt.Errorf("invalid cgroup path: %s is not under root %s", cgroupRelPath, root) -} diff --git a/memlimit/cgroups_linux_test.go b/memlimit/cgroups_linux_test.go new file mode 100644 index 0000000..4c61cfb --- /dev/null +++ b/memlimit/cgroups_linux_test.go @@ -0,0 +1,70 @@ +//go:build linux +// +build linux + +package memlimit + +import ( + "testing" +) + +func TestFromCgroup(t *testing.T) { + if expected == 0 { + t.Skip() + } + + limit, err := FromCgroup() + if cgVersion == 0 && err != ErrNoCgroup { + t.Fatalf("FromCgroup() error = %v, wantErr %v", err, ErrNoCgroup) + } + + if err != nil { + t.Fatalf("FromCgroup() error = %v, wantErr %v", err, nil) + } + if limit != expected { + t.Fatalf("FromCgroup() got = %v, want %v", limit, expected) + } +} + +func TestFromCgroupHybrid(t *testing.T) { + if expected == 0 { + t.Skip() + } + + limit, err := FromCgroupHybrid() + if cgVersion == 0 && err != ErrNoCgroup { + t.Fatalf("FromCgroup() error = %v, wantErr %v", err, ErrNoCgroup) + } + + if err != nil { + t.Fatalf("FromCgroup() error = %v, wantErr %v", err, nil) + } + if limit != expected { + t.Fatalf("FromCgroup() got = %v, want %v", limit, expected) + } +} + +func TestFromCgroupV1(t *testing.T) { + if expected == 0 || cgVersion != 1 { + t.Skip() + } + limit, err := FromCgroupV1() + if err != nil { + t.Fatalf("FromCgroupV1() error = %v, wantErr %v", err, nil) + } + if limit != expected { + t.Fatalf("FromCgroupV1() got = %v, want %v", limit, expected) + } +} + +func TestFromCgroupV2(t *testing.T) { + if expected == 0 || cgVersion != 2 { + t.Skip() + } + limit, err := FromCgroupV2() + if err != nil { + t.Fatalf("FromCgroupV2() error = %v, wantErr %v", err, nil) + } + if limit != expected { + t.Fatalf("FromCgroupV2() got = %v, want %v", limit, expected) + } +} diff --git a/memlimit/cgroups_test.go b/memlimit/cgroups_test.go index 537b764..56e3122 100644 --- a/memlimit/cgroups_test.go +++ b/memlimit/cgroups_test.go @@ -1,67 +1,245 @@ -//go:build linux -// +build linux - package memlimit import ( + "reflect" "testing" - - "github.com/containerd/cgroups/v3" ) -func TestFromCgroup(t *testing.T) { - if expected == 0 { - t.Skip() +func TestParseMountInfoLine(t *testing.T) { + tests := []struct { + name string + input string + want mountInfo + wantErr string + }{ + { + name: "valid line with optional field", + input: "36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw,errors=continue", + want: mountInfo{ + Root: "/mnt1", + MountPoint: "/mnt2", + FilesystemType: "ext3", + SuperOptions: "rw,errors=continue", + }, + }, + { + name: "valid line without optional field", + input: "731 771 0:59 /sysrq-trigger /proc/sysrq-trigger ro,nosuid,nodev,noexec,relatime - proc proc rw", + want: mountInfo{ + Root: "/sysrq-trigger", + MountPoint: "/proc/sysrq-trigger", + FilesystemType: "proc", + SuperOptions: "rw", + }, + }, + { + name: "valid line with minimal fields (no optional fields)", + input: "25 1 0:22 / /dev rw - devtmpfs udev rw", + want: mountInfo{ + Root: "/", + MountPoint: "/dev", + FilesystemType: "devtmpfs", + SuperOptions: "rw", + }, + }, + { + name: "no separator", + input: "36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 ext3 /dev/root rw,errors=continue", + wantErr: `invalid separator`, + }, + { + name: "not enough fields on left side", + input: "36 35 98:0 /mnt1 /mnt2 - ext3 /dev/root rw,errors=continue", + wantErr: `not enough fields before separator: [36 35 98:0 /mnt1 /mnt2]`, + }, + { + name: "not enough fields on right side", + input: "36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3", + wantErr: `not enough fields after separator: [ext3]`, + }, + { + name: "too many fields on left side", + input: "36 35 98:0 /mnt1 /mnt2 rw,noatime extra master:1 - ext3 /dev/root rw,errors=continue", + wantErr: `too many fields before separator: [36 35 98:0 /mnt1 /mnt2 rw,noatime extra master:1]`, + }, + { + name: "too many fields on right side", + input: "36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw extra", + wantErr: `too many fields after separator: [ext3 /dev/root rw extra]`, + }, + { + name: "empty line", + input: "", + wantErr: `empty line`, + }, + { + name: "6 fields on left side (no optional field), should add empty optional field", + input: "100 1 8:2 / /data rw - ext4 /dev/sda2 rw,relatime", + want: mountInfo{ + Root: "/", + MountPoint: "/data", + FilesystemType: "ext4", + SuperOptions: "rw,relatime", + }, + }, } - limit, err := FromCgroup() - if cgVersion == cgroups.Unavailable && err != ErrNoCgroup { - t.Fatalf("FromCgroup() error = %v, wantErr %v", err, ErrNoCgroup) - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseMountInfoLine(tt.input) + if tt.wantErr != "" { + if err == nil { + t.Fatalf("expected an error containing %q, got nil", tt.wantErr) + } + if err.Error() != tt.wantErr { + t.Fatalf("expected error containing %q, got %q", tt.wantErr, err.Error()) + } + return + } - if err != nil { - t.Fatalf("FromCgroup() error = %v, wantErr %v", err, nil) - } - if limit != expected { - t.Fatalf("FromCgroup() got = %v, want %v", limit, expected) - } -} + if err != nil { + t.Fatalf("unexpected error: %v", err) + } -func TestFromCgroupV1(t *testing.T) { - if expected == 0 || cgVersion != cgroups.Legacy { - t.Skip() - } - limit, err := FromCgroupV1() - if err != nil { - t.Fatalf("FromCgroupV1() error = %v, wantErr %v", err, nil) - } - if limit != expected { - t.Fatalf("FromCgroupV1() got = %v, want %v", limit, expected) + if !reflect.DeepEqual(got, tt.want) { + t.Fatalf("expected %+v, got %+v", tt.want, got) + } + }) } } -func TestFromCgroupHybrid(t *testing.T) { - if expected == 0 || cgVersion != cgroups.Hybrid { - t.Skip() - } - limit, err := FromCgroupHybrid() - if err != nil { - t.Fatalf("FromCgroupHybrid() error = %v, wantErr %v", err, nil) +func TestParseCgroupHierarchyLine(t *testing.T) { + tests := []struct { + name string + input string + want cgroupHierarchy + wantErr string + }{ + { + name: "valid line with multiple controllers", + input: "5:cpuacct,cpu,cpuset:/daemons", + want: cgroupHierarchy{ + HierarchyID: "5", + ControllerList: "cpuacct,cpu,cpuset", + CgroupPath: "/daemons", + }, + }, + { + name: "valid line with no controllers (cgroup v2)", + input: "0::/system.slice/docker.service", + want: cgroupHierarchy{ + HierarchyID: "0", + ControllerList: "", + CgroupPath: "/system.slice/docker.service", + }, + }, + { + name: "invalid line - only two fields", + input: "5:cpuacct,cpu,cpuset", + wantErr: "not enough fields: [5 cpuacct,cpu,cpuset]", + }, + { + name: "invalid line - too many fields", + input: "5:cpuacct,cpu:cpuset:/daemons:extra", + wantErr: "too many fields: [5 cpuacct,cpu cpuset /daemons extra]", + }, + { + name: "empty line", + input: "", + wantErr: "empty line", + }, + { + name: "line with empty controller list but valid fields", + input: "2::/my_cgroup", + want: cgroupHierarchy{ + HierarchyID: "2", + ControllerList: "", + CgroupPath: "/my_cgroup", + }, + }, } - if limit != expected { - t.Fatalf("FromCgroupHybrid() got = %v, want %v", limit, expected) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseCgroupHierarchyLine(tt.input) + if tt.wantErr != "" { + if err == nil { + t.Fatalf("expected an error containing %q, got nil", tt.wantErr) + } + if err.Error() != tt.wantErr { + t.Fatalf("expected error containing %q, got %q", tt.wantErr, err.Error()) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !reflect.DeepEqual(got, tt.want) { + t.Fatalf("expected %+v, got %+v", tt.want, got) + } + }) } } -func TestFromCgroupV2(t *testing.T) { - if expected == 0 || cgVersion != cgroups.Unified { - t.Skip() +func TestResolveCgroupPath(t *testing.T) { + tests := []struct { + name string + mountpoint string + root string + cgroupRelPath string + want string + wantErr string + }{ + { + name: "exact match with both root and cgroupRelPath as '/'", + mountpoint: "/fake/mount", + root: "/", + cgroupRelPath: "/", + want: "/fake/mount", + }, + { + name: "exact match with a non-root path", + mountpoint: "/fake/mount", + root: "/container0", + cgroupRelPath: "/container0", + want: "/fake/mount", + }, + { + name: "valid subpath under root", + mountpoint: "/fake/mount", + root: "/container0", + cgroupRelPath: "/container0/group1", + want: "/fake/mount/group1", + }, + { + name: "invalid cgroup path outside root", + mountpoint: "/fake/mount", + root: "/container0", + cgroupRelPath: "/other_container", + wantErr: "invalid cgroup path: /other_container is not under root /container0", + }, } - limit, err := FromCgroupV2() - if err != nil { - t.Fatalf("FromCgroupV2() error = %v, wantErr %v", err, nil) - } - if limit != expected { - t.Fatalf("FromCgroupV2() got = %v, want %v", limit, expected) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := resolveCgroupPath(tt.mountpoint, tt.root, tt.cgroupRelPath) + if tt.wantErr != "" { + if err == nil { + t.Fatalf("expected an error containing %q, got nil", tt.wantErr) + } + if err.Error() != tt.wantErr { + t.Fatalf("expected error containing %q, got %q", tt.wantErr, err.Error()) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if got != tt.want { + t.Fatalf("expected path %q, got %q", tt.want, got) + } + }) } } diff --git a/memlimit/memlimit_common_test.go b/memlimit/memlimit_common_test.go deleted file mode 100644 index e65aa00..0000000 --- a/memlimit/memlimit_common_test.go +++ /dev/null @@ -1,299 +0,0 @@ -package memlimit - -import ( - "fmt" - "math" - "runtime/debug" - "sync/atomic" - "testing" - "time" -) - -func TestLimit(t *testing.T) { - type args struct { - limit uint64 - } - tests := []struct { - name string - args args - want uint64 - wantErr error - }{ - { - name: "0bytes", - args: args{ - limit: 0, - }, - want: 0, - wantErr: nil, - }, - { - name: "1kib", - args: args{ - limit: 1024, - }, - want: 1024, - wantErr: nil, - }, - { - name: "1mib", - args: args{ - limit: 1024 * 1024, - }, - want: 1024 * 1024, - wantErr: nil, - }, - { - name: "1gib", - args: args{ - limit: 1024 * 1024 * 1024, - }, - want: 1024 * 1024 * 1024, - wantErr: nil, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := Limit(tt.args.limit)() - if err != tt.wantErr { - t.Errorf("Limit() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("Limit() got = %v, want %v", got, tt.want) - } - }) - } -} - -func TestSetGoMemLimitWithProvider(t *testing.T) { - type args struct { - provider Provider - ratio float64 - } - tests := []struct { - name string - args args - want int64 - wantErr error - gomemlimit int64 - }{ - { - name: "Limit_0.5", - args: args{ - provider: Limit(1024 * 1024 * 1024), - ratio: 0.5, - }, - want: 536870912, - wantErr: nil, - gomemlimit: 536870912, - }, - { - name: "Limit_0.9", - args: args{ - provider: Limit(1024 * 1024 * 1024), - ratio: 0.9, - }, - want: 966367641, - wantErr: nil, - gomemlimit: 966367641, - }, - { - name: "Limit_0.9_math.MaxUint64", - args: args{ - provider: Limit(math.MaxUint64), - ratio: 0.9, - }, - want: math.MaxInt64, - wantErr: nil, - gomemlimit: math.MaxInt64, - }, - { - name: "Limit_0.9_math.MaxUint64", - args: args{ - provider: Limit(math.MaxUint64), - ratio: 0.9, - }, - want: math.MaxInt64, - wantErr: nil, - gomemlimit: math.MaxInt64, - }, - { - name: "Limit_0.45_math.MaxUint64", - args: args{ - provider: Limit(math.MaxUint64), - ratio: 0.45, - }, - want: 8301034833169298432, - wantErr: nil, - gomemlimit: 8301034833169298432, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Cleanup(func() { - debug.SetMemoryLimit(math.MaxInt64) - }) - got, err := SetGoMemLimitWithProvider(tt.args.provider, tt.args.ratio) - if err != tt.wantErr { - t.Errorf("SetGoMemLimitWithProvider() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("SetGoMemLimitWithProvider() got = %v, want %v", got, tt.want) - } - if debug.SetMemoryLimit(-1) != tt.gomemlimit { - t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", debug.SetMemoryLimit(-1), tt.gomemlimit) - } - }) - } -} - -func TestSetGoMemLimitWithOpts(t *testing.T) { - tests := []struct { - name string - opts []Option - want int64 - wantErr error - gomemlimit int64 - }{ - { - name: "unknown error", - opts: []Option{ - WithProvider(func() (uint64, error) { - return 0, fmt.Errorf("unknown error") - }), - }, - want: 0, - wantErr: fmt.Errorf("failed to set GOMEMLIMIT: unknown error"), - gomemlimit: math.MaxInt64, - }, - { - name: "ErrNoLimit", - opts: []Option{ - WithProvider(func() (uint64, error) { - return 0, ErrNoLimit - }), - }, - want: 0, - wantErr: nil, - gomemlimit: math.MaxInt64, - }, - { - name: "wrapped ErrNoLimit", - opts: []Option{ - WithProvider(func() (uint64, error) { - return 0, fmt.Errorf("wrapped: %w", ErrNoLimit) - }), - }, - want: 0, - wantErr: nil, - gomemlimit: math.MaxInt64, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := SetGoMemLimitWithOpts(tt.opts...) - if tt.wantErr != nil && err.Error() != tt.wantErr.Error() { - t.Errorf("SetGoMemLimitWithOpts() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("SetGoMemLimitWithOpts() got = %v, want %v", got, tt.want) - } - if debug.SetMemoryLimit(-1) != tt.gomemlimit { - t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", debug.SetMemoryLimit(-1), tt.gomemlimit) - } - }) - } -} - -func TestSetGoMemLimitWithOpts_rollbackOnPanic(t *testing.T) { - t.Cleanup(func() { - debug.SetMemoryLimit(math.MaxInt64) - }) - - limit := int64(987654321) - _ = debug.SetMemoryLimit(987654321) - _, err := SetGoMemLimitWithOpts( - WithProvider(func() (uint64, error) { - debug.SetMemoryLimit(123456789) - panic("panic") - }), - WithRatio(1), - ) - if err == nil { - t.Error("SetGoMemLimtWithOpts() error = nil, want panic") - } - - curr := debug.SetMemoryLimit(-1) - if curr != limit { - t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", curr, limit) - } -} - -func TestSetGoMemLimitWithOpts_WithRefreshInterval(t *testing.T) { - t.Cleanup(func() { - debug.SetMemoryLimit(math.MaxInt64) - }) - - var limit atomic.Int64 - output, err := SetGoMemLimitWithOpts( - WithProvider(func() (uint64, error) { - l := limit.Load() - if l == 0 { - return 0, ErrNoLimit - } - return uint64(l), nil - }), - WithRatio(1), - WithRefreshInterval(10*time.Millisecond), - ) - if err != nil { - t.Errorf("SetGoMemLimitWithOpts() error = %v", err) - } else if output != limit.Load() { - t.Errorf("SetGoMemLimitWithOpts() got = %v, want %v", output, limit.Load()) - } - - // 1. no limit - curr := debug.SetMemoryLimit(-1) - if curr != math.MaxInt64 { - t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", curr, limit.Load()) - } - - // 2. max limit - limit.Add(math.MaxInt64) - time.Sleep(100 * time.Millisecond) - - curr = debug.SetMemoryLimit(-1) - if curr != math.MaxInt64 { - t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", curr, math.MaxInt64) - } - - // 3. adjust limit - limit.Add(-1024) - time.Sleep(100 * time.Millisecond) - - curr = debug.SetMemoryLimit(-1) - if curr != math.MaxInt64-1024 { - t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", curr, math.MaxInt64-1024) - } - - // 4. no limit again (don't change the limit) - limit.Store(0) - time.Sleep(100 * time.Millisecond) - - curr = debug.SetMemoryLimit(-1) - if curr != math.MaxInt64-1024 { - t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", curr, math.MaxInt64-1024) - } - - // 5. new limit - limit.Store(math.MaxInt32) - time.Sleep(100 * time.Millisecond) - - curr = debug.SetMemoryLimit(-1) - if curr != math.MaxInt32 { - t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", curr, math.MaxInt32) - } -} diff --git a/memlimit/memlimit_linux_test.go b/memlimit/memlimit_linux_test.go new file mode 100644 index 0000000..2aab0d5 --- /dev/null +++ b/memlimit/memlimit_linux_test.go @@ -0,0 +1,234 @@ +//go:build linux +// +build linux + +package memlimit + +import ( + "flag" + "log" + "math" + "os" + "runtime/debug" + "testing" +) + +var ( + cgVersion uint64 + expected uint64 + expectedSystem uint64 +) + +func TestMain(m *testing.M) { + flag.Uint64Var(&expected, "expected", 0, "Expected cgroup's memory limit") + flag.Uint64Var(&expectedSystem, "expected-system", 0, "Expected system memory limit") + flag.Uint64Var(&cgVersion, "cgroup-version", 0, "Cgroup version") + flag.Parse() + + os.Exit(m.Run()) +} + +func TestSetGoMemLimit(t *testing.T) { + type args struct { + ratio float64 + } + tests := []struct { + name string + args args + want int64 + wantErr error + gomemlimit int64 + skip bool + }{ + { + name: "0.5", + args: args{ + ratio: 0.5, + }, + want: int64(float64(expected) * 0.5), + wantErr: nil, + gomemlimit: int64(float64(expected) * 0.5), + skip: expected == 0 || cgVersion == 0, + }, + { + name: "0.9", + args: args{ + ratio: 0.9, + }, + want: int64(float64(expected) * 0.9), + wantErr: nil, + gomemlimit: int64(float64(expected) * 0.9), + skip: expected == 0 || cgVersion == 0, + }, + { + name: "Unavailable", + args: args{ + ratio: 0.9, + }, + want: 0, + wantErr: ErrCgroupsNotSupported, + gomemlimit: math.MaxInt64, + skip: cgVersion != 0, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.skip { + t.Skip() + } + t.Cleanup(func() { + debug.SetMemoryLimit(math.MaxInt64) + }) + got, err := SetGoMemLimit(tt.args.ratio) + if err != tt.wantErr { + t.Errorf("SetGoMemLimit() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("SetGoMemLimit() got = %v, want %v", got, tt.want) + } + if debug.SetMemoryLimit(-1) != tt.gomemlimit { + t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", debug.SetMemoryLimit(-1), tt.gomemlimit) + } + }) + } +} + +func TestSetGoMemLimitWithProvider_WithCgroupProvider(t *testing.T) { + type args struct { + provider Provider + ratio float64 + } + tests := []struct { + name string + args args + want int64 + wantErr error + gomemlimit int64 + skip bool + }{ + { + name: "FromCgroup", + args: args{ + provider: FromCgroup, + ratio: 0.9, + }, + want: int64(float64(expected) * 0.9), + wantErr: nil, + gomemlimit: int64(float64(expected) * 0.9), + skip: expected == 0 || cgVersion == 0, + }, + { + name: "FromCgroup_Unavaliable", + args: args{ + provider: FromCgroup, + ratio: 0.9, + }, + want: 0, + wantErr: ErrNoCgroup, + gomemlimit: math.MaxInt64, + skip: expected == 0 || cgVersion != 0, + }, + { + name: "FromCgroupV1", + args: args{ + provider: FromCgroupV1, + ratio: 0.9, + }, + want: int64(float64(expected) * 0.9), + wantErr: nil, + gomemlimit: int64(float64(expected) * 0.9), + skip: expected == 0 || cgVersion != 1, + }, + { + name: "FromCgroupHybrid", + args: args{ + provider: FromCgroupHybrid, + ratio: 0.9, + }, + want: int64(float64(expected) * 0.9), + wantErr: nil, + gomemlimit: int64(float64(expected) * 0.9), + skip: expected == 0 || cgVersion != 1, + }, + { + name: "FromCgroupV2", + args: args{ + provider: FromCgroupV2, + ratio: 0.9, + }, + want: int64(float64(expected) * 0.9), + wantErr: nil, + gomemlimit: int64(float64(expected) * 0.9), + skip: expected == 0 || cgVersion != 2, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.skip { + t.Skip() + } + t.Cleanup(func() { + debug.SetMemoryLimit(math.MaxInt64) + }) + got, err := SetGoMemLimitWithProvider(tt.args.provider, tt.args.ratio) + if err != tt.wantErr { + t.Errorf("SetGoMemLimitWithProvider() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("SetGoMemLimitWithProvider() got = %v, want %v", got, tt.want) + } + if debug.SetMemoryLimit(-1) != tt.gomemlimit { + t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", debug.SetMemoryLimit(-1), tt.gomemlimit) + } + }) + } +} + +func TestSetGoMemLimitWithProvider_WithSystemProvider(t *testing.T) { + type args struct { + provider Provider + ratio float64 + } + tests := []struct { + name string + args args + want int64 + wantErr error + gomemlimit int64 + skip bool + }{ + { + name: "FromSystem", + args: args{ + provider: FromSystem, + ratio: 0.9, + }, + want: int64(float64(expectedSystem) * 0.9), + wantErr: nil, + gomemlimit: int64(float64(expectedSystem) * 0.9), + skip: expectedSystem == 0, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.skip { + t.Skip() + } + t.Cleanup(func() { + debug.SetMemoryLimit(math.MaxInt64) + }) + got, err := SetGoMemLimitWithProvider(tt.args.provider, tt.args.ratio) + if err != tt.wantErr { + t.Errorf("SetGoMemLimitWithProvider() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("SetGoMemLimitWithProvider() got = %v, want %v", got, tt.want) + } + if debug.SetMemoryLimit(-1) != tt.gomemlimit { + t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", debug.SetMemoryLimit(-1), tt.gomemlimit) + } + }) + } +} diff --git a/memlimit/memlimit_test.go b/memlimit/memlimit_test.go index 21d7f22..e65aa00 100644 --- a/memlimit/memlimit_test.go +++ b/memlimit/memlimit_test.go @@ -1,103 +1,72 @@ -//go:build linux -// +build linux - package memlimit import ( - "flag" - "log" + "fmt" "math" - "os" "runtime/debug" + "sync/atomic" "testing" - - "github.com/containerd/cgroups/v3" -) - -var ( - cgVersion cgroups.CGMode - expected uint64 - expectedSystem uint64 + "time" ) -func TestMain(m *testing.M) { - flag.Uint64Var(&expected, "expected", 0, "Expected cgroup's memory limit") - flag.Uint64Var(&expectedSystem, "expected-system", 0, "Expected system memory limit") - flag.Parse() - - cgVersion = cgroups.Mode() - log.Println("Cgroups version:", cgVersion) - - os.Exit(m.Run()) -} - -func TestSetGoMemLimit(t *testing.T) { +func TestLimit(t *testing.T) { type args struct { - ratio float64 + limit uint64 } tests := []struct { - name string - args args - want int64 - wantErr error - gomemlimit int64 - skip bool + name string + args args + want uint64 + wantErr error }{ { - name: "0.5", + name: "0bytes", args: args{ - ratio: 0.5, + limit: 0, }, - want: int64(float64(expected) * 0.5), - wantErr: nil, - gomemlimit: int64(float64(expected) * 0.5), - skip: expected == 0 || cgVersion == cgroups.Unavailable, + want: 0, + wantErr: nil, }, { - name: "0.9", + name: "1kib", args: args{ - ratio: 0.9, + limit: 1024, }, - want: int64(float64(expected) * 0.9), - wantErr: nil, - gomemlimit: int64(float64(expected) * 0.9), - skip: expected == 0 || cgVersion == cgroups.Unavailable, + want: 1024, + wantErr: nil, }, { - name: "Unavailable", + name: "1mib", args: args{ - ratio: 0.9, + limit: 1024 * 1024, }, - want: 0, - wantErr: ErrCgroupsNotSupported, - gomemlimit: math.MaxInt64, - skip: cgVersion != cgroups.Unavailable, + want: 1024 * 1024, + wantErr: nil, + }, + { + name: "1gib", + args: args{ + limit: 1024 * 1024 * 1024, + }, + want: 1024 * 1024 * 1024, + wantErr: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if tt.skip { - t.Skip() - } - t.Cleanup(func() { - debug.SetMemoryLimit(math.MaxInt64) - }) - got, err := SetGoMemLimit(tt.args.ratio) + got, err := Limit(tt.args.limit)() if err != tt.wantErr { - t.Errorf("SetGoMemLimit() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("Limit() error = %v, wantErr %v", err, tt.wantErr) return } if got != tt.want { - t.Errorf("SetGoMemLimit() got = %v, want %v", got, tt.want) - } - if debug.SetMemoryLimit(-1) != tt.gomemlimit { - t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", debug.SetMemoryLimit(-1), tt.gomemlimit) + t.Errorf("Limit() got = %v, want %v", got, tt.want) } }) } } -func TestSetGoMemLimitWithProvider_WithCgroupProvider(t *testing.T) { +func TestSetGoMemLimitWithProvider(t *testing.T) { type args struct { provider Provider ratio float64 @@ -108,69 +77,60 @@ func TestSetGoMemLimitWithProvider_WithCgroupProvider(t *testing.T) { want int64 wantErr error gomemlimit int64 - skip bool }{ { - name: "FromCgroup", + name: "Limit_0.5", args: args{ - provider: FromCgroup, - ratio: 0.9, + provider: Limit(1024 * 1024 * 1024), + ratio: 0.5, }, - want: int64(float64(expected) * 0.9), + want: 536870912, wantErr: nil, - gomemlimit: int64(float64(expected) * 0.9), - skip: expected == 0 || cgVersion == cgroups.Unavailable, + gomemlimit: 536870912, }, { - name: "FromCgroup_Unavaliable", + name: "Limit_0.9", args: args{ - provider: FromCgroup, + provider: Limit(1024 * 1024 * 1024), ratio: 0.9, }, - want: 0, - wantErr: ErrNoCgroup, - gomemlimit: math.MaxInt64, - skip: expected == 0 || cgVersion != cgroups.Unavailable, + want: 966367641, + wantErr: nil, + gomemlimit: 966367641, }, { - name: "FromCgroupV1", + name: "Limit_0.9_math.MaxUint64", args: args{ - provider: FromCgroupV1, + provider: Limit(math.MaxUint64), ratio: 0.9, }, - want: int64(float64(expected) * 0.9), + want: math.MaxInt64, wantErr: nil, - gomemlimit: int64(float64(expected) * 0.9), - skip: expected == 0 || cgVersion != cgroups.Legacy, + gomemlimit: math.MaxInt64, }, { - name: "FromCgroupHybrid", + name: "Limit_0.9_math.MaxUint64", args: args{ - provider: FromCgroupHybrid, + provider: Limit(math.MaxUint64), ratio: 0.9, }, - want: int64(float64(expected) * 0.9), + want: math.MaxInt64, wantErr: nil, - gomemlimit: int64(float64(expected) * 0.9), - skip: expected == 0 || cgVersion != cgroups.Hybrid, + gomemlimit: math.MaxInt64, }, { - name: "FromCgroupV2", + name: "Limit_0.45_math.MaxUint64", args: args{ - provider: FromCgroupV2, - ratio: 0.9, + provider: Limit(math.MaxUint64), + ratio: 0.45, }, - want: int64(float64(expected) * 0.9), + want: 8301034833169298432, wantErr: nil, - gomemlimit: int64(float64(expected) * 0.9), - skip: expected == 0 || cgVersion != cgroups.Unified, + gomemlimit: 8301034833169298432, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if tt.skip { - t.Skip() - } t.Cleanup(func() { debug.SetMemoryLimit(math.MaxInt64) }) @@ -189,46 +149,57 @@ func TestSetGoMemLimitWithProvider_WithCgroupProvider(t *testing.T) { } } -func TestSetGoMemLimitWithProvider_WithSystemProvider(t *testing.T) { - type args struct { - provider Provider - ratio float64 - } +func TestSetGoMemLimitWithOpts(t *testing.T) { tests := []struct { name string - args args + opts []Option want int64 wantErr error gomemlimit int64 - skip bool }{ { - name: "FromSystem", - args: args{ - provider: FromSystem, - ratio: 0.9, + name: "unknown error", + opts: []Option{ + WithProvider(func() (uint64, error) { + return 0, fmt.Errorf("unknown error") + }), }, - want: int64(float64(expectedSystem) * 0.9), + want: 0, + wantErr: fmt.Errorf("failed to set GOMEMLIMIT: unknown error"), + gomemlimit: math.MaxInt64, + }, + { + name: "ErrNoLimit", + opts: []Option{ + WithProvider(func() (uint64, error) { + return 0, ErrNoLimit + }), + }, + want: 0, wantErr: nil, - gomemlimit: int64(float64(expectedSystem) * 0.9), - skip: expectedSystem == 0, + gomemlimit: math.MaxInt64, + }, + { + name: "wrapped ErrNoLimit", + opts: []Option{ + WithProvider(func() (uint64, error) { + return 0, fmt.Errorf("wrapped: %w", ErrNoLimit) + }), + }, + want: 0, + wantErr: nil, + gomemlimit: math.MaxInt64, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if tt.skip { - t.Skip() - } - t.Cleanup(func() { - debug.SetMemoryLimit(math.MaxInt64) - }) - got, err := SetGoMemLimitWithProvider(tt.args.provider, tt.args.ratio) - if err != tt.wantErr { - t.Errorf("SetGoMemLimitWithProvider() error = %v, wantErr %v", err, tt.wantErr) + got, err := SetGoMemLimitWithOpts(tt.opts...) + if tt.wantErr != nil && err.Error() != tt.wantErr.Error() { + t.Errorf("SetGoMemLimitWithOpts() error = %v, wantErr %v", err, tt.wantErr) return } if got != tt.want { - t.Errorf("SetGoMemLimitWithProvider() got = %v, want %v", got, tt.want) + t.Errorf("SetGoMemLimitWithOpts() got = %v, want %v", got, tt.want) } if debug.SetMemoryLimit(-1) != tt.gomemlimit { t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", debug.SetMemoryLimit(-1), tt.gomemlimit) @@ -236,3 +207,93 @@ func TestSetGoMemLimitWithProvider_WithSystemProvider(t *testing.T) { }) } } + +func TestSetGoMemLimitWithOpts_rollbackOnPanic(t *testing.T) { + t.Cleanup(func() { + debug.SetMemoryLimit(math.MaxInt64) + }) + + limit := int64(987654321) + _ = debug.SetMemoryLimit(987654321) + _, err := SetGoMemLimitWithOpts( + WithProvider(func() (uint64, error) { + debug.SetMemoryLimit(123456789) + panic("panic") + }), + WithRatio(1), + ) + if err == nil { + t.Error("SetGoMemLimtWithOpts() error = nil, want panic") + } + + curr := debug.SetMemoryLimit(-1) + if curr != limit { + t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", curr, limit) + } +} + +func TestSetGoMemLimitWithOpts_WithRefreshInterval(t *testing.T) { + t.Cleanup(func() { + debug.SetMemoryLimit(math.MaxInt64) + }) + + var limit atomic.Int64 + output, err := SetGoMemLimitWithOpts( + WithProvider(func() (uint64, error) { + l := limit.Load() + if l == 0 { + return 0, ErrNoLimit + } + return uint64(l), nil + }), + WithRatio(1), + WithRefreshInterval(10*time.Millisecond), + ) + if err != nil { + t.Errorf("SetGoMemLimitWithOpts() error = %v", err) + } else if output != limit.Load() { + t.Errorf("SetGoMemLimitWithOpts() got = %v, want %v", output, limit.Load()) + } + + // 1. no limit + curr := debug.SetMemoryLimit(-1) + if curr != math.MaxInt64 { + t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", curr, limit.Load()) + } + + // 2. max limit + limit.Add(math.MaxInt64) + time.Sleep(100 * time.Millisecond) + + curr = debug.SetMemoryLimit(-1) + if curr != math.MaxInt64 { + t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", curr, math.MaxInt64) + } + + // 3. adjust limit + limit.Add(-1024) + time.Sleep(100 * time.Millisecond) + + curr = debug.SetMemoryLimit(-1) + if curr != math.MaxInt64-1024 { + t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", curr, math.MaxInt64-1024) + } + + // 4. no limit again (don't change the limit) + limit.Store(0) + time.Sleep(100 * time.Millisecond) + + curr = debug.SetMemoryLimit(-1) + if curr != math.MaxInt64-1024 { + t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", curr, math.MaxInt64-1024) + } + + // 5. new limit + limit.Store(math.MaxInt32) + time.Sleep(100 * time.Millisecond) + + curr = debug.SetMemoryLimit(-1) + if curr != math.MaxInt32 { + t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", curr, math.MaxInt32) + } +}