diff --git a/lib/agent/aikido_types/stats.go b/lib/agent/aikido_types/stats.go index 0fc0f937..6bcd4e85 100644 --- a/lib/agent/aikido_types/stats.go +++ b/lib/agent/aikido_types/stats.go @@ -1,6 +1,9 @@ package aikido_types -import "sync" +import ( + "regexp" + "sync" +) type StatsDataType struct { StatsMutex sync.Mutex @@ -32,3 +35,8 @@ type RateLimitingValue struct { UserCounts map[string]*RateLimitingCounts IpCounts map[string]*RateLimitingCounts } + +type RateLimitingWildcardValue struct { + RouteRegex *regexp.Regexp + RateLimitingValue *RateLimitingValue +} diff --git a/lib/agent/cloud/common.go b/lib/agent/cloud/common.go index 738f2d08..6156754e 100644 --- a/lib/agent/cloud/common.go +++ b/lib/agent/cloud/common.go @@ -7,6 +7,8 @@ import ( . "main/globals" "main/log" "main/utils" + "regexp" + "strings" "time" ) @@ -42,6 +44,10 @@ func ResetHeartbeatTicker() { } } +func isWildcardEndpoint(route string) bool { + return strings.Contains(route, "*") +} + func UpdateRateLimitingConfig() { globals.RateLimitingMutex.Lock() defer globals.RateLimitingMutex.Unlock() @@ -62,6 +68,7 @@ func UpdateRateLimitingConfig() { log.Infof("Rate limiting endpoint config has changed: %v", newEndpointConfig) delete(globals.RateLimitingMap, k) + delete(globals.RateLimitingWildcardMap, k) } if !newEndpointConfig.RateLimiting.Enabled { @@ -76,13 +83,24 @@ func UpdateRateLimitingConfig() { } log.Infof("Got new rate limiting endpoint config and storing to map: %v", newEndpointConfig) - globals.RateLimitingMap[k] = &RateLimitingValue{ + rateLimitingValue := &RateLimitingValue{ Config: RateLimitingConfig{ MaxRequests: newEndpointConfig.RateLimiting.MaxRequests, WindowSizeInMinutes: newEndpointConfig.RateLimiting.WindowSizeInMS / MinRateLimitingIntervalInMs}, UserCounts: make(map[string]*RateLimitingCounts), IpCounts: make(map[string]*RateLimitingCounts), } + + if isWildcardEndpoint(k.Route) { + routeRegex, err := regexp.Compile(k.Route) + if err != nil { + log.Warnf("Route regex is not compiling: %s", k.Route) + } else { + globals.RateLimitingWildcardMap[k] = &RateLimitingWildcardValue{RouteRegex: routeRegex, RateLimitingValue: rateLimitingValue} + } + } + + globals.RateLimitingMap[k] = rateLimitingValue } for k := range globals.RateLimitingMap { @@ -90,6 +108,7 @@ func UpdateRateLimitingConfig() { if !exists { log.Infof("Removed rate limiting entry as it is no longer part of the config: %v", k) delete(globals.RateLimitingMap, k) + delete(globals.RateLimitingWildcardMap, k) } } } diff --git a/lib/agent/globals/globals.go b/lib/agent/globals/globals.go index 30ab0a72..a89d9751 100644 --- a/lib/agent/globals/globals.go +++ b/lib/agent/globals/globals.go @@ -37,8 +37,14 @@ var RoutesMutex sync.Mutex var StatsData StatsDataType // Rate limiting map, which holds the current rate limiting state for each configured route +// map[(method, route)] -> RateLimitingValue var RateLimitingMap = make(map[RateLimitingKey]*RateLimitingValue) +// Rate limiting wildcard map, which holds the current rate limiting state for each configured wildcard route +// map[method] -> (RouteRegex, RateLimitingValue) +// method can also be '*' +var RateLimitingWildcardMap = make(map[RateLimitingKey]*RateLimitingWildcardValue) + // Rate limiting mutex used to sync access across the go routines var RateLimitingMutex sync.RWMutex diff --git a/lib/agent/go.mod b/lib/agent/go.mod index c71b80b5..602e2251 100644 --- a/lib/agent/go.mod +++ b/lib/agent/go.mod @@ -6,7 +6,7 @@ toolchain go1.23.3 require ( github.com/stretchr/testify v1.9.0 - google.golang.org/grpc v1.69.0 + google.golang.org/grpc v1.69.2 google.golang.org/protobuf v1.35.1 ) diff --git a/lib/agent/go.sum b/lib/agent/go.sum index 1f0c485c..d634f6d8 100644 --- a/lib/agent/go.sum +++ b/lib/agent/go.sum @@ -38,6 +38,8 @@ google.golang.org/grpc v1.68.1 h1:oI5oTa11+ng8r8XMMN7jAOmWfPZWbYpCFaMUTACxkM0= google.golang.org/grpc v1.68.1/go.mod h1:+q1XYFJjShcqn0QZHvCyeR4CXPA+llXIeUIfIe00waw= google.golang.org/grpc v1.69.0 h1:quSiOM1GJPmPH5XtU+BCoVXcDVJJAzNcoyfC2cCjGkI= google.golang.org/grpc v1.69.0/go.mod h1:vyjdE6jLBI76dgpDojsFGNaHlxdjXN9ghpnd2o7JGZ4= +google.golang.org/grpc v1.69.2 h1:U3S9QEtbXC0bYNvRtcoklF3xGtLViumSYxWykJS+7AU= +google.golang.org/grpc v1.69.2/go.mod h1:vyjdE6jLBI76dgpDojsFGNaHlxdjXN9ghpnd2o7JGZ4= google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA= diff --git a/lib/agent/grpc/request.go b/lib/agent/grpc/request.go index f399cf3e..0814ec17 100644 --- a/lib/agent/grpc/request.go +++ b/lib/agent/grpc/request.go @@ -118,26 +118,49 @@ func isRateLimitingThresholdExceeded(config *RateLimitingConfig, countsMap map[s return counts.TotalNumberOfRequests >= config.MaxRequests } -func getRateLimitingStatus(method string, route string, user string, ip string) *protos.RateLimitingStatus { - globals.RateLimitingMutex.RLock() - defer globals.RateLimitingMutex.RUnlock() - - rateLimitingDataForRoute, exists := globals.RateLimitingMap[RateLimitingKey{Method: method, Route: route}] +func getRateLimitingValue(method, route string) []*RateLimitingValue { + rateLimitingDataForEndpoint, exists := globals.RateLimitingMap[RateLimitingKey{Method: method, Route: route}] if !exists { - return &protos.RateLimitingStatus{Block: false} + return []*RateLimitingValue{} } + return []*RateLimitingValue{rateLimitingDataForEndpoint} +} + +func getWildcardRateLimitingValues(method, route string) []*RateLimitingValue { + wildcardRatelimitingValues := []*RateLimitingValue{} - if user != "" { - // If the user exists, we only try to rate limit by user - if isRateLimitingThresholdExceeded(&rateLimitingDataForRoute.Config, rateLimitingDataForRoute.UserCounts, user) { - log.Infof("Rate limited request for user %s - %s %s - %v", user, method, route, rateLimitingDataForRoute.UserCounts[user]) - return &protos.RateLimitingStatus{Block: true, Trigger: "user"} + for key, r := range globals.RateLimitingWildcardMap { + if key.Method != "*" && key.Method != method { + continue + } + if r.RouteRegex.MatchString(route) { + wildcardRatelimitingValues = append(wildcardRatelimitingValues, r.RateLimitingValue) } - } else { - // Otherwise, we rate limit by ip - if isRateLimitingThresholdExceeded(&rateLimitingDataForRoute.Config, rateLimitingDataForRoute.IpCounts, ip) { - log.Infof("Rate limited request for ip %s - %s %s - %v", ip, method, route, rateLimitingDataForRoute.IpCounts[ip]) - return &protos.RateLimitingStatus{Block: true, Trigger: "ip"} + } + return wildcardRatelimitingValues +} + +func getRateLimitingStatus(method, route, user, ip string) *protos.RateLimitingStatus { + globals.RateLimitingMutex.RLock() + defer globals.RateLimitingMutex.RUnlock() + + rateLimitingDataArray := getRateLimitingValue(method, route) + rateLimitingDataArray = append(rateLimitingDataArray, getRateLimitingValue("*", route)...) + rateLimitingDataArray = append(rateLimitingDataArray, getWildcardRateLimitingValues(method, route)...) + + for _, rateLimitingDataForRoute := range rateLimitingDataArray { + if user != "" { + // If the user exists, we only try to rate limit by user + if isRateLimitingThresholdExceeded(&rateLimitingDataForRoute.Config, rateLimitingDataForRoute.UserCounts, user) { + log.Infof("Rate limited request for user %s - %s %s - %v", user, method, route, rateLimitingDataForRoute.UserCounts[user]) + return &protos.RateLimitingStatus{Block: true, Trigger: "user"} + } + } else { + // Otherwise, we rate limit by ip + if isRateLimitingThresholdExceeded(&rateLimitingDataForRoute.Config, rateLimitingDataForRoute.IpCounts, ip) { + log.Infof("Rate limited request for ip %s - %s %s - %v", ip, method, route, rateLimitingDataForRoute.IpCounts[ip]) + return &protos.RateLimitingStatus{Block: true, Trigger: "ip"} + } } } diff --git a/lib/request-processor/aikido_types/config.go b/lib/request-processor/aikido_types/config.go index 5fccfcdc..368a8234 100644 --- a/lib/request-processor/aikido_types/config.go +++ b/lib/request-processor/aikido_types/config.go @@ -1,6 +1,10 @@ package aikido_types -import "github.com/seancfoley/ipaddress-go/ipaddr" +import ( + "regexp" + + "github.com/seancfoley/ipaddress-go/ipaddr" +) type EnvironmentConfigData struct { SocketPath string `json:"socket_path"` // '/run/aikido-{version}/aikido-{datetime}-{randint}.sock' @@ -31,6 +35,16 @@ type EndpointData struct { AllowedIPAddresses map[string]bool } +type EndpointDataStatus struct { + Data EndpointData + Found bool +} + +type WildcardEndpointData struct { + RouteRegex *regexp.Regexp + Data EndpointData +} + type EndpointKey struct { Method string Route string @@ -43,10 +57,11 @@ type IpBlockList struct { } type CloudConfigData struct { - ConfigUpdatedAt int64 - Endpoints map[EndpointKey]EndpointData - BlockedUserIds map[string]bool - BypassedIps map[string]bool - BlockedIps map[string]IpBlockList - Block int + ConfigUpdatedAt int64 + Endpoints map[EndpointKey]EndpointData + WildcardEndpoints map[string][]WildcardEndpointData + BlockedUserIds map[string]bool + BypassedIps map[string]bool + BlockedIps map[string]IpBlockList + Block int } diff --git a/lib/request-processor/context/cache.go b/lib/request-processor/context/cache.go index b5f9e1fc..a266b591 100644 --- a/lib/request-processor/context/cache.go +++ b/lib/request-processor/context/cache.go @@ -3,6 +3,7 @@ package context // #include "../../API.h" import "C" import ( + . "main/aikido_types" "main/helpers" "main/log" "main/utils" @@ -123,18 +124,108 @@ func ContextSetUserName() { ContextSetString(C.CONTEXT_USER_NAME, &Context.UserName) } -func ContextSetIsProtectionTurnedOff() { - if Context.IsProtectionTurnedOff != nil { +func ContextSetEndpointConfig() { + if Context.EndpointConfig != nil { return } - method := GetMethod() - route := GetParsedRoute() + endpointConfig, endpointConfigFound := utils.GetEndpointConfig(GetMethod(), GetParsedRoute()) + Context.EndpointConfig = &EndpointDataStatus{Data: endpointConfig, Found: endpointConfigFound} +} + +func ContextSetWildcardEndpointsConfigs() { + if Context.WildcardEndpointsConfigs != nil { + return + } + + wildcardEndpointsConfigs := utils.GetWildcardEndpointsConfigs(GetMethod(), GetParsedRoute()) + Context.WildcardEndpointsConfigs = &wildcardEndpointsConfigs +} - endpointConfig, err := utils.GetEndpointConfig(method, route) - isProtectionTurnedOff := false - if err == nil { - isProtectionTurnedOff = endpointConfig.ForceProtectionOff +func ContextSetIsEndpointProtectionTurnedOff() { + if Context.IsEndpointProtectionTurnedOff != nil { + return + } + + isEndpointProtectionTurnedOff := false + + endpointConfig, found := GetEndpointConfig() + if found { + isEndpointProtectionTurnedOff = endpointConfig.ForceProtectionOff + } + if !isEndpointProtectionTurnedOff { + for _, wildcardEndpointConfig := range GetWildcardEndpointsConfig() { + if wildcardEndpointConfig.ForceProtectionOff { + isEndpointProtectionTurnedOff = true + break + } + } + } + Context.IsEndpointProtectionTurnedOff = &isEndpointProtectionTurnedOff +} + +func ContextSetIsEndpointConfigured() { + if Context.IsEndpointConfigured != nil { + return + } + + IsEndpointConfigured := false + + _, found := GetEndpointConfig() + if found { + IsEndpointConfigured = true + } + if !IsEndpointConfigured { + if len(GetWildcardEndpointsConfig()) != 0 { + IsEndpointConfigured = true + } + } + Context.IsEndpointConfigured = &IsEndpointConfigured +} + +func ContextSetIsEndpointRateLimitingEnabled() { + if Context.IsEndpointRateLimitingEnabled != nil { + return + } + + IsEndpointRateLimitingEnabled := false + + endpointData, found := GetEndpointConfig() + if found { + IsEndpointRateLimitingEnabled = endpointData.RateLimiting.Enabled + } + if !IsEndpointRateLimitingEnabled { + for _, wildcardEndpointConfig := range GetWildcardEndpointsConfig() { + if wildcardEndpointConfig.RateLimiting.Enabled { + IsEndpointRateLimitingEnabled = true + break + } + } + } + Context.IsEndpointRateLimitingEnabled = &IsEndpointRateLimitingEnabled +} + +func ContextSetIsEndpointIpAllowed() { + if Context.IsEndpointIpAllowed != nil { + return + } + + ip := GetIp() + + isEndpointIpAllowed := true + + endpointData, found := GetEndpointConfig() + if found { + isEndpointIpAllowed = utils.IsIpAllowed(endpointData.AllowedIPAddresses, ip) + } + + if isEndpointIpAllowed { + for _, wildcardEndpointConfig := range GetWildcardEndpointsConfig() { + if !utils.IsIpAllowed(wildcardEndpointConfig.AllowedIPAddresses, ip) { + isEndpointIpAllowed = false + break + } + } } - Context.IsProtectionTurnedOff = &isProtectionTurnedOff + Context.IsEndpointIpAllowed = &isEndpointIpAllowed } diff --git a/lib/request-processor/context/request_context.go b/lib/request-processor/context/request_context.go index bc7f3381..b05113ff 100644 --- a/lib/request-processor/context/request_context.go +++ b/lib/request-processor/context/request_context.go @@ -3,6 +3,7 @@ package context // #include "../../API.h" import "C" import ( + . "main/aikido_types" "main/log" ) @@ -10,27 +11,32 @@ type CallbackFunction func(int) string /* Request level context cache (changes on each PHP request) */ type RequestContextData struct { - Callback CallbackFunction // callback to access data from the PHP layer (C++ extension) about the current request and current event - Method *string - Route *string - RouteParsed *string - URL *string - StatusCode *int - IP *string - IsIpBypassed *bool - IsProtectionTurnedOff *bool - UserAgent *string - UserId *string - UserName *string - BodyRaw *string - BodyParsed *map[string]interface{} - BodyParsedFlattened *map[string]string - QueryParsed *map[string]interface{} - QueryParsedFlattened *map[string]string - CookiesParsed *map[string]interface{} - CookiesParsedFlattened *map[string]string - HeadersParsed *map[string]interface{} - HeadersParsedFlattened *map[string]string + Callback CallbackFunction // callback to access data from the PHP layer (C++ extension) about the current request and current event + Method *string + Route *string + RouteParsed *string + URL *string + StatusCode *int + IP *string + EndpointConfig *EndpointDataStatus + WildcardEndpointsConfigs *[]EndpointData + IsIpBypassed *bool + IsEndpointConfigured *bool + IsEndpointRateLimitingEnabled *bool + IsEndpointProtectionTurnedOff *bool + IsEndpointIpAllowed *bool + UserAgent *string + UserId *string + UserName *string + BodyRaw *string + BodyParsed *map[string]interface{} + BodyParsedFlattened *map[string]string + QueryParsed *map[string]interface{} + QueryParsedFlattened *map[string]string + CookiesParsed *map[string]interface{} + CookiesParsedFlattened *map[string]string + HeadersParsed *map[string]interface{} + HeadersParsedFlattened *map[string]string } var Context RequestContextData @@ -137,6 +143,27 @@ func GetUserName() string { return GetFromCache(ContextSetUserName, &Context.UserName) } -func IsProtectionTurnedOff() bool { - return GetFromCache(ContextSetIsProtectionTurnedOff, &Context.IsProtectionTurnedOff) +func GetEndpointConfig() (EndpointData, bool) { + endpointDataStatus := GetFromCache(ContextSetEndpointConfig, &Context.EndpointConfig) + return endpointDataStatus.Data, endpointDataStatus.Found +} + +func GetWildcardEndpointsConfig() []EndpointData { + return GetFromCache(ContextSetWildcardEndpointsConfigs, &Context.WildcardEndpointsConfigs) +} + +func IsEndpointConfigured() bool { + return GetFromCache(ContextSetIsEndpointConfigured, &Context.IsEndpointConfigured) +} + +func IsEndpointRateLimitingEnabled() bool { + return GetFromCache(ContextSetIsEndpointRateLimitingEnabled, &Context.IsEndpointRateLimitingEnabled) +} + +func IsEndpointIpAllowed() bool { + return GetFromCache(ContextSetIsEndpointIpAllowed, &Context.IsEndpointIpAllowed) +} + +func IsEndpointProtectionTurnedOff() bool { + return GetFromCache(ContextSetIsEndpointProtectionTurnedOff, &Context.IsEndpointProtectionTurnedOff) } diff --git a/lib/request-processor/go.mod b/lib/request-processor/go.mod index 4e25fc26..a3e9848c 100644 --- a/lib/request-processor/go.mod +++ b/lib/request-processor/go.mod @@ -6,7 +6,7 @@ toolchain go1.23.3 require ( github.com/stretchr/testify v1.10.0 - google.golang.org/grpc v1.69.0 + google.golang.org/grpc v1.69.2 google.golang.org/protobuf v1.35.1 ) diff --git a/lib/request-processor/go.sum b/lib/request-processor/go.sum index 8a26c38d..c5d90604 100644 --- a/lib/request-processor/go.sum +++ b/lib/request-processor/go.sum @@ -47,6 +47,8 @@ google.golang.org/grpc v1.68.1 h1:oI5oTa11+ng8r8XMMN7jAOmWfPZWbYpCFaMUTACxkM0= google.golang.org/grpc v1.68.1/go.mod h1:+q1XYFJjShcqn0QZHvCyeR4CXPA+llXIeUIfIe00waw= google.golang.org/grpc v1.69.0 h1:quSiOM1GJPmPH5XtU+BCoVXcDVJJAzNcoyfC2cCjGkI= google.golang.org/grpc v1.69.0/go.mod h1:vyjdE6jLBI76dgpDojsFGNaHlxdjXN9ghpnd2o7JGZ4= +google.golang.org/grpc v1.69.2 h1:U3S9QEtbXC0bYNvRtcoklF3xGtLViumSYxWykJS+7AU= +google.golang.org/grpc v1.69.2/go.mod h1:vyjdE6jLBI76dgpDojsFGNaHlxdjXN9ghpnd2o7JGZ4= google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA= diff --git a/lib/request-processor/grpc/config.go b/lib/request-processor/grpc/config.go index c8c36839..7e873b18 100644 --- a/lib/request-processor/grpc/config.go +++ b/lib/request-processor/grpc/config.go @@ -5,6 +5,8 @@ import ( "main/globals" "main/ipc/protos" "main/log" + "regexp" + "strings" "time" "github.com/seancfoley/ipaddress-go/ipaddr" @@ -41,6 +43,42 @@ func buildIpBlocklist(name, description string, ipsList []string) IpBlockList { return ipBlocklist } +func getEndpointData(ep *protos.Endpoint) EndpointData { + endpointData := EndpointData{ + ForceProtectionOff: ep.ForceProtectionOff, + RateLimiting: RateLimiting{ + Enabled: ep.RateLimiting.Enabled, + }, + AllowedIPAddresses: map[string]bool{}, + } + for _, ip := range ep.AllowedIPAddresses { + endpointData.AllowedIPAddresses[ip] = true + } + return endpointData +} + +func storeEndpointConfig(ep *protos.Endpoint) { + globals.CloudConfig.Endpoints[EndpointKey{Method: ep.Method, Route: ep.Route}] = getEndpointData(ep) +} + +func storeWildcardEndpointConfig(ep *protos.Endpoint) { + wildcardRouteCompiled, err := regexp.Compile(ep.Route) + if err != nil { + return + } + + wildcardRoutes, exists := globals.CloudConfig.WildcardEndpoints[ep.Method] + if !exists { + globals.CloudConfig.WildcardEndpoints[ep.Method] = []WildcardEndpointData{} + } + + globals.CloudConfig.WildcardEndpoints[ep.Method] = append(wildcardRoutes, WildcardEndpointData{RouteRegex: wildcardRouteCompiled, Data: getEndpointData(ep)}) +} + +func isWildcardEndpoint(method, route string) bool { + return method == "*" || strings.Contains(route, "*") +} + func setCloudConfig(cloudConfigFromAgent *protos.CloudConfig) { if cloudConfigFromAgent == nil { return @@ -52,18 +90,14 @@ func setCloudConfig(cloudConfigFromAgent *protos.CloudConfig) { globals.CloudConfig.ConfigUpdatedAt = cloudConfigFromAgent.ConfigUpdatedAt globals.CloudConfig.Endpoints = map[EndpointKey]EndpointData{} + globals.CloudConfig.WildcardEndpoints = map[string][]WildcardEndpointData{} + for _, ep := range cloudConfigFromAgent.Endpoints { - endpointData := EndpointData{ - ForceProtectionOff: ep.ForceProtectionOff, - RateLimiting: RateLimiting{ - Enabled: ep.RateLimiting.Enabled, - }, - AllowedIPAddresses: map[string]bool{}, - } - for _, ip := range ep.AllowedIPAddresses { - endpointData.AllowedIPAddresses[ip] = true + if isWildcardEndpoint(ep.Method, ep.Route) { + storeWildcardEndpointConfig(ep) + } else { + storeEndpointConfig(ep) } - globals.CloudConfig.Endpoints[EndpointKey{Method: ep.Method, Route: ep.Route}] = endpointData } globals.CloudConfig.BlockedUserIds = map[string]bool{} diff --git a/lib/request-processor/handle_blocking_request.go b/lib/request-processor/handle_blocking_request.go index 9744644b..89117f85 100644 --- a/lib/request-processor/handle_blocking_request.go +++ b/lib/request-processor/handle_blocking_request.go @@ -48,13 +48,12 @@ func OnGetBlockingStatus() string { return "" } - endpointData, err := utils.GetEndpointConfig(method, route) - if err != nil { + if context.IsEndpointConfigured() { log.Debugf("Method+route is not configured in endpoints! Skipping checks...") return "" } - if endpointData.RateLimiting.Enabled { + if context.IsEndpointRateLimitingEnabled() { if !context.IsIpBypassed() { // If request is monitored for rate limiting and the IP is not bypassed, // do a sync call via gRPC to see if the request should be blocked or not @@ -68,8 +67,8 @@ func OnGetBlockingStatus() string { } } - if !utils.IsIpAllowed(endpointData.AllowedIPAddresses, ip) { - log.Infof("IP \"%s\" is not allowd to access this endpoint!", ip) + if !context.IsEndpointIpAllowed() { + log.Infof("IP \"%s\" is not allowed to access this endpoint!", ip) return GetStoreAction("blocked", "ip", "not allowed by config to access this endpoint", ip) } return "" diff --git a/lib/request-processor/handle_path_traversal.go b/lib/request-processor/handle_path_traversal.go index 89c4c1f6..5267066b 100644 --- a/lib/request-processor/handle_path_traversal.go +++ b/lib/request-processor/handle_path_traversal.go @@ -16,7 +16,7 @@ func OnPrePathAccessed() string { return "" } - if context.IsProtectionTurnedOff() { + if context.IsEndpointProtectionTurnedOff() { log.Infof("Protection is turned off -> will not run detection logic!") return "" } diff --git a/lib/request-processor/handle_pdo.go b/lib/request-processor/handle_pdo.go index 47cd24ac..2d0fa371 100644 --- a/lib/request-processor/handle_pdo.go +++ b/lib/request-processor/handle_pdo.go @@ -16,7 +16,7 @@ func OnPreSqlQueryExecuted() string { } log.Info("Got PDO query: ", query, " dialect: ", dialect) - if context.IsProtectionTurnedOff() { + if context.IsEndpointProtectionTurnedOff() { log.Infof("Protection is turned off -> will not run detection logic!") return "" } diff --git a/lib/request-processor/handle_shell_execution.go b/lib/request-processor/handle_shell_execution.go index e78c11ad..f120c555 100644 --- a/lib/request-processor/handle_shell_execution.go +++ b/lib/request-processor/handle_shell_execution.go @@ -16,7 +16,7 @@ func OnPreShellExecuted() string { log.Info("Got shell command: ", cmd) - if context.IsProtectionTurnedOff() { + if context.IsEndpointProtectionTurnedOff() { log.Infof("Protection is turned off -> will not run detection logic!") return "" } diff --git a/lib/request-processor/handle_urls.go b/lib/request-processor/handle_urls.go index 1175558c..5d4be48b 100644 --- a/lib/request-processor/handle_urls.go +++ b/lib/request-processor/handle_urls.go @@ -21,7 +21,7 @@ Protects both curl and fopen wrapper functions (file_get_contents, etc...). func OnPreOutgoingRequest() string { defer context.ResetEventContext() - if context.IsProtectionTurnedOff() { + if context.IsEndpointProtectionTurnedOff() { log.Infof("Protection is turned off -> will not run detection logic!") return "" } @@ -72,7 +72,7 @@ func OnPostOutgoingRequest() string { go grpc.OnDomain(effectiveHostname, effectivePort) } - if context.IsProtectionTurnedOff() { + if context.IsEndpointProtectionTurnedOff() { log.Infof("Protection is turned off -> will not run detection logic!") return "" } diff --git a/lib/request-processor/utils/config.go b/lib/request-processor/utils/config.go index 2bb400b0..1593e61b 100644 --- a/lib/request-processor/utils/config.go +++ b/lib/request-processor/utils/config.go @@ -1,21 +1,44 @@ package utils import ( - "errors" . "main/aikido_types" "main/globals" ) -func GetEndpointConfig(method string, route string) (EndpointData, error) { +func GetWildcardEndpointsConfigsForMethod(method string) []WildcardEndpointData { + wildcardRoutesForMethod, found := globals.CloudConfig.WildcardEndpoints[method] + if !found { + return []WildcardEndpointData{} + } + return wildcardRoutesForMethod +} + +func GetWildcardEndpointsConfigs(method string, route string) []EndpointData { + globals.CloudConfigMutex.Lock() + defer globals.CloudConfigMutex.Unlock() + + wildcardRoutes := GetWildcardEndpointsConfigsForMethod(method) + wildcardRoutes = append(wildcardRoutes, GetWildcardEndpointsConfigsForMethod("*")...) + + matchedEndpointsData := []EndpointData{} + for _, wildcardEndpointData := range wildcardRoutes { + if wildcardEndpointData.RouteRegex.MatchString(route) { + matchedEndpointsData = append(matchedEndpointsData, wildcardEndpointData.Data) + } + } + return matchedEndpointsData +} + +func GetEndpointConfig(method string, route string) (EndpointData, bool) { globals.CloudConfigMutex.Lock() defer globals.CloudConfigMutex.Unlock() endpointData, exists := globals.CloudConfig.Endpoints[EndpointKey{Method: method, Route: route}] if !exists { - return EndpointData{}, errors.New("endpoint does not exist") + return EndpointData{}, false } - return endpointData, nil + return endpointData, true } func GetCloudConfigUpdatedAt() int64 {