Skip to content

Commit

Permalink
temp
Browse files Browse the repository at this point in the history
  • Loading branch information
KimMachineGun committed Dec 8, 2024
1 parent 944c495 commit a2338a2
Show file tree
Hide file tree
Showing 7 changed files with 1,093 additions and 825 deletions.
382 changes: 382 additions & 0 deletions memlimit/cgroups.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
package memlimit

import (
"bufio"
"errors"
"fmt"
"io"
"math"
"os"
"path/filepath"
"slices"
"strconv"
"strings"
)

var (
Expand All @@ -10,3 +19,376 @@ var (
// ErrCgroupsNotSupported is returned when the system does not support cgroups.
ErrCgroupsNotSupported = errors.New("cgroups is not supported on this system")
)

func fromCgroup(versionDetector func(mis []mountInfo) (bool, bool)) (uint64, error) {
mf, err := os.Open("/proc/self/mountinfo")
if err != nil {
return 0, fmt.Errorf("failed to open /proc/self/mountinfo: %w", err)
}
defer mf.Close()

mis, err := parseMountInfo(mf)
if err != nil {
return 0, fmt.Errorf("failed to parse mountinfo: %w", err)
}

v1, v2 := versionDetector(mis)
if !(v1 || v2) {
return 0, ErrNoCgroup
}

cf, err := os.Open("/proc/self/cgroup")
if err != nil {
return 0, fmt.Errorf("failed to open /proc/self/cgroup: %w", err)
}
defer cf.Close()

chs, err := parseCgroupFile(cf)
if err != nil {
return 0, fmt.Errorf("failed to parse cgroup file: %w", err)
}

if v2 {
limit, err := getMemoryLimitV2(chs, mis)
if err == nil {
return limit, nil
} else if !v1 {
return 0, err
}
}

return getMemoryLimitV1(chs, mis)
}

func detectCgroupVersion(mis []mountInfo) (bool, bool) {
var v1, v2 bool
for _, mi := range mis {
switch mi.FilesystemType {
case "cgroup":
v1 = true
case "cgroup2":
v2 = true
}
}
return v1, v2
}

// getMemoryLimitV2 retrieves the memory limit for cgroup v2.
func getMemoryLimitV2(chs []cgroupHierarchy, mis []mountInfo) (uint64, error) {
idx := slices.IndexFunc(chs, func(ch cgroupHierarchy) bool {
return ch.HierarchyID == "0" && ch.ControllerList == ""
})
if idx == -1 {
return 0, errors.New("cgroup v2 path not found")
}
relPath := chs[idx].CgroupPath

idx = slices.IndexFunc(mis, func(mi mountInfo) bool {
return mi.FilesystemType == "cgroup2"
})
if idx == -1 {
return 0, errors.New("cgroup v2 mountpoint not found")
}
root, mountPoint := mis[idx].Root, mis[idx].MountPoint

// Resolve the actual cgroup path
cgroupPath, err := resolveCgroupPath(mountPoint, root, relPath)
if err != nil {
return 0, err
}

// Construct the path to memory.max
memoryMaxPath := filepath.Join(cgroupPath, "memory.max")

// Read the memory limit from memory.max
return readMemoryLimitV2FromPath(memoryMaxPath)
}

// getMemoryLimitV1 retrieves the memory limit for cgroup v1.
func getMemoryLimitV1(chs []cgroupHierarchy, mis []mountInfo) (uint64, error) {
idx := slices.IndexFunc(chs, func(ch cgroupHierarchy) bool {
return slices.Contains(strings.Split(ch.ControllerList, ","), "memory")
})
if idx == -1 {
return 0, errors.New("cgroup v1 path for memory controller not found")
}
relPath := chs[idx].CgroupPath

idx = slices.IndexFunc(mis, func(mi mountInfo) bool {
return mi.FilesystemType == "cgroup" && slices.Contains(strings.Split(mi.SuperOptions, ","), "memory")
})
if idx == -1 {
return 0, errors.New("cgroup v1 mountpoint for memory controller not found")
}
root, mountPoint := mis[idx].Root, mis[idx].MountPoint

// Resolve the actual cgroup path
cgroupPath, err := resolveCgroupPath(mountPoint, root, relPath)
if err != nil {
return 0, err
}

// Retrieve the memory limit
return readMemoryLimitV1FromPath(cgroupPath)
}

// readMemoryLimitV2FromPath reads the memory limit from the memory.max file for cgroup v2.
func readMemoryLimitV2FromPath(path string) (uint64, error) {
b, err := os.ReadFile(path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return 0, ErrNoLimit
}
return 0, fmt.Errorf("failed to read memory.max: %w", err)
}

slimit := strings.TrimSpace(string(b))
if slimit == "max" {
return 0, ErrNoLimit
}

limit, err := strconv.ParseUint(slimit, 10, 64)
if err != nil {
return 0, fmt.Errorf("failed to parse memory.max value: %w", err)
}

return limit, nil
}

