diff --git a/go.mod b/go.mod index 36028d6..a4cf984 100644 --- a/go.mod +++ b/go.mod @@ -1,18 +1,20 @@ module github.com/in4it/wireguard-server -go 1.22.2 +go 1.23.0 require ( github.com/go-jose/go-jose/v4 v4.0.2 github.com/golang-jwt/jwt/v5 v5.2.1 github.com/google/uuid v1.6.0 + github.com/gopacket/gopacket v1.3.0 github.com/mdlayher/genetlink v1.3.2 github.com/mdlayher/netlink v1.7.2 + github.com/packetcap/go-pcap v0.0.0-20240528124601-8c87ecf5dbc5 github.com/russellhaering/gosaml2 v0.9.1 github.com/russellhaering/goxmldsig v1.4.0 - golang.org/x/crypto v0.25.0 - golang.org/x/sys v0.22.0 - golang.org/x/term v0.22.0 + golang.org/x/crypto v0.26.0 + golang.org/x/sys v0.24.0 + golang.org/x/term v0.23.0 ) require ( @@ -22,9 +24,8 @@ require ( github.com/josharian/native v1.1.0 // indirect github.com/mattermost/xml-roundtrip-validator v0.1.0 // indirect github.com/mdlayher/socket v0.5.1 // indirect - github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect - github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect - golang.org/x/net v0.27.0 // indirect - golang.org/x/sync v0.7.0 // indirect - golang.org/x/text v0.16.0 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect + golang.org/x/net v0.28.0 // indirect + golang.org/x/sync v0.8.0 // indirect + golang.org/x/text v0.17.0 // indirect ) diff --git a/go.sum b/go.sum index 43a3316..e3965de 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,3 @@ -github.com/beevik/etree v1.1.0 h1:T0xke/WvNtMoCqgzPhkX2r4rjY3GDZFi+FjpRZY2Jbs= github.com/beevik/etree v1.1.0/go.mod h1:r8Aw8JqVegEf0w2fDnATrX9VpkMcyFeM0FhwO62wh+A= github.com/beevik/etree v1.4.0 h1:oz1UedHRepuY3p4N5OjE0nK1WLCqtzHf25bxplKOHLs= github.com/beevik/etree v1.4.0/go.mod h1:cyWiXwGoasx60gHvtnEh5x8+uIjUVnjWqBvEnhnqKDA= @@ -14,8 +13,9 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gopacket/gopacket v1.3.0 h1:MouZCc+ej0vnqzB0WeiaO/6+tGvb+KU7UczxoQ+X0Yc= +github.com/gopacket/gopacket v1.3.0/go.mod h1:WnFrU1Xkf5lWKV38uKNR9+yYtppn+ZYzOyNqMeH4oNE= github.com/jonboulle/clockwork v0.2.2/go.mod h1:Pkfl5aHPm1nk2H9h0bjmnJD/BcgbGXUBGnn1kMkgxc8= -github.com/jonboulle/clockwork v0.3.0 h1:9BSCMi8C+0qdApAp4auwX0RkLGUjs956h0EkuQymUhg= github.com/jonboulle/clockwork v0.3.0/go.mod h1:Pkfl5aHPm1nk2H9h0bjmnJD/BcgbGXUBGnn1kMkgxc8= github.com/jonboulle/clockwork v0.4.0 h1:p4Cf1aMWXnXAUh8lVfewRBx1zaTSYKrKMF2g3ST4RZ4= github.com/jonboulle/clockwork v0.4.0/go.mod h1:xgRqUGwRcjKCO1vbZUEtSLrqKoPSsUpK7fnezOII0kc= @@ -35,6 +35,8 @@ github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/ github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= +github.com/packetcap/go-pcap v0.0.0-20240528124601-8c87ecf5dbc5 h1:p4VuaitqUAqSZSomd7Wb4BPV/Jj7Hno2/iqtfX7DZJI= +github.com/packetcap/go-pcap v0.0.0-20240528124601-8c87ecf5dbc5/go.mod h1:zIAoVKeWP0mz4zXY50UYQt6NLg2uwKRswMDcGEqOms4= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -42,38 +44,29 @@ github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTE github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/russellhaering/gosaml2 v0.9.1 h1:H/whrl8NuSoxyW46Ww5lKPskm+5K+qYLw9afqJ/Zef0= github.com/russellhaering/gosaml2 v0.9.1/go.mod h1:ja+qgbayxm+0mxBRLMSUuX3COqy+sb0RRhIGun/W2kc= -github.com/russellhaering/goxmldsig v1.3.0 h1:DllIWUgMy0cRUMfGiASiYEa35nsieyD3cigIwLonTPM= github.com/russellhaering/goxmldsig v1.3.0/go.mod h1:gM4MDENBQf7M+V824SGfyIUVFWydB7n0KkEubVJl+Tw= github.com/russellhaering/goxmldsig v1.4.0 h1:8UcDh/xGyQiyrW+Fq5t8f+l2DLB1+zlhYzkPUJ7Qhys= github.com/russellhaering/goxmldsig v1.4.0/go.mod h1:gM4MDENBQf7M+V824SGfyIUVFWydB7n0KkEubVJl+Tw= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= -github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= -github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= -github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= -golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= -golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= -golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= -golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= -golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= -golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= -golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= -golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= -golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= -golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA= -golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= -golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk= -golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4= -golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= -golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= +golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= +golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= +golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= +golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU= +golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk= +golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= +golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/latest b/latest index 0f1acbd..99a4aef 100644 --- a/latest +++ b/latest @@ -1 +1 @@ -v1.1.2 +v1.1.3 diff --git a/pkg/auth/oidc/store/discovery_test.go b/pkg/auth/oidc/store/discovery_test.go index 3d5037e..31b6d9b 100644 --- a/pkg/auth/oidc/store/discovery_test.go +++ b/pkg/auth/oidc/store/discovery_test.go @@ -7,7 +7,7 @@ import ( "testing" "github.com/in4it/wireguard-server/pkg/auth/oidc" - testingmocks "github.com/in4it/wireguard-server/pkg/testing/mocks" + memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" ) func TestGetDiscovery(t *testing.T) { @@ -27,7 +27,7 @@ func TestGetDiscovery(t *testing.T) { })) defer ts.Close() - store, err := NewStore(&testingmocks.MockMemoryStorage{}) + store, err := NewStore(&memorystorage.MockMemoryStorage{}) if err != nil { t.Fatalf("new store error: %s", err) } diff --git a/pkg/auth/oidc/store/jwks_test.go b/pkg/auth/oidc/store/jwks_test.go index 74698ff..94c808a 100644 --- a/pkg/auth/oidc/store/jwks_test.go +++ b/pkg/auth/oidc/store/jwks_test.go @@ -7,7 +7,7 @@ import ( "testing" "github.com/in4it/wireguard-server/pkg/auth/oidc" - testingmocks "github.com/in4it/wireguard-server/pkg/testing/mocks" + memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" ) func TestGetJwks(t *testing.T) { @@ -32,7 +32,7 @@ func TestGetJwks(t *testing.T) { })) defer ts.Close() - store, err := NewStore(&testingmocks.MockMemoryStorage{}) + store, err := NewStore(&memorystorage.MockMemoryStorage{}) if err != nil { t.Fatalf("new store error: %s", err) } diff --git a/pkg/auth/oidc/store/save_test.go b/pkg/auth/oidc/store/save_test.go index bc56eb2..9e33d53 100644 --- a/pkg/auth/oidc/store/save_test.go +++ b/pkg/auth/oidc/store/save_test.go @@ -4,11 +4,11 @@ import ( "testing" "github.com/in4it/wireguard-server/pkg/auth/oidc" - testingmocks "github.com/in4it/wireguard-server/pkg/testing/mocks" + memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" ) func TestSave(t *testing.T) { - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} store, err := NewStore(storage) if err != nil { t.Fatalf("error: %s", err) diff --git a/pkg/auth/provisioning/scim/users_test.go b/pkg/auth/provisioning/scim/users_test.go index e02cd61..2c399ac 100644 --- a/pkg/auth/provisioning/scim/users_test.go +++ b/pkg/auth/provisioning/scim/users_test.go @@ -11,7 +11,7 @@ import ( "path" "testing" - testingmocks "github.com/in4it/wireguard-server/pkg/testing/mocks" + memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" "github.com/in4it/wireguard-server/pkg/users" "github.com/in4it/wireguard-server/pkg/wireguard" ) @@ -19,7 +19,7 @@ import ( const USERSTORE_MAX_USERS = 1000 func TestUsersGetCount100EmptyResult(t *testing.T) { - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} userStore, err := users.NewUserStore(storage, USERSTORE_MAX_USERS) if err != nil { @@ -48,7 +48,7 @@ func TestUsersGetCount100EmptyResult(t *testing.T) { } func TestUsersGetCount10(t *testing.T) { - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} userStore, err := users.NewUserStore(storage, USERSTORE_MAX_USERS) if err != nil { t.Fatalf("cannot create new user store") @@ -88,7 +88,7 @@ func TestUsersGetCount10(t *testing.T) { func TestUsersGetCount10Start5(t *testing.T) { count := 10 start := 5 - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} userStore, err := users.NewUserStore(storage, USERSTORE_MAX_USERS) if err != nil { t.Fatalf("cannot create new user store") @@ -138,12 +138,12 @@ func TestUsersGetCount10Start5(t *testing.T) { } func TestUsersGetNonExistentUser(t *testing.T) { - userStore, err := users.NewUserStore(&testingmocks.MockMemoryStorage{}, USERSTORE_MAX_USERS) + userStore, err := users.NewUserStore(&memorystorage.MockMemoryStorage{}, USERSTORE_MAX_USERS) if err != nil { t.Fatalf("cannot create new user stoer") } - s := New(&testingmocks.MockMemoryStorage{}, userStore, "token") + s := New(&memorystorage.MockMemoryStorage{}, userStore, "token") req := httptest.NewRequest("GET", "http://example.com/api/scim/v2/Users?filter=userName+eq+%22ward%40in4it.io%22&", nil) w := httptest.NewRecorder() s.getUsersHandler(w, req) @@ -161,7 +161,7 @@ func TestUsersGetNonExistentUser(t *testing.T) { } func TestAddUser(t *testing.T) { - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} userStore, err := users.NewUserStore(storage, USERSTORE_MAX_USERS) if err != nil { t.Fatalf("cannot create new user store: %s", err) @@ -208,7 +208,7 @@ func TestAddUser(t *testing.T) { } func TestCreateUserConnectionDeleteUserFlow(t *testing.T) { - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} userStore, err := users.NewUserStore(storage, USERSTORE_MAX_USERS) if err != nil { t.Fatalf("cannot create new user store: %s", err) @@ -232,6 +232,11 @@ func TestCreateUserConnectionDeleteUserFlow(t *testing.T) { w.Write([]byte("OK")) return } + if r.RequestURI == "/refresh-server-config" { + w.WriteHeader(http.StatusAccepted) + w.Write([]byte("OK")) + return + } w.WriteHeader(http.StatusBadRequest) default: w.WriteHeader(http.StatusBadRequest) @@ -323,7 +328,7 @@ func TestCreateUserConnectionDeleteUserFlow(t *testing.T) { } } func TestCreateUserConnectionSuspendUserFlow(t *testing.T) { - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} userStore, err := users.NewUserStore(storage, USERSTORE_MAX_USERS) if err != nil { @@ -348,6 +353,11 @@ func TestCreateUserConnectionSuspendUserFlow(t *testing.T) { w.Write([]byte("OK")) return } + if r.RequestURI == "/refresh-server-config" { + w.WriteHeader(http.StatusAccepted) + w.Write([]byte("OK")) + return + } w.WriteHeader(http.StatusBadRequest) default: w.WriteHeader(http.StatusBadRequest) diff --git a/pkg/configmanager/handlers.go b/pkg/configmanager/handlers.go index f66f7dd..c0b5314 100644 --- a/pkg/configmanager/handlers.go +++ b/pkg/configmanager/handlers.go @@ -59,13 +59,13 @@ func (c *ConfigManager) refreshClients(w http.ResponseWriter, r *http.Request) { } switch payload.Action { case wireguard.ACTION_ADD: - err = syncClient(c.Storage, filename) + err = syncClient(c.Storage, filename, c.ClientCache) if err != nil { returnError(w, fmt.Errorf("syncClient error: %s", err), http.StatusBadRequest) return } case wireguard.ACTION_DELETE: - err = deleteClient(c.Storage, filename) + err = deleteClient(c.Storage, filename, c.ClientCache) if err != nil { returnError(w, fmt.Errorf("deleteClient error: %s", err), http.StatusBadRequest) return @@ -77,6 +77,28 @@ func (c *ConfigManager) refreshClients(w http.ResponseWriter, r *http.Request) { returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) } } +func (c *ConfigManager) refreshServerConfig(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPost: + vpnConfig, err := wireguard.GetVPNConfig(c.Storage) + if err != nil { + returnError(w, fmt.Errorf("get vpn config error: %s", err), http.StatusBadRequest) + return + } + startPacketLogger := false + if vpnConfig.EnablePacketLogs && !c.VPNConfig.EnablePacketLogs { + startPacketLogger = true + } + c.VPNConfig.EnablePacketLogs = vpnConfig.EnablePacketLogs + c.VPNConfig.PacketLogsTypes = vpnConfig.PacketLogsTypes + if startPacketLogger { + go wireguard.RunPacketLogger(c.Storage, c.ClientCache, c.VPNConfig) + } + w.WriteHeader(http.StatusAccepted) + default: + returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) + } +} func (c *ConfigManager) upgrade(w http.ResponseWriter, r *http.Request) { switch r.Method { @@ -131,7 +153,7 @@ func (c *ConfigManager) restartVpn(w http.ResponseWriter, r *http.Request) { returnError(w, fmt.Errorf("vpn start error: %s", err), http.StatusBadRequest) return } - err = refreshAllClientsAndServer(c.Storage) + err = refreshAllClientsAndServer(c.Storage, c.ClientCache) if err != nil { returnError(w, fmt.Errorf("could not refresh all clients: %s", err), http.StatusBadRequest) return diff --git a/pkg/configmanager/refresh_darwin.go b/pkg/configmanager/refresh_darwin.go index 3c4300a..6ffb396 100644 --- a/pkg/configmanager/refresh_darwin.go +++ b/pkg/configmanager/refresh_darwin.go @@ -4,17 +4,49 @@ package configmanager import ( + "errors" "fmt" + "os" "github.com/in4it/wireguard-server/pkg/storage" + "github.com/in4it/wireguard-server/pkg/wireguard" ) -func refreshAllClientsAndServer(storage storage.Iface) error { +func refreshAllClientsAndServer(storage storage.Iface, clientCache *wireguard.ClientCache) error { + peerConfigPath := storage.ConfigPath(wireguard.VPN_CLIENTS_DIR) + + if _, err := os.Stat(peerConfigPath); errors.Is(err, os.ErrNotExist) { + return nil // directory doesn't exist, so no configs to be read + } + + entries, err := storage.ReadDir(peerConfigPath) + if err != nil { + return fmt.Errorf("can not list clients from dir %s: %s", peerConfigPath, err) + } + + for _, filename := range entries { + peerConfig, err := wireguard.GetPeerConfigByFilename(storage, filename) + if err != nil { + return fmt.Errorf("getClientFile error: %s", err) + } + err = wireguard.UpdateClientCache(peerConfig, clientCache) + if err != nil { + return fmt.Errorf("update client cache error: %s", err) + } + } fmt.Printf("Warning: not refreshAllClients supported on darwin\n") return nil } -func syncClient(storage storage.Iface, filename string) error { +func syncClient(storage storage.Iface, filename string, clientCache *wireguard.ClientCache) error { + peerConfig, err := wireguard.GetPeerConfigByFilename(storage, filename) + if err != nil { + return fmt.Errorf("getClientFile error: %s", err) + } + err = wireguard.UpdateClientCache(peerConfig, clientCache) + if err != nil { + return fmt.Errorf("update client cache error: %s", err) + } fmt.Printf("Warning: syncClient not supported on darwin. Cannot sync: %s\n", filename) return nil } @@ -24,7 +56,15 @@ func cleanupClients(storage storage.Iface) error { return nil } -func deleteClient(storage storage.Iface, filename string) error { +func deleteClient(storage storage.Iface, filename string, clientCache *wireguard.ClientCache) error { + peerConfig, err := wireguard.GetPeerConfigByFilename(storage, filename) + if err != nil { + return fmt.Errorf("getClientFile error: %s", err) + } + err = wireguard.UpdateClientCache(peerConfig, clientCache) + if err != nil { + return fmt.Errorf("update client cache error: %s", err) + } fmt.Printf("Warning: deleteClient not supported on darwin. Cannot delete: %s\n", filename) return nil } diff --git a/pkg/configmanager/refresh_linux.go b/pkg/configmanager/refresh_linux.go index 8a4cf8f..b8311c8 100644 --- a/pkg/configmanager/refresh_linux.go +++ b/pkg/configmanager/refresh_linux.go @@ -13,12 +13,28 @@ import ( syncclients "github.com/in4it/wireguard-server/pkg/wireguard/linux/syncclients" ) -func syncClient(storage storage.Iface, filename string) error { - go syncclients.SyncClientsAndCleanup(storage, filename) +func syncClient(storage storage.Iface, filename string, clientCache *wireguard.ClientCache) error { + peerConfig, err := wireguard.GetPeerConfigByFilename(storage, filename) + if err != nil { + return fmt.Errorf("getClientFile error: %s", err) + } + err = wireguard.UpdateClientCache(peerConfig, clientCache) + if err != nil { + return fmt.Errorf("update client cache error: %s", err) + } + go syncclients.SyncClientsAndCleanup(storage, peerConfig) return nil } -func deleteClient(storage storage.Iface, filename string) error { - go syncclients.DeleteClient(storage, filename) +func deleteClient(storage storage.Iface, filename string, clientCache *wireguard.ClientCache) error { + peerConfig, err := wireguard.GetPeerConfigByFilename(storage, filename) + if err != nil { + return fmt.Errorf("getClientFile error: %s", err) + } + err = wireguard.UpdateClientCache(peerConfig, clientCache) + if err != nil { + return fmt.Errorf("update client cache error: %s", err) + } + go syncclients.DeleteClient(peerConfig) return nil } func cleanupClients(storage storage.Iface) error { @@ -26,7 +42,7 @@ func cleanupClients(storage storage.Iface) error { return nil } -func refreshAllClientsAndServer(storage storage.Iface) error { +func refreshAllClientsAndServer(storage storage.Iface, clientCache *wireguard.ClientCache) error { peerConfigPath := storage.ConfigPath(wireguard.VPN_CLIENTS_DIR) if _, err := os.Stat(peerConfigPath); errors.Is(err, os.ErrNotExist) { @@ -39,7 +55,15 @@ func refreshAllClientsAndServer(storage storage.Iface) error { } for _, e := range entries { - err = syncclients.SyncClients(storage, e) + peerConfig, err := wireguard.GetPeerConfigByFilename(storage, e) + if err != nil { + return fmt.Errorf("getClientFile error: %s", err) + } + err = wireguard.UpdateClientCache(peerConfig, clientCache) + if err != nil { + return fmt.Errorf("update client cache error: %s", err) + } + err = syncclients.SyncClients(storage, peerConfig) if err != nil { return fmt.Errorf("SyncClients error: %s", err) } diff --git a/pkg/configmanager/router.go b/pkg/configmanager/router.go index 59a32d8..7b4f2b2 100644 --- a/pkg/configmanager/router.go +++ b/pkg/configmanager/router.go @@ -7,6 +7,7 @@ func (c *ConfigManager) getRouter() *http.ServeMux { mux.Handle("/pubkey", http.HandlerFunc(c.getPubKey)) mux.Handle("/refresh-clients", http.HandlerFunc(c.refreshClients)) + mux.Handle("/refresh-server-config", http.HandlerFunc(c.refreshServerConfig)) mux.Handle("/upgrade", http.HandlerFunc(c.upgrade)) mux.Handle("/restart-vpn", http.HandlerFunc(c.restartVpn)) mux.Handle("/version", http.HandlerFunc(c.version)) diff --git a/pkg/configmanager/server.go b/pkg/configmanager/server.go index 4ea37f9..3ef2580 100644 --- a/pkg/configmanager/server.go +++ b/pkg/configmanager/server.go @@ -35,12 +35,14 @@ func StartServer(port int) { } // refresh all clients - err = refreshAllClientsAndServer(localStorage) + err = refreshAllClientsAndServer(localStorage, c.ClientCache) if err != nil { log.Fatalf("could not refresh all clients: %s", err) } - startStats(localStorage) // start gathering of wireguard stats + // start goroutines + startStats(localStorage) // start gathering of wireguard stats + startPacketLogger(localStorage, c.ClientCache, c.VPNConfig) // start packet logger (optional) log.Printf("Starting localhost http server at port %d\n", port) log.Fatal(http.ListenAndServe(fmt.Sprintf("127.0.0.1:%d", port), c.getRouter())) @@ -49,6 +51,9 @@ func StartServer(port int) { func initConfigManager(storage storage.Iface) (*ConfigManager, error) { c := &ConfigManager{ Storage: storage, + ClientCache: &wireguard.ClientCache{ + Addresses: []wireguard.ClientCacheAddresses{}, + }, } vpnConfig, err := wireguard.GetVPNConfig(storage) @@ -67,5 +72,7 @@ func initConfigManager(storage storage.Iface) (*ConfigManager, error) { } } + c.VPNConfig = &vpnConfig + return c, nil } diff --git a/pkg/configmanager/start_darwin.go b/pkg/configmanager/start_darwin.go index 2f7bb3e..07ea124 100644 --- a/pkg/configmanager/start_darwin.go +++ b/pkg/configmanager/start_darwin.go @@ -7,6 +7,7 @@ import ( "fmt" "github.com/in4it/wireguard-server/pkg/storage" + "github.com/in4it/wireguard-server/pkg/wireguard" ) func startVPN(storage storage.Iface) error { @@ -22,3 +23,9 @@ func stopVPN() error { func startStats(storage storage.Iface) { fmt.Printf("Warning: startStats is not implemented in darwin\n") } + +func startPacketLogger(storage storage.Iface, clientCache *wireguard.ClientCache, vpnConfig *wireguard.VPNConfig) { + go wireguard.RunPacketLogger(storage, clientCache, vpnConfig) + // run cleanup + go wireguard.PacketLoggerLogRotation(storage) +} diff --git a/pkg/configmanager/start_linux.go b/pkg/configmanager/start_linux.go index 334f6df..033ee02 100644 --- a/pkg/configmanager/start_linux.go +++ b/pkg/configmanager/start_linux.go @@ -27,3 +27,10 @@ func startStats(storage storage.Iface) { // run statistics go routine go wireguard.RunStats(storage) } + +func startPacketLogger(storage storage.Iface, clientCache *wireguard.ClientCache, vpnConfig *wireguard.VPNConfig) { + // run statistics go routine + go wireguard.RunPacketLogger(storage, clientCache, vpnConfig) + // run cleanup + go wireguard.PacketLoggerLogRotation(storage) +} diff --git a/pkg/configmanager/types.go b/pkg/configmanager/types.go index b7d8d54..0546805 100644 --- a/pkg/configmanager/types.go +++ b/pkg/configmanager/types.go @@ -2,12 +2,15 @@ package configmanager import ( "github.com/in4it/wireguard-server/pkg/storage" + "github.com/in4it/wireguard-server/pkg/wireguard" ) type ConfigManager struct { - PrivateKey string - PublicKey string - Storage storage.Iface + PrivateKey string + PublicKey string + Storage storage.Iface + ClientCache *wireguard.ClientCache + VPNConfig *wireguard.VPNConfig } type UpgradeResponse struct { diff --git a/pkg/license/gcp_test.go b/pkg/license/gcp_test.go index 33cc566..37d2672 100644 --- a/pkg/license/gcp_test.go +++ b/pkg/license/gcp_test.go @@ -9,7 +9,7 @@ import ( "testing" "time" - testingmocks "github.com/in4it/wireguard-server/pkg/testing/mocks" + memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" ) func TestGuessInfrastructureGCP(t *testing.T) { @@ -52,10 +52,10 @@ func TestGetMaxUsersGCPBYOL(t *testing.T) { licenseURL = ts.URL metadataIP = strings.Replace(ts.URL, "http://", "", -1) - mockStorage := &testingmocks.MockReadWriter{ - Data: map[string][]byte{ - "config/license.key": []byte("license-1234556-license"), - }, + mockStorage := &memorystorage.MockMemoryStorage{} + err := mockStorage.WriteFile("config/license.key", []byte("license-1234556-license")) + if err != nil { + t.Fatalf("writefile error: %s", err) } for _, v := range []int{50} { diff --git a/pkg/license/license_test.go b/pkg/license/license_test.go index 8a41601..9f41356 100644 --- a/pkg/license/license_test.go +++ b/pkg/license/license_test.go @@ -11,7 +11,7 @@ import ( "time" "github.com/in4it/wireguard-server/pkg/logging" - testingmocks "github.com/in4it/wireguard-server/pkg/testing/mocks" + memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" ) func TestGetMaxUsersAzure(t *testing.T) { @@ -39,7 +39,7 @@ func TestGetMaxUsersAWSMarketplace(t *testing.T) { "t3.xlarge": 250, } for instanceType, v := range testCases { - if getMaxUsers(&testingmocks.MockReadWriter{}, "aws-marketplace", instanceType) != v { + if getMaxUsers(&memorystorage.MockMemoryStorage{}, "aws-marketplace", instanceType) != v { t.Fatalf("Wrong output: %d vs %d", GetMaxUsersAWS(instanceType), v) } } @@ -98,7 +98,7 @@ func TestGetMaxUsersAWSBYOL(t *testing.T) { licenseURL = ts.URL metadataIP = strings.Replace(ts.URL, "http://", "", -1) for _, v := range testCases { - if v2 := GetMaxUsersAWSBYOL(http.Client{Timeout: 5 * time.Second}, &testingmocks.MockMemoryStorage{}); v2 != v { + if v2 := GetMaxUsersAWSBYOL(http.Client{Timeout: 5 * time.Second}, &memorystorage.MockMemoryStorage{}); v2 != v { t.Fatalf("Wrong output: %d vs %d", v2, v) } } @@ -111,7 +111,7 @@ func TestGetMaxUsersAWS(t *testing.T) { "t3.xlarge": 3, } for instanceType, v := range testCases { - if getMaxUsers(&testingmocks.MockReadWriter{}, "aws", instanceType) != v { + if getMaxUsers(&memorystorage.MockMemoryStorage{}, "aws", instanceType) != v { t.Fatalf("Wrong output: %d vs %d", GetMaxUsersAWS(instanceType), v) } } @@ -212,7 +212,7 @@ func TestGuessInfrastructureAWS(t *testing.T) { t.Fatalf("wrong infra returned: %s", infra) } - if getMaxUsers(&testingmocks.MockReadWriter{}, infra, "t3.large") != 3 { + if getMaxUsers(&memorystorage.MockMemoryStorage{}, infra, "t3.large") != 3 { t.Fatalf("wrong users returned") } } @@ -252,7 +252,7 @@ func TestGetAzureInstanceType(t *testing.T) { usersPerVCPU := 25 - users := getMaxUsers(&testingmocks.MockReadWriter{}, "azure", getAzureInstanceType(http.Client{Timeout: 5 * time.Second})) + users := getMaxUsers(&memorystorage.MockMemoryStorage{}, "azure", getAzureInstanceType(http.Client{Timeout: 5 * time.Second})) if users != usersPerVCPU*2 { t.Fatalf("Wrong user count returned") @@ -302,7 +302,7 @@ func TestGuessInfrastructureDigitalOcean(t *testing.T) { t.Fatalf("wrong infra returned: %s", infra) } - if getMaxUsers(&testingmocks.MockReadWriter{}, infra, "t3.large") != 3 { + if getMaxUsers(&memorystorage.MockMemoryStorage{}, infra, "t3.large") != 3 { t.Fatalf("wrong users returned") } } @@ -332,12 +332,11 @@ func TestGetMaxUsersDigitalOceanBYOL(t *testing.T) { licenseURL = ts.URL metadataIP = strings.Replace(ts.URL, "http://", "", -1) - mockStorage := &testingmocks.MockReadWriter{ - Data: map[string][]byte{ - "config/license.key": []byte("license-1234556-license"), - }, + mockStorage := &memorystorage.MockMemoryStorage{} + err := mockStorage.WriteFile("config/license.key", []byte("license-1234556-license")) + if err != nil { + t.Fatalf("writefile error: %s", err) } - for _, v := range []int{50} { if v2 := GetMaxUsersDigitalOceanBYOL(http.Client{Timeout: 5 * time.Second}, mockStorage); v2 != v { t.Fatalf("Wrong output: %d vs %d", v2, v) @@ -387,19 +386,19 @@ func TestGetLicenseKey(t *testing.T) { metadataIP = strings.Replace(ts.URL, "http://", "", -1) logging.Loglevel = logging.LOG_DEBUG + logging.LOG_ERROR - key := GetLicenseKey(&testingmocks.MockMemoryStorage{}, "") + key := GetLicenseKey(&memorystorage.MockMemoryStorage{}, "") if key == "" { t.Fatalf("key is empty") } - key = GetLicenseKey(&testingmocks.MockMemoryStorage{}, "aws") + key = GetLicenseKey(&memorystorage.MockMemoryStorage{}, "aws") if key == "" { t.Fatalf("aws key is empty") } - key = GetLicenseKey(&testingmocks.MockMemoryStorage{}, "digitalocean") + key = GetLicenseKey(&memorystorage.MockMemoryStorage{}, "digitalocean") if key == "" { t.Fatalf("digitalocean key is empty") } - key = GetLicenseKey(&testingmocks.MockMemoryStorage{}, "gcp") + key = GetLicenseKey(&memorystorage.MockMemoryStorage{}, "gcp") if key == "" { t.Fatalf("gcp key is empty") } @@ -407,7 +406,7 @@ func TestGetLicenseKey(t *testing.T) { func TestGetLicenseKeyNoCloudProvider(t *testing.T) { logging.Loglevel = logging.LOG_DEBUG + logging.LOG_ERROR - key := GetLicenseKey(&testingmocks.MockMemoryStorage{}, "") + key := GetLicenseKey(&memorystorage.MockMemoryStorage{}, "") if key == "" { t.Fatalf("key is empty") } diff --git a/pkg/rest/auth_test.go b/pkg/rest/auth_test.go index 17078e9..848e8f3 100644 --- a/pkg/rest/auth_test.go +++ b/pkg/rest/auth_test.go @@ -23,7 +23,7 @@ import ( "github.com/in4it/wireguard-server/pkg/auth/saml" "github.com/in4it/wireguard-server/pkg/logging" "github.com/in4it/wireguard-server/pkg/rest/login" - testingmocks "github.com/in4it/wireguard-server/pkg/testing/mocks" + memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" "github.com/in4it/wireguard-server/pkg/users" "github.com/russellhaering/gosaml2/types" dsigtypes "github.com/russellhaering/goxmldsig/types" @@ -80,7 +80,7 @@ c7tL1QjbfAUHAQYwmHkWgPP+T2wAv0pOt36GgMCM` } func TestAuthHandler(t *testing.T) { - c, err := newContext(&testingmocks.MockMemoryStorage{}, SERVER_TYPE_VPN) + c, err := newContext(&memorystorage.MockMemoryStorage{}, SERVER_TYPE_VPN) if err != nil { t.Fatalf("Cannot create context: %s", err) } @@ -130,7 +130,7 @@ func TestAuthHandler(t *testing.T) { func TestNewSAMLConnection(t *testing.T) { // generate new keypair - kp := saml.NewKeyPair(&testingmocks.MockMemoryStorage{}, "www.idp.inv") + kp := saml.NewKeyPair(&memorystorage.MockMemoryStorage{}, "www.idp.inv") _, cert, err := kp.GetKeyPair() if err != nil { t.Fatalf("Can't generate new keypair: %s", err) @@ -191,7 +191,7 @@ func TestNewSAMLConnection(t *testing.T) { defer l.Close() // first create a new user - c, err := newContext(&testingmocks.MockMemoryStorage{}, SERVER_TYPE_VPN) + c, err := newContext(&memorystorage.MockMemoryStorage{}, SERVER_TYPE_VPN) if err != nil { t.Fatalf("Cannot create context") } @@ -419,7 +419,7 @@ func TestNewSAMLConnection(t *testing.T) { } func TestAddModifyDeleteNewSAMLConnection(t *testing.T) { - c, err := newContext(&testingmocks.MockMemoryStorage{}, SERVER_TYPE_VPN) + c, err := newContext(&memorystorage.MockMemoryStorage{}, SERVER_TYPE_VPN) if err != nil { t.Fatalf("Cannot create context") } @@ -592,7 +592,7 @@ func TestAddModifyDeleteNewSAMLConnection(t *testing.T) { } func TestSAMLCallback(t *testing.T) { - c, err := newContext(&testingmocks.MockMemoryStorage{}, SERVER_TYPE_VPN) + c, err := newContext(&memorystorage.MockMemoryStorage{}, SERVER_TYPE_VPN) if err != nil { t.Fatalf("Cannot create context") } @@ -701,7 +701,7 @@ func TestOIDCFlow(t *testing.T) { } // first create a new user - c, err := newContext(&testingmocks.MockMemoryStorage{}, SERVER_TYPE_VPN) + c, err := newContext(&memorystorage.MockMemoryStorage{}, SERVER_TYPE_VPN) if err != nil { t.Fatalf("Cannot create context") } diff --git a/pkg/rest/router.go b/pkg/rest/router.go index 1206d01..98a1e8d 100644 --- a/pkg/rest/router.go +++ b/pkg/rest/router.go @@ -63,6 +63,7 @@ func (c *Context) getRouter(assets fs.FS, indexHtml []byte) *http.ServeMux { mux.Handle("/api/users", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.usersHandler))))) mux.Handle("/api/user/{id}", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.userHandler))))) mux.Handle("/api/stats/user/{date}", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.userStatsHandler))))) + mux.Handle("/api/stats/packetlogs/{user}/{date}", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.packetLogsHandler))))) return mux } diff --git a/pkg/rest/rsa_test.go b/pkg/rest/rsa_test.go index 8ac5844..ce23943 100644 --- a/pkg/rest/rsa_test.go +++ b/pkg/rest/rsa_test.go @@ -6,11 +6,11 @@ import ( "encoding/pem" "testing" - testingmocks "github.com/in4it/wireguard-server/pkg/testing/mocks" + memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" ) func TestGetJWTKeys(t *testing.T) { - mockStorage := testingmocks.MockMemoryStorage{} + mockStorage := memorystorage.MockMemoryStorage{} keys, err := getJWTKeys(&mockStorage) if err != nil { t.Fatalf("error: %s", err) diff --git a/pkg/rest/setup.go b/pkg/rest/setup.go index 236c1c9..4354dbf 100644 --- a/pkg/rest/setup.go +++ b/pkg/rest/setup.go @@ -9,6 +9,8 @@ import ( "net/http" "net/netip" "reflect" + "slices" + "sort" "strconv" "strings" "time" @@ -167,6 +169,12 @@ func (c *Context) vpnSetupHandler(w http.ResponseWriter, r *http.Request) { } switch r.Method { case http.MethodGet: + packetLogTypes := []string{} + for k, enabled := range vpnConfig.PacketLogsTypes { + if enabled { + packetLogTypes = append(packetLogTypes, k) + } + } setupRequest := VPNSetupRequest{ Routes: strings.Join(vpnConfig.ClientRoutes, ", "), VPNEndpoint: vpnConfig.Endpoint, @@ -176,6 +184,8 @@ func (c *Context) vpnSetupHandler(w http.ResponseWriter, r *http.Request) { ExternalInterface: vpnConfig.ExternalInterface, Nameservers: strings.Join(vpnConfig.Nameservers, ","), DisableNAT: vpnConfig.DisableNAT, + EnablePacketLogs: vpnConfig.EnablePacketLogs, + PacketLogsTypes: packetLogTypes, } out, err := json.Marshal(setupRequest) if err != nil { @@ -258,6 +268,28 @@ func (c *Context) vpnSetupHandler(w http.ResponseWriter, r *http.Request) { vpnConfig.DisableNAT = setupRequest.DisableNAT writeVPNConfig = true } + if setupRequest.EnablePacketLogs != vpnConfig.EnablePacketLogs { + vpnConfig.EnablePacketLogs = setupRequest.EnablePacketLogs + writeVPNConfig = true + } + // packetlogtypes + packetLogTypes := []string{} + for k, enabled := range vpnConfig.PacketLogsTypes { + if enabled { + packetLogTypes = append(packetLogTypes, k) + } + } + sort.Strings(setupRequest.PacketLogsTypes) + sort.Strings(packetLogTypes) + if !slices.Equal(setupRequest.PacketLogsTypes, packetLogTypes) { + vpnConfig.PacketLogsTypes = make(map[string]bool) + for _, v := range setupRequest.PacketLogsTypes { + if v == "http+https" || v == "dns" || v == "tcp" { + vpnConfig.PacketLogsTypes[v] = true + } + } + writeVPNConfig = true + } // write vpn config if config has changed if writeVPNConfig { diff --git a/pkg/rest/stats.go b/pkg/rest/stats.go index ae6ccd0..fc66f4a 100644 --- a/pkg/rest/stats.go +++ b/pkg/rest/stats.go @@ -13,9 +13,13 @@ import ( "strings" "time" + "github.com/in4it/wireguard-server/pkg/storage" + dateutils "github.com/in4it/wireguard-server/pkg/utils/date" "github.com/in4it/wireguard-server/pkg/wireguard" ) +const MAX_LOG_OUTPUT_LINES = 100 + func (c *Context) userStatsHandler(w http.ResponseWriter, r *http.Request) { if r.PathValue("date") == "" { c.returnError(w, fmt.Errorf("no date supplied"), http.StatusBadRequest) @@ -54,7 +58,7 @@ func (c *Context) userStatsHandler(w http.ResponseWriter, r *http.Request) { path.Join(wireguard.VPN_STATS_DIR, "user-"+date.AddDate(0, 0, -1).Format("2006-01-02")+".log"), path.Join(wireguard.VPN_STATS_DIR, "user-"+date.Format("2006-01-02")+".log"), } - if !dateEqual(time.Now(), date) { + if !dateutils.DateEqual(time.Now(), date) { statsFiles = append(statsFiles, path.Join(wireguard.VPN_STATS_DIR, "user-"+date.AddDate(0, 0, 1).Format("2006-01-02")+".log")) } logData := bytes.NewBuffer([]byte{}) @@ -108,7 +112,7 @@ func (c *Context) userStatsHandler(w http.ResponseWriter, r *http.Request) { timestamp, err := time.Parse(wireguard.TIMESTAMP_FORMAT, inputSplit[0]) if err == nil { timestamp = timestamp.Add(time.Duration(offset) * time.Minute) - if dateEqual(timestamp, date) { + if dateutils.DateEqual(timestamp, date) { receiveBytesData[userID] = append(receiveBytesData[userID], UserStatsDataPoint{X: timestamp.Format(wireguard.TIMESTAMP_FORMAT), Y: value}) } } @@ -122,7 +126,7 @@ func (c *Context) userStatsHandler(w http.ResponseWriter, r *http.Request) { timestamp, err := time.Parse(wireguard.TIMESTAMP_FORMAT, inputSplit[0]) if err == nil { timestamp = timestamp.Add(time.Duration(offset) * time.Minute) - if dateEqual(timestamp, date) { + if dateutils.DateEqual(timestamp, date) { transmitBytesData[userID] = append(transmitBytesData[userID], UserStatsDataPoint{X: timestamp.Format(wireguard.TIMESTAMP_FORMAT), Y: value}) } } @@ -130,7 +134,7 @@ func (c *Context) userStatsHandler(w http.ResponseWriter, r *http.Request) { handshake, err := time.Parse(wireguard.TIMESTAMP_FORMAT, inputSplit[5]) if err == nil { handshake = handshake.Add(time.Duration(offset) * time.Minute) - if dateEqual(handshake, date) && !handshake.Equal(handshakeLast[userID]) { + if dateutils.DateEqual(handshake, date) && !handshake.Equal(handshakeLast[userID]) { if _, ok := handshakeData[userID]; !ok { handshakeData[userID] = []UserStatsDataPoint{} } @@ -210,6 +214,155 @@ func (c *Context) userStatsHandler(w http.ResponseWriter, r *http.Request) { c.write(w, out) } +func (c *Context) packetLogsHandler(w http.ResponseWriter, r *http.Request) { + vpnConfig, err := wireguard.GetVPNConfig(c.Storage.Client) + if err != nil { + c.returnError(w, fmt.Errorf("get vpn config error: %s", err), http.StatusBadRequest) + return + } + if !vpnConfig.EnablePacketLogs { // packet logs is disabled + out, err := json.Marshal(LogDataResponse{Enabled: false}) + if err != nil { + c.returnError(w, fmt.Errorf("user stats response marshal error: %s", err), http.StatusBadRequest) + return + } + c.write(w, out) + return + } + userID := r.PathValue("user") + if userID == "" { + c.returnError(w, fmt.Errorf("no user supplied"), http.StatusBadRequest) + return + } + if r.PathValue("date") == "" { + c.returnError(w, fmt.Errorf("no date supplied"), http.StatusBadRequest) + return + } + date, err := time.Parse("2006-01-02", r.PathValue("date")) + if err != nil { + c.returnError(w, fmt.Errorf("invalid date: %s", err), http.StatusBadRequest) + return + } + offset := 0 + if r.FormValue("offset") != "" { + i, err := strconv.Atoi(r.FormValue("offset")) + if err == nil { + offset = i + } + } + pos := int64(0) + if r.FormValue("pos") != "" { + i, err := strconv.ParseInt(r.FormValue("pos"), 10, 0) + if err == nil { + pos = i + } + } + search := r.FormValue("search") + // get all users + users := c.UserStore.ListUsers() + userMap := make(map[string]string) + for _, user := range users { + userMap[user.ID] = user.Login + } + // get filter + logTypeFilterQueryString := r.URL.Query().Get("logtype") + logTypeFilter := strings.Split(logTypeFilterQueryString, ",") + // initialize response + logData := LogData{ + Schema: LogSchema{ + Columns: map[string]string{ + "Protocol": "string", + "Source IP": "string", + "Destination IP": "string", + "Source Port": "string", + "Destination Port": "string", + "Destination": "string", + }, + }, + Data: []LogRow{}, + } + // logs + statsFiles := []string{ + path.Join(wireguard.VPN_STATS_DIR, wireguard.VPN_PACKETLOGGER_DIR, userID+"-"+date.Format("2006-01-02")+".log"), + } + if !dateutils.DateEqual(time.Now(), date) { // date is in local timezone, and we are UTC, so also read next file + statsFiles = append(statsFiles, path.Join(wireguard.VPN_STATS_DIR, wireguard.VPN_PACKETLOGGER_DIR, userID+"-"+date.AddDate(0, 0, 1).Format("2006-01-02")+".log")) + } + statsFiles = filterNonExistentFiles(c.Storage.Client, statsFiles) + fileReaders, err := c.Storage.Client.OpenFilesFromPos(statsFiles, pos) + if err != nil { + c.returnError(w, fmt.Errorf("error while reading files: %s", err), http.StatusBadRequest) + return + } + for _, fileReader := range fileReaders { + defer fileReader.Close() + } + + for _, logInputData := range fileReaders { // read multiple files + if len(logData.Data) >= MAX_LOG_OUTPUT_LINES { + break + } + scanner := bufio.NewScanner(logInputData) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + advance, token, err = bufio.ScanLines(data, atEOF) + pos += int64(advance) + return + }) + for scanner.Scan() && len(logData.Data) < MAX_LOG_OUTPUT_LINES { // read multiple lines + inputSplit := strings.Split(scanner.Text(), ",") + timestamp, err := time.Parse(wireguard.TIMESTAMP_FORMAT, inputSplit[0]) + if err != nil { + continue // invalid record + } + timestamp = timestamp.Add(time.Duration(offset) * time.Minute) + if dateutils.DateEqual(timestamp, date) { + if !filterLogRecord(logTypeFilter, inputSplit[1]) && matchesSearch(search, inputSplit) { + row := LogRow{ + Timestamp: timestamp.Format("2006-01-02 15:04:05"), + Data: inputSplit[1:], + } + logData.Data = append(logData.Data, row) + } + } + } + if err := scanner.Err(); err != nil { + c.returnError(w, fmt.Errorf("log file read (scanner) error: %s", err), http.StatusBadRequest) + return + } + } + if len(logData.Data) < MAX_LOG_OUTPUT_LINES { + pos = -1 // no more records + } + + // set position + logData.NextPos = pos + + // logtypes + packetLogTypes := []string{} + for k, enabled := range vpnConfig.PacketLogsTypes { + if enabled { + packetLogTypes = append(packetLogTypes, k) + } + } + + out, err := json.Marshal(LogDataResponse{Enabled: true, LogData: logData, LogTypes: packetLogTypes, Users: userMap}) + if err != nil { + c.returnError(w, fmt.Errorf("user stats response marshal error: %s", err), http.StatusBadRequest) + return + } + c.write(w, out) +} + +func filterNonExistentFiles(storage storage.Iface, files []string) []string { + res := []string{} + for _, file := range files { + if storage.FileExists(file) { + res = append(res, file) + } + } + return res +} + func getColor(i int) string { colors := []string{ "#DEEFB7", @@ -236,8 +389,36 @@ func getColor(i int) string { return colors[i%len(colors)] } -func dateEqual(date1, date2 time.Time) bool { - y1, m1, d1 := date1.Date() - y2, m2, d2 := date2.Date() - return y1 == y2 && m1 == m2 && d1 == d2 +func filterLogRecord(logTypeFilter []string, logType string) bool { + if len(logTypeFilter) > 0 && logTypeFilter[0] != "" { + for _, logTypeFilterItem := range logTypeFilter { + if logType == logTypeFilterItem { + return false + } + + if logTypeFilterItem == "dns" && logType == "udp" { + return false + } + + splitLogTypes := strings.Split(logTypeFilterItem, "+") + for _, splitLogType := range splitLogTypes { + if splitLogType == logType { + return false + } + } + } + return true + } + return false +} +func matchesSearch(search string, data []string) bool { + if search == "" { + return true + } + for _, element := range data { + if strings.Contains(element, search) { + return true + } + } + return false } diff --git a/pkg/rest/stats_test.go b/pkg/rest/stats_test.go index 447aa69..b1da1fc 100644 --- a/pkg/rest/stats_test.go +++ b/pkg/rest/stats_test.go @@ -8,13 +8,13 @@ import ( "testing" "time" - testingmocks "github.com/in4it/wireguard-server/pkg/testing/mocks" + memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" "github.com/in4it/wireguard-server/pkg/wireguard" ) func TestUserStatsHandler(t *testing.T) { - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} c, err := newContext(storage, SERVER_TYPE_VPN) if err != nil { @@ -67,3 +67,14 @@ func TestUserStatsHandler(t *testing.T) { } } + +func TestFilterLogRecord(t *testing.T) { + logTypeFilter := []string{"tcp", "http+https"} + expected := []bool{false, false, true, false} + for k, v := range []string{"tcp", "http", "udp", "https"} { + res := filterLogRecord(logTypeFilter, v) + if res != expected[k] { + t.Fatalf("unexpected result: %v, got: %v", res, expected[k]) + } + } +} diff --git a/pkg/rest/types.go b/pkg/rest/types.go index ff7fdc0..fe72db0 100644 --- a/pkg/rest/types.go +++ b/pkg/rest/types.go @@ -102,14 +102,16 @@ type GeneralSetupRequest struct { } type VPNSetupRequest struct { - Routes string `json:"routes"` - VPNEndpoint string `json:"vpnEndpoint"` - AddressRange string `json:"addressRange"` - ClientAddressPrefix string `json:"clientAddressPrefix"` - Port string `json:"port"` - ExternalInterface string `json:"externalInterface"` - Nameservers string `json:"nameservers"` - DisableNAT bool `json:"disableNAT"` + Routes string `json:"routes"` + VPNEndpoint string `json:"vpnEndpoint"` + AddressRange string `json:"addressRange"` + ClientAddressPrefix string `json:"clientAddressPrefix"` + Port string `json:"port"` + ExternalInterface string `json:"externalInterface"` + Nameservers string `json:"nameservers"` + DisableNAT bool `json:"disableNAT"` + EnablePacketLogs bool `json:"enablePacketLogs"` + PacketLogsTypes []string `json:"packetLogsTypes"` } type TemplateSetupRequest struct { @@ -205,3 +207,23 @@ type NewUserRequest struct { Role string `json:"role"` Password string `json:"password,omitempty"` } + +type LogDataResponse struct { + LogData LogData `json:"logData"` + Enabled bool `json:"enabled"` + LogTypes []string `json:"logTypes"` + Users map[string]string `json:"users"` +} + +type LogData struct { + Schema LogSchema `json:"schema"` + Data []LogRow `json:"rows"` + NextPos int64 `json:"nextPos"` +} +type LogSchema struct { + Columns map[string]string `json:"columns"` +} +type LogRow struct { + Timestamp string `json:"t"` + Data []string `json:"d"` +} diff --git a/pkg/rest/users_test.go b/pkg/rest/users_test.go index c95e920..c516691 100644 --- a/pkg/rest/users_test.go +++ b/pkg/rest/users_test.go @@ -13,7 +13,7 @@ import ( "strings" "testing" - testingmocks "github.com/in4it/wireguard-server/pkg/testing/mocks" + memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" "github.com/in4it/wireguard-server/pkg/users" "github.com/in4it/wireguard-server/pkg/wireguard" ) @@ -32,6 +32,11 @@ func TestCreateUserConnectionDeleteUserFlow(t *testing.T) { w.Write([]byte("OK")) return } + if r.RequestURI == "/refresh-server-config" { + w.WriteHeader(http.StatusAccepted) + w.Write([]byte("OK")) + return + } w.WriteHeader(http.StatusBadRequest) default: w.WriteHeader(http.StatusBadRequest) @@ -45,7 +50,7 @@ func TestCreateUserConnectionDeleteUserFlow(t *testing.T) { defer l.Close() // first create a new user - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} c, err := newContext(storage, SERVER_TYPE_VPN) if err != nil { @@ -156,7 +161,7 @@ func TestCreateUserConnectionDeleteUserFlow(t *testing.T) { func TestCreateUser(t *testing.T) { // first create a new user - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} c, err := newContext(storage, SERVER_TYPE_VPN) if err != nil { diff --git a/pkg/storage/iface.go b/pkg/storage/iface.go index 7cc8367..212b1a4 100644 --- a/pkg/storage/iface.go +++ b/pkg/storage/iface.go @@ -1,13 +1,17 @@ package storage +import "io" + type Iface interface { GetPath() string EnsurePath(path string) error EnsureOwnership(filename, login string) error ReadDir(name string) ([]string, error) Remove(name string) error + Rename(oldName, newName string) error AppendFile(name string, data []byte) error ReadWriter + Seeker } type ReadWriter interface { @@ -15,4 +19,10 @@ type ReadWriter interface { WriteFile(name string, data []byte) error FileExists(filename string) bool ConfigPath(filename string) string + OpenFile(name string) (io.ReadCloser, error) + OpenFileForWriting(name string) (io.WriteCloser, error) +} + +type Seeker interface { + OpenFilesFromPos(names []string, pos int64) ([]io.ReadCloser, error) } diff --git a/pkg/storage/local/path.go b/pkg/storage/local/path.go index 03dee39..f1e8d8b 100644 --- a/pkg/storage/local/path.go +++ b/pkg/storage/local/path.go @@ -81,3 +81,7 @@ func (l *LocalStorage) ReadDir(pathname string) ([]string, error) { func (l *LocalStorage) Remove(name string) error { return os.Remove(path.Join(l.path, name)) } + +func (l *LocalStorage) Rename(oldName, newName string) error { + return os.Rename(path.Join(l.path, oldName), path.Join(l.path, newName)) +} diff --git a/pkg/storage/local/read.go b/pkg/storage/local/read.go index 435f273..3a2e6ce 100644 --- a/pkg/storage/local/read.go +++ b/pkg/storage/local/read.go @@ -1,6 +1,8 @@ package localstorage import ( + "fmt" + "io" "os" "path" ) @@ -8,3 +10,39 @@ import ( func (l *LocalStorage) ReadFile(name string) ([]byte, error) { return os.ReadFile(path.Join(l.path, name)) } + +func (l *LocalStorage) OpenFilesFromPos(names []string, pos int64) ([]io.ReadCloser, error) { + readers := []io.ReadCloser{} + if pos < 0 { + return readers, nil + } + for _, name := range names { + file, err := os.Open(path.Join(l.path, name)) + if err != nil { + return nil, fmt.Errorf("cannot open file (%s): %s", name, err) + } + stat, err := file.Stat() + if err != nil { + return nil, fmt.Errorf("cannot get file stat (%s): %s", name, err) + } + if stat.Size() <= pos { + pos -= stat.Size() + } else { + _, err := file.Seek(pos, 0) + if err != nil { + return nil, fmt.Errorf("could not seek to pos (file: %s): %s", name, err) + } + pos = 0 + readers = append(readers, file) + } + } + return readers, nil +} + +func (l *LocalStorage) OpenFile(name string) (io.ReadCloser, error) { + file, err := os.Open(path.Join(l.path, name)) + if err != nil { + return nil, fmt.Errorf("cannot open file (%s): %s", name, err) + } + return file, nil +} diff --git a/pkg/storage/local/read_test.go b/pkg/storage/local/read_test.go new file mode 100644 index 0000000..c40e6b4 --- /dev/null +++ b/pkg/storage/local/read_test.go @@ -0,0 +1,91 @@ +package localstorage + +import ( + "bytes" + "io" + "os" + "path" + "testing" +) + +func TestOpenFilesFromPos(t *testing.T) { + pwd, err := os.Executable() + if err != nil { + t.Fatalf("os Executable error: %s", err) + } + l := LocalStorage{ + path: path.Dir(pwd), + } + contents1 := []byte(`this is the first file`) + contents2 := []byte(`this is the second file`) + err = l.WriteFile("1.txt", contents1) + if err != nil { + t.Fatalf("write file error: %s", err) + } + err = l.WriteFile("2.txt", contents2) + if err != nil { + t.Fatalf("write file error: %s", err) + } + t.Cleanup(func() { + err = os.Remove(path.Join(l.path, "1.txt")) + if err != nil { + t.Fatalf("file delete error: %s", err) + } + err = os.Remove(path.Join(l.path, "2.txt")) + if err != nil { + t.Fatalf("file delete error: %s", err) + } + }) + expected := []string{ + "this is the first filethis is the second file", + "is the first filethis is the second file", + "this is the second file", + "ethis is the second file", + "his is the second file", + "", + "", + "", + } + expextedOpenFiles := []int{ + 2, + 2, + 1, + 2, + 1, + 0, + 0, + 0, + } + tests := []int64{ + 0, + 5, + int64(len(contents1)), + int64(len(contents1) - 1), + int64(len(contents1) + 1), + int64(len(contents1) + len(contents2)), + int64(len(contents1) + len(contents2) + 1), + -5, + } + for k, pos := range tests { + files, err := l.OpenFilesFromPos([]string{"1.txt", "2.txt"}, pos) + if err != nil { + t.Fatalf("open file error: %s", err) + } + contents := bytes.NewBuffer([]byte{}) + for _, file := range files { + defer file.Close() + body, err := io.ReadAll(file) + if err != nil { + t.Fatalf("could not read file: %s", err) + } + contents.Write(body) + } + if expected[k] != contents.String() { + t.Fatalf("unexpected output: expected '%s' got '%s'", expected[k], contents.String()) + } + if expextedOpenFiles[k] != len(files) { + t.Fatalf("unexpected open files: expected %d got %d", expextedOpenFiles[k], len(files)) + } + } + +} diff --git a/pkg/storage/local/write.go b/pkg/storage/local/write.go index 00e135c..91a88b1 100644 --- a/pkg/storage/local/write.go +++ b/pkg/storage/local/write.go @@ -1,6 +1,8 @@ package localstorage import ( + "fmt" + "io" "os" "path" ) @@ -21,3 +23,11 @@ func (l *LocalStorage) AppendFile(name string, data []byte) error { return nil } + +func (l *LocalStorage) OpenFileForWriting(name string) (io.WriteCloser, error) { + file, err := os.Create(path.Join(l.path, name)) + if err != nil { + return nil, fmt.Errorf("cannot open file (%s): %s", name, err) + } + return file, nil +} diff --git a/pkg/storage/memory/storage.go b/pkg/storage/memory/storage.go new file mode 100644 index 0000000..2246e63 --- /dev/null +++ b/pkg/storage/memory/storage.go @@ -0,0 +1,142 @@ +package memorystorage + +import ( + "bufio" + "bytes" + "fmt" + "io" + "os" + "path" + "strings" +) + +type MyWriteCloser struct { + *bufio.Writer +} + +func (mwc *MyWriteCloser) Close() error { + return nil +} + +type MockReadWriterData []byte + +func (m *MockReadWriterData) Close() error { + return nil +} +func (m *MockReadWriterData) Write(p []byte) (nn int, err error) { + *m = append(*m, p...) + return len(p), nil +} + +type MockMemoryStorage struct { + Data map[string]*MockReadWriterData +} + +func (m *MockMemoryStorage) ConfigPath(filename string) string { + return path.Join("config", filename) +} +func (m *MockMemoryStorage) Rename(oldName, newName string) error { + if m.Data == nil { + m.Data = make(map[string]*MockReadWriterData) + } + _, ok := m.Data[oldName] + if !ok { + return fmt.Errorf("file doesn't exist") + } + m.Data[newName] = m.Data[oldName] + delete(m.Data, oldName) + return nil +} +func (m *MockMemoryStorage) FileExists(name string) bool { + if m.Data == nil { + m.Data = make(map[string]*MockReadWriterData) + } + _, ok := m.Data[name] + return ok +} + +func (m *MockMemoryStorage) ReadFile(name string) ([]byte, error) { + if m.Data == nil { + m.Data = make(map[string]*MockReadWriterData) + } + val, ok := m.Data[name] + if !ok { + return nil, fmt.Errorf("file does not exist") + } + return *val, nil +} +func (m *MockMemoryStorage) WriteFile(name string, data []byte) error { + if m.Data == nil { + m.Data = make(map[string]*MockReadWriterData) + } + m.Data[name] = (*MockReadWriterData)(&data) + return nil +} +func (m *MockMemoryStorage) AppendFile(name string, data []byte) error { + if m.Data == nil { + m.Data = make(map[string]*MockReadWriterData) + } + if m.Data[name] == nil { + m.Data[name] = (*MockReadWriterData)(&data) + } else { + *m.Data[name] = append(*m.Data[name], data...) + } + + return nil +} + +func (m *MockMemoryStorage) GetPath() string { + pwd, _ := os.Executable() + return path.Dir(pwd) +} + +func (m *MockMemoryStorage) EnsurePath(pathname string) error { + return nil +} + +func (m *MockMemoryStorage) EnsureOwnership(filename, login string) error { + return nil +} + +func (m *MockMemoryStorage) ReadDir(path string) ([]string, error) { + if m.Data == nil { + m.Data = make(map[string]*MockReadWriterData) + } + res := []string{} + for k := range m.Data { + if strings.HasPrefix(k, path+"/") { + res = append(res, strings.ReplaceAll(k, path+"/", "")) + } + } + return res, nil +} + +func (m *MockMemoryStorage) Remove(name string) error { + if m.Data == nil { + m.Data = make(map[string]*MockReadWriterData) + } + delete(m.Data, name) + return nil +} + +func (m *MockMemoryStorage) OpenFilesFromPos(names []string, pos int64) ([]io.ReadCloser, error) { + return nil, fmt.Errorf("not implemented") +} +func (m *MockMemoryStorage) OpenFile(name string) (io.ReadCloser, error) { + if m.Data == nil { + m.Data = make(map[string]*MockReadWriterData) + } + val, ok := m.Data[name] + if !ok { + return nil, fmt.Errorf("file does not exist") + } + + return io.NopCloser(bytes.NewBuffer(*val)), nil +} +func (m *MockMemoryStorage) OpenFileForWriting(name string) (io.WriteCloser, error) { + if m.Data == nil { + m.Data = make(map[string]*MockReadWriterData) + } + m.Data[name] = (*MockReadWriterData)(&[]byte{}) + return m.Data[name], nil +} diff --git a/pkg/testing/mocks/storage.go b/pkg/testing/mocks/storage.go deleted file mode 100644 index 774823a..0000000 --- a/pkg/testing/mocks/storage.go +++ /dev/null @@ -1,115 +0,0 @@ -package testingmocks - -import ( - "fmt" - "os" - "path" - "strings" -) - -type MockReadWriter struct { - Data map[string][]byte -} - -func (m *MockReadWriter) ConfigPath(filename string) string { - return path.Join("config", filename) -} -func (m *MockReadWriter) FileExists(name string) bool { - if m.Data == nil { - m.Data = make(map[string][]byte) - } - _, ok := m.Data[name] - return ok -} - -func (m *MockReadWriter) ReadFile(name string) ([]byte, error) { - if m.Data == nil { - m.Data = make(map[string][]byte) - } - val, ok := m.Data[name] - if !ok { - return val, fmt.Errorf("file does not exist") - } - return val, nil -} -func (m *MockReadWriter) WriteFile(name string, data []byte) error { - if m.Data == nil { - m.Data = make(map[string][]byte) - } - m.Data[name] = data - return nil -} - -type MockMemoryStorage struct { - Data map[string][]byte -} - -func (m *MockMemoryStorage) ConfigPath(filename string) string { - return path.Join("config", filename) -} -func (m *MockMemoryStorage) FileExists(name string) bool { - if m.Data == nil { - m.Data = make(map[string][]byte) - } - _, ok := m.Data[name] - return ok -} - -func (m *MockMemoryStorage) ReadFile(name string) ([]byte, error) { - if m.Data == nil { - m.Data = make(map[string][]byte) - } - val, ok := m.Data[name] - if !ok { - return val, fmt.Errorf("file does not exist") - } - return val, nil -} -func (m *MockMemoryStorage) WriteFile(name string, data []byte) error { - if m.Data == nil { - m.Data = make(map[string][]byte) - } - m.Data[name] = data - return nil -} -func (m *MockMemoryStorage) AppendFile(name string, data []byte) error { - if m.Data == nil { - m.Data = make(map[string][]byte) - } - m.Data[name] = append(m.Data[name], data...) - return nil -} - -func (m *MockMemoryStorage) GetPath() string { - pwd, _ := os.Executable() - return path.Dir(pwd) -} - -func (m *MockMemoryStorage) EnsurePath(pathname string) error { - return nil -} - -func (m *MockMemoryStorage) EnsureOwnership(filename, login string) error { - return nil -} - -func (m *MockMemoryStorage) ReadDir(path string) ([]string, error) { - if m.Data == nil { - m.Data = make(map[string][]byte) - } - res := []string{} - for k := range m.Data { - if strings.HasPrefix(k, path+"/") { - res = append(res, strings.ReplaceAll(k, path+"/", "")) - } - } - return res, nil -} - -func (m *MockMemoryStorage) Remove(name string) error { - if m.Data == nil { - m.Data = make(map[string][]byte) - } - delete(m.Data, name) - return nil -} diff --git a/pkg/utils/date/compare.go b/pkg/utils/date/compare.go new file mode 100644 index 0000000..b6a901e --- /dev/null +++ b/pkg/utils/date/compare.go @@ -0,0 +1,9 @@ +package dateutils + +import "time" + +func DateEqual(date1, date2 time.Time) bool { + y1, m1, d1 := date1.Date() + y2, m2, d2 := date2.Date() + return y1 == y2 && m1 == m2 && d1 == d2 +} diff --git a/pkg/wireguard/cache.go b/pkg/wireguard/cache.go new file mode 100644 index 0000000..2214540 --- /dev/null +++ b/pkg/wireguard/cache.go @@ -0,0 +1,32 @@ +package wireguard + +import ( + "fmt" + "net" +) + +func UpdateClientCache(peerConfig PeerConfig, clientCache *ClientCache) error { + _, peerConfigAddressParsed, err := net.ParseCIDR(peerConfig.Address) + if err != nil { + return fmt.Errorf("cannot parse peerConfig's address: %s", err) + } + found := false + for k, addressesItem := range clientCache.Addresses { + if addressesItem.ClientID == peerConfig.ID { + found = true + if addressesItem.Address.String() != peerConfig.Address { + clientCache.Addresses[k].Address = *peerConfigAddressParsed + return nil + } + } + } + + if !found { + clientCache.Addresses = append(clientCache.Addresses, ClientCacheAddresses{ + Address: *peerConfigAddressParsed, + ClientID: peerConfig.ID, + }) + } + + return nil +} diff --git a/pkg/wireguard/constants.go b/pkg/wireguard/constants.go index e2bce0e..26beba2 100644 --- a/pkg/wireguard/constants.go +++ b/pkg/wireguard/constants.go @@ -8,6 +8,7 @@ const VPN_CONFIG_NAME = "vpn-config.json" const IP_LIST_PATH = "config/iplist.json" const VPN_CLIENTS_DIR = "clients" const VPN_STATS_DIR = "stats" +const VPN_PACKETLOGGER_DIR = "packetlogs" const VPN_SERVER_SECRETS_PATH = "secrets" const VPN_PRIVATE_KEY_FILENAME = "priv.key" const PRESHARED_KEY_FILENAME = "preshared.key" diff --git a/pkg/wireguard/linux/syncclients/process.go b/pkg/wireguard/linux/syncclients/process.go index b310e74..0da83f2 100644 --- a/pkg/wireguard/linux/syncclients/process.go +++ b/pkg/wireguard/linux/syncclients/process.go @@ -4,30 +4,24 @@ package processpeerconfig import ( - "encoding/json" "fmt" "log" - "path" "github.com/in4it/wireguard-server/pkg/storage" "github.com/in4it/wireguard-server/pkg/wireguard" wireguardlinux "github.com/in4it/wireguard-server/pkg/wireguard/linux" ) -func SyncClients(storage storage.Iface, filename string) error { - peerConfig, peerConfigFilename, err := getClientFile(storage, filename) +func SyncClients(storage storage.Iface, peerConfig wireguard.PeerConfig) error { + err := processPeerConfig(storage, peerConfig) if err != nil { - return fmt.Errorf("getClientFile error: %s", err) - } - err = processPeerConfig(storage, peerConfig) - if err != nil { - return fmt.Errorf("could not process peerconfig (%s): %s", peerConfigFilename, err) + return fmt.Errorf("could not process peerconfig (%s): %s", peerConfig.ID, err) } return nil } -func SyncClientsAndCleanup(storage storage.Iface, filename string) { - if err := SyncClients(storage, filename); err != nil { +func SyncClientsAndCleanup(storage storage.Iface, peerConfig wireguard.PeerConfig) { + if err := SyncClients(storage, peerConfig); err != nil { returnErrorInGoRoutine(err) return } @@ -37,34 +31,14 @@ func SyncClientsAndCleanup(storage storage.Iface, filename string) { } } -func DeleteClient(storage storage.Iface, filename string) { - peerConfig, peerConfigFilename, err := getClientFile(storage, filename) - if err != nil { - returnErrorInGoRoutine(fmt.Errorf("getClientFile error: %s", err)) - return - } - err = processDeleteOfPeerConfig(peerConfig) +func DeleteClient(peerConfig wireguard.PeerConfig) { + err := processDeleteOfPeerConfig(peerConfig) if err != nil { - returnErrorInGoRoutine(fmt.Errorf("could not process delete of peerconfig (%s): %s", peerConfigFilename, err)) + returnErrorInGoRoutine(fmt.Errorf("could not process delete of peerconfig (%s): %s", peerConfig.ID, err)) return } } -func getClientFile(storage storage.Iface, filename string) (wireguard.PeerConfig, string, error) { - var peerConfig wireguard.PeerConfig - - peerConfigFilename := storage.ConfigPath(path.Join(wireguard.VPN_CLIENTS_DIR, filename)) - peerConfigData, err := storage.ReadFile(peerConfigFilename) - if err != nil { - return peerConfig, peerConfigFilename, fmt.Errorf("could not read clients filename (%s): %s", peerConfigFilename, err) - } - err = json.Unmarshal(peerConfigData, &peerConfig) - if err != nil { - return peerConfig, peerConfigFilename, fmt.Errorf("could not read unmarshal peerconfig(%s): %s", peerConfigFilename, err) - } - return peerConfig, peerConfigFilename, nil -} - func processPeerConfig(storage storage.Iface, peerConfig wireguard.PeerConfig) error { c, available, err := wireguardlinux.New() if err != nil { diff --git a/pkg/wireguard/packetlogger.go b/pkg/wireguard/packetlogger.go new file mode 100644 index 0000000..4d977fb --- /dev/null +++ b/pkg/wireguard/packetlogger.go @@ -0,0 +1,383 @@ +package wireguard + +import ( + "bufio" + "bytes" + "compress/gzip" + "encoding/binary" + "fmt" + "io" + "net" + "net/http" + "os" + "path" + "runtime" + "strconv" + "strings" + "sync" + "time" + + "github.com/gopacket/gopacket" + "github.com/gopacket/gopacket/layers" + "github.com/in4it/wireguard-server/pkg/logging" + "github.com/in4it/wireguard-server/pkg/storage" + dateutils "github.com/in4it/wireguard-server/pkg/utils/date" + "github.com/packetcap/go-pcap" + "golang.org/x/sys/unix" +) + +var ( + PacketLoggerIsRunning sync.Mutex +) + +func RunPacketLogger(storage storage.Iface, clientCache *ClientCache, vpnConfig *VPNConfig) { + if !vpnConfig.EnablePacketLogs { + return + } + fmt.Printf("starting packetlogger") + // ensure we only run a single instance of the packet logger + PacketLoggerIsRunning.Lock() + defer PacketLoggerIsRunning.Unlock() + // ensure logs dir is created + err := storage.EnsurePath(VPN_STATS_DIR) + if err != nil { + logging.ErrorLog(fmt.Errorf("could not create stats path: %s. Stats disabled", err)) + return + } + err = storage.EnsureOwnership(VPN_STATS_DIR, "vpn") + if err != nil { + logging.ErrorLog(fmt.Errorf("could not ensure ownership of stats path: %s. Stats disabled", err)) + return + } + err = storage.EnsurePath(path.Join(VPN_STATS_DIR, VPN_PACKETLOGGER_DIR)) + if err != nil { + logging.ErrorLog(fmt.Errorf("could not create stats path: %s. Stats disabled", err)) + return + } + err = storage.EnsureOwnership(path.Join(VPN_STATS_DIR, VPN_PACKETLOGGER_DIR), "vpn") + if err != nil { + logging.ErrorLog(fmt.Errorf("could not ensure ownership of stats path: %s. Stats disabled", err)) + return + } + + useSyscalls := false + if runtime.GOOS == "darwin" { + useSyscalls = true + } + handle, err := pcap.OpenLive(VPN_INTERFACE_NAME, 1600, false, 0, useSyscalls) + if err != nil { + logging.ErrorLog(fmt.Errorf("can't start packet inspector: %s", err)) + return + } + defer handle.Close() + i := 0 + for { + err := readPacket(storage, handle, clientCache) + if err != nil { + logging.DebugLog(fmt.Errorf("readPacket error: %s", err)) + } + if !vpnConfig.EnablePacketLogs { + logging.InfoLog("disabling packetlogs") + return + } + if i%1000 == 0 { + if err := checkDiskSpace(); err != nil { + logging.ErrorLog(fmt.Errorf("disk space error: %s", err)) + return + } + i = 0 + } + i++ + } +} +func readPacket(storage storage.Iface, handle *pcap.Handle, clientCache *ClientCache) error { + data, _, err := handle.ReadPacketData() + if err != nil { + return fmt.Errorf("read packet error: %s", err) + } + return parsePacket(storage, data, clientCache) +} +func parsePacket(storage storage.Iface, data []byte, clientCache *ClientCache) error { + packet := gopacket.NewPacket(data, layers.IPProtocolIPv4, gopacket.DecodeOptions{Lazy: true, DecodeStreamsAsDatagrams: true}) + var ( + ip4 *layers.IPv4 + ip6 *layers.IPv6 + srcIP net.IP + dstIP net.IP + ) + + if ipv4Layer := packet.Layer(layers.LayerTypeIPv4); ipv4Layer != nil { + ip4 = ipv4Layer.(*layers.IPv4) + srcIP = ip4.SrcIP + dstIP = ip4.DstIP + } + if ipv6Layer := packet.Layer(layers.LayerTypeIPv6); ipv6Layer != nil { + ip6 = ipv6Layer.(*layers.IPv6) + srcIP = ip6.SrcIP + dstIP = ip6.DstIP + } + if ip4 == nil && ip6 == nil { + return fmt.Errorf("got packet which is not ipv4/ipv6") + } + + clientID := "" + for _, address := range clientCache.Addresses { + if address.Address.Contains(srcIP) { + clientID = address.ClientID + } + } + if clientID == "" { // doesn't match a client ID + return nil + } + now := time.Now() + filename := path.Join(VPN_STATS_DIR, VPN_PACKETLOGGER_DIR, clientID+"-"+now.Format("2006-01-02")+".log") + if tcpLayer := packet.Layer(layers.LayerTypeTCP); tcpLayer != nil { + tcpPacket, _ := tcpLayer.(*layers.TCP) + if tcpPacket.SYN { + storage.AppendFile(filename, []byte(strings.Join([]string{ + time.Now().Format(TIMESTAMP_FORMAT), + "tcp", + srcIP.String(), + dstIP.String(), + strconv.FormatUint(uint64(tcpPacket.SrcPort), 10), + strconv.FormatUint(uint64(tcpPacket.DstPort), 10)}, + ",")+"\n", + )) + } + switch tcpPacket.DstPort { + case 80: + if tcpPacket.DstPort == 80 { + appLayer := packet.ApplicationLayer() + if appLayer != nil { + req, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(appLayer.Payload()))) + if err != nil { + fmt.Printf("debug: can't parse http packet: %s", err) + } else { + storage.AppendFile(filename, []byte(strings.Join([]string{ + time.Now().Format(TIMESTAMP_FORMAT), + "http", + srcIP.String(), + dstIP.String(), + strconv.FormatUint(uint64(tcpPacket.SrcPort), 10), + strconv.FormatUint(uint64(tcpPacket.DstPort), 10), + "http://" + req.Host + req.URL.RequestURI()}, + ",")+"\n", + )) + } + } + } + case 443: + if tls, ok := packet.Layer(layers.LayerTypeTLS).(*layers.TLS); ok { + for _, handshake := range tls.Handshake { + if sni := parseTLSExtensionSNI([]byte(handshake.ClientHello.Extensions)); sni != nil { + storage.AppendFile(filename, []byte(strings.Join([]string{ + time.Now().Format(TIMESTAMP_FORMAT), + "https", + srcIP.String(), + dstIP.String(), + strconv.FormatUint(uint64(tcpPacket.SrcPort), 10), + strconv.FormatUint(uint64(tcpPacket.DstPort), 10), + string(sni)}, + ",")+"\n", + )) + } + } + } + } + } + if udpLayer := packet.Layer(layers.LayerTypeUDP); udpLayer != nil { + udp, _ := udpLayer.(*layers.UDP) + + if udp.NextLayerType().Contains(layers.LayerTypeDNS) { + dnsPacket := packet.Layer(layers.LayerTypeDNS) + if dnsPacket != nil { + udpDNS := dnsPacket.(*layers.DNS) + questions := []string{} + for k := range udpDNS.Questions { + found := false + for _, question := range questions { + if question == string(udpDNS.Questions[k].Name) { + found = true + } + } + if !found { + questions = append(questions, string(udpDNS.Questions[k].Name)) + } + + } + storage.AppendFile(filename, []byte(strings.Join([]string{ + time.Now().Format(TIMESTAMP_FORMAT), + "udp", + srcIP.String(), + dstIP.String(), + strconv.FormatUint(uint64(udp.SrcPort), 10), + strconv.FormatUint(uint64(udp.DstPort), 10), + strings.Join(questions, "#")}, + ",")+"\n")) + } + } + } + + return nil +} + +// TLS Extensions http://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml +type TLSExtension uint16 + +const ( + ExtServerName TLSExtension = 0 + ExtMaxFragLen TLSExtension = 1 + ExtClientCertURL TLSExtension = 2 + ExtTrustedCAKeys TLSExtension = 3 + ExtTruncatedHMAC TLSExtension = 4 + ExtStatusRequest TLSExtension = 5 + ExtUserMapping TLSExtension = 6 + ExtClientAuthz TLSExtension = 7 + ExtServerAuthz TLSExtension = 8 + ExtCertType TLSExtension = 9 + ExtSupportedGroups TLSExtension = 10 + ExtECPointFormats TLSExtension = 11 + ExtSRP TLSExtension = 12 + ExtSignatureAlgs TLSExtension = 13 + ExtUseSRTP TLSExtension = 14 + ExtHeartbeat TLSExtension = 15 + ExtALPN TLSExtension = 16 + ExtStatusRequestV2 TLSExtension = 17 + ExtSignedCertTS TLSExtension = 18 + ExtClientCertType TLSExtension = 19 + ExtServerCertType TLSExtension = 20 + ExtPadding TLSExtension = 21 + ExtEncryptThenMAC TLSExtension = 22 + ExtExtendedMasterSecret TLSExtension = 23 + ExtSessionTicket TLSExtension = 35 + ExtNPN TLSExtension = 13172 + ExtRenegotiationInfo TLSExtension = 65281 +) + +func parseTLSExtensionSNI(data []byte) []byte { + for len(data) > 0 { + if len(data) < 4 { + break + } + extensionType := binary.BigEndian.Uint16(data[:2]) + length := binary.BigEndian.Uint16(data[2:4]) + if len(data) < 4+int(length) { + break + } + if TLSExtension(extensionType) == ExtServerName && len(data) > 6 { + serverNameExtensionLength := binary.BigEndian.Uint16(data[4:6]) + entryType := data[6] + + if serverNameExtensionLength > 0 && entryType == 0 && len(data) > 8 { // 0 = DNS hostname + hostnameLength := binary.BigEndian.Uint16(data[7:9]) + if len(data) > int(8+hostnameLength) { + return data[9 : 9+hostnameLength] + } + } + } + data = data[4+length:] + } + return nil +} + +func checkDiskSpace() error { + var stat unix.Statfs_t + + wd, err := os.Getwd() + if err != nil { + fmt.Printf("cannot get cwd: %s", err) + } + unix.Statfs(wd, &stat) + if stat.Blocks*uint64(stat.Bsize) == 0 { + return fmt.Errorf("no blocks available") + } + freeDiskSpace := float64(stat.Bfree) / float64(stat.Blocks) + if freeDiskSpace < 0.10 { + return fmt.Errorf("not enough disk free disk space: %f", freeDiskSpace) + } + + return nil +} + +// Packet log rotation +func PacketLoggerLogRotation(storage storage.Iface) { + for { + err := packetLoggerLogRotation(storage) + if err != nil { + logging.ErrorLog(fmt.Errorf("packet logger log rotation error: %s", err)) + } + time.Sleep(getTimeUntilTomorrowStartOfDay()) // sleep until tomorrow + } +} + +func getTimeUntilTomorrowStartOfDay() time.Duration { + tomorrow := time.Now().AddDate(0, 0, 1) + tomorrowStartOfDay := time.Date(tomorrow.Year(), tomorrow.Month(), tomorrow.Day(), 0, 5, 0, 0, time.Local) + return time.Until(tomorrowStartOfDay) +} + +func packetLoggerLogRotation(storage storage.Iface) error { + logDir := path.Join(VPN_STATS_DIR, VPN_PACKETLOGGER_DIR) + files, err := storage.ReadDir(logDir) + if err != nil { + return fmt.Errorf("readDir error: %s", err) + } + for _, filename := range files { + filenameSplit := strings.Split(strings.TrimSuffix(filename, ".log"), "-") + if len(filenameSplit) > 3 { + dateParsed, err := time.Parse("2006-01-02", strings.Join(filenameSplit[len(filenameSplit)-3:], "-")) + if err == nil { + if !dateutils.DateEqual(dateParsed, time.Now()) { + err := packetLoggerCompressLog(storage, filename) + if err != nil { + return fmt.Errorf("rotate log error: %s", err) + } + err = packetLoggerRenameLog(storage, filename) + if err != nil { + return fmt.Errorf("rotate log error (rename): %s", err) + } + } + + } + } + } + return nil +} + +func packetLoggerCompressLog(storage storage.Iface, filename string) error { + reader, err := storage.OpenFile(path.Join(VPN_STATS_DIR, VPN_PACKETLOGGER_DIR, filename)) + if err != nil { + return fmt.Errorf("open file error (%s): %s", filename, err) + } + writer, err := storage.OpenFileForWriting(path.Join(VPN_STATS_DIR, VPN_PACKETLOGGER_DIR, filename+".gz.tmp")) + if err != nil { + return fmt.Errorf("write file error (%s): %s", filename+".gz.tmp", err) + } + defer reader.Close() + defer writer.Close() + + gzipWriter, err := gzip.NewWriterLevel(writer, gzip.DefaultCompression) + if err != nil { + return fmt.Errorf("gzip writer error: %s", err) + } + _, err = io.Copy(gzipWriter, reader) + if err != nil { + return fmt.Errorf("copy error: %s", err) + } + err = gzipWriter.Close() + if err != nil { + return fmt.Errorf("file close error (gzip): %s", err) + } + return nil +} +func packetLoggerRenameLog(storage storage.Iface, filename string) error { + err := storage.Rename(path.Join(VPN_STATS_DIR, VPN_PACKETLOGGER_DIR, filename+".gz.tmp"), path.Join(VPN_STATS_DIR, VPN_PACKETLOGGER_DIR, filename+".gz")) + if err != nil { + return fmt.Errorf("rename error: %s", err) + } + err = storage.Remove(path.Join(VPN_STATS_DIR, VPN_PACKETLOGGER_DIR, filename)) + if err != nil { + return fmt.Errorf("delete log error: %s", err) + } + return nil +} diff --git a/pkg/wireguard/packetlogger_test.go b/pkg/wireguard/packetlogger_test.go new file mode 100644 index 0000000..edc00bb --- /dev/null +++ b/pkg/wireguard/packetlogger_test.go @@ -0,0 +1,284 @@ +package wireguard + +import ( + "bytes" + "compress/gzip" + "encoding/hex" + "fmt" + "io" + "net" + "os" + "path" + "strings" + "testing" + "time" + + localstorage "github.com/in4it/wireguard-server/pkg/storage/local" + memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" + dateutils "github.com/in4it/wireguard-server/pkg/utils/date" +) + +func TestParsePacket(t *testing.T) { + storage := &memorystorage.MockMemoryStorage{} + clientCache := &ClientCache{ + Addresses: []ClientCacheAddresses{ + { + Address: net.IPNet{ + IP: net.ParseIP("10.189.184.2"), + Mask: net.IPMask(net.ParseIP("255.255.255.255").To4()), + }, + ClientID: "1-2-3-4", + }, + { + Address: net.IPNet{ + IP: net.ParseIP("10.189.184.3"), + Mask: net.IPMask(net.ParseIP("255.255.255.255").To4()), + }, + ClientID: "1-2-3-5", + }, + }, + } + input := []string{ + // DNS reqs + "45000037e04900004011cdab0abdb8020a000002e60d00350023d6861e1501000001000000000000056170706c6503636f6d0000010001", + "4500004092d1000040111b1b0abdb8020a000002c73b0035002c4223b28e01000001000000000000037777770a676f6f676c656170697303636f6d0000410001", + "450000e300004000fe11af480a0000020abdb8020035dbb500cffccbad65818000010000000100000975732d656173742d310470726f6402707209616e616c797469637307636f6e736f6c65036177730361327a03636f6d00001c00010975732d656173742d310470726f6402707209616e616c797469637307636f6e736f6c65036177730361327a03636f6d00000600010000014b004b076e732d3136333709617773646e732d313202636f02756b0011617773646e732d686f73746d617374657206616d617a6f6e03636f6d000000000100001c20000003840012750000015180", + "450000a100004000fe11af8a0a0000020abdb8020035e136008db8bd155f81830001000000010000026462075f646e732d7364045f756470086174746c6f63616c036e657400000c0001c01c00060001000003c0004b046f726375026f72026272026e7007656c732d676d7303617474c0250d726d2d686f73746d617374657203656d730361747403636f6d0000000001000151800000271000093a8000015180", + // http req (SYN + Data) + "450000400000400040066ced0abdb8020a00010cc7b000507216cbdd00000000b0c2ffff008f000002040564010303060101080a69fbf8410000000004020000", + "450200810000400040066caa0abdb8020a00010cc7b000507216cbde4845afad80180804449900000101080a69fbf873eddf46d7474554202f6c6f67696e20485454502f312e310d0a486f73743a2031302e302e312e31320d0a557365722d4167656e743a206375726c2f382e372e310d0a4163636570743a202a2f2a0d0a0d0a", + // https req + "450000400000400040066ced0abdb8020a00010cf24a01bb510f111000000000b0c2ffffe119000002040564010303060101080a327dff040000000004020000", + "450000340000400040066cf90abdb8020a00010cf24a01bb510f1111c4b4fb4b801008046b8700000101080a327dff34edeeff9e", + "4502017d0000400040066bae0abdb8020a00010cf24a01bb510f1111c4b4fb4b801808041b1500000101080a327dff36edeeff9e1603010144010001400303e3b233de9dcd3f71f4c6e3d0d45ec25144e2fcdf8c676e52ff5cfc021123786020056eefe25e5b4e9abec2953b5fa9bc1f68dd09d7ad4ddce858476b4aaaa029b80062130313021301cca9cca8ccaac030c02cc028c024c014c00a009f006b0039ff8500c400880081009d003d003500c00084c02fc02bc027c023c013c009009e0067003300be0045009c003c002f00ba0041c011c00700050004c012c0080016000a00ff01000095002b0009080304030303020301003300260024001d0020dc2b5e4f0741b2ff9982fe2bfa6641e22fe80e5b50811780b82aafae96570c2400000018001600001376706e2d7365727665722e696e3469742e696f000b00020100000a000a0008001d001700180019000d00180016080606010603080505010503080404010403020102030010000e000c02683208687474702f312e31", + "450000340000400040066cf90abdb8020a00010cf24a01bb510f125ac4b500a3801007ee649700000101080a327dff66edeeffd1", + "450000340000400040066cf90abdb8020a00010cf24a01bb510f125ac4b504a1801007f0609600000101080a327dff67edeeffd1", + "4502003a0000400040066cf10abdb8020a00010cf24a01bb510f125ac4b504a180180800487100000101080a327dff6aedeeffd1140303000101", + "450201050000400040066c260abdb8020a00010cf24a01bb510f1260c4b504a180180800ea1300000101080a327dffc0edef002a1703030035131e32cc93174219580748842686d43e1cbb73501f643eaa49b3b7ba50a9f0a97e19ec926f8b5b141b363067d9a31061b146010d8f17030300511611c04909f5346b580fe1a95c68b2a62389ca6ed7e2f31ddb38cb191cf0997e16b5efaa9248a213e621869d071af7339ddafaee642953538a03d89cb3896ecf6756f5fb80f1866671282da72dce691169170303003c3bd012039a27a373dd1b4e7509e0e9aaefc4cfae6adcae6f670501e2577e20c98233761878d9f64355a89aa389f56480517bada888a2625ef211cb5e", + } + now := time.Now() + for _, s := range input { + + data, err := hex.DecodeString(s) + if err != nil { + t.Fatalf("hex decode error: %s", err) + } + err = parsePacket(storage, data, clientCache) + if err != nil { + t.Fatalf("parse error: %s", err) + } + } + + out, err := storage.ReadFile(path.Join(VPN_STATS_DIR, VPN_PACKETLOGGER_DIR, "1-2-3-4-"+now.Format("2006-01-02")+".log")) + if err != nil { + t.Fatalf("read file error: %s", err) + } + if !strings.Contains(string(out), `,udp,10.189.184.2,10.0.0.2,58893,53,apple.com`) { + t.Fatalf("unexpected output. Expected udp record") + } + if !strings.Contains(string(out), `,https,10.189.184.2,10.0.1.12,62026,443,vpn-server.in4it.io`) { + t.Fatalf("unexpected output. Expected https record") + } +} + +func TestParsePacketSNI(t *testing.T) { + storage := &memorystorage.MockMemoryStorage{} + clientCache := &ClientCache{ + Addresses: []ClientCacheAddresses{ + { + Address: net.IPNet{ + IP: net.ParseIP("10.189.184.2"), + Mask: net.IPMask(net.ParseIP("255.255.255.255").To4()), + }, + ClientID: "1-2-3-4", + }, + { + Address: net.IPNet{ + IP: net.ParseIP("10.189.184.3"), + Mask: net.IPMask(net.ParseIP("255.255.255.255").To4()), + }, + ClientID: "1-2-3-5", + }, + }, + } + input := []string{ + `450000d100004000400682160abdb80240e9b468ec5001bb4f71ed891a93673d8018080468f400000101080a1329f7772c5410131603010098010000940301f1d62f57f05cc00fc8fb984e7fc381a26adc301ec143b9bab6d36f3f1b15c97200002ec014c00a0039ff850088008100350084c013c00900330045002f0041c011c00700050004c012c0080016000a00ff0100003d00000013001100000e7777772e676f6f676c652e636f6d000b00020100000a000a0008001d0017001800190010000e000c02683208687474702f312e31`, + } + now := time.Now() + for _, s := range input { + + data, err := hex.DecodeString(s) + if err != nil { + t.Fatalf("hex decode error: %s", err) + } + err = parsePacket(storage, data, clientCache) + if err != nil { + t.Fatalf("parse error: %s", err) + } + } + + out, err := storage.ReadFile(path.Join(VPN_STATS_DIR, VPN_PACKETLOGGER_DIR, "1-2-3-4-"+now.Format("2006-01-02")+".log")) + if err != nil { + t.Fatalf("read file error: %s", err) + } + if !strings.Contains(string(out), `,https,10.189.184.2,64.233.180.104,60496,443,www.google.com`) { + t.Fatalf("unexpected output. Expected https record") + } +} + +func TestParseTLSExtensionSNI(t *testing.T) { + input := []string{ + "00000013001100000e7777772e676f6f676c652e636f6d000b00020100000a000a0008001d0017001800190010000e000c02683208687474702f312e31", + "00000018001600001376706e2d7365727665722e696e3469742e696f", + "00000018001600001376706e2d7365727665722e696e3469742e696f000b00020100000a000a0008001d001700180019000d00180016080606010603080505010503080404010403020102030010000e000c02683208687474702f312e31", + } + match := []string{ + "www.google.com", + "vpn-server.in4it.io", + "vpn-server.in4it.io", + } + for k, s := range input { + + data, err := hex.DecodeString(s) + if err != nil { + t.Fatalf("hex decode error: %s", err) + } + if sni := parseTLSExtensionSNI(data); sni != nil { + if string(sni) != match[k] { + t.Fatalf("got SNI, but expected different hostname. Got: %s", sni) + } + } else { + t.Fatalf("no SNI found") + } + } +} +func TestParseTLSExtensionSNINoMatch(t *testing.T) { + input := []string{ + "0010000e000c02", + "000d00180016080606010603080505010503080404010403020102030010000e000c02683208687474702f312e31", + "00", + } + for _, s := range input { + + data, err := hex.DecodeString(s) + if err != nil { + t.Fatalf("hex decode error: %s", err) + } + if sni := parseTLSExtensionSNI(data); sni != nil { + t.Fatalf("got match, expected no match. Got: %s", sni) + } + } +} + +func TestCheckDiskSpace(t *testing.T) { + err := checkDiskSpace() + if err != nil { + t.Fatalf("disk space error: %s", err) + } +} + +func TestPacketLoggerLogRotation(t *testing.T) { + prefix := path.Join(VPN_STATS_DIR, VPN_PACKETLOGGER_DIR) + key1 := path.Join(prefix, fmt.Sprintf("1-2-3-4-%s.log", time.Now().AddDate(0, 0, -1).Format("2006-01-02"))) + value1 := []byte(time.Now().Format(TIMESTAMP_FORMAT) + `,https,10.189.184.2,64.233.180.104,60496,443,www.google.com`) + key2 := path.Join(prefix, fmt.Sprintf("1-2-3-4-%s.log", time.Now().Format("2006-01-02"))) + value2 := []byte(time.Now().Format(TIMESTAMP_FORMAT) + `,https,10.189.184.3,64.233.180.104,12345,443,www.google.com`) + + storage := &memorystorage.MockMemoryStorage{ + Data: map[string]*memorystorage.MockReadWriterData{}, + } + err := storage.WriteFile(key1, value1) + if err != nil { + t.Fatalf("write file error: %s", err) + } + err = storage.WriteFile(key2, value2) + if err != nil { + t.Fatalf("write file error: %s", err) + } + + err = packetLoggerLogRotation(storage) + if err != nil { + t.Fatalf("packetLoggerRotation error: %s", err) + } + body, err := storage.ReadFile(key1 + ".gz") + if err != nil { + t.Fatalf("can't read compressed file") + } + reader, err := gzip.NewReader(bytes.NewReader(body)) + if err != nil { + t.Fatalf("can't open gzip reader") + } + bodyDecoded, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("can't read gzip data") + } + if string(bodyDecoded) != string(value1) { + t.Fatalf("unexpected output. Got %s, expected: %s", bodyDecoded, value1) + } +} + +func TestPacketLoggerLogRotationLocalStorage(t *testing.T) { + prefix := path.Join(VPN_STATS_DIR, VPN_PACKETLOGGER_DIR) + key1 := path.Join(prefix, fmt.Sprintf("1-2-3-4-%s.log", time.Now().AddDate(0, 0, -1).Format("2006-01-02"))) + value1 := []byte(time.Now().Format(TIMESTAMP_FORMAT) + `,https,10.189.184.2,64.233.180.104,60496,443,www.google.com`) + key2 := path.Join(prefix, fmt.Sprintf("1-2-3-4-%s.log", time.Now().Format("2006-01-02"))) + value2 := []byte(time.Now().Format(TIMESTAMP_FORMAT) + `,https,10.189.184.3,64.233.180.104,12345,443,www.google.com`) + + pwd, err := os.Executable() + if err != nil { + t.Fatalf("os Executable error: %s", err) + } + storage, err := localstorage.NewWithPath(path.Dir(pwd)) + if err != nil { + t.Fatalf("localstorage error: %s", err) + } + err = storage.EnsurePath(VPN_STATS_DIR) + if err != nil { + t.Fatalf("could not ensure path: %s", err) + } + storage.EnsurePath(path.Join(VPN_STATS_DIR, VPN_PACKETLOGGER_DIR)) + if err != nil { + t.Fatalf("could not ensure path: %s", err) + } + err = storage.WriteFile(key1, value1) + if err != nil { + t.Fatalf("write file error: %s", err) + } + err = storage.WriteFile(key2, value2) + if err != nil { + t.Fatalf("write file error: %s", err) + } + t.Cleanup(func() { + os.Remove(path.Join(path.Dir(pwd), key1)) + os.Remove(path.Join(path.Dir(pwd), key1+".gz.tmp")) + os.Remove(path.Join(path.Dir(pwd), key1+".gz")) + os.Remove(path.Join(path.Dir(pwd), key2)) + }) + + err = packetLoggerLogRotation(storage) + if err != nil { + t.Fatalf("packetLoggerRotation error: %s", err) + } + body, err := storage.ReadFile(key1 + ".gz") + if err != nil { + t.Fatalf("can't read compressed file") + } + reader, err := gzip.NewReader(bytes.NewReader(body)) + if err != nil { + t.Fatalf("can't open gzip reader") + } + bodyDecoded, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("can't read gzip data") + } + if string(bodyDecoded) != string(value1) { + t.Fatalf("unexpected output. Got %s, expected: %s", bodyDecoded, value1) + } +} + +func TestGetTimeUntilTomorrowStartOfDay(t *testing.T) { + duration := getTimeUntilTomorrowStartOfDay() + if !dateutils.DateEqual(time.Now().Add(duration), time.Now().AddDate(0, 0, 1)) { + t.Fatalf("date is not tomorrow") + } +} diff --git a/pkg/wireguard/types.go b/pkg/wireguard/types.go index dcd4cc0..7632da3 100644 --- a/pkg/wireguard/types.go +++ b/pkg/wireguard/types.go @@ -1,6 +1,7 @@ package wireguard import ( + "net" "net/netip" "time" ) @@ -31,16 +32,18 @@ type VPNServerClient struct { } type VPNConfig struct { - AddressRange netip.Prefix `json:"addressRange"` - ClientAddressPrefix string `json:"clientAddressPrefix"` - PublicKey string `json:"publicKey"` - PresharedKey string `json:"presharedKey"` - Endpoint string `json:"endpoint"` - Port int `json:"port"` - ExternalInterface string `json:"externalInterface"` - Nameservers []string `json:"nameservers"` - DisableNAT bool `json:"disableNAT"` - ClientRoutes []string `json:"clientRoutes"` + AddressRange netip.Prefix `json:"addressRange"` + ClientAddressPrefix string `json:"clientAddressPrefix"` + PublicKey string `json:"publicKey"` + PresharedKey string `json:"presharedKey"` + Endpoint string `json:"endpoint"` + Port int `json:"port"` + ExternalInterface string `json:"externalInterface"` + Nameservers []string `json:"nameservers"` + DisableNAT bool `json:"disableNAT"` + ClientRoutes []string `json:"clientRoutes"` + EnablePacketLogs bool `json:"enablePacketLogs"` + PacketLogsTypes map[string]bool `json:"packetLogsTypes"` } type PubKeyExchange struct { @@ -71,3 +74,13 @@ type StatsEntry struct { ReceiveBytes int64 TransmitBytes int64 } + +// client cache + +type ClientCache struct { + Addresses []ClientCacheAddresses +} +type ClientCacheAddresses struct { + Address net.IPNet + ClientID string +} diff --git a/pkg/wireguard/vpnconfig.go b/pkg/wireguard/vpnconfig.go index 2d819a5..e56fe3b 100644 --- a/pkg/wireguard/vpnconfig.go +++ b/pkg/wireguard/vpnconfig.go @@ -44,16 +44,26 @@ func GetVPNConfig(storage storage.Iface) (VPNConfig, error) { if err != nil { return vpnConfig, fmt.Errorf("decode input error: %s", err) } + + if vpnConfig.PacketLogsTypes == nil { + vpnConfig.PacketLogsTypes = make(map[string]bool) + } + return vpnConfig, nil } func getEmptyVPNConfig() (VPNConfig, error) { - vpnConfig := VPNConfig{} + vpnConfig := VPNConfig{ + PacketLogsTypes: make(map[string]bool), + } return vpnConfig, nil } func CreateNewVPNConfig(storage storage.Iface) (VPNConfig, error) { - vpnConfig := VPNConfig{} + vpnConfig, err := getEmptyVPNConfig() + if err != nil { + return vpnConfig, fmt.Errorf("get empty vpn config error: %s", err) + } prefix, err := netip.ParsePrefix(DEFAULT_VPN_PREFIX) if err != nil { return vpnConfig, fmt.Errorf("ParsePrefix error: %s", err) @@ -135,6 +145,19 @@ func WriteVPNConfig(storage storage.Iface, vpnConfig VPNConfig) error { } } + // notify configmanager + client := http.Client{ + Timeout: 10 * time.Second, + } + + resp, err := client.Post("http://"+CONFIGMANAGER_URI+"/refresh-server-config", "application/json", nil) + if err != nil { + return fmt.Errorf("configmanager post error: %s", err) + } + if resp.StatusCode != http.StatusAccepted { + return fmt.Errorf("configmanager post error: received status code %d", resp.StatusCode) + } + return nil } diff --git a/pkg/wireguard/wireguardclientconfig.go b/pkg/wireguard/wireguardclientconfig.go index 09a29f6..e934852 100644 --- a/pkg/wireguard/wireguardclientconfig.go +++ b/pkg/wireguard/wireguardclientconfig.go @@ -70,12 +70,19 @@ func NewEmptyClientConfig(storage storage.Iface, userID string) (PeerConfig, err return PeerConfig{}, fmt.Errorf("getClientAllowedIPs error: %s", err) } + // validate address + address := nextFreeIP.String() + vpnConfig.ClientAddressPrefix + _, _, err = net.ParseCIDR(address) + if err != nil { + return PeerConfig{}, fmt.Errorf("cannot parse client address: %s", err) + } + peerConfig := PeerConfig{ ID: fmt.Sprintf("%s-%d", userID, newConfigNumber), DNS: strings.Join(vpnConfig.Nameservers, ", "), Name: fmt.Sprintf("connection-%d", newConfigNumber), - Address: nextFreeIP.String() + vpnConfig.ClientAddressPrefix, - ServerAllowedIPs: []string{nextFreeIP.String() + vpnConfig.ClientAddressPrefix}, + Address: address, + ServerAllowedIPs: []string{address}, ClientAllowedIPs: clientAllowedIPs, } @@ -161,10 +168,10 @@ func UpdateClientsConfig(storage storage.Iface) error { } func getPeerConfig(storage storage.Iface, connectionID string) (PeerConfig, error) { - return getPeerConfigByFilename(storage, fmt.Sprintf("%s.json", connectionID)) + return GetPeerConfigByFilename(storage, fmt.Sprintf("%s.json", connectionID)) } -func getPeerConfigByFilename(storage storage.Iface, filename string) (PeerConfig, error) { +func GetPeerConfigByFilename(storage storage.Iface, filename string) (PeerConfig, error) { var peerConfig PeerConfig peerConfigFilename := storage.ConfigPath(path.Join(VPN_CLIENTS_DIR, filename)) peerConfigBytes, err := storage.ReadFile(peerConfigFilename) @@ -187,7 +194,7 @@ func GetAllPeerConfigs(storage storage.Iface) ([]PeerConfig, error) { } peerConfigs := make([]PeerConfig, len(entries)) for k, entry := range entries { - peerConfig, err := getPeerConfigByFilename(storage, entry) + peerConfig, err := GetPeerConfigByFilename(storage, entry) if err != nil { return peerConfigs, fmt.Errorf("cnanot get peer config (%s): %s", entry, err) } diff --git a/pkg/wireguard/wireguardclientconfig_test.go b/pkg/wireguard/wireguardclientconfig_test.go index 5c1d346..d6d38c2 100644 --- a/pkg/wireguard/wireguardclientconfig_test.go +++ b/pkg/wireguard/wireguardclientconfig_test.go @@ -12,7 +12,7 @@ import ( "testing" "time" - testingmocks "github.com/in4it/wireguard-server/pkg/testing/mocks" + memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" ) func TestGetNextFreeIPFromList(t *testing.T) { @@ -80,6 +80,11 @@ func TestWriteConfig(t *testing.T) { w.Write([]byte("OK")) return } + if r.RequestURI == "/refresh-server-config" { + w.WriteHeader(http.StatusAccepted) + w.Write([]byte("OK")) + return + } w.WriteHeader(http.StatusBadRequest) default: w.WriteHeader(http.StatusBadRequest) @@ -92,7 +97,7 @@ func TestWriteConfig(t *testing.T) { defer ts.Close() defer l.Close() - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} // first create a new vpn config vpnconfig, err := CreateNewVPNConfig(storage) @@ -152,6 +157,11 @@ func TestWriteConfigMultipleClients(t *testing.T) { w.Write([]byte("OK")) return } + if r.RequestURI == "/refresh-server-config" { + w.WriteHeader(http.StatusAccepted) + w.Write([]byte("OK")) + return + } w.WriteHeader(http.StatusBadRequest) default: w.WriteHeader(http.StatusBadRequest) @@ -164,7 +174,7 @@ func TestWriteConfigMultipleClients(t *testing.T) { defer ts.Close() defer l.Close() - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} // first create a new vpn config vpnconfig, err := CreateNewVPNConfig(storage) @@ -226,6 +236,11 @@ func TestCreateAndDeleteAllClientConfig(t *testing.T) { w.Write([]byte("OK")) return } + if r.RequestURI == "/refresh-server-config" { + w.WriteHeader(http.StatusAccepted) + w.Write([]byte("OK")) + return + } w.WriteHeader(http.StatusBadRequest) default: w.WriteHeader(http.StatusBadRequest) @@ -238,7 +253,7 @@ func TestCreateAndDeleteAllClientConfig(t *testing.T) { defer ts.Close() defer l.Close() - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} // first create a new vpn config vpnconfig, err := CreateNewVPNConfig(storage) @@ -272,7 +287,7 @@ func TestCreateAndDeleteAllClientConfig(t *testing.T) { t.Fatalf("couldn't find peer config file written in storage") } - err = json.Unmarshal(writtenPeerconfig, &peerConfig) + err = json.Unmarshal(*writtenPeerconfig, &peerConfig) if err != nil { t.Fatalf("unmarshal error: %s", err) } @@ -314,6 +329,11 @@ func TestCreateAndDeleteClientConfig(t *testing.T) { w.Write([]byte("OK")) return } + if r.RequestURI == "/refresh-server-config" { + w.WriteHeader(http.StatusAccepted) + w.Write([]byte("OK")) + return + } w.WriteHeader(http.StatusBadRequest) default: w.WriteHeader(http.StatusBadRequest) @@ -326,7 +346,7 @@ func TestCreateAndDeleteClientConfig(t *testing.T) { defer ts.Close() defer l.Close() - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} // first create a new vpn config vpnconfig, err := CreateNewVPNConfig(storage) @@ -360,7 +380,7 @@ func TestCreateAndDeleteClientConfig(t *testing.T) { t.Fatalf("couldn't find peer config file written in storage") } - err = json.Unmarshal(writtenPeerconfig, &peerConfig) + err = json.Unmarshal(*writtenPeerconfig, &peerConfig) if err != nil { t.Fatalf("unmarshal error: %s", err) } @@ -403,6 +423,11 @@ func TestCreateAndDisableAllClientConfig(t *testing.T) { w.Write([]byte("OK")) return } + if r.RequestURI == "/refresh-server-config" { + w.WriteHeader(http.StatusAccepted) + w.Write([]byte("OK")) + return + } w.WriteHeader(http.StatusBadRequest) default: w.WriteHeader(http.StatusBadRequest) @@ -415,7 +440,7 @@ func TestCreateAndDisableAllClientConfig(t *testing.T) { defer ts.Close() defer l.Close() - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} // first create a new vpn config vpnconfig, err := CreateNewVPNConfig(storage) @@ -449,7 +474,7 @@ func TestCreateAndDisableAllClientConfig(t *testing.T) { t.Fatalf("couldn't find peer config file written in storage") } - err = json.Unmarshal(writtenPeerconfig, &peerConfig) + err = json.Unmarshal(*writtenPeerconfig, &peerConfig) if err != nil { t.Fatalf("unmarshal error: %s", err) } @@ -471,7 +496,7 @@ func TestCreateAndDisableAllClientConfig(t *testing.T) { t.Fatalf("couldn't find peer config file written in storage") } - err = json.Unmarshal(writtenPeerconfig, &peerConfig) + err = json.Unmarshal(*writtenPeerconfig, &peerConfig) if err != nil { t.Fatalf("unmarshal error: %s", err) } @@ -492,7 +517,7 @@ func TestCreateAndDisableAllClientConfig(t *testing.T) { t.Fatalf("couldn't find peer config file written in storage") } - err = json.Unmarshal(writtenPeerconfig, &peerConfig) + err = json.Unmarshal(*writtenPeerconfig, &peerConfig) if err != nil { t.Fatalf("unmarshal error: %s", err) } @@ -527,6 +552,11 @@ func TestUpdateClientConfig(t *testing.T) { w.Write([]byte("OK")) return } + if r.RequestURI == "/refresh-server-config" { + w.WriteHeader(http.StatusAccepted) + w.Write([]byte("OK")) + return + } w.WriteHeader(http.StatusBadRequest) default: w.WriteHeader(http.StatusBadRequest) @@ -539,7 +569,7 @@ func TestUpdateClientConfig(t *testing.T) { defer ts.Close() defer l.Close() - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} // first create a new vpn config vpnconfig, err := CreateNewVPNConfig(storage) @@ -610,6 +640,11 @@ func TestUpdateClientConfigNewAddressRange(t *testing.T) { w.Write([]byte("OK")) return } + if r.RequestURI == "/refresh-server-config" { + w.WriteHeader(http.StatusAccepted) + w.Write([]byte("OK")) + return + } w.WriteHeader(http.StatusBadRequest) default: w.WriteHeader(http.StatusBadRequest) @@ -622,7 +657,7 @@ func TestUpdateClientConfigNewAddressRange(t *testing.T) { defer ts.Close() defer l.Close() - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} // first create a new vpn config vpnconfig, err := CreateNewVPNConfig(storage) @@ -722,6 +757,11 @@ func TestUpdateClientConfigNewClientAddressPrefix(t *testing.T) { w.Write([]byte("OK")) return } + if r.RequestURI == "/refresh-server-config" { + w.WriteHeader(http.StatusAccepted) + w.Write([]byte("OK")) + return + } w.WriteHeader(http.StatusBadRequest) default: w.WriteHeader(http.StatusBadRequest) @@ -734,7 +774,7 @@ func TestUpdateClientConfigNewClientAddressPrefix(t *testing.T) { defer ts.Close() defer l.Close() - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} // first create a new vpn config vpnconfig, err := CreateNewVPNConfig(storage) diff --git a/pkg/wireguard/wireguardserverconfig_test.go b/pkg/wireguard/wireguardserverconfig_test.go index f791d75..3552bf6 100644 --- a/pkg/wireguard/wireguardserverconfig_test.go +++ b/pkg/wireguard/wireguardserverconfig_test.go @@ -1,14 +1,59 @@ package wireguard import ( + "net" + "net/http" + "net/http/httptest" "strings" "testing" + "time" - testingmocks "github.com/in4it/wireguard-server/pkg/testing/mocks" + memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" ) func TestWriteWireGuardServerConfig(t *testing.T) { - storage := &testingmocks.MockMemoryStorage{} + var ( + l net.Listener + err error + ) + for { + l, err = net.Listen("tcp", CONFIGMANAGER_URI) + if err != nil { + if !strings.HasSuffix(err.Error(), "address already in use") { + t.Fatal(err) + } + time.Sleep(1 * time.Second) + } else { + break + } + } + + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPost: + if r.RequestURI == "/refresh-clients" { + w.WriteHeader(http.StatusAccepted) + w.Write([]byte("OK")) + return + } + if r.RequestURI == "/refresh-server-config" { + w.WriteHeader(http.StatusAccepted) + w.Write([]byte("OK")) + return + } + w.WriteHeader(http.StatusBadRequest) + default: + w.WriteHeader(http.StatusBadRequest) + } + })) + + ts.Listener.Close() + ts.Listener = l + ts.Start() + defer ts.Close() + defer l.Close() + + storage := &memorystorage.MockMemoryStorage{} // first create a new vpn config vpnconfig, err := CreateNewVPNConfig(storage) diff --git a/webapp/package-lock.json b/webapp/package-lock.json index 7c054e3..120f729 100644 --- a/webapp/package-lock.json +++ b/webapp/package-lock.json @@ -12,7 +12,6 @@ "@mantine/dates": "^7.12.1", "@mantine/form": "^7.10.0", "@mantine/hooks": "^7.9.2", - "@tabler/icons-react": "^3.4.0", "@tanstack/react-query": "^5.36.2", "axios": "^1.7.4", "base32-encode": "^2.0.0", @@ -24,6 +23,7 @@ "react-cookie": "^7.1.4", "react-dom": "^18.2.0", "react-hook-qrcode-svg": "^1.5.1", + "react-icons": "^5.3.0", "react-router-dom": "^6.23.1" }, "devDependencies": { @@ -1356,30 +1356,6 @@ "win32" ] }, - "node_modules/@tabler/icons": { - "version": "3.5.0", - "resolved": "https://registry.npmjs.org/@tabler/icons/-/icons-3.5.0.tgz", - "integrity": "sha512-I53dC3ZSHQ2MZFGvDYJelfXm91L2bTTixS4w5jTAulLhHbCZso5Bih4Rk/NYZxlngLQMKHvEYwZQ+6w/WluKiA==", - "funding": { - "type": "github", - "url": "https://github.com/sponsors/codecalm" - } - }, - "node_modules/@tabler/icons-react": { - "version": "3.5.0", - "resolved": "https://registry.npmjs.org/@tabler/icons-react/-/icons-react-3.5.0.tgz", - "integrity": "sha512-bn05XKZV3ZfOv5Jr1FCTmVPOQGBVJoA4NefrnR919rqg6WGXAa08NovONHJGSuMxXUMV3b9Cni85diIW/E9yuw==", - "dependencies": { - "@tabler/icons": "3.5.0" - }, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/codecalm" - }, - "peerDependencies": { - "react": ">= 16" - } - }, "node_modules/@tanstack/query-core": { "version": "5.45.0", "resolved": "https://registry.npmjs.org/@tanstack/query-core/-/query-core-5.45.0.tgz", @@ -3733,6 +3709,15 @@ "react": ">=18.0.0" } }, + "node_modules/react-icons": { + "version": "5.3.0", + "resolved": "https://registry.npmjs.org/react-icons/-/react-icons-5.3.0.tgz", + "integrity": "sha512-DnUk8aFbTyQPSkCfF8dbX6kQjXA9DktMeJqfjrg6cK9vwQVMxmcA3BfP4QoiztVmEHtwlTgLFsPuH2NskKT6eg==", + "license": "MIT", + "peerDependencies": { + "react": "*" + } + }, "node_modules/react-is": { "version": "16.13.1", "resolved": "https://registry.npmjs.org/react-is/-/react-is-16.13.1.tgz", diff --git a/webapp/package.json b/webapp/package.json index b33abd9..41d6964 100644 --- a/webapp/package.json +++ b/webapp/package.json @@ -15,7 +15,6 @@ "@mantine/dates": "^7.12.1", "@mantine/form": "^7.10.0", "@mantine/hooks": "^7.9.2", - "@tabler/icons-react": "^3.4.0", "@tanstack/react-query": "^5.36.2", "axios": "^1.7.4", "base32-encode": "^2.0.0", @@ -27,6 +26,7 @@ "react-cookie": "^7.1.4", "react-dom": "^18.2.0", "react-hook-qrcode-svg": "^1.5.1", + "react-icons": "^5.3.0", "react-router-dom": "^6.23.1" }, "devDependencies": { diff --git a/webapp/src/App.tsx b/webapp/src/App.tsx index ed795a8..ecc98fa 100644 --- a/webapp/src/App.tsx +++ b/webapp/src/App.tsx @@ -18,6 +18,7 @@ import { Users } from "./Routes/Users/Users"; import { Profile } from "./Routes/Profile/Profile"; import { Upgrade } from "./Routes/Upgrade/Upgrade"; import { GetMoreLicenses } from "./Routes/Licenses/GetMoreLicenses"; +import { PacketLogs } from "./Routes/PacketLogs/PacketLogs"; const queryClient = new QueryClient() @@ -44,14 +45,16 @@ export default function App() { } /> } /> } /> + } /> } /> + } /> + } /> + } /> } /> } /> } /> } /> } /> - } /> - } /> diff --git a/webapp/src/Auth/AuthBanner.tsx b/webapp/src/Auth/AuthBanner.tsx index 74eb62e..54d3d21 100644 --- a/webapp/src/Auth/AuthBanner.tsx +++ b/webapp/src/Auth/AuthBanner.tsx @@ -18,7 +18,7 @@ import { AppSettings } from '../Constants/Constants'; import { useAuthContext } from './Auth'; import { AuthError } from './AuthError'; import { MFAInput } from './MFAInput'; -import { IconInfoCircle } from '@tabler/icons-react'; +import { TbInfoCircle } from "react-icons/tb"; type LoginResponse = { @@ -117,7 +117,7 @@ export function AuthBanner() { authenticate.mutate({login, password, factorResponse}) } } - const alertIcon = + const alertIcon = if (error) return 'An backend error has occurred: ' + error.message diff --git a/webapp/src/Auth/AuthError.tsx b/webapp/src/Auth/AuthError.tsx index 3e8e7c0..be91b21 100644 --- a/webapp/src/Auth/AuthError.tsx +++ b/webapp/src/Auth/AuthError.tsx @@ -1,9 +1,9 @@ import { Alert } from "@mantine/core"; -import { IconInfoCircle } from "@tabler/icons-react"; +import { TbInfoCircle } from "react-icons/tb"; import { useSearchParams } from "react-router-dom"; export function AuthError() { - const alertIcon = ; + const alertIcon = ; let [ searchParams, _ ] = useSearchParams(); if(!searchParams.has("error")) return '' diff --git a/webapp/src/NavBar/NavBar.tsx b/webapp/src/NavBar/NavBar.tsx index 5616b8d..d0437fe 100644 --- a/webapp/src/NavBar/NavBar.tsx +++ b/webapp/src/NavBar/NavBar.tsx @@ -1,15 +1,18 @@ import { useState } from 'react'; import { Group, Code } from '@mantine/core'; import { - IconBellRinging, - IconSettings, - IconLogout, - IconUser, - IconPlugConnected, - IconCloudDataConnection, - IconBook, - IconUserCircle, -} from '@tabler/icons-react'; + TbBellRinging, + TbSettings, + TbLogout, + TbUser, + TbPlugConnected, + TbCloudDataConnection, + TbBook, + TbUserCircle, + +} from 'react-icons/tb'; +import { FaStream } from "react-icons/fa"; + import classes from './Navbar.module.css'; import { NavLink, useLocation } from 'react-router-dom'; import { useAuthContext } from '../Auth/Auth'; @@ -22,16 +25,17 @@ export function NavBar() { const [active, setActive] = useState(pathname); const data = authInfo.role === "admin" ? [ - { link: '/', label: 'Status', icon: IconBellRinging }, - { link: '/connection', label: 'VPN Connections', icon: IconPlugConnected }, - { link: '/users', label: 'Users', icon: IconUser }, - { link: '/setup', label: 'VPN Setup', icon: IconSettings }, - { link: '/auth-setup', label: 'Authentication & Provisioning', icon: IconCloudDataConnection }, - { link: 'https://vpn-documentation.in4it.com', label: 'Documentation', icon: IconBook }, + { link: '/', label: 'Status', icon: TbBellRinging }, + { link: '/connection', label: 'VPN Connections', icon: TbPlugConnected }, + { link: '/users', label: 'Users', icon: TbUser }, + { link: '/setup', label: 'VPN Setup', icon: TbSettings }, + { link: '/auth-setup', label: 'Authentication & Provisioning', icon: TbCloudDataConnection }, + { link: '/packetlogs', label: 'Logging', icon: FaStream }, + { link: 'https://vpn-documentation.in4it.com', label: 'Documentation', icon: TbBook }, ] : [ - { link: '/connection', label: 'VPN Connections', icon: IconPlugConnected }, - { link: 'https://vpn-documentation.in4it.com', label: 'Documentation', icon: IconBook }, + { link: '/connection', label: 'VPN Connections', icon: TbPlugConnected }, + { link: 'https://vpn-documentation.in4it.com', label: 'Documentation', icon: TbBook }, ]; const links = data.map((item) => ( @@ -45,7 +49,7 @@ export function NavBar() { setActive(item.link); }} > - + {item.label} )); @@ -62,14 +66,14 @@ export function NavBar() {
{authInfo.userType == "local" ? { setActive("/profile"); }} data-active={"/profile" === active || undefined}> - + Profile : null } - + Logout
diff --git a/webapp/src/NavBar/Navbar.module.css b/webapp/src/NavBar/Navbar.module.css index d2db1b8..ed6b460 100644 --- a/webapp/src/NavBar/Navbar.module.css +++ b/webapp/src/NavBar/Navbar.module.css @@ -60,4 +60,5 @@ margin-right: var(--mantine-spacing-sm); width: rem(25px); height: rem(25px); + stroke-width: 1.25px; } \ No newline at end of file diff --git a/webapp/src/Routes/Connection/ListConnections.tsx b/webapp/src/Routes/Connection/ListConnections.tsx index 40e2399..cf584cd 100644 --- a/webapp/src/Routes/Connection/ListConnections.tsx +++ b/webapp/src/Routes/Connection/ListConnections.tsx @@ -4,7 +4,7 @@ import { Table, ScrollArea, Button } from '@mantine/core'; import classes from './ListConnections.module.css'; import { useQueryClient, useMutation, useQuery } from '@tanstack/react-query'; import { AppSettings } from '../../Constants/Constants'; -import { IconTrash } from '@tabler/icons-react'; +import { TbTrash } from 'react-icons/tb'; import axios from 'axios'; import { useAuthContext } from '../../Auth/Auth'; import { Download } from './Download'; @@ -49,7 +49,7 @@ export function ListConnections() { {row.name} - + )); diff --git a/webapp/src/Routes/Connection/NewConnection.tsx b/webapp/src/Routes/Connection/NewConnection.tsx index 3315a41..0197e50 100644 --- a/webapp/src/Routes/Connection/NewConnection.tsx +++ b/webapp/src/Routes/Connection/NewConnection.tsx @@ -4,13 +4,13 @@ import axios, { AxiosError } from "axios" import { useAuthContext } from "../../Auth/Auth"; import { useState } from "react"; import { Alert, Button } from "@mantine/core"; -import { IconInfoCircle } from "@tabler/icons-react"; +import { TbInfoCircle } from "react-icons/tb"; export function NewConnection() { const queryClient = useQueryClient() const {authInfo} = useAuthContext(); const [newConnectionError, setError] = useState("") - const alertIcon = + const alertIcon = const newConnection = useMutation({ mutationFn: () => { return axios.post(AppSettings.url + '/connections', {}, { diff --git a/webapp/src/Routes/Home/Home.tsx b/webapp/src/Routes/Home/Home.tsx index 86dac66..dcb0fdf 100644 --- a/webapp/src/Routes/Home/Home.tsx +++ b/webapp/src/Routes/Home/Home.tsx @@ -6,7 +6,7 @@ import classes from './Home.module.css'; import { AppSettings } from '../../Constants/Constants'; import { useQuery } from '@tanstack/react-query'; import { UpgradeAlert } from './UpgradeAlert'; -import { IconPaperBag } from '@tabler/icons-react'; +import { TbPaperBag } from 'react-icons/tb'; import { UserStats } from './UserStats'; export function Home() { @@ -54,7 +54,7 @@ export function Home() { {isPending || data.cloudType === "aws-marketplace" || data.cloudType === "azure" ? null : - diff --git a/webapp/src/Routes/Home/UpgradeAlert.tsx b/webapp/src/Routes/Home/UpgradeAlert.tsx index 98d145b..f4ef6b8 100644 --- a/webapp/src/Routes/Home/UpgradeAlert.tsx +++ b/webapp/src/Routes/Home/UpgradeAlert.tsx @@ -4,7 +4,7 @@ import { Link } from 'react-router-dom'; import { AppSettings } from '../../Constants/Constants'; import { useQuery } from '@tanstack/react-query'; -import { IconInfoCircle } from '@tabler/icons-react'; +import { TbInfoCircle } from "react-icons/tb"; export function UpgradeAlert() { const {authInfo} = useAuthContext() @@ -26,7 +26,7 @@ export function UpgradeAlert() { if (error) return '' if (isPending) return '' - const alertIcon = + const alertIcon = if (!data.newVersionAvailable) return '' diff --git a/webapp/src/Routes/Licenses/GetMoreLicenses.tsx b/webapp/src/Routes/Licenses/GetMoreLicenses.tsx index 9de1121..9fc3ede 100644 --- a/webapp/src/Routes/Licenses/GetMoreLicenses.tsx +++ b/webapp/src/Routes/Licenses/GetMoreLicenses.tsx @@ -4,7 +4,7 @@ import classes from './GetMoreLicenses.module.css'; import { AppSettings } from '../../Constants/Constants'; import { useQuery, useQueryClient } from '@tanstack/react-query'; -import { IconInfoCircle } from '@tabler/icons-react'; +import { TbInfoCircle } from "react-icons/tb"; import { useState } from 'react'; export function GetMoreLicenses() { @@ -36,7 +36,7 @@ export function GetMoreLicenses() { if (error) return 'cannot retrieve licensed users' if (isPending) return 'Loading...' - const alertIcon = ; + const alertIcon = ; return ( diff --git a/webapp/src/Routes/PacketLogs/PacketLogs.tsx b/webapp/src/Routes/PacketLogs/PacketLogs.tsx new file mode 100644 index 0000000..8f2b681 --- /dev/null +++ b/webapp/src/Routes/PacketLogs/PacketLogs.tsx @@ -0,0 +1,223 @@ +import { Card, Container, Text, Table, Title, Button, Grid, Select, MultiSelect, Popover, Group, TextInput, rem, ActionIcon, Highlight} from "@mantine/core"; +import { AppSettings } from "../../Constants/Constants"; +import { useInfiniteQuery } from "@tanstack/react-query"; +import { useAuthContext } from "../../Auth/Auth"; +import { Link, useSearchParams } from "react-router-dom"; +import { TbArrowRight, TbSearch, TbSettings } from "react-icons/tb"; +import { DatePickerInput } from "@mantine/dates"; +import { useEffect, useState } from "react"; +import React from "react"; + +type LogsDataResponse = { + enabled: boolean; + logData: LogData; + logTypes: string[]; + users: UserMap; +} +type LogData = { + schema: LogDataSchema; + rows: LogRow[]; + nextPos: number; +} +type LogDataSchema = { + columns: string[]; +} +type LogRow = { + t: string; + d: string[]; +} +type UserMap = { + [key: string]: string; +} + +function getDate(date:Date) { + var dd = String(date.getDate()).padStart(2, '0'); + var mm = String(date.getMonth() + 1).padStart(2, '0'); //January is 0! + var yyyy = date.getFullYear(); + return yyyy + "-" + mm + '-' + dd; +} + +export function PacketLogs() { + const {authInfo} = useAuthContext(); + const timezoneOffset = new Date().getTimezoneOffset() * -1 + const [currentQueryParameters] = useSearchParams(); + const dateParam = currentQueryParameters.get("date") + const userParam = currentQueryParameters.get("user") + const [logType, setLogType] = useState([]) + const [search, setSearch] = useState("") + const [searchParam, setSearchParam] = useState("") + const [logsDate, setLogsDate] = useState(dateParam === null ? new Date() : new Date(dateParam)); + const [user, setUser] = useState(userParam === null ? "all" : userParam) + const { isPending, fetchNextPage, hasNextPage, error, data } = useInfiniteQuery({ + queryKey: ['packetlogs', user, logsDate, logType, searchParam], + queryFn: async ({ pageParam }) => + fetch(AppSettings.url + '/stats/packetlogs/'+(user === undefined || user === "" ? "all" : user)+'/'+(logsDate == undefined ? getDate(new Date()) : getDate(logsDate)) + "?pos="+pageParam+"&offset="+timezoneOffset+"&logtype="+encodeURIComponent(logType.join(","))+"&search="+encodeURIComponent(searchParam), { + headers: { + "Content-Type": "application/json", + "Authorization": "Bearer " + authInfo.token + }, + }).then((res) => { + return res.json() + } + ), + initialPageParam: 0, + getNextPageParam: (lastRequest) => lastRequest.logData.nextPos === -1 ? null : lastRequest.logData.nextPos, + }) + + const captureEnter = (e: React.KeyboardEvent) => { + if (e.key === "Enter") { + setSearchParam(search) + } + } + + useEffect(() => { + const handleScroll = () => { + const { scrollTop, clientHeight, scrollHeight } = + document.documentElement; + if (scrollTop + clientHeight >= scrollHeight - 20) { + fetchNextPage(); + } + }; + + window.addEventListener("scroll", handleScroll); + return () => { + window.removeEventListener("scroll", handleScroll); + }; + }, [fetchNextPage]) + + if(isPending) return "Loading..." + if(error) return 'A backend error has occurred: ' + error.message + + if(data.pages.length === 0 || !data.pages[0].enabled || data.pages[0].logTypes.length == 0) { // show disabled page if not enabled + return ( + + + Packet Logs + + + + { !data.pages[0].enabled ? + "Packet Logs are not activated. Activate packet logging in the VPN Settings." + : + data.pages[0].logTypes.length == 0 ? "Packet logs are activated, but no packet logging types are selected. Select at least one packet log type." : null + } + + + + + + + + + ) + } + + const rows = data.pages.map((group, groupIndex) => ( + + {group.logData.rows.map((row, i) => ( + + {row.t} + {row.d.map((element, y) => { + return ( + {searchParam === "" ? element : {element}} + ) + })} + + ))} + + )); + return ( + + + Packet Logs + + + + } + rightSection={ + setSearchParam(search)}> + + + } + onKeyDown={(e) => captureEnter(e)} + onChange={(e) => setSearch(e.currentTarget.value)} + value={search} + /> + + + + + +