Skip to content

Commit

Permalink
refactor: proxy and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
phillebaba committed Aug 6, 2024
1 parent ee60eb8 commit 0aa2700
Show file tree
Hide file tree
Showing 6 changed files with 269 additions and 155 deletions.
12 changes: 10 additions & 2 deletions src/cmd/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ var agentCmd = &cobra.Command{
Short: lang.CmdInternalAgentShort,
Long: lang.CmdInternalAgentLong,
RunE: func(cmd *cobra.Command, _ []string) error {
return agent.StartWebhook(cmd.Context())
cluster, err := cluster.NewCluster()
if err != nil {
return err
}
return agent.StartWebhook(cmd.Context(), cluster)
},
}

Expand All @@ -51,7 +55,11 @@ var httpProxyCmd = &cobra.Command{
Short: lang.CmdInternalProxyShort,
Long: lang.CmdInternalProxyLong,
RunE: func(cmd *cobra.Command, _ []string) error {
return agent.StartHTTPProxy(cmd.Context())
cluster, err := cluster.NewCluster()
if err != nil {
return err
}
return agent.StartHTTPProxy(cmd.Context(), cluster)
},
}

Expand Down
1 change: 0 additions & 1 deletion src/config/lang/english.go
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,6 @@ const (
AgentErrMarshallJSONPatch = "unable to marshall the json patch"
AgentErrMarshalResponse = "unable to marshal the response"
AgentErrNilReq = "malformed admission review: request is nil"
AgentErrUnableTransform = "unable to transform the provided request; see zarf http proxy logs for more details"
)

// Package create
Expand Down
97 changes: 28 additions & 69 deletions src/internal/agent/http/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
package http

import (
"context"
"crypto/tls"
"fmt"
"io"
Expand All @@ -14,51 +13,43 @@ import (
"net/url"
"strings"

"github.com/zarf-dev/zarf/src/config/lang"
"github.com/zarf-dev/zarf/src/pkg/cluster"
"github.com/zarf-dev/zarf/src/pkg/message"
"github.com/zarf-dev/zarf/src/pkg/transform"
"github.com/zarf-dev/zarf/src/types"
)

// ProxyHandler constructs a new httputil.ReverseProxy and returns an http handler.
func ProxyHandler() http.HandlerFunc {
func ProxyHandler(cluster *cluster.Cluster) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
err := proxyRequestTransform(r)
state, err := cluster.LoadZarfState(r.Context())
if err != nil {
message.Debugf("%#v", err)
w.WriteHeader(http.StatusInternalServerError)
//nolint: errcheck // ignore
w.Write([]byte(lang.AgentErrUnableTransform))
w.Write([]byte("unable to load Zarf state, see the Zarf HTTP proxy logs for more details"))
return
}
err = proxyRequestTransform(r, state)
if err != nil {
message.Debugf("%#v", err)
w.WriteHeader(http.StatusInternalServerError)
//nolint: errcheck // ignore
w.Write([]byte("unable to transform the provided request, see the Zarf HTTP proxy logs for more details"))
return
}

proxy := &httputil.ReverseProxy{Director: func(_ *http.Request) {}, ModifyResponse: proxyResponseTransform}
proxy.ServeHTTP(w, r)
}
}