func getCgroupV1NoLimit() uint64 {
ps := uint64(os.Getpagesize())
return math.MaxInt64 / ps * ps
}

// readMemoryLimitV1FromPath reads the memory limit for cgroup v1 from the given path.
func readMemoryLimitV1FromPath(cgroupPath string) (uint64, error) {
hml, err := readHierarchicalMemoryLimit(filepath.Join(cgroupPath, "memory.stats"))
if err != nil && !errors.Is(err, os.ErrNotExist) {
return 0, fmt.Errorf("failed to read hierarchical_memory_limit: %w", err)
} else if hml == 0 {
hml = math.MaxUint64
}

b, err := os.ReadFile(filepath.Join(cgroupPath, "memory.limit_in_bytes"))
if err != nil && !errors.Is(err, os.ErrNotExist) {
return 0, fmt.Errorf("failed to read memory.limit_in_bytes: %w", err)
}
lib, err := strconv.ParseUint(strings.TrimSpace(string(b)), 10, 64)
if err != nil {
return 0, fmt.Errorf("failed to parse memory.limit_in_bytes value: %w", err)
} else if lib == 0 {
hml = math.MaxUint64
}

limit := min(hml, lib)
if limit >= getCgroupV1NoLimit() {
return 0, ErrNoLimit
}

return limit, nil
}

// readHierarchicalMemoryLimit extracts hierarchical_memory_limit from memory.stats for cgroup v1.
func readHierarchicalMemoryLimit(statPath string) (uint64, error) {
file, err := os.Open(statPath)
if err != nil {
return 0, err
}
defer file.Close()

scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()

fields := strings.Split(line, " ")
if len(fields) < 2 {
return 0, fmt.Errorf("failed to parse memory.stat %q: not enough fields", line)
}

if fields[0] == "hierarchical_memory_limit" {
if len(fields) > 2 {
return 0, fmt.Errorf("failed to parse memory.stat %q: too many fields for hierarchical_memory_limit", line)
}
return strconv.ParseUint(fields[1], 10, 64)
}
}
if err := scanner.Err(); err != nil {
return 0, err
}

return 0, nil
}

// https://www.man7.org/linux/man-pages/man5/proc_pid_mountinfo.5.html
// 731 771 0:59 /sysrq-trigger /proc/sysrq-trigger ro,nosuid,nodev,noexec,relatime - proc proc rw
//
// 36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw,errors=continue
// (1)(2)(3) (4) (5) (6) (7) (8) (9) (10) (11)
//
// (1) mount ID: a unique ID for the mount (may be reused after umount(2)).
// (2) parent ID: the ID of the parent mount (or of self for the root of this mount namespace's mount tree).
// (3) major:minor: the value of st_dev for files on this filesystem (see stat(2)).
// (4) root: the pathname of the directory in the filesystem which forms the root of this mount.
// (5) mount point: the pathname of the mount point relative to the process's root directory.
// (6) mount options: per-mount options (see mount(2)).
// (7) optional fields: zero or more fields of the form "tag[:value]"; see below.
// (8) separator: the end of the optional fields is marked by a single hyphen.
// (9) filesystem type: the filesystem type in the form "type[.subtype]".
// (10) mount source: filesystem-specific information or "none".
// (11) super options: per-superblock options (see mount(2)).
type mountInfo struct {
Root string
MountPoint string
FilesystemType string
SuperOptions string
}

