Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wip: impl unix sockets and impl trusted proxies #88

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
5 changes: 5 additions & 0 deletions config/crowdsec-blocklist-mirror.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ blocklists:
- ::1

listen_uri: 127.0.0.1:41412
# listen_socket: /var/run/crowdsec-blocklist-mirror.sock
# trusted_proxies:
# - 127.0.0.1
# - 127.0.0.1/32
# trusted_header: X-Forwarded-For
tls:
cert_file:
key_file:
Expand Down
72 changes: 59 additions & 13 deletions pkg/cfg/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ import (
"errors"
"fmt"
"io"
"net"
"os"
"strings"

"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/slices"
"gopkg.in/yaml.v3"

Expand Down Expand Up @@ -56,14 +57,18 @@ type TLSConfig struct {
}

type Config struct {
CrowdsecConfig CrowdsecConfig `yaml:"crowdsec_config"`
Blocklists []*BlockListConfig `yaml:"blocklists"`
ListenURI string `yaml:"listen_uri"`
TLS TLSConfig `yaml:"tls"`
Metrics MetricConfig `yaml:"metrics"`
Logging LoggingConfig `yaml:",inline"`
ConfigVersion string `yaml:"config_version"`
EnableAccessLogs bool `yaml:"enable_access_logs"`
CrowdsecConfig CrowdsecConfig `yaml:"crowdsec_config"`
Blocklists []*BlockListConfig `yaml:"blocklists"`
ListenURI string `yaml:"listen_uri"`
ListenSocket string `yaml:"listen_socket"`
TrustedProxies []string `yaml:"trusted_proxies"`
ParsedTrustedProxies []*net.IPNet `yaml:"-"`
TrustedHeader string `yaml:"trusted_header"`
TLS TLSConfig `yaml:"tls"`
Metrics MetricConfig `yaml:"metrics"`
Logging LoggingConfig `yaml:",inline"`
ConfigVersion string `yaml:"config_version"`
EnableAccessLogs bool `yaml:"enable_access_logs"`
}

func (cfg *Config) ValidateAndSetDefaults() error {
Expand All @@ -80,19 +85,19 @@ func (cfg *Config) ValidateAndSetDefaults() error {
}

if cfg.CrowdsecConfig.UpdateFrequency == "" {
logrus.Warn("update_frequency is not provided")
log.Warn("update_frequency is not provided")

cfg.CrowdsecConfig.UpdateFrequency = "10s"
}

if cfg.ConfigVersion == "" {
logrus.Warn("config version is not provided; assuming v1.0")
log.Warn("config version is not provided; assuming v1.0")

cfg.ConfigVersion = "v1.0"
}

if cfg.ListenURI == "" {
logrus.Warn("listen_uri is not provided ; assuming 127.0.0.1:41412")
if cfg.ListenURI == "" && cfg.ListenSocket == "" {
log.Warn("listen_uri is not provided ; assuming 127.0.0.1:41412")

cfg.ListenURI = "127.0.0.1:41412"
}
Expand Down Expand Up @@ -125,9 +130,50 @@ func (cfg *Config) ValidateAndSetDefaults() error {
}
}

cfg.ParsedTrustedProxies = make([]*net.IPNet, 0, len(cfg.TrustedProxies))
for _, ip := range cfg.TrustedProxies {
if !strings.Contains(ip, "/") {
log.Debug("no CIDR provided attempting to add /32 or /128; ", ip)
parsedIP := parseIP(ip)
if parsedIP == nil {
return fmt.Errorf("invalid IP address: %s", ip)
}
switch len(parsedIP) {
case net.IPv4len:
ip += "/32"
case net.IPv6len:
ip += "/128"
}
log.Debug("added CIDR to IP: ", ip)
}
_, ipNet, err := net.ParseCIDR(ip)
if err != nil {
return fmt.Errorf("invalid IP address: %s", ip)
}
log.Info("adding trusted proxy: ", ip)
cfg.ParsedTrustedProxies = append(cfg.ParsedTrustedProxies, ipNet)
}

if cfg.TrustedHeader == "" {
log.Info("trusted_header is not provided; assuming X-Forwarded-For")
cfg.TrustedHeader = "X-Forwarded-For"
}

if len(cfg.ParsedTrustedProxies) == 0 {
log.Info("no trusted proxies provided so trusted_header is ignored")
}

return nil
}

func parseIP(ip string) net.IP {
parsedIP := net.ParseIP(ip)
if ipv4 := parsedIP.To4(); ipv4 != nil {
return ipv4
}
return parsedIP
}

func MergedConfig(configPath string) ([]byte, error) {
patcher := yamlpatch.NewPatcher(configPath, ".local")

Expand Down
136 changes: 100 additions & 36 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func RunServer(ctx context.Context, g *errgroup.Group, config cfg.Config) error
return err
}

http.HandleFunc(blockListCFG.Endpoint, f)
http.HandleFunc(blockListCFG.Endpoint, globalMiddleware(config, f))
log.Infof("serving blocklist in format %s at endpoint %s", blockListCFG.Format, blockListCFG.Endpoint)
}

