forked from Ulexus/traefik-plugin-geoblock
-
Notifications
You must be signed in to change notification settings - Fork 0
/
plugin.go
120 lines (101 loc) · 2.65 KB
/
plugin.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
package geoblock
import (
"context"
_ "embed"
"errors"
"fmt"
"github.com/ip2location/ip2location-go"
"log"
"net/http"
"strings"
)
type Config struct {
DatabaseFilePath string
AllowedCountries []string `yaml:"allowed_countries"`
}
func CreateConfig() *Config {
return &Config{}
}
type Plugin struct {
next http.Handler
name string
db *ip2location.DB
allowedCountries []string
}
func New(_ context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
db, err := ip2location.OpenDB(config.DatabaseFilePath)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
return &Plugin{next: next, name: name, db: db, allowedCountries: config.AllowedCountries}, nil
}
func (p Plugin) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
ips := p.GetRemoteIPs(req)
for _, ip := range ips {
allowed, country, err := p.CheckAllowed(ip)
if err != nil {
log.Printf("%s: %v", p.name, err)
rw.WriteHeader(http.StatusForbidden)
return
}
if !allowed {
log.Printf("%s: access denied for %s (%s)", p.name, ip, country)
rw.WriteHeader(http.StatusForbidden)
return
}
}
p.next.ServeHTTP(rw, req)
}
// GetRemoteIPs collects the remote IPs from the X-Forwarded-For and X-Real-IP headers.
func (p Plugin) GetRemoteIPs(req *http.Request) (ips []string) {
ipMap := make(map[string]struct{})
if xff := req.Header.Get("x-forwarded-for"); xff != "" {
for _, ip := range strings.Split(xff, ",") {
ip = strings.TrimSpace(ip)
if ip == "" {
continue
}
ipMap[ip] = struct{}{}
}
}
if xri := req.Header.Get("x-real-ip"); xri != "" {
for _, ip := range strings.Split(xri, ",") {
ip = strings.TrimSpace(ip)
if ip == "" {
continue
}
ipMap[ip] = struct{}{}
}
}
for ip := range ipMap {
ips = append(ips, ip)
}
return
}
// CheckAllowed checks whether a given IP address is allowed according to the configured allowed countries.
func (p Plugin) CheckAllowed(ip string) (bool, string, error) {
country, err := p.Lookup(ip)
if err != nil {
return false, "", fmt.Errorf("lookup of %s failed: %w", ip, err)
}
var allowed bool
for _, allowedCountry := range p.allowedCountries {
if allowedCountry == country {
allowed = true
break
}
}
return allowed, country, nil
}
// Lookup queries the ip2location database for a given IP address.
func (p Plugin) Lookup(ip string) (string, error) {
record, err := p.db.Get_country_short(ip)
if err != nil {
return "", err
}
country := record.Country_short
if strings.HasPrefix(strings.ToLower(country), "invalid") {
return "", errors.New(country)
}
return record.Country_short, nil
}