func parseMountInfoLine(line string) (mountInfo, error) {
if line == "" {
return mountInfo{}, errors.New("empty line")
}

fieldss := strings.SplitN(line, " - ", 2)
if len(fieldss) != 2 {
return mountInfo{}, fmt.Errorf("invalid separator")
}

fields1 := strings.Split(fieldss[0], " ")
if len(fields1) < 6 {
return mountInfo{}, fmt.Errorf("not enough fields before separator: %v", fields1)
} else if len(fields1) > 7 {
return mountInfo{}, fmt.Errorf("too many fields before separator: %v", fields1)
} else if len(fields1) == 6 {
fields1 = append(fields1, "")
}

fields2 := strings.Split(fieldss[1], " ")
if len(fields2) < 3 {
return mountInfo{}, fmt.Errorf("not enough fields after separator: %v", fields2)
} else if len(fields2) > 3 {
return mountInfo{}, fmt.Errorf("too many fields after separator: %v", fields2)
}

return mountInfo{
Root: fields1[3],
MountPoint: fields1[4],
FilesystemType: fields2[0],
SuperOptions: fields2[2],
}, nil
}

func parseMountInfo(r io.Reader) ([]mountInfo, error) {
var (
s = bufio.NewScanner(r)
mis []mountInfo
)
for s.Scan() {
line := s.Text()

mi, err := parseMountInfoLine(line)
if err != nil {
return nil, fmt.Errorf("failed to parse mountinfo file %q: %w", line, err)
}

mis = append(mis, mi)
}
if err := s.Err(); err != nil {
return nil, err
}

return mis, nil
}

// https://www.man7.org/linux/man-pages/man7/cgroups.7.html
//
// 5:cpuacct,cpu,cpuset:/daemons
// (1) (2) (3)
//
// (1) hierarchy ID:
//
// cgroups version 1 hierarchies, this field
// contains a unique hierarchy ID number that can be
// matched to a hierarchy ID in /proc/cgroups. For the
// cgroups version 2 hierarchy, this field contains the
// value 0.
//
// (2) controller list:
//
// For cgroups version 1 hierarchies, this field
// contains a comma-separated list of the controllers
// bound to the hierarchy. For the cgroups version 2
// hierarchy, this field is empty.
//
// (3) cgroup path:
//
// This field contains the pathname of the control group
// in the hierarchy to which the process belongs. This
// pathname is relative to the mount point of the
// hierarchy.
type cgroupHierarchy struct {
HierarchyID string
ControllerList string
CgroupPath string
}

func parseCgroupHierarchyLine(line string) (cgroupHierarchy, error) {
if line == "" {
return cgroupHierarchy{}, errors.New("empty line")
}

fields := strings.Split(line, ":")
if len(fields) < 3 {
return cgroupHierarchy{}, fmt.Errorf("not enough fields: %v", fields)
} else if len(fields) > 3 {
return cgroupHierarchy{}, fmt.Errorf("too many fields: %v", fields)
}

return cgroupHierarchy{
HierarchyID: fields[0],
ControllerList: fields[1],
CgroupPath: fields[2],
}, nil
}

func parseCgroupFile(r io.Reader) ([]cgroupHierarchy, error) {
var (
s = bufio.NewScanner(r)
chs []cgroupHierarchy
)
for s.Scan() {
line := s.Text()

ch, err := parseCgroupHierarchyLine(line)
if err != nil {
return nil, fmt.Errorf("failed to parse cgroup file %q: %w", line, err)
}

chs = append(chs, ch)
}
if err := s.Err(); err != nil {
return nil, err
}

return chs, nil
}

// resolveCgroupPath resolves the actual cgroup path from the mountpoint, root, and cgroupRelPath.
func resolveCgroupPath(mountpoint, root, cgroupRelPath string) (string, error) {
rel, err := filepath.Rel(root, cgroupRelPath)
if err != nil {
return "", err
}

// if the relative path is ".", then the cgroupRelPath is the root itself.
if rel == "." {
return mountpoint, nil
}

// if the relative path starts with "..", then it is outside the root.
if strings.HasPrefix(rel, "..") {
return "", fmt.Errorf("invalid cgroup path: %s is not under root %s", cgroupRelPath, root)
}

return filepath.Join(mountpoint, rel), nil

}
Loading

0 comments on commit a2338a2

Please sign in to comment.