Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add error handling for nil RouteContext in URLFormat #841

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion middleware/url_format.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ import (
"github.com/go-chi/chi/v5"
)

const (
errRouteContextNil = "RouteContext was nil."
)

var (
// URLFormatCtxKey is the context.Context key to store the URL format data
// for a request.
Expand Down Expand Up @@ -52,7 +56,12 @@ func URLFormat(next http.Handler) http.Handler {
path := r.URL.Path

rctx := chi.RouteContext(r.Context())
if rctx != nil && rctx.RoutePath != "" {
if rctx == nil {
http.Error(w, errRouteContextNil, http.StatusInternalServerError)
return
}

if rctx.RoutePath != "" {
path = rctx.RoutePath
}

Expand Down
29 changes: 29 additions & 0 deletions middleware/url_format_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package middleware

import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/go-chi/chi/v5"
Expand Down Expand Up @@ -67,3 +69,30 @@ func TestURLFormatInSubRouter(t *testing.T) {
t.Fatalf(resp)
}
}

func TestURLFormatWithoutChiRouteContext(t *testing.T) {
r := chi.NewRouter()

r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
newCtx := context.WithValue(r.Context(), chi.RouteCtxKey, nil)
Copy link
Contributor

@VojtechVitek VojtechVitek Sep 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Taaipi,

Thanks for the PR!

It looks like this test cases injects a nil RouteContext down the chain intentionally.

Do you have an example how could this happen in real-world use case? I wonder how is the middleware.URLFormat ever called without chi route context unless we set it to nil explicitly.

Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess based on #718, it'd be better to create a test case with some sub-routers (instead of setting nil). Right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you using middleware.URLFormat outside of chi router?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@VojtechVitek
Thank you for the review.

Are you using middleware.URLFormat outside of chi router?

Yes, this issue could arise if middleware.URLFormat is used outside of the chi router or if RouteContext is explicitly set to nil (though I haven’t fully grasped the intent behind such a scenario based on issue #839). However, since encountering such a problem is conceivable, even though it remains an edge case, I’ve added error handling to catch and address it early.

For now, I will omit the case where nil is intentionally assigned. Instead, I’m considering adding a test case for error handling when the router is not used. What do you think?

next.ServeHTTP(w, r.WithContext(newCtx))
})
})
r.Use(URLFormat)
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
})

ts := httptest.NewServer(r)
defer ts.Close()

resp, respBody := testRequest(t, ts, "GET", "/", nil)
if resp.StatusCode != http.StatusInternalServerError {
t.Fatalf("non 500 response: %v", resp.StatusCode)
}

if strings.TrimSpace(respBody) != errRouteContextNil {
t.Fatalf("Expected error message: %s, but got: %s", errRouteContextNil, respBody)
}
}