From 81ce9d5dc53e07dd802e4aeb20e7fcec60caffbb Mon Sep 17 00:00:00 2001 From: Edward Viaene Date: Sun, 8 Sep 2024 21:03:52 -0500 Subject: [PATCH] refactor memorystorage, add goroutine for logratation) --- pkg/auth/oidc/store/discovery_test.go | 4 +- pkg/auth/oidc/store/jwks_test.go | 4 +- pkg/auth/oidc/store/save_test.go | 4 +- pkg/auth/provisioning/scim/users_test.go | 18 +-- pkg/configmanager/start_darwin.go | 2 + pkg/configmanager/start_linux.go | 2 + pkg/license/gcp_test.go | 10 +- pkg/license/license_test.go | 33 +++-- pkg/rest/auth_test.go | 14 +-- pkg/rest/rsa_test.go | 4 +- pkg/rest/stats_test.go | 4 +- pkg/rest/users_test.go | 6 +- pkg/storage/local/path.go | 2 +- .../mocks => storage/memory}/storage.go | 105 +++++----------- pkg/wireguard/packetlogger.go | 17 ++- pkg/wireguard/packetlogger_test.go | 119 +++++++++++++++++- pkg/wireguard/wireguardclientconfig_test.go | 28 ++--- pkg/wireguard/wireguardserverconfig_test.go | 4 +- 18 files changed, 229 insertions(+), 151 deletions(-) rename pkg/{testing/mocks => storage/memory}/storage.go (51%) diff --git a/pkg/auth/oidc/store/discovery_test.go b/pkg/auth/oidc/store/discovery_test.go index 3d5037e..31b6d9b 100644 --- a/pkg/auth/oidc/store/discovery_test.go +++ b/pkg/auth/oidc/store/discovery_test.go @@ -7,7 +7,7 @@ import ( "testing" "github.com/in4it/wireguard-server/pkg/auth/oidc" - testingmocks "github.com/in4it/wireguard-server/pkg/testing/mocks" + memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" ) func TestGetDiscovery(t *testing.T) { @@ -27,7 +27,7 @@ func TestGetDiscovery(t *testing.T) { })) defer ts.Close() - store, err := NewStore(&testingmocks.MockMemoryStorage{}) + store, err := NewStore(&memorystorage.MockMemoryStorage{}) if err != nil { t.Fatalf("new store error: %s", err) } diff --git a/pkg/auth/oidc/store/jwks_test.go b/pkg/auth/oidc/store/jwks_test.go index 74698ff..94c808a 100644 --- a/pkg/auth/oidc/store/jwks_test.go +++ b/pkg/auth/oidc/store/jwks_test.go @@ -7,7 +7,7 @@ import ( "testing" "github.com/in4it/wireguard-server/pkg/auth/oidc" - testingmocks "github.com/in4it/wireguard-server/pkg/testing/mocks" + memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" ) func TestGetJwks(t *testing.T) { @@ -32,7 +32,7 @@ func TestGetJwks(t *testing.T) { })) defer ts.Close() - store, err := NewStore(&testingmocks.MockMemoryStorage{}) + store, err := NewStore(&memorystorage.MockMemoryStorage{}) if err != nil { t.Fatalf("new store error: %s", err) } diff --git a/pkg/auth/oidc/store/save_test.go b/pkg/auth/oidc/store/save_test.go index bc56eb2..9e33d53 100644 --- a/pkg/auth/oidc/store/save_test.go +++ b/pkg/auth/oidc/store/save_test.go @@ -4,11 +4,11 @@ import ( "testing" "github.com/in4it/wireguard-server/pkg/auth/oidc" - testingmocks "github.com/in4it/wireguard-server/pkg/testing/mocks" + memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" ) func TestSave(t *testing.T) { - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} store, err := NewStore(storage) if err != nil { t.Fatalf("error: %s", err) diff --git a/pkg/auth/provisioning/scim/users_test.go b/pkg/auth/provisioning/scim/users_test.go index 773eb9f..2c399ac 100644 --- a/pkg/auth/provisioning/scim/users_test.go +++ b/pkg/auth/provisioning/scim/users_test.go @@ -11,7 +11,7 @@ import ( "path" "testing" - testingmocks "github.com/in4it/wireguard-server/pkg/testing/mocks" + memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" "github.com/in4it/wireguard-server/pkg/users" "github.com/in4it/wireguard-server/pkg/wireguard" ) @@ -19,7 +19,7 @@ import ( const USERSTORE_MAX_USERS = 1000 func TestUsersGetCount100EmptyResult(t *testing.T) { - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} userStore, err := users.NewUserStore(storage, USERSTORE_MAX_USERS) if err != nil { @@ -48,7 +48,7 @@ func TestUsersGetCount100EmptyResult(t *testing.T) { } func TestUsersGetCount10(t *testing.T) { - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} userStore, err := users.NewUserStore(storage, USERSTORE_MAX_USERS) if err != nil { t.Fatalf("cannot create new user store") @@ -88,7 +88,7 @@ func TestUsersGetCount10(t *testing.T) { func TestUsersGetCount10Start5(t *testing.T) { count := 10 start := 5 - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} userStore, err := users.NewUserStore(storage, USERSTORE_MAX_USERS) if err != nil { t.Fatalf("cannot create new user store") @@ -138,12 +138,12 @@ func TestUsersGetCount10Start5(t *testing.T) { } func TestUsersGetNonExistentUser(t *testing.T) { - userStore, err := users.NewUserStore(&testingmocks.MockMemoryStorage{}, USERSTORE_MAX_USERS) + userStore, err := users.NewUserStore(&memorystorage.MockMemoryStorage{}, USERSTORE_MAX_USERS) if err != nil { t.Fatalf("cannot create new user stoer") } - s := New(&testingmocks.MockMemoryStorage{}, userStore, "token") + s := New(&memorystorage.MockMemoryStorage{}, userStore, "token") req := httptest.NewRequest("GET", "http://example.com/api/scim/v2/Users?filter=userName+eq+%22ward%40in4it.io%22&", nil) w := httptest.NewRecorder() s.getUsersHandler(w, req) @@ -161,7 +161,7 @@ func TestUsersGetNonExistentUser(t *testing.T) { } func TestAddUser(t *testing.T) { - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} userStore, err := users.NewUserStore(storage, USERSTORE_MAX_USERS) if err != nil { t.Fatalf("cannot create new user store: %s", err) @@ -208,7 +208,7 @@ func TestAddUser(t *testing.T) { } func TestCreateUserConnectionDeleteUserFlow(t *testing.T) { - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} userStore, err := users.NewUserStore(storage, USERSTORE_MAX_USERS) if err != nil { t.Fatalf("cannot create new user store: %s", err) @@ -328,7 +328,7 @@ func TestCreateUserConnectionDeleteUserFlow(t *testing.T) { } } func TestCreateUserConnectionSuspendUserFlow(t *testing.T) { - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} userStore, err := users.NewUserStore(storage, USERSTORE_MAX_USERS) if err != nil { diff --git a/pkg/configmanager/start_darwin.go b/pkg/configmanager/start_darwin.go index db6fba3..07ea124 100644 --- a/pkg/configmanager/start_darwin.go +++ b/pkg/configmanager/start_darwin.go @@ -26,4 +26,6 @@ func startStats(storage storage.Iface) { func startPacketLogger(storage storage.Iface, clientCache *wireguard.ClientCache, vpnConfig *wireguard.VPNConfig) { go wireguard.RunPacketLogger(storage, clientCache, vpnConfig) + // run cleanup + go wireguard.PacketLoggerLogRotation(storage) } diff --git a/pkg/configmanager/start_linux.go b/pkg/configmanager/start_linux.go index 65e0639..033ee02 100644 --- a/pkg/configmanager/start_linux.go +++ b/pkg/configmanager/start_linux.go @@ -31,4 +31,6 @@ func startStats(storage storage.Iface) { func startPacketLogger(storage storage.Iface, clientCache *wireguard.ClientCache, vpnConfig *wireguard.VPNConfig) { // run statistics go routine go wireguard.RunPacketLogger(storage, clientCache, vpnConfig) + // run cleanup + go wireguard.PacketLoggerLogRotation(storage) } diff --git a/pkg/license/gcp_test.go b/pkg/license/gcp_test.go index 33cc566..37d2672 100644 --- a/pkg/license/gcp_test.go +++ b/pkg/license/gcp_test.go @@ -9,7 +9,7 @@ import ( "testing" "time" - testingmocks "github.com/in4it/wireguard-server/pkg/testing/mocks" + memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" ) func TestGuessInfrastructureGCP(t *testing.T) { @@ -52,10 +52,10 @@ func TestGetMaxUsersGCPBYOL(t *testing.T) { licenseURL = ts.URL metadataIP = strings.Replace(ts.URL, "http://", "", -1) - mockStorage := &testingmocks.MockReadWriter{ - Data: map[string][]byte{ - "config/license.key": []byte("license-1234556-license"), - }, + mockStorage := &memorystorage.MockMemoryStorage{} + err := mockStorage.WriteFile("config/license.key", []byte("license-1234556-license")) + if err != nil { + t.Fatalf("writefile error: %s", err) } for _, v := range []int{50} { diff --git a/pkg/license/license_test.go b/pkg/license/license_test.go index 8a41601..9f41356 100644 --- a/pkg/license/license_test.go +++ b/pkg/license/license_test.go @@ -11,7 +11,7 @@ import ( "time" "github.com/in4it/wireguard-server/pkg/logging" - testingmocks "github.com/in4it/wireguard-server/pkg/testing/mocks" + memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" ) func TestGetMaxUsersAzure(t *testing.T) { @@ -39,7 +39,7 @@ func TestGetMaxUsersAWSMarketplace(t *testing.T) { "t3.xlarge": 250, } for instanceType, v := range testCases { - if getMaxUsers(&testingmocks.MockReadWriter{}, "aws-marketplace", instanceType) != v { + if getMaxUsers(&memorystorage.MockMemoryStorage{}, "aws-marketplace", instanceType) != v { t.Fatalf("Wrong output: %d vs %d", GetMaxUsersAWS(instanceType), v) } } @@ -98,7 +98,7 @@ func TestGetMaxUsersAWSBYOL(t *testing.T) { licenseURL = ts.URL metadataIP = strings.Replace(ts.URL, "http://", "", -1) for _, v := range testCases { - if v2 := GetMaxUsersAWSBYOL(http.Client{Timeout: 5 * time.Second}, &testingmocks.MockMemoryStorage{}); v2 != v { + if v2 := GetMaxUsersAWSBYOL(http.Client{Timeout: 5 * time.Second}, &memorystorage.MockMemoryStorage{}); v2 != v { t.Fatalf("Wrong output: %d vs %d", v2, v) } } @@ -111,7 +111,7 @@ func TestGetMaxUsersAWS(t *testing.T) { "t3.xlarge": 3, } for instanceType, v := range testCases { - if getMaxUsers(&testingmocks.MockReadWriter{}, "aws", instanceType) != v { + if getMaxUsers(&memorystorage.MockMemoryStorage{}, "aws", instanceType) != v { t.Fatalf("Wrong output: %d vs %d", GetMaxUsersAWS(instanceType), v) } } @@ -212,7 +212,7 @@ func TestGuessInfrastructureAWS(t *testing.T) { t.Fatalf("wrong infra returned: %s", infra) } - if getMaxUsers(&testingmocks.MockReadWriter{}, infra, "t3.large") != 3 { + if getMaxUsers(&memorystorage.MockMemoryStorage{}, infra, "t3.large") != 3 { t.Fatalf("wrong users returned") } } @@ -252,7 +252,7 @@ func TestGetAzureInstanceType(t *testing.T) { usersPerVCPU := 25 - users := getMaxUsers(&testingmocks.MockReadWriter{}, "azure", getAzureInstanceType(http.Client{Timeout: 5 * time.Second})) + users := getMaxUsers(&memorystorage.MockMemoryStorage{}, "azure", getAzureInstanceType(http.Client{Timeout: 5 * time.Second})) if users != usersPerVCPU*2 { t.Fatalf("Wrong user count returned") @@ -302,7 +302,7 @@ func TestGuessInfrastructureDigitalOcean(t *testing.T) { t.Fatalf("wrong infra returned: %s", infra) } - if getMaxUsers(&testingmocks.MockReadWriter{}, infra, "t3.large") != 3 { + if getMaxUsers(&memorystorage.MockMemoryStorage{}, infra, "t3.large") != 3 { t.Fatalf("wrong users returned") } } @@ -332,12 +332,11 @@ func TestGetMaxUsersDigitalOceanBYOL(t *testing.T) { licenseURL = ts.URL metadataIP = strings.Replace(ts.URL, "http://", "", -1) - mockStorage := &testingmocks.MockReadWriter{ - Data: map[string][]byte{ - "config/license.key": []byte("license-1234556-license"), - }, + mockStorage := &memorystorage.MockMemoryStorage{} + err := mockStorage.WriteFile("config/license.key", []byte("license-1234556-license")) + if err != nil { + t.Fatalf("writefile error: %s", err) } - for _, v := range []int{50} { if v2 := GetMaxUsersDigitalOceanBYOL(http.Client{Timeout: 5 * time.Second}, mockStorage); v2 != v { t.Fatalf("Wrong output: %d vs %d", v2, v) @@ -387,19 +386,19 @@ func TestGetLicenseKey(t *testing.T) { metadataIP = strings.Replace(ts.URL, "http://", "", -1) logging.Loglevel = logging.LOG_DEBUG + logging.LOG_ERROR - key := GetLicenseKey(&testingmocks.MockMemoryStorage{}, "") + key := GetLicenseKey(&memorystorage.MockMemoryStorage{}, "") if key == "" { t.Fatalf("key is empty") } - key = GetLicenseKey(&testingmocks.MockMemoryStorage{}, "aws") + key = GetLicenseKey(&memorystorage.MockMemoryStorage{}, "aws") if key == "" { t.Fatalf("aws key is empty") } - key = GetLicenseKey(&testingmocks.MockMemoryStorage{}, "digitalocean") + key = GetLicenseKey(&memorystorage.MockMemoryStorage{}, "digitalocean") if key == "" { t.Fatalf("digitalocean key is empty") } - key = GetLicenseKey(&testingmocks.MockMemoryStorage{}, "gcp") + key = GetLicenseKey(&memorystorage.MockMemoryStorage{}, "gcp") if key == "" { t.Fatalf("gcp key is empty") } @@ -407,7 +406,7 @@ func TestGetLicenseKey(t *testing.T) { func TestGetLicenseKeyNoCloudProvider(t *testing.T) { logging.Loglevel = logging.LOG_DEBUG + logging.LOG_ERROR - key := GetLicenseKey(&testingmocks.MockMemoryStorage{}, "") + key := GetLicenseKey(&memorystorage.MockMemoryStorage{}, "") if key == "" { t.Fatalf("key is empty") } diff --git a/pkg/rest/auth_test.go b/pkg/rest/auth_test.go index 17078e9..848e8f3 100644 --- a/pkg/rest/auth_test.go +++ b/pkg/rest/auth_test.go @@ -23,7 +23,7 @@ import ( "github.com/in4it/wireguard-server/pkg/auth/saml" "github.com/in4it/wireguard-server/pkg/logging" "github.com/in4it/wireguard-server/pkg/rest/login" - testingmocks "github.com/in4it/wireguard-server/pkg/testing/mocks" + memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" "github.com/in4it/wireguard-server/pkg/users" "github.com/russellhaering/gosaml2/types" dsigtypes "github.com/russellhaering/goxmldsig/types" @@ -80,7 +80,7 @@ c7tL1QjbfAUHAQYwmHkWgPP+T2wAv0pOt36GgMCM` } func TestAuthHandler(t *testing.T) { - c, err := newContext(&testingmocks.MockMemoryStorage{}, SERVER_TYPE_VPN) + c, err := newContext(&memorystorage.MockMemoryStorage{}, SERVER_TYPE_VPN) if err != nil { t.Fatalf("Cannot create context: %s", err) } @@ -130,7 +130,7 @@ func TestAuthHandler(t *testing.T) { func TestNewSAMLConnection(t *testing.T) { // generate new keypair - kp := saml.NewKeyPair(&testingmocks.MockMemoryStorage{}, "www.idp.inv") + kp := saml.NewKeyPair(&memorystorage.MockMemoryStorage{}, "www.idp.inv") _, cert, err := kp.GetKeyPair() if err != nil { t.Fatalf("Can't generate new keypair: %s", err) @@ -191,7 +191,7 @@ func TestNewSAMLConnection(t *testing.T) { defer l.Close() // first create a new user - c, err := newContext(&testingmocks.MockMemoryStorage{}, SERVER_TYPE_VPN) + c, err := newContext(&memorystorage.MockMemoryStorage{}, SERVER_TYPE_VPN) if err != nil { t.Fatalf("Cannot create context") } @@ -419,7 +419,7 @@ func TestNewSAMLConnection(t *testing.T) { } func TestAddModifyDeleteNewSAMLConnection(t *testing.T) { - c, err := newContext(&testingmocks.MockMemoryStorage{}, SERVER_TYPE_VPN) + c, err := newContext(&memorystorage.MockMemoryStorage{}, SERVER_TYPE_VPN) if err != nil { t.Fatalf("Cannot create context") } @@ -592,7 +592,7 @@ func TestAddModifyDeleteNewSAMLConnection(t *testing.T) { } func TestSAMLCallback(t *testing.T) { - c, err := newContext(&testingmocks.MockMemoryStorage{}, SERVER_TYPE_VPN) + c, err := newContext(&memorystorage.MockMemoryStorage{}, SERVER_TYPE_VPN) if err != nil { t.Fatalf("Cannot create context") } @@ -701,7 +701,7 @@ func TestOIDCFlow(t *testing.T) { } // first create a new user - c, err := newContext(&testingmocks.MockMemoryStorage{}, SERVER_TYPE_VPN) + c, err := newContext(&memorystorage.MockMemoryStorage{}, SERVER_TYPE_VPN) if err != nil { t.Fatalf("Cannot create context") } diff --git a/pkg/rest/rsa_test.go b/pkg/rest/rsa_test.go index 8ac5844..ce23943 100644 --- a/pkg/rest/rsa_test.go +++ b/pkg/rest/rsa_test.go @@ -6,11 +6,11 @@ import ( "encoding/pem" "testing" - testingmocks "github.com/in4it/wireguard-server/pkg/testing/mocks" + memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" ) func TestGetJWTKeys(t *testing.T) { - mockStorage := testingmocks.MockMemoryStorage{} + mockStorage := memorystorage.MockMemoryStorage{} keys, err := getJWTKeys(&mockStorage) if err != nil { t.Fatalf("error: %s", err) diff --git a/pkg/rest/stats_test.go b/pkg/rest/stats_test.go index b7c2b2a..b1da1fc 100644 --- a/pkg/rest/stats_test.go +++ b/pkg/rest/stats_test.go @@ -8,13 +8,13 @@ import ( "testing" "time" - testingmocks "github.com/in4it/wireguard-server/pkg/testing/mocks" + memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" "github.com/in4it/wireguard-server/pkg/wireguard" ) func TestUserStatsHandler(t *testing.T) { - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} c, err := newContext(storage, SERVER_TYPE_VPN) if err != nil { diff --git a/pkg/rest/users_test.go b/pkg/rest/users_test.go index 99652d6..c516691 100644 --- a/pkg/rest/users_test.go +++ b/pkg/rest/users_test.go @@ -13,7 +13,7 @@ import ( "strings" "testing" - testingmocks "github.com/in4it/wireguard-server/pkg/testing/mocks" + memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" "github.com/in4it/wireguard-server/pkg/users" "github.com/in4it/wireguard-server/pkg/wireguard" ) @@ -50,7 +50,7 @@ func TestCreateUserConnectionDeleteUserFlow(t *testing.T) { defer l.Close() // first create a new user - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} c, err := newContext(storage, SERVER_TYPE_VPN) if err != nil { @@ -161,7 +161,7 @@ func TestCreateUserConnectionDeleteUserFlow(t *testing.T) { func TestCreateUser(t *testing.T) { // first create a new user - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} c, err := newContext(storage, SERVER_TYPE_VPN) if err != nil { diff --git a/pkg/storage/local/path.go b/pkg/storage/local/path.go index c4b5773..f1e8d8b 100644 --- a/pkg/storage/local/path.go +++ b/pkg/storage/local/path.go @@ -83,5 +83,5 @@ func (l *LocalStorage) Remove(name string) error { } func (l *LocalStorage) Rename(oldName, newName string) error { - return os.Rename(oldName, newName) + return os.Rename(path.Join(l.path, oldName), path.Join(l.path, newName)) } diff --git a/pkg/testing/mocks/storage.go b/pkg/storage/memory/storage.go similarity index 51% rename from pkg/testing/mocks/storage.go rename to pkg/storage/memory/storage.go index 4bf82c8..2246e63 100644 --- a/pkg/testing/mocks/storage.go +++ b/pkg/storage/memory/storage.go @@ -1,4 +1,4 @@ -package testingmocks +package memorystorage import ( "bufio" @@ -18,72 +18,18 @@ func (mwc *MyWriteCloser) Close() error { return nil } -type MockReadWriter struct { - Data map[string][]byte -} +type MockReadWriterData []byte -func (m *MockReadWriter) ConfigPath(filename string) string { - return path.Join("config", filename) -} -func (m *MockReadWriter) FileExists(name string) bool { - if m.Data == nil { - m.Data = make(map[string][]byte) - } - _, ok := m.Data[name] - return ok -} - -func (m *MockReadWriter) ReadFile(name string) ([]byte, error) { - if m.Data == nil { - m.Data = make(map[string][]byte) - } - val, ok := m.Data[name] - if !ok { - return val, fmt.Errorf("file does not exist") - } - return val, nil -} -func (m *MockReadWriter) WriteFile(name string, data []byte) error { - if m.Data == nil { - m.Data = make(map[string][]byte) - } - m.Data[name] = data +func (m *MockReadWriterData) Close() error { return nil } - -func (m *MockReadWriter) OpenFile(name string) (io.ReadCloser, error) { - if m.Data == nil { - m.Data = make(map[string][]byte) - } - val, ok := m.Data[name] - if !ok { - return nil, fmt.Errorf("file does not exist") - } - - return io.NopCloser(bytes.NewBuffer(val)), nil -} -func (m *MockReadWriter) OpenFileForWriting(name string) (io.WriteCloser, error) { - if m.Data == nil { - m.Data = make(map[string][]byte) - } - buf := bufio.NewWriter(bytes.NewBuffer(m.Data[name])) - return &MyWriteCloser{buf}, nil -} -func (m *MockReadWriter) Rename(oldName, newName string) error { - if m.Data == nil { - m.Data = make(map[string][]byte) - } - _, ok := m.Data[oldName] - if !ok { - return fmt.Errorf("file doesn't exist") - } - m.Data[newName] = m.Data[oldName] - delete(m.Data, oldName) - return nil +func (m *MockReadWriterData) Write(p []byte) (nn int, err error) { + *m = append(*m, p...) + return len(p), nil } type MockMemoryStorage struct { - Data map[string][]byte + Data map[string]*MockReadWriterData } func (m *MockMemoryStorage) ConfigPath(filename string) string { @@ -91,7 +37,7 @@ func (m *MockMemoryStorage) ConfigPath(filename string) string { } func (m *MockMemoryStorage) Rename(oldName, newName string) error { if m.Data == nil { - m.Data = make(map[string][]byte) + m.Data = make(map[string]*MockReadWriterData) } _, ok := m.Data[oldName] if !ok { @@ -103,7 +49,7 @@ func (m *MockMemoryStorage) Rename(oldName, newName string) error { } func (m *MockMemoryStorage) FileExists(name string) bool { if m.Data == nil { - m.Data = make(map[string][]byte) + m.Data = make(map[string]*MockReadWriterData) } _, ok := m.Data[name] return ok @@ -111,26 +57,31 @@ func (m *MockMemoryStorage) FileExists(name string) bool { func (m *MockMemoryStorage) ReadFile(name string) ([]byte, error) { if m.Data == nil { - m.Data = make(map[string][]byte) + m.Data = make(map[string]*MockReadWriterData) } val, ok := m.Data[name] if !ok { - return val, fmt.Errorf("file does not exist") + return nil, fmt.Errorf("file does not exist") } - return val, nil + return *val, nil } func (m *MockMemoryStorage) WriteFile(name string, data []byte) error { if m.Data == nil { - m.Data = make(map[string][]byte) + m.Data = make(map[string]*MockReadWriterData) } - m.Data[name] = data + m.Data[name] = (*MockReadWriterData)(&data) return nil } func (m *MockMemoryStorage) AppendFile(name string, data []byte) error { if m.Data == nil { - m.Data = make(map[string][]byte) + m.Data = make(map[string]*MockReadWriterData) } - m.Data[name] = append(m.Data[name], data...) + if m.Data[name] == nil { + m.Data[name] = (*MockReadWriterData)(&data) + } else { + *m.Data[name] = append(*m.Data[name], data...) + } + return nil } @@ -149,7 +100,7 @@ func (m *MockMemoryStorage) EnsureOwnership(filename, login string) error { func (m *MockMemoryStorage) ReadDir(path string) ([]string, error) { if m.Data == nil { - m.Data = make(map[string][]byte) + m.Data = make(map[string]*MockReadWriterData) } res := []string{} for k := range m.Data { @@ -162,7 +113,7 @@ func (m *MockMemoryStorage) ReadDir(path string) ([]string, error) { func (m *MockMemoryStorage) Remove(name string) error { if m.Data == nil { - m.Data = make(map[string][]byte) + m.Data = make(map[string]*MockReadWriterData) } delete(m.Data, name) return nil @@ -173,19 +124,19 @@ func (m *MockMemoryStorage) OpenFilesFromPos(names []string, pos int64) ([]io.Re } func (m *MockMemoryStorage) OpenFile(name string) (io.ReadCloser, error) { if m.Data == nil { - m.Data = make(map[string][]byte) + m.Data = make(map[string]*MockReadWriterData) } val, ok := m.Data[name] if !ok { return nil, fmt.Errorf("file does not exist") } - return io.NopCloser(bytes.NewBuffer(val)), nil + return io.NopCloser(bytes.NewBuffer(*val)), nil } func (m *MockMemoryStorage) OpenFileForWriting(name string) (io.WriteCloser, error) { if m.Data == nil { - m.Data = make(map[string][]byte) + m.Data = make(map[string]*MockReadWriterData) } - buf := bufio.NewWriter(bytes.NewBuffer(m.Data[name])) - return &MyWriteCloser{buf}, nil + m.Data[name] = (*MockReadWriterData)(&[]byte{}) + return m.Data[name], nil } diff --git a/pkg/wireguard/packetlogger.go b/pkg/wireguard/packetlogger.go index b1fbe13..4d977fb 100644 --- a/pkg/wireguard/packetlogger.go +++ b/pkg/wireguard/packetlogger.go @@ -301,8 +301,19 @@ func checkDiskSpace() error { // Packet log rotation func PacketLoggerLogRotation(storage storage.Iface) { - err := packetLoggerLogRotation(storage) - logging.ErrorLog(fmt.Errorf("packet logger log rotation error: %s", err)) + for { + err := packetLoggerLogRotation(storage) + if err != nil { + logging.ErrorLog(fmt.Errorf("packet logger log rotation error: %s", err)) + } + time.Sleep(getTimeUntilTomorrowStartOfDay()) // sleep until tomorrow + } +} + +func getTimeUntilTomorrowStartOfDay() time.Duration { + tomorrow := time.Now().AddDate(0, 0, 1) + tomorrowStartOfDay := time.Date(tomorrow.Year(), tomorrow.Month(), tomorrow.Day(), 0, 5, 0, 0, time.Local) + return time.Until(tomorrowStartOfDay) } func packetLoggerLogRotation(storage storage.Iface) error { @@ -314,7 +325,7 @@ func packetLoggerLogRotation(storage storage.Iface) error { for _, filename := range files { filenameSplit := strings.Split(strings.TrimSuffix(filename, ".log"), "-") if len(filenameSplit) > 3 { - dateParsed, err := time.Parse("2006-01-02", filenameSplit[len(filenameSplit)-3]) + dateParsed, err := time.Parse("2006-01-02", strings.Join(filenameSplit[len(filenameSplit)-3:], "-")) if err == nil { if !dateutils.DateEqual(dateParsed, time.Now()) { err := packetLoggerCompressLog(storage, filename) diff --git a/pkg/wireguard/packetlogger_test.go b/pkg/wireguard/packetlogger_test.go index 4728578..edc00bb 100644 --- a/pkg/wireguard/packetlogger_test.go +++ b/pkg/wireguard/packetlogger_test.go @@ -1,18 +1,25 @@ package wireguard import ( + "bytes" + "compress/gzip" "encoding/hex" + "fmt" + "io" "net" + "os" "path" "strings" "testing" "time" - testingmocks "github.com/in4it/wireguard-server/pkg/testing/mocks" + localstorage "github.com/in4it/wireguard-server/pkg/storage/local" + memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" + dateutils "github.com/in4it/wireguard-server/pkg/utils/date" ) func TestParsePacket(t *testing.T) { - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} clientCache := &ClientCache{ Addresses: []ClientCacheAddresses{ { @@ -75,7 +82,7 @@ func TestParsePacket(t *testing.T) { } func TestParsePacketSNI(t *testing.T) { - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} clientCache := &ClientCache{ Addresses: []ClientCacheAddresses{ { @@ -169,3 +176,109 @@ func TestCheckDiskSpace(t *testing.T) { t.Fatalf("disk space error: %s", err) } } + +func TestPacketLoggerLogRotation(t *testing.T) { + prefix := path.Join(VPN_STATS_DIR, VPN_PACKETLOGGER_DIR) + key1 := path.Join(prefix, fmt.Sprintf("1-2-3-4-%s.log", time.Now().AddDate(0, 0, -1).Format("2006-01-02"))) + value1 := []byte(time.Now().Format(TIMESTAMP_FORMAT) + `,https,10.189.184.2,64.233.180.104,60496,443,www.google.com`) + key2 := path.Join(prefix, fmt.Sprintf("1-2-3-4-%s.log", time.Now().Format("2006-01-02"))) + value2 := []byte(time.Now().Format(TIMESTAMP_FORMAT) + `,https,10.189.184.3,64.233.180.104,12345,443,www.google.com`) + + storage := &memorystorage.MockMemoryStorage{ + Data: map[string]*memorystorage.MockReadWriterData{}, + } + err := storage.WriteFile(key1, value1) + if err != nil { + t.Fatalf("write file error: %s", err) + } + err = storage.WriteFile(key2, value2) + if err != nil { + t.Fatalf("write file error: %s", err) + } + + err = packetLoggerLogRotation(storage) + if err != nil { + t.Fatalf("packetLoggerRotation error: %s", err) + } + body, err := storage.ReadFile(key1 + ".gz") + if err != nil { + t.Fatalf("can't read compressed file") + } + reader, err := gzip.NewReader(bytes.NewReader(body)) + if err != nil { + t.Fatalf("can't open gzip reader") + } + bodyDecoded, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("can't read gzip data") + } + if string(bodyDecoded) != string(value1) { + t.Fatalf("unexpected output. Got %s, expected: %s", bodyDecoded, value1) + } +} + +func TestPacketLoggerLogRotationLocalStorage(t *testing.T) { + prefix := path.Join(VPN_STATS_DIR, VPN_PACKETLOGGER_DIR) + key1 := path.Join(prefix, fmt.Sprintf("1-2-3-4-%s.log", time.Now().AddDate(0, 0, -1).Format("2006-01-02"))) + value1 := []byte(time.Now().Format(TIMESTAMP_FORMAT) + `,https,10.189.184.2,64.233.180.104,60496,443,www.google.com`) + key2 := path.Join(prefix, fmt.Sprintf("1-2-3-4-%s.log", time.Now().Format("2006-01-02"))) + value2 := []byte(time.Now().Format(TIMESTAMP_FORMAT) + `,https,10.189.184.3,64.233.180.104,12345,443,www.google.com`) + + pwd, err := os.Executable() + if err != nil { + t.Fatalf("os Executable error: %s", err) + } + storage, err := localstorage.NewWithPath(path.Dir(pwd)) + if err != nil { + t.Fatalf("localstorage error: %s", err) + } + err = storage.EnsurePath(VPN_STATS_DIR) + if err != nil { + t.Fatalf("could not ensure path: %s", err) + } + storage.EnsurePath(path.Join(VPN_STATS_DIR, VPN_PACKETLOGGER_DIR)) + if err != nil { + t.Fatalf("could not ensure path: %s", err) + } + err = storage.WriteFile(key1, value1) + if err != nil { + t.Fatalf("write file error: %s", err) + } + err = storage.WriteFile(key2, value2) + if err != nil { + t.Fatalf("write file error: %s", err) + } + t.Cleanup(func() { + os.Remove(path.Join(path.Dir(pwd), key1)) + os.Remove(path.Join(path.Dir(pwd), key1+".gz.tmp")) + os.Remove(path.Join(path.Dir(pwd), key1+".gz")) + os.Remove(path.Join(path.Dir(pwd), key2)) + }) + + err = packetLoggerLogRotation(storage) + if err != nil { + t.Fatalf("packetLoggerRotation error: %s", err) + } + body, err := storage.ReadFile(key1 + ".gz") + if err != nil { + t.Fatalf("can't read compressed file") + } + reader, err := gzip.NewReader(bytes.NewReader(body)) + if err != nil { + t.Fatalf("can't open gzip reader") + } + bodyDecoded, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("can't read gzip data") + } + if string(bodyDecoded) != string(value1) { + t.Fatalf("unexpected output. Got %s, expected: %s", bodyDecoded, value1) + } +} + +func TestGetTimeUntilTomorrowStartOfDay(t *testing.T) { + duration := getTimeUntilTomorrowStartOfDay() + if !dateutils.DateEqual(time.Now().Add(duration), time.Now().AddDate(0, 0, 1)) { + t.Fatalf("date is not tomorrow") + } +} diff --git a/pkg/wireguard/wireguardclientconfig_test.go b/pkg/wireguard/wireguardclientconfig_test.go index df266de..d6d38c2 100644 --- a/pkg/wireguard/wireguardclientconfig_test.go +++ b/pkg/wireguard/wireguardclientconfig_test.go @@ -12,7 +12,7 @@ import ( "testing" "time" - testingmocks "github.com/in4it/wireguard-server/pkg/testing/mocks" + memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" ) func TestGetNextFreeIPFromList(t *testing.T) { @@ -97,7 +97,7 @@ func TestWriteConfig(t *testing.T) { defer ts.Close() defer l.Close() - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} // first create a new vpn config vpnconfig, err := CreateNewVPNConfig(storage) @@ -174,7 +174,7 @@ func TestWriteConfigMultipleClients(t *testing.T) { defer ts.Close() defer l.Close() - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} // first create a new vpn config vpnconfig, err := CreateNewVPNConfig(storage) @@ -253,7 +253,7 @@ func TestCreateAndDeleteAllClientConfig(t *testing.T) { defer ts.Close() defer l.Close() - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} // first create a new vpn config vpnconfig, err := CreateNewVPNConfig(storage) @@ -287,7 +287,7 @@ func TestCreateAndDeleteAllClientConfig(t *testing.T) { t.Fatalf("couldn't find peer config file written in storage") } - err = json.Unmarshal(writtenPeerconfig, &peerConfig) + err = json.Unmarshal(*writtenPeerconfig, &peerConfig) if err != nil { t.Fatalf("unmarshal error: %s", err) } @@ -346,7 +346,7 @@ func TestCreateAndDeleteClientConfig(t *testing.T) { defer ts.Close() defer l.Close() - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} // first create a new vpn config vpnconfig, err := CreateNewVPNConfig(storage) @@ -380,7 +380,7 @@ func TestCreateAndDeleteClientConfig(t *testing.T) { t.Fatalf("couldn't find peer config file written in storage") } - err = json.Unmarshal(writtenPeerconfig, &peerConfig) + err = json.Unmarshal(*writtenPeerconfig, &peerConfig) if err != nil { t.Fatalf("unmarshal error: %s", err) } @@ -440,7 +440,7 @@ func TestCreateAndDisableAllClientConfig(t *testing.T) { defer ts.Close() defer l.Close() - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} // first create a new vpn config vpnconfig, err := CreateNewVPNConfig(storage) @@ -474,7 +474,7 @@ func TestCreateAndDisableAllClientConfig(t *testing.T) { t.Fatalf("couldn't find peer config file written in storage") } - err = json.Unmarshal(writtenPeerconfig, &peerConfig) + err = json.Unmarshal(*writtenPeerconfig, &peerConfig) if err != nil { t.Fatalf("unmarshal error: %s", err) } @@ -496,7 +496,7 @@ func TestCreateAndDisableAllClientConfig(t *testing.T) { t.Fatalf("couldn't find peer config file written in storage") } - err = json.Unmarshal(writtenPeerconfig, &peerConfig) + err = json.Unmarshal(*writtenPeerconfig, &peerConfig) if err != nil { t.Fatalf("unmarshal error: %s", err) } @@ -517,7 +517,7 @@ func TestCreateAndDisableAllClientConfig(t *testing.T) { t.Fatalf("couldn't find peer config file written in storage") } - err = json.Unmarshal(writtenPeerconfig, &peerConfig) + err = json.Unmarshal(*writtenPeerconfig, &peerConfig) if err != nil { t.Fatalf("unmarshal error: %s", err) } @@ -569,7 +569,7 @@ func TestUpdateClientConfig(t *testing.T) { defer ts.Close() defer l.Close() - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} // first create a new vpn config vpnconfig, err := CreateNewVPNConfig(storage) @@ -657,7 +657,7 @@ func TestUpdateClientConfigNewAddressRange(t *testing.T) { defer ts.Close() defer l.Close() - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} // first create a new vpn config vpnconfig, err := CreateNewVPNConfig(storage) @@ -774,7 +774,7 @@ func TestUpdateClientConfigNewClientAddressPrefix(t *testing.T) { defer ts.Close() defer l.Close() - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} // first create a new vpn config vpnconfig, err := CreateNewVPNConfig(storage) diff --git a/pkg/wireguard/wireguardserverconfig_test.go b/pkg/wireguard/wireguardserverconfig_test.go index 40f2a2c..3552bf6 100644 --- a/pkg/wireguard/wireguardserverconfig_test.go +++ b/pkg/wireguard/wireguardserverconfig_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - testingmocks "github.com/in4it/wireguard-server/pkg/testing/mocks" + memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" ) func TestWriteWireGuardServerConfig(t *testing.T) { @@ -53,7 +53,7 @@ func TestWriteWireGuardServerConfig(t *testing.T) { defer ts.Close() defer l.Close() - storage := &testingmocks.MockMemoryStorage{} + storage := &memorystorage.MockMemoryStorage{} // first create a new vpn config vpnconfig, err := CreateNewVPNConfig(storage)