diff --git a/pkg/api/rule.go b/pkg/api/rule.go index 652beef..ce6d410 100644 --- a/pkg/api/rule.go +++ b/pkg/api/rule.go @@ -26,9 +26,15 @@ type RuleConditions struct { Methods []string `json:"methods"` } type RuleActions struct { - Proxy RuleActionsProxy `json:"proxy"` + Proxy RuleActionsProxy `json:"proxy"` + DirectResponse RuleActionsDirectResponse `json:"directResponse" yaml:"directResponse"` } type RuleActionsProxy struct { Hostname string `json:"hostname"` Port int64 `json:"port"` } + +type RuleActionsDirectResponse struct { + Status uint32 `json:"status"` + Body string `json:"body"` +} diff --git a/pkg/envoy/listener.go b/pkg/envoy/listener.go index a32fb7f..8777926 100644 --- a/pkg/envoy/listener.go +++ b/pkg/envoy/listener.go @@ -2,6 +2,7 @@ package envoy import ( "fmt" + "reflect" "sort" "strings" @@ -150,9 +151,10 @@ func (l *Listener) getListenerRouteSpecifier(manager hcm.HttpConnectionManager) return routeSpecifier, nil } -func (l *Listener) getVirtualHost(hostname, targetHostname, targetPrefix, clusterName, virtualHostName string, methods []string, matchType string) *route.VirtualHost { +func (l *Listener) getVirtualHost(hostname, targetHostname, targetPrefix, clusterName, virtualHostName string, methods []string, matchType string, directResponse DirectResponse) *route.VirtualHost { var hostRewriteSpecifier *route.RouteAction_HostRewrite var routes []*route.Route + var routeAction *route.Route_Route if hostname == "" { hostname = "*" @@ -162,15 +164,16 @@ func (l *Listener) getVirtualHost(hostname, targetHostname, targetPrefix, cluste hostRewriteSpecifier = &route.RouteAction_HostRewrite{ HostRewrite: targetHostname, } - } - - routeAction := &route.Route_Route{ - Route: &route.RouteAction{ - HostRewriteSpecifier: hostRewriteSpecifier, - ClusterSpecifier: &route.RouteAction_Cluster{ - Cluster: clusterName, + routeAction = &route.Route_Route{ + Route: &route.RouteAction{ + HostRewriteSpecifier: hostRewriteSpecifier, + ClusterSpecifier: &route.RouteAction_Cluster{ + Cluster: clusterName, + }, }, - }, + } + } else { + routeAction = &route.Route_Route{} } var headers []*route.HeaderMatcher @@ -185,7 +188,8 @@ func (l *Listener) getVirtualHost(hostname, targetHostname, targetPrefix, cluste }) } } - if matchType == "prefix" { + switch matchType { + case "prefix": if len(headers) == 0 { routes = append(routes, &route.Route{ Match: &route.RouteMatch{ @@ -208,7 +212,7 @@ func (l *Listener) getVirtualHost(hostname, targetHostname, targetPrefix, cluste }) } } - } else if matchType == "path" { + case "path": if len(headers) == 0 { routes = append(routes, &route.Route{ Match: &route.RouteMatch{ @@ -231,7 +235,7 @@ func (l *Listener) getVirtualHost(hostname, targetHostname, targetPrefix, cluste }) } } - } else if matchType == "regex" { + case "regex": if len(headers) == 0 { routes = append(routes, &route.Route{ Match: &route.RouteMatch{ @@ -256,6 +260,22 @@ func (l *Listener) getVirtualHost(hostname, targetHostname, targetPrefix, cluste } } + // fill out directresponse action if defined + if directResponse.Status > 0 { + for routeKey := range routes { + routes[routeKey].Action = &route.Route_DirectResponse{ + DirectResponse: &route.DirectResponseAction{ + Status: directResponse.Status, + Body: &core.DataSource{ + Specifier: &core.DataSource_InlineString{ + InlineString: directResponse.Body, + }, + }, + }, + } + } + } + return &route.VirtualHost{ Name: virtualHostName, Domains: []string{hostname}, @@ -341,7 +361,7 @@ func (l *Listener) updateListener(cache *WorkQueueCache, params ListenerParams, } // create new virtualhost - v := l.getVirtualHost(params.Conditions.Hostname, params.TargetHostname, targetPrefix, params.Name, virtualHostname, params.Conditions.Methods, matchType) + v := l.getVirtualHost(params.Conditions.Hostname, params.TargetHostname, targetPrefix, params.Name, virtualHostname, params.Conditions.Methods, matchType, params.DirectResponse) // check if we need to overwrite the virtualhost virtualHostKey := -1 @@ -541,7 +561,7 @@ func (l *Listener) DeleteRoute(cache *WorkQueueCache, params ListenerParams, par return err } - v := l.getVirtualHost(params.Conditions.Hostname, params.TargetHostname, targetPrefix, params.Name, virtualHostname, params.Conditions.Methods, matchType) + v := l.getVirtualHost(params.Conditions.Hostname, params.TargetHostname, targetPrefix, params.Name, virtualHostname, params.Conditions.Methods, matchType, params.DirectResponse) virtualHostKey := -1 for k, curVirtualHost := range routeSpecifier.RouteConfig.VirtualHosts { @@ -604,17 +624,28 @@ func (l *Listener) validateListeners(listeners []cache.Resource, clusterNames [] } for _, virtualHost := range routeSpecifier.RouteConfig.VirtualHosts { for _, virtualHostRoute := range virtualHost.Routes { - clusterFound := false - virtualHostRouteClusterName := virtualHostRoute.Action.(*route.Route_Route).Route.ClusterSpecifier.(*route.RouteAction_Cluster).Cluster - for _, clusterName := range clusterNames { - if clusterName == virtualHostRouteClusterName { - clusterFound = true + if virtualHostRoute.Action != nil { + switch reflect.TypeOf(virtualHostRoute.Action).String() { + case "*route.Route_Route": + clusterFound := false + virtualHostRouteClusterName := virtualHostRoute.Action.(*route.Route_Route).Route.ClusterSpecifier.(*route.RouteAction_Cluster).Cluster + for _, clusterName := range clusterNames { + if clusterName == virtualHostRouteClusterName { + clusterFound = true + } + } + if !clusterFound { + return false, fmt.Errorf("Cluster not found: %s", virtualHostRouteClusterName) + } + case "*route.Route_DirectResponse": + logger.Debugf("Validation: DirectResponse, no cluster validation necessary") + // no validation necessary + default: + return false, fmt.Errorf("Route action type is unknown: %s", reflect.TypeOf(virtualHostRoute.Action).String()) } + } else { + return false, fmt.Errorf("Validation: no route action found for virtualhost: %+v", virtualHost) } - if !clusterFound { - return false, fmt.Errorf("Cluster not found: %s", virtualHostRouteClusterName) - } - } } } @@ -636,3 +667,50 @@ func (l *Listener) updateDefaultTracingSetting(tracing TracingParams) { func (l *Listener) newHTTPRouterFilter() []*hcm.HttpFilter { return l.httpFilter } + +func (l *Listener) printListener(cache *WorkQueueCache) (string, error) { + var res string + for _, listener := range cache.listeners { + ll := listener.(*api.Listener) + res += "Listener: " + ll.Name + "\n" + manager, err := getListenerHTTPConnectionManager(ll) + if err != nil { + return "", err + } + routeSpecifier, err := l.getListenerRouteSpecifier(manager) + if err != nil { + return "", err + } + for _, virtualHost := range routeSpecifier.RouteConfig.VirtualHosts { + res += "Virtualhost: " + virtualHost.GetName() + "\n" + for _, virtualHostRoute := range virtualHost.Routes { + if virtualHostRoute.Match != nil { + if virtualHostRoute.Match.GetPath() != "" { + res += "Match path: " + virtualHostRoute.Match.GetPath() + "\n" + } + if virtualHostRoute.Match.GetPrefix() != "" { + res += "Match prefix: " + virtualHostRoute.Match.GetPrefix() + "\n" + } + if virtualHostRoute.Match.GetRegex() != "" { + res += "Match regex: " + virtualHostRoute.Match.GetRegex() + "\n" + } + } + if virtualHostRoute.Action != nil { + switch reflect.TypeOf(virtualHostRoute.Action).String() { + case "*route.Route_Route": + res += "Route action (cluster): " + virtualHostRoute.Action.(*route.Route_Route).Route.ClusterSpecifier.(*route.RouteAction_Cluster).Cluster + "\n" + case "*route.Route_DirectResponse": + res += "Route action (directResponse): " + res += fmt.Sprint(virtualHostRoute.Action.(*route.Route_DirectResponse).DirectResponse.GetStatus()) + " " + res += virtualHostRoute.Action.(*route.Route_DirectResponse).DirectResponse.Body.GetInlineString() + "\n" + default: + return "", fmt.Errorf("Route action type is unknown: %s", reflect.TypeOf(virtualHostRoute.Action).String()) + } + } else { + return "", fmt.Errorf("Validation: no route action found for virtualhost: %+v", virtualHost) + } + } + } + } + return res, nil +} diff --git a/pkg/envoy/listener_test.go b/pkg/envoy/listener_test.go index 5a822ba..2f7c47d 100644 --- a/pkg/envoy/listener_test.go +++ b/pkg/envoy/listener_test.go @@ -299,6 +299,17 @@ func TestUpdateListener(t *testing.T) { FailureModeAllow: false, }, } + params10 := ListenerParams{ + Name: "directResponseTest", + Conditions: Conditions{ + Path: "/directresponse", + Methods: []string{"GET"}, + }, + DirectResponse: DirectResponse{ + Status: 200, + Body: "OK", + }, + } listener := l.createListener(params1, paramsTLS1) cache.listeners = append(cache.listeners, listener) @@ -498,7 +509,17 @@ func TestUpdateListener(t *testing.T) { if err := validateNewHTTPRouterFilter(l.newHTTPRouterFilter(), params9); err != nil { t.Errorf("Validation failed: %s", err) return + } + // update listener with domain 10 + if err := l.updateListener(&cache, params10, paramsTLS1); err != nil { + t.Errorf("Error: %s", err) + return + } + // validate domain 10 + if err := validateDomain(cache.listeners, params10); err != nil { + t.Errorf("Validation failed: %s", err) + return } } @@ -661,6 +682,7 @@ func validateAttributes(manager hcm.HttpConnectionManager, params ListenerParams prefixFound := false pathFound := false regexFound := false + directResponseFound := false methodsFound := make(map[string]bool) if params.Conditions.Hostname == "" { @@ -701,6 +723,17 @@ func validateAttributes(manager hcm.HttpConnectionManager, params ListenerParams } } } + switch reflect.TypeOf(r.Action).String() { + case "*route.Route_Route": + // do nothing here + case "*route.Route_DirectResponse": + d := r.Action.(*route.Route_DirectResponse).DirectResponse + if params.DirectResponse.Status == d.GetStatus() && params.DirectResponse.Body == d.GetBody().GetInlineString() { + directResponseFound = true + } + default: + return fmt.Errorf("Type is %s", reflect.TypeOf(r.Action).String()) + } } } } @@ -737,6 +770,12 @@ func validateAttributes(manager hcm.HttpConnectionManager, params ListenerParams logger.Debugf("Methods found: %s", strings.Join(params.Conditions.Methods, ",")) } + if params.DirectResponse.Status > 0 && !directResponseFound { + return fmt.Errorf("Got directresponse parameter, but no directresponse found") + } else { + logger.Debugf("Directresponse found found: %d : %s", params.DirectResponse.Status, params.DirectResponse.Body) + } + return validateJWT(manager, params) } diff --git a/pkg/envoy/testdata/test-directresponse.yaml b/pkg/envoy/testdata/test-directresponse.yaml new file mode 100644 index 0000000..3752011 --- /dev/null +++ b/pkg/envoy/testdata/test-directresponse.yaml @@ -0,0 +1,11 @@ +api: proxy.in4it.io/v1 +kind: rule +metadata: + name: healthcheck +spec: + conditions: + - path: /.roxprox/health + actions: + - directResponse: + status: 200 + body: "OK" diff --git a/pkg/envoy/types.go b/pkg/envoy/types.go index 92d5277..84a2b89 100644 --- a/pkg/envoy/types.go +++ b/pkg/envoy/types.go @@ -47,6 +47,7 @@ type ListenerParams struct { Conditions Conditions Auth Auth Authz Authz + DirectResponse DirectResponse } type ChallengeParams struct { @@ -77,9 +78,10 @@ type Auth struct { } type Action struct { - RuleName string - Type string - Proxy ActionProxy + RuleName string + Type string + Proxy ActionProxy + DirectResponse DirectResponseAction } type ActionProxy struct { @@ -101,3 +103,12 @@ type TracingParams struct { RandomSampling float64 OverallSampling float64 } + +type DirectResponse struct { + Status uint32 + Body string +} +type DirectResponseAction struct { + Status uint32 + Body string +} diff --git a/pkg/envoy/workqueue.go b/pkg/envoy/workqueue.go index a403286..6cd615d 100644 --- a/pkg/envoy/workqueue.go +++ b/pkg/envoy/workqueue.go @@ -79,6 +79,7 @@ func (w *WorkQueue) Submit(items []WorkQueueItem) (string, error) { for k, item := range items { itemID := uuid.New().String() items[k].id = itemID + logger.Tracef("WorkQueue: processing item: %s", item.Action) switch item.Action { case "createCluster": if element, err := w.cluster.findCluster(w.cache.clusters, item.ClusterParams); err == nil { @@ -115,6 +116,18 @@ func (w *WorkQueue) Submit(items []WorkQueueItem) (string, error) { } updateXds = true } + case "createRuleWithoutCluster": + if len(w.cache.listeners) == 0 { + w.cache.listeners = append(w.cache.listeners, w.listener.createListener(item.ListenerParams, item.TLSParams)) + } + err := w.listener.updateListener(&w.cache, item.ListenerParams, item.TLSParams) + if err != nil { + logger.Errorf("createRule error: %s", err) + item.state = "error" + } else { + item.state = "finished" + } + updateXds = true case "createJwtRule": err := w.jwtProvider.UpdateJwtRule(&w.cache, item.ListenerParams, item.TLSParams) if err != nil { diff --git a/pkg/envoy/xds.go b/pkg/envoy/xds.go index 6ab83d4..d03c81b 100644 --- a/pkg/envoy/xds.go +++ b/pkg/envoy/xds.go @@ -375,21 +375,46 @@ func (x *XDS) getAction(ruleName string, actions []pkgApi.RuleActions) Action { action.RuleName = ruleName action.Proxy.TargetHostname = ruleAction.Proxy.Hostname action.Proxy.Port = ruleAction.Proxy.Port + } else if ruleAction.DirectResponse.Status > 0 { + action.Type = "directResponse" + action.RuleName = ruleName + action.DirectResponse.Status = ruleAction.DirectResponse.Status + action.DirectResponse.Body = ruleAction.DirectResponse.Body } } return action } func (x *XDS) getListenerParams(action Action, condition pkgApi.RuleConditions) ListenerParams { - return ListenerParams{ - Name: action.RuleName, - TargetHostname: action.Proxy.TargetHostname, - Conditions: Conditions{ - Hostname: condition.Hostname, - Prefix: condition.Prefix, - Path: condition.Path, - Regex: condition.Regex, - Methods: condition.Methods, - }, + switch action.Type { + case "proxy": + return ListenerParams{ + Name: action.RuleName, + TargetHostname: action.Proxy.TargetHostname, + Conditions: Conditions{ + Hostname: condition.Hostname, + Prefix: condition.Prefix, + Path: condition.Path, + Regex: condition.Regex, + Methods: condition.Methods, + }, + } + case "directResponse": + return ListenerParams{ + Name: action.RuleName, + DirectResponse: DirectResponse{ + Status: action.DirectResponse.Status, + Body: action.DirectResponse.Body, + }, + Conditions: Conditions{ + Hostname: condition.Hostname, + Prefix: condition.Prefix, + Path: condition.Path, + Regex: condition.Regex, + Methods: condition.Methods, + }, + } + default: + return ListenerParams{} } } func (x *XDS) getClusterParams(action Action) ClusterParams { @@ -411,77 +436,91 @@ func (x *XDS) getAuthParams(jwtProviderName string, jwtProvider pkgApi.JwtProvid func (x *XDS) ImportRule(rule pkgApi.Rule) ([]WorkQueueItem, error) { var workQueueItems []WorkQueueItem action := x.getAction(rule.Metadata.Name, rule.Spec.Actions) - if action.Type == "proxy" { - // create cluster + createRuleType := "" + // create cluster + switch action.Type { + case "proxy": workQueueItem := WorkQueueItem{ Action: "createCluster", ClusterParams: x.getClusterParams(action), } workQueueItems = append(workQueueItems, workQueueItem) - // create listener that proxies to targetHostname - for _, condition := range rule.Spec.Conditions { - // validation - if rule.Spec.Certificate != "" && condition.Hostname == "" { - return []WorkQueueItem{}, fmt.Errorf("Validation error: rule with certificate, but without a hostname condition - ignoring rule") - - } - if condition.Hostname != "" || condition.Prefix != "" || condition.Path != "" || condition.Regex != "" { - listenerParams := x.getListenerParams(action, condition) - if rule.Spec.Auth.JwtProvider != "" { - object, err := x.getObject("jwtProvider", rule.Spec.Auth.JwtProvider) - if err != nil { - logger.Infof("Could not set Auth parameters: jwtprovider not found (error: %s)", err) - return workQueueItems, err - } else { - listenerParams.Auth = x.getAuthParams(rule.Spec.Auth.JwtProvider, object.Data.(pkgApi.JwtProvider)) - } - workQueueItems = append(workQueueItems, []WorkQueueItem{ - { - Action: "createRule", - ListenerParams: listenerParams, - TLSParams: TLSParams{}, - }, - { - Action: "updateListenerWithJwtProvider", - ListenerParams: listenerParams, - }, - { - Action: "createJwtRule", - ListenerParams: listenerParams, - TLSParams: TLSParams{}, - }, - }...) + createRuleType = "createRule" + case "directResponse": + createRuleType = "createRuleWithoutCluster" + default: + logger.Debugf("Rule without action: %+v", rule) + } + // create listener that proxies to targetHostname + for _, condition := range rule.Spec.Conditions { + // validation + if rule.Spec.Certificate != "" && condition.Hostname == "" { + return []WorkQueueItem{}, fmt.Errorf("Validation error: rule with certificate, but without a hostname condition - ignoring rule") + } + if condition.Hostname != "" || condition.Prefix != "" || condition.Path != "" || condition.Regex != "" { + listenerParams := x.getListenerParams(action, condition) + if rule.Spec.Auth.JwtProvider != "" { + object, err := x.getObject("jwtProvider", rule.Spec.Auth.JwtProvider) + if err != nil { + logger.Infof("Could not set Auth parameters: jwtprovider not found (error: %s)", err) + return workQueueItems, err } else { - workQueueItems = append(workQueueItems, WorkQueueItem{ - Action: "createRule", + listenerParams.Auth = x.getAuthParams(rule.Spec.Auth.JwtProvider, object.Data.(pkgApi.JwtProvider)) + } + workQueueItems = append(workQueueItems, []WorkQueueItem{ + { + Action: createRuleType, // createRule or createRuleWithoutCluster ListenerParams: listenerParams, TLSParams: TLSParams{}, - }) - } + }, + { + Action: "updateListenerWithJwtProvider", + ListenerParams: listenerParams, + }, + { + Action: "createJwtRule", + ListenerParams: listenerParams, + TLSParams: TLSParams{}, + }, + }...) + } else { + workQueueItems = append(workQueueItems, WorkQueueItem{ + Action: createRuleType, // createRule or createRuleWithoutCluster + ListenerParams: listenerParams, + TLSParams: TLSParams{}, + }) + } - if rule.Spec.Certificate == "letsencrypt" { - // TLS listener - certBundle, err := x.s.GetCertBundle(rule.Metadata.Name) - if err != nil && err != x.s.GetError("errNotExist") { + if rule.Spec.Certificate == "letsencrypt" { + // TLS listener + certBundle, err := x.s.GetCertBundle(rule.Metadata.Name) + if err != nil && err != x.s.GetError("errNotExist") { + return workQueueItems, err + } + if err != nil && err == x.s.GetError("errNotExist") { + // TODO: add to list for creation + logger.Debugf("Certificate not found, needs to be created") + } + if err == nil { + logger.Debugf("Certificate found, adding TLS") + privateKeyPem, err := x.s.GetPrivateKeyPem(rule.Metadata.Name) + if err != nil { return workQueueItems, err } - if err != nil && err == x.s.GetError("errNotExist") { - // TODO: add to list for creation - logger.Debugf("Certificate not found, needs to be created") - } - if err == nil { - logger.Debugf("Certificate found, adding TLS") - privateKeyPem, err := x.s.GetPrivateKeyPem(rule.Metadata.Name) - if err != nil { - return workQueueItems, err + createRuleKey := -1 + for k, v := range workQueueItems { + if v.Action == "createRule" { + createRuleKey = k } - workQueueItemTLS := workQueueItem - workQueueItemTLS.Action = "createRule" + } + if createRuleKey != -1 { + workQueueItemTLS := workQueueItems[createRuleKey] + workQueueItemTLS.Action = createRuleType // createRule or createRuleWithoutCluster workQueueItemTLS.TLSParams = TLSParams{ Name: rule.Metadata.Name, CertBundle: certBundle, PrivateKey: privateKeyPem, - Domain: workQueueItem.ListenerParams.Conditions.Hostname, + Domain: workQueueItems[createRuleKey].ListenerParams.Conditions.Hostname, } workQueueItems = append(workQueueItems, workQueueItemTLS) } diff --git a/pkg/envoy/xds_test.go b/pkg/envoy/xds_test.go index 08fba77..feb15bd 100644 --- a/pkg/envoy/xds_test.go +++ b/pkg/envoy/xds_test.go @@ -471,3 +471,32 @@ func TestTracingObject(t *testing.T) { } } + +func TestDirectResponseObject(t *testing.T) { + logger.SetLogLevel(loggo.DEBUG) + s, err := initStorage() + if err != nil { + t.Errorf("Couldn't initialize storage: %s", err) + return + } + x := NewXDS(s, "", "") + ObjectFileNames := []string{"test-directresponse.yaml"} + for _, filename := range ObjectFileNames { + newItems, err := x.putObject(filename) + if err != nil { + t.Errorf("PutObject failed: %s", err) + return + } + _, err = x.workQueue.Submit(newItems) + if err != nil { + t.Errorf("WorkQueue error: %s", err) + return + } + } + out, err := x.workQueue.listener.printListener(&x.workQueue.cache) + if err != nil { + t.Errorf("listener print error: %s", err) + return + } + fmt.Printf("%s\n", out) +} diff --git a/terraform/lb.tf b/terraform/lb.tf index 477efb5..4a9dc48 100644 --- a/terraform/lb.tf +++ b/terraform/lb.tf @@ -77,7 +77,8 @@ resource "aws_lb_target_group" "envoy-proxy-http" { healthy_threshold = 2 unhealthy_threshold = 2 protocol = var.loadbalancer == "alb" ? "HTTP" : "TCP" - matcher = var.loadbalancer == "alb" ? "200,404,301,302" : "" + matcher = var.loadbalancer == "alb" ? var.loadbalancer_healthcheck_matcher : "" + path = var.loadbalancer == "alb" ? var.loadbalancer_healthcheck_path : "" interval = 30 } } diff --git a/terraform/variables.tf b/terraform/variables.tf index 94a0254..6be338b 100644 --- a/terraform/variables.tf +++ b/terraform/variables.tf @@ -65,6 +65,16 @@ variable "loadbalancer" { default = "nlb" } +variable "loadbalancer_healthcheck_matcher" { + description = "loadbalancer healthcheck matcher to use" + default = "200,404,301,302" +} + +variable "loadbalancer_healthcheck_path" { + description = "loadbalancer healthcheck path to use" + default = "/" +} + variable "loadbalancer_alb_cert" { description = "loadbalancer alb certificate to use" default = ""