func proxyRequestTransform(r *http.Request) error {
message.Debugf("Before Req %#v", r)
message.Debugf("Before Req URL %#v", r.URL)

func proxyRequestTransform(r *http.Request, state *types.ZarfState) error {
// We add this so that we can use it to rewrite urls in the response if needed
r.Header.Add("X-Forwarded-Host", r.Host)

// We remove this so that go will encode and decode on our behalf (see https://pkg.go.dev/net/http#Transport DisableCompression)
r.Header.Del("Accept-Encoding")

c, err := cluster.NewCluster()
if err != nil {
return err
}
ctx := context.Background()
state, err := c.LoadZarfState(ctx)
if err != nil {
return err
}

var targetURL *url.URL

// Setup authentication for each type of service based on User Agent
switch {
case isGitUserAgent(r.UserAgent()):
Expand All @@ -70,6 +61,8 @@ func proxyRequestTransform(r *http.Request) error {
}

// Transform the URL; if we see the NoTransform prefix, strip it; otherwise, transform the URL based on User Agent
var err error
var targetURL *url.URL
if strings.HasPrefix(r.URL.Path, transform.NoTransform) {
switch {
case isGitUserAgent(r.UserAgent()):
Expand All @@ -89,7 +82,6 @@ func proxyRequestTransform(r *http.Request) error {
targetURL, err = transform.GenTransformURL(state.ArtifactServer.Address, getTLSScheme(r.TLS)+r.Host+r.URL.String())
}
}

if err != nil {
return err
}
Expand All @@ -98,19 +90,12 @@ func proxyRequestTransform(r *http.Request) error {
r.URL = targetURL
r.RequestURI = getRequestURI(targetURL.Path, targetURL.RawQuery, targetURL.Fragment)

message.Debugf("After Req %#v", r)
message.Debugf("After Req URL%#v", r.URL)

return nil
}

func proxyResponseTransform(resp *http.Response) error {
message.Debugf("Before Resp %#v", resp)

// Handle redirection codes (3xx) by adding a marker to let Zarf know this has been redirected
if resp.StatusCode/100 == 3 {
message.Debugf("Before Resp Location %#v", resp.Header.Get("Location"))

locationURL, err := url.Parse(resp.Header.Get("Location"))
if err != nil {
return err
Expand All @@ -119,72 +104,46 @@ func proxyResponseTransform(resp *http.Response) error {
locationURL.Host = resp.Request.Header.Get("X-Forwarded-Host")

resp.Header.Set("Location", locationURL.String())

message.Debugf("After Resp Location %#v", resp.Header.Get("Location"))
}

contentType := resp.Header.Get("Content-Type")

// Handle text content returns that may contain links
contentType := resp.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "text") || strings.HasPrefix(contentType, "application/json") || strings.HasPrefix(contentType, "application/xml") {
err := replaceBodyLinks(resp)

forwardedPrefix := fmt.Sprintf("%s%s%s", getTLSScheme(resp.Request.TLS), resp.Request.Header.Get("X-Forwarded-Host"), transform.NoTransform)
targetPrefix := fmt.Sprintf("%s%s", getTLSScheme(resp.TLS), resp.Request.Host)
b, err := io.ReadAll(resp.Body)
if err != nil {
message.Debugf("%#v", err)
return err
}
}

message.Debugf("After Resp %#v", resp)

return nil
}

func replaceBodyLinks(resp *http.Response) error {
message.Debugf("Resp Request: %#v", resp.Request)

// Create the forwarded (online) and target (offline) URL prefixes to replace
forwardedPrefix := fmt.Sprintf("%s%s%s", getTLSScheme(resp.Request.TLS), resp.Request.Header.Get("X-Forwarded-Host"), transform.NoTransform)
targetPrefix := fmt.Sprintf("%s%s", getTLSScheme(resp.TLS), resp.Request.Host)
err = resp.Body.Close()
if err != nil {
return err
}
bodyString := strings.ReplaceAll(string(b), targetPrefix, forwardedPrefix)

b, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
err = resp.Body.Close()
if err != nil {
return err
resp.Body = io.NopCloser(strings.NewReader(bodyString))
resp.ContentLength = int64(len(bodyString))
resp.Header.Set("Content-Length", fmt.Sprint(int64(len(bodyString))))
}
bodyString := strings.ReplaceAll(string(b), targetPrefix, forwardedPrefix)

// Setup the new reader, and correct the content length
resp.Body = io.NopCloser(strings.NewReader(bodyString))
resp.ContentLength = int64(len(bodyString))
resp.Header.Set("Content-Length", fmt.Sprint(int64(len(bodyString))))

return nil
}

func getTLSScheme(tls *tls.ConnectionState) string {
scheme := "https://"

if tls == nil {
scheme = "http://"
}

return scheme
}

func getRequestURI(path, query, fragment string) string {
uri := path

if query != "" {
uri += "?" + query
}

if fragment != "" {
uri += "#" + fragment
}

return uri
}

Expand Down
Loading

0 comments on commit 0aa2700

Please sign in to comment.