@@ -110,10 +110,10 @@ def get_first_record(event):
110
110
111
111
def parse_event_source (event : dict ) -> _EventSource :
112
112
"""Determines the source of the trigger event"""
113
- if type (event ) is not dict :
113
+ if not isinstance (event , dict ) :
114
114
return _EventSource (EventTypes .UNKNOWN )
115
115
116
- event_source = _EventSource ( EventTypes . UNKNOWN )
116
+ event_source = None
117
117
118
118
request_context = event .get ("requestContext" )
119
119
if request_context and request_context .get ("stage" ):
@@ -126,7 +126,7 @@ def parse_event_source(event: dict) -> _EventSource:
126
126
event_source .subtype = EventSubtypes .API_GATEWAY
127
127
if "routeKey" in event :
128
128
event_source .subtype = EventSubtypes .HTTP_API
129
- if event . get ( "requestContext" , {}) .get ("messageDirection" ):
129
+ if request_context .get ("messageDirection" ):
130
130
event_source .subtype = EventSubtypes .WEBSOCKET
131
131
132
132
if request_context and request_context .get ("elb" ):
@@ -151,10 +151,9 @@ def parse_event_source(event: dict) -> _EventSource:
151
151
152
152
event_record = get_first_record (event )
153
153
if event_record :
154
- aws_event_source = event_record .get (
155
- "eventSource" , event_record . get ( " EventSource")
154
+ aws_event_source = event_record .get ("eventSource" ) or event_record . get (
155
+ "EventSource"
156
156
)
157
-
158
157
if aws_event_source == "aws:dynamodb" :
159
158
event_source = _EventSource (EventTypes .DYNAMODB )
160
159
if aws_event_source == "aws:kinesis" :
@@ -165,11 +164,10 @@ def parse_event_source(event: dict) -> _EventSource:
165
164
event_source = _EventSource (EventTypes .SNS )
166
165
if aws_event_source == "aws:sqs" :
167
166
event_source = _EventSource (EventTypes .SQS )
168
-
169
167
if event_record .get ("cf" ):
170
168
event_source = _EventSource (EventTypes .CLOUDFRONT )
171
169
172
- return event_source
170
+ return event_source or _EventSource ( EventTypes . UNKNOWN )
173
171
174
172
175
173
def detect_lambda_function_url_domain (domain : str ) -> bool :
@@ -193,11 +191,19 @@ def parse_event_source_arn(source: _EventSource, event: dict, context: Any) -> s
193
191
event_record = get_first_record (event )
194
192
# e.g. arn:aws:s3:::lambda-xyz123-abc890
195
193
if source .to_string () == "s3" :
196
- return event_record .get ("s3" , {}).get ("bucket" , {}).get ("arn" )
194
+ s3 = event_record .get ("s3" )
195
+ if s3 :
196
+ bucket = s3 .get ("bucket" )
197
+ if bucket :
198
+ return bucket .get ("arn" )
199
+ return None
197
200
198
201
# e.g. arn:aws:sns:us-east-1:123456789012:sns-lambda
199
202
if source .to_string () == "sns" :
200
- return event_record .get ("Sns" , {}).get ("TopicArn" )
203
+ sns = event_record .get ("Sns" )
204
+ if sns :
205
+ return sns .get ("TopicArn" )
206
+ return None
201
207
202
208
# e.g. arn:aws:cloudfront::123456789012:distribution/ABC123XYZ
203
209
if source .event_type == EventTypes .CLOUDFRONT :
@@ -228,7 +234,11 @@ def parse_event_source_arn(source: _EventSource, event: dict, context: Any) -> s
228
234
# e.g. arn:aws:elasticloadbalancing:us-east-1:123456789012:targetgroup/lambda-xyz/123
229
235
if source .event_type == EventTypes .ALB :
230
236
request_context = event .get ("requestContext" )
231
- return request_context .get ("elb" , {}).get ("targetGroupArn" )
237
+ if request_context :
238
+ elb = request_context .get ("elb" )
239
+ if elb :
240
+ return elb .get ("targetGroupArn" )
241
+ return None
232
242
233
243
# e.g. arn:aws:logs:us-west-1:123456789012:log-group:/my-log-group-xyz
234
244
if source .event_type == EventTypes .CLOUDWATCH_LOGS :
@@ -292,6 +302,13 @@ def extract_http_tags(event):
292
302
return http_tags
293
303
294
304
305
+ _http_event_types = (
306
+ EventTypes .API_GATEWAY ,
307
+ EventTypes .ALB ,
308
+ EventTypes .LAMBDA_FUNCTION_URL ,
309
+ )
310
+
311
+
295
312
def extract_trigger_tags (event : dict , context : Any ) -> dict :
296
313
"""
297
314
Parses the trigger event object to get tags to be added to the span metadata
@@ -305,16 +322,15 @@ def extract_trigger_tags(event: dict, context: Any) -> dict:
305
322
if event_source_arn :
306
323
trigger_tags ["function_trigger.event_source_arn" ] = event_source_arn
307
324
308
- if event_source .event_type in [
309
- EventTypes .API_GATEWAY ,
310
- EventTypes .ALB ,
311
- EventTypes .LAMBDA_FUNCTION_URL ,
312
- ]:
325
+ if event_source .event_type in _http_event_types :
313
326
trigger_tags .update (extract_http_tags (event ))
314
327
315
328
return trigger_tags
316
329
317
330
331
+ _str_http_triggers = [et .value for et in _http_event_types ]
332
+
333
+
318
334
def extract_http_status_code_tag (trigger_tags , response ):
319
335
"""
320
336
If the Lambda was triggered by API Gateway, Lambda Function URL, or ALB,
@@ -325,15 +341,7 @@ def extract_http_status_code_tag(trigger_tags, response):
325
341
str_event_source = trigger_tags .get ("function_trigger.event_source" )
326
342
# it would be cleaner if each event type was a constant object that
327
343
# knew some properties about itself like this.
328
- str_http_triggers = [
329
- et .value
330
- for et in [
331
- EventTypes .API_GATEWAY ,
332
- EventTypes .LAMBDA_FUNCTION_URL ,
333
- EventTypes .ALB ,
334
- ]
335
- ]
336
- if str_event_source not in str_http_triggers :
344
+ if str_event_source not in _str_http_triggers :
337
345
return
338
346
339
347
status_code = "200"
0 commit comments