Skip to content

Commit

Permalink
Refactor/webhooks dynamodb (#657)
Browse files Browse the repository at this point in the history
* fix: remove duplicate dependency

* fix: handling of custom client config

* fix: update webhooks to leverage aws integration

* feat: hoist dynamodb methods

* fix: remove use of kwargs function in execute aws api call

* fix: remove use of kwargs function in execute aws api call

* fix: remove use of kwargs function and assert statement

* fix: expected arguments in test

* fix: use proper argument case

* chore: cleanup unused imports

* chore: cleanup unused imports

* feat: migrate the slack webhooks methods into the modules.slack package

* feat: add dev command for slack tests

* fix: update aws dev command with dynamodb integration

* fix: remove commented code
  • Loading branch information
gcharest authored Sep 19, 2024
1 parent 3ee8c4c commit 84c115c
Show file tree
Hide file tree
Showing 21 changed files with 458 additions and 318 deletions.
66 changes: 53 additions & 13 deletions app/integrations/aws/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
import logging
from functools import wraps
import boto3 # type: ignore
from botocore.exceptions import BotoCoreError, ClientError
from botocore.exceptions import BotoCoreError, ClientError # type: ignore
from botocore.client import BaseClient # type: ignore
from dotenv import load_dotenv
from integrations.utils.api import convert_kwargs_to_pascal_case

load_dotenv()

Expand Down Expand Up @@ -74,7 +73,11 @@ def assume_role_session(role_arn, session_name="DefaultSession"):

@handle_aws_api_errors
def get_aws_service_client(
service_name, role_arn=None, session_name="DefaultSession", **config
service_name,
role_arn=None,
session_name="DefaultSession",
session_config=None,
client_config=None,
):
"""Get an AWS service client. If a role_arn is provided in the config, assume the role to get temporary credentials.
Expand All @@ -85,12 +88,16 @@ def get_aws_service_client(
Returns:
botocore.client.BaseClient: The service client.
"""
if session_config is None:
session_config = {}
if client_config is None:
client_config = {}

if role_arn:
session = assume_role_session(role_arn, session_name)
else:
session = boto3.Session(**config)
return session.client(service_name)
session = boto3.Session(**session_config)
return session.client(service_name, **client_config)


def execute_aws_api_call(
Expand All @@ -99,6 +106,8 @@ def execute_aws_api_call(
paginated=False,
keys=None,
role_arn=None,
session_config=None,
client_config=None,
**kwargs,
):
"""Execute an AWS API call.
Expand All @@ -116,16 +125,34 @@ def execute_aws_api_call(
Raises:
ValueError: If the role_arn is not provided.
"""
config = kwargs.pop("config", dict(region_name=AWS_REGION))
convert_kwargs = kwargs.pop("convert_kwargs", True)
client = get_aws_service_client(service_name, role_arn, **config)
if kwargs and convert_kwargs:
kwargs = convert_kwargs_to_pascal_case(kwargs)
if session_config is None:
session_config = {"region_name": AWS_REGION}
if client_config is None:
client_config = {"region_name": AWS_REGION}

client = get_aws_service_client(
service_name,
role_arn,
session_config=session_config,
client_config=client_config,
)
api_method = getattr(client, method)
if paginated:
return paginator(client, method, keys, **kwargs)
results = paginator(client, method, keys, **kwargs)
else:
return api_method(**kwargs)
results = api_method(**kwargs)

if (
"ResponseMetadata" in results
and results["ResponseMetadata"]["HTTPStatusCode"] != 200
):
logger.error(
f"API call to {service_name}.{method} failed with status code {results['ResponseMetadata']['HTTPStatusCode']}"
)
raise Exception(
f"API call to {service_name}.{method} failed with status code {results['ResponseMetadata']['HTTPStatusCode']}"
)
return results


def paginator(client: BaseClient, operation, keys=None, **kwargs):
Expand All @@ -147,7 +174,20 @@ def paginator(client: BaseClient, operation, keys=None, **kwargs):

for page in paginator.paginate(**kwargs):
if keys is None:
results.append(page)
for key, value in page.items():
if key != "ResponseMetadata":
if isinstance(value, list):
results.extend(value)
else:
results.append(value)
else:
if key == "ResponseMetadata" and value["HTTPStatusCode"] != 200:
logger.error(
f"API call to {client.meta.service_model.service_name}.{operation} failed with status code {value['HTTPStatusCode']}"
)
raise Exception(
f"API call to {client.meta.service_model.service_name}.{operation} failed with status code {value['HTTPStatusCode']}"
)
else:
for key in keys:
if key in page:
Expand Down
1 change: 0 additions & 1 deletion app/integrations/aws/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def describe_aggregate_compliance_by_config_rules(config_aggregator_name, filter
paginated=True,
keys=["AggregateComplianceByConfigRules"],
role_arn=AUDIT_ROLE_ARN,
convert_kwargs=False,
**params,
)
return response if response else []
1 change: 0 additions & 1 deletion app/integrations/aws/cost_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,5 @@ def get_cost_and_usage(time_period, granularity, metrics, filter=None, group_by=
"ce",
"get_cost_and_usage",
role_arn=ORG_ROLE_ARN,
convert_kwargs=False,
**params,
)
92 changes: 79 additions & 13 deletions app/integrations/aws/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,95 @@

@handle_aws_api_errors
def query(
table_name,
key_condition_expression=None,
expression_attribute_values=None,
TableName,
**kwargs,
):
params = {"TableName": table_name, "config": client_config}
if key_condition_expression:
params["KeyConditionExpression"] = key_condition_expression
if expression_attribute_values:
params["ExpressionAttributeValues"] = expression_attribute_values
params = {
"TableName": TableName,
}
if params:
params.update(kwargs)
response = execute_aws_api_call(
"dynamodb", "query", paginated=True, client_config=client_config, **params
)
return response


@handle_aws_api_errors
def scan(TableName, **kwargs):
params = {
"TableName": TableName,
}
if kwargs:
params.update(kwargs)
response = execute_aws_api_call("dynamodb", "query", **params)
response = execute_aws_api_call(
"dynamodb",
"scan",
paginated=True,
keys=["Items"],
client_config=client_config,
**params,
)
return response


@handle_aws_api_errors
def scan(table, **kwargs):
def put_item(TableName, **kwargs):
params = {
"TableName": table,
"config": client_config,
"TableName": TableName,
}
if kwargs:
params.update(kwargs)
response = execute_aws_api_call("dynamodb", "scan", **params)
response = execute_aws_api_call(
"dynamodb", "put_item", client_config=client_config, **params
)
return response


@handle_aws_api_errors
def get_item(TableName, **kwargs):
params = {
"TableName": TableName,
}
if kwargs:
params.update(kwargs)
response = execute_aws_api_call(
"dynamodb", "get_item", client_config=client_config, **params
)
if response.get("ResponseMetadata", {}).get("HTTPStatusCode") == 200:
return response


@handle_aws_api_errors
def update_item(TableName, **kwargs):
params = {
"TableName": TableName,
}
if kwargs:
params.update(kwargs)
response = execute_aws_api_call(
"dynamodb", "update_item", client_config=client_config, **params
)
if response.get("ResponseMetadata", {}).get("HTTPStatusCode") == 200:
return response


@handle_aws_api_errors
def delete_item(TableName, **kwargs):
params = {
"TableName": TableName,
}
if kwargs:
params.update(kwargs)
response = execute_aws_api_call(
"dynamodb", "delete_item", client_config=client_config, **params
)
return response


@handle_aws_api_errors
def list_tables(**kwargs):
response = execute_aws_api_call(
"dynamodb", "list_tables", client_config=client_config, **kwargs
)
return response
1 change: 0 additions & 1 deletion app/integrations/aws/guard_duty.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def get_findings_statistics(detector_id, finding_criteria=None):
"guardduty",
"get_findings_statistics",
role_arn=LOGGING_ROLE_ARN,
convert_kwargs=False,
**params,
)

Expand Down
2 changes: 1 addition & 1 deletion app/integrations/aws/security_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ def get_findings(filters):
"get_findings",
paginated=True,
role_arn=LOGGING_ROLE_ARN,
filters=filters,
Filters=filters,
)
return response
Loading

0 comments on commit 84c115c

Please sign in to comment.