Expand All @@ -51,16 +51,37 @@ func RunServer(ctx context.Context, g *errgroup.Group, config cfg.Config) error
}

server := &http.Server{
Addr: config.ListenURI,
Handler: logHandler,
}

g.Go(func() error {
err := listenAndServe(server, config)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
if config.ListenSocket != "" {
log.Info("listening on unix socket: ", config.ListenSocket)
listener, err := net.Listen("unix", config.ListenSocket)
if err != nil {
return err
}
defer listener.Close()
if err := listenAndServe(server, listener, config); !errors.Is(err, http.ErrServerClosed) {
return err
}
}
return nil
})

g.Go(func() error {
if config.ListenURI != "" {
log.Info("listening on tcp server: ", config.ListenURI)
listener, err := net.Listen("tcp", config.ListenURI)
if err != nil {
return err
}
defer listener.Close()

if err := listenAndServe(server, listener, config); !errors.Is(err, http.ErrServerClosed) {
return err
}
}
return nil
})

Expand All @@ -73,15 +94,57 @@ func RunServer(ctx context.Context, g *errgroup.Group, config cfg.Config) error
return nil
}

func listenAndServe(server *http.Server, config cfg.Config) error {
if config.TLS.CertFile != "" && config.TLS.KeyFile != "" {
log.Infof("Starting server with TLS at %s", config.ListenURI)
return server.ListenAndServeTLS(config.TLS.CertFile, config.TLS.KeyFile)
/*
Global middlewares are middlewares that are applied to all routes and are not specific to a blocklist.
*/
func globalMiddleware(config cfg.Config, next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
//Parsed unix socket request
if r.RemoteAddr == "@" {
r.RemoteAddr = "127.0.0.1:65535"
}
//Trusted proxies
header := r.Header.Get(config.TrustedHeader)
// If there is no header then we don't need to do anything
if header != "" {
headerSplit := strings.Split(header, ",")
ip, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr))
if err != nil {
log.Errorf("error while spliting hostport for %s: %v", r.RemoteAddr, err)
Dismissed Show dismissed Hide dismissed
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
//Loop over the parsed trusted proxies
for _, trustedProxy := range config.ParsedTrustedProxies {
//check if the remote address is in the trusted proxies
if trustedProxy.Contains(net.ParseIP(ip)) {
// Loop over the header values in reverse order
for i := len(headerSplit) - 1; i >= 0; i-- {
ipStr := strings.TrimSpace(headerSplit[i])
ip := net.ParseIP(ipStr)
if ip == nil {
break
}
// If the IP is not in the trusted proxies, set the remote address to the IP
if (i == 0) || (!trustedProxy.Contains(ip)) {
r.RemoteAddr = ipStr
break
}
}
}
}
}

next.ServeHTTP(w, r)
}
}

