diff --git a/pkg/api/v2/consts.go b/pkg/api/v2/consts.go new file mode 100644 index 0000000..2fb8a82 --- /dev/null +++ b/pkg/api/v2/consts.go @@ -0,0 +1,6 @@ +package v2 + +const ( + // CAPIAuthTokenHeader is the header used to pass the CAPI auth token. + CAPIAuthTokenHeader = "capi-auth-token" +) diff --git a/pkg/api/v2/register.go b/pkg/api/v2/register.go index 2178edf..115bd46 100644 --- a/pkg/api/v2/register.go +++ b/pkg/api/v2/register.go @@ -53,4 +53,27 @@ func (a *API) RegisterServer(server *http.ServeMux, middleware func(f http.Handl } httputil.Response(w, map[string]string{"status": "OK"}) })) + + // POST v2/dqlite/remove + server.HandleFunc(fmt.Sprintf("%s/dqlite/remove", HTTPPrefix), middleware(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + req := RemoveFromDqliteRequest{} + if err := httputil.UnmarshalJSON(r, &req); err != nil { + httputil.Error(w, http.StatusBadRequest, fmt.Errorf("failed to unmarshal JSON: %w", err)) + return + } + + token := r.Header.Get(CAPIAuthTokenHeader) + + if rc, err := a.RemoveFromDqlite(r.Context(), req, token); err != nil { + httputil.Error(w, rc, fmt.Errorf("failed to remove from dqlite: %w", err)) + return + } + + httputil.Response(w, nil) + })) } diff --git a/pkg/api/v2/remove.go b/pkg/api/v2/remove.go new file mode 100644 index 0000000..0fd00c0 --- /dev/null +++ b/pkg/api/v2/remove.go @@ -0,0 +1,33 @@ +package v2 + +import ( + "context" + "fmt" + "net/http" + + snaputil "github.com/canonical/microk8s-cluster-agent/pkg/snap/util" +) + +// RemoveFromDqliteRequest represents a request to remove a node from the dqlite cluster. +type RemoveFromDqliteRequest struct { + // RemoveEndpoint is the endpoint of the node to remove from the dqlite cluster. + RemoveEndpoint string `json:"remove_endpoint"` +} + +// RemoveFromDqlite implements the "POST /v2/dqlite/remove" endpoint and removes a node from the dqlite cluster. +func (a *API) RemoveFromDqlite(ctx context.Context, req RemoveFromDqliteRequest, token string) (int, error) { + isValid, err := a.Snap.IsCAPIAuthTokenValid(token) + if err != nil { + return http.StatusInternalServerError, fmt.Errorf("failed to validate CAPI auth token: %w", err) + } + + if !isValid { + return http.StatusUnauthorized, fmt.Errorf("invalid CAPI auth token %q", token) + } + + if err := snaputil.RemoveNodeFromDqlite(ctx, a.Snap, req.RemoveEndpoint); err != nil { + return http.StatusInternalServerError, fmt.Errorf("failed to remove node from dqlite: %w", err) + } + + return http.StatusOK, nil +} diff --git a/pkg/api/v2/remove_test.go b/pkg/api/v2/remove_test.go new file mode 100644 index 0000000..2371b4d --- /dev/null +++ b/pkg/api/v2/remove_test.go @@ -0,0 +1,74 @@ +package v2_test + +import ( + "context" + "errors" + "net/http" + "testing" + + . "github.com/onsi/gomega" + + v2 "github.com/canonical/microk8s-cluster-agent/pkg/api/v2" + "github.com/canonical/microk8s-cluster-agent/pkg/snap/mock" +) + +func TestRemove(t *testing.T) { + t.Run("RemoveFails", func(t *testing.T) { + cmdErr := errors.New("failed to run command") + apiv2 := &v2.API{ + Snap: &mock.Snap{ + RunCommandErr: cmdErr, + CAPIAuthTokenValid: true, + }, + } + + rc, err := apiv2.RemoveFromDqlite(context.Background(), v2.RemoveFromDqliteRequest{RemoveEndpoint: "1.1.1.1:1234"}, "token") + + g := NewWithT(t) + g.Expect(err).To(MatchError(cmdErr)) + g.Expect(rc).To(Equal(http.StatusInternalServerError)) + }) + + t.Run("InvalidToken", func(t *testing.T) { + apiv2 := &v2.API{ + Snap: &mock.Snap{ + CAPIAuthTokenValid: false, // explicitly set to false + }, + } + + rc, err := apiv2.RemoveFromDqlite(context.Background(), v2.RemoveFromDqliteRequest{RemoveEndpoint: "1.1.1.1:1234"}, "token") + + g := NewWithT(t) + g.Expect(err).To(HaveOccurred()) + g.Expect(rc).To(Equal(http.StatusUnauthorized)) + }) + + t.Run("TokenFileNotFound", func(t *testing.T) { + tokenErr := errors.New("token file not found") + apiv2 := &v2.API{ + Snap: &mock.Snap{ + CAPIAuthTokenError: tokenErr, + }, + } + + rc, err := apiv2.RemoveFromDqlite(context.Background(), v2.RemoveFromDqliteRequest{RemoveEndpoint: "1.1.1.1:1234"}, "token") + + g := NewWithT(t) + g.Expect(err).To(MatchError(tokenErr)) + g.Expect(rc).To(Equal(http.StatusInternalServerError)) + }) + + t.Run("RemovesSuccessfully", func(t *testing.T) { + apiv2 := &v2.API{ + Snap: &mock.Snap{ + CAPIAuthTokenValid: true, + }, + } + + rc, err := apiv2.RemoveFromDqlite(context.Background(), v2.RemoveFromDqliteRequest{RemoveEndpoint: "1.1.1.1:1234"}, "token") + + g := NewWithT(t) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(rc).To(Equal(http.StatusOK)) + }) +} diff --git a/pkg/snap/interface.go b/pkg/snap/interface.go index e7a85e8..35c863d 100644 --- a/pkg/snap/interface.go +++ b/pkg/snap/interface.go @@ -7,6 +7,18 @@ import ( // Snap is how the cluster agent interacts with the snap. type Snap interface { + // GetSnapPath returns the path to a file or directory in the snap directory. + GetSnapPath(parts ...string) string + // GetSnapDataPath returns the path to a file or directory in the snap's data directory. + GetSnapDataPath(parts ...string) string + // GetSnapCommonPath returns the path to a file or directory in the snap's common directory. + GetSnapCommonPath(parts ...string) string + // GetCAPIPath returns the path to a file or directory in the CAPI directory. + GetCAPIPath(parts ...string) string + + // RunCommand runs a shell command. + RunCommand(ctx context.Context, commands ...string) error + // GetGroupName is the group microk8s is using. // The group name is "microk8s" for classic snaps and "snap_microk8s" for strict snaps. GetGroupName() string @@ -88,6 +100,9 @@ type Snap interface { // GetKnownToken returns the token for a known user from the known_users.csv file. GetKnownToken(username string) (string, error) + // IsCAPIAuthTokenValid returns true if token is a valid CAPI auth token. + IsCAPIAuthTokenValid(token string) (bool, error) + // SignCertificate signs the certificate signing request, and returns the certificate in PEM format. SignCertificate(ctx context.Context, csrPEM []byte) ([]byte, error) diff --git a/pkg/snap/mock/mock.go b/pkg/snap/mock/mock.go index dd6a996..54f244d 100644 --- a/pkg/snap/mock/mock.go +++ b/pkg/snap/mock/mock.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "path/filepath" "strings" "github.com/canonical/microk8s-cluster-agent/pkg/snap" @@ -23,8 +24,21 @@ type JoinClusterCall struct { Worker bool } +// RunCommandCall contains the arguments passed to a specific call of the RunCommand method. +type RunCommandCall struct { + Commands []string +} + // Snap is a generic mock for the snap.Snap interface. type Snap struct { + SnapDir string + SnapDataDir string + SnapCommonDir string + CAPIDir string + + RunCommandCalledWith []RunCommandCall + RunCommandErr error + GroupName string EnableAddonCalledWith []string @@ -72,6 +86,9 @@ type Snap struct { KubeletTokens map[string]string // map hostname to token KnownTokens map[string]string // map username to token + CAPIAuthTokenValid bool + CAPIAuthTokenError error + SignCertificateCalledWith []string // string(csrPEM) SignedCertificate string @@ -86,6 +103,32 @@ type Snap struct { JoinClusterCalledWith []JoinClusterCall } +// GetSnapPath is a mock implementation for the snap.Snap interface. +func (s *Snap) GetSnapPath(parts ...string) string { + return filepath.Join(append([]string{s.SnapDir}, parts...)...) +} + +// GetSnapDataPath is a mock implementation for the snap.Snap interface. +func (s *Snap) GetSnapDataPath(parts ...string) string { + return filepath.Join(append([]string{s.SnapDataDir}, parts...)...) +} + +// GetSnapCommonPath is a mock implementation for the snap.Snap interface. +func (s *Snap) GetSnapCommonPath(parts ...string) string { + return filepath.Join(append([]string{s.SnapCommonDir}, parts...)...) +} + +// GetCAPIPath is a mock implementation for the snap.Snap interface. +func (s *Snap) GetCAPIPath(parts ...string) string { + return filepath.Join(append([]string{s.CAPIDir}, parts...)...) +} + +// RunCommand is a mock implementation for the snap.Snap interface. +func (s *Snap) RunCommand(_ context.Context, commands ...string) error { + s.RunCommandCalledWith = append(s.RunCommandCalledWith, RunCommandCall{Commands: commands}) + return s.RunCommandErr +} + // GetGroupName is a mock implementation for the snap.Snap interface. func (s *Snap) GetGroupName() string { return s.GroupName @@ -284,6 +327,11 @@ func (s *Snap) GetKnownToken(username string) (string, error) { return "", fmt.Errorf("no known token for user %s", username) } +// IsCAPIAuthTokenValid is a mock implementation for the snap.Snap interface. +func (s *Snap) IsCAPIAuthTokenValid(token string) (bool, error) { + return s.CAPIAuthTokenValid, s.CAPIAuthTokenError +} + // RunUpgrade is a mock implementation for the snap.Snap interface. func (s *Snap) RunUpgrade(ctx context.Context, upgrade string, phase string) error { s.RunUpgradeCalledWith = append(s.RunUpgradeCalledWith, fmt.Sprintf("%s %s", upgrade, phase)) diff --git a/pkg/snap/options.go b/pkg/snap/options.go index a439e41..f58456c 100644 --- a/pkg/snap/options.go +++ b/pkg/snap/options.go @@ -22,3 +22,10 @@ func WithCommandRunner(f func(context.Context, ...string) error) func(s *snap) { s.runCommand = f } } + +// WithCAPIPath configures the path to the CAPI directory. +func WithCAPIPath(path string) func(s *snap) { + return func(s *snap) { + s.capiPath = path + } +} diff --git a/pkg/snap/snap.go b/pkg/snap/snap.go index 16e018d..3405646 100644 --- a/pkg/snap/snap.go +++ b/pkg/snap/snap.go @@ -23,6 +23,7 @@ type snap struct { snapDir string snapDataDir string snapCommonDir string + capiPath string runCommand func(context.Context, ...string) error clusterTokensMu sync.Mutex @@ -34,6 +35,10 @@ type snap struct { applyCNIBackoff time.Duration } +const ( + defaultCAPIPath = "/capi" +) + // NewSnap creates a new interface with the MicroK8s snap. // NewSnap accepts the $SNAP, $SNAP_DATA and $SNAP_COMMON, directories, and a number of options. func NewSnap(snapDir, snapDataDir, snapCommonDir string, options ...func(s *snap)) Snap { @@ -41,6 +46,7 @@ func NewSnap(snapDir, snapDataDir, snapCommonDir string, options ...func(s *snap snapDir: snapDir, snapDataDir: snapDataDir, snapCommonDir: snapCommonDir, + capiPath: defaultCAPIPath, runCommand: util.RunCommand, } @@ -51,16 +57,23 @@ func NewSnap(snapDir, snapDataDir, snapCommonDir string, options ...func(s *snap } -func (s *snap) snapPath(parts ...string) string { +func (s *snap) RunCommand(ctx context.Context, commands ...string) error { + return s.runCommand(ctx, commands...) +} + +func (s *snap) GetSnapPath(parts ...string) string { return filepath.Join(append([]string{s.snapDir}, parts...)...) } -func (s *snap) snapDataPath(parts ...string) string { +func (s *snap) GetSnapDataPath(parts ...string) string { return filepath.Join(append([]string{s.snapDataDir}, parts...)...) } -func (s *snap) snapCommonPath(parts ...string) string { +func (s *snap) GetSnapCommonPath(parts ...string) string { return filepath.Join(append([]string{s.snapCommonDir}, parts...)...) } +func (s *snap) GetCAPIPath(parts ...string) string { + return filepath.Join(append([]string{s.capiPath}, parts...)...) +} func (s *snap) GetGroupName() string { if s.isStrict() { @@ -70,11 +83,11 @@ func (s *snap) GetGroupName() string { } func (s *snap) EnableAddon(ctx context.Context, addon string, args ...string) error { - return s.runCommand(ctx, append([]string{s.snapPath("microk8s-enable.wrapper"), addon}, args...)...) + return s.runCommand(ctx, append([]string{s.GetSnapPath("microk8s-enable.wrapper"), addon}, args...)...) } func (s *snap) DisableAddon(ctx context.Context, addon string, args ...string) error { - return s.runCommand(ctx, append([]string{s.snapPath("microk8s-disable.wrapper"), addon}, args...)...) + return s.runCommand(ctx, append([]string{s.GetSnapPath("microk8s-disable.wrapper"), addon}, args...)...) } type snapcraftYml struct { @@ -83,7 +96,7 @@ type snapcraftYml struct { func (s *snap) isStrict() bool { var meta snapcraftYml - contents, err := util.ReadFile(s.snapPath("meta", "snapcraft.yaml")) + contents, err := util.ReadFile(s.GetSnapPath("meta", "snapcraft.yaml")) if err != nil { return false } @@ -122,7 +135,7 @@ func (s *snap) RunUpgrade(ctx context.Context, upgrade string, phase string) err default: return fmt.Errorf("unknown upgrade phase %q", phase) } - scriptName := s.snapPath("upgrade-scripts", upgrade, fmt.Sprintf("%s-node.sh", phase)) + scriptName := s.GetSnapPath("upgrade-scripts", upgrade, fmt.Sprintf("%s-node.sh", phase)) if !util.FileExists(scriptName) { return fmt.Errorf("could not find script %s", scriptName) } @@ -133,41 +146,41 @@ func (s *snap) RunUpgrade(ctx context.Context, upgrade string, phase string) err } func (s *snap) ReadCA() (string, error) { - return util.ReadFile(s.snapDataPath("certs", "ca.crt")) + return util.ReadFile(s.GetSnapDataPath("certs", "ca.crt")) } func (s *snap) ReadCAKey() (string, error) { - return util.ReadFile(s.snapDataPath("certs", "ca.key")) + return util.ReadFile(s.GetSnapDataPath("certs", "ca.key")) } func (s *snap) GetCAPath() string { - return s.snapDataPath("certs", "ca.crt") + return s.GetSnapDataPath("certs", "ca.crt") } func (s *snap) GetCAKeyPath() string { - return s.snapDataPath("certs", "ca.key") + return s.GetSnapDataPath("certs", "ca.key") } func (s *snap) ReadServiceAccountKey() (string, error) { - return util.ReadFile(s.snapDataPath("certs", "serviceaccount.key")) + return util.ReadFile(s.GetSnapDataPath("certs", "serviceaccount.key")) } func (s *snap) GetCNIYamlPath() string { - return s.snapDataPath("args", "cni-network", "cni.yaml") + return s.GetSnapDataPath("args", "cni-network", "cni.yaml") } func (s *snap) ReadCNIYaml() (string, error) { - return util.ReadFile(s.snapDataPath("args", "cni-network", "cni.yaml")) + return util.ReadFile(s.GetSnapDataPath("args", "cni-network", "cni.yaml")) } func (s *snap) WriteCNIYaml(cniManifest []byte) error { - return os.WriteFile(s.snapDataPath("args", "cni-network", "cni.yaml"), []byte(cniManifest), 0660) + return os.WriteFile(s.GetSnapDataPath("args", "cni-network", "cni.yaml"), []byte(cniManifest), 0660) } func (s *snap) ApplyCNI(ctx context.Context) error { var err error for i := 0; i < s.applyCNIRetries; i++ { - if err = s.runCommand(ctx, s.snapPath("microk8s-kubectl.wrapper"), "apply", "-f", s.GetCNIYamlPath()); err == nil { + if err = s.runCommand(ctx, s.GetSnapPath("microk8s-kubectl.wrapper"), "apply", "-f", s.GetCNIYamlPath()); err == nil { return nil } time.Sleep(s.applyCNIBackoff) @@ -176,61 +189,61 @@ func (s *snap) ApplyCNI(ctx context.Context) error { } func (s *snap) ReadDqliteCert() (string, error) { - return util.ReadFile(s.snapDataPath("var", "kubernetes", "backend", "cluster.crt")) + return util.ReadFile(s.GetSnapDataPath("var", "kubernetes", "backend", "cluster.crt")) } func (s *snap) ReadDqliteKey() (string, error) { - return util.ReadFile(s.snapDataPath("var", "kubernetes", "backend", "cluster.key")) + return util.ReadFile(s.GetSnapDataPath("var", "kubernetes", "backend", "cluster.key")) } func (s *snap) ReadDqliteInfoYaml() (string, error) { - return util.ReadFile(s.snapDataPath("var", "kubernetes", "backend", "info.yaml")) + return util.ReadFile(s.GetSnapDataPath("var", "kubernetes", "backend", "info.yaml")) } func (s *snap) ReadDqliteClusterYaml() (string, error) { - return util.ReadFile(s.snapDataPath("var", "kubernetes", "backend", "cluster.yaml")) + return util.ReadFile(s.GetSnapDataPath("var", "kubernetes", "backend", "cluster.yaml")) } func (s *snap) WriteDqliteUpdateYaml(updateYaml []byte) error { - return os.WriteFile(s.snapDataPath("var", "kubernetes", "backend", "update.yaml"), updateYaml, 0660) + return os.WriteFile(s.GetSnapDataPath("var", "kubernetes", "backend", "update.yaml"), updateYaml, 0660) } func (s *snap) GetKubeconfigFile() string { - return s.snapDataPath("credentials", "client.config") + return s.GetSnapDataPath("credentials", "client.config") } func (s *snap) HasKubeliteLock() bool { - return util.FileExists(s.snapDataPath("var", "lock", "lite.lock")) + return util.FileExists(s.GetSnapDataPath("var", "lock", "lite.lock")) } func (s *snap) HasDqliteLock() bool { - return util.FileExists(s.snapDataPath("var", "lock", "ha-cluster")) + return util.FileExists(s.GetSnapDataPath("var", "lock", "ha-cluster")) } func (s *snap) HasNoCertsReissueLock() bool { - return util.FileExists(s.snapDataPath("var", "lock", "no-cert-reissue")) + return util.FileExists(s.GetSnapDataPath("var", "lock", "no-cert-reissue")) } func (s *snap) CreateNoCertsReissueLock() error { - _, err := os.OpenFile(s.snapDataPath("var", "lock", "no-cert-reissue"), os.O_CREATE, 0600) + _, err := os.OpenFile(s.GetSnapDataPath("var", "lock", "no-cert-reissue"), os.O_CREATE, 0600) return err } func (s *snap) ReadServiceArguments(serviceName string) (string, error) { - return util.ReadFile(s.snapDataPath("args", serviceName)) + return util.ReadFile(s.GetSnapDataPath("args", serviceName)) } func (s *snap) WriteServiceArguments(serviceName string, arguments []byte) error { - return os.WriteFile(s.snapDataPath("args", serviceName), arguments, 0660) + return os.WriteFile(s.GetSnapDataPath("args", serviceName), arguments, 0660) } func (s *snap) ConsumeClusterToken(token string) bool { s.clusterTokensMu.Lock() defer s.clusterTokensMu.Unlock() - if isValid, _ := util.IsValidToken(token, s.snapDataPath("credentials", "persistent-cluster-tokens.txt")); isValid { + if isValid, _ := util.IsValidToken(token, s.GetSnapDataPath("credentials", "persistent-cluster-tokens.txt")); isValid { return true } - clusterTokensFile := s.snapDataPath("credentials", "cluster-tokens.txt") + clusterTokensFile := s.GetSnapDataPath("credentials", "cluster-tokens.txt") isValid, hasTTL := util.IsValidToken(token, clusterTokensFile) if isValid && !hasTTL { if err := util.RemoveToken(token, clusterTokensFile, s.GetGroupName()); err != nil { @@ -243,7 +256,7 @@ func (s *snap) ConsumeClusterToken(token string) bool { func (s *snap) ConsumeCertificateRequestToken(token string) bool { s.certTokensMu.Lock() defer s.certTokensMu.Unlock() - certRequestTokensFile := s.snapDataPath("credentials", "certs-request-tokens.txt") + certRequestTokensFile := s.GetSnapDataPath("credentials", "certs-request-tokens.txt") isValid, _ := util.IsValidToken(token, certRequestTokensFile) if isValid { if err := util.RemoveToken(token, certRequestTokensFile, s.GetGroupName()); err != nil { @@ -254,32 +267,32 @@ func (s *snap) ConsumeCertificateRequestToken(token string) bool { } func (s *snap) ConsumeSelfCallbackToken(token string) bool { - valid, _ := util.IsValidToken(token, s.snapDataPath("credentials", "callback-token.txt")) + valid, _ := util.IsValidToken(token, s.GetSnapDataPath("credentials", "callback-token.txt")) return valid } func (s *snap) AddPersistentClusterToken(token string) error { s.certTokensMu.Lock() defer s.certTokensMu.Unlock() - return util.AppendToken(token, s.snapDataPath("credentials", "persistent-cluster-tokens.txt"), s.GetGroupName()) + return util.AppendToken(token, s.GetSnapDataPath("credentials", "persistent-cluster-tokens.txt"), s.GetGroupName()) } func (s *snap) AddCertificateRequestToken(token string) error { s.certTokensMu.Lock() defer s.certTokensMu.Unlock() - return util.AppendToken(token, s.snapDataPath("credentials", "certs-request-tokens.txt"), s.GetGroupName()) + return util.AppendToken(token, s.GetSnapDataPath("credentials", "certs-request-tokens.txt"), s.GetGroupName()) } func (s *snap) AddCallbackToken(clusterAgentEndpoint string, token string) error { s.callbackTokensMu.Lock() defer s.callbackTokensMu.Unlock() - return util.AppendToken(fmt.Sprintf("%s %s", clusterAgentEndpoint, token), s.snapDataPath("credentials", "callback-tokens.txt"), s.GetGroupName()) + return util.AppendToken(fmt.Sprintf("%s %s", clusterAgentEndpoint, token), s.GetSnapDataPath("credentials", "callback-tokens.txt"), s.GetGroupName()) } func (s *snap) GetOrCreateSelfCallbackToken() (string, error) { s.callbackTokensMu.Lock() defer s.callbackTokensMu.Unlock() - callbackTokenFile := s.snapDataPath("credentials", "callback-token.txt") + callbackTokenFile := s.GetSnapDataPath("credentials", "callback-token.txt") c, err := util.ReadFile(callbackTokenFile) if err != nil { token := util.NewRandomString(util.Alpha, 64) @@ -303,7 +316,7 @@ func (s *snap) GetOrCreateKubeletToken(hostname string) (string, error) { s.knownTokensMu.Lock() defer s.knownTokensMu.Unlock() - if err := util.AppendToken(fmt.Sprintf("%s,%s,kubelet-%s,\"system:nodes\"", token, user, uid), s.snapDataPath("credentials", "known_tokens.csv"), s.GetGroupName()); err != nil { + if err := util.AppendToken(fmt.Sprintf("%s,%s,kubelet-%s,\"system:nodes\"", token, user, uid), s.GetSnapDataPath("credentials", "known_tokens.csv"), s.GetGroupName()); err != nil { return "", fmt.Errorf("failed to add new kubelet token for %s: %w", user, err) } @@ -313,7 +326,7 @@ func (s *snap) GetOrCreateKubeletToken(hostname string) (string, error) { func (s *snap) GetKnownToken(username string) (string, error) { s.knownTokensMu.Lock() defer s.knownTokensMu.Unlock() - allTokens, err := util.ReadFile(s.snapDataPath("credentials", "known_tokens.csv")) + allTokens, err := util.ReadFile(s.GetSnapDataPath("credentials", "known_tokens.csv")) if err != nil { return "", fmt.Errorf("failed to retrieve known token for user %s: %w", username, err) } @@ -327,12 +340,21 @@ func (s *snap) GetKnownToken(username string) (string, error) { return "", fmt.Errorf("no known token found for user %s", username) } +// IsCAPIAuthTokenValid checks if the given CAPI auth token is valid. +func (s *snap) IsCAPIAuthTokenValid(token string) (bool, error) { + contents, err := util.ReadFile(s.GetCAPIPath("etc", "token")) + if err != nil { + return false, fmt.Errorf("failed to read token file: %w", err) + } + return strings.TrimSpace(contents) == token, nil +} + func (s *snap) SignCertificate(ctx context.Context, csrPEM []byte) ([]byte, error) { // TODO: consider using crypto/x509 for this instead of relying on openssl commands. // NOTE(neoaggelos): x509.CreateCertificate() has some hardcoded fields that are incompatible with MicroK8s. signCmd := exec.CommandContext(ctx, "openssl", "x509", "-sha256", "-req", - "-CA", s.snapDataPath("certs", "ca.crt"), "-CAkey", s.snapDataPath("certs", "ca.key"), + "-CA", s.GetSnapDataPath("certs", "ca.crt"), "-CAkey", s.GetSnapDataPath("certs", "ca.key"), "-CAcreateserial", "-days", "3650", ) signCmd.Stdin = bytes.NewBuffer(csrPEM) @@ -346,9 +368,9 @@ func (s *snap) SignCertificate(ctx context.Context, csrPEM []byte) ([]byte, erro func (s *snap) ImportImage(ctx context.Context, reader io.Reader) error { importCmd := exec.CommandContext(ctx, - s.snapPath("bin", "ctr"), + s.GetSnapPath("bin", "ctr"), "--namespace", "k8s.io", - "--address", s.snapCommonPath("run", "containerd.sock"), + "--address", s.GetSnapCommonPath("run", "containerd.sock"), "image", "import", "--platform", runtime.GOARCH, @@ -365,11 +387,11 @@ func (s *snap) ImportImage(ctx context.Context, reader io.Reader) error { } func (s *snap) WriteCSRConfig(csrConf []byte) error { - return os.WriteFile(s.snapDataPath("certs", "csr.conf.template"), csrConf, 0660) + return os.WriteFile(s.GetSnapDataPath("certs", "csr.conf.template"), csrConf, 0660) } func (s *snap) UpdateContainerdRegistryConfigs(configs map[string][]byte) error { - relativeHostsDir := s.snapDataPath("args", "certs.d") + relativeHostsDir := s.GetSnapDataPath("args", "certs.d") hostsDir, err := filepath.Abs(relativeHostsDir) if err != nil { return fmt.Errorf("failed to get absolute directory for registry configurations: %w", err) @@ -397,7 +419,7 @@ func (s *snap) UpdateContainerdRegistryConfigs(configs map[string][]byte) error } func (s *snap) AddAddonsRepository(ctx context.Context, name, url, reference string, force bool) error { - cmd := []string{filepath.Join(s.snapPath("microk8s-addons.wrapper")), "repo", "add", name, url} + cmd := []string{filepath.Join(s.GetSnapPath("microk8s-addons.wrapper")), "repo", "add", name, url} if reference != "" { cmd = append(cmd, "--reference", reference) } @@ -411,7 +433,7 @@ func (s *snap) AddAddonsRepository(ctx context.Context, name, url, reference str } func (s *snap) JoinCluster(ctx context.Context, url string, worker bool) error { - cmd := []string{filepath.Join(s.snapPath("microk8s-join.wrapper")), url} + cmd := []string{filepath.Join(s.GetSnapPath("microk8s-join.wrapper")), url} if worker { cmd = append(cmd, "--worker") } diff --git a/pkg/snap/snap_capi_token_test.go b/pkg/snap/snap_capi_token_test.go new file mode 100644 index 0000000..23bdb96 --- /dev/null +++ b/pkg/snap/snap_capi_token_test.go @@ -0,0 +1,37 @@ +package snap_test + +import ( + "os" + "path/filepath" + "testing" + + . "github.com/onsi/gomega" + + "github.com/canonical/microk8s-cluster-agent/pkg/snap" +) + +func TestCAPIAuthToken(t *testing.T) { + capiTestPath := "./capi-test" + os.RemoveAll(capiTestPath) + s := snap.NewSnap("", "", "", snap.WithCAPIPath(capiTestPath)) + token := "token123" + + g := NewWithT(t) + + isValid, err := s.IsCAPIAuthTokenValid(token) + g.Expect(err).To(MatchError(os.ErrNotExist)) + g.Expect(isValid).To(BeFalse()) + + capiEtc := filepath.Join(capiTestPath, "etc") + defer os.RemoveAll(capiTestPath) + g.Expect(os.MkdirAll(capiEtc, 0755)).To(Succeed()) + g.Expect(os.WriteFile("./capi-test/etc/token", []byte(token), 0600)).To(Succeed()) + + isValid, err = s.IsCAPIAuthTokenValid("random-token") + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(isValid).To(BeFalse()) + + isValid, err = s.IsCAPIAuthTokenValid(token) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(isValid).To(BeTrue()) +} diff --git a/pkg/snap/util/dqlite.go b/pkg/snap/util/dqlite.go index 2d32c69..f8523a8 100644 --- a/pkg/snap/util/dqlite.go +++ b/pkg/snap/util/dqlite.go @@ -95,3 +95,18 @@ func WaitForDqliteCluster(ctx context.Context, s snap.Snap, f func(DqliteCluster } } } + +// RemoveNodeFromDqlite uses the Dqlite binary to remove a node from the Dqlite cluster. +func RemoveNodeFromDqlite(ctx context.Context, snap snap.Snap, removeEp string) error { + binPath := snap.GetSnapPath("bin", "dqlite") + clusterYamlPath := snap.GetSnapDataPath("var", "kubernetes", "backend", "cluster.yaml") + clusterCrtPath := snap.GetSnapDataPath("var", "kubernetes", "backend", "cluster.crt") + clusterKeyPath := snap.GetSnapDataPath("var", "kubernetes", "backend", "cluster.key") + + // NOTE(Hue): The last two arguments (.remove
) should be a single string. Otherwise Dqlite throws an error. + if err := snap.RunCommand(ctx, binPath, "-s", "file://"+clusterYamlPath, "-c", clusterCrtPath, "-k", clusterKeyPath, "-f", "json", "k8s", fmt.Sprintf(".remove %s", removeEp)); err != nil { + return fmt.Errorf("failed to run remove command: %w", err) + } + + return nil +} diff --git a/pkg/snap/util/dqlite_test.go b/pkg/snap/util/dqlite_test.go index 5dfdcdf..e5a3b68 100644 --- a/pkg/snap/util/dqlite_test.go +++ b/pkg/snap/util/dqlite_test.go @@ -2,10 +2,14 @@ package snaputil_test import ( "context" + "errors" + "fmt" "reflect" "testing" "time" + . "github.com/onsi/gomega" + "github.com/canonical/microk8s-cluster-agent/pkg/snap/mock" snaputil "github.com/canonical/microk8s-cluster-agent/pkg/snap/util" ) @@ -84,3 +88,33 @@ Role: 0`, }) } + +func TestRemoveNodeFromDqlite(t *testing.T) { + t.Run("CommandFails", func(t *testing.T) { + cmdErr := errors.New("failed to run command") + s := &mock.Snap{ + RunCommandErr: cmdErr, + } + + err := snaputil.RemoveNodeFromDqlite(context.Background(), s, "1.1.1.1:1234") + + g := NewWithT(t) + g.Expect(err).To(MatchError(cmdErr)) + }) + + t.Run("CommandRunsSuccessfully", func(t *testing.T) { + snapDir := "/snapDir" + snapDataDir := "/snapDataDir" + removeEp := "1.1.1.1:1234" + + s := &mock.Snap{ + SnapDir: snapDir, + SnapDataDir: snapDataDir, + } + + g := NewWithT(t) + g.Expect(snaputil.RemoveNodeFromDqlite(context.Background(), s, removeEp)).To(Succeed()) + g.Expect(s.RunCommandCalledWith).To(HaveLen(1)) + g.Expect(s.RunCommandCalledWith[0].Commands).To(ContainElements(ContainSubstring(snapDir), ContainSubstring(snapDataDir), fmt.Sprintf(".remove %s", removeEp))) + }) +} diff --git a/pkg/util/token.go b/pkg/util/token.go index 6c6ed72..95d7e5d 100644 --- a/pkg/util/token.go +++ b/pkg/util/token.go @@ -40,8 +40,8 @@ func NewRandomString(letters RandomCharacters, length int) string { // A token may optionally have a TTL, which is appended at the end of the token. // For example, the tokens file may look like this: // -// token1 -// token2|35616531876 +// token1 +// token2|35616531876 // // In the file above, token1 is a valid token. token2 is valid until the unix timestamp 35616531876. func IsValidToken(token string, tokensFile string) (isValidToken, hasTTL bool) {