diff --git a/README.md b/README.md index 7058e58..a623de6 100644 --- a/README.md +++ b/README.md @@ -53,4 +53,6 @@ http: allowPrivate: true # HTTP status code to return for disallowed requests (default: 403) disallowedStatusCode: 204 + # Add CIDR to be whitelisted, even if in a non-allowed country + allowedIPBlocks: ["66.249.64.0/19"] ``` \ No newline at end of file diff --git a/plugin.go b/plugin.go index 383887f..d27fb04 100644 --- a/plugin.go +++ b/plugin.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "log" + "net" "net/http" "strings" @@ -21,6 +22,7 @@ type Config struct { AllowedCountries []string // Whitelist of countries to allow (ISO 3166-1 alpha-2) AllowPrivate bool // Allow requests from private / internal networks? DisallowedStatusCode int // HTTP status code to return for disallowed requests + AllowedIPBlocks []string // List of whitelist CIDR } // CreateConfig creates the default plugin configuration. @@ -36,6 +38,7 @@ type Plugin struct { allowedCountries []string allowPrivate bool disallowedStatusCode int + allowedIPBlocks []*net.IPNet } // New creates a new plugin instance. @@ -69,6 +72,7 @@ func New(_ context.Context, next http.Handler, cfg *Config, name string) (http.H if err != nil { return nil, fmt.Errorf("%s: failed to open database: %w", name, err) } + allowedIPBlocks := initAllowedIPBlocks(cfg.AllowedIPBlocks) return &Plugin{ next: next, @@ -78,6 +82,7 @@ func New(_ context.Context, next http.Handler, cfg *Config, name string) (http.H allowedCountries: cfg.AllowedCountries, allowPrivate: cfg.AllowPrivate, disallowedStatusCode: cfg.DisallowedStatusCode, + allowedIPBlocks: allowedIPBlocks, }, nil } @@ -154,10 +159,20 @@ func (p Plugin) CheckAllowed(ip string) (bool, string, error) { var allowed bool for _, allowedCountry := range p.allowedCountries { if allowedCountry == country { - allowed = true - break + // allowed = true + return true, country, nil + //break } } + + allowed, err = p.isAllowedIPBlocks(ip) + if err != nil { + return false, "", fmt.Errorf("checking if %s is part of an allowed range failed: %w", ip, err) + } + if allowed { + return true, country, nil + } + if !allowed { return false, country, nil } @@ -179,3 +194,34 @@ func (p Plugin) Lookup(ip string) (string, error) { return record.Country_short, nil } + +func initAllowedIPBlocks(allowedIPBlocks []string) []*net.IPNet { + + var allowedIPBlocksNet []*net.IPNet + + for _, cidr := range allowedIPBlocks { + _, block, err := net.ParseCIDR(cidr) + if err != nil { + panic(fmt.Errorf("parse error on %q: %v", cidr, err)) + } + allowedIPBlocksNet = append(allowedIPBlocksNet, block) + } + + return allowedIPBlocksNet +} + +func (p Plugin) isAllowedIPBlocks(ip string) (bool, error) { + var ipAddress net.IP = net.ParseIP(ip) + + if ipAddress == nil { + return false, fmt.Errorf("unable parse IP address from address [%s]", ip) + } + + for _, block := range p.allowedIPBlocks { + if block.Contains(ipAddress) { + return true, nil + } + } + + return false, nil +}