Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rebase to upstream and update go.mod #1

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .github/stale.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
daysUntilStale: 60
daysUntilClose: 7
# Issues with these labels will never be considered stale
exemptLabels:
- v2
- needs-review
- work-required
staleLabel: stale
markComment: >
This issue has been automatically marked as stale because it hasn't seen
a recent update. It'll be automatically closed in a few days.
closeComment: false
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ with Go's `net/http` package (or any framework supporting `http.Handler`), inclu
* [**RecoveryHandler**](https://godoc.org/github.com/gorilla/handlers#RecoveryHandler) for recovering from unexpected panics.

Other handlers are documented [on the Gorilla
website](http://www.gorillatoolkit.org/pkg/handlers).
website](https://www.gorillatoolkit.org/pkg/handlers).

## Example

Expand Down
54 changes: 46 additions & 8 deletions cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@ type cors struct {
maxAge int
ignoreOptions bool
allowCredentials bool
optionStatusCode int
}

// OriginValidator takes an origin string and returns whether or not that origin is allowed.
type OriginValidator func(string) bool

var (
defaultCorsMethods = []string{"GET", "HEAD", "POST"}
defaultCorsHeaders = []string{"Accept", "Accept-Language", "Content-Language", "Origin"}
defaultCorsOptionStatusCode = 200
defaultCorsMethods = []string{"GET", "HEAD", "POST"}
defaultCorsHeaders = []string{"Accept", "Accept-Language", "Content-Language", "Origin"}
// (WebKit/Safari v9 sends the Origin header by default in AJAX requests)
)

Expand All @@ -48,7 +50,10 @@ const (
func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get(corsOriginHeader)
if !ch.isOriginAllowed(origin) {
ch.h.ServeHTTP(w, r)
if r.Method != corsOptionMethod || ch.ignoreOptions {
ch.h.ServeHTTP(w, r)
}

return
}

Expand Down Expand Up @@ -110,9 +115,24 @@ func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Header().Set(corsVaryHeader, corsOriginHeader)
}

w.Header().Set(corsAllowOriginHeader, origin)
returnOrigin := origin
if ch.allowedOriginValidator == nil && len(ch.allowedOrigins) == 0 {
returnOrigin = "*"
} else {
for _, o := range ch.allowedOrigins {
// A configuration of * is different than explicitly setting an allowed
// origin. Returning arbitrary origin headers in an access control allow
// origin header is unsafe and is not required by any use case.
if o == corsOriginMatchAll {
returnOrigin = "*"
break
}
}
}
w.Header().Set(corsAllowOriginHeader, returnOrigin)

if r.Method == corsOptionMethod {
w.WriteHeader(ch.optionStatusCode)
return
}
ch.h.ServeHTTP(w, r)
Expand Down Expand Up @@ -147,9 +167,10 @@ func CORS(opts ...CORSOption) func(http.Handler) http.Handler {

func parseCORSOptions(opts ...CORSOption) *cors {
ch := &cors{
allowedMethods: defaultCorsMethods,
allowedHeaders: defaultCorsHeaders,
allowedOrigins: []string{corsOriginMatchAll},
allowedMethods: defaultCorsMethods,
allowedHeaders: defaultCorsHeaders,
allowedOrigins: []string{},
optionStatusCode: defaultCorsOptionStatusCode,
}

for _, option := range opts {
Expand Down Expand Up @@ -234,7 +255,20 @@ func AllowedOriginValidator(fn OriginValidator) CORSOption {
}
}

// ExposeHeaders can be used to specify headers that are available
// OptionStatusCode sets a custom status code on the OPTIONS requests.
// Default behaviour sets it to 200 to reflect best practices. This is option is not mandatory
// and can be used if you need a custom status code (i.e 204).
//
// More informations on the spec:
// https://fetch.spec.whatwg.org/#cors-preflight-fetch
func OptionStatusCode(code int) CORSOption {
return func(ch *cors) error {
ch.optionStatusCode = code
return nil
}
}

// ExposedHeaders can be used to specify headers that are available
// and will not be stripped out by the user-agent.
func ExposedHeaders(headers []string) CORSOption {
return func(ch *cors) error {
Expand Down Expand Up @@ -297,6 +331,10 @@ func (ch *cors) isOriginAllowed(origin string) bool {
return ch.allowedOriginValidator(origin)
}

if len(ch.allowedOrigins) == 0 {
return true
}

for _, allowedOrigin := range ch.allowedOrigins {
if allowedOrigin == origin || allowedOrigin == corsOriginMatchAll {
return true
Expand Down
76 changes: 75 additions & 1 deletion cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,43 @@ func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandler(t *testing.T) {
}
}

func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandlerWithCustomStatusCode(t *testing.T) {
statusCode := 204
r := newRequest("OPTIONS", "http://www.example.com/")
r.Header.Set("Origin", r.URL.String())
r.Header.Set(corsRequestMethodHeader, "GET")

rr := httptest.NewRecorder()

testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatal("Options request must not be passed to next handler")
})

CORS(OptionStatusCode(statusCode))(testHandler).ServeHTTP(rr, r)

if status := rr.Code; status != statusCode {
t.Fatalf("bad status: got %v want %v", status, http.StatusOK)
}
}

func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandlerWhenOriginNotAllowed(t *testing.T) {
r := newRequest("OPTIONS", "http://www.example.com/")
r.Header.Set("Origin", r.URL.String())
r.Header.Set(corsRequestMethodHeader, "GET")

rr := httptest.NewRecorder()

testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatal("Options request must not be passed to next handler")
})

CORS(AllowedOrigins([]string{}))(testHandler).ServeHTTP(rr, r)

if status := rr.Code; status != http.StatusOK {
t.Fatalf("bad status: got %v want %v", status, http.StatusOK)
}
}

func TestCORSHandlerAllowedMethodForPreflight(t *testing.T) {
r := newRequest("OPTIONS", "http://www.example.com/")
r.Header.Set("Origin", r.URL.String())
Expand Down Expand Up @@ -313,7 +350,7 @@ func TestCORSWithMultipleHandlers(t *testing.T) {
}
}

func TestCORSHandlerWithCustomValidator(t *testing.T) {
func TestCORSOriginValidatorWithImplicitStar(t *testing.T) {
r := newRequest("GET", "http://a.example.com")
r.Header.Set("Origin", r.URL.String())
rr := httptest.NewRecorder()
Expand All @@ -332,5 +369,42 @@ func TestCORSHandlerWithCustomValidator(t *testing.T) {
if header != r.URL.String() {
t.Fatalf("bad header: expected %s to be %s, got %s.", corsAllowOriginHeader, r.URL.String(), header)
}
}

func TestCORSOriginValidatorWithExplicitStar(t *testing.T) {
r := newRequest("GET", "http://a.example.com")
r.Header.Set("Origin", r.URL.String())
rr := httptest.NewRecorder()

testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})

originValidator := func(origin string) bool {
if strings.HasSuffix(origin, ".example.com") {
return true
}
return false
}

CORS(
AllowedOriginValidator(originValidator),
AllowedOrigins([]string{"*"}),
)(testHandler).ServeHTTP(rr, r)
header := rr.HeaderMap.Get(corsAllowOriginHeader)
if header != "*" {
t.Fatalf("bad header: expected %s to be %s, got %s.", corsAllowOriginHeader, "*", header)
}
}

func TestCORSAllowStar(t *testing.T) {
r := newRequest("GET", "http://a.example.com")
r.Header.Set("Origin", r.URL.String())
rr := httptest.NewRecorder()

testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})

CORS()(testHandler).ServeHTTP(rr, r)
header := rr.HeaderMap.Get(corsAllowOriginHeader)
if header != "*" {
t.Fatalf("bad header: expected %s to be %s, got %s.", corsAllowOriginHeader, "*", header)
}
}
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
module github.com/ef-ctx/handlers

go 1.12
Loading