@@ -110,10 +110,10 @@ def get_first_record(event):
110
110
111
111
def parse_event_source (event : dict ) -> _EventSource :
112
112
"""Determines the source of the trigger event"""
113
- if type (event ) is not dict :
113
+ if not isinstance (event , dict ) :
114
114
return _EventSource (EventTypes .UNKNOWN )
115
115
116
- event_source = _EventSource ( EventTypes . UNKNOWN )
116
+ event_source = None
117
117
118
118
request_context = event .get ("requestContext" )
119
119
if request_context and request_context .get ("stage" ):
@@ -126,7 +126,7 @@ def parse_event_source(event: dict) -> _EventSource:
126
126
event_source .subtype = EventSubtypes .API_GATEWAY
127
127
if "routeKey" in event :
128
128
event_source .subtype = EventSubtypes .HTTP_API
129
- if event . get ( "requestContext" , {}) .get ("messageDirection" ):
129
+ if request_context .get ("messageDirection" ):
130
130
event_source .subtype = EventSubtypes .WEBSOCKET
131
131
132
132
if request_context and request_context .get ("elb" ):
@@ -151,10 +151,7 @@ def parse_event_source(event: dict) -> _EventSource:
151
151
152
152
event_record = get_first_record (event )
153
153
if event_record :
154
- aws_event_source = event_record .get (
155
- "eventSource" , event_record .get ("EventSource" )
156
- )
157
-
154
+ aws_event_source = event_record .get ("eventSource" ) or event_record .get ("EventSource" )
158
155
if aws_event_source == "aws:dynamodb" :
159
156
event_source = _EventSource (EventTypes .DYNAMODB )
160
157
if aws_event_source == "aws:kinesis" :
@@ -165,11 +162,10 @@ def parse_event_source(event: dict) -> _EventSource:
165
162
event_source = _EventSource (EventTypes .SNS )
166
163
if aws_event_source == "aws:sqs" :
167
164
event_source = _EventSource (EventTypes .SQS )
168
-
169
165
if event_record .get ("cf" ):
170
166
event_source = _EventSource (EventTypes .CLOUDFRONT )
171
167
172
- return event_source
168
+ return event_source or _EventSource ( EventTypes . UNKNOWN )
173
169
174
170
175
171
def detect_lambda_function_url_domain (domain : str ) -> bool :
@@ -193,11 +189,19 @@ def parse_event_source_arn(source: _EventSource, event: dict, context: Any) -> s
193
189
event_record = get_first_record (event )
194
190
# e.g. arn:aws:s3:::lambda-xyz123-abc890
195
191
if source .to_string () == "s3" :
196
- return event_record .get ("s3" , {}).get ("bucket" , {}).get ("arn" )
192
+ s3 = event_record .get ("s3" )
193
+ if s3 :
194
+ bucket = s3 .get ("bucket" )
195
+ if bucket :
196
+ return bucket .get ("arn" )
197
+ return None
197
198
198
199
# e.g. arn:aws:sns:us-east-1:123456789012:sns-lambda
199
200
if source .to_string () == "sns" :
200
- return event_record .get ("Sns" , {}).get ("TopicArn" )
201
+ sns = event_record .get ("Sns" )
202
+ if sns :
203
+ return sns .get ("TopicArn" )
204
+ return None
201
205
202
206
# e.g. arn:aws:cloudfront::123456789012:distribution/ABC123XYZ
203
207
if source .event_type == EventTypes .CLOUDFRONT :
@@ -228,7 +232,11 @@ def parse_event_source_arn(source: _EventSource, event: dict, context: Any) -> s
228
232
# e.g. arn:aws:elasticloadbalancing:us-east-1:123456789012:targetgroup/lambda-xyz/123
229
233
if source .event_type == EventTypes .ALB :
230
234
request_context = event .get ("requestContext" )
231
- return request_context .get ("elb" , {}).get ("targetGroupArn" )
235
+ if request_context :
236
+ elb = request_context .get ("elb" )
237
+ if elb :
238
+ return elb .get ("targetGroupArn" )
239
+ return None
232
240
233
241
# e.g. arn:aws:logs:us-west-1:123456789012:log-group:/my-log-group-xyz
234
242
if source .event_type == EventTypes .CLOUDWATCH_LOGS :
@@ -292,6 +300,13 @@ def extract_http_tags(event):
292
300
return http_tags
293
301
294
302
303
+ _http_event_types = (
304
+ EventTypes .API_GATEWAY ,
305
+ EventTypes .ALB ,
306
+ EventTypes .LAMBDA_FUNCTION_URL ,
307
+ )
308
+
309
+
295
310
def extract_trigger_tags (event : dict , context : Any ) -> dict :
296
311
"""
297
312
Parses the trigger event object to get tags to be added to the span metadata
@@ -305,16 +320,15 @@ def extract_trigger_tags(event: dict, context: Any) -> dict:
305
320
if event_source_arn :
306
321
trigger_tags ["function_trigger.event_source_arn" ] = event_source_arn
307
322
308
- if event_source .event_type in [
309
- EventTypes .API_GATEWAY ,
310
- EventTypes .ALB ,
311
- EventTypes .LAMBDA_FUNCTION_URL ,
312
- ]:
323
+ if event_source .event_type in _http_event_types :
313
324
trigger_tags .update (extract_http_tags (event ))
314
325
315
326
return trigger_tags
316
327
317
328
329
+ _str_http_triggers = [et .value for et in _http_event_types ]
330
+
331
+
318
332
def extract_http_status_code_tag (trigger_tags , response ):
319
333
"""
320
334
If the Lambda was triggered by API Gateway, Lambda Function URL, or ALB,
@@ -325,15 +339,7 @@ def extract_http_status_code_tag(trigger_tags, response):
325
339
str_event_source = trigger_tags .get ("function_trigger.event_source" )
326
340
# it would be cleaner if each event type was a constant object that
327
341
# knew some properties about itself like this.
328
- str_http_triggers = [
329
- et .value
330
- for et in [
331
- EventTypes .API_GATEWAY ,
332
- EventTypes .LAMBDA_FUNCTION_URL ,
333
- EventTypes .ALB ,
334
- ]
335
- ]
336
- if str_event_source not in str_http_triggers :
342
+ if str_event_source not in _str_http_triggers :
337
343
return
338
344
339
345
status_code = "200"
0 commit comments