diff --git a/src/cmd/internal.go b/src/cmd/internal.go index ab14a9d5d6..d4196aa23e 100644 --- a/src/cmd/internal.go +++ b/src/cmd/internal.go @@ -42,7 +42,11 @@ var agentCmd = &cobra.Command{ Short: lang.CmdInternalAgentShort, Long: lang.CmdInternalAgentLong, RunE: func(cmd *cobra.Command, _ []string) error { - return agent.StartWebhook(cmd.Context()) + cluster, err := cluster.NewCluster() + if err != nil { + return err + } + return agent.StartWebhook(cmd.Context(), cluster) }, } @@ -51,7 +55,11 @@ var httpProxyCmd = &cobra.Command{ Short: lang.CmdInternalProxyShort, Long: lang.CmdInternalProxyLong, RunE: func(cmd *cobra.Command, _ []string) error { - return agent.StartHTTPProxy(cmd.Context()) + cluster, err := cluster.NewCluster() + if err != nil { + return err + } + return agent.StartHTTPProxy(cmd.Context(), cluster) }, } diff --git a/src/config/lang/english.go b/src/config/lang/english.go index 5cb2db2d61..e637651c61 100644 --- a/src/config/lang/english.go +++ b/src/config/lang/english.go @@ -611,7 +611,6 @@ const ( AgentErrMarshallJSONPatch = "unable to marshall the json patch" AgentErrMarshalResponse = "unable to marshal the response" AgentErrNilReq = "malformed admission review: request is nil" - AgentErrUnableTransform = "unable to transform the provided request; see zarf http proxy logs for more details" ) // Package create diff --git a/src/internal/agent/http/proxy.go b/src/internal/agent/http/proxy.go index 33709dfff7..760ba709ec 100644 --- a/src/internal/agent/http/proxy.go +++ b/src/internal/agent/http/proxy.go @@ -5,7 +5,6 @@ package http import ( - "context" "crypto/tls" "fmt" "io" @@ -14,51 +13,43 @@ import ( "net/url" "strings" - "github.com/zarf-dev/zarf/src/config/lang" "github.com/zarf-dev/zarf/src/pkg/cluster" "github.com/zarf-dev/zarf/src/pkg/message" "github.com/zarf-dev/zarf/src/pkg/transform" + "github.com/zarf-dev/zarf/src/types" ) // ProxyHandler constructs a new httputil.ReverseProxy and returns an http handler. -func ProxyHandler() http.HandlerFunc { +func ProxyHandler(cluster *cluster.Cluster) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - err := proxyRequestTransform(r) + state, err := cluster.LoadZarfState(r.Context()) if err != nil { message.Debugf("%#v", err) w.WriteHeader(http.StatusInternalServerError) //nolint: errcheck // ignore - w.Write([]byte(lang.AgentErrUnableTransform)) + w.Write([]byte("unable to load Zarf state, see the Zarf HTTP proxy logs for more details")) + return + } + err = proxyRequestTransform(r, state) + if err != nil { + message.Debugf("%#v", err) + w.WriteHeader(http.StatusInternalServerError) + //nolint: errcheck // ignore + w.Write([]byte("unable to transform the provided request, see the Zarf HTTP proxy logs for more details")) return } - proxy := &httputil.ReverseProxy{Director: func(_ *http.Request) {}, ModifyResponse: proxyResponseTransform} proxy.ServeHTTP(w, r) } } -func proxyRequestTransform(r *http.Request) error { - message.Debugf("Before Req %#v", r) - message.Debugf("Before Req URL %#v", r.URL) - +func proxyRequestTransform(r *http.Request, state *types.ZarfState) error { // We add this so that we can use it to rewrite urls in the response if needed r.Header.Add("X-Forwarded-Host", r.Host) // We remove this so that go will encode and decode on our behalf (see https://pkg.go.dev/net/http#Transport DisableCompression) r.Header.Del("Accept-Encoding") - c, err := cluster.NewCluster() - if err != nil { - return err - } - ctx := context.Background() - state, err := c.LoadZarfState(ctx) - if err != nil { - return err - } - - var targetURL *url.URL - // Setup authentication for each type of service based on User Agent switch { case isGitUserAgent(r.UserAgent()): @@ -70,6 +61,8 @@ func proxyRequestTransform(r *http.Request) error { } // Transform the URL; if we see the NoTransform prefix, strip it; otherwise, transform the URL based on User Agent + var err error + var targetURL *url.URL if strings.HasPrefix(r.URL.Path, transform.NoTransform) { switch { case isGitUserAgent(r.UserAgent()): @@ -89,7 +82,6 @@ func proxyRequestTransform(r *http.Request) error { targetURL, err = transform.GenTransformURL(state.ArtifactServer.Address, getTLSScheme(r.TLS)+r.Host+r.URL.String()) } } - if err != nil { return err } @@ -98,19 +90,12 @@ func proxyRequestTransform(r *http.Request) error { r.URL = targetURL r.RequestURI = getRequestURI(targetURL.Path, targetURL.RawQuery, targetURL.Fragment) - message.Debugf("After Req %#v", r) - message.Debugf("After Req URL%#v", r.URL) - return nil } func proxyResponseTransform(resp *http.Response) error { - message.Debugf("Before Resp %#v", resp) - // Handle redirection codes (3xx) by adding a marker to let Zarf know this has been redirected if resp.StatusCode/100 == 3 { - message.Debugf("Before Resp Location %#v", resp.Header.Get("Location")) - locationURL, err := url.Parse(resp.Header.Get("Location")) if err != nil { return err @@ -119,72 +104,46 @@ func proxyResponseTransform(resp *http.Response) error { locationURL.Host = resp.Request.Header.Get("X-Forwarded-Host") resp.Header.Set("Location", locationURL.String()) - - message.Debugf("After Resp Location %#v", resp.Header.Get("Location")) } - contentType := resp.Header.Get("Content-Type") - // Handle text content returns that may contain links + contentType := resp.Header.Get("Content-Type") if strings.HasPrefix(contentType, "text") || strings.HasPrefix(contentType, "application/json") || strings.HasPrefix(contentType, "application/xml") { - err := replaceBodyLinks(resp) - + forwardedPrefix := fmt.Sprintf("%s%s%s", getTLSScheme(resp.Request.TLS), resp.Request.Header.Get("X-Forwarded-Host"), transform.NoTransform) + targetPrefix := fmt.Sprintf("%s%s", getTLSScheme(resp.TLS), resp.Request.Host) + b, err := io.ReadAll(resp.Body) if err != nil { - message.Debugf("%#v", err) + return err } - } - - message.Debugf("After Resp %#v", resp) - - return nil -} - -func replaceBodyLinks(resp *http.Response) error { - message.Debugf("Resp Request: %#v", resp.Request) - - // Create the forwarded (online) and target (offline) URL prefixes to replace - forwardedPrefix := fmt.Sprintf("%s%s%s", getTLSScheme(resp.Request.TLS), resp.Request.Header.Get("X-Forwarded-Host"), transform.NoTransform) - targetPrefix := fmt.Sprintf("%s%s", getTLSScheme(resp.TLS), resp.Request.Host) + err = resp.Body.Close() + if err != nil { + return err + } + bodyString := strings.ReplaceAll(string(b), targetPrefix, forwardedPrefix) - b, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - err = resp.Body.Close() - if err != nil { - return err + resp.Body = io.NopCloser(strings.NewReader(bodyString)) + resp.ContentLength = int64(len(bodyString)) + resp.Header.Set("Content-Length", fmt.Sprint(int64(len(bodyString)))) } - bodyString := strings.ReplaceAll(string(b), targetPrefix, forwardedPrefix) - - // Setup the new reader, and correct the content length - resp.Body = io.NopCloser(strings.NewReader(bodyString)) - resp.ContentLength = int64(len(bodyString)) - resp.Header.Set("Content-Length", fmt.Sprint(int64(len(bodyString)))) - return nil } func getTLSScheme(tls *tls.ConnectionState) string { scheme := "https://" - if tls == nil { scheme = "http://" } - return scheme } func getRequestURI(path, query, fragment string) string { uri := path - if query != "" { uri += "?" + query } - if fragment != "" { uri += "#" + fragment } - return uri } diff --git a/src/internal/agent/http/proxy_test.go b/src/internal/agent/http/proxy_test.go new file mode 100644 index 0000000000..16448d74e0 --- /dev/null +++ b/src/internal/agent/http/proxy_test.go @@ -0,0 +1,189 @@ +package http + +import ( + "crypto/tls" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + "github.com/zarf-dev/zarf/src/types" +) + +func TestProxyRequestTransform(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + target string + state *types.ZarfState + expectedPath string + }{ + { + name: "basic request", + target: "http://example.com/zarf-3xx-no-transform/test", + state: &types.ZarfState{ + ArtifactServer: types.ArtifactServerInfo{ + PushUsername: "push-user", + PushToken: "push-token", + }, + }, + expectedPath: "/test", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest(http.MethodGet, tt.target, nil) + req.Header.Set("Accept-Encoding", "foo") + err := proxyRequestTransform(req, tt.state) + require.NoError(t, err) + + require.Empty(t, req.Header.Get("Accept-Encoding")) + + username, password, ok := req.BasicAuth() + require.True(t, ok) + require.Equal(t, tt.state.ArtifactServer.PushUsername, username) + require.Equal(t, tt.state.ArtifactServer.PushToken, password) + + require.Equal(t, tt.expectedPath, req.URL.Path) + }) + } +} + +func TestGetTLSScheme(t *testing.T) { + t.Parallel() + + scheme := getTLSScheme(nil) + require.Equal(t, "http://", scheme) + scheme = getTLSScheme(&tls.ConnectionState{}) + require.Equal(t, "https://", scheme) +} + +func TestGetRequestURI(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + path string + query string + fragment string + expected string + }{ + { + name: "basic", + path: "/foo", + query: "", + fragment: "", + expected: "/foo", + }, + { + name: "query", + path: "/foo", + query: "key=value", + fragment: "", + expected: "/foo?key=value", + }, + { + name: "fragment", + path: "/foo", + query: "", + fragment: "bar", + expected: "/foo#bar", + }, + { + name: "query and fragment", + path: "/foo", + query: "key=value", + fragment: "bar", + expected: "/foo?key=value#bar", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + uri := getRequestURI(tt.path, tt.query, tt.fragment) + require.Equal(t, tt.expected, uri) + }) + } +} + +func TestUserAgent(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + userAgent string + expectedGit bool + expectedPip bool + expectedNpm bool + }{ + { + name: "unknown user agent", + userAgent: "Firefox", + expectedGit: false, + expectedPip: false, + expectedNpm: false, + }, + { + name: "git user agent", + userAgent: "git/2.0.0", + expectedGit: true, + expectedPip: false, + expectedNpm: false, + }, + { + name: "pip user agent", + userAgent: "pip/1.2.3", + expectedGit: false, + expectedPip: true, + expectedNpm: false, + }, + { + name: "twine user agent", + userAgent: "twine/1.8.1", + expectedGit: false, + expectedPip: true, + expectedNpm: false, + }, + { + name: "npm user agent", + userAgent: "npm/1.0.0", + expectedGit: false, + expectedPip: false, + expectedNpm: true, + }, + { + name: "pnpm user agent", + userAgent: "pnpm/1.0.0", + expectedGit: false, + expectedPip: false, + expectedNpm: true, + }, + { + name: "yarn user agent", + userAgent: "yar/1.0.0", + expectedGit: false, + expectedPip: false, + expectedNpm: true, + }, + { + name: "bun user agent", + userAgent: "bun/1.0.0", + expectedGit: false, + expectedPip: false, + expectedNpm: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + require.Equal(t, tt.expectedGit, isGitUserAgent(tt.userAgent)) + require.Equal(t, tt.expectedPip, isPipUserAgent(tt.userAgent)) + require.Equal(t, tt.expectedNpm, isNpmUserAgent(tt.userAgent)) + }) + } +} diff --git a/src/internal/agent/http/server.go b/src/internal/agent/http/server.go deleted file mode 100644 index 6a79aaa449..0000000000 --- a/src/internal/agent/http/server.go +++ /dev/null @@ -1,74 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2021-Present The Zarf Authors - -// Package http provides a http server for the webhook and proxy. -package http - -import ( - "context" - "fmt" - "net/http" - "time" - - "github.com/prometheus/client_golang/prometheus/promhttp" - "github.com/zarf-dev/zarf/src/internal/agent/hooks" - "github.com/zarf-dev/zarf/src/internal/agent/http/admission" - "github.com/zarf-dev/zarf/src/pkg/cluster" -) - -// NewAdmissionServer creates a http.Server for the mutating webhook admission handler. -func NewAdmissionServer(ctx context.Context, port string) (*http.Server, error) { - c, err := cluster.NewCluster() - if err != nil { - return nil, err - } - - // Routers - admissionHandler := admission.NewHandler() - podsMutation := hooks.NewPodMutationHook(ctx, c) - fluxGitRepositoryMutation := hooks.NewGitRepositoryMutationHook(ctx, c) - argocdApplicationMutation := hooks.NewApplicationMutationHook(ctx, c) - argocdRepositoryMutation := hooks.NewRepositorySecretMutationHook(ctx, c) - fluxHelmRepositoryMutation := hooks.NewHelmRepositoryMutationHook(ctx, c) - fluxOCIRepositoryMutation := hooks.NewOCIRepositoryMutationHook(ctx, c) - - // Routers - mux := http.NewServeMux() - mux.Handle("/healthz", healthz()) - mux.Handle("/mutate/pod", admissionHandler.Serve(podsMutation)) - mux.Handle("/mutate/flux-gitrepository", admissionHandler.Serve(fluxGitRepositoryMutation)) - mux.Handle("/mutate/flux-helmrepository", admissionHandler.Serve(fluxHelmRepositoryMutation)) - mux.Handle("/mutate/flux-ocirepository", admissionHandler.Serve(fluxOCIRepositoryMutation)) - mux.Handle("/mutate/argocd-application", admissionHandler.Serve(argocdApplicationMutation)) - mux.Handle("/mutate/argocd-repository", admissionHandler.Serve(argocdRepositoryMutation)) - mux.Handle("/metrics", promhttp.Handler()) - - srv := &http.Server{ - Addr: fmt.Sprintf(":%s", port), - Handler: mux, - ReadHeaderTimeout: 5 * time.Second, // Set ReadHeaderTimeout to avoid Slowloris attacks - } - return srv, nil -} - -// NewProxyServer creates and returns an http proxy server. -func NewProxyServer(port string) *http.Server { - mux := http.NewServeMux() - mux.Handle("/healthz", healthz()) - mux.Handle("/", ProxyHandler()) - mux.Handle("/metrics", promhttp.Handler()) - - return &http.Server{ - Addr: fmt.Sprintf(":%s", port), - Handler: mux, - ReadHeaderTimeout: 5 * time.Second, // Set ReadHeaderTimeout to avoid Slowloris attacks - } -} - -func healthz() http.HandlerFunc { - return func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusOK) - //nolint: errcheck // ignore - w.Write([]byte("ok")) - } -} diff --git a/src/internal/agent/start.go b/src/internal/agent/start.go index 80a04c642f..e001022140 100644 --- a/src/internal/agent/start.go +++ b/src/internal/agent/start.go @@ -7,13 +7,18 @@ package agent import ( "context" "errors" + "fmt" "net/http" "time" + "github.com/prometheus/client_golang/prometheus/promhttp" "golang.org/x/sync/errgroup" "github.com/zarf-dev/zarf/src/config/lang" + "github.com/zarf-dev/zarf/src/internal/agent/hooks" agentHttp "github.com/zarf-dev/zarf/src/internal/agent/http" + "github.com/zarf-dev/zarf/src/internal/agent/http/admission" + "github.com/zarf-dev/zarf/src/pkg/cluster" "github.com/zarf-dev/zarf/src/pkg/message" ) @@ -28,20 +33,48 @@ const ( ) // StartWebhook launches the Zarf agent mutating webhook in the cluster. -func StartWebhook(ctx context.Context) error { - srv, err := agentHttp.NewAdmissionServer(ctx, httpPort) - if err != nil { - return err - } - return startServer(ctx, srv) +func StartWebhook(ctx context.Context, cluster *cluster.Cluster) error { + // Routers + admissionHandler := admission.NewHandler() + podsMutation := hooks.NewPodMutationHook(ctx, cluster) + fluxGitRepositoryMutation := hooks.NewGitRepositoryMutationHook(ctx, cluster) + argocdApplicationMutation := hooks.NewApplicationMutationHook(ctx, cluster) + argocdRepositoryMutation := hooks.NewRepositorySecretMutationHook(ctx, cluster) + fluxHelmRepositoryMutation := hooks.NewHelmRepositoryMutationHook(ctx, cluster) + fluxOCIRepositoryMutation := hooks.NewOCIRepositoryMutationHook(ctx, cluster) + + // Routers + mux := http.NewServeMux() + mux.Handle("/mutate/pod", admissionHandler.Serve(podsMutation)) + mux.Handle("/mutate/flux-gitrepository", admissionHandler.Serve(fluxGitRepositoryMutation)) + mux.Handle("/mutate/flux-helmrepository", admissionHandler.Serve(fluxHelmRepositoryMutation)) + mux.Handle("/mutate/flux-ocirepository", admissionHandler.Serve(fluxOCIRepositoryMutation)) + mux.Handle("/mutate/argocd-application", admissionHandler.Serve(argocdApplicationMutation)) + mux.Handle("/mutate/argocd-repository", admissionHandler.Serve(argocdRepositoryMutation)) + + return startServer(ctx, httpPort, mux) } // StartHTTPProxy launches the zarf agent proxy in the cluster. -func StartHTTPProxy(ctx context.Context) error { - return startServer(ctx, agentHttp.NewProxyServer(httpPort)) +func StartHTTPProxy(ctx context.Context, cluster *cluster.Cluster) error { + mux := http.NewServeMux() + mux.Handle("/", agentHttp.ProxyHandler(cluster)) + return startServer(ctx, httpPort, mux) } -func startServer(ctx context.Context, srv *http.Server) error { +func startServer(ctx context.Context, port string, mux *http.ServeMux) error { + mux.Handle("/metrics", promhttp.Handler()) + mux.Handle("/healthz", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + //nolint: errcheck // ignore + w.Write([]byte("ok")) + })) + srv := &http.Server{ + Addr: fmt.Sprintf(":%s", port), + Handler: mux, + ReadHeaderTimeout: 5 * time.Second, // Set ReadHeaderTimeout to avoid Slowloris attacks + } + g, gCtx := errgroup.WithContext(ctx) g.Go(func() error { err := srv.ListenAndServeTLS(tlsCert, tlsKey)