Skip to content
This repository has been archived by the owner on Dec 8, 2020. It is now read-only.

Commit

Permalink
Merge pull request #25 from puppetlabs/tasks/add-cors-middleware-func
Browse files Browse the repository at this point in the history
Update: Adds missing middleware wrapper to httputil/api.CORSBuilder
  • Loading branch information
kyleterry authored Jun 11, 2020
2 parents d863b25 + 2ad1cce commit 0985913
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 11 deletions.
66 changes: 57 additions & 9 deletions httputil/api/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,37 @@ func corsMatch(ms map[corsMatchable]struct{}, s string) bool {
return false
}

type corsHandler struct {
type corsMiddleware struct {
allowedOrigins map[corsMatchable]struct{}
defaultAllowedOrigin string
next http.Handler
}

func (cm *corsMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if len(cm.allowedOrigins) > 0 {
origin := r.Header.Get("origin")

if corsMatch(cm.allowedOrigins, origin) {
w.Header().Set("access-control-allow-origin", origin)
} else {
w.Header().Set("access-control-allow-origin", cm.defaultAllowedOrigin)
}

w.Header().Set("vary", "Origin")
}

cm.next.ServeHTTP(w, r)
}

type corsPreflightHandler struct {
allowedHeaders map[corsMatchable]struct{}
allowedMethods map[corsMatchable]struct{}
allowedMethodsHeader string
allowedOrigins map[corsMatchable]struct{}
defaultAllowedOrigin string
}

func (ch *corsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (ch *corsPreflightHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
requestedMethod := strings.ToUpper(r.Header.Get("access-control-request-method"))
if requestedMethod == "" {
w.WriteHeader(http.StatusBadRequest)
Expand Down Expand Up @@ -103,6 +125,8 @@ func (ch *corsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} else {
w.Header().Set("access-control-allow-origin", ch.defaultAllowedOrigin)
}

w.Header().Set("vary", "Origin")
}
}

Expand Down Expand Up @@ -165,10 +189,36 @@ func (cb *CORSBuilder) AllowOrigins(origins ...string) *CORSBuilder {
return cb
}

// PreflightHandler returns an http.Handler that can set Access-Control-Allow-* headers
// for preflight-requests (OPTIONS).
func (cb *CORSBuilder) PreflightHandler() http.Handler {
return cb.Build()
}

// Middleware wraps an http.Handler to set ACAO headers on responses.
func (cb *CORSBuilder) Middleware(next http.Handler) http.Handler {
cm := &corsMiddleware{
allowedOrigins: make(map[corsMatchable]struct{}),
defaultAllowedOrigin: cb.defaultAllowedOrigin,
}

for origin := range cb.allowedOrigins {
cm.allowedOrigins[corsMatchableString(origin)] = corsValue
}

cm.defaultAllowedOrigin = cb.defaultAllowedOrigin

cm.next = next

return cm
}

// Build returns an http.Handler that can set Access-Control-Allow-* headers
// based on requests it receives.
//
// DEPRECATED use PreflightHandler.
func (cb *CORSBuilder) Build() http.Handler {
ch := &corsHandler{
ch := &corsPreflightHandler{
allowedHeaders: make(map[corsMatchable]struct{}),
allowedOrigins: make(map[corsMatchable]struct{}),
}
Expand Down Expand Up @@ -198,14 +248,12 @@ func (cb *CORSBuilder) Build() http.Handler {
ch.allowedMethodsHeader = strings.Join(allowedMethods, ", ")
}

if len(cb.allowedOrigins) > 0 {
for origin := range cb.allowedOrigins {
ch.allowedOrigins[corsMatchableString(origin)] = corsValue
}

ch.defaultAllowedOrigin = cb.defaultAllowedOrigin
for origin := range cb.allowedOrigins {
ch.allowedOrigins[corsMatchableString(origin)] = corsValue
}

ch.defaultAllowedOrigin = cb.defaultAllowedOrigin

return ch
}

Expand Down
28 changes: 26 additions & 2 deletions httputil/api/cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ import (
"github.com/stretchr/testify/require"
)

func TestCORSBuilder(t *testing.T) {
func TestCORSBuilderPreflightHandler(t *testing.T) {
handler := NewCORSBuilder().
AllowOrigins("http://example.com", "http://app.example.com").
AllowHeaderPrefix("horsehead-").
AllowHeaders("X-Custom-Header").Build()
AllowHeaders("X-Custom-Header").PreflightHandler()

req, err := http.NewRequest(http.MethodOptions, "http://example.com", nil)
require.NoError(t, err)
Expand All @@ -28,6 +28,7 @@ func TestCORSBuilder(t *testing.T) {

require.Equal(t, http.StatusOK, result.StatusCode)
require.Equal(t, "http://app.example.com", result.Header.Get("Access-Control-Allow-Origin"))
require.Equal(t, "Origin", result.Header.Get("Vary"))
require.Equal(t, "Horsehead-Custom-Header, X-Custom-Header", result.Header.Get("Access-Control-Allow-Headers"))
require.Equal(t, strings.Join(corsDefaultAllowedMethods, ", "), result.Header.Get("Access-Control-Allow-Methods"))

Expand All @@ -45,3 +46,26 @@ func TestCORSBuilder(t *testing.T) {
require.Equal(t, http.StatusMethodNotAllowed, result.StatusCode)
}
}

func TestCORSBuilderMiddleware(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})

cm := NewCORSBuilder().
AllowOrigins("http://example.com", "http://app.example.com").
AllowHeaderPrefix("horsehead-").
AllowHeaders("X-Custom-Header").Middleware(handler)

req, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
require.NoError(t, err)

req.Header.Set("Origin", "http://app.example.com")

resp := httptest.NewRecorder()

cm.ServeHTTP(resp, req)
result := resp.Result()

require.Equal(t, http.StatusOK, result.StatusCode)
require.Equal(t, "http://app.example.com", result.Header.Get("Access-Control-Allow-Origin"))
require.Equal(t, "Origin", result.Header.Get("Vary"))
}

0 comments on commit 0985913

Please sign in to comment.