From 34335091b8597eb5745c0cb5adda562e7c4a4681 Mon Sep 17 00:00:00 2001 From: n33pm <12273891+n33pm@users.noreply.github.com> Date: Tue, 9 Apr 2024 15:44:50 +0200 Subject: [PATCH] feat(): custom real ip header #908 --- middleware/realip.go | 56 ++++++++++++++++++++++++--------------- middleware/realip_test.go | 49 ++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 22 deletions(-) diff --git a/middleware/realip.go b/middleware/realip.go index 11c0348a..48057c11 100644 --- a/middleware/realip.go +++ b/middleware/realip.go @@ -9,10 +9,12 @@ import ( "strings" ) -var cfConnectionIP = http.CanonicalHeaderKey("CF-Connecting-IP") -var trueClientIP = http.CanonicalHeaderKey("True-Client-IP") -var xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For") -var xRealIP = http.CanonicalHeaderKey("X-Real-IP") +var DefaultRealIPHeaders = []string{ + "CF-Connecting-IP", // Cloudflare free plan + "True-Client-IP", // Cloudflare Enterprise plan + "X-Real-IP", + "X-Forwarded-For", +} // RealIP is a middleware that sets a http.Request's RemoteAddr to the results // of parsing either the CF-Connecting-IP, True-Client-IP, X-Real-IP or the X-Forwarded-For headers @@ -31,7 +33,7 @@ var xRealIP = http.CanonicalHeaderKey("X-Real-IP") // how you're using RemoteAddr, vulnerable to an attack of some sort). func RealIP(h http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { - if rip := realIP(r); rip != "" { + if rip := getRealIP(r, DefaultRealIPHeaders); rip != "" { r.RemoteAddr = rip } h.ServeHTTP(w, r) @@ -40,24 +42,34 @@ func RealIP(h http.Handler) http.Handler { return http.HandlerFunc(fn) } -func realIP(r *http.Request) string { - var ip string - - if cfcip := r.Header.Get(cfConnectionIP); cfcip != "" { - ip = cfcip - } else if tcip := r.Header.Get(trueClientIP); tcip != "" { - ip = tcip - } else if xrip := r.Header.Get(xRealIP); xrip != "" { - ip = xrip - } else if xff := r.Header.Get(xForwardedFor); xff != "" { - i := strings.Index(xff, ",") - if i == -1 { - i = len(xff) +// RealIPCustomHeader is a middleware that sets a http.Request's RemoteAddr to the results +// of parsing the custom headers. +// +// usage: +// r.Use(RealIPCustomHeader([]string{"X-CUSTOM-IP"})) +// r.Use(RealIPCustomHeader(append(DefaultRealIPHeaders, "X-CUSTOM-IP"))) +func RealIPCustomHeader(realIPHeaders []string) func(http.Handler) http.Handler { + f := func(h http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + if rip := getRealIP(r, realIPHeaders); rip != "" { + r.RemoteAddr = rip + } + h.ServeHTTP(w, r) } - ip = xff[:i] + return http.HandlerFunc(fn) } - if ip == "" || net.ParseIP(ip) == nil { - return "" + return f +} + +func getRealIP(r *http.Request, realIPHeaders []string) string { + for _, header := range realIPHeaders { + if ip := r.Header.Get(header); ip != "" { + ips := strings.Split(ip, ",") + if ips[0] == "" || net.ParseIP(ips[0]) == nil { + continue + } + return ips[0] + } } - return ip + return "" } diff --git a/middleware/realip_test.go b/middleware/realip_test.go index 1ab5e95e..cc75dfcf 100644 --- a/middleware/realip_test.go +++ b/middleware/realip_test.go @@ -113,3 +113,52 @@ func TestInvalidIP(t *testing.T) { t.Fatal("Invalid IP used.") } } + +func TestCustomIPHeader(t *testing.T) { + var customHeaderKey = "X-CUSTOM-IP" + req, _ := http.NewRequest("GET", "/", nil) + req.Header.Add(customHeaderKey, "100.100.100.100") + w := httptest.NewRecorder() + + r := chi.NewRouter() + r.Use(RealIPCustomHeader([]string{customHeaderKey})) + + realIP := "" + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + realIP = r.RemoteAddr + w.Write([]byte("Hello World")) + }) + r.ServeHTTP(w, req) + + if w.Code != 200 { + t.Fatal("Response Code should be 200") + } + + if realIP != "100.100.100.100" { + t.Fatal("Test get real IP precedence error.") + } +} + +func TestCustomIPHeaderWithoutDefault(t *testing.T) { + req, _ := http.NewRequest("GET", "/", nil) + req.Header.Add("X-REAL-IP", "100.100.100.100") + w := httptest.NewRecorder() + + r := chi.NewRouter() + r.Use(RealIPCustomHeader([]string{"X-CUSTOM-IP"})) + + realIP := "" + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + realIP = r.RemoteAddr + w.Write([]byte("Hello World")) + }) + r.ServeHTTP(w, req) + + if w.Code != 200 { + t.Fatal("Response Code should be 200") + } + + if realIP != "" { + t.Fatal("Invalid IP used.") + } +}