Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CLOUDS-4012] Fix service tag setting in lambda forwarder #714

Merged
merged 12 commits into from
Dec 5, 2023
2 changes: 1 addition & 1 deletion aws/logs_monitoring/.gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
*.zip
tools/layers
.forwarder
.forwarder
10 changes: 8 additions & 2 deletions aws/logs_monitoring/lambda_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,16 @@ def get(self, key):
Returns:
lambda_tags (str[]): the list of "key:value" Datadog tag strings
"""
if not self.should_fetch_tags():
tyrcho marked this conversation as resolved.
Show resolved Hide resolved
logger.debug(
"Not fetching lambda function tags because the env variable DD_FETCH_LAMBDA_TAGS is "
"not set to true"
)
return []

if self._is_expired():
send_forwarder_internal_metrics("local_cache_expired")
logger.debug("Local cache expired, fetching cache from S3")
self._refresh()

function_tags = self.tags_by_id.get(key, [])
return function_tags
return self.tags_by_id.get(key, [])
39 changes: 31 additions & 8 deletions aws/logs_monitoring/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,14 +257,21 @@ def add_metadata_to_lambda_log(event):
# Get custom tags of the Lambda function
custom_lambda_tags = get_enriched_lambda_log_tags(event)

# Set the `service` tag and metadata field. If the Lambda function is
tyrcho marked this conversation as resolved.
Show resolved Hide resolved
# tagged with a `service` tag, use it, otherwise use the function name.
service_tag = next(
(tag for tag in custom_lambda_tags if tag.startswith("service:")),
f"service:{function_name}",
)
tags.append(service_tag)
event[DD_SERVICE] = service_tag.split(":")[1]
# If not set during parsing or has a default value
# then set the service tag from lambda tags cache or using the function name
# otherwise, remove the service tag from the custom lambda tags if exists to avoid duplication
if not event[DD_SERVICE] or event[DD_SERVICE] == event[DD_SOURCE]:
service_tag = next(
(tag for tag in custom_lambda_tags if tag.startswith("service:")),
f"service:{function_name}",
)
if service_tag:
tags.append(service_tag)
event[DD_SERVICE] = service_tag.split(":")[1]
else:
custom_lambda_tags = [
tag for tag in custom_lambda_tags if not tag.startswith("service:")
]

# Check if one of the Lambda's custom tags is env
# If an env tag exists, remove the env:none placeholder
Expand Down Expand Up @@ -319,6 +326,22 @@ def extract_ddtags_from_message(event):
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Failed to extract ddtags from: {event}")
return

# Extract service tag from message.ddtags if exists
if "service" in extracted_ddtags:
event[DD_SERVICE] = next(
tag[8:]
for tag in extracted_ddtags.split(",")
if tag.startswith("service:")
)
event[DD_CUSTOM_TAGS] = ",".join(
[
tag
for tag in event[DD_CUSTOM_TAGS].split(",")
if not tag.startswith("service")
]
)

event[DD_CUSTOM_TAGS] = f"{event[DD_CUSTOM_TAGS]},{extracted_ddtags}"


Expand Down
22 changes: 14 additions & 8 deletions aws/logs_monitoring/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def s3_handler(event, context, metadata):
source = "transitgateway"
metadata[DD_SOURCE] = source

metadata[DD_SERVICE] = get_service_from_tags(metadata)
metadata[DD_SERVICE] = get_service_from_tags_and_remove_duplicates(metadata)

##Get the ARN of the service and set it as the hostname
hostname = parse_service_arn(source, key, bucket, context)
Expand Down Expand Up @@ -242,15 +242,21 @@ def s3_handler(event, context, metadata):
yield structured_line


def get_service_from_tags(metadata):
# Get service from dd_custom_tags if it exists
def get_service_from_tags_and_remove_duplicates(metadata):
tyrcho marked this conversation as resolved.
Show resolved Hide resolved
service = ""
tagsplit = metadata[DD_CUSTOM_TAGS].split(",")
for tag in tagsplit:
for i, tag in enumerate(tagsplit):
if tag.startswith("service:"):
return tag[8:]
if service:
# remove duplicate entry from the tags
del tagsplit[i]
else:
service = tag[8:]

metadata[DD_CUSTOM_TAGS] = ",".join(tagsplit)

# Default service to source value
return metadata[DD_SOURCE]
return service if service else metadata[DD_SOURCE]


def parse_event_source(event, key):
Expand Down Expand Up @@ -530,7 +536,7 @@ def awslogs_handler(event, context, metadata):

# Set service from custom tags, which may include the tags set on the log group
# Returns DD_SOURCE by default
metadata[DD_SERVICE] = get_service_from_tags(metadata)
metadata[DD_SERVICE] = get_service_from_tags_and_remove_duplicates(metadata)

# Set host as log group where cloudwatch is source
if metadata[DD_SOURCE] == "cloudwatch" or metadata.get(DD_HOST, None) == None:
Expand Down Expand Up @@ -640,7 +646,7 @@ def cwevent_handler(event, metadata):
else:
metadata[DD_SOURCE] = "cloudwatch"

metadata[DD_SERVICE] = get_service_from_tags(metadata)
metadata[DD_SERVICE] = get_service_from_tags_and_remove_duplicates(metadata)

yield data

Expand Down
212 changes: 197 additions & 15 deletions aws/logs_monitoring/tests/test_lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from time import time
from botocore.exceptions import ClientError
from approvaltests.approvals import verify_as_json
from importlib import reload

sys.modules["trace_forwarder.connection"] = MagicMock()
sys.modules["datadog_lambda.wrapper"] = MagicMock()
Expand All @@ -34,6 +35,7 @@
enrich,
transform,
split,
extract_ddtags_from_message,
)
from parsing import parse, parse_event_type

Expand Down Expand Up @@ -130,12 +132,8 @@ def create_cloudwatch_log_event_from_data(data):


class TestLambdaFunctionEndToEnd(unittest.TestCase):
@patch("cloudwatch_log_group_cache.CloudwatchLogGroupTagsCache.get")
@patch("base_tags_cache.send_forwarder_internal_metrics")
@patch("enhanced_lambda_metrics.LambdaTagsCache.get_cache_from_s3")
def test_datadog_forwarder(
self, mock_get_s3_cache, mock_forward_metrics, cw_logs_tags_get
):
def test_datadog_forwarder(self, mock_get_s3_cache):
mock_get_s3_cache.return_value = (
{
"arn:aws:lambda:sa-east-1:601427279990:function:inferred-spans-python-dev-initsender": [
Expand All @@ -149,15 +147,7 @@ def test_datadog_forwarder(
time(),
)
context = Context()
my_path = os.path.abspath(os.path.dirname(__file__))
path = os.path.join(my_path, "events/cloudwatch_logs.json")

with open(
path,
"r",
) as input_file:
input_data = input_file.read()

input_data = self._get_input_data()
event = {"awslogs": {"data": create_cloudwatch_log_event_from_data(input_data)}}
os.environ["DD_FETCH_LAMBDA_TAGS"] = "True"

Expand All @@ -170,7 +160,7 @@ def test_datadog_forwarder(

verify_as_json(transformed_events)

metrics, logs, trace_payloads = split(transformed_events)
_, _, trace_payloads = split(transformed_events)
self.assertEqual(len(trace_payloads), 1)

trace_payload = json.loads(trace_payloads[0]["message"])
Expand Down Expand Up @@ -204,6 +194,98 @@ def test_datadog_forwarder(

del os.environ["DD_FETCH_LAMBDA_TAGS"]

@patch("cloudwatch_log_group_cache.CloudwatchLogGroupTagsCache.get")
def test_setting_service_tag_from_log_group_cache(self, cw_logs_tags_get):
reload(sys.modules["settings"])
reload(sys.modules["parsing"])
cw_logs_tags_get.return_value = ["service:log_group_service"]
context = Context()
input_data = self._get_input_data()
event = {"awslogs": {"data": create_cloudwatch_log_event_from_data(input_data)}}

normalized_events = parse(event, context)
enriched_events = enrich(normalized_events)
transformed_events = transform(enriched_events)

_, logs, _ = split(transformed_events)
self.assertEqual(len(logs), 16)
for log in logs:
self.assertEqual(log["service"], "log_group_service")

@patch.dict(os.environ, {"DD_TAGS": "service:dd_tag_service"}, clear=True)
@patch("cloudwatch_log_group_cache.CloudwatchLogGroupTagsCache.get")
def test_service_override_from_dd_tags(self, cw_logs_tags_get):
reload(sys.modules["settings"])
reload(sys.modules["parsing"])
cw_logs_tags_get.return_value = ["service:log_group_service"]
context = Context()
input_data = self._get_input_data()
event = {"awslogs": {"data": create_cloudwatch_log_event_from_data(input_data)}}

normalized_events = parse(event, context)
enriched_events = enrich(normalized_events)
transformed_events = transform(enriched_events)

_, logs, _ = split(transformed_events)
self.assertEqual(len(logs), 16)
for log in logs:
self.assertEqual(log["service"], "dd_tag_service")

@patch("lambda_cache.LambdaTagsCache.get")
@patch("cloudwatch_log_group_cache.CloudwatchLogGroupTagsCache.get")
def test_overrding_service_tag_from_lambda_cache(
self, lambda_tags_get, cw_logs_tags_get
):
lambda_tags_get.return_value = ["service:lambda_service"]
cw_logs_tags_get.return_value = ["service:log_group_service"]

context = Context()
input_data = self._get_input_data()
event = {"awslogs": {"data": create_cloudwatch_log_event_from_data(input_data)}}

normalized_events = parse(event, context)
enriched_events = enrich(normalized_events)
transformed_events = transform(enriched_events)

_, logs, _ = split(transformed_events)
self.assertEqual(len(logs), 16)
for log in logs:
self.assertEqual(log["service"], "lambda_service")

@patch.dict(os.environ, {"DD_TAGS": "service:dd_tag_service"}, clear=True)
@patch("lambda_cache.LambdaTagsCache.get")
@patch("cloudwatch_log_group_cache.CloudwatchLogGroupTagsCache.get")
def test_overrding_service_tag_from_lambda_cache_when_dd_tags_is_set(
self, lambda_tags_get, cw_logs_tags_get
):
lambda_tags_get.return_value = ["service:lambda_service"]
cw_logs_tags_get.return_value = ["service:log_group_service"]

context = Context()
input_data = self._get_input_data()
event = {"awslogs": {"data": create_cloudwatch_log_event_from_data(input_data)}}

normalized_events = parse(event, context)
enriched_events = enrich(normalized_events)
transformed_events = transform(enriched_events)

_, logs, _ = split(transformed_events)
self.assertEqual(len(logs), 16)
for log in logs:
self.assertEqual(log["service"], "lambda_service")

def _get_input_data(self):
my_path = os.path.abspath(os.path.dirname(__file__))
path = os.path.join(my_path, "events/cloudwatch_logs.json")

with open(
path,
"r",
) as input_file:
input_data = input_file.read()

return input_data


class TestLambdaFunctionExtractTracePayload(unittest.TestCase):
def test_extract_trace_payload_none_no_trace(self):
Expand Down Expand Up @@ -234,5 +316,105 @@ def test_extract_trace_payload_valid_trace(self):
)


class TestMergeMessageTags(unittest.TestCase):
message_tags = '{"ddtags":"service:my_application_service,custom_tag_1:value1"}'
custom_tags = "custom_tag_2:value2,service:my_custom_service"

def test_extract_ddtags_from_message_str(self):
event = {
"message": self.message_tags,
"ddtags": self.custom_tags,
"service": "my_service",
}

extract_ddtags_from_message(event)

self.assertEqual(
event["ddtags"],
"custom_tag_2:value2,service:my_application_service,custom_tag_1:value1",
)
self.assertEqual(
event["service"],
"my_application_service",
)

def test_extract_ddtags_from_message_dict(self):
loaded_message_tags = json.loads(self.message_tags)
event = {
"message": loaded_message_tags,
"ddtags": self.custom_tags,
"service": "my_service",
}

extract_ddtags_from_message(event)

self.assertEqual(
event["ddtags"],
"custom_tag_2:value2,service:my_application_service,custom_tag_1:value1",
)
self.assertEqual(
event["service"],
"my_application_service",
)

def test_extract_ddtags_from_message_service_tag_setting(self):
loaded_message_tags = json.loads(self.message_tags)
loaded_message_tags["ddtags"] = ",".join(
[
tag
for tag in loaded_message_tags["ddtags"].split(",")
if not tag.startswith("service:")
]
)
event = {
"message": loaded_message_tags,
"ddtags": self.custom_tags,
"service": "my_custom_service",
}

extract_ddtags_from_message(event)

self.assertEqual(
event["ddtags"],
"custom_tag_2:value2,service:my_custom_service,custom_tag_1:value1",
)
self.assertEqual(
event["service"],
"my_custom_service",
)

def test_extract_ddtags_from_message_multiple_service_tag_values(self):
custom_tags = self.custom_tags + ",service:my_custom_service_2"
event = {"message": self.message_tags, "ddtags": custom_tags}

extract_ddtags_from_message(event)

self.assertEqual(
event["ddtags"],
"custom_tag_2:value2,service:my_application_service,custom_tag_1:value1",
)
self.assertEqual(
event["service"],
"my_application_service",
)

def test_extract_ddtags_from_message_multiple_values_tag(self):
loaded_message_tags = json.loads(self.message_tags)
loaded_message_tags["ddtags"] += ",custom_tag_3:value4"
custom_tags = self.custom_tags + ",custom_tag_3:value3"
event = {"message": loaded_message_tags, "ddtags": custom_tags}

extract_ddtags_from_message(event)

self.assertEqual(
event["ddtags"],
"custom_tag_2:value2,custom_tag_3:value3,service:my_application_service,custom_tag_1:value1,custom_tag_3:value4",
)
self.assertEqual(
event["service"],
"my_application_service",
)


if __name__ == "__main__":
unittest.main()
Loading