diff --git a/cmd/reset-admin-password/main.go b/cmd/reset-admin-password/main.go index a3bb79d..a5cfb55 100644 --- a/cmd/reset-admin-password/main.go +++ b/cmd/reset-admin-password/main.go @@ -29,7 +29,7 @@ func main() { os.Exit(1) } if !newAdminUserCreated { - resetMFA, err := getLine("Also remove MFA if present? [Y/n] ") + resetMFA, err := getLine("\nAlso remove MFA if present? [Y/n] ") if err != nil { fmt.Printf("Failed to changed admin password: %s", err) os.Exit(1) diff --git a/docs/release-notes.md b/docs/release-notes.md index 698137b..0f53d4f 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -1,5 +1,14 @@ # Release Notes +## Version v1.1.4 +* Improved setup flow for AWS & DigitalOcean + +## Version v1.1.3 +* New Feature: Log packets traversing the VPN Server. This release supports logging TCP / DNS / HTTP / HTTPS packets and inspecting the destination of http/https packets. + +## Version v1.1.2 +* UI: fixes in user creation + ## Version v1.1.0 * UI: change VPN configuration within the admin UI * UI: ability to reload WireGuard® configuration diff --git a/latest b/latest index 99a4aef..c641220 100644 --- a/latest +++ b/latest @@ -1 +1 @@ -v1.1.3 +v1.1.4 diff --git a/pkg/commands/resetpassword.go b/pkg/commands/resetpassword.go index af8ac99..fbef43e 100644 --- a/pkg/commands/resetpassword.go +++ b/pkg/commands/resetpassword.go @@ -20,6 +20,9 @@ func ResetPassword(appDir, password string) (bool, error) { if err != nil { return adminCreated, fmt.Errorf("config retrieval error: %s", err) } + c.Storage = &rest.Storage{ + Client: localstorage, + } c.UserStore, err = users.NewUserStore(localstorage, -1) if err != nil { return adminCreated, fmt.Errorf("userstore initialization error: %s", err) diff --git a/pkg/license/aws.go b/pkg/license/aws.go index 1a81e87..4b507e7 100644 --- a/pkg/license/aws.go +++ b/pkg/license/aws.go @@ -13,7 +13,7 @@ import ( const AWS_PRODUCT_CODE = "7h7h3bnutjn0ziamv7npi8a69" func getMetadataToken(client http.Client) string { - metadataEndpoint := "http://" + metadataIP + "/latest/api/token" + metadataEndpoint := "http://" + MetadataIP + "/latest/api/token" req, err := http.NewRequest("PUT", metadataEndpoint, nil) if err != nil { @@ -62,7 +62,7 @@ func isOnAWS(client http.Client) bool { func getInstanceIdentityDocument(client http.Client, token string) (InstanceIdentityDocument, error) { var instanceIdentityDocument InstanceIdentityDocument - endpoint := "http://" + metadataIP + "/2022-09-24/dynamic/instance-identity/document" + endpoint := "http://" + MetadataIP + "/2022-09-24/dynamic/instance-identity/document" req, err := http.NewRequest("GET", endpoint, nil) if err != nil { return instanceIdentityDocument, err @@ -145,7 +145,7 @@ func getLicense(client http.Client, key string) (License, error) { } func getLicenseFromMetaData(token string, client http.Client) (string, error) { - endpoint := "http://" + metadataIP + "/2022-09-24/meta-data/tags/instance/license" + endpoint := "http://" + MetadataIP + "/2022-09-24/meta-data/tags/instance/license" req, err := http.NewRequest("GET", endpoint, nil) if err != nil { return "", err @@ -173,7 +173,7 @@ func getLicenseFromMetaData(token string, client http.Client) (string, error) { func getAWSInstanceType(client http.Client) string { token := getMetadataToken(client) - endpoint := "http://" + metadataIP + "/latest/meta-data/instance-type" + endpoint := "http://" + MetadataIP + "/latest/meta-data/instance-type" req, err := http.NewRequest("GET", endpoint, nil) if err != nil { return "" @@ -194,6 +194,30 @@ func getAWSInstanceType(client http.Client) string { return "" } +func GetAWSInstanceID(client http.Client) (string, error) { + token := getMetadataToken(client) + + endpoint := "http://" + MetadataIP + "/latest/meta-data/instance-id" + req, err := http.NewRequest("GET", endpoint, nil) + if err != nil { + return "", err + } + if token != "" { + req.Header.Add("X-aws-ec2-metadata-token", token) + } + + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + if resp.StatusCode == 200 { + bodyBytes, _ := io.ReadAll(resp.Body) + return string(bodyBytes), err + } + return "", fmt.Errorf("received statuscode %d from aws metadata api", resp.StatusCode) +} + func GetMaxUsersAWS(instanceType string) int { if instanceType == "" { return 3 diff --git a/pkg/license/azure.go b/pkg/license/azure.go index 468cc6f..18ab9de 100644 --- a/pkg/license/azure.go +++ b/pkg/license/azure.go @@ -9,7 +9,7 @@ import ( ) func isOnAzure(client http.Client) bool { - req, err := http.NewRequest("GET", "http://"+metadataIP+"/metadata/versions", nil) + req, err := http.NewRequest("GET", "http://"+MetadataIP+"/metadata/versions", nil) if err != nil { return false } @@ -51,7 +51,7 @@ func GetMaxUsersAzure(instanceType string) int { return 3 } func getAzureInstanceType(client http.Client) string { - metadataEndpoint := "http://" + metadataIP + "/metadata/instance?api-version=2021-02-01" + metadataEndpoint := "http://" + MetadataIP + "/metadata/instance?api-version=2021-02-01" req, err := http.NewRequest("GET", metadataEndpoint, nil) if err != nil { return "" diff --git a/pkg/license/digitalocean.go b/pkg/license/digitalocean.go index 04b6a05..5be4c07 100644 --- a/pkg/license/digitalocean.go +++ b/pkg/license/digitalocean.go @@ -1,6 +1,7 @@ package license import ( + "bufio" "fmt" "io" "net/http" @@ -11,7 +12,7 @@ import ( ) func isOnDigitalOcean(client http.Client) bool { - endpoint := "http://" + metadataIP + "/metadata/v1/interfaces/private/0/type" + endpoint := "http://" + MetadataIP + "/metadata/v1/interfaces/private/0/type" req, err := http.NewRequest("GET", endpoint, nil) if err != nil { return false @@ -60,7 +61,7 @@ func getDigitalOceanLicenseKey(storage storage.ReadWriter, client http.Client) ( func getDigitalOceanIdentifier(client http.Client) (string, error) { id := "" - endpoint := "http://" + metadataIP + "/metadata/v1/id" + endpoint := "http://" + MetadataIP + "/metadata/v1/id" req, err := http.NewRequest("GET", endpoint, nil) if err != nil { return id, err @@ -82,3 +83,39 @@ func getDigitalOceanIdentifier(client http.Client) (string, error) { return strings.TrimSpace(string(body)), nil } + +func HasDigitalOceanTagSet(client http.Client, tag string) (bool, error) { + endpoint := "http://" + MetadataIP + "/metadata/v1/tags" + req, err := http.NewRequest("GET", endpoint, nil) + if err != nil { + return false, err + } + + resp, err := client.Do(req) + if err != nil { + return false, err + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + body, err := io.ReadAll(resp.Body) + if err != nil { + return false, err + } + return false, fmt.Errorf("wrong statuscode returned: %d; body: %s", resp.StatusCode, body) + } + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + if tag == strings.TrimSpace(scanner.Text()) { + return true, nil + } + } + + if err := scanner.Err(); err != nil { + return false, err + } + + return false, nil + +} diff --git a/pkg/license/gcp.go b/pkg/license/gcp.go index 8428ef9..5e69c54 100644 --- a/pkg/license/gcp.go +++ b/pkg/license/gcp.go @@ -11,7 +11,7 @@ import ( ) func isOnGCP(client http.Client) bool { - endpoint := "http://" + metadataIP + "/computeMetadata/v1/" + endpoint := "http://" + MetadataIP + "/computeMetadata/v1/" req, err := http.NewRequest("GET", endpoint, nil) if err != nil { return false @@ -62,7 +62,7 @@ func getGCPLicenseKey(storage storage.ReadWriter, client http.Client) (string, e func getGCPIdentifier(client http.Client) (string, error) { id := "" - endpoint := "http://" + metadataIP + "/computeMetadata/v1/project/project-id" + endpoint := "http://" + MetadataIP + "/computeMetadata/v1/project/project-id" req, err := http.NewRequest("GET", endpoint, nil) if err != nil { return id, err diff --git a/pkg/license/gcp_test.go b/pkg/license/gcp_test.go index 37d2672..f52b047 100644 --- a/pkg/license/gcp_test.go +++ b/pkg/license/gcp_test.go @@ -22,7 +22,7 @@ func TestGuessInfrastructureGCP(t *testing.T) { })) defer ts.Close() - metadataIP = strings.Replace(ts.URL, "http://", "", -1) + MetadataIP = strings.Replace(ts.URL, "http://", "", -1) infra := guessInfrastructure() @@ -50,7 +50,7 @@ func TestGetMaxUsersGCPBYOL(t *testing.T) { defer ts.Close() licenseURL = ts.URL - metadataIP = strings.Replace(ts.URL, "http://", "", -1) + MetadataIP = strings.Replace(ts.URL, "http://", "", -1) mockStorage := &memorystorage.MockMemoryStorage{} err := mockStorage.WriteFile("config/license.key", []byte("license-1234556-license")) diff --git a/pkg/license/license.go b/pkg/license/license.go index 5ddf077..b3cfb77 100644 --- a/pkg/license/license.go +++ b/pkg/license/license.go @@ -11,7 +11,7 @@ import ( randomutils "github.com/in4it/wireguard-server/pkg/utils/random" ) -var metadataIP = "169.254.169.254" +var MetadataIP = "169.254.169.254" var licenseURL = "https://in4it-vpn-server.s3.amazonaws.com/licenses" func guessInfrastructure() string { diff --git a/pkg/license/license_test.go b/pkg/license/license_test.go index 9f41356..87172fa 100644 --- a/pkg/license/license_test.go +++ b/pkg/license/license_test.go @@ -96,7 +96,7 @@ func TestGetMaxUsersAWSBYOL(t *testing.T) { "t3.xlarge": 50, } licenseURL = ts.URL - metadataIP = strings.Replace(ts.URL, "http://", "", -1) + MetadataIP = strings.Replace(ts.URL, "http://", "", -1) for _, v := range testCases { if v2 := GetMaxUsersAWSBYOL(http.Client{Timeout: 5 * time.Second}, &memorystorage.MockMemoryStorage{}); v2 != v { t.Fatalf("Wrong output: %d vs %d", v2, v) @@ -127,7 +127,7 @@ func TestGuessInfrastructureAzure(t *testing.T) { })) defer ts.Close() - metadataIP = strings.Replace(ts.URL, "http://", "", -1) + MetadataIP = strings.Replace(ts.URL, "http://", "", -1) infra := guessInfrastructure() @@ -170,7 +170,7 @@ func TestGuessInfrastructureAWSMarketplace(t *testing.T) { })) defer ts.Close() - metadataIP = strings.Replace(ts.URL, "http://", "", -1) + MetadataIP = strings.Replace(ts.URL, "http://", "", -1) infra := guessInfrastructure() @@ -204,7 +204,7 @@ func TestGuessInfrastructureAWS(t *testing.T) { })) defer ts.Close() - metadataIP = strings.Replace(ts.URL, "http://", "", -1) + MetadataIP = strings.Replace(ts.URL, "http://", "", -1) infra := guessInfrastructure() @@ -223,7 +223,7 @@ func TestGuessInfrastructureOther(t *testing.T) { })) defer ts.Close() - metadataIP = strings.Replace(ts.URL, "http://", "", -1) + MetadataIP = strings.Replace(ts.URL, "http://", "", -1) infra := guessInfrastructure() @@ -248,7 +248,7 @@ func TestGetAzureInstanceType(t *testing.T) { })) defer ts.Close() - metadataIP = strings.Replace(ts.URL, "http://", "", -1) + MetadataIP = strings.Replace(ts.URL, "http://", "", -1) usersPerVCPU := 25 @@ -275,7 +275,7 @@ func TestGetAWSInstanceType(t *testing.T) { })) defer ts.Close() - metadataIP = strings.Replace(ts.URL, "http://", "", -1) + MetadataIP = strings.Replace(ts.URL, "http://", "", -1) users := GetMaxUsersAWS(getAWSInstanceType(http.Client{Timeout: 5 * time.Second})) @@ -294,7 +294,7 @@ func TestGuessInfrastructureDigitalOcean(t *testing.T) { })) defer ts.Close() - metadataIP = strings.Replace(ts.URL, "http://", "", -1) + MetadataIP = strings.Replace(ts.URL, "http://", "", -1) infra := guessInfrastructure() @@ -330,7 +330,7 @@ func TestGetMaxUsersDigitalOceanBYOL(t *testing.T) { defer ts.Close() licenseURL = ts.URL - metadataIP = strings.Replace(ts.URL, "http://", "", -1) + MetadataIP = strings.Replace(ts.URL, "http://", "", -1) mockStorage := &memorystorage.MockMemoryStorage{} err := mockStorage.WriteFile("config/license.key", []byte("license-1234556-license")) @@ -383,7 +383,7 @@ func TestGetLicenseKey(t *testing.T) { w.WriteHeader(http.StatusNotFound) })) - metadataIP = strings.Replace(ts.URL, "http://", "", -1) + MetadataIP = strings.Replace(ts.URL, "http://", "", -1) logging.Loglevel = logging.LOG_DEBUG + logging.LOG_ERROR key := GetLicenseKey(&memorystorage.MockMemoryStorage{}, "") diff --git a/pkg/rest/context.go b/pkg/rest/context.go index ddea413..5343ba3 100644 --- a/pkg/rest/context.go +++ b/pkg/rest/context.go @@ -110,6 +110,7 @@ func getEmptyContext(appDir string) (*Context, error) { TokenRenewalTimeMinutes: oidcrenewal.DEFAULT_RENEWAL_TIME_MINUTES, LogLevel: logging.LOG_ERROR, SCIM: &SCIM{EnableSCIM: false}, + SAML: &SAML{Providers: &[]saml.Provider{}}, } return c, nil } diff --git a/pkg/rest/setup.go b/pkg/rest/setup.go index d100c2c..7149b0d 100644 --- a/pkg/rest/setup.go +++ b/pkg/rest/setup.go @@ -18,6 +18,7 @@ import ( "github.com/google/uuid" "github.com/in4it/wireguard-server/pkg/auth/oidc" "github.com/in4it/wireguard-server/pkg/auth/saml" + "github.com/in4it/wireguard-server/pkg/license" "github.com/in4it/wireguard-server/pkg/users" "github.com/in4it/wireguard-server/pkg/wireguard" ) @@ -35,14 +36,51 @@ func (c *Context) contextHandler(w http.ResponseWriter, r *http.Request) { c.SetupCompleted = true } if !c.SetupCompleted { - localSecret, err := c.Storage.Client.ReadFile(SETUP_CODE_FILE) - if err != nil { - c.returnError(w, fmt.Errorf("secret file read error: %s", err), http.StatusBadRequest) - return + // check if tag hash is chosen + accessGranted := false + switch c.CloudType { + case "digitalocean": // check if the hashtag is set + if contextReq.TagHash != "" { + if !strings.HasPrefix(contextReq.TagHash, "vpnsecret-") { + c.returnError(w, fmt.Errorf("tag doesn't have the correct prefix. The tag needs to start with 'vpnsecret-'"), http.StatusUnauthorized) + return + } + accessGranted, err = license.HasDigitalOceanTagSet(http.Client{Timeout: 5 * time.Second}, contextReq.TagHash) + if err != nil { + c.returnError(w, fmt.Errorf("could not retrieve tags at this time: %s", err), http.StatusUnauthorized) + return + } + if !accessGranted { + c.returnError(w, fmt.Errorf("tag not found. Make sure the correct tag is attached to the droplet"), http.StatusUnauthorized) + return + } + } + case "aws": // check if the instance id is set + if contextReq.InstanceID != "" { + instanceID, err := license.GetAWSInstanceID(http.Client{Timeout: 5 * time.Second}) + if err != nil { + c.returnError(w, fmt.Errorf("could not retrieve instance id at this time: %s", err), http.StatusUnauthorized) + return + } + if strings.TrimPrefix(instanceID, "i-") == strings.TrimPrefix(contextReq.InstanceID, "i-") { + accessGranted = true + } else { + c.returnError(w, fmt.Errorf("instance id doesn't match"), http.StatusUnauthorized) + return + } + } } - if strings.TrimSpace(string(localSecret)) != contextReq.Secret { - c.returnError(w, fmt.Errorf("wrong secret provided"), http.StatusUnauthorized) - return + // check secret + if !accessGranted { + localSecret, err := c.Storage.Client.ReadFile(SETUP_CODE_FILE) + if err != nil { + c.returnError(w, fmt.Errorf("secret file read error: %s", err), http.StatusBadRequest) + return + } + if strings.TrimSpace(string(localSecret)) != contextReq.Secret { + c.returnError(w, fmt.Errorf("wrong secret provided"), http.StatusUnauthorized) + return + } } if contextReq.AdminPassword != "" { adminUser := users.User{ @@ -95,10 +133,8 @@ func (c *Context) contextHandler(w http.ResponseWriter, r *http.Request) { } } } - cOut := Context{ - SetupCompleted: c.SetupCompleted, - } - out, err := json.Marshal(cOut) + + out, err := json.Marshal(ContextSetupResponse{SetupCompleted: c.SetupCompleted, CloudType: c.CloudType}) if err != nil { c.returnError(w, err, http.StatusBadRequest) return diff --git a/pkg/rest/setup_test.go b/pkg/rest/setup_test.go new file mode 100644 index 0000000..2e14bca --- /dev/null +++ b/pkg/rest/setup_test.go @@ -0,0 +1,293 @@ +package rest + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/in4it/wireguard-server/pkg/license" + memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" + "github.com/in4it/wireguard-server/pkg/users" +) + +func TestContextHandlerSetupSecret(t *testing.T) { + storage := &memorystorage.MockMemoryStorage{} + + storage.WriteFile(SETUP_CODE_FILE, []byte(`secret setup code`)) + + userStore, err := users.NewUserStore(storage, -1) + if err != nil { + t.Fatalf("new user store error") + } + c, err := getEmptyContext("appdir") + if err != nil { + t.Fatalf("cannot create empty context") + } + c.Storage = &Storage{Client: storage} + c.UserStore = userStore + + payload := ContextRequest{ + Secret: "secret setup code", + AdminPassword: "adminPassword", + } + payloadBytes, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal error: %s", err) + } + req := httptest.NewRequest("POST", "http://example.com/setup", bytes.NewBuffer(payloadBytes)) + w := httptest.NewRecorder() + c.contextHandler(w, req) + + resp := w.Result() + + if resp.StatusCode != 200 { + t.Fatalf("status code is not 200: %d", resp.StatusCode) + } + + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("got read error: %s", err) + } + + var contextSetupResponse ContextSetupResponse + err = json.Unmarshal(body, &contextSetupResponse) + if err != nil { + t.Fatalf("unmarshal error: %s", err) + } + if !contextSetupResponse.SetupCompleted { + t.Fatalf("expected setup to be completed") + } +} + +func TestContextHandlerSetupWrongSecret(t *testing.T) { + storage := &memorystorage.MockMemoryStorage{} + + storage.WriteFile(SETUP_CODE_FILE, []byte(`secret setup code`)) + + userStore, err := users.NewUserStore(storage, -1) + if err != nil { + t.Fatalf("new user store error") + } + c, err := getEmptyContext("appdir") + if err != nil { + t.Fatalf("cannot create empty context") + } + c.Storage = &Storage{Client: storage} + c.UserStore = userStore + + payload := ContextRequest{ + AdminPassword: "adminPassword", + } + payloadBytes, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal error: %s", err) + } + req := httptest.NewRequest("POST", "http://example.com/setup", bytes.NewBuffer(payloadBytes)) + w := httptest.NewRecorder() + c.contextHandler(w, req) + + resp := w.Result() + + if resp.StatusCode != 401 { + t.Fatalf("status code is not 401: %d", resp.StatusCode) + } + + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("got read error: %s", err) + } + + var contextSetupResponse ContextSetupResponse + err = json.Unmarshal(body, &contextSetupResponse) + if err != nil { + t.Fatalf("unmarshal error: %s", err) + } + if contextSetupResponse.SetupCompleted { + t.Fatalf("expected setup to not be completed") + } +} +func TestContextHandlerSetupWrongSecretPartial(t *testing.T) { + storage := &memorystorage.MockMemoryStorage{} + + storage.WriteFile(SETUP_CODE_FILE, []byte(`secret setup code`)) + + userStore, err := users.NewUserStore(storage, -1) + if err != nil { + t.Fatalf("new user store error") + } + c, err := getEmptyContext("appdir") + if err != nil { + t.Fatalf("cannot create empty context") + } + c.Storage = &Storage{Client: storage} + c.UserStore = userStore + + payload := ContextRequest{ + Secret: "secret setup cod", + AdminPassword: "adminPassword", + } + payloadBytes, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal error: %s", err) + } + req := httptest.NewRequest("POST", "http://example.com/setup", bytes.NewBuffer(payloadBytes)) + w := httptest.NewRecorder() + c.contextHandler(w, req) + + resp := w.Result() + + if resp.StatusCode != 401 { + t.Fatalf("status code is not 401: %d", resp.StatusCode) + } + + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("got read error: %s", err) + } + + var contextSetupResponse ContextSetupResponse + err = json.Unmarshal(body, &contextSetupResponse) + if err != nil { + t.Fatalf("unmarshal error: %s", err) + } + if contextSetupResponse.SetupCompleted { + t.Fatalf("expected setup to not be completed") + } +} + +func TestContextHandlerSetupAWSInstanceID(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.RequestURI == "/latest/api/token" { + w.Write([]byte("this is a test token")) + return + } + if r.RequestURI == "/latest/meta-data/instance-id" { + w.Write([]byte("i-012aaaaaaaaaaaaa1")) + return + } + w.WriteHeader(http.StatusBadRequest) + })) + defer ts.Close() + license.MetadataIP = strings.TrimPrefix(ts.URL, "http://") + + storage := &memorystorage.MockMemoryStorage{} + + storage.WriteFile(SETUP_CODE_FILE, []byte(`secret setup code`)) + + userStore, err := users.NewUserStore(storage, -1) + if err != nil { + t.Fatalf("new user store error") + } + c, err := getEmptyContext("appdir") + if err != nil { + t.Fatalf("cannot create empty context") + } + c.Storage = &Storage{Client: storage} + c.UserStore = userStore + c.CloudType = "aws" + + payload := ContextRequest{ + InstanceID: "i-012aaaaaaaaaaaaa1", + AdminPassword: "adminPassword", + } + payloadBytes, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal error: %s", err) + } + req := httptest.NewRequest("POST", "http://example.com/setup", bytes.NewBuffer(payloadBytes)) + w := httptest.NewRecorder() + c.contextHandler(w, req) + + resp := w.Result() + + if resp.StatusCode != 200 { + t.Fatalf("status code is not 200: %d", resp.StatusCode) + } + + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("got read error: %s", err) + } + + var contextSetupResponse ContextSetupResponse + err = json.Unmarshal(body, &contextSetupResponse) + if err != nil { + t.Fatalf("unmarshal error: %s", err) + } + if !contextSetupResponse.SetupCompleted { + t.Fatalf("expected setup to be completed") + } +} +func TestContextHandlerSetupDigitalOceanTag(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.RequestURI == "/metadata/v1/tags" { + w.Write([]byte("vpnsecret-this-is-a-secret-tag")) + return + } + w.WriteHeader(http.StatusBadRequest) + })) + defer ts.Close() + license.MetadataIP = strings.TrimPrefix(ts.URL, "http://") + + storage := &memorystorage.MockMemoryStorage{} + + storage.WriteFile(SETUP_CODE_FILE, []byte(`secret setup code`)) + + userStore, err := users.NewUserStore(storage, -1) + if err != nil { + t.Fatalf("new user store error") + } + c, err := getEmptyContext("appdir") + if err != nil { + t.Fatalf("cannot create empty context") + } + c.Storage = &Storage{Client: storage} + c.UserStore = userStore + c.CloudType = "digitalocean" + + payload := ContextRequest{ + TagHash: "vpnsecret-this-is-a-secret-tag", + AdminPassword: "adminPassword", + } + payloadBytes, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal error: %s", err) + } + req := httptest.NewRequest("POST", "http://example.com/setup", bytes.NewBuffer(payloadBytes)) + w := httptest.NewRecorder() + c.contextHandler(w, req) + + resp := w.Result() + + if resp.StatusCode != 200 { + t.Fatalf("status code is not 200: %d", resp.StatusCode) + } + + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("got read error: %s", err) + } + + var contextSetupResponse ContextSetupResponse + err = json.Unmarshal(body, &contextSetupResponse) + if err != nil { + t.Fatalf("unmarshal error: %s", err) + } + if !contextSetupResponse.SetupCompleted { + t.Fatalf("expected setup to be completed") + } +} diff --git a/pkg/rest/types.go b/pkg/rest/types.go index dbbcd8e..133b2fc 100644 --- a/pkg/rest/types.go +++ b/pkg/rest/types.go @@ -61,10 +61,16 @@ type Storage struct { type ContextRequest struct { Secret string `json:"secret"` + TagHash string `json:"tagHash"` + InstanceID string `json:"instanceID"` AdminPassword string `json:"adminPassword"` Hostname string `json:"hostname"` Protocol string `json:"protocol"` } +type ContextSetupResponse struct { + SetupCompleted bool `json:"setupCompleted"` + CloudType string `json:"cloudType"` +} type AuthMethodsResponse struct { LocalAuthDisabled bool `json:"localAuthDisabled"` diff --git a/webapp/src/AppInit/AppInit.tsx b/webapp/src/AppInit/AppInit.tsx index e1ac419..3c86f40 100644 --- a/webapp/src/AppInit/AppInit.tsx +++ b/webapp/src/AppInit/AppInit.tsx @@ -27,7 +27,7 @@ import { AppSettings } from '../Constants/Constants'; } if (!setupCompleted) { - return + return } else { return children } diff --git a/webapp/src/AppInit/SetAdminPassword.tsx b/webapp/src/AppInit/SetAdminPassword.tsx index a9ef519..dd37bc9 100644 --- a/webapp/src/AppInit/SetAdminPassword.tsx +++ b/webapp/src/AppInit/SetAdminPassword.tsx @@ -5,26 +5,23 @@ import axios from 'axios'; import { AppSettings } from '../Constants/Constants'; import { useMutation, - useQueryClient, } from '@tanstack/react-query' type Props = { onChangeStep: (newType: number) => void; - secret: string + secrets: SetupResponse }; -export function SetAdminPassword({onChangeStep, secret}: Props) { - const queryClient = useQueryClient() +export function SetAdminPassword({onChangeStep, secrets}: Props) { const [password, setPassword] = useState(""); const [password2, setPassword2] = useState(""); const [passwordError, setPasswordError] = useState(""); const [password2Error, setPassword2Error] = useState(""); const passwordMutation = useMutation({ mutationFn: (newPassword: string) => { - return axios.post(AppSettings.url + '/context', {secret: secret, adminPassword: newPassword, hostname: window.location.host, protocol: window.location.protocol}) + return axios.post(AppSettings.url + '/context', {...secrets, adminPassword: newPassword, hostname: window.location.host, protocol: window.location.protocol}) }, onSuccess: () => { - queryClient.invalidateQueries({ queryKey: ['context'] }) onChangeStep(2) }, onError: (error) => { @@ -40,9 +37,17 @@ export function SetAdminPassword({onChangeStep, secret}: Props) { } if(password === "") { setPasswordError("admin password cannot be blank") + return } passwordMutation.mutate(password) } + const captureEnter = (e: React.KeyboardEvent) => { + if (e.key === "Enter") { + if(password !== "" && password2 !== "") { + changePassword() + } + } + } return (
@@ -51,16 +56,20 @@ export function SetAdminPassword({onChangeStep, secret}: Props) { Set a password for the admin user. At the next screen you'll be able to login with the username "admin" and the password you'll set now. {passwordMutation.isPending ? ( -
Setting Password...
+
Setting Password for user 'admin'...
) : (
Your password - setPassword(event.currentTarget.value)} value={password} error={passwordError} + onKeyDown={(e) => captureEnter(e)} /> Repeat password @@ -68,9 +77,11 @@ export function SetAdminPassword({onChangeStep, secret}: Props) { setPassword2(event.currentTarget.value)} value={password2} - error={password2Error} + error={password2Error} + onKeyDown={(e) => captureEnter(e)} />
diff --git a/webapp/src/AppInit/SetSecret.tsx b/webapp/src/AppInit/SetSecret.tsx index a726e74..9624204 100644 --- a/webapp/src/AppInit/SetSecret.tsx +++ b/webapp/src/AppInit/SetSecret.tsx @@ -1,44 +1,86 @@ -import { Text, Title, TextInput, Button } from '@mantine/core'; +import { Text, Title, TextInput, Button, Card, Grid, Container, Center, Alert } from '@mantine/core'; import classes from './SetupBanner.module.css'; import {useState} from 'react'; -import axios from 'axios'; +import axios, { AxiosError } from 'axios'; import { AppSettings } from '../Constants/Constants'; import { - useQueryClient, useMutation, } from '@tanstack/react-query' +import { TbInfoCircle } from 'react-icons/tb'; type Props = { onChangeStep: (newType: number) => void; - onChangeSecret: (newType: string) => void; - }; + onChangeSecrets: (newType: SetupResponse) => void; + cloudType: string; +}; -export function SetSecret({onChangeStep, onChangeSecret}: Props) { - const queryClient = useQueryClient() - const [secret, setSecret] = useState(""); +type SetupResponseError = { + error: string; +} + +const randomHex = (length:number) => { + const bytes = window.crypto.getRandomValues(new Uint8Array(length)) + var hexstring='', h; + for(var i=0; i({secret: "", tagHash: "", instanceID: ""}); const [secretError, setSecretError] = useState(""); + const [randomHexValue] = useState(randomHex(16)) const secretMutation = useMutation({ - mutationFn: (newSecret: string) => { + mutationFn: (setupResponseParam: SetupResponse) => { setSecretError("") - return axios.post(AppSettings.url + '/context', {secret: newSecret}) + return axios.post(AppSettings.url + '/context', setupResponseParam) }, - onSuccess: () => { - queryClient.invalidateQueries({ queryKey: ['context'] }) - onChangeSecret(secret) + onSuccess: (_, setupResponseParam) => { + onChangeSecrets(setupResponseParam) onChangeStep(1) }, - onError: (error) => { - if(error.message.includes("status code 403")) { - setSecretError("Invalid secret") - } else { + onError: (error:AxiosError) => { + const errorMessage = error.response?.data as SetupResponseError + if(errorMessage?.error === undefined) { setSecretError("Error: "+ error.message) + } else { + setSecretError(errorMessage.error) } - } + }, }) + const captureEnter = (e: React.KeyboardEvent) => { + if (e.key === "Enter") { + secretMutation.mutate(setupResponse) + } + } + const alertIcon = + const hasMoreOptions = cloudType === "aws" || cloudType === "digitalocean" ? true : false + const colSpanWithSSH = hasMoreOptions ? 3 : 6 + return ( -
-
- Start Setup... + +
+ Start Setup +
+ {secretError !== "" ? + + + + {secretError} + + + : + null + } + + + + + {hasMoreOptions ? "Option 1: " : ""}With SSH Access Enter the secret to start the setup. @@ -57,14 +99,64 @@ export function SetSecret({onChangeStep, onChangeSecret}: Props) { setSecret(event.currentTarget.value)} - value={secret} - error={secretError} + onChange={(event) => setSetupResponse({ ...setupResponse, secret: event.currentTarget.value})} + value={setupResponse.secret} + onKeyDown={(e) => captureEnter(e)} /> - +
)} -
-
+ + + {cloudType === "aws" ? + + + {hasMoreOptions ? "Option 2: " : ""}Without SSH Access + + + Enter the EC2 Instance ID of the VPN Server + + {secretMutation.isPending ? ( +
Checking Instance ID...
+ ) : ( +
+ setSetupResponse({ ...setupResponse, instanceID: event.currentTarget.value})} + value={setupResponse.instanceID} + onKeyDown={(e) => captureEnter(e)} + /> + +
+ )} +
+
+ : null } + {cloudType === "digitalocean" ? + + + {hasMoreOptions ? "Option 2: " : ""}Without SSH Access + + + Add the following tag to the droplet by going to the droplet settings and opening the Tags page. You can remove the tag once the setup is complete. + + {secretMutation.isPending ? ( +
Checking tag...
+ ) : ( +
+ + +
+ )} +
+
+ : null } + + ); } \ No newline at end of file diff --git a/webapp/src/AppInit/SetupBanner.module.css b/webapp/src/AppInit/SetupBanner.module.css index fc2debc..37e8dfd 100644 --- a/webapp/src/AppInit/SetupBanner.module.css +++ b/webapp/src/AppInit/SetupBanner.module.css @@ -2,9 +2,6 @@ display: flex; align-items: center; padding: calc(var(--mantine-spacing-xl) * 2); - border-radius: var(--mantine-radius-md); - background-color: light-dark(var(--mantine-color-white), var(--mantine-color-dark-8)); - border: rem(1px) solid light-dark(var(--mantine-color-gray-3), var(--mantine-color-dark-8)); @media (max-width: $mantine-breakpoint-sm) { flex-direction: column-reverse; @@ -58,4 +55,7 @@ border-top-left-radius: 0; border-bottom-left-radius: 0; } - \ No newline at end of file + + .error:first-letter { + text-transform: capitalize + } \ No newline at end of file diff --git a/webapp/src/AppInit/SetupBanner.tsx b/webapp/src/AppInit/SetupBanner.tsx index 76b99f6..9ebf57c 100644 --- a/webapp/src/AppInit/SetupBanner.tsx +++ b/webapp/src/AppInit/SetupBanner.tsx @@ -5,11 +5,12 @@ import React from 'react'; type Props = { onCompleted: (newType: boolean) => void; + cloudType: string; }; -export function SetupBanner({onCompleted}:Props) { +export function SetupBanner({onCompleted, cloudType}:Props) { const [step, setStep] = useState(0); - const [secret, setSecret] = useState(""); + const [secrets, setSecrets] = useState({secret: "", tagHash: "", instanceID: ""}); React.useEffect(() => { if(step === 2) { @@ -18,8 +19,8 @@ export function SetupBanner({onCompleted}:Props) { }, [step]); if(step === 0) { - return + return } else if(step === 1) { - return + return } } \ No newline at end of file diff --git a/webapp/src/types/SetupRequest.tsx b/webapp/src/types/SetupRequest.tsx new file mode 100644 index 0000000..e7f1bf5 --- /dev/null +++ b/webapp/src/types/SetupRequest.tsx @@ -0,0 +1,5 @@ +type SetupResponse = { + secret: string; + tagHash: string; + instanceID: string; + } \ No newline at end of file