Skip to content

Commit

Permalink
Refactor environment parsing
Browse files Browse the repository at this point in the history
Systemd environment variables LISTEN* are unset by default and saved for
future calls
  • Loading branch information
balki committed May 1, 2023
1 parent bbd21d7 commit bbef4be
Showing 1 changed file with 55 additions and 27 deletions.
82 changes: 55 additions & 27 deletions anyhttp.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"os"
"strconv"
"strings"
"sync"
"syscall"
)

Expand Down Expand Up @@ -39,6 +40,39 @@ func NewUnixSocketConfig(socketPath string) UnixSocketConfig {
return usc
}

type sysdEnvData struct {
pid int
fdNames []string
fdNamesStr string
numFds int
}

var sysdEnvParser = struct {
sysdOnce sync.Once
data sysdEnvData
err error
}{}

func parse() (sysdEnvData, error) {
p := &sysdEnvParser
p.sysdOnce.Do(func() {
p.data.pid, p.err = strconv.Atoi(os.Getenv("LISTEN_PID"))
if p.err != nil {
p.err = fmt.Errorf("invalid LISTEN_PID, err: %w", p.err)
return
}
p.data.numFds, p.err = strconv.Atoi(os.Getenv("LISTEN_FDS"))
if p.err != nil {
p.err = fmt.Errorf("invalid LISTEN_FDS, err: %w", p.err)
return
}
p.data.fdNamesStr = os.Getenv("LISTEN_FDNAMES")
p.data.fdNames = strings.Split(p.data.fdNamesStr, ":")

})
return p.data, p.err
}

// SysdConfig has the configuration for the socket activated fd
type SysdConfig struct {
// Integer value starting at 0. Either index or name is required
Expand All @@ -54,7 +88,7 @@ type SysdConfig struct {
// DefaultSysdConfig has the default values for SysdConfig
var DefaultSysdConfig = SysdConfig{
CheckPID: true,
UnsetEnv: false,
UnsetEnv: true,
}

// NewSysDConfigWithFDIdx creates SysdConfig with defaults and fdIdx
Expand Down Expand Up @@ -112,53 +146,47 @@ func (s *SysdConfig) GetListener() (net.Listener, error) {
defer UnsetSystemdListenVars()
}

if s.CheckPID {
pid, err := strconv.Atoi(os.Getenv("LISTEN_PID"))
if err != nil {
return nil, fmt.Errorf("invalid LISTEN_PID, err: %w", err)
}
if pid != os.Getpid() {
return nil, fmt.Errorf("unexpected PID, current:%v, LISTEN_PID: %v", os.Getpid(), pid)
}
}

numFds, err := strconv.Atoi(os.Getenv("LISTEN_FDS"))
envData, err := parse()
if err != nil {
return nil, fmt.Errorf("invalid LISTEN_FDS, err: %w", err)
return nil, err
}

fdNames := strings.Split(os.Getenv("LISTEN_FDNAMES"), ":")
if s.CheckPID {
if envData.pid != os.Getpid() {
return nil, fmt.Errorf("unexpected PID, current:%v, LISTEN_PID: %v", os.Getpid(), envData.pid)
}
}

if s.FDIndex != nil {
idx := *s.FDIndex
if idx < 0 || idx >= numFds {
return nil, fmt.Errorf("invalid fd index, expected between 0 and %v, got: %v", numFds, idx)
if idx < 0 || idx >= envData.numFds {
return nil, fmt.Errorf("invalid fd index, expected between 0 and %v, got: %v", envData.numFds, idx)
}
fd := StartFD + idx
if idx < len(fdNames) {
return makeFdListener(fd, fdNames[idx])
if idx < len(envData.fdNames) {
return makeFdListener(fd, envData.fdNames[idx])
}
return makeFdListener(fd, fmt.Sprintf("sysdfd_%d", fd))
}

if s.FDName != nil {
for idx, name := range fdNames {
for idx, name := range envData.fdNames {
if name == *s.FDName {
fd := StartFD + idx
return makeFdListener(fd, name)
}
}
return nil, fmt.Errorf("fdName not found: %q, LISTEN_FDNAMES:%q", *s.FDName, os.Getenv("LISTEN_FDNAMES"))
return nil, fmt.Errorf("fdName not found: %q, LISTEN_FDNAMES:%q", *s.FDName, envData.fdNamesStr)
}

return nil, errors.New("neither FDIndex nor FDName set")
}

// UnknownAddress Error is returned when address does not match any known syntax
type UnknownAddress string
type UnknownAddress struct{}

func (u UnknownAddress) Error() string {
return fmt.Sprintf("unknown address: %q", string(u))
return "unknown address"
}

// GetListener gets a unix or systemd socket listener
Expand All @@ -182,15 +210,15 @@ func GetListener(addr string) (net.Listener, error) {
return sysdc.GetListener()
}

return nil, UnknownAddress(addr)
return nil, UnknownAddress{}
}

// ListenAndServe is the drop-in replacement for `http.ListenAndServe`.
// Supports unix and systemd sockets in addition
func ListenAndServe(addr string, h http.Handler) error {

listener, err := GetListener(addr)
if _, ok := err.(UnknownAddress); err != nil && !ok {
if _, isUnknown := err.(UnknownAddress); err != nil && !isUnknown {
return err
}

Expand All @@ -211,7 +239,7 @@ func ListenAndServe(addr string, h http.Handler) error {

// UnsetSystemdListenVars unsets the LISTEN* environment variables so they are not passed to any child processes
func UnsetSystemdListenVars() {
os.Unsetenv("LISTEN_PID")
os.Unsetenv("LISTEN_FDS")
os.Unsetenv("LISTEN_FDNAMES")
_ = os.Unsetenv("LISTEN_PID")
_ = os.Unsetenv("LISTEN_FDS")
_ = os.Unsetenv("LISTEN_FDNAMES")
}

0 comments on commit bbef4be

Please sign in to comment.