From 974353448c6ee094a3c05b0a00d2708d2a78de2a Mon Sep 17 00:00:00 2001 From: Edward Viaene Date: Wed, 11 Sep 2024 10:51:05 -0500 Subject: [PATCH] backend work for improved setup flow --- docs/release-notes.md | 9 ++ pkg/license/aws.go | 32 +++- pkg/license/azure.go | 4 +- pkg/license/digitalocean.go | 41 ++++- pkg/license/gcp.go | 4 +- pkg/license/gcp_test.go | 4 +- pkg/license/license.go | 2 +- pkg/license/license_test.go | 20 +-- pkg/rest/context.go | 1 + pkg/rest/setup.go | 51 +++++-- pkg/rest/setup_test.go | 293 ++++++++++++++++++++++++++++++++++++ pkg/rest/types.go | 5 + 12 files changed, 432 insertions(+), 34 deletions(-) create mode 100644 pkg/rest/setup_test.go diff --git a/docs/release-notes.md b/docs/release-notes.md index 698137b..0f53d4f 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -1,5 +1,14 @@ # Release Notes +## Version v1.1.4 +* Improved setup flow for AWS & DigitalOcean + +## Version v1.1.3 +* New Feature: Log packets traversing the VPN Server. This release supports logging TCP / DNS / HTTP / HTTPS packets and inspecting the destination of http/https packets. + +## Version v1.1.2 +* UI: fixes in user creation + ## Version v1.1.0 * UI: change VPN configuration within the admin UI * UI: ability to reload WireGuard® configuration diff --git a/pkg/license/aws.go b/pkg/license/aws.go index 1a81e87..4b507e7 100644 --- a/pkg/license/aws.go +++ b/pkg/license/aws.go @@ -13,7 +13,7 @@ import ( const AWS_PRODUCT_CODE = "7h7h3bnutjn0ziamv7npi8a69" func getMetadataToken(client http.Client) string { - metadataEndpoint := "http://" + metadataIP + "/latest/api/token" + metadataEndpoint := "http://" + MetadataIP + "/latest/api/token" req, err := http.NewRequest("PUT", metadataEndpoint, nil) if err != nil { @@ -62,7 +62,7 @@ func isOnAWS(client http.Client) bool { func getInstanceIdentityDocument(client http.Client, token string) (InstanceIdentityDocument, error) { var instanceIdentityDocument InstanceIdentityDocument - endpoint := "http://" + metadataIP + "/2022-09-24/dynamic/instance-identity/document" + endpoint := "http://" + MetadataIP + "/2022-09-24/dynamic/instance-identity/document" req, err := http.NewRequest("GET", endpoint, nil) if err != nil { return instanceIdentityDocument, err @@ -145,7 +145,7 @@ func getLicense(client http.Client, key string) (License, error) { } func getLicenseFromMetaData(token string, client http.Client) (string, error) { - endpoint := "http://" + metadataIP + "/2022-09-24/meta-data/tags/instance/license" + endpoint := "http://" + MetadataIP + "/2022-09-24/meta-data/tags/instance/license" req, err := http.NewRequest("GET", endpoint, nil) if err != nil { return "", err @@ -173,7 +173,7 @@ func getLicenseFromMetaData(token string, client http.Client) (string, error) { func getAWSInstanceType(client http.Client) string { token := getMetadataToken(client) - endpoint := "http://" + metadataIP + "/latest/meta-data/instance-type" + endpoint := "http://" + MetadataIP + "/latest/meta-data/instance-type" req, err := http.NewRequest("GET", endpoint, nil) if err != nil { return "" @@ -194,6 +194,30 @@ func getAWSInstanceType(client http.Client) string { return "" } +func GetAWSInstanceID(client http.Client) (string, error) { + token := getMetadataToken(client) + + endpoint := "http://" + MetadataIP + "/latest/meta-data/instance-id" + req, err := http.NewRequest("GET", endpoint, nil) + if err != nil { + return "", err + } + if token != "" { + req.Header.Add("X-aws-ec2-metadata-token", token) + } + + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + if resp.StatusCode == 200 { + bodyBytes, _ := io.ReadAll(resp.Body) + return string(bodyBytes), err + } + return "", fmt.Errorf("received statuscode %d from aws metadata api", resp.StatusCode) +} + func GetMaxUsersAWS(instanceType string) int { if instanceType == "" { return 3 diff --git a/pkg/license/azure.go b/pkg/license/azure.go index 468cc6f..18ab9de 100644 --- a/pkg/license/azure.go +++ b/pkg/license/azure.go @@ -9,7 +9,7 @@ import ( ) func isOnAzure(client http.Client) bool { - req, err := http.NewRequest("GET", "http://"+metadataIP+"/metadata/versions", nil) + req, err := http.NewRequest("GET", "http://"+MetadataIP+"/metadata/versions", nil) if err != nil { return false } @@ -51,7 +51,7 @@ func GetMaxUsersAzure(instanceType string) int { return 3 } func getAzureInstanceType(client http.Client) string { - metadataEndpoint := "http://" + metadataIP + "/metadata/instance?api-version=2021-02-01" + metadataEndpoint := "http://" + MetadataIP + "/metadata/instance?api-version=2021-02-01" req, err := http.NewRequest("GET", metadataEndpoint, nil) if err != nil { return "" diff --git a/pkg/license/digitalocean.go b/pkg/license/digitalocean.go index 04b6a05..5be4c07 100644 --- a/pkg/license/digitalocean.go +++ b/pkg/license/digitalocean.go @@ -1,6 +1,7 @@ package license import ( + "bufio" "fmt" "io" "net/http" @@ -11,7 +12,7 @@ import ( ) func isOnDigitalOcean(client http.Client) bool { - endpoint := "http://" + metadataIP + "/metadata/v1/interfaces/private/0/type" + endpoint := "http://" + MetadataIP + "/metadata/v1/interfaces/private/0/type" req, err := http.NewRequest("GET", endpoint, nil) if err != nil { return false @@ -60,7 +61,7 @@ func getDigitalOceanLicenseKey(storage storage.ReadWriter, client http.Client) ( func getDigitalOceanIdentifier(client http.Client) (string, error) { id := "" - endpoint := "http://" + metadataIP + "/metadata/v1/id" + endpoint := "http://" + MetadataIP + "/metadata/v1/id" req, err := http.NewRequest("GET", endpoint, nil) if err != nil { return id, err @@ -82,3 +83,39 @@ func getDigitalOceanIdentifier(client http.Client) (string, error) { return strings.TrimSpace(string(body)), nil } + +func HasDigitalOceanTagSet(client http.Client, tag string) (bool, error) { + endpoint := "http://" + MetadataIP + "/metadata/v1/tags" + req, err := http.NewRequest("GET", endpoint, nil) + if err != nil { + return false, err + } + + resp, err := client.Do(req) + if err != nil { + return false, err + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + body, err := io.ReadAll(resp.Body) + if err != nil { + return false, err + } + return false, fmt.Errorf("wrong statuscode returned: %d; body: %s", resp.StatusCode, body) + } + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + if tag == strings.TrimSpace(scanner.Text()) { + return true, nil + } + } + + if err := scanner.Err(); err != nil { + return false, err + } + + return false, nil + +} diff --git a/pkg/license/gcp.go b/pkg/license/gcp.go index 8428ef9..5e69c54 100644 --- a/pkg/license/gcp.go +++ b/pkg/license/gcp.go @@ -11,7 +11,7 @@ import ( ) func isOnGCP(client http.Client) bool { - endpoint := "http://" + metadataIP + "/computeMetadata/v1/" + endpoint := "http://" + MetadataIP + "/computeMetadata/v1/" req, err := http.NewRequest("GET", endpoint, nil) if err != nil { return false @@ -62,7 +62,7 @@ func getGCPLicenseKey(storage storage.ReadWriter, client http.Client) (string, e func getGCPIdentifier(client http.Client) (string, error) { id := "" - endpoint := "http://" + metadataIP + "/computeMetadata/v1/project/project-id" + endpoint := "http://" + MetadataIP + "/computeMetadata/v1/project/project-id" req, err := http.NewRequest("GET", endpoint, nil) if err != nil { return id, err diff --git a/pkg/license/gcp_test.go b/pkg/license/gcp_test.go index 37d2672..f52b047 100644 --- a/pkg/license/gcp_test.go +++ b/pkg/license/gcp_test.go @@ -22,7 +22,7 @@ func TestGuessInfrastructureGCP(t *testing.T) { })) defer ts.Close() - metadataIP = strings.Replace(ts.URL, "http://", "", -1) + MetadataIP = strings.Replace(ts.URL, "http://", "", -1) infra := guessInfrastructure() @@ -50,7 +50,7 @@ func TestGetMaxUsersGCPBYOL(t *testing.T) { defer ts.Close() licenseURL = ts.URL - metadataIP = strings.Replace(ts.URL, "http://", "", -1) + MetadataIP = strings.Replace(ts.URL, "http://", "", -1) mockStorage := &memorystorage.MockMemoryStorage{} err := mockStorage.WriteFile("config/license.key", []byte("license-1234556-license")) diff --git a/pkg/license/license.go b/pkg/license/license.go index 5ddf077..b3cfb77 100644 --- a/pkg/license/license.go +++ b/pkg/license/license.go @@ -11,7 +11,7 @@ import ( randomutils "github.com/in4it/wireguard-server/pkg/utils/random" ) -var metadataIP = "169.254.169.254" +var MetadataIP = "169.254.169.254" var licenseURL = "https://in4it-vpn-server.s3.amazonaws.com/licenses" func guessInfrastructure() string { diff --git a/pkg/license/license_test.go b/pkg/license/license_test.go index 9f41356..87172fa 100644 --- a/pkg/license/license_test.go +++ b/pkg/license/license_test.go @@ -96,7 +96,7 @@ func TestGetMaxUsersAWSBYOL(t *testing.T) { "t3.xlarge": 50, } licenseURL = ts.URL - metadataIP = strings.Replace(ts.URL, "http://", "", -1) + MetadataIP = strings.Replace(ts.URL, "http://", "", -1) for _, v := range testCases { if v2 := GetMaxUsersAWSBYOL(http.Client{Timeout: 5 * time.Second}, &memorystorage.MockMemoryStorage{}); v2 != v { t.Fatalf("Wrong output: %d vs %d", v2, v) @@ -127,7 +127,7 @@ func TestGuessInfrastructureAzure(t *testing.T) { })) defer ts.Close() - metadataIP = strings.Replace(ts.URL, "http://", "", -1) + MetadataIP = strings.Replace(ts.URL, "http://", "", -1) infra := guessInfrastructure() @@ -170,7 +170,7 @@ func TestGuessInfrastructureAWSMarketplace(t *testing.T) { })) defer ts.Close() - metadataIP = strings.Replace(ts.URL, "http://", "", -1) + MetadataIP = strings.Replace(ts.URL, "http://", "", -1) infra := guessInfrastructure() @@ -204,7 +204,7 @@ func TestGuessInfrastructureAWS(t *testing.T) { })) defer ts.Close() - metadataIP = strings.Replace(ts.URL, "http://", "", -1) + MetadataIP = strings.Replace(ts.URL, "http://", "", -1) infra := guessInfrastructure() @@ -223,7 +223,7 @@ func TestGuessInfrastructureOther(t *testing.T) { })) defer ts.Close() - metadataIP = strings.Replace(ts.URL, "http://", "", -1) + MetadataIP = strings.Replace(ts.URL, "http://", "", -1) infra := guessInfrastructure() @@ -248,7 +248,7 @@ func TestGetAzureInstanceType(t *testing.T) { })) defer ts.Close() - metadataIP = strings.Replace(ts.URL, "http://", "", -1) + MetadataIP = strings.Replace(ts.URL, "http://", "", -1) usersPerVCPU := 25 @@ -275,7 +275,7 @@ func TestGetAWSInstanceType(t *testing.T) { })) defer ts.Close() - metadataIP = strings.Replace(ts.URL, "http://", "", -1) + MetadataIP = strings.Replace(ts.URL, "http://", "", -1) users := GetMaxUsersAWS(getAWSInstanceType(http.Client{Timeout: 5 * time.Second})) @@ -294,7 +294,7 @@ func TestGuessInfrastructureDigitalOcean(t *testing.T) { })) defer ts.Close() - metadataIP = strings.Replace(ts.URL, "http://", "", -1) + MetadataIP = strings.Replace(ts.URL, "http://", "", -1) infra := guessInfrastructure() @@ -330,7 +330,7 @@ func TestGetMaxUsersDigitalOceanBYOL(t *testing.T) { defer ts.Close() licenseURL = ts.URL - metadataIP = strings.Replace(ts.URL, "http://", "", -1) + MetadataIP = strings.Replace(ts.URL, "http://", "", -1) mockStorage := &memorystorage.MockMemoryStorage{} err := mockStorage.WriteFile("config/license.key", []byte("license-1234556-license")) @@ -383,7 +383,7 @@ func TestGetLicenseKey(t *testing.T) { w.WriteHeader(http.StatusNotFound) })) - metadataIP = strings.Replace(ts.URL, "http://", "", -1) + MetadataIP = strings.Replace(ts.URL, "http://", "", -1) logging.Loglevel = logging.LOG_DEBUG + logging.LOG_ERROR key := GetLicenseKey(&memorystorage.MockMemoryStorage{}, "") diff --git a/pkg/rest/context.go b/pkg/rest/context.go index ddea413..5343ba3 100644 --- a/pkg/rest/context.go +++ b/pkg/rest/context.go @@ -110,6 +110,7 @@ func getEmptyContext(appDir string) (*Context, error) { TokenRenewalTimeMinutes: oidcrenewal.DEFAULT_RENEWAL_TIME_MINUTES, LogLevel: logging.LOG_ERROR, SCIM: &SCIM{EnableSCIM: false}, + SAML: &SAML{Providers: &[]saml.Provider{}}, } return c, nil } diff --git a/pkg/rest/setup.go b/pkg/rest/setup.go index d100c2c..454796a 100644 --- a/pkg/rest/setup.go +++ b/pkg/rest/setup.go @@ -18,6 +18,7 @@ import ( "github.com/google/uuid" "github.com/in4it/wireguard-server/pkg/auth/oidc" "github.com/in4it/wireguard-server/pkg/auth/saml" + "github.com/in4it/wireguard-server/pkg/license" "github.com/in4it/wireguard-server/pkg/users" "github.com/in4it/wireguard-server/pkg/wireguard" ) @@ -35,14 +36,44 @@ func (c *Context) contextHandler(w http.ResponseWriter, r *http.Request) { c.SetupCompleted = true } if !c.SetupCompleted { - localSecret, err := c.Storage.Client.ReadFile(SETUP_CODE_FILE) - if err != nil { - c.returnError(w, fmt.Errorf("secret file read error: %s", err), http.StatusBadRequest) - return + // check if tag hash is chosen + accessGranted := false + switch c.CloudType { + case "digitalocean": // check if the hashtag is set + if contextReq.TagHash != "" { + accessGranted, err = license.HasDigitalOceanTagSet(http.Client{Timeout: 5 * time.Second}, contextReq.TagHash) + if err != nil { + c.returnError(w, fmt.Errorf("could not retrieve tags at this time: %s", err), http.StatusUnauthorized) + return + } + if !accessGranted { + c.returnError(w, fmt.Errorf("tag not found. Make sure the correct tag is attached to the droplet"), http.StatusUnauthorized) + return + } + } + case "aws": // check if the instance id is set + if contextReq.InstanceID != "" { + instanceID, err := license.GetAWSInstanceID(http.Client{Timeout: 5 * time.Second}) + if err != nil { + c.returnError(w, fmt.Errorf("could not retrieve instance id at this time: %s", err), http.StatusUnauthorized) + return + } + if strings.TrimPrefix(instanceID, "i-") == strings.TrimPrefix(contextReq.InstanceID, "i-") { + accessGranted = true + } + } } - if strings.TrimSpace(string(localSecret)) != contextReq.Secret { - c.returnError(w, fmt.Errorf("wrong secret provided"), http.StatusUnauthorized) - return + // check secret + if !accessGranted { + localSecret, err := c.Storage.Client.ReadFile(SETUP_CODE_FILE) + if err != nil { + c.returnError(w, fmt.Errorf("secret file read error: %s", err), http.StatusBadRequest) + return + } + if strings.TrimSpace(string(localSecret)) != contextReq.Secret { + c.returnError(w, fmt.Errorf("wrong secret provided"), http.StatusUnauthorized) + return + } } if contextReq.AdminPassword != "" { adminUser := users.User{ @@ -95,10 +126,8 @@ func (c *Context) contextHandler(w http.ResponseWriter, r *http.Request) { } } } - cOut := Context{ - SetupCompleted: c.SetupCompleted, - } - out, err := json.Marshal(cOut) + + out, err := json.Marshal(ContextSetupResponse{SetupCompleted: c.SetupCompleted}) if err != nil { c.returnError(w, err, http.StatusBadRequest) return diff --git a/pkg/rest/setup_test.go b/pkg/rest/setup_test.go new file mode 100644 index 0000000..0465b7b --- /dev/null +++ b/pkg/rest/setup_test.go @@ -0,0 +1,293 @@ +package rest + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/in4it/wireguard-server/pkg/license" + memorystorage "github.com/in4it/wireguard-server/pkg/storage/memory" + "github.com/in4it/wireguard-server/pkg/users" +) + +func TestContextHandlerSetupSecret(t *testing.T) { + storage := &memorystorage.MockMemoryStorage{} + + storage.WriteFile(SETUP_CODE_FILE, []byte(`secret setup code`)) + + userStore, err := users.NewUserStore(storage, -1) + if err != nil { + t.Fatalf("new user store error") + } + c, err := getEmptyContext("appdir") + if err != nil { + t.Fatalf("cannot create empty context") + } + c.Storage = &Storage{Client: storage} + c.UserStore = userStore + + payload := ContextRequest{ + Secret: "secret setup code", + AdminPassword: "adminPassword", + } + payloadBytes, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal error: %s", err) + } + req := httptest.NewRequest("POST", "http://example.com/setup", bytes.NewBuffer(payloadBytes)) + w := httptest.NewRecorder() + c.contextHandler(w, req) + + resp := w.Result() + + if resp.StatusCode != 200 { + t.Fatalf("status code is not 200: %d", resp.StatusCode) + } + + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("got read error: %s", err) + } + + var contextSetupResponse ContextSetupResponse + err = json.Unmarshal(body, &contextSetupResponse) + if err != nil { + t.Fatalf("unmarshal error: %s", err) + } + if !contextSetupResponse.SetupCompleted { + t.Fatalf("expected setup to be completed") + } +} + +func TestContextHandlerSetupWrongSecret(t *testing.T) { + storage := &memorystorage.MockMemoryStorage{} + + storage.WriteFile(SETUP_CODE_FILE, []byte(`secret setup code`)) + + userStore, err := users.NewUserStore(storage, -1) + if err != nil { + t.Fatalf("new user store error") + } + c, err := getEmptyContext("appdir") + if err != nil { + t.Fatalf("cannot create empty context") + } + c.Storage = &Storage{Client: storage} + c.UserStore = userStore + + payload := ContextRequest{ + AdminPassword: "adminPassword", + } + payloadBytes, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal error: %s", err) + } + req := httptest.NewRequest("POST", "http://example.com/setup", bytes.NewBuffer(payloadBytes)) + w := httptest.NewRecorder() + c.contextHandler(w, req) + + resp := w.Result() + + if resp.StatusCode != 401 { + t.Fatalf("status code is not 401: %d", resp.StatusCode) + } + + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("got read error: %s", err) + } + + var contextSetupResponse ContextSetupResponse + err = json.Unmarshal(body, &contextSetupResponse) + if err != nil { + t.Fatalf("unmarshal error: %s", err) + } + if contextSetupResponse.SetupCompleted { + t.Fatalf("expected setup to not be completed") + } +} +func TestContextHandlerSetupWrongSecretPartial(t *testing.T) { + storage := &memorystorage.MockMemoryStorage{} + + storage.WriteFile(SETUP_CODE_FILE, []byte(`secret setup code`)) + + userStore, err := users.NewUserStore(storage, -1) + if err != nil { + t.Fatalf("new user store error") + } + c, err := getEmptyContext("appdir") + if err != nil { + t.Fatalf("cannot create empty context") + } + c.Storage = &Storage{Client: storage} + c.UserStore = userStore + + payload := ContextRequest{ + Secret: "secret setup cod", + AdminPassword: "adminPassword", + } + payloadBytes, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal error: %s", err) + } + req := httptest.NewRequest("POST", "http://example.com/setup", bytes.NewBuffer(payloadBytes)) + w := httptest.NewRecorder() + c.contextHandler(w, req) + + resp := w.Result() + + if resp.StatusCode != 401 { + t.Fatalf("status code is not 401: %d", resp.StatusCode) + } + + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("got read error: %s", err) + } + + var contextSetupResponse ContextSetupResponse + err = json.Unmarshal(body, &contextSetupResponse) + if err != nil { + t.Fatalf("unmarshal error: %s", err) + } + if contextSetupResponse.SetupCompleted { + t.Fatalf("expected setup to not be completed") + } +} + +func TestContextHandlerSetupAWSInstanceID(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.RequestURI == "/latest/api/token" { + w.Write([]byte("this is a test token")) + return + } + if r.RequestURI == "/latest/meta-data/instance-id" { + w.Write([]byte("i-012aaaaaaaaaaaaa1")) + return + } + w.WriteHeader(http.StatusBadRequest) + })) + defer ts.Close() + license.MetadataIP = strings.TrimPrefix(ts.URL, "http://") + + storage := &memorystorage.MockMemoryStorage{} + + storage.WriteFile(SETUP_CODE_FILE, []byte(`secret setup code`)) + + userStore, err := users.NewUserStore(storage, -1) + if err != nil { + t.Fatalf("new user store error") + } + c, err := getEmptyContext("appdir") + if err != nil { + t.Fatalf("cannot create empty context") + } + c.Storage = &Storage{Client: storage} + c.UserStore = userStore + c.CloudType = "aws" + + payload := ContextRequest{ + InstanceID: "i-012aaaaaaaaaaaaa1", + AdminPassword: "adminPassword", + } + payloadBytes, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal error: %s", err) + } + req := httptest.NewRequest("POST", "http://example.com/setup", bytes.NewBuffer(payloadBytes)) + w := httptest.NewRecorder() + c.contextHandler(w, req) + + resp := w.Result() + + if resp.StatusCode != 200 { + t.Fatalf("status code is not 200: %d", resp.StatusCode) + } + + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("got read error: %s", err) + } + + var contextSetupResponse ContextSetupResponse + err = json.Unmarshal(body, &contextSetupResponse) + if err != nil { + t.Fatalf("unmarshal error: %s", err) + } + if !contextSetupResponse.SetupCompleted { + t.Fatalf("expected setup to be completed") + } +} +func TestContextHandlerSetupDigitalOceanTag(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.RequestURI == "/metadata/v1/tags" { + w.Write([]byte("this-is-a-secret-tag")) + return + } + w.WriteHeader(http.StatusBadRequest) + })) + defer ts.Close() + license.MetadataIP = strings.TrimPrefix(ts.URL, "http://") + + storage := &memorystorage.MockMemoryStorage{} + + storage.WriteFile(SETUP_CODE_FILE, []byte(`secret setup code`)) + + userStore, err := users.NewUserStore(storage, -1) + if err != nil { + t.Fatalf("new user store error") + } + c, err := getEmptyContext("appdir") + if err != nil { + t.Fatalf("cannot create empty context") + } + c.Storage = &Storage{Client: storage} + c.UserStore = userStore + c.CloudType = "digitalocean" + + payload := ContextRequest{ + TagHash: "this-is-a-secret-tag", + AdminPassword: "adminPassword", + } + payloadBytes, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal error: %s", err) + } + req := httptest.NewRequest("POST", "http://example.com/setup", bytes.NewBuffer(payloadBytes)) + w := httptest.NewRecorder() + c.contextHandler(w, req) + + resp := w.Result() + + if resp.StatusCode != 200 { + t.Fatalf("status code is not 200: %d", resp.StatusCode) + } + + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("got read error: %s", err) + } + + var contextSetupResponse ContextSetupResponse + err = json.Unmarshal(body, &contextSetupResponse) + if err != nil { + t.Fatalf("unmarshal error: %s", err) + } + if !contextSetupResponse.SetupCompleted { + t.Fatalf("expected setup to be completed") + } +} diff --git a/pkg/rest/types.go b/pkg/rest/types.go index dbbcd8e..66af301 100644 --- a/pkg/rest/types.go +++ b/pkg/rest/types.go @@ -61,10 +61,15 @@ type Storage struct { type ContextRequest struct { Secret string `json:"secret"` + TagHash string `json:"tagHash"` + InstanceID string `json:"instanceID"` AdminPassword string `json:"adminPassword"` Hostname string `json:"hostname"` Protocol string `json:"protocol"` } +type ContextSetupResponse struct { + SetupCompleted bool `json:"setupCompleted"` +} type AuthMethodsResponse struct { LocalAuthDisabled bool `json:"localAuthDisabled"`