From 042e19bc4a03bd46e9ad82f7bdaf726419b3adfd Mon Sep 17 00:00:00 2001 From: Jonathan Hall Date: Thu, 14 Dec 2023 12:06:08 +0100 Subject: [PATCH] Add function to expose allowed methods for use in custom 405-handlers Fixes #870 --- context.go | 12 ++++++++++++ context_test.go | 28 +++++++++++++++++++++++++++- 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/context.go b/context.go index aacf6eff..b74975a8 100644 --- a/context.go +++ b/context.go @@ -133,6 +133,18 @@ func (x *Context) RoutePattern() string { return routePattern } +// WithRouteContext returns the list of methods allowed for the current +// request, based on the current routing context. +func (x *Context) AllowedMethods() []string { + result := make([]string, 0, len(x.methodsAllowed)) + for _, method := range x.methodsAllowed { + if method := methodTypString(method); method != "" { + result = append(result, method) + } + } + return result +} + // replaceWildcards takes a route pattern and recursively replaces all // occurrences of "/*/" to "/". func replaceWildcards(p string) string { diff --git a/context_test.go b/context_test.go index fa3c9f5b..9d32f871 100644 --- a/context_test.go +++ b/context_test.go @@ -1,6 +1,9 @@ package chi -import "testing" +import ( + "strings" + "testing" +) // TestRoutePattern tests correct in-the-middle wildcard removals. // If user organizes a router like this: @@ -91,3 +94,26 @@ func TestRoutePattern(t *testing.T) { t.Fatalf("unexpected non-empty route pattern for nil context: %q", p) } } + +func TestAllowedMethods(t *testing.T) { + t.Run("expected methods", func(t *testing.T) { + want := "GET HEAD" + rctx := &Context{ + methodsAllowed: []methodTyp{mGET, mHEAD}, + } + got := strings.Join(rctx.AllowedMethods(), " ") + if want != got { + t.Errorf("Unexpected allowed methods: %s, want: %s", got, want) + } + }) + t.Run("unexpected methods", func(t *testing.T) { + want := "GET HEAD" + rctx := &Context{ + methodsAllowed: []methodTyp{mGET, mHEAD, 9000}, + } + got := strings.Join(rctx.AllowedMethods(), " ") + if want != got { + t.Errorf("Unexpected allowed methods: %s, want: %s", got, want) + } + }) +}