Skip to content

Commit 0259154

Browse files
committed
feat(data_classes): add API Gateway Websocket event
Pydantic models were added for API Gateway WebSocket events here: https://docs.powertools.aws.dev/lambda/python/latest/utilities/parser/#built-in-models Add the corresponding data model class.
1 parent 702b8a9 commit 0259154

File tree

3 files changed

+225
-0
lines changed

3 files changed

+225
-0
lines changed

aws_lambda_powertools/utilities/data_classes/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from .alb_event import ALBEvent
66
from .api_gateway_proxy_event import APIGatewayProxyEvent, APIGatewayProxyEventV2
7+
from .api_gateway_websocket_event import APIGatewayWebSocketEvent
78
from .appsync_resolver_event import AppSyncResolverEvent
89
from .aws_config_rule_event import AWSConfigRuleEvent
910
from .bedrock_agent_event import BedrockAgentEvent
@@ -51,6 +52,7 @@
5152
__all__ = [
5253
"APIGatewayProxyEvent",
5354
"APIGatewayProxyEventV2",
55+
"APIGatewayWebSocketEvent",
5456
"SecretsManagerEvent",
5557
"AppSyncResolverEvent",
5658
"ALBEvent",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
from __future__ import annotations
2+
3+
import base64
4+
from functools import cached_property
5+
from typing import Any
6+
7+
from aws_lambda_powertools.utilities.data_classes.common import (
8+
CaseInsensitiveDict,
9+
DictWrapper,
10+
)
11+
12+
13+
class APIGatewayWebSocketEventIdentity(DictWrapper):
14+
@property
15+
def source_ip(self) -> str:
16+
return self["sourceIp"]
17+
18+
@property
19+
def user_agent(self) -> str | None:
20+
return self.get("userAgent")
21+
22+
23+
class APIGatewayWebSocketEventRequestContext(DictWrapper):
24+
@property
25+
def route_key(self) -> str:
26+
return self["routeKey"]
27+
28+
@property
29+
def disconnect_status_code(self) -> int | None:
30+
return self.get("disconnectStatusCode")
31+
32+
@property
33+
def message_id(self) -> str | None:
34+
return self.get("messageId")
35+
36+
@property
37+
def event_type(self) -> str:
38+
return self["eventType"]
39+
40+
@property
41+
def extended_request_id(self) -> str:
42+
return self["extendedRequestId"]
43+
44+
@property
45+
def request_time(self) -> str:
46+
return self["requestTime"]
47+
48+
@property
49+
def message_direction(self) -> str:
50+
return self["messageDirection"]
51+
52+
@property
53+
def disconnect_reason(self) -> str | None:
54+
return self.get("disconnectReason")
55+
56+
@property
57+
def stage(self) -> str:
58+
return self["stage"]
59+
60+
@property
61+
def connected_at(self) -> int:
62+
return self["connectedAt"]
63+
64+
@property
65+
def request_time_epoch(self) -> int:
66+
return self["requestTimeEpoch"]
67+
68+
@property
69+
def identity(self) -> APIGatewayWebSocketEventIdentity:
70+
return APIGatewayWebSocketEventIdentity(self["identity"])
71+
72+
@property
73+
def request_id(self) -> str:
74+
return self["requestId"]
75+
76+
@property
77+
def domain_name(self) -> str:
78+
return self["domainName"]
79+
80+
@property
81+
def connection_id(self) -> str:
82+
return self["connectionId"]
83+
84+
@property
85+
def api_id(self) -> str:
86+
return self["apiId"]
87+
88+
89+
class APIGatewayWebSocketEvent(DictWrapper):
90+
"""AWS proxy integration event for WebSocket API
91+
92+
Documentation:
93+
--------------
94+
- https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-websocket-api-integration-requests.html
95+
"""
96+
97+
@property
98+
def is_base64_encoded(self) -> bool:
99+
return self["isBase64Encoded"]
100+
101+
@property
102+
def body(self) -> str | None:
103+
return self.get("body")
104+
105+
@cached_property
106+
def decoded_body(self) -> str | None:
107+
body = self.body
108+
if self.is_base64_encoded and body:
109+
return base64.b64decode(body.encode()).decode()
110+
return body
111+
112+
@cached_property
113+
def json_body(self) -> Any:
114+
if self.decoded_body:
115+
return self._json_deserializer(self.decoded_body)
116+
return None
117+
118+
@property
119+
def headers(self) -> dict[str, str]:
120+
return CaseInsensitiveDict(self.get("headers"))
121+
122+
@property
123+
def multi_value_headers(self) -> dict[str, list[str]]:
124+
return CaseInsensitiveDict(self.get("multiValueHeaders"))
125+
126+
@property
127+
def request_context(self) -> APIGatewayWebSocketEventRequestContext:
128+
return APIGatewayWebSocketEventRequestContext(self["requestContext"])
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from aws_lambda_powertools.utilities.data_classes import APIGatewayWebSocketEvent
2+
from tests.functional.utils import load_event
3+
4+
5+
def test_connect_api_gateway_websocket_event():
6+
raw_event = load_event("apiGatewayWebSocketApiConnect.json")
7+
parsed_event = APIGatewayWebSocketEvent(raw_event)
8+
9+
assert parsed_event.body is None
10+
assert parsed_event.headers == raw_event["headers"]
11+
assert parsed_event.multi_value_headers == raw_event["multiValueHeaders"]
12+
13+
request_context = parsed_event.request_context
14+
request_context_raw = raw_event["requestContext"]
15+
assert request_context.route_key == request_context_raw["routeKey"]
16+
assert request_context.disconnect_status_code is None
17+
assert request_context.message_id is None
18+
assert request_context.event_type == request_context_raw["eventType"]
19+
assert request_context.extended_request_id == request_context_raw["extendedRequestId"]
20+
assert request_context.request_time == request_context_raw["requestTime"]
21+
assert request_context.message_direction == request_context_raw["messageDirection"]
22+
assert request_context.disconnect_reason is None
23+
assert request_context.stage == request_context_raw["stage"]
24+
assert request_context.connected_at == request_context_raw["connectedAt"]
25+
assert request_context.request_time_epoch == request_context_raw["requestTimeEpoch"]
26+
assert request_context.request_id == request_context_raw["requestId"]
27+
assert request_context.domain_name == request_context_raw["domainName"]
28+
assert request_context.connection_id == request_context_raw["connectionId"]
29+
assert request_context.api_id == request_context_raw["apiId"]
30+
31+
identity = request_context.identity
32+
identity_raw = request_context_raw["identity"]
33+
assert identity.source_ip == identity_raw["sourceIp"]
34+
35+
36+
def test_disconnect_api_gateway_websocket_event():
37+
raw_event = load_event("apiGatewayWebSocketApiDisconnect.json")
38+
parsed_event = APIGatewayWebSocketEvent(raw_event)
39+
40+
assert parsed_event.body is None
41+
assert parsed_event.headers == raw_event["headers"]
42+
assert parsed_event.multi_value_headers == raw_event["multiValueHeaders"]
43+
44+
request_context = parsed_event.request_context
45+
request_context_raw = raw_event["requestContext"]
46+
assert request_context.route_key == request_context_raw["routeKey"]
47+
assert request_context.disconnect_status_code == request_context_raw["disconnectStatusCode"]
48+
assert request_context.message_id is None
49+
assert request_context.event_type == request_context_raw["eventType"]
50+
assert request_context.extended_request_id == request_context_raw["extendedRequestId"]
51+
assert request_context.request_time == request_context_raw["requestTime"]
52+
assert request_context.message_direction == request_context_raw["messageDirection"]
53+
assert request_context.disconnect_reason == request_context_raw["disconnectReason"]
54+
assert request_context.stage == request_context_raw["stage"]
55+
assert request_context.connected_at == request_context_raw["connectedAt"]
56+
assert request_context.request_time_epoch == request_context_raw["requestTimeEpoch"]
57+
assert request_context.request_id == request_context_raw["requestId"]
58+
assert request_context.domain_name == request_context_raw["domainName"]
59+
assert request_context.connection_id == request_context_raw["connectionId"]
60+
assert request_context.api_id == request_context_raw["apiId"]
61+
62+
identity = request_context.identity
63+
identity_raw = request_context_raw["identity"]
64+
assert identity.source_ip == identity_raw["sourceIp"]
65+
66+
67+
def test_message_api_gateway_websocket_event():
68+
raw_event = load_event("apiGatewayWebSocketApiMessage.json")
69+
parsed_event = APIGatewayWebSocketEvent(raw_event)
70+
71+
assert parsed_event.body == raw_event["body"]
72+
assert parsed_event.headers == {}
73+
assert parsed_event.multi_value_headers == {}
74+
75+
request_context = parsed_event.request_context
76+
request_context_raw = raw_event["requestContext"]
77+
assert request_context.route_key == request_context_raw["routeKey"]
78+
assert request_context.disconnect_status_code is None
79+
assert request_context.message_id == request_context_raw["messageId"]
80+
assert request_context.event_type == request_context_raw["eventType"]
81+
assert request_context.extended_request_id == request_context_raw["extendedRequestId"]
82+
assert request_context.request_time == request_context_raw["requestTime"]
83+
assert request_context.message_direction == request_context_raw["messageDirection"]
84+
assert request_context.disconnect_reason is None
85+
assert request_context.stage == request_context_raw["stage"]
86+
assert request_context.connected_at == request_context_raw["connectedAt"]
87+
assert request_context.request_time_epoch == request_context_raw["requestTimeEpoch"]
88+
assert request_context.request_id == request_context_raw["requestId"]
89+
assert request_context.domain_name == request_context_raw["domainName"]
90+
assert request_context.connection_id == request_context_raw["connectionId"]
91+
assert request_context.api_id == request_context_raw["apiId"]
92+
93+
identity = request_context.identity
94+
identity_raw = request_context_raw["identity"]
95+
assert identity.source_ip == identity_raw["sourceIp"]

0 commit comments

Comments
 (0)