Skip to content

Commit fdc3a29

Browse files
committed
Reduce allocations in trigger.py.
1 parent 3414cfb commit fdc3a29

File tree

1 file changed

+32
-26
lines changed

1 file changed

+32
-26
lines changed

datadog_lambda/trigger.py

+32-26
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,10 @@ def get_first_record(event):
110110

111111
def parse_event_source(event: dict) -> _EventSource:
112112
"""Determines the source of the trigger event"""
113-
if type(event) is not dict:
113+
if not isinstance(event, dict):
114114
return _EventSource(EventTypes.UNKNOWN)
115115

116-
event_source = _EventSource(EventTypes.UNKNOWN)
116+
event_source = None
117117

118118
request_context = event.get("requestContext")
119119
if request_context and request_context.get("stage"):
@@ -126,7 +126,7 @@ def parse_event_source(event: dict) -> _EventSource:
126126
event_source.subtype = EventSubtypes.API_GATEWAY
127127
if "routeKey" in event:
128128
event_source.subtype = EventSubtypes.HTTP_API
129-
if event.get("requestContext", {}).get("messageDirection"):
129+
if request_context.get("messageDirection"):
130130
event_source.subtype = EventSubtypes.WEBSOCKET
131131

132132
if request_context and request_context.get("elb"):
@@ -151,10 +151,7 @@ def parse_event_source(event: dict) -> _EventSource:
151151

152152
event_record = get_first_record(event)
153153
if event_record:
154-
aws_event_source = event_record.get(
155-
"eventSource", event_record.get("EventSource")
156-
)
157-
154+
aws_event_source = event_record.get("eventSource") or event_record.get("EventSource")
158155
if aws_event_source == "aws:dynamodb":
159156
event_source = _EventSource(EventTypes.DYNAMODB)
160157
if aws_event_source == "aws:kinesis":
@@ -165,11 +162,10 @@ def parse_event_source(event: dict) -> _EventSource:
165162
event_source = _EventSource(EventTypes.SNS)
166163
if aws_event_source == "aws:sqs":
167164
event_source = _EventSource(EventTypes.SQS)
168-
169165
if event_record.get("cf"):
170166
event_source = _EventSource(EventTypes.CLOUDFRONT)
171167

172-
return event_source
168+
return event_source or _EventSource(EventTypes.UNKNOWN)
173169

174170

175171
def detect_lambda_function_url_domain(domain: str) -> bool:
@@ -193,11 +189,19 @@ def parse_event_source_arn(source: _EventSource, event: dict, context: Any) -> s
193189
event_record = get_first_record(event)
194190
# e.g. arn:aws:s3:::lambda-xyz123-abc890
195191
if source.to_string() == "s3":
196-
return event_record.get("s3", {}).get("bucket", {}).get("arn")
192+
s3 = event_record.get("s3")
193+
if s3:
194+
bucket = s3.get("bucket")
195+
if bucket:
196+
return bucket.get("arn")
197+
return None
197198

198199
# e.g. arn:aws:sns:us-east-1:123456789012:sns-lambda
199200
if source.to_string() == "sns":
200-
return event_record.get("Sns", {}).get("TopicArn")
201+
sns = event_record.get("Sns")
202+
if sns:
203+
return sns.get("TopicArn")
204+
return None
201205

202206
# e.g. arn:aws:cloudfront::123456789012:distribution/ABC123XYZ
203207
if source.event_type == EventTypes.CLOUDFRONT:
@@ -228,7 +232,11 @@ def parse_event_source_arn(source: _EventSource, event: dict, context: Any) -> s
228232
# e.g. arn:aws:elasticloadbalancing:us-east-1:123456789012:targetgroup/lambda-xyz/123
229233
if source.event_type == EventTypes.ALB:
230234
request_context = event.get("requestContext")
231-
return request_context.get("elb", {}).get("targetGroupArn")
235+
if request_context:
236+
elb = request_context.get("elb")
237+
if elb:
238+
return elb.get("targetGroupArn")
239+
return None
232240

233241
# e.g. arn:aws:logs:us-west-1:123456789012:log-group:/my-log-group-xyz
234242
if source.event_type == EventTypes.CLOUDWATCH_LOGS:
@@ -292,6 +300,13 @@ def extract_http_tags(event):
292300
return http_tags
293301

294302

303+
_http_event_types = (
304+
EventTypes.API_GATEWAY,
305+
EventTypes.ALB,
306+
EventTypes.LAMBDA_FUNCTION_URL,
307+
)
308+
309+
295310
def extract_trigger_tags(event: dict, context: Any) -> dict:
296311
"""
297312
Parses the trigger event object to get tags to be added to the span metadata
@@ -305,16 +320,15 @@ def extract_trigger_tags(event: dict, context: Any) -> dict:
305320
if event_source_arn:
306321
trigger_tags["function_trigger.event_source_arn"] = event_source_arn
307322

308-
if event_source.event_type in [
309-
EventTypes.API_GATEWAY,
310-
EventTypes.ALB,
311-
EventTypes.LAMBDA_FUNCTION_URL,
312-
]:
323+
if event_source.event_type in _http_event_types:
313324
trigger_tags.update(extract_http_tags(event))
314325

315326
return trigger_tags
316327

317328

329+
_str_http_triggers = [et.value for et in _http_event_types]
330+
331+
318332
def extract_http_status_code_tag(trigger_tags, response):
319333
"""
320334
If the Lambda was triggered by API Gateway, Lambda Function URL, or ALB,
@@ -325,15 +339,7 @@ def extract_http_status_code_tag(trigger_tags, response):
325339
str_event_source = trigger_tags.get("function_trigger.event_source")
326340
# it would be cleaner if each event type was a constant object that
327341
# knew some properties about itself like this.
328-
str_http_triggers = [
329-
et.value
330-
for et in [
331-
EventTypes.API_GATEWAY,
332-
EventTypes.LAMBDA_FUNCTION_URL,
333-
EventTypes.ALB,
334-
]
335-
]
336-
if str_event_source not in str_http_triggers:
342+
if str_event_source not in _str_http_triggers:
337343
return
338344

339345
status_code = "200"

0 commit comments

Comments
 (0)