Skip to content

Commit

Permalink
improvement: application firewall now supports IPv4, IPv6, CIDR notation
Browse files Browse the repository at this point in the history
  • Loading branch information
pilinux committed Oct 2, 2024
1 parent 598e04c commit 4942008
Show file tree
Hide file tree
Showing 4 changed files with 401 additions and 40 deletions.
4 changes: 2 additions & 2 deletions .env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ ACTIVATE_FIREWALL=yes
# Block one or several IPs [LISTTYPE=blacklist | IP=x.x.x.x]
LISTTYPE=whitelist
# LISTTYPE=blacklist
# IP - comma-separated list
# IP=192.168.0.1,10.0.0.1
# IP - comma-separated list, IPv4, IPv6, CIDR
# IP=192.168.0.1,10.0.0.1,172.16.0.0/12,2400:cb00::/32
IP=*

#
Expand Down
4 changes: 2 additions & 2 deletions example/.env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ ACTIVATE_FIREWALL=yes
# Block one or several IPs [LISTTYPE=blacklist | IP=x.x.x.x]
LISTTYPE=whitelist
# LISTTYPE=blacklist
# IP - comma-separated list
# IP=192.168.0.1,10.0.0.1
# IP - comma-separated list, IPv4, IPv6, CIDR
# IP=192.168.0.1,10.0.0.1,172.16.0.0/12,2400:cb00::/32
IP=*

#
Expand Down
131 changes: 127 additions & 4 deletions lib/middleware/firewall.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,77 @@ package middleware
// Copyright (c) 2022 pilinux

import (
"fmt"
"net"
"net/http"
"strings"
"sync"

"github.com/gin-gonic/gin"
)

// firewall package-level variables
var (
parsedOnce sync.Once
ipNets []*net.IPNet
ipListMap map[string]bool
ipCIDR bool
)

// Firewall - whitelist/blacklist IPs
func Firewall(listType string, ipList string) gin.HandlerFunc {
return func(c *gin.Context) {
// Get the real client IP
clientIP := c.ClientIP()
// parse the IP list only once
parsedOnce.Do(func() {
parseIPList(listType, ipList)
})

// get the real client IP
clientNetIP := net.ParseIP(c.ClientIP())
if clientNetIP == nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, "IP invalid")
return
}
clientIP := clientNetIP.String()

if !strings.Contains(ipList, "*") {
if listType == "whitelist" {
if !strings.Contains(ipList, clientIP) {
var allowIP bool
if len(ipListMap) > 0 {
if _, ok := ipListMap[clientIP]; ok {
allowIP = true
}
}
if !allowIP && ipCIDR {
for _, ipNet := range ipNets {
if ipNet.Contains(clientNetIP) {
allowIP = true
break
}
}
}
if !allowIP {
c.AbortWithStatusJSON(http.StatusUnauthorized, "IP blocked")
return
}
}

if listType == "blacklist" {
if strings.Contains(ipList, clientIP) {
var blockIP bool
if len(ipListMap) > 0 {
if _, ok := ipListMap[clientIP]; ok {
blockIP = true
}
}
if !blockIP && ipCIDR {
for _, ipNet := range ipNets {
if ipNet.Contains(clientNetIP) {
blockIP = true
break
}
}
}
if blockIP {
c.AbortWithStatusJSON(http.StatusUnauthorized, "IP blocked")
return
}
Expand All @@ -43,3 +92,77 @@ func Firewall(listType string, ipList string) gin.HandlerFunc {
c.Next()
}
}

// helper function to parse the IP list and CIDR notations
func parseIPList(listType, ipList string) {
ipListMap = make(map[string]bool)

// split the list by comma and trim spaces
ipListSlice := strings.Split(ipList, ",")
for _, ip := range ipListSlice {
ip = strings.TrimSpace(ip)
if ip == "" {
continue
}
if strings.Contains(ip, "/") {
// parse CIDR notations
_, ipNet, err := net.ParseCIDR(ip)
if err == nil {
ipNets = append(ipNets, ipNet)
}
} else {
ipListMap[ip] = true
}
}

// if any CIDR notations were found, set ipCIDR to true
if len(ipNets) > 0 {
ipCIDR = true
}

var validIPs string
var validCIDRs string
for ip := range ipListMap {
validIPs += ip + ", "
}
for _, ipNet := range ipNets {
validCIDRs += ipNet.String() + ", "
}
// remove the trailing comma and space
validIPs = strings.TrimSuffix(validIPs, ", ")
validCIDRs = strings.TrimSuffix(validCIDRs, ", ")

fmt.Println("application firewall initialized")
if listType == "whitelist" {
if strings.Contains(validIPs, "*") {
fmt.Println("whitelisted IPs: *")
} else {
if len(validIPs) > 0 {
fmt.Println("whitelisted IPs:", validIPs)
}
if len(validCIDRs) > 0 {
fmt.Println("whitelisted CIDRs:", validCIDRs)
}
}
}
if listType == "blacklist" {
if strings.Contains(validIPs, "*") {
fmt.Println("blacklisted IPs: *")
} else {
if len(validIPs) > 0 {
fmt.Println("blacklisted IPs:", validIPs)
}
if len(validCIDRs) > 0 {
fmt.Println("blacklisted CIDRs:", validCIDRs)
}
}
}
}

// ResetFirewallState - helper function to reset firewall package-level variables
func ResetFirewallState() {
parsedOnce = sync.Once{}
ipNets = nil
ipListMap = nil
ipCIDR = false
}
Loading

0 comments on commit 4942008

Please sign in to comment.