log.Infof("Starting server at %s", config.ListenURI)
func listenAndServe(server *http.Server, listener net.Listener, config cfg.Config) error {
if config.TLS.CertFile != "" && config.TLS.KeyFile != "" {
return server.ServeTLS(listener, config.TLS.CertFile, config.TLS.KeyFile)
}

return server.ListenAndServe()
return server.Serve(listener)
}

var RouteHits = prometheus.NewCounterVec(
Expand Down Expand Up @@ -132,7 +195,7 @@ func toValidCIDR(ip string) string {
}

func getTrustedIPs(ips []string) ([]net.IPNet, error) {
trustedIPs := make([]net.IPNet, 0)
trustedIPs := make([]net.IPNet, 0, len(ips))

for _, ip := range ips {
cidr := toValidCIDR(ip)
Expand Down Expand Up @@ -183,36 +246,37 @@ func decisionMiddleware(next http.HandlerFunc) func(w http.ResponseWriter, r *ht

func authMiddleware(blockListCfg *cfg.BlockListConfig, next http.HandlerFunc) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
log.Errorf("error while spliting hostport for %s: %v", r.RemoteAddr, err)
http.Error(w, "internal error", http.StatusInternalServerError)

return
}

trustedIPs, err := getTrustedIPs(blockListCfg.Authentication.TrustedIPs)
if err != nil {
log.Errorf("error while parsing trusted IPs: %v", err)
http.Error(w, "internal error", http.StatusInternalServerError)
authType := strings.ToLower(blockListCfg.Authentication.Type)

// If auth != none then we implement checks if not bypass them to the next handler
if authType != "none" {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
// If we can't parse the IP, we use the remote address as is as it most likely been set by the trusted proxies middleware
if err != nil {
ip = r.RemoteAddr
}

return
}
trustedIPs, err := getTrustedIPs(blockListCfg.Authentication.TrustedIPs)
if err != nil {
log.Errorf("error while parsing trusted IPs: %v", err)
http.Error(w, "internal error", http.StatusInternalServerError)

switch strings.ToLower(blockListCfg.Authentication.Type) {
case "ip_based":
if !networksContainIP(trustedIPs, ip) {
http.Error(w, "access denied", http.StatusForbidden)
return
}
case "basic":
if !satisfiesBasicAuth(r, blockListCfg.Authentication.User, blockListCfg.Authentication.Password) {
http.Error(w, "access denied", http.StatusForbidden)
return

switch authType {
case "ip_based":
if !networksContainIP(trustedIPs, ip) {
http.Error(w, "access denied", http.StatusForbidden)
return
}
case "basic":
if !satisfiesBasicAuth(r, blockListCfg.Authentication.User, blockListCfg.Authentication.Password) {
http.Error(w, "access denied", http.StatusForbidden)
return
}
}
case "", "none":
}

next.ServeHTTP(w, r)
}
}
Expand Down
4 changes: 2 additions & 2 deletions test/bouncer/test_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_tls_server(crowdsec, certs_dir, api_key_factory, bouncer, bm_cfg_factor
with bouncer(cfg) as bm:
bm.wait_for_lines_fnmatch([
"*Using API key auth*",
"*Starting server at 127.0.0.1:*"
"*listening on tcp server: 127.0.0.1:*"
])


Expand Down Expand Up @@ -94,7 +94,7 @@ def test_tls_mutual(crowdsec, certs_dir, bouncer, bm_cfg_factory, bouncer_under_
"*Starting crowdsec-blocklist-mirror*",
"*Using CA cert*",
"*Using cert auth with cert * and key *",
"*Starting server at 127.0.0.1:*"
"*listening on tcp server: 127.0.0.1:*"
])

# check that the bouncer is registered
Expand Down
Loading