Skip to content

Commit 0b85fc9

Browse files
authored
Reduce allocations in trigger.py. (#470)
1 parent 3414cfb commit 0b85fc9

File tree

1 file changed

+33
-25
lines changed

1 file changed

+33
-25
lines changed

datadog_lambda/trigger.py

+33-25
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,10 @@ def get_first_record(event):
110110

111111
def parse_event_source(event: dict) -> _EventSource:
112112
"""Determines the source of the trigger event"""
113-
if type(event) is not dict:
113+
if not isinstance(event, dict):
114114
return _EventSource(EventTypes.UNKNOWN)
115115

116-
event_source = _EventSource(EventTypes.UNKNOWN)
116+
event_source = None
117117

118118
request_context = event.get("requestContext")
119119
if request_context and request_context.get("stage"):
@@ -126,7 +126,7 @@ def parse_event_source(event: dict) -> _EventSource:
126126
event_source.subtype = EventSubtypes.API_GATEWAY
127127
if "routeKey" in event:
128128
event_source.subtype = EventSubtypes.HTTP_API
129-
if event.get("requestContext", {}).get("messageDirection"):
129+
if request_context.get("messageDirection"):
130130
event_source.subtype = EventSubtypes.WEBSOCKET
131131

132132
if request_context and request_context.get("elb"):
@@ -151,10 +151,9 @@ def parse_event_source(event: dict) -> _EventSource:
151151

152152
event_record = get_first_record(event)
153153
if event_record:
154-
aws_event_source = event_record.get(
155-
"eventSource", event_record.get("EventSource")
154+
aws_event_source = event_record.get("eventSource") or event_record.get(
155+
"EventSource"
156156
)
157-
158157
if aws_event_source == "aws:dynamodb":
159158
event_source = _EventSource(EventTypes.DYNAMODB)
160159
if aws_event_source == "aws:kinesis":
@@ -165,11 +164,10 @@ def parse_event_source(event: dict) -> _EventSource:
165164
event_source = _EventSource(EventTypes.SNS)
166165
if aws_event_source == "aws:sqs":
167166
event_source = _EventSource(EventTypes.SQS)
168-
169167
if event_record.get("cf"):
170168
event_source = _EventSource(EventTypes.CLOUDFRONT)
171169

172-
return event_source
170+
return event_source or _EventSource(EventTypes.UNKNOWN)
173171

174172

175173
def detect_lambda_function_url_domain(domain: str) -> bool:
@@ -193,11 +191,19 @@ def parse_event_source_arn(source: _EventSource, event: dict, context: Any) -> s
193191
event_record = get_first_record(event)
194192
# e.g. arn:aws:s3:::lambda-xyz123-abc890
195193
if source.to_string() == "s3":
196-
return event_record.get("s3", {}).get("bucket", {}).get("arn")
194+
s3 = event_record.get("s3")
195+
if s3:
196+
bucket = s3.get("bucket")
197+
if bucket:
198+
return bucket.get("arn")
199+
return None
197200

198201
# e.g. arn:aws:sns:us-east-1:123456789012:sns-lambda
199202
if source.to_string() == "sns":
200-
return event_record.get("Sns", {}).get("TopicArn")
203+
sns = event_record.get("Sns")
204+
if sns:
205+
return sns.get("TopicArn")
206+
return None
201207

202208
# e.g. arn:aws:cloudfront::123456789012:distribution/ABC123XYZ
203209
if source.event_type == EventTypes.CLOUDFRONT:
@@ -228,7 +234,11 @@ def parse_event_source_arn(source: _EventSource, event: dict, context: Any) -> s
228234
# e.g. arn:aws:elasticloadbalancing:us-east-1:123456789012:targetgroup/lambda-xyz/123
229235
if source.event_type == EventTypes.ALB:
230236
request_context = event.get("requestContext")
231-
return request_context.get("elb", {}).get("targetGroupArn")
237+
if request_context:
238+
elb = request_context.get("elb")
239+
if elb:
240+
return elb.get("targetGroupArn")
241+
return None
232242

233243
# e.g. arn:aws:logs:us-west-1:123456789012:log-group:/my-log-group-xyz
234244
if source.event_type == EventTypes.CLOUDWATCH_LOGS:
@@ -292,6 +302,13 @@ def extract_http_tags(event):
292302
return http_tags
293303

294304

305+
_http_event_types = (
306+
EventTypes.API_GATEWAY,
307+
EventTypes.ALB,
308+
EventTypes.LAMBDA_FUNCTION_URL,
309+
)
310+
311+
295312
def extract_trigger_tags(event: dict, context: Any) -> dict:
296313
"""
297314
Parses the trigger event object to get tags to be added to the span metadata
@@ -305,16 +322,15 @@ def extract_trigger_tags(event: dict, context: Any) -> dict:
305322
if event_source_arn:
306323
trigger_tags["function_trigger.event_source_arn"] = event_source_arn
307324

308-
if event_source.event_type in [
309-
EventTypes.API_GATEWAY,
310-
EventTypes.ALB,
311-
EventTypes.LAMBDA_FUNCTION_URL,
312-
]:
325+
if event_source.event_type in _http_event_types:
313326
trigger_tags.update(extract_http_tags(event))
314327

315328
return trigger_tags
316329

317330

331+
_str_http_triggers = [et.value for et in _http_event_types]
332+
333+
318334
def extract_http_status_code_tag(trigger_tags, response):
319335
"""
320336
If the Lambda was triggered by API Gateway, Lambda Function URL, or ALB,
@@ -325,15 +341,7 @@ def extract_http_status_code_tag(trigger_tags, response):
325341
str_event_source = trigger_tags.get("function_trigger.event_source")
326342
# it would be cleaner if each event type was a constant object that
327343
# knew some properties about itself like this.
328-
str_http_triggers = [
329-
et.value
330-
for et in [
331-
EventTypes.API_GATEWAY,
332-
EventTypes.LAMBDA_FUNCTION_URL,
333-
EventTypes.ALB,
334-
]
335-
]
336-
if str_event_source not in str_http_triggers:
344+
if str_event_source not in _str_http_triggers:
337345
return
338346

339347
status_code = "200"

0 commit comments

Comments
 (0)