diff --git a/internal/router/firewall.go b/internal/router/firewall.go index 812ba27f..675665c4 100644 --- a/internal/router/firewall.go +++ b/internal/router/firewall.go @@ -89,7 +89,11 @@ func (f *Firewall) SetInactivityTimeout(inactivityTimeoutMinutes int) error { return errors.New("firewall instance has been closed") } - f.inactivityTimeout = time.Duration(inactivityTimeoutMinutes) * time.Minute + if inactivityTimeoutMinutes < 0 { + f.inactivityTimeout = -1 + } else { + f.inactivityTimeout = time.Duration(inactivityTimeoutMinutes) * time.Minute + } return nil } @@ -164,7 +168,7 @@ func (f *Firewall) Evaluate(src, dst netip.AddrPort, proto uint16) bool { // It doesnt matter if this gets race conditioned device := f.addressToDevice[deviceAddr.Addr()] - if device != nil && time.Since(device.lastPacketTime) < f.inactivityTimeout { + if device != nil && (f.inactivityTimeout == -1 || time.Since(device.lastPacketTime) < f.inactivityTimeout) { device.lastPacketTime = time.Now() } else { authorized = false @@ -174,7 +178,6 @@ func (f *Firewall) Evaluate(src, dst netip.AddrPort, proto uint16) bool { action := false for _, decision := range *policy { - // ANY = 0 // If we match the protocol, // If type is SINGLE and the port is either any, or equal @@ -259,7 +262,16 @@ func (f *Firewall) SetAuthorized(address string, node types.ID) error { return err } - device.sessionExpiry = time.Now().Add(time.Duration(maxSession) * time.Minute) + device.disableSessionExpiry = maxSession < 0 + + timeToSet := maxSession + if !device.disableSessionExpiry { + // when the session expiry is set, it doesnt matter what we set this to, it just cant be the time.Time{} zero value ( as that indicates unauthed) + timeToSet = 1 + } + + device.sessionExpiry = time.Now().Add(time.Duration(timeToSet) * time.Minute) + device.lastPacketTime = time.Now() device.associatedNode = node @@ -362,7 +374,11 @@ func (f *Firewall) RefreshConfiguration() []error { return []error{err} } - f.inactivityTimeout = time.Duration(inactivityTimeoutMinutes) * time.Minute + if inactivityTimeoutMinutes < 0 { + f.inactivityTimeout = -1 + } else { + f.inactivityTimeout = time.Duration(inactivityTimeoutMinutes) * time.Minute + } var allErrors []error for _, user := range allUsers { @@ -442,8 +458,7 @@ func (f *Firewall) isAuthed(addr netip.Addr) bool { } // If the device has been inactive - if device.lastPacketTime.Add(f.inactivityTimeout).Before(time.Now()) { - + if f.inactivityTimeout > 0 && device.lastPacketTime.Add(f.inactivityTimeout).Before(time.Now()) { return false } @@ -533,7 +548,9 @@ type FirewallDevice struct { address netip.Addr lastPacketTime time.Time - sessionExpiry time.Time + + disableSessionExpiry bool + sessionExpiry time.Time associatedNode types.ID @@ -551,8 +568,9 @@ func (fwd *FirewallDevice) toDTO() fwDevice { func (d *FirewallDevice) isAuthed() bool { t := time.Now() + return !d.sessionExpiry.Equal(time.Time{}) && - t.Before(d.sessionExpiry) + (t.Before(d.sessionExpiry) || d.disableSessionExpiry) } diff --git a/internal/router/firewall_test.go b/internal/router/firewall_test.go index 8e076342..4fdf7cad 100644 --- a/internal/router/firewall_test.go +++ b/internal/router/firewall_test.go @@ -557,86 +557,77 @@ func TestCompositeRules(t *testing.T) { } -// func TestDisabledSlidingWindow(t *testing.T) { - -// err := data.SetSessionInactivityTimeoutMinutes(-1) -// if err != nil { -// t.Fatal(err) -// } - -// timeout, err := data.GetSessionInactivityTimeoutMinutes() -// if err != nil { -// t.Fatal(err) -// } - -// err = testFw.SetInactivityTimeout(timeout) -// if err != nil { -// t.Fatal(err) -// } - -// var timeoutFromMap uint64 -// err = xdpObjects.InactivityTimeoutMinutes.Lookup(uint32(0), &timeoutFromMap) -// if err != nil { -// t.Fatal(err) -// } - -// if timeoutFromMap != math.MaxUint64 { -// t.Fatalf("the inactivity timeout was not set to max uint64, was %d (maxuint64 %d)", timeoutFromMap, uint64(math.MaxUint64)) -// } - -// err = SetAuthorized(devices["tester"].Address, devices["tester"].Username, uint64(data.GetServerID())) -// if err != nil { -// t.Fatal(err) -// } - -// if !IsAuthed(devices["tester"].Address) { -// t.Fatal("after setting user as authorized it should be.... authorized") -// } - -// ip, _, err := net.ParseCIDR(data.GetEffectiveAcl(devices["tester"].Username).Mfa[0]) -// if err != nil { -// t.Fatal("could not parse ip: ", err) -// } - -// testAuthorizedPacket := ipv4.Header{ -// Version: 4, -// Dst: ip, -// Src: net.ParseIP(devices["tester"].Address), -// Len: ipv4.HeaderLen, -// } - -// if testAuthorizedPacket.Src == nil || testAuthorizedPacket.Dst == nil { -// t.Fatal("could not parse ip") -// } - -// packet, err := testAuthorizedPacket.Marshal() -// if err != nil { -// t.Fatal(err) -// } - -// t.Logf("Now doing timing test for disabled sliding window waiting...") - -// elapsed := 0 -// for { -// time.Sleep(15 * time.Second) -// elapsed += 15 - -// value, _, err := xdpObjects.bpfPrograms.XdpWagFirewall.Test(packet) -// if err != nil { -// t.Fatalf("program failed %s", err) -// } - -// if value == 1 { -// if elapsed < config.Values.MaxSessionLifetimeMinutes*60 { -// t.Fatal("epbf kernel blocking valid traffic early") -// } else { -// break -// } - -// } -// } - -// } +func TestDisabledSlidingWindow(t *testing.T) { + + err := data.SetSessionInactivityTimeoutMinutes(-1) + if err != nil { + t.Fatal(err) + } + + // no op to give etcd time to update the value + data.GetSessionInactivityTimeoutMinutes() + + maxSessionLife, _ := data.GetSessionLifetimeMinutes() + + if testFw.inactivityTimeout != -1 { + t.Fatalf("the inactivity timeout was not set to -1, was %d", testFw.inactivityTimeout) + } + + err = testFw.SetAuthorized(devices["tester"].Address, data.GetServerID()) + if err != nil { + t.Fatal(err) + } + + if !testFw.IsAuthed(devices["tester"].Address) { + t.Fatal("after setting user as authorized it should be.... authorized") + } + + ip, _, err := net.ParseCIDR(data.GetEffectiveAcl(devices["tester"].Username).Mfa[0]) + if err != nil { + t.Fatal("could not parse ip: ", err) + } + + testAuthorizedPacket := ipv4.Header{ + Version: 4, + Dst: ip, + Src: net.ParseIP(devices["tester"].Address), + Len: ipv4.HeaderLen, + } + + if testAuthorizedPacket.Src == nil || testAuthorizedPacket.Dst == nil { + t.Fatal("could not parse ip") + } + + packet, err := testAuthorizedPacket.Marshal() + if err != nil { + t.Fatal(err) + } + + t.Logf("Now doing timing test for disabled sliding window waiting...") + + elapsed := 0 + for { + + value := testFw.Test(packet) + if err != nil { + t.Fatalf("program failed %s", err) + } + + if !value { + if elapsed < maxSessionLife*60 { + t.Fatal("blocking valid traffic early: ", elapsed) + } else { + break + } + + } + + time.Sleep(15 * time.Second) + elapsed += 15 + + } + +} func TestMaxSessionLifetime(t *testing.T) { @@ -687,84 +678,77 @@ func TestMaxSessionLifetime(t *testing.T) { } } -// func TestDisablingMaxLifetime(t *testing.T) { - -// // Disable session max lifetime -// err := data.SetSessionLifetimeMinutes(-1) -// if err != nil { -// t.Fatal(err) -// } - -// err = testFw.SetAuthorized(devices["tester"].Address, data.GetServerID()) -// if err != nil { -// t.Fatal(err) -// } - -// if !testFw.IsAuthed(devices["tester"].Address) { -// t.Fatal("after setting user as authorized it should be.... authorized") -// } - -// var maxSessionLifeDevice fwentry -// deviceBytes, err := xdpObjects.Devices.LookupBytes(net.ParseIP(devices["tester"].Address).To4()) -// if err != nil { -// t.Fatal(err) -// } - -// err = maxSessionLifeDevice.Unpack(deviceBytes) -// if err != nil { -// t.Fatal(err) -// } - -// if maxSessionLifeDevice.sessionExpiry != math.MaxUint64 { -// t.Fatalf("lifetime was not set to max uint64, was %d (maxuint64 %d)", maxSessionLifeDevice.sessionExpiry, uint64(math.MaxUint64)) -// } - -// ip, _, err := net.ParseCIDR(data.GetEffectiveAcl(devices["tester"].Username).Mfa[0]) -// if err != nil { -// t.Fatal("could not parse ip: ", err) -// } - -// testAuthorizedPacket := ipv4.Header{ -// Version: 4, -// Dst: ip, -// Src: net.ParseIP(devices["tester"].Address), -// Len: ipv4.HeaderLen, -// } - -// if testAuthorizedPacket.Src == nil || testAuthorizedPacket.Dst == nil { -// t.Fatal("could not parse ip") -// } - -// packet, err := testAuthorizedPacket.Marshal() -// if err != nil { -// t.Fatal(err) -// } -// t.Log(GetRoutes("tester")) -// t.Logf("Now doing timing test for disabled sliding window waiting...") - -// elapsed := 0 -// for { -// time.Sleep(15 * time.Second) -// elapsed += 15 - -// t.Logf("waiting %d sec...", elapsed) - -// value, _, err := xdpObjects.bpfPrograms.XdpWagFirewall.Test(packet) -// if err != nil { -// t.Fatalf("program failed %s", err) -// } - -// if value == 1 { -// t.Fatalf("should not block traffic") -// } - -// if elapsed > 30 { -// break -// } - -// } - -// } +func TestDisablingMaxLifetime(t *testing.T) { + + // Disable session max lifetime + err := data.SetSessionLifetimeMinutes(-1) + if err != nil { + t.Fatal(err) + } + + err = testFw.SetAuthorized(devices["tester"].Address, data.GetServerID()) + if err != nil { + t.Fatal(err) + } + + if !testFw.IsAuthed(devices["tester"].Address) { + t.Fatal("after setting user as authorized it should be.... authorized") + } + + addr, err := netip.ParseAddr(devices["tester"].Address) + if err != nil { + t.Fatal(err) + } + + device := testFw.addressToDevice[addr] + + if !device.disableSessionExpiry { + t.Fatalf("session expiry not disabled") + } + + ip, _, err := net.ParseCIDR(data.GetEffectiveAcl(devices["tester"].Username).Mfa[0]) + if err != nil { + t.Fatal("could not parse ip: ", err) + } + + testAuthorizedPacket := ipv4.Header{ + Version: 4, + Dst: ip, + Src: net.ParseIP(devices["tester"].Address), + Len: ipv4.HeaderLen, + } + + if testAuthorizedPacket.Src == nil || testAuthorizedPacket.Dst == nil { + t.Fatal("could not parse ip") + } + + packet, err := testAuthorizedPacket.Marshal() + if err != nil { + t.Fatal(err) + } + t.Log(testFw.GetRoutes("tester")) + t.Logf("Now doing timing test for disabled sliding window waiting...") + + elapsed := 0 + for { + time.Sleep(15 * time.Second) + elapsed += 15 + + t.Logf("waiting %d sec...", elapsed) + + value := testFw.Test(packet) + + if !value { + t.Fatalf("should not block traffic") + } + + if elapsed > 30 { + break + } + + } + +} func TestPortRestrictions(t *testing.T) { diff --git a/internal/router/init.go b/internal/router/init.go index 02d2d93b..417f897e 100644 --- a/internal/router/init.go +++ b/internal/router/init.go @@ -42,7 +42,11 @@ func newFw(testing, iptables bool, testDev tun.Device) (*Firewall, error) { return nil, fmt.Errorf("failed to get session inactivity timeout: %s", err) } - fw.inactivityTimeout = time.Duration(inactivityTimeoutInt) * time.Minute + if inactivityTimeoutInt > 0 { + fw.inactivityTimeout = time.Duration(inactivityTimeoutInt) * time.Minute + } else { + fw.inactivityTimeout = -1 + } fw.nodeID = data.GetServerID() fw.deviceName = config.Values.Wireguard.DevName