From 1c713a259b4c515425297714a5a0f5588a0cf8d8 Mon Sep 17 00:00:00 2001 From: nhas Date: Wed, 30 Oct 2024 10:37:40 +1300 Subject: [PATCH] Fix isAuthed issues --- internal/router/firewall.go | 6 +----- internal/router/firewall_test.go | 16 +--------------- internal/router/init.go | 8 ++++++++ internal/routetypes/policy.go | 7 +------ 4 files changed, 11 insertions(+), 26 deletions(-) diff --git a/internal/router/firewall.go b/internal/router/firewall.go index 84b6196c..61e939ac 100644 --- a/internal/router/firewall.go +++ b/internal/router/firewall.go @@ -433,11 +433,6 @@ func (f *Firewall) IsAuthed(address string) bool { func (f *Firewall) isAuthed(addr netip.Addr) bool { - ok := f.userIsLocked[addr.String()] - if !ok { - return false - } - device, ok := f.addressToDevice[addr] if !ok { return false @@ -449,6 +444,7 @@ func (f *Firewall) isAuthed(addr netip.Addr) bool { // If the device has been inactive if device.lastPacketTime.Add(f.inactivityTimeout).Before(time.Now()) { + return false } diff --git a/internal/router/firewall_test.go b/internal/router/firewall_test.go index cd98f5f0..3002b2ea 100644 --- a/internal/router/firewall_test.go +++ b/internal/router/firewall_test.go @@ -66,7 +66,7 @@ func TestAddNewDevices(t *testing.T) { for address, device := range testFw.addressToDevice { - if device.lastPacketTime.IsZero() || device.sessionExpiry.IsZero() { + if !device.lastPacketTime.IsZero() || !device.sessionExpiry.IsZero() { t.Fatal("timers were not 0 immediately after device add") } found[address.String()] = true @@ -787,9 +787,6 @@ func TestPortRestrictions(t *testing.T) { for _, rule := range rules { for _, policy := range rule.Values { - if policy.Is(routetypes.STOP) { - break - } // If we've got an any single port rule e.g 55/any, make sure that the proto is something that has ports otherwise the test fails successProto := policy.Proto @@ -895,9 +892,6 @@ func TestAgnosticRuleOrdering(t *testing.T) { for _, rule := range rules { for _, policy := range rule.Values { - if policy.Is(routetypes.STOP) { - break - } // If we've got an any single port rule e.g 55/any, make sure that the proto is something that has ports otherwise the test fails successProto := policy.Proto @@ -985,10 +979,6 @@ func TestLookupDifferentKeyTypesInMap(t *testing.T) { t.Fatal("policy was not marked as allow all despite having no rules defined") } - if !policies[1].Is(routetypes.STOP) { - t.Fatal("policy should only contain one any/any rule") - } - k = routetypes.Key{ IP: []byte{3, 3, 3, 3}, Prefixlen: 32, @@ -1009,10 +999,6 @@ func TestLookupDifferentKeyTypesInMap(t *testing.T) { t.Fatal("policy had incorrect proto and port defintions") } - if !policies[1].Is(routetypes.STOP) { - t.Fatal("policy should only contain one any/any rule") - } - } func addDevices(fw *Firewall) error { diff --git a/internal/router/init.go b/internal/router/init.go index 3a8f0c81..02d2d93b 100644 --- a/internal/router/init.go +++ b/internal/router/init.go @@ -4,6 +4,7 @@ import ( "fmt" "log" "net/netip" + "time" "github.com/NHAS/wag/internal/config" "github.com/NHAS/wag/internal/data" @@ -36,6 +37,13 @@ func newFw(testing, iptables bool, testDev tun.Device) (*Firewall, error) { hasIptables: iptables, } + inactivityTimeoutInt, err := data.GetSessionInactivityTimeoutMinutes() + if err != nil { + return nil, fmt.Errorf("failed to get session inactivity timeout: %s", err) + } + + fw.inactivityTimeout = time.Duration(inactivityTimeoutInt) * time.Minute + fw.nodeID = data.GetServerID() fw.deviceName = config.Values.Wireguard.DevName diff --git a/internal/routetypes/policy.go b/internal/routetypes/policy.go index 3bbf5aa0..9bb327cc 100644 --- a/internal/routetypes/policy.go +++ b/internal/routetypes/policy.go @@ -9,8 +9,7 @@ import ( type PolicyType uint16 const ( - ANY = 0 - STOP = 0 // Special directive, stop searching through the array, this is the end + ANY = 0 PUBLIC = 1 << PolicyType(iota) RANGE @@ -80,10 +79,6 @@ func (r Policy) String() string { restrictionType = "deny" } - if r.Is(STOP) { - return "stop" - } - if r.Is(SINGLE) { port := fmt.Sprintf("%d", r.LowerPort) if r.LowerPort == 0 {