Skip to content

Commit

Permalink
pf: properly flush outgoing connections when adding a banned ip (#389)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmetc authored Dec 30, 2024
1 parent 4b99c16 commit ca4cc00
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
- name: golangci-lint
uses: golangci/golangci-lint-action@v6
with:
version: v1.61
version: v1.62
args: --issues-exit-code=1 --timeout 10m
only-new-issues: false

Expand Down
2 changes: 0 additions & 2 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ linters:
#
# DEPRECATED by golangi-lint
#
- execinquery
- exportloopref

#
Expand Down Expand Up @@ -131,7 +130,6 @@ linters:
# Recommended? (requires some work)
#

- gomnd # An analyzer to detect magic numbers.
- ireturn # Accept Interfaces, Return Concrete Types
- mnd # An analyzer to detect magic numbers.
- unparam # Reports unused function parameters
Expand Down
80 changes: 58 additions & 22 deletions pkg/pf/pf_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ func (ctx *pfContext) shutDown() error {
return nil
}

// getStateIPs returns a list of IPs that are currently in the state table.
func getStateIPs() (map[string]bool, error) {
ret := make(map[string]bool)
// getStatesToKill returns the states of the connections that must be terminated.
func getStatesToKill(banned map[string]struct{}) (map[string]map[string]struct{}, error) {
ret := make(map[string]map[string]struct{})

cmd := exec.Command(pfctlCmd, "-s", "states")

Expand All @@ -73,24 +73,49 @@ func getStateIPs() (map[string]bool, error) {
continue
}

// right side
ip := fields[4]
if strings.Contains(ip, ":") {
ip = strings.Split(ip, ":")[0]
left := fields[2]
if strings.Contains(left, ":") {
left = strings.Split(left, ":")[0]
}

ret[ip] = true
right := fields[4]
if strings.Contains(right, ":") {
right = strings.Split(right, ":")[0]
}

// Don't know the direction, is left or right the origin of the connection?
// We either look at the arrow direction, or don't need to care and will treat both cases.
//
// The banned ip will be associated to an empty map (will call pfctl -k <banned_ip>)
// The other ip will be associated to a map where the keys are the banned ips with an existing connection.
// (i.e. pfctl -k <other_ip> -k <banned_ip>) so we don't have to terminate ALL connections from other_ip.

var bannedIP, otherIP string

// left side
ip = fields[2]
if strings.Contains(ip, ":") {
ip = strings.Split(ip, ":")[0]
if _, ok := banned[left]; ok {
bannedIP = left
otherIP = right
}

ret[ip] = true
}
if _, ok := banned[right]; ok {
bannedIP = right
otherIP = left
}

if bannedIP == "" {
continue
}

// will call "pfctl -k <banned_ip>"
ret[bannedIP] = make(map[string]struct{})

log.Tracef("Found IPs in state table: %v", len(ret))
// will call "pfctl -k <other_ip> -k <banned_ip>"
if _, ok := ret[otherIP]; !ok {
ret[otherIP] = make(map[string]struct{})
}

ret[otherIP][bannedIP] = struct{}{}
}

return ret, nil
}
Expand All @@ -103,9 +128,9 @@ func (ctx *pfContext) add(decisions []*models.Decision) error {
}
}

bannedIPs := make(map[string]bool)
bannedIPs := make(map[string]struct{})
for _, d := range decisions {
bannedIPs[*d.Value] = true
bannedIPs[*d.Value] = struct{}{}
}

if len(bannedIPs) == 0 {
Expand All @@ -115,16 +140,27 @@ func (ctx *pfContext) add(decisions []*models.Decision) error {

log.Tracef("New banned IPs: %v", bannedIPs)

stateIPs, err := getStateIPs()
// Get the states of connections
// - from a banned IP
// - from any IP to a banned IP

states, err := getStatesToKill(bannedIPs)
if err != nil {
return fmt.Errorf("error while getting state IPs: %w", err)
}

// Reset the states of connections coming from an IP if it's both in stateIPs and bannedIPs
for source := range states {
targets := states[source]
if len(targets) == 0 {
cmd := execPfctl("", "-k", source)
if out, err := cmd.CombinedOutput(); err != nil {
log.Errorf("Error while flushing state (%s): %v --> %s", cmd, err, out)
}
continue
}

for ip := range bannedIPs {
if stateIPs[ip] {
cmd := execPfctl("", "-k", ip)
for target := range targets {
cmd := execPfctl("", "-k", source, "-k", target)
if out, err := cmd.CombinedOutput(); err != nil {
log.Errorf("Error while flushing state (%s): %v --> %s", cmd, err, out)
}
Expand Down

0 comments on commit ca4cc00

Please sign in to comment.