diff --git a/datadog_lambda/trigger.py b/datadog_lambda/trigger.py index 0bb26d59..06a4507c 100644 --- a/datadog_lambda/trigger.py +++ b/datadog_lambda/trigger.py @@ -110,10 +110,10 @@ def get_first_record(event): def parse_event_source(event: dict) -> _EventSource: """Determines the source of the trigger event""" - if type(event) is not dict: + if not isinstance(event, dict): return _EventSource(EventTypes.UNKNOWN) - event_source = _EventSource(EventTypes.UNKNOWN) + event_source = None request_context = event.get("requestContext") if request_context and request_context.get("stage"): @@ -126,7 +126,7 @@ def parse_event_source(event: dict) -> _EventSource: event_source.subtype = EventSubtypes.API_GATEWAY if "routeKey" in event: event_source.subtype = EventSubtypes.HTTP_API - if event.get("requestContext", {}).get("messageDirection"): + if request_context.get("messageDirection"): event_source.subtype = EventSubtypes.WEBSOCKET if request_context and request_context.get("elb"): @@ -151,10 +151,9 @@ def parse_event_source(event: dict) -> _EventSource: event_record = get_first_record(event) if event_record: - aws_event_source = event_record.get( - "eventSource", event_record.get("EventSource") + aws_event_source = event_record.get("eventSource") or event_record.get( + "EventSource" ) - if aws_event_source == "aws:dynamodb": event_source = _EventSource(EventTypes.DYNAMODB) if aws_event_source == "aws:kinesis": @@ -165,11 +164,10 @@ def parse_event_source(event: dict) -> _EventSource: event_source = _EventSource(EventTypes.SNS) if aws_event_source == "aws:sqs": event_source = _EventSource(EventTypes.SQS) - if event_record.get("cf"): event_source = _EventSource(EventTypes.CLOUDFRONT) - return event_source + return event_source or _EventSource(EventTypes.UNKNOWN) def detect_lambda_function_url_domain(domain: str) -> bool: @@ -193,11 +191,19 @@ def parse_event_source_arn(source: _EventSource, event: dict, context: Any) -> s event_record = get_first_record(event) # e.g. arn:aws:s3:::lambda-xyz123-abc890 if source.to_string() == "s3": - return event_record.get("s3", {}).get("bucket", {}).get("arn") + s3 = event_record.get("s3") + if s3: + bucket = s3.get("bucket") + if bucket: + return bucket.get("arn") + return None # e.g. arn:aws:sns:us-east-1:123456789012:sns-lambda if source.to_string() == "sns": - return event_record.get("Sns", {}).get("TopicArn") + sns = event_record.get("Sns") + if sns: + return sns.get("TopicArn") + return None # e.g. arn:aws:cloudfront::123456789012:distribution/ABC123XYZ if source.event_type == EventTypes.CLOUDFRONT: @@ -228,7 +234,11 @@ def parse_event_source_arn(source: _EventSource, event: dict, context: Any) -> s # e.g. arn:aws:elasticloadbalancing:us-east-1:123456789012:targetgroup/lambda-xyz/123 if source.event_type == EventTypes.ALB: request_context = event.get("requestContext") - return request_context.get("elb", {}).get("targetGroupArn") + if request_context: + elb = request_context.get("elb") + if elb: + return elb.get("targetGroupArn") + return None # e.g. arn:aws:logs:us-west-1:123456789012:log-group:/my-log-group-xyz if source.event_type == EventTypes.CLOUDWATCH_LOGS: @@ -292,6 +302,13 @@ def extract_http_tags(event): return http_tags +_http_event_types = ( + EventTypes.API_GATEWAY, + EventTypes.ALB, + EventTypes.LAMBDA_FUNCTION_URL, +) + + def extract_trigger_tags(event: dict, context: Any) -> dict: """ Parses the trigger event object to get tags to be added to the span metadata @@ -305,16 +322,15 @@ def extract_trigger_tags(event: dict, context: Any) -> dict: if event_source_arn: trigger_tags["function_trigger.event_source_arn"] = event_source_arn - if event_source.event_type in [ - EventTypes.API_GATEWAY, - EventTypes.ALB, - EventTypes.LAMBDA_FUNCTION_URL, - ]: + if event_source.event_type in _http_event_types: trigger_tags.update(extract_http_tags(event)) return trigger_tags +_str_http_triggers = [et.value for et in _http_event_types] + + def extract_http_status_code_tag(trigger_tags, response): """ If the Lambda was triggered by API Gateway, Lambda Function URL, or ALB, @@ -325,15 +341,7 @@ def extract_http_status_code_tag(trigger_tags, response): str_event_source = trigger_tags.get("function_trigger.event_source") # it would be cleaner if each event type was a constant object that # knew some properties about itself like this. - str_http_triggers = [ - et.value - for et in [ - EventTypes.API_GATEWAY, - EventTypes.LAMBDA_FUNCTION_URL, - EventTypes.ALB, - ] - ] - if str_event_source not in str_http_triggers: + if str_event_source not in _str_http_triggers: return status_code = "200"