Skip to content

Commit

Permalink
handle private addresses
Browse files Browse the repository at this point in the history
Signed-off-by: nscuro <[email protected]>
  • Loading branch information
nscuro committed Oct 4, 2021
1 parent 6932146 commit 7d03639
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 12 deletions.
48 changes: 37 additions & 11 deletions plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ import (
)

type Config struct {
DatabaseFilePath string
DatabaseFilePath string `yaml:"database_file"`
AllowedCountries []string `yaml:"allowed_countries"`
AllowPrivate bool `yaml:"allow_private"`
}

func CreateConfig() *Config {
Expand All @@ -25,6 +26,7 @@ type Plugin struct {
name string
db *ip2location.DB
allowedCountries []string
allowPrivate bool
}

func New(_ context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
Expand All @@ -33,27 +35,39 @@ func New(_ context.Context, next http.Handler, config *Config, name string) (htt
return nil, fmt.Errorf("failed to open database: %w", err)
}

return &Plugin{next: next, name: name, db: db, allowedCountries: config.AllowedCountries}, nil
return &Plugin{
next: next,
name: name,
db: db,
allowedCountries: config.AllowedCountries,
allowPrivate: config.AllowPrivate,
}, nil
}

func (p Plugin) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
ips := p.GetRemoteIPs(req)

for _, ip := range ips {
allowed, country, err := p.CheckAllowed(ip)
if err != nil {
log.Printf("%s: %v", p.name, err)
rw.WriteHeader(http.StatusForbidden)
country, err := p.CheckAllowed(ip)
if err == nil {
p.next.ServeHTTP(rw, req)
return
}
if !allowed {
if errors.Is(err, ErrPrivate) && p.allowPrivate {
log.Printf("%s: allowed for private address %s", p.name, ip)
p.next.ServeHTTP(rw, req)
return
} else if errors.Is(err, ErrNotAllowed) {
log.Printf("%s: access denied for %s (%s)", p.name, ip, country)
rw.WriteHeader(http.StatusForbidden)
return
} else {
log.Printf("%s: %v", p.name, err)
rw.WriteHeader(http.StatusForbidden)
return
}
}

p.next.ServeHTTP(rw, req)
}

// GetRemoteIPs collects the remote IPs from the X-Forwarded-For and X-Real-IP headers.
Expand Down Expand Up @@ -86,11 +100,20 @@ func (p Plugin) GetRemoteIPs(req *http.Request) (ips []string) {
return
}

var (
ErrNotAllowed = errors.New("not allowed")
ErrPrivate = errors.New("private address")
)

// CheckAllowed checks whether a given IP address is allowed according to the configured allowed countries.
func (p Plugin) CheckAllowed(ip string) (bool, string, error) {
func (p Plugin) CheckAllowed(ip string) (string, error) {
country, err := p.Lookup(ip)
if err != nil {
return false, "", fmt.Errorf("lookup of %s failed: %w", ip, err)
return "", fmt.Errorf("lookup of %s failed: %w", ip, err)
}

if country == "-" { // Private address
return country, ErrPrivate
}

var allowed bool
Expand All @@ -100,8 +123,11 @@ func (p Plugin) CheckAllowed(ip string) (bool, string, error) {
break
}
}
if !allowed {
return country, ErrNotAllowed
}

return allowed, country, nil
return country, nil
}

// Lookup queries the ip2location database for a given IP address.
Expand Down
30 changes: 29 additions & 1 deletion plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func (n noopHandler) ServeHTTP(rw http.ResponseWriter, _ *http.Request) {
rw.WriteHeader(http.StatusTeapot)
}

func TestLookup(t *testing.T) {
func TestPlugin_ServeHTTP(t *testing.T) {
t.Run("Allowed", func(t *testing.T) {
cfg := &Config{DatabaseFilePath: dbFilePath, AllowedCountries: []string{"US"}}
plugin, err := New(context.TODO(), &noopHandler{}, cfg, "geoblock")
Expand All @@ -31,6 +31,20 @@ func TestLookup(t *testing.T) {
require.Equal(t, http.StatusTeapot, rr.Code)
})

t.Run("AllowedPrivate", func(t *testing.T) {
cfg := &Config{DatabaseFilePath: dbFilePath, AllowedCountries: []string{}, AllowPrivate: true}
plugin, err := New(context.TODO(), &noopHandler{}, cfg, "geoblock")
require.NoError(t, err)

req := httptest.NewRequest(http.MethodGet, "/foobar", nil)
req.Header.Set("X-Real-IP", "192.168.178.66")

rr := httptest.NewRecorder()
plugin.ServeHTTP(rr, req)

require.Equal(t, http.StatusTeapot, rr.Code)
})

t.Run("Disallowed", func(t *testing.T) {
cfg := &Config{DatabaseFilePath: dbFilePath, AllowedCountries: []string{"DE"}}
plugin, err := New(context.TODO(), &noopHandler{}, cfg, "geoblock")
Expand All @@ -44,4 +58,18 @@ func TestLookup(t *testing.T) {

require.Equal(t, http.StatusForbidden, rr.Code)
})

t.Run("DisallowedPrivate", func(t *testing.T) {
cfg := &Config{DatabaseFilePath: dbFilePath, AllowedCountries: []string{}, AllowPrivate: false}
plugin, err := New(context.TODO(), &noopHandler{}, cfg, "geoblock")
require.NoError(t, err)

req := httptest.NewRequest(http.MethodGet, "/foobar", nil)
req.Header.Set("X-Real-IP", "192.168.178.66")

rr := httptest.NewRecorder()
plugin.ServeHTTP(rr, req)

require.Equal(t, http.StatusForbidden, rr.Code)
})
}

0 comments on commit 7d03639

Please sign in to comment.