Skip to content

Commit

Permalink
feat(): custom real ip header
Browse files Browse the repository at this point in the history
  • Loading branch information
n33pm committed Apr 9, 2024
1 parent 288bb4f commit 3433509
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 22 deletions.
56 changes: 34 additions & 22 deletions middleware/realip.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ import (
"strings"
)

var cfConnectionIP = http.CanonicalHeaderKey("CF-Connecting-IP")
var trueClientIP = http.CanonicalHeaderKey("True-Client-IP")
var xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For")
var xRealIP = http.CanonicalHeaderKey("X-Real-IP")
var DefaultRealIPHeaders = []string{
"CF-Connecting-IP", // Cloudflare free plan
"True-Client-IP", // Cloudflare Enterprise plan
"X-Real-IP",
"X-Forwarded-For",
}

// RealIP is a middleware that sets a http.Request's RemoteAddr to the results
// of parsing either the CF-Connecting-IP, True-Client-IP, X-Real-IP or the X-Forwarded-For headers
Expand All @@ -31,7 +33,7 @@ var xRealIP = http.CanonicalHeaderKey("X-Real-IP")
// how you're using RemoteAddr, vulnerable to an attack of some sort).
func RealIP(h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
if rip := realIP(r); rip != "" {
if rip := getRealIP(r, DefaultRealIPHeaders); rip != "" {
r.RemoteAddr = rip
}
h.ServeHTTP(w, r)
Expand All @@ -40,24 +42,34 @@ func RealIP(h http.Handler) http.Handler {
return http.HandlerFunc(fn)
}

func realIP(r *http.Request) string {
var ip string

if cfcip := r.Header.Get(cfConnectionIP); cfcip != "" {
ip = cfcip
} else if tcip := r.Header.Get(trueClientIP); tcip != "" {
ip = tcip
} else if xrip := r.Header.Get(xRealIP); xrip != "" {
ip = xrip
} else if xff := r.Header.Get(xForwardedFor); xff != "" {
i := strings.Index(xff, ",")
if i == -1 {
i = len(xff)
// RealIPCustomHeader is a middleware that sets a http.Request's RemoteAddr to the results
// of parsing the custom headers.
//
// usage:
// r.Use(RealIPCustomHeader([]string{"X-CUSTOM-IP"}))
// r.Use(RealIPCustomHeader(append(DefaultRealIPHeaders, "X-CUSTOM-IP")))
func RealIPCustomHeader(realIPHeaders []string) func(http.Handler) http.Handler {
f := func(h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
if rip := getRealIP(r, realIPHeaders); rip != "" {
r.RemoteAddr = rip
}
h.ServeHTTP(w, r)
}
ip = xff[:i]
return http.HandlerFunc(fn)
}
if ip == "" || net.ParseIP(ip) == nil {
return ""
return f
}

func getRealIP(r *http.Request, realIPHeaders []string) string {
for _, header := range realIPHeaders {
if ip := r.Header.Get(header); ip != "" {
ips := strings.Split(ip, ",")
if ips[0] == "" || net.ParseIP(ips[0]) == nil {
continue
}
return ips[0]
}
}
return ip
return ""
}
49 changes: 49 additions & 0 deletions middleware/realip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,52 @@ func TestInvalidIP(t *testing.T) {
t.Fatal("Invalid IP used.")
}
}

func TestCustomIPHeader(t *testing.T) {
var customHeaderKey = "X-CUSTOM-IP"
req, _ := http.NewRequest("GET", "/", nil)
req.Header.Add(customHeaderKey, "100.100.100.100")
w := httptest.NewRecorder()

r := chi.NewRouter()
r.Use(RealIPCustomHeader([]string{customHeaderKey}))

realIP := ""
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
realIP = r.RemoteAddr
w.Write([]byte("Hello World"))
})
r.ServeHTTP(w, req)

if w.Code != 200 {
t.Fatal("Response Code should be 200")
}

if realIP != "100.100.100.100" {
t.Fatal("Test get real IP precedence error.")
}
}

func TestCustomIPHeaderWithoutDefault(t *testing.T) {
req, _ := http.NewRequest("GET", "/", nil)
req.Header.Add("X-REAL-IP", "100.100.100.100")
w := httptest.NewRecorder()

r := chi.NewRouter()
r.Use(RealIPCustomHeader([]string{"X-CUSTOM-IP"}))

realIP := ""
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
realIP = r.RemoteAddr
w.Write([]byte("Hello World"))
})
r.ServeHTTP(w, req)

if w.Code != 200 {
t.Fatal("Response Code should be 200")
}

if realIP != "" {
t.Fatal("Invalid IP used.")
}
}

0 comments on commit 3433509

Please sign in to comment.