diff --git a/httputil/api/cors.go b/httputil/api/cors.go index cbdd18e..3757303 100644 --- a/httputil/api/cors.go +++ b/httputil/api/cors.go @@ -52,7 +52,29 @@ func corsMatch(ms map[corsMatchable]struct{}, s string) bool { return false } -type corsHandler struct { +type corsMiddleware struct { + allowedOrigins map[corsMatchable]struct{} + defaultAllowedOrigin string + next http.Handler +} + +func (cm *corsMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if len(cm.allowedOrigins) > 0 { + origin := r.Header.Get("origin") + + if corsMatch(cm.allowedOrigins, origin) { + w.Header().Set("access-control-allow-origin", origin) + } else { + w.Header().Set("access-control-allow-origin", cm.defaultAllowedOrigin) + } + + w.Header().Set("vary", "Origin") + } + + cm.next.ServeHTTP(w, r) +} + +type corsPreflightHandler struct { allowedHeaders map[corsMatchable]struct{} allowedMethods map[corsMatchable]struct{} allowedMethodsHeader string @@ -60,7 +82,7 @@ type corsHandler struct { defaultAllowedOrigin string } -func (ch *corsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (ch *corsPreflightHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { requestedMethod := strings.ToUpper(r.Header.Get("access-control-request-method")) if requestedMethod == "" { w.WriteHeader(http.StatusBadRequest) @@ -103,6 +125,8 @@ func (ch *corsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } else { w.Header().Set("access-control-allow-origin", ch.defaultAllowedOrigin) } + + w.Header().Set("vary", "Origin") } } @@ -165,10 +189,36 @@ func (cb *CORSBuilder) AllowOrigins(origins ...string) *CORSBuilder { return cb } +// PreflightHandler returns an http.Handler that can set Access-Control-Allow-* headers +// for preflight-requests (OPTIONS). +func (cb *CORSBuilder) PreflightHandler() http.Handler { + return cb.Build() +} + +// Middleware wraps an http.Handler to set ACAO headers on responses. +func (cb *CORSBuilder) Middleware(next http.Handler) http.Handler { + cm := &corsMiddleware{ + allowedOrigins: make(map[corsMatchable]struct{}), + defaultAllowedOrigin: cb.defaultAllowedOrigin, + } + + for origin := range cb.allowedOrigins { + cm.allowedOrigins[corsMatchableString(origin)] = corsValue + } + + cm.defaultAllowedOrigin = cb.defaultAllowedOrigin + + cm.next = next + + return cm +} + // Build returns an http.Handler that can set Access-Control-Allow-* headers // based on requests it receives. +// +// DEPRECATED use PreflightHandler. func (cb *CORSBuilder) Build() http.Handler { - ch := &corsHandler{ + ch := &corsPreflightHandler{ allowedHeaders: make(map[corsMatchable]struct{}), allowedOrigins: make(map[corsMatchable]struct{}), } @@ -198,14 +248,12 @@ func (cb *CORSBuilder) Build() http.Handler { ch.allowedMethodsHeader = strings.Join(allowedMethods, ", ") } - if len(cb.allowedOrigins) > 0 { - for origin := range cb.allowedOrigins { - ch.allowedOrigins[corsMatchableString(origin)] = corsValue - } - - ch.defaultAllowedOrigin = cb.defaultAllowedOrigin + for origin := range cb.allowedOrigins { + ch.allowedOrigins[corsMatchableString(origin)] = corsValue } + ch.defaultAllowedOrigin = cb.defaultAllowedOrigin + return ch } diff --git a/httputil/api/cors_test.go b/httputil/api/cors_test.go index 4795ce0..4c26a21 100644 --- a/httputil/api/cors_test.go +++ b/httputil/api/cors_test.go @@ -9,11 +9,11 @@ import ( "github.com/stretchr/testify/require" ) -func TestCORSBuilder(t *testing.T) { +func TestCORSBuilderPreflightHandler(t *testing.T) { handler := NewCORSBuilder(). AllowOrigins("http://example.com", "http://app.example.com"). AllowHeaderPrefix("horsehead-"). - AllowHeaders("X-Custom-Header").Build() + AllowHeaders("X-Custom-Header").PreflightHandler() req, err := http.NewRequest(http.MethodOptions, "http://example.com", nil) require.NoError(t, err) @@ -28,6 +28,7 @@ func TestCORSBuilder(t *testing.T) { require.Equal(t, http.StatusOK, result.StatusCode) require.Equal(t, "http://app.example.com", result.Header.Get("Access-Control-Allow-Origin")) + require.Equal(t, "Origin", result.Header.Get("Vary")) require.Equal(t, "Horsehead-Custom-Header, X-Custom-Header", result.Header.Get("Access-Control-Allow-Headers")) require.Equal(t, strings.Join(corsDefaultAllowedMethods, ", "), result.Header.Get("Access-Control-Allow-Methods")) @@ -45,3 +46,26 @@ func TestCORSBuilder(t *testing.T) { require.Equal(t, http.StatusMethodNotAllowed, result.StatusCode) } } + +func TestCORSBuilderMiddleware(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + + cm := NewCORSBuilder(). + AllowOrigins("http://example.com", "http://app.example.com"). + AllowHeaderPrefix("horsehead-"). + AllowHeaders("X-Custom-Header").Middleware(handler) + + req, err := http.NewRequest(http.MethodGet, "http://example.com", nil) + require.NoError(t, err) + + req.Header.Set("Origin", "http://app.example.com") + + resp := httptest.NewRecorder() + + cm.ServeHTTP(resp, req) + result := resp.Result() + + require.Equal(t, http.StatusOK, result.StatusCode) + require.Equal(t, "http://app.example.com", result.Header.Get("Access-Control-Allow-Origin")) + require.Equal(t, "Origin", result.Header.Get("Vary")) +}