diff --git a/.golangci.yml b/.golangci.yml index ddeb1fd2..4767636a 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -102,7 +102,6 @@ linters: # # DEPRECATED by golangi-lint # - - execinquery - exportloopref # @@ -131,7 +130,6 @@ linters: # Recommended? (requires some work) # - - gomnd # An analyzer to detect magic numbers. - ireturn # Accept Interfaces, Return Concrete Types - mnd # An analyzer to detect magic numbers. - unparam # Reports unused function parameters diff --git a/pkg/pf/pf_context.go b/pkg/pf/pf_context.go index 11f1d2d2..36649b2c 100644 --- a/pkg/pf/pf_context.go +++ b/pkg/pf/pf_context.go @@ -55,9 +55,9 @@ func (ctx *pfContext) shutDown() error { return nil } -// getStateIPs returns a list of IPs that are currently in the state table. -func getStateIPs() (map[string]bool, error) { - ret := make(map[string]bool) +// getStatesToKill returns the states of the connections that must be terminated. +func getStatesToKill(banned map[string]struct{}) (map[string]map[string]struct{}, error) { + ret := make(map[string]map[string]struct{}) cmd := exec.Command(pfctlCmd, "-s", "states") @@ -73,24 +73,49 @@ func getStateIPs() (map[string]bool, error) { continue } - // right side - ip := fields[4] - if strings.Contains(ip, ":") { - ip = strings.Split(ip, ":")[0] + left := fields[2] + if strings.Contains(left, ":") { + left = strings.Split(left, ":")[0] } - ret[ip] = true + right := fields[4] + if strings.Contains(right, ":") { + right = strings.Split(right, ":")[0] + } + + // Don't know the direction, is left or right the origin of the connection? + // We either look at the arrow direction, or don't need to care and will treat both cases. + // + // The banned ip will be associated to an empty map (will call pfctl -k ) + // The other ip will be associated to a map where the keys are the banned ips with an existing connection. + // (i.e. pfctl -k -k ) so we don't have to terminate ALL connections from other_ip. + + var bannedIP, otherIP string - // left side - ip = fields[2] - if strings.Contains(ip, ":") { - ip = strings.Split(ip, ":")[0] + if _, ok := banned[left]; ok { + bannedIP = left + otherIP = right } - ret[ip] = true - } + if _, ok := banned[right]; ok { + bannedIP = right + otherIP = left + } + + if bannedIP == "" { + continue + } + + // will call "pfctl -k " + ret[bannedIP] = make(map[string]struct{}) - log.Tracef("Found IPs in state table: %v", len(ret)) + // will call "pfctl -k -k " + if _, ok := ret[otherIP]; !ok { + ret[otherIP] = make(map[string]struct{}) + } + + ret[otherIP][bannedIP] = struct{}{} + } return ret, nil } @@ -103,9 +128,9 @@ func (ctx *pfContext) add(decisions []*models.Decision) error { } } - bannedIPs := make(map[string]bool) + bannedIPs := make(map[string]struct{}) for _, d := range decisions { - bannedIPs[*d.Value] = true + bannedIPs[*d.Value] = struct{}{} } if len(bannedIPs) == 0 { @@ -115,16 +140,27 @@ func (ctx *pfContext) add(decisions []*models.Decision) error { log.Tracef("New banned IPs: %v", bannedIPs) - stateIPs, err := getStateIPs() + // Get the states of connections + // - from a banned IP + // - from any IP to a banned IP + + states, err := getStatesToKill(bannedIPs) if err != nil { return fmt.Errorf("error while getting state IPs: %w", err) } - // Reset the states of connections coming from an IP if it's both in stateIPs and bannedIPs + for source := range states { + targets := states[source] + if len(targets) == 0 { + cmd := execPfctl("", "-k", source) + if out, err := cmd.CombinedOutput(); err != nil { + log.Errorf("Error while flushing state (%s): %v --> %s", cmd, err, out) + } + continue + } - for ip := range bannedIPs { - if stateIPs[ip] { - cmd := execPfctl("", "-k", ip) + for target := range targets { + cmd := execPfctl("", "-k", source, "-k", target) if out, err := cmd.CombinedOutput(); err != nil { log.Errorf("Error while flushing state (%s): %v --> %s", cmd, err, out) }