Skip to content

Commit

Permalink
Merge pull request #142 from AikidoSec/AIK-3458
Browse files Browse the repository at this point in the history
AIK-3458 Report max x attacks per timeframe to core
  • Loading branch information
willem-delbare authored Sep 4, 2024
2 parents 11dd9b6 + 27bced0 commit 653013a
Show file tree
Hide file tree
Showing 3 changed files with 248 additions and 2 deletions.
10 changes: 8 additions & 2 deletions aikido_firewall/background_process/aikido_background_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
)
from aikido_firewall.helpers.check_env_for_blocking import check_env_for_blocking
from aikido_firewall.helpers.token import get_token_from_env
from aikido_firewall.background_process.api.http_api import ReportingApiHTTP
from aikido_firewall.background_process.api.http_api_ratelimited import (
ReportingApiHTTPRatelimited,
)
from .commands import process_incoming_command

EMPTY_QUEUE_INTERVAL = 5 # 5 seconds
Expand Down Expand Up @@ -70,7 +72,11 @@ def reporting_thread(self):
) # Create an event scheduler
self.send_to_connection_manager(event_scheduler)

api = ReportingApiHTTP("https://guard.aikido.dev/")
api = ReportingApiHTTPRatelimited(
"https://guard.aikido.dev/",
max_events_per_interval=100,
interval_in_ms=60 * 60 * 1000,
)
# We need to pass along the scheduler so that the heartbeat also gets sent
self.connection_manager = CloudConnectionManager(
block=check_env_for_blocking(),
Expand Down
33 changes: 33 additions & 0 deletions aikido_firewall/background_process/api/http_api_ratelimited.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""
Exports ReportingApiHTTPRatelimited
"""

import aikido_firewall.background_process.api.http_api as http_api
import aikido_firewall.helpers.get_current_unixtime_ms as t


class ReportingApiHTTPRatelimited(http_api.ReportingApiHTTP):
"""HTTP Reporting API that has ratelimiting support"""

def __init__(self, reporting_url, max_events_per_interval, interval_in_ms):
super().__init__(reporting_url)
self.interval_in_ms = interval_in_ms
self.max_events_per_interval = max_events_per_interval
self.events = []

def report(self, token, event, timeout_in_sec):
if event["type"] == "detected_attack":
# Remove all outdated events :
current_time = t.get_unixtime_ms()

def event_in_interval_filter(e):
return e["time"] > current_time - self.interval_in_ms

self.events = list(filter(event_in_interval_filter, self.events))

# Check if interval is exceeded :
if len(self.events) >= self.max_events_per_interval:
return {"success": False, "error": "max_attacks_reached"}

self.events.append(event)
return super().report(token, event, timeout_in_sec)
207 changes: 207 additions & 0 deletions aikido_firewall/background_process/api/http_api_ratelimited_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
import pytest
from unittest.mock import patch
from aikido_firewall.background_process.api.http_api import ReportingApiHTTP
from aikido_firewall.helpers.get_current_unixtime_ms import get_unixtime_ms
from .http_api_ratelimited import ReportingApiHTTPRatelimited


@pytest.fixture
def reporting_api():
"""Fixture to create an instance of ReportingApiHTTPRatelimited."""
return ReportingApiHTTPRatelimited(
reporting_url="http://example.com",
max_events_per_interval=3,
interval_in_ms=10000,
)


def test_report_within_limit(reporting_api):
"""Test reporting within the rate limit."""
event = {"type": "detected_attack", "time": 1000}

with patch.object(
ReportingApiHTTP, "report", return_value={"success": True}
) as mock_report:
with patch(
"aikido_firewall.helpers.get_current_unixtime_ms.get_unixtime_ms",
return_value=2000,
):
response = reporting_api.report(
token="token", event=event, timeout_in_sec=5
)
assert response == {"success": True}
assert len(reporting_api.events) == 1
mock_report.assert_called_once_with("token", event, 5)


def test_report_exceeds_limit(reporting_api):
"""Test reporting when the limit is exceeded."""
event = {"type": "detected_attack", "time": 1000}

# Simulate adding events to reach the limit
reporting_api.events = [
{"type": "detected_attack", "time": 1000},
{"type": "detected_attack", "time": 2000},
{"type": "detected_attack", "time": 3000},
]

with patch(
"aikido_firewall.helpers.get_current_unixtime_ms.get_unixtime_ms",
return_value=4000,
):
response = reporting_api.report(token="token", event=event, timeout_in_sec=5)
assert response == {"success": False, "error": "max_attacks_reached"}
assert len(reporting_api.events) == 3 # Should not add the new event


def test_report_within_limit_after_expiry(reporting_api):
"""Test reporting after some events have expired."""
event1 = {"type": "detected_attack", "time": 1000}
event2 = {"type": "detected_attack", "time": 2001}

# Add events to the list
reporting_api.events = [event1, event2]

with patch.object(
ReportingApiHTTP, "report", return_value={"success": True}
) as mock_report:
with patch(
"aikido_firewall.helpers.get_current_unixtime_ms.get_unixtime_ms",
return_value=12000,
):
event3 = {"type": "detected_attack", "time": 11000}
response = reporting_api.report(
token="token", event=event3, timeout_in_sec=5
)
assert response == {"success": True}
assert (
len(reporting_api.events) == 2
) # One event should have expired, and the new one is added


def test_report_with_non_attack_event(reporting_api):
"""Test reporting with a non-attack event."""
event = {"type": "other_event", "time": 1000}

with patch.object(
ReportingApiHTTP, "report", return_value={"success": True}
) as mock_report:
response = reporting_api.report(token="token", event=event, timeout_in_sec=5)
assert response == {"success": True}
assert len(reporting_api.events) == 0 # Non-attack events should not be stored
mock_report.assert_called_once_with("token", event, 5)


def test_report_multiple_events_within_limit(reporting_api):
"""Test reporting multiple events within the rate limit."""
events = [
{"type": "detected_attack", "time": 1000},
{"type": "detected_attack", "time": 2000},
{"type": "detected_attack", "time": 3000},
]

with patch.object(
ReportingApiHTTP, "report", return_value={"success": True}
) as mock_report:
with patch(
"aikido_firewall.helpers.get_current_unixtime_ms.get_unixtime_ms",
return_value=4000,
):
for event in events:
response = reporting_api.report(
token="token", event=event, timeout_in_sec=5
)
assert response == {"success": True}
assert len(reporting_api.events) == 3
assert mock_report.call_count == 3


def test_report_mixed_event_types(reporting_api):
"""Test reporting with a mix of attack and non-attack events."""
attack_event = {"type": "detected_attack", "time": 1000}
non_attack_event = {"type": "other_event", "time": 2000}

with patch.object(
ReportingApiHTTP, "report", return_value={"success": True}
) as mock_report:
response = reporting_api.report(
token="token", event=attack_event, timeout_in_sec=5
)
assert response == {"success": True}
assert len(reporting_api.events) == 1

response = reporting_api.report(
token="token", event=non_attack_event, timeout_in_sec=5
)
assert response == {"success": True}
assert len(reporting_api.events) == 1 # Non-attack event should not be stored


def test_report_event_expiry(reporting_api):
"""Test that events expire correctly based on the time interval."""
event1 = {"type": "detected_attack", "time": 1000}
event2 = {"type": "detected_attack", "time": 2000}
reporting_api.events = [event1, event2]

# Simulate time passing
with patch(
"aikido_firewall.helpers.get_current_unixtime_ms.get_unixtime_ms",
return_value=12000,
):
event3 = {"type": "detected_attack", "time": 11000}
response = reporting_api.report(token="token", event=event3, timeout_in_sec=5)
assert response == {"error": "timeout", "success": False}
assert (
len(reporting_api.events) == 1
) # One event should have expired, and the new one is added


def test_report_event_at_boundary(reporting_api):
"""Test reporting an event at the boundary of the interval."""
event1 = {"type": "detected_attack", "time": 1000}
reporting_api.events = [event1]

with patch(
"aikido_firewall.helpers.get_current_unixtime_ms.get_unixtime_ms",
return_value=10000,
):
event2 = {"type": "detected_attack", "time": 10000} # Exactly at the boundary
response = reporting_api.report(token="token", event=event2, timeout_in_sec=5)
assert response == {"error": "timeout", "success": False}
assert (
len(reporting_api.events) == 2
) # Should be added since it's at the boundary


def test_report_invalid_event_type(reporting_api):
"""Test reporting with an invalid event type."""
event = {"type": "invalid_event", "time": 1000}

with patch.object(
ReportingApiHTTP, "report", return_value={"success": True}
) as mock_report:
response = reporting_api.report(token="token", event=event, timeout_in_sec=5)
assert response == {"success": True}
assert (
len(reporting_api.events) == 0
) # Invalid event types should not be stored
mock_report.assert_called_once_with("token", event, 5)


def test_report_no_events(reporting_api):
"""Test reporting when no events have been reported yet."""
event = {"type": "detected_attack", "time": 1000}

with patch.object(
ReportingApiHTTP, "report", return_value={"success": True}
) as mock_report:
with patch(
"aikido_firewall.helpers.get_current_unixtime_ms.get_unixtime_ms",
return_value=2000,
):
response = reporting_api.report(
token="token", event=event, timeout_in_sec=5
)
assert response == {"success": True}
assert len(reporting_api.events) == 1 # Should add the first event
mock_report.assert_called_once_with("token", event, 5)

0 comments on commit 653013a

Please sign in to comment.