Skip to content

Commit

Permalink
Feat/migrate aws health functions (#601)
Browse files Browse the repository at this point in the history
* Refactor SRE command function for better readability and maintainability

* fix: add missing test

* feat: update module to use organizations list accounts fn()

* fix: patch the constant in the test

* feat: move security hub functions into integrations

* feat: move security hub functions into integrations

* feat: move guard duty functions to integrations

* feat: move config functions to integrations

* feat: move cost explorer to integrations

* chore: fmt

* feat: remove unused assume function

* chore: fmt

* fix: make kwargs case conversion optional
  • Loading branch information
gcharest authored Jul 31, 2024
1 parent 5aa41de commit 4701d99
Show file tree
Hide file tree
Showing 14 changed files with 558 additions and 208 deletions.
10 changes: 6 additions & 4 deletions app/integrations/aws/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import logging
from functools import wraps
import boto3 # type: ignore
from botocore.exceptions import BotoCoreError, ClientError # type: ignore
from botocore.exceptions import BotoCoreError, ClientError
from botocore.client import BaseClient # type: ignore
from dotenv import load_dotenv
from integrations.utils.api import convert_kwargs_to_pascal_case

Expand Down Expand Up @@ -116,8 +117,9 @@ def execute_aws_api_call(
ValueError: If the role_arn is not provided.
"""
config = kwargs.pop("config", dict(region_name=AWS_REGION))
convert_kwargs = kwargs.pop("convert_kwargs", True)
client = get_aws_service_client(service_name, role_arn, **config)
if kwargs:
if kwargs and convert_kwargs:
kwargs = convert_kwargs_to_pascal_case(kwargs)
api_method = getattr(client, method)
if paginated:
Expand All @@ -126,11 +128,11 @@ def execute_aws_api_call(
return api_method(**kwargs)


def paginator(client, operation, keys=None, **kwargs):
def paginator(client: BaseClient, operation, keys=None, **kwargs):
"""Generic paginator for AWS operations
Args:
client (botocore.client.BaseClient): The service client.
client (BaseClient): The service client.
operation (str): The operation to paginate.
keys (list, optional): The keys to extract from the paginated results.
**kwargs: Additional keyword arguments for the operation.
Expand Down
31 changes: 31 additions & 0 deletions app/integrations/aws/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import os
from integrations.aws.client import execute_aws_api_call, handle_aws_api_errors

AUDIT_ROLE_ARN = os.environ.get("AWS_AUDIT_ACCOUNT_ROLE_ARN")


@handle_aws_api_errors
def describe_aggregate_compliance_by_config_rules(config_aggregator_name, filters):
"""Retrieves the aggregate compliance of AWS Config rules for an account.
Args:
config_aggregator_name (str): The name of the AWS Config aggregator.
filters (dict): Filters to apply to the compliance results.
Returns:
list: A list of compliance objects
"""
params = {
"ConfigurationAggregatorName": config_aggregator_name,
"Filters": filters,
}
response = execute_aws_api_call(
"config",
"describe_aggregate_compliance_by_config_rules",
paginated=True,
keys=["AggregateComplianceByConfigRules"],
role_arn=AUDIT_ROLE_ARN,
convert_kwargs=False,
**params,
)
return response if response else []
27 changes: 27 additions & 0 deletions app/integrations/aws/cost_explorer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from logging import getLogger
import os
from .client import execute_aws_api_call, handle_aws_api_errors

logger = getLogger(__name__)
ORG_ROLE_ARN = os.environ.get("AWS_ORG_ACCOUNT_ROLE_ARN")


@handle_aws_api_errors
def get_cost_and_usage(time_period, granularity, metrics, filter=None, group_by=None):
params = {
"TimePeriod": time_period,
"Granularity": granularity,
"Metrics": metrics,
}
if filter:
params["Filter"] = filter
if group_by:
params["GroupBy"] = group_by

return execute_aws_api_call(
"ce",
"get_cost_and_usage",
role_arn=ORG_ROLE_ARN,
convert_kwargs=False,
**params,
)
51 changes: 51 additions & 0 deletions app/integrations/aws/guard_duty.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os
from integrations.aws.client import execute_aws_api_call, handle_aws_api_errors

LOGGING_ROLE_ARN = os.environ.get("AWS_LOGGING_ACCOUNT_ROLE_ARN")


@handle_aws_api_errors
def list_detectors():
"""Retrieves all detectors from AWS GuardDuty
Returns:
list: A list of detector objects.
"""
response = execute_aws_api_call(
"guardduty",
"list_detectors",
paginated=True,
keys=["DetectorIds"],
role_arn=LOGGING_ROLE_ARN,
)
return response if response else []


@handle_aws_api_errors
def get_findings_statistics(detector_id, finding_criteria=None):
"""Retrieves the findings statistics for a given detector
Args:
detector_id (str): The ID of the detector.
finding_criteria (dict, optional): The criteria to use to filter the findings
Returns:
dict: The findings statistics.
"""

params = {
"DetectorId": detector_id,
"FindingStatisticTypes": ["COUNT_BY_SEVERITY"],
}
if finding_criteria:
params["FindingCriteria"] = finding_criteria

response = execute_aws_api_call(
"guardduty",
"get_findings_statistics",
role_arn=LOGGING_ROLE_ARN,
convert_kwargs=False,
**params,
)

return response if response else {}
24 changes: 24 additions & 0 deletions app/integrations/aws/security_hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import os
from integrations.aws.client import execute_aws_api_call, handle_aws_api_errors

LOGGING_ROLE_ARN = os.environ.get("AWS_LOGGING_ACCOUNT_ROLE_ARN")


@handle_aws_api_errors
def get_findings(filters):
"""Retrieves all findings from AWS Security Hub
Args:
filters (dict): Filters to apply to the findings.
Returns:
list: A list of finding objects.
"""
response = execute_aws_api_call(
"securityhub",
"get_findings",
paginated=True,
role_arn=LOGGING_ROLE_ARN,
filters=filters,
)
return response
176 changes: 69 additions & 107 deletions app/modules/aws/aws_account_health.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,15 @@
import arrow
import boto3
import os
from slack_bolt import Ack
from slack_sdk import WebClient
from logging import Logger

AUDIT_ROLE_ARN = os.environ["AWS_AUDIT_ACCOUNT_ROLE_ARN"]
LOGGING_ROLE_ARN = os.environ.get("AWS_LOGGING_ACCOUNT_ROLE_ARN")
ORG_ROLE_ARN = os.environ.get("AWS_ORG_ACCOUNT_ROLE_ARN")


def assume_role_client(client_type, role=ORG_ROLE_ARN, region="ca-central-1"):
client = boto3.client("sts")

response = client.assume_role(
RoleArn=role, RoleSessionName="SREBot_Org_Account_Role"
)

session = boto3.Session(
aws_access_key_id=response["Credentials"]["AccessKeyId"],
aws_secret_access_key=response["Credentials"]["SecretAccessKey"],
aws_session_token=response["Credentials"]["SessionToken"],
)

return session.client(client_type, region_name=region)


def get_accounts():
client = assume_role_client("organizations")
response = client.list_accounts()
accounts = {}
# Loop response for NextToken
while True:
for account in response["Accounts"]:
accounts[account["Id"]] = account["Name"]
if "NextToken" in response:
response = client.list_accounts(NextToken=response["NextToken"])
else:
break
return dict(sorted(accounts.items(), key=lambda item: item[1]))
from integrations.aws import (
organizations,
security_hub,
guard_duty,
config,
cost_explorer,
)


def get_account_health(account_id):
Expand Down Expand Up @@ -77,15 +51,13 @@ def get_account_health(account_id):


def get_account_spend(account_id, start_date, end_date):
client = assume_role_client("ce")
response = client.get_cost_and_usage(
TimePeriod={"Start": start_date, "End": end_date},
Granularity="MONTHLY",
Metrics=["UnblendedCost"],
GroupBy=[
{"Type": "DIMENSION", "Key": "LINKED_ACCOUNT"},
],
Filter={"Dimensions": {"Key": "LINKED_ACCOUNT", "Values": [account_id]}},
time_period = {"Start": start_date, "End": end_date}
granularity = "MONTHLY"
metrics = ["UnblendedCost"]
group_by = [{"Type": "DIMENSION", "Key": "LINKED_ACCOUNT"}]
filter = {"Dimensions": {"Key": "LINKED_ACCOUNT", "Values": [account_id]}}
response = cost_explorer.get_cost_and_usage(
time_period, granularity, metrics, filter, group_by
)
if "Groups" in response["ResultsByTime"][0]:
return "{:0,.2f}".format(
Expand All @@ -100,71 +72,57 @@ def get_account_spend(account_id, start_date, end_date):


def get_config_summary(account_id):
client = assume_role_client("config", role=AUDIT_ROLE_ARN)
response = client.describe_aggregate_compliance_by_config_rules(
ConfigurationAggregatorName="aws-controltower-GuardrailsComplianceAggregator",
Filters={
"AccountId": account_id,
"ComplianceType": "NON_COMPLIANT",
},
config_name = "aws-controltower-GuardrailsComplianceAggregator"
filters = {
"AccountId": account_id,
"ComplianceType": "NON_COMPLIANT",
}
return len(
config.describe_aggregate_compliance_by_config_rules(config_name, filters)
)
return len(response["AggregateComplianceByConfigRules"])


def get_guardduty_summary(account_id):
client = assume_role_client("guardduty", role=LOGGING_ROLE_ARN)
detector_id = client.list_detectors()["DetectorIds"][0]
response = client.get_findings_statistics(
DetectorId=detector_id,
FindingStatisticTypes=[
"COUNT_BY_SEVERITY",
],
FindingCriteria={
"Criterion": {
"accountId": {"Eq": [account_id]},
"service.archived": {"Eq": ["false", "false"]},
"severity": {"Gte": 7},
}
},
)

detector_ids = guard_duty.list_detectors()
finding_criteria = {
"Criterion": {
"accountId": {"Eq": [account_id]},
"service.archived": {"Eq": ["false", "false"]},
"severity": {"Gte": 7},
}
}
response = guard_duty.get_findings_statistics(detector_ids[0], finding_criteria)
return sum(response["FindingStatistics"]["CountBySeverity"].values())


def get_securityhub_summary(account_id):
client = assume_role_client("securityhub", role=LOGGING_ROLE_ARN)
response = client.get_findings(
Filters={
"AwsAccountId": [{"Value": account_id, "Comparison": "EQUALS"}],
"ComplianceStatus": [
{"Value": "FAILED", "Comparison": "EQUALS"},
],
"RecordState": [
{"Value": "ACTIVE", "Comparison": "EQUALS"},
],
"SeverityProduct": [
{
"Gte": 70,
"Lte": 100,
},
],
"Title": get_ignored_security_hub_issues(),
"UpdatedAt": [
{"DateRange": {"Value": 1, "Unit": "DAYS"}},
],
"WorkflowStatus": [
{"Value": "NEW", "Comparison": "EQUALS"},
],
}
)
filters = {
"AwsAccountId": [{"Value": account_id, "Comparison": "EQUALS"}],
"ComplianceStatus": [
{"Value": "FAILED", "Comparison": "EQUALS"},
],
"RecordState": [
{"Value": "ACTIVE", "Comparison": "EQUALS"},
],
"SeverityProduct": [
{
"Gte": 70,
"Lte": 100,
},
],
"Title": get_ignored_security_hub_issues(),
"UpdatedAt": [
{"DateRange": {"Value": 1, "Unit": "DAYS"}},
],
"WorkflowStatus": [
{"Value": "NEW", "Comparison": "EQUALS"},
],
}
response = security_hub.get_findings(filters)
issues = 0
# Loop response for NextToken
while True:
issues += len(response["Findings"])
if "NextToken" in response:
response = client.get_findings(NextToken=response["NextToken"])
else:
break
if response:
for res in response:
issues += len(res["Findings"])
return issues


Expand All @@ -177,7 +135,7 @@ def get_ignored_security_hub_issues():
return list(map(lambda t: {"Value": t, "Comparison": "NOT_EQUALS"}, ignored_issues))


def health_view_handler(ack, body, logger, client):
def health_view_handler(ack: Ack, body, logger: Logger, client: WebClient):
ack()

account_id = body["view"]["state"]["values"]["account"]["account"][
Expand Down Expand Up @@ -239,15 +197,19 @@ def health_view_handler(ack, body, logger, client):
)


def request_health_modal(client, body):
accounts = get_accounts()
def request_health_modal(client: WebClient, body):
accounts = organizations.list_organization_accounts()
options = [
{
"text": {"type": "plain_text", "text": value},
"value": key,
"text": {
"type": "plain_text",
"text": f"{account['Name']} ({account['Id']})",
},
"value": account["Id"],
}
for key, value in accounts.items()
for account in accounts
]
options.sort(key=lambda x: x["text"]["text"].lower())
client.views_open(
trigger_id=body["trigger_id"],
view={
Expand Down
Loading

0 comments on commit 4701d99

Please sign in to comment.