Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Wildcard support for routes / methods #115

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion lib/agent/aikido_types/stats.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package aikido_types

import "sync"
import (
"regexp"
"sync"
)

type StatsDataType struct {
StatsMutex sync.Mutex
Expand Down Expand Up @@ -32,3 +35,8 @@ type RateLimitingValue struct {
UserCounts map[string]*RateLimitingCounts
IpCounts map[string]*RateLimitingCounts
}

type RateLimitingWildcardValue struct {
RouteRegex *regexp.Regexp
RateLimitingValue *RateLimitingValue
}
21 changes: 20 additions & 1 deletion lib/agent/cloud/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
. "main/globals"
"main/log"
"main/utils"
"regexp"
"strings"
"time"
)

Expand Down Expand Up @@ -42,6 +44,10 @@ func ResetHeartbeatTicker() {
}
}

func isWildcardEndpoint(route string) bool {
return strings.Contains(route, "*")
}

func UpdateRateLimitingConfig() {
globals.RateLimitingMutex.Lock()
defer globals.RateLimitingMutex.Unlock()
Expand All @@ -62,6 +68,7 @@ func UpdateRateLimitingConfig() {

log.Infof("Rate limiting endpoint config has changed: %v", newEndpointConfig)
delete(globals.RateLimitingMap, k)
delete(globals.RateLimitingWildcardMap, k)
}

if !newEndpointConfig.RateLimiting.Enabled {
Expand All @@ -76,20 +83,32 @@ func UpdateRateLimitingConfig() {
}

log.Infof("Got new rate limiting endpoint config and storing to map: %v", newEndpointConfig)
globals.RateLimitingMap[k] = &RateLimitingValue{
rateLimitingValue := &RateLimitingValue{
Config: RateLimitingConfig{
MaxRequests: newEndpointConfig.RateLimiting.MaxRequests,
WindowSizeInMinutes: newEndpointConfig.RateLimiting.WindowSizeInMS / MinRateLimitingIntervalInMs},
UserCounts: make(map[string]*RateLimitingCounts),
IpCounts: make(map[string]*RateLimitingCounts),
}

if isWildcardEndpoint(k.Route) {
routeRegex, err := regexp.Compile(k.Route)
if err != nil {
log.Warnf("Route regex is not compiling: %s", k.Route)
} else {
globals.RateLimitingWildcardMap[k] = &RateLimitingWildcardValue{RouteRegex: routeRegex, RateLimitingValue: rateLimitingValue}
}
}

globals.RateLimitingMap[k] = rateLimitingValue
}

for k := range globals.RateLimitingMap {
_, exists := UpdatedEndpoints[k]
if !exists {
log.Infof("Removed rate limiting entry as it is no longer part of the config: %v", k)
delete(globals.RateLimitingMap, k)
delete(globals.RateLimitingWildcardMap, k)
}
}
}
Expand Down
6 changes: 6 additions & 0 deletions lib/agent/globals/globals.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,14 @@ var RoutesMutex sync.Mutex
var StatsData StatsDataType

// Rate limiting map, which holds the current rate limiting state for each configured route
// map[(method, route)] -> RateLimitingValue
var RateLimitingMap = make(map[RateLimitingKey]*RateLimitingValue)

// Rate limiting wildcard map, which holds the current rate limiting state for each configured wildcard route
// map[method] -> (RouteRegex, RateLimitingValue)
// method can also be '*'
var RateLimitingWildcardMap = make(map[RateLimitingKey]*RateLimitingWildcardValue)

// Rate limiting mutex used to sync access across the go routines
var RateLimitingMutex sync.RWMutex

Expand Down
2 changes: 1 addition & 1 deletion lib/agent/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ toolchain go1.23.3

require (
github.com/stretchr/testify v1.9.0
google.golang.org/grpc v1.69.0
google.golang.org/grpc v1.69.2
google.golang.org/protobuf v1.35.1
)

Expand Down
2 changes: 2 additions & 0 deletions lib/agent/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ google.golang.org/grpc v1.68.1 h1:oI5oTa11+ng8r8XMMN7jAOmWfPZWbYpCFaMUTACxkM0=
google.golang.org/grpc v1.68.1/go.mod h1:+q1XYFJjShcqn0QZHvCyeR4CXPA+llXIeUIfIe00waw=
google.golang.org/grpc v1.69.0 h1:quSiOM1GJPmPH5XtU+BCoVXcDVJJAzNcoyfC2cCjGkI=
google.golang.org/grpc v1.69.0/go.mod h1:vyjdE6jLBI76dgpDojsFGNaHlxdjXN9ghpnd2o7JGZ4=
google.golang.org/grpc v1.69.2 h1:U3S9QEtbXC0bYNvRtcoklF3xGtLViumSYxWykJS+7AU=
google.golang.org/grpc v1.69.2/go.mod h1:vyjdE6jLBI76dgpDojsFGNaHlxdjXN9ghpnd2o7JGZ4=
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA=
Expand Down
55 changes: 39 additions & 16 deletions lib/agent/grpc/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,26 +118,49 @@ func isRateLimitingThresholdExceeded(config *RateLimitingConfig, countsMap map[s
return counts.TotalNumberOfRequests >= config.MaxRequests
}

func getRateLimitingStatus(method string, route string, user string, ip string) *protos.RateLimitingStatus {
globals.RateLimitingMutex.RLock()
defer globals.RateLimitingMutex.RUnlock()

rateLimitingDataForRoute, exists := globals.RateLimitingMap[RateLimitingKey{Method: method, Route: route}]
func getRateLimitingValue(method, route string) []*RateLimitingValue {
rateLimitingDataForEndpoint, exists := globals.RateLimitingMap[RateLimitingKey{Method: method, Route: route}]
if !exists {
return &protos.RateLimitingStatus{Block: false}
return []*RateLimitingValue{}
}
return []*RateLimitingValue{rateLimitingDataForEndpoint}
}

func getWildcardRateLimitingValues(method, route string) []*RateLimitingValue {
wildcardRatelimitingValues := []*RateLimitingValue{}

if user != "" {
// If the user exists, we only try to rate limit by user
if isRateLimitingThresholdExceeded(&rateLimitingDataForRoute.Config, rateLimitingDataForRoute.UserCounts, user) {
log.Infof("Rate limited request for user %s - %s %s - %v", user, method, route, rateLimitingDataForRoute.UserCounts[user])
return &protos.RateLimitingStatus{Block: true, Trigger: "user"}
for key, r := range globals.RateLimitingWildcardMap {
if key.Method != "*" && key.Method != method {
continue
}
if r.RouteRegex.MatchString(route) {
wildcardRatelimitingValues = append(wildcardRatelimitingValues, r.RateLimitingValue)
}
} else {
// Otherwise, we rate limit by ip
if isRateLimitingThresholdExceeded(&rateLimitingDataForRoute.Config, rateLimitingDataForRoute.IpCounts, ip) {
log.Infof("Rate limited request for ip %s - %s %s - %v", ip, method, route, rateLimitingDataForRoute.IpCounts[ip])
return &protos.RateLimitingStatus{Block: true, Trigger: "ip"}
}
return wildcardRatelimitingValues
}

func getRateLimitingStatus(method, route, user, ip string) *protos.RateLimitingStatus {
globals.RateLimitingMutex.RLock()
defer globals.RateLimitingMutex.RUnlock()

rateLimitingDataArray := getRateLimitingValue(method, route)
rateLimitingDataArray = append(rateLimitingDataArray, getRateLimitingValue("*", route)...)
rateLimitingDataArray = append(rateLimitingDataArray, getWildcardRateLimitingValues(method, route)...)

for _, rateLimitingDataForRoute := range rateLimitingDataArray {
if user != "" {
// If the user exists, we only try to rate limit by user
if isRateLimitingThresholdExceeded(&rateLimitingDataForRoute.Config, rateLimitingDataForRoute.UserCounts, user) {
log.Infof("Rate limited request for user %s - %s %s - %v", user, method, route, rateLimitingDataForRoute.UserCounts[user])
return &protos.RateLimitingStatus{Block: true, Trigger: "user"}
}
} else {
// Otherwise, we rate limit by ip
if isRateLimitingThresholdExceeded(&rateLimitingDataForRoute.Config, rateLimitingDataForRoute.IpCounts, ip) {
log.Infof("Rate limited request for ip %s - %s %s - %v", ip, method, route, rateLimitingDataForRoute.IpCounts[ip])
return &protos.RateLimitingStatus{Block: true, Trigger: "ip"}
}
}
}

Expand Down
29 changes: 22 additions & 7 deletions lib/request-processor/aikido_types/config.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package aikido_types

import "github.com/seancfoley/ipaddress-go/ipaddr"
import (
"regexp"

"github.com/seancfoley/ipaddress-go/ipaddr"
)

type EnvironmentConfigData struct {
SocketPath string `json:"socket_path"` // '/run/aikido-{version}/aikido-{datetime}-{randint}.sock'
Expand Down Expand Up @@ -31,6 +35,16 @@ type EndpointData struct {
AllowedIPAddresses map[string]bool
}

type EndpointDataStatus struct {
Data EndpointData
Found bool
}

type WildcardEndpointData struct {
RouteRegex *regexp.Regexp
Data EndpointData
}

type EndpointKey struct {
Method string
Route string
Expand All @@ -43,10 +57,11 @@ type IpBlockList struct {
}

type CloudConfigData struct {
ConfigUpdatedAt int64
Endpoints map[EndpointKey]EndpointData
BlockedUserIds map[string]bool
BypassedIps map[string]bool
BlockedIps map[string]IpBlockList
Block int
ConfigUpdatedAt int64
Endpoints map[EndpointKey]EndpointData
WildcardEndpoints map[string][]WildcardEndpointData
BlockedUserIds map[string]bool
BypassedIps map[string]bool
BlockedIps map[string]IpBlockList
Block int
}
109 changes: 100 additions & 9 deletions lib/request-processor/context/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package context
// #include "../../API.h"
import "C"
import (
. "main/aikido_types"
"main/helpers"
"main/log"
"main/utils"
Expand Down Expand Up @@ -123,18 +124,108 @@ func ContextSetUserName() {
ContextSetString(C.CONTEXT_USER_NAME, &Context.UserName)
}

func ContextSetIsProtectionTurnedOff() {
if Context.IsProtectionTurnedOff != nil {
func ContextSetEndpointConfig() {
if Context.EndpointConfig != nil {
return
}

method := GetMethod()
route := GetParsedRoute()
endpointConfig, endpointConfigFound := utils.GetEndpointConfig(GetMethod(), GetParsedRoute())
Context.EndpointConfig = &EndpointDataStatus{Data: endpointConfig, Found: endpointConfigFound}
}

func ContextSetWildcardEndpointsConfigs() {
if Context.WildcardEndpointsConfigs != nil {
return
}

wildcardEndpointsConfigs := utils.GetWildcardEndpointsConfigs(GetMethod(), GetParsedRoute())
Context.WildcardEndpointsConfigs = &wildcardEndpointsConfigs
}

endpointConfig, err := utils.GetEndpointConfig(method, route)
isProtectionTurnedOff := false
if err == nil {
isProtectionTurnedOff = endpointConfig.ForceProtectionOff
func ContextSetIsEndpointProtectionTurnedOff() {
if Context.IsEndpointProtectionTurnedOff != nil {
return
}

isEndpointProtectionTurnedOff := false

endpointConfig, found := GetEndpointConfig()
if found {
isEndpointProtectionTurnedOff = endpointConfig.ForceProtectionOff
}
if !isEndpointProtectionTurnedOff {
for _, wildcardEndpointConfig := range GetWildcardEndpointsConfig() {
if wildcardEndpointConfig.ForceProtectionOff {
isEndpointProtectionTurnedOff = true
break
}
}
}
Context.IsEndpointProtectionTurnedOff = &isEndpointProtectionTurnedOff
}

func ContextSetIsEndpointConfigured() {
if Context.IsEndpointConfigured != nil {
return
}

IsEndpointConfigured := false

_, found := GetEndpointConfig()
if found {
IsEndpointConfigured = true
}
if !IsEndpointConfigured {
if len(GetWildcardEndpointsConfig()) != 0 {
IsEndpointConfigured = true
}
}
Context.IsEndpointConfigured = &IsEndpointConfigured
}

func ContextSetIsEndpointRateLimitingEnabled() {
if Context.IsEndpointRateLimitingEnabled != nil {
return
}

IsEndpointRateLimitingEnabled := false

endpointData, found := GetEndpointConfig()
if found {
IsEndpointRateLimitingEnabled = endpointData.RateLimiting.Enabled
}
if !IsEndpointRateLimitingEnabled {
for _, wildcardEndpointConfig := range GetWildcardEndpointsConfig() {
if wildcardEndpointConfig.RateLimiting.Enabled {
IsEndpointRateLimitingEnabled = true
break
}
}
}
Context.IsEndpointRateLimitingEnabled = &IsEndpointRateLimitingEnabled
}

func ContextSetIsEndpointIpAllowed() {
if Context.IsEndpointIpAllowed != nil {
return
}

ip := GetIp()

isEndpointIpAllowed := true

endpointData, found := GetEndpointConfig()
if found {
isEndpointIpAllowed = utils.IsIpAllowed(endpointData.AllowedIPAddresses, ip)
}

if isEndpointIpAllowed {
for _, wildcardEndpointConfig := range GetWildcardEndpointsConfig() {
if !utils.IsIpAllowed(wildcardEndpointConfig.AllowedIPAddresses, ip) {
isEndpointIpAllowed = false
break
}
}
}
Context.IsProtectionTurnedOff = &isProtectionTurnedOff
Context.IsEndpointIpAllowed = &isEndpointIpAllowed
}
Loading
Loading