Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIK-3458 Report max x attacks per timeframe to core #142

Merged
merged 5 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions aikido_firewall/background_process/aikido_background_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
)
from aikido_firewall.helpers.check_env_for_blocking import check_env_for_blocking
from aikido_firewall.helpers.token import get_token_from_env
from aikido_firewall.background_process.api.http_api import ReportingApiHTTP
from aikido_firewall.background_process.api.http_api_ratelimited import (
ReportingApiHTTPRatelimited,
)
from .commands import process_incoming_command

EMPTY_QUEUE_INTERVAL = 5 # 5 seconds
Expand Down Expand Up @@ -70,7 +72,11 @@
) # Create an event scheduler
self.send_to_connection_manager(event_scheduler)

api = ReportingApiHTTP("https://guard.aikido.dev/")
api = ReportingApiHTTPRatelimited(

Check warning on line 75 in aikido_firewall/background_process/aikido_background_process.py

View check run for this annotation

Codecov / codecov/patch

aikido_firewall/background_process/aikido_background_process.py#L75

Added line #L75 was not covered by tests
"https://guard.aikido.dev/",
max_events_per_interval=100,
interval_in_ms=60 * 60 * 1000,
)
# We need to pass along the scheduler so that the heartbeat also gets sent
self.connection_manager = CloudConnectionManager(
block=check_env_for_blocking(),
Expand Down
33 changes: 33 additions & 0 deletions aikido_firewall/background_process/api/http_api_ratelimited.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""
Exports ReportingApiHTTPRatelimited
"""

import aikido_firewall.background_process.api.http_api as http_api
import aikido_firewall.helpers.get_current_unixtime_ms as t


class ReportingApiHTTPRatelimited(http_api.ReportingApiHTTP):
"""HTTP Reporting API that has ratelimiting support"""

def __init__(self, reporting_url, max_events_per_interval, interval_in_ms):
super().__init__(reporting_url)
self.interval_in_ms = interval_in_ms
self.max_events_per_interval = max_events_per_interval
self.events = []

def report(self, token, event, timeout_in_sec):
bitterpanda63 marked this conversation as resolved.
Show resolved Hide resolved
if event["type"] == "detected_attack":
# Remove all outdated events :
current_time = t.get_unixtime_ms()

def event_in_interval_filter(e):
return e["time"] > current_time - self.interval_in_ms

self.events = list(filter(event_in_interval_filter, self.events))

# Check if interval is exceeded :
if len(self.events) >= self.max_events_per_interval:
return {"success": False, "error": "max_attacks_reached"}

self.events.append(event)
return super().report(token, event, timeout_in_sec)
207 changes: 207 additions & 0 deletions aikido_firewall/background_process/api/http_api_ratelimited_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
import pytest
from unittest.mock import patch
from aikido_firewall.background_process.api.http_api import ReportingApiHTTP
from aikido_firewall.helpers.get_current_unixtime_ms import get_unixtime_ms
from .http_api_ratelimited import ReportingApiHTTPRatelimited


@pytest.fixture
def reporting_api():
"""Fixture to create an instance of ReportingApiHTTPRatelimited."""
return ReportingApiHTTPRatelimited(
reporting_url="http://example.com",
max_events_per_interval=3,
interval_in_ms=10000,
)


def test_report_within_limit(reporting_api):
"""Test reporting within the rate limit."""
event = {"type": "detected_attack", "time": 1000}

with patch.object(
ReportingApiHTTP, "report", return_value={"success": True}
) as mock_report:
with patch(
"aikido_firewall.helpers.get_current_unixtime_ms.get_unixtime_ms",
return_value=2000,
):
response = reporting_api.report(
token="token", event=event, timeout_in_sec=5
)
assert response == {"success": True}
assert len(reporting_api.events) == 1
mock_report.assert_called_once_with("token", event, 5)


def test_report_exceeds_limit(reporting_api):
"""Test reporting when the limit is exceeded."""
event = {"type": "detected_attack", "time": 1000}

# Simulate adding events to reach the limit
reporting_api.events = [
{"type": "detected_attack", "time": 1000},
{"type": "detected_attack", "time": 2000},
{"type": "detected_attack", "time": 3000},
]

with patch(
"aikido_firewall.helpers.get_current_unixtime_ms.get_unixtime_ms",
return_value=4000,
):
response = reporting_api.report(token="token", event=event, timeout_in_sec=5)
assert response == {"success": False, "error": "max_attacks_reached"}
assert len(reporting_api.events) == 3 # Should not add the new event


def test_report_within_limit_after_expiry(reporting_api):
"""Test reporting after some events have expired."""
event1 = {"type": "detected_attack", "time": 1000}
event2 = {"type": "detected_attack", "time": 2001}

# Add events to the list
reporting_api.events = [event1, event2]

with patch.object(
ReportingApiHTTP, "report", return_value={"success": True}
) as mock_report:
with patch(
"aikido_firewall.helpers.get_current_unixtime_ms.get_unixtime_ms",
return_value=12000,
):
event3 = {"type": "detected_attack", "time": 11000}
response = reporting_api.report(
token="token", event=event3, timeout_in_sec=5
)
assert response == {"success": True}
assert (
len(reporting_api.events) == 2
) # One event should have expired, and the new one is added


def test_report_with_non_attack_event(reporting_api):
"""Test reporting with a non-attack event."""
event = {"type": "other_event", "time": 1000}

with patch.object(
ReportingApiHTTP, "report", return_value={"success": True}
) as mock_report:
response = reporting_api.report(token="token", event=event, timeout_in_sec=5)
assert response == {"success": True}
assert len(reporting_api.events) == 0 # Non-attack events should not be stored
mock_report.assert_called_once_with("token", event, 5)


def test_report_multiple_events_within_limit(reporting_api):
"""Test reporting multiple events within the rate limit."""
events = [
{"type": "detected_attack", "time": 1000},
{"type": "detected_attack", "time": 2000},
{"type": "detected_attack", "time": 3000},
]

with patch.object(
ReportingApiHTTP, "report", return_value={"success": True}
) as mock_report:
with patch(
"aikido_firewall.helpers.get_current_unixtime_ms.get_unixtime_ms",
return_value=4000,
):
for event in events:
response = reporting_api.report(
token="token", event=event, timeout_in_sec=5
)
assert response == {"success": True}
assert len(reporting_api.events) == 3
assert mock_report.call_count == 3


def test_report_mixed_event_types(reporting_api):
"""Test reporting with a mix of attack and non-attack events."""
attack_event = {"type": "detected_attack", "time": 1000}
non_attack_event = {"type": "other_event", "time": 2000}

with patch.object(
ReportingApiHTTP, "report", return_value={"success": True}
) as mock_report:
response = reporting_api.report(
token="token", event=attack_event, timeout_in_sec=5
)
assert response == {"success": True}
assert len(reporting_api.events) == 1

response = reporting_api.report(
token="token", event=non_attack_event, timeout_in_sec=5
)
assert response == {"success": True}
assert len(reporting_api.events) == 1 # Non-attack event should not be stored


def test_report_event_expiry(reporting_api):
"""Test that events expire correctly based on the time interval."""
event1 = {"type": "detected_attack", "time": 1000}
event2 = {"type": "detected_attack", "time": 2000}
reporting_api.events = [event1, event2]

# Simulate time passing
with patch(
"aikido_firewall.helpers.get_current_unixtime_ms.get_unixtime_ms",
return_value=12000,
):
event3 = {"type": "detected_attack", "time": 11000}
response = reporting_api.report(token="token", event=event3, timeout_in_sec=5)
assert response == {"error": "timeout", "success": False}
assert (
len(reporting_api.events) == 1
) # One event should have expired, and the new one is added


def test_report_event_at_boundary(reporting_api):
"""Test reporting an event at the boundary of the interval."""
event1 = {"type": "detected_attack", "time": 1000}
reporting_api.events = [event1]

with patch(
"aikido_firewall.helpers.get_current_unixtime_ms.get_unixtime_ms",
return_value=10000,
):
event2 = {"type": "detected_attack", "time": 10000} # Exactly at the boundary
response = reporting_api.report(token="token", event=event2, timeout_in_sec=5)
assert response == {"error": "timeout", "success": False}
assert (
len(reporting_api.events) == 2
) # Should be added since it's at the boundary


def test_report_invalid_event_type(reporting_api):
"""Test reporting with an invalid event type."""
event = {"type": "invalid_event", "time": 1000}

with patch.object(
ReportingApiHTTP, "report", return_value={"success": True}
) as mock_report:
response = reporting_api.report(token="token", event=event, timeout_in_sec=5)
assert response == {"success": True}
assert (
len(reporting_api.events) == 0
) # Invalid event types should not be stored
mock_report.assert_called_once_with("token", event, 5)


def test_report_no_events(reporting_api):
"""Test reporting when no events have been reported yet."""
event = {"type": "detected_attack", "time": 1000}

with patch.object(
ReportingApiHTTP, "report", return_value={"success": True}
) as mock_report:
with patch(
"aikido_firewall.helpers.get_current_unixtime_ms.get_unixtime_ms",
return_value=2000,
):
response = reporting_api.report(
token="token", event=event, timeout_in_sec=5
)
assert response == {"success": True}
assert len(reporting_api.events) == 1 # Should add the first event
mock_report.assert_called_once_with("token", event, 5)
Loading